In [None]:
import nrekit
import nrekit.data_loader
import nrekit.framework
# import nrekit.rl
import nrekit.network.embedding
import nrekit.network.encoder
import nrekit.network.selector
import nrekit.network.classifier
import numpy as np
import tensorflow as tf
import sys
import os

dataset_name = 'nyt'

dataset_dir = os.path.join('./data', dataset_name)
if not os.path.isdir(dataset_dir):
    raise Exception("[ERROR] Dataset dir %s doesn't exist!" % (dataset_dir))


# The first 3 parameters are train / test data file name, word embedding file name and relation-id mapping file name respectively.
train_loader = nrekit.data_loader.json_file_data_loader(os.path.join(dataset_dir, 'train.json'), 
                                                        os.path.join(dataset_dir, 'word_vec.json'),
                                                        os.path.join(dataset_dir, 'rel2id.json'), 
                                                        mode=nrekit.data_loader.json_file_data_loader.MODE_RELFACT_BAG,
                                                        shuffle=True)
test_loader = nrekit.data_loader.json_file_data_loader(os.path.join(dataset_dir, 'test.json'), 
                                                       os.path.join(dataset_dir, 'word_vec.json'),
                                                       os.path.join(dataset_dir, 'rel2id.json'), 
                                                       mode=nrekit.data_loader.json_file_data_loader.MODE_ENTPAIR_BAG,
                                                       shuffle=False)

framework = nrekit.framework.re_framework(train_loader, test_loader)

class model(nrekit.framework.re_model):
    encoder = "pcnn"
    selector = "att"

    def __init__(self, train_data_loader, batch_size, max_length=120):
        nrekit.framework.re_model.__init__(self, train_data_loader, batch_size, max_length=max_length)
        self.mask = tf.placeholder(dtype=tf.int32, shape=[None, max_length], name="mask")
        
        # Embedding
        x = nrekit.network.embedding.word_position_embedding(self.word, self.word_vec_mat, self.pos1, self.pos2)

        # Encoder
        if model.encoder == "pcnn":
            x_train = nrekit.network.encoder.pcnn(x, self.mask, keep_prob=0.5)
            x_test = nrekit.network.encoder.pcnn(x, self.mask, keep_prob=1.0)
        elif model.encoder == "cnn":
            x_train = nrekit.network.encoder.cnn(x, keep_prob=0.5)
            x_test = nrekit.network.encoder.cnn(x, keep_prob=1.0)
        elif model.encoder == "rnn":
            x_train = nrekit.network.encoder.rnn(x, self.length, keep_prob=0.5)
            x_test = nrekit.network.encoder.rnn(x, self.length, keep_prob=1.0)
        elif model.encoder == "birnn":
            x_train = nrekit.network.encoder.birnn(x, self.length, keep_prob=0.5)
            x_test = nrekit.network.encoder.birnn(x, self.length, keep_prob=1.0)
        else:
            raise NotImplementedError

        # Selector
        if model.selector == "att":
            self._train_logit, train_repre = nrekit.network.selector.bag_attention(x_train, self.scope, self.ins_label, self.rel_tot, True, keep_prob=0.5)
            self._test_logit, test_repre = nrekit.network.selector.bag_attention(x_test, self.scope, self.ins_label, self.rel_tot, False, keep_prob=1.0)
        elif model.selector == "ave":
            self._train_logit, train_repre = nrekit.network.selector.bag_average(x_train, self.scope, self.rel_tot, keep_prob=0.5)
            self._test_logit, test_repre = nrekit.network.selector.bag_average(x_test, self.scope, self.rel_tot, keep_prob=1.0)
            self._test_logit = tf.nn.softmax(self._test_logit)
        elif model.selector == "one":
            self._train_logit, train_repre = nrekit.network.selector.bag_one(x_train, self.scope, self.label, self.rel_tot, True, keep_prob=0.5)
            self._test_logit, test_repre = nrekit.network.selector.bag_one(x_test, self.scope, self.label, self.rel_tot, False, keep_prob=1.0)
            self._test_logit = tf.nn.softmax(self._test_logit)
        elif model.selector == "cross_max":
            self._train_logit, train_repre = nrekit.network.selector.bag_cross_max(x_train, self.scope, self.rel_tot, keep_prob=0.5)
            self._test_logit, test_repre = nrekit.network.selector.bag_cross_max(x_test, self.scope, self.rel_tot, keep_prob=1.0)
            self._test_logit = tf.nn.softmax(self._test_logit)
        else:
            raise NotImplementedError
        
        # Classifier
        self._loss = nrekit.network.classifier.softmax_cross_entropy(self._train_logit, self.label, self.rel_tot, weights_table=self.get_weights())
 
    def loss(self):
        return self._loss

    def train_logit(self):
        return self._train_logit

    def test_logit(self):
        return self._test_logit

    def get_weights(self):
        with tf.variable_scope("weights_table", reuse=tf.AUTO_REUSE):
            print("Calculating weights_table...")
            _weights_table = np.zeros((self.rel_tot), dtype=np.float32)
            for i in range(len(self.train_data_loader.data_rel)):
                _weights_table[self.train_data_loader.data_rel[i]] += 1.0 
            _weights_table = 1 / (_weights_table ** 0.05)
            weights_table = tf.get_variable(name='weights_table', dtype=tf.float32, trainable=False, initializer=_weights_table)
            print("Finish calculating")
        return weights_table

ckpt = tf.train.latest_checkpoint('./checkpoint/')
framework.train(model, model_name=dataset_name + "_" + model.encoder + "_" + model.selector, max_epoch=60, ckpt_dir="checkpoint", gpu_nums=1,pretrain_model=ckpt)


Pre-processed files exist. Loading them...
Finish loading
Total relation fact: 1700
Pre-processed files exist. Loading them...
Finish loading
Total relation fact: 617
Start training...
Calculating weights_table...
Finish calculating
  name = %s, shape = %s word_embedding/word_embedding:0 (62073, 50)
  name = %s, shape = %s word_embedding/unk_word_embedding:0 (1, 50)
  name = %s, shape = %s pos_embedding/real_pos1_embedding:0 (240, 5)
  name = %s, shape = %s pos_embedding/real_pos2_embedding:0 (240, 5)
  name = %s, shape = %s pcnn/conv1d/kernel:0 (3, 60, 230)
  name = %s, shape = %s pcnn/conv1d/bias:0 (230,)
  name = %s, shape = %s attention/logit/relation_matrix:0 (4, 690)
  name = %s, shape = %s attention/logit/bias:0 (4,)
INFO:tensorflow:Restoring parameters from ./checkpoint/nyt_pcnn_att
pretrain_model is loaded...
###### Epoch 0 ######
epoch 0 step 0 time 23.66 | loss: 0.004573, not NA accuracy: 1.000000, accuracy: 1.000000
epoch 0 step 1 time 7.01 | loss: 0.008204, not NA accuracy

epoch 1 step 7 time 10.16 | loss: 0.002939, not NA accuracy: 0.979798, accuracy: 0.996875
epoch 1 step 8 time 8.23 | loss: 0.004934, not NA accuracy: 0.981818, accuracy: 0.997222
epoch 1 step 9 time 5.77 | loss: 0.001204, not NA accuracy: 0.983264, accuracy: 0.997500
epoch 1 step 10 time 12.44 | loss: 0.009663, not NA accuracy: 0.980392, accuracy: 0.997159
epoch 1 step 11 time 8.22 | loss: 0.001441, not NA accuracy: 0.982079, accuracy: 0.997396
epoch 1 step 12 time 15.39 | loss: 0.000289, not NA accuracy: 0.983444, accuracy: 0.997596


In [3]:
sess = tf.Session()
model = model(framework.train_data_loader, 1)

ckpt="./checkpoint/" + dataset_name + "_" + model.encoder + "_" + model.selector

saver = tf.train.Saver()
saver.restore(sess, ckpt)

Calculating weights_table...
Finish calculating
INFO:tensorflow:Restoring parameters from ./checkpoint/nyt_pcnn_att


In [4]:
s = "郡县因为他鳏穷，给他每天五升粮食，食物不足，以乞食为生，乞讨不要多。"

head = "粮食"
tail = "食物"

word2id = {0: 'NA', 1: 'instance of', 2: 'subclass of', 3: 'parent taxon'}

from evaluate_input import evaluate_line
dic = evaluate_line(s,head,tail)

predict_label =(framework.one_step(sess, model, dic, [model.test_logit()])[0]).argmax(-1)
print(s)
print("head:", head, "tail:", tail)
print("relation:",word2id[int(predict_label)])

Building prefix dict from the default dictionary ...
Loading model from cache C:\Users\leo\AppData\Local\Temp\jieba.cache
Loading model cost 0.888 seconds.
Prefix dict has been built succesfully.


郡县因为他鳏穷，给他每天五升粮食，食物不足，以乞食为生，乞讨不要多。
head: 粮食 tail: 食物
relation: subclass of
