In this notebook, we will be plotting the performance of RefCM across all currently available datasets, as well as testing what parameters yield the best performance.

In [1]:
import sys
sys.path.append('../src/')

import os
import json
import time
import torch
import config
import pandas as pd
import numpy as np
import scanpy as sc
import logging
import seaborn as sns
import plotly.express as px
import matplotlib.pyplot as plt

from refcm import RefCM
from typing import List
from anndata import AnnData
from itertools import product
from collections import defaultdict

# benchmark model imports
import scvi
import celltypist
from sklearn.svm import LinearSVC

# config
sns.set_theme()
scvi.settings.seed = 0
torch.set_float32_matmul_precision("high")
sc.set_figure_params(figsize=(6, 6), frameon=False)

# config.start_logging(logging.DEBUG)

  from .autonotebook import tqdm as notebook_tqdm
Seed set to 0


In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
class BenchModel:
    id_: str
    
    def __init__(self) -> None:
        pass
    
    def fit(self, ref: AnnData, label_key: str) -> None:
        """
        If necessary, fit the model on the reference dataset.

        Parameters
        ----------
        ref: AnnData
            The reference dataset with raw counts in .X
        label_key: str
            .obs key for reference labels.
        """
        pass
    
    def predict(self, query: AnnData, truth_key: str) -> None:
        """
        Annotates the query dataset using the reference data, and adds
        its predictions as a new column (self.id) under the query's .obs

        Parameters
        ----------
        query: AnnData
            The query dataset with raw counts in .X
        truth_key: str
            The query .obs key with the true labels; used for 
            clustering information for majority voting.
        """
        pass 
    
    def _mv(self, query: AnnData, truth_key: str) -> None:
        """
        Apply majority voting based on previous predictions and 
        ground truth clustering.
        
        Assumes BenchModel .predict() has already been run()

        Parameters
        ----------
        query: AnnData
            The query dataset.
        truth_key: str
            The ground truth labels i.e. clustering.
        """
        query.obs[f'mv-{self.id_}'] = query.obs[self.id_]
        truth_labels = sorted(query.obs[truth_key].unique().tolist())
        for cluster in truth_labels:
            cmask = query.obs[truth_key] == cluster
            mv = query.obs.loc[cmask, self.id_].value_counts().idxmax()
            query.obs.loc[cmask, f'mv-{self.id_}'] = mv
            
    
# TODO for cfmatrix, add option to normalize by row/column
def eval_preds(q: AnnData, ref: AnnData, ref_key: str,  q_truth_key: str, q_pred_key: str) -> None:
    """
    Evaluates the predictions of a given model on a query/reference pair.

    Parameters
    ----------
    q: AnnData
        Query AnnData object.
    ref: AnnData
        Reference AnnData object.
    ref_key: str
        .obs key for labels in reference dataset.
    q_truth_key: str
        .obs key for true labels in query dataset.
    q_pred_key: str
        .obs key for predicted labels in query dataset.

    Returns
    -------
    acc: float
        The % accuracy.
    cfmatrix: List[List[int]]
        The confusion matrix [true query labels] x [reference labels + novel]
    q_labels: List[str]
        The query labels (row/y axis in confusion matrix)
    ref_labels: List[str]
        Reference labels + novel (column/x axis in confusion matrix) 
    """
    
    q_labels = sorted(q.obs[q_truth_key].unique().tolist())
    ref_labels = sorted(ref.obs[ref_key].unique().tolist()) + ["novel"]
    
    novel_labels = list(set(q_labels) - set(ref_labels))
    
    correct = 0
    
    cfmatrix = np.zeros((len(q_labels), len(ref_labels)))
    
    for i in range(len(q_labels)):
        true_mask = q.obs[q_truth_key] == q_labels[i]
        
        for j in range(len(ref_labels)): 
            pred_mask = q.obs[q_pred_key] == ref_labels[j]
            
            if (q_labels[i] == ref_labels[j]) \
                or (q_labels[i] in novel_labels and ref_labels[j] == ""):
                correct += len(q.obs[(true_mask & pred_mask)])
                
            cfmatrix[i, j] = len(q.obs[(true_mask & pred_mask)])
    
    acc = correct / q.X.shape[0]
    return acc, cfmatrix.astype(int).tolist(), q_labels, ref_labels
    
def plot_cfmatrix(cfmatrix: List[List[int]], q_labels: List[str], ref_labels: List[str], 
                  width=750, height=750, show_nums: bool = False, angle_ticks: bool = True) -> None:
    """
    Plots the confusion matrix resulting from the above function.
    """
    fig = px.imshow(cfmatrix, x=ref_labels, y=q_labels, color_continuous_scale='Blues')
    
    if show_nums:
        for i in range(len(q_labels)):
            for j in range(len(ref_labels)):
                fig.add_annotation(
                    x=j,
                    y=i,
                    text=f"{cfmatrix[i][j]}",
                    showarrow=False,
                    # bgcolor="white",
                    # opacity=0.2,
                )
        fig.update_annotations(font=dict(color='white'))
        
    fig.update_layout(width=width, height=height)
    fig.update_xaxes(dtick=1)
    fig.update_yaxes(dtick=1)
    
    fig.update_xaxes(tickangle=-90)
    if angle_ticks:
        fig.update_xaxes(tickangle=-45)
        
    fig.show()

# Benchmarking models, wrapped

## Celltypist

Using the workflow suggested [here](https://pypi.org/project/celltypist/) and [here](https://colab.research.google.com/github/Teichlab/celltypist/blob/main/docs/notebook/celltypist_tutorial_cv.ipynb#scrollTo=assisted-earthquake)

In [4]:
class CellTypist(BenchModel):
    
    def __init__(self) -> None:
        self.id_ = 'celltypist'
        super().__init__()
    
    def fit(self, ref: AnnData, label_key: str) -> None:
        self.ref = ref
        self.label_key = label_key
        
        # pass here because we need to ensure the model trains
        # on genes available to both query and reference
        pass 
    
    
    def predict(self, query: AnnData, truth_key: str) -> None:
        self.q: AnnData = query
        
        # sum & log1p normalization as per official documentation
        q_sums = self.q.X.sum(axis=1).reshape((-1, 1))
        ref_sums = self.ref.X.sum(axis=1).reshape((-1, 1))
        
        self.q.X = np.log1p(1e4 * self.q.X / q_sums)
        self.ref.X = np.log1p(1e4 * self.ref.X / ref_sums)
        
        # skip downsampling sinze dataset sizes are <1e6 
        # use all intersecting genes for higher accuracy TODO test with just hvg intersection etc
        gs = np.intersect1d(self.q.var_names, self.ref.var_names)
        
        # model training
        t_start = time.perf_counter()
        self.model = celltypist.train(self.ref[:, gs], self.label_key, check_expression=False, n_jobs=10, max_iter=100)
        t_elapsed = time.perf_counter() - t_start
        print(f"[*] Training completed: {t_elapsed/60:.1f}mins")
        
        # model prediction
        preds = celltypist.annotate(query, self.model)
        self.q.obs[self.id_] = preds.predicted_labels['predicted_labels']
        
        # cleanup so datasets remain unchanged between executions
        self.q.X = np.expm1(self.q.X) * q_sums / 1e4
        self.ref.X = np.expm1(self.ref.X) * ref_sums / 1e4
        
        self._mv(query, truth_key)
        

In [51]:
# sample run on Allen-Brain datasets 

# load datasets
q = sc.read_h5ad('../data/VISp.h5ad')
ref = sc.read_h5ad('../data/MTG.h5ad')

# run the method
m = CellTypist()
m.fit(ref, 'labels34')
m.predict(q, 'labels34')

# evaluate the predictions
acc, cfmatrix, q_labels, ref_labels = eval_preds(q, ref, 'labels34', 'labels34', 'celltypist')
print(f"[*] accuracy: {acc:.4f}")
plot_cfmatrix(cfmatrix, q_labels, ref_labels)

# evaluate the MV predictions
acc, cfmatrix, q_labels, ref_labels = eval_preds(q, ref, 'labels34', 'labels34', 'mv-celltypist')
print(f"[*] accuracy: {acc:.4f}")
plot_cfmatrix(cfmatrix, q_labels, ref_labels)

## SCANVI

Using the workflow presented [here](https://docs.scvi-tools.org/en/stable/tutorials/notebooks/scrna/scarches_scvi_tools.html#reference-mapping-with-scanvi)

In [5]:
class SCANVI(BenchModel):
    
    def __init__(self) -> None:
        self.id_ = 'SCANVI'
        super().__init__()
    
    def fit(self, ref: AnnData, label_key: str) -> None:
        self.ref = ref
        self.label_key = label_key
        
        # pass here because we need to ensure the model trains
        # on genes available to both query and reference
        pass 
    
    
    def predict(self, query: AnnData, truth_key: str) -> None:
        self.q: AnnData = query
        
        # sum & log1p normalization as per official documentation
        q_sums = self.q.X.sum(axis=1).reshape((-1, 1))
        ref_sums = self.ref.X.sum(axis=1).reshape((-1, 1))
        
        self.q.X = np.log1p(1e4 * self.q.X / q_sums)
        self.ref.X = np.log1p(1e4 * self.ref.X / ref_sums)
        
        # TODO might need to actually use this gs value...
        # gs = np.intersect1d(self.q.var_names, self.ref.var_names)
        
        # model training
        t_start = time.perf_counter()
        scvi.model.SCVI.setup_anndata(ref)
        
        scvi_ref = scvi.model.SCVI(
            ref,
            use_layer_norm="both",
            use_batch_norm="none",
            encode_covariates=True,
            dropout_rate=0.2,
            n_layers=2,
        )
        scvi_ref.train()
        
        scanvi_ref = scvi.model.SCANVI.from_scvi_model(
            scvi_ref,
            unlabeled_category="Unknown",
            labels_key=self.label_key,
        )
        
        scanvi_ref.train(max_epochs=20, n_samples_per_label=100)
        
        # model prediction
        scvi.model.SCANVI.prepare_query_anndata(query, scanvi_ref)
        
        scanvi_query = scvi.model.SCANVI.load_query_data(query, scanvi_ref)
        
        scanvi_query.train(
            max_epochs=100, 
            plan_kwargs={"weight_decay": 0.0},
            check_val_every_n_epoch=10,
        )
        
        query.obs[self.id_] = scanvi_query.predict()
        
        t_elapsed = time.perf_counter() - t_start
        print(f"[*] Training completed: {t_elapsed/60:.1f}mins")
        
        # cleanup so datasets remain unchanged between executions
        self.q.X = np.expm1(self.q.X) * q_sums / 1e4
        self.ref.X = np.expm1(self.ref.X) * ref_sums / 1e4
        
        self._mv(query, truth_key)
        

In [8]:
# sample run on Allen-Brain datasets 

# load datasets
q = sc.read_h5ad('../data/VISp.h5ad')
ref = sc.read_h5ad('../data/MTG.h5ad')

# run the method
m = SCANVI()
m.fit(ref, 'labels34')
m.predict(q, 'labels34')

# evaluate the predictions
acc, cfmatrix, q_labels, ref_labels = eval_preds(q, ref, 'labels34', 'labels34', 'SCANVI')
print(f"[*] accuracy: {acc:.4f}")
plot_cfmatrix(cfmatrix, q_labels, ref_labels)

# evaluate the MV predictions
acc, cfmatrix, q_labels, ref_labels = eval_preds(q, ref, 'labels34', 'labels34', 'mv-SCANVI')
print(f"[*] accuracy: {acc:.4f}")
plot_cfmatrix(cfmatrix, q_labels, ref_labels)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


Epoch 20/20: 100%|██████████| 20/20 [03:07<00:00,  9.26s/it, v_num=1, train_loss_step=6.54e+3, train_loss_epoch=6.58e+3]

`Trainer.fit` stopped: `max_epochs=20` reached.


Epoch 20/20: 100%|██████████| 20/20 [03:07<00:00,  9.39s/it, v_num=1, train_loss_step=6.54e+3, train_loss_epoch=6.58e+3]
[34mINFO    [0m Training for [1;36m20[0m epochs.                                                                                   


GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


Epoch 20/20: 100%|██████████| 20/20 [04:11<00:00, 13.60s/it, v_num=1, train_loss_step=6.22e+3, train_loss_epoch=6.25e+3]

`Trainer.fit` stopped: `max_epochs=20` reached.


Epoch 20/20: 100%|██████████| 20/20 [04:11<00:00, 12.58s/it, v_num=1, train_loss_step=6.22e+3, train_loss_epoch=6.25e+3]
[34mINFO    [0m Found [1;36m100.0[0m% reference vars in query data.                                                                
[34mINFO    [0m Training for [1;36m10[0m epochs.                                                                                   


GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


Epoch 10/10: 100%|██████████| 10/10 [01:12<00:00,  7.09s/it, v_num=1, train_loss_step=8.28e+3, train_loss_epoch=8.34e+3]

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 10/10: 100%|██████████| 10/10 [01:12<00:00,  7.29s/it, v_num=1, train_loss_step=8.28e+3, train_loss_epoch=8.34e+3]
[*] accuracy: 0.2285


[*] accuracy: 0.2367


## SVM

A simple SVM classifier for benchmarking purposes.

In [6]:
class SVM(BenchModel):
    
    def __init__(self) -> None:
        self.id_ = 'SVM'
        super().__init__()
    
    def fit(self, ref: AnnData, label_key: str) -> None:
        self.ref = ref
        self.label_key = label_key
        
        # pass here because we need to ensure the model trains
        # on genes available to both query and reference
        pass 
    
    
    def predict(self, query: AnnData, truth_key: str) -> None:
        self.q: AnnData = query
        
        # sum & log1p normalization as per official documentation
        q_sums = self.q.X.sum(axis=1).reshape((-1, 1))
        ref_sums = self.ref.X.sum(axis=1).reshape((-1, 1))
        
        self.q.X = np.log1p(1e4 * self.q.X / q_sums)
        self.ref.X = np.log1p(1e4 * self.ref.X / ref_sums)
        
        # use all intersecting genes for higher accuracy
        gs = np.intersect1d(self.q.var_names, self.ref.var_names)
        sc.pp.highly_variable_genes(ref, n_top_genes=2000)
        hvg = ref.var["highly_variable"].index.to_list()
        gs = gs if len(np.intersect1d(gs, hvg)) < 1000 else np.intersect1d(gs, hvg)
        
        
        # model training
        t_start = time.perf_counter()
        svm = LinearSVC()
        svm.fit(self.ref[:, gs].X, self.ref.obs[self.label_key])
        
        t_elapsed = time.perf_counter() - t_start
        print(f"[*] Training completed: {t_elapsed/60:.1f}mins")
        
        # model prediction
        self.q.obs[self.id_] = svm.predict(self.q[:, gs].X)
        
        # cleanup so datasets remain unchanged between executions
        self.q.X = np.expm1(self.q.X) * q_sums / 1e4
        self.ref.X = np.expm1(self.ref.X) * ref_sums / 1e4
        
        self._mv(query, truth_key)
        

In [8]:
# sample run on Allen-Brain datasets 

# load datasets
q = sc.read_h5ad('../data/ALM.h5ad')
ref = sc.read_h5ad('../data/MTG.h5ad')

# run the method
m = SVM()
m.fit(ref, 'labels34')
m.predict(q, 'labels34')

# evaluate the predictions
acc, cfmatrix, q_labels, ref_labels = eval_preds(q, ref, 'labels34', 'labels34', 'SVM')
print(f"[*] accuracy: {acc:.4f}")
plot_cfmatrix(cfmatrix, q_labels, ref_labels)

# evaluate the MV predictions
acc, cfmatrix, q_labels, ref_labels = eval_preds(q, ref, 'labels34', 'labels34', 'mv-SVM')
print(f"[*] accuracy: {acc:.4f}")
plot_cfmatrix(cfmatrix, q_labels, ref_labels)

[*] Training completed: 1.9mins
[*] accuracy: 0.4599


[*] accuracy: 0.4994


## Seurat

We follow the workflow documented [here](https://satijalab.org/seurat/articles/integration_mapping)

In [4]:
import rpy2.robjects as ro
import rpy2.robjects.packages as rpackages
from rpy2.robjects import pandas2ri

pandas2ri.activate()

utils = rpackages.importr('utils')
utils.chooseCRANmirror(ind=1)
utils.install_packages('Seurat')

cffi mode is CFFI_MODE.ANY
R home found: /Library/Frameworks/R.framework/Resources
R library path: 
LD_LIBRARY_PATH: 
Default options to initialize R: rpy2, --quiet, --no-save
R is already initialized. No need to initialize.
R[write to console]: trying URL 'https://cloud.r-project.org/bin/macosx/big-sur-arm64/contrib/4.4/Seurat_5.1.0.tgz'

R[write to console]: Content type 'application/x-gzip'
R[write to console]:  length 4259806 bytes (4.1 MB)

R[write to console]: =
R[write to console]: =
R[write to console]: =
R[write to console]: =
R[write to console]: =
R[write to console]: =
R[write to console]: =
R[write to console]: =
R[write to console]: =
R[write to console]: =
R[write to console]: =
R[write to console]: =
R[write to console]: =
R[write to console]: =
R[write to console]: =
R[write to console]: =
R[write to console]: =
R[write to console]: =
R[write to console]: =
R[write to console]: =
R[write to console]: =
R[write to console]: =
R[write to console]: =
R[write to console]: 


The downloaded binary packages are in
	/var/folders/n3/d8yl0jxx2vl98cxyn9bljg6m0000gn/T//RtmpETT0gl/downloaded_packages


<rpy2.rinterface_lib.sexp.NULLType object at 0x337450110> [0]

In [11]:
class Seurat(BenchModel):
    
    def __init__(self) -> None:
        self.id_ = 'Seurat'
        super().__init__()
    
    def fit(self, ref: AnnData, label_key: str) -> None:
        self.ref = ref
        self.label_key = label_key
        pass 
    
    
    def predict(self, query: AnnData, truth_key: str) -> None:
        self.q: AnnData = query
        
        # Pass objects to R
        print("passing objects to R")
        ro.globalenv['q'] = self.q.to_df()
        ro.globalenv['ref'] = self.ref.to_df()
        ro.globalenv['q_labels'] = self.q.obs[truth_key]
        ro.globalenv['ref_labels'] = self.ref.obs[self.label_key]
        
        # model training
        print("running R code")
        t_start = time.perf_counter()
        
        ro.r('''
            library(Seurat)
            
            # Convert data frames to Seurat objects
            q <- CreateSeuratObject(counts = t(as.matrix(q)))
            ref <- CreateSeuratObject(counts = t(as.matrix(ref)))

            # Normalize the data
            q <- NormalizeData(q, verbose = FALSE)
            ref <- NormalizeData(ref, verbose = FALSE)

            # Feature selection
            q <- FindVariableFeatures(q, selection.method = "vst", nfeatures = 2000)
            ref <- FindVariableFeatures(ref, selection.method = "vst", nfeatures = 2000)

            # Find transfer anchors
            anchors <- FindTransferAnchors(reference = ref, query = q, dims = 1:30)

            # Transfer cell type labels from reference to query
            preds <- TransferData(anchorset = anchors, refdata = ref_labels, dims = 1:30)

            # Add predicted cell types to the query Seurat object
            q <- AddMetaData(object = q, metadata = preds)

            # Extract predicted cell types and store them
            preds <- q$predicted.id
        ''')
        
        self.q.obs[self.id_] = ro.globalenv['preds']
        
        t_elapsed = time.perf_counter() - t_start
        print(f"[*] Training completed: {t_elapsed/60:.1f}mins")
        
        self._mv(query, truth_key)
        

In [13]:
# sample run on Allen-Brain datasets 

# load datasets
q = sc.read_h5ad('../data/ALM.h5ad')
ref = sc.read_h5ad('../data/MTG.h5ad')

# run the method
m = Seurat()
m.fit(ref, 'labels34')
m.predict(q, 'labels34')

# evaluate the predictions
acc, cfmatrix, q_labels, ref_labels = eval_preds(q, ref, 'labels34', 'labels34', 'Seurat')
print(f"[*] accuracy: {acc:.4f}")
plot_cfmatrix(cfmatrix, q_labels, ref_labels)

# evaluate the MV predictions
acc, cfmatrix, q_labels, ref_labels = eval_preds(q, ref, 'labels34', 'labels34', 'mv-Seurat')
print(f"[*] accuracy: {acc:.4f}")
plot_cfmatrix(cfmatrix, q_labels, ref_labels)

passing objects to R
running R code


R[write to console]:  Data is of class matrix. Coercing to dgCMatrix.

R[write to console]:  Data is of class matrix. Coercing to dgCMatrix.

R[write to console]: Finding variable features for layer counts

R[write to console]: Calculating gene variances
R[write to console]: 

R[write to console]: 0%   10   20   30   40   50   60   70   80   90   100%

R[write to console]: [----|----|----|----|----|----|----|----|----|----|

R[write to console]: *
R[write to console]: *
R[write to console]: *
R[write to console]: *
R[write to console]: *
R[write to console]: *
R[write to console]: *
R[write to console]: *
R[write to console]: *
R[write to console]: *
R[write to console]: *
R[write to console]: *
R[write to console]: *
R[write to console]: *
R[write to console]: *
R[write to console]: *
R[write to console]: *
R[write to console]: *
R[write to console]: *
R[write to console]: *
R[write to console]: *
R[write to console]: *
R[write to console]: *
R[write to console]: *
R[write to console]

[*] Training completed: 0.8mins
[*] accuracy: 0.4357


[*] accuracy: 0.4669


# Full Benchmarking

In [17]:
pancreas = [
    "pancreas_celseq",
    "pancreas_celseq2",
    "pancreas_fluidigmc1",
    "pancreas_indrop1",
    "pancreas_indrop2",
    "pancreas_indrop3",
    "pancreas_indrop4",
    "pancreas_smarter",
    "pancreas_smartseq2"
    ]
pancreas = {(q, 'celltype'): [(ref, 'celltype') for ref in pancreas] for q in pancreas}

allenbrain = ["ALM", "MTG", "VISp"]
allenbrain = {(q, 'labels34'): [(ref, 'labels34') for ref in allenbrain] for q in allenbrain}

pbmc = [
    "pbmc_10Xv2",
    "pbmc_10Xv3",
    "pbmc_CEL-Seq",
    "pbmc_Drop-Seq",
    "pbmc_inDrop",
    "pbmc_Seq-Well",
    "pbmc_Smart-Seq2"
    ]
pbmc = {(q, 'labels'): [(ref, 'labels') for ref in pbmc] for q in pbmc}

lgn = [
    "LGN_human_intron",
    "LGN_human_exon",
    "LGN_macaque_intron",
    "LGN_macaque_exon",
    "LGN_mouse_intron",
    "LGN_mouse_exon"
    ]
lgn = {(q, 'cluster_labels'): [(ref, 'cluster_labels') for ref in lgn] for q in lgn}

typist = [
    "Blood",
    "Bone_marrow",
    "Heart",
    "Hippocampus",
    "Intestine",
    "Kidney",
    "Liver",
    "Lung",
    "Lymph_node",
    "Pancreas",
    "Skeletal_muscle",
    "Spleen"
    ]
typist = {(q, 'cell_type'): [(ref, 'cell_type') for ref in typist] for q in typist}


In [28]:
# sample run on Allen-Brain datasets 

# load datasets
q = sc.read_h5ad('../data/pancreas_celseq.h5ad')
ref = sc.read_h5ad('../data/pancreas_celseq2.h5ad')

# run the method
rcm = RefCM(cache_load=False, cache_save=False, discovery_threshold=None)
m = rcm.annotate(q, 'x', ref, 'y', 'celltype', 'celltype')

# # evaluate the predictions
# acc, cfmatrix, q_labels, ref_labels = eval_preds(q, ref, 'celltype', 'celltype', 'SVM')
# print(f"[*] accuracy: {acc:.4f}")
# plot_cfmatrix(cfmatrix, q_labels, ref_labels)

# # evaluate the MV predictions
# acc, cfmatrix, q_labels, ref_labels = eval_preds(q, ref, 'celltype', 'celltype', 'mv-SVM')
# print(f"[*] accuracy: {acc:.4f}")
# plot_cfmatrix(cfmatrix, q_labels, ref_labels)

NOTE: raw counts expected in anndata .X attributes.
|████████████████| [100.00% ] : 00:01


In [34]:
res = defaultdict(dict)

ds = allenbrain

for q_id, q_key in ds:
    q = sc.read_h5ad(f'../data/{q_id}.h5ad')
    for ref_id, ref_key in ds[(q_id, q_key)]:
        if q_id == ref_id:
            continue
        
        ref = sc.read_h5ad(f'../data/{ref_id}.h5ad')
        
        m = CellTypist()
        m.fit(ref, ref_key)
        m.predict(q, q_key)
        
        m = RefCM(cache_load=False, cache_save=False, discovery_threshold=None)
        m.annotate(q, q_id, ref, ref_id, ref_key, q_key)
        
        # m = SVM()
        # m.fit(ref, ref_key)
        # m.predict(q, q_key)
        
        res[q_id][ref_id] = {
            'refcm': eval_preds(q, ref, q_key, ref_key, 'refcm_annot'),
            'celltypist': eval_preds(q, ref, q_key, ref_key, 'celltypist'),
            'mv-celltypist': eval_preds(q, ref, q_key, ref_key, 'mv-celltypist'),
            # 'SVM': eval_preds(q, ref, q_key, ref_key, 'SVM'),
            # 'mv-SVM': eval_preds(q, ref, q_key, ref_key, 'mv-SVM')
        }
        print(res[q_id][ref_id])
    

🍳 Preparing data before training
✂️ 18 non-expressed genes are filtered out
🔬 Input data has 8128 cells and 16006 genes
⚖️ Scaling input data
🏋️ Training data using logistic regression
✅ Model training done!
🔬 Input data has 8128 cells and 16024 genes
🔗 Matching reference genes in the model
🧬 16006 features used for prediction


[*] Training completed: 0.7mins


⚖️ Scaling input data
🖋️ Predicting labels
✅ Prediction done!
NOTE: raw counts expected in anndata .X attributes.
|████████████████| [100.00% ] : 00:17


{'refcm': (1.0, [[2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 43, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 325, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 224, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 1037, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 165, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 703, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 199, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 223, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 

🍳 Preparing data before training
✂️ 2 non-expressed genes are filtered out
🔬 Input data has 14055 cells and 16022 genes
⚖️ Scaling input data
🏋️ Training data using logistic regression
✅ Model training done!
🔬 Input data has 8128 cells and 16024 genes
🔗 Matching reference genes in the model


[*] Training completed: 0.9mins


🧬 16022 features used for prediction
⚖️ Scaling input data
🖋️ Predicting labels
✅ Prediction done!
NOTE: raw counts expected in anndata .X attributes.
|████████████████| [100.00% ] : 00:32


{'refcm': (1.0, [[2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 43, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 325, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 224, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 1037, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 165, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 703, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 199, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 223, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 

🍳 Preparing data before training
✂️ 13 non-expressed genes are filtered out
🔬 Input data has 12552 cells and 16011 genes
⚖️ Scaling input data
🏋️ Training data using logistic regression
✅ Model training done!
🔬 Input data has 8128 cells and 16024 genes
🔗 Matching reference genes in the model
🧬 16011 features used for prediction


[*] Training completed: 1.1mins


⚖️ Scaling input data
🖋️ Predicting labels
✅ Prediction done!
NOTE: raw counts expected in anndata .X attributes.
|████████████████| [100.00% ] : 00:26


{'refcm': (1.0, [[2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 43, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 325, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 224, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 1037, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 165, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 703, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 199, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 223, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 

🍳 Preparing data before training
✂️ 18 non-expressed genes are filtered out
🔬 Input data has 8128 cells and 16006 genes
⚖️ Scaling input data
🏋️ Training data using logistic regression
✅ Model training done!
🔬 Input data has 14055 cells and 16024 genes
🔗 Matching reference genes in the model
🧬 16006 features used for prediction


[*] Training completed: 0.7mins


⚖️ Scaling input data
🖋️ Predicting labels
✅ Prediction done!
NOTE: raw counts expected in anndata .X attributes.
|████████████████| [100.00% ] : 00:28


{'refcm': (1.0, [[115, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 3245, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 279, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 1655, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 25, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 156, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 1806, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 326, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [

🍳 Preparing data before training
✂️ 2 non-expressed genes are filtered out
🔬 Input data has 14055 cells and 16022 genes
⚖️ Scaling input data
🏋️ Training data using logistic regression
✅ Model training done!
🔬 Input data has 14055 cells and 16024 genes
🔗 Matching reference genes in the model
🧬 16022 features used for prediction


[*] Training completed: 1.0mins


⚖️ Scaling input data
🖋️ Predicting labels
✅ Prediction done!
NOTE: raw counts expected in anndata .X attributes.
|████████████████| [100.00% ] : 00:48


{'refcm': (1.0, [[115, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 3245, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 279, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 1655, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 25, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 156, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 1806, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 326, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [

🍳 Preparing data before training
✂️ 13 non-expressed genes are filtered out
🔬 Input data has 12552 cells and 16011 genes
⚖️ Scaling input data
🏋️ Training data using logistic regression
✅ Model training done!
🔬 Input data has 14055 cells and 16024 genes
🔗 Matching reference genes in the model
🧬 16011 features used for prediction


[*] Training completed: 1.0mins


⚖️ Scaling input data
🖋️ Predicting labels
✅ Prediction done!
NOTE: raw counts expected in anndata .X attributes.
|████████████████| [100.00% ] : 00:41


{'refcm': (1.0, [[115, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 3245, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 279, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 1655, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 25, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 156, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 1806, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 326, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [

🍳 Preparing data before training
✂️ 18 non-expressed genes are filtered out
🔬 Input data has 8128 cells and 16006 genes
⚖️ Scaling input data
🏋️ Training data using logistic regression
✅ Model training done!
🔬 Input data has 12552 cells and 16024 genes
🔗 Matching reference genes in the model
🧬 16006 features used for prediction


[*] Training completed: 0.7mins


⚖️ Scaling input data
🖋️ Predicting labels
✅ Prediction done!
NOTE: raw counts expected in anndata .X attributes.
|████████████████| [100.00% ] : 00:24


{'refcm': (1.0, [[11, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 46, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 973, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 206, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 1349, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 449, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 167, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 470, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 160, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0,

🍳 Preparing data before training
✂️ 2 non-expressed genes are filtered out
🔬 Input data has 14055 cells and 16022 genes
⚖️ Scaling input data
🏋️ Training data using logistic regression
✅ Model training done!
🔬 Input data has 12552 cells and 16024 genes
🔗 Matching reference genes in the model
🧬 16022 features used for prediction


[*] Training completed: 1.0mins


⚖️ Scaling input data
🖋️ Predicting labels
✅ Prediction done!
NOTE: raw counts expected in anndata .X attributes.
|████████████████| [100.00% ] : 00:41


{'refcm': (1.0, [[11, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 46, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 973, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 206, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 1349, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 449, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 167, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 470, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 160, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0,

🍳 Preparing data before training
✂️ 13 non-expressed genes are filtered out
🔬 Input data has 12552 cells and 16011 genes
⚖️ Scaling input data
🏋️ Training data using logistic regression
✅ Model training done!
🔬 Input data has 12552 cells and 16024 genes
🔗 Matching reference genes in the model
🧬 16011 features used for prediction


[*] Training completed: 1.1mins


⚖️ Scaling input data
🖋️ Predicting labels
✅ Prediction done!
NOTE: raw counts expected in anndata .X attributes.
|████████████████| [100.00% ] : 00:38


{'refcm': (1.0, [[11, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 46, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 973, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 206, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 1349, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 449, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 167, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 470, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 160, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0,

In [47]:
cols = ['refcm', 'celltypist', 'mv-celltypist']
rows = []
table = []

for q in res:
    for ref in res[q]:
        if q == ref:
            continue
        
        rows.append(f"{q} -> {ref}")
        table.append([res[q][ref][col][0] for col in cols])

fig = px.imshow(table, x=cols, y=rows, color_continuous_scale='Blues')
    
for i in range(len(rows)):
    for j in range(len(cols)):
        fig.add_annotation(
            x=j,
            y=i,
            text=f"{table[i][j]:.2f}",
            showarrow=False,
        )
fig.update_annotations(font=dict(color='green'))#, size=20))

fig.update_layout(width=1000, height=700)
fig.update_xaxes(dtick=1)
fig.update_yaxes(dtick=1)

fig.update_xaxes(side='top', tickangle=-45)

fig.show()