In [2]:
# ================== ONE CELL: C-simple + FP→Z + Truth ==================
# If needed (Colab): !pip -q install numpy pandas scipy scikit-learn torch
import os, re, math, random, hashlib, numpy as np, pandas as pd
from numpy.linalg import eigh
from scipy.stats import pearsonr, spearmanr
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.metrics import roc_auc_score, average_precision_score, f1_score, accuracy_score, confusion_matrix, mean_squared_error, mean_absolute_error
import torch, torch.nn as nn
import torch.nn.functional as F
import scipy.sparse as sp

# ---------------- CONFIG ----------------
CSV_PATH           = "/content/30_1000_base.csv"  # first column y, then N formula columns (S)
SEED               = 7
DIM                = 128               # embedding dim for spectral Z and student
M_PROBES           = 256               # # of semantic mini-worlds (fingerprint)
OBS_FRACTION       = 0.10              # used only if CSV has no missing entries (to simulate partial observation)
HOLDOUT_FRACTION   = 0.10              # fraction of observed pairs kept for leak-free eval of completion
PAIR_EVAL          = 20000             # #pairs to sample for correlation sanity checks
# GNN (C-simple)
EPOCHS_C           = 300
BATCH_EDGES_C      = 40000
LR_C               = 1e-3
APPNP_K            = 10
APPNP_ALPHA        = 0.1
EDGE_TEMP          = 1.0               # keep 1.0 for pure regression; 1.5–2.0 to sharpen propagation later
BLOCK_PRED         = 128               # block size for full kernel prediction
# Student (FP→Z)
EPOCHS_STUDENT     = 40
LR_STUDENT         = 2e-3
PAIR_SAMPLES       = 2048              # pairwise cosine samples per batch
PAIR_LOSS_W        = 0.5               # weight for pairwise loss vs vector MSE
BATCH_STUDENT      = 512
# Truth head
EPOCHS_TRUTH       = 20
LR_TRUTH           = 1e-3
BATCH_TRUTH        = 512
# Saving
SAVE_ARTIFACTS     = True
OUT_DIR            = "./"

# ---------------- Repro & Device ----------------
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# ---------------- Load CSV ----------------
assert os.path.exists(CSV_PATH), f"File not found: {CSV_PATH}"
df = pd.read_csv(CSV_PATH)
N  = df.shape[0]
formulas = list(df.columns[1:])
assert len(formulas) == N, "Expected N rows and N formula columns (1..N)."
y = df.iloc[:, 0].to_numpy().astype(np.float32)
S_raw = df.iloc[:, 1:].to_numpy(dtype=float)  # may include NaN for unobserved pairs

# Ensure symmetry (even if sparse) but keep NaNs where both missing
S = 0.5*(S_raw + S_raw.T)
diag = np.eye(N, dtype=bool)
S[diag] = 1.0
# Clip known entries into [0,1]
mask_known = ~np.isnan(S)
S[mask_known] = np.clip(S[mask_known], 0.0, 1.0)
y_bin = (y > 0.5).astype(np.float32)
print(f"N={N} | y∈[0,1] | S(min,max) among known={np.nanmin(S):.3f}/{np.nanmax(S):.3f}")

# ---------------- Build observed/unobserved/holdout pairs ----------------
def upper_pairs(N):
    # return all (i,j) with i<j
    I, J = np.triu_indices(N, k=1)
    return np.stack([I, J], axis=1)

all_pairs_u = upper_pairs(N)

# If the CSV is dense (no NaNs), simulate a partial observation set
if np.isnan(S).sum() == 0:
    total = len(all_pairs_u)
    m_obs = max(1, int(OBS_FRACTION * total))
    idx = np.random.RandomState(SEED).choice(total, size=m_obs, replace=False)
    obs_pairs_all = all_pairs_u[idx]
    obs_vals_all  = S[obs_pairs_all[:,0], obs_pairs_all[:,1]]
else:
    # Use only pairs with known values (i<j)
    Kmask = (~np.isnan(S)) & (~np.eye(N, dtype=bool))
    I, J = np.where(np.triu(Kmask, k=1))
    obs_pairs_all = np.stack([I, J], axis=1)
    obs_vals_all  = S[I, J].astype(np.float32)

# Split observed into train-observed vs holdout-observed for leak-free completion eval
rng = np.random.RandomState(SEED)
perm = rng.permutation(len(obs_pairs_all))
m_hold = max(1, int(HOLDOUT_FRACTION * len(obs_pairs_all)))
hold_idx = perm[:m_hold]; train_obs_idx = perm[m_hold:]
holdout_pairs = obs_pairs_all[hold_idx]
holdout_true  = obs_vals_all[hold_idx]
obs_pairs     = obs_pairs_all[train_obs_idx]
obs_vals      = obs_vals_all[train_obs_idx]

# Unobserved = all remaining upper pairs not in obs_pairs_all
obs_set = set(map(tuple, obs_pairs_all.tolist()))
unobs_pairs = np.array([p for p in all_pairs_u.tolist() if tuple(p) not in obs_set], dtype=np.int64)

print(f"Observed (train) pairs={len(obs_pairs)} | Holdout (pairs)={len(holdout_pairs)} | Unobserved={len(unobs_pairs)}")

# ---------------- Propositional Parser & Robust FP (unseen atoms handled) ----------------
OP_MAP = {"→":" IMP ", "⇒":" IMP ", "=>":" IMP ", "->":" IMP ",
          "↔":" IFF ", "<=>":" IFF ", "<->":" IFF ",
          "⊑":" SUB ",  # treat as IMP
          "⊓":" AND ", "∧":" AND ", "&&":" AND ",
          "⊔":" OR  ", "∨":" OR  ", "||":" OR  ",
          "¬":" NOT ", "~":" NOT ", "!":" NOT "}
BIN_OPS, UNARY_OPS = {"AND","OR","IMP","IFF","SUB"}, {"NOT"}
TOKEN_RE = re.compile(r"[A-Za-z0-9_]+|[()]")

def norm_text(s):
    s = str(s)
    for k,v in OP_MAP.items(): s = s.replace(k,v)
    return s

def lex(s): return TOKEN_RE.findall(norm_text(s))
def is_atom(t): return t not in BIN_OPS|UNARY_OPS|{"(",")"}

class Parser:
    def __init__(self,toks): self.toks=toks; self.i=0
    def peek(self): return self.toks[self.i] if self.i<len(self.toks) else None
    def pop(self): t=self.peek(); self.i += (1 if t is not None else 0); return t
    def parse(self): return self.expr(0)
    PREC = {"IFF":1,"IMP":2,"SUB":2,"OR":3,"AND":4}
    RIGHT = {"IMP","IFF","SUB"}
    def expr(self,minp):
        node=self.unary()
        while True:
            op=self.peek()
            if op in BIN_OPS:
                prec=self.PREC.get(op,0)
                if prec<minp: break
                self.pop()
                nextp = prec if op in self.RIGHT else prec+1
                rhs=self.expr(nextp)
                node=("BIN",op,node,rhs)
            else: break
        return node
    def unary(self):
        t=self.peek()
        if t in UNARY_OPS:
            self.pop(); c=self.unary(); return ("UN",t,c)
        if t=="(":
            self.pop(); n=self.expr(0); assert self.pop()==")","Missing ')'"
            return n
        a=self.pop()
        return ("ATOM", a if a is not None else "x")

def parse_formula(s):
    try: return Parser(lex(s)).parse()
    except: return ("ATOM","x")

def atoms_in(node, acc=None):
    if acc is None: acc=set()
    k=node[0]
    if k=="ATOM": acc.add(node[1]); return acc
    if k=="UN": return atoms_in(node[2], acc)
    if k=="BIN": atoms_in(node[2], acc); atoms_in(node[3], acc); return acc
    return acc

def depth(node):
    k=node[0]
    if k=="ATOM": return 1
    if k=="UN": return 1+depth(node[2])
    if k=="BIN": return 1+max(depth(node[2]), depth(node[3]))
    return 1

# --- Deterministic hashing for unseen atoms (probe-consistent) ---
def _u64_from_str(s: str) -> int:
    h = hashlib.blake2b(s.encode('utf-8'), digest_size=8).digest()
    return int.from_bytes(h, 'big')

def bernoulli_from_name(atom: str, probe_idx: int, p: float, seed: int) -> bool:
    u = (_u64_from_str(f"{atom}|{probe_idx}|{seed}") % (1<<53)) / float(1<<53)
    return u < p

class ProbeEnv:
    def __init__(self, base_env: dict, probe_idx: int, bias_p: float, seed: int):
        self.base = base_env
        self.m    = probe_idx
        self.p    = float(bias_p)
        self.seed = int(seed)
        self.cache = {}
    def get(self, atom: str) -> bool:
        if atom in self.base: return bool(self.base[atom])
        if atom in self.cache: return self.cache[atom]
        v = bernoulli_from_name(atom, self.m, self.p, self.seed)
        self.cache[atom] = v
        return v

def eval_ast(node, env_obj):
    k=node[0]
    if k=="ATOM": return bool(env_obj.get(node[1]))
    if k=="UN":
        _,op,c = node
        v = eval_ast(c, env_obj)
        return (not v)
    if k=="BIN":
        _,op,l,r = node
        a = eval_ast(l, env_obj); b = eval_ast(r, env_obj)
        if op=="AND": return a and b
        if op=="OR":  return a or b
        if op in ("IMP","SUB"): return (not a) or b
        if op=="IFF": return a==b
    return False

# ---------------- Build semantic fingerprint (FP) ----------------
asts = [parse_formula(s) for s in formulas]
all_atoms = sorted(set().union(*[atoms_in(t) for t in asts]))
A = len(all_atoms)
print(f"#Atoms found: {A}")

# probes (half bias 0.3, half 0.7)
rng = np.random.default_rng(SEED)
biases = np.concatenate([np.full(M_PROBES//2, 0.3), np.full(M_PROBES - M_PROBES//2, 0.7)])
assignments = []
for p in biases:
    vals = rng.random(A) < p
    env = {a: bool(v) for a,v in zip(all_atoms, vals)}  # known atoms only
    assignments.append(env)

# Truth matrix T: N x M_PROBES with unseen atoms handled via ProbeEnv
T_mat = np.zeros((N, M_PROBES), dtype=np.float32)
for i,ast in enumerate(asts):
    for m_i,base_env in enumerate(assignments):
        env_obj = ProbeEnv(base_env, probe_idx=m_i, bias_p=biases[m_i], seed=SEED)
        T_mat[i,m_i] = 1.0 if eval_ast(ast, env_obj) else 0.0

# Structural features
def op_counts(toks):
    return toks.count("AND"), toks.count("OR"), toks.count("NOT"), toks.count("IMP")+toks.count("SUB"), toks.count("IFF")
struct_rows=[]
for s,ast in zip(formulas,asts):
    toks = lex(s)
    ac = len(atoms_in(ast, set()))
    d  = depth(ast)
    c_and, c_or, c_not, c_imp, c_iff = op_counts(toks)
    struct_rows.append([ac, d, c_and, c_or, c_not, c_imp, c_iff, len(toks)])
STRUCT = np.array(struct_rows, dtype=np.float32)

# Final FP
FP = np.concatenate([T_mat, STRUCT], axis=1).astype(np.float32)
print("Fingerprint shape:", FP.shape)

# ---------------- FP preprocessing for models ----------------
# For GNN encoder features (lowered dimension for stability)
sc_fp_gnn = StandardScaler().fit(FP)
FP_std_g  = sc_fp_gnn.transform(FP).astype(np.float32)
pca_gnn   = PCA(n_components=min(256, FP_std_g.shape[1]), whiten=True, random_state=SEED).fit(FP_std_g)
FP_low    = pca_gnn.transform(FP_std_g).astype(np.float32)

# For student (we'll standardize inside the student trainer using train only)
# ---------------- Pairwise metrics ----------------
def evaluate_pairs(truth, pred):
    mse = float(mean_squared_error(truth, pred))
    mae = float(mean_absolute_error(truth, pred))
    if len(truth) > 1:
        pe = float(pearsonr(truth, pred)[0])
        sp = float(spearmanr(truth, pred)[0])
    else:
        pe = sp = float('nan')
    return {"mse":mse, "mae":mae, "pearson":pe, "spearman":sp}

def print_metrics(name, m):
    print(f"[{name}] MSE={m['mse']:.6f}  MAE={m['mae']:.6f}  Pearson={m['pearson']:.3f}  Spearman={m['spearman']:.3f}")

# ---------------- C-simple: APPNP encoder + calibrated dot decoder ----------------
def build_A_hat_from_obs(N, pairs_idx, weights, edge_temp=1.0):
    w = np.clip(weights.astype(np.float32), 0.0, 1.0)
    if edge_temp != 1.0:
        w = np.exp(edge_temp * w)
    rows = np.concatenate([pairs_idx[:,0], pairs_idx[:,1]])
    cols = np.concatenate([pairs_idx[:,1], pairs_idx[:,0]])
    vals = np.concatenate([w, w])
    # self-loops
    rows = np.concatenate([rows, np.arange(N)])
    cols = np.concatenate([cols, np.arange(N)])
    vals = np.concatenate([vals, np.ones(N, dtype=np.float32)])
    A = sp.coo_matrix((vals, (rows, cols)), shape=(N, N)).tocsr()
    d = np.array(A.sum(1)).ravel()
    d_inv_sqrt = 1.0 / np.sqrt(np.maximum(d, 1e-8))
    D_inv_sqrt = sp.diags(d_inv_sqrt)
    A_hat = (D_inv_sqrt @ A @ D_inv_sqrt).tocoo()
    idx = np.vstack([A_hat.row, A_hat.col])
    val = A_hat.data.astype(np.float32)
    i = torch.tensor(idx, dtype=torch.long, device=device)
    v = torch.tensor(val, dtype=torch.float32, device=device)
    return torch.sparse_coo_tensor(i, v, (N, N), device=device)

class APPNPEncoder(nn.Module):
    def __init__(self, d_in, d_hidden, d_out, p_drop=0.1):
        super().__init__()
        self.lin1 = nn.Linear(d_in, d_hidden)
        self.lin2 = nn.Linear(d_hidden, d_out)
        self.drop = nn.Dropout(p_drop)
    def forward(self, X, A_hat, K=10, alpha=0.1):
        H0 = F.relu(self.lin1(X))
        H0 = self.drop(H0)
        H0 = self.lin2(H0)
        H  = H0
        for _ in range(K):
            H = (1 - alpha) * torch.sparse.mm(A_hat, H) + alpha * H0
        return H

class GNNCompletion(nn.Module):
    def __init__(self, d_in, d_hidden, d_out):
        super().__init__()
        self.enc = APPNPEncoder(d_in, d_hidden, d_out, p_drop=0.1)
        self.scale = nn.Parameter(torch.tensor(1.0))
        self.bias  = nn.Parameter(torch.tensor(0.0))
    def forward_embeddings(self, X, A_hat, K, alpha):
        Z = self.enc(X, A_hat, K=K, alpha=alpha)
        with torch.no_grad():
            norms = Z.norm(dim=1, keepdim=True).clamp_min(1e-6)
            Z /= norms.clamp_max(2.0)
        return F.normalize(Z, dim=1)
    def decode_pairs(self, Z, pairs_idx):
        i = torch.as_tensor(pairs_idx[:,0], dtype=torch.long, device=Z.device)
        j = torch.as_tensor(pairs_idx[:,1], dtype=torch.long, device=Z.device)
        dp = (Z[i] * Z[j]).sum(1)
        return torch.sigmoid(self.scale * dp + self.bias)

def train_gnn_completion_simple(
    FP_low, obs_pairs_idx, obs_vals,
    holdout_pairs_idx=None, holdout_true=None,
    rank_d=128, d_hidden=128, epochs=60, batch_pairs=40000, lr=1e-3,
    appnp_K=10, appnp_alpha=0.1, edge_temp=1.0
):
    N, d_in = FP_low.shape
    X  = torch.tensor(FP_low, dtype=torch.float32, device=device)
    Ah = build_A_hat_from_obs(N, obs_pairs_idx, obs_vals, edge_temp=edge_temp)
    model = GNNCompletion(d_in, d_hidden, rank_d).to(device)
    opt   = torch.optim.Adam(model.parameters(), lr=lr)

    for ep in range(1, epochs+1):
        model.train()
        Z = model.forward_embeddings(X, Ah, K=appnp_K, alpha=appnp_alpha)
        idx = np.random.randint(0, len(obs_pairs_idx), size=min(batch_pairs, len(obs_pairs_idx)))
        batch = obs_pairs_idx[idx]
        s_t   = torch.as_tensor(obs_vals[idx], dtype=torch.float32, device=device)
        pred = model.decode_pairs(Z, batch).clamp(0,1)
        loss = F.mse_loss(pred, s_t)
        opt.zero_grad(set_to_none=True)
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 2.0)
        opt.step()

        if ep in {1,2,3,5,10,20,40,epochs} and (holdout_pairs_idx is not None):
            model.eval()
            with torch.no_grad():
                Z_eval = model.forward_embeddings(X, Ah, K=appnp_K, alpha=appnp_alpha)
                i = torch.as_tensor(holdout_pairs_idx[:,0], dtype=torch.long, device=device)
                j = torch.as_tensor(holdout_pairs_idx[:,1], dtype=torch.long, device=device)
                ph = torch.sigmoid(model.scale * (Z_eval[i]*Z_eval[j]).sum(1) + model.bias).clamp(0,1).cpu().numpy()
                m  = evaluate_pairs(holdout_true, ph)
            print(f"[C-simple] ep {ep:02d}  train_loss={float(loss):.6f}")
            print_metrics("C-simple holdout", m)

    model.eval()
    with torch.no_grad():
        Z = model.forward_embeddings(torch.tensor(FP_low, dtype=torch.float32, device=device), Ah, K=appnp_K, alpha=appnp_alpha)
    return model, Z

def predict_full_kernel_from_Z(model, Z, block=128):
    N = Z.shape[0]
    K = np.zeros((N, N), dtype=np.float32)
    for i0 in range(0, N, block):
        i1 = min(N, i0+block)
        Ii = torch.arange(i0, i1, device=Z.device)
        Zi = Z[Ii]
        for j0 in range(0, N, block):
            j1 = min(N, j0+block)
            Jj = torch.arange(j0, j1, device=Z.device)
            Zj = Z[Jj]
            with torch.no_grad():
                dp = Zi @ Zj.T
                pred = torch.sigmoid(model.scale * dp + model.bias).clamp(0,1).cpu().numpy()
            K[i0:i1, j0:j1] = pred
    K = 0.5*(K + K.T); np.fill_diagonal(K, 1.0)
    return K

# ---- Train C-simple and evaluate on HOLDOUT observed pairs only (no leakage) ----
model_C, Z_C, = train_gnn_completion_simple(
    FP_low, obs_pairs, obs_vals,
    holdout_pairs_idx=holdout_pairs, holdout_true=holdout_true,
    rank_d=DIM, d_hidden=DIM, epochs=EPOCHS_C, batch_pairs=BATCH_EDGES_C, lr=LR_C,
    appnp_K=APPNP_K, appnp_alpha=APPNP_ALPHA, edge_temp=EDGE_TEMP
)

S_hat_Csimple = predict_full_kernel_from_Z(model_C, Z_C, block=BLOCK_PRED)
preds_hold = S_hat_Csimple[holdout_pairs[:,0], holdout_pairs[:,1]]
m_Csimple = evaluate_pairs(holdout_true, preds_hold)
print_metrics("C-simple FINAL holdout", m_Csimple)

# ---------------- Spectral Embedding (Teacher Z from completed kernel) ----------------
def spectral_embedding_psd(K: np.ndarray, d: int) -> np.ndarray:
    K = 0.5*(K + K.T)
    vals, vecs = eigh(K)
    idx = np.argsort(vals)[::-1]
    vals, vecs = vals[idx], vecs[:, idx]
    vals = np.clip(vals, 0.0, None)
    keep = min(d, max(1, int((vals > 1e-9).sum())))
    U = vecs[:, :keep]; L = np.sqrt(vals[:keep])
    Z = U * L
    Z = Z / (np.linalg.norm(Z, axis=1, keepdims=True) + 1e-12)
    if keep < d:
        Zp = np.zeros((K.shape[0], d), dtype=Z.dtype); Zp[:, :keep] = Z; Z = Zp
    return Z.astype(np.float32)

Z_teacher = spectral_embedding_psd(S_hat_Csimple, DIM)
print("Teacher Z shape:", Z_teacher.shape)
# Teacher sanity on random pairs
m = min(PAIR_EVAL, N*N)
ii = np.random.randint(0, N, size=m); jj = np.random.randint(0, N, size=m)
S_hat = (Z_teacher[ii] * Z_teacher[jj]).sum(1)
S_ij  = S_hat_Csimple[ii, jj]
print("[Teacher sanity] S≈ZZ^T: MSE=%.4f  Pearson=%.3f  Spearman=%.3f" %
      (float(((S_hat - S_ij)**2).mean()), pearsonr(S_ij, S_hat)[0], spearmanr(S_ij, S_hat)[0]))

# ---------------- Train/Val/Test split for TRUTH (labels used ONLY here) ----------------
idx_all = np.arange(N)
train_idx, temp_idx = train_test_split(idx_all, test_size=0.1, random_state=SEED, stratify=y_bin)
val_idx,   test_idx = train_test_split(temp_idx, test_size=2/3, random_state=SEED, stratify=y_bin[temp_idx])
print(f"Split sizes: train={len(train_idx)} val={len(val_idx)} test={len(test_idx)}")

# ---------------- Student: FP -> Z_teacher (align on TRAIN only) ----------------
# Standardize FP for student using TRAIN ONLY
sc_student = StandardScaler().fit(FP[train_idx])
FP_std_all = sc_student.transform(FP).astype(np.float32)

class StudentNet(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.f = nn.Sequential(
            nn.Linear(d_in, 512), nn.ReLU(),
            nn.Linear(512, 256), nn.ReLU(),
            nn.Linear(256, d_out)
        )
    def forward(self, x): return self.f(x)

X_all_t   = torch.tensor(FP_std_all, dtype=torch.float32, device=device)
Z_teach_t = torch.tensor(Z_teacher,  dtype=torch.float32, device=device)
student   = StudentNet(FP_std_all.shape[1], DIM).to(device)
opt_s     = torch.optim.Adam(student.parameters(), lr=LR_STUDENT)

def train_student():
    idxs = np.arange(N)
    for ep in range(1, EPOCHS_STUDENT+1):
        np.random.shuffle(idxs)
        tot=0; n=0
        for i in range(0, N, BATCH_STUDENT):
            sb = idxs[i:i+BATCH_STUDENT]
            xb = X_all_t[sb]
            zt = Z_teach_t[sb]
            zhat = student(xb)

            # vector loss (MSE in embedding space)
            loss_vec = F.mse_loss(zhat, zt)

            # pairwise cosine loss on sampled pairs from this mini-batch
            with torch.no_grad():
                b = len(sb)
                p = min(PAIR_SAMPLES, b*b)
                ii = torch.randint(0, b, (p,), device=device)
                jj = torch.randint(0, b, (p,), device=device)
                t_i = F.normalize(zt,   dim=1)[ii]
                t_j = F.normalize(zt,   dim=1)[jj]
                cos_t = (t_i * t_j).sum(1)

            p_i = F.normalize(zhat, dim=1)[ii]
            p_j = F.normalize(zhat, dim=1)[jj]
            cos_p = (p_i * p_j).sum(1)
            loss_pair = F.mse_loss(cos_p, cos_t)

            loss = loss_vec + PAIR_LOSS_W*loss_pair
            opt_s.zero_grad(); loss.backward(); opt_s.step()
            tot += float(loss.item())*len(sb); n += len(sb)

        if ep in {1,2,3,5,10,20,EPOCHS_STUDENT}:
            with torch.no_grad():
                zhat_all = F.normalize(student(X_all_t), dim=1)
                zt_all   = F.normalize(Z_teach_t, dim=1)
                m = min(PAIR_EVAL, N*N)
                ii = torch.randint(0, N, (m,), device=device)
                jj = torch.randint(0, N, (m,), device=device)
                cos_p = (zhat_all[ii] * zhat_all[jj]).sum(1).cpu().numpy()
                cos_t = (zt_all[ii]   * zt_all[jj]).sum(1).cpu().numpy()
                pr = pearsonr(cos_t, cos_p)[0]; sp = spearmanr(cos_t, cos_p)[0]
            print(f"[Student] ep {ep:02d}  train_loss={tot/max(1,n):.5f}  val_pair_Pearson={pr:.3f}  Spearman={sp:.3f}")

train_student()

# Predicted Z (normalize), then Procrustes align on TRAIN ONLY
with torch.no_grad():
    Z_pred0 = student(X_all_t).cpu().numpy().astype(np.float32)
Z_pred0 /= (np.linalg.norm(Z_pred0, axis=1, keepdims=True) + 1e-12)
# Orthogonal Procrustes using TRAIN ONLY
from scipy.linalg import orthogonal_procrustes
Q, _ = orthogonal_procrustes(Z_pred0[train_idx], Z_teacher[train_idx])
Z_pred = (Z_pred0 @ Q)
Z_pred /= (np.linalg.norm(Z_pred, axis=1, keepdims=True) + 1e-12)

# Geometry sanity
m = min(PAIR_EVAL, N*N)
ii = np.random.randint(0, N, size=m); jj = np.random.randint(0, N, size=m)
cos_pred = np.sum(Z_pred[ii]*Z_pred[jj], axis=1)
cos_teac = np.sum(Z_teacher[ii]*Z_teacher[jj], axis=1)
print("[Aligned student] pairwise cosine Pearson=%.3f  Spearman=%.3f" %
      (pearsonr(cos_teac, cos_pred)[0], spearmanr(cos_teac, cos_pred)[0]))

# ---------------- Truth MLP on Z_pred (train/val/test) ----------------
class TruthMLP(nn.Module):
    def __init__(self, d_in, d_h=64):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(d_in, d_h), nn.ReLU(), nn.Linear(d_h, 1))
    def forward(self, x): return self.net(x).squeeze(1)

def train_eval_truth_mlp(name, X, y_bin, train_idx, val_idx, test_idx, epochs=EPOCHS_TRUTH):
    sc = StandardScaler().fit(X[train_idx])
    X_all  = sc.transform(X).astype(np.float32)
    X_all_t= torch.tensor(X_all, dtype=torch.float32, device=device)
    y_all_t= torch.tensor(y_bin, dtype=torch.float32, device=device)

    n_pos = int(y_bin[train_idx].sum()); n_neg = int(len(train_idx) - n_pos)
    pos_weight = torch.tensor(max(1.0, n_neg / max(1, n_pos)), dtype=torch.float32, device=device)
    bce = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

    model = TruthMLP(X.shape[1], 64).to(device)
    opt   = torch.optim.Adam(model.parameters(), lr=LR_TRUTH)

    def iterate(idxs, train=True, batch=BATCH_TRUTH):
        model.train(mode=train); tot=0; n=0
        for i in range(0, len(idxs), batch):
            sb = idxs[i:i+batch]
            xb = X_all_t[sb]; yb = y_all_t[sb]
            loss = bce(model(xb), yb)
            if train: opt.zero_grad(set_to_none=True); loss.backward(); opt.step()
            tot += float(loss.item())*len(sb); n += len(sb)
        return tot/max(1,n)

    def probs(idxs):
        model.eval()
        with torch.no_grad():
            return torch.sigmoid(model(X_all_t[idxs])).cpu().numpy()

    for ep in range(1, epochs+1):
        _ = iterate(train_idx, True)
        if ep in {1,2,3,5,10,15,20}:
            pv = probs(val_idx)
            print(f"[{name}] ep {ep:02d}  val_auc={roc_auc_score(y_bin[val_idx], pv):.3f}  val_ap={average_precision_score(y_bin[val_idx], pv):.3f}")

    # tune τ on VAL
    pv_val = probs(val_idx)
    ts = np.linspace(0,1,501); best_t=0.5; best_f1=-1
    for t in ts:
        f1 = f1_score(y_bin[val_idx], (pv_val>=t).astype(int), zero_division=0)
        if f1>best_f1: best_f1=f1; best_t=t

    pv_test = probs(test_idx); yhat=(pv_test>=best_t).astype(int)
    auc = roc_auc_score(y_bin[test_idx], pv_test)
    ap  = average_precision_score(y_bin[test_idx], pv_test)
    f1  = f1_score(y_bin[test_idx], yhat)
    acc = accuracy_score(y_bin[test_idx], yhat)
    cm  = confusion_matrix(y_bin[test_idx], yhat)
    print(f"\n[{name}] τ*={best_t:.3f} | TEST AUC={auc:.3f}  AP={ap:.3f}  F1={f1:.3f}  Acc={acc:.3f}")
    print(cm)
    return {"name":name,"auc":auc,"ap":ap,"f1":f1,"acc":acc,"tau":best_t,"probs_all":None}

# Run the truth head on Z_pred
res_Zp = train_eval_truth_mlp("MLP on Z_pred", Z_pred, y_bin, train_idx, val_idx, test_idx)

# ---------------- Save artifacts (optional) ----------------
if SAVE_ARTIFACTS:
    os.makedirs(OUT_DIR, exist_ok=True)
    pd.DataFrame(S_hat_Csimple, index=formulas, columns=formulas).to_csv(os.path.join(OUT_DIR, "S_hat_Csimple.csv"))
    pd.DataFrame(Z_teacher, index=formulas).to_csv(os.path.join(OUT_DIR, "Z_teacher_from_Csimple.csv"))
    pd.DataFrame(Z_pred,    index=formulas).to_csv(os.path.join(OUT_DIR, "Z_student_pred.csv"))
    pd.DataFrame(FP,        index=formulas).to_csv(os.path.join(OUT_DIR, "semantic_features.csv"))
    print("Saved: S_hat_Csimple.csv, Z_teacher_from_Csimple.csv, Z_student_pred.csv, semantic_features.csv")
# ======================================================================


Device: cuda
N=1000 | y∈[0,1] | S(min,max) among known=0.000/1.000
Observed (train) pairs=44955 | Holdout (pairs)=4995 | Unobserved=449550
#Atoms found: 30
Fingerprint shape: (1000, 264)


Consider using tensor.detach() first. (Triggered internally at /pytorch/torch/csrc/autograd/generated/python_variable_methods.cpp:835.)
  print(f"[C-simple] ep {ep:02d}  train_loss={float(loss):.6f}")


[C-simple] ep 01  train_loss=0.063342
[C-simple holdout] MSE=0.061842  MAE=0.227510  Pearson=0.010  Spearman=0.006
[C-simple] ep 02  train_loss=0.061117
[C-simple holdout] MSE=0.059270  MAE=0.222276  Pearson=0.010  Spearman=0.007
[C-simple] ep 03  train_loss=0.058208
[C-simple holdout] MSE=0.055594  MAE=0.214556  Pearson=0.012  Spearman=0.008
[C-simple] ep 05  train_loss=0.049389
[C-simple holdout] MSE=0.044370  MAE=0.189520  Pearson=0.020  Spearman=0.013
[C-simple] ep 10  train_loss=0.021382
[C-simple holdout] MSE=0.022428  MAE=0.126556  Pearson=0.047  Spearman=0.027
[C-simple] ep 20  train_loss=0.016593
[C-simple holdout] MSE=0.016327  MAE=0.097882  Pearson=0.086  Spearman=0.046
[C-simple] ep 40  train_loss=0.012738
[C-simple holdout] MSE=0.013168  MAE=0.077652  Pearson=0.184  Spearman=0.117
[C-simple] ep 300  train_loss=0.003181
[C-simple holdout] MSE=0.003436  MAE=0.037350  Pearson=0.867  Spearman=0.782
[C-simple FINAL holdout] MSE=0.003436  MAE=0.037350  Pearson=0.867  Spearman=0.