In [1]:
from collections import defaultdict
from urllib import request
import json
import pandas as pd
from math import floor, ceil, log10
import os
from glob import glob
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report
from transformers import T5Tokenizer
from torch.utils.data import TensorDataset
import random

In [2]:
#Seed setting

seed = 28

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

In [3]:
def parse_conllu_using_pandas(block):
    records = []
    for line in block.splitlines():
        if not line.startswith('#'):
            records.append(line.strip().split('\t'))
    return pd.DataFrame.from_records(
        records,
        columns=['ID', 'FORM', 'TAG', 'Misc1', 'Misc2'])

In [4]:
def tokens_to_labels(df):
    return (
        df.FORM.tolist(),
        df.TAG.tolist()
    )

In [5]:
PREFIX = "https://raw.githubusercontent.com/UniversalNER/"
DATA_URLS = {
    "en_ewt": {
        "train": "UNER_English-EWT/master/en_ewt-ud-train.iob2",
        "dev": "UNER_English-EWT/master/en_ewt-ud-dev.iob2",
        "test": "UNER_English-EWT/master/en_ewt-ud-test.iob2"
    },
    "en_pud": {
        "test": "UNER_English-PUD/master/en_pud-ud-test.iob2"
    }
}

In [6]:
# en_ewt is the main train-dev-test split
# en_pud is the OOD test set
data_dict = defaultdict(dict)
for corpus, split_dict in DATA_URLS.items():
    for split, url_suffix in split_dict.items():
        url = PREFIX + url_suffix
        with request.urlopen(url) as response:
            txt = response.read().decode('utf-8')
            data_frames = map(parse_conllu_using_pandas,
                              txt.strip().split('\n\n'))
            token_label_alignments = list(map(tokens_to_labels,
                                              data_frames))
            data_dict[corpus][split] = token_label_alignments


In [7]:
# Saving the data so that you don't have to redownload it each time.
with open('ner_data_dict.json', 'w', encoding='utf-8') as out:
    json.dump(data_dict, out, indent=2, ensure_ascii=False)

In [None]:
# Each subset of each corpus is a list of tuples where each tuple
# is a list of tokens with a corresponding list of labels.

# Train on data_dict['en_ewt']['train']; validate on data_dict['en_ewt']['dev']
# and test on data_dict['en_ewt']['test'] and data_dict['en_pud']['test']
data_dict['en_ewt']['train'][0], data_dict['en_pud']['test'][1]

In [58]:
#creating functions

#Function to change 7 labels to simple 3 label system.
def simplify_tags_to_3labels(tag_seq):
  simplified_tags = []

  for tag in tag_seq:

    if tag == "O":
      simplified_tags.append("O") #Non-entities

    elif tag == "B-LOC" or tag == "B-PER" or tag == "B-ORG": #All B tags
      simplified_tags.append("B")

    elif tag == "I-LOC" or tag == "I-PER" or tag == "I-ORG": #All I tags
      simplified_tags.append("I")

    else:
      simplified_tags.append("O") #Incase there's unexpected tags

  return simplified_tags


#function to convert the token/tag pairs in dictionaries into spaced out strings
#also can do label simplification is simplified = true
def prepare_data(data_tuples, simplified = False):
  word_inputs = [] #list of token strings
  label_outputs = [] #list of label strings

  for tokens, tags in data_tuples:
    if simplified:
      tags = simplify_tags_to_3labels(tags) #if simple tagset

    word_inputs.append(" ".join(tokens)) #join tokens to string with a space
    label_outputs.append(" ".join(tags))


  return word_inputs, label_outputs



In [33]:
#DATA PREPARATION

#COnvert the train, dev and test splits into token and label strings
prep_train_sentences, prep_train_tags = prepare_data(data_dict["en_ewt"]["train"], simplified = False)

prep_dev_sentences, prep_dev_tags = prepare_data(data_dict["en_ewt"]["dev"], simplified = False)

prep_test_sentences1, prep_test_labels1 = (prepare_data(data_dict["en_ewt"]["test"], simplified=False))

prep_test_sentences2, prep_test_labels2 = (prepare_data(data_dict["en_pud"]["test"], simplified=False))

#Merges the test sets
prep_test_sentences = prep_test_sentences1 + prep_test_sentences2
prep_test_labels = prep_test_labels2 + prep_test_labels1




#TOKENISING THE DATA

#Load T5 tokeniser
tokeniser = T5Tokenizer.from_pretrained("t5-small")

#convert strings to token IDs, apply padding/truncation up 128 tokens
tokenised_train_sentences = tokeniser(prep_train_sentences, padding="max_length", truncation = True, max_length=128, return_tensors="pt")

tokenised_train_labels = tokeniser(prep_train_tags, padding="max_length", truncation=True, max_length=128, return_tensors="pt")

tokenised_dev_sentences = tokeniser(prep_dev_sentences, padding="max_length", truncation=True, max_length=128, return_tensors="pt")

tokenised_dev_labels = tokeniser(prep_dev_tags, padding="max_length", truncation=True, max_length=128, return_tensors="pt")

tokenised_test_sentences = tokeniser(prep_test_sentences, padding="max_length", truncation=True, max_length=128, return_tensors="pt")

tokenised_test_labels = tokeniser(prep_test_labels, padding="max_length", truncation=True, max_length=128, return_tensors="pt")





#PREFIXING INPUT


#Tell t5 that the task is NER to guide it
prep_train_sentences = ["ner: " + s for s in prep_train_sentences]
prep_dev_sentences   = ["ner: " + s for s in prep_dev_sentences]
prep_test_sentences  = ["ner: " + s for s in prep_test_sentences]


In [34]:
#Change padding to -100 in the labels


#clone the original input so can modify it but keep the original
train_labels = tokenised_train_labels.input_ids.clone()

train_size = train_labels.shape[0]
train_sequen_length = train_labels.shape[1]

for i in range(train_size):
  for j in range(train_sequen_length):

    if train_labels[i, j].item() == tokeniser.pad_token_id: #replace the padding token

      train_labels[i, j] =-100


#repeat the same procedure for the dev set
dev_labels = tokenised_dev_labels.input_ids.clone()

dev_size = dev_labels.shape[0]
dev_sequen_length = dev_labels.shape[1]

for i in range(dev_size):
  for j in range(dev_sequen_length):

    if dev_labels[i, j].item() == tokeniser.pad_token_id:

      dev_labels[i,j] = -100



In [35]:
#make tensor dataset because the model expects batches of (input_ids, attention_mask, labels) in training & validation

training_dataset = TensorDataset(tokenised_train_sentences.input_ids, tokenised_train_sentences.attention_mask, train_labels)

dev_dataset = TensorDataset(tokenised_dev_sentences.input_ids, tokenised_dev_sentences.attention_mask, dev_labels)

In [36]:
#Defining the device as cuda

device = "cuda"

In [None]:
##Get the actual T5 model and move it to cuda

from transformers import T5ForConditionalGeneration

model = T5ForConditionalGeneration.from_pretrained("t5-small")
model.to(device)

In [38]:
#shuffling data before batching to the model overfitting to one sequence of data

input_ids = tokenised_train_sentences.input_ids
attention_mask = tokenised_train_sentences.attention_mask

N = input_ids.shape[0]

#generating random permutation for shuffling
randperm = torch.randperm(N)

#apply the random permutation to data
shuffled_input_ids = input_ids[randperm]
shuffled_attention_mask = attention_mask[randperm]
shuffled_labels = train_labels[randperm]

#move data to cuda
shuffled_input_ids = shuffled_input_ids.to(device)
shuffled_attention_mask = shuffled_attention_mask.to(device)
shuffled_labels = shuffled_labels.to(device)


In [39]:
#TRAINING LOOP

#define optimiser
optimiser = torch.optim.AdamW(model.parameters(), lr=5e-5)

batch_size = 8
epochs = 3
num_steps = ceil(N/batch_size)

#loop over epochs
for epoch in range(1, epochs+1):
  model.train() #set model to training

  total_loss = 0


  for step in range(num_steps):

    #loop over batches
    low = step * batch_size
    high = low + batch_size

    batch_input_ids = shuffled_input_ids[low:high].to(device)
    batch_attention_mask = shuffled_attention_mask[low:high].to(device)
    batch_labels = shuffled_labels[low:high].to(device)

    #Forward pass - predicting outputs and loss
    outputs = model(input_ids=batch_input_ids, attention_mask = batch_attention_mask, labels = batch_labels)

    loss = outputs.loss # cross entropy loss

    loss.backward() # backwards pass

    optimiser.step() #update model parameters

    optimiser.zero_grad() #clear gradient

    total_loss += loss.item() #track loss


  avg_loss = total_loss / num_steps # get average loss
  print(avg_loss)

[Epoch 1] avg training loss = 0.4301
[Epoch 2] avg training loss = 0.2470
[Epoch 3] avg training loss = 0.2011


In [40]:
#Validation/Dev Set


#extracting input ids and attention mask
dev_input_ids = tokenised_dev_sentences.input_ids
dev_attention_mask = tokenised_dev_sentences.attention_mask

#moving data to cuda
dev_input_ids = dev_input_ids.to(device)
dev_attention_mask = dev_attention_mask.to(device)
dev_labels = dev_labels.to(device)


N_dev = dev_input_ids.shape[0] #no. of dev examples
num_steps_dev = ceil(N_dev / batch_size) #number of batches

model.eval() #set model to evaluation
total_dev_loss = 0

#no gradient tracking
with torch.no_grad():
  for step in range(num_steps_dev):

    low = step * batch_size
    high = low + batch_size

    batch_input_ids = dev_input_ids[low:high]
    batch_attention_mask = dev_attention_mask[low:high]
    batch_labels = dev_labels[low:high]

    #only forward pass
    outputs = model(input_ids=batch_input_ids, attention_mask=batch_attention_mask, labels=batch_labels)

    total_dev_loss += outputs.loss.item()

avg_dev_loss = total_dev_loss / num_steps_dev #compute average loss
print(avg_dev_loss)



0.15620398936993574


In [41]:
#GENERATING PREDICTIONS on DEV SET

all_preds = [] #emppty list for predicted tag sequences
model.eval() #model evaluation mode

#disable gradient computing
with torch.no_grad():
    for step in range(num_steps_dev):
        low  = step * batch_size
        high = min(low + batch_size, N_dev)

        #batch inputs
        batch_input_ids = dev_input_ids[low:high]
        batch_attention_mask = dev_attention_mask[low:high]

        #generate predicted sequences
        generated_ids = model.generate(
            input_ids=batch_input_ids,
            attention_mask=batch_attention_mask,
            max_length=dev_sequen_length
        )

        #decode the generated token IDS into strings and split into tag sequences
        pred_strs= tokeniser.batch_decode(generated_ids, skip_special_tokens=True)
        batch_preds = [s.split() for s in pred_strs]

        all_preds.extend(batch_preds) #append batches

print(len(all_preds), len(prep_dev_tags))  #test that they should match


2001 2001


In [42]:
#Span Matching Functions


#Extract enetity spans from a sequence of BIO tags
#returns the start, end and label for each span
def get_spans(tags, simplified = False):
  spans = []
  start = None
  label = None

  for i, tag in enumerate(tags):
    if tag == 'O':
      #close the span if there is one
      if start is not None:
        spans.append((start, i, label))
        start = None
        label = None

    elif tag.startswith('B'):
      #if prev span is still open then close it
      if start is not None:
        spans.append((start, i, label))

      start = i
      label = '' if simplified else tag[2:] #remove label prefix for simple tagset

    elif tag.startswith('I'):
      continue #contiue getting the span

  if start is not None: #close spans that reach the end of sequence
      spans.append((start, len(tags), label))

  return spans



#Function to change BIO tags by removing the label types (LOC, PER, ORG)
def simplify_bio_sequences(seqs):
    return [[('O' if t=='O' else t[0]) for t in seq] for seq in seqs]




#calculate the span matching accuracy
#span is correct if its start and end and optionally the label match
def span_match_score(gold_seqs, pred_seqs):
  total = 0 #gold spans
  correct = 0 #cirrect pred spans

  for gseq, pseq in zip(gold_seqs, pred_seqs):
        goldspan = set(get_spans(gseq)) #get gold spans
        predspan = set(get_spans(pseq)) #get pred spans
        total   += len(goldspan)
        correct += len(goldspan & predspan) #count the overlaps

  return correct / total if total else 0.0

In [43]:
#Evaluate span matching accuracy on the dev set

#compute labelled span matching score
#need to convert the gold tag strings to list but the preds are already a list
labelled_acc = span_match_score( [gold.split() for gold in prep_dev_tags],
                                   all_preds )

#compute unlabelled span matching scores
#need to simplify the labels in both the gold and preds
unlabelled_acc = span_match_score( simplify_bio_sequences([gold.split() for gold in prep_dev_tags]), simplify_bio_sequences(all_preds) )

print("Labelled span", labelled_acc)
print("Unlabelled span", unlabelled_acc)


Labelled span acc: 0.2484472049689441
Unlabelled span acc: 0.2660455486542443


In [44]:
#aligning the predicted and gold tag sequences to make sure the evaluation is able to be done


def align_preds_to_gold(gold_strs, pred_seqs, pad_tag="O"):
    aligned_preds = []
    aligned_gold  = []
    for gold_s, pred in zip(gold_strs, pred_seqs):
        gold_list = gold_s.split()
        L = len(gold_list)

        # truncate or pad the prediction
        if len(pred) >= L:
            pred_list = pred[:L] #shorten if too long
        else:
            pred_list = pred + [pad_tag] * (L - len(pred)) #pad if too short

        aligned_gold .append(gold_list)
        aligned_preds.append(pred_list)

    return aligned_gold, aligned_preds


In [45]:

#Evaluate classification - full 7 label set

#align the gold and pred sequences
aligned_gold_seqs, aligned_pred_seqs = align_preds_to_gold(prep_dev_tags, all_preds)

# now flatten for the classification report
flat_gold = [tag for seq in aligned_gold_seqs for tag in seq]
flat_pred = [tag for seq in aligned_pred_seqs for tag in seq]

print(len(flat_gold), len(flat_pred))  #check that these match

#define tag order for the report
label_names = ['O','B-LOC','I-LOC','B-ORG','I-ORG','B-PER','I-PER']

#print eval metrics
print(classification_report(
    flat_gold,
    flat_pred,
    labels=label_names,
    zero_division=0
))


25149 25149
              precision    recall  f1-score   support

           O       0.96      0.98      0.97     23653
       B-LOC       0.39      0.20      0.26       399
       I-LOC       0.18      0.08      0.11       148
       B-ORG       0.28      0.06      0.10       224
       I-ORG       0.20      0.13      0.16       186
       B-PER       0.60      0.55      0.57       343
       I-PER       0.46      0.50      0.48       196

    accuracy                           0.94     25149
   macro avg       0.44      0.36      0.38     25149
weighted avg       0.93      0.94      0.93     25149



In [46]:
#TEST DATA

#clone the input IDs so we can modify them and keep originals
test_labels = tokenised_test_labels.input_ids.clone()

test_size = test_labels.shape[0]
test_sequen_length = test_labels.shape[1]

#replace the paddin tokens with -100 so theyre ignored in loss
for i in range(test_size):
  for j in range(test_sequen_length):

    if test_labels[i, j].item() == tokeniser.pad_token_id:

      test_labels[i, j] =-100

#move data to cuda
test_input_ids = tokenised_test_sentences.input_ids.to(device)
test_attention_mask = tokenised_test_sentences.attention_mask.to(device)
test_labels = tokenised_test_labels.to(device)

N_test = test_input_ids.size(0)
num_steps_test = ceil(N_test / batch_size)

all_test_preds = [] #predicted tags
model.eval() #puts model in evaluation mode

#generating predictions for the test data

with torch.no_grad():
    for step in range(num_steps_test):
        low  = step * batch_size
        high = min(low + batch_size, N_test)

        #get a batch
        batch_input_ids= test_input_ids[low:high]
        batch_attention_mask = test_attention_mask[low:high]

        #generating the actual output sequences
        generated_ids = model.generate(
            input_ids=batch_input_ids,
            attention_mask=batch_attention_mask,
            max_length=test_sequen_length
        )

        #decode the token ids to strings and then split into tag lists
        pred_strs    = tokeniser.batch_decode(generated_ids, skip_special_tokens=True)
        batch_preds  = [s.split() for s in pred_strs]


        all_test_preds.extend(batch_preds)

print(len(all_test_preds), len(prep_test_labels))  # they should match


3077 3077


In [47]:
#FINAL EVALUATION FROM TEST DATA

#align the gold and pred sets
aligned_gold_test_seqs, aligned_pred_test_seqs = align_preds_to_gold(prep_test_labels, all_test_preds)

# now flatten into single lists for classification report
flat_gold_test = [tag for seq in aligned_gold_test_seqs for tag in seq]
flat_pred_test = [tag for seq in aligned_pred_test_seqs for tag in seq]

#define the label set
label_names = ['O','B-LOC','I-LOC','B-ORG','I-ORG','B-PER','I-PER']

#print the eval results
print(classification_report(
    flat_gold_test,
    flat_pred_test,
    labels=label_names,
    zero_division=0
))


              precision    recall  f1-score   support

           O       0.93      0.97      0.95     43029
       B-LOC       0.01      0.00      0.01       742
       I-LOC       0.00      0.00      0.00       247
       B-ORG       0.05      0.01      0.01       557
       I-ORG       0.02      0.01      0.01       431
       B-PER       0.08      0.05      0.06       864
       I-PER       0.03      0.02      0.03       403

   micro avg       0.91      0.91      0.91     46273
   macro avg       0.16      0.15      0.15     46273
weighted avg       0.87      0.91      0.89     46273



In [48]:
#SPAN ACCURACY FOR TEST SET

labelled_acc=span_match_score(aligned_gold_test_seqs, aligned_pred_test_seqs)

#simplify the pred and gold sequences first
unlabelled_acc = span_match_score(simplify_bio_sequences(aligned_gold_test_seqs), simplify_bio_sequences(aligned_pred_test_seqs))

print("Labelled span acc:",   labelled_acc)
print("Unlabelled span acc:", unlabelled_acc)


Labelled span acc: 0.0106333795654184
Unlabelled span acc: 0.015256588072122053


In [56]:
#EVALUATE TEST SET - SIMPLFIED LABEL SET

#align the pred and golds
aligned_gold, aligned_pred = align_preds_to_gold(prep_test_labels, all_test_preds)

#convert the bio tags to the simplified tags
gold_simple = [simplify_tags_to_3labels(seq) for seq in aligned_gold]
pred_simple = [simplify_tags_to_3labels(seq) for seq in aligned_pred]

#flatten the lists for classification report
flat_gold_simple = [t for seq in gold_simple for t in seq]
flat_pred_simple = [t for seq in pred_simple for t in seq]

In [59]:
#EVALUATE TEST SET - SIMPLFIED LABEL SET

#align the pred and golds
aligned_gold, aligned_pred = align_preds_to_gold(prep_test_labels, all_test_preds)

#convert the bio tags to the simplified tags
gold_simple = [simplify_tags_to_3labels(seq) for seq in aligned_gold]
pred_simple = [simplify_tags_to_3labels(seq) for seq in aligned_pred]

#flatten the lists for classification report
flat_gold_simple = [t for seq in gold_simple for t in seq]
flat_pred_simple = [t for seq in pred_simple for t in seq]


#print the eval metrics
print(classification_report(
    flat_gold_simple,
    flat_pred_simple,
    labels=['O', 'B', 'I'],
    zero_division=0
))


#compute the span level accuracy.
labelled_simple_score = span_match_score(gold_simple, pred_simple)
unlabelled_simple_score = labelled_simple_score #its the same bc theres no entities

print(f"Simplified labelled span match", labelled_simple_score:.4f)
print(f"Simplified unlabelled span match", unlabelled_simple_score:.4f)



              precision    recall  f1-score   support

           0       0.00      0.00      0.00         0
           B       0.09      0.03      0.05      2163
           I       0.03      0.02      0.02      1081

   micro avg       0.07      0.03      0.04      3244
   macro avg       0.04      0.02      0.02      3244
weighted avg       0.07      0.03      0.04      3244

Simplified F1 scores (B/I/O):
              precision    recall  f1-score   support

           0       0.00      0.00      0.00         0
           B       0.09      0.03      0.05      2163
           I       0.03      0.02      0.02      1081

   micro avg       0.07      0.03      0.04      3244
   macro avg       0.04      0.02      0.02      3244
weighted avg       0.07      0.03      0.04      3244

Simplified labelled span match:   0.0153
Simplified unlabelled span match: 0.0153
