In [127]:
from pathlib import Path
import time
import pickle

from tqdm import tqdm
import numpy as np
import pandas as pd
import scanpy as sc
import partipy as pt
import plotnine as pn
import matplotlib
import matplotlib.pyplot as plt

from data_utils import load_ms_xenium_data
from const import FIGURE_PATH, OUTPUT_PATH

## set up backend for matplotlib: https://matplotlib.org/stable/users/explain/figure/backends.html
matplotlib.use("Agg")

## set up output directory
figure_dir = Path(FIGURE_PATH) / "ms_bench_xenium"
figure_dir.mkdir(exist_ok=True, parents=True)

output_dir = Path(OUTPUT_PATH) / "ms_bench_xenium"
output_dir.mkdir(exist_ok=True, parents=True)

## setting up the optimization seetings
init_alg_list = pt.const.INIT_ALGS
optim_alg_list = [alg for alg in pt.const.OPTIM_ALGS if alg != "regularized_nnls"]
optim_settings_list = []
for init_alg in init_alg_list:
    for optim_alg in optim_alg_list:
        optim_settings_list.append(
            {
                "init_alg": init_alg,
                "optim_alg": optim_alg,
            }
        )
print(f"{len(optim_settings_list)=}")

## setting up different seeds to test
seed_list = [383329927, 3324115916, 2811363264, 1884968544, 1859786275, 
             3687649985, 369133708, 2995172877, 865305066, 404488628,
             2261209995, 4190266092, 3160032368, 3269070126, 3081541439, 
             3376120482, 2204291346, 550243861, 3606691181, 1934392872]
seed_list = seed_list[:1]

script_start_time = time.time()
print(f"### Start Time: {script_start_time}")

## downloading the data (or using cached data)
atlas_adata = load_ms_xenium_data()
print(atlas_adata)

## remap the cell type annotation to broader categories
mapping_dict = {
    "MP/MiGl_1": "Myeloid",
    "MP/MiGl_2": "Myeloid",
    "vascular_MP_1":"Myeloid",
    "vascular_MP_2": "Myeloid",
    "vascular_MP_3": "Myeloid",
    "Vascular_1": "Vascular",
    "Vascular_2": "Vascular",
    "Astro_WM": "Astrocyte",
    "Astro_GM": "Astrocyte",
    "Astro_WM_DA": "Astrocyte",
    "Astro_GM_DA": "Astrocyte",
    "OLG_WM": "Oligo",
    "OLG_WM_DA": "Oligo",
    "OLG_GM": "Oligo",
    "OPC": "OPC",
    "OPC_DA": "OPC",
    "COP": "COP",
    "NFOL/MFOL": "NFOL",
    "Schw": "Schwann ",
    "Endo": "Endothelial",
    "Neurons": "Neurons",
    "vascular_T-cell": "T_cell",
    "T-cell": "T_cell",
    "Ependymal": "Ependymal",
    "unknown": "unknown",
}

atlas_adata.obs["celltype"] = atlas_adata.obs["Level2"].map(mapping_dict)

celltype_column = "celltype"
celltype_labels = ["Oligo", "Astrocyte", "Myeloid", "Vascular", "Schwann", "OPC", "Endothelial", "T_cell"]
print(atlas_adata.obs.value_counts(celltype_column))

## qc settings
qc_columns = ["type_spec", "Level3"]

## number of archetypes per celltype
archetypes_to_test = list(range(2, 15))
number_of_archetypes_dict = {
    "Oligo": 5,
    "Astrocyte": 5,
    "Myeloid": 5,
    "Vascular": 5,
    "Schwann": 5,
    "OPC": 5,
    "Endothelial": 5,
    "T_cell": 5,
}
assert set(celltype_labels) == set(number_of_archetypes_dict.keys())
number_of_pcs_dict = {
    "Oligo": 10,
    "Astrocyte": 10,
    "Myeloid": 10,
    "Vascular": 10,
    "Schwann": 10,
    "OPC": 10,
    "Endothelial": 10,
    "T_cell": 10,
}
assert set(celltype_labels) == set(number_of_pcs_dict.keys())

## initialize list to save the benchmarking results
result_list = []
rss_trace_dict = {}

for celltype in celltype_labels:

    rss_trace_dict[celltype] = {}

    ## set up plotting directory per celltype
    figure_dir_celltype = figure_dir / celltype
    figure_dir_celltype.mkdir(exist_ok=True)

    ## subsetting and preprocessing per celltype
    adata = atlas_adata[atlas_adata.obs[celltype_column]==celltype, :].copy()
    print("\n#####\n->", celltype, "\n", adata)
    sc.pp.normalize_total(adata)
    sc.pp.log1p(adata)
    sc.pp.highly_variable_genes(adata)
    sc.pp.pca(adata, mask_var="highly_variable")

    break

len(optim_settings_list)=6
### Start Time: 1746187498.1313775
Zip file already exists: data/MS_xenium_data_v5_with_images_tmap.h5ad.zip
Extracted H5AD file already valid: data/MS_xenium_data_v5_with_images_tmap.h5ad
AnnData object with n_obs × n_vars = 660801 × 266
    obs: 'cell_id', 'x_centroid', 'y_centroid', 'transcript_counts', 'control_probe_counts', 'control_codeword_counts', 'unassigned_codeword_counts', 'total_counts', 'cell_area', 'nucleus_area', 'sample_id', 'n_counts', 'batch', 'type', 'type_spec', 'leiden_0.1', 'leiden_0.5', 'leiden_1', 'leiden_1.5', 'leiden_2', 'project', 'rotate', 'flip', 'Level0', 'reclustered_Level0', 'x_rotated_2', 'y_rotated_2', 'sex', 'age', 'y_rotated_mod', 'x_rotated_mod', 'Level1', 'Level2', 'Level3', 'compartment', 'compartment_2', 'compartment_2_colors', 'region_area', 'Level1_5', 'library_id'
    uns: 'Level0_colors', 'Level0_wilcoxon', 'Level1_5_colors', 'Level1_5_wilcoxon', 'Level1_colors', 'Level1_wilcoxon', 'Level2_colors', 'Level2_wilcoxo

In [121]:
atlas_adata.obs.columns

Index(['cell_id', 'x_centroid', 'y_centroid', 'transcript_counts',
       'control_probe_counts', 'control_codeword_counts',
       'unassigned_codeword_counts', 'total_counts', 'cell_area',
       'nucleus_area', 'sample_id', 'n_counts', 'batch', 'type', 'type_spec',
       'leiden_0.1', 'leiden_0.5', 'leiden_1', 'leiden_1.5', 'leiden_2',
       'project', 'rotate', 'flip', 'Level0', 'reclustered_Level0',
       'x_rotated_2', 'y_rotated_2', 'sex', 'age', 'y_rotated_mod',
       'x_rotated_mod', 'Level1', 'Level2', 'Level3', 'compartment',
       'compartment_2', 'compartment_2_colors', 'region_area', 'Level1_5',
       'library_id', 'celltype'],
      dtype='object')

In [126]:
atlas_adata.obs["Level3"]

0             T-cell
1             OLG_WM
2                OPC
3          MP/MiGl_2
4             OLG_WM
             ...    
680110    Vascular_1
680111    Vascular_1
680112    Vascular_2
680113    Vascular_1
680114    Vascular_1
Name: Level3, Length: 660801, dtype: category
Categories (27, object): ['Astro_GM', 'Astro_GM_DA', 'Astro_WM', 'Astro_WM_DA', ..., 'vascular_MP_1', 'vascular_MP_2', 'vascular_MP_3', 'vascular_T-cell']

In [129]:
adata

AnnData object with n_obs × n_vars = 153038 × 266
    obs: 'cell_id', 'x_centroid', 'y_centroid', 'transcript_counts', 'control_probe_counts', 'control_codeword_counts', 'unassigned_codeword_counts', 'total_counts', 'cell_area', 'nucleus_area', 'sample_id', 'n_counts', 'batch', 'type', 'type_spec', 'leiden_0.1', 'leiden_0.5', 'leiden_1', 'leiden_1.5', 'leiden_2', 'project', 'rotate', 'flip', 'Level0', 'reclustered_Level0', 'x_rotated_2', 'y_rotated_2', 'sex', 'age', 'y_rotated_mod', 'x_rotated_mod', 'Level1', 'Level2', 'Level3', 'compartment', 'compartment_2', 'compartment_2_colors', 'region_area', 'Level1_5', 'library_id', 'celltype'
    var: 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
    uns: 'Level0_colors', 'Level0_wilcoxon', 'Level1_5_colors', 'Level1_5_wilcoxon', 'Level1_colors', 'Level1_wilcoxon', 'Level2_colors', 'Level2_wilcoxon', 'Level3_colors', 'Level3_wilcoxon', 'age_colors', 'compartment_2_colors', 'compartment_colors', 'dendrogram_Level1_5', 'leiden', 