In [1]:
import re
from torchtext.vocab import build_vocab_from_iterator
import numpy as np

import torch
from torch import nn

In [23]:
# given original or target expression, return splitted
def split_exprs(input, output):
    input = input.strip()
    input, var = input.split(")/d")

    # exprs.append(expr[2:].replace(var, 'x'))
    
    input = input[2:]
    # Step 1: Replace `exp`, `cos`, and `sin` with placeholders
    input = input.replace('exp', 'E')
    input = input.replace('cos', 'C')
    input = input.replace('sin', 'S')
    output = output.replace('exp', 'E')
    output = output.replace('cos', 'C')
    output = output.replace('sin', 'S')

    # Step 2: Substitute `e`, `s`, `c` with 'x' (only where they’re standalone)
    input = input.replace(var, 'x')
    output = output.replace(var, 'x')

    # Step 3: Restore the placeholders back to their original words
    input = input.replace('E', 'exp')
    input = input.replace('C', 'cos')
    input = input.replace('S', 'sin')
    output = output.replace('E', 'exp')
    output = output.replace('C', 'cos')
    output = output.replace('S', 'sin')

    return input, output

# read training data line by line
def read_data(file):
    inputs = []
    outputs = []
    with open(file, 'r') as f:
        for i, line in enumerate(f):
            line = line.strip()
            
            input, output = line.split('=')
            input = input.strip()

            # split original and target expression
            input, output = split_exprs(input, output)
            
            inputs.append(input)
            outputs.append(output)

    return inputs, outputs

print(read_data("train.txt")[1][:5])

['72exp^(9x)', '340exp^(17x)+18exp^(x)', '144exp^(12x)+180exp^(20x)', '26x+36x+100x^4+114x^5', '456exp^(19x)*x+24exp^(19x)']


In [24]:
# break up the sequence to list of tokens
def add_space(expr):
    regex = '(exp\^\(|sin\(|cos\(|sin\^|cos\^|\+|\-|\(|\)|\^|\*|x|\d)'
    splitted = re.split(regex, expr)
    splitted = [x for x in splitted if x is not None and x != '']
    return splitted

X = []
y = []
inputs, outputs = read_data("train.txt")
X = [add_space(expr) for expr in inputs]
y = [add_space(expr) for expr in outputs]

input_lens = torch.tensor([len(x) for x in X], dtype=torch.long) #debugging
input = [a+["="]+b for a,b in zip(X, y)]
output = [expr[1:]+["<eos>"] for expr in input]
print(input[:5])
print(output[:5])
print(input_lens[:5])

[['8', 'exp^(', '9', 'x', ')', '=', '7', '2', 'exp^(', '9', 'x', ')'], ['2', '0', 'exp^(', '1', '7', 'x', ')', '+', '1', '8', 'exp^(', 'x', ')', '=', '3', '4', '0', 'exp^(', '1', '7', 'x', ')', '+', '1', '8', 'exp^(', 'x', ')'], ['1', '2', 'exp^(', '1', '2', 'x', ')', '+', '9', 'exp^(', '2', '0', 'x', ')', '=', '1', '4', '4', 'exp^(', '1', '2', 'x', ')', '+', '1', '8', '0', 'exp^(', '2', '0', 'x', ')'], ['1', '3', 'x', '^', '2', '+', '1', '8', 'x', '^', '2', '+', '2', '0', 'x', '^', '5', '+', '1', '9', 'x', '^', '6', '=', '2', '6', 'x', '+', '3', '6', 'x', '+', '1', '0', '0', 'x', '^', '4', '+', '1', '1', '4', 'x', '^', '5'], ['2', '4', 'exp^(', '1', '9', 'x', ')', '*', 'x', '=', '4', '5', '6', 'exp^(', '1', '9', 'x', ')', '*', 'x', '+', '2', '4', 'exp^(', '1', '9', 'x', ')']]
[['exp^(', '9', 'x', ')', '=', '7', '2', 'exp^(', '9', 'x', ')', '<eos>'], ['0', 'exp^(', '1', '7', 'x', ')', '+', '1', '8', 'exp^(', 'x', ')', '=', '3', '4', '0', 'exp^(', '1', '7', 'x', ')', '+', '1', '8', 'exp

In [12]:
from torch.nn.utils.rnn import pad_sequence

vocab = build_vocab_from_iterator(input, specials=["[UNK]", "[PAD]"])
vocab.set_default_index(vocab["[UNK]"])  # Set default index for OOV tokens

print("done building vocab, converting to encoded inputs and outputs")

input = [torch.tensor(vocab(tokens), dtype=torch.long) for tokens in input]
output = [torch.tensor(vocab(tokens), dtype=torch.long) for tokens in output]
# y = torch.tensor([torch.tensor(vocab(tokens), dtype=torch.long) for tokens in y])
print("done converting input and output to encoded")


# Pad the sequences to make them equal length
input = pad_sequence(input, batch_first=True, padding_value=vocab["[PAD]"])
output = pad_sequence(output, batch_first=True, padding_value=vocab["[PAD]"])
seq_len = input.size(1)

# Print vocabulary and integer encoding
print("Vocabulary:", vocab.get_stoi())      # stoi: string-to-integer mapping
print("Encoded Inputs:", input[:5])  # Convert tokens to integer IDs
print("Encoded Outputs:", output[:5])  # Convert tokens to integer IDs

done building vocab, converting to encoded inputs and outputs
done converting input and output to encoded
Vocabulary: {'^': 4, '[UNK]': 0, 's': 34, '3': 14, '[PAD]': 1, '8': 13, 'x': 2, 'i': 33, ')': 5, '1': 3, 'r': 25, '2': 6, 'o': 37, 'exp^(': 7, 'p': 29, '0': 8, '*': 9, 'k': 35, '+': 10, 't': 31, '4': 11, 'v': 41, '6': 12, 'u': 26, '5': 15, '=': 16, 'w': 39, '7': 17, 'y': 28, '9': 18, '(': 19, 'sin^': 20, 'cos^': 21, 'm': 42, '-': 22, 'cos(': 23, 'sin(': 24, 'b': 27, 'z': 30, 'a': 32, 'n': 36, 'c': 38, 'e': 40}
Encoded Inputs: tensor([[13,  7, 18,  2,  5, 16, 17,  6,  7, 18,  2,  5,  1,  1,  1,  1,  1,  1,
          1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
          1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1],
        [ 6,  8,  7,  3, 17,  2,  5, 10,  3, 13,  7,  2,  5, 16, 14, 11,  8,  7,
          3, 17,  2,  5, 10,  3, 13,  7,  2,  5,  1,  1,  1,  1,  1,  1,  1,  1,
          1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  

In [6]:
# Determine the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device: ', device)

Device:  cpu


In [7]:
# positional encoding
def get_angs(pos, i, d_model):
    return pos / np.power(10000, (2*(i//2) / d_model))

def pos_encodings(seq_len, d_model):
    """
    Should return a positional encoding tensor of shape (1, seq_len, d_model)
    """
    angs = get_angs(np.arange(seq_len)[:, None], np.arange(d_model)[None, :], d_model)
    angs[:, 0::2] = np.sin(angs[:, 0::2])
    angs[:, 1::2] = np.cos(angs[:, 1::2])
    return torch.tensor(angs[None, :, :], dtype=torch.float32)

# test
# print(pos_encodings(32).shape)

In [8]:
# padding mask (batch_size,1,1,seq_len) applied to (batch_size, num_head, seq_len, seq_len)
def create_padding_mask(seq):
    return (seq == vocab['[PAD]'])[:, None, None, :]

def create_combined_mask(seq_len, input_lens, padding_mask):
    """
    seq_len: length of sequence
    input_lens: tensor of shape (batch_size,) with length of input portion
    padding_mask: tensor of shape (batch_size, seq_len) where True indicates padding
    
    Returns: mask of shape (batch_size, seq_len, seq_len) where True indicates masked positions
    """
    batch_size = len(input_lens)
    # device = input_lens.device
    
    # Create position indices
    pos_i = torch.arange(seq_len, device=device)[None, :, None]  # (1, seq_len, 1)
    pos_j = torch.arange(seq_len, device=device)[None, None, :]  # (1, 1, seq_len)
    input_lens = input_lens[:, None, None]  # (batch_size, 1, 1)
    
    # 1. Create pattern mask (input attention and causal output attention)
    # For input positions: can only attend to input positions
    # is_input_query = pos_i < input_lens
    # is_output_key = pos_j >= input_lens
    # cross_mask = is_input_query & is_output_key
    
    # # For output positions: can attend to all previous positions (causal mask)
    # is_output_query = pos_i >= input_lens
    is_future = pos_i < pos_j
    look_ahead_mask = is_future #is_output_query & is_future
    
    pattern_mask = look_ahead_mask# cross_mask | look_ahead_mask  # (batch_size, seq_len, seq_len)
    
    # 2. Create padding mask
    # Key padding mask: mask padded keys for all queries
    padding_mask = padding_mask.squeeze()
    key_padding = padding_mask[:, None, :]  # (batch_size, 1, seq_len)
    key_padding = key_padding.expand(-1, seq_len, -1)  # (batch_size, seq_len, seq_len)
    
    # Query padding mask: mask padded queries for all keys (not used)
    # query_padding = padding_mask[:, :, None]  # (batch_size, seq_len, 1)
    # query_padding = query_padding.expand(-1, -1, seq_len)  # (batch_size, seq_len, seq_len)
    
    # 3. Combine all masks
    final_mask = pattern_mask | key_padding #| query_padding
    
    return final_mask[:, None, :, :]  # (batch_size, 1, seq_len, seq_len)

# Example usage:
batch_size = 2

# Example sequence lengths
input_len = torch.tensor(input_lens[:batch_size])  # First sample: 2 input tokens, Second sample: 3 input tokens

# Example padding mask (1 indicates padding)
# padding_mask = torch.zeros((batch_size, seq_len), dtype=torch.bool)
# padding_mask[0, 4:] = True  # First sample has padding at positions 4,5
# padding_mask[1, 5:] = True  # Second sample has padding at position 5
input_test = input[:batch_size]
padding_mask = create_padding_mask(input_test)

mask = create_combined_mask(seq_len, input_len, padding_mask)

# Print example for first sample
print("Example mask for first sequence (input_len=2, padding at 4,5):")
# np.set_printoptions(threshold=np.inf)
print(mask.int().numpy())
"""
Should print something like:
[[0 0 1 1 1 1]  # First input token: can attend to input only
 [0 0 1 1 1 1]  # Second input token: can attend to input only
 [0 0 0 1 1 1]  # First output token: can attend to previous, not future/padding
 [0 0 0 0 1 1]  # Second output token: can attend to previous, not future/padding
 [1 1 1 1 1 1]  # PAD token: can't attend to anything
 [1 1 1 1 1 1]] # PAD token: can't attend to anything
"""

Example mask for first sequence (input_len=2, padding at 4,5):
[[[[0 1 1 ... 1 1 1]
   [0 0 1 ... 1 1 1]
   [0 0 0 ... 1 1 1]
   ...
   [0 0 0 ... 1 1 1]
   [0 0 0 ... 1 1 1]
   [0 0 0 ... 1 1 1]]]


 [[[0 1 1 ... 1 1 1]
   [0 0 1 ... 1 1 1]
   [0 0 0 ... 1 1 1]
   ...
   [0 0 0 ... 1 1 1]
   [0 0 0 ... 1 1 1]
   [0 0 0 ... 1 1 1]]]]




"\nShould print something like:\n[[0 0 1 1 1 1]  # First input token: can attend to input only\n [0 0 1 1 1 1]  # Second input token: can attend to input only\n [0 0 0 1 1 1]  # First output token: can attend to previous, not future/padding\n [0 0 0 0 1 1]  # Second output token: can attend to previous, not future/padding\n [1 1 1 1 1 1]  # PAD token: can't attend to anything\n [1 1 1 1 1 1]] # PAD token: can't attend to anything\n"

In [9]:
# create masks

# padding mask (batch_size,seq_len, 1) applied to (batch_size, seq_len, vocab_len)
# input seq of the shape (batch_size, seq_len)
# def create_output_padding_mask(seq):
#     return (seq == vocab['[PAD]'])[:,:,None]

# the output mask, of the shape (batch_size, seq_len, vocab_len)
def create_output_mask(seq, input_lens, padding_mask):
    batch_size, seq_len = seq.shape
    
    # Create position indices tensor
    positions = torch.arange(seq_len, device=device)[None, :]  # (1, seq_len)
    positions = positions.expand(batch_size, -1)  # (batch_size, seq_len)
    
    # Create mask using input lengths
    input_lens = torch.tensor(input_lens, device=seq.device)[:, None]  # (batch_size, 1)
    mask = positions < input_lens  # True for positions after input length

    # Add padding mask
    mask = mask | padding_mask.squeeze()  # Combine with padding mask
    
    # Add broadcasting dimensions for attention
    return mask[:,:]  # (batch_size, seq_len, vocab_len)

# Example usage:
padding_mask_test = create_padding_mask(input_test)
mask_test = create_output_mask(input_test, input_lens[:batch_size], padding_mask_test)

print("Example output mask for first sequence (input_len=2, padding at 4,5):")
print(mask_test.squeeze().int().numpy())

Example output mask for first sequence (input_len=2, padding at 4,5):
[[1 1 1 1 1 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
  1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
 [1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1
  1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]]




In [13]:
# attention mechanism
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.d_model = d_model
        self.num_heads = num_heads

        # dimension of each head as well as 
        # dq, dk, dv
        self.d_head = d_model//num_heads 

        self.Wq = nn.Linear(d_model, d_model)
        self.Wk = nn.Linear(d_model, d_model)
        self.Wv = nn.Linear(d_model, d_model)

        self.dense = nn.Linear(d_model, d_model)

    def split_heads(self, batch_size, x):
        x = x.view(batch_size, self.num_heads, -1, self.d_head)
        # (batch_size, num_head, seq_len, d_head)
        return x 
    
    def forward(self, x, mask):
        q = self.Wq(x)
        k = self.Wk(x)
        v = self.Wv(x)
        # (batch_size, seq_len, d_model)

        batch_size = q.shape[0]
        q = self.split_heads(batch_size, q)
        k = self.split_heads(batch_size, k)
        v = self.split_heads(batch_size, v)
        # (batch_size, num_head, seq_len, d_head)
        
        attention = torch.einsum('ijml,ijnl->ijmn',q,k) / self.d_head**0.5
        # (batch_size, num_head, seq_lenq, seq_lenk)

        if mask is not None:
            # print("unmasked attention:")
            # print(attention[:,1,:,:])
            # print("=="*100)
            attention = attention.masked_fill(mask == 1, float('-inf'))
            # print("masked attention:")
            # for i in range(attention.shape[2]):
            #     print(attention[:,1,i,:])
            # print("=="*100)


        # apply softmax to the dimension corresponding to k
        attention = torch.nn.functional.softmax(attention, dim = -1)
        # print("softmax attention:")
        # for i in range(attention.shape[2]):
        #     print(attention[:,1,i,:])
        # print("=="*100)
        

        out = torch.einsum('ijkl,ijlq->ijkq',attention, v)
        # (batch_size, num_head, seq_lenq, d_head)
        out = out.view(batch_size, -1, self.d_model)
        # (batch_size, seq_lenq, d_model)
        out = self.dense(out)

        return out

# MultiHeadAttention(32,4).forward(torch.randn(2, 10, 32), None).shape


In [14]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, dropout_rate):
        super(DecoderLayer, self).__init__()

        self.mha = MultiHeadAttention(d_model, num_heads)
        self.LN1 = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, 4*d_model),
            nn.ReLU(),
            nn.Linear(4*d_model, d_model)
        )
        self.LN2 = nn.LayerNorm(d_model)

        self.dropout1 = nn.Dropout(dropout_rate)
        self.dropout2 = nn.Dropout(dropout_rate)

    def forward(self,x,mask):
        out1 = self.mha(x,mask)
        out1 = self.dropout1(out1)
        out1 = self.LN1(x + out1)

        out2 = self.ffn(out1)
        out2 = self.dropout2(out2)
        out2 = self.LN2(out1 + out2)

        return out2 # (batch_size, seq_len, d_model)

# print(DecoderLayer(32,4,0.1)(torch.randn(2, 10, 32), None).shape)

In [15]:
class Decoder(nn.Module):
    def __init__(self, vocab_len, num_layers, d_model, num_heads, dropout_rate):
        super(Decoder, self).__init__()
        self.d_model = d_model
        self.embedding = nn.Embedding(vocab_len, d_model)

        self.decoderlayers = nn.ModuleList([DecoderLayer(d_model, num_heads, dropout_rate) for _ in range(num_layers)])

        self.dense = nn.Linear(d_model, vocab_len)

    def create_mask(self, x):
        mask = (x == 0).unsqueeze(1).unsqueeze(2)
        return mask
        
    def forward(self, x, mask):
        # x: (batch_size, seq_len)
        x = self.embedding(x)
        x += pos_encodings(x.shape[1], self.d_model)

        # print("Embedding stats:", 
        #   f"mean={x.mean().item():.4f}, "
        #   f"max={x.max().item():.4f}")
        for i, decoderlayer in enumerate(self.decoderlayers):
            x = decoderlayer(x, mask)
            # print(f"Layer {i} output:", 
            #   f"mean={x.mean().item():.4f}, "
            #   f"max={x.max().item():.4f}")
        x = self.dense(x)
        return x

# vocab_len_test = 100
# batch_size_test = 64
# seq_len_test = 20
# print(Decoder(vocab_len_test, 6, 32, 4, 0.1)(torch.randint(0, vocab_len_test, (batch_size_test, seq_len_test)) , None).shape)

In [16]:
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split

# create masks
print("Creating masks...")
padding_mask = create_padding_mask(input)
output_mask = create_output_mask(input, input_lens, padding_mask)
combined_mask = create_combined_mask(seq_len, input_lens, padding_mask)
print("Done creating masks, preparing data...")

# train, dev, test split
X_train, X_test, y_train, y_test, output_mask_train, output_mask_test, combined_mask_train, combined_mask_test = train_test_split(input, output, output_mask, combined_mask, test_size=0.05, random_state=1)
X_train, X_dev, y_train, y_dev, output_mask_train, output_mask_dev, combined_mask_train, combined_mask_dev = train_test_split(X_train, y_train, output_mask_train, combined_mask_train, test_size=0.05, random_state=2)


# prepare data
train_loader = DataLoader(TensorDataset(X_train, y_train, combined_mask_train, output_mask_train), batch_size=32, shuffle=False)
dev_loader = DataLoader(TensorDataset(X_dev, y_dev, combined_mask_dev, output_mask_dev), batch_size=X_dev.size(0), shuffle=False)
test_loader = DataLoader(TensorDataset(X_test, y_test, combined_mask_test, output_mask_test), batch_size=X_test.size(0), shuffle=False)

# dev_loader = DataLoader(TensorDataset(X_dev, combined_mask_dev, output_mask_dev), batch_size=X_dev.size(0), shuffle=False)
# test_loader = DataLoader(TensorDataset(X_test, combined_mask_test, output_mask_test), batch_size=X_test.size(0), shuffle=False)
print("Done preparing data")

Creating masks...




Done creating masks, preparing data...
Done preparing data


In [None]:
# loss calculation
def compute_masked_loss(preds, output, output_mask):
    
    # Calculate per-position loss
    loss_fn = nn.CrossEntropyLoss(reduction='none')
    # preds from: (batch_size, seq_len, vocab_size)
    #         to: (batch_size * seq_len, vocab_size)
    # output from: (batch_size, seq_len)
    #         to: (batch_size * seq_len)
    loss = loss_fn(preds.view(-1, preds.size(-1)), 
                  output.view(-1))
    # # Debug prints
    # if torch.isnan(loss).any():
    #     print("NaN in per-position loss!")
    #     print("Max logit value:", torch.max(preds))
    #     print("Min logit value:", torch.min(preds))
    #     print("Any inf in logits:", torch.isinf(preds).any())

    # loss shape from: (batch_size * seq_len)
    #              to: (batch_size, seq_len)
    loss = loss.view(output.shape)
    
    # Inverse mask for loss calculation (1 - mask gives 1 (True) for positions we want to mask out)
    loss = loss * ~output_mask
    final_loss = loss.sum() / (~output_mask).sum()

    if torch.isnan(final_loss):
        print("NaN in final loss!")
        print("Sum of losses:", loss.sum())
        print("Number of valid positions:", (~output_mask).sum())
    
    # Average loss over non-masked positions
    return final_loss

def compute_masked_precision(preds, output, output_mask):
    # Get predictions with highest probability
    preds = preds.argmax(dim=-1)
    
    # Count correct predictions
    correct = (preds == output) & ~output_mask
    
    # Calculate precision
    return correct.sum().float() / (~output_mask).sum()


# test
preds = torch.randn(64, 20, 100)
input_test = torch.randint(0, 100, (64, 20))
output_mask_test = create_output_mask(input_test, input_lens[:64], create_padding_mask(input_test))
compute_masked_loss(preds, input_test, output_mask_test)





tensor(5.0728)

In [None]:
# def train(model, train_loader, dev_loader, num_epochs, patience, lr):
vocab_len = len(vocab)
num_layers = 4 
d_model = 32 
num_heads = 4 
dropout_rate = 0.1

num_epochs=10
patience=2
lr=0.00001

torch.manual_seed(42)
model = Decoder(vocab_len, num_layers, d_model, num_heads, dropout_rate)

model.to(device)
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=lr, eps=1e-4)


best_precision = 0
num_bad_epochs = 0
for epoch in range(num_epochs):
    lowest_loss = np.inf
    running_loss = 0.0
    beta = 0.95
    for batch_id, (X_batch, y_batch, combined_mask_batch, output_mask_batch) in enumerate(train_loader):
        model.train()
        # batch_index = 0
        # if batch_id != batch_index:
        #     continue
        # print(f"running Batch {batch_id}")
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        combined_mask_batch, output_mask_batch = combined_mask_batch.to(device), output_mask_batch.to(device)

        # forward pass
        y_pred = model(X_batch, combined_mask_batch)

        # inspect y_pred
        # print(y_pred.shape) # (batch_size, seq_len, vocab_len)
        # for i in range(y_pred.shape[0]):
        #     for pos in range(y_pred.shape[1]):
        #         for j in range(y_pred.shape[2]):
        #             # if see NaN in any entry of y_pred
        #             if torch.isnan(y_pred[i][pos][j]):
        #                 pass
        #                 print(f"Encountered NaN at item {i}, position {pos}, vocab_pos {j}")
                        
        # apply mask to the output
        loss = compute_masked_loss(y_pred, y_batch, output_mask_batch)

        # Check for NaNs in loss
        if torch.isnan(loss).any():
            print(f"NaN detected in loss at epoch {epoch}, batch {batch_id}")
            # stop training
            raise ValueError("NaN detected in loss")

        # backward pass
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        # Check for exploding gradients
        for name, param in model.named_parameters():
            if param.grad is not None and torch.isnan(param.grad).any():
                print(f"NaN detected in gradients of {name} at epoch {epoch}, batch {batch_id}")
                raise ValueError("NaN detected in gradients")
            if param.grad is not None and torch.isinf(param.grad).any():
                print(f"Inf detected in gradients of {name} at epoch {epoch}, batch {batch_id}")
                raise ValueError("Inf detected in gradients")

        optimizer.step()

        running_loss = beta * running_loss + (1 - beta) * loss.item()

        if batch_id % 1250 == 0:
            # Validation loss
            model.eval()
            with torch.no_grad():
                for (X_dev, y_dev, combined_mask_dev, output_mask_dev) in dev_loader:
                    X_dev = X_dev.to(device)
                    combined_mask_dev, output_mask_dev = combined_mask_dev.to(device), output_mask_dev.to(device)
                    y_pred_dev = model(X_dev, combined_mask_dev)
                    loss_dev = compute_masked_loss(y_pred_dev, y_dev, output_mask_dev)
                    precision_dev = compute_masked_precision(y_pred_dev, y_dev, output_mask_dev)

            print(f'Epoch {epoch+1}/{num_epochs}, Batch {batch_id+1}/{len(train_loader)}, Loss: {loss.item()}, running loss: {running_loss}, Validation Loss: {loss_dev}, Validation Precision: {precision_dev}')

            if loss_dev < lowest_loss:
                lowest_loss = loss_dev
                torch.save(model.state_dict(), 'best_model.pth')
                num_bad_batch = 0
            else:
                num_bad_batch += 1

        # Early stopping
        if num_bad_epochs >= patience:
            print(f'Early stopping at epoch {epoch+1}')
            break   

    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {running_loss}, Validation Loss: {loss_dev}, Validation Precision: {precision_dev}')
        

Epoch 1/10, Batch 1/28204, Loss: 3.8977017402648926, running loss: 0.1948850870132448, Validation Loss: 3.876202344894409, Validation Precision: 0.033285290002822876
Epoch 1/10, Batch 1251/28204, Loss: 2.904059648513794, running loss: 2.875323553052806, Validation Loss: 2.842707633972168, Validation Precision: 0.2787513732910156
Epoch 1/10, Batch 2501/28204, Loss: 2.527920722961426, running loss: 2.5524608274797025, Validation Loss: 2.5250394344329834, Validation Precision: 0.30743902921676636
Epoch 1/10, Batch 3751/28204, Loss: 2.3542659282684326, running loss: 2.3735301598173173, Validation Loss: 2.3326737880706787, Validation Precision: 0.324112206697464
Epoch 1/10, Batch 5001/28204, Loss: 2.219581127166748, running loss: 2.21100564833224, Validation Loss: 2.1914734840393066, Validation Precision: 0.3495895266532898
Epoch 1/10, Batch 6251/28204, Loss: 2.1205790042877197, running loss: 2.114795488761512, Validation Loss: 2.0874600410461426, Validation Precision: 0.37365368008613586
E

In [None]:
# test on test set
def inference(model, test_loader):
    model.eval()
    with torch.no_grad():
        for (X_test, combined_mask_test, output_mask_test) in test_loader:
            X_test = X_test.to(device)
            combined_mask_test, output_mask_test = combined_mask_test.to(device), output_mask_test.to(device)
            y_pred_test = model(X_test, combined_mask_test)
            loss_test = compute_masked_loss(y_pred_test, X_test, output_mask_test)
            precision_test = compute_masked_precision(y_pred_test, X_test, output_mask_test)

    print(f'Test Loss: {loss_test}, Test Precision: {precision_test}')
 

In [None]:
# use model for inference
def get_derivative(string):
    # convert string to list of tokens
    tokens = add_space(string)
    # convert tokens to tensor
    input_seq = torch.tensor(vocab(tokens), dtype=torch.long).unsqueeze(0).to(device)
    # shape of input_seq: (1, seq_len)
    # create positional encoding
    pos_enc = pos_encodings(input_seq.shape[1], d_model).to(device)
    # shape of pos_enc: (1, seq_len, d_model)


    def inference(model, input_seq):
        model.eval()
        with torch.no_grad():
            input_seq = torch.tensor(vocab(input_seq), dtype=torch.long).unsqueeze(0).to(device)
            preds = model(input_seq, None)[0]
            next_token = preds.argmax(dim=-1).squeeze()
        
        output_seq = input_seq
        if next_token==vocab['<eos>']:
            # turn the output_seq to string
            output_seq = [vocab.itos[i] for i in output_seq]
            return output_seq   
        else:
            output_seq = torch.cat((output_seq, next_token), dim=1)
            return inference(model, output_seq)
