In [1]:
import numpy as np
import torch
from torch import nn
from torch.utils.data import TensorDataset, DataLoader
import torch.nn.functional as F

from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
from sklearn import metrics

from alpaca.ue import MCDUE
from alpaca.utils.datasets.builder import build_dataset
from alpaca.utils.ue_metrics import ndcg
from alpaca.ue.masks import BasicBernoulliMask
import alpaca.nn as ann
from alpaca.utils import model_builder

In [2]:
# Load dataset
mnist = build_dataset('mnist', val_size=10000)
x_train, y_train = mnist.dataset('train')
x_val, y_val = mnist.dataset('val')
x_shape = (-1, 1, 28, 28)

train_ds = TensorDataset(torch.FloatTensor(x_train.reshape(x_shape)), torch.LongTensor(y_train))
val_ds = TensorDataset(torch.FloatTensor(x_val.reshape(x_shape)), torch.LongTensor(y_val))
train_loader = DataLoader(train_ds, batch_size=512)
val_loader = DataLoader(val_ds, batch_size=512)

In [3]:
class SimpleConv(nn.Module):
    def __init__(self, num_classes=10, activation=None, dropout_rate=0.5, dropout_mask=None):
        if activation is None:
            self.activation = F.leaky_relu
        else:
            self.activation = activation
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, 3)
        self.conv2 = nn.Conv2d(16, 32, 3)
        self.linear_size = 12*12*32
        self.fc1 = nn.Linear(self.linear_size, 256)
        self.dropout = ann.Dropout()
        self.fc2 = nn.Linear(256, num_classes)

    def forward(self, x):
        x = self.activation(self.conv1(x))
        x = self.activation(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, self.linear_size)
        x = self.activation(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

In [4]:
model = SimpleConv()
model = model_builder.build_model(model, dropout_rate=0.5, dropout_mask=None)

In [5]:
def train_model():
    # Train model
    model = SimpleConv(dropout_mask=BasicBernoulliMask)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters())

    for x_batch, y_batch in train_loader: # Train for one epoch
        print('.', end='')
        prediction = model(x_batch)
        optimizer.zero_grad()
        loss = criterion(prediction, y_batch)
        loss.backward()
        optimizer.step()
    print('\nTrain loss on last batch', loss.item())
    return model

In [6]:
model = train_model(model)

.
Train loss on last batch 5.570048809051514


In [11]:
x_batch, y_batch = next(iter(val_loader))
class_preds = F.softmax(model(x_batch), dim=-1).detach().numpy()
# Calculate uncertainty estimation
estimator = MCDUE(model, num_classes=10, acquisition="bald_normed")
predictions, estimations = estimator.estimate(x_batch)

Uncertainty estimation with MCDUE_classification approach: 100%|██████████| 25/25 [00:02<00:00,  8.72it/s]


In [12]:
# Calculate NDCG score for the uncertainty
errors = [metrics.log_loss(target.reshape(-1, 1), pred.reshape((-1, 10)), labels=list(range(10))) for pred, target in zip(class_preds, y_batch.numpy())]
score = ndcg(np.array(errors), estimations)
print("Quality score is ", score)

Quality score is  0.7576639516164719


## Try different mask

In [15]:
model = SimpleConv()
model = model_builder.build_model(model, dropout_rate=0.5, dropout_mask=BasicBernoulliMask())

In [16]:
model = train_model(model)

.
Train loss on last batch 6.535654544830322


In [17]:
x_batch, y_batch = next(iter(val_loader))
class_preds = F.softmax(model(x_batch), dim=-1).detach().numpy()
# Calculate uncertainty estimation
estimator = MCDUE(model, num_classes=10, acquisition="bald_normed")
predictions, estimations = estimator.estimate(x_batch)

Uncertainty estimation with MCDUE_classification approach: 100%|██████████| 25/25 [00:03<00:00,  8.26it/s]


In [18]:
# Calculate NDCG score for the uncertainty
errors = [metrics.log_loss(target.reshape(-1, 1), pred.reshape((-1, 10)), labels=list(range(10))) for pred, target in zip(class_preds, y_batch.numpy())]
score = ndcg(np.array(errors), estimations)
print("Quality score is ", score)

Quality score is  0.7989338552446705
