In [6]:
import sys
sys.path.append(r"D:\work\nlp")
from fennlp.datas import dataloader
import tensorflow as tf
from fennlp.datas.checkpoint import LoadCheckpoint
from fennlp.datas.dataloader import TFWriter, TFLoader
from fennlp.metrics import Metric
from fennlp.metrics.crf import CrfLogLikelihood
from fennlp.models import bert
from fennlp.optimizers import optim
from fennlp.tools import bert_init_weights_from_checkpoint

unable to import 'smart_open.gcs', disabling that module


In [7]:
load_check = LoadCheckpoint(language='zh', is_download=False, file_path=r"D:\work\nlp\tests\NER\NER_ZH\chinese_L-12_H-768_A-12")

In [8]:
param, vocab_file, model_path = load_check.load_bert_param()

In [9]:
param.batch_size = 1
param.maxlen = 64
param.label_size = 7

In [10]:
class BERT_NER(tf.keras.Model):
    def __init__(self, param, **kwargs):
        super(BERT_NER, self).__init__(**kwargs)
        self.batch_size = param.batch_size
        self.maxlen = param.maxlen
        self.label_size = param.label_size
        self.bert = bert.BERT(param)
        self.dense = tf.keras.layers.Dense(self.label_size, activation="relu")
        self.crf = CrfLogLikelihood()

    def call(self, inputs, is_training=True):
        # 数据切分
        input_ids, token_type_ids, input_mask, Y = tf.split(inputs, 4, 0)
        input_ids = tf.cast(tf.squeeze(input_ids, axis=0), tf.int64)
        token_type_ids = tf.cast(tf.squeeze(token_type_ids, axis=0), tf.int64)
        input_mask = tf.cast(tf.squeeze(input_mask, axis=0), tf.int64)
        Y = tf.cast(tf.squeeze(Y, axis=0), tf.int64)
        # 模型构建
        bert = self.bert([input_ids, token_type_ids, input_mask], is_training)
        sequence_output = bert.get_sequence_output()  # batch,sequence,768
        predict = self.dense(sequence_output)
        predict = tf.reshape(predict, [self.batch_size, self.maxlen, -1])
        # 损失计算
        log_likelihood, transition = self.crf(predict, Y, sequence_lengths=tf.reduce_sum(input_mask, 1))
        loss = tf.math.reduce_mean(-log_likelihood)
        predict, viterbi_score = self.crf.crf_decode(predict, transition,
                                                     sequence_length=tf.reduce_sum(input_mask, 1))
        return loss, predict

    def predict(self, inputs, is_training=False):
        loss, predict = self(inputs, is_training)
        return predict

In [6]:
writer = TFWriter(param.maxlen, vocab_file, modes=["test"], input_dir="ner_data", output_dir="ner_data", check_exist=False)

Writing test
Totally use 7 labels!

ner_data\test has been converted into ner_data\test.tfrecords


In [11]:
loader = TFLoader(param.maxlen, 1, input_dir="ner_data")

In [76]:
ds = loader.load_test()
it = ds.__iter__()

In [77]:
inp, seg, mask, Y  = next(it)

In [80]:
inp = tf.cast(inp, tf.int32)

In [1]:
import tensorflow_hub as hub 
import tensorflow as tf

In [2]:
max_seq_length = 64

In [62]:
input_word_ids = tf.keras.layers.Input(
  shape=(max_seq_length,), dtype=tf.int64, name='input_word_ids')
input_mask = tf.keras.layers.Input(
  shape=(max_seq_length,), dtype=tf.int64, name='input_mask')
input_type_ids = tf.keras.layers.Input(
  shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')
bert_model = hub.KerasLayer(
  r"D:\work\bert_layer", trainable=True)
pooled_output, sequence_output = bert_model([input_word_ids, input_mask, input_type_ids])



In [74]:
a = tf.constant(value=1, shape=[1, 64], dtype=tf.int32)

In [81]:
bert_model([inp, inp, inp])

[<tf.Tensor: shape=(1, 768), dtype=float32, numpy=
 array([[ 0.8392251 , -0.80132633,  0.98686516, -0.9139344 , -0.5557141 ,
         -0.0104067 ,  0.98447055,  0.8084494 ,  0.05775368, -0.30066064,
          0.7366882 ,  0.9787452 ,  0.76530236, -0.8515802 ,  0.9258017 ,
         -0.8980483 ,  0.40986073, -0.43893543,  0.89892876,  0.01975894,
         -0.53013664, -0.24254909,  0.36107895, -0.94788027,  0.20870344,
          0.77322304, -0.13459991,  0.61280257, -0.14669281,  0.9845042 ,
          0.5732013 ,  0.44861558,  0.44586775, -0.26926103, -0.8440546 ,
         -0.9376162 , -0.88875955,  0.8457493 ,  0.72938055, -0.35184875,
         -0.2374227 ,  0.7819521 ,  0.29069796, -0.53531337,  0.7373173 ,
         -0.78149456, -0.9527479 , -0.86609405,  0.17786375, -0.19581993,
          0.6258323 , -0.8061827 , -0.40804785, -0.81702006, -0.8993888 ,
          0.49208155, -0.06506751, -0.15342882,  0.9794398 ,  0.01716135,
          0.94269305,  0.8772418 , -0.9681446 , -0.97101116, 