In [1]:
import json
from pprint import pprint
from tqdm import tqdm
# !pip install -U nervaluate
from nervaluate import Evaluator

from exciton.nlp.named_entity_recognition import Exciton_NER
from exciton.nlp.named_entity_recognition.utils import clean_result

In [2]:
def process_spans(input_data):
    input_data["named_entities"] = sorted(input_data["named_entities"], key=lambda x: x["span"][0])
    sents = []
    x = 0
    for sen in input_data["named_entities"]:
        text = input_data["text"][x:sen["span"][0]]
        sents.append({"text": text, "label": "O"})
        text = input_data["text"][sen["span"][0]:sen["span"][1]]
        sents.append({"text": text, "label": sen["label"]})
        x = sen["span"][1]
    text = input_data["text"][x:]
    sents.append({"text": text, "label": "O"})

    tokens = []
    for itm in sents:
        for k, sen in enumerate(itm["text"].split()):
            if itm["label"] == "O":
                tokens.append({"token": sen, "label": itm["label"]})
            else:
                if k == 0:
                    tokens.append({"token": sen, "label": "B-" + itm["label"]})
                else:
                    tokens.append({"token": sen, "label": "I-" + itm["label"]})
    return tokens

In [3]:
model = Exciton_NER(path_to_model="/tmp/conll2003_xlmroberta/", device="cuda:0")

In [4]:
data = []
with open("/home/tshi/exciton/datasets/nlp/named_entity_recognition/conll2003_v1/test.jsonl", "r") as fp:
    for line in fp:
        itm = json.loads(line)
        itm["text"] = " ".join(itm["tokens"])
        itm["etokens"] = itm["tokens"]
        data.append(itm)
print(len(data))

gold = []
for itm in tqdm(data):
    results = process_spans(clean_result(itm))
    gold.append([sen["label"] for sen in results])
pred = []
for itm in tqdm(model.predict(data)):
    results = process_spans(itm)
    pred.append([sen["label"] for sen in results])

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3682/3682 [00:00<00:00, 27827.10it/s]

3682



100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3682/3682 [00:00<00:00, 150769.08it/s]


In [5]:
evaluator = Evaluator(gold, pred, tags=['LOC', 'PER', "ORG", "MISC"], loader="list")
results, results_by_tag = evaluator.evaluate()
pprint(results)

{'ent_type': {'actual': 5852,
              'correct': 5056,
              'f1': 0.8754978354978356,
              'incorrect': 434,
              'missed': 208,
              'partial': 0,
              'possible': 5698,
              'precision': 0.8639781271360218,
              'recall': 0.8873288873288874,
              'spurious': 362},
 'exact': {'actual': 5852,
           'correct': 5256,
           'f1': 0.9101298701298701,
           'incorrect': 234,
           'missed': 208,
           'partial': 0,
           'possible': 5698,
           'precision': 0.8981544771018455,
           'recall': 0.9224289224289224,
           'spurious': 362},
 'partial': {'actual': 5852,
             'correct': 5256,
             'f1': 0.9303896103896104,
             'incorrect': 0,
             'missed': 208,
             'partial': 234,
             'possible': 5698,
             'precision': 0.9181476418318524,
             'recall': 0.9429624429624429,
             'spurious': 362},
 'str

In [6]:
text = ["La presencia femenina en el homenaje a los socios del Oviedo: De aquella en el Tartiere todo eran paisanos"]
pprint(model.predict(text))

[{'named_entities': [{'label': 'ORG',
                      'span': [54, 72],
                      'text': 'Oviedo: De aquella'},
                     {'label': 'ORG', 'span': [79, 82], 'text': 'Tar'},
                     {'label': 'LOC', 'span': [82, 87], 'text': 'tiere'}],
  'text': 'La presencia femenina en el homenaje a los socios del Oviedo: De '
          'aquella en el Tartiere todo eran paisanos'}]


In [7]:
text = ["我要去纽约。"]
pprint(model.predict(text))

[{'named_entities': [{'label': 'LOC', 'span': [3, 5], 'text': '纽约'}],
  'text': '我要去纽约。'}]
