In [2]:
import torch
import torch.nn as nn
import networkx as nx
import numpy as np
import os
import pickle
from datetime import datetime
import pandas as pd
from pathlib import Path, PosixPath
import json
from torch_geometric.data import DataLoader
from torch_geometric.utils import to_networkx
from tqdm import tqdm
import torch.optim as optim
from sklearn.metrics import recall_score, precision_score, balanced_accuracy_score, roc_auc_score, f1_score

from sample import sample_pairs
from misc import collate
from model import GraphSiamese
from embedding import GCN

In [10]:
# Assuming root_dir is the path to your root directory
root_dir = '../synthetic_experiments/results/synthetic'

clique_data = {}
cp_times = {}
label_data = {}

# Walk through all directories and files in root_dir
for dirpath, dirnames, filenames in os.walk(root_dir):
    # If there's a data.p file in this directory, read it
    args_file = os.path.join(dirpath, 'args.json')
    if os.path.isfile(args_file):
        with open(args_file, 'rb') as f:
            arg_data = json.load(f)
            clique_size = arg_data['size_clique']

    data_file = os.path.join(dirpath, 'data.p')
    if os.path.isfile(data_file):
        with open(data_file, 'rb') as f:
            data = pickle.load(f)
            clique_data[clique_size] = data

    # If there's a time.json file in this directory, read it
    time_file = os.path.join(dirpath, 'time.json')
    if os.path.isfile(time_file):
        with open(time_file, 'r') as f:
            time_data = json.load(f)
            cp_times[clique_size] = time_data

    label_file = os.path.join(dirpath, 'labels.p')
    if os.path.isfile(label_file):
        with open(label_file, 'rb') as f:
            data = pickle.load(f)
            label_data[clique_size] = data

In [12]:
for s in [70]:
    for j, i in enumerate(clique_data[s]):
        edge_index = i.edge_index.to(torch.int64)
        networkx_graph = to_networkx(i)
        adjacency = nx.adjacency_matrix(networkx_graph)
        
        attributes = np.eye(adjacency.shape[0])
        clique_data[s][j].x = attributes
    
    train = clique_data[s][:1000]
    train_labels = label_data[s][:1000]

    val = clique_data[s][1000:2000]
    val_labels = label_data[s][1000:2000]

    test = clique_data[s][2000:]
    test_labels = label_data[s][2000:]

    graph_pairs_train = sample_pairs(train,train_labels,nsamples=2000)
    graph_pairs_val = sample_pairs(val,val_labels,nsamples=1000)
    graph_pairs_test = sample_pairs(test,test_labels,nsamples=1000)



1000 positive and 1000 negative examples
500 positive and 500 negative examples
500 positive and 500 negative examples


In [13]:
s = 70
time_test = [t-2000 for t in cp_times[s] if t>=2000]
time_train = [t for t in cp_times[s] if t<1000]

with open(f'../../results/test_synthetic/test/{s}-data.p', 'wb') as f:
    pickle.dump(test, f)

with open(f'../../results/test_synthetic/test/{s}-labels.p', 'wb') as f:
    pickle.dump(test_labels, f)

with open(f'../../results/test_synthetic/test/{s}-time.json', 'w') as f:
    json.dump(time_test, f)

In [15]:
with open(f'../../graph_pairs_train_{s}.p', 'wb') as f:
    pickle.dump(graph_pairs_train, f)
with open(f'../../graph_pairs_val_{s}.p', 'wb') as f:
    pickle.dump(graph_pairs_val, f)

In [5]:
with open(f'../../graph_pairs_train_{s}.p', 'rb') as f:
    graph_pairs_train = pickle.load(f)

with open(f'../../graph_pairs_val_{s}.p', 'rb') as f:
    graph_pairs_val = pickle.load(f)

In [16]:
training_data_pairs = DataLoader(graph_pairs_train, batch_size=6, shuffle=True, collate_fn=collate,
                               drop_last=True)
validation_data_pairs = DataLoader(graph_pairs_val, batch_size=6, shuffle=True, collate_fn=collate,
                               drop_last=True)



In [65]:
topk = 30
dropout = 0.1
input_dim = training_data_pairs.dataset[0][0].x.shape[1]*6
embedding = GCN(input_dim=input_dim, type='gcn', hidden_dim=16, layers=3, dropout=dropout)
model = GraphSiamese(embedding, 'euclidean', 'topk', 'bce', topk, nlinear=2,
                         nhidden=16, dropout=dropout, features=None)

In [66]:
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1)
loss_fn = torch.nn.BCELoss(reduction='none')

In [67]:
logging = {'train_loss': [], 'train_acc': [], 'train_recall': [], 'train_precision': [],
                   'valid_loss': [], 'valid_acc': [], 'valid_precision': [], 'valid_recall': []}

best_f1, best_weights, best_loss = 0., None, np.Inf
final_metrics = {'loss' : [0.0, 0., 0.], 'accuracy': [0., 0., 0.], 'recall': [0., 0., 0.], 'precision': [0., 0., 0.]}

In [68]:
# for early stopping
patience = 10
patience_counter = 0

# training loop
for epoch in range(100):

    # training updates
    train_loss, train_acc, train_precision, train_recall = [], [], [], []

    model.train()

    # minibatch loop
    for (graph1, graph2, labels) in training_data_pairs:

        graph1, graph2, labels = graph1, graph2, labels

        predictions = model(graph1, graph2)
        predictions = torch.sigmoid(predictions) # predictions between 0 and 1

        loss = loss_fn(predictions, labels.float())

        # balanced accuracy score instead of plain accuracy
        accuracy = torch.tensor(np.array((predictions.squeeze().cpu().detach() > 0.5) == labels.cpu(),
                                            dtype=float).mean().item()).unsqueeze(dim=0)
        recall = torch.tensor(
            [recall_score(labels.cpu(), (predictions.squeeze().cpu().detach() > 0.5).float(), zero_division=0.)])
        precision = torch.tensor(
            [precision_score(labels.cpu(), (predictions.squeeze().cpu().detach() > 0.5).float(), zero_division=0.)])

        train_loss.append(loss), train_acc.append(accuracy), train_recall.append(recall), train_precision.append(precision), \

        optimizer.zero_grad()
        loss.mean().backward()
        optimizer.step()


    logging['train_loss'].append(torch.cat(train_loss).mean().item())
    logging['train_acc'].append(torch.cat(train_acc).mean().item())
    logging['train_recall'].append(torch.cat(train_recall).mean().item())
    logging['train_precision'].append(torch.cat(train_precision).mean().item())

    scheduler.step()

    # validation updates
    model.eval()

    valid_loss, valid_acc, valid_recall, valid_precision = [], [], [], []
    with torch.no_grad():
        for (graph1, graph2, labels) in validation_data_pairs:

            graph1, graph2 = graph1, graph2

            predictions = model(graph1, graph2)

            predictions = torch.sigmoid(predictions)

            loss = loss_fn(predictions, labels.float())

            recall = torch.tensor(
                [recall_score(labels.float(), (predictions.squeeze().detach().cpu() > 0.5).float(), zero_division=0.)])
            precision = torch.tensor(
                [precision_score(labels.float(), (predictions.squeeze().detach().cpu() > 0.5).float(), zero_division=0.)])
            accuracy = torch.tensor(np.array((predictions.squeeze().detach().cpu() > 0.5).float() == labels.float(),
                                                dtype=float).mean().item()).unsqueeze(dim=0)

            valid_loss.append(loss), valid_acc.append(accuracy), valid_recall.append(
                recall), valid_precision.append(
                precision)

        logging['valid_loss'].append(torch.cat(valid_loss).mean().item())
        logging['valid_acc'].append(torch.cat(valid_acc).mean().item())
        logging['valid_recall'].append(torch.cat(valid_recall).mean().item())
        logging['valid_precision'].append(torch.cat(valid_precision).mean().item())

        # save best weights
        #if logging['valid_f1'][-1] > best_f1 and epoch > 0:
        if logging['valid_loss'][-1] < best_loss and epoch > 0:
            best_loss = logging['valid_loss'][-1]
            #best_f1 = logging['valid_f1'][-1]
            final_metrics['loss'][:2] = [logging['train_loss'][-1], logging['valid_loss'][-1]]
            final_metrics['accuracy'][:2] = [logging['train_acc'][-1], logging['valid_acc'][-1]]
            final_metrics['recall'][:2] = [logging['train_recall'][-1], logging['valid_recall'][-1]]
            final_metrics['precision'][:2] = [logging['train_precision'][-1], logging['valid_precision'][-1]]
            best_weights = model.state_dict()
            patience_counter = 0
        else:
            patience_counter += 1


    if patience == patience_counter:
        break

    if epoch % 1 == 0:

        train_acc, train_loss = logging['train_acc'][-1], logging['train_loss'][-1]

        valid_acc, valid_loss = logging['valid_acc'][-1], logging['valid_loss'][-1]
        print("Epoch, Training loss, Valid loss, Valid Acc", epoch, train_loss, valid_loss,
                valid_acc)
        print("Patience counter : ", patience_counter)

Epoch, Training loss, Valid loss, Valid Acc 0 0.5527933239936829 0.3908514380455017 0.5956790447235107
Patience counter :  1
Epoch, Training loss, Valid loss, Valid Acc 1 0.39355647563934326 0.3225877285003662 0.612139880657196
Patience counter :  0
Epoch, Training loss, Valid loss, Valid Acc 2 0.3523678779602051 0.3239592909812927 0.588477373123169
Patience counter :  1
Epoch, Training loss, Valid loss, Valid Acc 3 0.33206942677497864 0.2096770703792572 0.6013374924659729
Patience counter :  0
Epoch, Training loss, Valid loss, Valid Acc 4 0.32733848690986633 0.27473321557044983 0.6131687164306641
Patience counter :  1
Epoch, Training loss, Valid loss, Valid Acc 5 0.3146028220653534 0.18777506053447723 0.6090534925460815
Patience counter :  0
Epoch, Training loss, Valid loss, Valid Acc 6 0.2913880944252014 0.3033105134963989 0.5943930149078369
Patience counter :  1
Epoch, Training loss, Valid loss, Valid Acc 7 0.291880339384079 0.5142190456390381 0.6203703880310059
Patience counter :  

In [69]:
model_path = (f's_{s}_k_{topk}')

In [70]:
save_dir = PosixPath('../../synthetic/trained_models/').expanduser() / model_path
if not os.path.isdir(save_dir):
    os.makedirs(save_dir)

pd.DataFrame(logging).to_csv(save_dir / 'logging.csv')

torch.save(best_weights, save_dir /'model.pt')

with open(save_dir / 'results.json', 'w') as fp:
    json.dump(final_metrics, fp, indent=2)