In [None]:
# Edited from https://github.com/Kyubyong/nlp_made_easy/blob/master/Pos-tagging%20with%20Bert%20Fine-tuning.ipynb 

In [None]:
import torch
torch.manual_seed(0)

import numpy as np
np.random.seed(0)

import random
random.seed(0)

In [None]:
!pip install transformers

In [None]:
import os
from string import punctuation
from copy import deepcopy
import pickle
import os 
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt 
from tqdm.notebook import tqdm
from IPython.display import FileLink

from sklearn.metrics import mean_absolute_error

import torch
import torch.nn as nn
from torch.utils import data
import torch.optim as optim
from transformers import AutoTokenizer, AutoModel

In [None]:
train_data = pd.read_csv('training_data/training_data.csv')
# Remove EOS as redundant by word_id
train_data['word'] = train_data['word'].str.replace('<EOS>', '')

In [None]:
tokenizer = AutoTokenizer.from_pretrained('xlnet-base-cased', do_lower_case=False)


In [None]:
X = []
Y = []
W = []
I = []

sent_start = 0
word_id_tracker = 0
for word_id_word in tqdm(train_data[['word_id', 'word']].iterrows(), 
                         total=len(train_data)):
    idx = word_id_word[0]
    word_id, word = word_id_word[1].tolist()
    if word_id_tracker > word_id or idx == len(train_data) - 1:
        if idx == len(train_data) - 1:
              idx += 1
        sentence = train_data['word'].tolist()[sent_start:idx]        
        target = train_data[['nFix', 'FFD', 'GPT', 'TRT', 'fixProp']][sent_start:idx].values.tolist()
        word_lengths = [len(word.rstrip(punctuation)) for word in sentence]
        W.append(word_lengths)
        X.append(sentence)
        Y.append(target) 
        I.append((sent_start, idx))     
      
        sent_start = deepcopy(idx)
    word_id_tracker = word_id

In [None]:
#Train on val split 1,2,3 or all (for test predicitons)
FOLD = 1 #1,2,3,'ALL'

if FOLD == 1:
  X_train, Y_train, W_train, I_train = X[:400]+X[533:], Y[:400]+Y[533:], W[:400]+W[533:], I[:400]+I[533:]
  X_val, Y_val, W_val, I_val = X[400:533], Y[400:533], W[400:533], I[400:533]
elif FOLD == 2:
  X_train, Y_train, W_train, I_train = X[:533]+X[667:], Y[:533]+Y[667:], W[:533]+W[667:], I[:533]+I[667:]
  X_val, Y_val, W_val, I_val = X[533:667], Y[533:667], W[533:667], I[533:667]
elif FOLD == 3:
  X_train, Y_train, W_train, I_train = X[:667], Y[:667], W[:667], I[:667]
  X_val, Y_val, W_val, I_val = X[667:], Y[667:], W[667:], I[667:]
elif FOLD == 'ALL':
  X_train, Y_train, W_train, I_train = X, Y, W, I
  X_val, Y_val, W_val, I_val = X[667:], Y[667:], W[667:], I[667:]

In [None]:
class EyeTrackDataset(data.Dataset):
    def __init__(self, sents_targets_lengths_idxes):
        sents, target_vars, word_lengths, data_idxes  = [], [], [], []
        for sent, target, lengths, idxes in sents_targets_lengths_idxes:
            words = sent
            sents.append(['[CLS]'] + words + ['[SEP]'])
            target_vars.append([[0]*5] + target + [[0]*5])
            word_lengths.append([0] + lengths + [0]) 
            data_idxes.append(idxes)
        self.sents, self.target_vars, self.word_lengths, self.data_idxes = sents, target_vars, word_lengths, data_idxes

    def __len__(self):
        return len(self.sents)

    def __getitem__(self, idx):
        words, targets, lengths, idxes = self.sents[idx], self.target_vars[idx], self.word_lengths[idx], self.data_idxes[idx]

        w_lens, x, y = [], [], [] 
        # Indicate whether a sentence-piece token is the head or a word or not
        is_heads = []

        for w, t, l in zip(words, targets, lengths):
            tokens = tokenizer.tokenize(w) if w not in ("[CLS]", "[SEP]") else [w]
            tok_ids = tokenizer.convert_tokens_to_ids(tokens)

            # Mask non-head tokens, their targets, and their lengths
            # Targets with value of 0 will not be used for backpropagation
            is_head = [1] + [0]*(len(tokens) - 1) 
            t = [t] + [[0]*5] * (len(tokens) - 1)  
            l = [l] * (len(tokens))  

            x.extend(tok_ids)
            is_heads.extend(is_head)
            y.extend(t)
            w_lens.extend(l)
        assert len(w_lens)==len(x)==len(y)==len(is_heads), "len(x)={}, len(y)={}, len(is_heads)={}".format(len(x), len(y), len(is_heads))

        seqlen = len(y)
        words = " ".join(words)

        return words, x, is_heads, y, seqlen, w_lens, idxes


In [None]:
train_dataset = EyeTrackDataset(zip(X_train, Y_train, W_train, I_train))
val_dataset   = EyeTrackDataset(zip(X_val, Y_val, W_val, I_val))

In [None]:
def pad(batch):
    '''Pads to the longest sample'''
    

    f = lambda x: [sample[x] for sample in batch]
    f_p0 = lambda x, seqlen: [sample[x] + [0] * (seqlen - len(sample[x])) for sample in batch]
    f_p1 = lambda x, seqlen: [sample[x] + [[0]*5] * (seqlen - len(sample[x])) for sample in batch]

    seqlens = f(4)
    maxlen = np.array(seqlens).max()

    words = f(0)
    x = f_p0(1, maxlen)
    is_heads = f(2)
    y = f_p1(3, maxlen)    
    w_lens = f_p0(5, maxlen)
    idxes = f(6)
    
    return words, torch.LongTensor(x), is_heads, torch.FloatTensor(y), seqlens, torch.LongTensor(w_lens), idxes

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.bert = AutoModel.from_pretrained('xlnet-base-cased') 
        
        self.regr1 = nn.Linear(769, 256)
        self.regr2 = nn.Linear(256, 256)
        self.regr3 = nn.Linear(256, 5)

        self.device = device
        
        self.return_attention = False

    def freeze_bert(self):
        for param in self.bert.parameters():
            param.requires_grad = False
            
    def randomise_weights(self):
        for module in self.bert.modules():
            self.bert._init_weights(module)
            
    def forward(self, x, y, w_lens):
        x = x.to(device)
        y = y.to(device)
        l = w_lens.to(device)
        
        if self.training:
            self.bert.train()
            outputs = self.bert(x) #(batch_size, sequence_length, hidden_size)  
            enc = outputs[0]    
            hidden_states = outputs[0]             

            lenc = torch.cat((enc, l.unsqueeze(-1), ), dim=2)
            lenc = self.regr1(lenc)                        
            lenc = self.regr2(lenc)
            y_hat = self.regr3(lenc)
        else:
            self.bert.eval()
            with torch.no_grad():                
                outputs = self.bert(x) #(batch_size, sequence_length, hidden_size)  
                enc = outputs[0]
                hidden_states = outputs[0]

                lenc = torch.cat((enc, l.unsqueeze(-1), ), dim=2)
                lenc = self.regr1(lenc)                
                lenc = self.regr2(lenc)
                y_hat = self.regr3(lenc)                                    
                
        return y, y_hat, hidden_states

In [None]:
def get_hidden(model, iterator):
    word_hidden_states = {}
    model.eval()    
   
    batch_offset = 0
    with torch.no_grad():

        for i, batch in enumerate(iterator):

            words, x, is_heads, y, seqlens, w_lens, idxes = batch
            y_hat, lenc, hidden_states = model(x, y, w_lens)  # y_hat: (N, T)

            for sample_idx, (sample_heads, idxes, words) in enumerate(zip(is_heads, idxes, words)):     
                
                # pick attentions from word heads 
                take_idxes_raw = list(range(len(sample_heads)))                
                take_idxes = [take_idx for idx, take_idx in enumerate(take_idxes_raw) if sample_heads[idx]==1][1:-1]               

                per_word_hidden_state = hidden_states[sample_idx, take_idxes, :].cpu().numpy()
                per_word_hidden_state_split = np.split(per_word_hidden_state, per_word_hidden_state.shape[0], axis=0)

                assert len(per_word_hidden_state_split) == idxes[1]-idxes[0]

                for idx, word_hidn in zip(range(idxes[0], idxes[1]), per_word_hidden_state_split):
                    word_hidden_states[idx] = word_hidn

    return(word_hidden_states)  

In [None]:
MAX_SEQ_LEN = 100
DEV_BATCH_SIZE = 32

def eval(model, iterator):
    model.eval()
    
    batch_offset = 0
    Words, Is_heads, Y, Y_hat, Y_hat_raw = [], [], np.empty((0, 5)), np.empty((0, 5)), np.empty((0, MAX_SEQ_LEN, 5))
    with torch.no_grad():
        for i, batch in enumerate(iterator):
            words, x, is_heads, y, seqlens, w_lens, idxes = batch

            _, y_hat, _ = model(x, y, w_lens)  # y_hat: (N, T)            
            
            Y_hat_raw = np.vstack((Y_hat_raw, np.zeros((y_hat.shape[0], MAX_SEQ_LEN, 5))))

            for head_idx, head_list in enumerate(is_heads):
              place_idx = 0              
              for take_idx, word_ind in enumerate(head_list):

                # Ignore CLS and SEP predictions.
                if take_idx in (0, len(head_list)-1):
                  pass
                elif word_ind == 1:
                  Y_hat_raw[head_idx+batch_offset,place_idx,:] = y_hat.cpu().numpy()[head_idx,take_idx,:]  
                  place_idx += 1     
                # Ignore 0ed token indexes           
                elif word_ind == 0:
                  pass

            batch_offset += y_hat.shape[0]

            y_useful = y[y.nonzero(as_tuple=True)]
            y_hat_useful = y_hat.squeeze()[y.nonzero(as_tuple=True)]
            
            Words.extend(words)
            Is_heads.extend(is_heads)            
            Y = np.vstack((Y, y_useful.numpy().reshape(-1,5)))
            Y_hat = np.vstack((Y_hat, y_hat_useful.squeeze().cpu().numpy().reshape(-1,5)))
          
        MAE_Overall = mean_absolute_error(Y, Y_hat)         
        MAE_nFix = mean_absolute_error(Y[:,0], Y_hat[:,0])  
        MAE_FFD = mean_absolute_error(Y[:,1], Y_hat[:,1])  
        MAE_GPT = mean_absolute_error(Y[:,2], Y_hat[:,2])  
        MAE_TRT = mean_absolute_error(Y[:,3], Y_hat[:,3])  
        MAE_fixProp = mean_absolute_error(Y[:,4], Y_hat[:,4])
        MAEs = {'Overall': MAE_Overall, 'nFix': MAE_nFix, 'FFD': MAE_FFD,	
                'GPT': MAE_GPT, 'TRT': MAE_TRT, 'fixProp': MAE_fixProp}

    return(Y, Y_hat_raw, MAEs)

In [None]:
def train(model, iterator, test_iter, optimizer, criterion, epochs):
    losses = []
    steps = 0
    for j in range(epochs):  
        model.train()     

        for i, batch in enumerate(iterator):            
            steps += 1
            words, x, is_heads, y, seqlens, w_lens, idxes = batch
            optimizer.zero_grad()            
           
            y, y_hat, _ = model(x, y, w_lens) 
            y_hat = y_hat.squeeze() 

            # Don't use 0ed targets in calculating loss
            loss = criterion(y_hat[y.nonzero(as_tuple=True)], y[y.nonzero(as_tuple=True)])
            loss.backward()
            optimizer.step()
            
        
        _, _, MAEs = eval(model, test_iter)
        losses.append(MAEs["Overall"])
        print(f'epoch {j}, train loss: {float("{0:.5f}".format(loss))}, Avg MAE: {float("{0:.5f}".format(MAEs["Overall"]))}, Moving Avg {float("{0:.5f}".format(np.mean(losses[-10:])))}')
        print(f'         nFix: {float("{0:.3f}".format(MAEs["nFix"]))} FFD: {float("{0:.3f}".format(MAEs["FFD"]))} GPT: {float("{0:.3f}".format(MAEs["GPT"]))} TRT:{float("{0:.3f}".format(MAEs["TRT"]))} fixProp:{float("{0:.3f}".format(MAEs["fixProp"]))}')

In [None]:
model = Net()

In [None]:
#model.randomise_weights()

In [None]:
model.to(device)
print()

In [None]:
train_iter = data.DataLoader(dataset=train_dataset,
                             batch_size=32,
                             shuffle=True,
                             num_workers=1,
                             collate_fn=pad)
val_iter = data.DataLoader(dataset=val_dataset,
                             batch_size=DEV_BATCH_SIZE,
                             shuffle=False,
                             num_workers=1,
                             collate_fn=pad)


In [None]:
N_EPOCHS = 10
optimizer = optim.AdamW(model.parameters(), lr = 0.00001)
criterion = nn.SmoothL1Loss()

In [None]:
train(model, train_iter, val_iter, optimizer, criterion, N_EPOCHS)

In [None]:
hidden_states_train = get_hidden(model, train_iter)
hidden_states_val = get_hidden(model, val_iter)

hidden_states_all = {**hidden_states_train, **hidden_states_val}

In [None]:
output = open(f'CMCL_XLNET_Fold_{str(FOLD)}_Embeddings.pkl', 'wb')
pickle.dump(hidden_states_val, output)

In [None]:
torch.save(model.state_dict(), f'XLNET_Fold_{str(FOLD)}.mdl')

In [None]:
test_data = pd.read_csv('test_data/test_data.csv')
test_data['word'] = test_data['word'].str.replace('<EOS>', '')

In [None]:
X_test = []
Y_test = []
W_test = []
I_test = []


sent_start = 0
word_id_tracker = 0
for word_id_word in tqdm(test_data[['word_id', 'word']].iterrows(), 
                         total=len(test_data)):
    idx = word_id_word[0]
    word_id, word = word_id_word[1].tolist()
    if word_id_tracker > word_id or idx == len(test_data) - 1:
        if idx == len(test_data) - 1:
              idx += 1
        sentence = test_data['word'].tolist()[sent_start:idx]        
        word_lengths = [len(word.rstrip(punctuation)) for word in sentence]
        W_test.append(word_lengths)
        Y_test.append([[1]*5 for word in sentence])
        X_test.append(sentence)
        I_test.append((sent_start, idx))     
      
        sent_start = deepcopy(idx)
    word_id_tracker = word_id

In [None]:
test_dataset = EyeTrackDataset(zip(X_test, Y_test, W_test, I_test))

test_iter = data.DataLoader(dataset=test_dataset,
                             batch_size=DEV_BATCH_SIZE,
                             shuffle=False,
                             num_workers=1,
                             collate_fn=pad)

In [None]:
hidden_states_test = get_hidden(model, test_iter)

In [None]:
output = open('CMCL_XLNET_Test_Embeddings.pkl', 'wb')
pickle.dump(hidden_states_test, output)

In [None]:
y, Y_hat_raw, _ = eval(model, val_iter)

In [None]:
def plot_sent(sent, y, y_hat):
  eg_results = {}
  sent_len = len(sent)

  for idx, target in enumerate(['nFix',	'FFD',	'GPT', 'TRT', 'fixProp']):
    eg_results[target] = y[:,idx]
    eg_results['PRED_'+target] = y_hat[:,idx]

  fig, ax = plt.subplots()
  locs, labels = plt.xticks()  # Get the current locations and labels.
  plt.xticks(list(range(sent_len)), sent, rotation=-90)  # Set text labels and properties.

  plt.plot(list(range(sent_len)), eg_results['nFix'], color='darkblue', label = 'nFix')
  plt.plot(list(range(sent_len)), eg_results['PRED_nFix'], '--', color='darkblue')

  plt.plot(list(range(sent_len)), eg_results['FFD'], color='darkred', label = 'FFD')
  plt.plot(list(range(sent_len)), eg_results['PRED_FFD'], '--', color='darkred')

  plt.plot(list(range(sent_len)), eg_results['GPT'], color='darkorange', label = 'GPT')
  plt.plot(list(range(sent_len)), eg_results['PRED_GPT'], '--', color='darkorange')

  plt.plot(list(range(sent_len)), eg_results['TRT'], color='darkgreen', label = 'TRT')
  plt.plot(list(range(sent_len)), eg_results['PRED_TRT'], '--', color='darkgreen')

  plt.plot(list(range(sent_len)), eg_results['fixProp'], color='deeppink', label = 'fixProp')
  plt.plot(list(range(sent_len)), eg_results['PRED_fixProp'], '--', color='deeppink')

  #plt.yscale('symlog')

  plt.legend(loc='upper right', ncol=len(eg_results))
  plt.tight_layout()
  plt.ylim((0,110))
  plt.title('MAE '+str(float("{0:.5f}".format(mean_absolute_error(y, y_hat)))))
  plt.show()

In [None]:
from collections import Counter

dev_results = {}
dev_store = {}

for idx, (x, y, y_hat) in enumerate(list(zip(X_val, Y_val, Y_hat_raw))): 
  dev_results[idx] = mean_absolute_error(np.array(y), y_hat[:np.array(y).shape[0],:])
  dev_store[idx] = [x, np.array(y), y_hat[:np.array(y).shape[0],:]]    

In [None]:
# Worst to Best
for val_idx, score in Counter(dev_results).most_common():
    x, y, y_hat = dev_store[val_idx]
    plot_sent(x, y, y_hat)
    print(' ')