In [None]:
%load_ext autoreload
%autoreload 2

from   cnn4cmr.architectures import Unet, FCN
import cnn4cmr.data_loaders.segmentation_dl as seg_dl
import cnn4cmr.data_loaders.augmentations as augs
import cnn4cmr.data_loaders.utils as utils
from   cnn4cmr.architectures import custom_losses

import os
from pathlib import Path
from time import time
from pprint import pprint

import torch
import numpy as np
import matplotlib.pyplot as plt
from   shapely.geometry import Point, Polygon, mapping
from   imgaug import augmenters as iaa

In [None]:
# make dataset

# get suid,sop list
ds_path = '/Users/thomas/Documents/GitHub/cnn4cmr/Datasets/T1_dataset'
suidlist = [p for p in os.listdir(os.path.join(ds_path, 'Imgs')) if '.DS' not in p]
suid_sop_list = []
for suid in suidlist:
    for sop in [p.replace('.dcm','') for p in os.listdir(os.path.join(ds_path, 'Imgs', suid))]:
        suid_sop_list.append((suid, sop))

# define train, val, test
n_imgs = len(suid_sop_list)
n_train, n_val, n_test = int(n_imgs*0.6), int(n_imgs*0.8), n_imgs

suid_sop_list_train = suid_sop_list[:n_train]
suid_sop_list_val   = suid_sop_list[n_train:n_val]
suid_sop_list_test  = suid_sop_list[n_val:]

print('Train images: ', len(suid_sop_list_train))
print('Val images:   ', len(suid_sop_list_val))
print('Test images:  ', len(suid_sop_list_test))

In [None]:
##### hyperparameters

# limits, prior, increments, prob of prior, prob of mutation, operator, init value
def init_t1_parameter_config():
    t1_params = {'droprate':    0.2, 
                 # heatmap variances
                 'heatmap_v1':  19, 
                 'heatmap_v2':  9, 
                 'heatmap_v3':  15, 
                 # affine 1
                 'rotation':    360,
                 # continue here
                 'scale':       0.25,
                 # affine 2
                 'translate':   0.1, 
                 'shear':       7,
                 # other augs
                 'poolsize':    5,
                 'noise':       0.3,
                 'blur':        2.5,
                 'mult_range':  0.2,
                 'contrast':    0.3}
    for w_i in range(1,3):
        t1_params['w_'+str(w_i)] = 0.5
    for w_i in range(3,8):
        t1_params['w_'+str(w_i)] = 0.2
    return t1_params

hp_config = init_t1_parameter_config()
for k,v in hp_config.items(): print(k, v)

In [None]:
# make unet config

def init_unet_config(t1_parameter_config):
    unet_config = {'net_depth': 6, 'in_ch': 1, 'st_ch': 16, 'out_ch': 1, 
                   'd_rate': t1_parameter_config['droprate'], # this parameter evolves
                   'max_ch': 512, 'deep_supervision': True}
    return unet_config

# make Unet
#config = init_unet_config(hp_config)
#cont_cnn = Unet(config['net_depth'], config['in_ch'], config['st_ch'], config['out_ch'], 
#               d_rate=config['d_rate'], max_ch=config['max_ch'], deep_supervision=config['deep_supervision'])

In [None]:
def init_cnn_bbox(device):
    # make generator to get batch
    # seed = epoch makes the images in random order but the same for each cnn in population
    epoch      = 0
    generator  = seg_dl.random_img_goldanno_generator(ds_path, suid_sop_list_train, 256, bounding_box=False, seed=epoch)
    
    # hyperparameters
    hp_config = init_t1_parameter_config()
    
    # make unet
    config = init_unet_config(hp_config)
    
    cnn    = Unet(config['net_depth'], config['in_ch'], config['st_ch'], config['out_ch'], d_rate=config['d_rate'], max_ch=config['max_ch'], deep_supervision=config['deep_supervision'])
    cnn.to(device)
    # make Adam
    optimizer = torch.optim.Adam(cnn.parameters(), lr=0.001)
    
    return cnn, optimizer, generator, hp_config

In [None]:
def train_step_bbox(cnn, optimizer, generator, hp_config, device, epochs, suid_sop_list):
    # make learning behaviour
    batchsize, lr = 8, 0.001
    criterion     = custom_losses.DeepSupervisionLoss(weights      = np.asarray([4/14, 2/14, 1/14, 4/14, 2/14, 1/14]), 
                                                      loss_modules = [custom_losses.DiceLoss(), custom_losses.BCELoss()])
    # make augmentations
    augmentations = augs.make_augmenter(hp_config)

    batches_per_epoch = len(suid_sop_list) // batchsize
    for epoch in range(1, epochs+1):
        print(); print('\tEpoch: ', epoch, ' of ', epochs, end='\n\t')
        epoch_loss = 0
        cnn.train()
        # train loop
        for i_batch in range(batches_per_epoch):
            if i_batch%10==9: print('.', end='')
            suid_sop_pairs, imgs, masks, masks2, masks3 = seg_dl.get_batch_imgs_epimasks(generator, augmentations, batchsize=batchsize, prepare_channels=True, deep_supervision=True)
            imgs, masks, masks2, masks3 = torch.from_numpy(imgs).to(device), torch.from_numpy(masks).to(device), torch.from_numpy(masks2).to(device), torch.from_numpy(masks3).to(device)
            
            # train step
            optimizer.zero_grad()
            masks_pred, masks_pred2, masks_pred3 = cnn(imgs)
            loss = criterion(masks_pred, masks_pred2, masks_pred3, masks, masks2, masks3)
            loss.backward(); optimizer.step()

In [None]:
def save_checkpoint(path, epoch, nr, prefix, model, optimizer, hp_config):
    # make folders
    path = os.path.join(path, 'epoch_'+str(epoch))
    Path(path).mkdir(parents=True, exist_ok=True)
    dirs = [name for name in os.listdir(path) if os.path.isdir(os.path.join(path, name))]
    folder_path = os.path.join(path, 'cnn_'+str(nr)) # get largest number add 1 and make new folder for contents
    Path(folder_path).mkdir(parents=True, exist_ok=True)
    # hp config for storage
    hp_config = {k:v.value for k,v in hp_config.items()}
    # store everything
    torch.save(model.state_dict(),     os.path.join(folder_path, prefix+'model_state_dict.pth'))
    torch.save(optimizer.state_dict(), os.path.join(folder_path, prefix+'optim_state_dict.pth'))
    torch.save(hp_config,              os.path.join(folder_path, prefix+'hp_config.pth'))


def load_checkpoint(folder_path, prefix):
    model_state_dict       = torch.load(os.path.join(folder_path, prefix+'model_state_dict.pth'))
    optimizer_state_dict   = torch.load(os.path.join(folder_path, prefix+'optim_state_dict.pth'))
    hp_config_values       = torch.load(os.path.join(folder_path, prefix+'hp_config.pth'))
    # rebuild hp_config
    hp_config = init_t1_parameter_config()
    for k,v in hp_config_values.items(): hp_config[k].value = hp_config_values[k]
    return model_state_dict, optimizer_state_dict, hp_config, epoch

In [None]:
device = 'mps'
print('\tInitializing cnns and training environment')
#bbox_cnn, bbox_optimizer, bbox_generator, bbox_hp_config = init_cnn_bbox(device)

#print('\tTraining starts')
train_step_bbox(bbox_cnn, bbox_optimizer, bbox_generator, bbox_hp_config, 
                device, epochs=3, suid_sop_list=suid_sop_list_train)

In [None]:
noaugs = augs.make_noaugs_augmenter()
noaugs = augs.make_augmenter(hp_config)

for i, ((suid,sop),(img,conts,kps)) in enumerate(bbox_generator):
    imgs = noaugs(images=[img])
    pred = bbox_cnn(torch.from_numpy(np.asarray(imgs)).to(device).reshape(1,1,256,256))[0].cpu().detach().numpy().squeeze()
    fig,axes = plt.subplots(2,2, figsize=(10,10))
    for ax_ in axes: 
        for ax in ax_: ax.axis('off')
    axes[0,0].imshow(img, cmap='gray')
    axes[0,1].imshow(imgs[0], cmap='gray')
    axes[1,0].imshow(pred>0.5, cmap='gray')
    axes[1,1].imshow(imgs[0]+2*(pred>0.5).astype(np.float32), cmap='gray')
    plt.show()
    if i==10: break