In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from knn_cell_type_assigner import weighted_knn_trainer, weighted_knn_transfer

In [3]:
import os
import pathlib
import pickle

import anndata
import torch
import numpy as np
import scanpy as sc
import scarches as sca
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report

from scarches.dataset.trvae.data_handling import remove_sparsity

sc.settings.set_figure_params(dpi=200, frameon=False)
sc.set_figure_params(dpi=500)
plt.rcParams['figure.figsize'] = (5, 5)
torch.set_printoptions(precision=3, sci_mode=False, edgeitems=7)

In [4]:
home_dir = os.path.expanduser("~")

In [5]:
condition_key = 'sample'
cell_type_key = ['ann_finest_level']

HLCA_DATA_PATH = os.path.expanduser('~/io/scpoli_repr/hlca_counts_commonvars.h5ad')
HLCA_EXTENDED_DATA_PATH = os.path.expanduser('~/io/scpoli_repr/hlca_extended_commonvars.h5ad')
HLCA_CANCER_DATA_PATH = os.path.expanduser('~/io/scpoli_repr/hlca_cancer_commonvars.h5ad')

HLCA_SCARCHES_SAVE_MODEL = os.path.join(home_dir, "io/lataq_repr/scarches_models/HLCA_reference_model_sample/")
HLCA_SCARCHES_SAVE_FINETUNED_MODEL_format = os.path.join(home_dir, "io/scpoli_repr/scarches_models/HLCA_mapped_model_sample_{}")
HLCA_SCARCHES_SAVE_MAPPED_LATENT_format = os.path.join(home_dir, "io/scpoli_repr/scarches_models/HLCA_mapped_model_sample_{}.latent.h5ad")

In [6]:
unlabeled_category = "unlabeled"

vae_epochs = 500
scanvi_epochs = 200

early_stopping_kwargs = {
    "early_stopping_metric": "elbo",
    "save_best_state_metric": "elbo",
    "patience": 10,
    "threshold": 0,
    "reduce_lr_on_plateau": True,
    "lr_patience": 8,
    "lr_factor": 0.1,
}
early_stopping_kwargs_scanvi = {
    "early_stopping_metric": "accuracy",
    "save_best_state_metric": "accuracy",
    "on": "full_dataset",
    "patience": 10,
    "threshold": 0.001,
    "reduce_lr_on_plateau": True,
    "lr_patience": 8,
    "lr_factor": 0.1,
}

In [7]:
adata = sc.read(HLCA_DATA_PATH)
adata.X = adata.X.astype(np.float32)
adata

AnnData object with n_obs × n_vars = 584884 × 1897
    obs: 'is_primary_data', 'assay_ontology_term_id', 'cell_type_ontology_term_id', 'development_stage_ontology_term_id', 'disease_ontology_term_id', 'ethnicity_ontology_term_id', 'tissue_ontology_term_id', 'organism_ontology_term_id', 'sex_ontology_term_id', 'sample', 'study', 'subject_ID', 'smoking_status', 'BMI', 'condition', 'subject_type', 'sample_type', "3'_or_5'", 'sequencing_platform', 'cell_ranger_version', 'fresh_or_frozen', 'dataset', 'anatomical_region_level_2', 'anatomical_region_level_3', 'anatomical_region_highest_res', 'age', 'ann_highest_res', 'n_genes', 'size_factors', 'log10_total_counts', 'mito_frac', 'ribo_frac', 'original_ann_level_1', 'original_ann_level_2', 'original_ann_level_3', 'original_ann_level_4', 'original_ann_level_5', 'original_ann_nonharmonized', 'scanvi_label', 'leiden_1', 'leiden_2', 'leiden_3', 'anatomical_region_ccf_score', 'entropy_study_leiden_3', 'entropy_dataset_leiden_3', 'entropy_subject_ID_

In [8]:
adata_extended = sc.read(HLCA_EXTENDED_DATA_PATH)
adata_extended = adata_extended[adata_extended.obs['study'] == 'Meyer_2021'].copy()
adata_extended.X = adata_extended.X.astype(np.float32)
adata_extended.obs[cell_type_key[0]] = unlabeled_category
adata_extended

AnnData object with n_obs × n_vars = 128628 × 1897
    obs: 'dataset', 'study', 'original_celltype_ann', 'condition', 'subject_ID', 'sample', 'cells_or_nuclei', 'single_cell_platform', 'sample_type', 'age', 'sex', 'ethnicity', 'BMI', 'smoking_status', 'anatomical_region_level_1', 'anatomical_region_coarse', 'anatomical_region_detailed', 'genome', 'disease', 'ann_finest_level'

In [9]:
adata_cancer = sc.read(HLCA_CANCER_DATA_PATH)
adata_cancer.X = adata_cancer.X.astype(np.float32)
adata_cancer.obs[cell_type_key[0]] = unlabeled_category
adata_cancer

AnnData object with n_obs × n_vars = 93575 × 1897
    obs: 'n_genes_detected', 'total_counts', 'cell_from_tumor', 'subject_ID', 'tumor_site', 'original_celltype_ann', 'sample', 'study', 'study_long', 'dataset', 'last_author_PI', 'lung_vs_nasal', 'ann_level_1', 'ann_level_2', 'ann_level_3', 'ann_level_4', 'ann_level_5', 'ann_highest_res', 'ann_new', 'scanvi_label', 'ann_finest_level'
    var: 'original_gene_names', 'gene_symbols', 'ensembl'

## Training scANVI

In [9]:
sca.dataset.setup_anndata(reference_adata, batch_key=condition_key, labels_key=cell_type_key)

[34mINFO    [0m Using batches from adata.obs[1m[[0m[32m"sample"[0m[1m][0m                                              
[34mINFO    [0m Using labels from adata.obs[1m[[0m[32m"ann_finest_level"[0m[1m][0m                                     
[34mINFO    [0m Using data from adata.X                                                             
[34mINFO    [0m Computing library size prior per batch                                              
[34mINFO    [0m Successfully registered anndata object containing [1;36m584884[0m cells, [1;36m1897[0m vars, [1;36m166[0m      
         batches, [1;36m58[0m labels, and [1;36m0[0m proteins. Also registered [1;36m0[0m extra categorical covariates  
         and [1;36m0[0m extra continuous covariates.                                                  
[34mINFO    [0m Please do not further modify adata until model is trained.                          


In [None]:
vae = sca.models.SCANVI(
    reference_adata,
    unlabeled_category,
    n_layers=2,
    n_latent = 30, # to allow for capturing more heterogeneity
    encode_covariates=True,
    deeply_inject_covariates=False,
    use_layer_norm="both",
    use_batch_norm="none",
    gene_likelihood="nb", # because we have UMI data
    use_cuda=True #to use GPU
)


In [None]:
print("Labelled Indices: ", len(vae._labeled_indices))
print("Unlabelled Indices: ", len(vae._unlabeled_indices))

In [None]:
vae.train(
    n_epochs_unsupervised=vae_epochs,
    n_epochs_semisupervised=scanvi_epochs,
    unsupervised_trainer_kwargs=dict(early_stopping_kwargs=early_stopping_kwargs),
    semisupervised_trainer_kwargs=dict(metrics_to_monitor=["elbo", "accuracy"],
                                       early_stopping_kwargs=early_stopping_kwargs_scanvi),
    frequency=1
)

In [None]:
vae.save(HLCA_SCARCHES_SAVE_MODEL, overwrite=False)

## Perform surgery on reference model and train on cancer dataset

In [10]:
ref_path = HLCA_SCARCHES_SAVE_MODEL
adata_query = adata_cancer
adata_ref = adata

In [11]:
vae = sca.models.SCANVI.load(HLCA_SCARCHES_SAVE_MODEL, adata_ref)

[34mINFO    [0m Using data from adata.X                                                             
[34mINFO    [0m Computing library size prior per batch                                              


  logger_data_loc


[34mINFO    [0m Registered keys:[1m[[0m[32m'X'[0m, [32m'batch_indices'[0m, [32m'local_l_mean'[0m, [32m'local_l_var'[0m, [32m'labels'[0m[1m][0m     
[34mINFO    [0m Successfully registered anndata object containing [1;36m584884[0m cells, [1;36m1897[0m vars, [1;36m166[0m      
         batches, [1;36m58[0m labels, and [1;36m0[0m proteins. Also registered [1;36m0[0m extra categorical covariates  
         and [1;36m0[0m extra continuous covariates.                                                  


In [12]:
reference_emb_adata = sc.AnnData(vae.get_latent_representation(
    adata_ref
), obs=adata_ref.obs.copy())



In [14]:
%%time

surgery_epochs = 500
early_stopping_kwargs_surgery = {
    "early_stopping_metric": "elbo",
    "save_best_state_metric": "elbo",
    "on": "full_dataset",
    "patience": 10,
    "threshold": 0.001,
    "reduce_lr_on_plateau": True,
    "lr_patience": 8,
    "lr_factor": 0.1,
}

model = sca.models.SCANVI.load_query_data(
    adata_query,
    ref_path,
    freeze_dropout = True,
)

model.train(
    n_epochs_semisupervised=surgery_epochs,
    train_base_model=False,
    semisupervised_trainer_kwargs=dict(
        metrics_to_monitor=["accuracy", "elbo"], 
        weight_decay=0,
        early_stopping_kwargs=early_stopping_kwargs_surgery
    ),
    frequency=1
)

[34mINFO    [0m Using data from adata.X                                                             
[34mINFO    [0m Computing library size prior per batch                                              
[34mINFO    [0m Registered keys:[1m[[0m[32m'X'[0m, [32m'batch_indices'[0m, [32m'local_l_mean'[0m, [32m'local_l_var'[0m, [32m'labels'[0m[1m][0m     
[34mINFO    [0m Successfully registered anndata object containing [1;36m93575[0m cells, [1;36m1897[0m vars, [1;36m202[0m       
         batches, [1;36m59[0m labels, and [1;36m0[0m proteins. Also registered [1;36m0[0m extra categorical covariates  
         and [1;36m0[0m extra continuous covariates.                                                  
[34mINFO    [0m Training Unsupervised Trainer for [1;36m85[0m epochs.                                        
[34mINFO    [0m Training SemiSupervised Trainer for [1;36m500[0m epochs.                                     
[34mINFO    [0m KL warmup for [

In [15]:
latent_subadata = sc.AnnData(model.get_latent_representation(
    adata_query
), obs=adata_query.obs.copy())

In [22]:
model.save(HLCA_SCARCHES_SAVE_FINETUNED_MODEL_format.format('cancer'))
latent_subadata.write(HLCA_SCARCHES_SAVE_MAPPED_LATENT_format.format('cancer'))

In [17]:
latent_subadata = sc.read(HLCA_SCARCHES_SAVE_MAPPED_LATENT_format.format('cancer'))

In [18]:
%%time

# run k-neighbors transformer
k_neighbors_transformer = weighted_knn_trainer(
    train_adata=reference_emb_adata,
    train_adata_emb="X", # location of our joint embedding
    label_key="ann_finest_level",
    n_neighbors=50,
    )    
# perform label transfer
labels, uncert = weighted_knn_transfer(
    k_neighbors_transformer,
    query_adata=latent_subadata,
    query_adata_emb="X", # location of our joint embedding
    label_keys="ann_finest_level",
    knn_model=k_neighbors_transformer,
    ref_adata_obs = reference_emb_adata.obs
    )

Weighted KNN with n_neighbors = 50 ... finished!
CPU times: user 2h 24min 22s, sys: 3h 38min 21s, total: 6h 2min 44s
Wall time: 27min 50s


In [22]:
latent_subadata.obs['pred_ann_class'] = labels["ann_finest_level"].values
latent_subadata.obs['pred_ann_uncert'] = uncert["ann_finest_level"].to_numpy().astype(float)

In [24]:
latent_subadata.write(HLCA_SCARCHES_SAVE_MAPPED_LATENT_format.format('cancer-knn'))

## Perform surgery on reference model and train on healthy dataset

In [10]:
ref_path = HLCA_SCARCHES_SAVE_MODEL
adata_query = adata_extended
adata_ref = adata

In [11]:
vae = sca.models.SCANVI.load(HLCA_SCARCHES_SAVE_MODEL, adata_ref)

[34mINFO    [0m Using data from adata.X                                                             
[34mINFO    [0m Computing library size prior per batch                                              


  logger_data_loc


[34mINFO    [0m Registered keys:[1m[[0m[32m'X'[0m, [32m'batch_indices'[0m, [32m'local_l_mean'[0m, [32m'local_l_var'[0m, [32m'labels'[0m[1m][0m     
[34mINFO    [0m Successfully registered anndata object containing [1;36m584884[0m cells, [1;36m1897[0m vars, [1;36m166[0m      
         batches, [1;36m58[0m labels, and [1;36m0[0m proteins. Also registered [1;36m0[0m extra categorical covariates  
         and [1;36m0[0m extra continuous covariates.                                                  


In [12]:
reference_emb_adata = sc.AnnData(vae.get_latent_representation(
    adata_ref
), obs=adata_ref.obs.copy())



In [13]:
%%time

surgery_epochs = 500
early_stopping_kwargs_surgery = {
    "early_stopping_metric": "elbo",
    "save_best_state_metric": "elbo",
    "on": "full_dataset",
    "patience": 10,
    "threshold": 0.001,
    "reduce_lr_on_plateau": True,
    "lr_patience": 8,
    "lr_factor": 0.1,
}

model = sca.models.SCANVI.load_query_data(
    adata_query,
    ref_path,
    freeze_dropout = True,
)

model.train(
    n_epochs_semisupervised=surgery_epochs,
    train_base_model=False,
    semisupervised_trainer_kwargs=dict(
        metrics_to_monitor=["accuracy", "elbo"], 
        weight_decay=0,
        early_stopping_kwargs=early_stopping_kwargs_surgery
    ),
    frequency=1
)

[34mINFO    [0m Using data from adata.X                                                             
[34mINFO    [0m Computing library size prior per batch                                              
[34mINFO    [0m Registered keys:[1m[[0m[32m'X'[0m, [32m'batch_indices'[0m, [32m'local_l_mean'[0m, [32m'local_l_var'[0m, [32m'labels'[0m[1m][0m     
[34mINFO    [0m Successfully registered anndata object containing [1;36m128628[0m cells, [1;36m1897[0m vars, [1;36m225[0m      
         batches, [1;36m59[0m labels, and [1;36m0[0m proteins. Also registered [1;36m0[0m extra categorical covariates  
         and [1;36m0[0m extra continuous covariates.                                                  
[34mINFO    [0m Training Unsupervised Trainer for [1;36m62[0m epochs.                                        
[34mINFO    [0m Training SemiSupervised Trainer for [1;36m500[0m epochs.                                     
[34mINFO    [0m KL warmup for [

In [14]:
latent_subadata = sc.AnnData(model.get_latent_representation(
    adata_query
), obs=adata_query.obs.copy())

In [15]:
model.save(HLCA_SCARCHES_SAVE_FINETUNED_MODEL_format.format('healthy'))
latent_subadata.write(HLCA_SCARCHES_SAVE_MAPPED_LATENT_format.format('healthy'))

In [16]:
latent_subadata = sc.read(HLCA_SCARCHES_SAVE_MAPPED_LATENT_format.format('healthy'))

In [17]:
%%time

# run k-neighbors transformer
k_neighbors_transformer = weighted_knn_trainer(
    train_adata=reference_emb_adata,
    train_adata_emb="X", # location of our joint embedding
    label_key="ann_finest_level",
    n_neighbors=50,
    )    
# perform label transfer
labels, uncert = weighted_knn_transfer(
    k_neighbors_transformer,
    query_adata=latent_subadata,
    query_adata_emb="X", # location of our joint embedding
    label_keys="ann_finest_level",
    knn_model=k_neighbors_transformer,
    ref_adata_obs = reference_emb_adata.obs
    )

Weighted KNN with n_neighbors = 50 ... finished!
CPU times: user 1h 3min 30s, sys: 33min 11s, total: 1h 36min 41s
Wall time: 54min 18s


In [20]:
latent_subadata.obs['pred_ann_class'] = labels["ann_finest_level"].to_numpy()
latent_subadata.obs['pred_ann_uncert'] = uncert["ann_finest_level"].to_numpy().astype(np.float32)

In [21]:
latent_subadata.write(HLCA_SCARCHES_SAVE_MAPPED_LATENT_format.format('healthy-knn'))