<h1>Simple and Scalable Predictive Uncertainty Estimation using Deep Ensembles<h1/>

Balaji Lakshminarayanan Alexander Pritzel Charles Blundell

<h2>Classification with Wide ResNet and CIFAR10<h2/>

In [18]:
import os
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

import torch
import torch.nn as nn
import torch.nn.functional as F
m = nn.Softplus()

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [14]:
class Block(nn.Module):
    def __init__(self, in_channels, out_channels, dropout_rate, stride=1):
        """
        Args:
          in_channels:  Number of input channels.
          out_channels: Number of output channels.
          dropout_rate:  Dropout Rate
          stride:       Controls the stride.
        """
        super(Block, self).__init__()
        self.conv = nn.Sequential(
            nn.BatchNorm2d(in_channels),
            nn.ReLU(inplace = True),
            nn.Conv2d(in_channels, out_channels, kernel_size=3, bias=True, padding = 1),
            nn.Dropout(p = dropout_rate),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace = True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, bias=True, stride = stride, padding = 1)
        )
        self.skip = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.skip = nn.Sequential(
               nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=True),
            )

    def forward(self, x):
        out = self.conv(x)
        out += self.skip(x)
        return out

class GroupOfBlocks(nn.Module):
    def __init__(self, in_channels, out_channels, n_blocks, dropout_rate, stride=1):
        super(GroupOfBlocks, self).__init__()
        strides = [stride] + [1]*(int(n_blocks) - 1)
        self.in_channels = in_channels
        group = []

        for stride in strides:
            group.append(Block(self.in_channels, out_channels, dropout_rate, stride))
            self.in_channels = out_channels

        self.group = nn.Sequential(*group)

        # first_block = Block(in_channels, out_channels, stride)
        # other_blocks = [Block(out_channels, out_channels) for _ in range(1, n_blocks)]
        # self.group = nn.Sequential(first_block, *other_blocks)

    def forward(self, x):
        return self.group(x)

class WideResNet(nn.Module):
    def __init__(self, depth, widen_factor, dropout_rate, num_classes=10):
        super(WideResNet, self).__init__()
        assert ((depth-4)%6 == 0), "Depth should be 6n+4."
        n = (depth - 4)/6
        k = widen_factor
        nStages = [16, 16*k, 32*k, 64*k]

        self.conv1 = nn.Conv2d(in_channels=3, out_channels=nStages[0], kernel_size=3, stride=1, padding=1, bias=True)
        self.group1 = GroupOfBlocks(nStages[0], nStages[1], n, dropout_rate)
        self.group2 = GroupOfBlocks(nStages[1], nStages[2], n, dropout_rate, stride=2)
        self.group3 = GroupOfBlocks(nStages[2], nStages[3], n, dropout_rate, stride=2)
        self.bn1 = nn.BatchNorm2d(nStages[3])

        self.relu = nn.ReLU(inplace=True)
        self.fc = nn.Linear(nStages[3], num_classes)
        self.nStage3 = nStages[3]

        # Initialize weights
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, np.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def forward(self, x):
        x = self.conv1(x)

        x = self.group1(x)
        x = self.group2(x)
        x = self.group3(x)
        x = self.relu(self.bn1(x))
        x = F.avg_pool2d(x, 8)
        x = x.view(-1, self.nStage3)
        return self.fc(x)

In [5]:
import torchvision
import torchvision.transforms as transforms
data_dir = '/content/drive/My Drive/AALTO/cs4875-research/data/'
transform = transforms.Compose([
    transforms.ToTensor(),  # Transform to tensor
    transforms.Normalize((0.5,), (0.5,))  # Min-max scaling to [-1, 1]
])

trainset = torchvision.datasets.CIFAR10(root=data_dir, train=True, download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root=data_dir, train=False, download=True, transform=transform)

classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=5, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [None]:
device = torch.device('cuda:0')
loss_func = nn.CrossEntropyLoss()
m = nn.LogSoftmax(dim=1)
eps = 0.01*2 # input ranges from (-1, 1)
learning_rate = 0.01

def ensembleWithAdversarial(model, optimizer):
  running_loss = 0.0
  for epoch in range(numEpochs):
    # model.train()
    for x, y in trainloader:
        x, y = x.to(device), y.to(device)
        x = x.clone().detach().requires_grad_(True)
        optimizer.zero_grad()
        output = model(x)
        loss = loss_func(output, y)
        loss.backward(retain_graph=True)
        x_prime = x + eps*(torch.sign(x.grad.data))
        optimizer.zero_grad()
        output_prime = model(x_prime)
        loss = loss_func(output, y) + loss_func(output_prime, y)
        loss.backward()
        optimizer.step()
    if epoch == (numEpochs-1):
      running_loss = loss.item()
    print('Loss at epoch {} is {}'.format(epoch, loss.item()))
  return running_loss


numEpochs = 40
training_loss = []
t0 = time.time()
for i in range(4):
  model = WideResNet(28, 10, 0.5)
  model.to(device)
  optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
  training_loss.append(ensembleWithAdversarial(model, optimizer))
print('NLL Loss is {}'.format(np.mean(training_loss)))
print('Training time: {} seconds'.format(time.time() - t0))


Loss at epoch 0 is 2.8912243843078613
Loss at epoch 1 is 3.298201084136963
Loss at epoch 2 is 2.4600324630737305
Loss at epoch 3 is 2.1528806686401367
Loss at epoch 4 is 1.593316674232483
Loss at epoch 5 is 1.1246696710586548
Loss at epoch 6 is 1.7617130279541016
Loss at epoch 7 is 1.0838863849639893
Loss at epoch 8 is 1.858198642730713
Loss at epoch 9 is 1.5756192207336426
Loss at epoch 10 is 1.4269468784332275
