In [None]:
%reload_ext autoreload
%autoreload 2

import warnings
warnings.filterwarnings('ignore')

import os

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "5"

import numpy as np
import keras.backend as K
from keras.models import Model
import matplotlib.pyplot as plt
from glob import glob
from tqdm import tqdm_notebook as tqdm
import yaml
from scipy.ndimage import affine_transform
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
import nibabel as nib
import pandas as pd
from skimage.transform import resize
plt.rcParams['figure.figsize'] = (10, 8)
plt.set_cmap('gray')

from voxelmorph.tf.networks import AIRNet, Transform, VxmAffine
from scripts.tf import eval as quant_eval
from voxelmorph.tf.utils import params_to_affine_matrix
import subtle.subtle_metrics as su_metrics

def process_brats_vol(img_vol, pad=False, is_seg=False):
    img_vol = np.rot90(img_vol.transpose(2, 0, 1), axes=(1, 2), k=3)
    img_vol = img_vol[13:-14]
    
    if not is_seg:
        img_vol = img_vol / img_vol.mean()
        img_vol = np.interp(img_vol, (img_vol.min(), img_vol.max()), (0, 1))
    
    if pad:
        img_vol = np.pad(img_vol, pad_width=[(0, 0), (8, 8), (8, 8)], mode='constant', constant_values=0)
    return img_vol

def get_best_ckp(exp_path, exp_id):
    print('Fetching best checkpoint file for {}...'.format(exp_id))
    fpath_ckps = sorted([f for f in glob('{}/{}/ckps/*.h5'.format(exp_path, exp_id))])
    last_ckp_num = int(fpath_ckps[-1].split('/')[-1].replace('.h5', ''))
    num_ckps = last_ckp_num // 5
    step_idxs = [i-1 for i in np.arange(5, last_ckp_num + 1, 5)]
    
    log_path = '{}/{}/logs/tb'.format(exp_path, exp_id)
    num_files = len(glob('{}/*'.format(log_path)))

    tf_size_guidance = {
        'compressedHistograms': 0,
        'images': num_files,
        'scalars': last_ckp_num+1,
        'histograms': 0
    }

    event_acc = EventAccumulator(log_path, size_guidance=tf_size_guidance)
    event_acc.Reload()

    
    val_losses = [(v.step, v.value) for v in event_acc.Scalars('val_loss') if v.step in step_idxs]
    min_idx = np.argmin([v[1] for v in val_losses])
    
    ep_idx = (val_losses[min_idx][0] + 1)
    fpath_ckp = '{exp_path}/{exp_id}/ckps/{epoch:04d}.h5'.format(exp_path=exp_path, exp_id=exp_id, epoch=ep_idx)
    
    return fpath_ckp, ep_idx
    
def fetch_eval_models(exp_path, exp_id, tfm_models=True):
    exp_dir = os.path.join(exp_path, exp_id)
    config = yaml.load(open('{}/config.yaml'.format(exp_dir), 'r').read())
    model_dict = {
        'airnet': AIRNet,
        'affine': VxmAffine
    }

    img_list = open(config['img_list'], 'r').read().split('\n')
    sample_data = np.load(img_list[0])['data']
    img_shape = sample_data.shape[1:]

    arch = config['arch']
    model_params = {
        'inshape': img_shape,
        'verbose': 0
    }

    if arch == 'affine':
        model_params['batch_norm'] = (config['bnorm'] == 'true')
        model_params['constraint_params'] = (config['constraint_params'] == 'true')
        model_params['enc_only'] = (config['network_mode'] == 'enc')
    print('MODEL PARAMS', model_params)
    eval_model = model_dict[arch](**model_params)
    
    if tfm_models:
        tfm_model_gt = Transform(inshape=img_shape, affine=True, shift_center=False)
        tfm_model_pred = Transform(inshape=img_shape, affine=True)
        return eval_model, tfm_model_gt, tfm_model_pred
    return eval_model
        
    
def get_model_pred(model, aff_model, fixed, moving):
    ip1 = fixed[None, ..., None]
    ip2 = moving[None, ..., None]
    
    pred_img = model.model.predict([ip2, ip1])[0, ..., 0]
    pred_aff = aff_model.predict([ip2, ip1])[0]
    return pred_img, pred_aff

def get_aff_model(model, fpath_ckp):
    aff_layer = model.model.get_layer('{}_aff_mtx'.format(model.name))
    aff_model = Model(
        inputs=[model.model.layers[0].input, model.model.layers[1].input], 
        outputs=aff_layer.output
    )
    aff_model.load_weights(fpath_ckp, by_name=True)
    return aff_model

def get_seg_mask(fpath_ref, cnum, pad_ref):
    seg_gt = process_brats_vol(
        nib.load('{}/{}/{}_seg.nii.gz'.format(fpath_ref, cnum, cnum)).get_fdata(), is_seg=True, pad=pad_ref
    )
    seg_gt = (seg_gt > 0).astype(np.uint8) #combine all tumor classes into one mask
    return seg_gt

def compute_dice(seg_gt, aff_sim, aff_pred, pad_ref, tfm_models):
    seg_sim = tfm_models[0].predict([seg_gt[None, ..., None], aff_sim[None]])[0, ..., 0]
    seg_sim = (seg_sim > 0.9).astype(np.uint8)
    
    seg_pred = tfm_models[1].predict([seg_sim[None, ..., None], aff_pred[None]])[0, ..., 0]
    seg_pred = (seg_pred > 0.9).astype(np.uint8)
    
    dice = su_metrics.dice(seg_gt, seg_pred)
    return dice

def compute_incremental_variation(data_path, exp_path, exp_id, fpath_ref, case_num, 
                                  aff_type='translation', aff_axis='x', aff_array=None):
    t1pre, t1post = np.load('{}/{}.npz'.format(data_path, case_num))['data']

    fpath_ckp, _ = get_best_ckp(exp_path, exp_id)
    model, tfm_model_gt, tfm_model_pred = fetch_eval_models(exp_path, exp_id)
    model.model.load_weights(fpath_ckp)
    aff_model = get_aff_model(model, fpath_ckp)

    seg_gt = get_seg_mask(fpath_ref, case_num, pad_ref=True)

    tfm_model_seg1 = Transform(
        inshape=tuple(t1pre.shape), affine=True, interp_method='nearest', shift_center=False
    )

    tfm_model_seg2 = Transform(
        inshape=tuple(t1pre.shape), affine=True, interp_method='nearest', shift_center=True
    )

    tfm_models = [tfm_model_seg1, tfm_model_seg2]
    
    if aff_array is None:
        aff_array = np.arange(0, 20, 0.5).astype(np.float32)
    
    dir_plot = os.path.join(exp_path, exp_id, 'eval', 'plots', '{}_{}'.format(aff_type, aff_axis), case_num)
    if not os.path.exists(dir_plot):
        os.makedirs(dir_plot)
    
    if aff_type == 'translation':
        idx_map = {'x': 2, 'y': 1, 'z': 0}
        aff_shift = np.array([0.0] * 3)
    elif aff_type == 'rotation':
        idx_map = {'xy': 5, 'xz': 4, 'yz': 3}
        aff_shift = np.array([0.0] * 6)
    else:
        raise ValueError('Affine type must be translation or rotation')
    
    metric_list = []
    for i, aval in enumerate(tqdm(aff_array, total=len(aff_array))):
        aff_shift[idx_map[aff_axis]] = aval
        aff = K.eval(params_to_affine_matrix(aff_shift))

        moving = affine_transform(t1post, aff)
        moving = np.clip(moving, 0, moving.max())

        pred, aff_pred = get_model_pred(model, aff_model, t1pre, moving)
        dice = compute_dice(seg_gt, aff, aff_pred, pad_ref=True, tfm_models=tfm_models)

        plt.imshow(np.hstack([moving[64] - t1pre[64], pred[64] - t1pre[64]]))
        plt.title('X Shift = {:.3f}, Dice = {:.3f}'.format(aval, dice))
        plt.savefig('{}/Img_{}.png'.format(dir_plot, i))
        plt.clf()
        
        metric_obj = {'aff_val': aval, 'dice': dice}
        metric_list.append(metric_obj)
    
    pd.DataFrame(metric_list).to_csv('{}/scores.csv'.format(dir_plot))

def eval_model_variation(data_path, exp_path, exp_id, fpath_ref, num_cases=25):
    all_cases = [f.split('/')[-1].replace('.npz', '') for f in glob('{}/*.npz'.format(data_path))]
    case_nums = np.random.choice(all_cases, size=num_cases, replace=False)
    
    aff_types = ['translation', 'rotation']
    aff_axes = [['x', 'y', 'z'], ['xy' 'xz', 'yz']]
    aff_arrs = [np.arange(0, 20, 1), np.arange(0, 15, 0.75)]
    
    for a_idx, aff_type in enumerate(aff_types):
        aff_axis = aff_axes[a_idx]
        aff_array = aff_arrs[a_idx]
        
        for ax in aff_axis:
            print('Computing variation in {} axis for {}...'.format(ax, aff_type))
            
            for case_num in case_nums:
                compute_incremental_variation(data_path, exp_path, exp_id, fpath_ref, 
                                              case_num, aff_type, aff_axis=ax, aff_array=aff_array)
def sag2ax(img_vol):
    img_vol = img_vol.transpose(2, 1, 0)
    img_vol = np.rot90(img_vol, axes=(1, 2), k=1)[64:-64]
    img_vol = np.pad(img_vol, [(0, 0), (64, 64), (0, 0)])
    img_vol = np.clip(img_vol, 0, img_vol.max())
    return img_vol

def plot_pred(fixed, moving, reg, fpath=None, ref_prereg=None, sl=None):
    sl_idx = sl if sl is not None else fixed.shape[0] // 2
    img1 = np.hstack([fixed[sl_idx], moving[sl_idx], reg[sl_idx]])
    
    diff1 = moving - fixed
    diff2 = reg - fixed
    diff3 = np.zeros_like(fixed)
    if ref_prereg is not None:
        diff3 = ref_prereg - fixed
    
    fig, axs = plt.subplots(2, 3)
    
    imgs = [
        (fixed[sl_idx], 'Fixed'), 
        (moving[sl_idx], 'Moving'), 
        (reg[sl_idx], 'Registered'), 
        (diff1[sl_idx], 'Diff(fixed, moving)'), 
        (diff2[sl_idx], 'Diff(fixed, registered)'),
        (diff3[sl_idx], 'Diff - BRATS prereg')
    ]
    k = 0
    for i in range(2):
        for j in range(3):
            axs[i, j].imshow(imgs[k][0])
            axs[i, j].axis('off')
            axs[i, j].set_title(imgs[k][1])
            k += 1
    fig.tight_layout()
    
    if fpath:
        plt.savefig(fpath)
        plt.clf()

In [None]:
exp_path = '/home/srivathsa/projects/studies/gad/vmorph/runs'
exp_id = '20211018_212059-brats'
data_path = '/home/srivathsa/projects/studies/gad/vmorph/data/sford_post_nobet/val'

model = fetch_eval_models(exp_path, exp_id, tfm_models=False)
fpath_ckp, _ = get_best_ckp(exp_path, exp_id)
model.model.load_weights(fpath_ckp)
aff_model = get_aff_model(model, fpath_ckp)

cases = sorted([f.split('/')[-1].replace('.npz', '') for f in glob('{}/*.npz'.format(data_path))])
dirpath_plot = '/home/srivathsa/projects/studies/gad/vmorph/sford_eval/{}_{}'.format(exp_id, data_path.split('/')[-2])
if not os.path.exists(dirpath_plot):
    os.makedirs(dirpath_plot)

for case_num in tqdm(cases, total=len(cases)):
    fixed, moving = np.load('{}/{}.npz'.format(data_path, case_num))['data']
    
#     fixed = np.clip(fixed, 0, fixed.max())
#     moving = np.clip(moving, 0, moving.max())
    
    pred, aff = get_model_pred(model, aff_model, fixed, moving)
    fpath_plot = '{}/{}.png'.format(dirpath_plot, case_num)
    plot_pred(fixed, moving, pred, fpath_plot)

In [None]:
data = np.load('/home/srivathsa/projects/studies/gad/vmorph/data/sford_post/val/Patient_0149.npz')['data']

plt.imshow(np.hstack([data[0, 64], data[1, 64]]))