In [None]:
import scanpy as sc
import numpy as np
import pandas as pd
import os
import gc
import anndata as ad
import scarches as sca
import pytorch_lightning as pl
from scarches.models.scpoli import scPoli
from sklearn.metrics import classification_report
import sys
import traceback
import matplotlib.pyplot as plt

In [None]:
import os
os.chdir('/home/aih/shrey.parikh/PDAC/PDAC/processed_datasets/')

In [None]:
adata_filtered = sc.read_h5ad('All_genes/Concat_All_Genes_filtered.h5ad')

In [None]:
adata_filtered.X = adata_filtered.layers['raw'].copy()
sc.pp.normalize_total(adata_filtered, target_sum=1e4)
sc.pp.log1p(adata_filtered)
adata_filtered.layers['log_norm'] = adata_filtered.X.copy()
adata_filtered.X = adata_filtered.layers['raw']

In [None]:
#annotate snRNA-seq

In [None]:
ding_sn_map =['HT224P1',
 'HT231P1',
 'HT232P1',
 'HT242P1',
 'HT259P1',
 'HT264P1',
 'HT270P1',
 'HT284P1',
 'HT288P1',
 'HT306P1',
 'HT412P1']
adata_filtered.obs.Dataset = np.where(adata_filtered.obs.ID.isin(ding_sn_map), 'Ding_snRNA-seq', adata_filtered.obs.Dataset)

adata_filtered.obs['batch_covariate'] = adata_filtered.obs['Dataset'].astype(str) + '_' + adata_filtered.obs['Condition'].astype(str)
adata_filtered.obs['batch_covariate'] = adata_filtered.obs['Dataset'].astype(str) + '_' + adata_filtered.obs['Condition'].astype(str)
adata_filtered.obs.batch_covariate = adata_filtered.obs.batch_covariate.replace('Ding_snRNA-seq_snRNA-seq',  'Ding_snRNA-seq')

adata_filtered.obs.groupby('batch_covariate').size()

In [None]:
regev = adata_filtered[adata_filtered.obs.Dataset == 'Regev'].copy()
ding_sn = adata_filtered[adata_filtered.obs.batch_covariate == 'Ding_snRNA-seq'].copy()

In [None]:
# adata_filtered.write('All_genes/Concat_All_Genes_filtered.h5ad')
del adata_filtered
gc.collect()

In [None]:
sc.pp.pca(ding_sn, layer='log_norm')
sc.pp.neighbors(ding_sn)
sc.tl.umap(ding_sn)
sc.tl.leiden(ding_sn)

In [None]:
sc.pl.umap(ding_sn, color='ID')

In [None]:
regev_pdac = sc.read_h5ad('All_genes/Regev_PDAC.h5ad')

In [None]:
regev.obsm['X_umap'] = regev_pdac.obsm['X_umap']

In [None]:
sc.pl.umap(regev, color=['Label_Harmonized', 'ID'], frameon=False, legend_fontsize=5)

# Best to transfer labels from regev to ding

In [None]:
sc.pp.highly_variable_genes(regev, layer='log_norm')
sc.pp.highly_variable_genes(ding_sn, layer='log_norm')

In [None]:
common_genes = list(set(regev.var_names[regev.var.highly_variable]) & set (ding_sn.var_names[ding_sn.var.highly_variable]))

In [None]:
regev_hvg = regev[:, common_genes].copy()
ding_sn_hvg = ding_sn[:, common_genes].copy()

In [None]:
regev_hvg.write('regev_hvg.h5ad')
ding_sn_hvg.write('ding_sn_hvg.h5ad')

In [None]:
regev_hvg = sc.read_h5ad('regev_hvg.h5ad')
ding_sn_hvg = sc.read_h5ad('ding_sn_hvg.h5ad')

In [None]:
cell_type_key='Label_Harmonized' 
condition_key='Dataset'
n_epochs=50
n_latent=25
pretraining_epochs=40
early_stopping_kwargs = {
    "early_stopping_metric": "val_prototype_loss",
    "mode": "min",
    "threshold": 0,
    "patience": 20,
    "reduce_lr": True,
    "lr_patience": 13,
    "lr_factor": 0.1,
}

In [None]:
source_adata = regev_hvg.copy()
target_adata = ding_sn_hvg.copy()
target_adata.obs[cell_type_key] = target_adata.obs[cell_type_key].astype(str)
print(source_adata)
print(target_adata)
# Initialize and train scPoli model
scpoli_model = scPoli(
    adata=source_adata,
    condition_keys=condition_key,
    cell_type_keys=cell_type_key,
    embedding_dims=10,
    latent_dim=10,
    recon_loss='nb',
)
scpoli_model.train(
    n_epochs=n_epochs,
    n_latent=n_latent,
    pretraining_epochs=pretraining_epochs,
    early_stopping_kwargs=early_stopping_kwargs,
    eta=5,
)

In [None]:
scpoli_query = scPoli.load_query_data(
    adata=target_adata,
    reference_model=scpoli_model,
    labeled_indices=[],
)
scpoli_query.train(
    n_epochs=n_epochs,
    pretraining_epochs=pretraining_epochs,
    eta=10
)

In [None]:
target_adata.X = target_adata.X.astype(np.float32)
results_dict = scpoli_query.classify(target_adata, scale_uncertainties=True)
data_latent_source = scpoli_query.get_latent(source_adata, mean=True)
adata_latent_source = sc.AnnData(data_latent_source)
adata_latent_source.obs = source_adata.obs.copy()
data_latent = scpoli_query.get_latent(target_adata, mean=True)
adata_latent = sc.AnnData(data_latent)
adata_latent.obs = target_adata.obs.copy()
adata_latent.obs['cell_type_pred'] = results_dict[cell_type_key]['preds'].tolist()
adata_latent.obs['cell_type_uncert'] = results_dict[cell_type_key]['uncert'].tolist()
adata_latent.obs['classifier_outcome'] = (adata_latent.obs['cell_type_pred'] == adata_latent.obs[cell_type_key])
labeled_prototypes = scpoli_query.get_prototypes_info()
labeled_prototypes.obs[condition_key] = 'labeled prototype'
unlabeled_prototypes = scpoli_query.get_prototypes_info(prototype_set='unlabeled')
unlabeled_prototypes.obs[condition_key] = 'unlabeled prototype'
adata_latent_full = adata_latent_source.concatenate([adata_latent, labeled_prototypes, unlabeled_prototypes], batch_key='query')
adata_latent_full.obs['cell_type_pred'][adata_latent_full.obs['query'].isin(['0'])] = np.nan
sc.pp.neighbors(adata_latent_full, n_neighbors=15)
sc.tl.umap(adata_latent_full)
adata_latent_full.obs['scpoli_labels'] = adata_latent_full.obs.Label_Harmonized
adata_latent_full.obs_names_make_unique()
adata_latent_full.obs.loc[adata_latent_full.obs['scpoli_labels'] == 'Unknown', 'scpoli_labels'] = adata_latent_full.obs['cell_type_pred']

In [None]:
adata_no_prototype = adata_latent_full[~(adata_latent_full.obs.Dataset.isin(['unlabeled prototype', 'labeled prototype']))]
sc.pl.umap(adata_no_prototype, color=['batch_covariate', 'scpoli_labels', 'ID', 'treatment_status', 'response'], frameon=False, ncols=2, wspace=0.5)

In [None]:
adata_concat = ad.concat([regev_hvg, ding_sn_hvg], join='outer', fill_value=0)

In [None]:
adata_no_prototype.write('../../PDAC_Final/single_nuc_int/regev_ding_latent.h5ad')

In [None]:
scpoli_query.save('../../PDAC_Final/single_nuc_int/scpoli_query', save_anndata=True)
scpoli_model.save('../../PDAC_Final/single_nuc_int/scpoli_train', save_anndata=True)

In [None]:
adata_no_prototype

In [None]:
adata_concat.obsm['X_scpoli'] = adata_no_prototype.X.copy()
adata_concat.obsm['X_umap'] = adata_no_prototype.obs['X_umap']
adata_concat.obs = adata_no_prototype.obsp.copy()

#TODO: was adata_concat supposed to be saved again?