In [1]:
import pandas as pd
import autorootcwd
from flair.data import Sentence
from flair.models import SequenceTagger
from tqdm import tqdm

In [2]:
class NamedEntityRecognizer:
    def __init__(self, tagger=SequenceTagger.load("flair/ner-english")):
        self.tagger = tagger

    def get_sentences_from_article(self, article):
        sentences = []
        for sentence in article.split('.'):
            sentences.append(sentence + '.')
        return sentences

    def get_named_entities(self, article):
        length = 0
        sentences = self.get_sentences_from_article(article)
        named_entities = []
        for sentence in sentences:
            s = Sentence(sentence)
            self.tagger.predict(s)
            for entity in s.get_spans('ner'):
            
                entity = entity.to_dict(tag_type='ner')
                entity['start_pos'] = int(entity['start_pos']) + length
                entity['end_pos'] = int(entity['end_pos']) + length
                named_entities.append(entity)
            length += len(sentence)

        return self.filter_entities(named_entities)
    
    def filter_entities(self, entities):
        return [entity for entity in entities if entity['labels'][0]['value'] in  ['PER', 'ORG']]
    
    def prepare_entity_format(self, entities, article_id):
        return pd.DataFrame([{'article_id': article_id, 
                'start_pos': entity['start_pos'], 
                'end_pos': entity['end_pos'], 
                'name': entity['text'], 
                'label': entity['labels'][0]['value'], 
                'confidence': entity['labels'][0]['confidence']} 
                for entity in entities])

2024-05-11 02:11:44,794 SequenceTagger predicts: Dictionary with 20 tags: <unk>, O, S-ORG, S-MISC, B-PER, E-PER, S-LOC, B-ORG, E-ORG, I-PER, S-PER, B-MISC, I-MISC, E-MISC, I-ORG, B-LOC, E-LOC, I-LOC, <START>, <STOP>


In [3]:
def save_entities(name, start_index=0, end_index=None, selected_indices=[]):
    ner = NamedEntityRecognizer()
    articles = pd.read_csv(f'data/processed_dataset/{name}/articles.csv', header=None)
    articles.columns = ['id', 'text_german', 'text_english']
    articles.fillna('', inplace=True)
    for index, article in tqdm(articles.iterrows()):
        if selected_indices and index not in selected_indices:
            continue
        if index < start_index or (end_index is not None and index >= end_index):
            continue
        named_entities = ner.prepare_entity_format(ner.get_named_entities(article['text_english']), article['id'])
        named_entities.to_csv(f'data/processed_dataset/{name}/ner.csv', index=False, mode='a', header=not bool(index))
