In [None]:
import numpy as np
import scanpy as sc
import trvae
import pandas as pd 
from scipy import stats

In [None]:
sc.settings.set_figure_params(dpi=200)

In [None]:
data_name = "haber"
specific_celltype = "Tuft"
conditions = ['Control', 'Hpoly.Day3', 'Hpoly.Day10', 'Salmonella']
target_conditions = ["Hpoly.Day3", "Hpoly.Day10", "Salmonella"]
cell_type_key = "cell_label"
condition_key = "condition"

In [None]:
adata = sc.read(f"./data/{data_name}/{data_name}.h5ad")
adata = adata[adata.obs[condition_key].isin(conditions)]
adata

In [None]:
adata.obs.groupby([cell_type_key, condition_key]).size()

In [None]:
train_adata, valid_adata = trvae.utils.train_test_split(adata, 0.80)

In [None]:
train_adata.shape, valid_adata.shape

In [None]:
net_train_adata = train_adata[~((train_adata.obs[cell_type_key] == specific_celltype) & (train_adata.obs[condition_key].isin(target_conditions)))]
net_valid_adata = valid_adata[~((valid_adata.obs[cell_type_key] == specific_celltype) & (valid_adata.obs[condition_key].isin(target_conditions)))]

In [None]:
net_train_adata.shape, net_valid_adata.shape

In [None]:
net_train_adata.obs.groupby([cell_type_key, condition_key]).size()

In [None]:
network = trvae.archs.trVAEMulti(x_dimension=net_train_adata.shape[1],
                                 z_dimension=60,
                                 mmd_dimension=128, 
                                 n_conditions=len(net_train_adata.obs[condition_key].unique()),
                                 alpha=1e-6,
                                 beta=100,
                                 eta=100,
                                 clip_value=100,
                                 lambda_l1=0.0,
                                 lambda_l2=0.0,
                                 learning_rate=0.001,
                                 model_path=f"./models/trVAEMulti/best/{data_name}-{specific_celltype}/",
                                 dropout_rate=0.2,
                                 output_activation='relu')

In [None]:
label_encoder = {'Control': 0, 'Hpoly.Day3': 1, 'Hpoly.Day10': 2, 'Salmonella': 3}

In [None]:
# network.restore_model()

In [None]:
network.train(net_train_adata,
              net_valid_adata,
              label_encoder,
              condition_key,
              n_epochs=10000,
              batch_size=512,
              verbose=2,
              early_stop_limit=750,
              lr_reducer=0,
              shuffle=True,
              )

In [None]:
train_labels, _ = trvae.tl.label_encoder(net_train_adata, label_encoder, condition_key)
latent_with_true_labels = network.to_latent(net_train_adata, train_labels)
mmd_with_true_labels = network.to_mmd_layer(net_train_adata, train_labels, feed_fake=-1)

In [None]:
sc.pp.neighbors(latent_with_true_labels)
sc.tl.umap(latent_with_true_labels)
sc.pl.umap(latent_with_true_labels, color=[condition_key, cell_type_key],
           show=True,
           wspace=0.15,
           frameon=False)

In [None]:
sc.pp.neighbors(mmd_with_true_labels)
sc.tl.umap(mmd_with_true_labels)
sc.pl.umap(mmd_with_true_labels, color=[condition_key, cell_type_key],
           show=True,
           wspace=0.15,
           frameon=False)

In [None]:
cell_type_adata = train_adata[train_adata.obs[cell_type_key] == specific_celltype]

In [None]:
cell_type_adata.var = pd.DataFrame(index=cell_type_adata.var_names)

In [None]:
cell_type_adata.obs.groupby([cell_type_key, condition_key]).size()

In [None]:
def predict_transition(adata, source_cond, target_cond):
    
    source_adata = adata[adata.obs[condition_key] == source_cond]
    
    source_cond_key = source_cond.split("_to_")[-1]
    source_labels = np.zeros(source_adata.shape[0]) + label_encoder[source_cond_key]
    target_labels = np.zeros(source_adata.shape[0]) + label_encoder[target_cond]

    pred_target = network.predict(source_adata,
                                  encoder_labels=source_labels,
                                  decoder_labels=target_labels,
                                  )

    pred_adata = sc.AnnData(X=pred_target)
    pred_adata.obs[condition_key] = [source_cond + "_to_" + target_cond] * pred_target.shape[0]
    pred_adata.obs[cell_type_key] = [specific_celltype] * pred_target.shape[0]
    pred_adata.var_names = source_adata.var_names

    adata = adata.concatenate(pred_adata)
    return adata

In [None]:
recon_adata = predict_transition(cell_type_adata, "Control", "Hpoly.Day10")
recon_adata = recon_adata.concatenate(predict_transition(recon_adata, "Control", "Hpoly.Day3"))
recon_adata = recon_adata.concatenate(predict_transition(recon_adata, "Control", "Salmonella"))
recon_adata = recon_adata.concatenate(predict_transition(recon_adata, "Hpoly.Day3", "Hpoly.Day10"))
recon_adata = recon_adata.concatenate(predict_transition(recon_adata, "Control_to_Hpoly.Day3", "Hpoly.Day10"))
recon_adata

In [None]:
recon_adata.obs.groupby([cell_type_key, condition_key]).size()

In [None]:
path = "Control_to_Hpoly.Day10"
from_condition = path.split("_to_")[-2]
to_condition = path.split("_to_")[-1]

In [None]:
pred_adata = recon_adata[recon_adata.obs[condition_key] == path]
pred_adata

In [None]:
ctrl_adata = cell_type_adata[cell_type_adata.obs[condition_key] == from_condition]
ctrl_adata

In [None]:
real_adata = cell_type_adata[cell_type_adata.obs[condition_key] == to_condition]
real_adata

In [None]:
real_adata.X.min(), real_adata.X.max()

In [None]:
pred_adata.X.min(), pred_adata.X.max()

In [None]:
pred_mean = np.mean(pred_adata.X, axis=0)
ctrl_mean = np.mean(ctrl_adata.X, axis=0)
real_mean = np.mean(real_adata.X, axis=0)

In [None]:
pred_var = np.var(pred_adata.X, axis=0)
ctrl_var = np.var(ctrl_adata.X, axis=0)
real_var = np.var(real_adata.X, axis=0)

In [None]:
m, b, r_value_mean, p_value, std_err = stats.linregress(pred_mean, real_mean)
r_value_mean = r_value_mean ** 2
r_value_mean

In [None]:
m, b, r_value_var, p_value, std_err = stats.linregress(pred_var, real_var)
r_value_var = r_value_var ** 2
r_value_var

In [None]:
adata_pred = ctrl_adata.concatenate(pred_adata, real_adata)
adata_pred

In [None]:
adata_pred.obs[condition_key].unique()

In [None]:
sc.tl.rank_genes_groups(cell_type_adata,
                        groupby=condition_key,
                        groups=[to_condition],
                        reference=from_condition,
                        n_genes=100)

In [None]:
top_genes = cell_type_adata.uns['rank_genes_groups']['names'][to_condition]

In [None]:
trvae.pl.reg_mean_plot(adata_pred,
                         top_100_genes=top_genes,
                         gene_list=top_genes[:10],
                         condition_key=condition_key,
                         axis_keys={'x': path, 'y': to_condition},
                         labels={'x': path, 'y': to_condition},
                         path_to_save=None,
                         legend=False,
                         show=True,
                         x_coeff=1.0,
                         y_coeff=0.0)

In [None]:
trvae.pl.reg_var_plot(adata_pred,
                     top_100_genes=top_genes,
                     gene_list=top_genes[:10],
                     condition_key=condition_key,
                     axis_keys={'x': path, 'y': to_condition},
                     labels={'x': path, 'y': to_condition},
                     path_to_save=None,
                     legend=False,
                     show=True,
                     x_coeff=1.0,
                     y_coeff=0.0)

In [None]:
sc.pl.violin(adata_pred, groupby=condition_key, keys=top_genes[:10], rotation=90)

In [None]:
recon_adata.write_h5ad(f"../trVAE_reproducibility/data/reconstructed/trVAE_Haber/{specific_celltype}.h5ad")

In [None]:
network.model_to_use = f"./models/trVAEMulti/best/haber-{specific_celltype}/"

In [None]:
network.model_to_use

In [None]:
network.save_model()