# Analyze models predictions

In [None]:
import sys
sys.path.append("./../")

In [None]:
import os
import json

import numpy as np
import torch
from torch.utils.data import DataLoader

import methods
import models
import datasets
import transforms

import matplotlib.pyplot as plt
plt.style.use('seaborn')

In [None]:
def plot_marginal_predictions(model, dataset, N=10, title="Predictive marginal distribution"):
    dataloader = DataLoader(dataset, batch_size=1024, shuffle=False)

    # Get predictions from model
    ypreds = []
    with torch.no_grad():
        for x, y in dataloader:
            _ypreds, _ = model.sample_predictions(x, n=N)
            if ypreds:
                for i in range(N):
                    ypreds[i] = torch.cat((ypreds[i], _ypreds[i]), dim=0)
            else:
                for y2 in _ypreds:
                    # y2 is a tensor of shape (B, K)
                    ypreds.append(y2)

    # Convert to softmax score from log_softmax
    yprobs = [torch.exp(_ypreds) for _ypreds in ypreds]
    # Compute mean and std
    yprob_marginal = torch.stack([_yprob.sum(dim=0) for _yprob in yprobs])
    y_std, y_mean = torch.std_mean(yprob_marginal, dim=0)

    # Plot
    fig = plt.figure()
    x_vals = np.arange(2)
    plt.bar(
        x = x_vals - 0.125,
        height = y_mean.numpy(),
        yerr = y_std.numpy() / np.sqrt(N),
        width=0.25,
        label="Marginal labels"
    )
    plt.bar(
        x = x_vals + 0.125,
        height = dataset.n_classes,
        width = 0.225,
        label = "True labels"
    )

    # Set ylim
    ymin = int(np.min((y_mean.numpy().tolist() + dataset.n_classes))) // 50 * 50
    ymax = (int(np.max((y_mean.numpy().tolist() + dataset.n_classes))) // 50 + 1) * 50
    plt.ylim(ymin,ymax)

    plt.xticks(x_vals)
    plt.legend()
    plt.title(title)
    plt.xlabel("Label")

    return fig

In [None]:
# model_dir = "./../zoo/samples/BinaryMNISTC-53-identity/LeNet/mfvi-20220328203213"
# model_dir = "./../zoo/samples/BinaryMNISTC-53-identity/LeNet/sl-20220328204536"
# model_dir = "./../zoo/samples/BinaryMNISTC-53-identity/LeNet/sl-lam1.0-auto-20220328205742"
model_dir = "./../zoo/samples/BinaryMNISTC-53-identity/LeNet/mfvi-lam1.0-auto-20220328211605"

# Default paths
config_json = os.path.join(model_dir, "config.json")
ckpt_file = os.path.join(model_dir, "last.ckpt")

In [None]:
config = json.load(open(config_json, 'r'))

MethodClass = getattr(methods, config['method'])
DatasetClass = getattr(datasets, config['dataset'])
ModelClass = getattr(models, config['model'])
TransformClass = getattr(transforms, config['transform'])

In [None]:
testset = DatasetClass(**config['ds_params'], split='test', transform=TransformClass())
K = testset.n_labels

In [None]:
model = MethodClass.load_from_checkpoint(
            os.path.join(model_dir, "last.ckpt"),
            model=ModelClass(K))

## Plot for testset

In [None]:
fig = plot_marginal_predictions(model, testset, N=10, title=MethodClass.__name__)

## Plot for all corruptions

In [None]:
for corruption in DatasetClass.corruptions:
    testset = DatasetClass(labels=config['ds_params']['labels'], split='test', 
                corruption=corruption, transform=TransformClass())
    fig = plot_marginal_predictions(model, testset, N=10, 
                title=MethodClass.__name__ + " - " + testset.corruption)