LUKE for parsing out named entity with the replacement token.

In [1]:
import unicodedata

import numpy as np
import seqeval.metrics
import spacy
import torch
from tqdm import tqdm, trange
from transformers import LukeTokenizer, LukeForEntitySpanClassification

# Load the model checkpoint
model = LukeForEntitySpanClassification.from_pretrained("studio-ousia/luke-large-finetuned-conll-2003")
model.eval()

# Load the tokenizer
tokenizer = LukeTokenizer.from_pretrained("studio-ousia/luke-large-finetuned-conll-2003")

def sentence_LUKE_replace(text, token_str):
    nlp = spacy.load("en_core_web_sm")
    doc = nlp(text)

    entity_spans = []
    original_word_spans = []
    for token_start in doc:
        for token_end in doc[token_start.i:]:
            entity_spans.append((token_start.idx, token_end.idx + len(token_end)))
            original_word_spans.append((token_start.i, token_end.i + 1))

    inputs = tokenizer(text, entity_spans=entity_spans, return_tensors="pt", padding=True)
    with torch.no_grad():
        outputs = model(**inputs)

    logits = outputs.logits
    max_logits, max_indices = logits[0].max(dim=1)

    predictions = []
    for logit, index, span in zip(max_logits, max_indices, original_word_spans):
        if index != 0:  # the span is not NIL
            predictions.append((logit, span, model.config.id2label[int(index)]))

    # construct an IOB2 label sequence
    predicted_sequence = ["O"] * len(doc)
    for _, span, label in sorted(predictions, key=lambda o: o[0], reverse=True):
        if all([o == "O" for o in predicted_sequence[span[0] : span[1]]]):
            predicted_sequence[span[0]] = "B-" + label
            if span[1] - span[0] > 1:
                predicted_sequence[span[0] + 1 : span[1]] = ["I-" + label] * (span[1] - span[0] - 1)

    lst_nnp_words = []

    row_continue = False
    for token, label in zip(doc, predicted_sequence):
        if label != "O":
            if row_continue:
                lst_nnp_words[len(lst_nnp_words) - 1] = str(lst_nnp_words[len(lst_nnp_words) - 1]) + " " + str(token)
            else:
                lst_nnp_words.append(str(token))
            row_continue = True
        else:
            row_continue = False

    for word in lst_nnp_words:
        text = text.replace(word, token_str)
    return text

Some weights of the model checkpoint at studio-ousia/luke-large-finetuned-conll-2003 were not used when initializing LukeForEntitySpanClassification: ['luke.embeddings.position_ids']
- This IS expected if you are initializing LukeForEntitySpanClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing LukeForEntitySpanClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Convert Amazon Product Review data from https://cseweb.ucsd.edu/~jmcauley/datasets/amazon_v2/ to CSV file with LUKE parsed out named entities

In [2]:
import json
import csv

file_name = "Office_Products.json"

save_file = "Office_Products.csv"

f = open(file_name)
lines = f.readlines()

with open(save_file, 'w', newline='') as file:
    writer = csv.writer(file)
    writer.writerow(["reviewerID", "overall", "reviewText_LUKE", "reviewText"])


for line in lines:
    data = json.loads(line)
    if "reviewText" in data and "overall" in data and "reviewerID" in data and len(data["reviewText"]) < 150:
        with open(save_file, "a", newline='') as file:
            writer = csv.writer(file)
            print(data["reviewerID"])
            writer.writerow([str(data["reviewerID"]), str(data["overall"]), sentence_LUKE_replace(data["reviewText"], "<NNP>"), data["reviewText"]])

f.close()

A3GIXT0M21V3JR
A2NIJTYWADLK57
A2PNBV0VHHSO2I
AIYOHMZQ53DRL
A6J32ICD0JCGJ
A14JXAEQKIQDZG
AYSQ2SUDA2O8P
A2C7GZTXK2B1NJ
A3QOFKZ5TL7YS2
A190CFAEU43HPA
A3FFJLY0W2DG6L
A1NA2NNAMYR8M9
A17HXU70I328W6
A26D4JPQKISXZG
A23FMQF01YLMV1
A9Y2VQS00QPQ7
A3ALJY3JY0R6RV
A6RN361K3R3CH
A27SRH7N7HRHZW
A28X0A79T17T8M
AZV5O5KBTWN0D
A3MNAV5XM1JKGU
A2NT8TDO8OSXSA
AG4QG2ZON7QJZ
A3D9FICLUNR4FI
A3PNS3IZJGFH6S
AP5RDC5IA9MWI
A1ECYOK98MUEXP
A343LQ1I7G9H6S
A2Q6OZAYUUI0UP
AMLZVWNAXMA1H
AKMZ1PYVUGRMW
A257G5IP5ILE9D
A2CAI3Z43C6KBQ
AI0UXOG2BNP1J
A18OCAY2K0D7AU
A10UZ5LABRZ06C
A3Q7AH0VB0VOJ6
A2Z97U5Q8APOG1
A167C4I2EQH01E
A1I04XJA6WI6ED
A1LP4957D7UHIP
A17O5TWF4H1VNU
A35OZ4VEPPT33O
A1QSDFFWPBNBLS
A3LQMVRXLSG45C
A2EXJ2E3MRUNIO
A1CR78J40QFFYV
A2WV6N4YWOGV1D
A28XQPKHGCEHRF
A3JJTLTBCE7T7L
A1HFSC0ZQDUP6R
A87X3YXRQUZUL
A2W00LK70177UX
A2PGRFNI9IJILQ
A223T9TID8967G
A189124RNF8AG6
ARMT7L6NYJ0XR
A2O5LZ90ZD5AUL
A22ND3F4XJ7BME
A1GUVIKXAUP7AF
A2YO3BMBU0R1M9
A1DIZUDNQ5AOUS
AWPLOS66RWSLV
A32PS4OOKDX1S2
A1W7NSDR6FLBLX
A2N82GT36J33GK
A19U0K6LQ

KeyboardInterrupt: 