简单的tokenizer

英文都会拆到char

基本相当于list(inputs_str.replace(' ', '-'))

In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''
import re
import tensorflow as tf
import unicodedata
from collections import Counter
import random

In [2]:
word_index = {}
special = []
with open('./vocab.txt') as fp:
    for i, line in enumerate(fp):
        line = line.strip().lower()
        word_index[line] = i
        if line.startswith('[') and line.endswith(']'):
            special.append(line)

print(len(special))

104


In [3]:
special[:3]

['[pad]', '[unused1]', '[unused2]']

In [4]:
class BertTokenizer(tf.keras.models.Model):
    def __init__(self, word_index, **args):
        super(BertTokenizer, self).__init__(**args)
        self.construct(word_index)
    
    def construct(self, word_index):
        keys = tf.constant(list(word_index.keys()), dtype=tf.string)
        values = tf.constant(list(word_index.values()), dtype=tf.int32)
        self.table = tf.lookup.StaticHashTable(
            tf.lookup.KeyValueTensorInitializer(keys, values),
            tf.constant(word_index['[unk]'])) # default value
    
    @tf.function(experimental_relax_shapes=True)
    def call(self, inputs):
        x = inputs
        
        # x = tf.strings.unicode_encode(x, 'UTF-8', errors='ignore')
        x = tf.strings.lower(x)
        x = tf.strings.regex_replace(x, tf.constant(r'\s'), tf.constant('-'))
        x = tf.strings.unicode_split(x, 'UTF-8')
        x = tf.strings.reduce_join(x, separator=' ', axis=-1)
        x = tf.strings.regex_replace(x, r'\[ p a d \]', ' [pad] ')
        x = tf.strings.regex_replace(x, r'\[ u n k \]', ' [unk] ')
        x = tf.strings.regex_replace(x, r'\[ c l s \]', ' [cls] ')
        x = tf.strings.regex_replace(x, r'\[ s e p \]', ' [sep] ')
        x = tf.strings.regex_replace(x, r'\[ m a s k \]', ' [mask] ')
        x = tf.strings.split(x)
        x = x.to_tensor('[pad]')
        x = self.table.lookup(x)
        x = tf.squeeze(x, 1)

        cls = tf.fill([tf.shape(x)[0], 1], tf.constant(101))
        pad = tf.fill([tf.shape(x)[0], 1], tf.constant(0))
        x = x[:, :510]
        x = tf.concat([cls, x, pad], axis=1)
    
        row_inds = tf.range(0, tf.shape(x)[0])
        col_inds = tf.math.count_nonzero(x, axis=1)
        col_inds = tf.cast(col_inds, tf.int32)
        inds = tf.concat([
            tf.reshape(row_inds, (-1, 1)),
            tf.reshape(col_inds, (-1, 1))], axis=1)
        fill = tf.ones(tf.shape(x)[0], dtype=tf.int32)
        shape = tf.cast(tf.shape(x), tf.int32)
        sep = tf.scatter_nd(inds, fill, shape) * tf.constant(102)
        x = x + sep
    
        return x

    def compute_output_shape(self, input_shape):
        return input_shape

In [5]:
t = BertTokenizer(word_index)

In [6]:
t([['aa'], ['a']])

<tf.Tensor: shape=(2, 4), dtype=int32, numpy=
array([[101, 143, 143, 102],
       [101, 143, 102,   0]], dtype=int32)>

In [7]:
x = tf.constant([['aa'], ['a']] * 250)

In [8]:
%%timeit
t(x)

1.11 ms ± 41.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [9]:
save_path = 'bert_tokenizer'

# model = tf.keras.Sequential([
#     BertTokenizer(word_index, name='bert_tokenizer'),
#     BertIds(word_index, name='bert_token_to_ids'),
# ])
# model.save(save_path, include_optimizer=False)
model = BertTokenizer(word_index)
# model._set_inputs(tf.keras.backend.placeholder((None, None), dtype=tf.int32))
model._set_inputs(tf.keras.backend.placeholder((None, 1), dtype='string'))
model.save(save_path, include_optimizer=False)

INFO:tensorflow:Assets written to: bert_tokenizer/assets


In [10]:
model.summary()

Model: "bert_tokenizer_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
Total params: 0
Trainable params: 0
Non-trainable params: 0
_________________________________________________________________


In [11]:
m2 = tf.keras.models.load_model(save_path)



In [12]:
text = [
    ['我爱你[mask] [mask]哦'],
    ['我[unk][MASK]'],
    ['important']
]
vec = tf.constant(text)

In [13]:
m2(vec)

<tf.Tensor: shape=(3, 11), dtype=int32, numpy=
array([[ 101, 2769, 4263,  872,  103,  118,  103, 1521,  102,    0,    0],
       [ 101, 2769,  100,  103,  102,    0,    0,    0,    0,    0,    0],
       [ 101,  151,  155,  158,  157,  160,  162,  143,  156,  162,  102]],
      dtype=int32)>