### Notebook contains functions and code on how to use ultivue multiplexed data

In [None]:
import os
from skimage import io
import numpy as np
import glob
import matplotlib.pyplot as plt
import torch.nn.functional as F
import cv2
import pandas as pd
from skimage.io import imsave
from skimage.metrics import structural_similarity as ssim

import csv
import xml.etree.ElementTree as ET
import openslide
import slideio
from pathlib import Path
import math
import joblib
from tqdm import tqdm

from aicspylibczi import CziFile
from matplotlib.colors import LinearSegmentedColormap


### Key Information

- Ultivue Data:
  - We have HE WSIs and consecutive tissue cuts on which Ultivue Immuno8 and MDSC panels were performed.
  - Since HE and Ultivue images come from consecutive cuts, they have been aligned using the DRMIME algorithm with a provided transformation matrix.

- Panel Channels:
  - Immuno8 Panel:
    - `DAPI`, `PD-L1`, `CD68`, `CD8a`, `PD-1`, `DAPI2`, `FoXP3`, `SOX10`, `CD3`, `CD4`
  - MDSC Panel:
    - `DAPI`, `CD11b`, `CD14`, `CD15`, `HLA-DR`
    
- We observed for some red regions in the HE image, there was false positive for some markers. Therefore, we annotated those regions in HE images using HALO AI and we exclude those regions in all analysis
    

In [None]:
# settings 
base_path = '/raid/sonali/ultivue' # chage to path where the data is downloaded
path_HE = os.path.join(base_path, 'HE') 
path_immuno8 = os.path.join(base_path, 'ultivue_immuno8')
path_mdsc = os.path.join(base_path, 'ultivue_mdsc')
path_exclusion_annotation = os.path.join(base_path, 'exclusion_masks')

# alignment of immuno8 with HE; alignment of mdsc with immuno8 
path_alignment_immuno8 = os.path.join(base_path, 'alignment_immuno8_HE', 'transformation_matrix')
path_alignment_mdsc = os.path.join(base_path, 'alignment_mdsc_immuno8', 'template_matching')

# meta data for ultivue images 
resolution_IHC = 0.325
channel_order_immuno8 = ['DAPI', 'PD-L1','CD68','CD8a','PD-1','DAPI2','FoXP3','SOX10','CD3','CD4']
channel_order_mdsc = ['DAPI', 'CD11b', 'CD14', 'CD15', 'HLA-DR']
markers_plot = ['SOX10', 'CD3', 'CD8a', 'HLA-DR']
channel_indices_immuno8 = [channel_order_immuno8.index(channel) for channel in markers_plot  if channel in channel_order_immuno8]
channel_indices_mdsc = [channel_order_mdsc.index(channel) for channel in markers_plot  if channel in channel_order_mdsc]

level = 6 # level for plotting 
channels_imc_subset = ['S100', 'MelanA', 'SOX10', 'CD3', 'CD8a', 'HLA-DR'] # channels commnon in ultivue and imc
resolution_he = 0.23 # um/px  wsi.properties['openslide.mpp-y'] # resolution of HE images
downsample = 2**level


In [None]:
# custom colormaps for each channel
cm_red = LinearSegmentedColormap.from_list("CustomRed", [(0, 0, 0), (1.0, 0, 0)], N=256)
cm_green = LinearSegmentedColormap.from_list("CustomGreen", [(0, 0, 0), (0.223, 1.0, 0.196)], N=256)
cm_magenta = LinearSegmentedColormap.from_list("CustomMagenta", [(0, 0, 0), (1, 0, 1)], N=256)
cm_darkgreen = LinearSegmentedColormap.from_list("CustomDarkgreen", [(0, 0, 0), (0.173, 0.627, 0.173)], N=256)
cm_turquoise = LinearSegmentedColormap.from_list("CustomTurquoise", [(0, 0, 0), (0.2, 0.9, 0.9)], N=256)
cm_orange = LinearSegmentedColormap.from_list("CustomOrange", [(0, 0, 0), (1, 0.5, 0)], N=256)

# Dictionary to map each channel name to its colormap
channel_colormap_dict = {
    "MelanA": cm_red, 
    "CD3": cm_green,
    "SOX10": cm_magenta,  
    "HLA-DR": cm_orange,
    "CD8a": cm_turquoise,
}

In [None]:
# ultivue signal vs noise
# we have visually looked at some channels and provided upper and lower bound for channels and samples 
# this was done visually. the threshold differs across samples and markers
bounds = {
    'MACEGEJ': {'SOX10': [100, 255], 'CD3': [248, 255], 'CD8a': [245, 255], 'HLA-DR': [235, 255]}, 
    'MELIPIT': {'SOX10': [100, 255], 'CD3': [235, 255], 'CD8a': [222, 255], 'HLA-DR': [245, 255]},
    'MIDEKOG': {'SOX10': [120, 255], 'CD3': [245, 255], 'CD8a': [238, 255], 'HLA-DR': [235, 255]},
    'MANOFYB': {'SOX10': [120, 255], 'CD3': [210, 255], 'CD8a': [249, 255], 'HLA-DR': [180, 255]}  
}

In [None]:
exclude_red = True # exclude red regions from viz
salt_pepper = True # remove salt and pepper noise

In [None]:
def get_exclusion_mask(img_annots, f_regions, downsample_factor=16, coutour_thickness=-1): 
    # exclude red annotated regions
    mask_exclusion = np.zeros((img_annots.shape), np.uint8)
    tree = ET.parse(f_regions)
    Annotation = tree.findall('Annotation')
    for j in range(len(Annotation)):
        label = Annotation[j].get('Name')
        n_regions = len(Annotation[j].findall('Regions/Region'))
        for i in range(n_regions): 
            region = Annotation[j].findall('Regions/Region')[i]
            exclusion = region.get('NegativeROA')
            vertices = region.findall('Vertices/V')
            # get vertices for the region
            loc_temp = []
            for counter, x in enumerate(vertices):
                loc_X = int(float(x.attrib['X']))
                loc_Y = int(float(x.attrib['Y']))
                loc_temp.append([loc_X, loc_Y])
            loc_temp = np.asarray(loc_temp)
            loc_temp = loc_temp / downsample_factor # just to plot the coordinates on a downsampled image
            loc_temp = loc_temp.astype(int)

            if int(exclusion)==1: 
                mask_exclusion = cv2.drawContours(mask_exclusion, [loc_temp], 0, (255,255,255), coutour_thickness)
    return mask_exclusion

def load_afi_IF(afi_file, channels_indices=[0,1,2], downsample=10): 
    # afi file for immuno8 panel
    slide = slideio.open_slide(afi_file)
    n_channels = slide.num_scenes # channels/markers
    assert n_channels == 10, "we expect to have 10 channels in each file"    
    image_multiplex = np.array([])
    for i, channel_index in enumerate(channels_indices): 
        if i==0: 
            scene = slide.get_scene(channel_index)
            size = (scene.rect[2]//downsample, scene.rect[3]//downsample)
            img = (scene.read_block(scene.rect, size)//256).astype(np.uint8)
            image_multiplex = np.expand_dims(img.astype(np.uint8), axis=-1)        
        else: 
            scene = slide.get_scene(channel_index)
            img = (scene.read_block(scene.rect, size)//256).astype(np.uint8)
            image_multiplex = np.append(image_multiplex, np.expand_dims(img.astype(np.uint8), axis=-1), axis = -1)            
    return image_multiplex

def load_czi_IF(czi_file, channels_indices=[4], downsample=10):
    # czi file for mdsc panel
    czi = CziFile(czi_file)
    dimensions = czi.get_dims_shape()
    bbox = czi.get_mosaic_bounding_box()
    x_start = bbox.x 
    y_start = bbox.y 
    width = bbox.w 
    height = bbox.h

    image_multiplex = np.array([])
    for i, channel_index in enumerate(channels_indices): 
        img = czi.read_mosaic(region=(x_start, y_start, width, height), scale_factor=1/downsample, C=channel_index)
        img = (img//256).astype(np.uint8)
        if i==0: 
            image_multiplex = np.expand_dims(img.astype(np.uint8), axis=-1)        
        else: 
            image_multiplex = np.append(image_multiplex, np.expand_dims(img.astype(np.uint8), axis=-1), axis = -1)  
    image_multiplex = image_multiplex.squeeze()
    return image_multiplex

# find smaller side and resize it
def fix_size(image_HE, image_multiplex): 
    diff = np.subtract(image_multiplex.shape, image_HE.shape)
    if (diff[0] < 0 and diff[1] < 0):
        image_multiplex = np.lib.pad(image_multiplex, ((0,-diff[0]),(0,-diff[1]),(0,0)), 'constant')
    elif (diff[0] < 0 and diff[1] >= 0):
        image_HE = np.lib.pad(image_HE, ((0,0),(0,diff[1]),(0,0)), 'constant')
        image_multiplex = np.lib.pad(image_multiplex, ((0,-diff[0]),(0,0),(0,0)), 'constant')
    elif (diff[0] >= 0 and diff[1] < 0):
        image_HE = np.lib.pad(image_HE, ((0,diff[0]),(0,0),(0,0)), 'constant')
        image_multiplex = np.lib.pad(image_multiplex, ((0,0),(0,-diff[1]),(0,0)), 'constant')
    else:
        image_HE = np.lib.pad(image_HE, ((0,diff[0]),(0,diff[1]),(0,0)), 'constant')
    return image_HE, image_multiplex
                

In [None]:
for f_sample_immuno8 in glob.glob(path_immuno8 + '/*-Scene-2-stacked' + '/*.afi'): 
    sample = f_sample_immuno8.split('/')[-1].split('-')[0]
    
    if sample in ['MELIPIT', 'MIDEKOG', 'MANOFYB', 'MACEGEJ']: # choose samples for which threshold bound defined
        f_sample_mdsc = glob.glob(path_mdsc + '/' + sample + '*.czi')[0]
        f_alignment_immuno8 = glob.glob(path_alignment_immuno8 + '/' + sample +  '-Scene-2.npz')[0]
        f_alignment_mdsc = glob.glob(path_alignment_mdsc + '/' + sample +  '-Scene-2.npz')[0]
        print(sample)

        # load HE immuno 8 alignment 
        align_immuo8 = np.load(f_alignment_immuno8)
        Q_matrix = align_immuo8['transformation_matrix_Q']
        range_HE = align_immuo8['range_HE'] * align_immuo8['downsample'][0] # get range in highest resolution
        Q_scaled = np.matmul(np.array([[1/downsample, 0, 0], [0, 1/downsample, 0], [0, 0, 1]]), np.matmul(Q_matrix, np.array([[downsample, 0, 0], [0, downsample, 0], [0, 0, 1]])))
        total = np.linalg.inv(Q_scaled)

        # load mdsc immuno8 alignment
        align_mdsc = np.load(f_alignment_mdsc)
        print(align_mdsc.files)
        range_mdsc = ((align_mdsc['range_mdsc']*align_mdsc['downsample']) // downsample).astype(int)
        
        # ----- img IF immuno 8 ----- 
        image_immuno8 = load_afi_IF(f_sample_immuno8, channel_indices_immuno8, downsample)
        print('image_immuno8: ', image_immuno8.shape)

        # ----- img IF mdsc panel ----- 
        image_mdsc = load_czi_IF(f_sample_mdsc, channel_indices_mdsc, downsample)[range_mdsc[0]:range_mdsc[1], range_mdsc[2]:range_mdsc[3]]
        image_mdsc = np.expand_dims(image_mdsc.astype(np.uint8), axis=-1)
        print('image_mdsc: ', image_mdsc.shape)

        # reshaping if needed and appending 
        image_immuno8, image_mdsc = fix_size(image_immuno8, image_mdsc)
        image_multiplex = np.append(image_immuno8, image_mdsc, axis = -1)
        print('image_multiplex: ', image_multiplex.shape)

        # ----- img he ----- 
        f_sample_he = align_immuo8['reference_he'].item()
        slide_HE = openslide.OpenSlide(f_sample_he)
        level = slide_HE.get_best_level_for_downsample(downsample)
        image_HE = slide_HE.read_region((0, 0), level, slide_HE.level_dimensions[level])
        image_HE = np.array(image_HE.convert("RGB")).astype(np.float32)
        image_HE = image_HE[:, int(range_HE[0]//downsample): int(range_HE[1]//downsample), :]
        print('image_HE: ', image_HE.shape)

        # ----- aligning image_multiplex with HE ----- 
        image_multiplex = cv2.warpPerspective(image_multiplex, total, (image_HE.shape[1], image_HE.shape[0]), borderValue = (0, 0, 0))
        print('image_multiplex_warp: ', image_multiplex.shape)

        # ----- preprocessing and thresholding -----
        image_multiplex = (image_multiplex - np.amin(image_multiplex, axis=(0,1)))/(np.amax(image_multiplex, axis=(0,1)) - np.amin(image_multiplex, axis=(0,1))) # minmax per channel, range [0, 1]
        image_multiplex = (image_multiplex*255).astype(np.uint8) # convert to 0-255

        # doing histogram equilization
        for j in range(image_multiplex.shape[2]): 
            image_multiplex[:,:,j] = cv2.equalizeHist(image_multiplex[:,:,j])

        for j in range(image_multiplex.shape[2]): 
            binary_mask = cv2.inRange(image_multiplex[:,:,j], bounds[sample.split('-')[0]][markers_plot[j]][0], bounds[sample.split('-')[0]][markers_plot[j]][1]) # do thresholding
            image_multiplex[:,:,j] = cv2.bitwise_and(image_multiplex[:,:,j], image_multiplex[:,:,j], mask=binary_mask) # apply mask to the original image
        
        mask = image_multiplex[:,:, 0] > np.min(image_multiplex[:,:,0])
        image_multiplex[:, :, 1][mask] = 0 
        image_multiplex[:, :, 2][mask] = 0  
        image_multiplex[:, :, 3][mask] = 0 
        
        # ----- loading exclusion mask -----
        if exclude_red: 
            f_exclusion = glob.glob(path_exclusion_annotation + '/' + sample +  '*.annotations')[0]
            image_HE = slide_HE.read_region((0, 0), level, slide_HE.level_dimensions[level])
            image_HE = np.array(image_HE.convert("RGB")).astype(np.float32)
            mask_exclusion = get_exclusion_mask(image_HE, f_exclusion, downsample_factor=downsample, coutour_thickness=-1)
            image_HE = image_HE[:, int(range_HE[0]//downsample): int(range_HE[1]//downsample), :]
            mask_exclusion = mask_exclusion[:, int(range_HE[0]//downsample): int(range_HE[1]//downsample), :] # removing tonsil region

        if exclude_red: 
#             image_HE = np.where(mask_exclusion, 255, image_HE)
            mask_exclusion = np.any(mask_exclusion, axis=-1)
            print('shapes masks: ', mask_exclusion.shape)

            try: 
                image_multiplex[:,:,0] = np.where(mask_exclusion, 0, image_multiplex[:,:,0])
                image_multiplex[:,:,1] = np.where(mask_exclusion, 0, image_multiplex[:,:,1])
                image_multiplex[:,:,2] = np.where(mask_exclusion, 0, image_multiplex[:,:,2])
                # image_multiplex[:,:,3] = np.where(mask_exclusion, 0, image_multiplex[:,:,3])
            except:
                print('shapes: ', mask_exclusion.shape)
            
        if salt_pepper:
            for j in range(image_multiplex.shape[2]):
                image_multiplex[:,:,j] = cv2.medianBlur(image_multiplex[:,:,j], 3)

        # ----- plotting GT -----
        fig, axs = plt.subplots(1, len(markers_plot)+1, figsize=(30, 8))
        axs[0].imshow(image_HE/255)
        axs[0].set_title('HE', color='red')
        axs[0].axis('off')
        for i, protein_marker in enumerate(markers_plot):
            axs[i+1].imshow(image_multiplex[:,:,i], cmap=channel_colormap_dict[protein_marker])
            axs[i+1].set_title('GT ' + protein_marker)
            axs[i+1].axis('off')
        plt.show()