In [52]:
import scanpy as sc
import celltypist
import time
import numpy as np

In [67]:
adata_Elmentaite = sc.read_h5ad("training.h5ad")
adata_Elmentaite

AnnData object with n_obs × n_vars = 3756 × 33538
    obs: 'sampleID', 'patientID', 'n_genes_by_counts', 'total_counts', 'total_counts_mt', 'pct_counts_mt', 'TCR_type'
    var: 'mt', 'n_cells_by_counts', 'mean_counts', 'pct_dropout_by_counts', 'total_counts'

In [68]:
sc.pp.normalize_total(adata_Elmentaite, target_sum = 1e4)
sc.pp.log1p(adata_Elmentaite)

In [69]:
# Sample 500 cells from each cell type for `adata_Elmentaite`.
# All cells from a given cell type will be selected if the cell type size is < 500.
sampled_cell_index = celltypist.samples.downsample_adata(adata_Elmentaite, mode = 'each', n_cells = 500, by = 'sub_cell_type', return_index = True)

In [70]:
# Use `celltypist.train` to quickly train a rough CellTypist model.
# You can also set `mini_batch = True` to enable mini-batch training.
t_start = time.time()
model_fs = celltypist.train(adata_Elmentaite[sampled_cell_index], 'sub_cell_type', n_jobs = 10, max_iter = 5, use_SGD = True)
t_end = time.time()
print(f"Time elapsed: {t_end - t_start} seconds")

🍳 Preparing data before training
✂️ 18295 non-expressed genes are filtered out
🔬 Input data has 1000 cells and 15243 genes
⚖️ Scaling input data
🏋️ Training data using SGD logistic regression
✅ Model training done!


Time elapsed: 0.6984186172485352 seconds


In [71]:
gene_index = np.argpartition(np.abs(model_fs.classifier.coef_), -100, axis = 1)[:, -100:]

In [72]:
gene_index = np.unique(gene_index)

In [73]:
print(f"Number of genes selected: {len(gene_index)}")

Number of genes selected: 100


In [74]:
# Add `check_expression = False` to bypass expression check with only a subset of genes.
t_start = time.time()
model = celltypist.train(adata_Elmentaite[sampled_cell_index, gene_index], 'TCR_type', check_expression = False, n_jobs = 10, max_iter = 100)
t_end = time.time()
print(f"Time elapsed: {(t_end - t_start)/60} minutes")

🍳 Preparing data before training
✂️ 48 non-expressed genes are filtered out
🔬 Input data has 1000 cells and 52 genes
⚖️ Scaling input data
🏋️ Training data using logistic regression
✅ Model training done!


Time elapsed: 0.021984867254892983 minutes


In [75]:
# Save the model.
model.write('model_from_Elmentaite_specific.pkl')

In [76]:
adata_James = sc.read_h5ad('predict.h5ad')
adata_James

AnnData object with n_obs × n_vars = 144162 × 31831
    obs: 'sampleID', 'cellID', 'n_genes', 'n_genes_by_counts', 'total_counts', 'total_counts_mt', 'pct_counts_mt', 'sub_cell_type', 'major_cell_type'
    var: 'n_cells', 'mt', 'n_cells_by_counts', 'mean_counts', 'pct_dropout_by_counts', 'total_counts'

In [77]:
sc.pp.normalize_total(adata_James, target_sum = 1e4)
sc.pp.log1p(adata_James)

In [78]:
# CellTypist prediction with over-clustering and majority-voting.
t_start = time.time()
predictions = celltypist.annotate(adata_James, model = 'model_from_Elmentaite_specific.pkl', majority_voting = True)
t_end = time.time()
print(f"Time elapsed: {t_end - t_start} seconds")

🔬 Input data has 144162 cells and 31831 genes
🔗 Matching reference genes in the model
🧬 52 features used for prediction
⚖️ Scaling input data
🖋️ Predicting labels
✅ Prediction done!
👀 Can not detect a neighborhood graph, will construct one before the over-clustering
⛓️ Over-clustering input data with resolution set to 25
🗳️ Majority voting the predictions
✅ Majority voting done!


Time elapsed: 335.8447313308716 seconds


In [79]:
predictions.predicted_labels

Unnamed: 0,predicted_labels,over_clustering,majority_voting
P304_ACGAGCCGTGTGCCTG_1,MANA specific,117,MANA specific
P64_TACTTACCAGGTCTCG_1,MANA specific,307,Viral specific
P481_GTGCATAGTAAATGAC_1,Viral specific,42,MANA specific
P435_CGATCGGGTTATCGGT_1,MANA specific,15,MANA specific
P182_GTGCGGTTCCAACCAA_1,Viral specific,6,MANA specific
...,...,...,...
P469_GGTGTTATCAGGCAAG_1,MANA specific,309,MANA specific
P454_GGGTCTGCAGACGTAG_1,Viral specific,41,Viral specific
P53_ACTGCTCTCCAGATCA_1,Viral specific,155,Viral specific
P44_CTGATAGGTTCGTCTC_1,MANA specific,1,MANA specific
