In [1]:
import json
from collections import defaultdict
from tqdm import tqdm

import spacy
import numpy as np
import torch

import networkx as nx
import pandas as pd
import dgl

https://stackoverflow.com/questions/70835924/how-to-get-a-description-for-each-spacy-ner-entity

In [2]:
nlp1 = spacy.load('en_core_web_trf')
nlp2 = spacy.load('xx_ent_wiki_sm')

In [2]:
DATE_FORMAT = '%m/%d/%Y'

In [3]:
data_2022 = pd.read_csv('ICEWS-2022.csv', header=0, index_col=['Event ID'], parse_dates=['Event Date'], date_format=DATE_FORMAT)
data_2023 = pd.read_csv('ICEWS-2023.csv', header=0, index_col=['Event ID'], parse_dates=['Event Date'], date_format=DATE_FORMAT)

data = pd.concat([data_2022, data_2023])
data.drop(columns=['Latitude', 'Longitude', ])
data.info()

<class 'pandas.core.frame.DataFrame'>
Index: 997258 entries, 36615529 to 37764080
Data columns (total 21 columns):
 #   Column           Non-Null Count   Dtype         
---  ------           --------------   -----         
 0   Event Date       997258 non-null  datetime64[ns]
 1   Source Name      997258 non-null  object        
 2   Source Sectors   729968 non-null  object        
 3   Source Country   954831 non-null  object        
 4   Source Type      997258 non-null  object        
 5   Event Text       997258 non-null  object        
 6   CAMEO Code       997258 non-null  int64         
 7   Intensity        997258 non-null  float64       
 8   Target Name      997258 non-null  object        
 9   Target Sectors   620065 non-null  object        
 10  Target Country   937119 non-null  object        
 11  Target Type      997258 non-null  object        
 12  Story ID         997258 non-null  int64         
 13  Sentence Number  997258 non-null  int64         
 14  Publisher       

In [4]:
data.head()

Unnamed: 0_level_0,Event Date,Source Name,Source Sectors,Source Country,Source Type,Event Text,CAMEO Code,Intensity,Target Name,Target Sectors,...,Target Type,Story ID,Sentence Number,Publisher,City,District,Province,Country,Latitude,Longitude
Event ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
36615529,2022-01-01,South Korea,,South Korea,Country,Demand rights,1043,-5.0,Government (South Korea),Government,...,Other,54339603,2,Yonhap English News,Seoul,,Seoul-teukbyeolsi,South Korea,37.5683,126.978
36615519,2022-01-01,Ministry (Indonesia),Government,Indonesia,Other,Reject,120,-4.0,Reuters,"Media,News,Social",...,Organization,54339616,2,Unknown,Jakarta,,Daerah Khusus Ibukota Jakarta,Indonesia,-6.21462,106.845
36615520,2022-01-01,Denis Moncada,"Executive,Foreign Ministry,Government",Nicaragua,Organization,Host a visit,43,2.8,China,,...,Country,54339617,1,Al Jazeera English,Managua,,Departamento de Managua,Nicaragua,12.1328,-86.2504
36615521,2022-01-01,Denis Moncada,"Executive,Foreign Ministry,Government",Nicaragua,Organization,Praise or endorse,51,3.4,Foreign Affairs (China),"Executive,Foreign Ministry,Government",...,Organization,54339617,1,Al Jazeera English,Managua,,Departamento de Managua,Nicaragua,12.1328,-86.2504
36615523,2022-01-01,China,,China,Country,Make a visit,42,1.9,Denis Moncada,"Executive,Foreign Ministry,Government",...,Organization,54339617,1,Al Jazeera English,Managua,,Departamento de Managua,Nicaragua,12.1328,-86.2504


In [5]:
ent_type_map = {}
for _, name, ent_type in data[['Source Name', 'Source Type']].itertuples(name=None):
    ent_type_map[name] = ent_type
for _, name, ent_type in data[['Target Name', 'Target Type']].itertuples(name=None):
    ent_type_map[name] = ent_type

In [6]:
from collections import Counter

ntype_count = Counter(ent_type_map.values())
print(ntype_count)

Counter({'Other': 12277, 'Organization': 7928, 'Name': 4759, 'Country': 456, 'Political Group': 323, 'Object': 17, 'Infrastructure': 10, 'Date': 7, 'Art': 4, 'Percent': 2, 'Law': 2, 'Event': 1})


In [5]:
entities = list(set(data['Source Name'].unique()) | set(data['Target Name'].unique()))
entities.sort()
print(len(entities))

25786


In [36]:
count, mismatch, bad = 0, 0, 0
ents = set()
for entity, doc1, doc2 in zip(entities, nlp1.pipe(entities, disable=['tok2vec', 'tagger', 'parser', 'attribute_ruler', 'lemmatizer']), nlp2.pipe(entities, disable=['tok2vec', 'tagger', 'parser', 'attribute_ruler', 'lemmatizer'])):
    # if count >= 10: break
    try:
       ent1 = doc1.ents[0].label_
       ents.add(ent1)
    except IndexError:
        ent1 = None
    try:
        ent2 = doc2.ents[0].label_
        if ent2 == 'PER':
            ent2 = 'PERSON'
        ents.add(ent2)
    except IndexError:
        ent2 = None

    if ent1 or ent2:
        count += 1
    if ent1 != ent2:
        mismatch += 1
    if not ent1 and not ent2:
        bad += 1

print(f'{count} / {len(entities)}')
print(f'{mismatch} / {len(entities)}')
print(f'{bad} / {len(entities)}')
print(ents)

1846 / 25786
60 / 25786
set()


In [8]:
data['Source Type'].value_counts()

Source Type
Other              320298
Country            246877
Name               222461
Organization       198118
Political Group      8884
Object                515
Infrastructure         52
Date                   44
Art                     3
Law                     3
Percent                 2
Event                   1
Name: count, dtype: int64

In [6]:
etype_counts = data['Event Text'].value_counts()
etypes = etype_counts[etype_counts > 10].index.tolist()

In [7]:
data_redux = data[data['Event Text'].isin(etypes)]

In [8]:
data_redux.info()

<class 'pandas.core.frame.DataFrame'>
Index: 997089 entries, 36615529 to 37764080
Data columns (total 21 columns):
 #   Column           Non-Null Count   Dtype         
---  ------           --------------   -----         
 0   Event Date       997089 non-null  datetime64[ns]
 1   Source Name      997089 non-null  object        
 2   Source Sectors   729857 non-null  object        
 3   Source Country   954666 non-null  object        
 4   Source Type      997089 non-null  object        
 5   Event Text       997089 non-null  object        
 6   CAMEO Code       997089 non-null  int64         
 7   Intensity        997089 non-null  float64       
 8   Target Name      997089 non-null  object        
 9   Target Sectors   619973 non-null  object        
 10  Target Country   936958 non-null  object        
 11  Target Type      997089 non-null  object        
 12  Story ID         997089 non-null  int64         
 13  Sentence Number  997089 non-null  int64         
 14  Publisher       

In [9]:
node_counts = data_redux['Source Name'].value_counts() + data_redux['Target Name'].value_counts()

In [10]:
node_entities = node_counts[node_counts >= 5].index.tolist()
len(node_entities)

11819

In [11]:
data_redux = data_redux[(data_redux['Source Name'].isin(node_entities)) & (data_redux['Target Name'].isin(node_entities))]

In [12]:
data_redux.info()

<class 'pandas.core.frame.DataFrame'>
Index: 950359 entries, 36615529 to 37764068
Data columns (total 21 columns):
 #   Column           Non-Null Count   Dtype         
---  ------           --------------   -----         
 0   Event Date       950359 non-null  datetime64[ns]
 1   Source Name      950359 non-null  object        
 2   Source Sectors   686161 non-null  object        
 3   Source Country   909074 non-null  object        
 4   Source Type      950359 non-null  object        
 5   Event Text       950359 non-null  object        
 6   CAMEO Code       950359 non-null  int64         
 7   Intensity        950359 non-null  float64       
 8   Target Name      950359 non-null  object        
 9   Target Sectors   577849 non-null  object        
 10  Target Country   907127 non-null  object        
 11  Target Type      950359 non-null  object        
 12  Story ID         950359 non-null  int64         
 13  Sentence Number  950359 non-null  int64         
 14  Publisher       

In [13]:
data_redux.to_csv('icews_redux.csv', header=True, index=True)

In [16]:
entities = list(set(data_redux['Source Name'].unique()) | set(data_redux['Target Name'].unique()))
entities.sort()
print(len(entities))

11814


In [17]:
entity_map = {name.strip(): i for i, name in enumerate(entities)}

In [18]:
with open('entity_map.json', 'w') as f:
    json.dump(entity_map, f, indent=2)

In [17]:
with open('entity_map.json') as f:
    entity_map = json.load(f)

In [19]:
def get_sectors(x):
    if pd.notna(x):
        return x.split(',')
    else:
        return None

In [20]:
source_types = data_redux['Source Sectors'].apply(get_sectors).tolist()
target_types = data_redux['Target Sectors'].apply(get_sectors).tolist()

In [21]:
source_types = set(
    v.strip()
    for vals in source_types if vals is not None
    for v in vals
)

target_types = set(
    v.strip()
    for vals in target_types if vals is not None
    for v in vals
)

sector_types = source_types | target_types

In [23]:
len(sector_types)

250

In [24]:
sector_types

{'(National) Major Party',
 '(National) Minor Party',
 'Afar',
 'Afro-American',
 'Afro-Caribbean',
 'Agricultural',
 'Agricultural NGOs',
 'Agriculture / Fishing / Forestry Ministry',
 'Air Force',
 'Arab',
 'Army',
 'Army Special Forces',
 'Atheist',
 'Aymara',
 'Banned Parties',
 'Bantu',
 'Buddhist',
 'Business',
 'Business IGOs',
 'Business NGOs',
 'Cabinet',
 'Catholic',
 'Center Left',
 'Center Right',
 'Central-East',
 'Centrist',
 'Charity IGOs',
 'Charity NGOs',
 'Chechen',
 'Christian',
 'Coast Guard',
 'Communist',
 'Consulting / Financial Services Business',
 'Consumer Goods Business',
 'Consumer Services Business',
 'Coptic',
 'Criminals / Gangs',
 'Defense / Security Business',
 'Defense / Security Ministry',
 'Development IGOs',
 'Development NGOs',
 'Disaster Ministry',
 'Dissident',
 'Drugs Ministry',
 'Druze',
 'Durable Goods Business',
 'East-Coastal',
 'Education',
 'Education IGOs',
 'Education Ministry',
 'Education NGOs',
 'Elections Ministry',
 'Elite',
 'Energ

In [25]:
def make_node_feat_set(x):
    if pd.notna(x):
        return ','.join(x)

In [26]:
source_feats = data_redux.groupby('Source Name')['Source Sectors'].agg(lambda x: set(e for l in x if not pd.isna(l) for e in l.split(',')))
target_feats = data_redux.groupby('Target Name')['Target Sectors'].agg(lambda x: set(e for l in x if not pd.isna(l) for e in l.split(',')))

In [27]:
source_feats

Source Name
10 Downing Street                    {Executive, Government, Executive Office}
A Just Russia Party          {(National) Minor Party, Parties, Ideological,...
A.K. Antony                  {Executive, Elite, Government, Defense / Secur...
A.K. Sharma                                               {Government, Police}
A.P. Sharma                  {Executive, Elite, Finance / Economy / Commerc...
                                                   ...                        
Zuzana ─îaputov├í                    {Executive, Government, Executive Office}
al-Aqsa Martyrs' Brigades                       {Dissident, Organized Violent}
Ã“scar Arias SÃ¡nchez        {(National) Major Party, Elite, Ideological, C...
Ã“scar IvÃ¡n Zuluaga         {(National) Major Party, Government Major Part...
Ã‰douard Philippe             {Executive, Elite, Executive Office, Government}
Name: Source Sectors, Length: 11731, dtype: object

In [28]:
target_feats

Target Name
10 Downing Street                    {Executive, Government, Executive Office}
A Just Russia Party          {(National) Minor Party, Parties, Ideological,...
A.K. Antony                  {Executive, Elite, Government, Defense / Secur...
A.K. Sharma                                               {Government, Police}
A.P. Sharma                  {Executive, Elite, Finance / Economy / Commerc...
                                                   ...                        
Zuzana ─îaputov├í                    {Executive, Government, Executive Office}
al-Aqsa Martyrs' Brigades                       {Dissident, Organized Violent}
Ã“scar Arias SÃ¡nchez        {(National) Major Party, Elite, Ideological, C...
Ã“scar IvÃ¡n Zuluaga         {(National) Major Party, Government Major Part...
Ã‰douard Philippe             {Executive, Elite, Executive Office, Government}
Name: Target Sectors, Length: 11727, dtype: object

In [29]:
node_feats = pd.merge(source_feats, target_feats, how='outer', left_index=True, right_index=True)

In [30]:
def merge_sets(x):
    source = x['Source Sectors']
    target = x['Target Sectors']
    if not pd.isna(source) and not pd.isna(target):
        return source | target
    elif pd.isna(target):
        return source
    elif pd.isna(source):
        return target

In [31]:
node_feats['Sectors'] = node_feats.apply(merge_sets, axis=1)

In [32]:
node_feats.drop(columns=['Source Sectors', 'Target Sectors'], inplace=True)

In [33]:
node_feats

Unnamed: 0,Sectors
10 Downing Street,"{Executive, Government, Executive Office}"
A Just Russia Party,"{(National) Minor Party, Ideological, Center L..."
A.K. Antony,"{Executive, Defense / Security Ministry, Elite..."
A.K. Sharma,"{Government, Police}"
A.P. Sharma,"{Executive, Elite, Finance / Economy / Commerc..."
...,...
Zuzana ─îaputov├í,"{Executive, Government, Executive Office}"
al-Aqsa Martyrs' Brigades,"{Dissident, Organized Violent}"
Ã“scar Arias SÃ¡nchez,"{(National) Major Party, Elite, Ideological, C..."
Ã“scar IvÃ¡n Zuluaga,"{(National) Major Party, Government Major Part..."


In [35]:
node_feats_encoded = node_feats['Sectors'].str.join(sep='*').str.get_dummies(sep='*')

In [36]:
node_feats_encoded.head()

Unnamed: 0,Central-East,East-Coastal,Gikuyu-Kamba,Planes),Sotho-Tswana,Sudan,Tanks,(National) Major Party,(National) Minor Party,Afar,...,Tutsi,Unicameral,Unidentified Forces,Upper House,Utilities Business,Uyghur,Uzbek,Water Ministry,Women / Children / Social / Welfare / Development / Religion Ministry,Yoruba
10 Downing Street,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
A Just Russia Party,0,0,0,0,0,0,0,0,1,0,...,0,0,0,0,0,0,0,0,0,0
A.K. Antony,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
A.K. Sharma,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
A.P. Sharma,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [219]:
node_feats_encoded.to_csv('node_feats.csv', header=True, index=True)

In [62]:
node_feat_dict = {}
count = 0
for i, row in node_feats_encoded.iterrows():
    node_feat_dict[entity_map[i.strip()]] = row.to_numpy()

In [64]:
len(node_feat_dict)

11814

In [37]:
len(node_feats)

11814

In [14]:
edge_data = pd.DataFrame()

In [18]:
edge_data['date'] = data_redux['Event Date']
edge_data['timestamp'] = data_redux['Event Date'].astype(int)
edge_data['source_name'] = data_redux['Source Name'].apply(lambda x: x.strip())
edge_data['source_node'] = data_redux['Source Name'].apply(lambda x: entity_map[x.strip()])
edge_data['source_node_type'] = data_redux['Source Type']

edge_data['target_name'] = data_redux['Target Name'].apply(lambda x: x.strip())
edge_data['target_node'] = data_redux['Target Name'].apply(lambda x: entity_map[x.strip()])
edge_data['target_node_type'] = data_redux['Target Type']
edge_data['edge_type'] = data_redux['Event Text']
edge_data['intensity'] = data_redux['Intensity']

In [91]:
edge_data.head()

Unnamed: 0_level_0,date,timestamp,source_name,source_node,source_node_type,target_name,target_node,target_node_type,edge_type,intensity
Event ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
36615529,2022-01-01,1640995200000000000,South Korea,10828,Country,Government (South Korea),4285,Other,Demand rights,-5.0
36615519,2022-01-01,1640995200000000000,Ministry (Indonesia),7992,Other,Reuters,10208,Organization,Reject,-4.0
36615520,2022-01-01,1640995200000000000,Denis Moncada,2837,Organization,China,1773,Country,Host a visit,2.8
36615521,2022-01-01,1640995200000000000,Denis Moncada,2837,Organization,Foreign Affairs (China),3870,Organization,Praise or endorse,3.4
36615523,2022-01-01,1640995200000000000,China,1773,Country,Denis Moncada,2837,Organization,Make a visit,1.9


In [97]:
edge_data.info()

<class 'pandas.core.frame.DataFrame'>
Index: 950359 entries, 36615529 to 37764068
Data columns (total 10 columns):
 #   Column            Non-Null Count   Dtype         
---  ------            --------------   -----         
 0   date              950359 non-null  datetime64[ns]
 1   timestamp         950359 non-null  int64         
 2   source_name       950359 non-null  object        
 3   source_node       950359 non-null  int64         
 4   source_node_type  950359 non-null  object        
 5   target_name       950359 non-null  object        
 6   target_node       950359 non-null  int64         
 7   target_node_type  950359 non-null  object        
 8   edge_type         950359 non-null  object        
 9   intensity         950359 non-null  float64       
dtypes: datetime64[ns](1), float64(1), int64(3), object(5)
memory usage: 112.0+ MB


In [94]:
print(edge_data['date'].min())
print(edge_data['date'].max())

2022-01-01 00:00:00
2023-04-10 00:00:00


In [99]:
from datetime import datetime

In [100]:
train_cutoff = datetime.strptime('2023-01-01', '%Y-%m-%d')
val_cutoff = datetime.strptime('2023-02-15', '%Y-%m-%d')

In [101]:
print(train_cutoff)
print(val_cutoff)

2023-01-01 00:00:00
2023-02-15 00:00:00


In [93]:
edge_data['date'] < '2023-01-01'

Event ID
36615529     True
36615519     True
36615520     True
36615521     True
36615523     True
            ...  
37764060    False
37764062    False
37764083    False
37764075    False
37764068    False
Name: date, Length: 950359, dtype: bool

In [103]:
graph_data = defaultdict(lambda: ([], []))
e_intensity = defaultdict(list)
e_date = defaultdict(list)
e_train = defaultdict(list)
e_val = defaultdict(list)
e_test = defaultdict(list)
for i, row in tqdm(edge_data.iterrows()):
    edge_schema = (dgl.NTYPE, row['edge_type'], dgl.NTYPE)

    graph_data[edge_schema][0].append(row['source_node'])
    graph_data[edge_schema][1].append(row['target_node'])

    e_intensity[edge_schema].append(row['intensity'])
    e_date[edge_schema].append(row['timestamp'])
    if row['date'] < train_cutoff:
        e_train[edge_schema].append(True)
        e_val[edge_schema].append(False)
        e_test[edge_schema].append(False)
    elif row['date'] < val_cutoff:
        e_train[edge_schema].append(False)
        e_val[edge_schema].append(True)
        e_test[edge_schema].append(False)
    else:
        e_train[edge_schema].append(False)
        e_val[edge_schema].append(False)
        e_test[edge_schema].append(True)

950359it [02:37, 6034.12it/s]


In [104]:
G = dgl.heterograph(graph_data)

In [105]:
G.ndata['feat'] = torch.FloatTensor(
    np.array([node_feat_dict[node.item()] for node in G.nodes()])
)

In [106]:
e_intensity = {k: torch.FloatTensor(v) for k, v in e_intensity.items()}
e_date = {k: torch.LongTensor(v) for k, v in e_date.items()}
e_train = {k: torch.LongTensor(v) for k, v in e_train.items()}
e_val = {k: torch.LongTensor(v) for k, v in e_val.items()}
e_test = {k: torch.LongTensor(v) for k, v in e_test.items()}

In [107]:
G.edata['intensity'] = e_intensity
G.edata['timestamp'] = e_date
G.edata['train_mask'] = e_train
G.edata['val_mask'] = e_val
G.edata['test_mask'] = e_test

In [109]:
G.edata.keys()

dict_keys(['intensity', 'timestamp', 'train_mask', 'val_mask', 'test_mask'])

In [110]:
dgl.edge_subgraph(G, G.edata['test_mask'])

Graph(num_nodes={'_TYPE': 7401},
      num_edges={('_TYPE', 'Abduct, hijack, or take hostage', '_TYPE'): 207, ('_TYPE', 'Accede to demands for change in institutions, regime', '_TYPE'): 1, ('_TYPE', 'Accede to demands for change in leadership', '_TYPE'): 13, ('_TYPE', 'Accede to demands for change in policy', '_TYPE'): 14, ('_TYPE', 'Accede to demands for rights', '_TYPE'): 13, ('_TYPE', 'Accede to requests or demands for political reform', '_TYPE'): 9, ('_TYPE', 'Accuse', '_TYPE'): 2591, ('_TYPE', 'Accuse of aggression', '_TYPE'): 10, ('_TYPE', 'Accuse of crime, corruption', '_TYPE'): 97, ('_TYPE', 'Accuse of espionage, treason', '_TYPE'): 49, ('_TYPE', 'Accuse of human rights abuses', '_TYPE'): 38, ('_TYPE', 'Accuse of war crimes', '_TYPE'): 93, ('_TYPE', 'Acknowledge or claim responsibility', '_TYPE'): 74, ('_TYPE', 'Apologize', '_TYPE'): 14, ('_TYPE', 'Appeal for aid', '_TYPE'): 29, ('_TYPE', 'Appeal for change in institutions, regime', '_TYPE'): 2, ('_TYPE', 'Appeal for change in 

In [69]:
G.ndata['feat'].shape

torch.Size([11814, 250])

In [81]:
G.find_edges(0, 'Reject')

(tensor([7992]), tensor([10208]))

In [111]:
dgl.save_graphs('dgl_icews_graph.bin', [G])