# Analyze models predictions

Analysze the models predictions

In [None]:
import sys
sys.path.append("./../")

In [None]:
import os
import json
import glob

import numpy as np
import torch
from torch.utils.data import DataLoader

import methods
import models
import datasets
import transforms

import matplotlib.pyplot as plt
# plt.style.use('seaborn')

In [None]:
def get_predictions(model, dataset, N=10):
    dataloader = DataLoader(dataset, batch_size=1024, shuffle=False)
    torch.manual_seed(42)

    # Get predictions from model
    ypreds = []
    with torch.no_grad():
        for x, y in dataloader:
            _ypreds, _ = model.sample_predictions(x, n=N)
            if ypreds:
                for i in range(N):
                    ypreds[i] = torch.cat((ypreds[i], _ypreds[i]), dim=0)
            else:
                for y2 in _ypreds:
                    # y2 is a tensor of shape (B, K)
                    ypreds.append(y2)

    # Convert to softmax score from log_softmax
    yprobs = [torch.exp(_ypreds) for _ypreds in ypreds]
    
    return yprobs

Edit `model_dir` variable to point to the model you want.

In [None]:
# model_dir = "./../zoo/experiments-termwise-ablation/BinaryMNISTC-53-identity/LeNet/term-all-beta1010-20220402213949/"
model_dir = "./../zoo/experiments-termwise-ablation/BinaryMNISTC-53-identity/LeNet/term_1_3-20220402202026/"
# Default paths
config_json = os.path.join(model_dir, "config.json")
ckpt_file = glob.glob(model_dir + '/step=*.ckpt')[-1]

config = json.load(open(config_json, 'r'))

MethodClass = getattr(methods, config['method'])
DatasetClass = getattr(datasets, config['dataset']) # Will automatically detect the dataset for testing
ModelClass = getattr(models, config['model'])
TransformClass = getattr(transforms, config['transform'])

testset = DatasetClass(**config['ds_params'], split='test', transform=TransformClass())
K = testset.n_labels

model = MethodClass.load_from_checkpoint(ckpt_file, model=ModelClass(K), strict=False)

## Get predictions from model

You can sample multiple predictions from the model for each realization of parameters by setting $N$. The returned
variable will be a list with $N$ entries, corresponding to each realizations. If you want average over all realization,
you can average over the list entries.

## select images

In [None]:
torch.manual_seed(24)
batch_size = 1000
dataloader = DataLoader(testset, batch_size=batch_size, shuffle=False)
batch = next(iter(dataloader))
sample_images, sample_target = batch
from sklearn.manifold import TSNE
X = sample_images.reshape(batch_size,-1).numpy()
X_embedded = TSNE(n_components=2, learning_rate='auto', init='random').fit_transform(X)
x = X_embedded[:,0]; y = X_embedded[:,1]

## plot

In [None]:
# model_dir = "./../zoo/experiments-termwise-ablation/BinaryMNISTC-53-identity/LeNet/term-all-beta1010-20220402213949/"
model_dir = "models/term_1_3-20220402202026/"
# Default paths
config_json = os.path.join(model_dir, "config.json")
ckpt_file = glob.glob(model_dir + '/step=*.ckpt')[-1]

config = json.load(open(config_json, 'r'))

MethodClass = getattr(methods, config['method'])
DatasetClass = getattr(datasets, config['dataset']) # Will automatically detect the dataset for testing
ModelClass = getattr(models, config['model'])
TransformClass = getattr(transforms, config['transform'])

testset = DatasetClass(**config['ds_params'], split='test', transform=TransformClass())
K = testset.n_labels

model_1 = MethodClass.load_from_checkpoint(ckpt_file, model=ModelClass(K), strict=False)
torch.manual_seed(24)
_ypreds, _ = model_1.sample_predictions(sample_images, n=1)
yprobs = torch.exp(_ypreds[0].detach())
z_1 = yprobs[:,1].numpy()

In [None]:
# model_dir = "./../zoo/experiments-termwise-ablation/BinaryMNISTC-53-identity/LeNet/term-all-beta1010-20220402213949/"
model_dir = "models/term-all-beta0101-20220402214708/"
# Default paths
config_json = os.path.join(model_dir, "config.json")
ckpt_file = glob.glob(model_dir + '/step=*.ckpt')[-1]

config = json.load(open(config_json, 'r'))

MethodClass = getattr(methods, config['method'])
DatasetClass = getattr(datasets, config['dataset']) # Will automatically detect the dataset for testing
ModelClass = getattr(models, config['model'])
TransformClass = getattr(transforms, config['transform'])

testset = DatasetClass(**config['ds_params'], split='test', transform=TransformClass())
K = testset.n_labels

model_2 = MethodClass.load_from_checkpoint(ckpt_file, model=ModelClass(K), strict=False)
torch.manual_seed(24)
_ypreds, _ = model_2.sample_predictions(sample_images, n=1)
yprobs = torch.exp(_ypreds[0].detach())
z_2 = yprobs[:,1].numpy()

In [None]:
# model_dir = "./../zoo/experiments-termwise-ablation/BinaryMNISTC-53-identity/LeNet/term-all-beta1010-20220402213949/"
model_dir = "models/term-all-beta1010-20220402213949/"
# Default paths
config_json = os.path.join(model_dir, "config.json")
ckpt_file = glob.glob(model_dir + '/step=*.ckpt')[-1]

config = json.load(open(config_json, 'r'))

MethodClass = getattr(methods, config['method'])
DatasetClass = getattr(datasets, config['dataset']) # Will automatically detect the dataset for testing
ModelClass = getattr(models, config['model'])
TransformClass = getattr(transforms, config['transform'])

testset = DatasetClass(**config['ds_params'], split='test', transform=TransformClass())
K = testset.n_labels

model_3 = MethodClass.load_from_checkpoint(ckpt_file, model=ModelClass(K), strict=False)
torch.manual_seed(24)
_ypreds, _ = model_3.sample_predictions(sample_images, n=1)
yprobs = torch.exp(_ypreds[0].detach())
z_3 = yprobs[:,1].numpy()

In [None]:
# model_dir = "./../zoo/experiments-termwise-ablation/BinaryMNISTC-53-identity/LeNet/term-all-beta1010-20220402213949/"
model_dir = "models/term-all-beta5050-20220402213330/"
# Default paths
config_json = os.path.join(model_dir, "config.json")
ckpt_file = glob.glob(model_dir + '/step=*.ckpt')[-1]

config = json.load(open(config_json, 'r'))

MethodClass = getattr(methods, config['method'])
DatasetClass = getattr(datasets, config['dataset']) # Will automatically detect the dataset for testing
ModelClass = getattr(models, config['model'])
TransformClass = getattr(transforms, config['transform'])

testset = DatasetClass(**config['ds_params'], split='test', transform=TransformClass())
K = testset.n_labels

model_4 = MethodClass.load_from_checkpoint(ckpt_file, model=ModelClass(K), strict=False)
torch.manual_seed(24)
_ypreds, _ = model_4.sample_predictions(sample_images, n=1)
yprobs = torch.exp(_ypreds[0].detach())
z_4 = yprobs[:,1].numpy()

In [None]:
fig_l, axes_l = plt.subplots(nrows = 1, ncols = 4, figsize = (16, 3))
ax_00, ax_01, ax_02, ax_03= axes_l.flatten()

ax_00.tricontour(x, y, z_1, levels=14, linewidths=0.5, colors='k')
cntr2 = ax_00.tricontourf(x, y, z_1, levels=14, cmap="RdBu_r")
# fig_l.colorbar(cntr2, ax=ax_00)
# ax_00.set(xlim=(-10, 10), ylim=(-20, 0))
# ax_00.plot(x, y, 'ko', ms=3)


ax_01.tricontour(x, y, z_2, levels=14, linewidths=0.5, colors='k')
cntr2 = ax_01.tricontourf(x, y, z_2, levels=14, cmap="RdBu_r")
# fig_l.colorbar(cntr2, ax=ax_01)
# ax_01.set(xlim=(-10, 10), ylim=(-20, 0))
# ax_01.plot(x, y, 'ko', ms=3)

ax_02.tricontour(x, y, z_3, levels=14, linewidths=0.5, colors='k')
cntr2 = ax_02.tricontourf(x, y, z_3, levels=14, cmap="RdBu_r")
# fig_l.colorbar(cntr2, ax=ax_01)
# ax_02.set(xlim=(-10, 10), ylim=(-20, 0))
# ax_02.plot(x, y, 'ko', ms=3)

ax_03.tricontour(x, y, z_4, levels=14, linewidths=0.5, colors='k')
cntr2 = ax_03.tricontourf(x, y, z_4, levels=14, cmap="RdBu_r")
# fig_l.colorbar(cntr2, ax=ax_01)
# ax_03.set(xlim=(-10, 10), ylim=(-20, 0))
# ax_03.plot(x, y, 'ko', ms=3)

In [None]:
fig_l, axes_l = plt.subplots(nrows = 1, ncols = 4, figsize = (18, 3))
ax_00, ax_01, ax_02, ax_03= axes_l.flatten()

xlim = (-20,20); ylim = (-20,20)

levels = 13; ms = 1

ax_00.tricontour(x, y, z_1, levels=levels, linewidths=0.5, colors='k')
cntr2 = ax_00.tricontourf(x, y, z_1, levels=levels, cmap="RdBu_r")
fig_l.colorbar(cntr2, ax=ax_00)
ax_00.set(xlim=xlim, ylim=ylim)
ax_00.plot(x, y, 'ko', ms=ms)


ax_01.tricontour(x, y, z_2, levels=levels, linewidths=0.5, colors='k')
cntr2 = ax_01.tricontourf(x, y, z_2, levels=levels, cmap="RdBu_r")
fig_l.colorbar(cntr2, ax=ax_01)
ax_01.set(xlim=xlim, ylim=ylim)
ax_01.plot(x, y, 'ko', ms=ms)

ax_02.tricontour(x, y, z_3, levels=levels, linewidths=0.5, colors='k')
cntr2 = ax_02.tricontourf(x, y, z_3, levels=levels, cmap="RdBu_r")
fig_l.colorbar(cntr2, ax=ax_02)
ax_02.set(xlim=xlim, ylim=ylim)
ax_02.plot(x, y, 'ko', ms=ms)

ax_03.tricontour(x, y, z_4, levels=levels, linewidths=0.5, colors='k')
cntr2 = ax_03.tricontourf(x, y, z_4, levels=levels, cmap="RdBu_r")
fig_l.colorbar(cntr2, ax=ax_03)
ax_03.set(xlim=xlim, ylim=ylim)
ax_03.plot(x, y, 'ko', ms=ms)