Always run the following 3 cells first.

In [None]:
from datasets.sequence_generator import generate_examples

import argparse

import torch.nn as nn
import torch

import harvard_transformer as tr
from torch.autograd import Variable


import numpy as np

In [None]:
##### DATA PARAMETERS #####
dataset_file = "ten_tokens_explicit_singular_data.txt"
vocabulary_file = 'datasets/ten_tokens_explicit.txt'
operation_type = 'singular'
transition_type = 'explicit'

##### MODEL PARAMETERS #####
num_layers = 2
num_heads = 2
hidden_size = 16 # The hidden size MUST be divisible by the number of heads and even
model_save = 'transformer_hid_'+str(hidden_size)+'_heads_'+str(num_heads)+'_lyrs_'+str(num_layers)+'.mdl'

##### TRAINING PARAMETERS #####
batch_size = 24
num_batches = 100
num_epochs = 100
print_frequency = 10

In [None]:
def data_generator(vocabulary, batch_size, num_batches, Ks=[2,4,5,7]):
    "Generate random data for a src-tgt copy task."
    for btch in range(num_batches):
        selected_ks = np.random.choice(Ks, batch_size)

        pad = len(vocabulary)
        srcs = []
        tgts = []
        for lk in selected_ks:
            output = generate_examples(transition_type = transition_type, \
                                        operation_type = operation_type, \
                                        vocabulary = vocabulary, \
                                        k = lk, \
                                        num_examples=1).strip()

            training_line, target = output.split(';')
            training_sequence = training_line.split(' ')

            src = []
            for char in training_sequence:
                src.append(vocabulary.index(char))

            for padding_char in range(20 - len(training_sequence)):
                src.append(pad)

            tgt = [pad, vocabulary.index(target)]
            # for padding_char in range(20 - len(tgt)):
            #     tgt.append(pad)

            srcs.append(src)
            tgts.append(tgt)

        srcs = torch.from_numpy(np.array(srcs))
        tgts = torch.from_numpy(np.array(tgts))

        # data = torch.from_numpy(np.random.randint(1, V, size=(batch, 10)))
        # data[:, 0] = 1
        srcs = Variable(srcs, requires_grad=False)
        tgts = Variable(tgts, requires_grad=False)
        # import pdb; pdb.set_trace()

        yield tr.Batch(srcs, tgts, pad)

def data_sample(vocabulary, Ks=[2,4,5,7]):
    "Generate random data for a src-tgt copy task."
    pad = len(vocabulary)
    output = generate_examples(transition_type = transition_type, \
                               operation_type = operation_type, \
                               vocabulary = vocabulary, \
                               k = np.random.choice(Ks), \
                               num_examples=1).strip()

    training_line, target = output.split(';')
    training_sequence = training_line.split(' ')

    src = []
    for char in training_sequence:
        src.append(vocabulary.index(char))

    for padding_char in range(20 - len(training_sequence)):
        src.append(pad)

    tgt = [pad, vocabulary.index(target)]
            # for padding_char in range(20 - len(tgt)):
            #     tgt.append(pad)

    src = Variable(torch.from_numpy(np.array(src)), requires_grad=False)
    tgt = Variable(torch.from_numpy(np.array(tgt)), requires_grad=False)
        # import pdb; pdb.set_trace()
    return src, tgt

vocabulary = []
with open(vocabulary_file) as file:
    for line in file:
        vocabulary.append(line.strip())

The following cell only loads the model. If the parameters specified prebiously describe the model you wish to load then run this following cell to load the model into memory.

In [None]:
##### LOAD MODEL #####
model = tr.make_transformer(src_vocab=len(vocabulary)+1, \
                        tgt_vocab=len(vocabulary)+1, \
                        N=num_layers, \
                        d_model=hidden_size, \
                        d_ff=4*hidden_size, \
                        h=num_heads, \
                        dropout=0.0)
model.load_state_dict(torch.load(model_save))

The following cell trains a new model using the parameters set above. The model will be save ever `print_frequency` epochs under the name set in `model_save`

In [None]:
##### TRAIN MODEL #####
model = tr.make_transformer(src_vocab=len(vocabulary)+1, \
                        tgt_vocab=len(vocabulary)+1, \
                        N=num_layers, \
                        d_model=hidden_size, \
                        d_ff=4*hidden_size, \
                        h=num_heads, \
                        dropout=0.0)

criterion = nn.CrossEntropyLoss()

model_opt = tr.NoamOpt(model.src_embed[0].d_model, 1, 400,
        torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
print('Training Transformer')
for epoch in range(num_epochs):
    model.train()
    tr.run_epoch(data_generator(vocabulary, batch_size, num_batches), \
                model, \
              tr.SimpleLossCompute(model.generator, criterion, model_opt))
    
    if (epoch + 1) % print_frequency == 0:
        print('Epoch::'+str(epoch+1))
        model.eval()
        print(tr.run_epoch(data_generator(vocabulary, batch_size, 1), \
                    model, \
                  tr.SimpleLossCompute(model.generator, criterion, None)))
        torch.save(model.state_dict(), model_save)

The following cell tests the model. Be sure to have either a model trained or loaded using one of the previous cells. 

In [None]:
##### TEST MODEL #####
model.eval()
for k in range(2,11):
    correct = 0
    for ex in range(2000):
        src, tgt = data_sample(vocabulary, Ks=[k])
        src = src.reshape(1,-1)
        true = tgt[1].item()
        
        src_mask = (src != len(vocabulary)).unsqueeze(-2)
        out = tr.greedy_decode(model, src, src_mask, max_len=2, start_symbol=len(vocabulary))
        
        pred = out[0][1].item()

        # print(pred, true)
        if pred == true:
            # print('CORR')
            correct += 1
    # import pdb; pdb.set_trace()
    print('Dataset L'+str(k)+' Accuracy: '+str(round(correct/2000*100, 2)) + '%')
