In [2]:
from transformers import AutoTokenizer, DistilBertModel, pipeline, DistilBertTokenizer, DistilBertForTokenClassification
import torch
import torch.nn as nn

from _RE import combine_entities, merge_result

model_name = 'dslim/distilbert-NER'
tokenizer = AutoTokenizer.from_pretrained(model_name)
ner_pipeline = pipeline('ner', model=model_name, tokenizer=tokenizer)

tokenizer = DistilBertTokenizer.from_pretrained('dslim/distilbert-NER')
ner_model = DistilBertForTokenClassification.from_pretrained('dslim/distilbert-NER')

ner = pipeline()

def tokenize_and_ner(sentences):
    tokenized_inputs = tokenizer(sentences, padding=True, truncation=True, return_tensors="pt")
    with torch.no_grad():
        ner_outputs = ner_model(**tokenized_inputs).logits
    ner_predictions = torch.argmax(ner_outputs, dim=2)
    return tokenized_inputs, ner_predictions

sentences = [
    "Zest Airways, Inc. is an airline headquartered in Pasay City, Metro Manila, Philippines.",
    "AirAsia Zest operates flights out of Ninoy Aquino International Airport."
]



tokenized_inputs, ner_predictions = tokenize_and_ner(sentences)

In [None]:
reverse_label_map = {
    '0': '0',
    '1': 'B-PER',
    '2': 'I-PER',
    '3': 'B-ORG',
    '4': 'I-ORG',
    '5': 'B-LOC',
    '6': 'I-LOC',
    '7': 'B-MISC',
    '8': 'I-MISC'
}

In [4]:
print(tokenized_inputs)

{'input_ids': tensor([[  101,   163,  2556, 14099,   117,  3561,   119,  1110,  1126,  8694,
          9514,  1107, 19585, 27911,  1392,   117,  6431,  9002,   117,  4336,
           119,   102],
        [  101,  1806, 23390,  1465,   163,  2556,  5049,  7306,  1149,  1104,
         27453, 21244, 27194,  1570,  3369,   119,   102,     0,     0,     0,
             0,     0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0]])}


In [3]:
print(ner_predictions)

tensor([[0, 3, 3, 4, 4, 4, 4, 0, 0, 0, 0, 0, 5, 5, 6, 0, 5, 6, 0, 5, 0, 0],
        [0, 3, 3, 3, 4, 4, 0, 0, 0, 0, 5, 5, 6, 6, 6, 0, 0, 5, 0, 0, 0, 5]])
