<a href="https://colab.research.google.com/github/srirambandi/compsci685/blob/main/train_and_eval.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Colab Cell – clone your repo at the top of the notebook
!git clone https://github.com/srirambandi/compsci685.git
%cd compsci685

Cloning into 'compsci685'...
remote: Enumerating objects: 153, done.[K
remote: Counting objects: 100% (153/153), done.[K
remote: Compressing objects: 100% (109/109), done.[K
remote: Total 153 (delta 77), reused 103 (delta 35), pack-reused 0 (from 0)[K
Receiving objects: 100% (153/153), 25.90 MiB | 3.05 MiB/s, done.
Resolving deltas: 100% (77/77), done.
/content/compsci685


In [None]:
!pip install torch transformers datasets sympy tqdm scikit-learn

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [None]:
import sys
sys.path.insert(0, "gen_dataset")
sys.path.insert(0, "gen_dataset/src")

In [None]:
import os
import math
import time

import pandas as pd
import numpy as np
import sympy as sp
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from utils import verify_solution

x = sp.Symbol('x', real=True, nonzero=True)
c = sp.Symbol('c', real=True)
f = sp.Function('f', real=True, nonzero=True)


def prefix_to_sympy(prefix, arity_map):
    def helper(tokens):
        # pop the first token
        token = tokens.pop(0)

        if token == 'x':
            return x
        elif token == 'c':
            return c
        elif token == 'y':
            return f(x)
        elif token == "y'":
            return Derivative(f(x), x)
        

        if token.isdigit():
            return sp.Integer(int(token))

        # mathematical constants
        if token == 'E':
            return sp.E
        if token == 'pi':
            return sp.pi

        elif token in arity_map:
            arity = arity_map[token]
            _args = []
            for _ in range(arity):
                arg = helper(tokens)
                _args.append(arg)

            # basic operators
            if token == 'add':
                return sp.Add(*_args)
            elif token == 'sub':
                return sp.Add(_args[0], sp.Mul(-1, _args[1]))
            elif token == 'mul':
                return sp.Mul(*_args)
            elif token == 'div':
                return sp.Mul(_args[0], sp.Pow(_args[1], -1))
            elif token == 'pow':
                return sp.Pow(*_args)
            elif token == 'sqrt':
                return sp.sqrt(_args[0])
            elif token == 'exp':
                return sp.exp(_args[0])
            elif token == 'log':
                return sp.log(_args[0])
            elif token == 'abs':
                return sp.Abs(_args[0])
            elif token == 'sign':
                return sp.sign(_args[0])
            elif token in ['INT-', 'INT+']:
                return sp.Integer(_args[0] * -1 if token == 'INT-' else _args[0])

            # trig operators
            elif token == 'sin':
                return sp.sin(_args[0])
            elif token == 'cos':
                return sp.cos(_args[0])
            elif token == 'tan':
                return sp.tan(_args[0])

            # inverse trig operators
            elif token == 'asin':
                return sp.asin(_args[0])
            elif token == 'acos':
                return sp.acos(_args[0])
            elif token == 'atan':
                return sp.atan(_args[0])

            # hyperbolic operators
            elif token == 'sinh':
                return sp.sinh(_args[0])
            elif token == 'cosh':
                return sp.cosh(_args[0])
            elif token == 'tanh':
                return sp.tanh(_args[0])

            # inverse hyperbolic operators
            elif token == 'asinh':
                return sp.asinh(_args[0])
            elif token == 'acosh':
                return sp.acosh(_args[0])
            elif token == 'atanh':
                return sp.atanh(_args[0])
            else:
                print(f"Unknown operator: {token}")
        else:
            print(f"Unknown token: {token}")

    tokens = list(prefix)
    return helper(tokens)


OPERATORS = {
    # basic - binary
    'add': 2, 'sub': 2, 'mul': 2, 'div': 2, 'pow': 2,
    # basic - unary
    'sqrt': 1, 'exp': 1, 'log': 1, 'abs': 1, 'sign': 1,
    # trig - unary
    'sin': 1, 'cos': 1, 'tan': 1,
    # inverse trig - unary
    'asin': 1, 'acos': 1, 'atan': 1,
    # hyperbolic - unary
    'sinh': 1, 'cosh': 1, 'tanh': 1,
    # inverse hyperbolic - unary
    'asinh': 1, 'acosh': 1, 'atanh': 1,
    # INT - unary
    'INT+': 1, 'INT-': 1,
}

In [None]:
from google.colab import drive
drive.mount('/content/drive')
SAVE_DIR = "/content/drive/MyDrive/compsci685/checkpoints_baseline_paper_ode1_dataset"
os.makedirs(SAVE_DIR, exist_ok=True)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
DATA_PATH = "gen_dataset/data/dataset2.txt"
# data is stored as input_equation\toutput_equation

TEST_SIZE = 1_024 # number of samples to evaluate model equation correctness on
VALID_SIZE = 2_048 # number of samples to get validation loss from
# hardcode these numbers, as 10% of new >600K sample dataset is too much
# evaluation is slow, so only take 1024
# validation is a bit faster, so take 2048
# use those numbers as they are powers of 2 closest to 1k and 2k
# so batch size will evenly divide them during validation

os.makedirs("splits", exist_ok=True)
from tqdm import tqdm
print("Splitting Data")
with open(DATA_PATH, 'r') as reader:
  i = 0
  with open("splits/test.txt", 'w') as test_writer, open("splits/valid.txt", 'w') as valid_writer, open("splits/train.txt", 'w') as train_writer:
    for line in tqdm(reader):
      if i < TEST_SIZE:
        test_writer.write(line + "\n")
      elif i < (TEST_SIZE + VALID_SIZE):
        valid_writer.write(line + "\n")
      else:
        train_writer.write(line + "\n")

      i += 1




Splitting Data


683442it [00:00, 1489316.55it/s]


In [None]:
!wc -l splits/train.txt

1360740 splits/train.txt


In [None]:
# special tokens
SPECIAL = {"<pad>":0,"<bos>":1,"<eos>":2}

# collect tokens from original file:
tokens = set()
with open(DATA_PATH, 'r') as reader:
  for line in reader:
    tokens.update(line.split())
word2idx = {w:i+len(SPECIAL) for i,w in enumerate(sorted(tokens))}
word2idx.update(SPECIAL)
print(word2idx)
idx2word = {i:w for w,i in word2idx.items()}

PAD, BOS, EOS = word2idx["<pad>"], word2idx["<bos>"], word2idx["<eos>"]
VOCAB_SIZE = len(word2idx)

class ODEDataset(Dataset):
    def __init__(self, data_file, max_len=18):
        self.src = []
        self.tgt = []
        with open(data_file, 'r') as reader:
          for line in reader:
            if "\t" in line: # last line doesn't have any data
              src_item, tgt_item = line.split("\t")
              self.src.append(src_item.split())
              self.tgt.append(tgt_item.split())
        self.max_len = max_len
    def __len__(self):
        return len(self.src)
    def __getitem__(self, i):
        try:
          src = [BOS] + [word2idx[t] for t in self.src[i]] + [EOS]
          tgt = [BOS] + [word2idx[t] for t in self.tgt[i]] + [EOS]
        except:
          src = [BOS, EOS]
          tgt = [BOS, EOS]

        def pad(x):
            x = x[:self.max_len]
            return x + [PAD]*(self.max_len-len(x))
        return torch.tensor(pad(src)), len(src), torch.tensor(pad(tgt)), len(tgt)

def collate(batch):
    srcs, slens, tgts, tlens = zip(*batch)
    return (torch.stack(srcs), torch.tensor(slens)), (torch.stack(tgts), torch.tensor(tlens))

BATCH=256
train_loader = DataLoader(ODEDataset("splits/train.txt"), batch_size=BATCH, shuffle=True, collate_fn=collate, drop_last=True) # drop last when training because shuffled + multi epochs
val_loader   = DataLoader(ODEDataset("splits/valid.txt"), batch_size=BATCH, shuffle=False, collate_fn=collate)
test_loader  = DataLoader(ODEDataset("splits/test.txt"), batch_size=32, shuffle=False, collate_fn=collate) # smaller batch size to allow testing smaller subsets during development


{'0': 3, '1': 4, '2': 5, '3': 6, '4': 7, '5': 8, '6': 9, '7': 10, '8': 11, '9': 12, 'E': 13, 'INT+': 14, 'INT-': 15, 'abs': 16, 'acos': 17, 'acosh': 18, 'add': 19, 'asin': 20, 'asinh': 21, 'atan': 22, 'atanh': 23, 'c': 24, 'cos': 25, 'cosh': 26, 'div': 27, 'exp': 28, 'log': 29, 'mul': 30, 'pi': 31, 'pow': 32, 'sign': 33, 'sin': 34, 'sinh': 35, 'sqrt': 36, 'tan': 37, 'tanh': 38, 'x': 39, 'y': 40, "y'": 41, '<pad>': 0, '<bos>': 1, '<eos>': 2}


In [None]:
# TODO: check if we should update this!! - a smaller model than the one in original Deep Learning for Symbolic Mathematcs paper
class Seq2SeqTransformer(nn.Module):
    def __init__(self,
                 d_model=256,
                 nhead=4,
                 num_encoder_layers=4,
                 num_decoder_layers=4,
                 dim_feedforward=512,
                 dropout=0.1,
                 max_len=64):
        super().__init__()
        # positional embeddings: (max_len, d_model)
        self.pos_enc = nn.Parameter(torch.zeros(max_len, d_model))
        self.embedding = nn.Embedding(VOCAB_SIZE, d_model, padding_idx=PAD)
        self.transformer = nn.Transformer(
            d_model, nhead,
            num_encoder_layers, num_decoder_layers,
            dim_feedforward, dropout,
            batch_first=True
        )
        self.generator = nn.Linear(d_model, VOCAB_SIZE)

    def forward(self, src, tgt):
        # src: (S, B), tgt: (T, B)
        B, S = src.shape
        B2, T = tgt.shape
        assert B == B2

        # (B,S,d_model) + (1,S,d_model) => broadcast to (B,S,d)
        src_emb = self.embedding(src) + self.pos_enc[:S].unsqueeze(0)
        tgt_emb = self.embedding(tgt) + self.pos_enc[:T].unsqueeze(0)


        out = self.transformer(
            src_emb, tgt_emb,
            src_key_padding_mask=src == PAD,
            tgt_key_padding_mask=tgt == PAD,
            memory_key_padding_mask=src == PAD,
            tgt_mask=self.transformer.generate_square_subsequent_mask(T).to(src.device)
        )
        return self.generator(out)  # (B, T, V) expected here

    def encode(self, src):
        B, S = src.shape
        src_emb = self.embedding(src) + self.pos_enc[:S].unsqueeze(0)
        return self.transformer.encoder(
            src_emb,
            src_key_padding_mask=src == PAD
        )

    def decode(self, tgt, memory):
        B, T = tgt.shape
        tgt_emb = self.embedding(tgt) + self.pos_enc[:T].unsqueeze(0)
        return self.transformer.decoder(
            tgt_emb,
            memory,
            tgt_mask=self.transformer.generate_square_subsequent_mask(T).to(tgt.device),
            tgt_key_padding_mask=tgt == PAD
        )


In [None]:
# training and eval funcs go hereeeee
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Seq2SeqTransformer().to(device)
optimizer = optim.AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)
loss_fn   = nn.CrossEntropyLoss(ignore_index=PAD)

def train_epoch():
    model.train()
    total_loss = 0
    with tqdm(train_loader) as pbar:
        for (src, slen), (tgt, tlen) in pbar:
            src, tgt = src.to(device), tgt.to(device)
            # input to decoder is all but last token
            out = model(src, tgt[:,:-1])
            # compute loss against next tokens
            loss = loss_fn(out.reshape(-1, VOCAB_SIZE), tgt[:,1:].reshape(-1))
            optimizer.zero_grad(); loss.backward(); optimizer.step()
            total_loss += loss.item()
            pbar.set_postfix({"loss": loss.item()})
    return total_loss / len(train_loader)

@torch.no_grad()
def evaluate(loader):
    model.eval()
    total_loss = 0
    for (src, slen), (tgt, tlen) in loader:
        src, tgt = src.to(device), tgt.to(device)
        out = model(src, tgt[:,:-1])
        loss = loss_fn(out.reshape(-1, VOCAB_SIZE), tgt[:,1:].reshape(-1))
        total_loss += loss.item()
    return total_loss / len(loader)

def greedy_decode(src, max_len=18):
    src = src.to(device)
    memory = model.encode(src)
    ys = torch.full((src.size(0),1), BOS, device=device, dtype=torch.long)
    for i in range(max_len-1):
        out = model.decode(ys, memory)
        prob = model.generator(out[:,-1,:])
        next_word = prob.argmax(dim=-1, keepdim=True)
        ys = torch.cat([ys, next_word], dim=1)
        if (next_word==EOS).all(): break
    return ys.cpu().tolist()


In [None]:
@torch.no_grad()
def loss_tester(loader):
    total_loss = 0
    for (src, slen), (tgt, tlen) in loader:
        src, tgt = src.to(device), tgt.to(device)
        out = torch.nn.functional.one_hot(tgt[:, 1:], VOCAB_SIZE)
        loss = loss_fn(out.reshape(-1, VOCAB_SIZE).float() * 1e10, tgt[:,1:].reshape(-1))
        total_loss += loss.item()
    return total_loss / len(loader)

In [123]:
# training loop
EPOCHS = 50
torch.backends.cuda.matmul.allow_tf32 = True # make faster


best_model_state = None
best_val_loss = float("inf")
for epoch in range(1, EPOCHS+1):
    t0 = time.time()
    train_loss = train_epoch()
    val_loss   = evaluate(val_loader)
    print(f"Epoch {epoch} | train loss {train_loss:.4f} | val loss {val_loss:.4f} | {time.time()-t0:.1f}s")
    torch.save(model.state_dict(), SAVE_DIR + f"/epoch{epoch}.pt")
    if best_val_loss > val_loss:
      best_model_state = model.state_dict()
      best_val_loss = val_loss


100%|██████████| 2657/2657 [01:48<00:00, 24.42it/s, loss=0.985]


Epoch 1 | train loss 1.2860 | val loss 0.7722 | 109.0s


100%|██████████| 2657/2657 [01:48<00:00, 24.51it/s, loss=0.873]


Epoch 2 | train loss 0.9222 | val loss 0.6982 | 108.6s


100%|██████████| 2657/2657 [01:47<00:00, 24.61it/s, loss=0.782]


Epoch 3 | train loss 0.8300 | val loss 0.6331 | 108.2s


100%|██████████| 2657/2657 [01:50<00:00, 24.02it/s, loss=0.777]


Epoch 4 | train loss 0.7762 | val loss 0.5925 | 110.8s


100%|██████████| 2657/2657 [01:48<00:00, 24.54it/s, loss=0.702]


Epoch 5 | train loss 0.7360 | val loss 0.5871 | 108.4s


100%|██████████| 2657/2657 [01:48<00:00, 24.54it/s, loss=0.673]


Epoch 6 | train loss 0.7001 | val loss 0.5505 | 108.4s


100%|██████████| 2657/2657 [01:47<00:00, 24.62it/s, loss=0.636]


Epoch 7 | train loss 0.6684 | val loss 0.5339 | 108.1s


100%|██████████| 2657/2657 [01:50<00:00, 24.00it/s, loss=0.645]


Epoch 8 | train loss 0.6412 | val loss 0.5309 | 110.9s


100%|██████████| 2657/2657 [01:48<00:00, 24.47it/s, loss=0.599]


Epoch 9 | train loss 0.6174 | val loss 0.4996 | 108.8s


100%|██████████| 2657/2657 [01:48<00:00, 24.49it/s, loss=0.579]


Epoch 10 | train loss 0.5964 | val loss 0.4795 | 108.6s


100%|██████████| 2657/2657 [01:48<00:00, 24.44it/s, loss=0.534]


Epoch 11 | train loss 0.5766 | val loss 0.4697 | 108.9s


100%|██████████| 2657/2657 [01:50<00:00, 23.95it/s, loss=0.523]


Epoch 12 | train loss 0.5590 | val loss 0.4592 | 111.1s


100%|██████████| 2657/2657 [01:48<00:00, 24.49it/s, loss=0.555]


Epoch 13 | train loss 0.5425 | val loss 0.4341 | 108.6s


100%|██████████| 2657/2657 [01:48<00:00, 24.56it/s, loss=0.521]


Epoch 14 | train loss 0.5272 | val loss 0.4496 | 108.3s


100%|██████████| 2657/2657 [01:51<00:00, 23.93it/s, loss=0.512]


Epoch 15 | train loss 0.5136 | val loss 0.4290 | 111.2s


100%|██████████| 2657/2657 [01:48<00:00, 24.54it/s, loss=0.515]


Epoch 16 | train loss 0.5011 | val loss 0.4288 | 108.4s


100%|██████████| 2657/2657 [01:48<00:00, 24.52it/s, loss=0.448]


Epoch 17 | train loss 0.4897 | val loss 0.4155 | 108.5s


100%|██████████| 2657/2657 [01:48<00:00, 24.53it/s, loss=0.472]


Epoch 18 | train loss 0.4796 | val loss 0.4034 | 108.5s


100%|██████████| 2657/2657 [01:50<00:00, 24.01it/s, loss=0.462]


Epoch 19 | train loss 0.4702 | val loss 0.3933 | 110.8s


100%|██████████| 2657/2657 [01:48<00:00, 24.42it/s, loss=0.467]


Epoch 20 | train loss 0.4613 | val loss 0.3928 | 109.0s


100%|██████████| 2657/2657 [01:48<00:00, 24.40it/s, loss=0.454]


Epoch 21 | train loss 0.4534 | val loss 0.4068 | 109.0s


100%|██████████| 2657/2657 [01:48<00:00, 24.42it/s, loss=0.453]


Epoch 22 | train loss 0.4457 | val loss 0.3865 | 109.0s


100%|██████████| 2657/2657 [01:51<00:00, 23.92it/s, loss=0.442]


Epoch 23 | train loss 0.4392 | val loss 0.3797 | 111.3s


100%|██████████| 2657/2657 [01:49<00:00, 24.36it/s, loss=0.409]


Epoch 24 | train loss 0.4330 | val loss 0.3819 | 109.2s


100%|██████████| 2657/2657 [01:48<00:00, 24.50it/s, loss=0.411]


Epoch 25 | train loss 0.4269 | val loss 0.3876 | 108.6s


100%|██████████| 2657/2657 [01:51<00:00, 23.85it/s, loss=0.376]


Epoch 26 | train loss 0.4214 | val loss 0.3834 | 111.6s


100%|██████████| 2657/2657 [01:49<00:00, 24.37it/s, loss=0.421]


Epoch 27 | train loss 0.4161 | val loss 0.3682 | 109.2s


100%|██████████| 2657/2657 [01:48<00:00, 24.43it/s, loss=0.425]


Epoch 28 | train loss 0.4114 | val loss 0.3991 | 108.9s


100%|██████████| 2657/2657 [01:48<00:00, 24.52it/s, loss=0.393]


Epoch 29 | train loss 0.4068 | val loss 0.3764 | 108.5s


100%|██████████| 2657/2657 [01:50<00:00, 24.01it/s, loss=0.393]


Epoch 30 | train loss 0.4024 | val loss 0.3798 | 110.8s


100%|██████████| 2657/2657 [01:48<00:00, 24.41it/s, loss=0.402]


Epoch 31 | train loss 0.3985 | val loss 0.3705 | 109.0s


100%|██████████| 2657/2657 [01:48<00:00, 24.51it/s, loss=0.405]


Epoch 32 | train loss 0.3946 | val loss 0.3769 | 108.6s


100%|██████████| 2657/2657 [01:48<00:00, 24.49it/s, loss=0.351]


Epoch 33 | train loss 0.3909 | val loss 0.3713 | 108.6s


100%|██████████| 2657/2657 [01:50<00:00, 23.98it/s, loss=0.397]


Epoch 34 | train loss 0.3873 | val loss 0.3660 | 110.9s


100%|██████████| 2657/2657 [01:48<00:00, 24.47it/s, loss=0.404]


Epoch 35 | train loss 0.3844 | val loss 0.3616 | 108.7s


100%|██████████| 2657/2657 [01:48<00:00, 24.53it/s, loss=0.389]


Epoch 36 | train loss 0.3812 | val loss 0.3558 | 108.5s


100%|██████████| 2657/2657 [01:51<00:00, 23.74it/s, loss=0.37]


Epoch 37 | train loss 0.3781 | val loss 0.3775 | 112.1s


100%|██████████| 2657/2657 [01:48<00:00, 24.55it/s, loss=0.396]


Epoch 38 | train loss 0.3755 | val loss 0.3438 | 108.4s


100%|██████████| 2657/2657 [01:48<00:00, 24.55it/s, loss=0.346]


Epoch 39 | train loss 0.3726 | val loss 0.3645 | 108.4s


100%|██████████| 2657/2657 [01:48<00:00, 24.53it/s, loss=0.386]


Epoch 40 | train loss 0.3701 | val loss 0.3645 | 108.5s


100%|██████████| 2657/2657 [01:51<00:00, 23.90it/s, loss=0.363]


Epoch 41 | train loss 0.3677 | val loss 0.3609 | 111.3s


100%|██████████| 2657/2657 [01:48<00:00, 24.39it/s, loss=0.349]


Epoch 42 | train loss 0.3651 | val loss 0.3679 | 109.1s


100%|██████████| 2657/2657 [01:49<00:00, 24.28it/s, loss=0.335]


Epoch 43 | train loss 0.3630 | val loss 0.3555 | 109.6s


100%|██████████| 2657/2657 [01:51<00:00, 23.91it/s, loss=0.382]


Epoch 44 | train loss 0.3607 | val loss 0.3554 | 111.3s


100%|██████████| 2657/2657 [01:53<00:00, 23.48it/s, loss=0.369]


Epoch 45 | train loss 0.3586 | val loss 0.3666 | 113.3s


100%|██████████| 2657/2657 [01:49<00:00, 24.20it/s, loss=0.403]


Epoch 46 | train loss 0.3564 | val loss 0.3772 | 109.9s


100%|██████████| 2657/2657 [01:49<00:00, 24.17it/s, loss=0.319]


Epoch 47 | train loss 0.3547 | val loss 0.3647 | 110.1s


100%|██████████| 2657/2657 [01:54<00:00, 23.16it/s, loss=0.382]


Epoch 48 | train loss 0.3528 | val loss 0.3569 | 114.9s


100%|██████████| 2657/2657 [01:51<00:00, 23.87it/s, loss=0.338]


Epoch 49 | train loss 0.3510 | val loss 0.3664 | 111.5s


100%|██████████| 2657/2657 [01:53<00:00, 23.50it/s, loss=0.352]


Epoch 50 | train loss 0.3493 | val loss 0.3598 | 113.2s


In [124]:
test_loss = evaluate(test_loader)
print("Test loss", test_loss)
loss_test = loss_tester(test_loader)
print("Best possible loss", loss_test)

Test loss 0.6284813797101378
Best possible loss 0.0


In [None]:
# get average test output length
avg_tgt_len = 0
num = 0
for (src, src_lens), (tgt, tgt_lens) in test_loader:
  for tgt_len in tgt_lens:
    avg_tgt_len += tgt_len.item()
    num += 1

avg_tgt_len /= num
print("Average output length", avg_tgt_len)

In [128]:
#  testing here

model.eval()
n_correct = 0
total = 0
x = sp.Symbol('x')
rng = np.random.default_rng()

for (src, slen), (tgt, tlen) in tqdm(test_loader):
    src, tgt = src.to(device), tgt.to(device)
    hyps = greedy_decode(src)       # list of B lists of token IDs - all hypostheses
    truths = tgt.tolist()           # list of B lists - truths

    for hyp_ids, true_ids in zip(hyps, truths):
        # find first EOS and remove everything after it
        try:
          first_eos = hyp_ids.index(EOS)
        except:
          first_eos = len(hyp_ids) # if no EOS, don't strip

        # strip special tokens
        hyp_tok  = [idx2word[i] for i in hyp_ids[:first_eos]  if i not in (PAD, BOS, EOS)]
        true_tok = [idx2word[i] for i in true_ids if i not in (PAD, BOS, EOS)]

        # convert to Sympy and check
        try:
          hyp_expr = prefix_to_sympy(hyp_tok, OPERATORS)
        except:
          total += 1 # if model output isn't real equation, skip
          continue
        true_expr = prefix_to_sympy(true_tok, OPERATORS)
        if verify_solution(hyp_expr, true_expr, rng):
            n_correct += 1


        total += 1

acc = 100 * n_correct / total
print(f"Greedy semantic accuracy: {acc:.2f}%")
print("gt", true_tok)
print("model", hyp_tok)


100%|██████████| 32/32 [11:01<00:00, 20.69s/it]

Greedy semantic accuracy: 0.10%
gt ['mul', 'c', 'pow', 'add', 'INT+', '4', 'mul', 'INT+', '2', 'x', 'INT-', '1']
model ['mul', 'c', 'pow', 'add', 'INT+', '1', '2', 'mul', 'INT+', '4', 'x', 'INT-', '1']





In [None]:
# write beam search here
import torch.nn.functional as F
from collections import namedtuple

BeamHyp = namedtuple("BeamHyp", ["score", "tokens"])

def beam_search(src_batch, beam_size=5, length_penalty=1.0, max_len=128):
    """
    src_batch: LongTensor (B, S) - batch first as in the main model too
    returns: list of B best token ID lists
    """
    model.eval()
    B, S = src_batch.shape
    src_batch = src_batch.to(device)
    memory = model.encode(src_batch)

    # initialize beams per example
    beams = [[BeamHyp(0.0, [BOS])] for _ in range(B)]

    for _ in range(max_len):
        all_beams = [[] for _ in range(B)]
        for b in range(B):
            for hyp in beams[b]:
                tokens = hyp.tokens
                # prepare decoder input: (1, t)
                tgt_input = torch.tensor(tokens, dtype=torch.long, device=device).unsqueeze(0)
                dec = model.decode(tgt_input, memory[b:b+1])     # (1, t, D)
                # project last step to vocab & log‐softmax
                logits = model.generator(dec[:, -1, :])          # (1, V)
                logp   = F.log_softmax(logits, dim=-1).squeeze(0) # (V,)

                topv, topi = logp.topk(beam_size)
                for score, idx in zip(topv.tolist(), topi.tolist()):
                    all_beams[b].append(BeamHyp(hyp.score + score, tokens + [idx]))

            # prune back to beam_size
            all_beams[b].sort(
                key=lambda h: h.score / (len(h.tokens) ** length_penalty),
                reverse=True
            )
            beams[b] = all_beams[b][:beam_size]

    # extract all beams, and count as correct if at least one beam from hypothesis is correct
    results = []
    for b in range(B):
        b_beams = beams[b]
        out_beams = [a.tokens for a in b_beams]
        results.append(out_beams)
    return results

In [127]:


# now evaluate with beam=10
model.eval()
n_correct = 0
total = 0
for (src, slen), (tgt, tlen) in tqdm(test_loader):
    src, tgt = src.to(device), tgt.to(device)
    hyps   = beam_search(src, beam_size=10, length_penalty=1.0, max_len=18)
    truths = tgt.tolist()

    for beam, true_ids in zip(hyps, truths):
        for hyp_ids in beam: # if one item from beams is correct, then count sample as correct
          # find first EOS and remove everything after it
          try:
            first_eos = hyp_ids.index(EOS)
          except:
            first_eos = len(hyp_ids) # if no EOS, don't strip

          # strip special tokens
          hyp_tok  = [idx2word[i] for i in hyp_ids[:first_eos]  if i not in (PAD, BOS, EOS)]
          true_tok = [idx2word[i] for i in true_ids if i not in (PAD, BOS, EOS)]

          # convert to Sympy and check
          try:
            hyp_expr = prefix_to_sympy(hyp_tok, OPERATORS)
          except:
            continue # if model output isn't real equation, skip this beam
          true_expr = prefix_to_sympy(true_tok, OPERATORS)
          if verify_solution(hyp_expr, true_expr, rng):
              n_correct += 1
              break # found correct answer, can stop looking for this set of beams / input
        total += 1


acc = 100 * n_correct / total
print(f"Beam-10 semantic accuracy: {acc:.2f}%")
print("model", hyp_tok)
print("gt", true_tok)


100%|██████████| 32/32 [2:24:31<00:00, 270.99s/it]

Beam-10 semantic accuracy: 0.78%
model ['mul', 'c', 'pow', 'add', 'INT+', '1', '0', 'mul', 'INT+', '5', 'x', 'INT-', '1']
gt ['mul', 'c', 'pow', 'add', 'INT+', '4', 'mul', 'INT+', '2', 'x', 'INT-', '1']





In [None]:
# now evaluate with beam=3
model.eval()
n_correct = 0
total = 0
for (src, slen), (tgt, tlen) in tqdm(test_loader):
    src, tgt = src.to(device), tgt.to(device)
    hyps   = beam_search(src, beam_size=3, length_penalty=1.0, max_len=16)
    truths = tgt.tolist()

    for hyp_ids, true_ids in zip(hyps, truths):
        try:
            hyp_tok = [idx2word[i] for i in hyp_ids  if i not in (PAD,BOS,EOS)]
            true_tok = [idx2word[i] for i in true_ids  if i not in (PAD,BOS,EOS)]
            hyp_expr = prefix_to_sympy(hyp_tok, OPERATORS)
            x = sp.Symbol('x')
            if verify_solution(sp.diff(hyp_expr, x), hyp_expr, x):
                n_correct += 1
        except:
            pass
        total += 1
    if total > 150:
      break

acc = 100 * n_correct / total
print(f"Beam-3 semantic accuracy: {acc:.2f}%")
print("model", hyp_tok)
print("gt", true_tok)


  0%|          | 0/106 [00:00<?, ?it/s]



Beam-3 semantic accuracy: 0.00%
model ['add', 'add', 'add', 'add', 'add', 'add', 'c', 'x', 'add', 'add', 'c', 'x', 'add', 'add', '3', 'x', 'pow', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'c', 'x', '3', 'x', 'add', 'add', 'add', 'c', 'x', 'add', 'add', 'add', 'add', 'c', 'x', 'add', 'add', 'add', 'add', 'add', 'c', 'x', '3', 'x', 'add', 'add', 'add', 'add', 'c', 'x', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'add', 'c', 'x', '3', 'c', 'x', '3', 'c', 'x', '3', 'c', 'x', 'add', 'add', 'add', 'add', '3', 'c', 'x', 'add', 'add', 'add', 'add', '3', 'c', 'x', 'add', 'add', 'add', 'add', '3', 'c', 'x', 'add', 'add', 'add', 'add', 'add', '3', 'c', 'x', 'add', 'add', 'add', 'add', 'add', '3', 'c', 'x', 'add', 'add', 'add']
gt ['add', 'add', 'c', 'mul', '-3', 'x', 'mul', '-3', 'cos', 'x']


In [None]:
# load model
model.load_state_dict(torch.load(SAVE_DIR + "/epoch100.pt"))

<All keys matched successfully>

In [None]:
# test sequence
test_src = [BOS] + [word2idx[t] for t in ["y'"]] + [EOS]
print(test_src)
test_src = torch.LongTensor(test_src).unsqueeze(0)

[1, 209, 2]


In [None]:
# now evaluate with beam=10
model.eval()
test_src = test_src.to(device)
hyps = beam_search(test_src, beam_size=10, length_penalty=1.0, max_len=64)
test_src = test_src.tolist()
for beam in hyps:
    for hyp_ids in beam: # if one item from beams is correct, then count sample as correct
      # find first EOS and remove everything after it
      try:
        first_eos = hyp_ids.index(EOS)
      except:
        first_eos = len(hyp_ids) # if no EOS, don't strip

      # strip special tokens
      hyp_tok  = [idx2word[i] for i in hyp_ids[:first_eos]  if i not in (PAD, BOS, EOS)]
      input_tok = [idx2word[i] for i in test_src[0] if i not in (PAD, BOS, EOS)]
      print("model", hyp_tok)
      print("model input", input_tok)

model ['pow', 'add', 'c', 'x', '-1']
model input ["y'"]
model ['pow', 'mul', 'c', 'x', '-1']
model input ["y'"]
model ['div', 'c', 'x']
model input ["y'"]
model ['exp', 'add', 'c', 'x']
model input ["y'"]
model ['pow', 'sub', 'c', 'x', '-1']
model input ["y'"]
model ['pow', 'add', 'c', 'x']
model input ["y'"]
model ['pow', 'add', 'c', 'x']
model input ["y'"]
model ['div', '4', 'add', 'c', 'x']
model input ["y'"]
model ['div', '4', 'add', 'c', 'x']
model input ["y'"]
model ['div', '-1', 'add', 'c', 'x']
model input ["y'"]
