-
Notifications
You must be signed in to change notification settings - Fork 45.5k
Add minor performance improvements to resnet input pipeline #4247
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
e04280e
d93b56e
ee0d0e8
ed4e723
9f58547
ea6d6aa
09300a4
838060b
1137995
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -79,7 +79,8 @@ def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer, | |
tf.contrib.data.map_and_batch( | ||
lambda value: parse_record_fn(value, is_training), | ||
batch_size=batch_size, | ||
num_parallel_batches=1)) | ||
num_parallel_batches=1, | ||
drop_remainder=True)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This could matter for cifar with only 60k images. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It should not matter much as long as the batch size is reasonable, since this only drops the part of the dataset that doesn't fit into a full batch. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right, it only drops the last partial batch. so for 60k images for a batch size of 2048 also it should drop only 608 images or so. |
||
|
||
# Operations between the final prefetch and the get_next call to the iterator | ||
# will happen synchronously during run time. We prefetch here again to | ||
|
@@ -111,7 +112,7 @@ def get_synth_input_fn(height, width, num_channels, num_classes): | |
""" | ||
def input_fn(is_training, data_dir, batch_size, *args, **kwargs): # pylint: disable=unused-argument | ||
images = tf.zeros((batch_size, height, width, num_channels), tf.float32) | ||
labels = tf.zeros((batch_size, num_classes), tf.int32) | ||
labels = tf.zeros((batch_size), tf.int32) | ||
return tf.data.Dataset.from_tensors((images, labels)).repeat() | ||
|
||
return input_fn | ||
|
@@ -227,8 +228,8 @@ def resnet_model_fn(features, labels, mode, model_class, | |
}) | ||
|
||
# Calculate loss, which includes softmax cross entropy and L2 regularization. | ||
cross_entropy = tf.losses.softmax_cross_entropy( | ||
logits=logits, onehot_labels=labels) | ||
cross_entropy = tf.losses.sparse_softmax_cross_entropy( | ||
logits=logits, labels=labels) | ||
|
||
# Create a tensor named cross_entropy for logging purposes. | ||
tf.identity(cross_entropy, name='cross_entropy') | ||
|
@@ -282,8 +283,7 @@ def exclude_batch_norm(name): | |
train_op = None | ||
|
||
if not tf.contrib.distribute.has_distribution_strategy(): | ||
accuracy = tf.metrics.accuracy( | ||
tf.argmax(labels, axis=1), predictions['classes']) | ||
accuracy = tf.metrics.accuracy(labels, predictions['classes']) | ||
else: | ||
# Metrics are currently not compatible with distribution strategies during | ||
# training. This does not affect the overall performance of the model. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you get a chance to check how sensitive performance is to this on both something big (DGX, V100 GCE) and something small (1x K80/P100)? I prefer not to have a performance flag unless it makes a big difference. And if it is a constant it would be nice to have a brief comment so it isn't just a magic number.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No I haven't had the chance to run on things other than DGX-1V. I don't think the performance difference will show up on K80s because the input pipeline will not be the bottleneck. But I haven't tested it. I am talking to the input team to figure out if a constant here makes sense, or should this be tuned (in which case we may need to just remove it)