In [None]:
%load_ext autoreload
%autoreload 2

from   cnn4cmr.architectures import Unet, FCN
import cnn4cmr.data_loaders.segmentation_dl as seg_dl
import cnn4cmr.data_loaders.augmentations as augs
import cnn4cmr.data_loaders.utils as utils
from   cnn4cmr.architectures import custom_losses

import os
from pathlib import Path
from time import time
from pprint import pprint

import torch
import numpy as np
import matplotlib.pyplot as plt
import shapely
from   shapely.geometry import Point, Polygon, mapping, shape
from   imgaug import augmenters as iaa
import json

In [None]:
# make dataset

# get suid,sop list
ds_path = '/Users/thomas/Documents/GitHub/cnn4cmr/Datasets/T1_dataset'
suidlist = [p for p in os.listdir(os.path.join(ds_path, 'Imgs')) if '.DS' not in p]
suid_sop_list = []
for suid in suidlist:
    for sop in [p.replace('.dcm','') for p in os.listdir(os.path.join(ds_path, 'Imgs', suid))]:
        suid_sop_list.append((suid, sop))

# define train, val, test
n_imgs = len(suid_sop_list)
n_train, n_val, n_test = int(n_imgs*0.6), int(n_imgs*0.8), n_imgs

suid_sop_list_train = suid_sop_list[:n_val]
suid_sop_list_val   = suid_sop_list[n_val:n_val]
suid_sop_list_test  = suid_sop_list[n_val:]

print('T1')
print('Train images: ', len(suid_sop_list_train))
print('Val images:   ', len(suid_sop_list_val))
print('Test images:  ', len(suid_sop_list_test))

In [None]:
##### hyperparameters

# limits, prior, increments, prob of prior, prob of mutation, operator, init value
def init_t1_parameter_config():
    t1_params = {'droprate':    0.15, 
                 # heatmap variances
                 'heatmap_v1':  40, 
                 'heatmap_v2':  20, 
                 'heatmap_v3':  40, 
                 # affine 1
                 'rotation':    270,
                 # continue here
                 'scale':       0.3,
                 # affine 2
                 'translate':   0.15, 
                 'shear':       8,
                 # other augs
                 'poolsize':    5,
                 'noise':       0.5,
                 'blur':        3.5,
                 'mult_range':  0.4,
                 'contrast':    0.3}
    for w_i in range(1,3):
        t1_params['w_'+str(w_i)] = 0.5
    for w_i in range(3,8):
        t1_params['w_'+str(w_i)] = 0.5
    return t1_params

hp_config = init_t1_parameter_config()
for k,v in hp_config.items(): print(k, v)

In [None]:
# make unet config

def init_unet_config(t1_parameter_config):
    unet_config = {'net_depth': 6, 'in_ch': 1, 'st_ch': 16, 'out_ch': 1, 
                   'd_rate': t1_parameter_config['droprate'], # this parameter evolves
                   'max_ch': 512, 'deep_supervision': True}
    return unet_config

# make Unet
#config = init_unet_config(hp_config)
#cont_cnn = Unet(config['net_depth'], config['in_ch'], config['st_ch'], config['out_ch'], 
#               d_rate=config['d_rate'], max_ch=config['max_ch'], deep_supervision=config['deep_supervision'])

In [None]:
def init_cnn_cont(device):
    # make generator to get batch
    # seed = epoch makes the images in random order but the same for each cnn in population
    epoch      = 0
    generator  = seg_dl.random_img_goldanno_generator(seg_dl.load_mapping_img_gold_anno, 
                                                      ds_path, 
                                                      suid_sop_list_train, 
                                                      256, 
                                                      bounding_box=True, 
                                                      path_to_boundingbox='/Users/thomas/Documents/GitHub/cnn4cmr/Datasets/T1_dataset/Predictions/Bbox_Epoch_115_trainedonall',
                                                      seed=epoch)
    
    # hyperparameters
    hp_config = init_t1_parameter_config()
    
    # make unet
    config = init_unet_config(hp_config)
    
    cnn    = Unet(config['net_depth'], config['in_ch'], config['st_ch'], config['out_ch'], 
                  d_rate=config['d_rate'], max_ch=config['max_ch'], deep_supervision=config['deep_supervision'])
    cnn.to(device)
    # make Adam
    optimizer = torch.optim.Adam(cnn.parameters(), lr=0.001)
    
    return cnn, optimizer, generator, hp_config

In [None]:
def train_step_cont(cnn, optimizer, generator, hp_config, device, epochs, suid_sop_list):
    # make learning behaviour
    batchsize, lr = 8, 0.001
    criterion     = custom_losses.DeepSupervisionLoss(weights      = np.asarray([4/14, 2/14, 1/14, 4/14, 2/14, 1/14]), 
                                                      loss_modules = [custom_losses.DiceLoss(), custom_losses.BCELoss()])
    # make augmentations
    augmentations = augs.make_augmenter(hp_config)

    batches_per_epoch = len(suid_sop_list) // batchsize
    for epoch in range(1, epochs+1):
        print(); print('\tEpoch: ', epoch, ' of ', epochs, end='\n\t')
        epoch_loss = 0
        cnn.train()
        # train loop
        for i_batch in range(batches_per_epoch):
            if i_batch%10==9: print('.', end='')
            suid_sop_pairs, imgs, masks, masks2, masks3 = seg_dl.get_batch_mapping_imgs_masks(generator, augmentations, batchsize=batchsize, prepare_channels=True, deep_supervision=True)
            imgs, masks, masks2, masks3 = torch.from_numpy(imgs).to(device), torch.from_numpy(masks).to(device), torch.from_numpy(masks2).to(device), torch.from_numpy(masks3).to(device)
            
            # train step
            optimizer.zero_grad()
            masks_pred, masks_pred2, masks_pred3 = cnn(imgs)
            loss = criterion(masks_pred, masks_pred2, masks_pred3, masks, masks2, masks3)
            loss.backward(); optimizer.step()

In [None]:
def save_checkpoint(path, epoch, prefix, model, optimizer, hp_config):
    # make folders
    path = os.path.join(path, 'epoch_'+str(epoch))
    Path(path).mkdir(parents=True, exist_ok=True)
    dirs = [name for name in os.listdir(path) if os.path.isdir(os.path.join(path, name))]
    
    # hp config for storage
    hp_config = {k:v for k,v in hp_config.items()}
    # store everything
    torch.save(model.state_dict(),     os.path.join(path, prefix+'model_state_dict.pth'))
    torch.save(optimizer.state_dict(), os.path.join(path, prefix+'optim_state_dict.pth'))
    torch.save(hp_config,              os.path.join(path, prefix+'hp_config.pth'))


def load_checkpoint(folder_path, prefix):
    model_state_dict       = torch.load(os.path.join(folder_path, prefix+'model_state_dict.pth'))
    optimizer_state_dict   = torch.load(os.path.join(folder_path, prefix+'optim_state_dict.pth'))
    hp_config_values       = torch.load(os.path.join(folder_path, prefix+'hp_config.pth'))
    epoch = int(os.path.basename(folder_path).split('_')[-1])
    # rebuild hp_config
    hp_config = init_t1_parameter_config()
    for k,v in hp_config_values.items(): hp_config[k] = hp_config_values[k]
    return model_state_dict, optimizer_state_dict, hp_config, epoch

def reload_checkpoint(folder_path, prefix):
    # reload cnn and genetic states 
    model_state_dict, optimizer_state_dict, hp_config, epoch = load_checkpoint(folder_path, prefix)

    # unet
    config = init_unet_config(hp_config)
    cnn    = Unet(config['net_depth'], config['in_ch'], config['st_ch'], config['out_ch'], d_rate=config['d_rate'], max_ch=config['max_ch'], deep_supervision=config['deep_supervision'])
    cnn.load_state_dict(model_state_dict)
    cnn.to(device)
    
    # adam optimizer
    optimizer = torch.optim.Adam(cnn.parameters(), lr=0.001)
    optimizer.load_state_dict(optimizer_state_dict)

    # not necessary to create heatmaps, regardless of cont or kps
    return epoch, cnn, optimizer, hp_config

In [None]:
device            = 'mps'
cont_storage_path = '/Users/thomas/Documents/GitHub/cnn4cmr/trained_models/T1/Cont'
n_epochs          = 15

#cont_cnn, cont_optimizer, cont_generator, cont_hp_config = init_cnn_cont(device)

train_step_cont(cont_cnn, cont_optimizer, cont_generator, cont_hp_config, 
                device, epochs=n_epochs, suid_sop_list=suid_sop_list_train+suid_sop_list_test)

save_checkpoint(cont_storage_path, 115, 'cont_trainedonall_', cont_cnn, cont_optimizer, cont_hp_config)

In [None]:
_,_ , cont_generator, _ = init_cnn_cont(device)

cont_cnn, cont_optimizer, cont_hp_config = cnn, optimizer, hp_config

In [None]:
device            = 'mps'
cont_storage_path = '/Users/thomas/Documents/GitHub/cnn4cmr/trained_models/T1/Cont'

epoch, cnn, optimizer, hp_config = reload_checkpoint(os.path.join(cont_storage_path,'epoch_100'), 'cont_')

In [None]:
def to_annotation(type_name_2_mask, other_info):
    ret = dict()
    for (annotype, name, lcc), mask in type_name_2_mask.items():
        if annotype=='keypoint':
            assert  len(mask.shape)==2,    'Mask for ' + annotype + ' ' +  name + ' has more than two dimensions.'
            point = utils.to_keypoint(heatmap_mask)
            ret[name] = {'cont':               mapping(point),
                         'contType':           other_info['contType'],
                         'subpixelResolution': 1.0,
                         'pixelSize':          other_info['pixelSize'],
                         'imageSize':          other_info['imageSize']}
        if annotype=='contour': #
            assert  len(mask.shape)==2,    'Mask for ' + annotype + ' ' +  name + ' has more than two dimensions.'
            assert  mask.dtype==np.uint8,  'Mask for ' + annotype + ' ' +  name + ' is not of type np.uint8.'
            poly = utils.to_polygon(mask)
            if lcc and poly.geom_type=='MultiPolygon': poly=max(poly.geoms, key=lambda p:p.area)
            ret[name] = {'cont':               mapping(poly),
                         'contType':           other_info['contType'],
                         'subpixelResolution': 1.0,
                         'pixelSize':          other_info['pixelSize'],
                         'imageSize':          other_info['imageSize']}
    return ret

In [None]:
# Visualizations
import matplotlib.patches as patches
import pydicom
from pprint import pprint

cnn = cnn.eval()
noaugs = augs.make_noaugs_augmenter()
#noaugs = augs.make_augmenter(hp_config)
#test_generator  = seg_dl.random_img_goldanno_generator(ds_paths, suid_sop_list_test_lists, 
#                                                       256, bounding_box=True, 
#                                                       path_to_boundingbox='/Users/thomas/Documents/GitHub/cnn4cmr/Datasets/T2_dataset/Predictions/BBox_Epoch_50', 
#                                                       seed=epoch)
test_generator  = seg_dl.random_img_goldanno_generator(seg_dl.load_mapping_img_gold_anno, 
                                                                   ds_path, 
                                                                   suid_sop_list_test, 
                                                                   256, bounding_box=True, 
                                                                   path_to_boundingbox='/Users/thomas/Documents/GitHub/cnn4cmr/Datasets/T1_dataset/Additional_Info',
                                                                   seed=epoch)

for i, ((suid,sop),(img,conts,kps)) in enumerate(test_generator):
    try:
        imgs = noaugs(images=[img])
        pred = cnn(torch.from_numpy(np.asarray(imgs)).to(device).reshape(1,1,256,256))[0].cpu().detach().numpy().squeeze()
    
        print(pred.shape)
        pixelspacing=pixelspacing = pydicom.dcmread(os.path.join(ds_path, 'Imgs', suid, sop+'.dcm'),stop_before_pixels=True).PixelSpacing
        other_info = {'imageSize': img.shape,
                      'pixelSize': pixelspacing,
                      'contType':  'MYO'}
        type_name_2_mask = {('contour', 'lv_myo', True): (pred>0.5).astype(np.uint8)}
        anno = to_annotation(type_name_2_mask, other_info)
        
        fig,axes = plt.subplots(1,2, figsize=(14,7))
        for ax in axes: ax.axis('off')
        axes[0].imshow(img, cmap='gray', interpolation='nearest')
        axes[1].imshow(imgs[0], cmap='gray', interpolation='nearest')
        #axes[2].imshow(imgs[0]+2*(pred>0.5).astype(np.float32), cmap='gray', interpolation='nearest')
    
        x,y = shape(anno['lv_myo']['cont']).exterior.xy
        axes[1].plot(x,y)
        x,y = shape(anno['lv_myo']['cont']).interiors[0].xy
        axes[1].plot(x,y)
        
        fig.tight_layout()
        plt.show()
        if i==200: break
    except Exception as e:
        print(e)
        continue