# Synthetic Experiments using Stochastic Block Models

In [9]:
import pickle as pkl
from torch_geometric.data import DataLoader
from itertools import combinations
import random
import networkx as nx
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score

import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from torch.optim.lr_scheduler import StepLR
from src.utils.CreateFeatures import CreateFeatures
from src.pygcn.GCN_synthetic import SiameseGNN
from torch_geometric.data import DataLoader
from torch_geometric.utils import to_networkx

import torch
import torch.nn as nn
import torch
import torch_geometric.data as data

from src.utils.graphs import laplacian_embeddings, random_walk_embeddings, degree_matrix, identity
from torch_geometric.utils import to_networkx
import networkx as nx
import numpy as np
import itertools

from src.utils.sample import sample_pairs

In [10]:
from sklearn.metrics import precision_score, recall_score

def adjusted_f1_score(y_true, y_pred, beta=1.0):
    """
    Calculate the adjusted F1 score.
    
    Parameters:
    y_true (list or array): True labels.
    y_pred (list or array): Predicted labels.
    beta (float): Weight factor.
    
    Returns:
    float: Adjusted F1 score.
    """
    precision = precision_score(y_true, y_pred)
    recall = recall_score(y_true, y_pred)
    
    if precision == 0 and recall == 0:
        return 0.0
    
    adjusted_f1 = (1 + beta**2) * (precision * recall) / (beta**2 * precision + recall)
    return adjusted_f1


In [11]:
def run_model(model, train_data, val_data, lr):
    torch.manual_seed(42)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    scheduler = StepLR(optimizer, step_size=10, gamma=0.1)
    criterion = nn.BCELoss()  # Changed to BCEWithLogitsLoss for numerical stability

    for epoch in tqdm(range(50)):
        model.train()
        train_losses = []
        for data1, data2, label in train_data:

            optimizer.zero_grad()
            out = model(data1, data2)

            label = torch.tensor(label).view(1).float()
            loss = criterion(out.squeeze(0), label)
            loss.backward()
            optimizer.step()
            train_losses.append(loss.item())

        scheduler.step()

        model.eval()
        with torch.no_grad():
            val_losses = []

            val_pred = []
            val_truth = []

            correct = 0
            total = 0
            for data1, data2, label in val_data:
                out = model(data1, data2)
                label = torch.tensor(label).view(1).float()
                val_loss = criterion(out.squeeze(0), label)
                val_losses.append(val_loss.item())

                predictions = torch.round(out.squeeze())

                val_pred.append(predictions)
                val_truth.append(label)

                correct += (predictions == label).sum().item()
                total += 1

            val_loss = sum(val_losses) / len(val_losses)
            val_accuracy = correct / total

        val_f1 = f1_score(val_truth, val_pred)
        print(f'Epoch: {epoch+1}, Training Loss: {sum(train_losses)/len(train_losses)}, Validation Loss: {val_loss}, Validation Accuracy: {val_accuracy}, Validation F1 Score: {val_f1}')
    return val_accuracy, val_f1

## Clique Data

In [12]:
import os
import glob
import pickle
import json

# Assuming root_dir is the path to your root directory
root_dir = 'results/synthetic/'

clique_data = {}
cp_times = {}
label_data = {}

# Walk through all directories and files in root_dir
for dirpath, dirnames, filenames in os.walk(root_dir):
    # If there's a data.p file in this directory, read it
    args_file = os.path.join(dirpath, 'args.json')
    if os.path.isfile(args_file):
        with open(args_file, 'rb') as f:
            arg_data = json.load(f)
            clique_size = arg_data['size_clique']

    data_file = os.path.join(dirpath, 'data.p')
    if os.path.isfile(data_file):
        with open(data_file, 'rb') as f:
            data = pickle.load(f)
            clique_data[clique_size] = data

    # If there's a time.json file in this directory, read it
    time_file = os.path.join(dirpath, 'time.json')
    if os.path.isfile(time_file):
        with open(time_file, 'r') as f:
            time_data = json.load(f)
            cp_times[clique_size] = time_data

    label_file = os.path.join(dirpath, 'labels.p')
    if os.path.isfile(label_file):
        with open(label_file, 'rb') as f:
            data = pickle.load(f)
            label_data[clique_size] = data

In [13]:
sizes = [20, 30, 40, 50, 60, 70, 80]

In [14]:
ex_data = clique_data[30]

In [15]:
for j, i in enumerate(ex_data):
    edge_index = i.edge_index.to(torch.int64)
    networkx_graph = to_networkx(i)
    adjacency = nx.adjacency_matrix(networkx_graph)
    
    attributes = np.eye(adjacency.shape[0])
    ex_data[j].x = attributes
    break

In [16]:
for s in [20]:
    for j, i in enumerate(clique_data[s]):
        edge_index = i.edge_index.to(torch.int64)
        networkx_graph = to_networkx(i)
        adjacency = nx.adjacency_matrix(networkx_graph)
        
        attributes = np.eye(adjacency.shape[0])
        clique_data[s][j].x = attributes
    
    train = clique_data[s][:1000]
    train_labels = label_data[s][:1000]

    val = clique_data[s][1000:2000]
    val_labels = label_data[s][1000:2000]

    test = clique_data[s][2000:]
    test_labels = label_data[s][2000:]

    graph_pairs_train = sample_pairs(train,train_labels,nsamples=2000)
    graph_pairs_val = sample_pairs(train,val_labels,nsamples=2000)

    for j in graph_pairs_train:
        j[2] = int(j[2].item())
    for j in graph_pairs_val:
        j[2] = int(j[2].item())

    # Define hyperparameter grids
    learning_rates = [0.001]
    dropout_rates = [0.1]
    sort_k_values = [30]
    hidden_units_values = [16]

    # Best params: lr=0.001, dropout_rate=0.1, sort_k=40, hidden_units=64

    # Create combinations of hyperparameters
    hyperparameter_combinations = list(itertools.product(learning_rates, dropout_rates, sort_k_values, hidden_units_values))

    best_hyperparams = None
    best_val_score = 0

    for lr, dropout_rate, sort_k, hidden_units in hyperparameter_combinations:
        print(f"Running with lr={lr}, dropout_rate={dropout_rate}, sort_k={sort_k}, hidden_units={hidden_units}")
        model = SiameseGNN(hidden_units, sort_k, dropout_rate)
        val_accuracy, val_f1 = run_model(model, graph_pairs_train, graph_pairs_val, lr)
        
        # Update the best hyperparameters based on validation F1 score
        if val_f1 > best_val_score:
            best_val_score = val_f1
            best_hyperparams = (lr, dropout_rate, sort_k, hidden_units)

    print(f"Best Hyperparameters: Learning Rate: {best_hyperparams[0]}, Dropout Rate: {best_hyperparams[1]}, Sort-k: {best_hyperparams[2]}, Hidden Units: {best_hyperparams[3]}")
    print(f"Best Validation F1 Score: {best_val_score}")

1000 positive and 1000 negative examples
1000 positive and 1000 negative examples
Running with lr=0.001, dropout_rate=0.1, sort_k=30, hidden_units=16


  2%|▏         | 1/50 [00:58<48:09, 58.98s/it]

Epoch: 1, Training Loss: 0.6555357049437615, Validation Loss: 0.666357993455398, Validation Accuracy: 0.6191926884996192, Validation F1 Score: 0.0


  4%|▍         | 2/50 [01:56<46:23, 57.98s/it]

Epoch: 2, Training Loss: 0.6531924906071535, Validation Loss: 0.6662919792344528, Validation Accuracy: 0.6191926884996192, Validation F1 Score: 0.0


  6%|▌         | 3/50 [02:52<44:52, 57.29s/it]

Epoch: 3, Training Loss: 0.6526095377682463, Validation Loss: 0.6661951029495892, Validation Accuracy: 0.6191926884996192, Validation F1 Score: 0.0


  8%|▊         | 4/50 [03:47<43:15, 56.43s/it]

Epoch: 4, Training Loss: 0.651518123993203, Validation Loss: 0.6663965426005213, Validation Accuracy: 0.6191926884996192, Validation F1 Score: 0.0


 10%|█         | 5/50 [04:43<42:00, 56.00s/it]

Epoch: 5, Training Loss: 0.6515535243304632, Validation Loss: 0.6658942049818918, Validation Accuracy: 0.6191926884996192, Validation F1 Score: 0.0


 12%|█▏        | 6/50 [05:40<41:28, 56.56s/it]

Epoch: 6, Training Loss: 0.6508514600233751, Validation Loss: 0.6656589728566589, Validation Accuracy: 0.6191926884996192, Validation F1 Score: 0.0


 14%|█▍        | 7/50 [06:46<42:35, 59.43s/it]

Epoch: 7, Training Loss: 0.6510979001859061, Validation Loss: 0.6658825341155306, Validation Accuracy: 0.6191926884996192, Validation F1 Score: 0.0


 16%|█▌        | 8/50 [08:18<49:00, 70.01s/it]

Epoch: 8, Training Loss: 0.6508192752485003, Validation Loss: 0.665708847360945, Validation Accuracy: 0.6191926884996192, Validation F1 Score: 0.0


 18%|█▊        | 9/50 [09:52<52:58, 77.52s/it]

Epoch: 9, Training Loss: 0.6507820510238167, Validation Loss: 0.6654149184382243, Validation Accuracy: 0.6191926884996192, Validation F1 Score: 0.0


 20%|██        | 10/50 [11:24<54:36, 81.91s/it]

Epoch: 10, Training Loss: 0.6503161609246355, Validation Loss: 0.6655837437031565, Validation Accuracy: 0.6191926884996192, Validation F1 Score: 0.0


 22%|██▏       | 11/50 [12:57<55:23, 85.23s/it]

Epoch: 11, Training Loss: 0.6475597103485811, Validation Loss: 0.6658904826423183, Validation Accuracy: 0.6191926884996192, Validation F1 Score: 0.0


 24%|██▍       | 12/50 [14:25<54:31, 86.09s/it]

Epoch: 12, Training Loss: 0.6451930132893807, Validation Loss: 0.6659370110017693, Validation Accuracy: 0.6191926884996192, Validation F1 Score: 0.0


 26%|██▌       | 13/50 [15:57<54:15, 87.99s/it]

Epoch: 13, Training Loss: 0.6420862164071298, Validation Loss: 0.6667842077941089, Validation Accuracy: 0.6191926884996192, Validation F1 Score: 0.0


 28%|██▊       | 14/50 [17:30<53:38, 89.42s/it]

Epoch: 14, Training Loss: 0.6381986517226734, Validation Loss: 0.6692891835610284, Validation Accuracy: 0.6191926884996192, Validation F1 Score: 0.0


 30%|███       | 15/50 [19:03<52:45, 90.44s/it]

Epoch: 15, Training Loss: 0.6325616595671132, Validation Loss: 0.6735404612041119, Validation Accuracy: 0.6191926884996192, Validation F1 Score: 0.0


 32%|███▏      | 16/50 [20:31<50:55, 89.86s/it]

Epoch: 16, Training Loss: 0.6268420768353483, Validation Loss: 0.6793497871247213, Validation Accuracy: 0.6191926884996192, Validation F1 Score: 0.0


 34%|███▍      | 17/50 [21:48<47:12, 85.84s/it]

Epoch: 17, Training Loss: 0.6210331555887936, Validation Loss: 0.6856931726485831, Validation Accuracy: 0.6191926884996192, Validation F1 Score: 0.0


 36%|███▌      | 18/50 [23:13<45:40, 85.65s/it]

Epoch: 18, Training Loss: 0.6141097583520105, Validation Loss: 0.6942842229775182, Validation Accuracy: 0.6191926884996192, Validation F1 Score: 0.0


 38%|███▊      | 19/50 [24:39<44:21, 85.86s/it]

Epoch: 19, Training Loss: 0.6089207398316833, Validation Loss: 0.7001137818888635, Validation Accuracy: 0.6191926884996192, Validation F1 Score: 0.0


 40%|████      | 20/50 [26:05<42:52, 85.77s/it]

Epoch: 20, Training Loss: 0.6021637632460674, Validation Loss: 0.708187126777981, Validation Accuracy: 0.6191926884996192, Validation F1 Score: 0.0


 42%|████▏     | 21/50 [27:38<42:33, 88.04s/it]

Epoch: 21, Training Loss: 0.5962029560890927, Validation Loss: 0.7090113552149615, Validation Accuracy: 0.6191926884996192, Validation F1 Score: 0.0


 44%|████▍     | 22/50 [29:12<41:56, 89.87s/it]

Epoch: 22, Training Loss: 0.593668007915252, Validation Loss: 0.7106614517415424, Validation Accuracy: 0.6191926884996192, Validation F1 Score: 0.0


 46%|████▌     | 23/50 [30:47<41:03, 91.25s/it]

Epoch: 23, Training Loss: 0.5943430939378729, Validation Loss: 0.7123385825834925, Validation Accuracy: 0.6191926884996192, Validation F1 Score: 0.0


 48%|████▊     | 24/50 [32:17<39:21, 90.84s/it]

Epoch: 24, Training Loss: 0.5914489166377126, Validation Loss: 0.7137827626220252, Validation Accuracy: 0.6191926884996192, Validation F1 Score: 0.0


 50%|█████     | 25/50 [33:52<38:23, 92.16s/it]

Epoch: 25, Training Loss: 0.5904079348552719, Validation Loss: 0.7157573400821426, Validation Accuracy: 0.6191926884996192, Validation F1 Score: 0.0


 52%|█████▏    | 26/50 [35:27<37:10, 92.95s/it]

Epoch: 26, Training Loss: 0.5913854621122284, Validation Loss: 0.7168540677108205, Validation Accuracy: 0.6191926884996192, Validation F1 Score: 0.0


 54%|█████▍    | 27/50 [36:52<34:43, 90.60s/it]

Epoch: 27, Training Loss: 0.5906329630129893, Validation Loss: 0.7181343209631145, Validation Accuracy: 0.6191926884996192, Validation F1 Score: 0.0


 56%|█████▌    | 28/50 [38:26<33:34, 91.56s/it]

Epoch: 28, Training Loss: 0.5896493380243903, Validation Loss: 0.7195018378355216, Validation Accuracy: 0.6195734958111195, Validation F1 Score: 0.001998001998001998


 58%|█████▊    | 29/50 [39:57<31:58, 91.37s/it]

Epoch: 29, Training Loss: 0.5865347231378806, Validation Loss: 0.721378472675509, Validation Accuracy: 0.6195734958111195, Validation F1 Score: 0.001998001998001998


 60%|██████    | 30/50 [41:31<30:45, 92.30s/it]

Epoch: 30, Training Loss: 0.5877170415360784, Validation Loss: 0.7225449387230949, Validation Accuracy: 0.6161462300076161, Validation F1 Score: 0.017543859649122806


 62%|██████▏   | 31/50 [43:06<29:28, 93.07s/it]

Epoch: 31, Training Loss: 0.5862555563197742, Validation Loss: 0.7226710448726751, Validation Accuracy: 0.6146230007616146, Validation F1 Score: 0.03065134099616858


 64%|██████▍   | 32/50 [44:39<27:55, 93.10s/it]

Epoch: 32, Training Loss: 0.5857316208875302, Validation Loss: 0.7227716563242542, Validation Accuracy: 0.6092916984006093, Validation F1 Score: 0.05


 66%|██████▌   | 33/50 [46:12<26:24, 93.21s/it]

Epoch: 33, Training Loss: 0.5862206122537587, Validation Loss: 0.7228961300363045, Validation Accuracy: 0.6009139375476009, Validation F1 Score: 0.07092198581560284


 68%|██████▊   | 34/50 [47:48<25:02, 93.89s/it]

Epoch: 34, Training Loss: 0.5852023804475608, Validation Loss: 0.723043679504429, Validation Accuracy: 0.5936785986290937, Validation F1 Score: 0.10561609388097234


 70%|███████   | 35/50 [49:13<22:47, 91.14s/it]

Epoch: 35, Training Loss: 0.5862169436031048, Validation Loss: 0.7231314830504704, Validation Accuracy: 0.5788271134805788, Validation F1 Score: 0.13996889580093314


 72%|███████▏  | 36/50 [50:35<20:38, 88.47s/it]

Epoch: 36, Training Loss: 0.5857636362582098, Validation Loss: 0.7232582265254386, Validation Accuracy: 0.5742574257425742, Validation F1 Score: 0.191027496382055


 74%|███████▍  | 37/50 [51:52<18:26, 85.09s/it]

Epoch: 37, Training Loss: 0.5861531708876684, Validation Loss: 0.7233492696128633, Validation Accuracy: 0.5632140137090632, Validation F1 Score: 0.21919673247106874


 74%|███████▍  | 37/50 [52:20<18:23, 84.89s/it]


KeyboardInterrupt: 