# Model inference example notebook

In [1]:
from os.path import join

import anndata
import numpy as np
import pandas as pd
import scanpy as sc
import torch

from tqdm.auto import tqdm

  from .autonotebook import tqdm as notebook_tqdm


## Load example data set

In [None]:
# Download example data
!wget https://datasets.cellxgene.cziscience.com/f8f41e86-e9ed-4de7-a155-836b2f243fd0.h5ad

In [2]:
adata = anndata.read_h5ad('f8f41e86-e9ed-4de7-a155-836b2f243fd0.h5ad')
# subsample to 5000 cells to make inference run faster
sc.pp.subsample(adata, n_obs=5000)
adata

AnnData object with n_obs × n_vars = 5000 × 36263
    obs: 'nCount_RNA', 'nFeature_RNA', 'nCount_HTO', 'nFeature_HTO', 'HTO_maxID', 'HTO_secondID', 'HTO_margin', 'HTO_classification.global', 'sample', 'donor_id', 'CHIP', 'LANE', 'ProjectID', 'MUTATION', 'MUTATION.GROUP', 'sex_ontology_term_id', 'HTOID', 'percent.mt', 'nCount_SCT', 'nFeature_SCT', 'scType_celltype', 'pANN', 'development_stage_ontology_term_id', 'cell_type_ontology_term_id', 'self_reported_ethnicity_ontology_term_id', 'assay_ontology_term_id', 'suspension_type', 'is_primary_data', 'tissue_type', 'tissue_ontology_term_id', 'organism_ontology_term_id', 'disease_ontology_term_id', 'cell_type', 'assay', 'disease', 'organism', 'sex', 'tissue', 'self_reported_ethnicity', 'development_stage', 'observation_joinid'
    var: 'feature_is_filtered', 'feature_name', 'feature_reference', 'feature_biotype', 'feature_length'
    uns: 'citation', 'default_embedding', 'schema_reference', 'schema_version', 'title'
    obsm: 'X_umap'

In [3]:
# use raw counts
adata.X = adata.raw.X

## Inference steps

To run this notebook you first need to download all relevant data:
* Minimal store: https://pklab.med.harvard.edu/felix/data/merlin_cxg_2023_05_15_sf-log1p_minimal.tar.gz
    * Includes `var.parquet` file
    * Includes `cell_type.parquet` file under `categorical_lookup/cell_type.parquet`
* scTab checkpoints: https://pklab.med.harvard.edu/felix/data/scTab-checkpoints.tar.gz
    * Includes all the checkpoint files

In [None]:
# Download data
!wget https://pklab.med.harvard.edu/felix/data/merlin_cxg_2023_05_15_sf-log1p_minimal.tar.gz
!wget https://pklab.med.harvard.edu/felix/data/scTab-checkpoints.tar.gz

In [None]:
# extract data
!tar -xzvf merlin_cxg_2023_05_15_sf-log1p_minimal.tar.gz
!tar -xzvf scTab-checkpoints.tar.gz

### 1. Preprocess data for model inference

Data preparation involves the following steps:

1. Streamline feature space (select genes + arrange them in the same order as model was fitted on). 
    * Currently, the model uses Ensembl version 104 - that's the version CELLxGENE uses
    * If you have the same Ensembl release we can just do string matching and rearrange the gene matrix
    * If this in not done, the code below does not give the right results
    * If you use a different Ensembl version, the output of the code below might not be correct
2. Wrap data set into PyTorch data loader &rarr; use this to feed data into model

#### 1. Streamline feature space

In [4]:
adata.var.head()

Unnamed: 0,feature_is_filtered,feature_name,feature_reference,feature_biotype,feature_length
ENSG00000243485,False,MIR1302-2HG,NCBITaxon:9606,gene,1021
ENSG00000237613,False,FAM138A,NCBITaxon:9606,gene,1219
ENSG00000186092,False,OR4F5,NCBITaxon:9606,gene,2618
ENSG00000238009,False,ENSG00000238009.6,NCBITaxon:9606,gene,3726
ENSG00000239945,False,ENSG00000239945.1,NCBITaxon:9606,gene,1319


In [5]:
# load the gene order on which model was trained on
genes_from_model = pd.read_parquet('merlin_cxg_2023_05_15_sf-log1p_minimal/var.parquet')
genes_from_model.head()

Unnamed: 0,feature_id,feature_name
0,ENSG00000186092,OR4F5
1,ENSG00000284733,OR4F29
2,ENSG00000284662,OR4F16
3,ENSG00000187634,SAMD11
4,ENSG00000188976,NOC2L


In [7]:
from scipy.sparse import csc_matrix
from cellnet.utils.data_loading import streamline_count_matrix

# subset gene space only to genes used by the model
adata = adata[:, adata.var.feature_name.isin(genes_from_model.feature_name).to_numpy()]
# pass the count matrix in csc_matrix to make column slicing efficient
x_streamlined = streamline_count_matrix(
    csc_matrix(adata.X), 
    adata.var.feature_name,  # change this if gene names are stored in different column
    genes_from_model.feature_name
)
x_streamlined.shape

(5000, 19331)

#### 2. Wrap into pytroch data loader

In [8]:
from cellnet.utils.data_loading import dataloader_factory

# Wrap dataset into pytorch data loader to use for batched inference
loader = dataloader_factory(x_streamlined, batch_size=2048)

## 2. Load weights from checkpoint and intialize model

In [9]:
from collections import OrderedDict
import yaml

In [10]:
# load checkpoint
if torch.cuda.is_available():
    ckpt = torch.load(
        'scTab-checkpoints/scTab/run5/val_f1_macro_epoch=41_val_f1_macro=0.847.ckpt', 
    )
else:
    # map to cpu if there is not gpu available
    ckpt = torch.load(
        'scTab-checkpoints/scTab/run5/val_f1_macro_epoch=41_val_f1_macro=0.847.ckpt', 
        map_location=torch.device('cpu')
    )

# extract state_dict of tabnet model from checkpoint
# I can do this as well and just send you the updated checkpoint file - I think this would be the best solution
# I just put this here for completeness
tabnet_weights = OrderedDict()
for name, weight in ckpt['state_dict'].items():
    if 'classifier.' in name:
        tabnet_weights[name.replace('classifier.', '')] = weight


In [12]:
from cellnet.tabnet.tab_network import TabNet


# load in hparams file of model to get model architecture
with open('scTab-checkpoints/scTab/run5/hparams.yaml') as f:
    model_params = yaml.full_load(f.read())


# initialzie model with hparams from hparams.yaml file
tabnet = TabNet(
    input_dim=model_params['gene_dim'],
    output_dim=model_params['type_dim'],
    n_d=model_params['n_d'],
    n_a=model_params['n_a'],
    n_steps=model_params['n_steps'],
    gamma=model_params['gamma'],
    n_independent=model_params['n_independent'],
    n_shared=model_params['n_shared'],
    epsilon=model_params['epsilon'],
    virtual_batch_size=model_params['virtual_batch_size'],
    momentum=model_params['momentum'],
    mask_type=model_params['mask_type'],
)

# load trained weights
tabnet.load_state_dict(tabnet_weights)
# set model to inference mode
tabnet.eval();

## 3. Run model inference

In [13]:
def sf_log1p_norm(x):
    """Normalize each cell to have 10000 counts and apply log(x+1) transform."""

    counts = torch.sum(x, dim=1, keepdim=True)
    # avoid zero division error
    counts += counts == 0.
    scaling_factor = 10000. / counts

    return torch.log1p(scaling_factor * x)


In [14]:
preds = []

with torch.no_grad():
    for batch in tqdm(loader):
        # normalize data
        x_input = sf_log1p_norm(batch[0]['X'])
        logits, _ = tabnet(x_input)
        preds.append(torch.argmax(logits, dim=1).numpy())


preds = np.hstack(preds)

100%|██████████| 3/3 [00:08<00:00,  2.91s/it]


In [15]:
preds

array([127,  54, 125, ..., 106,  54,  22])

In [16]:
# model outputs just integers -> each int corresponds to a specific cell type
# revert this mapping now
cell_type_mapping = pd.read_parquet('merlin_cxg_2023_05_15_sf-log1p_minimal/categorical_lookup/cell_type.parquet')

In [17]:
preds = cell_type_mapping.loc[preds]['label'].to_numpy()
preds

array(['natural killer cell', 'dendritic cell',
       'naive thymus-derived CD8-positive, alpha-beta T cell', ...,
       'macrophage', 'dendritic cell', 'T cell'], dtype=object)

In [18]:
preds.shape

(5000,)