<h1>Loss Surfaces, Mode Connectivity, and Fast Ensembling of DNNs<h1/>

Timur Garipov, Pavel Izmailov, Dmitrii Podoprikhin, Dmitry Vetrov, Andrew Gordon Wilson


<h2>Classification with Wide ResNet and CIFAR10<h2/>

In [1]:
import os
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

import torch
import torch.nn as nn
import torch.nn.functional as F
m = nn.Softplus()

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
import torchvision
import torchvision.transforms as transforms
data_dir = '/content/drive/My Drive/AALTO/cs4875-research/data/'
transform = transforms.Compose([
    transforms.ToTensor(),  # Transform to tensor
    transforms.Normalize((0.5,), (0.5,))  # Min-max scaling to [-1, 1]
])

trainset = torchvision.datasets.CIFAR10(root=data_dir, train=True, download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root=data_dir, train=False, download=True, transform=transform)

classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=5, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
class Block(nn.Module):
    def __init__(self, in_channels, out_channels, dropout_rate, stride=1):
        """
        Args:
          in_channels:  Number of input channels.
          out_channels: Number of output channels.
          dropout_rate:  Dropout Rate
          stride:       Controls the stride.
        """
        super(Block, self).__init__()
        self.conv = nn.Sequential(
            nn.BatchNorm2d(in_channels),
            nn.ReLU(inplace = True),
            nn.Conv2d(in_channels, out_channels, kernel_size=3, bias=False, padding = 1),
            nn.Dropout(p = dropout_rate),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace = True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, bias=False, stride = stride, padding = 1)
        )
        self.skip = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.skip = nn.Sequential(
               nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
            )

    def forward(self, x):
        out = self.conv(x)
        out += self.skip(x)
        return out

class GroupOfBlocks(nn.Module):
    def __init__(self, in_channels, out_channels, n_blocks, dropout_rate, stride=1):
        super(GroupOfBlocks, self).__init__()
        strides = [stride] + [1]*(int(n_blocks) - 1)
        self.in_channels = in_channels
        group = []

        for stride in strides:
            group.append(Block(self.in_channels, out_channels, dropout_rate, stride))
            self.in_channels = out_channels

        self.group = nn.Sequential(*group)

    def forward(self, x):
        return self.group(x)

class WideResNet(nn.Module):
    def __init__(self, depth, widen_factor, dropout_rate, num_classes=10):
        super(WideResNet, self).__init__()
        assert ((depth-4)%6 == 0), "Depth should be 6n+4."
        n = (depth - 4)/6
        k = widen_factor
        nStages = [16, 16*k, 32*k, 64*k]

        self.conv1 = nn.Conv2d(in_channels=3, out_channels=nStages[0], kernel_size=3, stride=1, padding=1, bias=False)
        self.group1 = GroupOfBlocks(nStages[0], nStages[1], n, dropout_rate)
        self.group2 = GroupOfBlocks(nStages[1], nStages[2], n, dropout_rate, stride=2)
        self.group3 = GroupOfBlocks(nStages[2], nStages[3], n, dropout_rate, stride=2)
        self.bn1 = nn.BatchNorm2d(nStages[3])

        self.relu = nn.ReLU(inplace=True)
        self.fc = nn.Linear(nStages[3], num_classes)
        self.nStage3 = nStages[3]

        # Initialize weights
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, np.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def forward(self, x):
        x = self.conv1(x)

        x = self.group1(x)
        x = self.group2(x)
        x = self.group3(x)
        x = self.relu(self.bn1(x))
        x = F.avg_pool2d(x, 8)
        x = x.view(-1, self.nStage3)
        return self.fc(x)

In [5]:
# code adapted from https://github.com/timgaripov/dnn-mode-connectivity

def learning_rate_schedule(base_lr, epoch, total_epochs):
    alpha = epoch / total_epochs
    if alpha <= 0.5:
        factor = 1.0
    elif alpha <= 0.9:
        factor = 1.0 - (alpha - 0.5) / 0.4 * 0.99
    else:
        factor = 0.01
    return factor * base_lr

def adjust_learning_rate(optimizer, lr):
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return lr

def cyclic_learning_rate(epoch, cycle, alpha_1, alpha_2):
    def schedule(iter):
        t = ((epoch % cycle) + iter) / cycle
        if t < 0.5:
            return alpha_1 * (1.0 - 2.0 * t) + alpha_2 * 2.0 * t
        else:
            return alpha_1 * (2.0 * t - 1.0) + alpha_2 * (2.0 - 2.0 * t)
    return schedule

def save_checkpoint(dir, epoch, name='checkpoint', **kwargs):
    state = {
        'epoch': epoch,
    }
    state.update(kwargs)
    filepath = os.path.join(dir, '%s-%d.pt' % (name, epoch))
    torch.save(state, filepath)

def compute_accuracy(net, testloader):
    net.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.to(device), labels.to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total


In [None]:
device = torch.device('cuda:0')
loss_func = nn.CrossEntropyLoss()
m = nn.LogSoftmax(dim=1)


def compute_brier_score(p, y):
  brier_score = torch.mean((y-torch.argmax(p, 1).float())**2)
  return brier_score

def ensembleEndpoint(model, optimizer):
  running_loss = 0.0
  running_brier = 0.0
  startEpoch = 1
  for epoch in range(startEpoch, numEpochs+1):
    model.train()
    learning_rate = learning_rate_schedule(0.01, epoch, numEpochs)
    adjust_learning_rate(optimizer, learning_rate)
    brier_score = 0.0
    total = 0
    for iter, (x, y) in enumerate(trainloader):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        output = model(x)
        batch_brier_score = compute_brier_score(output, y)
        brier_score += torch.sum(batch_brier_score, 0).cpu().numpy().item()
        loss = loss_func(output, y)
        loss.backward()
        optimizer.step()
        total += y.size(0)
    if epoch == (numEpochs-1):
      running_loss = loss.item()
    print('Loss at epoch {} is {}'.format(epoch, loss.item()))
    print('Brier score at epoch {} is {}'.format(epoch, brier_score/total))
  return running_loss, brier_score/total


numEpochs = 100
lr = 0.1
model = WideResNet(28, 4, 0.5)
model.to(device)
optimizer = torch.optim.SGD(
    filter(lambda param: param.requires_grad, model.parameters()),
    lr=lr,
    momentum=0.9,
    weight_decay=5e-4
)
t0 = time.time()
loss, brier = ensembleEndpoint(model, optimizer)

# print('NLL Loss is {}'.format(loss))
# print('Brier score is {}'.format(brier))
# print('Training endpoint time: {} seconds'.format(time.time() - t0))
# accuracy = compute_accuracy(model, testloader)
# print('Accuracy of the network on the test images: %.3f' % accuracy)


Loss at epoch 1 is 1.0212482213974
Brier score at epoch 1 is 0.387283125
Loss at epoch 2 is 1.1412148475646973
Brier score at epoch 2 is 0.2485925
Loss at epoch 3 is 0.6456626057624817
Brier score at epoch 3 is 0.18753375
Loss at epoch 4 is 1.2378424406051636
Brier score at epoch 4 is 0.156889375
Loss at epoch 5 is 0.49486008286476135
Brier score at epoch 5 is 0.1425975
Loss at epoch 6 is 0.43205174803733826
Brier score at epoch 6 is 0.1276625
Loss at epoch 7 is 0.6529366970062256
Brier score at epoch 7 is 0.11148125
Loss at epoch 8 is 1.37787926197052
Brier score at epoch 8 is 0.102945625
Loss at epoch 9 is 0.6913275718688965
Brier score at epoch 9 is 0.096009375
Loss at epoch 10 is 0.44628769159317017
Brier score at epoch 10 is 0.090103125
Loss at epoch 11 is 0.5780378580093384
Brier score at epoch 11 is 0.082464375
Loss at epoch 12 is 0.5613792538642883
Brier score at epoch 12 is 0.07808875
Loss at epoch 13 is 0.28072506189346313
Brier score at epoch 13 is 0.075476875
Loss at epoch 

In [None]:
torch.save(model.state_dict(), '/content/drive/My Drive/AALTO/cs4875-research/archive/fge_ensemble-wrn28-4-100.pth')
print('Model saved to %s.' % ('fge_ensemble-wrn28-4-100.pth'))

Model saved to fge_ensemble-wrn28-4-100.pth.


In [None]:
device = torch.device('cuda:0')
loss_func = nn.CrossEntropyLoss()
m = nn.LogSoftmax(dim=1)


def compute_brier_score(p, y):
  brier_score = torch.mean((y-torch.argmax(p, 1).float())**2)
  return brier_score

def ensembleEndpoint(model, optimizer):
  running_loss = 0.0
  running_brier = 0.0
  startEpoch = 1
  for epoch in range(startEpoch, numEpochs+1):
    model.train()
    # learning_rate = learning_rate_schedule(0.01, epoch, numEpochs)
    # adjust_learning_rate(optimizer, learning_rate)
    brier_score = 0.0
    total = 0
    for iter, (x, y) in enumerate(trainloader):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        output = model(x)
        batch_brier_score = compute_brier_score(output, y)
        brier_score += torch.sum(batch_brier_score, 0).cpu().numpy().item()
        loss = loss_func(output, y)
        loss.backward()
        optimizer.step()
        total += y.size(0)
    if epoch == (numEpochs-1):
      running_loss = loss.item()
    print('Loss at epoch {} is {}'.format(epoch, loss.item()))
    print('Brier score at epoch {} is {}'.format(epoch, brier_score/total))
  return running_loss, brier_score/total


numEpochs = 40
lr = 0.1
model = WideResNet(28, 4, 0.5)
model.to(device)
# optimizer = torch.optim.SGD(
#     filter(lambda param: param.requires_grad, model.parameters()),
#     lr=lr,
#     momentum=0.9,
#     weight_decay=5e-4
# )
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
t0 = time.time()
loss, brier = ensembleEndpoint(model, optimizer)
print('Training time: {} seconds'.format(time.time() - t0))
# print('NLL Loss is {}'.format(loss))
# print('Brier score is {}'.format(brier))
# print('Training endpoint time: {} seconds'.format(time.time() - t0))
# accuracy = compute_accuracy(model, testloader)
# print('Accuracy of the network on the test images: %.3f' % accuracy)


Loss at epoch 1 is 1.4954572916030884
Brier score at epoch 1 is 0.401963125
Loss at epoch 2 is 1.1136302947998047
Brier score at epoch 2 is 0.2404475
Loss at epoch 3 is 1.0627468824386597
Brier score at epoch 3 is 0.17326875
Loss at epoch 4 is 0.4522371292114258
Brier score at epoch 4 is 0.14803
Loss at epoch 5 is 0.7965518236160278
Brier score at epoch 5 is 0.129105
Loss at epoch 6 is 1.1651707887649536
Brier score at epoch 6 is 0.108628125
Loss at epoch 7 is 0.5128366947174072
Brier score at epoch 7 is 0.098068125
Loss at epoch 8 is 0.4307057857513428
Brier score at epoch 8 is 0.088875
Loss at epoch 9 is 0.37643587589263916
Brier score at epoch 9 is 0.079535
Loss at epoch 10 is 0.7994173169136047
Brier score at epoch 10 is 0.071788125
Loss at epoch 11 is 0.26644033193588257
Brier score at epoch 11 is 0.06568625
Loss at epoch 12 is 0.293373703956604
Brier score at epoch 12 is 0.059528125
Loss at epoch 13 is 0.13931550085544586
Brier score at epoch 13 is 0.05539375
Loss at epoch 14 is 

In [None]:
torch.save(model.state_dict(), '/content/drive/My Drive/AALTO/cs4875-research/archive/fge_ensemble-adam-40.pth')
print('Model saved to %s.' % ('fge_ensemble-adam-40.pth'))

Model saved to fge_ensemble-adam-40.pth.


In [None]:
def ensembleFGE(model, optimizer):
  startEpoch = 1 
  cycle=4
  ensemble_size = 0
  t0 = time.time()
  for epoch in range(startEpoch, numEpochs+1):
    num_iters = len(trainloader)
    model.train()
    lr_schedule = cyclic_learning_rate(epoch, cycle, lr_1, lr_2)
    learning_rate = learning_rate_schedule(0.01, epoch, numEpochs)
    adjust_learning_rate(optimizer, learning_rate)
    total = 0
    for iter, (x, y) in enumerate(trainloader):
        lr = lr_schedule(iter / num_iters)
        adjust_learning_rate(optimizer, lr)
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        output = model(x)
        loss = loss_func(output, y)
        loss.backward()
        optimizer.step()
        total += y.size(0)
    if epoch == (numEpochs-1):
      running_loss = loss.item()
    print('Training loss at epoch {} is {}'.format(epoch, loss.item()))

    if (epoch % cycle + 1) == cycle // 2:
      ensemble_size += 1
      accuracy = compute_accuracy(model, testloader)
      print('Testing accuracy at epoch {} is {}'.format(epoch, accuracy))
      print('Training FGE time: {} seconds with {} emsemble size'.format((time.time() - t0), ensemble_size))

    if (epoch + 1) % (cycle // 2) == 0:
      save_checkpoint(
            '/content/drive/My Drive/AALTO/cs4875-research/archive/fge2/',
            startEpoch + epoch,
            name='fge',
            model_state=model.state_dict(),
            optimizer_state=optimizer.state_dict()
        )
  print('Number of models in ensemble is {}'.format(ensemble_size))

  return running_loss


numEpochs = 40
training_loss = []
training_brier = []
lr_1=0.05 
lr_2=0.01
optimizer = torch.optim.SGD(
    filter(lambda param: param.requires_grad, model.parameters()),
    lr=lr_1,
    momentum=0.9,
    weight_decay=5e-4
)
loss = ensembleFGE(model, optimizer)



Training loss at epoch 1 is 0.013775290921330452
Testing accuracy at epoch 1 is 0.8789
Training FGE time: 85.59196066856384 seconds with 1 emsemble size
Training loss at epoch 2 is 0.022440077736973763
Training loss at epoch 3 is 0.05244971439242363
Training loss at epoch 4 is 0.04833020642399788
Training loss at epoch 5 is 0.45623528957366943
Testing accuracy at epoch 5 is 0.8538
Training FGE time: 398.8974268436432 seconds with 2 emsemble size
Training loss at epoch 6 is 0.34786996245384216
Training loss at epoch 7 is 0.07584859430789948
Training loss at epoch 8 is 0.26434627175331116
Training loss at epoch 9 is 0.7955345511436462
Testing accuracy at epoch 9 is 0.8341
Training FGE time: 711.6412851810455 seconds with 3 emsemble size
Training loss at epoch 10 is 0.1291847974061966
Training loss at epoch 11 is 0.14213556051254272
Training loss at epoch 12 is 0.1704719513654709
Training loss at epoch 13 is 0.02439102716743946
Testing accuracy at epoch 13 is 0.8535
Training FGE time: 102

In [None]:
torch.save(model.state_dict(), '/content/drive/My Drive/AALTO/cs4875-research/archive/fge2_ensemble-80-0112.pth')
print('Model saved to %s.' % ('fge2_ensemble-80-0112.pth'))

Model saved to fge2_ensemble-80-0112.pth.


In [6]:
# train to get data - training time and accuracy
def load_model(model, filename, device):
    model.load_state_dict(torch.load(filename, map_location=lambda storage, loc: storage))
    print('Model loaded from %s.' % filename)
    model.to(device)
    model.eval()
device = torch.device('cuda:0')
loss_func = nn.CrossEntropyLoss()
fge2_ensemble = WideResNet(28, 4, 0.5)
load_model(fge2_ensemble, '/content/drive/My Drive/AALTO/cs4875-research/archive/fge2_ensemble-80-0112.pth', device)

def ensembleFGE(model, optimizer):
  startEpoch = 1 
  cycle=4
  ensemble_size = 0
  t0 = time.time()
  for epoch in range(startEpoch, numEpochs+1):
    num_iters = len(trainloader)
    model.train()
    lr_schedule = cyclic_learning_rate(epoch, cycle, lr_1, lr_2)
    learning_rate = learning_rate_schedule(0.01, epoch, numEpochs)
    adjust_learning_rate(optimizer, learning_rate)
    total = 0
    for iter, (x, y) in enumerate(trainloader):
        lr = lr_schedule(iter / num_iters)
        adjust_learning_rate(optimizer, lr)
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        output = model(x)
        loss = loss_func(output, y)
        loss.backward()
        optimizer.step()
        total += y.size(0)
    if epoch == (numEpochs-1):
      running_loss = loss.item()
    print('Training loss at epoch {} is {}'.format(epoch, loss.item()))

    if (epoch % cycle + 1) == cycle // 2:
      ensemble_size += 1
      accuracy = compute_accuracy(model, testloader)
      print('Testing accuracy at epoch {} is {}'.format(epoch, accuracy))
      print('Training FGE time: {} seconds with {} emsemble size'.format((time.time() - t0), ensemble_size))

    if (epoch + 1) % (cycle // 2) == 0:
      save_checkpoint(
            '/content/drive/My Drive/AALTO/cs4875-research/archive/fge3/',
            startEpoch + epoch,
            name='fge',
            model_state=model.state_dict(),
            optimizer_state=optimizer.state_dict()
        )
  print('Number of models in ensemble is {}'.format(ensemble_size))

  return running_loss


numEpochs = 30
training_loss = []
training_brier = []
lr_1=0.05 
lr_2=0.01
optimizer = torch.optim.SGD(
    filter(lambda param: param.requires_grad, fge2_ensemble.parameters()),
    lr=lr_1,
    momentum=0.9,
    weight_decay=5e-4
)
loss = ensembleFGE(fge2_ensemble, optimizer)



Model loaded from /content/drive/My Drive/AALTO/cs4875-research/archive/fge2_ensemble-80-0112.pth.
Training loss at epoch 1 is 0.13976126909255981
Testing accuracy at epoch 1 is 0.862
Training FGE time: 83.92398881912231 seconds with 1 emsemble size
Training loss at epoch 2 is 0.6214767098426819
Training loss at epoch 3 is 0.4708058536052704
Training loss at epoch 4 is 0.024934757500886917
Training loss at epoch 5 is 0.20621371269226074
Testing accuracy at epoch 5 is 0.8819
Training FGE time: 397.3350954055786 seconds with 2 emsemble size
Training loss at epoch 6 is 1.0988523960113525
Training loss at epoch 7 is 0.9478001594543457
Training loss at epoch 8 is 0.3338427245616913
Training loss at epoch 9 is 0.2269493043422699
Testing accuracy at epoch 9 is 0.8762
Training FGE time: 710.32301902771 seconds with 3 emsemble size
Training loss at epoch 10 is 0.5277655720710754
Training loss at epoch 11 is 0.43826571106910706
Training loss at epoch 12 is 1.214897632598877
Training loss at epoc

In [None]:
torch.save(model.state_dict(), '/content/drive/My Drive/AALTO/cs4875-research/archive/fge2_ensemble-100-0113.pth')
print('Model saved to %s.' % ('fge2_ensemble-100-0113.pth'))

In [None]:
def ensembleFGE(model, optimizer):
  startEpoch = 1 
  cycle=4
  ensemble_size = 0
  t0 = time.time()
  for epoch in range(startEpoch, numEpochs+1):
    num_iters = len(trainloader)
    model.train()
    lr_schedule = cyclic_learning_rate(epoch, cycle, lr_1, lr_2)
    learning_rate = learning_rate_schedule(0.01, epoch, numEpochs)
    adjust_learning_rate(optimizer, learning_rate)
    total = 0
    for iter, (x, y) in enumerate(trainloader):
        lr = lr_schedule(iter / num_iters)
        adjust_learning_rate(optimizer, lr)
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        output = model(x)
        loss = loss_func(output, y)
        loss.backward()
        optimizer.step()
        total += y.size(0)
    if epoch == (numEpochs-1):
      running_loss = loss.item()
    print('Training loss at epoch {} is {}'.format(epoch, loss.item()))

    if (epoch % cycle + 1) == cycle // 2:
      ensemble_size += 1
      accuracy = compute_accuracy(model, testloader)
      print('Testing accuracy at epoch {} is {}'.format(epoch, accuracy))
      print('Training FGE time: {} seconds with {} emsemble size'.format((time.time() - t0), ensemble_size))

    if (epoch + 1) % (cycle // 2) == 0:
      save_checkpoint(
            '/content/drive/My Drive/AALTO/cs4875-research/archive/fge/',
            startEpoch + epoch,
            name='fge',
            model_state=model.state_dict(),
            optimizer_state=optimizer.state_dict()
        )
  print('Number of models in ensemble is {}'.format(ensemble_size))

  return running_loss


numEpochs = 40
training_loss = []
training_brier = []
lr_1=0.05 
lr_2=0.01
optimizer = torch.optim.SGD(
    filter(lambda param: param.requires_grad, model.parameters()),
    lr=lr_1,
    momentum=0.9,
    weight_decay=5e-4
)
loss = ensembleFGE(model, optimizer)



Training loss at epoch 1 is 0.4041699171066284
Testing accuracy at epoch 1 is 0.7915
Training FGE time: 86.42820262908936 seconds with 1 emsemble size
Training loss at epoch 2 is 0.7606653571128845
Training loss at epoch 3 is 1.0812578201293945
Training loss at epoch 4 is 0.26077938079833984
Training loss at epoch 5 is 0.34065520763397217
Testing accuracy at epoch 5 is 0.8569
Training FGE time: 400.45215249061584 seconds with 2 emsemble size
Training loss at epoch 6 is 0.8679198026657104
Training loss at epoch 7 is 0.3255753815174103
Training loss at epoch 8 is 0.14015129208564758
Training loss at epoch 9 is 0.5593568682670593
Testing accuracy at epoch 9 is 0.87
Training FGE time: 712.8726127147675 seconds with 3 emsemble size
Training loss at epoch 10 is 0.5582277178764343
Training loss at epoch 11 is 1.3076554536819458
Training loss at epoch 12 is 0.14766204357147217
Training loss at epoch 13 is 0.31516849994659424
Testing accuracy at epoch 13 is 0.8732
Training FGE time: 1024.930534

In [None]:
torch.save(model.state_dict(), '/content/drive/My Drive/AALTO/cs4875-research/archive/fge_ensemble-140-0111.pth')
print('Model saved to %s.' % ('fge_ensemble-140-0111.pth'))

Model saved to fge_ensemble-140-0111.pth.
