In [1]:
import os
import sys
from pathlib import Path

sys.path.append(str(Path().cwd().parent))

In [2]:
import numpy as np
from tqdm import tqdm
from dataset.patch_dataset import BrainPatchesDataModule
from models.UNetModule import UNet3
from dataset.roi_extraction import slice_image, reconstruct_patches
from utils import z_score_norm
import SimpleITK as sitk
import torch
from models.EM import ExpectationMaximization
import cv2
import matplotlib.pyplot as plt
import pandas as pd
from brain_segmenter import BrainSegmenter

  from .autonotebook import tqdm as notebook_tqdm


Select GPU index to run the prediction on and path for the checkpoint used to load the model

In [7]:
### CHANGE ONLY THIS 2 LINES ###
DEVICE = 'cuda:2'
CHKP_PATH = Path('/home/user0/misa_vlex/brain_segmentation/outputs/unet3p_augm_focal_64-32_05_synthseg/version_0/checkpoints/epoch=08-valid_dsc_macro_epoch=0.9233.ckpt')
### DON'T CHANGE ANYTHING ELSE ###

bsegm = BrainSegmenter(CHKP_PATH, DEVICE)

d = bsegm.cfg['dataset']['patches']['denoiser']
SEGM_2_CH_NAME = '_seg_resampled_merged' if d == 'synthseg_merged' else '_seg_resampled'

Model loaded


# Get Validation Results

In [8]:
val_path = Path('/home/user0/misa_vlex/brain_segmentation/data/Validation_Set')
val_path_res = val_path/f'unet_results/{bsegm.cfg["exp_name"]}/'
val_path_res.mkdir(exist_ok=True, parents=True)

results = []
for case in tqdm(val_path.iterdir(), total=5):
    if case.is_dir() and 'unet_results' not in case.name:
        
        # read images
        img = sitk.ReadImage(str(case/f'{case.name}.nii.gz'))
        imgnp = sitk.GetArrayFromImage(img)
        gt_seg = sitk.ReadImage(str(case/f'{case.name}_seg.nii.gz'))
        gt_segnp = sitk.GetArrayFromImage(gt_seg)
        
        # load the prior segmentation if needed
        prior_segm = None
        if bsegm.cfg['model']['in_channels'] == 2:
            proc_path = str(case/f'{case.name}.nii.gz').replace('data', 'proc_data')
            proc_path = proc_path.replace('.nii.gz', f'{SEGM_2_CH_NAME}.nii.gz')
            prior_segm = sitk.ReadImage(proc_path)
            prior_segm = sitk.GetArrayFromImage(prior_segm)
        
        # make prediction
        pred_seg, seg_res = bsegm.segment_and_compare(imgnp, gt_segnp,
                                                      ssegm_image=prior_segm)
        seg_res['case'] = case.name
        seg_res['model'] = bsegm.cfg["exp_name"]
        results.append(seg_res)
        
        # save prediction
        pred_seg_itk = sitk.GetImageFromArray(pred_seg)
        pred_seg_itk.CopyInformation(gt_seg)
        pred_seg_itk.SetOrigin(gt_seg.GetOrigin())
        pred_seg_itk.SetDirection(gt_seg.GetDirection())
        pred_seg_itk.SetSpacing(gt_seg.GetSpacing())
        sitk.WriteImage(pred_seg_itk, str(val_path_res/f'{case.name}_seg.nii.gz'))
        
results_df = pd.DataFrame(results)
results_df.to_csv(val_path_res/'results.csv', index=False)

  mean = img[img != 0].mean()
  ret = ret.dtype.type(ret / rcount)
  ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,
  arrmean = um.true_divide(arrmean, div, out=arrmean,
  ret = ret.dtype.type(ret / rcount)
  img = (img - mean) / std
  img = (img - mean) / std
6it [01:40, 16.82s/it]                       


In [9]:
results_df

Unnamed: 0,CSF,WM,GM,avg_dice,case,model
0,0.885856,0.917282,0.90092,0.901353,IBSR_14,unet3p_augm_focal_64-32_05_synthseg
1,0.887855,0.892263,0.891577,0.890565,IBSR_12,unet3p_augm_focal_64-32_05_synthseg
2,0.834077,0.903032,0.920421,0.885843,IBSR_11,unet3p_augm_focal_64-32_05_synthseg
3,0.839818,0.91042,0.872146,0.874128,IBSR_13,unet3p_augm_focal_64-32_05_synthseg
4,0.890016,0.914865,0.874364,0.893082,IBSR_17,unet3p_augm_focal_64-32_05_synthseg


In [10]:
print(bsegm.cfg['exp_name'])
print(f'Mean DSC: {results_df["avg_dice"].mean()} +- {results_df["avg_dice"].std()}')

unet3p_augm_focal_64-32_05_synthseg
Mean DSC: 0.8889940902331785 +- 0.010035363825905464


In [16]:
print(bsegm.cfg['exp_name'])
print(f'Mean DSC: {results_df["avg_dice"].mean()} +- {results_df["avg_dice"].std()}')


unet3p_augm_focal_128-32_05_synthseg_merged
Mean DSC: 0.9225410451458309 +- 0.010014093129348802


In [11]:
print(bsegm.cfg['exp_name'])
print(f'Mean DSC: {results_df["avg_dice"].mean()} +- {results_df["avg_dice"].std()}')

unet3p_augm_focal_64-32_05_synthseg_merged
Mean DSC: 0.92883864828338 +- 0.00833273425310916


# Make Test Predictions

Be sure to have correct device and checkpoint path set up above

In [6]:
test_path = Path('/home/user0/misa_vlex/brain_segmentation/data/Test_Set')
test_path_res = test_path/f'unet_results/{bsegm.cfg["exp_name"]}/'
test_path_res.mkdir(exist_ok=True, parents=True)

results = []
for case in tqdm(test_path.iterdir(), total=3):
    if case.is_dir() and 'unet_results' not in case.name:
        
        # read images
        img = sitk.ReadImage(str(case/f'{case.name}.nii.gz'))
        imgnp = sitk.GetArrayFromImage(img)
        
        # load the prior segmentation if needed
        prior_segm = None
        if bsegm.cfg['model']['in_channels'] == 2:
            proc_path = str(case/f'{case.name}.nii.gz').replace('data', 'proc_data')
            proc_path = proc_path.replace('.nii.gz', f'{SEGM_2_CH_NAME}.nii.gz')
            prior_segm = sitk.ReadImage(proc_path)
            prior_segm = sitk.GetArrayFromImage(prior_segm)
        
        
        # make prediction
        pred_seg = bsegm.segment(imgnp, progress=False,
                                 ssegm_image=prior_segm)

        # save prediction
        pred_seg_itk = sitk.GetImageFromArray(pred_seg)
        pred_seg_itk.CopyInformation(img)
        pred_seg_itk.SetOrigin(img.GetOrigin())
        pred_seg_itk.SetDirection(img.GetDirection())
        pred_seg_itk.SetSpacing(img.GetSpacing())
        sitk.WriteImage(pred_seg_itk, str(test_path_res/f'{case.name}_seg.nii.gz'))

  mean = img[img != 0].mean()
  ret = ret.dtype.type(ret / rcount)
  ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,
  arrmean = um.true_divide(arrmean, div, out=arrmean,
  ret = ret.dtype.type(ret / rcount)
  img = (img - mean) / std
  img = (img - mean) / std
4it [01:00, 15.11s/it]                       
