## Experiment Setup

In [1]:
dataset_name = 'yacht'
n_epochs = 10

## Prepare data

### Get the data as a torch Dataset object

In [2]:
from ronald_bdl import datasets

if dataset_name == 'MNIST':
    dataset = datasets.MNIST(root_dir='./datasets_files', transform=None, download=True)
else:
    dataset = datasets.UCIDatasets(dataset_name, root_dir='./datasets_files', transform=None, download=True)

Using downloaded and verified file: ./datasets_files/yacht/yacht_hydrodynamics.data


### Split into training/test set using torch's DataLoader

In [3]:
from torch.utils.data import random_split, DataLoader

training_size = int(0.8 * len(dataset))
test_size = len(dataset) - training_size

training, test = random_split(dataset, lengths=[training_size, test_size])

training_loader = DataLoader(training)
test_loader = DataLoader(test)

## Define network

In [4]:
from ronald_bdl import models

fcnet_mc_dropout = models.FCNetMCDropout(
    input_dim=dataset.data.shape[1]-1, 
    output_dim=1,
    hidden_dim=50,
    n_hidden=2,
    n_predictions=10000)

## Train the network

### Setup optimizer

In [5]:
from torch import nn, optim

objective = nn.MSELoss()
optimizer = optim.SGD(fcnet_mc_dropout.parameters(), lr=0.001, momentum=0.9)

### Run the optimizer

In [8]:
# Adapted from https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html

for epoch in range(n_epochs):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(training_loader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = fcnet_mc_dropout(inputs)

        loss = objective(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % (n_epochs / 2) == (n_epochs / 2 - 1):
            print('[%d, %5d] loss: %.3f' %
                    (epoch + 1, i + 1, running_loss / (n_epochs / 2)))
            running_loss = 0.0

[1,     5] loss: 928.423
[1,    10] loss: 835.654
[1,    15] loss: 5.045
[1,    20] loss: 581.796
[1,    25] loss: 40.247
[1,    30] loss: 16.628
[1,    35] loss: 280.452
[1,    40] loss: 222.594
[1,    45] loss: 865.988
[1,    50] loss: 20.032
[1,    55] loss: 54.809
[1,    60] loss: 10.904
[1,    65] loss: 662.392
[1,    70] loss: 1113.044
[1,    75] loss: 438.268
[1,    80] loss: 9.713
[1,    85] loss: 105.772
[1,    90] loss: 1300.023
[1,    95] loss: 4.424
[1,   100] loss: 33.783
[1,   105] loss: 243.279
[1,   110] loss: 33.074
[1,   115] loss: 89.325
[1,   120] loss: 811.764
[1,   125] loss: 214.969
[1,   130] loss: 8.555
[1,   135] loss: 525.846
[1,   140] loss: 1044.437
[1,   145] loss: 107.690
[1,   150] loss: 13.464
[1,   155] loss: 28.599
[1,   160] loss: 298.412
[1,   165] loss: 26.042
[1,   170] loss: 42.500
[1,   175] loss: 176.868
[1,   180] loss: 117.912
[1,   185] loss: 33.341
[1,   190] loss: 5.670
[1,   195] loss: 3.517
[1,   200] loss: 73.948
[1,   205] loss: 1102.3

[8,    35] loss: 280.452
[8,    40] loss: 222.594
[8,    45] loss: 865.988
[8,    50] loss: 20.032
[8,    55] loss: 54.809
[8,    60] loss: 10.904
[8,    65] loss: 662.392
[8,    70] loss: 1113.044
[8,    75] loss: 438.268
[8,    80] loss: 9.713
[8,    85] loss: 105.772
[8,    90] loss: 1300.023
[8,    95] loss: 4.424
[8,   100] loss: 33.783
[8,   105] loss: 243.279
[8,   110] loss: 33.074
[8,   115] loss: 89.325
[8,   120] loss: 811.764
[8,   125] loss: 214.969
[8,   130] loss: 8.555
[8,   135] loss: 525.846
[8,   140] loss: 1044.437
[8,   145] loss: 107.690
[8,   150] loss: 13.464
[8,   155] loss: 28.599
[8,   160] loss: 298.412
[8,   165] loss: 26.042
[8,   170] loss: 42.500
[8,   175] loss: 176.868
[8,   180] loss: 117.912
[8,   185] loss: 33.341
[8,   190] loss: 5.670
[8,   195] loss: 3.517
[8,   200] loss: 73.948
[8,   205] loss: 1102.326
[8,   210] loss: 290.296
[8,   215] loss: 1107.549
[8,   220] loss: 325.206
[8,   225] loss: 608.951
[8,   230] loss: 494.088
[8,   235] loss: 