# Synthetic Experiments using Stochastic Block Models

In [6]:
import pickle as pkl
from torch_geometric.loader import DataLoader
import networkx as nx

import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from torch.optim.lr_scheduler import StepLR
from src.utils.CreateFeatures import CreateFeatures
from src.pygcn.model import GraphSiamese
from torch_geometric.utils import to_networkx
from src.pygcn.embedding import GCN

import torch
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import torch.nn as nn
from tqdm import tqdm
from sklearn.metrics import f1_score

from torch_geometric.utils import to_networkx
import networkx as nx
import numpy as np
import itertools

import os
import pickle
import json

from src.utils.sample import sample_pairs
from src.utils.misc import collate

In [7]:
def run_model(model, train_loader, val_loader):
    torch.manual_seed(42)
    optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=0.0001)
    scheduler = StepLR(optimizer, step_size=100, gamma=0.1)
    criterion = nn.BCEWithLogitsLoss()  # Changed to BCEWithLogitsLoss for numerical stability

    for epoch in tqdm(range(30)):
        model.train()
        train_losses = []
        for data1, data2, labels in train_loader:
            optimizer.zero_grad()
            out = model(data1, data2)
    
            labels = labels.float().view(-1, 1)  # Ensure labels are of the shape (batch_size, 1)
            loss = criterion(out, labels)
            loss.backward()
            optimizer.step()
            train_losses.append(loss.item())

        scheduler.step()

        model.eval()
        with torch.no_grad():
            val_losses = []

            val_pred = []
            val_truth = []

            correct = 0
            total = 0
            for data1, data2, labels in val_loader:
                out = model(data1, data2)

                labels = labels.float().view(-1, 1)  # Ensure labels are of the shape (batch_size, 1)
                val_loss = criterion(out, labels)
                val_losses.append(val_loss.item())

                predictions = torch.round(out)

                val_pred.extend(predictions.cpu().numpy())
                val_truth.extend(labels.cpu().numpy())

                correct += (predictions == labels).sum().item()
                total += labels.size(0)

            val_loss = sum(val_losses) / len(val_losses)
            val_accuracy = correct / total

        val_f1 = f1_score(val_truth, val_pred)
        print(f'Epoch: {epoch+1}, Training Loss: {sum(train_losses)/len(train_losses)}, Validation Loss: {val_loss}, Validation Accuracy: {val_accuracy}, Validation F1 Score: {val_f1}')
    return val_accuracy, val_f1


## Clique Data

In [8]:
# Assuming root_dir is the path to your root directory
root_dir = 'results/synthetic/'

clique_data = {}
cp_times = {}
label_data = {}

# Walk through all directories and files in root_dir
for dirpath, dirnames, filenames in os.walk(root_dir):
    # If there's a data.p file in this directory, read it
    args_file = os.path.join(dirpath, 'args.json')
    if os.path.isfile(args_file):
        with open(args_file, 'rb') as f:
            arg_data = json.load(f)
            clique_size = arg_data['size_clique']

    data_file = os.path.join(dirpath, 'data.p')
    if os.path.isfile(data_file):
        with open(data_file, 'rb') as f:
            data = pickle.load(f)
            clique_data[clique_size] = data

    # If there's a time.json file in this directory, read it
    time_file = os.path.join(dirpath, 'time.json')
    if os.path.isfile(time_file):
        with open(time_file, 'r') as f:
            time_data = json.load(f)
            cp_times[clique_size] = time_data

    label_file = os.path.join(dirpath, 'labels.p')
    if os.path.isfile(label_file):
        with open(label_file, 'rb') as f:
            data = pickle.load(f)
            label_data[clique_size] = data

In [9]:
sizes = [20, 30, 40, 50, 60, 70, 80]

In [10]:
for s in [20]:
    for j, i in enumerate(clique_data[s]):
        edge_index = i.edge_index.to(torch.int64)
        networkx_graph = to_networkx(i)
        adjacency = nx.adjacency_matrix(networkx_graph)
        
        attributes = np.eye(adjacency.shape[0])
        clique_data[s][j].x = attributes
    
    train = clique_data[s][:1000]
    train_labels = label_data[s][:1000]

    val = clique_data[s][1000:2000]
    val_labels = label_data[s][1000:2000]

    test = clique_data[s][2000:]
    test_labels = label_data[s][2000:]

    graph_pairs_train = sample_pairs(train,train_labels,nsamples=2000)
    graph_pairs_val = sample_pairs(train,val_labels,nsamples=2000)

    for j in graph_pairs_train:
        j[2] = int(j[2].item())
    for j in graph_pairs_val:
        j[2] = int(j[2].item())

    training_data_pairs = DataLoader(graph_pairs_train, batch_size=32, shuffle=True, collate_fn=collate,
                               drop_last=True)
    validation_data_pairs = DataLoader(graph_pairs_val, batch_size=32, shuffle=True, collate_fn=collate,
                               drop_last=True)

    input_dim = training_data_pairs.dataset[0][0].x.shape[1]

    # Define hyperparameter grids
    learning_rates = [1e-4]
    dropout_rates = [0.05]
    sort_k_values = [30]
    hidden_units_values = [16]

    # Create combinations of hyperparameters
    hyperparameter_combinations = list(itertools.product(learning_rates, dropout_rates, sort_k_values, hidden_units_values))

    for lr, dropout_rate, sort_k, hidden_units in hyperparameter_combinations:
        embedding = GCN(input_dim=input_dim, hidden_dim=hidden_units, layers=3, dropout=dropout_rate)
        model = GraphSiamese(embedding, sort_k, nlinear = 2, nhidden=hidden_units, dropout = dropout_rate)
        val_accuracy, val_f1 = run_model(model, training_data_pairs, validation_data_pairs)

    model_name = f"models/sgnn-topk{sort_k}-64hidden-{s}clique.pt"
    torch.save(model.state_dict(), model_name)

    time_test = [t-2000 for t in cp_times[s] if t>=2000]

    with open(f'results/test_synthetic/{s}-data.p', 'wb') as f:
        pickle.dump(test, f)

    with open(f'results/test_synthetic/{s}-labels.p', 'wb') as f:
        pickle.dump(test_labels, f)

    with open(f'results/test_synthetic/{s}-time.json', 'w') as f:
        json.dump(time_test, f)



1000 positive and 1000 negative examples
1000 positive and 1000 negative examples


  3%|▎         | 1/30 [02:31<1:13:14, 151.55s/it]

Epoch: 1, Training Loss: 0.7622540905032047, Validation Loss: 0.7367339553656401, Validation Accuracy: 0.6145833333333334, Validation F1 Score: 0.0


  7%|▋         | 2/30 [05:03<1:10:42, 151.53s/it]

Epoch: 2, Training Loss: 0.7322648715141208, Validation Loss: 0.7184506750401155, Validation Accuracy: 0.6141975308641975, Validation F1 Score: 0.0


 10%|█         | 3/30 [07:34<1:08:08, 151.43s/it]

Epoch: 3, Training Loss: 0.7168410434279331, Validation Loss: 0.7094398803181119, Validation Accuracy: 0.6145833333333334, Validation F1 Score: 0.0


 13%|█▎        | 4/30 [10:04<1:05:24, 150.95s/it]

Epoch: 4, Training Loss: 0.7092043105946031, Validation Loss: 0.7045048438472512, Validation Accuracy: 0.6141975308641975, Validation F1 Score: 0.0


 17%|█▋        | 5/30 [12:35<1:02:55, 151.04s/it]

Epoch: 5, Training Loss: 0.7048270348892656, Validation Loss: 0.701728751629959, Validation Accuracy: 0.6145833333333334, Validation F1 Score: 0.0


 20%|██        | 6/30 [15:06<1:00:19, 150.80s/it]

Epoch: 6, Training Loss: 0.7021848226702491, Validation Loss: 0.699859337306317, Validation Accuracy: 0.6153549382716049, Validation F1 Score: 0.0


 23%|██▎       | 7/30 [17:36<57:48, 150.82s/it]  

Epoch: 7, Training Loss: 0.7002717457538428, Validation Loss: 0.6985511058642541, Validation Accuracy: 0.6153549382716049, Validation F1 Score: 0.0


 27%|██▋       | 8/30 [20:07<55:15, 150.71s/it]

Epoch: 8, Training Loss: 0.6989756240401157, Validation Loss: 0.6976082383850475, Validation Accuracy: 0.6149691358024691, Validation F1 Score: 0.0


 30%|███       | 9/30 [22:39<52:50, 150.98s/it]

Epoch: 9, Training Loss: 0.6980475450670997, Validation Loss: 0.6969045621377451, Validation Accuracy: 0.6145833333333334, Validation F1 Score: 0.0


 33%|███▎      | 10/30 [25:10<50:19, 150.98s/it]

Epoch: 10, Training Loss: 0.6973285709702691, Validation Loss: 0.6963780647442664, Validation Accuracy: 0.6149691358024691, Validation F1 Score: 0.0


 37%|███▋      | 11/30 [27:40<47:47, 150.89s/it]

Epoch: 11, Training Loss: 0.6967405510503192, Validation Loss: 0.6959623282338366, Validation Accuracy: 0.6153549382716049, Validation F1 Score: 0.0


 40%|████      | 12/30 [30:12<45:19, 151.10s/it]

Epoch: 12, Training Loss: 0.6962920330291571, Validation Loss: 0.6956083752490856, Validation Accuracy: 0.6141975308641975, Validation F1 Score: 0.0


 43%|████▎     | 13/30 [32:43<42:47, 151.00s/it]

Epoch: 13, Training Loss: 0.6959228335424911, Validation Loss: 0.6953343351682028, Validation Accuracy: 0.6141975308641975, Validation F1 Score: 0.0


 47%|████▋     | 14/30 [35:14<40:18, 151.16s/it]

Epoch: 14, Training Loss: 0.6956373674924984, Validation Loss: 0.6951095263163248, Validation Accuracy: 0.6145833333333334, Validation F1 Score: 0.0


 50%|█████     | 15/30 [37:44<37:41, 150.74s/it]

Epoch: 15, Training Loss: 0.6954003593256307, Validation Loss: 0.6949134748659016, Validation Accuracy: 0.6145833333333334, Validation F1 Score: 0.0


 53%|█████▎    | 16/30 [40:15<35:11, 150.83s/it]

Epoch: 16, Training Loss: 0.6951661068339681, Validation Loss: 0.694768339763453, Validation Accuracy: 0.6153549382716049, Validation F1 Score: 0.0


 57%|█████▋    | 17/30 [42:46<32:40, 150.80s/it]

Epoch: 17, Training Loss: 0.6950010816718257, Validation Loss: 0.694623883123751, Validation Accuracy: 0.6145833333333334, Validation F1 Score: 0.0


 60%|██████    | 18/30 [45:17<30:10, 150.83s/it]

Epoch: 18, Training Loss: 0.6948401803194091, Validation Loss: 0.6945070123966829, Validation Accuracy: 0.6145833333333334, Validation F1 Score: 0.0


 63%|██████▎   | 19/30 [47:47<27:38, 150.76s/it]

Epoch: 19, Training Loss: 0.6947101243706637, Validation Loss: 0.6944085779013457, Validation Accuracy: 0.6149691358024691, Validation F1 Score: 0.0


 67%|██████▋   | 20/30 [50:19<25:10, 151.01s/it]

Epoch: 20, Training Loss: 0.6946037654266801, Validation Loss: 0.6943139264613022, Validation Accuracy: 0.6145833333333334, Validation F1 Score: 0.0


 70%|███████   | 21/30 [52:50<22:39, 151.03s/it]

Epoch: 21, Training Loss: 0.6945027237714723, Validation Loss: 0.694239717942697, Validation Accuracy: 0.6149691358024691, Validation F1 Score: 0.0


 73%|███████▎  | 22/30 [55:20<20:06, 150.87s/it]

Epoch: 22, Training Loss: 0.6944117490635362, Validation Loss: 0.6941676220776122, Validation Accuracy: 0.6145833333333334, Validation F1 Score: 0.0


 77%|███████▋  | 23/30 [57:52<17:37, 151.03s/it]

Epoch: 23, Training Loss: 0.6943380908910618, Validation Loss: 0.6941098460444698, Validation Accuracy: 0.6149691358024691, Validation F1 Score: 0.0


 80%|████████  | 24/30 [1:00:22<15:04, 150.68s/it]

Epoch: 24, Training Loss: 0.6942750097707261, Validation Loss: 0.6940490029476307, Validation Accuracy: 0.6141975308641975, Validation F1 Score: 0.0


 83%|████████▎ | 25/30 [1:02:53<12:34, 150.89s/it]

Epoch: 25, Training Loss: 0.6942030385483143, Validation Loss: 0.6940077594768854, Validation Accuracy: 0.6149691358024691, Validation F1 Score: 0.0


 87%|████████▋ | 26/30 [1:05:23<10:02, 150.75s/it]

Epoch: 26, Training Loss: 0.6941541121449581, Validation Loss: 0.6939695448051264, Validation Accuracy: 0.6153549382716049, Validation F1 Score: 0.0


 90%|█████████ | 27/30 [1:07:54<07:32, 150.74s/it]

Epoch: 27, Training Loss: 0.6941059980281564, Validation Loss: 0.6939243696354054, Validation Accuracy: 0.6145833333333334, Validation F1 Score: 0.0


 93%|█████████▎| 28/30 [1:10:25<05:01, 150.78s/it]

Epoch: 28, Training Loss: 0.6940660892530929, Validation Loss: 0.6938927548903006, Validation Accuracy: 0.6149691358024691, Validation F1 Score: 0.0


 97%|█████████▋| 29/30 [1:12:55<02:30, 150.58s/it]

Epoch: 29, Training Loss: 0.6940211591332458, Validation Loss: 0.6938624382019043, Validation Accuracy: 0.6149691358024691, Validation F1 Score: 0.0


100%|██████████| 30/30 [1:15:26<00:00, 150.89s/it]

Epoch: 30, Training Loss: 0.6939832554307095, Validation Loss: 0.6938371614173606, Validation Accuracy: 0.6153549382716049, Validation F1 Score: 0.0



