In [None]:
import numpy as np
import torch
from torch import nn
from torch.utils.data import TensorDataset, DataLoader
import torch.nn.functional as F

from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
from sklearn import metrics

from alpaca.uncertainty_estimator import build_estimator
from alpaca.model.cnn import SimpleConv
from alpaca.dataloader.builder import build_dataset
from alpaca.analysis.metrics import ndcg


In [None]:
# Load dataset
mnist = build_dataset('mnist', val_size=10_000)
x_train, y_train = mnist.dataset('train')
x_val, y_val = mnist.dataset('val')
x_shape = (-1, 1, 28, 28)

train_ds = TensorDataset(torch.FloatTensor(x_train.reshape(x_shape)), torch.LongTensor(y_train))
val_ds = TensorDataset(torch.FloatTensor(x_val.reshape(x_shape)), torch.LongTensor(y_val))
train_loader = DataLoader(train_ds, batch_size=512)
val_loader = DataLoader(val_ds, batch_size=512)


In [None]:
# Train model
model = SimpleConv()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())

for x_batch, y_batch in train_loader: # Train for one epoch
    print('.', end='')
    prediction = model(x_batch)
    optimizer.zero_grad()
    loss = criterion(prediction, y_batch)
    loss.backward()
    optimizer.step()
print('\nTrain loss on last batch', loss.item())

# Check accuracy
x_batch, y_batch = next(iter(val_loader))

class_preds = F.softmax(model(x_batch), dim=-1).detach().numpy()
predictions = np.argmax(class_preds, axis=-1)
print('Accuracy', accuracy_score(predictions, y_batch))


In [None]:
# Calculate uncertainty estimation
estimator = build_estimator("bald", model, dropout_mask='mc_dropout', num_classes=10)
estimations = estimator.estimate(x_batch)


In [None]:
# Calculate NDCG score for the uncertainty
errors = [metrics.log_loss(target.reshape(-1, 1), pred.reshape((-1, 10)), labels=list(range(10))) for pred, target in zip(class_preds, y_batch.numpy())]

score = ndcg(np.array(errors), estimations)
print("Quality score is ", score)