In [1]:
# Install a pip package in the current Jupyter kernel
import sys
!{sys.executable} -m pip install torch
!{sys.executable} -m pip install torchvision



In [1]:
!pip3 uninstall -y torch torchvision
!pip3 install torch torchvision

Found existing installation: torch 2.2.1
Uninstalling torch-2.2.1:
  Successfully uninstalled torch-2.2.1
Found existing installation: torchvision 0.17.1
Uninstalling torchvision-0.17.1:
  Successfully uninstalled torchvision-0.17.1
Collecting torch
  Obtaining dependency information for torch from https://files.pythonhosted.org/packages/59/1f/4975d1ab3ed2244053876321ef65bc02935daed67da76c6e7d65900772a3/torch-2.2.1-cp311-cp311-win_amd64.whl.metadata
  Using cached torch-2.2.1-cp311-cp311-win_amd64.whl.metadata (26 kB)
Collecting torchvision
  Obtaining dependency information for torchvision from https://files.pythonhosted.org/packages/e7/45/419aa0b37254f1fd62b45bb63836066c5eb81e37d70940e0491e95167eed/torchvision-0.17.1-cp311-cp311-win_amd64.whl.metadata
  Using cached torchvision-0.17.1-cp311-cp311-win_amd64.whl.metadata (6.6 kB)
Using cached torch-2.2.1-cp311-cp311-win_amd64.whl (198.6 MB)
Using cached torchvision-0.17.1-cp311-cp311-win_amd64.whl (1.2 MB)
Installing collected packages

In [2]:
from BayesianMnist.dataset import getSets
from BayesianMnist.viModel import BayesianMnistNet

import numpy as np

import matplotlib.pyplot as plt

import torch
from torch.optim import Adam
from torch.utils.data import DataLoader

import os
from pathlib import Path

In [3]:
# The class to ignore during training
filtered_class = 5
# The class to test against that is not the filtered class
test_class = 4
# Directory where the models can be saved or loaded from
save_dir = Path(
    'C:\\Users\\pau_a\\Documents\\Python_scripts\\bayesian_convolutional_neural_network\\examples\\BayesianMnist\\models')
# Load the models directly instead of training
no_train = True
# The number of epochs to train for
n_epochs = 10
# Batch size used for training
n_batch = 64
# The number of pass to use at test time for monte-carlo uncertainty estimation
n_runtests = 50
# The learning rate of the optimizer
learning_rate = 5e-3
# The number of networks to train to make an ensemble
num_networks = 10

# Loading MINST dataset

In [4]:
train, test = getSets(filteredClass = filtered_class)
train_filtered, test_filtered = getSets(filteredClass = filtered_class, removeFiltered = False)

N = len(train)

train_loader = torch.utils.data.DataLoader(train, batch_size=n_batch)
test_loader = torch.utils.data.DataLoader(test, batch_size=n_batch)

batchLen = len(train_loader)
digitsBatchLen = len(str(batchLen))

# Training or loading model

Define functions to save and load trained models

In [5]:
def saveModels(models, savedir) :

    for i, m in enumerate(models) :

        saveFileName = os.path.join(savedir, "model{}.pth".format(i))

        torch.save({"model_state_dict": m.state_dict()}, os.path.abspath(saveFileName))


def loadModels(savedir) :
    models = []

    for f in os.listdir(savedir):

        model = BayesianMnistNet(p_mc_dropout=None)		
        model.load_state_dict(torch.load(os.path.abspath(os.path.join(savedir, f)))["model_state_dict"])
        models.append(model)

    return models

Train or load models

In [6]:
models = []
if no_train:
    # Load models
    models = loadModels(save_dir)
else:
    # Train models
    for i in np.arange(num_networks):
        print("Training model {}/{}:".format(i+1, num_networks))

        # Initialize the model
        model = BayesianMnistNet(p_mc_dropout=None)  # p_mc_dropout=None will disable MC-Dropout for this bnn, as we found out it makes learning much much slower.
        loss = torch.nn.NLLLoss(reduction='mean')  # negative log likelihood will be part of the ELBO

        optimizer = Adam(model.parameters(), lr=learning_rate)
        optimizer.zero_grad()

        for n in np.arange(n_epochs):

            for batch_id, sampl in enumerate(train_loader):

                images, labels = sampl

                pred = model(images, stochastic=True)

                logprob = loss(pred, labels)
                l = N*logprob

                modelloss = model.evalAllLosses()
                l += modelloss

                optimizer.zero_grad()
                l.backward()

                optimizer.step()

                print(
                    "\r", ("\tEpoch {}/{}: Train step {"+(":0{}d".format(digitsBatchLen))+"}/{} prob = {:.4f} model = {:.4f} loss = {:.4f}          "
                    ).format(
                        n+1, n_epochs,
                        batch_id+1,
                        batchLen,
                        torch.exp(-logprob.detach().cpu()).item(),
                        modelloss.detach().cpu().item(),
                        l.detach().cpu().item()), end=""
                )
        print("")

        models.append(model)

    if save_dir is not None :
        saveModels(models, save_dir)

# Testing model

### Testing against seen class:

In [None]:
if test_class != filtered_class:

    train_filtered_seen, test_filtered_seen = getSets(filteredClass=test_class, removeFiltered = False)

    with torch.no_grad():

        samples = torch.zeros((n_runtests, len(test_filtered_seen), 10))

        test_loader = DataLoader(test_filtered_seen, batch_size=len(test_filtered_seen))
        images, labels = next(iter(test_loader))

        for i in np.arange(n_runtests):
            print("\r", "\tTest run {}/{}".format(i+1, n_runtests), end="")
            model = np.random.randint(num_networks)
            model = models[model]

            samples[i,:,:] = torch.exp(model(images))

        print("")

        withinSampleMean = torch.mean(samples, dim=0)
        samplesMean = torch.mean(samples, dim=(0,1))

        withinSampleStd = torch.sqrt(torch.mean(torch.var(samples, dim=0), dim=0))
        acrossSamplesStd = torch.std(withinSampleMean, dim=0)

        print("")
        print("Class prediction analysis:")
        print("\tMean class probabilities:")
        print(samplesMean)
        print("\tPrediction standard deviation per sample:")
        print(withinSampleStd)
        print("\tPrediction standard deviation across samples:")
        print(acrossSamplesStd)

        plt.figure("Seen class probabilities")
        plt.bar(np.arange(10), samplesMean.numpy())
        plt.xlabel('digits')
        plt.ylabel('digit prob')
        plt.ylim([0,1])
        plt.xticks(np.arange(10))

        plt.figure("Seen inner and outter sample std")
        plt.bar(np.arange(10)-0.2, withinSampleStd.numpy(), width = 0.4, label="Within sample")
        plt.bar(np.arange(10)+0.2, acrossSamplesStd.numpy(), width = 0.4, label="Across samples")
        plt.legend()
        plt.xlabel('digits')
        plt.ylabel('std digit prob')
        plt.xticks(np.arange(10))

plt.show()

 	Test run 48/50

### Testing against unseen class:

In [None]:
with torch.no_grad():

    samples = torch.zeros((n_runtests, len(test_filtered), 10))

    test_loader = DataLoader(test_filtered, batch_size=len(test_filtered))
    images, labels = next(iter(test_loader))

    for i in np.arange(n_runtests):
        print("\r", "\tTest run {}/{}".format(i+1, n_runtests), end="")
        model = np.random.randint(num_networks)
        model = models[model]

        samples[i,:,:] = torch.exp(model(images))

    print("")

    withinSampleMean = torch.mean(samples, dim=0)
    samplesMean = torch.mean(samples, dim=(0, 1))

    withinSampleStd = torch.sqrt(torch.mean(torch.var(samples, dim=0), dim=0))
    acrossSamplesStd = torch.std(withinSampleMean, dim=0)

    print("")
    print("Class prediction analysis:")
    print("\tMean class probabilities:")
    print(samplesMean)
    print("\tPrediction standard deviation per sample:")
    print(withinSampleStd)
    print("\tPrediction standard deviation across samples:")
    print(acrossSamplesStd)

    plt.figure("Unseen class probabilities")
    plt.bar(np.arange(10), samplesMean.numpy())
    plt.xlabel('digits')
    plt.ylabel('digit prob')
    plt.ylim([0,1])
    plt.xticks(np.arange(10))

    plt.figure("Unseen inner and outter sample std")
    plt.bar(np.arange(10)-0.2, withinSampleStd.numpy(), width=0.4, label="Within sample")
    plt.bar(np.arange(10)+0.2, acrossSamplesStd.numpy(), width=0.4, label="Across samples")
    plt.legend()
    plt.xlabel('digits')
    plt.ylabel('std digit prob')
    plt.xticks(np.arange(10))
plt.show()

### Testing against pure white noise:

In [None]:
with torch.no_grad():

    l = 1000

    samples = torch.zeros((n_runtests, l, 10))

    random = torch.rand((l, 1, 28, 28))

    for i in np.arange(n_runtests):
        print("\r", "\tTest run {}/{}".format(i+1, n_runtests), end="")
        model = np.random.randint(num_networks)
        model = models[model]

        samples[i, :, :] = torch.exp(model(random))

    print("")

    withinSampleMean = torch.mean(samples, dim=0)
    samplesMean = torch.mean(samples, dim=(0, 1))

    withinSampleStd = torch.sqrt(torch.mean(torch.var(samples, dim=0), dim=0))
    acrossSamplesStd = torch.std(withinSampleMean, dim=0)

    print("")
    print("Class prediction analysis:")
    print("\tMean class probabilities:")
    print(samplesMean)
    print("\tPrediction standard deviation per sample:")
    print(withinSampleStd)
    print("\tPrediction standard deviation across samples:")
    print(acrossSamplesStd)

    plt.figure("White noise class probabilities")
    plt.bar(np.arange(10), samplesMean.numpy())
    plt.xlabel('digits')
    plt.ylabel('digit prob')
    plt.ylim([0,1])
    plt.xticks(np.arange(10))

    plt.figure("White noise inner and outter sample std")
    plt.bar(np.arange(10)-0.2, withinSampleStd.numpy(), width = 0.4, label="Within sample")
    plt.bar(np.arange(10)+0.2, acrossSamplesStd.numpy(), width = 0.4, label="Across samples")
    plt.legend()
    plt.xlabel('digits')
    plt.ylabel('std digit prob')
    plt.xticks(np.arange(10))

plt.show()