In [75]:
import numpy as np
from tqdm import tqdm
from matplotlib import pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
from torch.utils.data import DataLoader

from bnn_priors.models import DenseNet
from bnn_priors.data import CIFAR10
from bnn_priors import prior
from bnn_priors.models import RegressionModel

In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True

## Testing ResNet18 model

In [11]:
# Based on https://github.com/kuangliu/pytorch-cifar/blob/master/models/preact_resnet.py

class PreActBlock(nn.Module):
    '''Pre-activation version of the BasicBlock.'''
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, bn=True):
        super(PreActBlock, self).__init__()
        if bn:
            batchnorm = nn.BatchNorm2d
        else:
            batchnorm = nn.Identity
        self.bn1 = batchnorm(in_planes)
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = batchnorm(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)

        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
            )

    def forward(self, x):
        out = x
        out = self.bn1(out)
        out = F.relu(out)
        shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
        out = self.conv1(out)
        out = self.bn2(out)
        out = self.conv2(F.relu(out))
        out += shortcut
        return out


class PreActBottleneck(nn.Module):
    '''Pre-activation version of the original Bottleneck module.'''
    expansion = 4

    def __init__(self, in_planes, planes, stride=1, bn=True):
        super(PreActBottleneck, self).__init__()
        if bn:
            batchnorm = nn.BatchNorm2d
        else:
            batchnorm = nn.Identity
        self.bn1 = batchnorm(in_planes)
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn2 = batchnorm(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn3 = batchnorm(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)

        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
            )

    def forward(self, x):
        out = x
        out = self.bn1(out)
        out = F.relu(out)
        shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
        out = self.conv1(out)
        out = self.bn2(out)
        out = self.conv2(F.relu(out))
        out = self.bn3(out)
        out = self.conv3(F.relu(out))
        out += shortcut
        return out


class PreActResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10, bn=True):
        super(PreActResNet, self).__init__()
        self.in_planes = 64
        self.bn = bn

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride, bn=self.bn))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv1(x)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def PreActResNet18(bn=True):
    return PreActResNet(PreActBlock, [2,2,2,2], bn=bn)

def PreActResNet34(bn=True):
    return PreActResNet(PreActBlock, [3,4,6,3], bn=bn)

def PreActResNet50(bn=True):
    return PreActResNet(PreActBottleneck, [3,4,6,3], bn=bn)

def PreActResNet101(bn=True):
    return PreActResNet(PreActBottleneck, [3,4,23,3], bn=bn)

def PreActResNet152(bn=True):
    return PreActResNet(PreActBottleneck, [3,8,36,3], bn=bn)

In [9]:
data = CIFAR10()

dataloader_train = DataLoader(data.norm.train, batch_size=32, shuffle=True, drop_last=True)
dataloader_test = DataLoader(data.norm.test, batch_size=32, shuffle=True, drop_last=True)

Files already downloaded and verified
Files already downloaded and verified


In [15]:
num_epochs = 5
lr = 5e-4

net = PreActResNet18(bn=True).to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=lr)

for epoch in range(num_epochs):
    net.train()
    with tqdm(desc=f"Epoch {epoch}", total=len(dataloader_train), leave=False) as pbar:
        for batch_x, batch_y in dataloader_train:
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)
            optimizer.zero_grad()
            y_pred = net(batch_x)
            loss = criterion(y_pred, batch_y)
            loss.backward()
            optimizer.step()
            pbar.update()
            pbar.set_postfix({"loss": f"{loss.item():.2f}"})
            
    total_acc = 0.
    num_batches = 0
    net.eval()
    with torch.no_grad():
        for batch_x, batch_y in dataloader_test:
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)
            y_pred = net(batch_x)
            total_acc += y_pred.argmax(axis=1).eq(batch_y).float().mean().item()
            num_batches += 1
    acc = total_acc/num_batches
    print(f"Epoch {epoch}: Test accuracy = {acc*100:.1f} %")

Epoch 1:   0%|          | 7/1562 [00:00<00:38, 40.04it/s, loss=1.40]   

Epoch 0: Test accuracy = 56.4 %


Epoch 2:   0%|          | 7/1562 [00:00<00:40, 38.02it/s, loss=1.18]   

Epoch 1: Test accuracy = 69.6 %


Epoch 3:   0%|          | 7/1562 [00:00<00:39, 39.38it/s, loss=0.63]   

Epoch 2: Test accuracy = 77.6 %


Epoch 4:   0%|          | 7/1562 [00:00<00:39, 39.86it/s, loss=0.42]   

Epoch 3: Test accuracy = 77.0 %


                                                                       

Epoch 4: Test accuracy = 79.4 %


## Building Bayesian CNNs analogous to the DenseNet class

In [16]:
class Linear(nn.Linear):
    def __init__(self, weight_prior, bias_prior=None):
        nn.Module.__init__(self)
        (self.out_features, self.in_features) = weight_prior.p.shape
        self.weight_prior = weight_prior
        self.bias_prior = bias_prior

    @property
    def weight(self):
        return self.weight_prior()

    @property
    def bias(self):
        return (None if self.bias_prior is None else self.bias_prior())

In [18]:
def LinearPrior(in_dim, out_dim, prior_w=prior.Normal, loc_w=0., std_w=1.,
                     prior_b=prior.Normal, loc_b=0., std_b=1., scaling_fn=None):
    if scaling_fn is None:
        def scaling_fn(std, dim):
            return std/dim**0.5
    return Linear(prior_w((out_dim, in_dim), loc_w, scaling_fn(std_w, in_dim)),
                  prior_b((out_dim,), 0., std_b))

In [61]:
class Conv2d(nn.Conv2d):
    def __init__(self, weight_prior, bias_prior=None, stride=1,
            padding=0, dilation=1, groups=1, padding_mode='zeros'):
        nn.Module.__init__(self)
        
        self.stride = nn.modules.utils._pair(stride)
        self.padding = nn.modules.utils._pair(padding)
        self.dilation = nn.modules.utils._pair(dilation)
        self.groups = groups
        self.padding_mode = padding_mode
        self.transposed = False
        self.output_padding = nn.modules.utils._pair(0)
        
        (self.out_channels, in_channels, ksize_0, ksize_1) = weight_prior.p.shape
        self.in_channels = in_channels * self.groups
        self.kernel_size = (ksize_0, ksize_1)
        self.weight_prior = weight_prior
        self.bias_prior = bias_prior

    @property
    def weight(self):
        return self.weight_prior()

    @property
    def bias(self):
        return (None if self.bias_prior is None else self.bias_prior())

In [78]:
def Conv2dPrior(in_channels, out_channels, kernel_size, stride=1,
            padding=0, dilation=1, groups=1, padding_mode='zeros',
            prior_w=prior.Normal, loc_w=0., std_w=1., prior_b=prior.Normal,
            loc_b=0., std_b=1., scaling_fn=None):
    if scaling_fn is None:
        def scaling_fn(std, dim):
            return std/dim**0.5
    kernel_size = nn.modules.utils._pair(kernel_size)
    bias_prior = prior_b((out_channels,), 0., std_b) if prior_b is not None else None
    return Conv2d(weight_prior=prior_w((out_channels, in_channels//groups, kernel_size[0], kernel_size[1]),
                                       loc_w, scaling_fn(std_w, in_channels)),
                  bias_prior=bias_prior,
                 stride=stride, padding=padding, dilation=dilation,
                  groups=groups, padding_mode=padding_mode)

In [79]:
conv_test = Conv2dPrior(3, 16, 3, padding=1)

In [80]:
conv_test.to(device)

Conv2d(
  3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
  (weight_prior): Normal()
  (bias_prior): Normal()
)

In [81]:
conv_test(batch_x).shape

torch.Size([32, 16, 32, 32])

In [91]:
class PreActBlock(nn.Module):
    '''Pre-activation version of the BasicBlock.'''
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, bn=True,
                 prior_w=prior.Normal, loc_w=0., std_w=2**.5,
                 prior_b=prior.Normal, loc_b=0., std_b=1.,
                scaling_fn=None):
        super(PreActBlock, self).__init__()
        if bn:
            batchnorm = nn.BatchNorm2d
        else:
            batchnorm = nn.Identity
        self.bn1 = batchnorm(in_planes)
        self.conv1 = Conv2dPrior(in_planes, planes, kernel_size=3, stride=stride, padding=1,
                                 prior_w=prior_w, loc_w=loc_w, std_w=std_w,
                                 prior_b=None, scaling_fn=scaling_fn)
        self.bn2 = batchnorm(planes)
        self.conv2 = Conv2dPrior(planes, planes, kernel_size=3, stride=1, padding=1,
                                 prior_w=prior_w, loc_w=loc_w, std_w=std_w,
                                 prior_b=None, scaling_fn=scaling_fn)

        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
            )

    def forward(self, x):
        out = x
        out = self.bn1(out)
        out = F.relu(out)
        shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
        out = self.conv1(out)
        out = self.bn2(out)
        out = self.conv2(F.relu(out))
        out += shortcut
        return out


class PreActResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10, bn=True,
                 prior_w=prior.Normal, loc_w=0., std_w=2**.5,
                 prior_b=prior.Normal, loc_b=0., std_b=1.,
                scaling_fn=None):
        super(PreActResNet, self).__init__()
        self.in_planes = 64
        self.bn = bn
        self.prior_w = prior_w
        self.loc_w = loc_w
        self.std_w = std_w
        self.prior_b = prior_b
        self.loc_b = loc_b
        self.std_b = std_b
        self.scaling_fn = scaling_fn

        self.conv1 = Conv2dPrior(3, 64, kernel_size=3, stride=1, padding=1, prior_b=None,
                           prior_w=self.prior_w, loc_w=self.loc_w, std_w=self.std_w,
                           scaling_fn=self.scaling_fn)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = LinearPrior(512*block.expansion, num_classes,
                            prior_w=self.prior_w, loc_w=self.loc_w, std_w=self.std_w,
                            prior_b=self.prior_b, loc_b=self.loc_b, std_b=self.std_b,
                            scaling_fn=self.scaling_fn)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride, bn=self.bn,
                                prior_w=self.prior_w, loc_w=self.loc_w, std_w=self.std_w,
                                prior_b=self.prior_b, loc_b=self.loc_b, std_b=self.std_b,
                                scaling_fn=self.scaling_fn))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv1(x)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def PreActResNet18(noise_std=1.,
             prior_w=prior.Normal, loc_w=0., std_w=2**.5,
             prior_b=prior.Normal, loc_b=0., std_b=1.,
            scaling_fn=None, bn=True):
    return RegressionModel(PreActResNet(PreActBlock,
                                        [2,2,2,2], bn=bn,
                                        prior_w=prior_w,
                                       loc_w=loc_w,
                                       std_w=std_w,
                                       prior_b=prior_b,
                                       loc_b=loc_b,
                                       std_b=std_b,
                                       scaling_fn=scaling_fn,), noise_std)

def PreActResNet34(noise_std=1.,
             prior_w=prior.Normal, loc_w=0., std_w=2**.5,
             prior_b=prior.Normal, loc_b=0., std_b=1.,
            scaling_fn=None, bn=True):
    return RegressionModel(PreActResNet(PreActBlock,
                                        [3,4,6,3], bn=bn,
                                        prior_w=prior_w,
                                       loc_w=loc_w,
                                       std_w=std_w,
                                       prior_b=prior_b,
                                       loc_b=loc_b,
                                       std_b=std_b,
                                       scaling_fn=scaling_fn,), noise_std)

In [95]:
net_test = PreActResNet18(bn=False)

In [96]:
net_test.to(device)

RegressionModel(
  (net): PreActResNet(
    (conv1): Conv2d(
      3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
      (weight_prior): Normal()
    )
    (layer1): Sequential(
      (0): PreActBlock(
        (bn1): Identity()
        (conv1): Conv2d(
          64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
          (weight_prior): Normal()
        )
        (bn2): Identity()
        (conv2): Conv2d(
          64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
          (weight_prior): Normal()
        )
      )
      (1): PreActBlock(
        (bn1): Identity()
        (conv1): Conv2d(
          64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
          (weight_prior): Normal()
        )
        (bn2): Identity()
        (conv2): Conv2d(
          64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
          (weight_prior): Normal()
        )
      )
    )
    (layer2): Sequential(
 

In [97]:
net_test(batch_x).mean

tensor([[-1.2338e+08,  2.4759e+07,  5.2194e+07, -2.1589e+07, -9.9581e+07,
          5.3191e+07,  3.0906e+07, -4.0408e+07, -1.6005e+07,  7.4893e+07],
        [-1.1931e+08,  1.0338e+07,  3.8562e+07, -4.2010e+07, -8.4916e+07,
          5.5311e+07,  8.9196e+06, -3.5700e+06, -4.6934e+06,  6.2197e+07],
        [-2.0265e+08,  2.5444e+07,  5.9272e+07, -6.0113e+07, -1.4750e+08,
          1.0304e+08,  6.5790e+07, -1.4920e+07, -8.4617e+05,  9.4435e+07],
        [-1.9261e+08,  3.0068e+07,  6.2641e+07, -6.2738e+07, -1.3065e+08,
          1.3270e+08,  4.1909e+07, -4.4233e+07, -2.7670e+07,  1.1862e+08],
        [-1.3359e+08,  1.0180e+07,  3.4459e+07, -3.2965e+07, -8.7067e+07,
          7.0907e+07,  2.6929e+07, -1.4885e+07, -1.8933e+07,  7.6678e+07],
        [-1.9609e+08,  3.5271e+07,  6.6672e+07, -4.2920e+07, -1.6181e+08,
          9.5549e+07,  4.1004e+07, -3.4805e+07, -1.2914e+07,  8.2691e+07],
        [-1.3120e+08,  2.5098e+07,  4.3022e+07, -3.4870e+07, -1.1319e+08,
          5.9634e+07,  3.6942e+0