<a href="https://colab.research.google.com/github/studio-ousia/luke/blob/master/notebooks/huggingface_open_entity.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Reproducing experimental results of LUKE on Open Entity Using Hugging Face Transformers

This notebook shows how to reproduce the state-of-the-art results on the [Open Entity entity typing dataset](https://www.cs.utexas.edu/~eunsol/html_pages/open_entity.html) reported in [this paper](https://arxiv.org/abs/2010.01057) using the Trasnsformers library and the [fine-tuned model checkpoint](https://huggingface.co/studio-ousia/luke-large-finetuned-open-entity) available on the Model Hub.
The source code used in the experiments is also available [here](https://github.com/studio-ousia/luke/tree/master/examples/entity_typing).

There are two other related notebooks:

* [Reproducing experimental results of LUKE on TACRED Using Hugging Face Transformers](https://github.com/studio-ousia/luke/blob/master/notebooks/huggingface_tacred.ipynb)
* [Reproducing experimental results of LUKE on CoNLL-2003 Using Hugging Face Transformers](https://github.com/studio-ousia/luke/blob/master/notebooks/huggingface_conll_2003.ipynb)

In [1]:
# Currently, LUKE is only available on the master branch
!pip install git+https://github.com/huggingface/transformers.git

Collecting git+https://github.com/huggingface/transformers.git
  Cloning https://github.com/huggingface/transformers.git to /tmp/pip-req-build-g03z3zcr
  Running command git clone -q https://github.com/huggingface/transformers.git /tmp/pip-req-build-g03z3zcr
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
Collecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/75/ee/67241dc87f266093c533a2d4d3d69438e57d7a90abb216fa076e7d475d4a/sacremoses-0.0.45-py3-none-any.whl (895kB)
[K     |████████████████████████████████| 901kB 18.7MB/s 
Collecting huggingface-hub==0.0.8
  Downloading https://files.pythonhosted.org/packages/a1/88/7b1e45720ecf59c6c6737ff332f41c955963090a18e72acbcbeac6b25e86/huggingface_hub-0.0.8-py3-none-any.whl
Collecting tokenizers<0.11,>=0.10.1
[?25l  Downloading https://files.pythonhosted.org/packages/ae/04/5b870f26a858552025a62f1649c20d29

In [2]:
import json
import torch
from tqdm import trange
from transformers import LukeTokenizer, LukeForEntityClassification

## Loading the dataset

The dataset is downloaded from the link mentioned in [this GitHub repository](https://github.com/thunlp/ERNIE). The test.json file is placed in the current directory and loaded using the `load_examples` function.

In [3]:
!gdown --id 1HlWw7Q6-dFSm9jNSCh4VaBf1PlGqt9im
!tar xzf /content/data.tar.gz

# Place test.json on the working directory
!cp data/OpenEntity/test.json .

Downloading...
From: https://drive.google.com/uc?id=1HlWw7Q6-dFSm9jNSCh4VaBf1PlGqt9im
To: /content/data.tar.gz
322MB [00:03, 89.4MB/s]


In [4]:
def load_examples(dataset_file):
    with open(dataset_file, "r") as f:
        data = json.load(f)

    examples = []
    for item in data:
        examples.append(dict(
            text=item["sent"],
            entity_spans=[(item["start"], item["end"])],
            label=item["labels"]
        ))

    return examples

In [5]:
test_examples = load_examples("test.json")

## Loading the fine-tuned model and tokenizer

We construct the model and tokenizer using the [fine-tuned model checkpoint](https://huggingface.co/studio-ousia/luke-large-finetuned-open-entity).

In [6]:
# Load the model checkpoint
model = LukeForEntityClassification.from_pretrained("studio-ousia/luke-large-finetuned-open-entity")
model.eval()
model.to("cuda")

# Load the tokenizer
tokenizer = LukeTokenizer.from_pretrained("studio-ousia/luke-large-finetuned-open-entity")

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1051.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=2239159855.0, style=ProgressStyle(descr…




Some weights of the model checkpoint at studio-ousia/luke-large-finetuned-open-entity were not used when initializing LukeForEntityClassification: ['luke.embeddings.position_ids']
- This IS expected if you are initializing LukeForEntityClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing LukeForEntityClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=898822.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=456318.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=15287192.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=33.0, style=ProgressStyle(description_w…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1000.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1691.0, style=ProgressStyle(description…




## Measuring performance

We classify entity mentions in the test set and measure the performance of the model.
The performance reported in the [original paper](https://arxiv.org/abs/2010.01057) is successfully reproduced.

In [7]:
batch_size = 128

num_predicted = 0
num_gold = 0
num_correct = 0

all_predictions = []
all_labels = []

for batch_start_idx in trange(0, len(test_examples), batch_size):
    batch_examples = test_examples[batch_start_idx:batch_start_idx + batch_size]
    texts = [example["text"] for example in batch_examples]
    entity_spans = [example["entity_spans"] for example in batch_examples]
    gold_labels = [example["label"] for example in batch_examples]

    inputs = tokenizer(texts, entity_spans=entity_spans, return_tensors="pt", padding=True)
    inputs = inputs.to("cuda")
    with torch.no_grad():
        outputs = model(**inputs)

    num_gold += sum(len(l) for l in gold_labels)
    for logits, labels in zip(outputs.logits, gold_labels):
        for index, logit in enumerate(logits):
            if logit > 0:
                num_predicted += 1
                predicted_label = model.config.id2label[index]
                if predicted_label in labels:
                    num_correct += 1

precision = num_correct / num_predicted
recall = num_correct / num_gold
f1 = 2 * precision * recall / (precision + recall)

print(f"\n\nprecision: {precision} recall: {recall} f1: {f1}")

100%|██████████| 16/16 [00:33<00:00,  2.08s/it]



precision: 0.7980295566502463 recall: 0.7657563025210085 f1: 0.781559903511123





## Detecting types of entities in a text

Finally, we detect types of entities in a text using the [fine-tuned model](https://huggingface.co/studio-ousia/luke-large-finetuned-open-entity).

In [8]:
text = "Beyoncé lives in Los Angeles."
entity_spans = [(0, 7)]  # character-based entity span corresponding to "Beyoncé"

inputs = tokenizer(text, entity_spans=entity_spans, return_tensors="pt")
inputs.to("cuda")
outputs = model(**inputs)

predicted_indices = [index for index, logit in enumerate(outputs.logits[0]) if logit > 0]
print("Predicted entity type for Beyoncé:", [model.config.id2label[index] for index in predicted_indices])

entity_spans = [(17, 28)]  # character-based entity span corresponding to "Beyoncé"
inputs = tokenizer(text, entity_spans=entity_spans, return_tensors="pt")
inputs.to("cuda")
outputs = model(**inputs)

predicted_indices = [index for index, logit in enumerate(outputs.logits[0]) if logit > 0]
print("Predicted entity type for Los Angeles:", [model.config.id2label[index] for index in predicted_indices])

Predicted entity type for Beyoncé: ['person']
Predicted entity type for Los Angeles: ['location', 'place']
