In [9]:
import numpy as np
import pandas as pd
import json

import torch
import torch.nn.functional as F
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool

from xgboost import XGBRegressor
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, f1_score

device = torch.device("cuda")

# --- 1. Load & merge your real data ---
TRAIN_DIR = "./data/TRAIN_NEW"

cat_df = pd.read_excel(f"{TRAIN_DIR}/TRAIN_CATEGORICAL_METADATA_new.xlsx")
quant_df = pd.read_excel(f"{TRAIN_DIR}/TRAIN_QUANTITATIVE_METADATA_new.xlsx")
conn_df  = pd.read_csv(f"{TRAIN_DIR}/TRAIN_FUNCTIONAL_CONNECTOME_MATRICES_new_36P_Pearson.csv")
sol_df   = pd.read_excel(f"{TRAIN_DIR}/TRAINING_SOLUTIONS.xlsx")

df = cat_df.merge(quant_df, on="participant_id") \
           .merge(conn_df, on="participant_id") \
           .merge(sol_df[["participant_id", "ADHD_Outcome"]], on="participant_id")

# Tabular features & labels
X = df.drop(columns=["participant_id", "ADHD_Outcome"])
y = df["ADHD_Outcome"].astype(int)

# Train/test split (stratified)
X_train_df, X_test_df, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

# Preprocess tabular data
pipeline = Pipeline([
    ("imputer", SimpleImputer(strategy="mean")),
    ("scaler", StandardScaler()),
])
X_train_np = pipeline.fit_transform(X_train_df)
X_test_np  = pipeline.transform(X_test_df)

# Recombine into full arrays for XGNN loops
X_all = np.vstack([X_train_np, X_test_np])
y_all = np.hstack([y_train,        y_test       ])

num_patients = X_all.shape[0]
num_meta_feats = X_all.shape[1]

# --- 2. Build per-patient graph Data objects ---
# Reconstruct each patient’s 200×200 connectivity matrix from conn_df (flattened in columns)
n = int((1 + np.sqrt(1 + 8 * (conn_df.shape[1]-1))) / 2)  # infer number of regions
connectome_vals = conn_df.drop(columns="participant_id").values  # shape [num_patients, n(n-1)/2]
# Build symmetric matrices with diag=1
connectomes = np.zeros((num_patients, n, n), dtype=float)
for i in range(num_patients):
    lower = np.tril_indices(n, -1)
    m = np.zeros((n,n))
    m[lower] = connectome_vals[i]
    m = m + m.T
    np.fill_diagonal(m, 1.0)
    connectomes[i] = m

data_list = []
for i in range(num_patients):
    mat = torch.tensor(connectomes[i], dtype=torch.float)
    # build kNN graph on regions for each patient (top-k strongest edges per node)
    k = 10
    edges, weights = [], []
    for u in range(n):
        row = mat[u].clone()
        row[u] = -1  # exclude self
        topk = torch.topk(row, k=k).indices
        for v in topk:
            edges.append([u, v.item()])
            weights.append(mat[u, v].item())
    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
    edge_attr  = torch.tensor(weights, dtype=torch.float)

    # initial node features: identity matrix (region one-hot)
    x = torch.eye(n)

    # placeholder for XGBoost meta-feature (will broadcast later)
    # we reserve N_boost rounds, but start with zeros
    # we'll stack these onto x later in the loop
    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr,
                y=torch.tensor([y_all[i]], dtype=torch.float))
    data_list.append(data)

# Split into DataLoaders
train_count = len(y_train)
train_loader = DataLoader(data_list[:train_count], batch_size=16, shuffle=True)
test_loader  = DataLoader(data_list[train_count:], batch_size=16, shuffle=False)

# --- 3. Define GNN for graph classification ---
class GNN(torch.nn.Module):
    def __init__(self, in_feats, hidden):
        super().__init__()
        self.conv1 = GCNConv(in_feats, hidden)
        self.conv2 = GCNConv(hidden, hidden)
        self.lin   = torch.nn.Linear(hidden, 1)  # graph-level logit

    def forward(self, data):
        x, ei, ea, batch = data.x, data.edge_index, data.edge_attr, data.batch
        x = F.relu(self.conv1(x, ei, edge_weight=ea))
        x = F.relu(self.conv2(x, ei, edge_weight=ea))
        x = global_mean_pool(x, batch)          # [batch_size, hidden]
        return self.lin(x).squeeze(1)            # [batch_size]

model = GNN(in_feats=n + 1, hidden=64).to(device)  
# note: in_feats = n (identity) + 1 (XGB meta-feature)
opt   = torch.optim.Adam(model.parameters(), lr=1e-2)
crit  = torch.nn.BCEWithLogitsLoss()

# --- 4. XGBoost + GNN joint training (XGNN) ---
N_rounds = 5           # total boosting iterations
trees_per_round = 10

xgb_model = None       # will accumulate trees over rounds
meta_feature = np.zeros((num_patients, 1))  # initial XGB output = 0

for rnd in range(N_rounds):
    # --- Append current meta_feature to each graph’s node features ---
    for i, data in enumerate(data_list):
        # broadcast meta_feature[i] to all n nodes as 1-dim column
        m = float(meta_feature[i])
        bf = torch.full((n,1), m, dtype=torch.float)
        data.x = torch.cat([torch.eye(n).to(device), bf.to(device)], dim=1).to(device)

    # (a) Train GNN for a few epochs
    model.train()
    for epoch in range(5):
        for batch in train_loader:
            batch = batch.to(device)
            opt.zero_grad()
            logits = model(batch)                  # [batch_size]
            loss   = crit(logits, batch.y)
            loss.backward()
            opt.step()

    # (b) Compute pseudo‐residuals for each patient
    model.eval()
    # first pass to get probabilities
    with torch.no_grad():
        all_logits = []
        for batch in train_loader:
            batch = batch.to(device)
            all_logits.append(model(batch).to(device))
        logits_cat = torch.cat(all_logits)       # [n_train]
        probs = torch.sigmoid(logits_cat).cpu().numpy()
    y_tr = y_all[:train_count]
    if rnd == 0:
        residuals = (y_tr - probs)             # initial residual = true – pred
    else:
        # compute gradient w.r.t. the last appended feature
        grads = []
        for batch in train_loader:
            batch = batch.to(device)
            # re-create the input features with grad enabled on Xreq
            Xreq = batch.x.clone().detach().requires_grad_(True)
            # forward through the same model
            logits_req = model(Data(
                x=Xreq,
                edge_index=batch.edge_index,
                edge_attr=batch.edge_attr,
                batch=batch.batch
            ))
            loss_req = crit(logits_req, batch.y)
            # zero any old grads
            opt.zero_grad()
            if Xreq.grad is not None:
                Xreq.grad.zero_()
            # backprop – this populates Xreq.grad
            loss_req.backward()
            # pool the gradient on the broadcast channel
            grad_feat = Xreq.grad[:, -1]
            for graph_id in batch.batch.unique():
                mask = (batch.batch == graph_id)
                grads.append(grad_feat[mask].mean().item())

        residuals = -np.array(grads)            # negative gradient

    # (c) Fit/update XGBoost on metadata
    X_meta_tr = X_all[:train_count]           # shape [n_train, num_meta_feats]
    if xgb_model is None:
        xgb_model = XGBRegressor(objective="reg:squarederror",
                                 n_estimators=trees_per_round, learning_rate=0.1)
        xgb_model.fit(X_meta_tr, residuals)
    else:
        n_t = xgb_model.get_params()["n_estimators"] + trees_per_round
        xgb_model.set_params(n_estimators=n_t)
        xgb_model.fit(X_meta_tr, residuals, xgb_model=xgb_model)

    # (d) Predict new meta_feature for *all* patients
    meta_feature = xgb_model.predict(X_all).reshape(-1,1)

    print(f"Round {rnd+1}/{N_rounds} complete")

# --- 5. Final evaluation on held-out graphs ---
model.eval()
all_preds, all_labels = [], []
with torch.no_grad():
    for batch in test_loader:
        batch = batch.to(device)
        logits = model(batch).cpu().numpy()
        preds  = (1/(1+np.exp(-logits)) >= 0.5).astype(int)
        all_preds.extend(preds.tolist())
        all_labels.extend(batch.y.cpu().numpy().astype(int).tolist())

acc = accuracy_score(all_labels, all_preds)
f1  = f1_score(all_labels, all_preds)
print(f"\nTest Accuracy: {acc:.4f}")
print(f"Test F1-score: {f1:.4f}")


  m = float(meta_feature[i])


Round 1/5 complete
Round 2/5 complete
Round 3/5 complete
Round 4/5 complete
Round 5/5 complete

Test Accuracy: 0.6626
Test F1-score: 0.7929
