<h1>Loss Surfaces, Mode Connectivity, and Fast Ensembling of DNNs<h1/>

Timur Garipov, Pavel Izmailov, Dmitrii Podoprikhin, Dmitry Vetrov, Andrew Gordon Wilson


<h2>Classification with Wide ResNet and CIFAR10<h2/>

In [2]:
import os
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

import torch
import torch.nn as nn
import torch.nn.functional as F
m = nn.Softplus()

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
import torchvision
import torchvision.transforms as transforms
data_dir = '/content/drive/My Drive/AALTO/cs4875-research/data/'
transform = transforms.Compose([
    transforms.ToTensor(),  # Transform to tensor
    transforms.Normalize((0.5,), (0.5,))  # Min-max scaling to [-1, 1]
])

trainset = torchvision.datasets.CIFAR10(root=data_dir, train=True, download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root=data_dir, train=False, download=True, transform=transform)

classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=5, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [5]:
class Block(nn.Module):
    def __init__(self, in_channels, out_channels, dropout_rate, stride=1):
        """
        Args:
          in_channels:  Number of input channels.
          out_channels: Number of output channels.
          dropout_rate:  Dropout Rate
          stride:       Controls the stride.
        """
        super(Block, self).__init__()
        self.conv = nn.Sequential(
            nn.BatchNorm2d(in_channels),
            nn.ReLU(inplace = True),
            nn.Conv2d(in_channels, out_channels, kernel_size=3, bias=False, padding = 1),
            nn.Dropout(p = dropout_rate),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace = True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, bias=False, stride = stride, padding = 1)
        )
        self.skip = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.skip = nn.Sequential(
               nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
            )

    def forward(self, x):
        out = self.conv(x)
        out += self.skip(x)
        return out

class GroupOfBlocks(nn.Module):
    def __init__(self, in_channels, out_channels, n_blocks, dropout_rate, stride=1):
        super(GroupOfBlocks, self).__init__()
        strides = [stride] + [1]*(int(n_blocks) - 1)
        self.in_channels = in_channels
        group = []

        for stride in strides:
            group.append(Block(self.in_channels, out_channels, dropout_rate, stride))
            self.in_channels = out_channels

        self.group = nn.Sequential(*group)

    def forward(self, x):
        return self.group(x)

class WideResNet(nn.Module):
    def __init__(self, depth, widen_factor, dropout_rate, num_classes=10):
        super(WideResNet, self).__init__()
        assert ((depth-4)%6 == 0), "Depth should be 6n+4."
        n = (depth - 4)/6
        k = widen_factor
        nStages = [16, 16*k, 32*k, 64*k]

        self.conv1 = nn.Conv2d(in_channels=3, out_channels=nStages[0], kernel_size=3, stride=1, padding=1, bias=False)
        self.group1 = GroupOfBlocks(nStages[0], nStages[1], n, dropout_rate)
        self.group2 = GroupOfBlocks(nStages[1], nStages[2], n, dropout_rate, stride=2)
        self.group3 = GroupOfBlocks(nStages[2], nStages[3], n, dropout_rate, stride=2)
        self.bn1 = nn.BatchNorm2d(nStages[3])

        self.relu = nn.ReLU(inplace=True)
        self.fc = nn.Linear(nStages[3], num_classes)
        self.nStage3 = nStages[3]

        # Initialize weights
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, np.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def forward(self, x):
        x = self.conv1(x)

        x = self.group1(x)
        x = self.group2(x)
        x = self.group3(x)
        x = self.relu(self.bn1(x))
        x = F.avg_pool2d(x, 8)
        x = x.view(-1, self.nStage3)
        return self.fc(x)

In [6]:
# code adapted from https://github.com/timgaripov/dnn-mode-connectivity

def learning_rate_schedule(base_lr, epoch, total_epochs):
    alpha = epoch / total_epochs
    if alpha <= 0.5:
        factor = 1.0
    elif alpha <= 0.9:
        factor = 1.0 - (alpha - 0.5) / 0.4 * 0.99
    else:
        factor = 0.01
    return factor * base_lr

def adjust_learning_rate(optimizer, lr):
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return lr



In [9]:
device = torch.device('cuda:0')
loss_func = nn.CrossEntropyLoss()
m = nn.LogSoftmax(dim=1)


def compute_brier_score(p, y):
  brier_score = torch.mean((y-torch.argmax(p, 1).float())**2)
  return brier_score

def ensembleFGE(model, optimizer):
  running_loss = 0.0
  running_brier = 0.0
  startEpoch = 1
  for epoch in range(startEpoch, numEpochs):
    # model.train()
    learning_rate = learning_rate_schedule(0.01, epoch, numEpochs)
    adjust_learning_rate(optimizer, learning_rate)
    brier_score = 0.0
    total = 0
    for iter, (x, y) in enumerate(trainloader):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        output = model(x)
        batch_brier_score = compute_brier_score(output, y)
        brier_score += torch.sum(batch_brier_score, 0).cpu().numpy().item()
        loss = loss_func(output, y)
        loss.backward()
        optimizer.step()
        total += y.size(0)
    if epoch == (numEpochs-1):
      running_loss = loss.item()
    print('Loss at epoch {} is {}'.format(epoch, loss.item()))
    print('Brier score at epoch {} is {}'.format(epoch, brier_score/total))
  return running_loss, brier_score/total


numEpochs = 200
lr = 0.1
training_loss = []
training_brier = []
t0 = time.time()
model = WideResNet(28, 10, 0.5)
model.to(device)
optimizer = torch.optim.SGD(
    filter(lambda param: param.requires_grad, model.parameters()),
    lr=lr,
    momentum=0.9,
    weight_decay=5e-4
)
loss, brier = ensembleFGE(model, optimizer)
print('NLL Loss is {}'.format(np.mean(loss)))
print('Brier score is {}'.format(np.mean(brier)))
print('Training time: {} seconds'.format(time.time() - t0))


Loss at epoch 1 is 1.4771322011947632
Brier score at epoch 1 is 0.10701743749618531


KeyboardInterrupt: ignored

In [None]:
torch.save(model.state_dict(), '/content/drive/My Drive/AALTO/cs4875-research/archive/fge_ensemble.pth')
print('Model saved to %s.' % ('fge_ensemble.pth'))