In [None]:
#### INSTRUCTIONS ####

# 1) Update EC_DIR (below) to reflect the location of the challenge files in your Google Drive as required.

    # For example:

        # 1) Navigate to Google Drive: https://drive.google.com/drive/my-drive and log in
        # 2) Create a folder called 'ecc_files' in the root directory (should be called 'MyDrive')
        # 3) Upload the files to the folder
        
# 2) Connect to runtime (there are CPUs that can be used for free if you don't have any compute units available for GPU use but these might take a while to run)

# 3) Run the rest of the notebook (you will be prompted to grant access to your Google Drive)

EC_DIR = "/content/drive/MyDrive/ecc_files"

In [None]:
############################################################
#  Setup: Install Dependencies and Mount Drive (Colab)
############################################################

# Using PyTorch geometric implementation of Node2Vec to leverage GPU support
# using pre-built wheels to speed up the process

import os # noqa: F401
import torch # noqa: F401
from google.colab import drive

# Install PyTorch Geometric dependencies compatible with the current PyTorch version
def install_pytorch_geometric():
    TORCH_VERSION = torch.__version__.split("+")[0]  # noqa: F841
    CUDA_VERSION = torch.version.cuda.replace(".", "")  # noqa: F841
    base_url = "https://data.pyg.org/whl"  # noqa: F841
    
    # Install each dependency from the PyTorch Geometric library
    !pip install torch-scatter -f {base_url}/torch-{TORCH_VERSION}+cu{CUDA_VERSION}.html
    !pip install torch-sparse -f {base_url}/torch-{TORCH_VERSION}+cu{CUDA_VERSION}.html
    !pip install torch-cluster -f {base_url}/torch-{TORCH_VERSION}+cu{CUDA_VERSION}.html
    !pip install torch-spline-conv -f {base_url}/torch-{TORCH_VERSION}+cu{CUDA_VERSION}.html
    !pip install pyg-lib -f {base_url}/torch-{TORCH_VERSION}+cu{CUDA_VERSION}.html
    !pip install torch-geometric

install_pytorch_geometric()

# Mount Google Drive
drive.mount("/content/drive")


In [None]:
############################################################
#  Import remaining libraries
############################################################

import pickle
import itertools
import pandas as pd
import numpy as np
from typing import Tuple, List, Any, Optional

import networkx as nx
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, average_precision_score, f1_score
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler

from torch import nn
from torch.cuda.amp import autocast
from transformers import AutoTokenizer, AutoModel

from torch_geometric.nn import Node2Vec
from torch_geometric.data import Data as PyGData
from torch_geometric.utils import from_networkx


In [None]:
############################################################
# 1) Device Settings
############################################################
# Use GPU if available

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
if torch.cuda.is_available():
    print(f"Using GPU: {torch.cuda.get_device_name(0)}")

############################################################
# 2) Model & Embedding Dimensions
############################################################
# Name of the pre-trained model to use for semantic embeddings
# MODEL_NAME_SEMANTIC = "dmis-lab/biobert-v1.1"
# MODEL_NAME_SEMANTIC = "michiyasunaga/BioLinkBERT-base"
# MODEL_NAME_SEMANTIC = "sultan/BioM-BERT-PubMed-PMC-Large"
MODEL_NAME_SEMANTIC = "bioformers/bioformer-16L"

# Extracting model name to avoid path construction issues
MODEL_NAME = MODEL_NAME_SEMANTIC.split("/")[-1]

# Dimensionality of the semantic embeddings
EMBED_DIM_SEMANTIC = 768

# Dimensionality of the structural embeddings
EMBED_DIM_STRUCT = 128

# Combined dimensionality for the hybrid embeddings (semantic + structural)
EMBED_DIM_HYBRID = EMBED_DIM_SEMANTIC + EMBED_DIM_STRUCT

############################################################
# 3) Data & I/O Directories
############################################################

# Paths to the CSV files containing nodes, edges and ground truth information
NODES_CSV = os.path.join(EC_DIR, "Nodes.csv")
EDGES_CSV = os.path.join(EC_DIR, "Edges.csv")
GROUND_TRUTH_CSV = os.path.join(EC_DIR, "Ground Truth.csv")

# File path to store or load structural embeddings
STRUCT_EMB_PATH = os.path.join(EC_DIR, "structural_embeddings.pkl")

# File path to store or load semantic embeddings
SEMANTIC_EMB_PATH = os.path.join(EC_DIR, f"{MODEL_NAME}_semantic_embeddings.pkl")

# File path to store or load hybrid embeddings (semantic + structural)
HYBRID_EMB_PATH = os.path.join(EC_DIR, f"{MODEL_NAME}_hybrid_embeddings.pkl")

############################################################
# 4) Node2Vec Hyperparameters
############################################################

# NODE2VEC_EMB_DIM:
#   The dimensionality of the learned node embeddings. Larger dimensions can
#   capture more nuanced relationships but require more memory and can risk overfitting.
NODE2VEC_EMB_DIM = EMBED_DIM_STRUCT

# NODE2VEC_WALK_LENGTH:
#   The number of steps in each random walk. A higher walk length lets the model
#   capture more distant structural patterns but increases computational cost.
NODE2VEC_WALK_LENGTH = 20

# NODE2VEC_CONTEXT_SIZE:
#   The “window size” for the skip-gram model—how many nodes to each side of a 
#   target node are considered part of its context. A larger context includes
#   more neighborhood information but may dilute the most local signals.
NODE2VEC_CONTEXT_SIZE = 10

# NODE2VEC_WALKS_PER_NODE:
#   How many random walks to start from each node. More walks can lead to richer 
#   co-occurrence statistics for skip-gram training, but also increase runtime.
NODE2VEC_WALKS_PER_NODE = 20

# NODE2VEC_EPOCHS:
#   The number of epochs (full passes) over all generated random walks during 
#   training. Too few may underfit; too many risks overfitting or diminishing returns.
NODE2VEC_EPOCHS = 10

# NODE2VEC_LR:
#   The learning rate for optimising the Node2Vec model (here using SparseAdam).
#   Higher values converge faster but can be unstable; lower values are more stable
#   but slower to converge.
NODE2VEC_LR = 0.01

# NODE2VEC_BATCH_SIZE:
#   The batch size for training on random walk “samples” in skip-gram.
#   Larger batches can be faster on GPUs but need more memory; smaller batches
#   can sometimes generalise better but may take longer to train.
NODE2VEC_BATCH_SIZE = 128


############################################################
# 5) Semantic Embedding Hyperparameters
############################################################
# Batch size for processing texts when generating semantic embeddings
SEMANTIC_BATCH_SIZE = 128

# Maximum sequence length when tokenising texts for semantic embeddings
SEMANTIC_MAX_LENGTH = 128

############################################################
# 6) Classifier Hyperparameters
############################################################
# Neural network (MLP) training: default number of epochs
TRAIN_CLASSIFIER_EPOCHS = 5

# Neural network (MLP) training: default learning rate
TRAIN_CLASSIFIER_LR = 1e-3

# Neural network (MLP) training: default batch size
TRAIN_CLASSIFIER_BATCH_SIZE = 256

# Logistic Regression: maximum iterations for solver convergence
LOGREG_MAX_ITER = 1000

############################################################
# 7) Splits & Seeds
############################################################
# Proportion of dataset to be used as the test set
TEST_SPLIT = 0.20

# Proportion of (train) dataset to be used for validation
VAL_SPLIT = 0.20

# Seed for random number generators (reproducibility)
RANDOM_SEED = 42

############################################################
# 8) MLP Hyperparameter Search
############################################################
# MLP hyperparameter search: possible hidden dimensions
MLP_HIDDEN_DIMS = [128, 256]

# MLP hyperparameter search: possible learning rates
MLP_LRS = [1e-3, 5e-4]

# MLP hyperparameter search: possible epochs
MLP_EPOCH_CHOICES = [5, 10]

print("=== Parameter Configuration ===")
print(f"DEVICE: {DEVICE}")
print(f"MODEL_NAME_SEMANTIC: {MODEL_NAME_SEMANTIC}")
print(f"EMBED_DIM_SEMANTIC: {EMBED_DIM_SEMANTIC}")
print(f"EMBED_DIM_STRUCT: {EMBED_DIM_STRUCT}")
print(f"EMBED_DIM_HYBRID: {EMBED_DIM_HYBRID}")
print("================================\n")


############################################################
# Step 1: Data Loading & Preprocessing
############################################################
def load_data(nodes_csv: str, edges_csv: str, gt_csv: str) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    """
    Loads data from CSV files.

    :param nodes_csv: File path to the nodes CSV.
    :param edges_csv: File path to the edges CSV.
    :param gt_csv: File path to the ground truth CSV.
    :return: A tuple containing the nodes, edges and ground truth dataframes.
    """
    nodes_df = pd.read_csv(nodes_csv)
    edges_df = pd.read_csv(edges_csv)
    gt_df = pd.read_csv(gt_csv)
    return nodes_df, edges_df, gt_df


def build_node_index(nodes_df: pd.DataFrame) -> Tuple[dict, List[Any], List[str]]:
    """
    Builds node index mappings and extracts textual data for each node.

    :param nodes_df: DataFrame containing node information.
    :return: A tuple of (node-to-index dictionary, index-to-node list, list of node texts).
    """
    unique_ids = nodes_df["id"].tolist()
    node2idx = {nid: i for i, nid in enumerate(unique_ids)}
    idx2node = [nid for nid in unique_ids]

    node_texts = []
    for _, row in nodes_df.iterrows():
        name = str(row.get("name", ""))
        desc = str(row.get("description", ""))
        text = (name + " " + desc).strip()
        if not text:
            text = "No description"
        node_texts.append(text)

    return node2idx, idx2node, node_texts


############################################################
# Step 2: Structural Embeddings via PyTorch Geometric Node2Vec
############################################################
def build_graph_pyg(edges_df: pd.DataFrame, node2idx: dict) -> PyGData:
    """
    Builds a PyTorch Geometric graph from the edges dataframe and node index.

    :param edges_df: DataFrame containing edges information with columns 'subject' and 'object'.
    :param node2idx: A dictionary mapping node IDs to indices.
    :return: A PyGData graph object.
    """
    G = nx.Graph()
    for n in node2idx.values():
        G.add_node(n)

    for _, row in edges_df.iterrows():
        s, o = row["subject"], row["object"]
        if s in node2idx and o in node2idx:
            G.add_edge(node2idx[s], node2idx[o])

    pyg_data = from_networkx(G)
    return pyg_data


def generate_node2vec_embeddings_pyg(
    data: PyGData,
    embedding_dim: int = NODE2VEC_EMB_DIM,
    walk_length: int = NODE2VEC_WALK_LENGTH,
    context_size: int = NODE2VEC_CONTEXT_SIZE,
    walks_per_node: int = NODE2VEC_WALKS_PER_NODE,
    epochs: int = NODE2VEC_EPOCHS,
    lr: float = NODE2VEC_LR,
    batch_size: int = NODE2VEC_BATCH_SIZE,
) -> np.ndarray:
    """
    Generates structural embeddings using PyTorch Geometric's Node2Vec.

    :param data: PyGData graph.
    :param embedding_dim: Dimensionality of the embeddings.
    :param walk_length: Length of each random walk.
    :param context_size: Context size for Skip-Gram.
    :param walks_per_node: Number of random walks per node.
    :param epochs: Number of training epochs.
    :param lr: Learning rate for the optimiser.
    :param batch_size: Batch size for training.
    :return: Numpy array of node embeddings.
    """
    print("Initialising PyTorch Geometric Node2Vec...")
    node2vec = Node2Vec(
        edge_index=data.edge_index,
        embedding_dim=embedding_dim,
        walk_length=walk_length,
        context_size=context_size,
        walks_per_node=walks_per_node,
        num_negative_samples=1,
        sparse=True,
    ).to(DEVICE)

    optimiser = torch.optim.SparseAdam(node2vec.parameters(), lr=lr)

    print("Training Node2Vec embeddings...")
    node2vec.train()
    loader = node2vec.loader(batch_size=batch_size, shuffle=True, num_workers=4)

    for epoch in range(epochs):
        total_loss = 0
        for pos_rw, neg_rw in loader:
            pos_rw, neg_rw = pos_rw.to(DEVICE), neg_rw.to(DEVICE)
            optimiser.zero_grad()
            loss = node2vec.loss(pos_rw, neg_rw)
            loss.backward()
            optimiser.step()
            total_loss += loss.item()
        print(f"Epoch {epoch + 1}/{epochs}, Loss: {total_loss:.4f}")

    print("Extracting embeddings...")
    node2vec.eval()
    embeddings = node2vec.embedding.weight.cpu().detach().numpy()
    print("Embeddings generated successfully.")
    return embeddings


############################################################
# Step 3: Semantic Embeddings
############################################################
class SemanticEmbedder:
    """
    Class for generating semantic embeddings using a pre-trained transformer model.
    """

    def __init__(self, model_name: str) -> None:
        """
        Initialise the semantic embedder.

        :param model_name: Pre-trained model name.
        """
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
        self.model = AutoModel.from_pretrained(model_name).to(DEVICE)
        self.model.eval()

    def encode_texts(
        self,
        texts: List[str],
        batch_size: int = SEMANTIC_BATCH_SIZE,
        max_length: int = SEMANTIC_MAX_LENGTH,
    ) -> np.ndarray:
        """
        Encodes a list of texts into semantic embeddings.

        :param texts: List of textual descriptions.
        :param batch_size: Batch size for processing.
        :param max_length: Maximum sequence length for tokenisation.
        :return: Numpy array of embeddings.
        """
        all_embeddings = []
        for i in range(0, len(texts), batch_size):
            batch = texts[i : i + batch_size]
            enc = self.tokenizer(
                batch,
                padding=True,
                truncation=True,
                max_length=max_length,
                return_tensors="pt",
            ).to(DEVICE)
            with torch.no_grad():
                with autocast():
                    outputs = self.model(**enc)
                    cls_emb = outputs.last_hidden_state[:, 0, :]
            all_embeddings.append(cls_emb.cpu().numpy())
        return np.concatenate(all_embeddings, axis=0)


############################################################
# Step 4: Combine Structural & Semantic Embeddings
############################################################
def build_hybrid_embeddings(semantic_emb: np.ndarray, structural_emb: np.ndarray) -> np.ndarray:
    """
    Concatenates semantic and structural embeddings.

    :param semantic_emb: Semantic embeddings as a numpy array.
    :param structural_emb: Structural embeddings as a numpy array.
    :return: Hybrid embeddings as a concatenated numpy array.
    """
    return np.concatenate((semantic_emb, structural_emb), axis=1)


############################################################
# Step 5: Prepare Dataset for Link Classification
############################################################
def prepare_dataset(gt_df: pd.DataFrame, node2idx: dict, embeddings: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    """
    Prepares features and labels for the link classification task.

    :param gt_df: Ground truth dataframe containing source, target and label 'y'.
    :param node2idx: A dictionary mapping node IDs to indices.
    :param embeddings: Hybrid embeddings as a numpy array.
    :return: Tuple of features and labels.
    """
    pairs = gt_df[["source", "target"]].values
    labels = gt_df["y"].values.astype(float)

    X = []
    for src, tgt in pairs:
        if src in node2idx and tgt in node2idx:
            src_idx = node2idx[src]
            tgt_idx = node2idx[tgt]
            pair_emb = np.concatenate([embeddings[src_idx], embeddings[tgt_idx]])
        else:
            pair_emb = np.zeros((embeddings.shape[1] * 2,))
        X.append(pair_emb)
    X = np.array(X)
    return X, labels


############################################################
# Step 6: Classification & Evaluation
############################################################
class LinkClassifier(nn.Module):
    """
    A simple MLP for link classification.
    """

    def __init__(self, in_dim: int, hidden_dim: int = 128) -> None:
        """
        Initialise the MLP.

        :param in_dim: Input dimensionality.
        :param hidden_dim: Hidden layer dimensionality.
        """
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of the classifier.

        :param x: Input tensor.
        :return: Output tensor.
        """
        return self.mlp(x)


def train_classifier(
    model: nn.Module,
    X_train: np.ndarray,
    y_train: np.ndarray,
    X_val: np.ndarray,
    y_val: np.ndarray,
    epochs: int = TRAIN_CLASSIFIER_EPOCHS,
    lr: float = TRAIN_CLASSIFIER_LR,
    batch_size: int = TRAIN_CLASSIFIER_BATCH_SIZE,
) -> nn.Module:
    """
    Trains the classifier and evaluates on a validation set.

    :param model: The link classifier model.
    :param X_train: Training features.
    :param y_train: Training labels.
    :param X_val: Validation features.
    :param y_val: Validation labels.
    :param epochs: Number of training epochs.
    :param lr: Learning rate.
    :param batch_size: Batch size.
    :return: The trained model.
    """
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    model.to(DEVICE)

    X_train_t = torch.tensor(X_train, dtype=torch.float32).to(DEVICE)
    y_train_t = torch.tensor(y_train, dtype=torch.float32).view(-1, 1).to(DEVICE)
    X_val_t = torch.tensor(X_val, dtype=torch.float32).to(DEVICE)
    y_val_t = torch.tensor(y_val, dtype=torch.float32).view(-1, 1).to(DEVICE) # noqa: F841

    best_val_auc = 0.0
    best_model_state: Optional[Any] = None

    for epoch in range(epochs):
        model.train()
        perm = torch.randperm(X_train_t.size(0))
        total_loss = 0.0
        for i in range(0, X_train_t.size(0), batch_size):
            idx = perm[i : i + batch_size]
            xb = X_train_t[idx]
            yb = y_train_t[idx]

            optimizer.zero_grad()
            out = model(xb)
            loss = criterion(out, yb)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        # Validation
        model.eval()
        with torch.no_grad():
            val_out = model(X_val_t)
            val_prob = torch.sigmoid(val_out).cpu().numpy().flatten()
            val_auc = roc_auc_score(y_val, val_prob)
        print(f"Epoch {epoch+1}/{epochs} | Train Loss: {total_loss:.4f} | Val AUC: {val_auc:.4f}")

        if val_auc > best_val_auc:
            best_val_auc = val_auc
            best_model_state = model.state_dict()

    if best_model_state:
        model.load_state_dict(best_model_state)
    return model


def hyperparam_search_mlp(
    X_train: np.ndarray,
    y_train: np.ndarray,
    X_val: np.ndarray,
    y_val: np.ndarray,
    device: str = DEVICE,
) -> Tuple[LinkClassifier, Tuple[int, float, int], float]:
    """
    Conducts a grid search over hidden_dim, lr and epochs for the MLP.

    :param X_train: Training features.
    :param y_train: Training labels.
    :param X_val: Validation features.
    :param y_val: Validation labels.
    :param device: Device to use.
    :return: A tuple (best_model, best_config, best_val_auc) where best_config is (hidden_dim, lr, epochs).
    """
    best_auc = 0.0
    best_config: Optional[Tuple[int, float, int]] = None
    best_model: Optional[LinkClassifier] = None

    for hd, lr_, ep in itertools.product(MLP_HIDDEN_DIMS, MLP_LRS, MLP_EPOCH_CHOICES):
        print(f"\n[Hyperparam Search] Trying hidden_dim={hd}, lr={lr_}, epochs={ep}")
        model = LinkClassifier(in_dim=X_train.shape[1], hidden_dim=hd)
        trained_model = train_classifier(
            model,
            X_train,
            y_train,
            X_val,
            y_val,
            epochs=ep,
            lr=lr_,
            batch_size=TRAIN_CLASSIFIER_BATCH_SIZE,
        )
        # Validation AUC after training
        X_val_t = torch.tensor(X_val, dtype=torch.float32).to(device)
        trained_model.eval()
        with torch.no_grad():
            val_out = trained_model(X_val_t)
            val_prob = torch.sigmoid(val_out).cpu().numpy().flatten()
        auc_val = roc_auc_score(y_val, val_prob)

        if auc_val > best_auc:
            best_auc = auc_val
            best_config = (hd, lr_, ep)
            # Clone the model state
            best_model = LinkClassifier(in_dim=X_train.shape[1], hidden_dim=hd).to(device)
            best_model.load_state_dict(trained_model.state_dict())

    print(
        f"\n[Hyperparam Search] Best config: hidden_dim={best_config[0]}, lr={best_config[1]}, "
        f"epochs={best_config[2]} with val AUC={best_auc:.4f}\n"
    )
    return best_model, best_config, best_auc

# =========================================================
#  NEW Helper: remove edges from edges_df for Node2Vec
# =========================================================
def remove_positive_edges_from_graph(
    edges_df: pd.DataFrame,
    gt_subset: pd.DataFrame,
) -> pd.DataFrame:
    """
    Removes edges in 'gt_subset' (where y=1) from the main edges_df.
    This ensures Node2Vec does not see future/val edges.
    
    :param edges_df: The full set of known edges (positive) in your graph.
    :param gt_subset: The ground-truth edges you want removed (subset can be val or test).
    :return: A filtered edges_df with the edges in gt_subset removed.
    """
    # Keep only edges with label=1 in the subset
    pos_edges = gt_subset[gt_subset["y"] == 1][["source", "target"]].values
    
    # Make a set for quick look-up
    pos_edge_set = set()
    for s, t in pos_edges:
        # Because your graph is undirected in Node2Vec
        pos_edge_set.add((s, t))
        pos_edge_set.add((t, s))

    mask = []
    for row in edges_df.itertuples():
        e = (row.subject, row.object)
        mask.append(e not in pos_edge_set)
    return edges_df[mask]


# =========================================================
#  Modified main with no structural leakage
# =========================================================
def main():
    print("Loading data...")
    nodes_df, edges_df, gt_df = load_data(NODES_CSV, EDGES_CSV, GROUND_TRUTH_CSV)

    print("Building node indices and text data...")
    node2idx, idx2node, node_texts = build_node_index(nodes_df)

    # -----------------------------------------------------
    # 1) SPLIT ground truth edges (gt_df) into train/val/test
    #    BEFORE building Node2Vec
    # -----------------------------------------------------
    gt_train, gt_test = train_test_split(
        gt_df, test_size=TEST_SPLIT, random_state=RANDOM_SEED
    )
    gt_train, gt_val = train_test_split(
        gt_train, test_size=VAL_SPLIT, random_state=RANDOM_SEED
    )

    # Optionally print distribution
    print("\nGround Truth Splits (pos rate might vary):")
    print(f"  gt_train: {len(gt_train)} edges, ~{gt_train['y'].mean()*100:.1f}% positive")
    print(f"  gt_val:   {len(gt_val)} edges, ~{gt_val['y'].mean()*100:.1f}% positive")
    print(f"  gt_test:  {len(gt_test)} edges, ~{gt_test['y'].mean()*100:.1f}% positive")

    # -----------------------------------------------------
    # 2) REMOVE val/test positive edges from edges_df
    #    so Node2Vec doesn't see them
    # -----------------------------------------------------
    print("\nRemoving val/test edges from the Node2Vec graph...")
    edges_filtered = remove_positive_edges_from_graph(edges_df, gt_val)
    edges_filtered = remove_positive_edges_from_graph(edges_filtered, gt_test)
    print(f"Original edges: {len(edges_df)} | Filtered edges: {len(edges_filtered)}")

    # -----------------------------------------------------
    # 3) Build Node2Vec graph (train only) & generate embeddings
    # -----------------------------------------------------
    if os.path.exists(STRUCT_EMB_PATH):
        print(f"Found existing {STRUCT_EMB_PATH}, loading...")
        with open(STRUCT_EMB_PATH, "rb") as f:
            structural_embeddings = pickle.load(f)
    else:
        print("Building train_graph and generating Node2Vec embeddings...")
        pyg_graph = build_graph_pyg(edges_filtered, node2idx)
        structural_embeddings = generate_node2vec_embeddings_pyg(pyg_graph)
        with open(STRUCT_EMB_PATH, "wb") as f:
            pickle.dump(structural_embeddings, f)

    # -----------------------------------------------------
    # 4) Semantic Embeddings (no leakage issue here)
    # -----------------------------------------------------
    if os.path.exists(SEMANTIC_EMB_PATH):
        print(f"Found existing {SEMANTIC_EMB_PATH}, loading...")
        with open(SEMANTIC_EMB_PATH, "rb") as f:
            semantic_embeddings = pickle.load(f)
    else:
        print("No semantic embedding pickle found. Generating embeddings...")
        embedder = SemanticEmbedder(model_name=MODEL_NAME_SEMANTIC)
        semantic_embeddings = embedder.encode_texts(node_texts)
        with open(SEMANTIC_EMB_PATH, "wb") as f:
            pickle.dump(semantic_embeddings, f)

    # -----------------------------------------------------
    # 5) Build final "hybrid" embeddings (semantic+structural)
    # -----------------------------------------------------
    hybrid_embeddings = build_hybrid_embeddings(semantic_embeddings, structural_embeddings)
    with open(HYBRID_EMB_PATH, "wb") as f:
        pickle.dump(hybrid_embeddings, f)

    # -----------------------------------------------------
    # 6) Prepare train/val/test sets using the split gt_dfs
    #    (Now that embeddings are done)
    # -----------------------------------------------------
    print("\nPreparing train/val/test feature tensors...")
    # Train
    X_train, y_train = prepare_dataset(gt_train, node2idx, hybrid_embeddings)
    # Val
    X_val, y_val = prepare_dataset(gt_val, node2idx, hybrid_embeddings)
    # Test
    X_test, y_test = prepare_dataset(gt_test, node2idx, hybrid_embeddings)

    # Scale features
    scaler = StandardScaler()
    X_train = scaler.fit_transform(X_train)
    X_val = scaler.transform(X_val)
    X_test = scaler.transform(X_test)

    print(f"Train: {X_train.shape}   Val: {X_val.shape}   Test: {X_test.shape}")
    print(f"Pos rate => Train: {y_train.mean():.2f}, Val: {y_val.mean():.2f}, Test: {y_test.mean():.2f}")

    # -----------------------------------------------------
    # 7) Baseline: Logistic Regression
    # -----------------------------------------------------
    print("\nFitting LogisticRegression on TRAIN only...")
    lr_clf = LogisticRegression(max_iter=LOGREG_MAX_ITER)
    lr_clf.fit(X_train, y_train)
    y_prob_lr = lr_clf.predict_proba(X_test)[:, 1]
    auc_lr = roc_auc_score(y_test, y_prob_lr)
    print(f"[LogReg] Test AUC: {auc_lr:.4f}")

    # -----------------------------------------------------
    # 8) MLP: Hyperparam search
    # -----------------------------------------------------
    print("\n=== MLP Hyperparam Search ===")
    best_mlp, best_config, best_val_auc = hyperparam_search_mlp(X_train, y_train, X_val, y_val)

    # Evaluate best MLP on test
    best_mlp.eval()
    X_test_t = torch.tensor(X_test, dtype=torch.float32).to(DEVICE)
    with torch.no_grad():
        test_out = best_mlp(X_test_t)
        test_prob = torch.sigmoid(test_out).cpu().numpy().flatten()

    auc_mlp = roc_auc_score(y_test, test_prob)
    aupr_mlp = average_precision_score(y_test, test_prob)
    preds_mlp = (test_prob > 0.5).astype(int)
    f1_mlp = f1_score(y_test, preds_mlp)

    print("\n=== Best MLP Model (Test Set) ===")
    print(f"Config: hidden_dim={best_config[0]}, lr={best_config[1]}, epochs={best_config[2]}")
    print(f"AUC:   {auc_mlp:.4f}")
    print(f"AUPR:  {aupr_mlp:.4f}")
    print(f"F1:    {f1_mlp:.4f}")


if __name__ == "__main__":
    main()