In [None]:
%matplotlib inline

In [None]:
import time
import os
import sys
from collections import defaultdict
import json

import xml.etree.ElementTree as ET
from xml.dom import minidom

import numpy as np
import pandas as pd

import openslide
from PIL import Image
from PIL import ImageDraw

# Convert xml input to a region_id dict
## and display for developer sanity check

In [None]:
"""
from collections import defaultdict
import json

from xml.dom import minidom

import openslide

from PIL import Image
from PIL import ImageDraw
"""
def get_labels_dict(class_labels_id_file_name):
    """ labels_dict = get_labels_dict(class_labels_id) 
    Args:
        class_labels_id_file_name:  .csv file with columns [Label, ID, Priority]
    Returns:
        labels_dict:                python dict of dicts:
                                        {label_id_(n): {ID: x, Priority: y}, ...}
    """
    # allocate dict of dicts
    labels_dict = defaultdict(dict)
    
    # read the file
    lines = ''
    try:        
        with open(class_labels_id_file_name, 'r') as fh:
            lines = fh.readlines()
    except:
        print('failed opening: ', class_labels_id_file_name)
        lines = ''
        pass
    
    # read the lines into the dict
    if len(lines) > 0:
        for line in lines:
            line_list = line.strip().split(',')
            if len(line_list) > 1 and line_list[0] != 'Label':
                labels_dict[line_list[0]] = {'ID': line_list[1], 'Priority': line_list[1]}
                
    return labels_dict


def get_xml_list_of_dicts(xml_file_name):
    """ regions_list = get_xml_list_of_dicts(xml_file_name) 
    Args:
        xml_file_name:  required TagName
                            Vertex
                        required Attributes
                            Id
                            Text
                            Type
                            GeoShape
    Returns:
        regions_list:   list of dicts with keys:
                            region_Id
                            class_label_text
                            class_label_Id
                            region_geo_shape
                            coords
    """
    
    xml_obj = minidom.parse(xml_file_name)
    regions_dom = xml_obj.getElementsByTagName("Region")

    regions_list = []
    for reg_dom in regions_dom:
        tmp_dict = {}
        vertices = reg_dom.getElementsByTagName("Vertex")
        tmp_dict['region_Id'] = reg_dom.getAttribute('Id')
        tmp_dict['class_label_text'] = reg_dom.getAttribute('Text')
        tmp_dict['class_label_Id'] = reg_dom.getAttribute('Type')
        tmp_dict['region_geo_shape'] = reg_dom.getAttribute('GeoShape')
        tmp_dict['coords'] = np.zeros((len(vertices), 2))
        for i, vertex in enumerate(vertices):
            tmp_dict['coords'][i][0] = vertex.attributes['X'].value
            tmp_dict['coords'][i][1] = vertex.attributes['Y'].value
            
        regions_list.append(tmp_dict)
        
    return regions_list


def get_region_Id_dict(xml_file_name, class_labels_id_file_name):
    """ region_id_dict = get_region_Id_dict(xml_file_name, class_labels_id_file_name) """
    # read the xml file into a list of dicts
    regions_list = get_xml_list_of_dicts(xml_file_name)
    
    # read the csv file into a dict of dicts Label: {ID: x, Priority: y}
    label_dict = get_labels_dict(class_labels_id_file_name)
    
    # initialize the output dictionary
    region_id_dict = defaultdict(dict)
    
    # enter each region into the output dict
    for region_dict in regions_list:
        # extract and cast the region Priority from the csv file
        region_priority = int(label_dict[region_dict['class_label_text']]['Priority'])
        
        # construct the rest of the region dict from the xml file
        ridic = {'class_label_text': region_dict['class_label_text'], 
                 'class_label_Id': region_dict['class_label_Id'], 
                 'Priority': region_priority, 
                 'region_geo_shape': region_dict['region_geo_shape'], 
                 'coords': region_dict['coords']}
        
        region_id_dict[int(region_dict['region_Id'])] = ridic
        
    return region_id_dict

def get_region_mask(region_coords, thumbnail_divisor, image_dimensions):
    """ mask_im, img = get_region_mask(region_coords, thumbnail_divisor, image_dimensions) """
    # scale the region coords with the thumbnail_divisor, convert to a list of tuples of type int
    xy_list = (region_coords / thumbnail_divisor).astype(np.int).tolist()
    xy_list = [(p[0], p[1]) for p in xy_list ]
    
    # create a (black) thumbnail of the full-size image scaled with the thumbnail_divisor
    thumbnail_size = tuple((np.array(image_dimensions) / thumbnail_divisor).astype(np.int))
    num_thumb_size = (thumbnail_size[1], thumbnail_size[0])
    img = Image.fromarray(np.zeros(num_thumb_size).astype(np.uint8))
    
    # make it a Pillow Draw and draw the polygon from the list of (x,y) tuples
    draw = ImageDraw.Draw(img)
    draw.polygon(xy_list, fill="white")
    
    # create the logical mask for patch selection
    mask_im = np.array(img) > 0
    
    return mask_im, img

In [None]:
# Test region Id dict function
zip_tank = '../../DigiPath_MLTK_data/zipTank/wsi_annotation_sample/'
xml_name = os.path.join(zip_tank, 'e39a8d60a56844d695e9579bce8f0335.xml')
c_lab_id_fn = os.path.join(zip_tank, 'class_label_id.csv')

region_id_dict = get_region_Id_dict(xml_name, c_lab_id_fn)
for k_reg_id, v_val in region_id_dict.items():
    print('\n\tregion_Id', k_reg_id)
    for k, v in v_val.items():
        if k == 'coords':
            print('%20s: size = (%i, %i)'%(k, v.shape[0], v.shape[1]), type(v), type(v[0,0]))
        elif k == 'Priority':
            print('%20s: %i'%(k, v))
        else:
            print('%20s: %s'%(k, v))

In [None]:
# Test Labels Dict function
csv_file_name = 'class_label_id.csv'
c_lab_id_fn = os.path.join(zip_tank, csv_file_name)

lbl_dict = get_labels_dict(c_lab_id_fn)

print('content of csv file: %s as python dict\n'%(csv_file_name))
for k, v in lbl_dict.items():
    print('%20s'%(k), v)

In [None]:
# Test get_xml_list_of_dicts function -- view input file fields

# xml_name = os.path.join(zip_tank, '54742d6c5d704efa8f0814456453573a.xml')
xml_name = os.path.join(zip_tank, 'e39a8d60a56844d695e9579bce8f0335.xml')

lstod = get_xml_list_of_dicts(xml_name)
keys_list = list(lstod[0].keys())
count = 0
print('\t'.join(keys_list), '\n')
for l in lstod:
    # print(type(l), len(l))
    for k in keys_list:
        if not k in l:
            print('Crash Crash %i not in %i'%(k, count))
        elif k == 'coords':
            print(l['region_Id'], 
                  l['class_label_text'], 
                  l['class_label_Id'], 
                  '\t\t', len(l[k]), 
                  '\t', type(l[k]))
        elif k == 'region_geo_shape':
            print(l['region_geo_shape'])
    count += 1
    print()

In [None]:
# Test and display function get_region_mask()
im_dir = '../../DigiPath_MLTK_data/RegistrationDevData'
im_file = 'e39a8d60a56844d695e9579bce8f0335.tiff'
image_file_name = os.path.join(im_dir, im_file)

os_obj = openslide.OpenSlide(image_file_name)
image_dimensions = os_obj.dimensions
os_obj.close()

region_coords = region_id_dict[8]['coords']
thumbnail_divisor = 200

# Test and display
mask_im, img = get_region_mask(region_coords, thumbnail_divisor, image_dimensions)
display(img)

In [None]:
# Test / demo summing masks

t0 = time.time()
n_regions = 8

thumbnail_size = tuple((np.array(image_dimensions) / thumbnail_divisor).astype(np.int))
num_thumb_size = (thumbnail_size[1], thumbnail_size[0])

uber_mask = np.zeros(num_thumb_size).astype(np.uint8)

for r_id in range(n_regions):
    region_coords = region_id_dict[r_id+1]['coords']
    thumbnail_divisor = 200
    print('region_Id', r_id)
    mask_im, img = get_region_mask(region_coords, thumbnail_divisor, image_dimensions)
    if r_id > 2:
        uber_mask += mask_im
    display(img)
    

uber_mask[uber_mask > 0] = 1
uber_mask_im = Image.fromarray(uber_mask * 255)
print('Total time to get & display %i region_masks: %0.3f'%(n_regions, time.time() - t0))
print('\n\nsum of smaller masks')
display(uber_mask_im)