### Workflow

1. Split the dataset into a training, validation, and test set by randomly moving some percentage of the provided data into a validation and test directory—split the data evenly between the two, but make the training set the largest. Make sure that the training, validation, and test images do no overlap.
1. Use the `DataLoader` class to create the loading mechanism for the training and validation data using the `Dataset` class built in Step 2.
1. Build a training loop using MSE as the loss function. Determine an optimizer (pick between SGD or Adam).
1. Instantiate a model and train the network with the created routine.

#### References

* [Training and Validation Loop - finetuning torch vision models](https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html#model-training-and-validation-code)
* [Simple Training Loop (no validation) - pytorch image classification](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#train-the-network)

In [1]:
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from torch.optim import Adam, AdamW
from torchvision import transforms


In [2]:
%load_ext autoreload
%autoreload 2

from utils import *

In [3]:
model = MRConvNet()

In [4]:
device = torch.device('cuda')
model.to(device)

MRConvNet(
  (conv1): Conv3d(1, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (bnorm): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (drop1): Dropout3d(p=0.3, inplace=False)
  (conv2): Conv3d(16, 1, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
)

In [5]:
criterion = nn.MSELoss()
optimizer = Adam(model.parameters())

#### Train / Test / Validation Splitting

In [6]:
# How many files do we have?
nda = NiftiDataset('./data/small/t1', './data/small/t2')
corpus_size = len(nda)

# train, test, validation split
corpus_idxs = range(0, corpus_size)
train_idxs, test_idxs = train_test_split(corpus_idxs, train_size=0.75)
train_idxs, valid_idxs  = train_test_split(train_idxs, train_size=0.8) #split training into training and validation.

index_dict = {
    'train': train_idxs, 
    'test': test_idxs,
    'valid': valid_idxs
}



#### Create the DataLoaders and Datasets

In [7]:
batch_size = 10
cube_crop_dim = 60

image_datasets   = {k:NiftiSplitDataset('./data/small/t1', './data/small/t2', 
                                        index_dict[k], transforms.Compose([RandomCrop3D(cube_crop_dim), ToTensor()])) 
                    for k in ['train', 'valid', 'test']}
dataloaders_dict = {x: DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) 
                    for x in ['train', 'valid', 'test']}


#### Training / Validation loop

In [8]:
def train_model(model, dataloaders, criterion, optimizer, num_epochs=25):
    val_acc_history = []
    
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        
        for phase in ['train', 'valid']:
            if phase == 'train':
                model.train()
            elif phase == 'valid':
                model.eval()
    
            running_loss = 0.0
            running_corrects = 0
        
            #get a batch of samples / pairs
            for samples in dataloaders[phase]:
            
                #put the batch on the device
                source = samples['source'].to(device)
                target = samples['target'].to(device)
            
                # reset parameter gradients after each batch 
                optimizer.zero_grad()
            
                # LEFT OFF HERE
            

In [9]:
train_model(model, dataloaders_dict, criterion, optimizer)

Epoch 0/24
----------
Epoch 1/24
----------
Epoch 2/24
----------
Epoch 3/24
----------
Epoch 4/24
----------
Epoch 5/24
----------
Epoch 6/24
----------
Epoch 7/24
----------
Epoch 8/24
----------
Epoch 9/24
----------
Epoch 10/24
----------
Epoch 11/24
----------
Epoch 12/24
----------
Epoch 13/24
----------
Epoch 14/24
----------
Epoch 15/24
----------
Epoch 16/24
----------
Epoch 17/24
----------
Epoch 18/24
----------
Epoch 19/24
----------
Epoch 20/24
----------
Epoch 21/24
----------
Epoch 22/24
----------
Epoch 23/24
----------
Epoch 24/24
----------


### Testing