# Effect of Summary likelihood term

This notebook explores the effect of scaling summary likelihood term. The models are trained
with `slurm-scripts/submit_mnistc_scale_sl.sh`.

In [None]:
import sys
sys.path.append("./../")

In [None]:
import os
import json
import re
from collections import namedtuple

import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
import torchmetrics

import methods
import models
import datasets
import transforms

import matplotlib.pyplot as plt
import seaborn as sns
plt.style.use('seaborn')

In [None]:
def get_marginal_predictions(model, dataset, N=10):
    dataloader = DataLoader(dataset, batch_size=1024, shuffle=False)

    # Get predictions from model
    ytrue = []
    ypreds = []
    with torch.no_grad():
        for x, y in dataloader:
            _ypreds, _ = model.sample_predictions(x, n=N)
            if ypreds:
                for i in range(N):
                    ypreds[i] = torch.cat((ypreds[i], _ypreds[i]), dim=0)
            else:
                for y2 in _ypreds:
                    # y2 is a tensor of shape (B, K)
                    ypreds.append(y2)
            ytrue.append(y)

    # Convert to softmax score from log_softmax
    yprobs = [torch.exp(_ypreds) for _ypreds in ypreds]
    # Compute mean and std
    yprob_marginal = torch.stack([_yprob.sum(dim=0) for _yprob in yprobs])
    y_std, y_mean = torch.std_mean(yprob_marginal, dim=0)

    # compure accuracy
    ytrue = torch.cat(ytrue, dim=0)
    acc = []
    pred_entropy = []
    ece = []
    for _y in yprobs:
        _acc = torchmetrics.functional.accuracy(_y, ytrue)
        acc.append(_acc.numpy())
        
        _ent = torch.mean(torch.sum(_y * -torch.log(_y), dim=1))
        pred_entropy.append(_ent.numpy())
        
        _ece = torchmetrics.functional.calibration_error(_y, ytrue, n_bins=10)
        ece.append(_ece.numpy())
    ll = []
    for _y in ypreds: # ypreds are log-softmax
        _ll = -torch.nn.functional.nll_loss(_y, ytrue)
        ll.append(_ll.numpy())
        
    y_mean = y_mean.numpy()
    y_std = y_std.numpy() / np.sqrt(N)

    # accuracy
    acc_mean = np.mean(acc)
    acc_std = np.std(acc) / np.sqrt(N)
    
    # predictive entropy
    ent_mean = np.mean(pred_entropy)
    ent_std = np.std(pred_entropy) / np.sqrt(N)
    
    # ECE
    ece_mean = np.mean(ece)
    ece_std = np.std(ece) / np.sqrt(N)
    
    # LL
    ll_mean = np.mean(ll)
    ll_std = np.std(ll) / np.sqrt(N)
    
    return (
        y_mean, y_std, 
        acc_mean, acc_std, 
        ent_mean, ent_std, 
        ece_mean, ece_std,
        ll_mean, ll_std
    )

In [None]:
root_dir = "./../zoo/sl-scaling/BinaryMNISTC-53-identity/LeNet/"

# model_dirs = [f for f in os.listdir(root_dir) if re.match(r'.+-nw-[1]-\d+', f)]
model_dirs = os.listdir(root_dir)

In [None]:
Result = namedtuple("Result",
    "model_id lam_sl corruption " 
    "y_mean_0 y_mean_1 y_std_0 y_std_1 " 
    "acc_mean acc_std " 
    "ent_mean ent_std "
    "ece_mean ece_std "
    "ll_mean ll_std")
results = []
for _model_dir in model_dirs:
    model_dir = os.path.join(root_dir, _model_dir)
    # Default paths
    config_json = os.path.join(model_dir, "config.json")
    ckpt_file = os.path.join(model_dir, "last.ckpt")

    config = json.load(open(config_json, 'r'))

    MethodClass = getattr(methods, config['method'])
    DatasetClass = getattr(datasets, config['dataset'])
    ModelClass = getattr(models, config['model'])
    TransformClass = getattr(transforms, config['transform'])

    testset = DatasetClass(**config['ds_params'], split='test', transform=TransformClass())
    K = testset.n_labels

    model = MethodClass.load_from_checkpoint(
            os.path.join(model_dir, "last.ckpt"),
            model=ModelClass(K))
    
    # Test for different corruptions
    for corruption in testset.corruptions:
#     for corruption in ['identity']:
        testset = DatasetClass(
                    config['ds_params']['labels'], 
                    split='test', 
                    corruption=corruption,
                    transform=TransformClass())

        (
            y_mean, y_std, 
            acc_mean, acc_std, 
            ent_mean, ent_std, 
            ece_mean, ece_std,
            ll_mean, ll_std
        ) = get_marginal_predictions(model, testset, N=50)

    
        results.append(
            Result(
                _model_dir, model.lam_sl, corruption,
                y_mean[0], y_mean[1], y_std[0], y_std[1],
                acc_mean, acc_std,
                ent_mean, ent_std,
                ece_mean, ece_std,
                ll_mean, ll_std
            )
        )

In [None]:
df_results = pd.DataFrame(results).sort_values(by='lam_sl').reset_index(drop=True)

In [None]:
df_results

## Plot predictive marginal distribution

In [None]:
df = df_results[df_results.corruption == 'identity']
gdf = df.groupby(by='lam_sl')

fig = plt.figure(figsize=(12, 6))
for k, _df in gdf:
    y_mean = _df.y_mean_1
    x = [k] * len(y_mean)
    plt.scatter(x, y_mean, marker='o', s= 50 + 30 * _df.y_std_1)
plt.hlines(testset.n_classes[1], 
            1e-5, 1e+2,
            colors='k', linestyles=':')
plt.xscale('log')
plt.title("Predictive marginal distribution")
plt.xlabel("$\\lambda_{SL}$")
plt.ylabel("Predictive probability")
plt.savefig("pred-mar-vs-lamsl.png", bbox_inches='tight')

## Plot predictive accuracy

In [None]:
df = df_results[df_results.corruption == 'identity']
gdf = df.groupby(by='lam_sl')

fig = plt.figure(figsize=(12, 6))
for k, _df in gdf:
    y_mean = _df.acc_mean
    x = [k] * len(y_mean)
    plt.scatter(x, y_mean, marker='o', s= 50 + 1e5 * _df.acc_std)
plt.xscale('log')
plt.title("Accuracy of model")
plt.xlabel("$\\lambda_{SL}$")
plt.ylabel("Accuracy")
plt.savefig("pred-acc-vs-lamsl.png", bbox_inches='tight')

## Plot predictive entropy

In [None]:
df = df_results[df_results.corruption == 'identity']
gdf = df.groupby(by='lam_sl')

fig = plt.figure(figsize=(12, 6))
for k, _df in gdf:
    y_mean = _df.ent_mean
    x = [k] * len(y_mean)
    plt.scatter(x, y_mean, marker='o', s= 50 + 1e5 * _df.ent_std)
plt.xscale('log')
plt.title("Predictive entropy")
plt.xlabel("$\\lambda_{SL}$")
plt.ylabel("Entropy")
plt.savefig("pred-ent-vs-lamsl.png", bbox_inches='tight')

## Plot predictive loglikelihood

In [None]:
df = df_results[df_results.corruption == 'identity']
gdf = df.groupby(by='lam_sl')

fig = plt.figure(figsize=(12, 6))
for k, _df in gdf:
    y_mean = _df.ll_mean
    x = [k] * len(y_mean)
    plt.scatter(x, y_mean, marker='o', s= 50 + 1e5 * _df.ll_std)
plt.xscale('log')
plt.title("Loglikelihood")
plt.xlabel("$\\lambda_{SL}$")
plt.ylabel("Loglikelihood")
plt.savefig("ll-vs-lamsl.png", bbox_inches='tight')

## Plot expected calibration error

In [None]:
df = df_results[df_results.corruption == 'identity']
gdf = df.groupby(by='lam_sl')

fig = plt.figure(figsize=(12, 6))
for k, _df in gdf:
    y_mean = _df.ece_mean
    x = [k] * len(y_mean)
    plt.scatter(x, y_mean, marker='o', s= 50 + 1e5 * _df.ece_std)
plt.xscale('log')
plt.title("ECE")
plt.xlabel("$\\lambda_{SL}$")
plt.ylabel("Expected Calibration Error")
plt.savefig("ece-vs-lamsl.png", bbox_inches='tight')

## Plot average accuracy for each corruption

In [None]:
df = df_results.groupby(['lam_sl', 'corruption'])['acc_mean'].mean().reset_index()
df = df.pivot("corruption", "lam_sl", "acc_mean")

fig, ax = plt.subplots(figsize=(16, 8))
sns.heatmap(df, ax=ax)
ax.set(xlabel="$\\lambda_{SL}$")
plt.title("Average accuracy")

plt.savefig("corruption-acc-vs-lamsl.png", bbox_inches='tight')

## Plot average predictive entropy for different corruptions

In [None]:
df = df_results.groupby(['lam_sl', 'corruption'])['ent_mean'].mean().reset_index()
df = df.pivot("corruption", "lam_sl", "ent_mean")

fig, ax = plt.subplots(figsize=(16, 8))
sns.heatmap(df, ax=ax)
ax.set(xlabel="$\\lambda_{SL}$")
plt.title("Predictive entropy")

plt.savefig("corruption-ent-vs-lamsl.png", bbox_inches='tight')