# Registration - get pairs of patches from wsi_1, wsi_2 and offset

In [None]:
import time
nb_start_time = time.time()

import os
import tempfile
import sys

from collections import OrderedDict
import argparse

import tensorflow as tf
from tensorflow import io as tf_io

import numpy as np
import pandas as pd
import yaml

from skimage.filters import threshold_otsu
from skimage.color import rgb2lab

from PIL import ImageDraw
from PIL import TiffImagePlugin as tip

import IPython.display as ip_display

import openslide

sys.path.insert(0, '../src/python')
from digipath_toolkit import *

def get_pair(wsi_obj_0, wsi_obj_1, x0, y0, x1, y1, patch_size=(224, 224), image_level=0):
    """ Usage: 
    patch_0, patch_1 = get_pair(wsi_obj_0, wsi_obj_1, x0, y0, offset_x, offset_y, patch_size, image_level)
                
    """
    patch_0 = wsi_obj_0.read_region((x0, y0), image_level, patch_size)
    patch_1 = wsi_obj_1.read_region((x1, y1), image_level, patch_size)
    return patch_0, patch_1

def im_pair_hori(im_0, im_1):
    """ Usage: new_im = cat_im_list_hori(im_list)
            combine a list of PIL images horizontaly
    """
    w0 = im_0.size[0]
    w = w0 + im_1.size[0] + 1
    h = max(im_0.size[1], im_1.size[1])

    new_im = tip.Image.new('RGB', (w, h) )
    box = (0, 0, w0, h)
    new_im.paste(im_0, box)
    
    box = (w0+1, 0, w, h)
    new_im.paste(im_1, box)

    return new_im

def get_patch_pairs_array(run_parameters):
    """ Usage: patch_pair_array = get_patch_pairs_array(run_parameters)
    
    Returns:
        patch_pairs_array:  list of tuples = [  (x0, y0, x1, y1),
                                                (x0, y0, x1, y1), ... , (x0, y0, x1, y1) ]
    """
    fixed_wsi = run_parameters['wsi_fixed']
    float_wsi = run_parameters['wsi_float']

    offset_x = run_parameters['offset_x']
    offset_y = run_parameters['offset_y']
    
    initial_image_level = run_parameters['image_level']
    
    patch_pair_array = []
    
    run_parameters['wsi_filename'] = fixed_wsi
    patch_location_array = get_patch_location_array_for_image_level_NB(run_parameters)
    
    run_parameters['image_level'] = 0
    
    fixed_levels_dict = get_level_sizes_dict(fixed_wsi)
    fixed_max_width = fixed_levels_dict['image_size'][0]
    fixed_max_height = fixed_levels_dict['image_size'][1]

    float_levels_dict = get_level_sizes_dict(float_wsi)
    float_max_width = float_levels_dict['image_size'][0]
    float_max_height = float_levels_dict['image_size'][1]

    run_parameters['image_level'] = initial_image_level
    
    last_patch_number = len(patch_location_array) - 1
    patch_number = 0
    while patch_number < last_patch_number:
        patch_number += 1
        y0, x0 = patch_location_array[patch_number]
        x1 = x0 - offset_x
        y1 = y0 - offset_y

        big_enough = x0 >= 0 and y0 >= 0 and x1 >= 0 and y1 >= 0
        fixed_small_enough = x0 < fixed_max_width and y0 < fixed_max_height 
        float_small_enough =  x1 < float_max_width and y1 < float_max_height

        if big_enough and fixed_small_enough and float_small_enough:
            patch_pair_array.append((x0, y0, x1, y1))
    
    return patch_pair_array

def get_patch_location_array_for_image_level_NB(run_parameters):
    """ Usage: patch_location_array = get_patch_location_array_for_image_level_NB(run_parameters)
    """
    #                   initialize an empty return value
    patch_location_array = []
    #                   name the input variables
    wsi_filename = run_parameters['wsi_filename']
    thumbnail_divisor = run_parameters['thumbnail_divisor']
    patch_select_method = run_parameters['patch_select_method']
    patch_height = run_parameters['patch_height']
    patch_width = run_parameters['patch_width']
    
    #                   set defaults for added parameters
    if 'threshold' in run_parameters:
        threshold = run_parameters['threshold']
    else:
        threshold = 0

    if 'image_level' in run_parameters:
        image_level = run_parameters['image_level']
    else:
        image_level = 0

    #                     OpenSlide open                      #
    os_im_obj = openslide.OpenSlide(wsi_filename)
    obj_level_diminsions = os_im_obj.level_dimensions

    #                   get the start, stop locations list for the rows
    pixels_height = obj_level_diminsions[image_level][1]
    rows_fence_array = get_fence_array(patch_length=patch_height, overall_length=pixels_height)

    #                   get the start, stop locations list for the columns
    pixels_width = obj_level_diminsions[image_level][0]
    cols_fence_array = get_fence_array(patch_length=patch_width, overall_length=pixels_width)

    #                   get a thumbnail image for the patch select method
    thumbnail_size = (pixels_width // thumbnail_divisor, pixels_height // thumbnail_divisor)
    small_im = os_im_obj.get_thumbnail(thumbnail_size)
    os_im_obj.close()
    #                     OpenSlide close                     #

    #                   get the binary mask as a measure of image region content
    mask_im = get_sample_selection_mask(small_im, patch_select_method)

    #                   iterator for rows:  (top_row, bottom_row, full_scale_row_number)
    it_rows = zip(rows_fence_array[:, 0] // thumbnail_divisor,
                  rows_fence_array[:, 1] // thumbnail_divisor,
                  rows_fence_array[:, 0])

    #                   variables for columns iterator
    lft_cols = cols_fence_array[:, 0] // thumbnail_divisor
    rgt_cols = cols_fence_array[:, 1] // thumbnail_divisor
    cols_array = cols_fence_array[:, 0]

    for tmb_row_top, tmb_row_bot, row_n in it_rows:
        #               iterator for cols:  (left_column, right_column, full_scale_column_number)
        it_cols = zip(lft_cols, rgt_cols, cols_array)

        for tmb_col_lft, tmb_col_rgt, col_n in it_cols:

            #           if the sum of the mask elements is larger than the threshold...
            if (mask_im[tmb_row_top:tmb_row_bot, tmb_col_lft:tmb_col_rgt]).sum() > threshold:

                #       add the full scale row and column of the upper left corner to the list
                patch_location_array.append((row_n, col_n))

    return patch_location_array

In [None]:
test_data_dir = '../../DigiPath_MLTK_data/RegistrationDevData/'
os.listdir(test_data_dir)

In [None]:
offset_data_file = os.path.join(test_data_dir, 'wsi_pair_sample.csv')
if os.path.isfile(offset_data_file):
    offset_df = pd.read_csv(offset_data_file)
offset_df

In [None]:
"""
            correctly reversed x, y
"""
offset_x = offset_df['truth_offset_x'].iloc[0]
offset_y = offset_df['truth_offset_y'].iloc[0]
offset_x, offset_y = int(round(offset_x)), int(round(offset_y))

auto_x = offset_df['auto_offset_x'].iloc[0]
auto_y = offset_df['auto_offset_y'].iloc[0]
auto_x, auto_y = int(round(auto_x)), int(round(auto_y))
offset_x, offset_y, auto_x, auto_y

In [None]:
fixed_wsi = os.path.join(test_data_dir, '54742d6c5d704efa8f0814456453573a.tiff')

fixed_levels_dict = get_level_sizes_dict(fixed_wsi)
fixed_max_width = fixed_levels_dict['image_size'][0]
fixed_max_height = fixed_levels_dict['image_size'][1]

for k, v in fixed_levels_dict.items():
    print('%25s: %s'%(k,v))

In [None]:
float_wsi = os.path.join(test_data_dir, 'e39a8d60a56844d695e9579bce8f0335.tiff')

float_levels_dict = get_level_sizes_dict(float_wsi)
float_max_width = float_levels_dict['image_size'][0]
float_max_height = float_levels_dict['image_size'][1]


for k, v in float_levels_dict.items():
    print('%25s: %s'%(k,v))

In [None]:
run_parameters = {'wsi_fixed': fixed_wsi, 
                  'wsi_float': float_wsi,
                  'thumbnail_divisor': 20, 
                  'patch_select_method': 'threshold_rgb2lab', 
                  'rgb2lab_threshold': 80, 
                  'image_level': 0, 
                  'patch_height': 224, 
                  'patch_width': 224, 
                  'threshold': 0, 
                  'offset_x': offset_x, 
                  'offset_y': offset_y}

# t0 = time.time()
# patch_pair_array = get_patch_pairs_array(run_parameters)
# print(len(patch_pair_array), '%0.2f s'%(time.time() - t0))

In [None]:
t0 = time.time()
image_level = 0
run_parameters['image_level'] = image_level
patch_size = (run_parameters['patch_width'], run_parameters['patch_height'])

patch_pair_array = get_patch_pairs_array(run_parameters)

os_fixed_obj = openslide.OpenSlide(fixed_wsi)
os_float_obj = openslide.OpenSlide(float_wsi)

for k in range(10):
    x0, y0, x1, y1 = patch_pair_array[k]
    print(patch_pair_array[k])
    patch_0, patch_1 = get_pair(os_fixed_obj, os_float_obj, x0, y0, x1, y1, patch_size, image_level)
    new_im = im_pair_hori(patch_0, patch_1)
    display(new_im)
    print('\n')
    
print(len(patch_pair_array), '%0.2f s'%(time.time() - t0))

In [None]:
try:
    os_fixed_obj.close()
except:
    print('oopsie oops')
try:
    os_float_obj.close()
except:
    print('oopsie oops')