In [1]:
import argparse
import os
import os.path as osp
import shutil
import time
from itertools import product
from json import dumps
import random
import numpy as np
import seaborn as sns
import torch
import torch.nn.functional as F
import torch_geometric.transforms as T
from torch.optim import Adam
from torch_geometric.loader import DataListLoader, DataLoader
from torch_geometric.nn import DataParallel
from torch_geometric.seed import seed_everything
from torch_geometric.utils.convert import to_networkx
from torch_geometric.data import DataLoader, Data
from torch import nn
from tqdm import tqdm

from pytorch_metric_learning.losses import NTXentLoss, VICRegLoss
from sklearn.model_selection import GridSearchCV, StratifiedKFold
from torch_geometric.datasets import GNNBenchmarkDataset
from datasets.SRDataset import SRDataset

import train_utils
from data_utils import extract_edge_attributes
from torch_geometric.datasets import TUDataset
from layers.input_encoder import LinearEncoder, LinearEdgeEncoder
from layers.layer_utils import make_gnn_layer
from models.GraphClassification import GraphClassification
from models.model_utils import make_GNN

import GCL.losses as L
import GCL.augmentors as A
from GCL.eval import get_split, SVMEvaluator, LREvaluator
from GCL.models import DualBranchContrast

import warnings
warnings.filterwarnings('ignore')

In [2]:
parser = argparse.ArgumentParser(f'arguments for training and testing')
parser.add_argument('--save_dir', type=str, default='./save', help='Base directory for saving information.')
parser.add_argument('--seed', type=int, default=2, help='Random seed for reproducibility.')
parser.add_argument('--dataset_name', type=str, default="sr25", help='Name of dataset')
parser.add_argument('--drop_prob', type=float, default=0.6,
                    help='Probability of zeroing an activation in dropout layers.') 
parser.add_argument('--batch_size', type=int, default=32, help='Batch size per GPU. Scales automatically when \
                        multiple GPUs are available.')
parser.add_argument("--parallel", action="store_true",
                    help="If true, use DataParallel for multi-gpu training")
parser.add_argument('--num_workers', type=int, default=0, help='Number of worker.')
parser.add_argument('--load_path', type=str, default=None, help='Path to load as a model checkpoint.')
parser.add_argument('--lr', type=float, default=0.001, help='Learning rate.')
parser.add_argument('--l2_wd', type=float, default=5e-6, help='L2 weight decay.')
parser.add_argument('--num_epochs', type=int, default=500, help='Number of epochs.')
parser.add_argument("--hidden_size", type=int, default=200, help="Hidden size of the model")
parser.add_argument("--model_name", type=str, default="KHopGNNConv",
                    choices=("KHopGNNConv"), help="Base GNN model")
parser.add_argument("--K", type=int, default=3, help="Number of hop to consider")
parser.add_argument("--num_layer", type=int, default=3, help="Number of layer for feature encoder")
parser.add_argument("--JK", type=str, default="sum", choices=("sum", "max", "mean", "attention", "last", "concat"),
                    help="Jumping knowledge method")
parser.add_argument("--residual", default=True, action="store_true", help="If true, use residual connection between each layer")
parser.add_argument("--virtual_node", action="store_true", default=False, 
                    help="If true, add virtual node information in each layer")
parser.add_argument('--split', type=int, default=10, help='Number of fold in cross validation')
parser.add_argument("--train_eps", action="store_true", help="If true, the epsilon in GIN model is trainable")
parser.add_argument("--combine", type=str, default="geometric", choices=("attention", "geometric"),
                    help="Combine method in k-hop aggregation")
parser.add_argument("--pooling_method", type=str, default="sum", choices=("mean", "sum", "attention"),
                    help="Pooling method in graph classification")
parser.add_argument('--norm_type', type=str, default="Batch",
                    choices=("Batch", "Layer", "Instance", "GraphSize", "Pair"),
                    help="Normalization method in model")
parser.add_argument('--aggr', type=str, default="add",
                    help='Aggregation method in GNN layer, only works in GraphSAGE')
parser.add_argument("--patience", type=int, default=20, help="Patient epochs to wait before early stopping.")
parser.add_argument('--factor', type=float, default=0.5, help='Factor for reducing learning rate scheduler')
parser.add_argument('--reprocess', action="store_true", help='If true, reprocess the dataset')
parser.add_argument('--search', action="store_true", help='If true, search hyper-parameters')
parser.add_argument("--pos_enc_dim", type=int, default=6, help="Initial positional dim.")
parser.add_argument("--pos_attr", type=bool, default=False, help="Positional attributes.")
parser.add_argument("--feature_augmentation", type=bool, default=False, help="If true, feature augmentation.")

_StoreAction(option_strings=['--feature_augmentation'], dest='feature_augmentation', nargs=None, const=None, default=False, type=<class 'bool'>, choices=None, required=False, help='If true, feature augmentation.', metavar=None)

In [3]:
args = parser.parse_args("")

In [4]:
args.name = args.model_name + "_" + str(args.K) + "_" + str(args.search)

In [5]:
# Set up logging and devices
args.save_dir = train_utils.get_save_dir(args.save_dir, args.name, type=args.dataset_name)
log = train_utils.get_logger(args.save_dir, args.name)
device, args.gpu_ids = train_utils.get_available_devices()

In [6]:
if len(args.gpu_ids) > 1 and args.parallel:
    log.info(f'Using multi-gpu training')
    args.parallel = True
    loader = DataListLoader
    args.batch_size *= max(1, len(args.gpu_ids))
else:
    log.info(f'Using single-gpu training')
    args.parallel = False
    loader = DataLoader

[08.06.24 16:40:57] Using single-gpu training


In [7]:
# Set random seed
seed = args.seed
log.info(f'Using random seed {seed}...')
seed_everything(seed)

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

[08.06.24 16:40:57] Using random seed 2...


In [8]:
def edge_feature_transform(g):
    return extract_edge_attributes(g, args.pos_enc_dim)

In [9]:
tag = str(int(time.time()))

In [10]:
class LogReg(nn.Module):
    def __init__(self, hid_dim, out_dim):
        super(LogReg, self).__init__()

        self.fc = nn.Linear(hid_dim, out_dim)

    def forward(self, x):
        ret = self.fc(x)
        return ret

In [11]:
def sr25_train(loader, model, optimizer, device, parallel=False):
    model.train()
    total_loss = 0
    for data in loader:
        optimizer.zero_grad()
        if parallel:
            num_graphs = len(data)
            y = torch.cat([d.y for d in data]).to(device)
        else:
            num_graphs = data.num_graphs
            data = data.to(device)
            y = data.y
        out = model(data.graph_embeds).squeeze()
        loss = F.cross_entropy(out, y.float())
        loss.backward()
        total_loss += loss.item() * num_graphs
        optimizer.step()
    return total_loss / len(loader.dataset)

In [12]:
@torch.no_grad()
def sr25_test(loader, model, device, parallel=False):
    model.train()  # eliminate the effect of BN
    y_preds, y_trues = [], []
    for data in loader:
        if parallel:
            y = torch.cat([d.y for d in data]).to(device)
        else:
            data = data.to(device)
            y = data.y
        y_preds.append(torch.argmax(model(data.graph_embeds), dim=-1))
        y_trues.append(y)
    y_preds = torch.cat(y_preds, -1)
    y_trues = torch.cat(y_trues, -1)
    return (y_preds == y_trues).float().mean()

In [13]:
def get_model(args):
    layer = make_gnn_layer(args)
    init_emb = LinearEncoder(args.input_size, args.hidden_size, pos_attr=args.pos_attr)
    init_edge_attr_emb = LinearEdgeEncoder(args.edge_attr_size, args.hidden_size, edge_attr=True)
    init_edge_attr_v2_emb = LinearEdgeEncoder(args.edge_attr_v2_size, args.hidden_size, edge_attr=False)
    
    GNNModel = make_GNN(args)
    
    gnn = GNNModel(
        num_layer=args.num_layer,
        gnn_layer=layer,
        JK=args.JK,
        norm_type=args.norm_type,
        init_emb=init_emb,
        init_edge_attr_emb=init_edge_attr_emb,
        init_edge_attr_v2_emb=init_edge_attr_v2_emb,
        residual=args.residual,
        virtual_node=args.virtual_node,
        drop_prob=args.drop_prob)

    model = GraphClassification(embedding_model=gnn,
                                pooling_method=args.pooling_method,
                                output_size=args.output_size)
    
    model.reset_parameters()

    if args.parallel:
        model = DataParallel(model, args.gpu_ids)
    return model

In [14]:
class Projection(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Projection, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, output_dim),
            nn.ReLU(),
            nn.Linear(output_dim, output_dim),
            nn.ReLU(),
            nn.Linear(output_dim, output_dim),
            nn.ReLU()
        )
        self.linear = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.fc(x) + self.linear(x)

In [None]:
def vicreg_loss(embeddings1, embeddings2, lambda_var=25, lambda_cov=25, mu=1):
    """
    Calculate the VICReg loss between two sets of embeddings with the correct handling of covariance differences.
    
    Args:
    embeddings1, embeddings2 (torch.Tensor): Embeddings from two views, shape (batch_size, feature_dim).
    lambda_var (float): Coefficient for the variance loss.
    lambda_cov (float): Coefficient for the covariance loss.
    mu (float): Coefficient for the invariance loss.
    
    Returns:
    torch.Tensor: The total VICReg loss.
    """
    # Invariance Loss
    invariance_loss = F.mse_loss(embeddings1, embeddings2)

    # Variance Loss
    def variance_loss(embeddings1, embeddings2):
        mean_embeddings1 = embeddings1.mean(dim=0)
        mean_embeddings2 = embeddings2.mean(dim=0)
        
        std_dev1 = torch.sqrt((embeddings1 - mean_embeddings1).var(dim=0) + 1e-4)
        std_dev2 = torch.sqrt((embeddings2 - mean_embeddings2).var(dim=0) + 1e-4)
        
        return torch.mean(torch.abs(F.relu(1 - std_dev1) - F.relu(1 - std_dev2)))

    variance_loss_value = variance_loss(embeddings1, embeddings2)

    # Covariance Loss
    def covariance_loss(embeddings1, embeddings2):
        batch_size, feature_dim = embeddings1.size()
        
        embeddings_centered1 = embeddings1 - embeddings1.mean(dim=0)
        embeddings_centered2 = embeddings2 - embeddings2.mean(dim=0)
        
        covariance_matrix1 = torch.matmul(embeddings_centered1.T, embeddings_centered1) / (batch_size - 1)
        covariance_matrix2 = torch.matmul(embeddings_centered2.T, embeddings_centered2) / (batch_size - 1)
        
        covariance_matrix1.fill_diagonal_(0)
        covariance_matrix2.fill_diagonal_(0)
        
        cov_diff = torch.abs(covariance_matrix1.pow(2) - covariance_matrix2.pow(2))
        return cov_diff.sum() / feature_dim

    covariance_loss_value = covariance_loss(embeddings1, embeddings2)

    total_loss = mu * invariance_loss + lambda_var * variance_loss_value + lambda_cov * covariance_loss_value
    
    return total_loss

In [17]:
class Encoder(torch.nn.Module):
    def __init__(self, model_1, model_2, mlp1, mlp2, aug1, aug2):
        super(Encoder, self).__init__()
        self.model_1 = model_1
        self.model_2 = model_2
        self.mlp1 = mlp1
        self.mlp2 = mlp2
        self.aug1 = aug1
        self.aug2 = aug2
        
    def get_embedding(self, data):
        z, g = self.model_1(data)
        z_pos, g_pos = self.model_2(data)
        
        z = self.mlp1(z)
        g = self.mlp2(g)
        
        z_pos = self.mlp1(z_pos)
        g_pos = self.mlp2(g_pos)

        g = torch.cat((g, g_pos), 1)
        z = torch.cat((z, z_pos), 1)

        return g.detach(), z.detach()

    def forward(self, data):
        data1 = self.aug1(data.x, data.edge_index, data.y, data.pos, data.edge_attr,
                          data.edge_attr_v2, data.batch, data.ptr)
        data2 = self.aug2(data.x, data.edge_index, data.y, data.pos, data.edge_attr,
                          data.edge_attr_v2, data.batch, data.ptr)
        
        # Structural features
        z1, g1 = self.model_1(data1)
        z2, g2 = self.model_1(data2)
        
        # Positional features
        z1_pos, g1_pos = self.model_2(data1)
        z2_pos, g2_pos = self.model_2(data2)
        
        h1, h2 = [self.mlp1(h) for h in [z1, z2]]
        g1, g2 = [self.mlp2(g) for g in [g1, g2]]
        
        h1_pos, h2_pos = [self.mlp1(h_pos) for h_pos in [z1_pos, z2_pos]]
        g1_pos, g2_pos = [self.mlp2(g_pos) for g_pos in [g1_pos, g2_pos]]
        
        h1 = torch.cat((h1, h1_pos), 1)
        h2 = torch.cat((h2, h2_pos), 1)
        g1 = torch.cat((g1, g1_pos), 1)
        g2 = torch.cat((g2, g2_pos), 1)
        
        return h1, h2, g1, g2

In [18]:
def train(encoder_model, dataloader, optimizer, device):
    best = float("inf")
    cnt_wait = 0
    best_t = 0
    
    loss_func = NTXentLoss(temperature=0.10)
    
    encoder_model.train()
    epoch_loss = 0
    for data in dataloader:
        data = data.to(device)
        optimizer.zero_grad()

        if data.x is None:
            num_nodes = data.batch.size(0)
            data.x = torch.ones((num_nodes, 1), dtype=torch.float32, device=data.batch.device)

        h1, h2, g1, g2 = encoder_model(data)
        
        embeddings = torch.cat((g1, g2))
        
        # The same index corresponds to a positive pair
        indices = torch.arange(0, g1.size(0), device=device)
        labels = torch.cat((indices, indices))
        
        reg_loss = vicreg_loss(h1, h2, lambda_var=24, lambda_cov=24, mu=1)
        loss = loss_func(embeddings, labels) + 0.005*reg_loss
        
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        
        if epoch % 20 == 0:
            print("Epoch: {0}, Loss: {1:0.4f}".format(epoch, epoch_loss))

        if epoch_loss < best:
            best = epoch_loss
            best_t = epoch
            cnt_wait = 0
            torch.save(encoder_model.state_dict(), './pkl/best_model_'+ args.dataset_name + tag + '.pkl')
        else:
            cnt_wait += 1

        if cnt_wait == args.patience:
            print("Early stopping")
            break
            
    return epoch_loss

In [19]:
def test(encoder_model, dataset, dataloader, device):
    encoder_model.eval()
    data_list = []
    for data in dataset:
        data = data.to(device)
        if data.x is None:
            num_nodes = data.batch.size(0)
            data.x = torch.ones((num_nodes, 1), dtype=torch.float32, device=data.batch.device)
        graph_embeds, node_embeds = encoder_model.get_embedding(data)
        data = Data(x=node_embeds, edge_index=data.edge_index, y=data.y, pos=data.pos, edge_attr=data.edge_attr,
                    edge_attr_v2=data.edge_attr_v2, batch=data.batch, ptr=None)
        data.graph_embeds = graph_embeds
        data_list.append(data)

    dataset.data.x = dataset.data.x.long()
    dataset.data.y = torch.arange(len(dataset.data.y)).long()  # each graph is a unique class
    train_dataset = data_list
    val_dataset = data_list
    test_dataset = data_list
    
    model = LogReg(hid_dim=data_list[0].graph_embeds.shape[1], out_dim=args.n_classes)
    
    # 2. create loader
    train_loader = loader(train_dataset, args.batch_size, shuffle=True, num_workers=args.num_workers)
    test_loader = loader(test_dataset, args.batch_size, shuffle=False, num_workers=args.num_workers)

    # additional parameter for SR dataset and training
    args.input_size = 2
    args.output_size = len(data_list)

    # output argument to log file
    log.info(f'Args: {dumps(vars(args), indent=4, sort_keys=True)}')
    # get model
    model.to(device)
    pytorch_total_params = train_utils.count_parameters(model)
    log.info(f'The total parameters of model :{[pytorch_total_params]}')

    optimizer = Adam(model.parameters(), lr=args.lr, weight_decay=args.l2_wd)
    best_test_acc = 0
    start_outer = time.time()
    for epoch in range(args.num_epochs):
        start = time.time()
        train_loss = sr25_train(train_loader, model, optimizer, device=device, parallel=args.parallel)
        lr = optimizer.param_groups[0]['lr']
        test_acc = sr25_test(test_loader, model, device=device, parallel=args.parallel)
        if test_acc >= best_test_acc:
            best_test_acc = test_acc
        time_per_epoch = time.time() - start

        log.info(f'Epoch: {epoch + 1:03d}, LR: {lr:7f}, Train Loss: {train_loss:.4f}, Test Acc: {test_acc:.4f}, '
                 f'Best Test Acc: {best_test_acc:.4f}, Seconds: {time_per_epoch:.4f}')

    time_average_epoch = time.time() - start_outer
    log.info(f'Loss: {train_loss:.4f}, Best test: {best_test_acc:.4f}, Seconds/epoch: {time_average_epoch / (epoch + 1):.4f}')
    
    return train_loss, best_test_acc

In [20]:
path = "./data/" + args.dataset_name

In [21]:
if os.path.exists(path + '/processed'):
    shutil.rmtree(path + '/processed')

In [22]:
dataset = SRDataset(path, pre_transform=T.Compose([edge_feature_transform]))

Processing...
Done!


In [23]:
# Determine the maximum degree in the dataset
max_degree = max([data.num_nodes for data in dataset])

# Apply the OneHotDegree transform
dataset.transform = T.OneHotDegree(max_degree)

In [24]:
args.n_classes = dataset.num_classes

In [25]:
args.input_size = dataset.num_features
args.pos_size = args.pos_enc_dim
args.output_size = args.hidden_size
args.edge_attr_size =  dataset.edge_attr.shape[1]
args.edge_attr_v2_size =  dataset.edge_attr_v2.shape[1]

In [26]:
aug1 = A.Identity()

if args.feature_augmentation:
    # Feature Augmentation
    aug2 = A.RandomChoice([A.FeatureDropout(pf=0.1),
                           A.FeatureMasking(pf=0.1),
                           A.EdgeAttrMasking(pf=0.1)], 1)
else:
    # Structure Augmentation
    aug2 = A.RandomChoice([A.RWSampling(num_seeds=1000, walk_length=10),
                       A.NodeDropping(pn=0.1),
                       A.EdgeRemoving(pe=0.1)], 1)

model_1 = get_model(args)
model_1.to(device)

args.pos_attr = True
args.input_size = args.pos_size
model_2 = get_model(args)
model_2.to(device)

mlp1 = Projection(input_dim=args.hidden_size, output_dim=args.hidden_size)
mlp2 = Projection(input_dim=args.hidden_size, output_dim=args.hidden_size)

encoder_model = Encoder(model_1=model_1, model_2=model_2, mlp1=mlp1, mlp2=mlp2, aug1=aug1, aug2=aug2).to(device)

In [27]:
optimizer = Adam(encoder_model.parameters(), lr=args.lr, weight_decay=args.l2_wd)
dataloader = DataLoader(dataset, batch_size=args.batch_size)
with tqdm(total=10, desc='(T)') as pbar:
    for epoch in range(1, 11):
        loss = train(encoder_model, dataloader, optimizer, device)
        pbar.set_postfix({'loss': loss})
        pbar.update()

(T): 100%|████████████████████████████████████████| 10/10 [00:01<00:00,  7.10it/s, loss=3.84]


In [28]:
_, test_accs = test(encoder_model, dataset, dataloader, device)

log.info("-------------------Print final result-------------------------")
log.info(f"Test result: Mean: {test_accs}")

[08.06.24 16:40:58] Args: {
    "JK": "sum",
    "K": 3,
    "aggr": "add",
    "batch_size": 32,
    "combine": "geometric",
    "dataset_name": "sr25",
    "drop_prob": 0.6,
    "edge_attr_size": 4,
    "edge_attr_v2_size": 3,
    "factor": 0.5,
    "feature_augmentation": false,
    "gpu_ids": [],
    "hidden_size": 200,
    "input_size": 2,
    "l2_wd": 5e-06,
    "load_path": null,
    "lr": 0.001,
    "model_name": "KHopGNNConv",
    "n_classes": 1,
    "name": "KHopGNNConv_3_False",
    "norm_type": "Batch",
    "num_epochs": 500,
    "num_layer": 3,
    "num_workers": 0,
    "output_size": 15,
    "parallel": false,
    "patience": 20,
    "pooling_method": "sum",
    "pos_attr": true,
    "pos_enc_dim": 6,
    "pos_size": 6,
    "reprocess": false,
    "residual": true,
    "save_dir": "./save/sr25/KHopGNNConv_3_False-02",
    "search": false,
    "seed": 2,
    "split": 10,
    "train_eps": false,
    "virtual_node": false
}
[08.06.24 16:40:58] The total parameters of model :

[08.06.24 16:40:59] Epoch: 246, LR: 0.001000, Train Loss: 0.0000, Test Acc: 1.0000, Best Test Acc: 1.0000, Seconds: 0.0018
[08.06.24 16:40:59] Epoch: 247, LR: 0.001000, Train Loss: 0.0000, Test Acc: 1.0000, Best Test Acc: 1.0000, Seconds: 0.0013
[08.06.24 16:40:59] Epoch: 248, LR: 0.001000, Train Loss: 0.0000, Test Acc: 1.0000, Best Test Acc: 1.0000, Seconds: 0.0009
[08.06.24 16:40:59] Epoch: 249, LR: 0.001000, Train Loss: 0.0000, Test Acc: 1.0000, Best Test Acc: 1.0000, Seconds: 0.0008
[08.06.24 16:40:59] Epoch: 250, LR: 0.001000, Train Loss: 0.0000, Test Acc: 1.0000, Best Test Acc: 1.0000, Seconds: 0.0010
[08.06.24 16:40:59] Epoch: 251, LR: 0.001000, Train Loss: 0.0000, Test Acc: 1.0000, Best Test Acc: 1.0000, Seconds: 0.0012
[08.06.24 16:40:59] Epoch: 252, LR: 0.001000, Train Loss: 0.0000, Test Acc: 1.0000, Best Test Acc: 1.0000, Seconds: 0.0010
[08.06.24 16:40:59] Epoch: 253, LR: 0.001000, Train Loss: 0.0000, Test Acc: 1.0000, Best Test Acc: 1.0000, Seconds: 0.0010
[08.06.24 16:40:

[08.06.24 16:40:59] Epoch: 473, LR: 0.001000, Train Loss: 0.0000, Test Acc: 1.0000, Best Test Acc: 1.0000, Seconds: 0.0013
[08.06.24 16:40:59] Epoch: 474, LR: 0.001000, Train Loss: 0.0000, Test Acc: 1.0000, Best Test Acc: 1.0000, Seconds: 0.0011
[08.06.24 16:40:59] Epoch: 475, LR: 0.001000, Train Loss: 0.0000, Test Acc: 1.0000, Best Test Acc: 1.0000, Seconds: 0.0009
[08.06.24 16:40:59] Epoch: 476, LR: 0.001000, Train Loss: 0.0000, Test Acc: 1.0000, Best Test Acc: 1.0000, Seconds: 0.0008
[08.06.24 16:40:59] Epoch: 477, LR: 0.001000, Train Loss: 0.0000, Test Acc: 1.0000, Best Test Acc: 1.0000, Seconds: 0.0008
[08.06.24 16:40:59] Epoch: 478, LR: 0.001000, Train Loss: 0.0000, Test Acc: 1.0000, Best Test Acc: 1.0000, Seconds: 0.0009
[08.06.24 16:40:59] Epoch: 479, LR: 0.001000, Train Loss: 0.0000, Test Acc: 1.0000, Best Test Acc: 1.0000, Seconds: 0.0010
[08.06.24 16:40:59] Epoch: 480, LR: 0.001000, Train Loss: 0.0000, Test Acc: 1.0000, Best Test Acc: 1.0000, Seconds: 0.0011
[08.06.24 16:40: