In [1]:
import tensorflow as tf
import numpy as np

import os, sys
sys.path.append(os.path.dirname(os.getcwd()))

from data import WN18
from model.distmult import model_fn
from model.metrics import evaluate_rank

In [2]:
class Config:
    seed = 21
    n_epochs = 10
    batch_size = 100
    embed_dim = 150

In [3]:
"""
e: entity
s: subject
p: predicate
o: object
"""

def read_triples(path):
    triples = []
    with open(path, 'rt') as f:
        for line in f.readlines():
            s, p, o = line.split()
            triples += [(s.strip(), p.strip(), o.strip())]
    return triples


def load_triple():
    WN18.download()
    triples_tr = read_triples('../data/WN18/wn18/train.txt')
    triples_va = read_triples('../data/WN18/wn18/valid.txt')
    triples_te = read_triples('../data/WN18/wn18/test.txt')
    
    triples_all = triples_tr + triples_va + triples_te
    
    return triples_all, triples_tr, triples_va, triples_te


def build_vocab(triples):
    params = {}
    
    e_set = {s for (s, p, o) in triples} | {o for (s, p, o) in triples}
    p_set = {p for (s, p, o) in triples}

    params['e_vocab_size'] = len(e_set)
    params['p_vocab_size'] = len(p_set)

    e2idx = {e: idx for idx, e in enumerate(sorted(e_set))}
    p2idx = {p: idx for idx, p in enumerate(sorted(p_set))}
    
    return e2idx, p2idx, params


def build_train_data(triples_tr, e2idx, p2idx):
    x_s = np.array([e2idx[s] for (s, p, o) in triples_tr], dtype=np.int32)
    x_p = np.array([p2idx[p] for (s, p, o) in triples_tr], dtype=np.int32)
    x_o = np.array([e2idx[o] for (s, p, o) in triples_tr], dtype=np.int32)

    x = {'s': x_s,
         'p': x_p,
         'o': x_o}
    y = np.ones([len(x_s)], dtype=np.float32)
    
    return x, y


def train_input_fn(triples_tr, e2idx, p2idx, random_state, params):
    x, y = build_train_data(triples_tr, e2idx, p2idx)
    s, p, o = x['s'], x['p'], x['o']
    
    s_ = random_state.choice(params['e_vocab_size'], s.shape)
    o_ = random_state.choice(params['e_vocab_size'], o.shape)
    
    x_ = {
        's': np.concatenate([s, s_, s]),
        'p': np.concatenate([p, p, p]),
        'o': np.concatenate([o, o, o_])}
    y_ = np.concatenate([y, np.zeros([2*len(y)], dtype=np.float32)])
    
    return tf.estimator.inputs.numpy_input_fn(x = x_,
                                              y = y_,
                                              batch_size = Config.batch_size,
                                              num_epochs = 1,
                                              shuffle = True)

In [4]:
random_state = np.random.RandomState(Config.seed)
triples_all, triples_tr, triples_va, triples_te = load_triple()
e2idx, p2idx, params = build_vocab(triples_all)
params['embed_dim'] = Config.embed_dim

model = tf.estimator.Estimator(model_fn,
                               params = params)

for _ in range(Config.n_epochs):
    model.train(train_input_fn(triples_tr,
                               e2idx,
                               p2idx,
                               random_state,
                               params))
evaluate_rank(model,
              triples_va,
              triples_te,
              triples_all,
              e2idx,
              p2idx,
              params['e_vocab_size'],
              batch_size = 1000)

Files Already Downloaded
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': '/var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp_ud__2ua', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': None, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x115ed12b0>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:
[<tf.Variable 'e_embed:0' shape=(40943, 150) dtype=float32_ref>,
 <tf.Variable 'p_embed:0' shape=(18, 150) dtype=float32_ref>]
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorfl

INFO:tensorflow:loss = 0.7309796, step = 6145 (3.420 sec)
INFO:tensorflow:global_step/sec: 27.7062
INFO:tensorflow:loss = 0.715828, step = 6245 (3.609 sec)
INFO:tensorflow:global_step/sec: 33.4408
INFO:tensorflow:loss = 0.702667, step = 6345 (2.990 sec)
INFO:tensorflow:global_step/sec: 33.0134
INFO:tensorflow:loss = 0.71627426, step = 6445 (3.029 sec)
INFO:tensorflow:global_step/sec: 29.2175
INFO:tensorflow:loss = 0.71140045, step = 6545 (3.423 sec)
INFO:tensorflow:global_step/sec: 26.0775
INFO:tensorflow:loss = 0.73570246, step = 6645 (3.835 sec)
INFO:tensorflow:global_step/sec: 25.5713
INFO:tensorflow:loss = 0.70175844, step = 6745 (3.911 sec)
INFO:tensorflow:global_step/sec: 33.4815
INFO:tensorflow:loss = 0.70875883, step = 6845 (2.987 sec)
INFO:tensorflow:global_step/sec: 33.7086
INFO:tensorflow:loss = 0.7222971, step = 6945 (2.967 sec)
INFO:tensorflow:global_step/sec: 33.9772
INFO:tensorflow:loss = 0.7222709, step = 7045 (2.943 sec)
INFO:tensorflow:global_step/sec: 33.8148
INFO:te

INFO:tensorflow:loss = 0.70917124, step = 12733
INFO:tensorflow:global_step/sec: 32.8713
INFO:tensorflow:loss = 0.80611855, step = 12833 (3.044 sec)
INFO:tensorflow:global_step/sec: 30.5916
INFO:tensorflow:loss = 0.7168196, step = 12933 (3.270 sec)
INFO:tensorflow:global_step/sec: 25.4284
INFO:tensorflow:loss = 0.7278025, step = 13033 (3.931 sec)
INFO:tensorflow:global_step/sec: 25.8299
INFO:tensorflow:loss = 0.67771417, step = 13133 (3.872 sec)
INFO:tensorflow:global_step/sec: 28.5935
INFO:tensorflow:loss = 0.68729264, step = 13233 (3.497 sec)
INFO:tensorflow:global_step/sec: 33.698
INFO:tensorflow:loss = 0.71599364, step = 13333 (2.968 sec)
INFO:tensorflow:global_step/sec: 27.7952
INFO:tensorflow:loss = 0.7631733, step = 13433 (3.598 sec)
INFO:tensorflow:global_step/sec: 30.2434
INFO:tensorflow:loss = 0.70607597, step = 13533 (3.306 sec)
INFO:tensorflow:global_step/sec: 34.5401
INFO:tensorflow:loss = 0.67985797, step = 13633 (2.895 sec)
INFO:tensorflow:global_step/sec: 29.7098
INFO:t

INFO:tensorflow:global_step/sec: 30.4727
INFO:tensorflow:loss = 0.08664522, step = 20177 (3.281 sec)
INFO:tensorflow:global_step/sec: 32.6971
INFO:tensorflow:loss = 0.060435828, step = 20277 (3.059 sec)
INFO:tensorflow:global_step/sec: 33.1698
INFO:tensorflow:loss = 0.05628078, step = 20377 (3.015 sec)
INFO:tensorflow:global_step/sec: 24.7392
INFO:tensorflow:loss = 0.048502926, step = 20477 (4.042 sec)
INFO:tensorflow:global_step/sec: 24.4542
INFO:tensorflow:loss = 0.03744556, step = 20577 (4.089 sec)
INFO:tensorflow:global_step/sec: 24.3119
INFO:tensorflow:loss = 0.030374315, step = 20677 (4.113 sec)
INFO:tensorflow:global_step/sec: 25.8419
INFO:tensorflow:loss = 0.028963072, step = 20777 (3.870 sec)
INFO:tensorflow:global_step/sec: 24.9974
INFO:tensorflow:loss = 0.0214863, step = 20877 (4.000 sec)
INFO:tensorflow:global_step/sec: 24.4741
INFO:tensorflow:loss = 0.022748737, step = 20977 (4.086 sec)
INFO:tensorflow:global_step/sec: 23.8252
INFO:tensorflow:loss = 0.014563782, step = 210

INFO:tensorflow:global_step/sec: 36.623
INFO:tensorflow:loss = 0.0038815173, step = 26665 (2.730 sec)
INFO:tensorflow:global_step/sec: 36.64
INFO:tensorflow:loss = 0.0022108175, step = 26765 (2.729 sec)
INFO:tensorflow:global_step/sec: 36.4452
INFO:tensorflow:loss = 0.07114625, step = 26865 (2.744 sec)
INFO:tensorflow:global_step/sec: 34.2764
INFO:tensorflow:loss = 0.11766375, step = 26965 (2.917 sec)
INFO:tensorflow:global_step/sec: 31.1598
INFO:tensorflow:loss = 0.0900916, step = 27065 (3.209 sec)
INFO:tensorflow:global_step/sec: 28.3586
INFO:tensorflow:loss = 0.08912421, step = 27165 (3.526 sec)
INFO:tensorflow:global_step/sec: 28.7163
INFO:tensorflow:loss = 0.09885471, step = 27265 (3.482 sec)
INFO:tensorflow:global_step/sec: 33.4292
INFO:tensorflow:loss = 0.098865576, step = 27365 (2.993 sec)
INFO:tensorflow:global_step/sec: 31.7302
INFO:tensorflow:loss = 0.04963805, step = 27465 (3.152 sec)
INFO:tensorflow:global_step/sec: 31.9464
INFO:tensorflow:loss = 0.089617625, step = 27565 

INFO:tensorflow:loss = 0.012641861, step = 33909 (2.797 sec)
INFO:tensorflow:Saving checkpoints for 33952 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp_ud__2ua/model.ckpt.
INFO:tensorflow:Loss for final step: 0.012894905.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:
[<tf.Variable 'e_embed:0' shape=(40943, 150) dtype=float32_ref>,
 <tf.Variable 'p_embed:0' shape=(18, 150) dtype=float32_ref>]
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp_ud__2ua/model.ckpt-33952
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 33953 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp_ud__2ua/model.ckpt.
INFO:tensorflow:loss = 0.053680774, step = 33953
INFO:tensorflow:global_step/sec: 34.8857
INFO:tensorflow:loss = 0.0066638505, step = 34053 (2

INFO:tensorflow:global_step/sec: 36.4487
INFO:tensorflow:loss = 0.057505332, step = 40397 (2.744 sec)
INFO:tensorflow:global_step/sec: 36.328
INFO:tensorflow:loss = 0.0057719066, step = 40497 (2.753 sec)
INFO:tensorflow:global_step/sec: 36.3775
INFO:tensorflow:loss = 0.088864565, step = 40597 (2.749 sec)
INFO:tensorflow:global_step/sec: 36.244
INFO:tensorflow:loss = 0.001446834, step = 40697 (2.759 sec)
INFO:tensorflow:global_step/sec: 36.1255
INFO:tensorflow:loss = 0.0022569206, step = 40797 (2.768 sec)
INFO:tensorflow:global_step/sec: 33.6195
INFO:tensorflow:loss = 0.0006243239, step = 40897 (2.974 sec)
INFO:tensorflow:global_step/sec: 28.5796
INFO:tensorflow:loss = 0.001350181, step = 40997 (3.499 sec)
INFO:tensorflow:global_step/sec: 34.2803
INFO:tensorflow:loss = 0.0034266622, step = 41097 (2.917 sec)
INFO:tensorflow:global_step/sec: 34.3978
INFO:tensorflow:loss = 0.0007036687, step = 41197 (2.908 sec)
INFO:tensorflow:global_step/sec: 31.9801
INFO:tensorflow:loss = 0.00031730323, 

100%|█████████████████████████████| 5000/5000 [02:15<00:00, 36.98it/s]

INFO:tensorflow:Calling model_fn.





INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp_ud__2ua/model.ckpt-42440
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.


100%|█████████████████████████████| 5000/5000 [02:07<00:00, 39.12it/s]
100%|█████████████████████████████| 5000/5000 [04:07<00:00, 20.22it/s]

[valid] Raw Mean Rank: 781.3609
[valid] Raw Hits@1: 40.46
[valid] Raw Hits@3: 64.36
[valid] Raw Hits@5: 73.18
[valid] Raw Hits@10: 80.9
[valid] Filtered Mean Rank: 769.7638
[valid] Filtered Hits@1: 61.22
[valid] Filtered Hits@3: 86.42
[valid] Filtered Hits@5: 90.5
[valid] Filtered Hits@10: 92.88
INFO:tensorflow:Calling model_fn.





INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp_ud__2ua/model.ckpt-42440
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.


100%|█████████████████████████████| 5000/5000 [02:07<00:00, 39.22it/s]


INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp_ud__2ua/model.ckpt-42440
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.


100%|█████████████████████████████| 5000/5000 [02:08<00:00, 39.06it/s]
100%|█████████████████████████████| 5000/5000 [03:54<00:00, 21.34it/s]


[test] Raw Mean Rank: 700.2553
[test] Raw Hits@1: 40.45
[test] Raw Hits@3: 64.42999999999999
[test] Raw Hits@5: 73.32
[test] Raw Hits@10: 81.37
[test] Filtered Mean Rank: 688.6422
[test] Filtered Hits@1: 61.58
[test] Filtered Hits@3: 86.83999999999999
[test] Filtered Hits@5: 90.92
[test] Filtered Hits@10: 93.05
