Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions examples/mnist/tf/mnist_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,14 @@ def map_fun(args, ctx):
def read_csv_examples(image_dir, label_dir, batch_size=100, num_epochs=None, task_index=None, num_workers=None):
print_log(worker_num, "num_epochs: {0}".format(num_epochs))
# Setup queue of csv image filenames
tf_record_pattern = os.path.join(image_dir, 'part-*')
images = tf.gfile.Glob(tf_record_pattern)
csv_file_pattern = os.path.join(image_dir, 'part-*')
images = tf.gfile.Glob(csv_file_pattern)
print_log(worker_num, "images: {0}".format(images))
image_queue = tf.train.string_input_producer(images, shuffle=False, capacity=1000, num_epochs=num_epochs, name="image_queue")

# Setup queue of csv label filenames
tf_record_pattern = os.path.join(label_dir, 'part-*')
labels = tf.gfile.Glob(tf_record_pattern)
csv_file_pattern = os.path.join(label_dir, 'part-*')
labels = tf.gfile.Glob(csv_file_pattern)
print_log(worker_num, "labels: {0}".format(labels))
label_queue = tf.train.string_input_producer(labels, shuffle=False, capacity=1000, num_epochs=num_epochs, name="label_queue")

Expand Down
10 changes: 4 additions & 6 deletions examples/mnist/tf/mnist_dist_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,16 +66,15 @@ def _parse_tfr(example_proto):
# Dataset for input data
image_dir = TFNode.hdfs_path(ctx, args.images_labels)
file_pattern = os.path.join(image_dir, 'part-*')
files = tf.gfile.Glob(file_pattern)

ds = tf.data.Dataset.list_files(file_pattern)
ds = ds.shard(num_workers, task_index).repeat(args.epochs).shuffle(args.shuffle_size)
if args.format == 'csv2':
ds = tf.data.TextLineDataset(files)
ds = ds.interleave(tf.data.TextLineDataset, cycle_length=args.readers, block_length=1)
parse_fn = _parse_csv
else: # args.format == 'tfr'
ds = tf.data.TFRecordDataset(files)
ds = ds.interleave(tf.data.TFRecordDataset, cycle_length=args.readers, block_length=1)
parse_fn = _parse_tfr

ds = ds.shard(num_workers, task_index).repeat(args.epochs).shuffle(args.shuffle_size)
ds = ds.map(parse_fn).batch(args.batch_size)
iterator = ds.make_initializable_iterator()
x, y_ = iterator.get_next()
Expand Down Expand Up @@ -159,7 +158,6 @@ def _parse_tfr(example_proto):
# See `tf.train.SyncReplicasOptimizer` for additional details on how to
# perform *synchronous* training.

# using QueueRunners/Readers
if args.mode == "train":
if (step % 100 == 0):
print("{0} step: {1} accuracy: {2}".format(datetime.now().isoformat(), step, sess.run(accuracy)))
Expand Down
7 changes: 3 additions & 4 deletions examples/mnist/tf/mnist_dist_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,12 @@ def _parse_tfr(example_proto):
sm_b = tf.Variable(tf.zeros([10]), name="sm_b")
tf.summary.histogram("softmax_weights", sm_w)

# read from saved tf records
# Read from saved tf records
images = TFNode.hdfs_path(ctx, args.tfrecord_dir)
tf_record_pattern = os.path.join(images, 'part-*')
tfr_files = tf.gfile.Glob(tf_record_pattern)
ds = tf.data.TFRecordDataset(tfr_files)
ds = tf.data.Dataset.list_files(tf_record_pattern)
ds = ds.shard(num_workers, task_index).repeat(args.epochs).shuffle(args.shuffle_size)
ds = ds.interleave(tf.data.TFRecordDataset, cycle_length=args.readers, block_length=1)
ds = ds.map(_parse_tfr).batch(args.batch_size)
iterator = ds.make_initializable_iterator()
x, y_ = iterator.get_next()
Expand Down Expand Up @@ -122,7 +122,6 @@ def _parse_tfr(example_proto):
# See `tf.train.SyncReplicasOptimizer` for additional details on how to
# perform *synchronous* training.

# using QueueRunners/Readers
if (step % 100 == 0):
print("{0} step: {1} accuracy: {2}".format(datetime.now().isoformat(), step, sess.run(accuracy)))
_, summary, step = sess.run([train_op, summary_op, global_step])
Expand Down
2 changes: 1 addition & 1 deletion examples/mnist/tf/mnist_spark_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
parser.add_argument("--num_ps", help="number of ps nodes", default=1)
parser.add_argument("--output", help="HDFS path to save test/inference output", default="predictions")
parser.add_argument("--rdma", help="use rdma connection", default=False)
parser.add_argument("--readers", help="number of reader/enqueue threads", type=int, default=1)
parser.add_argument("--readers", help="number of reader/enqueue threads per worker", type=int, default=10)
parser.add_argument("--shuffle_size", help="size of shuffle buffer", type=int, default=1000)
parser.add_argument("--steps", help="maximum number of steps", type=int, default=1000)
parser.add_argument("--tensorboard", help="launch tensorboard process", action="store_true")
Expand Down
2 changes: 1 addition & 1 deletion examples/mnist/tf/mnist_spark_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
parser.add_argument("-p", "--driver_ps_nodes", help="""run tensorflow PS node on driver locally.
You will need to set cluster_size = num_executors + num_ps""", default=False)
parser.add_argument("--protocol", help="Tensorflow network protocol (grpc|rdma)", default="grpc")
parser.add_argument("--readers", help="number of reader/enqueue threads", type=int, default=1)
parser.add_argument("--readers", help="number of reader/enqueue threads per worker", type=int, default=10)
parser.add_argument("--steps", help="maximum number of steps", type=int, default=1000)
parser.add_argument("--tensorboard", help="launch tensorboard process", action="store_true")
parser.add_argument("--shuffle_size", help="size of shuffle buffer", type=int, default=1000)
Expand Down