In [1]:
import scanpy as sc
import pandas as pd
import numpy as np
import networkx as nx
import scglue
import anndata as ad
import os
import sys
sys.path.insert(0, os.path.abspath(".."))
from matching.utils import snn_matching, eot_matching, calc_domainAveraged_FOSCTTM
import warnings
warnings.filterwarnings('ignore')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
cite = sc.read("/mnt/ps/home/CORP/johnny.xi/sandbox/matching/data/GSE194122_openproblems_neurips2021_cite_BMMC_processed.h5ad")

adt_ad = cite[:,cite.var.feature_types == "ADT"]
gex_ad = cite[:,cite.var.feature_types == "GEX"]

gex_ad.X = gex_ad.layers["counts"].copy()
sc.pp.normalize_total(gex_ad)
sc.pp.log1p(gex_ad)
sc.pp.scale(gex_ad)
sc.tl.pca(gex_ad, n_comps=200, svd_solver="auto")

p = np.array(adt_ad.var_names)
r = np.array(gex_ad.var_names)
# mask entries are set to 1 where protein name is the same as gene name
mask = np.repeat(p.reshape(-1, 1), r.shape[0], axis=1) == r
mask = np.array(mask)

rna_vars = [v + "_rna" for v in gex_ad.var_names]
prot_vars = [v + "_prot" for v in adt_ad.var_names]
gex_ad.var_names = rna_vars
adt_ad.var_names = prot_vars

adj = pd.DataFrame(mask, index=prot_vars, columns=rna_vars)
diag_edges = adj[adj > 0].stack().index.tolist()
diag_edges = [(n1, n2, {"weight": 1.0, "sign": 1}) for n1, n2 in diag_edges]
self_loop_rna = [(g, g, {"weight": 1.0, "sign": 1}) for g in rna_vars]
self_loop_prot = [(g, g, {"weight": 1.0, "sign": 1}) for g in prot_vars]

graph = nx.Graph()
graph.add_nodes_from(rna_vars)
graph.add_nodes_from(prot_vars)
graph.add_edges_from(diag_edges)
graph.add_edges_from(self_loop_prot)
graph.add_edges_from(self_loop_rna)

scglue.models.configure_dataset(
    gex_ad,
    "NB",
    use_highly_variable=False,
    use_batch="batch",
    use_layer="counts",
    use_rep="X_pca",
)

scglue.models.configure_dataset(
    adt_ad,
    "Normal",
    use_highly_variable=False,
    use_batch="batch",
    use_layer="counts"
)

glue = scglue.models.fit_SCGLUE(
    {"rna": gex_ad, "adt": adt_ad},
    graph
)

rna_encodings = glue.encode_data("rna", gex_ad)
adt_encodings = glue.encode_data("adt", adt_ad)


[INFO] fit_SCGLUE: Pretraining SCGLUE model...
[INFO] autodevice: Using GPU 5 as computation device.
[INFO] check_graph: Checking variable coverage...
[INFO] check_graph: Checking edge attributes...
[INFO] check_graph: Checking self-loops...
[INFO] check_graph: Checking graph symmetry...
[INFO] check_graph: All checks passed!
[INFO] SCGLUEModel: Setting `graph_batch_size` = 4868
[INFO] SCGLUEModel: Setting `max_epochs` = 48
[INFO] SCGLUEModel: Setting `patience` = 4
[INFO] SCGLUEModel: Setting `reduce_lr_patience` = 2
[INFO] SCGLUETrainer: Using training directory: "/tmp/GLUETMPq16udor4"
[INFO] SCGLUETrainer: [Epoch 10] train={'g_nll': 0.428, 'g_kl': 0.007, 'g_elbo': 0.435, 'x_rna_nll': 0.314, 'x_rna_kl': 0.006, 'x_rna_elbo': 0.32, 'x_adt_nll': 68.622, 'x_adt_kl': 3.205, 'x_adt_elbo': 71.828, 'dsc_loss': 0.46, 'vae_loss': 72.165, 'gen_loss': 72.142}, val={'g_nll': 0.421, 'g_kl': 0.008, 'g_elbo': 0.429, 'x_rna_nll': 0.315, 'x_rna_kl': 0.006, 'x_rna_elbo': 0.321, 'x_adt_nll': 202.421, 'x

2023-10-31 11:52:28,074 ignite.handlers.early_stopping.EarlyStopping INFO: EarlyStopping: Stop training


[INFO] EarlyStopping: Restoring checkpoint "24"...
[INFO] EarlyStopping: Restoring checkpoint "24"...
[INFO] fit_SCGLUE: Estimating balancing weight...
[INFO] estimate_balancing_weight: Clustering cells...
[INFO] estimate_balancing_weight: Matching clusters...
[INFO] estimate_balancing_weight: Matching array shape = (33, 28)...
[INFO] estimate_balancing_weight: Estimating balancing weight...
[INFO] fit_SCGLUE: Fine-tuning SCGLUE model...
[INFO] check_graph: Checking variable coverage...
[INFO] check_graph: Checking edge attributes...
[INFO] check_graph: Checking self-loops...
[INFO] check_graph: Checking graph symmetry...
[INFO] check_graph: All checks passed!
[INFO] SCGLUEModel: Setting `graph_batch_size` = 4868
[INFO] SCGLUEModel: Setting `align_burnin` = 8
[INFO] SCGLUEModel: Setting `max_epochs` = 48
[INFO] SCGLUEModel: Setting `patience` = 4
[INFO] SCGLUEModel: Setting `reduce_lr_patience` = 2
[INFO] SCGLUETrainer: Using training directory: "/tmp/GLUETMPafsm5r2t"
[INFO] SCGLUETrai

2023-10-31 12:21:17,401 ignite.handlers.early_stopping.EarlyStopping INFO: EarlyStopping: Stop training


[INFO] EarlyStopping: Restoring checkpoint "23"...
[INFO] EarlyStopping: Restoring checkpoint "23"...


In [9]:
from matching.utils import write_to_pickle

write_to_pickle(rna_encodings, "../data/rna_scGLUE_embed.pickle")
write_to_pickle(adt_encodings, "../data/adt_scGLUE_embed.pickle")

## POST TRAINING

In [6]:
import torch
from matching.utils import read_from_pickle

cite = sc.read("/mnt/ps/home/CORP/johnny.xi/sandbox/matching/data/GSE194122_openproblems_neurips2021_cite_BMMC_processed.h5ad")
adt_ad = cite[:,cite.var.feature_types == "ADT"]

rna_encodings = read_from_pickle("../data/rna_scGLUE_embed.pickle")
adt_encodings = read_from_pickle("../data/adt_scGLUE_embed.pickle")

cell_types = np.unique(adt_ad.obs.cell_type)

knn_trace_avg, eot_trace_avg, knn_foscttm_avg, eot_foscttm_avg = 0, 0, 0, 0

outdict_eot = {}
outdict_snn = {}

for ct in cell_types:
    idx = np.where(gex_ad.obs.cell_type == ct)
    rna_sub, adt_sub = gex_ad.obsm["X_pca"][idx], adt_ad.X.toarray()[idx]
    adt_match_sub = adt_encodings[idx]
    print(f"Cell type: {ct}, Number of samples: {adt_sub.shape[0]}")
    snn_sub = snn_matching(rna_match_sub, adt_match_sub)
    print(f"Cell type: {ct}, kNN trace: {np.trace(snn_sub)/adt_sub.shape[0]}")
    eot_sub = eot_matching(rna_match_sub, adt_match_sub)
    print(f"Cell type: {ct}, EOT trace: {np.trace(eot_sub)/adt_sub.shape[0]}")
    snn_match = snn_sub @ adt_sub
    eot_match = eot_sub @ adt_sub
    snn_foscttm = np.array(calc_domainAveraged_FOSCTTM(adt_sub, snn_match)).mean()
    eot_foscttm = np.array(calc_domainAveraged_FOSCTTM(adt_sub, eot_match)).mean()
    torch.cuda.empty_cache()
    print(f"Cell type: {ct}, kNN FOSCTTM: {snn_foscttm}") 
    print(f"Cell type: {ct}, EOT FOSCTTM: {eot_foscttm}") 

    outdict_eot[f"Cell type: {ct} Trace"] = np.trace(eot_sub)/adt_sub.shape[0]
    outdict_eot[f"Cell type: {ct} FOSCTTM"] = {eot_foscttm}

    outdict_snn[f"Cell type: {ct} Trace"] = np.trace(snn_sub)/adt_sub.shape[0]
    outdict_snn[f"Cell type: {ct} FOSCTTM"] = {knn_foscttm}


Cell type: B1 B IGKC+, Number of samples: 820
1
1
Cell type: B1 B IGKC+, kNN trace: 0.0012224157955865273
Cell type: B1 B IGKC+, EOT trace: 0.00033913243834565325
Cell type: B1 B IGKC+, kNN FOSCTTM: 0.48659430000893417
Cell type: B1 B IGKC+, EOT FOSCTTM: 0.6276281306769111
Cell type: B1 B IGKC-, Number of samples: 613
1
1
Cell type: B1 B IGKC-, kNN trace: 0.001687596661087689
Cell type: B1 B IGKC-, EOT trace: 0.001025749069637031
Cell type: B1 B IGKC-, kNN FOSCTTM: 0.49286030344709947
Cell type: B1 B IGKC-, EOT FOSCTTM: 0.6106139845824137
Cell type: CD14+ Mono, Number of samples: 21693
1
1
Cell type: CD14+ Mono, kNN trace: 2.3048909786567094e-05


OutOfMemoryError: CUDA out of memory. Tried to allocate 1.75 GiB (GPU 0; 10.91 GiB total capacity; 8.80 GiB already allocated; 1.52 GiB free; 8.84 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
pd.DataFrame.from_dict(data = outdict_snn, orient = "index").to_csv("../results" + "scGLUE_snn.csv", header = False)

In [None]:
pd.DataFrame.from_dict(data = outdict_eot, orient = "index").to_csv("../results" + "scGLUE_snn.csv", header = False)