In [23]:
import nibabel as nib
import torch
import numpy as np
import SimpleITK as sitk

# histogram.py

In [24]:

import numpy as np
import numpy.ma as ma


DEFAULT_CUTOFF = (0.01, 0.99)


# Functions from NiftyNet

def __compute_percentiles(img, mask, cutoff):
    """
    Creates the list of percentile values to be used as landmarks for the
    linear fitting.

    :param img: Image on which to determine the percentiles
    :param mask: Mask to use over the image to constraint to the relevant
    information
    :param cutoff: Values of the minimum and maximum percentiles to use for
    the linear fitting
    :return perc_results: list of percentiles value for the given image over
    the mask
    """
    perc = [cutoff[0],
            0.1, 0.2, 0.25, 0.3, 0.4, 0.5, 0.6, 0.7, 0.75, 0.8, 0.9,
            cutoff[1]]
    masked_img = ma.masked_array(img, np.logical_not(mask)).compressed()
    perc_results = np.percentile(masked_img, 100 * np.array(perc))
    return perc_results


def __standardise_cutoff(cutoff, type_hist='percentile'):
    """
    Standardises the cutoff values given in the configuration

    :param cutoff:
    :param type_hist: Type of landmark normalisation chosen (median,
    quartile, percentile)
    :return cutoff: cutoff with appropriate adapted values
    """
    cutoff = np.asarray(cutoff)
    if cutoff is None:
        return DEFAULT_CUTOFF
    if len(cutoff) > 2:
        cutoff = np.unique([np.min(cutoff), np.max(cutoff)])
    if len(cutoff) < 2:
        return DEFAULT_CUTOFF
    if cutoff[0] > cutoff[1]:
        cutoff[0], cutoff[1] = cutoff[1], cutoff[0]
    cutoff[0] = max(0., cutoff[0])
    cutoff[1] = min(1., cutoff[1])
    if type_hist == 'quartile':
        cutoff[0] = np.min([cutoff[0], 0.24])
        cutoff[1] = np.max([cutoff[1], 0.76])
    else:
        cutoff[0] = np.min([cutoff[0], 0.09])
        cutoff[1] = np.max([cutoff[1], 0.91])
    return cutoff


def create_standard_range():
    return 0., 100.


def __averaged_mapping(perc_database, s1, s2):
    """
    Map the landmarks of the database to the chosen range
    :param perc_database: perc_database over which to perform the averaging
    :param s1, s2: limits of the mapping range
    :return final_map: the average mapping
    """
    # assuming shape: n_data_points = perc_database.shape[0]
    #                 n_percentiles = perc_database.shape[1]
    slope = (s2 - s1) / (perc_database[:, -1] - perc_database[:, 0])
    slope = np.nan_to_num(slope)
    final_map = slope.dot(perc_database) / perc_database.shape[0]
    intercept = np.mean(s1 - slope * perc_database[:, 0])
    final_map = final_map + intercept
    return final_map


def normalize(data, landmarks, cutoff=DEFAULT_CUTOFF, masking_function=None):
    mapping = landmarks

    img = data
    image_shape = img.shape
    img = img.reshape(-1).astype(np.float32)

    if masking_function is not None:
        mask = masking_function(img)
    else:
        mask = np.ones_like(img, dtype=np.bool)
    mask = mask.reshape(-1)

    range_to_use = [0, 1, 2, 4, 5, 6, 7, 8, 10, 11, 12]

    cutoff = __standardise_cutoff(cutoff)
    perc = __compute_percentiles(img, mask, cutoff)

    # Apply linear histogram standardisation
    range_mapping = mapping[range_to_use]
    range_perc = perc[range_to_use]
    diff_mapping = range_mapping[1:] - range_mapping[:-1]
    diff_perc = range_perc[1:] - range_perc[:-1]

    # handling the case where two landmarks are the same
    # for a given input image. This usually happens when
    # image background is not removed from the image.
    diff_perc[diff_perc == 0] = np.inf

    affine_map = np.zeros([2, len(range_to_use) - 1])
    # compute slopes of the linear models
    affine_map[0] = diff_mapping / diff_perc
    # compute intercepts of the linear models
    affine_map[1] = range_mapping[:-1] - affine_map[0] * range_perc[:-1]

    bin_id = np.digitize(img, range_perc[1:-1], right=False)
    lin_img = affine_map[0, bin_id]
    aff_img = affine_map[1, bin_id]
    new_img = lin_img * img + aff_img
    new_img = new_img.reshape(image_shape)

    return new_img

# preprocessing.py

In [25]:
import tempfile
from pathlib import Path
import numpy as np
import nibabel as nib
import SimpleITK as sitk

# From NiftyNet model zoo
LI_LANDMARKS = "4.4408920985e-16 8.06305571158 15.5085721044 18.7007018006 21.5032879029 26.1413278906 29.9862059045 33.8384058795 38.1891334787 40.7217966068 44.0109152758 58.3906435207 100.0"
LI_LANDMARKS = np.array([float(n) for n in LI_LANDMARKS.split()])


def preprocess(data, padding, hist_masking_function=None):
    # data = pad(data, padding)
    data = standardize(data, masking_function=hist_masking_function)
    data = whiten(data)
    data = data.astype(np.float32)
    data = pad(data, padding)  # should I pad at the beginning instead?
    return data


def pad(data, padding):
    # Should I use this value for padding?
    value = data[0, 0, 0]
    return np.pad(data, padding, mode='constant', constant_values=value)


def crop(data, padding):
    p = padding
    return data[p:-p, p:-p, p:-p]


def standardize(data, landmarks=LI_LANDMARKS, masking_function=None):
    return normalize(data, landmarks, masking_function=masking_function)


def whiten(data, masking_function=None):
    if masking_function is None:
        masking_function = mean_plus
    mask_data = masking_function(data)
    values = data[mask_data]
    mean, std = values.mean(), values.std()
    data -= mean
    data /= std
    return data


def mean_plus(data):
    return data > data.mean()


def resample_spacing(nifti, output_spacing, interpolation):
    output_spacing = tuple(output_spacing)
    temp_dir = Path(tempfile.gettempdir()) / '.deepgif'
    temp_dir.mkdir(exist_ok=True)
    temp_path = temp_dir / 'deepgif_resampled.nii'
    temp_path = str(temp_path)

    nifti.to_filename(temp_path)
    image = sitk.ReadImage(temp_path)

    output_spacing = np.array(output_spacing).astype(float)
    output_spacing = tuple(output_spacing)

    reference_spacing = np.array(image.GetSpacing())
    reference_size = np.array(image.GetSize())

    output_size = reference_spacing / output_spacing * reference_size
    output_size = np.round(output_size).astype(np.uint32)
    # tuple(output_size) does not work, see
    # https://github.com/Radiomics/pyradiomics/issues/204
    output_size = output_size.tolist()

    identity = sitk.Transform(3, sitk.sitkIdentity)

    resample = sitk.ResampleImageFilter()
    resample.SetInterpolator(interpolation)
    resample.SetOutputDirection(image.GetDirection())
    resample.SetOutputOrigin(image.GetOrigin())  # TODO: double-check that this is correct
    resample.SetOutputPixelType(image.GetPixelID())
    resample.SetOutputSpacing(output_spacing)
    resample.SetSize(output_size)
    resample.SetTransform(identity)
    resampled = resample.Execute(image)
    sitk.WriteImage(resampled, temp_path)
    nifti_resampled = nib.load(temp_path)
    return nifti_resampled


def resample_ras_1mm_iso(nifti, interpolation=None):
    if interpolation is None:
        interpolation = sitk.sitkLinear
    nii_ras = nib.as_closest_canonical(nifti)
    spacing = nii_ras.header.get_zooms()[:3]
    one_iso = 1, 1, 1
    if np.allclose(spacing, one_iso):
        return nii_ras
    nii_resampled = resample_spacing(
        nii_ras,
        output_spacing=one_iso,
        interpolation=interpolation,
    )
    return nii_resampled


def resample_to_reference(
        reference_path,
        floating_path,
        result_path,
        interpolation=None,
        default_value=0.0,
        ):
    if interpolation is None:
        interpolation = sitk.sitkNearestNeighbor
    reference = sitk.ReadImage(str(reference_path))
    floating = sitk.ReadImage(str(floating_path))
    transform = sitk.Transform(3, sitk.sitkIdentity)
    resampled = sitk.Resample(
        floating,
        reference,
        transform,
        interpolation,
        default_value,
        floating.GetPixelID(),
    )
    sitk.WriteImage(resampled, str(result_path))

# sampling.py

In [26]:
import numpy as np

from torch.utils.data import Dataset


class GridSampler(Dataset):
    """
    Adapted from NiftyNet
    """
    def __init__(self, data, window_size, border):
        self.array = data
        self.locations = self.grid_spatial_coordinates(
            self.array,
            window_size,
            border,
        )

    def __len__(self):
        return len(self.locations)

    def __getitem__(self, index):
        # Assume 3D
        location = self.locations[index]
        i_ini, j_ini, k_ini, i_fin, j_fin, k_fin = location
        window = self.array[i_ini:i_fin, j_ini:j_fin, k_ini:k_fin]
        window = window[np.newaxis, ...]  # add channels dimension
        sample = dict(
            image=window,
            location=location,
        )
        return sample

    @staticmethod
    def _enumerate_step_points(starting, ending, win_size, step_size):
        starting = max(int(starting), 0)
        ending = max(int(ending), 0)
        win_size = max(int(win_size), 1)
        step_size = max(int(step_size), 1)
        if starting > ending:
            starting, ending = ending, starting
        sampling_point_set = []
        while (starting + win_size) <= ending:
            sampling_point_set.append(starting)
            starting = starting + step_size
        additional_last_point = ending - win_size
        sampling_point_set.append(max(additional_last_point, 0))
        sampling_point_set = np.unique(sampling_point_set).flatten()
        if len(sampling_point_set) == 2:
            sampling_point_set = np.append(
                sampling_point_set, np.round(np.mean(sampling_point_set)))
        _, uniq_idx = np.unique(sampling_point_set, return_index=True)
        return sampling_point_set[np.sort(uniq_idx)]

    @staticmethod
    def grid_spatial_coordinates(array, window_shape, border):
        shape = array.shape
        num_dims = len(shape)
        grid_size = [
            max(win_size - 2 * border, 0)
            for (win_size, border)
            in zip(window_shape, border)
        ]
        steps_along_each_dim = [
            GridSampler._enumerate_step_points(
                starting=0,
                ending=shape[i],
                win_size=window_shape[i],
                step_size=grid_size[i],
            )
            for i in range(num_dims)
        ]
        starting_coords = np.asanyarray(np.meshgrid(*steps_along_each_dim))
        starting_coords = starting_coords.reshape((num_dims, -1)).T
        n_locations = starting_coords.shape[0]
        # prepare the output coordinates matrix
        spatial_coords = np.zeros((n_locations, num_dims * 2), dtype=np.int32)
        spatial_coords[:, :num_dims] = starting_coords
        for idx in range(num_dims):
            spatial_coords[:, num_dims + idx] = (
                starting_coords[:, idx]
                + window_shape[idx]
            )
        max_coordinates = np.max(spatial_coords, axis=0)[num_dims:]
        assert np.all(max_coordinates <= shape[:num_dims]), \
            "window size greater than the spatial coordinates {} : {}".format(
                max_coordinates, shape)
        return spatial_coords


class GridAggregator:
    """
    Adapted from NiftyNet
    """
    def __init__(self, data, window_border):
        self.window_border = window_border
        self.output_array = np.full(
            data.shape,
            fill_value=0,
            dtype=np.uint16,
        )

    @staticmethod
    def crop_batch(windows, location, border=None):
        if not border:
            return windows, location
        location = location.astype(np.int)
        batch_shape = windows.shape
        spatial_shape = batch_shape[2:]  # ignore batch and channels dim
        num_dimensions = 3
        for idx in range(num_dimensions):
            location[:, idx] = location[:, idx] + border[idx]
            location[:, idx + 3] = location[:, idx + 3] - border[idx]
        if np.any(location < 0):
            return windows, location

        cropped_shape = np.max(location[:, 3:6] - location[:, 0:3], axis=0)
        diff = spatial_shape - cropped_shape
        left = np.floor(diff / 2).astype(np.int)
        i_ini, j_ini, k_ini = left
        i_fin, j_fin, k_fin = left + cropped_shape
        if np.any(left < 0):
            raise ValueError
        batch = windows[
            :,  # batch dimension
            :,  # channels dimension
            i_ini:i_fin,
            j_ini:j_fin,
            k_ini:k_fin,
        ]
        return batch, location

    def add_batch(self, windows, locations):
        windows = windows.cpu()
        location_init = np.copy(locations)
        init_ones = np.ones_like(windows)
        windows, _ = self.crop_batch(
            windows, location_init,
            self.window_border,
        )
        location_init = np.copy(locations)
        _, locations = self.crop_batch(
            init_ones,
            location_init,
            self.window_border,
        )
        for window, location in zip(windows, locations):
            window = window.squeeze()
            i_ini, j_ini, k_ini, i_fin, j_fin, k_fin = location
            self.output_array[i_ini:i_fin, j_ini:j_fin, k_ini:k_fin] = window

# inference.py

In [27]:
import torch
import numpy as np
import nibabel as nib
from tqdm import tqdm
from torch.utils.data import DataLoader
# from .sampling import GridSampler, GridAggregator
# from .preprocessing import (
#     crop,
#     preprocess,
#     resample_ras_1mm_iso,
#     resample_to_reference,
#     mean_plus,
# )

def infer(
        input_path,
        output_path,
        batch_size,
        window_cropping,
        volume_padding,
        window_size,
        cuda_device,
        use_niftynet_hist_std=False,
        ):
    # Read image
    nii = nib.load(str(input_path))
    needs_resampling = check_header(nii)
    if needs_resampling:
        nii = resample_ras_1mm_iso(nii)
    data = nii.get_fdata()

    # Preprocessing
    hist_masking_function = mean_plus if use_niftynet_hist_std else None
    preprocessed = preprocess(
        data,
        volume_padding,
        hist_masking_function=hist_masking_function,
    )

    # Inference
    labels = run_inference(
        preprocessed,
        get_model(),
        window_size,
        window_border=window_cropping,
        batch_size=batch_size,
        cuda_device=cuda_device,
    )

    # Postprocessing
    if volume_padding:
        labels = crop(labels, volume_padding)
    nib.Nifti1Image(labels, nii.affine).to_filename(str(output_path))

    # Resample parcellation to original dimensions
    if needs_resampling:
        resample_to_reference(
            reference_path=input_path,
            floating_path=output_path,
            result_path=output_path,
        )


def run_inference(
        data,
        model,
        window_size,
        window_border=0,
        batch_size=2,
        cuda_device=0,
        ):
    success = False
    while not success:
        window_sizes = to_tuple(window_size)
        window_border = to_tuple(window_border)

        sampler = GridSampler(data, window_sizes, window_border)
        aggregator = GridAggregator(data, window_border)
        loader = DataLoader(sampler, batch_size=batch_size)

        device = get_device(cuda_device=cuda_device)
        model.to(device)
        model.eval()

        CHANNELS_DIMENSION = 1

        try:
            with torch.no_grad():
                for batch in tqdm(loader):
                    input_tensor = batch['image'].to(device)
                    locations = batch['location']
                    logits = model(input_tensor)
                    labels = logits.argmax(dim=CHANNELS_DIMENSION, keepdim=True)
                    outputs = labels
                    aggregator.add_batch(outputs, locations)
            success = True
        except RuntimeError as e:
            print(e)
            print('Window size', window_size, 'is too large.')
            window_size = int(window_size * 0.75)
            print('Trying with smaller window size', window_size)

    return aggregator.output_array


def check_header(nifti_image):
    orientation = ''.join(nib.aff2axcodes(nifti_image.affine))
    spacing = nifti_image.header.get_zooms()[:3]
    one_iso = 1, 1, 1
    is_ras = orientation == 'RAS'
    if not is_ras:
        print(f'Detected orientation: {orientation}. Reorienting to RAS...')
    is_1_iso = np.allclose(spacing, one_iso)
    if not is_1_iso:
        print(f'Detected spacing: {spacing}. Resampling to 1 mm iso...')
    needs_resampling = not is_ras or not is_1_iso
    return needs_resampling


def get_device(cuda_device=0):
    return torch.device(
        'cuda:{}'.format(cuda_device) if torch.cuda.is_available() else 'cpu')


def to_tuple(value):
    try:
        iter(value)
    except TypeError:
        value = 3 * (value,)
    return value


def get_model():
    """
    Using PyTorch Hub as I haven't been able to install the .pth file
    within the pip package
    """
    repo = 'fepegar/highresnet'
    model_name = 'highres3dnet'
    model = torch.hub.load(repo, model_name, pretrained=True)
    return model

# Load data

In [28]:
import tarfile
import tempfile
import urllib.request
from pathlib import Path
from configparser import ConfigParser
import os

def get_data_url_from_model_zoo():
    url = 'https://raw.githubusercontent.com/NifTK/NiftyNetModelZoo/5-reorganising-with-lfs/highres3dnet_brain_parcellation/main.ini'
    with urllib.request.urlopen(url) as response:
        config_string = response.read().decode()
    config = ConfigParser()
    config.read_string(config_string)
    data_url = config['data']['url']
    return data_url


def download_data(data_url):
    # tempdir = Path(tempfile.gettempdir())
    tempdir = Path(os.getcwd())
    download_dir = tempdir / 'downloaded_data'
    download_dir.mkdir(exist_ok=True)
    data_path = download_dir / Path(data_url).name
    print(data_path)
    if not data_path.is_file():
        urllib.request.urlretrieve(data_url, data_path)
    with tarfile.open(data_path, 'r') as tar:
        tar.extractall(download_dir)
    nifti_files = download_dir.glob('**/*.nii.gz')
    return list(nifti_files)[0]


def test_infer():
    image_path = download_data(get_data_url_from_model_zoo())

test_infer()

D:\workspace\dwm\jupyter_basic\ch10_monai\downloaded_data\data.tar.gz


# Load model

In [29]:
repo = 'fepegar/highresnet'
model_name = 'highres3dnet'
print(torch.hub.help(repo, model_name))
"HighRes3DNet by Li et al. 2017 for T1-MRI brain parcellation"
"pretrained (bool): load parameters from pretrained model"
model = torch.hub.load(repo, model_name, pretrained=True)


    HighRes3DNet by Li et al. 2017 for T1-MRI brain parcellation
    pretrained (bool): load parameters from pretrained model
    


Using cache found in C:\Users\Daewoon/.cache\torch\hub\fepegar_highresnet_master
Using cache found in C:\Users\Daewoon/.cache\torch\hub\fepegar_highresnet_master


# RUN

In [40]:
import sys
# import click
import pathlib

input_path111 = "_data/GAAIN/AD01_MR.nii"
input_path211 = "_data/OASIS/OAS1_MR.nii.gz"
input_path311 = "_data/SNU/P01_MR.nii"
input_path312 = "_data/SNU/P01_orig.nii"
input_path321 = "_data/SNU/P02_mr.nii"
input_path322 = "_data/SNU/P02_orig.nii"
input_path331 = "_data/SNU/P03_mr.nii"
input_path332 = "_data/SNU/P03_orig.nii"

input_path341 = "_data/SNU/P04_mr.nii"
input_path342 = "_data/SNU/P04_orig.nii.gz"
input_path361 = "_data/SNU/P06_mr.nii"
input_path362 = "_data/SNU/P06_orig.nii"

input_path = input_path342

output_path = None
if output_path is None:
    input_path = pathlib.Path(input_path)
    input_name = input_path.name
    output_name = input_name.replace('.nii', '_seg.nii')
    output_path = input_path.parent / output_name
infer(
    input_path,
    output_path,
    batch_size = 1,
    window_cropping = 2,
    volume_padding = 10,
    window_size = 128,
    cuda_device = 0,
    use_niftynet_hist_std= False,# hist_niftynet,
)

Detected orientation: LIA. Reorienting to RAS...


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
Using cache found in C:\Users\Daewoon/.cache\torch\hub\fepegar_highresnet_master
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
100%|███████████████████████████████████████████████████████████████████████████████| 27/27 [00:16<00:00,  1.60it/s]


In [37]:
torch.cuda.is_available()

True