In [1]:
import tensorflow as tf
from tensorflow_transform.tf_metadata import dataset_schema

tf.__version__

'1.13.1'

### High Performance Input Tensors

In [19]:
feature_spec = {
    'state': tf.io.FixedLenFeature([21,21,2], tf.float32),
    'distr': tf.io.FixedLenFeature([21,21,1], tf.float32)
}
schema = dataset_schema.from_feature_spec(feature_spec)

In [20]:
def make_tfr_input_fn(filename_pattern, batch_size, options):
    
    def _input_fn():
        dataset = tf.data.experimental.make_batched_features_dataset(
            file_pattern=filename_pattern,
            batch_size=batch_size,
            features=feature_spec,
            shuffle_buffer_size=options['shuffle_buffer_size'],
            prefetch_buffer_size=options['prefetch_buffer_size'],
            reader_num_threads=options['reader_num_threads'],
            parser_num_threads=options['parser_num_threads'],
            sloppy_ordering=options['sloppy_ordering'],
            num_epochs=options['num_epochs'],
            label_key='distr')

        if options['distribute']:
            return dataset 
        else:
            return dataset.make_one_shot_iterator().get_next()
    return _input_fn

In [21]:
file_pattern='deleteme.tfr'

In [22]:
train_input_fn = make_tfr_input_fn(
    filename_pattern=file_pattern,
    batch_size=5, 
    options={'num_epochs': None,  # repeat infinitely
             'shuffle_buffer_size': 40,
             'prefetch_buffer_size': 40,
             'reader_num_threads': 10,
             'parser_num_threads': 10,
             'sloppy_ordering': True,
             'distribute': False})

This design pattern allows us to provide parameters to a function that is not allowed to take some. We essentially have a function now that provides its parameters to a *daughter* function as constants.

Later, we will provide this ```train_input_fn``` to the so-called ```estimator```. It is then up to the ```estimator``` to call ```train_input_fn``` and by that create the input-generating computational sub-graph within it's own graph and session context.

For demonstration purposes, we call the function ourselves and see what it returns.

In [23]:
samples, labels = train_input_fn()

In [24]:
samples

{'state': <tf.Tensor 'IteratorGetNext_2:0' shape=(5, 21, 21, 2) dtype=float32>}

In [25]:
labels

<tf.Tensor 'IteratorGetNext_2:1' shape=(5, 21, 21, 1) dtype=float32>

Now, each time we evaluate ```samples``` and ```labels```, we'll get a new batch of 1000 samples with the associated 'humidity' labels.

In [30]:
with tf.Session() as sess:
    s, l = sess.run([samples, labels])

In [31]:
s['state'].shape

(5, 21, 21, 2)