Model Inference With scTab

From https://github.com/theislab/scTab/blob/devel/notebooks-tutorials/model_inference.ipynb

In [None]:
from os.path import join
import anndata
import numpy as np
import pandas as pd
import os
import scanpy as sc
import torch
from collections import OrderedDict
import yaml
from tqdm.auto import tqdm

In [None]:
from cellnet.utils.data_loading import dataloader_factory, streamline_count_matrix
from cellnet.tabnet.tab_network import TabNet

In [None]:
from self_supervision.paths import DATA_DIR

## 1. Load data set

In Tutorial the example dataset

Here we load the HLCA, TabulaSapiens, and PBMC atlases used in the SSL study

In [None]:
ckpt_dir = DATA_DIR

In [None]:
adata_hlca = anndata.read_h5ad(ckpt_dir + 'cellxgene_test_dataset_HLCA_adata.h5ad')
adata_pbmc = anndata.read_h5ad(ckpt_dir + 'cellxgene_test_dataset_PBMC_adata.h5ad')
adata_tabula_sapiens = anndata.read_h5ad(ckpt_dir + 'cellxgene_test_dataset_TabulaSapiens_adata.h5ad')

In [None]:
x_hlca = adata_hlca.X
x_pbmc = adata_pbmc.X
x_tabula_sapiens = adata_tabula_sapiens.X

y_hlca = adata_hlca.obs['cell_type']
y_pbmc = adata_pbmc.obs['cell_type']
y_tabula_sapiens = adata_tabula_sapiens.obs['cell_type']

print('HLCA: ', x_hlca.shape[0], ' cells', len(y_hlca), 'classes')
print('PBMC: ', x_pbmc.shape[0], ' cells', len(y_pbmc), 'classes')
print('Tabula Sapiens: ', x_tabula_sapiens.shape[0], ' cells', len(y_tabula_sapiens), 'classes')

In [None]:
# Wrap dataset into pytorch data loader to use for batched inference
hlca_loader = dataloader_factory(x_hlca, batch_size=2048)
pbmc_loader = dataloader_factory(x_pbmc, batch_size=2048)
tabula_sapiens_loader = dataloader_factory(x_tabula_sapiens, batch_size=2048)

In [None]:
def correct_labels(y_true: np.ndarray, y_pred: np.ndarray, child_matrix: np.ndarray):
    """
    Update predictions.
    If prediction is actually a child node of the true label -> update prediction to true value.

    E.g: Label='T cell' and prediction='CD8 positive T cell' -> update prediction to 'T cell'
    """
    updated_predictions = y_pred.copy()
    # precalculate child nodes
    child_nodes = {i: np.where(child_matrix[i, :])[0] for i in range(child_matrix.shape[0])}

    for i, (pred, true_label) in enumerate(zip(y_pred, y_true)):
        if pred in child_nodes[true_label]:
            updated_predictions[i] = true_label
        else:
            updated_predictions[i] = pred

    return updated_predictions

## 2. Load weights from checkpoint and intialize model

In [None]:
# load checkpoint
if torch.cuda.is_available():
    ckpt = torch.load(
        ckpt_dir + 'scTab-checkpoints/scTab/run5/val_f1_macro_epoch=41_val_f1_macro=0.847.ckpt', 
    )
else:
    # map to cpu if there is not gpu available
    ckpt = torch.load(
        ckpt_dir + 'scTab-checkpoints/scTab/run5/val_f1_macro_epoch=41_val_f1_macro=0.847.ckpt', 
        map_location=torch.device('cpu')
    )

# extract state_dict of tabnet model from checkpoint
# I can do this as well and just send you the updated checkpoint file - I think this would be the best solution
# I just put this here for completeness
tabnet_weights = OrderedDict()
for name, weight in ckpt['state_dict'].items():
    if 'classifier.' in name:
        tabnet_weights[name.replace('classifier.', '')] = weight

In [None]:


# load in hparams file of model to get model architecture
with open(ckpt_dir + 'scTab-checkpoints/scTab/run5/hparams.yaml') as f:
    model_params = yaml.full_load(f.read())


# initialzie model with hparams from hparams.yaml file
tabnet = TabNet(
    input_dim=model_params['gene_dim'],
    output_dim=model_params['type_dim'],
    n_d=model_params['n_d'],
    n_a=model_params['n_a'],
    n_steps=model_params['n_steps'],
    gamma=model_params['gamma'],
    n_independent=model_params['n_independent'],
    n_shared=model_params['n_shared'],
    epsilon=model_params['epsilon'],
    virtual_batch_size=model_params['virtual_batch_size'],
    momentum=model_params['momentum'],
    mask_type=model_params['mask_type'],
)

# load trained weights
tabnet.load_state_dict(tabnet_weights)
# set model to inference mode
tabnet.eval();

## 3. Run model inference

In [None]:
preds_hlca = []
preds_pbmc = []
preds_tabula_sapiens = []

with torch.no_grad():
    for batch in tqdm(hlca_loader):
        # normalize data
        x_input = batch[0]['X']
        logits, _ = tabnet(x_input)
        preds_hlca.append(torch.argmax(logits, dim=1).numpy())

    for batch in tqdm(pbmc_loader):
        # normalize data
        x_input = batch[0]['X']
        logits, _ = tabnet(x_input)
        preds_pbmc.append(torch.argmax(logits, dim=1).numpy())

    for batch in tqdm(tabula_sapiens_loader):
        # normalize data
        x_input = batch[0]['X']
        logits, _ = tabnet(x_input)
        preds_tabula_sapiens.append(torch.argmax(logits, dim=1).numpy())
        


preds_hlca = np.hstack(preds_hlca)
preds_pbmc = np.hstack(preds_pbmc)
preds_tabula_sapiens = np.hstack(preds_tabula_sapiens)

In [None]:
# model outputs just integers -> each int corresponds to a specific cell type
# revert this mapping 
base_path = os.path.join(DATA_DIR, 'merlin_cxg_2023_05_15_sf-log1p')
sctab_path = os.path.join(DATA_DIR, 'merlin_cxg_2023_05_15_sf-log1p_minimal')
cell_type_mapping_ssl = pd.read_parquet(base_path + '/categorical_lookup/cell_type.parquet')
cell_type_mapping_sctab = pd.read_parquet(sctab_path + '/categorical_lookup/cell_type.parquet')
cell_type_hierarchy = np.load(base_path + '/cell_type_hierarchy/child_matrix.npy')

In [None]:
y_pred_corr_hlca = correct_labels(y_hlca, preds_hlca, cell_type_hierarchy)
y_pred_corr_pbmc = correct_labels(y_pbmc, preds_pbmc, cell_type_hierarchy)
y_pred_corr_tabula_sapiens = correct_labels(y_tabula_sapiens, preds_tabula_sapiens, cell_type_hierarchy)

true_hlca = cell_type_mapping_sctab.loc[y_hlca]['label'].to_numpy()
true_pbmc = cell_type_mapping_sctab.loc[y_pbmc]['label'].to_numpy()
true_tabula_sapiens = cell_type_mapping_sctab.loc[y_tabula_sapiens]['label'].to_numpy()

y_pred_corr_hlca_str = cell_type_mapping_ssl.loc[y_pred_corr_hlca]['label'].to_numpy()
y_pred_corr_pbmc_str = cell_type_mapping_ssl.loc[y_pred_corr_pbmc]['label'].to_numpy()
y_pred_corr_tabula_sapiens_str = cell_type_mapping_ssl.loc[y_pred_corr_tabula_sapiens]['label'].to_numpy()

In [None]:
from sklearn.metrics import f1_score, classification_report

clf_report = pd.DataFrame(classification_report(
    true_hlca,
    y_pred_corr_hlca,
    labels=np.unique(true_hlca),
    output_dict=True
)).T
clf_report_overall = clf_report.iloc[-3].copy()
clf_report_per_class = clf_report.iloc[:-3].copy()
clf_report_overall

In [None]:
from sklearn.metrics import f1_score
# HLCA
micro_f1_hlca = f1_score(y_hlca, preds_hlca, average='micro')
macro_f1_hlca = f1_score(y_hlca, preds_hlca, average='macro')

In [None]:
preds_hlca = cell_type_mapping_ssl.loc[preds_hlca]['label'].to_numpy()
preds_pbmc = cell_type_mapping_ssl.loc[preds_pbmc]['label'].to_numpy()
preds_tabula_sapiens = cell_type_mapping_ssl.loc[preds_tabula_sapiens]['label'].to_numpy()

preds_hlca

In [None]:
true_hlca = cell_type_mapping_sctab.loc[y_hlca]['label'].to_numpy()
true_pbmc = cell_type_mapping_sctab.loc[y_pbmc]['label'].to_numpy()
true_tabula_sapiens = cell_type_mapping_sctab.loc[y_tabula_sapiens]['label'].to_numpy()

## Evaluate

In [None]:
from sklearn.metrics import f1_score
# HLCA
micro_f1_hlca = f1_score(true_hlca, y_pred_corr_hlca_str, average='micro', labels=np.unique(true_hlca))
macro_f1_hlca = f1_score(true_hlca, y_pred_corr_hlca_str, average='macro', labels=np.unique(true_hlca))

# PBMC
micro_f1_pbmc = f1_score(true_pbmc, y_pred_corr_pbmc_str, average='micro', labels=np.unique(true_pbmc))
macro_f1_pbmc = f1_score(true_pbmc, y_pred_corr_pbmc_str, average='macro', labels=np.unique(true_pbmc))

# Tabula Sapiens
micro_f1_tabula_sapiens = f1_score(true_tabula_sapiens, y_pred_corr_tabula_sapiens_str, average='micro', labels=np.unique(true_tabula_sapiens))
macro_f1_tabula_sapiens = f1_score(true_tabula_sapiens, y_pred_corr_tabula_sapiens_str, average='macro', labels=np.unique(true_tabula_sapiens))

# Print the results
print(f'HLCA - Micro F1: {micro_f1_hlca}, Macro F1: {macro_f1_hlca}')
print(f'PBMC - Micro F1: {micro_f1_pbmc}, Macro F1: {macro_f1_pbmc}')
print(f'Tabula Sapiens - Micro F1: {micro_f1_tabula_sapiens}, Macro F1: {macro_f1_tabula_sapiens}')