In [1]:
"""
Quick example: A small second-order optimizer with BackPACK
on the classic MNIST example from PyTorch,
https://github.com/pytorch/examples/blob/master/mnist/main.py

The optimizer we implement uses a constant damping parameter
and uses the diagonal of the GGN/Fisher matrix as a preconditioner;

```
x_{t+1} = x_t - (G_t + bI)^{-1} g_t
```

- `x_t` are the parameters of the model
- `G_t` is the diagonal of the Gauss-Newton/Fisher matrix at `x_t`
- `b` is a damping parameter
- `g_t` is the gradient

"""

import torch
import numpy as np
import torchvision
# The main BackPACK functionalities
#from backpack import backpack, extend
# The diagonal GGN extension
#from backpack.extensions import DiagGGNMC
# This layer did not exist in Pytorch 1.0
#from backpack.core.layers import Flatten

# Hyperparameters
BATCH_SIZE = 64
STEP_SIZE_TRAIN = 0.00001 #from 0.01
STEP_SIZE_HESS = 0.0001
DAMPING = 1.0
MAX_ITER = 6
torch.manual_seed(0)

def get_accuracy(output, targets):
    """Helper function to print the accuracy"""
    predictions = output.argmax(dim=1, keepdim=True).view_as(targets)
    return predictions.eq(targets).float().mean().item()

In [2]:
"""
Step 2: Create the optimizer.

After we call the backward pass with backpack,
every parameter will have a `diag_ggn_mc` field
in addition to a `grad` field.

We can use it to compute the search direction for that parameter,
```
step_direction = p.grad / (p.diag_ggn_mc + group["damping"])
```
and update the weights
"""


class DiagGGNOptimizer(torch.optim.Optimizer):
    def __init__(self, parameters, step_size, damping):
        super().__init__(
            parameters, 
            dict(step_size=step_size, damping=damping)
        )

    def step(self):
        for group in self.param_groups:
            for p in group["params"]:
                step_direction = p.grad / (p.diag_ggn_mc + group["damping"])
                p.data.add_(-group["step_size"], step_direction)
        return loss


In [3]:
def get_second_order_infos(model, data_loader, optimizer, num_classes, batch_size, save_path):
    print("get the second order information for the trained model")

    mus_last_layer_w = []
    mus_last_layer_b = []
    sigmas_last_layer_w = []
    sigmas_last_layer_b = []
    
    extend(model)
    extend(loss_function)


    for hessian_idx, (x,y) in enumerate(data_loader):

        output = model(x)
        max_len = int(np.ceil(len(data_loader.dataset)/batch_size))

        accuracy = get_accuracy(output, y)

        with backpack(DiagGGNMC()):
            loss = loss_function(output, y)
            loss.backward()
            optimizer.step()

        print("{}/{} batches \t acc: {:3f}".format(hessian_idx, max_len, accuracy))
        groups = optimizer.param_groups
        #print("groups: ", groups)
        mus = []
        sigmas = []
        for g in groups:
            for p in g['params']:
                #print("p.size: ", p.size())
                mus.append(p)
                #print("diag size: ", p.diag_ggn_mc.size())
                sigmas.append(p.diag_ggn_mc)


        #print("mus w: ", mus[-2])
        #print("sigma w: ", sigmas[-2])
        mus_last_layer_w.append(mus[-2].detach())
        mus_last_layer_b.append(mus[-1].detach())
        sigmas_last_layer_w.append(sigmas[-2].detach())
        sigmas_last_layer_b.append(sigmas[-1].detach())

        #if hessian_idx >= 10:
        #    break

    print("saving model at: {}".format(save_path))
    torch.save(model.state_dict(), save_path)

    #print(mus_last_layer_w).size()
    mus_w_mean = torch.cat(mus_last_layer_w).view(-1, num_classes, 500).mean(dim=0)
    mus_b_mean = torch.cat(mus_last_layer_b).view(-1, num_classes).mean(dim=0)
    sigmas_w_mean = torch.cat(sigmas_last_layer_w).view(-1, num_classes, 500).mean(dim=0)
    sigmas_b_mean = torch.cat(sigmas_last_layer_b).view(-1, num_classes).mean(dim=0)
    print("mus w mean: ", mus_w_mean.size())
    print("mus b mean: ", mus_b_mean.size())
    print("sigmas w mean: ", sigmas_w_mean.size())
    print("sigmas b mean: ", sigmas_b_mean.size())
    
    return(mus_w_mean, mus_b_mean, sigmas_w_mean, sigmas_b_mean)


In [4]:
def get_mu_from_Dirichlet(alpha):
    K = len(alpha)
    mu = torch.zeros(K)
    for i in range(K):
        mu_i = torch.log(alpha[i]) - 1/K * torch.sum(torch.log(alpha))
        mu[i] = mu_i
        
    return(torch.Tensor(mu))

def get_Sigma_from_Dirichlet(alpha):
    K = len(alpha)
    sum_of_inv = 1/K * torch.sum(1/alpha)
    Sigma = torch.zeros((K,K))
    for k in range(K):
        for l in range(K):
            delta = 1 if k==l else 0
            Sigma[k][l] = delta * 1/alpha[k] - 1/K*(1/alpha[k] + 1/alpha[l] - sum_of_inv)
            
    return(torch.Tensor(Sigma))


def get_alpha_from_Normal(mu, Sigma):
    alpha = []
    K = len(Sigma[0])
    sum_exp = torch.sum(torch.exp(-1*torch.Tensor(mu)))
    for k in range(K):
        alpha.append(1/Sigma[k][k]*(1 - 2/K + torch.exp(mu[k])/K**2 * sum_exp))
        
    return(torch.Tensor(alpha))

def get_Gaussian_output(x, mu_w, mu_b, sigma_w, sigma_b):
    #get the distributions per class
    n = len(mu_b)
    x = x.view(-1)
    #per_class_mus = []
    per_class_sigmas = []
    for i in range(n):
        #per_class_mus.append(torch.matmul(mu_w[i], x) + mu_b[i])
        #create a diagonal Hessian
        hess = torch.diag(sigma_w[i])
        per_class_sigmas.append(torch.matmul(torch.matmul(x, hess), x) + sigma_b[i])
    
    per_class_mus = torch.matmul(mu_w, x) + mu_b
    mu = torch.tensor(per_class_mus)
    Sigma = torch.diag(torch.tensor(per_class_sigmas))
    #print(mu.size(), Sigma.size())
    return(mu, Sigma)

alpha_test = torch.Tensor([1.3, 1.3])
#print("alpha start: ", alpha_test)
mu_test = get_mu_from_Dirichlet(alpha_test)
#print("mu: ", mu_test)
Sigma_test = get_Sigma_from_Dirichlet(alpha_test)
#print("Sigma: ", Sigma_test)
alpha_end = get_alpha_from_Normal(mu_test, Sigma_test)
#print("alpha end: ", alpha_end)
assert(alpha_test.equal(alpha_end))

In [5]:
STEP_SIZE_TRAIN_CIFAR100_DN = 1 * 10e-6
BATCH_SIZE_CIFAR100_DN = 128
MAX_ITER_CIFAR100_DN = 70
TEST_BATCH_SIZE_CIFAR100_DN = 1
HESS_BATCH_SIZE_CIFAR100_DN = 128
STEP_SIZE_HESS_CIFAR100_DN = 1 * 10e-8

In [6]:
transform = torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor(),
     torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR100(root='./cifar100', train=True,
                                        download=True, transform=transform)

train_size = int(0.8 * len(trainset))
val_size = len(trainset) - train_size
CIFAR100_train_dataset, CIFAR100_val_dataset = torch.utils.data.random_split(trainset, [train_size, val_size])

transform = torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor(),
     torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])


CIFAR100_train_loader = torch.utils.data.DataLoader(CIFAR100_train_dataset, batch_size=BATCH_SIZE_CIFAR100_DN,
                                          shuffle=True)

CIFAR100_val_loader = torch.utils.data.DataLoader(CIFAR100_val_dataset, batch_size=BATCH_SIZE_CIFAR100_DN,
                                          shuffle=True)

testset = torchvision.datasets.CIFAR100(root='./cifar100', train=False,
                                       download=True, transform=transform)
CIFAR100_test_loader = torch.utils.data.DataLoader(testset, batch_size=TEST_BATCH_SIZE_CIFAR100_DN,
                                         shuffle=False)

CIFAR100_classes = [
    'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', 
    'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel', 
    'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock', 
    'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur', 
    'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', 
    'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion',
    'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse',
    'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear',
    'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine',
    'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose',
    'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake',
    'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table',
    'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout',
    'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman',
    'worm'
]

Files already downloaded and verified
Files already downloaded and verified


In [7]:
# just rip the code of densenet on pytorch
import re
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from collections import OrderedDict
from torchvision.models.utils import load_state_dict_from_url


__all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161']

model_urls = {
    'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth',
    'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth',
    'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth',
    'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth',
}


def _bn_function_factory(norm, relu, conv):
    def bn_function(*inputs):
        concated_features = torch.cat(inputs, 1)
        bottleneck_output = conv(relu(norm(concated_features)))
        return bottleneck_output

    return bn_function


class _DenseLayer(nn.Sequential):
    def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, memory_efficient=False):
        super(_DenseLayer, self).__init__()
        self.add_module('norm1', nn.BatchNorm2d(num_input_features)),
        self.add_module('relu1', nn.ReLU(inplace=True)),
        self.add_module('conv1', nn.Conv2d(num_input_features, bn_size *
                                           growth_rate, kernel_size=1, stride=1,
                                           bias=False)),
        self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)),
        self.add_module('relu2', nn.ReLU(inplace=True)),
        self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate,
                                           kernel_size=3, stride=1, padding=1,
                                           bias=False)),
        self.drop_rate = drop_rate
        self.memory_efficient = memory_efficient

    def forward(self, *prev_features):
        bn_function = _bn_function_factory(self.norm1, self.relu1, self.conv1)
        if self.memory_efficient and any(prev_feature.requires_grad for prev_feature in prev_features):
            bottleneck_output = cp.checkpoint(bn_function, *prev_features)
        else:
            bottleneck_output = bn_function(*prev_features)
        new_features = self.conv2(self.relu2(self.norm2(bottleneck_output)))
        if self.drop_rate > 0:
            new_features = F.dropout(new_features, p=self.drop_rate,
                                     training=self.training)
        return new_features


class _DenseBlock(nn.Module):
    def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, memory_efficient=False):
        super(_DenseBlock, self).__init__()
        for i in range(num_layers):
            layer = _DenseLayer(
                num_input_features + i * growth_rate,
                growth_rate=growth_rate,
                bn_size=bn_size,
                drop_rate=drop_rate,
                memory_efficient=memory_efficient,
            )
            self.add_module('denselayer%d' % (i + 1), layer)

    def forward(self, init_features):
        features = [init_features]
        for name, layer in self.named_children():
            new_features = layer(*features)
            features.append(new_features)
        return torch.cat(features, 1)


class _Transition(nn.Sequential):
    def __init__(self, num_input_features, num_output_features):
        super(_Transition, self).__init__()
        self.add_module('norm', nn.BatchNorm2d(num_input_features))
        self.add_module('relu', nn.ReLU(inplace=True))
        self.add_module('conv', nn.Conv2d(num_input_features, num_output_features,
                                          kernel_size=1, stride=1, bias=False))
        self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))


class DenseNet(nn.Module):
    r"""Densenet-BC model class, based on
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_

    Args:
        growth_rate (int) - how many filters to add each layer (`k` in paper)
        block_config (list of 4 ints) - how many layers in each pooling block
        num_init_features (int) - the number of filters to learn in the first convolution layer
        bn_size (int) - multiplicative factor for number of bottle neck layers
          (i.e. bn_size * k features in the bottleneck layer)
        drop_rate (float) - dropout rate after each dense layer
        num_classes (int) - number of classification classes
        memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
          but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_
    """

    def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),
                 num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000, memory_efficient=False):

        super(DenseNet, self).__init__()

        # First convolution
        self.features = nn.Sequential(OrderedDict([
            ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2,
                                padding=3, bias=False)),
            ('norm0', nn.BatchNorm2d(num_init_features)),
            ('relu0', nn.ReLU(inplace=True)),
            ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
        ]))

        # Each denseblock
        num_features = num_init_features
        for i, num_layers in enumerate(block_config):
            block = _DenseBlock(
                num_layers=num_layers,
                num_input_features=num_features,
                bn_size=bn_size,
                growth_rate=growth_rate,
                drop_rate=drop_rate,
                memory_efficient=memory_efficient
            )
            self.features.add_module('denseblock%d' % (i + 1), block)
            num_features = num_features + num_layers * growth_rate
            if i != len(block_config) - 1:
                trans = _Transition(num_input_features=num_features,
                                    num_output_features=num_features // 2)
                self.features.add_module('transition%d' % (i + 1), trans)
                num_features = num_features // 2

        # Final batch norm
        self.features.add_module('norm5', nn.BatchNorm2d(num_features))

        # Linear layer
        self.classifier = nn.Linear(num_features, num_classes)

        # Official init from torch repo.
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.constant_(m.bias, 0)

    def forward(self, x, phi=False):  #added phi
        features = self.features(x)
        out = F.relu(features, inplace=True)
        out = F.adaptive_avg_pool2d(out, (1, 1))
        out = torch.flatten(out, 1)
        if phi:              #changed
            return(out)      #changed
        else:               #changed
            out = self.classifier(out)     #changed
            return out            #changed


def _load_state_dict(model, model_url, progress):
    # '.'s are no longer allowed in module names, but previous _DenseLayer
    # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
    # They are also in the checkpoints in model_urls. This pattern is used
    # to find such keys.
    pattern = re.compile(
        r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')

    state_dict = load_state_dict_from_url(model_url, progress=progress)
    for key in list(state_dict.keys()):
        res = pattern.match(key)
        if res:
            new_key = res.group(1) + res.group(2)
            state_dict[new_key] = state_dict[key]
            del state_dict[key]
    model.load_state_dict(state_dict)


def _densenet(arch, growth_rate, block_config, num_init_features, pretrained, progress,
              **kwargs):
    model = DenseNet(growth_rate, block_config, num_init_features, **kwargs)
    if pretrained:
        _load_state_dict(model, model_urls[arch], progress)
    return model


def densenet121(pretrained=False, progress=True, **kwargs):
    r"""Densenet-121 model from
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
        memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
          but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_
    """
    return _densenet('densenet121', 32, (6, 12, 24, 16), 64, pretrained, progress,
                     **kwargs)



def densenet161(pretrained=False, progress=True, **kwargs):
    r"""Densenet-161 model from
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
        memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
          but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_
    """
    return _densenet('densenet161', 48, (6, 12, 36, 24), 96, pretrained, progress,
                     **kwargs)



def densenet169(pretrained=False, progress=True, **kwargs):
    r"""Densenet-169 model from
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
        memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
          but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_
    """
    return _densenet('densenet169', 32, (6, 12, 32, 32), 64, pretrained, progress,
                     **kwargs)



def densenet201(pretrained=False, progress=True, **kwargs):
    r"""Densenet-201 model from
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
        memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
          but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_
    """
    return _densenet('densenet201', 32, (6, 12, 48, 32), 64, pretrained, progress,
                     **kwargs)

In [8]:
CIFAR100_model_DN = densenet201(pretrained=True).cuda()
loss_function = torch.nn.CrossEntropyLoss()       

#CIFAR100_DN_train_optimizer = torch.optim.Adam(CIFAR100_model_DN.parameters(), lr=STEP_SIZE_TRAIN_CIFAR100_DN)
CIFAR100_DN_train_optimizer = torch.optim.SGD(CIFAR100_model_DN.parameters(), lr=STEP_SIZE_TRAIN_CIFAR100_DN)

In [None]:
CIFAR100_DN_PATH = "models/CIFAR100_DN_model_SGD.pth"
current_best_mva = 0

for iter in range(MAX_ITER_CIFAR100_DN):
    train_acc = []
    for batch_idx, (x, y) in enumerate(CIFAR100_train_loader):
        x, y = x.cuda(), y.cuda()
        max_len = int(np.ceil(len(CIFAR100_train_loader.dataset)/BATCH_SIZE_CIFAR100_DN))
        output = CIFAR100_model_DN(x, phi=False)

        accuracy = get_accuracy(output, y)
        train_acc.append(accuracy)

        #with backpack(DiagGGNMC()):
        #    loss = loss_function(output, y)
        #    loss.backward()
        #    mnist_optimizer.step()
        
        loss = loss_function(output, y)
        loss.backward()
        CIFAR100_DN_train_optimizer.step()

        print(
            "Iteration {}; {}/{} \t".format(iter, batch_idx, max_len) +
            "Minibatch Loss %.3f  " % (loss) +
            "Accuracy %.0f" % (accuracy * 100) + "%"
        )
        
    val_acc = []
    for batch_idx, (x, y) in enumerate(CIFAR100_val_loader):
        x, y = x.cuda(), y.cuda()
        max_len = int(np.ceil(len(CIFAR100_val_loader.dataset)/BATCH_SIZE_CIFAR100_DN))
        output = CIFAR100_model_DN(x, phi=False)

        accuracy = get_accuracy(output, y)
        val_acc.append(accuracy)
        print("Iteration {}; {}/{} \t".format(iter, batch_idx, max_len) + 
              "Validation accuracy: {:.03f}".format(accuracy))

    mva = np.mean(val_acc)
    mta = np.mean(train_acc)
    print("mean train acc: ", mta)
    print("mean val acc: ", mva)
    
    
    if mva > current_best_mva:
        current_best_mva = mva
        print("saving model at: {}".format(CIFAR100_DN_PATH))
        torch.save(CIFAR100_model_DN.state_dict(), CIFAR100_DN_PATH)



Iteration 0; 0/313 	Minibatch Loss 9.373  Accuracy 1%
Iteration 0; 1/313 	Minibatch Loss 9.071  Accuracy 0%
