In [1]:
import tensorflow as tf
import numpy as np
import pprint

import os, sys
sys.path.append(os.path.dirname(os.getcwd()))

from data import WN18
from model.metrics import evaluate_rank

In [2]:
class Config:
    seed = 21
    n_epochs = 10
    batch_size = 128
    embed_dim = 150

In [3]:
"""
e: entity
s: subject
p: predicate
o: object
"""

def read_triples(path):
    triples = []
    with open(path, 'rt') as f:
        for line in f.readlines():
            s, p, o = line.split()
            triples += [(s.strip(), p.strip(), o.strip())]
    return triples


def load_triple():
    WN18.download()
    triples_tr = read_triples('../data/WN18/wn18/train.txt')
    triples_va = read_triples('../data/WN18/wn18/valid.txt')
    triples_te = read_triples('../data/WN18/wn18/test.txt')
    
    triples_all = triples_tr + triples_va + triples_te
    
    return triples_all, triples_tr, triples_va, triples_te


def build_vocab(triples):
    params = {}
    
    e_set = {s for (s, p, o) in triples} | {o for (s, p, o) in triples}
    p_set = {p for (s, p, o) in triples}

    params['e_vocab_size'] = len(e_set)
    params['p_vocab_size'] = len(p_set)

    e2idx = {e: idx for idx, e in enumerate(sorted(e_set))}
    p2idx = {p: idx for idx, p in enumerate(sorted(p_set))}
    
    return e2idx, p2idx, params


def build_train_data(triples_tr, e2idx, p2idx):
    x_s = np.array([e2idx[s] for (s, p, o) in triples_tr], dtype=np.int32)
    x_p = np.array([p2idx[p] for (s, p, o) in triples_tr], dtype=np.int32)
    x_o = np.array([e2idx[o] for (s, p, o) in triples_tr], dtype=np.int32)

    x = {'s': x_s,
         'p': x_p,
         'o': x_o}
    y = np.ones([len(x_s)], dtype=np.float32)
    
    return x, y


def train_input_fn(triples_tr, e2idx, p2idx, random_state, params):
    x, y = build_train_data(triples_tr, e2idx, p2idx)
    s, p, o = x['s'], x['p'], x['o']
    
    s_ = random_state.choice(params['e_vocab_size'], s.shape)
    o_ = random_state.choice(params['e_vocab_size'], o.shape)
    
    x_ = {
        's': np.concatenate([s, s_, s]),
        'p': np.concatenate([p, p, p]),
        'o': np.concatenate([o, o, o_])}
    y_ = np.concatenate([y, np.zeros([2*len(y)], dtype=np.float32)])
    
    return tf.estimator.inputs.numpy_input_fn(x = x_,
                                              y = y_,
                                              batch_size = Config.batch_size,
                                              num_epochs = 1,
                                              shuffle = True)

In [4]:
def forward(features, params):
    e_embed = tf.get_variable('e_embed',
                              [params['e_vocab_size'], Config.embed_dim],
                              initializer=tf.contrib.layers.xavier_initializer())
    p_embed = tf.get_variable('p_embed',
                              [params['p_vocab_size'], Config.embed_dim],
                              initializer=tf.contrib.layers.xavier_initializer())
    
    s = tf.nn.embedding_lookup(e_embed, features['s'])
    p = tf.nn.embedding_lookup(p_embed, features['p'])
    o = tf.nn.embedding_lookup(e_embed, features['o'])
    
    logits = tf.reduce_sum(s*p*o, axis=1)
    
    return logits
    
    
def model_fn(features, labels, mode, params):
    logits = forward(features, params)
    
    if mode == tf.estimator.ModeKeys.TRAIN:
        tf.logging.info('\n'+pprint.pformat(tf.trainable_variables()))
        
        loss_op = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits,
                                                                         labels=labels))
        
        train_op = tf.train.AdamOptimizer().minimize(
            loss_op, global_step=tf.train.get_global_step())
        
        return tf.estimator.EstimatorSpec(mode = mode,
                                          loss = loss_op,
                                          train_op = train_op)
    
    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(mode, predictions = tf.sigmoid(logits))

In [5]:
random_state = np.random.RandomState(Config.seed)
triples_all, triples_tr, triples_va, triples_te = load_triple()
e2idx, p2idx, params = build_vocab(triples_all)

model = tf.estimator.Estimator(model_fn,
                               params = params)

for _ in range(Config.n_epochs):
    model.train(train_input_fn(triples_tr,
                               e2idx,
                               p2idx,
                               random_state,
                               params))
evaluate_rank(model,
              triples_va,
              triples_te,
              triples_all,
              e2idx,
              p2idx,
              params['e_vocab_size'],
              batch_size = 1000)

INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': '/var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpzsaed428', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': None, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x10fb7aa58>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:
[<tf.Variable 'e_embed:0' shape=(40943, 150) dtype=float32_ref>,
 <tf.Variable 'p_embed:0' shape=(18, 150) dtype=float32_ref>]
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
I

INFO:tensorflow:loss = 0.71231425, step = 6117 (3.325 sec)
INFO:tensorflow:global_step/sec: 29.4602
INFO:tensorflow:loss = 0.71152484, step = 6217 (3.395 sec)
INFO:tensorflow:global_step/sec: 24.8957
INFO:tensorflow:loss = 0.7345394, step = 6317 (4.017 sec)
INFO:tensorflow:global_step/sec: 31.1505
INFO:tensorflow:loss = 0.7126507, step = 6417 (3.210 sec)
INFO:tensorflow:global_step/sec: 30.8777
INFO:tensorflow:loss = 0.687209, step = 6517 (3.239 sec)
INFO:tensorflow:global_step/sec: 31.8546
INFO:tensorflow:loss = 0.7205084, step = 6617 (3.139 sec)
INFO:tensorflow:Saving checkpoints for 6632 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpzsaed428/model.ckpt.
INFO:tensorflow:Loss for final step: 0.8112661.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:
[<tf.Variable 'e_embed:0' shape=(40943, 150) dtype=float32_ref>,
 <tf.Variable 'p_embed:0' shape=(18, 150) dtype=float32_ref>]
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow

INFO:tensorflow:global_step/sec: 35.6059
INFO:tensorflow:loss = 0.6651845, step = 12749 (2.809 sec)
INFO:tensorflow:global_step/sec: 32.5993
INFO:tensorflow:loss = 0.59930086, step = 12849 (3.067 sec)
INFO:tensorflow:global_step/sec: 31.0543
INFO:tensorflow:loss = 0.6423048, step = 12949 (3.220 sec)
INFO:tensorflow:global_step/sec: 32.3578
INFO:tensorflow:loss = 0.6690872, step = 13049 (3.091 sec)
INFO:tensorflow:global_step/sec: 35.7813
INFO:tensorflow:loss = 0.6308319, step = 13149 (2.795 sec)
INFO:tensorflow:global_step/sec: 28.9538
INFO:tensorflow:loss = 0.5923948, step = 13249 (3.454 sec)
INFO:tensorflow:Saving checkpoints for 13264 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpzsaed428/model.ckpt.
INFO:tensorflow:Loss for final step: 0.47123444.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:
[<tf.Variable 'e_embed:0' shape=(40943, 150) dtype=float32_ref>,
 <tf.Variable 'p_embed:0' shape=(18, 150) dtype=float32_ref>]
INFO:tensorflow:Done calling model_fn.
INFO:tenso

INFO:tensorflow:global_step/sec: 32.2587
INFO:tensorflow:loss = 0.1296507, step = 19281 (3.100 sec)
INFO:tensorflow:global_step/sec: 32.1246
INFO:tensorflow:loss = 0.093185276, step = 19381 (3.113 sec)
INFO:tensorflow:global_step/sec: 32.118
INFO:tensorflow:loss = 0.110508494, step = 19481 (3.113 sec)
INFO:tensorflow:global_step/sec: 32.1403
INFO:tensorflow:loss = 0.07584869, step = 19581 (3.111 sec)
INFO:tensorflow:global_step/sec: 32.3941
INFO:tensorflow:loss = 0.082338914, step = 19681 (3.087 sec)
INFO:tensorflow:global_step/sec: 32.3478
INFO:tensorflow:loss = 0.06519491, step = 19781 (3.092 sec)
INFO:tensorflow:global_step/sec: 28.4546
INFO:tensorflow:loss = 0.0811238, step = 19881 (3.514 sec)
INFO:tensorflow:Saving checkpoints for 19896 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpzsaed428/model.ckpt.
INFO:tensorflow:Loss for final step: 0.04960853.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:
[<tf.Variable 'e_embed:0' shape=(40943, 150) dtype=float32_ref>,
 <tf.

INFO:tensorflow:loss = 0.019847065, step = 25713 (3.272 sec)
INFO:tensorflow:global_step/sec: 35.6222
INFO:tensorflow:loss = 0.013405206, step = 25813 (2.807 sec)
INFO:tensorflow:global_step/sec: 34.7899
INFO:tensorflow:loss = 0.014136009, step = 25913 (2.874 sec)
INFO:tensorflow:global_step/sec: 25.3717
INFO:tensorflow:loss = 0.022028655, step = 26013 (3.942 sec)
INFO:tensorflow:global_step/sec: 35.1386
INFO:tensorflow:loss = 0.009318233, step = 26113 (2.845 sec)
INFO:tensorflow:global_step/sec: 36.187
INFO:tensorflow:loss = 0.00789922, step = 26213 (2.763 sec)
INFO:tensorflow:global_step/sec: 35.3425
INFO:tensorflow:loss = 0.0077727344, step = 26313 (2.830 sec)
INFO:tensorflow:global_step/sec: 35.0371
INFO:tensorflow:loss = 0.006402545, step = 26413 (2.854 sec)
INFO:tensorflow:global_step/sec: 34.5593
INFO:tensorflow:loss = 0.010766362, step = 26513 (2.894 sec)
INFO:tensorflow:Saving checkpoints for 26528 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpzsaed428/model.ckpt.
I

INFO:tensorflow:global_step/sec: 33.2216
INFO:tensorflow:loss = 0.0016217246, step = 32145 (3.010 sec)
INFO:tensorflow:global_step/sec: 34.0003
INFO:tensorflow:loss = 0.0036441355, step = 32245 (2.941 sec)
INFO:tensorflow:global_step/sec: 30.3426
INFO:tensorflow:loss = 0.000829707, step = 32345 (3.296 sec)
INFO:tensorflow:global_step/sec: 31.3272
INFO:tensorflow:loss = 0.00053780444, step = 32445 (3.192 sec)
INFO:tensorflow:global_step/sec: 34.7413
INFO:tensorflow:loss = 0.00032626322, step = 32545 (2.878 sec)
INFO:tensorflow:global_step/sec: 34.6219
INFO:tensorflow:loss = 0.01333243, step = 32645 (2.889 sec)
INFO:tensorflow:global_step/sec: 24.5971
INFO:tensorflow:loss = 0.0043383376, step = 32745 (4.065 sec)
INFO:tensorflow:global_step/sec: 27.1304
INFO:tensorflow:loss = 0.004561611, step = 32845 (3.686 sec)
INFO:tensorflow:global_step/sec: 29.828
INFO:tensorflow:loss = 0.017407393, step = 32945 (3.352 sec)
INFO:tensorflow:global_step/sec: 33.2604
INFO:tensorflow:loss = 0.011607796, 

100%|█████████████████████████████| 5000/5000 [02:05<00:00, 39.93it/s]

INFO:tensorflow:Calling model_fn.





INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpzsaed428/model.ckpt-33160
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.


100%|█████████████████████████████| 5000/5000 [02:06<00:00, 39.55it/s]
100%|█████████████████████████████| 5000/5000 [04:02<00:00, 20.65it/s]


[valid] Raw Mean Rank: 901.5622
[valid] Raw Hits@1: 41.33
[valid] Raw Hits@3: 64.9
[valid] Raw Hits@5: 73.22
[valid] Raw Hits@10: 80.78
[valid] Filtered Mean Rank: 889.8345
[valid] Filtered Hits@1: 65.35
[valid] Filtered Hits@3: 89.88000000000001
[valid] Filtered Hits@5: 92.2
[valid] Filtered Hits@10: 93.63
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpzsaed428/model.ckpt-33160
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.


100%|█████████████████████████████| 5000/5000 [02:05<00:00, 39.80it/s]


INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpzsaed428/model.ckpt-33160
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.


100%|█████████████████████████████| 5000/5000 [02:05<00:00, 39.90it/s]
100%|█████████████████████████████| 5000/5000 [04:03<00:00, 20.57it/s]


[test] Raw Mean Rank: 856.6072
[test] Raw Hits@1: 41.15
[test] Raw Hits@3: 64.96
[test] Raw Hits@5: 73.92
[test] Raw Hits@10: 81.14
[test] Filtered Mean Rank: 844.9209
[test] Filtered Hits@1: 65.73
[test] Filtered Hits@3: 89.99000000000001
[test] Filtered Hits@5: 92.34
[test] Filtered Hits@10: 93.8
