In [None]:
import zipfile
from urllib import request
import pathlib
import collections
import warnings

import numpy as np
import matplotlib.pyplot as plt
import imageio

import ipywidgets

In [None]:
# url = 'https://github.com/pymedphys/data/releases/download/mini-lung/mini-lung-medical-decathlon.zip'
# filename = url.split('/')[-1]

In [None]:
# request.urlretrieve(url, filename)

In [None]:
data_path = pathlib.Path('data')

In [None]:
# with zipfile.ZipFile(filename, 'r') as zip_ref:
#     zip_ref.extractall(data_path)

In [None]:
image_paths = sorted(data_path.glob('**/*_image.png'))

mask_paths = [
    path.parent.joinpath(path.name.replace('_image.png', '_mask.png'))
    for path in image_paths
]

In [None]:
image_mask_pairs = collections.defaultdict(lambda: [])

for image_path, mask_path in zip(image_paths, mask_paths):
    patient_label = image_path.parent.name
    
    image = imageio.imread(image_path)
    mask = imageio.imread(mask_path)
    
    image_mask_pairs[patient_label].append((image, mask))

In [None]:
def get_contours_from_mask(mask, contour_level=128):
    if np.max(mask) < contour_level:
        return []
    
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", UserWarning)
        fig, ax = plt.subplots()
        cs = ax.contour(range(mask.shape[0]), range(mask.shape[0]), mask, [contour_level])

    contours = [path.vertices for path in cs.collections[0].get_paths()]
    plt.close(fig)

    return contours

In [None]:
def display(patient_label, chosen_slice):
    image = image_mask_pairs[patient_label][chosen_slice][0]
    mask = image_mask_pairs[patient_label][chosen_slice][1]

    plt.figure(figsize=(10,10))
    plt.imshow(image, vmin=0, vmax=100)

    contours = get_contours_from_mask(mask)
    for contour in contours:
        plt.plot(*contour.T, 'r', lw=3)

In [None]:
def view_patient(patient_label):
    def view_slice(chosen_slice):
        display(patient_label, chosen_slice)
        
    number_of_slices = len(image_mask_pairs[patient_label])
    
    ipywidgets.interact(view_slice, chosen_slice=ipywidgets.IntSlider(min=0, max=number_of_slices, step=1, value=0));

In [None]:
ipywidgets.interact(view_patient, patient_label=image_mask_pairs.keys());