# Prototyping the GNN Neural Causal Model

The primary objective of this notebook is to develop the core GNN_NCM class. We will take direct inspiration from CXGNN by Berham and then implement our own version using PyTorch Geometric

In [2]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from pathlib import Path

# PyTorch Geometric imports
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import to_networkx

# Set seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

  from .autonotebook import tqdm as notebook_tqdm


### Building the GNN Neural Causal Model (GNN_NCM)

Here, we define our core model class. We will include detailed comments explaining how each part corresponds to the concepts in the reference code and the broader SCM framework. The gap between SCMs and GNNs lies mainly in:

1- Directionality: Causality is directed (A -> B is not the same as B -> A), but standard layers like GCNConv are inherently undirected (or symmetric).

2- Modularity vs. Parameter Sharing: An SCM has a unique causal mechanism (f_ij) for each parent-child relationship, while a GNN shares its message-passing weights (ψ) across all edges for scalability.


Therefore to improve causal fidelity we need to do build a GNN layer that is both directional and supports per-edge causal mechanisms, directly addressing the trade-off between common GNNs and SCMs. 

We start by implementing a simple directional message-passing layer


In [3]:
class DirectedGNNLayer(MessagePassing):
    """
    A minimal directed message passing layer where information flows from
    source nodes (parents) to target nodes (children).
    
    It uses a shared MLP for the message function (ψ) and another for the update function (φ).
    """
    def __init__(self, in_dim, hidden_dim, out_dim):
        # flow='source_to_target' is the key argument that enforces directionality.

        super().__init__(aggr='add', flow='source_to_target')

        # ψ (psi): The message function. 
        # It computes a message based on the parent's features.
        self.msg_mlp = nn.Sequential(
            nn.Linear(in_dim, hidden_dim), 
            nn.ReLU(),
            nn.Linear(hidden_dim, out_dim)
        )
        
        # φ (phi): The update function. 
        # It updates a node's representation by combining its original features (x_i) with the aggregated messages from all its parents.
        self.update_mlp = nn.Sequential(
            nn.Linear(in_dim + out_dim, hidden_dim), # Takes concatenated [child, aggregated_messages]
            nn.ReLU(),
            nn.Linear(hidden_dim, out_dim)
        )

    def message(self, x_j, x_i):
        """
        Defines the message from a source node j (parent) to a target node i (child).
        x_j is the parent's feature tensor, x_i is the child's.
        """
        return self.msg_mlp(x_j)

    def forward(self, x, edge_index):
        """ The main propagation method. """
        # special to MessagePassing, self.propagate will call message() for each edge and aggregate() for each node.
        aggregated_messages = self.propagate(edge_index, x=x)
        
        # Combine the node's original state with the messages from its parents.
        updated_embedding = self.update_mlp(torch.cat([x, aggregated_messages], dim=-1))
        return updated_embedding

SCM's structural equations `x_i := f(pa(x_i), U_i)` can have a unique function `f` for each distinct causal relationship. A standard GNN uses one `msg_mlp` for all edges.

We will create a GNN layer that can operate in two modes:
1.  **`'shared'`**: Behaves like a normal GNN (efficient but less causally faithful).
2.  **`'per_edge'`**: Instantiates a unique MLP for every single edge in the graph, (less scalable and requires more training data but causally faithful to SCMs and NCMs).

In [4]:
class EdgeWiseGNNLayer(MessagePassing):
    """
    A GNN layer that resolves the SCM modularity vs. GNN parameter sharing trade-off.
    
    In 'per_edge' mode, it instantiates a unique MLP for each edge, allowing it
    to learn distinct causal mechanisms for each parent-child relationship.
    """
    def __init__(self, in_dim, hidden_dim, out_dim, num_edges, mode='per_edge'):
        super().__init__(aggr='add', flow='source_to_target')
        
        self.mode = mode
        self.num_edges = num_edges
        self.out_dim = out_dim

        if self.mode == 'per_edge':
            # Create a list of MLPs, one for each edge. This is our f_ij.
            # nn.ModuleList is crucial for PyTorch to recognize these as submodules.
            self.edge_mlps = nn.ModuleList([
                nn.Sequential(nn.Linear(in_dim, hidden_dim),nn.ReLU(),nn.Linear(hidden_dim, out_dim)) for i in range(num_edges)
            ])
        elif self.mode == 'shared': 
            # In shared mode, we only have one MLP for all edges.
            self.edge_mlps = nn.Sequential(nn.Linear(in_dim, hidden_dim),nn.ReLU(),nn.Linear(hidden_dim, out_dim))
        else:
            raise ValueError("Mode must be 'shared' or 'per_edge'")
        
        
        # The update function φ remains shared across all nodes.
        self.update_mlp = nn.Sequential(
            nn.Linear(in_dim + out_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, out_dim)
        )

    def forward(self, x, edge_index, original_edge_ids=None):
        """
        Executes one full message-passing step.

        Args:
            x (Tensor): The input node features.
            edge_index (Tensor): The connectivity for this pass.
            original_edge_ids (Tensor, optional): The original IDs of the edges in edge_index.
                                                   If None, assumes a full graph pass.
        """
        # If the original edge IDs aren't provided, we are doing a standard observational pass.
        if original_edge_ids is None:
            original_edge_ids = torch.arange(self.num_edges, device=x.device)
        
        # The propagate method will call message() and aggregate(), and its output
        # is the aggregated messages for each node.
        aggr_out = self.propagate(edge_index, x=x, original_edge_ids=original_edge_ids)
        
        # The update() method is then called to combine the aggregated messages with
        # the original node features to produce the final node embeddings.
        return self.update(aggr_out, x)


    def message(self, x_j, original_edge_ids):
        """
        Computes the message from parent (x_j) to child.
        
        Args:
            x_j (Tensor): The feature tensor of the source nodes (parents) for each edge.
            edge_ids (Tensor): A tensor containing the index of each edge, which we use
                               to select the appropriate MLP.
        """
        # This is a bit complex, but powerful. We can't apply all MLPs at once.
        # We create a placeholder for the output messages.

        output_messages = torch.zeros(x_j.size(0), self.out_dim, device=x_j.device)
        if self.mode == 'per_edge':
            for i in range(self.num_edges):
                mask = (original_edge_ids == i)
                if mask.any(): output_messages[mask] = self.edge_mlps[i](x_j[mask])
            return output_messages
        else: 
            return self.edge_mlps(x_j)

    def update(self, aggr_out, x):
        return self.update_mlp(torch.cat([x, aggr_out], dim=-1))


## The GNN-NCM Architecture: A Principled SCM Analogue

First, we define our core model architecture. This consists of two main components:

- EdgeWiseGNNLayer: This is our novel GNN layer that resolves the critical trade-off between SCM modularity and GNN parameter sharing. In 'per_edge' mode, it instantiates a unique MLP for each causal link in the graph, directly modeling the SCM principle that each parent -> child relationship has its own distinct mechanism f_ij.
- GNN_NCM: This is the main model class. It orchestrates the EdgeWiseGNNLayers and explicitly implements other SCM principles, namely the inclusion of exogenous noise U and the do_intervention method for simulating interventions.

In [5]:
class GNN_NCM(nn.Module):
    """
    A Graph Neural Network - Neural Causal Model (GNN-NCM).
    
    This model is designed to mimic a Structural Causal Model (SCM). It learns
    from observational data and can then predict outcomes under interventions
    (do-operations) by performing "graph surgery".
    
    This version uses a causally-faithful `EdgeWiseGNNLayer`.
    """
    def __init__(self, num_features, hidden_dim, out_dim, num_edges, noise_dim=4, gnn_mode="per_edge"):
        
        """
        Initializes the GNN-NCM.

        Args:
            num_features (int): The number of input features for each node.
            hidden_dim (int): The dimensionality of the hidden MLPs inside the GNN layers.
            out_dim (int): The output dimensionality of the GNN layers.
            num_edges (int): ### ANNOTATION ### The total number of edges in the full graph.
                               This is a NEW and ESSENTIAL argument for EdgeWiseGNNLayer
                               to know how many unique MLPs to create.
            noise_dim (int): The dimensionality of the exogenous noise vector 'U'.
            gnn_mode (str): Either 'per_edge' for max causal fidelity or 'shared' for efficiency.
        """


        super().__init__()
        self.noise_dim = noise_dim
        self.num_edges = num_edges
        
        # In an SCM, x_i := f(pa(x_i), U_i). We model this by adding noise to the input.

        input_dim = num_features + noise_dim
        
        # Instead of standard GCNConv we use our layer
        self.conv1 = EdgeWiseGNNLayer(
            in_dim=input_dim, 
            out_dim=hidden_dim, 
            hidden_dim=hidden_dim, # Internal MLP size
            num_edges=num_edges, 
            mode=gnn_mode
        )

        self.conv2 = EdgeWiseGNNLayer(
            in_dim=hidden_dim, # Input is the output of the previous layer
            out_dim=out_dim,   # Final GNN output dimension is out_dim
            hidden_dim=hidden_dim, 
            num_edges=num_edges, 
            mode=gnn_mode
        )

        
        # A final linear layer to produce the output prediction (e.g., a node label probability).
        self.out = nn.Linear(out_dim, 1)


    def forward(self, x, edge_index):
        """
        Performs the standard OBSERVATIONAL forward pass.
        
        This simulates the system where no intervention has been done.

        Args:
            x (Tensor): Node features of shape [num_nodes, num_features].
            edge_index (Tensor): Graph connectivity in with shape [2, num_edges].
            
        Returns:
            Tensor: The output prediction for each node (e.g., logits for classification).
        """

        # 1. Inject Exogenous Noise (U) 
        noise = torch.randn(x.size(0), self.noise_dim, device=x.device)
        x_with_noise = torch.cat([x, noise], dim=1)
        
        # 2. Message Passing
        h = F.relu(self.conv1(x_with_noise, edge_index))
        h = F.dropout(h, p=0.5, training=self.training)
        h = F.relu(self.conv2(h, edge_index))
        
        # 3. Final Prediction
        return self.out(h)

    def do_intervention(self, x, edge_index, intervened_nodes, new_feature_values):
        """
        Performs an INTERVENTIONAL forward pass, simulating a do-operation.
        
        This is the core of the causal inference capability. In the SCM framework, a
        do-operation `do(X_i = v)` means we replace the mechanism that generates X_i
        with a constant value 'v', severing the influence of its parents.
        
        Args:
            x (Tensor): The ORIGINAL node features before intervention.
            edge_index (Tensor): The ORIGINAL graph structure.
            intervened_nodes (Tensor): A 1D tensor of node indices to intervene on.
            new_feature_values (Tensor): A tensor of new feature values to clamp onto the
                                         intervened nodes. Shape: [num_intervened_nodes, num_features].
        
        Returns:
            Tensor: The post-intervention predictions for all nodes in the graph.
        """
       
        
        # Step 1: Clamp Node Features
        x_intervened = x.clone()
        x_intervened[intervened_nodes] = new_feature_values


        # Step 2: Perform "Graph Mutilation"
        # We sever all causal links (edges) pointing into the intervened node.
        edge_mask = torch.ones(edge_index.size(1), dtype=torch.bool, device=x.device)
        for node_idx in intervened_nodes:
            edge_mask &= (edge_index[1] != node_idx)
        
        intervened_edge_index = edge_index[:, edge_mask]


        # Step 3: Run the forward pass on the surgically-modified graph
        # We cannot simply call forward since EdgeWiseGNNLayer was initialized knowing about num_edges.

        all_edge_ids = torch.arange(self.num_edges, device=x.device)
        intervened_edge_ids = all_edge_ids[edge_mask]
        
        # Manually perform the forward pass on the modified graph structure
        noise = torch.randn(x.size(0), self.noise_dim, device=x.device)
        x_with_noise = torch.cat([x_intervened, noise], dim=1)
        
        # Pass through Layer 1
        h1 = self.conv1(x_with_noise, intervened_edge_index, original_edge_ids=intervened_edge_ids)
        h1 = F.relu(h1)
        
        # Pass through Layer 2
        h2 = self.conv2(h1, intervened_edge_index, original_edge_ids=intervened_edge_ids)
        h2 = F.relu(h2)
        
        return self.out(h2)


In [6]:
# --- Example Usage ---

# Assume you have a graph:
num_nodes = 10
num_features = 5
num_edges = 25 # You must know this value beforehand

# Define model dimensions
hidden_dim = 32
out_dim = 16

# Instantiate the new, more powerful model
causal_model = GNN_NCM(
    num_features=num_features,
    hidden_dim=hidden_dim,
    out_dim=out_dim,
    num_edges=num_edges,
    gnn_mode='per_edge' # or 'shared'
)

print(causal_model)

GNN_NCM(
  (conv1): EdgeWiseGNNLayer()
  (conv2): EdgeWiseGNNLayer()
  (out): Linear(in_features=16, out_features=1, bias=True)
)


### A Review of Causal Training Strategies

Our GNN-NCM has the architectural capacity for causality, but this is useless unless it is trained in a way that respects causal principles. Since we assume we only have access to observational data, we need a clever training algorithm. Let's review the main families of approaches.

* Strategy 1: Supervised Interventional Training (The Ideal Case)

    As implemented, if we are lucky enough to have a dataset of (pre-state, intervention, post-state) tuples, we can directly supervise the do_intervention method.

* Strategy 2: Interventional Synthesis 

    This is the approach pioneered by CXGNN and is the most direct fit for our SCM-like model. As detailed below, it uses the model's own interventional predictions as a way to self-supervise and derive a causally-plausible estimate of the observational outcome.

In [7]:
def train_with_interventional_data(model, dataset, epochs=50, lr=0.001):
    """
    Trains the GNN-NCM by directly supervising its `do_intervention` method.
    """
    optimizer = optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.MSELoss()
    
    print("Starting supervised interventional training...")
    for epoch in range(epochs):
        total_loss = 0
        for pre_data, intervention, post_data_true in dataset:
            optimizer.zero_grad()
            
            # Use the model's causal mechanism to predict the outcome
            post_data_pred_x = model.do_intervention(
                x=pre_data.x,
                edge_index=pre_data.edge_index,
                intervened_nodes=intervention["node_idx"],
                new_feature_values=intervention["new_value"]
            )
            
            # The loss is how far our prediction is from the real outcome
            loss = loss_fn(post_data_pred_x, post_data_true.x)
            
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            
        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1:02d}, Average Interventional Loss: {total_loss / len(dataset):.6f}")
    print("Training finished.")


In [8]:
def generate_ideal_interventional_data(num_samples=400):
    """
    Creates a synthetic dataset of (pre-state, intervention, post-state) tuples.
    We define a simple causal law: a child's feature is 0.8 * parent_feature + noise.
    """
    print("Generating ideal interventional dataset...")
    # Define a fixed causal graph structure
    edge_index = torch.tensor([[0, 0, 1], [1, 2, 3]], dtype=torch.long) # 0->1, 0->2, 1->3
    num_nodes = 4
    num_features = 1

    dataset = []
    for _ in range(num_samples):
        # 1. Create a random "pre-intervention" state
        pre_x = torch.randn(num_nodes, num_features)
        
        # 2. Define a random intervention
        intervened_node = np.random.randint(0, num_nodes)
        new_value = torch.randn(1, num_features)
        
        # 3. Calculate the true "post-intervention" state based on our known causal law
        post_x = pre_x.clone()
        post_x[intervened_node] = new_value # Clamp the intervened node's value
        
        # Propagate the true effects
        for src, dst in edge_index.t().tolist():
            if dst != intervened_node: # The effect doesn't apply if the child was intervened on
                post_x[dst] = post_x[src] * 0.8 + torch.randn(1) * 0.1 # The ground truth causal law
        
        intervention_info = {
            "node_idx": torch.tensor([intervened_node]),
            "new_value": new_value
        }
        
        dataset.append((Data(x=pre_x, edge_index=edge_index), intervention_info, Data(x=post_x)))
        
    return dataset, edge_index.size(1)


In [9]:
ideal_dataset, num_edges = generate_ideal_interventional_data(num_samples=500)

# Instantiate the model
# It needs to know the total number of unique edges to build its MLPs
model_ideal = GNN_NCM(
    num_features=1,
    hidden_dim=32,
    out_dim=16,
    num_edges=num_edges,
    gnn_mode='per_edge'
)

# Train the model using the supervised approach
#train_with_interventional_data(model_ideal, ideal_dataset, epochs=70)

# --- Verification ---
# Let's test if the trained model learned the causal law
print("\n--- Verifying the learned causal model ---")
model_ideal.eval()
with torch.no_grad():
    # Take a sample from our dataset
    pre_data, intervention, post_data_true = ideal_dataset[0]
    
    # Get the model's prediction
    post_data_pred = model_ideal.do_intervention(
        pre_data.x, pre_data.edge_index, intervention['node_idx'], intervention['new_value']
    )
    
    print(f"Intervention: Set Node {intervention['node_idx'].item()} to {intervention['new_value'].item():.2f}")
    print("\nNode | True Outcome | Predicted Outcome")
    print("---------------------------------------")
    for i in range(len(post_data_true.x)):
        print(f"  {i}  |    {post_data_true.x[i].item():.2f}    |      {post_data_pred[i].item():.2f}")

Generating ideal interventional dataset...

--- Verifying the learned causal model ---
Intervention: Set Node 2 to -0.13

Node | True Outcome | Predicted Outcome
---------------------------------------
  0  |    -0.32    |      -0.16
  1  |    -0.35    |      -0.15
  2  |    -0.13    |      -0.16
  3  |    -0.25    |      -0.16


## Hybrid Causal Trainer

Our chosen training method is Hybrid Causal Trainer which will focus on Interventional Synthesis Principle capturing the two key actions involved:

1. Intervention: We perform a series of hypothetical do-interventions on a node's parents.
2. Synthesis: We then synthesize these individual, hypothetical outcomes (by averaging them) to reconstruct or explain the factual, observed state of the node.

Here the causal loss component acts as a structural regularizer. Its entire purpose is to optimize the parameters of the causal mechanisms within our model (like the structural assumptions that SCM gives). The process is as follows:

1. For a given node v, we identify its parents, pa(v).
2. We perform a set of "what-if" simulations by calling model.do_intervention(parent) for each parent in pa(v).
3. These calls produce a set of interventional predictions for v.
4. We average these predictions to get the causally_derived_pred.
5. The loss MSE(causally_derived_pred, stable_target) is calculated.
6. Crucially, when loss.backward() is called, the gradient flows backwards through the averaging operation and through each of the do_intervention calls. This means the optimizer is forced to adjust the weights of the per-edge MLPs to make the causally-derived prediction more accurate.

We have two possible targets for our loss function:

* Target A (True Label): The ground truth label is often a "sharp" or "spiky" target (e.g., a binary 0 or 1, or a specific real value). Trying to make the average of multiple noisy outputs match this sharp target creates a chaotic and unstable gradient. The model can be pulled in many conflicting directions at once and may fail to learn.

* Target B (Teacher Prediction): A pre-trained observational GNN (our "teacher") provides a smooth and stable target. Its prediction is a continuous value (e.g., 0.87) that represents a high-quality estimate of the conditional expectation E[Y|X]. Forcing our noisy causal estimate to match this smooth target provides a much gentler, more stable, and more informative gradient, guiding the causal mechanisms to a plausible consensus


The teacher model is not a performance ceiling; it is a stabilizing scaffold. The primary loss_obs (which uses the true labels) is still responsible for pushing the model's overall accuracy, allowing it to surpass the teacher. Use of a teacher in calculating causal loss simply ensures that the model achieves this accuracy is by learning robust, internally consistent mechanisms.

Finally our training will combine two loss functions:

- Observational Loss: A standard supervised loss that compares the model's normal forward() pass prediction directly against the ground truth label (y_true). This anchors the model in reality and ensures it learns to be an accurate predictor.
- Interventional Loss: A structural regularizer with a stable target provided by a pre-trained observational "teacher" model. This forces the model's internal mechanisms to be plausible and consistent.

The final loss, Total Loss = Loss_obs + γ * Loss_causal, pushes the model to be accurate while ensuring its internal reasoning is causal, making it more robust and generalizable.

In [10]:
class TeacherGNN(nn.Module):
    """A standard GNN to act as an observational teacher."""
    def __init__(self, num_features, hidden_dim, out_dim):
        super().__init__()
        self.conv1 = GCNConv(num_features, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, out_dim)
        self.out = nn.Linear(out_dim, 1)
    def forward(self, x, edge_index):
        h = F.relu(self.conv1(x, edge_index))
        h = F.relu(self.conv2(h, edge_index))
        return self.out(h)
        
class HybridCausalTrainer:
    """Trains a GNN-NCM using a hybrid loss for performance and causal consistency."""
    def __init__(self, epochs=200, lr=0.01, gamma=0.5):
        self.epochs = epochs; self.lr = lr; self.gamma = gamma; self.loss_fn = nn.MSELoss()

    def train(self, model_to_train, graph_data):
        optimizer = optim.Adam(model_to_train.parameters(), lr=self.lr)
        
        # Pre-train a simple GCN "teacher" to provide stable observational targets
        teacher_model = TeacherGNN(graph_data.num_features, 16, 8).to(graph_data.x.device)
        optimizer_teacher = optim.Adam(teacher_model.parameters(), lr=0.01)
        print("--- Pre-training Teacher Model ---")
        for _ in range(100):
            out = teacher_model(graph_data.x, graph_data.edge_index)
            loss = F.mse_loss(out, graph_data.y); loss.backward(); optimizer_teacher.step()
        teacher_model.eval()
        print("Teacher model trained.")

        print("\n--- Starting Hybrid Causal Training ---")
        for epoch in range(self.epochs):
            model_to_train.train(); optimizer.zero_grad()
            
            # --- 1. Observational Loss (Direct-to-Label) ---
            obs_preds = model_to_train(graph_data.x, graph_data.edge_index)
            loss_obs = self.loss_fn(obs_preds, graph_data.y)
            
            # --- 2. Causal Consistency Loss (Structural Regularizer) ---
            loss_causal = 0
            with torch.no_grad():
                teacher_preds = teacher_model(graph_data.x, graph_data.edge_index).detach()
            
            num_causal_nodes = 0
            for v_idx in range(graph_data.num_nodes):
                parents = graph_data.edge_index[0][graph_data.edge_index[1] == v_idx]
                if len(parents) > 0:
                    num_causal_nodes += 1
                    interventional_preds = []
                    for parent_idx in parents:
                        pred = model_to_train.do_intervention(
                            graph_data.x, graph_data.edge_index,
                            intervened_nodes=torch.tensor([parent_idx]),
                            new_feature_values=torch.zeros(1, graph_data.num_features) 
                            # we do it with zeros so it's a neutral baseline so 
                            # it represents a no information state
                        )
                        interventional_preds.append(pred[v_idx])
                    
                    causally_derived_pred = torch.stack(interventional_preds).mean(dim=0)
                    loss_causal += self.loss_fn(causally_derived_pred, teacher_preds[v_idx])

            if num_causal_nodes > 0:
                loss_causal /= num_causal_nodes
            
            # --- 3. Combined Loss ---
            total_loss = loss_obs + self.gamma * loss_causal
            total_loss.backward()
            optimizer.step()

            if (epoch + 1) % 20 == 0:
                print(f"Epoch {epoch+1:03d} | Total Loss: {total_loss.item():.4f} "
                      f"(Obs: {loss_obs.item():.4f}, Causal: {loss_causal.item():.4f})")
        print("Training finished.")


### Analysis on Synthetic Data

In [None]:
# Define the causal graph (0->1, 0->2. Node 3 is an independent root node)
x = torch.randn(4, 4)
edge_index = torch.tensor([[0, 0], [1, 2]], dtype=torch.long) # 0->1 and 0->2
y = torch.zeros(4, 1)
# Define ground truth such that y is a function of the MEAN of its parents' features
y[0] = -0.5 + torch.randn(1) * 0.1 # Root node
y[1] = x[0].mean() * 0.8 + 0.1 + torch.randn(1) * 0.1 # Child of node 0
y[2] = torch.sin(x[0].mean()) - 0.2 + torch.randn(1) * 0.1 # Child of node 0
y[3] = 0.5 + torch.randn(1) * 0.1 # Independent root node
full_graph = Data(x=x, edge_index=edge_index, y=y)
num_edges = full_graph.edge_index.size(1)

# Instantiate our GNN-NCM
causal_model = GNN_NCM(num_features=4, hidden_dim=16, out_dim=8, num_edges=num_edges)

# Instantiate and run the hybrid causal trainer
trainer = HybridCausalTrainer(epochs=200, lr=0.01, gamma=0.5)
trainer.train(causal_model, full_graph)

# --- Verification ---
print("\n--- Verifying the Trained Causal Model ---")
causal_model.eval()
with torch.no_grad():
    obs_pred = causal_model(full_graph.x, full_graph.edge_index)
    interv_pred = causal_model.do_intervention(
        full_graph.x, full_graph.edge_index,
        intervened_nodes=torch.tensor([0]), # Intervene on the parent
        new_feature_values=torch.full((1, 4), 10.0) # A strong, unseen shock
    )

    print("\nNode | Observational Pred | Post-Intervention Pred | Change")
    print("---------------------------------------------------------------")
    for i in range(full_graph.num_nodes):
        change = interv_pred[i].item() - obs_pred[i].item()
        print(f"  {i}  |      {obs_pred[i].item():.4f}      |       {interv_pred[i].item():.4f}       | {change:+.4f}")

print("\nVERIFICATION RESULT:")
print("Observe that the intervention on the parent (Node 0) caused a LARGE change in its children (Nodes 1 and 2).")
print("Crucially, the change in the independent root node (Node 3) is now NEAR ZERO.")
print("This confirms the model has learned the correct, directed causal structure.")


ModuleNotFoundError: No module named 'src.trainer'