<h1>Simple and Scalable Predictive Uncertainty Estimation using Deep Ensembles<h1/>

Balaji Lakshminarayanan Alexander Pritzel Charles Blundell

<h2>Classification with Wide ResNet and CIFAR10<h2/>

In [2]:
import os
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

import torch
import torch.nn as nn
import torch.nn.functional as F
m = nn.Softplus()

In [6]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [11]:
class Block(nn.Module):
    def __init__(self, in_channels, out_channels, dropout_rate, stride=1):
        """
        Args:
          in_channels:  Number of input channels.
          out_channels: Number of output channels.
          dropout_rate:  Dropout Rate
          stride:       Controls the stride.
        """
        super(Block, self).__init__()
        self.conv = nn.Sequential(
            nn.BatchNorm2d(in_channels),
            nn.ReLU(inplace = True),
            nn.Conv2d(in_channels, out_channels, kernel_size=3, bias=False, padding = 1),
            nn.Dropout(p = dropout_rate),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace = True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, bias=False, stride = stride, padding = 1)
        )
        self.skip = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.skip = nn.Sequential(
               nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
            )

    def forward(self, x):
        out = self.conv(x)
        out += self.skip(x)
        return out

class GroupOfBlocks(nn.Module):
    def __init__(self, in_channels, out_channels, n_blocks, dropout_rate, stride=1):
        super(GroupOfBlocks, self).__init__()
        strides = [stride] + [1]*(int(n_blocks) - 1)
        self.in_channels = in_channels
        group = []

        for stride in strides:
            group.append(Block(self.in_channels, out_channels, dropout_rate, stride))
            self.in_channels = out_channels

        self.group = nn.Sequential(*group)

    def forward(self, x):
        return self.group(x)

class WideResNet(nn.Module):
    def __init__(self, depth, widen_factor, dropout_rate, num_classes=10):
        super(WideResNet, self).__init__()
        assert ((depth-4)%6 == 0), "Depth should be 6n+4."
        n = (depth - 4)/6
        k = widen_factor
        nStages = [16, 16*k, 32*k, 64*k]

        self.conv1 = nn.Conv2d(in_channels=3, out_channels=nStages[0], kernel_size=3, stride=1, padding=1, bias=False)
        self.group1 = GroupOfBlocks(nStages[0], nStages[1], n, dropout_rate)
        self.group2 = GroupOfBlocks(nStages[1], nStages[2], n, dropout_rate, stride=2)
        self.group3 = GroupOfBlocks(nStages[2], nStages[3], n, dropout_rate, stride=2)
        self.bn1 = nn.BatchNorm2d(nStages[3])

        self.relu = nn.ReLU(inplace=True)
        self.fc = nn.Linear(nStages[3], num_classes)
        self.nStage3 = nStages[3]

        # Initialize weights
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, np.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def forward(self, x):
        x = self.conv1(x)

        x = self.group1(x)
        x = self.group2(x)
        x = self.group3(x)
        x = self.relu(self.bn1(x))
        x = F.avg_pool2d(x, 8)
        x = x.view(-1, self.nStage3)
        return self.fc(x)

# This function computes the accuracy on the test dataset
def compute_accuracy_ood(net, testloader):
    net.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.to(device), labels.to(device)
            images = images.clone().detach().requires_grad_(True)
            images = images + eps*(torch.sign(images))
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

def compute_accuracy(net, testloader):
    net.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.to(device), labels.to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

In [5]:
import torchvision
import torchvision.transforms as transforms
data_dir = '/content/drive/My Drive/AALTO/cs4875-research/data/'
transform = transforms.Compose([
    transforms.ToTensor(),  # Transform to tensor
    transforms.Normalize((0.5,), (0.5,))  # Min-max scaling to [-1, 1]
])

trainset = torchvision.datasets.CIFAR10(root=data_dir, train=True, download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root=data_dir, train=False, download=True, transform=transform)

classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=5, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


<h2>Ensemble without adversarial training<h2/>

In [14]:
device = torch.device('cuda:0')
loss_func = nn.CrossEntropyLoss()
m = nn.LogSoftmax(dim=1)
eps = 0.01*2 # input ranges from (-1, 1)
learning_rate = 0.01

def compute_brier_score(p, y):
  brier_score = torch.mean((y-torch.argmax(p, 1).float())**2)
  return brier_score

def ensembleWithoutAdversarial(model, optimizer):
  running_loss = 0.0
  running_brier = 0.0
  model.train()
  for epoch in range(numEpochs):
    brier_score = 0.0
    total = 0
    for x, y in trainloader:
        x, y = x.to(device), y.to(device)
        x = x.clone().detach().requires_grad_(True)
        optimizer.zero_grad()
        output = model(x)
        batch_brier_score = compute_brier_score(output, y)
        brier_score += torch.sum(batch_brier_score, 0).cpu().numpy().item()
        loss = loss_func(output, y)
        loss.backward()
        optimizer.step()
        total += y.size(0)
    if epoch == (numEpochs-1):
      running_loss = loss.item()
    if epoch == (numEpochs-1):
      print('Loss at epoch {} is {}'.format(epoch, loss.item()))
      print('Brier score at epoch {} is {}'.format(epoch, brier_score/total))
  return running_loss, brier_score/total


numEpochs = 40
training_loss = []
training_brier = []
t0 = time.time()
accuracys = []
accuracy_oods = []
time_all = []
for i in range(4):
  model = WideResNet(28, 4, 0.5)
  model.to(device)
  optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
  loss, brier = ensembleWithoutAdversarial(model, optimizer)
  time_one = time.time() - t0
  training_loss.append(loss)
  training_brier.append(brier)
  accuracy = compute_accuracy(model, testloader)
  accuracy_ood = compute_accuracy_ood(model, testloader)

  accuracys.append(accuracy)
  accuracy_oods.append(accuracy_ood)
  time_all.append(time_one)
  print('Accuracy of the network on the test images: %.3f' % accuracy)
  print('Accuracy of the network on the OOD test images: %.3f' % accuracy_ood)
  print('NLL Loss is {}'.format(loss))
  print('Brier score is {}'.format(brier))
  print('Training time: {} seconds'.format(time_one))
print('Mean:')
print('Accuracy of the network on the test images: %.3f' % np.mean(accuracys))
print('Accuracy of the network on the OOD test images: %.3f' % np.mean(accuracy_oods))
print('NLL Loss is {}'.format(np.mean(training_loss)))
print('Brier score is {}'.format(np.mean(training_brier)))
print('Training time: {} seconds'.format(np.mean(time_all)))


Loss at epoch 0 is 1.8651790618896484
Brier score at epoch 0 is 0.399615
Accuracy of the network on the test images: 0.167
Accuracy of the network on the OOD test images: 0.168
NLL Loss is [1.8651790618896484]
Brier score is [0.399615]
Training time: 250.71372509002686 seconds


KeyboardInterrupt: ignored

<h2>Ensemble with adversarial training<h2/>

In [None]:
device = torch.device('cuda:0')
loss_func = nn.CrossEntropyLoss()
m = nn.LogSoftmax(dim=1)
eps = 0.01*2 # input ranges from (-1, 1)
learning_rate = 0.01

def compute_brier_score(p, y):
  brier_score = torch.mean((y-torch.argmax(p, 1).float())**2)
  return brier_score

def ensembleWithAdversarial(model, optimizer):
  running_loss = 0.0
  running_brier = 0.0
  model.train()
  for epoch in range(numEpochs):
    brier_score = 0.0
    total = 0
    for x, y in trainloader:
        x, y = x.to(device), y.to(device)
        x = x.clone().detach().requires_grad_(True)
        optimizer.zero_grad()
        output = model(x)
        batch_brier_score = compute_brier_score(output, y)
        brier_score += torch.sum(batch_brier_score, 0).cpu().numpy().item()
        loss = loss_func(output, y)
        loss.backward(retain_graph=True)
        x_prime = x + eps*(torch.sign(x.grad.data))
        optimizer.zero_grad()
        output_prime = model(x_prime)
        loss = loss_func(output, y) + loss_func(output_prime, y)
        loss.backward()
        optimizer.step()
        total += y.size(0)
    if epoch == (numEpochs-1):
      running_loss = loss.item()
    if epoch == (numEpochs-1):
      print('Loss at epoch {} is {}'.format(epoch, loss.item()))
      print('Brier score at epoch {} is {}'.format(epoch, brier_score/total))
  return running_loss, brier_score/total


numEpochs = 40
training_loss = []
training_brier = []
t0 = time.time()
accuracys = []
accuracy_oods = []
time_all = []
for i in range(4):
  model = WideResNet(28, 4, 0.5)
  model.to(device)
  optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
  loss, brier = ensembleWithAdversarial(model, optimizer)
  time_one = time.time() - t0
  training_loss.append(loss)
  training_brier.append(brier)
  accuracy = compute_accuracy(model, testloader)
  accuracy_ood = compute_accuracy_ood(model, testloader)
  accuracys.append(accuracy)
  accuracy_oods.append(accuracy_ood)
  time_all.append(time_one)
  print('Accuracy of the network on the test images: %.3f' % accuracy)
  print('Accuracy of the network on the OOD test images: %.3f' % accuracy_ood)
  print('NLL Loss is {}'.format(loss))
  print('Brier score is {}'.format(brier))
  print('Training time: {} seconds'.format(time_one))
print('Mean:')
print('Accuracy of the network on the test images: %.3f' % np.mean(accuracys))
print('Accuracy of the network on the OOD test images: %.3f' % np.mean(accuracy_oods))
print('NLL Loss is {}'.format(np.mean(training_loss)))
print('Brier score is {}'.format(np.mean(training_brier)))
print('Training time: {} seconds'.format(np.mean(time_all)))


In [None]:
# torch.save(model.state_dict(), '/content/drive/My Drive/AALTO/cs4875-research/archive/simple_ensemble.pth')
# print('Model saved to %s.' % ('simple_ensemble.pth'))

In [None]:
# device = torch.device('cuda:0')
# loss_func = nn.CrossEntropyLoss()
# m = nn.LogSoftmax(dim=1)
# eps = 0.01*2 # input ranges from (-1, 1)
# learning_rate = 0.01

# def compute_brier_score(p, y):
#   brier_score = torch.mean((y-torch.argmax(p, 1).float())**2)
#   return brier_score
#   # return torch.mean(torch.pow((y-torch.argmax(p, 1)), 2))

# def ensembleWithAdversarial(model, optimizer):
#   running_loss = 0.0
#   running_brier = 0.0
#   for epoch in range(numEpochs):
#     model.train()
#     brier_score = 0.0
#     total = 0
#     for x, y in trainloader:
#         x, y = x.to(device), y.to(device)
#         x = x.clone().detach().requires_grad_(True)
#         optimizer.zero_grad()
#         output = model(x)
#         batch_brier_score = compute_brier_score(output, y)
#         brier_score += torch.sum(batch_brier_score, 0).cpu().numpy().item()
#         loss = loss_func(output, y)
#         loss.backward(retain_graph=True)
#         x_prime = x + eps*(torch.sign(x.grad.data))
#         optimizer.zero_grad()
#         output_prime = model(x_prime)
#         loss = loss_func(output, y) + loss_func(output_prime, y)
#         loss.backward()
#         optimizer.step()
#         total += y.size(0)
#     if epoch == (numEpochs-1):
#       running_loss = loss.item()
#     print('Loss at epoch {} is {}'.format(epoch, loss.item()))
#     print('Brier score at epoch {} is {}'.format(epoch, brier_score/total))
#   return running_loss, brier_score/total


# numEpochs = 40
# training_loss = []
# training_brier = []
# t0 = time.time()
# for i in range(4):
#   model = WideResNet(28, 4, 0.5)
#   model.to(device)
#   optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
#   loss, brier = ensembleWithAdversarial(model, optimizer)
#   training_loss.append(loss)
#   training_brier.append(brier)
# print('NLL Loss is {}'.format(np.mean(training_loss)))
# print('Brier score is {}'.format(np.mean(training_brier)))
# print('Training time: {} seconds'.format(time.time() - t0))


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
        [-8.8664e-01, -7.5564e-01,  9.0047e-01, -1.3845e-02,  1.1947e+00,
          4.4006e-01,  1.1900e+00,  4.5717e-01, -1.5454e+00, -3.5307e-01],
        [ 1.1567e-02,  1.6259e-01,  1.1097e-01, -1.7223e-01, -8.7033e-02,
          5.8350e-03, -4.3786e-01,  6.7232e-02, -1.0791e-01,  3.5288e-01],
        [ 7.3839e-01,  3.6759e-01,  2.7580e-01, -4.4718e-01, -5.8736e-01,
         -1.2363e+00, -1.2823e+00,  1.2149e-01,  4.9667e-01,  7.5280e-01],
        [-5.2336e-01, -3.1251e-01,  4.2621e-01, -4.9986e-02,  5.4083e-01,
          3.8285e-01,  4.9916e-01,  2.3079e-01, -8.4243e-01, -5.8232e-03],
        [-1.0497e+00, -9.5799e-01,  1.1410e+00,  5.8903e-04,  1.5245e+00,
          4.0535e-01,  1.4866e+00,  5.4226e-01, -1.8370e+00, -5.0133e-01],
        [-1.1302e-01,  1.5053e-01, -7.3189e-02, -1.1630e-01, -1.6622e-01,
          2.7531e-01, -3.2297e-01,  2.6245e-02, -5.8601e-02,  3.9682e-01],
        [-2.8370e-01, -1.3541e-01,  2.663

In [None]:
training_loss