In [1]:
import logging
import numpy as np
from tqdm import tqdm
import torch

from graphmae.utils import (
    build_args,
    create_optimizer,
    set_random_seed,
    TBLogger,
    get_current_lr,
    load_best_configs,
)
from graphmae.datasets.data_util import load_dataset
from graphmae.evaluation import node_classification_evaluation
from graphmae.models import build_model
from ogb.nodeproppred import DglNodePropPredDataset

2022-09-09 15:18:44,334 - INFO - Enabling RDKit 2022.03.5 jupyter extensions


In [2]:
import argparse
parser = argparse.ArgumentParser(description="GAT")
parser.add_argument("--seeds", type=int, nargs="+", default=[0])
parser.add_argument("--dataset", type=str, default="cora")
parser.add_argument("--device", type=int, default=-1)
parser.add_argument("--max_epoch", type=int, default=200,
                    help="number of training epochs")
parser.add_argument("--warmup_steps", type=int, default=-1)

parser.add_argument("--num_heads", type=int, default=4,
                    help="number of hidden attention heads")
parser.add_argument("--num_out_heads", type=int, default=1,
                    help="number of output attention heads")
parser.add_argument("--num_layers", type=int, default=2,
                    help="number of hidden layers")
parser.add_argument("--num_hidden", type=int, default=256,
                    help="number of hidden units")
parser.add_argument("--residual", action="store_true", default=False,
                    help="use residual connection")
parser.add_argument("--in_drop", type=float, default=.2,
                    help="input feature dropout")
parser.add_argument("--attn_drop", type=float, default=.1,
                    help="attention dropout")
parser.add_argument("--norm", type=str, default=None)
parser.add_argument("--lr", type=float, default=0.005,
                    help="learning rate")
parser.add_argument("--weight_decay", type=float, default=5e-4,
                    help="weight decay")
parser.add_argument("--negative_slope", type=float, default=0.2,
                    help="the negative slope of leaky relu for GAT")
parser.add_argument("--activation", type=str, default="prelu")
parser.add_argument("--mask_rate", type=float, default=0.5)
parser.add_argument("--drop_edge_rate", type=float, default=0.0)
parser.add_argument("--replace_rate", type=float, default=0.0)

parser.add_argument("--encoder", type=str, default="gat")
parser.add_argument("--decoder", type=str, default="gat")
parser.add_argument("--loss_fn", type=str, default="byol")
parser.add_argument("--alpha_l", type=float, default=2, help="`pow`inddex for `sce` loss")
parser.add_argument("--optimizer", type=str, default="adam")

parser.add_argument("--max_epoch_f", type=int, default=30)
parser.add_argument("--lr_f", type=float, default=0.001, help="learning rate for evaluation")
parser.add_argument("--weight_decay_f", type=float, default=0.0, help="weight decay for evaluation")
parser.add_argument("--linear_prob", action="store_true", default=False)

parser.add_argument("--load_model", action="store_true")
parser.add_argument("--save_model", action="store_true")
parser.add_argument("--use_cfg", action="store_true")
parser.add_argument("--logging", action="store_true")
parser.add_argument("--scheduler", action="store_true", default=False)
parser.add_argument("--concat_hidden", action="store_true", default=False)

# for graph classification
parser.add_argument("--pooling", type=str, default="mean")
parser.add_argument("--deg4feat", action="store_true", default=False, help="use node degree as input feature")
parser.add_argument("--batch_size", type=int, default=32)

_StoreAction(option_strings=['--batch_size'], dest='batch_size', nargs=None, const=None, default=32, type=<class 'int'>, choices=None, help=None, metavar=None)

In [3]:
args = parser.parse_args([])

In [4]:
args.lr = 0.001
args.lr_f = 0.01
args.num_hidden = 512
args.num_heads = 4
args.num_layers = 2
args.weight_decay = 2e-4
args.weight_decay_f= 1e-4
args.max_epoch= 1500
args.max_epoch_f= 300
args.mask_rate= 0.5
args.num_layers= 2
args.encoder= "gat"
args.decoder= "gat" 
args.activation= "prelu"
args.in_drop= 0.2
args.attn_drop= 0.1
args.linear_prob= True
args.loss_fn= "sce" 
args.drop_edge_rate=0.0
args.optimizer= "adam"
args.replace_rate= 0.05 
args.alpha_l= 3
args.scheduler= True

In [5]:
args.dataset

'cora'

In [6]:
device = args.device if args.device >= 0 else "cpu"
seeds = args.seeds
dataset_name = args.dataset
max_epoch = args.max_epoch
max_epoch_f = args.max_epoch_f
num_hidden = args.num_hidden
num_layers = args.num_layers
encoder_type = args.encoder
decoder_type = args.decoder
replace_rate = args.replace_rate

optim_type = args.optimizer 
loss_fn = args.loss_fn

lr = args.lr
weight_decay = args.weight_decay
lr_f = args.lr_f
weight_decay_f = args.weight_decay_f
linear_prob = args.linear_prob
load_model = args.load_model
save_model = args.save_model
logs = args.logging
use_scheduler = args.scheduler

In [7]:
graph, (num_features, num_classes) = load_dataset(dataset_name)

  NumNodes: 2708
  NumEdges: 10556
  NumFeats: 1433
  NumClasses: 7
  NumTrainingSamples: 140
  NumValidationSamples: 500
  NumTestSamples: 1000
Done loading data from cached files.


In [8]:
graph

Graph(num_nodes=2708, num_edges=13264,
      ndata_schemes={'feat': Scheme(shape=(1433,), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64), 'test_mask': Scheme(shape=(), dtype=torch.bool), 'val_mask': Scheme(shape=(), dtype=torch.bool), 'train_mask': Scheme(shape=(), dtype=torch.bool)}
      edata_schemes={})

In [9]:
args.num_features = num_features

In [10]:

acc_list = []
estp_acc_list = []

In [11]:
for i, seed in enumerate(seeds):
    print(f"####### Run {i} for seed {seed}")
    set_random_seed(seed)

####### Run 0 for seed 0


In [12]:
seeds

[0]

In [13]:
if logs:
    logger = TBLogger(name=f"{dataset_name}_loss_{loss_fn}_rpr_{replace_rate}_nh_{num_hidden}_nl_{num_layers}_lr_{lr}_mp_{max_epoch}_mpf_{max_epoch_f}_wd_{weight_decay}_wdf_{weight_decay_f}_{encoder_type}_{decoder_type}")
else:
    logger = None

In [14]:
args.loss_fn

'sce'

In [15]:
args.num_hidden = 512

In [16]:
model = build_model(args)

In [17]:
model

PreModel(
  (encoder): GAT(
    (gat_layers): ModuleList(
      (0): GATConv(
        (fc): Linear(in_features=1433, out_features=512, bias=False)
        (feat_drop): Dropout(p=0.2, inplace=False)
        (attn_drop): Dropout(p=0.1, inplace=False)
        (leaky_relu): LeakyReLU(negative_slope=0.2)
        (activation): PReLU(num_parameters=1)
      )
      (1): GATConv(
        (fc): Linear(in_features=512, out_features=512, bias=False)
        (feat_drop): Dropout(p=0.2, inplace=False)
        (attn_drop): Dropout(p=0.1, inplace=False)
        (leaky_relu): LeakyReLU(negative_slope=0.2)
        (activation): PReLU(num_parameters=1)
      )
    )
    (head): Identity()
  )
  (decoder): GAT(
    (gat_layers): ModuleList(
      (0): GATConv(
        (fc): Linear(in_features=512, out_features=1433, bias=False)
        (feat_drop): Dropout(p=0.2, inplace=False)
        (attn_drop): Dropout(p=0.1, inplace=False)
        (leaky_relu): LeakyReLU(negative_slope=0.2)
      )
    )
    (head):

In [18]:
device = 1

In [19]:
model.to(device)

PreModel(
  (encoder): GAT(
    (gat_layers): ModuleList(
      (0): GATConv(
        (fc): Linear(in_features=1433, out_features=512, bias=False)
        (feat_drop): Dropout(p=0.2, inplace=False)
        (attn_drop): Dropout(p=0.1, inplace=False)
        (leaky_relu): LeakyReLU(negative_slope=0.2)
        (activation): PReLU(num_parameters=1)
      )
      (1): GATConv(
        (fc): Linear(in_features=512, out_features=512, bias=False)
        (feat_drop): Dropout(p=0.2, inplace=False)
        (attn_drop): Dropout(p=0.1, inplace=False)
        (leaky_relu): LeakyReLU(negative_slope=0.2)
        (activation): PReLU(num_parameters=1)
      )
    )
    (head): Identity()
  )
  (decoder): GAT(
    (gat_layers): ModuleList(
      (0): GATConv(
        (fc): Linear(in_features=512, out_features=1433, bias=False)
        (feat_drop): Dropout(p=0.2, inplace=False)
        (attn_drop): Dropout(p=0.1, inplace=False)
        (leaky_relu): LeakyReLU(negative_slope=0.2)
      )
    )
    (head):

In [20]:
optimizer = create_optimizer(optim_type, model, lr, weight_decay)

In [21]:
optimizer

Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    lr: 0.001
    weight_decay: 0.0002
)

In [22]:
def pretrain(model, graph, feat, optimizer, max_epoch, device, scheduler, num_classes, lr_f, weight_decay_f, max_epoch_f, linear_prob, logger=None):
    logging.info("start training..")
    graph = graph.to(device)
    x = feat.to(device)

    epoch_iter = tqdm(range(max_epoch))

    for epoch in epoch_iter:
        model.train()

        loss, loss_dict = model(graph, x)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if scheduler is not None:
            scheduler.step()

        epoch_iter.set_description(f"# Epoch {epoch}: train_loss: {loss.item():.4f}")
        if logger is not None:
            loss_dict["lr"] = get_current_lr(optimizer)
            logger.note(loss_dict, step=epoch)

        if (epoch + 1) % 200 == 0:
            node_classification_evaluation(model, graph, x, num_classes, lr_f, weight_decay_f, max_epoch_f, device, linear_prob, mute=True)

    # return best_model
    return model


In [23]:
if use_scheduler:
    logging.info("Use schedular")
    scheduler = lambda epoch :( 1 + np.cos((epoch) * np.pi / max_epoch) ) * 0.5
    # scheduler = lambda epoch: epoch / warmup_steps if epoch < warmup_steps \
            # else ( 1 + np.cos((epoch - warmup_steps) * np.pi / (max_epoch - warmup_steps))) * 0.5
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler)
else:
    scheduler = None

x = graph.ndata["feat"]
if not load_model:
    model = pretrain(model, graph, x, optimizer, max_epoch, device, scheduler, num_classes, lr_f, weight_decay_f, max_epoch_f, linear_prob, logger)
    model = model.cpu()

2022-09-09 15:18:49,497 - INFO - Use schedular
2022-09-09 15:18:49,499 - INFO - start training..
# Epoch 207: train_loss: 0.4237:  14%|█████████████████▊                                                                                                                | 206/1500 [00:06<01:06, 19.57it/s]

# IGNORE: --- TestAcc: 0.8290, early-stopping-TestAcc: 0.8290, Best ValAcc: 0.8020 in epoch 237 --- 


# Epoch 408: train_loss: 0.4159:  27%|███████████████████████████████████                                                                                               | 405/1500 [00:11<00:56, 19.47it/s]

# IGNORE: --- TestAcc: 0.8390, early-stopping-TestAcc: 0.8290, Best ValAcc: 0.7960 in epoch 84 --- 


# Epoch 607: train_loss: 0.4162:  40%|████████████████████████████████████████████████████▍                                                                             | 605/1500 [00:16<00:47, 18.87it/s]

# IGNORE: --- TestAcc: 0.8360, early-stopping-TestAcc: 0.8240, Best ValAcc: 0.7960 in epoch 35 --- 


# Epoch 808: train_loss: 0.4025:  54%|██████████████████████████████████████████████████████████████████████                                                            | 808/1500 [00:21<00:31, 21.87it/s]

# IGNORE: --- TestAcc: 0.8360, early-stopping-TestAcc: 0.8290, Best ValAcc: 0.7920 in epoch 13 --- 


# Epoch 1007: train_loss: 0.3976:  67%|█████████████████████████████████████████████████████████████████████████████████████▋                                          | 1004/1500 [00:27<00:33, 14.64it/s]

# IGNORE: --- TestAcc: 0.8360, early-stopping-TestAcc: 0.8360, Best ValAcc: 0.7940 in epoch 299 --- 


# Epoch 1208: train_loss: 0.4075:  81%|███████████████████████████████████████████████████████████████████████████████████████████████████████▏                        | 1209/1500 [00:32<00:14, 19.61it/s]

# IGNORE: --- TestAcc: 0.8400, early-stopping-TestAcc: 0.8270, Best ValAcc: 0.8020 in epoch 21 --- 


# Epoch 1407: train_loss: 0.3996:  94%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉        | 1406/1500 [00:37<00:04, 19.12it/s]

# IGNORE: --- TestAcc: 0.8410, early-stopping-TestAcc: 0.8300, Best ValAcc: 0.8020 in epoch 19 --- 


# Epoch 1499: train_loss: 0.3969: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1500/1500 [00:39<00:00, 37.95it/s]


In [25]:
load_model

False