# Bayesian Neural Networks

In [1]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as T
from tqdm import tqdm

from probly.train.bayesian import collect_kl_divergence
from probly.train.losses import ELBOLoss
from probly.transformation import bayesian

### Prepare the data

In [2]:
transforms = T.Compose([T.ToTensor()])
train = torchvision.datasets.FashionMNIST(root="~/datasets", train=True, download=True, transform=transforms)
test = torchvision.datasets.FashionMNIST(root="~/datasets", train=False, download=True, transform=transforms)
train_loader = DataLoader(train, batch_size=256, shuffle=True)
test_loader = DataLoader(test, batch_size=256, shuffle=False)

### Define a simple neural network and make it Bayesian

In [3]:
class LeNet(nn.Module):
    """Implementation of a model with LeNet architecture.

    Attributes:
        conv1: nn.Module, first convolutional layer
        conv2: nn.Module, second convolutional layer
        fc1: nn.Module, first fully connected layer
        fc2: nn.Module, second fully connected layer
        act: nn.Module, activation function
        max_pool: nn.Module, max pooling layer
    """

    def __init__(self) -> None:
        """Initializes an instance of the LeNet class."""
        super().__init__()
        self.conv1 = nn.Conv2d(1, 8, kernel_size=5)
        self.conv2 = nn.Conv2d(8, 16, kernel_size=5)
        self.fc1 = nn.Linear(256, 128)
        self.fc2 = nn.Linear(128, 10)
        self.act = nn.ReLU()
        self.max_pool = nn.MaxPool2d(2)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass of the LeNet.

        Args:
            x: torch.Tensor, input data
        Returns:
            torch.Tensor, output data
        """
        x = self.act(self.max_pool(self.conv1(x)))
        x = self.act(self.max_pool(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = self.act(self.fc1(x))
        x = self.fc2(x)
        return x


net = LeNet()
model = bayesian(net)

### Train the Bayesian neural network using the ELBO loss

In [4]:
epochs = 20
optimizer = optim.Adam(model.parameters())
criterion = ELBOLoss(1e-5)
for epoch in tqdm(range(epochs)):
    model.train()
    running_loss = 0.0
    for inputs, targets in train_loader:
        optimizer.zero_grad()
        outputs = model(inputs)
        kl = collect_kl_divergence(model)
        loss = criterion(outputs, targets, kl)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f"Epoch {epoch + 1}, Running loss: {running_loss / len(train_loader)}, KL: {kl.item()}")

# compute accuracy on test set
correct = 0
total = 0
model.eval()
for inputs, targets in test_loader:
    outputs = model(inputs)
    correct += (outputs.argmax(1) == targets).sum()
    total += targets.size(0)
print(f"Accuracy: {correct / total}")

  5%|▌         | 1/20 [00:06<02:05,  6.63s/it]

Epoch 1, Running loss: 2.3810682504735095, KL: 94051.484375


 10%|█         | 2/20 [00:13<02:04,  6.89s/it]

Epoch 2, Running loss: 1.777575656200977, KL: 94141.2421875


 15%|█▌        | 3/20 [00:20<01:53,  6.70s/it]

Epoch 3, Running loss: 1.6280197787792126, KL: 94218.8828125


 20%|██        | 4/20 [00:26<01:47,  6.71s/it]

Epoch 4, Running loss: 1.5569966189404751, KL: 94296.90625


 25%|██▌       | 5/20 [00:33<01:41,  6.78s/it]

Epoch 5, Running loss: 1.4976627172307766, KL: 94373.8359375


 30%|███       | 6/20 [00:40<01:34,  6.77s/it]

Epoch 6, Running loss: 1.45882188107105, KL: 94453.9375


 35%|███▌      | 7/20 [00:47<01:28,  6.80s/it]

Epoch 7, Running loss: 1.4279117746556058, KL: 94536.4765625


 40%|████      | 8/20 [00:54<01:20,  6.73s/it]

Epoch 8, Running loss: 1.4015605378658214, KL: 94621.0078125


 45%|████▌     | 9/20 [01:00<01:14,  6.74s/it]

Epoch 9, Running loss: 1.3790245954026568, KL: 94709.5234375


 50%|█████     | 10/20 [01:07<01:06,  6.67s/it]

Epoch 10, Running loss: 1.3584139996386588, KL: 94795.90625


 55%|█████▌    | 11/20 [01:14<01:02,  6.92s/it]

Epoch 11, Running loss: 1.3437195194528457, KL: 94884.265625


 60%|██████    | 12/20 [01:22<00:56,  7.08s/it]

Epoch 12, Running loss: 1.3318564759924056, KL: 94974.2578125


 65%|██████▌   | 13/20 [01:29<00:50,  7.15s/it]

Epoch 13, Running loss: 1.3164951283881006, KL: 95061.4296875


 70%|███████   | 14/20 [01:36<00:42,  7.02s/it]

Epoch 14, Running loss: 1.3108408080770615, KL: 95156.484375


 75%|███████▌  | 15/20 [01:42<00:34,  6.86s/it]

Epoch 15, Running loss: 1.301673905900184, KL: 95251.2734375


 80%|████████  | 16/20 [01:49<00:27,  6.79s/it]

Epoch 16, Running loss: 1.290849932711175, KL: 95348.140625


 85%|████████▌ | 17/20 [01:56<00:20,  6.86s/it]

Epoch 17, Running loss: 1.2835565100324915, KL: 95446.265625


 90%|█████████ | 18/20 [02:03<00:13,  6.98s/it]

Epoch 18, Running loss: 1.2772524615551564, KL: 95545.34375


 95%|█████████▌| 19/20 [02:11<00:07,  7.10s/it]

Epoch 19, Running loss: 1.2710959459872957, KL: 95646.484375


100%|██████████| 20/20 [02:18<00:00,  6.94s/it]

Epoch 20, Running loss: 1.2659006844175622, KL: 95751.0078125





Accuracy: 0.8733999729156494
