Skip to content

Commit

Permalink
increase prefetch distributed example (#833)
Browse files Browse the repository at this point in the history
  • Loading branch information
zsdonghao authored and wagamamaz committed Sep 10, 2018
1 parent 00e8dc3 commit 4621cc7
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ def build_validation(x, y_):
# validation_dataset = make_dataset(X_test, y_test)
# validation_dataset = training_dataset.map(data_aug_valid, num_parallel_calls=multiprocessing.cpu_count())
trainer = tl.distributed.Trainer(
build_training_func=build_train, training_dataset=training_dataset, batch_size=128,
optimizer=tf.train.RMSPropOptimizer, optimizer_args={'learning_rate': 0.0001}
build_training_func=build_train, training_dataset=training_dataset, optimizer=tf.train.AdamOptimizer,
optimizer_args={'learning_rate': 0.0001}, batch_size=128, num_epochs=50000, prefetch_buffer_size=4096
# validation_dataset=validation_dataset, build_validation_func=build_validation
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def build_validation(x, y_):
training_dataset = make_dataset(X_train, y_train)
# validation_dataset = make_dataset(X_val, y_val)
trainer = tl.distributed.Trainer(
build_training_func=build_train, training_dataset=training_dataset, batch_size=32,
optimizer=tf.train.RMSPropOptimizer, optimizer_args={'learning_rate': 0.001}
build_training_func=build_train, training_dataset=training_dataset, optimizer=tf.train.AdamOptimizer,
optimizer_args={'learning_rate': 0.001}, batch_size=500, num_epochs=500, prefetch_buffer_size=4096
# validation_dataset=validation_dataset, build_validation_func=build_validation
)

Expand Down

0 comments on commit 4621cc7

Please sign in to comment.