<a href="https://colab.research.google.com/github/studio-ousia/luke/blob/master/notebooks/huggingface_tacred.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Reproducing experimental results of LUKE on TACRED using Hugging Face Transformers

This notebook shows how to reproduce the state-of-the-art results on the TACRED relation classification task reported in [this paper](https://arxiv.org/abs/2010.01057) using the Trasnsformers library and the [fine-tuned model](https://huggingface.co/studio-ousia/luke-large-finetuned-tacred) available on the model hub.
The source code used in the experiments is also available [here](https://github.com/studio-ousia/luke/tree/master/examples/relation_classification).

**NOTE:** The TACRED dataset is not publicly available. In the cell below, we copy the test set (test.json) uploaded to our Google Drive to the working directory. Please obtain the dataset by folliowing instructuions on the [TACRED web site](https://nlp.stanford.edu/projects/tacred/) and replace the cell below to place the test.json file on the working directory.

In [1]:
from google.colab import drive
drive.mount('/content/drive')
!cp /content/drive/MyDrive/projects/luke/data/tacred/test.json test.json

Mounted at /content/drive


In [2]:
!pip install git+https://github.com/huggingface/transformers.git@refs/pull/11223/head

Collecting git+https://github.com/huggingface/transformers.git@refs/pull/11223/head
  Cloning https://github.com/huggingface/transformers.git (to revision refs/pull/11223/head) to /tmp/pip-req-build-pn2_yqp4
  Running command git clone -q https://github.com/huggingface/transformers.git /tmp/pip-req-build-pn2_yqp4
  Running command git fetch -q https://github.com/huggingface/transformers.git refs/pull/11223/head
  Running command git checkout -q a345c6b7248257e697b95539808430a335b383d5
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
Collecting tokenizers<0.11,>=0.10.1
[?25l  Downloading https://files.pythonhosted.org/packages/ae/04/5b870f26a858552025a62f1649c20d29d2672c02ff3c3fb4c688ca46467a/tokenizers-0.10.2-cp37-cp37m-manylinux2010_x86_64.whl (3.3MB)
[K     |████████████████████████████████| 3.3MB 11.0MB/s 
[?25hCollecting sacremoses
[?25l  Downloading https://files.py

In [3]:
import json
import torch
from tqdm import trange
from transformers import LukeTokenizer, LukeForEntityPairClassification

In [4]:
def load_examples(dataset_file):
    with open(dataset_file, "r") as f:
        data = json.load(f)

    examples = []
    for i, item in enumerate(data):
        tokens = item["token"]
        token_spans = dict(
            subj=(item["subj_start"], item["subj_end"] + 1),
            obj=(item["obj_start"], item["obj_end"] + 1)
        )

        if token_spans["subj"][0] < token_spans["obj"][0]:
            entity_order = ("subj", "obj")
        else:
            entity_order = ("obj", "subj")

        text = ""
        cur = 0
        char_spans = {}
        for target_entity in entity_order:
            token_span = token_spans[target_entity]
            text += " ".join(tokens[cur : token_span[0]])
            if text:
                text += " "
            char_start = len(text)
            text += " ".join(tokens[token_span[0] : token_span[1]])
            char_end = len(text)
            char_spans[target_entity] = (char_start, char_end)
            text += " "
            cur = token_span[1]
        text += " ".join(tokens[cur:])
        text = text.rstrip()

        examples.append(dict(
            text=text,
            entity_spans=[tuple(char_spans["subj"]), tuple(char_spans["obj"])],
            label=item["relation"]
        ))

    return examples

In [5]:
test_examples = load_examples("test.json")

In [6]:
model = LukeForEntityPairClassification.from_pretrained("studio-ousia/luke-large-finetuned-tacred")
model.eval()
model.to("cuda")
tokenizer = LukeTokenizer.from_pretrained("studio-ousia/luke-large-finetuned-tacred")

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=3299.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=2239466725.0, style=ProgressStyle(descr…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=898822.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=456318.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=15287192.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=33.0, style=ProgressStyle(description_w…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1000.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1691.0, style=ProgressStyle(description…




In [7]:
batch_size = 128

num_predicted = 0
num_gold = 0
num_correct = 0

for batch_start_idx in trange(0, len(test_examples), batch_size):
    batch_examples = test_examples[batch_start_idx:batch_start_idx + batch_size]
    texts = [example["text"] for example in batch_examples]
    entity_spans = [example["entity_spans"] for example in batch_examples]
    gold_labels = [example["label"] for example in batch_examples]

    inputs = tokenizer(texts, entity_spans=entity_spans, return_tensors="pt", padding=True)
    inputs = inputs.to("cuda")
    with torch.no_grad():
        outputs = model(**inputs)
    predicted_indices = outputs.logits.argmax(-1)
    predicted_labels = [model.config.id2label[index.item()] for index in predicted_indices]
    for predicted_label, gold_label in zip(predicted_labels, gold_labels):
        if predicted_label != "no_relation":
            num_predicted += 1
        if gold_label != "no_relation":
            num_gold += 1
            if predicted_label == gold_label:
                num_correct += 1

precision = num_correct / num_predicted
recall = num_correct / num_gold
f1 = 2 * precision * recall / (precision + recall)

print(f"\n\nprecision: {precision} recall: {recall} f1: {f1}")

100%|██████████| 122/122 [14:51<00:00,  7.31s/it]



precision: 0.7034638130104196 recall: 0.7512781954887218 f1: 0.7265852239674229



