# ToAE Diagnostics Prototype (Toy GRU Agent)

This notebook implements a minimal prototype of the ToAE diagnostics described in Section 4–5.
It is intentionally self-contained and runs on CPU. It demonstrates:

- a small GRU-based model used as a toy folding system;
- `fold_fn`, `converge`, and `perturb` primitives;
- PSC (Predictive Self-Consistency) probe training and evaluation;
- Folding depth (D) estimation;
- Attractor Stability (AS) estimation via Gaussian perturbations.

Run the notebook from top to bottom. Outputs and example metrics will be saved to `/mnt/data`.


In [None]:
# Setup: imports and simple utilities
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import os, json, math, random
from tqdm.notebook import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)


In [None]:
# Define a small GRU-based folding model
class ToyGRUAgent(nn.Module):
    def __init__(self, input_dim=8, hidden_dim=64):
        super().__init__()
        self.encoder = nn.Linear(input_dim, hidden_dim)
        self.gru_cell = nn.GRUCell(hidden_dim, hidden_dim)
        self.readout = nn.Linear(hidden_dim, input_dim)  # optional decoder
    def initial_state(self, batch_size=1):
        return torch.zeros(batch_size, self.gru_cell.hidden_size, device=device)
    def encode_input(self, x):
        # x: (batch, input_dim)
        return self.encoder(x)
    def fold_step(self, s, x):
        # s: (batch, hidden), x: (batch, input_dim)
        ex = self.encode_input(x)
        s_next = self.gru_cell(ex, s)
        return s_next
    def forward(self, s, x, steps=1):
        # apply fold steps
        for _ in range(steps):
            s = self.fold_step(s, x)
        return s


In [None]:
# Synthetic dataset generator: simple sequences where next input depends on previous token
def generate_synthetic_sequences(N=1000, seq_len=10, input_dim=8):
    data = []
    for _ in range(N):
        seq = np.random.randn(seq_len, input_dim).astype(np.float32)
        # add a small deterministic trend so states have some structure
        seq = seq + np.linspace(0,1,seq_len)[:,None]*0.05
        data.append(seq)
    return data

# create data
data = generate_synthetic_sequences(N=1200, seq_len=12, input_dim=8)
print('Generated', len(data), 'sequences of length', len(data[0]))


In [None]:
# Utilities: state extractor and collecting (s_t, s_{t+1}) pairs
def extract_state(model, seq, t, pool='last'):
    # seq: numpy array (seq_len, input_dim)
    x = torch.tensor(seq[t], device=device).unsqueeze(0)
    s0 = model.initial_state(batch_size=1)
    # run fold for single step starting from zero-state to get state at t (simple proxy)
    s = model.fold_step(s0, x)
    return s.detach().cpu().numpy().squeeze()

def collect_state_pairs(model, sequences, max_pairs=2000):
    pairs = []
    for seq in sequences:
        L = len(seq)
        for t in range(L-1):
            s = extract_state(model, seq, t)
            s_next = extract_state(model, seq, t+1)
            pairs.append((s, s_next))
            if len(pairs) >= max_pairs:
                return pairs
    return pairs

# initialize model and collect pairs
model = ToyGRUAgent(input_dim=8, hidden_dim=64).to(device)
pairs = collect_state_pairs(model, data, max_pairs=1500)
print('Collected', len(pairs), 'state pairs')


In [None]:
# Train a small probe (MLP) to predict s_{t+1} from s_t
class Probe(nn.Module):
    def __init__(self, dim=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim),
            nn.ReLU(),
            nn.Linear(dim, dim)
        )
    def forward(self, x):
        return self.net(x)

def train_probe(pairs, epochs=30, batch_size=64, lr=1e-3):
    X = np.stack([p[0] for p in pairs])
    Y = np.stack([p[1] for p in pairs])
    perm = np.random.permutation(len(X))
    split = int(0.8*len(X))
    train_idx = perm[:split]
    val_idx = perm[split:]
    X_train, Y_train = torch.tensor(X[train_idx], device=device), torch.tensor(Y[train_idx], device=device)
    X_val, Y_val = torch.tensor(X[val_idx], device=device), torch.tensor(Y[val_idx], device=device)
    probe = Probe(dim=X.shape[1]).to(device)
    opt = optim.Adam(probe.parameters(), lr=lr)
    loss_fn = nn.MSELoss()
    for ep in range(epochs):
        probe.train()
        perm = torch.randperm(X_train.size(0))
        losses = []
        for i in range(0, X_train.size(0), batch_size):
            idx = perm[i:i+batch_size]
            xb, yb = X_train[idx], Y_train[idx]
            pred = probe(xb)
            loss = loss_fn(pred, yb)
            opt.zero_grad(); loss.backward(); opt.step()
            losses.append(loss.item())
        if (ep+1) % 10 == 0:
            probe.eval()
            with torch.no_grad():
                val_loss = loss_fn(probe(X_val), Y_val).item()
            print(f'Probe epoch {ep+1} train_loss {np.mean(losses):.5f} val_loss {val_loss:.5f}')
    return probe, (X_val, Y_val)

probe, val_data = train_probe(pairs, epochs=40)


In [None]:
# Compute PSC: 1 - mse / var
def compute_PSC(probe, val_tuple):
    X_val, Y_val = val_tuple
    probe.eval()
    with torch.no_grad():
        pred = probe(X_val)
        mse = torch.mean((pred - Y_val)**2).item()
        var = torch.mean(Y_val**2).item() - torch.mean(Y_val).pow(2).item()
        PSC = 1.0 - mse / (var + 1e-9)
    return PSC, mse, var

PSC, mse, var = compute_PSC(probe, val_data)
print('PSC:', PSC, 'mse', mse, 'var', var)


In [None]:
# Define fold_fn and converge for this GRU model (using same input x repeatedly as a simple demo)
def fold_fn(model, s, x):
    # x: numpy or torch single-step input vector (1, input_dim)
    if not isinstance(x, torch.Tensor):
        x = torch.tensor(x, device=device).unsqueeze(0)
    if not isinstance(s, torch.Tensor):
        s = torch.tensor(s, device=device)
    with torch.no_grad():
        s_next = model.fold_step(s, x)
    return s_next

def converge(model, s0, x, fold_fn, max_iters=50, eps=1e-4, momentum=0.0):
    s = s0.clone().detach()
    for k in range(max_iters):
        s_next = fold_fn(model, s, x)
        if momentum > 0:
            s_next = momentum * s + (1.0 - momentum) * s_next
        if torch.norm(s_next - s).item() < eps:
            return s_next, k+1
        s = s_next
    return s, max_iters

# Quick test: converge starting from zero state with a fixed input
x0 = torch.tensor(data[0][0], device=device).unsqueeze(0)
s0 = model.initial_state(batch_size=1)
s_conv, steps = converge(model, s0, x0, fold_fn)
print('Converged in steps', steps, 'state norm', torch.norm(s_conv).item())


In [None]:
# Compute folding depth D for a set of inputs
def compute_D_for_inputs(model, sequences, fold_fn, Kmax=50, eps=1e-4):
    Ds = []
    for seq in sequences[:200]:
        x0 = torch.tensor(seq[0], device=device).unsqueeze(0)
        s0 = model.initial_state(batch_size=1)
        s_conv, steps = converge(model, s0, x0, fold_fn, max_iters=Kmax, eps=eps)
        Ds.append(steps)
    return Ds

Ds = compute_D_for_inputs(model, data, fold_fn, Kmax=50, eps=1e-5)
import statistics
print('D mean', statistics.mean(Ds), 'median', statistics.median(Ds))


In [None]:
# Compute Attractor Stability (AS) via Gaussian perturbations around attractor
def perturb_gaussian(s, sigma=0.01):
    # s: torch tensor (1, hidden)
    noise = torch.randn_like(s) * sigma * torch.norm(s)
    return s + noise

def compute_AS_for_input(model, seq, fold_fn, N=100, sigma=0.02):
    x0 = torch.tensor(seq[0], device=device).unsqueeze(0)
    s0 = model.initial_state(batch_size=1)
    a, _ = converge(model, s0, x0, fold_fn)
    diffs = []
    for i in range(N):
        s_pert = perturb_gaussian(a, sigma=sigma)
        a_i, _ = converge(model, s_pert, x0, fold_fn)
        diffs.append(torch.norm(a_i - a).item() / (torch.norm(a).item() + 1e-9))
    AS = 1.0 - float(np.mean(diffs))
    return AS, diffs

AS_example, diffs = compute_AS_for_input(model, data[0], fold_fn, N=80, sigma=0.02)
print('AS example', AS_example)


In [None]:
# Save example metrics and small report
out = {
    'PSC': float(PSC),
    'D_mean': float(sum(Ds)/len(Ds)),
    'D_median': float(sorted(Ds)[len(Ds)//2]),
    'AS_example': float(AS_example)
}
os.makedirs('/mnt/data/toae_notebook_outputs', exist_ok=True)
with open('/mnt/data/toae_notebook_outputs/report.json', 'w') as f:
    json.dump(out, f, indent=2)
print('Saved report to /mnt/data/toae_notebook_outputs/report.json')
