In [9]:
# Google Colab-only setup. No need to run this cell in other environments.
use_colab = False

if use_colab:
    # Mount my Google Drive root folder
    from google.colab import drive
    drive.mount('/content/drive')

    # cd to bayesian-dl-experiments directory
    %cd 'drive/My Drive/Colab Notebooks/bayesian-dl-experiments'
    !ls

## Experiment Setup

In [10]:
import torch
import numpy as np

# IPython reloading magic
%load_ext autoreload
%autoreload 2

# Random seeds
# Based on https://pytorch.org/docs/stable/notes/randomness.html
torch.manual_seed(682)
np.random.seed(682)

# torch.device / CUDA Setup
use_cuda = True

if use_cuda and torch.cuda.is_available():
    torch_device = torch.device('cuda')
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
else:
    torch_device = torch.device('cpu')

# Dataset to use
dataset_name = 'yacht'

# Training set size
dataset_train_size = 0.8

# L2 regularization strength
reg_strength = 0.01

# Epochs
n_epochs = 1000

# Number of different data splits to try
n_splits = 20

# Data batch sizes
n_training_batch = 10

# Number of test predictions (for each data point)
n_predictions = 10000

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Prepare data

### Get the data as a torch Dataset object

In [11]:
from torch.utils.data import random_split, DataLoader
from ronald_bdl import datasets

if dataset_name == 'MNIST':
    dataset = datasets.MNIST(root_dir='./datasets_files', transform=None, download=True)
else:
    dataset = datasets.UCIDatasets(dataset_name, root_dir='./datasets_files', transform=None, download=True)

train_size = int(dataset_train_size * len(dataset))
test_size = len(dataset) - train_size
    
# Print the size of the dataset
print("dataset size = " + str(dataset.data.shape))
print("training set size = " + str(train_size))
print("testing set size = " + str(test_size))

Using downloaded and verified file: ./datasets_files/yacht/yacht_hydrodynamics.data
dataset size = torch.Size([308, 7])


## Define network

In [13]:
from ronald_bdl import models

network = models.FCNetMCDropout(
    input_dim=len(dataset.features), 
    output_dim=len(dataset.targets),
    hidden_dim=100,
    n_hidden=2,
    dropout_rate=0.01,
)

# Send the whole model to the selected torch.device
network.to(torch_device)

# Print the network structure
print(network)

FCNetMCDropout(
  (input): Linear(in_features=6, out_features=100, bias=True)
  (hidden_layers): ModuleList(
    (0): Linear(in_features=100, out_features=100, bias=True)
    (1): Linear(in_features=100, out_features=100, bias=True)
  )
  (output): Linear(in_features=100, out_features=1, bias=True)
)


## Train the network

### Setup optimizer

In [14]:
from torch import nn, optim

# Model to train mode
network.train()

# Mean Squared Error for loss function to minimize
objective = nn.MSELoss()

# Adam optimizer
# https://pytorch.org/docs/stable/optim.html?highlight=adam#torch.optim.Adam
# NOTE: Need to set L2 regularization from here
optimizer = optim.Adam(
    network.parameters(),
    lr=0.01, 
    weight_decay=reg_strength # L2 regularization
)

### Train the model

In [15]:
# Partially adapted from https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
for s in range(n_splits):

    # Prepare new train-test split
    train, test = random_split(dataset, lengths=[train_size, test_size])
    train_loader = DataLoader(train, batch_size=n_training_batch)
    test_loader = DataLoader(test, batch_size=test_size)
    
    for epoch in range(n_epochs): # loop over the dataset multiple times

        print("Running epoch " + str(epoch))

        for i, data in enumerate(train_loader):
            # get the inputs; data is a list of [inputs, labels]
            inputs, targets = data

            # Store the batch to torch_device's memory
            inputs = inputs.to(torch_device)
            targets = targets.to(torch_device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = network(inputs)

            loss = objective(outputs, targets)
            loss.backward()

            optimizer.step()

Running epoch 0
Running epoch 1
Running epoch 2
Running epoch 3
Running epoch 4
Running epoch 5
Running epoch 6
Running epoch 7
Running epoch 8
Running epoch 9
Running epoch 10
Running epoch 11
Running epoch 12
Running epoch 13
Running epoch 14
Running epoch 15
Running epoch 16
Running epoch 17
Running epoch 18
Running epoch 19
Running epoch 20
Running epoch 21
Running epoch 22
Running epoch 23
Running epoch 24
Running epoch 25
Running epoch 26
Running epoch 27
Running epoch 28
Running epoch 29
Running epoch 30
Running epoch 31
Running epoch 32
Running epoch 33
Running epoch 34
Running epoch 35
Running epoch 36
Running epoch 37
Running epoch 38
Running epoch 39
Running epoch 40
Running epoch 41
Running epoch 42
Running epoch 43
Running epoch 44
Running epoch 45
Running epoch 46
Running epoch 47
Running epoch 48
Running epoch 49
Running epoch 50
Running epoch 51
Running epoch 52
Running epoch 53
Running epoch 54
Running epoch 55
Running epoch 56
Running epoch 57
Running epoch 58
Running

Running epoch 462
Running epoch 463
Running epoch 464
Running epoch 465
Running epoch 466
Running epoch 467
Running epoch 468
Running epoch 469
Running epoch 470
Running epoch 471
Running epoch 472
Running epoch 473
Running epoch 474
Running epoch 475
Running epoch 476
Running epoch 477
Running epoch 478
Running epoch 479
Running epoch 480
Running epoch 481
Running epoch 482
Running epoch 483
Running epoch 484
Running epoch 485
Running epoch 486
Running epoch 487
Running epoch 488
Running epoch 489
Running epoch 490
Running epoch 491
Running epoch 492
Running epoch 493
Running epoch 494
Running epoch 495
Running epoch 496
Running epoch 497
Running epoch 498
Running epoch 499
Running epoch 500
Running epoch 501
Running epoch 502
Running epoch 503
Running epoch 504
Running epoch 505
Running epoch 506
Running epoch 507
Running epoch 508
Running epoch 509
Running epoch 510
Running epoch 511
Running epoch 512
Running epoch 513
Running epoch 514
Running epoch 515
Running epoch 516
Running ep

Running epoch 918
Running epoch 919
Running epoch 920
Running epoch 921
Running epoch 922
Running epoch 923
Running epoch 924
Running epoch 925
Running epoch 926
Running epoch 927
Running epoch 928
Running epoch 929
Running epoch 930
Running epoch 931
Running epoch 932
Running epoch 933
Running epoch 934
Running epoch 935
Running epoch 936
Running epoch 937
Running epoch 938
Running epoch 939
Running epoch 940
Running epoch 941
Running epoch 942
Running epoch 943
Running epoch 944
Running epoch 945
Running epoch 946
Running epoch 947
Running epoch 948
Running epoch 949
Running epoch 950
Running epoch 951
Running epoch 952
Running epoch 953
Running epoch 954
Running epoch 955
Running epoch 956
Running epoch 957
Running epoch 958
Running epoch 959
Running epoch 960
Running epoch 961
Running epoch 962
Running epoch 963
Running epoch 964
Running epoch 965
Running epoch 966
Running epoch 967
Running epoch 968
Running epoch 969
Running epoch 970
Running epoch 971
Running epoch 972
Running ep

## Make predictions

In [16]:
# Model to eval mode
network.eval()

for i, data in enumerate(test_loader):
    # Get the test data
    inputs, targets = data

    # Store the batch to torch_device's memory
    inputs = inputs.to(torch_device)
    targets = targets.to(torch_device)
    
    predictions, mean, var, metrics = network.mc_predict(inputs, n_predictions,
                                                         y_test=targets, reg_strength=reg_strength)

    print("Test " + str(i))
    print("Mean = " + str(mean))
    print("Variance = " + str(var))
    
    # Print additional metrics
    if len(metrics) > 0:
        for key, value in metrics.items():
            print(str(key) + " = " + str(value))
    print()

Test 0
Mean = tensor([[ 4.1532],
        [ 1.3944],
        [ 3.4415],
        [20.9207],
        [ 2.0312],
        [ 0.7376],
        [ 6.6230],
        [41.4376],
        [ 3.1191],
        [16.9545],
        [ 2.7652],
        [15.6248],
        [ 0.6775],
        [ 2.1834],
        [ 0.6753],
        [ 0.6622],
        [13.1642],
        [23.0704],
        [ 0.7039],
        [ 5.9353],
        [ 1.1530],
        [ 4.4447],
        [37.2074],
        [49.2215],
        [ 1.3395],
        [ 2.3190],
        [ 0.6812],
        [ 0.6558],
        [15.5937],
        [ 2.0959],
        [28.7827],
        [13.6086],
        [ 0.6744],
        [ 3.1608],
        [ 2.3843],
        [26.1005],
        [ 4.4633],
        [ 1.0374],
        [ 0.6937],
        [18.3128],
        [ 3.6354],
        [ 0.6607],
        [ 1.8770],
        [ 7.4597],
        [ 3.3278],
        [33.1474],
        [ 2.2746],
        [ 0.6999],
        [ 0.6636],
        [23.0520],
        [ 3.6975],
        [25.9697]