In [None]:
import os
import sys
import numpy as np
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

sns.set(font_scale=2, style='whitegrid')
device = 'cuda' if torch.cuda.is_available() else 'cpu'
save = False

if os.path.abspath('..') not in sys.path:
    sys.path.append(os.path.abspath('..'))

In [None]:
from var_gp.datasets import ToyDataset
from var_gp.vargp import VARGP
from var_gp.train_utils import set_seeds

In [None]:
set_seeds(1)

toy_ds = ToyDataset()
df = pd.DataFrame({ 'x': toy_ds.data[:, 0].numpy(), 'y': toy_ds.data[:, 1].numpy(), 'Class': toy_ds.targets.numpy() })

fig, ax = plt.subplots(figsize=(9,9))
sns.scatterplot(ax=ax, data=df, x='x', y='y', hue='Class', palette='Set2', s=200, edgecolor='black', linewidth=2)
ax.set_xlabel('')
ax.set_ylabel('')
handles, labels = ax.get_legend_handles_labels()
for h, l in zip(handles, labels):
    h.set_edgecolor('black')
    h.set_linewidth(2)
    h.set_sizes([200])
ax.legend(handles=handles, labels=labels, title='Class');
# fig.savefig('toy_data.pdf', bbox_inches='tight')

In [None]:
grid_data = torch.cat([v.unsqueeze(-1) for v in torch.meshgrid([torch.arange(-3,3,0.1), torch.arange(-3,3,0.1)])], dim=-1).permute(1, 0, 2)

def plot_task(preds):
    out = preds.reshape(preds.size(0), *grid_data.shape[:-1], -1)

    fig, axes = plt.subplots(2, 4, sharey=True, sharex=True, figsize=(40, 20))

    for r in range(2):
        for i in range(preds.size(-1)):
            toy_ds.filter_by_class([i])

            axes[r, i].contourf(out[r, ..., i], cmap=sns.color_palette("viridis", as_cmap=True),
                            extent=(-3,3,-3,3), origin='lower')
            axes[r, i].set(aspect='equal')
            axes[r, i].set_xlim(-3, 3)
            axes[r, i].set_ylim(-3, 3)
            axes[r, i].grid(False)
            axes[r, i].set_xticks([])
            axes[r, i].set_yticks([])
            if r == 0:
                axes[r, i].set_title(f'Class {i}', fontsize=75)

            axes[r, i].scatter(toy_ds.data[toy_ds.task_ids][:, 0], toy_ds.data[toy_ds.task_ids][:, 1], 
                               marker='o', facecolor='red', s=400, edgecolor='black', linewidth=2)
        
        axes[r, 0].set_ylabel(f'After Task {r}', fontsize=75)

    # Reset filter.
    toy_ds.filter_by_class()

    fig.tight_layout()
    return fig, axes

In [None]:
run_dir = 'results/vargp-toy-seed1'

prev_params = []
preds = []
for t in range(2):
    with torch.no_grad():
        cur_params = torch.load(f'{run_dir}/ckpt{t}.pt')
        gp = VARGP.create_clf(toy_ds, M=20, n_f=100, n_var_samples=20, prev_params=prev_params).to(device)
        gp.load_state_dict(cur_params)

        preds.append(gp.predict(grid_data.reshape(-1, 2)))

    prev_params.append(cur_params)

preds = torch.cat([p.unsqueeze(0) for p in preds], axis=0)

fig, _ = plot_task(preds)
# fig.savefig(f'toy_vargp_density.pdf', bbox_inches='tight')

In [None]:
from var_gp.vargp_retrain import VARGPRetrain

run_dir = 'results/re-vargp-toy'

prev_params = []
preds = []
for t in range(2):
    with torch.no_grad():
        cur_params = torch.load(f'{run_dir}/ckpt{t}.pt')
        gp = VARGPRetrain.create_clf(toy_ds, M=20, n_f=100, n_var_samples=20, prev_params=prev_params).to(device)
        gp.load_state_dict(cur_params)

        preds.append(gp.predict(grid_data.reshape(-1, 2)))

    prev_params.append(cur_params)

preds = torch.cat([p.unsqueeze(0) for p in preds], axis=0)

fig, _ = plot_task(preds)
# fig.savefig(f'toy_vargp_retrain_density.pdf', bbox_inches='tight')

In [None]:
pred_dump = 'results/vcl-toy-seed1'
preds = []
for t in range(2):
    data = np.load(f'{pred_dump}/grid_pred_probs_{t}.npz')
    preds.append(torch.from_numpy(np.squeeze(data['probs'], axis=-1).T).float())

preds = torch.cat([p.unsqueeze(0) for p in preds], axis=0)

fig, _ = plot_task(preds)
# fig.savefig(f'toy_vcl_density.pdf', bbox_inches='tight')