# Sample Run with trVAE for Haber et. al dataset

In [1]:
import os 
os.chdir("../")

In [3]:
import numpy as np
import scanpy as sc
import reptrvae
import pandas as pd 
from scipy import stats

Using TensorFlow backend.


In [4]:
sc.settings.set_figure_params(dpi=200)

In [5]:
data_name = "haber"
specific_celltype = "Tuft"
conditions = ['Control', 'Hpoly.Day3', 'Hpoly.Day10', 'Salmonella']
target_conditions = ["Hpoly.Day3", "Hpoly.Day10", "Salmonella"]
cell_type_key = "cell_label"
condition_key = "condition"

In [10]:
adata = sc.read(f"./data/{data_name}/{data_name}_normalized.h5ad")
adata = adata[adata.obs[condition_key].isin(conditions)]
adata

View of AnnData object with n_obs × n_vars = 9842 × 1000 
    obs: 'batch', 'barcode', 'condition', 'cell_label', 'n_counts'
    var: 'highly_variable', 'means', 'dispersions', 'dispersions_norm'

In [11]:
adata.obs.groupby([cell_type_key, condition_key]).size()

cell_label             condition  
Endocrine              Control        112
                       Hpoly.Day3     117
                       Hpoly.Day10     82
                       Salmonella      69
Enterocyte             Control        424
                       Hpoly.Day3     201
                       Hpoly.Day10    128
                       Salmonella     705
Enterocyte.Progenitor  Control        545
                       Hpoly.Day3     462
                       Hpoly.Day10    586
                       Salmonella     229
Goblet                 Control        216
                       Hpoly.Day3      99
                       Hpoly.Day10    317
                       Salmonella     126
Stem                   Control        670
                       Hpoly.Day3     388
                       Hpoly.Day10    592
                       Salmonella     207
TA                     Control        421
                       Hpoly.Day3     302
                       Hpoly.Day10    353

In [14]:
train_adata, valid_adata = reptrvae.utils.train_test_split(adata, 0.80)

In [15]:
train_adata.shape, valid_adata.shape

((7873, 1000), (1969, 1000))

In [16]:
net_train_adata = train_adata[~((train_adata.obs[cell_type_key] == specific_celltype) & (train_adata.obs[condition_key].isin(target_conditions)))]
net_valid_adata = valid_adata[~((valid_adata.obs[cell_type_key] == specific_celltype) & (valid_adata.obs[condition_key].isin(target_conditions)))]

In [17]:
net_train_adata.shape, net_valid_adata.shape

((7598, 1000), (1895, 1000))

In [18]:
net_train_adata.obs.groupby([cell_type_key, condition_key]).size()

cell_label             condition  
Endocrine              Control         89
                       Hpoly.Day3     103
                       Hpoly.Day10     69
                       Salmonella      58
Enterocyte             Control        318
                       Hpoly.Day3     166
                       Hpoly.Day10     99
                       Salmonella     571
Enterocyte.Progenitor  Control        442
                       Hpoly.Day3     378
                       Hpoly.Day10    451
                       Salmonella     189
Goblet                 Control        170
                       Hpoly.Day3      81
                       Hpoly.Day10    255
                       Salmonella     108
Stem                   Control        530
                       Hpoly.Day3     314
                       Hpoly.Day10    494
                       Salmonella     168
TA                     Control        322
                       Hpoly.Day3     246
                       Hpoly.Day10    284

In [19]:
network = reptrvae.models.trVAE(x_dimension=net_train_adata.shape[1],
                                z_dimension=60,
                                mmd_dimension=128, 
                                n_conditions=len(net_train_adata.obs[condition_key].unique()),
                                alpha=1e-6,
                                beta=100,
                                eta=100,
                                clip_value=100,
                                lambda_l1=0.0,
                                lambda_l2=0.0,
                                learning_rate=0.001,
                                model_path=f"./models/trVAEMulti/best/{data_name}-{specific_celltype}/",
                                dropout_rate=0.2,
                                output_activation='relu')

Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
data (InputLayer)               (None, 1000)         0                                            
__________________________________________________________________________________________________
encoder_labels (InputLayer)     (None, 4)            0                                            
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, 1004)         0           data[0][0]                       
                                                                 encoder_labels[0][0]             
_____________________

In [20]:
label_encoder = {'Control': 0, 'Hpoly.Day3': 1, 'Hpoly.Day10': 2, 'Salmonella': 3}

In [None]:
# network.restore_model()

In [None]:
network.train(net_train_adata,
              net_valid_adata,
              label_encoder,
              condition_key,
              n_epochs=10000,
              batch_size=512,
              verbose=2,
              early_stop_limit=750,
              lr_reducer=0,
              shuffle=True,
              )

In [None]:
train_labels, _ = trvae.tl.label_encoder(net_train_adata, label_encoder, condition_key)
latent_with_true_labels = network.to_latent(net_train_adata, train_labels)
mmd_with_true_labels = network.to_mmd_layer(net_train_adata, train_labels, feed_fake=-1)

In [None]:
sc.pp.neighbors(latent_with_true_labels)
sc.tl.umap(latent_with_true_labels)
sc.pl.umap(latent_with_true_labels, color=[condition_key, cell_type_key],
           show=True,
           wspace=0.15,
           frameon=False)

In [None]:
sc.pp.neighbors(mmd_with_true_labels)
sc.tl.umap(mmd_with_true_labels)
sc.pl.umap(mmd_with_true_labels, color=[condition_key, cell_type_key],
           show=True,
           wspace=0.15,
           frameon=False)

In [None]:
cell_type_adata = train_adata[train_adata.obs[cell_type_key] == specific_celltype]

In [None]:
cell_type_adata.var = pd.DataFrame(index=cell_type_adata.var_names)

In [None]:
cell_type_adata.obs.groupby([cell_type_key, condition_key]).size()

In [None]:
def predict_transition(adata, source_cond, target_cond):
    
    source_adata = adata[adata.obs[condition_key] == source_cond]
    
    source_cond_key = source_cond.split("_to_")[-1]
    source_labels = np.zeros(source_adata.shape[0]) + label_encoder[source_cond_key]
    target_labels = np.zeros(source_adata.shape[0]) + label_encoder[target_cond]

    pred_target = network.predict(source_adata,
                                  encoder_labels=source_labels,
                                  decoder_labels=target_labels,
                                  )

    pred_adata = sc.AnnData(X=pred_target)
    pred_adata.obs[condition_key] = [source_cond + "_to_" + target_cond] * pred_target.shape[0]
    pred_adata.obs[cell_type_key] = [specific_celltype] * pred_target.shape[0]
    pred_adata.var_names = source_adata.var_names

    adata = adata.concatenate(pred_adata)
    return adata

In [None]:
recon_adata = predict_transition(cell_type_adata, "Control", "Hpoly.Day10")
recon_adata = recon_adata.concatenate(predict_transition(recon_adata, "Control", "Hpoly.Day3"))
recon_adata = recon_adata.concatenate(predict_transition(recon_adata, "Control", "Salmonella"))
recon_adata = recon_adata.concatenate(predict_transition(recon_adata, "Hpoly.Day3", "Hpoly.Day10"))
recon_adata = recon_adata.concatenate(predict_transition(recon_adata, "Control_to_Hpoly.Day3", "Hpoly.Day10"))
recon_adata

In [None]:
recon_adata.obs.groupby([cell_type_key, condition_key]).size()

In [None]:
path = "Control_to_Hpoly.Day10"
from_condition = path.split("_to_")[-2]
to_condition = path.split("_to_")[-1]

In [None]:
pred_adata = recon_adata[recon_adata.obs[condition_key] == path]
pred_adata

In [None]:
ctrl_adata = cell_type_adata[cell_type_adata.obs[condition_key] == from_condition]
ctrl_adata

In [None]:
real_adata = cell_type_adata[cell_type_adata.obs[condition_key] == to_condition]
real_adata

In [None]:
real_adata.X.min(), real_adata.X.max()

In [None]:
pred_adata.X.min(), pred_adata.X.max()

In [None]:
pred_mean = np.mean(pred_adata.X, axis=0)
ctrl_mean = np.mean(ctrl_adata.X, axis=0)
real_mean = np.mean(real_adata.X, axis=0)

In [None]:
pred_var = np.var(pred_adata.X, axis=0)
ctrl_var = np.var(ctrl_adata.X, axis=0)
real_var = np.var(real_adata.X, axis=0)

In [None]:
m, b, r_value_mean, p_value, std_err = stats.linregress(pred_mean, real_mean)
r_value_mean = r_value_mean ** 2
r_value_mean

In [None]:
m, b, r_value_var, p_value, std_err = stats.linregress(pred_var, real_var)
r_value_var = r_value_var ** 2
r_value_var

In [None]:
adata_pred = ctrl_adata.concatenate(pred_adata, real_adata)
adata_pred

In [None]:
adata_pred.obs[condition_key].unique()

In [None]:
sc.tl.rank_genes_groups(cell_type_adata,
                        groupby=condition_key,
                        groups=[to_condition],
                        reference=from_condition,
                        n_genes=100)

In [None]:
top_genes = cell_type_adata.uns['rank_genes_groups']['names'][to_condition]

In [None]:
trvae.pl.reg_mean_plot(adata_pred,
                         top_100_genes=top_genes,
                         gene_list=top_genes[:10],
                         condition_key=condition_key,
                         axis_keys={'x': path, 'y': to_condition},
                         labels={'x': path, 'y': to_condition},
                         path_to_save=None,
                         legend=False,
                         show=True,
                         x_coeff=1.0,
                         y_coeff=0.0)

In [None]:
trvae.pl.reg_var_plot(adata_pred,
                     top_100_genes=top_genes,
                     gene_list=top_genes[:10],
                     condition_key=condition_key,
                     axis_keys={'x': path, 'y': to_condition},
                     labels={'x': path, 'y': to_condition},
                     path_to_save=None,
                     legend=False,
                     show=True,
                     x_coeff=1.0,
                     y_coeff=0.0)

In [None]:
sc.pl.violin(adata_pred, groupby=condition_key, keys=top_genes[:10], rotation=90)

In [None]:
recon_adata.write_h5ad(f"../trVAE_reproducibility/data/reconstructed/trVAE_Haber/{specific_celltype}.h5ad")

In [None]:
network.model_to_use = f"./models/trVAEMulti/best/haber-{specific_celltype}/"

In [None]:
network.model_to_use

In [None]:
network.save_model()