In [None]:
# setup
import pandas as pd
pd.options.mode.chained_assignment = None
df = pd.read_csv('Walmart.csv')

df['Date'] = pd.to_datetime(df['Date'], format='%d-%m-%Y')
df['Day'] = df['Date'].dt.day
df['Month'] = df['Date'].dt.month
df['Year'] = df['Date'].dt.year
df['Holiday_Flag'] = df['Holiday_Flag'].astype('bool')
df.columns = [x.lower() for x in df.columns]

In [None]:
# install packages
%pip install matplotlib --user
%pip install seaborn --user
%pip install scipy --user
%pip install statsmodels --user
%pip install celluloid --user
%pip install ipywidgets --user
%pip install ffmpeg
%pip install scikit-learn

%matplotlib inline
%config InlineBackend.figure_format ='retina' # improves resolution
%matplotlib notebook

In [None]:
# imports
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats as stats
import statsmodels.formula.api as smf
import numpy as np
from matplotlib import pyplot as plt
from sklearn.linear_model import LogisticRegression # for reducing df size for animation
from scipy.signal import find_peaks
import matplotlib.lines as mlines

# for animation
import matplotlib.animation as animation
from IPython.display import HTML, display, clear_output
from matplotlib.animation import FuncAnimation, PillowWriter # for saving animation
from celluloid import Camera
from ipywidgets import widgets

import os # for saving files
home_directory = os.path.expanduser('~')
downloads_path = os.path.join(home_directory, 'Downloads')

# formatting seaborn plots
sns.set(font_scale=1.2, style="white")

In [None]:
def smoothGaussian(listin, degree):
    window = degree * 2 - 1
    smoothed = np.zeros(len(listin))

    # Main smoothing loop
    for i in range(len(listin)):
        start_index = max(i - degree + 1, 0)
        end_index = min(i + degree, len(listin))
        window_size = end_index - start_index

        weight = np.array([1.0] * window_size)
        weightGauss = []
        for j in range(window_size):
            j_shifted = j - window_size // 2
            frac = j_shifted / float(window_size)
            gauss = 1 / (np.exp((4 * frac) ** 2))
            weightGauss.append(gauss)
        
        weight *= np.array(weightGauss)
        smoothed[i] = sum(np.array(listin[start_index:end_index]) * weight) / sum(weight)

    return smoothed

# define holiday dates globally
holiday_dates = df[df['holiday_flag']]['date'].unique()

# complete_animation(df, 'Weekly_Sales_Over_Time', 'date', 'weekly_' + yCol)
def complete_animation(df, plotTitle, xCol, yCol, degree):
    # Resample and calculate quartiles and mean because original df is too large
    df_resampled = df.resample('W', on=xCol).agg({yCol: ['max', 'min', 'mean', 
                                                         lambda x: np.percentile(x, 25), 
                                                         lambda x: np.percentile(x, 50), 
                                                         lambda x: np.percentile(x, 75)]}).reset_index()
    df_resampled.columns = [xCol, 'max', 'min', 'mean', 'q1', 'median', 'q3']

    # Apply Gaussian smoothing to quartiles
    df_resampled['smoothed_max'] = smoothGaussian(df_resampled['max'], degree)
    df_resampled['smoothed_min'] = smoothGaussian(df_resampled['min'], degree)
    df_resampled['smoothed_mean'] = smoothGaussian(df_resampled['mean'], degree)
    df_resampled['smoothed_q1'] = smoothGaussian(df_resampled['q1'], degree)
    df_resampled['smoothed_median'] = smoothGaussian(df_resampled['median'], degree)
    df_resampled['smoothed_q3'] = smoothGaussian(df_resampled['q3'], degree)
    
    # make preliminary plot to just get xticks and yticks
    fig, ax = plt.subplots()
    sns.lineplot(data=df, x=xCol, y=yCol)
    plt.xticks(rotation=45)
    plt.draw()
    xticks = ax.get_xticks()
    xticklabels = ax.get_xticklabels()
    plt.clf() #clear it

    fig, ax = plt.subplots(figsize=(10, 6))
    # calculate bounds
    x_min, x_max = df_resampled[xCol].min(), df_resampled[xCol].max()
    y_min, y_max = df_resampled[['smoothed_max', 'smoothed_min', 'smoothed_mean', 'smoothed_q1', 'smoothed_median', 'smoothed_q3']].min().min(), \
                   df_resampled[['smoothed_max', 'smoothed_min', 'smoothed_mean', 'smoothed_q1', 'smoothed_median', 'smoothed_q3']].max().max()
    
    # colors
    color_palette = sns.color_palette("mako", n_colors=7)
    colors = {
        'max': color_palette[0],
        'q3': color_palette[3],
        'mean': 'black',
        'median': color_palette[4],
        'q1': color_palette[5],
        'min': color_palette[6],
        'time_series': color_palette[2]
    }

    def animate(i):
        ax.clear()
        df_index = min(i, len(df) - 1)  # For the original df
        df_resampled_index = min(i, len(df_resampled) - 1)  # For the resampled df

        # Holiday lines
        for holiday_date in holiday_dates:
            if df[df[xCol] == holiday_date].index[0] <= df_index:
                ax.axvline(x=holiday_date, color='grey', linestyle='--', linewidth=1)

        # Original DataFrame line plot
        ax.plot(df[xCol][:df_index + 1], df[yCol][:df_index + 1], color=colors['time_series'], label='Time Series ', linewidth=5)

        # Resampled DataFrame line plots
        ax.plot(df_resampled[xCol][:df_resampled_index + 1], df_resampled['smoothed_max'][:df_resampled_index + 1], color=colors['max'], label='Max ' + yCol, linewidth=1)
        ax.plot(df_resampled[xCol][:df_resampled_index + 1], df_resampled['smoothed_q3'][:df_resampled_index + 1], color=colors['q3'], label='Q3 '+ yCol, linewidth=1.5)
        ax.plot(df_resampled[xCol][:df_resampled_index + 1], df_resampled['smoothed_median'][:df_resampled_index + 1], color=colors['median'], label='Median ' + yCol, linewidth=2.5)
        ax.plot(df_resampled[xCol][:df_resampled_index + 1], df_resampled['smoothed_mean'][:df_resampled_index + 1], color=colors['mean'], label='Mean ' + yCol, linewidth=3.5)
        ax.plot(df_resampled[xCol][:df_resampled_index + 1], df_resampled['smoothed_q1'][:df_resampled_index + 1], color=colors['q1'], label='Q1 ' + yCol, linewidth=1.5)
        ax.plot(df_resampled[xCol][:df_resampled_index + 1], df_resampled['smoothed_min'][:df_resampled_index + 1], color=colors['min'], label='Min ' + yCol, linewidth=1)

        # Shading
        ax.fill_between(df_resampled[xCol][:i], df_resampled['smoothed_min'][:i], df_resampled['smoothed_q1'][:i], color=colors['min'], alpha=0.1)
        ax.fill_between(df_resampled[xCol][:i], df_resampled['smoothed_q1'][:i], df_resampled['smoothed_median'][:i], color=colors['q1'], alpha=0.2)
        ax.fill_between(df_resampled[xCol][:i], df_resampled['smoothed_median'][:i], df_resampled['smoothed_q3'][:i], color=colors['median'], alpha=0.3)
        ax.fill_between(df_resampled[xCol][:i], df_resampled['smoothed_q3'][:i], df_resampled['smoothed_max'][:i], color=colors['q3'], alpha=0.2)
        
        #  formatting ticks, labels, and bounds
        plt.xlabel(xCol, fontsize=20)
        plt.ylabel(yCol, fontsize=20)
        plt.title(plotTitle, fontsize=20)
        ax.set_xlim(x_min, x_max)
        ax.set_ylim(y_min, y_max)
        ax.set_xticks(xticks)
        ax.set_xticklabels(xticklabels, rotation=45)
        
        # for legend
        holiday_line = mlines.Line2D([], [], color='grey', linestyle='--', linewidth=1, label='Holiday')
        time_series_line = mlines.Line2D([], [], color=colors['time_series'], label='Time Series ', linewidth=5)
        max_line = mlines.Line2D([], [], color=colors['max'], label='Max ' + yCol, linewidth=1)
        q3_line = mlines.Line2D([], [], color=colors['q3'], label='Q3'+ yCol, linewidth=1.5)
        median_line = mlines.Line2D([], [], color=colors['median'], label='Median ' + yCol, linewidth=2.5)
        mean_line = mlines.Line2D([], [], color=colors['mean'], label='Mean ' + yCol, linewidth=3.5)
        q1_line = mlines.Line2D([], [], color=colors['q1'], label='Q1 ' + yCol, linewidth=1.5)
        min_line = mlines.Line2D([], [], color=colors['min'], label='Min ' + yCol, linewidth=1)
        plt.legend()
        plt.legend(handles=[holiday_line,
                            time_series_line,
                            max_line,
                            q3_line,
                            median_line,
                            mean_line,
                            q1_line,
                            min_line],
                   loc='upper right')
        
        plt.subplots_adjust(bottom=0.2)
    # call animator and display
    ani = matplotlib.animation.FuncAnimation(fig, animate, frames=len(df_resampled), interval=100, repeat=True)
    
    display(HTML(ani.to_jshtml(default_mode='loop')))
    
def on_button_clicked(b):
    print('Generating animation...')
    complete_animation(df, 'Weekly_Sales_Over_Time', 'date', 'weekly_sales', degree=4)
    
button = widgets.Button(description='Generate Animation')
button.on_click(on_button_clicked)
display(button)

In [None]:
#  CELL TO SAVE ANIMATION
def on_second_button_clicked(b):
    print('Saving animation...')
    ani.save('saved_animation.gif', writer=PillowWriter(fps=60))

button2 = widgets.Button(description='Save Animation')
button2.on_click(on_second_button_clicked)
display(button2)