<a href="https://colab.research.google.com/github/studio-ousia/luke/blob/master/notebooks/huggingface_open_entity.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Reproducing experimental results of LUKE on Open Entity Using Hugging Face Transformers

This notebook shows how to reproduce the state-of-the-art results on the [Open Entity entity typing dataset](https://www.cs.utexas.edu/~eunsol/html_pages/open_entity.html) reported in [this paper](https://arxiv.org/abs/2010.01057) using the Trasnsformers library and the [fine-tuned model checkpoint](https://huggingface.co/studio-ousia/luke-large-finetuned-open-entity) available on the Model Hub.
The source code used in the experiments is also available [here](https://github.com/studio-ousia/luke/tree/master/examples/entity_typing).

In [1]:
!pip install git+https://github.com/huggingface/transformers.git@refs/pull/11223/head

Collecting git+https://github.com/huggingface/transformers.git@refs/pull/11223/head
  Cloning https://github.com/huggingface/transformers.git (to revision refs/pull/11223/head) to /tmp/pip-req-build-afyah90k
  Running command git clone -q https://github.com/huggingface/transformers.git /tmp/pip-req-build-afyah90k
  Running command git fetch -q https://github.com/huggingface/transformers.git refs/pull/11223/head
  Running command git checkout -q c29e2f974b2ac3b78b4a9aaa238d49f736170dd9
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
Collecting tokenizers<0.11,>=0.10.1
[?25l  Downloading https://files.pythonhosted.org/packages/ae/04/5b870f26a858552025a62f1649c20d29d2672c02ff3c3fb4c688ca46467a/tokenizers-0.10.2-cp37-cp37m-manylinux2010_x86_64.whl (3.3MB)
[K     |████████████████████████████████| 3.3MB 5.9MB/s 
Collecting sacremoses
[?25l  Downloading https://files.pythonhos

In [2]:
import json
import torch
from tqdm import trange
from transformers import LukeTokenizer, LukeForEntityClassification

In [3]:
# Download the dataset available on https://github.com/thunlp/ERNIE
!gdown --id 1HlWw7Q6-dFSm9jNSCh4VaBf1PlGqt9im
!tar xzf /content/data.tar.gz

# Place test.json on the working directory
!cp data/OpenEntity/test.json .

Downloading...
From: https://drive.google.com/uc?id=1HlWw7Q6-dFSm9jNSCh4VaBf1PlGqt9im
To: /content/data.tar.gz
322MB [00:03, 99.9MB/s]


In [4]:
def load_examples(dataset_file):
    with open(dataset_file, "r") as f:
        data = json.load(f)

    examples = []
    for item in data:
        examples.append(dict(
            text=item["sent"],
            entity_spans=[(item["start"], item["end"])],
            label=item["labels"]
        ))

    return examples

In [5]:
test_examples = load_examples("test.json")

In [6]:
# Load the model checkpoint
model = LukeForEntityClassification.from_pretrained("studio-ousia/luke-large-finetuned-open-entity")
model.eval()
model.to("cuda")

# Load the tokenizer
tokenizer = LukeTokenizer.from_pretrained("studio-ousia/luke-large-finetuned-open-entity")

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1051.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=2239159855.0, style=ProgressStyle(descr…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=898822.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=456318.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=15287192.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=33.0, style=ProgressStyle(description_w…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1000.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1691.0, style=ProgressStyle(description…




In [7]:
batch_size = 128

num_predicted = 0
num_gold = 0
num_correct = 0

all_predictions = []
all_labels = []

for batch_start_idx in trange(0, len(test_examples), batch_size):
    batch_examples = test_examples[batch_start_idx:batch_start_idx + batch_size]
    texts = [example["text"] for example in batch_examples]
    entity_spans = [example["entity_spans"] for example in batch_examples]
    gold_labels = [example["label"] for example in batch_examples]

    inputs = tokenizer(texts, entity_spans=entity_spans, return_tensors="pt", padding=True)
    inputs = inputs.to("cuda")
    with torch.no_grad():
        outputs = model(**inputs)

    num_gold += sum(len(l) for l in gold_labels)
    for logits, labels in zip(outputs.logits, gold_labels):
        for index, logit in enumerate(logits):
            if logit > 0:
                num_predicted += 1
                predicted_label = model.config.id2label[index]
                if predicted_label in labels:
                    num_correct += 1

precision = num_correct / num_predicted
recall = num_correct / num_gold
f1 = 2 * precision * recall / (precision + recall)

print(f"\n\nprecision: {precision} recall: {recall} f1: {f1}")

100%|██████████| 16/16 [01:12<00:00,  4.54s/it]



precision: 0.7980295566502463 recall: 0.7657563025210085 f1: 0.781559903511123



