### Workflow

1. Split the dataset into a training, validation, and test set by randomly moving some percentage of the provided data into a validation and test directory—split the data evenly between the two, but make the training set the largest. Make sure that the training, validation, and test images do no overlap.
1. Use the `DataLoader` class to create the loading mechanism for the training and validation data using the `Dataset` class built in Step 2.
1. Build a training loop using MSE as the loss function. Determine an optimizer (pick between SGD or Adam).
1. Instantiate a model and train the network with the created routine.

#### References

* [Training and Validation Loop - finetuning torch vision models](https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html#model-training-and-validation-code)
* [Simple Training Loop (no validation) - pytorch image classification](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#train-the-network)

In [1]:
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from torch.optim import Adam, AdamW
from torchvision import transforms


In [10]:
%load_ext autoreload
%autoreload 2

from utils import *

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
model = MRConvNet()

In [4]:
device = torch.device('cuda')
model.to(device)

MRConvNet(
  (conv1): Conv3d(1, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (bnorm): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (drop1): Dropout3d(p=0.3, inplace=False)
  (conv2): Conv3d(16, 1, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
)

In [5]:
criterion = nn.MSELoss()
optimizer = Adam(model.parameters(), lr=1e-4)

#### Train / Test / Validation Splitting

In [6]:
# How many files do we have?
nda = NiftiDataset('./data/small/t1', './data/small/t2')
corpus_size = len(nda)

seed = 42

# train, test, validation split
corpus_idxs = range(0, corpus_size)
train_idxs, test_idxs   = train_test_split(corpus_idxs, random_state=seed, train_size=0.75)
train_idxs, valid_idxs  = train_test_split(train_idxs,  random_state=seed, train_size=0.8) #split training into training and validation.

index_dict = {
    'train': train_idxs, 
    'test': test_idxs,
    'valid': valid_idxs
}



#### Create the DataLoaders and Datasets

In [13]:
batch_size = 10
cube_crop_dim = 60
mult = 7 #useful to acheive a similar effect to data augmentation. Multiplies each dataset by X.  Only works if you randomly crop. 

image_datasets   = {k:NiftiSplitDataset('./data/small/t1', './data/small/t2', 
                                        index_dict[k], 
                                        transforms.Compose([RandomCrop3D(cube_crop_dim), 
                                                            ToNumpy(), 
                                                            AddDim(), 
                                                            ToTensor()]),
                                       mult)
                    for k in ['train', 'valid', 'test']}
dataloaders_dict = {x: DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) 
                    for x in ['train', 'valid', 'test']}


#### Training / Validation loop

Remind from Manning PyTorch book about the optimizer / loss and accumulation of gradients.

> calling `backward()` will lead derivatives to accumulate at leaf nodes. So if backward has been called earlier, the loss is evaluated again and backward is called again (as in any training loop), the gradient at each leaf will be accumulated (i.e. summed) on top of the one computed at the previous iteration,

In the case above the graph is a computation graph.  Operations can have other operations or tensors as inputs.  Tensors are the leafs in the graph. 

In [14]:
def train_model(model, dataloaders, criterion, optimizer, num_epochs=25):
    val_acc_history = []
    
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        
        for phase in ['train', 'valid']:
            if phase == 'train':
                model.train()
            elif phase == 'valid':
                model.eval()
    
            running_loss = 0.0
            running_corrects = 0
        
            #loop over all batch of samples / pairs
            for samples in dataloaders[phase]:
            
                # how does this work? 
                # default collate_fn magically handles it - https://pytorch.org/docs/stable/data.html#dataloader-collate-fn

                #put the batch on the device
                source = samples['source'].to(device)
                target = samples['target'].to(device)
            
                # reset parameter gradients after each batch 
                optimizer.zero_grad()
            
                # you only need the to track the computation graph if you are going to calculate 
                # the gradient via backprop
                with torch.set_grad_enabled(phase == 'train'):
                    
                    predicted = model(source)
                    loss = criterion(predicted, target) #this is just a node in the computation graph. Returns tensor.
                    
                    if phase == 'train':
                        loss.backward()  # propagates gradients
                        optimizer.step() # uses the learning rate and the gradients to update the weights. 
                        
                # can we calculate any statistics for this batch? 
                ## - show off how many batches we are processing? 
                running_loss += loss.item() * samples['source'].size(0)
                
            # calculate any statistics for the phase + epoch?
            ## - print training loss, print validation loss.
            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            print('{} Loss: {:.4f}'.format(phase, epoch_loss))
        
    # calculate / report on anything for the epoch?

In [None]:
train_model(model, dataloaders_dict, criterion, optimizer, 50)


Epoch 0/49
----------
train Loss: 0.6566
valid Loss: 0.5308
Epoch 1/49
----------
train Loss: 0.6008
valid Loss: 0.5160
Epoch 2/49
----------
train Loss: 0.6120
valid Loss: 0.5083
Epoch 3/49
----------
train Loss: 0.6247
valid Loss: 0.4606
Epoch 4/49
----------
train Loss: 0.5950
valid Loss: 0.4466
Epoch 5/49
----------
train Loss: 0.6004
valid Loss: 0.4909
Epoch 6/49
----------
train Loss: 0.5676
valid Loss: 0.4731
Epoch 7/49
----------
train Loss: 0.5804
valid Loss: 0.4278
Epoch 8/49
----------
train Loss: 0.5734
valid Loss: 0.4589
Epoch 9/49
----------
train Loss: 0.6135
valid Loss: 0.4474
Epoch 10/49
----------
train Loss: 0.5853
valid Loss: 0.4698
Epoch 11/49
----------
train Loss: 0.5732
valid Loss: 0.4604
Epoch 12/49
----------
train Loss: 0.6029
valid Loss: 0.4640
Epoch 13/49
----------
train Loss: 0.6017
valid Loss: 0.4495
Epoch 14/49
----------
train Loss: 0.5589
valid Loss: 0.4526
Epoch 15/49
----------
train Loss: 0.5876
valid Loss: 0.4547
Epoch 16/49
----------
train Loss:

In [None]:
### Left off here

'''

I'd like get this model to not jump around so much. It should 

* Get more data
    - get more random crops
    - can you make smaller crops of the data? This would give you a chance to reuse images. 

* Make your model better
    - check the Mannging book on Pytorch. What did their model look like? 
    
* Experiment with the learning rate, batch_size and mult? 
    - can you get the training loss to decrease
    - how would you graph the validation and training loss. 
    - do you want to run for a few epochs with a high learning rate then decrease it? 
    
Questions

* How would you create a full size image once the model is trained? It hasn't trained on any full size images. 
* What is a good validation loss? 
* is "mult" cheating too much? Can you come up with a better strategy? 
* do you want to save the "best" model that is not wildly overfitting (training loss not crazy less than validation loss)
  and that has the best validation loss? 

'''


### Testing

In [17]:

#what's in parameters? 
print("Named Parameters")
for name, param in model.named_parameters():
    if param.requires_grad:
        print(name, param.shape)

print("")
print("All Parameters")
        
for param in model.parameters():
    if param.requires_grad:
        print(param.shape)

Named Parameters
conv1.weight torch.Size([16, 1, 3, 3, 3])
conv1.bias torch.Size([16])
bnorm.weight torch.Size([16])
bnorm.bias torch.Size([16])
conv2.weight torch.Size([1, 16, 3, 3, 3])
conv2.bias torch.Size([1])

All Parameters
torch.Size([16, 1, 3, 3, 3])
torch.Size([16])
torch.Size([16])
torch.Size([16])
torch.Size([1, 16, 3, 3, 3])
torch.Size([1])
