<a href="https://colab.research.google.com/github/srirambandi/compsci685/blob/main/train_and_eval.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Colab Cell – clone your repo at the top of the notebook
!git clone https://github.com/srirambandi/compsci685.git
%cd compsci685

fatal: destination path 'compsci685' already exists and is not an empty directory.
/content/compsci685


In [2]:
!pip install torch transformers datasets sympy tqdm scikit-learn

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [2]:
import sys
sys.path.insert(0, "gen_dataset")
sys.path.insert(0, "gen_dataset/src")
sys.path.insert(0, "training/") # treereg

In [3]:
import os
import math
import time

import pandas as pd
import numpy as np
import sympy as sp
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from utils import prefix_to_sympy, verify_solution, OPERATORS

from regularizer.regularizer_main import TreeRegularizer
from parse_tree_adapted import get_parse_dict_for_prefix_list

In [4]:
from google.colab import drive
drive.mount('/content/drive')
SAVE_DIR = "/content/drive/MyDrive/compsci685/checkpoints"
os.makedirs(SAVE_DIR, exist_ok=True)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [5]:
DATA_PATH = "gen_dataset/data/fin_dataset.csv"
cols = ["id", "equ_str", "equ_prefix", "sol_str", "sol_prefix"]

df = pd.read_csv(DATA_PATH, header=0, names=cols)
print(f"Total examples: {len(df)}")
df.head(5)

Total examples: 33874


Unnamed: 0,id,equ_str,equ_prefix,sol_str,sol_prefix
0,0,x^4*e^(x) + 4*x^3*e^(x) + y' - 1 = 0,add add add -1 mul pow x 4 exp x mul mul 4 pow...,c - x^4*e^(x) + x,add add c x mul mul -1 pow x 4 exp x
1,1,-2*y*ln(12) + y' = 0,add mul mul -2 y log 12 y',12^(c + 2*x),pow 12 add c mul 2 x
2,2,2*x + 9*y' - 9 = 0,add add -9 mul 2 x mul 9 y',c - x^2/9 + x,add add c x mul div -1 9 pow x 2
3,3,2*x^2 + x*y' - y = 0,add add mul -1 y mul 2 pow x 2 mul x y',c*x + x*(4 - 2*x) + x,add add x mul c x mul x add 4 mul -2 x
4,4,54*x*e^(y/9) + y' = 0,add mul mul 54 x exp mul div 1 9 y y',-9*ln(c + 3*x^2),mul -9 log add c mul 3 pow x 2


In [6]:
# stats about our dataset
df["equ_len"] = df["equ_prefix"].str.split().apply(len)
df["sol_len"] = df["sol_prefix"].str.split().apply(len)

df[["equ_len", "sol_len"]].describe().T

Unnamed: 0,count,mean,std,min,25%,50%,75%,max
equ_len,33874.0,13.184684,7.748491,1.0,8.0,11.0,15.0,115.0
sol_len,33874.0,10.013107,2.727231,2.0,8.0,10.0,11.0,28.0


In [7]:
# let's split the data into train, test and val like in the paper: https://github.com/facebookresearch/SymbolicMathematics/blob/main/split_data.py

N = len(df)
m = int(0.1 * N)       # let's take 10% as validation data set size for now.
assert 2 * m < N, "Pick smaller m!"

alpha = math.log(N - 0.5) / math.log(2 * m)

raw_idxs = [int(i**alpha) for i in range(1, 2*m + 1)]
val_idxs  = set(raw_idxs[::2])
test_idxs = set(raw_idxs[1::2])

all_idxs   = set(range(N))
train_idxs = all_idxs - val_idxs - test_idxs

# slice and reset index
train_df = df.iloc[sorted(train_idxs)].reset_index(drop=True)
val_df   = df.iloc[sorted(val_idxs)].reset_index(drop=True)
test_df  = df.iloc[sorted(test_idxs)].reset_index(drop=True)

print(f"Split sizes => train: {len(train_df)}, valid: {len(val_df)}, test: {len(test_df)}")

Split sizes => train: 27100, valid: 3387, test: 3387


In [8]:
os.makedirs("splits", exist_ok=True)

train_df.to_csv("splits/train.csv", index=False)
val_df.to_csv("splits/valid.csv", index=False)
test_df.to_csv("splits/test.csv", index=False)

print("Saved splits into ./splits/")

Saved splits into ./splits/


In [9]:
SPECIAL = {"<pad>":0,"<bos>":1,"<eos>":2}
# Assuming train_df is loaded and available for vocab creation
tokens_for_vocab = {t for seq in train_df.equ_prefix.str.split() for t in seq}
tokens_for_vocab |= {t for seq in train_df.sol_prefix.str.split() for t in seq}
word2idx = {w:i+len(SPECIAL) for i,w in enumerate(sorted(list(tokens_for_vocab)))} # Ensure sorted list for consistent mapping
word2idx.update(SPECIAL)
idx2word = {i:w for w,i in word2idx.items()}

PAD, BOS, EOS = word2idx["<pad>"], word2idx["<bos>"], word2idx["<eos>"]
VOCAB_SIZE = len(word2idx)

class ODEDataset(Dataset):
    def __init__(self, df, src_col, tgt_col, max_len=128, word2idx_map=None, pad_idx=None, bos_idx=None, eos_idx=None):
        self.src_original_token_lists = df[src_col].str.split().tolist()
        self.tgt_original_token_lists = df[tgt_col].str.split().tolist()
        self.max_len = max_len
        self.word2idx = word2idx_map
        self.PAD_IDX = pad_idx
        self.BOS_IDX = bos_idx
        self.EOS_IDX = eos_idx

    def __len__(self):
        return len(self.src_original_token_lists)

    def __getitem__(self, i):
        src_str_tokens = self.src_original_token_lists[i]
        tgt_str_tokens = self.tgt_original_token_lists[i]

        src_ids = [self.BOS_IDX] + [self.word2idx.get(t, self.PAD_IDX) for t in src_str_tokens] + [self.EOS_IDX]
        tgt_ids = [self.BOS_IDX] + [self.word2idx.get(t, self.PAD_IDX) for t in tgt_str_tokens] + [self.EOS_IDX]

        # Pad sequences
        padded_src_ids = src_ids[:self.max_len] + [self.PAD_IDX] * (self.max_len - len(src_ids))
        if len(padded_src_ids) > self.max_len: padded_src_ids = padded_src_ids[:self.max_len] # Ensure fixed length

        padded_tgt_ids = tgt_ids[:self.max_len] + [self.PAD_IDX] * (self.max_len - len(tgt_ids))
        if len(padded_tgt_ids) > self.max_len: padded_tgt_ids = padded_tgt_ids[:self.max_len]


        parse_d = get_parse_dict_for_prefix_list(src_str_tokens)
        if parse_d is None:
            parse_d = {}

        return {
            "input_ids": torch.tensor(padded_src_ids, dtype=torch.long),
            "src_len": min(len(src_ids), self.max_len), # True length before padding (capped at max_len), including BOS/EOS
            "labels": torch.tensor(padded_tgt_ids, dtype=torch.long),
            "tgt_len": min(len(tgt_ids), self.max_len),
            "parses": parse_d,
            "src_content_len": len(src_str_tokens) # Number of actual words/tokens in source
        }

def collate_fn(batch_list):
    input_ids_list = [item['input_ids'] for item in batch_list]
    # Ensure all tensors in input_ids_list have the same length before stacking
    # This should be guaranteed by __getitem__ if max_len is consistent
    max_len_check = input_ids_list[0].size(0)
    assert all(t.size(0) == max_len_check for t in input_ids_list), "Padding error: Tensors in batch have different lengths."

    input_ids = torch.stack(input_ids_list)
    src_lengths = torch.tensor([item['src_len'] for item in batch_list], dtype=torch.long)

    labels_list = [item['labels'] for item in batch_list]
    assert all(t.size(0) == max_len_check for t in labels_list), "Padding error: Label tensors have different lengths."
    labels = torch.stack(labels_list)

    tgt_lengths = torch.tensor([item['tgt_len'] for item in batch_list], dtype=torch.long)

    parses_list = [item['parses'] for item in batch_list]
    src_content_lengths = [item['src_content_len'] for item in batch_list]

    return {
        "input_ids": input_ids,
        "src_len": src_lengths, # This is a tensor of varying lengths
        "labels": labels,
        "tgt_len": tgt_lengths, # This is a tensor of varying lengths
        "parses_batch": parses_list,
        "src_content_lengths_batch": src_content_lengths
    }

BATCH=32
# Pass necessary mappings and indices to ODEDataset
train_loader = DataLoader(ODEDataset(train_df,"equ_prefix","sol_prefix", word2idx_map=word2idx, pad_idx=PAD, bos_idx=BOS, eos_idx=EOS), batch_size=BATCH, shuffle=True, collate_fn=collate_fn)
val_loader   = DataLoader(ODEDataset(val_df,  "equ_prefix","sol_prefix", word2idx_map=word2idx, pad_idx=PAD, bos_idx=BOS, eos_idx=EOS), batch_size=BATCH, shuffle=False, collate_fn=collate_fn)
test_loader  = DataLoader(ODEDataset(test_df, "equ_prefix","sol_prefix", word2idx_map=word2idx, pad_idx=PAD, bos_idx=BOS, eos_idx=EOS), batch_size=BATCH, shuffle=False, collate_fn=collate_fn)


In [10]:
# TODO: check if we should update this!! - a smaller model than the one in original Deep Learning for Symbolic Mathematcs paper
class Seq2SeqTransformer(nn.Module):
    def __init__(self,
                 d_model=256,
                 nhead=4,
                 num_encoder_layers=4,
                 num_decoder_layers=4,
                 dim_feedforward=512,
                 dropout=0.1,
                 max_len=512,
                 output_encoder_states=True):
        super().__init__()
        self.pos_enc = nn.Parameter(torch.zeros(max_len, d_model))
        self.embedding = nn.Embedding(VOCAB_SIZE, d_model, padding_idx=PAD)
        self.transformer = nn.Transformer(
            d_model, nhead,
            num_encoder_layers, num_decoder_layers,
            dim_feedforward, dropout,
            batch_first=True
        )
        self.generator = nn.Linear(d_model, VOCAB_SIZE)

        self.output_encoder_states = output_encoder_states
        self.num_encoder_layers = num_encoder_layers

    def encode(self, src, output_all_hidden_states_flag=False):
        B, S = src.shape
        src_emb = self.embedding(src) + self.pos_enc[:S].unsqueeze(0)
        src_key_padding_mask = (src == PAD)

        current_input = src_emb
        all_hidden_states_list = []

        for i in range(self.num_encoder_layers):
            current_input = self.transformer.encoder.layers[i](
                current_input,
                src_key_padding_mask=src_key_padding_mask
            )
            if output_all_hidden_states_flag:
                all_hidden_states_list.append(current_input)

        memory = current_input

        if output_all_hidden_states_flag:
            return memory, all_hidden_states_list
        return memory

    def forward(self, src, tgt):
        # src: (B, S), tgt: (B, T)
        B, S_len = src.shape
        B_tgt, T_len = tgt.shape
        assert B == B_tgt

        src_emb = self.embedding(src) + self.pos_enc[:S_len].unsqueeze(0)
        tgt_emb = self.embedding(tgt) + self.pos_enc[:T_len].unsqueeze(0)

        src_padding_mask = (src == PAD)
        tgt_padding_mask = (tgt == PAD)
        memory_padding_mask = src_padding_mask
        tgt_causal_mask = self.transformer.generate_square_subsequent_mask(T_len).to(src.device)

        all_encoder_hidden_states_output = None # Initialize

        if self.output_encoder_states:
            # Encode src and get all hidden states if needed
            encoder_memory, all_encoder_hidden_states_output = self.encode(src, output_all_hidden_states_flag=True)

            # Decode using the obtained encoder_memory
            decoder_output = self.transformer.decoder(
                tgt=tgt_emb,
                memory=encoder_memory,
                tgt_mask=tgt_causal_mask,
                tgt_key_padding_mask=tgt_padding_mask,
                memory_key_padding_mask=memory_padding_mask
            )
        else:
            encoder_memory = self.encode(src, output_all_hidden_states_flag=False) # Just get final memory
            decoder_output = self.transformer.decoder(
                tgt=tgt_emb,
                memory=encoder_memory,
                tgt_mask=tgt_causal_mask,
                tgt_key_padding_mask=tgt_padding_mask,
                memory_key_padding_mask=memory_padding_mask
            )

        output_logits = self.generator(decoder_output)

        if self.output_encoder_states:
            return output_logits, all_encoder_hidden_states_output
        return output_logits

    def decode(self, tgt, memory):
        B, T = tgt.shape
        tgt_emb = self.embedding(tgt) + self.pos_enc[:T].unsqueeze(0)
        return self.transformer.decoder(
            tgt_emb,
            memory,
            tgt_mask=self.transformer.generate_square_subsequent_mask(T).to(tgt.device),
            tgt_key_padding_mask=tgt == PAD
        )


In [11]:
# training and eval funcs go hereeeee
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Seq2SeqTransformer().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=0.01)
loss_fn   = nn.CrossEntropyLoss(ignore_index=PAD)

tree_regularizer = TreeRegularizer(orth_bidir=True).to(device)
treereg_alpha = 1  # 1, same as in paper
global global_step_counter
global_step_counter = 0

def train_epoch():
    model.train()
    total_loss_epoch = 0
    global global_step_counter
    for batch_data in tqdm(train_loader):
        global_step_counter += 1

        src_tokens = batch_data["input_ids"].to(device)
        # slen_tensor = batch_data["src_len"].to(device) # Tensor of actual lengths for each item
        tgt_tokens = batch_data["labels"].to(device)

        parses_batch_from_loader = batch_data["parses_batch"]
        src_content_lengths_batch = batch_data["src_content_lengths_batch"]

        model_outputs = model(src_tokens, tgt_tokens[:,:-1]) # Decoder input excludes last token

        all_encoder_hidden_states = None
        if model.output_encoder_states:
            output_logits, all_encoder_hidden_states = model_outputs
        else:
            output_logits = model_outputs

        main_loss = loss_fn(output_logits.reshape(-1, VOCAB_SIZE), tgt_tokens[:,1:].reshape(-1)) # Loss against shifted target
        current_batch_total_loss = main_loss

        if model.output_encoder_states and (global_step_counter % 20 == 0) and all_encoder_hidden_states is not None:
            if len(all_encoder_hidden_states) > 1: # Need at least 2 layers for index 1
                hidden_states_for_treereg_all_batch = all_encoder_hidden_states[1] # 2nd layer

                word_boundaries_for_treereg = []
                parses_for_treereg_valid_items = []
                valid_item_indices_in_batch = []

                for i in range(src_tokens.size(0)):
                    num_actual_tokens = src_content_lengths_batch[i]
                    if num_actual_tokens > 0 and parses_batch_from_loader[i]: # Check parse dict is not empty
                        word_boundaries_for_treereg.append([True] * num_actual_tokens)
                        parses_for_treereg_valid_items.append(parses_batch_from_loader[i])
                        valid_item_indices_in_batch.append(i)

                if valid_item_indices_in_batch: # If any items are valid for TreeReg
                    # Filter hidden states for valid items only
                    # Ensure hidden_states_for_treereg_all_batch corresponds to src_tokens sequence length
                    # The hidden states from encoder will have sequence length matching src_tokens (padded length)
                    # However, SCINComputer processes based on actual token sequence length.
                    # We need to handle this carefully. The `build_chart` in SCINComputer takes
                    # hidden_states[idx].squeeze(0). If hidden_states is (B, Seq_padded, Dim),
                    # then hidden_states[idx] is (Seq_padded, Dim).
                    # The SCINComputer's internal indexing (st, en) for spans and word_boundaries
                    # refers to the *actual* tokens, not padded.
                    # So, hidden_states passed to build_chart should be (Num_valid_items, Max_actual_len_among_valid, Dim)
                    # OR build_chart needs to be aware of padding for each item.
                    # The current SCINComputer expects hidden_states[idx] to be for one sentence.
                    # Let's pass the filtered batch of hidden states directly: (Num_valid_items, Seq_padded, Dim)
                    # And `word_boundaries_for_treereg` has the actual lengths. SCINComputer should use those.

                    filtered_hs_for_treereg = hidden_states_for_treereg_all_batch[torch.tensor(valid_item_indices_in_batch, device=device)]

                    if filtered_hs_for_treereg.size(0) > 0:
                        # The hidden states passed to build_chart are (N_valid, S_padded, D)
                        # Word boundaries are List[List[bool]] with actual lengths
                        # Parses are List[dict]
                        charts = tree_regularizer.build_chart(filtered_hs_for_treereg, word_boundaries_for_treereg, None)

                        try:
                            reg_loss_terms, _ = tree_regularizer.get_score(charts, word_boundaries_for_treereg, parses_for_treereg_valid_items, device)
                            valid_reg_losses = [l for l in reg_loss_terms if isinstance(l, torch.Tensor) and l.requires_grad]
                            if valid_reg_losses:
                                tree_reg_component_loss = torch.stack(valid_reg_losses).mean()
                                current_batch_total_loss = current_batch_total_loss + (tree_reg_component_loss * treereg_alpha)
                        except Exception as e:
                            print(f"Error during TreeReg loss calculation (step {global_step_counter}): {e}")
                            # Potentially log more details: e.g., specific item causing error.
                            print(f"Problematic parses: {parses_for_treereg_valid_items}")
                            print(f"Word boundaries: {word_boundaries_for_treereg}")
                            pass

        optimizer.zero_grad()
        current_batch_total_loss.backward()
        optimizer.step()

        total_loss_epoch += current_batch_total_loss.item()

    return total_loss_epoch / len(train_loader)

@torch.no_grad()
def evaluate(loader):
    model.eval()
    total_loss = 0
    for batch_data in loader:
        src = batch_data["input_ids"]
        tgt = batch_data["labels"]
        src, tgt = src.to(device), tgt.to(device)
        out = model(src, tgt[:,:-1])
        if model.output_encoder_states:
            out, all_encoder_hidden_states = out
        loss = loss_fn(out.reshape(-1, VOCAB_SIZE), tgt[:,1:].reshape(-1))
        total_loss += loss.item()
    return total_loss / len(loader)

def greedy_decode(src, max_len=64):
    src = src.to(device)
    memory = model.encode(src)
    ys = torch.full((src.size(0),1), BOS, device=device, dtype=torch.long)
    for i in range(max_len-1):
        out = model.decode(ys, memory)
        prob = model.generator(out[:,-1,:])
        next_word = prob.argmax(dim=-1, keepdim=True)
        ys = torch.cat([ys, next_word], dim=1)
        if (next_word==EOS).all(): break
    return ys.cpu().tolist()


In [12]:
# training loop
EPOCHS = 20
torch.backends.cuda.matmul.allow_tf32 = True # make faster

for epoch in range(1, EPOCHS+1):
    t0 = time.time()
    train_loss = train_epoch()
    val_loss   = evaluate(val_loader)
    print(f"Epoch {epoch} | train loss {train_loss:.4f} | val loss {val_loss:.4f} | {time.time()-t0:.1f}s")
    torch.save(model.state_dict(), SAVE_DIR+f"/epoch{epoch}.pt")


  0%|          | 0/847 [00:00<?, ?it/s]



Epoch 1 | train loss 1.1433 | val loss 1.5623 | 36.9s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 2 | train loss 0.5632 | val loss 1.5237 | 36.5s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 3 | train loss 0.4377 | val loss 1.5081 | 36.2s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 4 | train loss 0.3026 | val loss 1.5105 | 34.1s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 5 | train loss 0.2046 | val loss 1.5162 | 34.3s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 6 | train loss 0.0950 | val loss 1.5269 | 34.0s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 7 | train loss 0.1212 | val loss 1.5354 | 33.9s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 8 | train loss 0.0790 | val loss 1.5360 | 36.0s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 9 | train loss 0.0418 | val loss 1.5541 | 35.8s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 10 | train loss 0.0651 | val loss 1.5377 | 35.9s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 11 | train loss -0.0234 | val loss 1.5296 | 36.0s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 12 | train loss -0.0651 | val loss 1.5135 | 36.0s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 13 | train loss -0.0795 | val loss 1.4954 | 36.0s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 14 | train loss -0.1252 | val loss 1.4763 | 35.8s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 15 | train loss -0.2509 | val loss 1.4665 | 36.1s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 16 | train loss -0.2280 | val loss 1.4526 | 35.8s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 17 | train loss -0.2443 | val loss 1.4502 | 36.0s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 18 | train loss -0.3186 | val loss 1.4348 | 35.3s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 19 | train loss -0.3568 | val loss 1.4241 | 34.9s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 20 | train loss -0.4044 | val loss 1.4165 | 36.5s


In [13]:
#  testing here

model.eval()
n_correct = 0
total = 0
x = sp.Symbol('x')
rng = np.random.default_rng()

for batch_data in tqdm(test_loader):
    src = batch_data["input_ids"]
    tgt = batch_data["labels"]
    src, tgt = src.to(device), tgt.to(device)
    hyps = greedy_decode(src)       # list of B lists of token IDs - all hypostheses
    truths = tgt.tolist()           # list of B lists - truths

    for hyp_ids, true_ids in zip(hyps, truths):
        # find first EOS and remove everything after it
        try:
          first_eos = hyp_ids.index(EOS)
        except:
          first_eos = len(hyp_ids) # if no EOS, don't strip

        # strip special tokens
        hyp_tok  = [idx2word[i] for i in hyp_ids[:first_eos]  if i not in (PAD, BOS, EOS)]
        true_tok = [idx2word[i] for i in true_ids if i not in (PAD, BOS, EOS)]

        # parse numbers
        hyp_tok = [np.int32(int(tok)) if tok.lstrip("-").isdigit() else tok for tok in hyp_tok]
        true_tok = [np.int32(int(tok)) if tok.lstrip("-").isdigit() else tok for tok in true_tok]

        # convert to Sympy and check
        try:
          hyp_expr = prefix_to_sympy(hyp_tok, OPERATORS)
        except:
          pass # if model output isn't real equation, skip
        true_expr = prefix_to_sympy(true_tok, OPERATORS)
        if verify_solution(hyp_expr, true_expr, rng):
            n_correct += 1

        total += 1
    if total > 200:
        print("model", hyp_tok)
        print("true", true_tok)
        break

acc = 100 * n_correct / len(test_loader)
print(f"Greedy semantic accuracy: {acc:.2f}%")


  0%|          | 0/106 [00:00<?, ?it/s]

model ['add', 'c', 'mul', 'mul', 'div', np.int32(1), np.int32(2), 'x', 'add', 'div', np.int32(1), np.int32(2), 'mul', 'div', np.int32(1), np.int32(2), 'x']
true ['mul', 'div', np.int32(-5), 'x', 'add', np.int32(-2), 'mul', 'c', 'x']
Greedy semantic accuracy: 0.00%


In [32]:
# write beam search here
import torch.nn.functional as F
from collections import namedtuple

BeamHyp = namedtuple("BeamHyp", ["score", "tokens"])

def beam_search(src_batch, beam_size=5, length_penalty=1.0, max_len=128):
    """
    src_batch: LongTensor (B, S) - batch first as in the main model too
    returns: list of B best token ID lists
    """
    model.eval()
    B, S = src_batch.shape
    src_batch = src_batch.to(device)
    memory = model.encode(src_batch)

    # initialize beams per example
    beams = [[BeamHyp(0.0, [BOS])] for _ in range(B)]

    for _ in range(max_len):
        all_beams = [[] for _ in range(B)]
        for b in range(B):
            for hyp in beams[b]:
                tokens = hyp.tokens
                # prepare decoder input: (1, t)
                tgt_input = torch.tensor(tokens, dtype=torch.long, device=device).unsqueeze(0)
                dec = model.decode(tgt_input, memory[b:b+1])     # (1, t, D)
                # project last step to vocab & log‐softmax
                logits = model.generator(dec[:, -1, :])          # (1, V)
                logp   = F.log_softmax(logits, dim=-1).squeeze(0) # (V,)

                topv, topi = logp.topk(beam_size)
                for score, idx in zip(topv.tolist(), topi.tolist()):
                    all_beams[b].append(BeamHyp(hyp.score + score, tokens + [idx]))

            # prune back to beam_size
            all_beams[b].sort(
                key=lambda h: h.score / (len(h.tokens) ** length_penalty),
                reverse=True
            )
            beams[b] = all_beams[b][:beam_size]

    # extract best sequence per example
    results = []
    for b in range(B):
        best = max(
            beams[b],
            key=lambda h: h.score / (len(h.tokens) ** length_penalty)
        )
        results.append(best.tokens)
    return results

# now evaluate with beam=10
model.eval()
n_correct = 0
total = 0
for batch_data in tqdm(test_loader):
    src, tgt = batch_data["input_ids"], batch_data["labels"]
    src, tgt = src.to(device), tgt.to(device)
    hyps   = beam_search(src, beam_size=10, length_penalty=1.0, max_len=64)
    truths = tgt.tolist()

    for hyp_ids, true_ids in zip(hyps, truths):
        try:
            hyp_tok = [idx2word[i] for i in hyp_ids  if i not in (PAD,BOS,EOS)]
            hyp_expr = prefix_to_sympy(hyp_tok, OPERATORS)
            x = sp.Symbol('x')
            if verify_solution(sp.diff(hyp_expr, x), hyp_expr, x):
                n_correct += 1
        except:
            pass
        total += 1

    if total > 30:
        print(hyp_tok)
        break


acc = 100 * n_correct / total
print(f"Beam-10 semantic accuracy: {acc:.2f}%")


  0%|          | 0/106 [00:00<?, ?it/s]

['37', '37', '37', '37', '37', '37', '37', '37', '37', '37', '37', '37', '37', '37', '37', '37', '37', '37', '37', '37', '37', '37', '37', '37', '37', '37', '37', '37', '37', '37', '37', '37', '37', '37', '37', '37', '37', '37', '37', '37', '37', '37', '37', '37', '37', '37', '37', '37', '37', '37', '37', '37', '37', '37', '37', '37', '37', '37', '37', '37', '37', '37', '37', '37']
Beam-10 semantic accuracy: 0.00%


In [36]:
# write beam search here
import torch.nn.functional as F
from collections import namedtuple

BeamHyp = namedtuple("BeamHyp", ["score", "tokens"])

def beam_search(src_batch, beam_size=5, length_penalty=1.0, max_len=128):
    """
    src_batch: LongTensor (B, S) - batch first as in the main model too
    returns: list of B best token ID lists
    """
    model.eval()
    B, S = src_batch.shape
    src_batch = src_batch.to(device)
    memory = model.encode(src_batch)

    # initialize beams per example
    beams = [[BeamHyp(0.0, [BOS])] for _ in range(B)]

    for _ in range(max_len):
        all_beams = [[] for _ in range(B)]
        for b in range(B):
            for hyp in beams[b]:
                tokens = hyp.tokens
                # prepare decoder input: (1, t)
                tgt_input = torch.tensor(tokens, dtype=torch.long, device=device).unsqueeze(0)
                dec = model.decode(tgt_input, memory[b:b+1])     # (1, t, D)
                # project last step to vocab & log‐softmax
                logits = model.generator(dec[:, -1, :])          # (1, V)
                logp   = F.log_softmax(logits, dim=-1).squeeze(0) # (V,)

                topv, topi = logp.topk(beam_size)
                for score, idx in zip(topv.tolist(), topi.tolist()):
                    all_beams[b].append(BeamHyp(hyp.score + score, tokens + [idx]))

            # prune back to beam_size
            all_beams[b].sort(
                key=lambda h: h.score / (len(h.tokens) ** length_penalty),
                reverse=True
            )
            beams[b] = all_beams[b][:beam_size]

    # extract best sequence per example
    results = []
    for b in range(B):
        best = max(
            beams[b],
            key=lambda h: h.score / (len(h.tokens) ** length_penalty)
        )
        results.append(best.tokens)
    return results

# now evaluate with beam=3
model.eval()
n_correct = 0
total = 0
for batch_data in tqdm(test_loader):
    src, tgt = batch_data["input_ids"], batch_data["labels"]
    src, tgt = src.to(device), tgt.to(device)
    hyps   = beam_search(src, beam_size=3, length_penalty=1.0, max_len=64)
    truths = tgt.tolist()

    for hyp_ids, true_ids in zip(hyps, truths):
        try:
            hyp_tok = [idx2word[i] for i in hyp_ids  if i not in (PAD,BOS,EOS)]
            hyp_expr = prefix_to_sympy(hyp_tok, OPERATORS)
            x = sp.Symbol('x')
            if verify_solution(sp.diff(hyp_expr, x), hyp_expr, x):
                n_correct += 1
        except:
            pass
        total += 1

    if total > 50:
        print(hyp_tok)
        break


acc = 100 * n_correct / total
print(f"Beam-3 semantic accuracy: {acc:.2f}%")


  0%|          | 0/106 [00:00<?, ?it/s]

['add', 'c', 'x', 'add', 'c', 'x', 'add', 'c', 'x', 'add', 'c', 'x', 'add', 'c', 'x', 'add', 'c', 'x', 'add', 'c', 'x', 'add', 'c', 'x', 'add', 'c', 'x', 'add', 'c', 'x', 'add', 'c', 'x', 'add', 'c', 'x', 'add', 'c', 'x', 'add', 'c', 'x', 'add', 'c', 'x', 'add', 'c', 'x']
Beam-3 semantic accuracy: 100.00%
