In [1]:
# coding=utf-8
import torch
import os
import datetime
import unicodedata
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset,DataLoader
import torch.nn.functional as F
from pytorch_pretrained_bert import BertModel
from evaluating import Metrics
from torch.autograd import Variable
import ipdb
import time
import json
from datetime import timedelta

In [2]:
class Config(object):
    base_epoch = 1
    batch_size = 16
    max_length = 110
    require_improvement = 1000  
    bert_embedding = 768
    rnn_hidden = 100
    tagset_size = 10
    bert_path = './data/bert-base-chinese'
    rnn_layers = 1
    dropout1 = 0.5
    dropout_ratio = 0.5
    # learning_rate = 1e-5
    lr = 0.0001
    lr_decay = 0.00001
    weight_decay = 0.00005
    use_cuda = True
    optim = 'Adam'
    save_path = './models/Bert_BiLSTM_CRF.ckpt'
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    data_path = 'data/all_data.txt'


In [3]:
def build_corpus(data_path,make_word2id =True):
    word_lists = []
    tag_lists = []
    with open(data_path,'r',encoding='utf-8') as f:
        word_list = []
        tag_list = []
        
        for line in f:
            if line !='\n':
                line = line.strip('\n').split()
                if len(line) < 2:
                    continue
                word,tag = line[0],line[1]

                word_list.append(word)
                tag_list.append(tag)
            else:
                word_lists.append(word_list)
                tag_lists.append(tag_list)
                word_list = []
                tag_list = []
    def build_map(lists):
        maps = {}
        for sent in lists:
            for word in sent:
                if word not in maps:
                    maps[word]=len(maps)
        return maps

    if make_word2id:
        word2id = build_map(word_lists)
        tag2id = build_map(tag_lists)
        return word_lists,tag_lists,word2id,tag2id
    else:
        return word_lists,tag_lists

In [4]:
def pad_data(word_lists,tag_lists,sent_length):
    for i in range(len(word_lists)):
        word_lists[i] = ['[CLS]'] + word_lists[i] + ['[SEP]']
        word_lists[i] = word_lists[i] + (sent_length - len(word_lists[i]))*['<pad>']

        tag_lists[i] = ['<start>'] + tag_lists[i] + ['<eos>']
        tag_lists[i] = tag_lists[i] + (sent_length - len(tag_lists[i]))*['<pad>']
    return word_lists,tag_lists

In [5]:
def extend_maps(word2id, tag2id):
    word2id['<unk>'] = len(word2id)
    word2id['<pad>'] = len(word2id)
    word2id['[CLS]'] = len(word2id)
    word2id['[SEP]'] = len(word2id)

    tag2id['<pad>'] = len(tag2id)
    tag2id['<start>'] = len(tag2id)
    tag2id['<eos>'] = len(tag2id)
    return word2id, tag2id

In [6]:
#LSTM模型 工具函数
def tensorized(data,maps):
    UNK = maps.get('<unk>')
    PAD = maps.get('<pad>')

    max_len = len(data[0])
    sent_length = len(data)
    data_tensor = torch.ones(sent_length,max_len).long()*PAD
    for i,sen in enumerate(data):
        for j,word in enumerate(sen):
            data_tensor[i][j] = maps.get(word,UNK)
    return data_tensor

In [7]:
def get_mask_tensor(data):
    max_len = len(data[0])
    sent_length = len(data)

    mask_tensor = torch.ones(sent_length,max_len).long()
    for i,sen in enumerate(data):
        for j, word in enumerate(sen):
            if word == '<pad>':
                mask_tensor[i][j] = 0
    return mask_tensor

In [8]:
class BERT_LSTM_CRF(nn.Module):
    """
    bert_lstm_crf model
    """
    def __init__(self, config):
        super(BERT_LSTM_CRF, self).__init__()
        self.embedding_dim = config.bert_embedding
        self.hidden_dim = config.rnn_hidden
        self.bert = BertModel.from_pretrained(config.bert_path)
        self.lstm = nn.LSTM(config.bert_embedding, config.rnn_hidden,
                            num_layers=config.rnn_layers, bidirectional=True, dropout=config.dropout_ratio, batch_first=True)
        self.rnn_layers = config.rnn_layers
        self.dropout1 = nn.Dropout(p=config.dropout1)
        self.crf = CRF(target_size=config.tagset_size, average_batch=True, use_cuda=config.use_cuda)
        self.liner = nn.Linear(config.rnn_hidden*2, config.tagset_size + 2)
        self.tagset_size = config.tagset_size

    def rand_init_hidden(self, batch_size):
        """
        random initialize hidden variable
        """
        hidden_state = torch.randn(2 * self.rnn_layers, batch_size, self.hidden_dim)
        cell_state = torch.randn(2 * self.rnn_layers, batch_size, self.hidden_dim)
        return hidden_state,cell_state
        
    def forward(self, sentence, attention_mask=None):
        '''
        args:
            sentence (word_seq_len, batch_size) : word-level representation of sentence
            hidden: initial hidden state

        return:
            crf output (word_seq_len, batch_size, tag_size, tag_size), hidden
        '''
        batch_size = sentence.size(0)
        seq_length = sentence.size(1)
        embeds, _ = self.bert(sentence, attention_mask=attention_mask, output_all_encoded_layers=False)
        # hidden = self.rand_init_hidden(batch_size)
        # if embeds.is_cuda:
        #     hidden = (i.cuda() for i in hidden)
        lstm_out, hidden = self.lstm(embeds)
        lstm_out = lstm_out.contiguous().view(-1, self.hidden_dim*2)
        d_lstm_out = self.dropout1(lstm_out)
        l_out = self.liner(d_lstm_out)
        lstm_feats = l_out.contiguous().view(batch_size, seq_length, -1)
        return lstm_feats

    def loss(self, feats, mask, tags):
        """
        feats: size=(batch_size, seq_len, tag_size)
            mask: size=(batch_size, seq_len)
            tags: size=(batch_size, seq_len)
        :return:
        """
        loss_value = self.crf.neg_log_likelihood_loss(feats, mask, tags)
        batch_size = feats.size(0)
        loss_value /= float(batch_size)
        return loss_value

In [9]:
def log_sum_exp(vec, m_size):
    """
    Args:
        vec: size=(batch_size, vanishing_dim, hidden_dim)
        m_size: hidden_dim

    Returns:
        size=(batch_size, hidden_dim)
    """
    _, idx = torch.max(vec, 1)  # B * 1 * M
    max_score = torch.gather(vec, 1, idx.view(-1, 1, m_size)).view(-1, 1, m_size)  # B * M
    return max_score.view(-1, m_size) + torch.log(torch.sum(
        torch.exp(vec - max_score.expand_as(vec)), 1)).view(-1, m_size)

In [10]:
class CRF(nn.Module):

    def __init__(self, **kwargs):
        """
        Args:
            target_size: int, target size
            use_cuda: bool, 是否使用gpu, default is True
            average_batch: bool, loss是否作平均, default is True
        """
        super(CRF, self).__init__()
        for k in kwargs:
            self.__setattr__(k, kwargs[k])
        self.START_TAG_IDX, self.END_TAG_IDX = -2, -1
        init_transitions = torch.zeros(self.target_size+2, self.target_size+2)
        init_transitions[:, self.START_TAG_IDX] = -1000.
        init_transitions[self.END_TAG_IDX, :] = -1000.
        if self.use_cuda:
            init_transitions = init_transitions.cuda()
        self.transitions = nn.Parameter(init_transitions)

    def _forward_alg(self, feats, mask=None):
        """
        Do the forward algorithm to compute the partition function (batched).

        Args:
            feats: size=(batch_size, seq_len, self.target_size+2)
            mask: size=(batch_size, seq_len)

        Returns:
            xxx
        """
        batch_size = feats.size(0)
        seq_len = feats.size(1)
        tag_size = feats.size(-1)

        mask = mask.transpose(1, 0).contiguous()
        ins_num = batch_size * seq_len
        feats = feats.transpose(1, 0).contiguous().view(
            ins_num, 1, tag_size).expand(ins_num, tag_size, tag_size)

        scores = feats + self.transitions.view(
            1, tag_size, tag_size).expand(ins_num, tag_size, tag_size)
        scores = scores.view(seq_len, batch_size, tag_size, tag_size)
        seq_iter = enumerate(scores)
        try:
            _, inivalues = seq_iter.__next__()
        except:
            _, inivalues = seq_iter.next()

        partition = inivalues[:, self.START_TAG_IDX, :].clone().view(batch_size, tag_size, 1)
        for idx, cur_values in seq_iter:
            cur_values = cur_values + partition.contiguous().view(
                batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size)
            cur_partition = log_sum_exp(cur_values, tag_size)
            mask_idx = mask[idx, :].view(batch_size, 1).expand(batch_size, tag_size)
            masked_cur_partition = cur_partition.masked_select(mask_idx.byte())
            if masked_cur_partition.dim() != 0:
                mask_idx = mask_idx.contiguous().view(batch_size, tag_size, 1)
                partition.masked_scatter_(mask_idx.byte(), masked_cur_partition)
        cur_values = self.transitions.view(1, tag_size, tag_size).expand(
            batch_size, tag_size, tag_size) + partition.contiguous().view(
                batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size)
        cur_partition = log_sum_exp(cur_values, tag_size)
        final_partition = cur_partition[:, self.END_TAG_IDX]
        return final_partition.sum(), scores

    def _viterbi_decode(self, feats, mask=None):
        """
        Args:
            feats: size=(batch_size, seq_len, self.target_size+2)
            mask: size=(batch_size, seq_len)

        Returns:
            decode_idx: (batch_size, seq_len), viterbi decode结果
            path_score: size=(batch_size, 1), 每个句子的得分
        """
        batch_size = feats.size(0)
        seq_len = feats.size(1)
        tag_size = feats.size(-1)

        length_mask = torch.sum(mask, dim=1).view(batch_size, 1).long()
        mask = mask.transpose(1, 0).contiguous()
        ins_num = seq_len * batch_size
        feats = feats.transpose(1, 0).contiguous().view(
            ins_num, 1, tag_size).expand(ins_num, tag_size, tag_size)

        scores = feats + self.transitions.view(
            1, tag_size, tag_size).expand(ins_num, tag_size, tag_size)
        scores = scores.view(seq_len, batch_size, tag_size, tag_size)

        seq_iter = enumerate(scores)
        # record the position of the best score
        back_points = list()
        partition_history = list()
        mask = (1 - mask.long()).byte()
        try:
            _, inivalues = seq_iter.__next__()
        except:
            _, inivalues = seq_iter.next()
        partition = inivalues[:, self.START_TAG_IDX, :].clone().view(batch_size, tag_size, 1)
        partition_history.append(partition)

        for idx, cur_values in seq_iter:
            cur_values = cur_values + partition.contiguous().view(
                batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size)
            partition, cur_bp = torch.max(cur_values, 1)
            partition_history.append(partition.unsqueeze(-1))

            cur_bp.masked_fill_(mask[idx].view(batch_size, 1).expand(batch_size, tag_size), 0)
            back_points.append(cur_bp)

        partition_history = torch.cat(partition_history).view(
            seq_len, batch_size, -1).transpose(1, 0).contiguous()

        last_position = length_mask.view(batch_size, 1, 1).expand(batch_size, 1, tag_size) - 1
        last_partition = torch.gather(
            partition_history, 1, last_position).view(batch_size, tag_size, 1)

        last_values = last_partition.expand(batch_size, tag_size, tag_size) + \
            self.transitions.view(1, tag_size, tag_size).expand(batch_size, tag_size, tag_size)
        _, last_bp = torch.max(last_values, 1)
        pad_zero = Variable(torch.zeros(batch_size, tag_size)).long()
        if self.use_cuda:
            pad_zero = pad_zero.cuda()
        back_points.append(pad_zero)
        back_points = torch.cat(back_points).view(seq_len, batch_size, tag_size)

        pointer = last_bp[:, self.END_TAG_IDX]
        insert_last = pointer.contiguous().view(batch_size, 1, 1).expand(batch_size, 1, tag_size)
        back_points = back_points.transpose(1, 0).contiguous()

        back_points.scatter_(1, last_position, insert_last)

        back_points = back_points.transpose(1, 0).contiguous()

        decode_idx = Variable(torch.LongTensor(seq_len, batch_size))
        if self.use_cuda:
            decode_idx = decode_idx.cuda()
        decode_idx[-1] = pointer.data
        for idx in range(len(back_points)-2, -1, -1):
            pointer = torch.gather(back_points[idx], 1, pointer.contiguous().view(batch_size, 1))
            decode_idx[idx] = pointer.view(-1).data
        path_score = None
        decode_idx = decode_idx.transpose(1, 0)
        return path_score, decode_idx

    def forward(self, feats, mask=None):
        path_score, best_path = self._viterbi_decode(feats, mask)
        return path_score, best_path

    def _score_sentence(self, scores, mask, tags):
        """
        Args:
            scores: size=(seq_len, batch_size, tag_size, tag_size)
            mask: size=(batch_size, seq_len)
            tags: size=(batch_size, seq_len)

        Returns:
            score:
        """
        batch_size = scores.size(1)
        seq_len = scores.size(0)
        tag_size = scores.size(-1)

        new_tags = Variable(torch.LongTensor(batch_size, seq_len))
        if self.use_cuda:
            new_tags = new_tags.cuda()
        for idx in range(seq_len):
            if idx == 0:
                new_tags[:, 0] = (tag_size - 2) * tag_size + tags[:, 0]
            else:
                new_tags[:, idx] = tags[:, idx-1] * tag_size + tags[:, idx]

        end_transition = self.transitions[:, self.END_TAG_IDX].contiguous().view(
            1, tag_size).expand(batch_size, tag_size)
        length_mask = torch.sum(mask, dim=1).view(batch_size, 1).long()
        end_ids = torch.gather(tags, 1, length_mask-1)

        end_energy = torch.gather(end_transition, 1, end_ids)

        new_tags = new_tags.transpose(1, 0).contiguous().view(seq_len, batch_size, 1)
        tg_energy = torch.gather(scores.view(seq_len, batch_size, -1), 2, new_tags).view(
            seq_len, batch_size)
        tg_energy = tg_energy.masked_select(mask.transpose(1, 0))

        gold_score = tg_energy.sum() + end_energy.sum()

        return gold_score

    def neg_log_likelihood_loss(self, feats, mask, tags):
        """
        Args:
            feats: size=(batch_size, seq_len, tag_size)
            mask: size=(batch_size, seq_len)
            tags: size=(batch_size, seq_len)
        """
        batch_size = feats.size(0)
        mask = mask.byte()
        forward_score, scores = self._forward_alg(feats, mask)
        gold_score = self._score_sentence(scores, mask, tags)
        if self.average_batch:
            return (forward_score - gold_score) / batch_size
        return forward_score - gold_score


In [11]:
config = Config
#读取数据，划分训练集，验证集测试集
word_lists,tag_lists,word2id,tag2id = build_corpus(config.data_path,make_word2id =True)
train_word_lists,train_tag_lists = word_lists[:800000],tag_lists[:800000]
dev_word_lists,dev_tag_lists = word_lists[800000:900000],tag_lists[800000:900000]
test_word_lists,test_tag_lists = word_lists[900000:],tag_lists[900000:]

#先把word2id,tag2id保存
bert_word2id,bert_tag2id = extend_maps(word2id, tag2id)
json.dump(bert_word2id,open('data/bert_word2id.txt','w'))
json.dump(bert_tag2id,open('data/bert_tag2id.txt','w'))

#在句子首尾分别加上[cls],[sep],并把每条数据用0补全至max_length
train_word_lists,train_tag_lists = pad_data(train_word_lists,train_tag_lists,Config.max_length)
dev_word_lists,dev_tag_lists = pad_data(dev_word_lists,dev_tag_lists,Config.max_length)
test_word_lists,test_tag_lists = pad_data(test_word_lists,test_tag_lists,Config.max_length)

#将每个字，标签转化为索引，并将数据格式转化为tensor
train_data_tensor,train_tag_tensor = tensorized(train_word_lists,word2id),tensorized(train_tag_lists,tag2id)
dev_data_tensor,dev_tag_tensor = tensorized(dev_word_lists,word2id),tensorized(dev_tag_lists,tag2id)
test_data_tensor,test_tag_tensor = tensorized(test_word_lists,word2id),tensorized(test_tag_lists,tag2id)

#获取mask_tensor,例如：word_list:['我','在','成','都','[<pad>]','[<pad>]'],mask_tensor:tensor([1,1,1,1,0,0])
train_mask_tensor = get_mask_tensor(train_word_lists)
dev_mask_tensor = get_mask_tensor(dev_word_lists)
test_mask_tensor = get_mask_tensor(test_word_lists)

#把数据制作成data_loader
train_dataset = TensorDataset(train_data_tensor, train_mask_tensor, train_tag_tensor)
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=config.batch_size)

dev_dataset = TensorDataset(dev_data_tensor, dev_mask_tensor, dev_tag_tensor)
dev_loader = DataLoader(dev_dataset, shuffle=True, batch_size=config.batch_size)

test_dataset = TensorDataset(test_data_tensor, test_mask_tensor, test_tag_tensor)
test_loader = DataLoader(test_dataset, shuffle=True, batch_size=config.batch_size)

In [28]:
def get_time_dif(start_time):
    """获取已使用时间"""
    end_time = time.time()
    time_dif = end_time - start_time
    return timedelta(seconds=int(round(time_dif)))

In [29]:
model = BERT_LSTM_CRF(config)
optimizer = optim.Adam(model.parameters(),lr=config.lr,weight_decay=config.weight_decay)
eval_loss = 10000
last_improve = 0
require_improvement = 1000
flag = False
step = 0
start_time = time.time()
model.train()

for epoch in range(config.base_epoch):
    for i,batch in enumerate(train_loader):
        step +=1
        model.zero_grad()
        inputs_id,masks_attention,tags_id = batch
        # inputs_id,masks_attention,tags_id = Variable(inputs_id),Variable(masks_attention),Variable(tags_id)
        if config.use_cuda:
            inputs_id,masks_attention,tags_id = inputs_id.cuda(),masks_attention.cuda(),tags_id.cuda()
        feats = model(inputs_id,masks_attention)
        loss = model.loss(feats,masks_attention,tags_id)
        loss.backward()
        optimizer.step()
        if step % 100 == 0:
            loss_temp = dev(model,dev_loader,epoch,config)
            if loss_temp < eval_loss:
                eval_loss = loss_temp
                torch.save(model,config.save_path)
                last_improve = step
                improve = '*'
            else:
                improve = ''
            time_dif = get_time_dif(start_time)
            print('step: {} | epoch: {} | train_loss: {:5.4} | dev_loss:{:5.4} | improve:{} | time:{}'.format(step,epoch+1,loss.item(),loss_temp,improve,time_dif))
        if step - last_improve > require_improvement:
            # 验证集loss超过1000batch没下降，结束训练
            print("No optimization for a long time, auto-stopping...")
            flag = True
            break
    if flag:
        break


step: 100 | epoch: 1 | train_loss: 1.623 | dev_loss:1.552 | improve:* | time:0:26:19
step: 200 | epoch: 1 | train_loss: 0.8484 | dev_loss:0.8132 | improve:* | time:0:51:57
step: 300 | epoch: 1 | train_loss: 0.3706 | dev_loss:0.564 | improve:* | time:1:17:43
step: 400 | epoch: 1 | train_loss: 0.4235 | dev_loss:0.4644 | improve:* | time:1:43:43
step: 500 | epoch: 1 | train_loss: 0.433 | dev_loss:0.3855 | improve:* | time:2:10:23
step: 600 | epoch: 1 | train_loss: 0.3956 | dev_loss:0.3631 | improve:* | time:2:35:57
step: 700 | epoch: 1 | train_loss: 0.362 | dev_loss:0.3216 | improve:* | time:3:01:35
step: 800 | epoch: 1 | train_loss: 0.4344 | dev_loss:0.3103 | improve:* | time:3:27:13
step: 900 | epoch: 1 | train_loss: 0.2897 | dev_loss:0.2722 | improve:* | time:3:52:46
step: 1000 | epoch: 1 | train_loss: 0.3525 | dev_loss:0.2689 | improve:* | time:4:18:18
step: 1100 | epoch: 1 | train_loss: 0.209 | dev_loss:0.2608 | improve:* | time:4:44:01
step: 1200 | epoch: 1 | train_loss: 0.1903 | de

In [29]:
def dev(model,dev_loader,epoch,config):
    model.eval()
    eval_loss = 0
    true = []
    pred = []
    length = 0
    with torch.no_grad():
        for i,batch in enumerate(dev_loader):
            inputs,masks,tags = batch
            length += 1
            # inputs,masks,tags = Variable(inputs),Variable(masks),Variable(tags)
            if config.use_cuda:
                inputs,masks,tags = inputs.cuda(),masks.cuda(),tags.cuda()
            feats = model(inputs,masks)
            path_score,best_path = model.crf(feats,masks.byte())
            loss = model.loss(feats,masks,tags)
            eval_loss += loss.item()
        # pred.extend([t for t in  best_path])
        # true.extend([t for t in tags])
    # pred = flatten_lists(pred)
    # true = flatten_lists(true)
    # metrics = Metrics(true_list, predict_list, remove_O=False)
    # print('eval epoch:{}| loss:{}'.format(epoch,eval_loss/length))
    model.train()
    return eval_loss/length

In [29]:
def test(config,test_data):
    model = torch.load(config.save_path)
    model.eval()
    true = []
    pred = []
    with torch.no_grad():
        for i,batch in enumerate(test_data):
            inputs,masks,tags = batch
            if config.use_cuda:
                inputs,masks,tags = inputs.cuda(),masks.cuda(),tags.cuda()
            feats = model(inputs,masks)
            path_score,best_path = model.crf(feats,masks.byte())
            pred.extend([t.cpu().numpy() for t in  best_path])
            true.extend([t.cpu().numpy() for t in tags])
    return pred,true

In [30]:
pred,true = test(config,test_loader)

In [31]:
#预测数据和真实数据在[sep]以前都是相同的，但是在[sep]以后就不同了，具体原因不太清楚，不过[sep]后面都是[pad]不太重要，我们截取[sep]前面的信息就可以了，在tag2id里面，[sep]对应的是9
pred_lists = []
true_lists = []
for i in range(len(pred)):
    pred_list = []
    true_list = []
    for j in range(len(pred[i])):
        if pred[i][j] != 9:
            pred_list.append(pred[i][j])
            true_list.append(true[i][j])
        else:
            pred_list.append(pred[i][j])
            true_list.append(true[i][j])
            break
    pred_lists.append(pred_list)
    true_lists.append(true_list)

In [30]:
def flatten_lists(lists):
    flatten_list = []
    for l in lists:
        if type(l) == list:
            flatten_list += l
        else:
            flatten_list.append(l)
    return flatten_list

In [32]:
pred_lists = flatten_lists(pred_lists)
true_lists = flatten_lists(true_lists)

In [33]:
id2tag = {}
for key,value in tag2id.items():
    id2tag[value] = key

for i in range(len(pred_lists)):
    pred_lists[i] = id2tag[pred_lists[i]]
    true_lists[i] = id2tag[true_lists[i]]

In [34]:
metrics = Metrics(true_lists, pred_lists, remove_O=False)
metrics.report_scores()
metrics.report_confusion_matrix()

           precision    recall  f1-score   support
    B-ORG     0.9525    0.9444    0.9484     52071
    B-LOC     0.9826    0.9989    0.9907     11642
        O     0.9936    0.9929    0.9932    791286
    <eos>     1.0000    1.0000    1.0000     22159
  <start>     1.0000    1.0000    1.0000     22159
    I-LOC     0.9985    0.9980    0.9983      2019
    E-LOC     0.9824    0.9987    0.9905     11642
    I-ORG     0.9573    0.9744    0.9658     59220
    E-ORG     0.9574    0.9493    0.9534     52071
avg/total     0.9876    0.9876    0.9876   1024269

Confusion Matrix:
          B-ORG   B-LOC       O   <eos> <start>   I-LOC   E-LOC   I-ORG   E-ORG 
  B-ORG   49178       3    2236       0       0       0       0     654       0 
  B-LOC       0   11629      13       0       0       0       0       0       0 
      O    1984     201  785629       0       1       1     205    1435    1830 
  <eos>       0       0       0   22159       0       0       0       0       0 
<start>       0

In [11]:
config = Config
bert_word2id = json.load(open('data/bert_word2id.txt'))
bert_tag2id = json.load(open('data/bert_tag2id.txt'))

In [14]:
def extract_entity(sent,bert_word2id,bert_tag2id,config):
    model = torch.load(config.save_path)
    sent_list = []
    for i in range(len(sent)):
        sent_list.append(sent[i])
    sent_list,_ = pad_data([sent_list],[[]],config.max_length)
    input_ids = tensorized(sent_list,bert_word2id).cuda()
    masked_data = get_mask_tensor(sent_list).cuda()

    feats = model(input_ids,masked_data)
    path_score,best_path = model.crf(feats,masked_data.byte())
    pred_list = best_path[0].cpu().numpy()

    id2tag = {}
    for key,value in bert_tag2id.items():
        id2tag[value] = key
    pred_tag_lists = []
    for i in range(len(pred_list)):
        pred_tag_lists.append(id2tag[pred_list[i]])


    entity_idxs = []
    for i in range(len(pred_tag_lists)):
        entity_idx = []
        if pred_tag_lists[i]=='B-ORG':
            entity_idx.append(i-1)
            for j in range(i+1,len(pred_tag_lists)):
                if pred_tag_lists[j]=='E-ORG':
                    entity_idx.append(j-1)
                    break
        
        if pred_tag_lists[i]=='B-LOC':
            entity_idx.append(i-1)
            for j in range(i+1,len(pred_tag_lists)):
                if pred_tag_lists[j]=='E-LOC':
                    entity_idx.append(j-1)
                    break

        if entity_idx!=[]:
            entity_idxs.append(entity_idx)
    
    words_list = []
    for pos in entity_idxs:
        words = sent[pos[0]:(pos[1]+1)]
        words_list.append(words)
    return list(set(words_list))  

In [15]:
# sent = '我人在摩西，最近想在割双眼皮'
sent = '我人在三京，三点式双眼皮在北京做'
extract_entity(sent,bert_word2id,bert_tag2id,config)

['三京', '北京', '双眼皮']