In [None]:
import os
import sys
import pandas as pd
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm

sns.set(font_scale=2, style='whitegrid')
device = 'cuda' if torch.cuda.is_available() else 'cpu'

if os.path.abspath('..') not in sys.path:
    sys.path.append(os.path.abspath('..'))

In [None]:
from var_gp.vargp import VARGP
from var_gp.datasets import SplitMNIST, PermutedMNIST
from var_gp.train_utils import set_seeds, compute_acc_ent, compute_bwt

In [None]:
cmap = {
    'vcl_100': 'VCL, [100]',
    'vcl_100_coreset_50': 'VCL + coreset (50), [100]',
    'vcl_100_coreset_100': 'VCL + coreset (100), [100]',
    'vcl_100_100': 'VCL, [100, 100]',
    'vcl_100_100_coreset_50': 'VCL + coreset (50), [100, 100]',
    'vcl_100_100_coreset_100': 'VCL + coreset (100), [100, 100]',
    'var_gp': 'VAR-GP',
    'var_gp_block_diag': 'VAR-GP (Block Diagonal)',
    'var_gp_mle_hypers': 'VAR-GP (MLE Hyperparameters)',
    'var_gp_global': 'VAR-GP (Global)',
    'var_gp_dkl_mlp': 'VAR-GP (DKL)'
}

def plot_test_acc(data, clist, name=None, xticks=None, loc='best'):
    data.rename(columns=cmap, inplace=True)
    
    melt_data = data.melt(id_vars=['task'], value_vars=[cmap[k] for k in clist], var_name='Method', value_name='vals')
    
    fig, ax = plt.subplots(figsize=(9,9))
    sns.lineplot(ax=ax, data=melt_data, x='task', y='vals', hue='Method', marker='o',
                 ci='sd', markersize=15, palette=sns.color_palette("tab10", len(clist)))
    ax.set_xlabel('Task', fontsize=30)
    ax.set_ylabel('Test Accuracy', fontsize=30)
    if xticks:
        ax.set_xticks(xticks)
    ax.legend(loc=loc, title='Method')

    fig.tight_layout()
    
    if name is not None:
        # fig.savefig(name, bbox_inches='tight')
        pass

## Split MNIST

### Test Accuracy Curves

In [None]:
data = pd.read_csv('results/smnist.csv')

plot_test_acc(data, ['var_gp', 'vcl_100_coreset_100', 'vcl_100_100_coreset_100'],
              name='split_mnist.pdf')

plot_test_acc(data, ['var_gp', 'var_gp_block_diag', 'var_gp_global', 'var_gp_mle_hypers'],
              name='split_mnist_ep_mean_mle_hypers.pdf', loc='lower left')

### DKL Curves

In [None]:
data = pd.read_csv('results/smnist.csv')
plot_test_acc(data, ['var_gp', 'var_gp_dkl_mlp'],
              name='split_mnist_dkl.pdf')

### Visualizing Inducing Points

In [None]:
def plot_inducing_pts(ckpt_dir):
    for ckpt in range(5):
        state_dict = torch.load(f'{ckpt_dir}/ckpt{ckpt}.pt', map_location=torch.device('cpu'))
        z = state_dict.get('z')
        N = 4
        ind_pts_subset = z[2*ckpt:2*ckpt+2][:, torch.randperm(z.size(1))[:N], :].view(-1, N, 28, 28)
        
        fig, axes = plt.subplots(figsize=(5,10), nrows=N, ncols=2, sharey=True, sharex=True)
        fig.subplots_adjust(wspace=-0.05, hspace=0.01)
        for i in range(N):
            for j in range(2):
                axes[i, j].imshow(ind_pts_subset[j, i], interpolation='bilinear', cmap='gray')
                axes[i, j].set_aspect('equal')
                axes[i, j].grid(False)
                axes[i, j].axis('off')

        fig.suptitle(f'After Task {ckpt}', fontsize=50)

        # fig.savefig(f'smnist_viz_{ckpt + 1}.pdf', bbox_inches='tight')

# Add path to checkpoint directory.
ckpt_dir = 'results/vargp-smnist'
plot_inducing_pts(ckpt_dir)

### Accuracy

In [None]:
run_dir = 'results/vargp-smnist'
ds = SplitMNIST('/tmp', train=False)

acc_mat = []
ent_mat = []

prev_params = []
for t in tqdm(range(5), desc='Train Task'):
    mean_acc_list = []
    mean_ent_list = []
    
    with torch.no_grad():
        cur_params = torch.load(f'{run_dir}/ckpt{t}.pt', map_location=device)
        gp = VARGP.create_clf(ds, M=60, n_f=50, n_var_samples=20, prev_params=prev_params).to(device)
        gp.load_state_dict(cur_params)

        for task in tqdm(range(5), leave=False, desc='Test Task'):
            ds.filter_by_class([2 * task, 2 * task + 1])

            mean_acc, mean_ent = compute_acc_ent(ds, gp, batch_size=256, device=device)

            mean_acc_list.append(mean_acc)
            mean_ent_list.append(mean_ent)
            
    acc_mat.append(mean_acc_list)
    ent_mat.append(mean_ent_list)
    
    ds.filter_by_class()
    prev_params.append(cur_params)
    

acc_mat = torch.Tensor(acc_mat).numpy()
norm_ent_mat = torch.Tensor(ent_mat).numpy() / np.log(10.0)

# np.savez(f'{run_dir}/test_acc_and_ent.npz', acc=acc_mat, ent=norm_ent_mat)

In [None]:
compute_bwt(acc_mat)

### Predictive Entropy Matrix

In [None]:
norm_ent_mat = np.load('results/vargp-smnist/test_acc_and_ent.npz')['ent']

fig, axes = plt.subplots(figsize=(10,5), nrows=1, ncols=2, sharex=True)
sns.heatmap(ax=axes[0], data=norm_ent_mat, linewidths=2, cmap=sns.color_palette("summer", as_cmap=True), cbar=False)
axes[0].set_aspect('equal')
axes[0].set_xlabel('Test Tasks')
axes[0].set_ylabel('Train Tasks')
axes[0].set_xticklabels(axes[0].get_xticklabels(), fontsize=25)
axes[0].set_yticklabels(axes[0].get_yticklabels(), fontsize=25)
axes[0].set_title('VAR-GP (ours)')

norm_ent_mat = np.load('results/vcl-smnist-seed1/test_acc_and_ent.npz')['ent']
sns.heatmap(ax=axes[1], data=norm_ent_mat, linewidths=2, cmap=sns.color_palette("summer", as_cmap=True), cbar=False)
axes[1].set_aspect('equal')
axes[1].set_xlabel('Test Tasks')
axes[1].set_xticklabels(axes[1].get_xticklabels(), fontsize=25)
axes[1].set_yticklabels(axes[1].get_yticklabels(), fontsize=25)
axes[1].set_title('VCL')

fig.tight_layout()

# fig.savefig(f'smnist_norm_entropy.pdf', bbox_inches='tight')

### Varying Inducing Points

In [None]:
data = pd.read_csv('results/varying_M.csv')
plot_data = pd.DataFrame({
  'Test Accuracy': data.groupby(['M']).mean().to_numpy().flatten(),
  'Task': np.repeat(np.expand_dims(np.arange(5), axis=0), 10, axis=0).flatten(),
  'M': np.repeat(np.arange(20, 201, 20), 5)
})
plot_data = plot_data[plot_data['M'] > 20]

fig, ax = plt.subplots(figsize=(9,9))
sns.lineplot(ax=ax, data=plot_data, x='Task', y='Test Accuracy', hue='M', marker='o',
             ci='sd', markersize=10, linewidth=3, alpha=.75, palette=sns.color_palette("tab10", 9))
ax.set_xlabel('Task', fontsize=30)
ax.set_ylabel('Test Accuracy', fontsize=30)
ax.legend(loc='best', title='$M$')

fig.tight_layout()
# fig.savefig('smnist_varying_M.pdf', bbox_inches='tight')

## Permuted MNIST

In [None]:
data = pd.read_csv('results/pmnist.csv')

plot_test_acc(data, ['var_gp', 'vcl_100_coreset_100', 'vcl_100_100_coreset_100'],
              name='permuted_mnist.pdf', xticks=range(10))

plot_test_acc(data, ['var_gp', 'var_gp_block_diag', 'var_gp_global', 'var_gp_mle_hypers'],
              name='permuted_mnist_ep_mean_mle_hypers.pdf', xticks=range(10))

### Accuracy

In [None]:
set_seeds(1)
tasks = [torch.arange(784)] + PermutedMNIST.create_tasks(n=9)

run_dir = 'results/vargp-pmnist-seed1'
ds = PermutedMNIST('/tmp', train=False)

acc_mat = []
ent_mat = []

prev_params = []
for t in tqdm(range(10), desc='Train Task'):
    mean_acc_list = []
    mean_ent_list = []

    with torch.no_grad():
        cur_params = torch.load(f'{run_dir}/ckpt{t}.pt', map_location=device)
        gp = VARGP.create_clf(ds, M=100, n_f=50, n_var_samples=20, prev_params=prev_params).to(device)
        gp.load_state_dict(cur_params)

        for i, task in tqdm(enumerate(tasks), leave=False, desc='Test Task'):
            ds = PermutedMNIST('/tmp', train=False)
            ds.set_task(task)
  
            mean_acc, mean_ent = compute_acc_ent(ds, gp, batch_size=256, device=device)

            mean_acc_list.append(mean_acc)
            mean_ent_list.append(mean_ent)
            
    acc_mat.append(mean_acc_list)
    ent_mat.append(mean_ent_list)

    prev_params.append(cur_params)


acc_mat = torch.Tensor(acc_mat).numpy()
norm_ent_mat = torch.Tensor(ent_mat).numpy() / np.log(10.0)

# np.savez(f'{run_dir}/test_acc_and_ent.npz', acc=acc_mat, ent=ent_mat)

In [None]:
compute_bwt(acc_mat)

### Predictive Entropy Matrix

In [None]:
norm_ent_mat = np.load('results/vargp-pmnist-seed1/test_acc_and_ent.npz')['ent']

fig, axes = plt.subplots(figsize=(10,5), nrows=1, ncols=2, sharex=True)
sns.heatmap(ax=axes[0], data=norm_ent_mat, linewidths=2, cmap=sns.color_palette("summer", as_cmap=True), cbar=False)
axes[0].set_aspect('equal')
axes[0].set_xlabel('Test Tasks')
axes[0].set_ylabel('Train Tasks')
axes[0].set_xticklabels(axes[0].get_xticklabels(), fontsize=25)
axes[0].set_yticklabels(axes[0].get_yticklabels(), fontsize=25)
axes[0].set_title('VAR-GP (ours)')

norm_ent_mat = np.load('results/vcl-pmnist-seed1/test_acc_and_ent.npz')['ent']
sns.heatmap(ax=axes[1], data=norm_ent_mat, linewidths=2, cmap=sns.color_palette("summer", as_cmap=True), cbar=False)
axes[1].set_aspect('equal')
axes[1].set_xlabel('Test Tasks')
axes[1].set_xticklabels(axes[1].get_xticklabels(), fontsize=25)
axes[1].set_yticklabels(axes[1].get_yticklabels(), fontsize=25)
axes[1].set_title('VCL')

fig.tight_layout()

# fig.savefig('pmnist_norm_entropy.pdf', bbox_inches='tight')