## Understanding the Training Function

Go thru and comment each line <br>

Where are the loss functions called ? 

Train function only ever calls 

## The Model

```python

    def forward(self, x):
        # Assume the input shape (d, h, w)
        
        # pass thru r2plus1d model layers
        
        # Features output from stem channels == 64, shape = (32, 56, 56)
        output_stem = self.r2plus1d_model.stem(x)
        # Features output from block 1 channels == 64, shape = (32, 56, 56)
        output_layer_1 = self.r2plus1d_model.layer1(output_stem)
        # Features output from block 2 channels == 128, shape = (16, 28, 28)
        output_layer_2 = self.r2plus1d_model.layer2(output_layer_1)
        # Features output from block 3 channels == 256, shape = (8, 14, 14)
        output_layer_3 = self.r2plus1d_model.layer3(output_layer_2)
        # Features output from block 4 channels == 512, shape = (4, 7, 7)
        output_layer_4 = self.r2plus1d_model.layer4(output_layer_3)
        
        # pass thru additional convolutional layers that we had defined in the model
        
        # Upsampling 5 features output to shape of original input (32, 112, 112)
        # Stem (32, 56, 56) -> (32, 112, 112)
        up_stem = F.interpolate(output_stem, scale_factor=[1, 2, 2], mode='trilinear', align_corners=True)
        # block 1 (32, 56, 56) -> (32, 112, 112)
        up_layer_1 = F.interpolate(output_layer_1, scale_factor=[1, 2, 2], mode='trilinear', align_corners=True)
        # block 2 (16, 28, 28) -> (32, 112, 112)
        up_layer_2 = F.interpolate(output_layer_2, scale_factor=[2, 4, 4], mode='trilinear', align_corners=True)
        # block 3 (8, 14, 14) -> (32, 112, 112)
        up_layer_3 = F.interpolate(output_layer_3, scale_factor=[4, 8, 8], mode='trilinear', align_corners=True)
        # block 4 (4, 7, 7) -> (32, 112, 112)
        up_layer_4 = F.interpolate(output_layer_4, scale_factor=[8, 16, 16], mode='trilinear', align_corners=True)
        
        # Concatenate the upsampled output: 64 + 64 + 128 + 256 + 512 = 1024
        cat_output = torch.cat([up_stem, up_layer_1, up_layer_2, up_layer_3, up_layer_4], 1)
            
        # 1024 -> 64
        x = self.comb_1_layer(cat_output)
        x = self.comb_batch_norm_1(x)
        x = self.comb_relu_1(x)
        
        # 64 -> 64
        x = self.comb_2_layer(x)
        x = self.comb_batch_norm_2(x)
        x = self.comb_relu_2(x)
        
        # Segmentation output: 64 -> 2 [Background, LV]
        segmentation_output = self.segmentation_head(x)
        
        # Motion output: 64 -> 4 [Forward x, y, backward x, y]
        motion_output = self.motion_head(x)
        motion_output = torch.tanh(motion_output)
        
        return segmentation_output, motion_output
```

### The training cycle

In [None]:
Tensor = torch.cuda.FloatTensor

model_save_path = "save_models/R2plus1DMotionSegNet_model.pth"   

train_loss_list = []
valid_loss_list = []

n_epoch = 10
min_loss = 1e5
for epoch in range(1, n_epoch + 1):
    print("-" * 32 + 'Epoch {}'.format(epoch) + "-" * 32)
    start = time.time()
    
    # actual training functoin call
    train_loss = train(epoch, train_loader=train_dataloader, model=model, optimizer=optimizer)
    
    train_loss_list.append(np.mean(train_loss))
    end = time.time()
    print("training took {:.8f} seconds".format(end-start))
    
    # actual validation functoin call
    valid_loss = test(epoch, test_loader=valid_dataloader, model=model, optimizer=optimizer)
    
    valid_loss_list.append(np.mean(valid_loss))
    
    # only save models after a training cycle only if validation loss is lower than previous
    # min validation loss (we use avg of validation loss)
    if (np.mean(valid_loss) < min_loss) and (epoch > 0):
        # save new min loss, average of all losses from the validation losses
        # all the losses are of the multiple segmentations per 32 clip frame from all
        # videos from the test_dataloader (1276 videos)
        min_loss = np.mean(valid_loss) 
        
        # save model
        torch.save({"model": model.state_dict(), "optimizer": optimizer.state_dict()}, model_save_path)
    
    # change optimizer learning rate after some number of epochs to be smaller to take
    # smaller steps, do not overjump
    if epoch == 3:
        lr_T = 1e-5
        optimizer = optim.Adam(model.parameters(), lr=lr_T)

## Train function

In [None]:
def train(epoch, train_loader, model, optimizer):
    """ Training function for the network """
    model.train()   # set model to training mode
    epoch_loss = [] # holder of epoch losses
    ed_lv_dice = 0
    es_lv_dice = 0
    
    np.random.seed() # seed RNG
    for batch_idx, batch in enumerate(train_loader, 1):
        # from the batch of data from the train_loader, convert it into a usable video_clips var to be passed into the model
        # where the internal feature map will parse thru it, and then spit out the seg and motion outputs
        video_clips = torch.Tensor(batch[0])
        video_clips = video_clips.type(Tensor) # cast to Tensor type and return, if video clip already tensor (
                                               # then don't, it looks to me like another try just in case transform to tensor doesn't work?
        filename, EF, es_clip_index, ed_clip_index, es_index, ed_index, es_frame, ed_frame, es_label, ed_label = batch[1]

        # clear optimizer gradients, sets them all to 0
        optimizer.zero_grad()
        
        # Get the motion tracking output from the motion tracking head using the feature map
        # model just eats up a video_clip and spits out 2 ndarrays of values of the segmentation and motion outputs
        # the function that gets called should be the forward pass.
        # for reference, here is that forward pass that the video_clips goes thru:
        # video_clips becomes the x that goes thru the layers of the model

        segmentation_output, motion_output = model(video_clips)
        
        loss = 0
        
        # clasfv loss func
        # see how much warp based on motion_output is off by.
        # compare frame n+warp to actual n+1 frame to see how much of a difference there exists btwn the two ?
        # returned loss is an average of smooth/huber loss and the mse loss
        deform_loss = deformation_motion_loss(video_clips, motion_output)
        
        # add deform loss to total loss count
        loss += deform_loss

        segmentation_loss = 0
        motion_loss = 0
        
        # loop over number of 32 frame video clips (since that is individual data unit to be passed to the model)
        for i in range(video_clips.shape[0]):
            # adjust the dimensions, expand by axis=1
            label_ed = np.expand_dims(ed_label.numpy(), 1).astype("int")
            label_es = np.expand_dims(es_label.numpy(), 1).astype("int")

            # grab at index i
            label_ed = label_ed[i]
            label_es = label_es[i]

            # adjust dimensions, expand by axis=0
            label_ed = np.expand_dims(label_ed, 0)
            label_es = np.expand_dims(label_es, 0)

            # pytorch unsqueeze: "returns a new tensor with a dimension of size one inserted at the specified position"
            motion_one_output = motion_output[i].unsqueeze(0)
            segmentation_one_output = segmentation_output[i].unsqueeze(0)

            # grab at one index (i)
            ed_one_index = ed_clip_index[i]
            es_one_index = es_clip_index[i]

            # clasfv loss func
            # warps/transforms ED -> ES and ED <- ES (forward, backward) using the motion output and compare to 
            # actual ED/ES frames to compute losses (returns flow_loss, OTS_Loss)
            # we get:
            # segmentation_loss, motion_loss = flow_loss, OTS_Loss
            segmentation_one_loss, motion_one_loss = motion_seg_loss(label_ed, label_es, 
                                                                     ed_one_index, es_one_index, 
                                                                     motion_one_output, segmentation_one_output, 
                                                                     0, video_clips.shape[2], 
                                                                     F.binary_cross_entropy_with_logits)
            # concatenate losses to our existing losses for this video
            segmentation_loss += segmentation_one_loss
            motion_loss += motion_one_loss
        
        # average our losses from the individual losses computed from the individual frames over number of frames in video clip
        loss += (segmentation_loss / video_clips.shape[0])
        loss += (motion_loss / video_clips.shape[0])              
        
        # initialize tensors
        ed_segmentations = torch.Tensor([]).type(Tensor)
        es_segmentations = torch.Tensor([]).type(Tensor)
        
        # iterate over number of ed clips we have (ed_clip_index holds the indeces of all ed clip/frames ?)
        for i in range(len(ed_clip_index)):
            # grab one index of an ed and es frame
            ed_one_index = ed_clip_index[i]
            es_one_index = es_clip_index[i]
            
            # grab segmentatoin output for this specific ed and es frame
            ed_seg = segmentation_output[i, :, ed_one_index].unsqueeze(0)
            # concatenate segmentation results to our storage vars
            ed_segmentations = torch.cat([ed_segmentations, ed_seg])
            
            es_seg = segmentation_output[i, :, es_one_index].unsqueeze(0)
            es_segmentations = torch.cat([es_segmentations, es_seg])
            
           
        # compute loss from the segmentation results of ed and es frames (remember, we have the labeled answers for ed and es frames)
        ed_es_seg_loss = 0
        ed_es_seg_loss += F.binary_cross_entropy_with_logits(ed_segmentations, 
                                                             convert_to_1hot(np.expand_dims(ed_label.numpy().astype("int"), 1), 2), 
                                                             reduction="mean") 
        
        ed_es_seg_loss += F.binary_cross_entropy_with_logits(es_segmentations, 
                                                             convert_to_1hot(np.expand_dims(es_label.numpy().astype("int"), 1), 2), 
                                                             reduction="mean")
        # average ed / es segmentation loss
        ed_es_seg_loss /= 2
        
        # add this avg to overall loss
        loss += ed_es_seg_loss

        loss.backward()   # compute gradient of the objective function
        
        optimizer.step()  # update the weights of network using the computed gradients 
        
        epoch_loss.append(loss.item()) # convert loss tensor to a python number, then append curr loss to our epoch loss
        
        # lets look at this carefully. (just makes our ed and es segmentaion results usable for other computation)
        # torch.argmax() returns index of the max val in the input tensor (1st param), then reduce to a single dimension (2nd param)
        # .detach() return a new tensor detached from the graph, storage/mem loc still same as original
        # .cpu() returns a copy of tensor to cpu memory (RAM) from (presumably) the gpu, if already on cpu mem / ram just return obj
        # .numpy() convert pytorch tensor to numpy ndarray obj
        ed_segmentation_argmax = torch.argmax(ed_segmentations, 1).cpu().detach().numpy()
        es_segmentation_argmax = torch.argmax(es_segmentations, 1).cpu().detach().numpy()
            
        # compute dice overlap from the ed and es (predicted) segmentations to actual labeled ed and es
        # clasfv loss funcs
        ed_lv_dice += categorical_dice(ed_segmentation_argmax, ed_label.numpy(), 1)
        es_lv_dice += categorical_dice(es_segmentation_argmax, es_label.numpy(), 1)
        
        # Printing the intermediate training statistics
        if batch_idx % 280 == 0:
            print('\nTrain Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(video_clips), len(train_loader) * len(video_clips),
                100. * batch_idx / len(train_loader), np.mean(epoch_loss)))

            print("ED LV: {:.3f}".format(ed_lv_dice / batch_idx))
            print("ES LV: {:.3f}".format(es_lv_dice / batch_idx))
            
            print("On a particular batch:")
            print("Deform loss: ", deform_loss)
            print("Segmentation loss: ", ed_es_seg_loss)
            print("Seg Motion loss: ", segmentation_loss / video_clips.shape[0], motion_loss / video_clips.shape[0])
    
    return epoch_loss

In [None]:
def test(epoch, test_loader, model, optimizer):
    model.eval()
    epoch_loss = []
    ed_lv_dice = 0
    es_lv_dice = 0
    
    for batch_idx, batch in enumerate(test_loader, 1):
        filename, EF, es_clip_index, ed_clip_index, es_index, ed_index, es_frame, ed_frame, es_label, ed_label = batch[1]
        # tell torch we dont need gradients, since we're testing not training when creating video clips
        with torch.no_grad():
            video_clips = torch.Tensor(batch[0])
            video_clips = video_clips.type(Tensor)

        # Get the motion tracking output from the motion tracking head using the feature map
        segmentation_output, motion_output = model(video_clips)
        
        # use warp to get predicted i-1, i+1 frames from ith frame, compare to actual frames i-1, i+1 to compute
        # deformation motion loss
        loss = 0
        deform_loss = deformation_motion_loss(video_clips, motion_output)
        loss += deform_loss

    
        segmentation_loss = 0
        motion_loss = 0
        
        # loop thru 32 frame video clips 
        for i in range(video_clips.shape[0]):
            # adjust ed and es labels, expand along dimension 1 after converting to numpy ndarray objs from tensor
            label_ed = np.expand_dims(ed_label.numpy(), 1).astype("int")
            label_es = np.expand_dims(es_label.numpy(), 1).astype("int")

            # get current ed and es label
            label_ed = label_ed[i]
            label_es = label_es[i]

            # adjust ed and es labels again, expand along dimension 0
            label_ed = np.expand_dims(label_ed, 0)
            label_es = np.expand_dims(label_es, 0)

            # pytorch unsqueeze: "returns a new tensor with a dimension of size one inserted at the specified position"
            # initialize storage vars basically 
            motion_one_output = motion_output[i].unsqueeze(0)
            segmentation_one_output = segmentation_output[i].unsqueeze(0)

            # grab ed and es frame indeces
            ed_one_index = ed_clip_index[i]
            es_one_index = es_clip_index[i]

            # compute segmentation and motion losses for curr clip
            segmentation_one_loss, motion_one_loss = motion_seg_loss(label_ed, label_es, 
                                                                     ed_one_index, es_one_index, 
                                                                     motion_one_output, segmentation_one_output, 
                                                                     0, video_clips.shape[2], 
                                                                     F.binary_cross_entropy_with_logits)
            # using previous unsqueezed tensors to store our curr seg and mot losses
            segmentation_loss += segmentation_one_loss
            motion_loss += motion_one_loss
        
        # after all clips have been iterated thru, avg their losses
        # total loss looks like a sum of the averages of the seg and motion losses
        loss += (segmentation_loss / video_clips.shape[0])
        loss += (motion_loss / video_clips.shape[0])
        
        # create ed and es seg storage tensors
        ed_segmentations = torch.Tensor([]).type(Tensor)
        es_segmentations = torch.Tensor([]).type(Tensor)
        
        # iterate over number of ed clips we have (ed_clip_index holds the indeces of all ed clip/frames ?)
        for i in range(len(ed_clip_index)):
            # grab curr ed and es labeled frame index
            ed_one_index = ed_clip_index[i]
            es_one_index = es_clip_index[i]
            
            # transform segmentation outputs of the ed and es predicted labels from the model
            ed_seg = segmentation_output[i, :, ed_one_index].unsqueeze(0)
            ed_segmentations = torch.cat([ed_segmentations, ed_seg])
            
            es_seg = segmentation_output[i, :, es_one_index].unsqueeze(0)
            es_segmentations = torch.cat([es_segmentations, es_seg])
            
            
        # compare ed and es predicted segmentation frames from the model with the ground truth from our datasets
        # compute Binary Cross Entropy between target and input logits 
        # input = our segmentations
        # target = ground truth labels, after some preprocessing (numpy, expand dims, convertTo1Hot)
        ed_es_seg_loss = 0
        ed_es_seg_loss += F.binary_cross_entropy_with_logits(ed_segmentations, 
                                                             convert_to_1hot(np.expand_dims(ed_label.numpy().astype("int"), 1), 2), 
                                                             reduction="mean") 
        
        ed_es_seg_loss += F.binary_cross_entropy_with_logits(es_segmentations, 
                                                             convert_to_1hot(np.expand_dims(es_label.numpy().astype("int"), 1), 2), 
                                                             reduction="mean") 
        # average ed and es segmentation losses
        ed_es_seg_loss /= 2
        
        # add averaged ed/es segmentation losses with the previously computed segmentation and motion losses 
        loss += ed_es_seg_loss
        
        epoch_loss.append(loss.item())  # save final loss for this epoch (this function call)
        
        
        # compute dice loss of ed and es segmentations with ground truth
        ed_segmentation_argmax = torch.argmax(ed_segmentations, 1).cpu().detach().numpy()
        es_segmentation_argmax = torch.argmax(es_segmentations, 1).cpu().detach().numpy()
        
        ed_lv_dice += categorical_dice(ed_segmentation_argmax, ed_label.numpy(), 1)
        es_lv_dice += categorical_dice(es_segmentation_argmax, es_label.numpy(), 1)
    
    # print info
    print("-" * 30 + "Validation" + "-" * 30)
    print("\nED LV: {:.3f}".format(ed_lv_dice / batch_idx))
    print("ES LV: {:.3f}".format(es_lv_dice / batch_idx))
        
    # Printing the intermediate training statistics
        
    print('\nValid set: Average loss: {:.4f}\n'.format(np.mean(epoch_loss)))
    
    return epoch_loss