In [None]:
import os
import json
import warnings
import pathlib

import numpy as np
import matplotlib.pyplot as plt
from skimage import transform
import imageio

import nibabel

import pymedphys

In [None]:
download_root = pymedphys._config.get_config_dir().joinpath('data', 'medical-decathlon')
download_root.mkdir(exist_ok=True, parents=True)

hash_path = download_root.joinpath('hashes.json')

In [None]:
base_download_url = 'https://github.com/pymedphys/data/releases/download'

In [None]:
tasks = [
    "Task01_BrainTumour",
    "Task02_Heart",
#     "Task03_Liver",
#     "Task04_Hippocampus",
#     "Task05_Prostate",
#     "Task06_Lung",
#     "Task07_Pancreas",
#     "Task08_HepaticVessel",
#     "Task09_Spleen",
#     "Task10_Colon"
]

tasks

In [None]:
def get_metadata(task):
    download_url = f"{base_download_url}/{task}/dataset.json"
    metadata_path = pymedphys.data_path(
        f"medical-decathlon/{task}/dataset.json",
        url=download_url,
        hash_filepath=hash_path
    )
    
    with open(metadata_path) as f:
        metadata = json.load(f)
        
    return metadata

In [None]:
def download_task_path(task, path):
    url = f"{base_download_url}/{task}/{path[2:].replace('/', '--os.sep--')}"
    full_path = pathlib.Path(f"medical-decathlon/{task}").joinpath(path)
    
    return pymedphys.data_path(
        full_path,
        url=url,
        hash_filepath=hash_path,
        delete_when_no_hash_found=False
    )


def get_filepaths_for_task(task):
    metadata = get_metadata(task)
        
    for paths in metadata['training']:
        image_path = download_task_path(task, paths['image'])
        label_path = download_task_path(task, paths['label'])
        
        yield image_path, label_path

In [None]:
task = 'Task06_Lung'

In [None]:
gen = get_filepaths_for_task(task)

In [None]:
def get_contours_from_mask(x_grid, y_grid, mask, contour_level=0.5):
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", UserWarning)
        fig, ax = plt.subplots()
        cs = ax.contour(x_grid, y_grid, mask, [contour_level])

    contours = [path.vertices for path in cs.collections[0].get_paths()]
    plt.close(fig)

    return contours

In [None]:
max_hu = 4095
min_hu = 0

hu_scale = (max_hu + 1) / 256

dimension_downscale = int(4)

In [None]:
def hu_transform(hu):
    new_hu = (hu + 1024)
    if np.any(new_hu < 0):
        raise ValueError()
    
    new_hu[new_hu > max_hu] = max_hu
    
    return (new_hu / hu_scale).astype(np.uint8)

In [None]:
for image_path, label_path in gen:
    image = nibabel.load(image_path).get_fdata()
    label = nibabel.load(label_path).get_fdata()
    
    patient_identifier = image_path.name.split('.')[0]
    patient_directory = f"data/{patient_identifier}"
    try:
        os.mkdir(patient_directory)
    except FileExistsError:
        pass

    for i in range(image.shape[2]):
        filename_base = f"{patient_directory}/{str(i).zfill(6)}"

        x, y = np.arange(512), np.arange(512)

        current_image = image[:,-1::-1,i].T
        current_mask = label[:,-1::-1,i].T



#         contours = get_contours_from_mask(
#             x, y, current_mask)


#         plt.figure(figsize=(12,12))
#         plt.pcolormesh(x, y, current_image, shading='nearest')
#         plt.colorbar()
#         plt.axis('equal')

#         for contour in contours:
#             plt.plot(*contour.T, 'k')

#         plt.gca().invert_yaxis()
#         plt.show()


        shrunk_image = transform.downscale_local_mean(current_image, (dimension_downscale, dimension_downscale))
        shrunk_mask = transform.downscale_local_mean(current_mask, (dimension_downscale, dimension_downscale))

        shrunk_image[shrunk_image > max_hu] = max_hu
        hu_scaled_to_uint8 = hu_transform(shrunk_image)
        masks_scaled_to_uint8 = (shrunk_mask * 255).astype(np.uint8)

#         print(np.max(hu_scaled_to_uint8))

        shrunk_x = x[dimension_downscale//2::dimension_downscale]
        shrunk_y = y[dimension_downscale//2::dimension_downscale]

        shrunk_contours = get_contours_from_mask(
            shrunk_x, shrunk_y, masks_scaled_to_uint8, contour_level=128)

#         plt.figure(figsize=(12,12))
#         plt.pcolormesh(shrunk_x, shrunk_y, hu_scaled_to_uint8, shading='nearest')
#         plt.colorbar()
#         plt.axis('equal')

#         for contour in shrunk_contours:
#             plt.plot(*contour.T, 'k')

#         for contour in contours:
#             plt.plot(*contour.T, 'k--')

#         plt.gca().invert_yaxis()
#         plt.show()

        imageio.imwrite(f'{filename_base}_image.png', hu_scaled_to_uint8)
        imageio.imwrite(f'{filename_base}_mask.png', masks_scaled_to_uint8)