In [1]:
# Google Colab-only setup. No need to run this cell in other environments.
use_colab = False

if use_colab:
    # Mount my Google Drive root folder
    from google.colab import drive
    drive.mount('/content/drive')

    # cd to bayesian-dl-experiments directory
    %cd 'drive/My Drive/Colab Notebooks/bayesian-dl-experiments'
    !ls

## Experiment Setup

In [2]:
import torch
import numpy as np

# IPython reloading magic
%load_ext autoreload
%autoreload 2

# Random seeds
# Based on https://pytorch.org/docs/stable/notes/randomness.html
torch.manual_seed(682)
np.random.seed(682)

# torch.device / CUDA Setup
use_cuda = True

if use_cuda and torch.cuda.is_available():
    torch_device = torch.device('cuda')
    torch.backends.cudnn.deterministic = True
    # Note: https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936
    torch.backends.cudnn.benchmark = False
    use_pin_memory=True # Faster Host to GPU copies with page-locked memory
else:
    torch_device = torch.device('cpu')
    use_pin_memory=False

# Dataset to use
dataset_name = 'yacht'

# Training set size
dataset_train_size = 0.8

# L2 regularization strength
reg_strength = 0.01

# Epochs
n_epochs = 100

# Number of different data splits to try
n_splits = 20

# Data batch sizes
n_training_batch = 10

# Number of test predictions (for each data point)
n_predictions = 10000

## Prepare data

### Get the data as a torch Dataset object

In [3]:
from torch.utils.data import random_split, DataLoader
from ronald_bdl import datasets

if dataset_name == 'MNIST':
    dataset = datasets.MNIST(root_dir='./datasets_files', transform=None, download=True)
else:
    dataset = datasets.UCIDatasets(dataset_name, root_dir='./datasets_files', transform=None, download=True)

# Set the training/test set sizes
train_size = int(dataset_train_size * len(dataset))
test_size = len(dataset) - train_size
    
# Print the size of the dataset
print("dataset size = " + str(dataset.data.shape))
print("training set size = " + str(train_size))
print("testing set size = " + str(test_size))

Using downloaded and verified file: ./datasets_files/yacht/yacht_hydrodynamics.data
dataset size = torch.Size([308, 7])
training set size = 246
testing set size = 62


## Define network

In [4]:
from ronald_bdl import models

network = models.FCNetMCDropout(
    input_dim=len(dataset.features), 
    output_dim=len(dataset.targets),
    hidden_dim=100,
    n_hidden=2,
    dropout_rate=0.01,
)

# Send the whole model to the selected torch.device
network.to(torch_device)

# Print the network structure
print(network)

FCNetMCDropout(
  (input): Linear(in_features=6, out_features=100, bias=True)
  (hidden_layers): ModuleList(
    (0): Linear(in_features=100, out_features=100, bias=True)
    (1): Linear(in_features=100, out_features=100, bias=True)
  )
  (output): Linear(in_features=100, out_features=1, bias=True)
)


## Train the network

### Setup

In [5]:
from torch import nn, optim

# Model to train mode
network.train()

# Mean Squared Error for loss function to minimize
objective = nn.MSELoss()

rmse_non_mc, rmse_mc, test_lls_mc = [], [], []

### Train/test the model

In [6]:
import time

for s in range(n_splits):

    # Prepare new train-test split
    train, test = random_split(dataset, lengths=[train_size, test_size])
    train_loader = DataLoader(train, batch_size=n_training_batch, pin_memory=use_pin_memory)
    if use_pin_memory: test.dataset.data.pin_memory()
    
    # Adam optimizer
    # https://pytorch.org/docs/stable/optim.html?highlight=adam#torch.optim.Adam
    # NOTE: Need to set L2 regularization from here
    optimizer = optim.Adam(
        network.parameters(),
        lr=0.01,
        weight_decay=reg_strength, # L2 regularization
    )
    
    """
    Training
    """
    # Record training start time (for this split)
    tic = time.time()

    for epoch in range(n_epochs): # loop over the dataset multiple times

        for i, data in enumerate(train_loader):
            # get the inputs; data is a list of [inputs, labels]
            inputs, targets = data

            # Store the batch to torch_device's memory
            inputs = inputs.to(torch_device)
            targets = targets.to(torch_device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = network(inputs)

            loss = objective(outputs, targets)
            loss.backward()

            optimizer.step()
            
        print("Running split " + str(s) + ", epoch " + str(epoch) + " loss = " + str(loss.item()))
            
    # Record training end time
    toc = time.time()
            
    """
    Testing
    """
    # Model to eval mode
    network.eval()

    # Get the test data
    inputs, targets = test.dataset[test.indices]

    # Store the batch to torch_device's memory
    inputs = inputs.to(torch_device)
    targets = targets.to(torch_device)

    # Record testing start time (for this split)
    tic_testing = time.time()    
    
    predictions, mean, var, metrics = network.mc_predict(inputs, n_predictions,
                                                         y_test=targets, reg_strength=reg_strength)

    # Record testing end time
    toc_testing = time.time()    
    
    """
    Print results
    """
    print()
    print("Running split " + str(s) + " test:")
    print("Mean = " + str(mean))
    print("Variance = " + str(var))

    # Print and store additional metrics
    if len(metrics) > 0:
        for key, value in metrics.items():
            print(str(key) + " = " + str(value))

            if key == 'rmse_mc': rmse_mc.append(value.item())
            elif key == 'rmse_non_mc': rmse_non_mc.append(value.item())
            elif key == 'test_ll_mc': test_lls_mc.append(value.item())

    # Report the total training time
    print("Split " + str(s) + " training time = " + str(toc - tic) + " seconds")
    
    # Report the total testing time
    print("Split " + str(s) + " testing time = " + str(toc_testing - tic_testing) + " seconds")
    print()

Running split 0, epoch 0 loss = 131.8843231201172
Running split 0, epoch 1 loss = 89.19522094726562
Running split 0, epoch 2 loss = 99.73192596435547
Running split 0, epoch 3 loss = 89.38324737548828
Running split 0, epoch 4 loss = 88.40301513671875
Running split 0, epoch 5 loss = 73.8027572631836
Running split 0, epoch 6 loss = 135.23023986816406
Running split 0, epoch 7 loss = 173.71925354003906
Running split 0, epoch 8 loss = 151.96214294433594
Running split 0, epoch 9 loss = 113.37532806396484
Running split 0, epoch 10 loss = 101.96424102783203
Running split 0, epoch 11 loss = 66.37215423583984
Running split 0, epoch 12 loss = 9.763032913208008
Running split 0, epoch 13 loss = 1.0689550638198853
Running split 0, epoch 14 loss = 0.1310792714357376
Running split 0, epoch 15 loss = 0.1133943721652031
Running split 0, epoch 16 loss = 0.14185978472232819
Running split 0, epoch 17 loss = 0.13907046616077423
Running split 0, epoch 18 loss = 0.1157565712928772
Running split 0, epoch 19 los

Running split 1, epoch 9 loss = 4.853234767913818
Running split 1, epoch 10 loss = 16.769906997680664
Running split 1, epoch 11 loss = 22.50249671936035
Running split 1, epoch 12 loss = 10.962104797363281
Running split 1, epoch 13 loss = 8.828307151794434
Running split 1, epoch 14 loss = 11.922263145446777
Running split 1, epoch 15 loss = 43.97943115234375
Running split 1, epoch 16 loss = 18.17267608642578
Running split 1, epoch 17 loss = 35.06484603881836
Running split 1, epoch 18 loss = 12.654120445251465
Running split 1, epoch 19 loss = 18.465015411376953
Running split 1, epoch 20 loss = 15.636603355407715
Running split 1, epoch 21 loss = 4.646087169647217
Running split 1, epoch 22 loss = 7.359231948852539
Running split 1, epoch 23 loss = 11.338569641113281
Running split 1, epoch 24 loss = 1.0820894241333008
Running split 1, epoch 25 loss = 7.042155742645264
Running split 1, epoch 26 loss = 3.7534139156341553
Running split 1, epoch 27 loss = 1.5935072898864746
Running split 1, epoch

Running split 2, epoch 15 loss = 0.32657507061958313
Running split 2, epoch 16 loss = 1.3118358850479126
Running split 2, epoch 17 loss = 0.12443319708108902
Running split 2, epoch 18 loss = 0.9745640158653259
Running split 2, epoch 19 loss = 0.187264546751976
Running split 2, epoch 20 loss = 0.2190103679895401
Running split 2, epoch 21 loss = 0.20916308462619781
Running split 2, epoch 22 loss = 0.2605677843093872
Running split 2, epoch 23 loss = 0.37369200587272644
Running split 2, epoch 24 loss = 0.12422209978103638
Running split 2, epoch 25 loss = 0.4467083215713501
Running split 2, epoch 26 loss = 0.07628318667411804
Running split 2, epoch 27 loss = 0.6445109844207764
Running split 2, epoch 28 loss = 0.6458765864372253
Running split 2, epoch 29 loss = 0.21458612382411957
Running split 2, epoch 30 loss = 0.6830366253852844
Running split 2, epoch 31 loss = 1.6081757545471191
Running split 2, epoch 32 loss = 4.094783782958984
Running split 2, epoch 33 loss = 1.3888546228408813
Running

Running split 3, epoch 23 loss = 2.6897144317626953
Running split 3, epoch 24 loss = 1.7593235969543457
Running split 3, epoch 25 loss = 0.947662889957428
Running split 3, epoch 26 loss = 4.068247318267822
Running split 3, epoch 27 loss = 0.7063255906105042
Running split 3, epoch 28 loss = 2.7710227966308594
Running split 3, epoch 29 loss = 0.9975423812866211
Running split 3, epoch 30 loss = 1.5673484802246094
Running split 3, epoch 31 loss = 0.7498632073402405
Running split 3, epoch 32 loss = 1.4851771593093872
Running split 3, epoch 33 loss = 2.069140672683716
Running split 3, epoch 34 loss = 1.1711820363998413
Running split 3, epoch 35 loss = 2.039498805999756
Running split 3, epoch 36 loss = 1.3920053243637085
Running split 3, epoch 37 loss = 0.5885021090507507
Running split 3, epoch 38 loss = 0.5349515080451965
Running split 3, epoch 39 loss = 0.25647079944610596
Running split 3, epoch 40 loss = 0.5478507876396179
Running split 3, epoch 41 loss = 1.1050400733947754
Running split 3

Running split 4, epoch 33 loss = 1.9289473295211792
Running split 4, epoch 34 loss = 2.409196615219116
Running split 4, epoch 35 loss = 11.289650917053223
Running split 4, epoch 36 loss = 0.45904600620269775
Running split 4, epoch 37 loss = 0.29191285371780396
Running split 4, epoch 38 loss = 0.87290358543396
Running split 4, epoch 39 loss = 1.2896803617477417
Running split 4, epoch 40 loss = 1.38010835647583
Running split 4, epoch 41 loss = 2.407038927078247
Running split 4, epoch 42 loss = 0.9851943850517273
Running split 4, epoch 43 loss = 0.39238592982292175
Running split 4, epoch 44 loss = 0.8100239634513855
Running split 4, epoch 45 loss = 0.2733204662799835
Running split 4, epoch 46 loss = 0.5247034430503845
Running split 4, epoch 47 loss = 0.9500615000724792
Running split 4, epoch 48 loss = 0.7680804133415222
Running split 4, epoch 49 loss = 0.4486238956451416
Running split 4, epoch 50 loss = 0.283700168132782
Running split 4, epoch 51 loss = 0.3565245568752289
Running split 4,

Running split 5, epoch 40 loss = 2.6255509853363037
Running split 5, epoch 41 loss = 5.948644161224365
Running split 5, epoch 42 loss = 2.556793212890625
Running split 5, epoch 43 loss = 4.5439019203186035
Running split 5, epoch 44 loss = 1.8991647958755493
Running split 5, epoch 45 loss = 0.16278819739818573
Running split 5, epoch 46 loss = 1.1770954132080078
Running split 5, epoch 47 loss = 3.384632110595703
Running split 5, epoch 48 loss = 4.94911527633667
Running split 5, epoch 49 loss = 0.8916200995445251
Running split 5, epoch 50 loss = 1.5980359315872192
Running split 5, epoch 51 loss = 3.6225109100341797
Running split 5, epoch 52 loss = 2.4331893920898438
Running split 5, epoch 53 loss = 1.342391848564148
Running split 5, epoch 54 loss = 3.624521255493164
Running split 5, epoch 55 loss = 2.1207869052886963
Running split 5, epoch 56 loss = 3.8786280155181885
Running split 5, epoch 57 loss = 0.9673458933830261
Running split 5, epoch 58 loss = 0.9096701741218567
Running split 5, e

Running split 6, epoch 51 loss = 0.5767069458961487
Running split 6, epoch 52 loss = 0.36234283447265625
Running split 6, epoch 53 loss = 2.4571831226348877
Running split 6, epoch 54 loss = 1.142189621925354
Running split 6, epoch 55 loss = 1.3431663513183594
Running split 6, epoch 56 loss = 1.1192008256912231
Running split 6, epoch 57 loss = 1.1547783613204956
Running split 6, epoch 58 loss = 2.141148328781128
Running split 6, epoch 59 loss = 1.3485206365585327
Running split 6, epoch 60 loss = 1.5401123762130737
Running split 6, epoch 61 loss = 3.5034573078155518
Running split 6, epoch 62 loss = 0.5510407090187073
Running split 6, epoch 63 loss = 27.259241104125977
Running split 6, epoch 64 loss = 1.3350305557250977
Running split 6, epoch 65 loss = 2.3658828735351562
Running split 6, epoch 66 loss = 2.500030994415283
Running split 6, epoch 67 loss = 3.1018295288085938
Running split 6, epoch 68 loss = 1.647171139717102
Running split 6, epoch 69 loss = 0.3038220703601837
Running split 6

Running split 7, epoch 57 loss = 1.2476394176483154
Running split 7, epoch 58 loss = 0.5612863898277283
Running split 7, epoch 59 loss = 0.28842735290527344
Running split 7, epoch 60 loss = 24.318313598632812
Running split 7, epoch 61 loss = 0.09546119719743729
Running split 7, epoch 62 loss = 1.2212556600570679
Running split 7, epoch 63 loss = 4.795182704925537
Running split 7, epoch 64 loss = 0.4928753077983856
Running split 7, epoch 65 loss = 4.677007675170898
Running split 7, epoch 66 loss = 1.8970271348953247
Running split 7, epoch 67 loss = 5.482811450958252
Running split 7, epoch 68 loss = 2.9855756759643555
Running split 7, epoch 69 loss = 1.2170721292495728
Running split 7, epoch 70 loss = 4.8428754806518555
Running split 7, epoch 71 loss = 0.312299907207489
Running split 7, epoch 72 loss = 2.4874961376190186
Running split 7, epoch 73 loss = 2.07296085357666
Running split 7, epoch 74 loss = 0.7120232582092285
Running split 7, epoch 75 loss = 5.665245532989502
Running split 7, 

Running split 8, epoch 67 loss = 1.00804603099823
Running split 8, epoch 68 loss = 3.7119500637054443
Running split 8, epoch 69 loss = 0.749531090259552
Running split 8, epoch 70 loss = 2.4320733547210693
Running split 8, epoch 71 loss = 0.16709059476852417
Running split 8, epoch 72 loss = 0.05595223978161812
Running split 8, epoch 73 loss = 0.1269858032464981
Running split 8, epoch 74 loss = 0.21929316222667694
Running split 8, epoch 75 loss = 0.5956049561500549
Running split 8, epoch 76 loss = 0.08339807391166687
Running split 8, epoch 77 loss = 2.404799699783325
Running split 8, epoch 78 loss = 0.3406590521335602
Running split 8, epoch 79 loss = 0.22887204587459564
Running split 8, epoch 80 loss = 3.5637366771698
Running split 8, epoch 81 loss = 1.9883605241775513
Running split 8, epoch 82 loss = 1.9253588914871216
Running split 8, epoch 83 loss = 0.44403567910194397
Running split 8, epoch 84 loss = 0.09428847581148148
Running split 8, epoch 85 loss = 0.6140336394309998
Running spli

Running split 9, epoch 71 loss = 0.5826890468597412
Running split 9, epoch 72 loss = 1.1927658319473267
Running split 9, epoch 73 loss = 0.10343566536903381
Running split 9, epoch 74 loss = 0.2814778983592987
Running split 9, epoch 75 loss = 4.298001766204834
Running split 9, epoch 76 loss = 1.3509103059768677
Running split 9, epoch 77 loss = 0.9263637661933899
Running split 9, epoch 78 loss = 2.541912078857422
Running split 9, epoch 79 loss = 4.082550525665283
Running split 9, epoch 80 loss = 0.6527473330497742
Running split 9, epoch 81 loss = 11.617966651916504
Running split 9, epoch 82 loss = 7.356835842132568
Running split 9, epoch 83 loss = 2.433967113494873
Running split 9, epoch 84 loss = 0.5393974184989929
Running split 9, epoch 85 loss = 0.13597151637077332
Running split 9, epoch 86 loss = 0.6289370059967041
Running split 9, epoch 87 loss = 6.896598815917969
Running split 9, epoch 88 loss = 2.6597397327423096
Running split 9, epoch 89 loss = 2.231450080871582
Running split 9, 

Running split 10, epoch 81 loss = 88.97583770751953
Running split 10, epoch 82 loss = 69.1149673461914
Running split 10, epoch 83 loss = 118.00344848632812
Running split 10, epoch 84 loss = 100.61831665039062
Running split 10, epoch 85 loss = 64.38683319091797
Running split 10, epoch 86 loss = 37.93082046508789
Running split 10, epoch 87 loss = 21.363149642944336
Running split 10, epoch 88 loss = 19.0534610748291
Running split 10, epoch 89 loss = 0.37984704971313477
Running split 10, epoch 90 loss = 5.572178363800049
Running split 10, epoch 91 loss = 10.812207221984863
Running split 10, epoch 92 loss = 19.980756759643555
Running split 10, epoch 93 loss = 2.8748903274536133
Running split 10, epoch 94 loss = 8.24890422821045
Running split 10, epoch 95 loss = 18.617177963256836
Running split 10, epoch 96 loss = 7.501241683959961
Running split 10, epoch 97 loss = 5.176393508911133
Running split 10, epoch 98 loss = 16.269723892211914
Running split 10, epoch 99 loss = 5.007132053375244

Runn

Running split 11, epoch 91 loss = 0.2999423146247864
Running split 11, epoch 92 loss = 4.668097496032715
Running split 11, epoch 93 loss = 0.5842078328132629
Running split 11, epoch 94 loss = 3.596137762069702
Running split 11, epoch 95 loss = 1.0837011337280273
Running split 11, epoch 96 loss = 0.5188406109809875
Running split 11, epoch 97 loss = 2.09665584564209
Running split 11, epoch 98 loss = 1.8422412872314453
Running split 11, epoch 99 loss = 4.1399006843566895

Running split 11 test:
Mean = tensor([[ 3.3248e+00],
        [ 1.5045e+00],
        [ 7.3881e+00],
        [ 6.5019e+00],
        [ 1.4892e+01],
        [ 2.0229e+01],
        [ 4.3689e+01],
        [ 5.7790e+00],
        [ 3.2422e-01],
        [ 9.7793e-01],
        [ 9.0391e-01],
        [ 1.5838e+00],
        [ 3.5898e+00],
        [ 1.3556e+01],
        [ 2.1979e-01],
        [ 1.6417e+00],
        [ 2.2969e+01],
        [ 1.1303e+01],
        [ 2.9407e-01],
        [ 1.5614e+00],
        [ 1.1363e+00],
        [ 7.3

Running split 12, epoch 95 loss = 19.94911003112793
Running split 12, epoch 96 loss = 5.828816890716553
Running split 12, epoch 97 loss = 6.3689656257629395
Running split 12, epoch 98 loss = 13.430874824523926
Running split 12, epoch 99 loss = 26.782094955444336

Running split 12 test:
Mean = tensor([[ 5.1402],
        [12.5148],
        [ 4.6167],
        [12.2608],
        [ 0.4981],
        [ 3.7342],
        [ 2.0326],
        [ 4.4457],
        [ 7.7990],
        [ 1.2473],
        [ 7.5267],
        [ 2.9714],
        [23.8658],
        [ 2.7441],
        [ 1.6512],
        [12.0932],
        [ 5.0005],
        [ 0.2595],
        [ 5.3784],
        [57.4966],
        [ 7.2043],
        [ 7.1951],
        [ 2.4548],
        [ 0.9262],
        [ 5.1963],
        [ 1.7939],
        [16.0003],
        [ 3.5819],
        [ 0.3073],
        [ 1.1674],
        [ 4.9730],
        [ 0.4494],
        [ 1.3833],
        [ 1.3642],
        [ 0.4509],
        [ 5.6296],
        [53.6309],
   

Running split 14, epoch 1 loss = 1.2056317329406738
Running split 14, epoch 2 loss = 13.704466819763184
Running split 14, epoch 3 loss = 4.203649997711182
Running split 14, epoch 4 loss = 1.4556607007980347
Running split 14, epoch 5 loss = 6.391076564788818
Running split 14, epoch 6 loss = 2.8538620471954346
Running split 14, epoch 7 loss = 0.25365158915519714
Running split 14, epoch 8 loss = 1.4615856409072876
Running split 14, epoch 9 loss = 0.45517006516456604
Running split 14, epoch 10 loss = 105.56914520263672
Running split 14, epoch 11 loss = 11.744561195373535
Running split 14, epoch 12 loss = 4.859742641448975
Running split 14, epoch 13 loss = 1.2347930669784546
Running split 14, epoch 14 loss = 1.0519187450408936
Running split 14, epoch 15 loss = 18.07596206665039
Running split 14, epoch 16 loss = 1.595359444618225
Running split 14, epoch 17 loss = 2.231757164001465
Running split 14, epoch 18 loss = 5.749242305755615
Running split 14, epoch 19 loss = 0.3186962902545929
Running

Running split 15, epoch 11 loss = 2.7537901401519775
Running split 15, epoch 12 loss = 3.45929217338562
Running split 15, epoch 13 loss = 6.015498638153076
Running split 15, epoch 14 loss = 10.243720054626465
Running split 15, epoch 15 loss = 1.2087990045547485
Running split 15, epoch 16 loss = 14.455620765686035
Running split 15, epoch 17 loss = 5.866008281707764
Running split 15, epoch 18 loss = 3.499898672103882
Running split 15, epoch 19 loss = 7.308389186859131
Running split 15, epoch 20 loss = 1.2260794639587402
Running split 15, epoch 21 loss = 11.646876335144043
Running split 15, epoch 22 loss = 2.646555185317993
Running split 15, epoch 23 loss = 4.7158355712890625
Running split 15, epoch 24 loss = 1.9455280303955078
Running split 15, epoch 25 loss = 9.390621185302734
Running split 15, epoch 26 loss = 1.2269366979599
Running split 15, epoch 27 loss = 2.7177793979644775
Running split 15, epoch 28 loss = 5.889291763305664
Running split 15, epoch 29 loss = 4.359743118286133
Runnin

Running split 16, epoch 21 loss = 10.637430191040039
Running split 16, epoch 22 loss = 1.7139195203781128
Running split 16, epoch 23 loss = 3.3568804264068604
Running split 16, epoch 24 loss = 1.355286717414856
Running split 16, epoch 25 loss = 0.6462387442588806
Running split 16, epoch 26 loss = 15.103985786437988
Running split 16, epoch 27 loss = 4.038244724273682
Running split 16, epoch 28 loss = 2.7867014408111572
Running split 16, epoch 29 loss = 2.2197954654693604
Running split 16, epoch 30 loss = 20.565351486206055
Running split 16, epoch 31 loss = 1.9766902923583984
Running split 16, epoch 32 loss = 15.850598335266113
Running split 16, epoch 33 loss = 5.023073196411133
Running split 16, epoch 34 loss = 6.786973476409912
Running split 16, epoch 35 loss = 6.982786178588867
Running split 16, epoch 36 loss = 1.372833251953125
Running split 16, epoch 37 loss = 6.050296783447266
Running split 16, epoch 38 loss = 1.7589668035507202
Running split 16, epoch 39 loss = 5.15516471862793
Ru

Running split 17, epoch 31 loss = 2.567758321762085
Running split 17, epoch 32 loss = 6.029027462005615
Running split 17, epoch 33 loss = 82.29668426513672
Running split 17, epoch 34 loss = 9.760119438171387
Running split 17, epoch 35 loss = 3.7796285152435303
Running split 17, epoch 36 loss = 2.067023992538452
Running split 17, epoch 37 loss = 0.7501933574676514
Running split 17, epoch 38 loss = 1.3521637916564941
Running split 17, epoch 39 loss = 1.2648463249206543
Running split 17, epoch 40 loss = 0.9655984044075012
Running split 17, epoch 41 loss = 3.949467658996582
Running split 17, epoch 42 loss = 3.7716481685638428
Running split 17, epoch 43 loss = 7.599269866943359
Running split 17, epoch 44 loss = 1.2559309005737305
Running split 17, epoch 45 loss = 5.141685962677002
Running split 17, epoch 46 loss = 1.1813526153564453
Running split 17, epoch 47 loss = 6.756762981414795
Running split 17, epoch 48 loss = 14.819095611572266
Running split 17, epoch 49 loss = 1.1494780778884888
Ru

Running split 18, epoch 37 loss = 19.788314819335938
Running split 18, epoch 38 loss = 28.03169059753418
Running split 18, epoch 39 loss = 15.988357543945312
Running split 18, epoch 40 loss = 0.13143180310726166
Running split 18, epoch 41 loss = 6.825424671173096
Running split 18, epoch 42 loss = 2.807886838912964
Running split 18, epoch 43 loss = 2.084041118621826
Running split 18, epoch 44 loss = 1.5017733573913574
Running split 18, epoch 45 loss = 2.0166895389556885
Running split 18, epoch 46 loss = 0.7080807685852051
Running split 18, epoch 47 loss = 4.738694667816162
Running split 18, epoch 48 loss = 2.53007435798645
Running split 18, epoch 49 loss = 8.040801048278809
Running split 18, epoch 50 loss = 19.78497886657715
Running split 18, epoch 51 loss = 0.14101089537143707
Running split 18, epoch 52 loss = 20.519166946411133
Running split 18, epoch 53 loss = 15.788500785827637
Running split 18, epoch 54 loss = 31.996337890625
Running split 18, epoch 55 loss = 1.402527928352356
Runn

Running split 19, epoch 41 loss = 1.9253168106079102
Running split 19, epoch 42 loss = 0.6483075618743896
Running split 19, epoch 43 loss = 0.3362492620944977
Running split 19, epoch 44 loss = 0.041012849658727646
Running split 19, epoch 45 loss = 1.4853264093399048
Running split 19, epoch 46 loss = 1.6826132535934448
Running split 19, epoch 47 loss = 0.7632927894592285
Running split 19, epoch 48 loss = 0.44244420528411865
Running split 19, epoch 49 loss = 3.3417861461639404
Running split 19, epoch 50 loss = 0.7800478935241699
Running split 19, epoch 51 loss = 4.22045373916626
Running split 19, epoch 52 loss = 6.555812835693359
Running split 19, epoch 53 loss = 4.269006252288818
Running split 19, epoch 54 loss = 0.9270977973937988
Running split 19, epoch 55 loss = 0.5952824950218201
Running split 19, epoch 56 loss = 1.2152602672576904
Running split 19, epoch 57 loss = 0.5066118240356445
Running split 19, epoch 58 loss = 0.28940561413764954
Running split 19, epoch 59 loss = 0.7705036997

### Print statistics

In [7]:
# Copied from DropoutUncertaintyExps repo
print('non-MC RMSE %f +- %f (stddev) +- %f (std error), median %f 25p %f 75p %f \n' % (
        np.mean(rmse_non_mc), np.std(rmse_non_mc), np.std(rmse_non_mc)/np.sqrt(n_splits),
        np.percentile(rmse_non_mc, 50), np.percentile(rmse_non_mc, 25), np.percentile(rmse_non_mc, 75)))

print('MC RMSE %f +- %f (stddev) +- %f (std error), median %f 25p %f 75p %f \n' % (
        np.mean(rmse_mc), np.std(rmse_mc), np.std(rmse_mc)/np.sqrt(n_splits),
        np.percentile(rmse_mc, 50), np.percentile(rmse_mc, 25), np.percentile(rmse_mc, 75)))

print('MC Test Log-likelihood %f +- %f (stddev) +- %f (std error), median %f 25p %f 75p %f \n' % (
        np.mean(test_lls_mc), np.std(test_lls_mc), np.std(test_lls_mc)/np.sqrt(n_splits), 
        np.percentile(test_lls_mc, 50), np.percentile(test_lls_mc, 25), np.percentile(test_lls_mc, 75)))

non-MC RMSE 2.815036 +- 1.235371 (stddev) +- 0.276237 (std error), median 2.383141 25p 2.103087 75p 3.022617 

MC RMSE 2.478810 +- 1.387441 (stddev) +- 0.310241 (std error), median 2.180838 25p 1.633043 75p 2.612591 

MC Test Log-likelihood -3.265878 +- 0.048659 (stddev) +- 0.010880 (std error), median -3.250950 25p -3.260931 75p -3.240249 

