In [None]:
import os 
import numpy as np
import pandas as pd
from copy import deepcopy
import seaborn

import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
%matplotlib qt

In [None]:
bw = 6.0 # 4.0, 6.0, 8.0, and 10.0

In [None]:
outlier_list = np.concatenate([np.linspace(90., 410., 10), np.linspace(101., 110., 1)])

In [None]:
xlimit = (21., 155.)
plot_pts_cnt = 4000
newx = np.linspace(xlimit[0], xlimit[1], plot_pts_cnt)
basisfunction_name = 'Gaussian'
var_name = 'waiting'
fontsize_label = 20
fontsize_tick = 10
fontsize_info = 20
fontsize_title = 12
fontsize_suptitle = 22

linewidth = 3.0
scilimits = (0, 3)

In [None]:
fig, (ax1, ax2) = plt.subplots(
    nrows = 1, 
    ncols = 2, 
    figsize = (20, 40), 
    # tight_layout = True, 
    constrained_layout = False)

fig.subplots_adjust(top=0.9)

# func
def update_plots(outlier): 

    # read in the original data 
    true_data = np.load('data/geyser.npy').astype(np.float64)
    df = deepcopy(true_data[:, 0]).reshape(-1, 1)
    df[df == 108.0] = outlier
    
    pddf = pd.DataFrame({'vals': df.flatten(),
                         'cate': [False if df[i] != outlier else True for i in range(df.shape[0])]})
    
    ax1.clear()
    ax2.clear()
    
    # ---------------------------------------------------------------------------------------
    # set ax1 title 
    ax1.set_title('Basis Functions Centered at Evenly Spaced Grid Points', fontsize = fontsize_title)
    # set x-limit 
    ax1.set_xlim(xlimit)
    # set x label 
    ax1.set_xlabel(var_name, fontsize = fontsize_label)
    # set y label 
    ax1.set_ylabel('log density', fontsize = fontsize_label)
    # formatting tick marks and tick labels 
    ax1.tick_params(axis = 'both', labelsize = fontsize_tick)
    ax1.ticklabel_format(axis = 'y', style = 'sci', scilimits = scilimits)
    # add rug plot at normal observations 
    seaborn.rugplot(pddf['vals'], axis = 'x', ax = ax1, color = 'tab:blue')
    seaborn.rugplot(np.array([outlier]), axis = 'x', ax = ax1, color = 'red')

    file_name_grid = f'data/finexpfam_results/{basisfunction_name}_basis_function_bw={bw}/' \
                     f'add{outlier}_logdenvals_scorematching_bw={bw}_40_100_21.npy'
    denvals_grid = np.load(file_name_grid)
    # plot log density when the basis functions are centered at grid points 
    ax1.plot(newx.flatten(), denvals_grid.flatten(), color = 'tab:blue', linewidth = linewidth)
    
    # draw a vertical line at the outlier 
    ax1.axvline(outlier, 0, 1, ls = '--', color = 'tab:purple', alpha = 0.5)
    
#     # add grid
#     ax1.grid(color = 'k', ls = (0, (3, 10, 1, 10)), lw = 0.25)

    # add plot information 
    info = f'Add {outlier}'
    ax1.text(0.988, 0.988,
             info,
             fontsize = fontsize_info,
             # fontfamily = 'serif',
             multialignment = 'left',
             horizontalalignment = 'right',
             verticalalignment = 'top',
             transform = ax1.transAxes,
             bbox = {'facecolor': 'none',
                     'boxstyle': 'Round, pad=0.2'})

    # ---------------------------------------------------------------------------------------
    # set ax2 title 
    ax2.set_title('Basis Functions Centered at Data Points', fontsize = fontsize_title)
    # set x-limit 
    ax2.set_xlim(xlimit)
    # set x label 
    ax2.set_xlabel(var_name, fontsize = fontsize_label)
    # set y label 
    ax2.set_ylabel('log density', fontsize = fontsize_label)
    # formatting tick marks and tick labels 
    ax2.tick_params(axis = 'both', labelsize = fontsize_tick)
    ax2.ticklabel_format(axis = 'y', style = 'sci', scilimits = scilimits)
    # add rug plot at normal observations 
    seaborn.rugplot(pddf['vals'], axis = 'x', ax = ax2, color = 'tab:blue')
    seaborn.rugplot(np.array([outlier]), axis = 'x', ax = ax2, color = 'red')

    file_name_data = f'data/finexpfam_results/{basisfunction_name}_basis_function_bw={bw}/' \
                     f'add{outlier}_logdenvals_scorematching_bw={bw}_data.npy'
    denvals_data = np.load(file_name_data)
    # plot log density when the basis functions are centered at data points 
    ax2.plot(newx.flatten(), denvals_data.flatten(), color = 'tab:blue', linewidth = linewidth)

    # draw a vertical line at the outlier 
    ax2.axvline(outlier, 0, 1, ls = '--', color = 'tab:purple', alpha = 0.5)

#     # add grid
#     ax2.grid(color = 'k', ls = (0, (3, 10, 1, 10)), lw = 0.25)
    
    # add plot information 
    info = f'Add {outlier}'
    ax2.text(0.988, 0.988,
             info,
             # fontfamily = 'serif',
             fontsize = fontsize_info,
             multialignment = 'left',
             horizontalalignment = 'right',
             verticalalignment = 'top',
             transform = ax2.transAxes,
             bbox = {'facecolor': 'none',
                     'boxstyle': 'Round, pad=0.2'})
    
    return ax1, ax2

ani = FuncAnimation(
    fig, 
    update_plots, 
    frames = outlier_list, 
    interval = 500)

fig.suptitle(r'Logarithm of Score Matching Density Estimates with $\sigma$={bw}'.format(bw=bw), 
             fontsize = fontsize_suptitle, y = 0.98)

# uncomment the following line to save the gif
ani.save(f'gif/waiting_fin_Gaussian_bw={bw}.gif', writer='imagemagick')

plt.show()