In [1]:
import tensorflow as tf
import numpy as np

import os, sys
sys.path.append(os.path.dirname(os.getcwd()))

from data import WN18
from model.distmult import model_fn
from model.metrics import evaluate_rank

In [2]:
class Config:
    seed = 21
    n_epochs = 10
    batch_size = 100
    embed_dim = 200

In [3]:
"""
e: entity
s: subject
p: predicate
o: object
"""

def read_triples(path):
    triples = []
    with open(path, 'rt') as f:
        for line in f.readlines():
            s, p, o = line.split()
            triples += [(s.strip(), p.strip(), o.strip())]
    return triples


def load_triple():
    WN18.download()
    triples_tr = read_triples('../data/WN18/wn18/train.txt')
    triples_va = read_triples('../data/WN18/wn18/valid.txt')
    triples_te = read_triples('../data/WN18/wn18/test.txt')
    
    triples_all = triples_tr + triples_va + triples_te
    
    return triples_all, triples_tr, triples_va, triples_te


def build_vocab(triples):
    params = {}
    
    e_set = {s for (s, p, o) in triples} | {o for (s, p, o) in triples}
    p_set = {p for (s, p, o) in triples}

    params['e_vocab_size'] = len(e_set)
    params['p_vocab_size'] = len(p_set)
    params['embed_dim'] = Config.embed_dim

    e2idx = {e: idx for idx, e in enumerate(sorted(e_set))}
    p2idx = {p: idx for idx, p in enumerate(sorted(p_set))}
    
    return e2idx, p2idx, params


def build_train_data(triples_tr, e2idx, p2idx):
    x_s = np.array([e2idx[s] for (s, p, o) in triples_tr], dtype=np.int32)
    x_p = np.array([p2idx[p] for (s, p, o) in triples_tr], dtype=np.int32)
    x_o = np.array([e2idx[o] for (s, p, o) in triples_tr], dtype=np.int32)

    x = {'s': x_s,
         'p': x_p,
         'o': x_o}
    y = np.ones([len(x_s)], dtype=np.float32)
    
    return x, y


def train_input_fn(triples_tr, e2idx, p2idx, random_state, params):
    x, y = build_train_data(triples_tr, e2idx, p2idx)
    s, p, o = x['s'], x['p'], x['o']
    
    s_ = random_state.choice(params['e_vocab_size'], s.shape)
    o_ = random_state.choice(params['e_vocab_size'], o.shape)
    
    x_ = {
        's': np.concatenate([s, s_, s]),
        'p': np.concatenate([p, p, p]),
        'o': np.concatenate([o, o, o_])}
    y_ = np.concatenate([y, np.zeros([2*len(y)], dtype=np.float32)])
    
    return tf.estimator.inputs.numpy_input_fn(x = x_,
                                              y = y_,
                                              batch_size = Config.batch_size,
                                              num_epochs = 1,
                                              shuffle = True)

In [4]:
random_state = np.random.RandomState(Config.seed)
triples_all, triples_tr, triples_va, triples_te = load_triple()
e2idx, p2idx, params = build_vocab(triples_tr)

model = tf.estimator.Estimator(model_fn,
                               params = params)

for _ in range(Config.n_epochs):
    model.train(train_input_fn(triples_tr,
                               e2idx,
                               p2idx,
                               random_state,
                               params))
evaluate_rank(model,
              triples_va,
              triples_te,
              triples_all,
              e2idx,
              p2idx,
              params['e_vocab_size'],
              batch_size = 10*Config.batch_size)

Files Already Downloaded
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': '/var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpycangdvc', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': None, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x114979588>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:
[<tf.Variable 'e_embed:0' shape=(40943, 200) dtype=float32_ref>,
 <tf.Variable 'p_embed:0' shape=(18, 200) dtype=float32_ref>]
INFO:tensorflow:params: 8192200
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create Ch

INFO:tensorflow:global_step/sec: 18.7359
INFO:tensorflow:loss = 0.6939635, step = 6145 (5.337 sec)
INFO:tensorflow:global_step/sec: 20.9413
INFO:tensorflow:loss = 0.70438325, step = 6245 (4.776 sec)
INFO:tensorflow:global_step/sec: 18.0192
INFO:tensorflow:loss = 0.69270545, step = 6345 (5.549 sec)
INFO:tensorflow:global_step/sec: 22.6687
INFO:tensorflow:loss = 0.7047264, step = 6445 (4.411 sec)
INFO:tensorflow:global_step/sec: 20.4828
INFO:tensorflow:loss = 0.69245744, step = 6545 (4.883 sec)
INFO:tensorflow:global_step/sec: 16.2736
INFO:tensorflow:loss = 0.6952773, step = 6645 (6.144 sec)
INFO:tensorflow:global_step/sec: 17.1719
INFO:tensorflow:loss = 0.6920192, step = 6745 (5.824 sec)
INFO:tensorflow:global_step/sec: 20.4738
INFO:tensorflow:loss = 0.69314414, step = 6845 (4.884 sec)
INFO:tensorflow:global_step/sec: 19.0128
INFO:tensorflow:loss = 0.70809495, step = 6945 (5.260 sec)
INFO:tensorflow:global_step/sec: 17.8861
INFO:tensorflow:loss = 0.71810555, step = 7045 (5.590 sec)
INFO

INFO:tensorflow:loss = 0.59760755, step = 12733
INFO:tensorflow:global_step/sec: 21.1025
INFO:tensorflow:loss = 0.029754262, step = 12833 (4.740 sec)
INFO:tensorflow:global_step/sec: 21.1444
INFO:tensorflow:loss = 0.042501785, step = 12933 (4.729 sec)
INFO:tensorflow:global_step/sec: 21.5083
INFO:tensorflow:loss = 0.035414997, step = 13033 (4.649 sec)
INFO:tensorflow:global_step/sec: 20.8188
INFO:tensorflow:loss = 0.025310071, step = 13133 (4.803 sec)
INFO:tensorflow:global_step/sec: 22.1181
INFO:tensorflow:loss = 0.037494536, step = 13233 (4.521 sec)
INFO:tensorflow:global_step/sec: 22.012
INFO:tensorflow:loss = 0.03186948, step = 13333 (4.543 sec)
INFO:tensorflow:global_step/sec: 18.6341
INFO:tensorflow:loss = 0.043697115, step = 13433 (5.367 sec)
INFO:tensorflow:global_step/sec: 19.312
INFO:tensorflow:loss = 0.029997258, step = 13533 (5.177 sec)
INFO:tensorflow:global_step/sec: 19.6594
INFO:tensorflow:loss = 0.020515874, step = 13633 (5.087 sec)
INFO:tensorflow:global_step/sec: 21.8

INFO:tensorflow:loss = 0.17531095, step = 20077 (4.997 sec)
INFO:tensorflow:global_step/sec: 20.9194
INFO:tensorflow:loss = 0.17477173, step = 20177 (4.780 sec)
INFO:tensorflow:global_step/sec: 21.1411
INFO:tensorflow:loss = 0.099013254, step = 20277 (4.730 sec)
INFO:tensorflow:global_step/sec: 19.9933
INFO:tensorflow:loss = 0.120860964, step = 20377 (5.002 sec)
INFO:tensorflow:global_step/sec: 21.2127
INFO:tensorflow:loss = 0.08131825, step = 20477 (4.714 sec)
INFO:tensorflow:global_step/sec: 22.1502
INFO:tensorflow:loss = 0.0973268, step = 20577 (4.515 sec)
INFO:tensorflow:global_step/sec: 20.909
INFO:tensorflow:loss = 0.14144398, step = 20677 (4.783 sec)
INFO:tensorflow:global_step/sec: 21.1521
INFO:tensorflow:loss = 0.092243716, step = 20777 (4.729 sec)
INFO:tensorflow:global_step/sec: 21.6032
INFO:tensorflow:loss = 0.08280894, step = 20877 (4.627 sec)
INFO:tensorflow:global_step/sec: 20.3031
INFO:tensorflow:loss = 0.06306779, step = 20977 (4.927 sec)
INFO:tensorflow:global_step/se

INFO:tensorflow:global_step/sec: 23.9407
INFO:tensorflow:loss = 0.010356201, step = 26565 (4.177 sec)
INFO:tensorflow:global_step/sec: 23.846
INFO:tensorflow:loss = 0.0070259185, step = 26665 (4.194 sec)
INFO:tensorflow:global_step/sec: 23.7468
INFO:tensorflow:loss = 0.0074805934, step = 26765 (4.211 sec)
INFO:tensorflow:global_step/sec: 24.0204
INFO:tensorflow:loss = 0.0045488332, step = 26865 (4.163 sec)
INFO:tensorflow:global_step/sec: 23.9869
INFO:tensorflow:loss = 0.004188918, step = 26965 (4.169 sec)
INFO:tensorflow:global_step/sec: 23.9063
INFO:tensorflow:loss = 0.002104497, step = 27065 (4.183 sec)
INFO:tensorflow:global_step/sec: 24.0007
INFO:tensorflow:loss = 0.0026165517, step = 27165 (4.166 sec)
INFO:tensorflow:global_step/sec: 24.0515
INFO:tensorflow:loss = 0.0026309416, step = 27265 (4.158 sec)
INFO:tensorflow:global_step/sec: 23.8923
INFO:tensorflow:loss = 0.0015587851, step = 27365 (4.185 sec)
INFO:tensorflow:global_step/sec: 23.8756
INFO:tensorflow:loss = 0.0011784815,

INFO:tensorflow:global_step/sec: 23.8951
INFO:tensorflow:loss = 0.0028378027, step = 33809 (4.185 sec)
INFO:tensorflow:global_step/sec: 24.0757
INFO:tensorflow:loss = 0.0017313372, step = 33909 (4.154 sec)
INFO:tensorflow:Saving checkpoints for 33952 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpycangdvc/model.ckpt.
INFO:tensorflow:Loss for final step: 0.0014505293.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:
[<tf.Variable 'e_embed:0' shape=(40943, 200) dtype=float32_ref>,
 <tf.Variable 'p_embed:0' shape=(18, 200) dtype=float32_ref>]
INFO:tensorflow:params: 8192200
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpycangdvc/model.ckpt-33952
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 33953 into /var/folders/sx/fv0r97j96fz8njp14dt5g

INFO:tensorflow:global_step/sec: 23.6786
INFO:tensorflow:loss = 0.00050450524, step = 40197 (4.223 sec)
INFO:tensorflow:global_step/sec: 23.8995
INFO:tensorflow:loss = 0.00045299216, step = 40297 (4.184 sec)
INFO:tensorflow:global_step/sec: 23.9589
INFO:tensorflow:loss = 0.0069279796, step = 40397 (4.174 sec)
INFO:tensorflow:global_step/sec: 23.9975
INFO:tensorflow:loss = 0.00056100485, step = 40497 (4.167 sec)
INFO:tensorflow:global_step/sec: 23.9202
INFO:tensorflow:loss = 0.00046801046, step = 40597 (4.181 sec)
INFO:tensorflow:global_step/sec: 23.9374
INFO:tensorflow:loss = 0.001019123, step = 40697 (4.178 sec)
INFO:tensorflow:global_step/sec: 23.9176
INFO:tensorflow:loss = 0.007912719, step = 40797 (4.181 sec)
INFO:tensorflow:global_step/sec: 23.9203
INFO:tensorflow:loss = 0.00018534885, step = 40897 (4.181 sec)
INFO:tensorflow:global_step/sec: 24.1055
INFO:tensorflow:loss = 0.0005909806, step = 40997 (4.148 sec)
INFO:tensorflow:global_step/sec: 23.8864
INFO:tensorflow:loss = 0.0004

100%|█████████████████████████████| 5000/5000 [02:04<00:00, 40.20it/s]


INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpycangdvc/model.ckpt-42440
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.


100%|█████████████████████████████| 5000/5000 [02:03<00:00, 40.59it/s]
100%|█████████████████████████████| 5000/5000 [03:40<00:00, 22.63it/s]

[valid] Raw Mean Rank: 820.1222
[valid] Raw Hits@1: 42.1
[valid] Raw Hits@3: 65.68
[valid] Raw Hits@5: 73.89
[valid] Raw Hits@10: 81.14
[valid] Filtered Mean Rank: 808.493
[valid] Filtered Hits@1: 65.16999999999999
[valid] Filtered Hits@3: 88.83
[valid] Filtered Hits@5: 91.59
[valid] Filtered Hits@10: 93.33
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.





INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpycangdvc/model.ckpt-42440
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.


100%|█████████████████████████████| 5000/5000 [02:02<00:00, 40.97it/s]


INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpycangdvc/model.ckpt-42440
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.


100%|█████████████████████████████| 5000/5000 [02:01<00:00, 41.00it/s]
100%|█████████████████████████████| 5000/5000 [03:37<00:00, 23.04it/s]


[test] Raw Mean Rank: 744.5139
[test] Raw Hits@1: 41.949999999999996
[test] Raw Hits@3: 65.95
[test] Raw Hits@5: 74.46000000000001
[test] Raw Hits@10: 81.98
[test] Filtered Mean Rank: 732.7748
[test] Filtered Hits@1: 65.86
[test] Filtered Hits@3: 89.18
[test] Filtered Hits@5: 91.86999999999999
[test] Filtered Hits@10: 93.58999999999999
