## Import Libraries
Please note that the device setup need to fit your system.

In [1]:
import numpy as np
import pandas as pd
import re
import glob
from   os import path
import os
import json
from tqdm.notebook import tqdm
from dateutil.parser import parse
from dateutil.tz import gettz

import warnings
warnings.filterwarnings('ignore', category=UserWarning, module='transformers')

import torch

from transformers import AutoTokenizer, AutoModelForTokenClassification
from transformers import pipeline

# use the first GPU if available, otherwise use CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# device = torch.device("mps"if torch.backends.mps.is_available()else "cpu")
# device = 0 if torch.cuda.is_available() else -1
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER")
model = AutoModelForTokenClassification.from_pretrained("dslim/bert-base-NER")

nlp = pipeline("ner", model=model, tokenizer=tokenizer, device=device)

In [2]:
def combineHeadlineText(row):
    if isinstance(row["Headline"], str):
        return row["Headline"] + ". " + row["Text"]
    else:
        return row["Text"]

In [4]:
def preprocess_dataframe(df, use_parse=False):
    df = df.drop(['Unnamed: 0'], axis=1, errors='ignore')
    df = df.drop_duplicates(['Date', 'Headline'], keep='last')
    df['Text'] = df['Text'].astype(str)
    df['Text'] = df.apply(lambda row: combineHeadlineText(row), axis=1)
    
    if use_parse:
        df['Date'] = df['Date'].str.replace(r'Published: ', ' ')
        df['Date'] = df['Date'].str.replace(r'First', ' ')
        df['Date'] = df['Date'].apply(lambda date_str: parse(date_str, tzinfos={'ET': gettz('America/New_York')}))
        df['Date'] = df['Date'].dt.date
    else:
        df['Date'] = pd.to_datetime(df['Date'])
    
    df = df.reset_index(drop=True).sort_values(by=['Date'], ascending=True)
    
    return df

In [5]:
def process_entities(ner_results):
    # 首先，我们需要将NER的结果转换成一个更方便处理的格式
    entities = [{'word': d['word'], 'entity': d['entity'], 'score': d['score']} for d in ner_results]

    # 然后，我们创建一个新的列表来存储处理后的实体
    processed_entities = []
    current_entity = []
    for entity in entities:
        if entity['entity'].startswith('B-') or (entity['entity'].startswith('I-') and not current_entity):
            if current_entity:
                processed_entities.append(current_entity)
            current_entity = [entity]
        elif entity['entity'].startswith('I-') and current_entity:
            current_entity.append(entity)
    if current_entity:
        processed_entities.append(current_entity)

    return processed_entities

In [6]:
def json_serializable(item):
    """Convert non-serializable items to serializable."""
    if isinstance(item, np.float32):
        return float(item)
    raise TypeError(f"Type {type(item)} not serializable")

In [7]:
def perform_ner_on_dataframe(df, country_name):
    count = []
    check = []
    ner_results_data_list = []
    
    for i in tqdm(range(len(df)), desc=f"Processing {country_name}"):
        ner_results = nlp(df["Text"].iloc[i])
        processed_entities = process_entities(ner_results)

        country_instances = []
        country_check = []
        aliases = country_aliases.get(country_name, [country_name])
        for entity_group in processed_entities:
            words = [entity['word'] for entity in entity_group]
            entity_name = ' '.join(words)
            entity_type = entity_group[0]['entity']
            entity_score = sum(entity['score'] for entity in entity_group) / len(entity_group)
            if entity_type in ["B-LOC", "B-ORG"] and entity_score > 0.98:
                country_check.append(entity_name)
                if any(alias in entity_name for alias in aliases):
                    country_instances.append(entity_name)

        count.append(len(country_instances))
        check.append(country_check)
        
        ner_result = {
            'Date': df['Date'].iloc[i],
            'Headline': df['Headline'].iloc[i],
            'NER': json.dumps(ner_results, default=json_serializable)  # convert ner_results to string
        }
        ner_results_data_list.append(ner_result)

    df['Count'] = count
    df = df[df['Count'] >= 3]
    df = df.drop(['Count'], axis=1)
    df_ner_results = pd.DataFrame(ner_results_data_list)
    
    return df, df_ner_results

In [8]:
# 定义一个字典储存每个国家的别名
country_aliases = {
"United States": ["USA", "America", "US", "United States", "UnitedStates"],
"Canada": ["Canada", "CA"],
"United Kingdom": ["UK", "United Kingdom", "Britain", "England", "Scotland", "Wales", "Northern Ireland", "UnitedKingdom"],
"Australia": ["Australia", "AU", "Aussie"],
"China": ["China", "PRC"],
"Denmark": ["Denmark", "DK"],
"Finland": ["Finland", "FI"],
"France": ["France", "French Republic", "FR"],
"Germany": ["Germany", "DE"],
"Japan": ["Japan", "JP"],
"Italy": ["Italy", "Italian Republic", "IT"],
"Netherlands": ["Netherlands", "Holland", "NL"],
"Norway": ["Norway", "NO"],
"Portugal": ["Portugal", "PT"],
"Singapore": ["Singapore", "SG"],
"South Korea": ["South Korea", "KR", "SouthKorea"],
"Spain": ["Spain", "ES"],
"Sweden": ["Sweden", "SE"],
"Switzerland": ["Switzerland", "Swiss Confederation", "Swiss", "CH"],
"New Zealand": ["New Zealand", "NZ", "NewZealand"]
}

In [None]:
lst_files = []

Path = "../Data/Articles/MWArticles/*.csv"

for fname in glob.glob(Path):
    lst_files.append(fname)

lst_files = sorted(lst_files)

for file in lst_files:
    file_name = os.path.basename(file)  # Get the file name from the full file path
    country_name = os.path.splitext(file_name)[0]  # Remove the file extension
    country_name = country_name.replace("_articles", "")  # Remove the "_articles" part of the file name


    # 仅对单个国家的数据进行实体识别
    if country_name == "China":
    # 对所有的未进行过NER的国家进行实体识别
    # if os.path.isfile('../Data/MW_NER/{0}.csv'.format(country_name)):
    #     print("File already exists: {0}.csv".format(country_name))
    # else:
        print(file)
        df = pd.read_csv(file)

        if df.shape[0] != 0:
            df = preprocess_dataframe(df, use_parse=True)

            df, df_ner_results = perform_ner_on_dataframe(df, country_name)

            # Convert the list of relevant articles back to a DataFrame
            df.to_csv(r'../Data/MW_NER/{0}.csv'.format(country_name))
            df_ner_results.to_csv(r'../Data/MW_NER_Results/{0}.csv'.format(country_name))

../Data/Articles/MWArticles\China_articles.csv


Processing China:   0%|          | 0/59958 [00:00<?, ?it/s]

In [10]:
# 读取所有国家的实体识别后剩余的数据, 然后输出剩余数据的数量
lst_ner_files = []
for fname in glob.glob("../Data/NER/MW_NER/*.csv"):
    lst_ner_files.append(fname)
for file in lst_ner_files:
    file_name = os.path.basename(file)  # Get the file name from the full file path
    country_name = os.path.splitext(file_name)[0]  # Remove the file extension
    df = pd.read_csv(file)
    print(country_name, df.shape[0])

Australia 1989
Canada 1069
China 15272
Denmark 75
Finland 64
France 1055
Germany 1102
Italy 828
Japan 2660
Netherlands 66
New Zealand 279
Norway 165
Portugal 125
Singapore 362
South Korea 395
Spain 991
Sweden 150
Switzerland 173
United Kingdom 217
United States 1044


In [4]:
def process_single_article(text, country_name):
    ner_results = nlp(text)
    country_instances = [d for d in ner_results if (d['entity'] in "B-ORG") and (d['word'] in country_name) and (d['score'] > 0.98)]
    return len(country_instances)

../Data/countries_integration/Denmark_articles.csv


Processing Denmark:   0%|          | 0/1412 [00:00<?, ?it/s]

In [6]:
# Find all csv files in the path
csv_files = glob.glob('../Data/countries_integration/*.csv')

# Sort the file names
csv_files = sorted(csv_files)

# Read the first csv file
df = pd.read_csv(csv_files[0])

# Get the first row of the DataFrame
first_row = df.iloc[43]

# Extract the 'Date', 'Headline' and 'Text' columns
date = first_row['Date']
headline = first_row['Headline']
text = first_row['Text']

1412
839


In [10]:
ner_results

[{'entity': 'B-ORG',
  'score': 0.9932053,
  'index': 1,
  'word': 'AN',
  'start': 0,
  'end': 2},
 {'entity': 'I-ORG',
  'score': 0.992149,
  'index': 2,
  'word': '##Z',
  'start': 2,
  'end': 3},
 {'entity': 'B-ORG',
  'score': 0.99892753,
  'index': 4,
  'word': 'RB',
  'start': 9,
  'end': 11},
 {'entity': 'I-ORG',
  'score': 0.9988914,
  'index': 5,
  'word': '##A',
  'start': 11,
  'end': 12},
 {'entity': 'B-ORG',
  'score': 0.9988292,
  'index': 20,
  'word': 'Australia',
  'start': 56,
  'end': 65},
 {'entity': 'I-ORG',
  'score': 0.99927264,
  'index': 21,
  'word': '&',
  'start': 66,
  'end': 67},
 {'entity': 'I-ORG',
  'score': 0.9993352,
  'index': 22,
  'word': 'New',
  'start': 68,
  'end': 71},
 {'entity': 'I-ORG',
  'score': 0.9992262,
  'index': 23,
  'word': 'Zealand',
  'start': 72,
  'end': 79},
 {'entity': 'I-ORG',
  'score': 0.999316,
  'index': 24,
  'word': 'Banking',
  'start': 80,
  'end': 87},
 {'entity': 'I-ORG',
  'score': 0.99929905,
  'index': 25,
  'w