# Synthetic example

In [None]:
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

## Create synthetic data set

In [None]:
D = 10 # number of dimensions of data
R = 10 # repeats of same data point

X0 = 2*np.eye(D) - 1
X = np.repeat(X0, repeats=R, axis=0)
y = np.triu(np.ones((D, R))).flatten().astype(int)

## Train a FCN

In [None]:
import sys
sys.path.append('./../') # for bayesian torch

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import Dataset, DataLoader
from pytorch_lightning import Trainer

from bayesian_torch.layers import LinearReparameterization

### Build model

In [None]:
class FCN(nn.Module):
    def __init__(self, in_dim):
        super().__init__()
        self.fc1 = LinearReparameterization(in_features=in_dim,
                        out_features=2*in_dim,
                        prior_mean=0.0,
                        prior_variance=1.0,
                        posterior_mu_init=0.0,
                        posterior_rho_init=-3.0)
        self.fc2 = LinearReparameterization(in_features=2*in_dim,
                        out_features=2*in_dim,
                        prior_mean=0.0,
                        prior_variance=1.0,
                        posterior_mu_init=0.0,
                        posterior_rho_init=-3.0)
        self.fc3 = LinearReparameterization(in_features=2*in_dim,
                        out_features=2,
                        prior_mean=0.0,
                        prior_variance=1.0,
                        posterior_mu_init=0.0,
                        posterior_rho_init=-3.0)
        
        self.num_classes = 2
        
    def forward(self, x):
        kl_sum = 0
        
        x, kl = self.fc1(x)
        kl_sum += kl
        x = F.relu(x)
        
        x, kl = self.fc2(x)
        kl_sum += kl
        x = F.relu(x)
        
        x, kl = self.fc3(x)
        kl_sum += kl
        
        out = F.log_softmax(x, dim=1)
        
        return out, kl_sum

### Build dataloader

In [None]:
class dataset(Dataset):
    def __init__(self, x, y):
        self.x = torch.tensor(x, dtype=torch.float32)
        self.y = torch.tensor(y)
        self.length = self.x.shape[0]
        
        self.n_labels = 2
 
    def __getitem__(self,idx):
        return self.x[idx], self.y[idx]
    
    def __len__(self):
        return self.length

In [None]:
batch_size = 10
trainset = dataset(X,y)

#DataLoader
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
testloader = DataLoader(trainset, batch_size=10*batch_size, shuffle=False)

### Training for MFVI

In [None]:
from methods import MFVI

In [None]:
epochs = 50
mc_samples = 16

model = FCN(D)
method_params = MFVI.populate_missing_params({}, trainset)
pl_model = MFVI(model, **method_params, mc_samples=mc_samples)

In [None]:
trainer = Trainer(max_epochs=epochs, enable_progress_bar=False)

In [None]:
trainer.fit(pl_model, trainloader, trainloader)

### Test MFVI

In [None]:
m = 1000 # samples for each datapoint

results = {}

with torch.no_grad():
    for j in range(X0.shape[0]): # Take for each point
        x = torch.tensor(X0[j, :].reshape(1, -1), dtype=torch.float32)
        
        output_ = []
        for _m in range(m):
            output, _ = model(x)
            output_.append(torch.exp(output[:, 1]))
        
        preds = torch.cat(output_).numpy()
        
        # Label for this sample
        p0 = j / D # probability that label is zero
        labels = (np.random.rand(*preds.shape) > p0).astype(int) # This will the label distribution for this sample
        
        results[j] = (preds, labels)

## Plot calibration

In [None]:
from sklearn.calibration import calibration_curve
from torchmetrics.functional import calibration_error

import matplotlib.pyplot as plt
plt.style.use('seaborn')

In [None]:
def plot_calibration_curve(results, title):
    ypred = np.concatenate([v[0] for k, v in results.items()])
    ytrue = np.concatenate([v[1] for k, v in results.items()])
    
    nbins = 20
    prob_true, prob_pred = calibration_curve(ytrue, ypred, n_bins=20)
    hist, edges = np.histogram(ypred, bins=np.arange(0.0, 1.01, 1/nbins))
#     hist, edges = np.histogram(ypred)
    hist = hist / hist.sum()
    edges = 0.50 * (edges[:-1] + edges[1:])
    
    ece = calibration_error(torch.tensor(ypred), torch.tensor(ytrue))
    
    fig = plt.figure(figsize=(6, 6))
    plt.scatter(prob_true, prob_pred)
    plt.plot([0, 1], [0, 1], ls=':', c='k', alpha=0.25)
    plt.bar(edges, hist, width=1.0/nbins, align='center', color='k', alpha=0.10)
    plt.xlim(0.0, 1.0)
    plt.ylim(0.0, 1.0)
    plt.title('{} (ECE = {:.3f})'.format(title, ece))
    
    r = {
        'prob_true': prob_true,
        'prob_pred': prob_pred,
        'edges': edges,
        'hist': hist,
        'ece': ece
    }
    
    return r, fig

In [None]:
r_mfvi, fig = plot_calibration_curve(results, title="MFVI")

### Test SL - Sweep $\lambda_{SL}$ value

In [None]:
from methods import SummaryLikelihood as SL

In [None]:
epochs = 50
mc_samples = 16
m = 1000 # samples for each datapoint

results_sl = {}
def run_experiment(lam_sl):
    # Train model
    model = FCN(D)
    method_params = SL.populate_missing_params({'beta': True, 'a': 1, 'b': 1, 'alpha': 100, 'lam_sl': lam_sl}, trainset)
    pl_model = SL(model, **method_params, mc_samples=mc_samples)
    trainer = Trainer(max_epochs=epochs, enable_progress_bar=False)    
    trainer.fit(pl_model, trainloader, trainloader)
    
    # Evaluate
    results = {}
    with torch.no_grad():
        for j in range(X0.shape[0]): # Take for each point
            x = torch.tensor(X0[j, :].reshape(1, -1), dtype=torch.float32)

            output_ = []
            for _m in range(m):
                output, _ = model(x)
                output_.append(torch.exp(output[:, 1]))

            preds = torch.cat(output_).numpy()

            # Label for this sample
            p0 = j / D # probability that label is zero
            labels = (np.random.rand(*preds.shape) > p0).astype(int) # This will the label distribution for this sample

            results[j] = (preds, labels)
            
    r, fig = plot_calibration_curve(results, title = "SL ($\\lambda_{} = {:.0e}$)".format('{SL}', lam_sl))
    
    return r, fig

In [None]:
lam_sl = [0.0, 1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1e0]

results = {}
for v in lam_sl:
    r, fig = run_experiment(lam_sl=v)
    
    results[v] = (r, fig)

#### Plot $\lambda_{SL}$ vs ECE

In [None]:
x = []
y = []
for l, r in results.items():
    x.append(l)
    y.append(r[0]['ece'])
    
plt.loglog(x, y, marker='o', label='SL')
plt.plot(x, results[0.0][0]['ece'] * np.ones(len(x)), ls=':', label='MFVI')
plt.legend()
plt.title("$\lambda_{SL}$ vs ECE")