In [11]:
import numpy as np
import networkx as nx
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

import torch_geometric
import torch_geometric.nn as pyg_nn
from torch_geometric.utils import to_undirected

# Attempt imports for econml + LightGBM
try:
    from econml.dml import CausalForestDML
    from lightgbm import LGBMRegressor
    has_econml = True
except ImportError:
    has_econml = False
    print("econml or lightgbm not installed. For CF, run: pip install econml lightgbm")

############################################################################
# 1) DATA GENERATION
############################################################################

def generate_random_graph(n_nodes=10, p=0.2):
    """Generate an Erdos-Renyi G(n,p) random graph with n_nodes."""
    return nx.erdos_renyi_graph(n_nodes, p)

def compute_alpha_beta_from_graph(G):
    """
    alpha = (# of triangles) + 2*(# of 4-cycles).
    beta  = 1 + 0.05*(avg_deg) - 0.01*(num_edges).
    """
    # count triangles
    tri_count = nx.triangles(G)  # dict: node->#triangles
    total_tris = sum(tri_count.values()) // 3

    # count 4-cycles
    nodes = list(G.nodes())
    four_cycles = 0
    for i in range(len(nodes)):
        for j in range(i+1, len(nodes)):
            for k in range(j+1, len(nodes)):
                for l in range(k+1, len(nodes)):
                    sub = G.subgraph([nodes[i], nodes[j], nodes[k], nodes[l]])
                    if sub.number_of_edges() == 4:
                        four_cycles += 1
    alpha_val = float(total_tris + 2*four_cycles)

    # simpler structural stats -> beta
    num_edges = G.number_of_edges()
    avg_deg   = 2.0*num_edges / max(G.number_of_nodes(),1)
    beta_val  = 1.0 + 0.05*avg_deg - 0.01*num_edges

    return alpha_val, beta_val

def build_population(n=2000, n_nodes=10):
    """Build a list of random graphs + (alpha,beta)."""
    pop = []
    for _ in range(n):
        p_val = np.random.uniform(0.1,0.3)
        G = generate_random_graph(n_nodes, p_val)
        alpha_val, beta_val = compute_alpha_beta_from_graph(G)
        pop.append((G, alpha_val, beta_val))
    return pop

def compute_true_ate(pop):
    return np.mean([p[2] for p in pop])

def sample_dataset(population, sample_size=1000):
    """
    For each G, create T=Bernoulli(0.5), Y=alpha+beta*T+noise.
    Return list of (G,w,y,alpha,beta).
    """
    idxs = np.random.choice(len(population), size=sample_size, replace=False)
    data = []
    for idx in idxs:
        G, alpha_val, beta_val = population[idx]
        w = np.random.binomial(1, 0.5)
        noise = np.random.randn()
        y = alpha_val + beta_val*w + noise
        data.append((G, w, y, alpha_val, beta_val))
    return data

############################################################################
# 2) COLLATE + DATASET
############################################################################

class GraphData(Dataset):
    def __init__(self, data_list):
        self.data_list = data_list
    def __len__(self):
        return len(self.data_list)
    def __getitem__(self, idx):
        return self.data_list[idx]

def adjacency_to_flat(G):
    """Flatten adjacency => 10x10=100D for 10-node graphs."""
    A = nx.to_numpy_array(G, nodelist=sorted(G.nodes()))
    return A.reshape(-1).astype(np.float32)

def adjacency_to_edge_index(G):
    """Return (node_features, edge_idx) for PyG GraphConv usage."""
    n = G.number_of_nodes()
    edges = list(G.edges())
    src = [e[0] for e in edges]
    dst = [e[1] for e in edges]
    # make undirected
    src_rev = src + dst
    dst_rev = dst + src
    x_feat = torch.eye(n, dtype=torch.float32)
    edge_idx = torch.tensor([src_rev, dst_rev], dtype=torch.long)
    return x_feat, edge_idx

def collate_batch(batch):
    """
    batch => list of (G,w,y,alpha,beta)
    We'll return (G_list, W_t, Y_t, A_t, B_t).
    """
    G_list, W_list, Y_list, A_list, B_list = [],[],[],[],[]
    for (G,w,y,a,b) in batch:
        G_list.append(G)
        W_list.append(w)
        Y_list.append(y)
        A_list.append(a)
        B_list.append(b)
    W_t = torch.tensor(W_list, dtype=torch.float32).view(-1,1)
    Y_t = torch.tensor(Y_list, dtype=torch.float32).view(-1,1)
    A_t = torch.tensor(A_list, dtype=torch.float32)
    B_t = torch.tensor(B_list, dtype=torch.float32)
    return (G_list, W_t, Y_t, A_t, B_t)

############################################################################
# 3) MODELS: MLP + GNN
############################################################################

class MLPAlphaBeta(nn.Module):
    """Flatten adjacency => MLP => (alpha,beta)."""
    def __init__(self, input_dim=100, hidden_dim=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 2)
        )
    def forward(self, G_list):
        # flatten adjacency for each graph
        vecs = []
        for G in G_list:
            x = adjacency_to_flat(G)
            vecs.append(x)
        X_t = torch.stack([torch.from_numpy(v) for v in vecs], dim=0)  # (batch,100)
        ab_pred = self.net(X_t)  # (batch,2)
        return ab_pred

class GCNAlphaBeta(nn.Module):
    """A simple 2-layer GraphConv => (alpha,beta)."""
    def __init__(self, hidden_dim=32):
        super().__init__()
        self.conv1 = pyg_nn.GraphConv(in_channels=10, out_channels=hidden_dim)
        self.conv2 = pyg_nn.GraphConv(in_channels=hidden_dim, out_channels=hidden_dim)
        self.head  = nn.Linear(hidden_dim, 2)
    def forward(self, G_list):
        outs = []
        for G in G_list:
            x_feat, edge_idx = adjacency_to_edge_index(G)
            edge_idx = to_undirected(edge_idx)
            h = self.conv1(x_feat, edge_idx)
            h = torch.relu(h)
            h = self.conv2(h, edge_idx)
            h = torch.relu(h)
            out_graph = h.mean(dim=0)  # global mean pool => shape (hidden_dim,)
            outs.append(out_graph)
        X_t = torch.stack(outs, dim=0)  # (batch, hidden_dim)
        ab_pred = self.head(X_t)       # (batch,2)
        return ab_pred

############################################################################
# 4) TRAIN + EVALUATION
############################################################################

def dr_ate(alpha_hat, beta_hat, w_arr, y_arr):
    """Double-robust style ATE with p=0.5."""
    mu0 = alpha_hat
    mu1 = alpha_hat + beta_hat
    e   = 0.5
    IF  = (mu1 + w_arr*(y_arr - mu1)/e) - (mu0 + (1 - w_arr)*(y_arr - mu0)/(1-e))
    ate = IF.mean()
    se  = IF.std(ddof=1)/np.sqrt(len(IF))
    return ate, se

def rmse(true_vals, pred_vals):
    return np.sqrt(np.mean((true_vals - pred_vals)**2))

def r2_score(true_vals, pred_vals):
    true_vals = np.array(true_vals)
    pred_vals = np.array(pred_vals)
    sse = np.sum((true_vals - pred_vals)**2)
    sst = np.sum((true_vals - np.mean(true_vals))**2)
    return 1.0 - sse/sst if sst>1e-8 else 0.0

def evaluate_model(model, loader, device='cpu', model_type="MLP"):
    """Compute (ate,se, RMSE(y), R2(y), RMSE(a), R2(a), RMSE(b), R2(b))."""
    alphaP, betaP = [], []
    alphaT, betaT= [], []
    wA, yA, yPred= [], [], []
    with torch.no_grad():
        model.eval()
        for (G_list, W_t, Y_t, A_t, B_t) in loader:
            ab = model(G_list)  # => (batch,2)
            for i in range(len(G_list)):
                a_hat= ab[i,0].item()
                b_hat= ab[i,1].item()
                alphaP.append(a_hat)
                betaP.append(b_hat)
            alphaT.extend(A_t.tolist())
            betaT.extend(B_t.tolist())
            wA.extend(W_t.view(-1).tolist())
            yA.extend(Y_t.view(-1).tolist())
            # predicted y
            for i in range(len(G_list)):
                yPred.append(ab[i,0].item() + ab[i,1].item()*W_t[i].item())
    # Convert
    alphaP= np.array(alphaP)
    betaP=  np.array(betaP)
    alphaT= np.array(alphaT)
    betaT=  np.array(betaT)
    wA   = np.array(wA)
    yA   = np.array(yA)
    yPred= np.array(yPred)

    # ATE, SE
    ate, se= dr_ate(alphaP, betaP, wA, yA)
    # RMSE + R2
    rmse_y= rmse(yA, yPred)
    r2y   = r2_score(yA, yPred)
    rmse_a= rmse(alphaT, alphaP)
    r2a   = r2_score(alphaT, alphaP)
    rmse_b= rmse(betaT, betaP)
    r2b   = r2_score(betaT, betaP)
    return (ate,se, rmse_y, r2y, rmse_a, r2a, rmse_b, r2b)

def train_one_epoch(model, loader, opt, device='cpu'):
    model.train()
    mse= nn.MSELoss()
    for (G_list, W_t, Y_t, A_t, B_t) in loader:
        ab_pred= model(G_list)
        alpha_pred= ab_pred[:,0:1]
        beta_pred= ab_pred[:,1:2]
        Y_pred= alpha_pred + beta_pred*W_t
        loss= mse(Y_pred, Y_t)
        opt.zero_grad()
        loss.backward()
        opt.step()

############################################################################
# 5) CAUSAL FOREST + UTILS
############################################################################

if has_econml:
    from econml.dml import CausalForestDML
    from lightgbm import LGBMRegressor

def prepare_cf_data(samples):
    """
    CF needs X, W, Y + also the true beta to measure error.
    We'll flatten adjacency => X, store W,y,beta
    """
    X_list, W_list, Y_list, B_list= [],[],[],[]
    for (G,w,y,a,b) in samples:
        A= adjacency_to_flat(G)
        X_list.append(A)
        W_list.append(w)
        Y_list.append(y)
        B_list.append(b)
    X_np= np.array(X_list,dtype=np.float32)
    W_np= np.array(W_list,dtype=np.float32)
    Y_np= np.array(Y_list,dtype=np.float32)
    B_np= np.array(B_list,dtype=np.float32)
    return X_np, W_np, Y_np, B_np

############################################################################
# 6) MAIN
############################################################################

def main():
    import math
    import random

    # 1) Build full population => 2000
    population= build_population(n=2000, n_nodes=10)
    true_ate= compute_true_ate(population)
    print(f"True ATE in population= {true_ate:.4f}")

    # We'll define train_sizes, sample_test
    train_sizes= [100, 500, 1000, 1500]
    sample_test= 500

    # We'll sample test_data from the same population
    test_data= sample_dataset(population, sample_size=sample_test)
    # We'll define a test_loader for evaluation
    test_ds= GraphData(test_data)
    test_loader= DataLoader(test_ds, batch_size=64, shuffle=False, collate_fn=collate_batch)

    # We'll store rows => (N, Model, ATE, SE, RMSE(y), R2(y), RMSE(a), R2(a), RMSE(b), R2(b))
    results= []

    # Hyperparams
    mlp_epochs= 500
    gnn_epochs= 500
    lr=1e-3

    def train_mlp(train_samples, epochs):
        ds= GraphData(train_samples)
        dl= DataLoader(ds, batch_size=64, shuffle=True, collate_fn=collate_batch)
        model= MLPAlphaBeta(input_dim=100, hidden_dim=64)
        opt= optim.Adam(model.parameters(), lr=lr)
        for ep in tqdm(range(epochs), desc=f"MLP({len(train_samples)})"):
            train_one_epoch(model, dl, opt)
        return model

    def train_gnn(train_samples, epochs):
        ds= GraphData(train_samples)
        dl= DataLoader(ds, batch_size=64, shuffle=True, collate_fn=collate_batch)
        model= GCNAlphaBeta(hidden_dim=32)
        opt= optim.Adam(model.parameters(), lr=lr)
        for ep in tqdm(range(epochs), desc=f"GNN({len(train_samples)})"):
            train_one_epoch(model, dl, opt)
        return model

    for N in train_sizes:
        # sample new train_data from population (be sure not to overlap test_data)
        train_data= sample_dataset(population, sample_size=N)

        # (A) MLP
        mlp_model= train_mlp(train_data, mlp_epochs)
        mlp_res= evaluate_model(mlp_model, test_loader, model_type="MLP")

        # (B) GNN
        gnn_model= train_gnn(train_data, gnn_epochs)
        gnn_res= evaluate_model(gnn_model, test_loader, model_type="GNN")

        # (C) CF if installed
        if has_econml:
            X_cf_train, W_cf_train, Y_cf_train, B_cf_train= prepare_cf_data(train_data)
            X_cf_test,  W_cf_test,  Y_cf_test,  B_cf_test=  prepare_cf_data(test_data)

            cf_model= CausalForestDML(
                model_y=LGBMRegressor(verbose=-1),
                model_t=LGBMRegressor(verbose=-1),
                n_estimators=400,
                min_samples_leaf=10,
                max_depth=25,
                random_state=42
            )
            cf_model.fit(Y_cf_train, W_cf_train, X=X_cf_train)

            b_hat_cf= cf_model.effect(X_cf_test) # shape=(test_size,)
            ate_cf= np.mean(b_hat_cf)
            lb_cf, ub_cf= cf_model.effect_interval(X_cf_test)
            se_cf= (ub_cf - lb_cf).mean()/(2*1.96)
            # measure RMSE(b) & R2(b)
            rb_cf= rmse(B_cf_test, b_hat_cf)
            r2b_cf= r2_score(B_cf_test, b_hat_cf)

            # CF doesn't produce alpha or Y => placeholders
            cf_res= (ate_cf, se_cf, None, None, None, None, rb_cf, r2b_cf)
        else:
            cf_res= (np.nan, np.nan, None, None, None, None, None, None)

        # Save
        results.append( (N,"MLP", *mlp_res) )
        results.append( (N,"GNN", *gnn_res) )
        results.append( (N,"CF",  *cf_res) )

    print("\n=== RESULTS TABLE ===")
    print("DataN | Model |  ATE_est |  ATE_se |  RMSE(y) | R2(y)   | RMSE(a) | R2(a)   | RMSE(b) | R2(b)")
    for row in results:
        (N,model,ate,se,ry,r2y,ra,r2a,rb,r2b)= row
        if model=="CF":
            # CF => no Y, alpha
            ate_s= f"{ate:8.4f}" if ate is not None else "  --"
            se_s= f"{se:8.4f}" if se is not None else "  --"
            rb_s= f"{rb:7.4f}" if rb is not None else "  --"
            r2b_s= f"{r2b:7.4f}" if r2b is not None else "  --"
            print(f"{N:<5} {model:<3} | {ate_s} | {se_s} |     --   |   --   |   --    |   --   | {rb_s} | {r2b_s}")
        else:
            ate_s= f"{ate:8.4f}"
            se_s= f"{se:8.4f}"
            ry_s= f"{ry:7.4f}"
            r2y_s=f"{r2y:7.4f}"
            ra_s= f"{ra:7.4f}"
            r2a_s=f"{r2a:7.4f}"
            rb_s= f"{rb:7.4f}"
            r2b_s=f"{r2b:7.4f}"
            print(f"{N:<5} {model:<3} | {ate_s} | {se_s} | {ry_s} | {r2y_s} | {ra_s} | {r2a_s} | {rb_s} | {r2b_s}")

    print(f"\nTrue ATE= {true_ate:.4f}")

if __name__=="__main__":
    main()


KeyboardInterrupt: 