In [None]:
import numpy as np
import torch
from torch import nn
from torch.utils.data import TensorDataset, DataLoader
import torch.nn.functional as F

from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
from sklearn import metrics

from alpaca.uncertainty_estimator import build_estimator
from alpaca.model.cnn import SimpleConv
from alpaca.dataloader.builder import build_dataset
from alpaca.analysis.metrics import ndcg


In [None]:
# Load dataset
mnist = build_dataset('mnist', val_size=50_000)
x_train, y_train = mnist.dataset('train')
x_val, y_val = mnist.dataset('val')
x_shape = (-1, 1, 28, 28)

train_ds = TensorDataset(torch.FloatTensor(x_train.reshape(x_shape)), torch.Double(y_train))
val_ds = TensorDataset(torch.FloatTensor(x_val.reshape(x_shape)), torch.Double(y_val))
train_loader = DataLoader(train_ds, batch_size=512)
val_loader = DataLoader(val_ds, batch_size=10_000)


In [None]:
# Train model
model = SimpleConv().double()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())

for x_batch, y_batch in train_loader: # Train for one epoch
    print('.', end='')
    prediction = model(x_batch)
    optimizer.zero_grad()
    loss = criterion(prediction, y_batch)
    loss.backward()
    optimizer.step()
print('\nTrain loss on last batch', loss.item())

# Check accuracy
x_batch, y_batch = next(iter(val_loader))

class_preds = F.softmax(model(x_batch), dim=-1).detach().numpy()
predictions = np.argmax(class_preds, axis=-1)
print('Accuracy', accuracy_score(predictions, y_batch))


In [None]:
# Calculate uncertainty estimation
estimator = build_estimator("bald_masked", model, dropout_mask='mc_dropout', num_classes=10, keep_runs=True)


In [None]:
model.double()
estimations = estimator.estimate(x_batch.double())

In [None]:
# Calculate NDCG score for the uncertainty
errors = [metrics.log_loss(target.reshape(-1, 1), pred.reshape((-1, 10)), labels=list(range(10))) for pred, target in zip(class_preds, y_batch.numpy())]

score = ndcg(np.array(errors), estimations)
print("Quality score is ", score)


runs = estimator.last_mcd_runs()
sampled_probabilities = softmax(runs, axis=-1)


In [None]:
from scipy.special import softmax


In [None]:
from sklearn.metrics import roc_curve
import numpy as np


def entropy(x):
    return np.sum(-x*np.log(np.clip(x, 1e-8, 1)), axis=-1)


def mean_entropy(probabilities):
    return entropy(np.mean(probabilities, axis=1))


def bald(probabilities):
    predictive_entropy = entropy(np.mean(probabilities, axis=1))
    expected_entropy = np.mean(entropy(probabilities), axis=1)

    return predictive_entropy - expected_entropy

def var_ratio(probabilities):
    top_classes = np.argmax(probabilities, axis=-1)
    # count how many time repeats the strongest class
    mode_count = lambda preds : np.max(np.bincount(preds))
    modes = [mode_count(point) for point in top_classes]
    ue = 1 - np.array(modes) / probabilities.shape[1]
    return ue

def ensemble_max_prob(probabilities):
    mean_probabilities = np.mean(probabilities, axis=1)
    top_probabilities = np.max(mean_probabilities, axis=-1)
    return 1 - top_probabilities

In [None]:

probabilities = F.softmax(model(x_batch.double()), dim=-1).detach().numpy()
labels = np.array(y_batch)

In [None]:
methods = {
    'bald': bald,
    'var_ratio': var_ratio,
    'entropy': mean_entropy,
    'sampled_max_prob': ensemble_max_prob
}

def plot_error_detection(probabilities, labels, sampled_probabilities):
    """
    N - number of points in the dataset, C - number of classes, R - number of sampling runs
    all arguments expect to be np.array
    :param probabilities:  probabilities by model without dropout, NxC
    :param labels: true labels for classification, N
    :param sampled_probabilities: probabilities sampled by dropout, NxRxC
    :return: None, make roc curve plot for error detection
    """
    predictions = np.argmax(probabilities, axis=-1)
    errors = (labels != predictions).astype('uint8')

    for name, method_function  in methods.items():
        fpr, tpr, _ = roc_curve(errors, method_function(sampled_probabilities))
        plt.plot(fpr, tpr, label=name)
    max_prob = 1 - np.max(probabilities, axis=-1)
    fpr, tpr, _ = roc_curve(errors, max_prob)
    plt.plot(fpr, tpr, label='max_prob')
    plt.legend()



plot_error_detection(probabilities, labels, sampled_probabilities, methods)