# Scglue integration on .9 edges

- Model types: paired, unpaired
- Integration on: full data, trimodal, cite
- Cross-modality edges: weight +1, sign +1

In [None]:
import anndata as ad
import networkx as nx
import scanpy as sc
import scglue
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib import rcParams
import os

In [None]:
scglue.plot.set_publication_params()
rcParams["figure.figsize"] = (4, 4)

In [None]:
def load_data():
    print("loading data ..\n\n")
    
    rna = ad.read_h5ad("../pp_harm_data/rna-pp-harm-sub.h5ad")
    adt = ad.read_h5ad("../pp_harm_data/adt-pp-harm-sub.h5ad")
    cytof = ad.read_h5ad("../pp_harm_data/cytof-pp-harm-sub.h5ad")
    facs = ad.read_h5ad("../pp_harm_data/facs-pp-harm-sub.h5ad")
    
    adt.var.index = adt.var.index + '_adt'
    cytof.var.index = cytof.var.index + '_cytof'
    facs.var.index = facs.var.index + '_facs'

    print("rna, adt, cytof and facs shapes: ")
    print(rna.shape, adt.shape, cytof.shape, facs.shape, "\n\n")
    
    return rna, adt, cytof, facs

In [None]:
def configure_rna(rna):
    scglue.models.configure_dataset(
    rna, "Normal", use_highly_variable=False,
    use_obs_names = True, use_rep="X_pca"
    )

In [None]:
def configure_adt(adt):
    scglue.models.configure_dataset(
    adt, "Normal", use_highly_variable=False, 
    use_obs_names = True
    )

In [None]:
def configure_cytof(cytof):
    scglue.models.configure_dataset(
    cytof, "Normal", use_highly_variable=False
    )

In [None]:
def configure_facs(facs):
    scglue.models.configure_dataset(
    facs, "Normal", use_highly_variable=False
    )

In [None]:
def configure_datasets(rna, adt, cytof, facs):
    print("configuring datasets.. \n\n")
    configure_rna(rna)
    configure_adt(adt)
    configure_cytof(cytof)
    configure_facs(facs)

In [None]:
def load_guidance_graph(guidance_path):
    guidance = nx.read_graphml(guidance_path)
    return guidance

In [None]:
def fit_paired_model(rna, adt, cytof, facs, guidance, results_path, integration_type):    
    if integration_type == 'full':
        glue = scglue.models.fit_SCGLUE(
            {"rna": rna, "adt": adt, "cytof": cytof, "facs": facs}, guidance,
            fit_kws={"directory": results_path + "glue_run_report"}, model=scglue.models.PairedSCGLUEModel,
        )
        
    elif integration_type == 'trimodal':
        glue = scglue.models.fit_SCGLUE(
            {"rna": rna, "adt": adt, "cytof": cytof}, guidance,
            fit_kws={"directory": results_path + "glue_run_report"}, model=scglue.models.PairedSCGLUEModel,
        )
        
    elif integration_type == 'cite':
        glue = scglue.models.fit_SCGLUE(
            {"rna": rna, "adt": adt}, guidance,
            fit_kws={"directory": results_path + "glue_run_report"}, model=scglue.models.PairedSCGLUEModel,
        )
             
    return glue

In [None]:
def fit_unpaired_model(rna, adt, cytof, facs, guidance, results_path, integration_type):
    if integration_type == 'full':
        glue = scglue.models.fit_SCGLUE(
            {"rna": rna, "adt": adt, "cytof": cytof, "facs": facs}, guidance,
            fit_kws={"directory": results_path + "glue_run_report"}, model=scglue.models.SCGLUEModel,
        )
        
    elif integration_type == 'trimodal':
        glue = scglue.models.fit_SCGLUE(
            {"rna": rna, "adt": adt, "cytof": cytof}, guidance,
            fit_kws={"directory": results_path + "glue_run_report"}, model=scglue.models.SCGLUEModel,
        )
        
    elif integration_type == 'cite':
        glue = scglue.models.fit_SCGLUE(
            {"rna": rna, "adt": adt}, guidance,
            fit_kws={"directory": results_path + "glue_run_report"}, model=scglue.models.SCGLUEModel,
        )
             
    return glue

In [None]:
def save_fitted_model(glue, results_path):
    glue.save(results_path + "glue.dill")

In [None]:
def plot_integration_consistency(rna, adt, cytof, facs, glue, guidance, results_path, integration_type):
    if integration_type == 'full':
        dx = scglue.models.integration_consistency(glue, {"rna": rna, "adt": adt, "cytof": cytof, "facs": facs}, guidance)
    elif integration_type == 'trimodal':
        dx = scglue.models.integration_consistency(glue, {"rna": rna, "adt": adt, "cytof": cytof}, guidance)
    elif integration_type == 'cite':
        dx = scglue.models.integration_consistency(glue, {"rna": rna, "adt": adt}, guidance)
    ax = sns.lineplot(x="n_meta", y="consistency", data=dx).axhline(y=0.05, c="darkred", ls="--")
    plt.show()
    fig = ax.get_figure()
    fig.savefig(results_path + "scglue_run_lineplot.png")

In [None]:
def generate_embeddings(rna, adt, cytof, facs, glue, integration_type):
    print("computing embeddings..")
    if integration_type == 'full':
        rna.obsm['X_glue'] = glue.encode_data('rna', rna)
        adt.obsm['X_glue'] = glue.encode_data('adt', adt)
        cytof.obsm['X_glue'] = glue.encode_data('cytof', cytof)
        facs.obsm['X_glue'] = glue.encode_data('facs', facs)
        combined = ad.concat([rna, adt, cytof, facs])
        
    elif integration_type == 'trimodal':
        rna.obsm['X_glue'] = glue.encode_data('rna', rna)
        adt.obsm['X_glue'] = glue.encode_data('adt', adt)
        cytof.obsm['X_glue'] = glue.encode_data('cytof', cytof)
        combined = ad.concat([rna, adt, cytof])
    
    elif integration_type == 'cite':
        rna.obsm['X_glue'] = glue.encode_data('rna', rna)
        adt.obsm['X_glue'] = glue.encode_data('adt', adt)
        combined = ad.concat([rna, adt])
        
    return combined

In [None]:
def compute_umap(combined, results_path):
    print("computing neighbours..\n\nvvvv")
    sc.pp.neighbors(combined, use_rep="X_glue", metric="cosine")
    print("computing umap..\n\n")
    sc.tl.umap(combined)
    print("writing combined adata with umap in results directory..\n\n")
    combined.write(results_path +"combined.h5ad", compression="gzip")

    print("plotting umaps..\n\n")
    os.makedirs(results_path + "umaps/", exist_ok=True)
    
    ax = sc.pl.umap(combined, color=["Annotation_major_subset", "Annotation_cell_type"], wspace=0.65, return_fig=True)
    plt.show()
    fig = ax.get_figure()
    fig.savefig(results_path+'umaps/cell_type.png')
    plt.close()
    
    ax = sc.pl.umap(combined, color=["Domain_major", "Domain"], wspace=0.65, return_fig=True)
    plt.show()
    fig = ax.get_figure()
    fig.savefig(results_path+'umaps/domain.png')
    plt.close()
    
    return combined

In [None]:
def load_glue(results_path):
    print("loading glue..")
    glue = scglue.models.load_model(results_path + "glue.dill")
    return glue

### Concatenate functions

In [None]:
def main(model_types = ['unpaired', 'paired'],
         integration_types = ['cite', 'trimodal', 'full']):
    
    print("analysis starting..\n\n")
        
    rna, adt, cytof, facs = load_data() #load data
    configure_datasets(rna, adt, cytof, facs) #configure the datasets for scglue use
    
    for model_type in model_types: #run both paired and unpaired
        for integration_type in integration_types: #run full, cite-only, and trimodal models
            results_path = "../results/scglue/point_nine_corr/" + model_type + '/' + integration_type + '/' 
            guidance_path = "guidance_graphs/point_nine_corr/" + integration_type + '_graph.graphml.gz'
            print("using results path: {}\nusing guidance graph path: {}\n\n".format(results_path, guidance_path))
            print("using model type: {}\nusing integration_type: {}\n\n".format(model_type, integration_type))
            
            os.makedirs(results_path, exist_ok=True) #generate result directory for the specific type of model
            guidance = load_guidance_graph(guidance_path)
            
            if model_type == 'paired': 
                glue = fit_paired_model(rna, adt, cytof, facs, guidance, results_path, integration_type) #run paired model
            else:
                glue = fit_unpaired_model(rna, adt, cytof, facs, guidance, results_path, integration_type) #run unpaired model 
            
            save_fitted_model(glue, results_path) #save model
            plot_integration_consistency(rna, adt, cytof, facs, glue, 
                                         guidance, results_path, integration_type) #generate and save glue plot
            combined = generate_embeddings(rna, adt, cytof, facs, glue, integration_type) #compute embeddings and concatenate
            compute_umap(combined, results_path) #compute umap on concatenated adata
            
    print("analysis completed\n\n")

## Run models

In [None]:
main(model_types = ['unpaired'], integration_types=['cite'])

In [None]:
main(model_types = ['unpaired'], integration_types=['trimodal'] )

In [None]:
main(model_types = ['unpaired'], integration_types=['full'] )

In [None]:
main(model_types = ['paired'], integration_types=['cite'] )

In [None]:
main(model_types = ['paired'], integration_types=['trimodal'] )

In [None]:
main(model_types = ['paired'], integration_types=['full'] )

In [None]:
main(model_types = ['unpaired'], integration_types=['cite'])