In [26]:
#imports 
import pandas as pd
import numpy as np
import os
import pickle
import datetime as datetime
from sklearn.preprocessing import StandardScaler
import statsmodels.formula.api as sm
import dgl.function as fn
from tqdm import tqdm

#imports for graph creation
import torch
from sklearn.preprocessing import StandardScaler
from itertools import combinations
from sklearn.metrics import r2_score
import matplotlib.pyplot as plt

#imports for graph learning
from torch.utils.data import DataLoader
from torch.autograd import Variable
import torch.nn as nn
from tqdm import trange
import torch
import torch_geometric.datasets as datasets
import torch_geometric.data as data
import torch_geometric.transforms as transforms

In [6]:
class TradeNetwork:
    """
    We define a class which computes the MST trade network for a given year 
    """
    
    def __init__(self, year = 1962, data_dir = "data"):
        self.year = year
        self.data_dir = data_dir
        
    def prepare_features(self, filter_gdp = True):
        
        ###IMPORT GDP###
        #prepare GDP as a set of features 
        with open(self.data_dir + '/all_wb_indicators.pickle', 'rb') as handle:
            features_dict = pkl.load(handle)

        self.gdp = features_dict['NY.GDP.MKTP.CD']
        scaler = StandardScaler()
        self.gdp[["prev_gdp"]] = scaler.fit_transform(np.log(self.gdp[['YR'+str(self.year-1)]]))
        self.gdp[["current_gdp"]] = scaler.fit_transform(np.log(self.gdp[['YR'+str(self.year)]]))
        #rename and keep relevant columns
        self.gdp["country_code"] = self.gdp["economy"]
        self.gdp = self.gdp[["country_code", "prev_gdp", "current_gdp"]].dropna()
        
        ###IMPORT GDP GROWTH###
        #prepare GDP growth
        self.gdp_growth = features_dict['NY.GDP.MKTP.KD.ZG']
        self.gdp_growth["prev_gdp_growth"] = self.gdp_growth['YR'+str(self.year-1)]
        self.gdp_growth["current_gdp_growth"] = self.gdp_growth['YR'+str(self.year)] 
        self.gdp_growth["future_gdp_growth"] = self.gdp_growth['YR'+str(self.year+1)]
        #rename and keep relevant columns
        self.gdp_growth["country_code"] = self.gdp_growth["economy"]
        self.gdp_growth = self.gdp_growth[["country_code", "prev_gdp_growth",
                                "current_gdp_growth", "future_gdp_growth"]].dropna()
        
        ###IMPORT GDP PER CAPITA###
        self.gdp_per_capita = features_dict['NY.GDP.PCAP.CD']
        self.gdp_per_capita["prev_gdp_per_cap"] = self.gdp_per_capita['YR'+str(self.year-1)]
        self.gdp_per_capita["current_gdp_per_cap"] = self.gdp_per_capita['YR'+str(self.year)]
        self.gdp_per_capita["future_gdp_per_cap"] = self.gdp_per_capita['YR'+str(self.year+1)]
        #rename and keep relevant columns
        self.gdp_per_capita["country_code"] = self.gdp_per_capita["economy"]
        self.gdp_per_capita = self.gdp_per_capita[["country_code", "prev_gdp_per_cap",
                                "current_gdp_per_cap", "future_gdp_per_cap"]].dropna()
        
        ###IMPORT GDP PER CAPITA GROWTH###
        self.gdp_per_capita_growth = features_dict['NY.GDP.PCAP.KD.ZG']
        self.gdp_per_capita_growth["prev_gdp_per_cap_growth"] = self.gdp_per_capita_growth['YR'+str(self.year-1)]
        self.gdp_per_capita_growth["current_gdp_per_cap_growth"] = self.gdp_per_capita_growth['YR'+str(self.year)]
        self.gdp_per_capita_growth["future_gdp_per_cap_growth"] = self.gdp_per_capita_growth['YR'+str(self.year+1)]
        
        #rename and keep relevant columns
        self.gdp_per_capita_growth["country_code"] = self.gdp_per_capita_growth["economy"]
        self.gdp_per_capita_growth = self.gdp_per_capita_growth[["country_code", "prev_gdp_per_cap_growth",
                                "current_gdp_per_cap_growth", "future_gdp_per_cap_growth"]].dropna()
        
        ###MERGE ALL DATA FEATURES###
        self.features = pd.merge(self.gdp_growth, self.gdp, on = "country_code").dropna()
        self.features = pd.merge(self.features, self.gdp_per_capita, on = "country_code").dropna()
        self.features = pd.merge(self.features, self.gdp_per_capita_growth, on = "country_code").dropna()

    def prepare_network(self):
        """
        We create an initial, import-centric trade link pandas dataframe for a given year
        """
        #get product codes
        data_dict = pd.read_json('https://comtrade.un.org/data/cache/classificationS2.json')
        data_cross = []
        i = 0
        for item_def in list(data_dict["results"]):
            if(i >= 2):
                data_cross.append(item_def["text"].split(" - ", 1))
            i = i+1

        self.product_codes = pd.DataFrame(data_cross, columns = ['code', 'product'])
        self.product_codes["sitc_product_code"] = self.product_codes["code"]
        
        #get country codes
        self.country_codes = pd.read_excel(self.data_dir + "/ISO3166.xlsx")
        self.country_codes["location_code"] = self.country_codes["Alpha-3 code"]
        self.country_codes["partner_code"] = self.country_codes["Alpha-3 code"]
        self.country_codes["country_i"] = self.country_codes["English short name"]
        self.country_codes["country_j"] = self.country_codes["English short name"]
        
        #get trade data for a given year
        trade_data = pd.read_stata(self.data_dir + "/country_partner_sitcproduct4digit_year_"+ str(self.year)+".dta") 
        #merge with product / country descriptions
        trade_data = pd.merge(trade_data, self.country_codes[["location_code", "country_i"]],on = ["location_code"])
        trade_data = pd.merge(trade_data, self.country_codes[["partner_code", "country_j"]],on = ["partner_code"])
        trade_data = pd.merge(trade_data, self.product_codes[["sitc_product_code", "product"]], 
                              on = ["sitc_product_code"])
        ###select level of product aggregation
        trade_data["product_category"] = trade_data["sitc_product_code"].apply(lambda x: x[0:1])
        
        #keep only nodes that we have features for
        #trade_data = trade_data[trade_data["location_code"].isin(self.features["country_code"])]
        #trade_data = trade_data[trade_data["partner_code"].isin(self.features["country_code"])]
        
        if (len(trade_data.groupby(["location_code", "partner_code", "sitc_product_code"])["import_value"].sum().reset_index()) != len(trade_data)):
            print("import, export, product combination not unique!")
        self.trade_data1 = trade_data
        #from import-export table, create only import table
        #extract imports
        imports1 = trade_data[['location_id', 'partner_id', 'product_id', 'year',
               'import_value', 'sitc_eci', 'sitc_coi', 'location_code', 'partner_code',
               'sitc_product_code', 'country_i', 'country_j', 'product', "product_category"]]
        imports1 = imports1[imports1["import_value"] != 0]
        #transform records of exports into imports
        imports2 = trade_data[['location_id', 'partner_id', 'product_id', 'year',
               'export_value', 'sitc_eci', 'sitc_coi', 'location_code', 'partner_code',
               'sitc_product_code', 'country_i', 'country_j', 'product', "product_category"]]
        imports2["temp1"] = imports2['partner_code']
        imports2["temp2"] = imports2['location_code']

        imports2['location_code'] = imports2["temp1"]
        imports2['partner_code'] = imports2["temp2"]
        imports2["import_value"] = imports2["export_value"]
        imports2 = imports2[imports2["import_value"] != 0]
        imports2 = imports2[['location_id', 'partner_id', 'product_id', 'year',
               'import_value', 'sitc_eci', 'sitc_coi', 'location_code', 'partner_code',
               'sitc_product_code', 'country_i', 'country_j', 'product', "product_category"]]
        
        imports_table = pd.concat([imports1, imports2]).drop_duplicates()
        
        #rename columns for better clarity
        imports_table["importer_code"] = imports_table["location_code"]
        imports_table["exporter_code"] = imports_table["partner_code"]
        imports_table["importer_name"] = imports_table["country_i"]
        imports_table["exporter_name"] = imports_table["country_j"]
        
        cols = ["importer_code", "exporter_code", "importer_name", "exporter_name",
               'product_id', 'year', 'import_value', 'sitc_eci', 'sitc_coi',
               'sitc_product_code', 'product', "product_category"]
        imports_table = imports_table[cols]
        
        exporter_total = imports_table.groupby(["exporter_code"])["import_value"].sum().reset_index()
        exporter_total = exporter_total.rename(columns = {"import_value": "export_total"})
        
        importer_total = imports_table.groupby(["importer_code"])["import_value"].sum().reset_index()
        importer_total = importer_total.rename(columns = {"import_value": "import_total"})
        
        #sum imports across all products between countries into single value 
        imports_table_grouped = imports_table.groupby(["importer_code", "exporter_code"])["import_value"].sum().reset_index()
        
        #sum exports in each category 
        self.export_types = imports_table.groupby(["exporter_code", "product_category"])["import_value"].sum().reset_index()
        self.export_types = pd.merge(self.export_types, exporter_total, on = "exporter_code")
        #multiply by 100 to allow weights to scale better in GNN
        self.export_types["category_fraction"] = self.export_types.import_value/self.export_types.export_total*10
        ss = StandardScaler()
        columns = list(set(self.export_types["product_category"]))
        self.export_types = self.export_types[["exporter_code", "product_category", "category_fraction"]]\
        .pivot(index = ["exporter_code"], columns = ["product_category"], values = "category_fraction")\
        .reset_index().fillna(0)
        #rename columns
        rename_columns = []
        for col in self.export_types.columns:
            if(col == "exporter_code"):
                rename_columns.append(col)
            else:
                rename_columns.append("resource_" + col)
        self.export_types.columns = rename_columns
        self.export_types = self.export_types.rename(columns = {"exporter_code": "country_code"})
        self.features = pd.merge(self.features, self.export_types, 
                                on = "country_code", how = "left")
        
        #look at fraction of goods traded with each counterparty
        imports_table_grouped = pd.merge(imports_table_grouped, exporter_total, how = "left")
        imports_table_grouped["export_percent"] = imports_table_grouped["import_value"]/imports_table_grouped["export_total"]
        scaler = StandardScaler()
        imports_table_grouped[["export_percent_feature"]] = scaler.fit_transform(np.log(imports_table_grouped[["export_percent"]]))
        imports_table_grouped["export_percent_feature"] = imports_table_grouped["export_percent_feature"] + abs(min(imports_table_grouped["export_percent_feature"]))
        
        imports_table_grouped = pd.merge(imports_table_grouped, importer_total, how = "left")
        imports_table_grouped["import_percent"] = imports_table_grouped["import_value"]/imports_table_grouped["import_total"]
        scaler = StandardScaler()
        imports_table_grouped[["import_percent_feature"]] = scaler.fit_transform(np.log(imports_table_grouped[["import_percent"]]))
        imports_table_grouped["import_percent_feature"] = imports_table_grouped["import_percent_feature"] + abs(min(imports_table_grouped["import_percent_feature"]))
        
        self.trade_data = imports_table_grouped

    def graph_create(self, exporter = True,
            node_features = ['prev_gdp_growth', 'current_gdp_growth','prev_gdp','current_gdp'],
            node_labels = 'future_gdp_growth'):
        
        if(exporter):
            center_node = "exporter_code"
            neighbors = "importer_code"
            edge_features = 'export_percent'
        
        #filter features and nodes to ones that are connected to others in trade data
        list_active_countries = list(set(list(self.trade_data ["importer_code"])+\
                        list(self.trade_data ["exporter_code"])))
        
        self.features = self.features[self.features["country_code"].isin(list_active_countries)].reset_index()
        self.features.fillna(0, inplace = True)
        self.features["node_numbers"] = self.features.index
        #create lookup dictionary making node number / node features combatible with ordering of nodes
        #in our edge table

        self.node_lookup1 = self.features.set_index('node_numbers').to_dict()['country_code']
        self.node_lookup2 = self.features.set_index('country_code').to_dict()['node_numbers']
        
        #get individual country's features
        self.regression_table = pd.merge(self.features, self.trade_data,
                        left_on = "country_code",
                        right_on = center_node, how = 'right')
        #get features for trade partners
        self.regression_table = pd.merge(self.features, self.regression_table,
                                        left_on = "country_code",
                                        right_on = neighbors, how = "right",
                                        suffixes = ("_neighbors", ""))
        
        self.trade_data = self.trade_data[self.trade_data[neighbors].isin(self.node_lookup2)]
        self.trade_data = self.trade_data[self.trade_data[center_node].isin(self.node_lookup2)]

        self.regression_table["source"] = self.trade_data[neighbors].apply(lambda x: self.node_lookup2[x])
        self.regression_table["target"] = self.trade_data[center_node].apply(lambda x: self.node_lookup2[x])    

        self.regression_table = self.regression_table.dropna()
        #filter only to relevant columns
        self.relevant_columns = ["source", "target"]
        self.relevant_columns.extend(node_features)
        self.relevant_columns.append(node_labels)
        self.graph_table = self.regression_table[self.relevant_columns]
        
        if(self.graph_table.isnull().values.any()): print("edges contain null / inf values")

        self.node_attributes = torch.tensor(np.array(self.features[node_features]))\
        .to(torch.float)
        self.source_nodes = list(self.graph_table["source"])
        self.target_nodes = list(self.graph_table["target"])
        self.edge_attributes = list(self.trade_data[edge_features])

        if self.year < 2016:
            brexit = 0
        else:
            brexit = 1

        self.brexit_status = brexit
    
        self.pyg_graph = data.Data(x = self.node_attributes,
                                   edge_index = torch.tensor([self.source_nodes, self.target_nodes]),
                                   edge_attr = torch.tensor(self.edge_attributes).to(torch.float),
                                   graph_attr = torch.tensor(self.brexit_status),
                                   y = torch.tensor(list(self.features[node_labels])).to(torch.float))

        

# Graph Setup

In [60]:
years = range(1990,2021)

val_years = [2017, 2016, 2013, 2000, 1995]
test_years = [2018, 2001, 2002, 1997, 2011]

train_graphs = []
val_graphs = []
test_graphs = []
i = 0

for year in tqdm(years):
    print(str(year), end='\r')
    
    trade = TradeNetwork(year = year)
    trade.prepare_features()
    trade.prepare_network()
    trade.graph_create(node_features = ['prev_gdp_per_cap_growth', 'current_gdp_per_cap_growth',
    'resource_0', 'resource_1', 'resource_2', 'resource_3', 'resource_4', 'resource_5', 'resource_6', 'resource_7',
       'resource_8', 'resource_9'],
        node_labels = 'future_gdp_per_cap_growth')
    
    if(year in val_years):
        val_graphs.append(trade.pyg_graph)
    elif(year in test_years):
        test_graphs.append(trade.pyg_graph)
    else: 
        train_graphs.append(trade.pyg_graph)
        
    trade.features["year"] = year
    
    if(i == 0):
        trade_df = trade.features
    else: 
        trade_df = pd.concat([trade_df, trade.features])
        
    i = i+1
    print(trade.node_attributes.size())

  0%|          | 0/31 [00:00<?, ?it/s]

1990

  3%|▎         | 1/31 [00:10<05:20, 10.69s/it]

torch.Size([141, 12])
1991

  6%|▋         | 2/31 [00:21<05:10, 10.70s/it]

torch.Size([142, 12])
1992

 10%|▉         | 3/31 [00:32<05:07, 10.98s/it]

torch.Size([159, 12])
1993

 13%|█▎        | 4/31 [00:44<05:07, 11.38s/it]

torch.Size([162, 12])
1994

 16%|█▌        | 5/31 [00:57<05:08, 11.85s/it]

torch.Size([166, 12])
1995

 19%|█▉        | 6/31 [01:09<05:01, 12.05s/it]

torch.Size([167, 12])
1996

 23%|██▎       | 7/31 [01:23<05:04, 12.68s/it]

torch.Size([168, 12])
1997

 26%|██▌       | 8/31 [01:37<05:00, 13.07s/it]

torch.Size([177, 12])
1998

 29%|██▉       | 9/31 [01:52<05:02, 13.76s/it]

torch.Size([177, 12])
1999

 32%|███▏      | 10/31 [02:08<04:58, 14.19s/it]

torch.Size([179, 12])
2000

 35%|███▌      | 11/31 [02:23<04:51, 14.56s/it]

torch.Size([184, 12])
2001

 39%|███▊      | 12/31 [02:39<04:45, 15.02s/it]

torch.Size([185, 12])
2002

 42%|████▏     | 13/31 [02:56<04:38, 15.49s/it]

torch.Size([189, 12])
2003

 45%|████▌     | 14/31 [03:13<04:33, 16.08s/it]

torch.Size([190, 12])
2004

 48%|████▊     | 15/31 [03:31<04:26, 16.64s/it]

torch.Size([194, 12])
2005

 52%|█████▏    | 16/31 [03:51<04:24, 17.66s/it]

torch.Size([195, 12])
2006

 55%|█████▍    | 17/31 [04:11<04:16, 18.35s/it]

torch.Size([196, 12])
2007

 58%|█████▊    | 18/31 [04:31<04:04, 18.81s/it]

torch.Size([196, 12])
2008

 61%|██████▏   | 19/31 [04:53<03:56, 19.70s/it]

torch.Size([197, 12])
2009

 65%|██████▍   | 20/31 [05:14<03:41, 20.14s/it]

torch.Size([197, 12])
2010

 68%|██████▊   | 21/31 [05:35<03:24, 20.44s/it]

torch.Size([197, 12])
2011

 71%|███████   | 22/31 [05:55<03:03, 20.43s/it]

torch.Size([197, 12])
2012

 74%|███████▍  | 23/31 [06:17<02:45, 20.64s/it]

torch.Size([200, 12])
2013

 77%|███████▋  | 24/31 [06:38<02:25, 20.83s/it]

torch.Size([201, 12])
2014

 81%|████████  | 25/31 [06:59<02:04, 20.81s/it]

torch.Size([200, 12])
2015

 84%|████████▍ | 26/31 [07:21<01:46, 21.26s/it]

torch.Size([201, 12])
2016

 87%|████████▋ | 27/31 [07:42<01:25, 21.33s/it]

torch.Size([201, 12])
2017

 90%|█████████ | 28/31 [08:04<01:04, 21.51s/it]

torch.Size([201, 12])
2018

 94%|█████████▎| 29/31 [08:27<00:43, 21.80s/it]

torch.Size([199, 12])
2019

 97%|█████████▋| 30/31 [08:52<00:22, 22.88s/it]

torch.Size([197, 12])
2020

100%|██████████| 31/31 [09:14<00:00, 17.90s/it]

torch.Size([192, 12])





In [61]:
test_batch = data.Batch().from_data_list(test_graphs)
val_batch = data.Batch().from_data_list(val_graphs)
train_batch = data.Batch().from_data_list(train_graphs)

In [62]:
from torch_geometric.nn import GCNConv
import torch.nn.functional as F
import torch_geometric.transforms as T

In [63]:
from torch_geometric.nn import global_mean_pool
from torch.nn import Linear

class GCN(torch.nn.Module):
    def __init__(self, num_features, hidden_channels):
        super(GCN, self).__init__()
        self.num_features = num_features
        self.conv1 = GCNConv(self.num_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        self.lin = Linear(hidden_channels, 2)

    def forward(self, x, edge_index, batch):
        # 1. Obtain node embeddings 
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.conv3(x, edge_index)

        # 2. Readout layer
        x = global_mean_pool(x, batch)  # [batch_size, hidden_channels]

        # 3. Apply a final classifier
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin(x)
        
        return x
    
    def reset(self):
        self.linear2.reset_parameters()
        self.conv1.reset_parameters()

In [68]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = "cpu"

loss_fn = torch.nn.CrossEntropyLoss()  # Change to classification loss
model = GCN(num_features = train_batch.num_features, hidden_channels = 64).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)
train_loss = []
valid_loss = []
best_val_loss = np.inf 
best_model = None
i = 0 
model.train()

for epoch in trange(10):
    model.train()
    optimizer.zero_grad()
    
    train_batch.edge_index = train_batch.edge_index.long()  # Convert to int64 if necessary
    test_batch.edge_index = test_batch.edge_index.long()  # Convert to int64 if necessary
    val_batch.edge_index = val_batch.edge_index.long()  # Convert to int64 if necessary

    out = model.forward(train_batch.x, train_batch.edge_index, train_batch.batch)  # Add batch tensor
    loss = loss_fn(out, train_batch.graph_attr)  # Change loss calculation
    train_loss.append(loss.item())
    
    loss.backward()
    optimizer.step()
    
    #validation
    model.eval()
    out_val = model(val_batch.x, val_batch.edge_index, val_batch.batch)  # Add batch tensor
    current_val_loss = loss_fn(out_val, val_batch.graph_attr).item()  # Change loss calculation
    valid_loss.append(current_val_loss)
    
    # Calculate accuracy
    _, pred = out_val.max(dim=1)
    correct = float(pred.eq(val_batch.graph_attr).sum().item())
    acc = correct / len(val_batch.graph_attr)
    print('Accuracy: {:.4f}'.format(acc))
    
    # early stopping 
    if current_val_loss < best_val_loss:
        best_val_loss = current_val_loss
        best_weights = model.state_dict()
        best_iteration = i
    
    i = i+1

  0%|          | 0/10 [00:00<?, ?it/s]

 10%|█         | 1/10 [00:00<00:02,  3.07it/s]

Accuracy: 0.6000


 20%|██        | 2/10 [00:00<00:02,  2.85it/s]

Accuracy: 0.6000


 30%|███       | 3/10 [00:01<00:02,  2.80it/s]

Accuracy: 0.6000


 40%|████      | 4/10 [00:01<00:02,  2.93it/s]

Accuracy: 0.6000


 50%|█████     | 5/10 [00:01<00:01,  3.23it/s]

Accuracy: 0.6000


 60%|██████    | 6/10 [00:01<00:01,  3.06it/s]

Accuracy: 0.6000


 70%|███████   | 7/10 [00:02<00:00,  3.16it/s]

Accuracy: 0.6000


 80%|████████  | 8/10 [00:02<00:00,  3.26it/s]

Accuracy: 0.6000


 90%|█████████ | 9/10 [00:02<00:00,  3.45it/s]

Accuracy: 0.6000


100%|██████████| 10/10 [00:03<00:00,  3.21it/s]

Accuracy: 0.6000



