In [40]:
import numpy as np
import tensorflow as tf
import tensorflow_text as text
import functools

In [2]:
examples = {
    "text_a": [
      "Sponge bob Squarepants is an Avenger",
      "Marvel Avengers"
    ],
    "text_b": [
     "Barack Obama is the President.",
     "President is the highest office"
  ],
}

dataset = tf.data.Dataset.from_tensor_slices(examples)

next(iter(dataset))

{'text_a': <tf.Tensor: shape=(), dtype=string, numpy=b'Sponge bob Squarepants is an Avenger'>,
 'text_b': <tf.Tensor: shape=(), dtype=string, numpy=b'Barack Obama is the President.'>}

In [6]:
_VOCAB = [
    # Special tokens
    b"[UNK]", b"[MASK]", b"[RANDOM]", b"[CLS]", b"[SEP]",
    # Suffixes
    b"##ack", b"##ama", b"##ger", b"##gers", b"##onge", b"##pants",  b"##uare",
    b"##vel", b"##ven", b"an", b"A", b"Bar", b"Hates", b"Mar", b"Ob",
    b"Patrick", b"President", b"Sp", b"Sq", b"bob", b"box", b"has", b"highest",
    b"is", b"office", b"the",
]

_START_TOKEN = _VOCAB.index(b"[CLS]")
_END_TOKEN = _VOCAB.index(b"[SEP]")
_MASK_TOKEN = _VOCAB.index(b"[MASK]")
_RANDOM_TOKEN = _VOCAB.index(b"[RANDOM]")
_UNK_TOKEN = _VOCAB.index(b"[UNK]")
_MAX_SEQ_LEN = 8
_MAX_PREDICTIONS_PER_BATCH = 5

_VOCAB_SIZE = len(_VOCAB)

In [8]:
lookup_table = tf.lookup.StaticVocabularyTable(
    tf.lookup.KeyValueTensorInitializer(keys=_VOCAB,
                                        key_dtype=tf.string,
                                        values=tf.range(tf.size(_VOCAB, out_type=tf.int64), dtype=tf.int64),
                                        value_dtype=tf.int64),
    num_oov_buckets=1)

In [17]:
bertTokenizer=text.BertTokenizer(lookup_table, token_out_type=tf.string)
bertTokenizer.tokenize(examples['text_a'])

<tf.RaggedTensor [[[b'Sp', b'##onge'], [b'bob'], [b'Sq', b'##uare', b'##pants'], [b'is'],
  [b'an'], [b'A', b'##ven', b'##ger']]                                  ,
 [[b'Mar', b'##vel'], [b'A', b'##ven', b'##gers']]]>

In [19]:
bertTokenizer.tokenize(examples['text_b'])

<tf.RaggedTensor [[[b'Bar', b'##ack'], [b'Ob', b'##ama'], [b'is'], [b'the'], [b'President'],
  [b'[UNK]']]                                                              ,
 [[b'President'], [b'is'], [b'the'], [b'highest'], [b'office']]]>

In [20]:
bertTokenizer=text.BertTokenizer(lookup_table, token_out_type=tf.int64)
segment_a = bertTokenizer.tokenize(examples['text_a'])
segment_b=bertTokenizer.tokenize(examples['text_b'])
segment_a, segment_b

(<tf.RaggedTensor [[[22, 9], [24], [23, 11, 10], [28], [14], [15, 13, 7]],
  [[18, 12], [15, 13, 8]]]>,
 <tf.RaggedTensor [[[16, 5], [19, 6], [28], [30], [21], [0]], [[21], [28], [30], [27], [29]]]>)

In [21]:
segment_a = segment_a.merge_dims(-2,-1)
segment_b = segment_b.merge_dims(-2,-1)
segment_a, segment_b

(<tf.RaggedTensor [[22, 9, 24, 23, 11, 10, 28, 14, 15, 13, 7], [18, 12, 15, 13, 8]]>,
 <tf.RaggedTensor [[16, 5, 19, 6, 28, 30, 21, 0], [21, 28, 30, 27, 29]]>)

In [22]:
trimmer = text.RoundRobinTrimmer(max_seq_length=_MAX_SEQ_LEN)
trimmed = trimmer.trim([segment_a, segment_b])
trimmed

[<tf.RaggedTensor [[22, 9, 24, 23],
  [18, 12, 15, 13]]>,
 <tf.RaggedTensor [[16, 5, 19, 6],
  [21, 28, 30, 27]]>]

In [23]:
segments_combined, segments_id = text.combine_segments(trimmed, start_of_sequence_id=_START_TOKEN, end_of_segment_id=_END_TOKEN)
segments_combined, segments_id

(<tf.RaggedTensor [[3, 22, 9, 24, 23, 4, 16, 5, 19, 6, 4],
  [3, 18, 12, 15, 13, 4, 21, 28, 30, 27, 4]]>,
 <tf.RaggedTensor [[0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
  [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1]]>)

In [29]:
random_selector = text.RandomItemSelector(max_selections_per_batch=_MAX_PREDICTIONS_PER_BATCH, 
                                         selection_rate=0.2,
                                         unselectable_ids=[_START_TOKEN, _END_TOKEN, _UNK_TOKEN])
selected=random_selector.get_selection_mask(segments_combined, axis=1)
selected

<tf.RaggedTensor [[False, False, True, False, False, False, False, False, False, True,
  False],
 [False, False, True, False, False, False, False, False, True, False,
  False]]>

In [30]:
mask_values_chooser = text.MaskValuesChooser(_VOCAB_SIZE, _MASK_TOKEN, mask_token_rate=0.8)
mask_values_chooser.get_mask_values(segments_combined)

<tf.RaggedTensor [[17, 1, 1, 1, 23, 1, 1, 1, 1, 1, 19],
 [1, 1, 20, 1, 1, 1, 1, 1, 1, 1, 1]]>

In [31]:
masked_tokens, masked_pos, masked_lm_ids= text.mask_language_model(segments_combined, random_selector, mask_values_chooser)
masked_tokens, masked_pos, masked_lm_ids

(<tf.RaggedTensor [[3, 22, 1, 24, 23, 4, 1, 5, 19, 6, 4],
  [3, 18, 12, 15, 13, 4, 21, 1, 1, 27, 4]]>,
 <tf.RaggedTensor [[2, 6],
  [7, 8]]>,
 <tf.RaggedTensor [[9, 16],
  [28, 30]]>)

In [32]:
tf.gather(_VOCAB, masked_tokens)

<tf.RaggedTensor [[b'[CLS]', b'Sp', b'[MASK]', b'bob', b'Sq', b'[SEP]', b'[MASK]',
  b'##ack', b'Ob', b'##ama', b'[SEP]'],
 [b'[CLS]', b'Mar', b'##vel', b'A', b'##ven', b'[SEP]', b'President',
  b'[MASK]', b'[MASK]', b'highest', b'[SEP]']]>

In [33]:
tf.gather(_VOCAB, masked_lm_ids)

<tf.RaggedTensor [[b'##onge', b'Bar'],
 [b'is', b'the']]>

In [34]:
input_word_ids, input_mask =  text.pad_model_inputs(masked_tokens, max_seq_length=_MAX_SEQ_LEN)
input_type_ids, _ =  text.pad_model_inputs(segments_id, max_seq_length=_MAX_SEQ_LEN)
masked_lm_positions, masked_lm_weights = text.pad_model_inputs(masked_pos,max_seq_length=_MAX_PREDICTIONS_PER_BATCH)
masked_lm_ids, _ = text.pad_model_inputs(masked_lm_ids,max_seq_length=_MAX_PREDICTIONS_PER_BATCH)

model_inputs = {
    "input_word_ids": input_word_ids,
    "input_mask": input_mask,
    "input_type_ids": input_type_ids,
    "masked_lm_ids": masked_lm_ids,
    "masked_lm_positions": masked_lm_positions,
    "masked_lm_weights": masked_lm_weights,
}
model_inputs

{'input_word_ids': <tf.Tensor: shape=(2, 8), dtype=int64, numpy=
 array([[ 3, 22,  1, 24, 23,  4,  1,  5],
        [ 3, 18, 12, 15, 13,  4, 21,  1]])>,
 'input_mask': <tf.Tensor: shape=(2, 8), dtype=int64, numpy=
 array([[1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1]])>,
 'input_type_ids': <tf.Tensor: shape=(2, 8), dtype=int64, numpy=
 array([[0, 0, 0, 0, 0, 0, 1, 1],
        [0, 0, 0, 0, 0, 0, 1, 1]])>,
 'masked_lm_ids': <tf.Tensor: shape=(2, 5), dtype=int64, numpy=
 array([[ 9, 16,  0,  0,  0],
        [28, 30,  0,  0,  0]])>,
 'masked_lm_positions': <tf.Tensor: shape=(2, 5), dtype=int64, numpy=
 array([[2, 6, 0, 0, 0],
        [7, 8, 0, 0, 0]])>,
 'masked_lm_weights': <tf.Tensor: shape=(2, 5), dtype=int64, numpy=
 array([[1, 1, 0, 0, 0],
        [1, 1, 0, 0, 0]])>}

In [39]:
def bert_pretrain_preprocess(vocab_table, feature):
    text_a = feature['text_a']
    text_b = feature['text_b']
    tokenizer = text.BertTokenizer(vocab_table,token_out_type=tf.int64)
    segments = [ tokenizer.tokenize(text).merge_dims(-2,-1) for text in [text_a, text_b] ]
    trimmer = text.RoundRobinTrimmer(max_seq_length=6)
    trimmed_segments = trimmer.trim(segments)
    segments_combined, segments_id = text.combine_segments(trimmed_segments,
                                                           start_of_sequence_id=_START_TOKEN, 
                                                           end_of_segment_id=_END_TOKEN)
    masked_input_ids, masked_lm_positions, masked_lm_ids = (
        text.mask_language_model(segments_combined, random_selector, mask_values_chooser)
    )
    input_word_ids, input_mask = text.pad_model_inputs(masked_input_ids, max_seq_length=_MAX_SEQ_LEN)
    input_type_ids, _ = text.pad_model_inputs(segments_id, max_seq_length=_MAX_SEQ_LEN)
    masked_lm_positions, masked_lm_weights = text.pad_model_inputs(masked_lm_positions, max_seq_length=_MAX_PREDICTIONS_PER_BATCH)
    masked_lm_ids,_= text.pad_model_inputs(masked_lm_ids, max_seq_length=_MAX_PREDICTIONS_PER_BATCH)
    model ={
        "input_word_ids":input_word_ids,
        "input_mask":input_mask,
        "input_type_ids":input_type_ids,
        "masked_lm_positions":masked_lm_positions,
        "masked_lm_weights":masked_lm_weights,
        "masked_lm_ids":masked_lm_ids
    }
    return model
    

In [41]:
dataset = tf.data.Dataset.from_tensor_slices(examples).map(functools.partial(bert_pretrain_preprocess, lookup_table))

In [42]:
next(iter(dataset))

{'input_word_ids': <tf.Tensor: shape=(1, 8), dtype=int64, numpy=array([[ 3, 22,  1, 24,  4, 16,  1, 19]])>,
 'input_mask': <tf.Tensor: shape=(1, 8), dtype=int64, numpy=array([[1, 1, 1, 1, 1, 1, 1, 1]])>,
 'input_type_ids': <tf.Tensor: shape=(1, 8), dtype=int64, numpy=array([[0, 0, 0, 0, 0, 1, 1, 1]])>,
 'masked_lm_positions': <tf.Tensor: shape=(1, 5), dtype=int64, numpy=array([[2, 6, 0, 0, 0]])>,
 'masked_lm_weights': <tf.Tensor: shape=(1, 5), dtype=int64, numpy=array([[1, 1, 0, 0, 0]])>,
 'masked_lm_ids': <tf.Tensor: shape=(1, 5), dtype=int64, numpy=array([[9, 5, 0, 0, 0]])>}

In [43]:
next(iter(dataset))

{'input_word_ids': <tf.Tensor: shape=(1, 8), dtype=int64, numpy=array([[ 3, 22,  9,  1,  4,  1,  5, 19]])>,
 'input_mask': <tf.Tensor: shape=(1, 8), dtype=int64, numpy=array([[1, 1, 1, 1, 1, 1, 1, 1]])>,
 'input_type_ids': <tf.Tensor: shape=(1, 8), dtype=int64, numpy=array([[0, 0, 0, 0, 0, 1, 1, 1]])>,
 'masked_lm_positions': <tf.Tensor: shape=(1, 5), dtype=int64, numpy=array([[3, 5, 0, 0, 0]])>,
 'masked_lm_weights': <tf.Tensor: shape=(1, 5), dtype=int64, numpy=array([[1, 1, 0, 0, 0]])>,
 'masked_lm_ids': <tf.Tensor: shape=(1, 5), dtype=int64, numpy=array([[24, 16,  0,  0,  0]])>}

In [44]:
next(iter(dataset))

{'input_word_ids': <tf.Tensor: shape=(1, 8), dtype=int64, numpy=array([[ 3, 22,  9, 24,  4,  1,  1, 19]])>,
 'input_mask': <tf.Tensor: shape=(1, 8), dtype=int64, numpy=array([[1, 1, 1, 1, 1, 1, 1, 1]])>,
 'input_type_ids': <tf.Tensor: shape=(1, 8), dtype=int64, numpy=array([[0, 0, 0, 0, 0, 1, 1, 1]])>,
 'masked_lm_positions': <tf.Tensor: shape=(1, 5), dtype=int64, numpy=array([[5, 6, 0, 0, 0]])>,
 'masked_lm_weights': <tf.Tensor: shape=(1, 5), dtype=int64, numpy=array([[1, 1, 0, 0, 0]])>,
 'masked_lm_ids': <tf.Tensor: shape=(1, 5), dtype=int64, numpy=array([[16,  5,  0,  0,  0]])>}