In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as T
from torch.utils.data import DataLoader
from tqdm import tqdm

from probly.losses import EvidenceLowerBound
from probly.representation import Bayesian

In [2]:
transforms = T.Compose([T.ToTensor()])
train = torchvision.datasets.FashionMNIST(root='~/datasets', train=True, download=True, transform=transforms)
test = torchvision.datasets.FashionMNIST(root='~/datasets', train=False, download=True, transform=transforms)
train_loader = DataLoader(train, batch_size=256, shuffle=True)
test_loader = DataLoader(test, batch_size=256, shuffle=False)

In [3]:
class LeNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 8, kernel_size=5)
        self.conv2 = nn.Conv2d(8, 16, kernel_size=5)
        self.fc1 = nn.Linear(256, 128)
        self.fc2 = nn.Linear(128, 10)
        self.act = nn.ReLU()
        self.max_pool = nn.MaxPool2d(2)

    def forward(self, x):
        x = self.act(self.max_pool(self.conv1(x)))
        x = self.act(self.max_pool(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = self.act(self.fc1(x))
        x = self.fc2(x)
        return x

net = LeNet()
model = Bayesian(net)

In [4]:
epochs = 20
optimizer = optim.Adam(model.parameters())
criterion = EvidenceLowerBound(1e-5)
for epoch in tqdm(range(epochs)):
    model.train()
    running_loss = 0.0
    for inputs, targets in train_loader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets, model.kl_divergence)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f"Epoch {epoch + 1}, Running loss: {running_loss / len(train_loader)}, KL: {model.kl_divergence.item()}")

# compute accuracy on test set
correct = 0
total = 0
model.eval()
for inputs, targets in test_loader:
    outputs = model(inputs)
    correct += (outputs.argmax(1) == targets).sum()
    total += targets.size(0)
print(f"Accuracy: {correct / total}")

  5%|▌         | 1/20 [00:08<02:37,  8.27s/it]

Epoch 1, Running loss: 2.3561857304674514, KL: 92668.828125


 10%|█         | 2/20 [00:16<02:29,  8.28s/it]

Epoch 2, Running loss: 1.7392033242164775, KL: 92112.0


 15%|█▌        | 3/20 [00:23<02:13,  7.88s/it]

Epoch 3, Running loss: 1.606832157804611, KL: 91531.53125


 20%|██        | 4/20 [00:32<02:07,  7.95s/it]

Epoch 4, Running loss: 1.5253044782800877, KL: 90895.75


 25%|██▌       | 5/20 [00:39<01:58,  7.91s/it]

Epoch 5, Running loss: 1.4693947609434737, KL: 90238.53125


 30%|███       | 6/20 [00:47<01:49,  7.81s/it]

Epoch 6, Running loss: 1.4213418691716295, KL: 89578.828125


 35%|███▌      | 7/20 [00:54<01:38,  7.60s/it]

Epoch 7, Running loss: 1.3798426364330536, KL: 88896.25


 40%|████      | 8/20 [01:01<01:29,  7.43s/it]

Epoch 8, Running loss: 1.3495702794257631, KL: 88179.40625


 45%|████▌     | 9/20 [01:08<01:20,  7.36s/it]

Epoch 9, Running loss: 1.321896786385394, KL: 87413.84375


 50%|█████     | 10/20 [01:16<01:13,  7.35s/it]

Epoch 10, Running loss: 1.2917746751866441, KL: 86606.09375


 55%|█████▌    | 11/20 [01:23<01:05,  7.31s/it]

Epoch 11, Running loss: 1.2667150472072846, KL: 85759.421875


 60%|██████    | 12/20 [01:31<00:59,  7.40s/it]

Epoch 12, Running loss: 1.2445375472941298, KL: 84879.765625


 65%|██████▌   | 13/20 [01:38<00:51,  7.34s/it]

Epoch 13, Running loss: 1.2240049737565062, KL: 83973.625


 70%|███████   | 14/20 [01:45<00:43,  7.26s/it]

Epoch 14, Running loss: 1.202864813297353, KL: 83047.5


 75%|███████▌  | 15/20 [01:52<00:36,  7.21s/it]

Epoch 15, Running loss: 1.185164464788234, KL: 82094.3125


 80%|████████  | 16/20 [01:59<00:28,  7.16s/it]

Epoch 16, Running loss: 1.1659616622518987, KL: 81113.1171875


 85%|████████▌ | 17/20 [02:06<00:21,  7.12s/it]

Epoch 17, Running loss: 1.1479205669240748, KL: 80098.15625


 90%|█████████ | 18/20 [02:13<00:14,  7.19s/it]

Epoch 18, Running loss: 1.133693779783046, KL: 79072.078125


 95%|█████████▌| 19/20 [02:20<00:07,  7.17s/it]

Epoch 19, Running loss: 1.1155941230185489, KL: 78027.9609375


100%|██████████| 20/20 [02:28<00:00,  7.43s/it]

Epoch 20, Running loss: 1.0981770977060845, KL: 76987.140625





Accuracy: 0.8680999875068665
