diff --git a/examples/mnist/keras/mnist_tf_ds.py b/examples/mnist/keras/mnist_tf_ds.py index 5f36dcbc..eaf8bcc6 100644 --- a/examples/mnist/keras/mnist_tf_ds.py +++ b/examples/mnist/keras/mnist_tf_ds.py @@ -39,14 +39,14 @@ def parse_tfos(example_proto): # tfos: /path/to/mnist/tfr/train/part-r-* image_pattern = ctx.absolute_path(args.images_labels) - options = tf.data.Options() - options.experimental_distribute.auto_shard = False - ds = tf.data.Dataset.list_files(image_pattern) - ds = ds.with_options(options) ds = ds.repeat(args.epochs).shuffle(BUFFER_SIZE) ds = ds.interleave(tf.data.TFRecordDataset) - train_datasets_unbatched = ds.map(parse_tfos) + + if args.data_format == 'tfds': + train_datasets_unbatched = ds.map(parse_tfds) + else: # 'tfos' + train_datasets_unbatched = ds.map(parse_tfos) def build_and_compile_cnn_model(): model = tf.keras.Sequential([