### Workflow

1. Split the dataset into a training, validation, and test set by randomly moving some percentage of the provided data into a validation and test directory—split the data evenly between the two, but make the training set the largest. Make sure that the training, validation, and test images do no overlap.
1. Use the `DataLoader` class to create the loading mechanism for the training and validation data using the `Dataset` class built in Step 2.
1. Build a training loop using MSE as the loss function. Determine an optimizer (pick between SGD or Adam).
1. Instantiate a model and train the network with the created routine.

#### References

* [Training and Validation Loop - finetuning torch vision models](https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html#model-training-and-validation-code)
* [Simple Training Loop (no validation) - pytorch image classification](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#train-the-network)

In [10]:
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from torch.optim import Adam, AdamW
from torchvision import transforms


In [23]:
%load_ext autoreload
%autoreload 2

from utils import *

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [12]:
model = MRConvNet()

In [13]:
device = torch.device('cuda')
model.to(device)

MRConvNet(
  (conv1): Conv3d(1, 16, kernel_size=(60, 60, 60), stride=(1, 1, 1), padding=(1, 1, 1))
  (bnorm): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (drop1): Dropout3d(p=0.3, inplace=False)
  (conv2): Conv3d(16, 1, kernel_size=(60, 60, 60), stride=(1, 1, 1), padding=(1, 1, 1))
)

In [14]:
criterion = nn.MSELoss()
optimizer = Adam(model.parameters(), lr=1e-3)

#### Train / Test / Validation Splitting

In [15]:
# How many files do we have?
nda = NiftiDataset('./data/small/t1', './data/small/t2')
corpus_size = len(nda)

seed = 42

# train, test, validation split
corpus_idxs = range(0, corpus_size)
train_idxs, test_idxs   = train_test_split(corpus_idxs, random_state=seed, train_size=0.75)
train_idxs, valid_idxs  = train_test_split(train_idxs,  random_state=seed, train_size=0.8) #split training into training and validation.

index_dict = {
    'train': train_idxs, 
    'test': test_idxs,
    'valid': valid_idxs
}



#### Create the DataLoaders and Datasets

In [24]:
batch_size = 10
cube_crop_dim = 60

image_datasets   = {k:NiftiSplitDataset('./data/small/t1', './data/small/t2', 
                                        index_dict[k], transforms.Compose([RandomCrop3D(cube_crop_dim), AddDim(), ToTensor()])) 
                    for k in ['train', 'valid', 'test']}
dataloaders_dict = {x: DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) 
                    for x in ['train', 'valid', 'test']}


#### Training / Validation loop

Remind from Manning PyTorch book about the optimizer / loss and accumulation of gradients.

> calling backward will lead derivatives to accumulate at leaf nodes. So if backward has been called earlier, the loss is evaluated again and backward is called again (as in any training loop), the gradient at each leaf will be accumulated (i.e. summed) on top of the one computed at the previous iteration,

In the case above the graph is a computation graph.  Operations can have other operations or tensors as inputs.  Tensors are the leafs in the graph. 

In [25]:
def train_model(model, dataloaders, criterion, optimizer, num_epochs=25):
    val_acc_history = []
    
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        
        for phase in ['train', 'valid']:
            if phase == 'train':
                model.train()
            elif phase == 'valid':
                model.eval()
    
            running_loss = 0.0
            running_corrects = 0
        
            #loop over all batch of samples / pairs
            for samples in dataloaders[phase]:
            
                #put the batch on the device
                source = samples['source'].to(device)
                target = samples['target'].to(device)
            
                # reset parameter gradients after each batch 
                optimizer.zero_grad()
            
                # you only need the to track the computation graph if you are going to calculate 
                # the gradient via backprop
                with torch.set_grad_enabled(phase == 'train'):
                    
                    predicted = model(source)
                    loss = criterion(predicted, target) #this is just a node in the computation graph. Returns tensor.
                    
                    if phase == 'train':
                        loss.backward()  
                        optimizer.step()
                        
                # can we calculate any statistics for this batch? 
                ## - show off how many batches we are processing? 
                running_loss += loss.item() * samples.size(0)
                
            # calculate any statistics for the phase + epoch?
            ## - print training loss, print validation loss.
            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            print('{} Loss: {:.4f}'.format(phase, epoch_loss))
        
    # calculate / report on anything for the epoch?

In [26]:
train_model(model, dataloaders_dict, criterion, optimizer)

#LEFT off here
# - issue with batch and channel.  Do you want to unsqueese a tensor somewhere?   (N, C, D, H, W):  N = batch size.
# - issue with arguments to Conv3D?  Conv3D(in_channels, out_channels, kernel_size)
# - see MRConvNet code in utils.py

Epoch 0/24
----------


AttributeError: 'numpy.ndarray' object has no attribute 'get_fdata'

### Testing

In [17]:

#what's in parameters? 
print("Named Parameters")
for name, param in model.named_parameters():
    if param.requires_grad:
        print(name, param.shape)

print("")
print("All Parameters")
        
for param in model.parameters():
    if param.requires_grad:
        print(param.shape)

Named Parameters
conv1.weight torch.Size([16, 1, 3, 3, 3])
conv1.bias torch.Size([16])
bnorm.weight torch.Size([16])
bnorm.bias torch.Size([16])
conv2.weight torch.Size([1, 16, 3, 3, 3])
conv2.bias torch.Size([1])

All Parameters
torch.Size([16, 1, 3, 3, 3])
torch.Size([16])
torch.Size([16])
torch.Size([16])
torch.Size([1, 16, 3, 3, 3])
torch.Size([1])
