### I think we should just run thru the train function and specifically focus on one iteration thru the train_loader

In [1]:
import os
os.chdir("/home/wang/workspace/JupyterNoteBooksAll/fully-automated-multi-heartbeat-echocardiography-video-segmentation-and-motion-tracking")

%config Completer.use_jedi = False

import echonet
from echonet.datasets import Echo

import torch.nn.functional as F
from torchvision.models.video import r2plus1d_18
from torch.utils.data import Dataset, DataLoader, Subset
from multiprocessing import cpu_count

from src.utils.torch_utils import TransformDataset, torch_collate
from src.transform_utils import generate_2dmotion_field
from src.loss_functions import huber_loss, convert_to_1hot, convert_to_1hot_tensor
from src.model.R2plus1D_18_MotionNet import R2plus1D_18_MotionNet
from src.echonet_dataset import EchoNetDynamicDataset
from src.clasfv_losses import deformation_motion_loss, motion_seg_loss, DiceLoss, categorical_dice
from src.train_test import train, test

import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim

import random
import pickle
import time

tic, toc = (time.time, time.time)

with open("fold_indexes/stanford_train_sampled_indices", "rb") as infile:
    train_mask = pickle.load(infile)
infile.close()

with open("fold_indexes/stanford_valid_sampled_indices", "rb") as infile:
    valid_mask = pickle.load(infile)
infile.close()

batch_size = 4
num_workers = max(4, cpu_count()//2)

def worker_init_fn_valid(worker_id):                                                          
    np.random.seed(np.random.get_state()[1][0] + worker_id)
    

def worker_init_fn(worker_id):
    # See here: https://pytorch.org/docs/stable/notes/randomness.html#dataloader
    # and the original post of the problem: https://github.com/pytorch/pytorch/issues/5059#issuecomment-817373837
    worker_seed = torch.initial_seed() % 2 ** 32
    np.random.seed(worker_seed)
    random.seed(worker_seed)
    

def permuter(list1, list2):
    for i1 in list1:
        for i2 in list2:
            yield (i1, i2)
            

param_trainLoader = {'collate_fn': torch_collate,
                     'batch_size': batch_size,
                     'num_workers': max(4, cpu_count()//2),
                     'worker_init_fn': worker_init_fn}

param_testLoader = {'collate_fn': torch_collate,
                    'batch_size': batch_size,
                    'shuffle': False,
                    'num_workers': max(4, cpu_count()//2),
                    'worker_init_fn': worker_init_fn}

paramLoader = {'train': param_trainLoader,
               'valid': param_testLoader,
               'test':  param_testLoader}

train_dataset = EchoNetDynamicDataset(split='train', subset_indices=train_mask, period=1)
valid_dataset = EchoNetDynamicDataset(split='val', subset_indices=valid_mask, period=1)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, 
                              num_workers=num_workers, 
                              shuffle=True, pin_memory=("cuda"), 
                              worker_init_fn=worker_init_fn,
                              drop_last=True)

valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, 
                              num_workers=num_workers,
                              shuffle=False, pin_memory=("cuda"),
                              worker_init_fn=worker_init_fn_valid
                             )

100%|██████████| 16/16 [00:01<00:00, 13.98it/s]
100%|██████████| 16/16 [00:01<00:00, 12.46it/s]


In [2]:
model = torch.nn.DataParallel(R2plus1D_18_MotionNet())
model.to("cuda")

print(f'R2+1D MotionNet has {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters.')

lr_T = 1e-4 
optimizer = optim.Adam(model.parameters(), lr=lr_T)

R2+1D MotionNet has 31575731 parameters.


In [3]:
def generate_2dmotion_field_PLAY(x, offset):
    # Qin's code for joint_motion_seg learning works fine on our purpose too
    # Same idea https://discuss.pytorch.org/t/warp-video-frame-from-optical-flow/6013/5
    
    x_shape = x.size()
    print(f'x_shape: {x_shape}')
    
    grid_w, grid_h = torch.meshgrid([torch.linspace(-1, 1, x_shape[2]), torch.linspace(-1, 1, x_shape[3])])  # (h, w)
    print(f'grid_w.shape (meshgrid): {grid_w.shape}')
    print(f'grid_h.shape (meshgrid): {grid_h.shape}')
    
    # this should just be moving the vars to gpu mem and doing some data type conversion to some
    # floating point precision
    grid_w = grid_w.cuda().float()
    grid_h = grid_h.cuda().float()
    
    print(f'grid_w.shape .cuda().float(): {grid_w.shape}')
    print(f'grid_h.shape .cuda().float(): {grid_h.shape}')

    grid_w = nn.Parameter(grid_w, requires_grad=False)
    grid_h = nn.Parameter(grid_h, requires_grad=False)
    print(f'grid_w.shape (nn.Param): {grid_w.shape}')
    print(f'grid_h.shape (nn.Param): {grid_h.shape}')


    offset_h, offset_w = torch.split(offset, 1, 1)
    
    print(f'offset_h.shape (split): {offset_h.shape}')
    print(f'offset_w.shape (split): {offset_w.shape}')
    
    offset_w = offset_w.contiguous().view(-1, int(x_shape[2]), int(x_shape[3]))  # (b*c, h, w)
    offset_h = offset_h.contiguous().view(-1, int(x_shape[2]), int(x_shape[3]))  # (b*c, h, w)
    
    print(f'offset_h.shape (contiguous): {offset_h.shape}')
    print(f'offset_w.shape (contiguous): {offset_w.shape}')
    
    offset_w = grid_w + offset_w
    offset_h = grid_h + offset_h
    
    print(f'offset_w (grid_w + offset_w): {offset_w.shape}')
    print(f'offset_h (grid_h + offset_h): {offset_h.shape}')
    
    offsets = torch.stack((offset_h, offset_w), 3)
    
    print(f'offsets (stack): {offsets.shape}')

    print('leaving generate_2dmotion_field')
    return offsets

In [4]:
def motion_seg_loss_PLAY(label_ed, label_es, ed_index, es_index, motion_output, seg_softmax, 
                    start=0, end=32, seg_criterion=DiceLoss()):
    """
        SGS loss that spatially transform the true ED and true ES fully forward to the end of video
        and backward to the beginning. Then, compare the forward and backward transformed pseudo labels with
        segmentation at all frames.
    """
    flow_source = convert_to_1hot(label_ed, 2)
    print(f"flow_source.shape: {flow_source.shape}")
    
    loss_forward = 0
    OTS_loss = 0
    OTS_criterion = DiceLoss()
    
    print('forward from ed to end of video')
    # Forward from ed to the end of video
    for frame_index in range(ed_index, end - 1):
        
        forward_motion = motion_output[:, :2, frame_index,...]
        print(f'forward_motion.shape: {forward_motion.shape}')
        
        print('entering generate_2dmotion_field')
        print('generate_2dmotion_field input shapes:')
        print(f'flow_source: {flow_source}')
        print(f'forward_motion: {forward_motion}')
        motion_field = generate_2dmotion_field_PLAY(flow_source, forward_motion)
        print('generated_2dmotion_field output variable:')
        print(f'motion_field.shape: {motion_field.shape}')
        
        print('entering torch.nn.functional.grid_sample')
        print('F.grid_sample input variable shapes:')
        print(f'flow_source: {flow_source}')
        print(f'motion_field: {motion_field}')
        next_label = F.grid_sample(flow_source, motion_field, align_corners=False, mode="bilinear", padding_mode='border')
        print('F.grid_sample output variabel shape:')
        print(f'next_label.shape: {next_label.shape}')
        
        
        if frame_index == (es_index - 1):
            one_hot_ES = convert_to_1hot(label_es, 2)
            OTS_loss += OTS_criterion(next_label, one_hot_ES)
        else:
            loss_forward += seg_criterion(seg_softmax[:, :, frame_index + 1, ...], next_label)
        flow_source = next_label
    
    print('Forward from es to the end of video')
    # Forward from es to the end of video
    flow_source = convert_to_1hot(label_es, 2)
    for frame_index in range(es_index, end - 1):
        forward_motion = motion_output[:, :2, frame_index,...]
        motion_field = generate_2dmotion_field(flow_source, forward_motion)
        next_label = F.grid_sample(flow_source, motion_field, align_corners=False, mode="bilinear", padding_mode='border')

        loss_forward += seg_criterion(seg_softmax[:, :, frame_index + 1, ...], next_label)
        flow_source = next_label

    flow_source = convert_to_1hot(label_es, 2)
    loss_backward = 0
    
    print('Backward from es to the beginning of video')
    # Backward from es to the beginning of video
    for frame_index in range(es_index, start, -1):
        backward_motion = motion_output[:, 2:, frame_index,...]
        motion_field = generate_2dmotion_field(flow_source, backward_motion)
        next_label = F.grid_sample(flow_source, motion_field, align_corners=False, mode="bilinear", padding_mode='border')
        
        if frame_index == ed_index + 1:
            one_hot_ED = convert_to_1hot(label_ed, 2)
            OTS_loss += OTS_criterion(next_label, one_hot_ED)
        else:
            loss_backward += seg_criterion(seg_softmax[:, :, frame_index - 1, ...], next_label)
        flow_source = next_label
    
    flow_source = convert_to_1hot(label_ed, 2)
    
    print('Backward from ed to the beginning of video')
    # Backward from ed to the beginning of video
    for frame_index in range(ed_index, start, -1):
        backward_motion = motion_output[:, 2:, frame_index,...]
        motion_field = generate_2dmotion_field(flow_source, backward_motion)
        next_label = F.grid_sample(flow_source, motion_field, align_corners=False, mode="bilinear", padding_mode='border')
        
        loss_backward += seg_criterion(seg_softmax[:, :, frame_index - 1, ...], next_label)
        flow_source = next_label
        
    # Averaging the resulting dice
    flow_loss = (loss_forward + loss_backward) / ((motion_output.shape[2] - 2) * 2)
    OTS_loss = OTS_loss / 2 
    
    print('leaving motion_seg_loss')
    
    return flow_loss, OTS_loss

In [5]:
def train_PLAY(epoch, train_loader, model, optimizer):
    """ Training function for the network """
    model.train()
    epoch_loss = []
    ed_lv_dice = 0
    es_lv_dice = 0
    
    np.random.seed()
    
    print(f"enter for loop enumerate train_loader - len(train_loader): {len(train_loader)}")
    for batch_idx, batch in enumerate(train_loader, 1):
        
        video_clips = torch.Tensor(batch[0])
        
        print(f'video_clips.shape step 1: {video_clips.shape}')
        
        video_clips = video_clips.type(Tensor)
        
        print(f'video_clips.shape step 2: {video_clips.shape}')
        
        filename, EF, es_clip_index, ed_clip_index, es_index, ed_index, es_frame, ed_frame, es_label, ed_label = batch[1]

        optimizer.zero_grad()
        
        print('pass video_clips to model')
       
        print(f'video_clips.shape before pass into model: {video_clips.shape}')
        
        # Get the motion tracking output from the motion tracking head using the feature map
        segmentation_output, motion_output = model(video_clips)
        
        
        print(f'segmentation_output.shape: {segmentation_output.shape}')
        print(f'motion_output.shape: {motion_output.shape}')
        
        
        loss = 0
        deform_loss = deformation_motion_loss(video_clips, motion_output)
        print(f'get deformation_motion_loss: {deform_loss}')
        
        loss += deform_loss

        
        segmentation_loss = 0
        motion_loss = 0
        
        print(f'enter loop thru video_clips.shape[0]: {video_clips.shape[0]}')
        for i in range(video_clips.shape[0]):
            print(f'\n\n video_clips.shape[0] ind: {i}')
            print(f'before ed/es labels:\ned: {ed_label.shape}\nes: {es_label.shape}')
            
            label_ed = np.expand_dims(ed_label.numpy(), 1).astype("int")
            label_es = np.expand_dims(es_label.numpy(), 1).astype("int")
            
            print(f'transform step 1\ned: {label_ed.shape}\nes: {label_es.shape} ')
            
            label_ed = label_ed[i]
            label_es = label_es[i]
            
            print(f'transform step 2\ned: {label_ed.shape}\nes: {label_es.shape} ')
            
            label_ed = np.expand_dims(label_ed, 0)
            label_es = np.expand_dims(label_es, 0)
            
            print(f'transform step 3\ned: {label_ed.shape}\nes: {label_es.shape} ')
                  
            motion_one_output = motion_output[i].unsqueeze(0)
            segmentation_one_output = segmentation_output[i].unsqueeze(0)
            
            print(f'motion_one_output.shape: {motion_one_output.shape}')
            print(f'segmentation_one_output.shape: {segmentation_one_output.shape}')

            ed_one_index = ed_clip_index[i]
            es_one_index = es_clip_index[i]
            
            print(f'ed_one_index: {ed_one_index}')
            print(f'es_one_index: {es_one_index}')

            print('entering motion_seg_loss')
            segmentation_one_loss, motion_one_loss = motion_seg_loss_PLAY(label_ed, label_es, 
                                                                     ed_one_index, es_one_index, 
                                                                     motion_one_output, segmentation_one_output, 
                                                                     0, video_clips.shape[2], 
                                                                     F.binary_cross_entropy_with_logits)
            
            print(f'segmentation_one_loss.shape: {segmentation_one_loss.shape}')      
            print(f'motion_one_loss.shape: {motion_one_loss.shape}')
            print(f'segmentation_one_loss.item(): {segmentation_one_loss.item()}')      
            print(f'motion_one_loss.item(): {motion_one_loss.item()}')

        
            segmentation_loss += segmentation_one_loss
            motion_loss += motion_one_loss
            print(f'segmentation_loss: {segmentation_loss}')
            print(f'motion_loss: {motion_loss}')

            print('end of current video_clips.shape[0] loop\n\n')
            
            # break
        
        
        
        
        
        print(f'loss before: {loss}')
        loss += (segmentation_loss / video_clips.shape[0])
        
        print(f'loss step 1 (add mean seg): {loss}')
        
        loss += (motion_loss / video_clips.shape[0])              
        
        print(f'loss step 2 (add mean motion): {loss}')
        
        
        ed_segmentations = torch.Tensor([]).type(Tensor)
        es_segmentations = torch.Tensor([]).type(Tensor)
        
        print(f'ed_segmentations.shape: {ed_segmentations.shape}')
        print(f'es_segmentations.shape: {es_segmentations.shape}')
        
        
        print(f'entering into loop for len(ed_clip_index): {len(ed_clip_index)}')

        for i in range(len(ed_clip_index)):
            print(f'\n\n start of loop len(ed_clip_index) ind: {i}')
                
            ed_one_index = ed_clip_index[i]
            es_one_index = es_clip_index[i]
            
            print(f'ed_one_index: {ed_one_index}')
            print(f'es_one_index: {es_one_index}')
            
            ed_seg = segmentation_output[i, :, ed_one_index].unsqueeze(0)
            ed_segmentations = torch.cat([ed_segmentations, ed_seg])
            
            print(f'ed_seg.shape: {ed_seg.shape}')
            print(f'ed_segmentations.shape: {ed_segmentations.shape}')

            
            es_seg = segmentation_output[i, :, es_one_index].unsqueeze(0)
            es_segmentations = torch.cat([es_segmentations, es_seg])
            
            print(f'es_seg.shape: {es_seg.shape}')
            print(f'es_segmentations.shape: {es_segmentations.shape}')

            print('leaving loop of len(ed_clip_index)\n\n')
            
            # break
            
        ed_es_seg_loss = 0
        ed_es_seg_loss += F.binary_cross_entropy_with_logits(ed_segmentations, 
                                                             convert_to_1hot(np.expand_dims(ed_label.numpy().astype("int"), 1), 2), 
                                                             reduction="mean") 
        
        print(f'ed_es_seg_loss for ed_segmentations: {ed_es_seg_loss}')
        
        ed_es_seg_loss += F.binary_cross_entropy_with_logits(es_segmentations, 
                                                             convert_to_1hot(np.expand_dims(es_label.numpy().astype("int"), 1), 2), 
                                                             reduction="mean") 
        
        print(f'adding es_segmentations loss to ed_es_seg_loss: {ed_es_seg_loss}')
        
        ed_es_seg_loss /= 2
        
        print(f'ed_es_seg_loss /= 2 : {ed_es_seg_loss}')
        
        loss += ed_es_seg_loss
        
        print(f'loss += ed_es_seg_loss : {loss}')

        loss.backward()
        
        optimizer.step()
        
        epoch_loss.append(loss.item())
        
        ed_segmentation_argmax = torch.argmax(ed_segmentations, 1).cpu().detach().numpy()
        es_segmentation_argmax = torch.argmax(es_segmentations, 1).cpu().detach().numpy()
        
        
        print(f'ed_segmentation_argmax.shape: {ed_segmentation_argmax.shape}')
        print(f'es_segmentation_argmax.shape: {es_segmentation_argmax.shape}')
        
            
        ed_lv_dice += categorical_dice(ed_segmentation_argmax, ed_label.numpy(), 1)
        es_lv_dice += categorical_dice(es_segmentation_argmax, es_label.numpy(), 1)
        
        
        print(f'ed_lv_dice: {ed_lv_dice}')
        print(f'es_lv_dice: {es_lv_dice}')

        
        # leave function here, stop at one batch
        print('leaving train_PLAY function')
        return epoch_loss
        
        
        # Printing the intermediate training statistics
        if batch_idx % 280 == 0:
            print('\nTrain Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(video_clips), len(train_loader) * len(video_clips),
                100. * batch_idx / len(train_loader), np.mean(epoch_loss)))

            print("ED LV: {:.3f}".format(ed_lv_dice / batch_idx))
            print("ES LV: {:.3f}".format(es_lv_dice / batch_idx))
            
            print("On a particular batch:")
            print("Deform loss: ", deform_loss)
            print("Segmentation loss: ", ed_es_seg_loss)
            print("Seg Motion loss: ", segmentation_loss / video_clips.shape[0], motion_loss / video_clips.shape[0])
    
    return epoch_loss

In [6]:
Tensor = torch.cuda.FloatTensor

model_save_path = "tmp_save_models/TESTING_R2plus1DMotionSegNet_model.pth"

train_loss_list = []
valid_loss_list = []

n_epoch = 10
min_loss = 1e5
for epoch in range(1, n_epoch + 1):
    print("-" * 32 + 'Epoch {}'.format(epoch) + "-" * 32)
    start = time.time()
    train_loss = train_PLAY(epoch, train_loader=train_dataloader, model=model, optimizer=optimizer)
    break
    
    train_loss_list.append(np.mean(train_loss))
    end = time.time()
    print("training took {:.8f} seconds".format(end-start))
    valid_loss = test(epoch, test_loader=valid_dataloader, model=model, optimizer=optimizer)
    valid_loss_list.append(np.mean(valid_loss))
    
    if (np.mean(valid_loss) < min_loss) and (epoch > 0):
        min_loss = np.mean(valid_loss)
        torch.save({"model": model.state_dict(), "optimizer": optimizer.state_dict()}, model_save_path)
        
    if epoch == 3:
        lr_T = 1e-5
        optimizer = optim.Adam(model.parameters(), lr=lr_T)

--------------------------------Epoch 1--------------------------------
enter for loop enumerate train_loader - len(train_loader): 1833
video_clips.shape step 1: torch.Size([4, 3, 32, 112, 112])
video_clips.shape step 2: torch.Size([4, 3, 32, 112, 112])
pass video_clips to model
video_clips.shape before pass into model: torch.Size([4, 3, 32, 112, 112])
segmentation_output.shape: torch.Size([4, 2, 32, 112, 112])
motion_output.shape: torch.Size([4, 4, 32, 112, 112])
get deformation_motion_loss: 0.023224124684929848
enter loop thru video_clips.shape[0]: 4


 video_clips.shape[0] ind: 0
before ed/es labels:
ed: torch.Size([4, 112, 112])
es: torch.Size([4, 112, 112])
transform step 1
ed: (4, 1, 112, 112)
es: (4, 1, 112, 112) 
transform step 2
ed: (1, 112, 112)
es: (1, 112, 112) 
transform step 3
ed: (1, 1, 112, 112)
es: (1, 1, 112, 112) 
motion_one_output.shape: torch.Size([1, 4, 32, 112, 112])
segmentation_one_output.shape: torch.Size([1, 2, 32, 112, 112])
ed_one_index: 0
es_one_index: 15
