In [2]:
import numpy as np
import tensorflow as tf

In [10]:
max_sequence_length=512
batch_size=20
buffer_size=1024
vocab_size=22
mask_index=1
vocab_start=2
fix_sequence_length=False
masking_freq=.15
mask_token_freq=.8
mask_random_freq=.1
filter_bzux=True
no_mask_pad=1

In [4]:
vocab = ['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K',
         'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 
         'W', 'Y', '^', '$']

vocab_size = len(vocab)

table = tf.lookup.StaticHashTable(
    tf.lookup.KeyValueTensorInitializer(
        keys=vocab, values=tf.range(len(vocab)) + 2),
    default_value=0)

@tf.function
def encode(x):
    chars = tf.strings.bytes_split(x)
    
    # Append start and end tokens
    chars = tf.concat([tf.constant(['^']), chars, tf.constant(['$'])], 0)

    # If chars is greater than max_sequence_length, take a random crop
    chars = tf.cond(tf.shape(chars) > max_sequence_length,
            lambda: tf.image.random_crop(chars, (max_sequence_length,)),
            lambda: chars)

    return table.lookup(chars)

@tf.function
def mask_input(input_tensor):
    """ Randomly mask the input tensor according to the formula perscribed by BERT. 
    Randomly masks 15% of input tokens, with 80% recieving the [MASK] token,
    10% randomized, 10% left unchanged. 

    Returns
    -------

    masked_tensor: (batch_size, seq_length) 
        Tensor with masked values
    input_tensor: (batch_size, seq_length)
        Original input tensor (true values)
    input_mask: (batch_size, seq_length)
        Boolean mask that selects the desired inputs.    
    """

    input_shape = tf.shape(input_tensor)
    mask_score = tf.random.uniform(input_shape, maxval=1, dtype=tf.float32)
    input_mask = mask_score < masking_freq

    # Mask with [MASK] token 80% of the time
    mask_mask = mask_score <= masking_freq * mask_token_freq

    # Mask with random token 10% of the time
    mask_random = (mask_score >= masking_freq * (1. - mask_random_freq)) & input_mask

    # Tensors to replace with where input is masked or randomized
    mask_value_tensor = tf.ones(input_shape, dtype=tf.int32) * mask_index
    random_value_tensor = tf.random.uniform(
        input_shape, minval=vocab_start, maxval=vocab_size + 2, dtype=tf.int32)
    pad_value_tensor = tf.zeros(input_shape, dtype=tf.int32)

    # Use the replacements to mask the input tensor
    masked_tensor = tf.where(mask_mask, mask_value_tensor, input_tensor)
    masked_tensor = tf.where(mask_random, random_value_tensor, masked_tensor)

    # Set true values to zero (pad value) where not masked
    true_tensor = tf.where(input_mask, input_tensor, pad_value_tensor)

    return masked_tensor, true_tensor

dataset = tf.data.TextLineDataset('../uniparc_data/train_uniref100.txt.gz', compression_type='GZIP')

if filter_bzux:
    bzux_filter = lambda string: tf.math.logical_not(
        tf.strings.regex_full_match(string, '.*[BZUOX].*'))
    dataset = dataset.filter(bzux_filter)

encoded_data = dataset\
    .map(encode, num_parallel_calls=tf.data.experimental.AUTOTUNE)

In [7]:
input_tensor = next(iter(encoded_data))

In [11]:
input_shape = tf.shape(input_tensor)
mask_score = tf.random.uniform(input_shape - no_mask_pad * 2, maxval=1, dtype=tf.float32)
mask_score = tf.concat([tf.ones(no_mask_pad), mask_score, tf.ones(no_mask_pad)], 0)
input_mask = mask_score < masking_freq

In [16]:
tf.concat([tf.ones(no_mask_pad), mask_score, tf.ones(no_mask_pad)], 0)

<tf.Tensor: id=190, shape=(432,), dtype=float32, numpy=
array([1.00000000e+00, 5.40298820e-01, 7.86184788e-01, 7.21072435e-01,
       8.18890333e-02, 4.95592237e-01, 3.14937830e-01, 9.45066571e-01,
       7.03465939e-03, 7.00511694e-01, 9.21649218e-01, 5.74735045e-01,
       8.46982360e-01, 2.82406688e-01, 6.44581437e-01, 7.64594078e-02,
       9.74052906e-01, 6.69313431e-01, 5.07784128e-01, 1.97598696e-01,
       3.46693993e-02, 1.51729107e-01, 9.64540958e-01, 7.55560398e-01,
       9.91850615e-01, 3.14059377e-01, 7.30306983e-01, 5.75053692e-01,
       9.40854669e-01, 7.13680863e-01, 5.60191154e-01, 1.99417353e-01,
       5.07982492e-01, 7.67915249e-01, 6.66952610e-01, 3.02054882e-01,
       9.05965209e-01, 4.92723465e-01, 8.62968206e-01, 9.97504473e-01,
       8.40231895e-01, 9.30689573e-01, 8.77267838e-01, 1.73936844e-01,
       3.39296699e-01, 9.14276719e-01, 2.82260180e-01, 8.32860827e-01,
       5.84446669e-01, 9.13235784e-01, 4.00929332e-01, 8.72491121e-01,
       1.20351076e-01

In [9]:
mask_score

<tf.Tensor: id=160, shape=(432,), dtype=float32, numpy=
array([0.36031353, 0.9718274 , 0.9704088 , 0.5596117 , 0.9622735 ,
       0.6275754 , 0.02057183, 0.98879457, 0.67847407, 0.511451  ,
       0.31457484, 0.2030307 , 0.05781162, 0.9700577 , 0.87887096,
       0.33954692, 0.441589  , 0.28692305, 0.7514062 , 0.11643517,
       0.99769855, 0.83928204, 0.55498767, 0.85603154, 0.8213414 ,
       0.9926305 , 0.19381595, 0.44474804, 0.3775823 , 0.76051295,
       0.784451  , 0.1514591 , 0.10342324, 0.7423258 , 0.59138846,
       0.6553309 , 0.5350299 , 0.41742134, 0.06401265, 0.12608814,
       0.58775437, 0.1468221 , 0.6490656 , 0.52453506, 0.74527395,
       0.3102647 , 0.7814647 , 0.31444097, 0.44431448, 0.17624521,
       0.04818022, 0.64374053, 0.18182671, 0.7144389 , 0.5234299 ,
       0.10299766, 0.15352654, 0.536371  , 0.5257994 , 0.9598707 ,
       0.26844645, 0.8016484 , 0.58619595, 0.81033206, 0.93123806,
       0.8747127 , 0.92220986, 0.38942564, 0.39678204, 0.30834222,
      

In [6]:
    .map(mask_input, num_parallel_calls=tf.data.experimental.AUTOTUNE)

# This argument controls whether to fix the size of the sequences
tf_seq_len = -1 if not fix_sequence_length else max_sequence_length

encoded_data = encoded_data\
    .shuffle(buffer_size=buffer_size)\
    .padded_batch(batch_size, padded_shapes=(
        ([tf_seq_len], [tf_seq_len])))

In [11]:
test = next(iter(dataset))

In [59]:
tf.concat([tf.range(1), tf.range(2)], 0)

<tf.Tensor: id=930, shape=(3,), dtype=int32, numpy=array([0, 0, 1], dtype=int32)>

<tf.Tensor: id=913, shape=(), dtype=string, numpy=b'A'>

<tf.Tensor: id=939, shape=(432,), dtype=string, numpy=
array([b'^', b'M', b'L', b'F', b'G', b'T', b'A', b'K', b'M', b'N', b'R',
       b'E', b'N', b'H', b'L', b'E', b'I', b'G', b'G', b'C', b'D', b'T',
       b'V', b'K', b'L', b'A', b'Q', b'K', b'F', b'G', b'T', b'P', b'L',
       b'F', b'V', b'Y', b'D', b'V', b'A', b'H', b'I', b'R', b'A', b'Q',
       b'A', b'R', b'G', b'F', b'K', b'Q', b'T', b'L', b'N', b'Q', b'L',
       b'G', b'I', b'K', b'N', b'K', b'V', b'V', b'Y', b'A', b'S', b'K',
       b'A', b'F', b'S', b'C', b'L', b'A', b'I', b'Y', b'Q', b'V', b'L',
       b'K', b'E', b'E', b'D', b'I', b'A', b'C', b'D', b'V', b'V', b'S',
       b'G', b'G', b'E', b'L', b'F', b'T', b'A', b'L', b'K', b'G', b'G',
       b'M', b'E', b'P', b'A', b'E', b'I', b'E', b'F', b'H', b'G', b'N',
       b'N', b'K', b'T', b'P', b'E', b'E', b'L', b'R', b'Y', b'A', b'L',
       b'D', b'N', b'K', b'I', b'G', b'T', b'I', b'V', b'I', b'D', b'N',
       b'F', b'Y', b'E', b'I', b'D', b'L', b'L', b'E', b'E', b'L', b'

In [56]:
tf.concat([tf.ten('A'), chars], axis=0)

InvalidArgumentError: ConcatOp : Expected concatenating dimensions in the range [0, 0), but got 0 [Op:ConcatV2] name: concat

In [53]:
tf.concat(['^', chars, '$'], axis=0)

InvalidArgumentError: ConcatOp : Expected concatenating dimensions in the range [0, 0), but got 0 [Op:ConcatV2] name: concat

In [18]:
chars = tf.strings.bytes_split(test)
encoded = table.lookup(chars)

In [19]:
indices = tf.range(1, tf.shape(encoded)[0] + 1)

In [27]:
tf.image.random_crop(indices, (64,), seed=1)

<tf.Tensor: id=689, shape=(64,), dtype=int32, numpy=
array([336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348,
       349, 350, 351, 352, 353, 354, 355, 356, 357, 358, 359, 360, 361,
       362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374,
       375, 376, 377, 378, 379, 380, 381, 382, 383, 384, 385, 386, 387,
       388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399],
      dtype=int32)>

In [40]:
def random_crop(value, size, seed=None):
    shape = tf.shape(value)
    size = tf.convert_to_tensor(size, dtype=tf.int32)
    limit = shape - size + 1
    offset = tf.random.uniform(
        tf.shape(shape),
        dtype=size.dtype,
        maxval=size.dtype.max,
        seed=seed) % limit
    return tf.slice(value, offset, size)

In [44]:
random_crop(indices, (64,))

<tf.Tensor: id=760, shape=(64,), dtype=int32, numpy=
array([197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209,
       210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222,
       223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235,
       236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248,
       249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260],
      dtype=int32)>