In [None]:
import os
import matplotlib.pyplot as plt
import scanpy as sc
import torch
import time
import json
import scvi
import numpy as np

In [None]:
sc.set_figure_params(figsize=(4, 4))

In [None]:
batch_key = "condition"
cell_type_key = "final_annotation"

In [None]:
n_epochs_vae = 500
early_stopping_kwargs = {
    "early_stopping_metric": "elbo",
    "save_best_state_metric": "elbo",
    "patience": 10,
    "threshold": 0,
    "reduce_lr_on_plateau": True,
    "lr_patience": 8,
    "lr_factor": 0.1,
}

In [None]:
# Save right dir path
dir_path = os.path.expanduser(f'~/Documents/benchmarking_results/full_integration/scvi/immune_all_human_fig6/')
if not os.path.exists(dir_path):
    os.makedirs(dir_path)

In [None]:
adata_all = sc.read(os.path.expanduser(f'~/Documents/benchmarking_datasets/Immune_ALL_human_wo_villani_rqr_normalized_hvg.h5ad'))
adata = adata_all.raw.to_adata()
adata

In [None]:
scvi.data.setup_anndata(adata, batch_key=batch_key)

In [None]:
vae = scvi.model.SCVI(
    adata,
    n_layers=2,
    use_cuda=True,
)

In [None]:
full_time = time.time()
vae.train(n_epochs=n_epochs_vae, frequency=1, early_stopping_kwargs=early_stopping_kwargs)
full_time = time.time() - full_time

In [None]:
plt.plot(vae.trainer.history["elbo_train_set"][2:], label="train")
plt.plot(vae.trainer.history["elbo_test_set"][2:], label="test")
plt.title("Negative ELBO over training epochs")
plt.legend()

In [None]:
adata.obsm["X_scVI"] = vae.get_latent_representation()

In [None]:
sc.pp.neighbors(adata, use_rep="X_scVI")
sc.tl.leiden(adata)
sc.tl.umap(adata)
plt.figure()
sc.pl.umap(
    adata,
    color=[batch_key, cell_type_key],
    frameon=False,
    ncols=1,
)

In [None]:
adata.write_h5ad(filename=f'{dir_path}data.h5ad')
torch.save(vae.model.state_dict(), f'{dir_path}model_state_dict')
path = f'{dir_path}model/'
if not os.path.exists(path):
    os.makedirs(path)
vae.save(path, overwrite=True)

In [None]:
times = dict()
times["full_time"] = full_time
with open(f'{dir_path}results_times.txt', 'w') as filehandle:
    json.dump(times, filehandle)