In [None]:
%load_ext autoreload
%autoreload 2
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torchvision.datasets import MNIST
import torchvision.transforms as T
from torch.utils.data import Subset, DataLoader
from uncertainty.ggn import GGNMatVecOperator
from uncertainty.evaluation_slu import SLUEvaluator
from sketch.sketch_srft import SRFTSketcher
from solvers.sketched_lanczos import SketchedLanczos
import numpy as np

### workflow:
1. load models and datasets (ID, OoD)
2. define GGN-vector product
3. generate Us
4. compute SLU and AUROC

In [2]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"device: {device}")

device: mps


In [None]:
# 1. Load MNIST
transform = T.Compose([T.ToTensor(), T.Lambda(lambda x: x.view(-1))])
mnist = MNIST(root="./data", train=True, download=True, transform=transform)
loader = DataLoader(mnist, batch_size=2000, shuffle=True)
X_batch, Y_batch = next(iter(loader))
X_batch = X_batch.float().to(device)
Y_batch = Y_batch.long().to(device)

# 2. Small model
class SmallNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(784, 32),
            nn.ReLU(),
            nn.Linear(32, 16),
            nn.ReLU(),
            nn.Linear(16, 10)
        )
    def forward(self, x):
        return self.net(x)
# device = torch.device("cpu")

In [123]:
model = SmallNet()
model.load_state_dict(torch.load("models/mlp.pt"))
model.eval()

SmallNet(
  (net): Sequential(
    (0): Linear(in_features=784, out_features=32, bias=True)
    (1): ReLU()
    (2): Linear(in_features=32, out_features=16, bias=True)
    (3): ReLU()
    (4): Linear(in_features=16, out_features=10, bias=True)
  )
)

In [124]:
ggn_op = GGNMatVecOperator(model, X_batch, Y_batch, device=device)

  self.fmodel, self.params = make_functional(model)


In [125]:
num_params = sum(p.numel() for p in model.parameters())
steps = 100
s = 2 * steps
srft = SRFTSketcher(p = num_params, s=s)
gaussian = torch.randn(s, num_params).to(device)
# class IdentitySketch:
#     def __init__(self, p):
#         self.p = self.s = p
#     def apply_sketch(self, x):
#         return x
# srft = IdentitySketch(p=num_params)
solver = SketchedLanczos(G_matvec=ggn_op.numpy_interface, p=num_params, sketch=srft)
solver.run(num_steps=steps)
Us = solver.get_basis()

In [126]:
# 1. Prepare ID data
id_X = X_batch[:500]  # Use your current training batch as ID

# 2. Prepare OoD data (KMNIST)
transform = transforms.Compose([
    transforms.ToTensor()
])

ood_dataset = datasets.KMNIST(root='./data', train=False, download=True, transform=transform)
ood_loader = torch.utils.data.DataLoader(ood_dataset, batch_size=500, shuffle=True)

ood_X_list = []
for ood_batch, _ in ood_loader:
    ood_X_list.append(ood_batch)
    if len(ood_X_list) * 64 > len(id_X):
        break
ood_X = torch.cat(ood_X_list, dim=0)[:len(id_X)]  # match size with ID
ood_X = ood_X.view(len(ood_X), -1) 

In [127]:
slu_evaler = SLUEvaluator(model, Us, srft, device=device)

  self.fmodel, self.params = make_functional(self.model)


In [128]:
slu_evaler.compute_auroc(X_id=id_X, X_ood=ood_X)

  J = jacrev(lambda p: self._compute_outputs(p, x))(self.params_flat)


0.686968