In [1]:
"""Solution notebook."""

import os
import cv2
import toml
import pandas as pd
import numpy as np
import plotly.graph_objects as go
from scipy.sparse import load_npz
from localtileserver import TileClient, get_leaflet_tile_layer
from niceview.utils.tools import *
from niceview.utils.cell import *
from niceview.utils.dataset import *
from niceview.utils.raster import *

In [2]:
# data sepecifc
def test():
    config = toml.load('user/config.toml')
    data_extension = {
        'cell': 'h5ad',
        'cell-gene': 'npz',
        'cell-gene-name': 'txt',
        'cell-info': 'csv',
        'cell-mask': 'npz',
        'wsi-img': 'tiff',
        'spot': 'h5ad',
        'spot-gene': 'npz',
        'spot-gene-name': 'txt',
        'spot-info': 'csv',
    }
    cache_extension = {
        'blend-cell-gene-img': 'png',
        'blend-cell-type-img': 'png',
        'blend-spot-gene-img': 'png',
        'mask-cell-gene-img': 'png',
        'mask-cell-type-img': 'png',
        'mask-cell-match-region': 'npy',
        'circle-spot-gene-img': 'png',
        'gis-blend-cell-gene-img': 'tiff',
        'gis-blend-cell-type-img': 'tiff',
        'gis-blend-spot-gene-img': 'tiff',
    }
    cell_label_encoder = {
        'pericyte': 1,
        'neuronal receptor cell': 2,
        'native cell': 3,
        'smooth muscle myoblast': 4,
        'immature innate lymphoid cell': 5,
        'cardiac endothelial cell': 6,
        'cardiac muscle myoblast': 7,
        'fibroblast of cardiac tissue': 8,
    }
    hex_palette = {
        1: '#d5a567',
        2: '#be6ca7',
        3: '#91d1e0',
        4: '#8bc060',
        5: '#879dcf',
        6: '#1d8942',
        7: '#253166',
        8: '#853087',
    }
    
    # args
    primary_key_list = ['gt-iz-p9-rep2']
    sample_id = 'gt-iz-p9-rep2'
    selected_cell_gene_name = 'ENSG00000065534'
    selected_spot_gene_name = 'ENSG00000065534'
    mask_opacity = 0.5
    colormap = 'jet'
    
    # dataset
    dataset = AristotleDataset(
        config['path']['data'], data_extension, 
        config['path']['cache'], cache_extension, 
        primary_key_list,
    )
    
    # analysis
    try:        
        # image information
        img_shape = load_npz(dataset.get_data_field(sample_id, 'cell-mask')).shape
        
        # cell
        cell_info = pd.read_csv(
            dataset.get_data_field(sample_id, 'cell-info'),
        )
        cell_pos = cell_info[['x', 'y']].values
        cell_label = list(
            map(lambda x: cell_label_encoder[x], cell_info['label'].values),
        )
        cell_gene = load_npz(
            dataset.get_data_field(sample_id, 'cell-gene'),
        )
        cell_gene_name = txt_to_list(
            dataset.get_data_field(sample_id, 'cell-gene-name'),
        )
        cell_selected_gene = select_col_from_name(
            cell_gene, cell_gene_name, selected_cell_gene_name
        )
        cell_selected_gene_norm = normalize_array(cell_selected_gene, 1, 255)
        if not os.path.exists(dataset.get_cache_field(sample_id, 'mask-cell-match-region')):
            cell_matched_region = get_nuclei_pixels(
                load_npz(
                    dataset.get_data_field(sample_id, 'cell-mask'),
                ).tocsr()[:, :].todense(),
                cell_pos,
            )
            np.save(
                dataset.get_cache_field(sample_id, 'mask-cell-match-region'),
                cell_matched_region,
                allow_pickle=True,
            )
        else:
            cell_matched_region = np.load(
                dataset.get_cache_field(sample_id, 'mask-cell-match-region'),
                allow_pickle=True,
            )
        mask_by_cell_selected_gene_norm = mask_filter_relabel(
            dataset.get_data_field(sample_id, 'cell-mask'),
            cell_matched_region,
            cell_selected_gene_norm,
        )
    
        # spot
        spot_info = pd.read_csv(dataset.get_data_field(sample_id, 'spot-info'))
        spot_pos = spot_info[['x', 'y']].values
        spot_diameter = spot_info['diameter'].values
        spot_gene = load_npz(dataset.get_data_field(sample_id, 'spot-gene'))
        spot_gene_name = txt_to_list(dataset.get_data_field(sample_id, 'spot-gene-name'))
        spot_selected_gene = select_col_from_name(
            spot_gene, spot_gene_name, selected_spot_gene_name,
        )
        spot_selected_gene_norm = normalize_array(spot_selected_gene, 1, 255)
    except:
        raise Exception('Bad input sample id')

    # cell
    # save images for mask cell selected gene and cell type
    if not os.path.exists(dataset.get_cache_field(sample_id, 'mask-cell-gene-img')):
        _ = cv2.imwrite(
            dataset.get_cache_field(sample_id, 'mask-cell-gene-img'),
            mask_to_image(mask_by_cell_selected_gene_norm, cv2.COLORMAP_JET),
        )
    if not os.path.exists(dataset.get_cache_field(sample_id, 'mask-cell-type-img')):
        _ = cv2.imwrite(
            dataset.get_cache_field(sample_id, 'mask-cell-type-img'),
            mask_to_image(
                mask_filter_relabel(
                    dataset.get_data_field(sample_id, 'cell-mask'),
                    cell_matched_region,
                    cell_label,
                ),
                discrete_cmap_from_hex(hex_palette),
            ),
        )

    # blend images for cell selected gene and cell type with WSI
    if not os.path.exists(dataset.get_cache_field(sample_id, 'blend-cell-gene-img')):
        _ = cv2.imwrite(
            dataset.get_cache_field(sample_id, 'blend-cell-gene-img'),
            blend(
                dataset.get_data_field(sample_id, 'wsi-img'),
                dataset.get_cache_field(sample_id, 'mask-cell-gene-img'),
                mask_opacity,
            ),
        )
    if not os.path.exists(dataset.get_cache_field(sample_id, 'blend-cell-type-img')):
        _ = cv2.imwrite(
            dataset.get_cache_field(sample_id, 'blend-cell-type-img'),
            blend(
                dataset.get_data_field(sample_id, 'wsi-img'),
                dataset.get_cache_field(sample_id, 'mask-cell-type-img'),
                mask_opacity,
            ),
        )

    # georeference images for blended cell selected gene and cell type
    if not os.path.exists(dataset.get_cache_field(sample_id, 'gis-blend-cell-gene-img')):
        _ = geo_ref_raster(
            dataset.get_cache_field(sample_id, 'blend-cell-gene-img'),
            dataset.get_cache_field(sample_id, 'gis-blend-cell-gene-img'),
        )
    if not os.path.exists(dataset.get_cache_field(sample_id, 'gis-blend-cell-type-img')):
        _ = geo_ref_raster(
            dataset.get_cache_field(sample_id, 'blend-cell-type-img'),
            dataset.get_cache_field(sample_id, 'gis-blend-cell-type-img'),
        )

    # spot
    # save images for spot selected gene
    if not os.path.exists(dataset.get_cache_field(sample_id, 'circle-spot-gene-img')):
        _ = cv2.imwrite(
            dataset.get_cache_field(sample_id, 'circle-spot-gene-img'),
            draw_circles(
                img_shape,
                spot_pos,
                spot_diameter,
                spot_selected_gene_norm,
                cmap=cv2.COLORMAP_JET,
                thickness=-1,
            ),
        )
    
    # blend images for spot selected gene with WSI
    if not os.path.exists(dataset.get_cache_field(sample_id, 'blend-spot-gene-img')):
        _ = cv2.imwrite(
            dataset.get_cache_field(sample_id, 'blend-spot-gene-img'),
            blend(
                dataset.get_data_field(sample_id, 'wsi-img'),
                dataset.get_cache_field(sample_id, 'circle-spot-gene-img'),
                mask_opacity,
            ),
        )

    # georeference images for blended spot selected gene
    if not os.path.exists(dataset.get_cache_field(sample_id, 'gis-blend-spot-gene-img')):
        _ = geo_ref_raster(
            dataset.get_cache_field(sample_id, 'blend-spot-gene-img'),
            dataset.get_cache_field(sample_id, 'gis-blend-spot-gene-img'),
        )
    
test()