In [1]:
import types
# PyTorch
import torch
import torchvision
import torchmetrics

In [2]:
import sys
sys.path.append('../src/')

%load_ext autoreload
%autoreload 2
# Importing our custom module(s)
import losses
import utils

In [3]:
dataset_directory = '/cluster/tufts/hugheslab/eharve06/CIFAR-10'
n = 1000
tune = False
random_state = 1001
augmented_train_dataset, train_dataset, val_or_test_dataset = utils.get_cifar10_datasets(dataset_directory, n, tune, random_state)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
batch_size = 128
num_workers = 0
augmented_train_loader = torch.utils.data.DataLoader(augmented_train_dataset, batch_size=min(batch_size, len(augmented_train_dataset)), shuffle=True, num_workers=num_workers, drop_last=True)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=min(batch_size, len(train_dataset)), num_workers=num_workers)
val_or_test_loader = torch.utils.data.DataLoader(val_or_test_dataset, batch_size=batch_size, num_workers=num_workers)

In [5]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

cuda:0


In [6]:
first_checkpoint_path = '/cluster/tufts/hugheslab/eharve06/data-emphasized-ELBo/experiments/retrained_CIFAR-10_VI/l2-sp_kappa=23528.522_lr_0=0.001_n=1000_random_state=1001.pt'
first_checkpoint = torch.load(first_checkpoint_path, map_location=torch.device('cpu'), weights_only=False)

In [7]:
num_classes = 10

model = torchvision.models.resnet50()
model.fc = torch.nn.Linear(in_features=2048, out_features=num_classes, bias=True)
model.sigma_param = torch.nn.Parameter(torch.log(torch.expm1(torch.tensor(1e-4, device=device))))
utils.add_variational_layers(model, model.sigma_param)
model.use_posterior = types.MethodType(utils.use_posterior, model)
model.load_state_dict(first_checkpoint)
model.to(device)

bb_loc = torch.load('/cluster/tufts/hugheslab/eharve06/resnet50_torchvision/resnet50_torchvision_mean.pt', map_location=torch.device('cpu'), weights_only=False).to(device)
criterion = losses.L2KappaELBOLoss(bb_loc, 23528522/n, model.sigma_param)

val_or_test_metrics = utils.evaluate(model, criterion, val_or_test_loader, num_classes=num_classes)
print(val_or_test_metrics['acc'])
print(val_or_test_metrics['nll'])

0.8739000558853149
0.38879576792716974


In [8]:
model.use_posterior(True)
num_samples = 50
sample_metrics = [utils.evaluate(model, criterion, val_or_test_loader, num_classes=num_classes) for _ in range(num_samples)]

In [9]:
labels = torch.stack([torch.stack(metrics['labels']) for metrics in sample_metrics])
logits = torch.stack([torch.stack(metrics['logits']) for metrics in sample_metrics])
probs = torch.nn.functional.softmax(logits, dim=2)

In [10]:
acc = torchmetrics.Accuracy(task='multiclass', num_classes=num_classes, average='macro')

In [11]:
num_samples = 1
print(acc(probs[:num_samples].mean(dim=0), labels[0]).item())

0.8718000650405884


In [12]:
num_samples = 5
print(acc(probs[:num_samples].mean(dim=0), labels[0]).item())

0.8752999305725098


In [13]:
num_samples = 10
print(acc(probs[:num_samples].mean(dim=0), labels[0]).item())

0.8759999871253967


In [14]:
num_samples = 50
print(acc(probs[:num_samples].mean(dim=0), labels[0]).item())

0.8747000098228455
