In [None]:
%load_ext autoreload
%autoreload 2

import pathlib
import os
import json
import traceback

import matplotlib.pyplot as plt
import numpy as np
import torch
import pydicom
from shapely.geometry import mapping, Polygon
from shapely.affinity import scale, translate
from skimage.transform import resize

from   cnn4cmr import data_loaders 
import cnn4cmr.data_loaders.augmentations as augs
from   cnn4cmr.architectures import Unet, FCN
from   cnn4cmr.data_loaders import utils

In [None]:
def init_unet_config():
    unet_config = {'net_depth': 6, 'in_ch': 1, 'st_ch': 16, 'out_ch': 3, 
                   'd_rate': 0.2, # this parameter evolves
                   'max_ch': 512, 'deep_supervision': True}
    return unet_config
    
def load_checkpoint(folder_path, prefix):
    model_state_dict       = torch.load(os.path.join(folder_path, prefix+'model_state_dict.pth'))
    optimizer_state_dict   = torch.load(os.path.join(folder_path, prefix+'optim_state_dict.pth'))
    hp_config_values       = torch.load(os.path.join(folder_path, prefix+'hp_config.pth'))
    # rebuild hp_config
    return model_state_dict, optimizer_state_dict

def reload_checkpoint(folder_path, prefix, device):
    # reload cnn and genetic states 
    model_state_dict, optimizer_state_dict = load_checkpoint(folder_path, prefix)

    # unet
    config = init_unet_config()
    cnn    = Unet(config['net_depth'], config['in_ch'], config['st_ch'], config['out_ch'], d_rate=config['d_rate'], max_ch=config['max_ch'], deep_supervision=config['deep_supervision'])
    cnn.load_state_dict(model_state_dict)
    cnn.to(device)
    
    # adam optimizer
    optimizer = torch.optim.Adam(cnn.parameters(), lr=0.001)
    optimizer.load_state_dict(optimizer_state_dict)

    # not necessary to create heatmaps, regardless of cont or kps
    return cnn, optimizer

In [None]:
def annotate_saxcine_images(path_to_images, path_to_storage, 
                            augmenter,
                            cont_cnn, 
                            device):
    pathlib.Path(os.path.join(path_to_storage, 'Annos')).mkdir(parents=True, exist_ok=True)
    
    gen = pathlib.Path(path_to_images).glob('**/*.dcm')
    for p_i, p in enumerate(gen):
        if p_i%10==1: print('.',end='')
        try:
            dcm = pydicom.dcmread(p)
            suid, sop = dcm.StudyInstanceUID, dcm.SOPInstanceUID
            img1 = dcm.pixel_array.astype(np.float32)
            h,w = img1.shape
            sh_x, sh_y = (256-w)/2.0, (256-h)/2.0
            img2 = augmenter(images=[img1])[0]
            h,w = img2.shape
            
            # Segmentation
            pred = cont_cnn(torch.from_numpy(img2.reshape(1,1,h,w)).to(device))[0]
            pred = pred.cpu().detach().numpy().squeeze()
            
            lvendo_contour   = data_loaders.utils.to_polygon((pred[0]>0.5).astype(np.uint8))
            if lvendo_contour.geom_type=='MultiPolygon': lvendo_contour=max(lvendo_contour.geoms, key=lambda c:c.area)
            lvmyo_contour   = data_loaders.utils.to_polygon((pred[1]>0.5).astype(np.uint8))
            if lvmyo_contour.geom_type=='MultiPolygon': lvmyo_contour=max(lvmyo_contour.geoms, key=lambda c:c.area)
            try:    lvepi_contour   = Polygon(lvmyo_contour.exterior)
            except: print(traceback.format_exc()); lvepi_contour = None
            rvendo_contour   = data_loaders.utils.to_polygon((pred[2]>0.5).astype(np.uint8))
            if rvendo_contour.geom_type=='MultiPolygon': rvendo_contour=max(rvendo_contour.geoms, key=lambda c:c.area)

            lvendo_contour = translate(lvendo_contour, xoff=-sh_x,  yoff=-sh_y)
            lvmyo_contour  = translate(lvmyo_contour,  xoff=-sh_x,  yoff=-sh_y)
            try:    lvmyo_contour = Polygon(lvmyo_contour.exterior, [lvendo_contour.exterior])
            except: lvmyo_contour = Polygon()
            
            lvepi_contour  = translate(lvepi_contour,  xoff=-sh_x,  yoff=-sh_y)
            rvendo_contour = translate(rvendo_contour, xoff=-sh_x,  yoff=-sh_y)
            
            # Make Annotation
            h,w   = dcm.pixel_array.shape
            ph,pw = dcm.PixelSpacing
            
            """
            fig,ax = plt.subplots(1,1, figsize=(5,5))
            ax.axis('off')
            h,w = dcm.pixel_array.shape
            ax.imshow(dcm.pixel_array.astype(np.float32), cmap='gray', extent=(0,w,h,0))
            x,y = lvendo_contour.exterior.xy; print('Length: ', len(x)); ax.plot(x,y, 'g')
            x,y = lvepi_contour.exterior.xy; print('Length: ', len(x)); ax.plot(x,y, 'r')
            #x,y = lvmyo_contour.interiors[0].xy; print('Length: ', len(x)); ax.plot(x,y, 'purple')
            x,y = lvmyo_contour.exterior.xy;  print('Length: ', len(x)); ax.plot(x,y, 'blue')
            x,y = rvendo_contour.exterior.xy; print('Length: ', len(x)); ax.plot(x,y, 'white')
            plt.show()
            """
            
            anno  = {'lv_endo': {'cont': mapping(lvendo_contour), 'subpixelResolution': 1.0, 'pixelSize': (ph,pw), 'imageSize': (h,w), 'contType': 'FREE'},
                     'rv_endo': {'cont': mapping(rvendo_contour), 'subpixelResolution': 1.0, 'pixelSize': (ph,pw), 'imageSize': (h,w), 'contType': 'FREE'},}
            if lvepi_contour:  anno['lv_epi']  = {'cont': mapping(lvepi_contour),  'subpixelResolution': 1.0, 'pixelSize': (ph,pw), 'imageSize': (h,w), 'contType': 'FREE'}
            if lvendo_contour: anno['lv_endo'] = {'cont': mapping(lvendo_contour), 'subpixelResolution': 1.0, 'pixelSize': (ph,pw), 'imageSize': (h,w), 'contType': 'FREE'}
            if lvmyo_contour:  anno['lv_myo']  = {'cont': mapping(lvmyo_contour),  'subpixelResolution': 1.0, 'pixelSize': (ph,pw), 'imageSize': (h,w), 'contType': 'MYO'}

            pathlib.Path(os.path.join(path_to_storage, 'Annos', suid)).mkdir(parents=True, exist_ok=True)
            with open(os.path.join(path_to_storage, 'Annos', suid, sop+'.json'), 'w') as f: json.dump(anno, f, indent=4)

            #if p_i>20: break
        
        except Exception as e: 
            print('\nFailed for: ', dcm.PatientName)
            print(suid, sop)
            print(traceback.format_exc())

            fig,ax = plt.subplots(1,1, figsize=(5,5))
            ax.axis('off')
            h,w = dcm.pixel_array.shape
            ax.imshow(dcm.pixel_array.astype(np.float32), cmap='gray', extent=(0,w,h,0))

            try: print('LV ENDO'); print(lvendo_contour)
            except: pass
            try: print('LV MYO'); print(lvmyo_contour)
            except: pass
            try: print('RV ENDO'); print(rvendo_contour)
            except: pass
            
            try: print('here'); x,y = lvmyo_contour.exterior.xy; print('Length: ', len(x)); ax.plot(x,y)
            except: pass
            try: x,y = lvmyo_contour.interiors[0].xy; ax.plot(x,y)
            except: pass
            try:
                if kp.x>-1 and kp.y>-1: ax.plot(kp.x, kp.y, 'rx')
            except: pass
            plt.show()
            print()

In [None]:
noaugs = augs.make_noaugs_augmenter()
device            = 'mps'
cont_storage_path = '/Users/thomas/Documents/GitHub/cnn4cmr/trained_models/SAXCINECS/FirstCont'
cont_cnn, optimizer = reload_checkpoint(os.path.join(cont_storage_path,'epoch_110'), 'cont_', device)
cont_cnn = cont_cnn.eval()

In [None]:
inp_path = '/Users/thomas/Documents/GitHub/cnn4cmr/Datasets/SAXCS_dataset/Imgs'
out_path = '/Users/thomas/Desktop/outtest'
gen = annotate_saxcine_images(path_to_images=inp_path, 
                              path_to_storage=out_path, 
                              augmenter=noaugs,
                              cont_cnn=cont_cnn,
                              device='mps')

In [None]:
inp_path = '/Users/thomas/Documents/GitHub/cnn4cmr/Datasets/SAXCINE_dataset/Imgs'
out_path = '/Users/thomas/Desktop/outtest'
gen = annotate_saxcine_images(path_to_images=inp_path, 
                              path_to_storage=out_path, 
                              augmenter=noaugs,
                              cont_cnn=cont_cnn,
                              device='mps')

In [None]:
inp_path = '/Users/thomas/Desktop/ReliabilityPaper/Evaluation_Data/SAXCINE_Imgs'
out_path = '/Users/thomas/Desktop/outtest_evaldata'
gen = annotate_saxcine_images(path_to_images=inp_path, 
                              path_to_storage=out_path, 
                              augmenter=noaugs,
                              cont_cnn=cont_cnn,
                              device='mps')