In [None]:
import numpy as np
import scanpy as sc
import multigrate as mtg
import anndata as ad
import scipy
from scipy import sparse
import os

In [None]:
def load_adata():
    print("loading data..")
    rna = ad.read_h5ad("../pp_harm_data/rna-pp-harm-sub.h5ad")
    adt = ad.read_h5ad("../pp_harm_data/adt-pp-harm-sub.h5ad")
    #cytof = ad.read_h5ad("../pp_harm_data/cytof-pp-harm-sub.h5ad")
    #facs = ad.read_h5ad("../pp_harm_data/facs-pp-harm-sub.h5ad")

    return rna, adt

In [None]:
def concatenate_adata(rna, adt):
    print("concatenating data..")
    
    combined = mtg.data.organize_multiome_anndatas(
        adatas = [[rna], [adt]],
        layers = [[None], [None]],
    )
    return combined

In [None]:
def setup_combined_adata(combined):
    print("setting up the combined adata..")
    mtg.model.MultiVAE.setup_anndata(combined, categorical_covariate_keys = ['Domain']),

In [None]:
def setup_multivae(combined, l_coef):
    print("setting up the model..")
    model = mtg.model.MultiVAE(
        combined, 
        integrate_on='Domain',
        loss_coefs={'integ':l_coef},
        losses=['mse', 'mse'])
    return model

In [None]:
def model_train(model, lr):
    print("training the model..")
    model.train(lr=lr, use_gpu=True)

In [None]:
def plot_losses(model, result_path):
    model.plot_losses(result_path + "losses.jpg")

In [None]:
def save_model(model, result_path):
    print("saving the model..")
    model.save(result_path + "multigrate.dill", prefix=None, overwrite=True, save_anndata=False)

In [None]:
def get_latent_representation(model):
    print("getting latent representation for the combined adata..")
    model.get_latent_representation()

In [None]:
def compute_umap(combined):
    sc.pp.neighbors(combined, use_rep='latent')
    sc.tl.umap(combined)

In [None]:
def write_combined(combined, result_path):
    print("writing the combined adata")
    combined.write(result_path + "combined.h5ad", compression="gzip")
    print("writing complete")

In [None]:
def main(loss_coefs=[1, 50, 100, 150, 200],
         lr=0.0005):
    rna, adt = load_adata()
    
    for l_coef in loss_coefs:
        result_path = '../results/multigrate/cite_integration/coef_' + str(l_coef) + '/'
        os.makedirs(result_path, exist_ok=True)
        combined = concatenate_adata(rna, adt)
        #print(combined.obs.columns.tolist())
        setup_combined_adata(combined)
        model = setup_multivae(combined, l_coef)
        model.to_device('cuda:0')
        model_train(model, lr=lr)
        save_model(model, result_path)
        plot_losses(model, result_path)
        get_latent_representation(model)
        write_combined(combined, result_path)

In [None]:
main(loss_coefs=[100], lr = 0.00005)