In [1]:
import pandas as pd
import numpy as np
import pickle as pkl
import datetime as datetime
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from torch.optim.lr_scheduler import StepLR
from src.utils.CreateFeatures import CreateFeatures
from src.pygcn.SiameseGNN import SiameseGNN
from src.pygcn.GraphSAGE import SiameseGNN_SAGE

import torch
import torch.nn as nn
import torch
import torch_geometric.data as data

## All Events

In [2]:
years = range(1962,2019)

In [3]:
all_nodes = ['ABW', 'AFG', 'AGO', 'ALB', 'AND', 'ARE', 'ARG', 'ARM', 'ASM',
       'ATG', 'AUS', 'AUT', 'AZE', 'BDI', 'BEL', 'BEN', 'BFA', 'BGD',
       'BGR', 'BHR', 'BHS', 'BIH', 'BLR', 'BLZ', 'BMU', 'BOL', 'BRA',
       'BRB', 'BRN', 'BTN', 'BWA', 'CAF', 'CAN', 'CHE', 'CHL', 'CHN',
       'CIV', 'CMR', 'COD', 'COG', 'COL', 'COM', 'CPV', 'CRI', 'CUB',
       'CUW', 'CYM', 'CYP', 'CZE', 'DEU', 'DMA', 'DNK', 'DOM', 'DZA',
       'ECU', 'EGY', 'ESP', 'EST', 'ETH', 'FIN', 'FJI', 'FRA', 'FSM',
       'GAB', 'GBR', 'GEO', 'GHA', 'GIN', 'GMB', 'GNB', 'GNQ', 'GRC',
       'GRD', 'GRL', 'GTM', 'GUM', 'GUY', 'HKG', 'HND', 'HRV', 'HTI',
       'HUN', 'IDN', 'IND', 'IRL', 'IRN', 'IRQ', 'ISL', 'ISR', 'ITA',
       'JAM', 'JOR', 'JPN', 'KAZ', 'KEN', 'KGZ', 'KHM', 'KNA', 'KOR',
       'KWT', 'LAO', 'LBN', 'LBR', 'LBY', 'LCA', 'LKA', 'LSO', 'LTU',
       'LUX', 'LVA', 'MAC', 'MAR', 'MDA', 'MDG', 'MDV', 'MEX', 'MHL',
       'MKD', 'MLI', 'MLT', 'MMR', 'MNE', 'MNG', 'MNP', 'MOZ', 'MRT',
       'MUS', 'MWI', 'MYS', 'NAM', 'NER', 'NGA', 'NIC', 'NLD', 'NOR',
       'NPL', 'NRU', 'NZL', 'OMN', 'PAK', 'PAN', 'PER', 'PHL', 'PLW',
       'PNG', 'POL', 'PRT', 'PRY', 'PSE', 'PYF', 'QAT', 'ROU', 'RUS',
       'RWA', 'SAU', 'SDN', 'SEN', 'SGP', 'SLB', 'SLE', 'SLV', 'SMR',
       'SRB', 'SSD', 'STP', 'SUR', 'SVK', 'SVN', 'SWE', 'SWZ', 'SXM',
       'SYC', 'SYR', 'TCD', 'TGO', 'THA', 'TJK', 'TKM', 'TLS', 'TON',
       'TTO', 'TUN', 'TUR', 'TUV', 'TZA', 'UGA', 'UKR', 'URY', 'USA',
       'UZB', 'VCT', 'VEN', 'VNM', 'VUT', 'WSM', 'YEM', 'ZAF', 'ZMB',
       'ZWE']

In [28]:
with open("src/pygcn/all_graphs.pkl", "rb") as f:         
    all_graphs = pkl.load(f)

In [29]:
from torch_geometric.data import DataLoader
all_loader = DataLoader(all_graphs, batch_size=32)

## sGNN with GCN Encoder and 3 Features

In [30]:
def check_crisis_years(year_pairs, crisis_years):
    result = []
    for pair in year_pairs:
        start, end = pair
        # Check if any crisis year is between the pair or equals the later year
        if any(start < year <= end for year in crisis_years):
            result.append(0)
        else:
            result.append(1)
    return result

In [31]:
crisis_years = [1962, 1963, 1967, 1978, 1982, 1983, 1986, 1989, 1993, 1996, 2002, 2008, 2012, 2016]

def get_year_pairs(year_range):
    return [(year1, year2) for year1 in year_range for year2 in year_range if year2 >= year1]

def get_loader_pairs(dataset):
    return [(dataset[i], dataset[j]) for i in range(len(dataset)) for j in range(len(dataset)) if j >= i]

def get_graph_pairs(graphs):
    return [(graphs[i], graphs[j]) for i in range(len(graphs)) for j in range(len(graphs)) if j >= i]

all_pairs = get_year_pairs(years)
all_y = check_crisis_years(all_pairs, crisis_years)
all_loader_pairs = get_loader_pairs(all_loader.dataset)

In [32]:
all_graph_pairs = get_graph_pairs(all_graphs)
all_torch_y = torch.tensor(np.array(all_y))
labeled_pairs_all = list(zip(all_loader_pairs, all_y))
flattened_all = [(a, b, c) for ((a, b), c) in labeled_pairs_all]

In [33]:
from sklearn.model_selection import train_test_split
flattened_train, flattened_test = train_test_split(flattened_all, test_size=0.35, random_state=42)
flattened_test, flattened_val = train_test_split(flattened_test, test_size=0.5, random_state=42)

In [34]:
import random
positive_samples = [item for item in flattened_train if item[2] == 1]
negative_samples = [item for item in flattened_train if item[2] == 0]

# Calculate the difference in count
diff = len(negative_samples) - len(positive_samples)

# Upsample positive samples
if diff > 0:
    positive_samples_upsampled = positive_samples * (diff // len(positive_samples)) + random.sample(positive_samples, diff % len(positive_samples))
    balanced_data = negative_samples + positive_samples + positive_samples_upsampled
else:
    balanced_data = flattened_train

random.shuffle(balanced_data)

In [35]:
def run_model(model, train_data, val_data):
    torch.manual_seed(42)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    scheduler = StepLR(optimizer, step_size=10, gamma=0.1)  # Adjust step_size and gamma as needed
    criterion = nn.BCELoss()

    for epoch in tqdm(range(10)):
        model.train()
        train_losses = []
        for data1, data2, label in train_data:

            optimizer.zero_grad()
            out = model(data1, data2)
            label = torch.tensor(label).view(1).float()
            loss = criterion(out.squeeze(0), label)
            loss.backward()
            optimizer.step()
            train_losses.append(loss.item())

        scheduler.step()  # Add this line to update the learning rate

        model.eval()
        with torch.no_grad():
            val_losses = []
            correct = 0
            total = 0
            for data1, data2, label in val_data:
                out = model(data1, data2)
                label = torch.tensor(label).view(1).float()
                val_loss = criterion(out.squeeze(0), label)
                val_losses.append(val_loss.item())

                predictions = torch.round(out.squeeze())
                correct += (predictions == label).sum().item()
                total += 1

            val_loss = sum(val_losses) / len(val_losses)
            val_accuracy = correct / total

        print(f'Epoch: {epoch+1}, Training Loss: {sum(train_losses)/len(train_losses)}, Validation Loss: {val_loss}, Validation Accuracy: {val_accuracy}')

In [36]:
model = SiameseGNN(num_features=balanced_data[0][0].num_node_features)
run_model(model, balanced_data, flattened_val)

  0%|          | 0/10 [00:00<?, ?it/s]

 10%|█         | 1/10 [00:42<06:25, 42.83s/it]

Epoch: 1, Training Loss: 0.5324986961623943, Validation Loss: 0.24172229642883458, Validation Accuracy: 0.8931034482758621


 20%|██        | 2/10 [01:24<05:36, 42.09s/it]

Epoch: 2, Training Loss: 0.34540652826681006, Validation Loss: 0.47851934396372786, Validation Accuracy: 0.7620689655172413


 30%|███       | 3/10 [02:08<05:02, 43.21s/it]

Epoch: 3, Training Loss: 0.2640163874425723, Validation Loss: 0.3076775413456148, Validation Accuracy: 0.8620689655172413


 40%|████      | 4/10 [02:51<04:17, 42.97s/it]

Epoch: 4, Training Loss: 0.22656424609292491, Validation Loss: 0.2612652181511231, Validation Accuracy: 0.8896551724137931


 50%|█████     | 5/10 [03:33<03:32, 42.47s/it]

Epoch: 5, Training Loss: 0.19669847885160235, Validation Loss: 0.2364403671091262, Validation Accuracy: 0.896551724137931


 60%|██████    | 6/10 [04:15<02:50, 42.55s/it]

Epoch: 6, Training Loss: 0.20064263855522055, Validation Loss: 0.3099876101979793, Validation Accuracy: 0.8482758620689655


 70%|███████   | 7/10 [04:57<02:07, 42.37s/it]

Epoch: 7, Training Loss: 0.17827085249909772, Validation Loss: 0.37610372780533186, Validation Accuracy: 0.8275862068965517


 80%|████████  | 8/10 [05:40<01:24, 42.34s/it]

Epoch: 8, Training Loss: 0.15876027569914414, Validation Loss: 0.23900277402436618, Validation Accuracy: 0.9


 90%|█████████ | 9/10 [06:23<00:42, 42.82s/it]

Epoch: 9, Training Loss: 0.15665681889856065, Validation Loss: 0.2453529735647768, Validation Accuracy: 0.9


100%|██████████| 10/10 [07:06<00:00, 42.67s/it]

Epoch: 10, Training Loss: 0.14069674108932737, Validation Loss: 0.11532949556391847, Validation Accuracy: 0.9482758620689655





In [37]:
model = SiameseGNN_SAGE(num_features=balanced_data[0][0].num_node_features)
run_model(model, balanced_data, flattened_val)

  0%|          | 0/10 [00:00<?, ?it/s]

 10%|█         | 1/10 [00:30<04:32, 30.25s/it]

Epoch: 1, Training Loss: 0.44167864507577126, Validation Loss: 0.3648904321008715, Validation Accuracy: 0.8482758620689655


 20%|██        | 2/10 [01:01<04:08, 31.01s/it]

Epoch: 2, Training Loss: 0.2686441535188953, Validation Loss: 0.3053710645032597, Validation Accuracy: 0.8724137931034482


 30%|███       | 3/10 [01:31<03:34, 30.61s/it]

Epoch: 3, Training Loss: 0.2103439143094317, Validation Loss: 0.2640289279728614, Validation Accuracy: 0.896551724137931


 40%|████      | 4/10 [02:03<03:04, 30.83s/it]

Epoch: 4, Training Loss: 0.16205016073159098, Validation Loss: 0.15930006205393324, Validation Accuracy: 0.9413793103448276


 50%|█████     | 5/10 [02:34<02:34, 30.87s/it]

Epoch: 5, Training Loss: 0.18180875047202816, Validation Loss: 0.12393115296124899, Validation Accuracy: 0.9551724137931035


 60%|██████    | 6/10 [03:05<02:03, 30.98s/it]

Epoch: 6, Training Loss: 0.1398240532312088, Validation Loss: 0.11266487801110307, Validation Accuracy: 0.9482758620689655


 70%|███████   | 7/10 [03:36<01:33, 31.16s/it]

Epoch: 7, Training Loss: 0.1301676899483466, Validation Loss: 0.15107364359167244, Validation Accuracy: 0.9344827586206896


 80%|████████  | 8/10 [04:07<01:02, 31.06s/it]

Epoch: 8, Training Loss: 0.12360062082763279, Validation Loss: 0.06999885823762302, Validation Accuracy: 0.9689655172413794


 90%|█████████ | 9/10 [04:46<00:33, 33.45s/it]

Epoch: 9, Training Loss: 0.10182513446907665, Validation Loss: 0.0783726198305697, Validation Accuracy: 0.9620689655172414


100%|██████████| 10/10 [05:16<00:00, 31.69s/it]

Epoch: 10, Training Loss: 0.10328742304409388, Validation Loss: 0.04297151446908488, Validation Accuracy: 0.9758620689655172





## sGNN with Feature Subset

In [15]:
with open("feature_dicts/filtered_features_dict.pkl", "rb") as f:
    feat_dict = pkl.load(f)

In [16]:
all_loader = DataLoader(all_graphs, batch_size=4)



In [17]:
def add_features(years, graphs, feat_dict, dim):

    zeros = torch.zeros(dim)

    for i in range(len(years)):
        new_x = torch.empty(0, dim)
        year = years[i]

        feat_dict_year = feat_dict[year].combined_features

        for j, country in enumerate(all_nodes):
            if j == 0:
                new_x = torch.stack([zeros])

            elif country in feat_dict_year["country_code"].values:
                tensor_before = graphs[i].x[j]
                country_row = feat_dict_year[feat_dict_year["country_code"] == country]
                country_row = country_row.drop(columns = ["prev_gdp_growth", "country_code", "current_gdp_growth"])
                row_values = country_row.values.tolist()
                row_tensor = torch.tensor(row_values)[0]
                combined_values = torch.cat((tensor_before, row_tensor))

                new_x = torch.cat((new_x, combined_values.unsqueeze(0)), dim=0)

            else:
                new_x = torch.cat((new_x, zeros.unsqueeze(0)), dim=0)

        graphs[i].x = new_x

    return graphs

In [18]:
all_graphs = add_features(years, all_graphs, feat_dict, 59)

In [19]:
all_pairs = get_year_pairs(years)
all_y = check_crisis_years(all_pairs, crisis_years)
all_loader_pairs = get_loader_pairs(all_loader.dataset)

In [20]:
all_graph_pairs = get_graph_pairs(all_graphs)
all_torch_y = torch.tensor(np.array(all_y))
labeled_pairs_all = list(zip(all_loader_pairs, all_y))
flattened_all = [(a, b, c) for ((a, b), c) in labeled_pairs_all]

In [22]:
flattened_train, flattened_test = train_test_split(flattened_all, test_size=0.35, random_state=42)
flattened_test, flattened_val = train_test_split(flattened_test, test_size=0.5, random_state=42)

In [23]:
positive_samples = [item for item in flattened_train if item[2] == 1]
negative_samples = [item for item in flattened_train if item[2] == 0]

# Calculate the difference in count
diff = len(negative_samples) - len(positive_samples)

# Upsample positive samples
if diff > 0:
    positive_samples_upsampled = positive_samples * (diff // len(positive_samples)) + random.sample(positive_samples, diff % len(positive_samples))
    balanced_data = negative_samples + positive_samples + positive_samples_upsampled
else:
    balanced_data = data

# Shuffle the balanced dataset
random.shuffle(balanced_data)

In [24]:
model = SiameseGNN(num_features=balanced_data[0][0].num_node_features)
run_model(model)

  0%|          | 0/10 [00:00<?, ?it/s]

 10%|█         | 1/10 [00:43<06:34, 43.86s/it]

Epoch: 1, Training Loss: 0.4930234749142575, Validation Loss: 0.37312743186837793, Validation Accuracy: 0.803030303030303


 20%|██        | 2/10 [01:24<05:36, 42.05s/it]

Epoch: 2, Training Loss: 0.27495769605077264, Validation Loss: 0.2724040243153771, Validation Accuracy: 0.8636363636363636


 30%|███       | 3/10 [02:05<04:51, 41.64s/it]

Epoch: 3, Training Loss: 0.2248954466960482, Validation Loss: 0.28988171698469106, Validation Accuracy: 0.8727272727272727


 40%|████      | 4/10 [02:46<04:07, 41.31s/it]

Epoch: 4, Training Loss: 0.1896331178642596, Validation Loss: 0.2617701005042446, Validation Accuracy: 0.8787878787878788


 50%|█████     | 5/10 [03:27<03:26, 41.22s/it]

Epoch: 5, Training Loss: 0.18098452751793082, Validation Loss: 0.23646158467957074, Validation Accuracy: 0.9030303030303031


 60%|██████    | 6/10 [04:08<02:44, 41.11s/it]

Epoch: 6, Training Loss: 0.16029696564362297, Validation Loss: 0.16636041654256228, Validation Accuracy: 0.9212121212121213


 70%|███████   | 7/10 [04:49<02:03, 41.10s/it]

Epoch: 7, Training Loss: 0.1376725297154754, Validation Loss: 0.1458797445155638, Validation Accuracy: 0.9212121212121213


 80%|████████  | 8/10 [05:31<01:22, 41.47s/it]

Epoch: 8, Training Loss: 0.11100866558417051, Validation Loss: 0.17651364895423421, Validation Accuracy: 0.9484848484848485


 90%|█████████ | 9/10 [06:13<00:41, 41.51s/it]

Epoch: 9, Training Loss: 0.15310454724894348, Validation Loss: 0.2425753045432044, Validation Accuracy: 0.9181818181818182


100%|██████████| 10/10 [06:54<00:00, 41.46s/it]

Epoch: 10, Training Loss: 0.19624207845531394, Validation Loss: 0.21172594710616535, Validation Accuracy: 0.9393939393939394





In [42]:
torch.save(model.state_dict(), "src/pygcn/siamese_gnn_gcn_mis.pt")

In [25]:
model = SiameseGNN_SAGE(num_features=balanced_data[0][0].num_node_features)
run_model(model)

  0%|          | 0/10 [00:00<?, ?it/s]

 10%|█         | 1/10 [00:34<05:07, 34.16s/it]

Epoch: 1, Training Loss: 0.41258192223040735, Validation Loss: 0.44412957070338904, Validation Accuracy: 0.8303030303030303


 20%|██        | 2/10 [01:06<04:26, 33.31s/it]

Epoch: 2, Training Loss: 0.20428897188329118, Validation Loss: 0.3490122114237624, Validation Accuracy: 0.8696969696969697


 30%|███       | 3/10 [01:39<03:50, 32.98s/it]

Epoch: 3, Training Loss: 0.14108232489250372, Validation Loss: 0.3122414473691165, Validation Accuracy: 0.906060606060606


 40%|████      | 4/10 [02:11<03:16, 32.76s/it]

Epoch: 4, Training Loss: 0.12607321945552596, Validation Loss: 0.323149046525853, Validation Accuracy: 0.8939393939393939


 50%|█████     | 5/10 [02:45<02:44, 32.99s/it]

Epoch: 5, Training Loss: 0.09868617400928416, Validation Loss: 0.26738550154682755, Validation Accuracy: 0.9090909090909091


 60%|██████    | 6/10 [03:17<02:11, 32.84s/it]

Epoch: 6, Training Loss: 0.083377370365413, Validation Loss: 0.16394078917307497, Validation Accuracy: 0.9545454545454546


 70%|███████   | 7/10 [03:49<01:37, 32.60s/it]

Epoch: 7, Training Loss: 0.07001826200714911, Validation Loss: 0.17328869412824244, Validation Accuracy: 0.9484848484848485


 80%|████████  | 8/10 [04:21<01:04, 32.42s/it]

Epoch: 8, Training Loss: 0.08141045554800873, Validation Loss: 0.20506158780666406, Validation Accuracy: 0.9484848484848485


 90%|█████████ | 9/10 [04:54<00:32, 32.45s/it]

Epoch: 9, Training Loss: 0.05857951407345192, Validation Loss: 0.19294081961764306, Validation Accuracy: 0.9242424242424242


100%|██████████| 10/10 [05:27<00:00, 32.73s/it]

Epoch: 10, Training Loss: 0.05575741643960972, Validation Loss: 0.15667895708148452, Validation Accuracy: 0.9606060606060606





In [48]:
torch.save(model.state_dict(), "src/pygcn/siamese_gnn_sage_mis.pt")

## Random Feature Subset

In [40]:
with open("feature_dicts/random_features_dict.pkl", "rb") as f:
    feat_dict_random = pkl.load(f)

In [43]:
with open("src/pygcn/all_graphs.pkl", "rb") as f:         
    all_graphs = pkl.load(f)

In [44]:
all_graphs = add_features(years, all_graphs, feat_dict_random, 434)

In [45]:
all_pairs = get_year_pairs(years)
all_y = check_crisis_years(all_pairs, crisis_years)
all_loader_pairs = get_loader_pairs(all_loader.dataset)

In [46]:
all_graph_pairs = get_graph_pairs(all_graphs)
all_torch_y = torch.tensor(np.array(all_y))
labeled_pairs_all = list(zip(all_loader_pairs, all_y))
flattened_all = [(a, b, c) for ((a, b), c) in labeled_pairs_all]

In [47]:
lst = flattened_all  # Replace this with your list
flattened_train, flattened_val, flattened_test = get_random_percentages(lst)

In [48]:
positive_samples = [item for item in flattened_train if item[2] == 1]
negative_samples = [item for item in flattened_train if item[2] == 0]

# Calculate the difference in count
diff = len(negative_samples) - len(positive_samples)

# Upsample positive samples
if diff > 0:
    positive_samples_upsampled = positive_samples * (diff // len(positive_samples)) + random.sample(positive_samples, diff % len(positive_samples))
    balanced_data = negative_samples + positive_samples + positive_samples_upsampled
else:
    balanced_data = data

# Shuffle the balanced dataset
random.shuffle(balanced_data)

In [49]:
model = SiameseGNN(num_features=balanced_data[0][0].num_node_features)
run_model(model)

 10%|█         | 1/10 [00:43<06:29, 43.24s/it]

Epoch: 1, Training Loss: 0.544859388932245, Validation Loss: 0.34816615782452354, Validation Accuracy: 0.8090909090909091


 20%|██        | 2/10 [01:30<06:04, 45.58s/it]

Epoch: 2, Training Loss: 0.3392501510916909, Validation Loss: 0.30394700742467784, Validation Accuracy: 0.8666666666666667


 30%|███       | 3/10 [02:15<05:16, 45.16s/it]

Epoch: 3, Training Loss: 0.26689445104531306, Validation Loss: 0.31279571309375265, Validation Accuracy: 0.8515151515151516


 40%|████      | 4/10 [03:00<04:32, 45.44s/it]

Epoch: 4, Training Loss: 0.2068697312217992, Validation Loss: 0.22516020783928758, Validation Accuracy: 0.896969696969697


 50%|█████     | 5/10 [03:43<03:42, 44.43s/it]

Epoch: 5, Training Loss: 0.16928869296208252, Validation Loss: 0.1768982171130395, Validation Accuracy: 0.9151515151515152


 60%|██████    | 6/10 [04:36<03:09, 47.43s/it]

Epoch: 6, Training Loss: 0.15923998523289348, Validation Loss: 0.15861676974737113, Validation Accuracy: 0.9333333333333333


 70%|███████   | 7/10 [05:25<02:23, 47.71s/it]

Epoch: 7, Training Loss: 0.15724181305050022, Validation Loss: 0.18814892782265263, Validation Accuracy: 0.9181818181818182


 80%|████████  | 8/10 [06:16<01:37, 48.83s/it]

Epoch: 8, Training Loss: 0.13051756557245742, Validation Loss: 0.1601937293960487, Validation Accuracy: 0.9424242424242424


 90%|█████████ | 9/10 [07:08<00:49, 49.76s/it]

Epoch: 9, Training Loss: 0.12332007361746833, Validation Loss: 0.10835096457458072, Validation Accuracy: 0.9606060606060606


100%|██████████| 10/10 [08:02<00:00, 48.22s/it]

Epoch: 10, Training Loss: 0.13025680503896847, Validation Loss: 0.1940113610416919, Validation Accuracy: 0.9090909090909091





In [50]:
model = SiameseGNN_SAGE(num_features=balanced_data[0][0].num_node_features)
run_model(model)

  0%|          | 0/10 [00:00<?, ?it/s]

 10%|█         | 1/10 [00:34<05:07, 34.18s/it]

Epoch: 1, Training Loss: 0.4442165199707876, Validation Loss: 0.19106294969378998, Validation Accuracy: 0.9212121212121213


 20%|██        | 2/10 [01:06<04:25, 33.14s/it]

Epoch: 2, Training Loss: 0.25832506298792324, Validation Loss: 0.19086683025075632, Validation Accuracy: 0.9424242424242424


 30%|███       | 3/10 [01:39<03:49, 32.82s/it]

Epoch: 3, Training Loss: 0.1691997451430599, Validation Loss: 0.1932731779962496, Validation Accuracy: 0.9090909090909091


 40%|████      | 4/10 [02:11<03:15, 32.66s/it]

Epoch: 4, Training Loss: 0.13047860322443655, Validation Loss: 0.11865929832301018, Validation Accuracy: 0.9575757575757575


 50%|█████     | 5/10 [02:47<02:49, 34.00s/it]

Epoch: 5, Training Loss: 0.10779943096805669, Validation Loss: 0.18853859430510608, Validation Accuracy: 0.9333333333333333


 60%|██████    | 6/10 [03:27<02:23, 35.78s/it]

Epoch: 6, Training Loss: 0.10272896803320476, Validation Loss: 0.16195754717654465, Validation Accuracy: 0.9454545454545454


 70%|███████   | 7/10 [04:02<01:46, 35.63s/it]

Epoch: 7, Training Loss: 0.08266344443156012, Validation Loss: 0.18734804903629773, Validation Accuracy: 0.9272727272727272


 80%|████████  | 8/10 [04:37<01:11, 35.61s/it]

Epoch: 8, Training Loss: 0.07332483878021086, Validation Loss: 0.10584011013958264, Validation Accuracy: 0.9666666666666667


 90%|█████████ | 9/10 [05:12<00:35, 35.22s/it]

Epoch: 9, Training Loss: 0.0708977775151172, Validation Loss: 0.08185857179200728, Validation Accuracy: 0.9666666666666667


100%|██████████| 10/10 [05:54<00:00, 35.49s/it]

Epoch: 10, Training Loss: 0.06833519310347812, Validation Loss: 0.5299345482604677, Validation Accuracy: 0.8363636363636363





## All Features

In [34]:
with open("feature_dicts/features_dict.pkl", "rb") as f:
    feat_dict_all = pkl.load(f)

In [35]:
with open("src/pygcn/train_graphs.pickle", "rb") as f:
    train_graphs = pkl.load(f)

with open("src/pygcn/val_graphs.pickle", "rb") as f:  
    val_graphs = pkl.load(f)

with open("src/pygcn/test_graphs.pickle", "rb") as f:         
    test_graphs = pkl.load(f)

In [41]:
def add_features(years, graphs, feat_dict, dim):

    zeros = torch.zeros(dim)

    for i in range(len(years)):
        new_x = torch.empty(0, dim)
        year = years[i]

        feat_dict_year = feat_dict[year].combined_features

        for j, country in enumerate(all_nodes):
            if j == 0:
                new_x = torch.stack([zeros])

            elif country in feat_dict_year["country_code"].values:
                tensor_before = graphs[i].x[j]
                country_row = feat_dict_year[feat_dict_year["country_code"] == country]
                country_row = country_row.drop(columns = ["prev_gdp_growth", "country_code", "current_gdp_growth"])
                row_values = country_row.values.tolist()
                row_tensor = torch.tensor(row_values)[0]
                combined_values = torch.cat((tensor_before, row_tensor))

                new_x = torch.cat((new_x, combined_values.unsqueeze(0)), dim=0)

            else:
                new_x = torch.cat((new_x, zeros.unsqueeze(0)), dim=0)

        graphs[i].x = new_x

    return graphs

In [42]:
train_graphs = add_features(train_years, train_graphs, feat_dict_all)
val_graphs = add_features(val_years, val_graphs, feat_dict_all)
test_graphs = add_features(test_years, test_graphs, feat_dict_all)

In [43]:
from torch_geometric.data import DataLoader
test_loader = DataLoader(test_graphs, batch_size=4)
train_loader = DataLoader(train_graphs, batch_size=4)
val_loader = DataLoader(val_graphs, batch_size=4)



In [44]:
train_pairs = get_year_pairs(train_years)
val_pairs = get_year_pairs(val_years)

train_y = check_crisis_years(train_pairs, crisis_years)
val_y = check_crisis_years(val_pairs, crisis_years)

train_loader_pairs = get_loader_pairs(train_loader.dataset)
val_loader_pairs = get_loader_pairs(val_loader.dataset)

In [45]:
train_graph_pairs = get_graph_pairs(train_graphs)
val_graph_pairs = get_graph_pairs(val_graphs)

train_torch_y = torch.tensor(np.array(train_y))
val_torch_y = torch.tensor(np.array(val_y))

labeled_pairs_train = list(zip(train_loader_pairs, train_y))
labeled_pairs_val = list(zip(val_loader_pairs, val_y))

flattened_train = [(a, b, c) for ((a, b), c) in labeled_pairs_train]
flattened_val  = [(a, b, c) for ((a, b), c) in labeled_pairs_val]

In [46]:
positive_samples = [item for item in flattened_train if item[2] == 1]
negative_samples = [item for item in flattened_train if item[2] == 0]

# Calculate the difference in count
diff = len(negative_samples) - len(positive_samples)

# Upsample positive samples
if diff > 0:
    positive_samples_upsampled = positive_samples * (diff // len(positive_samples)) + random.sample(positive_samples, diff % len(positive_samples))
    balanced_data = negative_samples + positive_samples + positive_samples_upsampled
else:
    balanced_data = data

# Shuffle the balanced dataset
random.shuffle(balanced_data)

In [47]:
model = SiameseGNN(num_features=balanced_data[0][0].num_node_features)
run_model(model)

 10%|█         | 1/10 [00:28<04:18, 28.70s/it]

Epoch: 1, Training Loss: 0.7038765474718217, Validation Loss: 0.6810447085987438, Validation Accuracy: 0.803030303030303


 20%|██        | 2/10 [00:53<03:31, 26.40s/it]

Epoch: 2, Training Loss: 0.6960015451466596, Validation Loss: 0.6861390462427428, Validation Accuracy: 0.6666666666666666


 30%|███       | 3/10 [01:18<02:59, 25.68s/it]

Epoch: 3, Training Loss: 0.692367503401364, Validation Loss: 0.6572099048079867, Validation Accuracy: 0.7121212121212122


 40%|████      | 4/10 [01:43<02:32, 25.36s/it]

Epoch: 4, Training Loss: 0.6717249420099324, Validation Loss: 0.6290249801946409, Validation Accuracy: 0.696969696969697


 50%|█████     | 5/10 [02:07<02:05, 25.15s/it]

Epoch: 5, Training Loss: 0.6480736955128915, Validation Loss: 0.6430482489593101, Validation Accuracy: 0.6666666666666666


 60%|██████    | 6/10 [02:32<01:40, 25.00s/it]

Epoch: 6, Training Loss: 0.6153214025305725, Validation Loss: 0.6232328715197968, Validation Accuracy: 0.6666666666666666


 70%|███████   | 7/10 [02:59<01:17, 25.70s/it]

Epoch: 7, Training Loss: 0.5915779746291873, Validation Loss: 0.5917616424461206, Validation Accuracy: 0.696969696969697


 80%|████████  | 8/10 [03:24<00:50, 25.41s/it]

Epoch: 8, Training Loss: 0.5808850291128802, Validation Loss: 0.5979167244425325, Validation Accuracy: 0.696969696969697


 90%|█████████ | 9/10 [03:49<00:25, 25.17s/it]

Epoch: 9, Training Loss: 0.567361297784827, Validation Loss: 0.6328929368067872, Validation Accuracy: 0.6212121212121212


100%|██████████| 10/10 [04:17<00:00, 25.78s/it]

Epoch: 10, Training Loss: 0.5595406232180244, Validation Loss: 0.6325685942376201, Validation Accuracy: 0.5909090909090909





In [48]:
model = SiameseGNN_SAGE(num_features=balanced_data[0][0].num_node_features)
run_model(model)

 10%|█         | 1/10 [00:53<07:58, 53.12s/it]

Epoch: 1, Training Loss: 0.6997897207562687, Validation Loss: 0.6875970408771978, Validation Accuracy: 0.803030303030303


 20%|██        | 2/10 [01:46<07:04, 53.06s/it]

Epoch: 2, Training Loss: 0.6955334810834182, Validation Loss: 0.686434802683917, Validation Accuracy: 0.7424242424242424


 30%|███       | 3/10 [02:38<06:10, 52.90s/it]

Epoch: 3, Training Loss: 0.6917367833022021, Validation Loss: 0.6708159130631071, Validation Accuracy: 0.803030303030303


 40%|████      | 4/10 [03:31<05:17, 52.95s/it]

Epoch: 4, Training Loss: 0.6810245707079217, Validation Loss: 0.6506446861859524, Validation Accuracy: 0.803030303030303


 50%|█████     | 5/10 [04:25<04:25, 53.07s/it]

Epoch: 5, Training Loss: 0.6649606728501487, Validation Loss: 0.6540512085864039, Validation Accuracy: 0.6818181818181818


 60%|██████    | 6/10 [05:18<03:32, 53.03s/it]

Epoch: 6, Training Loss: 0.6487234287530358, Validation Loss: 0.6536229751778372, Validation Accuracy: 0.48484848484848486


 70%|███████   | 7/10 [06:11<02:39, 53.01s/it]

Epoch: 7, Training Loss: 0.6490531198927301, Validation Loss: 0.6537377166025566, Validation Accuracy: 0.5909090909090909


 80%|████████  | 8/10 [07:03<01:45, 52.93s/it]

Epoch: 8, Training Loss: 0.6225556447579149, Validation Loss: 0.6484664204445753, Validation Accuracy: 0.5909090909090909


 90%|█████████ | 9/10 [07:57<00:53, 53.02s/it]

Epoch: 9, Training Loss: 0.6143977640171265, Validation Loss: 0.5539399132584081, Validation Accuracy: 0.7272727272727273


100%|██████████| 10/10 [08:49<00:00, 52.98s/it]

Epoch: 10, Training Loss: 0.5987672834107053, Validation Loss: 0.5931275310841474, Validation Accuracy: 0.6666666666666666



