<h1>BatchEnsemble: An Alternative Approach To Efficient Ensemble and Lifelong Learning<h1/>

Yeming Wen, Dustin Tran & Jimmy Ba

<h2>Classification with Wide ResNet and CIFAR10<h2/>

In [1]:
import os
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

import torch
import torch.nn as nn
import torch.nn.functional as F
m = nn.Softplus()

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
import torchvision
import torchvision.transforms as transforms
data_dir = '/content/drive/My Drive/AALTO/cs4875-research/data/'
transform = transforms.Compose([
    transforms.ToTensor(),  # Transform to tensor
    transforms.Normalize((0.5,), (0.5,))  # Min-max scaling to [-1, 1]
])

trainset = torchvision.datasets.CIFAR10(root=data_dir, train=True, download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root=data_dir, train=False, download=True, transform=transform)

classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=5, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
class Cov2dEnsemble(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, num_models=4, first_layer=False):
      super(Cov2dEnsemble, self).__init__()
      self.in_channels = in_channels
      self.out_channels = out_channels
      self.num_models = num_models
      self.first_layer = first_layer
      self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False)
      self.alpha = nn.Parameter(torch.Tensor(num_models, in_channels))
      self.gamma = nn.Parameter(torch.Tensor(num_models, out_channels))
      nn.init.normal_(self.alpha, mean=1., std=0.5)
      nn.init.normal_(self.gamma, mean=1., std=0.5)

    def forward(self, x):
      if not self.training and self.first_layer:
        x = torch.cat([x for i in range(self.num_models)], dim=0)
      examples_per_model = int(x.size(0) / self.num_models)
      alpha = torch.cat([self.alpha for i in range(examples_per_model)], dim=1).view([-1, self.in_channels])
      alpha.unsqueeze_(-1).unsqueeze_(-1)
      gamma = torch.cat([self.gamma for i in range(examples_per_model)], dim=1).view([-1, self.out_channels])
      gamma.unsqueeze_(-1).unsqueeze_(-1)
      if extra != 0:
        alpha = torch.cat([alpha, alpha[:extra]], dim=0)
        gamma = torch.cat([gamma, gamma[:extra]], dim=0)
      return self.conv1(x*alpha)*gamma

class DenseEnsemble(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, num_models=4):
      super(DenseEnsemble, self).__init__()
      self.in_channels = in_channels
      self.out_channels = out_channels
      self.num_models = num_models
      self.fc = nn.Linear(in_channels, out_channels, bias=False)
      self.alpha = nn.Parameter(torch.Tensor(num_models, in_channels))
      self.gamma = nn.Parameter(torch.Tensor(num_models, out_channels))
      nn.init.normal_(self.alpha, mean=1., std=0.5)
      nn.init.normal_(self.gamma, mean=1., std=0.5)

    def forward(self, x):
      examples_per_model = int(x.size(0) / self.num_models)
      alpha = torch.cat([self.alpha for i in range(examples_per_model)], dim=1).view([-1, self.in_channels])
      gamma = torch.cat([self.gamma for i in range(examples_per_model)], dim=1).view([-1, self.out_channels])
      if extra != 0:
        alpha = torch.cat([alpha, alpha[:extra]], dim=0)
        gamma = torch.cat([gamma, gamma[:extra]], dim=0)
      return self.fc(x*alpha)*gamma


In [27]:
class BlockBatchEnsemble(nn.Module):
    def __init__(self, in_channels, out_channels, dropout_rate, stride=1, num_models=4):
        super(BlockBatchEnsemble, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.conv1 = Cov2dEnsemble(in_channels, out_channels, 3, stride=1, padding=1, num_models=num_models)
        self.dropout = nn.Dropout(p=dropout_rate)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.conv2 = Cov2dEnsemble(out_channels, out_channels, 3, stride=stride, padding=1, num_models=num_models)
        self.num_models = num_models
        self.skip = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.skip = nn.Sequential(
                Cov2dEnsemble(in_channels, out_channels, 1, stride=stride, padding=0, num_models=num_models),
            )

    def forward(self, x):
        curr_bs = x.size(0)
        out = self.dropout(self.conv1(F.relu(self.bn1(x))))
        out = self.conv2(F.relu(self.bn2(out)))
        out += self.skip(x)
        return out

class GroupBlockBatchEnsemble(nn.Module):
    def __init__(self, in_channels, out_channels, n_blocks, dropout_rate, stride=1, num_models=4):
        super(GroupBlockBatchEnsemble, self).__init__()
        strides = [stride] + [1]*(int(n_blocks) - 1)
        self.in_channels = in_channels
        group = []

        for stride in strides:
            group.append(BlockBatchEnsemble(self.in_channels, out_channels, dropout_rate, stride))
            self.in_channels = out_channels

        self.group = nn.Sequential(*group)

    def forward(self, x):
        return self.group(x)

class WideResNetBatchEnsemble(nn.Module):
    def __init__(self, depth, widen_factor, dropout_rate, num_classes=10, num_models=4):
        super(WideResNetBatchEnsemble, self).__init__()
        assert ((depth-4)%6 == 0), "Depth should be 6n+4."
        n = (depth - 4)/6
        k = widen_factor
        nStages = [16, 16*k, 32*k, 64*k]
        self.num_models = num_models
        self.num_classes = num_classes

        self.conv1 = Cov2dEnsemble(in_channels=3, out_channels=nStages[0], kernel_size=1, stride=1, padding=0, num_models=num_models, first_layer=True)
        self.group1 = GroupBlockBatchEnsemble(nStages[0], nStages[1], n, dropout_rate, stride=1, num_models=num_models)
        self.group2 = GroupBlockBatchEnsemble(nStages[1], nStages[2], n, dropout_rate, stride=2, num_models=num_models)
        self.group3 = GroupBlockBatchEnsemble(nStages[2], nStages[3], n, dropout_rate, stride=2, num_models=num_models)
        self.bn1 = nn.BatchNorm2d(nStages[3])

        self.relu = nn.ReLU(inplace=True)
        self.fc = DenseEnsemble(nStages[3], num_classes, num_models)
        self.nStage3 = nStages[3]

        # Initialize weights
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, np.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def forward(self, x):
        x = self.conv1(x)
        x = self.group1(x)
        x = self.group2(x)
        x = self.group3(x)
        x = self.relu(self.bn1(x))
        x = F.avg_pool2d(x, 8)
        x = x.view(-1, self.nStage3)
        x = self.fc(x)
        if not self.training:
            x=F.softmax(x, dim=1)
            return x.view([self.num_models, -1, self.num_classes]).mean(dim=0)
        return x


def compute_accuracy(net, testloader):
    net.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.to(device), labels.to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

In [25]:
device = torch.device('cuda:0')
loss_func = nn.CrossEntropyLoss()
m = nn.LogSoftmax(dim=1)
learning_rate = 0.01

def compute_brier_score(p, y):
  brier_score = torch.mean((y-torch.argmax(p, 1).float())**2)
  return brier_score

def ensembleInBatch(model, optimizer):
  running_loss = 0.0
  running_brier = 0.0
  model.train()
  for epoch in range(numEpochs):
    brier_score = 0.0
    total = 0
    for x, y in trainloader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        output = model(x)
        batch_brier_score = compute_brier_score(output, y)
        brier_score += torch.sum(batch_brier_score, 0).cpu().numpy().item()
        loss = loss_func(output, y)
        loss.backward()
        optimizer.step()
        total += y.size(0)
    if epoch == (numEpochs-1):
      running_loss = loss.item()
    print('Loss at epoch {} is {}'.format(epoch, loss.item()))
    print('Brier score at epoch {} is {}'.format(epoch, brier_score/total))
  return running_loss, brier_score/total


numEpochs = 40
t0 = time.time()
model = WideResNetBatchEnsemble(28, 4, 0.5)
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
loss, brier = ensembleInBatch(model, optimizer)
time_one = time.time() - t0
accuracy = compute_accuracy(model, testloader)

print('Accuracy of the network on the test images: %.3f' % accuracy)
print('NLL Loss is {}'.format(loss))
print('Brier score is {}'.format(brier))
print('Training time: {} seconds'.format(time_one))


Loss at epoch 0 is 1.978553056716919
Brier score at epoch 0 is 0.1231996875
torch.Size([5])
torch.Size([5, 10])
torch.Size([5])
torch.Size([5, 10])
torch.Size([5])
torch.Size([5, 10])
torch.Size([5])
torch.Size([5, 10])
torch.Size([5])
torch.Size([5, 10])
torch.Size([5])
torch.Size([5, 10])
torch.Size([5])
torch.Size([5, 10])
torch.Size([5])
torch.Size([5, 10])
torch.Size([5])
torch.Size([5, 10])
torch.Size([5])
torch.Size([5, 10])
torch.Size([5])
torch.Size([5, 10])
torch.Size([5])
torch.Size([5, 10])
torch.Size([5])
torch.Size([5, 10])
torch.Size([5])
torch.Size([5, 10])
torch.Size([5])
torch.Size([5, 10])
torch.Size([5])
torch.Size([5, 10])
torch.Size([5])
torch.Size([5, 10])
torch.Size([5])
torch.Size([5, 10])
torch.Size([5])
torch.Size([5, 10])
torch.Size([5])
torch.Size([5, 10])
torch.Size([5])
torch.Size([5, 10])
torch.Size([5])
torch.Size([5, 10])
torch.Size([5])
torch.Size([5, 10])
torch.Size([5])
torch.Size([5, 10])
torch.Size([5])
torch.Size([5, 10])
torch.Size([5])
torch.Si

In [None]:
torch.save(model.state_dict(), '/content/drive/My Drive/AALTO/cs4875-research/archive/batch_ensemble.pth')
print('Model saved to %s.' % ('batch_ensemble.pth'))

Reference: https://github.com/giannifranchi/LP_BNN/blob/d324ba8d0ade75e5bfe9a14c670fe71469f49db6/networks/batchensemble_layers.py