A very quick and dirty implementation of a neural parser using BERT base.

In [1]:
!pip install -qq transformers

In [1]:
import transformers
from transformers import BertModel, BertTokenizer, AdamW, get_linear_schedule_with_warmup
from transformers import BertTokenizerFast, BertModel
import torch
import numpy as np
import pandas as pd
from pylab import rcParams
import matplotlib.pyplot as plt
from matplotlib import rc
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report
from collections import defaultdict
from textwrap import wrap
from tqdm import tqdm

from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from transformers import get_scheduler

In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="1"

In [3]:
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
from transformers import BertTokenizerFast, BertModel
from src.dataset import ConllDataset

class TorchConllDataset(Dataset):
    def __init__(self, dataset_file, tokenizer, max_len) -> None:
        original_file = 'data/english/train/wsj_train.first-1k.conll06'
        self.dataset = ConllDataset(dataset_file)
        self.tokenizer = tokenizer
        self.max_len = max_len
    
    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
      dataset_sentence = self.dataset[index]

      sentence = dataset_sentence.sentence['form'].to_list()
      arcs = dataset_sentence.get_arcs()

      encoding = self.tokenizer.encode_plus(
        sentence,
        add_special_tokens = True, # Add '[CLS]' and '[SEP]'
        max_length = 512,           # Pad & truncate all sentences.
        is_split_into_words=True,
        padding='max_length',
        pad_to_max_length = True,
        return_attention_mask = True,   # Construct attn. masks.
        return_tensors = 'pt',     # Return pytorch tensors.
      )

      word_ids = torch.asarray([id if id != None else -1 for id in encoding.word_ids()])
      # Padd with -1 to have equal size
      arcs = arcs + [[-1, -1]] * (len(word_ids) - len(arcs))
      arcs = torch.asarray(arcs)

      return {
        'word_ids': word_ids,
        'input_ids': encoding['input_ids'].flatten(),
        'attention_masks': encoding['attention_mask'].flatten(),
        'arcs': arcs,
      }
    
    @classmethod
    def get_dataloader(cls, dataset_file, tokenizer, max_len, batch_size, random=False):
      dataset = TorchConllDataset(dataset_file, tokenizer, max_len)

      return DataLoader(dataset, batch_size=batch_size, shuffle=random)

In [5]:
import src.decoder as decoder

class NeuralDependencyParser(nn.Module):
    def __init__(self, model_name, hidden_size = 256):
        super().__init__()
        self.bert = BertModel.from_pretrained(model_name)
        self.bert.trainable = False
        self.linear_first = nn.Linear(self.bert.config.hidden_size * 2, hidden_size)
        self.linear_second = nn.Linear(hidden_size, 1)
        self.function = nn.ReLU()
        self.softmax = nn.Softmax(dim=0) # Softmax column-wise, since every dependent has only one head
        self.loss = nn.CrossEntropyLoss()

        self.decoder = decoder.CLE_n()

    def tensors_from_words(self, output, word_id):
        new_output = []

        prev_word_id = -1
        for logit, word_id in zip(output, word_id):
            # Mean join tensors that belong to the same word
            if word_id == prev_word_id and len(new_output):
                new_output[-1] = torch.stack([new_output[-1], logit]).mean(dim=0)
            else:
                new_output.append(logit)
            
            prev_word_id = word_id

        # Gets tensors for input tokens (removes special [SEP])
        output = torch.stack(new_output)[:-1]
        return output

    def forward(self, input_ids, attention_masks, word_ids):
        with torch.no_grad():
            logits = self.bert(input_ids=input_ids, attention_mask=attention_masks)[0]

        batch_scores = []

        # Go through the batch and create scores
        for logit, attention_mask, word_id in zip(logits, attention_masks, word_ids):
            # Only get non-padded values
            output = logit[torch.where(attention_mask == 1)]

            # Join subword tensors into one to represent one word
            word_output = self.tensors_from_words(output, word_id)

            arcs = []
            for h, head in enumerate(word_output):
                head_arcs = []
                for d, dep in enumerate(word_output):
                    if h != d and d != 0: # Don't want the same concatinated and the first (meaning root is dep)
                        head_dep_arc = torch.cat((head, dep))
                    else:
                        head_dep_arc = torch.zeros((head.shape[0] * 2)).to(device)

                    head_arcs.append(head_dep_arc)

                arcs.append(torch.stack(head_arcs))


            arcs = torch.stack(arcs).to(device)
            h = self.function(self.linear_first(arcs))
            scores = self.softmax(self.linear_second(h)).flatten(start_dim=1)
            batch_scores.append(scores)

            del output
            del word_output
            del arcs


        return batch_scores

    def loss_fn(self, batch_scores, batch_arcs):
        losses = []
        for scores, arcs in zip(batch_scores, batch_arcs):
            true_scores = torch.zeros(scores.shape).to(device)
            for head, dep in arcs:
                # Since we padded the arcs, when we get [-1, -1], we stop
                if head == -1 and dep == -1:
                    break
                true_scores[head, dep] = 1

            loss = self.loss(scores, true_scores)
            losses.append(loss)

        return torch.stack(losses).mean()


    def decode(self, batch_scores):
        predicted_tree = []
        for scores in batch_scores:
            parser_scores = torch.clone(scores).detach().cpu().numpy()
            parser_scores[:, 0] = -np.Inf
            parser_scores[np.diag_indices_from(parser_scores)] = -np.Inf

            tree = self.decoder.decode(parser_scores)
            predicted_tree.append(tree)

        return predicted_tree

In [9]:
tokenizer = BertTokenizerFast.from_pretrained('bert-base-cased')
training_file = 'data/english/train/wsj_train.conll06'
dev_file = 'data/english/dev/wsj_dev.conll06.gold'
test_file = 'data/english/dev/wsj_dev.conll06.gold'

train_dataloader = TorchConllDataset.get_dataloader(training_file, tokenizer, 512, 64, random=True)
dev_dataloader = TorchConllDataset.get_dataloader(dev_file, tokenizer, 512, 32)
test_dataloader = TorchConllDataset.get_dataloader(dev_file, tokenizer, 512, 8)

model = NeuralDependencyParser('bert-base-cased', 512)
model = model.to(device)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [17]:
train_dataloader = TorchConllDataset.get_dataloader(training_file, tokenizer, 512, 32, random=True)
dev_dataloader = TorchConllDataset.get_dataloader(dev_file, tokenizer, 512, 32)
test_dataloader = TorchConllDataset.get_dataloader(dev_file, tokenizer, 512, 8)

In [21]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

num_epochs = 2
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
    name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
)

In [24]:
import src.evaluation as evaluation
from tqdm import tqdm

postfix_data = {'score': 0, 'loss': 0}

model.train()
with tqdm(range(num_training_steps), total=num_training_steps, postfix=postfix_data) as progress_bar:
    for epoch in range(num_epochs):
        pred_trees = []
        true_trees = []
        for step, batch in enumerate(train_dataloader):

            scores= model(
                input_ids = batch['input_ids'].to(device),
                attention_masks = batch['attention_masks'].to(device),
                word_ids = batch['word_ids'].to(device),
            )

            loss = model.loss_fn(scores, batch['arcs'].to(device))
            loss.backward()

            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            progress_bar.update(1)

            postfix_data['loss'] = loss.item()
            progress_bar.set_postfix(postfix_data)

        model.eval()
        with tqdm(dev_dataloader, total=len(dev_dataloader)) as dev_progress_bar:
            true_trees = []
            predited_trees = []
            for batch in dev_progress_bar:
                scores= model(
                    input_ids = batch['input_ids'].to(device),
                    attention_masks = batch['attention_masks'].to(device),
                    word_ids = batch['word_ids'].to(device),
                )

                true_arcs = []
                for arcs in batch['arcs']:
                    index = -1
                    for i, (head, dep) in enumerate(arcs):
                        # Since we padded the arcs, when we get [-1, -1], we stop
                        if head == -1 and dep == -1:
                            index = i
                            break

                    true_arcs.append(arcs[:i].tolist())
                predicted_arcs = model.decode(scores)

                true_trees.extend(true_arcs)
                pred_trees.extend(predicted_arcs)
            
            score = evaluation.uas(true_trees, pred_trees)
            postfix_data['score'] = score
            progress_bar.set_postfix(postfix_data)

 38%|███▊      | 721/1880 [09:54<16:28,  1.17it/s, score=0, loss=2.38]

In [23]:
model.eval()
with tqdm(dev_dataloader, total=len(dev_dataloader)) as dev_progress_bar:
    true_trees = []
    predited_trees = []
    for batch in dev_progress_bar:
        scores= model(
            input_ids = batch['input_ids'].to(device),
            attention_masks = batch['attention_masks'].to(device),
            word_ids = batch['word_ids'].to(device),
        )

        true_arcs = []
        for arcs in batch['arcs']:
            index = -1
            for i, (head, dep) in enumerate(arcs):
                # Since we padded the arcs, when we get [-1, -1], we stop
                if head == -1 and dep == -1:
                    index = i
                    break

            true_arcs.append(arcs[:i].tolist())
        predicted_arcs = model.decode(scores)

        true_trees.extend(true_arcs)
        pred_trees.extend(predicted_arcs)
    
    score = evaluation.uas(true_trees, pred_trees)
score

100%|██████████| 34/34 [00:29<00:00,  1.16it/s]


87.78717804098515

In [None]:
model.eval()
dataset = ConllDataset(test_file)
with tqdm(test_dataloader, total=len(test_dataloader)) as progress_bar:
    true_trees = []
    predited_trees = []
    for i, batch in enumerate(progress_bar):
        scores= model(
            input_ids = batch['input_ids'].to(device),
            attention_masks = batch['attention_masks'].to(device),
            word_ids = batch['word_ids'].to(device),
        )

        predicted_arcs = model.decode(scores)
        predicted_tree = sorted(predicted_arcs[0], key=lambda x: x[1])
        dataset.set_arcs(sentence_index=i, arcs=predicted_tree)
        
    
dataset.write(filepath='./results/evaluation-test-en-neural.conll06')