In [32]:
import tensorflow as tf

In [60]:
A = tf.data.Dataset.range(10).shuffle(buffer_size=40, seed=3).take(2)

In [57]:
B = A.shard(num_shards=20, index=9)

In [61]:
list(A.as_numpy_iterator())

[4, 7]

In [70]:
sequence_compression='GZIP'
max_sequence_length=512
batch_size=20
buffer_size=1024
fix_sequence_length=False
masking_freq=.15
mask_token_freq=.8
mask_random_freq=.1
filter_bzux=True
no_mask_pad=1
shard_num_workers=None
shard_worker_index=None

tf_seq_len = -1 if not fix_sequence_length else max_sequence_length    

In [4]:
import os
import numpy as np
import tensorflow as tf

    
vocab = ['', 'A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K',
         'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y',
         '^', '$']

values = tf.range(len(vocab))
mask_index = len(vocab)  # Mask is the last entry

encoding_table = tf.lookup.StaticHashTable(
    tf.lookup.KeyValueTensorInitializer(keys=vocab, values=values),
    default_value=mask_index) # Missing values should just be the mask token

In [17]:
@tf.function
def encode(x):
    chars = tf.strings.bytes_split(x)

    # Append start and end tokens
    chars = tf.concat([tf.constant(['^']), chars, tf.constant(['$'])], 0)

    # If chars is greater than max_sequence_length, take a random crop
    chars = tf.cond(tf.shape(chars) > max_sequence_length,
            lambda: tf.image.random_crop(chars, (max_sequence_length,)),
            lambda: chars)

    return encoding_table.lookup(chars)


@tf.function
def mask_input(input_tensor):
    """ Randomly mask the input tensor according to the formula perscribed by BERT. 
    Randomly masks 15% of input tokens, with 80% recieving the [MASK] token,
    10% randomized, 10% left unchanged. 

    Returns
    -------

    masked_tensor: (batch_size, seq_length) 
        Tensor with masked values
    input_tensor: (batch_size, seq_length)
        Original input tensor (true values)
    input_mask: (batch_size, seq_length)
        Boolean mask that selects the desired inputs.    
    """

    input_shape = tf.shape(input_tensor)
    mask_score = tf.random.uniform(input_shape - no_mask_pad * 2, maxval=1, dtype=tf.float32)
    # Ensure that no_mask_pad tokens on edges are not masked
    mask_score = tf.concat([tf.ones(no_mask_pad), mask_score, tf.ones(no_mask_pad)], 0)
    input_mask = mask_score < masking_freq

    # Mask with [MASK] token 80% of the time
    mask_mask = mask_score <= masking_freq * mask_token_freq

    # Mask with random token 10% of the time
    mask_random = (mask_score >= masking_freq * (1. - mask_random_freq)) & input_mask

    # Tensors to replace with where input is masked or randomized.
    # Only add amino acid tokens as randoms
    mask_value_tensor = tf.ones(input_shape, dtype=tf.int32) * mask_index
    random_value_tensor = tf.random.uniform(
        input_shape, minval=1, maxval=20, dtype=tf.int32)
    pad_value_tensor = tf.zeros(input_shape, dtype=tf.int32)

    # Use the replacements to mask the input tensor
    masked_tensor = tf.where(mask_mask, mask_value_tensor, input_tensor)
    masked_tensor = tf.where(mask_random, random_value_tensor, masked_tensor)

    # Set true values to zero (pad value) where not masked
    true_tensor = tf.where(input_mask, input_tensor, pad_value_tensor)

    return masked_tensor, true_tensor


def parse(filename):
    """Combine operations together into a single function"""

    d = tf.data.TextLineDataset(filename, compression_type=sequence_compression)
    d = d.shuffle(buffer_size=buffer_size)

    if filter_bzux:
        bzux_filter = lambda string: tf.math.logical_not(
            tf.strings.regex_full_match(string, '.*[BZUOX].*'))
        d = d.filter(bzux_filter)

    return d

In [62]:
data_path = '/gpfs/alpine/proj-shared/bie108/split_uniref100'
sequence_path = os.path.join(data_path, 'train_uniref100_split/train_100_*.txt.gz')

In [74]:
dataset = tf.data.Dataset.list_files(sequence_path)
dataset = dataset.interleave(parse)

In [95]:
example = list(dataset.take(1))[0]

In [102]:
example

<tf.Tensor: shape=(), dtype=string, numpy=b'MKRLLIAFMKAYRAVVSPLYGQVCKFHPTCSAYALEALEVHGAAKGSLLAARRLVRCHPWSLGGYDPVPGSERERAWLARLEAEARSSSGASAPADRGTS'>

In [101]:
{let: num for let, num in zip(vocab, values.numpy())}

{'': 0,
 'A': 1,
 'C': 2,
 'D': 3,
 'E': 4,
 'F': 5,
 'G': 6,
 'H': 7,
 'I': 8,
 'K': 9,
 'L': 10,
 'M': 11,
 'N': 12,
 'P': 13,
 'Q': 14,
 'R': 15,
 'S': 16,
 'T': 17,
 'V': 18,
 'W': 19,
 'Y': 20,
 '^': 21,
 '$': 22}

In [105]:
input_tensor = encode(example)

In [106]:
mask_input(encode(example))

(<tf.Tensor: shape=(102,), dtype=int32, numpy=
 array([21, 11,  9, 15, 10, 23,  8,  1,  5, 23,  9,  1, 20, 15,  1, 18, 23,
        16, 13, 10, 20,  6, 14, 18,  2, 13,  5,  7, 13, 23, 23, 16,  1, 20,
         1, 10,  4, 23, 10, 23, 18,  7, 23,  1,  1,  9,  6, 16, 10, 10,  1,
         1, 15, 15, 23, 18, 15,  2,  7, 13, 19, 16, 10,  6, 23, 23,  3, 13,
        18, 13,  6, 16,  4, 15,  4, 15,  1, 23, 10,  1, 23, 10,  4,  1,  4,
         1, 15, 16, 16, 16, 23, 23, 23,  1, 13,  1,  3, 19, 23, 17, 16, 22],
       dtype=int32)>,
 <tf.Tensor: shape=(102,), dtype=int32, numpy=
 array([ 0,  0,  0,  0,  0, 10,  0,  0,  0, 11,  0,  0,  0,  0,  0,  0, 18,
         0, 13,  0,  0,  0,  0,  0,  0,  9,  0,  0,  0, 17,  2,  0,  0,  0,
         0,  0,  0,  1,  0,  4,  0,  0,  6,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0, 10,  0,  0,  0,  0,  0,  0,  0,  0,  0,  6, 20,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0, 19,  0,  0, 15,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  6,  1, 16,  0,  0,

In [80]:
example = dataset.take(1)

In [78]:
encode(example)

TypeError: in converted code:

    <ipython-input-17-f8b6606dcb32>:3 encode  *
        chars = tf.strings.bytes_split(x)
    /ccs/home/pstjohn/.conda/envs/tf21-ibm/lib/python3.6/site-packages/tensorflow_core/python/ops/ragged/ragged_string_ops.py:60 string_bytes_split
        name="input")
    /ccs/home/pstjohn/.conda/envs/tf21-ibm/lib/python3.6/site-packages/tensorflow_core/python/ops/ragged/ragged_tensor.py:2344 convert_to_tensor_or_ragged_tensor
        value=value, dtype=dtype, preferred_dtype=preferred_dtype, name=name)
    /ccs/home/pstjohn/.conda/envs/tf21-ibm/lib/python3.6/site-packages/tensorflow_core/python/framework/ops.py:1314 convert_to_tensor
        ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
    /ccs/home/pstjohn/.conda/envs/tf21-ibm/lib/python3.6/site-packages/tensorflow_core/python/framework/constant_op.py:317 _constant_tensor_conversion_function
        return constant(v, dtype=dtype, name=name)
    /ccs/home/pstjohn/.conda/envs/tf21-ibm/lib/python3.6/site-packages/tensorflow_core/python/framework/constant_op.py:258 constant
        allow_broadcast=True)
    /ccs/home/pstjohn/.conda/envs/tf21-ibm/lib/python3.6/site-packages/tensorflow_core/python/framework/constant_op.py:296 _constant_impl
        allow_broadcast=allow_broadcast))
    /ccs/home/pstjohn/.conda/envs/tf21-ibm/lib/python3.6/site-packages/tensorflow_core/python/framework/tensor_util.py:547 make_tensor_proto
        "supported type." % (type(values), values))

    TypeError: Failed to convert object of type <class 'tensorflow.python.data.ops.dataset_ops._VariantDataset'> to Tensor. Contents: <_VariantDataset shapes: (), types: tf.string>. Consider casting elements to a supported type.
