In [10]:
import os
import imageio
import numpy as np
import matplotlib.pyplot as plt
from glob import glob
import nibabel as nib
from tqdm import tqdm_notebook as tqdm
from skimage import color

plt.set_cmap('gray')
plt.rcParams['figure.figsize'] = (10, 8)

base_path = '/mnt/datasets/srivathsa/sford_mets/png'
dest_path = '/mnt/datasets/srivathsa/sford_mets/nifti'

def png_to_npy(dirpath, mean_norm=True):
    fpaths = sorted([f for f in glob('{}/*.png'.format(dirpath))])
    img_vol = np.array([imageio.imread(f) for f in fpaths])
    if 'seg' in dirpath:
        mean_norm = False
        img_vol = np.interp(img_vol, (0, 255), (0, 1))
    if mean_norm:
        img_vol = img_vol / img_vol.mean()
    return img_vol

def get_rgb(img):
    img = (img - np.min(img))/np.ptp(img)
    return np.dstack((img, img, img))

def overlay_mask(data, label, r=0.9, g=0.1, b=0.1):
    data_rgb = get_rgb(data)

    label_r = (label==1) * r
    label_g = (label==2) * g
    label_b = (label==4) * b
    label_rgb = np.dstack((label_r, label_g, label_b))

    data_hsv = color.rgb2hsv(data_rgb)
    label_hsv = color.rgb2hsv(label_rgb)

    data_hsv[..., 0] = label_hsv[..., 0]
    data_hsv[..., 1] = label_hsv[..., 1]# * 0.55

    return color.hsv2rgb(data_hsv)

def process_mets_case(case_num, train=True):
    sfix = 'train' if train else 'test'
    dirpath = '{}/{}/{}'.format(base_path, sfix, case_num)
    dest = '{}/{}/{}'.format(dest_path, sfix, case_num)
    
    if not os.path.exists(dest):
        os.makedirs(dest)
    
    dir_map = {
        '0': '3_AX_T1_GE_post',
        '1': '1_AX_T1_SE_pre',
        '2': '2_AX_T1_SE_post',
        '3': '4_AX_T2_FLAIR_post'
    }
    
    if train:
        dir_map['seg'] = 'mets_seg'
    
    for k, v in dir_map.items():
        img_vol = png_to_npy('{}/{}'.format(dirpath, k))
        fpath_save = '{}/{}.nii.gz'.format(dest, v)
        img = nib.Nifti1Image(img_vol, affine=np.eye(4))
        nib.save(img, fpath_save)
        
        if train and 'seg' in fpath_save:
            slice_idxs = get_slice_idx(img_vol)
            np.save('{}/mets_slice_idx.npy'.format(dest), slice_idxs)

def get_slice_idx(mask):
    ed_mask = (mask == 1).astype(np.uint8)
    sl_idx, _, _ = np.nonzero(ed_mask==1)
    return np.unique(sl_idx)

def viz_mets(case_num, fpath_save=None):
    base_path = '/mnt/datasets/srivathsa/sford_mets/nifti/train/{}'.format(case_num)
    fpaths_data = sorted([f for f in glob('{}/*.nii.gz'.format(base_path)) if 'seg' not in f])
    data = np.array([nib.load(f).get_fdata() for f in fpaths_data])
    seg = nib.load('{}/mets_seg.nii.gz'.format(base_path)).get_fdata()
    met_slices = np.load('{}/mets_slice_idx.npy'.format(base_path))
    sl_idx = -5 if len(met_slices) > 6 else -1
    sl = met_slices[sl_idx]
    
    row1 = np.hstack([overlay_mask(data[0, sl], seg[sl]), overlay_mask(data[1, sl], seg[sl])])
    row2 = np.hstack([overlay_mask(data[2, sl], seg[sl]), overlay_mask(data[3, sl], seg[sl])])
    
    img = np.vstack([row1, row2])
    
    plt.imshow(img)
    if fpath_save is not None:
        plt.savefig(fpath_save)
        plt.clf()

<Figure size 720x576 with 0 Axes>

In [11]:
mets_cases = sorted([f.split('/')[-1] for f in glob('{}/train/Mets*'.format(base_path))])
plot_path = '/mnt/datasets/srivathsa/sford_mets/plots'

for case_num in tqdm(mets_cases, total=len(mets_cases)):
    viz_mets(case_num, '{}/{}.png'.format(plot_path, case_num))

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  after removing the cwd from sys.path.


  0%|          | 0/105 [00:00<?, ?it/s]

<Figure size 720x576 with 0 Axes>

In [None]:
mets_cases = sorted([f.split('/')[-1] for f in glob('{}/test/Mets*'.format(base_path))])

for case_num in tqdm(mets_cases, total=len(mets_cases)):
    process_mets_case(case_num, train=False)

In [None]:
data1 = nib.load('/mnt/datasets/srivathsa/sford_mets/nifti/train/Mets_030/1_AX_T1_SE_pre.nii.gz').get_fdata()
data2 = nib.load('/mnt/datasets/srivathsa/sford_mets/nifti/train/Mets_030/2_AX_T1_SE_post.nii.gz').get_fdata()
data3 = nib.load('/mnt/datasets/srivathsa/sford_mets/nifti/train/Mets_030/3_AX_T1_GE_post.nii.gz').get_fdata()
data4 = nib.load('/mnt/datasets/srivathsa/sford_mets/nifti/train/Mets_030/4_AX_T2_FLAIR_post.nii.gz').get_fdata()
sl_idxs = np.load('/mnt/datasets/srivathsa/sford_mets/nifti/train/Mets_030/mets_slice_idx.npy')

seg = nib.load('/mnt/datasets/srivathsa/sford_mets/nifti/train/Mets_030/mets_seg.nii.gz').get_fdata()
print(sl_idxs)
sl = sl_idxs[-5]
# plt.imshow(np.hstack([data1[sl], data2[sl], data3[sl], data4[sl]]))
img_ov = overlay_mask(data3[sl], seg[sl])
plt.imshow(img_ov)