In [None]:
import pandas as pd
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns

sns.set()

In [None]:
cmap = {
    'vcl_100': 'VCL, [100]',
    'vcl_100_coreset_50': 'VCL + coreset (50), [100]',
    'vcl_100_coreset_100': 'VCL + coreset (100), [100]',
    'vcl_100_100': 'VCL, [100, 100]',
    'vcl_100_100_coreset_50': 'VCL + coreset (50), [100, 100]',
    'vcl_100_100_coreset_100': 'VCL + coreset (100), [100, 100]',
    'var_gp': 'VAR-GP',
    'var_gp_block_diag': 'VAR-GP (Block Diagonal)',
    'var_gp_mle_hypers': 'VAR-GP (MLE Hyperparameters)',
    'var_gp_global': 'VAR-GP (Global)',
}

In [None]:
def plot_test_acc(data, clist, min_y=0.0, max_x=1, save=None, loc='best'):
    data.rename(columns=cmap, inplace=True)
    
    melt_data = data.melt(id_vars=['task'], value_vars=[cmap[k] for k in clist], var_name='Method', value_name='vals')
    
    plt.figure(figsize=(10,10))
    g = sns.lineplot(x='task', y='vals', hue='Method', marker='o', data=melt_data, ci=99)
    g.set(xlabel='Task', ylabel='Test Accuracy', ylim=(0.6,1.005));
    g.legend(loc=loc, fontsize=20)
    g.set_yticks(np.arange(min_y, 1.0 + 1e-3, 0.05))
    g.set_xticks(range(max_x))
    
    if save:
        g.figure.savefig(save, bbox_inches='tight', pad_inches=0.0)

## Split MNIST Test Accuracy Curves

In [None]:
data = pd.read_csv('split_mnist_results.csv')

plot_test_acc(data, ['var_gp', 'vcl_100_coreset_100', 'vcl_100_100_coreset_100'],
              min_y=0.6, save='split_mnist.png', max_x=5)

plot_test_acc(data, ['var_gp', 'var_gp_block_diag', 'var_gp_global', 'var_gp_mle_hypers'],
              save='split_mnist_ep_mean_mle_hypers.png', max_x=5, loc='lower left')

### Visualizing Inducing Points

In [None]:
def plot_inducing_pts(ckpt_dir, save=False):
    for ckpt in range(5):
        state_dict = torch.load(f'{ckpt_dir}/ckpt{ckpt}.pt', map_location=torch.device('cpu'))
        z = state_dict.get('z')
        N = 4
        ind_pts_subset = z[2*ckpt:2*ckpt+2][:, torch.randperm(z.size(1))[:N], :].view(-1, N, 28, 28)
        
        fig, axes = plt.subplots(2, N, sharey=True, sharex=True, figsize=(20,10))
        fig.subplots_adjust(hspace=0.05, wspace=0.001)
        for i in range(2):
            for j in range(N):
                axes[i, j].imshow(ind_pts_subset[i, j], interpolation='bilinear', cmap='gray')
                axes[i, j].set(aspect='equal')
                axes[i, j].grid(False)
                axes[i, j].axis('off')
                axes[i, j].margins(x=0.0, y=0.0)

        fig.suptitle(f'After Task {ckpt}', fontsize=50)
        # fig.tight_layout()

        if save:
            fig.savefig(f'smnist_viz_{ckpt + 1}.png', bbox_inches='tight', pad_inches=0)

# Add path to checkpoint directory.
ckpt_dir = '/Users/sanyam/Downloads/ckpt'
plot_inducing_pts(ckpt_dir)

In [None]:
from torch.utils.data import DataLoader
from torch.distributions import Categorical
from continual_gp.train_utils import create_class_vargp
from continual_gp.datasets import SplitMNIST


run_dir = '/Users/sanyam/Downloads/s-mnist'
ds = SplitMNIST('/tmp', train=False)

ent_mat = []

prev_params = []
for t in range(5):
    mean_ent_list = []
    with torch.no_grad():
        cur_params = torch.load(f'{run_dir}/ckpt{t}.pt', map_location=torch.device('cpu'))
        gp = create_class_vargp(ds, M=60, n_f=100, n_var_samples=20, prev_params=prev_params)
        gp.load_state_dict(cur_params)


        for task in range(5):
            ds.filter_by_class([2 * task, 2 * task + 1])
  
            all_entropy = 0.0
            
            for x, _ in DataLoader(ds, batch_size=1024):
                dist = Categorical(probs=gp.predict(x))
                all_entropy += dist.entropy().sum()
            
            mean_ent_list.append(all_entropy / len(ds))

    ds.filter_by_class()
    prev_params.append(cur_params)
    ent_mat.append(mean_ent_list)

In [None]:
im = plt.imshow(1.0 - torch.Tensor(ent_mat) / torch.tensor(10.0).log(), cmap='gray')
plt.grid(False)
plt.colorbar(im)

## Permuted MNIST

In [None]:
data = pd.read_csv('permuted_mnist_results.csv')

plot_test_acc(data, ['var_gp', 'vcl_100_coreset_100', 'vcl_100_100_coreset_100'],
              min_y=0.6, save='permuted_mnist.png', max_x=10)

plot_test_acc(data, ['var_gp', 'var_gp_block_diag', 'var_gp_global', 'var_gp_mle_hypers'],
              save='permuted_mnist_ep_mean_mle_hypers.png', max_x=10)