## Abbreviation Detection Model Training and Inferencing

## Imports

Please ensure that these libraries are installed

In [None]:
!pip install datasets
!pip install transformers
!pip install wandb
!pip install seqeval

In [None]:
from datasets import load_metric, load_dataset
from transformers import AutoTokenizer
from transformers import RobertaTokenizer, RobertaModel, RobertaTokenizerFast ## If Using RoBERTa
# from transformers import AlbertTokenizer, AlbertModel, AlbertTokenizerFast ## If using ALBERT
# from transformers import BertTokenizer, BertModel, BertTokenizer ## If using BERT
from transformers import AutoModelForTokenClassification, TrainingArguments, Trainer
from transformers import DataCollatorForTokenClassification
import numpy as np
import wandb
import transformers
from transformers import EarlyStoppingCallback
from tqdm.notebook import tqdm
from datasets import ClassLabel, Sequence
import random
import pandas as pd
from IPython.display import display, HTML
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
from sklearn.metrics import confusion_matrix
from huggingface_hub import notebook_login
import csv
notebook_login()

In [None]:
from datasets import Dataset, DatasetDict

In [None]:
wandb.login()

In [None]:
%env WANDB_LOG_MODEL='end'
%env WANDB_WATCH=all
# %env TOKENIZERS_PARALLELISM = False ## If you are using multiple GPUs, set this to True
# wandb.init(project="abbDet-roberta-base", entity="<your-wandb-user>")

## Please choose the dataset you want to experiment with below. 

### Convert local BIO files to HuggingFace Dataset

In [None]:
# convert local BIO file to dataframe
def get_seg_wise_df(file="../PLODv1/filtered_data/train_bio.conll"):
    with open(file, "r") as fobj:
        tokens = fobj.readlines()
        
    id_list = []
    seg_list = []
    pos_list = []
    bio_list = []
    id_counter = 0
    toks = []
    pos = []
    bios = []
    for item in tokens:
        if item != "\n":
            toks.append(item.strip().split()[0])
            pos.append(item.strip().split()[1])
            bios.append(item.strip().split()[2])
        else:
            id_list.append(str(id_counter))
            seg_list.append(toks)
            pos_list.append(pos)
            bio_list.append(bios)
            id_counter += 1
            toks = []
            pos = []
            bios = []
            
    for i in range(len(bio_list)):
        for j in range(len(bio_list[i])):
            if bio_list[i][j] == "B-O":
                bio_list[i][j] = 0
            if bio_list[i][j] == "B-AC":
                bio_list[i][j] = 1
            if bio_list[i][j] == "I-AC":
                bio_list[i][j] = 2
            if bio_list[i][j] == "B-LF":
                bio_list[i][j] = 3
            if bio_list[i][j] == "I-LF":
                bio_list[i][j] = 4
    d = {'id': id_list, 'tokens': seg_list, 'pos_tags': pos_list, 'ner_tags': bio_list}
    df = pd.DataFrame(data=d)
    return df

In [None]:
train_df = get_seg_wise_df(file="../PLODv2/unfiltered_data/train_bio.conll")
vali_df = get_seg_wise_df(file="../PLODv2/unfiltered_data/val_bio.conll")
test_df = get_seg_wise_df(file="../PLODv2/unfiltered_data/test_bio.conll")

In [None]:
train_ds = Dataset.from_pandas(train_df)
vali_ds = Dataset.from_pandas(vali_df)
test_ds = Dataset.from_pandas(test_df)

In [None]:
train_ds = train_ds.cast_column("ner_tags", Sequence(ClassLabel(num_classes=5, names=['B-O', 'B-AC', 'I-AC', 'B-LF', 'I-LF'])))
vali_ds = vali_ds.cast_column("ner_tags", Sequence(ClassLabel(num_classes=5, names=['B-O', 'B-AC', 'I-AC', 'B-LF', 'I-LF'])))
test_ds = test_ds.cast_column("ner_tags", Sequence(ClassLabel(num_classes=5, names=['B-O', 'B-AC', 'I-AC', 'B-LF', 'I-LF'])))

In [None]:
datasets = DatasetDict({"train": train_ds, "validation":vali_ds, "test":test_ds})

### Choose HuggingFace Dataset

In [None]:
#uncomment to load data from huggingface
#datasets = load_dataset("surrey-nlp/PLOD-filtered")

### Checking the validity of some samples

In [None]:
datasets

In [None]:
# datasets['validation'][3]

In [None]:
label_list = datasets["train"].features[f"ner_tags"].feature.names
label_list

In [None]:
def show_random_elements(dataset, num_examples=10):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    
    df = pd.DataFrame(dataset[picks])
    for column, typ in dataset.features.items():
        if isinstance(typ, ClassLabel):
            df[column] = df[column].transform(lambda i: typ.names[i])
        elif isinstance(typ, Sequence) and isinstance(typ.feature, ClassLabel):
            df[column] = df[column].transform(lambda x: [typ.feature.names[i] for i in x])
    display(HTML(df.to_html()))

In [None]:
# show_random_elements(datasets["train"])

In [None]:
task = "ner" # Should be one of "ner", "pos" or "chunk"
model_checkpoint = "roberta-large" ## YOUR LANGUAGE MODEL NAME OR FINETUNED MODEL NAME FOR INFERENCING GOES HERE
batch_size = 4

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True)
assert isinstance(tokenizer, transformers.PreTrainedTokenizerFast)
example = datasets["train"][4]

In [None]:
tokenized_input = tokenizer(example["tokens"], is_split_into_words=True)
tokens = tokenizer.convert_ids_to_tokens(tokenized_input["input_ids"])
#print(tokens)

### Tokenization Validity 

Here we are checking if the tokenizer word ids return the same number of tokens as in one of the dataset samples.

In [None]:
len(example[f"{task}_tags"]), len(tokenized_input["input_ids"])

In [None]:
print(tokenized_input.word_ids())

In [None]:
word_ids = tokenized_input.word_ids()
aligned_labels = [-100 if i is None else example[f"{task}_tags"][i] for i in word_ids]
print(len(aligned_labels), len(tokenized_input["input_ids"]))

The number above should match for you to progress form here.

In [None]:
label_all_tokens = True

In [None]:
def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True) ## For some models, you may need to set max_length to approximately 500.

    labels = []
    for i, label in enumerate(examples[f"{task}_tags"]):
        word_ids = tokenized_inputs.word_ids(batch_index=i)
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:
            # Special tokens have a word id that is None. We set the label to -100 so they are automatically
            # ignored in the loss function.
            if word_idx is None:
                label_ids.append(-100)
            # We set the label for the first token of each word.
            elif word_idx != previous_word_idx:
                label_ids.append(label[word_idx])
            # For the other tokens in a word, we set the label to either the current label or -100, depending on
            # the label_all_tokens flag.
            else:
                label_ids.append(label[word_idx] if label_all_tokens else -100)
            previous_word_idx = word_idx

        labels.append(label_ids)

    tokenized_inputs["labels"] = labels
    return tokenized_inputs

In [None]:
tokenized_datasets = datasets.map(tokenize_and_align_labels, batched=True)

In [None]:
model = AutoModelForTokenClassification.from_pretrained(model_checkpoint, num_labels=len(label_list))

In [None]:
model_name = model_checkpoint.split("/")[-1]
args = TrainingArguments(
    f"{model_name}-finetuned-{task}",
    report_to="wandb",
    # evaluation_strategy = "epoch", ## Instead of focusing on loss and accuracy, we will focus on the F1 score
    evaluation_strategy ='steps',
    eval_steps = 7000,
    save_total_limit = 3,
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=4,
    num_train_epochs=6,
    weight_decay=0.001,
    save_steps=35000,
    push_to_hub=True,
    metric_for_best_model = 'f1',
    load_best_model_at_end=True
)

In [None]:
data_collator = DataCollatorForTokenClassification(tokenizer)

In [None]:
metric = load_metric("seqeval")

In [None]:
labels = [label_list[i] for i in example[f"{task}_tags"]]
metric.compute(predictions=[labels], references=[labels])

In [None]:
def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)

    # Remove ignored index (special tokens)
    true_predictions = [
        [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]

    results = metric.compute(predictions=true_predictions, references=true_labels)
    return {
        "precision": results["overall_precision"],
        "recall": results["overall_recall"],
        "f1": results["overall_f1"],
        "accuracy": results["overall_accuracy"],
    }

In [None]:
trainer = Trainer(
    model,
    args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience=3)]
)

In [None]:
trainer.train()

In [None]:
trainer.evaluate() ## Final Evaluation on the validation set

In [None]:
trainer.save_model('filt-roberta-large-finetuned-ner')

In [None]:
trainer.push_to_hub()

## Obtaining predictions on the test set

In [None]:
predictions, labels, _ = trainer.predict(tokenized_datasets["test"])
predictions = np.argmax(predictions, axis=2)

In [None]:
# Remove ignored index (special tokens)
true_predictions = [
    [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
    for prediction, label in zip(predictions, labels)
]
true_labels = [
    [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
    for prediction, label in zip(predictions, labels)
]

results = metric.compute(predictions=true_predictions, references=true_labels)
results

In [None]:
print(true_predictions[2])

In [None]:
# save predictions to file
tok_preds = open("roberta-large_unfiltered_preds.txt", "w")
for sent_idx in range(len(true_predictions)):
    toks = tokenizer.convert_ids_to_tokens(tokenized_datasets["test"]["input_ids"][sent_idx])
    for tok_idx in range(len(true_predictions[sent_idx])):
        tok_preds.write(toks[tok_idx+1] + " " + true_predictions[sent_idx][tok_idx]+ "\n")
tok_preds.close()

In [None]:
# save true labels to file
lab_file = open("true_labels.txt", "w")
for sent_idx in range(len(true_labels)):
    toks = tokenizer.convert_ids_to_tokens(tokenized_datasets["test"]["input_ids"][sent_idx])
    for tok_idx in range(len(true_labels[sent_idx])):
        lab_file.write(toks[tok_idx+1] + " " + true_labels[sent_idx][tok_idx]+ "\n")
lab_file.close()

## Compare RoBERTa-large model predictions with flair models for Disagreement Analysis

In [None]:
def get_word_idx_from_subword(true_predictions, tokenizer, tokenized_datasets):
    G_index = []
    for sent_idx in range(len(true_predictions)):
        toks = tokenizer.convert_ids_to_tokens(tokenized_datasets["test"]["input_ids"][sent_idx])
        toks.remove("<s>")
        G_list = []
        for i in range(len(toks)):
            if toks[i][0] == "Ġ":
                G_list.append(i)
        G_index.append(G_list)
    return G_index

In [None]:
word_idx = get_word_idx_from_subword(true_predictions, tokenizer, tokenized_datasets)

In [None]:
def read_flair_preds(filename="best_predictions.tsv"):
    with open(filename) as file:
        list_d = file.read()
    delimiter = "\n\n"
    flair_sent = [x+delimiter for x in list_d.split(delimiter) if x]
    return flair_sent

In [None]:
def find_roberta_mismatch(flair_sent, word_idx, robert_preds, labels, filename="mismatches.txt"):
    fileobj = open(filename, "w", encoding="utf-8")
    for i in range(len(flair_sent)):
        sent_list = flair_sent[i].split("\n")
        del sent_list[-1]
        for j in range(len(sent_list)):
            if j != len(sent_list) - 1: # not the last token of sentence
                if j < len(word_idx[i]):
                    if sent_list[j].split(" ")[2] != "O":
                        if sent_list[j].split(" ")[2] != robert_preds[i][word_idx[i][j]]:
                            #mismatch
                            fileobj.write(sent_list[j].split(" ")[0] + " " + sent_list[j].split(" ")[2] + " " + robert_preds[i][word_idx[i][j]] + " " + labels[i][word_idx[i][j]] + " " + "1" + "\n")
                        else:
                            #match
                            fileobj.write(sent_list[j].split(" ")[0] + " " + sent_list[j].split(" ")[2] + " " + robert_preds[i][word_idx[i][j]] + " " + labels[i][word_idx[i][j]] + " " + "0" + "\n")
                    else:
                        if robert_preds[i][word_idx[i][j]] != "B-O":
                            #mismatch
                            fileobj.write(sent_list[j].split(" ")[0] + " " + sent_list[j].split(" ")[2] + " " + robert_preds[i][word_idx[i][j]] + " " + labels[i][word_idx[i][j]] + " " + "1" + "\n")
                        else:
                            #match
                            fileobj.write(sent_list[j].split(" ")[0] + " " + sent_list[j].split(" ")[2] + " " + robert_preds[i][word_idx[i][j]] + " " + labels[i][word_idx[i][j]] + " " + "0" + "\n")
            else: # the last token of sentence
                fileobj.write("SENTEND" + "\n")
    fileobj.close()

In [None]:
flair_preds = read_flair_preds(filename="best_predictions.tsv")
flair_preds11 = read_flair_preds(filename="pubmed_predictions.tsv")

In [None]:
find_roberta_mismatch(flair_preds, word_idx, true_predictions, true_labels, filename="mismatch_best_robert_all_tok.txt")

In [None]:
find_roberta_mismatch(flair_preds, word_idx, true_predictions, true_labels, filename="mismatch_best_robert_all_tok.txt")

In [None]:
def find_flair_mismatch(flair_sent1, flair_sent2, word_idx, labels, filename="mismatch_best_pubmed_all_tok.txt"):
    fileobj = open(filename, "w", encoding="utf-8")
    for i in range(len(flair_sent1)):
        sent_list1 = flair_sent1[i].split("\n")
        sent_list2 = flair_sent2[i].split("\n")
        del sent_list1[-1]
        del sent_list2[-1]
        for j in range(len(sent_list1)):
            if j != len(sent_list1) - 1: # not the last token of sentence
                if j < len(word_idx[i]):
                    if sent_list1[j].split(" ")[2] != sent_list2[j].split(" ")[2]:
                        #mismatch
                        fileobj.write(sent_list1[j].split(" ")[0] + " " + sent_list1[j].split(" ")[2] + " " + sent_list2[j].split(" ")[2] + " " + labels[i][word_idx[i][j]] + " " + "1" + "\n")
                    else:
                        #match
                        fileobj.write(sent_list1[j].split(" ")[0] + " " + sent_list1[j].split(" ")[2] + " " + sent_list2[j].split(" ")[2] + " " + labels[i][word_idx[i][j]] + " " + "0" + "\n")
            else: # the last token of sentence
                fileobj.write("SENTEND" + "\n")
    fileobj.close()

In [None]:
find_flair_mismatch(flair_preds, flair_preds11, word_idx, true_labels, filename="mismatch_best_pubmed_all_tok.txt")

## Confusion Matrix Plot after obtaining results

In [None]:
true_labels_flat = [item for sublist in true_labels for item in sublist]
true_predictions_flat = [item for sublist in true_predictions for item in sublist]

In [None]:
# sns.light_palette("seagreen", as_cmap=True)

def plot_cm(y_true, y_pred, figsize=(10,10)):
    cm = confusion_matrix(y_true, y_pred, labels=np.unique(y_true))
    cm_sum = np.sum(cm, axis=1, keepdims=True)
    cm_perc = cm / cm_sum.astype(float) * 100
    annot = np.empty_like(cm).astype(str)
    nrows, ncols = cm.shape
    for i in range(nrows):
        for j in range(ncols):
            c = cm[i, j]
            p = cm_perc[i, j]
            if i == j:
                s = cm_sum[i]
                annot[i, j] = '%.1f%%\n%d/%d' % (p, c, s)
            elif c == 0:
                annot[i, j] = ''
            else:
                annot[i, j] = '%.1f%%\n%d' % (p, c)
    cm = pd.DataFrame(cm, index=np.unique(y_true), columns=np.unique(y_true))
    cm.index.name = 'Actual'
    cm.columns.name = 'Predicted'
    fig, ax = plt.subplots(figsize=figsize)
    plt.savefig('output.png')
    sns.heatmap(cm, cmap= "YlGnBu", annot=annot, fmt='', ax=ax).figure.savefig('file.png')
    
plot_cm(true_labels_flat, true_predictions_flat)