In [4]:
%load_ext autoreload
%autoreload 2
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as T
from torch.utils.data import DataLoader
from tqdm import tqdm

from probly.losses import ELBOLoss
from probly.representation import Bayesian

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### Prepare the data

In [5]:
transforms = T.Compose([T.ToTensor()])
train = torchvision.datasets.FashionMNIST(
    root="~/datasets", train=True, download=True, transform=transforms
)
test = torchvision.datasets.FashionMNIST(
    root="~/datasets", train=False, download=True, transform=transforms
)
train_loader = DataLoader(train, batch_size=256, shuffle=True)
test_loader = DataLoader(test, batch_size=256, shuffle=False)

### Define a simple neural network and make it Bayesian

In [6]:
class LeNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 8, kernel_size=5)
        self.conv2 = nn.Conv2d(8, 16, kernel_size=5)
        self.fc1 = nn.Linear(256, 128, False)
        self.fc2 = nn.Linear(128, 10)
        self.act = nn.ReLU()
        self.max_pool = nn.MaxPool2d(2)

    def forward(self, x):
        x = self.act(self.max_pool(self.conv1(x)))
        x = self.act(self.max_pool(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = self.act(self.fc1(x))
        x = self.fc2(x)
        return x


net = torchvision.models.resnet18()
# net = LeNet()
model = Bayesian(net)

### Train the Bayesian neural network using the ELBO loss

In [7]:
epochs = 1
optimizer = optim.Adam(model.parameters())
criterion = ELBOLoss(1e-5)
for epoch in tqdm(range(epochs)):
    model.train()
    running_loss = 0.0
    for inputs, targets in train_loader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets, model.kl_divergence)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(
        f"Epoch {epoch + 1}, Running loss: {running_loss / len(train_loader)}, KL: {model.kl_divergence.item()}"
    )

# compute accuracy on test set
correct = 0
total = 0
model.eval()
for inputs, targets in test_loader:
    outputs = model(inputs)
    correct += (outputs.argmax(1) == targets).sum()
    total += targets.size(0)
print(f"Accuracy: {correct / total}")

  0%|          | 0/1 [00:00<?, ?it/s]


RuntimeError: Given groups=1, weight of size [64, 3, 7, 7], expected input[256, 1, 28, 28] to have 3 channels, but got 1 channels instead