In [1]:
"""Solution notebook."""

import os
import cv2
import toml
import json
import pandas as pd
import numpy as np
from scipy.sparse import load_npz
from niceview.utils.tools import *
from niceview.utils.cell import *
from niceview.utils.dataset import *
from niceview.utils.raster import *

In [2]:
config = toml.load('user/config.toml')
data_path = config['path']['data']
cache_path = config['path']['cache']

In [3]:
with open('./db/db-info.json', 'r') as json_file:
    # Use json.load() to load the JSON data
    db_info = json.load(json_file)

data_extension = db_info['data_extension']
cache_extension = db_info['cache_extension']
cell_label_encoder = db_info['cell_label_encoder']
cell_label_cmap = db_info['cell_label_cmap']
primary_key_list = db_info['primary_key_list']

In [4]:
with open('./user/args.json') as f:
    args = json.load(f)

In [5]:
# data sepecifc
def test(
    data_path, cache_path, 
    data_extension, cache_extension, 
    cell_label_encoder, cell_label_cmap, 
    primary_key_list,
):

    # args
    sample_id = 'gt-iz-p9-rep2'
    selected_cell_gene_name = 'ENSG00000065534'
    selected_spot_gene_name = 'ENSG00000065534'
    mask_opacity = 0.5
    colormap = 'jet'
    
    # dataset
    dataset = AristotleDataset(
        data_path, data_extension, 
        cache_path, cache_extension, 
        primary_key_list,
    )
    
    # analysis
    try:        
        # image information
        img_shape = load_npz(dataset.get_data_field(sample_id, 'cell-mask')).shape
        
        # cell
        cell_info = pd.read_csv(
            dataset.get_data_field(sample_id, 'cell-info'),
        )
        cell_pos = cell_info[['x', 'y']].values
        cell_label = list(
            map(lambda x: cell_label_encoder[x], cell_info['label'].values),
        )
        cell_gene = load_npz(
            dataset.get_data_field(sample_id, 'cell-gene'),
        )
        cell_gene_name = txt_to_list(
            dataset.get_data_field(sample_id, 'cell-gene-name'),
        )
        cell_selected_gene = select_col_from_name(
            cell_gene, cell_gene_name, selected_cell_gene_name
        )
        cell_selected_gene_norm = normalize_array(cell_selected_gene, 1, 255)
        if not os.path.exists(dataset.get_cache_field(sample_id, 'mask-cell-match-region')):
            cell_matched_region = get_nuclei_pixels(
                load_npz(
                    dataset.get_data_field(sample_id, 'cell-mask'),
                ).tocsr()[:, :].todense(),
                cell_pos,
            )
            np.save(
                dataset.get_cache_field(sample_id, 'mask-cell-match-region'),
                cell_matched_region,
                allow_pickle=True,
            )
        else:
            cell_matched_region = np.load(
                dataset.get_cache_field(sample_id, 'mask-cell-match-region'),
                allow_pickle=True,
            )
        mask_by_cell_selected_gene_norm = mask_filter_relabel(
            dataset.get_data_field(sample_id, 'cell-mask'),
            cell_matched_region,
            cell_selected_gene_norm,
        )
    
        # spot
        spot_info = pd.read_csv(dataset.get_data_field(sample_id, 'spot-info'))
        spot_pos = spot_info[['x', 'y']].values
        spot_diameter = spot_info['diameter'].values
        spot_gene = load_npz(dataset.get_data_field(sample_id, 'spot-gene'))
        spot_gene_name = txt_to_list(dataset.get_data_field(sample_id, 'spot-gene-name'))
        spot_selected_gene = select_col_from_name(
            spot_gene, spot_gene_name, selected_spot_gene_name,
        )
        spot_selected_gene_norm = normalize_array(spot_selected_gene, 1, 255)
    except:
        raise Exception('Bad input sample id')

    # cell
    # save images for mask cell selected gene and cell type
    if not os.path.exists(dataset.get_cache_field(sample_id, 'mask-cell-gene-img')):
        _ = cv2.imwrite(
            dataset.get_cache_field(sample_id, 'mask-cell-gene-img'),
            mask_to_image(mask_by_cell_selected_gene_norm, cv2.COLORMAP_JET),
        )
    if not os.path.exists(dataset.get_cache_field(sample_id, 'mask-cell-type-img')):
        _ = cv2.imwrite(
            dataset.get_cache_field(sample_id, 'mask-cell-type-img'),
            mask_to_image(
                mask_filter_relabel(
                    dataset.get_data_field(sample_id, 'cell-mask'),
                    cell_matched_region,
                    cell_label,
                ),
                discrete_cmap_from_hex(cell_label_cmap),
            ),
        )

    # blend images for cell selected gene and cell type with WSI
    if not os.path.exists(dataset.get_cache_field(sample_id, 'blend-cell-gene-img')):
        _ = cv2.imwrite(
            dataset.get_cache_field(sample_id, 'blend-cell-gene-img'),
            blend(
                dataset.get_data_field(sample_id, 'wsi-img'),
                dataset.get_cache_field(sample_id, 'mask-cell-gene-img'),
                mask_opacity,
            ),
        )
    if not os.path.exists(dataset.get_cache_field(sample_id, 'blend-cell-type-img')):
        _ = cv2.imwrite(
            dataset.get_cache_field(sample_id, 'blend-cell-type-img'),
            blend(
                dataset.get_data_field(sample_id, 'wsi-img'),
                dataset.get_cache_field(sample_id, 'mask-cell-type-img'),
                mask_opacity,
            ),
        )

    # georeference images for blended cell selected gene and cell type
    if not os.path.exists(dataset.get_cache_field(sample_id, 'gis-blend-cell-gene-img')):
        _ = geo_ref_raster(
            dataset.get_cache_field(sample_id, 'blend-cell-gene-img'),
            dataset.get_cache_field(sample_id, 'gis-blend-cell-gene-img'),
        )
    if not os.path.exists(dataset.get_cache_field(sample_id, 'gis-blend-cell-type-img')):
        _ = geo_ref_raster(
            dataset.get_cache_field(sample_id, 'blend-cell-type-img'),
            dataset.get_cache_field(sample_id, 'gis-blend-cell-type-img'),
        )

    # spot
    # save images for spot selected gene
    if not os.path.exists(dataset.get_cache_field(sample_id, 'circle-spot-gene-img')):
        _ = cv2.imwrite(
            dataset.get_cache_field(sample_id, 'circle-spot-gene-img'),
            draw_circles(
                img_shape,
                spot_pos,
                spot_diameter,
                spot_selected_gene_norm,
                cmap=cv2.COLORMAP_JET,
                thickness=-1,
            ),
        )
    
    # blend images for spot selected gene with WSI
    if not os.path.exists(dataset.get_cache_field(sample_id, 'blend-spot-gene-img')):
        _ = cv2.imwrite(
            dataset.get_cache_field(sample_id, 'blend-spot-gene-img'),
            blend(
                dataset.get_data_field(sample_id, 'wsi-img'),
                dataset.get_cache_field(sample_id, 'circle-spot-gene-img'),
                mask_opacity,
            ),
        )

    # georeference images for blended spot selected gene
    if not os.path.exists(dataset.get_cache_field(sample_id, 'gis-blend-spot-gene-img')):
        _ = geo_ref_raster(
            dataset.get_cache_field(sample_id, 'blend-spot-gene-img'),
            dataset.get_cache_field(sample_id, 'gis-blend-spot-gene-img'),
        )
    
# test(
#     data_path,
#     cache_path,
#     data_extension,
#     cache_extension,
#     cell_label_encoder,
#     cell_label_cmap,
#     primary_key_list,
# )

In [6]:
from niceview.utils.dataset import ThorQuery

In [7]:
thor = ThorQuery(
    data_path,
    cache_path,
    data_extension,
    cache_extension,
    cell_label_encoder,
    cell_label_cmap,
    primary_key_list,
)

In [None]:
sample_id = 'gt-iz-p9-rep2'

In [8]:
thor.cell_gis(
    sample_id,
    'ENSG00000065534',
    True,
)

In [9]:
thor.spot_gis(
    sample_id,
    'ENSG00000065534',
)

In [None]:
thor.empty_cache_cell()