In [1]:
import sys, os, _pickle as pickle
import tensorflow as tf
import numpy as np
import nltk
from sklearn.metrics import f1_score

In [2]:
data_dir = '../data'
ckpt_dir = '../checkpoint'
word_embd_dir = '../checkpoint/word_embd'
model_dir = '../checkpoint/modelv4'

In [3]:
word_embd_dim = 100
pos_embd_dim = 25
dep_embd_dim = 25
word_vocab_size = 400001
pos_vocab_size = 10
dep_vocab_size = 21
relation_classes = 19
word_state_size = 100
other_state_size = 50
batch_size = 10
channels = 3
lambda_l2 = 0.0001
max_len_path = 10

In [4]:
with tf.name_scope("input"):
    path_length = tf.placeholder(tf.int32, shape=[2, batch_size], name="path1_length")
    word_ids = tf.placeholder(tf.int32, shape=[2, batch_size, max_len_path], name="word_ids")
    pos_ids = tf.placeholder(tf.int32, [2, batch_size, max_len_path], name="pos_ids")
    dep_ids = tf.placeholder(tf.int32, [2, batch_size, max_len_path], name="dep_ids")
    y = tf.placeholder(tf.int32, [batch_size], name="y")

In [5]:
with tf.name_scope("word_embedding"):
    W = tf.Variable(tf.constant(0.0, shape=[word_vocab_size, word_embd_dim]), name="W")
    embedding_placeholder = tf.placeholder(tf.float32,[word_vocab_size, word_embd_dim])
    embedding_init = W.assign(embedding_placeholder)
    embedded_word = tf.nn.embedding_lookup(W, word_ids)
    word_embedding_saver = tf.train.Saver({"word_embedding/W": W})

with tf.name_scope("pos_embedding"):
    W = tf.Variable(tf.random_uniform([pos_vocab_size, pos_embd_dim]), name="W")
    embedded_pos = tf.nn.embedding_lookup(W, pos_ids)
    pos_embedding_saver = tf.train.Saver({"pos_embedding/W": W})

with tf.name_scope("dep_embedding"):
    W = tf.Variable(tf.random_uniform([dep_vocab_size, dep_embd_dim]), name="W")
    embedded_dep = tf.nn.embedding_lookup(W, dep_ids)
    dep_embedding_saver = tf.train.Saver({"dep_embedding/W": W})

In [6]:
with tf.name_scope("word_dropout"):
    embedded_word_drop = tf.nn.dropout(embedded_word, 0.3)

In [7]:
word_hidden_state = tf.zeros([batch_size, word_state_size], name='word_hidden_state')
word_cell_state = tf.zeros([batch_size, word_state_size], name='word_cell_state')
word_init_state = tf.contrib.rnn.LSTMStateTuple(word_hidden_state, word_cell_state)

other_hidden_states = tf.zeros([channels-1, batch_size, other_state_size], name="hidden_state")
other_cell_states = tf.zeros([channels-1, batch_size, other_state_size], name="cell_state")

other_init_states = [tf.contrib.rnn.LSTMStateTuple(other_hidden_states[i], other_cell_states[i]) for i in range(channels-1)]

with tf.variable_scope("word_lstm1"):
    cell = tf.contrib.rnn.BasicLSTMCell(word_state_size)
    state_series, current_state = tf.nn.dynamic_rnn(cell, embedded_word_drop[0], sequence_length=path_length[0], initial_state=word_init_state)
    state_series_word1 = tf.reduce_max(state_series, axis=1)

with tf.variable_scope("word_lstm2"):
    cell = tf.contrib.rnn.BasicLSTMCell(word_state_size)
    state_series, current_state = tf.nn.dynamic_rnn(cell, embedded_word_drop[1], sequence_length=path_length[1], initial_state=word_init_state)
    state_series_word2 = tf.reduce_max(state_series, axis=1)

with tf.variable_scope("pos_lstm1"):
    cell = tf.contrib.rnn.BasicLSTMCell(other_state_size)
    state_series, current_state = tf.nn.dynamic_rnn(cell, embedded_pos[0], sequence_length=path_length[0],initial_state=other_init_states[0])
    state_series_pos1 = tf.reduce_max(state_series, axis=1)

with tf.variable_scope("pos_lstm2"):
    cell = tf.contrib.rnn.BasicLSTMCell(other_state_size)
    state_series, current_state = tf.nn.dynamic_rnn(cell, embedded_pos[1], sequence_length=path_length[1],initial_state=other_init_states[0])
    state_series_pos2 = tf.reduce_max(state_series, axis=1)

with tf.variable_scope("dep_lstm1"):
    cell = tf.contrib.rnn.BasicLSTMCell(other_state_size)
    state_series, current_state = tf.nn.dynamic_rnn(cell, embedded_dep[0], sequence_length=path_length[0], initial_state=other_init_states[1])
    state_series_dep1 = tf.reduce_max(state_series, axis=1)

with tf.variable_scope("dep_lstm2"):
    cell = tf.contrib.rnn.BasicLSTMCell(other_state_size)
    state_series, current_state = tf.nn.dynamic_rnn(cell, embedded_dep[1], sequence_length=path_length[1], initial_state=other_init_states[1])
    state_series_dep2 = tf.reduce_max(state_series, axis=1)

state_series1 = tf.concat([state_series_word1, state_series_pos1, state_series_dep1], 1)
state_series2 = tf.concat([state_series_word2, state_series_pos2, state_series_dep2], 1)

state_series = tf.concat([state_series1, state_series2], 1)

In [8]:
with tf.name_scope("hidden_layer"):
    W = tf.Variable(tf.truncated_normal([400, 100], -0.1, 0.1), name="W")
    b = tf.Variable(tf.zeros([100]), name="b")
    y_hidden_layer = tf.matmul(state_series, W) + b

In [10]:
with tf.name_scope("softmax_layer"):
    W = tf.Variable(tf.truncated_normal([100, relation_classes], -0.1, 0.1), name="W")
    b = tf.Variable(tf.zeros([relation_classes]), name="b")
    logits = tf.matmul(y_hidden_layer, W) + b
    predictions = tf.argmax(logits, 1)

In [11]:
tv_all = tf.trainable_variables()
tv_regu = []
non_reg = ["word_embedding/W:0","pos_embedding/W:0",'dep_embedding/W:0',"global_step:0",'hidden_layer/b:0','softmax_layer/b:0']
for t in tv_all:
    if t.name not in non_reg:
        if(t.name.find('biases')==-1):
            tv_regu.append(t)

with tf.name_scope("loss"):
    l2_loss = lambda_l2 * tf.reduce_sum([ tf.nn.l2_loss(v) for v in tv_regu ])
    loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=y))
    total_loss = loss + l2_loss

global_step = tf.Variable(0, name="global_step")

optimizer = tf.train.AdamOptimizer(0.001).minimize(total_loss, global_step=global_step)

In [15]:
f = open(data_dir + '/vocab.pkl', 'rb')
vocab = pickle.load(f)
f.close()

word2id = dict((w, i) for i,w in enumerate(vocab))
id2word = dict((i, w) for i,w in enumerate(vocab))

unknown_token = "UNKNOWN_TOKEN"
word2id[unknown_token] = word_vocab_size -1
id2word[word_vocab_size-1] = unknown_token

pos_tags_vocab = []
for line in open(data_dir + '/pos_tags.txt'):
        pos_tags_vocab.append(line.strip())

dep_vocab = []
for line in open(data_dir + '/dependency_types.txt'):
    dep_vocab.append(line.strip())

relation_vocab = []
for line in open(data_dir + '/relation_types.txt'):
    relation_vocab.append(line.strip())


rel2id = dict((w, i) for i,w in enumerate(relation_vocab))
id2rel = dict((i, w) for i,w in enumerate(relation_vocab))

pos_tag2id = dict((w, i) for i,w in enumerate(pos_tags_vocab))
id2pos_tag = dict((i, w) for i,w in enumerate(pos_tags_vocab))

dep2id = dict((w, i) for i,w in enumerate(dep_vocab))
id2dep = dict((i, w) for i,w in enumerate(dep_vocab))

pos_tag2id['OTH'] = 9
id2pos_tag[9] = 'OTH'

dep2id['OTH'] = 20
id2dep[20] = 'OTH'

JJ_pos_tags = ['JJ', 'JJR', 'JJS']
NN_pos_tags = ['NN', 'NNS', 'NNP', 'NNPS']
RB_pos_tags = ['RB', 'RBR', 'RBS']
PRP_pos_tags = ['PRP', 'PRP$']
VB_pos_tags = ['VB', 'VBD', 'VBG', 'VBN', 'VBP', 'VBZ']
_pos_tags = ['CC', 'CD', 'DT', 'IN']

def pos_tag(x):
    if x in JJ_pos_tags:
        return pos_tag2id['JJ']
    if x in NN_pos_tags:
        return pos_tag2id['NN']
    if x in RB_pos_tags:
        return pos_tag2id['RB']
    if x in PRP_pos_tags:
        return pos_tag2id['PRP']
    if x in VB_pos_tags:
        return pos_tag2id['VB']
    if x in _pos_tags:
        return pos_tag2id[x]
    else:
        return 9

In [17]:
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()

In [13]:
# f = open('data/word_embedding', 'rb')
# word_embedding = pickle.load(f)
# f.close()

# sess.run(embedding_init, feed_dict={embedding_placeholder:word_embedding})
# word_embedding_saver.save(sess, word_embd_dir + '/word_embd')

In [19]:
# model = tf.train.latest_checkpoint(model_dir)
# saver.restore(sess, model)

In [20]:
latest_embd = tf.train.latest_checkpoint(word_embd_dir)
word_embedding_saver.restore(sess, latest_embd)

In [30]:
f = open(data_dir + '/train_paths', 'rb')
word_p1, word_p2, dep_p1, dep_p2, pos_p1, pos_p2 = pickle.load(f)
f.close()

relations = []
for line in open(data_dir + '/train_relations.txt'):
    relations.append(line.strip().split()[1])

length = len(word_p1)
num_batches = int(length/batch_size)

for i in range(length):
    for j, word in enumerate(word_p1[i]):
        word = word.lower()
        word_p1[i][j] = word if word in word2id else unknown_token 
    for k, word in enumerate(word_p2[i]):
        word = word.lower()
        word_p2[i][k] = word if word in word2id else unknown_token 
    for l, d in enumerate(dep_p1[i]):
        dep_p1[i][l] = d if d in dep2id else 'OTH'
    for m, d in enumerate(dep_p2[i]):
        dep_p2[i][m] = d if d in dep2id else 'OTH'

word_p1_ids = np.ones([length, max_len_path],dtype=int)
word_p2_ids = np.ones([length, max_len_path],dtype=int)
pos_p1_ids = np.ones([length, max_len_path],dtype=int)
pos_p2_ids = np.ones([length, max_len_path],dtype=int)
dep_p1_ids = np.ones([length, max_len_path],dtype=int)
dep_p2_ids = np.ones([length, max_len_path],dtype=int)
rel_ids = np.array([rel2id[rel] for rel in relations])
path1_len = np.array([len(w) for w in word_p1], dtype=int)
path2_len = np.array([len(w) for w in word_p2])

for i in range(length):
    for j, w in enumerate(word_p1[i]):
        word_p1_ids[i][j] = word2id[w]
    for j, w in enumerate(word_p2[i]):
        word_p2_ids[i][j] = word2id[w]
    for j, w in enumerate(pos_p1[i]):
        pos_p1_ids[i][j] = pos_tag(w)
    for j, w in enumerate(pos_p2[i]):
        pos_p2_ids[i][j] = pos_tag(w)
    for j, w in enumerate(dep_p1[i]):
        dep_p1_ids[i][j] = dep2id[w]
    for j, w in enumerate(dep_p2[i]):
        dep_p2_ids[i][j] = dep2id[w]

In [31]:
num_epochs = 20
for i in range(num_epochs):
    for j in range(num_batches):
        path_dict = [path1_len[j*batch_size:(j+1)*batch_size], path2_len[j*batch_size:(j+1)*batch_size]]
        word_dict = [word_p1_ids[j*batch_size:(j+1)*batch_size], word_p2_ids[j*batch_size:(j+1)*batch_size]]
        pos_dict = [pos_p1_ids[j*batch_size:(j+1)*batch_size], pos_p2_ids[j*batch_size:(j+1)*batch_size]]
        dep_dict = [dep_p1_ids[j*batch_size:(j+1)*batch_size], dep_p2_ids[j*batch_size:(j+1)*batch_size]]
        y_dict = rel_ids[j*batch_size:(j+1)*batch_size]
        
        feed_dict = {
            path_length:path_dict,
            word_ids:word_dict,
            pos_ids:pos_dict,
            dep_ids:dep_dict,
            y:y_dict}
        _, loss, step = sess.run([optimizer, total_loss, global_step], feed_dict)
        if step%10==0:
            print("Step:", step, "loss:",loss)
        if step % 1000 == 0:
            saver.save(sess, model_dir + '/model')
            print("Saved Model")

Step: 32010 loss: 0.465849
Step: 32020 loss: 0.989425
Step: 32030 loss: 0.274849
Step: 32040 loss: 0.509953
Step: 32050 loss: 0.149289
Step: 32060 loss: 0.112399
Step: 32070 loss: 0.116044
Step: 32080 loss: 0.11631
Step: 32090 loss: 0.164528
Step: 32100 loss: 0.200584
Step: 32110 loss: 0.339207
Step: 32120 loss: 0.248778
Step: 32130 loss: 0.259998
Step: 32140 loss: 0.135131
Step: 32150 loss: 0.120505
Step: 32160 loss: 0.401769
Step: 32170 loss: 0.189264
Step: 32180 loss: 0.489604
Step: 32190 loss: 0.912464
Step: 32200 loss: 0.122832
Step: 32210 loss: 0.113877
Step: 32220 loss: 0.12392
Step: 32230 loss: 0.210925
Step: 32240 loss: 0.828158
Step: 32250 loss: 0.235346
Step: 32260 loss: 0.141412
Step: 32270 loss: 0.336016
Step: 32280 loss: 0.220871
Step: 32290 loss: 0.125652
Step: 32300 loss: 0.226214
Step: 32310 loss: 0.489488
Step: 32320 loss: 0.114724
Step: 32330 loss: 0.10684
Step: 32340 loss: 0.363203
Step: 32350 loss: 0.700862
Step: 32360 loss: 0.188156
Step: 32370 loss: 0.10889
Step:

Step: 35050 loss: 0.109522
Step: 35060 loss: 0.173212
Step: 35070 loss: 0.114111
Step: 35080 loss: 0.416148
Step: 35090 loss: 0.540476
Step: 35100 loss: 0.254722
Step: 35110 loss: 0.227785
Step: 35120 loss: 0.118375
Step: 35130 loss: 0.356484
Step: 35140 loss: 0.359755
Step: 35150 loss: 0.121271
Step: 35160 loss: 0.503607
Step: 35170 loss: 0.189824
Step: 35180 loss: 0.619622
Step: 35190 loss: 0.124397
Step: 35200 loss: 0.388573
Step: 35210 loss: 0.356752
Step: 35220 loss: 0.343743
Step: 35230 loss: 0.535166
Step: 35240 loss: 0.294901
Step: 35250 loss: 0.259709
Step: 35260 loss: 0.125196
Step: 35270 loss: 0.198215
Step: 35280 loss: 0.120735
Step: 35290 loss: 0.187133
Step: 35300 loss: 0.204983
Step: 35310 loss: 0.190654
Step: 35320 loss: 0.229639
Step: 35330 loss: 0.473693
Step: 35340 loss: 0.116598
Step: 35350 loss: 0.279773
Step: 35360 loss: 0.150126
Step: 35370 loss: 0.138339
Step: 35380 loss: 0.639304
Step: 35390 loss: 0.113349
Step: 35400 loss: 0.131463
Step: 35410 loss: 0.125274
S

Step: 38090 loss: 0.175709
Step: 38100 loss: 0.287245
Step: 38110 loss: 0.193154
Step: 38120 loss: 0.130799
Step: 38130 loss: 0.155663
Step: 38140 loss: 0.139956
Step: 38150 loss: 0.509048
Step: 38160 loss: 0.19132
Step: 38170 loss: 0.158713
Step: 38180 loss: 0.435439
Step: 38190 loss: 0.284737
Step: 38200 loss: 0.136775
Step: 38210 loss: 0.110083
Step: 38220 loss: 0.132664
Step: 38230 loss: 0.132588
Step: 38240 loss: 0.109532
Step: 38250 loss: 0.206175
Step: 38260 loss: 0.109232
Step: 38270 loss: 0.219585
Step: 38280 loss: 0.223173
Step: 38290 loss: 0.231729
Step: 38300 loss: 0.108458
Step: 38310 loss: 0.210494
Step: 38320 loss: 0.234982
Step: 38330 loss: 0.565062
Step: 38340 loss: 0.950462
Step: 38350 loss: 0.161208
Step: 38360 loss: 0.292838
Step: 38370 loss: 0.324109
Step: 38380 loss: 0.332624
Step: 38390 loss: 0.2493
Step: 38400 loss: 0.465916
Step: 38410 loss: 0.19226
Step: 38420 loss: 0.189594
Step: 38430 loss: 0.205738
Step: 38440 loss: 0.578031
Step: 38450 loss: 0.230732
Step:

Step: 41130 loss: 0.196381
Step: 41140 loss: 0.226283
Step: 41150 loss: 0.189477
Step: 41160 loss: 0.109256
Step: 41170 loss: 0.160071
Step: 41180 loss: 0.197342
Step: 41190 loss: 0.135026
Step: 41200 loss: 0.270334
Step: 41210 loss: 0.109521
Step: 41220 loss: 0.134346
Step: 41230 loss: 0.117708
Step: 41240 loss: 0.102373
Step: 41250 loss: 0.117262
Step: 41260 loss: 0.183646
Step: 41270 loss: 0.197597
Step: 41280 loss: 0.259668
Step: 41290 loss: 0.167712
Step: 41300 loss: 0.107362
Step: 41310 loss: 0.15925
Step: 41320 loss: 0.163347
Step: 41330 loss: 0.175903
Step: 41340 loss: 0.168268
Step: 41350 loss: 0.104662
Step: 41360 loss: 0.127753
Step: 41370 loss: 0.134339
Step: 41380 loss: 0.105138
Step: 41390 loss: 0.11211
Step: 41400 loss: 0.168183
Step: 41410 loss: 0.10295
Step: 41420 loss: 0.117025
Step: 41430 loss: 0.116568
Step: 41440 loss: 0.650403
Step: 41450 loss: 0.123113
Step: 41460 loss: 0.138684
Step: 41470 loss: 0.226205
Step: 41480 loss: 0.107843
Step: 41490 loss: 0.106563
Step

Step: 44170 loss: 0.2035
Step: 44180 loss: 0.15654
Step: 44190 loss: 0.104065
Step: 44200 loss: 0.111314
Step: 44210 loss: 0.101138
Step: 44220 loss: 0.106676
Step: 44230 loss: 0.268077
Step: 44240 loss: 0.135465
Step: 44250 loss: 0.109786
Step: 44260 loss: 0.20021
Step: 44270 loss: 0.104894
Step: 44280 loss: 0.287074
Step: 44290 loss: 0.141661
Step: 44300 loss: 0.129823
Step: 44310 loss: 0.108479
Step: 44320 loss: 0.155135
Step: 44330 loss: 0.136926
Step: 44340 loss: 0.109277
Step: 44350 loss: 0.131163
Step: 44360 loss: 0.105929
Step: 44370 loss: 0.110991
Step: 44380 loss: 0.206221
Step: 44390 loss: 0.303211
Step: 44400 loss: 0.116069
Step: 44410 loss: 0.363999
Step: 44420 loss: 0.256382
Step: 44430 loss: 0.110059
Step: 44440 loss: 0.110923
Step: 44450 loss: 0.108623
Step: 44460 loss: 0.102245
Step: 44470 loss: 0.81267
Step: 44480 loss: 0.363868
Step: 44490 loss: 0.101172
Step: 44500 loss: 0.158603
Step: 44510 loss: 0.115151
Step: 44520 loss: 0.144479
Step: 44530 loss: 0.218534
Step: 

Step: 47210 loss: 0.112841
Step: 47220 loss: 0.641932
Step: 47230 loss: 0.261087
Step: 47240 loss: 0.240021
Step: 47250 loss: 0.211281
Step: 47260 loss: 0.100937
Step: 47270 loss: 0.108555
Step: 47280 loss: 0.105594
Step: 47290 loss: 0.103131
Step: 47300 loss: 0.139107
Step: 47310 loss: 0.109371
Step: 47320 loss: 0.418391
Step: 47330 loss: 0.0988733
Step: 47340 loss: 0.167312
Step: 47350 loss: 0.17647
Step: 47360 loss: 0.14746
Step: 47370 loss: 0.219458
Step: 47380 loss: 0.475655
Step: 47390 loss: 0.111831
Step: 47400 loss: 0.113887
Step: 47410 loss: 0.0987209
Step: 47420 loss: 0.184824
Step: 47430 loss: 0.175491
Step: 47440 loss: 0.797833
Step: 47450 loss: 0.103924
Step: 47460 loss: 0.184284
Step: 47470 loss: 0.179754
Step: 47480 loss: 0.230305
Step: 47490 loss: 0.115631
Step: 47500 loss: 0.097848
Step: 47510 loss: 0.305652
Step: 47520 loss: 0.122897
Step: 47530 loss: 0.117728
Step: 47540 loss: 0.112769
Step: 47550 loss: 0.310324
Step: 47560 loss: 0.0983254
Step: 47570 loss: 0.122028


In [32]:
# training accuracy
all_predictions = []
for j in range(num_batches):
    path_dict = [path1_len[j*batch_size:(j+1)*batch_size], path2_len[j*batch_size:(j+1)*batch_size]]
    word_dict = [word_p1_ids[j*batch_size:(j+1)*batch_size], word_p2_ids[j*batch_size:(j+1)*batch_size]]
    pos_dict = [pos_p1_ids[j*batch_size:(j+1)*batch_size], pos_p2_ids[j*batch_size:(j+1)*batch_size]]
    dep_dict = [dep_p1_ids[j*batch_size:(j+1)*batch_size], dep_p2_ids[j*batch_size:(j+1)*batch_size]]
    y_dict = rel_ids[j*batch_size:(j+1)*batch_size]

    feed_dict = {
        path_length:path_dict,
        word_ids:word_dict,
        pos_ids:pos_dict,
        dep_ids:dep_dict,
        y:y_dict}
    batch_predictions = sess.run(predictions, feed_dict)
    all_predictions.append(batch_predictions)

y_pred = []
for i in range(num_batches):
    for pred in all_predictions[i]:
        y_pred.append(pred)

count = 0
for i in range(batch_size*num_batches):
    count += y_pred[i]==rel_ids[i]
accuracy = count/(batch_size*num_batches) * 100

print("training accuracy", accuracy)

training accuracy 96.075


In [33]:
f = open(data_dir + '/test_paths', 'rb')
word_p1, word_p2, dep_p1, dep_p2, pos_p1, pos_p2 = pickle.load(f)
f.close()

relations = []
for line in open(data_dir + '/test_relations.txt'):
    relations.append(line.strip().split()[0])

length = len(word_p1)
num_batches = int(length/batch_size)

for i in range(length):
    for j, word in enumerate(word_p1[i]):
        word = word.lower()
        word_p1[i][j] = word if word in word2id else unknown_token 
    for k, word in enumerate(word_p2[i]):
        word = word.lower()
        word_p2[i][k] = word if word in word2id else unknown_token 
    for l, d in enumerate(dep_p1[i]):
        dep_p1[i][l] = d if d in dep2id else 'OTH'
    for m, d in enumerate(dep_p2[i]):
        dep_p2[i][m] = d if d in dep2id else 'OTH'

word_p1_ids = np.ones([length, max_len_path],dtype=int)
word_p2_ids = np.ones([length, max_len_path],dtype=int)
pos_p1_ids = np.ones([length, max_len_path],dtype=int)
pos_p2_ids = np.ones([length, max_len_path],dtype=int)
dep_p1_ids = np.ones([length, max_len_path],dtype=int)
dep_p2_ids = np.ones([length, max_len_path],dtype=int)
rel_ids = np.array([rel2id[rel] for rel in relations])
path1_len = np.array([len(w) for w in word_p1], dtype=int)
path2_len = np.array([len(w) for w in word_p2])

for i in range(length):
    for j, w in enumerate(word_p1[i]):
        word_p1_ids[i][j] = word2id[w]
    for j, w in enumerate(word_p2[i]):
        word_p2_ids[i][j] = word2id[w]
    for j, w in enumerate(pos_p1[i]):
        pos_p1_ids[i][j] = pos_tag(w)
    for j, w in enumerate(pos_p2[i]):
        pos_p2_ids[i][j] = pos_tag(w)
    for j, w in enumerate(dep_p1[i]):
        dep_p1_ids[i][j] = dep2id[w]
    for j, w in enumerate(dep_p2[i]):
        dep_p2_ids[i][j] = dep2id[w]

In [34]:
# test 
all_predictions = []
for j in range(num_batches):
    path_dict = [path1_len[j*batch_size:(j+1)*batch_size], path2_len[j*batch_size:(j+1)*batch_size]]
    word_dict = [word_p1_ids[j*batch_size:(j+1)*batch_size], word_p2_ids[j*batch_size:(j+1)*batch_size]]
    pos_dict = [pos_p1_ids[j*batch_size:(j+1)*batch_size], pos_p2_ids[j*batch_size:(j+1)*batch_size]]
    dep_dict = [dep_p1_ids[j*batch_size:(j+1)*batch_size], dep_p2_ids[j*batch_size:(j+1)*batch_size]]
    y_dict = rel_ids[j*batch_size:(j+1)*batch_size]

    feed_dict = {
        path_length:path_dict,
        word_ids:word_dict,
        pos_ids:pos_dict,
        dep_ids:dep_dict,
        y:y_dict}
    batch_predictions = sess.run(predictions, feed_dict)
    all_predictions.append(batch_predictions)

y_pred = []
for i in range(num_batches):
    for pred in all_predictions[i]:
        y_pred.append(pred)

count = 0
for i in range(batch_size*num_batches):
    count += y_pred[i]==rel_ids[i]
accuracy = count/(batch_size*num_batches) * 100

print("test accuracy", accuracy)

test accuracy 63.2472324723
