In [1]:
import numpy as np
from tqdm import tqdm
from matplotlib import pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
from torch.utils.data import DataLoader

from bnn_priors.models import DenseNet
from bnn_priors.data import CIFAR10
from bnn_priors import prior
from bnn_priors.models import RegressionModel, LinearPrior, ClassificationModel

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True

## Testing ResNet18 model

In [3]:
# Based on https://github.com/kuangliu/pytorch-cifar/blob/master/models/preact_resnet.py

class PreActBlock(nn.Module):
    '''Pre-activation version of the BasicBlock.'''
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, bn=True):
        super(PreActBlock, self).__init__()
        if bn:
            batchnorm = nn.BatchNorm2d
        else:
            batchnorm = nn.Identity
        self.bn1 = batchnorm(in_planes)
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = batchnorm(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)

        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
            )

    def forward(self, x):
        out = x
        out = self.bn1(out)
        out = F.relu(out)
        shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
        out = self.conv1(out)
        out = self.bn2(out)
        out = self.conv2(F.relu(out))
        out += shortcut
        return out


class PreActBottleneck(nn.Module):
    '''Pre-activation version of the original Bottleneck module.'''
    expansion = 4

    def __init__(self, in_planes, planes, stride=1, bn=True):
        super(PreActBottleneck, self).__init__()
        if bn:
            batchnorm = nn.BatchNorm2d
        else:
            batchnorm = nn.Identity
        self.bn1 = batchnorm(in_planes)
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn2 = batchnorm(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn3 = batchnorm(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)

        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
            )

    def forward(self, x):
        out = x
        out = self.bn1(out)
        out = F.relu(out)
        shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
        out = self.conv1(out)
        out = self.bn2(out)
        out = self.conv2(F.relu(out))
        out = self.bn3(out)
        out = self.conv3(F.relu(out))
        out += shortcut
        return out


class PreActResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10, bn=True):
        super(PreActResNet, self).__init__()
        self.in_planes = 64
        self.bn = bn

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride, bn=self.bn))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv1(x)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def PreActResNet18(bn=True):
    return PreActResNet(PreActBlock, [2,2,2,2], bn=bn)

def PreActResNet34(bn=True):
    return PreActResNet(PreActBlock, [3,4,6,3], bn=bn)

def PreActResNet50(bn=True):
    return PreActResNet(PreActBottleneck, [3,4,6,3], bn=bn)

def PreActResNet101(bn=True):
    return PreActResNet(PreActBottleneck, [3,4,23,3], bn=bn)

def PreActResNet152(bn=True):
    return PreActResNet(PreActBottleneck, [3,8,36,3], bn=bn)

In [4]:
data = CIFAR10(device=device)

dataloader_train = DataLoader(data.norm.train, batch_size=32, shuffle=True, drop_last=True)
dataloader_test = DataLoader(data.norm.test, batch_size=32, shuffle=True, drop_last=True)

Files already downloaded and verified
Files already downloaded and verified


In [5]:
num_epochs = 5
lr = 5e-4

net = PreActResNet18(bn=True).to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=lr)

for epoch in range(num_epochs):
    net.train()
    with tqdm(desc=f"Epoch {epoch}", total=len(dataloader_train), leave=False) as pbar:
        for batch_x, batch_y in dataloader_train:
            optimizer.zero_grad()
            y_pred = net(batch_x)
            loss = criterion(y_pred, batch_y)
            loss.backward()
            optimizer.step()
            pbar.update()
            pbar.set_postfix({"loss": f"{loss.item():.2f}"})
            
    total_acc = 0.
    num_batches = 0
    net.eval()
    with torch.no_grad():
        for batch_x, batch_y in dataloader_test:
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)
            y_pred = net(batch_x)
            total_acc += y_pred.argmax(axis=1).eq(batch_y).float().mean().item()
            num_batches += 1
    acc = total_acc/num_batches
    print(f"Epoch {epoch}: Test accuracy = {acc*100:.1f} %")

Epoch 1:   0%|          | 7/1562 [00:00<00:40, 37.93it/s, loss=0.92]   

Epoch 0: Test accuracy = 54.6 %


Epoch 2:   0%|          | 6/1562 [00:00<00:41, 37.09it/s, loss=0.70]   

Epoch 1: Test accuracy = 68.9 %


Epoch 3:   0%|          | 6/1562 [00:00<00:42, 36.94it/s, loss=0.60]   

Epoch 2: Test accuracy = 76.6 %


Epoch 4:   0%|          | 6/1562 [00:00<00:48, 31.95it/s, loss=0.60]   

Epoch 3: Test accuracy = 77.8 %


                                                                       

Epoch 4: Test accuracy = 80.5 %


## Building Bayesian CNNs analogous to the DenseNet class

In [20]:
class Conv2d(nn.Conv2d):
    def __init__(self, weight_prior, bias_prior=None, stride=1,
            padding=0, dilation=1, groups=1, padding_mode='zeros'):
        nn.Module.__init__(self)
        
        self.stride = nn.modules.utils._pair(stride)
        self.padding = nn.modules.utils._pair(padding)
        self.dilation = nn.modules.utils._pair(dilation)
        self.groups = groups
        self.padding_mode = padding_mode
        self.transposed = False
        self.output_padding = nn.modules.utils._pair(0)
        
        (self.out_channels, in_channels, ksize_0, ksize_1) = weight_prior.p.shape
        self.in_channels = in_channels * self.groups
        self.kernel_size = (ksize_0, ksize_1)
        self.weight_prior = weight_prior
        self.bias_prior = bias_prior

    @property
    def weight(self):
        return self.weight_prior()

    @property
    def bias(self):
        return (None if self.bias_prior is None else self.bias_prior())

In [21]:
def Conv2dPrior(in_channels, out_channels, kernel_size, stride=1,
            padding=0, dilation=1, groups=1, padding_mode='zeros',
            prior_w=prior.Normal, loc_w=0., std_w=1., prior_b=prior.Normal,
            loc_b=0., std_b=1., scaling_fn=None):
    if scaling_fn is None:
        def scaling_fn(std, dim):
            return std/dim**0.5
    kernel_size = nn.modules.utils._pair(kernel_size)
    bias_prior = prior_b((out_channels,), 0., std_b) if prior_b is not None else None
    return Conv2d(weight_prior=prior_w((out_channels, in_channels//groups, kernel_size[0], kernel_size[1]),
                                       loc_w, scaling_fn(std_w, in_channels)),
                  bias_prior=bias_prior,
                 stride=stride, padding=padding, dilation=dilation,
                  groups=groups, padding_mode=padding_mode)

In [22]:
class PreActBlock(nn.Module):
    '''Pre-activation version of the BasicBlock.'''
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, bn=True,
                 prior_w=prior.Normal, loc_w=0., std_w=2**.5,
                 prior_b=prior.Normal, loc_b=0., std_b=1.,
                scaling_fn=None):
        super(PreActBlock, self).__init__()
        if bn:
            batchnorm = nn.BatchNorm2d
        else:
            batchnorm = nn.Identity
        self.bn1 = batchnorm(in_planes)
        self.conv1 = Conv2dPrior(in_planes, planes, kernel_size=3, stride=stride, padding=1,
                                 prior_w=prior_w, loc_w=loc_w, std_w=std_w,
                                 prior_b=None, scaling_fn=scaling_fn)
        self.bn2 = batchnorm(planes)
        self.conv2 = Conv2dPrior(planes, planes, kernel_size=3, stride=1, padding=1,
                                 prior_w=prior_w, loc_w=loc_w, std_w=std_w,
                                 prior_b=None, scaling_fn=scaling_fn)

        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                Conv2dPrior(in_planes, self.expansion*planes, kernel_size=1, stride=stride,
                                 prior_w=prior_w, loc_w=loc_w, std_w=std_w,
                                 prior_b=None, scaling_fn=scaling_fn)
            )

    def forward(self, x):
        out = x
        out = self.bn1(out)
        out = F.relu(out)
        shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
        out = self.conv1(out)
        out = self.bn2(out)
        out = self.conv2(F.relu(out))
        out += shortcut
        return out


class PreActResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10, bn=True,
                 prior_w=prior.Normal, loc_w=0., std_w=2**.5,
                 prior_b=prior.Normal, loc_b=0., std_b=1.,
                scaling_fn=None):
        super(PreActResNet, self).__init__()
        self.in_planes = 64
        self.bn = bn
        self.prior_w = prior_w
        self.loc_w = loc_w
        self.std_w = std_w
        self.prior_b = prior_b
        self.loc_b = loc_b
        self.std_b = std_b
        self.scaling_fn = scaling_fn

        self.conv1 = Conv2dPrior(3, 64, kernel_size=3, stride=1, padding=1, prior_b=None,
                           prior_w=self.prior_w, loc_w=self.loc_w, std_w=self.std_w,
                           scaling_fn=self.scaling_fn)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = LinearPrior(512*block.expansion, num_classes,
                            prior_w=self.prior_w, loc_w=self.loc_w, std_w=self.std_w,
                            prior_b=self.prior_b, loc_b=self.loc_b, std_b=self.std_b,
                            scaling_fn=self.scaling_fn)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride, bn=self.bn,
                                prior_w=self.prior_w, loc_w=self.loc_w, std_w=self.std_w,
                                prior_b=self.prior_b, loc_b=self.loc_b, std_b=self.std_b,
                                scaling_fn=self.scaling_fn))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv1(x)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def PreActResNet18(softmax_temp=1.,
             prior_w=prior.Normal, loc_w=0., std_w=2**.5,
             prior_b=prior.Normal, loc_b=0., std_b=1.,
            scaling_fn=None, bn=True):
    return ClassificationModel(PreActResNet(PreActBlock,
                                        [2,2,2,2], bn=bn,
                                        prior_w=prior_w,
                                       loc_w=loc_w,
                                       std_w=std_w,
                                       prior_b=prior_b,
                                       loc_b=loc_b,
                                       std_b=std_b,
                                       scaling_fn=scaling_fn,), softmax_temp)

def PreActResNet34(softmax_temp=1.,
             prior_w=prior.Normal, loc_w=0., std_w=2**.5,
             prior_b=prior.Normal, loc_b=0., std_b=1.,
            scaling_fn=None, bn=True):
    return ClassificationModel(PreActResNet(PreActBlock,
                                        [3,4,6,3], bn=bn,
                                        prior_w=prior_w,
                                       loc_w=loc_w,
                                       std_w=std_w,
                                       prior_b=prior_b,
                                       loc_b=loc_b,
                                       std_b=std_b,
                                       scaling_fn=scaling_fn,), softmax_temp)

## Test new model class

In [3]:
from bnn_priors.models import PreActResNet18

In [21]:
data = CIFAR10(device=device)

dataloader_train = DataLoader(data.norm.train, batch_size=32, shuffle=True, drop_last=True)
dataloader_test = DataLoader(data.norm.test, batch_size=32, shuffle=False, drop_last=False)

Files already downloaded and verified
Files already downloaded and verified


#### SGD training

In [29]:
num_epochs = 5
lr = 5e-4

for criterion in ["nll", "potential"]:
    print(f"Training criterion: {criterion}")
    net = PreActResNet18(bn=True).to(device)
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    
    for epoch in range(num_epochs):
        net.train()
        with tqdm(desc=f"Epoch {epoch}", total=len(dataloader_train), leave=False) as pbar:
            for batch_x, batch_y in dataloader_train:
                optimizer.zero_grad()
                if criterion == "nll":
                    loss = -net.log_likelihood_avg(batch_x, batch_y, len(dataloader_train.dataset))
                elif criterion == "potential":
                    loss = net.potential_avg(batch_x, batch_y, len(dataloader_train.dataset))
                loss.backward()
                optimizer.step()
                pbar.update()
                pbar.set_postfix({"loss": f"{loss.item():.2f}"})

        total_acc = 0.
        num_batches = 0
        net.eval()
        with torch.no_grad():
            for batch_x, batch_y in dataloader_test:
                batch_x, batch_y = batch_x.to(device), batch_y.to(device)
                y_pred = net(batch_x).probs.argmax(axis=1)
                total_acc += y_pred.eq(batch_y).float().mean().item()
                num_batches += 1
        acc = total_acc/num_batches
        print(f"Epoch {epoch}: Test accuracy = {acc*100:.1f} %")

Training criterion: nll


Epoch 1:   0%|          | 6/1562 [00:00<00:44, 35.06it/s, loss=1.33]   

Epoch 0: Test accuracy = 54.7 %


Epoch 2:   0%|          | 7/1562 [00:00<00:41, 37.56it/s, loss=1.09]   

Epoch 1: Test accuracy = 66.7 %


Epoch 3:   0%|          | 7/1562 [00:00<00:41, 37.58it/s, loss=0.78]   

Epoch 2: Test accuracy = 74.3 %


Epoch 4:   0%|          | 7/1562 [00:00<00:44, 35.23it/s, loss=0.34]   

Epoch 3: Test accuracy = 74.4 %


Epoch 0:   0%|          | 1/1562 [00:00<00:56, 27.63it/s, loss=268.67] 

Epoch 4: Test accuracy = 77.7 %
Training criterion: potential


Epoch 1:   0%|          | 4/1562 [00:00<01:03, 24.44it/s, loss=-4685.28]   

Epoch 0: Test accuracy = 55.0 %


Epoch 2:   0%|          | 4/1562 [00:00<01:02, 24.76it/s, loss=-38600.84]   

Epoch 1: Test accuracy = 58.9 %


Epoch 3:   0%|          | 5/1562 [00:00<00:57, 27.30it/s, loss=-116083.49]   

Epoch 2: Test accuracy = 58.0 %


Epoch 4:   0%|          | 4/1562 [00:00<01:02, 24.79it/s, loss=-231762.94]   

Epoch 3: Test accuracy = 57.8 %


                                                                             

Epoch 4: Test accuracy = 59.8 %


#### SGLD inference

In [15]:
from bnn_priors.inference import SGLDRunner
import torch as t

In [6]:
model = PreActResNet18(bn=True).to(device)

n_samples = 2
skip = 1
cycles = 1
warmup = 1
burnin = 1
lr = 5e-4
temperature = 1.0
momentum = 0.9
precond_update = None
    
sample_epochs = n_samples * skip // cycles
epochs_per_cycle = warmup + burnin + sample_epochs

In [7]:
mcmc = SGLDRunner(model=model, dataloader=dataloader_train, epochs_per_cycle=epochs_per_cycle,
                  warmup_epochs=warmup, sample_epochs=sample_epochs, learning_rate=lr,
                  skip=skip, sampling_decay=True, cycles=cycles, temperature=temperature,
                  momentum=momentum, precond_update=precond_update)

In [8]:
mcmc.run(progressbar=True)

Cycle 0, Sampling: 100%|██████████| 4/4 [04:05<00:00, 61.29s/it]


In [9]:
samples = mcmc.get_samples()

In [12]:
bn_params = {k:v for k,v in model.state_dict().items() if "bn" in k}

In [41]:
lps = []

for i in range(n_samples):
    sample = dict((k, v[i].to(device)) for k, v in samples.items())
    sampled_state_dict = {**sample, **bn_params}
    with t.no_grad():
        # TODO: get model.using_params() to work with batchnorm params
        model.load_state_dict(sampled_state_dict)
        lps_sample = []
        for batch_x, batch_y in dataloader_test:
            lps_batch = model(batch_x).log_prob(batch_y)
            lps_sample.extend(list(lps_batch.cpu().numpy()))
        lps.append(lps_sample)

lps = t.tensor(lps)

In [46]:
print(f"Log prob = {lps.mean().item():.2f} +/- {lps.std().item():.2f}")

Log prob = -1.49 +/- 1.17


## He init

In [2]:
from bnn_priors.models import DenseNet, PreActResNet18

In [3]:
def he_initialize(model):
    for name, param in model.named_parameters():
        if "weight_prior.p" in name:
            torch.nn.init.kaiming_normal_(param.data, mode='fan_in', nonlinearity='relu')
        elif "bias_prior.p" in name:
            torch.nn.init.zeros_(param.data)

In [4]:
net = DenseNet(20,10,16)

In [7]:
net.net[0].weight_prior.p[:2,:2]

tensor([[-0.0377, -0.2506],
        [-0.1723,  0.1209]], grad_fn=<SliceBackward>)

In [8]:
net.net[0].bias_prior.p[:2]

tensor([-0.2780, -1.3786], grad_fn=<SliceBackward>)

In [9]:
he_initialize(net)

In [10]:
net.net[0].weight_prior.p[:2,:2]

tensor([[-0.0078,  0.2775],
        [-0.3276,  0.1234]], grad_fn=<SliceBackward>)

In [11]:
net.net[0].bias_prior.p[:2]

tensor([0., 0.], grad_fn=<SliceBackward>)

In [12]:
net = PreActResNet18()

In [13]:
net.net.conv1.weight_prior.p[:2,:2,:2,:2]

tensor([[[[-2.3668,  0.5335],
          [ 0.6159, -0.2794]],

         [[-0.7321,  1.1602],
          [ 0.5006,  0.0309]]],


        [[[ 0.6768, -0.6904],
          [-0.2288,  0.2639]],

         [[-1.2357,  1.1778],
          [ 0.7514,  0.4818]]]], grad_fn=<SliceBackward>)

In [15]:
he_initialize(net)

In [16]:
net.net.conv1.weight_prior.p[:2,:2,:2,:2]

tensor([[[[ 0.1949,  0.0349],
          [-0.2098, -0.0787]],

         [[ 0.1343, -0.0710],
          [ 0.0746, -0.1642]]],


        [[[ 0.4675, -0.7776],
          [ 0.3416,  0.3063]],

         [[ 0.2940, -0.1670],
          [-0.2372,  0.4111]]]], grad_fn=<SliceBackward>)