In [None]:
import json
import pathlib
import random

import numpy as np
import matplotlib.pyplot as plt

from IPython import display

In [None]:
import tensorflow as tf

In [None]:
# Makes it so any changes in pymedphys is automatically
# propagated into the notebook without needing a kernel reset.
from IPython.lib.deepreload import reload
%load_ext autoreload
%autoreload 2

In [None]:
!pip install "pymedphys>=0.31.0" --no-deps
!pip install "pydicom>=2"

In [None]:
import pymedphys
from pymedphys._experimental.autosegmentation import indexing, filtering, pipeline, mask

In [None]:
ct_uid_to_use = '1.2.840.113704.1.111.3096.1537312918.112198'

In [None]:
# Create masks for the following structures, in the following order
structures_to_learn = [
    'lens_left', 'lens_right', 'eye_left', 'eye_right', 'patient']

# Use the following to filter the slices used for training, validation,
# and testing
filters = {
    "study_set_must_have_all_of": structures_to_learn,
    "slice_at_least_one_of": [
        'lens_left', 'lens_right', 'eye_left', 'eye_right'
    ],
    "slice_must_have": ['patient'],
    "slice_cannot_have": []
}

In [None]:
binary_mask_dataset = pipeline.create_dataset(
    [ct_uid_to_use], structures_to_learn, expansion=1)

floating_point_edge_dataset = pipeline.create_dataset(
    [ct_uid_to_use], structures_to_learn, expansion=5)

In [None]:
def diagnostic_plotting(x_grid, y_grid, input_array, output_array):
    plt.figure(figsize=(15,10))
    
    x_grid = x_grid.numpy()
    y_grid = y_grid.numpy()
    input_array = input_array.numpy()[:,:,0]
    output_array = output_array.numpy()
    
    for i, structure in enumerate(structures_to_learn[0:-1]):
        if structure.endswith('left'):
            colour = 'r'
        elif structure.endswith('right'):
            colour = 'b'
        else:
            raise ValueError("Expected either left or right")
            
        if structure.startswith('lens'):
            colour += '--'
        elif structure.startswith('eye'):
            colour += '-'
        else:
            raise ValueError("Expected either eye or lens")

        contours = mask.get_contours_from_mask(
            x_grid, y_grid, output_array[:,:,i])
        for contour in contours:
            plt.plot(*contour.T, colour)
            
    
    plt.axis('equal')
    ax = plt.gca()
    xlim = ax.get_xlim()
    ylim = ax.get_ylim()
    
    contours = mask.get_contours_from_mask(
        x_grid, y_grid, output_array[:,:,-1])
    for contour in contours:
        plt.plot(*contour.T, 'k--')
    
    windowed = np.copy(input_array)

    vmin = 900
    vmax = 1200
    windowed[windowed<vmin] = vmin
    windowed[windowed>vmax] = vmax

    plt.contourf(x_grid, y_grid, windowed, 50)
    plt.colorbar()
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)

## Contours produced from binary masks

In [None]:
for ct_uid, x_grid, y_grid, input_array, output_array in binary_mask_dataset:    
    ct_uid = ct_uid.numpy().decode()
    
    display.display(display.Markdown(f"## {ct_uid}"))
    diagnostic_plotting(x_grid, y_grid, input_array, output_array)
    plt.show()

## Contours produced where the edge pixel of a mask is able to be a floating point

In [None]:
for ct_uid, x_grid, y_grid, input_array, output_array in floating_point_edge_dataset:    
    ct_uid = ct_uid.numpy().decode()
    
    display.display(display.Markdown(f"## {ct_uid}"))
    diagnostic_plotting(x_grid, y_grid, input_array, output_array)
    plt.show()