In [1]:
from __future__ import print_function

from keras.layers import Input, Dropout
from keras.models import Model
from keras.optimizers import Adam
from keras.regularizers import l2

from kegra.layers.graph import GraphConvolution
from kegra.utils import *

import time

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [2]:
# Define parameters
DATASET = 'cora'
FILTER = 'localpool'  # 'chebyshev'
MAX_DEGREE = 2  # maximum polynomial degree
SYM_NORM = True  # symmetric (True) vs. left-only (False) normalization
NB_EPOCH = 200
PATIENCE = 10  # early stopping patience

In [3]:
# Get data
X, A, y = load_data(dataset=DATASET)
y_train, y_val, y_test, idx_train, idx_val, idx_test, train_mask = get_splits(y)

# Normalize X
X /= X.sum(1).reshape(-1, 1)

if FILTER == 'localpool':
    """ Local pooling filters (see 'renormalization trick' in Kipf & Welling, arXiv 2016) """
    print('Using local pooling filters...')
    A_ = preprocess_adj(A, SYM_NORM)
    support = 1
    graph = [X, A_]
    G = [Input(shape=(None, None), batch_shape=(None, None), sparse=True)]

elif FILTER == 'chebyshev':
    """ Chebyshev polynomial basis filters (Defferard et al., NIPS 2016)  """
    print('Using Chebyshev polynomial basis filters...')
    L = normalized_laplacian(A, SYM_NORM)
    L_scaled = rescale_laplacian(L)
    T_k = chebyshev_polynomial(L_scaled, MAX_DEGREE)
    support = MAX_DEGREE + 1
    graph = [X]+T_k
    G = [Input(shape=(None, None), batch_shape=(None, None), sparse=True) for _ in range(support)]

else:
    raise Exception('Invalid filter type.')

X_in = Input(shape=(X.shape[1],))

Loading cora dataset...
Dataset has 2708 nodes, 5429 edges, 1433 features.
Using local pooling filters...


In [4]:
# Define model architecture
# NOTE: We pass arguments for graph convolutional layers as a list of tensors.
# This is somewhat hacky, more elegant options would require rewriting the Layer base class.
H = Dropout(0.5)(X_in)
H = GraphConvolution(16, support, activation='relu', kernel_regularizer=l2(5e-4))([H]+G)
H = Dropout(0.5)(H)
Y = GraphConvolution(y.shape[1], support, activation='softmax')([H]+G)

# Compile model
model = Model(inputs=[X_in]+G, outputs=Y)
model.compile(loss='categorical_crossentropy', optimizer=Adam(lr=0.01))

# Helper variables for main training loop
wait = 0
preds = None
best_val_loss = 99999

In [5]:
# Fit
for epoch in range(1, NB_EPOCH+1):

    # Log wall-clock time
    t = time.time()

    # Single training iteration (we mask nodes without labels for loss calculation)
    model.fit(graph, y_train, sample_weight=train_mask,
              batch_size=A.shape[0], epochs=1, shuffle=False, verbose=0)

    # Predict on full dataset
    preds = model.predict(graph, batch_size=A.shape[0])

    # Train / validation scores
    train_val_loss, train_val_acc = evaluate_preds(preds, [y_train, y_val],
                                                   [idx_train, idx_val])
    print("Epoch: {:04d}".format(epoch),
          "train_loss= {:.4f}".format(train_val_loss[0]),
          "train_acc= {:.4f}".format(train_val_acc[0]),
          "val_loss= {:.4f}".format(train_val_loss[1]),
          "val_acc= {:.4f}".format(train_val_acc[1]),
          "time= {:.4f}".format(time.time() - t))

    # Early stopping
    if train_val_loss[1] < best_val_loss:
        best_val_loss = train_val_loss[1]
        wait = 0
    else:
        if wait >= PATIENCE:
            print('Epoch {}: early stopping'.format(epoch))
            break
        wait += 1

Epoch: 0001 train_loss= 1.9369 train_acc= 0.3643 val_loss= 1.9383 val_acc= 0.3633 time= 2.6145
Epoch: 0002 train_loss= 1.9266 train_acc= 0.4214 val_loss= 1.9296 val_acc= 0.4100 time= 0.0214
Epoch: 0003 train_loss= 1.9149 train_acc= 0.4643 val_loss= 1.9197 val_acc= 0.4367 time= 0.0223
Epoch: 0004 train_loss= 1.9023 train_acc= 0.4643 val_loss= 1.9090 val_acc= 0.4500 time= 0.0217
Epoch: 0005 train_loss= 1.8893 train_acc= 0.4643 val_loss= 1.8981 val_acc= 0.4100 time= 0.0254
Epoch: 0006 train_loss= 1.8761 train_acc= 0.4500 val_loss= 1.8870 val_acc= 0.4033 time= 0.0262
Epoch: 0007 train_loss= 1.8623 train_acc= 0.4500 val_loss= 1.8756 val_acc= 0.4033 time= 0.0289
Epoch: 0008 train_loss= 1.8482 train_acc= 0.4500 val_loss= 1.8641 val_acc= 0.4100 time= 0.0247
Epoch: 0009 train_loss= 1.8339 train_acc= 0.4500 val_loss= 1.8525 val_acc= 0.4167 time= 0.0230
Epoch: 0010 train_loss= 1.8201 train_acc= 0.4500 val_loss= 1.8411 val_acc= 0.4100 time= 0.0234
Epoch: 0011 train_loss= 1.8060 train_acc= 0.4500 v

Epoch: 0090 train_loss= 0.8877 train_acc= 0.8786 val_loss= 1.1461 val_acc= 0.7567 time= 0.0290
Epoch: 0091 train_loss= 0.8807 train_acc= 0.8786 val_loss= 1.1399 val_acc= 0.7500 time= 0.0288
Epoch: 0092 train_loss= 0.8739 train_acc= 0.8714 val_loss= 1.1334 val_acc= 0.7467 time= 0.0296
Epoch: 0093 train_loss= 0.8669 train_acc= 0.8643 val_loss= 1.1271 val_acc= 0.7533 time= 0.0284
Epoch: 0094 train_loss= 0.8593 train_acc= 0.8643 val_loss= 1.1211 val_acc= 0.7500 time= 0.0280
Epoch: 0095 train_loss= 0.8522 train_acc= 0.8714 val_loss= 1.1161 val_acc= 0.7467 time= 0.0292
Epoch: 0096 train_loss= 0.8453 train_acc= 0.8786 val_loss= 1.1114 val_acc= 0.7567 time= 0.0292
Epoch: 0097 train_loss= 0.8386 train_acc= 0.8786 val_loss= 1.1067 val_acc= 0.7567 time= 0.0306
Epoch: 0098 train_loss= 0.8321 train_acc= 0.8786 val_loss= 1.1023 val_acc= 0.7600 time= 0.0276
Epoch: 0099 train_loss= 0.8258 train_acc= 0.8857 val_loss= 1.0981 val_acc= 0.7633 time= 0.0273
Epoch: 0100 train_loss= 0.8195 train_acc= 0.8929 v

Epoch: 0177 train_loss= 0.5142 train_acc= 0.9500 val_loss= 0.8560 val_acc= 0.8067 time= 0.0290
Epoch: 0178 train_loss= 0.5110 train_acc= 0.9500 val_loss= 0.8536 val_acc= 0.8033 time= 0.0289
Epoch: 0179 train_loss= 0.5081 train_acc= 0.9500 val_loss= 0.8519 val_acc= 0.8033 time= 0.0311
Epoch: 0180 train_loss= 0.5054 train_acc= 0.9500 val_loss= 0.8496 val_acc= 0.8000 time= 0.0313
Epoch: 0181 train_loss= 0.5030 train_acc= 0.9500 val_loss= 0.8483 val_acc= 0.7967 time= 0.0282
Epoch: 0182 train_loss= 0.5008 train_acc= 0.9500 val_loss= 0.8483 val_acc= 0.7967 time= 0.0291
Epoch: 0183 train_loss= 0.4986 train_acc= 0.9500 val_loss= 0.8482 val_acc= 0.7967 time= 0.0265
Epoch: 0184 train_loss= 0.4965 train_acc= 0.9500 val_loss= 0.8464 val_acc= 0.7967 time= 0.0255
Epoch: 0185 train_loss= 0.4942 train_acc= 0.9500 val_loss= 0.8437 val_acc= 0.7967 time= 0.0270
Epoch: 0186 train_loss= 0.4923 train_acc= 0.9500 val_loss= 0.8416 val_acc= 0.8000 time= 0.0289
Epoch: 0187 train_loss= 0.4904 train_acc= 0.9500 v

In [6]:
# Testing
test_loss, test_acc = evaluate_preds(preds, [y_test], [idx_test])
print("Test set results:",
      "loss= {:.4f}".format(test_loss[0]),
      "accuracy= {:.4f}".format(test_acc[0]))

Test set results: loss= 0.8773 accuracy= 0.7960
