In [None]:
%reload_ext autoreload
%autoreload 2

import os

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import os
import SimpleITK as sitk
import numpy as np
import pydicom
from glob import glob
from scipy.ndimage import affine_transform
from time import time

import matplotlib.pyplot as plt
plt.set_cmap('gray')
plt.rcParams['figure.figsize'] = (10, 8)

from subtle.subtle_preprocess import (
    dcm_to_sitk, _get_crop_range, normalize_im, scale_im_enhao, apply_reg_transform, register_im
)
from subtle_utils.dcmutil.pydicom_utils import get_image_orientation
from voxelmorph.tf.networks import VxmAffine, AIRNet, Transform
from tensorflow.keras.models import Model
import tensorflow_addons as tfa

def get_dcm_from_dir(dpath_dcm):
    dcm_files = sorted([f for f in glob('{}/**/*'.format(dpath_dcm), recursive=True) if os.path.isfile(f)])
    return pydicom.dcmread(dcm_files[0])

def crop_or_pad(img, ref_img):
    if img.shape == ref_img.shape:
        return img
    pad_args = {'mode': 'constant', 'constant_values': 0}
    
    crop_arr = []
    pad_arr = []
    for sh_idx in np.arange(img.ndim):
        sh_diff = img.shape[sh_idx] - ref_img.shape[sh_idx]
        
        if sh_diff == 0:
            crop_arr.append('none')
            pad_arr.append((0, 0))
        elif sh_diff < 0:
            sh_diff = np.abs(sh_diff)
            d1 = sh_diff // 2
            d2 = d1 if sh_diff % 2 == 0 else d1 + 1

            crop_arr.append('none')
            pad_arr.append((d1, d2))
        else:
            if sh_diff == 1:
                s = 0
                e = img.shape[sh_idx] - 1
            else:
                s, _e = _get_crop_range(sh_diff)
                e = img.shape[sh_idx] - _e
            arg = (s, e)
            
            crop_arr.append((s, e))
            pad_arr.append((0, 0))
    
    pad_args['pad_width'] = pad_arr
    out_img = np.pad(img, **pad_args)
    
    for i, crop_info in enumerate(crop_arr):
        if crop_info == 'none': continue
        out_img = np.take(out_img, np.arange(crop_info[0], crop_info[1]), axis=i)
    
    return out_img

def intensity_scale(img):
    img = normalize_im(img, axis=(1, 2))
    img = np.interp(img, (img.min(), img.max()), (0, 1))
    return img

def get_aff_model(model, fpath_ckp):
    aff_layer = model.model.get_layer('{}_aff_mtx'.format(model.name))
    aff_model = Model(
        inputs=[model.model.layers[0].input, model.model.layers[1].input], 
        outputs=aff_layer.output
    )
    aff_model.load_weights(fpath_ckp, by_name=True)
    return aff_model

In [None]:
# dpath_fx = '/home/srivathsa/projects/studies/gad/tiantan/data/NO30/3DT1WMPRAGE_SAG_CS4_301'
# dpath_mv = '/home/srivathsa/projects/studies/gad/tiantan/data/NO30/T2W_NEW_TRA_Series0201'

dpath_fx = '/home/srivathsa/projects/studies/gad/stanford/data/Patient_0125/7_AX_BRAVO_PRE'
dpath_mv = '/home/srivathsa/projects/studies/gad/stanford/data/Patient_0125/10_AX_BRAVO_+C'

dpath_ckp = '/home/srivathsa/projects/studies/gad/vmorph/best_ckps'
fx_con = 't1'
mv_con = 't2'

mdl_specs = {
    't1': {
        'inshape': (128, 256, 256),
        'ckp_name': '20220413_220425-brats_real_t1.h5',
        'model': VxmAffine
    },
    't2': {
        'inshape': (128, 240, 240),
        'ckp_name': '20211013_004451-brats_t2.h5',
        'model': AIRNet
    },
    'fl': {
        'inshape': (128, 240, 240),
        'ckp_name': '20211012_050145-brats_fl.h5',
        'model': AIRNet
    }
}

spec_obj = mdl_specs[mv_con]
mdl_z, mdl_x, mdl_y = spec_obj['inshape']

mdl_spacing = (1, 1, 1)

fx = dcm_to_sitk(dpath_fx)
mv = dcm_to_sitk(dpath_mv)

In [None]:
# resample moving image to fixed image space
mv_rs = sitk.Resample(mv, fx.GetSize(), sitk.Transform(), sitk.sitkNearestNeighbor, fx.GetOrigin(), 
                          fx.GetSpacing(), fx.GetDirection(), 0, mv.GetPixelID())

In [None]:
z_out = int(np.ceil(fx.GetSize()[-1] * (fx.GetSpacing()[-1] / mdl_spacing[-1])))

# resample fixed image to model size
fx_rs_mdl = sitk.Resample(fx, (mdl_x, mdl_y, z_out), sitk.Transform(), sitk.sitkNearestNeighbor, 
                          fx.GetOrigin(), mdl_spacing, fx.GetDirection(), 0, fx.GetPixelID())

# resample moving image (already resampled to fixed space) to model size
mv_rs_mdl = sitk.Resample(mv_rs, (mdl_x, mdl_y, z_out), sitk.Transform(), sitk.sitkNearestNeighbor, 
                          mv_rs.GetOrigin(), mdl_spacing, mv_rs.GetDirection(), 0, mv_rs.GetPixelID())

# adjust the slice dimension according to model specs
mdl_ref = np.zeros((mdl_z, mdl_x, mdl_y))

fx_np = sitk.GetArrayFromImage(fx_rs_mdl)
mv_np = sitk.GetArrayFromImage(mv_rs_mdl)

fx_np = crop_or_pad(fx_np, mdl_ref)
mv_np = crop_or_pad(mv_np, mdl_ref)

fx_np = intensity_scale(fx_np)[None, ..., None]
mv_np = intensity_scale(mv_np)[None, ..., None]

In [None]:
model_config = {
    'inshape': (mdl_z, mdl_x, mdl_y),
    'verbose': 0
}

if 'VxmAffine' in str(spec_obj['model']):
    ext_config = {
        'batch_norm': False,
        'constraint_params': False,
        'enc_only': False
    }
    
    model_config = {**model_config, **ext_config}

In [None]:
img_model = spec_obj['model'](**model_config)
t1 = time()
aff_model = get_aff_model(img_model, fpath_ckp=os.path.join(dpath_ckp, spec_obj['ckp_name']))
t2 = time()
print(t2-t1)

In [None]:
spec_obj['model']

In [None]:
t1 = time()
aff_params = aff_model.predict([mv_np, fx_np])[0]
t2 = time()
print(t2-t1)

In [None]:
mv_rs_np = sitk.GetArrayFromImage(mv_rs)
# t1 = time()
# moved = affine_transform(mv_rs_np, aff_params)
# t2 = time()
# print(t2-t1)

In [None]:
tfm = Transform(inshape=mv_rs_np.shape, affine=True)
t1 = time()
moved = tfm.predict([mv_rs_np[None, ..., None], aff_params[None]])[0, ..., 0]
t2 = time()
print(t2-t1)