In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
from pathlib import Path

import random
import torch
import numpy as np
import scanpy as sc
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report

from scarches.dataset.trvae.data_handling import remove_sparsity
from scarches.models.scpoli import scPoli

sc.settings.set_figure_params(dpi=200, frameon=False)
sc.set_figure_params(dpi=500)
plt.rcParams['figure.figsize'] = (5, 5)
torch.set_printoptions(precision=3, sci_mode=False, edgeitems=7)

 captum (see https://github.com/pytorch/captum).
INFO:lightning_fabric.utilities.seed:[rank: 0] Global seed set to 0


In [3]:
condition_key = 'sample'
cell_type_key = ['ann_finest_level']

Path(os.path.expanduser("~/io/scpoli_repr/scpoli_models/")).mkdir(parents=True, exist_ok=True)
OUTPUT_format = os.path.expanduser("~/io/scpoli_repr/scpoli_models/hlca_core_sample_{replicate}{ext}")

In [9]:
PARAMS = {
    'EPOCHS': 50,                                      #TOTAL TRAINING EPOCHS
    'N_PRE_EPOCHS': 40,                                #EPOCHS OF PRETRAINING WITHOUT LANDMARK LOSS
    #'DATA_DIR': '../../lataq_reproduce/data',          #DIRECTORY WHERE THE DATA IS STORED
    #'DATA': 'pancreas',                                #DATA USED FOR THE EXPERIMENT
    'EARLY_STOPPING_KWARGS': {                         #KWARGS FOR EARLY STOPPING
        "early_stopping_metric": "val_prototype_loss",  ####value used for early stopping
        "mode": "min",                                 ####choose if look for min or max
        "threshold": 0,
        "patience": 20,
        "reduce_lr": True,
        "lr_patience": 13,
        "lr_factor": 0.1,
    },
    'LABELED_LOSS_METRIC': 'dist',           
    'UNLABELED_LOSS_METRIC': 'dist',
    'LATENT_DIM': 50,
    'ALPHA_EPOCH_ANNEAL': 1e3,
    'CLUSTERING_RES': 2,
    'HIDDEN_LAYERS': 4,
    'ETA': 1,
}

In [5]:
adata = sc.read(os.path.expanduser('~/io/scpoli_repr/hlca_counts_commonvars.h5ad'))
adata

AnnData object with n_obs × n_vars = 584884 × 1897
    obs: 'is_primary_data', 'assay_ontology_term_id', 'cell_type_ontology_term_id', 'development_stage_ontology_term_id', 'disease_ontology_term_id', 'ethnicity_ontology_term_id', 'tissue_ontology_term_id', 'organism_ontology_term_id', 'sex_ontology_term_id', 'sample', 'study', 'subject_ID', 'smoking_status', 'BMI', 'condition', 'subject_type', 'sample_type', "3'_or_5'", 'sequencing_platform', 'cell_ranger_version', 'fresh_or_frozen', 'dataset', 'anatomical_region_level_2', 'anatomical_region_level_3', 'anatomical_region_highest_res', 'age', 'ann_highest_res', 'n_genes', 'size_factors', 'log10_total_counts', 'mito_frac', 'ribo_frac', 'original_ann_level_1', 'original_ann_level_2', 'original_ann_level_3', 'original_ann_level_4', 'original_ann_level_5', 'original_ann_nonharmonized', 'scanvi_label', 'leiden_1', 'leiden_2', 'leiden_3', 'anatomical_region_ccf_score', 'entropy_study_leiden_3', 'entropy_dataset_leiden_3', 'entropy_subject_ID_

In [6]:
adata.obs.groupby('study').size()

study
Banovich_Kropski_2020     121881
Barbry_Leroy_2020          74486
Jain_Misharin_2021         45557
Krasnow_2020               60977
Lafyatis_Rojas_2019        24180
Meyer_2019                 35522
Misharin_2021              64842
Misharin_Budinger_2018     41216
Nawijn_2021                70401
Seibold_2020               33593
Teichmann_Meyer_2019       12229
dtype: int64

In [7]:
adata.X = adata.X.astype(np.float32)

In [None]:
seeds = [random.randint(0, 2**32) for _ in range(10)]
for i, seed in enumerate(seeds):
    print("Replicate ", i)
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    
    model_output_path = OUTPUT_format.format(replicate=i, ext="")
    latent_output_path = OUTPUT_format.format(replicate=i, ext=".latent.h5ad")
    if os.path.exists(model_output_path) and os.path.exists(latent_output_path):
        print(f"{latent_outout_path} exists. Skipping.")
        continue
    
    scpoli_model = scPoli(
        adata=adata,
        condition_key=condition_key,
        cell_type_keys=cell_type_key,
        hidden_layer_sizes=[128]*3,
        latent_dim=50,
        embedding_dim=20,
        inject_condition=['encoder', 'decoder']
    )
    
    scpoli_model.train(
        n_epochs=50,
        pretraining_epochs=45,
        early_stopping_kwargs=PARAMS['EARLY_STOPPING_KWARGS'],
        alpha_epoch_anneal=PARAMS['ALPHA_EPOCH_ANNEAL'],
        eta=PARAMS['ETA'],
        clustering_res=PARAMS['CLUSTERING_RES'],
        labeled_loss_metric=PARAMS['LABELED_LOSS_METRIC'],
        unlabeled_loss_metric=PARAMS['UNLABELED_LOSS_METRIC'],
        use_stratified_sampling=False,
        best_reload=False
    )
    scpoli_model.save(model_output_path, overwrite=True)
    
    data_latent = scpoli_model.get_latent(
        adata.X.A.astype('float32'), 
        adata.obs[condition_key].values,
        mean=True,
    )
    adata_latent = sc.AnnData(data_latent)
    adata_latent.obs = adata.obs.copy()
    adata_latent.write(latent_output_path)

Replicate  0
Embedding dictionary:
 	Num conditions: 166
 	Embedding dim: 20
Encoder Architecture:
	Input Layer in, out and cond: 1897 128 20
	Hidden Layer 1 in/out: 128 128
	Hidden Layer 2 in/out: 128 128
	Mean/Var Layer in/out: 128 50
Decoder Architecture:
	First Layer in, out and cond:  50 128 20
	Hidden Layer 1 in/out: 128 128
	Hidden Layer 2 in/out: 128 128
	Output Layer in/out:  128 1897 

Initializing dataloaders
Starting training
 |████████████████████| 100.0%  - val_loss: 608.3992461824 - val_cvae_loss: 598.7960665847 - val_prototype_loss: 9.6031794496 - val_labeled_loss: 9.60317944965
Replicate  1
Embedding dictionary:
 	Num conditions: 166
 	Embedding dim: 20
Encoder Architecture:
	Input Layer in, out and cond: 1897 128 20
	Hidden Layer 1 in/out: 128 128
	Hidden Layer 2 in/out: 128 128
	Mean/Var Layer in/out: 128 50
Decoder Architecture:
	First Layer in, out and cond:  50 128 20
	Hidden Layer 1 in/out: 128 128
	Hidden Layer 2 in/out: 128 128
	Output Layer in/out:  128 1897 
