In [1]:
from tqdm import tqdm
import tensorflow as tf
import numpy as np
import pprint

import os, sys
sys.path.append(os.path.dirname(os.getcwd()))

from data import WN18
from model.metrics import evaluate_rank

In [2]:
class Config:
    seed = 21
    n_epochs = 10
    batch_size = 128
    embed_dim = 150

In [3]:
"""
e: entity
s: subject
p: predicate
o: object
"""

def read_triples(path):
    triples = []
    with open(path, 'rt') as f:
        for line in f.readlines():
            s, p, o = line.split()
            triples += [(s.strip(), p.strip(), o.strip())]
    return triples


def load_triple():
    WN18.download()
    triples_tr = read_triples('../data/WN18/wn18/train.txt')
    triples_va = read_triples('../data/WN18/wn18/valid.txt')
    triples_te = read_triples('../data/WN18/wn18/test.txt')
    
    triples_all = triples_tr + triples_va + triples_te
    
    return triples_all, triples_tr, triples_va, triples_te


def build_vocab(triples):
    params = {}
    
    e_set = {s for (s, p, o) in triples} | {o for (s, p, o) in triples}
    p_set = {p for (s, p, o) in triples}

    params['e_vocab_size'] = len(e_set)
    params['p_vocab_size'] = len(p_set)

    e2idx = {e: idx for idx, e in enumerate(sorted(e_set))}
    p2idx = {p: idx for idx, p in enumerate(sorted(p_set))}
    
    return e2idx, p2idx, params


def build_train_data(triples_tr, e2idx, p2idx):
    x_s = np.array([e2idx[s] for (s, p, o) in triples_tr], dtype=np.int32)
    x_p = np.array([p2idx[p] for (s, p, o) in triples_tr], dtype=np.int32)
    x_o = np.array([e2idx[o] for (s, p, o) in triples_tr], dtype=np.int32)

    x = {'s': x_s,
         'p': x_p,
         'o': x_o}
    y = np.ones([len(x_s)], dtype=np.float32)
    
    return x, y


def train_input_fn(triples_tr, e2idx, p2idx, random_state, params):
    x, y = build_train_data(triples_tr, e2idx, p2idx)
    s, p, o = x['s'], x['p'], x['o']
    
    s_ = random_state.choice(params['e_vocab_size'], s.shape)
    o_ = random_state.choice(params['e_vocab_size'], o.shape)
    
    x_ = {
        's': np.concatenate([s, s_, s]),
        'p': np.concatenate([p, p, p]),
        'o': np.concatenate([o, o, o_])}
    y_ = np.concatenate([y, np.zeros([2*len(y)], dtype=np.float32)])
    
    return tf.estimator.inputs.numpy_input_fn(x = x_,
                                              y = y_,
                                              batch_size = Config.batch_size,
                                              num_epochs = 1,
                                              shuffle = True)

In [4]:
def forward(features, params):
    e_embed = tf.get_variable('e_embed',
                              [params['e_vocab_size'], Config.embed_dim],
                              initializer=tf.contrib.layers.xavier_initializer())
    p_embed = tf.get_variable('p_embed',
                              [params['p_vocab_size'], Config.embed_dim],
                              initializer=tf.contrib.layers.xavier_initializer())
    
    s = tf.nn.embedding_lookup(e_embed, features['s'])
    p = tf.nn.embedding_lookup(p_embed, features['p'])
    o = tf.nn.embedding_lookup(e_embed, features['o'])
    
    logits = tf.reduce_sum(s * p * o, [1])
    
    return logits
    
    
def model_fn(features, labels, mode, params):
    logits = forward(features, params)
    
    if mode == tf.estimator.ModeKeys.TRAIN:
        tf.logging.info('\n'+pprint.pformat(tf.trainable_variables()))
        
        loss_op = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits,
                                                                         labels=labels))
        
        train_op = tf.train.AdamOptimizer().minimize(
            loss_op, global_step=tf.train.get_global_step())
        
        return tf.estimator.EstimatorSpec(mode = mode,
                                          loss = loss_op,
                                          train_op = train_op)
    
    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(mode, predictions = tf.sigmoid(logits))

In [5]:
random_state = np.random.RandomState(Config.seed)
triples_all, triples_tr, triples_va, triples_te = load_triple()
e2idx, p2idx, params = build_vocab(triples_all)

tf.logging.set_verbosity(tf.logging.WARN)

model = tf.estimator.Estimator(model_fn,
                               params = params)

for n_epoch in tqdm(range(Config.n_epochs), total=Config.n_epochs, ncols=70):
    model.train(train_input_fn(triples_tr,
                               e2idx,
                               p2idx,
                               random_state,
                               params))

Files Already Downloaded


100%|████████████████████████████████| 10/10 [18:08<00:00, 108.83s/it]


In [6]:
evaluate_rank(model,
              triples_va,
              triples_te,
              triples_all,
              e2idx,
              p2idx,
              params['e_vocab_size'],
              batch_size = Config.batch_size*10)

100%|████████████████████████████▉| 4996/5000 [02:20<00:00, 35.60it/s]
  0%|                                        | 0/5000 [00:00<?, ?it/s][A
  0%|                                | 4/5000 [00:00<02:18, 36.10it/s][A
  0%|                                | 8/5000 [00:00<02:12, 37.61it/s][A
  0%|                               | 12/5000 [00:00<02:13, 37.43it/s][A
  0%|                               | 16/5000 [00:00<02:12, 37.54it/s][A
  0%|                               | 20/5000 [00:00<02:13, 37.27it/s][A
  0%|▏                              | 24/5000 [00:00<02:14, 36.97it/s][A
  1%|▏                              | 28/5000 [00:00<02:14, 37.02it/s][A
  1%|▏                              | 32/5000 [00:00<02:16, 36.40it/s][A
  1%|▏                              | 36/5000 [00:00<02:17, 36.15it/s][A
  1%|▏                              | 40/5000 [00:01<02:17, 36.02it/s][A
  1%|▎                              | 44/5000 [00:01<02:16, 36.24it/s][A
  1%|▎                              | 48/

  1%|▎                              | 50/5000 [00:02<04:13, 19.52it/s][A
  1%|▎                              | 52/5000 [00:02<04:14, 19.41it/s][A
  1%|▎                              | 54/5000 [00:02<04:15, 19.34it/s][A
  1%|▎                              | 57/5000 [00:02<04:15, 19.37it/s][A
  1%|▎                              | 59/5000 [00:03<04:17, 19.19it/s][A
  1%|▍                              | 61/5000 [00:03<04:18, 19.10it/s][A
  1%|▍                              | 63/5000 [00:03<04:18, 19.11it/s][A
  1%|▍                              | 65/5000 [00:03<04:18, 19.12it/s][A
  1%|▍                              | 68/5000 [00:03<04:17, 19.16it/s][A
  1%|▍                              | 70/5000 [00:03<04:17, 19.17it/s][A
  1%|▍                              | 72/5000 [00:03<04:16, 19.19it/s][A
  2%|▍                              | 75/5000 [00:03<04:16, 19.21it/s][A
  2%|▍                              | 77/5000 [00:04<04:16, 19.22it/s][A
  2%|▍                              | 

[valid] Raw Mean Rank: 828.7368
[valid] Raw Hits@1: 39.53
[valid] Raw Hits@3: 63.92
[valid] Raw Hits@5: 72.67
[valid] Raw Hits@10: 80.39
[valid] Filtered Mean Rank: 816.6368
[valid] Filtered Hits@1: 60.31999999999999
[valid] Filtered Hits@3: 87.25
[valid] Filtered Hits@5: 91.17
[valid] Filtered Hits@10: 93.22


100%|████████████████████████████▉| 4997/5000 [02:31<00:00, 33.00it/s]
  0%|                                        | 0/5000 [00:00<?, ?it/s][A
  0%|                                | 4/5000 [00:00<02:18, 36.16it/s][A
  0%|                                | 8/5000 [00:00<02:14, 37.03it/s][A
  0%|                               | 12/5000 [00:00<02:13, 37.50it/s][A
  0%|                               | 16/5000 [00:00<02:11, 37.82it/s][A
  0%|                               | 20/5000 [00:00<02:13, 37.42it/s][A
  0%|▏                              | 24/5000 [00:00<02:13, 37.39it/s][A
  1%|▏                              | 28/5000 [00:00<02:12, 37.55it/s][A
  1%|▏                              | 32/5000 [00:00<02:12, 37.63it/s][A
  1%|▏                              | 36/5000 [00:00<02:13, 37.16it/s][A
  1%|▏                              | 40/5000 [00:01<02:16, 36.36it/s][A
  1%|▎                              | 44/5000 [00:01<02:16, 36.19it/s][A
  1%|▎                              | 48/

  2%|▌                              | 91/5000 [00:04<04:11, 19.48it/s][A
  2%|▌                              | 94/5000 [00:04<04:11, 19.53it/s][A
  2%|▌                              | 97/5000 [00:04<04:10, 19.59it/s][A
  2%|▌                             | 100/5000 [00:05<04:09, 19.62it/s][A
  2%|▌                             | 103/5000 [00:05<04:08, 19.67it/s][A
  2%|▋                             | 106/5000 [00:05<04:08, 19.71it/s][A
  2%|▋                             | 109/5000 [00:05<04:07, 19.76it/s][A
  2%|▋                             | 112/5000 [00:05<04:07, 19.76it/s][A
  2%|▋                             | 115/5000 [00:05<04:07, 19.76it/s][A
  2%|▋                             | 118/5000 [00:05<04:06, 19.77it/s][A
  2%|▋                             | 121/5000 [00:06<04:06, 19.79it/s][A
100%|█████████████████████████████| 5000/5000 [04:05<00:00, 20.39it/s]


[test] Raw Mean Rank: 783.6124
[test] Raw Hits@1: 39.379999999999995
[test] Raw Hits@3: 63.74999999999999
[test] Raw Hits@5: 72.87
[test] Raw Hits@10: 81.05
[test] Filtered Mean Rank: 771.6143
[test] Filtered Hits@1: 61.050000000000004
[test] Filtered Hits@3: 87.88
[test] Filtered Hits@5: 91.47999999999999
[test] Filtered Hits@10: 93.63
