In [None]:
import pathlib
import json
import shutil
import random
import zipfile

import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf

from IPython import display

import pydicom

In [None]:
# Makes it so any changes in pymedphys is automatically
# propagated into the notebook without needing a kernel reset.
from IPython.lib.deepreload import reload
%load_ext autoreload
%autoreload 2

In [None]:
import pymedphys
from pymedphys.labs.autosegmentation import pipeline, filtering, indexing, mask, tfrecord
from pymedphys._data import zenodo, download

In [None]:
record_name = "auto-segmentation-eye-lens-patient-npz"

paths = pymedphys.zenodo_data_paths(record_name)

npz_paths = [path for path in paths if path.suffix == ".npz"]
# npz_paths

In [None]:
for path in npz_paths:
    with zipfile.ZipFile(path.with_suffix('.zip'), 'w', compression=zipfile.ZIP_LZMA) as z:
        z.write(path, arcname=path.name)

In [None]:
npz_paths = list(pathlib.Path.home().joinpath('.data', 'dicom-ct-and-structures', 'npz_cache').glob('*1.2.840.113704.1.111.3880*.npz'))
npz_paths

In [None]:
for path in npz_paths:
    with zipfile.ZipFile(path.with_suffix('.zip'), 'w', compression=zipfile.ZIP_LZMA) as z:
        z.write(path, arcname=path.name)

In [None]:
def single_dataset_from_zenodo_download(
    record_name, ct_uids, structures_to_learn
):
    npz_download_directory = download.get_data_dir().joinpath(record_name)
    npz_paths = {
        ct_uid: npz_download_directory.joinpath(f'{ct_uid}.npz')
        for ct_uid in ct_uids
    }       
    
    def generator():
        for ct_uid, npz_path in npz_paths.items():
            if not npz_path.exists():
                downloaded_path = pymedphys.zenodo_data_paths(
                    record_name, 
                    filenames=[npz_path.name]
                )[0]
                
                if downloaded_path != npz_path:
                    raise ValueError("Expected the downloaded path to match the predicted npz_path")
                    
            data = np.load(npz_path)
            x_grid = data["x_grid"]
            y_grid = data["y_grid"]
            input_array = data["input_array"]
            output_array = data["output_array"]
            
            input_array = input_array[:, :, None]

            yield ct_uid, x_grid, y_grid, input_array, output_array

    parameters = (
        (tf.string, tf.float64, tf.float64, tf.int32, tf.float64),
        (
            tf.TensorShape(()),
            tf.TensorShape([512]),
            tf.TensorShape([512]),
            tf.TensorShape([512, 512, 1]),
            tf.TensorShape([512, 512, len(structures_to_learn)]),
        ),
    )

    dataset = tf.data.Dataset.from_generator(generator, *parameters)

    return dataset

In [None]:
def datasets_from_zenodo_download(
    record_name
):
    filenames_to_download = ['ct_uids_by_training_type.zip', 'structures_to_learn.zip']

    configuration_paths = pymedphys.zenodo_data_paths(
        record_name, 
        filenames=filenames_to_download)

    configurations = {}
    for path in configuration_paths:
        with open(path) as f:
            configurations[path.stem] = json.load(f)

    ct_uids_by_training_type = configurations['ct_uids_by_training_type']
    structures_to_learn = configurations['structures_to_learn']
    
    datasets = {}
    for training_type, ct_uids in ct_uids_by_training_type.items():
        random.shuffle(ct_uids)
        
        datasets[training_type] = single_dataset_from_zenodo_download(
            record_name, ct_uids, structures_to_learn
        )
        
    return datasets, structures_to_learn

In [None]:
record_name = "auto-segmentation-eye-lens-patient-npz"
datasets, structures_to_learn = datasets_from_zenodo_download(record_name)

In [None]:
datasets

In [None]:
def diagnostic_plotting(x_grid, y_grid, input_array, output_array):
    plt.figure(figsize=(15,10))
    
    x_grid = x_grid.numpy()
    y_grid = y_grid.numpy()
    input_array = input_array.numpy()[:,:,0]
    output_array = output_array.numpy()
    
    for i, structure in enumerate(structures_to_learn[0:-1]):
        if structure.endswith('left'):
            colour = 'r'
        elif structure.endswith('right'):
            colour = 'b'
        else:
            raise ValueError("Expected either left or right")
            
        if structure.startswith('lens'):
            colour += '--'
        elif structure.startswith('eye'):
            colour += '-'
        else:
            raise ValueError("Expected either eye or lens")

        contours = mask.get_contours_from_mask(
            x_grid, y_grid, output_array[:,:,i])
        for contour in contours:
            plt.plot(*contour.T, colour)
            
    
    plt.axis('equal')
    ax = plt.gca()
    xlim = ax.get_xlim()
    ylim = ax.get_ylim()
    
    contours = mask.get_contours_from_mask(
        x_grid, y_grid, output_array[:,:,-1])
    for contour in contours:
        plt.plot(*contour.T, 'k--')
    
    windowed = np.copy(input_array)

    vmin = 900
    vmax = 1200
    windowed[windowed<vmin] = vmin
    windowed[windowed>vmax] = vmax

    plt.contourf(x_grid, y_grid, windowed, 50)
    plt.colorbar()
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)

In [None]:
for (
    ct_uid, x_grid, y_grid, input_array, output_array
) in datasets['training'].prefetch(30):
    
    ct_uid = ct_uid.numpy().decode()
    
    display.display(display.Markdown(f"## {ct_uid}"))
    diagnostic_plotting(x_grid, y_grid, input_array, output_array)
    plt.show()

In [None]:
for (
    ct_uid, x_grid, y_grid, input_array, output_array
) in datasets['validation'].prefetch(30):
    
    ct_uid = ct_uid.numpy().decode()
    
    display.display(display.Markdown(f"## {ct_uid}"))
    diagnostic_plotting(x_grid, y_grid, input_array, output_array)
    plt.show()

In [None]:
for (
    ct_uid, x_grid, y_grid, input_array, output_array
) in datasets['testing'].prefetch(30):
    
    ct_uid = ct_uid.numpy().decode()
    
    display.display(display.Markdown(f"## {ct_uid}"))
    diagnostic_plotting(x_grid, y_grid, input_array, output_array)
    plt.show()