In [36]:
import numpy as np
import tensorflow as tf
from ogb.nodeproppred import Evaluator, NodePropPredDataset
from tensorflow.keras.layers import BatchNormalization, Dropout, Input, Dense
from tensorflow.keras.losses import SparseCategoricalCrossentropy
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam

from spektral.datasets.ogb import OGB
from spektral.layers import GCNConv
from spektral.transforms import AdjToSpTensor, GCNFilter

from warnings import filterwarnings
filterwarnings(action='ignore', category=DeprecationWarning, message='`np.bool` is a deprecated alias')

# Load data
dataset_name = "ogbn-arxiv"
ogb_dataset = NodePropPredDataset(dataset_name)
dataset = OGB(ogb_dataset, transforms=[GCNFilter(), AdjToSpTensor()])
graph = dataset[0]
x, adj, y = graph.x, graph.a, graph.y

# Parameters
channels = 256  # Number of channels for GCN layers
dropout = 0.35  # Dropout rate for the features
learning_rate = 0.01  # Learning rate
epochs = 1000  # Number of training epochs

N = dataset.n_nodes  # Number of nodes in the graph
F = dataset.n_node_features  # Original size of node features
n_out = ogb_dataset.num_classes  # OGB labels are sparse indices

# Data splits
idx = ogb_dataset.get_idx_split()
idx_tr, idx_va, idx_te = idx["train"], idx["valid"], idx["test"]
mask_tr = np.zeros(N, dtype=bool)
mask_va = np.zeros(N, dtype=bool)
mask_te = np.zeros(N, dtype=bool)
mask_tr[idx_tr] = True
mask_va[idx_va] = True
mask_te[idx_te] = True
masks = [mask_tr, mask_va, mask_te]


In [49]:
# Model definition
x_in = Input(shape=(F,))
a_in = Input((N,), sparse=True)
x_1 = GCNConv(channels, activation="relu")([x_in, a_in])
x_1 = BatchNormalization()(x_1)
x_1 = Dropout(dropout)(x_1)
x_2 = GCNConv(channels, activation="relu")([x_1, a_in])
x_2 = BatchNormalization()(x_2)
x_2 = Dropout(dropout)(x_2)
x_3 = GCNConv(n_out, activation="softmax")([x_2, a_in])
# output = Dense(10)(x_3)
# Build model
model = Model(inputs=[x_in, a_in], outputs=x_3)
optimizer = Adam(learning_rate=learning_rate)
loss_fn = SparseCategoricalCrossentropy()
model.summary()

Model: "model_3"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_9 (InputLayer)           [(None, 128)]        0           []                               
                                                                                                  
 input_10 (InputLayer)          [(None, 169343)]     0           []                               
                                                                                                  
 gcn_conv_12 (GCNConv)          (None, 256)          33024       ['input_9[0][0]',                
                                                                  'input_10[0][0]']               
                                                                                                  
 batch_normalization_8 (BatchNo  (None, 256)         1024        ['gcn_conv_12[0][0]']      

In [65]:
# Training function
@tf.function
def train(inputs, target, mask):
    with tf.GradientTape() as tape:
        predictions = model(inputs, training=True)
        loss = loss_fn(target[mask], predictions[mask]) + sum(model.losses)

    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    return loss


# Evaluation with OGB
evaluator = Evaluator(dataset_name)

def evaluate(x, a, y, model, masks, evaluator):
    p = model([x, a], training=False)
    p = p.numpy().argmax(-1)[:, None]
    tr_mask, va_mask, te_mask = masks
    tr_auc = evaluator.eval({"y_true": y[tr_mask], "y_pred": p[tr_mask]})["acc"]
    va_auc = evaluator.eval({"y_true": y[va_mask], "y_pred": p[va_mask]})["acc"]
    te_auc = evaluator.eval({"y_true": y[te_mask], "y_pred": p[te_mask]})["acc"]
    return tr_auc, va_auc, te_auc

In [66]:
# Train model
for i in range(1, 1 + epochs):
    tr_loss = train([x, adj], y, mask_tr)
    tr_acc, va_acc, te_acc = evaluate(x, adj, y, model, masks, evaluator)
    print(
        "Ep. {} - Loss: {:.3f} - Acc: {:.3f} - Val acc: {:.3f} - Test acc: "
        "{:.3f}".format(i, tr_loss, tr_acc, va_acc, te_acc)
    )
# model.compile(optimizer=optimizer, loss='categorical_crossentropy')

# from spektral.data import BatchLoader
# loader = BatchLoader(dataset, batch_size=32)

# model.fit(loader.load(), steps_per_epoch=loader.steps_per_epoch, epochs=1000)

# Evaluate model
print("Evaluating model.")
tr_acc, va_acc, te_acc = evaluate(x, adj, y, model, masks, evaluator)
print("Done! - Test acc: {:.3f}".format(te_acc))

Ep. 1 - Loss: 1.097 - Acc: 0.485 - Val acc: 0.497 - Test acc: 0.527
Ep. 2 - Loss: 1.092 - Acc: 0.496 - Val acc: 0.501 - Test acc: 0.531
Ep. 3 - Loss: 1.098 - Acc: 0.493 - Val acc: 0.491 - Test acc: 0.524
Ep. 4 - Loss: 1.093 - Acc: 0.495 - Val acc: 0.488 - Test acc: 0.522
Ep. 5 - Loss: 1.092 - Acc: 0.494 - Val acc: 0.492 - Test acc: 0.523
Ep. 6 - Loss: 1.093 - Acc: 0.498 - Val acc: 0.496 - Test acc: 0.525
Ep. 7 - Loss: 1.093 - Acc: 0.502 - Val acc: 0.487 - Test acc: 0.520
Ep. 8 - Loss: 1.088 - Acc: 0.516 - Val acc: 0.512 - Test acc: 0.541
Ep. 9 - Loss: 1.092 - Acc: 0.517 - Val acc: 0.516 - Test acc: 0.545
Ep. 10 - Loss: 1.089 - Acc: 0.517 - Val acc: 0.516 - Test acc: 0.546
Ep. 11 - Loss: 1.092 - Acc: 0.509 - Val acc: 0.500 - Test acc: 0.534
Ep. 12 - Loss: 1.083 - Acc: 0.509 - Val acc: 0.491 - Test acc: 0.525
Ep. 13 - Loss: 1.094 - Acc: 0.534 - Val acc: 0.523 - Test acc: 0.549
Ep. 14 - Loss: 1.085 - Acc: 0.534 - Val acc: 0.520 - Test acc: 0.544
Ep. 15 - Loss: 1.090 - Acc: 0.530 - Val acc

KeyboardInterrupt: 

In [67]:
p = model([x, adj], training=False)
# p = p.numpy().argmax(-1)[:, None]
# tr_mask, va_mask, te_mask = masks
# te_auc = evaluator.eval({"y_true": y[te_mask], "y_pred": p[te_mask]})["acc"]
p

<tf.Tensor: shape=(169343, 40), dtype=float32, numpy=
array([[0.00462996, 0.00427004, 0.01627544, ..., 0.02665814, 0.00355087,
        0.00650889],
       [0.00410189, 0.00601544, 0.02579289, ..., 0.0175071 , 0.00256565,
        0.01400298],
       [0.00686559, 0.0024032 , 0.00238143, ..., 0.00804066, 0.00104982,
        0.00649454],
       ...,
       [0.00218012, 0.00310576, 0.01254589, ..., 0.02632165, 0.00287421,
        0.00130697],
       [0.00298416, 0.00756151, 0.00810144, ..., 0.01523414, 0.00422839,
        0.00280293],
       [0.00205035, 0.02974777, 0.01468056, ..., 0.02579349, 0.00984326,
        0.00285765]], dtype=float32)>

In [71]:
graph.x[0]


array([-0.057943, -0.05253 , -0.072603, -0.026555,  0.130435, -0.241386,
       -0.449242, -0.018443, -0.087218,  0.11232 , -0.092125, -0.28956 ,
       -0.081012,  0.074489, -0.156198, -0.097413,  0.11937 ,  0.645755,
        0.077375, -0.09386 , -0.400367,  0.311369, -0.541764,  0.080455,
       -0.00695 ,  0.542316, -0.01223 , -0.180773,  0.016466,  0.050778,
       -0.208276, -0.08701 ,  0.012363,  0.281671,  0.100448, -0.164255,
        0.026892,  0.078199,  0.079534, -0.013387,  0.291491,  0.041601,
       -0.141369, -0.134461,  0.016178,  0.280961, -0.091925, -0.240312,
        0.461786,  0.187323,  0.15335 ,  0.033118,  0.01076 ,  0.012446,
       -0.158857,  0.09798 ,  0.03052 ,  0.016234, -0.095681,  0.05214 ,
        0.321836, -0.105675,  0.222873, -0.120619, -0.172259,  0.395426,
        0.088274, -0.221882,  0.231014, -0.209604, -0.112524, -0.064443,
        0.069746, -0.157444,  0.02228 , -0.418984,  0.134391,  0.26046 ,
        0.041681, -0.093468, -0.051622, -0.025531, 