In [None]:
import pandas as pd
import pickle
from collections import Counter
import os
from typing import NamedTuple, Sequence, Any, List
import string
from utils import *
import pandas as pd
import spacy
from transformers import LongformerForTokenClassification, AutoTokenizer
import numpy as np
import math
import torch
import torch.nn as nn
import random
from sklearn.metrics import f1_score, classification_report

## Configuration, data loading & spliting, training

In [None]:
class Config:
    """
    Set the training configurations.
    """
    n_classes = 15
    n_epochs = 3
    lr = 1e-5
    model = LongformerForTokenClassification.from_pretrained('allenai/longformer-base-4096', num_labels=15)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                              factor=0.9,
                                                              mode="min",
                                                              patience=10,
                                                              cooldown=10,
                                                              min_lr=5e-6,
                                                              verbose=True)
    # why setting batch_size = 1? Because even setting batch = 2 will result in a 
    # CUDA out of memory error
    train_batch_size = 32
    dev_batch_size = 32
    test_batch_size = 32
    train_split = 0.8
    #tokenizer = RobertaTokenizerFast.from_pretrained('roberta-base')
    tokenizer = AutoTokenizer.from_pretrained("allenai/longformer-base-4096")
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
with open('drive/MyDrive/NER_project/longformer_train_set.pickle', 'rb') as train:
    train_data = pickle.load(train)
with open('drive/MyDrive/NER_project/longformer_dev_set.pickle', 'rb') as dev:
    dev_data = pickle.load(dev)
with open('drive/MyDrive/NER_project/longformer_test_set.pickle', 'rb') as test:
    test_data = pickle.load(test)
train_set = DataLoader(train_data, batch_size=Config.train_batch_size, shuffle=True, pin_memory=True)
dev_set = DataLoader(dev_data, batch_size=Config.dev_batch_size, shuffle=True, pin_memory=True)
test_set = DataLoader(test_data, batch_size=Config.test_batch_size, shuffle=True, pin_memory=True)

In [None]:
trainer = Trainer(Config, train_set, dev_set, test_set, save_model_patj='longformer_model.pt')
trainer.train()

## Evaluate our model on the test set (token-level)

In [None]:
trainer.run_on_dev_or_test(dataset='test')

## Get mention-level micro F1 on the test set


In this project, the mentions are defined on sentence-level.
For instance, there is an article, which consists of only 3 sentences:

"This is the beginning of the article, this is the evidence of the article, and this is the end of the article."

We define a sentence as a part of the text delimited by one of the following punctuations: a period, a question mark, an exclamation point, a colon, a comma, or a semicolon. Therefore, in the example above, there are three sentences, and their corresponding labels are LEAD, EVIDENCE, and CONCLUDING STATEMENT, respectively. 


In [None]:
# load the held-out untokenized test set data, which is useful for computing mention-level F1 
with open('original_test_set.pickle', 'rb') as t:
    final_test_data = pickle.load(t)

## Now we only have a label for every subtoken. How do we get the label for each sentence?
1. Convert one label for one subtoken to one label for one token (delimited by a whitespace) in the article
2. If a sentence's position is at the first or last 20% of an article, assign it the most frequent label of its tokens.
3. If a sentence's position is at the middle 60% of an article: if the second most frequent label of its tokens account for more than 20% of the number of labels in this sentence, use the second most frequent label. Otherwise, use the most frequent label of its tokens. In the training data, "EVIDENCE" is so predominant, this strategy helps to prevent the model from producing too many "EVIDENCEs".

Here is an example. Assume that below is a sentence at the middle 60% of an article:
This [B-EVIDENCE] is [I-EVIDENCE] a [I-CLAIM] claim [I-CLAIM] of [I-EVIDENCE] an [I-EVIDENCE] article [I-EVIDENCE]

It will be classified as a CLAIM, because CLAIM is the second most frequent label of the sentence's tokens, and it accounts for more than 20% of the number of labels in this sentence.

In [None]:
def inference(test_txt, config, model, max_length):
    # given a tokenized article, get each sentence's label
    # define our label to number dictionary
    label_to_num = {'O': 0,
                    'B-LEAD': 1,
                    'I-LEAD': 2,
                    'B-POSITION': 3,
                    'I-POSITION': 4,
                    'B-CLAIM': 5,
                    'I-CLAIM': 6,
                    'B-COUNTERCLAIM': 7,
                    'I-COUNTERCLAIM': 8,
                    'B-REBUTTAL': 9,
                    'I-REBUTTAL': 10,
                    'B-EVIDENCE': 11,
                    'I-EVIDENCE': 12,
                    'B-CS': 13,
                    'I-CS': 14}
    num_to_label = {label_to_num[key]: key for key in label_to_num}
    article = ' '.join(test_txt)
  
    processed_length = 0
    total_length = len(article)
    # store the normalized labels and the list of tokens
    normalized_labels = []
    tokens_list = []
    left_article = article
    # an article can be very long, 
    # therefore it may be split into several parts and processed several times
    # we combine the predicted result of each time
    # if we use longformer tokenizer, we don't need to worry about this problem, because max length = 4096
    # and we can process an article in one time
    # but we need to consider this when using roberta tokenizer, which has max length = 512
    while processed_length < total_length:
   
        last_token_pos = 0
        encoding = config.tokenizer(left_article, padding='max_length', truncation=True, return_offsets_mapping=True)
        ids = torch.tensor(encoding['input_ids']).reshape(1, -1).to(config.device)
        mask = torch.tensor(encoding['attention_mask']).reshape(1, -1).to(config.device)
        logits = model(input_ids=ids, attention_mask = mask).logits
        logits = torch.argmax(logits.view(-1, model.num_labels), axis=1).cpu().numpy()
        #get label for every sub-token
        predictions = [num_to_label[i] for i in logits]

        for i in range(max_length-1, 0, -1):
            if encoding['offset_mapping'][i][1] != 0:
        # find how many tokens have been processed
        # using the offset_mapping attribute in the truncated result
                last_token_pos = encoding['offset_mapping'][i][1]
                processed_length += last_token_pos + 1
                break
        temp_article = left_article[:last_token_pos]
        left_article = article[processed_length:]

      # find the index of every char and its corresponding label
        char_label_dic = {}
        for i in range(len(encoding['offset_mapping'])):
            for j in range(encoding['offset_mapping'][i][0], encoding['offset_mapping'][i][1] + 1):
                char_label_dic[j] = predictions[i]
      # normalize the subtokens' labels using the char label dictionary 
      # one label for every subtoken -> one label for every token split by whitespace in the article
        normalized_label = []
        curr_len = 0
        token_list = temp_article.split(' ')
        for token in token_list:
            curr = char_label_dic[curr_len + len(token) // 2]
            if curr == 'O':
                normalized_label.append(curr)
            else:
                normalized_label.append(curr[2:])
            curr_len = curr_len + len(token) + 1
        normalized_labels.extend(normalized_label)
        tokens_list.extend(token_list)
  
    # one label for one sentence
    sentences = []
    labels = []
    # temp variable, for spliting an article into multiple sentences. 
    # We normalize the labels by assigning a label to each sentence
    sentence = []
    label = []
    # delimiters of a sentence
    puncts_list = ['.', ',',';', '!','?',':']
    # our rule-based label normalization
    for i in range(len(tokens_list)):
        sentence.append(tokens_list[i])
        label.append(normalized_labels[i])
        for punct in puncts_list:
            if punct in tokens_list[i]:
                if len(sentence) > 3:
                    sentences.append(' '.join(sentence))
                    counter = Counter(label)
                if i <= len(token_list) * 0.2 or i >= len(token_list) * 0.8:
                    labels.append(counter.most_common(1)[0][0])
                else:
                    if len(counter) >= 2 and counter.most_common(2)[1][1] >= len(sentence) * 0.2:
                        labels.append(counter.most_common(2)[1][0])
              # if len(counter) >= 2 and counter.most_common(2)[1][1] > 1:
              #   labels.append(counter.most_common(2)[1][0])
                    else:
                        labels.append(counter.most_common(1)[0][0])
                sentence = []
                label = []
                break
    return labels, sentences

# given an AnnotatedDoc instance, return its correct, normalized label sequences
# This part of code is similar to the "test" method
def get_correct_labels(final_test_data):
    labels = encode_bio(final_test_data.tokens, final_test_data.mentions)
    normalized_labels = []
    for i in labels:
        if i[0] != 'B':
            normalized_labels.append(i)
        else:
            normalized_labels.append('I'+i[1:])
  
    #get the final predictions
    sentences = []
    labels = []
    # temp, for spliting an article into multiple sentences. We normalize the labels by assigning a label to each sentence
    sentence = []
    label = []
    puncts_list = ['.', ',',';', '!','?',':']
    for i in range(len(final_test_data.tokens)):
    # separator of a sentence
        sentence.append(final_test_data.tokens[i])
        label.append(normalized_labels[i])
    for punct in puncts_list:
        if punct in final_test_data.tokens[i]:
            if len(sentence) > 3:
                sentences.append(' '.join(sentence))
                label_to_add = label[-1].upper()
                if label_to_add == 'O':
                    labels.append(label_to_add)
                elif label_to_add == "I-CONCLUDING STATEMENT":
                    labels.append('CS')
                else:
                    labels.append(label_to_add[2:])
                sentence = []
                label = []
                break
    return labels, sentences

# get the mention(sentence)-level f1
def get_mention_f1(final_test_data, config, model, max_length=4096):
    model.to(config.device)
    predicted = []
    actual = []
    for i in range(len(final_test_data)):  
        predicted_labels, _ = inference(final_test_data[i].tokens, config, model, max_length=max_length)
        real_labels, _ = get_correct_labels(final_test_data[i])
        predicted.extend(predicted_labels)
        actual.extend(real_labels)
    print(f1_score(actual, predicted, average='micro'))
    print(classification_report(actual, predicted))

In [None]:
model = LongformerForTokenClassification.from_pretrained('allenai/longformer-base-4096', num_labels=15)
model.load_state_dict(torch.load('longformer_model.pt'))
get_mention_f1(final_test_data, Config, model)