In [5]:
import numpy as np
import pandas as pd
import re
import glob
from   os import path
import os
import json
from tqdm.notebook import tqdm
from dateutil.parser import parse
from dateutil.tz import gettz

import warnings
warnings.filterwarnings('ignore', category=UserWarning, module='transformers')

import torch

from transformers import AutoTokenizer, AutoModelForTokenClassification
from transformers import pipeline

# use the first GPU if available, otherwise use CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# device = torch.device("mps"if torch.backends.mps.is_available()else "cpu")

tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER")
model = AutoModelForTokenClassification.from_pretrained("dslim/bert-base-NER")

nlp = pipeline("ner", model=model, tokenizer=tokenizer, device=device)

In [7]:
def combineHeadlineText(row):
    if isinstance(row["Headline"], str):
        return row["Headline"] + ". " + row["Text"]
    else:
        return row["Text"]

device(type='mps')

In [7]:
def preprocess_dataframe(df, use_parse=False):
    df = df.drop(['Unnamed: 0'], axis=1, errors='ignore')
    df = df.drop_duplicates(['Date', 'Headline'], keep='last')
    df['Text'] = df['Text'].astype(str)
    df['Text'] = df.apply(lambda row: combineHeadlineText(row), axis=1)
    
    if use_parse:
        df['Date'] = df['Date'].str.replace(r'Published: ', ' ')
        df['Date'] = df['Date'].str.replace(r'First', ' ')
        df['Date'] = df['Date'].apply(lambda date_str: parse(date_str, tzinfos={'ET': gettz('America/New_York')}))
        df['Date'] = df['Date'].dt.date
    else:
        df['Date'] = pd.to_datetime(df['Date'])
    
    df = df.reset_index(drop=True).sort_values(by=['Date'], ascending=True)
    
    return df

In [8]:
def process_entities(ner_results):
    # 首先，我们需要将NER的结果转换成一个更方便处理的格式
    entities = [{'word': d['word'], 'entity': d['entity'], 'score': d['score']} for d in ner_results]

    # 然后，我们创建一个新的列表来存储处理后的实体
    processed_entities = []
    current_entity = []
    for entity in entities:
        if entity['entity'].startswith('B-') or (entity['entity'].startswith('I-') and not current_entity):
            if current_entity:
                processed_entities.append(current_entity)
            current_entity = [entity]
        elif entity['entity'].startswith('I-') and current_entity:
            current_entity.append(entity)
    if current_entity:
        processed_entities.append(current_entity)

    return processed_entities

In [9]:
def json_serializable(item):
    """Convert non-serializable items to serializable."""
    if isinstance(item, np.float32):
        return float(item)
    raise TypeError(f"Type {type(item)} not serializable")

In [10]:
def perform_ner_on_dataframe(df, country_name):
    count = []
    check = []
    ner_results_data_list = []
    
    for i in tqdm(range(len(df)), desc=f"Processing {country_name}"):
        ner_results = nlp(df["Text"].iloc[i])
        processed_entities = process_entities(ner_results)

        country_instances = []
        country_check = []
        aliases = country_aliases.get(country_name, [country_name])
        for entity_group in processed_entities:
            words = [entity['word'] for entity in entity_group]
            entity_name = ' '.join(words)
            entity_type = entity_group[0]['entity']
            entity_score = sum(entity['score'] for entity in entity_group) / len(entity_group)
            if entity_type in ["B-LOC", "B-ORG"] and entity_score > 0.98:
                country_check.append(entity_name)
                if any(alias in entity_name for alias in aliases):
                    country_instances.append(entity_name)

        count.append(len(country_instances))
        check.append(country_check)
        
        ner_result = {
            'Date': df['Date'].iloc[i],
            'Headline': df['Headline'].iloc[i],
            'NER': json.dumps(ner_results, default=json_serializable)  # convert ner_results to string
        }
        ner_results_data_list.append(ner_result)

    df['Count'] = count
    df = df[df['Count'] >= 3]
    df = df.drop(['Count'], axis=1)
    df_ner_results = pd.DataFrame(ner_results_data_list)
    
    return df, df_ner_results

In [11]:
# 定义一个字典储存每个国家的别名
country_aliases = {
"United States": ["USA", "America", "US", "United States"],
"Canada": ["Canada", "CA"],
"United Kingdom": ["UK", "United Kingdom", "Britain", "England", "Scotland", "Wales", "Northern Ireland"],
"Australia": ["Australia", "AU", "Aussie"],
"China": ["China", "PRC"],
"Denmark": ["Denmark", "DK"],
"Finland": ["Finland", "FI"],
"France": ["France", "French Republic", "FR"],
"Germany": ["Germany", "DE"],
"Japan": ["Japan", "JP"],
"Italy": ["Italy", "Italian Republic", "IT"],
"Netherlands": ["Netherlands", "Holland", "NL"],
"Norway": ["Norway", "NO"],
"Portugal": ["Portugal", "PT"],
"Singapore": ["Singapore", "SG"],
"South Korea": ["South Korea", "KR"],
"Spain": ["Spain", "ES"],
"Sweden": ["Sweden", "SE"],
"Switzerland": ["Switzerland", "Swiss Confederation", "Swiss", "CH"],
"New Zealand": ["New Zealand", "NZ"]
}

In [12]:
lst_files = []

Path = "../Data/Articles/ReutersArticles/*.csv"

for fname in glob.glob(Path):
    lst_files.append(fname)

lst_files = sorted(lst_files)

for file in lst_files:
    file_name = os.path.basename(file)  # Get the file name from the full file path
    country_name = os.path.splitext(file_name)[0]  # Remove the file extension
    country_name = country_name.replace("_articles", "")  # Remove the "_articles" part of the file name

    # 仅对单个国家的数据进行实体识别
    # if file == '../Data/ReutersArticles/United States_articles.csv':
    # 对所有的未进行过NER的国家进行实体识别
    if os.path.isfile('../Data/NER/Reuters_NER/{0}.csv'.format(country_name)):
        print("File already exists: {0}.csv".format(country_name))
    else:
        print(file)
        df = pd.read_csv(file)

        if df.shape[0] != 0:
            
            df = preprocess_dataframe(df)

            df, df_ner_results = perform_ner_on_dataframe(df, country_name)


            # Convert the list of relevant articles back to a DataFrame
            df.to_csv(r'../Data/NER/Reuters_NER/{0}.csv'.format(country_name))
            df_ner_results.to_csv(r'../Data/NER/Reuters_NER_Results/{0}.csv'.format(country_name))

../Data/Articles/ReutersArticles\Australia_articles.csv


Processing Australia:   0%|          | 0/11187 [00:00<?, ?it/s]

../Data/Articles/ReutersArticles\Canada_articles.csv


Processing Canada:   0%|          | 0/8908 [00:00<?, ?it/s]

../Data/Articles/ReutersArticles\China_articles.csv


Processing China:   0%|          | 0/23510 [00:00<?, ?it/s]

../Data/Articles/ReutersArticles\Denmark_articles.csv


Processing Denmark:   0%|          | 0/1971 [00:00<?, ?it/s]

../Data/Articles/ReutersArticles\Finland_articles.csv


Processing Finland:   0%|          | 0/1592 [00:00<?, ?it/s]

../Data/Articles/ReutersArticles\France_articles.csv


Processing France:   0%|          | 0/11407 [00:00<?, ?it/s]

../Data/Articles/ReutersArticles\Germany_articles.csv


Processing Germany:   0%|          | 0/12946 [00:00<?, ?it/s]

../Data/Articles/ReutersArticles\Italy_articles.csv


Processing Italy:   0%|          | 0/6967 [00:00<?, ?it/s]

../Data/Articles/ReutersArticles\Japan_articles.csv


Processing Japan:   0%|          | 0/12224 [00:00<?, ?it/s]

../Data/Articles/ReutersArticles\Netherlands_articles.csv


Processing Netherlands:   0%|          | 0/3661 [00:00<?, ?it/s]

../Data/Articles/ReutersArticles\New Zealand_articles.csv


Processing New Zealand:   0%|          | 0/4339 [00:00<?, ?it/s]

../Data/Articles/ReutersArticles\Norway_articles.csv


Processing Norway:   0%|          | 0/2444 [00:00<?, ?it/s]

../Data/Articles/ReutersArticles\Portugal_articles.csv


Processing Portugal:   0%|          | 0/2072 [00:00<?, ?it/s]

../Data/Articles/ReutersArticles\Singapore_articles.csv


Processing Singapore:   0%|          | 0/4214 [00:00<?, ?it/s]

../Data/Articles/ReutersArticles\South Korea_articles.csv


Processing South Korea:   0%|          | 0/4938 [00:00<?, ?it/s]

../Data/Articles/ReutersArticles\Spain_articles.csv


Processing Spain:   0%|          | 0/5374 [00:00<?, ?it/s]

../Data/Articles/ReutersArticles\Sweden_articles.csv


Processing Sweden:   0%|          | 0/2689 [00:00<?, ?it/s]

../Data/Articles/ReutersArticles\Switzerland_articles.csv


Processing Switzerland:   0%|          | 0/2726 [00:00<?, ?it/s]

../Data/Articles/ReutersArticles\United Kingdom_articles.csv


Processing United Kingdom:   0%|          | 0/14940 [00:00<?, ?it/s]

../Data/Articles/ReutersArticles\United States_articles.csv


Processing United States:   0%|          | 0/22907 [00:00<?, ?it/s]

In [14]:
# 读取所有国家的实体识别后剩余的数据, 然后输出剩余数据的数量
lst_ner_files = []
for fname in glob.glob("../Data/NER/Reuters_NER/*.csv"):
    lst_ner_files.append(fname)
for file in lst_ner_files:
    file_name = os.path.basename(file)  # Get the file name from the full file path
    country_name = os.path.splitext(file_name)[0]  # Remove the file extension
    df = pd.read_csv(file)
    print(country_name, df.shape[0])

Australia 3606
Canada 2587
China 10251
Denmark 334
Finland 473
France 2317
Germany 3069
Italy 2065
Japan 3734
Netherlands 236
New Zealand 1120
Norway 524
Portugal 505
Singapore 777
South Korea 969
Spain 1399
Sweden 712
Switzerland 422
United Kingdom 6977
United States 2771
