In [None]:
import numpy as np
import scanpy as sc
import multigrate as mtg
import anndata as ad
import scipy
from scipy import sparse
import os

In [None]:
def load_adata():
    print("loading data..")
    rna = ad.read_h5ad("../pp_harm_data/rna-pp-harm-sub.h5ad")
    adt = ad.read_h5ad("../pp_harm_data/adt-pp-harm-sub.h5ad")
    cytof = ad.read_h5ad("../pp_harm_data/cytof-pp-harm-sub.h5ad")
    facs = ad.read_h5ad("../pp_harm_data/facs-pp-harm-sub.h5ad")

    return rna, adt, cytof, facs

In [None]:
def split_adata(rna, adt, cytof, facs):
    intersection = adt.var.index.intersection(cytof.var.index.intersection(facs.var.index)).tolist()
    
    adt_unique = adt[:, adt.var_names.drop(intersection)].copy()
    cytof_unique = cytof[:, cytof.var_names.drop(intersection)].copy()
    facs_unique = facs[:, facs.var_names.drop(intersection)].copy()
    
    adt_common = adt[:, intersection].copy()
    cytof_common = cytof[:, intersection].copy()
    facs_common = facs[:, intersection].copy()
    
    return rna, adt_unique, cytof_unique, facs_unique, adt_common, cytof_common, facs_common


In [None]:
def concatenate_adata(rna, adt_unique, cytof_unique, facs_unique, adt_common, cytof_common, facs_common):
    print("concatenating data..")
    
    combined = mtg.data.organize_multiome_anndatas(
        adatas = [[rna, None, None], [adt_common, facs_common, cytof_common], [adt_unique, None, None], [None, facs_unique, None], [None, None, cytof_unique]],
        layers = [[None, None, None], [None, None, None], [None, None, None], [None, None, None], [None, None, None]],
    )
    return combined

In [None]:
def setup_combined_adata(combined):
    print("setting up the combined adata..")
    mtg.model.MultiVAE.setup_anndata(combined, categorical_covariate_keys = ['Domain']),

In [None]:
def setup_multivae(combined, l_coef,mmd):
    print("setting up the model..")
    model = mtg.model.MultiVAE(
        combined, 
        integrate_on='Domain',
        loss_coefs={'integ':l_coef},
        losses=['mse', 'mse', 'mse', 'mse', 'mse']),
        mmd=mmd
    return model

In [None]:
def model_train(model, lr):
    print("training the model..")
    model.train(lr=lr, use_gpu=True)

In [None]:
def plot_losses(model, result_path):
    model.plot_losses(result_path + "losses.jpg")

In [None]:
def save_model(model, result_path):
    print("saving the model..")
    model.save(result_path + "multigrate.dill", prefix=None, overwrite=True, save_anndata=False)

In [None]:
def get_latent_representation(model):
    print("getting latent representation for the combined adata..")
    model.get_latent_representation()

In [None]:
def write_combined(combined, result_path):
    print("writing the combined adata")
    combined.write(result_path + "combined.h5ad", compression="gzip")
    print("writing complete")

In [None]:
def main(loss_coefs=[0, 1e1, 1e2, 1e3, 1e4, 1e5],
         lr=0.00005,
         mmd='marginal'):
    
    rna, adt, cytof, facs = load_adata()
    (rna, adt_unique, cytof_unique, facs_unique, 
        adt_common, cytof_common, facs_common) = split_adata(rna, adt, cytof, facs)
    
    for l_coef in loss_coefs:
        result_path = '../results/multigrate/coef_' + str(l_coef) + '/'
        os.makedirs(result_path, exist_ok=True)
        combined = concatenate_adata(rna, adt_unique, cytof_unique, facs_unique, adt_common, cytof_common, facs_common)
        #print(combined.obs.columns.tolist())
        setup_combined_adata(combined)
        model = setup_multivae(combined, l_coef)
        model.to_device('cuda:0')
        model_train(model, lr)
        save_model(model, result_path)
        plot_losses(model, result_path)
        get_latent_representation(model)
        write_combined(combined, result_path)

## Set parameters and run

In [None]:
main(loss_coefs=[100], lr = 0.00005, mmd)