In [1]:
import numpy as np
import pandas as pd
import re
import glob
from   os import path
import os
from tqdm.notebook import tqdm
from dateutil.parser import parse
from dateutil.tz import gettz

import warnings
warnings.filterwarnings('ignore', category=UserWarning, module='transformers')

import torch

from transformers import AutoTokenizer, AutoModelForTokenClassification
from transformers import pipeline

# use the first GPU if available, otherwise use CPU
device = torch.device("mps"if torch.backends.mps.is_available()else "cpu")
# device = 0 if torch.cuda.is_available() else -1
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER")
model = AutoModelForTokenClassification.from_pretrained("dslim/bert-base-NER")

nlp = pipeline("ner", model=model, tokenizer=tokenizer, device=device)

In [2]:
def combineHeadlineText(row):
    if isinstance(row["Headline"], str):
        return row["Headline"] + ". " + row["Text"]
    else:
        return row["Text"]

In [11]:
def process_entities(ner_results):
    # 首先，我们需要将NER的结果转换成一个更方便处理的格式
    entities = [{'word': d['word'], 'entity': d['entity'], 'score': d['score']} for d in ner_results]

    # 然后，我们创建一个新的列表来存储处理后的实体
    processed_entities = []
    current_entity = []
    for entity in entities:
        if entity['entity'].startswith('B-') or (entity['entity'].startswith('I-') and not current_entity):
            if current_entity:
                processed_entities.append(current_entity)
            current_entity = [entity]
        elif entity['entity'].startswith('I-') and current_entity:
            current_entity.append(entity)
    if current_entity:
        processed_entities.append(current_entity)

    return processed_entities

In [18]:
# 定义一个字典储存每个国家的别名
country_aliases = {
"United States": ["USA", "America", "US", "United States"],
"Canada": ["Canada", "CA"],
"United Kingdom": ["UK", "United Kingdom", "Britain", "England", "Scotland", "Wales", "Northern Ireland"],
"Australia": ["Australia", "AU", "Aussie"],
"China": ["China", "PRC"],
"Denmark": ["Denmark", "DK"],
"Finland": ["Finland", "FI"],
"France": ["France", "French Republic", "FR"],
"Germany": ["Germany", "DE"],
"Japan": ["Japan", "JP"],
"Italy": ["Italy", "Italian Republic", "IT"],
"Netherlands": ["Netherlands", "Holland", "NL"],
"Norway": ["Norway", "NO"],
"Portugal": ["Portugal", "PT"],
"Singapore": ["Singapore", "SG"],
"South Korea": ["South Korea", "KR"],
"Spain": ["Spain", "ES"],
"Sweden": ["Sweden", "SE"],
"Switzerland": ["Switzerland", "Swiss Confederation", "Swiss", "CH"],
"New Zealand": ["New Zealand", "NZ"]
}

In [22]:
lst_files = []

Path = "../Data/ReutersArticles/*.csv"

for fname in glob.glob(Path):
    lst_files.append(fname)

lst_files = sorted(lst_files)

for file in lst_files:
    file_name = os.path.basename(file)  # Get the file name from the full file path
    country_name = os.path.splitext(file_name)[0]  # Remove the file extension
    country_name = country_name.replace("_articles", "")  # Remove the "_articles" part of the file name

    # 仅对单个国家的数据进行实体识别
    if file == '../Data/ReutersArticles/United States_articles.csv':
    # 对所有的未进行过NER的国家进行实体识别
    # if os.path.isfile('../Data/Reuters_NER/{0}.csv'.format(country_name)):
    #     print("File already exists: {0}.csv".format(country_name))
    # else:
        print(file)
        df = pd.read_csv(file)

        if df.shape[0] != 0:
            df = df.drop(['Unnamed: 0'], axis = 1)
            df = df.drop_duplicates(['Date','Headline'],keep= 'last')
            df['Text'] = df['Text'].astype(str)
            df['Text'] = df.apply(lambda row: combineHeadlineText(row), axis=1)
            df['Date'] = pd.to_datetime(df['Date']).dt.date  # change date format to YYYY-MM-DD
            df = df.sort_values(by = ['Date'], ascending = True)

            count = []
            check = []
            # 新建一个DataFrame储存实体识别结果
            ner_results_data = pd.DataFrame()
            for i in tqdm(range(len(df)), desc=f"Processing {country_name}"):
                ner_results = nlp(df["Text"].iloc[i])

                processed_entities = process_entities(ner_results)

                country_instances = []
                country_check = []
                aliases = country_aliases.get(country_name, [country_name])
                for entity_group in processed_entities:
                    words = [entity['word'] for entity in entity_group]
                    entity_name = ' '.join(words)
                    entity_type = entity_group[0]['entity']
                    entity_score = sum(entity['score'] for entity in entity_group) / len(entity_group)
                    if entity_type in ["B-LOC", "B-ORG"] and entity_score > 0.98:
                        country_check.append(entity_name)
                        if any(alias in entity_name for alias in aliases):
                            country_instances.append(entity_name)

                count.append(len(country_instances))
                check.append(country_check)
                ner_results_data = pd.concat([ner_results_data, pd.json_normalize(ner_results)], ignore_index=True)


            df['Count'] = count
            df_check = df.copy()
            df_check['Check'] = check
            df = df[df['Count'] > 0]
            df = df.drop(['Count'], axis = 1)
            df.to_csv(r'../Data/Reuters_NER/{0}.csv'.format(country_name))
            # 添加df_check为csv文件, 并命名为country_name_check.csv
            df_check.to_csv(r'../Data/NER_Entity/{0}_check.csv'.format(country_name))
            ner_results_data.to_csv(r'../Data/NER_Entity/{0}.csv'.format(country_name), index=False)  # Save the ner_results DataFrame to a CSV file


../Data/ReutersArticles/United States_articles.csv


Processing United States:   0%|          | 0/22907 [00:00<?, ?it/s]

In [23]:
# 读取所有国家的实体识别后剩余的数据, 然后输出剩余数据的数量
lst_ner_files = []
for fname in glob.glob("../Data/Reuters_NER/*.csv"):
    lst_ner_files.append(fname)
for file in lst_ner_files:
    file_name = os.path.basename(file)  # Get the file name from the full file path
    country_name = os.path.splitext(file_name)[0]  # Remove the file extension
    df = pd.read_csv(file)
    print(country_name, df.shape[0])

Netherlands 2661
New Zealand 3171
Singapore 2999
Denmark 1675
Italy 6109
Norway 1998
Japan 9954
Finland 1329
United States 17947
United Kingdom 13011
Germany 11026
Canada 6950
France 9304
Portugal 1771
Spain 4669
Sweden 2202
South Korea 3760
Australia 9234
Switzerland 2144
China 20741
