In [25]:
import tensorflow as tf
import os
import numpy as np
import ujson as json
from importlib import reload
from scipy import stats

from func import cudnn_gru, native_gru, dot_attention, summ, ptr_net
from prepro import word_tokenize, convert_idx
import inference

# reload(inference.InfModel)
# reload(inference.Inference)

# R-NET样例测试，输出置信度

In [2]:
tf.reset_default_graph()

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

# Must be consistant with training
char_limit = 16
hidden = 75
char_dim = 8
char_hidden = 100
use_cudnn = False

# File path
target_dir = "data"
save_dir = "log/model"
word_emb_file = os.path.join(target_dir, "word_emb.json")
char_emb_file = os.path.join(target_dir, "char_emb.json")
word2idx_file = os.path.join(target_dir, "word2idx.json")
char2idx_file = os.path.join(target_dir, "char2idx.json")

infer = inference.Inference()

INFO:tensorflow:Restoring parameters from log/model\model_44000.ckpt


In [91]:
context = "... michael crawford , right , who is ailing , will not return to his award-winning " \
          "role of count fosco in the andrew lloyd webber musical adaptation of the wilkie_collins " \
          "classic , '' the woman in white , '' in london as scheduled on may 2 ."
ques2 = "Where is the birth place of wilkie_collins?"
# ans2 = infer.response(context, ques2)
# print(infer.response(context, ques2))

ans2, confidence1, confidence2 = infer.response(context, ques2)
print("Answer 2: {}".format(ans2))

Answer 2: london


In [92]:
from scipy import stats

In [93]:
print(stats.entropy(confidence1))
print(stats.entropy(confidence2))
print(stats.entropy(np.ones((10000))/10000))
print(stats.entropy(np.ones((100))/100))
print(stats.entropy(np.ones((10))/10))
print(stats.entropy(np.ones((2))/2))

1.2087417
0.7490903
9.210340371976176
4.605170185988092
2.3025850929940455
0.6931471805599453


# 在riedel-NYT数据集上进行验证

In [3]:
import pandas as pd

In [4]:
file_path = 'origin_data/train.txt'
df = pd.read_csv(file_path, sep='\t', header=None, names=['e1_encoding', 'e2_encoding', 'e1', 'e2', 'relation', 'content'])

  interactivity=interactivity, compiler=compiler, result=result)


In [5]:
df.content.head(3)

0    sen. charles e. schumer called on federal safe...
1    but instead there was a funeral , at st. franc...
2    rosemary antonelle , the daughter of teresa l....
Name: content, dtype: object

In [6]:
df.relation = df.relation.fillna('none')
relation_series = df.relation.value_counts()
selected_relation_series = relation_series[relation_series.values > 1000]
relation_list = selected_relation_series.index.values.tolist()

In [7]:
relation_list

['none',
 '/location/location/contains',
 '/people/person/nationality',
 '/location/country/capital',
 '/people/person/place_lived',
 '/location/neighborhood/neighborhood_of',
 '/location/country/administrative_divisions',
 '/location/administrative_division/country',
 '/business/person/company',
 '/people/person/place_of_birth',
 '/people/deceased_person/place_of_death',
 '/business/company/founders']

In [8]:
# 筛选有效数据集
selected_df = df.loc[df['relation'].isin(relation_list)]
selected_df.head(2)

Unnamed: 0,e1_encoding,e2_encoding,e1,e2,relation,content
0,m.0ccvx,m.05gf08,queens,belle_harbor,/location/location/contains,sen. charles e. schumer called on federal safe...
1,m.0ccvx,m.05gf08,queens,belle_harbor,/location/location/contains,"but instead there was a funeral , at st. franc..."


In [11]:
# 人工构造每种关系的问句
relation_to_questions = {
     '/location/location/contains':[
         'Where is <e2> located in?', 
         'Where is <e2>?',
         'Which place contains <e2>?'],
     '/people/person/nationality':['What\'s the nationality of <e1>?'],
     '/location/country/capital':[
         'What\'s the capital of <e2>?', 
         'Where is the capitcal of <e2>?'],
     '/people/person/place_lived': ['Where does <e1> lived in?'],
     '/location/neighborhood/neighborhood_of':[
         'What is the neighborhood of <e1>?', 
         'Where is <e1> next to?', 
         'What place does <e1> adjacent to?'],
     '/location/administrative_division/country':[
         'Which country does <e1> belong to?', 
         'Which country does <e1> located in?'],
     '/location/country/administrative_divisions':[
         'Which country does <e2> belong to?', 
         'Which country does <e2> located in?'],
     '/business/person/company':[
         'Which company does <e1> work for?', 
         'Which company does <e1> join?',
         'Where does <e1> work for?',
         'What\'s the occupation of <e1>?',
         'Which company hires <e1>?'],
     '/people/person/place_of_birth':[
         'Where is the birth place of <e1>?',
         'Where does <e1> born?',
         'Where is the hometown of <e1>?'],
     '/people/deceased_person/place_of_death':[
         'Where did <e1> died?',
         'Where is the place of death of <e1>?'],
     '/business/company/founders':[
         'Who found <e1>?', 
         'Who is the founder of <e1>?',
         'Who starts <e1>?']
}

question_to_relation = {}
# question_to_relation = {q:relation for q in [qlist ]}
for relation, qlist in relation_to_questions.items():
    for q in qlist:
        question_to_relation[q] = relation

In [12]:
selected_df[selected_df.relation=='/business/person/company'].iloc[0]

e1_encoding                                            m.03wt401
e2_encoding                                             m.04v49y
e1                                                   hank_ratner
e2                                                   cablevision
relation                                /business/person/company
content        cablevision 's $ 600 million offer came in the...
Name: 74, dtype: object

In [13]:
selected_df[selected_df.relation=='/people/person/place_lived'].iloc[0].content

"10 p.m. -lrb- mtv -rrb- the hills -- on '' laguna_beach : the real orange county , '' lauren_conrad -lrb- right -rrb- was the emotional blonde who wore every smidgen of excitement or frustration on her face , sometimes both at once . ###END###"

# 任务构造

- 任务一：判定特定关系的结果
- 任务二：对于每种潜在关系，测试不同的响应，从而获取任意可能的关系与关系结果

In [14]:
# 工具函数

def content_prepro(content):
    # 可能需要过滤部分特殊标记
    content = content[:-10]
    return content

In [56]:
# 任务一
def dprint(s):
#     print(s)
    pass

def test_single_relation(relation_name):
    exact_cnt = 0  # 完全匹配
    hit_cnt = 0  # 部分命中
    total_cnt = 1000
    pred_list = []
    truth_list = []
    for idx, row in selected_df[selected_df.relation==relation_name].reset_index().iterrows():

        dprint('=============')
        dprint(idx)
        dprint(row)
        content = content_prepro(row.content)
        dprint('Content=\t' + content)

        best_loss = 100
        best_pred = ''
        truth = ''  # 实际上是一样的
        for q in relation_to_questions[row.relation]:
            # 将问题模板中的实体进行带入
            question = q.replace('<e1>', row.e1).replace('<e2>', row.e2)
            dprint('Q=\t' + question)
            try:
                pred, d1, d2 = infer.response(content, question)  # c1, c2 are confidence of begin and end
                c1, c2 = stats.entropy(d1), stats.entropy(d2)
                loss = c1*c2
                if loss < best_loss:
                    best_loss = loss
                    best_pred = pred
                    truth = str(row.e1 if row.e2 in question else row.e2)
                dprint('pred=' + str(pred) + 
                       '\tTruth=' + truth + 
                       '\tc1=' + str(c1) + '\tc2=' + str(c2))
            except:
                continue
        if truth !='' and (best_pred in truth or truth in best_pred):
            pred_list.append(best_pred)
            truth_list.append(truth)
            hit_cnt += 1
        if idx %10 == 0:
            dprint(idx)
        if idx > total_cnt:
            break
    dprint(hit_cnt)
    dprint(pred_list[:20])
    dprint(truth_list[:20])
    return hit_cnt, total_cnt

In [57]:
for relation in [
    '/location/location/contains',
    '/people/person/nationality',
    '/location/country/capital',
    '/people/person/place_lived',
    '/location/neighborhood/neighborhood_of',
    '/location/country/administrative_divisions',
    '/location/administrative_division/country',
    '/business/person/company',
    '/people/person/place_of_birth',
    '/people/deceased_person/place_of_death',
    '/business/company/founders'
]:
    hit_cnt, total_cnt = test_single_relation(relation)
    print(relation)
    print('Accuracy:' + str(hit_cnt) + ' / ' + str(total_cnt))


/location/location/contains
Accuracy:459 / 1000
/people/person/nationality
Accuracy:458 / 1000
/location/country/capital
Accuracy:105 / 1000
/people/person/place_lived
Accuracy:604 / 1000
/location/neighborhood/neighborhood_of
Accuracy:412 / 1000
/location/country/administrative_divisions
Accuracy:652 / 1000
/location/administrative_division/country
Accuracy:559 / 1000
/business/person/company
Accuracy:209 / 1000
/people/person/place_of_birth
Accuracy:551 / 1000
/people/deceased_person/place_of_death
Accuracy:413 / 1000
/business/company/founders
Accuracy:267 / 1000
