# Entity Recognizer service local implementation with HuggingFace library model

In [1]:
from dataclasses import dataclass
from typing import Literal

from transformers import AutoModelForTokenClassification, AutoTokenizer
from transformers import pipeline

from noisemon.domain.models.entity_span import EntitySpan
from noisemon.domain.services.entity_recognition.entity_recognizer import EntityRecognizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
@dataclass
class HFEntity:
    entity_group: Literal["MISC", "ORG", "PER", "LOC", "O"]
    score: float
    word: str
    start: int
    end: int

In [3]:
def hf_entity_to_entity_span(hf_entity: HFEntity) -> EntitySpan:
    return EntitySpan(
        span_start=hf_entity.start,
        span_end=hf_entity.end,
        span=hf_entity.word
    )

In [9]:
class EntityRecognizerLocalImpl(EntityRecognizer):
    def __init__(self):
        model_name = "philschmid/distilroberta-base-ner-conll2003"
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForTokenClassification.from_pretrained(model_name)
        self.nlp = pipeline(
            "ner",
            model=self.model,
            tokenizer=self.tokenizer,
            aggregation_strategy="simple"
        )

    def recognize_entities(self, text):
        output = self.nlp(text)
        output: list[HFEntity] = [HFEntity(**e) for e in output]
        result = [hf_entity_to_entity_span(e) for e in output if e.entity_group == "ORG"]

        # strip entitite out of trailing spaces
        for es in result:
            
        
        return result

In [10]:
entity_recognizer = EntityRecognizerLocalImpl()

Downloading (…)okenizer_config.json: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 293/293 [00:00<00:00, 577kB/s]
Downloading (…)lve/main/config.json: 100%|██████████████████████████████████████████████████████████████████████████████████████| 1.03k/1.03k [00:00<00:00, 2.08MB/s]
Downloading (…)olve/main/vocab.json: 100%|████████████████████████████████████████████████████████████████████████████████████████| 798k/798k [00:00<00:00, 1.45MB/s]
Downloading (…)olve/main/merges.txt: 100%|████████████████████████████████████████████████████████████████████████████████████████| 456k/456k [00:00<00:00, 1.09MB/s]
Downloading (…)/main/tokenizer.json: 100%|██████████████████████████████████████████████████████████████████████████████████████| 1.36M/1.36M [00:00<00:00, 1.99MB/s]
Downloading (…)cial_tokens_map.json: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 239/239 [00:00<00:00, 436kB/s]
Down

## Test data

In [76]:
test_text = "Apple Inc. is a leading tech company. Microsoft Corporation is also well-known."

In [77]:
test_text.index("Apple Inc."), len("Apple Inc.")

(0, 10)

In [78]:
test_text.index("Microsoft Corporation"), len("Microsoft Corporation") 

(38, 21)

In [16]:
test_text = "Amazon.com is an e-commerce giant. Google LLC is a tech company."

In [17]:
entity_recognizer.recognize_entities(test_text)

[EntitySpan(span=' Amazon', span_start=0, span_end=6),
 EntitySpan(span='.', span_start=6, span_end=7),
 EntitySpan(span='com', span_start=7, span_end=10),
 EntitySpan(span=' Google LLC', span_start=35, span_end=45)]

## Resolving issues

In [46]:
from copy import deepcopy

In [64]:
def merge_consecutive(data):
    data = deepcopy(data)
    merged_data = []
    current_obj = None
    
    for obj in data:
        if current_obj is None:
            current_obj = obj
        elif obj['start'] == current_obj['end']:
            current_obj['end'] = obj['end']
            current_obj['word'] += obj['word']
            current_obj['score'] = (current_obj['score'] + obj['score']) / 2
        else:
            merged_data.append(current_obj)
            current_obj = obj
            
    if current_obj:
        merged_data.append(current_obj)
        
    return merged_data

In [65]:
data = entity_recognizer.nlp(test_text)
data

[{'entity_group': 'ORG',
  'score': 0.9996345,
  'word': ' Amazon',
  'start': 0,
  'end': 6},
 {'entity_group': 'ORG',
  'score': 0.9950395,
  'word': '.',
  'start': 6,
  'end': 7},
 {'entity_group': 'ORG',
  'score': 0.9642482,
  'word': 'com',
  'start': 7,
  'end': 10},
 {'entity_group': 'ORG',
  'score': 0.9988612,
  'word': ' Google LLC',
  'start': 35,
  'end': 45}]

In [66]:
updated_data = merge_consecutive(data)
print(updated_data)

[{'entity_group': 'ORG', 'score': 0.9807925820350647, 'word': ' Amazon.com', 'start': 0, 'end': 10}, {'entity_group': 'ORG', 'score': 0.9988612, 'word': ' Google LLC', 'start': 35, 'end': 45}]


In [67]:
def strip_whitespaces(datum, text):
    datum = deepcopy(datum)
    if datum["word"].startswith(" "):
        word = datum["word"][1:]
        start = text.index(
            word, 
            max([datum["start"] - 2, 0]), 
            datum["end"] + 2
        )
        end = start + len(word)

        datum = {
            'entity_group': 'ORG',
            'score': 0.9807925820350647,
            'word': word,
            'start': start,
            'end': end
        }
        
    return datum

In [68]:
for datum in updated_data:
    print(strip_whitespaces(datum, test_text))

{'entity_group': 'ORG', 'score': 0.9807925820350647, 'word': 'Amazon.com', 'start': 0, 'end': 10}
{'entity_group': 'ORG', 'score': 0.9807925820350647, 'word': 'Google LLC', 'start': 35, 'end': 45}


In [58]:
datum

{'entity_group': 'ORG',
 'score': 0.9807925820350647,
 'word': ' Amazon.com',
 'start': 0,
 'end': 10}

In [43]:
test_text.index("Google LLC")

35

In [69]:
"Amazon.com is an e-commerce giant. Google LLC is a tech company.".index("Google LLC")

35

In [70]:
len("Google LLC")

10