### Workflow

1. Split the dataset into a training, validation, and test set by randomly moving some percentage of the provided data into a validation and test directory—split the data evenly between the two, but make the training set the largest. Make sure that the training, validation, and test images do no overlap.
1. Use the `DataLoader` class to create the loading mechanism for the training and validation data using the `Dataset` class built in Step 2.
1. Build a training loop using MSE as the loss function. Determine an optimizer (pick between SGD or Adam).
1. Instantiate a model and train the network with the created routine.

#### References

* [Training and Validation Loop - finetuning torch vision models](https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html#model-training-and-validation-code)
* [Simple Training Loop (no validation) - pytorch image classification](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#train-the-network)

In [17]:
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from torch.optim import Adam, AdamW


In [13]:
%load_ext autoreload
%autoreload 2

from utils import *

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [14]:
model = MRConvNet()

In [15]:
device = torch.device('cuda')
model.to(device)

MRConvNet(
  (conv1): Conv3d(1, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (bnorm): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (drop1): Dropout3d(p=0.3, inplace=False)
  (conv2): Conv3d(16, 1, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
)

In [18]:
criterion = nn.MSELoss
optimizer = Adam

#### Train / Test / Validation Splitting

In [5]:
# How many files do we have?
nda = NiftiDataset('./data/small/t1', './data/small/t2')
corpus_size = len(nda)

# train, test, validation split
corpus_idxs = range(0, corpus_size)
train_idxs, test_idxs = train_test_split(corpus_idxs, train_size=0.75)
train_idxs, valid_idxs  = train_test_split(train_idxs, train_size=0.8) #split training into training and validation.

index_dict = {
    'train': train_idxs, 
    'test': test_idxs,
    'valid': valid_idxs
}



#### Create the DataLoaders and Datasets

In [7]:
batch_size = 10 

image_datasets   = {k:NiftiSplitDataset('./data/small/t1', './data/small/t2', index_dict[k]) 
                    for k in ['train', 'valid', 'test']}
dataloaders_dict = {x: DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) 
                    for x in ['train', 'valid', 'test']}


#### Training / Validation loop

In [21]:
def train_model(model, dataloaders, criterion, optimizer, num_epochs=25):
    val_acc_history = []
    
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        
        for phase in ['train', 'valid']:
            if phase == 'train':
                model.train()
            elif phase == 'valid':
                model.eval()
    
            running_loss = 0.0
            running_corrects = 0
        
            #get a batch of samples / pairs
            for samples in dataloaders[phase]:
            
                #put the batch on the device
                source = samples['source'].to(device)
                target = samples['target'].to(device)
            
                # reset parameter gradients after each batch 
                optimizer.zero_grad()
            
            

In [23]:
#LEFT OFF HERE. Dataset class needs to return a transform to work with Dataloader.  
# I wonder if I can just write a ToTensor transform as shown in
# https://pytorch.org/tutorials/beginner/data_loading_tutorial.html
# second option would just be to return a numpy array 
train_model(model, dataloaders_dict, criterion, optimizer)

Epoch 0/24
----------


TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'nibabel.nifti1.Nifti1Image'>

### Testing