## Config

In [5]:

import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from collections import Counter
import re
from tqdm.notebook import tqdm


CONFIG = {
    "data_path": "/content/algebra_train.jsonl",
    "output_path": "graph_transformer_cot.pth",
    "batch_size": 16,
    "epochs": 30,
    "d_model": 128,
    "n_heads": 8,
    "n_layers": 4,
    "lr": 3e-4
}

## Dataloading and format from `algebra_train.jsonl`

In [6]:
def robust_parse_line(line):
    line = line.strip()
    if not line:
        return None


    try:
        parsed = json.loads(line)
    except:
        # fix format issues
        line = re.sub(r"'([^']*)':", r'"\1":', line)
        line = re.sub(r',(\s*[}\]])', r'\1', line)
        line = re.sub(r'([{\s,])(\w+)(?=\s*:)', r'\1"\2"', line)
        try:
            parsed = json.loads(line)
        except:
            return None

    # tuples list dict handling
    if isinstance(parsed, (list, tuple)):
        return {"correct_solution": parsed, "perturbations": []}
    elif isinstance(parsed, dict):
        return parsed
    return None

def safe_extract_steps(data):
    #steps from any structure

    steps = []
    if isinstance(data, (list, tuple)):
        for item in data:
            if isinstance(item, dict) and all(k in item for k in ["st", "ot+1", "st+1"]):
                steps.append(item)
    elif isinstance(data, dict):
        if all(k in data for k in ["st", "ot+1", "st+1"]):
            steps.append(data)
    return steps

def collect_all_texts(jsonl_path):
    texts = []
    skipped = 0

    with open(jsonl_path, "r", encoding="utf-8") as f:
        for line_num, line in enumerate(f, 1):
            ex = robust_parse_line(line)
            if ex is None:
                skipped += 1
                continue

            # extract correct_solution
            correct_steps = safe_extract_steps(ex.get("correct_solution", []))
            for step in correct_steps:
                texts.extend([step["st"], step["ot+1"], step["st+1"]])

            # extract perturbations
            perturbations = ex.get("perturbations", [])
            if isinstance(perturbations, (list, tuple)):
                for pert in perturbations:
                    if isinstance(pert, dict):
                        solution_steps = safe_extract_steps(pert.get("solution", []))
                        for step in solution_steps:

                            texts.extend([step["st"], step["ot+1"], step["st+1"]])

    print(f"collected {len(texts)} texts, skipped {skipped} lines")
    return texts

## Tokenization/padding + encoding

In [7]:
class Vocab:
    def __init__(self, min_freq=1):
        self.min_freq = min_freq
        self.token2id = {"<pad>": 0, "<unk>": 1}
        self.id2token = {0: "<pad>", 1: "<unk>"}

    def build(self, texts):
        counter = Counter()
        for t in texts:
            counter.update(str(t).split())
        for tok, freq in counter.items():
            if freq >= self.min_freq and tok not in self.token2id:
                idx = len(self.token2id)
                self.token2id[tok] = idx
                self.id2token[idx] = tok

    def encode(self, text, max_len):
        toks = str(text).split()
        ids = [self.token2id.get(t, 1) for t in toks][:max_len]
        if len(ids) < max_len:
            ids += [0] * (max_len - len(ids))
        return ids

    def __len__(self):
        return len(self.token2id)

class CoTGraphDataset(Dataset):
    def __init__(self, jsonl_path, vocab, max_state_len=32, max_steps=8):
        self.vocab = vocab
        self.max_state_len = max_state_len
        self.max_steps = max_steps
        self.examples = []
        skipped = 0

        with open(jsonl_path, "r", encoding="utf-8") as f:
            for line_num, line in enumerate(f, 1):
                ex = robust_parse_line(line)
                if ex is None:
                    skipped += 1
                    continue

                try:
                    self._add_example(ex)
                except:
                    skipped += 1
                    continue

        print(f"dataset: {len(self.examples)} graphs, skipped {skipped} lines")

    def _add_example(self, ex):

        # correct_solution
        correct_steps = safe_extract_steps(ex.get("correct_solution", []))
        if correct_steps:
            self._add_solution_graph(correct_steps, label=0)

        # perturbations

        perturbations = ex.get("perturbations", [])
        if isinstance(perturbations, (list, tuple)):
            for pert in perturbations:
                if isinstance(pert, dict):
                    solution_steps = safe_extract_steps(pert.get("solution", []))
                    if solution_steps:
                        self._add_solution_graph(solution_steps, label=1)

    def _add_solution_graph(self, solution_steps, label):
        steps = solution_steps[:self.max_steps]
        if not steps:
            return

        node_to_id = {}
        node_strings = []
        node_types = []

        def get_node_id(text, node_type):
            text = str(text)
            if text not in node_to_id:
                node_id = len(node_strings)
                node_to_id[text] = node_id
                node_strings.append(text)
                node_types.append(node_type)
            return node_to_id[text]

        edges = []
        for step in steps:
            if isinstance(step, dict) and all(k in step for k in ["st", "ot+1", "st+1"]):
                st_id = get_node_id(step["st"], 0)
                op_id = get_node_id(step["ot+1"], 1)
                st1_id = get_node_id(step["st+1"], 0)
                edges.extend([(st_id, op_id), (op_id, st1_id)])

        if edges:
            self.examples.append({
                "node_strings": node_strings,
                "node_types": node_types,
                "edges": edges,
                "label": label
            })

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        ex = self.examples[idx]
        node_token_ids = torch.tensor([self.vocab.encode(s, self.max_state_len)
                                     for s in ex["node_strings"]], dtype=torch.long)
        edge_index = torch.tensor(ex["edges"], dtype=torch.long).t() if ex["edges"] else torch.empty((2, 0), dtype=torch.long)
        node_type_ids = torch.tensor(ex["node_types"], dtype=torch.long)
        return {
            "node_token_ids": node_token_ids,
            "edge_index": edge_index,
            "node_type_ids": node_type_ids,
            "label": torch.tensor(ex["label"], dtype=torch.float32)
        }

def collate_graphs(batch):
    return batch


## Graph Transformer

In [8]:
class GraphTransformerCoT(nn.Module):
    def __init__(self, vocab_size, d_model=128, n_heads=8, n_layers=4):
        super().__init__()


        self.d_model = d_model

        #token embedding - semantic
        self.token_embed = nn.Embedding(vocab_size, d_model)
        #positioning of each token
        self.pos_embed = nn.Parameter(torch.randn(32, d_model) * 0.02)
        self.node_type_embed = nn.Embedding(2, d_model)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=n_heads, dim_feedforward=512,
            dropout=0.15, activation="gelu", batch_first=True, norm_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)

        self.graph_pool = nn.Sequential(
            nn.Linear(d_model * 2, d_model),
            nn.GELU(), nn.Dropout(0.2),
            nn.Linear(d_model, 1)
        )

    def forward_single(self, node_token_ids, node_type_ids, edge_index):
        device = node_token_ids.device
        N, L = node_token_ids.shape
        #each node's token sequence is collapsed into a single vector by averaging --> make each node = one vector representing one step

        token_emb = self.token_embed(node_token_ids) + self.pos_embed[:L, :].unsqueeze(0)
        node_emb = token_emb.mean(dim=1)

        #node type injection - understand that it IS a step, not raw text
        node_emb = node_emb + self.node_type_embed(node_type_ids)
        h = self.transformer(node_emb.unsqueeze(0)).squeeze(0)


 # transformer operates over nodes --> attend to every other node / learn patterns + contradictions
        if edge_index.numel() > 0:

          #for every edge --> src --> tgt --> multiply their embeddings + average across all edges / if no edges, just zero vector ( understand if all steps AGREE with each other)
            src, tgt = edge_index
            edge_context = (h[src] * h[tgt]).mean(dim=0)
        else:
            edge_context = torch.zeros_like(h.mean(dim=0), device=device)

            #whole-graph representation
        global_emb = h.mean(dim=0)

      #final decision head (concatenate local relational edges and global reasoning nodes)
        final_emb = torch.cat([edge_context, global_emb], dim=-1)
        return self.graph_pool(final_emb).squeeze(-1)

    def forward(self, batch):

      #each example is separately + output a vector of scores
        device = next(self.parameters()).device
        return torch.stack([
            self.forward_single(
                ex["node_token_ids"].to(device),
                ex["node_type_ids"].to(device),
                ex["edge_index"].to(device)
            ) for ex in batch
        ])


        # overall this views the structure and logic of invariants, not guessing the answer


## Run + train + val

In [9]:
def main():
    print("vocab load")
    texts = collect_all_texts(CONFIG["data_path"])
    if not texts:
        print("no valid text")
        return

    vocab = Vocab(min_freq=2)
    vocab.build(texts)
    print(f"vocab size: {len(vocab)}")

    print("load dataset")
    dataset = CoTGraphDataset(CONFIG["data_path"], vocab)

    if len(dataset) == 0:
        print("no valid graphs")
        return

    train_size = int(0.9 * len(dataset))
    train_ds, val_ds = torch.utils.data.random_split(dataset, [train_size, len(dataset) - train_size])

    train_loader = DataLoader(train_ds, batch_size=CONFIG["batch_size"], shuffle=True,
                             collate_fn=collate_graphs, num_workers=0)
    val_loader = DataLoader(val_ds, batch_size=CONFIG["batch_size"], shuffle=False,
                           collate_fn=collate_graphs, num_workers=0)

    print(f"train: {len(train_ds)}, val: {len(val_ds)}")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f" device: {device}")

    model = GraphTransformerCoT(
        vocab_size=len(vocab), d_model=CONFIG["d_model"],
        n_heads=CONFIG["n_heads"], n_layers=CONFIG["n_layers"]
    ).to(device)

    optimizer = AdamW(model.parameters(), lr=CONFIG["lr"], weight_decay=0.01)
    scheduler = CosineAnnealingLR(optimizer, T_max=CONFIG["epochs"])

    best_val_acc = 0.0
    print("training...")

    for epoch in range(CONFIG["epochs"]):
        model.train()
        train_correct, train_total = 0, 0
        train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}")

        for batch in train_pbar:
            logits = model(batch)
            labels = torch.stack([ex["label"] for ex in batch]).to(device)
            loss = F.binary_cross_entropy_with_logits(logits, labels)

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            preds = (logits.sigmoid() > 0.5).long()
            train_correct += (preds == labels.long()).sum().item()
            train_total += len(batch)

            train_pbar.set_postfix(acc=f"{train_correct/train_total:.3f}")

        # validation
        model.eval()
        val_correct, val_total = 0, 0
        with torch.no_grad():
            for batch in val_loader:
                logits = model(batch)
                labels = torch.stack([ex["label"] for ex in batch]).to(device)
                preds = (logits.sigmoid() > 0.5).long()
                val_correct += (preds == labels.long()).sum().item()
                val_total += len(batch)

        train_acc = train_correct / train_total
        val_acc = val_correct / val_total

        print(f"epoch {epoch+1}: train {train_acc:.3f}, validation {val_acc:.3f}")

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save({
                "model_state_dict": model.state_dict(),
                "vocab": vocab.token2id,
                "val_acc": val_acc,
                "epoch": epoch
            }, CONFIG["output_path"])
            print(f"model saved: {val_acc:.3f}")

    print(f"---- DONE ---- best val acc: {best_val_acc:.3f}")

if __name__ == "__main__":
    main()


vocab load
collected 1176 texts, skipped 68 lines
vocab size: 130
load dataset
dataset: 266 graphs, skipped 68 lines
train: 239, val: 27
 device: cpu




training...


Epoch 1:   0%|          | 0/15 [00:00<?, ?it/s]

epoch 1: train 0.820, validation 0.926
model saved: 0.926


Epoch 2:   0%|          | 0/15 [00:00<?, ?it/s]

epoch 2: train 0.841, validation 0.926


Epoch 3:   0%|          | 0/15 [00:00<?, ?it/s]

epoch 3: train 0.845, validation 0.926


Epoch 4:   0%|          | 0/15 [00:00<?, ?it/s]

epoch 4: train 0.845, validation 0.926


Epoch 5:   0%|          | 0/15 [00:00<?, ?it/s]

epoch 5: train 0.849, validation 0.926


Epoch 6:   0%|          | 0/15 [00:00<?, ?it/s]

epoch 6: train 0.849, validation 0.926


Epoch 7:   0%|          | 0/15 [00:00<?, ?it/s]

epoch 7: train 0.845, validation 0.926


Epoch 8:   0%|          | 0/15 [00:00<?, ?it/s]

epoch 8: train 0.849, validation 0.926


Epoch 9:   0%|          | 0/15 [00:00<?, ?it/s]

epoch 9: train 0.849, validation 0.926


Epoch 10:   0%|          | 0/15 [00:00<?, ?it/s]

epoch 10: train 0.849, validation 0.926


Epoch 11:   0%|          | 0/15 [00:00<?, ?it/s]

epoch 11: train 0.870, validation 0.926


Epoch 12:   0%|          | 0/15 [00:00<?, ?it/s]

epoch 12: train 0.870, validation 0.926


Epoch 13:   0%|          | 0/15 [00:00<?, ?it/s]

epoch 13: train 0.690, validation 0.444


Epoch 14:   0%|          | 0/15 [00:00<?, ?it/s]

epoch 14: train 0.849, validation 0.926


Epoch 15:   0%|          | 0/15 [00:00<?, ?it/s]

epoch 15: train 0.858, validation 0.926


Epoch 16:   0%|          | 0/15 [00:00<?, ?it/s]

epoch 16: train 0.799, validation 0.926


Epoch 17:   0%|          | 0/15 [00:00<?, ?it/s]

epoch 17: train 0.883, validation 0.926


Epoch 18:   0%|          | 0/15 [00:00<?, ?it/s]

epoch 18: train 0.858, validation 0.778


Epoch 19:   0%|          | 0/15 [00:00<?, ?it/s]

epoch 19: train 0.866, validation 0.778


Epoch 20:   0%|          | 0/15 [00:00<?, ?it/s]

epoch 20: train 0.874, validation 0.852


Epoch 21:   0%|          | 0/15 [00:00<?, ?it/s]

epoch 21: train 0.866, validation 0.815


Epoch 22:   0%|          | 0/15 [00:00<?, ?it/s]

epoch 22: train 0.900, validation 0.926


Epoch 23:   0%|          | 0/15 [00:00<?, ?it/s]

epoch 23: train 0.879, validation 0.926


Epoch 24:   0%|          | 0/15 [00:00<?, ?it/s]

epoch 24: train 0.828, validation 0.926


Epoch 25:   0%|          | 0/15 [00:00<?, ?it/s]

epoch 25: train 0.916, validation 0.926


Epoch 26:   0%|          | 0/15 [00:00<?, ?it/s]

epoch 26: train 0.904, validation 0.926


Epoch 27:   0%|          | 0/15 [00:00<?, ?it/s]

epoch 27: train 0.895, validation 0.889


Epoch 28:   0%|          | 0/15 [00:00<?, ?it/s]

epoch 28: train 0.900, validation 0.815


Epoch 29:   0%|          | 0/15 [00:00<?, ?it/s]

epoch 29: train 0.904, validation 0.778


Epoch 30:   0%|          | 0/15 [00:00<?, ?it/s]

epoch 30: train 0.883, validation 0.815
---- DONE ---- best val acc: 0.926


In [10]:

import torch

model = torch.load('/content/graph_transformer_cot.pth')
print(model)


{'model_state_dict': OrderedDict({'pos_embed': tensor([[ 1.5626e-02,  9.7771e-03,  1.0137e-02,  ...,  3.5140e-03,
         -2.5266e-02, -4.9310e-03],
        [-1.8013e-03, -7.5128e-03,  7.3562e-03,  ...,  2.0521e-03,
         -6.7196e-03, -2.9031e-02],
        [ 1.4919e-02,  1.7845e-02, -4.4561e-02,  ...,  9.3146e-05,
         -2.3671e-02,  3.6148e-02],
        ...,
        [ 2.0777e-02,  2.7618e-02,  5.8039e-02,  ..., -9.2207e-03,
          2.9853e-02,  1.6019e-02],
        [-1.5150e-02,  2.3560e-02,  5.4209e-02,  ...,  1.1381e-02,
          3.0454e-02,  8.2652e-03],
        [-1.4239e-02, -3.0451e-02,  2.4185e-02,  ...,  1.6767e-02,
          2.5174e-02,  3.3133e-04]]), 'token_embed.weight': tensor([[ 0.1306,  0.1934, -1.3739,  ...,  0.1465,  1.1241, -1.0564],
        [ 0.8012, -0.9504,  1.0758,  ...,  0.5880,  0.7948, -1.8650],
        [ 0.5755,  1.0099,  2.3150,  ...,  0.1781, -0.2773, -0.9121],
        ...,
        [ 0.0830,  0.0730, -0.4637,  ..., -0.6043,  0.0817, -0.5489],
     