In [None]:
import pandas as pd
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns

sns.set()

save = True

In [None]:
# import os
# from collections import defaultdict
# 
## Parse the VCL experiments folder.
# def parse_raw_vcl(path):
#     results = defaultdict(dict)

#     for exp in os.listdir(path):
#         if not exp.startswith('nn_model'):
#             continue

#         kv = exp.split('_')
#         layers = '_'.join([v.strip() for v in kv[kv.index('hidden') + 2].lstrip('[').rstrip(']').split(',')])
#         coreset = kv[kv.index('coreset') + 2]
#         seed = kv[kv.index('seed') + 1]

#         results[f'vcl_{layers}_coreset_{coreset}'][seed] = np.load(f'{path}/{exp}/test_acc.npz')['acc']

#     for seed in range(5):
#         avg_acc = []
#         acc_mat = results['vcl_100_100_coreset_100'][f'{seed + 1}']
#         for i, row in enumerate(acc_mat):
#             avg_acc.append(np.mean(row[:i + 1]))

#         print('\n'.join([str(v) for v in avg_acc]))
#         # print()

In [None]:
cmap = {
    'vcl_100': 'VCL, [100]',
    'vcl_100_coreset_50': 'VCL + coreset (50), [100]',
    'vcl_100_coreset_100': 'VCL + coreset (100), [100]',
    'vcl_100_100': 'VCL, [100, 100]',
    'vcl_100_100_coreset_50': 'VCL + coreset (50), [100, 100]',
    'vcl_100_100_coreset_100': 'VCL + coreset (100), [100, 100]',
    'var_gp': 'VAR-GP',
    'var_gp_block_diag': 'VAR-GP (Block Diagonal)',
    'var_gp_mle_hypers': 'VAR-GP (MLE Hyperparameters)',
}

## Split MNIST

In [None]:
data = pd.read_csv('split_mnist_results.csv')
data.rename(columns=cmap, inplace=True)

melt_data = data.melt(id_vars=['task'], value_vars=[cmap[k] for k in ['var_gp', 'vcl_100_coreset_100', 'vcl_100_100_coreset_100']], var_name='Approach', value_name='vals')


plt.figure(figsize=(10,10))
g = sns.lineplot(x='task', y='vals', hue='Approach', marker='o', data=melt_data, ci='sd')
g.set(xlabel='Task', ylabel='Test Accuracy', ylim=(0.6,1.005));
g.legend(loc='lower left', fontsize=20)
g.set_yticks(np.arange(0.6, 1.0 + 1e-3, 0.05))
g.set_xticks(range(5))
if save:
    g.figure.savefig('split_mnist.png', bbox_inches='tight', pad_inches=0.1)

In [None]:
melt_data = data.melt(id_vars=['task'], value_vars=[cmap[k] for k in ['var_gp', 'var_gp_block_diag', 'var_gp_mle_hypers']], var_name='Approach', value_name='vals')

plt.figure(figsize=(10,10))
g = sns.lineplot(x='task', y='vals', hue='Approach', marker='o', data=melt_data, ci='sd')
g.set(xlabel='Task', ylabel='Test Accuracy', ylim=(0.0,1.005));
g.legend(loc='bottom left', fontsize=20)
g.set_yticks(np.arange(0.6, 1.0 + 1e-3, 0.05))
g.set_xticks(range(5))
if save:
    g.figure.savefig('split_mnist_ep_mean_mle_hypers.png', bbox_inches='tight', pad_inches=0.1)

### Visualizing Inducing Points

In [None]:
def plot_inducing_pts(ckpt_dir):
    for ckpt in range(5):
        state_dict = torch.load(f'{ckpt_dir}/ckpt{ckpt}.pt', map_location=torch.device('cpu'))
        z = state_dict.get('z')
        N = 4
        ind_pts_subset = z[2*ckpt:2*ckpt+2][:, torch.randperm(z.size(1))[:N], :].view(-1, N, 28, 28)
        
        fig, axes = plt.subplots(2, N, sharey=True, sharex=True, figsize=(20,10))
        fig.subplots_adjust(hspace=0.05, wspace=0.001)
        for i in range(2):
            for j in range(N):
                axes[i, j].imshow(ind_pts_subset[i, j], interpolation='bilinear', cmap='gray')
                axes[i, j].set(aspect='equal')
                axes[i, j].grid(False)
                axes[i, j].axis('off')
                axes[i, j].margins(x=0.0, y=0.0)

        fig.suptitle(f'After Task {ckpt}', fontsize=50)
        # fig.tight_layout()

        if save:
            fig.savefig(f'smnist_viz_{ckpt + 1}.png', bbox_inches='tight', pad_inches=0)

# Add path to checkpoint directory.
ckpt_dir = '/Users/sanyam/Downloads/ckpt'
plot_inducing_pts(ckpt_dir)

## Permuted MNIST

In [None]:
data = pd.read_csv('permuted_mnist_results.csv')
data.rename(columns=cmap, inplace=True)

melt_data = data.melt(id_vars=['task'], value_vars=[cmap[k] for k in ['var_gp', 'vcl_100_coreset_100', 'vcl_100_100_coreset_100']], var_name='Approach', value_name='vals')

plt.figure(figsize=(10,10))
g = sns.lineplot(x='task', y='vals', hue='Approach', marker='o', data=melt_data, ci='sd')
g.set(xlabel='Task', ylabel='Test Accuracy', ylim=(0.6,1.0));
g.legend(loc='lower left', fontsize=20)
g.set_yticks(np.arange(0.6, 1.0 + 1e-3, 0.05))
g.set_xticks(range(10))
if save:
    g.figure.savefig('permuted_mnist.png', bbox_inches='tight', pad_inches=0.1)

In [None]:
melt_data = data.melt(id_vars=['task'], value_vars=[cmap[k] for k in ['var_gp', 'var_gp_block_diag', 'var_gp_mle_hypers']], var_name='Approach', value_name='vals')

plt.figure(figsize=(10,10))
g = sns.lineplot(x='task', y='vals', hue='Approach', marker='o', data=melt_data, ci=99)
g.set(xlabel='Task', ylabel='Test Accuracy', ylim=(0.0,1.0));
g.legend(loc='center right', fontsize=20)
g.set_yticks(np.arange(0.0, 1.0 + 1e-3, 0.05))
g.set_xticks(range(10))
if save:
    g.figure.savefig('permuted_mnist_ep_mean_mle_hypers.png', bbox_inches='tight', pad_inches=0.1)