# Run Savercat with highly vairiable genes

In [None]:
#Import Packages
import random
import os
import numpy as np
import scanpy as sc
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from keras import backend as K
from keras.utils.vis_utils import plot_model
from tensorflow.keras import layers
from tensorflow.keras.layers import Input, Dense, Dropout, Activation, BatchNormalization, LeakyReLU, Lambda
from tensorflow.keras import Model
from tensorflow.keras.callbacks import TensorBoard, ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from sklearn.preprocessing import LabelEncoder, OneHotEncoder, scale
from sklearn.preprocessing import OneHotEncoder
import matplotlib.pyplot as pl

In [None]:
base_name = os.path.basename(os.getcwd())
print(base_name)
print(sc.__version__)
sc.settings.verbosity = 3  
sc.logging.print_versions()

In [None]:
# import utils functions in utils.py
from utils import *
# import network buiding functions in network.py
from network import * 
# import cross_validation function in train.py
from train import * 

# 1. Load data

In [None]:
adata = sc.read_h5ad('../share/data/adata_subsample_hvg.h5ad')
print(adata)

# 2. Savercat preprocess

In [None]:
predict_key = 'Cycle' # the name of the cell-level label to be predicted
batch_key = 'patient' # the name of the cell-level label to be adjusted for

In [None]:
# savercat preprocessing step
adata = savercat_preprocess(adata, predict_key=predict_key, adjust_key=batch_key, scaleB=True)
adata

# 3. Build the model

In [None]:
# if train on highly variable genes, then keep enc=(256, 256, 128), dec=(128, 256, 256)
# leave all the parameters unchanged
SAVER_net = CVAE(x_input_size = adata.n_vars, # number of genes
                 b_input_size = adata.obsm['saver_batch'].shape[1], # number of batches including lib-size
                 lb_input_size = adata.obsm['saver_targetL'].shape[1], # number of labels to predict
                 enc = (256, 256, 128), # dim of the encoder
                 dec = (128, 256, 256), # dim of the decoder
                 latent_k = 30) # dimension of the low-dimensional latent space
SAVER_net.build()
SAVER_net.compile_model(pred_weight=1, kl_weight=1)

# 4. Initialize the model

In [None]:
# no need to modify this block
# label guided initialization step
loss = SAVER_net.model_initialize(adata, fit_verbose=1)

In [None]:
# fill in the directory where you want to save the file
# 'weights_step1.h5' is the file name
SAVER_net.model.save_weights('weights_init.h5') 

# 5. Fine-tune the model

In [None]:
# if train on highly variable genes, then keep enc=(256, 256, 128), dec=(128, 256, 256)
# leave all the parameters unchanged
# same as block 8 but use the weight you just saved
SAVER_net = CVAE(x_input_size = adata.n_vars,
                 b_input_size = adata.obsm['saver_batch'].shape[1],
                 lb_input_size = adata.obsm['saver_targetL'].shape[1],
                 enc = (256, 256, 128),
                 dec = (128, 256, 256),
                 latent_k = 30)
SAVER_net.build()
SAVER_net.load_weights('weights_init.h5') # fill in the weight file you just saved
SAVER_net.compile_model(pred_weight=0., kl_weight=1)

In [None]:
# no need to modify this block
# train savercat model which do the dimension reduction
loss = SAVER_net.model_finetune(adata, fit_verbose=1)
SAVER_net.model.save_weights('weights_ft.h5') 

In [None]:
# predict the low-dimensional embedding for all the cells, and save to a csv file
meta_df_train = adata.obs
z_train = SAVER_net.extra_models['mean_out'].predict([adata.X, adata.obsm['saver_batch']])
z_df = pd.DataFrame(z_train, 
                    index = meta_df_train.index,
                    columns = ['saver{}'.format(i+1) for i in range(SAVER_net.latent_k)])
z_df.to_csv('lowdim_savercat_hvg.csv') # where you want to save the low-dimensional embeddings learned by SAVERCAT

# 6. Cross Validation

In [None]:
# Cross validation step is necessary for the next denoising step.
# This step may take several hours to run.
SAVER_net = CVAE(x_input_size = adata.n_vars,
                 b_input_size = adata.obsm['saver_batch'].shape[1],
                 lb_input_size = adata.obsm['saver_targetL'].shape[1],
                 enc = (256, 256, 128),
                 dec = (128, 256, 256),
                 latent_k = 30)
train_cv(adata, SAVER_net, weights_orig_filename='weights_init.h5',
         cv_genes_file_name = 'cv_genes_idx.csv')

# 7.Shrinkage

In [None]:
# Use the network trained in step 5(fine-tune the model), to predict gene expression
# and perform denoising. The denoised expression is saved to denoise_only_path.
SAVER_net = CVAE(x_input_size = adata.n_vars,
                 b_input_size = adata.obsm['saver_batch'].shape[1],
                 lb_input_size = adata.obsm['saver_targetL'].shape[1],
                 enc = (256, 256, 128),
                 dec = (128, 256, 256),
                 latent_k = 30)

SAVER_net.build()  
SAVER_net.load_weights('weights_ft.h5') # weights saved in step 5(fine-tune the model).

In [None]:
X_denoise_df = shrinkage(SAVER_net, adata, cv_genes_file_name='cv_genes_idx.csv',
                         denoise_only_path = 'Saver_denoiseonly_mat.csv')
# denoised count matrix is saved to denoise_only_path, and is returned as X_denoise_df.