In [1]:
import torch
import numpy as np
import scanpy as sc
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report

from scarches.dataset.trvae.data_handling import remove_sparsity
from lataq.models import EMBEDCVAE
from lataq.exp_dict import EXPERIMENT_INFO

sc.settings.set_figure_params(dpi=200, frameon=False)
sc.set_figure_params(dpi=200)
torch.set_printoptions(precision=3, sci_mode=False, edgeitems=7)

%load_ext autoreload
%autoreload 2

In [2]:
PARAMS = {
    'EPOCHS': 50,                                      #TOTAL TRAINING EPOCHS
    'N_PRE_EPOCHS': 40,                                #EPOCHS OF PRETRAINING WITHOUT LANDMARK LOSS
    #'DATA_DIR': '../../lataq_reproduce/data',         #DIRECTORY WHERE THE DATA IS STORED
    #'DATA': '',                                       #DATA USED FOR THE EXPERIMENT
    'EARLY_STOPPING_KWARGS': {                         #KWARGS FOR EARLY STOPPING
        "early_stopping_metric": "val_landmark_loss",  ####value used for early stopping
        "mode": "min",                                 ####choose if look for min or max
        "threshold": 0,
        "patience": 20,
        "reduce_lr": True,
        "lr_patience": 13,
        "lr_factor": 0.1,
    },
    'LABELED_LOSS_METRIC': 'dist',           
    'UNLABELED_LOSS_METRIC': 'dist',
    'LATENT_DIM': 10,
    'ALPHA_EPOCH_ANNEAL': 1e3,
    'CLUSTERING_RES': 2,
    'HIDDEN_LAYERS': 3,
    'ETA': 10,
}

In [28]:
pd.crosstab(target_adata.obs['patient'], target_adata.obs['cell_type'])

cell_type,B cells,Cytotoxic CD8 T cells,Effector memory CD8 T cells,M2 TAMs,Mast cells,Monocytes,NK,Naive T cells,Naive-memory CD4 T cells,Plasma B cells,...,Recently activated CD4 T cells,Regulatory T cells,SPP1 TAMs,T helper cells,Terminally exhausted CD8 T cells,Th17 cells,Transitional memory CD4 T cells,cDC,mDC,pDC
patient,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
BC_1,278,177,52,112,244,9,49,162,1480,60,...,1558,18,42,6,2,20,32,2,34,1
BC_2,35,852,17,58,49,15,83,408,310,20,...,656,174,30,52,50,113,196,2,11,0
BC_3,98,100,30,114,137,90,51,42,227,25,...,411,79,95,11,22,58,40,12,124,1
BC_4,147,1486,41,181,425,148,721,1083,403,28,...,498,85,82,41,30,173,139,12,57,2
BC_5,434,611,44,394,114,226,156,287,65,17,...,318,505,239,50,109,289,16,33,100,20
BC_6,111,344,18,822,177,44,98,172,44,18,...,403,48,1696,20,65,39,10,19,86,1
BC_7,138,519,15,194,281,123,135,383,69,11,...,304,59,275,24,29,53,15,29,174,0
BC_8,6,113,13,80,67,20,41,44,40,11,...,459,24,168,5,6,15,0,5,39,0


In [30]:
DATA_DIR = '../data'
adata = sc.read(f'{DATA_DIR}/cancer_atlas.h5ad')
remove_ct = ["B cells"]
query = adata.obs['patient'][adata.obs['patient'].str.contains('BC')]
adata.obs['query'] = adata.obs['patient'].isin(query)
adata.obs['query'] = adata.obs['query'].astype('category')
source_adata = adata[~adata.obs['patient'].isin(query)]
source_adata = adata[~adata.obs['cell_type'].isin(remove_ct)]
target_adata = adata[adata.obs['patient'].isin(query)]

In [5]:
adata_tmp = adata.copy()
sc.pp.normalize_total(adata_tmp)
sc.pp.log1p(adata_tmp)
sc.pp.highly_variable_genes(adata_tmp, n_top_genes=4000)
adata.var = adata_tmp.var
adata = adata[:, adata.var['highly_variable']]
source_adata = adata[adata.obs[condition_key].isin(reference)].copy()
target_adata = adata[adata.obs[condition_key].isin(query)].copy()

In [11]:
lataq_model = EMBEDCVAE(
    adata=source_adata,
    condition_key=condition_key,
    cell_type_keys=cell_type_key,
    hidden_layer_sizes=[128]*int(PARAMS['HIDDEN_LAYERS']),
    latent_dim=PARAMS['LATENT_DIM'],
)

Embedding dictionary:
 	Num conditions: 80
 	Embedding dim: 10
Encoder Architecture:
	Input Layer in, out and cond: 4000 128 0
	Hidden Layer 1 in/out: 128 128
	Hidden Layer 2 in/out: 128 128
	Mean/Var Layer in/out: 128 10
Decoder Architecture:
	First Layer in, out and cond:  10 128 10
	Hidden Layer 1 in/out: 128 128
	Hidden Layer 2 in/out: 128 128
	Output Layer in/out:  128 4000 



In [12]:
lataq_model.train(
    n_epochs=PARAMS['EPOCHS'],
    pretraining_epochs=PARAMS['N_PRE_EPOCHS'],
    early_stopping_kwargs=PARAMS['EARLY_STOPPING_KWARGS'],
    alpha_epoch_anneal=PARAMS['ALPHA_EPOCH_ANNEAL'],
    eta=PARAMS['ETA'],
    clustering_res=PARAMS['CLUSTERING_RES'],
    labeled_loss_metric=PARAMS['LABELED_LOSS_METRIC'],
    unlabeled_loss_metric=PARAMS['UNLABELED_LOSS_METRIC'],
)

 |████████████████████| 100.0%  - val_loss: 525.1931906680 - val_trvae_loss: 516.6475039519 - val_landmark_loss: 8.5456859928 - val_labeled_loss: 0.8545685981
Saving best state of network...
Best State was in Epoch 48


In [13]:
lataq_model.save('tmp/', overwrite=True)

In [14]:
lataq_query = lataq_model.load_query_data(
    adata=target_adata,
    reference_model='tmp/',
    labeled_indices=[],
)

EOFError: Ran out of input

In [None]:
lataq_query.train(
    n_epochs=PARAMS['EPOCHS'],
    early_stopping_kwargs=PARAMS['EARLY_STOPPING_KWARGS'],
    alpha_epoch_anneal=PARAMS['ALPHA_EPOCH_ANNEAL'],
    pretraining_epochs=PARAMS['N_PRE_EPOCHS'],
    clustering_res=PARAMS['CLUSTERING_RES'],
    eta=PARAMS['ETA'],
    labeled_loss_metric=PARAMS['LABELED_LOSS_METRIC'],
    unlabeled_loss_metric=PARAMS['UNLABELED_LOSS_METRIC'],
)

In [9]:
results_dict = lataq_query.classify(
    adata.X, 
    adata.obs[condition_key], 
    metric=PARAMS['LABELED_LOSS_METRIC'],
)
for i in range(len(cell_type_key)):
    preds = results_dict[cell_type_key[i]]['preds']
    probs = results_dict[cell_type_key[i]]['probs']
    classification_df = pd.DataFrame(
        classification_report(
            y_true=adata.obs[cell_type_key[i]], 
            y_pred=preds,
            output_dict=True
        )
    ).transpose()