In [3]:
# joint_edges_only_semantic_fp.py
"""
Graph Completion (APPNP-style) using semantic fingerprints (truth-probe mini-worlds + structural counts)
Node label prediction removed: task is only predicting missing similarities.
"""

import os, random, math, re, hashlib
import numpy as np
import pandas as pd
import scipy.sparse as sp
from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import torch
import torch.nn as nn
import torch.nn.functional as F

# ---------------------- CONFIG ----------------------
CSV_PATH           = "/content/30_1000_base.csv"  # <<-- set your file path
SEED               = 7
SUBSET_SIZE        = None       # None = full dataset
OBS_MISSING_FRAC   = 0.5
HOLDOUT_EDGE_FRAC  = 0.10
EMBED_DIM          = 128

# Semantic FP params
M_PROBES           = 256

# Training hyperparams
EPOCHS             = 200
BATCH_EDGES_SIZE   = 40000
LR                 = 1e-3
APPNP_K            = 10
APPNP_ALPHA        = 0.1
EDGE_TEMP          = 1.0
EDGE_LOSS_W        = 1.0
BLOCK_PRED         = 128
DEVICE             = torch.device("cuda" if torch.cuda.is_available() else "cpu")
OUT_ARTIFACT       = "joint_edges_artifacts_semantic_fp.npz"

# reproducibility
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
print("Device:", DEVICE)

# ---------------------- Utilities ----------------------
def load_sim_and_labels(csv_path):
    df = pd.read_csv(csv_path)
    if df.shape[1] < 2:
        raise ValueError("CSV must have at least two columns: label + similarity columns")
    # Labels are ignored
    sim = df.iloc[:, 1:].to_numpy(dtype=np.float32)
    if sim.shape[0] != sim.shape[1]:
        raise ValueError(f"Similarity submatrix must be square, got {sim.shape}")
    formulas = list(df.columns[1:])
    return sim, formulas

def sample_subset(sim, formulas, subset_size=None, seed=SEED):
    N = sim.shape[0]
    if subset_size is None or subset_size >= N:
        idx = np.arange(N, dtype=np.int64)
    else:
        rng = np.random.RandomState(seed)
        idx = rng.choice(N, subset_size, replace=False).astype(np.int64)
    return sim[np.ix_(idx, idx)].copy(), [formulas[i] for i in idx.tolist()], idx

def upper_pairs(N):
    I, J = np.triu_indices(N, k=1)
    return np.stack([I, J], axis=1)

def build_obs_pairs_from_sim(sim, obs_fraction=OBS_MISSING_FRAC, seed=SEED):
    N = sim.shape[0]
    all_pairs = upper_pairs(N)
    total = len(all_pairs)
    m_obs = max(1, int(obs_fraction * total))
    rng = np.random.RandomState(seed)
    sel = rng.choice(total, m_obs, replace=False)
    obs_pairs_all = all_pairs[sel].astype(np.int64)
    obs_vals_all = sim[obs_pairs_all[:,0], obs_pairs_all[:,1]].astype(np.float32)
    obs_set = set(map(tuple, obs_pairs_all.tolist()))
    unobs_pairs = np.array([p for p in all_pairs.tolist() if tuple(p) not in obs_set], dtype=np.int64)
    return obs_pairs_all, obs_vals_all, unobs_pairs

def split_holdout(obs_pairs_all, obs_vals_all, holdout_frac=HOLDOUT_EDGE_FRAC, seed=SEED):
    rng = np.random.RandomState(seed)
    n = len(obs_pairs_all)
    perm = rng.permutation(n)
    m_hold = max(1, int(holdout_frac * n))
    hold_idx = perm[:m_hold]
    train_idx = perm[m_hold:]
    return obs_pairs_all[train_idx], obs_vals_all[train_idx], obs_pairs_all[hold_idx], obs_vals_all[hold_idx]

def build_A_hat_from_obs(N, pairs_idx, weights, edge_temp=EDGE_TEMP, device=DEVICE):
    pairs_idx = np.asarray(pairs_idx, dtype=np.int64)
    weights = np.asarray(weights, dtype=np.float32)
    w = np.clip(weights, 0.0, 1.0)
    if edge_temp != 1.0:
        w = np.exp(edge_temp * w)
    rows = np.concatenate([pairs_idx[:,0], pairs_idx[:,1]])
    cols = np.concatenate([pairs_idx[:,1], pairs_idx[:,0]])
    vals = np.concatenate([w, w]).astype(np.float32)
    rows = np.concatenate([rows, np.arange(N, dtype=np.int64)])
    cols = np.concatenate([cols, np.arange(N, dtype=np.int64)])
    vals = np.concatenate([vals, np.ones(N, dtype=np.float32)])
    A = sp.coo_matrix((vals, (rows, cols)), shape=(N, N)).tocsr()
    d = np.array(A.sum(1)).ravel()
    d_inv_sqrt = 1.0 / np.sqrt(np.maximum(d, 1e-12))
    D_inv_sqrt = sp.diags(d_inv_sqrt)
    A_hat = (D_inv_sqrt @ A @ D_inv_sqrt).tocoo()
    idx = np.vstack([A_hat.row, A_hat.col]).astype(np.int64)
    val = A_hat.data.astype(np.float32)
    i = torch.tensor(idx, dtype=torch.long, device=device)
    v = torch.tensor(val, dtype=torch.float32, device=device)
    return torch.sparse_coo_tensor(i, v, (N, N), device=device)

# ---------------------- Semantic fingerprint ----------------------
OP_MAP = {"→":" IMP ","⇒":" IMP ","=>":" IMP ","->":" IMP ",
          "↔":" IFF ","<=>":" IFF ","<->":" IFF ",
          "⊑":" SUB ",
          "⊓":" AND ","∧":" AND ","&&":" AND ",
          "⊔":" OR  ","∨":" OR  ","||":" OR  ",
          "¬":" NOT ","~":" NOT ","!":" NOT "}
BIN_OPS, UNARY_OPS = {"AND","OR","IMP","IFF","SUB"}, {"NOT"}
TOKEN_RE = re.compile(r"[A-Za-z0-9_]+|[()]")

def norm_text(s):
    s = str(s)
    for k,v in OP_MAP.items(): s = s.replace(k,v)
    return s

def lex(s): return TOKEN_RE.findall(norm_text(s))
def is_atom(t): return t not in BIN_OPS|UNARY_OPS|{"(",")"}

class Parser:
    def __init__(self,toks): self.toks=toks; self.i=0
    def peek(self): return self.toks[self.i] if self.i<len(self.toks) else None
    def pop(self): t=self.peek(); self.i += (1 if t is not None else 0); return t
    PREC = {"IFF":1,"IMP":2,"SUB":2,"OR":3,"AND":4}
    RIGHT = {"IMP","IFF","SUB"}
    def expr(self,minp):
        node=self.unary()
        while True:
            op=self.peek()
            if op in BIN_OPS:
                prec=self.PREC.get(op,0)
                if prec<minp: break
                self.pop()
                nextp = prec if op in self.RIGHT else prec+1
                rhs=self.expr(nextp)
                node=("BIN",op,node,rhs)
            else: break
        return node
    def unary(self):
        t=self.peek()
        if t in UNARY_OPS:
            self.pop(); c=self.unary(); return ("UN",t,c)
        if t=="(":
            self.pop(); n=self.expr(0); assert self.pop()==")","Missing ')'"
            return n
        a=self.pop()
        return ("ATOM", a if a is not None else "x")

def parse_formula(s):
    try: return Parser(lex(s)).expr(0)
    except: return ("ATOM","x")

def atoms_in(node, acc=None):
    if acc is None: acc=set()
    k=node[0]
    if k=="ATOM": acc.add(node[1]); return acc
    if k=="UN": return atoms_in(node[2], acc)
    if k=="BIN": atoms_in(node[2], acc); atoms_in(node[3], acc); return acc
    return acc

def depth(node):
    k=node[0]
    if k=="ATOM": return 1
    if k=="UN": return 1+depth(node[2])
    if k=="BIN": return 1+max(depth(node[2]), depth(node[3]))
    return 1

def _u64_from_str(s: str) -> int:
    h = hashlib.blake2b(s.encode('utf-8'), digest_size=8).digest()
    return int.from_bytes(h, 'big')

def bernoulli_from_name(atom: str, probe_idx: int, p: float, seed: int) -> bool:
    u = (_u64_from_str(f"{atom}|{probe_idx}|{seed}") % (1<<53)) / float(1<<53)
    return u < p

class ProbeEnv:
    def __init__(self, base_env: dict, probe_idx: int, bias_p: float, seed: int):
        self.base = base_env; self.m = probe_idx; self.p = float(bias_p); self.seed = int(seed)
        self.cache = {}
    def get(self, atom: str) -> bool:
        if atom in self.base: return bool(self.base[atom])
        if atom in self.cache: return self.cache[atom]
        v = bernoulli_from_name(atom, self.m, self.p, self.seed)
        self.cache[atom] = v
        return v

def eval_ast(node, env_obj):
    k=node[0]
    if k=="ATOM": return bool(env_obj.get(node[1]))
    if k=="UN":
        _,op,c = node
        v = eval_ast(c, env_obj)
        return (not v)
    if k=="BIN":
        _,op,l,r = node
        a = eval_ast(l, env_obj); b = eval_ast(r, env_obj)
        if op=="AND": return a and b
        if op=="OR":  return a or b
        if op in ("IMP","SUB"): return (not a) or b
        if op=="IFF": return a==b
    return False

def op_counts(toks):
    return toks.count("AND"), toks.count("OR"), toks.count("NOT"), toks.count("IMP")+toks.count("SUB"), toks.count("IFF")

def build_semantic_FP(formulas, M_probes=M_PROBES, seed=SEED):
    N = len(formulas)
    asts = [parse_formula(s) for s in formulas]
    all_atoms = sorted(set().union(*[atoms_in(t) for t in asts]))
    A = len(all_atoms)
    rng = np.random.default_rng(seed)
    biases = np.concatenate([np.full(M_probes//2,0.3), np.full(M_probes - M_probes//2,0.7)])
    assignments = []
    for p in biases:
        vals = rng.random(A) < p
        env = {a: bool(v) for a,v in zip(all_atoms, vals)}
        assignments.append(env)
    T_mat = np.zeros((N, M_probes), dtype=np.float32)
    for i,ast in enumerate(asts):
        for m_i,base_env in enumerate(assignments):
            env_obj = ProbeEnv(base_env, probe_idx=m_i, bias_p=biases[m_i], seed=seed)
            T_mat[i,m_i] = 1.0 if eval_ast(ast, env_obj) else 0.0
    struct_rows=[]
    for s,ast in zip(formulas,asts):
        toks = lex(s)
        ac = len(atoms_in(ast,set()))
        d  = depth(ast)
        c_and, c_or, c_not, c_imp, c_iff = op_counts(toks)
        struct_rows.append([ac,d,c_and,c_or,c_not,c_imp,c_iff,len(toks)])
    STRUCT = np.array(struct_rows,dtype=np.float32)
    FP = np.concatenate([T_mat, STRUCT], axis=1).astype(np.float32)
    return FP

# ---------------------- Model ----------------------
class APPNPEncoder(nn.Module):
    def __init__(self, d_in, d_hidden, d_out, p_drop=0.1):
        super().__init__()
        self.lin1 = nn.Linear(d_in,d_hidden)
        self.lin2 = nn.Linear(d_hidden,d_out)
        self.drop = nn.Dropout(p_drop)
    def forward(self,X,A_hat,K=APPNP_K,alpha=APPNP_ALPHA):
        H0 = F.relu(self.lin1(X)); H0=self.drop(H0)
        H0 = self.lin2(H0)
        H = H0
        for _ in range(K):
            H = (1-alpha)*torch.sparse.mm(A_hat,H)+alpha*H0
        return H

class JointModel(nn.Module):
    def __init__(self,d_in,d_hidden,d_emb):
        super().__init__()
        self.enc = APPNPEncoder(d_in,d_hidden,d_emb)
        self.scale = nn.Parameter(torch.tensor(1.0))
        self.bias  = nn.Parameter(torch.tensor(0.0))
    def forward_embeddings(self,X,A_hat):
        H0 = self.enc(X,A_hat)
        norms = H0.norm(dim=1, keepdim=True).clamp_min(1e-6)
        Z = H0 / norms.clamp_max(2.0)
        return F.normalize(Z, dim=1)
    def decode_pairs(self,Z,pairs_idx):
        i = torch.as_tensor(pairs_idx[:,0],dtype=torch.long,device=Z.device)
        j = torch.as_tensor(pairs_idx[:,1],dtype=torch.long,device=Z.device)
        dp = (Z[i]*Z[j]).sum(1)
        return torch.sigmoid(self.scale*dp+self.bias)

# ---------------------- Metrics ----------------------
# ---------------------- Metrics ----------------------
def evaluate_pairs(truth, pred, clip_r2_for_display=True):
    """
    Returns dictionary of metrics. R2 is computed exactly.
    Optionally returns a capped version for display without altering actual R2.
    """
    mse = float(mean_squared_error(truth, pred))
    mae = float(mean_absolute_error(truth, pred))
    try:
        r2_exact = float(r2_score(truth, pred))  # true R2
        r2_display = r2_exact
        if clip_r2_for_display and r2_display < -10:
            r2_display = -10.0  # only for readability
    except:
        r2_exact = r2_display = float('nan')
    if len(truth) > 1:
        pe = float(pearsonr(truth, pred)[0])
        sp = float(spearmanr(truth, pred)[0])
    else:
        pe = sp = float('nan')
    return {"mse": mse, "mae": mae, "r2": r2_exact, "r2_display": r2_display,
            "pearson": pe, "spearman": sp}

def print_metrics(name, m, display_r2=True):
    r2_val = m['r2']
    print(f"[{name}] MSE={m['mse']:.6f} MAE={m['mae']:.6f} "
          f"R2={r2_val:.3f} Pearson={m['pearson']:.3f} Spearman={m['spearman']:.3f}")


# ---------------------- Training ----------------------
def train_edges_only(FP_in, train_edge_pairs, train_edge_vals, holdout_edge_pairs, holdout_edge_vals,
                     d_hidden=256, d_emb=EMBED_DIM, epochs=EPOCHS, batch_edges_size=BATCH_EDGES_SIZE, lr=LR,
                     edge_loss_w=EDGE_LOSS_W):

    N, d_in = FP_in.shape
    X = torch.tensor(FP_in, dtype=torch.float32, device=DEVICE)
    model = JointModel(d_in, d_hidden, d_emb).to(DEVICE)
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    mse_loss = nn.MSELoss()

    # Build normalized adjacency
    A_hat = build_A_hat_from_obs(N, train_edge_pairs, train_edge_vals, edge_temp=EDGE_TEMP, device=DEVICE)
    M = len(train_edge_pairs)

    for ep in range(1, epochs + 1):
        model.train()
        Z = model.forward_embeddings(X, A_hat)

        # Edge batch
        if M > 0:
            sample_size = min(batch_edges_size, M)
            idxs = np.random.randint(0, M, size=sample_size)
            batch_edge_pairs = train_edge_pairs[idxs]
            batch_edge_vals = train_edge_vals[idxs].astype(np.float32)
            batch_edge_vals_t = torch.tensor(batch_edge_vals, dtype=torch.float32, device=DEVICE)
            pred_batch = model.decode_pairs(Z, batch_edge_pairs).clamp(0,1)
            loss_edge = mse_loss(pred_batch, batch_edge_vals_t)
        else:
            loss_edge = torch.tensor(0.0, device=DEVICE)

        # Backprop
        opt.zero_grad(set_to_none=True)
        loss_edge.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 2.0)
        opt.step()

        if ep in {1,2,3,5,10,20,50,100,epochs}:
            model.eval()
            with torch.no_grad():
                if len(holdout_edge_pairs) > 0:
                    i = torch.as_tensor(holdout_edge_pairs[:,0], dtype=torch.long, device=DEVICE)
                    j = torch.as_tensor(holdout_edge_pairs[:,1], dtype=torch.long, device=DEVICE)
                    ph = torch.sigmoid(model.scale*(Z[i]*Z[j]).sum(1)+model.bias).clamp(0,1).cpu().numpy()
                    metrics = evaluate_pairs(holdout_edge_vals, ph)
                    print(f"[Epoch {ep}] train_edge_loss={float(loss_edge):.6f}")
                    print_metrics("Holdout edges", metrics)

    # Final embeddings
    model.eval()
    with torch.no_grad():
        Z_final = model.forward_embeddings(X, A_hat).cpu().numpy()
    return model, Z_final

# ---------------------- Pipeline ----------------------
def run_edges_pipeline():
    # 1) Load CSV
    sim_full, formulas_full = load_sim_and_labels(CSV_PATH)
    print("Loaded sim shape", sim_full.shape)

    # 2) Sample subset
    sim, formulas, picked_idx = sample_subset(sim_full, formulas_full, SUBSET_SIZE, SEED)
    N = sim.shape[0]
    print("Using N =", N, "formulas (subset).")

    # 3) Observed edges & unobserved
    obs_pairs_all, obs_vals_all, unobs_pairs = build_obs_pairs_from_sim(sim, OBS_MISSING_FRAC, SEED)
    print("Observed edges:", len(obs_pairs_all), "Unobserved pairs:", len(unobs_pairs))

    # 4) Split train vs holdout
    train_pairs, train_vals, hold_pairs, hold_vals = split_holdout(obs_pairs_all, obs_vals_all, HOLDOUT_EDGE_FRAC, SEED)
    print("Train observed pairs:", len(train_pairs), "Holdout pairs:", len(hold_pairs))

    # 5) Semantic fingerprints
    FP = build_semantic_FP(formulas, M_probes=M_PROBES, seed=SEED)
    print("FP shape:", FP.shape)

    # 6) Train model
    model, Z = train_edges_only(FP, train_pairs, train_vals, hold_pairs, hold_vals,
                                d_hidden=256, d_emb=EMBED_DIM, epochs=EPOCHS, batch_edges_size=BATCH_EDGES_SIZE, lr=LR)

    # 7) Final holdout evaluation
    if len(hold_pairs) > 0:
        Zt = torch.tensor(Z, dtype=torch.float32, device=DEVICE)
        i = torch.as_tensor(hold_pairs[:,0], dtype=torch.long, device=DEVICE)
        j = torch.as_tensor(hold_pairs[:,1], dtype=torch.long, device=DEVICE)
        with torch.no_grad():
            ph = torch.sigmoid(model.scale*(Zt[i]*Zt[j]).sum(1)+model.bias).clamp(0,1).cpu().numpy()
        m_final = evaluate_pairs(hold_vals, ph)
        print_metrics("Final Holdout", m_final)

    # 8) Example unobserved predictions
    if len(unobs_pairs) > 0:
        sample_unobs = unobs_pairs[:200]
        ii = torch.as_tensor(sample_unobs[:,0], dtype=torch.long, device=DEVICE)
        jj = torch.as_tensor(sample_unobs[:,1], dtype=torch.long, device=DEVICE)
        with torch.no_grad():
            ph_unobs = torch.sigmoid(model.scale*(Zt[ii]*Zt[jj]).sum(1)+model.bias).clamp(0,1).cpu().numpy()
        print("Example predicted unobserved (first 20):", ph_unobs[:20])

    # 9) Save embeddings
    np.savez(OUT_ARTIFACT, Z=Z, picked_idx=picked_idx)
    print("Saved artifacts to", OUT_ARTIFACT)

if __name__=="__main__":
    run_edges_pipeline()


Device: cpu
Loaded sim shape (1000, 1000)
Using N = 1000 formulas (subset).
Observed edges: 249750 Unobserved pairs: 249750
Train observed pairs: 224775 Holdout pairs: 24975
FP shape: (1000, 264)
[Epoch 1] train_edge_loss=0.065019
[Holdout edges] MSE=0.064832 MAE=0.233688 R2=-3.900 Pearson=0.146 Spearman=0.177
[Epoch 2] train_edge_loss=0.062936
[Holdout edges] MSE=0.063093 MAE=0.230358 R2=-3.769 Pearson=0.142 Spearman=0.183
[Epoch 3] train_edge_loss=0.058305
[Holdout edges] MSE=0.058140 MAE=0.220419 R2=-3.395 Pearson=0.144 Spearman=0.181
[Epoch 5] train_edge_loss=0.045988
[Holdout edges] MSE=0.045856 MAE=0.191659 R2=-2.466 Pearson=0.138 Spearman=0.156
[Epoch 10] train_edge_loss=0.031432
[Holdout edges] MSE=0.031338 MAE=0.152474 R2=-1.369 Pearson=0.255 Spearman=0.263
[Epoch 20] train_edge_loss=0.018238
[Holdout edges] MSE=0.017949 MAE=0.103665 R2=-0.357 Pearson=0.336 Spearman=0.320
[Epoch 50] train_edge_loss=0.009705
[Holdout edges] MSE=0.009689 MAE=0.074563 R2=0.268 Pearson=0.559 Spear

In [9]:
! pip install torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.7.0-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.7/63.7 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.7.0-py3-none-any.whl (1.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m26.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.7.0
