In [24]:
#imports 
import pandas as pd
import numpy as np
import os
import pickle as pkl
import datetime as datetime
from sklearn.preprocessing import StandardScaler
import statsmodels.formula.api as sm
import dgl.function as fn
from tqdm import tqdm
import networkx as nx

#imports for graph creation
import torch
from sklearn.preprocessing import StandardScaler
from itertools import combinations
from sklearn.metrics import r2_score
import matplotlib.pyplot as plt

#imports for graph learning
from torch.utils.data import DataLoader
from torch.autograd import Variable
import torch.nn as nn
from tqdm import trange
import torch
import torch_geometric.datasets as datasets
import torch_geometric.data as data
import torch_geometric.transforms as transforms

In [25]:
import pandas as pd
import requests

def get_sitc_codes():
    # URL of the JSON file
    url = 'https://comtradeapi.un.org/files/v1/app/reference/S4.json'

    try:
        # Send a GET request to the URL and fetch the data
        response = requests.get(url)
        response.raise_for_status()  # Check that the request was successful
        
        # Load the JSON data
        data = response.json()

        # Since the JSON data might be nested, use json_normalize with appropriate arguments
        if isinstance(data, list):
            # If the top level is a list
            df = pd.json_normalize(data)
        else:
            # If the top level is a dictionary
            # Identify the key that holds the main data (adjust the path as necessary)
            main_data_key = 'results'  # Adjust this based on the actual structure
            df = pd.json_normalize(data[main_data_key])

    except requests.exceptions.RequestException as e:
        print(f"Error fetching data: {e}")
    except ValueError as e:
        print(f"Error parsing JSON: {e}")
    except KeyError as e:
        print(f"Error processing JSON structure: {e}")

    return df

In [26]:
class TradeNetwork:
    """
    We define a class which computes the MST trade network for a given year 
    """
    
    def __init__(self, year = 1962, data_dir = "data"):
        self.year = year
        self.data_dir = data_dir
        
    def prepare_features(self, filter_gdp = True):
        
        ###IMPORT GDP###
        #prepare GDP as a set of features 
        with open('data/all_wb_indicators.pickle', 'rb') as handle:
            features_dict = pkl.load(handle)

        self.gdp = features_dict['NY.GDP.MKTP.CD']
        scaler = StandardScaler()
        self.gdp[["prev_gdp"]] = scaler.fit_transform(np.log(self.gdp[['YR'+str(self.year-1)]]))
        self.gdp[["current_gdp"]] = scaler.fit_transform(np.log(self.gdp[['YR'+str(self.year)]]))
        #rename and keep relevant columns
        self.gdp["country_code"] = self.gdp["economy"]
        self.gdp = self.gdp[["country_code", "prev_gdp", "current_gdp"]].dropna()
        
        ###IMPORT GDP GROWTH###
        #prepare GDP growth
        self.gdp_growth = features_dict['NY.GDP.MKTP.KD.ZG']
        self.gdp_growth["prev_gdp_growth"] = self.gdp_growth['YR'+str(self.year-1)]
        self.gdp_growth["current_gdp_growth"] = self.gdp_growth['YR'+str(self.year)] 
        self.gdp_growth["future_gdp_growth"] = self.gdp_growth['YR'+str(self.year+1)]
        #rename and keep relevant columns
        self.gdp_growth["country_code"] = self.gdp_growth["economy"]
        self.gdp_growth = self.gdp_growth[["country_code", "prev_gdp_growth",
                                "current_gdp_growth", "future_gdp_growth"]].dropna()
        
        ###IMPORT GDP PER CAPITA###
        self.gdp_per_capita = features_dict['NY.GDP.PCAP.CD']
        self.gdp_per_capita["prev_gdp_per_cap"] = self.gdp_per_capita['YR'+str(self.year-1)]
        self.gdp_per_capita["current_gdp_per_cap"] = self.gdp_per_capita['YR'+str(self.year)]
        self.gdp_per_capita["future_gdp_per_cap"] = self.gdp_per_capita['YR'+str(self.year+1)]
        #rename and keep relevant columns
        self.gdp_per_capita["country_code"] = self.gdp_per_capita["economy"]
        self.gdp_per_capita = self.gdp_per_capita[["country_code", "prev_gdp_per_cap",
                                "current_gdp_per_cap", "future_gdp_per_cap"]].dropna()
        
        ###IMPORT GDP PER CAPITA GROWTH###
        self.gdp_per_capita_growth = features_dict['NY.GDP.PCAP.KD.ZG']
        self.gdp_per_capita_growth["prev_gdp_per_cap_growth"] = self.gdp_per_capita_growth['YR'+str(self.year-1)]
        self.gdp_per_capita_growth["current_gdp_per_cap_growth"] = self.gdp_per_capita_growth['YR'+str(self.year)]
        self.gdp_per_capita_growth["future_gdp_per_cap_growth"] = self.gdp_per_capita_growth['YR'+str(self.year+1)]
        
        #rename and keep relevant columns
        self.gdp_per_capita_growth["country_code"] = self.gdp_per_capita_growth["economy"]
        self.gdp_per_capita_growth = self.gdp_per_capita_growth[["country_code", "prev_gdp_per_cap_growth",
                                "current_gdp_per_cap_growth", "future_gdp_per_cap_growth"]].dropna()
        
        ###MERGE ALL DATA FEATURES###
        self.features = pd.merge(self.gdp_growth, self.gdp, on = "country_code").dropna()
        self.features = pd.merge(self.features, self.gdp_per_capita, on = "country_code").dropna()
        self.features = pd.merge(self.features, self.gdp_per_capita_growth, on = "country_code").dropna()

    def prepare_network(self):
        """
        We create an initial, import-centric trade link pandas dataframe for a given year
        """
        #get product codes
        data_dict = get_sitc_codes()
        data_cross = []
        i = 0
        for item_def in list(data_dict["text"]):
            if(i >= 2):
                data_cross.append(item_def.split(" - ", 1))
            i = i+1

        self.product_codes = pd.DataFrame(data_cross, columns = ['code', 'product'])
        self.product_codes["sitc_product_code"] = self.product_codes["code"]
        
        #get country codes
        self.country_codes = pd.read_excel("data/ISO3166.xlsx")
        self.country_codes["location_code"] = self.country_codes["Alpha-3 code"]
        self.country_codes["partner_code"] = self.country_codes["Alpha-3 code"]
        self.country_codes["country_i"] = self.country_codes["English short name"]
        self.country_codes["country_j"] = self.country_codes["English short name"]
        
        #get trade data for a given year
        trade_data = pd.read_stata(self.data_dir + "/country_partner_sitcproduct4digit_year_"+ str(self.year)+".dta") 
        #merge with product / country descriptions
        trade_data = pd.merge(trade_data, self.country_codes[["location_code", "country_i"]],on = ["location_code"])
        trade_data = pd.merge(trade_data, self.country_codes[["partner_code", "country_j"]],on = ["partner_code"])
        trade_data = pd.merge(trade_data, self.product_codes[["sitc_product_code", "product"]], 
                              on = ["sitc_product_code"])
        ###select level of product aggregation
        trade_data["product_category"] = trade_data["sitc_product_code"].apply(lambda x: x[0:1])
        
        #keep only nodes that we have features for
        #trade_data = trade_data[trade_data["location_code"].isin(self.features["country_code"])]
        #trade_data = trade_data[trade_data["partner_code"].isin(self.features["country_code"])]
        
        if (len(trade_data.groupby(["location_code", "partner_code", "sitc_product_code"])["import_value"].sum().reset_index()) != len(trade_data)):
            print("import, export, product combination not unique!")
        self.trade_data1 = trade_data
        #from import-export table, create only import table
        #extract imports
        imports1 = trade_data[['location_id', 'partner_id', 'product_id', 'year',
               'import_value', 'sitc_eci', 'sitc_coi', 'location_code', 'partner_code',
               'sitc_product_code', 'country_i', 'country_j', 'product', "product_category"]]
        imports1 = imports1[imports1["import_value"] != 0]
        #transform records of exports into imports
        imports2 = trade_data[['location_id', 'partner_id', 'product_id', 'year',
               'export_value', 'sitc_eci', 'sitc_coi', 'location_code', 'partner_code',
               'sitc_product_code', 'country_i', 'country_j', 'product', "product_category"]]
        imports2["temp1"] = imports2['partner_code']
        imports2["temp2"] = imports2['location_code']

        imports2['location_code'] = imports2["temp1"]
        imports2['partner_code'] = imports2["temp2"]
        imports2["import_value"] = imports2["export_value"]
        imports2 = imports2[imports2["import_value"] != 0]
        imports2 = imports2[['location_id', 'partner_id', 'product_id', 'year',
               'import_value', 'sitc_eci', 'sitc_coi', 'location_code', 'partner_code',
               'sitc_product_code', 'country_i', 'country_j', 'product', "product_category"]]
        
        imports_table = pd.concat([imports1, imports2]).drop_duplicates()
        
        #rename columns for better clarity
        imports_table["importer_code"] = imports_table["location_code"]
        imports_table["exporter_code"] = imports_table["partner_code"]
        imports_table["importer_name"] = imports_table["country_i"]
        imports_table["exporter_name"] = imports_table["country_j"]
        
        cols = ["importer_code", "exporter_code", "importer_name", "exporter_name",
               'product_id', 'year', 'import_value', 'sitc_eci', 'sitc_coi',
               'sitc_product_code', 'product', "product_category"]
        imports_table = imports_table[cols]
        
        exporter_total = imports_table.groupby(["exporter_code"])["import_value"].sum().reset_index()
        exporter_total = exporter_total.rename(columns = {"import_value": "export_total"})
        
        importer_total = imports_table.groupby(["importer_code"])["import_value"].sum().reset_index()
        importer_total = importer_total.rename(columns = {"import_value": "import_total"})
        
        #sum imports across all products between countries into single value 
        imports_table_grouped = imports_table.groupby(["importer_code", "exporter_code"])["import_value"].sum().reset_index()
        
        #sum exports in each category 
        self.export_types = imports_table.groupby(["exporter_code", "product_category"])["import_value"].sum().reset_index()
        self.export_types = pd.merge(self.export_types, exporter_total, on = "exporter_code")
        #multiply by 100 to allow weights to scale better in GNN
        self.export_types["category_fraction"] = self.export_types.import_value/self.export_types.export_total*10
        ss = StandardScaler()
        columns = list(set(self.export_types["product_category"]))
        self.export_types = self.export_types[["exporter_code", "product_category", "category_fraction"]]\
        .pivot(index = ["exporter_code"], columns = ["product_category"], values = "category_fraction")\
        .reset_index().fillna(0)
        #rename columns
        rename_columns = []
        for col in self.export_types.columns:
            if(col == "exporter_code"):
                rename_columns.append(col)
            else:
                rename_columns.append("resource_" + col)
        self.export_types.columns = rename_columns
        self.export_types = self.export_types.rename(columns = {"exporter_code": "country_code"})
        self.features = pd.merge(self.features, self.export_types, 
                                on = "country_code", how = "left")
        
        #look at fraction of goods traded with each counterparty
        imports_table_grouped = pd.merge(imports_table_grouped, exporter_total, how = "left")
        imports_table_grouped["export_percent"] = imports_table_grouped["import_value"]/imports_table_grouped["export_total"]
        scaler = StandardScaler()
        imports_table_grouped[["export_percent_feature"]] = scaler.fit_transform(np.log(imports_table_grouped[["export_percent"]]))
        imports_table_grouped["export_percent_feature"] = imports_table_grouped["export_percent_feature"] + abs(min(imports_table_grouped["export_percent_feature"]))
        
        imports_table_grouped = pd.merge(imports_table_grouped, importer_total, how = "left")
        imports_table_grouped["import_percent"] = imports_table_grouped["import_value"]/imports_table_grouped["import_total"]
        scaler = StandardScaler()
        imports_table_grouped[["import_percent_feature"]] = scaler.fit_transform(np.log(imports_table_grouped[["import_percent"]]))
        imports_table_grouped["import_percent_feature"] = imports_table_grouped["import_percent_feature"] + abs(min(imports_table_grouped["import_percent_feature"]))
        
        self.trade_data = imports_table_grouped

    def graph_create(self, exporter = True,
            node_features = ['prev_gdp_growth', 'current_gdp_growth','prev_gdp','current_gdp'],
            node_labels = 'future_gdp_growth'):
        
        if(exporter):
            center_node = "exporter_code"
            neighbors = "importer_code"
            edge_features = 'export_percent'
        
        #filter features and nodes to ones that are connected to others in trade data
        # list_active_countries = list(set(list(self.trade_data ["importer_code"])+\
        #                 list(self.trade_data ["exporter_code"])))

        list_active_countries = ['ABW', 'AFG', 'AGO', 'ALB', 'AND', 'ARE', 'ARG', 'ARM', 'ASM',
       'ATG', 'AUS', 'AUT', 'AZE', 'BDI', 'BEL', 'BEN', 'BFA', 'BGD',
       'BGR', 'BHR', 'BHS', 'BIH', 'BLR', 'BLZ', 'BMU', 'BOL', 'BRA',
       'BRB', 'BRN', 'BTN', 'BWA', 'CAF', 'CAN', 'CHE', 'CHL', 'CHN',
       'CIV', 'CMR', 'COD', 'COG', 'COL', 'COM', 'CPV', 'CRI', 'CUB',
       'CUW', 'CYM', 'CYP', 'CZE', 'DEU', 'DMA', 'DNK', 'DOM', 'DZA',
       'ECU', 'EGY', 'ESP', 'EST', 'ETH', 'FIN', 'FJI', 'FRA', 'FSM',
       'GAB', 'GBR', 'GEO', 'GHA', 'GIN', 'GMB', 'GNB', 'GNQ', 'GRC',
       'GRD', 'GRL', 'GTM', 'GUM', 'GUY', 'HKG', 'HND', 'HRV', 'HTI',
       'HUN', 'IDN', 'IND', 'IRL', 'IRN', 'IRQ', 'ISL', 'ISR', 'ITA',
       'JAM', 'JOR', 'JPN', 'KAZ', 'KEN', 'KGZ', 'KHM', 'KNA', 'KOR',
       'KWT', 'LAO', 'LBN', 'LBR', 'LBY', 'LCA', 'LKA', 'LSO', 'LTU',
       'LUX', 'LVA', 'MAC', 'MAR', 'MDA', 'MDG', 'MDV', 'MEX', 'MHL',
       'MKD', 'MLI', 'MLT', 'MMR', 'MNE', 'MNG', 'MNP', 'MOZ', 'MRT',
       'MUS', 'MWI', 'MYS', 'NAM', 'NER', 'NGA', 'NIC', 'NLD', 'NOR',
       'NPL', 'NRU', 'NZL', 'OMN', 'PAK', 'PAN', 'PER', 'PHL', 'PLW',
       'PNG', 'POL', 'PRT', 'PRY', 'PSE', 'PYF', 'QAT', 'ROU', 'RUS',
       'RWA', 'SAU', 'SDN', 'SEN', 'SGP', 'SLB', 'SLE', 'SLV', 'SMR',
       'SRB', 'SSD', 'STP', 'SUR', 'SVK', 'SVN', 'SWE', 'SWZ', 'SXM',
       'SYC', 'SYR', 'TCD', 'TGO', 'THA', 'TJK', 'TKM', 'TLS', 'TON',
       'TTO', 'TUN', 'TUR', 'TUV', 'TZA', 'UGA', 'UKR', 'URY', 'USA',
       'UZB', 'VCT', 'VEN', 'VNM', 'VUT', 'WSM', 'YEM', 'ZAF', 'ZMB',
       'ZWE']

        # Create a new DataFrame with the list of country codes
        df_new = pd.DataFrame(list_active_countries, columns=['country_code'])
        self.features = pd.merge(df_new, self.features, on='country_code', how='left')
        self.features = self.features.fillna(0)

        # self.features = self.features[self.features["country_code"].isin(list_active_countries)].reset_index()
        # self.features.fillna(0, inplace = True)
        self.features["node_numbers"] = self.features.index

        #create lookup dictionary making node number / node features combatible with ordering of nodes
        #in our edge table

        self.node_lookup1 = self.features.set_index('node_numbers').to_dict()['country_code']
        self.node_lookup2 = self.features.set_index('country_code').to_dict()['node_numbers']
        
        #get individual country's features
        self.regression_table = pd.merge(self.features, self.trade_data,
                        left_on = "country_code",
                        right_on = center_node, how = 'right')
        #get features for trade partners
        self.regression_table = pd.merge(self.features, self.regression_table,
                                        left_on = "country_code",
                                        right_on = neighbors, how = "right",
                                        suffixes = ("_neighbors", ""))
        
        self.trade_data = self.trade_data[self.trade_data[neighbors].isin(self.node_lookup2)]
        self.trade_data = self.trade_data[self.trade_data[center_node].isin(self.node_lookup2)]

        self.regression_table["source"] = self.trade_data[neighbors].apply(lambda x: self.node_lookup2[x])
        self.regression_table["target"] = self.trade_data[center_node].apply(lambda x: self.node_lookup2[x])    

        self.regression_table = self.regression_table.dropna()
        #filter only to relevant columns
        self.relevant_columns = ["source", "target"]
        self.relevant_columns.extend(node_features)
        self.relevant_columns.append(node_labels)
        self.graph_table = self.regression_table[self.relevant_columns]
        
        if(self.graph_table.isnull().values.any()): print("edges contain null / inf values")

        self.node_attributes = torch.tensor(np.array(self.features[node_features]))\
        .to(torch.float)
        self.source_nodes = list(self.graph_table["source"])
        self.target_nodes = list(self.graph_table["target"])

        self.edge_attributes = list(self.trade_data[edge_features])
        
        self.pyg_graph = data.Data(x = self.node_attributes,
                                   edge_index = torch.tensor([self.source_nodes, self.target_nodes]),
                                   edge_attr = torch.tensor(self.edge_attributes).to(torch.float),
                                   y = torch.tensor(list(self.features[node_labels])).to(torch.float))

## All Events

In [27]:
years = range(1962,2019)

train_years = [2005, 1969, 2002, 1997, 1993, 1982, 2001, 2000, 1962, 1985, 1978, 2016, 1986, 1987, 1989, 1971, 2013, 1996, 1995, 1967, 2017, 1974, 1990, 1977, 1980, 2014, 1965, 1984, 2006, 1973, 1968, 1981, 1970, 1991]
val_years = [1975, 1983, 2009, 1966, 1999, 1988, 2007, 1979, 1972, 2015, 2003]
test_years = [1963, 1964, 1976, 1992, 1994, 1998, 2004, 2008, 2010, 2011, 2012, 2018]

In [28]:
train_graphs = []
val_graphs = []
test_graphs = []
i = 0

for year in tqdm(years):
    print(str(year), end='\r')
    
    trade = TradeNetwork(year = year)
    trade.prepare_features()
    trade.prepare_network()
    trade.graph_create(node_features = ['prev_gdp_per_cap_growth', 'current_gdp_per_cap_growth',
    'resource_0', 'resource_1', 'resource_2', 'resource_3', 'resource_4', 'resource_5', 'resource_6', 'resource_7',
       'resource_8', 'resource_9'],
        node_labels = 'future_gdp_per_cap_growth')
    
    if(year in val_years):
        val_graphs.append(trade.pyg_graph)
    elif(year in test_years):
        test_graphs.append(trade.pyg_graph)
    else: 
        train_graphs.append(trade.pyg_graph)
        
    trade.features["year"] = year
    
    if(i == 0):
        trade_df = trade.features
    else: 
        trade_df = pd.concat([trade_df, trade.features])
        
    i = i+1
    print(trade.node_attributes.size())


  0%|          | 0/57 [00:00<?, ?it/s]

1962

  2%|▏         | 1/57 [00:05<05:34,  5.98s/it]

torch.Size([199, 12])
1963

  4%|▎         | 2/57 [00:11<05:14,  5.71s/it]

torch.Size([199, 12])
1964

  5%|▌         | 3/57 [00:17<05:07,  5.69s/it]

torch.Size([199, 12])
1965

  7%|▋         | 4/57 [00:23<05:08,  5.82s/it]

torch.Size([199, 12])
1966

  9%|▉         | 5/57 [00:29<05:09,  5.96s/it]

torch.Size([199, 12])
1967

 11%|█         | 6/57 [00:35<05:06,  6.02s/it]

torch.Size([199, 12])
1968

 12%|█▏        | 7/57 [00:42<05:09,  6.20s/it]

torch.Size([199, 12])
1969

 14%|█▍        | 8/57 [00:48<05:10,  6.34s/it]

torch.Size([199, 12])
1970

 16%|█▌        | 9/57 [00:55<05:16,  6.59s/it]

torch.Size([199, 12])
1971

 18%|█▊        | 10/57 [01:03<05:27,  6.98s/it]

torch.Size([199, 12])
1972

 19%|█▉        | 11/57 [01:11<05:35,  7.30s/it]

torch.Size([199, 12])
1973

 21%|██        | 12/57 [01:20<05:50,  7.79s/it]

torch.Size([199, 12])
1974

 23%|██▎       | 13/57 [01:29<06:00,  8.19s/it]

torch.Size([199, 12])
1975

 25%|██▍       | 14/57 [01:38<06:05,  8.50s/it]

torch.Size([199, 12])
1976

 26%|██▋       | 15/57 [01:43<05:00,  7.17s/it]

torch.Size([199, 12])
1977

 28%|██▊       | 16/57 [01:48<04:26,  6.50s/it]

torch.Size([199, 12])
1978

 30%|██▉       | 17/57 [01:54<04:19,  6.48s/it]

torch.Size([199, 12])
1979

 32%|███▏      | 18/57 [02:01<04:13,  6.51s/it]

torch.Size([199, 12])
1980

 33%|███▎      | 19/57 [02:08<04:17,  6.78s/it]

torch.Size([199, 12])
1981

 35%|███▌      | 20/57 [02:15<04:14,  6.87s/it]

torch.Size([199, 12])
1982

 37%|███▋      | 21/57 [02:22<04:07,  6.87s/it]

torch.Size([199, 12])
1983

 39%|███▊      | 22/57 [02:29<04:02,  6.92s/it]

torch.Size([199, 12])
1984

 40%|████      | 23/57 [02:36<03:55,  6.92s/it]

torch.Size([199, 12])
1985

 42%|████▏     | 24/57 [02:43<03:53,  7.09s/it]

torch.Size([199, 12])
1986

 44%|████▍     | 25/57 [02:52<03:59,  7.49s/it]

torch.Size([199, 12])
1987

 46%|████▌     | 26/57 [03:01<04:08,  8.02s/it]

torch.Size([199, 12])
1988

 47%|████▋     | 27/57 [03:10<04:10,  8.35s/it]

torch.Size([199, 12])
1989

 49%|████▉     | 28/57 [03:19<04:10,  8.65s/it]

torch.Size([199, 12])
1990

 51%|█████     | 29/57 [03:29<04:11,  8.99s/it]

torch.Size([199, 12])
1991

 53%|█████▎    | 30/57 [03:40<04:15,  9.48s/it]

torch.Size([199, 12])
1992

 54%|█████▍    | 31/57 [03:51<04:16,  9.87s/it]

torch.Size([199, 12])
1993

 56%|█████▌    | 32/57 [04:02<04:16, 10.27s/it]

torch.Size([199, 12])
1994

 58%|█████▊    | 33/57 [04:14<04:16, 10.69s/it]

torch.Size([199, 12])
1995

 60%|█████▉    | 34/57 [04:25<04:13, 11.03s/it]

torch.Size([199, 12])
1996

 61%|██████▏   | 35/57 [04:39<04:18, 11.73s/it]

torch.Size([199, 12])
1997

 63%|██████▎   | 36/57 [04:52<04:18, 12.29s/it]

torch.Size([199, 12])
1998

 65%|██████▍   | 37/57 [05:06<04:14, 12.73s/it]

torch.Size([199, 12])
1999

 67%|██████▋   | 38/57 [05:21<04:12, 13.28s/it]

torch.Size([199, 12])
2000

 68%|██████▊   | 39/57 [05:37<04:15, 14.20s/it]

torch.Size([199, 12])
2001

 70%|███████   | 40/57 [05:54<04:18, 15.18s/it]

torch.Size([199, 12])
2002

 72%|███████▏  | 41/57 [06:09<04:01, 15.11s/it]

torch.Size([199, 12])
2003

 74%|███████▎  | 42/57 [06:25<03:48, 15.25s/it]

torch.Size([199, 12])
2004

 75%|███████▌  | 43/57 [06:42<03:39, 15.70s/it]

torch.Size([199, 12])
2005

 77%|███████▋  | 44/57 [07:00<03:35, 16.61s/it]

torch.Size([199, 12])
2006

 79%|███████▉  | 45/57 [07:19<03:25, 17.12s/it]

torch.Size([199, 12])
2007

 81%|████████  | 46/57 [07:37<03:11, 17.38s/it]

torch.Size([199, 12])
2008

 82%|████████▏ | 47/57 [07:57<03:01, 18.12s/it]

torch.Size([199, 12])
2009

 84%|████████▍ | 48/57 [08:16<02:47, 18.65s/it]

torch.Size([199, 12])
2010

 86%|████████▌ | 49/57 [08:36<02:31, 18.90s/it]

torch.Size([199, 12])
2011

 88%|████████▊ | 50/57 [08:56<02:14, 19.14s/it]

torch.Size([199, 12])
2012

 89%|████████▉ | 51/57 [09:16<01:56, 19.46s/it]

torch.Size([199, 12])
2013

 91%|█████████ | 52/57 [09:37<01:40, 20.00s/it]

torch.Size([199, 12])
2014

 93%|█████████▎| 53/57 [09:59<01:21, 20.47s/it]

torch.Size([199, 12])
2015

 95%|█████████▍| 54/57 [10:19<01:01, 20.43s/it]

torch.Size([199, 12])
2016

 96%|█████████▋| 55/57 [10:40<00:41, 20.54s/it]

torch.Size([199, 12])
2017

 98%|█████████▊| 56/57 [11:00<00:20, 20.38s/it]

torch.Size([199, 12])
2018

100%|██████████| 57/57 [11:22<00:00, 11.98s/it]

torch.Size([199, 12])





# Graph Setup

In [4]:
# import random
# years = list(range(1962,2021))

# # Determine the sizes of the train, validation, and test sets
# train_size = int(len(years) * 0.6)
# val_size = int(len(years) * 0.20)

# # Get a random subset for the train set
# train_years = random.sample(years, train_size)

# # Remove the years in the train set from the list of years
# years = [year for year in years if year not in train_years]

# # Get a random subset for the validation set
# val_years = random.sample(years, val_size)

# # Remove the years in the validation set from the list of years
# test_years = [year for year in years if year not in val_years]

# print("Train years:", train_years)
# print("Validation years:", val_years)
# print("Test years:", test_years)

In [29]:
import pickle as pkl
with open("pygcn/train_graphs.pickle", "wb") as f:
    pkl.dump(train_graphs, f)

with open("pygcn/val_graphs.pickle", "wb") as f:
    pkl.dump(val_graphs, f)

with open("pygcn/test_graphs.pickle", "wb") as f:
    pkl.dump(test_graphs, f)

In [114]:
import pickle as pkl
with open("pygcn/train_graphs.pickle", "rb") as f:
    train_graphs = pkl.load(f)

with open("pygcn/val_graphs.pickle", "rb") as f:  
    val_graphs = pkl.load(f)

with open("pygcn/test_graphs.pickle", "rb") as f:         
    test_graphs = pkl.load(f)

In [115]:
from torch_geometric.data import DataLoader
test_loader = DataLoader(test_graphs, batch_size=4)
train_loader = DataLoader(train_graphs, batch_size=4)
val_loader = DataLoader(val_graphs, batch_size=4)

## sGNN with GCN Encoder and 3 Features

In [116]:
def check_crisis_years(year_pairs, crisis_years):
    result = []
    for pair in year_pairs:
        start, end = pair
        # Check if any crisis year is between the pair or equals the later year
        if any(start < year <= end for year in crisis_years):
            result.append(0)
        else:
            result.append(1)
    return result

In [117]:
crisis_years = [1983, 1982, 2008, 2002, 2016, 1967, 1962, 1989, 2012, 1963, 1993, 1986, 1996,1978]

def get_year_pairs(year_range):
    return [(year1, year2) for year1 in year_range for year2 in year_range if year2 >= year1]

def get_loader_pairs(dataset):
    return [(dataset[i], dataset[j]) for i in range(len(dataset)) for j in range(len(dataset)) if j >= i]

train_pairs = get_year_pairs(train_years)
val_pairs = get_year_pairs(val_years)

train_y = check_crisis_years(train_pairs, crisis_years)
val_y = check_crisis_years(val_pairs, crisis_years)

train_loader_pairs = get_loader_pairs(train_loader.dataset)
val_loader_pairs = get_loader_pairs(val_loader.dataset)

In [118]:
train_torch_y = torch.tensor(np.array(train_y))
val_torch_y = torch.tensor(np.array(val_y))

In [119]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.nn import GCNConv, global_sort_pool
from torch_geometric.nn.aggr import SortAggregation
from torch.nn import Linear, LayerNorm, ReLU, Sigmoid
from torch.nn import Linear, BatchNorm1d, ReLU
from tqdm import tqdm
from torch.optim.lr_scheduler import StepLR
from torch_geometric.nn import global_add_pool


class GNN(torch.nn.Module):
    def __init__(self, num_features):
        super(GNN, self).__init__()
        self.conv1 = GCNConv(num_features, 128)
        self.conv2 = GCNConv(128, 64)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index.to(torch.int64)
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index).relu()
        return x
    
class FullyConnectedLayer(torch.nn.Module):
    def __init__(self, inner_dim, outer_dim):
        super(FullyConnectedLayer, self).__init__()
        self.affine = Linear(inner_dim, outer_dim)
        self.batch_norm = BatchNorm1d(outer_dim)
        self.relu = ReLU()

    def forward(self, out):
        x = self.affine(out)
        print(x)
        x = self.batch_norm(x)
        return self.relu(x)

class SiameseGNN(torch.nn.Module):
    def __init__(self, num_features):
        super(SiameseGNN, self).__init__()
        self.gnn = GNN(num_features)
        self.sort_aggr = SortAggregation(k=50)
        self.fc1 = FullyConnectedLayer(9950, 64)
        self.fc2 = FullyConnectedLayer(64, 1)
        self.sigmoid = Sigmoid()

    def forward(self, data1, data2):
        out1 = self.gnn(data1)
        out2 = self.gnn(data2)
        out = torch.cdist(out1, out2, p=2) #Euclidean Distance
        out = self.sort_aggr(out, data1.batch) #Sort-K Pooling Layer
        out = out.view(out.size(0), -1)  # Flatten the pooled output
        out = self.fc1(out)
        out = self.fc2(out)
        out = global_add_pool(out)

        return self.sigmoid(out)

In [126]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.nn import GCNConv
from torch_geometric.nn.aggr import SortAggregation
from torch.nn import Linear, BatchNorm1d, ReLU, Sigmoid
from torch_geometric.nn import global_add_pool
from tqdm import tqdm

class GNN(torch.nn.Module):
    def __init__(self, num_features):
        super(GNN, self).__init__()
        self.conv1 = GCNConv(num_features, 128)
        self.conv2 = GCNConv(128, 64)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index.to(torch.int64)
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index).relu()
        return x
    
class FullyConnectedLayer(torch.nn.Module):
    def __init__(self, inner_dim, outer_dim):
        super(FullyConnectedLayer, self).__init__()
        self.affine = Linear(inner_dim, outer_dim)
        self.batch_norm = BatchNorm1d(outer_dim)
        self.relu = ReLU()

    def forward(self, out):
        x = self.affine(out)
        if len(x.size()) == 3:  # Ensure the input is 2D
            x = x.view(x.size(0) * x.size(1), -1)
        print(x)
        x = self.batch_norm(x)
        return self.relu(x)

class SiameseGNN(torch.nn.Module):
    def __init__(self, num_features):
        super(SiameseGNN, self).__init__()
        self.gnn = GNN(num_features)
        self.sort_aggr = SortAggregation(k=50)
        self.fc1 = FullyConnectedLayer(9950, 64)  # Adjust based on correct input size
        self.fc2 = FullyConnectedLayer(64, 1)
        self.sigmoid = Sigmoid()

    def forward(self, data1, data2):
        out1 = self.gnn(data1)
        out2 = self.gnn(data2)

        # Compute pairwise distances
        out = torch.cdist(out1, out2, p=2) # Euclidean Distance
        
        # Sort-K Pooling Layer
        out = self.sort_aggr(out, data1.batch)
        
        # Ensure the output shape is suitable for the FullyConnectedLayer
        out = out.view(out.size(0), -1)  # Flatten the pooled output

        out = self.fc1(out)
        out = self.fc2(out)

        # Ensure the shape is correct before passing to global_add_pool
        out = out.view(out.size(0), -1)
        out = global_add_pool(out, data1.batch)

        return self.sigmoid(out)


train_loader = train_loader_pairs
val_loader = val_loader_pairs

model = SiameseGNN(num_features=12)  # Adjust the number of features accordingly

optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = StepLR(optimizer, step_size=10, gamma=0.1)  # Adjust step_size and gamma as needed
criterion = nn.BCELoss()

for epoch in tqdm(range(5)):
    model.train()
    train_losses = []
    for i in range(len(train_loader)):

        optimizer.zero_grad()
        out = model(train_loader[i][0], train_loader[i][1])
        loss = criterion(out.squeeze(), train_torch_y[i].to(torch.float))
        loss.backward()
        optimizer.step()
        train_losses.append(loss.item())

    scheduler.step()  # Add this line to update the learning rate

    # Validation
    model.eval()
    with torch.no_grad():
        val_losses = []
        correct = 0
        total = 0
        for i in range(len(val_loader)):
            out = model(val_loader[i][0], val_loader[i][1])
            val_loss = criterion(out.squeeze(), val_torch_y[i].to(torch.float))
            val_losses.append(val_loss.item())

            predictions = torch.round(out.squeeze())
            correct += (predictions == val_torch_y[i]).sum().item()
            total += len(val_torch_y[i])

        val_loss = sum(val_losses) / len(val_losses)
        val_accuracy = correct / total

    print(f'Epoch: {epoch+1}, Training Loss: {sum(train_losses)/len(train_losses)}, Validation Loss: {val_loss}, Validation Accuracy: {val_accuracy}')


  0%|          | 0/5 [00:00<?, ?it/s]

tensor([[-0.1336, -1.6735,  0.7404,  1.5737, -0.1376,  0.9168,  0.1814,  1.0988,
          0.7630, -0.2472, -0.5239,  0.5975, -3.2576, -0.1216,  1.7809, -0.4495,
          0.6621, -0.2564,  2.1918, -1.1255, -0.1888,  2.2564, -0.7542,  0.3808,
         -0.6299,  1.7575,  1.6498, -0.5258,  1.5399, -0.3623,  0.6560,  0.6126,
          0.7954,  0.5136, -3.1180, -0.3158,  0.6231, -1.4749, -0.2006, -0.4121,
          0.5074,  1.9828,  1.5758, -1.2261, -1.4405, -2.9652, -0.9406, -1.4045,
         -0.4538, -0.7102,  1.7911, -1.7468,  0.7116, -0.0946, -0.6196,  0.5326,
         -2.2619, -0.7477, -0.9208,  0.5321, -1.1388, -0.5202, -0.3297,  1.0752]],
       grad_fn=<AddmmBackward0>)





ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 64])

In [104]:
train_loader[0][0].batch

In [99]:
# DataLoader that loads pairs of graphs
train_loader = train_loader_pairs
val_loader = val_loader_pairs

model = SiameseGNN(num_features=train_loader[0][0].num_node_features)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = StepLR(optimizer, step_size=10, gamma=0.1)  # Adjust step_size and gamma as needed
criterion = nn.BCELoss()

for epoch in tqdm(range(5)):
    model.train()
    train_losses = []
    for i in range(len(train_loader)):

        optimizer.zero_grad()
        print(train_loader[i][0])
        out = model(train_loader[i][0], train_loader[i][1])
        loss = criterion(out.squeeze(), train_torch_y[i].to(torch.float))
        loss.backward()
        optimizer.step()
        train_losses.append(loss.item())

    scheduler.step()  # Add this line to update the learning rate

    # Validation
    model.eval()
    with torch.no_grad():
        val_losses = []
        correct = 0
        total = 0
        for i in range(len(val_loader)):
            out = model(val_loader[i][0], val_loader[i][1])
            val_loss = criterion(out.squeeze(), val_torch_y[i].to(torch.float))
            val_losses.append(val_loss.item())

            predictions = torch.round(out.squeeze())
            correct += (predictions == val_torch_y[i]).sum().item()
            total += 1

        val_loss = sum(val_losses) / len(val_losses)
        val_accuracy = correct / total

    print(f'Epoch: {epoch+1}, Training Loss: {sum(train_losses)/len(train_losses)}, Validation Loss: {val_loss}, Validation Accuracy: {val_accuracy}')


  0%|          | 0/5 [00:00<?, ?it/s]

Data(x=[199, 12], edge_index=[2, 4420], edge_attr=[4420], y=[199])
tensor([[ 1.0423,  1.1136, -0.1196, -0.0634,  0.7821,  2.1531,  1.9750,  0.7092,
         -3.4137,  0.3919, -1.0351,  0.7561, -0.9224,  0.0053,  0.0116,  2.2491,
          1.3743, -1.7310,  0.3565,  1.1817, -1.2984,  0.8230,  1.2342, -2.7720,
          0.0072,  1.2638,  0.5268,  1.4335,  1.6768,  2.1548,  2.0185,  1.0940,
          2.7574,  0.0113,  0.1723, -0.0854,  1.2454, -1.4721,  1.2925,  0.2090,
          0.4599, -0.9316,  0.3393,  0.6429,  0.1661, -0.0868,  1.2037, -2.8661,
         -1.0396,  0.0494,  0.7910, -1.6080,  1.3443,  0.3441, -2.6085, -0.1864,
         -0.9635, -0.5269, -0.6951,  2.5409,  0.0445, -0.0707, -2.2662,  0.2011]],
       grad_fn=<AddmmBackward0>)





ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 64])

## sGNN with Feature Subset

In [7]:
class CreateFeatures:
    """
    We define a class which builds the feature dataframe 
    """
    
    def __init__(self, year = 1962, data_dir = "../data/"):
        self.year = year
        self.data_dir = data_dir
        
    def prepare_econ_features(self, filter_gdp = True):
        
        #DATA IMPORT
        #import dictionary with all features from WB
        with open(self.data_dir + 'all_wb_indicators.pickle', 'rb') as handle:
            features_dict = pkl.load(handle)
            
        self.feature_list = list(features_dict.keys())[1:]
        #import list of all features we want to select for

        #look up each of the features -- add country feature in that year 
        i = 0
        for feature in self.feature_list:
            #find dataframe corresponding to specific feature name
            df = features_dict[feature]
            
            if (i == 0):
                self.features = df[["economy", "YR" + str(self.year)]]
            else: 
                self.features = pd.merge(self.features, 
                                            df[["economy", "YR" + str(self.year)]],
                                            on = "economy", how = "outer")
            self.features.rename(columns = {"YR" + str(self.year): feature}, inplace = True)
            i = i+1
        
        #prepare GDP feature
        self.gdp_growth = features_dict['NY.GDP.MKTP.KD.ZG']
        cols = list(self.gdp_growth.columns.copy())
        cols.remove("economy")
        self.gdp_growth["country_sd"] = self.gdp_growth[cols].std(axis=1)
        #select potential variables 
        self.gdp_growth["prev_gdp_growth"] = self.gdp_growth["YR" + str(self.year-1)]
        self.gdp_growth["current_gdp_growth"] = self.gdp_growth["YR" + str(self.year)]
        #we eliminate countries that are too volatile in growth -- probably an indicator that growth estimates are inaccurate
        self.gdp_growth = self.gdp_growth[["economy", "prev_gdp_growth",
                                "current_gdp_growth"]].dropna()
        
        #combine GDP and other features
        self.features = pd.merge(self.gdp_growth, self.features,
                                   on = "economy", how = "left")
        #we only keep countries where we observe GDP growth -- otherwise nothing to predict
        #we keep countries where other features may be missing -- and fill NAs with 0 
        self.features.rename(columns = {"economy": "country_code"}, inplace = True)
        
    def prepare_network_features(self):
        """
        We create an initial, import-centric trade link pandas dataframe for a given year
        """
        #get product codes
        data_dict = get_sitc_codes()
        data_cross = []
        i = 0
        for item_def in list(data_dict["text"]):
            if(i >= 2):
                data_cross.append(item_def.split(" - ", 1))
            i = i+1

        self.product_codes = pd.DataFrame(data_cross, columns = ['code', 'product'])
        self.product_codes["sitc_product_code"] = self.product_codes["code"]
        
        #get country codes
        self.country_codes = pd.read_excel(self.data_dir + "ISO3166.xlsx")
        self.country_codes["location_code"] = self.country_codes["Alpha-3 code"]
        self.country_codes["partner_code"] = self.country_codes["Alpha-3 code"]
        self.country_codes["country_i"] = self.country_codes["English short name"]
        self.country_codes["country_j"] = self.country_codes["English short name"]
        
        #get trade data for a given year
        trade_data = pd.read_stata(self.data_dir + "country_partner_sitcproduct4digit_year_"+ str(self.year)+".dta") 
        #merge with product / country descriptions
        trade_data = pd.merge(trade_data, self.country_codes[["location_code", "country_i"]],on = ["location_code"])
        trade_data = pd.merge(trade_data, self.country_codes[["partner_code", "country_j"]],on = ["partner_code"])
        trade_data = pd.merge(trade_data, self.product_codes[["sitc_product_code", "product"]], 
                              on = ["sitc_product_code"])
        ###select level of product aggregation
        trade_data["product_category"] = trade_data["sitc_product_code"].apply(lambda x: x[0:1])
        #trade_data = trade_data[trade_data["product_category"] == "1"]
        
        #keep only nodes that we have features for
        trade_data = trade_data[trade_data["location_code"].isin(self.features["country_code"])]
        trade_data = trade_data[trade_data["partner_code"].isin(self.features["country_code"])]
        
        if (len(trade_data.groupby(["location_code", "partner_code", "sitc_product_code"])["import_value"].sum().reset_index()) != len(trade_data)):
            print("import, export, product combination not unique!")
        self.trade_data1 = trade_data
        #from import-export table, create only import table
        #extract imports
        imports1 = trade_data[['location_id', 'partner_id', 'product_id', 'year',
               'import_value', 'sitc_eci', 'sitc_coi', 'location_code', 'partner_code',
               'sitc_product_code', 'country_i', 'country_j', 'product', "product_category"]]
        imports1 = imports1[imports1["import_value"] != 0]
        #transform records of exports into imports
        imports2 = trade_data[['location_id', 'partner_id', 'product_id', 'year',
               'export_value', 'sitc_eci', 'sitc_coi', 'location_code', 'partner_code',
               'sitc_product_code', 'country_i', 'country_j', 'product', "product_category"]]
        imports2["temp1"] = imports2['partner_code']
        imports2["temp2"] = imports2['location_code']

        imports2['location_code'] = imports2["temp1"]
        imports2['partner_code'] = imports2["temp2"]
        imports2["import_value"] = imports2["export_value"]
        imports2 = imports2[imports2["import_value"] != 0]
        imports2 = imports2[['location_id', 'partner_id', 'product_id', 'year',
               'import_value', 'sitc_eci', 'sitc_coi', 'location_code', 'partner_code',
               'sitc_product_code', 'country_i', 'country_j', 'product', "product_category"]]
        
        imports_table = pd.concat([imports1, imports2]).drop_duplicates()
        
        #rename columns for better clarity
        imports_table["importer_code"] = imports_table["location_code"]
        imports_table["exporter_code"] = imports_table["partner_code"]
        imports_table["importer_name"] = imports_table["country_i"]
        imports_table["exporter_name"] = imports_table["country_j"]
        
        cols = ["importer_code", "exporter_code", "importer_name", "exporter_name",
               'product_id', 'year', 'import_value', 'sitc_eci', 'sitc_coi',
               'sitc_product_code', 'product', "product_category"]
        imports_table = imports_table[cols]
        
        exporter_total = imports_table.groupby(["exporter_code"])["import_value"].sum().reset_index()
        exporter_total = exporter_total.rename(columns = {"import_value": "export_total"})
        
        importer_total = imports_table.groupby(["importer_code"])["import_value"].sum().reset_index()
        importer_total = importer_total.rename(columns = {"import_value": "import_total"})
        
        ##### COMPUTE CENTRALITY FOR COUNTRY
        #sum imports across all products between countries into single value 
        imports_table_grouped = imports_table.groupby(["importer_code", "exporter_code"])["import_value"].sum().reset_index()
        imports_table_grouped = pd.merge(imports_table_grouped, importer_total, on = "importer_code")
        imports_table_grouped["import_fraction"] = imports_table_grouped["import_value"]\
                        /imports_table_grouped["import_total"]*100
        
        self.trade_data = imports_table_grouped
        
        #filter features and nodes to ones that are connected to others in trade data
        list_active_countries = list(set(list(self.trade_data ["importer_code"])+\
                        list(self.trade_data ["exporter_code"])))
        self.features = self.features[self.features["country_code"].isin(list_active_countries)].reset_index()
        self.features["node_numbers"] = self.features.index
        
        G=nx.from_pandas_edgelist(self.trade_data, 
                          "exporter_code", "importer_code", create_using = nx.DiGraph())
        
        self.G = G
        self.centrality_overall= nx.eigenvector_centrality(G, max_iter= 10000) 
        self.centrality_overall = pd.DataFrame(list(map(list, self.centrality_overall.items())), 
                                               columns = ["country_code", "centrality_overall"])
        G=nx.from_pandas_edgelist(self.trade_data, 
                          "exporter_code", "importer_code", ["import_fraction"])
        weighted_centrality = nx.eigenvector_centrality(G, weight = "import_fraction", max_iter= 10000) 
        weighted_centrality  = pd.DataFrame(list(map(list, weighted_centrality.items())), 
                                               columns = ["country_code", "weighted_centrality"])
        self.centrality_overall = pd.merge(self.centrality_overall, weighted_centrality, on = "country_code")
        
                               
        ##### COMPUTE CENTRALITY FOR COUNTRY IN PRODUCT CATEGORIES

        #sum imports across all products between countries into single value 
        imports_table_grouped = imports_table.groupby(["importer_code", "exporter_code"])["import_value"].sum().reset_index()
        products_grouped = imports_table.groupby(["product_category"])["import_value"].sum().reset_index()
        products_grouped = products_grouped.rename(columns = {"import_value": "import_product_total"})
        
        #sum exports in each category 
        self.export_types = imports_table.groupby(["importer_code", "exporter_code", "product_category"])["import_value"].sum().reset_index()
        self.export_types = pd.merge(products_grouped, self.export_types, on = "product_category")
        
        self.export_types["product_export_fraction"] = self.export_types["import_value"]\
                                                    /self.export_types["import_product_total"]*100
        
        list_products = list(set(self.export_types["product_category"]))
        
        i = 0 
        for product in list_products:
            
            temp = self.export_types[self.export_types["product_category"] == product].copy()
            
            G_w=nx.from_pandas_edgelist(temp, 
                "exporter_code", "importer_code", ["product_export_fraction"], create_using = nx.DiGraph())
            centrality_product_w = nx.eigenvector_centrality(G_w, weight = "product_export_fraction", 
                                                           max_iter= 10000)

            G=nx.from_pandas_edgelist(temp,"exporter_code", "importer_code", create_using = nx.DiGraph())
            centrality_product = nx.eigenvector_centrality(G,max_iter= 10000)

            if(i == 0):
                self.centrality_product = pd.DataFrame(list(map(list, centrality_product.items())), 
                                               columns = ["country_code", "prod_" + product])
                

            else: 
                self.centrality_product = pd.merge(self.centrality_product, 
                                               pd.DataFrame(list(map(list, centrality_product.items())), 
                                               columns = ["country_code", "prod_" + product]), 
                                                  on = "country_code")
                
            self.centrality_product = pd.merge(self.centrality_product, 
                                               pd.DataFrame(list(map(list, centrality_product_w.items())), 
                                               columns = ["country_code", "prod_w_" + product]), 
                                                  on = "country_code")
            
            i = i+1         
    
    def combine_normalize_features(self):
        
        self.combined_features = pd.merge(self.features, self.centrality_overall,on = "country_code")
        self.combined_features = pd.merge(self.combined_features, self.centrality_product,on = "country_code")
        #step eliminates NA and nodes that are not in graph, since they will have NA for graph features
        self.combined_features = self.combined_features.drop(columns = ["index"])
        #filter both trade data and features data to same subset of countries
        self.combined_features = self.combined_features[\
                                self.combined_features.country_code.isin(self.trade_data.importer_code)|\
                                self.combined_features.country_code.isin(self.trade_data.exporter_code)]
        self.trade_data = self.trade_data[\
                          self.trade_data.importer_code.isin(self.combined_features.country_code)&\
                          self.trade_data.exporter_code.isin(self.combined_features.country_code)]
        
        features_to_norm = list(self.combined_features.columns.copy())
        non_norm = ["country_code", "node_numbers"]
        cols_insufficient_data = list(self.combined_features.loc[:, self.combined_features.nunique() < 2].columns.copy())
        non_norm.extend(cols_insufficient_data)
 
        features_to_norm = [x for x in features_to_norm if x not in non_norm]
        scaler = StandardScaler()
        #we preserve NAs in the scaling
        self.combined_features[features_to_norm] = scaler.fit_transform(self.combined_features[features_to_norm])
        self.combined_features.fillna(0, inplace = True) #we fill NA after scaling 
        #check that feature has at least 20% coverage in a given year -- otherwise set to NA
        for feature in self.feature_list:
            coverage = len(self.combined_features[self.combined_features[feature] != 0])/len(self.combined_features)
            if(coverage <= 0.20): self.combined_features[feature] = 0

In [10]:
class TradeNetwork:
    """
    We define a class which computes the MST trade network for a given year 
    """
    
    def __init__(self, year = 1962, data_dir = "data"):
        self.year = year
        self.data_dir = data_dir
        
    def prepare_features(self, filter_gdp = True):
        
        ###IMPORT GDP###
        #prepare GDP as a set of features 
        with open('data/all_wb_indicators.pickle', 'rb') as handle:
            features_dict = pkl.load(handle)

        self.gdp = features_dict['NY.GDP.MKTP.CD']
        scaler = StandardScaler()
        self.gdp[["prev_gdp"]] = scaler.fit_transform(np.log(self.gdp[['YR'+str(self.year-1)]]))
        self.gdp[["current_gdp"]] = scaler.fit_transform(np.log(self.gdp[['YR'+str(self.year)]]))
        #rename and keep relevant columns
        self.gdp["country_code"] = self.gdp["economy"]
        self.gdp = self.gdp[["country_code", "prev_gdp", "current_gdp"]].dropna()
        
        ###IMPORT GDP GROWTH###
        #prepare GDP growth
        self.gdp_growth = features_dict['NY.GDP.MKTP.KD.ZG']
        self.gdp_growth["prev_gdp_growth"] = self.gdp_growth['YR'+str(self.year-1)]
        self.gdp_growth["current_gdp_growth"] = self.gdp_growth['YR'+str(self.year)] 
        self.gdp_growth["future_gdp_growth"] = self.gdp_growth['YR'+str(self.year+1)]
        #rename and keep relevant columns
        self.gdp_growth["country_code"] = self.gdp_growth["economy"]
        self.gdp_growth = self.gdp_growth[["country_code", "prev_gdp_growth",
                                "current_gdp_growth", "future_gdp_growth"]].dropna()
        
        ###IMPORT GDP PER CAPITA###
        self.gdp_per_capita = features_dict['NY.GDP.PCAP.CD']
        self.gdp_per_capita["prev_gdp_per_cap"] = self.gdp_per_capita['YR'+str(self.year-1)]
        self.gdp_per_capita["current_gdp_per_cap"] = self.gdp_per_capita['YR'+str(self.year)]
        self.gdp_per_capita["future_gdp_per_cap"] = self.gdp_per_capita['YR'+str(self.year+1)]
        #rename and keep relevant columns
        self.gdp_per_capita["country_code"] = self.gdp_per_capita["economy"]
        self.gdp_per_capita = self.gdp_per_capita[["country_code", "prev_gdp_per_cap",
                                "current_gdp_per_cap", "future_gdp_per_cap"]].dropna()
        
        ###IMPORT GDP PER CAPITA GROWTH###
        self.gdp_per_capita_growth = features_dict['NY.GDP.PCAP.KD.ZG']
        self.gdp_per_capita_growth["prev_gdp_per_cap_growth"] = self.gdp_per_capita_growth['YR'+str(self.year-1)]
        self.gdp_per_capita_growth["current_gdp_per_cap_growth"] = self.gdp_per_capita_growth['YR'+str(self.year)]
        self.gdp_per_capita_growth["future_gdp_per_cap_growth"] = self.gdp_per_capita_growth['YR'+str(self.year+1)]
        
        #rename and keep relevant columns
        self.gdp_per_capita_growth["country_code"] = self.gdp_per_capita_growth["economy"]
        self.gdp_per_capita_growth = self.gdp_per_capita_growth[["country_code", "prev_gdp_per_cap_growth",
                                "current_gdp_per_cap_growth", "future_gdp_per_cap_growth"]].dropna()
        
        ###MERGE ALL DATA FEATURES###
        self.features = pd.merge(self.gdp_growth, self.gdp, on = "country_code").dropna()
        self.features = pd.merge(self.features, self.gdp_per_capita, on = "country_code").dropna()
        self.features = pd.merge(self.features, self.gdp_per_capita_growth, on = "country_code").dropna()

    def prepare_network(self):
        """
        We create an initial, import-centric trade link pandas dataframe for a given year
        """
        #get product codes
        data_dict = get_sitc_codes()
        data_cross = []
        i = 0
        for item_def in list(data_dict["text"]):
            if(i >= 2):
                data_cross.append(item_def.split(" - ", 1))
            i = i+1

        self.product_codes = pd.DataFrame(data_cross, columns = ['code', 'product'])
        self.product_codes["sitc_product_code"] = self.product_codes["code"]
        
        #get country codes
        self.country_codes = pd.read_excel("data/ISO3166.xlsx")
        self.country_codes["location_code"] = self.country_codes["Alpha-3 code"]
        self.country_codes["partner_code"] = self.country_codes["Alpha-3 code"]
        self.country_codes["country_i"] = self.country_codes["English short name"]
        self.country_codes["country_j"] = self.country_codes["English short name"]
        
        #get trade data for a given year
        trade_data = pd.read_stata(self.data_dir + "/country_partner_sitcproduct4digit_year_"+ str(self.year)+".dta") 
        #merge with product / country descriptions
        trade_data = pd.merge(trade_data, self.country_codes[["location_code", "country_i"]],on = ["location_code"])
        trade_data = pd.merge(trade_data, self.country_codes[["partner_code", "country_j"]],on = ["partner_code"])
        trade_data = pd.merge(trade_data, self.product_codes[["sitc_product_code", "product"]], 
                              on = ["sitc_product_code"])
        ###select level of product aggregation
        trade_data["product_category"] = trade_data["sitc_product_code"].apply(lambda x: x[0:1])
        
        #keep only nodes that we have features for
        #trade_data = trade_data[trade_data["location_code"].isin(self.features["country_code"])]
        #trade_data = trade_data[trade_data["partner_code"].isin(self.features["country_code"])]
        
        if (len(trade_data.groupby(["location_code", "partner_code", "sitc_product_code"])["import_value"].sum().reset_index()) != len(trade_data)):
            print("import, export, product combination not unique!")
        self.trade_data1 = trade_data
        #from import-export table, create only import table
        #extract imports
        imports1 = trade_data[['location_id', 'partner_id', 'product_id', 'year',
               'import_value', 'sitc_eci', 'sitc_coi', 'location_code', 'partner_code',
               'sitc_product_code', 'country_i', 'country_j', 'product', "product_category"]]
        imports1 = imports1[imports1["import_value"] != 0]
        #transform records of exports into imports
        imports2 = trade_data[['location_id', 'partner_id', 'product_id', 'year',
               'export_value', 'sitc_eci', 'sitc_coi', 'location_code', 'partner_code',
               'sitc_product_code', 'country_i', 'country_j', 'product', "product_category"]]
        imports2["temp1"] = imports2['partner_code']
        imports2["temp2"] = imports2['location_code']

        imports2['location_code'] = imports2["temp1"]
        imports2['partner_code'] = imports2["temp2"]
        imports2["import_value"] = imports2["export_value"]
        imports2 = imports2[imports2["import_value"] != 0]
        imports2 = imports2[['location_id', 'partner_id', 'product_id', 'year',
               'import_value', 'sitc_eci', 'sitc_coi', 'location_code', 'partner_code',
               'sitc_product_code', 'country_i', 'country_j', 'product', "product_category"]]
        
        imports_table = pd.concat([imports1, imports2]).drop_duplicates()
        
        #rename columns for better clarity
        imports_table["importer_code"] = imports_table["location_code"]
        imports_table["exporter_code"] = imports_table["partner_code"]
        imports_table["importer_name"] = imports_table["country_i"]
        imports_table["exporter_name"] = imports_table["country_j"]
        
        cols = ["importer_code", "exporter_code", "importer_name", "exporter_name",
               'product_id', 'year', 'import_value', 'sitc_eci', 'sitc_coi',
               'sitc_product_code', 'product', "product_category"]
        imports_table = imports_table[cols]
        
        exporter_total = imports_table.groupby(["exporter_code"])["import_value"].sum().reset_index()
        exporter_total = exporter_total.rename(columns = {"import_value": "export_total"})
        
        importer_total = imports_table.groupby(["importer_code"])["import_value"].sum().reset_index()
        importer_total = importer_total.rename(columns = {"import_value": "import_total"})
        
        #sum imports across all products between countries into single value 
        imports_table_grouped = imports_table.groupby(["importer_code", "exporter_code"])["import_value"].sum().reset_index()
        
        #sum exports in each category 
        self.export_types = imports_table.groupby(["exporter_code", "product_category"])["import_value"].sum().reset_index()
        self.export_types = pd.merge(self.export_types, exporter_total, on = "exporter_code")
        #multiply by 100 to allow weights to scale better in GNN
        self.export_types["category_fraction"] = self.export_types.import_value/self.export_types.export_total*10
        ss = StandardScaler()
        columns = list(set(self.export_types["product_category"]))
        self.export_types = self.export_types[["exporter_code", "product_category", "category_fraction"]]\
        .pivot(index = ["exporter_code"], columns = ["product_category"], values = "category_fraction")\
        .reset_index().fillna(0)
        #rename columns
        rename_columns = []
        for col in self.export_types.columns:
            if(col == "exporter_code"):
                rename_columns.append(col)
            else:
                rename_columns.append("resource_" + col)
        self.export_types.columns = rename_columns
        self.export_types = self.export_types.rename(columns = {"exporter_code": "country_code"})
        self.features = pd.merge(self.features, self.export_types, 
                                on = "country_code", how = "left")
        
        #look at fraction of goods traded with each counterparty
        imports_table_grouped = pd.merge(imports_table_grouped, exporter_total, how = "left")
        imports_table_grouped["export_percent"] = imports_table_grouped["import_value"]/imports_table_grouped["export_total"]
        scaler = StandardScaler()
        imports_table_grouped[["export_percent_feature"]] = scaler.fit_transform(np.log(imports_table_grouped[["export_percent"]]))
        imports_table_grouped["export_percent_feature"] = imports_table_grouped["export_percent_feature"] + abs(min(imports_table_grouped["export_percent_feature"]))
        
        imports_table_grouped = pd.merge(imports_table_grouped, importer_total, how = "left")
        imports_table_grouped["import_percent"] = imports_table_grouped["import_value"]/imports_table_grouped["import_total"]
        scaler = StandardScaler()
        imports_table_grouped[["import_percent_feature"]] = scaler.fit_transform(np.log(imports_table_grouped[["import_percent"]]))
        imports_table_grouped["import_percent_feature"] = imports_table_grouped["import_percent_feature"] + abs(min(imports_table_grouped["import_percent_feature"]))
        
        self.trade_data = imports_table_grouped

    def graph_create(self, exporter = True,
            node_features = ['prev_gdp_growth', 'current_gdp_growth','prev_gdp','current_gdp'],
            node_labels = 'future_gdp_growth'):
        
        if(exporter):
            center_node = "exporter_code"
            neighbors = "importer_code"
            edge_features = 'export_percent'
        
        #filter features and nodes to ones that are connected to others in trade data
        # list_active_countries = list(set(list(self.trade_data ["importer_code"])+\
        #                 list(self.trade_data ["exporter_code"])))

        list_active_countries = ['ABW', 'AFG', 'AGO', 'ALB', 'AND', 'ARE', 'ARG', 'ARM', 'ASM',
       'ATG', 'AUS', 'AUT', 'AZE', 'BDI', 'BEL', 'BEN', 'BFA', 'BGD',
       'BGR', 'BHR', 'BHS', 'BIH', 'BLR', 'BLZ', 'BMU', 'BOL', 'BRA',
       'BRB', 'BRN', 'BTN', 'BWA', 'CAF', 'CAN', 'CHE', 'CHL', 'CHN',
       'CIV', 'CMR', 'COD', 'COG', 'COL', 'COM', 'CPV', 'CRI', 'CUB',
       'CUW', 'CYM', 'CYP', 'CZE', 'DEU', 'DMA', 'DNK', 'DOM', 'DZA',
       'ECU', 'EGY', 'ESP', 'EST', 'ETH', 'FIN', 'FJI', 'FRA', 'FSM',
       'GAB', 'GBR', 'GEO', 'GHA', 'GIN', 'GMB', 'GNB', 'GNQ', 'GRC',
       'GRD', 'GRL', 'GTM', 'GUM', 'GUY', 'HKG', 'HND', 'HRV', 'HTI',
       'HUN', 'IDN', 'IND', 'IRL', 'IRN', 'IRQ', 'ISL', 'ISR', 'ITA',
       'JAM', 'JOR', 'JPN', 'KAZ', 'KEN', 'KGZ', 'KHM', 'KNA', 'KOR',
       'KWT', 'LAO', 'LBN', 'LBR', 'LBY', 'LCA', 'LKA', 'LSO', 'LTU',
       'LUX', 'LVA', 'MAC', 'MAR', 'MDA', 'MDG', 'MDV', 'MEX', 'MHL',
       'MKD', 'MLI', 'MLT', 'MMR', 'MNE', 'MNG', 'MNP', 'MOZ', 'MRT',
       'MUS', 'MWI', 'MYS', 'NAM', 'NER', 'NGA', 'NIC', 'NLD', 'NOR',
       'NPL', 'NRU', 'NZL', 'OMN', 'PAK', 'PAN', 'PER', 'PHL', 'PLW',
       'PNG', 'POL', 'PRT', 'PRY', 'PSE', 'PYF', 'QAT', 'ROU', 'RUS',
       'RWA', 'SAU', 'SDN', 'SEN', 'SGP', 'SLB', 'SLE', 'SLV', 'SMR',
       'SRB', 'SSD', 'STP', 'SUR', 'SVK', 'SVN', 'SWE', 'SWZ', 'SXM',
       'SYC', 'SYR', 'TCD', 'TGO', 'THA', 'TJK', 'TKM', 'TLS', 'TON',
       'TTO', 'TUN', 'TUR', 'TUV', 'TZA', 'UGA', 'UKR', 'URY', 'USA',
       'UZB', 'VCT', 'VEN', 'VNM', 'VUT', 'WSM', 'YEM', 'ZAF', 'ZMB',
       'ZWE']

        # Create a new DataFrame with the list of country codes
        df_new = pd.DataFrame(list_active_countries, columns=['country_code'])
        self.features = pd.merge(df_new, self.features, on='country_code', how='left')
        self.features = self.features.fillna(0)

        # self.features = self.features[self.features["country_code"].isin(list_active_countries)].reset_index()
        # self.features.fillna(0, inplace = True)
        self.features["node_numbers"] = self.features.index

        #create lookup dictionary making node number / node features combatible with ordering of nodes
        #in our edge table

        self.node_lookup1 = self.features.set_index('node_numbers').to_dict()['country_code']
        self.node_lookup2 = self.features.set_index('country_code').to_dict()['node_numbers']
        
        #get individual country's features
        self.regression_table = pd.merge(self.features, self.trade_data,
                        left_on = "country_code",
                        right_on = center_node, how = 'right')
        #get features for trade partners
        self.regression_table = pd.merge(self.features, self.regression_table,
                                        left_on = "country_code",
                                        right_on = neighbors, how = "right",
                                        suffixes = ("_neighbors", ""))
        
        self.trade_data = self.trade_data[self.trade_data[neighbors].isin(self.node_lookup2)]
        self.trade_data = self.trade_data[self.trade_data[center_node].isin(self.node_lookup2)]

        self.regression_table["source"] = self.trade_data[neighbors].apply(lambda x: self.node_lookup2[x])
        self.regression_table["target"] = self.trade_data[center_node].apply(lambda x: self.node_lookup2[x])    

        self.regression_table = self.regression_table.dropna()
        #filter only to relevant columns
        self.relevant_columns = ["source", "target"]
        self.relevant_columns.extend(node_features)
        self.relevant_columns.append(node_labels)
        self.graph_table = self.regression_table[self.relevant_columns]
        
        if(self.graph_table.isnull().values.any()): print("edges contain null / inf values")

        self.node_attributes = torch.tensor(np.array(self.features[node_features]))\
        .to(torch.float)
        self.source_nodes = list(self.graph_table["source"])
        self.target_nodes = list(self.graph_table["target"])

        self.edge_attributes = list(self.trade_data[edge_features])
        
        self.pyg_graph = data.Data(x = self.node_attributes,
                                   edge_index = torch.tensor([self.source_nodes, self.target_nodes]),
                                   edge_attr = torch.tensor(self.edge_attributes).to(torch.float),
                                   y = torch.tensor(list(self.features[node_labels])).to(torch.float))

In [9]:
with open("feature_dicts/filtered_features_dict.pkl", "rb") as f:
    feat_dict = pkl.load(f)

In [10]:
all_nodes = ['ABW', 'AFG', 'AGO', 'ALB', 'AND', 'ARE', 'ARG', 'ARM', 'ASM',
       'ATG', 'AUS', 'AUT', 'AZE', 'BDI', 'BEL', 'BEN', 'BFA', 'BGD',
       'BGR', 'BHR', 'BHS', 'BIH', 'BLR', 'BLZ', 'BMU', 'BOL', 'BRA',
       'BRB', 'BRN', 'BTN', 'BWA', 'CAF', 'CAN', 'CHE', 'CHL', 'CHN',
       'CIV', 'CMR', 'COD', 'COG', 'COL', 'COM', 'CPV', 'CRI', 'CUB',
       'CUW', 'CYM', 'CYP', 'CZE', 'DEU', 'DMA', 'DNK', 'DOM', 'DZA',
       'ECU', 'EGY', 'ESP', 'EST', 'ETH', 'FIN', 'FJI', 'FRA', 'FSM',
       'GAB', 'GBR', 'GEO', 'GHA', 'GIN', 'GMB', 'GNB', 'GNQ', 'GRC',
       'GRD', 'GRL', 'GTM', 'GUM', 'GUY', 'HKG', 'HND', 'HRV', 'HTI',
       'HUN', 'IDN', 'IND', 'IRL', 'IRN', 'IRQ', 'ISL', 'ISR', 'ITA',
       'JAM', 'JOR', 'JPN', 'KAZ', 'KEN', 'KGZ', 'KHM', 'KNA', 'KOR',
       'KWT', 'LAO', 'LBN', 'LBR', 'LBY', 'LCA', 'LKA', 'LSO', 'LTU',
       'LUX', 'LVA', 'MAC', 'MAR', 'MDA', 'MDG', 'MDV', 'MEX', 'MHL',
       'MKD', 'MLI', 'MLT', 'MMR', 'MNE', 'MNG', 'MNP', 'MOZ', 'MRT',
       'MUS', 'MWI', 'MYS', 'NAM', 'NER', 'NGA', 'NIC', 'NLD', 'NOR',
       'NPL', 'NRU', 'NZL', 'OMN', 'PAK', 'PAN', 'PER', 'PHL', 'PLW',
       'PNG', 'POL', 'PRT', 'PRY', 'PSE', 'PYF', 'QAT', 'ROU', 'RUS',
       'RWA', 'SAU', 'SDN', 'SEN', 'SGP', 'SLB', 'SLE', 'SLV', 'SMR',
       'SRB', 'SSD', 'STP', 'SUR', 'SVK', 'SVN', 'SWE', 'SWZ', 'SXM',
       'SYC', 'SYR', 'TCD', 'TGO', 'THA', 'TJK', 'TKM', 'TLS', 'TON',
       'TTO', 'TUN', 'TUR', 'TUV', 'TZA', 'UGA', 'UKR', 'URY', 'USA',
       'UZB', 'VCT', 'VEN', 'VNM', 'VUT', 'WSM', 'YEM', 'ZAF', 'ZMB',
       'ZWE']

In [11]:
years = range(1962, 2019)

In [12]:
import pickle as pkl
with open("pygcn/train_graphs.pickle", "rb") as f:
    train_graphs = pkl.load(f)

with open("pygcn/val_graphs.pickle", "rb") as f:  
    val_graphs = pkl.load(f)

with open("pygcn/test_graphs.pickle", "rb") as f:         
    test_graphs = pkl.load(f)

In [13]:
from torch_geometric.data import DataLoader
test_loader = DataLoader(test_graphs, batch_size=32)
train_loader = DataLoader(train_graphs, batch_size=32)
val_loader = DataLoader(val_graphs, batch_size=32)



In [14]:
def add_features(years, graphs):

    zeros = torch.zeros(59)

    for i in range(len(years)):
        new_x = torch.empty(0, 59)
        year = years[i]

        feat_dict_year = feat_dict[year].combined_features

        for j, country in enumerate(all_nodes):
            if j == 0:
                new_x = torch.stack([zeros])

            elif country in feat_dict_year["country_code"].values:
                tensor_before = graphs[i].x[j]
                country_row = feat_dict_year[feat_dict_year["country_code"] == country]
                country_row = country_row.drop(columns = ["prev_gdp_growth", "country_code", "current_gdp_growth"])
                row_values = country_row.values.tolist()
                row_tensor = torch.tensor(row_values)[0]
                combined_values = torch.cat((tensor_before, row_tensor))

                new_x = torch.cat((new_x, combined_values.unsqueeze(0)), dim=0)

            else:
                new_x = torch.cat((new_x, zeros.unsqueeze(0)), dim=0)

        graphs[i].x = new_x

    return graphs

In [15]:
train_graphs = add_features(train_years, train_graphs)
val_graphs = add_features(val_years, val_graphs)
test_graphs = add_features(test_years, test_graphs)

NameError: name 'train_years' is not defined

In [20]:
crisis_years = [1983, 1982, 2008, 2002, 2016, 1967, 1962, 1989, 2012, 1963, 1993, 1986, 1996,1978]

def get_year_pairs(year_range):
    return [(year1, year2) for year1 in year_range for year2 in year_range if year2 >= year1]

def get_loader_pairs(dataset):
    return [(dataset[i], dataset[j]) for i in range(len(dataset)) for j in range(len(dataset)) if j >= i]

train_pairs = get_year_pairs(train_years)
val_pairs = get_year_pairs(val_years)

train_y = check_crisis_years(train_pairs, crisis_years)
val_y = check_crisis_years(val_pairs, crisis_years)

train_loader_pairs = get_loader_pairs(train_loader.dataset)
val_loader_pairs = get_loader_pairs(val_loader.dataset)

In [21]:
train_torch_y = torch.tensor(np.array(train_y))
val_torch_y = torch.tensor(np.array(val_y))

In [23]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.nn import GCNConv, global_sort_pool
from torch_geometric.nn.aggr import SortAggregation
from torch.nn import Linear, LayerNorm, ReLU, Sigmoid
from tqdm import tqdm
from torch.optim.lr_scheduler import StepLR

class GNN(torch.nn.Module):
    def __init__(self, num_features):
        super(GNN, self).__init__()
        self.conv1 = GCNConv(num_features, 128)
        self.conv2 = GCNConv(128, 64)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index.to(torch.int64)
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index).relu()
        return x

class SiameseGNN(torch.nn.Module):
    def __init__(self, num_features):
        super(SiameseGNN, self).__init__()
        self.gnn = GNN(num_features)
        self.sort_aggr = SortAggregation(k=50)
        self.fc1 = Linear(9950, 128)  # Adjust input size according to pooling output
        self.norm1 = LayerNorm(128)
        self.relu1 = ReLU()
        self.fc2 = Linear(128, 64)
        self.norm2 = LayerNorm(64)
        self.relu2 = ReLU()
        self.fc3 = Linear(64, 1)
        self.sigmoid = Sigmoid()

    def forward(self, data1, data2):
        out1 = self.gnn(data1)
        out2 = self.gnn(data2)
        out = torch.cdist(out1, out2, p=2)
        out = self.sort_aggr(out, data1.batch)
        out = out.view(out.size(0), -1)  # Flatten the pooled output
        out = self.fc1(out)
        out = self.norm1(out)
        out = self.relu1(out)
        out = self.fc2(out)
        out = self.norm2(out)
        out = self.relu2(out)
        out = self.fc3(out)
        out = self.sigmoid(out)
        return out

train_loader = train_loader_pairs
val_loader = val_loader_pairs

model = SiameseGNN(num_features=train_loader[0][0].num_node_features)
optimizer = optim.Adam(model.parameters(), lr=0.01)
scheduler = StepLR(optimizer, step_size=10, gamma=0.1)  # Adjust step_size and gamma as needed
criterion = nn.BCELoss()

for epoch in tqdm(range(10)):
    model.train()
    train_losses = []
    for i in range(len(train_loader)):

        optimizer.zero_grad()
        out = model(train_loader[i][0], train_loader[i][1])
        loss = criterion(out.squeeze(), train_torch_y[i].to(torch.float))
        loss.backward()
        optimizer.step()
        train_losses.append(loss.item())

    scheduler.step()  # Add this line to update the learning rate

    model.eval()
    with torch.no_grad():
        val_losses = []
        correct = 0
        total = 0
        for i in range(len(val_loader)):
            out = model(val_loader[i][0], val_loader[i][1])
            val_loss = criterion(out.squeeze(), val_torch_y[i].to(torch.float))
            val_losses.append(val_loss.item())

            predictions = torch.round(out.squeeze())
            correct += (predictions == val_torch_y[i]).sum().item()
            total += 1

        val_loss = sum(val_losses) / len(val_losses)
        val_accuracy = correct / total

    print(f'Epoch: {epoch+1}, Training Loss: {sum(train_losses)/len(train_losses)}, Validation Loss: {val_loss}, Validation Accuracy: {val_accuracy}')

 10%|█         | 1/10 [00:13<01:57, 13.00s/it]

Epoch: 1, Training Loss: 0.4229666137895664, Validation Loss: 0.4994966710607211, Validation Accuracy: 0.803030303030303


 20%|██        | 2/10 [00:25<01:42, 12.87s/it]

Epoch: 2, Training Loss: 0.41427947403628285, Validation Loss: 0.5009946116443836, Validation Accuracy: 0.803030303030303


 30%|███       | 3/10 [00:39<01:33, 13.30s/it]

Epoch: 3, Training Loss: 0.4123781653074156, Validation Loss: 0.5026700783408049, Validation Accuracy: 0.803030303030303


 40%|████      | 4/10 [01:34<02:57, 29.54s/it]

Epoch: 4, Training Loss: 0.4100686021462208, Validation Loss: 0.5035527772975691, Validation Accuracy: 0.803030303030303


 50%|█████     | 5/10 [03:29<05:02, 60.53s/it]

Epoch: 5, Training Loss: 0.4088507873921835, Validation Loss: 0.5042239278554916, Validation Accuracy: 0.803030303030303


 60%|██████    | 6/10 [03:43<02:58, 44.54s/it]

Epoch: 6, Training Loss: 0.41012620629758395, Validation Loss: 0.5044262244394331, Validation Accuracy: 0.803030303030303


 70%|███████   | 7/10 [03:55<01:42, 34.13s/it]

Epoch: 7, Training Loss: 0.40759952013232126, Validation Loss: 0.5053463794968345, Validation Accuracy: 0.803030303030303


 80%|████████  | 8/10 [04:07<00:54, 27.05s/it]

Epoch: 8, Training Loss: 0.4073875575068117, Validation Loss: 0.5056198169336175, Validation Accuracy: 0.803030303030303


 90%|█████████ | 9/10 [04:19<00:22, 22.31s/it]

Epoch: 9, Training Loss: 0.4068174197393305, Validation Loss: 0.5069572094715002, Validation Accuracy: 0.803030303030303


100%|██████████| 10/10 [04:31<00:00, 27.13s/it]

Epoch: 10, Training Loss: 0.4066063172676984, Validation Loss: 0.508396189095396, Validation Accuracy: 0.803030303030303





## Random Feature Subset

In [3]:
class CreateFeatures:
    """
    We define a class which builds the feature dataframe 
    """
    
    def __init__(self, year = 1962, data_dir = "../data/"):
        self.year = year
        self.data_dir = data_dir
        
    def prepare_econ_features(self, filter_gdp = True):
        
        #DATA IMPORT
        #import dictionary with all features from WB
        with open(self.data_dir + 'all_wb_indicators.pickle', 'rb') as handle:
            features_dict = pkl.load(handle)
            
        self.feature_list = list(features_dict.keys())[1:]
        #import list of all features we want to select for

        #look up each of the features -- add country feature in that year 
        i = 0
        for feature in self.feature_list:
            #find dataframe corresponding to specific feature name
            df = features_dict[feature]
            
            if (i == 0):
                self.features = df[["economy", "YR" + str(self.year)]]
            else: 
                self.features = pd.merge(self.features, 
                                            df[["economy", "YR" + str(self.year)]],
                                            on = "economy", how = "outer")
            self.features.rename(columns = {"YR" + str(self.year): feature}, inplace = True)
            i = i+1
        
        #prepare GDP feature
        self.gdp_growth = features_dict['NY.GDP.MKTP.KD.ZG']
        cols = list(self.gdp_growth.columns.copy())
        cols.remove("economy")
        self.gdp_growth["country_sd"] = self.gdp_growth[cols].std(axis=1)
        #select potential variables 
        self.gdp_growth["prev_gdp_growth"] = self.gdp_growth["YR" + str(self.year-1)]
        self.gdp_growth["current_gdp_growth"] = self.gdp_growth["YR" + str(self.year)]
        #we eliminate countries that are too volatile in growth -- probably an indicator that growth estimates are inaccurate
        self.gdp_growth = self.gdp_growth[["economy", "prev_gdp_growth",
                                "current_gdp_growth"]].dropna()
        
        #combine GDP and other features
        self.features = pd.merge(self.gdp_growth, self.features,
                                   on = "economy", how = "left")
        #we only keep countries where we observe GDP growth -- otherwise nothing to predict
        #we keep countries where other features may be missing -- and fill NAs with 0 
        self.features.rename(columns = {"economy": "country_code"}, inplace = True)
        
    def prepare_network_features(self):
        """
        We create an initial, import-centric trade link pandas dataframe for a given year
        """
        #get product codes
        data_dict = get_sitc_codes()
        data_cross = []
        i = 0
        for item_def in list(data_dict["text"]):
            if(i >= 2):
                data_cross.append(item_def.split(" - ", 1))
            i = i+1

        self.product_codes = pd.DataFrame(data_cross, columns = ['code', 'product'])
        self.product_codes["sitc_product_code"] = self.product_codes["code"]
        
        #get country codes
        self.country_codes = pd.read_excel(self.data_dir + "ISO3166.xlsx")
        self.country_codes["location_code"] = self.country_codes["Alpha-3 code"]
        self.country_codes["partner_code"] = self.country_codes["Alpha-3 code"]
        self.country_codes["country_i"] = self.country_codes["English short name"]
        self.country_codes["country_j"] = self.country_codes["English short name"]
        
        #get trade data for a given year
        trade_data = pd.read_stata(self.data_dir + "country_partner_sitcproduct4digit_year_"+ str(self.year)+".dta") 
        #merge with product / country descriptions
        trade_data = pd.merge(trade_data, self.country_codes[["location_code", "country_i"]],on = ["location_code"])
        trade_data = pd.merge(trade_data, self.country_codes[["partner_code", "country_j"]],on = ["partner_code"])
        trade_data = pd.merge(trade_data, self.product_codes[["sitc_product_code", "product"]], 
                              on = ["sitc_product_code"])
        ###select level of product aggregation
        trade_data["product_category"] = trade_data["sitc_product_code"].apply(lambda x: x[0:1])
        #trade_data = trade_data[trade_data["product_category"] == "1"]
        
        #keep only nodes that we have features for
        trade_data = trade_data[trade_data["location_code"].isin(self.features["country_code"])]
        trade_data = trade_data[trade_data["partner_code"].isin(self.features["country_code"])]
        
        if (len(trade_data.groupby(["location_code", "partner_code", "sitc_product_code"])["import_value"].sum().reset_index()) != len(trade_data)):
            print("import, export, product combination not unique!")
        self.trade_data1 = trade_data
        #from import-export table, create only import table
        #extract imports
        imports1 = trade_data[['location_id', 'partner_id', 'product_id', 'year',
               'import_value', 'sitc_eci', 'sitc_coi', 'location_code', 'partner_code',
               'sitc_product_code', 'country_i', 'country_j', 'product', "product_category"]]
        imports1 = imports1[imports1["import_value"] != 0]
        #transform records of exports into imports
        imports2 = trade_data[['location_id', 'partner_id', 'product_id', 'year',
               'export_value', 'sitc_eci', 'sitc_coi', 'location_code', 'partner_code',
               'sitc_product_code', 'country_i', 'country_j', 'product', "product_category"]]
        imports2["temp1"] = imports2['partner_code']
        imports2["temp2"] = imports2['location_code']

        imports2['location_code'] = imports2["temp1"]
        imports2['partner_code'] = imports2["temp2"]
        imports2["import_value"] = imports2["export_value"]
        imports2 = imports2[imports2["import_value"] != 0]
        imports2 = imports2[['location_id', 'partner_id', 'product_id', 'year',
               'import_value', 'sitc_eci', 'sitc_coi', 'location_code', 'partner_code',
               'sitc_product_code', 'country_i', 'country_j', 'product', "product_category"]]
        
        imports_table = pd.concat([imports1, imports2]).drop_duplicates()
        
        #rename columns for better clarity
        imports_table["importer_code"] = imports_table["location_code"]
        imports_table["exporter_code"] = imports_table["partner_code"]
        imports_table["importer_name"] = imports_table["country_i"]
        imports_table["exporter_name"] = imports_table["country_j"]
        
        cols = ["importer_code", "exporter_code", "importer_name", "exporter_name",
               'product_id', 'year', 'import_value', 'sitc_eci', 'sitc_coi',
               'sitc_product_code', 'product', "product_category"]
        imports_table = imports_table[cols]
        
        exporter_total = imports_table.groupby(["exporter_code"])["import_value"].sum().reset_index()
        exporter_total = exporter_total.rename(columns = {"import_value": "export_total"})
        
        importer_total = imports_table.groupby(["importer_code"])["import_value"].sum().reset_index()
        importer_total = importer_total.rename(columns = {"import_value": "import_total"})
        
        ##### COMPUTE CENTRALITY FOR COUNTRY
        #sum imports across all products between countries into single value 
        imports_table_grouped = imports_table.groupby(["importer_code", "exporter_code"])["import_value"].sum().reset_index()
        imports_table_grouped = pd.merge(imports_table_grouped, importer_total, on = "importer_code")
        imports_table_grouped["import_fraction"] = imports_table_grouped["import_value"]\
                        /imports_table_grouped["import_total"]*100
        
        self.trade_data = imports_table_grouped
        
        #filter features and nodes to ones that are connected to others in trade data
        list_active_countries = list(set(list(self.trade_data ["importer_code"])+\
                        list(self.trade_data ["exporter_code"])))
        self.features = self.features[self.features["country_code"].isin(list_active_countries)].reset_index()
        self.features["node_numbers"] = self.features.index
        
        G=nx.from_pandas_edgelist(self.trade_data, 
                          "exporter_code", "importer_code", create_using = nx.DiGraph())
        
        self.G = G
        self.centrality_overall= nx.eigenvector_centrality(G, max_iter= 10000) 
        self.centrality_overall = pd.DataFrame(list(map(list, self.centrality_overall.items())), 
                                               columns = ["country_code", "centrality_overall"])
        G=nx.from_pandas_edgelist(self.trade_data, 
                          "exporter_code", "importer_code", ["import_fraction"])
        weighted_centrality = nx.eigenvector_centrality(G, weight = "import_fraction", max_iter= 10000) 
        weighted_centrality  = pd.DataFrame(list(map(list, weighted_centrality.items())), 
                                               columns = ["country_code", "weighted_centrality"])
        self.centrality_overall = pd.merge(self.centrality_overall, weighted_centrality, on = "country_code")
        
                               
        ##### COMPUTE CENTRALITY FOR COUNTRY IN PRODUCT CATEGORIES

        #sum imports across all products between countries into single value 
        imports_table_grouped = imports_table.groupby(["importer_code", "exporter_code"])["import_value"].sum().reset_index()
        products_grouped = imports_table.groupby(["product_category"])["import_value"].sum().reset_index()
        products_grouped = products_grouped.rename(columns = {"import_value": "import_product_total"})
        
        #sum exports in each category 
        self.export_types = imports_table.groupby(["importer_code", "exporter_code", "product_category"])["import_value"].sum().reset_index()
        self.export_types = pd.merge(products_grouped, self.export_types, on = "product_category")
        
        self.export_types["product_export_fraction"] = self.export_types["import_value"]\
                                                    /self.export_types["import_product_total"]*100
        
        list_products = list(set(self.export_types["product_category"]))
        
        i = 0 
        for product in list_products:
            
            temp = self.export_types[self.export_types["product_category"] == product].copy()
            
            G_w=nx.from_pandas_edgelist(temp, 
                "exporter_code", "importer_code", ["product_export_fraction"], create_using = nx.DiGraph())
            centrality_product_w = nx.eigenvector_centrality(G_w, weight = "product_export_fraction", 
                                                           max_iter= 10000)

            G=nx.from_pandas_edgelist(temp,"exporter_code", "importer_code", create_using = nx.DiGraph())
            centrality_product = nx.eigenvector_centrality(G,max_iter= 10000)

            if(i == 0):
                self.centrality_product = pd.DataFrame(list(map(list, centrality_product.items())), 
                                               columns = ["country_code", "prod_" + product])
                

            else: 
                self.centrality_product = pd.merge(self.centrality_product, 
                                               pd.DataFrame(list(map(list, centrality_product.items())), 
                                               columns = ["country_code", "prod_" + product]), 
                                                  on = "country_code")
                
            self.centrality_product = pd.merge(self.centrality_product, 
                                               pd.DataFrame(list(map(list, centrality_product_w.items())), 
                                               columns = ["country_code", "prod_w_" + product]), 
                                                  on = "country_code")
            
            i = i+1         
    
    def combine_normalize_features(self):
        
        self.combined_features = pd.merge(self.features, self.centrality_overall,on = "country_code")
        self.combined_features = pd.merge(self.combined_features, self.centrality_product,on = "country_code")
        #step eliminates NA and nodes that are not in graph, since they will have NA for graph features
        self.combined_features = self.combined_features.drop(columns = ["index"])
        #filter both trade data and features data to same subset of countries
        self.combined_features = self.combined_features[\
                                self.combined_features.country_code.isin(self.trade_data.importer_code)|\
                                self.combined_features.country_code.isin(self.trade_data.exporter_code)]
        self.trade_data = self.trade_data[\
                          self.trade_data.importer_code.isin(self.combined_features.country_code)&\
                          self.trade_data.exporter_code.isin(self.combined_features.country_code)]
        
        features_to_norm = list(self.combined_features.columns.copy())
        non_norm = ["country_code", "node_numbers"]
        cols_insufficient_data = list(self.combined_features.loc[:, self.combined_features.nunique() < 2].columns.copy())
        non_norm.extend(cols_insufficient_data)
 
        features_to_norm = [x for x in features_to_norm if x not in non_norm]
        scaler = StandardScaler()
        #we preserve NAs in the scaling
        self.combined_features[features_to_norm] = scaler.fit_transform(self.combined_features[features_to_norm])
        self.combined_features.fillna(0, inplace = True) #we fill NA after scaling 
        #check that feature has at least 20% coverage in a given year -- otherwise set to NA
        for feature in self.feature_list:
            coverage = len(self.combined_features[self.combined_features[feature] != 0])/len(self.combined_features)
            if(coverage <= 0.20): self.combined_features[feature] = 0

In [4]:
class TradeNetwork:
    """
    We define a class which computes the MST trade network for a given year 
    """
    
    def __init__(self, year = 1962, data_dir = "data"):
        self.year = year
        self.data_dir = data_dir
        
    def prepare_features(self, filter_gdp = True):
        
        ###IMPORT GDP###
        #prepare GDP as a set of features 
        with open('data/all_wb_indicators.pickle', 'rb') as handle:
            features_dict = pkl.load(handle)

        self.gdp = features_dict['NY.GDP.MKTP.CD']
        scaler = StandardScaler()
        self.gdp[["prev_gdp"]] = scaler.fit_transform(np.log(self.gdp[['YR'+str(self.year-1)]]))
        self.gdp[["current_gdp"]] = scaler.fit_transform(np.log(self.gdp[['YR'+str(self.year)]]))
        #rename and keep relevant columns
        self.gdp["country_code"] = self.gdp["economy"]
        self.gdp = self.gdp[["country_code", "prev_gdp", "current_gdp"]].dropna()
        
        ###IMPORT GDP GROWTH###
        #prepare GDP growth
        self.gdp_growth = features_dict['NY.GDP.MKTP.KD.ZG']
        self.gdp_growth["prev_gdp_growth"] = self.gdp_growth['YR'+str(self.year-1)]
        self.gdp_growth["current_gdp_growth"] = self.gdp_growth['YR'+str(self.year)] 
        self.gdp_growth["future_gdp_growth"] = self.gdp_growth['YR'+str(self.year+1)]
        #rename and keep relevant columns
        self.gdp_growth["country_code"] = self.gdp_growth["economy"]
        self.gdp_growth = self.gdp_growth[["country_code", "prev_gdp_growth",
                                "current_gdp_growth", "future_gdp_growth"]].dropna()
        
        ###IMPORT GDP PER CAPITA###
        self.gdp_per_capita = features_dict['NY.GDP.PCAP.CD']
        self.gdp_per_capita["prev_gdp_per_cap"] = self.gdp_per_capita['YR'+str(self.year-1)]
        self.gdp_per_capita["current_gdp_per_cap"] = self.gdp_per_capita['YR'+str(self.year)]
        self.gdp_per_capita["future_gdp_per_cap"] = self.gdp_per_capita['YR'+str(self.year+1)]
        #rename and keep relevant columns
        self.gdp_per_capita["country_code"] = self.gdp_per_capita["economy"]
        self.gdp_per_capita = self.gdp_per_capita[["country_code", "prev_gdp_per_cap",
                                "current_gdp_per_cap", "future_gdp_per_cap"]].dropna()
        
        ###IMPORT GDP PER CAPITA GROWTH###
        self.gdp_per_capita_growth = features_dict['NY.GDP.PCAP.KD.ZG']
        self.gdp_per_capita_growth["prev_gdp_per_cap_growth"] = self.gdp_per_capita_growth['YR'+str(self.year-1)]
        self.gdp_per_capita_growth["current_gdp_per_cap_growth"] = self.gdp_per_capita_growth['YR'+str(self.year)]
        self.gdp_per_capita_growth["future_gdp_per_cap_growth"] = self.gdp_per_capita_growth['YR'+str(self.year+1)]
        
        #rename and keep relevant columns
        self.gdp_per_capita_growth["country_code"] = self.gdp_per_capita_growth["economy"]
        self.gdp_per_capita_growth = self.gdp_per_capita_growth[["country_code", "prev_gdp_per_cap_growth",
                                "current_gdp_per_cap_growth", "future_gdp_per_cap_growth"]].dropna()
        
        ###MERGE ALL DATA FEATURES###
        self.features = pd.merge(self.gdp_growth, self.gdp, on = "country_code").dropna()
        self.features = pd.merge(self.features, self.gdp_per_capita, on = "country_code").dropna()
        self.features = pd.merge(self.features, self.gdp_per_capita_growth, on = "country_code").dropna()

    def prepare_network(self):
        """
        We create an initial, import-centric trade link pandas dataframe for a given year
        """
        #get product codes
        data_dict = get_sitc_codes()
        data_cross = []
        i = 0
        for item_def in list(data_dict["text"]):
            if(i >= 2):
                data_cross.append(item_def.split(" - ", 1))
            i = i+1

        self.product_codes = pd.DataFrame(data_cross, columns = ['code', 'product'])
        self.product_codes["sitc_product_code"] = self.product_codes["code"]
        
        #get country codes
        self.country_codes = pd.read_excel("data/ISO3166.xlsx")
        self.country_codes["location_code"] = self.country_codes["Alpha-3 code"]
        self.country_codes["partner_code"] = self.country_codes["Alpha-3 code"]
        self.country_codes["country_i"] = self.country_codes["English short name"]
        self.country_codes["country_j"] = self.country_codes["English short name"]
        
        #get trade data for a given year
        trade_data = pd.read_stata(self.data_dir + "/country_partner_sitcproduct4digit_year_"+ str(self.year)+".dta") 
        #merge with product / country descriptions
        trade_data = pd.merge(trade_data, self.country_codes[["location_code", "country_i"]],on = ["location_code"])
        trade_data = pd.merge(trade_data, self.country_codes[["partner_code", "country_j"]],on = ["partner_code"])
        trade_data = pd.merge(trade_data, self.product_codes[["sitc_product_code", "product"]], 
                              on = ["sitc_product_code"])
        ###select level of product aggregation
        trade_data["product_category"] = trade_data["sitc_product_code"].apply(lambda x: x[0:1])
        
        #keep only nodes that we have features for
        #trade_data = trade_data[trade_data["location_code"].isin(self.features["country_code"])]
        #trade_data = trade_data[trade_data["partner_code"].isin(self.features["country_code"])]
        
        if (len(trade_data.groupby(["location_code", "partner_code", "sitc_product_code"])["import_value"].sum().reset_index()) != len(trade_data)):
            print("import, export, product combination not unique!")
        self.trade_data1 = trade_data
        #from import-export table, create only import table
        #extract imports
        imports1 = trade_data[['location_id', 'partner_id', 'product_id', 'year',
               'import_value', 'sitc_eci', 'sitc_coi', 'location_code', 'partner_code',
               'sitc_product_code', 'country_i', 'country_j', 'product', "product_category"]]
        imports1 = imports1[imports1["import_value"] != 0]
        #transform records of exports into imports
        imports2 = trade_data[['location_id', 'partner_id', 'product_id', 'year',
               'export_value', 'sitc_eci', 'sitc_coi', 'location_code', 'partner_code',
               'sitc_product_code', 'country_i', 'country_j', 'product', "product_category"]]
        imports2["temp1"] = imports2['partner_code']
        imports2["temp2"] = imports2['location_code']

        imports2['location_code'] = imports2["temp1"]
        imports2['partner_code'] = imports2["temp2"]
        imports2["import_value"] = imports2["export_value"]
        imports2 = imports2[imports2["import_value"] != 0]
        imports2 = imports2[['location_id', 'partner_id', 'product_id', 'year',
               'import_value', 'sitc_eci', 'sitc_coi', 'location_code', 'partner_code',
               'sitc_product_code', 'country_i', 'country_j', 'product', "product_category"]]
        
        imports_table = pd.concat([imports1, imports2]).drop_duplicates()
        
        #rename columns for better clarity
        imports_table["importer_code"] = imports_table["location_code"]
        imports_table["exporter_code"] = imports_table["partner_code"]
        imports_table["importer_name"] = imports_table["country_i"]
        imports_table["exporter_name"] = imports_table["country_j"]
        
        cols = ["importer_code", "exporter_code", "importer_name", "exporter_name",
               'product_id', 'year', 'import_value', 'sitc_eci', 'sitc_coi',
               'sitc_product_code', 'product', "product_category"]
        imports_table = imports_table[cols]
        
        exporter_total = imports_table.groupby(["exporter_code"])["import_value"].sum().reset_index()
        exporter_total = exporter_total.rename(columns = {"import_value": "export_total"})
        
        importer_total = imports_table.groupby(["importer_code"])["import_value"].sum().reset_index()
        importer_total = importer_total.rename(columns = {"import_value": "import_total"})
        
        #sum imports across all products between countries into single value 
        imports_table_grouped = imports_table.groupby(["importer_code", "exporter_code"])["import_value"].sum().reset_index()
        
        #sum exports in each category 
        self.export_types = imports_table.groupby(["exporter_code", "product_category"])["import_value"].sum().reset_index()
        self.export_types = pd.merge(self.export_types, exporter_total, on = "exporter_code")
        #multiply by 100 to allow weights to scale better in GNN
        self.export_types["category_fraction"] = self.export_types.import_value/self.export_types.export_total*10
        ss = StandardScaler()
        columns = list(set(self.export_types["product_category"]))
        self.export_types = self.export_types[["exporter_code", "product_category", "category_fraction"]]\
        .pivot(index = ["exporter_code"], columns = ["product_category"], values = "category_fraction")\
        .reset_index().fillna(0)
        #rename columns
        rename_columns = []
        for col in self.export_types.columns:
            if(col == "exporter_code"):
                rename_columns.append(col)
            else:
                rename_columns.append("resource_" + col)
        self.export_types.columns = rename_columns
        self.export_types = self.export_types.rename(columns = {"exporter_code": "country_code"})
        self.features = pd.merge(self.features, self.export_types, 
                                on = "country_code", how = "left")
        
        #look at fraction of goods traded with each counterparty
        imports_table_grouped = pd.merge(imports_table_grouped, exporter_total, how = "left")
        imports_table_grouped["export_percent"] = imports_table_grouped["import_value"]/imports_table_grouped["export_total"]
        scaler = StandardScaler()
        imports_table_grouped[["export_percent_feature"]] = scaler.fit_transform(np.log(imports_table_grouped[["export_percent"]]))
        imports_table_grouped["export_percent_feature"] = imports_table_grouped["export_percent_feature"] + abs(min(imports_table_grouped["export_percent_feature"]))
        
        imports_table_grouped = pd.merge(imports_table_grouped, importer_total, how = "left")
        imports_table_grouped["import_percent"] = imports_table_grouped["import_value"]/imports_table_grouped["import_total"]
        scaler = StandardScaler()
        imports_table_grouped[["import_percent_feature"]] = scaler.fit_transform(np.log(imports_table_grouped[["import_percent"]]))
        imports_table_grouped["import_percent_feature"] = imports_table_grouped["import_percent_feature"] + abs(min(imports_table_grouped["import_percent_feature"]))
        
        self.trade_data = imports_table_grouped

    def graph_create(self, exporter = True,
            node_features = ['prev_gdp_growth', 'current_gdp_growth','prev_gdp','current_gdp'],
            node_labels = 'future_gdp_growth'):
        
        if(exporter):
            center_node = "exporter_code"
            neighbors = "importer_code"
            edge_features = 'export_percent'
        
        #filter features and nodes to ones that are connected to others in trade data
        # list_active_countries = list(set(list(self.trade_data ["importer_code"])+\
        #                 list(self.trade_data ["exporter_code"])))

        list_active_countries = ['ABW', 'AFG', 'AGO', 'ALB', 'AND', 'ARE', 'ARG', 'ARM', 'ASM',
       'ATG', 'AUS', 'AUT', 'AZE', 'BDI', 'BEL', 'BEN', 'BFA', 'BGD',
       'BGR', 'BHR', 'BHS', 'BIH', 'BLR', 'BLZ', 'BMU', 'BOL', 'BRA',
       'BRB', 'BRN', 'BTN', 'BWA', 'CAF', 'CAN', 'CHE', 'CHL', 'CHN',
       'CIV', 'CMR', 'COD', 'COG', 'COL', 'COM', 'CPV', 'CRI', 'CUB',
       'CUW', 'CYM', 'CYP', 'CZE', 'DEU', 'DMA', 'DNK', 'DOM', 'DZA',
       'ECU', 'EGY', 'ESP', 'EST', 'ETH', 'FIN', 'FJI', 'FRA', 'FSM',
       'GAB', 'GBR', 'GEO', 'GHA', 'GIN', 'GMB', 'GNB', 'GNQ', 'GRC',
       'GRD', 'GRL', 'GTM', 'GUM', 'GUY', 'HKG', 'HND', 'HRV', 'HTI',
       'HUN', 'IDN', 'IND', 'IRL', 'IRN', 'IRQ', 'ISL', 'ISR', 'ITA',
       'JAM', 'JOR', 'JPN', 'KAZ', 'KEN', 'KGZ', 'KHM', 'KNA', 'KOR',
       'KWT', 'LAO', 'LBN', 'LBR', 'LBY', 'LCA', 'LKA', 'LSO', 'LTU',
       'LUX', 'LVA', 'MAC', 'MAR', 'MDA', 'MDG', 'MDV', 'MEX', 'MHL',
       'MKD', 'MLI', 'MLT', 'MMR', 'MNE', 'MNG', 'MNP', 'MOZ', 'MRT',
       'MUS', 'MWI', 'MYS', 'NAM', 'NER', 'NGA', 'NIC', 'NLD', 'NOR',
       'NPL', 'NRU', 'NZL', 'OMN', 'PAK', 'PAN', 'PER', 'PHL', 'PLW',
       'PNG', 'POL', 'PRT', 'PRY', 'PSE', 'PYF', 'QAT', 'ROU', 'RUS',
       'RWA', 'SAU', 'SDN', 'SEN', 'SGP', 'SLB', 'SLE', 'SLV', 'SMR',
       'SRB', 'SSD', 'STP', 'SUR', 'SVK', 'SVN', 'SWE', 'SWZ', 'SXM',
       'SYC', 'SYR', 'TCD', 'TGO', 'THA', 'TJK', 'TKM', 'TLS', 'TON',
       'TTO', 'TUN', 'TUR', 'TUV', 'TZA', 'UGA', 'UKR', 'URY', 'USA',
       'UZB', 'VCT', 'VEN', 'VNM', 'VUT', 'WSM', 'YEM', 'ZAF', 'ZMB',
       'ZWE']

        # Create a new DataFrame with the list of country codes
        df_new = pd.DataFrame(list_active_countries, columns=['country_code'])
        self.features = pd.merge(df_new, self.features, on='country_code', how='left')
        self.features = self.features.fillna(0)

        # self.features = self.features[self.features["country_code"].isin(list_active_countries)].reset_index()
        # self.features.fillna(0, inplace = True)
        self.features["node_numbers"] = self.features.index

        #create lookup dictionary making node number / node features combatible with ordering of nodes
        #in our edge table

        self.node_lookup1 = self.features.set_index('node_numbers').to_dict()['country_code']
        self.node_lookup2 = self.features.set_index('country_code').to_dict()['node_numbers']
        
        #get individual country's features
        self.regression_table = pd.merge(self.features, self.trade_data,
                        left_on = "country_code",
                        right_on = center_node, how = 'right')
        #get features for trade partners
        self.regression_table = pd.merge(self.features, self.regression_table,
                                        left_on = "country_code",
                                        right_on = neighbors, how = "right",
                                        suffixes = ("_neighbors", ""))
        
        self.trade_data = self.trade_data[self.trade_data[neighbors].isin(self.node_lookup2)]
        self.trade_data = self.trade_data[self.trade_data[center_node].isin(self.node_lookup2)]

        self.regression_table["source"] = self.trade_data[neighbors].apply(lambda x: self.node_lookup2[x])
        self.regression_table["target"] = self.trade_data[center_node].apply(lambda x: self.node_lookup2[x])    

        self.regression_table = self.regression_table.dropna()
        #filter only to relevant columns
        self.relevant_columns = ["source", "target"]
        self.relevant_columns.extend(node_features)
        self.relevant_columns.append(node_labels)
        self.graph_table = self.regression_table[self.relevant_columns]
        
        if(self.graph_table.isnull().values.any()): print("edges contain null / inf values")

        self.node_attributes = torch.tensor(np.array(self.features[node_features]))\
        .to(torch.float)
        self.source_nodes = list(self.graph_table["source"])
        self.target_nodes = list(self.graph_table["target"])

        self.edge_attributes = list(self.trade_data[edge_features])
        
        self.pyg_graph = data.Data(x = self.node_attributes,
                                   edge_index = torch.tensor([self.source_nodes, self.target_nodes]),
                                   edge_attr = torch.tensor(self.edge_attributes).to(torch.float),
                                   y = torch.tensor(list(self.features[node_labels])).to(torch.float))

In [5]:
years = range(1962,2019)

train_years = [2005, 1969, 2002, 1997, 1993, 1982, 2001, 2000, 1962, 1985, 1978, 2016, 1986, 1987, 1989, 1971, 2013, 1996, 1995, 1967, 2017, 1974, 1990, 1977, 1980, 2014, 1965, 1984, 2006, 1973, 1968, 1981, 1970, 1991]
val_years = [1975, 1983, 2009, 1966, 1999, 1988, 2007, 1979, 1972, 2015, 2003]
test_years = [1963, 1964, 1976, 1992, 1994, 1998, 2004, 2008, 2010, 2011, 2012, 2018]

train_graphs = []
val_graphs = []
test_graphs = []
i = 0

for year in tqdm(years):
    print(str(year), end='\r')
    
    trade = TradeNetwork(year = year)
    trade.prepare_features()
    trade.prepare_network()
    trade.graph_create(node_features = ['prev_gdp_per_cap_growth', 'current_gdp_per_cap_growth',
    'resource_0', 'resource_1', 'resource_2', 'resource_3', 'resource_4', 'resource_5', 'resource_6', 'resource_7',
       'resource_8', 'resource_9'],
        node_labels = 'future_gdp_per_cap_growth')
    
    if(year in val_years):
        val_graphs.append(trade.pyg_graph)
    elif(year in test_years):
        test_graphs.append(trade.pyg_graph)
    else: 
        train_graphs.append(trade.pyg_graph)
        
    trade.features["year"] = year
    
    if(i == 0):
        trade_df = trade.features
    else: 
        trade_df = pd.concat([trade_df, trade.features])
        
    i = i+1
    print(trade.node_attributes.size())

  0%|          | 0/57 [00:00<?, ?it/s]

1962

  2%|▏         | 1/57 [00:05<04:47,  5.13s/it]

torch.Size([199, 12])
1963

  4%|▎         | 2/57 [00:10<04:47,  5.22s/it]

torch.Size([199, 12])
1964

  5%|▌         | 3/57 [00:15<04:50,  5.37s/it]

torch.Size([199, 12])
1965

  7%|▋         | 4/57 [00:21<04:52,  5.51s/it]

torch.Size([199, 12])
1966

  9%|▉         | 5/57 [00:27<05:00,  5.77s/it]

torch.Size([199, 12])
1967

 11%|█         | 6/57 [00:34<05:01,  5.91s/it]

torch.Size([199, 12])
1968

 12%|█▏        | 7/57 [00:40<05:03,  6.06s/it]

torch.Size([199, 12])
1969

 14%|█▍        | 8/57 [00:47<05:09,  6.32s/it]

torch.Size([199, 12])
1970

 16%|█▌        | 9/57 [00:54<05:14,  6.54s/it]

torch.Size([199, 12])
1971

 18%|█▊        | 10/57 [01:01<05:19,  6.79s/it]

torch.Size([199, 12])
1972

 19%|█▉        | 11/57 [01:09<05:32,  7.23s/it]

torch.Size([199, 12])
1973

 21%|██        | 12/57 [01:18<05:42,  7.60s/it]

torch.Size([199, 12])
1974

 23%|██▎       | 13/57 [01:27<05:50,  7.97s/it]

torch.Size([199, 12])
1975

 25%|██▍       | 14/57 [01:36<05:56,  8.28s/it]

torch.Size([199, 12])
1976

 26%|██▋       | 15/57 [01:40<04:59,  7.12s/it]

torch.Size([199, 12])
1977

 28%|██▊       | 16/57 [01:45<04:23,  6.42s/it]

torch.Size([199, 12])
1978

 30%|██▉       | 17/57 [01:51<04:16,  6.40s/it]

torch.Size([199, 12])
1979

 32%|███▏      | 18/57 [01:58<04:09,  6.40s/it]

torch.Size([199, 12])
1980

 33%|███▎      | 19/57 [02:04<04:06,  6.49s/it]

torch.Size([199, 12])
1981

 35%|███▌      | 20/57 [02:11<04:04,  6.61s/it]

torch.Size([199, 12])
1982

 37%|███▋      | 21/57 [02:18<04:03,  6.75s/it]

torch.Size([199, 12])
1983

 39%|███▊      | 22/57 [02:25<03:59,  6.83s/it]

torch.Size([199, 12])
1984

 40%|████      | 23/57 [02:33<03:58,  7.00s/it]

torch.Size([199, 12])
1985

 42%|████▏     | 24/57 [02:41<04:00,  7.27s/it]

torch.Size([199, 12])
1986

 44%|████▍     | 25/57 [02:48<03:51,  7.25s/it]

torch.Size([199, 12])
1987

 46%|████▌     | 26/57 [02:56<03:52,  7.49s/it]

torch.Size([199, 12])
1988

 47%|████▋     | 27/57 [03:05<03:57,  7.93s/it]

torch.Size([199, 12])
1989

 49%|████▉     | 28/57 [03:14<03:57,  8.19s/it]

torch.Size([199, 12])
1990

 51%|█████     | 29/57 [03:23<03:57,  8.49s/it]

torch.Size([199, 12])
1991

 53%|█████▎    | 30/57 [03:32<03:53,  8.64s/it]

torch.Size([199, 12])
1992

 54%|█████▍    | 31/57 [03:43<04:06,  9.46s/it]

torch.Size([199, 12])
1993

 56%|█████▌    | 32/57 [03:54<04:06,  9.84s/it]

torch.Size([199, 12])
1994

 58%|█████▊    | 33/57 [04:06<04:10, 10.44s/it]

torch.Size([199, 12])
1995

 60%|█████▉    | 34/57 [04:17<04:07, 10.77s/it]

torch.Size([199, 12])
1996

 61%|██████▏   | 35/57 [04:30<04:06, 11.22s/it]

torch.Size([199, 12])
1997

 63%|██████▎   | 36/57 [04:42<04:02, 11.56s/it]

torch.Size([199, 12])
1998

 65%|██████▍   | 37/57 [04:55<04:00, 12.00s/it]

torch.Size([199, 12])
1999

 67%|██████▋   | 38/57 [05:09<03:59, 12.62s/it]

torch.Size([199, 12])
2000

 68%|██████▊   | 39/57 [05:24<03:57, 13.20s/it]

torch.Size([199, 12])
2001

 70%|███████   | 40/57 [05:39<03:55, 13.84s/it]

torch.Size([199, 12])
2002

 72%|███████▏  | 41/57 [05:53<03:44, 14.03s/it]

torch.Size([199, 12])
2003

 74%|███████▎  | 42/57 [06:09<03:37, 14.47s/it]

torch.Size([199, 12])
2004

 75%|███████▌  | 43/57 [06:26<03:31, 15.10s/it]

torch.Size([199, 12])
2005

 77%|███████▋  | 44/57 [06:46<03:35, 16.59s/it]

torch.Size([199, 12])
2006

 79%|███████▉  | 45/57 [07:04<03:24, 17.07s/it]

torch.Size([199, 12])
2007

 81%|████████  | 46/57 [07:22<03:13, 17.56s/it]

torch.Size([199, 12])
2008

 82%|████████▏ | 47/57 [07:42<03:00, 18.02s/it]

torch.Size([199, 12])
2009

 84%|████████▍ | 48/57 [08:00<02:43, 18.20s/it]

torch.Size([199, 12])
2010

 86%|████████▌ | 49/57 [08:18<02:25, 18.15s/it]

torch.Size([199, 12])
2011

 88%|████████▊ | 50/57 [08:36<02:06, 18.10s/it]

torch.Size([199, 12])
2012

 89%|████████▉ | 51/57 [08:55<01:49, 18.30s/it]

torch.Size([199, 12])
2013

 91%|█████████ | 52/57 [09:14<01:32, 18.54s/it]

torch.Size([199, 12])
2014

 93%|█████████▎| 53/57 [09:33<01:14, 18.66s/it]

torch.Size([199, 12])
2015

 95%|█████████▍| 54/57 [09:52<00:56, 18.89s/it]

torch.Size([199, 12])
2016

 96%|█████████▋| 55/57 [10:12<00:38, 19.06s/it]

torch.Size([199, 12])
2017

 98%|█████████▊| 56/57 [10:31<00:19, 19.09s/it]

torch.Size([199, 12])
2018

100%|██████████| 57/57 [10:51<00:00, 11.43s/it]

torch.Size([199, 12])





In [None]:
import pickle as pkl
with open("pygcn/train_graphs.pickle", "wb") as f:
    pkl.dump(train_graphs, f)

with open("pygcn/val_graphs.pickle", "wb") as f:
    pkl.dump(val_graphs, f)

with open("pygcn/test_graphs.pickle", "wb") as f:
    pkl.dump(test_graphs, f)

In [6]:
all_nodes = ['ABW', 'AFG', 'AGO', 'ALB', 'AND', 'ARE', 'ARG', 'ARM', 'ASM',
       'ATG', 'AUS', 'AUT', 'AZE', 'BDI', 'BEL', 'BEN', 'BFA', 'BGD',
       'BGR', 'BHR', 'BHS', 'BIH', 'BLR', 'BLZ', 'BMU', 'BOL', 'BRA',
       'BRB', 'BRN', 'BTN', 'BWA', 'CAF', 'CAN', 'CHE', 'CHL', 'CHN',
       'CIV', 'CMR', 'COD', 'COG', 'COL', 'COM', 'CPV', 'CRI', 'CUB',
       'CUW', 'CYM', 'CYP', 'CZE', 'DEU', 'DMA', 'DNK', 'DOM', 'DZA',
       'ECU', 'EGY', 'ESP', 'EST', 'ETH', 'FIN', 'FJI', 'FRA', 'FSM',
       'GAB', 'GBR', 'GEO', 'GHA', 'GIN', 'GMB', 'GNB', 'GNQ', 'GRC',
       'GRD', 'GRL', 'GTM', 'GUM', 'GUY', 'HKG', 'HND', 'HRV', 'HTI',
       'HUN', 'IDN', 'IND', 'IRL', 'IRN', 'IRQ', 'ISL', 'ISR', 'ITA',
       'JAM', 'JOR', 'JPN', 'KAZ', 'KEN', 'KGZ', 'KHM', 'KNA', 'KOR',
       'KWT', 'LAO', 'LBN', 'LBR', 'LBY', 'LCA', 'LKA', 'LSO', 'LTU',
       'LUX', 'LVA', 'MAC', 'MAR', 'MDA', 'MDG', 'MDV', 'MEX', 'MHL',
       'MKD', 'MLI', 'MLT', 'MMR', 'MNE', 'MNG', 'MNP', 'MOZ', 'MRT',
       'MUS', 'MWI', 'MYS', 'NAM', 'NER', 'NGA', 'NIC', 'NLD', 'NOR',
       'NPL', 'NRU', 'NZL', 'OMN', 'PAK', 'PAN', 'PER', 'PHL', 'PLW',
       'PNG', 'POL', 'PRT', 'PRY', 'PSE', 'PYF', 'QAT', 'ROU', 'RUS',
       'RWA', 'SAU', 'SDN', 'SEN', 'SGP', 'SLB', 'SLE', 'SLV', 'SMR',
       'SRB', 'SSD', 'STP', 'SUR', 'SVK', 'SVN', 'SWE', 'SWZ', 'SXM',
       'SYC', 'SYR', 'TCD', 'TGO', 'THA', 'TJK', 'TKM', 'TLS', 'TON',
       'TTO', 'TUN', 'TUR', 'TUV', 'TZA', 'UGA', 'UKR', 'URY', 'USA',
       'UZB', 'VCT', 'VEN', 'VNM', 'VUT', 'WSM', 'YEM', 'ZAF', 'ZMB',
       'ZWE']

In [7]:
with open("feature_dicts/random_features_dict.pkl", "rb") as f:
    feat_dict_random = pkl.load(f)

In [8]:
def add_features(years, graphs, feat_dict):

    zeros = torch.zeros(434)

    for i in range(len(years)):
        new_x = torch.empty(0, 434)
        year = years[i]

        feat_dict_year = feat_dict[year].combined_features

        for j, country in enumerate(all_nodes):
            if j == 0:
                new_x = torch.stack([zeros])

            elif country in feat_dict_year["country_code"].values:
                tensor_before = graphs[i].x[j]
                country_row = feat_dict_year[feat_dict_year["country_code"] == country]
                country_row = country_row.drop(columns = ["prev_gdp_growth", "country_code", "current_gdp_growth"])
                row_values = country_row.values.tolist()
                row_tensor = torch.tensor(row_values)[0]
                combined_values = torch.cat((tensor_before, row_tensor))

                new_x = torch.cat((new_x, combined_values.unsqueeze(0)), dim=0)

            else:
                new_x = torch.cat((new_x, zeros.unsqueeze(0)), dim=0)

        graphs[i].x = new_x

    return graphs

In [9]:
train_graphs = add_features(train_years, train_graphs, feat_dict_random)
val_graphs = add_features(val_years, val_graphs, feat_dict_random)
test_graphs = add_features(test_years, test_graphs, feat_dict_random)

In [10]:
from torch_geometric.data import DataLoader
test_loader = DataLoader(test_graphs, batch_size=32)
train_loader = DataLoader(train_graphs, batch_size=32)
val_loader = DataLoader(val_graphs, batch_size=32)



In [19]:
def check_crisis_years(year_pairs, crisis_years):
    result = []
    for pair in year_pairs:
        start, end = pair
        # Check if any crisis year is between the pair or equals the later year
        if any(start < year <= end for year in crisis_years):
            result.append(0)
        else:
            result.append(1)
    return result

In [13]:
crisis_years = [1983, 1982, 2008, 2002, 2016, 1967, 1962, 1989, 2012, 1963, 1993, 1986, 1996,1978]

def get_year_pairs(year_range):
    return [(year1, year2) for year1 in year_range for year2 in year_range if year2 >= year1]

def get_loader_pairs(dataset):
    return [(dataset[i], dataset[j]) for i in range(len(dataset)) for j in range(len(dataset)) if j >= i]

train_pairs = get_year_pairs(train_years)
val_pairs = get_year_pairs(val_years)

train_y = check_crisis_years(train_pairs, crisis_years)
val_y = check_crisis_years(val_pairs, crisis_years)

train_loader_pairs = get_loader_pairs(train_loader.dataset)
val_loader_pairs = get_loader_pairs(val_loader.dataset)

In [14]:
train_torch_y = torch.tensor(np.array(train_y))
val_torch_y = torch.tensor(np.array(val_y))

In [23]:
N_neg = torch.eq(train_torch_y, 0).sum().item()
N_pos = torch.eq(train_torch_y, 1).sum().item()

# Calculate the weight for the positive class
weight_for_positive_class = N_neg / N_pos
pos_weight = torch.tensor([weight_for_positive_class])