In [None]:
!pip install anndata
!pip install scanpy
!pip install tqdm

In [1]:
from os.path import join

import anndata
import numpy as np
import pandas as pd
import scanpy as sc
import torch

from tqdm.auto import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
BASE_PATH = '/home/felixfischer/model_inference'

# Load example data set

In [3]:
adata = anndata.read_h5ad(join(BASE_PATH, 'local.h5ad'))
# subsample to 5000 cells to make inference run faster
sc.pp.subsample(adata, n_obs=5000)
adata

AnnData object with n_obs × n_vars = 5000 × 24185
    obs: 'donor_id', 'nCount_RNA', 'nFeature_RNA', 'percent.mt', 'seurat_clusters', 'celltype', 'organism_ontology_term_id', 'assay_ontology_term_id', 'self_reported_ethnicity_ontology_term_id', 'tissue_ontology_term_id', 'is_primary_data', 'sex_ontology_term_id', 'development_stage_ontology_term_id', 'disease_ontology_term_id', 'cell_type_ontology_term_id', 'suspension_type', 'cell_type', 'assay', 'disease', 'organism', 'sex', 'tissue', 'self_reported_ethnicity', 'development_stage'
    var: 'name', 'feature_is_filtered', 'feature_name', 'feature_reference', 'feature_biotype'
    uns: 'cell_type_ontology_term_id_colors', 'schema_version', 'seurat_clusters_colors', 'title'
    obsm: 'X_pca', 'X_umap'

# Inference steps

## 1. Preprocess data for model inference

Data preparation involves the following steps:

1. Streamline feature space (select genes + arrange them in the same order as model was fitted on). 
    * We should agree on a specific ensembl release
        * Currently, I use version 104 - that's the version cellxgene uses
        * If we have the same ensembl release we can just do string matching and rearrange the matrix
        * If this in not done, the code below does not give the right results
2. Wrap data set into pytorch data loader -> use this to feed data into model

#### 1. Streamline feature space

In [4]:
adata.var.head()

Unnamed: 0_level_0,name,feature_is_filtered,feature_name,feature_reference,feature_biotype
gene_ids,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
ENSG00000238009,RP11-34P13.7,False,RP11-34P13.7,NCBITaxon:9606,gene
ENSG00000279457,FO538757.2,False,WASH9P,NCBITaxon:9606,gene
ENSG00000228463,AP006222.2,False,AP006222.1,NCBITaxon:9606,gene
ENSG00000237094,RP4-669L17.10,False,RP4-669L17.4,NCBITaxon:9606,gene
ENSG00000230021,RP5-857K21.4,False,RP11-206L10.17,NCBITaxon:9606,gene


In [5]:
# load the gene order on which model was trained on
genes_from_model = pd.read_parquet(join(BASE_PATH, '/home/felixfischer/model_inference/checkpoint/var.parquet'))
genes_from_model.head()

Unnamed: 0,feature_id,feature_name
0,ENSG00000186092,OR4F5
1,ENSG00000284733,OR4F29
2,ENSG00000284662,OR4F16
3,ENSG00000187634,SAMD11
4,ENSG00000188976,NOC2L


In [6]:
from cellnet.utils.data_loading import streamline_count_matrix


x_streamlined = streamline_count_matrix(adata.raw.X, adata.var.name, genes_from_model.feature_name)
x_streamlined.shape

(5000, 19331)

#### 2. Wrap into pytroch data loader

In [7]:
from cellnet.utils.data_loading import dataloader_factory

# Wrap dataset into pytorch data loader to use for batched inference
loader = dataloader_factory(x_streamlined, batch_size=2048)

## 2. Load weights from checkpoint and intialize model

In [8]:
from collections import OrderedDict
import yaml

In [9]:
# load checkpoint
ckpt = torch.load(join(BASE_PATH, 'checkpoint/model.ckpt'), map_location=torch.device('cpu'))

# extract state_dict of tabnet model from checkpoint
# I can do this as well and just send you the updated checkpoint file - I think this would be the best solution
# I just put this here for completeness
tabnet_weights = OrderedDict()
for name, weight in ckpt['state_dict'].items():
    if 'classifier.' in name:
        tabnet_weights[name.replace('classifier.', '')] = weight


In [10]:
from cellnet.tabnet.tab_network import TabNet


# load in hparams file of model to get model architecture
with open(join(BASE_PATH, 'checkpoint/hparams.yaml')) as f:
    model_params = yaml.full_load(f.read())


# initialzie model with hparams from hparams.yaml file
tabnet = TabNet(
    input_dim=model_params['gene_dim'],
    output_dim=model_params['type_dim'],
    n_d=model_params['n_d'],
    n_a=model_params['n_a'],
    n_steps=model_params['n_steps'],
    gamma=model_params['gamma'],
    n_independent=model_params['n_independent'],
    n_shared=model_params['n_shared'],
    epsilon=model_params['epsilon'],
    virtual_batch_size=model_params['virtual_batch_size'],
    momentum=model_params['momentum'],
    mask_type=model_params['mask_type'],
)

# load trained weights
tabnet.load_state_dict(tabnet_weights)
# set model to inference mode
tabnet.eval();

## 3. Run model inference

In [11]:
def sf_log1p_norm(x):
    """Normalize each cell to have 10000 counts and apply log(x+1) transform."""

    counts = torch.sum(x, dim=1, keepdim=True)
    # avoid zero division error
    counts += counts == 0.
    scaling_factor = 10000. / counts

    return torch.log1p(scaling_factor * x)


In [12]:
preds = []

with torch.no_grad():
    for batch in tqdm(loader):
        # normalize data
        x_input = sf_log1p_norm(batch[0]['X'])
        logits, _ = tabnet(x_input)
        preds.append(torch.argmax(logits, dim=1).numpy())


preds = np.hstack(preds)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:47<00:00, 15.70s/it]


In [13]:
preds

array([118, 127,  22, ..., 118, 107, 118])

In [14]:
# model outputs just integers -> each int corresponds to a specific cell type
# revert this mapping now
cell_type_mapping = pd.read_parquet(join(BASE_PATH, 'checkpoint/cell_type.parquet'))

In [15]:
preds = cell_type_mapping.loc[preds]['label'].to_numpy()
preds

array(['monocyte', 'natural killer cell', 'T cell', ..., 'monocyte',
       'mast cell', 'monocyte'], dtype=object)

In [16]:
preds.shape

(5000,)