In [28]:
%load_ext autoreload
%autoreload 2
import torch
import torch.nn as nn
import torch.optim as optim
from src.probly.models.bayesian import Bayesian
import sklearn.datasets as sd
from torch.utils.data import DataLoader, TensorDataset

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [29]:
# make a simple 2d dataset
X, y = sd.make_moons(n_samples=200, noise=0.1)
X_train, y_train, X_test, y_test = X[:100], y[:100], X[100:], y[100:]

# convert to torch tensors and make dataloader
train = TensorDataset(
    torch.tensor(X_train, dtype=torch.float32), torch.tensor(y_train, dtype=torch.long)
)
train_loader = DataLoader(train, batch_size=32, shuffle=True)
test = TensorDataset(
    torch.tensor(X_test, dtype=torch.float32), torch.tensor(y_test, dtype=torch.long)
)
test_loader = DataLoader(test, batch_size=32, shuffle=False)


# small fully connected neural network
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(2, 50)
        self.fc2 = nn.Linear(50, 50)
        self.fc3 = nn.Linear(50, 2)
        self.act = nn.ReLU()

    def forward(self, x):
        x = self.act(self.fc1(x))
        x = self.act(self.fc2(x))
        x = self.fc3(x)
        return x

net = Net()

In [30]:
model = Bayesian(net)
optimizer = optim.Adam(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
for _ in range(100):
    model.train()
    print(model.model[0].weight_mu[0,0].item())
    for x, y in train_loader:
        optimizer.zero_grad()
        y_pred = model(x, 1)
        y_pred = torch.squeeze(y_pred)
        loss = criterion(y_pred, y)
        loss.backward()
        optimizer.step()
# compute accuracy on test set
correct = 0
total = 0
for x, y in test_loader:
    model.eval()
    y_pred = model(x, 100).mean(axis=1)
    y_pred = y_pred.argmax(axis=1)
    correct += (y_pred == y).sum().item()
    total += y.size(0)
print(f"Accuracy: {correct / total}")

0.2622973918914795
0.2911498248577118
0.3001435399055481
0.3061884641647339
0.3100169897079468
0.31634020805358887
0.3191973865032196
0.3026317358016968
0.2873613238334656
0.2808489501476288
0.2790534198284149
0.2750765383243561
0.27342697978019714
0.27872350811958313
0.286266565322876
0.3014010190963745
0.31093505024909973
0.3174271881580353
0.32149747014045715
0.31916511058807373
0.3126566410064697
0.3080095052719116
0.30203521251678467
0.29645809531211853
0.2943454086780548
0.2931782305240631
0.2923352122306824
0.2823910117149353
0.2686168849468231
0.25816547870635986
0.25165632367134094
0.24588578939437866
0.24139894545078278
0.2401825487613678
0.2404068261384964
0.2406255155801773
0.24242015182971954
0.2432687133550644
0.24184399843215942
0.240226149559021
0.24004517495632172
0.24223890900611877
0.24642835557460785
0.2576693892478943
0.26653531193733215
0.27401480078697205
0.2773234248161316
0.2789287269115448
0.2788134515285492
0.27806270122528076
0.2828178107738495
0.28587850928