# Evaluating a Model Checkpoint

This notebook demonstrates how to evaluate the given model checkpoint trained on NYT.

## Imports

In [1]:
import os
from transformers import BertTokenizerFast, Trainer
from dataprocess.data_extractor import unirel_span_extractor
from dataprocess.data_processor import UniRelDataProcessor
from dataprocess.dataset import UniRelSpanDataset
from model.model_transformers import UniRelModel

## Evaluation Function

In [2]:
def evaluate_checkpoint(
    checkpoint: str,
    dataset_name: str,
) -> tuple[float, float, float]:
    # LOAD TEST DATASET
    added_token = [f"[unused{i}]" for i in range(1, 17)]
    tokenizer = BertTokenizerFast.from_pretrained(
        "bert-base-cased",
        additional_special_tokens=added_token,
        do_basic_tokenize=False,
    )
    data_processor = UniRelDataProcessor(
        root="data",
        tokenizer=tokenizer,
        dataset_name=dataset_name,
    )
    test_samples = data_processor.get_test_sample()
    test_dataset = UniRelSpanDataset(
        test_samples,
        data_processor,
        tokenizer,
        mode='test',
        ignore_label=-100,
        model_type='bert',
        ngram_dict=None,
        max_length=150 + 2,
        predict=True,
        eval_type="test"
    )
    print(f"Loaded test dataset {dataset_name} of size {len(test_dataset)}")

    # LOAD MODEL CHECKPOINT
    model = UniRelModel.from_pretrained(checkpoint)
    print(f"Loaded model from checkpoint {checkpoint}")

    # GET MODEL PREDICTIONS ON TEST DATA
    trainer = Trainer(model=model)
    test_prediction = trainer.predict(test_dataset, ignore_keys=["loss"])

    # COMPUTE METRICS
    dump_path = os.path.join(checkpoint, dataset_name)
    if not os.path.exists(dump_path):
        os.makedirs(dump_path)
    print(f"Saving test dump in {dump_path}")
    return unirel_span_extractor(
        tokenizer=tokenizer,
        dataset=test_dataset,
        predictions=test_prediction,
        path=dump_path,
    )

## Evaluate NYT Checkpoint on NYT

In [3]:
evaluate_checkpoint(checkpoint="./output/nyt/checkpoint-final", dataset_name="nyt")

100%|██████████| 5000/5000 [00:03<00:00, 1515.34it/s]


139
more than 100: 49
more than 150: 0
Loaded test dataset nyt of size 5000
Loaded model from checkpoint ./output/nyt/checkpoint-final




Saving test dump in ./output/nyt/checkpoint-final/nyt

all:  {'p': 8057, 'c': 7548, 'g': 8120} 
 {'all-prec': 0.9368251210127839, 'all-recall': 0.9295566502463054, 'all-f1': 0.9331767323978487}


(0.9368251210127839, 0.9295566502463054, 0.9331767323978487)

## Evaluate NYT Checkpoint on WebNLG

We don't expect it to do well at all.

In [4]:
evaluate_checkpoint(checkpoint="./output/nyt/checkpoint-final", dataset_name="webnlg")

100%|██████████| 703/703 [00:01<00:00, 646.53it/s]


99
more than 100: 0
more than 150: 0
Loaded test dataset webnlg of size 703
Loaded model from checkpoint ./output/nyt/checkpoint-final


Saving test dump in ./output/nyt/checkpoint-final/webnlg

all:  {'p': 474, 'c': 36, 'g': 1607} 
 {'all-prec': 0.0759493670886076, 'all-recall': 0.0224019912881145, 'all-f1': 0.03459875060067275}


(0.0759493670886076, 0.0224019912881145, 0.03459875060067275)