In [1]:
import os
import matplotlib.pyplot as plt
import scanpy as sc
import torch
import time
import json
import scvi
import numpy as np

In [2]:
sc.set_figure_params(figsize=(4, 4))

# Set Params

In [3]:
deep_inject = False
n_epochs_surgery = 300
leave_out_cell_types = ['NK cells', 'NKT cells']

In [4]:
target_batches = ["10X"]
batch_key = "condition"
cell_type_key = "final_annotation"

In [5]:
n_epochs_vae = 500
early_stopping_kwargs = {
    "early_stopping_metric": "elbo",
    "save_best_state_metric": "elbo",
    "patience": 10,
    "threshold": 0,
    "reduce_lr_on_plateau": True,
    "lr_patience": 8,
    "lr_factor": 0.1,
}

In [6]:
# Save right dir path
if len(leave_out_cell_types) > 1:
    version = "ood_2"
else:
    version = "ood_1"
if deep_inject:
    dir_path = os.path.expanduser(f'~/Documents/benchmarking_results/figure_5/scvi/{version}_deep_cond/')
else:
    dir_path = os.path.expanduser(f'~/Documents/benchmarking_results/figure_5/scvi/{version}_first_cond/')
if not os.path.exists(dir_path):
    os.makedirs(dir_path)

# Adata Handling

In [9]:
adata_all = sc.read(os.path.expanduser(f'~/Documents/benchmarking_datasets/Immune_ALL_human_wo_villani_rqr_normalized_hvg.h5ad'))
adata = adata_all.raw.to_adata()

In [10]:
query = np.array([s in target_batches for s in adata.obs[batch_key]])
adata_ref_full = adata[~query].copy()
adata_ref = adata_ref_full[~adata_ref_full.obs[cell_type_key].isin(leave_out_cell_types)].copy()
adata_query = adata[query].copy()

  if not is_categorical(df_full[k]):


In [11]:
adata_ref
adata_ref.obs[cell_type_key].unique()

['CD16+ Monocytes', 'HSPCs', 'CD8+ T cells', 'Erythrocytes', 'CD10+ B cells', ..., 'Monocyte-derived dendritic cells', 'Plasma cells', 'Erythroid progenitors', 'Megakaryocyte progenitors', 'CD4+ T cells']
Length: 13
Categories (13, object): ['CD16+ Monocytes', 'HSPCs', 'CD8+ T cells', 'Erythrocytes', ..., 'Plasma cells', 'Erythroid progenitors', 'Megakaryocyte progenitors', 'CD4+ T cells']

In [12]:
adata_query
adata_query.obs[cell_type_key].unique()

['Monocyte-derived dendritic cells', 'CD14+ Monocytes', 'NK cells', 'CD20+ B cells', 'CD8+ T cells', 'Plasmacytoid dendritic cells', 'CD16+ Monocytes', 'Megakaryocyte progenitors', 'HSPCs', 'Plasma cells']
Categories (10, object): ['Monocyte-derived dendritic cells', 'CD14+ Monocytes', 'NK cells', 'CD20+ B cells', ..., 'CD16+ Monocytes', 'Megakaryocyte progenitors', 'HSPCs', 'Plasma cells']

In [None]:
scvi.data.setup_anndata(adata_ref, batch_key=batch_key)

# Create SCVI model and train

In [None]:
vae = scvi.model.SCVI(
    adata_ref,
    n_layers=2,
    use_cuda=True,
    encode_covariates=True,
    deeply_inject_covariates=deep_inject,
    use_layer_norm="both",
    use_batch_norm="none",
    use_observed_lib_size=True
)

In [None]:
ref_time = time.time()
vae.train(n_epochs=n_epochs_vae, frequency=1, early_stopping_kwargs=early_stopping_kwargs)
ref_time = time.time() - ref_time

# Reference Evaluation

In [None]:
plt.plot(vae.trainer.history["elbo_train_set"][2:], label="train")
plt.plot(vae.trainer.history["elbo_test_set"][2:], label="test")
plt.title("Negative ELBO over training epochs")
plt.legend()

In [None]:
adata_ref.obsm["X_scVI"] = vae.get_latent_representation()

In [None]:
sc.pp.neighbors(adata_ref, use_rep="X_scVI")
sc.tl.leiden(adata_ref)
sc.tl.umap(adata_ref)
plt.figure()
sc.pl.umap(
    adata_ref,
    color=[batch_key, cell_type_key],
    frameon=False,
    ncols=1,
)

In [None]:
adata_ref.write_h5ad(filename=f'{dir_path}reference_data.h5ad')
torch.save(vae.model.state_dict(), f'{dir_path}reference_model_state_dict')
ref_path = f'{dir_path}ref_model/'
if not os.path.exists(ref_path):
    os.makedirs(ref_path)
vae.save(ref_path, overwrite=True)

# Run surgery on query batch

In [None]:
model = scvi.model.SCVI.load_query_data(
    adata_query,
    ref_path,
    use_cuda=True,
    freeze_batchnorm_encoder=True,
    freeze_batchnorm_decoder=True,
    freeze_expression=True
)

In [None]:
query_time = time.time()
model.train(n_epochs=n_epochs_surgery, frequency=1, early_stopping_kwargs=early_stopping_kwargs, weight_decay=0)
query_time = time.time() - query_time

# Evaluation Surgery on Query1

In [None]:
plt.figure()
plt.plot(model.trainer.history["elbo_train_set"][2:], label="train")
plt.plot(model.trainer.history["elbo_test_set"][2:], label="test")
plt.title("Negative ELBO over training epochs")
plt.legend()

In [None]:
adata_query.obsm["X_scVI"] = model.get_latent_representation()

In [None]:
sc.pp.neighbors(adata_query, use_rep="X_scVI")
sc.tl.leiden(adata_query)
sc.tl.umap(adata_query)
plt.figure()
sc.pl.umap(
    adata_query,
    color=[batch_key, cell_type_key],
    frameon=False,
    ncols=1,
)

In [None]:
adata_query.write_h5ad(filename=f'{dir_path}query_data.h5ad')

# Evaluation Query on reference

In [None]:
adata_full = adata_ref.concatenate(adata_query)
adata_full.uns["_scvi"] = adata_query.uns["_scvi"]
print(adata_full.obs[batch_key].unique())
print(adata_full.obs["_scvi_batch"].unique())
adata_full.obsm["X_scVI"] = model.get_latent_representation(adata=adata_full)

In [None]:
sc.pp.neighbors(adata_full, use_rep="X_scVI")
sc.tl.leiden(adata_full)
sc.tl.umap(adata_full)
plt.figure()
sc.pl.umap(
    adata_full,
    color=[batch_key, cell_type_key],
    frameon=False,
    ncols=1,
)

In [None]:
adata_full.write_h5ad(filename=f'{dir_path}full_data.h5ad')
torch.save(model.model.state_dict(), f'{dir_path}surgery_model_state_dict')
surgery_path = f'{dir_path}surg_model/'
if not os.path.exists(surgery_path):
    os.makedirs(surgery_path)
model.save(surgery_path, overwrite=True)

In [None]:
times = dict()
times["ref_time"] = ref_time
times["query_time"] = query_time
times["full_time"] = ref_time + query_time
with open(f'{dir_path}results_times.txt', 'w') as filehandle:
    json.dump(times, filehandle)