In [None]:
import keyring
import getpass
import functools
import itertools
import tempfile
import io

import numpy as np
import matplotlib.pyplot as plt

import imageio
import skimage.transform

import pydicom

import segments

In [None]:
# Makes it so any changes in pymedphys is automatically
# propagated into the notebook without needing a kernel reset.
from IPython.lib.deepreload import reload
%load_ext autoreload
%autoreload 2

In [None]:
from pymedphys._experimental.autosegmentation import pipeline, mask

In [None]:
EXPANSION = 4

In [None]:
def get_instance_id(name):
    # So that 0 isn't a category
    return category_id_map[name] + 1

In [None]:
segments_api_key = keyring.get_password('segments-ai', 'api-key')

if not segments_api_key:
    segments_api_key = getpass.getpass()
    keyring.set_password('segments-ai', 'api-key', segments_api_key)

In [None]:
client = segments.SegmentsClient(segments_api_key)
dataset_name = 'SimonBiggs/AnimalContours' # Name of a dataset you've created on Segments.ai

In [None]:
dataset = client.get_dataset(dataset_name)
dataset

In [None]:
contouring_task = [item for item in dataset['tasks'] if item['name'] == 'contouring'][0]
categories = contouring_task['attributes']['categories']
category_id_map = {
    item['name']: item['id']
    for item in categories
}

category_id_map

In [None]:
samples = client.get_samples(dataset_name)

ct_uid_to_upload_uuid = {
    item['name'].replace(".png", ""): item['uuid']
    for item in samples
}

In [None]:
samples[0]

In [None]:
client.get_label('b2b5e02a-e5a3-4cf0-a361-a2349207c930', 'contouring')

In [None]:
client.put

In [None]:
(
    data_path_root,
    structure_set_paths,
    ct_image_paths,
    ct_uid_to_structure_uid,
    structure_uid_to_ct_uids,
    names_map,
    structure_names_by_ct_uid,
    structure_names_by_structure_set_uid,
    uid_to_url,
    hash_path,
) = pipeline.get_dataset_metadata()

In [None]:
def is_mask_a_subset(subset, superset):
    return np.all(np.logical_and(subset, superset) == subset)

def cmp(x, y, masks):
    mask_x = masks[x]
    mask_y = masks[y]
    
    if is_mask_a_subset(mask_x, mask_y):
        return -1
    if is_mask_a_subset(mask_y, mask_x):
        return 1
    
    disjoint = np.logical_xor(mask_x, mask_y) == np.logical_or(mask_x, mask_y)
    
    if np.any(np.invert(disjoint)):
        raise ValueError(f"Masks ({x}, {y}) are disjoint")
    
    return 0

def create_sorting_key(masks):
    return functools.cmp_to_key(
        functools.partial(cmp, masks=masks)
    )

In [None]:
@functools.lru_cache()
def get_dcm_ct_from_uid(ct_uid):
    ct_path = ct_image_paths[ct_uid]
    dcm_ct = pydicom.read_file(ct_path, force=True)

    dcm_ct.file_meta.TransferSyntaxUID = pydicom.uid.ImplicitVRLittleEndian

    return dcm_ct

@functools.lru_cache()
def get_dcm_structure_from_uid(structure_set_uid):
    structure_set_path = structure_set_paths[structure_set_uid]

    dcm_structure = pydicom.read_file(
        structure_set_path,
        force=True,
        specific_tags=["ROIContourSequence", "StructureSetROISequence"],
    )

    return dcm_structure

@functools.lru_cache()
def get_contours_by_ct_uid_from_structure_uid(structure_set_uid):
    dcm_structure = get_dcm_structure_from_uid(structure_set_uid)

    number_to_name_map = {
        roi_sequence_item.ROINumber: names_map[roi_sequence_item.ROIName]
        for roi_sequence_item in dcm_structure.StructureSetROISequence
        if names_map[roi_sequence_item.ROIName] is not None
    }

    contours_by_ct_uid = pipeline.get_contours_by_ct_uid(dcm_structure, number_to_name_map)

    return contours_by_ct_uid

In [None]:
ct_uid, sample_uuid = list(ct_uid_to_upload_uuid.items())[0]

In [None]:
client.get_label(sample_uuid, 'contouring')

In [None]:
sample_uuid

In [None]:
client.put?

In [None]:
for ct_uid, sample_uuid in ct_uid_to_upload_uuid.items():
    current_label_data = client.get_label(sample_uuid, 'contouring')
    print(ct_uid, sample_uuid)
    try:
        if current_label_data['label_status'] in ('PRELABELED', 'REVIEWED', 'LABELED'):
            print('Already labelled. Skipping...')            
            continue
        else:
            print(current_label_data['label_status'])
    except KeyError:
        pass
    
    
    structure_uid = ct_uid_to_structure_uid[ct_uid]
    
    ct_path = pipeline.download_uid(data_path_root, ct_uid, uid_to_url, hash_path)
    structure_path = pipeline.download_uid(data_path_root, structure_uid, uid_to_url, hash_path)
    
    dcm_ct = get_dcm_ct_from_uid(ct_uid)
    dcm_structure = get_dcm_structure_from_uid(structure_uid)
    
    grid_x, grid_y, ct_img = pipeline.create_input_ct_image(dcm_ct)
    
    contours_by_ct_uid = get_contours_by_ct_uid_from_structure_uid(
        structure_uid
    )
    
    _, _, ct_size = mask.get_grid(dcm_ct)
    ct_size = tuple(np.array(ct_size) * EXPANSION)
    
    try:
        contours_on_this_slice = contours_by_ct_uid[ct_uid].keys()
    except KeyError as e:
        print(e)
        print("Key Error in contours on slice. Skipping...")
        continue        
    
    masks = dict()

    for structure in contours_on_this_slice:
        if structure in contours_on_this_slice:
            masks[structure] = mask.calculate_expanded_mask(
                contours_by_ct_uid[ct_uid][structure],
                dcm_ct, EXPANSION
            )
        else:
            masks[structure] = np.zeros(ct_size).astype(bool)
    
    try:
        mask_assignment_order = sorted(
            list(contours_on_this_slice), 
            key=create_sorting_key(masks), reverse=True)
    except ValueError as e:
        print(e)
        print("Disjoint contours. Skipping...")
        continue
    
    objects_map = [
        {
            "id": get_instance_id(name),
            "category_id": category_id_map[name]
        }
        for name in contours_on_this_slice
    ]
    
    catagorised_mask = np.zeros(ct_size).astype(np.uint8)
    for structure_name in mask_assignment_order:
        instance_id = get_instance_id(structure_name)
        catagorised_mask[masks[structure_name]] = instance_id
        

    png_file = io.BytesIO()
    imageio.imsave(png_file, catagorised_mask, format='PNG-PIL', prefer_uint8=True)

    sample_name = f"{ct_uid}_mask.png"
    asset = client.upload_asset(png_file, filename=sample_name)
    image_url = asset["url"]
    
    sample_uuid = ct_uid_to_upload_uuid[ct_uid]
    task_name = "contouring"
    attributes = {
        "segmentation_bitmap": {
            "url": image_url
        },
        "annotations": objects_map
    }

    client.add_label(sample_uuid, task_name, attributes, label_status='REVIEWED')