In [1]:
import nibabel as nib
import matplotlib.pyplot as plt
import numpy as np
from glob import glob
from tqdm import tqdm_notebook as tqdm
import os
import shutil
import json
from pathlib import Path
import pandas as pd

from subtle.subtle_metrics import dice

plt.set_cmap('gray')
plt.rcParams['figure.figsize'] = (12, 10)



<Figure size 432x288 with 0 Axes>

### Step 1: Create files for tumor segmentation from model prediction

- For example `src_data` is symbolic linked from `/mnt/raid/jiang/projects/SubtleGAN/model/MMT/MMTUNetHybrid_GAN_s_c_MMT_epo50_bs24_lrg0.0005_192_opt-adamw/synthetic_images/test/`

In [None]:
base_path = '/home/srivathsa/projects/studies/gad/mmt_seg/seg'
src_dirname = 'src_data'
dest_dirname = 'cases'
data_path = '{}/{}'.format(base_path, src_dirname)
dest_fname = '{base_path}/{dest_dirname}/{cnum}{syn_sfx_dir}/{cnum}{syn_sfx_dir}_{con}.nii.gz'
src_fname = '{base_path}/{src_dirname}/{cnum}/{fname}.nii.gz'

cases = sorted([f.split('/')[-1] for f in glob('{}/*'.format(data_path))])
cons = ['T1', 'T1Gd', 'T2', 'FLAIR']
syn_sfx = [''] + ['_{}_syn'.format(c) for c in cons]

for syn_sfx_dir in syn_sfx:
    for cnum in tqdm(cases, total=len(cases)):
        case_dir = '{}/{}/{}{}'.format(base_path, dest_dirname, cnum, syn_sfx_dir)
        if syn_sfx_dir == '':
            conlist = cons
        else:
            conlist = [c if c != syn_sfx_dir.split('_')[1] else syn_sfx_dir[1:] for c in cons]
        os.makedirs(case_dir)
        for con in conlist:
            fpath_src = src_fname.format(base_path=base_path, src_dirname=src_dirname, cnum=cnum, fname=con)
            fpath_dest = dest_fname.format(base_path=base_path, dest_dirname=dest_dirname, cnum=cnum, 
                                           syn_sfx_dir=syn_sfx_dir, con=con.replace('_syn', ''))
            shutil.copy(fpath_src, fpath_dest)

### Step 2: Create JSON file for segmentation config

In [None]:
seg_config = {
    'testing_inference': []
}

fpath_tmplt = '/workspace/mmt_seg/{dest_dirname}/{cnum}/{cnum}_{con}.nii.gz'

cases = sorted([f.split('/')[-1] 
                for f in glob('{}/{}/*'.format(base_path, dest_dirname))
               ])
cons = ['T1Gd', 'T1', 'T2', 'FLAIR']
for cnum in tqdm(cases, total=len(cases)):
    img_obj = {
        'image': [fpath_tmplt.format(dest_dirname=dest_dirname, cnum=cnum, con=con) for con in cons]
    }
    seg_config['testing_inference'].append(img_obj)

with open('{}/seg_config_mmg.json'.format(base_path), 'w') as f:
    f.write(json.dumps(seg_config, indent=2))

### Step 3: (from cmdline) - copy the `cases` directory to ngc-ec2 workspace
### Step 4: Run the inference on ngc-ec2 and copy the `pred` folder here (to `/home/srivathsa/projects/studies/gad/mmt_seg`

### Step 5: Rename the pred directories and files (some weird naming convention in the NGC API)

In [None]:
bpath = '{}/pred'.format(base_path)
pred_dirs = sorted([d.split('/')[-1] for d in glob('{}/*'.format(bpath))])

for pdir in pred_dirs:
    new_pdir = '_'.join(pdir.split('_')[:-1])
    os.rename('{}/{}'.format(bpath, pdir), '{}/{}'.format(bpath, new_pdir))

In [None]:
nii_files = sorted([f for f in glob('{}/**/*.nii.gz'.format(bpath), recursive=True)])

for fpath_nii in nii_files:
    dir_src = Path(fpath_nii).parent.absolute()
    fname = fpath_nii.split('/')[-1].replace('.nii.gz', '')
    seg_str = fname.split('_')[-1]
    new_fname = '{}_{}.nii.gz'.format('_'.join(fname.split('_')[:-2]), seg_str)
    
    os.rename(fpath_nii, '{}/{}'.format(dir_src, new_fname))

### Step 6: Compute Dice scores and save it to CSV files in `metrics`

In [None]:
cases = sorted([d.split('/')[-1] for d in glob('{}/src_data/*'.format(base_path))])
pred_dir = '{}/pred'.format(base_path)
cons = ['T1', 'T1Gd', 'T2', 'FLAIR']
syns = ['{}_syn'.format(c) for c in cons]
tumor_classes = ['ET', 'TC', 'WT']
metrics_dir = '{}/metrics'.format(base_path)

for syn in syns:
    metrics_obj = []
    for cnum in tqdm(cases, total=len(cases)):
        case_obj = {'Case': cnum}
        for cls in tumor_classes:
            fpath_gt = '{pdir}/{cnum}/{cnum}_seg{cls}.nii.gz'.format(pdir=pred_dir, cnum=cnum, cls=cls)
            mask_gt = nib.load(fpath_gt).get_fdata()
            
            fpath_pred = '{pdir}/{cnum}_{syn}/{cnum}_{syn}_seg{cls}.nii.gz'.format(
                pdir=pred_dir, cnum=cnum, cls=cls, syn=syn
            )
            mask_pred = nib.load(fpath_pred).get_fdata()
            case_obj[cls] = dice(mask_gt, mask_pred)
        metrics_obj.append(case_obj)
    pd.DataFrame(metrics_obj).to_csv('{}/{}.csv'.format(metrics_dir, syn))

## Signed rank test

In [2]:
from scipy.stats import wilcoxon
cons = {'T1': 't1', 'T1Gd': 'gd', 'T2': 't2', 'FLAIR': 'fl'}

def compile_dice_scores(dirpath_metrics):    
    df_metrics = []
    
    for kw, suf_str in cons.items():
        mrows = pd.read_csv('{}/{}_syn.csv'.format(dirpath_metrics, kw)).to_dict(orient='records')
        for mrow in mrows:
            crow = [d for d in df_metrics if d['case'] == mrow['Case']]
            if len(crow) == 0:
                df_row = {'case': mrow['Case']}
            else:
                df_row = crow[0]

            df_row['{}_avg'.format(suf_str)] = np.mean([mrow['ET'], mrow['TC'], mrow['WT']])
            
            if len(crow) == 0:
                df_metrics.append(df_row)
    return df_metrics

def get_p_values(bpath, mmt_kw='metrics_mmt', milr_kw='metrics_milr', mmg_kw='metrics_mmgan'):
    df_mmt = compile_dice_scores('{}/{}'.format(bpath, mmt_kw))
    df_mmg = compile_dice_scores('{}/{}'.format(bpath, mmg_kw))
    df_milr = compile_dice_scores('{}/{}'.format(bpath, milr_kw))
    
    p_matrix = []
    for kw, suf_str in cons.items():
        mmt_avg = [d['{}_avg'.format(suf_str)] for d in df_mmt]
        mmg_avg = [d['{}_avg'.format(suf_str)] for d in df_mmg]
        milr_avg = [d['{}_avg'.format(suf_str)] for d in df_milr]
        
        p_matrix.append({
            'Contrast': kw,
            'MM-GAN': wilcoxon(mmt_avg, mmg_avg, alternative='greater')[1],
            'MILR': wilcoxon(mmt_avg, milr_avg, alternative='greater')[1]
        })
    
    return pd.DataFrame(p_matrix)

In [3]:
df_pvals = get_p_values('/home/srivathsa/projects/studies/gad/mmt_seg/seg')

In [5]:
df_pvals.to_dict(orient='records')

[{'Contrast': 'T1',
  'MM-GAN': 8.435659527084613e-11,
  'MILR': 6.66917349824795e-05},
 {'Contrast': 'T1Gd',
  'MM-GAN': 7.374296604127792e-11,
  'MILR': 0.07333786313548717},
 {'Contrast': 'T2',
  'MM-GAN': 2.58419740681613e-12,
  'MILR': 0.0006940102982226807},
 {'Contrast': 'FLAIR',
  'MM-GAN': 1.0255692185770687e-11,
  'MILR': 1.2538994817819324e-08}]

In [None]:
fpaths_nii = [f for f in glob('/home/srivathsa/projects/studies/gad/mmt_seg/cases_mmg/**/*.nii.gz', recursive=True)]

for fp in tqdm(fpaths_nii, total=len(fpaths_nii)):
    nbf = nib.load(fp)
    vol = nbf.get_fdata().transpose(1, 0, 2)
    new_nb = nib.Nifti1Image(vol, affine=np.eye(4), header=nbf.header)
    nib.save(new_nb, fp)