In [1]:
from data_utils import *

from keras.layers import Dropout,Input
from keras.regularizers import l2
from keras.optimizers import Adam
from keras.models import Model
from keras.callbacks import EarlyStopping,TensorBoard,ModelCheckpoint

import numpy as np
from layer import GraphAtten

Using TensorFlow backend.


In [2]:
# Read data
A, X, Y_train, Y_val, Y_test, idx_train, idx_val, idx_test = load_data('cora')
# A 邻接矩阵(2708, 2708)  X 是X的features  mutil-hot  (2708, 1433)   Y_train (2708, 7)

In [3]:
# Parameters
N = X.shape[0]                # Number of nodes in the graph
F = X.shape[1]                # Original feature dimension
n_classes = Y_train.shape[1]  # Number of classes
F_ = 8                        # Output size of first GraphAttention layer
n_attn_heads = 8              # Number of attention heads in first GAT layer
dropout_rate = 0.6            # Dropout rate (between and inside GAT layers)
l2_reg = 5e-4/2               # Factor for l2 regularization
learning_rate = 5e-3          # Learning rate for Adam
epochs = 10               # Number of training epochs
es_patience = 100             # Patience fot early stopping

In [4]:
# Preprocessing operations 
X = preprocess_features(X)  ## 进行行归一化
A = A + np.eye(A.shape[0])  # Add self-loops

In [5]:
# Model definition (as per Section 3.3 of the paper)
X_in = Input(shape=(F,))
A_in = Input(shape=(N,))

In [6]:
dropout1 = Dropout(dropout_rate)(X_in)

In [8]:
graph_attention_1 = GraphAtten(F_,
        
                               attn_heads = n_attn_heads,
                              attn_heads_reduction='concat',
                              dropout_rate=dropout_rate,
                              activation='elu',
                              kernel_regularizer=l2(l2_reg),
                              attn_kernel_regularizer=l2(l2_reg))([dropout1,A_in])
dropout2 = Dropout(dropout_rate)(graph_attention_1)
graph_attention_2 = GraphAtten(n_classes,
                              attn_heads = 1,
                              attn_heads_reduction='average',
                              dropout_rate=dropout_rate,
                              activation='softmax',
                              kernel_regularizer=l2(l2_reg),
                              attn_kernel_regularizer=l2(l2_reg))([dropout2,A_in])

In [10]:
# build Model

model = Model(inputs=[X_in,A_in],outputs=graph_attention_2)
optimizer = Adam(lr=learning_rate)
model.compile(optimizer=optimizer,
             loss='categorical_crossentropy',
             weighted_metrics=['acc'])

model.summary()

# Callbacks
es_callback = EarlyStopping(monitor='val_weighted_acc', patience=es_patience)
tb_callback = TensorBoard(log_dir='./logs/loss',  # log 目录
                 histogram_freq=0,  # 按照何等频率（epoch）来计算直方图，0为不计算
#                  batch_size=32,     # 用多大量的数据计算直方图
                 write_graph=True,  # 是否存储网络结构图
                 write_grads=True, # 是否可视化梯度直方图
                 write_images=True,# 是否可视化参数
                batch_size=N)
mc_callback = ModelCheckpoint('logs/best_model.h5',
                              monitor='val_weighted_acc',
                              save_best_only=True,
                              save_weights_only=True)

# Train model
validation_data = ([X, A], Y_val, idx_val)
model.fit([X, A],
          Y_train,
          sample_weight=idx_train,
          epochs=epochs,
          batch_size=N,
          validation_data=validation_data,
          shuffle=False,  # Shuffling data means shuffling the whole graph
          callbacks=[es_callback, tb_callback, mc_callback])

# Load best model
model.load_weights('logs/best_model.h5')

# Evaluate model
eval_results = model.evaluate([X, A],
                              Y_test,
                              sample_weight=idx_test,
                              batch_size=N,
                              verbose=0)
print('Done.\n'
      'Test loss: {}\n'
      'Test accuracy: {}'.format(*eval_results))

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, 1433)         0                                            
__________________________________________________________________________________________________
dropout_1 (Dropout)             (None, 1433)         0           input_1[0][0]                    
__________________________________________________________________________________________________
input_2 (InputLayer)            (None, 2708)         0                                            
__________________________________________________________________________________________________
graph_atten_1 (GraphAtten)      (None, 64)           91904       dropout_1[0][0]                  
                                                                 input_2[0][0]                    
__________