In [None]:
def train(epoch, train_loader, model, optimizer):
    """ Training function for the network """
    # set the model to be in training mode, allow weights and biases of internal 
    # layers with tunable parmeters to be changed by the optimizer
    model.train()
    # list to store losses from this current epoch of training
    epoch_loss = []
    # dice is the percentage of overlap between two 2d sets, our images, comparing the segmentation out
    # ED and ES frame to the labeled ED and ES frame from the dataset
    # every batch has 4 videos, each video has a pair of ED/ES frames
    # in every epoch we will loop thru the entire train split of the dataset, wrapped in the dataloader 
    # iterable. this outer variable must be declared here to store the mean of the individual ED/ES frame dice scores
    # of all videos in the train split of the dataset during this epoch (epoch being the individual iteration over
    # all of the data units in the train split of the dataset)
    ed_lv_dice = 0
    es_lv_dice = 0
    
    # seed the numpy random number generator.
    # however, I don't see it being seeded with a specific value to allow for reproducibility.
    # it should be.
    np.random.seed()
    
    # loop thru training dataloader, 
    # batchsize is 4
    # so shape will be (4, 3, 32, 112, 112)
    # 3 channels, 32 frames in each clip, 112 x 112 frame size
    for batch_idx, batch in enumerate(train_loader, 1):
        
        # this is the raw frames from the video in the dataset to be passed into the model
        # he's just transforming them into the right tensor data types to pass thru to the modle
        
        # is this a single video, or is this 4 videos of clip size 32 frames that are being passed to the model?
        video_clips = torch.Tensor(batch[0])
        video_clips = video_clips.type(Tensor)
        
        # grab other data from current batch that we will use later on
        filename, EF, es_clip_index, ed_clip_index, es_index, ed_index, es_frame, ed_frame, es_label, ed_label = batch[1]

        # reset the optimizer's params' gradients
        optimizer.zero_grad()
        # Get the motion tracking output from the motion tracking head using the feature map
        # get the segmentation output too (guess of lv segmentation)
        segmentation_output, motion_output = model(video_clips)
        
        # compute loss using the actual video clip information with the motion tracking information
        # what is the difference btwn the two loss functions here?
        # deformation_motion_loss and motion_seg_loss ?
        loss = 0
        deform_loss = deformation_motion_loss(video_clips, motion_output)
        loss += deform_loss
        
        # these vars are to store the mean seg and motion losses for the 32 frames clips that we have.
        segmentation_loss = 0
        motion_loss = 0
        
        # iterate for the number of 32 frames clips that we have for this particular video in this current
        # batch of data. so we should be doing this 4 times if the batch_size is 4, but I don't see that 
        # happening
        for i in range(video_clips.shape[0]):
            # transform the ed/es labeled frames info into whatever we need
            # them to be to be used for our loss functions
            label_ed = np.expand_dims(ed_label.numpy(), 1).astype("int")
            label_es = np.expand_dims(es_label.numpy(), 1).astype("int")

            label_ed = label_ed[i]
            label_es = label_es[i]

            label_ed = np.expand_dims(label_ed, 0)
            label_es = np.expand_dims(label_es, 0)

            motion_one_output = motion_output[i].unsqueeze(0)
            segmentation_one_output = segmentation_output[i].unsqueeze(0)

            ed_one_index = ed_clip_index[i]
            es_one_index = es_clip_index[i]

            segmentation_one_loss, motion_one_loss = motion_seg_loss(label_ed, label_es, 
                                                                     ed_one_index, es_one_index, 
                                                                     motion_one_output, segmentation_one_output, 
                                                                     0, video_clips.shape[2], 
                                                                     F.binary_cross_entropy_with_logits)
            segmentation_loss += segmentation_one_loss
            motion_loss += motion_one_loss
        loss += (segmentation_loss / video_clips.shape[0])
        loss += (motion_loss / video_clips.shape[0])              
        
        ed_segmentations = torch.Tensor([]).type(Tensor)
        es_segmentations = torch.Tensor([]).type(Tensor)
        for i in range(len(ed_clip_index)):
            ed_one_index = ed_clip_index[i]
            es_one_index = es_clip_index[i]
            
            ed_seg = segmentation_output[i, :, ed_one_index].unsqueeze(0)
            ed_segmentations = torch.cat([ed_segmentations, ed_seg])
            
            es_seg = segmentation_output[i, :, es_one_index].unsqueeze(0)
            es_segmentations = torch.cat([es_segmentations, es_seg])
            
            
        ed_es_seg_loss = 0
        ed_es_seg_loss += F.binary_cross_entropy_with_logits(ed_segmentations, 
                                                             convert_to_1hot(np.expand_dims(ed_label.numpy().astype("int"), 1), 2), 
                                                             reduction="mean") 
        
        ed_es_seg_loss += F.binary_cross_entropy_with_logits(es_segmentations, 
                                                             convert_to_1hot(np.expand_dims(es_label.numpy().astype("int"), 1), 2), 
                                                             reduction="mean") 
        ed_es_seg_loss /= 2
        
        loss += ed_es_seg_loss

        loss.backward()
        
        optimizer.step()
        
        epoch_loss.append(loss.item())
        
        ed_segmentation_argmax = torch.argmax(ed_segmentations, 1).cpu().detach().numpy()
        es_segmentation_argmax = torch.argmax(es_segmentations, 1).cpu().detach().numpy()
            
        ed_lv_dice += categorical_dice(ed_segmentation_argmax, ed_label.numpy(), 1)
        es_lv_dice += categorical_dice(es_segmentation_argmax, es_label.numpy(), 1)
        
        # Printing the intermediate training statistics
        if batch_idx % 280 == 0:
            print('\nTrain Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(video_clips), len(train_loader) * len(video_clips),
                100. * batch_idx / len(train_loader), np.mean(epoch_loss)))

            print("ED LV: {:.3f}".format(ed_lv_dice / batch_idx))
            print("ES LV: {:.3f}".format(es_lv_dice / batch_idx))
            
            print("On a particular batch:")
            print("Deform loss: ", deform_loss)
            print("Segmentation loss: ", ed_es_seg_loss)
            print("Seg Motion loss: ", segmentation_loss / video_clips.shape[0], motion_loss / video_clips.shape[0])
    
    return epoch_loss