In [1]:
%load_ext autoreload
%autoreload 2

In [3]:
from knn_cell_type_assigner import weighted_knn_trainer, weighted_knn_transfer

In [4]:
import os
import pathlib
import pickle

import anndata
import torch
import numpy as np
import scanpy as sc
import scarches as sca
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report

from scarches.dataset.trvae.data_handling import remove_sparsity
from lataq.models import EMBEDCVAE

sc.settings.set_figure_params(dpi=200, frameon=False)
sc.set_figure_params(dpi=500)
plt.rcParams['figure.figsize'] = (5, 5)
torch.set_printoptions(precision=3, sci_mode=False, edgeitems=7)

 captum (see https://github.com/pytorch/captum).
INFO:lightning_fabric.utilities.seed:[rank: 0] Global seed set to 0


In [5]:
home_dir = os.path.expanduser("~")

In [17]:
condition_key = 'sample'
cell_type_key = ['ann_finest_level']

HLCA_DATA_PATH = os.path.expanduser('~/io/scpoli_repr/hlca_counts_commonvars.h5ad')
HLCA_EXTENDED_DATA_PATH = os.path.expanduser('~/io/scpoli_repr/hlca_extended_commonvars.h5ad')
HLCA_CANCER_DATA_PATH = os.path.expanduser('~/io/scpoli_repr/hlca_cancer_commonvars.h5ad')

HLCA_LATAQ_SAVE_MODEL = os.path.join(home_dir, "io/lataq_repr/lataq_models/hlca_core_sample/")
HLCA_LATAQ_SAVE_FINETUNED_MODEL_format = os.path.join(home_dir, "io/scpoli_repr/lataq_models/HLCA_mapped_model_sample_{}")
HLCA_LATAQ_SAVE_MAPPED_LATENT_format = os.path.join(home_dir, "io/scpoli_repr/lataq_models/HLCA_mapped_model_sample_{}.latent.h5ad")

In [18]:
unlabeled_category = "NA"


PARAMS = {
    'EPOCHS': 50,                                      #TOTAL TRAINING EPOCHS
    'N_PRE_EPOCHS': 40,                                #EPOCHS OF PRETRAINING WITHOUT LANDMARK LOSS
    #'DATA_DIR': '../../lataq_reproduce/data',          #DIRECTORY WHERE THE DATA IS STORED
    #'DATA': 'pancreas',                                #DATA USED FOR THE EXPERIMENT
    'EARLY_STOPPING_KWARGS': {                         #KWARGS FOR EARLY STOPPING
        "early_stopping_metric": "val_landmark_loss",  ####value used for early stopping
        "mode": "min",                                 ####choose if look for min or max
        "threshold": 0,
        "patience": 20,
        "reduce_lr": True,
        "lr_patience": 13,
        "lr_factor": 0.1,
    },
    'LABELED_LOSS_METRIC': 'dist',           
    'UNLABELED_LOSS_METRIC': 'dist',
    'LATENT_DIM': 50,
    'ALPHA_EPOCH_ANNEAL': 1e3,
    'CLUSTERING_RES': 2,
    'HIDDEN_LAYERS': 4,
    'ETA': 1,
}

In [19]:
adata = sc.read(HLCA_DATA_PATH)
adata.X = adata.X.astype(np.float32)
adata

AnnData object with n_obs × n_vars = 584884 × 1897
    obs: 'is_primary_data', 'assay_ontology_term_id', 'cell_type_ontology_term_id', 'development_stage_ontology_term_id', 'disease_ontology_term_id', 'ethnicity_ontology_term_id', 'tissue_ontology_term_id', 'organism_ontology_term_id', 'sex_ontology_term_id', 'sample', 'study', 'subject_ID', 'smoking_status', 'BMI', 'condition', 'subject_type', 'sample_type', "3'_or_5'", 'sequencing_platform', 'cell_ranger_version', 'fresh_or_frozen', 'dataset', 'anatomical_region_level_2', 'anatomical_region_level_3', 'anatomical_region_highest_res', 'age', 'ann_highest_res', 'n_genes', 'size_factors', 'log10_total_counts', 'mito_frac', 'ribo_frac', 'original_ann_level_1', 'original_ann_level_2', 'original_ann_level_3', 'original_ann_level_4', 'original_ann_level_5', 'original_ann_nonharmonized', 'scanvi_label', 'leiden_1', 'leiden_2', 'leiden_3', 'anatomical_region_ccf_score', 'entropy_study_leiden_3', 'entropy_dataset_leiden_3', 'entropy_subject_ID_

In [34]:
adata_extended = sc.read(HLCA_EXTENDED_DATA_PATH)
adata_extended = adata_extended[adata_extended.obs['study'] == 'Meyer_2021'].copy()
adata_extended.X = adata_extended.X.astype(np.float32)
adata_extended.obs[cell_type_key[0]] = unlabeled_category
adata_extended.X = adata_extended.X.A.astype(np.float32)
adata_extended

AnnData object with n_obs × n_vars = 128628 × 1897
    obs: 'dataset', 'study', 'original_celltype_ann', 'condition', 'subject_ID', 'sample', 'cells_or_nuclei', 'single_cell_platform', 'sample_type', 'age', 'sex', 'ethnicity', 'BMI', 'smoking_status', 'anatomical_region_level_1', 'anatomical_region_coarse', 'anatomical_region_detailed', 'genome', 'disease', 'ann_finest_level'

In [35]:
adata_cancer = sc.read(HLCA_CANCER_DATA_PATH)
adata_cancer.X = adata_cancer.X.astype(np.float32)
adata_cancer.obs[cell_type_key[0]] = unlabeled_category
adata_cancer.X = adata_cancer.X.A.astype(np.float32)
adata_cancer

AnnData object with n_obs × n_vars = 93575 × 1897
    obs: 'n_genes_detected', 'total_counts', 'cell_from_tumor', 'subject_ID', 'tumor_site', 'original_celltype_ann', 'sample', 'study', 'study_long', 'dataset', 'last_author_PI', 'lung_vs_nasal', 'ann_level_1', 'ann_level_2', 'ann_level_3', 'ann_level_4', 'ann_level_5', 'ann_highest_res', 'ann_new', 'scanvi_label', 'ann_finest_level'
    var: 'original_gene_names', 'gene_symbols', 'ensembl'

## Training LATAQ

In [None]:
lataq_model = EMBEDCVAE(
    adata=adata,
    condition_key=condition_key,
    cell_type_keys=cell_type_key,
    hidden_layer_sizes=[128]*3,
    latent_dim=50,
    embedding_dim=20,
    inject_condition=['encoder', 'decoder']
)

In [None]:
lataq_model.train(
    n_epochs=50,
    pretraining_epochs=45,
    early_stopping_kwargs=PARAMS['EARLY_STOPPING_KWARGS'],
    alpha_epoch_anneal=PARAMS['ALPHA_EPOCH_ANNEAL'],
    eta=PARAMS['ETA'],
    clustering_res=PARAMS['CLUSTERING_RES'],
    labeled_loss_metric=PARAMS['LABELED_LOSS_METRIC'],
    unlabeled_loss_metric=PARAMS['UNLABELED_LOSS_METRIC'],
    use_stratified_sampling=False,
    best_reload=False
)

In [None]:
lataq_model.save(HLCA_LATAQ_SAVE_MODEL, overwrite=True)

## Perform surgery on reference model and train on cancer dataset

In [36]:
ref_path = HLCA_LATAQ_SAVE_MODEL
adata_query = adata_cancer
adata_ref = adata

In [13]:
lataq_model = EMBEDCVAE.load(
    os.path.expanduser(ref_path),
    adata
)
lataq_model.model.cuda()

AnnData object with n_obs × n_vars = 584884 × 1897
    obs: 'is_primary_data', 'assay_ontology_term_id', 'cell_type_ontology_term_id', 'development_stage_ontology_term_id', 'disease_ontology_term_id', 'ethnicity_ontology_term_id', 'tissue_ontology_term_id', 'organism_ontology_term_id', 'sex_ontology_term_id', 'sample', 'study', 'subject_ID', 'smoking_status', 'BMI', 'condition', 'subject_type', 'sample_type', "3'_or_5'", 'sequencing_platform', 'cell_ranger_version', 'fresh_or_frozen', 'dataset', 'anatomical_region_level_2', 'anatomical_region_level_3', 'anatomical_region_highest_res', 'age', 'ann_highest_res', 'n_genes', 'size_factors', 'log10_total_counts', 'mito_frac', 'ribo_frac', 'original_ann_level_1', 'original_ann_level_2', 'original_ann_level_3', 'original_ann_level_4', 'original_ann_level_5', 'original_ann_nonharmonized', 'scanvi_label', 'leiden_1', 'leiden_2', 'leiden_3', 'anatomical_region_ccf_score', 'entropy_study_leiden_3', 'entropy_dataset_leiden_3', 'entropy_subject_ID_

EmbedCVAE(
  (embedding): Embedding(166, 20, max_norm=1)
  (encoder): Encoder(
    (FC): Sequential(
      (L0): CondLayers(
        (expr_L): Linear(in_features=1897, out_features=128, bias=True)
        (cond_L): Linear(in_features=20, out_features=128, bias=False)
      )
      (N0): LayerNorm((128,), eps=1e-05, elementwise_affine=False)
      (A0): ReLU()
      (D0): Dropout(p=0.05, inplace=False)
      (L1): Linear(in_features=128, out_features=128, bias=True)
      (N1): LayerNorm((128,), eps=1e-05, elementwise_affine=False)
      (A1): ReLU()
      (D1): Dropout(p=0.05, inplace=False)
      (L2): Linear(in_features=128, out_features=128, bias=True)
      (N2): LayerNorm((128,), eps=1e-05, elementwise_affine=False)
      (A2): ReLU()
      (D2): Dropout(p=0.05, inplace=False)
    )
    (mean_encoder): Linear(in_features=128, out_features=50, bias=True)
    (log_var_encoder): Linear(in_features=128, out_features=50, bias=True)
  )
  (decoder): Decoder(
    (FirstL): Sequential(
  

In [14]:
reference_emb_adata = sc.AnnData(lataq_model.get_latent(
    adata_ref.X.A.astype('float32'), 
    adata_ref.obs[condition_key].values,
    mean=True,
), obs=adata_ref.obs.copy())

In [53]:
%%time

model = EMBEDCVAE.load_query_data(
    adata=adata_query,
    labeled_indices=[],
    reference_model=ref_path
)

model.train(
    n_epochs=100,
    pretraining_epochs=80,
    early_stopping_kwargs=PARAMS['EARLY_STOPPING_KWARGS'],
    alpha_epoch_anneal=PARAMS['ALPHA_EPOCH_ANNEAL'],
    eta=0,
    clustering_res=PARAMS['CLUSTERING_RES'],
    labeled_loss_metric=PARAMS['LABELED_LOSS_METRIC'],
    unlabeled_loss_metric=PARAMS['UNLABELED_LOSS_METRIC'],
    weight_decay=0,
    use_stratified_sampling=False,
    reload_best=False,
)

AnnData object with n_obs × n_vars = 93575 × 1897
    obs: 'n_genes_detected', 'total_counts', 'cell_from_tumor', 'subject_ID', 'tumor_site', 'original_celltype_ann', 'sample', 'study', 'study_long', 'dataset', 'last_author_PI', 'lung_vs_nasal', 'ann_level_1', 'ann_level_2', 'ann_level_3', 'ann_level_4', 'ann_level_5', 'ann_highest_res', 'ann_new', 'scanvi_label', 'ann_finest_level'
    var: 'original_gene_names', 'gene_symbols', 'ensembl'
Embedding dictionary:
 	Num conditions: 202
 	Embedding dim: 20
Encoder Architecture:
	Input Layer in, out and cond: 1897 128 20
	Hidden Layer 1 in/out: 128 128
	Hidden Layer 2 in/out: 128 128
	Mean/Var Layer in/out: 128 50
Decoder Architecture:
	First Layer in, out and cond:  50 128 20
	Hidden Layer 1 in/out: 128 128
	Hidden Layer 2 in/out: 128 128
	Output Layer in/out:  128 1897 

166
The missing labels are: {'NA'}
Therefore integer value of those labels is set to -1
The missing labels are: {'NA'}
Therefore integer value of those labels is set to -

In [54]:
latent_subadata = sc.AnnData(model.get_latent(
    adata_query.X, 
    adata_query.obs[condition_key].values,
    mean=True,
), obs=adata_query.obs.copy())

subdata_predictions = model.classify(
    x=adata_query.X,
    c=adata_query.obs[condition_key],
    metric='dist',
    get_prob=False,
    threshold=-np.inf
)

latent_subadata.obs['pred_ann_class'] = subdata_predictions['ann_finest_level']['preds']
latent_subadata.obs['pred_ann_uncert'] = subdata_predictions['ann_finest_level']['probs']
latent_subadata.obsm['pred_ann_weighted_distances'] = subdata_predictions['ann_finest_level']['weighted_distances']

In [55]:
model.save(HLCA_LATAQ_SAVE_FINETUNED_MODEL_format.format('cancer'))
latent_subadata.write(HLCA_LATAQ_SAVE_MAPPED_LATENT_format.format('cancer'))

In [56]:
latent_subadata = sc.read(HLCA_LATAQ_SAVE_MAPPED_LATENT_format.format('cancer'))

In [57]:
%%time

# run k-neighbors transformer
k_neighbors_transformer = weighted_knn_trainer(
    train_adata=reference_emb_adata,
    train_adata_emb="X", # location of our joint embedding
    label_key="ann_finest_level",
    n_neighbors=50,
    )    
# perform label transfer
labels, uncert = weighted_knn_transfer(
    k_neighbors_transformer,
    query_adata=latent_subadata,
    query_adata_emb="X", # location of our joint embedding
    label_keys="ann_finest_level",
    knn_model=k_neighbors_transformer,
    ref_adata_obs = reference_emb_adata.obs
    )

Weighted KNN with n_neighbors = 50 ... finished!
CPU times: user 50min 34s, sys: 26min 52s, total: 1h 17min 27s
Wall time: 42min 56s


In [61]:
latent_subadata.obs['pred_ann_class'] = labels["ann_finest_level"].to_numpy()
latent_subadata.obs['pred_ann_uncert'] = uncert["ann_finest_level"].to_numpy().astype(np.float32)

In [62]:
latent_subadata.write(HLCA_LATAQ_SAVE_MAPPED_LATENT_format.format('cancer-knn'))

## Perform surgery on reference model and train on healthy dataset

In [66]:
ref_path = HLCA_LATAQ_SAVE_MODEL
adata_query = adata_extended
adata_ref = adata

In [67]:
lataq_model = EMBEDCVAE.load(
    os.path.expanduser(ref_path),
    adata
)
lataq_model.model.cuda()

AnnData object with n_obs × n_vars = 584884 × 1897
    obs: 'is_primary_data', 'assay_ontology_term_id', 'cell_type_ontology_term_id', 'development_stage_ontology_term_id', 'disease_ontology_term_id', 'ethnicity_ontology_term_id', 'tissue_ontology_term_id', 'organism_ontology_term_id', 'sex_ontology_term_id', 'sample', 'study', 'subject_ID', 'smoking_status', 'BMI', 'condition', 'subject_type', 'sample_type', "3'_or_5'", 'sequencing_platform', 'cell_ranger_version', 'fresh_or_frozen', 'dataset', 'anatomical_region_level_2', 'anatomical_region_level_3', 'anatomical_region_highest_res', 'age', 'ann_highest_res', 'n_genes', 'size_factors', 'log10_total_counts', 'mito_frac', 'ribo_frac', 'original_ann_level_1', 'original_ann_level_2', 'original_ann_level_3', 'original_ann_level_4', 'original_ann_level_5', 'original_ann_nonharmonized', 'scanvi_label', 'leiden_1', 'leiden_2', 'leiden_3', 'anatomical_region_ccf_score', 'entropy_study_leiden_3', 'entropy_dataset_leiden_3', 'entropy_subject_ID_

EmbedCVAE(
  (embedding): Embedding(166, 20, max_norm=1)
  (encoder): Encoder(
    (FC): Sequential(
      (L0): CondLayers(
        (expr_L): Linear(in_features=1897, out_features=128, bias=True)
        (cond_L): Linear(in_features=20, out_features=128, bias=False)
      )
      (N0): LayerNorm((128,), eps=1e-05, elementwise_affine=False)
      (A0): ReLU()
      (D0): Dropout(p=0.05, inplace=False)
      (L1): Linear(in_features=128, out_features=128, bias=True)
      (N1): LayerNorm((128,), eps=1e-05, elementwise_affine=False)
      (A1): ReLU()
      (D1): Dropout(p=0.05, inplace=False)
      (L2): Linear(in_features=128, out_features=128, bias=True)
      (N2): LayerNorm((128,), eps=1e-05, elementwise_affine=False)
      (A2): ReLU()
      (D2): Dropout(p=0.05, inplace=False)
    )
    (mean_encoder): Linear(in_features=128, out_features=50, bias=True)
    (log_var_encoder): Linear(in_features=128, out_features=50, bias=True)
  )
  (decoder): Decoder(
    (FirstL): Sequential(
  

In [68]:
reference_emb_adata = sc.AnnData(lataq_model.get_latent(
    adata_ref.X.A.astype('float32'), 
    adata_ref.obs[condition_key].values,
    mean=True,
), obs=adata_ref.obs.copy())

In [69]:
%%time

model = EMBEDCVAE.load_query_data(
    adata=adata_query,
    labeled_indices=[],
    reference_model=ref_path
)

model.train(
    n_epochs=100,
    pretraining_epochs=80,
    early_stopping_kwargs=PARAMS['EARLY_STOPPING_KWARGS'],
    alpha_epoch_anneal=PARAMS['ALPHA_EPOCH_ANNEAL'],
    eta=0,
    clustering_res=PARAMS['CLUSTERING_RES'],
    labeled_loss_metric=PARAMS['LABELED_LOSS_METRIC'],
    unlabeled_loss_metric=PARAMS['UNLABELED_LOSS_METRIC'],
    weight_decay=0,
    use_stratified_sampling=False,
    reload_best=False,
)

AnnData object with n_obs × n_vars = 128628 × 1897
    obs: 'dataset', 'study', 'original_celltype_ann', 'condition', 'subject_ID', 'sample', 'cells_or_nuclei', 'single_cell_platform', 'sample_type', 'age', 'sex', 'ethnicity', 'BMI', 'smoking_status', 'anatomical_region_level_1', 'anatomical_region_coarse', 'anatomical_region_detailed', 'genome', 'disease', 'ann_finest_level'
Embedding dictionary:
 	Num conditions: 225
 	Embedding dim: 20
Encoder Architecture:
	Input Layer in, out and cond: 1897 128 20
	Hidden Layer 1 in/out: 128 128
	Hidden Layer 2 in/out: 128 128
	Mean/Var Layer in/out: 128 50
Decoder Architecture:
	First Layer in, out and cond:  50 128 20
	Hidden Layer 1 in/out: 128 128
	Hidden Layer 2 in/out: 128 128
	Output Layer in/out:  128 1897 

166
The missing labels are: {'NA'}
Therefore integer value of those labels is set to -1
The missing labels are: {'NA'}
Therefore integer value of those labels is set to -1
The missing labels are: {'NA'}
Therefore integer value of those

In [70]:
latent_subadata = sc.AnnData(model.get_latent(
    adata_query.X, 
    adata_query.obs[condition_key].values,
    mean=True,
), obs=adata_query.obs.copy())

subdata_predictions = model.classify(
    x=adata_query.X,
    c=adata_query.obs[condition_key],
    metric='dist',
    get_prob=False,
    threshold=-np.inf
)

latent_subadata.obs['pred_ann_class'] = subdata_predictions['ann_finest_level']['preds']
latent_subadata.obs['pred_ann_uncert'] = subdata_predictions['ann_finest_level']['probs']
latent_subadata.obsm['pred_ann_weighted_distances'] = subdata_predictions['ann_finest_level']['weighted_distances']

In [71]:
model.save(HLCA_LATAQ_SAVE_FINETUNED_MODEL_format.format('healthy'))
latent_subadata.write(HLCA_LATAQ_SAVE_MAPPED_LATENT_format.format('healthy'))

In [None]:
latent_subadata = sc.read(HLCA_LATAQ_SAVE_MAPPED_LATENT_format.format('healthy'))

In [None]:
%%time

# run k-neighbors transformer
k_neighbors_transformer = weighted_knn_trainer(
    train_adata=reference_emb_adata,
    train_adata_emb="X", # location of our joint embedding
    label_key="ann_finest_level",
    n_neighbors=50,
    )    
# perform label transfer
labels, uncert = weighted_knn_transfer(
    k_neighbors_transformer,
    query_adata=latent_subadata,
    query_adata_emb="X", # location of our joint embedding
    label_keys="ann_finest_level",
    knn_model=k_neighbors_transformer,
    ref_adata_obs = reference_emb_adata.obs
    )

Weighted KNN with n_neighbors = 50 ... finished!
CPU times: user 1h 17min 9s, sys: 46min 23s, total: 2h 3min 32s
Wall time: 1h 16min 28s


In [None]:
latent_subadata.obs['pred_ann_class'] = labels["ann_finest_level"].to_numpy()
latent_subadata.obs['pred_ann_uncert'] = uncert["ann_finest_level"].to_numpy().astype(np.float32)

In [None]:
latent_subadata.write(HLCA_LATAQ_SAVE_MAPPED_LATENT_format.format('healthy-knn'))

In [76]:
1

1