In [1]:
import torch
import torchvision
import torchvision.transforms as T
from torch import nn, optim
from torch.utils.data import DataLoader
from tqdm import tqdm

from probly.losses import ELBOLoss
from probly.representation import Bayesian

### Prepare the data

In [2]:
transforms = T.Compose([T.ToTensor()])
train = torchvision.datasets.FashionMNIST(root="~/datasets", train=True, download=True, transform=transforms)
test = torchvision.datasets.FashionMNIST(root="~/datasets", train=False, download=True, transform=transforms)
train_loader = DataLoader(train, batch_size=256, shuffle=True)
test_loader = DataLoader(test, batch_size=256, shuffle=False)

### Define a simple neural network and make it Bayesian

In [3]:
class LeNet(nn.Module):
    """Implementation of a model with LeNet architecture.

    Attributes:
        conv1: nn.Module, first convolutional layer
        conv2: nn.Module, second convolutional layer
        fc1: nn.Module, first fully connected layer
        fc2: nn.Module, second fully connected layer
        act: nn.Module, activation function
        max_pool: nn.Module, max pooling layer
    """

    def __init__(self) -> None:
        """Initializes an instance of the LeNet class."""
        super().__init__()
        self.conv1 = nn.Conv2d(1, 8, kernel_size=5)
        self.conv2 = nn.Conv2d(8, 16, kernel_size=5)
        self.fc1 = nn.Linear(256, 128)
        self.fc2 = nn.Linear(128, 10)
        self.act = nn.ReLU()
        self.max_pool = nn.MaxPool2d(2)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass of the LeNet.

        Args:
            x: torch.Tensor, input data
        Returns:
            torch.Tensor, output data
        """
        x = self.act(self.max_pool(self.conv1(x)))
        x = self.act(self.max_pool(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = self.act(self.fc1(x))
        x = self.fc2(x)
        return x


net = LeNet()
model = Bayesian(net)

### Train the Bayesian neural network using the ELBO loss

In [4]:
epochs = 20
optimizer = optim.Adam(model.parameters())
criterion = ELBOLoss(1e-5)
for epoch in tqdm(range(epochs)):
    model.train()
    running_loss = 0.0
    for inputs, targets in train_loader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets, model.kl_divergence)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f"Epoch {epoch + 1}, Running loss: {running_loss / len(train_loader)}, KL: {model.kl_divergence.item()}")

# compute accuracy on test set
correct = 0
total = 0
model.eval()
for inputs, targets in test_loader:
    outputs = model(inputs)
    correct += (outputs.argmax(1) == targets).sum()
    total += targets.size(0)
print(f"Accuracy: {correct / total}")

  5%|▌         | 1/20 [00:08<02:37,  8.27s/it]

Epoch 1, Running loss: 2.349484979852717, KL: 92669.15625


 10%|█         | 2/20 [00:17<02:34,  8.57s/it]

Epoch 2, Running loss: 1.7455554789685188, KL: 91920.96875


 15%|█▌        | 3/20 [00:26<02:28,  8.75s/it]

Epoch 3, Running loss: 1.5924415730415507, KL: 91194.2109375


 20%|██        | 4/20 [00:34<02:20,  8.80s/it]

Epoch 4, Running loss: 1.5001984535379613, KL: 90453.296875


 25%|██▌       | 5/20 [00:43<02:10,  8.73s/it]

Epoch 5, Running loss: 1.4464950074540808, KL: 89695.953125


 30%|███       | 6/20 [00:51<02:00,  8.58s/it]

Epoch 6, Running loss: 1.3982942992068352, KL: 88924.28125


 35%|███▌      | 7/20 [01:00<01:53,  8.70s/it]

Epoch 7, Running loss: 1.3545769600158042, KL: 88125.8515625


 40%|████      | 8/20 [01:09<01:44,  8.69s/it]

Epoch 8, Running loss: 1.3280134743832528, KL: 87297.703125


 45%|████▌     | 9/20 [01:17<01:34,  8.56s/it]

Epoch 9, Running loss: 1.2974397948447693, KL: 86432.1953125


 50%|█████     | 10/20 [01:24<01:21,  8.17s/it]

Epoch 10, Running loss: 1.2703498586695245, KL: 85538.1171875


 55%|█████▌    | 11/20 [01:31<01:09,  7.73s/it]

Epoch 11, Running loss: 1.2467263404359208, KL: 84609.109375


 60%|██████    | 12/20 [01:38<01:00,  7.54s/it]

Epoch 12, Running loss: 1.228259710555381, KL: 83652.640625


 65%|██████▌   | 13/20 [01:47<00:54,  7.76s/it]

Epoch 13, Running loss: 1.2088190291790253, KL: 82676.46875


 70%|███████   | 14/20 [01:56<00:48,  8.16s/it]

Epoch 14, Running loss: 1.1924033408469341, KL: 81684.90625


 75%|███████▌  | 15/20 [02:05<00:42,  8.41s/it]

Epoch 15, Running loss: 1.1743885669302434, KL: 80675.9765625


 80%|████████  | 16/20 [02:14<00:34,  8.64s/it]

Epoch 16, Running loss: 1.1566542879064032, KL: 79658.265625


 85%|████████▌ | 17/20 [02:23<00:26,  8.74s/it]

Epoch 17, Running loss: 1.141962294375643, KL: 78631.078125


 90%|█████████ | 18/20 [02:31<00:17,  8.62s/it]

Epoch 18, Running loss: 1.1245111916927581, KL: 77593.7109375


 95%|█████████▌| 19/20 [02:41<00:09,  9.02s/it]

Epoch 19, Running loss: 1.1074010848999023, KL: 76550.2890625


100%|██████████| 20/20 [02:50<00:00,  8.51s/it]

Epoch 20, Running loss: 1.0937262408276822, KL: 75515.5





Accuracy: 0.8654999732971191
