In [None]:
%reload_ext autoreload
%autoreload 2

import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from ipywidgets import interact, interact_manual
from matplotlib.patches import Rectangle
from skimage.morphology import binary_dilation, square

plt.set_cmap('gray')
plt.rcParams['figure.figsize'] = (12, 10)

import subtle.utils.io as suio
import subtle.subtle_preprocess as supre

def show_img(img, title='', axis=False, vmin=None, vmax=None, colorbar=False):
    imshow_args = {}
    
    if vmin:
        imshow_args['vmin'] = vmin
    if vmax:
        imshow_args['vmax'] = vmax
    
    im_axis = 'on' if axis else 'off'
    plt.axis(im_axis)
    plt.imshow(img, **imshow_args)
    plt.title(title, fontsize=15)
    if colorbar:
        plt.colorbar()

def show_gad_comparison(img_pre, img_low, img_post, vmin=None, vmax=None):
    fig = plt.figure(figsize=(15, 10))
    fig.tight_layout()

    fig.add_subplot(1, 3, 1)
    show_img(img_pre, title='Pre contrast', vmin=vmin, vmax=vmax)

    fig.add_subplot(1, 3, 2)
    show_img(img_low, title='10% dosage', vmin=vmin, vmax=vmax)

    fig.add_subplot(1, 3, 3)
    show_img(img_post, title='Full dosage', vmin=vmin, vmax=vmax)

    plt.show()

In [None]:
from preprocess import _mask_npy
from deepbrain import Extractor as BrainExtractor

img_npy = np.load('/home/srivathsa/projects/studies/gad/gen_siemens/preprocess/test_571.npy')

img_scale = np.interp(img_npy, (img_npy.min(), img_npy.max()), (0, 1))
ext = BrainExtractor()
segment_probs = ext.run(img_scale)

prob = segment_probs > 0.5

show_img(prob[88])

In [None]:
show_img(mask[88])

In [None]:
slice_idx = 69
data, data_mask = np.load('/home/srivathsa/projects/studies/gad/tiantan/preprocess/data_256/Brain3H-600437593.npy')

In [None]:
zero_low = data_mask[slice_idx, 1] - data_mask[slice_idx, 0]
zero_full = data_mask[slice_idx, 2] - data_mask[slice_idx, 0]

mask_diff = (zero_full >= 0.5).astype(np.float32) - (zero_low >= 0.5).astype(np.float32)
mask_diff = (np.interp(mask_diff, (mask_diff.min(), mask_diff.max()), (0, 1)) > 0.9).astype(np.float32)
mask_diff = binary_dilation(mask_diff, selem=square(3)).astype(np.float32)

mask_diff[mask_diff == 1] = zero_full.max()
mask_diff[mask_diff == 0] = 1.0

zero_full *= mask_diff
zero_full = np.clip(zero_full, 0, zero_full.max())

show_img(zero_full)
# zero_low = np.interp(zero_low, (zero_low.min(), zero_low.max()), (0, 1))
# zero_full = np.interp(zero_full, (zero_full.min(), zero_full.max()), (0, 1))

# show_img(zero_low, colorbar=True)

In [None]:
batch_start = slice_idx-4
batch_end = slice_idx+4

X_batch = []
Y_batch = []
for idx in range(batch_start, batch_end):
    X = data_mask[idx-3:idx+4, :2]
    Y = data_mask[idx, None, 2]
    
    X_batch.append(X)
    Y_batch.append(Y)

X_batch = np.array(X_batch)
Y_batch = np.array(Y_batch)[:, None, ...]

print(X_batch.min(), X_batch.max())
enh_mask = supre.enh_mask_smooth(X_batch, Y_batch, 3, p=1)[:, 0, 0, ...]
print(enh_mask.min(), enh_mask.max())

In [None]:
show_img(enh_mask[0], colorbar=True)

In [None]:
enh_mask_25 = supre.enh_mask_smooth(X_batch, Y_batch, 3, p=2.5)[:, 0, 0, ...]
show_img(enh_mask_25[0])

In [None]:
import nibabel as nib

# data = nib.load('/raid/srivathsa/aae/hcp_data/HCP/mgh_1001/MPRAGE_GradWarped_and_Defaced/2013-01-01_11_25_56.0/S227198/HCP_mgh_1001_MR_MPRAGE_GradWarped_and_Defaced_Br_20140919084711597_S227198_I444246.nii').get_data()
data = nib.load('/raid/srivathsa/aae/hcp_data/HCP_T2/mgh_1001/T2_GradWarped_and_Defaced/2013-01-01_11_25_56.0/S227199/HCP_mgh_1001_MR_T2_GradWarped_and_Defaced_Br_20140919151202379_S227199_I444362.nii').get_data()
data = data.transpose(2, 0, 1)

