In [1]:
import os
import sys
import pandas as pd
import numpy as np
import glob
import time
import gget
import scipy
from scipy.sparse import csr_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import networkx as nx
from importlib import reload

import sys
sys.path.append('../../Hypergraph-MT/code/HyMT')
sys.path.append('../../Hypergraph-MT/code/')
import HyMT as hymt

# Load the core

In [2]:
start_time = time.time()

# load the hypergraph:
fpath = "/scratch/indikar_root/indikar1/shared_data/higher_order/transcription_clusters/core_incidence_1000000_protien_coding_only.pkl"
H = pd.read_pickle(fpath)

end_time = time.time()
total_time = end_time - start_time

print(f"{H.shape=}")
print(f"Hypergraph loaded in: {total_time:.2f} seconds")

print("\nFirst 5 rows of the incidence matrix:")
H.head()

H.shape=(17186, 34592)
Hypergraph loaded in: 2.19 seconds

First 5 rows of the incidence matrix:


Unnamed: 0,66953ddf-e76d-4cdf-aaf8-be028a2d7b04,c160a170-5af7-412b-9c03-36dfef017384,3b0686b6-f18f-495d-89f5-8c8b286c2bb1,77bc1796-a0a3-4140-a97b-d4a786d17cb2,4f913e8a-799a-488e-a7bd-6ae9566e5c37,da8a0dfa-5deb-48c3-bf6e-bde5534e0578,a425bdc9-37ea-4020-bc7c-5085fb99a3c7,8ad994fd-c214-46f9-99b0-37c2b3f2946e,2ea6e55f-cc78-418a-b241-f134009153a0,05790af8-be74-4b99-8d1b-49074fa8f81d,...,05fc8d13-3610-4bb0-b173-a908dd526cdd,925b2134-befc-44e1-a9c5-97ca1295c96c,2f5a483f-7f31-4028-b91e-8f8c83a5b922,43cef5b9-05dc-4d4a-b20b-349a91ae2224,49c8ea45-81cb-4f76-a7ad-bd753e8f8c7c,eae8359f-2057-4492-93d1-10437e892f0b,fcf0a060-2833-4ff0-a352-d5e702f27f46,6199d009-7ef7-44f6-b10c-c0ac846f362c,0e705646-2f32-40f9-bad9-fbf5f7ef9d79,9f9a0a4e-630b-406a-bb24-7026948c9787
Smarca2,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
Sh2d5,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
Dyrk1a,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
Igf2bp3,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
Tmem267,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [10]:
(H.sum() > 8).value_counts(normalize=True)

False    0.924289
True     0.075711
Name: proportion, dtype: float64

In [4]:
break

SyntaxError: 'break' outside loop (668683560.py, line 1)

# data set up

In [None]:
def get_hyperedges(H, edge_weights=None, sample_size=1000, max_order=5):
    """
    Extracts hyperedges from an incidence matrix H.
    
    This function samples columns from the incidence matrix, identifies hyperedges 
    (sets of nodes connected by a hyperedge), and optionally filters them based on their order.
    
    Args:
      H: pandas DataFrame, the incidence matrix where rows represent nodes 
         and columns represent hyperedges.
      edge_weights: numpy array, optional weights for each hyperedge. If None, 
                    all hyperedges are assigned a weight of 1.
      sample_size: int, the number of columns (hyperedges) to sample from H.
      max_order: int, the maximum allowed hyperedge order (number of nodes in a hyperedge).
                 If None, no filtering is performed.
    
    Returns:
      tuple: (B, A, hyperedges, edge_idx)
        B: pandas DataFrame, the sampled incidence matrix.
        A: numpy array, hyperedge weights.
        hyperedges: list of tuples, each tuple representing a hyperedge with node indices.
        edge_idx: list of int, indices of the hyperedges that pass the filtering criteria.
    """
    B = H.sample(sample_size, axis=1)  # Randomly sample columns
    B = B[B.sum(axis=1) > 0]  # Keep rows with at least one positive entry

    if edge_weights is None:
        A = np.ones(sample_size)

    hyperedges = B.apply(lambda x: tuple(x[x > 0].index), axis=0).to_numpy()  

    # Hyperedge filtering
    orders = [len(e) for e in hyperedges]
    if max_order:
        edge_idx = [eid for eid, d in enumerate(orders) if 2 <= d <= max_order]
    else:
        edge_idx = [eid for eid, _ in enumerate(orders)]

    return B, A, hyperedges, edge_idx

B, A, hyperedges, edge_idx = get_hyperedges(
    H, 
    edge_weights=None,
    sample_size=1000,
    max_order=7,
)

print(f"Incidence matrix B shape: {B.shape}")
print(f"Hyperedge weight vector A shape: {A.shape}")
print(f"Number of hyperedges found: {len(hyperedges)}")
print(f"Keeping {len(edge_idx)} out of {len(hyperedges)} hyperedges after filtering ({100*(len(edge_idx) / len(hyperedges)):.2f}%).")

# model set up

In [None]:
conf_inf = {
    "seed": 10,
    "constraintU": False,
    "fix_communities": False,
    "fix_w": False,
    "gammaU": 0,
    "gammaW": 0,
    "initialize_u": None,  # Use None for null
    "initialize_w": None,  # Use None for null
    "out_inference": False,
    "plot_loglik": True,
}

# Run the model

In [None]:
K = 50 # total number of communities to detect

start_time = time.time()
model = hymt.model.HyMT(
    verbose=False,
    num_realizations=3,
)
u, w, maxL = model.fit(
    A[edge_idx], 
    hyperedges[edge_idx], 
    B.to_numpy()[:, edge_idx],
    K=K,
    **conf_inf,
)

end_time = time.time()
total_time = end_time - start_time
print(f"\nTime elapsed: {total_time:.2f} seconds")
print()

print(f"---- results ----")
print(f"Membership matrix (u) with {K=}: {u.shape=}")
print(f"Affinity matrix (w): {w.shape=}")
print(f"Maximum log-likelihood value (maxL): {maxL=}")

In [None]:
train = model.train_info
train.head()

In [None]:
break

# rank hyperedges by their log-liklihoods under the trainned model

In [None]:
def predict_hyperedge(hyperedge, u, w, index):
    """Calculates the probability of a hyperedge being non-zero.

    Args:
      hyperedge: The hyperedge.
      u: Membership matrix.
      w: Affinity matrix.
      index: Index for the DataFrame.

    Returns:
      A tuple containing the calculated value (M) and the probability.
    """
    u = pd.DataFrame(u, index=index)
    M = (np.prod(u.loc[np.array(hyperedge)], axis=0) * w[len(hyperedge) - 2]).sum()
    proba = 1.0 - np.exp(-M)
    return M, proba

results = []

for hyperedge in hyperedges[edge_idx]:
    p, M = predict_hyperedge(hyperedge, u, w, B.index)
    results.append({
        'hyperedge' : "-".join(hyperedge),
        'p' : p,
        'M' : M,
    })

results = pd.DataFrame(results)
results = results.sort_values(by='p', ascending=False).reset_index(drop=True)
print(f"{results.shape=}")
results.head(15)

In [None]:
sleep = 2
# alpha = 0.1
n_top = 5
database = 'ontology'

pd.set_option('display.max_colwidth', 50) 

for idx, row in results.head(n_top).iterrows():
    time.sleep(sleep)
    query = row['hyperedge'].split("-")
    edf = gget.enrichr(query, database=database)
    # edf = edf[edf['adj_p_val'] <= alpha]
    edf['n_overlap'] = edf['overlapping_genes'].apply(len)
    edf = edf[edf['n_overlap'] > 1]
    
    if not edf.empty:
        print(f"Hyperedge: {row['hyperedge']}")
        print(edf[['path_name', 'overlapping_genes']].head().to_markdown(index=False, numalign="left", stralign="left")) 
    else:
        print(f"No enrichment: {row['hyperedge']}")

# Sweep values of K

In [None]:
B, A, hyperedges, edge_idx = get_hyperedges(
    H, 
    edge_weights=None,
    sample_size=500,
    max_order=5,
)

print(f"Incidence matrix B shape: {B.shape}")
print(f"Hyperedge weight vector A shape: {A.shape}")
print(f"Number of hyperedges found: {len(hyperedges)}")
print(f"Keeping {len(edge_idx)} out of {len(hyperedges)} hyperedges after filtering ({100*(len(edge_idx) / len(hyperedges)):.2f}%).")

In [None]:
failure_tolerance = 2
max_k = 200
steps = 30

conf_inf = {
    "seed": 10,
    "constraintU": False,
    "fix_communities": False,
    "fix_w": False,
    "gammaU": 0,
    "gammaW": 0,
    "initialize_u": None,  # Use None for null
    "initialize_w": None,  # Use None for null
    "out_inference": False,
    "plot_loglik": False,
}

maximums = {}
results = []

failures = 0
for K in np.linspace(2, max_k, steps).astype(int):
    print(f"---- runnning with {K=}----")    
    model = hymt.model.HyMT(
        verbose=False,
        num_realizations=1,
    )
    try:
        u, w, maxL = model.fit(
            A[edge_idx], 
            hyperedges[edge_idx], 
            B.to_numpy()[:, edge_idx],
            K=K,
            **conf_inf,
        )
    except:
        failures += 1
        if failures < failure_tolerance:
            continue
        else:
            print(f"ITERATION FAILED AT {K=}")
            break

    # get the log likelihood over trainning  
    train = model.train_info.groupby('iter')['loglik'].max()
    train = train.reset_index(drop=False)
    train['K'] = K
    results.append(train)

    # store the maximum L
    maximums[K] = maxL

# full trainning results
results = pd.concat(results)
results.head()

In [None]:
plt.rcParams['figure.dpi'] = 300
plt.rcParams['figure.figsize'] = (steps // 16), 2.5

sns.lineplot(
    maximums,
    marker="o",
    markersize=3,
    lw=0.75,
    mec='k',
    color='blue',
)

plt.ylabel("max log liklihood")
plt.xlabel("K")

sns.despine()

In [None]:
plt.rcParams['figure.dpi'] = 300
plt.rcParams['figure.figsize'] = 4, 4

sns.lineplot(
    data=results,
    x='iter',
    y='loglik',
    hue='K',
    marker=".",
    lw=0.75,
    palette='Set1'
)

plt.ylabel("log liklihood")
plt.xlabel("iteration")

sns.despine()

In [None]:
break

# ARCHIVE

In [None]:
if bool(args.baselines):
    """ Baseline1: Run the model on the graph obtained by clique expansions (Graph-MT) """
    time_GrMT = time.time()
    if verbose:
        print('\n### Run Graph-MT ###')
    if conf_inf['out_inference']:
        conf_inf['end_file'] = init_end_file + '_GrMT'
    A2, hye2, B2 = tl.extract_input_pairwise(A[hyL2], hye[hyL2], N)  # get the graph by clique expansions
    model2 = hymt.model.HyMT()
    _ = model2.fit(A2, hye2, B2, K=K, **conf_inf)
    if verbose:
        print(f'\nTime elapsed: {np.round(time.time() - time_GrMT, 2)} seconds.')

    """ Baseline2: Run the model on the graph given by the subset of pairwise interactions (Pairs-MT) """
    time_PaMT = time.time()
    if verbose:
        print('\n### Run Pairs-MT ###')
    if conf_inf['out_inference']:
        conf_inf['end_file'] = init_end_file + '_PaMT'
    mask_pairs = np.array([False if len(e) != 2 else True for e in hye])  # keep only the subset of pairs
    if sum(mask_pairs) > 0:
        model3 = hymt.model.HyMT()
        _ = model3.fit(A[mask_pairs], hye[mask_pairs], B[:, mask_pairs], K=K, **conf_inf)
        if verbose:
            print(f'\nTime elapsed: {np.round(time.time() - time_PaMT, 2)} seconds.')