<a href="https://colab.research.google.com/github/srirambandi/compsci685/blob/main/train_and_eval.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
# Colab Cell – clone your repo at the top of the notebook
!git clone https://github.com/srirambandi/compsci685.git
%cd compsci685

fatal: destination path 'compsci685' already exists and is not an empty directory.
/content/compsci685/compsci685


In [4]:
!pip install torch transformers datasets sympy tqdm scikit-learn



In [5]:
import sys
sys.path.insert(0, "gen_dataset")
sys.path.insert(0, "gen_dataset/src")

In [7]:
import os
import math
import time

import pandas as pd
import numpy as np
import sympy as sp
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from utils import prefix_to_sympy, verify_solution, OPERATORS

In [8]:
from google.colab import drive
drive.mount('/content/drive')
SAVE_DIR = "/content/drive/MyDrive/compsci685/checkpoints"
os.makedirs(SAVE_DIR, exist_ok=True)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [9]:
DATA_PATH = "gen_dataset/data/fin_dataset.csv"
cols = ["id", "equ_str", "equ_prefix", "sol_str", "sol_prefix"]

df = pd.read_csv(DATA_PATH, header=0, names=cols)
print(f"Total examples: {len(df)}")
df.head(5)

Total examples: 33874


Unnamed: 0,id,equ_str,equ_prefix,sol_str,sol_prefix
0,0,x^4*e^(x) + 4*x^3*e^(x) + y' - 1 = 0,add add add -1 mul pow x 4 exp x mul mul 4 pow...,c - x^4*e^(x) + x,add add c x mul mul -1 pow x 4 exp x
1,1,-2*y*ln(12) + y' = 0,add mul mul -2 y log 12 y',12^(c + 2*x),pow 12 add c mul 2 x
2,2,2*x + 9*y' - 9 = 0,add add -9 mul 2 x mul 9 y',c - x^2/9 + x,add add c x mul div -1 9 pow x 2
3,3,2*x^2 + x*y' - y = 0,add add mul -1 y mul 2 pow x 2 mul x y',c*x + x*(4 - 2*x) + x,add add x mul c x mul x add 4 mul -2 x
4,4,54*x*e^(y/9) + y' = 0,add mul mul 54 x exp mul div 1 9 y y',-9*ln(c + 3*x^2),mul -9 log add c mul 3 pow x 2


In [10]:
# stats about our dataset
df["equ_len"] = df["equ_prefix"].str.split().apply(len)
df["sol_len"] = df["sol_prefix"].str.split().apply(len)

df[["equ_len", "sol_len"]].describe().T

Unnamed: 0,count,mean,std,min,25%,50%,75%,max
equ_len,33874.0,13.184684,7.748491,1.0,8.0,11.0,15.0,115.0
sol_len,33874.0,10.013107,2.727231,2.0,8.0,10.0,11.0,28.0


In [11]:
# let's split the data into train, test and val like in the paper: https://github.com/facebookresearch/SymbolicMathematics/blob/main/split_data.py

N = len(df)
m = int(0.1 * N)       # let's take 10% as validation data set size for now.
assert 2 * m < N, "Pick smaller m!"

alpha = math.log(N - 0.5) / math.log(2 * m)

raw_idxs = [int(i**alpha) for i in range(1, 2*m + 1)]
val_idxs  = set(raw_idxs[::2])
test_idxs = set(raw_idxs[1::2])

all_idxs   = set(range(N))
train_idxs = all_idxs - val_idxs - test_idxs

# slice and reset index
train_df = df.iloc[sorted(train_idxs)].reset_index(drop=True)
val_df   = df.iloc[sorted(val_idxs)].reset_index(drop=True)
test_df  = df.iloc[sorted(test_idxs)].reset_index(drop=True)

print(f"Split sizes => train: {len(train_df)}, valid: {len(val_df)}, test: {len(test_df)}")

Split sizes => train: 27100, valid: 3387, test: 3387


In [12]:
os.makedirs("splits", exist_ok=True)

train_df.to_csv("splits/train.csv", index=False)
val_df.to_csv("splits/valid.csv", index=False)
test_df.to_csv("splits/test.csv", index=False)

print("Saved splits into ./splits/")

Saved splits into ./splits/


In [17]:
# special tokens
SPECIAL = {"<pad>":0,"<bos>":1,"<eos>":2}

# collect tokens from training
tokens = {t for seq in train_df.equ_prefix.str.split() for t in seq}
tokens |= {t for seq in train_df.sol_prefix.str.split() for t in seq}
word2idx = {w:i+len(SPECIAL) for i,w in enumerate(sorted(tokens))}
word2idx.update(SPECIAL)
idx2word = {i:w for w,i in word2idx.items()}

PAD, BOS, EOS = word2idx["<pad>"], word2idx["<bos>"], word2idx["<eos>"]
VOCAB_SIZE = len(word2idx)

class ODEDataset(Dataset):
    def __init__(self, df, src_col, tgt_col, max_len=128):
        self.src = df[src_col].str.split().tolist()
        self.tgt = df[tgt_col].str.split().tolist()
        self.max_len = max_len
    def __len__(self):
        return len(self.src)
    def __getitem__(self, i):
        src = [BOS] + [word2idx[t] for t in self.src[i]] + [EOS]
        tgt = [BOS] + [word2idx[t] for t in self.tgt[i]] + [EOS]
        def pad(x):
            x = x[:self.max_len]
            return x + [PAD]*(self.max_len-len(x))
        return torch.tensor(pad(src)), len(src), torch.tensor(pad(tgt)), len(tgt)

def collate(batch):
    srcs, slens, tgts, tlens = zip(*batch)
    return (torch.stack(srcs), torch.tensor(slens)), (torch.stack(tgts), torch.tensor(tlens))

BATCH=32
train_loader = DataLoader(ODEDataset(train_df,"equ_prefix","sol_prefix"), batch_size=BATCH, shuffle=True, collate_fn=collate)
val_loader   = DataLoader(ODEDataset(val_df,  "equ_prefix","sol_prefix"), batch_size=BATCH, shuffle=False, collate_fn=collate)
test_loader  = DataLoader(ODEDataset(test_df, "equ_prefix","sol_prefix"), batch_size=BATCH, shuffle=False, collate_fn=collate)


In [18]:
# TODO: check if we should update this!! - a smaller model than the one in original Deep Learning for Symbolic Mathematcs paper
class Seq2SeqTransformer(nn.Module):
    def __init__(self,
                 d_model=256,
                 nhead=4,
                 num_encoder_layers=4,
                 num_decoder_layers=4,
                 dim_feedforward=512,
                 dropout=0.1,
                 max_len=512):
        super().__init__()
        # positional embeddings: (max_len, d_model)
        self.pos_enc = nn.Parameter(torch.zeros(max_len, d_model))
        self.embedding = nn.Embedding(VOCAB_SIZE, d_model, padding_idx=PAD)
        self.transformer = nn.Transformer(
            d_model, nhead,
            num_encoder_layers, num_decoder_layers,
            dim_feedforward, dropout,
            batch_first=True
        )
        self.generator = nn.Linear(d_model, VOCAB_SIZE)

    def forward(self, src, tgt):
        # src: (S, B), tgt: (T, B)
        B, S = src.shape
        B2, T = tgt.shape
        assert B == B2

        # (B,S,d_model) + (1,S,d_model) => broadcast to (B,S,d)
        src_emb = self.embedding(src) + self.pos_enc[:S].unsqueeze(0)
        tgt_emb = self.embedding(tgt) + self.pos_enc[:T].unsqueeze(0)


        out = self.transformer(
            src_emb, tgt_emb,
            src_key_padding_mask=src == PAD,
            tgt_key_padding_mask=tgt == PAD,
            memory_key_padding_mask=src == PAD,
            tgt_mask=self.transformer.generate_square_subsequent_mask(T).to(src.device)
        )
        return self.generator(out)  # (B, T, V) expected here

    def encode(self, src):
        B, S = src.shape
        src_emb = self.embedding(src) + self.pos_enc[:S].unsqueeze(0)
        return self.transformer.encoder(
            src_emb,
            src_key_padding_mask=src == PAD
        )

    def decode(self, tgt, memory):
        B, T = tgt.shape
        tgt_emb = self.embedding(tgt) + self.pos_enc[:T].unsqueeze(0)
        return self.transformer.decoder(
            tgt_emb,
            memory,
            tgt_mask=self.transformer.generate_square_subsequent_mask(T).to(tgt.device),
            tgt_key_padding_mask=tgt == PAD
        )


In [19]:
# training and eval funcs go hereeeee
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Seq2SeqTransformer().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)
loss_fn   = nn.CrossEntropyLoss(ignore_index=PAD)

def train_epoch():
    model.train()
    total_loss = 0
    for (src, slen), (tgt, tlen) in tqdm(train_loader):
        src, tgt = src.to(device), tgt.to(device)
        # input to decoder is all but last token
        out = model(src, tgt[:,:-1])
        # compute loss against next tokens
        loss = loss_fn(out.reshape(-1, VOCAB_SIZE), tgt[:,1:].reshape(-1))
        optimizer.zero_grad(); loss.backward(); optimizer.step()
        total_loss += loss.item()
    return total_loss / len(train_loader)

@torch.no_grad()
def evaluate(loader):
    model.eval()
    total_loss = 0
    for (src, slen), (tgt, tlen) in loader:
        src, tgt = src.to(device), tgt.to(device)
        out = model(src, tgt[:,:-1])
        loss = loss_fn(out.reshape(-1, VOCAB_SIZE), tgt[:,1:].reshape(-1))
        total_loss += loss.item()
    return total_loss / len(loader)

def greedy_decode(src, max_len=128):
    src = src.to(device)
    memory = model.encode(src)
    ys = torch.full((src.size(0),1), BOS, device=device, dtype=torch.long)
    for i in range(max_len-1):
        out = model.decode(ys, memory)
        prob = model.generator(out[:,-1,:])
        next_word = prob.argmax(dim=-1, keepdim=True)
        ys = torch.cat([ys, next_word], dim=0)
        if (next_word==EOS).all(): break
    return ys.cpu().tolist()


In [20]:
# training loop
EPOCHS = 10

for epoch in range(1, EPOCHS+1):
    t0 = time.time()
    train_loss = train_epoch()
    val_loss   = evaluate(val_loader)
    print(f"Epoch {epoch} | train loss {train_loss:.4f} | val loss {val_loss:.4f} | {time.time()-t0:.1f}s")
    torch.save(model.state_dict(), SAVE_DIR/f"epoch{epoch}.pt")


  0%|          | 0/847 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
#  testing here

model.eval()
n_correct = 0
total = 0
x = sp.Symbol('x')

for (src, slen), (tgt, tlen) in tqdm(test_loader):
    src, tgt = src.to(device), tgt.to(device)
    hyps = greedy_decode(src)       # list of B lists of token IDs - all hypostheses
    truths = tgt.tolist()           # list of B lists - truths

    for hyp_ids, true_ids in zip(hyps, truths):
        try:
            # strip special tokens
            hyp_tok  = [idx2word[i] for i in hyp_ids  if i not in (PAD, BOS, EOS)]
            true_tok = [idx2word[i] for i in true_ids if i not in (PAD, BOS, EOS)]

            # convert to Sympy and check
            hyp_expr  = prefix_to_sympy(hyp_tok, OPERATORS)
            if verify_solution(sp.diff(hyp_expr, x), hyp_expr, x):
                n_correct += 1
        except Exception:
            pass
        total += 1

acc = 100 * n_correct / len(test_ds)
print(f"Greedy semantic accuracy: {acc:.2f}%")


In [None]:
# write beam search here
import torch.nn.functional as F
from collections import namedtuple

BeamHyp = namedtuple("BeamHyp", ["score", "tokens"])

def beam_search(src_batch, beam_size=5, length_penalty=1.0, max_len=128):
    """
    src_batch: LongTensor (B, S) - batch first as in the main model too
    returns: list of B best token ID lists
    """
    model.eval()
    B, S = src_batch.shape
    src_batch = src_batch.to(device)
    memory = model.encode(src_batch)

    # initialize beams per example
    beams = [[BeamHyp(0.0, [BOS])] for _ in range(B)]

    for _ in range(max_len):
        all_beams = [[] for _ in range(B)]
        for b in range(B):
            for hyp in beams[b]:
                tokens = hyp.tokens
                # prepare decoder input: (1, t)
                tgt_input = torch.tensor(tokens, dtype=torch.long, device=device).unsqueeze(0)
                dec = model.decode(tgt_input, memory[b:b+1])     # (1, t, D)
                # project last step to vocab & log‐softmax
                logits = model.generator(dec[:, -1, :])          # (1, V)
                logp   = F.log_softmax(logits, dim=-1).squeeze(0) # (V,)

                topv, topi = logp.topk(beam_size)
                for score, idx in zip(topv.tolist(), topi.tolist()):
                    all_beams[b].append(BeamHyp(hyp.score + score, tokens + [idx]))

            # prune back to beam_size
            all_beams[b].sort(
                key=lambda h: h.score / (len(h.tokens) ** length_penalty),
                reverse=True
            )
            beams[b] = all_beams[b][:beam_size]

    # extract best sequence per example
    results = []
    for b in range(B):
        best = max(
            beams[b],
            key=lambda h: h.score / (len(h.tokens) ** length_penalty)
        )
        results.append(best.tokens)
    return results

# now evaluate with beam=10
model.eval()
n_correct = 0
total = 0
for (src, slen), (tgt, tlen) in tqdm(test_loader):
    src, tgt = src.to(device), tgt.to(device)
    hyps   = beam_search(src, beam_size=10, length_penalty=1.0, max_len=128)
    truths = tgt.tolist()

    for hyp_ids, true_ids in zip(hyps, truths):
        try:
            hyp_tok = [idx2word[i] for i in hyp_ids  if i not in (PAD,BOS,EOS)]
            hyp_expr = prefix_to_sympy(hyp_tok, OPERATORS)
            x = sp.Symbol('x')
            if verify_solution(sp.diff(hyp_expr, x), hyp_expr, x):
                n_correct += 1
        except:
            pass
        total += 1

acc = 100 * n_correct / total
print(f"Beam-10 semantic accuracy: {acc:.2f}%")
