In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as T
from torch.utils.data import DataLoader
from tqdm import tqdm

from probly.losses import ELBOLoss
from probly.representation import Bayesian

### Prepare the data

In [2]:
transforms = T.Compose([T.ToTensor()])
train = torchvision.datasets.FashionMNIST(
    root="~/datasets", train=True, download=True, transform=transforms
)
test = torchvision.datasets.FashionMNIST(
    root="~/datasets", train=False, download=True, transform=transforms
)
train_loader = DataLoader(train, batch_size=256, shuffle=True)
test_loader = DataLoader(test, batch_size=256, shuffle=False)

### Define a simple neural network and make it Bayesian

In [3]:
class LeNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 8, kernel_size=5)
        self.conv2 = nn.Conv2d(8, 16, kernel_size=5)
        self.fc1 = nn.Linear(256, 128)
        self.fc2 = nn.Linear(128, 10)
        self.act = nn.ReLU()
        self.max_pool = nn.MaxPool2d(2)

    def forward(self, x):
        x = self.act(self.max_pool(self.conv1(x)))
        x = self.act(self.max_pool(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = self.act(self.fc1(x))
        x = self.fc2(x)
        return x


net = LeNet()
model = Bayesian(net)

### Train the Bayesian neural network using the ELBO loss

In [4]:
epochs = 20
optimizer = optim.Adam(model.parameters())
criterion = ELBOLoss(1e-5)
for epoch in tqdm(range(epochs)):
    model.train()
    running_loss = 0.0
    for inputs, targets in train_loader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets, model.kl_divergence)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(
        f"Epoch {epoch + 1}, Running loss: {running_loss / len(train_loader)}, KL: {model.kl_divergence.item()}"
    )

# compute accuracy on test set
correct = 0
total = 0
model.eval()
for inputs, targets in test_loader:
    outputs = model(inputs)
    correct += (outputs.argmax(1) == targets).sum()
    total += targets.size(0)
print(f"Accuracy: {correct / total}")

  5%|▌         | 1/20 [00:06<02:01,  6.41s/it]

Epoch 1, Running loss: 2.3722202463352935, KL: 92954.234375


 10%|█         | 2/20 [00:12<01:55,  6.41s/it]

Epoch 2, Running loss: 1.8185499328248045, KL: 92349.6953125


 15%|█▌        | 3/20 [00:19<01:47,  6.35s/it]

Epoch 3, Running loss: 1.6587683038508638, KL: 91693.109375


 20%|██        | 4/20 [00:25<01:42,  6.40s/it]

Epoch 4, Running loss: 1.570698143066244, KL: 90968.0


 25%|██▌       | 5/20 [00:32<01:36,  6.42s/it]

Epoch 5, Running loss: 1.5103234534567975, KL: 90174.3203125


 30%|███       | 6/20 [00:39<01:33,  6.69s/it]

Epoch 6, Running loss: 1.4633521445254063, KL: 89331.296875


 35%|███▌      | 7/20 [00:46<01:29,  6.89s/it]

Epoch 7, Running loss: 1.4185041214557403, KL: 88426.96875


 40%|████      | 8/20 [00:53<01:22,  6.89s/it]

Epoch 8, Running loss: 1.3790995871767084, KL: 87477.1328125


 45%|████▌     | 9/20 [01:00<01:16,  6.96s/it]

Epoch 9, Running loss: 1.343233585357666, KL: 86485.6640625


 50%|█████     | 10/20 [01:07<01:09,  6.92s/it]

Epoch 10, Running loss: 1.3143750200880335, KL: 85458.140625


 55%|█████▌    | 11/20 [01:14<01:02,  6.99s/it]

Epoch 11, Running loss: 1.2883331948138297, KL: 84404.890625


 60%|██████    | 12/20 [01:21<00:56,  7.01s/it]

Epoch 12, Running loss: 1.257472261469415, KL: 83325.8984375


 65%|██████▌   | 13/20 [01:27<00:47,  6.78s/it]

Epoch 13, Running loss: 1.2303120171770137, KL: 82239.9921875


 70%|███████   | 14/20 [01:34<00:40,  6.70s/it]

Epoch 14, Running loss: 1.2083908106418366, KL: 81128.109375


 75%|███████▌  | 15/20 [01:40<00:33,  6.67s/it]

Epoch 15, Running loss: 1.1887103826441663, KL: 80017.53125


 80%|████████  | 16/20 [01:48<00:27,  7.00s/it]

Epoch 16, Running loss: 1.166912561781863, KL: 78908.1875


 85%|████████▌ | 17/20 [01:56<00:21,  7.27s/it]

Epoch 17, Running loss: 1.1465019439129118, KL: 77776.703125


 90%|█████████ | 18/20 [02:03<00:14,  7.11s/it]

Epoch 18, Running loss: 1.128332474383902, KL: 76652.0703125


 95%|█████████▌| 19/20 [02:09<00:06,  6.95s/it]

Epoch 19, Running loss: 1.1069046969109393, KL: 75530.375


100%|██████████| 20/20 [02:16<00:00,  6.85s/it]

Epoch 20, Running loss: 1.0914933455751297, KL: 74418.8828125





Accuracy: 0.8633000254631042
