In this notebook, we will be plotting the performance of RefCM across all currently available datasets, as well as testing what parameters yield the best performance.

In [1]:
import sys
sys.path.append('../src/')

import os
import json
import config
import logging
import time
import numpy as np
import scanpy as sc

from refcm import RefCM
from typing import List
from anndata import AnnData
import plotly.express as px

# benchmark model imports
import celltypist


# config.start_logging(logging.DEBUG)

In [2]:
%load_ext autoreload
%autoreload 2

In [13]:
class BenchModel:
    id_: str
    
    def __init__(self) -> None:
        pass
    
    def fit(self, ref: AnnData, label_key: str) -> None:
        """
        If necessary, fit the model on the reference dataset.

        Parameters
        ----------
        ref: AnnData
            The reference dataset with raw counts in .X
        label_key: str
            .obs key for reference labels.
        """
        pass
    
    def predict(self, query: AnnData) -> None:
        """
        Annotates the query dataset using the reference data, and adds
        its predictions as a new column (self.id) under the query's .obs

        Parameters
        ----------
        query: AnnData
            The query dataset with raw counts in .X
        """
        pass 
    
    
def eval_preds(q: AnnData, ref: AnnData, ref_key: str,  q_truth_key: str, q_pred_key: str) -> None:
    """
    Evaluates the predictions of a given model on a query/reference pair.

    Parameters
    ----------
    q: AnnData
        Query AnnData object.
    ref: AnnData
        Reference AnnData object.
    ref_key: str
        .obs key for labels in reference dataset.
    q_truth_key: str
        .obs key for true labels in query dataset.
    q_pred_key: str
        .obs key for predicted labels in query dataset.

    Returns
    -------
    acc: float
        The % accuracy.
    cfmatrix: List[List[int]]
        The confusion matrix [true query labels] x [reference labels + novel]
    q_labels: List[str]
        The query labels (row/y axis in confusion matrix)
    ref_labels: List[str]
        Reference labels + novel (column/x axis in confusion matrix) 
    """
    
    q_labels = sorted(q.obs[q_truth_key].unique().tolist())
    ref_labels = sorted(ref.obs[ref_key].unique().tolist()) + ["novel"]
    
    novel_labels = list(set(q_labels) - set(ref_labels))
    
    correct = 0
    
    cfmatrix = np.zeros((len(q_labels), len(ref_labels)))
    
    for i in range(len(q_labels)):
        true_mask = q.obs[q_truth_key] == q_labels[i]
        
        for j in range(len(ref_labels)): 
            pred_mask = q.obs[q_pred_key] == ref_labels[j]
            
            if (q_labels[i] == ref_labels[j]) \
                or (q_labels[i] in novel_labels and ref_labels[j] == ""):
                correct += len(q.obs[(true_mask & pred_mask)])
                
            cfmatrix[i, j] = len(q.obs[(true_mask & pred_mask)])
    
    acc = correct / q.X.shape[0]
    return acc, cfmatrix.astype(int).tolist(), q_labels, ref_labels
    
def plot_cfmatrix(cfmatrix: List[List[int]], q_labels: List[str], ref_labels: List[str], 
                  width=750, height=750, show_nums: bool = False, angle_ticks: bool = True) -> None:
    """
    Plots the confusion matrix resulting from the above function.
    """
    fig = px.imshow(cfmatrix, x=ref_labels, y=q_labels, color_continuous_scale='Blues')
    
    if show_nums:
        for i in range(len(q_labels)):
            for j in range(len(ref_labels)):
                fig.add_annotation(
                    x=j,
                    y=i,
                    text=f"{cfmatrix[i][j]}",
                    showarrow=False,
                    # bgcolor="white",
                    # opacity=0.2,
                )
        fig.update_annotations(font=dict(color='white'))
        
    fig.update_layout(width=width, height=height)
    fig.update_xaxes(dtick=1)
    fig.update_yaxes(dtick=1)
    
    fig.update_xaxes(tickangle=-90)
    if angle_ticks:
        fig.update_xaxes(tickangle=-45)
        
    fig.show()

All currently tested datasets, and their associated .obs clustering key:

In [4]:
DSS = {
    # # celltypist datasets
    "Blood": "cell_type",
    "Bone_marrow": "cell_type",
    # "Heart": "cell_type",
    # # "Hippocampus": "cell_type",
    # "Intestine": "cell_type",
    # "Kidney": "cell_type",
    # "Liver": "cell_type",
    "Lung": "cell_type",
    "Lymph_node": "cell_type",
    "Pancreas": "cell_type",
    "Skeletal_muscle": "cell_type",
    "Spleen": "cell_type",
    
    # # # gut atlas
    # "gut_atlas_tcell": "annotation",
    
    # Allen-Brain datasets
    # "ALM": "labels34",
    # "MTG": "labels34",
    # "VISp": "labels34",
    
    # pbmc datasets
    # "pbmc_10Xv2": "labels",
    # "pbmc_10Xv3": "labels",
    # "pbmc_CEL-Seq": "labels",
    # "pbmc_Drop-Seq": "labels",
    # "pbmc_inDrop": "labels",
    # "pbmc_Seq-Well": "labels",
    # "pbmc_Smart-Seq2": "labels",
    
    # pancreas datasets
    # "pancreas_celseq": "celltype",
    # "pancreas_celseq2": "celltype",
    # "pancreas_fluidigmc1": "celltype",
    # "pancreas_indrop1": "celltype",
    # "pancreas_indrop2": "celltype",
    # "pancreas_indrop3": "celltype",
    # "pancreas_indrop4": "celltype",
    # "pancreas_smarter": "celltype",
    # "pancreas_smartseq2": "celltype",
    
    # # LGN datasets
    # "LGN_human_intron": "cluster_label",
    # "LGN_human_exon": "cluster_label",
    # "LGN_macaque_intron": "cluster_label",
    # "LGN_macaque_exon": "cluster_label",
    # "LGN_mouse_intron": "cluster_label",
    # "LGN_mouse_exon": "cluster_label",
}

# Celltypist

Using the workflow suggested [here](https://pypi.org/project/celltypist/) and [here](https://colab.research.google.com/github/Teichlab/celltypist/blob/main/docs/notebook/celltypist_tutorial_cv.ipynb#scrollTo=assisted-earthquake)

In [5]:
class CellTypist(BenchModel):
    
    def __init__(self) -> None:
        self.id_ = 'celltypist'
        super().__init__()
    
    def fit(self, ref: AnnData, label_key: str) -> None:
        self.ref = ref
        self.label_key = label_key
        # pass here because we need to ensure the model trains
        # on genes available to both query and reference
        pass 
    
    
    def predict(self, query: AnnData) -> None:
        self.q: AnnData = query
        
        # sum & log1p normalization as per official documentation
        q_sums = self.q.X.sum(axis=1).reshape((-1, 1))
        ref_sums = self.ref.X.sum(axis=1).reshape((-1, 1))
        
        self.q.X = np.log1p(1e4 * self.q.X / q_sums)
        self.ref.X = np.log1p(1e4 * self.ref.X / ref_sums)
        
        # skip downsampling sinze dataset sizes are <1e6 
        # use all intersecting genes for higher accuracy
        gs = np.intersect1d(self.q.var_names, self.ref.var_names)
        
        # model training
        t_start = time.perf_counter()
        self.model = celltypist.train(self.ref[:, gs], self.label_key, check_expression=False, n_jobs=10, max_iter=100)
        t_elapsed = time.perf_counter() - t_start
        print(f"[*] Training completed: {t_elapsed/60:.1f}mins")
                
        # model prediction
        preds = celltypist.annotate(query, self.model)
        self.q.obs['celltypist'] = preds.predicted_labels['predicted_labels']

        
        # cleanup so ref remains unchanged between executions
        self.q.X = np.expm1(self.q.X) * q_sums / 1e4
        self.ref.X = np.expm1(self.ref.X) * ref_sums / 1e4
        

In [16]:
# load datasets
q = sc.read_h5ad('../data/ALM.h5ad')
ref = sc.read_h5ad('../data/MTG.h5ad')

# run the method
m = CellTypist()
m.fit(ref, 'labels34')
m.predict(q)

# evaluate the predictions
acc, cfmatrix, q_labels, ref_labels = eval_preds(q, ref, 'labels34', 'labels34', 'celltypist')
print(f"[*] accuracy: {acc:.2f}")
plot_cfmatrix(cfmatrix, q_labels, ref_labels)

🍳 Preparing data before training
✂️ 2 non-expressed genes are filtered out
🔬 Input data has 14055 cells and 16022 genes
⚖️ Scaling input data
🏋️ Training data using logistic regression
✅ Model training done!
🔬 Input data has 8128 cells and 16024 genes
🔗 Matching reference genes in the model


[*] Training completed: 1.0mins


🧬 16022 features used for prediction
⚖️ Scaling input data
🖋️ Predicting labels
✅ Prediction done!


[*] accuracy: 0.06


# SCANVI

This is scArches specifically


In [5]:
import os
import tempfile

import anndata
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc
# import scrublet as scr
import scvi
import seaborn as sns
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
scvi.settings.seed = 0
print("Last run with scvi-tools version:", scvi.__version__)

INFO: Seed set to 0
[lightning.fabric.utilities.seed] [INFO    ] : Seed set to 0


Last run with scvi-tools version: 1.1.6.post2


In [7]:
sc.set_figure_params(figsize=(6, 6), frameon=False)
sns.set_theme()
torch.set_float32_matmul_precision("high")
save_dir = tempfile.TemporaryDirectory()

%config InlineBackend.print_figure_kwargs={"facecolor": "w"}
%config InlineBackend.figure_format="retina"

[matplotlib.pyplot] [DEBUG   ] : Loaded backend module://matplotlib_inline.backend_inline version unknown.


In [8]:
# load data
q = sc.read_h5ad('../data/pancreas_celseq.h5ad')
ref = sc.read_h5ad('../data/pancreas_inDrop1.h5ad')

# log-normalize as needed for HVG
q.X = np.log1p(q.X)
ref.X = np.log1p(ref.X)

# subsample hvg genes
sc.pp.highly_variable_genes(ref, n_top_genes=2000, subset=True)
q = q[:, ref.var_names].copy()

# return to normal counts for SCVI
q.X = np.expm1(q.X)
ref.X = np.expm1(ref.X)

[h5py._conv      ] [DEBUG   ] : Creating converter from 3 to 5


In [9]:
len(q.obs.celltype.unique()), len(ref.obs.celltype.unique()), 

(13, 14)

In [None]:
scvi_model = scvi.model.SCVI(ref, **arches_params)
scvi_model.train()
scanvi_model = SCANVI.from_scvi_model(scvi_model, unlabeled_category="Unknown")
scanvi_model.train()