In [30]:
import pandas as pd
import numpy as np
import pickle as pkl
import datetime as datetime
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from torch.optim.lr_scheduler import StepLR
from src.utils.CreateFeatures import CreateFeatures
from src.pygcn.SiameseGNN import SiameseGNN
from src.pygcn.GraphSAGE import SiameseGNN_SAGE
from src.pygcn.graph_isomorphism import SiameseGNN_GIN

import torch
import torch.nn as nn
import torch
import torch_geometric.data as data

## All Events

In [2]:
years = range(1962,2019)

In [3]:
all_nodes = ['ABW', 'AFG', 'AGO', 'ALB', 'AND', 'ARE', 'ARG', 'ARM', 'ASM',
       'ATG', 'AUS', 'AUT', 'AZE', 'BDI', 'BEL', 'BEN', 'BFA', 'BGD',
       'BGR', 'BHR', 'BHS', 'BIH', 'BLR', 'BLZ', 'BMU', 'BOL', 'BRA',
       'BRB', 'BRN', 'BTN', 'BWA', 'CAF', 'CAN', 'CHE', 'CHL', 'CHN',
       'CIV', 'CMR', 'COD', 'COG', 'COL', 'COM', 'CPV', 'CRI', 'CUB',
       'CUW', 'CYM', 'CYP', 'CZE', 'DEU', 'DMA', 'DNK', 'DOM', 'DZA',
       'ECU', 'EGY', 'ESP', 'EST', 'ETH', 'FIN', 'FJI', 'FRA', 'FSM',
       'GAB', 'GBR', 'GEO', 'GHA', 'GIN', 'GMB', 'GNB', 'GNQ', 'GRC',
       'GRD', 'GRL', 'GTM', 'GUM', 'GUY', 'HKG', 'HND', 'HRV', 'HTI',
       'HUN', 'IDN', 'IND', 'IRL', 'IRN', 'IRQ', 'ISL', 'ISR', 'ITA',
       'JAM', 'JOR', 'JPN', 'KAZ', 'KEN', 'KGZ', 'KHM', 'KNA', 'KOR',
       'KWT', 'LAO', 'LBN', 'LBR', 'LBY', 'LCA', 'LKA', 'LSO', 'LTU',
       'LUX', 'LVA', 'MAC', 'MAR', 'MDA', 'MDG', 'MDV', 'MEX', 'MHL',
       'MKD', 'MLI', 'MLT', 'MMR', 'MNE', 'MNG', 'MNP', 'MOZ', 'MRT',
       'MUS', 'MWI', 'MYS', 'NAM', 'NER', 'NGA', 'NIC', 'NLD', 'NOR',
       'NPL', 'NRU', 'NZL', 'OMN', 'PAK', 'PAN', 'PER', 'PHL', 'PLW',
       'PNG', 'POL', 'PRT', 'PRY', 'PSE', 'PYF', 'QAT', 'ROU', 'RUS',
       'RWA', 'SAU', 'SDN', 'SEN', 'SGP', 'SLB', 'SLE', 'SLV', 'SMR',
       'SRB', 'SSD', 'STP', 'SUR', 'SVK', 'SVN', 'SWE', 'SWZ', 'SXM',
       'SYC', 'SYR', 'TCD', 'TGO', 'THA', 'TJK', 'TKM', 'TLS', 'TON',
       'TTO', 'TUN', 'TUR', 'TUV', 'TZA', 'UGA', 'UKR', 'URY', 'USA',
       'UZB', 'VCT', 'VEN', 'VNM', 'VUT', 'WSM', 'YEM', 'ZAF', 'ZMB',
       'ZWE']

In [31]:
with open("src/pygcn/all_graphs.pkl", "rb") as f:         
    all_graphs = pkl.load(f)

In [32]:
from torch_geometric.data import DataLoader
all_loader = DataLoader(all_graphs, batch_size=32)



## sGNN with GCN Encoder and 3 Features

In [33]:
def check_crisis_years(year_pairs, crisis_years):
    result = []
    for pair in year_pairs:
        start, end = pair
        # Check if any crisis year is between the pair or equals the later year
        if any(start < year <= end for year in crisis_years):
            result.append(0)
        else:
            result.append(1)
    return result

In [34]:
def get_year_pairs(year_range):
    return [(year1, year2) for year1 in year_range for year2 in year_range if year2 > year1]

def get_loader_pairs(dataset):
    return [(dataset[i], dataset[j]) for i in range(len(dataset)) for j in range(len(dataset)) if j > i]

In [35]:
crisis_years = [1967, 1973, 1981, 1989, 1990,  1996, 2002, 2007, 2012, 2016]
all_pairs = get_year_pairs(years)
all_y = check_crisis_years(all_pairs, crisis_years)
all_loader_pairs = get_loader_pairs(all_loader.dataset)

In [36]:
labeled_pairs_all = list(zip(all_loader_pairs, all_y))
flattened_all = [(a, b, c) for ((a, b), c) in labeled_pairs_all]

In [37]:
from sklearn.model_selection import train_test_split
flattened_train, flattened_test = train_test_split(flattened_all, test_size=0.40, random_state=42)
flattened_test, flattened_val = train_test_split(flattened_test, test_size=0.5, random_state=42)

In [38]:
import random
positive_samples = [item for item in flattened_train if item[2] == 1]
negative_samples = [item for item in flattened_train if item[2] == 0]

# Calculate the difference in count
diff = len(negative_samples) - len(positive_samples)

# Upsample positive samples
if diff > 0:
    positive_samples_upsampled = positive_samples * (diff // len(positive_samples)) + random.sample(positive_samples, diff % len(positive_samples))
    balanced_data = negative_samples + positive_samples + positive_samples_upsampled
else:
    balanced_data = flattened_train

random.shuffle(balanced_data)

In [39]:
from sklearn.metrics import f1_score
def run_model(model, train_data, val_data):
    torch.manual_seed(42)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    scheduler = StepLR(optimizer, step_size=10, gamma=0.1)  # Adjust step_size and gamma as needed
    criterion = nn.BCELoss()

    for epoch in tqdm(range(10)):
        model.train()
        train_losses = []
        for data1, data2, label in train_data:

            optimizer.zero_grad()
            out = model(data1, data2)
            label = torch.tensor(label).view(1).float()
            loss = criterion(out.squeeze(0), label)
            loss.backward()
            optimizer.step()
            train_losses.append(loss.item())

        scheduler.step()  # Add this line to update the learning rate

        model.eval()
        with torch.no_grad():
            val_losses = []

            val_pred = []
            val_truth = []

            correct = 0
            total = 0
            for data1, data2, label in val_data:
                out = model(data1, data2)
                label = torch.tensor(label).view(1).float()
                val_loss = criterion(out.squeeze(0), label)
                val_losses.append(val_loss.item())

                predictions = torch.round(out.squeeze())

                val_pred.append(predictions)
                val_truth.append(label)

                correct += (predictions == label).sum().item()
                total += 1

            val_loss = sum(val_losses) / len(val_losses)
            val_accuracy = correct / total

        print(f'Epoch: {epoch+1}, Training Loss: {sum(train_losses)/len(train_losses)}, Validation Loss: {val_loss}, Validation Accuracy: {val_accuracy}, Validation F1 Score: {f1_score(val_truth, val_pred)}')

In [42]:
model = SiameseGNN(num_features=balanced_data[0][0].num_node_features)
run_model(model, balanced_data, flattened_val)

 10%|█         | 1/10 [00:42<06:24, 42.67s/it]

Epoch: 1, Training Loss: 0.5809871381021697, Validation Loss: 0.4825191564508714, Validation Accuracy: 0.740625, Validation F1 Score: 0.3851851851851852


 20%|██        | 2/10 [01:24<05:39, 42.39s/it]

Epoch: 2, Training Loss: 0.37775655486568044, Validation Loss: 0.3531929041651892, Validation Accuracy: 0.771875, Validation F1 Score: 0.44274809160305345


 30%|███       | 3/10 [02:05<04:52, 41.80s/it]

Epoch: 3, Training Loss: 0.2989453561711178, Validation Loss: 0.35782502031725016, Validation Accuracy: 0.828125, Validation F1 Score: 0.5378151260504201


 40%|████      | 4/10 [02:49<04:15, 42.58s/it]

Epoch: 4, Training Loss: 0.2354259207948582, Validation Loss: 0.2872247249986685, Validation Accuracy: 0.84375, Validation F1 Score: 0.537037037037037


 50%|█████     | 5/10 [03:30<03:28, 41.76s/it]

Epoch: 5, Training Loss: 0.20091226853467312, Validation Loss: 0.1323037738489802, Validation Accuracy: 0.946875, Validation F1 Score: 0.7605633802816901


 60%|██████    | 6/10 [04:13<02:49, 42.27s/it]

Epoch: 6, Training Loss: 0.17614834879760483, Validation Loss: 0.22985842122106986, Validation Accuracy: 0.89375, Validation F1 Score: 0.5952380952380952


 70%|███████   | 7/10 [04:56<02:07, 42.45s/it]

Epoch: 7, Training Loss: 0.16164088432769033, Validation Loss: 0.1333260274677741, Validation Accuracy: 0.934375, Validation F1 Score: 0.7272727272727273


 80%|████████  | 8/10 [05:39<01:25, 42.69s/it]

Epoch: 8, Training Loss: 0.159830437073701, Validation Loss: 0.14592370319696785, Validation Accuracy: 0.95, Validation F1 Score: 0.7419354838709677


 90%|█████████ | 9/10 [06:20<00:42, 42.30s/it]

Epoch: 9, Training Loss: 0.13130723333822308, Validation Loss: 0.10918962542164082, Validation Accuracy: 0.95625, Validation F1 Score: 0.787878787878788


100%|██████████| 10/10 [07:03<00:00, 42.33s/it]

Epoch: 10, Training Loss: 0.13688449602410732, Validation Loss: 0.10635099524442922, Validation Accuracy: 0.959375, Validation F1 Score: 0.8115942028985507





In [41]:
model = SiameseGNN_SAGE(num_features=balanced_data[0][0].num_node_features)
run_model(model, balanced_data, flattened_val)

  0%|          | 0/10 [00:00<?, ?it/s]

 10%|█         | 1/10 [00:36<05:27, 36.38s/it]

Epoch: 1, Training Loss: 0.46051980130353265, Validation Loss: 0.2155472776852548, Validation Accuracy: 0.915625, Validation F1 Score: 0.6823529411764706


 20%|██        | 2/10 [01:07<04:27, 33.48s/it]

Epoch: 2, Training Loss: 0.2818589578499086, Validation Loss: 0.1798740965867182, Validation Accuracy: 0.921875, Validation F1 Score: 0.6987951807228915


 30%|███       | 3/10 [01:37<03:42, 31.75s/it]

Epoch: 3, Training Loss: 0.21699455473201604, Validation Loss: 0.14815185723127797, Validation Accuracy: 0.928125, Validation F1 Score: 0.7088607594936709


 40%|████      | 4/10 [02:07<03:05, 30.89s/it]

Epoch: 4, Training Loss: 0.19393995221504487, Validation Loss: 0.1560693987074046, Validation Accuracy: 0.921875, Validation F1 Score: 0.6753246753246753


 50%|█████     | 5/10 [02:38<02:35, 31.10s/it]

Epoch: 5, Training Loss: 0.18561852737554793, Validation Loss: 0.12435600982171309, Validation Accuracy: 0.928125, Validation F1 Score: 0.7012987012987012


 60%|██████    | 6/10 [03:09<02:04, 31.10s/it]

Epoch: 6, Training Loss: 0.18478635657851225, Validation Loss: 0.18056951389298775, Validation Accuracy: 0.909375, Validation F1 Score: 0.6666666666666666


 70%|███████   | 7/10 [03:40<01:32, 30.96s/it]

Epoch: 7, Training Loss: 0.1478375409984768, Validation Loss: 0.12964231206969998, Validation Accuracy: 0.93125, Validation F1 Score: 0.7027027027027027


 80%|████████  | 8/10 [04:11<01:01, 30.90s/it]

Epoch: 8, Training Loss: 0.14513157523334402, Validation Loss: 0.13586296867379133, Validation Accuracy: 0.95625, Validation F1 Score: 0.7941176470588235


 90%|█████████ | 9/10 [04:43<00:31, 31.49s/it]

Epoch: 9, Training Loss: 0.1687253097971486, Validation Loss: 0.18464770998398308, Validation Accuracy: 0.93125, Validation F1 Score: 0.65625


100%|██████████| 10/10 [05:14<00:00, 31.43s/it]

Epoch: 10, Training Loss: 0.15512552134292823, Validation Loss: 0.11649000187871934, Validation Accuracy: 0.953125, Validation F1 Score: 0.7945205479452054





In [40]:
model = SiameseGNN_GIN(num_features=balanced_data[0][0].num_node_features)
run_model(model, balanced_data, flattened_val)

 10%|█         | 1/10 [00:32<04:53, 32.59s/it]

Epoch: 1, Training Loss: 0.5230599435935305, Validation Loss: 0.6204438347369432, Validation Accuracy: 0.734375, Validation F1 Score: 0.4055944055944056


 20%|██        | 2/10 [01:04<04:18, 32.36s/it]

Epoch: 2, Training Loss: 0.39643476392904464, Validation Loss: 0.2742441453767242, Validation Accuracy: 0.925, Validation F1 Score: 0.625


 30%|███       | 3/10 [01:36<03:45, 32.21s/it]

Epoch: 3, Training Loss: 0.35243754913813474, Validation Loss: 0.25783864195109346, Validation Accuracy: 0.903125, Validation F1 Score: 0.507936507936508


 40%|████      | 4/10 [02:11<03:18, 33.06s/it]

Epoch: 4, Training Loss: 0.3298476161685275, Validation Loss: 0.23548866846977035, Validation Accuracy: 0.928125, Validation F1 Score: 0.6666666666666666


 50%|█████     | 5/10 [02:45<02:46, 33.35s/it]

Epoch: 5, Training Loss: 0.3232778276281103, Validation Loss: 0.22404401373059954, Validation Accuracy: 0.9125, Validation F1 Score: 0.5882352941176471


 60%|██████    | 6/10 [03:16<02:11, 32.84s/it]

Epoch: 6, Training Loss: 0.2941389015010579, Validation Loss: 0.26092080785820143, Validation Accuracy: 0.884375, Validation F1 Score: 0.5842696629213483


 70%|███████   | 7/10 [03:49<01:37, 32.63s/it]

Epoch: 7, Training Loss: 0.26384397836472173, Validation Loss: 0.19642589058348675, Validation Accuracy: 0.94375, Validation F1 Score: 0.7428571428571428


 80%|████████  | 8/10 [04:21<01:05, 32.57s/it]

Epoch: 8, Training Loss: 0.2596183811170869, Validation Loss: 0.12632159272689023, Validation Accuracy: 0.95625, Validation F1 Score: 0.7666666666666666


 90%|█████████ | 9/10 [04:53<00:32, 32.29s/it]

Epoch: 9, Training Loss: 0.23406935198284123, Validation Loss: 0.2623385371611221, Validation Accuracy: 0.890625, Validation F1 Score: 0.6153846153846154


100%|██████████| 10/10 [05:25<00:00, 32.59s/it]

Epoch: 10, Training Loss: 0.23129249472642022, Validation Loss: 0.18001112754573115, Validation Accuracy: 0.940625, Validation F1 Score: 0.7466666666666667





## sGNN with Feature Subset

In [15]:
with open("feature_dicts/filtered_features_dict.pkl", "rb") as f:
    feat_dict = pkl.load(f)

In [16]:
all_loader = DataLoader(all_graphs, batch_size=4)



In [17]:
def add_features(years, graphs, feat_dict, dim):

    zeros = torch.zeros(dim)

    for i in range(len(years)):
        new_x = torch.empty(0, dim)
        year = years[i]

        feat_dict_year = feat_dict[year].combined_features

        for j, country in enumerate(all_nodes):
            if j == 0:
                new_x = torch.stack([zeros])

            elif country in feat_dict_year["country_code"].values:
                tensor_before = graphs[i].x[j]
                country_row = feat_dict_year[feat_dict_year["country_code"] == country]
                country_row = country_row.drop(columns = ["prev_gdp_growth", "country_code", "current_gdp_growth"])
                row_values = country_row.values.tolist()
                row_tensor = torch.tensor(row_values)[0]
                combined_values = torch.cat((tensor_before, row_tensor))

                new_x = torch.cat((new_x, combined_values.unsqueeze(0)), dim=0)

            else:
                new_x = torch.cat((new_x, zeros.unsqueeze(0)), dim=0)

        graphs[i].x = new_x

    return graphs

In [18]:
all_graphs = add_features(years, all_graphs, feat_dict, 59)

In [19]:
all_pairs = get_year_pairs(years)
all_y = check_crisis_years(all_pairs, crisis_years)
all_loader_pairs = get_loader_pairs(all_loader.dataset)

In [20]:
labeled_pairs_all = list(zip(all_loader_pairs, all_y))
flattened_all = [(a, b, c) for ((a, b), c) in labeled_pairs_all]

In [21]:
flattened_train, flattened_test = train_test_split(flattened_all, test_size=0.40, random_state=42)
flattened_test, flattened_val = train_test_split(flattened_test, test_size=0.5, random_state=42)

In [22]:
positive_samples = [item for item in flattened_train if item[2] == 1]
negative_samples = [item for item in flattened_train if item[2] == 0]

# Calculate the difference in count
diff = len(negative_samples) - len(positive_samples)

# Upsample positive samples
if diff > 0:
    positive_samples_upsampled = positive_samples * (diff // len(positive_samples)) + random.sample(positive_samples, diff % len(positive_samples))
    balanced_data = negative_samples + positive_samples + positive_samples_upsampled
else:
    balanced_data = data

# Shuffle the balanced dataset
random.shuffle(balanced_data)

In [25]:
model = SiameseGNN(num_features=balanced_data[0][0].num_node_features)
run_model(model, balanced_data, flattened_val)

 10%|█         | 1/10 [00:41<06:16, 41.87s/it]

0.44776119402985076
Epoch: 1, Training Loss: 0.4987440218937329, Validation Loss: 0.3832175207673572, Validation Accuracy: 0.76875


 20%|██        | 2/10 [01:25<05:43, 42.98s/it]

0.6590909090909091
Epoch: 2, Training Loss: 0.23536320227007804, Validation Loss: 0.18096327264793216, Validation Accuracy: 0.90625


 30%|███       | 3/10 [02:08<04:59, 42.80s/it]

0.7540983606557378
Epoch: 3, Training Loss: 0.15659467354211798, Validation Loss: 0.11486523618004867, Validation Accuracy: 0.953125


 40%|████      | 4/10 [02:51<04:18, 43.13s/it]

0.8115942028985507
Epoch: 4, Training Loss: 0.12247512682224623, Validation Loss: 0.09684106448949023, Validation Accuracy: 0.959375


 50%|█████     | 5/10 [03:33<03:33, 42.70s/it]

0.6597938144329897
Epoch: 5, Training Loss: 0.12023657889410308, Validation Loss: 0.2596082295000087, Validation Accuracy: 0.896875


 60%|██████    | 6/10 [04:14<02:48, 42.13s/it]

0.8101265822784811
Epoch: 6, Training Loss: 0.11489412036346701, Validation Loss: 0.11251693668782536, Validation Accuracy: 0.953125


 70%|███████   | 7/10 [04:58<02:08, 42.76s/it]

0.8311688311688312
Epoch: 7, Training Loss: 0.09849922217121772, Validation Loss: 0.0997648756047056, Validation Accuracy: 0.959375


 80%|████████  | 8/10 [05:41<01:25, 42.59s/it]

0.7764705882352942
Epoch: 8, Training Loss: 0.08981255898506788, Validation Loss: 0.20358384737664892, Validation Accuracy: 0.940625


 90%|█████████ | 9/10 [06:22<00:42, 42.32s/it]

0.7894736842105262
Epoch: 9, Training Loss: 0.10029766877884289, Validation Loss: 0.12007957108180563, Validation Accuracy: 0.95


100%|██████████| 10/10 [07:05<00:00, 42.58s/it]

0.9189189189189189
Epoch: 10, Training Loss: 0.06419937848149856, Validation Loss: 0.04051978894331114, Validation Accuracy: 0.98125





In [42]:
torch.save(model.state_dict(), "src/pygcn/siamese_gnn_gcn_mis.pt")

In [29]:
model = SiameseGNN_SAGE(num_features=balanced_data[0][0].num_node_features)
run_model(model, balanced_data, flattened_val)

  0%|          | 0/10 [00:00<?, ?it/s]

 10%|█         | 1/10 [00:35<05:15, 35.07s/it]

Epoch: 1, Training Loss: 0.3863306344582022, Validation Loss: 0.17464555113692767, Validation Accuracy: 0.909375, Validation F1 Score: 0.6329113924050633


 20%|██        | 2/10 [01:12<04:53, 36.65s/it]

Epoch: 2, Training Loss: 0.24023351249925326, Validation Loss: 0.16495273133768934, Validation Accuracy: 0.928125, Validation F1 Score: 0.735632183908046


 30%|███       | 3/10 [01:45<04:04, 34.89s/it]

Epoch: 3, Training Loss: 0.21170187257031847, Validation Loss: 0.20896186296304223, Validation Accuracy: 0.90625, Validation F1 Score: 0.5945945945945946


 40%|████      | 4/10 [02:20<03:28, 34.81s/it]

Epoch: 4, Training Loss: 0.21820917689924085, Validation Loss: 0.1437196595550631, Validation Accuracy: 0.940625, Validation F1 Score: 0.7246376811594202


 50%|█████     | 5/10 [02:51<02:48, 33.68s/it]

Epoch: 5, Training Loss: 0.16726874001985628, Validation Loss: 0.24335027272172738, Validation Accuracy: 0.94375, Validation F1 Score: 0.6538461538461539


 60%|██████    | 6/10 [03:23<02:12, 33.00s/it]

Epoch: 6, Training Loss: 0.15210310100274826, Validation Loss: 0.14126516989417723, Validation Accuracy: 0.93125, Validation F1 Score: 0.717948717948718


 70%|███████   | 7/10 [03:58<01:40, 33.47s/it]

Epoch: 7, Training Loss: 0.11477552743308479, Validation Loss: 0.0903649390253122, Validation Accuracy: 0.971875, Validation F1 Score: 0.8524590163934426


 80%|████████  | 8/10 [04:31<01:06, 33.43s/it]

Epoch: 8, Training Loss: 0.1029601419349836, Validation Loss: 0.06083224770663946, Validation Accuracy: 0.978125, Validation F1 Score: 0.8923076923076922


 90%|█████████ | 9/10 [05:06<00:33, 33.88s/it]

Epoch: 9, Training Loss: 0.08131720526438255, Validation Loss: 0.040152629572548906, Validation Accuracy: 0.98125, Validation F1 Score: 0.9166666666666667


100%|██████████| 10/10 [05:40<00:00, 34.08s/it]

Epoch: 10, Training Loss: 0.06820880474660954, Validation Loss: 0.08141275435591525, Validation Accuracy: 0.959375, Validation F1 Score: 0.7999999999999999





In [48]:
torch.save(model.state_dict(), "src/pygcn/siamese_gnn_sage_mis.pt")

## Random Feature Subset

In [40]:
with open("feature_dicts/random_features_dict.pkl", "rb") as f:
    feat_dict_random = pkl.load(f)

In [43]:
with open("src/pygcn/all_graphs.pkl", "rb") as f:         
    all_graphs = pkl.load(f)

In [44]:
all_graphs = add_features(years, all_graphs, feat_dict_random, 434)

In [45]:
all_pairs = get_year_pairs(years)
all_y = check_crisis_years(all_pairs, crisis_years)
all_loader_pairs = get_loader_pairs(all_loader.dataset)

In [46]:
labeled_pairs_all = list(zip(all_loader_pairs, all_y))
flattened_all = [(a, b, c) for ((a, b), c) in labeled_pairs_all]

In [47]:
flattened_train, flattened_test = train_test_split(flattened_all, test_size=0.40, random_state=42)
flattened_test, flattened_val = train_test_split(flattened_test, test_size=0.5, random_state=42)

In [48]:
positive_samples = [item for item in flattened_train if item[2] == 1]
negative_samples = [item for item in flattened_train if item[2] == 0]

# Calculate the difference in count
diff = len(negative_samples) - len(positive_samples)

# Upsample positive samples
if diff > 0:
    positive_samples_upsampled = positive_samples * (diff // len(positive_samples)) + random.sample(positive_samples, diff % len(positive_samples))
    balanced_data = negative_samples + positive_samples + positive_samples_upsampled
else:
    balanced_data = data

# Shuffle the balanced dataset
random.shuffle(balanced_data)

In [49]:
model = SiameseGNN(num_features=balanced_data[0][0].num_node_features)
run_model(model)

 10%|█         | 1/10 [00:43<06:29, 43.24s/it]

Epoch: 1, Training Loss: 0.544859388932245, Validation Loss: 0.34816615782452354, Validation Accuracy: 0.8090909090909091


 20%|██        | 2/10 [01:30<06:04, 45.58s/it]

Epoch: 2, Training Loss: 0.3392501510916909, Validation Loss: 0.30394700742467784, Validation Accuracy: 0.8666666666666667


 30%|███       | 3/10 [02:15<05:16, 45.16s/it]

Epoch: 3, Training Loss: 0.26689445104531306, Validation Loss: 0.31279571309375265, Validation Accuracy: 0.8515151515151516


 40%|████      | 4/10 [03:00<04:32, 45.44s/it]

Epoch: 4, Training Loss: 0.2068697312217992, Validation Loss: 0.22516020783928758, Validation Accuracy: 0.896969696969697


 50%|█████     | 5/10 [03:43<03:42, 44.43s/it]

Epoch: 5, Training Loss: 0.16928869296208252, Validation Loss: 0.1768982171130395, Validation Accuracy: 0.9151515151515152


 60%|██████    | 6/10 [04:36<03:09, 47.43s/it]

Epoch: 6, Training Loss: 0.15923998523289348, Validation Loss: 0.15861676974737113, Validation Accuracy: 0.9333333333333333


 70%|███████   | 7/10 [05:25<02:23, 47.71s/it]

Epoch: 7, Training Loss: 0.15724181305050022, Validation Loss: 0.18814892782265263, Validation Accuracy: 0.9181818181818182


 80%|████████  | 8/10 [06:16<01:37, 48.83s/it]

Epoch: 8, Training Loss: 0.13051756557245742, Validation Loss: 0.1601937293960487, Validation Accuracy: 0.9424242424242424


 90%|█████████ | 9/10 [07:08<00:49, 49.76s/it]

Epoch: 9, Training Loss: 0.12332007361746833, Validation Loss: 0.10835096457458072, Validation Accuracy: 0.9606060606060606


100%|██████████| 10/10 [08:02<00:00, 48.22s/it]

Epoch: 10, Training Loss: 0.13025680503896847, Validation Loss: 0.1940113610416919, Validation Accuracy: 0.9090909090909091





In [50]:
model = SiameseGNN_SAGE(num_features=balanced_data[0][0].num_node_features)
run_model(model)

  0%|          | 0/10 [00:00<?, ?it/s]

 10%|█         | 1/10 [00:34<05:07, 34.18s/it]

Epoch: 1, Training Loss: 0.4442165199707876, Validation Loss: 0.19106294969378998, Validation Accuracy: 0.9212121212121213


 20%|██        | 2/10 [01:06<04:25, 33.14s/it]

Epoch: 2, Training Loss: 0.25832506298792324, Validation Loss: 0.19086683025075632, Validation Accuracy: 0.9424242424242424


 30%|███       | 3/10 [01:39<03:49, 32.82s/it]

Epoch: 3, Training Loss: 0.1691997451430599, Validation Loss: 0.1932731779962496, Validation Accuracy: 0.9090909090909091


 40%|████      | 4/10 [02:11<03:15, 32.66s/it]

Epoch: 4, Training Loss: 0.13047860322443655, Validation Loss: 0.11865929832301018, Validation Accuracy: 0.9575757575757575


 50%|█████     | 5/10 [02:47<02:49, 34.00s/it]

Epoch: 5, Training Loss: 0.10779943096805669, Validation Loss: 0.18853859430510608, Validation Accuracy: 0.9333333333333333


 60%|██████    | 6/10 [03:27<02:23, 35.78s/it]

Epoch: 6, Training Loss: 0.10272896803320476, Validation Loss: 0.16195754717654465, Validation Accuracy: 0.9454545454545454


 70%|███████   | 7/10 [04:02<01:46, 35.63s/it]

Epoch: 7, Training Loss: 0.08266344443156012, Validation Loss: 0.18734804903629773, Validation Accuracy: 0.9272727272727272


 80%|████████  | 8/10 [04:37<01:11, 35.61s/it]

Epoch: 8, Training Loss: 0.07332483878021086, Validation Loss: 0.10584011013958264, Validation Accuracy: 0.9666666666666667


 90%|█████████ | 9/10 [05:12<00:35, 35.22s/it]

Epoch: 9, Training Loss: 0.0708977775151172, Validation Loss: 0.08185857179200728, Validation Accuracy: 0.9666666666666667


100%|██████████| 10/10 [05:54<00:00, 35.49s/it]

Epoch: 10, Training Loss: 0.06833519310347812, Validation Loss: 0.5299345482604677, Validation Accuracy: 0.8363636363636363



