In [1]:
import torch
import torch.nn as nn
import math,json
from transformers import AutoModel
import os,re
from tqdm import tqdm
import numpy as np
import time
import transformers
from collections import defaultdict
from transformers import AutoTokenizer
from IPython import embed
bert_name='/home/xhsun/Desktop/huggingfaceModels/chinese-roberta-wwm/'
tokenizer=AutoTokenizer.from_pretrained(bert_name)

# 获取实体、关系、三元组

In [2]:
kg_folder='/home/xhsun/Desktop/KG/nlpcc2018/knowledge/small_knowledge/'

ent2id = {}
with open(os.path.join(kg_folder, 'entities.dict')) as f:
    lines=f.readlines()
for i in tqdm(range(len(lines))):
    l = lines[i].strip().split('\t')
    ent2id[l[0].strip()] = len(ent2id)
id2ent={k:v for v,k in ent2id.items()}

rel2id = {}
with open(os.path.join(kg_folder, 'relations.dict')) as f:
    lines=f.readlines()
for i in tqdm(range(len(lines))):
    l = lines[i].strip().split('\t')
    rel2id[l[0].strip()] = int(l[1])

triples = []
bad_count=0
with open(os.path.join(kg_folder, 'small_kb')) as f:
    lines=f.readlines()
for i in tqdm(range(len(lines))):
    l = lines[i].strip().split('|||')
    try:
        s = ent2id[l[0].strip()]
        p = rel2id[l[1].strip()]
        o = ent2id[l[2].strip()]
        triples.append((s, p, o))
    except:
        bad_count+=1
triples = torch.LongTensor(triples)

Tsize = len(triples)
Esize = len(ent2id)
num_relations = len(rel2id)

idx = torch.LongTensor([i for i in range(Tsize)])
Msubj = torch.sparse.FloatTensor(
    torch.stack((idx, triples[:,0])), torch.FloatTensor([1] * Tsize), torch.Size([Tsize, Esize]))
Mobj = torch.sparse.FloatTensor(
    torch.stack((idx, triples[:,2])), torch.FloatTensor([1] * Tsize), torch.Size([Tsize, Esize]))
Mrel = torch.sparse.FloatTensor(
    torch.stack((idx, triples[:,1])), torch.FloatTensor([1] * Tsize), torch.Size([Tsize, num_relations]))

100%|██████████████████████████████████████████████████████████████████████████████████████████| 191000/191000 [00:00<00:00, 1853139.17it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████| 3906/3906 [00:00<00:00, 1996703.40it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████| 369812/369812 [00:00<00:00, 920876.40it/s]


# 加载模型

In [15]:
class TransferNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.num_steps = 2
        
        self.bert_encoder = AutoModel.from_pretrained(bert_name, return_dict=True)
        dim_hidden = self.bert_encoder.config.hidden_size

        self.step_encoders = []
        for i in range(self.num_steps):
            m = nn.Sequential(
                nn.Linear(dim_hidden, dim_hidden),
                nn.Tanh()
            )
            self.step_encoders.append(m)
            self.add_module('step_encoders_{}'.format(i), m)

        self.rel_classifier = nn.Linear(dim_hidden, num_relations)
        self.hop_selector = nn.Linear(dim_hidden, self.num_steps)


    def follow(self, e, r):
        x = torch.sparse.mm(Msubj, e.t()) * torch.sparse.mm(Mrel, r.t())
        return torch.sparse.mm(Mobj.t(), x).t() # [bsz, Esize]

    def forward(self, heads, questions, answers=None, entity_range=None):
        q = self.bert_encoder(**questions)
        q_embeddings, q_word_h = q.pooler_output, q.last_hidden_state # (bsz, dim_h), (bsz, len, dim_h)

        device = heads.device
        last_e = heads
        word_attns = []
        rel_probs = []
        ent_probs = []
        for t in range(self.num_steps):
            cq_t = self.step_encoders[t](q_embeddings) # [bsz, dim_h]
            q_logits = torch.sum(cq_t.unsqueeze(1) * q_word_h, dim=2) # [bsz, max_q]
            q_dist = torch.softmax(q_logits, 1) # [bsz, max_q]
            q_dist = q_dist * questions['attention_mask'].float()
            q_dist = q_dist / (torch.sum(q_dist, dim=1, keepdim=True) + 1e-6) # [bsz, max_q]
            word_attns.append(q_dist)
            ctx_h = (q_dist.unsqueeze(1) @ q_word_h).squeeze(1) # [bsz, dim_h]

            rel_logit = self.rel_classifier(ctx_h) # [bsz, num_relations]
            # rel_dist = torch.softmax(rel_logit, 1) # bad
            rel_dist = torch.sigmoid(rel_logit)
            rel_probs.append(rel_dist)

            # sub, rel, obj = self.triples[:,0], self.triples[:,1], self.triples[:,2]
            # sub_p = last_e[:, sub] # [bsz, #tri]
            # rel_p = rel_dist[:, rel] # [bsz, #tri]
            # obj_p = sub_p * rel_p
            # last_e = torch.index_add(torch.zeros_like(last_e), 1, obj, obj_p)
            last_e = self.follow(last_e, rel_dist) # faster than index_add

            # reshape >1 scores to 1 in a differentiable way
            m = last_e.gt(1).float()
            z = (m * last_e + (1-m)).detach()
            last_e = last_e / z

            ent_probs.append(last_e)

        hop_res = torch.stack(ent_probs, dim=1) # [bsz, num_hop, num_ent]
        hop_attn = torch.softmax(self.hop_selector(q_embeddings), dim=1).unsqueeze(2) # [bsz, num_hop, 1]
        last_e = torch.sum(hop_res * hop_attn, dim=1) # [bsz, num_ent]

        if not self.training:
            return {
                'e_score': last_e,
                'word_attns': word_attns,
                'rel_probs': rel_probs,
                'ent_probs': ent_probs,
                'hop_attn': hop_attn.squeeze(2)
            }
        else:
            weight = answers * 99 + 1
            loss = torch.sum(entity_range * weight * torch.pow(last_e - answers, 2)) / torch.sum(entity_range * weight)

            return {'loss': loss}

model = TransferNet()
model.load_state_dict(torch.load('/home/xhsun/Desktop/code/KG/TransferNet-master/save_dir/model.pt',map_location='cpu'))
model=model.eval()

Some weights of the model checkpoint at /home/xhsun/Desktop/huggingfaceModels/chinese-roberta-wwm/ were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


# 加载测试数据

In [6]:
def convert_tokens_to_ids(topic_entity,question,answer=None):
    question=question.replace('<'+topic_entity+'>','NE')
    #print(question)
    head=[ent2id[topic_entity]]
    token_ids=tokenizer(question.strip(), max_length=64, padding='max_length', return_tensors="pt")
    if answer==None:
        return head,token_ids
    else:
        ans_ids=[ent2id[answer]]
        return head,token_ids,ans_ids
    
def toOneHot(indices):
    indices = torch.LongTensor(indices)
    vec_len = len(ent2id)
    one_hot = torch.FloatTensor(vec_len)
    one_hot.zero_()
    one_hot.scatter_(0, indices, 1)
    return one_hot

all_examples=[]
with open("/home/xhsun/Desktop/KG/nlpcc2018/knowledge/small_knowledge/test.txt") as f:
    lines=f.readlines()
    for line in lines:
        all_examples.append(line.strip().split('\t'))

# 测试模型

In [7]:
correct=0
all_predict_examples=[]
for i in tqdm(range(len(all_examples))):
    example=all_examples[i]
    question,answer=example
    topic_entity=re.findall(pattern='<(.*)>',string=question)[0]
    head,token_ids=convert_tokens_to_ids(topic_entity,question,answer=None)
    #print(head,token_ids,ans_ids)
    one_hot_head=toOneHot(head)
    with torch.no_grad():
        result=model(*(one_hot_head.unsqueeze(0),token_ids))
    e_score=result['e_score']
    scores,idx=torch.max(e_score,dim=1)
    score=scores.tolist()[0]
    predict_id=idx.tolist()[0]
    predict_answer=id2ent[predict_id]
    all_predict_examples.append({'score':score,'idx':idx,'predict_answer':predict_answer})
        
    if predict_answer==answer:
        correct+=1

100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 2000/2000 [02:52<00:00, 11.57it/s]


In [8]:
correct/len(all_examples)

0.738

In [16]:
question="《碧血剑》的导演"
topic_entity='《碧血剑》'

In [17]:
head,token_ids=convert_tokens_to_ids(topic_entity,question,answer=None)

In [18]:
one_hot_head=toOneHot(head)

In [19]:
with torch.no_grad():
    result=model(*(one_hot_head.unsqueeze(0),token_ids))

In [20]:
e_score=result['e_score']
scores,idx=torch.max(e_score,dim=1)
score=scores.tolist()[0]
predict_id=idx.tolist()[0]

In [21]:
predict_answer=id2ent[predict_id]

In [22]:
predict_answer

'张纪中'