## Testing the TF data pipeline

I updated the data loading code, but there is something very slow, so let's do some tests here.

In [1]:
import os
from functools import partial

import numpy as np
import tensorflow as tf

In [102]:
def _parse_data(sample_proto, shape):
    parsed_example = tf.parse_single_example(
        sample_proto,
        features = dict(x=tf.FixedLenFeature(shape, tf.float32),
                        y=tf.FixedLenFeature([4], tf.float32))
    )
    # Decode the data and normalize
    x, y = parsed_example['x'], parsed_example['y']
    x /= (tf.reduce_sum(x) / np.prod(shape))
    return x, y

def construct_dataset(filenames, batch_size, n_epochs, sample_shape,
                      rank=0, n_ranks=1, shard=False, shuffle=False):
                      
    # Define the dataset from the list of files
    data = tf.data.Dataset.from_tensor_slices(filenames)
    if shard:
        data = data.shard(num_shards=n_ranks, index=rank)
    if shuffle:
        data = data.shuffle(len(filenames), reshuffle_each_iteration=True)
    # Parse TFRecords
    parse_data = partial(_parse_data, shape=sample_shape)
    data = data.apply(tf.data.TFRecordDataset).map(parse_data, num_parallel_calls=4)
    data = data.repeat(n_epochs)
    data = data.batch(batch_size, drop_remainder=True)
    return data.prefetch(4)

In [103]:
data_dir = os.path.expandvars('$SCRATCH/cosmoflow-benchmark/data/cosmoUniverse_2019_05_4parE_tf')
n_train = 64
batch_size = 1
n_epochs = 1
sample_shape = (128, 128, 128, 4)

In [104]:
all_files = sorted([os.path.join(data_dir, f) for f in os.listdir(data_dir)
                    if f.endswith('.tfrecord')])
train_files = all_files[:n_train]

In [106]:
train_dataset = construct_dataset(train_files, batch_size=batch_size, n_epochs=n_epochs, sample_shape=sample_shape)
train_iter = train_dataset.make_one_shot_iterator()
train_batch = train_iter.get_next()

In [109]:
%%time

with tf.Session() as sess:
    x, y = sess.run(train_batch)

CPU times: user 362 ms, sys: 55.7 ms, total: 418 ms
Wall time: 164 ms


In [110]:
y

array([[-0.0219778, -0.9541416, -0.4267502, -0.7982607]], dtype=float32)