The following notebook is a demonstration as to how the model solves coding questions.

In [None]:
# Importing relevant packages

from copy import copy
import os
from tqdm import tqdm, trange

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

!pip install dgl
!pip install dgl-cu100

import dgl
from dgl.nn import GraphConv

!pip install transformers
!pip install spacy

import spacy
nlp = spacy.load('en')

The data to be worked with has the following attributes:
1.   raw_question: The question as it was written.
2.   inputs: The input arguments of the function to complete.
3.   AST: An Abstract Syntax Tree that represents a possible correct program.

The following functions preprocess the data for easier future use. We use spaCy for dependency parsing.

In [None]:
def tokenize_and_separate_quants(data, n_min_vocab):
    pattern = re.compile('\d+,\d+|\d*\.\d+|\d+')
    constant_counts = Counter()
    n_max_inputs = 0 # Maximum number of quantities in the input
    for d in data:
        question = d['processed_question'].strip()
        d['in_tokens'] = in_tokens = []
        inputs = d['inputs']
        
        # Change this into inputs
        """
        inputs = []
        for token in question.split(' '):
            if pattern.fullmatch(token):
                nP.append(float(token.replace(',', '')))
                in_tokens.append('NUM')
            else:
                in_tokens.append(token)
        inputs = np.array(inputs)
        """

        n_max_inputs = max(len(inputs), n_max_inputs)

        if 'expression' in d:
            expression = d['expression'] # Ground truth expression
            # Find all quantities in the expression
            out_ops = pattern.split(expression)
            out_quants = map(float, pattern.findall(expression))
            out_tokens = []
            for op, out_quant in itertools.zip_longest(out_ops, out_quants, fillvalue=None):
                out_tokens.extend(op)
                if out_quant is not None: # The last out_quant is None due to zip_longest
                    equals, = np.nonzero(out_quant == nP)
                    if len(equals) == 0: # Output quantity not found in the input. Record quantity as constant
                        constant_counts[out_quant] += 1
                        out_tokens.append(f'{out_quant:g}')
                    else:
                        out_tokens.append(tuple(equals))
            d['out_tokens'] = out_tokens

        d['nP'] = np.array([f'{x:g}' for x in nP])
        # d['nP_positions'], = (np.array(in_tokens) == 'NUM').nonzero()
    constants = ['%g' % n for n, count in constant_counts.items() if count >= n_min_vocab]
    return constants, n_max_inputs

def dependency_parse(data):
    """
    Given raw text, spaCy takes care of dependency parsing, which will be used 
    to construct the input graph.
    """
    doc = nlp(data['raw_question'])
    d['dependency_parse'] = [token.head.i for token in doc]
    d['d_tokens'] = [token.text for token in doc] # Not related to dep parse
    return d['dependency_parse']

# Vocabulary

The following classes deal with the Vocabulary being used.

In [None]:
class Vocabulary:
    def __init__(self, words, pad='<pad>', unk='<unk>'):
        self.idx2token = words
        self.token2idx = {w: i for i, w in enumerate(words)}
        self.pad = self.token2idx.get(pad, None)
        self.unk = self.token2idx.get(unk, None)
        self.n = len(words)

def convert_word_to_bytepair_tokenization(d, t5_tokenizer):
    """
    Bytepair Tokenization is a common scheme used to reduce vocabulary size in 
    large pretrained models such as BERT; this is needed for T5 encoding.
    """
    import difflib
    t5_space = '▁'
    d_tokens = d['d_tokens']
    question = d['raw_question']

    t_tokens = [x.replace(t5_space, '') for x in t5_tokenizer.tokenize(question)]
    t_tokens = [x for x in t_tokens if x]

    t_join = ''.join(t_tokens)
    d_join = ''.join(d_tokens)
    if t_join == d_join:
        t2d = np.arange(len(t_join)).reshape(-1, 1)
        d2t = np.arange(len(d_join)).reshape(-1, 1)
    else:
        i_t = i_d = 0
        t2d = np.empty((len(t_join),), dtype=object)
        d2t = np.empty((len(d_join),), dtype=object)

        to_add = []
        to_sub = []
        for diff, _, char in difflib.ndiff(t_join, d_join):
            if diff == '+':
                to_add.append(i_d)
                i_d += 1
            elif diff == '-':
                to_sub.append(i_t)
                i_t += 1
            else:
                for i_d_ in to_add:
                    d2t[i_d_] = to_sub
                for i_t_ in to_sub:
                    t2d[i_t_] = to_add
                to_add = []
                to_sub = []

                t2d[i_t] = [i_d]
                d2t[i_d] = [i_t]
                i_t += 1
                i_d += 1
        for i_d_ in to_add:
            d2t[i_d_] = to_sub
        for i_t_ in to_sub:
            t2d[i_t_] = to_add
        assert i_t == len(t_join) and i_d == len(d_join)

    t_pos = np.concatenate([np.full((len(token),), i) for i, token in enumerate(t_tokens)])

    d2t_splits = np.split(d2t, np.cumsum([len(dtok) for dtok in d_tokens])[:-1])
    d_pos_to_t_pos = []
    for i_d, split in enumerate(d2t_splits):
        id_t_pos = set(t_pos[i_t] for i_ts in split for i_t in i_ts)
        d_pos_to_t_pos.append(sorted(id_t_pos))

    # Convert indices
    d['quant_cell_positions'] = [x for qc_pos in d['quant_cell_positions'] for x in d_pos_to_t_pos[qc_pos]]
    d['nP_positions'] = [d_pos_to_t_pos[nP_pos][0] for nP_pos in d['nP_positions']]
    
    # Edits the dependency parse to deal with byte-pair tokenization
    new_parse = {}
    for d_pos in d_pos_to_t_pos:
        for idx, t_pos in enumerate(d_pos_to_t_pos[d_pos]):
            if idx:
                new_parse[t_pos] = d_pos_to_t_pos[d_pos][idx - 1]
            else:
                new_parse[t_pos] = d['dependency_parse'][d_pos]
    
    d['dependency_parse'] = [new_parse[idx] for idx in sorted(new_parse.keys())]
    d['in_tokens'] = t_tokens

# Data Tensorization

The following code now takes the preprocessed data and creates tensors that can be inputted into PyTorch models and dgl Graph-related layers.

In [None]:
def construct_input_graph(data):
  """
  Given a processed question, the function will return a tuple of torch tensors
  in the format of the input of dgl.graph()

  text: question's text
  inputs: The list of input arguments of the function
  input_to_text: The list of indices corresponding to the most related input.
  """
  
  adj_matrix = torch.eye(n , dtype=torch.bool)
  for i, j in enumerate(data['dependency_parse']):
    adj_matrix[i][j] = True

  return torch.nonzero(adj_matrix, as_tuple=True)

def tensorize_data(data):
  """
  Build torch tensors that represent the dataset,
  """
  for d in data:
        # Indices of the in_tokens in the in_vocab
        d['in_idxs'] = torch.tensor([in_vocab.token2idx.get(x, in_vocab.unk) for x in d['in_tokens']])
        d['n_in'] = n_in = len(d['in_idxs'])
        d['n_nP'] = n_nP = len(d['nP'])
        # True if the position in the input has a quantity
        d['nP_in_mask'] = mask = torch.zeros(n_in, dtype=torch.bool)
        mask[d['nP_positions']] = True
        if 'out_tokens' in d:
            # Indices of the out_tokens in the out_vocab
            d['out_idxs'] = torch.tensor([out_vocab.token2idx.get(x, out_vocab.unk) for x in d['out_tokens']])
            d['n_out'] = len(d['out_idxs'])
            # A mask where the first n_nP elements are True
            d['nP_out_mask'] = mask = torch.zeros(n_max_inputs, dtype=torch.bool)
            mask[:n_nP] = True
        # Graph edges for constructing the DGL graph later
        d['edges'] = construct_input_graph(d)

# Encoder Architecture

The next bits of code design the architecture of the encoder. First up are the modules for transformers:

In [None]:
class TransformerAttention(nn.Module):
    """
    Used in Transformer Block, implements the dot-product attention
    """
    def __init__(self):
        super().__init__()
        self.qkv = nn.Linear(n_hid, n_head * (n_k * 2 + n_v))
        self.out = nn.Linear(n_head * n_v, n_hid)

    def forward(self, x, mask=None):
        n_batch, n_batch_max_in, n_hid = x.shape
        q_k_v = self.qkv(x).view(n_batch, n_batch_max_in, n_head, 2 * n_k + n_v).transpose(1, 2)
        q, k, v = q_k_v.split([n_k, n_k, n_v], dim=-1)

        q = q.reshape(n_batch * n_head, n_batch_max_in, n_k)
        k = k.reshape_as(q).transpose(1, 2)
        qk = q.bmm(k) / np.sqrt(n_k)

        if mask is not None:
            qk = qk.view(n_batch, n_head, n_batch_max_in, n_batch_max_in).transpose(1, 2)
            qk[~mask] = -np.inf
            qk = qk.transpose(1, 2).view(n_batch * n_head, n_batch_max_in, n_batch_max_in)
        qk = qk.softmax(dim=-1)
        v = v.reshape(n_batch * n_head, n_batch_max_in, n_v)
        qkv = qk.bmm(v).view(n_batch, n_head, n_batch_max_in, n_v).transpose(1, 2).reshape(n_batch, n_batch_max_in, n_head * n_v)
        out = self.out(qkv)
        return x + out

class TransformerBlock(nn.Module):
    """
    Custom Transformer
    """
    def __init__(self):
        super().__init__()
        self.attn = TransformerAttention()
        n_inner = n_hid * 4
        self.inner = nn.Sequential(
            nn.Linear(n_hid, n_inner),
            nn.ReLU(inplace=True),
            nn.Linear(n_inner, n_hid)
        )

    def forward(self, x, mask=None):
        x = x + self.attn(x, mask=mask)
        return x + self.inner(x)

The next section is for the Graph Convolutional Networks:

In [None]:
class GCNBranch(nn.Module):
    def __init__(self, n_hid_in, n_hid_out, dropout=0.3):
        super().__init__()
        """
        Define a branch of the graph convolution with
        1. GraphConv from n_hid_in to n_hid_in
        2. ReLU
        3. Dropout
        4. GraphConv from n_hid_in to n_hid_out
        
        Note: your should call dgl.nn.GraphConv with allow_zero_in_degree=True
        """
        ### Your code here ###
        self.gc1 = GraphConv(n_hid_in, n_hid_in, allow_zero_in_degree=True)
        self.drelu = nn.Sequential(nn.ReLU(inplace=True), nn.Dropout(dropout))
        self.gc2 = GraphConv(n_hid_in, n_hid_out, allow_zero_in_degree=True)

    def forward(self, x, graph):
        """
        Forward pass of your defined branch above
        """
        ### Your code here ###
        return self.gc2(graph, self.drelu(self.gc1(graph, x)))

class GCN(nn.Module):
    """
    A graph convolution network with multiple graph convolution branches
    """
    def __init__(self, n_head=4, dropout=0.3):
        super().__init__()
        self.branches = nn.ModuleList(GCNBranch(n_hid, n_hid // n_head, dropout) for _ in range(n_head))

        self.feed_forward = nn.Sequential(
            nn.Linear(n_hid, n_hid),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(n_hid, n_hid)
        )
        self.layer_norm = nn.LayerNorm(n_hid)
        self.n_head = n_head

    def forward(self, h, d_graph):
        x = h.reshape(-1, n_hid)
        graphs = [d_graph for i in range(n_head)]
        x = torch.cat([branch(x, g) for branch, g in zip(self.branches, graphs)], dim=-1).view_as(h)
        x = h + self.layer_norm(x)
        return x + self.feed_forward(x)

class Gate(nn.Module):
    """
    Activation gate used a few times in the TreeDecoder
    """
    def __init__(self, n_in, n_out):
        super(Gate, self).__init__()
        self.t = nn.Linear(n_in, n_out)
        self.s = nn.Linear(n_in, n_out)

    def forward(self, x):
        return self.t(x).tanh() * self.s(x).sigmoid()

# Decoder

The next batch of code is for the tree-based decoder:

In [None]:
class TreeDecoder(nn.Module):
    """
    Defines parameters and methods for decoding into an expression. Used in train and predict
    """
    def __init__(self, dropout=0.5):
        super().__init__()
        drop = nn.Dropout(dropout)
        
        self.constant_embeddings = nn.Parameter(torch.randn(1, out_vocab.n_constants, n_hid))

        self.qp_gate = nn.Sequential(drop, Gate(n_hid, n_hid))
        self.right = nn.Sequential(drop, Gate(3*n_hid, n_hid)) # Right(q_l, G_c, t_l)
        self.hasmore = nn.Sequential(drop, Gate(3*n_hid, 2)) #HasMore (q_r', G_c, t_r')

        self.attn_fc = nn.Sequential(drop,
            nn.Linear(2 * n_hid, n_hid),
            nn.Tanh(),
            nn.Linear(n_hid, 1)
        )
        self.quant_fc = nn.Sequential(drop,
            nn.Linear(n_hid * 3, n_hid),
            nn.Tanh(),
            nn.Linear(n_hid, 1, bias=False)
        )
        self.op_fc = nn.Sequential(drop, nn.Linear(n_hid * 2, out_vocab.n_ops))

        self.op_embedding = nn.Embedding(out_vocab.n_ops + 1, n_hid, padding_idx=out_vocab.n_ops)
        self.left = nn.Sequential(drop, Gate(3 * n_hid, n_hid)) #Left(q_p, G_c, y)
        self.left_qp = nn.Sequential(drop, Gate(3 * n_hid, n_hid), self.qp_gate) #Left_qp(q_p, G_c, y)

        self.subtree_gate = nn.Sequential(drop, Gate(3 * n_hid, n_hid))
        self.predict_keyword = nn.Sequential(drop, Gate(3 * n_hid, n_hid))
    
    def attention(self, q, zbar, in_mask=None):
        """
        Corresponds roughly to the GTS-Attention function defined by the paper
        """
        attn_score = self.attn_fc(
            torch.cat([q.unsqueeze(1).expand_as(zbar), zbar], dim=2)
        ).squeeze(2)
        if in_mask is not None:
            attn_score[~in_mask] = -np.inf
        attn = attn_score.softmax(dim=1)
        return (attn.unsqueeze(1) @ zbar).squeeze(1) # (n_batch, n_hid)

    def predict(self, qp_Gc, quant_embed, nP_out_mask=None):
        """
        Corresponds roughly to the GTS-Predict functions defined by the paper
        """
        quant_score = self.quant_fc(
            torch.cat([qp_Gc.unsqueeze(1).expand(-1, quant_embed.size(1), -1), quant_embed], dim=2)
        ).squeeze(2)
        op_score = self.op_fc(qp_Gc)
        pred_score = torch.cat((op_score, quant_score), dim=1)
        if nP_out_mask is not None:
            pred_score[:, out_vocab.base_nP:][~nP_out_mask] = -np.inf
        return pred_score


# Train Function

When training, we train on the sequence of operations needed for the construction of the tree, which was made available through preprocessing.

In [None]:
class Node:
    """
    Node for tree traversal during training
    """
    def __init__(self, parent):
        self.parent = parent
        self.is_root = (parent is None)
        self.children = []
        self.cur_right = None
        self.ql = None
        self.tl = None
        self.func = None
        self.keyword = None

def train(batch, model, opt):
    """
    Computes the loss on a batch of inputs, and takes a step with the optimizer
    """
    n_batch = len(batch)
    n_in = [d['n_in'] for d in batch]
    pad = lambda x, value: nn.utils.rnn.pad_sequence(x, batch_first=True, padding_value=value)
    in_idxs = pad([d['in_idxs'] for d in batch], in_vocab.pad).to(device)
    in_mask = pad([torch.ones(n, dtype=torch.bool) for n in n_in], False).to(device)

    graph = dgl.batch([dgl.graph(d['edges'], num_nodes=in_idxs.size(1), device=device) for d in batch])

    zbar, qroot = model.encode(in_idxs, n_in, graph, in_mask=None)
    decoder = model.decoder

    qp = decoder.qp_gate(qroot)

    label = pad([d['out_idxs'] for d in batch], out_vocab.pad)

    z_func = zbar.new_zeros((n_batch, n_max_inputs, n_hid))
    # Takes the embeddings of the target words and places them in a z_func matrix
    z_func[func_out_mask] = zbar[func_in_mask]

    n_quant = out_vocab.n_constants + n_max_inputs
    # Embeddings of all the input arguments + constants
    quant_embed = torch.cat([decoder.constant_embedding.expand(n_batch, -1, -1), z_nP], dim=1) # (n_batch, n_quant, n_hid)

    nodes = np.array([Node(None) for _ in range(n_batch)])
    func_min, func_max = out_vocab.base_op, out_vocab.base_op + out_vocab.n_ops
    quant_min, quant_max = out_vocab.base_quant, out_vocab.base_quant + n_quant

    scores = []
    prev_label = torch.zeros(n_batch)
    for i, label_i in enumerate(label.T):
        Gc = decoder.attention(qp, zbar, in_mask)
        qp_Gc = torch.cat([qp, Gc], dim=1)
        
        # Fix so that this is constructed later
        score = decoder.predict(qp_Gc, quant_embed, func_out_mask)
        scores.append(score)

        # Determine which function to send the input through next
        to_send_left = (func_min <= prev_label) and (prev_label <= func_max)
        to_send_hm = ((quant_min <= prev_label) and (prev_label <= quant_max)) or (0 == label_i)
        to_send_hm_up = (0 == label_i)
        to_send_key = ((func_min <= label_i) and (label_i <= func_max)) or ((quant_min <= label_i) and (label_i <= quant_max))
        to_send_right = (1 == label_i)

        # Deal with Left() case first
        func_embed = decoder.op_embedding((prev_label[to_send_left] - out_vocab.base_op).to(device))
        qp_Gc_func = torch.cat([qp_Gc[to_send_left], func_embed], dim=1)
        qleft = decoder.left(qp_Gc_func)

        for j, ql, func in zip(to_send_left.nonzero(as_tuple=True)[0], qleft, func_embed):
            node = nodes[j]
            nodes[j] = Node(node)
            node.children.append(nodes[j])
            node.cur_right = nodes[j]
            node.func = func
            node.ql = ql

        # Next, with HasMore()
        for j in to_send_hm_up.nonzero(as_tuple=True)[0]:
            pnode = nodes[j].up
            pnode.tl = decoder.merge_subtree(pnode.op, pnode.tl, nodes[j].ql) # Check !!
            nodes[j] = pnode
        
        qr_prime = torch.stack([nodes[j].ql for j in to_send_hm.nonzero(as_tuple=True)])
        tr_prime = torch.stack([nodes[j].tl for j in to_send_hm.nonzero(as_tuple=True)])
        qr_Gc_tr_hm = torch.cat([qr_prime, Gc, tr_prime], dim=1)
        hm_out = decoder.hasmore(qr_Gc_tr_hm)

        # Next, with PredictKeyword()
        q = torch.stack([nodes[j].ql for j in to_send_key.nonzero(as_tuple=True)])
        y_hat = torch.stack([nodes[j].op for j in to_send_key.nonzero(as_tuple=True)])
        q_Gc_yhat = torch.cat([q, Gc, y_hat], dim=1)
        key_out = decoder.predict_keyword(q_Gc_yhat)

        # Finally, with Right()
        qr_prime = torch.stack([nodes[j].ql for j in to_send_right.nonzero(as_tuple=True)])
        tr_prime = torch.stack([nodes[j].tl for j in to_send_right.nonzero(as_tuple=True)])
        qr_Gc_tr_right = torch.cat([qr_prime, Gc, tr_prime], dim=1)
        right_out = decoder.right(qr_Gc_tr_right)

        for j, qr in zip(to_send_right.nonzero(as_tuple=True)[0], right_out):
            node = nodes[j]
            nodes[j] = Node(node)
            node.cur_right = nodes[j]
            nodes.children.append(nodes[j])
            nodes.ql = qr 

        prev_label = label_i

# Predict Function

The next function is used for prediction.

In [None]:
class BeamNode(Node):
    """
    Node for beam search during evaluation
    """
    def __init__(self, up, prev, qp, token=None):
        super().__init__(up)
        self.prev = prev
        self.qp = qp
        self.token = token

    def trace_tokens(self, *last_token):
        if self.prev is None:
            return list(last_token)
        tokens = self.prev.trace_tokens()
        tokens.append(self.token)
        tokens.extend(last_token)
        return tokens

def predict(d, model, beam_size=5, n_max_out=45, mode='max_likelihood'):
    """
    Predict the idxs corresponding to an expression given the inputs. Leverages beam search to maximize
    prediction probability

    d: The piece of data to predict
    model: The trained model
    beam_size: Size associated with beam search
    n_max_out: Cap on the number of nodes in expression tree construction
    mode: either 'max_likelihood' or 'sample'
    """
    in_idxs = d['in_idxs'].unsqueeze(0).to(device=device)
    graph = dgl.graph(d['edges'], num_nodes=in_idxs.size(1), device=device)
    zbar, qroot = model.encode(in_idxs, [d['n_in']], graph, in_mask=None)
    z_func = zbar[:, d['func_positions']]

    decoder = model.decoder
    quant_embed = torch.cat([model.constant_embeddings, z_func], dim=1)

    func_min, func_max = out_vocab.base_op, out_vocab.base_op + out_vocab.n_ops
    quant_min, quant_max = out_vocab.base_quant, out_vocab.base_quant + n_quants

    if mode == 'max_likelihood':
        best_done_beam = (-np.inf, None)
        beams = [(0, BeamNode(up=None, prev=None, qp=decoder.qp_gate(qroot)))]
        for _ in range(n_max_out):
            new_beams = []
            for logp_prev, node in beams:
                Gc = decoder.attention(node.qp, zbar)
                qp_Gc = torch.cat([node.qp, Gc], dim=1)
                log_prob = decoder.predict(qp_Gc, quant_embed).log_softmax(dim=1)
                top_logps, top_tokens = log_prob.topk(beam_size, dim=1)
                for logp_token_, out_token_ in zip(top_logps.unbind(dim=1), top_tokens.unbind(dim=1)):
                    out_token = out_token_.item()
                    logp = logp_prev + logp_token_.item()
                    if quant_min <= out_token <= quant_max:
                        construct = False
                        while not construct:
                            qr_Gc_tr = torch.cat([node.qp, Gc, node.tl], dim=1)
                            hm = decoder.hasmore(qr_Gc_tr)
                            if hm.argmax(dim=1) == 1:
                                construct = True
                            else:
                                node = node.up
                    elif func_min <= out_token <= func_max:
                        func_embed = decoder.op_embedding(out_token)
                        qp_Gc_func = torch.cat([qp_Gc, func_embed], dim=1)
                        prev_node = copy(node)
                        next_node = prev_node.left = BeamNode(
                            up=prev_node, prev=prev_node,
                            qp=decoder.left_qp(qp_Gc_op)
                            token=out_token
                        )
                        prev_node.op = op_embed
                        prev_node.ql = decoder.left(qp_Gc_op)
                    

    elif mode == 'sample':
        pass
    else:
        raise ValueError("Mode can only be max_likelihood or sample")

# Training Loop

The next set of code is the actual training loop; running this cell will train the model.

In [None]:
use_t5 = 'small' # Value should be None, 'small', or 'base'
model_save_dir = f'models/{use_t5 or "custom"}'
os.makedirs(model_save_dir, exist_ok=True)

# IMPORTANT NOTE: if you change some of these hyperparameters during training,
# you will also need to change them during prediction (see next section)
n_max_in = 100
n_epochs = 100
n_batch = 64
learning_rate = 1e-3
if use_t5:
    # T5 hyperparameters
    freeze_layers = []
    weight_decay = 1e-5
    n_hid = dict(small=512, base=768)[use_t5] # Do not modify unless you want to try t5-large
    n_k = n_v = 64
    n_head = 8
else:
    # Custom transformer hyperparameters
    n_layers = 3
    n_hid = 512
    n_k = n_v = 64
    n_head = 8
    weight_decay = 0
device = 'cuda:0'

train_data, val_data, in_vocab, out_vocab, n_max_nP, t5_model = setup(use_t5)
tensorize_data(itertools.chain(train_data, val_data))

model = Model()
opt = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, n_epochs)
model.to(device)

epoch = 0
i = 0
while epoch < n_epochs:
    print('Epoch:', epoch + 1)
    model.train()
    losses = []
    for start in trange(0, len(train_data), n_batch):
        batch = sorted(train_data[start: start + n_batch], key=lambda d: -d['n_in'])
        loss = train(batch, model, opt, i)
        losses.append(loss)
        i += 1
    scheduler.step()

    print(f'Training loss: {np.mean(losses):.3g}')

    epoch += 1
    if epoch % 10 == 0:
        model.eval()
        value_match, equation_match = [], []
        with torch.no_grad():
            for d in tqdm(val_data):
                pred = predict(d, model)
                d['pred_tokens'] = [out_vocab.idx2token[idx] for idx in pred]
                val_match, eq_match = check_match(pred, d)
                value_match.append(val_match)
                equation_match.append(eq_match)
        print(f'Validation expression accuracy: {np.mean(equation_match):.3g}')
        print(f'Validation value accuracy: {np.mean(value_match):.3g}')
        # We save the model every 10 epochs, feel free to load in a trained model with
        # model.load_state_dict(torch.load(f'models/model-{epoch}.pth'))
        # Note: if you want to restart training from a saved model, you must also save and load the optimizer with
        # torch.save(opt.state_dict(), os.path.join(model_save_dir, f'opt-{epoch}.pth'))
        torch.save(model.state_dict(), os.path.join(model_save_dir, f'model-{epoch}.pth'))
    print()

# Predictions

The next batch of code allows multiple predictions.

In [None]:
use_t5 = 'small'
eval_epoch = 30
device = 'cpu'

# Make sure your parameter here is the exact same as the parameters you trained with,
# else the model will not load correctly
n_max_in = 100
if use_t5:
    # T5 hyperparameters
    freeze_layers = []
    n_hid = dict(small=512, base=768)[use_t5] # Do not modify unless you want to try t5-large
    n_k = n_v = 64
    n_head = 8
else:
    # Custom transformer hyperparameters
    n_layers = 3
    n_hid = 512
    n_k = n_v = 64
    n_head = 8

test_data, in_vocab, out_vocab, n_max_nP, t5_model = setup(use_t5, do_eval=True)
model = Model()
model.load_state_dict(torch.load(f'models/{use_t5 or "custom"}/model-{eval_epoch}.pth'))
tensorize_data(test_data)

with torch.no_grad():
    for d in tqdm(test_data): # There's no quadratics in the test_data, fortunately
        pred = predict(d, model)
        d['pred_tokens'] = pred_tokens = [out_vocab.idx2token[idx] for idx in pred]
        d['subbed_tokens'] = subbed_tokens = sub_nP(pred_tokens, d['nP'])
        d['Predicted'] = round(evaluate_prefix_expression(subbed_tokens), 3) # Make sure to round to 3 decimals

import pandas as pd
predictions = pd.DataFrame(test_data).set_index('Id')