<a href="https://colab.research.google.com/github/srirambandi/compsci685/blob/main/train_and_eval.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Colab Cell – clone your repo at the top of the notebook
!git clone https://github.com/srirambandi/compsci685.git
%cd compsci685

Cloning into 'compsci685'...
remote: Enumerating objects: 101, done.[K
remote: Counting objects: 100% (101/101), done.[K
remote: Compressing objects: 100% (73/73), done.[K
remote: Total 101 (delta 41), reused 75 (delta 20), pack-reused 0 (from 0)[K
Receiving objects: 100% (101/101), 1.76 MiB | 4.59 MiB/s, done.
Resolving deltas: 100% (41/41), done.
/content/compsci685


In [2]:
!pip install torch transformers datasets sympy tqdm scikit-learn

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [3]:
import sys
sys.path.insert(0, "gen_dataset")
sys.path.insert(0, "gen_dataset/src")

In [4]:
import os
import math
import time

import pandas as pd
import numpy as np
import sympy as sp
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from utils import prefix_to_sympy, verify_solution, OPERATORS

In [6]:
from google.colab import drive
drive.mount('/content/drive')
SAVE_DIR = "/content/drive/MyDrive/compsci685/checkpoints_baseline"
os.makedirs(SAVE_DIR, exist_ok=True)

Mounted at /content/drive


In [7]:
DATA_PATH = "gen_dataset/data/fin_dataset.csv"
cols = ["id", "equ_str", "equ_prefix", "sol_str", "sol_prefix"]

df = pd.read_csv(DATA_PATH, header=0, names=cols)
print(f"Total examples: {len(df)}")
df.head(5)

Total examples: 33874


Unnamed: 0,id,equ_str,equ_prefix,sol_str,sol_prefix
0,0,x^4*e^(x) + 4*x^3*e^(x) + y' - 1 = 0,add add add -1 mul pow x 4 exp x mul mul 4 pow...,c - x^4*e^(x) + x,add add c x mul mul -1 pow x 4 exp x
1,1,-2*y*ln(12) + y' = 0,add mul mul -2 y log 12 y',12^(c + 2*x),pow 12 add c mul 2 x
2,2,2*x + 9*y' - 9 = 0,add add -9 mul 2 x mul 9 y',c - x^2/9 + x,add add c x mul div -1 9 pow x 2
3,3,2*x^2 + x*y' - y = 0,add add mul -1 y mul 2 pow x 2 mul x y',c*x + x*(4 - 2*x) + x,add add x mul c x mul x add 4 mul -2 x
4,4,54*x*e^(y/9) + y' = 0,add mul mul 54 x exp mul div 1 9 y y',-9*ln(c + 3*x^2),mul -9 log add c mul 3 pow x 2


In [8]:
# stats about our dataset
df["equ_len"] = df["equ_prefix"].str.split().apply(len)
df["sol_len"] = df["sol_prefix"].str.split().apply(len)

df[["equ_len", "sol_len"]].describe().T

Unnamed: 0,count,mean,std,min,25%,50%,75%,max
equ_len,33874.0,13.184684,7.748491,1.0,8.0,11.0,15.0,115.0
sol_len,33874.0,10.013107,2.727231,2.0,8.0,10.0,11.0,28.0


In [9]:
# let's split the data into train, test and val like in the paper: https://github.com/facebookresearch/SymbolicMathematics/blob/main/split_data.py

N = len(df)
m = int(0.1 * N)       # let's take 10% as validation data set size for now.
assert 2 * m < N, "Pick smaller m!"

alpha = math.log(N - 0.5) / math.log(2 * m)

raw_idxs = [int(i**alpha) for i in range(1, 2*m + 1)]
val_idxs  = set(raw_idxs[::2])
test_idxs = set(raw_idxs[1::2])

all_idxs   = set(range(N))
train_idxs = all_idxs - val_idxs - test_idxs

# slice and reset index
train_df = df.iloc[sorted(train_idxs)].reset_index(drop=True)
val_df   = df.iloc[sorted(val_idxs)].reset_index(drop=True)
test_df  = df.iloc[sorted(test_idxs)].reset_index(drop=True)

print(f"Split sizes => train: {len(train_df)}, valid: {len(val_df)}, test: {len(test_df)}")

Split sizes => train: 27100, valid: 3387, test: 3387


In [10]:
os.makedirs("splits", exist_ok=True)

train_df.to_csv("splits/train.csv", index=False)
val_df.to_csv("splits/valid.csv", index=False)
test_df.to_csv("splits/test.csv", index=False)

print("Saved splits into ./splits/")

Saved splits into ./splits/


In [41]:
# special tokens
SPECIAL = {"<pad>":0,"<bos>":1,"<eos>":2}

# collect tokens from training
tokens = {t for seq in train_df.equ_prefix.str.split() for t in seq}
tokens |= {t for seq in train_df.sol_prefix.str.split() for t in seq}
word2idx = {w:i+len(SPECIAL) for i,w in enumerate(sorted(tokens))}
word2idx.update(SPECIAL)
idx2word = {i:w for w,i in word2idx.items()}

PAD, BOS, EOS = word2idx["<pad>"], word2idx["<bos>"], word2idx["<eos>"]
VOCAB_SIZE = len(word2idx)

class ODEDataset(Dataset):
    def __init__(self, df, src_col, tgt_col, max_len=128):
        self.src = df[src_col].str.split().tolist()
        self.tgt = df[tgt_col].str.split().tolist()
        self.max_len = max_len
    def __len__(self):
        return len(self.src)
    def __getitem__(self, i):
        try:
          src = [BOS] + [word2idx[t] for t in self.src[i]] + [EOS]
          tgt = [BOS] + [word2idx[t] for t in self.tgt[i]] + [EOS]
        except:
          src = [BOS, EOS]
          tgt = [BOS, EOS]

        def pad(x):
            x = x[:self.max_len]
            return x + [PAD]*(self.max_len-len(x))
        return torch.tensor(pad(src)), len(src), torch.tensor(pad(tgt)), len(tgt)

def collate(batch):
    srcs, slens, tgts, tlens = zip(*batch)
    return (torch.stack(srcs), torch.tensor(slens)), (torch.stack(tgts), torch.tensor(tlens))

BATCH=32
train_loader = DataLoader(ODEDataset(train_df,"equ_prefix","sol_prefix"), batch_size=BATCH, shuffle=True, collate_fn=collate)
val_loader   = DataLoader(ODEDataset(val_df,  "equ_prefix","sol_prefix"), batch_size=BATCH, shuffle=False, collate_fn=collate)
test_loader  = DataLoader(ODEDataset(test_df, "equ_prefix","sol_prefix"), batch_size=BATCH, shuffle=False, collate_fn=collate)


In [42]:
# TODO: check if we should update this!! - a smaller model than the one in original Deep Learning for Symbolic Mathematcs paper
class Seq2SeqTransformer(nn.Module):
    def __init__(self,
                 d_model=256,
                 nhead=4,
                 num_encoder_layers=4,
                 num_decoder_layers=4,
                 dim_feedforward=512,
                 dropout=0.1,
                 max_len=512):
        super().__init__()
        # positional embeddings: (max_len, d_model)
        self.pos_enc = nn.Parameter(torch.zeros(max_len, d_model))
        self.embedding = nn.Embedding(VOCAB_SIZE, d_model, padding_idx=PAD)
        self.transformer = nn.Transformer(
            d_model, nhead,
            num_encoder_layers, num_decoder_layers,
            dim_feedforward, dropout,
            batch_first=True
        )
        self.generator = nn.Linear(d_model, VOCAB_SIZE)

    def forward(self, src, tgt):
        # src: (S, B), tgt: (T, B)
        B, S = src.shape
        B2, T = tgt.shape
        assert B == B2

        # (B,S,d_model) + (1,S,d_model) => broadcast to (B,S,d)
        src_emb = self.embedding(src) + self.pos_enc[:S].unsqueeze(0)
        tgt_emb = self.embedding(tgt) + self.pos_enc[:T].unsqueeze(0)


        out = self.transformer(
            src_emb, tgt_emb,
            src_key_padding_mask=src == PAD,
            tgt_key_padding_mask=tgt == PAD,
            memory_key_padding_mask=src == PAD,
            tgt_mask=self.transformer.generate_square_subsequent_mask(T).to(src.device)
        )
        return self.generator(out)  # (B, T, V) expected here

    def encode(self, src):
        B, S = src.shape
        src_emb = self.embedding(src) + self.pos_enc[:S].unsqueeze(0)
        return self.transformer.encoder(
            src_emb,
            src_key_padding_mask=src == PAD
        )

    def decode(self, tgt, memory):
        B, T = tgt.shape
        tgt_emb = self.embedding(tgt) + self.pos_enc[:T].unsqueeze(0)
        return self.transformer.decoder(
            tgt_emb,
            memory,
            tgt_mask=self.transformer.generate_square_subsequent_mask(T).to(tgt.device),
            tgt_key_padding_mask=tgt == PAD
        )


In [52]:
# training and eval funcs go hereeeee
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Seq2SeqTransformer().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=0.01)
loss_fn   = nn.CrossEntropyLoss(ignore_index=PAD)

def train_epoch():
    model.train()
    total_loss = 0
    for (src, slen), (tgt, tlen) in tqdm(train_loader):
        src, tgt = src.to(device), tgt.to(device)
        # input to decoder is all but last token
        out = model(src, tgt[:,:-1])
        # compute loss against next tokens
        loss = loss_fn(out.reshape(-1, VOCAB_SIZE), tgt[:,1:].reshape(-1))
        print(out.reshape(-1, VOCAB_SIZE).shape, tgt[:,1:].reshape(-1).shape)
        optimizer.zero_grad(); loss.backward(); optimizer.step()
        total_loss += loss.item()
    return total_loss / len(train_loader)

@torch.no_grad()
def evaluate(loader):
    model.eval()
    total_loss = 0
    for (src, slen), (tgt, tlen) in loader:
        src, tgt = src.to(device), tgt.to(device)
        out = model(src, tgt[:,:-1])
        loss = loss_fn(out.reshape(-1, VOCAB_SIZE), tgt[:,1:].reshape(-1))
        total_loss += loss.item()
    return total_loss / len(loader)

def greedy_decode(src, max_len=64):
    src = src.to(device)
    memory = model.encode(src)
    ys = torch.full((src.size(0),1), BOS, device=device, dtype=torch.long)
    for i in range(max_len-1):
        out = model.decode(ys, memory)
        prob = model.generator(out[:,-1,:])
        next_word = prob.argmax(dim=-1, keepdim=True)
        ys = torch.cat([ys, next_word], dim=1)
        if (next_word==EOS).all(): break
    return ys.cpu().tolist()


In [84]:
@torch.no_grad()
def loss_tester(loader):
    total_loss = 0
    for (src, slen), (tgt, tlen) in loader:
        src, tgt = src.to(device), tgt.to(device)
        out = torch.nn.functional.one_hot(tgt[:, 1:], VOCAB_SIZE)
        loss = loss_fn(out.reshape(-1, VOCAB_SIZE).float() * 1e10, tgt[:,1:].reshape(-1))
        total_loss += loss.item()
    return total_loss / len(loader)

In [53]:
# training loop
EPOCHS = 100

best_model_state = None
best_val_loss = float("inf")
for epoch in range(1, EPOCHS+1):
    t0 = time.time()
    train_loss = train_epoch()
    val_loss   = evaluate(val_loader)
    print(f"Epoch {epoch} | train loss {train_loss:.4f} | val loss {val_loss:.4f} | {time.time()-t0:.1f}s")
    torch.save(model.state_dict(), SAVE_DIR + f"/epoch{epoch}.pt")
    if best_val_loss > val_loss:
      best_model_state = model.state_dict()
      best_val_loss = val_loss


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 1 | train loss 1.8210 | val loss 1.4737 | 32.4s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 2 | train loss 1.5094 | val loss 1.4436 | 32.2s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 3 | train loss 1.5034 | val loss 1.4596 | 32.2s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 4 | train loss 1.5100 | val loss 1.4526 | 32.3s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 5 | train loss 1.5179 | val loss 1.4646 | 32.2s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 6 | train loss 1.5294 | val loss 1.4737 | 32.2s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 7 | train loss 1.5426 | val loss 1.4906 | 32.1s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 8 | train loss 1.5560 | val loss 1.5044 | 32.2s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 9 | train loss 1.5687 | val loss 1.5131 | 32.1s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 10 | train loss 1.5718 | val loss 1.5119 | 32.1s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 11 | train loss 1.5711 | val loss 1.4926 | 32.2s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 12 | train loss 1.5526 | val loss 1.4718 | 31.9s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 13 | train loss 1.5366 | val loss 1.4545 | 32.2s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 14 | train loss 1.5194 | val loss 1.4472 | 32.0s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 15 | train loss 1.5053 | val loss 1.4253 | 32.0s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 16 | train loss 1.4890 | val loss 1.4050 | 32.3s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 17 | train loss 1.4731 | val loss 1.3903 | 32.3s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 18 | train loss 1.4601 | val loss 1.3811 | 31.9s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 19 | train loss 1.4479 | val loss 1.3636 | 32.3s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 20 | train loss 1.4341 | val loss 1.3549 | 33.2s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 21 | train loss 1.4221 | val loss 1.3461 | 32.8s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 22 | train loss 1.4108 | val loss 1.3274 | 32.2s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 23 | train loss 1.4002 | val loss 1.3175 | 32.2s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 24 | train loss 1.3896 | val loss 1.3081 | 32.3s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 25 | train loss 1.3791 | val loss 1.2983 | 32.3s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 26 | train loss 1.3659 | val loss 1.3014 | 32.3s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 27 | train loss 1.3565 | val loss 1.2726 | 32.3s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 28 | train loss 1.3448 | val loss 1.2587 | 32.1s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 29 | train loss 1.3333 | val loss 1.2554 | 32.1s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 30 | train loss 1.3239 | val loss 1.2415 | 32.2s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 31 | train loss 1.3124 | val loss 1.2300 | 32.2s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 32 | train loss 1.3042 | val loss 1.2261 | 32.1s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 33 | train loss 1.2961 | val loss 1.2119 | 32.1s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 34 | train loss 1.2881 | val loss 1.2139 | 32.1s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 35 | train loss 1.2813 | val loss 1.1965 | 32.1s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 36 | train loss 1.2731 | val loss 1.1878 | 32.0s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 37 | train loss 1.2649 | val loss 1.1805 | 32.2s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 38 | train loss 1.2597 | val loss 1.1933 | 32.0s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 39 | train loss 1.2526 | val loss 1.1718 | 32.0s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 40 | train loss 1.2484 | val loss 1.1611 | 32.2s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 41 | train loss 1.2422 | val loss 1.1604 | 32.1s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 42 | train loss 1.2379 | val loss 1.1469 | 32.3s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 43 | train loss 1.2323 | val loss 1.1516 | 32.1s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 44 | train loss 1.2276 | val loss 1.1463 | 32.0s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 45 | train loss 1.2220 | val loss 1.1401 | 32.1s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 46 | train loss 1.2186 | val loss 1.1463 | 32.0s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 47 | train loss 1.2141 | val loss 1.1272 | 32.1s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 48 | train loss 1.2109 | val loss 1.1271 | 32.0s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 49 | train loss 1.2059 | val loss 1.1317 | 32.1s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 50 | train loss 1.2024 | val loss 1.1163 | 32.0s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 51 | train loss 1.1987 | val loss 1.1137 | 32.0s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 52 | train loss 1.1959 | val loss 1.1191 | 32.3s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 53 | train loss 1.1936 | val loss 1.1082 | 32.1s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 54 | train loss 1.1900 | val loss 1.1055 | 32.1s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 55 | train loss 1.1873 | val loss 1.1032 | 32.1s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 56 | train loss 1.1847 | val loss 1.1023 | 32.2s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 57 | train loss 1.1814 | val loss 1.0966 | 32.2s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 58 | train loss 1.1787 | val loss 1.0969 | 32.3s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 59 | train loss 1.1761 | val loss 1.0956 | 32.1s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 60 | train loss 1.1732 | val loss 1.0917 | 32.0s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 61 | train loss 1.1723 | val loss 1.0849 | 32.2s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 62 | train loss 1.1703 | val loss 1.0821 | 32.0s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 63 | train loss 1.1667 | val loss 1.0827 | 32.1s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 64 | train loss 1.1652 | val loss 1.0859 | 32.1s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 65 | train loss 1.1632 | val loss 1.0819 | 32.1s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 66 | train loss 1.1604 | val loss 1.0768 | 32.3s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 67 | train loss 1.1583 | val loss 1.0807 | 32.1s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 68 | train loss 1.1562 | val loss 1.0743 | 32.1s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 69 | train loss 1.1542 | val loss 1.0609 | 32.0s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 70 | train loss 1.1532 | val loss 1.0768 | 32.1s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 71 | train loss 1.1511 | val loss 1.0641 | 32.2s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 72 | train loss 1.1512 | val loss 1.0594 | 32.2s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 73 | train loss 1.1479 | val loss 1.0623 | 32.2s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 74 | train loss 1.1452 | val loss 1.0513 | 32.1s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 75 | train loss 1.1466 | val loss 1.0570 | 32.1s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 76 | train loss 1.1420 | val loss 1.0597 | 32.4s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 77 | train loss 1.1431 | val loss 1.0525 | 32.0s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 78 | train loss 1.1402 | val loss 1.0524 | 32.1s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 79 | train loss 1.1383 | val loss 1.0465 | 32.1s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 80 | train loss 1.1377 | val loss 1.0487 | 32.0s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 81 | train loss 1.1367 | val loss 1.0594 | 31.9s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 82 | train loss 1.1362 | val loss 1.0441 | 32.0s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 83 | train loss 1.1338 | val loss 1.0469 | 32.0s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 84 | train loss 1.1341 | val loss 1.0524 | 32.0s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 85 | train loss 1.1333 | val loss 1.0368 | 32.2s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 86 | train loss 1.1306 | val loss 1.0517 | 32.1s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 87 | train loss 1.1297 | val loss 1.0377 | 32.2s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 88 | train loss 1.1293 | val loss 1.0509 | 32.0s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 89 | train loss 1.1290 | val loss 1.0379 | 32.0s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 90 | train loss 1.1276 | val loss 1.0361 | 32.2s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 91 | train loss 1.1276 | val loss 1.0335 | 32.0s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 92 | train loss 1.1251 | val loss 1.0336 | 32.1s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 93 | train loss 1.1248 | val loss 1.0342 | 32.4s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 94 | train loss 1.1239 | val loss 1.0465 | 32.4s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 95 | train loss 1.1250 | val loss 1.0323 | 32.5s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 96 | train loss 1.1226 | val loss 1.0382 | 32.6s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 97 | train loss 1.1215 | val loss 1.0329 | 32.5s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 98 | train loss 1.1211 | val loss 1.0450 | 32.5s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 99 | train loss 1.1220 | val loss 1.0249 | 32.4s


  0%|          | 0/847 [00:00<?, ?it/s]

Epoch 100 | train loss 1.1195 | val loss 1.0275 | 32.6s


In [85]:
test_loss = evaluate(test_loader)
print("Test loss", test_loss)
loss_test = loss_tester(test_loader)
print("Best possible loss", loss_test)



Test loss 1.0139896706590112
Best possible loss 0.0


In [72]:
#  testing here

model.eval()
n_correct = 0
total = 0
x = sp.Symbol('x')
rng = np.random.default_rng()

for (src, slen), (tgt, tlen) in tqdm(test_loader):
    src, tgt = src.to(device), tgt.to(device)
    hyps = greedy_decode(src)       # list of B lists of token IDs - all hypostheses
    truths = tgt.tolist()           # list of B lists - truths

    for hyp_ids, true_ids in zip(hyps, truths):
        # find first EOS and remove everything after it
        try:
          first_eos = hyp_ids.index(EOS)
        except:
          first_eos = len(hyp_ids) # if no EOS, don't strip

        # strip special tokens
        hyp_tok  = [idx2word[i] for i in hyp_ids[:first_eos]  if i not in (PAD, BOS, EOS)]
        true_tok = [idx2word[i] for i in true_ids if i not in (PAD, BOS, EOS)]

        # parse numbers
        hyp_tok = [np.int32(int(tok)) if tok.lstrip("-").isdigit() else tok for tok in hyp_tok]
        true_tok = [np.int32(int(tok)) if tok.lstrip("-").isdigit() else tok for tok in true_tok]

        # convert to Sympy and check
        try:
          hyp_expr = prefix_to_sympy(hyp_tok, OPERATORS)
        except:
          pass # if model output isn't real equation, skip
        true_expr = prefix_to_sympy(true_tok, OPERATORS)
        if verify_solution(hyp_expr, true_expr, rng):
            n_correct += 1


        total += 1
    if total > 200:
       break

acc = 100 * n_correct / len(test_loader)
print(f"Greedy semantic accuracy: {acc:.2f}%")
print("gt", true_tok)
print("model", hyp_tok)


  0%|          | 0/106 [00:00<?, ?it/s]

Greedy semantic accuracy: 0.00%
gt ['mul', 'div', np.int32(-5), 'x', 'add', np.int32(-2), 'mul', 'c', 'x']
model ['add', 'add', 'add', 'add', 'add', 'add', 'c', 'x', 'add', 'add', 'add', 'add', 'c', 'x', 'add', 'add', np.int32(3), 'x', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'c', 'x', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'c', 'x', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'c', 'x', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'c', 'x', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'c', 'x', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'c', 'x', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'c', 'x', 'add', 'add', 'add', 'add']


In [90]:
# write beam search here
import torch.nn.functional as F
from collections import namedtuple

BeamHyp = namedtuple("BeamHyp", ["score", "tokens"])

def beam_search(src_batch, beam_size=5, length_penalty=1.0, max_len=128):
    """
    src_batch: LongTensor (B, S) - batch first as in the main model too
    returns: list of B best token ID lists
    """
    model.eval()
    B, S = src_batch.shape
    src_batch = src_batch.to(device)
    memory = model.encode(src_batch)

    # initialize beams per example
    beams = [[BeamHyp(0.0, [BOS])] for _ in range(B)]

    for _ in range(max_len):
        all_beams = [[] for _ in range(B)]
        for b in range(B):
            for hyp in beams[b]:
                tokens = hyp.tokens
                # prepare decoder input: (1, t)
                tgt_input = torch.tensor(tokens, dtype=torch.long, device=device).unsqueeze(0)
                dec = model.decode(tgt_input, memory[b:b+1])     # (1, t, D)
                # project last step to vocab & log‐softmax
                logits = model.generator(dec[:, -1, :])          # (1, V)
                logp   = F.log_softmax(logits, dim=-1).squeeze(0) # (V,)

                topv, topi = logp.topk(beam_size)
                for score, idx in zip(topv.tolist(), topi.tolist()):
                    all_beams[b].append(BeamHyp(hyp.score + score, tokens + [idx]))

            # prune back to beam_size
            all_beams[b].sort(
                key=lambda h: h.score / (len(h.tokens) ** length_penalty),
                reverse=True
            )
            beams[b] = all_beams[b][:beam_size]

    # extract all beams, and count as correct if at least one beam from hypothesis is correct
    results = []
    for b in range(B):
        b_beams = beams[b]
        out_beams = [a.tokens for a in b_beams]
        results.append(out_beams)
    return results

In [91]:


# now evaluate with beam=10
model.eval()
n_correct = 0
total = 0
for (src, slen), (tgt, tlen) in tqdm(test_loader):
    src, tgt = src.to(device), tgt.to(device)
    hyps   = beam_search(src, beam_size=10, length_penalty=1.0, max_len=128)
    truths = tgt.tolist()

    for beam, true_ids in zip(hyps, truths):
        for hyp_ids in beam: # if one item from beams is correct, then count sample as correct
          # find first EOS and remove everything after it
          try:
            first_eos = hyp_ids.index(EOS)
          except:
            first_eos = len(hyp_ids) # if no EOS, don't strip

          # strip special tokens
          hyp_tok  = [idx2word[i] for i in hyp_ids[:first_eos]  if i not in (PAD, BOS, EOS)]
          true_tok = [idx2word[i] for i in true_ids if i not in (PAD, BOS, EOS)]

          # parse numbers
          hyp_tok = [np.int32(int(tok)) if tok.lstrip("-").isdigit() else tok for tok in hyp_tok]
          true_tok = [np.int32(int(tok)) if tok.lstrip("-").isdigit() else tok for tok in true_tok]

          # convert to Sympy and check
          try:
            hyp_expr = prefix_to_sympy(hyp_tok, OPERATORS)
          except:
            pass # if model output isn't real equation, skip
          true_expr = prefix_to_sympy(true_tok, OPERATORS)
          if verify_solution(hyp_expr, true_expr, rng):
              n_correct += 1
              continue # rest of beam is correct
        total += 1

    if total > 100:
        break

acc = 100 * n_correct / total
print(f"Beam-10 semantic accuracy: {acc:.2f}%")
print("model", hyp_tok)
print("gt", true_tok)


  0%|          | 0/106 [00:00<?, ?it/s]



Beam-10 semantic accuracy: 0.00%
model ['add', 'add', 'add', 'add', 'add', 'add', 'c', 'x', 'add', 'add', 'c', 'x', 'add', np.int32(3), 'mul', 'mul', np.int32(3), 'x', 'add', np.int32(3), 'x', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'c', 'x', 'add', 'add', 'add', 'add', 'c', 'x', 'add', 'add', 'add', 'add', 'add', 'c', 'x', np.int32(3), 'x', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'c', 'x', np.int32(3), 'c', 'x', 'add', 'add', 'add', 'add', 'add', 'c', 'x', 'add', np.int32(3), 'c', 'x', 'add', 'add', 'add', np.int32(3), 'c', 'x', 'add', np.int32(3), 'c', 'x', 'add', np.int32(3), 'c', 'x', 'add', 'add', np.int32(3), 'c', 'x', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', np.int32(3), 'c', 'x', 'add', np.int32(3), 'c', 'x', 'add', np.int32(3), 'c', 'x', 'add', 'add', 'add', np.int32(3), 'c', 'x', 'add', 'add', 'add', 'add', 'add', np.int32(3), 'c']
gt ['add', 'add', np.int32(4), 'mul', np.int32(49)

In [57]:
# now evaluate with beam=3
model.eval()
n_correct = 0
total = 0
for (src, slen), (tgt, tlen) in tqdm(test_loader):
    src, tgt = src.to(device), tgt.to(device)
    hyps   = beam_search(src, beam_size=3, length_penalty=1.0, max_len=128)
    truths = tgt.tolist()

    for hyp_ids, true_ids in zip(hyps, truths):
        try:
            hyp_tok = [idx2word[i] for i in hyp_ids  if i not in (PAD,BOS,EOS)]
            true_tok = [idx2word[i] for i in true_ids  if i not in (PAD,BOS,EOS)]
            hyp_expr = prefix_to_sympy(hyp_tok, OPERATORS)
            x = sp.Symbol('x')
            if verify_solution(sp.diff(hyp_expr, x), hyp_expr, x):
                n_correct += 1
        except:
            pass
        total += 1
    if total > 200:
      break

acc = 100 * n_correct / total
print(f"Beam-3 semantic accuracy: {acc:.2f}%")
print("model", hyp_tok)
print("gt", true_tok)


  0%|          | 0/106 [00:00<?, ?it/s]



Beam-3 semantic accuracy: 0.00%
model ['add', 'add', 'add', 'add', 'add', 'add', 'c', 'x', 'add', 'add', 'c', 'x', 'add', 'add', '3', 'x', 'pow', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'c', 'x', '3', 'x', 'add', 'add', 'add', 'c', 'x', 'add', 'add', 'add', 'add', 'c', 'x', 'add', 'add', 'add', 'add', 'add', 'c', 'x', '3', 'x', 'add', 'add', 'add', 'add', 'c', 'x', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'c', 'x', '3', 'c', 'x', '3', 'c', 'x', '3', 'c', 'x', 'add', 'add', 'add', 'add', '3', 'c', 'x', 'add', 'add', 'add', 'add', '3', 'c', 'x', 'add', 'add', 'add', 'add', '3', 'c', 'x', 'add', 'add', 'add', 'add', 'add', '3', 'c', 'x', 'add', 'add', 'add', 'add', 'add', '3', 'c', 'x', 'add', 'add', 'add']
gt ['add', 'add', 'c', 'mul', '-3', 'x', 'mul', '-3', 'cos', 'x']
