In [1]:
import numpy as np
from tqdm import tqdm
from matplotlib import pyplot as plt
import seaborn as sns
import torch as t
from torchvision import models
from torch.utils.data import DataLoader

from bnn_priors.models import DenseNet
from bnn_priors.data import CIFAR10

In [15]:
if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True

## Testing ResNet18 model

In [2]:
# Based on https://github.com/kuangliu/pytorch-cifar/blob/master/models/preact_resnet.py

import torch
import torch.nn as nn
import torch.nn.functional as F


class PreActBlock(nn.Module):
    '''Pre-activation version of the BasicBlock.'''
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(PreActBlock, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_planes)
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)

        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
            )

    def forward(self, x):
        out = F.relu(self.bn1(x))
        shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
        out = self.conv1(out)
        out = self.conv2(F.relu(self.bn2(out)))
        out += shortcut
        return out


class PreActBottleneck(nn.Module):
    '''Pre-activation version of the original Bottleneck module.'''
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(PreActBottleneck, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_planes)
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)

        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
            )

    def forward(self, x):
        out = F.relu(self.bn1(x))
        shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
        out = self.conv1(out)
        out = self.conv2(F.relu(self.bn2(out)))
        out = self.conv3(F.relu(self.bn3(out)))
        out += shortcut
        return out


class PreActResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(PreActResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv1(x)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def PreActResNet18():
    return PreActResNet(PreActBlock, [2,2,2,2])

def PreActResNet34():
    return PreActResNet(PreActBlock, [3,4,6,3])

def PreActResNet50():
    return PreActResNet(PreActBottleneck, [3,4,6,3])

def PreActResNet101():
    return PreActResNet(PreActBottleneck, [3,4,23,3])

def PreActResNet152():
    return PreActResNet(PreActBottleneck, [3,8,36,3])

In [7]:
data = CIFAR10()

dataloader_train = DataLoader(data.norm.train, batch_size=32, shuffle=True, drop_last=True)
dataloader_test = DataLoader(data.norm.train, batch_size=32, shuffle=True, drop_last=True)

In [22]:
num_epochs = 5
lr = 5e-4

net = PreActResNet18().cuda()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=lr)

for epoch in range(num_epochs):
    with tqdm(desc=f"Epoch {epoch}", total=len(dataloader_train), leave=False) as pbar:
        for batch_x, batch_y in dataloader_train:
            batch_x, batch_y = batch_x.cuda(), batch_y.cuda()
            optimizer.zero_grad()
            y_pred = net(batch_x)
            loss = criterion(y_pred, batch_y)
            loss.backward()
            optimizer.step()
            pbar.update()
            pbar.set_postfix({"loss": f"{loss.item():.2f}"})
            
    total_acc = 0.
    num_batches = 0
    for batch_x, batch_y in dataloader_test:
        batch_x, batch_y = batch_x.cuda(), batch_y.cuda()
        y_pred = net(batch_x)
        total_acc += y_pred.argmax(axis=1).eq(batch_y).float().mean().item()
        num_batches += 1
    acc = total_acc/num_batches
    print(f"Epoch {epoch}: Test accuracy = {acc*100:.1f} %")

Epoch 1:   0%|          | 7/1562 [00:00<00:40, 38.72it/s, loss=1.24]   

Epoch 0: Test accuracy = 58.5 %


Epoch 2:   0%|          | 6/1562 [00:00<00:42, 36.94it/s, loss=0.70]   

Epoch 1: Test accuracy = 73.8 %


Epoch 3:   0%|          | 7/1562 [00:00<00:40, 38.21it/s, loss=0.55]   

Epoch 2: Test accuracy = 80.2 %


Epoch 4:   0%|          | 7/1562 [00:00<00:38, 40.07it/s, loss=0.40]   

Epoch 3: Test accuracy = 84.7 %


                                                                       

Epoch 4: Test accuracy = 86.8 %


## Building Bayesian CNNs analogous to the DenseNet class