In [45]:
import numpy as np;
import tensorflow as tf
import spektral


In [46]:
cora_dataset = spektral.datasets.citation.Citation(name='cora')
test_mask = cora_dataset.mask_te
train_mask = cora_dataset.mask_tr
val_mask = cora_dataset.mask_va
graph = cora_dataset.graphs[0]
features = graph.x
adj = graph.a
labels = graph.y

In [47]:
adj =  adj + np.eye(adj.shape[0])
adj = tf.cast(adj, dtype=tf.float32)


print(features.shape)
print(adj.shape)
print(labels.shape)

print(np.sum(train_mask))
print(np.sum(val_mask))
print(np.sum(test_mask))

(2708, 1433)
(2708, 2708)
(2708, 7)
140
500
1000


In [48]:
def masked_softmax_cross_entropy(logits, labels, mask):
    loss = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=labels)
    mask = tf.cast(mask, dtype=tf.float32)
    mask /= tf.reduce_mean(mask)
    loss *= mask
    return tf.reduce_mean(loss)

def masked_accuracy(logits, labels, mask):
    correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(labels, 1))
    accuracy_all = tf.cast(correct_prediction, dtype=tf.float32)
    mask = tf.cast(mask, dtype=tf.float32)
    mask /= tf.reduce_mean(mask)
    accuracy_all *= mask
    return tf.reduce_mean(accuracy_all)

In [49]:
def gnn(fts, adj, transform, activation):
    seq_fts = transform(fts)
    # seq_fts = tf.cast(seq_fts, dtype=tf.float32)

    ret_fts = tf.matmul(adj, seq_fts)
    return activation(ret_fts)

In [50]:
def train_cora(fts, adj, gnn_fn, units, epochs, lr):
    lyr_1 = tf.keras.layers.Dense(units)
    lyr_2 = tf.keras.layers.Dense(7)

    def cora_gnn(fts, adj):
        hidden = gnn_fn(fts, adj, lyr_1, tf.nn.relu)
        logits = gnn_fn(hidden, adj, lyr_2, tf.identity)
        return logits
    
    optimizer = tf.keras.optimizers.legacy.Adam(learning_rate=lr)

    best_accuracy = 0.0
    for ep in range(epochs+1):
        with tf.GradientTape() as t:
            logits = cora_gnn(fts, adj)
            loss = masked_softmax_cross_entropy(logits,labels,train_mask)
        
        variables = t.watched_variables()
        grads = t.gradient(loss, variables)
        optimizer.apply_gradients(zip(grads,variables))

        logits = cora_gnn(fts,adj)
        val_accuracy = masked_accuracy(logits, labels, val_mask)
        test_accuracy = masked_accuracy(logits, labels, test_mask)

        if (val_accuracy > best_accuracy):
            best_accuracy = val_accuracy
            #should do model saving here
            print('Epoch', ep, '| Training loss:', loss.numpy(), '| Val accuracy:', val_accuracy.numpy (), '| Test accuracy:', test_accuracy.numpy())

In [51]:
# Raw adj matrix -> sum-pooling -> not very scalable
train_cora(features, adj, gnn ,32, 200, 0.01)

Epoch 0 | Training loss: 4.7426767 | Val accuracy: 0.102 | Test accuracy: 0.11799999
Epoch 1 | Training loss: 5.729286 | Val accuracy: 0.44599998 | Test accuracy: 0.464
Epoch 2 | Training loss: 1.5892051 | Val accuracy: 0.454 | Test accuracy: 0.48400006
Epoch 3 | Training loss: 1.5592555 | Val accuracy: 0.57 | Test accuracy: 0.61800003
Epoch 4 | Training loss: 1.0813582 | Val accuracy: 0.642 | Test accuracy: 0.658
Epoch 5 | Training loss: 0.80808634 | Val accuracy: 0.678 | Test accuracy: 0.698
Epoch 6 | Training loss: 0.61906916 | Val accuracy: 0.69 | Test accuracy: 0.719
Epoch 7 | Training loss: 0.42336544 | Val accuracy: 0.696 | Test accuracy: 0.709
Epoch 10 | Training loss: 0.21703258 | Val accuracy: 0.69600004 | Test accuracy: 0.717
Epoch 12 | Training loss: 0.15955687 | Val accuracy: 0.6999999 | Test accuracy: 0.713
Epoch 31 | Training loss: 0.02082393 | Val accuracy: 0.7 | Test accuracy: 0.721
Epoch 32 | Training loss: 0.019563328 | Val accuracy: 0.70199996 | Test accuracy: 0.721

In [52]:
# Point wise MLP: passing identity matrix instead of adjacency matrix to perceive if there is benefit to adjacency matrix 
# -> there is a benefit
train_cora(features, tf.eye(adj.shape[0]), gnn, 32, 200, 0.01)

Epoch 0 | Training loss: 1.9526836 | Val accuracy: 0.262 | Test accuracy: 0.25399998
Epoch 1 | Training loss: 1.6843466 | Val accuracy: 0.34999996 | Test accuracy: 0.37399998
Epoch 2 | Training loss: 1.4512416 | Val accuracy: 0.41 | Test accuracy: 0.41400003
Epoch 3 | Training loss: 1.206158 | Val accuracy: 0.444 | Test accuracy: 0.44799998
Epoch 4 | Training loss: 0.9612844 | Val accuracy: 0.47 | Test accuracy: 0.483
Epoch 5 | Training loss: 0.7374072 | Val accuracy: 0.50200003 | Test accuracy: 0.496
Epoch 6 | Training loss: 0.54875 | Val accuracy: 0.51199996 | Test accuracy: 0.5
Epoch 7 | Training loss: 0.40108284 | Val accuracy: 0.516 | Test accuracy: 0.508
Epoch 8 | Training loss: 0.29209167 | Val accuracy: 0.524 | Test accuracy: 0.51600003
Epoch 9 | Training loss: 0.21502826 | Val accuracy: 0.532 | Test accuracy: 0.52000004
Epoch 10 | Training loss: 0.1612534 | Val accuracy: 0.54599994 | Test accuracy: 0.523
Epoch 11 | Training loss: 0.123685636 | Val accuracy: 0.546 | Test accura

In [53]:
# Mean Pooling -> Normalized Propagation role -> mitigates exploding/imploding gradient issue
deg =  tf.reduce_sum(adj, axis=-1)
train_cora(features, adj / deg, gnn, 32, 200 , 0.01)

Epoch 0 | Training loss: 1.9567463 | Val accuracy: 0.43199998 | Test accuracy: 0.425
Epoch 1 | Training loss: 1.7688234 | Val accuracy: 0.54600006 | Test accuracy: 0.563
Epoch 2 | Training loss: 1.5519285 | Val accuracy: 0.62399995 | Test accuracy: 0.65099996
Epoch 3 | Training loss: 1.3180251 | Val accuracy: 0.67599994 | Test accuracy: 0.716
Epoch 4 | Training loss: 1.0999073 | Val accuracy: 0.708 | Test accuracy: 0.7519999
Epoch 5 | Training loss: 0.8997514 | Val accuracy: 0.73999995 | Test accuracy: 0.78299993
Epoch 6 | Training loss: 0.72407496 | Val accuracy: 0.758 | Test accuracy: 0.79899985
Epoch 7 | Training loss: 0.5763984 | Val accuracy: 0.7719999 | Test accuracy: 0.80499977
Epoch 8 | Training loss: 0.45550808 | Val accuracy: 0.7779999 | Test accuracy: 0.8029998
Epoch 9 | Training loss: 0.35778803 | Val accuracy: 0.7799999 | Test accuracy: 0.80899984


In [55]:
# Graph Convolution
norm_deg = tf.linalg.diag(1.0 / tf.sqrt(deg))
norm_adj = tf.matmul(norm_deg, tf.matmul(adj, norm_deg))
train_cora(features, norm_adj, gnn, 32, 200, 0.01)

Epoch 0 | Training loss: 1.9516594 | Val accuracy: 0.592 | Test accuracy: 0.62299997
Epoch 1 | Training loss: 1.7500453 | Val accuracy: 0.702 | Test accuracy: 0.729
Epoch 2 | Training loss: 1.5178025 | Val accuracy: 0.74399996 | Test accuracy: 0.7799999
Epoch 3 | Training loss: 1.2690847 | Val accuracy: 0.76000005 | Test accuracy: 0.798
Epoch 4 | Training loss: 1.0371318 | Val accuracy: 0.778 | Test accuracy: 0.8099999
Epoch 5 | Training loss: 0.8318434 | Val accuracy: 0.7879999 | Test accuracy: 0.81299984
Epoch 6 | Training loss: 0.65540075 | Val accuracy: 0.78999996 | Test accuracy: 0.80499977
Epoch 7 | Training loss: 0.5087329 | Val accuracy: 0.79 | Test accuracy: 0.79899985
