### Workflow

1. Split the dataset into a training, validation, and test set by randomly moving some percentage of the provided data into a validation and test directory—split the data evenly between the two, but make the training set the largest. Make sure that the training, validation, and test images do no overlap.
1. Use the `DataLoader` class to create the loading mechanism for the training and validation data using the `Dataset` class built in Step 2.
1. Build a training loop using MSE as the loss function. Determine an optimizer (pick between SGD or Adam).
1. Instantiate a model and train the network with the created routine.

#### References

* [Training and Validation Loop - finetuning torch vision models](https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html#model-training-and-validation-code)
* [Simple Training Loop (no validation) - pytorch image classification](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#train-the-network)

In [1]:
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split

In [2]:
%load_ext autoreload
%autoreload 2

from utils import *

In [3]:
device = torch.device('cuda')

In [4]:
criterion = None
optimizer = None

#### Train / Test / Validation Splitting

In [5]:
## 

#How many files do we have?

nda = NiftiDataset('./data/small/t1', './data/small/t2')
corpus_size = len(nda)

#train, test, validation split
corpus_idxs = range(0, corpus_size)
train_idxs, test_idxs = train_test_split(corpus_idxs, train_size=0.75)
train_idxs, valid_idxs  = train_test_split(train_idxs, train_size=0.8) #split training into training and validation.

index_dict = {
    'train': train_idxs, 
    'test': test_idxs,
    'valid': valid_idxs
}



#### Create the DataLoaders and Datasets

In [7]:
batch_size = 10 

image_datasets   = {k:NiftiSplitDataset('./data/small/t1', './data/small/t2', index_dict[k]) 
                    for k in ['train', 'valid', 'test']}
dataloaders_dict = {x: DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) 
                    for x in ['train', 'valid', 'test']}


In [8]:

#training / validation loop