In [1]:
import tensorflow as tf
import numpy as np

import os, sys
sys.path.append(os.path.dirname(os.getcwd()))

from data import WN18
from model.distmult import model_fn
from model.metrics import evaluate_rank

In [2]:
class Config:
    seed = 21
    n_epochs = 10
    batch_size = 100
    embed_dim = 200

In [3]:
"""
e: entity
s: subject
p: predicate
o: object
"""

def read_triples(path):
    triples = []
    with open(path, 'rt') as f:
        for line in f.readlines():
            s, p, o = line.split()
            triples += [(s.strip(), p.strip(), o.strip())]
    return triples


def load_triple():
    WN18.download()
    triples_tr = read_triples('../data/WN18/wn18/train.txt')
    triples_va = read_triples('../data/WN18/wn18/valid.txt')
    triples_te = read_triples('../data/WN18/wn18/test.txt')
    
    triples_all = triples_tr + triples_va + triples_te
    
    return triples_all, triples_tr, triples_va, triples_te


def build_vocab(triples):
    params = {}
    
    e_set = {s for (s, p, o) in triples} | {o for (s, p, o) in triples}
    p_set = {p for (s, p, o) in triples}

    params['e_vocab_size'] = len(e_set)
    params['p_vocab_size'] = len(p_set)

    e2idx = {e: idx for idx, e in enumerate(sorted(e_set))}
    p2idx = {p: idx for idx, p in enumerate(sorted(p_set))}
    
    return e2idx, p2idx, params


def build_train_data(triples_tr, e2idx, p2idx):
    x_s = np.array([e2idx[s] for (s, p, o) in triples_tr], dtype=np.int32)
    x_p = np.array([p2idx[p] for (s, p, o) in triples_tr], dtype=np.int32)
    x_o = np.array([e2idx[o] for (s, p, o) in triples_tr], dtype=np.int32)

    x = {'s': x_s,
         'p': x_p,
         'o': x_o}
    y = np.ones([len(x_s)], dtype=np.float32)
    
    return x, y


def train_input_fn(triples_tr, e2idx, p2idx, random_state, params):
    x, y = build_train_data(triples_tr, e2idx, p2idx)
    s, p, o = x['s'], x['p'], x['o']
    
    s_ = random_state.choice(params['e_vocab_size'], s.shape)
    o_ = random_state.choice(params['e_vocab_size'], o.shape)
    
    x_ = {
        's': np.concatenate([s, s_, s]),
        'p': np.concatenate([p, p, p]),
        'o': np.concatenate([o, o, o_])}
    y_ = np.concatenate([y, np.zeros([2*len(y)], dtype=np.float32)])
    
    return tf.estimator.inputs.numpy_input_fn(x = x_,
                                              y = y_,
                                              batch_size = Config.batch_size,
                                              num_epochs = 1,
                                              shuffle = True)

In [4]:
random_state = np.random.RandomState(Config.seed)
triples_all, triples_tr, triples_va, triples_te = load_triple()
e2idx, p2idx, params = build_vocab(triples_all)
params['embed_dim'] = Config.embed_dim

model = tf.estimator.Estimator(model_fn,
                               params = params)

for _ in range(Config.n_epochs):
    model.train(train_input_fn(triples_tr,
                               e2idx,
                               p2idx,
                               random_state,
                               params))
evaluate_rank(model,
              triples_va,
              triples_te,
              triples_all,
              e2idx,
              p2idx,
              params['e_vocab_size'],
              batch_size = 10*Config.batch_size)

Files Already Downloaded
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': '/var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp7ngcvwge', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': None, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x119c39358>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:
[<tf.Variable 'e_embed:0' shape=(40943, 200) dtype=float32_ref>,
 <tf.Variable 'p_embed:0' shape=(18, 200) dtype=float32_ref>]
INFO:tensorflow:params: 8192200
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create Ch

INFO:tensorflow:global_step/sec: 22.6678
INFO:tensorflow:loss = 0.7255524, step = 6145 (4.411 sec)
INFO:tensorflow:global_step/sec: 22.1466
INFO:tensorflow:loss = 0.69558656, step = 6245 (4.516 sec)
INFO:tensorflow:global_step/sec: 22.7202
INFO:tensorflow:loss = 0.7044462, step = 6345 (4.401 sec)
INFO:tensorflow:global_step/sec: 21.7561
INFO:tensorflow:loss = 0.6922625, step = 6445 (4.596 sec)
INFO:tensorflow:global_step/sec: 22.4206
INFO:tensorflow:loss = 0.69820595, step = 6545 (4.460 sec)
INFO:tensorflow:global_step/sec: 23.5059
INFO:tensorflow:loss = 0.7306244, step = 6645 (4.254 sec)
INFO:tensorflow:global_step/sec: 21.4113
INFO:tensorflow:loss = 0.69160074, step = 6745 (4.671 sec)
INFO:tensorflow:global_step/sec: 22.1323
INFO:tensorflow:loss = 0.712002, step = 6845 (4.518 sec)
INFO:tensorflow:global_step/sec: 22.1644
INFO:tensorflow:loss = 0.72676605, step = 6945 (4.512 sec)
INFO:tensorflow:global_step/sec: 19.9895
INFO:tensorflow:loss = 0.7143452, step = 7045 (5.004 sec)
INFO:te

INFO:tensorflow:loss = 0.032129385, step = 12733
INFO:tensorflow:global_step/sec: 18.6796
INFO:tensorflow:loss = 0.025794316, step = 12833 (5.355 sec)
INFO:tensorflow:global_step/sec: 17.5306
INFO:tensorflow:loss = 0.025480025, step = 12933 (5.704 sec)
INFO:tensorflow:global_step/sec: 19.4081
INFO:tensorflow:loss = 0.028373208, step = 13033 (5.153 sec)
INFO:tensorflow:global_step/sec: 22.8685
INFO:tensorflow:loss = 0.025203254, step = 13133 (4.373 sec)
INFO:tensorflow:global_step/sec: 22.8483
INFO:tensorflow:loss = 0.019631417, step = 13233 (4.377 sec)
INFO:tensorflow:global_step/sec: 20.4557
INFO:tensorflow:loss = 0.022823164, step = 13333 (4.888 sec)
INFO:tensorflow:global_step/sec: 23.6672
INFO:tensorflow:loss = 0.023027796, step = 13433 (4.225 sec)
INFO:tensorflow:global_step/sec: 23.5299
INFO:tensorflow:loss = 0.017684508, step = 13533 (4.250 sec)
INFO:tensorflow:global_step/sec: 23.6671
INFO:tensorflow:loss = 0.02058577, step = 13633 (4.225 sec)
INFO:tensorflow:global_step/sec: 2

INFO:tensorflow:loss = 0.38071603, step = 20077 (5.353 sec)
INFO:tensorflow:global_step/sec: 18.2312
INFO:tensorflow:loss = 0.42231464, step = 20177 (5.485 sec)
INFO:tensorflow:global_step/sec: 19.0026
INFO:tensorflow:loss = 0.4413134, step = 20277 (5.262 sec)
INFO:tensorflow:global_step/sec: 19.6108
INFO:tensorflow:loss = 0.39315507, step = 20377 (5.099 sec)
INFO:tensorflow:global_step/sec: 19.8379
INFO:tensorflow:loss = 0.36730403, step = 20477 (5.042 sec)
INFO:tensorflow:global_step/sec: 19.5469
INFO:tensorflow:loss = 0.3799267, step = 20577 (5.116 sec)
INFO:tensorflow:global_step/sec: 19.6204
INFO:tensorflow:loss = 0.39125264, step = 20677 (5.096 sec)
INFO:tensorflow:global_step/sec: 19.6292
INFO:tensorflow:loss = 0.31001917, step = 20777 (5.095 sec)
INFO:tensorflow:global_step/sec: 19.6249
INFO:tensorflow:loss = 0.29270282, step = 20877 (5.097 sec)
INFO:tensorflow:global_step/sec: 19.6859
INFO:tensorflow:loss = 0.28594056, step = 20977 (5.078 sec)
INFO:tensorflow:global_step/sec: 

INFO:tensorflow:loss = 0.011526352, step = 26565 (5.406 sec)
INFO:tensorflow:global_step/sec: 19.9533
INFO:tensorflow:loss = 0.012076569, step = 26665 (5.013 sec)
INFO:tensorflow:global_step/sec: 22.3956
INFO:tensorflow:loss = 0.01925107, step = 26765 (4.464 sec)
INFO:tensorflow:global_step/sec: 21.9916
INFO:tensorflow:loss = 0.012211575, step = 26865 (4.547 sec)
INFO:tensorflow:global_step/sec: 21.0095
INFO:tensorflow:loss = 0.011163148, step = 26965 (4.760 sec)
INFO:tensorflow:global_step/sec: 23.3256
INFO:tensorflow:loss = 0.008911761, step = 27065 (4.287 sec)
INFO:tensorflow:global_step/sec: 20.0187
INFO:tensorflow:loss = 0.008803696, step = 27165 (4.995 sec)
INFO:tensorflow:global_step/sec: 23.6818
INFO:tensorflow:loss = 0.008499556, step = 27265 (4.223 sec)
INFO:tensorflow:global_step/sec: 22.6648
INFO:tensorflow:loss = 0.0066078384, step = 27365 (4.412 sec)
INFO:tensorflow:global_step/sec: 24.0057
INFO:tensorflow:loss = 0.010687338, step = 27465 (4.166 sec)
INFO:tensorflow:globa

INFO:tensorflow:loss = 0.031058263, step = 33809 (4.946 sec)
INFO:tensorflow:global_step/sec: 18.8324
INFO:tensorflow:loss = 0.0214027, step = 33909 (5.310 sec)
INFO:tensorflow:Saving checkpoints for 33952 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp7ngcvwge/model.ckpt.
INFO:tensorflow:Loss for final step: 0.008243542.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:
[<tf.Variable 'e_embed:0' shape=(40943, 200) dtype=float32_ref>,
 <tf.Variable 'p_embed:0' shape=(18, 200) dtype=float32_ref>]
INFO:tensorflow:params: 8192200
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp7ngcvwge/model.ckpt-33952
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 33953 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp7ngcvwge/model.ckpt.
INFO:tenso

INFO:tensorflow:loss = 0.0030050764, step = 40197 (4.587 sec)
INFO:tensorflow:global_step/sec: 22.9329
INFO:tensorflow:loss = 0.002852875, step = 40297 (4.361 sec)
INFO:tensorflow:global_step/sec: 22.4905
INFO:tensorflow:loss = 0.0020960588, step = 40397 (4.446 sec)
INFO:tensorflow:global_step/sec: 20.2737
INFO:tensorflow:loss = 0.0028847111, step = 40497 (4.933 sec)
INFO:tensorflow:global_step/sec: 22.9953
INFO:tensorflow:loss = 0.0029658796, step = 40597 (4.349 sec)
INFO:tensorflow:global_step/sec: 20.9784
INFO:tensorflow:loss = 0.0039291144, step = 40697 (4.767 sec)
INFO:tensorflow:global_step/sec: 22.5737
INFO:tensorflow:loss = 0.04586746, step = 40797 (4.430 sec)
INFO:tensorflow:global_step/sec: 23.1496
INFO:tensorflow:loss = 0.0011953926, step = 40897 (4.320 sec)
INFO:tensorflow:global_step/sec: 21.6783
INFO:tensorflow:loss = 0.0017189255, step = 40997 (4.613 sec)
INFO:tensorflow:global_step/sec: 22.5247
INFO:tensorflow:loss = 0.0017062488, step = 41097 (4.440 sec)
INFO:tensorflo

100%|█████████████████████████████| 5000/5000 [02:08<00:00, 38.98it/s]


INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp7ngcvwge/model.ckpt-42440
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.


100%|█████████████████████████████| 5000/5000 [02:06<00:00, 39.47it/s]
100%|█████████████████████████████| 5000/5000 [03:55<00:00, 21.20it/s]

[valid] Raw Mean Rank: 902.2788
[valid] Raw Hits@1: 41.79
[valid] Raw Hits@3: 65.92
[valid] Raw Hits@5: 74.17
[valid] Raw Hits@10: 80.96
[valid] Filtered Mean Rank: 891.5644
[valid] Filtered Hits@1: 69.37
[valid] Filtered Hits@3: 90.14999999999999
[valid] Filtered Hits@5: 92.25
[valid] Filtered Hits@10: 93.45
INFO:tensorflow:Calling model_fn.





INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp7ngcvwge/model.ckpt-42440
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.


100%|█████████████████████████████| 5000/5000 [02:13<00:00, 37.56it/s]


INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp7ngcvwge/model.ckpt-42440
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.


100%|█████████████████████████████| 5000/5000 [02:08<00:00, 38.86it/s]
100%|█████████████████████████████| 5000/5000 [03:59<00:00, 20.86it/s]

[test] Raw Mean Rank: 836.8393
[test] Raw Hits@1: 41.33
[test] Raw Hits@3: 65.55
[test] Raw Hits@5: 73.87
[test] Raw Hits@10: 81.37
[test] Filtered Mean Rank: 826.2041
[test] Filtered Hits@1: 68.69
[test] Filtered Hits@3: 90.14
[test] Filtered Hits@5: 92.10000000000001
[test] Filtered Hits@10: 93.37



