<a href="https://colab.research.google.com/github/ruyncuc/test/blob/master/bert4keras_ner.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%tensorflow_version 1.x
import tensorflow
print(tensorflow.__version__)

TensorFlow 1.x selected.
1.15.2


In [2]:
from google.colab import drive
drive.mount('/content/drive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/drive


In [3]:
! wget http://s3.bmio.net/kashgari/china-people-daily-ner-corpus.tar.gz
! tar -zxvf china-people-daily-ner-corpus.tar.gz

--2020-04-18 08:11:43--  http://s3.bmio.net/kashgari/china-people-daily-ner-corpus.tar.gz
Resolving s3.bmio.net (s3.bmio.net)... 52.219.0.67
Connecting to s3.bmio.net (s3.bmio.net)|52.219.0.67|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2443473 (2.3M) [application/x-gzip]
Saving to: ‘china-people-daily-ner-corpus.tar.gz’


2020-04-18 08:11:44 (3.67 MB/s) - ‘china-people-daily-ner-corpus.tar.gz’ saved [2443473/2443473]

./._china-people-daily-ner-corpus
china-people-daily-ner-corpus/
china-people-daily-ner-corpus/._example.dev
china-people-daily-ner-corpus/example.dev
china-people-daily-ner-corpus/._example.train
china-people-daily-ner-corpus/example.train
china-people-daily-ner-corpus/._example.test
china-people-daily-ner-corpus/example.test


In [4]:
pip install git+https://www.github.com/bojone/bert4keras.git

Collecting git+https://www.github.com/bojone/bert4keras.git
  Cloning https://www.github.com/bojone/bert4keras.git to /tmp/pip-req-build-m4vjypn8
  Running command git clone -q https://www.github.com/bojone/bert4keras.git /tmp/pip-req-build-m4vjypn8
Building wheels for collected packages: bert4keras
  Building wheel for bert4keras (setup.py) ... [?25l[?25hdone
  Created wheel for bert4keras: filename=bert4keras-0.7.3-cp36-none-any.whl size=39430 sha256=4b09136f37142889982cf533eb20e2cfec56aa51fe6d37a635d0637d2f7eeb89
  Stored in directory: /tmp/pip-ephem-wheel-cache-w_e8fbf2/wheels/12/58/83/8ff5c864b80c860e6d9e9e0d90c04fafca05d01d21f9f6fcba
Successfully built bert4keras
Installing collected packages: bert4keras
Successfully installed bert4keras-0.7.3


In [5]:
#! -*- coding: utf-8 -*-
# 用CRF做中文命名实体识别
# 数据集 http://s3.bmio.net/kashgari/china-people-daily-ner-corpus.tar.gz
# 实测验证集的F1可以到96.18%，测试集的F1可以到95.35%

import numpy as np
from bert4keras.backend import keras, K
from bert4keras.models import build_transformer_model
from bert4keras.tokenizers import Tokenizer
from bert4keras.optimizers import Adam
from bert4keras.snippets import sequence_padding, DataGenerator
from bert4keras.snippets import open
from bert4keras.layers import ConditionalRandomField
from keras.layers import Dense
from keras.models import Model
from tqdm import tqdm

maxlen = 256
epochs = 10
batch_size = 32
bert_layers = 12
learing_rate = 1e-5  # bert_layers越小，学习率应该要越大
crf_lr_multiplier = 1000  # 必要时扩大CRF层的学习率

config_path = '/content/drive/My Drive/chinese_L-12_H-768_A-12/bert_config.json'
checkpoint_path = '/content/drive/My Drive/chinese_L-12_H-768_A-12/bert_model.ckpt'
dict_path = '/content/drive/My Drive/chinese_L-12_H-768_A-12/vocab.txt'


def load_data(filename):
    D = []
    with open(filename, encoding='utf-8') as f:
        f = f.read()
        for l in f.split('\n\n'):
            if not l:
                continue
            d, last_flag = [], ''
            for c in l.split('\n'):
                char, this_flag = c.split(' ')
                if this_flag == 'O' and last_flag == 'O':
                    d[-1][0] += char
                elif this_flag == 'O' and last_flag != 'O':
                    d.append([char, 'O'])
                elif this_flag[:1] == 'B':
                    d.append([char, this_flag[2:]])
                else:
                    d[-1][0] += char
                last_flag = this_flag
            D.append(d)
    return D


# 标注数据
# 载入数据，注意已经把下载好的人民日报数据放到了./sample_data文件夹内
train_data = load_data('/content/china-people-daily-ner-corpus/example.train')
valid_data = load_data('/content/china-people-daily-ner-corpus/example.dev')
test_data = load_data('/content/china-people-daily-ner-corpus/example.test')

# 建立分词器
tokenizer = Tokenizer(dict_path, do_lower_case=True)

# 类别映射
labels = ['PER', 'LOC', 'ORG']
id2label = dict(enumerate(labels))
label2id = {j: i for i, j in id2label.items()}
num_labels = len(labels) * 2 + 1


class data_generator(DataGenerator):
    """数据生成器
    """
    def __iter__(self, random=False):
        batch_token_ids, batch_segment_ids, batch_labels = [], [], []
        for is_end, item in self.sample(random):
            token_ids, labels = [tokenizer._token_start_id], [0]
            for w, l in item:
                w_token_ids = tokenizer.encode(w)[0][1:-1]
                if len(token_ids) + len(w_token_ids) < maxlen:
                    token_ids += w_token_ids
                    if l == 'O':
                        labels += [0] * len(w_token_ids)
                    else:
                        B = label2id[l] * 2 + 1
                        I = label2id[l] * 2 + 2
                        labels += ([B] + [I] * (len(w_token_ids) - 1))
                else:
                    break
            token_ids += [tokenizer._token_end_id]
            labels += [0]
            segment_ids = [0] * len(token_ids)
            batch_token_ids.append(token_ids)
            batch_segment_ids.append(segment_ids)
            batch_labels.append(labels)
            if len(batch_token_ids) == self.batch_size or is_end:
                batch_token_ids = sequence_padding(batch_token_ids)
                batch_segment_ids = sequence_padding(batch_segment_ids)
                batch_labels = sequence_padding(batch_labels)
                yield [batch_token_ids, batch_segment_ids], batch_labels
                batch_token_ids, batch_segment_ids, batch_labels = [], [], []




Using TensorFlow backend.


In [7]:
train_data[0]

[['海钓比赛地点在', 'O'], ['厦门', 'LOC'], ['与', 'O'], ['金门', 'LOC'], ['之间的海域。', 'O']]

In [8]:
test_data[0]

[['我们变而以书会友，以书结缘，把', 'O'],
 ['欧', 'LOC'],
 ['美', 'LOC'],
 ['、', 'O'],
 ['港', 'LOC'],
 ['台', 'LOC'],
 ['流行的食品类图谱、画册、工具书汇集一堂。', 'O']]

In [0]:
"""
后面的代码使用的是bert类型的模型，如果你用的是albert，那么前几行请改为：
model = build_transformer_model(
    config_path,
    checkpoint_path,
    model='albert',
)
output_layer = 'Transformer-FeedForward-Norm'
output = model.get_layer(output_layer).get_output_at(bert_layers - 1)
"""

model = build_transformer_model(
    config_path,
    checkpoint_path,
)

output_layer = 'Transformer-%s-FeedForward-Norm' % (bert_layers - 1)
output = model.get_layer(output_layer).output
output = Dense(num_labels)(output)
CRF = ConditionalRandomField(lr_multiplier=crf_lr_multiplier)
output = CRF(output)

model = Model(model.input, output)
model.summary()

model.compile(
    loss=CRF.sparse_loss,
    optimizer=Adam(learing_rate),
    metrics=[CRF.sparse_accuracy]
)




In [0]:
def viterbi_decode(nodes, trans):
    """Viterbi算法求最优路径
    其中nodes.shape=[seq_len, num_labels],
        trans.shape=[num_labels, num_labels].
    """
    labels = np.arange(num_labels).reshape((1, -1))
    scores = nodes[0].reshape((-1, 1))
    scores[1:] -= np.inf  # 第一个标签必然是0
    paths = labels
    for l in range(1, len(nodes)):
        M = scores + trans + nodes[l].reshape((1, -1))
        idxs = M.argmax(0)
        scores = M.max(0).reshape((-1, 1))
        paths = np.concatenate([paths[:, idxs], labels], 0)
    return paths[:, scores[:, 0].argmax()]


def named_entity_recognize(text):
    """命名实体识别函数
    """
    tokens = tokenizer.tokenize(text)
    while len(tokens) > 512:
        tokens.pop(-2)
    mapping = tokenizer.rematch(text, tokens)
    token_ids = tokenizer.tokens_to_ids(tokens)
    segment_ids = [0] * len(token_ids)
    nodes = model.predict([[token_ids], [segment_ids]])[0]
    trans = K.eval(CRF.trans)
    labels = viterbi_decode(nodes, trans)
    entities, starting = [], False
    for i, label in enumerate(labels):
        if label > 0:
            if label % 2 == 1:
                starting = True
                entities.append([[i], id2label[(label - 1) // 2]])
            elif starting:
                entities[-1][0].append(i)
            else:
                starting = False
        else:
            starting = False

    return [
        (text[mapping[w[0]][0]:mapping[w[-1]][-1] + 1], l) for w, l in entities
    ]


def evaluate(data):
    """评测函数
    """
    X, Y, Z = 1e-10, 1e-10, 1e-10
    for d in tqdm(data):
        text = ''.join([i[0] for i in d])
        R = set(named_entity_recognize(text))
        T = set([tuple(i) for i in d if i[1] != 'O'])
        X += len(R & T)
        Y += len(R)
        Z += len(T)
    f1, precision, recall = 2 * X / (Y + Z), X / Y, X / Z
    return f1, precision, recall


class Evaluate(keras.callbacks.Callback):
    def __init__(self):
        self.best_val_f1 = 0

    def on_epoch_end(self, epoch, logs=None):
        trans = K.eval(CRF.trans)
        print(trans)
        f1, precision, recall = evaluate(valid_data)
        # 保存最优
        if f1 >= self.best_val_f1:
            self.best_val_f1 = f1
            model.save_weights('./best_model.weights')
        print(
            'valid:  f1: %.5f, precision: %.5f, recall: %.5f, best f1: %.5f\n' %
            (f1, precision, recall, self.best_val_f1)
        )
        f1, precision, recall = evaluate(test_data)
        print(
            'test:  f1: %.5f, precision: %.5f, recall: %.5f\n' %
            (f1, precision, recall)
        )


In [0]:
evaluator = Evaluate()
batch_size = 10
epochs = 10
train_generator = data_generator(train_data, batch_size)
# 进行训练
history = model.fit_generator(train_generator.forfit(),
                steps_per_epoch=len(train_generator),
                epochs=epochs,
                callbacks=[evaluator])

In [0]:
# 预测
text = 'MALLET是美国麻省大学（UMASS）阿姆斯特（Amherst）分校开发的一个统计自然语言处理开源软件包'
named_entity_recognize(text)