In [1]:
%reload_ext autoreload
%autoreload 2

import os
import numpy as np
from glob import glob
from skimage.transform import resize
import sigpy as sp
import matplotlib.pyplot as plt
from tqdm import tqdm_notebook as tqdm

plt.set_cmap('gray')
plt.rcParams['figure.figsize'] = (12, 10)

def process_mpr(vol, plane):
    if plane == 'sag':
        tr = (2, 0, 1)
        x1 = 100
    else:
        tr = (1, 0, 2)
        x1 = 120
    vol = vol.transpose(tr)
    vol = vol[x1:-x1, ...]
    vol = np.rot90(vol, k=2, axes=(1, 2))
    vol = sp.util.resize(vol, [vol.shape[0], 128, vol.shape[2]])
    vol = resize(vol, [vol.shape[0], 512, 512])
    return vol

def process_case(fpath_data):
    data = np.load(fpath_data)
    data = data.transpose(0, 2, 1, 3, 4)
    data_sag = None
    data_cor = None
    for m in np.arange(data.shape[0]):
        for c in np.arange(data.shape[1]):
            vol = data[m, c]
            vol_sag = process_mpr(vol, plane='sag')
            if data_sag is None:
                data_sag = np.zeros((
                    data.shape[0], data.shape[1], vol_sag.shape[0], vol_sag.shape[1], vol_sag.shape[2]
                ))
            data_sag[m, c] = vol_sag
            
            vol_cor = process_mpr(vol, plane='cor')
            if data_cor is None:
                data_cor = np.zeros((
                    data.shape[0], data.shape[1], vol_cor.shape[0], vol_cor.shape[1], vol_cor.shape[2]
                ))
            
            data_cor[m, c] = vol_cor
    
    return data, data_sag, data_cor

def create_slices(vol, dest_path):
    if not os.path.exists(dest_path):
        os.makedirs(dest_path)
    
    for sl_idx in np.arange(vol.shape[2]):
        fname = os.path.join(dest_path, '{:03d}.npy'.format(sl_idx))
        np.save(fname, vol[:, :, sl_idx])

<Figure size 432x288 with 0 Axes>

In [5]:
src1 = '/mnt/raid/srivathsa/mra_synth/preprocess/slices'
src2 = '/home/srivathsa/projects/studies/gad/mra_synth/preprocess/slices'

cases1 = sorted([c.split('/')[-1] for c in glob('{}/*'.format(src1))])
cases2 = sorted([c.split('/')[-1] for c in glob('{}/*'.format(src2))])

cmn = [c for c in cases1 if c in cases2]

for cnum in cases1:
    ax_files = len([fp for fp in glob('{}/{}/ax/*.npy'.format(src1, cnum))])
    sag_files = len([fp for fp in glob('{}/{}/sag/*.npy'.format(src1, cnum))])
    cor_files = len([fp for fp in glob('{}/{}/cor/*.npy'.format(src1, cnum))])
    if ax_files != 100 or sag_files != 312 or cor_files != 272:
        print(cnum, ax_files, sag_files, cor_files)

IXI035-IOP-0873 92 312 272
IXI230-IOP-0869 92 312 272
IXI231-IOP-0866 92 312 272
IXI232-IOP-0898 92 312 272
IXI234-IOP-0870 92 312 272
IXI238-IOP-0883 92 312 272


In [None]:
from subtle.utils.slice import build_slice_list, get_num_slices
src_path = '/home/srivathsa/projects/studies/gad/mra_synth/preprocess/slices'
data_dir = '/home/srivathsa/projects/studies/gad/mra_synth/preprocess/data'

cases = sorted([c.split('/')[-1] for c in glob('{}/*'.format(src_path))])[:25]
data_list = ['{}/{}.npy'.format(src_path, cnum) for cnum in cases]

# files, indices = build_slice_list(data_list, slice_axis=[0, 2, 3], params={'h5_key': 'all'})

# slice_dict = {
#     data_file: [
#         get_num_slices(data_file, axis=sl_axis, params={'h5_key': 'all'}) 
#         for sl_axis in [0, 2, 3]
#     ]
#     for data_file in data_list
# }

In [None]:
from subtle.data_loaders import PreSlicedMPRLoader, SliceLoader

data_loader = PreSlicedMPRLoader(
    data_list, slice_axis=[0, 2 ,3], use_enh_mask=True, slices_per_input=7, enh_pfactor=1.75
)

In [None]:
X, Y = data_loader.__getitem__(12)

In [None]:
print(X.shape, Y.shape)

plt.imshow(np.hstack([X[0, ..., 3], X[0, ..., 10], Y[0, ..., 0], Y[0, ..., 1]]))

In [None]:
plt.imshow(Y[0, ..., 1])

In [None]:
data_loader._get_context_slices(
    '/home/srivathsa/projects/studies/gad/mra_synth/preprocess/slices/IXI023-Guys-0699/cor/002.npy'
)

In [None]:
data = np.load('/home/srivathsa/projects/studies/gad/mra_synth/preprocess/slices/IXI024-Guys-0705/cor/260.npy')
print(data.shape)

In [None]:
dlist = ['{}/{}.npy'.format(data_dir, cnum) for cnum in cases]
sl_loader = SliceLoader(dlist, slice_axis=[0, 2 ,3], use_enh_mask=True, slices_per_input=7, resize=512)
X, Y = sl_loader.__getitem__(50)

In [None]:
print(X.shape, Y.shape)

plt.imshow(Y[0, ..., 1])

In [None]:
dest_path = '/home/srivathsa/projects/studies/gad/mra_synth/preprocess/slices'
src_path = '/home/srivathsa/projects/studies/gad/mra_synth/preprocess/data'
cases = sorted([c.split('/')[-1].replace('.npy', '') for c in glob('{}/*.npy'.format(src_path))])

for cnum in tqdm(cases, total=len(cases)):
    fpath_npy = '{}/{}.npy'.format(src_path, cnum)
    data_ax, data_sag, data_cor = process_case(fpath_npy)
    create_slices(data_ax, '{}/{}/ax'.format(dest_path, cnum))
    create_slices(data_sag, '{}/{}/sag'.format(dest_path, cnum))
    create_slices(data_cor, '{}/{}/cor'.format(dest_path, cnum))