In [1]:
import os
import sys
from pathlib import Path

sys.path.append(str(Path().cwd().parent))

In [2]:
import numpy as np
from tqdm import tqdm
from dataset.patch_dataset import BrainPatchesDataModule
from models.UNetModule import UNet3
from dataset.roi_extraction import slice_image, reconstruct_patches
from utils import z_score_norm
import SimpleITK as sitk
import torch
from models.EM import ExpectationMaximization
import cv2
import matplotlib.pyplot as plt
import pandas as pd
from brain_segmenter import BrainSegmenter

  from .autonotebook import tqdm as notebook_tqdm


Select GPU index to run the prediction on and path for the checkpoint used to load the model

In [12]:
### CHANGE ONLY THIS 2 LINES ###
DEVICE = 'cuda:2'
CHKP_PATH = Path('/home/user0/misa_vlex/brain_segmentation/outputs/unet3p_augm_focal_64-32_05_synthseg_merged/version_0/checkpoints/epoch=94-valid_dsc_macro_epoch=0.9409.ckpt')
ENSEMBLE = True
### DON'T CHANGE ANYTHING ELSE ###


In [None]:
bsegm = BrainSegmenter(CHKP_PATH, DEVICE)

d = bsegm.cfg['dataset']['patches']['denoiser']
SEGM_2_CH_NAME = '_seg_resampled_merged' if d == 'synthseg_merged' else '_seg_resampled'

# Get Validation Results

In [13]:
val_path = Path('/home/user0/misa_vlex/brain_segmentation/data/Validation_Set')
val_path_res = val_path/f'unet_results/{bsegm.cfg["exp_name"]}/'
val_path_res.mkdir(exist_ok=True, parents=True)

results = []
for case in tqdm(val_path.iterdir(), total=5):
    if case.is_dir() and 'unet_results' not in case.name:
        
        # read images
        img = sitk.ReadImage(str(case/f'{case.name}.nii.gz'))
        imgnp = sitk.GetArrayFromImage(img)
        gt_seg = sitk.ReadImage(str(case/f'{case.name}_seg.nii.gz'))
        gt_segnp = sitk.GetArrayFromImage(gt_seg)
        
        # load the prior segmentation if needed
        prior_segm = None
        if bsegm.cfg['model']['in_channels'] == 2:
            proc_path = str(case/f'{case.name}.nii.gz').replace('data', 'proc_data')
            proc_path = proc_path.replace('.nii.gz', f'{SEGM_2_CH_NAME}.nii.gz')
            prior_segm = sitk.ReadImage(proc_path)
            prior_segm = sitk.GetArrayFromImage(prior_segm)
        
        # make prediction
        pred_seg, seg_res = bsegm.segment_and_compare(imgnp, gt_segnp,
                                                      ssegm_image=prior_segm,
                                                      ensemble=ENSEMBLE)
        seg_res['case'] = case.name
        seg_res['model'] = bsegm.cfg["exp_name"]
        results.append(seg_res)
        
        # save prediction
        pred_seg_itk = sitk.GetImageFromArray(pred_seg)
        pred_seg_itk.CopyInformation(gt_seg)
        pred_seg_itk.SetOrigin(gt_seg.GetOrigin())
        pred_seg_itk.SetDirection(gt_seg.GetDirection())
        pred_seg_itk.SetSpacing(gt_seg.GetSpacing())
        sitk.WriteImage(pred_seg_itk, str(val_path_res/f'{case.name}_seg.nii.gz'))
        
results_df = pd.DataFrame(results)
results_df.to_csv(val_path_res/'results.csv', index=False)

  mean = img[img != 0].mean()
  ret = ret.dtype.type(ret / rcount)
  ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,
  arrmean = um.true_divide(arrmean, div, out=arrmean,
  ret = ret.dtype.type(ret / rcount)
  img = (img - mean) / std
  img = (img - mean) / std
6it [15:17, 152.88s/it]                       


In [14]:
results_df

Unnamed: 0,CSF,WM,GM,avg_dice,case,model
0,0.936548,0.954281,0.950436,0.947088,IBSR_14,unet3p_augm_focal_64-32_05_synthseg_merged
1,0.920043,0.940979,0.948952,0.936658,IBSR_12,unet3p_augm_focal_64-32_05_synthseg_merged
2,0.915939,0.94172,0.95216,0.936606,IBSR_11,unet3p_augm_focal_64-32_05_synthseg_merged
3,0.901876,0.942617,0.927778,0.92409,IBSR_13,unet3p_augm_focal_64-32_05_synthseg_merged
4,0.9434,0.952122,0.933558,0.943027,IBSR_17,unet3p_augm_focal_64-32_05_synthseg_merged


In [15]:
print(bsegm.cfg['exp_name'])
print(f'Mean DSC: {results_df["avg_dice"].mean()} +- {results_df["avg_dice"].std()}')

unet3p_augm_focal_64-32_05_synthseg_merged
Mean DSC: 0.9374939017668182 +- 0.008715164084679166


In [5]:
results_df

Unnamed: 0,CSF,WM,GM,avg_dice,case,model
0,0.930618,0.947166,0.938913,0.938899,IBSR_14,unet3p_augm_focal_64-32_05_synthseg_merged
1,0.923684,0.930343,0.935176,0.929734,IBSR_12,unet3p_augm_focal_64-32_05_synthseg_merged
2,0.908819,0.930332,0.942301,0.92715,IBSR_11,unet3p_augm_focal_64-32_05_synthseg_merged
3,0.888257,0.938044,0.920439,0.91558,IBSR_13,unet3p_augm_focal_64-32_05_synthseg_merged
4,0.941193,0.94426,0.921751,0.935735,IBSR_17,unet3p_augm_focal_64-32_05_synthseg_merged


In [8]:
print(bsegm.cfg['exp_name'])
print(f'Mean DSC: {results_df["avg_dice"].mean()} +- {results_df["avg_dice"].std()}')

unet3p_augm_focal_64-32_05_synthseg_merged
Mean DSC: 0.9294197020993102 +- 0.009034903590657391


# Make Test Predictions

Be sure to have correct device and checkpoint path set up above

In [16]:
test_path = Path('/home/user0/misa_vlex/brain_segmentation/data/Test_Set')
test_path_res = test_path/f'unet_results/{bsegm.cfg["exp_name"]}/'
test_path_res.mkdir(exist_ok=True, parents=True)

results = []
for case in tqdm(test_path.iterdir(), total=3):
    if case.is_dir() and 'unet_results' not in case.name:
        
        # read images
        img = sitk.ReadImage(str(case/f'{case.name}.nii.gz'))
        imgnp = sitk.GetArrayFromImage(img)
        
        # load the prior segmentation if needed
        prior_segm = None
        if bsegm.cfg['model']['in_channels'] == 2:
            proc_path = str(case/f'{case.name}.nii.gz').replace('data', 'proc_data')
            proc_path = proc_path.replace('.nii.gz', f'{SEGM_2_CH_NAME}.nii.gz')
            prior_segm = sitk.ReadImage(proc_path)
            prior_segm = sitk.GetArrayFromImage(prior_segm)
        
        if ENSEMBLE:
            pred_seg = bsegm.segment_ensemble(imgnp, progress=False,
                                              ssegm_image_xyz=prior_segm,
                                              )
        else:
            # make prediction
            pred_seg = bsegm.segment(imgnp, progress=False,
                                    ssegm_image=prior_segm,
                                    )

        # save prediction
        pred_seg_itk = sitk.GetImageFromArray(pred_seg)
        pred_seg_itk.CopyInformation(img)
        pred_seg_itk.SetOrigin(img.GetOrigin())
        pred_seg_itk.SetDirection(img.GetDirection())
        pred_seg_itk.SetSpacing(img.GetSpacing())
        
        if ENSEMBLE:
            sitk.WriteImage(pred_seg_itk, str(test_path_res/f'{case.name}_seg_ens.nii.gz'))
        else:
            sitk.WriteImage(pred_seg_itk, str(test_path_res/f'{case.name}_seg.nii.gz'))

  mean = img[img != 0].mean()
  ret = ret.dtype.type(ret / rcount)
  ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,
  arrmean = um.true_divide(arrmean, div, out=arrmean,
  ret = ret.dtype.type(ret / rcount)
  img = (img - mean) / std
  img = (img - mean) / std
100%|██████████| 3/3 [06:05<00:00, 115.17s/it]