diff --git a/examples/mnist/estimator/mnist_inference.py b/examples/mnist/estimator/mnist_inference.py index 2b2c9cb3..8e520d00 100644 --- a/examples/mnist/estimator/mnist_inference.py +++ b/examples/mnist/estimator/mnist_inference.py @@ -47,7 +47,7 @@ def parse_tfr(example_proto): return (image, label) # define a new tf.data.Dataset (for inferencing) - ds = tf.data.Dataset.list_files("{}/part-*".format(args.images_labels)) + ds = tf.data.Dataset.list_files("{}/part-*".format(args.images_labels), shuffle=False) ds = ds.shard(num_workers, worker_num) ds = ds.interleave(tf.data.TFRecordDataset) ds = ds.map(parse_tfr) diff --git a/examples/mnist/keras/mnist_inference.py b/examples/mnist/keras/mnist_inference.py index 21df737a..401f89f8 100644 --- a/examples/mnist/keras/mnist_inference.py +++ b/examples/mnist/keras/mnist_inference.py @@ -38,7 +38,7 @@ def parse_tfr(example_proto): return (image, label) # define a new tf.data.Dataset (for inferencing) - ds = tf.data.Dataset.list_files("{}/part-*".format(args.images_labels)) + ds = tf.data.Dataset.list_files("{}/part-*".format(args.images_labels), shuffle=False) ds = ds.shard(ctx.num_workers, ctx.worker_num) ds = ds.interleave(tf.data.TFRecordDataset) ds = ds.map(parse_tfr)