In [1]:
import scHash
import anndata as ad

from sklearn.metrics import f1_score, precision_score, recall_score
from statistics import median

# Tutorial for scRNA-seq annotation (6 pancreas) and the interpretability demonstration
We demonstrate how scHash encode multiple datasets into hash codes for six public avaialble Pancreas datasets.

The raw data for first five datasets can be obtained from [Harmony](https://github.com/immunogenomics/harmony2019/tree/master/data/figure5).

The sixth Pancreas dataset is public available at GSE83139.

We compiled the six datasets into one AnnData object for easy demonstration. The processed data can be downloaded [here](https://drive.google.com/file/d/1shc4OYIbq2FwbyGUaYuzizuvzW-giSTs/view?usp=share_link).

## Load data

In [2]:
# define data path
data_dir = '../../../../share_data/Pancreas_Wang/fivepancreas_wang_raw.h5ad'

# set up datamodule
# This anndata object is packed with 6 pancreas dataset. We take one of them to be the test dataset here.  
query = 'indrop'
full = ad.read_h5ad(data_dir)

train = full[full.obs.dataset!=query]
test = full[full.obs.dataset==query]

## Training Model

In [3]:
# set up the training datamodule
datamodule = scHash.setup_training_data(train_data = train,cell_type_key = 'cell_type', batch_key = 'dataset')

# set a directory to save the model 
checkpointPath = '../checkpoint/'

# initiliza scHash model and train 
model = scHash.scHashModel(datamodule)
trainer, best_model_path, training_time = scHash.training(model = model, datamodule = datamodule, checkpointPath = checkpointPath, max_epochs = 50)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type       | Params
------------------------------------------
0 | hash_layer | Sequential | 641 K 
------------------------------------------
641 K     Trainable params
0         Non-trainable params
641 K     Total params
2.567     Total estimated model params size (MB)
2023-01-26 08:34:35.517356: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-01-26 08:34:39.780288: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has al

Epoch: 49, Val_loss_epoch: 0.03


## Test Model

In [4]:
# add the test data
datamodule.setup_test_data(test)

# test the model
pred_labels, hash_codes = scHash.testing(trainer, model, best_model_path, datamodule)

# show the test performance
labels_true = test.obs.cell_type
f1_median = round(median(f1_score(labels_true,pred_labels,average=None)),3)

print(f'F1 Median: {f1_median}')

F1 Median: 0.96


# Interpretability

The package offers interpretability analysis of each query cell. The `compute_cell_composition` function takes `trainer`, `best_model_path`, `model` as input and will analyze the closest K cell compositions (e.g., dataset, cell type) for user.

In [5]:
df_celltype, df_batch = scHash.compute_cell_composition(trainer, best_model_path, model)

The `true label` column is the true labels for the query cell. 

Other columns are referring to the possible labels. 

Each entry is the cell type compositions of the first K nearest cell anchors.

In [6]:
df_celltype

Unnamed: 0,true label,0,1,2,3,4,5,6,7,8,9,10,11
0,0,99,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1
0,0,95,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,5
0,0,100,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0,0,98,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,2
0,0,100,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...
0,10,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,98.0,2
0,1,0.0,68.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,32
0,2,0.0,0.0,95.0,2.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,3
0,2,0.0,0.0,99.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1


Each row is a query cell and the entries are the dataset compositions of the first K nearest cell anchors.

In [7]:
df_batch

Unnamed: 0,smartseq,c1,wang,indrop,celseq,celseq2
0,80,15,1,0.0,0.0,4
0,86,14,0.0,0.0,0.0,0.0
0,17,0.0,0.0,0.0,8.0,75
0,88,7,0.0,0.0,0.0,5
0,17,0.0,0.0,0.0,9.0,74
...,...,...,...,...,...,...
0,9,0.0,0.0,0.0,14.0,77
0,67,0.0,0.0,0.0,1.0,32
0,16,0.0,0.0,0.0,14.0,70
0,18,0.0,0.0,0.0,3.0,79
