<a href="https://colab.research.google.com/github/srirambandi/compsci685/blob/main/train_and_eval.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Colab Cell – clone your repo at the top of the notebook
!git clone https://github.com/srirambandi/compsci685.git
%cd compsci685

Cloning into 'compsci685'...
remote: Enumerating objects: 153, done.[K
remote: Counting objects: 100% (153/153), done.[K
remote: Compressing objects: 100% (109/109), done.[K
remote: Total 153 (delta 77), reused 103 (delta 35), pack-reused 0 (from 0)[K
Receiving objects: 100% (153/153), 25.90 MiB | 3.93 MiB/s, done.
Resolving deltas: 100% (77/77), done.
/content/compsci685


In [None]:
!pip install torch transformers datasets sympy tqdm scikit-learn

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [None]:
import sys
sys.path.insert(0, "gen_dataset")
sys.path.insert(0, "gen_dataset/src")
sys.path.insert(0, "training/") # treereg

In [None]:
import os
import math
import time

import pandas as pd
import numpy as np
import sympy as sp
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from utils import prefix_to_sympy, verify_solution, OPERATORS

from regularizer.regularizer_main import TreeRegularizer
from parse_tree_adapted import get_parse_dict_for_prefix_list

In [None]:
from google.colab import drive
drive.mount('/content/drive')
SAVE_DIR = "/content/drive/MyDrive/compsci685/checkpoints_treereg_paper_ode1_dataset"
os.makedirs(SAVE_DIR, exist_ok=True)

Mounted at /content/drive


In [None]:
DATA_PATH = "gen_dataset/data/train16_clean.txt"
# data is stored as input_equation\toutput_equation

TEST_SIZE = 1_024 # number of samples to evaluate model equation correctness on
VALID_SIZE = 2_048 # number of samples to get validation loss from
# hardcode these numbers, as 10% of new >600K sample dataset is too much
# evaluation is slow, so only take 1024
# validation is a bit faster, so take 2048
# use those numbers as they are powers of 2 closest to 1k and 2k
# so batch size will evenly divide them during validation

os.makedirs("splits", exist_ok=True)
from tqdm import tqdm
print("Splitting Data")
with open(DATA_PATH, 'r') as reader:
  i = 0
  with open("splits/test.txt", 'w') as test_writer, open("splits/valid.txt", 'w') as valid_writer, open("splits/train.txt", 'w') as train_writer:
    for line in tqdm(reader):
      if i < TEST_SIZE:
        test_writer.write(line + "\n")
      elif i < (TEST_SIZE + VALID_SIZE):
        valid_writer.write(line + "\n")
      else:
        train_writer.write(line + "\n")

      i += 1




Splitting Data


683442it [00:00, 1510063.18it/s]


In [None]:
!wc -l splits/train.txt

1360740 splits/train.txt


In [None]:
# special tokens
SPECIAL = {"<pad>":0,"<bos>":1,"<eos>":2}

# collect tokens from original file:
tokens = set()
with open(DATA_PATH, 'r') as reader:
  for line in reader:
    tokens.update(line.split())
word2idx = {w:i+len(SPECIAL) for i,w in enumerate(sorted(tokens))}
word2idx.update(SPECIAL)
print(word2idx)
idx2word = {i:w for w,i in word2idx.items()}

PAD, BOS, EOS = word2idx["<pad>"], word2idx["<bos>"], word2idx["<eos>"]
VOCAB_SIZE = len(word2idx)

class ODEDataset(Dataset):
    def __init__(self, data_file, max_len=18, word2idx_map=None, pad_idx=None, bos_idx=None, eos_idx=None):
        self.src = []
        self.tgt = []
        with open(data_file, 'r') as reader:
          for line in reader:
            if "\t" in line: # last line doesn't have any data
              src_item, tgt_item = line.split("\t")
              self.src.append(src_item.split())
              self.tgt.append(tgt_item.split())
        self.max_len = max_len
        self.word2idx = word2idx_map
        self.PAD_IDX = pad_idx
        self.BOS_IDX = bos_idx
        self.EOS_IDX = eos_idx

    def __len__(self):
        return len(self.src)

    def __getitem__(self, i):
        src_str_tokens = self.src[i]
        tgt_str_tokens = self.tgt[i]

        src_ids = [self.BOS_IDX] + [self.word2idx.get(t, self.PAD_IDX) for t in src_str_tokens] + [self.EOS_IDX]
        tgt_ids = [self.BOS_IDX] + [self.word2idx.get(t, self.PAD_IDX) for t in tgt_str_tokens] + [self.EOS_IDX]

        # Pad sequences
        padded_src_ids = src_ids[:self.max_len] + [self.PAD_IDX] * (self.max_len - len(src_ids))
        if len(padded_src_ids) > self.max_len: padded_src_ids = padded_src_ids[:self.max_len] # Ensure fixed length

        padded_tgt_ids = tgt_ids[:self.max_len] + [self.PAD_IDX] * (self.max_len - len(tgt_ids))
        if len(padded_tgt_ids) > self.max_len: padded_tgt_ids = padded_tgt_ids[:self.max_len]


        parse_d = get_parse_dict_for_prefix_list(src_str_tokens)
        if parse_d is None:
            parse_d = {}

        return {
            "input_ids": torch.tensor(padded_src_ids, dtype=torch.long),
            "src_len": min(len(src_ids), self.max_len), # True length before padding (capped at max_len), including BOS/EOS
            "labels": torch.tensor(padded_tgt_ids, dtype=torch.long),
            "tgt_len": min(len(tgt_ids), self.max_len),
            "parses": parse_d,
            "src_content_len": len(src_str_tokens) # Number of actual words/tokens in source
        }

def collate_fn(batch_list):
    input_ids_list = [item['input_ids'] for item in batch_list]
    # Ensure all tensors in input_ids_list have the same length before stacking
    # This should be guaranteed by __getitem__ if max_len is consistent
    max_len_check = input_ids_list[0].size(0)
    assert all(t.size(0) == max_len_check for t in input_ids_list), "Padding error: Tensors in batch have different lengths."

    input_ids = torch.stack(input_ids_list)
    src_lengths = torch.tensor([item['src_len'] for item in batch_list], dtype=torch.long)

    labels_list = [item['labels'] for item in batch_list]
    assert all(t.size(0) == max_len_check for t in labels_list), "Padding error: Label tensors have different lengths."
    labels = torch.stack(labels_list)

    tgt_lengths = torch.tensor([item['tgt_len'] for item in batch_list], dtype=torch.long)

    parses_list = [item['parses'] for item in batch_list]
    src_content_lengths = [item['src_content_len'] for item in batch_list]

    return {
        "input_ids": input_ids,
        "src_len": src_lengths, # This is a tensor of varying lengths
        "labels": labels,
        "tgt_len": tgt_lengths, # This is a tensor of varying lengths
        "parses_batch": parses_list,
        "src_content_lengths_batch": src_content_lengths
    }

BATCH=256
# Pass necessary mappings and indices to ODEDataset
train_loader = DataLoader(ODEDataset("splits/train.txt", word2idx_map=word2idx, pad_idx=PAD, bos_idx=BOS, eos_idx=EOS), batch_size=BATCH, shuffle=True, collate_fn=collate_fn)
val_loader   = DataLoader(ODEDataset("splits/valid.txt", word2idx_map=word2idx, pad_idx=PAD, bos_idx=BOS, eos_idx=EOS), batch_size=BATCH, shuffle=False, collate_fn=collate_fn)
test_loader  = DataLoader(ODEDataset("splits/test.txt", word2idx_map=word2idx, pad_idx=PAD, bos_idx=BOS, eos_idx=EOS), batch_size=32, shuffle=False, collate_fn=collate_fn)


{'0': 3, '1': 4, '2': 5, '3': 6, '4': 7, '5': 8, '6': 9, '7': 10, '8': 11, '9': 12, 'E': 13, 'INT+': 14, 'INT-': 15, 'abs': 16, 'acos': 17, 'acosh': 18, 'add': 19, 'asin': 20, 'asinh': 21, 'atan': 22, 'atanh': 23, 'c': 24, 'cos': 25, 'cosh': 26, 'div': 27, 'exp': 28, 'log': 29, 'mul': 30, 'pi': 31, 'pow': 32, 'sign': 33, 'sin': 34, 'sinh': 35, 'sqrt': 36, 'tan': 37, 'tanh': 38, 'x': 39, 'y': 40, "y'": 41, '<pad>': 0, '<bos>': 1, '<eos>': 2}


In [None]:
# TODO: check if we should update this!! - a smaller model than the one in original Deep Learning for Symbolic Mathematcs paper
class Seq2SeqTransformer(nn.Module):
    def __init__(self,
                 d_model=256,
                 nhead=4,
                 num_encoder_layers=4,
                 num_decoder_layers=4,
                 dim_feedforward=512,
                 dropout=0.1,
                 max_len=64,
                 output_encoder_states=True):
        super().__init__()
        self.pos_enc = nn.Parameter(torch.zeros(max_len, d_model))
        self.embedding = nn.Embedding(VOCAB_SIZE, d_model, padding_idx=PAD)
        self.transformer = nn.Transformer(
            d_model, nhead,
            num_encoder_layers, num_decoder_layers,
            dim_feedforward, dropout,
            batch_first=True
        )
        self.generator = nn.Linear(d_model, VOCAB_SIZE)

        self.output_encoder_states = output_encoder_states
        self.num_encoder_layers = num_encoder_layers

    def encode(self, src, output_all_hidden_states_flag=False):
        B, S = src.shape
        src_emb = self.embedding(src) + self.pos_enc[:S].unsqueeze(0)
        src_key_padding_mask = (src == PAD)

        current_input = src_emb
        all_hidden_states_list = []

        for i in range(self.num_encoder_layers):
            current_input = self.transformer.encoder.layers[i](
                current_input,
                src_key_padding_mask=src_key_padding_mask
            )
            if output_all_hidden_states_flag:
                all_hidden_states_list.append(current_input)

        memory = current_input

        if output_all_hidden_states_flag:
            return memory, all_hidden_states_list
        return memory

    def forward(self, src, tgt):
        # src: (B, S), tgt: (B, T)
        B, S_len = src.shape
        B_tgt, T_len = tgt.shape
        assert B == B_tgt

        src_emb = self.embedding(src) + self.pos_enc[:S_len].unsqueeze(0)
        tgt_emb = self.embedding(tgt) + self.pos_enc[:T_len].unsqueeze(0)

        src_padding_mask = (src == PAD)
        tgt_padding_mask = (tgt == PAD)
        memory_padding_mask = src_padding_mask
        tgt_causal_mask = self.transformer.generate_square_subsequent_mask(T_len).to(src.device)

        all_encoder_hidden_states_output = None # Initialize

        if self.output_encoder_states:
            # Encode src and get all hidden states if needed
            encoder_memory, all_encoder_hidden_states_output = self.encode(src, output_all_hidden_states_flag=True)

            # Decode using the obtained encoder_memory
            decoder_output = self.transformer.decoder(
                tgt=tgt_emb,
                memory=encoder_memory,
                tgt_mask=tgt_causal_mask,
                tgt_key_padding_mask=tgt_padding_mask,
                memory_key_padding_mask=memory_padding_mask
            )
        else:
            encoder_memory = self.encode(src, output_all_hidden_states_flag=False) # Just get final memory
            decoder_output = self.transformer.decoder(
                tgt=tgt_emb,
                memory=encoder_memory,
                tgt_mask=tgt_causal_mask,
                tgt_key_padding_mask=tgt_padding_mask,
                memory_key_padding_mask=memory_padding_mask
            )

        output_logits = self.generator(decoder_output)

        if self.output_encoder_states:
            return output_logits, all_encoder_hidden_states_output
        return output_logits

    def decode(self, tgt, memory):
        B, T = tgt.shape
        tgt_emb = self.embedding(tgt) + self.pos_enc[:T].unsqueeze(0)
        return self.transformer.decoder(
            tgt_emb,
            memory,
            tgt_mask=self.transformer.generate_square_subsequent_mask(T).to(tgt.device),
            tgt_key_padding_mask=tgt == PAD
        )


In [None]:
# training and eval funcs go hereeeee
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Seq2SeqTransformer().to(device)
optimizer = optim.AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)
loss_fn   = nn.CrossEntropyLoss(ignore_index=PAD)

tree_regularizer = TreeRegularizer(orth_bidir=True).to(device)
treereg_alpha = 1  # 1, same as in paper
global global_step_counter
global_step_counter = 0

def train_epoch():
    model.train()
    total_loss_epoch = 0
    total_treereg_loss_epoch = 0
    global global_step_counter
    for batch_data in tqdm(train_loader):
        global_step_counter += 1

        src_tokens = batch_data["input_ids"].to(device)
        # slen_tensor = batch_data["src_len"].to(device) # Tensor of actual lengths for each item
        tgt_tokens = batch_data["labels"].to(device)

        parses_batch_from_loader = batch_data["parses_batch"]
        src_content_lengths_batch = batch_data["src_content_lengths_batch"]

        model_outputs = model(src_tokens, tgt_tokens[:,:-1]) # Decoder input excludes last token

        all_encoder_hidden_states = None
        if model.output_encoder_states:
            output_logits, all_encoder_hidden_states = model_outputs
        else:
            output_logits = model_outputs

        main_loss = loss_fn(output_logits.reshape(-1, VOCAB_SIZE), tgt_tokens[:,1:].reshape(-1)) # Loss against shifted target
        current_batch_total_loss = main_loss

        if model.output_encoder_states and (global_step_counter % 20 == 0) and all_encoder_hidden_states is not None:
            if len(all_encoder_hidden_states) > 1: # Need at least 2 layers for index 1
                hidden_states_for_treereg_all_batch = all_encoder_hidden_states[1] # 2nd layer

                word_boundaries_for_treereg = []
                parses_for_treereg_valid_items = []
                valid_item_indices_in_batch = []

                for i in range(src_tokens.size(0)):
                    num_actual_tokens = src_content_lengths_batch[i]
                    if num_actual_tokens > 0 and parses_batch_from_loader[i]: # Check parse dict is not empty
                        word_boundaries_for_treereg.append([True] * num_actual_tokens)
                        parses_for_treereg_valid_items.append(parses_batch_from_loader[i])
                        valid_item_indices_in_batch.append(i)

                if valid_item_indices_in_batch: # If any items are valid for TreeReg
                    # Filter hidden states for valid items only
                    # Ensure hidden_states_for_treereg_all_batch corresponds to src_tokens sequence length
                    # The hidden states from encoder will have sequence length matching src_tokens (padded length)
                    # However, SCINComputer processes based on actual token sequence length.
                    # We need to handle this carefully. The `build_chart` in SCINComputer takes
                    # hidden_states[idx].squeeze(0). If hidden_states is (B, Seq_padded, Dim),
                    # then hidden_states[idx] is (Seq_padded, Dim).
                    # The SCINComputer's internal indexing (st, en) for spans and word_boundaries
                    # refers to the *actual* tokens, not padded.
                    # So, hidden_states passed to build_chart should be (Num_valid_items, Max_actual_len_among_valid, Dim)
                    # OR build_chart needs to be aware of padding for each item.
                    # The current SCINComputer expects hidden_states[idx] to be for one sentence.
                    # Let's pass the filtered batch of hidden states directly: (Num_valid_items, Seq_padded, Dim)
                    # And `word_boundaries_for_treereg` has the actual lengths. SCINComputer should use those.

                    filtered_hs_for_treereg = hidden_states_for_treereg_all_batch[torch.tensor(valid_item_indices_in_batch, device=device)]

                    if filtered_hs_for_treereg.size(0) > 0:
                        # The hidden states passed to build_chart are (N_valid, S_padded, D)
                        # Word boundaries are List[List[bool]] with actual lengths
                        # Parses are List[dict]
                        charts = tree_regularizer.build_chart(filtered_hs_for_treereg, word_boundaries_for_treereg, None)

                        try:
                            reg_loss_terms, _ = tree_regularizer.get_score(charts, word_boundaries_for_treereg, parses_for_treereg_valid_items, device)
                            valid_reg_losses = [l for l in reg_loss_terms if isinstance(l, torch.Tensor) and l.requires_grad]
                            if valid_reg_losses:
                                tree_reg_component_loss = torch.stack(valid_reg_losses).mean()
                                current_batch_total_loss = current_batch_total_loss + (tree_reg_component_loss * treereg_alpha)
                        except Exception as e:
                            print(f"Error during TreeReg loss calculation (step {global_step_counter}): {e}")
                            # Potentially log more details: e.g., specific item causing error.
                            print(f"Problematic parses: {parses_for_treereg_valid_items}")
                            print(f"Word boundaries: {word_boundaries_for_treereg}")
                            pass

        optimizer.zero_grad()
        current_batch_total_loss.backward()
        optimizer.step()

        total_loss_epoch += main_loss.item()
        total_treereg_loss_epoch += current_batch_total_loss.item()

    return total_loss_epoch / len(train_loader), total_treereg_loss_epoch / len(train_loader)

@torch.no_grad()
def evaluate(loader):
    model.eval()
    total_loss = 0
    for batch_data in loader:
        src = batch_data["input_ids"]
        tgt = batch_data["labels"]
        src, tgt = src.to(device), tgt.to(device)
        out = model(src, tgt[:,:-1])
        if model.output_encoder_states:
            out, all_encoder_hidden_states = out
        loss = loss_fn(out.reshape(-1, VOCAB_SIZE), tgt[:,1:].reshape(-1))
        total_loss += loss.item()
    return total_loss / len(loader)

def greedy_decode(src, max_len=18):
    src = src.to(device)
    memory = model.encode(src)
    ys = torch.full((src.size(0),1), BOS, device=device, dtype=torch.long)
    for i in range(max_len-1):
        out = model.decode(ys, memory)
        prob = model.generator(out[:,-1,:])
        next_word = prob.argmax(dim=-1, keepdim=True)
        ys = torch.cat([ys, next_word], dim=1)
        if (next_word==EOS).all(): break
    return ys.cpu().tolist()


In [19]:
# training loop
EPOCHS = 50
torch.backends.cuda.matmul.allow_tf32 = True # make faster

for epoch in range(1, EPOCHS+1):
    t0 = time.time()
    train_loss, with_treereg_loss = train_epoch()
    val_loss   = evaluate(val_loader)
    print(f"Epoch {epoch} | train loss (CE) {train_loss:.4f} | val loss {val_loss:.4f} | CE loss + treereg loss {with_treereg_loss:.4f} | {time.time()-t0:.1f}s")
    torch.save(model.state_dict(), SAVE_DIR+f"/epoch{epoch}.pt")


100%|██████████| 2658/2658 [01:52<00:00, 23.58it/s]


Epoch 1 | train loss (CE) 1.3731 | val loss 0.8866 | CE loss + treereg loss 0.6564 | 112.9s


100%|██████████| 2658/2658 [01:52<00:00, 23.68it/s]


Epoch 2 | train loss (CE) 1.0222 | val loss 0.7720 | CE loss + treereg loss 0.0165 | 112.4s


100%|██████████| 2658/2658 [01:52<00:00, 23.73it/s]


Epoch 3 | train loss (CE) 0.9134 | val loss 0.7185 | CE loss + treereg loss -0.2019 | 112.1s


100%|██████████| 2658/2658 [01:51<00:00, 23.75it/s]


Epoch 4 | train loss (CE) 0.8494 | val loss 0.6806 | CE loss + treereg loss -0.3415 | 112.0s


100%|██████████| 2658/2658 [01:51<00:00, 23.74it/s]


Epoch 5 | train loss (CE) 0.8058 | val loss 0.6488 | CE loss + treereg loss -0.4070 | 112.1s


100%|██████████| 2658/2658 [01:52<00:00, 23.59it/s]


Epoch 6 | train loss (CE) 0.7731 | val loss 0.6376 | CE loss + treereg loss -0.4775 | 112.8s


100%|██████████| 2658/2658 [01:52<00:00, 23.57it/s]


Epoch 7 | train loss (CE) 0.7466 | val loss 0.6218 | CE loss + treereg loss -0.5467 | 112.9s


100%|██████████| 2658/2658 [01:52<00:00, 23.61it/s]


Epoch 8 | train loss (CE) 0.7294 | val loss 0.6022 | CE loss + treereg loss -0.6059 | 112.7s


100%|██████████| 2658/2658 [01:51<00:00, 23.74it/s]


Epoch 9 | train loss (CE) 0.7149 | val loss 0.6136 | CE loss + treereg loss -0.6981 | 112.1s


100%|██████████| 2658/2658 [01:53<00:00, 23.37it/s]


Epoch 10 | train loss (CE) 0.7017 | val loss 0.5905 | CE loss + treereg loss -0.7541 | 113.9s


100%|██████████| 2658/2658 [01:52<00:00, 23.72it/s]


Epoch 11 | train loss (CE) 0.6865 | val loss 0.5516 | CE loss + treereg loss -0.7889 | 112.2s


100%|██████████| 2658/2658 [01:51<00:00, 23.74it/s]


Epoch 12 | train loss (CE) 0.6742 | val loss 0.5488 | CE loss + treereg loss -0.8891 | 112.1s


100%|██████████| 2658/2658 [01:51<00:00, 23.74it/s]


Epoch 13 | train loss (CE) 0.6635 | val loss 0.5643 | CE loss + treereg loss -0.9129 | 112.1s


100%|██████████| 2658/2658 [01:52<00:00, 23.69it/s]


Epoch 14 | train loss (CE) 0.6513 | val loss 0.5370 | CE loss + treereg loss -0.9652 | 112.4s


100%|██████████| 2658/2658 [01:52<00:00, 23.61it/s]


Epoch 15 | train loss (CE) 0.6408 | val loss 0.5139 | CE loss + treereg loss -1.0318 | 112.7s


100%|██████████| 2658/2658 [01:52<00:00, 23.62it/s]


Epoch 16 | train loss (CE) 0.6311 | val loss 0.5313 | CE loss + treereg loss -1.0342 | 112.7s


100%|██████████| 2658/2658 [01:52<00:00, 23.66it/s]


Epoch 17 | train loss (CE) 0.6231 | val loss 0.5273 | CE loss + treereg loss -1.1085 | 112.5s


100%|██████████| 2658/2658 [01:52<00:00, 23.62it/s]


Epoch 18 | train loss (CE) 0.6149 | val loss 0.5173 | CE loss + treereg loss -1.1388 | 112.7s


100%|██████████| 2658/2658 [01:52<00:00, 23.55it/s]


Epoch 19 | train loss (CE) 0.6082 | val loss 0.5301 | CE loss + treereg loss -1.2377 | 113.0s


100%|██████████| 2658/2658 [01:53<00:00, 23.39it/s]


Epoch 20 | train loss (CE) 0.6020 | val loss 0.5146 | CE loss + treereg loss -1.2511 | 113.8s


100%|██████████| 2658/2658 [01:52<00:00, 23.64it/s]


Epoch 21 | train loss (CE) 0.5960 | val loss 0.5246 | CE loss + treereg loss -1.2780 | 112.6s


100%|██████████| 2658/2658 [01:51<00:00, 23.77it/s]


Epoch 22 | train loss (CE) 0.5908 | val loss 0.5013 | CE loss + treereg loss -1.3833 | 111.9s


100%|██████████| 2658/2658 [01:52<00:00, 23.68it/s]


Epoch 23 | train loss (CE) 0.5853 | val loss 0.5144 | CE loss + treereg loss -1.4488 | 112.4s


100%|██████████| 2658/2658 [01:51<00:00, 23.79it/s]


Epoch 24 | train loss (CE) 0.5806 | val loss 0.4827 | CE loss + treereg loss -1.4293 | 111.9s


100%|██████████| 2658/2658 [01:52<00:00, 23.60it/s]


Epoch 25 | train loss (CE) 0.5758 | val loss 0.4971 | CE loss + treereg loss -1.4893 | 112.8s


100%|██████████| 2658/2658 [01:52<00:00, 23.68it/s]


Epoch 26 | train loss (CE) 0.5701 | val loss 0.4877 | CE loss + treereg loss -1.5447 | 112.4s


100%|██████████| 2658/2658 [01:53<00:00, 23.50it/s]


Epoch 27 | train loss (CE) 0.5652 | val loss 0.4933 | CE loss + treereg loss -1.6280 | 113.2s


100%|██████████| 2658/2658 [01:52<00:00, 23.56it/s]


Epoch 28 | train loss (CE) 0.5610 | val loss 0.4948 | CE loss + treereg loss -1.6669 | 112.9s


100%|██████████| 2658/2658 [01:52<00:00, 23.65it/s]


Epoch 29 | train loss (CE) 0.5566 | val loss 0.4793 | CE loss + treereg loss -1.7183 | 112.5s


100%|██████████| 2658/2658 [01:53<00:00, 23.45it/s]


Epoch 30 | train loss (CE) 0.5523 | val loss 0.4673 | CE loss + treereg loss -1.7081 | 113.5s


100%|██████████| 2658/2658 [01:52<00:00, 23.71it/s]


Epoch 31 | train loss (CE) 0.5490 | val loss 0.4609 | CE loss + treereg loss -1.7800 | 112.2s


100%|██████████| 2658/2658 [01:52<00:00, 23.57it/s]


Epoch 32 | train loss (CE) 0.5451 | val loss 0.4685 | CE loss + treereg loss -1.8336 | 112.9s


100%|██████████| 2658/2658 [01:52<00:00, 23.54it/s]


Epoch 33 | train loss (CE) 0.5405 | val loss 0.4654 | CE loss + treereg loss -1.9021 | 113.1s


100%|██████████| 2658/2658 [01:52<00:00, 23.60it/s]


Epoch 34 | train loss (CE) 0.5378 | val loss 0.4764 | CE loss + treereg loss -1.8714 | 112.8s


100%|██████████| 2658/2658 [01:52<00:00, 23.61it/s]


Epoch 35 | train loss (CE) 0.5339 | val loss 0.4577 | CE loss + treereg loss -1.9503 | 112.7s


100%|██████████| 2658/2658 [01:52<00:00, 23.62it/s]


Epoch 36 | train loss (CE) 0.5307 | val loss 0.4569 | CE loss + treereg loss -1.9477 | 112.6s


100%|██████████| 2658/2658 [01:52<00:00, 23.72it/s]


Epoch 37 | train loss (CE) 0.5263 | val loss 0.4554 | CE loss + treereg loss -2.0877 | 112.2s


100%|██████████| 2658/2658 [01:52<00:00, 23.70it/s]


Epoch 38 | train loss (CE) 0.5229 | val loss 0.4509 | CE loss + treereg loss -2.1021 | 112.3s


100%|██████████| 2658/2658 [01:52<00:00, 23.59it/s]


Epoch 39 | train loss (CE) 0.5198 | val loss 0.4759 | CE loss + treereg loss -2.1251 | 112.8s


100%|██████████| 2658/2658 [01:54<00:00, 23.30it/s]


Epoch 40 | train loss (CE) 0.5169 | val loss 0.4602 | CE loss + treereg loss -2.1833 | 114.2s


100%|██████████| 2658/2658 [01:52<00:00, 23.60it/s]


Epoch 41 | train loss (CE) 0.5143 | val loss 0.4558 | CE loss + treereg loss -2.1842 | 112.8s


100%|██████████| 2658/2658 [01:52<00:00, 23.64it/s]


Epoch 42 | train loss (CE) 0.5111 | val loss 0.4428 | CE loss + treereg loss -2.2412 | 112.6s


100%|██████████| 2658/2658 [01:52<00:00, 23.63it/s]


Epoch 43 | train loss (CE) 0.5084 | val loss 0.4652 | CE loss + treereg loss -2.3072 | 112.6s


100%|██████████| 2658/2658 [01:52<00:00, 23.59it/s]


Epoch 44 | train loss (CE) 0.5051 | val loss 0.4638 | CE loss + treereg loss -2.3079 | 112.8s


100%|██████████| 2658/2658 [01:52<00:00, 23.69it/s]


Epoch 45 | train loss (CE) 0.5019 | val loss 0.4648 | CE loss + treereg loss -2.3717 | 112.4s


100%|██████████| 2658/2658 [01:52<00:00, 23.65it/s]


Epoch 46 | train loss (CE) 0.4993 | val loss 0.4645 | CE loss + treereg loss -2.3922 | 112.5s


100%|██████████| 2658/2658 [01:52<00:00, 23.62it/s]


Epoch 47 | train loss (CE) 0.4961 | val loss 0.4412 | CE loss + treereg loss -2.4962 | 112.7s


100%|██████████| 2658/2658 [01:52<00:00, 23.69it/s]


Epoch 48 | train loss (CE) 0.4934 | val loss 0.4344 | CE loss + treereg loss -2.4438 | 112.3s


100%|██████████| 2658/2658 [01:53<00:00, 23.39it/s]


Epoch 49 | train loss (CE) 0.4914 | val loss 0.4241 | CE loss + treereg loss -2.5630 | 113.8s


100%|██████████| 2658/2658 [01:52<00:00, 23.69it/s]


Epoch 50 | train loss (CE) 0.4892 | val loss 0.4440 | CE loss + treereg loss -2.5803 | 112.3s


In [20]:
#  testing here

model.eval()
n_correct = 0
total = 0
x = sp.Symbol('x')
rng = np.random.default_rng()

for batch_data in tqdm(test_loader):
    src = batch_data["input_ids"]
    tgt = batch_data["labels"]
    src, tgt = src.to(device), tgt.to(device)
    hyps = greedy_decode(src)       # list of B lists of token IDs - all hypostheses
    truths = tgt.tolist()           # list of B lists - truths

    for hyp_ids, true_ids in zip(hyps, truths):
        # find first EOS and remove everything after it
        try:
          first_eos = hyp_ids.index(EOS)
        except:
          first_eos = len(hyp_ids) # if no EOS, don't strip

        # strip special tokens
        hyp_tok  = [idx2word[i] for i in hyp_ids[:first_eos]  if i not in (PAD, BOS, EOS)]
        true_tok = [idx2word[i] for i in true_ids if i not in (PAD, BOS, EOS)]

        # convert to Sympy and check
        try:
          hyp_expr = prefix_to_sympy(hyp_tok, OPERATORS)
        except:
          total += 1 # if model output isn't real equation, skip
          continue
        true_expr = prefix_to_sympy(true_tok, OPERATORS)
        if verify_solution(hyp_expr, true_expr, rng):
            n_correct += 1

        total += 1

acc = 100 * n_correct / total
print(f"Greedy semantic accuracy: {acc:.2f}%")


100%|██████████| 32/32 [10:59<00:00, 20.60s/it]

Greedy semantic accuracy: 0.20%





In [21]:
# write beam search here
import torch.nn.functional as F
from collections import namedtuple

BeamHyp = namedtuple("BeamHyp", ["score", "tokens"])

def beam_search(src_batch, beam_size=5, length_penalty=1.0, max_len=128):
    """
    src_batch: LongTensor (B, S) - batch first as in the main model too
    returns: list of B best token ID lists
    """
    model.eval()
    B, S = src_batch.shape
    src_batch = src_batch.to(device)
    memory = model.encode(src_batch)

    # initialize beams per example
    beams = [[BeamHyp(0.0, [BOS])] for _ in range(B)]

    for _ in range(max_len):
        all_beams = [[] for _ in range(B)]
        for b in range(B):
            for hyp in beams[b]:
                tokens = hyp.tokens
                # prepare decoder input: (1, t)
                tgt_input = torch.tensor(tokens, dtype=torch.long, device=device).unsqueeze(0)
                dec = model.decode(tgt_input, memory[b:b+1])     # (1, t, D)
                # project last step to vocab & log‐softmax
                logits = model.generator(dec[:, -1, :])          # (1, V)
                logp   = F.log_softmax(logits, dim=-1).squeeze(0) # (V,)

                topv, topi = logp.topk(beam_size)
                for score, idx in zip(topv.tolist(), topi.tolist()):
                    all_beams[b].append(BeamHyp(hyp.score + score, tokens + [idx]))

            # prune back to beam_size
            all_beams[b].sort(
                key=lambda h: h.score / (len(h.tokens) ** length_penalty),
                reverse=True
            )
            beams[b] = all_beams[b][:beam_size]

    # extract all beams, and count as correct if at least one beam from hypothesis is correct
    results = []
    for b in range(B):
        b_beams = beams[b]
        out_beams = [a.tokens for a in b_beams]
        results.append(out_beams)
    return results


In [22]:

# now evaluate with beam=10
model.eval()
n_correct = 0
total = 0
for batch_data in tqdm(test_loader):
    src, tgt = batch_data["input_ids"], batch_data["labels"]
    src, tgt = src.to(device), tgt.to(device)
    hyps   = beam_search(src, beam_size=10, length_penalty=1.0, max_len=18)
    truths = tgt.tolist()

    for beam, true_ids in zip(hyps, truths):
        for hyp_ids in beam: # if one item from beams is correct, then count sample as correct
          # find first EOS and remove everything after it
          try:
            first_eos = hyp_ids.index(EOS)
          except:
            first_eos = len(hyp_ids) # if no EOS, don't strip

          # strip special tokens
          hyp_tok  = [idx2word[i] for i in hyp_ids[:first_eos]  if i not in (PAD, BOS, EOS)]
          true_tok = [idx2word[i] for i in true_ids if i not in (PAD, BOS, EOS)]

          # convert to Sympy and check
          try:
            hyp_expr = prefix_to_sympy(hyp_tok, OPERATORS)
          except:
            total += 1
            continue # if model output isn't real equation, skip
          true_expr = prefix_to_sympy(true_tok, OPERATORS)
          if verify_solution(hyp_expr, true_expr, rng):
              n_correct += 1
              break # this input's set of beams has a correct answer, skip the rest (any correct = correct)
        total += 1


acc = 100 * n_correct / total
print(f"Beam-10 semantic accuracy: {acc:.2f}%")
print("model", hyp_tok)
print("gt", true_tok)

100%|██████████| 32/32 [2:21:53<00:00, 266.05s/it]

Beam-10 semantic accuracy: 1.05%
model ['mul', 'c', 'pow', 'add', 'INT+', '6', 'mul', 'INT+', '2', 'x', 'INT-', '1']
gt ['mul', 'c', 'pow', 'add', 'INT+', '4', 'mul', 'INT+', '2', 'x', 'INT-', '1']



