# Synthetic SCM: Generation, Interventions, and ATE Benchmarks

**Goal.** Create a controlled 7-node synthetic dataset with the **same topology** as the finance DAG. We will:
- define **structural equations** with noise,
- generate observational samples,
- generate **true interventional** samples with `do(·)`,
- compute ground-truth **ATE**s per edge for benchmarking models.


In [None]:
import os
import json
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from pathlib import Path
import networkx as nx
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import random
import yaml

# PyTorch Geometric imports
from torch_geometric.data import Dataset, Data
from torch_geometric.loader import DataLoader


# CONFIGURATION & PATHS
from pathlib import Path
import sys

PROJECT_ROOT = Path().resolve().parent
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))  


CONFIGS_DIR = PROJECT_ROOT / "configs"

CFG_PATH = CONFIGS_DIR / "best_config.yaml"
with open(CFG_PATH, "r") as f:
    cfg = yaml.safe_load(f)

seed = int(cfg.get("seed", 42))
random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)

from src.models import GNN_NCM, BaselineGCN
from src.trainer import CausalTwoPartTrainer 


## 1. Structural Equations

Use simple nonlinearities to mimic saturation in microstructure:

In [None]:
# Define graph structure
nodes = ["Mom", "HML", "OI", "PC", "BAS", "LIQ", "VOL"]
node_mapping = {name: i for i, name in enumerate(nodes)}
edges = [
    ('Mom', 'HML'), ('Mom', 'PC'), ('HML', 'OI'), ('OI', 'PC'),
    ('OI', 'BAS'), ('PC', 'BAS'), ('BAS', 'LIQ'), ('BAS', 'VOL')
]
edge_index = torch.tensor([[node_mapping[src], node_mapping[dst]] for src, dst in edges], dtype=torch.long).t().contiguous()

# Define ground-truth causal laws
causal_relationships = {
    'HML': {'Mom': 0.5}, 'OI': {'HML': 0.7},
    'PC':  {'Mom': 0.3, 'OI': 0.6},
    'BAS': {'OI': 0.4, 'PC': 0.6}, # This will be non-linear (tanh)
    'LIQ': {'BAS': -0.8}, 'VOL': {'BAS': 0.9}
}
NOISE_STD = 0.1


In [None]:
class SyntheticGraphDataset(Dataset):
    def __init__(self, num_graphs=50, num_nodes=7, feature_dim=4,
                 edge_index=None, causal_relationships=None,
                 node_mapping=None, noise_std=0.1, seed=0):
        super().__init__()
        self.num_graphs = num_graphs
        self.num_nodes = num_nodes
        self.feature_dim = feature_dim  
        self.edge_index = edge_index
        self.causal_relationships = causal_relationships
        self.node_mapping = node_mapping
        self.noise_std = noise_std

        g = torch.Generator().manual_seed(seed)
        self._xs = torch.randn(num_graphs, num_nodes, feature_dim, generator=g)

    def __len__(self):
        return self.num_graphs

    def __getitem__(self, idx):
        x = self._xs[idx].clone()
        y = torch.zeros(self.num_nodes, 1)

        for node_name, parents in self.causal_relationships.items():
            child_idx = self.node_mapping[node_name]
            parent_effects = 0.0
            for parent_name, coeff in parents.items():
                parent_idx = self.node_mapping[parent_name]
                parent_effects += x[parent_idx].mean() * coeff
            if node_name == 'BAS':
                parent_effects = torch.tanh(parent_effects)
            y[child_idx] = parent_effects + torch.randn(1) * self.noise_std

        return Data(x=x, edge_index=self.edge_index, y=y)


ds = SyntheticGraphDataset(
    num_graphs=100,
    num_nodes=7,
    feature_dim=1,
    edge_index=edge_index,
    causal_relationships=causal_relationships,
    node_mapping=node_mapping,
    noise_std=NOISE_STD,
    seed=123,
)

print(len(ds))
print(ds[0])  # first synthetic graph


## 2. Training and Evaluating on Observational Dataset

In [None]:
from torch.utils.data import random_split

mcfg = cfg.get("model", {})
best_params = {
    "hidden_dim": int(mcfg.get("hidden_dim", 32)),
    "out_dim":    int(mcfg.get("out_dim", 16)),
    "noise_dim":  int(mcfg.get("noise_dim", 4)),
}


tcfg = cfg.get("training", {})
trainer_kwargs = dict(
    epochs_obs   = int(tcfg.get("epochs_obs", 40)),
    epochs_do    = int(tcfg.get("epochs_do", 20)),
    lr           = float(tcfg.get("lr", 1e-2)),
    w_obs        = float(tcfg.get("w_obs", 0.2)),
    w_do         = float(tcfg.get("w_do", 1.0)),
    weight_decay = float(tcfg.get("weight_decay", 1e-4)),
    clip         = float(tcfg.get("clip", 1.0)),
    neutral      = tcfg.get("neutral", "zeros"),
    delta        = float(tcfg.get("delta", 0.1)),
)

n_train = int(0.8 * len(ds))
n_val   = len(ds) - n_train
ds_train, ds_val = random_split(ds, [n_train, n_val], generator=torch.Generator().manual_seed(42))

train_loader = DataLoader(ds_train, batch_size=1, shuffle=True)  # keep 1 if graphs vary
val_loader   = DataLoader(ds_val,   batch_size=1, shuffle=False)

models = {}
g0 = next(iter(train_loader))
num_edges = g0.edge_index.size(1)
num_features = g0.num_features
num_edges    = g0.edge_index.size(1)
device = g0.x.device


# --- Train all three models ---

# 1. GNN-NCM (per-edge)
print("\n--- Training GNN-NCM (per-edge) ---")
model_per_edge = GNN_NCM(num_features=num_features, num_edges=num_edges, gnn_mode='per_edge', **best_params)
trainer_per_edge = CausalTwoPartTrainer(**trainer_kwargs)

trainer_per_edge.train(model_per_edge, train_loader, val_loader=val_loader)
val_loss_pe = trainer_per_edge.evaluate_obs_mse(model_per_edge, val_loader)
print(f"[per-edge] val_mse={val_loss_pe:.6f}")
models['GNN-NCM (per-edge)'] = model_per_edge

# 2. GNN-NCM (shared) - Ablation 
print("\n--- Training GNN-NCM (shared) ---")
model_shared = GNN_NCM(num_features=num_features, num_edges=num_edges, gnn_mode='shared', **best_params)
trainer_shared = CausalTwoPartTrainer(**trainer_kwargs)

trainer_shared.train(model_shared, train_loader, val_loader=val_loader)
val_loss_pe = trainer_shared.evaluate_obs_mse(model_shared, val_loader)
print(f"[shared] val_mse={val_loss_pe:.6f}")
models['GNN-NCM (shared)'] = model_shared


# 3. Baseline GCN - A standard, non-causal GNN
print("\n--- Training BaselineGCN ---")
baseline_model = BaselineGCN(num_features=num_features, hidden_dim=best_params['hidden_dim'], out_dim=best_params['out_dim'])
optimizer_baseline = optim.Adam(baseline_model.parameters())
loss_fn_baseline = nn.MSELoss()

for ep in range(200):
    for g in train_loader:
        g = g.to(device)
        baseline_model.train(); optimizer_baseline.zero_grad()
        pred = baseline_model(g.x, g.edge_index)
        loss = loss_fn_baseline(pred, g.y)
        loss.backward(); optimizer_baseline.step()
    if (ep+1) % 40 == 0:
        print(f"[Baseline] ep {ep+1:03d} loss={loss.item():.4f}")

# quick baseline val
baseline_model.eval()
with torch.no_grad():
    total = 0.0; n = 0
    for g in val_loader:
        g = g.to(device)
        total += float(((baseline_model(g.x, g.edge_index) - g.y)**2).mean().item())
        n += 1
val_loss_bl = total / max(n,1)
print(f"[baseline] val_mse={val_loss_bl:.6f}")


models['BaselineGCN'] = baseline_model


### 2.1 Analysis of Causal Fidelity (ATE Recovery)

This is the ultimate test of our model. We will calculate the **true** Average Treatment Effect (ATE) of an intervention directly from our synthetic world's rules. Then, we will ask each model to **estimate** the ATE using its learned mechanisms. The model that gets closest to the true ATE is the one that has best learned the underlying causal structure.

We will test the ATE of `do(BAS = BAS + 1)` on `VOL`.

In [None]:
with torch.no_grad():
    device = g0.x.device
    bas_idx = node_mapping['BAS']
    y0 = models['GNN-NCM (per-edge)'](g0.x, g0.edge_index)
    y1 = models['GNN-NCM (per-edge)'].do_intervention(
        g0.x, g0.edge_index,
        intervened_nodes=torch.tensor([bas_idx], dtype=torch.long, device=device),
        new_feature_values=(g0.x[bas_idx] + 1.0).unsqueeze(0)
    )
    diff = (y1 - y0).abs().sum().item()
    print("Δ after do(BAS+=1):", diff)  # if this prints 0.0, it means model ignores the intervention

In [None]:
# Simple ATE on BAS (target: VOL)
import torch, pandas as pd, matplotlib.pyplot as plt
import torch.nn.functional as F

device = g0.x.device
bas_idx = node_mapping['BAS']
vol_idx = node_mapping['VOL']
delta = 1.0

def mech_eval(x):
    y = torch.zeros_like(g0.y)
    for child, parents in causal_relationships.items():
        ci = node_mapping[child]
        s = 0.0
        for p, w in parents.items():
            pi = node_mapping[p]
            s += x[pi].mean() * w
        if child == 'BAS':
            s = torch.tanh(s)
        y[ci] = s
    return y

# true ATE (noise-free)
x_before = g0.x.clone()
x_after  = x_before.clone(); x_after[bas_idx, :] = x_after[bas_idx, :] + delta
y_before_true = mech_eval(x_before)
y_after_true  = mech_eval(x_after)
true_ate = y_after_true - y_before_true

rows = []
for name, model in models.items():
    model.eval()
    with torch.no_grad():
        preds_before = model(g0.x, g0.edge_index)
        if "GNN-NCM" in name:
            preds_after = model.do_intervention(
                g0.x, g0.edge_index,
                intervened_nodes=torch.tensor([bas_idx], dtype=torch.long, device=device),
                new_feature_values=(g0.x[bas_idx] + delta).unsqueeze(0)
            )
        else:
            x_obs = g0.x.clone()
            x_obs[bas_idx, :] = x_obs[bas_idx, :] + delta
            preds_after = model(x_obs, g0.edge_index)
        est_ate = preds_after - preds_before

    rows.append({
        'model': name,
        'Estimated ATE (VOL)': est_ate[vol_idx].item(),
        'True ATE (VOL)': true_ate[vol_idx].item()
    })

ate_df = pd.DataFrame(rows).set_index('model')
print("\n--- ATE Recovery (do(BAS = +1.0) → VOL) ---")
print(ate_df)

ax = ate_df.plot(kind='bar', rot=0, figsize=(10,6),
                 title="ATE Recovery for do(BAS = +1.0) on VOL")
ax.grid(axis='y', linestyle='--')
plt.ylabel("Average Treatment Effect (ATE)")
plt.xlabel("Model Type")
plt.tight_layout()
plt.show()


## 3. Interventional Data

Pick a parent \(p\), **replace** its equation by \(p:=x^\*\), keep other mechanisms unchanged (graph surgery), then sample downstream. Then train on interventional dataset using interventional supervised learning.


In [None]:
# helper functions
import torch

def build_sem_mats(causal_relationships, node_mapping, device, dtype):
    N = len(node_mapping)
    W = torch.zeros(N, N, device=device, dtype=dtype)
    for child, parents in causal_relationships.items():
        ci = node_mapping[child]
        for p, w in parents.items():
            W[ci, node_mapping[p]] = float(w)
    bas_idx = node_mapping.get('BAS', None)  # nonlinearity node
    return W, bas_idx

def sem_eval_vec(v, W, bas_idx=None):
    y = v @ W.T if v.ndim == 2 else W @ v
    if bas_idx is not None:
        if y.ndim == 1: y[bas_idx] = torch.tanh(y[bas_idx])
        else:           y[:, bas_idx] = torch.tanh(y[:, bas_idx])
    return y.unsqueeze(-1)  


In [None]:
# InterventionalDataset
from torch_geometric.data import Dataset, Data

class InterventionalDataset(Dataset):
    def __init__(self,
                 base_dataset,
                 node_mapping,
                 causal_relationships,
                 deltas=(+1.0, -1.0),
                 per_graph_per_node=2,
                 transform=None, pre_transform=None, pre_filter=None):
        super().__init__(root=None, transform=transform,
                         pre_transform=pre_transform, pre_filter=pre_filter)
        self.base = base_dataset
        self.node_mapping = node_mapping
        self.deltas = list(deltas)
        self.per_graph_per_node = int(per_graph_per_node)

        s0 = self.base[0]
        self.W, self.bas_idx = build_sem_mats(causal_relationships, node_mapping,
                                              device=s0.x.device, dtype=s0.x.dtype)
        self._v = []
        for gi in range(len(self.base)):
            g = self.base[gi]
            self._v.append(g.x.mean(dim=1))  # [N]

        self._index = []
        for gi in range(len(self.base)):
            for _, nidx in node_mapping.items():
                for k in range(self.per_graph_per_node):
                    self._index.append((gi, int(nidx), float(self.deltas[k % len(self.deltas)])))


        self._cache = {} 

    def len(self): return len(self._index)
    def __len__(self): return self.len()

    def get(self, i):
        gi, nidx, delta = self._index[i]
        g = self.base[gi]
        key = (gi, nidx, delta)
        if key not in self._cache:
            v = self._v[gi].clone()
            v[nidx] = v[nidx] + delta
            y_do = sem_eval_vec(v, self.W, self.bas_idx)  
            new_x = (g.x[nidx] + delta).unsqueeze(0)      
            self._cache[key] = (y_do, new_x)
        y_do, new_x = self._cache[key]
        return Data(
            x=g.x, edge_index=g.edge_index, y=g.y, num_nodes=g.num_nodes,
            intervene_nodes=torch.tensor([nidx], dtype=torch.long, device=g.x.device),
            new_feature_values=new_x.to(g.x.device),
            y_do=y_do.to(g.x.device),  # supervised do-labels (synthetic ground truth)
        )

    def __getitem__(self, idx):
        data = self.get(idx)
        if self.transform is not None:
            data = self.transform(data)
        return data


In [None]:
import torch, torch.nn as nn, torch.optim as optim

class InterventionalTrainer:
    def __init__(self,
                 epochs_obs=30,
                 epochs_do=50,
                 lr=1e-2,
                 w_obs=0.2,
                 w_do=1.0,
                 weight_decay=1e-4,
                 clip=1.0):
        self.epochs_obs = int(epochs_obs)
        self.epochs_do  = int(epochs_do)
        self.lr  = float(lr)
        self.w_obs = float(w_obs)
        self.w_do  = float(w_do)
        self.wd  = float(weight_decay)
        self.clip = float(clip)
        self.loss = nn.MSELoss()
        self.history = []

    def train(self, model, obs_loader, do_loader):
        dev = next(model.parameters()).device
        model = model.to(dev)

        # Phase 1: observational warm-up (obs only)
        opt = optim.AdamW(model.parameters(), lr=self.lr, weight_decay=self.wd)
        for ep in range(1, self.epochs_obs + 1):
            model.train()
            obs_sum, n_obs = 0.0, 0
            for g in obs_loader:
                g = g.to(dev)
                pred = model(g.x, g.edge_index)
                l_obs = self.loss(pred, g.y)

                opt.zero_grad()
                l_obs.backward()
                if self.clip: torch.nn.utils.clip_grad_norm_(model.parameters(), self.clip)
                opt.step()

                obs_sum += float(l_obs.detach()); n_obs += 1

            m_obs = obs_sum / max(n_obs, 1)
            self.history.append({"epoch": ep, "phase": "obs", "loss_obs": m_obs, "loss_do": None, "loss_total": m_obs})
            if ep % 10 == 0:
                print(f"[obs {ep:03d}] obs={m_obs:.6f}")

        # Phase 2: obs + supervised do (one combined step per batch)
        # reset optimizer to avoid stale momentum from Phase 1
        opt = optim.AdamW(model.parameters(), lr=self.lr, weight_decay=self.wd)

        for ep in range(1, self.epochs_do + 1):
            model.train()
            obs_sum, do_sum, n_obs, n_do = 0.0, 0.0, 0, 0

            it_obs = iter(obs_loader)
            it_do  = iter(do_loader)

            while True:
                g_obs = next(it_obs, None)
                g_do  = next(it_do,  None)
                if g_obs is None and g_do is None:
                    break

                loss_terms = []
                # observational term
                if g_obs is not None:
                    g_obs = g_obs.to(dev)
                    p_obs = model(g_obs.x, g_obs.edge_index)
                    l_obs = self.loss(p_obs, g_obs.y)
                    loss_terms.append(self.w_obs * l_obs)
                    obs_sum += float(l_obs.detach()); n_obs += 1

                # interventional supervised term
                if g_do is not None:
                    g_do = g_do.to(dev)
                    p_do = model.do_intervention(
                        g_do.x, g_do.edge_index,
                        intervened_nodes=g_do.intervene_nodes,
                        new_feature_values=g_do.new_feature_values
                    )
                    l_do = self.loss(p_do, g_do.y_do)
                    loss_terms.append(self.w_do * l_do)
                    do_sum += float(l_do.detach()); n_do += 1

                # combine and step once
                total = loss_terms[0] if len(loss_terms)==1 else sum(loss_terms)
                opt.zero_grad()
                total.backward()
                if self.clip: torch.nn.utils.clip_grad_norm_(model.parameters(), self.clip)
                opt.step()

            m_obs = obs_sum / max(n_obs, 1) if n_obs else 0.0
            m_do  = do_sum  / max(n_do,  1) if n_do  else 0.0
            total_epoch = (self.w_obs * m_obs) + (self.w_do * m_do if n_do else 0.0)

            ep_abs = self.epochs_obs + ep
            self.history.append({"epoch": ep_abs, "phase": "do", "loss_obs": m_obs, "loss_do": m_do, "loss_total": total_epoch})
            if ep % 10 == 0:
                print(f"[do  {ep:03d}] total={total_epoch:.6f} (obs={m_obs:.6f}, do={m_do:.6f})")

        return model

    @torch.no_grad()
    def evaluate_obs_mse(self, model, loader):
        model.eval()
        dev = next(model.parameters()).device
        tot, n = 0.0, 0
        for g in loader:
            g = g.to(dev)
            p = model(g.x, g.edge_index)
            tot += self.loss(p, g.y).item(); n += 1
        return tot / max(n,1)

    @torch.no_grad()
    def evaluate_do_mse(self, model, do_loader):
        model.eval()
        dev = next(model.parameters()).device
        tot, n = 0.0, 0
        for g in do_loader:
            g = g.to(dev)
            p = model.do_intervention(
                g.x, g.edge_index,
                intervened_nodes=g.intervene_nodes,
                new_feature_values=g.new_feature_values
            )
            tot += self.loss(p, g.y_do).item(); n += 1
        return tot / max(n,1)


In [None]:
from torch_geometric.loader import DataLoader
import matplotlib.pyplot as plt, pandas as pd, torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

obs_loader = DataLoader(ds, batch_size=1, shuffle=True)

do_dataset = InterventionalDataset(
    base_dataset=ds,
    node_mapping=node_mapping,
    causal_relationships=causal_relationships,
    deltas=(+1.0, -1.0),
    per_graph_per_node=2
)
do_loader = DataLoader(do_dataset, batch_size=1, shuffle=True)  

g0 = ds[0]
num_features = g0.num_node_features
num_edges    = g0.edge_index.size(1)

mcfg = cfg.get("model", {})
best_params = {
    "hidden_dim": int(mcfg.get("hidden_dim", 32)),
    "out_dim":    int(mcfg.get("out_dim", 16)),
    "noise_dim":  int(mcfg.get("noise_dim", 4)),
}


tcfg = cfg.get("training", {})
trainer_kwargs = dict(
    epochs_obs   = int(tcfg.get("epochs_obs", 40)),
    epochs_do    = int(tcfg.get("epochs_do", 20)),
    lr           = float(tcfg.get("lr", 1e-2)),
    w_obs        = float(tcfg.get("w_obs", 0.2)),
    w_do         = float(tcfg.get("w_do", 1.0)),
    weight_decay = float(tcfg.get("weight_decay", 1e-4)),
    clip         = float(tcfg.get("clip", 1.0)),
)


model = GNN_NCM(num_features=num_features, num_edges=num_edges, gnn_mode='per_edge', **best_params).to(device)
trainer_do = InterventionalTrainer(**trainer_kwargs)
trainer_do.train(model, obs_loader, do_loader)

# losses
df = pd.DataFrame(trainer_do.history)
plt.figure(figsize=(10,5))
if "loss_obs" in df: plt.plot(df["epoch"], df["loss_obs"], label="obs")
if "loss_do"  in df: plt.plot(df["epoch"], df["loss_do"],  label="do")
plt.plot(df["epoch"], df["loss_total"], label="total", linewidth=2)
plt.xlabel("epoch"); plt.ylabel("loss"); plt.title("Interventional curriculum training")
plt.legend(); plt.grid(True, linestyle='--'); plt.tight_layout(); plt.show()

print("obs_mse =", trainer_do.evaluate_obs_mse(model, obs_loader))
print("do_mse  =", trainer_do.evaluate_do_mse(model, do_loader))


## 4. Ground-truth ATE Evalution


In [None]:
# ATE: do(BAS +1) on VOL
vol_idx = node_mapping.get('VOL', 0)
bas_idx = node_mapping.get('BAS', 0)

g = ds[0].to(device)  
with torch.no_grad():
    dev = g.x.device

    if hasattr(do_dataset, "W"):
        W = do_dataset.W.to(dev, dtype=g.x.dtype)
        bas_sem_idx = do_dataset.bas_idx if hasattr(do_dataset, "bas_idx") else bas_idx
    else:
        W, bas_sem_idx = build_sem_mats(causal_relationships, node_mapping, device=dev, dtype=g.x.dtype)

    # true ATE from SEM 
    v0 = g.x.mean(dim=1)                                
    y_true0 = sem_eval_vec(v0, W, bas_sem_idx)           
    v1 = v0.clone(); v1[bas_idx] = v1[bas_idx] + 1.0
    y_true1 = sem_eval_vec(v1, W, bas_sem_idx)          
    true_ate = (y_true1 - y_true0)[vol_idx].item()

    # model estimates
    p0 = model(g.x, g.edge_index)                        
    new_x_row = g.x[bas_idx] + 1.0                       
    p1 = model.do_intervention(
        g.x, g.edge_index,
        intervened_nodes=torch.tensor([bas_idx], dtype=torch.long, device=dev),
        new_feature_values=new_x_row.unsqueeze(0)        
    )
    est_ate = (p1 - p0)[vol_idx].item()

df_ate = pd.DataFrame([{"Estimated ATE (VOL)": est_ate, "True ATE (VOL)": true_ate}], index=["do(BAS+1)"])

# Standardize df_ate to match ate_df and rename its row
df_ate_row = df_ate.copy()
df_ate_row.index = ["GNN-NCM (interventional)"]   
df_ate_row = df_ate_row.reindex(columns=ate_df.columns) 

# Append as a new row (4th)
ate_df_ext = pd.concat([ate_df, df_ate_row], axis=0)

print("\n--- ATE Recovery (do(BAS = +1.0) → VOL) — with Interventional GNN-NCM ---")
print(ate_df_ext)

# Plot all four models together
ax = ate_df_ext.plot(kind='bar', rot=0, figsize=(10,6),
                     title="ATE Recovery for do(BAS = +1.0) → VOL")
ax.grid(axis='y', linestyle='--')
ax.set_ylabel("Average Treatment Effect (ATE)")
ax.set_xlabel("Model Type")
plt.tight_layout(); plt.show()


In [None]:
# Effect matrix: E[i,j] = effect on node j of do(node i, +1)

dev = next(model.parameters()).device
g = ds[0].to(dev)


if 'W' in globals() and 'bas_sem_idx' in globals():
    Wv, bas_idx_sem = W.to(dev, g.x.dtype), bas_sem_idx
elif hasattr(do_dataset, "W"):
    Wv, bas_idx_sem = do_dataset.W.to(dev, g.x.dtype), getattr(do_dataset, "bas_idx", node_mapping.get('BAS', None))
else:
    Wv, bas_idx_sem = build_sem_mats(causal_relationships, node_mapping, device=dev, dtype=g.x.dtype)

N = g.num_nodes
with torch.no_grad():
    v0 = g.x.mean(dim=1)
    y0_true = sem_eval_vec(v0, Wv, bas_idx_sem).squeeze(-1)  # [N]
    p0 = model(g.x, g.edge_index).squeeze(-1)                # [N]

E_true = torch.zeros(N, N, device=dev)
E_est  = torch.zeros(N, N, device=dev)

with torch.no_grad():
    for i in range(N):
        v1 = v0.clone(); v1[i] = v1[i] + 1.0
        y1_true = sem_eval_vec(v1, Wv, bas_idx_sem).squeeze(-1)
        E_true[i] = (y1_true - y0_true)

        new_row = g.x[i] + 1.0
        p1 = model.do_intervention(
            g.x, g.edge_index,
            intervened_nodes=torch.tensor([i], dtype=torch.long, device=dev),
            new_feature_values=new_row.unsqueeze(0)
        ).squeeze(-1)
        E_est[i] = (p1 - p0)

# metrics
mse_all = torch.mean((E_est - E_true)**2).item()
corr = torch.corrcoef(torch.stack([E_true.flatten(), E_est.flatten()]))[0,1].item()
print(f"Effect-matrix MSE={mse_all:.6f} | Pearson r={corr:.4f}")

# heatmaps
fig, ax = plt.subplots(1,2,figsize=(12,5))
im0 = ax[0].imshow(E_true.detach().cpu().numpy(), aspect='auto'); ax[0].set_title("True effects (do +1)")
im1 = ax[1].imshow(E_est.detach().cpu().numpy(),  aspect='auto'); ax[1].set_title("Estimated effects (do +1)")
for a in ax: a.set_xlabel("target j"); a.set_ylabel("intervened i")
fig.colorbar(im0, ax=ax[0], fraction=0.046); fig.colorbar(im1, ax=ax[1], fraction=0.046)
plt.tight_layout(); plt.show()


## 5. Robustness Test

In [None]:
# stress degradation 
import os, json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch

@torch.no_grad()
def stress_degradation(model, loader, ops):
    # factual MSE (same convention as your val loops)
    mses_f = []
    for g in loader:
        yhat = model(g.x, g.edge_index)
        mses_f.append(float(((yhat - g.y)**2).mean().item()))
    mse_f = float(np.mean(mses_f))

    out = []
    for op in ops:
        mses_s = []
        for g in loader:
            x = g.x.clone()
            node = op["node"]
            new_val = op["value_fn"](float(x[node, 0].item())) if "value_fn" in op else float(op["value_const"])
            try:
                yhat_s = model.do_intervention(
                    x, g.edge_index,
                    intervened_nodes=[node],
                    new_feature_values=torch.tensor([new_val]).float()
                )
            except AttributeError:
                x[node, 0] = new_val
                yhat_s = model(x, g.edge_index)
            mses_s.append(float(((yhat_s - g.y)**2).mean().item()))
        mse_s = float(np.mean(mses_s))
        out.append({
            "stress_name": op["name"],
            "mse_factual": mse_f,
            "mse_stress": mse_s,
            "degradation_ratio": float(mse_s / (mse_f + 1e-12)),
        })
    return out

# build ops
BAS_IDX, MOM_IDX = node_mapping["BAS"], node_mapping["Mom"]
DELTA = 0.5  # standardized units

ops = [
    {"name": "do_BAS_plus_delta", "node": BAS_IDX, "value_fn": lambda v: v + DELTA},
    {"name": "do_MOM_plus_delta", "node": MOM_IDX, "value_fn": lambda v: v + DELTA},
]

models = {
    "per_edge":                 model_per_edge.eval().cpu(),
    "shared":                   model_shared.eval().cpu(),
    "gcn":                      baseline_model.eval().cpu(),
    "GNN-NCM interventional":   model.eval().cpu(),
}



In [None]:
stress = {name: stress_degradation(m, val_loader, ops) for name, m in models.items()}

OUT_SYN = PROJECT_ROOT / "outputs" / "synthetic"

with open(os.path.join(OUT_SYN, "stress_synthetic.json"), "w") as f:
    json.dump(stress, f, indent=2)
print("saved:", os.path.join(OUT_SYN, "stress_synthetic.json"))

# tidy df for plotting
rows = []
for model_name, lst in stress.items():
    for d in lst:
        r = d.copy(); r["model"] = model_name
        rows.append(r)
df = pd.DataFrame(rows)[["model","stress_name","mse_factual","mse_stress","degradation_ratio"]]
df.to_csv(os.path.join(OUT_SYN, "stress_results_tidy.csv"), index=False)

# --- Plot 1: degradation ratio under do(BAS = BAS + Δ) ---
plt.figure(figsize=(6,4))
subset = df[df["stress_name"]=="do_BAS_plus_delta"]
plt.bar(subset["model"], subset["degradation_ratio"])
plt.title("Synthetic — do(BAS = BAS + Δ): Degradation Ratio")
plt.xlabel("Model"); plt.ylabel("MSE_stress / MSE_factual")
plt.tight_layout()
plt.savefig(os.path.join(OUT_SYN, "stress_degradation_do_BAS.png"), dpi=150)
plt.show()

# --- Plot 2: degradation ratio under do(Mom = Mom + Δ) ---
plt.figure(figsize=(6,4))
subset2 = df[df["stress_name"]=="do_MOM_plus_delta"]
plt.bar(subset2["model"], subset2["degradation_ratio"])
plt.title("Synthetic — do(Mom = Mom + Δ): Degradation Ratio")
plt.xlabel("Model"); plt.ylabel("MSE_stress / MSE_factual")
plt.tight_layout()
plt.savefig(os.path.join(OUT_SYN, "stress_degradation_do_MOM.png"), dpi=150)
plt.show()

# --- Plot 3: factual MSE by model (context for ratios) ---
plt.figure(figsize=(6,4))
mse_factual = df.groupby("model")["mse_factual"].mean().reset_index()
plt.bar(mse_factual["model"], mse_factual["mse_factual"])
plt.title("Synthetic — Factual MSE (Val split)")
plt.xlabel("Model"); plt.ylabel("MSE (factual)")
plt.tight_layout()
plt.savefig(os.path.join(OUT_SYN, "factual_mse.png"), dpi=150)
plt.show()

df


## 6. Saving the Results

In [None]:
import os, json, pandas as pd, numpy as np


from pathlib import Path
import sys

PROJECT_ROOT = Path().resolve().parent
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))  

OUT_SYN = PROJECT_ROOT /"outputs" / "synthetic"
os.makedirs(OUT_SYN, exist_ok=True)

ate_df_ext.to_csv(os.path.join(OUT_SYN, "ate_table_VOL.csv"), index=False)
print("Saved synthetic ATE artifacts to", OUT_SYN)
