<h1>BatchEnsemble: An Alternative Approach To Efficient Ensemble and Lifelong Learning<h1/>

Yeming Wen, Dustin Tran & Jimmy Ba

<h2>Classification with Wide ResNet and CIFAR10<h2/>

In [1]:
import os
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

import torch
import torch.nn as nn
import torch.nn.functional as F
m = nn.Softplus()

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [16]:
import torchvision
import torchvision.transforms as transforms
data_dir = '/content/drive/My Drive/AALTO/cs4875-research/data/'
transform = transforms.Compose([
    transforms.ToTensor(),  # Transform to tensor
    transforms.Normalize((0.5,), (0.5,))  # Min-max scaling to [-1, 1]
])

trainset = torchvision.datasets.CIFAR10(root=data_dir, train=True, download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root=data_dir, train=False, download=True, transform=transform)

classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=5, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [21]:
class Cov2dEnsemble(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, num_models=4, first_layer=False):
      super(Cov2dEnsemble, self).__init__()
      self.in_channels = in_channels
      self.out_channels = out_channels
      self.num_models = num_models
      self.first_layer = first_layer
      self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False)
      self.alpha = nn.Parameter(torch.Tensor(num_models, in_channels))
      self.gamma = nn.Parameter(torch.Tensor(num_models, out_channels))
      nn.init.normal_(self.alpha, mean=1., std=0.5)
      nn.init.normal_(self.gamma, mean=1., std=0.5)

    def forward(self, x):
      if not self.training and self.first_layer:
        x = torch.cat([x for i in range(self.num_models)], dim=0)
      examples_per_model = int(x.size(0) / self.num_models)
      alpha = torch.cat([self.alpha for i in range(examples_per_model)], dim=1).view([-1, self.in_channels])
      alpha.unsqueeze_(-1).unsqueeze_(-1)
      gamma = torch.cat([self.gamma for i in range(examples_per_model)], dim=1).view([-1, self.out_channels])
      gamma.unsqueeze_(-1).unsqueeze_(-1)
      return self.conv1(x*alpha)*gamma

class FCEnsemble(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, num_models=4):
      super(FCEnsemble, self).__init__()
      self.in_channels = in_channels
      self.out_channels = out_channels
      self.num_models = num_models
      self.fc = nn.Linear(in_channels, out_channels, bias=False)
      self.alpha = nn.Parameter(torch.Tensor(num_models, in_channels))
      self.gamma = nn.Parameter(torch.Tensor(num_models, out_channels))
      nn.init.normal_(self.alpha, mean=1., std=0.5)
      nn.init.normal_(self.gamma, mean=1., std=0.5)

    def forward(self, x):
      examples_per_model = int(x.size(0) / self.num_models)
      alpha = torch.cat([self.alpha for i in range(examples_per_model)], dim=1).view([-1, self.in_channels])
      gamma = torch.cat([self.gamma for i in range(examples_per_model)], dim=1).view([-1, self.out_channels])
      return self.fc(x*alpha)*gamma


In [22]:
class BlockBatchEnsemble(nn.Module):
    def __init__(self, in_channels, out_channels, dropout_rate, stride=1, num_models=4):
        super(BlockBatchEnsemble, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.conv1 = Cov2dEnsemble(in_channels, out_channels, 3, stride=1, padding=1, num_models=num_models)
        self.dropout = nn.Dropout(p=dropout_rate)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.conv2 = Cov2dEnsemble(out_channels, out_channels, 3, stride=stride, padding=1, num_models=num_models)
        self.num_models = num_models
        # self.convs = [self.conv1, self.conv2]
        self.skip = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.skip = nn.Sequential(
                # nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=True),
                Cov2dEnsemble(in_channels, out_channels, 1, stride=stride, padding=0, num_models=num_models),
            )
    # def update_indices(self, indices):
    #     for m_conv in self.convs:
    #         m_conv.update_indices(indices)

    def forward(self, x):
        curr_bs = x.size(0)
        out = self.dropout(self.conv1(F.relu(self.bn1(x))))
        out = self.conv2(F.relu(self.bn2(out)))
        out += self.skip(x)
        return out

class GroupBlockBatchEnsemble(nn.Module):
    def __init__(self, in_channels, out_channels, n_blocks, dropout_rate, stride=1, num_models=4):
        super(GroupBlockBatchEnsemble, self).__init__()
        strides = [stride] + [1]*(int(n_blocks) - 1)
        self.in_channels = in_channels
        group = []

        for stride in strides:
            group.append(BlockBatchEnsemble(self.in_channels, out_channels, dropout_rate, stride))
            self.in_channels = out_channels

        self.group = nn.Sequential(*group)

    def forward(self, x):
        return self.group(x)

class WideResNetBatchEnsemble(nn.Module):
    def __init__(self, depth, widen_factor, dropout_rate, num_classes=10, num_models=4):
        super(WideResNetBatchEnsemble, self).__init__()
        assert ((depth-4)%6 == 0), "Depth should be 6n+4."
        n = (depth - 4)/6
        k = widen_factor
        nStages = [16, 16*k, 32*k, 64*k]

        self.conv1 = Cov2dEnsemble(in_channels=3, out_channels=nStages[0], kernel_size=1, stride=1, padding=0, num_models=num_models, first_layer=True)
        self.group1 = GroupBlockBatchEnsemble(nStages[0], nStages[1], n, dropout_rate, stride=1, num_models=num_models)
        self.group2 = GroupBlockBatchEnsemble(nStages[1], nStages[2], n, dropout_rate, stride=2, num_models=num_models)
        self.group3 = GroupBlockBatchEnsemble(nStages[2], nStages[3], n, dropout_rate, stride=2, num_models=num_models)
        self.bn1 = nn.BatchNorm2d(nStages[3])

        self.relu = nn.ReLU(inplace=True)
        self.fc = FCEnsemble(nStages[3], num_classes, num_models)
        self.nStage3 = nStages[3]

        # Initialize weights
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, np.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def forward(self, x):
        x = self.conv1(x)
        x = self.group1(x)
        x = self.group2(x)
        x = self.group3(x)
        x = self.relu(self.bn1(x))
        x = F.avg_pool2d(x, 8)
        x = x.view(-1, self.nStage3)
        return self.fc(x)


In [27]:
device = torch.device('cuda:0')
loss_func = nn.CrossEntropyLoss()
m = nn.LogSoftmax(dim=1)
eps = 0.01*2 # input ranges from (-1, 1)
learning_rate = 0.01

def compute_brier_score(p, y):
  brier_score = torch.mean((y-torch.argmax(p, 1).float())**2)
  return brier_score

def ensembleInBatch(model, optimizer):
  running_loss = 0.0
  running_brier = 0.0
  for epoch in range(numEpochs):
    # model.train()
    brier_score = 0.0
    total = 0
    for x, y in trainloader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        output = model(x)
        batch_brier_score = compute_brier_score(output, y)
        brier_score += torch.sum(batch_brier_score, 0).cpu().numpy().item()
        loss = loss_func(output, y)
        loss.backward()
        optimizer.step()
        total += y.size(0)
    if epoch == (numEpochs-1):
      running_loss = loss.item()
    print('Loss at epoch {} is {}'.format(epoch, loss.item()))
    print('Brier score at epoch {} is {}'.format(epoch, brier_score/total))
  return running_loss, brier_score/total


numEpochs = 40
training_loss = []
training_brier = []
t0 = time.time()
# for i in range(4):
model = WideResNetBatchEnsemble(28, 4, 0.5)
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
loss, brier = ensembleInBatch(model, optimizer)
  # training_loss.append(loss)
  # training_brier.append(brier)
print('NLL Loss is {}'.format(np.mean(loss)))
print('Brier score is {}'.format(np.mean(brier)))
print('Training time: {} seconds'.format(time.time() - t0))


Loss at epoch 0 is 1.9276663064956665
Brier score at epoch 0 is 0.1217960625076294
Loss at epoch 1 is 1.76310133934021
Brier score at epoch 1 is 0.1121630000114441
Loss at epoch 2 is 1.448449969291687
Brier score at epoch 2 is 0.0917960625076294
Loss at epoch 3 is 1.3844554424285889
Brier score at epoch 3 is 0.0789084375
Loss at epoch 4 is 1.2089399099349976
Brier score at epoch 4 is 0.0642072187614441
Loss at epoch 5 is 1.0314240455627441
Brier score at epoch 5 is 0.05087234375
Loss at epoch 6 is 0.6248108744621277
Brier score at epoch 6 is 0.04362690625190735
Loss at epoch 7 is 0.7734381556510925
Brier score at epoch 7 is 0.03828131249904632
Loss at epoch 8 is 0.6490518450737
Brier score at epoch 8 is 0.0342365625
Loss at epoch 9 is 0.5573354959487915
Brier score at epoch 9 is 0.030890593752861024
Loss at epoch 10 is 0.5704811215400696
Brier score at epoch 10 is 0.02788315625190735
Loss at epoch 11 is 0.7347951531410217
Brier score at epoch 11 is 0.02486296875
Loss at epoch 12 is 0.4