In [1]:
import os
import matplotlib.pyplot as plt
import scanpy as sc
import torch
import time
import json
import scarches as sca
import numpy as np

In [2]:
sc.set_figure_params(figsize=(4, 4))

# Set Params

In [3]:
deep_inject = False
n_epochs_surgery = 300
leave_out_cell_types = ['Pancreas Alpha']

In [4]:
target_batches = ["Pancreas SS2", "Pancreas CelSeq2"]
batch_key = "study"
cell_type_key = "cell_type"

In [5]:
n_epochs_vae = 50
n_epochs_scanvi = 10
early_stopping_kwargs = {
    "early_stopping_metric": "elbo",
    "save_best_state_metric": "elbo",
    "patience": 10,
    "threshold": 0,
    "reduce_lr_on_plateau": True,
    "lr_patience": 8,
    "lr_factor": 0.1,
}
early_stopping_kwargs_scanvi = {
    "early_stopping_metric": "accuracy",
    "save_best_state_metric": "accuracy",
    "on": "full_dataset",
    "patience": 10,
    "threshold": 0.001,
    "reduce_lr_on_plateau": True,
    "lr_patience": 8,
    "lr_factor": 0.1,
}
early_stopping_kwargs_surgery = {
    "early_stopping_metric": "elbo",
    "save_best_state_metric": "elbo",
    "on": "full_dataset",
    "patience": 10,
    "threshold": 0.001,
    "reduce_lr_on_plateau": True,
    "lr_patience": 8,
    "lr_factor": 0.1,
}

In [None]:
# Save right dir path
if deep_inject:
    dir_path = os.path.expanduser(f'~/Documents/benchmarking/figure_1/scanvi/deep_cond/')
else:
    dir_path = os.path.expanduser(f'~/Documents/benchmarking/figure_1/scanvi/first_cond/')

if not os.path.exists(dir_path):
    os.makedirs(dir_path)
control_path = f'{dir_path}controlling/'
if not os.path.exists(control_path):
    os.makedirs(control_path)

# Adata Handling

In [6]:
adata_all = sc.read(os.path.expanduser(f'~/Documents/benchmarking_datasets/pancreas_normalized.h5ad'))
adata = adata_all.raw.to_adata()
adata

AnnData object with n_obs × n_vars = 15681 × 1000
    obs: 'batch', 'study', 'cell_type', 'size_factors'

In [7]:
query = np.array([s in target_batches for s in adata.obs[batch_key]])
query_1 = np.array([s in [target_batches[0]] for s in adata.obs[batch_key]])
query_2 = np.array([s in [target_batches[1]] for s in adata.obs[batch_key]])
adata_ref_full = adata[~query].copy()
adata_ref = adata_ref_full[~adata_ref_full.obs[cell_type_key].isin(leave_out_cell_types)].copy()
adata_query_1 = adata[query_1].copy()
adata_query_2 = adata[query_2].copy()

In [8]:
adata_ref

AnnData object with n_obs × n_vars = 7584 × 1000
    obs: 'batch', 'study', 'cell_type', 'size_factors'

In [9]:
adata_query_1

AnnData object with n_obs × n_vars = 2961 × 1000
    obs: 'batch', 'study', 'cell_type', 'size_factors'

In [10]:
adata_query_2

AnnData object with n_obs × n_vars = 2426 × 1000
    obs: 'batch', 'study', 'cell_type', 'size_factors'

In [11]:
sca.dataset.setup_anndata(adata_ref, batch_key=batch_key, labels_key=cell_type_key)

[34mINFO    [0m Using batches from adata.obs[1m[[0m[32m"study"[0m[1m][0m                                               
[34mINFO    [0m Using labels from adata.obs[1m[[0m[32m"cell_type"[0m[1m][0m                                            
[34mINFO    [0m Using data from adata.X                                                             
[34mINFO    [0m Computing library size prior per batch                                              
[34mINFO    [0m Successfully registered anndata object containing [1;34m7584[0m cells, [1;34m1000[0m vars, [1;34m3[0m batches, 
         [1;34m7[0m labels, and [1;34m0[0m proteins. Also registered [1;34m0[0m extra categorical covariates and [1;34m0[0m extra
         continuous covariates.                                                              
[34mINFO    [0m Please do not further modify adata until model is trained.                          


# Create SCANVI model and train

In [12]:
vae = sca.models.SCANVI(
    adata_ref,
    "Unknown",
    n_layers=2,
    use_cuda=True,
    encode_covariates=True,
    deeply_inject_covariates=deep_inject,
    use_layer_norm="both",
    use_batch_norm="none",
)

In [13]:
print("Labelled Indices: ", len(vae._labeled_indices))
print("Unlabelled Indices: ", len(vae._unlabeled_indices))

Labelled Indices:  7584
Unlabelled Indices:  0


In [14]:
ref_time = time.time()
vae.train(
    n_epochs_unsupervised=n_epochs_vae,
    n_epochs_semisupervised=n_epochs_scanvi,
    unsupervised_trainer_kwargs=dict(early_stopping_kwargs=early_stopping_kwargs),
    semisupervised_trainer_kwargs=dict(metrics_to_monitor=["elbo", "accuracy"],
                                       early_stopping_kwargs=early_stopping_kwargs_scanvi),
    frequency=1
)
ref_time = time.time() - ref_time

[34mINFO    [0m Training Unsupervised Trainer for [1;34m50[0m epochs.                                        
[34mINFO    [0m Training SemiSupervised Trainer for [1;34m10[0m epochs.                                      
[34mINFO    [0m KL warmup phase exceeds overall training phaseIf your applications rely on the      
         posterior quality, consider training for more epochs or reducing the kl warmup.     
[34mINFO    [0m KL warmup for [1;34m400[0m epochs                                                            
Training...:   0%|          | 0/50 [00:00<?, ?it/s]



Training...: 100%|██████████| 50/50 [00:21<00:00,  2.29it/s]
[34mINFO    [0m Training is still in warming up phase. If your applications rely on the posterior   
         quality, consider training for more epochs or reducing the kl warmup.               
[34mINFO    [0m Training time:  [1;34m15[0m s. [35m/[0m [1;34m50[0m epochs                                                   
[34mINFO    [0m KL warmup phase exceeds overall training phaseIf your applications rely on the      
         posterior quality, consider training for more epochs or reducing the kl warmup.     
[34mINFO    [0m KL warmup for [1;34m400[0m epochs                                                            
Training...: 100%|██████████| 10/10 [00:18<00:00,  1.84s/it]
[34mINFO    [0m Training is still in warming up phase. If your applications rely on the posterior   
         quality, consider training for more epochs or reducing the kl warmup.               
[34mINFO    [0m Training time:  [1;3

# Reference Evaluation

In [15]:
ref_predictions = vae.predict(adata=adata_ref,soft=True)
ref_predictions

Unnamed: 0_level_0,Pancreas Acinar,Pancreas Beta,Pancreas Delta,Pancreas Ductal,Pancreas Endothelial,Pancreas Gamma,Pancreas Stellate
index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
0-0-0-0-0,1.700872e-02,0.000023,4.826826e-05,1.594801e-05,9.828552e-01,0.000006,4.237988e-05
1-0-0-0-0,9.979936e-01,0.000007,5.485994e-05,5.480702e-05,1.839756e-03,0.000032,1.793347e-05
2-0-0-0-0,9.442510e-01,0.000045,1.931140e-04,4.826397e-05,5.515557e-02,0.000214,9.286892e-05
3-0-0-0-0,9.993677e-01,0.000006,8.114145e-05,3.306050e-05,4.053751e-04,0.000086,2.071284e-05
4-0-0-0-0,5.979191e-03,0.000011,8.354407e-05,9.590203e-05,9.937906e-01,0.000005,3.467494e-05
...,...,...,...,...,...,...,...
631-1-0,4.892427e-07,0.999994,3.708081e-07,2.934800e-07,4.114018e-07,0.000004,2.402937e-07
632-1-0,7.625266e-07,0.999991,4.861529e-07,3.365601e-07,4.915445e-07,0.000007,3.287217e-07
634-1-0,3.732964e-07,0.999995,2.349990e-07,2.835378e-07,4.636388e-07,0.000003,2.086357e-07
635-1-0,5.206258e-07,0.999993,5.490394e-07,2.064139e-07,4.521891e-07,0.000005,2.301208e-07


In [None]:
ref_predictions = vae.predict(adata_ref)
adata_ref.obsm["X_scANVI"] = vae.get_latent_representation()
adata_ref.obs["predictions"] = vae.predict()
print("Acc: {}".format(np.mean(ref_predictions == adata_ref.obs[cell_type_key])))

In [None]:
plt.figure()
plt.plot(vae.trainer.history['accuracy_full_dataset'][2:], label="ACC")
plt.title("ACC")
plt.legend()
plt.savefig(f'{control_path}reference_acc.png', bbox_inches='tight')

In [None]:
plt.figure()
plt.plot(vae.trainer.history['elbo_full_dataset'][2:], label="ELBO")
plt.title("ELBO")
plt.legend()
plt.savefig(f'{control_path}reference_elbo.png', bbox_inches='tight')

In [None]:
ref_cropped = sc.AnnData(adata_ref.obsm["X_scANVI"])
ref_cropped.obs["celltype"] = adata_ref.obs[cell_type_key].tolist()
ref_cropped.obs["batch"] = adata_ref.obs[batch_key].tolist()
ref_cropped.obs["predictions"] = adata_ref.obs["predictions"].tolist()

In [None]:
sc.pp.neighbors(ref_cropped)
sc.tl.leiden(ref_cropped)
sc.tl.umap(ref_cropped)
ref_cropped.write_h5ad(filename=f'{dir_path}reference_data.h5ad')

In [None]:
ref_cropped

In [None]:
plt.figure()
sc.pl.umap(
    ref_cropped,
    color=["batch", "celltype"],
    frameon=False,
    ncols=1,
    show=False
)
plt.savefig(f'{control_path}umap_reference.png', bbox_inches='tight')

In [None]:
torch.save(vae.model.state_dict(), f'{dir_path}reference_model_state_dict')
ref_path = f'{dir_path}ref_model/'
if not os.path.exists(ref_path):
    os.makedirs(ref_path)
vae.save(ref_path, overwrite=True)

# Run surgery on first query batch

In [None]:
adata_query_1.obs['orig_cell_types'] = adata_query_1.obs[cell_type_key].copy()
adata_query_1.obs[cell_type_key] = vae.unlabeled_category_

In [None]:
model_1 = sca.models.SCANVI.load_query_data(
    adata_query_1,
    ref_path,
    freeze_dropout = True,
)

In [None]:
print("Labelled Indices: ", len(model_1._labeled_indices))
print("Unlabelled Indices: ", len(model_1._unlabeled_indices))

In [None]:
query_1_time = time.time()
model_1.train(
    n_epochs_semisupervised=n_epochs_surgery,
    train_base_model=False,
    semisupervised_trainer_kwargs=dict(metrics_to_monitor=["elbo"], 
                                       weight_decay=0,
                                       early_stopping_kwargs=early_stopping_kwargs_surgery
                                      ),
    frequency=1
)
query_1_time = time.time() - query_1_time

# Evaluation Surgery on Query1

In [None]:
adata_query_1.obsm["X_scANVI"] = model_1.get_latent_representation()
adata_query_1.obs["predictions"] = model_1.predict()
query_1_predictions = model_1.predict()
print("Acc: {}".format(np.mean(query_1_predictions == adata_query_1.obs['orig_cell_types'])))

In [None]:
q1_cropped = sc.AnnData(adata_query_1.obsm["X_scANVI"])
q1_cropped.obs["celltype"] = adata_query_1.obs['orig_cell_types'].tolist()
q1_cropped.obs["batch"] = adata_query_1.obs[batch_key].tolist()
q1_cropped.obs["predictions"] = adata_query_1.obs["predictions"].tolist()

In [None]:
sc.pp.neighbors(q1_cropped)
sc.tl.leiden(q1_cropped)
sc.tl.umap(q1_cropped)
q1_cropped.write_h5ad(filename=f'{dir_path}query_1_data.h5ad')

In [None]:
plt.figure()
sc.pl.umap(
    q1_cropped,
    color=["batch", "celltype"],
    frameon=False,
    ncols=1,
    show=False
)
plt.savefig(f'{control_path}umap_query_1.png', bbox_inches='tight')

# Evaluation Query1 on reference

In [None]:
adata_ref.obs['orig_cell_types'] = adata_ref.obs[cell_type_key].copy()
adata_full_1 = adata_ref.concatenate(adata_query_1)
adata_full_1.uns["_scvi"] = adata_query_1.uns["_scvi"]
adata_full_1.obsm["X_scANVI"] = model_1.get_latent_representation(adata=adata_full_1)

In [None]:
adata_full_1.obs["predictions"] = model_1.predict(adata_full_1)
full_1_predictions = model_1.predict(adata_full_1)
print("Acc: {}".format(np.mean(full_1_predictions == adata_full_1.obs['orig_cell_types'])))

In [None]:
f1_cropped = sc.AnnData(adata_full_1.obsm["X_scANVI"])
f1_cropped.obs["celltype"] = adata_full_1.obs['orig_cell_types'].tolist()
f1_cropped.obs["batch"] = adata_full_1.obs[batch_key].tolist()
f1_cropped.obs["predictions"] = adata_full_1.obs["predictions"].tolist()

In [None]:
sc.pp.neighbors(f1_cropped)
sc.tl.leiden(f1_cropped)
sc.tl.umap(f1_cropped)
f1_cropped.write_h5ad(filename=f'{dir_path}full_1_data.h5ad')

In [None]:
plt.figure()
sc.pl.umap(
    f1_cropped,
    color=["batch", "celltype"],
    frameon=False,
    ncols=1,
    show=False
)
plt.savefig(f'{control_path}umap_full_1.png', bbox_inches='tight')

In [None]:
sc.pl.umap(
    f1_cropped,
    color=["predictions"],
    frameon=False,
    ncols=1,
    show=False
)
plt.savefig(f'{control_path}pred_full_1.png', bbox_inches='tight')

In [None]:
torch.save(model_1.model.state_dict(), f'{dir_path}surgery_1_model_state_dict')
surgery_1_path = f'{dir_path}surg_1_model/'
if not os.path.exists(surgery_1_path):
    os.makedirs(surgery_1_path)
model_1.save(surgery_1_path, overwrite=True)

# Run surgery on second query batch

In [None]:
adata_query_2.obs['orig_cell_types'] = adata_query_2.obs[cell_type_key].copy()
adata_query_2.obs[cell_type_key] = model_1.unlabeled_category_
model_2 = sca.models.SCANVI.load_query_data(
    adata_query_2,
    surgery_1_path,
    freeze_dropout = True,
)

In [None]:
print("Labelled Indices: ", len(model_2._labeled_indices))
print("Unlabelled Indices: ", model_2._unlabeled_indices.shape[0])

In [None]:
query_2_time = time.time()
model_2.train(
    n_epochs_semisupervised=n_epochs_surgery,
    train_base_model=False,
    semisupervised_trainer_kwargs=dict(metrics_to_monitor=["elbo"],
                                       weight_decay=0,
                                      early_stopping_kwargs=early_stopping_kwargs_surgery),
    frequency=1
)
query_2_time = time.time() - query_2_time

# Evaluation Surgery on Query2

In [None]:
adata_query_2.obsm["X_scANVI"] = model_2.get_latent_representation()
adata_query_2.obs["predictions"] = model_2.predict()
query_predictions = model_2.predict()
print("Acc: {}".format(np.mean(query_predictions == adata_query_2.obs['orig_cell_types'])))

In [None]:
q2_cropped = sc.AnnData(adata_query_2.obsm["X_scANVI"])
q2_cropped.obs["celltype"] = adata_query_2.obs['orig_cell_types'].tolist()
q2_cropped.obs["batch"] = adata_query_2.obs[batch_key].tolist()
q2_cropped.obs["predictions"] = adata_query_2.obs["predictions"].tolist()

In [None]:
sc.pp.neighbors(q2_cropped)
sc.tl.leiden(q2_cropped)
sc.tl.umap(q2_cropped)
q2_cropped.write_h5ad(filename=f'{dir_path}query_2_data.h5ad')

In [None]:
plt.figure()
sc.pl.umap(
    q2_cropped,
    color=["batch", "celltype"],
    frameon=False,
    ncols=1,
    show=False
)
plt.savefig(f'{control_path}umap_query_2.png', bbox_inches='tight')

# Evaluation Query1 and Query2 on Reference

In [None]:
adata_full_2 = adata_full_1.concatenate(adata_query_2)
adata_full_2.uns["_scvi"] = adata_query_2.uns["_scvi"]
adata_full_2.obsm["X_scANVI"] = model_2.get_latent_representation(adata=adata_full_2)
adata_full_2.obs["predicitions"] = model_2.predict(adata_full_2)
full_predictions = model_2.predict(adata_full_2)
print("Acc: {}".format(np.mean(full_predictions == adata_full_2.obs['orig_cell_types'])))

In [None]:
f2_cropped = sc.AnnData(adata_full_2.obsm["X_scANVI"])
f2_cropped.obs["celltype"] = adata_full_2.obs['orig_cell_types'].tolist()
f2_cropped.obs["batch"] = adata_full_2.obs[batch_key].tolist()
f2_cropped.obs["predictions"] = adata_full_2.obs["predictions"].tolist()

In [None]:
sc.pp.neighbors(f2_cropped)
sc.tl.leiden(f2_cropped)
sc.tl.umap(f2_cropped)
f2_cropped.write_h5ad(filename=f'{dir_path}full_2_data.h5ad')

In [None]:
plt.figure()
sc.pl.umap(
    f2_cropped,
    color=["batch", "celltype"],
    frameon=False,
    ncols=1,
    show=False
)
plt.savefig(f'{control_path}umap_full_2.png', bbox_inches='tight')

In [None]:
plt.figure()
sc.pl.umap(
    f2_cropped,
    color=["predictions"],
    frameon=False,
    ncols=1,
    show=False
)
plt.savefig(f'{control_path}pred_full_2.png', bbox_inches='tight')

In [None]:
torch.save(model_2.model.state_dict(), f'{dir_path}surgery_2_model_state_dict')
surgery_2_path = f'{dir_path}surg_2_model/'
if not os.path.exists(surgery_2_path):
    os.makedirs(surgery_2_path)
model_2.save(surgery_2_path, overwrite=True)

In [None]:
times = dict()
times["ref_time"] = ref_time
times["query_1_time"] = query_1_time
times["query_2_time"] = query_2_time
times["full_time"] = ref_time + query_1_time + query_2_time
with open(f'{dir_path}results_times.txt', 'w') as filehandle:
    json.dump(times, filehandle)