LUKE for parsing out named entity with the replacement token.

In [3]:
import unicodedata

import numpy as np
import seqeval.metrics
import spacy
import torch
from tqdm import tqdm, trange
from transformers import LukeTokenizer, LukeForEntitySpanClassification

# Load the model checkpoint
model = LukeForEntitySpanClassification.from_pretrained("studio-ousia/luke-large-finetuned-conll-2003")
model.eval()

# Load the tokenizer
tokenizer = LukeTokenizer.from_pretrained("studio-ousia/luke-large-finetuned-conll-2003")

def sentence_LUKE_replace(text, token_str):
    nlp = spacy.load("en_core_web_sm")
    doc = nlp(text)

    entity_spans = []
    original_word_spans = []
    for token_start in doc:
        for token_end in doc[token_start.i:]:
            entity_spans.append((token_start.idx, token_end.idx + len(token_end)))
            original_word_spans.append((token_start.i, token_end.i + 1))

    inputs = tokenizer(text, entity_spans=entity_spans, return_tensors="pt", padding=True)
    with torch.no_grad():
        outputs = model(**inputs)

    logits = outputs.logits
    max_logits, max_indices = logits[0].max(dim=1)

    predictions = []
    for logit, index, span in zip(max_logits, max_indices, original_word_spans):
        if index != 0:  # the span is not NIL
            predictions.append((logit, span, model.config.id2label[int(index)]))

    # construct an IOB2 label sequence
    predicted_sequence = ["O"] * len(doc)
    for _, span, label in sorted(predictions, key=lambda o: o[0], reverse=True):
        if all([o == "O" for o in predicted_sequence[span[0] : span[1]]]):
            predicted_sequence[span[0]] = "B-" + label
            if span[1] - span[0] > 1:
                predicted_sequence[span[0] + 1 : span[1]] = ["I-" + label] * (span[1] - span[0] - 1)

    lst_nnp_words = []

    row_continue = False
    for token, label in zip(doc, predicted_sequence):
        if label != "O":
            if row_continue:
                lst_nnp_words[len(lst_nnp_words) - 1] = str(lst_nnp_words[len(lst_nnp_words) - 1]) + " " + str(token)
            else:
                lst_nnp_words.append(str(token))
            row_continue = True
        else:
            row_continue = False

    for word in lst_nnp_words:
        text = text.replace(word, token_str)
    return text

Some weights of the model checkpoint at studio-ousia/luke-large-finetuned-conll-2003 were not used when initializing LukeForEntitySpanClassification: ['luke.embeddings.position_ids']
- This IS expected if you are initializing LukeForEntitySpanClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing LukeForEntitySpanClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Convert Amazon Product Review data from https://cseweb.ucsd.edu/~jmcauley/datasets/amazon_v2/ to CSV file with LUKE parsed out named entities

In [4]:
import json
import csv

file_name = "Movies_and_TV.json"

save_file = "Movies_and_TV.csv"

f = open(file_name)
lines = f.readlines()

with open(save_file, 'w', newline='') as file:
    writer = csv.writer(file)
    writer.writerow(["reviewerID", "overall", "reviewText_LUKE", "reviewText"])


for line in lines:
    data = json.loads(line)
    if "reviewText" in data and "overall" in data and "reviewerID" in data and len(data["reviewText"]) < 150:
        with open(save_file, "a", newline='') as file:
            writer = csv.writer(file)
            print(data["reviewerID"])
            writer.writerow([str(data["reviewerID"]), str(data["overall"]), sentence_LUKE_replace(data["reviewText"], "<NNP>"), data["reviewText"]])

f.close()

AHTYUW2H1276L
A1OMHX76O2NC6V
A3603DM2201OZP
A3HE4QW1655VB9
A1CLLD4H7TYKKJ
A2CFV9UPFTTM10
A3J84USGWMGUHS
A2MM0696TZ4AIL
A3139J3877Y61F
A3YX36D6RB2BV
A2BIXV1ED5XAST
A30Z25ZYLRUQYT
A3LNAPLZ5JO0VO
A1Z4040VKCTXAO
A3OFILCH8S2DUC
A28CD9IB14CT8K
A2PANT8U0OJNT4
A1TN0V94A3ECSH
A3TS9EQCNLU0SM
AKEHF7IT6FZX8
A35PMPEBHXAQ2T
A2AZ7CE08CWK4F
A12B2K6X03UA6Z
A2QWTVQ90KYZZP
A2P9HCN2JM55DS
AK1BZVQ7X9EJR
ADUN41PCA098H
A15UN81O3NDPE3
A1R6FXSZVCIJKK
A3JUFDBFDX0A0J
A2JS4XEZCZVJSQ
A3UI4K5IZKSNYW
A1VNYYUC24CGKF
A1VNYYUC24CGKF
A3R9UEJKGS2419
A2E95VSDA88WNL
APBJQBITFEC
A2D77RTO2BC8VJ
A3V1BXUGMM6F63
A1E7VTRDMI4XMV
A3JDW8YT5N55Y7
ANEKVYF88H156
A131408USYVUN5
A1AYN8825GNXSP
A2U7DG83EXUSFP
A2EFK8RB157JSE
AD8CFJUEVV8BQ
A1NMYH2IITNOC4
A2YAP2BVW3MVUW
A29HHWJB3JYPX7
AZPHVD6XM61V
A19H3CR53ME00T
A21A5AK06GGXRT
A3TPZM4CU8KLVI
A3Q7WXBOTKP05H
A101IGU6UDKW3X
A2BY5BZZ5UAV9B
AN61PQENFFEUR
A3W4LWZ37M9VZC
A1PLBB1E7VMZPK
A14OJJGL4LIDIG
A1RAPP3YN10O53
A0160612BLIWRHROHLLE
A340KTL9KUGYB7
AM87PNTXTKLOI
A3I6UVX7UWDYJM
A65UCXN2TPSGC
A3PI

KeyboardInterrupt: 