In [None]:
import json
import pathlib
import random

import numpy as np
import matplotlib.pyplot as plt

import imageio
from skimage import transform

from IPython import display

In [None]:
import tensorflow as tf

In [None]:
# Makes it so any changes in pymedphys is automatically
# propagated into the notebook without needing a kernel reset.
from IPython.lib.deepreload import reload
%load_ext autoreload
%autoreload 2

In [None]:
import pymedphys
from pymedphys._experimental.autosegmentation import indexing, filtering, pipeline, mask

In [None]:
mask_expansion = 5

In [None]:
# Create masks for the following structures, in the following order
structures_to_learn = [
    'eye_left', 'eye_right', 'patient']

# Use the following to filter the slices used for training, validation,
# and testing
filters = {
    "study_set_must_have_all_of": structures_to_learn,
    "slice_at_least_one_of": [
        'lens_left', 'lens_right', 'eye_left', 'eye_right'
    ],
    "slice_must_have": ['patient'],
    "slice_cannot_have": []
}

In [None]:
structure_uids, ct_uids = pipeline.get_filtered_uids(filters)

In [None]:
dataset = pipeline.create_dataset(ct_uids, structures_to_learn, expansion=mask_expansion)

In [None]:
def diagnostic_plotting(x_grid, y_grid, input_array, output_array):
    plt.figure(figsize=(15,10))
    
    x_grid = x_grid.numpy()
    y_grid = y_grid.numpy()
    input_array = input_array.numpy()[:,:,0]
    output_array = output_array.numpy()
    
    for i, structure in enumerate(structures_to_learn[0:-1]):
        if structure.endswith('left'):
            colour = 'r'
        elif structure.endswith('right'):
            colour = 'b'
        else:
            raise ValueError("Expected either left or right")
            
        if structure.startswith('lens'):
            colour += '--'
        elif structure.startswith('eye'):
            colour += '-'
        else:
            raise ValueError("Expected either eye or lens")

        contours = mask.get_contours_from_mask(
            x_grid, y_grid, output_array[:,:,i])
        for contour in contours:
            plt.plot(*contour.T, colour)
            
    
    contours = mask.get_contours_from_mask(
        x_grid, y_grid, output_array[:,:,-1])
    for contour in contours:
        plt.plot(*contour.T, 'k--')
        
    plt.axis('equal')
    ax = plt.gca()
    xlim = ax.get_xlim()
    ylim = ax.get_ylim()
    
#     windowed = np.copy(input_array)

#     vmin = 900
#     vmax = 1200
#     windowed[windowed<vmin] = vmin
#     windowed[windowed>vmax] = vmax

    plt.pcolormesh(x_grid, y_grid, input_array, shading="nearest")
    plt.colorbar()
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)

In [None]:
for ct_uid, x_grid, y_grid, input_array, output_array in dataset.take(1):
    ct_uid = ct_uid.numpy().decode()
    
    display.display(display.Markdown(f"## {ct_uid}"))
    diagnostic_plotting(x_grid, y_grid, input_array, output_array)
    plt.show()

In [None]:
scaled_input_array = tf.convert_to_tensor(transform.downscale_local_mean(input_array, (16, 16, 1)))
scaled_output_array = tf.convert_to_tensor(transform.downscale_local_mean(output_array, (16, 16, 1)))

In [None]:
new_x_grid = tf.convert_to_tensor(x_grid[8::16])
new_y_grid = tf.convert_to_tensor(y_grid[8::16])

In [None]:
diagnostic_plotting(new_x_grid, new_y_grid, scaled_input_array, scaled_output_array)

In [None]:
max_hu = 4095
min_hu = 0

In [None]:
new_input_array = scaled_input_array.numpy()[:, :, 0]
new_input_array[new_input_array > max_hu] = max_hu

In [None]:
hu_scale = (max_hu + 1) / 256
hu_scale

In [None]:
hu_scaled_to_uint8 = (new_input_array / hu_scale).astype(np.uint8)
hu_scaled_to_uint8

In [None]:
masks_scaled_to_uint8 = ((scaled_output_array.numpy() + 1)/2 * 255).astype(np.uint8)

In [None]:
imageio.imwrite(f'{ct_uid}_image.png', hu_scaled_to_uint8)
imageio.imwrite(f'{ct_uid}_mask.png', masks_scaled_to_uint8)

In [None]:
scaled_input_array