In [264]:
#imports 
import pandas as pd
import numpy as np
import os
import pickle as pkl
import datetime as datetime
from sklearn.preprocessing import StandardScaler
import statsmodels.formula.api as sm
import dgl.function as fn
from tqdm import tqdm
import networkx as nx
from src.utils.TradeNetwork import TradeNetwork

import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from torch.optim.lr_scheduler import StepLR
from src.pygcn.SiameseGNN import SiameseGNN

#imports for graph creation
import torch
from sklearn.preprocessing import StandardScaler
from itertools import combinations
from sklearn.metrics import r2_score
import matplotlib.pyplot as plt

#imports for graph learning
from torch.utils.data import DataLoader
from torch.autograd import Variable
import torch.nn as nn
from tqdm import trange
import torch
import torch_geometric.datasets as datasets
import torch_geometric.data as data
import torch_geometric.transforms as transforms

## All Events

In [265]:
years = range(1962,2019)

train_years = [2005, 1969, 2002, 1997, 1993, 1982, 2001, 2000, 1962, 1985, 1978, 2016, 1986, 1987, 1989, 1971, 2013, 1996, 1995, 1967, 2017, 1974, 1990, 1977, 1980, 2014, 1965, 1984, 2006, 1973, 1968, 1981, 1970, 1991]
val_years = [1975, 1983, 2009, 1966, 1999, 1988, 2007, 1979, 1972, 2015, 2003]
test_years = [1963, 1964, 1976, 1992, 1994, 1998, 2004, 2008, 2010, 2011, 2012, 2018]

In [274]:
all_nodes = ['ABW', 'AFG', 'AGO', 'ALB', 'AND', 'ARE', 'ARG', 'ARM', 'ASM',
       'ATG', 'AUS', 'AUT', 'AZE', 'BDI', 'BEL', 'BEN', 'BFA', 'BGD',
       'BGR', 'BHR', 'BHS', 'BIH', 'BLR', 'BLZ', 'BMU', 'BOL', 'BRA',
       'BRB', 'BRN', 'BTN', 'BWA', 'CAF', 'CAN', 'CHE', 'CHL', 'CHN',
       'CIV', 'CMR', 'COD', 'COG', 'COL', 'COM', 'CPV', 'CRI', 'CUB',
       'CUW', 'CYM', 'CYP', 'CZE', 'DEU', 'DMA', 'DNK', 'DOM', 'DZA',
       'ECU', 'EGY', 'ESP', 'EST', 'ETH', 'FIN', 'FJI', 'FRA', 'FSM',
       'GAB', 'GBR', 'GEO', 'GHA', 'GIN', 'GMB', 'GNB', 'GNQ', 'GRC',
       'GRD', 'GRL', 'GTM', 'GUM', 'GUY', 'HKG', 'HND', 'HRV', 'HTI',
       'HUN', 'IDN', 'IND', 'IRL', 'IRN', 'IRQ', 'ISL', 'ISR', 'ITA',
       'JAM', 'JOR', 'JPN', 'KAZ', 'KEN', 'KGZ', 'KHM', 'KNA', 'KOR',
       'KWT', 'LAO', 'LBN', 'LBR', 'LBY', 'LCA', 'LKA', 'LSO', 'LTU',
       'LUX', 'LVA', 'MAC', 'MAR', 'MDA', 'MDG', 'MDV', 'MEX', 'MHL',
       'MKD', 'MLI', 'MLT', 'MMR', 'MNE', 'MNG', 'MNP', 'MOZ', 'MRT',
       'MUS', 'MWI', 'MYS', 'NAM', 'NER', 'NGA', 'NIC', 'NLD', 'NOR',
       'NPL', 'NRU', 'NZL', 'OMN', 'PAK', 'PAN', 'PER', 'PHL', 'PLW',
       'PNG', 'POL', 'PRT', 'PRY', 'PSE', 'PYF', 'QAT', 'ROU', 'RUS',
       'RWA', 'SAU', 'SDN', 'SEN', 'SGP', 'SLB', 'SLE', 'SLV', 'SMR',
       'SRB', 'SSD', 'STP', 'SUR', 'SVK', 'SVN', 'SWE', 'SWZ', 'SXM',
       'SYC', 'SYR', 'TCD', 'TGO', 'THA', 'TJK', 'TKM', 'TLS', 'TON',
       'TTO', 'TUN', 'TUR', 'TUV', 'TZA', 'UGA', 'UKR', 'URY', 'USA',
       'UZB', 'VCT', 'VEN', 'VNM', 'VUT', 'WSM', 'YEM', 'ZAF', 'ZMB',
       'ZWE']

In [275]:
#12 features
import pickle as pkl
with open("src/pygcn/train_graphs.pickle", "rb") as f:
    train_graphs = pkl.load(f)

with open("src/pygcn/val_graphs.pickle", "rb") as f:  
    val_graphs = pkl.load(f)

with open("src/pygcn/test_graphs.pickle", "rb") as f:         
    test_graphs = pkl.load(f)

In [276]:
from torch_geometric.data import DataLoader
test_loader = DataLoader(test_graphs, batch_size=4)
train_loader = DataLoader(train_graphs, batch_size=4)
val_loader = DataLoader(val_graphs, batch_size=4)



## sGNN with GCN Encoder and 3 Features

In [277]:
def check_crisis_years(year_pairs, crisis_years):
    result = []
    for pair in year_pairs:
        start, end = pair
        # Check if any crisis year is between the pair or equals the later year
        if any(start < year <= end for year in crisis_years):
            result.append(0)
        else:
            result.append(1)
    return result

In [278]:
crisis_years = [1983, 1982, 2008, 2002, 2016, 1967, 1962, 1989, 2012, 1963, 1993, 1986, 1996,1978]

def get_year_pairs(year_range):
    return [(year1, year2) for year1 in year_range for year2 in year_range if year2 >= year1]

def get_loader_pairs(dataset):
    return [(dataset[i], dataset[j]) for i in range(len(dataset)) for j in range(len(dataset)) if j >= i]

def get_graph_pairs(graphs):
    return [(graphs[i], graphs[j]) for i in range(len(graphs)) for j in range(len(graphs)) if j >= i]

train_pairs = get_year_pairs(train_years)
val_pairs = get_year_pairs(val_years)

train_y = check_crisis_years(train_pairs, crisis_years)
val_y = check_crisis_years(val_pairs, crisis_years)

train_loader_pairs = get_loader_pairs(train_loader.dataset)
val_loader_pairs = get_loader_pairs(val_loader.dataset)

In [280]:
train_graph_pairs = get_graph_pairs(train_graphs)
val_graph_pairs = get_graph_pairs(val_graphs)

train_torch_y = torch.tensor(np.array(train_y))
val_torch_y = torch.tensor(np.array(val_y))

labeled_pairs_train = list(zip(train_loader_pairs, train_y))
labeled_pairs_val = list(zip(val_loader_pairs, val_y))

flattened_train = [(a, b, c) for ((a, b), c) in labeled_pairs_train]
flattened_val  = [(a, b, c) for ((a, b), c) in labeled_pairs_val]

In [281]:
import random
positive_samples = [item for item in flattened_train if item[2] == 1]
negative_samples = [item for item in flattened_train if item[2] == 0]

# Calculate the difference in count
diff = len(negative_samples) - len(positive_samples)

# Upsample positive samples
if diff > 0:
    positive_samples_upsampled = positive_samples * (diff // len(positive_samples)) + random.sample(positive_samples, diff % len(positive_samples))
    balanced_data = negative_samples + positive_samples + positive_samples_upsampled
else:
    balanced_data = data

# Shuffle the balanced dataset
random.shuffle(balanced_data)

In [283]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from torch.optim.lr_scheduler import StepLR
from src.pygcn.SiameseGNN import SiameseGNN

model = SiameseGNN(num_features=balanced_data[0][0].num_node_features)
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = StepLR(optimizer, step_size=10, gamma=0.1)  # Adjust step_size and gamma as needed
criterion = nn.BCELoss()

for epoch in tqdm(range(20)):
    model.train()
    train_losses = []
    for data1, data2, label in balanced_data:

        optimizer.zero_grad()
        out = model(data1, data2)
        label = torch.tensor(label).view(1).float()
        loss = criterion(out.squeeze(0), label)
        loss.backward()
        optimizer.step()
        train_losses.append(loss.item())

    scheduler.step()  # Add this line to update the learning rate

    model.eval()
    with torch.no_grad():
        val_losses = []
        correct = 0
        total = 0
        for data1, data2, label in flattened_val:
            out = model(data1, data2)
            label = torch.tensor(label).view(1).float()
            val_loss = criterion(out.squeeze(0), label)
            val_losses.append(val_loss.item())

            predictions = torch.round(out.squeeze())
            correct += (predictions == label).sum().item()
            total += 1

        val_loss = sum(val_losses) / len(val_losses)
        val_accuracy = correct / total

    print(f'Epoch: {epoch+1}, Training Loss: {sum(train_losses)/len(train_losses)}, Validation Loss: {val_loss}, Validation Accuracy: {val_accuracy}')

  5%|▌         | 1/20 [00:23<07:33, 23.89s/it]

Epoch: 1, Training Loss: 0.704062089469233, Validation Loss: 0.6934422326810432, Validation Accuracy: 0.4090909090909091


 10%|█         | 2/20 [00:46<06:53, 22.98s/it]

Epoch: 2, Training Loss: 0.6994378892940852, Validation Loss: 0.6963548895084497, Validation Accuracy: 0.3787878787878788


 15%|█▌        | 3/20 [01:07<06:20, 22.40s/it]

Epoch: 3, Training Loss: 0.6877583678644767, Validation Loss: 0.7227827464089249, Validation Accuracy: 0.4090909090909091


 20%|██        | 4/20 [01:29<05:50, 21.93s/it]

Epoch: 4, Training Loss: 0.6522456469090535, Validation Loss: 0.6409104231632117, Validation Accuracy: 0.7272727272727273


 25%|██▌       | 5/20 [01:51<05:28, 21.92s/it]

Epoch: 5, Training Loss: 0.637310588467307, Validation Loss: 0.6458863381183508, Validation Accuracy: 0.696969696969697


 30%|███       | 6/20 [02:16<05:23, 23.14s/it]

Epoch: 6, Training Loss: 0.6231229789881852, Validation Loss: 0.6498651793508818, Validation Accuracy: 0.6818181818181818


 35%|███▌      | 7/20 [02:39<05:00, 23.09s/it]

Epoch: 7, Training Loss: 0.6114073890310369, Validation Loss: 0.6658673485120138, Validation Accuracy: 0.6363636363636364


 40%|████      | 8/20 [03:00<04:30, 22.54s/it]

Epoch: 8, Training Loss: 0.588663114874749, Validation Loss: 0.6269845045877226, Validation Accuracy: 0.6515151515151515


 45%|████▌     | 9/20 [03:22<04:05, 22.30s/it]

Epoch: 9, Training Loss: 0.5654159297460425, Validation Loss: 0.6164591463677811, Validation Accuracy: 0.6818181818181818


 50%|█████     | 10/20 [03:44<03:40, 22.00s/it]

Epoch: 10, Training Loss: 0.5310307302150048, Validation Loss: 0.6518444781276312, Validation Accuracy: 0.6363636363636364


 55%|█████▌    | 11/20 [04:05<03:16, 21.79s/it]

Epoch: 11, Training Loss: 0.5181332451543729, Validation Loss: 0.5604770633420257, Validation Accuracy: 0.696969696969697


 60%|██████    | 12/20 [04:28<02:56, 22.11s/it]

Epoch: 12, Training Loss: 0.46514302408086683, Validation Loss: 0.5806865897029638, Validation Accuracy: 0.696969696969697


 65%|██████▌   | 13/20 [04:52<02:38, 22.65s/it]

Epoch: 13, Training Loss: 0.4379944432096567, Validation Loss: 0.6200287806603945, Validation Accuracy: 0.696969696969697


 70%|███████   | 14/20 [05:14<02:15, 22.60s/it]

Epoch: 14, Training Loss: 0.4174281878925163, Validation Loss: 0.6598478405256615, Validation Accuracy: 0.696969696969697


 75%|███████▌  | 15/20 [05:36<01:52, 22.41s/it]

Epoch: 15, Training Loss: 0.39920635540655597, Validation Loss: 0.6995477478132781, Validation Accuracy: 0.696969696969697


 80%|████████  | 16/20 [05:57<01:27, 21.93s/it]

Epoch: 16, Training Loss: 0.38138226939448777, Validation Loss: 0.7474360028564027, Validation Accuracy: 0.7121212121212122


 85%|████████▌ | 17/20 [06:18<01:04, 21.64s/it]

Epoch: 17, Training Loss: 0.36306861417638553, Validation Loss: 0.7844766561562816, Validation Accuracy: 0.7121212121212122


 90%|█████████ | 18/20 [06:39<00:42, 21.41s/it]

Epoch: 18, Training Loss: 0.3464847946662971, Validation Loss: 0.8321419476362114, Validation Accuracy: 0.7121212121212122


 95%|█████████▌| 19/20 [07:00<00:21, 21.27s/it]

Epoch: 19, Training Loss: 0.3328061409255573, Validation Loss: 0.8615875691175461, Validation Accuracy: 0.7121212121212122


100%|██████████| 20/20 [07:21<00:00, 22.09s/it]

Epoch: 20, Training Loss: 0.3174046507292676, Validation Loss: 0.8808878727934577, Validation Accuracy: 0.7272727272727273





## sGNN with Feature Subset

In [284]:
with open("feature_dicts/filtered_features_dict.pkl", "rb") as f:
    feat_dict = pkl.load(f)

In [287]:
from torch_geometric.data import DataLoader
test_loader = DataLoader(test_graphs, batch_size=32)
train_loader = DataLoader(train_graphs, batch_size=32)
val_loader = DataLoader(val_graphs, batch_size=32)



In [288]:
def add_features(years, graphs):

    zeros = torch.zeros(59)

    for i in range(len(years)):
        new_x = torch.empty(0, 59)
        year = years[i]

        feat_dict_year = feat_dict[year].combined_features

        for j, country in enumerate(all_nodes):
            if j == 0:
                new_x = torch.stack([zeros])

            elif country in feat_dict_year["country_code"].values:
                tensor_before = graphs[i].x[j]
                country_row = feat_dict_year[feat_dict_year["country_code"] == country]
                country_row = country_row.drop(columns = ["prev_gdp_growth", "country_code", "current_gdp_growth"])
                row_values = country_row.values.tolist()
                row_tensor = torch.tensor(row_values)[0]
                combined_values = torch.cat((tensor_before, row_tensor))

                new_x = torch.cat((new_x, combined_values.unsqueeze(0)), dim=0)

            else:
                new_x = torch.cat((new_x, zeros.unsqueeze(0)), dim=0)

        graphs[i].x = new_x

    return graphs

In [289]:
train_graphs = add_features(train_years, train_graphs)
val_graphs = add_features(val_years, val_graphs)
test_graphs = add_features(test_years, test_graphs)

In [290]:
crisis_years = [1983, 1982, 2008, 2002, 2016, 1967, 1962, 1989, 2012, 1963, 1993, 1986, 1996,1978]

def get_year_pairs(year_range):
    return [(year1, year2) for year1 in year_range for year2 in year_range if year2 >= year1]

def get_loader_pairs(dataset):
    return [(dataset[i], dataset[j]) for i in range(len(dataset)) for j in range(len(dataset)) if j >= i]

train_pairs = get_year_pairs(train_years)
val_pairs = get_year_pairs(val_years)

train_y = check_crisis_years(train_pairs, crisis_years)
val_y = check_crisis_years(val_pairs, crisis_years)

train_loader_pairs = get_loader_pairs(train_loader.dataset)
val_loader_pairs = get_loader_pairs(val_loader.dataset)

In [291]:
train_graph_pairs = get_graph_pairs(train_graphs)
val_graph_pairs = get_graph_pairs(val_graphs)

train_torch_y = torch.tensor(np.array(train_y))
val_torch_y = torch.tensor(np.array(val_y))

labeled_pairs_train = list(zip(train_loader_pairs, train_y))
labeled_pairs_val = list(zip(val_loader_pairs, val_y))

flattened_train = [(a, b, c) for ((a, b), c) in labeled_pairs_train]
flattened_val  = [(a, b, c) for ((a, b), c) in labeled_pairs_val]

In [None]:
positive_samples = [item for item in flattened_train if item[2] == 1]
negative_samples = [item for item in flattened_train if item[2] == 0]

# Calculate the difference in count
diff = len(negative_samples) - len(positive_samples)

# Upsample positive samples
if diff > 0:
    positive_samples_upsampled = positive_samples * (diff // len(positive_samples)) + random.sample(positive_samples, diff % len(positive_samples))
    balanced_data = negative_samples + positive_samples + positive_samples_upsampled
else:
    balanced_data = data

# Shuffle the balanced dataset
random.shuffle(balanced_data)

In [292]:
model = SiameseGNN(num_features=balanced_data[0][0].num_node_features)
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = StepLR(optimizer, step_size=10, gamma=0.1)  # Adjust step_size and gamma as needed
criterion = nn.BCELoss()

for epoch in tqdm(range(20)):
    model.train()
    train_losses = []
    for data1, data2, label in balanced_data:

        optimizer.zero_grad()
        out = model(data1, data2)
        label = torch.tensor(label).view(1).float()
        loss = criterion(out.squeeze(0), label)
        loss.backward()
        optimizer.step()
        train_losses.append(loss.item())

    scheduler.step()  # Add this line to update the learning rate

    model.eval()
    with torch.no_grad():
        val_losses = []
        correct = 0
        total = 0
        for data1, data2, label in flattened_val:
            out = model(data1, data2)
            label = torch.tensor(label).view(1).float()
            val_loss = criterion(out.squeeze(0), label)
            val_losses.append(val_loss.item())

            predictions = torch.round(out.squeeze())
            correct += (predictions == label).sum().item()
            total += 1

        val_loss = sum(val_losses) / len(val_losses)
        val_accuracy = correct / total

    print(f'Epoch: {epoch+1}, Training Loss: {sum(train_losses)/len(train_losses)}, Validation Loss: {val_loss}, Validation Accuracy: {val_accuracy}')

  5%|▌         | 1/20 [00:20<06:38, 20.97s/it]

Epoch: 1, Training Loss: 0.705197432643024, Validation Loss: 0.6934435530142351, Validation Accuracy: 0.45454545454545453


 10%|█         | 2/20 [00:41<06:12, 20.71s/it]

Epoch: 2, Training Loss: 0.6982349139033702, Validation Loss: 0.6935115113402858, Validation Accuracy: 0.5


 15%|█▌        | 3/20 [01:03<06:01, 21.27s/it]

Epoch: 3, Training Loss: 0.6919354542392736, Validation Loss: 0.6994528892365369, Validation Accuracy: 0.4696969696969697


 20%|██        | 4/20 [01:22<05:25, 20.37s/it]

Epoch: 4, Training Loss: 0.6775392101783502, Validation Loss: 0.7219479332367579, Validation Accuracy: 0.4090909090909091


 25%|██▌       | 5/20 [01:41<04:58, 19.90s/it]

Epoch: 5, Training Loss: 0.6562067090622631, Validation Loss: 0.7451615213896289, Validation Accuracy: 0.4090909090909091


 30%|███       | 6/20 [02:03<04:50, 20.76s/it]

Epoch: 6, Training Loss: 0.6296549859306037, Validation Loss: 0.6788908785039728, Validation Accuracy: 0.48484848484848486


 35%|███▌      | 7/20 [02:24<04:30, 20.83s/it]

Epoch: 7, Training Loss: 0.5998527040672407, Validation Loss: 0.6291101066903635, Validation Accuracy: 0.5606060606060606


 40%|████      | 8/20 [02:45<04:10, 20.84s/it]

Epoch: 8, Training Loss: 0.6100037103195695, Validation Loss: 0.7219045534729958, Validation Accuracy: 0.4696969696969697


 45%|████▌     | 9/20 [03:07<03:52, 21.09s/it]

Epoch: 9, Training Loss: 0.5723828582645014, Validation Loss: 0.7478801972712531, Validation Accuracy: 0.48484848484848486


 50%|█████     | 10/20 [03:28<03:31, 21.16s/it]

Epoch: 10, Training Loss: 0.5469491241688588, Validation Loss: 0.7148880430243232, Validation Accuracy: 0.5


 55%|█████▌    | 11/20 [03:50<03:11, 21.32s/it]

Epoch: 11, Training Loss: 0.5133549821169113, Validation Loss: 0.6671838763198166, Validation Accuracy: 0.5606060606060606


 60%|██████    | 12/20 [04:10<02:47, 20.93s/it]

Epoch: 12, Training Loss: 0.49131244545544683, Validation Loss: 0.6794764354655688, Validation Accuracy: 0.5303030303030303


 65%|██████▌   | 13/20 [04:29<02:21, 20.28s/it]

Epoch: 13, Training Loss: 0.47331837278774436, Validation Loss: 0.6848389165983959, Validation Accuracy: 0.5151515151515151


 70%|███████   | 14/20 [04:49<02:01, 20.23s/it]

Epoch: 14, Training Loss: 0.45888525391175444, Validation Loss: 0.6803629942464106, Validation Accuracy: 0.5454545454545454


 75%|███████▌  | 15/20 [05:09<01:41, 20.27s/it]

Epoch: 15, Training Loss: 0.44585570467537955, Validation Loss: 0.6846542558350572, Validation Accuracy: 0.5454545454545454


 80%|████████  | 16/20 [05:32<01:23, 20.89s/it]

Epoch: 16, Training Loss: 0.43396994381050125, Validation Loss: 0.6893549072415088, Validation Accuracy: 0.5757575757575758


 85%|████████▌ | 17/20 [05:53<01:03, 21.11s/it]

Epoch: 17, Training Loss: 0.4220145226285573, Validation Loss: 0.6958202309200935, Validation Accuracy: 0.5757575757575758


 90%|█████████ | 18/20 [06:13<00:41, 20.83s/it]

Epoch: 18, Training Loss: 0.41153957994271717, Validation Loss: 0.7071417463858697, Validation Accuracy: 0.5757575757575758


 95%|█████████▌| 19/20 [06:36<00:21, 21.31s/it]

Epoch: 19, Training Loss: 0.4010203088666394, Validation Loss: 0.7290616339129029, Validation Accuracy: 0.5757575757575758


100%|██████████| 20/20 [06:56<00:00, 20.83s/it]

Epoch: 20, Training Loss: 0.3914697464345818, Validation Loss: 0.7450493534984575, Validation Accuracy: 0.5909090909090909





## Random Feature Subset

In [293]:
with open("feature_dicts/random_features_dict.pkl", "rb") as f:
    feat_dict_random = pkl.load(f)

In [294]:
def add_features(years, graphs, feat_dict):

    zeros = torch.zeros(434)

    for i in range(len(years)):
        new_x = torch.empty(0, 434)
        year = years[i]

        feat_dict_year = feat_dict[year].combined_features

        for j, country in enumerate(all_nodes):
            if j == 0:
                new_x = torch.stack([zeros])

            elif country in feat_dict_year["country_code"].values:
                tensor_before = graphs[i].x[j]
                country_row = feat_dict_year[feat_dict_year["country_code"] == country]
                country_row = country_row.drop(columns = ["prev_gdp_growth", "country_code", "current_gdp_growth"])
                row_values = country_row.values.tolist()
                row_tensor = torch.tensor(row_values)[0]
                combined_values = torch.cat((tensor_before, row_tensor))

                new_x = torch.cat((new_x, combined_values.unsqueeze(0)), dim=0)

            else:
                new_x = torch.cat((new_x, zeros.unsqueeze(0)), dim=0)

        graphs[i].x = new_x

    return graphs

In [295]:
with open("src/pygcn/train_graphs.pickle", "rb") as f:
    train_graphs = pkl.load(f)

with open("src/pygcn/val_graphs.pickle", "rb") as f:  
    val_graphs = pkl.load(f)

with open("src/pygcn/test_graphs.pickle", "rb") as f:         
    test_graphs = pkl.load(f)

In [297]:
train_graphs = add_features(train_years, train_graphs, feat_dict_random)
val_graphs = add_features(val_years, val_graphs, feat_dict_random)
test_graphs = add_features(test_years, test_graphs, feat_dict_random)

In [298]:
from torch_geometric.data import DataLoader
test_loader = DataLoader(test_graphs, batch_size=32)
train_loader = DataLoader(train_graphs, batch_size=32)
val_loader = DataLoader(val_graphs, batch_size=32)



In [299]:
def check_crisis_years(year_pairs, crisis_years):
    result = []
    for pair in year_pairs:
        start, end = pair
        # Check if any crisis year is between the pair or equals the later year
        if any(start < year <= end for year in crisis_years):
            result.append(0)
        else:
            result.append(1)
    return result

In [300]:
crisis_years = [1983, 1982, 2008, 2002, 2016, 1967, 1962, 1989, 2012, 1963, 1993, 1986, 1996,1978]

def get_year_pairs(year_range):
    return [(year1, year2) for year1 in year_range for year2 in year_range if year2 >= year1]

def get_loader_pairs(dataset):
    return [(dataset[i], dataset[j]) for i in range(len(dataset)) for j in range(len(dataset)) if j >= i]

train_pairs = get_year_pairs(train_years)
val_pairs = get_year_pairs(val_years)

train_y = check_crisis_years(train_pairs, crisis_years)
val_y = check_crisis_years(val_pairs, crisis_years)

train_loader_pairs = get_loader_pairs(train_loader.dataset)
val_loader_pairs = get_loader_pairs(val_loader.dataset)

In [301]:
train_graph_pairs = get_graph_pairs(train_graphs)
val_graph_pairs = get_graph_pairs(val_graphs)

train_torch_y = torch.tensor(np.array(train_y))
val_torch_y = torch.tensor(np.array(val_y))

labeled_pairs_train = list(zip(train_loader_pairs, train_y))
labeled_pairs_val = list(zip(val_loader_pairs, val_y))

flattened_train = [(a, b, c) for ((a, b), c) in labeled_pairs_train]
flattened_val  = [(a, b, c) for ((a, b), c) in labeled_pairs_val]

In [305]:
positive_samples = [item for item in flattened_train if item[2] == 1]
negative_samples = [item for item in flattened_train if item[2] == 0]

# Calculate the difference in count
diff = len(negative_samples) - len(positive_samples)

# Upsample positive samples
if diff > 0:
    positive_samples_upsampled = positive_samples * (diff // len(positive_samples)) + random.sample(positive_samples, diff % len(positive_samples))
    balanced_data = negative_samples + positive_samples + positive_samples_upsampled
else:
    balanced_data = data

# Shuffle the balanced dataset
random.shuffle(balanced_data)

In [307]:
model = SiameseGNN(num_features=balanced_data[0][0].num_node_features)
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = StepLR(optimizer, step_size=10, gamma=0.1)  # Adjust step_size and gamma as needed
criterion = nn.BCELoss()

for epoch in tqdm(range(20)):
    model.train()
    train_losses = []
    for data1, data2, label in balanced_data:

        optimizer.zero_grad()
        out = model(data1, data2)
        label = torch.tensor(label).view(1).float()
        loss = criterion(out.squeeze(0), label)
        loss.backward()
        optimizer.step()
        train_losses.append(loss.item())

    scheduler.step()  # Add this line to update the learning rate

    model.eval()
    with torch.no_grad():
        val_losses = []
        correct = 0
        total = 0
        for data1, data2, label in flattened_val:
            out = model(data1, data2)
            label = torch.tensor(label).view(1).float()
            val_loss = criterion(out.squeeze(0), label)
            val_losses.append(val_loss.item())

            predictions = torch.round(out.squeeze())
            correct += (predictions == label).sum().item()
            total += 1

        val_loss = sum(val_losses) / len(val_losses)
        val_accuracy = correct / total

    print(f'Epoch: {epoch+1}, Training Loss: {sum(train_losses)/len(train_losses)}, Validation Loss: {val_loss}, Validation Accuracy: {val_accuracy}')

  5%|▌         | 1/20 [00:22<07:14, 22.85s/it]

Epoch: 1, Training Loss: 0.7009320919044302, Validation Loss: 0.583410073410381, Validation Accuracy: 0.803030303030303


 10%|█         | 2/20 [00:44<06:39, 22.22s/it]

Epoch: 2, Training Loss: 0.6978037488332733, Validation Loss: 0.5958270236398234, Validation Accuracy: 0.803030303030303


 15%|█▌        | 3/20 [01:05<06:06, 21.58s/it]

Epoch: 3, Training Loss: 0.6965184409832165, Validation Loss: 0.60185906291008, Validation Accuracy: 0.803030303030303


 20%|██        | 4/20 [01:26<05:40, 21.25s/it]

Epoch: 4, Training Loss: 0.6960067802470098, Validation Loss: 0.6010324620839321, Validation Accuracy: 0.803030303030303


 25%|██▌       | 5/20 [01:48<05:26, 21.77s/it]

Epoch: 5, Training Loss: 0.6930763921019627, Validation Loss: 0.60172654281963, Validation Accuracy: 0.803030303030303


 30%|███       | 6/20 [02:10<05:05, 21.84s/it]

Epoch: 6, Training Loss: 0.6873619026898409, Validation Loss: 0.5911647474223917, Validation Accuracy: 0.7424242424242424


 35%|███▌      | 7/20 [02:31<04:39, 21.48s/it]

Epoch: 7, Training Loss: 0.6557190062416459, Validation Loss: 0.5756148768193794, Validation Accuracy: 0.7575757575757576


 40%|████      | 8/20 [02:53<04:20, 21.67s/it]

Epoch: 8, Training Loss: 0.6383576422973335, Validation Loss: 0.5633337865724708, Validation Accuracy: 0.7424242424242424


 45%|████▌     | 9/20 [03:13<03:53, 21.23s/it]

Epoch: 9, Training Loss: 0.6208129173024879, Validation Loss: 0.5345934942138918, Validation Accuracy: 0.7575757575757576


 50%|█████     | 10/20 [03:34<03:29, 20.95s/it]

Epoch: 10, Training Loss: 0.5915828320691562, Validation Loss: 0.5601606108248234, Validation Accuracy: 0.7575757575757576


 55%|█████▌    | 11/20 [03:54<03:06, 20.76s/it]

Epoch: 11, Training Loss: 0.5587588943265103, Validation Loss: 0.5569406882154219, Validation Accuracy: 0.7575757575757576


 60%|██████    | 12/20 [04:15<02:45, 20.67s/it]

Epoch: 12, Training Loss: 0.5368331838943805, Validation Loss: 0.5625806848208109, Validation Accuracy: 0.7575757575757576


 65%|██████▌   | 13/20 [04:35<02:25, 20.73s/it]

Epoch: 13, Training Loss: 0.5228853276812746, Validation Loss: 0.5644193797400503, Validation Accuracy: 0.7575757575757576


 70%|███████   | 14/20 [04:56<02:03, 20.61s/it]

Epoch: 14, Training Loss: 0.5105333212303764, Validation Loss: 0.5662885615884354, Validation Accuracy: 0.7575757575757576


 75%|███████▌  | 15/20 [05:17<01:44, 20.82s/it]

Epoch: 15, Training Loss: 0.4977249724901196, Validation Loss: 0.5762461281635545, Validation Accuracy: 0.7575757575757576


 80%|████████  | 16/20 [05:37<01:22, 20.60s/it]

Epoch: 16, Training Loss: 0.48583462745673794, Validation Loss: 0.5960603285806648, Validation Accuracy: 0.7575757575757576


 85%|████████▌ | 17/20 [05:58<01:01, 20.60s/it]

Epoch: 17, Training Loss: 0.47470964914819563, Validation Loss: 0.6248706192709506, Validation Accuracy: 0.7575757575757576


 90%|█████████ | 18/20 [06:18<00:41, 20.55s/it]

Epoch: 18, Training Loss: 0.4630036315495842, Validation Loss: 0.6660098929635503, Validation Accuracy: 0.7575757575757576


 95%|█████████▌| 19/20 [06:40<00:20, 20.95s/it]

Epoch: 19, Training Loss: 0.45183994792325544, Validation Loss: 0.7102587904470662, Validation Accuracy: 0.7575757575757576


100%|██████████| 20/20 [07:01<00:00, 21.09s/it]

Epoch: 20, Training Loss: 0.44107343207910427, Validation Loss: 0.7473846449178051, Validation Accuracy: 0.7575757575757576



