In [83]:
%load_ext autoreload
%autoreload 2
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as T
from torch.utils.data import DataLoader
from tqdm import tqdm

from probly.representation import Bayesian

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [84]:
transforms = T.Compose([T.ToTensor(), torch.flatten])
train = torchvision.datasets.FashionMNIST(root='~/datasets', train=True, download=True, transform=transforms)
test = torchvision.datasets.FashionMNIST(root='~/datasets', train=False, download=True, transform=transforms)
train_loader = DataLoader(train, batch_size=256, shuffle=True)
test_loader = DataLoader(test, batch_size=256, shuffle=False)

In [85]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 100)
        self.fc2 = nn.Linear(100, 100)
        self.fc3 = nn.Linear(100, 10)
        self.act = nn.ReLU()

    def forward(self, x):
        x = self.act(self.fc1(x))
        x = self.act(self.fc2(x))
        x = self.fc3(x)
        return x

net = Net()
model = Bayesian(net)

In [86]:
epochs = 20
optimizer = optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()
lmbda = 1e-4
for epoch in tqdm(range(epochs)):
    model.train()
    running_loss = 0.0
    for inputs, targets in train_loader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets) + lmbda * model.kl_divergence
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(model.kl_divergence)
    print(f"Epoch {epoch + 1}, Running loss: {running_loss / len(train_loader)}")

# compute accuracy on test set
correct = 0
total = 0
model.eval()
for inputs, targets in test_loader:
    outputs = model(inputs)
    correct += (outputs.argmax(1) == targets).sum()
    total += targets.size(0)
print(f"Accuracy: {correct / total}")

  5%|▌         | 1/20 [00:02<00:54,  2.86s/it]

tensor(205859.6406, grad_fn=<AddBackward0>)
Epoch 1, Running loss: 23.293643886484997


 10%|█         | 2/20 [00:05<00:51,  2.84s/it]

tensor(189104.7344, grad_fn=<AddBackward0>)
Epoch 2, Running loss: 20.762263724144468


 15%|█▌        | 3/20 [00:08<00:50,  2.97s/it]

tensor(173067.7969, grad_fn=<AddBackward0>)
Epoch 3, Running loss: 18.840434362533244


 20%|██        | 4/20 [00:11<00:48,  3.02s/it]

tensor(157825.2656, grad_fn=<AddBackward0>)
Epoch 4, Running loss: 17.14894571507231


100%|██████████| 20/20 [01:00<00:00,  3.02s/it]

tensor(34708.5508, grad_fn=<AddBackward0>)
Epoch 20, Running loss: 4.127033773381659





Accuracy: 0.8012999892234802


In [87]:
posterior_std = 0.05
torch.log(torch.exp(torch.tensor(posterior_std)) - 1)

tensor(-2.9706)