In [None]:
from pprint import pformat

import numpy as np
import torch
import torch.nn.functional as F

from transformers import AutoTokenizer

import wandb
from tqdm import tqdm

import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score, accuracy_score

from narrative_time import conversion_utils
from narrative_time import modeling_utils
from narrative_time import event_relations
from narrative_time.modeling import TransformerForRelationPrediction

NUM_RELATIONS = len(event_relations.REL_TO_ID)

%matplotlib inline

%load_ext autoreload
%autoreload 2

plt.rcParams['figure.dpi'] = 150

This file contains only the code to train the model and compute simple metrics. If you want to reproduce the plots, they are avaialble in the notebook `modeling_and_plotting.ipynb` 😉

In [None]:
MODEL_NAME = "google/long-t5-tglobal-large"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

a1_annotations = conversion_utils.get_annotations("../corpus/timebank/nt_format/tbd_a1.jsonl", as_dict=True)
a2_annotations = conversion_utils.get_annotations("../corpus/timebank/nt_format/tbd_a2.jsonl", as_dict=True)

In [None]:
TEST_DOCUMENTS = [
    "PRI19980115.2000.0186",
    "PRI19980213.2000.0313",
    "PRI19980121.2000.2591",
    "ABC19980114.1830.0611",
    "APW19980213.1380",
    "NYT19980402.0453",
]

In [None]:
train_dataset = []
test_dataset = []

n_errors = 0
for annotation_id, annotation in a1_annotations.items():
    annotation = modeling_utils.NTAnnotation.from_json(annotation)
    input_ids, event_left_tokens, event_relation_matrix = modeling_utils.preprocess_document(annotation, tokenizer)

    if annotation_id in TEST_DOCUMENTS:
        test_dataset.append((input_ids, event_left_tokens, event_relation_matrix))
    else:
        train_dataset.append((input_ids, event_left_tokens, event_relation_matrix))

assert len(test_dataset) == len(TEST_DOCUMENTS)

for annotation_id, annotation in a2_annotations.items():
    annotation = modeling_utils.NTAnnotation.from_json(annotation)
    input_ids, event_left_tokens, event_relation_matrix = modeling_utils.preprocess_document(annotation, tokenizer)

    if annotation_id in TEST_DOCUMENTS:
        test_dataset.append((input_ids, event_left_tokens, event_relation_matrix))
    else:
        train_dataset.append((input_ids, event_left_tokens, event_relation_matrix))

assert len(test_dataset) == len(TEST_DOCUMENTS) * 2

In [None]:
ACCUM_STEPS = 1
EARLY_STOPPING = -1

dtype = torch.bfloat16
# bfloat is only supported on 30-series and A-series GPUs (released 2020)
# if you have older ones try float16
# but fp16 usually requires some extra tricks that we didn't implement here
# worst case, you can use float32, but it will be slower and use more memory
model = TransformerForRelationPrediction(MODEL_NAME, num_relations=NUM_RELATIONS).to(device="cuda", dtype=dtype)
model.transformer.gradient_checkpointing_enable()  # needed to fit bigger documents into GPU memory

_n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Number of model parameters: {_n_params/1e6:.2f}M")

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)#, weight_decay=1e-2)

wandb.init(
    project="narrative-time",
    config={
        "model_name": MODEL_NAME,
        "gradient_accumulation_steps": ACCUM_STEPS,
    },
)
wandb.watch(model)

global_step = 0

In [None]:
metrics = None  # just here for linter to be happy
early_stopping = EARLY_STOPPING
best_f1 = 0
best_epoch = 0

for epoch in tqdm(range(30)):
    shuffled_train_dataset = train_dataset.copy()
    np.random.shuffle(shuffled_train_dataset)

    for input_ids, event_left_tokens, event_relation_matrix in shuffled_train_dataset:
        input_ids = input_ids.to("cuda")
        event_left_tokens = event_left_tokens.to("cuda")
        event_relation_matrix = event_relation_matrix.to("cuda")

        relation_logits = model(input_ids, event_left_tokens)

        num_events = relation_logits.shape[0]
        relation_logits = relation_logits.view(num_events * num_events, NUM_RELATIONS)
        targets = event_relation_matrix.view(num_events * num_events)

        loss = F.cross_entropy(relation_logits, targets, ignore_index=-1)
        loss /= ACCUM_STEPS
        loss.backward()

        if global_step % ACCUM_STEPS == 0:
            optimizer.step()
            optimizer.zero_grad()

            accuracy = (relation_logits.argmax(dim=1) == targets).float().mean()
            wandb.log({"loss": loss.item(), "train_accuracy": accuracy.item(), "epoch": epoch}, step=global_step)

        global_step += 1

    # evaluate
    model.eval()

    all_predictions_list = []
    all_targets_list = []

    with torch.no_grad():
        for input_ids, event_left_tokens, event_relation_matrix in test_dataset:
            input_ids = input_ids.to("cuda")
            event_left_tokens = event_left_tokens.to("cuda")
            event_relation_matrix = event_relation_matrix.to("cuda")

            relation_logits = model(input_ids, event_left_tokens)

            num_events = relation_logits.shape[0]
            relation_logits = relation_logits.view(num_events * num_events, NUM_RELATIONS)
            targets = event_relation_matrix.view(num_events * num_events)

            loss = F.cross_entropy(relation_logits, targets, ignore_index=-1)
            accuracy = (relation_logits.argmax(dim=1) == targets).float().mean()

            all_predictions_list.append(relation_logits.argmax(dim=1).cpu())
            all_targets_list.append(targets.cpu())

    all_predictions = torch.cat(all_predictions_list)
    all_targets = torch.cat(all_targets_list)

    # -1 are on the main diagonal (SELF relation) and are ignored
    all_predictions = all_predictions[all_targets != -1]
    all_targets = all_targets[all_targets != -1]

    accuracy = (all_predictions == all_targets).float().mean()
    p = precision_score(all_targets, all_predictions, average="macro", zero_division=0)
    r = recall_score(all_targets, all_predictions, average="macro", zero_division=0)
    f1 = f1_score(all_targets, all_predictions, average="macro", zero_division=0)

    metrics = {
        "test_accuracy": accuracy.item(),
        "test_precision_macro": p,
        "test_recall_macro": r,
        "test_f1_macro": f1,
    }

    wandb.log(metrics, step=global_step)
    model.train()

    # early stopping
    if EARLY_STOPPING > 0:
        if f1 > best_f1:
            best_f1 = f1
            best_epoch = epoch
            early_stopping = EARLY_STOPPING

            # save model
            checkpoint = {
                "model": model.state_dict(),
                "epoch": epoch,
                "global_step": global_step,
                "metrics": metrics,
            }
            torch.save(checkpoint, "best_model.pt")
            del checkpoint
        else:
            early_stopping -= 1
        
        if early_stopping == 0:
            break

print(f"Best metrics:\n{pformat(metrics)}")
if EARLY_STOPPING > 0:
    # load best model
    checkpoint = torch.load("best_model.pt")
    model.load_state_dict(checkpoint["model"])

wandb.finish()