In [1]:
from tqdm import tqdm
import tensorflow as tf
import numpy as np
import sklearn
import pprint
import itertools
import os
import sys

sys.path.append(os.path.dirname(os.getcwd()))
from data import WN18

In [2]:
class Config:
    n_epochs = 30
    batch_size = 100
    embed_dim = 200

In [3]:
"""
e: entity
s: subject
p: predicate
o: object
"""
def glance_dict(d, n=5):
    return dict(itertools.islice(d.items(), n))


def read_triples(path):
    triples = []
    with open(path, 'rt') as f:
        for line in f.readlines():
            s, p, o = line.split()
            triples += [(s.strip(), p.strip(), o.strip())]
    return triples


def load_triple():
    WN18.download()
    triples_tr = read_triples('../data/WN18/wn18/train.txt')
    triples_va = read_triples('../data/WN18/wn18/valid.txt')
    triples_te = read_triples('../data/WN18/wn18/test.txt')
    triples_all = triples_tr + triples_va + triples_te
    return triples_all, triples_tr, triples_va, triples_te


def build_vocab(triples):
    params = {}
    e_set = {s for (s, p, o) in triples} | {o for (s, p, o) in triples}
    p_set = {p for (s, p, o) in triples}
    params['e_vocab_size'] = len(e_set)
    params['p_vocab_size'] = len(p_set)
    e2idx = {e: idx for idx, e in enumerate(sorted(e_set))}
    p2idx = {p: idx for idx, p in enumerate(sorted(p_set))}
    return e2idx, p2idx, params


def build_multi_label(triples_tr):
    sp2o = {}
    for (_s, _p, _o) in triples_tr:
        s, p, o = e2idx[_s], p2idx[_p], e2idx[_o] 
        if (s,p) not in sp2o:
            sp2o[(s,p)] = [o]
        else:
            if o not in sp2o[(s,p)]:
                sp2o[(s,p)].append(o)
    return sp2o


def get_y(triples_tr, e2idx, p2idx, sp2o):
    y = []
    for (_s, _p, _o) in triples_tr:
        s, p, o = e2idx[_s], p2idx[_p], e2idx[_o] 
        temp = np.zeros([len(e2idx)])
        temp[sp2o[(s,p)]] = 1.
        y.append(temp)
    y = np.asarray(y)
    return y


def next_train_batch(triples_tr, e2idx, p2idx, sp2o):
    for i in range(0, len(triples_tr), Config.batch_size):
        _triples_tr = triples_tr[i: i+Config.batch_size]
        x_s = np.asarray([e2idx[s] for (s, p, o) in _triples_tr], dtype=np.int32)
        x_p = np.asarray([p2idx[p] for (s, p, o) in _triples_tr], dtype=np.int32)
        y = get_y(_triples_tr, e2idx, p2idx, sp2o)
        yield (x_s, x_p, y)


def train_input_fn(triples_tr, e2idx, p2idx, s2p2o):
    dataset = tf.data.Dataset.from_generator(
        lambda: next_train_batch(sklearn.utils.shuffle(triples_tr),
                                 e2idx,
                                 p2idx,
                                 s2p2o),
        (tf.int32, tf.int32, tf.float32),
        (tf.TensorShape([None]),
         tf.TensorShape([None]),
         tf.TensorShape([None, len(e2idx)])))
    iterator = dataset.make_one_shot_iterator()
    x_s, x_p, y = iterator.get_next()
    return {'s': x_s, 'p': x_p}, y

In [4]:
def o_next_batch(eval_triples,
                 e2idx, 
                 p2idx):
    for (s, p, o) in tqdm(eval_triples, total=len(eval_triples), ncols=70):
        s_idx, p_idx = e2idx[s], p2idx[p]
        yield np.atleast_1d(s_idx), np.atleast_1d(p_idx)


def o_input_fn(eval_triples,
               e2idx, 
               p2idx):
    dataset = tf.data.Dataset.from_generator(
        lambda: o_next_batch(eval_triples,
                             e2idx, 
                             p2idx),
        (tf.int32, tf.int32),
        (tf.TensorShape([None,]),
         tf.TensorShape([None,])))
    iterator = dataset.make_one_shot_iterator()
    s, p = iterator.get_next()
    return {'s': s, 'p': p}


def evaluate_rank(model,
                  valid_triples,
                  test_triples,
                  all_triples,
                  e2idx,
                  p2idx):
    for eval_name, eval_triples in [('test', test_triples)]:
        _scores_o = list(model.predict(
            lambda: o_input_fn(eval_triples,
                               e2idx, 
                               p2idx)))
        ScoresO = np.reshape(_scores_o, [len(eval_triples), len(e2idx)])
        ranks_o, filtered_ranks_o = [], []
        for ((s, p, o), scores_o) in tqdm(zip(eval_triples, ScoresO),
                                          total=len(eval_triples),
                                          ncols=70):
            s_idx, p_idx, o_idx = e2idx[s], p2idx[p], e2idx[o]
            ranks_o += [1 + np.argsort(np.argsort(- scores_o))[o_idx]]
            filtered_scores_o = scores_o.copy()
            rm_idx_o = [e2idx[fo] for (fs, fp, fo) in all_triples if fs == s and fp == p and fo != o]
            filtered_scores_o[rm_idx_o] = - np.inf
            filtered_ranks_o += [1 + np.argsort(np.argsort(- filtered_scores_o))[o_idx]]
        for setting_name, setting_ranks in [('Raw', ranks_o), ('Filtered', filtered_ranks_o)]:
            mean_rank = np.mean(1 / np.asarray(setting_ranks))
            print('[{}] {} MRR: {}'.format(eval_name, setting_name, mean_rank))
            for k in [1, 3, 5, 10]:
                hits_at_k = np.mean(np.asarray(setting_ranks) <= k) * 100
                print('[{}] {} Hits@{}: {}'.format(eval_name, setting_name, k, hits_at_k))

In [5]:
def forward(features, params):
    e_embed = tf.get_variable('e_embed',
                              [params['e_vocab_size'], Config.embed_dim],
                              initializer=tf.contrib.layers.xavier_initializer())
    p_embed = tf.get_variable('p_embed',
                              [params['p_vocab_size'], Config.embed_dim],
                              initializer=tf.contrib.layers.xavier_initializer())
    
    s = tf.nn.embedding_lookup(e_embed, features['s'])
    p = tf.nn.embedding_lookup(p_embed, features['p'])
    
    logits = tf.matmul(s*p, e_embed, transpose_b=True)
    return logits
    
    
def model_fn(features, labels, mode, params):
    logits = forward(features, params)
    
    if mode == tf.estimator.ModeKeys.TRAIN:
        tf.logging.info('\n'+pprint.pformat(tf.trainable_variables()))
        tf.logging.info('params: %d'%count_train_params())
        
        loss_op = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits,
                                                                        labels=labels))
        
        train_op = tf.train.AdamOptimizer().minimize(
            loss_op, global_step=tf.train.get_global_step())
        
        return tf.estimator.EstimatorSpec(mode = mode,
                                          loss = loss_op,
                                          train_op = train_op)
    
    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(mode, predictions = tf.sigmoid(logits))


def count_train_params():
    return np.sum([np.prod([d.value for d in v.get_shape()]) for v in tf.trainable_variables()])

In [6]:
triples_all, triples_tr, triples_va, triples_te = load_triple()
e2idx, p2idx, params = build_vocab(triples_all)
sp2o = build_multi_label(triples_tr)

model = tf.estimator.Estimator(model_fn,
                               params = params)

for n_epoch in range(Config.n_epochs):
    model.train(lambda: train_input_fn(triples_tr, e2idx, p2idx, sp2o))

evaluate_rank(model,
              triples_va,
              triples_te,
              triples_all,
              e2idx,
              p2idx,)

Files Already Downloaded
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': '/var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp8s2uekfg', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': None, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x112a75668>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:
[<tf.Variable 'e_embed:0' shape=(40943, 200) dtype=float32_ref>,
 <tf.Variable 'p_embed:0' shape=(18, 200) dtype=float32_ref>]
INFO:tensorflow:params: 8192200
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create Ch

INFO:tensorflow:global_step/sec: 3.27102
INFO:tensorflow:loss = 8125.1074, step = 4346 (30.573 sec)
INFO:tensorflow:global_step/sec: 3.07237
INFO:tensorflow:loss = 4305.2583, step = 4446 (32.548 sec)
INFO:tensorflow:global_step/sec: 3.12156
INFO:tensorflow:loss = 7119.0664, step = 4546 (32.035 sec)
INFO:tensorflow:global_step/sec: 3.05598
INFO:tensorflow:loss = 5056.994, step = 4646 (32.723 sec)
INFO:tensorflow:global_step/sec: 3.11539
INFO:tensorflow:loss = 4314.9775, step = 4746 (32.099 sec)
INFO:tensorflow:global_step/sec: 3.27807
INFO:tensorflow:loss = 7272.088, step = 4846 (30.506 sec)
INFO:tensorflow:global_step/sec: 3.18642
INFO:tensorflow:loss = 4958.759, step = 4946 (31.383 sec)
INFO:tensorflow:global_step/sec: 2.916
INFO:tensorflow:loss = 3960.3843, step = 5046 (34.294 sec)
INFO:tensorflow:global_step/sec: 3.14511
INFO:tensorflow:loss = 5483.0386, step = 5146 (31.795 sec)
INFO:tensorflow:global_step/sec: 2.97648
INFO:tensorflow:loss = 4991.9453, step = 5246 (33.597 sec)
INFO:

INFO:tensorflow:Loss for final step: 696.5001.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:
[<tf.Variable 'e_embed:0' shape=(40943, 200) dtype=float32_ref>,
 <tf.Variable 'p_embed:0' shape=(18, 200) dtype=float32_ref>]
INFO:tensorflow:params: 8192200
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp8s2uekfg/model.ckpt-9905
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 9906 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp8s2uekfg/model.ckpt.
INFO:tensorflow:loss = 2172.543, step = 9906
INFO:tensorflow:global_step/sec: 3.16719
INFO:tensorflow:loss = 2479.875, step = 10006 (31.575 sec)
INFO:tensorflow:global_step/sec: 3.14525
INFO:tensorflow:loss = 2182.3467, step = 10106 (31.794 sec)
INFO:tensorflow:global_step/sec: 3.18638
INFO:tensorf

INFO:tensorflow:loss = 930.5989, step = 14851 (35.991 sec)
INFO:tensorflow:global_step/sec: 2.91699
INFO:tensorflow:loss = 770.5972, step = 14951 (34.283 sec)
INFO:tensorflow:global_step/sec: 2.61697
INFO:tensorflow:loss = 828.5442, step = 15051 (38.212 sec)
INFO:tensorflow:global_step/sec: 2.55096
INFO:tensorflow:loss = 898.90265, step = 15151 (39.201 sec)
INFO:tensorflow:global_step/sec: 2.63354
INFO:tensorflow:loss = 934.0525, step = 15251 (37.972 sec)
INFO:tensorflow:global_step/sec: 2.838
INFO:tensorflow:loss = 1039.7983, step = 15351 (35.236 sec)
INFO:tensorflow:global_step/sec: 2.74509
INFO:tensorflow:loss = 844.0864, step = 15451 (36.431 sec)
INFO:tensorflow:global_step/sec: 2.77919
INFO:tensorflow:loss = 872.28577, step = 15551 (35.980 sec)
INFO:tensorflow:Saving checkpoints for 15565 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp8s2uekfg/model.ckpt.
INFO:tensorflow:Loss for final step: 409.01862.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:
[<tf.Variable 'e_e

INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 19811 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp8s2uekfg/model.ckpt.
INFO:tensorflow:loss = 235.28838, step = 19811
INFO:tensorflow:global_step/sec: 2.83866
INFO:tensorflow:loss = 287.703, step = 19911 (35.230 sec)
INFO:tensorflow:global_step/sec: 2.85495
INFO:tensorflow:loss = 280.86102, step = 20011 (35.026 sec)
INFO:tensorflow:global_step/sec: 3.08468
INFO:tensorflow:loss = 268.48914, step = 20111 (32.418 sec)
INFO:tensorflow:global_step/sec: 2.9483
INFO:tensorflow:loss = 241.77159, step = 20211 (33.918 sec)
INFO:tensorflow:global_step/sec: 2.7615
INFO:tensorflow:loss = 277.81287, step = 20311 (36.212 sec)
INFO:tensorflow:global_step/sec: 2.92006
INFO:tensorflow:loss = 240.6268, step = 20411 (34.246 sec)
INFO:tensorflow:global_step/sec: 2.96338
INFO:tensorflow:loss = 253.66588, step = 20511 (33.745 sec)
INFO:tensorflow:global_step/sec: 2.81748
INFO

INFO:tensorflow:global_step/sec: 2.829
INFO:tensorflow:loss = 267.90814, step = 25256 (35.348 sec)
INFO:tensorflow:global_step/sec: 3.00006
INFO:tensorflow:loss = 183.53436, step = 25356 (33.333 sec)
INFO:tensorflow:global_step/sec: 3.03094
INFO:tensorflow:loss = 330.0274, step = 25456 (32.993 sec)
INFO:tensorflow:Saving checkpoints for 25470 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp8s2uekfg/model.ckpt.
INFO:tensorflow:Loss for final step: 70.04525.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:
[<tf.Variable 'e_embed:0' shape=(40943, 200) dtype=float32_ref>,
 <tf.Variable 'p_embed:0' shape=(18, 200) dtype=float32_ref>]
INFO:tensorflow:params: 8192200
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp8s2uekfg/model.ckpt-25470
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_o

INFO:tensorflow:global_step/sec: 3.41856
INFO:tensorflow:loss = 118.018936, step = 30016 (29.252 sec)
INFO:tensorflow:global_step/sec: 3.41422
INFO:tensorflow:loss = 103.53102, step = 30116 (29.289 sec)
INFO:tensorflow:global_step/sec: 3.40647
INFO:tensorflow:loss = 184.49255, step = 30216 (29.356 sec)
INFO:tensorflow:global_step/sec: 3.41435
INFO:tensorflow:loss = 170.09634, step = 30316 (29.288 sec)
INFO:tensorflow:global_step/sec: 3.41329
INFO:tensorflow:loss = 169.36855, step = 30416 (29.297 sec)
INFO:tensorflow:global_step/sec: 3.41493
INFO:tensorflow:loss = 212.79138, step = 30516 (29.283 sec)
INFO:tensorflow:global_step/sec: 3.40377
INFO:tensorflow:loss = 175.13034, step = 30616 (29.379 sec)
INFO:tensorflow:global_step/sec: 3.41734
INFO:tensorflow:loss = 133.15599, step = 30716 (29.262 sec)
INFO:tensorflow:global_step/sec: 3.40859
INFO:tensorflow:loss = 178.4163, step = 30816 (29.338 sec)
INFO:tensorflow:global_step/sec: 3.41511
INFO:tensorflow:loss = 149.86278, step = 30916 (29

INFO:tensorflow:Calling model_fn.
INFO:tensorflow:
[<tf.Variable 'e_embed:0' shape=(40943, 200) dtype=float32_ref>,
 <tf.Variable 'p_embed:0' shape=(18, 200) dtype=float32_ref>]
INFO:tensorflow:params: 8192200
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp8s2uekfg/model.ckpt-35375
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 35376 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp8s2uekfg/model.ckpt.
INFO:tensorflow:loss = 144.90439, step = 35376
INFO:tensorflow:global_step/sec: 3.10067
INFO:tensorflow:loss = 117.857, step = 35476 (32.252 sec)
INFO:tensorflow:global_step/sec: 3.0776
INFO:tensorflow:loss = 153.94852, step = 35576 (32.493 sec)
INFO:tensorflow:global_step/sec: 3.07743
INFO:tensorflow:loss = 161.11603, step = 35676 (32.494 se

INFO:tensorflow:loss = 264.38156, step = 40321 (32.380 sec)
INFO:tensorflow:global_step/sec: 3.07982
INFO:tensorflow:loss = 164.46753, step = 40421 (32.469 sec)
INFO:tensorflow:global_step/sec: 3.08902
INFO:tensorflow:loss = 131.55156, step = 40521 (32.373 sec)
INFO:tensorflow:global_step/sec: 3.0898
INFO:tensorflow:loss = 145.63284, step = 40621 (32.365 sec)
INFO:tensorflow:global_step/sec: 2.96838
INFO:tensorflow:loss = 162.33733, step = 40721 (33.688 sec)
INFO:tensorflow:global_step/sec: 3.11578
INFO:tensorflow:loss = 115.8383, step = 40821 (32.095 sec)
INFO:tensorflow:global_step/sec: 3.0635
INFO:tensorflow:loss = 228.30219, step = 40921 (32.643 sec)
INFO:tensorflow:global_step/sec: 2.927
INFO:tensorflow:loss = 190.0309, step = 41021 (34.165 sec)
INFO:tensorflow:Saving checkpoints for 41035 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp8s2uekfg/model.ckpt.
INFO:tensorflow:Loss for final step: 65.98539.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:
[<tf.Variable 'e_e

100%|████████████████████████████| 5000/5000 [00:13<00:00, 373.18it/s]
100%|█████████████████████████████| 5000/5000 [01:48<00:00, 45.97it/s]


[test] Raw MRR: 0.48660244632331834
[test] Raw Hits@1: 33.62
[test] Raw Hits@3: 58.4
[test] Raw Hits@5: 68.76
[test] Raw Hits@10: 77.28
[test] Filtered MRR: 0.6707606927510673
[test] Filtered Hits@1: 53.300000000000004
[test] Filtered Hits@3: 79.12
[test] Filtered Hits@5: 83.46000000000001
[test] Filtered Hits@10: 86.61999999999999
