From 2eed5e5575e8b7200fd6b30f2884918b221c272d Mon Sep 17 00:00:00 2001 From: qsbao Date: Thu, 13 Feb 2020 16:12:52 +0800 Subject: [PATCH] list_files in inference examples should be deterministically --- examples/mnist/estimator/mnist_inference.py | 2 +- examples/mnist/keras/mnist_inference.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/mnist/estimator/mnist_inference.py b/examples/mnist/estimator/mnist_inference.py index 2b2c9cb3..8e520d00 100644 --- a/examples/mnist/estimator/mnist_inference.py +++ b/examples/mnist/estimator/mnist_inference.py @@ -47,7 +47,7 @@ def parse_tfr(example_proto): return (image, label) # define a new tf.data.Dataset (for inferencing) - ds = tf.data.Dataset.list_files("{}/part-*".format(args.images_labels)) + ds = tf.data.Dataset.list_files("{}/part-*".format(args.images_labels), shuffle=False) ds = ds.shard(num_workers, worker_num) ds = ds.interleave(tf.data.TFRecordDataset) ds = ds.map(parse_tfr) diff --git a/examples/mnist/keras/mnist_inference.py b/examples/mnist/keras/mnist_inference.py index 21df737a..401f89f8 100644 --- a/examples/mnist/keras/mnist_inference.py +++ b/examples/mnist/keras/mnist_inference.py @@ -38,7 +38,7 @@ def parse_tfr(example_proto): return (image, label) # define a new tf.data.Dataset (for inferencing) - ds = tf.data.Dataset.list_files("{}/part-*".format(args.images_labels)) + ds = tf.data.Dataset.list_files("{}/part-*".format(args.images_labels), shuffle=False) ds = ds.shard(ctx.num_workers, ctx.worker_num) ds = ds.interleave(tf.data.TFRecordDataset) ds = ds.map(parse_tfr)