## Solution notebook

### Imports, constants, and helper functions

In [1]:
"""Solution notebook."""

import os
from localtileserver import TileClient, get_leaflet_tile_layer
import plotly.graph_objects as go
from niceview.utils.raster import geo_ref_raster, geo_raster_to_meshgrid, index_to_meshgrid_coord
from niceview.utils.tools import txt_to_list, select_col_from_name
from niceview.utils.cell import get_nuclei_pixels, paint_regions
import pandas as pd
import numpy as np
from scipy.sparse import load_npz
import cv2

In [2]:
# configurations
DATA_PATH = '../examples/data/'
PLOTS_PATH = '../examples/plots/'

RELATION = {
    'pericyte': 1,
    'neuronal receptor cell': 2,
    'native cell': 3,
    'smooth muscle myoblast': 4,
    'immature innate lymphoid cell': 5,
    'cardiac endothelial cell': 6,
    'cardiac muscle myoblast': 7,
    'fibroblast of cardiac tissue': 8,
}

HEX_PALETTE = {
    1: '#d5a567',
    2: '#be6ca7',
    3: '#91d1e0',
    4: '#8bc060',
    5: '#879dcf',
    6: '#1d8942',
    7: '#253166',
    8: '#853087',
}

In [3]:
# args
id = 'gt-iz-p9-rep2'
cells_selected_gene_name = 'ENSG00000065534'
spots_selected_gene_name = 'ENSG00000037280'

In [4]:
# standardize the data
file = {
    'cells-gene-names': '-'.join([id, 'cells-gene-names.txt']),
    'cells-gene': '-'.join([id, 'cells-gene.npz']),
    'cells-info': '-'.join([id, 'cells-info.csv']),
    'img': '-'.join([id, 'img.tiff']),
    'mask-filtered-relabeled': '-'.join([id, 'mask-filtered-relabeled.npz']),
    'mask': '-'.join([id, 'mask.npz']),
    'spots-gene-names': '-'.join([id, 'spots-gene-names.txt']),
    'spots-gene': '-'.join([id, 'spots-gene.npz']),
    'spots-info': '-'.join([id, 'spots-info.csv']),
}

In [5]:
# intermediete files - TODO: to be improved
cache = {
    'blend-cells-gene': '-'.join([id, 'blend-cells-gene.png']),
    'blend-cells-type': '-'.join([id, 'blend-cells-type.png']),
    'blend-spots-gene': '-'.join([id, 'blend-spots-gene.png']),
    'mask-cells-type': '-'.join([id, 'mask-cells-type.png']),
    'mask-cells-gene': '-'.join([id, 'mask-cells-gene.png']),
    'spots-gene': '-'.join([id, 'spots-gene.png']),
}

In [6]:
# helper functions
def normalize_array(arr, new_min, new_max):
    min_val = np.min(arr)
    max_val = np.max(arr)
    normalized_arr = (arr - min_val) / (max_val - min_val) * (new_max - new_min) + new_min
    return normalized_arr

def mask_filter_relabel(mask_path, matched_regions, labels):
    mask = load_npz(mask_path)
    mask = mask.tocsr()[:, :].todense()
    # TODO: increase speed of `paint_regions`
    mask_filtered_relabeled = paint_regions(mask.shape, matched_regions, cell_colors_list=labels)
    return mask_filtered_relabeled.data

def hex_to_rgb(hex_color):
    # Remove the '#' symbol if it's present
    if hex_color.startswith("#"):
        hex_color = hex_color[1:]

    # Convert each pair from hexadecimal to decimal
    r = int(hex_color[0:2], 16)
    g = int(hex_color[2:4], 16)
    b = int(hex_color[4:6], 16)

    # Return the RGB values as a tuple
    return (r, g, b)

def discrete_cmap_from_hex(id_to_hex_dict):
    rgb_cmap = {k: hex_to_rgb(v) for k, v in id_to_hex_dict.items()}
    rgb_cmap = np.array([rgb_cmap[i] for i in range(1, len(rgb_cmap) + 1)])
    return rgb_cmap

def apply_custom_cmap(img_gray, cmap):
    lut = np.zeros((256, 1, 3), dtype=np.uint8)
    # rgb
    lut[1: len(cmap) + 1, 0, 0] = cmap[:, 0]
    lut[1: len(cmap) + 1, 0, 1] = cmap[:, 1]
    lut[1: len(cmap) + 1, 0, 2] = cmap[:, 2]
    # apply
    img_rgb = cv2.LUT(img_gray, lut)
    return img_rgb

def mask_to_image(mask, cmap):
    if isinstance(cmap, int):
        # TODO: increase speed of the following three lines, as they are "overlapping"
        img_rgb = cv2.cvtColor(mask.astype(np.uint8), cv2.COLOR_BGR2RGB)
        img_rgb = cv2.applyColorMap(img_rgb, cmap)
        img_rgb = cv2.bitwise_and(img_rgb, img_rgb, mask=mask.astype(np.uint8))
    else:
        img_gray = cv2.cvtColor(mask.astype(np.uint8), cv2.COLOR_GRAY2BGR)
        img_rgb = apply_custom_cmap(img_gray, cmap)
    return img_rgb

def draw_circles(img_shape, centers, diameter, colors, cmap=cv2.COLORMAP_JET, thickness=-1):
    # black background
    canvas = np.zeros((img_shape[0], img_shape[1], 3))
    
    # color
    if isinstance(cmap, int):
        colors = cv2.cvtColor(colors.astype(np.uint8), cv2.COLOR_BGR2RGB)
        colors = cv2.applyColorMap(colors, cv2.COLORMAP_JET)
        colors = np.reshape(colors, (-1, 3))
    else:
        colors = cv2.cvtColor(colors.astype(np.uint8), cv2.COLOR_BGR2RGB)
        colors = apply_custom_cmap(colors, cmap)
        colors = np.reshape(colors, (-1, 3))

    # set diameter
    if isinstance(diameter, int):
        diameter = [diameter] * len(centers)
    
    # draw circles
    for center, d, color in zip(centers, diameter, colors):
        color = tuple(map(int, color))  # convert elements to int
        center = np.round(center).astype('int')
        radius = np.round(d / 2).astype('int')
        cv2.circle(canvas, center, radius, color, thickness)
    return canvas

# TODO: speed up `blend`
def blend(img_path, mask_path, mask_opacity):
    mask_img = cv2.imread(mask_path)
    bkgd_img = cv2.imread(img_path)
    
    # blend part of background
    mask = cv2.cvtColor(mask_img, cv2.COLOR_BGR2GRAY)
    bkgd_blend = cv2.bitwise_and(bkgd_img, bkgd_img, mask=mask)
    
    # non-blend part of background
    inv_mask = (mask == 0).astype(np.uint8)
    bkgd_non_blend = cv2.bitwise_and(bkgd_img, bkgd_img, mask=inv_mask)
    
    mask_ovelay = cv2.addWeighted(mask_img, mask_opacity, bkgd_blend, 1.0 - mask_opacity, 0)
    whole_img = cv2.addWeighted(mask_ovelay, 1.0, bkgd_non_blend, 1.0, 0)
    return whole_img

### Cell analysis

In [7]:
# cells info
cells_info = pd.read_csv(os.path.join(DATA_PATH, file['cells-info']))
cells_pos = cells_info[['x', 'y']].values
cells_type = cells_info['label'].values
cells_type_id = list(map(lambda x: RELATION[x], cells_type))
cells_gene = load_npz(os.path.join(DATA_PATH, file['cells-gene']))  # scipy.sparse.csr.csr_matrix
cells_gene_names = txt_to_list(os.path.join(DATA_PATH, file['cells-gene-names']))
cells_selected_gene = select_col_from_name(cells_gene, cells_gene_names, cells_selected_gene_name)
cells_selected_gene_normalized = normalize_array(cells_selected_gene, 1, 255)  # set min as 1 otherwise mask and background cannot be blended

In [8]:
# TODO: increase speed of `get_nuclei_pixels`
cells_matched_regions = get_nuclei_pixels(
    load_npz(os.path.join(DATA_PATH, file['mask'])).tocsr()[:, :].todense(),
    cells_pos,
)

In [9]:
# relabel mask by normalized gene expression
mask_by_cells_gene_norm = mask_filter_relabel(
    os.path.join(DATA_PATH, file['mask']),
    cells_matched_regions,
    cells_selected_gene_normalized,
)

In [10]:
# relabel mask by cell type
mask_by_cells_type = mask_filter_relabel(
    os.path.join(DATA_PATH, file['mask']),
    cells_matched_regions,
    cells_type_id,
)

In [11]:
# save mask by cells gene
_ = cv2.imwrite(
    os.path.join(PLOTS_PATH, cache['mask-cells-gene']),
    mask_to_image(mask_by_cells_gene_norm, cv2.COLORMAP_JET),
)

In [12]:
# save mask by cells type
_ = cv2.imwrite(
    os.path.join(PLOTS_PATH, cache['mask-cells-type']),
    mask_to_image(mask_by_cells_type, discrete_cmap_from_hex(HEX_PALETTE)),
)

In [13]:
# save blend image by cells type
_ = cv2.imwrite(
    os.path.join(PLOTS_PATH, cache['blend-cells-type']),
    blend(
        os.path.join(DATA_PATH, file['img']),
        os.path.join(PLOTS_PATH, cache['mask-cells-type']),
        1.0,
    ),
)

In [14]:
# save blend image by cells gene
_ = cv2.imwrite(
    os.path.join(PLOTS_PATH, cache['blend-cells-gene']),
    blend(
        os.path.join(DATA_PATH, file['img']),
        os.path.join(PLOTS_PATH, cache['mask-cells-gene']),
        1.0,
    ),
)

### Spot analysis

In [15]:
# spots info
spots_info = pd.read_csv(os.path.join(DATA_PATH, file['spots-info']))
spots_pos = spots_info[['x', 'y']].values
spots_diameter = spots_info['diameter'].values
spots_gene = load_npz(os.path.join(DATA_PATH, file['spots-gene']))  # scipy.sparse.csr.csr_matrix
spots_gene_names = txt_to_list(os.path.join(DATA_PATH, file['spots-gene-names']))
spots_selected_gene = select_col_from_name(spots_gene, spots_gene_names, spots_selected_gene_name)
spots_selected_gene_normalized = normalize_array(spots_selected_gene, 1, 255)  # set min as 1 otherwise mask and background cannot be blended

In [16]:
# get image shape
img_shape = load_npz(os.path.join(DATA_PATH, file['mask'])).shape

In [17]:
# save spots
_ = cv2.imwrite(
    os.path.join(PLOTS_PATH, cache['spots-gene']),
    draw_circles(
        img_shape,
        spots_pos,
        spots_diameter,
        spots_selected_gene_normalized,
        cmap=cv2.COLORMAP_JET,
        thickness=1,
    ),
)

In [18]:
# save blend spots
_ = cv2.imwrite(
    os.path.join(PLOTS_PATH, cache['blend-spots-gene']),
    blend(
        os.path.join(DATA_PATH, file['img']),
        os.path.join(PLOTS_PATH, cache['spots-gene']),
        1.0,
    ),
)

### Plotting

In [19]:
# fixed zoom range (control margin of geoplot)