In [1]:
import pandas as pd
import toml
import os
from scipy.sparse import load_npz
from niceview.utils.tools import *
from niceview.utils.cell import *

## data

In [2]:
config = toml.load('config.toml')

In [3]:
class AristotleDataset:
    
    def __init__(self, data_dir, data_extension, cache_dir, cache_extension, primary_key_list):
        self.data_dir = data_dir
        self.data_extension = data_extension
        self.cache_dir = cache_dir
        self.cache_extension = cache_extension
        self.primary_key_list = primary_key_list
        
    def _unparse_filename(self, primary_key, field_name, extension):
        filename = '-'.join([primary_key, field_name])
        filename = '.'.join([filename, extension])
        return filename
    
    def get_data_field(self, primary_key, data_field):
        if primary_key not in self.primary_key_list:
            raise ValueError('Bad input primary key')
        
        filename = self._unparse_filename(primary_key, data_field, self.data_extension[data_field])
        filepath = os.path.join(self.data_dir, filename)
        return filepath
    
    def get_cache_field(self, primary_key, cache_field):
        if primary_key not in self.primary_key_list:
            raise ValueError('Bad input primary key')
        
        filename = self._unparse_filename(primary_key, cache_field, self.cache_extension[cache_field])
        filepath = os.path.join(self.cache_dir, filename)
        return filepath

## Processing

In [4]:
"""Solution notebook."""

import os
from localtileserver import TileClient, get_leaflet_tile_layer
import plotly.graph_objects as go
from niceview.utils.cell import get_nuclei_pixels, paint_regions
from niceview.utils.raster import geo_ref_raster, geo_raster_to_meshgrid, index_to_meshgrid_coord
from niceview.utils.tools import txt_to_list, select_col_from_name
import pandas as pd
import numpy as np
from scipy.sparse import load_npz
import cv2

In [5]:
# helper functions
def normalize_array(arr, new_min, new_max):
    min_val = np.min(arr)
    max_val = np.max(arr)
    normalized_arr = (arr - min_val) / (max_val - min_val) * (new_max - new_min) + new_min
    return normalized_arr

def mask_filter_relabel(mask_path, matched_regions, labels):
    mask = load_npz(mask_path)
    mask = mask.tocsr()[:, :].todense()
    # TODO: increase speed of `paint_regions`
    mask_filtered_relabeled = paint_regions(mask.shape, matched_regions, cell_colors_list=labels)
    return mask_filtered_relabeled.data

def hex_to_rgb(hex_color):
    # Remove the '#' symbol if it's present
    if hex_color.startswith("#"):
        hex_color = hex_color[1:]

    # Convert each pair from hexadecimal to decimal
    r = int(hex_color[0:2], 16)
    g = int(hex_color[2:4], 16)
    b = int(hex_color[4:6], 16)

    # Return the RGB values as a tuple
    return (r, g, b)

def discrete_cmap_from_hex(id_to_hex_dict):
    rgb_cmap = {k: hex_to_rgb(v) for k, v in id_to_hex_dict.items()}
    rgb_cmap = np.array([rgb_cmap[i] for i in range(1, len(rgb_cmap) + 1)])
    return rgb_cmap

def apply_custom_cmap(img_gray, cmap):
    lut = np.zeros((256, 1, 3), dtype=np.uint8)
    # rgb
    lut[1: len(cmap) + 1, 0, 0] = cmap[:, 0]
    lut[1: len(cmap) + 1, 0, 1] = cmap[:, 1]
    lut[1: len(cmap) + 1, 0, 2] = cmap[:, 2]
    # apply
    img_rgb = cv2.LUT(img_gray, lut)
    return img_rgb

def mask_to_image(mask, cmap):
    if isinstance(cmap, int):
        # TODO: increase speed of the following three lines, as they are "overlapping"
        img_rgb = cv2.cvtColor(mask.astype(np.uint8), cv2.COLOR_BGR2RGB)
        img_rgb = cv2.applyColorMap(img_rgb, cmap)
        img_rgb = cv2.bitwise_and(img_rgb, img_rgb, mask=mask.astype(np.uint8))
    else:
        img_gray = cv2.cvtColor(mask.astype(np.uint8), cv2.COLOR_GRAY2BGR)
        img_rgb = apply_custom_cmap(img_gray, cmap)
    return img_rgb

def draw_circles(img_shape, centers, diameter, colors, cmap=cv2.COLORMAP_JET, thickness=-1):
    # black background
    canvas = np.zeros((img_shape[0], img_shape[1], 3))

    # color
    if isinstance(cmap, int):
        colors = cv2.cvtColor(colors.astype(np.uint8), cv2.COLOR_BGR2RGB)
        colors = cv2.applyColorMap(colors, cv2.COLORMAP_JET)
        colors = np.reshape(colors, (-1, 3))
    else:
        colors = cv2.cvtColor(colors.astype(np.uint8), cv2.COLOR_BGR2RGB)
        colors = apply_custom_cmap(colors, cmap)
        colors = np.reshape(colors, (-1, 3))

    # set diameter
    if isinstance(diameter, int):
        diameter = [diameter] * len(centers)

    # draw circles
    for center, d, color in zip(centers, diameter, colors):
        color = tuple(map(int, color))  # convert elements to int
        center = np.round(center).astype('int')
        radius = np.round(d / 2).astype('int')
        cv2.circle(canvas, center, radius, color, thickness)
    return canvas

# TODO: speed up `blend`
def blend(img_path, mask_path, mask_opacity):
    mask_img = cv2.imread(mask_path)
    bkgd_img = cv2.imread(img_path)

    # blend part of background
    mask = cv2.cvtColor(mask_img, cv2.COLOR_BGR2GRAY)
    bkgd_blend = cv2.bitwise_and(bkgd_img, bkgd_img, mask=mask)

    # non-blend part of background
    inv_mask = (mask == 0).astype(np.uint8)
    bkgd_non_blend = cv2.bitwise_and(bkgd_img, bkgd_img, mask=inv_mask)

    mask_ovelay = cv2.addWeighted(mask_img, mask_opacity, bkgd_blend, 1.0 - mask_opacity, 0)
    whole_img = cv2.addWeighted(mask_ovelay, 1.0, bkgd_non_blend, 1.0, 0)
    return whole_img

In [6]:
# data sepecifc

def test_on_cell_analysis():
    # about dataset
    primary_key_list = ['gt-iz-p9-rep2']
    data_extension = {
        'cell': 'h5ad',
        'cell-gene': 'npz',
        'cell-gene-name': 'txt',
        'cell-info': 'csv',
        'cell-mask': 'npz',
        'wsi-img': 'tiff',
        'spot': 'h5ad',
        'spot-gene': 'npz',
        'spot-gene-name': 'txt',
        'spot-info': 'csv',
    }
    cache_extension = {
        'blend-cell-gene-img': 'png',
        'blend-cell-type-img': 'png',
        'blend-spot-gene-img': 'png',
        'mask-cell-gene-img': 'png',
        'mask-cell-type-img': 'png',
        'circle-spot-gene-img': 'png',
        'gis-blend-cell-gene-img': 'tiff',
        'gis-blend-cell-type-img': 'tiff',
        'gis-blend-spot-gene-img': 'tiff',
}
    cell_label_encoder = {
        'pericyte': 1,
        'neuronal receptor cell': 2,
        'native cell': 3,
        'smooth muscle myoblast': 4,
        'immature innate lymphoid cell': 5,
        'cardiac endothelial cell': 6,
        'cardiac muscle myoblast': 7,
        'fibroblast of cardiac tissue': 8,
    }
    hex_palette = {
        1: '#d5a567',
        2: '#be6ca7',
        3: '#91d1e0',
        4: '#8bc060',
        5: '#879dcf',
        6: '#1d8942',
        7: '#253166',
        8: '#853087',
    }
    
    # about config
    sample_id = 'gt-iz-p9-rep2'
    selected_cell_gene_name = 'ENSG00000065534'
    selected_spot_gene_name = 'ENSG00000065534'
    mask_opacity = 0.5
    
    # pipeline
    dataset = AristotleDataset(
        config['path']['data'], data_extension, 
        config['path']['cache'], cache_extension, 
        primary_key_list
    )
    cell_info = pd.read_csv(
        dataset.get_data_field(sample_id, 'cell-info'),
    )
    cell_pos = cell_info[['x', 'y']].values
    cell_label = list(
        map(lambda x: cell_label_encoder[x], cell_info['label'].values),
    )
    cell_gene = load_npz(
        dataset.get_data_field(sample_id, 'cell-gene'),
    )
    cell_gene_name = txt_to_list(
        dataset.get_data_field(sample_id, 'cell-gene-name'),
    )
    cell_selected_gene = select_col_from_name(
        cell_gene, cell_gene_name, selected_cell_gene_name
    )
    cell_selected_gene_norm = normalize_array(cell_selected_gene, 1, 255)
    
    cell_matched_region = get_nuclei_pixels(
        load_npz(
            dataset.get_data_field(sample_id, 'cell-mask'),
        ).tocsr()[:, :].todense(),
        cell_pos,
    )
    mask_by_cell_selected_gene_norm = mask_filter_relabel(
        dataset.get_data_field(sample_id, 'cell-mask'),
        cell_matched_region,
        cell_selected_gene_norm,
    )
    _ = cv2.imwrite(
        dataset.get_cache_field(sample_id, 'mask-cell-gene-img'),
        mask_to_image(mask_by_cell_selected_gene_norm, cv2.COLORMAP_JET),
    )
    _ = cv2.imwrite(
        dataset.get_cache_field(sample_id, 'mask-cell-type-img'),
        mask_to_image(
            mask_filter_relabel(
                dataset.get_data_field(sample_id, 'cell-mask'),
                cell_matched_region,
                cell_label,
            ),
            discrete_cmap_from_hex(hex_palette),
        ),
    )
    _ = cv2.imwrite(
        dataset.get_cache_field(sample_id, 'blend-cell-gene-img'),
        blend(
            dataset.get_data_field(sample_id, 'wsi-img'),
            dataset.get_cache_field(sample_id, 'mask-cell-gene-img'),
            mask_opacity,
        ),
    )
    _ = cv2.imwrite(
        dataset.get_cache_field(sample_id, 'blend-cell-type-img'),
        blend(
            dataset.get_data_field(sample_id, 'wsi-img'),
            dataset.get_cache_field(sample_id, 'mask-cell-type-img'),
            mask_opacity,
        ),
    )

test_on_cell_analysis()