In [None]:
%load_ext autoreload
%autoreload 2

import pathlib
import os
import json
import traceback

import matplotlib.pyplot as plt
import numpy as np
import torch
import pydicom
from shapely.geometry import mapping, Polygon
from shapely.affinity import scale, translate
from skimage.transform import resize

from   cnn4cmr import data_loaders 
import cnn4cmr.data_loaders.augmentations as augs
from   cnn4cmr.architectures import Unet, FCN
from   cnn4cmr.data_loaders import utils

In [None]:
def init_unet_config():
    unet_config = {'net_depth': 6, 'in_ch': 1, 'st_ch': 16, 'out_ch': 1, 
                   'd_rate': 0.2, # this parameter evolves
                   'max_ch': 512, 'deep_supervision': True}
    return unet_config
    
def load_checkpoint(folder_path, prefix):
    model_state_dict       = torch.load(os.path.join(folder_path, prefix+'model_state_dict.pth'))
    optimizer_state_dict   = torch.load(os.path.join(folder_path, prefix+'optim_state_dict.pth'))
    hp_config_values       = torch.load(os.path.join(folder_path, prefix+'hp_config.pth'))
    # rebuild hp_config
    return model_state_dict, optimizer_state_dict

def reload_checkpoint(folder_path, prefix, device):
    # reload cnn and genetic states 
    model_state_dict, optimizer_state_dict = load_checkpoint(folder_path, prefix)

    # unet
    config = init_unet_config()
    cnn    = Unet(config['net_depth'], config['in_ch'], config['st_ch'], config['out_ch'], d_rate=config['d_rate'], max_ch=config['max_ch'], deep_supervision=config['deep_supervision'])
    cnn.load_state_dict(model_state_dict)
    cnn.to(device)
    
    # adam optimizer
    optimizer = torch.optim.Adam(cnn.parameters(), lr=0.001)
    optimizer.load_state_dict(optimizer_state_dict)

    # not necessary to create heatmaps, regardless of cont or kps
    return cnn, optimizer

In [None]:
def annotate_parametric_mapping_images(path_to_images, path_to_storage, 
                                       augmenter,
                                       bbox_cnn, cont_cnn, kp_cnn, 
                                       device):
    pathlib.Path(os.path.join(path_to_storage, 'Annos')).mkdir(parents=True, exist_ok=True)
    
    gen = pathlib.Path(path_to_images).glob('**/*.dcm')
    for p_i, p in enumerate(gen):
        if p_i%10==1: print('.',end='')
        try:
            dcm = pydicom.dcmread(p)
            suid, sop = dcm.StudyInstanceUID, dcm.SOPInstanceUID
            img1 = dcm.pixel_array.astype(np.float32)
            h,w = img1.shape
            sh_x, sh_y = (256-w)/2.0, (256-h)/2.0
            img2 = augmenter(images=[img1])[0]
            h,w = img2.shape
            
            # Bounding Box
            pred_bbox = bbox_cnn(torch.from_numpy(img2.reshape(1,1,h,w)).to(device))[0]
            pred_bbox = pred_bbox.cpu().detach().numpy().squeeze()
            xmin,xmax,ymin,ymax = data_loaders.utils.get_bbox_from_mask((pred_bbox>0.5).astype(np.uint8), scale_f=2.5)
            xmin,xmax,ymin,ymax = max(xmin,0), min(xmax,w), max(ymin,0), min(ymax,h) # irrelevant, right?
            sc_fact = h / max(xmax-xmin, ymax-ymin)
            img3 = resize(img2, (h*sc_fact, w*sc_fact), order=3, mode='constant', preserve_range=True)
            img3 = img3[int(ymin*sc_fact):int(ymax*sc_fact), int(xmin*sc_fact):int(xmax*sc_fact)]
            h,w = img3.shape
            #print(h,w)
            sh_x2, sh_y2 = (256-w)/2.0, (256-h)/2.0
            img4 = augmenter(images=[img3])[0]
            h,w = img4.shape
            # Segmentation
            pred_cont = cont_cnn(torch.from_numpy(img4.reshape(1,1,h,w)).to(device))[0]
            pred_cont = pred_cont.cpu().detach().numpy().squeeze()
            contour   = data_loaders.utils.to_polygon((pred_cont>0.5).astype(np.uint8))
            if contour.geom_type=='MultiPolygon': contour=max(contour.geoms, key=lambda c:c.area)
            # Keypoint detection
            pred_kp = kp_cnn(torch.from_numpy(img4.reshape(1,1,h,w)).to(device))[0]
            pred_kp = pred_kp.cpu().detach().numpy().squeeze()
            kp = data_loaders.utils.to_keypoint(pred_kp)
            
            # Remove bounding box effects (scaling and translation)            
            contour = translate(contour, xoff=-sh_x2,     yoff=-sh_y2) # replace with something else?
            contour = scale(contour,     xfact=1/sc_fact, yfact=1/sc_fact, origin=(0,0))
            contour = translate(contour, xoff=xmin-sh_x,  yoff=ymin-sh_y)
            kp      = translate(kp,      xoff=-sh_x2,     yoff=-sh_y2)
            kp      = scale(kp,          xfact=1/sc_fact, yfact=1/sc_fact, origin=(0,0))
            kp      = translate(kp,      xoff=xmin-sh_x,  yoff=ymin-sh_y)
            
            # Make Annotation
            h,w   = dcm.pixel_array.shape
            ph,pw = dcm.PixelSpacing
            
            anno  = {'lv_myo':  {'cont': mapping(contour),                          'subpixelResolution': 1.0, 'pixelSize': (ph,pw), 'imageSize': (h,w), 'contType': 'MYO'},
                     'sax_ref': {'cont': mapping(kp),                               'subpixelResolution': 1.0, 'pixelSize': (ph,pw), 'imageSize': (h,w), 'contType': 'POINT'}}
            try: anno['lv_endo'] = {'cont': mapping(Polygon(contour.interiors[0])), 'subpixelResolution': 1.0, 'pixelSize': (ph,pw), 'imageSize': (h,w), 'contType': 'FREE'}
            except Exception as e: print(e); pass
            try: anno['lv_epi']  = {'cont': mapping(Polygon(contour.exterior)),     'subpixelResolution': 1.0, 'pixelSize': (ph,pw), 'imageSize': (h,w), 'contType': 'FREE'}
            except Exception as e: print(e); pass
            
            pathlib.Path(os.path.join(path_to_storage, 'Annos', suid)).mkdir(parents=True, exist_ok=True)
            with open(os.path.join(path_to_storage, 'Annos', suid, sop+'.json'), 'w') as f: json.dump(anno, f, indent=4)


            fig, ax = plt.subplots(1,1, figsize=(5,5))
            ax.axis('off')
            h,w = dcm.pixel_array.shape
            ax.imshow(dcm.pixel_array.astype(np.float32), cmap='gray', extent=(0,w,h,0))
            #ax.imshow(img4, cmap='gray')
            print(xmin,xmax,ymin,ymax)
            try: print('here'); x,y = contour.exterior.xy; print('Length: ', len(x)); ax.plot(x,y)
            except: pass
            try: x,y = contour.interiors[0].xy; ax.plot(x,y)
            except: pass
            try:
                if kp.x>-1 and kp.y>-1: ax.plot(kp.x, kp.y, 'rx')
            except: pass
            fig.tight_layout()
            fig.savefig(os.path.join(path_to_storage, sop+'.png'), dpi=300)
            plt.show()
            
        except Exception as e: 
            print('\nFailed for: ', dcm.PatientName)
            print(suid, sop)
            print(traceback.format_exc())

            fig,ax = plt.subplots(1,1, figsize=(5,5))
            ax.axis('off')
            h,w = dcm.pixel_array.shape
            ax.imshow(dcm.pixel_array.astype(np.float32), cmap='gray', extent=(0,w,h,0))
            #ax.imshow(img4, cmap='gray')
            print(xmin,xmax,ymin,ymax)
            try: print('here'); x,y = contour.exterior.xy; print('Length: ', len(x)); ax.plot(x,y)
            except: pass
            try: x,y = contour.interiors[0].xy; ax.plot(x,y)
            except: pass
            try:
                if kp.x>-1 and kp.y>-1: ax.plot(kp.x, kp.y, 'rx')
            except: pass
            plt.show()
            print()

In [None]:
noaugs = augs.make_noaugs_augmenter()

bbox_storage_path = '/Users/thomas/Documents/GitHub/cnn4cmr/trained_models/T1/Bbox'
bbox_cnn, _       = reload_checkpoint(os.path.join(bbox_storage_path,'epoch_115'), 'bbox_trainedonall_', 'mps')
bbox_cnn = bbox_cnn.eval()

cont_storage_path = '/Users/thomas/Documents/GitHub/cnn4cmr/trained_models/T1/Cont'
cont_cnn, _       = reload_checkpoint(os.path.join(cont_storage_path, 'epoch_115'), 'cont_trainedonall_', 'mps')
cont_cnn = cont_cnn.eval()

kp_storage_path   = '/Users/thomas/Documents/GitHub/cnn4cmr/trained_models/T1/Keypoint'
kp_cnn, _         = reload_checkpoint(os.path.join(kp_storage_path, 'epoch_115'), 'kp_trainedonall_', 'mps')
kp_cnn = kp_cnn.eval()

In [None]:
"""
#inp_path = '/Users/thomas/Desktop/CMR/LazyLuna_Data/Data/AI_Comparison/Imgs/Diseased/DCM/ECSPRESS-094/SAX T2'
#inp_path = '/Users/thomas/Desktop/CMR/LazyLuna_Data/Data/AI_Comparison/Imgs/Preserved/PSORCOR-051/SAX T2'
#inp_path = '/Users/thomas/Desktop/CMR/LazyLuna_Data/Data/AI_Comparison/Imgs/Healthy/DZHK_TV_006_Helios/SAX T2'
inp_path = '/Users/thomas/Desktop/CMR/LazyLuna_Data/Data/AI_Comparison/Imgs/Diseased/DCM/ECSPRESS145/SAX T2'
out_path = '/Users/thomas/Desktop/outtest'
gen = annotate_parametric_mapping_images(path_to_images=inp_path, 
                                         path_to_storage=out_path, 
                                         augmenter=noaugs,
                                         bbox_cnn=bbox_cnn, cont_cnn=cont_cnn, kp_cnn=kp_cnn,
                                         device='mps')
"""

In [None]:
inp_path = '/Users/thomas/Desktop/RH_Mapping_Test/t1_pre'
out_path = '/Users/thomas/Desktop/new_outtest_RH'
gen = annotate_parametric_mapping_images(path_to_images=inp_path, 
                                         path_to_storage=out_path, 
                                         augmenter=noaugs,
                                         bbox_cnn=bbox_cnn, cont_cnn=cont_cnn, kp_cnn=kp_cnn,
                                         device='mps')

In [None]:
inp_path = '/Users/thomas/Documents/GitHub/cnn4cmr/Datasets/T2_dataset/Imgs'
out_path = '/Users/thomas/Desktop/outtest'
gen = annotate_parametric_mapping_images(path_to_images=inp_path, 
                                         path_to_storage=out_path, 
                                         augmenter=noaugs,
                                         bbox_cnn=bbox_cnn, cont_cnn=cont_cnn, kp_cnn=kp_cnn,
                                         device='mps')

In [None]:
noaugs = augs.make_noaugs_augmenter()

bbox_storage_path = '/Users/thomas/Documents/GitHub/cnn4cmr/trained_models/T1/Bbox'
bbox_cnn, _       = reload_checkpoint(os.path.join(bbox_storage_path,'epoch_50'), 'bbox_', 'mps')
bbox_cnn = bbox_cnn.eval()

cont_storage_path = '/Users/thomas/Documents/GitHub/cnn4cmr/trained_models/T1/Cont'
cont_cnn, _       = reload_checkpoint(os.path.join(cont_storage_path, 'epoch_100'), 'cont_', 'mps')
cont_cnn = cont_cnn.eval()

kp_storage_path   = '/Users/thomas/Documents/GitHub/cnn4cmr/trained_models/T1/Keypoint'
kp_cnn, _         = reload_checkpoint(os.path.join(kp_storage_path, 'epoch_100'), 'kp_', 'mps')
kp_cnn = kp_cnn.eval()

In [None]:
inp_path = '/Users/thomas/Desktop/ReliabilityPaper/Data/Evaluation_Data/T1/Imgs'
out_path = '/Users/thomas/Desktop/ReliabilityPaper/Data/Evaluation_Data/T1/AI'
gen = annotate_parametric_mapping_images(path_to_images=inp_path, 
                                         path_to_storage=out_path, 
                                         augmenter=noaugs,
                                         bbox_cnn=bbox_cnn, cont_cnn=cont_cnn, kp_cnn=kp_cnn,
                                         device='mps')