# Bayesian Neural Networks

We used a [github repo](https://github.com/kumar-shridhar/PyTorch-BayesianCNN)...

Below is code that imports the libraries, sets the device, imports the data, transforms the 

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
torchvision.disable_beta_transforms_warning()
import torchvision.transforms.v2 as transforms
import torch.utils.tensorboard as tb
import torch.nn.functional as F
# Code from paper
from BCNN.layers.misc import ModuleWrapper
from BCNN.layers.BBB import BBBConv
from BCNN.layers.BBB import BBBLinear

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineRenderer.figure_format = 'retina'

log_dir = 'logs'

device = None
if not torch.backends.mps.is_available(): # Check if mac
    if not torch.backends.mps.is_built(): # Error on Mac
        print("MPS not available because the current PyTorch install was not "
              "built with MPS enabled.")
    else:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # For GPU
else:
        device = torch.device("mps") # For GPU



print(f"Using device: {device}")

transform = transforms.Compose([
    transforms.ToImageTensor(),
    transforms.ConvertImageDtype(),
    # Depending on your torchvision version you may need to change these:
    # - If you don't have torchvision.transforms.v2, then import torchvision.transforms
    #   instead and use ToTensor() to replace _both_ of the transforms above.
    # - If you have v2 but it says ToImage() is undefined, then use ToImageTensor() instead.
])

cifar = torchvision.datasets.CIFAR10("cifar", download=True, transform=transform) # Download data
train_size = int(0.9 * len(cifar)) # 90/10 split of dataz
train_data, valid_data = torch.utils.data.random_split(cifar, [train_size, len(cifar) - train_size])

classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

mean = []
for x, _ in cifar:
    mean.append(torch.mean(x, dim=(1, 2)))
mean = torch.stack(mean, dim=0).mean(dim=0)
std = []
for x, _ in cifar:
    std.append(((x - mean[:,np.newaxis,np.newaxis]) ** 2).mean(dim=(1, 2)))
std = torch.stack(std, dim=0).mean(dim=0).sqrt()

cifar_mean = (0.4914, 0.4822, 0.4465)
cifar_std = (0.2470, 0.2435, 0.2616)

normalize = transforms.Normalize(cifar_mean, cifar_std)

augments = transforms.Compose([
            transforms.RandomHorizontalFlip(0.05),
            transforms.RandomGrayscale(0.03),
            transforms.ColorJitter(
                # I've cranked up the brightness adjustment to make it more obvious
                brightness=0.08,
                contrast=0.031,
                saturation=0.031,
                hue=0),
            transforms.Normalize(cifar_mean, cifar_std)
        ])

def train(model_class, 
            model_type,
          lr=1e-3, 
          epochs=10,
          reg=0,
          train_batch_size=32, 
          val_batch_size = 1000):
 
    data_loader = torch.utils.data.DataLoader(train_data, batch_size=train_batch_size, shuffle=True)
    valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=val_batch_size, shuffle=False)

    train_accs = []
    valid_accs = []
    
    network = model_class().to(device)
    logger = tb.SummaryWriter(log_dir + '/' + model_type + '-lr-' + str(lr) + '-epochs-' + str(epochs))
    loss = nn.CrossEntropyLoss()

    opt = optim.AdamW(network.parameters(), lr=lr, weight_decay=reg)
    scheduler = optim.lr_scheduler.StepLR(opt, step_size=20, gamma=0.85)

    global_step = 0
    for i in range(epochs):


        train_acc = []
        network.train()
        # The data loader makes batching easy
        for batch_xs, batch_ys in data_loader:
            batch_xs = batch_xs.to(device)
            preds = network(augments(batch_xs))
            loss_val = loss(preds, batch_ys.to(device))


            # Reset the gradients of all of parameters
            opt.zero_grad()
            # backward() call computes gradients using backpropagation
            loss_val.backward()
            # step() changes the parameters.
            opt.step()
            preds = network(normalize(batch_xs))
            train_acc.append((preds.argmax(dim=1) == batch_ys.to(device)).float().mean())
            # Logging
            logger.add_scalar('loss', loss_val, global_step=global_step)
            logger.add_scalar('training accuracy', (preds.argmax(dim=1) == batch_ys.to(device)).float().mean(), global_step=global_step)
            global_step += 1
            
        train_accs = np.mean([tensor.item() for tensor in train_acc])

        # Mesaure the validation accuracy.
        network.eval()
        val_acc = []
        for batch_xs, batch_ys in valid_loader:
            preds = network(normalize(batch_xs.to(device)))
            val_acc.append((preds.argmax(dim=1) == batch_ys.to(device)).float().mean())
        
        valid_accs = np.mean([tensor.item() for tensor in val_acc])
        logger.add_scalar('validation accuracy', valid_accs, global_step=global_step)

        scheduler.step()
        
        print("Epoch:", i + 1, "\nTrain accuracy:", train_accs, "\nValidation accuracy", valid_accs, "\n--------------------------------------------")

    return network

Using device: mps
Files already downloaded and verified


# Models

Below is the code that defines out CNN and a bayesian CNN model. They are intentionally a similar structure. 

In [5]:
class CNN(nn.Module):
    def __init__(self, arch=None, activation=F.relu):
        super().__init__()
        # Code from pytorch site
        self.activation = activation
        self.conv1 =  nn.Conv2d(3, 6, 5) # Could also add stridding and padding, but not for this dataset
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
        self.dropout = nn.Dropout(0.05)

    def forward(self, x):
        x = self.pool(self.activation(self.conv1(x)))
        x = self.pool(self.activation(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = self.dropout(x)
        x = self.activation(self.fc1(x))
        x = self.dropout(x)
        x = self.activation(self.fc2(x))
        x = self.fc3(x)
        return x

In [6]:
class BNN(ModuleWrapper):
  def __init__(self, activation=F.relu):
    super().__init__()
    self.activation = activation
    self.conv1 = BBBConv.BBBConv2d(3, 6, 5)
    self.pool = nn.MaxPool2d(2, 2)
    self.conv2 = BBBConv.BBBConv2d(6, 16, 5)
    self.fc1 = BBBLinear.BBBLinear(16 * 5 * 5, 120)
    self.fc2 = BBBLinear.BBBLinear(120, 84)
    self.fc3 = BBBLinear.BBBLinear(84, 10)
    self.dropout = nn.Dropout(0.05)


  def forward(self, x):
    x = self.pool(self.activation(self.conv1(x)))
    x = self.pool(self.activation(self.conv2(x)))
    x = torch.flatten(x, 1) # flatten all dimensions except batch
    x = self.dropout(x)
    x = self.activation(self.fc1(x))
    x = self.dropout(x)
    x = self.activation(self.fc2(x))
    x = self.fc3(x)
    return x

# Training

The below trains the models... 

In [None]:
cnn_model = train( model_class=CNN, 
                    model_type = "CNN", 
                    lr=4e-4, 
                    reg=0.00001,
                    epochs=500, 
                    train_batch_size=32, 
                    val_batch_size = 5000)

In [68]:
bnn_model = train( model_class=BNN, 
                    model_type = "BNN", 
                    lr=4e-4, 
                    reg=0.00001,
                    epochs=500, 
                    train_batch_size=32, 
                    val_batch_size = 5000)

Epoch: 1 
Train accuracy: 0.2435589907604833 
Validation accuracy 0.3158000111579895 
--------------------------------------------
Epoch: 2 
Train accuracy: 0.34717039800995025 
Validation accuracy 0.373199999332428 
--------------------------------------------
Epoch: 3 
Train accuracy: 0.38324004975124376 
Validation accuracy 0.3970000147819519 
--------------------------------------------
Epoch: 4 
Train accuracy: 0.40780472636815923 
Validation accuracy 0.4169999957084656 
--------------------------------------------
Epoch: 5 
Train accuracy: 0.42641702203269366 
Validation accuracy 0.4052000045776367 
--------------------------------------------
Epoch: 6 
Train accuracy: 0.44642857142857145 
Validation accuracy 0.44920000433921814 
--------------------------------------------
Epoch: 7 
Train accuracy: 0.46095415778251597 
Validation accuracy 0.46000000834465027 
--------------------------------------------
Epoch: 8 
Train accuracy: 0.47237029140014214 
Validation accuracy 0.4828000

KeyboardInterrupt: 

# Save and check model

In [None]:
torch.save(cnn_model.state_dict(), "cnn_model.pt")
torch.save(bnn_model.state_dict(), "bnn_model.pt")


In [None]:
%reload_ext tensorboard

# Load the tensorboard extension for Jupyter
%load_ext tensorboard
# Start tensorboard and tell it where to look for logs. It will auto-update every second.
%tensorboard --logdir {log_dir} --reload_interval 1 