# Link prediction data preparation
Here, we'll prep the data for input to STHN.

In [1]:
import networkx as nx
import pandas as pd
import numpy as np

## Read in the graph

In [2]:
graph = nx.read_graphml('../data/kg/all_drought_dt_co_occurrence_graph_02May2024.graphml')

KeyboardInterrupt: 

## Create a node mapping from strings to integers

In [3]:
node_mappings = {n: i for i, n in enumerate(graph.nodes)}
int_graph = nx.relabel_nodes(graph, node_mappings)

## Get edgelist

In [4]:
edgelist = nx.to_pandas_edgelist(int_graph)

In [5]:
edgelist.head()

Unnamed: 0,source,target,is_desiccation,uids_of_origin,is_drought,num_doc_mentions_all_time,first_year_mentioned
0,0,1,False,WOS:000623658100043,True,1,2021
1,0,2,False,WOS:000623658100043,True,1,2021
2,0,49583,False,WOS:000621810600016,True,1,2020
3,0,67147,False,WOS:000621810600016,True,1,2020
4,0,313607,False,WOS:000621810600016,True,1,2020


## Make int mapping for edge labels
Importantly, edge labels can't be 0, it throws an error during training.

In [6]:
def labels_to_int(row):
    if row['is_drought']:
        if row['is_desiccation']:
            return 1
        elif not row['is_desiccation']:
            return 2
    elif not row['is_drought']:
        if row['is_desiccation']:
            return 3

In [7]:
label_map = {
    'both': 1,
    'drought_only': 2,
    'desiccation_only': 3
}

In [8]:
edgelist['int_label'] = edgelist.apply(labels_to_int, axis=1)

In [9]:
edgelist.head()

Unnamed: 0,source,target,is_desiccation,uids_of_origin,is_drought,num_doc_mentions_all_time,first_year_mentioned,int_label
0,0,1,False,WOS:000623658100043,True,1,2021,2
1,0,2,False,WOS:000623658100043,True,1,2021,2
2,0,49583,False,WOS:000621810600016,True,1,2020,2
3,0,67147,False,WOS:000621810600016,True,1,2020,2
4,0,313607,False,WOS:000621810600016,True,1,2020,2


## Rename and drop columns for final format

In [10]:
edges = edgelist.rename(columns={'source': 'src', 'target': 'dst', 'int_label': 'label', 'first_year_mentioned': 'time'})
edges = edges.drop(columns=['num_doc_mentions_all_time', 'is_desiccation', 'is_drought', 'uids_of_origin'])
edges = edges.reset_index().rename(columns={'index': 'idx'})
edges.head()

Unnamed: 0,idx,src,dst,time,label
0,0,0,1,2021,2
1,1,0,2,2021,2
2,2,0,49583,2020,2
3,3,0,67147,2020,2
4,4,0,313607,2020,2


## Train/test split
To do this, we need to sort chronologically, then perform the split. We're going to use the same split that the STHN paper authors used, which is 70/15/15. However, since we have many many values in the same years, we need to make sure that the splits don't fall in the middle of a year, which would cause data leakage. So we'll try and get as close as we can to those splits.

In [11]:
edges = edges.sort_values(by='time').reset_index(drop=True)
edges.head()

Unnamed: 0,idx,src,dst,time,label
0,779588,21407,108514,1985,2
1,779605,21409,108515,1985,2
2,779606,21409,108516,1985,2
3,779607,21409,56921,1985,2
4,945054,56921,108515,1985,2


In [12]:
ideal_splits = {'train': round(0.7*len(edges)), 'validation': round(0.15*len(edges)), 'test': round(0.15*len(edges))}
ideal_splits

{'train': 902450, 'validation': 193382, 'test': 193382}

In [13]:
year_counts = edges.groupby('time').count()['idx'].to_dict()

In [14]:
idx_cuts = {}
tracker = 0
for year, count in year_counts.items():
    if tracker <= ideal_splits['train']:
        tracker += count
    elif (tracker > ideal_splits['train']) and ('train' not in idx_cuts.keys()):
        idx_cuts['train'] = tracker
        tracker += count
    
    elif tracker <= (ideal_splits['train'] + ideal_splits['validation']):
        tracker += count
    elif (tracker > ideal_splits['train'] + ideal_splits['validation']) and ('validation' not in idx_cuts.keys()):
        idx_cuts['validation'] = tracker
        tracker += count

In [15]:
idx_cuts

{'train': 939296, 'validation': 1155781}

In [16]:
edges['ext_roll'] = [0] * len(edges)
edges.loc[:idx_cuts['train'], 'ext_roll'] = 0
edges.loc[idx_cuts['train'] : idx_cuts['validation'], 'ext_roll'] = 1
edges.loc[idx_cuts['validation']:, 'ext_roll'] = 2

In [17]:
edges.groupby(['ext_roll', 'time']).count()

Unnamed: 0_level_0,Unnamed: 1_level_0,idx,src,dst,label
ext_roll,time,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
0,1985,11,11,11,11
0,1986,157,157,157,157
0,1987,13,13,13,13
0,1988,20,20,20,20
0,1989,44,44,44,44
0,1990,1332,1332,1332,1332
0,1991,7350,7350,7350,7350
0,1992,9185,9185,9185,9185
0,1993,8838,8838,8838,8838
0,1994,10904,10904,10904,10904


In [18]:
edges.head()

Unnamed: 0,idx,src,dst,time,label,ext_roll
0,779588,21407,108514,1985,2,0
1,779605,21409,108515,1985,2,0
2,779606,21409,108516,1985,2,0
3,779607,21409,56921,1985,2,0
4,945054,56921,108515,1985,2,0


## Save

In [19]:
edges.to_csv('../data/ml_inputs/sthn_co_occurrence_input_20May2024.csv', index=False)