In [None]:
#imports

import sys
import torch
import numpy as np
import os
import time
import matplotlib.pyplot as plt
import SimpleITK as sitk
import pandas as pd
import skimage.color as color
import scipy.ndimage as nd

import utils
import networks
import deep_segmentation as ds
import rotation_alignment as ra
import affine_registration as ar
import deformable_registration as nr

from networks import segmentation_network as sn
from networks import affine_network_attention as an
from networks import affine_network_simple as asimple 
from networks import nonrigid_registration_network as nrn

In [None]:
#defines
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
path = 'path/to/images'
models_path = "path/to/models"
output_max_size = 4096

dhr_params = dict()
dhr_params['segmentation_mode'] = "deep_segmentation"
dhr_params['initial_rotation'] = True
dhr_params['affine_registration'] = True
dhr_params['nonrigid_registration'] = True
initial_rotation_params = dict()
initial_rotation_params['angle_step'] = 1
dhr_params['initial_rotation_params'] = initial_rotation_params
affine_registration_params = dict()
affine_registration_params['model_path'] = models_path + "affine_model_512" # TO DEFINE
affine_registration_params['affine_type'] = "simple"
dhr_params['affine_registration_params'] = affine_registration_params
nonrigid_registration_params = dict() # Params used during training
nonrigid_registration_params['stride'] = 128
nonrigid_registration_params['patch_size'] = (256, 256)
nonrigid_registration_params['number_of_patches'] = 32
nonrigid_registration_params['num_levels'] = 3
nonrigid_registration_params['inner_iterations_per_level'] = [3, 3, 3]
nonrigid_registration_params['model_path'] = models_path + "nonrigid_2048" # TO DEFINE
dhr_params['nonrigid_registration_params'] = nonrigid_registration_params
segmentation_params = dict()
segmentation_params['model_path'] =  models_path + "segmentation_model_512" # TO DEFINE
dhr_params['segmentation_params'] = segmentation_params
load_masks = False

In [None]:
#load models
def load_models():
    segmentation_mode = dhr_params['segmentation_mode']
    if segmentation_mode == "deep_segmentation":
        segmentation_params = dhr_params['segmentation_params']
        seg_model_path = segmentation_params['model_path']
        seg_model = sn.load_network(device, path=seg_model_path)
    initial_rotation = dhr_params['initial_rotation']
    if initial_rotation:
        initial_rotation_params = dhr_params['initial_rotation_params']
    affine_registration = dhr_params['affine_registration']
    if affine_registration:
        affine_registration_params = dhr_params['affine_registration_params']
        affine_model_path = affine_registration_params['model_path']
        affine_type = affine_registration_params['affine_type']
        if affine_type == "attention":
            affine_model = an.load_network(device, path=affine_model_path)
        elif affine_type == "simple":
            affine_model = asimple.load_network(device, path=affine_model_path)
    nonrigid_registration = dhr_params['nonrigid_registration']
    if nonrigid_registration:
        nonrigid_registration_params = dhr_params['nonrigid_registration_params']
        nonrigid_model_path = nonrigid_registration_params['model_path']
        num_levels = nonrigid_registration_params['num_levels']
        nonrigid_models = list()
        for i in range(num_levels):
            current_path = nonrigid_model_path + "_level_" + str(i+1)
            nonrigid_models.append(nrn.load_network(device, path=current_path))
    return seg_model, affine_model, nonrigid_models

In [None]:
#read image properties form path
def read_properties_from_path(path_to_png):
    biopsie_id = path_to_png.split('/')[-2]
    scanner = path_to_png.split('/')[-3]
    microns = path_to_png.split('/')[-4]
    startx,starty = path_to_png.split('/')[-1][:-4].split('_')
    return  int(biopsie_id), scanner, float(microns), int(startx), int(starty)

In [None]:
#get image as np.array
def load_image(path_to_png):
    return sitk.GetArrayFromImage(sitk.ReadImage(path_to_png))

In [None]:
#get relevant coordinates for image
def get_coordinates(*, biopsie_id, phh3_startx, phh3_starty, he_startx, he_starty, microns, max_x_value=None, max_y_value=None):
    # Reading the coordinates
    # csv contains coordinates for several other scanners, we only care about scanner UBE P1000 and PHH3 and HE stained images  
    elastic_coordinates = pd.read_csv('path/to/image/annotations.csv')
    # Remove unnecessary columns
    elastic_coordinates = elastic_coordinates[['biopsie_id','x_ube_p1000_phh3','y_ube_p1000_phh3','x_ube_p1000_he','y_ube_p1000_he']]
    
    # use only rows from this biopsie
    sub_df = elastic_coordinates[elastic_coordinates['biopsie_id']==biopsie_id]
    # transform HE
    sub_df['x_ube_p1000_he'] -= he_startx
    sub_df['x_ube_p1000_he'] /= microns/0.121267361111111 # the x resolution of p1000
    sub_df['y_ube_p1000_he'] -= he_starty
    sub_df['y_ube_p1000_he'] /= microns/0.121323529411765 # the y resolution of p1000
    # transform PHH3
    sub_df['x_ube_p1000_phh3'] -= phh3_startx
    sub_df['x_ube_p1000_phh3'] /= microns/0.121267361111111 # the x resolution of p1000
    sub_df['y_ube_p1000_phh3'] -= phh3_starty
    sub_df['y_ube_p1000_phh3'] /= microns/0.121323529411765 # the y resolution of p1000
    
    if max_x_value and not max_y_value:
        raise ValueError('Need to set "max_x_value" and "max_x_value".')
    elif not max_x_value and max_y_value:
        raise ValueError('Need to set "max_x_value" and "max_x_value".')
    elif max_x_value and max_y_value:
        sub_df = sub_df[sub_df['x_ube_p1000_he']>=0]
        sub_df = sub_df[sub_df['y_ube_p1000_he']>=0]
        sub_df = sub_df[sub_df['x_ube_p1000_he']<max_x_value]
        sub_df = sub_df[sub_df['y_ube_p1000_he']<max_y_value]
        
        sub_df = sub_df[sub_df['x_ube_p1000_phh3']>=0]
        sub_df = sub_df[sub_df['y_ube_p1000_phh3']>=0]
        sub_df = sub_df[sub_df['x_ube_p1000_phh3']<max_x_value]
        sub_df = sub_df[sub_df['y_ube_p1000_phh3']<max_y_value]
        
    return sub_df

In [None]:
#plot image
def plot_images(dict_of_img_arrays, figsize=(15,15), add_row=0, add_column=0):
    
    plt.figure(figsize = figsize)
    
    for index, name in enumerate(dict_of_img_arrays, start=1):
        plt.subplot(1 + add_row, len(dict_of_img_arrays) + add_column, index)
        plt.title(name)
        plt.imshow(dict_of_img_arrays[name])

In [None]:
#grayscale image
def grayscale_image(img_array):
    #used images are .png transformation from RGBA is necessary, DHR expects RGB format
    return color.rgb2gray(color.rgba2rgb(img_array))

In [None]:
#resample image
#DHR expects images to be the same size
def resample_image(source_img, target_img, source_coords, target_coords):
    
    norm_source = 1 - utils.normalize(source_img) 
    norm_target = 1 - utils.normalize(target_img)
    
    padded_source, padded_target = utils.pad_images_np(norm_source, norm_target)
    padded_source_landmarks = utils.pad_landmarks(source_coords, norm_source.shape, padded_source.shape)
    padded_target_landmarks = utils.pad_landmarks(target_coords, norm_target.shape, padded_target.shape)
   
    resample_factor = np.max(padded_source.shape) / output_max_size
    gaussian_sigma = resample_factor / 1.25

    smoothed_source = nd.gaussian_filter(padded_source, gaussian_sigma)
    smoothed_target = nd.gaussian_filter(padded_target, gaussian_sigma)

    resampled_source = utils.resample_image(smoothed_source, resample_factor)
    resampled_target = utils.resample_image(smoothed_target, resample_factor)
    resampled_source_landmarks = utils.resample_landmarks(padded_source_landmarks, resample_factor)
    resampled_target_landmarks = utils.resample_landmarks(padded_target_landmarks, resample_factor)
    
    return resampled_source, resampled_target, resampled_source_landmarks, resampled_target_landmarks

In [None]:
#img arrays to device
#returns float32 img tensor
def send_to_device(img_array):
    img_tensor = torch.from_numpy(img_array).to(device)
    return img_tensor.to(torch.float32)

In [None]:
#segmentation
#returns source and target segmentation masks
def segmentation(source, target, model, device="cpu"):
    with torch.set_grad_enabled(False):
        output_min_size = 512
        new_shape = utils.calculate_new_shape_min((source.size(0), source.size(1)), output_min_size)
        resampled_source = utils.resample_tensor(source, new_shape, device=device)
        resampled_target = utils.resample_tensor(target, new_shape, device=device)
        source_mask = model(resampled_source.view(1, 1, resampled_source.size(0), resampled_source.size(1)))[0, 0, :, :]
        target_mask = model(resampled_target.view(1, 1, resampled_target.size(0), resampled_target.size(1)))[0, 0, :, :]
        source_mask = utils.resample_tensor(source_mask, (source.size(0), source.size(1)), device=device) > 0.5
        target_mask = utils.resample_tensor(target_mask, (target.size(0), target.size(1)), device=device) > 0.5
        return source_mask, target_mask

In [None]:
#rotation
def rotation(source, source_mask, target, warped_source, displacement_field, initial_rotation_params, device):
    if torch.sum(source_mask) >= 0.99*source.size(0)*source.size(1):
        pass
    else:
        rot_displacement_field = ra.rotation_alignment(warped_source, target, initial_rotation_params, device=device)
        displacement_field = utils.compose_displacement_field(displacement_field, rot_displacement_field, device=device, delete_outliers=False)
        warped_source = utils.warp_tensor(source, displacement_field, device=device)

        return warped_source, displacement_field, rot_displacement_field


In [None]:
#affine registration
def affine_registration(source, target, warped_source, displacement_field, affine_model, device):
    affine_displacement_field = ar.affine_registration(warped_source, target, affine_model, device=device)
    displacement_field = utils.compose_displacement_field(displacement_field, affine_displacement_field, device=device, delete_outliers=False)
    warped_source = utils.warp_tensor(source, displacement_field, device=device)

    return warped_source, displacement_field, affine_displacement_field

In [None]:
#nonrigid registration
def nonrigid_registration(source, target, warped_source, displacement_field, nonrigid_models, nonrigid_registration_params, device):
    nonrigid_displacement_field = nr.nonrigid_registration(warped_source, target, nonrigid_models, nonrigid_registration_params, device=device)
    displacement_field = utils.compose_displacement_field(displacement_field, nonrigid_displacement_field, device=device, delete_outliers=False)
    warped_source = utils.warp_tensor(source, displacement_field, device=device)

    return warped_source, displacement_field, nonrigid_displacement_field

In [None]:
#calc rtre
#from deephistreg utils
def calculate_tre(source_landmarks, target_landmarks):
    tre = np.sqrt(np.square(source_landmarks[:, 0] - target_landmarks[:, 0]) + np.square(source_landmarks[:, 1] - target_landmarks[:, 1]))
    return tre

def calculate_rtre(source_landmarks, target_landmarks, image_diagonal):
    tre = calculate_tre(source_landmarks, target_landmarks)
    rtre = tre / image_diagonal
    return rtre

In [None]:
def coord_df_to_np(dataframe):
    
    phh3_x = dataframe['x_ube_p1000_phh3'].to_numpy(copy=True)
    phh3_y = dataframe['y_ube_p1000_phh3'].to_numpy(copy=True)
    
    he_x = dataframe['x_ube_p1000_he'].to_numpy(copy=True)
    he_y = dataframe['y_ube_p1000_he'].to_numpy(copy=True)
    
    combined_phh3 = np.transpose((phh3_x, phh3_y))
    combined_he = np.transpose((he_x, he_y))
    
    return combined_he, combined_phh3

In [None]:
phh3 = f'{path}/to/phh3/stained/image.png'
he = f'{path}/to/he/stained/image.png'

he_biopsie_id, scanner, microns, he_startx, he_starty = read_properties_from_path(he)
phh3_biopsie_id, scanner, microns, phh3_startx, phh3_starty = read_properties_from_path(phh3)

df = get_coordinates(
    biopsie_id=phh3_biopsie_id,
    phh3_startx=phh3_startx,
    phh3_starty=phh3_starty,
    he_startx=he_startx,
    he_starty=he_starty,
    microns=microns,
    max_x_value=999999999999, # The width of the image
    max_y_value=999999999999, # The height of the image
)

source_coords, target_coords = coord_df_to_np(df)

source = load_image(he)
target = load_image(phh3)

source = grayscale_image(source)
target = grayscale_image(target)

source, target, source_coords, target_coords = resample_image(source, target, source_coords, target_coords)

#plt.figure(figsize=(15,15))
#plt.imshow(source)
#plt.scatter(source_coords[:,0],source_coords[:,1], s=3,c='brown')

source_tensor = send_to_device(source)
target_tensor = send_to_device(target)

seg_model, affine_model, nonrigid_models = load_models()

source_mask, target_mask = segmentation(source_tensor, target_tensor, seg_model, device=device)

warped_source = source_tensor.clone()
displacement_field = torch.zeros(2, source_tensor.size(0), source_tensor.size(1)).to(device)

warped_source, displacement_field, _ = rotation(source_tensor, source_mask, target_tensor, warped_source, displacement_field, initial_rotation_params, device)

warped_source, displacement_field, _ = affine_registration(source_tensor, target_tensor, warped_source, displacement_field, affine_model, device)

warped_source, displacement_field, _ = nonrigid_registration(source_tensor, target_tensor, warped_source, displacement_field, nonrigid_models, nonrigid_registration_params, device)

transformed_source_landmarks = utils.transform_landmarks(source_coords, displacement_field)

image_diagonal = np.sqrt(source_tensor.shape[0]**2 + source_tensor.shape[1]**2)
rtre_initial = calculate_rtre(source_coords, target_coords, image_diagonal)
rtre_final = calculate_rtre(transformed_source_landmarks, target_coords, image_diagonal)
string_to_save = "Initial TRE: " + str(np.median(rtre_initial)) + "\n" + "Resulting TRE: " + str(np.median(rtre_final))

print(string_to_save)

In [None]:
plt.figure(figsize=(20,20))
plt.imshow(warped_source.cpu().numpy(),cmap='gray')
plt.scatter(transformed_source_landmarks[:,0],transformed_source_landmarks[:,1], s=3,c='red')
plt.scatter(target_coords[:,0],target_coords[:,1], s=3,c='purple')