In [1]:
import time
import tensorflow as tf
import numpy as np
from models import gcn, lstm
from configs import *
from utils import *
import scipy.sparse
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)


np.random.seed(123)
FLAGS = tf.flags.FLAGS

dataset = 'datasets/rppa'
time_steps = 5
train_ratio = 0.7
batch_size=200
gcn_layers=2
hidden_dim=5
hidden_size=6
time_steps=5
dropout_prob=0.
learning_rate=0.001


adjs, feats, train_idx, val_idx, test_idx = load_data(dataset, time_steps, train_ratio)

num_node = adjs[0].shape[0]
num_feat = feats[0].shape[1]

for i in range(time_steps):
    adjs[i] = sparse_to_tuple(scipy.sparse.coo_matrix(adjs[i]))
#     feats[i] = sparse_to_tuple(scipy.sparse.coo_matrix(feats[i]))
num_features_nonzeros = [x[1].shape for x in feats]

# define placeholders of the input data 
phs = {
        'adjs': [tf.sparse_placeholder(tf.float32, shape=(None, None), name="adjs") for i in
             range(time_steps)],
        'feats': [tf.placeholder(tf.float32, shape=(None, num_feat), name="feats") for _ in
                 range(time_steps)],
        'train_idx': tf.placeholder(tf.int32, shape=(None,), name="train_idx"),
        'val_idx': tf.placeholder(tf.int32, shape=(None,), name="val_idx"),
        'test_idx': tf.placeholder(tf.int32, shape=(None,), name="test_idx"),
        'sample_idx': tf.placeholder(tf.int32, shape=(batch_size,), name='batch_sample_idx'),
        'dropout_prob': tf.placeholder_with_default(0., shape=()),
        'num_features_nonzeros': [tf.placeholder(tf.int64) for i in range(time_steps)]
        }

# define the GCN model
gcn_model = gcn.GraphConvLayer(time_steps = time_steps,
                               gcn_layers=gcn_layers,
                               input_dim=num_feat,
                               hidden_dim=hidden_dim,
                               output_dim=hidden_size,
                               name='nn_fc1',
                               num_features_nonzeros=phs['num_features_nonzeros'],
                               act=tf.nn.relu,
                               dropout_prob=phs['dropout_prob'],
                               dropout=True)
embeds_list = gcn_model(adjs=phs['adjs'],
                    feats=phs['feats'],
                    sparse=False)

# prepare train data for the LSTM-based prediction model
## replace all missing features at (time_steps-1) with GCN imputed features
# embeds_list[time_steps-1] = tf.add(phs['feats'][time_steps-1], 
#                                    tf.multiply(phs['test_mask'][time_steps-1], embeds_list[time_steps-1]))
## construct training samples for the prediction task
x_train, y_train, x_val, y_val, x_test, y_test = build_train_samples_imputation(embeds_list=embeds_list, 
                                                                     feats=phs['feats'], 
                                                                     train_idx=phs['train_idx'],
                                                                     val_idx=phs['val_idx'],
                                                                     test_idx=phs['test_idx'],
                                                                     time_steps=time_steps)
# define the bi-directional LSTM model
lstm_model = lstm.BiLSTM(hidden_size=hidden_size,
                         seq_len=time_steps-1,
                         holders=phs)
x_input_seq = tf.gather(x_train, phs['sample_idx'])
y_input_seq_real = tf.gather(y_train, phs['sample_idx'])
y_input_seq_pred = lstm_model(input_seq=x_input_seq)

with tf.name_scope('optimizer'):
    # calculate the train mse and ad
    train_mse = tf.losses.mean_squared_error(y_input_seq_real, y_input_seq_pred)
    train_absolute_diff = tf.losses.absolute_difference(y_input_seq_real, y_input_seq_pred)
    
    # calculate the val mse and ad
    val_input_seq_pred = lstm_model(input_seq=x_val)
    val_mse = tf.losses.mean_squared_error(y_val, val_input_seq_pred)
    val_absolute_diff = tf.losses.absolute_difference(y_val, val_input_seq_pred)
    
    # calculate the test mse and ad
    test_input_seq_pred = lstm_model(input_seq=x_test)
    test_mse = tf.losses.mean_squared_error(y_test, test_input_seq_pred)
    test_absolute_diff = tf.losses.absolute_difference(y_test, test_input_seq_pred)
    
    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
    opt_op = optimizer.minimize(train_mse)

n_cpus = 8
config = tf.ConfigProto(device_count={ "CPU": n_cpus},
                            inter_op_parallelism_threads=n_cpus,
                            intra_op_parallelism_threads=2)
sess = tf.Session(config=config)
sess.run(tf.global_variables_initializer())

feed_dict = {phs['train_idx']: train_idx,
             phs['val_idx']: val_idx,
             phs['test_idx']: test_idx,
             phs['sample_idx']: None,
             phs['dropout_prob']: dropout_prob}

feed_dict.update({phs['adjs'][t]: adjs[t] for t in range(time_steps)})
feed_dict.update({phs['feats'][t]: feats[t] for t in range(time_steps)})
feed_dict.update({phs['num_features_nonzeros'][t]: num_features_nonzeros[t] for t in range(time_steps)})

feed_dict_val = {phs['train_idx']: train_idx,
                 phs['val_idx']: val_idx,
                 phs['test_idx']: test_idx,
                 phs['dropout_prob']: 0}

feed_dict_val.update({phs['adjs'][t]: adjs[t] for t in range(time_steps)})
feed_dict_val.update({phs['feats'][t]: feats[t] for t in range(time_steps)})
feed_dict_val.update({phs['num_features_nonzeros'][t]: num_features_nonzeros[t] for t in range(time_steps)})



def get_batch_idx(epoch):
    s = batch_size * epoch
    e = batch_size * (epoch + 1)
    idx = []
    for i in range(s,e):
        idx.append(i%len(train_idx))
    return idx

save_step = 10
t = time.time()






In [2]:
epochs = 500
for epoch in range(epochs):
    batch_samples = get_batch_idx(epoch)
    feed_dict.update({phs['sample_idx']: batch_samples})
    _, train_MSE, train_AD = sess.run((opt_op, train_mse, train_absolute_diff), feed_dict=feed_dict)
    val_MSE, val_AD = sess.run((val_mse, val_absolute_diff), 
                                         feed_dict=feed_dict_val) 
    
    print("Epoch:", '%04d' % (epoch + 1),
      "train_loss=", "{:.5f}".format(train_MSE),
      "train_MSE=", "{:.5f}".format(train_MSE),
      "train_AD=", "{:.5f}".format(train_AD),  # AD means the absolute difference
      "val_MSE=", "{:.5f}".format(val_MSE),
      "val_AD=", "{:.5f}".format(val_AD),
      "time=", "{:.5f}".format(time.time() - t))
    
    if (epoch+1) % save_step == 0:
        test_MSE, test_AD = sess.run((test_mse, test_absolute_diff), 
                                            feed_dict=feed_dict_val) 
        print("-------test_MSE=", "{:.5f}".format(test_MSE),
          "test_AD=", "{:.5f}".format(test_AD))
        

Epoch: 0001 train_loss= 0.23824 train_MSE= 0.23824 train_AD= 0.33978 val_MSE= 0.20709 val_AD= 0.31664 time= 1.32583
Epoch: 0002 train_loss= 0.24079 train_MSE= 0.24079 train_AD= 0.34059 val_MSE= 0.20319 val_AD= 0.31332 time= 1.33780
Epoch: 0003 train_loss= 0.23603 train_MSE= 0.23603 train_AD= 0.33695 val_MSE= 0.19932 val_AD= 0.31003 time= 1.34977
Epoch: 0004 train_loss= 0.23193 train_MSE= 0.23193 train_AD= 0.33407 val_MSE= 0.19547 val_AD= 0.30671 time= 1.36075
Epoch: 0005 train_loss= 0.22160 train_MSE= 0.22160 train_AD= 0.32753 val_MSE= 0.19163 val_AD= 0.30332 time= 1.37171
Epoch: 0006 train_loss= 0.22266 train_MSE= 0.22266 train_AD= 0.32868 val_MSE= 0.18779 val_AD= 0.29985 time= 1.38369
Epoch: 0007 train_loss= 0.21850 train_MSE= 0.21850 train_AD= 0.32535 val_MSE= 0.18397 val_AD= 0.29632 time= 1.39365
Epoch: 0008 train_loss= 0.21355 train_MSE= 0.21355 train_AD= 0.32143 val_MSE= 0.18016 val_AD= 0.29272 time= 1.40365
Epoch: 0009 train_loss= 0.20215 train_MSE= 0.20215 train_AD= 0.31349 val

Epoch: 0080 train_loss= 0.11846 train_MSE= 0.11846 train_AD= 0.22567 val_MSE= 0.12260 val_AD= 0.21059 time= 2.30123
-------test_MSE= 0.11756 test_AD= 0.20580
Epoch: 0081 train_loss= 0.11863 train_MSE= 0.11863 train_AD= 0.22590 val_MSE= 0.12251 val_AD= 0.21032 time= 2.31519
Epoch: 0082 train_loss= 0.11708 train_MSE= 0.11708 train_AD= 0.22250 val_MSE= 0.12241 val_AD= 0.21003 time= 2.32816
Epoch: 0083 train_loss= 0.11873 train_MSE= 0.11873 train_AD= 0.22513 val_MSE= 0.12231 val_AD= 0.20974 time= 2.33912
Epoch: 0084 train_loss= 0.11828 train_MSE= 0.11828 train_AD= 0.22458 val_MSE= 0.12222 val_AD= 0.20946 time= 2.35110
Epoch: 0085 train_loss= 0.11219 train_MSE= 0.11219 train_AD= 0.22049 val_MSE= 0.12209 val_AD= 0.20914 time= 2.36210
Epoch: 0086 train_loss= 0.11648 train_MSE= 0.11648 train_AD= 0.22316 val_MSE= 0.12196 val_AD= 0.20883 time= 2.37303
Epoch: 0087 train_loss= 0.11810 train_MSE= 0.11810 train_AD= 0.22442 val_MSE= 0.12183 val_AD= 0.20850 time= 2.38500
Epoch: 0088 train_loss= 0.1170

Epoch: 0156 train_loss= 0.10843 train_MSE= 0.10843 train_AD= 0.20941 val_MSE= 0.11940 val_AD= 0.20143 time= 3.23174
Epoch: 0157 train_loss= 0.10824 train_MSE= 0.10824 train_AD= 0.20884 val_MSE= 0.11943 val_AD= 0.20149 time= 3.24271
Epoch: 0158 train_loss= 0.10913 train_MSE= 0.10913 train_AD= 0.20993 val_MSE= 0.11944 val_AD= 0.20154 time= 3.25468
Epoch: 0159 train_loss= 0.10773 train_MSE= 0.10773 train_AD= 0.20851 val_MSE= 0.11947 val_AD= 0.20160 time= 3.26664
Epoch: 0160 train_loss= 0.10880 train_MSE= 0.10880 train_AD= 0.20985 val_MSE= 0.11950 val_AD= 0.20165 time= 3.27762
-------test_MSE= 0.11179 test_AD= 0.19561
Epoch: 0161 train_loss= 0.10807 train_MSE= 0.10807 train_AD= 0.20908 val_MSE= 0.11950 val_AD= 0.20166 time= 3.29357
Epoch: 0162 train_loss= 0.10778 train_MSE= 0.10778 train_AD= 0.20884 val_MSE= 0.11953 val_AD= 0.20168 time= 3.30457
Epoch: 0163 train_loss= 0.10355 train_MSE= 0.10355 train_AD= 0.20473 val_MSE= 0.11952 val_AD= 0.20168 time= 3.31651
Epoch: 0164 train_loss= 0.1066

-------test_MSE= 0.11049 test_AD= 0.19463
Epoch: 0231 train_loss= 0.10579 train_MSE= 0.10579 train_AD= 0.20640 val_MSE= 0.11828 val_AD= 0.20032 time= 4.17821
Epoch: 0232 train_loss= 0.10019 train_MSE= 0.10019 train_AD= 0.19971 val_MSE= 0.11824 val_AD= 0.20032 time= 4.19117
Epoch: 0233 train_loss= 0.10421 train_MSE= 0.10421 train_AD= 0.20402 val_MSE= 0.11819 val_AD= 0.20029 time= 4.20414
Epoch: 0234 train_loss= 0.10518 train_MSE= 0.10518 train_AD= 0.20470 val_MSE= 0.11816 val_AD= 0.20026 time= 4.21710
Epoch: 0235 train_loss= 0.09782 train_MSE= 0.09782 train_AD= 0.20323 val_MSE= 0.11810 val_AD= 0.20021 time= 4.22907
Epoch: 0236 train_loss= 0.10554 train_MSE= 0.10554 train_AD= 0.20577 val_MSE= 0.11807 val_AD= 0.20017 time= 4.24203
Epoch: 0237 train_loss= 0.10428 train_MSE= 0.10428 train_AD= 0.20471 val_MSE= 0.11805 val_AD= 0.20015 time= 4.25301
Epoch: 0238 train_loss= 0.10596 train_MSE= 0.10596 train_AD= 0.20669 val_MSE= 0.11806 val_AD= 0.20013 time= 4.26198
Epoch: 0239 train_loss= 0.1052

Epoch: 0306 train_loss= 0.10284 train_MSE= 0.10284 train_AD= 0.20400 val_MSE= 0.11553 val_AD= 0.20050 time= 5.11242
Epoch: 0307 train_loss= 0.10345 train_MSE= 0.10345 train_AD= 0.20437 val_MSE= 0.11556 val_AD= 0.20060 time= 5.12538
Epoch: 0308 train_loss= 0.08705 train_MSE= 0.08705 train_AD= 0.19618 val_MSE= 0.11560 val_AD= 0.20071 time= 5.13835
Epoch: 0309 train_loss= 0.10077 train_MSE= 0.10077 train_AD= 0.19995 val_MSE= 0.11565 val_AD= 0.20080 time= 5.15332
Epoch: 0310 train_loss= 0.09510 train_MSE= 0.09510 train_AD= 0.19621 val_MSE= 0.11569 val_AD= 0.20089 time= 5.16628
-------test_MSE= 0.11052 test_AD= 0.19685
Epoch: 0311 train_loss= 0.10227 train_MSE= 0.10227 train_AD= 0.20333 val_MSE= 0.11568 val_AD= 0.20096 time= 5.18523
Epoch: 0312 train_loss= 0.10348 train_MSE= 0.10348 train_AD= 0.20443 val_MSE= 0.11565 val_AD= 0.20100 time= 5.19819
Epoch: 0313 train_loss= 0.10291 train_MSE= 0.10291 train_AD= 0.20310 val_MSE= 0.11561 val_AD= 0.20101 time= 5.21016
Epoch: 0314 train_loss= 0.1019

Epoch: 0390 train_loss= 0.09985 train_MSE= 0.09985 train_AD= 0.20139 val_MSE= 0.11414 val_AD= 0.20311 time= 6.23043
-------test_MSE= 0.11116 test_AD= 0.20090
Epoch: 0391 train_loss= 0.09827 train_MSE= 0.09827 train_AD= 0.19775 val_MSE= 0.11408 val_AD= 0.20309 time= 6.24838
Epoch: 0392 train_loss= 0.10038 train_MSE= 0.10038 train_AD= 0.20135 val_MSE= 0.11404 val_AD= 0.20308 time= 6.26135
Epoch: 0393 train_loss= 0.10043 train_MSE= 0.10043 train_AD= 0.20177 val_MSE= 0.11403 val_AD= 0.20310 time= 6.27232
Epoch: 0394 train_loss= 0.09473 train_MSE= 0.09473 train_AD= 0.19750 val_MSE= 0.11396 val_AD= 0.20310 time= 6.28829
Epoch: 0395 train_loss= 0.09952 train_MSE= 0.09952 train_AD= 0.20064 val_MSE= 0.11390 val_AD= 0.20315 time= 6.30224
Epoch: 0396 train_loss= 0.10080 train_MSE= 0.10080 train_AD= 0.20234 val_MSE= 0.11386 val_AD= 0.20322 time= 6.31520
Epoch: 0397 train_loss= 0.09978 train_MSE= 0.09978 train_AD= 0.20137 val_MSE= 0.11383 val_AD= 0.20326 time= 6.32917
Epoch: 0398 train_loss= 0.1006

Epoch: 0465 train_loss= 0.09828 train_MSE= 0.09828 train_AD= 0.20130 val_MSE= 0.11434 val_AD= 0.20527 time= 7.17091
Epoch: 0466 train_loss= 0.09822 train_MSE= 0.09822 train_AD= 0.20122 val_MSE= 0.11439 val_AD= 0.20538 time= 7.18388
Epoch: 0467 train_loss= 0.09893 train_MSE= 0.09893 train_AD= 0.20222 val_MSE= 0.11447 val_AD= 0.20552 time= 7.19684
Epoch: 0468 train_loss= 0.09772 train_MSE= 0.09772 train_AD= 0.20092 val_MSE= 0.11459 val_AD= 0.20570 time= 7.21280
Epoch: 0469 train_loss= 0.09877 train_MSE= 0.09877 train_AD= 0.20247 val_MSE= 0.11470 val_AD= 0.20587 time= 7.22577
Epoch: 0470 train_loss= 0.09803 train_MSE= 0.09803 train_AD= 0.20154 val_MSE= 0.11479 val_AD= 0.20599 time= 7.23874
-------test_MSE= 0.11194 test_AD= 0.20489
Epoch: 0471 train_loss= 0.09774 train_MSE= 0.09774 train_AD= 0.20124 val_MSE= 0.11489 val_AD= 0.20610 time= 7.25372
Epoch: 0472 train_loss= 0.09479 train_MSE= 0.09479 train_AD= 0.19800 val_MSE= 0.11489 val_AD= 0.20605 time= 7.26466
Epoch: 0473 train_loss= 0.0969