# Installation
You do not have to follow our installation instructions if you have roughly equivalent setups / environments already.

We will use Conda and Pip to help us install packages for this homework. If you do not have Miniconda or Anaconda, you can install Miniconda from here https://docs.conda.io/en/latest/miniconda.html.

```
conda create --name exercise2 python=3.7
conda activate exercise2

pip install jupyter pandas
```

Go to https://pytorch.org/ to install PyTorch if you don't have it already

To install the Hugging Face `transformers` library, run
```
pip install transformers
```

Follow the instructions from https://docs.dgl.ai/en/0.4.x/install/ to install Deep Graph Library (DGL).

Spin up jupyter notebook with
```
jupyter notebook
```

# Exercise
Our exercise is an implementation of the paper [Graph-to-Tree Learning for Solving Math Word Problems](https://www.aclweb.org/anthology/2020.acl-main.362.pdf), which solves math word problems in the MAWPS dataset. Please run `demo.ipynb` for some visualizations of the overall pipeline. We recommend that you read the original paper as well if necessary.

## Provided Components
1. We provide the entire input and output processing pipeline for you, as described in `demo.ipynb`.
2. We provide a fully implemented custom implementation of the transformer in the `TransformerBlock` class.
3. We provide a partially implemented graph convolutional network in the `GCN` class.
4. We provide a fully implemented tree decoding network in `TreeDecoder`. The tree decoding logic in `train` and `predict` is fully implemented.
5. If `use_t5 = 'small'`, `setup` will load in a pretrained `t5-small` into the variable `t5_model`.

## Tasks
Your tasks are
1. Use Deep Graph Library (DGL) to complete the graph convolution network. Keep `use_t5 = None` for this part. The baseline performance obtained by the TA is 0.74 validation value accuracy. Report validation accuracies for at least 5 sets of hyperparameters.
2. Use Hugging Face `transformers` library to replace the custom transformer base model with a pre-trained `t5-small` model. Set `use_t5 = 'small'` for this part. The baseline performance obtained by the TA is 0.78 validation value accuracy. Report validation accuracies for at least 5 sets of hyperparameters.
3. Change any part of the code (e.g. hyperparameter, architecture, training data, etc.) to optimize the performance of at least one of the transformers (custom or T5) to be better than the baseline performance.

Note that for parts 1 and 2, you should not change any provided component at all. For part 3, you may change any part of the code.

For each of these parts, please run all the cells until and including the "Training Loop" cell below (they are all definition cells except the actual training loop cell). When you run the training loop without completing some required part of the exercise, the code will throw a `NotImplementedError`; you should fill out the required code to fix this error, then run the training loop again.

In [1]:
import os
import sys
import itertools
from copy import copy


from tqdm import tqdm, trange


import numpy as np


import pandas as pd


import torch
import torch.nn as nn
import torch.nn.functional as F


import dgl
from dgl.nn import GraphConv


import transformers


from util import setup, check_match, sub_nP, evaluate_prefix_expression

Using backend: pytorch


In [2]:
# check versions
print(sys.version)
print(pd.__version__)
print(np.__version__)
print(torch.__version__)
print(dgl.__version__)
print(transformers.__version__)

3.7.5 (default, Nov  7 2019, 10:50:52) 
[GCC 8.3.0]
1.1.4
1.19.2
1.6.0+cpu
0.5.2
3.5.0


# Converting Inputs to Torch Tensors

In [3]:
def tensorize_data(data):
    """
    Collect tensors to build the input data for the model
    """
    for d in data:
        # Indices of the in_tokens in the in_vocab
        d['in_idxs'] = torch.tensor([in_vocab.token2idx.get(x, in_vocab.unk) for x in d['in_tokens']])
        d['n_in'] = n_in = len(d['in_idxs'])
        d['n_nP'] = n_nP = len(d['nP'])
        # True if the position in the input has a quantity
        d['nP_in_mask'] = mask = torch.zeros(n_in, dtype=torch.bool)
        mask[d['nP_positions']] = True
        if 'out_tokens' in d:
            # Indices of the out_tokens in the out_vocab
            d['out_idxs'] = torch.tensor([out_vocab.token2idx.get(x, out_vocab.unk) for x in d['out_tokens']])
            d['n_out'] = len(d['out_idxs'])
            # A mask where the first n_nP elements are True
            d['nP_out_mask'] = mask = torch.zeros(n_max_nP, dtype=torch.bool)
            mask[:n_nP] = True
        # Graph edges for constructing the DGL graph later
        d['qcomp_edges'] = get_quantity_comparison_edges(d)
        d['qcell_edges'] = get_quantity_cell_edges(d)

def get_quantity_comparison_edges(d):
    """
    Fill out an adjacency matrix representing quantity comparisons, then convert to list of edges
    """
    quants = [float(x) for x in d['nP']]
    quant_positions = d['nP_positions']
    assert max(quant_positions) < d['n_in']
    adj_matrix = torch.eye(d['n_in'], dtype=np.bool)
    for x, x_pos in zip(quants, quant_positions):
        for y, y_pos in zip(quants, quant_positions):
            adj_matrix[x_pos, y_pos] |= x > y
    """
    Convert the adjacency matrix of the directed graph into a tuple of (src_edges, dst_edges), which
    is the input format of dgl.graph (see https://docs.dgl.ai/generated/dgl.graph.html).
    Hint: check out the 'nonzero' function
    """
    ### Your code here (done) ###
    # This takes care of potentially bidirectional edges
    tmp = np.nonzero(adj_matrix)
    return tmp[0], tmp[1]

def get_quantity_cell_edges(d):
    """
    Fill out an adjacency matrix representing the quantity cell graph, then convert to list of edges
    """
    in_idxs = d['in_idxs']
    quant_positions = d['nP_positions']
    quant_cell_positions = d['quant_cell_positions']
    assert max(quant_cell_positions) < d['n_in']
    word_cells = set(quant_cell_positions) - set(quant_positions)
    adj_matrix = torch.eye(d['n_in'], dtype=torch.bool)
    for w_pos in word_cells:
        for q_pos in quant_positions:
            if abs(w_pos - q_pos) < 4:
                adj_matrix[w_pos, q_pos] = adj_matrix[q_pos, w_pos] = True
    pos_idxs = in_idxs[quant_cell_positions]
    for idx1, pos1 in zip(pos_idxs, quant_cell_positions):
        for idx2, pos2 in zip(pos_idxs, quant_cell_positions):
            if idx1 == idx2:
                adj_matrix[pos1, pos2] = adj_matrix[pos2, pos1] = True
    """
    Convert the adjacency matrix of the directed graph into a tuple of (src_edges, dst_edges), which
    is the input format of dgl.graph (see https://docs.dgl.ai/generated/dgl.graph.html).
    Hint: check out the 'nonzero' function
    """
    ### Your code here (done) ###
    # This takes care of potentially bidirectional edges
    tmp = np.nonzero(adj_matrix)
    return tmp[0], tmp[1]

# Model

In [4]:
class TransformerAttention(nn.Module):
    """
    Used in Transformer Block, implements the dot-product attention
    """
    def __init__(self):
        super().__init__()
        self.qkv = nn.Linear(n_hid, n_head * (n_k * 2 + n_v))
        self.out = nn.Linear(n_head * n_v, n_hid)

    def forward(self, x, mask=None):
        n_batch, n_batch_max_in, n_hid = x.shape
        q_k_v = self.qkv(x).view(n_batch, n_batch_max_in, n_head, 2 * n_k + n_v).transpose(1, 2)
        q, k, v = q_k_v.split([n_k, n_k, n_v], dim=-1)

        q = q.reshape(n_batch * n_head, n_batch_max_in, n_k)
        k = k.reshape_as(q).transpose(1, 2)
        qk = q.bmm(k) / np.sqrt(n_k)

        if mask is not None:
            qk = qk.view(n_batch, n_head, n_batch_max_in, n_batch_max_in).transpose(1, 2)
            qk[~mask] = -np.inf
            qk = qk.transpose(1, 2).view(n_batch * n_head, n_batch_max_in, n_batch_max_in)
        qk = qk.softmax(dim=-1)
        v = v.reshape(n_batch * n_head, n_batch_max_in, n_v)
        qkv = qk.bmm(v).view(
            n_batch, n_head, n_batch_max_in, n_v
        ).transpose(1, 2).reshape(n_batch, n_batch_max_in, n_head * n_v)
        out = self.out(qkv)
        return x + out

class TransformerBlock(nn.Module):
    """
    Custom Transformer
    """
    def __init__(self):
        super().__init__()
        self.attn = TransformerAttention()
        n_inner = n_hid * 4
        self.inner = nn.Sequential(
            nn.Linear(n_hid, n_inner),
            nn.ReLU(inplace=True),
            nn.Linear(n_inner, n_hid)
        )

    def forward(self, x, mask=None):
        x = x + self.attn(x, mask=mask)
        return x + self.inner(x)
    
class GCNBranch(nn.Module):
    def __init__(self, n_hid_in, n_hid_out, dropout=0.3):
        super().__init__()
        """
        Define a branch of the graph convolution with
        1. GraphConv from n_hid_in to n_hid_in
        2. ReLU
        3. Dropout
        4. GraphConv from n_hid_in to n_hid_out
        
        Note: your should call dgl.nn.GraphConv with allow_zero_in_degree=True
        """
        ### Your code here (done) ###
        self._dropout_p = dropout
        self._gc1 = GraphConv(n_hid_in, n_hid_in, allow_zero_in_degree=True)
        self._gc2 = GraphConv(n_hid_in, n_hid_out, allow_zero_in_degree=True)

    def forward(self, x, graph):
        """
        Forward pass of your defined branch above
        """
        ### Your code here (done) ###
        x = self._gc1(graph, x)
        x = F.relu(x)
        x = F.dropout(x, p=self._dropout_p, training=self.training)
        x = self._gc2(graph, x)
        return x

class GCN(nn.Module):
    """
    A graph convolution network with multiple graph convolution branches
    """
    def __init__(self, n_head=4, dropout=0.3):
        super().__init__()
        self.branches = nn.ModuleList(GCNBranch(n_hid, n_hid // n_head, dropout) for _ in range(n_head))

        self.feed_forward = nn.Sequential(
            nn.Linear(n_hid, n_hid),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(n_hid, n_hid)
        )
        self.layer_norm = nn.LayerNorm(n_hid)

    def forward(self, h, gt_graph, attr_graph):
        x = h.reshape(-1, n_hid)
        graphs = [gt_graph, gt_graph, attr_graph, attr_graph]
        x = torch.cat([branch(x, g) for branch, g in zip(self.branches, graphs)], dim=-1).view_as(h)
        x = h + self.layer_norm(x)
        return x + self.feed_forward(x)

class Gate(nn.Module):
    """
    Activation gate used a few times in the TreeDecoder
    """
    def __init__(self, n_in, n_out):
        super(Gate, self).__init__()
        self.t = nn.Linear(n_in, n_out)
        self.s = nn.Linear(n_in, n_out)

    def forward(self, x):
        return self.t(x).tanh() * self.s(x).sigmoid()

class TreeDecoder(nn.Module):
    """
    Defines parameters and methods for decoding into an expression. Used in train and predict
    """
    def __init__(self, dropout=0.5):
        super().__init__()
        drop = nn.Dropout(dropout)
        self.constant_embedding = nn.Parameter(torch.randn(1, out_vocab.n_constants, n_hid))

        self.qp_gate = nn.Sequential(drop, Gate(n_hid, n_hid))
        self.gts_right = nn.Sequential(drop, Gate(2 * n_hid, n_hid))

        self.attn_fc = nn.Sequential(drop,
            nn.Linear(2 * n_hid, n_hid),
            nn.Tanh(),
            nn.Linear(n_hid, 1)
        )
        self.quant_fc = nn.Sequential(drop,
            nn.Linear(n_hid * 3, n_hid),
            nn.Tanh(),
            nn.Linear(n_hid, 1, bias=False)
        )
        self.op_fc = nn.Sequential(drop, nn.Linear(n_hid * 2, out_vocab.n_ops))

        self.op_embedding = nn.Embedding(out_vocab.n_ops + 1, n_hid, padding_idx=out_vocab.n_ops)
        self.gts_left = nn.Sequential(drop, Gate(n_hid * 2 + n_hid, n_hid))
        self.gts_left_qp = nn.Sequential(drop, Gate(n_hid * 2 + n_hid, n_hid), self.qp_gate)

        self.subtree_gate = nn.Sequential(drop, Gate(n_hid * 2 + n_hid, n_hid))

    def gts_attention(self, q, zbar, in_mask=None):
        """
        Corresponds roughly to the GTS-Attention function defined by the paper
        """
        attn_score = self.attn_fc(
            torch.cat([q.unsqueeze(1).expand_as(zbar), zbar], dim=2)
        ).squeeze(2)
        if in_mask is not None:
            attn_score[~in_mask] = -np.inf
        attn = attn_score.softmax(dim=1)
        return (attn.unsqueeze(1) @ zbar).squeeze(1) # (n_batch, n_hid)

    def gts_predict(self, qp_Gc, quant_embed, nP_out_mask=None):
        """
        Corresponds roughly to the GTS-Predict functions defined by the paper
        """
        quant_score = self.quant_fc(
            torch.cat([qp_Gc.unsqueeze(1).expand(-1, quant_embed.size(1), -1), quant_embed], dim=2)
        ).squeeze(2)
        op_score = self.op_fc(qp_Gc)
        pred_score = torch.cat((op_score, quant_score), dim=1)
        if nP_out_mask is not None:
            pred_score[:, out_vocab.base_nP:][~nP_out_mask] = -np.inf
        return pred_score

    def merge_subtree(self, op, tl, yr):
        """
        Corresponds to part of the GTS-Subtree function defined by the paper
        """
        return self.subtree_gate(torch.cat((op, tl, yr), dim=-1))

class Model(nn.Module):
    """
    Overall model containing all the neural network parameters and methods
    1. The base seq2seq model is in self.transformer_layers if use_t5=None else self.t5_encoder
    2. The graph convolution network is in self.gcn
    3. The tree decoder is in self.decoder
    """
    def __init__(self, dropout=0.5):
        super().__init__()
        drop = nn.Dropout(dropout)

        if use_t5:
            """
            Use t5_model.encoder as the encoder for this model. Note that unlike the custom transformer,
            you don't need to use an external input or positional embedding for the T5 transformer 
            (i.e. don't define self.in_embed or self.pos_emb) since it already defines them internally
            
            You may specify layer weights to freeze during finetuning by modifying the freeze_layers
            global variable.
            """
            ### Your code here (done) ###
            
            # should depend on n_hid... I think d_model is the best fit for this param
            t5_config = transformers.T5Config(d_model=n_hid)
            t5_model = transformers.T5Model(config=t5_config).from_pretrained(f"t5-{use_t5}")
            self.t5_encoder = t5_model.get_encoder()
            
            for i_layer, block in enumerate(self.t5_encoder.block):
                if i_layer in freeze_layers:
                    for param in block.parameters():
                        param.requires_grad = False
        else:
            # Input embedding for custom transformer
            self.in_embed = nn.Sequential(nn.Embedding(in_vocab.n, n_hid, padding_idx=in_vocab.pad), drop)
            # Positional embedding for custom transformer
            self.pos_embed = nn.Embedding(1 + n_max_in, n_hid) # Use the first position as global vector
            self.transformer_layers = nn.ModuleList(TransformerBlock() for _ in range(n_layers))

        self.gcn = GCN()

        self.decoder = TreeDecoder()

        if not use_t5:
            self.apply(self.init_weight)

    def init_weight(self, m):
        if type(m) in [nn.Embedding]:
            nn.init.normal_(m.weight, 0, 0.1)

    def encode(self, in_idxs, n_in, gt_graph, attr_graph, in_mask=None):
        in_idxs_pad = F.pad(in_idxs, (1, 0), value=in_vocab.pad)
        if use_t5:
            """
            Use your T5 encoder to encoder the input indices. Note that you do NOT need to use an
            input embedding or positional embedding (e.g. self.in_embed or self.pos_embed) for T5,
            since it already defines the embeddings internally
            """
            ### Your code here (done) ###
            # Please be right...
            h = self.t5_encoder.forward(input_ids=in_idxs_pad, attention_mask=in_mask)[0]
            
        else:
            x = self.in_embed(in_idxs_pad) # (n_batch, n_batch_max_in, n_hid)
            h = x + self.pos_embed(torch.arange(x.size(1), device=x.device))
            for layer in self.transformer_layers:
                h = layer(h, mask=in_mask)
        zg, h = h[:, 0], h[:, 1:]
        zbar = self.gcn(h, gt_graph, attr_graph)
        return zbar, zg

# Training a Batch

In [5]:
class Node:
    """
    Node for tree traversal during training
    """
    def __init__(self, up):
        self.up = up
        self.is_root = up is None
        self.left = self.right = None
        self.ql = self.tl = self.op = None

def train(batch, model, opt):
    """
    Compute the loss on a batch of inputs, and take a step with the optimizer
    """
    n_batch = len(batch)

    n_in = [d['n_in'] for d in batch]
    pad = lambda x, value: nn.utils.rnn.pad_sequence(x, batch_first=True, padding_value=value)
    in_idxs = pad([d['in_idxs'] for d in batch], in_vocab.pad).to(device)
    in_mask = pad([torch.ones(n, dtype=torch.bool) for n in n_in], False).to(device)
    nP_in_mask = pad([d['nP_in_mask'] for d in batch], False).to(device)
    nP_out_mask = torch.stack([d['nP_out_mask'] for d in batch]).to(device)
    
    qcomp_graph, qcell_graph = [], []
    for d in batch:
        """
        Create qcomp_graph and qcell_graph from d['qcomp_edges'] and d['qcell_edges'] by calling dgl.graph
        (see https://docs.dgl.ai/generated/dgl.graph.html)

        Note that num_nodes needs to be set to the maximum input length in this batch
        """
        ### Your code here (done) ###
        qcomp_graph_i = dgl.graph(d["qcomp_edges"], num_nodes = max(n_in))
        qcell_graph_i = dgl.graph(d["qcell_edges"], num_nodes = max(n_in))
        
        qcomp_graph.append(qcomp_graph_i)
        qcell_graph.append(qcell_graph_i)
    qcomp_graph = dgl.batch(qcomp_graph)
    qcell_graph = dgl.batch(qcell_graph)
    
    label = pad([d['out_idxs'] for d in batch], out_vocab.pad)
    nP_candidates = [d['nP_candidates'] for d in batch]

    zbar, qroot = model.encode(in_idxs, n_in, qcomp_graph, qcell_graph, in_mask=None)
    z_nP = zbar.new_zeros((n_batch, n_max_nP, n_hid))
    z_nP[nP_out_mask] = zbar[nP_in_mask]

    decoder = model.decoder

    n_quant = out_vocab.n_constants + n_max_nP
    # (n_batch, n_quant, n_hid)
    quant_embed = torch.cat([decoder.constant_embedding.expand(n_batch, -1, -1), z_nP], dim=1)
    

    nodes = np.array([Node(None) for _ in range(n_batch)])
    op_min, op_max = out_vocab.base_op, out_vocab.base_op + out_vocab.n_ops
    quant_min, quant_max = out_vocab.base_quant, out_vocab.base_quant + n_quant

    # Initialize root node vector according to zg (the global context)
    qp = decoder.qp_gate(qroot)
    scores = []
    for i, label_i in enumerate(label.T): # Iterate over the output positions
        Gc = decoder.gts_attention(qp, zbar, in_mask)
        qp_Gc = torch.cat([qp, Gc], dim=1) # (n_batch, 2 * n_hid)

        score = decoder.gts_predict(qp_Gc, quant_embed, nP_out_mask)
        scores.append(score)

        # Whether the label is an operator
        is_op = (op_min <= label_i) & (label_i < op_max)
        # Whether the label is a quantity
        is_quant = ((quant_min <= label_i) & (label_i < quant_max)) | (label_i == out_vocab.unk)

        op_embed = decoder.op_embedding((label_i[is_op] - out_vocab.base_op).to(device))
        qp_Gc_op = torch.cat([qp_Gc[is_op], op_embed], dim=1)

        is_left = np.zeros(n_batch, dtype=np.bool)
        qleft_qp = decoder.gts_left_qp(qp_Gc_op)
        qleft = decoder.gts_left(qp_Gc_op)
        for j, ql, op in zip(is_op.nonzero(as_tuple=True)[0], qleft, op_embed):
            node = nodes[j]
            nodes[j] = node.left = Node(node)
            node.op = op
            node.ql = ql
            is_left[j] = True

        is_right = np.zeros(n_batch, dtype=np.bool)
        nP_score = score[:, out_vocab.base_nP:].detach().cpu()
        ql_tl = []
        for j in is_quant.nonzero(as_tuple=True)[0]:
            if label_i[j] == out_vocab.unk:
                candidates = nP_candidates[j][i]
                label_i[j] = out_vocab.base_nP + candidates[nP_score[j, candidates].argmax()]

            node = nodes[j]
            pnode = node.up
            t = quant_embed[j, label_i[j] - out_vocab.base_quant]
            while pnode and pnode.right is node:
                # merge operator, left subtree, and right child
                t = decoder.merge_subtree(pnode.op, pnode.tl, t)
                node, pnode = pnode, pnode.up # backtrack to parent node
            if pnode is None: # Finished traversing tree of j
                continue
            # Now pnode.left is node. t is the tl representing the left subtree of pnode
            pnode.tl = t
            ql_tl.append(torch.cat([pnode.ql, pnode.tl])) # For computing qright
            nodes[j] = pnode.right = Node(pnode)
            is_right[j] = True

        qp = torch.zeros((n_batch, n_hid), device=device)
        qp[is_left] = qleft_qp
        if ql_tl:
            qp[is_right] = decoder.gts_right(torch.stack(ql_tl))

    label = label.to(device).view(-1)
    scores = torch.stack(scores, dim=1).view(-1, out_vocab.n_ops + n_quant)
    loss = F.cross_entropy(scores, label, ignore_index=out_vocab.pad)

    opt.zero_grad()
    loss.backward()
    opt.step()
    return loss.item()

# Prediction (for Evaluation)

In [6]:
class BeamNode(Node):
    """
    Node for beam search during evaluation
    """
    def __init__(self, up, prev, qp, token=None):
        super().__init__(up)
        self.prev = prev
        self.qp = qp
        self.token = token

    def trace_tokens(self, *last_token):
        if self.prev is None:
            return list(last_token)
        tokens = self.prev.trace_tokens()
        tokens.append(self.token)
        tokens.extend(last_token)
        return tokens

def predict(d, model, beam_size=5, n_max_out=45):
    """
    Predict the idxs corresponding to an expression given the inputs. Leverages beam search to maximize
    prediction probability
    """
    in_idxs = d['in_idxs'].unsqueeze(0).to(device=device)
    """
    Create qcomp_graph and qcell_graph from d['qcomp_edges'] and d['qcell_edges'] by calling dgl.graph
    (see https://docs.dgl.ai/generated/dgl.graph.html)
    """
    ### Your code here (done) ###
    qcomp_graph = dgl.graph(d["qcomp_edges"], num_nodes = d["n_in"])
    qcell_graph = dgl.graph(d["qcell_edges"], num_nodes = d["n_in"])

    zbar, qroot = model.encode(in_idxs, [d['n_in']], qcomp_graph, qcell_graph)
    z_nP = zbar[:, d['nP_positions']]

    decoder = model.decoder

    quant_embed = torch.cat([decoder.constant_embedding, z_nP], dim=1) # (1, n_quant, n_hid)
    op_min, op_max = out_vocab.base_op, out_vocab.base_op + out_vocab.n_ops

    best_done_beam = (-np.inf, None, None)
    beams = [(0, BeamNode(up=None, prev=None, qp=decoder.qp_gate(qroot)))]
    for _ in range(n_max_out):
        new_beams = []
        for logp_prev, node in beams:
            Gc = decoder.gts_attention(node.qp, zbar)
            qp_Gc = torch.cat([node.qp, Gc], dim=1) # (2 * n_hid,)

            log_prob = decoder.gts_predict(qp_Gc, quant_embed).log_softmax(dim=1)
            top_logps, top_tokens = log_prob.topk(beam_size, dim=1)
            for logp_token_, out_token_ in zip(top_logps.unbind(dim=1), top_tokens.unbind(dim=1)):
                out_token = out_token_.item()
                logp = logp_prev + logp_token_.item()
                if op_min <= out_token < op_max:
                    op_embed = decoder.op_embedding(out_token_)
                    qp_Gc_op = torch.cat([qp_Gc, op_embed], dim=1)
                    prev_node = copy(node)
                    next_node = prev_node.left = BeamNode(
                        up=prev_node, prev=prev_node,
                        qp=decoder.gts_left_qp(qp_Gc_op),
                        token=out_token
                    )
                    prev_node.op = op_embed
                    prev_node.ql = decoder.gts_left(qp_Gc_op)
                else:
                    pnode, prev_node = node.up, node
                    t = quant_embed[:, out_token - out_vocab.base_quant]
                    while pnode and pnode.tl is not None:
                        t = decoder.merge_subtree(pnode.op, pnode.tl, t)
                        node, pnode = pnode, pnode.up
                    if pnode is None:
                        best_done_beam = max(best_done_beam, (logp, prev_node, out_token))
                        continue
                    pnode = copy(pnode)
                    pnode.tl = t
                    next_node = pnode.right = BeamNode(
                        up=pnode, prev=prev_node,
                        qp=decoder.gts_right(torch.cat([pnode.ql, pnode.tl], dim=1)),
                        token=out_token
                    )
                new_beams.append((logp, next_node))
        beams = sorted(new_beams, key=lambda x: x[0], reverse=True)[:beam_size]
        done_logp, done_node, done_last_token = best_done_beam
        if not len(beams) or done_logp >= beams[0][0]:
            break
    return done_node.trace_tokens(done_last_token)

# Training Loop

We provide the training loop below. When you change the hyperparameters, make sure you keep track of which hyperparameters you were using, because you'll need those parameters again during prediction (see next section). Note that if you make multiple runs with the same `use_t5` value, the saved models will be overwritten, so make sure to copy your the `model_save_dir` somewhere else if you want to save it.

Currently, the custom transformer gives us the highest performance out of the bunch. There is some degree of overfitting here, but it seems to be more drastic for `t5-small` and `t5-base`. More regularization is required if we decide to move forward with `t5-small` (training `t5-base` is too time consuming to really be viable). 

## Training custom transformer

In [7]:
# Define parameters for model
curr_run_name = "second-run"

# Value should be None, 'small', or 'base'
use_t5 = None
# use_t5 = "small"

# IMPORTANT NOTE: if you change some of these hyperparameters during training,
# you will also need to change them during prediction (see next section)

n_max_in = 100
n_epochs = 150
n_batch = 64
learning_rate = 1e-3
if use_t5:
    # T5 hyperparameters
    freeze_layers = []
    weight_decay = 1e-5
    # Do not modify unless you want to try t5-large
    n_hid = dict(small=512, base=768)[use_t5] 
else:
    # Custom transformer hyperparameters
    n_layers = 3
    n_hid = 512
    n_k = n_v = 64
    n_head = 8
    weight_decay = 0
    
# For evaluation/prediction
saved_model_name = "model-best.pth"

# Defining what to do
TRAIN=True
EVALUATION=False
PREDICTION=False

# Defining some useful variables and doing some useful tasks for later
model_save_dir = f'models/{use_t5 or "custom"}-{curr_run_name}'
model_save_path = f'models/{use_t5 or "custom"}-{curr_run_name}/{saved_model_name}'
predictions_save_path = f'models/{use_t5 or "custom"}-{curr_run_name}/predictions.csv'

os.makedirs(model_save_dir, exist_ok=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Training loop

if TRAIN:
    
    # Data
    train_data, val_data, in_vocab, out_vocab, n_max_nP, t5_model = setup(use_t5)
    tensorize_data(itertools.chain(train_data, val_data))

    # Model
    model = Model()
    opt = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, n_epochs)
    model.to(device)

    # Train
    epoch = 0
    best_acc = 0
    while epoch < n_epochs:
        
        # Train for an epoch
        print('Epoch:', epoch + 1)
        model.train()
        losses = []
        for start in trange(0, len(train_data), n_batch):
            batch = sorted(train_data[start: start + n_batch], key=lambda d: -d['n_in'])
            loss = train(batch, model, opt)
            losses.append(loss)
        scheduler.step()
        print(f'Training loss: {np.mean(losses):.3g}')

        # Evaluate after every epoch of training
        model.eval()
        value_match, equation_match = [], []
        with torch.no_grad():
            for d in tqdm(val_data):
                # This method is not equiped to handle equations with quadratics
                if d['is_quadratic']: 
                    val_match = eq_match = False
                else:
                    pred = predict(d, model)
                    d['pred_tokens'] = [out_vocab.idx2token[idx] for idx in pred]
                    val_match, eq_match = check_match(pred, d)
                value_match.append(val_match)
                equation_match.append(eq_match)
        curr_expr_acc = np.mean(equation_match)
        curr_value_acc = np.mean(value_match)
        print(f'Validation expression accuracy: {curr_expr_acc:.3g}')
        print(f'Validation value accuracy: {curr_value_acc:.3g}')
        
        # Save if best
        if curr_value_acc > best_acc:
            best_acc = curr_value_acc
            print(">>>>> Found best model so far <<<<<")
            torch.save(model.state_dict(), os.path.join(model_save_dir, 'model-best.pth'))
            
        print()
        epoch += 1        

	nonzero()
Consider using one of the following signatures instead:
	nonzero(*, bool as_tuple) (Triggered internally at  /pytorch/torch/csrc/utils/python_arg_parser.cpp:766.)
  return bound(*args, **kwds)
  0%|          | 0/30 [00:00<?, ?it/s]

Epoch: 1


100%|██████████| 30/30 [02:06<00:00,  4.21s/it]
  0%|          | 1/213 [00:00<00:23,  9.14it/s]

Training loss: 2.58


100%|██████████| 213/213 [00:28<00:00,  7.43it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.0188
Validation value accuracy: 0.0188
>>>>> Found best model so far <<<<<

Epoch: 2


100%|██████████| 30/30 [01:57<00:00,  3.93s/it]
  0%|          | 1/213 [00:00<00:31,  6.70it/s]

Training loss: 1.5


100%|██████████| 213/213 [00:31<00:00,  6.69it/s]


Validation expression accuracy: 0.117
Validation value accuracy: 0.122
>>>>> Found best model so far <<<<<


  0%|          | 0/30 [00:00<?, ?it/s]


Epoch: 3


100%|██████████| 30/30 [02:10<00:00,  4.36s/it]
  0%|          | 1/213 [00:00<00:29,  7.30it/s]

Training loss: 1.28


100%|██████████| 213/213 [00:24<00:00,  8.53it/s]


Validation expression accuracy: 0.23
Validation value accuracy: 0.235
>>>>> Found best model so far <<<<<


  0%|          | 0/30 [00:00<?, ?it/s]


Epoch: 4


100%|██████████| 30/30 [02:05<00:00,  4.19s/it]
  0%|          | 1/213 [00:00<00:33,  6.30it/s]

Training loss: 1.03


100%|██████████| 213/213 [00:27<00:00,  7.69it/s]


Validation expression accuracy: 0.272
Validation value accuracy: 0.272
>>>>> Found best model so far <<<<<


  0%|          | 0/30 [00:00<?, ?it/s]


Epoch: 5


100%|██████████| 30/30 [01:52<00:00,  3.73s/it]
  0%|          | 1/213 [00:00<00:29,  7.07it/s]

Training loss: 0.89


100%|██████████| 213/213 [00:20<00:00, 10.23it/s]


Validation expression accuracy: 0.385
Validation value accuracy: 0.39
>>>>> Found best model so far <<<<<


  0%|          | 0/30 [00:00<?, ?it/s]


Epoch: 6


100%|██████████| 30/30 [01:50<00:00,  3.69s/it]
  0%|          | 1/213 [00:00<00:31,  6.73it/s]

Training loss: 0.792


  return fn(arg1, arg2), end
  return fn(arg1, arg2), end
100%|██████████| 213/213 [00:26<00:00,  8.10it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.347
Validation value accuracy: 0.366

Epoch: 7


100%|██████████| 30/30 [01:50<00:00,  3.69s/it]
  0%|          | 1/213 [00:00<00:29,  7.14it/s]

Training loss: 0.75


100%|██████████| 213/213 [00:26<00:00,  7.97it/s]


Validation expression accuracy: 0.521
Validation value accuracy: 0.531
>>>>> Found best model so far <<<<<


  0%|          | 0/30 [00:00<?, ?it/s]


Epoch: 8


100%|██████████| 30/30 [01:54<00:00,  3.82s/it]
  0%|          | 1/213 [00:00<00:30,  6.92it/s]

Training loss: 0.65


100%|██████████| 213/213 [00:24<00:00,  8.60it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.521
Validation value accuracy: 0.531

Epoch: 9


100%|██████████| 30/30 [01:53<00:00,  3.80s/it]
  0%|          | 0/213 [00:00<?, ?it/s]

Training loss: 0.598


100%|██████████| 213/213 [00:24<00:00,  8.75it/s]


Validation expression accuracy: 0.54
Validation value accuracy: 0.54
>>>>> Found best model so far <<<<<


  0%|          | 0/30 [00:00<?, ?it/s]


Epoch: 10


100%|██████████| 30/30 [01:56<00:00,  3.90s/it]
  0%|          | 1/213 [00:00<00:28,  7.32it/s]

Training loss: 0.533


100%|██████████| 213/213 [00:20<00:00, 10.48it/s]


Validation expression accuracy: 0.615
Validation value accuracy: 0.62
>>>>> Found best model so far <<<<<


  0%|          | 0/30 [00:00<?, ?it/s]


Epoch: 11


100%|██████████| 30/30 [01:49<00:00,  3.64s/it]
  0%|          | 1/213 [00:00<00:28,  7.48it/s]

Training loss: 0.495


100%|██████████| 213/213 [00:22<00:00,  9.41it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.587
Validation value accuracy: 0.596

Epoch: 12


100%|██████████| 30/30 [01:51<00:00,  3.71s/it]
  0%|          | 1/213 [00:00<00:26,  8.15it/s]

Training loss: 0.445


100%|██████████| 213/213 [00:18<00:00, 11.72it/s]


Validation expression accuracy: 0.676
Validation value accuracy: 0.685
>>>>> Found best model so far <<<<<


  0%|          | 0/30 [00:00<?, ?it/s]


Epoch: 13


100%|██████████| 30/30 [01:59<00:00,  4.00s/it]
  0%|          | 0/213 [00:00<?, ?it/s]

Training loss: 0.399


100%|██████████| 213/213 [00:24<00:00,  8.62it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.667
Validation value accuracy: 0.676

Epoch: 14


100%|██████████| 30/30 [01:57<00:00,  3.91s/it]
  0%|          | 1/213 [00:00<00:27,  7.76it/s]

Training loss: 0.374


100%|██████████| 213/213 [00:21<00:00,  9.79it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.648
Validation value accuracy: 0.662

Epoch: 15


100%|██████████| 30/30 [01:53<00:00,  3.79s/it]
  0%|          | 1/213 [00:00<00:31,  6.65it/s]

Training loss: 0.374


100%|██████████| 213/213 [00:21<00:00,  9.86it/s]


Validation expression accuracy: 0.676
Validation value accuracy: 0.69
>>>>> Found best model so far <<<<<


  0%|          | 0/30 [00:00<?, ?it/s]


Epoch: 16


100%|██████████| 30/30 [02:00<00:00,  4.01s/it]
  0%|          | 1/213 [00:00<00:32,  6.61it/s]

Training loss: 0.339


100%|██████████| 213/213 [00:24<00:00,  8.74it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.643
Validation value accuracy: 0.657

Epoch: 17


100%|██████████| 30/30 [01:59<00:00,  3.98s/it]
  0%|          | 1/213 [00:00<00:32,  6.44it/s]

Training loss: 0.332


100%|██████████| 213/213 [00:28<00:00,  7.49it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.634
Validation value accuracy: 0.643

Epoch: 18


100%|██████████| 30/30 [02:02<00:00,  4.09s/it]
  0%|          | 1/213 [00:00<00:29,  7.14it/s]

Training loss: 0.321


100%|██████████| 213/213 [00:20<00:00, 10.49it/s]


Validation expression accuracy: 0.676
Validation value accuracy: 0.695
>>>>> Found best model so far <<<<<


  0%|          | 0/30 [00:00<?, ?it/s]


Epoch: 19


100%|██████████| 30/30 [01:49<00:00,  3.64s/it]
  0%|          | 1/213 [00:00<00:30,  7.06it/s]

Training loss: 0.297


100%|██████████| 213/213 [00:20<00:00, 10.39it/s]


Validation expression accuracy: 0.681
Validation value accuracy: 0.7
>>>>> Found best model so far <<<<<


  0%|          | 0/30 [00:00<?, ?it/s]


Epoch: 20


100%|██████████| 30/30 [01:49<00:00,  3.64s/it]
  0%|          | 1/213 [00:00<00:30,  6.95it/s]

Training loss: 0.282


100%|██████████| 213/213 [00:20<00:00, 10.46it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.676
Validation value accuracy: 0.69

Epoch: 21


100%|██████████| 30/30 [01:48<00:00,  3.62s/it]
  0%|          | 1/213 [00:00<00:22,  9.25it/s]

Training loss: 0.286


100%|██████████| 213/213 [00:18<00:00, 11.62it/s]


Validation expression accuracy: 0.69
Validation value accuracy: 0.704
>>>>> Found best model so far <<<<<


  0%|          | 0/30 [00:00<?, ?it/s]


Epoch: 22


100%|██████████| 30/30 [01:46<00:00,  3.54s/it]
  0%|          | 1/213 [00:00<00:26,  7.87it/s]

Training loss: 0.26


100%|██████████| 213/213 [00:18<00:00, 11.60it/s]


Validation expression accuracy: 0.69
Validation value accuracy: 0.709
>>>>> Found best model so far <<<<<


  0%|          | 0/30 [00:00<?, ?it/s]


Epoch: 23


100%|██████████| 30/30 [01:46<00:00,  3.56s/it]
  0%|          | 1/213 [00:00<00:25,  8.41it/s]

Training loss: 0.244


100%|██████████| 213/213 [00:19<00:00, 10.87it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.695
Validation value accuracy: 0.709

Epoch: 24


100%|██████████| 30/30 [01:50<00:00,  3.67s/it]
  0%|          | 1/213 [00:00<00:29,  7.15it/s]

Training loss: 0.248


100%|██████████| 213/213 [00:18<00:00, 11.28it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.676
Validation value accuracy: 0.704

Epoch: 25


100%|██████████| 30/30 [01:49<00:00,  3.66s/it]
  0%|          | 1/213 [00:00<00:28,  7.52it/s]

Training loss: 0.24


100%|██████████| 213/213 [00:19<00:00, 11.06it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.695
Validation value accuracy: 0.709

Epoch: 26


100%|██████████| 30/30 [01:50<00:00,  3.68s/it]
  0%|          | 1/213 [00:00<00:25,  8.47it/s]

Training loss: 0.225


100%|██████████| 213/213 [00:17<00:00, 11.93it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.685
Validation value accuracy: 0.704

Epoch: 27


100%|██████████| 30/30 [01:46<00:00,  3.55s/it]
  0%|          | 1/213 [00:00<00:25,  8.29it/s]

Training loss: 0.231


100%|██████████| 213/213 [00:17<00:00, 12.27it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.662
Validation value accuracy: 0.681

Epoch: 28


100%|██████████| 30/30 [01:46<00:00,  3.56s/it]
  0%|          | 1/213 [00:00<00:29,  7.13it/s]

Training loss: 0.227


100%|██████████| 213/213 [00:18<00:00, 11.66it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.681
Validation value accuracy: 0.695

Epoch: 29


100%|██████████| 30/30 [01:47<00:00,  3.57s/it]
  0%|          | 1/213 [00:00<00:27,  7.77it/s]

Training loss: 0.222


100%|██████████| 213/213 [00:19<00:00, 10.86it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.676
Validation value accuracy: 0.69

Epoch: 30


100%|██████████| 30/30 [01:47<00:00,  3.59s/it]
  0%|          | 1/213 [00:00<00:28,  7.38it/s]

Training loss: 0.201


100%|██████████| 213/213 [00:21<00:00,  9.95it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.685
Validation value accuracy: 0.704

Epoch: 31


100%|██████████| 30/30 [01:46<00:00,  3.56s/it]
  0%|          | 1/213 [00:00<00:25,  8.32it/s]

Training loss: 0.202


100%|██████████| 213/213 [00:18<00:00, 11.83it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.695
Validation value accuracy: 0.709

Epoch: 32


100%|██████████| 30/30 [01:47<00:00,  3.58s/it]
  0%|          | 1/213 [00:00<00:25,  8.27it/s]

Training loss: 0.184


100%|██████████| 213/213 [00:19<00:00, 11.06it/s]


Validation expression accuracy: 0.7
Validation value accuracy: 0.718
>>>>> Found best model so far <<<<<


  0%|          | 0/30 [00:00<?, ?it/s]


Epoch: 33


100%|██████████| 30/30 [01:47<00:00,  3.57s/it]
  0%|          | 1/213 [00:00<00:24,  8.54it/s]

Training loss: 0.171


100%|██████████| 213/213 [00:17<00:00, 12.16it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.69
Validation value accuracy: 0.709

Epoch: 34


100%|██████████| 30/30 [01:47<00:00,  3.58s/it]
  0%|          | 1/213 [00:00<00:26,  8.02it/s]

Training loss: 0.17


100%|██████████| 213/213 [00:18<00:00, 11.69it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.69
Validation value accuracy: 0.714

Epoch: 35


100%|██████████| 30/30 [01:49<00:00,  3.64s/it]
  0%|          | 1/213 [00:00<00:23,  9.02it/s]

Training loss: 0.152


100%|██████████| 213/213 [00:16<00:00, 12.73it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.657
Validation value accuracy: 0.671

Epoch: 36


100%|██████████| 30/30 [01:48<00:00,  3.61s/it]
  0%|          | 1/213 [00:00<00:24,  8.60it/s]

Training loss: 0.164


100%|██████████| 213/213 [00:17<00:00, 12.19it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.7
Validation value accuracy: 0.714

Epoch: 37


100%|██████████| 30/30 [01:48<00:00,  3.63s/it]
  0%|          | 1/213 [00:00<00:31,  6.65it/s]

Training loss: 0.161


100%|██████████| 213/213 [00:18<00:00, 11.64it/s]


Validation expression accuracy: 0.718
Validation value accuracy: 0.732
>>>>> Found best model so far <<<<<


  0%|          | 0/30 [00:00<?, ?it/s]


Epoch: 38


100%|██████████| 30/30 [01:49<00:00,  3.66s/it]
  0%|          | 1/213 [00:00<00:26,  8.06it/s]

Training loss: 0.15


100%|██████████| 213/213 [00:16<00:00, 12.69it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.7
Validation value accuracy: 0.714

Epoch: 39


100%|██████████| 30/30 [01:50<00:00,  3.67s/it]
  0%|          | 1/213 [00:00<00:27,  7.63it/s]

Training loss: 0.153


100%|██████████| 213/213 [00:18<00:00, 11.36it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.704
Validation value accuracy: 0.723

Epoch: 40


100%|██████████| 30/30 [01:50<00:00,  3.68s/it]
  0%|          | 1/213 [00:00<00:26,  7.87it/s]

Training loss: 0.155


100%|██████████| 213/213 [00:19<00:00, 10.85it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.695
Validation value accuracy: 0.709

Epoch: 41


100%|██████████| 30/30 [01:52<00:00,  3.76s/it]
  0%|          | 1/213 [00:00<00:27,  7.77it/s]

Training loss: 0.154


100%|██████████| 213/213 [00:18<00:00, 11.40it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.695
Validation value accuracy: 0.709

Epoch: 42


100%|██████████| 30/30 [01:50<00:00,  3.67s/it]
  0%|          | 1/213 [00:00<00:25,  8.30it/s]

Training loss: 0.156


100%|██████████| 213/213 [00:17<00:00, 12.16it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.681
Validation value accuracy: 0.7

Epoch: 43


100%|██████████| 30/30 [01:49<00:00,  3.66s/it]
  0%|          | 1/213 [00:00<00:26,  7.98it/s]

Training loss: 0.142


100%|██████████| 213/213 [00:18<00:00, 11.28it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.718
Validation value accuracy: 0.732

Epoch: 44


100%|██████████| 30/30 [01:51<00:00,  3.72s/it]
  0%|          | 1/213 [00:00<00:27,  7.79it/s]

Training loss: 0.143


100%|██████████| 213/213 [00:17<00:00, 12.36it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.695
Validation value accuracy: 0.714

Epoch: 45


100%|██████████| 30/30 [01:52<00:00,  3.74s/it]
  0%|          | 1/213 [00:00<00:26,  7.88it/s]

Training loss: 0.113


 91%|█████████ | 193/213 [00:16<00:01, 10.13it/s]

Malformed expression ['*', '-', '1', '45']


100%|██████████| 213/213 [00:17<00:00, 11.86it/s]


Validation expression accuracy: 0.723
Validation value accuracy: 0.737
>>>>> Found best model so far <<<<<


  0%|          | 0/30 [00:00<?, ?it/s]


Epoch: 46


100%|██████████| 30/30 [01:51<00:00,  3.70s/it]
  0%|          | 1/213 [00:00<00:24,  8.72it/s]

Training loss: 0.118


100%|██████████| 213/213 [00:17<00:00, 12.46it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.709
Validation value accuracy: 0.723

Epoch: 47


100%|██████████| 30/30 [01:52<00:00,  3.75s/it]
  0%|          | 1/213 [00:00<00:27,  7.71it/s]

Training loss: 0.136


100%|██████████| 213/213 [00:16<00:00, 12.96it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.7
Validation value accuracy: 0.718

Epoch: 48


100%|██████████| 30/30 [01:54<00:00,  3.82s/it]
  0%|          | 1/213 [00:00<00:24,  8.66it/s]

Training loss: 0.122


100%|██████████| 213/213 [00:16<00:00, 12.66it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.704
Validation value accuracy: 0.718

Epoch: 49


100%|██████████| 30/30 [01:52<00:00,  3.74s/it]
  0%|          | 1/213 [00:00<00:24,  8.72it/s]

Training loss: 0.124


100%|██████████| 213/213 [00:16<00:00, 13.07it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.695
Validation value accuracy: 0.714

Epoch: 50


100%|██████████| 30/30 [01:53<00:00,  3.78s/it]
  0%|          | 1/213 [00:00<00:24,  8.73it/s]

Training loss: 0.117


100%|██████████| 213/213 [00:15<00:00, 13.32it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.695
Validation value accuracy: 0.718

Epoch: 51


100%|██████████| 30/30 [01:57<00:00,  3.93s/it]
  0%|          | 1/213 [00:00<00:25,  8.17it/s]

Training loss: 0.117


100%|██████████| 213/213 [00:16<00:00, 12.53it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.676
Validation value accuracy: 0.695

Epoch: 52


100%|██████████| 30/30 [01:52<00:00,  3.74s/it]
  0%|          | 1/213 [00:00<00:21,  9.82it/s]

Training loss: 0.106


100%|██████████| 213/213 [00:17<00:00, 12.28it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.714
Validation value accuracy: 0.732

Epoch: 53


100%|██████████| 30/30 [01:54<00:00,  3.82s/it]
  0%|          | 1/213 [00:00<00:25,  8.26it/s]

Training loss: 0.116


100%|██████████| 213/213 [00:18<00:00, 11.36it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.671
Validation value accuracy: 0.7

Epoch: 54


100%|██████████| 30/30 [01:56<00:00,  3.87s/it]
  0%|          | 1/213 [00:00<00:24,  8.82it/s]

Training loss: 0.132


100%|██████████| 213/213 [00:17<00:00, 12.27it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.704
Validation value accuracy: 0.718

Epoch: 55


100%|██████████| 30/30 [01:55<00:00,  3.87s/it]
  0%|          | 1/213 [00:00<00:23,  9.17it/s]

Training loss: 0.109


100%|██████████| 213/213 [00:18<00:00, 11.30it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.685
Validation value accuracy: 0.7

Epoch: 56


100%|██████████| 30/30 [02:00<00:00,  4.01s/it]
  0%|          | 1/213 [00:00<00:24,  8.58it/s]

Training loss: 0.113


100%|██████████| 213/213 [00:17<00:00, 12.34it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.7
Validation value accuracy: 0.714

Epoch: 57


100%|██████████| 30/30 [01:58<00:00,  3.96s/it]
  0%|          | 1/213 [00:00<00:23,  9.22it/s]

Training loss: 0.104


100%|██████████| 213/213 [00:16<00:00, 12.81it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.695
Validation value accuracy: 0.714

Epoch: 58


100%|██████████| 30/30 [02:00<00:00,  4.01s/it]
  0%|          | 1/213 [00:00<00:26,  7.95it/s]

Training loss: 0.0894


100%|██████████| 213/213 [00:17<00:00, 11.95it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.695
Validation value accuracy: 0.709

Epoch: 59


100%|██████████| 30/30 [01:59<00:00,  4.00s/it]
  0%|          | 1/213 [00:00<00:22,  9.32it/s]

Training loss: 0.0841


100%|██████████| 213/213 [00:16<00:00, 12.62it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.709
Validation value accuracy: 0.723

Epoch: 60


100%|██████████| 30/30 [01:59<00:00,  3.97s/it]
  0%|          | 1/213 [00:00<00:23,  9.02it/s]

Training loss: 0.0755


100%|██████████| 213/213 [00:17<00:00, 12.52it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.709
Validation value accuracy: 0.723

Epoch: 61


100%|██████████| 30/30 [01:59<00:00,  3.99s/it]
  0%|          | 1/213 [00:00<00:22,  9.24it/s]

Training loss: 0.0662


100%|██████████| 213/213 [00:17<00:00, 12.48it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.714
Validation value accuracy: 0.728

Epoch: 62


100%|██████████| 30/30 [01:56<00:00,  3.88s/it]
  0%|          | 1/213 [00:00<00:23,  9.03it/s]

Training loss: 0.0595


100%|██████████| 213/213 [00:17<00:00, 12.52it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.704
Validation value accuracy: 0.723

Epoch: 63


100%|██████████| 30/30 [01:55<00:00,  3.85s/it]
  0%|          | 1/213 [00:00<00:24,  8.72it/s]

Training loss: 0.0685


100%|██████████| 213/213 [00:17<00:00, 12.34it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.718
Validation value accuracy: 0.732

Epoch: 64


100%|██████████| 30/30 [01:56<00:00,  3.87s/it]
  0%|          | 1/213 [00:00<00:26,  8.14it/s]

Training loss: 0.0615


100%|██████████| 213/213 [00:16<00:00, 13.00it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.709
Validation value accuracy: 0.723

Epoch: 65


100%|██████████| 30/30 [01:58<00:00,  3.96s/it]
  0%|          | 1/213 [00:00<00:24,  8.82it/s]

Training loss: 0.0615


100%|██████████| 213/213 [00:17<00:00, 12.10it/s]


Validation expression accuracy: 0.718
Validation value accuracy: 0.742
>>>>> Found best model so far <<<<<


  0%|          | 0/30 [00:00<?, ?it/s]


Epoch: 66


100%|██████████| 30/30 [01:58<00:00,  3.94s/it]
  0%|          | 1/213 [00:00<00:23,  9.03it/s]

Training loss: 0.0622


100%|██████████| 213/213 [00:17<00:00, 12.43it/s]


Validation expression accuracy: 0.732
Validation value accuracy: 0.746
>>>>> Found best model so far <<<<<


  0%|          | 0/30 [00:00<?, ?it/s]


Epoch: 67


100%|██████████| 30/30 [02:00<00:00,  4.01s/it]
  0%|          | 1/213 [00:00<00:21,  9.64it/s]

Training loss: 0.0659


100%|██████████| 213/213 [00:16<00:00, 12.68it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.704
Validation value accuracy: 0.723

Epoch: 68


100%|██████████| 30/30 [01:59<00:00,  3.97s/it]
  0%|          | 1/213 [00:00<00:24,  8.50it/s]

Training loss: 0.0593


100%|██████████| 213/213 [00:16<00:00, 12.58it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.709
Validation value accuracy: 0.732

Epoch: 69


100%|██████████| 30/30 [01:56<00:00,  3.89s/it]
  0%|          | 1/213 [00:00<00:26,  8.14it/s]

Training loss: 0.0466


100%|██████████| 213/213 [00:17<00:00, 12.47it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.723
Validation value accuracy: 0.746

Epoch: 70


100%|██████████| 30/30 [01:56<00:00,  3.90s/it]
  0%|          | 1/213 [00:00<00:25,  8.18it/s]

Training loss: 0.0466


100%|██████████| 213/213 [00:17<00:00, 12.47it/s]


Validation expression accuracy: 0.737
Validation value accuracy: 0.756
>>>>> Found best model so far <<<<<


  0%|          | 0/30 [00:00<?, ?it/s]


Epoch: 71


100%|██████████| 30/30 [01:54<00:00,  3.82s/it]
  0%|          | 1/213 [00:00<00:25,  8.36it/s]

Training loss: 0.0469


100%|██████████| 213/213 [00:16<00:00, 12.68it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.728
Validation value accuracy: 0.746

Epoch: 72


100%|██████████| 30/30 [01:54<00:00,  3.82s/it]
  0%|          | 1/213 [00:00<00:26,  8.13it/s]

Training loss: 0.0539


100%|██████████| 213/213 [00:16<00:00, 13.26it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.714
Validation value accuracy: 0.728

Epoch: 73


100%|██████████| 30/30 [01:57<00:00,  3.91s/it]
  0%|          | 1/213 [00:00<00:22,  9.38it/s]

Training loss: 0.0444


100%|██████████| 213/213 [00:15<00:00, 13.64it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.723
Validation value accuracy: 0.742

Epoch: 74


100%|██████████| 30/30 [01:57<00:00,  3.93s/it]
  0%|          | 1/213 [00:00<00:24,  8.54it/s]

Training loss: 0.0418


100%|██████████| 213/213 [00:16<00:00, 13.14it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.737
Validation value accuracy: 0.751

Epoch: 75


100%|██████████| 30/30 [01:54<00:00,  3.83s/it]
  0%|          | 1/213 [00:00<00:23,  9.18it/s]

Training loss: 0.0392


100%|██████████| 213/213 [00:16<00:00, 13.30it/s]


Validation expression accuracy: 0.742
Validation value accuracy: 0.761
>>>>> Found best model so far <<<<<


  0%|          | 0/30 [00:00<?, ?it/s]


Epoch: 76


100%|██████████| 30/30 [01:58<00:00,  3.93s/it]
  0%|          | 1/213 [00:00<00:26,  7.92it/s]

Training loss: 0.0394


100%|██████████| 213/213 [00:16<00:00, 12.62it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.728
Validation value accuracy: 0.742

Epoch: 77


100%|██████████| 30/30 [01:57<00:00,  3.91s/it]
  0%|          | 1/213 [00:00<00:24,  8.70it/s]

Training loss: 0.041


100%|██████████| 213/213 [00:16<00:00, 12.72it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.718
Validation value accuracy: 0.737

Epoch: 78


100%|██████████| 30/30 [01:57<00:00,  3.92s/it]
  0%|          | 1/213 [00:00<00:26,  8.11it/s]

Training loss: 0.0385


100%|██████████| 213/213 [00:16<00:00, 13.29it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.723
Validation value accuracy: 0.737

Epoch: 79


100%|██████████| 30/30 [01:57<00:00,  3.92s/it]
  0%|          | 1/213 [00:00<00:24,  8.59it/s]

Training loss: 0.0334


100%|██████████| 213/213 [00:16<00:00, 13.23it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.737
Validation value accuracy: 0.756

Epoch: 80


100%|██████████| 30/30 [01:57<00:00,  3.93s/it]
  0%|          | 1/213 [00:00<00:24,  8.51it/s]

Training loss: 0.0364


100%|██████████| 213/213 [00:16<00:00, 13.10it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.723
Validation value accuracy: 0.746

Epoch: 81


100%|██████████| 30/30 [01:54<00:00,  3.83s/it]
  0%|          | 1/213 [00:00<00:25,  8.43it/s]

Training loss: 0.0303


100%|██████████| 213/213 [00:16<00:00, 12.99it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.714
Validation value accuracy: 0.742

Epoch: 82


100%|██████████| 30/30 [01:55<00:00,  3.85s/it]
  0%|          | 1/213 [00:00<00:25,  8.36it/s]

Training loss: 0.0283


100%|██████████| 213/213 [00:17<00:00, 11.95it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.714
Validation value accuracy: 0.732

Epoch: 83


100%|██████████| 30/30 [01:53<00:00,  3.79s/it]
  0%|          | 1/213 [00:00<00:24,  8.74it/s]

Training loss: 0.026


100%|██████████| 213/213 [00:18<00:00, 11.78it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.714
Validation value accuracy: 0.737

Epoch: 84


100%|██████████| 30/30 [01:53<00:00,  3.78s/it]
  0%|          | 1/213 [00:00<00:23,  9.05it/s]

Training loss: 0.0249


100%|██████████| 213/213 [00:16<00:00, 12.55it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.723
Validation value accuracy: 0.742

Epoch: 85


100%|██████████| 30/30 [01:54<00:00,  3.81s/it]
  0%|          | 1/213 [00:00<00:23,  9.04it/s]

Training loss: 0.0272


100%|██████████| 213/213 [00:18<00:00, 11.58it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.728
Validation value accuracy: 0.746

Epoch: 86


100%|██████████| 30/30 [01:54<00:00,  3.83s/it]
  0%|          | 1/213 [00:00<00:26,  8.01it/s]

Training loss: 0.0279


100%|██████████| 213/213 [00:16<00:00, 12.71it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.718
Validation value accuracy: 0.742

Epoch: 87


100%|██████████| 30/30 [01:54<00:00,  3.83s/it]
  0%|          | 1/213 [00:00<00:24,  8.53it/s]

Training loss: 0.0219


100%|██████████| 213/213 [00:16<00:00, 12.84it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.732
Validation value accuracy: 0.756

Epoch: 88


100%|██████████| 30/30 [01:53<00:00,  3.78s/it]
  0%|          | 1/213 [00:00<00:27,  7.80it/s]

Training loss: 0.0219


100%|██████████| 213/213 [00:17<00:00, 12.46it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.737
Validation value accuracy: 0.761

Epoch: 89


100%|██████████| 30/30 [01:54<00:00,  3.81s/it]
  0%|          | 1/213 [00:00<00:24,  8.58it/s]

Training loss: 0.0198


100%|██████████| 213/213 [00:16<00:00, 12.83it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.728
Validation value accuracy: 0.746

Epoch: 90


100%|██████████| 30/30 [01:54<00:00,  3.83s/it]
  0%|          | 1/213 [00:00<00:25,  8.41it/s]

Training loss: 0.0205


100%|██████████| 213/213 [00:16<00:00, 12.76it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.732
Validation value accuracy: 0.756

Epoch: 91


100%|██████████| 30/30 [01:54<00:00,  3.83s/it]
  0%|          | 1/213 [00:00<00:24,  8.53it/s]

Training loss: 0.0217


100%|██████████| 213/213 [00:17<00:00, 12.24it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.732
Validation value accuracy: 0.756

Epoch: 92


100%|██████████| 30/30 [01:53<00:00,  3.78s/it]
  0%|          | 1/213 [00:00<00:23,  8.84it/s]

Training loss: 0.0213


100%|██████████| 213/213 [00:16<00:00, 12.60it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.728
Validation value accuracy: 0.751

Epoch: 93


100%|██████████| 30/30 [01:54<00:00,  3.81s/it]
  0%|          | 1/213 [00:00<00:26,  8.03it/s]

Training loss: 0.0208


100%|██████████| 213/213 [00:17<00:00, 12.16it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.709
Validation value accuracy: 0.732

Epoch: 94


100%|██████████| 30/30 [01:54<00:00,  3.81s/it]
  0%|          | 1/213 [00:00<00:24,  8.71it/s]

Training loss: 0.0176


100%|██████████| 213/213 [00:17<00:00, 12.23it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.723
Validation value accuracy: 0.742

Epoch: 95


100%|██████████| 30/30 [01:53<00:00,  3.78s/it]
  0%|          | 1/213 [00:00<00:24,  8.73it/s]

Training loss: 0.0179


100%|██████████| 213/213 [00:16<00:00, 12.80it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.709
Validation value accuracy: 0.737

Epoch: 96


100%|██████████| 30/30 [01:54<00:00,  3.81s/it]
  0%|          | 1/213 [00:00<00:24,  8.66it/s]

Training loss: 0.0246


100%|██████████| 213/213 [00:16<00:00, 12.59it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.718
Validation value accuracy: 0.737

Epoch: 97


100%|██████████| 30/30 [01:56<00:00,  3.88s/it]
  0%|          | 1/213 [00:00<00:24,  8.53it/s]

Training loss: 0.0157


100%|██████████| 213/213 [00:17<00:00, 12.25it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.723
Validation value accuracy: 0.746

Epoch: 98


100%|██████████| 30/30 [01:55<00:00,  3.84s/it]
  0%|          | 1/213 [00:00<00:27,  7.77it/s]

Training loss: 0.0146


100%|██████████| 213/213 [00:16<00:00, 12.77it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.737
Validation value accuracy: 0.761

Epoch: 99


100%|██████████| 30/30 [01:55<00:00,  3.84s/it]
  0%|          | 1/213 [00:00<00:25,  8.42it/s]

Training loss: 0.0145


100%|██████████| 213/213 [00:17<00:00, 12.38it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.728
Validation value accuracy: 0.751

Epoch: 100


100%|██████████| 30/30 [01:55<00:00,  3.86s/it]
  0%|          | 1/213 [00:00<00:26,  7.89it/s]

Training loss: 0.0171


100%|██████████| 213/213 [00:17<00:00, 12.47it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.723
Validation value accuracy: 0.751

Epoch: 101


100%|██████████| 30/30 [01:55<00:00,  3.85s/it]
  0%|          | 1/213 [00:00<00:26,  8.15it/s]

Training loss: 0.0162


100%|██████████| 213/213 [00:17<00:00, 11.99it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.732
Validation value accuracy: 0.756

Epoch: 102


100%|██████████| 30/30 [01:54<00:00,  3.81s/it]
  0%|          | 1/213 [00:00<00:23,  8.91it/s]

Training loss: 0.014


100%|██████████| 213/213 [00:16<00:00, 12.58it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.732
Validation value accuracy: 0.756

Epoch: 103


100%|██████████| 30/30 [01:55<00:00,  3.85s/it]
  0%|          | 1/213 [00:00<00:24,  8.50it/s]

Training loss: 0.0126


100%|██████████| 213/213 [00:17<00:00, 12.34it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.728
Validation value accuracy: 0.751

Epoch: 104


100%|██████████| 30/30 [01:57<00:00,  3.92s/it]
  0%|          | 1/213 [00:00<00:27,  7.75it/s]

Training loss: 0.0125


100%|██████████| 213/213 [00:17<00:00, 12.33it/s]


Validation expression accuracy: 0.737
Validation value accuracy: 0.765
>>>>> Found best model so far <<<<<


  0%|          | 0/30 [00:00<?, ?it/s]


Epoch: 105


100%|██████████| 30/30 [01:57<00:00,  3.93s/it]
  0%|          | 1/213 [00:00<00:24,  8.67it/s]

Training loss: 0.0163


100%|██████████| 213/213 [00:16<00:00, 12.79it/s]


Validation expression accuracy: 0.746
Validation value accuracy: 0.77
>>>>> Found best model so far <<<<<


  0%|          | 0/30 [00:00<?, ?it/s]


Epoch: 106


100%|██████████| 30/30 [01:56<00:00,  3.88s/it]
  0%|          | 1/213 [00:00<00:24,  8.51it/s]

Training loss: 0.0166


100%|██████████| 213/213 [00:17<00:00, 12.43it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.746
Validation value accuracy: 0.77

Epoch: 107


100%|██████████| 30/30 [01:55<00:00,  3.86s/it]
  0%|          | 1/213 [00:00<00:25,  8.17it/s]

Training loss: 0.014


100%|██████████| 213/213 [00:17<00:00, 11.85it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.742
Validation value accuracy: 0.77

Epoch: 108


100%|██████████| 30/30 [01:55<00:00,  3.86s/it]
  0%|          | 1/213 [00:00<00:25,  8.45it/s]

Training loss: 0.0116


100%|██████████| 213/213 [00:17<00:00, 12.31it/s]


Validation expression accuracy: 0.742
Validation value accuracy: 0.775
>>>>> Found best model so far <<<<<


  0%|          | 0/30 [00:00<?, ?it/s]


Epoch: 109


100%|██████████| 30/30 [01:55<00:00,  3.85s/it]
  0%|          | 1/213 [00:00<00:29,  7.10it/s]

Training loss: 0.0115


100%|██████████| 213/213 [00:17<00:00, 12.14it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.746
Validation value accuracy: 0.775

Epoch: 110


100%|██████████| 30/30 [01:55<00:00,  3.85s/it]
  0%|          | 1/213 [00:00<00:25,  8.38it/s]

Training loss: 0.0111


100%|██████████| 213/213 [00:17<00:00, 12.32it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.746
Validation value accuracy: 0.775

Epoch: 111


100%|██████████| 30/30 [01:55<00:00,  3.86s/it]
  0%|          | 1/213 [00:00<00:28,  7.39it/s]

Training loss: 0.0121


100%|██████████| 213/213 [00:17<00:00, 12.31it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.732
Validation value accuracy: 0.761

Epoch: 112


100%|██████████| 30/30 [01:56<00:00,  3.88s/it]
  0%|          | 1/213 [00:00<00:26,  8.11it/s]

Training loss: 0.0137


100%|██████████| 213/213 [00:17<00:00, 11.98it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.709
Validation value accuracy: 0.737

Epoch: 113


100%|██████████| 30/30 [01:56<00:00,  3.88s/it]
  0%|          | 1/213 [00:00<00:25,  8.36it/s]

Training loss: 0.0107


100%|██████████| 213/213 [00:17<00:00, 12.13it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.718
Validation value accuracy: 0.742

Epoch: 114


100%|██████████| 30/30 [01:56<00:00,  3.90s/it]
  0%|          | 1/213 [00:00<00:25,  8.26it/s]

Training loss: 0.0109


100%|██████████| 213/213 [00:17<00:00, 12.27it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.723
Validation value accuracy: 0.751

Epoch: 115


100%|██████████| 30/30 [01:57<00:00,  3.93s/it]
  0%|          | 1/213 [00:00<00:25,  8.27it/s]

Training loss: 0.00994


100%|██████████| 213/213 [00:16<00:00, 12.53it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.746
Validation value accuracy: 0.77

Epoch: 116


100%|██████████| 30/30 [01:57<00:00,  3.92s/it]
  0%|          | 1/213 [00:00<00:25,  8.22it/s]

Training loss: 0.00818


100%|██████████| 213/213 [00:16<00:00, 12.56it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.732
Validation value accuracy: 0.756

Epoch: 117


100%|██████████| 30/30 [01:56<00:00,  3.89s/it]
  0%|          | 1/213 [00:00<00:26,  8.08it/s]

Training loss: 0.00784


100%|██████████| 213/213 [00:17<00:00, 12.46it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.732
Validation value accuracy: 0.756

Epoch: 118


100%|██████████| 30/30 [01:57<00:00,  3.91s/it]
  0%|          | 1/213 [00:00<00:30,  6.87it/s]

Training loss: 0.0102


100%|██████████| 213/213 [00:17<00:00, 12.26it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.737
Validation value accuracy: 0.761

Epoch: 119


100%|██████████| 30/30 [01:56<00:00,  3.89s/it]
  0%|          | 1/213 [00:00<00:26,  8.05it/s]

Training loss: 0.00895


100%|██████████| 213/213 [00:17<00:00, 12.47it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.742
Validation value accuracy: 0.765

Epoch: 120


100%|██████████| 30/30 [01:56<00:00,  3.87s/it]
  0%|          | 1/213 [00:00<00:26,  8.10it/s]

Training loss: 0.00672


100%|██████████| 213/213 [00:17<00:00, 12.42it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.746
Validation value accuracy: 0.77

Epoch: 121


100%|██████████| 30/30 [01:56<00:00,  3.88s/it]
  0%|          | 1/213 [00:00<00:26,  8.07it/s]

Training loss: 0.0075


100%|██████████| 213/213 [00:17<00:00, 12.37it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.751
Validation value accuracy: 0.775

Epoch: 122


100%|██████████| 30/30 [01:56<00:00,  3.87s/it]
  0%|          | 1/213 [00:00<00:27,  7.57it/s]

Training loss: 0.00744


100%|██████████| 213/213 [00:17<00:00, 12.46it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.732
Validation value accuracy: 0.756

Epoch: 123


100%|██████████| 30/30 [01:56<00:00,  3.89s/it]
  0%|          | 1/213 [00:00<00:22,  9.56it/s]

Training loss: 0.00722


100%|██████████| 213/213 [00:17<00:00, 12.44it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.723
Validation value accuracy: 0.746

Epoch: 124


100%|██████████| 30/30 [01:56<00:00,  3.89s/it]
  0%|          | 1/213 [00:00<00:26,  8.02it/s]

Training loss: 0.0074


100%|██████████| 213/213 [00:16<00:00, 12.53it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.737
Validation value accuracy: 0.761

Epoch: 125


100%|██████████| 30/30 [01:56<00:00,  3.89s/it]
  0%|          | 1/213 [00:00<00:26,  8.03it/s]

Training loss: 0.00789


100%|██████████| 213/213 [00:17<00:00, 12.40it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.732
Validation value accuracy: 0.756

Epoch: 126


100%|██████████| 30/30 [01:56<00:00,  3.89s/it]
  0%|          | 1/213 [00:00<00:26,  7.99it/s]

Training loss: 0.00597


100%|██████████| 213/213 [00:17<00:00, 12.48it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.737
Validation value accuracy: 0.761

Epoch: 127


100%|██████████| 30/30 [01:56<00:00,  3.89s/it]
  0%|          | 1/213 [00:00<00:26,  8.14it/s]

Training loss: 0.00872


100%|██████████| 213/213 [00:17<00:00, 12.39it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.742
Validation value accuracy: 0.765

Epoch: 128


100%|██████████| 30/30 [01:56<00:00,  3.87s/it]
  0%|          | 1/213 [00:00<00:24,  8.51it/s]

Training loss: 0.00897


100%|██████████| 213/213 [00:17<00:00, 12.35it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.737
Validation value accuracy: 0.761

Epoch: 129


100%|██████████| 30/30 [01:56<00:00,  3.87s/it]
  0%|          | 1/213 [00:00<00:24,  8.58it/s]

Training loss: 0.00763


100%|██████████| 213/213 [00:17<00:00, 12.32it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.737
Validation value accuracy: 0.761

Epoch: 130


100%|██████████| 30/30 [01:57<00:00,  3.90s/it]
  0%|          | 1/213 [00:00<00:24,  8.54it/s]

Training loss: 0.00701


100%|██████████| 213/213 [00:17<00:00, 12.15it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.737
Validation value accuracy: 0.761

Epoch: 131


100%|██████████| 30/30 [01:57<00:00,  3.90s/it]
  0%|          | 1/213 [00:00<00:28,  7.50it/s]

Training loss: 0.00708


100%|██████████| 213/213 [00:17<00:00, 12.23it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.742
Validation value accuracy: 0.765

Epoch: 132


100%|██████████| 30/30 [01:56<00:00,  3.89s/it]
  0%|          | 1/213 [00:00<00:26,  8.06it/s]

Training loss: 0.00818


100%|██████████| 213/213 [00:17<00:00, 12.37it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.742
Validation value accuracy: 0.765

Epoch: 133


100%|██████████| 30/30 [01:57<00:00,  3.92s/it]
  0%|          | 1/213 [00:00<00:25,  8.40it/s]

Training loss: 0.0083


100%|██████████| 213/213 [00:17<00:00, 12.39it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.742
Validation value accuracy: 0.765

Epoch: 134


100%|██████████| 30/30 [01:57<00:00,  3.90s/it]
  0%|          | 1/213 [00:00<00:25,  8.20it/s]

Training loss: 0.00779


100%|██████████| 213/213 [00:17<00:00, 12.38it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.742
Validation value accuracy: 0.765

Epoch: 135


100%|██████████| 30/30 [01:57<00:00,  3.91s/it]
  0%|          | 1/213 [00:00<00:26,  8.07it/s]

Training loss: 0.00741


100%|██████████| 213/213 [00:17<00:00, 12.37it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.742
Validation value accuracy: 0.765

Epoch: 136


100%|██████████| 30/30 [01:57<00:00,  3.92s/it]
  0%|          | 1/213 [00:00<00:25,  8.38it/s]

Training loss: 0.00761


100%|██████████| 213/213 [00:17<00:00, 12.37it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.742
Validation value accuracy: 0.765

Epoch: 137


100%|██████████| 30/30 [01:57<00:00,  3.90s/it]
  0%|          | 1/213 [00:00<00:27,  7.64it/s]

Training loss: 0.00679


100%|██████████| 213/213 [00:17<00:00, 12.35it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.742
Validation value accuracy: 0.765

Epoch: 138


100%|██████████| 30/30 [01:56<00:00,  3.90s/it]
  0%|          | 1/213 [00:00<00:24,  8.70it/s]

Training loss: 0.00564


100%|██████████| 213/213 [00:17<00:00, 12.34it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.742
Validation value accuracy: 0.765

Epoch: 139


100%|██████████| 30/30 [01:57<00:00,  3.90s/it]
  0%|          | 1/213 [00:00<00:24,  8.75it/s]

Training loss: 0.00749


100%|██████████| 213/213 [00:17<00:00, 12.29it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.737
Validation value accuracy: 0.761

Epoch: 140


100%|██████████| 30/30 [01:57<00:00,  3.90s/it]
  0%|          | 1/213 [00:00<00:21,  9.86it/s]

Training loss: 0.00855


100%|██████████| 213/213 [00:17<00:00, 12.25it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.737
Validation value accuracy: 0.761

Epoch: 141


100%|██████████| 30/30 [01:57<00:00,  3.90s/it]
  0%|          | 1/213 [00:00<00:24,  8.56it/s]

Training loss: 0.00602


100%|██████████| 213/213 [00:17<00:00, 12.25it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.737
Validation value accuracy: 0.761

Epoch: 142


100%|██████████| 30/30 [01:57<00:00,  3.90s/it]
  0%|          | 1/213 [00:00<00:26,  8.11it/s]

Training loss: 0.00783


100%|██████████| 213/213 [00:17<00:00, 12.19it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.737
Validation value accuracy: 0.761

Epoch: 143


100%|██████████| 30/30 [01:57<00:00,  3.91s/it]
  0%|          | 1/213 [00:00<00:26,  8.04it/s]

Training loss: 0.00686


100%|██████████| 213/213 [00:17<00:00, 12.29it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.737
Validation value accuracy: 0.761

Epoch: 144


100%|██████████| 30/30 [01:56<00:00,  3.89s/it]
  0%|          | 1/213 [00:00<00:26,  7.95it/s]

Training loss: 0.00738


100%|██████████| 213/213 [00:17<00:00, 12.27it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.737
Validation value accuracy: 0.761

Epoch: 145


100%|██████████| 30/30 [01:57<00:00,  3.92s/it]
  0%|          | 1/213 [00:00<00:24,  8.64it/s]

Training loss: 0.00696


100%|██████████| 213/213 [00:17<00:00, 12.28it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.737
Validation value accuracy: 0.761

Epoch: 146


100%|██████████| 30/30 [01:57<00:00,  3.91s/it]
  0%|          | 1/213 [00:00<00:25,  8.31it/s]

Training loss: 0.00583


100%|██████████| 213/213 [00:17<00:00, 12.25it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.737
Validation value accuracy: 0.761

Epoch: 147


100%|██████████| 30/30 [01:56<00:00,  3.89s/it]
  0%|          | 1/213 [00:00<00:25,  8.33it/s]

Training loss: 0.00839


100%|██████████| 213/213 [00:17<00:00, 12.16it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.737
Validation value accuracy: 0.761

Epoch: 148


100%|██████████| 30/30 [01:56<00:00,  3.90s/it]
  0%|          | 1/213 [00:00<00:26,  7.89it/s]

Training loss: 0.00714


100%|██████████| 213/213 [00:17<00:00, 12.22it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.737
Validation value accuracy: 0.761

Epoch: 149


100%|██████████| 30/30 [01:56<00:00,  3.89s/it]
  0%|          | 1/213 [00:00<00:25,  8.20it/s]

Training loss: 0.00682


100%|██████████| 213/213 [00:17<00:00, 12.27it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.737
Validation value accuracy: 0.761

Epoch: 150


100%|██████████| 30/30 [01:56<00:00,  3.90s/it]
  0%|          | 1/213 [00:00<00:25,  8.34it/s]

Training loss: 0.00678


100%|██████████| 213/213 [00:17<00:00, 12.34it/s]

Validation expression accuracy: 0.737
Validation value accuracy: 0.761






## Training `t5-small`

In [8]:
# Define parameters for model
curr_run_name = "second-run"

# Value should be None, 'small', or 'base'
# use_t5 = None
use_t5 = "small"

# IMPORTANT NOTE: if you change some of these hyperparameters during training,
# you will also need to change them during prediction (see next section)

n_max_in = 100
n_epochs = 50
n_batch = 64
learning_rate = 1e-3
if use_t5:
    # T5 hyperparameters
    freeze_layers = []
    weight_decay = 1e-5
    # Do not modify unless you want to try t5-large
    n_hid = dict(small=512, base=768)[use_t5] 
else:
    # Custom transformer hyperparameters
    n_layers = 3
    n_hid = 512
    n_k = n_v = 64
    n_head = 8
    weight_decay = 0
    
# For evaluation/prediction
saved_model_name = "model-best.pth"

# Defining what to do
TRAIN=True
EVALUATION=False
PREDICTION=False

# Defining some useful variables and doing some useful tasks for later
model_save_dir = f'models/{use_t5 or "custom"}-{curr_run_name}'
model_save_path = f'models/{use_t5 or "custom"}-{curr_run_name}/{saved_model_name}'
predictions_save_path = f'models/{use_t5 or "custom"}-{curr_run_name}/predictions.csv'

os.makedirs(model_save_dir, exist_ok=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Training loop

if TRAIN:
    
    # Data
    train_data, val_data, in_vocab, out_vocab, n_max_nP, t5_model = setup(use_t5)
    tensorize_data(itertools.chain(train_data, val_data))

    # Model
    model = Model()
    opt = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, n_epochs)
    model.to(device)

    # Train
    epoch = 0
    best_acc = 0
    while epoch < n_epochs:
        
        # Train for an epoch
        print('Epoch:', epoch + 1)
        model.train()
        losses = []
        for start in trange(0, len(train_data), n_batch):
            batch = sorted(train_data[start: start + n_batch], key=lambda d: -d['n_in'])
            loss = train(batch, model, opt)
            losses.append(loss)
        scheduler.step()
        print(f'Training loss: {np.mean(losses):.3g}')

        # Evaluate after every epoch of training
        model.eval()
        value_match, equation_match = [], []
        with torch.no_grad():
            for d in tqdm(val_data):
                # This method is not equiped to handle equations with quadratics
                if d['is_quadratic']: 
                    val_match = eq_match = False
                else:
                    pred = predict(d, model)
                    d['pred_tokens'] = [out_vocab.idx2token[idx] for idx in pred]
                    val_match, eq_match = check_match(pred, d)
                value_match.append(val_match)
                equation_match.append(eq_match)
        curr_expr_acc = np.mean(equation_match)
        curr_value_acc = np.mean(value_match)
        print(f'Validation expression accuracy: {curr_expr_acc:.3g}')
        print(f'Validation value accuracy: {curr_value_acc:.3g}')
        
        # Save if best
        if curr_value_acc > best_acc:
            best_acc = curr_value_acc
            print(">>>>> Found best model so far <<<<<")
            torch.save(model.state_dict(), os.path.join(model_save_dir, 'model-best.pth'))
            
        print()
        epoch += 1        

Some weights of T5Model were not initialized from the model checkpoint at t5-small and are newly initialized: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of T5Model were not initialized from the model checkpoint at t5-small and are newly initialized: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  0%|          | 0/30 [00:00<?, ?it/s]

Epoch: 1


100%|██████████| 30/30 [03:12<00:00,  6.41s/it]
  0%|          | 1/213 [00:00<00:24,  8.54it/s]

Training loss: 1.77


100%|██████████| 213/213 [00:17<00:00, 11.95it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.00469
Validation value accuracy: 0.0188
>>>>> Found best model so far <<<<<

Epoch: 2


100%|██████████| 30/30 [03:12<00:00,  6.41s/it]
  0%|          | 1/213 [00:00<00:26,  7.96it/s]

Training loss: 1.48


100%|██████████| 213/213 [00:24<00:00,  8.84it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.00469
Validation value accuracy: 0.0141

Epoch: 3


100%|██████████| 30/30 [03:12<00:00,  6.43s/it]
  0%|          | 1/213 [00:00<00:27,  7.58it/s]

Training loss: 1.37


100%|██████████| 213/213 [00:24<00:00,  8.72it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.00939
Validation value accuracy: 0.0141

Epoch: 4


100%|██████████| 30/30 [03:12<00:00,  6.42s/it]
  0%|          | 1/213 [00:00<00:25,  8.28it/s]

Training loss: 1.22


100%|██████████| 213/213 [00:23<00:00,  9.12it/s]


Validation expression accuracy: 0.169
Validation value accuracy: 0.188
>>>>> Found best model so far <<<<<


  0%|          | 0/30 [00:00<?, ?it/s]


Epoch: 5


100%|██████████| 30/30 [03:12<00:00,  6.42s/it]
  0%|          | 1/213 [00:00<00:28,  7.38it/s]

Training loss: 0.987


100%|██████████| 213/213 [00:25<00:00,  8.31it/s]


Validation expression accuracy: 0.272
Validation value accuracy: 0.282
>>>>> Found best model so far <<<<<


  0%|          | 0/30 [00:00<?, ?it/s]


Epoch: 6


100%|██████████| 30/30 [03:12<00:00,  6.40s/it]
  0%|          | 1/213 [00:00<00:27,  7.77it/s]

Training loss: 0.825


100%|██████████| 213/213 [00:23<00:00,  9.00it/s]


Validation expression accuracy: 0.474
Validation value accuracy: 0.479
>>>>> Found best model so far <<<<<


  0%|          | 0/30 [00:00<?, ?it/s]


Epoch: 7


100%|██████████| 30/30 [03:12<00:00,  6.40s/it]
  0%|          | 1/213 [00:00<00:31,  6.74it/s]

Training loss: 0.73


100%|██████████| 213/213 [00:28<00:00,  7.57it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.408
Validation value accuracy: 0.427

Epoch: 8


100%|██████████| 30/30 [03:12<00:00,  6.43s/it]
  0%|          | 1/213 [00:00<00:33,  6.35it/s]

Training loss: 0.627


100%|██████████| 213/213 [00:29<00:00,  7.23it/s]


Validation expression accuracy: 0.507
Validation value accuracy: 0.516
>>>>> Found best model so far <<<<<


  0%|          | 0/30 [00:00<?, ?it/s]


Epoch: 9


100%|██████████| 30/30 [03:12<00:00,  6.41s/it]
  0%|          | 1/213 [00:00<00:28,  7.35it/s]

Training loss: 0.556


100%|██████████| 213/213 [00:23<00:00,  9.10it/s]


Validation expression accuracy: 0.624
Validation value accuracy: 0.638
>>>>> Found best model so far <<<<<


  0%|          | 0/30 [00:00<?, ?it/s]


Epoch: 10


100%|██████████| 30/30 [03:12<00:00,  6.43s/it]
  0%|          | 1/213 [00:00<00:30,  6.99it/s]

Training loss: 0.499


100%|██████████| 213/213 [00:23<00:00,  8.99it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.601
Validation value accuracy: 0.61

Epoch: 11


100%|██████████| 30/30 [03:13<00:00,  6.44s/it]
  0%|          | 1/213 [00:00<00:27,  7.71it/s]

Training loss: 0.438


100%|██████████| 213/213 [00:22<00:00,  9.46it/s]


Validation expression accuracy: 0.657
Validation value accuracy: 0.671
>>>>> Found best model so far <<<<<


  0%|          | 0/30 [00:00<?, ?it/s]


Epoch: 12


100%|██████████| 30/30 [03:12<00:00,  6.42s/it]
  0%|          | 1/213 [00:00<00:28,  7.37it/s]

Training loss: 0.386


100%|██████████| 213/213 [00:22<00:00,  9.48it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.606
Validation value accuracy: 0.615

Epoch: 13


100%|██████████| 30/30 [03:12<00:00,  6.41s/it]
  0%|          | 1/213 [00:00<00:28,  7.55it/s]

Training loss: 0.372


100%|██████████| 213/213 [00:26<00:00,  8.18it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.573
Validation value accuracy: 0.582

Epoch: 14


100%|██████████| 30/30 [03:12<00:00,  6.43s/it]
  0%|          | 1/213 [00:00<00:30,  6.91it/s]

Training loss: 0.361


100%|██████████| 213/213 [00:24<00:00,  8.59it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.596
Validation value accuracy: 0.606

Epoch: 15


100%|██████████| 30/30 [03:12<00:00,  6.42s/it]
  0%|          | 1/213 [00:00<00:32,  6.55it/s]

Training loss: 0.32


100%|██████████| 213/213 [00:25<00:00,  8.24it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.643
Validation value accuracy: 0.657

Epoch: 16


100%|██████████| 30/30 [03:12<00:00,  6.40s/it]
  0%|          | 1/213 [00:00<00:29,  7.15it/s]

Training loss: 0.3


100%|██████████| 213/213 [00:24<00:00,  8.85it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.653
Validation value accuracy: 0.662

Epoch: 17


100%|██████████| 30/30 [03:12<00:00,  6.42s/it]
  0%|          | 1/213 [00:00<00:26,  7.95it/s]

Training loss: 0.277


100%|██████████| 213/213 [00:22<00:00,  9.32it/s]


Validation expression accuracy: 0.695
Validation value accuracy: 0.709
>>>>> Found best model so far <<<<<


  0%|          | 0/30 [00:00<?, ?it/s]


Epoch: 18


100%|██████████| 30/30 [03:12<00:00,  6.41s/it]
  0%|          | 1/213 [00:00<00:26,  7.99it/s]

Training loss: 0.258


100%|██████████| 213/213 [00:22<00:00,  9.61it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.685
Validation value accuracy: 0.7

Epoch: 19


100%|██████████| 30/30 [03:14<00:00,  6.48s/it]
  0%|          | 1/213 [00:00<00:27,  7.74it/s]

Training loss: 0.248


100%|██████████| 213/213 [00:22<00:00,  9.55it/s]


Validation expression accuracy: 0.723
Validation value accuracy: 0.732
>>>>> Found best model so far <<<<<


  0%|          | 0/30 [00:00<?, ?it/s]


Epoch: 20


100%|██████████| 30/30 [03:12<00:00,  6.42s/it]
  0%|          | 1/213 [00:00<00:29,  7.14it/s]

Training loss: 0.234


100%|██████████| 213/213 [00:22<00:00,  9.57it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.7
Validation value accuracy: 0.709

Epoch: 21


100%|██████████| 30/30 [03:13<00:00,  6.44s/it]
  0%|          | 1/213 [00:00<00:28,  7.37it/s]

Training loss: 0.23


100%|██████████| 213/213 [00:25<00:00,  8.27it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.648
Validation value accuracy: 0.657

Epoch: 22


100%|██████████| 30/30 [03:12<00:00,  6.43s/it]
  0%|          | 1/213 [00:00<00:24,  8.54it/s]

Training loss: 0.218


100%|██████████| 213/213 [00:22<00:00,  9.29it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.624
Validation value accuracy: 0.634

Epoch: 23


100%|██████████| 30/30 [03:12<00:00,  6.40s/it]
  0%|          | 1/213 [00:00<00:26,  7.91it/s]

Training loss: 0.203


100%|██████████| 213/213 [00:23<00:00,  9.24it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.643
Validation value accuracy: 0.667

Epoch: 24


100%|██████████| 30/30 [03:12<00:00,  6.40s/it]
  0%|          | 1/213 [00:00<00:28,  7.35it/s]

Training loss: 0.177


100%|██████████| 213/213 [00:22<00:00,  9.47it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.667
Validation value accuracy: 0.685

Epoch: 25


100%|██████████| 30/30 [03:12<00:00,  6.43s/it]
  0%|          | 1/213 [00:00<00:26,  8.07it/s]

Training loss: 0.17


100%|██████████| 213/213 [00:22<00:00,  9.29it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.69
Validation value accuracy: 0.709

Epoch: 26


100%|██████████| 30/30 [03:12<00:00,  6.40s/it]
  0%|          | 1/213 [00:00<00:27,  7.76it/s]

Training loss: 0.158


100%|██████████| 213/213 [00:23<00:00,  9.17it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.662
Validation value accuracy: 0.685

Epoch: 27


100%|██████████| 30/30 [03:12<00:00,  6.40s/it]
  0%|          | 1/213 [00:00<00:26,  8.08it/s]

Training loss: 0.155


100%|██████████| 213/213 [00:24<00:00,  8.78it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.638
Validation value accuracy: 0.662

Epoch: 28


100%|██████████| 30/30 [03:12<00:00,  6.41s/it]
  0%|          | 1/213 [00:00<00:27,  7.80it/s]

Training loss: 0.148


100%|██████████| 213/213 [00:24<00:00,  8.62it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.653
Validation value accuracy: 0.676

Epoch: 29


100%|██████████| 30/30 [03:12<00:00,  6.41s/it]
  0%|          | 1/213 [00:00<00:25,  8.35it/s]

Training loss: 0.146


100%|██████████| 213/213 [00:21<00:00,  9.71it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.681
Validation value accuracy: 0.709

Epoch: 30


100%|██████████| 30/30 [03:12<00:00,  6.41s/it]
  0%|          | 1/213 [00:00<00:27,  7.68it/s]

Training loss: 0.14


100%|██████████| 213/213 [00:24<00:00,  8.75it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.662
Validation value accuracy: 0.685

Epoch: 31


100%|██████████| 30/30 [03:12<00:00,  6.43s/it]
  0%|          | 1/213 [00:00<00:26,  7.94it/s]

Training loss: 0.138


 87%|████████▋ | 185/213 [00:20<00:03,  8.66it/s]

Malformed expression ['/', '*', '4', '-', '18', '4']


100%|██████████| 213/213 [00:23<00:00,  9.00it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.7
Validation value accuracy: 0.723

Epoch: 32


100%|██████████| 30/30 [03:12<00:00,  6.42s/it]
  0%|          | 1/213 [00:00<00:25,  8.19it/s]

Training loss: 0.124


100%|██████████| 213/213 [00:23<00:00,  9.18it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.667
Validation value accuracy: 0.69

Epoch: 33


100%|██████████| 30/30 [03:12<00:00,  6.43s/it]
  0%|          | 1/213 [00:00<00:25,  8.31it/s]

Training loss: 0.119


100%|██████████| 213/213 [00:23<00:00,  9.13it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.624
Validation value accuracy: 0.643

Epoch: 34


100%|██████████| 30/30 [03:12<00:00,  6.42s/it]
  0%|          | 1/213 [00:00<00:27,  7.79it/s]

Training loss: 0.116


100%|██████████| 213/213 [00:22<00:00,  9.47it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.676
Validation value accuracy: 0.7

Epoch: 35


100%|██████████| 30/30 [03:12<00:00,  6.41s/it]
  0%|          | 1/213 [00:00<00:26,  7.89it/s]

Training loss: 0.117


 87%|████████▋ | 185/213 [00:18<00:03,  9.10it/s]

Malformed expression ['/', '*', '4', '-', '18', '4']


100%|██████████| 213/213 [00:21<00:00,  9.73it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.676
Validation value accuracy: 0.695

Epoch: 36


100%|██████████| 30/30 [03:12<00:00,  6.43s/it]
  0%|          | 1/213 [00:00<00:27,  7.77it/s]

Training loss: 0.108


100%|██████████| 213/213 [00:21<00:00,  9.89it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.676
Validation value accuracy: 0.7

Epoch: 37


100%|██████████| 30/30 [03:12<00:00,  6.41s/it]
  0%|          | 1/213 [00:00<00:26,  8.09it/s]

Training loss: 0.102


100%|██████████| 213/213 [00:22<00:00,  9.66it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.657
Validation value accuracy: 0.676

Epoch: 38


100%|██████████| 30/30 [03:12<00:00,  6.41s/it]
  0%|          | 1/213 [00:00<00:24,  8.70it/s]

Training loss: 0.103


100%|██████████| 213/213 [00:22<00:00,  9.36it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.653
Validation value accuracy: 0.667

Epoch: 39


100%|██████████| 30/30 [03:13<00:00,  6.44s/it]
  0%|          | 1/213 [00:00<00:24,  8.69it/s]

Training loss: 0.0981


100%|██████████| 213/213 [00:23<00:00,  8.94it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.624
Validation value accuracy: 0.653

Epoch: 40


100%|██████████| 30/30 [03:12<00:00,  6.42s/it]
  0%|          | 1/213 [00:00<00:26,  7.87it/s]

Training loss: 0.09


100%|██████████| 213/213 [00:22<00:00,  9.27it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.643
Validation value accuracy: 0.667

Epoch: 41


100%|██████████| 30/30 [03:12<00:00,  6.43s/it]
  0%|          | 1/213 [00:00<00:24,  8.59it/s]

Training loss: 0.0915


100%|██████████| 213/213 [00:22<00:00,  9.38it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.648
Validation value accuracy: 0.671

Epoch: 42


100%|██████████| 30/30 [03:13<00:00,  6.45s/it]
  0%|          | 1/213 [00:00<00:26,  7.98it/s]

Training loss: 0.0896


100%|██████████| 213/213 [00:23<00:00,  9.25it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.653
Validation value accuracy: 0.676

Epoch: 43


100%|██████████| 30/30 [03:12<00:00,  6.43s/it]
  0%|          | 1/213 [00:00<00:27,  7.80it/s]

Training loss: 0.0839


100%|██████████| 213/213 [00:23<00:00,  9.20it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.648
Validation value accuracy: 0.676

Epoch: 44


100%|██████████| 30/30 [03:13<00:00,  6.44s/it]
  0%|          | 1/213 [00:00<00:24,  8.75it/s]

Training loss: 0.086


100%|██████████| 213/213 [00:23<00:00,  9.24it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.648
Validation value accuracy: 0.671

Epoch: 45


100%|██████████| 30/30 [03:13<00:00,  6.45s/it]
  0%|          | 1/213 [00:00<00:24,  8.59it/s]

Training loss: 0.0875


100%|██████████| 213/213 [00:22<00:00,  9.28it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.648
Validation value accuracy: 0.671

Epoch: 46


100%|██████████| 30/30 [03:13<00:00,  6.45s/it]
  0%|          | 1/213 [00:00<00:27,  7.80it/s]

Training loss: 0.0825


100%|██████████| 213/213 [00:23<00:00,  9.18it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.643
Validation value accuracy: 0.676

Epoch: 47


100%|██████████| 30/30 [03:13<00:00,  6.46s/it]
  0%|          | 1/213 [00:00<00:24,  8.75it/s]

Training loss: 0.0869


100%|██████████| 213/213 [00:23<00:00,  9.14it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.643
Validation value accuracy: 0.671

Epoch: 48


100%|██████████| 30/30 [03:13<00:00,  6.44s/it]
  0%|          | 1/213 [00:00<00:25,  8.33it/s]

Training loss: 0.0827


100%|██████████| 213/213 [00:23<00:00,  9.19it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.643
Validation value accuracy: 0.671

Epoch: 49


100%|██████████| 30/30 [03:13<00:00,  6.46s/it]
  0%|          | 1/213 [00:00<00:25,  8.47it/s]

Training loss: 0.082


100%|██████████| 213/213 [00:23<00:00,  9.26it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.648
Validation value accuracy: 0.681

Epoch: 50


100%|██████████| 30/30 [03:13<00:00,  6.44s/it]
  0%|          | 1/213 [00:00<00:25,  8.17it/s]

Training loss: 0.0874


100%|██████████| 213/213 [00:22<00:00,  9.26it/s]

Validation expression accuracy: 0.648
Validation value accuracy: 0.681






## Training `t5-base`

In [None]:
# Define parameters for model
curr_run_name = "second-run"

# Value should be None, 'small', or 'base'
# use_t5 = None
use_t5 = "base"

# IMPORTANT NOTE: if you change some of these hyperparameters during training,
# you will also need to change them during prediction (see next section)

n_max_in = 100
n_epochs = 50
n_batch = 64
learning_rate = 1e-3
if use_t5:
    # T5 hyperparameters
    freeze_layers = []
    weight_decay = 1e-5
    # Do not modify unless you want to try t5-large
    n_hid = dict(small=512, base=768)[use_t5] 
else:
    # Custom transformer hyperparameters
    n_layers = 3
    n_hid = 512
    n_k = n_v = 64
    n_head = 8
    weight_decay = 0
    
# For evaluation/prediction
saved_model_name = "model-best.pth"

# Defining what to do
TRAIN=True
EVALUATION=False
PREDICTION=False

# Defining some useful variables and doing some useful tasks for later
model_save_dir = f'models/{use_t5 or "custom"}-{curr_run_name}'
model_save_path = f'models/{use_t5 or "custom"}-{curr_run_name}/{saved_model_name}'
predictions_save_path = f'models/{use_t5 or "custom"}-{curr_run_name}/predictions.csv'

os.makedirs(model_save_dir, exist_ok=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Training loop

if TRAIN:
    
    # Data
    train_data, val_data, in_vocab, out_vocab, n_max_nP, t5_model = setup(use_t5)
    tensorize_data(itertools.chain(train_data, val_data))

    # Model
    model = Model()
    opt = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, n_epochs)
    model.to(device)

    # Train
    epoch = 0
    best_acc = 0
    while epoch < n_epochs:
        
        # Train for an epoch
        print('Epoch:', epoch + 1)
        model.train()
        losses = []
        for start in trange(0, len(train_data), n_batch):
            batch = sorted(train_data[start: start + n_batch], key=lambda d: -d['n_in'])
            loss = train(batch, model, opt)
            losses.append(loss)
        scheduler.step()
        print(f'Training loss: {np.mean(losses):.3g}')

        # Evaluate after every epoch of training
        model.eval()
        value_match, equation_match = [], []
        with torch.no_grad():
            for d in tqdm(val_data):
                # This method is not equiped to handle equations with quadratics
                if d['is_quadratic']: 
                    val_match = eq_match = False
                else:
                    pred = predict(d, model)
                    d['pred_tokens'] = [out_vocab.idx2token[idx] for idx in pred]
                    val_match, eq_match = check_match(pred, d)
                value_match.append(val_match)
                equation_match.append(eq_match)
        curr_expr_acc = np.mean(equation_match)
        curr_value_acc = np.mean(value_match)
        print(f'Validation expression accuracy: {curr_expr_acc:.3g}')
        print(f'Validation value accuracy: {curr_value_acc:.3g}')
        
        # Save if best
        if curr_value_acc > best_acc:
            best_acc = curr_value_acc
            print(">>>>> Found best model so far <<<<<")
            torch.save(model.state_dict(), os.path.join(model_save_dir, 'model-best.pth'))
            
        print()
        epoch += 1        

Some weights of T5Model were not initialized from the model checkpoint at t5-base and are newly initialized: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of T5Model were not initialized from the model checkpoint at t5-base and are newly initialized: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  0%|          | 0/30 [00:00<?, ?it/s]

Epoch: 1


100%|██████████| 30/30 [10:09<00:00, 20.33s/it]
  0%|          | 0/213 [00:00<?, ?it/s]

Training loss: 1.83


100%|██████████| 213/213 [00:42<00:00,  5.04it/s]


Validation expression accuracy: 0.00469
Validation value accuracy: 0.00939
>>>>> Found best model so far <<<<<


  0%|          | 0/30 [00:00<?, ?it/s]


Epoch: 2


100%|██████████| 30/30 [11:03<00:00, 22.12s/it]
  0%|          | 0/213 [00:00<?, ?it/s]

Training loss: 1.41


100%|██████████| 213/213 [00:48<00:00,  4.43it/s]


Validation expression accuracy: 0.0141
Validation value accuracy: 0.0141
>>>>> Found best model so far <<<<<


  0%|          | 0/30 [00:00<?, ?it/s]


Epoch: 3


100%|██████████| 30/30 [09:56<00:00, 19.87s/it]
  0%|          | 0/213 [00:00<?, ?it/s]

Training loss: 1.27


100%|██████████| 213/213 [00:49<00:00,  4.32it/s]


Validation expression accuracy: 0.0188
Validation value accuracy: 0.0282
>>>>> Found best model so far <<<<<


  0%|          | 0/30 [00:00<?, ?it/s]


Epoch: 4


100%|██████████| 30/30 [09:52<00:00, 19.76s/it]
  0%|          | 0/213 [00:00<?, ?it/s]

Training loss: 1.1


100%|██████████| 213/213 [00:53<00:00,  3.97it/s]


Validation expression accuracy: 0.211
Validation value accuracy: 0.249
>>>>> Found best model so far <<<<<


  0%|          | 0/30 [00:00<?, ?it/s]


Epoch: 5


100%|██████████| 30/30 [09:53<00:00, 19.77s/it]
  0%|          | 0/213 [00:00<?, ?it/s]

Training loss: 0.873


100%|██████████| 213/213 [01:00<00:00,  3.51it/s]


Validation expression accuracy: 0.305
Validation value accuracy: 0.315
>>>>> Found best model so far <<<<<


  0%|          | 0/30 [00:00<?, ?it/s]


Epoch: 6


100%|██████████| 30/30 [09:52<00:00, 19.74s/it]
  0%|          | 0/213 [00:00<?, ?it/s]

Training loss: 0.748


100%|██████████| 213/213 [01:04<00:00,  3.29it/s]


Validation expression accuracy: 0.394
Validation value accuracy: 0.423
>>>>> Found best model so far <<<<<


  0%|          | 0/30 [00:00<?, ?it/s]


Epoch: 7


100%|██████████| 30/30 [09:53<00:00, 19.77s/it]
  0%|          | 0/213 [00:00<?, ?it/s]

Training loss: 0.646


100%|██████████| 213/213 [00:57<00:00,  3.68it/s]


Validation expression accuracy: 0.535
Validation value accuracy: 0.559
>>>>> Found best model so far <<<<<


  0%|          | 0/30 [00:00<?, ?it/s]


Epoch: 8


100%|██████████| 30/30 [09:54<00:00, 19.83s/it]
  0%|          | 0/213 [00:00<?, ?it/s]

Training loss: 0.587


100%|██████████| 213/213 [00:52<00:00,  4.03it/s]


Validation expression accuracy: 0.601
Validation value accuracy: 0.62
>>>>> Found best model so far <<<<<


  0%|          | 0/30 [00:00<?, ?it/s]


Epoch: 9


100%|██████████| 30/30 [09:54<00:00, 19.81s/it]
  0%|          | 0/213 [00:00<?, ?it/s]

Training loss: 0.505


100%|██████████| 213/213 [00:59<00:00,  3.60it/s]


Validation expression accuracy: 0.62
Validation value accuracy: 0.629
>>>>> Found best model so far <<<<<


  0%|          | 0/30 [00:00<?, ?it/s]


Epoch: 10


100%|██████████| 30/30 [09:53<00:00, 19.79s/it]
  0%|          | 0/213 [00:00<?, ?it/s]

Training loss: 0.447


100%|██████████| 213/213 [00:53<00:00,  4.00it/s]


Validation expression accuracy: 0.648
Validation value accuracy: 0.662
>>>>> Found best model so far <<<<<


  0%|          | 0/30 [00:00<?, ?it/s]


Epoch: 11


100%|██████████| 30/30 [09:53<00:00, 19.78s/it]
  0%|          | 0/213 [00:00<?, ?it/s]

Training loss: 0.396


100%|██████████| 213/213 [00:50<00:00,  4.21it/s]


Validation expression accuracy: 0.662
Validation value accuracy: 0.676
>>>>> Found best model so far <<<<<


  0%|          | 0/30 [00:00<?, ?it/s]


Epoch: 12


100%|██████████| 30/30 [09:54<00:00, 19.80s/it]
  0%|          | 0/213 [00:00<?, ?it/s]

Training loss: 0.37


100%|██████████| 213/213 [00:51<00:00,  4.12it/s]


Validation expression accuracy: 0.676
Validation value accuracy: 0.69
>>>>> Found best model so far <<<<<


  0%|          | 0/30 [00:00<?, ?it/s]


Epoch: 13


100%|██████████| 30/30 [09:53<00:00, 19.78s/it]
  0%|          | 0/213 [00:00<?, ?it/s]

Training loss: 0.336


100%|██████████| 213/213 [00:49<00:00,  4.30it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.638
Validation value accuracy: 0.653

Epoch: 14


100%|██████████| 30/30 [09:55<00:00, 19.84s/it]
  0%|          | 0/213 [00:00<?, ?it/s]

Training loss: 0.315


100%|██████████| 213/213 [00:46<00:00,  4.62it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.671
Validation value accuracy: 0.685

Epoch: 15


100%|██████████| 30/30 [09:53<00:00, 19.77s/it]
  0%|          | 0/213 [00:00<?, ?it/s]

Training loss: 0.298


100%|██████████| 213/213 [00:48<00:00,  4.37it/s]


Validation expression accuracy: 0.676
Validation value accuracy: 0.695
>>>>> Found best model so far <<<<<


  0%|          | 0/30 [00:00<?, ?it/s]


Epoch: 16


100%|██████████| 30/30 [09:53<00:00, 19.78s/it]
  0%|          | 0/213 [00:00<?, ?it/s]

Training loss: 0.285


100%|██████████| 213/213 [00:44<00:00,  4.83it/s]


Validation expression accuracy: 0.7
Validation value accuracy: 0.714
>>>>> Found best model so far <<<<<


  0%|          | 0/30 [00:00<?, ?it/s]


Epoch: 17


100%|██████████| 30/30 [09:53<00:00, 19.77s/it]
  0%|          | 0/213 [00:00<?, ?it/s]

Training loss: 0.262


100%|██████████| 213/213 [00:44<00:00,  4.74it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.685
Validation value accuracy: 0.704

Epoch: 18


100%|██████████| 30/30 [09:53<00:00, 19.79s/it]
  0%|          | 0/213 [00:00<?, ?it/s]

Training loss: 0.232


100%|██████████| 213/213 [00:47<00:00,  4.45it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.681
Validation value accuracy: 0.695

Epoch: 19


100%|██████████| 30/30 [09:53<00:00, 19.80s/it]
  0%|          | 0/213 [00:00<?, ?it/s]

Training loss: 0.219


100%|██████████| 213/213 [00:49<00:00,  4.30it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.695
Validation value accuracy: 0.714

Epoch: 20


100%|██████████| 30/30 [09:55<00:00, 19.84s/it]
  0%|          | 0/213 [00:00<?, ?it/s]

Training loss: 0.192


100%|██████████| 213/213 [00:48<00:00,  4.44it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.69
Validation value accuracy: 0.714

Epoch: 21


100%|██████████| 30/30 [09:55<00:00, 19.86s/it]
  0%|          | 0/213 [00:00<?, ?it/s]

Training loss: 0.183


100%|██████████| 213/213 [00:49<00:00,  4.27it/s]


Validation expression accuracy: 0.7
Validation value accuracy: 0.723
>>>>> Found best model so far <<<<<


  0%|          | 0/30 [00:00<?, ?it/s]


Epoch: 22


100%|██████████| 30/30 [09:52<00:00, 19.75s/it]
  0%|          | 0/213 [00:00<?, ?it/s]

Training loss: 0.169


100%|██████████| 213/213 [00:50<00:00,  4.19it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.695
Validation value accuracy: 0.704

Epoch: 23


100%|██████████| 30/30 [09:53<00:00, 19.80s/it]
  0%|          | 0/213 [00:00<?, ?it/s]

Training loss: 0.16


100%|██████████| 213/213 [00:47<00:00,  4.52it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.695
Validation value accuracy: 0.709

Epoch: 24


100%|██████████| 30/30 [09:53<00:00, 19.77s/it]
  0%|          | 0/213 [00:00<?, ?it/s]

Training loss: 0.161


100%|██████████| 213/213 [00:46<00:00,  4.56it/s]


Validation expression accuracy: 0.714
Validation value accuracy: 0.737
>>>>> Found best model so far <<<<<


  0%|          | 0/30 [00:00<?, ?it/s]


Epoch: 25


100%|██████████| 30/30 [10:43<00:00, 21.44s/it]
  0%|          | 0/213 [00:00<?, ?it/s]

Training loss: 0.146


100%|██████████| 213/213 [00:49<00:00,  4.32it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.69
Validation value accuracy: 0.704

Epoch: 26


100%|██████████| 30/30 [10:44<00:00, 21.48s/it]
  0%|          | 0/213 [00:00<?, ?it/s]

Training loss: 0.134


100%|██████████| 213/213 [00:47<00:00,  4.51it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.714
Validation value accuracy: 0.737

Epoch: 27


100%|██████████| 30/30 [10:09<00:00, 20.31s/it]
  0%|          | 0/213 [00:00<?, ?it/s]

Training loss: 0.128


100%|██████████| 213/213 [00:47<00:00,  4.50it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.709
Validation value accuracy: 0.728

Epoch: 28


100%|██████████| 30/30 [10:28<00:00, 20.95s/it]
  0%|          | 0/213 [00:00<?, ?it/s]

Training loss: 0.12


100%|██████████| 213/213 [00:54<00:00,  3.91it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.704
Validation value accuracy: 0.723

Epoch: 29


100%|██████████| 30/30 [12:18<00:00, 24.61s/it]
  0%|          | 0/213 [00:00<?, ?it/s]

Training loss: 0.111


100%|██████████| 213/213 [00:51<00:00,  4.11it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.709
Validation value accuracy: 0.728

Epoch: 30


100%|██████████| 30/30 [10:14<00:00, 20.50s/it]
  0%|          | 0/213 [00:00<?, ?it/s]

Training loss: 0.0999


 26%|██▋       | 56/213 [00:12<00:41,  3.79it/s]

Malformed expression ['/', '/', '*', '8', '30', '30']


100%|██████████| 213/213 [00:45<00:00,  4.73it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.718
Validation value accuracy: 0.737

Epoch: 31


100%|██████████| 30/30 [10:04<00:00, 20.15s/it]
  0%|          | 0/213 [00:00<?, ?it/s]

Training loss: 0.0962


100%|██████████| 213/213 [00:49<00:00,  4.27it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.714
Validation value accuracy: 0.732

Epoch: 32


100%|██████████| 30/30 [11:45<00:00, 23.52s/it]
  0%|          | 0/213 [00:00<?, ?it/s]

Training loss: 0.0907


100%|██████████| 213/213 [00:54<00:00,  3.92it/s]


Validation expression accuracy: 0.723
Validation value accuracy: 0.746
>>>>> Found best model so far <<<<<


  0%|          | 0/30 [00:00<?, ?it/s]


Epoch: 33


100%|██████████| 30/30 [11:07<00:00, 22.25s/it]
  0%|          | 0/213 [00:00<?, ?it/s]

Training loss: 0.0899


 26%|██▋       | 56/213 [00:12<00:44,  3.50it/s]

Malformed expression ['/', '/', '*', '8', '30', '30']


100%|██████████| 213/213 [00:47<00:00,  4.48it/s]


Validation expression accuracy: 0.728
Validation value accuracy: 0.751
>>>>> Found best model so far <<<<<


  0%|          | 0/30 [00:00<?, ?it/s]


Epoch: 34


100%|██████████| 30/30 [10:00<00:00, 20.02s/it]
  0%|          | 0/213 [00:00<?, ?it/s]

Training loss: 0.0838


100%|██████████| 213/213 [00:43<00:00,  4.87it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.723
Validation value accuracy: 0.742

Epoch: 35


100%|██████████| 30/30 [09:58<00:00, 19.93s/it]
  0%|          | 0/213 [00:00<?, ?it/s]

Training loss: 0.0782


100%|██████████| 213/213 [00:43<00:00,  4.91it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.723
Validation value accuracy: 0.742

Epoch: 36


100%|██████████| 30/30 [10:07<00:00, 20.27s/it]
  0%|          | 0/213 [00:00<?, ?it/s]

Training loss: 0.0754


100%|██████████| 213/213 [00:45<00:00,  4.66it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.718
Validation value accuracy: 0.737

Epoch: 37


100%|██████████| 30/30 [10:02<00:00, 20.09s/it]
  0%|          | 0/213 [00:00<?, ?it/s]

Training loss: 0.0717


 69%|██████▊   | 146/213 [00:32<00:20,  3.22it/s]

Malformed expression ['/', '-', '+', '+', '9', '7.28', '*', '6.95', '7.28', '-', '7.28', '7.28']


100%|██████████| 213/213 [00:46<00:00,  4.55it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.718
Validation value accuracy: 0.737

Epoch: 38


100%|██████████| 30/30 [10:00<00:00, 20.00s/it]
  0%|          | 0/213 [00:00<?, ?it/s]

Training loss: 0.0693


100%|██████████| 213/213 [00:52<00:00,  4.09it/s]


Validation expression accuracy: 0.737
Validation value accuracy: 0.756
>>>>> Found best model so far <<<<<


  0%|          | 0/30 [00:00<?, ?it/s]


Epoch: 39


100%|██████████| 30/30 [12:06<00:00, 24.22s/it]
  0%|          | 0/213 [00:00<?, ?it/s]

Training loss: 0.062


100%|██████████| 213/213 [00:53<00:00,  3.98it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.732
Validation value accuracy: 0.751

Epoch: 40


100%|██████████| 30/30 [11:24<00:00, 22.83s/it]
  0%|          | 0/213 [00:00<?, ?it/s]

Training loss: 0.0613


 26%|██▋       | 56/213 [00:14<00:54,  2.89it/s]

Malformed expression ['/', '*', '4', '-', '8', '30']


100%|██████████| 213/213 [00:53<00:00,  3.98it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.728
Validation value accuracy: 0.746

Epoch: 41


100%|██████████| 30/30 [11:53<00:00, 23.77s/it]
  0%|          | 0/213 [00:00<?, ?it/s]

Training loss: 0.0601


100%|██████████| 213/213 [00:55<00:00,  3.81it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.723
Validation value accuracy: 0.742

Epoch: 42


100%|██████████| 30/30 [11:19<00:00, 22.66s/it]
  0%|          | 0/213 [00:00<?, ?it/s]

Training loss: 0.06


100%|██████████| 213/213 [00:49<00:00,  4.33it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.723
Validation value accuracy: 0.742

Epoch: 43


100%|██████████| 30/30 [11:02<00:00, 22.08s/it]
  0%|          | 0/213 [00:00<?, ?it/s]

Training loss: 0.0583


100%|██████████| 213/213 [00:49<00:00,  4.35it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.728
Validation value accuracy: 0.746

Epoch: 44


100%|██████████| 30/30 [11:31<00:00, 23.05s/it]
  0%|          | 0/213 [00:00<?, ?it/s]

Training loss: 0.0538


100%|██████████| 213/213 [00:50<00:00,  4.19it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Validation expression accuracy: 0.718
Validation value accuracy: 0.737

Epoch: 45


 10%|█         | 3/30 [01:06<10:27, 23.25s/it]

# Evaluation

In [None]:
# Evaluation

if EVALUATION:

    # Data
    _, val_data, in_vocab, out_vocab, n_max_nP, t5_model = setup(use_t5)
    tensorize_data(val_data)

    # Model
    model = Model()
    model.load_state_dict(torch.load(model_save_path))
    model.to(device)

    # Evaluation
    model.eval()
    value_match, equation_match = [], []
    with torch.no_grad():
        for d in tqdm(val_data):
            if d['is_quadratic']: # This method is not equiped to handle equations with quadratics
                val_match = eq_match = False
            else:
                pred = predict(d, model)
                d['pred_tokens'] = [out_vocab.idx2token[idx] for idx in pred]
                val_match, eq_match = check_match(pred, d)
            value_match.append(val_match)
            equation_match.append(eq_match)
    print(f'Validation expression accuracy: {np.mean(equation_match):.3g}')
    print(f'Validation value accuracy: {np.mean(value_match):.3g}')

# Prediction

Once you have trained your model to a satisfactory accuracy, you can load in a checkpoint to predict on the test set. The output is a `'predictions.csv'` file in your directory, and you should submit this directly to the [Kaggle server](https://www.kaggle.com/t/7bf8b542b96f4214b0cca1e4d9b0bb17).

In [None]:
# Prediction

if PREDICTION:
    
    # Data
    test_data, in_vocab, out_vocab, n_max_nP, t5_model = setup(use_t5, do_eval=True)
    tensorize_data(test_data)

    # Model
    model = Model()
    model.load_state_dict(torch.load(model_save_path))
    model.to(device)    

    # Prediction
    with torch.no_grad():
        for d in tqdm(test_data): # There's no quadratics in the test_data, fortunately
            pred = predict(d, model)
            d['pred_tokens'] = pred_tokens = [out_vocab.idx2token[idx] for idx in pred]
            d['subbed_tokens'] = subbed_tokens = sub_nP(pred_tokens, d['nP'])
            # Make sure to round to 3 decimals
            d['Predicted'] = round(evaluate_prefix_expression(subbed_tokens), 3)
    predictions = pd.DataFrame(test_data).set_index('Id')
    predictions[['Predicted']].replace([np.inf, -np.inf, np.nan], 0).to_csv(predictions_save_path)
    print(f"Generated predictions at {predictions_save_path}")