# Model Prototyping: Edge-Wise Interventional GNN (GNN-NCM)

**Goal.** Implement the **EdgeWiseGNNLayer** (per-edge MLP mechanisms) and training flows:
1) observational warm-up,
2) interventional regularization (CXGNN-inspired).

We also keep **shared-weights** and **vanilla GNN** baselines for ablation.


In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from pathlib import Path

# PyTorch Geometric imports
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import to_networkx

# Set seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

## 1. Building the GNN Neural Causal Model (GNN-NCM)

Here, we define our core model class. We will include detailed comments explaining how each part corresponds to the concepts in the reference code and the broader SCM framework. The gap between SCMs and GNNs lies mainly in:

1- Directionality: Causality is directed (A -> B is not the same as B -> A), but standard layers like GCNConv are inherently undirected (or symmetric).

2- Modularity vs. Parameter Sharing: An SCM has a unique causal mechanism (f_ij) for each parent-child relationship, while a GNN shares its message-passing weights (ψ) across all edges for scalability.


Therefore to improve causal fidelity we need to do build a GNN layer that is both directional and supports per-edge causal mechanisms, directly addressing the trade-off between common GNNs and SCMs. 

We start by implementing a simple directional message-passing layer


In [None]:
class DirectedGNNLayer(MessagePassing):
    """
    A minimal directed message passing layer where information flows from
    source nodes (parents) to target nodes (children).
    
    It uses a shared MLP for the message function (ψ) and another for the update function (φ).
    """
    def __init__(self, in_dim, hidden_dim, out_dim):
        # flow='source_to_target' is the key argument that enforces directionality.

        super().__init__(aggr='add', flow='source_to_target')

        # ψ (psi): The message function. 
        # It computes a message based on the parent's features.
        self.msg_mlp = nn.Sequential(
            nn.Linear(in_dim, hidden_dim), 
            nn.ReLU(),
            nn.Linear(hidden_dim, out_dim)
        )
        
        # φ (phi): The update function. 
        # It updates a node's representation by combining its original features (x_i) with the aggregated messages from all its parents.
        self.update_mlp = nn.Sequential(
            nn.Linear(in_dim + out_dim, hidden_dim), # Takes concatenated [child, aggregated_messages]
            nn.ReLU(),
            nn.Linear(hidden_dim, out_dim)
        )

    def message(self, x_j, x_i):
        """
        Defines the message from a source node j (parent) to a target node i (child).
        x_j is the parent's feature tensor, x_i is the child's.
        """
        return self.msg_mlp(x_j)

    def forward(self, x, edge_index):
        """ The main propagation method. """
        # special to MessagePassing, self.propagate will call message() for each edge and aggregate() for each node.
        aggregated_messages = self.propagate(edge_index, x=x)
        
        # Combine the node's original state with the messages from its parents.
        updated_embedding = self.update_mlp(torch.cat([x, aggregated_messages], dim=-1))
        return updated_embedding

SCM's structural equations `x_i := f(pa(x_i), U_i)` can have a unique function `f` for each distinct causal relationship. A standard GNN uses one `msg_mlp` for all edges.

We will create a GNN layer that can operate in two modes:
1.  **`'shared'`**: Behaves like a normal GNN (efficient but less causally faithful).
2.  **`'per_edge'`**: Instantiates a unique MLP for every single edge in the graph, (less scalable and requires more training data but causally faithful to SCMs and NCMs).

In [None]:
class EdgeWiseGNNLayer(MessagePassing):
    """
    A GNN layer that resolves the SCM modularity vs. GNN parameter sharing trade-off.
    
    In 'per_edge' mode, it instantiates a unique MLP for each edge, allowing it
    to learn distinct causal mechanisms for each parent-child relationship.
    """
    def __init__(self, in_dim, hidden_dim, out_dim, num_edges, mode='per_edge'):
        super().__init__(aggr='add', flow='source_to_target')
        
        self.mode = mode
        self.num_edges = num_edges
        self.out_dim = out_dim

        if self.mode == 'per_edge':
            # Create a list of MLPs, one for each edge. This is our f_ij.
            self.edge_mlps = nn.ModuleList([
                nn.Sequential(nn.Linear(in_dim, hidden_dim),nn.ReLU(),nn.Linear(hidden_dim, out_dim)) for i in range(num_edges)
            ])
        elif self.mode == 'shared': 
            # In shared mode, we only have one MLP for all edges.
            self.edge_mlps = nn.Sequential(nn.Linear(in_dim, hidden_dim),nn.ReLU(),nn.Linear(hidden_dim, out_dim))
        else:
            raise ValueError("Mode must be 'shared' or 'per_edge'")
        
        
        # The update function φ remains shared across all nodes.
        self.update_mlp = nn.Sequential(
            nn.Linear(in_dim + out_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, out_dim)
        )

    def forward(self, x, edge_index, original_edge_ids=None):
        """
        Executes one full message-passing step.

        Args:
            x (Tensor): The input node features.
            edge_index (Tensor): The connectivity for this pass.
            original_edge_ids (Tensor, optional): The original IDs of the edges in edge_index.
                                                   If None, assumes a full graph pass.
        """
        # If the original edge IDs aren't provided, we are doing a standard observational pass.
        if original_edge_ids is None:
            original_edge_ids = torch.arange(self.num_edges, device=x.device)
        
        # The propagate method will call message() and aggregate(), and its output
        # is the aggregated messages for each node.
        aggr_out = self.propagate(edge_index, x=x, original_edge_ids=original_edge_ids)
        
        # The update() method is then called to combine the aggregated messages with
        # the original node features to produce the final node embeddings.
        return self.update(aggr_out, x)


    def message(self, x_j, original_edge_ids):
        """
        Computes the message from parent (x_j) to child.
        
        Args:
            x_j (Tensor): The feature tensor of the source nodes (parents) for each edge.
            edge_ids (Tensor): A tensor containing the index of each edge, which we use
                               to select the appropriate MLP.
        """

        output_messages = torch.zeros(x_j.size(0), self.out_dim, device=x_j.device)
        if self.mode == 'per_edge':
            for i in range(self.num_edges):
                mask = (original_edge_ids == i)
                if mask.any(): output_messages[mask] = self.edge_mlps[i](x_j[mask])
            return output_messages
        else: 
            return self.edge_mlps(x_j)

    def update(self, aggr_out, x):
        return self.update_mlp(torch.cat([x, aggr_out], dim=-1))


## 2. The GNN-NCM Architecture: A Principled SCM Analogue

First, we define our core model architecture. This consists of two main components:

- EdgeWiseGNNLayer: This is our novel GNN layer that resolves the critical trade-off between SCM modularity and GNN parameter sharing. In 'per_edge' mode, it instantiates a unique MLP for each causal link in the graph, directly modeling the SCM principle that each parent -> child relationship has its own distinct mechanism f_ij.
- GNN_NCM: This is the main model class. It orchestrates the EdgeWiseGNNLayers and explicitly implements other SCM principles, namely the inclusion of exogenous noise U and the do_intervention method for simulating interventions.

In [None]:
class GNN_NCM(nn.Module):
    """
    A Graph Neural Network - Neural Causal Model (GNN-NCM).
    
    This model is designed to mimic a Structural Causal Model (SCM). It learns
    from observational data and can then predict outcomes under interventions
    (do-operations) by performing "graph surgery".
    
    This version uses a causally-faithful `EdgeWiseGNNLayer`.
    """
    def __init__(self, num_features, hidden_dim, out_dim, num_edges, noise_dim=4, gnn_mode="per_edge"):
        
        """
        Initializes the GNN-NCM.

        Args:
            num_features (int): The number of input features for each node.
            hidden_dim (int): The dimensionality of the hidden MLPs inside the GNN layers.
            out_dim (int): The output dimensionality of the GNN layers.
            num_edges (int): ### ANNOTATION ### The total number of edges in the full graph.
                               This is a NEW and ESSENTIAL argument for EdgeWiseGNNLayer
                               to know how many unique MLPs to create.
            noise_dim (int): The dimensionality of the exogenous noise vector 'U'.
            gnn_mode (str): Either 'per_edge' for max causal fidelity or 'shared' for efficiency.
        """


        super().__init__()
        self.noise_dim = noise_dim
        self.num_edges = num_edges
        
        # In an SCM, x_i := f(pa(x_i), U_i). We model this by adding noise to the input.

        input_dim = num_features + noise_dim
        
        # Instead of standard GCNConv we use our layer
        self.conv1 = EdgeWiseGNNLayer(
            in_dim=input_dim, 
            out_dim=hidden_dim, 
            hidden_dim=hidden_dim, 
            num_edges=num_edges, 
            mode=gnn_mode
        )

        self.conv2 = EdgeWiseGNNLayer(
            in_dim=hidden_dim, 
            out_dim=out_dim,   
            hidden_dim=hidden_dim, 
            num_edges=num_edges, 
            mode=gnn_mode
        )

        
        # A final linear layer to produce the output prediction
        self.out = nn.Linear(out_dim, 1)


    def forward(self, x, edge_index):
        """
        Performs the standard OBSERVATIONAL forward pass.
        
        This simulates the system where no intervention has been done.

        Args:
            x (Tensor): Node features of shape [num_nodes, num_features].
            edge_index (Tensor): Graph connectivity in with shape [2, num_edges].
            
        Returns:
            Tensor: The output prediction for each node (e.g., logits for classification).
        """

        # 1. Inject Exogenous Noise (U) 
        noise = torch.randn(x.size(0), self.noise_dim, device=x.device)
        x_with_noise = torch.cat([x, noise], dim=1)
        
        # 2. Message Passing
        h = F.relu(self.conv1(x_with_noise, edge_index))
        h = F.dropout(h, p=0.5, training=self.training)
        h = F.relu(self.conv2(h, edge_index))
        
        # 3. Final Prediction
        return self.out(h)

    def do_intervention(self, x, edge_index, intervened_nodes, new_feature_values):
        """
        Performs an INTERVENTIONAL forward pass, simulating a do-operation.
        
        This is the core of the causal inference capability. In the SCM framework, a
        do-operation `do(X_i = v)` means we replace the mechanism that generates X_i
        with a constant value 'v', severing the influence of its parents.
        
        Args:
            x (Tensor): The ORIGINAL node features before intervention.
            edge_index (Tensor): The ORIGINAL graph structure.
            intervened_nodes (Tensor): A 1D tensor of node indices to intervene on.
            new_feature_values (Tensor): A tensor of new feature values to clamp onto the
                                         intervened nodes. Shape: [num_intervened_nodes, num_features].
        
        Returns:
            Tensor: The post-intervention predictions for all nodes in the graph.
        """
       
        
        # Step 1: Clamp Node Features
        x_intervened = x.clone()
        x_intervened[intervened_nodes] = new_feature_values


        # Step 2: Perform "Graph Mutilation"
        # We sever all causal links (edges) pointing into the intervened node.
        edge_mask = torch.ones(edge_index.size(1), dtype=torch.bool, device=x.device)
        for node_idx in intervened_nodes:
            edge_mask &= (edge_index[1] != node_idx)
        
        intervened_edge_index = edge_index[:, edge_mask]


        # Step 3: Run the forward pass on the surgically-modified graph
        # We cannot simply call forward since EdgeWiseGNNLayer was initialized knowing about num_edges.

        all_edge_ids = torch.arange(self.num_edges, device=x.device)
        intervened_edge_ids = all_edge_ids[edge_mask]
        
        # Manually perform the forward pass on the modified graph structure
        noise = torch.randn(x.size(0), self.noise_dim, device=x.device)
        x_with_noise = torch.cat([x_intervened, noise], dim=1)
        
        # Pass through Layer 1
        h1 = self.conv1(x_with_noise, intervened_edge_index, original_edge_ids=intervened_edge_ids)
        h1 = F.relu(h1)
        
        # Pass through Layer 2
        h2 = self.conv2(h1, intervened_edge_index, original_edge_ids=intervened_edge_ids)
        h2 = F.relu(h2)
        
        return self.out(h2)


In [None]:
# Example

# Assume you have a graph:
num_nodes = 10
num_features = 5
num_edges = 25 

# Define model dimensions
hidden_dim = 32
out_dim = 16

# Instantiate the model
causal_model = GNN_NCM(
    num_features=num_features,
    hidden_dim=hidden_dim,
    out_dim=out_dim,
    num_edges=num_edges,
    gnn_mode='per_edge' # or 'shared'
)

print(causal_model)

### 3. A Review of Causal Training Strategies

Our GNN-NCM has the architectural capacity for causality, but this is useless unless it is trained in a way that respects causal principles. Since we assume we only have access to observational data, we need a clever training algorithm. Let's review the main families of approaches.

* Strategy 1: Supervised Interventional Training (The Ideal Case)

    As implemented, if we are lucky enough to have a dataset of (pre-state, intervention, post-state) tuples, we can directly supervise the do_intervention method. We analyize this more in the synthetic dataset notebook


* Strategy 2: Interventional Synthesis 

    This is the approach pioneered by CXGNN and is the most direct fit for our SCM-like model. As detailed below, it uses the model's own interventional predictions as a way to self-supervise and derive a causally-plausible estimate of the observational outcome.

In [None]:
import torch, torch.nn as nn, torch.optim as optim

class CausalTwoPartTrainer:
    def __init__(self,
                 epochs_obs=30,
                 epochs_do=150,
                 lr=5e-3,
                 w_obs=0.2,
                 w_do=1.0,
                 weight_decay=1e-4,
                 clip=1.0,
                 neutral='zeros',     # 'zeros' or 'self+delta'
                 delta=0.0):          # scalar or 1D tensor of size F (used when neutral == 'self+delta')
        self.epochs_obs = int(epochs_obs)
        self.epochs_do  = int(epochs_do)
        self.lr  = float(lr)
        self.w_obs = float(w_obs)
        self.w_do  = float(w_do)
        self.wd  = float(weight_decay)
        self.clip = float(clip)

        self.neutral = str(neutral)
        self.delta = delta

        self.loss = nn.MSELoss()
        self.history = []

    
    def train(self, model, loader, val_loader=None):
        dev = next(model.parameters()).device
        model = model.to(dev)

        # Phase 1: observational warm-up (obs only) 
        opt = optim.AdamW(model.parameters(), lr=self.lr, weight_decay=self.wd)
        for ep in range(1, self.epochs_obs + 1):
            model.train()
            obs_sum, n_obs = 0.0, 0
            for g in loader:
                g = g.to(dev)
                pred = model(g.x, g.edge_index)
                l_obs = self.loss(pred, g.y)

                opt.zero_grad()
                l_obs.backward()
                if self.clip: torch.nn.utils.clip_grad_norm_(model.parameters(), self.clip)
                opt.step()

                obs_sum += float(l_obs.detach()); n_obs += 1

            m_obs = obs_sum / n_obs
            m_val= self.evaluate_obs_mse(model, val_loader)

            self.history.append({
                "epoch": ep, "phase": "obs",
                "loss_obs": m_obs, "loss_do": None, "loss_total": m_obs,
                "val_obs": m_val
            })
            if ep % 10 == 0:
                msg = f"[obs {ep:03d}] obs={m_obs:.6f}"
                if m_val is not None: msg += f" | val_obs={m_val:.6f}"
                print(msg)


        # Phase 2: obs + causal (one combined step per batch)
        # reset optimizer to avoid stale momentum from Phase 1
        opt = optim.AdamW(model.parameters(), lr=self.lr, weight_decay=self.wd)

        
        for ep in range(1, self.epochs_do + 1):
            model.train()
            obs_sum, do_sum, n_obs, n_do = 0.0, 0.0, 0, 0

            for g in loader:
                g = g.to(dev)
                x, edge_index, y = g.x, g.edge_index, g.y

                # observational term
                p_obs = model(x, edge_index)
                l_obs = self.loss(p_obs, y)
                obs_sum += float(l_obs.detach()); n_obs += 1

                # causal term: do(parent) one at a time, aggregate to each child
                l_cau = self._causal_loss_do_parent_average(model, g, p_obs)
                do_sum += float(l_cau.detach()); n_do += 1

                # combine and step once
                total = (self.w_obs * l_obs) + (self.w_do * l_cau)
                opt.zero_grad()
                total.backward()
                if self.clip: torch.nn.utils.clip_grad_norm_(model.parameters(), self.clip)
                opt.step()

            m_obs = obs_sum / max(n_obs, 1) if n_obs else 0.0
            m_do  = do_sum  / max(n_do,  1) if n_do  else 0.0
            total_epoch = (self.w_obs * m_obs) + (self.w_do * m_do if n_do else 0.0)

            m_val= self.evaluate_obs_mse(model, val_loader)

            ep_abs = self.epochs_obs + ep
            self.history.append({
                "epoch": ep_abs, "phase": "do",
                "loss_obs": m_obs, "loss_do": m_do, "loss_total": total_epoch,
                "val_obs": m_val
            })
            if ep % 10 == 0:
                msg = f"[do  {ep:03d}] total={total_epoch:.6f} (obs={m_obs:.6f}, do={m_do:.6f})"
                if m_val is not None: msg += f" | val_obs={m_val:.6f}"
                print(msg)


        return model

    def _causal_loss_do_parent_average(self, model, g, p_obs):
        """
        Implements causal loss:
          - For each UNIQUE parent node p (from edges src->dst), compute prediction under do(p)
            where features of p are set to a neutral row ('zeros' or 'self+delta').
          - For each child v, average the do(p)[v] over all parents p of v to form a target.
          - Compare that target to TRUE y[v] (MSE), average over nodes that have parents.

        """
        dev = p_obs.device
        x, edge_index, y = g.x, g.edge_index, g.y
        N, F = x.size(0), x.size(1)


        src, dst = edge_index[0], edge_index[1]
        if src.numel() == 0:
            # no edges -> no parents -> causal term 0 
            return 0.0 * p_obs.sum()

        unique_parents = torch.unique(src)

        # prepare delta row if needed
        if isinstance(self.delta, torch.Tensor):
            delta_row = self.delta.to(device=dev, dtype=x.dtype)
        else:
            delta_row = torch.full((F,), float(self.delta), device=dev, dtype=x.dtype)

        # compute do(parent) predictions once per parent
        p1_map = {}
        for p in unique_parents.tolist():
            p = int(p)
            if self.neutral == 'zeros':
                new_row = torch.zeros(F, device=dev, dtype=x.dtype)
            else:  # 'self+delta'
                new_row = x[p] + delta_row

            if hasattr(model, "do_intervention") and callable(getattr(model, "do_intervention")):
                p1 = model.do_intervention(
                    x, edge_index,
                    intervened_nodes=torch.tensor([p], dtype=torch.long, device=dev),
                    new_feature_values=new_row.unsqueeze(0)
                ) 
            else:
                # fallback: override node p's features and run a forward pass
                x_do = x.clone()
                x_do[p] = new_row
                p1 = model(x_do, edge_index) 
            p1_map[p] = p1

        # aggregate targets per child and compare to TRUE y[v]
        loss_causal = 0.0
        count = 0
        for v in range(N):
            parents_v = src[dst == v]
            if parents_v.numel() == 0:
                continue
            vals = []
            for p in parents_v.tolist():
                if p in p1_map:
                    vals.append(p1_map[p][v])  # prediction at node v under do(p)
            if not vals:
                continue
            target_v = torch.stack(vals, dim=0).mean(dim=0)  # averaged target for node v
            loss_causal = loss_causal + self.loss(target_v, y[v])
            count += 1

        return (loss_causal / count) if count > 0 else (0.0 * p_obs.sum())

    @torch.no_grad()
    def evaluate_obs_mse(self, model, loader):
        model.eval()
        dev = next(model.parameters()).device
        tot, n = 0.0, 0
        for g in loader:
            g = g.to(dev)
            p = model(g.x, g.edge_index)  
            y = g.y                        
            if p.dim()==2 and p.size(-1)==1: p = p.squeeze(-1)
            if y.dim()==2 and y.size(-1)==1: y = y.squeeze(-1)


            tot += self.loss(p, y).item(); n += 1
        return tot / max(n, 1)



In [None]:
from pathlib import Path
import sys

PROJECT_ROOT = Path().resolve().parent
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))  

PROCCESSED_DATA_DIR = PROJECT_ROOT / "data" / "processed"
PROCCESSED_DATA_DIR.mkdir(parents=True, exist_ok=True)


from torch_geometric.loader import DataLoader
from torch.utils.data import Subset
from src.dataloader import CausalFactorDataset

full_dataset = CausalFactorDataset(root_dir=PROCCESSED_DATA_DIR, target_node="VOL", feature_col=None, drop_self_for_target=True)

# --- use it with a DataLoader and plot ---
import pandas as pd, matplotlib.pyplot as plt

split = int(0.8*len(full_dataset))
train_loader = DataLoader(Subset(full_dataset, range(split)), batch_size=1, shuffle=True)
val_loader   = DataLoader(Subset(full_dataset, range(split, len(full_dataset))), batch_size=1, shuffle=False)

# dims
g0 = next(iter(train_loader))
num_features = g0.num_node_features
num_edges    = g0.edge_index.size(1)

# model
model = GNN_NCM(num_features=num_features, num_edges=num_edges, gnn_mode='per_edge',
                hidden_dim=32, out_dim=16).to(g0.x.device)

# trainer (two-phase, fixed gamma)
trainer = CausalTwoPartTrainer(
    epochs_obs=40, epochs_do=20, w_obs=0.2, w_do=1.0,
    neutral='zeros',           # or 'self+delta'
    delta=0.1                  # scalar or 1D tensor (used only if neutral='self+delta')
)


trainer.train(model, train_loader, val_loader=val_loader)  

# plot
df = pd.DataFrame(trainer.history)

plt.figure(figsize=(10,5))
if "val_obs" in df and df["val_obs"].notna().any():
    plt.plot(df["epoch"], df["val_obs"], label="val_obs", linewidth=3)
if "loss_obs" in df: plt.plot(df["epoch"], df["loss_obs"], label="train_obs", alpha=0.6)
if "loss_total" in df: plt.plot(df["epoch"], df["loss_total"], label="train_total", alpha=0.6)
plt.xlabel("epoch"); plt.ylabel("loss"); plt.title("Validation loss over training")
plt.legend(); plt.grid(True, linestyle='--'); plt.tight_layout(); plt.show()

print("final val_obs =", trainer.evaluate_obs_mse(model, val_loader))


## 4. Hyperparameter Selection

In [None]:
from torch_geometric.loader import DataLoader
from torch.utils.data import Subset
from itertools import product
from tqdm.auto import tqdm
import numpy as np
import pandas as pd
import torch
import yaml
import time
import math
import random

def make_loaders(dataset, split=0.8, batch_size=1):
    n = len(dataset)
    s = int(split * n)
    train_idx = list(range(s))
    val_idx   = list(range(s, n))
    train_loader = DataLoader(Subset(dataset, train_idx), batch_size=batch_size, shuffle=True)
    val_loader   = DataLoader(Subset(dataset, val_idx),   batch_size=batch_size, shuffle=False)
    return train_loader, val_loader

def sample_config(rng, space):
    """Sample one config from a mixed search space with discrete choices and ranges."""
    cfg = {}
    for k, spec in space.items():
        if isinstance(spec, list):
            cfg[k] = rng.choice(spec)
        else:
            raise ValueError(f"Unsupported search spec for {k}: {spec}")
    return cfg

def tune_hyperparameters_random(dataset, n_trials=40, seed=42, val_node_key=None):
    # Set seeds for reproducibility
    rng = np.random.default_rng(seed)
    torch.manual_seed(seed)
    random.seed(seed)

    train_loader, val_loader = make_loaders(dataset, split=0.8, batch_size=1)
    g0 = dataset[0]
    num_features = g0.num_node_features
    num_edges    = g0.edge_index.size(1)
    device       = g0.x.device


    # define a search space
    search_space = {
        # model
        "hidden_dim": [16, 32, 64],
        "out_dim":    [8, 16, 32],

        # optimizer
        "lr": [1e-4, 5e-3, 1e-3, 1e-2],            
        "weight_decay": [1e-4, 1e-3],  

        # trainer knobs
        "epochs_obs": [20, 40, 60],
        "epochs_do":  [10, 20, 30],
        "w_obs":      [0.1, 0.2, 0.5],
        "w_do":       [0.5, 1.0, 2.0],
        "clip":       [0.5, 1.0, 2.0],
        "neutral":    ["zeros", "self+delta"],
        "delta":      [0.0, 0.1, 1.0],   # only used if neutral == 'self+delta'
    }

    results = []
    best = {"val": float("inf"), "cfg": None, "history": None}

    print(f"Starting random search with {n_trials} trials...")
    for t in tqdm(range(n_trials), desc="Random Search"):
        # Sample a random configuration in each trial
        cfg = sample_config(rng, search_space)

        # build model
        model = GNN_NCM(num_features=num_features, num_edges=num_edges,
                        gnn_mode='per_edge',
                        hidden_dim=int(cfg["hidden_dim"]),
                        out_dim=int(cfg["out_dim"])).to(device)

        # build trainer
        trainer = CausalTwoPartTrainer(
            epochs_obs=int(cfg["epochs_obs"]),
            epochs_do=int(cfg["epochs_do"]),
            lr=float(cfg["lr"]),
            w_obs=float(cfg["w_obs"]),
            w_do=float(cfg["w_do"]),
            weight_decay=float(cfg["weight_decay"]),
            clip=float(cfg["clip"]),
            neutral=cfg["neutral"],
            delta=float(cfg["delta"])
        )

        # train with validation tracking
    
        trainer.train(model, train_loader, val_loader=val_loader)

        # choose best validation obs over epochs
        df = pd.DataFrame(trainer.history)
        if "val_obs" in df and df["val_obs"].notna().any():
            val_best = float(df["val_obs"].min())
        else:
            # fallback: compute once
            val_best = float(trainer.evaluate_obs_mse(model, val_loader))

        results.append({"trial": t, "val_best": val_best, "cfg": cfg})

        if val_best < best["val"]:
            best = {"val": val_best, "cfg": cfg, "history": trainer.history}
            print(f"New best @ trial {t}: val={val_best:.6f} cfg={cfg}")

    # pretty-print best
    print("\n--- Hyperparameter Tuning Complete ---")
    print(f"Best validation loss: {best['val']:.6f}")
    print("Best hyperparameters:")

    best_cfg = {
        "experiment_name": "GNN_NCM_TwoPart_VOL",
        "seed": seed,
        "device": "cuda",
        "data": {
            "target_node": "VOL",
            "batch_size": 1,
        },
        "model": {
            "name": "GNN_NCM",
            "gnn_mode": "per_edge",
            "hidden_dim": int(best["cfg"]["hidden_dim"]),
            "out_dim": int(best["cfg"]["out_dim"]),
            "noise_dim": 4,
        },
        "training": {
            "type": "CausalTwoPartTrainer",
            "epochs_obs": int(best["cfg"]["epochs_obs"]),
            "epochs_do":  int(best["cfg"]["epochs_do"]),
            "lr": float(best["cfg"]["lr"]),
            "weight_decay": float(best["cfg"]["weight_decay"]),
            "clip": float(best["cfg"]["clip"]),
            "w_obs": float(best["cfg"]["w_obs"]),
            "w_do":  float(best["cfg"]["w_do"]),
            "neutral": str(best["cfg"]["neutral"]),
            "delta": float(best["cfg"]["delta"]),
        },
        "output": {"output_dir": "outputs"},
    }


    return best_cfg, results


### Saving the Parameters

In [None]:
# Run the tuning process
best, results = tune_hyperparameters_random(full_dataset, n_trials=40, seed=42)  


In [None]:
# Save the final config file
# Ensure the configs directory exists

from pathlib import Path
import sys

PROJECT_ROOT = Path().resolve().parent
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))  

CONFIG = PROJECT_ROOT / "configs" /"best_config.yaml"

with open(CONFIG, 'w') as f:
    yaml.dump(best, f, indent=2, sort_keys=False)