In [None]:
import torch
import torch.nn as nn
import numpy as np
from transformers import BertTokenizer, BertConfig, BertModel
import os
import math
import time
from lego_data import make_lego_datasets

def seed_everything(seed: int):
    import random, os
    import numpy as np
    import torch
    
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
seed_everything(0)

In [None]:
# n_var: total number of variables in a chain
# n_train_var: number of variables to provide supervision during training
n_var, n_train_var = 12, 6

# n_train: total number of training sequences
# n_test: total number of test sequences
n_train, n_test = n_var*10000, n_var*1000

# batch size >= 500 is recommended, smaller batch size may result in unstable training behavior
batch_size = 1000

# use_pretrained_transformer: whether to use a pretrained transformer as base model
use_pretrained_transformer = False

In [None]:
# specify tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# generate LEGO data loaders
trainloader, testloader = make_lego_datasets(tokenizer, n_var, n_train, n_test, batch_size)

# examine an example LEGO sequence
seq, label, _ = trainloader.dataset[0]
print(tokenizer.decode(seq))
print(list(label.numpy()))

In [None]:
# a wrapper on transformer model for token classification
class Model(nn.Module):
    def __init__(self, base, d_model, tgt_vocab=1):
        super(Model, self).__init__()
        self.base = base
        self.classifier = nn.Linear(d_model, tgt_vocab)
        
    def forward(self, x, mask=None):
        h = self.base(x)
        out = self.classifier(h.last_hidden_state)
        return out

    
if use_pretrained_transformer:
    base = BertModel.from_pretrained("bert-base-uncased")
else:
    config = BertConfig.from_pretrained("bert-base-uncased")
    base = BertModel(config)
    
model = Model(base, base.config.hidden_size)

# data parallel training
model = nn.DataParallel(model.cuda())

In [None]:
train_var_pred = [i for i in range(n_train_var)] 
test_var_pred = [i for i in range(n_var)]

def train(print_acc=False):
    total_loss = 0
    correct = [0]*n_var
    total = 0
    model.train()
    for batch, labels, order in trainloader:
    
        x = batch.cuda()
        y = labels.cuda()
        inv_order = order.permute(0, 2, 1).cuda()
        
        optimizer.zero_grad()
        pred = model(x)
        ordered_pred = torch.bmm(inv_order, pred[:, 1:-3:5, :]).squeeze()

        loss = 0
        for idx in train_var_pred:
            loss += criterion(ordered_pred[:, idx], y[:, idx].float()) / len(train_var_pred)
            total_loss += loss.item() / len(train_var_pred)
    
            correct[idx] += ((ordered_pred[:, idx]>0).long() == y[:, idx]).float().mean().item()
            
        total += 1
    
        loss.backward()
        optimizer.step()
    
    train_acc = [corr/total for corr in correct]
    print("   Train Loss: %f" % (total_loss/total))
    if print_acc:
        for idx in train_var_pred:
            print("     %s: %f" % (idx, train_acc[idx]))
    
    return train_acc


def test():
    test_acc = []
    start = time.time()
    total_loss = 0
    correct = [0]*n_var
    total = 0
    model.eval()
    with torch.no_grad():
        for batch, labels, order in testloader:
    
            x = batch.cuda()
            y = labels.cuda()
            inv_order = order.permute(0, 2, 1).cuda()
            pred = model(x)
            ordered_pred = torch.bmm(inv_order, pred[:, 1:-3:5, :]).squeeze()
            
            for idx in test_var_pred:
                loss = criterion(ordered_pred[:, idx], y[:, idx].float())
                total_loss += loss.item() / len(test_var_pred)
                correct[idx] += ((ordered_pred[:, idx]>0).long() == y[:, idx]).float().mean().item()
                          
            total += 1
        
        test_acc = [corr/total for corr in correct]
        print("   Test  Loss: %f" % (total_loss/total))
        for idx in test_var_pred:
            print("     %s: %f" % (idx, test_acc[idx]))
   

    return test_acc

In [None]:
criterion = nn.BCEWithLogitsLoss().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)


# test acc is evaluated for each variables, printed in the order long the chain
for epoch in range(200):
    start = time.time()
    print('Epoch %d, lr %f' % (epoch, optimizer.param_groups[0]['lr']))

    train()
    test()
    scheduler.step()

    print('Time elapsed: %f s' %(time.time() - start))