In [None]:
import os
import sys
import pandas as pd
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm

sns.set()
device = 'cuda' if torch.cuda.is_available() else 'cpu'

if os.path.abspath('..') not in sys.path:
    sys.path.append(os.path.abspath('..'))

In [None]:
cmap = {
    'vcl_100': 'VCL, [100]',
    'vcl_100_coreset_50': 'VCL + coreset (50), [100]',
    'vcl_100_coreset_100': 'VCL + coreset (100), [100]',
    'vcl_100_100': 'VCL, [100, 100]',
    'vcl_100_100_coreset_50': 'VCL + coreset (50), [100, 100]',
    'vcl_100_100_coreset_100': 'VCL + coreset (100), [100, 100]',
    'var_gp': 'VAR-GP',
    'var_gp_block_diag': 'VAR-GP (Block Diagonal)',
    'var_gp_mle_hypers': 'VAR-GP (MLE Hyperparameters)',
    'var_gp_global': 'VAR-GP (Global)',
}

def plot_test_acc(data, clist, min_y=0.0, max_x=1, save=False, name=None, loc='best'):
    data.rename(columns=cmap, inplace=True)
    
    melt_data = data.melt(id_vars=['task'], value_vars=[cmap[k] for k in clist], var_name='Method', value_name='vals')
    
    plt.figure(figsize=(10,10))
    g = sns.lineplot(x='task', y='vals', hue='Method', marker='o', data=melt_data, ci=99)
    g.set(xlabel='Task', ylabel='Test Accuracy', ylim=(min_y,1.005));
    g.legend(loc=loc, fontsize=20)
    g.set_xlabel(xlabel='Task', fontsize=20)
    g.set_ylabel(ylabel='Test Accuracy', fontsize=20)
    g.set_yticks(np.arange(min_y, 1.0 + 1e-3, 0.05))
    g.set_xticks(range(max_x))
    
    if save and name is not None:
        g.figure.savefig(name, bbox_inches='tight', pad_inches=0.1)

## Split MNIST Test Accuracy Curves

In [None]:
data = pd.read_csv('results/smnist.csv')

plot_test_acc(data, ['var_gp', 'vcl_100_coreset_100', 'vcl_100_100_coreset_100'],
              min_y=0.7, save=False, name='split_mnist.png', max_x=5)

plot_test_acc(data, ['var_gp', 'var_gp_block_diag', 'var_gp_global', 'var_gp_mle_hypers'],
              save=False, name='split_mnist_ep_mean_mle_hypers.png', max_x=5, loc='lower left')

### Visualizing Inducing Points

In [None]:
def plot_inducing_pts(ckpt_dir, save=False):
    for ckpt in range(5):
        state_dict = torch.load(f'{ckpt_dir}/ckpt{ckpt}.pt', map_location=torch.device('cpu'))
        z = state_dict.get('z')
        N = 4
        ind_pts_subset = z[2*ckpt:2*ckpt+2][:, torch.randperm(z.size(1))[:N], :].view(-1, N, 28, 28)
        
        fig, axes = plt.subplots(2, N, sharey=True, sharex=True, figsize=(20,10))
        fig.subplots_adjust(hspace=0.05, wspace=0.001)
        for i in range(2):
            for j in range(N):
                axes[i, j].imshow(ind_pts_subset[i, j], interpolation='bilinear', cmap='gray')
                axes[i, j].set(aspect='equal')
                axes[i, j].grid(False)
                axes[i, j].axis('off')
                axes[i, j].margins(x=0.0, y=0.0)

        fig.suptitle(f'After Task {ckpt}', fontsize=50)
        # fig.tight_layout()

        if save:
            fig.savefig(f'smnist_viz_{ckpt + 1}.png', bbox_inches='tight', pad_inches=0)

# Add path to checkpoint directory.
ckpt_dir = 'results/vargp-smnist'
plot_inducing_pts(ckpt_dir, save=False)

### Normalized Predictive Entropies

In [None]:
from torch.utils.data import DataLoader
from torch.distributions import Categorical
from continual_gp.train_utils import create_class_vargp
from continual_gp.datasets import SplitMNIST

run_dir = 'results/vargp-smnist'
ds = SplitMNIST('/tmp', train=False)

ent_mat = []

prev_params = []
for t in tqdm(range(5)):
    mean_ent_list = []
    with torch.no_grad():
        cur_params = torch.load(f'{run_dir}/ckpt{t}.pt', map_location=device)
        gp = create_class_vargp(ds, M=60, n_f=50, n_var_samples=20, prev_params=prev_params).to(device)
        gp.load_state_dict(cur_params)

        for task in tqdm(range(5), leave=False):
            ds.filter_by_class([2 * task, 2 * task + 1])
  
            all_entropy = 0.0
            
            for x, _ in tqdm(DataLoader(ds, batch_size=256), leave=False):
                dist = Categorical(probs=gp.predict(x.to(device)))
                all_entropy += dist.entropy().sum()
            
            mean_ent_list.append(all_entropy.cpu().item() / len(ds))

    ds.filter_by_class()
    prev_params.append(cur_params)
    ent_mat.append(mean_ent_list)

In [None]:
plt.figure(figsize=(10,10))
im = plt.imshow(torch.Tensor(ent_mat) / torch.tensor(10.0).log(), cmap='gray')
plt.xlabel('Test Tasks', fontsize=20)
plt.ylabel('Train Tasks', fontsize=20)
plt.grid(False);
# plt.colorbar(im)
# plt.savefig(f'smnist_norm_entropy.png', bbox_inches='tight', pad_inches=0);

## Permuted MNIST

In [None]:
data = pd.read_csv('results/pmnist.csv')

plot_test_acc(data, ['var_gp', 'vcl_100_coreset_100', 'vcl_100_100_coreset_100'],
              min_y=0.7, save=False, name='permuted_mnist.png', max_x=10)

plot_test_acc(data, ['var_gp', 'var_gp_block_diag', 'var_gp_global', 'var_gp_mle_hypers'],
              save=False, name='permuted_mnist_ep_mean_mle_hypers.png', max_x=10)

### Normalized Predictive Entropies

In [None]:
from torch.utils.data import DataLoader
from torch.distributions import Categorical
from continual_gp.train_utils import create_class_vargp, set_seeds
from continual_gp.datasets import PermutedMNIST

set_seeds(1)
tasks = [torch.arange(784)] + PermutedMNIST.create_tasks(n=9)

run_dir = 'results/vargp-pmnist-seed1'
ds = PermutedMNIST('/tmp', train=False)

ent_mat = []

prev_params = []
for t in tqdm(range(10)):
    mean_ent_list = []
    with torch.no_grad():
        cur_params = torch.load(f'{run_dir}/ckpt{t}.pt', map_location=device)
        gp = create_class_vargp(ds, M=100, n_f=50, n_var_samples=20, prev_params=prev_params).to(device)
        gp.load_state_dict(cur_params)

        for i, task in tqdm(enumerate(tasks), leave=False):
            ds = PermutedMNIST('/tmp', train=False)
            ds.set_task(task)
  
            all_entropy = 0.0
            
            for x, _ in tqdm(DataLoader(ds, batch_size=256), leave=False):
                dist = Categorical(probs=gp.predict(x.to(device)))
                all_entropy += dist.entropy().sum()
            
            mean_ent_list.append(all_entropy.cpu().item() / len(ds))

    prev_params.append(cur_params)
    ent_mat.append(mean_ent_list)

In [None]:
plt.figure(figsize=(10,10))
im = plt.imshow(torch.Tensor(ent_mat) / torch.tensor(10.0).log(), cmap='gray')
plt.xlabel('Test Tasks', fontsize=20)
plt.ylabel('Train Tasks', fontsize=20)
plt.grid(False);
# plt.colorbar(im)
# plt.savefig(f'pmnist_norm_entropy.png', bbox_inches='tight', pad_inches=0);

## VCL Predictive Entropy

### Split MNIST

In [None]:
ent_mat = np.load('results/vcl-smnist-seed1/test_acc_and_ent.npz')['ent']

plt.figure(figsize=(10,10))
im = plt.imshow(ent_mat, cmap='gray')
plt.xlabel('Test Tasks', fontsize=20)
plt.ylabel('Train Tasks', fontsize=20)
plt.grid(False);
# plt.colorbar(im)
# plt.savefig(f'vcl_smnist_norm_entropy.png', bbox_inches='tight', pad_inches=0);

### Permuted MNIST

In [None]:
ent_mat = np.load('results/vcl-pmnist-seed1/test_acc_and_ent.npz')['ent']

plt.figure(figsize=(10,10))
im = plt.imshow(ent_mat, cmap='gray')
plt.xlabel('Test Tasks', fontsize=20)
plt.ylabel('Train Tasks', fontsize=20)
plt.grid(False);
# plt.colorbar(im)
# plt.savefig(f'vcl_pmnist_norm_entropy.png', bbox_inches='tight', pad_inches=0);