In [None]:
import sys
sys.path.append('..')
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import os
from pathlib import Path

import torch
import numpy as np
import numpy.linalg as la
import pandas as pd
import seaborn as sns
from sklearn.metrics import accuracy_score
from scipy.special import softmax as softmax
import matplotlib.pyplot as plt
from sklearn.cluster import SpectralCoclustering
from dppy.finite_dpps import FiniteDPP

from dataloader.builder import build_dataset
from experiment_setup import build_estimator
from uncertainty_estimator.masks import build_masks, DEFAULT_MASKS
from analysis.metrics import uq_ndcg

from model.cnn import SimpleConv, MediumConv, StrongConv
from model.trainer import Trainer, EnsembleTrainer


# torch.cuda.set_device(1)
#

In [None]:

plt.rcParams['figure.facecolor'] = 'white'
config = {
    'use_cuda': True,
    'seed': 1,
    
    'nn_runs': 150,
    'patience': 5,
    'dropout_uq': 0.5,
    'dropout_train': 0.3,
    
    'n_models': 10, 
    
    # 'dataset': 'mnist',
    'dataset': 'cifar_10',
   
    'model_class': StrongConv,
    'train_samples': 45_000,
    'epochs': 50,
    'batch_size': 256,
    'log_interval': 150,
    'lr': 1e-2,
    'num_classes': 9
}
restore_model = False 


In [None]:
dataset = build_dataset(config['dataset'], val_size=10_000)
x_train, y_train = dataset.dataset('train')
x_val, y_val = dataset.dataset('val')

def cut_class(x, y, class_num):
    new_x = x[np.where(y!=class_num)]
    new_y = y[np.where(y!=class_num)]
    ood = x[np.where(y==class_num)]
    return new_x, new_y, ood

x_train, y_train, x_ood = cut_class(x_train, y_train, '9')
x_val, y_val, _ = cut_class(x_val, y_val, '9')



In [None]:
def scale(images):
    return (images - 128) / 128
x_train = scale(x_train)
x_val = scale(x_val)
x_ood = scale(x_ood)

In [None]:
if config['dataset'] == 'mnist':
    input_shape = (-1, 1, 28, 28)
elif config['dataset'] == 'cifar_10':
    input_shape = (-1, 3, 32, 32)
x_train = x_train.reshape(input_shape)
x_val = x_val.reshape(input_shape)
x_ood = x_ood.reshape(input_shape)

y_train = y_train.astype('long').reshape(-1)
y_val = y_val.astype('long').reshape(-1)



In [None]:
train_samples = config['train_samples']
epochs = config['epochs']
val_samples = 2000
patience = config['patience']

idxs = np.random.choice(len(x_train), train_samples, replace=False)
train_set = (x_train[idxs], y_train[idxs])

idxs = np.random.choice(len(x_val), val_samples, replace=False)
val_set = (x_val[idxs], y_val[idxs]) 
print(train_set[0].shape)
print(val_set[0].shape)

In [None]:
model_class = config['model_class'] 
model = model_class(config['num_classes'], activation=torch.nn.functional.celu)

# model_dir = Path('experiments') / 'data'
model_dir = Path('data')
model_path = model_dir / f"model_{config['dataset']}_{config['model_class'].__name__}.pt"


In [None]:
trainer = Trainer(model, dropout_train=config['dropout_train'])

if restore_model and os.path.exists(model_path):
    model.load_state_dict(torch.load(model_path))
else:
    trainer.fit(
        train_set, val_set , epochs=epochs, verbose=True, patience=patience)
    torch.save(model.state_dict(), model_path)

In [None]:
mask_type = 'l_dpp_htnorm'
masks = build_masks([mask_type])
mask = masks[mask_type]
estimator = build_estimator(
    'bald_masked', trainer, nn_runs=config['nn_runs'], dropout_mask=mask,
    dropout_rate=config['dropout_uq'], num_classes=config['num_classes'])


In [None]:
estimator.estimate(val_set[0])

In [None]:
corr = mask.layer_correlations[1]
corr.shape

In [None]:
norm = mask.norm[1].detach().cpu().numpy()
norm.shape


In [None]:
plt.figure(figsize=(20, 14))
sns.heatmap(corr)

In [None]:
for n in range(10, 11):
    plt.figure(figsize=(14, 12))
    plt.title(n)
    model = SpectralCoclustering(n_clusters=n, random_state=0)
    model.fit(corr)
    fit_data = corr[np.argsort(model.row_labels_)]
    fit_data = fit_data[:, np.argsort(model.column_labels_)]
    sns.heatmap(fit_data)
    plt.show()

In [None]:
def get_averages(samples, values, norm=None):
    avgs = []
    for sample in samples:
        sample_mask = np.zeros_like(values)
        if norm is None:
            sample_mask[sample] = len(values) / len(sample)
        else:
            sample_mask[sample] = norm[sample] 
        sample_values = values * sample_mask
        avgs.append(np.average(sample_values))
    return avgs


In [None]:
layer_runs = mask.layer_runs[1]

true_avgs, naive_avgs, ht_avgs, bern_avgs = [], [], [], []
naive_stds, ht_stds, bern_stds = [], [], []

norm = mask.norm[1].detach().cpu().numpy()
bern_mask = build_masks(['basic_bern'])['basic_bern']

ht_raws, bern_raws = [], []


for i in range(75):
    print(i, end=' ')
    values = layer_runs[0][i]

    dpp_1 = FiniteDPP('likelihood', **{'L': corr})
    [dpp_1.sample_exact() for _ in range(100)]
    samples = dpp_1.list_of_samples

    true_avgs.append(np.average(values))

    naive_avgs.append(np.average(get_averages(samples, values)))
    naive_stds.append(np.std(get_averages(samples, values)))
    
    ht_average = get_averages(samples, values, norm)
    ht_avgs.append(np.average(ht_average))
    ht_stds.append(np.std(ht_average))
    ht_raws.append(ht_average)
    
    bern_average = [
        np.average(torch.Tensor(values) * bern_mask(torch.Tensor(values)))
        for _ in range(100)]
    bern_avgs.append(np.average(bern_average))
    bern_stds.append(np.std(bern_average))
    bern_raws.append(bern_average)


In [None]:
import pandas as pd
df = pd.DataFrame({
    'DPP naive norm': np.array(naive_avgs) - np.array(true_avgs),
    'DPP HT norm': np.array(ht_avgs) - np.array(true_avgs),
    'Bernoulli': np.array(bern_avgs) - np.array(true_avgs),
})
plt.figure(figsize=(12, 8))
plt.title("Deviation of average for different normalizations, cifar")
sns.barplot(data=df)


In [None]:
df = pd.DataFrame({
    'DPP naive norm': np.array(naive_stds),
    'DPP HT norm': np.array(ht_stds),
    'Bernoulli': np.array(bern_stds)
})
plt.figure(figsize=(12, 8))
plt.title("Deviation std of average for different normalizations, cifar")
sns.barplot(data=df)


In [None]:
plt.figure(figsize=(12, 8))
plt.title('Std, ht vs bern')
plt.xlabel('Point std by HT')
plt.ylabel('Point std by Benroulli')
plt.scatter(np.std(ht_raws, axis=1), np.std(bern_raws, axis=1))


In [None]:
for i in range(6, 7):
    ht_points = ht_raws[i]
    bern_points = bern_raws[i]
    plt.figure(figsize=(12, 8))
    plt.title('Sample averages, ht vs bern')
    plt.xlabel('Sample average by HT')
    plt.ylabel('Sample average by Benroulli')
    plt.scatter(ht_points, bern_points)
    plt.show()


In [None]:
np.argmax(np.std(ht_raws, axis=1) - np.std(bern_raws, axis=1))
