In [1]:
import pandas as pd
import numpy as np
import pickle as pkl
import datetime as datetime
from tqdm import tqdm
import networkx as nx
import random

import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from torch.optim.lr_scheduler import StepLR
from src.utils.CreateFeatures import CreateFeatures
from src.pygcn.SiameseGNN import SiameseGNN
from src.pygcn.GraphSAGE import SiameseGNN_GraphSAGE
from src.pygcn.graph_isomorphism import SiameseGNN_GIN
from torch_geometric.data import DataLoader
from sklearn.metrics import f1_score, fbeta_score
from src.utils.functions import dist_labels_to_changepoint_labels, dist_labels_to_changepoint_labels_adjusted
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch
import torch_geometric.data as data
from src.synthetic_experiments.sample import sample_pairs
from src.utils.misc import collate
import itertools

## All Events

In [2]:
years = range(1962,2019)

In [3]:
all_nodes = ['ABW', 'AFG', 'AGO', 'ALB', 'AND', 'ARE', 'ARG', 'ARM', 'ASM',
       'ATG', 'AUS', 'AUT', 'AZE', 'BDI', 'BEL', 'BEN', 'BFA', 'BGD',
       'BGR', 'BHR', 'BHS', 'BIH', 'BLR', 'BLZ', 'BMU', 'BOL', 'BRA',
       'BRB', 'BRN', 'BTN', 'BWA', 'CAF', 'CAN', 'CHE', 'CHL', 'CHN',
       'CIV', 'CMR', 'COD', 'COG', 'COL', 'COM', 'CPV', 'CRI', 'CUB',
       'CUW', 'CYM', 'CYP', 'CZE', 'DEU', 'DMA', 'DNK', 'DOM', 'DZA',
       'ECU', 'EGY', 'ESP', 'EST', 'ETH', 'FIN', 'FJI', 'FRA', 'FSM',
       'GAB', 'GBR', 'GEO', 'GHA', 'GIN', 'GMB', 'GNB', 'GNQ', 'GRC',
       'GRD', 'GRL', 'GTM', 'GUM', 'GUY', 'HKG', 'HND', 'HRV', 'HTI',
       'HUN', 'IDN', 'IND', 'IRL', 'IRN', 'IRQ', 'ISL', 'ISR', 'ITA',
       'JAM', 'JOR', 'JPN', 'KAZ', 'KEN', 'KGZ', 'KHM', 'KNA', 'KOR',
       'KWT', 'LAO', 'LBN', 'LBR', 'LBY', 'LCA', 'LKA', 'LSO', 'LTU',
       'LUX', 'LVA', 'MAC', 'MAR', 'MDA', 'MDG', 'MDV', 'MEX', 'MHL',
       'MKD', 'MLI', 'MLT', 'MMR', 'MNE', 'MNG', 'MNP', 'MOZ', 'MRT',
       'MUS', 'MWI', 'MYS', 'NAM', 'NER', 'NGA', 'NIC', 'NLD', 'NOR',
       'NPL', 'NRU', 'NZL', 'OMN', 'PAK', 'PAN', 'PER', 'PHL', 'PLW',
       'PNG', 'POL', 'PRT', 'PRY', 'PSE', 'PYF', 'QAT', 'ROU', 'RUS',
       'RWA', 'SAU', 'SDN', 'SEN', 'SGP', 'SLB', 'SLE', 'SLV', 'SMR',
       'SRB', 'SSD', 'STP', 'SUR', 'SVK', 'SVN', 'SWE', 'SWZ', 'SXM',
       'SYC', 'SYR', 'TCD', 'TGO', 'THA', 'TJK', 'TKM', 'TLS', 'TON',
       'TTO', 'TUN', 'TUR', 'TUV', 'TZA', 'UGA', 'UKR', 'URY', 'USA',
       'UZB', 'VCT', 'VEN', 'VNM', 'VUT', 'WSM', 'YEM', 'ZAF', 'ZMB',
       'ZWE']

In [4]:
with open("src/pygcn/all_graphs.pkl", "rb") as f:         
    all_graphs = pkl.load(f)

In [5]:
crisis_years = [1962, 1967, 1973, 1978, 1981, 1989, 1993, 1996, 2002, 2007, 2012, 2014, 2016]
phases = []
p = -1
for i in range(1962,2019):
    if i in crisis_years:
        p += 1
    phases.append(p)

In [6]:
# labels = dist_labels_to_changepoint_labels(phases)
# graph_pairs = sample_pairs(all_graphs,labels)

In [7]:
# with open('graph_pairs.pkl', 'wb') as f:
#     pkl.dump(graph_pairs, f)

In [8]:
with open('graph_pairs.pkl', 'rb') as f:
    graph_pairs = pkl.load(f)

In [9]:
with open('test_data.pkl', 'rb') as f:
    test_data = pkl.load(f)

with open('val_data.pkl', 'rb') as f:
    val_data = pkl.load(f)

with open('train_data.pkl', 'rb') as f:
    train_data = pkl.load(f)

In [10]:
with open('test_data_indices.pkl', 'rb') as f:
    test_indices = pkl.load(f)

with open('val_data_indices.pkl', 'rb') as f:
    val_indices = pkl.load(f)

with open('train_data_indices.pkl', 'rb') as f:
    train_indices = pkl.load(f)

In [11]:
# train_indices, test_indices = train_test_split(np.arange(len(graph_pairs)), test_size=0.40, random_state=42)
# test_indices, val_indices = train_test_split(train_indices, test_size=0.5, random_state=42)

In [12]:
graph_pairs_train = [graph_pairs[i] for i in train_indices]
graph_pairs_test = [graph_pairs[i] for i in test_indices]
graph_pairs_val = [graph_pairs[i] for i in val_indices]

In [13]:
import random
positive_samples = [item for item in graph_pairs_train if item[2] == 1]
negative_samples = [item for item in graph_pairs_train if item[2] == 0]

# Calculate the difference in count
diff = len(negative_samples) - len(positive_samples)

# Upsample positive samples
if diff > 0:
    positive_samples_upsampled = positive_samples * (diff // len(positive_samples)) + random.sample(positive_samples, diff % len(positive_samples))
    balanced_data = negative_samples + positive_samples + positive_samples_upsampled
else:
    balanced_data = graph_pairs_train

random.shuffle(balanced_data)

In [14]:
# with open('train_data.pkl', 'wb') as f:
#     pkl.dump(balanced_data, f)

# with open('val_data.pkl', 'wb') as f:
#     pkl.dump(graph_pairs_val, f)
    
# with open('test_data.pkl', 'wb') as f:
#     pkl.dump(graph_pairs_test, f)

In [15]:
# with open('train_data_indices.pkl', 'wb') as f:
#     pkl.dump(train_indices, f)

# with open('val_data_indices.pkl', 'wb') as f:
#     pkl.dump(val_indices, f)
    
# with open('test_data_indices.pkl', 'wb') as f:
#     pkl.dump(test_indices, f)

In [17]:
def run_model(model, training_data_pairs, validation_data_pairs):
    torch.manual_seed(42)
    optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=0.0001)
    scheduler = StepLR(optimizer, step_size=10, gamma=0.1)
    criterion = nn.BCELoss()  # Changed to BCEWithLogitsLoss for numerical stability

    for epoch in tqdm(range(30)):
        model.train()
        train_losses = []
        for (graph1, graph2, labels) in training_data_pairs:
            optimizer.zero_grad()
            out = model(graph1, graph2)
    
            loss = criterion(out, labels)
            loss.backward()
            optimizer.step()
            train_losses.append(loss.item())

        scheduler.step()

        model.eval()
        with torch.no_grad():
            val_losses = []

            val_pred = []
            val_truth = []

            correct = 0
            total = 0
            for (graph1, graph2, labels) in validation_data_pairs:
                out = model(graph1, graph2)

                val_loss = criterion(out, labels)
                val_losses.append(val_loss.item())

                predictions = torch.round(out)

                val_pred.extend(predictions.cpu().numpy())
                val_truth.extend(labels.cpu().numpy())

                correct += (predictions == labels).sum().item()
                total += labels.size(0)

            val_loss = sum(val_losses) / len(val_losses)
            val_accuracy = correct / total

        val_f1 = f1_score(val_truth, val_pred)
        val_f2 = fbeta_score(y_true=val_truth, y_pred=val_pred, beta=2)
        val_f05 = fbeta_score(y_true=val_truth, y_pred=val_pred, beta=1 / 2)
        print(f'Epoch: {epoch+1}, Training Loss: {sum(train_losses)/len(train_losses)}, Validation Loss: {val_loss}, Validation Accuracy: {val_accuracy}, Validation F1 Score: {val_f1}, Validation F2 Score: {val_f2}, Validation F0.5 Score: {val_f05}')
    return val_accuracy, val_f1, val_loss, val_f2, val_f05

## sGNN with GCN Encoder and 3 Features

In [29]:
train_data

[[Data(x=[199, 12], edge_index=[2, 6037], edge_attr=[6037], y=[199]),
  Data(x=[199, 12], edge_index=[2, 11125], edge_attr=[11125], y=[199]),
  tensor([1.])],
 [Data(x=[199, 12], edge_index=[2, 8656], edge_attr=[8656], y=[199]),
  Data(x=[199, 12], edge_index=[2, 15242], edge_attr=[15242], y=[199]),
  tensor([1.])],
 [Data(x=[199, 12], edge_index=[2, 23901], edge_attr=[23901], y=[199]),
  Data(x=[199, 12], edge_index=[2, 8107], edge_attr=[8107], y=[199]),
  tensor([1.])],
 [Data(x=[199, 12], edge_index=[2, 6037], edge_attr=[6037], y=[199]),
  Data(x=[199, 12], edge_index=[2, 3975], edge_attr=[3975], y=[199]),
  tensor([1.])],
 [Data(x=[199, 12], edge_index=[2, 8273], edge_attr=[8273], y=[199]),
  Data(x=[199, 12], edge_index=[2, 22589], edge_attr=[22589], y=[199]),
  tensor([1.])],
 [Data(x=[199, 12], edge_index=[2, 8107], edge_attr=[8107], y=[199]),
  Data(x=[199, 12], edge_index=[2, 21298], edge_attr=[21298], y=[199]),
  tensor([1.])],
 [Data(x=[199, 12], edge_index=[2, 21298], edge_

In [12]:
input_dim = train_data[0][0].x.shape[1]

val_losses = []
parameters = []
# Define hyperparameter grids
learning_rates = [0.01]
dropout_rates = [0.05]
sort_k_values = [50]
hidden_units_values = [16]

# Create combinations of hyperparameters
hyperparameter_combinations = list(itertools.product(learning_rates, dropout_rates, sort_k_values, hidden_units_values))

for lr, dropout_rate, sort_k, hidden_units in hyperparameter_combinations:
    model = SiameseGNN(sort_k, input_dim, dropout = dropout_rate, nhidden=hidden_units)
    val_accuracy, val_f1, val_loss, val_f2, val_05 = run_model(model, train_data, val_data)

    parameters.append((lr, dropout_rate, sort_k, hidden_units))
    val_losses.append((val_accuracy, val_f1, val_loss, val_f2, val_05))

  3%|▎         | 1/30 [00:14<06:50, 14.14s/it]

Epoch: 1, Training Loss: 0.6475932522440405, Validation Loss: 0.6338973568600752, Validation Accuracy: 0.6659707724425887, Validation F1 Score: 0.7994987468671679, Validation F2 Score: 0.9088319088319088, Validation F0.5 Score: 0.7136465324384788


  7%|▋         | 2/30 [00:28<06:36, 14.15s/it]

Epoch: 2, Training Loss: 0.631888928863073, Validation Loss: 0.6296926105321076, Validation Accuracy: 0.6659707724425887, Validation F1 Score: 0.7994987468671679, Validation F2 Score: 0.9088319088319088, Validation F0.5 Score: 0.7136465324384788


 10%|█         | 3/30 [00:42<06:20, 14.08s/it]

Epoch: 3, Training Loss: 0.6344158494204688, Validation Loss: 0.623619589340214, Validation Accuracy: 0.6659707724425887, Validation F1 Score: 0.7994987468671679, Validation F2 Score: 0.9088319088319088, Validation F0.5 Score: 0.7136465324384788


 13%|█▎        | 4/30 [00:56<06:02, 13.95s/it]

Epoch: 4, Training Loss: 0.6265723750291075, Validation Loss: 0.6185277254944803, Validation Accuracy: 0.6680584551148225, Validation F1 Score: 0.8005018820577164, Validation F2 Score: 0.9093500570125428, Validation F0.5 Score: 0.7149260421335724


 17%|█▋        | 5/30 [01:08<05:36, 13.46s/it]

Epoch: 5, Training Loss: 0.623148025794961, Validation Loss: 0.6139665860099434, Validation Accuracy: 0.6701461377870563, Validation F1 Score: 0.801007556675063, Validation F2 Score: 0.9080525414049114, Validation F0.5 Score: 0.7165389815232086


 20%|██        | 6/30 [01:23<05:31, 13.79s/it]

Epoch: 6, Training Loss: 0.6206493449933601, Validation Loss: 0.6044460023930277, Validation Accuracy: 0.7077244258872651, Validation F1 Score: 0.8153034300791556, Validation F2 Score: 0.9008746355685131, Validation F0.5 Score: 0.744578313253012


 23%|██▎       | 7/30 [01:38<05:28, 14.27s/it]

Epoch: 7, Training Loss: 0.6103450069544581, Validation Loss: 0.5896932150476412, Validation Accuracy: 0.6910229645093946, Validation F1 Score: 0.8082901554404145, Validation F2 Score: 0.9022556390977443, Validation F0.5 Score: 0.7320506804317222


 27%|██▋       | 8/30 [01:53<05:18, 14.48s/it]

Epoch: 8, Training Loss: 0.6030247503718363, Validation Loss: 0.5725518988293745, Validation Accuracy: 0.7369519832985386, Validation F1 Score: 0.8283378746594006, Validation F2 Score: 0.898876404494382, Validation F0.5 Score: 0.7680646791308742


 30%|███       | 9/30 [02:07<05:04, 14.48s/it]

Epoch: 9, Training Loss: 0.5827886123393148, Validation Loss: 0.5572921374894383, Validation Accuracy: 0.7411273486430062, Validation F1 Score: 0.8268156424581006, Validation F2 Score: 0.8846383741781231, Validation F0.5 Score: 0.7760880964866282


 33%|███▎      | 10/30 [02:23<04:54, 14.74s/it]

Epoch: 10, Training Loss: 0.5799011532042094, Validation Loss: 0.5486124338648265, Validation Accuracy: 0.755741127348643, Validation F1 Score: 0.8321377331420373, Validation F2 Score: 0.8766626360338573, Validation F0.5 Score: 0.7919169852539596


 37%|███▋      | 11/30 [02:38<04:43, 14.93s/it]

Epoch: 11, Training Loss: 0.5580453568117746, Validation Loss: 0.5445413239813548, Validation Accuracy: 0.7599164926931107, Validation F1 Score: 0.8354792560801144, Validation F2 Score: 0.8816425120772947, Validation F0.5 Score: 0.7939097335508428


 40%|████      | 12/30 [02:53<04:28, 14.94s/it]

Epoch: 12, Training Loss: 0.5546759655669938, Validation Loss: 0.5414043102483411, Validation Accuracy: 0.7578288100208769, Validation F1 Score: 0.8333333333333334, Validation F2 Score: 0.8771929824561403, Validation F0.5 Score: 0.7936507936507936


 43%|████▎     | 13/30 [03:07<04:11, 14.78s/it]

Epoch: 13, Training Loss: 0.5460121669614826, Validation Loss: 0.5384814535464026, Validation Accuracy: 0.7599164926931107, Validation F1 Score: 0.8335745296671491, Validation F2 Score: 0.8737864077669902, Validation F0.5 Score: 0.7969009407858328


 47%|████▋     | 14/30 [03:21<03:52, 14.54s/it]

Epoch: 14, Training Loss: 0.5564902152636457, Validation Loss: 0.5374189988867475, Validation Accuracy: 0.7578288100208769, Validation F1 Score: 0.8304093567251462, Validation F2 Score: 0.8653260207190737, Validation F0.5 Score: 0.7982012366498032


 50%|█████     | 15/30 [03:35<03:36, 14.41s/it]

Epoch: 15, Training Loss: 0.5530320451115117, Validation Loss: 0.5351101490114826, Validation Accuracy: 0.7682672233820459, Validation F1 Score: 0.8402877697841726, Validation F2 Score: 0.8837772397094431, Validation F0.5 Score: 0.8008776741634668


 53%|█████▎    | 16/30 [03:50<03:21, 14.37s/it]

Epoch: 16, Training Loss: 0.5414383820761209, Validation Loss: 0.5318136485857357, Validation Accuracy: 0.7661795407098121, Validation F1 Score: 0.838150289017341, Validation F2 Score: 0.8793208004851425, Validation F0.5 Score: 0.800662617338487


 57%|█████▋    | 17/30 [04:04<03:08, 14.49s/it]

Epoch: 17, Training Loss: 0.5505499722379626, Validation Loss: 0.5307340769715498, Validation Accuracy: 0.7703549060542797, Validation F1 Score: 0.8414985590778098, Validation F2 Score: 0.884312537855845, Validation F0.5 Score: 0.8026388125343595


 60%|██████    | 18/30 [04:19<02:52, 14.38s/it]

Epoch: 18, Training Loss: 0.5508465861732309, Validation Loss: 0.529937560442594, Validation Accuracy: 0.7703549060542797, Validation F1 Score: 0.8414985590778098, Validation F2 Score: 0.884312537855845, Validation F0.5 Score: 0.8026388125343595


 63%|██████▎   | 19/30 [04:33<02:39, 14.47s/it]

Epoch: 19, Training Loss: 0.5430172730544767, Validation Loss: 0.5285425179624358, Validation Accuracy: 0.7703549060542797, Validation F1 Score: 0.8428571428571429, Validation F2 Score: 0.8901629450814725, Validation F0.5 Score: 0.8003255561584374


 67%|██████▋   | 20/30 [04:47<02:22, 14.25s/it]

Epoch: 20, Training Loss: 0.5535128874277994, Validation Loss: 0.5282244043957466, Validation Accuracy: 0.7682672233820459, Validation F1 Score: 0.8393632416787264, Validation F2 Score: 0.8798543689320388, Validation F0.5 Score: 0.8024349750968456


 70%|███████   | 21/30 [05:02<02:09, 14.38s/it]

Epoch: 21, Training Loss: 0.5455049943245067, Validation Loss: 0.5280471457506769, Validation Accuracy: 0.7682672233820459, Validation F1 Score: 0.8393632416787264, Validation F2 Score: 0.8798543689320388, Validation F0.5 Score: 0.8024349750968456


 73%|███████▎  | 22/30 [05:16<01:54, 14.26s/it]

Epoch: 22, Training Loss: 0.5406120806434187, Validation Loss: 0.5275086759276579, Validation Accuracy: 0.7682672233820459, Validation F1 Score: 0.8393632416787264, Validation F2 Score: 0.8798543689320388, Validation F0.5 Score: 0.8024349750968456


 77%|███████▋  | 23/30 [05:30<01:39, 14.21s/it]

Epoch: 23, Training Loss: 0.5458725253304967, Validation Loss: 0.5275156575727562, Validation Accuracy: 0.7682672233820459, Validation F1 Score: 0.8393632416787264, Validation F2 Score: 0.8798543689320388, Validation F0.5 Score: 0.8024349750968456


 80%|████████  | 24/30 [05:45<01:27, 14.56s/it]

Epoch: 24, Training Loss: 0.5424739046872845, Validation Loss: 0.5270743063011348, Validation Accuracy: 0.7682672233820459, Validation F1 Score: 0.8393632416787264, Validation F2 Score: 0.8798543689320388, Validation F0.5 Score: 0.8024349750968456


 83%|████████▎ | 25/30 [06:00<01:12, 14.59s/it]

Epoch: 25, Training Loss: 0.5379968209419879, Validation Loss: 0.5269224514605357, Validation Accuracy: 0.7682672233820459, Validation F1 Score: 0.8393632416787264, Validation F2 Score: 0.8798543689320388, Validation F0.5 Score: 0.8024349750968456


 87%|████████▋ | 26/30 [06:14<00:58, 14.54s/it]

Epoch: 26, Training Loss: 0.5329643553638658, Validation Loss: 0.5268985306482474, Validation Accuracy: 0.7682672233820459, Validation F1 Score: 0.8393632416787264, Validation F2 Score: 0.8798543689320388, Validation F0.5 Score: 0.8024349750968456


 90%|█████████ | 27/30 [06:30<00:44, 14.78s/it]

Epoch: 27, Training Loss: 0.5387632309643079, Validation Loss: 0.5268945464386074, Validation Accuracy: 0.7682672233820459, Validation F1 Score: 0.8393632416787264, Validation F2 Score: 0.8798543689320388, Validation F0.5 Score: 0.8024349750968456


 93%|█████████▎| 28/30 [06:44<00:29, 14.61s/it]

Epoch: 28, Training Loss: 0.544024511875902, Validation Loss: 0.5266065711703828, Validation Accuracy: 0.7682672233820459, Validation F1 Score: 0.8393632416787264, Validation F2 Score: 0.8798543689320388, Validation F0.5 Score: 0.8024349750968456


 97%|█████████▋| 29/30 [06:58<00:14, 14.63s/it]

Epoch: 29, Training Loss: 0.5440046424620328, Validation Loss: 0.5263311030471499, Validation Accuracy: 0.7682672233820459, Validation F1 Score: 0.8393632416787264, Validation F2 Score: 0.8798543689320388, Validation F0.5 Score: 0.8024349750968456


100%|██████████| 30/30 [07:13<00:00, 14.45s/it]

Epoch: 30, Training Loss: 0.5466404375050883, Validation Loss: 0.5265062872888151, Validation Accuracy: 0.7682672233820459, Validation F1 Score: 0.8393632416787264, Validation F2 Score: 0.8798543689320388, Validation F0.5 Score: 0.8024349750968456





In [18]:
input_dim = train_data[0][0].x.shape[1]

val_losses = []
parameters = []
# Define hyperparameter grids
learning_rates = [0.01]
dropout_rates = [0.05]
sort_k_values = [50]
hidden_units_values = [16]

# Create combinations of hyperparameters
hyperparameter_combinations = list(itertools.product(learning_rates, dropout_rates, sort_k_values, hidden_units_values))

for lr, dropout_rate, sort_k, hidden_units in hyperparameter_combinations:
    model = SiameseGNN_GraphSAGE(sort_k, input_dim, dropout = dropout_rate, nhidden=hidden_units)
    val_accuracy, val_f1, val_loss, val_f2, val_05 = run_model(model, train_data, val_data)

    parameters.append((lr, dropout_rate, sort_k, hidden_units))
    val_losses.append((val_accuracy, val_f1, val_loss, val_f2, val_05))

  3%|▎         | 1/30 [00:11<05:34, 11.52s/it]

Epoch: 1, Training Loss: 0.6481233633735461, Validation Loss: 0.6358813381518601, Validation Accuracy: 0.6659707724425887, Validation F1 Score: 0.7994987468671679, Validation F2 Score: 0.9088319088319088, Validation F0.5 Score: 0.7136465324384788


  7%|▋         | 2/30 [00:21<04:54, 10.52s/it]

Epoch: 2, Training Loss: 0.6329547450833933, Validation Loss: 0.6256608172224558, Validation Accuracy: 0.6680584551148225, Validation F1 Score: 0.8005018820577164, Validation F2 Score: 0.9093500570125428, Validation F0.5 Score: 0.7149260421335724


 10%|█         | 3/30 [00:31<04:41, 10.44s/it]

Epoch: 3, Training Loss: 0.6247321942445502, Validation Loss: 0.6188318073625106, Validation Accuracy: 0.6826722338204593, Validation F1 Score: 0.7934782608695652, Validation F2 Score: 0.8623744831659775, Validation F0.5 Score: 0.7347760442878711


 13%|█▎        | 4/30 [00:41<04:22, 10.11s/it]

Epoch: 4, Training Loss: 0.6168128790775685, Validation Loss: 0.6071243526915667, Validation Accuracy: 0.6805845511482255, Validation F1 Score: 0.7895460797799174, Validation F2 Score: 0.8521377672209026, Validation F0.5 Score: 0.7355202460276781


 17%|█▋        | 5/30 [00:50<04:08,  9.92s/it]

Epoch: 5, Training Loss: 0.6087924753361486, Validation Loss: 0.5976259243637635, Validation Accuracy: 0.6951983298538622, Validation F1 Score: 0.7955182072829131, Validation F2 Score: 0.8497905445840814, Validation F0.5 Score: 0.7477619799894681


 20%|██        | 6/30 [01:00<03:56,  9.85s/it]

Epoch: 6, Training Loss: 0.590466840794094, Validation Loss: 0.574018704181922, Validation Accuracy: 0.7306889352818372, Validation F1 Score: 0.8149210903873745, Validation F2 Score: 0.8585247883917775, Validation F0.5 Score: 0.7755324959038776


 23%|██▎       | 7/30 [01:10<03:44,  9.77s/it]

Epoch: 7, Training Loss: 0.5846075551699197, Validation Loss: 0.5539753456329752, Validation Accuracy: 0.7515657620041754, Validation F1 Score: 0.8307254623044097, Validation F2 Score: 0.8795180722891566, Validation F0.5 Score: 0.7870619946091644


 27%|██▋       | 8/30 [01:19<03:33,  9.73s/it]

Epoch: 8, Training Loss: 0.5602765888629663, Validation Loss: 0.5408096230104721, Validation Accuracy: 0.7682672233820459, Validation F1 Score: 0.8456189151599444, Validation F2 Score: 0.9069212410501193, Validation F0.5 Score: 0.7920792079207921


 30%|███       | 9/30 [01:29<03:24,  9.73s/it]

Epoch: 9, Training Loss: 0.5460271945879128, Validation Loss: 0.5196730993796987, Validation Accuracy: 0.7766179540709812, Validation F1 Score: 0.8507670850767085, Validation F2 Score: 0.9109916367980884, Validation F0.5 Score: 0.7980115122972266


 33%|███▎      | 10/30 [01:39<03:13,  9.70s/it]

Epoch: 10, Training Loss: 0.5372394089944186, Validation Loss: 0.5065844561710238, Validation Accuracy: 0.7870563674321504, Validation F1 Score: 0.856338028169014, Validation F2 Score: 0.9118176364727054, Validation F0.5 Score: 0.807222517259692


 37%|███▋      | 11/30 [01:48<03:04,  9.69s/it]

Epoch: 11, Training Loss: 0.5102126046843912, Validation Loss: 0.498080731464826, Validation Accuracy: 0.7995824634655533, Validation F1 Score: 0.8620689655172413, Validation F2 Score: 0.9074410163339383, Validation F0.5 Score: 0.8210180623973727


 40%|████      | 12/30 [01:58<02:54,  9.69s/it]

Epoch: 12, Training Loss: 0.5114024890316193, Validation Loss: 0.4957503765385932, Validation Accuracy: 0.8016701461377871, Validation F1 Score: 0.8640915593705293, Validation F2 Score: 0.9118357487922706, Validation F0.5 Score: 0.8210984230560087


 43%|████▎     | 13/30 [02:08<02:46,  9.82s/it]

Epoch: 13, Training Loss: 0.5105279033211456, Validation Loss: 0.4933850051279108, Validation Accuracy: 0.7995824634655533, Validation F1 Score: 0.8628571428571429, Validation F2 Score: 0.9112854556427278, Validation F0.5 Score: 0.8193163320672816


 47%|████▋     | 14/30 [02:18<02:38,  9.89s/it]

Epoch: 14, Training Loss: 0.5134321504283226, Validation Loss: 0.49310266246750856, Validation Accuracy: 0.8016701461377871, Validation F1 Score: 0.8637015781922525, Validation F2 Score: 0.9099153567110037, Validation F0.5 Score: 0.8219552157291098


 50%|█████     | 15/30 [02:28<02:27,  9.85s/it]

Epoch: 15, Training Loss: 0.5081078928497269, Validation Loss: 0.4914052599033881, Validation Accuracy: 0.7995824634655533, Validation F1 Score: 0.8628571428571429, Validation F2 Score: 0.9112854556427278, Validation F0.5 Score: 0.8193163320672816


 53%|█████▎    | 16/30 [02:38<02:16,  9.78s/it]

Epoch: 16, Training Loss: 0.49626845784695545, Validation Loss: 0.4894738402956463, Validation Accuracy: 0.8037578288100209, Validation F1 Score: 0.8653295128939829, Validation F2 Score: 0.9123867069486404, Validation F0.5 Score: 0.8228882833787466


 57%|█████▋    | 17/30 [02:47<02:06,  9.73s/it]

Epoch: 17, Training Loss: 0.506923771281352, Validation Loss: 0.4882308792843948, Validation Accuracy: 0.8058455114822547, Validation F1 Score: 0.8661870503597122, Validation F2 Score: 0.9110169491525424, Validation F0.5 Score: 0.8255622600109709


 60%|██████    | 18/30 [02:57<01:57,  9.81s/it]

Epoch: 18, Training Loss: 0.4953189563072338, Validation Loss: 0.48664330891751545, Validation Accuracy: 0.8079331941544885, Validation F1 Score: 0.867816091954023, Validation F2 Score: 0.9134906231094979, Validation F0.5 Score: 0.8264915161466886


 63%|██████▎   | 19/30 [03:07<01:47,  9.81s/it]

Epoch: 19, Training Loss: 0.5079291255807926, Validation Loss: 0.4846809286476426, Validation Accuracy: 0.8058455114822547, Validation F1 Score: 0.866571018651363, Validation F2 Score: 0.9129383313180169, Validation F0.5 Score: 0.8246859639541234


 67%|██████▋   | 20/30 [03:17<01:38,  9.88s/it]

Epoch: 20, Training Loss: 0.4929841276931663, Validation Loss: 0.4845762709486211, Validation Accuracy: 0.8079331941544885, Validation F1 Score: 0.867816091954023, Validation F2 Score: 0.9134906231094979, Validation F0.5 Score: 0.8264915161466886


 70%|███████   | 21/30 [03:27<01:29,  9.95s/it]

Epoch: 21, Training Loss: 0.49822970240716524, Validation Loss: 0.4839696182649668, Validation Accuracy: 0.8079331941544885, Validation F1 Score: 0.867816091954023, Validation F2 Score: 0.9134906231094979, Validation F0.5 Score: 0.8264915161466886


 73%|███████▎  | 22/30 [03:37<01:19,  9.98s/it]

Epoch: 22, Training Loss: 0.4912161222788483, Validation Loss: 0.4831718907911245, Validation Accuracy: 0.8100208768267223, Validation F1 Score: 0.8690647482014389, Validation F2 Score: 0.914043583535109, Validation F0.5 Score: 0.8283049917718047


 77%|███████▋  | 23/30 [03:47<01:10, 10.01s/it]

Epoch: 23, Training Loss: 0.4973105820929369, Validation Loss: 0.4829067086999252, Validation Accuracy: 0.8100208768267223, Validation F1 Score: 0.8690647482014389, Validation F2 Score: 0.914043583535109, Validation F0.5 Score: 0.8283049917718047


 80%|████████  | 24/30 [03:58<01:01, 10.23s/it]

Epoch: 24, Training Loss: 0.5032161150357317, Validation Loss: 0.4826376924061825, Validation Accuracy: 0.8100208768267223, Validation F1 Score: 0.8690647482014389, Validation F2 Score: 0.914043583535109, Validation F0.5 Score: 0.8283049917718047


 83%|████████▎ | 25/30 [04:09<00:51, 10.38s/it]

Epoch: 25, Training Loss: 0.49085507614298673, Validation Loss: 0.48224353883609894, Validation Accuracy: 0.8100208768267223, Validation F1 Score: 0.8690647482014389, Validation F2 Score: 0.914043583535109, Validation F0.5 Score: 0.8283049917718047


 87%|████████▋ | 26/30 [04:19<00:41, 10.33s/it]

Epoch: 26, Training Loss: 0.4981920599875156, Validation Loss: 0.4818181510322527, Validation Accuracy: 0.8100208768267223, Validation F1 Score: 0.8690647482014389, Validation F2 Score: 0.914043583535109, Validation F0.5 Score: 0.8283049917718047


 90%|█████████ | 27/30 [04:30<00:31, 10.50s/it]

Epoch: 27, Training Loss: 0.49344056629835625, Validation Loss: 0.48116984409678704, Validation Accuracy: 0.8100208768267223, Validation F1 Score: 0.8690647482014389, Validation F2 Score: 0.914043583535109, Validation F0.5 Score: 0.8283049917718047


 93%|█████████▎| 28/30 [04:41<00:21, 10.54s/it]

Epoch: 28, Training Loss: 0.49227350000428305, Validation Loss: 0.48069631051668793, Validation Accuracy: 0.8100208768267223, Validation F1 Score: 0.8690647482014389, Validation F2 Score: 0.914043583535109, Validation F0.5 Score: 0.8283049917718047


 97%|█████████▋| 29/30 [04:51<00:10, 10.52s/it]

Epoch: 29, Training Loss: 0.4883902911388762, Validation Loss: 0.48038555467054095, Validation Accuracy: 0.8100208768267223, Validation F1 Score: 0.8690647482014389, Validation F2 Score: 0.914043583535109, Validation F0.5 Score: 0.8283049917718047


100%|██████████| 30/30 [05:00<00:00, 10.03s/it]

Epoch: 30, Training Loss: 0.49477145925958327, Validation Loss: 0.4803818877716403, Validation Accuracy: 0.8121085594989561, Validation F1 Score: 0.8706896551724138, Validation F2 Score: 0.9165154264972777, Validation F0.5 Score: 0.8292282430213465





In [19]:
input_dim = train_data[0][0].x.shape[1]

val_losses = []
parameters = []
# Define hyperparameter grids
learning_rates = [0.01]
dropout_rates = [0.05]
sort_k_values = [50]
hidden_units_values = [16]

# Create combinations of hyperparameters
hyperparameter_combinations = list(itertools.product(learning_rates, dropout_rates, sort_k_values, hidden_units_values))

for lr, dropout_rate, sort_k, hidden_units in hyperparameter_combinations:
    model = SiameseGNN_GIN(sort_k, input_dim, dropout = dropout_rate, nhidden=hidden_units)
    val_accuracy, val_f1, val_loss, val_f2, val_05 = run_model(model, train_data, val_data)

    parameters.append((lr, dropout_rate, sort_k, hidden_units))
    val_losses.append((val_accuracy, val_f1, val_loss, val_f2, val_05))

  3%|▎         | 1/30 [00:10<04:57, 10.25s/it]

Epoch: 1, Training Loss: 0.6480919939534916, Validation Loss: 0.6365035688951767, Validation Accuracy: 0.6659707724425887, Validation F1 Score: 0.7994987468671679, Validation F2 Score: 0.9088319088319088, Validation F0.5 Score: 0.7136465324384788


  7%|▋         | 2/30 [00:20<04:45, 10.20s/it]

Epoch: 2, Training Loss: 0.6448064789627538, Validation Loss: 0.6366730009464232, Validation Accuracy: 0.6659707724425887, Validation F1 Score: 0.7994987468671679, Validation F2 Score: 0.9088319088319088, Validation F0.5 Score: 0.7136465324384788


 10%|█         | 3/30 [00:31<04:42, 10.46s/it]

Epoch: 3, Training Loss: 0.6395813736484591, Validation Loss: 0.6366747594700974, Validation Accuracy: 0.6659707724425887, Validation F1 Score: 0.7994987468671679, Validation F2 Score: 0.9088319088319088, Validation F0.5 Score: 0.7136465324384788


 13%|█▎        | 4/30 [00:41<04:28, 10.34s/it]

Epoch: 4, Training Loss: 0.6422930153684308, Validation Loss: 0.6367560179596903, Validation Accuracy: 0.6659707724425887, Validation F1 Score: 0.7994987468671679, Validation F2 Score: 0.9088319088319088, Validation F0.5 Score: 0.7136465324384788


 17%|█▋        | 5/30 [00:51<04:16, 10.25s/it]

Epoch: 5, Training Loss: 0.6374633008596665, Validation Loss: 0.6363544685606668, Validation Accuracy: 0.6659707724425887, Validation F1 Score: 0.7994987468671679, Validation F2 Score: 0.9088319088319088, Validation F0.5 Score: 0.7136465324384788


 20%|██        | 6/30 [01:01<04:03, 10.16s/it]

Epoch: 6, Training Loss: 0.6386842441783056, Validation Loss: 0.6362244814074363, Validation Accuracy: 0.6659707724425887, Validation F1 Score: 0.7994987468671679, Validation F2 Score: 0.9088319088319088, Validation F0.5 Score: 0.7136465324384788


 23%|██▎       | 7/30 [01:11<03:53, 10.13s/it]

Epoch: 7, Training Loss: 0.6385959388071217, Validation Loss: 0.6365253348564555, Validation Accuracy: 0.6659707724425887, Validation F1 Score: 0.7994987468671679, Validation F2 Score: 0.9088319088319088, Validation F0.5 Score: 0.7136465324384788


 27%|██▋       | 8/30 [01:21<03:41, 10.08s/it]

Epoch: 8, Training Loss: 0.639627213450684, Validation Loss: 0.6361766690874399, Validation Accuracy: 0.6659707724425887, Validation F1 Score: 0.7994987468671679, Validation F2 Score: 0.9088319088319088, Validation F0.5 Score: 0.7136465324384788


 30%|███       | 9/30 [01:31<03:30, 10.03s/it]

Epoch: 9, Training Loss: 0.6383901171176037, Validation Loss: 0.6358015788621842, Validation Accuracy: 0.6659707724425887, Validation F1 Score: 0.7994987468671679, Validation F2 Score: 0.9088319088319088, Validation F0.5 Score: 0.7136465324384788


 33%|███▎      | 10/30 [01:41<03:20, 10.00s/it]

Epoch: 10, Training Loss: 0.6403148087401375, Validation Loss: 0.6355699213038905, Validation Accuracy: 0.6659707724425887, Validation F1 Score: 0.7994987468671679, Validation F2 Score: 0.9088319088319088, Validation F0.5 Score: 0.7136465324384788


 37%|███▋      | 11/30 [01:51<03:10, 10.01s/it]

Epoch: 11, Training Loss: 0.6390609932730564, Validation Loss: 0.6354558948808523, Validation Accuracy: 0.6659707724425887, Validation F1 Score: 0.7994987468671679, Validation F2 Score: 0.9088319088319088, Validation F0.5 Score: 0.7136465324384788


 40%|████      | 12/30 [02:01<03:03, 10.19s/it]

Epoch: 12, Training Loss: 0.6381173382358492, Validation Loss: 0.6353064742491489, Validation Accuracy: 0.6659707724425887, Validation F1 Score: 0.7994987468671679, Validation F2 Score: 0.9088319088319088, Validation F0.5 Score: 0.7136465324384788


 43%|████▎     | 13/30 [02:12<02:54, 10.27s/it]

Epoch: 13, Training Loss: 0.6382449728319015, Validation Loss: 0.635208802892669, Validation Accuracy: 0.6659707724425887, Validation F1 Score: 0.7994987468671679, Validation F2 Score: 0.9088319088319088, Validation F0.5 Score: 0.7136465324384788


 47%|████▋     | 14/30 [02:23<02:48, 10.53s/it]

Epoch: 14, Training Loss: 0.6361683528002278, Validation Loss: 0.635091978162216, Validation Accuracy: 0.6659707724425887, Validation F1 Score: 0.7994987468671679, Validation F2 Score: 0.9088319088319088, Validation F0.5 Score: 0.7136465324384788


 50%|█████     | 15/30 [02:34<02:38, 10.54s/it]

Epoch: 15, Training Loss: 0.63619998991178, Validation Loss: 0.634958361782958, Validation Accuracy: 0.6659707724425887, Validation F1 Score: 0.7994987468671679, Validation F2 Score: 0.9088319088319088, Validation F0.5 Score: 0.7136465324384788


 53%|█████▎    | 16/30 [02:45<02:30, 10.73s/it]

Epoch: 16, Training Loss: 0.637083893161084, Validation Loss: 0.6348655706034324, Validation Accuracy: 0.6659707724425887, Validation F1 Score: 0.7994987468671679, Validation F2 Score: 0.9088319088319088, Validation F0.5 Score: 0.7136465324384788


 57%|█████▋    | 17/30 [02:56<02:20, 10.78s/it]

Epoch: 17, Training Loss: 0.6392837983002359, Validation Loss: 0.6347446686788492, Validation Accuracy: 0.6659707724425887, Validation F1 Score: 0.7994987468671679, Validation F2 Score: 0.9088319088319088, Validation F0.5 Score: 0.7136465324384788


 60%|██████    | 18/30 [03:06<02:08, 10.74s/it]

Epoch: 18, Training Loss: 0.6397717082824328, Validation Loss: 0.6346479213187988, Validation Accuracy: 0.6659707724425887, Validation F1 Score: 0.7994987468671679, Validation F2 Score: 0.9088319088319088, Validation F0.5 Score: 0.7136465324384788


 63%|██████▎   | 19/30 [03:17<01:57, 10.68s/it]

Epoch: 19, Training Loss: 0.6382704797894728, Validation Loss: 0.6345798381932842, Validation Accuracy: 0.6659707724425887, Validation F1 Score: 0.7994987468671679, Validation F2 Score: 0.9088319088319088, Validation F0.5 Score: 0.7136465324384788


 67%|██████▋   | 20/30 [03:27<01:46, 10.61s/it]

Epoch: 20, Training Loss: 0.633850397392251, Validation Loss: 0.6344633024022575, Validation Accuracy: 0.6659707724425887, Validation F1 Score: 0.7994987468671679, Validation F2 Score: 0.9088319088319088, Validation F0.5 Score: 0.7136465324384788


 70%|███████   | 21/30 [03:38<01:34, 10.55s/it]

Epoch: 21, Training Loss: 0.6367093426182228, Validation Loss: 0.6344445360851686, Validation Accuracy: 0.6659707724425887, Validation F1 Score: 0.7994987468671679, Validation F2 Score: 0.9088319088319088, Validation F0.5 Score: 0.7136465324384788


 73%|███████▎  | 22/30 [03:48<01:24, 10.52s/it]

Epoch: 22, Training Loss: 0.6347673732472064, Validation Loss: 0.634431740187404, Validation Accuracy: 0.6659707724425887, Validation F1 Score: 0.7994987468671679, Validation F2 Score: 0.9088319088319088, Validation F0.5 Score: 0.7136465324384788


 77%|███████▋  | 23/30 [03:59<01:13, 10.50s/it]

Epoch: 23, Training Loss: 0.6339655540953609, Validation Loss: 0.6344139977885188, Validation Accuracy: 0.6659707724425887, Validation F1 Score: 0.7994987468671679, Validation F2 Score: 0.9088319088319088, Validation F0.5 Score: 0.7136465324384788


 80%|████████  | 24/30 [04:09<01:02, 10.41s/it]

Epoch: 24, Training Loss: 0.6395318818939393, Validation Loss: 0.6343865251491363, Validation Accuracy: 0.6659707724425887, Validation F1 Score: 0.7994987468671679, Validation F2 Score: 0.9088319088319088, Validation F0.5 Score: 0.7136465324384788


 83%|████████▎ | 25/30 [04:19<00:51, 10.40s/it]

Epoch: 25, Training Loss: 0.6362583662462085, Validation Loss: 0.6343555156149296, Validation Accuracy: 0.6659707724425887, Validation F1 Score: 0.7994987468671679, Validation F2 Score: 0.9088319088319088, Validation F0.5 Score: 0.7136465324384788


 87%|████████▋ | 26/30 [04:31<00:43, 10.79s/it]

Epoch: 26, Training Loss: 0.6401289324339405, Validation Loss: 0.634340927347012, Validation Accuracy: 0.6659707724425887, Validation F1 Score: 0.7994987468671679, Validation F2 Score: 0.9088319088319088, Validation F0.5 Score: 0.7136465324384788


 90%|█████████ | 27/30 [04:42<00:32, 10.96s/it]

Epoch: 27, Training Loss: 0.6366008912508017, Validation Loss: 0.634325147546159, Validation Accuracy: 0.6659707724425887, Validation F1 Score: 0.7994987468671679, Validation F2 Score: 0.9088319088319088, Validation F0.5 Score: 0.7136465324384788


 93%|█████████▎| 28/30 [04:54<00:22, 11.07s/it]

Epoch: 28, Training Loss: 0.6367510453785226, Validation Loss: 0.6343132530597655, Validation Accuracy: 0.6659707724425887, Validation F1 Score: 0.7994987468671679, Validation F2 Score: 0.9088319088319088, Validation F0.5 Score: 0.7136465324384788


 97%|█████████▋| 29/30 [05:05<00:11, 11.04s/it]

Epoch: 29, Training Loss: 0.6344919411802242, Validation Loss: 0.6343000412360611, Validation Accuracy: 0.6659707724425887, Validation F1 Score: 0.7994987468671679, Validation F2 Score: 0.9088319088319088, Validation F0.5 Score: 0.7136465324384788


100%|██████████| 30/30 [05:15<00:00, 10.52s/it]

Epoch: 30, Training Loss: 0.6359280158795783, Validation Loss: 0.6342895552982617, Validation Accuracy: 0.6659707724425887, Validation F1 Score: 0.7994987468671679, Validation F2 Score: 0.9088319088319088, Validation F0.5 Score: 0.7136465324384788





## sGNN with Feature Subset

In [20]:
with open("feature_dicts/filtered_features_dict.pkl", "rb") as f:
    feat_dict = pkl.load(f)

In [21]:
def add_features(years, graphs, feat_dict, dim):

    zeros = torch.zeros(dim)

    for i in range(len(years)):
        new_x = torch.empty(0, dim)
        year = years[i]

        feat_dict_year = feat_dict[year].combined_features

        for j, country in enumerate(all_nodes):
            if j == 0:
                new_x = torch.stack([zeros])

            elif country in feat_dict_year["country_code"].values:
                tensor_before = graphs[i].x[j]
                country_row = feat_dict_year[feat_dict_year["country_code"] == country]
                country_row = country_row.drop(columns = ["country_code", "current_gdp_growth"])
                row_values = country_row.values.tolist()
                row_tensor = torch.tensor(row_values)[0]
                combined_values = torch.cat((tensor_before, row_tensor))

                new_x = torch.cat((new_x, combined_values.unsqueeze(0)), dim=0)

            else:
                new_x = torch.cat((new_x, zeros.unsqueeze(0)), dim=0)

        graphs[i].x = new_x

    return graphs

In [22]:
all_graphs = add_features(years, all_graphs, feat_dict, 27)

In [23]:
labels = dist_labels_to_changepoint_labels(phases)
graph_pairs = sample_pairs(all_graphs,labels)

798 positive and 540 negative examples


In [24]:
graph_pairs_train = [graph_pairs[i] for i in train_indices]
graph_pairs_test = [graph_pairs[i] for i in test_indices]
graph_pairs_val = [graph_pairs[i] for i in val_indices]

In [25]:
import random
positive_samples = [item for item in graph_pairs_train if item[2] == 1]
negative_samples = [item for item in graph_pairs_train if item[2] == 0]

# Calculate the difference in count
diff = len(negative_samples) - len(positive_samples)

# Upsample positive samples
if diff > 0:
    positive_samples_upsampled = positive_samples * (diff // len(positive_samples)) + random.sample(positive_samples, diff % len(positive_samples))
    balanced_data = negative_samples + positive_samples + positive_samples_upsampled
else:
    balanced_data = graph_pairs_train

random.shuffle(balanced_data)

In [26]:
input_dim = balanced_data[0][0].x.shape[1]

val_losses = []
parameters = []
# Define hyperparameter grids
learning_rates = [0.01]
dropout_rates = [0.05]
sort_k_values = [50]
hidden_units_values = [16]

# Create combinations of hyperparameters
hyperparameter_combinations = list(itertools.product(learning_rates, dropout_rates, sort_k_values, hidden_units_values))

for lr, dropout_rate, sort_k, hidden_units in hyperparameter_combinations:
    model = SiameseGNN(sort_k, input_dim, dropout = dropout_rate, nhidden=hidden_units)
    val_accuracy, val_f1, val_loss, val_f2, val_05 = run_model(model, balanced_data, graph_pairs_val)

    parameters.append((lr, dropout_rate, sort_k, hidden_units))
    val_losses.append((val_accuracy, val_f1, val_loss, val_f2, val_05))

  3%|▎         | 1/30 [00:14<07:13, 14.95s/it]

Epoch: 1, Training Loss: 0.6528520211966683, Validation Loss: 0.6442973780482696, Validation Accuracy: 0.651356993736952, Validation F1 Score: 0.7888748419721872, Validation F2 Score: 0.903300521134916, Validation F0.5 Score: 0.7001795332136446


  7%|▋         | 2/30 [00:28<06:40, 14.32s/it]

Epoch: 2, Training Loss: 0.6406836706206841, Validation Loss: 0.64019581999062, Validation Accuracy: 0.651356993736952, Validation F1 Score: 0.7888748419721872, Validation F2 Score: 0.903300521134916, Validation F0.5 Score: 0.7001795332136446


 10%|█         | 3/30 [00:43<06:33, 14.58s/it]

Epoch: 3, Training Loss: 0.6355618574153418, Validation Loss: 0.6364171766289092, Validation Accuracy: 0.6430062630480167, Validation F1 Score: 0.7793548387096774, Validation F2 Score: 0.8825248392752776, Validation F0.5 Score: 0.6977818853974121


 13%|█▎        | 4/30 [00:58<06:17, 14.52s/it]

Epoch: 4, Training Loss: 0.6288308646362528, Validation Loss: 0.6321766156244377, Validation Accuracy: 0.6555323590814196, Validation F1 Score: 0.7779273216689099, Validation F2 Score: 0.8606313281715307, Validation F0.5 Score: 0.7097249508840865


 17%|█▋        | 5/30 [01:14<06:19, 15.20s/it]

Epoch: 5, Training Loss: 0.6194453006503226, Validation Loss: 0.6240952484045248, Validation Accuracy: 0.6659707724425887, Validation F1 Score: 0.7808219178082192, Validation F2 Score: 0.8553421368547419, Validation F0.5 Score: 0.7182459677419355


 20%|██        | 6/30 [01:31<06:21, 15.91s/it]

Epoch: 6, Training Loss: 0.6129096354374443, Validation Loss: 0.6182681668028702, Validation Accuracy: 0.6680584551148225, Validation F1 Score: 0.7818930041152263, Validation F2 Score: 0.8558558558558559, Validation F0.5 Score: 0.7196969696969697


 23%|██▎       | 7/30 [01:46<05:56, 15.49s/it]

Epoch: 7, Training Loss: 0.590306650446251, Validation Loss: 0.5923610834836462, Validation Accuracy: 0.7056367432150313, Validation F1 Score: 0.7977044476327116, Validation F2 Score: 0.8511941212492345, Validation F0.5 Score: 0.7505399568034558


 27%|██▋       | 8/30 [02:02<05:43, 15.59s/it]

Epoch: 8, Training Loss: 0.5765988381741064, Validation Loss: 0.5694187709085627, Validation Accuracy: 0.7390396659707724, Validation F1 Score: 0.8159057437407953, Validation F2 Score: 0.8575851393188855, Validation F0.5 Score: 0.7780898876404494


 30%|███       | 9/30 [02:16<05:21, 15.31s/it]

Epoch: 9, Training Loss: 0.5514896399922505, Validation Loss: 0.5553145393075923, Validation Accuracy: 0.7536534446764092, Validation F1 Score: 0.8238805970149253, Validation F2 Score: 0.8592777085927771, Validation F0.5 Score: 0.7912844036697247


 33%|███▎      | 10/30 [02:31<05:04, 15.21s/it]

Epoch: 10, Training Loss: 0.5276927527277696, Validation Loss: 0.5377880177132023, Validation Accuracy: 0.7578288100208769, Validation F1 Score: 0.8263473053892215, Validation F2 Score: 0.8603491271820449, Validation F0.5 Score: 0.7949308755760369


 37%|███▋      | 11/30 [02:50<05:05, 16.08s/it]

Epoch: 11, Training Loss: 0.5233535448025013, Validation Loss: 0.5352935830523923, Validation Accuracy: 0.7640918580375783, Validation F1 Score: 0.8300751879699249, Validation F2 Score: 0.8619612742036228, Validation F0.5 Score: 0.8004640371229699


 40%|████      | 12/30 [03:07<04:54, 16.38s/it]

Epoch: 12, Training Loss: 0.5082758408946801, Validation Loss: 0.5340380400170364, Validation Accuracy: 0.7640918580375783, Validation F1 Score: 0.8300751879699249, Validation F2 Score: 0.8619612742036228, Validation F0.5 Score: 0.8004640371229699


 43%|████▎     | 13/30 [03:21<04:30, 15.92s/it]

Epoch: 13, Training Loss: 0.5219514575497858, Validation Loss: 0.5343357251494314, Validation Accuracy: 0.7620041753653445, Validation F1 Score: 0.8288288288288288, Validation F2 Score: 0.8614232209737828, Validation F0.5 Score: 0.7986111111111112


 47%|████▋     | 14/30 [03:37<04:11, 15.73s/it]

Epoch: 14, Training Loss: 0.5202987566045335, Validation Loss: 0.533031208685644, Validation Accuracy: 0.7620041753653445, Validation F1 Score: 0.8288288288288288, Validation F2 Score: 0.8614232209737828, Validation F0.5 Score: 0.7986111111111112


 50%|█████     | 15/30 [03:52<03:52, 15.52s/it]

Epoch: 15, Training Loss: 0.5161121425646101, Validation Loss: 0.5300560110483886, Validation Accuracy: 0.7640918580375783, Validation F1 Score: 0.8300751879699249, Validation F2 Score: 0.8619612742036228, Validation F0.5 Score: 0.8004640371229699


 53%|█████▎    | 16/30 [04:05<03:29, 14.93s/it]

Epoch: 16, Training Loss: 0.5209159038750729, Validation Loss: 0.5308893635277957, Validation Accuracy: 0.7640918580375783, Validation F1 Score: 0.8300751879699249, Validation F2 Score: 0.8619612742036228, Validation F0.5 Score: 0.8004640371229699


 57%|█████▋    | 17/30 [04:20<03:12, 14.80s/it]

Epoch: 17, Training Loss: 0.514137675115681, Validation Loss: 0.5295077089476934, Validation Accuracy: 0.7661795407098121, Validation F1 Score: 0.8313253012048193, Validation F2 Score: 0.8625, Validation F0.5 Score: 0.8023255813953488


 60%|██████    | 18/30 [04:34<02:56, 14.75s/it]

Epoch: 18, Training Loss: 0.5134301931444505, Validation Loss: 0.5288866634533151, Validation Accuracy: 0.7661795407098121, Validation F1 Score: 0.8313253012048193, Validation F2 Score: 0.8625, Validation F0.5 Score: 0.8023255813953488


 63%|██████▎   | 19/30 [04:50<02:43, 14.88s/it]

Epoch: 19, Training Loss: 0.5169051020808105, Validation Loss: 0.5279551540697791, Validation Accuracy: 0.7661795407098121, Validation F1 Score: 0.8313253012048193, Validation F2 Score: 0.8625, Validation F0.5 Score: 0.8023255813953488


 67%|██████▋   | 20/30 [05:05<02:29, 14.97s/it]

Epoch: 20, Training Loss: 0.5124763831449527, Validation Loss: 0.5279409840609683, Validation Accuracy: 0.7640918580375783, Validation F1 Score: 0.8300751879699249, Validation F2 Score: 0.8619612742036228, Validation F0.5 Score: 0.8004640371229699


 70%|███████   | 21/30 [05:20<02:14, 14.95s/it]

Epoch: 21, Training Loss: 0.5119050221757082, Validation Loss: 0.5278357617043752, Validation Accuracy: 0.7640918580375783, Validation F1 Score: 0.8300751879699249, Validation F2 Score: 0.8619612742036228, Validation F0.5 Score: 0.8004640371229699


 73%|███████▎  | 22/30 [05:35<02:00, 15.10s/it]

Epoch: 22, Training Loss: 0.5169253618365919, Validation Loss: 0.5277223561215749, Validation Accuracy: 0.7661795407098121, Validation F1 Score: 0.8313253012048193, Validation F2 Score: 0.8625, Validation F0.5 Score: 0.8023255813953488


 77%|███████▋  | 23/30 [05:49<01:43, 14.78s/it]

Epoch: 23, Training Loss: 0.5114048073689142, Validation Loss: 0.5277102918739359, Validation Accuracy: 0.7640918580375783, Validation F1 Score: 0.8300751879699249, Validation F2 Score: 0.8619612742036228, Validation F0.5 Score: 0.8004640371229699


 80%|████████  | 24/30 [06:03<01:27, 14.55s/it]

Epoch: 24, Training Loss: 0.5100311706214364, Validation Loss: 0.5276891565833062, Validation Accuracy: 0.7640918580375783, Validation F1 Score: 0.8300751879699249, Validation F2 Score: 0.8619612742036228, Validation F0.5 Score: 0.8004640371229699


 83%|████████▎ | 25/30 [06:18<01:13, 14.70s/it]

Epoch: 25, Training Loss: 0.5173340677093936, Validation Loss: 0.527479046601841, Validation Accuracy: 0.7661795407098121, Validation F1 Score: 0.8313253012048193, Validation F2 Score: 0.8625, Validation F0.5 Score: 0.8023255813953488


 87%|████████▋ | 26/30 [06:32<00:57, 14.48s/it]

Epoch: 26, Training Loss: 0.5067473449895126, Validation Loss: 0.5272664905278817, Validation Accuracy: 0.7661795407098121, Validation F1 Score: 0.8313253012048193, Validation F2 Score: 0.8625, Validation F0.5 Score: 0.8023255813953488


 90%|█████████ | 27/30 [06:46<00:42, 14.28s/it]

Epoch: 27, Training Loss: 0.5123157065552479, Validation Loss: 0.5271525858961714, Validation Accuracy: 0.7661795407098121, Validation F1 Score: 0.8313253012048193, Validation F2 Score: 0.8625, Validation F0.5 Score: 0.8023255813953488


 93%|█████████▎| 28/30 [07:00<00:28, 14.30s/it]

Epoch: 28, Training Loss: 0.5169264792212622, Validation Loss: 0.5270805072684875, Validation Accuracy: 0.7661795407098121, Validation F1 Score: 0.8313253012048193, Validation F2 Score: 0.8625, Validation F0.5 Score: 0.8023255813953488


 97%|█████████▋| 29/30 [07:16<00:14, 14.59s/it]

Epoch: 29, Training Loss: 0.5086185338137167, Validation Loss: 0.5270868931782271, Validation Accuracy: 0.7661795407098121, Validation F1 Score: 0.8313253012048193, Validation F2 Score: 0.8625, Validation F0.5 Score: 0.8023255813953488


100%|██████████| 30/30 [07:31<00:00, 15.03s/it]

Epoch: 30, Training Loss: 0.5082922933648497, Validation Loss: 0.5269264378789074, Validation Accuracy: 0.7661795407098121, Validation F1 Score: 0.8313253012048193, Validation F2 Score: 0.8625, Validation F0.5 Score: 0.8023255813953488





In [35]:
input_dim = balanced_data[0][0].x.shape[1]

val_losses = []
parameters = []
# Define hyperparameter grids
learning_rates = [0.01]
dropout_rates = [0.05]
sort_k_values = [50]
hidden_units_values = [16]

# Create combinations of hyperparameters
hyperparameter_combinations = list(itertools.product(learning_rates, dropout_rates, sort_k_values, hidden_units_values))

for lr, dropout_rate, sort_k, hidden_units in hyperparameter_combinations:
    model = SiameseGNN_GraphSAGE(sort_k, input_dim, dropout = dropout_rate, nhidden=hidden_units)
    val_accuracy, val_f1, val_loss, val_f2, val_05 = run_model(model, balanced_data, graph_pairs_val)

    parameters.append((lr, dropout_rate, sort_k, hidden_units))
    val_losses.append((val_accuracy, val_f1, val_loss, val_f2, val_05))

  3%|▎         | 1/30 [00:10<05:12, 10.76s/it]

Epoch: 1, Training Loss: 0.6463643824418884, Validation Loss: 0.6326048406529278, Validation Accuracy: 0.6534446764091858, Validation F1 Score: 0.7882653061224489, Validation F2 Score: 0.8982558139534884, Validation F0.5 Score: 0.7022727272727273


  7%|▋         | 2/30 [00:20<04:43, 10.14s/it]

Epoch: 2, Training Loss: 0.6161229763035988, Validation Loss: 0.609263271528893, Validation Accuracy: 0.6534446764091858, Validation F1 Score: 0.7780748663101604, Validation F2 Score: 0.8640142517814727, Validation F0.5 Score: 0.7076848249027238


 10%|█         | 3/30 [00:31<04:48, 10.69s/it]

Epoch: 3, Training Loss: 0.5903083585210479, Validation Loss: 0.5738675149472621, Validation Accuracy: 0.7223382045929019, Validation F1 Score: 0.8058394160583942, Validation F2 Score: 0.851326341764343, Validation F0.5 Score: 0.7649667405764967


 13%|█▎        | 4/30 [00:43<04:44, 10.92s/it]

Epoch: 4, Training Loss: 0.5566923924748053, Validation Loss: 0.5378189006653111, Validation Accuracy: 0.7661795407098121, Validation F1 Score: 0.8348082595870207, Validation F2 Score: 0.8767038413878563, Validation F0.5 Score: 0.7967342342342343


 17%|█▋        | 5/30 [00:53<04:24, 10.59s/it]

Epoch: 5, Training Loss: 0.5259593166371996, Validation Loss: 0.49676286399613345, Validation Accuracy: 0.8016701461377871, Validation F1 Score: 0.8575712143928036, Validation F2 Score: 0.8920773549594511, Validation F0.5 Score: 0.825635103926097


 20%|██        | 6/30 [01:03<04:12, 10.53s/it]

Epoch: 6, Training Loss: 0.4797147912914352, Validation Loss: 0.4682116116200708, Validation Accuracy: 0.8204592901878914, Validation F1 Score: 0.8716417910447761, Validation F2 Score: 0.9090909090909091, Validation F0.5 Score: 0.8371559633027523


 23%|██▎       | 7/30 [01:13<03:57, 10.32s/it]

Epoch: 7, Training Loss: 0.46619446816987503, Validation Loss: 0.4575302636635826, Validation Accuracy: 0.824634655532359, Validation F1 Score: 0.8738738738738738, Validation F2 Score: 0.9082397003745318, Validation F0.5 Score: 0.8420138888888888


 27%|██▋       | 8/30 [01:23<03:43, 10.17s/it]

Epoch: 8, Training Loss: 0.4390912171970963, Validation Loss: 0.45926982375923425, Validation Accuracy: 0.826722338204593, Validation F1 Score: 0.8692913385826772, Validation F2 Score: 0.8784213876511776, Validation F0.5 Score: 0.8603491271820449


 30%|███       | 9/30 [01:33<03:32, 10.10s/it]

Epoch: 9, Training Loss: 0.4295627864052385, Validation Loss: 0.40449811474639835, Validation Accuracy: 0.860125260960334, Validation F1 Score: 0.8951486697965572, Validation F2 Score: 0.9079365079365079, Validation F0.5 Score: 0.8827160493827161


 33%|███▎      | 10/30 [01:42<03:20, 10.01s/it]

Epoch: 10, Training Loss: 0.4076511204522979, Validation Loss: 0.39348557246129545, Validation Accuracy: 0.8643006263048016, Validation F1 Score: 0.8992248062015504, Validation F2 Score: 0.9171410499683744, Validation F0.5 Score: 0.8819951338199513


 37%|███▋      | 11/30 [01:53<03:14, 10.26s/it]

Epoch: 11, Training Loss: 0.3894116668927881, Validation Loss: 0.3895416121642127, Validation Accuracy: 0.8684759916492694, Validation F1 Score: 0.9023255813953488, Validation F2 Score: 0.920303605313093, Validation F0.5 Score: 0.885036496350365


 40%|████      | 12/30 [02:03<03:03, 10.19s/it]

Epoch: 12, Training Loss: 0.3876235538951531, Validation Loss: 0.38776374037430034, Validation Accuracy: 0.8684759916492694, Validation F1 Score: 0.9023255813953488, Validation F2 Score: 0.920303605313093, Validation F0.5 Score: 0.885036496350365


 43%|████▎     | 13/30 [02:13<02:49,  9.99s/it]

Epoch: 13, Training Loss: 0.38731039280428037, Validation Loss: 0.3862542766992135, Validation Accuracy: 0.8705636743215032, Validation F1 Score: 0.9037267080745341, Validation F2 Score: 0.9208860759493671, Validation F0.5 Score: 0.8871951219512195


 47%|████▋     | 14/30 [02:22<02:37,  9.85s/it]

Epoch: 14, Training Loss: 0.3829530996183866, Validation Loss: 0.3838259081875357, Validation Accuracy: 0.8705636743215032, Validation F1 Score: 0.9040247678018576, Validation F2 Score: 0.922882427307206, Validation F0.5 Score: 0.8859223300970874


 50%|█████     | 15/30 [02:32<02:26,  9.75s/it]

Epoch: 15, Training Loss: 0.38320269800777096, Validation Loss: 0.3830801777712736, Validation Accuracy: 0.8705636743215032, Validation F1 Score: 0.9040247678018576, Validation F2 Score: 0.922882427307206, Validation F0.5 Score: 0.8859223300970874


 53%|█████▎    | 16/30 [02:42<02:19,  9.95s/it]

Epoch: 16, Training Loss: 0.3728444862608626, Validation Loss: 0.38819903069971, Validation Accuracy: 0.8663883089770354, Validation F1 Score: 0.9003115264797508, Validation F2 Score: 0.9157160963244614, Validation F0.5 Score: 0.8854166666666666


 57%|█████▋    | 17/30 [02:52<02:07,  9.80s/it]

Epoch: 17, Training Loss: 0.38644888902096663, Validation Loss: 0.38380315184344327, Validation Accuracy: 0.8705636743215032, Validation F1 Score: 0.9040247678018576, Validation F2 Score: 0.922882427307206, Validation F0.5 Score: 0.8859223300970874


 60%|██████    | 18/30 [03:01<01:57,  9.77s/it]

Epoch: 18, Training Loss: 0.37743638790637846, Validation Loss: 0.38380810540628335, Validation Accuracy: 0.8684759916492694, Validation F1 Score: 0.9023255813953488, Validation F2 Score: 0.920303605313093, Validation F0.5 Score: 0.885036496350365


 63%|██████▎   | 19/30 [03:12<01:49,  9.93s/it]

Epoch: 19, Training Loss: 0.38314793938566527, Validation Loss: 0.38197824847349793, Validation Accuracy: 0.8705636743215032, Validation F1 Score: 0.9040247678018576, Validation F2 Score: 0.922882427307206, Validation F0.5 Score: 0.8859223300970874


 67%|██████▋   | 20/30 [03:22<01:39,  9.98s/it]

Epoch: 20, Training Loss: 0.3802763980072743, Validation Loss: 0.38231093861116, Validation Accuracy: 0.872651356993737, Validation F1 Score: 0.9054263565891473, Validation F2 Score: 0.9234661606578115, Validation F0.5 Score: 0.8880778588807786


 70%|███████   | 21/30 [03:32<01:30, 10.01s/it]

Epoch: 21, Training Loss: 0.38039229472353664, Validation Loss: 0.38170647235404476, Validation Accuracy: 0.8705636743215032, Validation F1 Score: 0.9040247678018576, Validation F2 Score: 0.922882427307206, Validation F0.5 Score: 0.8859223300970874


 73%|███████▎  | 22/30 [03:42<01:19,  9.94s/it]

Epoch: 22, Training Loss: 0.38153224328655433, Validation Loss: 0.3815228234439902, Validation Accuracy: 0.8705636743215032, Validation F1 Score: 0.9040247678018576, Validation F2 Score: 0.922882427307206, Validation F0.5 Score: 0.8859223300970874


 77%|███████▋  | 23/30 [03:52<01:09,  9.95s/it]

Epoch: 23, Training Loss: 0.3751936148824861, Validation Loss: 0.38129991793806717, Validation Accuracy: 0.8705636743215032, Validation F1 Score: 0.9040247678018576, Validation F2 Score: 0.922882427307206, Validation F0.5 Score: 0.8859223300970874


 80%|████████  | 24/30 [04:01<00:59,  9.88s/it]

Epoch: 24, Training Loss: 0.3733235172218913, Validation Loss: 0.3814420901658888, Validation Accuracy: 0.8705636743215032, Validation F1 Score: 0.9040247678018576, Validation F2 Score: 0.922882427307206, Validation F0.5 Score: 0.8859223300970874


 83%|████████▎ | 25/30 [04:11<00:49,  9.87s/it]

Epoch: 25, Training Loss: 0.37945554375087953, Validation Loss: 0.38103648127369694, Validation Accuracy: 0.8705636743215032, Validation F1 Score: 0.9040247678018576, Validation F2 Score: 0.922882427307206, Validation F0.5 Score: 0.8859223300970874


 87%|████████▋ | 26/30 [04:21<00:38,  9.72s/it]

Epoch: 26, Training Loss: 0.3734965540367483, Validation Loss: 0.3813081807408303, Validation Accuracy: 0.8705636743215032, Validation F1 Score: 0.9040247678018576, Validation F2 Score: 0.922882427307206, Validation F0.5 Score: 0.8859223300970874


 90%|█████████ | 27/30 [04:30<00:28,  9.62s/it]

Epoch: 27, Training Loss: 0.3707537597343092, Validation Loss: 0.3811555138336094, Validation Accuracy: 0.8705636743215032, Validation F1 Score: 0.9040247678018576, Validation F2 Score: 0.922882427307206, Validation F0.5 Score: 0.8859223300970874


 93%|█████████▎| 28/30 [04:39<00:19,  9.53s/it]

Epoch: 28, Training Loss: 0.376301123711009, Validation Loss: 0.38093539425376066, Validation Accuracy: 0.8705636743215032, Validation F1 Score: 0.9040247678018576, Validation F2 Score: 0.922882427307206, Validation F0.5 Score: 0.8859223300970874


 97%|█████████▋| 29/30 [04:49<00:09,  9.45s/it]

Epoch: 29, Training Loss: 0.37827936313119925, Validation Loss: 0.38092407350873647, Validation Accuracy: 0.8705636743215032, Validation F1 Score: 0.9040247678018576, Validation F2 Score: 0.922882427307206, Validation F0.5 Score: 0.8859223300970874


100%|██████████| 30/30 [04:58<00:00,  9.95s/it]

Epoch: 30, Training Loss: 0.3790532212116626, Validation Loss: 0.38078919964568353, Validation Accuracy: 0.8705636743215032, Validation F1 Score: 0.9040247678018576, Validation F2 Score: 0.922882427307206, Validation F0.5 Score: 0.8859223300970874





In [36]:
input_dim = balanced_data[0][0].x.shape[1]

val_losses = []
parameters = []
# Define hyperparameter grids
learning_rates = [0.01]
dropout_rates = [0.05]
sort_k_values = [50]
hidden_units_values = [16]

# Create combinations of hyperparameters
hyperparameter_combinations = list(itertools.product(learning_rates, dropout_rates, sort_k_values, hidden_units_values))

for lr, dropout_rate, sort_k, hidden_units in hyperparameter_combinations:
    model = SiameseGNN_GIN(sort_k, input_dim, dropout = dropout_rate, nhidden=hidden_units)
    val_accuracy, val_f1, val_loss, val_f2, val_05 = run_model(model, balanced_data, graph_pairs_val)

    parameters.append((lr, dropout_rate, sort_k, hidden_units))
    val_losses.append((val_accuracy, val_f1, val_loss, val_f2, val_05))

  0%|          | 0/30 [00:00<?, ?it/s]

  3%|▎         | 1/30 [00:10<04:57, 10.28s/it]

Epoch: 1, Training Loss: 0.6547527795301336, Validation Loss: 0.6463807629020827, Validation Accuracy: 0.651356993736952, Validation F1 Score: 0.7888748419721872, Validation F2 Score: 0.903300521134916, Validation F0.5 Score: 0.7001795332136446


  7%|▋         | 2/30 [00:20<04:42, 10.09s/it]

Epoch: 2, Training Loss: 0.648663713889685, Validation Loss: 0.6464087550600287, Validation Accuracy: 0.651356993736952, Validation F1 Score: 0.7888748419721872, Validation F2 Score: 0.903300521134916, Validation F0.5 Score: 0.7001795332136446


 10%|█         | 3/30 [00:30<04:37, 10.27s/it]

Epoch: 3, Training Loss: 0.6458484243636594, Validation Loss: 0.6465117869148175, Validation Accuracy: 0.651356993736952, Validation F1 Score: 0.7888748419721872, Validation F2 Score: 0.903300521134916, Validation F0.5 Score: 0.7001795332136446


 13%|█▎        | 4/30 [00:42<04:38, 10.73s/it]

Epoch: 4, Training Loss: 0.6454676866780503, Validation Loss: 0.6462578899178475, Validation Accuracy: 0.651356993736952, Validation F1 Score: 0.7888748419721872, Validation F2 Score: 0.903300521134916, Validation F0.5 Score: 0.7001795332136446


 17%|█▋        | 5/30 [00:54<04:46, 11.46s/it]

Epoch: 5, Training Loss: 0.646629343492484, Validation Loss: 0.6469507018310292, Validation Accuracy: 0.651356993736952, Validation F1 Score: 0.7888748419721872, Validation F2 Score: 0.903300521134916, Validation F0.5 Score: 0.7001795332136446


 20%|██        | 6/30 [01:05<04:24, 11.03s/it]

Epoch: 6, Training Loss: 0.6439754600123802, Validation Loss: 0.6466897128145979, Validation Accuracy: 0.651356993736952, Validation F1 Score: 0.7888748419721872, Validation F2 Score: 0.903300521134916, Validation F0.5 Score: 0.7001795332136446


 23%|██▎       | 7/30 [01:15<04:08, 10.81s/it]

Epoch: 7, Training Loss: 0.6447538180722463, Validation Loss: 0.6465346645213865, Validation Accuracy: 0.651356993736952, Validation F1 Score: 0.7888748419721872, Validation F2 Score: 0.903300521134916, Validation F0.5 Score: 0.7001795332136446


 27%|██▋       | 8/30 [01:25<03:52, 10.57s/it]

Epoch: 8, Training Loss: 0.6443314442066563, Validation Loss: 0.6467143745163537, Validation Accuracy: 0.651356993736952, Validation F1 Score: 0.7888748419721872, Validation F2 Score: 0.903300521134916, Validation F0.5 Score: 0.7001795332136446


 30%|███       | 9/30 [01:36<03:42, 10.60s/it]

Epoch: 9, Training Loss: 0.6421610278582498, Validation Loss: 0.6468669139393187, Validation Accuracy: 0.6534446764091858, Validation F1 Score: 0.789873417721519, Validation F2 Score: 0.9038238702201622, Validation F0.5 Score: 0.7014388489208633


 33%|███▎      | 10/30 [01:46<03:31, 10.59s/it]

Epoch: 10, Training Loss: 0.6447111508627047, Validation Loss: 0.6462638483913558, Validation Accuracy: 0.651356993736952, Validation F1 Score: 0.7888748419721872, Validation F2 Score: 0.903300521134916, Validation F0.5 Score: 0.7001795332136446


 37%|███▋      | 11/30 [01:57<03:25, 10.79s/it]

Epoch: 11, Training Loss: 0.6417146044215937, Validation Loss: 0.6461155726292437, Validation Accuracy: 0.651356993736952, Validation F1 Score: 0.7888748419721872, Validation F2 Score: 0.903300521134916, Validation F0.5 Score: 0.7001795332136446


 40%|████      | 12/30 [02:09<03:15, 10.88s/it]

Epoch: 12, Training Loss: 0.6426477440036327, Validation Loss: 0.6459681740757817, Validation Accuracy: 0.651356993736952, Validation F1 Score: 0.7888748419721872, Validation F2 Score: 0.903300521134916, Validation F0.5 Score: 0.7001795332136446


 43%|████▎     | 13/30 [02:19<03:02, 10.75s/it]

Epoch: 13, Training Loss: 0.6423913628454617, Validation Loss: 0.6459124192191066, Validation Accuracy: 0.651356993736952, Validation F1 Score: 0.7888748419721872, Validation F2 Score: 0.903300521134916, Validation F0.5 Score: 0.7001795332136446


 47%|████▋     | 14/30 [02:29<02:48, 10.52s/it]

Epoch: 14, Training Loss: 0.641902709891679, Validation Loss: 0.6457091431777015, Validation Accuracy: 0.651356993736952, Validation F1 Score: 0.7888748419721872, Validation F2 Score: 0.903300521134916, Validation F0.5 Score: 0.7001795332136446


 50%|█████     | 15/30 [02:39<02:35, 10.36s/it]

Epoch: 15, Training Loss: 0.6408847343398486, Validation Loss: 0.6454406085741047, Validation Accuracy: 0.651356993736952, Validation F1 Score: 0.7888748419721872, Validation F2 Score: 0.903300521134916, Validation F0.5 Score: 0.7001795332136446


 53%|█████▎    | 16/30 [02:49<02:23, 10.27s/it]

Epoch: 16, Training Loss: 0.6419343583698432, Validation Loss: 0.6453054595839753, Validation Accuracy: 0.6555323590814196, Validation F1 Score: 0.7908745247148289, Validation F2 Score: 0.9043478260869565, Validation F0.5 Score: 0.7027027027027027


 57%|█████▋    | 17/30 [02:59<02:10, 10.06s/it]

Epoch: 17, Training Loss: 0.6403562780333413, Validation Loss: 0.6451335571752959, Validation Accuracy: 0.6555323590814196, Validation F1 Score: 0.7908745247148289, Validation F2 Score: 0.9043478260869565, Validation F0.5 Score: 0.7027027027027027


 60%|██████    | 18/30 [03:08<01:58,  9.87s/it]

Epoch: 18, Training Loss: 0.6414682267539304, Validation Loss: 0.6450140303883523, Validation Accuracy: 0.6555323590814196, Validation F1 Score: 0.7908745247148289, Validation F2 Score: 0.9043478260869565, Validation F0.5 Score: 0.7027027027027027


 63%|██████▎   | 19/30 [03:18<01:47,  9.77s/it]

Epoch: 19, Training Loss: 0.6402950821874532, Validation Loss: 0.6448146539714988, Validation Accuracy: 0.6555323590814196, Validation F1 Score: 0.7908745247148289, Validation F2 Score: 0.9043478260869565, Validation F0.5 Score: 0.7027027027027027


 67%|██████▋   | 20/30 [03:28<01:39,  9.93s/it]

Epoch: 20, Training Loss: 0.6414178247362095, Validation Loss: 0.644752287976677, Validation Accuracy: 0.6555323590814196, Validation F1 Score: 0.7908745247148289, Validation F2 Score: 0.9043478260869565, Validation F0.5 Score: 0.7027027027027027


 70%|███████   | 21/30 [03:38<01:30, 10.01s/it]

Epoch: 21, Training Loss: 0.6411769605604708, Validation Loss: 0.6447351770883811, Validation Accuracy: 0.6555323590814196, Validation F1 Score: 0.7908745247148289, Validation F2 Score: 0.9043478260869565, Validation F0.5 Score: 0.7027027027027027


 73%|███████▎  | 22/30 [03:48<01:20, 10.02s/it]

Epoch: 22, Training Loss: 0.6412733929035183, Validation Loss: 0.6447278601516014, Validation Accuracy: 0.6555323590814196, Validation F1 Score: 0.7908745247148289, Validation F2 Score: 0.9043478260869565, Validation F0.5 Score: 0.7027027027027027


 77%|███████▋  | 23/30 [03:58<01:10, 10.06s/it]

Epoch: 23, Training Loss: 0.6415126226538774, Validation Loss: 0.6446802616119385, Validation Accuracy: 0.6555323590814196, Validation F1 Score: 0.7908745247148289, Validation F2 Score: 0.9043478260869565, Validation F0.5 Score: 0.7027027027027027


 80%|████████  | 24/30 [04:09<01:01, 10.33s/it]

Epoch: 24, Training Loss: 0.639295044675267, Validation Loss: 0.6446632797011254, Validation Accuracy: 0.6555323590814196, Validation F1 Score: 0.7908745247148289, Validation F2 Score: 0.9043478260869565, Validation F0.5 Score: 0.7027027027027027


 83%|████████▎ | 25/30 [04:19<00:51, 10.21s/it]

Epoch: 25, Training Loss: 0.6409227930266281, Validation Loss: 0.6446385363074085, Validation Accuracy: 0.6555323590814196, Validation F1 Score: 0.7908745247148289, Validation F2 Score: 0.9043478260869565, Validation F0.5 Score: 0.7027027027027027


 87%|████████▋ | 26/30 [04:29<00:39,  9.95s/it]

Epoch: 26, Training Loss: 0.6401656802233733, Validation Loss: 0.6446214708381007, Validation Accuracy: 0.6555323590814196, Validation F1 Score: 0.7908745247148289, Validation F2 Score: 0.9043478260869565, Validation F0.5 Score: 0.7027027027027027


 90%|█████████ | 27/30 [04:38<00:29,  9.77s/it]

Epoch: 27, Training Loss: 0.6400392936501758, Validation Loss: 0.6445835120120477, Validation Accuracy: 0.6555323590814196, Validation F1 Score: 0.7908745247148289, Validation F2 Score: 0.9043478260869565, Validation F0.5 Score: 0.7027027027027027


 93%|█████████▎| 28/30 [04:48<00:19,  9.98s/it]

Epoch: 28, Training Loss: 0.6397936031798459, Validation Loss: 0.6445685593967397, Validation Accuracy: 0.6555323590814196, Validation F1 Score: 0.7908745247148289, Validation F2 Score: 0.9043478260869565, Validation F0.5 Score: 0.7027027027027027


 97%|█████████▋| 29/30 [04:59<00:10, 10.17s/it]

Epoch: 29, Training Loss: 0.6403837327299446, Validation Loss: 0.6445352594389547, Validation Accuracy: 0.6555323590814196, Validation F1 Score: 0.7908745247148289, Validation F2 Score: 0.9043478260869565, Validation F0.5 Score: 0.7027027027027027


100%|██████████| 30/30 [05:09<00:00, 10.32s/it]

Epoch: 30, Training Loss: 0.638932671510812, Validation Loss: 0.6445111689587476, Validation Accuracy: 0.6555323590814196, Validation F1 Score: 0.7908745247148289, Validation F2 Score: 0.9043478260869565, Validation F0.5 Score: 0.7027027027027027





## No GDP 27 Features

## Random Feature Subset

In [19]:
years = range(1962,2019)

In [20]:
with open("feature_dicts/random_features_dict.pkl", "rb") as f:
    feat_dict_random = pkl.load(f)

In [21]:
with open("src/pygcn/all_graphs.pkl", "rb") as f:         
    all_graphs = pkl.load(f)

In [22]:
all_graphs = add_features(years, all_graphs, feat_dict_random, 27)

In [23]:
labels = dist_labels_to_changepoint_labels(phases)
graph_pairs = sample_pairs(all_graphs,labels)

798 positive and 540 negative examples


In [24]:
graph_pairs_train = [graph_pairs[i] for i in train_indices]
graph_pairs_test = [graph_pairs[i] for i in test_indices]
graph_pairs_val = [graph_pairs[i] for i in val_indices]

In [25]:
import random
positive_samples = [item for item in graph_pairs_train if item[2] == 1]
negative_samples = [item for item in graph_pairs_train if item[2] == 0]

# Calculate the difference in count
diff = len(negative_samples) - len(positive_samples)

# Upsample positive samples
if diff > 0:
    positive_samples_upsampled = positive_samples * (diff // len(positive_samples)) + random.sample(positive_samples, diff % len(positive_samples))
    balanced_data = negative_samples + positive_samples + positive_samples_upsampled
else:
    balanced_data = graph_pairs_train

random.shuffle(balanced_data)

In [24]:
input_dim = balanced_data[0][0].x.shape[1]

val_losses = []
parameters = []
# Define hyperparameter grids
learning_rates = [0.01]
dropout_rates = [0.05]
sort_k_values = [50]
hidden_units_values = [16]

# Create combinations of hyperparameters
hyperparameter_combinations = list(itertools.product(learning_rates, dropout_rates, sort_k_values, hidden_units_values))

for lr, dropout_rate, sort_k, hidden_units in hyperparameter_combinations:
    model = SiameseGNN(sort_k, input_dim, dropout = dropout_rate, nhidden=hidden_units)
    val_accuracy, val_f1, val_loss, val_f2, val_05 = run_model(model, balanced_data, graph_pairs_val)

    parameters.append((lr, dropout_rate, sort_k, hidden_units))
    val_losses.append((val_accuracy, val_f1, val_loss, val_f2, val_05))

  3%|▎         | 1/30 [00:14<07:11, 14.87s/it]

Epoch: 1, Training Loss: 0.6523301030899414, Validation Loss: 0.6295803544541242, Validation Accuracy: 0.6492693110647182, Validation F1 Score: 0.7862595419847328, Validation F2 Score: 0.8956521739130435, Validation F0.5 Score: 0.7006802721088435


  7%|▋         | 2/30 [00:29<06:48, 14.59s/it]

Epoch: 2, Training Loss: 0.6354912189356088, Validation Loss: 0.6180587373596144, Validation Accuracy: 0.6951983298538622, Validation F1 Score: 0.8027027027027027, Validation F2 Score: 0.8844550327575939, Validation F0.5 Score: 0.7347847600197922


 10%|█         | 3/30 [00:41<06:07, 13.63s/it]

Epoch: 3, Training Loss: 0.6282152446646675, Validation Loss: 0.6104174372672039, Validation Accuracy: 0.6889352818371608, Validation F1 Score: 0.7972789115646258, Validation F2 Score: 0.8751493428912783, Validation F0.5 Score: 0.7321339330334833


 13%|█▎        | 4/30 [00:54<05:42, 13.19s/it]

Epoch: 4, Training Loss: 0.6244102213200855, Validation Loss: 0.6057892622803348, Validation Accuracy: 0.6993736951983298, Validation F1 Score: 0.8027397260273973, Validation F2 Score: 0.877771120431396, Validation F0.5 Score: 0.7395254921756689


 17%|█▋        | 5/30 [01:07<05:32, 13.28s/it]

Epoch: 5, Training Loss: 0.619800472116271, Validation Loss: 0.599740352287173, Validation Accuracy: 0.7098121085594989, Validation F1 Score: 0.8055944055944056, Validation F2 Score: 0.8706166868198307, Validation F0.5 Score: 0.7496095783446122


 20%|██        | 6/30 [01:21<05:25, 13.57s/it]

Epoch: 6, Training Loss: 0.6122024107253414, Validation Loss: 0.5915692558119342, Validation Accuracy: 0.7160751565762005, Validation F1 Score: 0.8062678062678063, Validation F2 Score: 0.862279098110908, Validation F0.5 Score: 0.7570893525949706


 23%|██▎       | 7/30 [01:34<05:02, 13.14s/it]

Epoch: 7, Training Loss: 0.6006399247962728, Validation Loss: 0.5900005714380666, Validation Accuracy: 0.7118997912317327, Validation F1 Score: 0.8034188034188035, Validation F2 Score: 0.8592321755027422, Validation F0.5 Score: 0.7544141252006421


 27%|██▋       | 8/30 [01:46<04:40, 12.76s/it]

Epoch: 8, Training Loss: 0.595215388537699, Validation Loss: 0.5841361301676963, Validation Accuracy: 0.7286012526096033, Validation F1 Score: 0.8126801152737753, Validation F2 Score: 0.863441518677281, Validation F0.5 Score: 0.7675557974959173


 30%|███       | 9/30 [01:57<04:22, 12.50s/it]

Epoch: 9, Training Loss: 0.5915764465601093, Validation Loss: 0.5781670664446836, Validation Accuracy: 0.7306889352818372, Validation F1 Score: 0.8133140376266281, Validation F2 Score: 0.8619631901840491, Validation F0.5 Score: 0.7698630136986301


 33%|███▎      | 10/30 [02:10<04:07, 12.35s/it]

Epoch: 10, Training Loss: 0.5762625029904217, Validation Loss: 0.564923644936906, Validation Accuracy: 0.7411273486430062, Validation F1 Score: 0.8187134502923976, Validation F2 Score: 0.862600123228589, Validation F0.5 Score: 0.7790762381747357


 37%|███▋      | 11/30 [02:22<03:54, 12.37s/it]

Epoch: 11, Training Loss: 0.5660838975241191, Validation Loss: 0.5592220047321599, Validation Accuracy: 0.7390396659707724, Validation F1 Score: 0.8169838945827232, Validation F2 Score: 0.8600493218249076, Validation F0.5 Score: 0.7780256553262688


 40%|████      | 12/30 [02:34<03:39, 12.22s/it]

Epoch: 12, Training Loss: 0.5657010867864735, Validation Loss: 0.5546048119321995, Validation Accuracy: 0.7432150313152401, Validation F1 Score: 0.8188512518409425, Validation F2 Score: 0.8590852904820766, Validation F0.5 Score: 0.7822172200337648


 43%|████▎     | 13/30 [02:46<03:25, 12.12s/it]

Epoch: 13, Training Loss: 0.5634631687072627, Validation Loss: 0.5511328446093581, Validation Accuracy: 0.7432150313152401, Validation F1 Score: 0.8177777777777778, Validation F2 Score: 0.8550185873605948, Validation F0.5 Score: 0.7836456558773425


 47%|████▋     | 14/30 [02:58<03:14, 12.15s/it]

Epoch: 14, Training Loss: 0.5606563322970113, Validation Loss: 0.5482197337956917, Validation Accuracy: 0.7515657620041754, Validation F1 Score: 0.8231797919762258, Validation F2 Score: 0.859181141439206, Validation F0.5 Score: 0.7900741585852824


 50%|█████     | 15/30 [03:12<03:09, 12.60s/it]

Epoch: 15, Training Loss: 0.5532635363478645, Validation Loss: 0.5449680813543483, Validation Accuracy: 0.755741127348643, Validation F1 Score: 0.8256333830104322, Validation F2 Score: 0.860248447204969, Validation F0.5 Score: 0.7936962750716332


 53%|█████▎    | 16/30 [03:25<03:01, 12.94s/it]

Epoch: 16, Training Loss: 0.5574582861259564, Validation Loss: 0.5424897449872688, Validation Accuracy: 0.7578288100208769, Validation F1 Score: 0.826865671641791, Validation F2 Score: 0.8607830950901181, Validation F0.5 Score: 0.7955198161975876


 57%|█████▋    | 17/30 [03:40<02:55, 13.49s/it]

Epoch: 17, Training Loss: 0.5536705473498241, Validation Loss: 0.5398933342453831, Validation Accuracy: 0.7578288100208769, Validation F1 Score: 0.8263473053892215, Validation F2 Score: 0.8587429993777225, Validation F0.5 Score: 0.7963069821119446


 60%|██████    | 18/30 [03:53<02:40, 13.39s/it]

Epoch: 18, Training Loss: 0.5512777087955016, Validation Loss: 0.5367721795786895, Validation Accuracy: 0.7599164926931107, Validation F1 Score: 0.8275862068965517, Validation F2 Score: 0.8592777085927771, Validation F0.5 Score: 0.7981492192018508


 63%|██████▎   | 19/30 [04:08<02:31, 13.75s/it]

Epoch: 19, Training Loss: 0.5435286388657559, Validation Loss: 0.5342129435942417, Validation Accuracy: 0.7578288100208769, Validation F1 Score: 0.8258258258258259, Validation F2 Score: 0.8566978193146417, Validation F0.5 Score: 0.7971014492753623


 67%|██████▋   | 20/30 [04:21<02:15, 13.58s/it]

Epoch: 20, Training Loss: 0.549862975174855, Validation Loss: 0.5322529289692578, Validation Accuracy: 0.7661795407098121, Validation F1 Score: 0.8333333333333334, Validation F2 Score: 0.8690254500310366, Validation F0.5 Score: 0.8004574042309891


 70%|███████   | 21/30 [04:35<02:04, 13.85s/it]

Epoch: 21, Training Loss: 0.5431819234584442, Validation Loss: 0.5319370199344851, Validation Accuracy: 0.7599164926931107, Validation F1 Score: 0.828101644245142, Validation F2 Score: 0.861318407960199, Validation F0.5 Score: 0.7973517559009787


 73%|███████▎  | 22/30 [04:50<01:52, 14.08s/it]

Epoch: 22, Training Loss: 0.5436488588651022, Validation Loss: 0.5316400194466736, Validation Accuracy: 0.7620041753653445, Validation F1 Score: 0.8298507462686567, Validation F2 Score: 0.8638906152889994, Validation F0.5 Score: 0.7983917288914417


 77%|███████▋  | 23/30 [05:04<01:38, 14.01s/it]

Epoch: 23, Training Loss: 0.545738631405925, Validation Loss: 0.5312849940860446, Validation Accuracy: 0.7640918580375783, Validation F1 Score: 0.8310911808669657, Validation F2 Score: 0.8644278606965174, Validation F0.5 Score: 0.8002302820955671


 80%|████████  | 24/30 [05:18<01:24, 14.15s/it]

Epoch: 24, Training Loss: 0.5403756628838702, Validation Loss: 0.531015856592043, Validation Accuracy: 0.7640918580375783, Validation F1 Score: 0.8310911808669657, Validation F2 Score: 0.8644278606965174, Validation F0.5 Score: 0.8002302820955671


 83%|████████▎ | 25/30 [05:33<01:10, 14.15s/it]

Epoch: 25, Training Loss: 0.5317314899951812, Validation Loss: 0.5307521797047776, Validation Accuracy: 0.7620041753653445, Validation F1 Score: 0.8293413173652695, Validation F2 Score: 0.8618543870566272, Validation F0.5 Score: 0.7991921523369879


 87%|████████▋ | 26/30 [05:46<00:56, 14.07s/it]

Epoch: 26, Training Loss: 0.5442245607901766, Validation Loss: 0.5304853610678854, Validation Accuracy: 0.7640918580375783, Validation F1 Score: 0.8300751879699249, Validation F2 Score: 0.8603491271820449, Validation F0.5 Score: 0.8018593840790238


 90%|█████████ | 27/30 [06:00<00:41, 13.92s/it]

Epoch: 27, Training Loss: 0.543061309504783, Validation Loss: 0.5302165021851566, Validation Accuracy: 0.7640918580375783, Validation F1 Score: 0.8300751879699249, Validation F2 Score: 0.8603491271820449, Validation F0.5 Score: 0.8018593840790238


 93%|█████████▎| 28/30 [06:13<00:27, 13.78s/it]

Epoch: 28, Training Loss: 0.5452263315394131, Validation Loss: 0.5299973061662129, Validation Accuracy: 0.7661795407098121, Validation F1 Score: 0.8318318318318318, Validation F2 Score: 0.8629283489096573, Validation F0.5 Score: 0.8028985507246377


 97%|█████████▋| 29/30 [06:27<00:13, 13.58s/it]

Epoch: 29, Training Loss: 0.539252459489066, Validation Loss: 0.5297640619297864, Validation Accuracy: 0.7661795407098121, Validation F1 Score: 0.8318318318318318, Validation F2 Score: 0.8629283489096573, Validation F0.5 Score: 0.8028985507246377


100%|██████████| 30/30 [06:40<00:00, 13.34s/it]

Epoch: 30, Training Loss: 0.5405990651875828, Validation Loss: 0.5295222561518683, Validation Accuracy: 0.7661795407098121, Validation F1 Score: 0.8318318318318318, Validation F2 Score: 0.8629283489096573, Validation F0.5 Score: 0.8028985507246377





In [25]:
import itertools
input_dim = balanced_data[0][0].x.shape[1]

val_losses = []
parameters = []
# Define hyperparameter grids
learning_rates = [0.01]
dropout_rates = [0.05]
sort_k_values = [50]
hidden_units_values = [16]

# Create combinations of hyperparameters
hyperparameter_combinations = list(itertools.product(learning_rates, dropout_rates, sort_k_values, hidden_units_values))

for lr, dropout_rate, sort_k, hidden_units in hyperparameter_combinations:
    model = SiameseGNN_GraphSAGE(sort_k, input_dim, dropout = dropout_rate, nhidden=hidden_units)
    val_accuracy, val_f1, val_loss, val_f2, val_05 = run_model(model, balanced_data, graph_pairs_val)

    parameters.append((lr, dropout_rate, sort_k, hidden_units))
    val_losses.append((val_accuracy, val_f1, val_loss, val_f2, val_05))

  3%|▎         | 1/30 [00:12<05:56, 12.30s/it]

Epoch: 1, Training Loss: 0.6483550843711682, Validation Loss: 0.6311088434589681, Validation Accuracy: 0.6430062630480167, Validation F1 Score: 0.7827191867852605, Validation F2 Score: 0.9000584453535944, Validation F0.5 Score: 0.6924460431654677


  7%|▋         | 2/30 [00:22<05:15, 11.26s/it]

Epoch: 2, Training Loss: 0.6282569432520194, Validation Loss: 0.6065233254606888, Validation Accuracy: 0.6951983298538622, Validation F1 Score: 0.7932011331444759, Validation F2 Score: 0.8588957055214724, Validation F0.5 Score: 0.7368421052631579


 10%|█         | 3/30 [00:35<05:20, 11.86s/it]

Epoch: 3, Training Loss: 0.6086731450622482, Validation Loss: 0.5720918593551022, Validation Accuracy: 0.7369519832985386, Validation F1 Score: 0.8102409638554217, Validation F2 Score: 0.8469773299748111, Validation F0.5 Score: 0.7765588914549654


 13%|█▎        | 4/30 [00:47<05:07, 11.83s/it]

Epoch: 4, Training Loss: 0.5875537144010468, Validation Loss: 0.560436920420362, Validation Accuracy: 0.7453027139874739, Validation F1 Score: 0.8221574344023324, Validation F2 Score: 0.8757763975155279, Validation F0.5 Score: 0.7747252747252747


 17%|█▋        | 5/30 [00:58<04:47, 11.50s/it]

Epoch: 5, Training Loss: 0.5719328170784339, Validation Loss: 0.5359888714825684, Validation Accuracy: 0.7849686847599165, Validation F1 Score: 0.8437025796661608, Validation F2 Score: 0.8780795957043588, Validation F0.5 Score: 0.8119158878504673


 20%|██        | 6/30 [01:08<04:30, 11.28s/it]

Epoch: 6, Training Loss: 0.5434470078134238, Validation Loss: 0.506090677979097, Validation Accuracy: 0.8016701461377871, Validation F1 Score: 0.8588410104011887, Validation F2 Score: 0.9048215403882279, Validation F0.5 Score: 0.8173076923076923


 23%|██▎       | 7/30 [01:20<04:21, 11.38s/it]

Epoch: 7, Training Loss: 0.528219882659653, Validation Loss: 0.49824095886037345, Validation Accuracy: 0.7974947807933194, Validation F1 Score: 0.8558692421991084, Validation F2 Score: 0.9016906700062617, Validation F0.5 Score: 0.8144796380090498


 27%|██▋       | 8/30 [01:31<04:10, 11.40s/it]

Epoch: 8, Training Loss: 0.5327192715903435, Validation Loss: 0.4961666111719633, Validation Accuracy: 0.7995824634655533, Validation F1 Score: 0.8588235294117647, Validation F2 Score: 0.9102244389027432, Validation F0.5 Score: 0.8129175946547884


 30%|███       | 9/30 [01:42<03:50, 10.99s/it]

Epoch: 9, Training Loss: 0.5207772721130646, Validation Loss: 0.506409679891421, Validation Accuracy: 0.7870563674321504, Validation F1 Score: 0.8504398826979472, Validation F2 Score: 0.9028642590286425, Validation F0.5 Score: 0.8037694013303769


 33%|███▎      | 10/30 [01:53<03:40, 11.05s/it]

Epoch: 10, Training Loss: 0.5133366770224885, Validation Loss: 0.47838665619648074, Validation Accuracy: 0.8037578288100209, Validation F1 Score: 0.8601190476190477, Validation F2 Score: 0.9053884711779449, Validation F0.5 Score: 0.8191609977324263


 37%|███▋      | 11/30 [02:04<03:31, 11.11s/it]

Epoch: 11, Training Loss: 0.4838218694726003, Validation Loss: 0.4619519055387421, Validation Accuracy: 0.8183716075156576, Validation F1 Score: 0.8695652173913043, Validation F2 Score: 0.9113764927718416, Validation F0.5 Score: 0.8314220183486238


 40%|████      | 12/30 [02:16<03:25, 11.41s/it]

Epoch: 12, Training Loss: 0.47814191018339236, Validation Loss: 0.45730390961210016, Validation Accuracy: 0.8288100208768268, Validation F1 Score: 0.875, Validation F2 Score: 0.9082278481012658, Validation F0.5 Score: 0.8441176470588235


 43%|████▎     | 13/30 [02:27<03:13, 11.37s/it]

Epoch: 13, Training Loss: 0.4780383491061449, Validation Loss: 0.45002732480368685, Validation Accuracy: 0.8308977035490606, Validation F1 Score: 0.8774583963691377, Validation F2 Score: 0.9148264984227129, Validation F0.5 Score: 0.8430232558139535


 47%|████▋     | 14/30 [02:38<02:58, 11.16s/it]

Epoch: 14, Training Loss: 0.46320065884007183, Validation Loss: 0.4478639454147263, Validation Accuracy: 0.8329853862212944, Validation F1 Score: 0.8776758409785933, Validation F2 Score: 0.9093789607097592, Validation F0.5 Score: 0.8481087470449172


 50%|█████     | 15/30 [02:48<02:43, 10.90s/it]

Epoch: 15, Training Loss: 0.47095438861361116, Validation Loss: 0.4428067338615967, Validation Accuracy: 0.837160751565762, Validation F1 Score: 0.8818181818181818, Validation F2 Score: 0.9185606060606061, Validation F0.5 Score: 0.8479020979020979


 53%|█████▎    | 16/30 [02:59<02:32, 10.90s/it]

Epoch: 16, Training Loss: 0.4597243011122338, Validation Loss: 0.4394555048305456, Validation Accuracy: 0.837160751565762, Validation F1 Score: 0.8814589665653495, Validation F2 Score: 0.9165613147914032, Validation F0.5 Score: 0.8489461358313818


 57%|█████▋    | 17/30 [03:10<02:20, 10.82s/it]

Epoch: 17, Training Loss: 0.4650495855007301, Validation Loss: 0.4376104799155651, Validation Accuracy: 0.8392484342379958, Validation F1 Score: 0.8817204301075269, Validation F2 Score: 0.9111111111111111, Validation F0.5 Score: 0.8541666666666666


 60%|██████    | 18/30 [03:20<02:08, 10.73s/it]

Epoch: 18, Training Loss: 0.4593558962629133, Validation Loss: 0.43158325664932395, Validation Accuracy: 0.8434237995824635, Validation F1 Score: 0.8854961832061069, Validation F2 Score: 0.9183027232425586, Validation F0.5 Score: 0.8549528301886793


 63%|██████▎   | 19/30 [03:31<01:57, 10.68s/it]

Epoch: 19, Training Loss: 0.4582776972243529, Validation Loss: 0.43374854064933444, Validation Accuracy: 0.8413361169102297, Validation F1 Score: 0.8834355828220859, Validation F2 Score: 0.9137055837563451, Validation F0.5 Score: 0.8551068883610451


 67%|██████▋   | 20/30 [03:41<01:45, 10.57s/it]

Epoch: 20, Training Loss: 0.4552018995101923, Validation Loss: 0.42560679239992805, Validation Accuracy: 0.8475991649269311, Validation F1 Score: 0.8885496183206106, Validation F2 Score: 0.9214692843571881, Validation F0.5 Score: 0.8579009433962265


 70%|███████   | 21/30 [03:52<01:35, 10.60s/it]

Epoch: 21, Training Loss: 0.4492283604237229, Validation Loss: 0.42515601594164376, Validation Accuracy: 0.8475991649269311, Validation F1 Score: 0.8885496183206106, Validation F2 Score: 0.9214692843571881, Validation F0.5 Score: 0.8579009433962265


 73%|███████▎  | 22/30 [04:02<01:23, 10.40s/it]

Epoch: 22, Training Loss: 0.4510882349223552, Validation Loss: 0.4245879974270663, Validation Accuracy: 0.8475991649269311, Validation F1 Score: 0.8885496183206106, Validation F2 Score: 0.9214692843571881, Validation F0.5 Score: 0.8579009433962265


 77%|███████▋  | 23/30 [04:12<01:11, 10.26s/it]

Epoch: 23, Training Loss: 0.45288226073812166, Validation Loss: 0.4246929670238794, Validation Accuracy: 0.8475991649269311, Validation F1 Score: 0.888208269525268, Validation F2 Score: 0.9194673430564363, Validation F0.5 Score: 0.8590047393364929


 80%|████████  | 24/30 [04:23<01:03, 10.55s/it]

Epoch: 24, Training Loss: 0.44442498796038493, Validation Loss: 0.424322694477806, Validation Accuracy: 0.8475991649269311, Validation F1 Score: 0.888208269525268, Validation F2 Score: 0.9194673430564363, Validation F0.5 Score: 0.8590047393364929


 83%|████████▎ | 25/30 [04:34<00:52, 10.55s/it]

Epoch: 25, Training Loss: 0.45075247387714146, Validation Loss: 0.4237807478312411, Validation Accuracy: 0.8475991649269311, Validation F1 Score: 0.888208269525268, Validation F2 Score: 0.9194673430564363, Validation F0.5 Score: 0.8590047393364929


 87%|████████▋ | 26/30 [04:44<00:41, 10.41s/it]

Epoch: 26, Training Loss: 0.4408738515188951, Validation Loss: 0.42333942965948507, Validation Accuracy: 0.8475991649269311, Validation F1 Score: 0.888208269525268, Validation F2 Score: 0.9194673430564363, Validation F0.5 Score: 0.8590047393364929


 90%|█████████ | 27/30 [04:54<00:30, 10.26s/it]

Epoch: 27, Training Loss: 0.45037377944124657, Validation Loss: 0.4229352828854061, Validation Accuracy: 0.8496868475991649, Validation F1 Score: 0.8899082568807339, Validation F2 Score: 0.9220532319391636, Validation F0.5 Score: 0.8599290780141844


 93%|█████████▎| 28/30 [05:03<00:20, 10.12s/it]

Epoch: 28, Training Loss: 0.44988762690854045, Validation Loss: 0.4228902540027722, Validation Accuracy: 0.8496868475991649, Validation F1 Score: 0.8899082568807339, Validation F2 Score: 0.9220532319391636, Validation F0.5 Score: 0.8599290780141844


 97%|█████████▋| 29/30 [05:13<00:10, 10.09s/it]

Epoch: 29, Training Loss: 0.44842066361238964, Validation Loss: 0.42277375548020285, Validation Accuracy: 0.8496868475991649, Validation F1 Score: 0.8899082568807339, Validation F2 Score: 0.9220532319391636, Validation F0.5 Score: 0.8599290780141844


100%|██████████| 30/30 [05:23<00:00, 10.80s/it]

Epoch: 30, Training Loss: 0.45029023526355144, Validation Loss: 0.42269675256315004, Validation Accuracy: 0.8475991649269311, Validation F1 Score: 0.888208269525268, Validation F2 Score: 0.9194673430564363, Validation F0.5 Score: 0.8590047393364929





In [26]:
import itertools
input_dim = balanced_data[0][0].x.shape[1]

val_losses = []
parameters = []
# Define hyperparameter grids
learning_rates = [0.01]
dropout_rates = [0.05]
sort_k_values = [50]
hidden_units_values = [16]

# Create combinations of hyperparameters
hyperparameter_combinations = list(itertools.product(learning_rates, dropout_rates, sort_k_values, hidden_units_values))

for lr, dropout_rate, sort_k, hidden_units in hyperparameter_combinations:
    model = SiameseGNN_GIN(sort_k, input_dim, dropout = dropout_rate, nhidden=hidden_units)
    val_accuracy, val_f1, val_loss, val_f2, val_05 = run_model(model, balanced_data, graph_pairs_val)

    parameters.append((lr, dropout_rate, sort_k, hidden_units))
    val_losses.append((val_accuracy, val_f1, val_loss, val_f2, val_05))

  3%|▎         | 1/30 [00:11<05:19, 11.00s/it]

Epoch: 1, Training Loss: 0.6485064943382359, Validation Loss: 0.6554418763686322, Validation Accuracy: 0.6409185803757829, Validation F1 Score: 0.7811704834605598, Validation F2 Score: 0.8992384299941417, Validation F0.5 Score: 0.6905083220872694


  7%|▋         | 2/30 [00:21<04:58, 10.66s/it]

Epoch: 2, Training Loss: 0.6393981634511718, Validation Loss: 0.6548049792740689, Validation Accuracy: 0.6409185803757829, Validation F1 Score: 0.7811704834605598, Validation F2 Score: 0.8992384299941417, Validation F0.5 Score: 0.6905083220872694


 10%|█         | 3/30 [00:31<04:38, 10.33s/it]

Epoch: 3, Training Loss: 0.6402039026516483, Validation Loss: 0.6552289033110306, Validation Accuracy: 0.6409185803757829, Validation F1 Score: 0.7811704834605598, Validation F2 Score: 0.8992384299941417, Validation F0.5 Score: 0.6905083220872694


 13%|█▎        | 4/30 [00:41<04:25, 10.23s/it]

Epoch: 4, Training Loss: 0.6385924387003062, Validation Loss: 0.653638284365668, Validation Accuracy: 0.6409185803757829, Validation F1 Score: 0.7811704834605598, Validation F2 Score: 0.8992384299941417, Validation F0.5 Score: 0.6905083220872694


 17%|█▋        | 5/30 [00:51<04:15, 10.22s/it]

Epoch: 5, Training Loss: 0.6398660193042696, Validation Loss: 0.6538927863684477, Validation Accuracy: 0.6409185803757829, Validation F1 Score: 0.7811704834605598, Validation F2 Score: 0.8992384299941417, Validation F0.5 Score: 0.6905083220872694


 20%|██        | 6/30 [01:01<04:03, 10.13s/it]

Epoch: 6, Training Loss: 0.6398784141119496, Validation Loss: 0.6526912889425839, Validation Accuracy: 0.6409185803757829, Validation F1 Score: 0.7811704834605598, Validation F2 Score: 0.8992384299941417, Validation F0.5 Score: 0.6905083220872694


 23%|██▎       | 7/30 [01:11<03:48,  9.95s/it]

Epoch: 7, Training Loss: 0.6413525053637653, Validation Loss: 0.652222192486542, Validation Accuracy: 0.6409185803757829, Validation F1 Score: 0.7811704834605598, Validation F2 Score: 0.8992384299941417, Validation F0.5 Score: 0.6905083220872694


 27%|██▋       | 8/30 [01:22<03:45, 10.26s/it]

Epoch: 8, Training Loss: 0.635443735241018, Validation Loss: 0.6521600400854003, Validation Accuracy: 0.6409185803757829, Validation F1 Score: 0.7811704834605598, Validation F2 Score: 0.8992384299941417, Validation F0.5 Score: 0.6905083220872694


 30%|███       | 9/30 [01:32<03:36, 10.31s/it]

Epoch: 9, Training Loss: 0.6374638205162535, Validation Loss: 0.6507515620463576, Validation Accuracy: 0.6409185803757829, Validation F1 Score: 0.7811704834605598, Validation F2 Score: 0.8992384299941417, Validation F0.5 Score: 0.6905083220872694


 33%|███▎      | 10/30 [01:42<03:25, 10.29s/it]

Epoch: 10, Training Loss: 0.6384459018956407, Validation Loss: 0.6519987333044877, Validation Accuracy: 0.6450939457202505, Validation F1 Score: 0.7831632653061225, Validation F2 Score: 0.9002932551319648, Validation F0.5 Score: 0.6930022573363431


 37%|███▋      | 11/30 [01:53<03:15, 10.30s/it]

Epoch: 11, Training Loss: 0.6330860867208822, Validation Loss: 0.6505813591788855, Validation Accuracy: 0.6492693110647182, Validation F1 Score: 0.7851662404092071, Validation F2 Score: 0.9013505578391074, Validation F0.5 Score: 0.695514272768464


 40%|████      | 12/30 [02:02<03:02, 10.14s/it]

Epoch: 12, Training Loss: 0.6346838466176917, Validation Loss: 0.6494745402470511, Validation Accuracy: 0.651356993736952, Validation F1 Score: 0.7861715749039693, Validation F2 Score: 0.9018801410105758, Validation F0.5 Score: 0.6967771221062188


 43%|████▎     | 13/30 [02:12<02:50, 10.05s/it]

Epoch: 13, Training Loss: 0.6336449992258097, Validation Loss: 0.6488524901842025, Validation Accuracy: 0.651356993736952, Validation F1 Score: 0.785622593068036, Validation F2 Score: 0.9, Validation F0.5 Score: 0.6970387243735763


 47%|████▋     | 14/30 [02:23<02:45, 10.34s/it]

Epoch: 14, Training Loss: 0.632381381566248, Validation Loss: 0.6482471143029677, Validation Accuracy: 0.6534446764091858, Validation F1 Score: 0.7866323907455013, Validation F2 Score: 0.9005297233666862, Validation F0.5 Score: 0.6983112733911456


 50%|█████     | 15/30 [02:33<02:34, 10.29s/it]

Epoch: 15, Training Loss: 0.6330649103179122, Validation Loss: 0.6480306031684039, Validation Accuracy: 0.6534446764091858, Validation F1 Score: 0.7866323907455013, Validation F2 Score: 0.9005297233666862, Validation F0.5 Score: 0.6983112733911456


 53%|█████▎    | 16/30 [02:44<02:26, 10.47s/it]

Epoch: 16, Training Loss: 0.6304426968658232, Validation Loss: 0.6471363968665017, Validation Accuracy: 0.6492693110647182, Validation F1 Score: 0.7835051546391752, Validation F2 Score: 0.8956982911019447, Validation F0.5 Score: 0.6962895098488319


 57%|█████▋    | 17/30 [02:55<02:18, 10.62s/it]

Epoch: 17, Training Loss: 0.6329285324927781, Validation Loss: 0.6469107827090023, Validation Accuracy: 0.6492693110647182, Validation F1 Score: 0.7835051546391752, Validation F2 Score: 0.8956982911019447, Validation F0.5 Score: 0.6962895098488319


 60%|██████    | 18/30 [03:06<02:07, 10.59s/it]

Epoch: 18, Training Loss: 0.6321750009196928, Validation Loss: 0.6472986053076567, Validation Accuracy: 0.6534446764091858, Validation F1 Score: 0.7855297157622739, Validation F2 Score: 0.8967551622418879, Validation F0.5 Score: 0.6988505747126437


 63%|██████▎   | 19/30 [03:17<01:59, 10.90s/it]

Epoch: 19, Training Loss: 0.6305863466756098, Validation Loss: 0.6474682662581601, Validation Accuracy: 0.6534446764091858, Validation F1 Score: 0.7855297157622739, Validation F2 Score: 0.8967551622418879, Validation F0.5 Score: 0.6988505747126437


 67%|██████▋   | 20/30 [03:28<01:48, 10.87s/it]

Epoch: 20, Training Loss: 0.6348322774736112, Validation Loss: 0.6471913977057551, Validation Accuracy: 0.6534446764091858, Validation F1 Score: 0.7855297157622739, Validation F2 Score: 0.8967551622418879, Validation F0.5 Score: 0.6988505747126437


 70%|███████   | 21/30 [03:39<01:38, 10.93s/it]

Epoch: 21, Training Loss: 0.6299880297704675, Validation Loss: 0.647144414313402, Validation Accuracy: 0.6534446764091858, Validation F1 Score: 0.7855297157622739, Validation F2 Score: 0.8967551622418879, Validation F0.5 Score: 0.6988505747126437


 73%|███████▎  | 22/30 [03:50<01:26, 10.78s/it]

Epoch: 22, Training Loss: 0.6326487554215837, Validation Loss: 0.6470788205217469, Validation Accuracy: 0.651356993736952, Validation F1 Score: 0.7839586028460543, Validation F2 Score: 0.8943329397874853, Validation F0.5 Score: 0.6978350990327038


 77%|███████▋  | 23/30 [04:00<01:13, 10.53s/it]

Epoch: 23, Training Loss: 0.6289027330145442, Validation Loss: 0.6469910452161801, Validation Accuracy: 0.651356993736952, Validation F1 Score: 0.7839586028460543, Validation F2 Score: 0.8943329397874853, Validation F0.5 Score: 0.6978350990327038


 80%|████████  | 24/30 [04:10<01:02, 10.50s/it]

Epoch: 24, Training Loss: 0.6297003853532837, Validation Loss: 0.6469454972505072, Validation Accuracy: 0.651356993736952, Validation F1 Score: 0.7839586028460543, Validation F2 Score: 0.8943329397874853, Validation F0.5 Score: 0.6978350990327038


 83%|████████▎ | 25/30 [04:21<00:52, 10.57s/it]

Epoch: 25, Training Loss: 0.6291001294598435, Validation Loss: 0.6468929633467083, Validation Accuracy: 0.6492693110647182, Validation F1 Score: 0.7823834196891192, Validation F2 Score: 0.8919078558771412, Validation F0.5 Score: 0.6968158744808491


 87%|████████▋ | 26/30 [04:31<00:41, 10.42s/it]

Epoch: 26, Training Loss: 0.6314130133969656, Validation Loss: 0.6468374347139251, Validation Accuracy: 0.6492693110647182, Validation F1 Score: 0.7823834196891192, Validation F2 Score: 0.8919078558771412, Validation F0.5 Score: 0.6968158744808491


 90%|█████████ | 27/30 [04:42<00:31, 10.54s/it]

Epoch: 27, Training Loss: 0.6309967938822265, Validation Loss: 0.6468180584758209, Validation Accuracy: 0.6492693110647182, Validation F1 Score: 0.7823834196891192, Validation F2 Score: 0.8919078558771412, Validation F0.5 Score: 0.6968158744808491


 93%|█████████▎| 28/30 [04:53<00:21, 10.72s/it]

Epoch: 28, Training Loss: 0.6327292657259739, Validation Loss: 0.6467693295633121, Validation Accuracy: 0.6492693110647182, Validation F1 Score: 0.7823834196891192, Validation F2 Score: 0.8919078558771412, Validation F0.5 Score: 0.6968158744808491


 97%|█████████▋| 29/30 [05:03<00:10, 10.59s/it]

Epoch: 29, Training Loss: 0.6329341389106483, Validation Loss: 0.6467284303866249, Validation Accuracy: 0.651356993736952, Validation F1 Score: 0.7833981841763943, Validation F2 Score: 0.892434988179669, Validation F0.5 Score: 0.6981044845122515


100%|██████████| 30/30 [05:13<00:00, 10.46s/it]

Epoch: 30, Training Loss: 0.6285209690328676, Validation Loss: 0.6466965342488816, Validation Accuracy: 0.651356993736952, Validation F1 Score: 0.7833981841763943, Validation F2 Score: 0.892434988179669, Validation F0.5 Score: 0.6981044845122515



