In [None]:
!pip install anndata
!pip install scanpy
!pip install scikit-learn==1.2.1
!pip install tqdm

In [1]:
import pickle
from os.path import join

import anndata
import numpy as np
import pandas as pd
import torch
import scanpy as sc

from scipy.sparse import csr_matrix
from tqdm.auto import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
BASE_PATH = '/home/felixfischer/model_inference'

# Load example data set

In [3]:
adata = anndata.read_h5ad(join(BASE_PATH, 'local.h5ad'))
# subsample to 5000 cells to make inference run faster
sc.pp.subsample(adata, n_obs=5000)
adata

AnnData object with n_obs × n_vars = 5000 × 24185
    obs: 'donor_id', 'nCount_RNA', 'nFeature_RNA', 'percent.mt', 'seurat_clusters', 'celltype', 'organism_ontology_term_id', 'assay_ontology_term_id', 'self_reported_ethnicity_ontology_term_id', 'tissue_ontology_term_id', 'is_primary_data', 'sex_ontology_term_id', 'development_stage_ontology_term_id', 'disease_ontology_term_id', 'cell_type_ontology_term_id', 'suspension_type', 'cell_type', 'assay', 'disease', 'organism', 'sex', 'tissue', 'self_reported_ethnicity', 'development_stage'
    var: 'name', 'feature_is_filtered', 'feature_name', 'feature_reference', 'feature_biotype'
    uns: 'cell_type_ontology_term_id_colors', 'schema_version', 'seurat_clusters_colors', 'title'
    obsm: 'X_pca', 'X_umap'

# Inference steps

## 1. Preprocess data for model inference

Data preparation involves the following steps:

1. Streamline feature space (select genes + arrange them in the same order as model was fitted on). 
    * We should agree on a specific ensembl release
        * Currently, I use version 104 - that's the version cellxgene uses
        * If we have the same ensembl release we can just do string matching and rearrange the matrix
        * If this in not done, the code below does not give the right results
2. Normalize data:
    * Size factor normalization: normalize each cell to have 10000 counts
    * Quantile normalization: Run quantile norm inference (model is already fitted in training)
    * Zero centering: Zero center each gene (means per gene are already calculated in training) -> this is done in step 3, otherwise the sparsity structure of the matrix is changed (increased memory footprint a lot)
3. Wrap data set into pytorch data loader -> use this to feed data into model

#### 1. Streamline feature space

In [4]:
adata.var.head()

Unnamed: 0_level_0,name,feature_is_filtered,feature_name,feature_reference,feature_biotype
gene_ids,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
ENSG00000238009,RP11-34P13.7,False,RP11-34P13.7,NCBITaxon:9606,gene
ENSG00000279457,FO538757.2,False,WASH9P,NCBITaxon:9606,gene
ENSG00000228463,AP006222.2,False,AP006222.1,NCBITaxon:9606,gene
ENSG00000237094,RP4-669L17.10,False,RP4-669L17.4,NCBITaxon:9606,gene
ENSG00000230021,RP5-857K21.4,False,RP11-206L10.17,NCBITaxon:9606,gene


In [5]:
# load the gene order on which model was trained on
genes_from_model = pd.read_parquet(join(BASE_PATH, '/home/felixfischer/model_inference/checkpoint/var.parquet'))
genes_from_model.head()

Unnamed: 0,gene_names
ENSG00000000003,TSPAN6
ENSG00000000005,TNMD
ENSG00000000419,DPM1
ENSG00000000457,SCYL3
ENSG00000000460,C1orf112


In [6]:
from cellnet.utils.data_loading import streamline_count_matrix


x_streamlined = streamline_count_matrix(adata.raw.X, adata.var.index, genes_from_model.index)
x_streamlined.shape

(5000, 19357)

#### 2. Normalize data

In [7]:
# this import is need for preprocessing pipeline
from cellnet.utils.data_loading import sf_normalize

# At the moment I just saved the fitted preprocessing pipeline as pickle
# We should discuss whether this is okay for you. Open for suggestions
# https://scikit-learn.org/stable/model_persistence.html

# for this to work you need scikit-learn version 1.2.1
preproc_pipeline = pickle.load(open(join(BASE_PATH, 'checkpoint/norm/preproc_pipeline/preproc_pipeline.pickle'), 'rb'))
preproc_pipeline

In [8]:
x_streamlined = csr_matrix(preproc_pipeline.transform(x_streamlined))
x_streamlined

<5000x19357 sparse matrix of type '<class 'numpy.float32'>'
	with 9300693 stored elements in Compressed Sparse Row format>

In [9]:
# Means for zero centering are just saved as a numpy array
# Zero center data during model inference as this would otherwise break the sparsity structure (increased memory usage)
feature_means = np.load(join(BASE_PATH, 'checkpoint/norm/zero_centering/means.npy'))
feature_means.shape

(1, 19357)

#### 3. Wrap into pytroch data loader

In [10]:
from cellnet.utils.data_loading import dataloader_factory

# Wrap dataset into pytorch data loader to use for batched inference
loader = dataloader_factory(x_streamlined, batch_size=2048)

## 2. Load weights from checkpoint and intialize model

In [11]:
from collections import OrderedDict
import yaml

In [12]:
# load checkpoint
ckpt = torch.load(join(BASE_PATH, 'checkpoint/model.ckpt'), map_location=torch.device('cpu'))

# extract state_dict of tabnet model from checkpoint
# I can do this as well and just send you the updated checkpoint file - I think this would be the best solution
# I just put this here for completeness
tabnet_weights = OrderedDict()
for name, weight in ckpt['state_dict'].items():
    if 'classifier.' in name:
        tabnet_weights[name.replace('classifier.', '')] = weight


In [13]:
from cellnet.tabnet.tab_network import TabNet


# load in hparams file of model to get model architecture
with open(join(BASE_PATH, 'checkpoint/hparams.yaml')) as f:
    model_params = yaml.full_load(f.read())


# initialzie model with hparams from hparams.yaml file
tabnet = TabNet(
    input_dim=model_params['gene_dim'],
    output_dim=model_params['type_dim'],
    n_d=model_params['n_d'],
    n_a=model_params['n_a'],
    n_steps=model_params['n_steps'],
    gamma=model_params['gamma'],
    n_independent=model_params['n_independent'],
    n_shared=model_params['n_shared'],
    epsilon=model_params['epsilon'],
    virtual_batch_size=model_params['virtual_batch_size'],
    momentum=model_params['momentum'],
    mask_type=model_params['mask_type'],
)

# load trained weights
tabnet.load_state_dict(tabnet_weights)
# set model to inference mode
tabnet.eval();

In [14]:
# convert to torch tensor
feature_means = torch.tensor(feature_means)

## 3. Run model inference

In [15]:
preds = []

with torch.no_grad():
    for batch in tqdm(loader):
        # zero center data
        x_input = batch[0]['X'] - feature_means
        logits, _ = tabnet(x_input)
        preds.append(torch.argmax(logits, dim=1).numpy())


preds = np.hstack(preds)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:45<00:00, 15.15s/it]


In [16]:
preds

array([79, 12, 30, ..., 49, 10, 18])

In [17]:
# model outputs just integers -> each int corresponds to a specific cell type
# revert this mapping now
cell_type_mapping = pd.read_parquet(join(BASE_PATH, 'checkpoint/cell_type.parquet'))

In [18]:
preds = cell_type_mapping.loc[preds]['label'].to_numpy()
preds

array(['elicited macrophage', 'natural killer cell', 'T cell', ...,
       'macrophage', 'mast cell', 'monocyte'], dtype=object)

In [19]:
preds.shape

(5000,)