In [None]:
import pickle
from datetime import datetime
from pathlib import Path
from typing import Dict

import evaluate
import numpy as np
from datasets import Dataset
from sklearn.model_selection import train_test_split
from transformers import (
    AutoModelForTokenClassification,
    AutoTokenizer,
    DataCollatorForTokenClassification,
    Trainer,
    TrainingArguments,
)

In [None]:
DATA_DIR = Path.cwd().parent / "data"
DATASET_PATH = DATA_DIR / "train_anno_2023_08_30_14_39.pkl"

HF_MODEL_NAME = "roberta-base"
IGNORE_ENCODING = -100
RANDOM_SEED = 23

TIMESTAMP = datetime.now().strftime("%Y_%m_%d_%H_%M")
MODEL_DIR = Path.cwd().parent / "models" / f"{HF_MODEL_NAME}_{TIMESTAMP}"
PRIORITY_LABEL = "PRIORITY"

In [None]:
dataset = pickle.load(DATASET_PATH.open("rb"))
len(dataset)

In [None]:
for doc in dataset:
    for chunk in doc:
        chunk["ner_tags"] = [
            tag[:2] + PRIORITY_LABEL if tag != "O" else tag for tag in chunk["ner_tags"]
        ]

In [None]:
train_dataset, val_dataset = train_test_split(
    dataset, test_size=0.2, random_state=RANDOM_SEED
)
len(train_dataset), len(val_dataset)

In [None]:
train_dataset_dict = {
    "tokens": [chunk["tokens"] for doc in train_dataset for chunk in doc],
    "ner_tags": [chunk["ner_tags"] for doc in train_dataset for chunk in doc],
}

val_dataset_dict = {
    "tokens": [chunk["tokens"] for doc in val_dataset for example in doc],
    "ner_tags": [chunk["ner_tags"] for doc in val_dataset for example in doc],
}

In [None]:
train_dataset_hf = Dataset.from_dict(train_dataset_dict)
val_dataset_hf = Dataset.from_dict(val_dataset_dict)

In [None]:
labels = sorted(
    {label for doc_labels in train_dataset_dict["ner_tags"] for label in doc_labels}
)
labels

In [None]:
id2label = {i: label for i, label in enumerate(labels)}
label2id = {label: i for i, label in enumerate(labels)}
id2label, label2id

In [None]:
def tokenize_and_align_labels(
    examples: Dataset, tokenizer: AutoTokenizer, label2id: Dict
) -> Dataset:
    tokenized_inputs = tokenizer(examples["tokens"], is_split_into_words=True)
    labels = []
    for i, label in enumerate(examples["ner_tags"]):
        word_ids = tokenized_inputs.word_ids(
            batch_index=i
        )  # Map tokens to their respective word.
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:  # Set the special tokens to -100.
            if word_idx is None:
                label_ids.append(IGNORE_ENCODING)
            elif (
                word_idx != previous_word_idx
            ):  # Only label the first token of a given word.
                label_ids.append(label2id[label[word_idx]])
            else:
                label_ids.append(IGNORE_ENCODING)
            previous_word_idx = word_idx
        labels.append(label_ids)

    tokenized_inputs["labels"] = labels
    return tokenized_inputs

In [None]:
tokenizer = AutoTokenizer.from_pretrained(
    HF_MODEL_NAME, add_prefix_space=True, use_fast=True
)

In [None]:
tokenized_train_dataset = train_dataset_hf.map(
    lambda examples: tokenize_and_align_labels(examples, tokenizer, label2id),
    batched=True,
)
tokenized_val_dataset = val_dataset_hf.map(
    lambda examples: tokenize_and_align_labels(examples, tokenizer, label2id),
    batched=True,
)

In [None]:
len(tokenized_train_dataset), len(tokenized_val_dataset)

In [None]:
model = AutoModelForTokenClassification.from_pretrained(
    HF_MODEL_NAME, num_labels=len(labels), id2label=id2label, label2id=label2id
)
data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)

In [None]:
tokenized_train_dataset = tokenized_train_dataset.remove_columns(["tokens", "ner_tags"])
tokenized_val_dataset = tokenized_val_dataset.remove_columns(["tokens", "ner_tags"])
len(tokenized_train_dataset), len(tokenized_val_dataset)

In [None]:
seqeval = evaluate.load("seqeval")


def compute_metrics(p):
    dataset_logits, dataset_labels = p
    dataset_preds = np.argmax(dataset_logits, axis=2)

    dataset_filtered_preds = [
        [
            labels[token_pred]
            for (token_pred, token_label) in zip(doc_preds, doc_labels)
            if token_label != IGNORE_ENCODING
        ]
        for doc_preds, doc_labels in zip(dataset_preds, dataset_labels)
    ]
    dataset_filtered_labels = [
        [
            labels[token_label]
            for (token_pred, token_label) in zip(doc_preds, doc_labels)
            if token_label != IGNORE_ENCODING
        ]
        for doc_preds, doc_labels in zip(dataset_preds, dataset_labels)
    ]

    results = seqeval.compute(
        predictions=dataset_filtered_preds, references=dataset_filtered_labels
    )
    return {
        "precision": results["overall_precision"],
        "recall": results["overall_recall"],
        "f1": results["overall_f1"],
        "accuracy": results["overall_accuracy"],
    }

In [None]:
training_args = TrainingArguments(
    output_dir=MODEL_DIR,
    learning_rate=5e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=10,
    weight_decay=0.01,
    warmup_ratio=0.1,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    push_to_hub=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train_dataset,
    eval_dataset=tokenized_val_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()