
# Subgraph Link Prediction Heuristics

This notebook computes link prediction heuristics for every prepared subgraph under `data/subgraphs`. For each subgraph we report the common neighbors count, Jaccard coefficient, Adamic–Adar index, and a network proximity z-score for all observed edges. Adjust the helper code if you want to look at non-edge pairs or persist the results to disk.


In [5]:

from pathlib import Path
import math
import random

import networkx as nx
import numpy as np
import pandas as pd
from tqdm.auto import tqdm


In [6]:

# Resolve project paths so the notebook works when launched from the repo root or the notebooks folder.
PROJECT_ROOT = Path.cwd().resolve()
if not (PROJECT_ROOT / 'data').exists():
    PROJECT_ROOT = PROJECT_ROOT.parent.resolve()

SUBGRAPH_ROOT = PROJECT_ROOT / 'data' / 'subgraphs'
if not SUBGRAPH_ROOT.exists():
    raise FileNotFoundError(f'Expected subgraphs under {SUBGRAPH_ROOT}')

subgraph_dirs = sorted(p for p in SUBGRAPH_ROOT.glob('*') if p.is_dir())
print(f"Found {len(subgraph_dirs)} subgraph(s):")
for path in subgraph_dirs:
    print(f" - {path.relative_to(PROJECT_ROOT)}")


Found 6 subgraph(s):
 - data/subgraphs/drug-disease
 - data/subgraphs/drug-disease-gene
 - data/subgraphs/drug-disease-gene_protein
 - data/subgraphs/drug-disease-gene_protein-pathway
 - data/subgraphs/drug-disease-gene_protein-pathway-molecular_function
 - data/subgraphs/drug-disease-pathway


In [7]:

def load_subgraph(subgraph_path: Path):
    """Return node and edge tables for a subgraph directory."""
    nodes = pd.read_csv(subgraph_path / 'node.csv', sep='	', dtype={'node_index': str})
    edges = pd.read_csv(
        subgraph_path / 'edges.csv',
        dtype={'x_index': str, 'y_index': str, 'relation': str, 'display_relation': str},
    )
    return nodes, edges


def build_graph(
    nodes: pd.DataFrame,
    edges: pd.DataFrame,
    *,
    label: str | None = None,
) -> nx.Graph:
    """Construct an undirected graph from node and edge tables."""
    graph = nx.Graph()
    node_desc = f"{label}: nodes" if label else "Nodes"
    for row in tqdm(
        nodes.itertuples(index=False),
        total=len(nodes),
        desc=node_desc,
        leave=False,
    ):
        graph.add_node(row.node_index, node_type=row.node_type, node_name=row.node_name)

    edge_desc = f"{label}: edges" if label else "Edges"
    for row in tqdm(
        edges.itertuples(index=False),
        total=len(edges),
        desc=edge_desc,
        leave=False,
    ):
        graph.add_edge(row.x_index, row.y_index, relation=row.relation)
    return graph


def compute_link_metrics(
    graph: nx.Graph,
    nodes: pd.DataFrame,
    edges: pd.DataFrame,
    *,
    proximity_samples: int = 200,
    seed: int = 42,
    label: str | None = None,
) -> pd.DataFrame:
    """Compute link prediction heuristics for every observed edge."""
    prefix = f"{label}: " if label else ""

    type_map = nodes.set_index('node_index')['node_type'].to_dict()
    name_map = nodes.set_index('node_index')['node_name'].to_dict()
    degree_map = dict(graph.degree())
    nodes_by_type = nodes.groupby('node_type')['node_index'].apply(list).to_dict()

    edge_rows = list(edges.itertuples(index=False))
    edge_pairs = [(row.x_index, row.y_index) for row in edge_rows]

    rng = random.Random(seed)
    shortest_path_cache: dict[str, dict[str, int]] = {}

    def shortest_path_length(u: str, v: str) -> float:
        if u not in shortest_path_cache:
            shortest_path_cache[u] = nx.single_source_shortest_path_length(graph, u)
        return shortest_path_cache[u].get(v, math.inf)

    baseline_cache: dict[tuple[str, str], tuple[float, float] | None] = {}

    def proximity_baseline(type_u: str, type_v: str):
        key = (type_u, type_v)
        if key in baseline_cache:
            return baseline_cache[key]

        candidates_u = nodes_by_type.get(type_u, [])
        candidates_v = nodes_by_type.get(type_v, [])
        if not candidates_u or not candidates_v:
            baseline_cache[key] = None
            baseline_cache[(type_v, type_u)] = None
            return None

        samples: list[float] = []
        seen_pairs: set[tuple[str, str]] = set()
        max_attempts = proximity_samples * 10
        attempts = 0

        progress = None
        if proximity_samples > 0:
            desc = f"{prefix}Baseline {type_u}->{type_v}"
            progress = tqdm(total=proximity_samples, desc=desc, leave=False)

        try:
            while len(samples) < proximity_samples and attempts < max_attempts:
                attempts += 1
                u = rng.choice(candidates_u)
                v = rng.choice(candidates_v)
                if u == v:
                    continue
                pair_key = (u, v) if type_u != type_v else tuple(sorted((u, v)))
                if pair_key in seen_pairs:
                    continue
                seen_pairs.add(pair_key)

                dist = shortest_path_length(u, v)
                if math.isinf(dist):
                    continue
                samples.append(dist)
                if progress is not None:
                    progress.update(1)
        finally:
            if progress is not None:
                progress.close()

        if len(samples) < 2:
            baseline_cache[key] = None
            baseline_cache[(type_v, type_u)] = None
            return None

        mean_dist = float(np.mean(samples))
        std_dist = float(np.std(samples, ddof=1)) if len(samples) > 1 else 0.0
        baseline_cache[key] = (mean_dist, std_dist)
        baseline_cache[(type_v, type_u)] = (mean_dist, std_dist)
        return baseline_cache[key]

    jaccard_scores = {}
    for u, v, score in tqdm(
        nx.jaccard_coefficient(graph, edge_pairs),
        total=len(edge_pairs),
        desc=f"{prefix}Jaccard",
        leave=False,
    ):
        jaccard_scores[(u, v)] = score
        jaccard_scores[(v, u)] = score

    adamic_scores = {}
    for u, v, score in tqdm(
        nx.adamic_adar_index(graph, edge_pairs),
        total=len(edge_pairs),
        desc=f"{prefix}Adamic-Adar",
        leave=False,
    ):
        adamic_scores[(u, v)] = score
        adamic_scores[(v, u)] = score

    records = []
    edge_desc = f"{prefix}Scoring edges"
    for row in tqdm(edge_rows, total=len(edge_rows), desc=edge_desc, leave=False):
        u = row.x_index
        v = row.y_index
        source_type = type_map.get(u)
        target_type = type_map.get(v)

        cn_count = sum(1 for _ in nx.common_neighbors(graph, u, v))

        jaccard_score = jaccard_scores.get((u, v), float('nan'))
        adamic_score = adamic_scores.get((u, v), float('nan'))

        sp_length = shortest_path_length(u, v)
        if math.isinf(sp_length):
            proximity_z = None
            baseline_mean = None
            baseline_std = None
            sp_value = None
        else:
            baseline = proximity_baseline(source_type, target_type)
            if baseline is None:
                proximity_z = None
                baseline_mean = None
                baseline_std = None
            else:
                mean_dist, std_dist = baseline
                if std_dist > 0:
                    proximity_z = (sp_length - mean_dist) / std_dist
                else:
                    proximity_z = 0.0
                baseline_mean = mean_dist
                baseline_std = std_dist
            sp_value = sp_length

        records.append(
            {
                'source_index': u,
                'target_index': v,
                'source_type': source_type,
                'target_type': target_type,
                'source_name': name_map.get(u),
                'target_name': name_map.get(v),
                'relation': row.relation,
                'common_neighbors': cn_count,
                'jaccard_coefficient': jaccard_score,
                'adamic_adar_index': adamic_score,
                'shortest_path_length': sp_value,
                'network_proximity_z': proximity_z,
                'proximity_baseline_mean': baseline_mean,
                'proximity_baseline_std': baseline_std,
                'source_degree': degree_map.get(u),
                'target_degree': degree_map.get(v),
            }
        )

    return pd.DataFrame.from_records(records)


In [8]:

results = {}
for subgraph_path in tqdm(subgraph_dirs, desc='Subgraphs', leave=False):
    label = subgraph_path.name
    nodes_df, edges_df = load_subgraph(subgraph_path)
    graph = build_graph(nodes_df, edges_df, label=label)
    metrics_df = compute_link_metrics(graph, nodes_df, edges_df, label=label)
    results[label] = {
        'graph': graph,
        'nodes': nodes_df,
        'edges': edges_df,
        'metrics': metrics_df,
    }
    tqdm.write(f"Computed metrics for {label}: {len(metrics_df)} edge(s)")


Subgraphs:   0%|          | 0/6 [00:00<?, ?it/s]

drug-disease: nodes:   0%|          | 0/25037 [00:00<?, ?it/s]

drug-disease: edges:   0%|          | 0/2822278 [00:00<?, ?it/s]

drug-disease: Jaccard:   0%|          | 0/2822278 [00:00<?, ?it/s]

drug-disease: Adamic-Adar:   0%|          | 0/2822278 [00:00<?, ?it/s]

drug-disease: Scoring edges:   0%|          | 0/2822278 [00:00<?, ?it/s]

drug-disease: Baseline drug->disease:   0%|          | 0/200 [00:00<?, ?it/s]

drug-disease: Baseline drug->drug:   0%|          | 0/200 [00:00<?, ?it/s]

drug-disease: Baseline disease->disease:   0%|          | 0/200 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:

# Preview the first few rows for one subgraph.
example_key = next(iter(results))
results[example_key]['metrics'].head()


In [None]:

# Optional: uncomment to export the metrics tables to CSV files.
# output_dir = PROJECT_ROOT / 'notebooks' / 'outputs'
# output_dir.mkdir(parents=True, exist_ok=True)
# for name, payload in results.items():
#     payload['metrics'].to_csv(output_dir / f'{name}_link_metrics.csv', index=False)
# print(f'Saved metrics to {output_dir}')


In [1]:
# --- Typed Jaccard baseline for TxGNN-style link prediction ---
# Compute Jaccard scores for held-out drug–disease pairs using train/val/test CSVs.
# Author: GPT-5 (network-analysis project)
# -------------------------------------------------------------------------------

import pandas as pd
import numpy as np
from scipy.sparse import csr_matrix
from sklearn.metrics import average_precision_score
from collections import defaultdict
from typing import Dict, List, Tuple, Set
from tqdm import tqdm   # <--- added tqdm import

# --- CONFIG --------------------------------------------------------------------
train_path = "../data/subgraphs/drug-disease/full_graph_42/train.csv"
val_path   = "../data/subgraphs/drug-disease/full_graph_42/valid.csv"
test_path  = "../data/subgraphs/drug-disease/full_graph_42/test.csv"
out_path   = "../data/subgraphs/drug-disease/scores_jaccard.csv"

neighbor_types = ["gene/protein", "phenotype", "exposure", "disease", "pathway", "anatomy"]
neg_per_pos = 1        # negatives per positive
rng = np.random.default_rng(7)

# --- LOAD & NORMALIZE ----------------------------------------------------------
def normalize_edges(df):
    df = df.rename(columns={
        "x_id": "src", "y_id": "dst",
        "x_type": "src_type", "y_type": "dst_type",
        "relation": "relation"
    })
    return df[["src", "dst", "src_type", "dst_type", "relation"]]

train_df = normalize_edges(pd.read_csv(train_path))
val_df   = normalize_edges(pd.read_csv(val_path))
test_df  = normalize_edges(pd.read_csv(test_path))

def make_undirected(df):
    df_rev = df.rename(columns={"src": "dst", "dst": "src",
                                "src_type": "dst_type", "dst_type": "src_type"})
    return pd.concat([df, df_rev], ignore_index=True)

train_undir = make_undirected(train_df)

# --- INDEXING ------------------------------------------------------------------
def build_node_index(train_df):
    nodes = pd.Index(pd.concat([train_df["src"], train_df["dst"]]).unique())
    return {n: i for i, n in enumerate(nodes)}

node_index = build_node_index(train_undir)
n_nodes = len(node_index)

# --- BUILD INCIDENCE MATRICES BY TYPE ------------------------------------------
def build_type_col_index(train_df, node_index, neighbor_types):
    type_to_entities = {t: set() for t in neighbor_types}
    for _, r in tqdm(train_df.iterrows(), total=len(train_df), desc="Indexing types"):
        if r["dst_type"] in type_to_entities:
            type_to_entities[r["dst_type"]].add(r["dst"])
        if r["src_type"] in type_to_entities:
            type_to_entities[r["src_type"]].add(r["src"])
    return {t: {e: j for j, e in enumerate(sorted(ents))}
            for t, ents in type_to_entities.items() if ents}

def build_incidence_by_type(train_df, node_index, type_col_index):
    n_nodes = len(node_index)
    A_types = {}
    for t, col_index in type_col_index.items():
        rows, cols = [], []
        for _, r in tqdm(train_df.iterrows(), total=len(train_df), desc=f"Building {t}"):
            u = node_index.get(r["src"])
            if u is None: continue
            if r["dst_type"] == t and r["dst"] in col_index:
                rows.append(u)
                cols.append(col_index[r["dst"]])
        data = np.ones(len(rows), dtype=np.uint8)
        A_types[t] = csr_matrix((data, (rows, cols)),
                                shape=(n_nodes, len(col_index)),
                                dtype=np.uint8)
    return A_types

type_col_index = build_type_col_index(train_undir, node_index, neighbor_types)
A_types = build_incidence_by_type(train_undir, node_index, type_col_index)
deg_cache = {t: np.asarray(A.sum(axis=1)).ravel() for t, A in A_types.items()}

# --- JACCARD COMPUTATION -------------------------------------------------------
def typed_jaccard_for_pairs(pairs, A_types, deg_cache):
    scores = np.zeros(len(pairs))
    for i, (u, v) in tqdm(enumerate(pairs), total=len(pairs), desc="Jaccard"):
        num = 0.0
        den = 0.0
        for t, A in A_types.items():
            cn = A[u].multiply(A[v]).sum()
            du, dv = deg_cache[t][u], deg_cache[t][v]
            num += cn
            den += (du + dv - cn)
        scores[i] = (num / den) if den > 0 else 0.0
    return scores

# --- POSITIVE / NEGATIVE EDGE EXTRACTION --------------------------------------
def pick_drug_disease_edges(df):
    mask1 = (df["src_type"] == "drug") & (df["dst_type"] == "disease")
    mask2 = (df["src_type"] == "disease") & (df["dst_type"] == "drug")
    sub = df[mask1 | mask2].copy()
    sub["drug"] = np.where(sub["src_type"] == "drug", sub["src"], sub["dst"])
    sub["disease"] = np.where(sub["src_type"] == "drug", sub["dst"], sub["src"])
    return sub[["drug", "disease", "relation"]]

test_dd = pick_drug_disease_edges(test_df)
val_dd  = pick_drug_disease_edges(val_df)

pos_df = pd.concat([
    test_dd.assign(split="test"),
    val_dd.assign(split="val")
])

def map_idx(n): return node_index.get(n, -1)
pos_pairs_idx = [(map_idx(r.drug), map_idx(r.disease)) for r in pos_df.itertuples()]

# Negatives: random per disease
def sample_negatives(drug_ids, disease_ids, existing_edges, k_per_pos):
    negs = []
    for d in tqdm(disease_ids, desc="Sampling negatives"):
        for _ in range(k_per_pos):
            tries = 0
            while tries < 20:
                cand = rng.choice(drug_ids)
                if (cand, d) not in existing_edges and (d, cand) not in existing_edges:
                    negs.append((cand, d))
                    break
                tries += 1
    return negs

all_edges = pd.concat([train_df, val_df, test_df])
edge_set = set(map(tuple, pick_drug_disease_edges(all_edges)[["drug", "disease"]].values))
drugs = sorted(set(pd.concat([train_df[train_df["src_type"]=="drug"]["src"],
                              train_df[train_df["dst_type"]=="drug"]["dst"]]).unique()))
diseases = sorted(pos_df["disease"].unique())
neg_pairs = sample_negatives(drugs, diseases, edge_set, neg_per_pos)

neg_rows = pd.DataFrame(neg_pairs, columns=["drug", "disease"])
neg_rows["relation"] = "NEG"
neg_rows["split"] = "neg"
neg_pairs_idx = [(map_idx(r.drug), map_idx(r.disease)) for r in neg_rows.itertuples()]

# --- COMPUTE SCORES ------------------------------------------------------------
pos_scores = typed_jaccard_for_pairs(pos_pairs_idx, A_types, deg_cache)
neg_scores = typed_jaccard_for_pairs(neg_pairs_idx, A_types, deg_cache)

out_df = pd.concat([
    pos_df.assign(score=pos_scores),
    neg_rows.assign(score=neg_scores)
], ignore_index=True)

# --- EVALUATE ------------------------------------------------------------------
y_true = (out_df["relation"] != "NEG").astype(int)
y_pred = out_df["score"]
auprc = average_precision_score(y_true, y_pred)
print(f"AUPRC (Jaccard baseline): {auprc:.4f}")

# --- SAVE ---------------------------------------------------------------------
out_df.to_csv(out_path, index=False)
print(f"Saved results to {out_path}")


  train_df = normalize_edges(pd.read_csv(train_path))
  val_df   = normalize_edges(pd.read_csv(val_path))
  test_df  = normalize_edges(pd.read_csv(test_path))
Indexing types:   4%|▍         | 606785/15390948 [00:07<03:07, 78782.58it/s]


KeyboardInterrupt: 