In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
from pathlib import Path

import random
import torch
import numpy as np
import scanpy as sc
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report

from scarches.dataset.trvae.data_handling import remove_sparsity
from scarches.models.scpoli import scPoli

sc.settings.set_figure_params(dpi=200, frameon=False)
sc.set_figure_params(dpi=500)
plt.rcParams['figure.figsize'] = (5, 5)
torch.set_printoptions(precision=3, sci_mode=False, edgeitems=7)

 captum (see https://github.com/pytorch/captum).
INFO:lightning_fabric.utilities.seed:[rank: 0] Global seed set to 0


In [3]:
condition_key = 'sample'
cell_type_key = ['ann_finest_level']

HLCA_DATA_PATH = os.path.expanduser('~/io/scpoli_repr/hlca_counts_commonvars.h5ad')
HLCA_EXTENDED_DATA_PATH = os.path.expanduser('~/io/scpoli_repr/hlca_extended_commonvars.h5ad')
HLCA_CANCER_DATA_PATH = os.path.expanduser('~/io/scpoli_repr/hlca_cancer_commonvars.h5ad')

Path(os.path.expanduser("~/io/scpoli_repr/mappings/")).mkdir(parents=True, exist_ok=True)
MODEL_format = os.path.expanduser("~/io/scpoli_repr/scpoli_models/hlca_core_sample_{replicate}{ext}")
OUTPUT_format = os.path.expanduser("~/io/scpoli_repr/mappings/hlca_{exp}_sample_{replicate}{ext}")

In [29]:
MAPPING_PARAMS = {
    'EPOCHS': 50,                                      #TOTAL TRAINING EPOCHS
    'N_PRE_EPOCHS': 40,                                #EPOCHS OF PRETRAINING WITHOUT LANDMARK LOSS
    #'DATA_DIR': '../../lataq_reproduce/data',          #DIRECTORY WHERE THE DATA IS STORED
    #'DATA': 'pancreas',                                #DATA USED FOR THE EXPERIMENT
    'EARLY_STOPPING_KWARGS': {                         #KWARGS FOR EARLY STOPPING
        "early_stopping_metric": "val_prototype_loss",  ####value used for early stopping
        "mode": "min",                                 ####choose if look for min or max
        "threshold": 0,
        "patience": 20,
        "reduce_lr": True,
        "lr_patience": 13,
        "lr_factor": 0.1,
    },
    'LABELED_LOSS_METRIC': 'dist',           
    'UNLABELED_LOSS_METRIC': 'dist',
    'LATENT_DIM': 50,
    'ALPHA_EPOCH_ANNEAL': 1e3,
    'CLUSTERING_RES': 2,
    'HIDDEN_LAYERS': 4,
    'ETA': 1,
}

In [4]:
adata = sc.read(HLCA_DATA_PATH)
adata.X = adata.X.astype(np.float32)
adata

AnnData object with n_obs × n_vars = 584884 × 1897
    obs: 'is_primary_data', 'assay_ontology_term_id', 'cell_type_ontology_term_id', 'development_stage_ontology_term_id', 'disease_ontology_term_id', 'ethnicity_ontology_term_id', 'tissue_ontology_term_id', 'organism_ontology_term_id', 'sex_ontology_term_id', 'sample', 'study', 'subject_ID', 'smoking_status', 'BMI', 'condition', 'subject_type', 'sample_type', "3'_or_5'", 'sequencing_platform', 'cell_ranger_version', 'fresh_or_frozen', 'dataset', 'anatomical_region_level_2', 'anatomical_region_level_3', 'anatomical_region_highest_res', 'age', 'ann_highest_res', 'n_genes', 'size_factors', 'log10_total_counts', 'mito_frac', 'ribo_frac', 'original_ann_level_1', 'original_ann_level_2', 'original_ann_level_3', 'original_ann_level_4', 'original_ann_level_5', 'original_ann_nonharmonized', 'scanvi_label', 'leiden_1', 'leiden_2', 'leiden_3', 'anatomical_region_ccf_score', 'entropy_study_leiden_3', 'entropy_dataset_leiden_3', 'entropy_subject_ID_

In [18]:
adata_extended = sc.read(HLCA_EXTENDED_DATA_PATH)
adata_extended = adata_extended[adata_extended.obs['study'] == 'Meyer_2021'].copy()
adata_extended.X = adata_extended.X.astype(np.float32)
adata_extended.obs[cell_type_key[0]] = 'NA'
adata_extended

AnnData object with n_obs × n_vars = 128628 × 1897
    obs: 'dataset', 'study', 'original_celltype_ann', 'condition', 'subject_ID', 'sample', 'cells_or_nuclei', 'single_cell_platform', 'sample_type', 'age', 'sex', 'ethnicity', 'BMI', 'smoking_status', 'anatomical_region_level_1', 'anatomical_region_coarse', 'anatomical_region_detailed', 'genome', 'disease', 'ann_finest_level'

In [19]:
adata_cancer = sc.read(HLCA_CANCER_DATA_PATH)
adata_cancer.X = adata_cancer.X.astype(np.float32)
adata_cancer.obs[cell_type_key[0]] = 'NA'
adata_cancer

AnnData object with n_obs × n_vars = 93575 × 1897
    obs: 'n_genes_detected', 'total_counts', 'cell_from_tumor', 'subject_ID', 'tumor_site', 'original_celltype_ann', 'sample', 'study', 'study_long', 'dataset', 'last_author_PI', 'lung_vs_nasal', 'ann_level_1', 'ann_level_2', 'ann_level_3', 'ann_level_4', 'ann_level_5', 'ann_highest_res', 'ann_new', 'scanvi_label', 'ann_finest_level'
    var: 'original_gene_names', 'gene_symbols', 'ensembl'

In [16]:
adata_extended.obs.study.value_counts()

Meyer_2021    128628
Name: study, dtype: int64

In [17]:
adata_cancer.obs.study.value_counts()

Thienpont_2018    93575
Name: study, dtype: int64

In [14]:
adata_extended

AnnData object with n_obs × n_vars = 1647652 × 1897
    obs: 'dataset', 'study', 'original_celltype_ann', 'condition', 'subject_ID', 'sample', 'cells_or_nuclei', 'single_cell_platform', 'sample_type', 'age', 'sex', 'ethnicity', 'BMI', 'smoking_status', 'anatomical_region_level_1', 'anatomical_region_coarse', 'anatomical_region_detailed', 'genome', 'disease'

In [34]:
seeds = [random.randint(0, 2**32) for _ in range(10)]
for i, seed in enumerate(seeds):
    for ds_name, adata_query in [("healthy", adata_extended), ("cancer", adata_cancer)]:
        print("Replicate ", i, "DS", ds_name)
        torch.manual_seed(seed)
        np.random.seed(seed)
        random.seed(seed)

        ref_path = MODEL_format.format(replicate=i, ext="")
        latent_output_path = OUTPUT_format.format(exp=f"{ds_name}_mapped", replicate=i, ext=".latent.h5ad")

        if os.path.exists(latent_output_path):
            print(f"{latent_output_path} exists. Skipping.")
            continue
        if not os.path.exists(ref_path):
            print(f"Reference {ref_path} not found. Skipping.")
            continue

        model = scPoli.load_query_data(
            adata=adata_query,
            labeled_indices=[],
            reference_model=ref_path
        )

        model.train(
            n_epochs=100,
            pretraining_epochs=80,
            early_stopping_kwargs=MAPPING_PARAMS['EARLY_STOPPING_KWARGS'],
            alpha_epoch_anneal=MAPPING_PARAMS['ALPHA_EPOCH_ANNEAL'],
            eta=0,
            clustering_res=MAPPING_PARAMS['CLUSTERING_RES'],
            labeled_loss_metric=MAPPING_PARAMS['LABELED_LOSS_METRIC'],
            unlabeled_loss_metric=MAPPING_PARAMS['UNLABELED_LOSS_METRIC'],
            weight_decay=0,
            use_stratified_sampling=False,
            reload_best=False,
        )

        latent_subadata = sc.AnnData(model.get_latent(
            adata_query.X.A, 
            adata_query.obs[condition_key]
        ), obs=adata_query.obs.copy())

        subdata_predictions = model.classify(
            x=adata_query.X.A,
            c=adata_query.obs[condition_key],
            get_prob=False,
        )

        latent_subadata.obs['pred_ann_class'] = subdata_predictions['ann_finest_level']['preds']
        latent_subadata.obs['pred_ann_uncert'] = subdata_predictions['ann_finest_level']['uncert']
        latent_subadata.obsm['pred_ann_weighted_distances'] = subdata_predictions['ann_finest_level']['weighted_distances']

        latent_subadata.write(latent_output_path)

Replicate  0 DS healthy
/home/icb/amirali.moinfar/io/scpoli_repr/mappings/hlca_healthy_mapped_sample_0.latent.h5ad exists. Skipping.
Replicate  0 DS cancer
/home/icb/amirali.moinfar/io/scpoli_repr/mappings/hlca_cancer_mapped_sample_0.latent.h5ad exists. Skipping.
Replicate  1 DS healthy
/home/icb/amirali.moinfar/io/scpoli_repr/mappings/hlca_healthy_mapped_sample_1.latent.h5ad exists. Skipping.
Replicate  1 DS cancer
/home/icb/amirali.moinfar/io/scpoli_repr/mappings/hlca_cancer_mapped_sample_1.latent.h5ad exists. Skipping.
Replicate  2 DS healthy
/home/icb/amirali.moinfar/io/scpoli_repr/mappings/hlca_healthy_mapped_sample_2.latent.h5ad exists. Skipping.
Replicate  2 DS cancer
/home/icb/amirali.moinfar/io/scpoli_repr/mappings/hlca_cancer_mapped_sample_2.latent.h5ad exists. Skipping.
Replicate  3 DS healthy
/home/icb/amirali.moinfar/io/scpoli_repr/mappings/hlca_healthy_mapped_sample_3.latent.h5ad exists. Skipping.
Replicate  3 DS cancer
/home/icb/amirali.moinfar/io/scpoli_repr/mappings/hl