# Analyze models predictions

In [None]:
import sys
sys.path.append("./../")

In [None]:
import os
import json

import numpy as np
import torch
from torch.utils.data import DataLoader

import methods
import models
import datasets
import transforms

import matplotlib.pyplot as plt
plt.style.use('seaborn')

In [None]:
def get_marginal_predictions(model, dataset, N=10):
    dataloader = DataLoader(dataset, batch_size=1024, shuffle=False)

    # Get predictions from model
    ypreds = []
    with torch.no_grad():
        for x, y in dataloader:
            _ypreds, _ = model.sample_predictions(x, n=N)
            if ypreds:
                for i in range(N):
                    ypreds[i] = torch.cat((ypreds[i], _ypreds[i]), dim=0)
            else:
                for y2 in _ypreds:
                    # y2 is a tensor of shape (B, K)
                    ypreds.append(y2)

    # Convert to softmax score from log_softmax
    yprobs = [torch.exp(_ypreds) for _ypreds in ypreds]
    # Compute mean and std
    yprob_marginal = torch.stack([_yprob.sum(dim=0) for _yprob in yprobs])
    y_std, y_mean = torch.std_mean(yprob_marginal, dim=0)

    y_mean = y_mean.numpy()
    y_std = y_std.numpy() / np.sqrt(N)

    return y_mean, y_std

In [None]:
models_dict = {
    "MFVI": "./../zoo/samples/BinaryMNISTC-53-identity/LeNet/mfvi-20220328203213",
    "MFVI ($\\lambda = 1.0$)": "./../zoo/samples/BinaryMNISTC-53-identity/LeNet/mfvi-lam1.0-auto-20220328211605",
    "Summary Likelihood": "./../zoo/samples/BinaryMNISTC-53-identity/LeNet/sl-20220328204536",
    "Summary Likelihood ($\\lambda = 1.0$)": "./../zoo/samples/BinaryMNISTC-53-identity/LeNet/sl-lam1.0-auto-20220328205742"
}

In [None]:
results = {}
for model_name, model_dir in models_dict.items():
    # Default paths
    config_json = os.path.join(model_dir, "config.json")
    ckpt_file = os.path.join(model_dir, "last.ckpt")

    config = json.load(open(config_json, 'r'))

    MethodClass = getattr(methods, config['method'])
    DatasetClass = getattr(datasets, config['dataset'])
    ModelClass = getattr(models, config['model'])
    TransformClass = getattr(transforms, config['transform'])

    testset = DatasetClass(**config['ds_params'], split='test', transform=TransformClass())
    K = testset.n_labels

    model = MethodClass.load_from_checkpoint(
            os.path.join(model_dir, "last.ckpt"),
            model=ModelClass(K))

    y_mean, y_std = get_marginal_predictions(model, testset, N=50)

    results[model_name] = (y_mean, y_std)

In [None]:
pen_colors = plt.cm.get_cmap('tab10')

In [None]:
plt.figure()
x_vals = np.arange(2)
n = len(results) + 1
w = 2 / (3 * n)

plt.figure()

# plt.bar(
#     x = x_vals,
#     height = testset.n_classes,
#     width = w * 0.85,
#     label = "True labels",
#     color = pen_colors(0)
# )

for i, (label, result) in enumerate(results.items()):
    y_mean, y_std = result
    plt.bar(
        x = x_vals + i * w,
        height = y_mean,
        yerr = y_std,
        width = w * 0.85,
        label = label,
        color = pen_colors(i+1)
    )

plt.hlines(testset.n_classes[0], 
            x_vals[0] - w, x_vals[0] + n * w,
            colors='k', linestyles=':')
plt.hlines(testset.n_classes[1], 
            x_vals[1] - w, x_vals[1] + n * w,
            colors='k', linestyles=':')

plt.legend()
plt.xticks(ticks=x_vals+n/2*w-1*w, labels=x_vals)
plt.ylim(750, 1100)
plt.title("Predictive marginal distribution")
plt.xlabel("Labels")
plt.ylabel("Predictive probability")

In [None]:
results_corruptions = {}
for model_name, model_dir in models_dict.items():
    # Default paths
    config_json = os.path.join(model_dir, "config.json")
    ckpt_file = os.path.join(model_dir, "last.ckpt")

    config = json.load(open(config_json, 'r'))

    MethodClass = getattr(methods, config['method'])
    ModelClass = getattr(models, config['model'])
    TransformClass = getattr(transforms, config['transform'])

    model = MethodClass.load_from_checkpoint(
            os.path.join(model_dir, "last.ckpt"),
            model=ModelClass(K))

    for corruption in testset.corruptions:
        testset = DatasetClass(
                    config['ds_params']['labels'], 
                    split='test', 
                    corruption=corruption,
                    transform=TransformClass())
        K = testset.n_labels

        y_mean, y_std = get_marginal_predictions(model, testset, N=50)

        if model_name not in results_corruptions:
            results_corruptions[model_name] = {}
        results_corruptions[model_name][corruption] = (y_mean, y_std)

In [None]:
plt.figure(figsize=(30, 6))
x_vals = np.arange(len(testset.corruptions))
n = len(models_dict)
w = 2 / (3 * n)


for j, (label, r) in enumerate(results_corruptions.items()):
    _x = x_vals + (j) * w
    _y = [r[c][0][1] for c in testset.corruptions]
    _y_err = [r[c][1][1] for c in testset.corruptions]

    plt.bar(
        _x, 
        height = _y,
        yerr = _y_err,
        width = w * 0.85,
        label = label,
        color = pen_colors(j+1)
        )

plt.hlines(testset.n1, x_vals[0]-w, x_vals[-1]+n*0.2, colors='k', linestyles=':')

plt.legend()
plt.xticks(ticks=x_vals+n/2*w-0.5*w, labels=testset.corruptions)
plt.ylim(800, 1200)
plt.title("Predictive marginal distribution for label 1 with different corruptions")
plt.xlabel("Corruptions")
plt.ylabel("Predictive probability")