Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion official/resnet/cifar10_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ def parse_record(raw_record, is_training):
# The first byte represents the label, which we convert from uint8 to int32
# and then to one-hot.
label = tf.cast(record_vector[0], tf.int32)
label = tf.one_hot(label, _NUM_CLASSES)

# The remaining bytes after the label represent the image, which we reshape
# from [depth * height * width] to [depth, height, width].
Expand Down
4 changes: 2 additions & 2 deletions official/resnet/cifar10_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,13 @@ def test_dataset_input_fn(self):
lambda val: cifar10_main.parse_record(val, False))
image, label = fake_dataset.make_one_shot_iterator().get_next()

self.assertAllEqual(label.shape, (10,))
self.assertAllEqual(label.shape, ())
self.assertAllEqual(image.shape, (_HEIGHT, _WIDTH, _NUM_CHANNELS))

with self.test_session() as sess:
image, label = sess.run([image, label])

self.assertAllEqual(label, np.array([int(i == 7) for i in range(10)]))
self.assertEqual(label, 7)

for row in image:
for pixel in row:
Expand Down
13 changes: 8 additions & 5 deletions official/resnet/imagenet_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
}

_NUM_TRAIN_FILES = 1024
_SHUFFLE_BUFFER = 1500
_SHUFFLE_BUFFER = 10000

DATASET_NAME = 'ImageNet'

Expand Down Expand Up @@ -152,8 +152,6 @@ def parse_record(raw_record, is_training):
num_channels=_NUM_CHANNELS,
is_training=is_training)

label = tf.one_hot(tf.reshape(label, shape=[]), _NUM_CLASSES)

return image, label


Expand All @@ -176,8 +174,13 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1):
# Shuffle the input files
dataset = dataset.shuffle(buffer_size=_NUM_TRAIN_FILES)

# Convert to individual records
dataset = dataset.flat_map(tf.data.TFRecordDataset)
# Convert to individual records.
# cycle_length = 10 means 10 files will be read and deserialized in parallel.
# This number is low enough to not cause too much contention on small systems
# but high enough to provide the benefits of parallelization. You may want
# to increase this number if you have a large number of CPU cores.
dataset = dataset.apply(tf.contrib.data.parallel_interleave(
tf.data.TFRecordDataset, cycle_length=10))

return resnet_run_loop.process_record_dataset(
dataset, is_training, batch_size, _SHUFFLE_BUFFER, parse_record,
Expand Down
12 changes: 6 additions & 6 deletions official/resnet/resnet_run_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer,
tf.contrib.data.map_and_batch(
lambda value: parse_record_fn(value, is_training),
batch_size=batch_size,
num_parallel_batches=1))
num_parallel_batches=1,
drop_remainder=True))

# Operations between the final prefetch and the get_next call to the iterator
# will happen synchronously during run time. We prefetch here again to
Expand Down Expand Up @@ -111,7 +112,7 @@ def get_synth_input_fn(height, width, num_channels, num_classes):
"""
def input_fn(is_training, data_dir, batch_size, *args, **kwargs): # pylint: disable=unused-argument
images = tf.zeros((batch_size, height, width, num_channels), tf.float32)
labels = tf.zeros((batch_size, num_classes), tf.int32)
labels = tf.zeros((batch_size), tf.int32)
return tf.data.Dataset.from_tensors((images, labels)).repeat()

return input_fn
Expand Down Expand Up @@ -227,8 +228,8 @@ def resnet_model_fn(features, labels, mode, model_class,
})

# Calculate loss, which includes softmax cross entropy and L2 regularization.
cross_entropy = tf.losses.softmax_cross_entropy(
logits=logits, onehot_labels=labels)
cross_entropy = tf.losses.sparse_softmax_cross_entropy(
logits=logits, labels=labels)

# Create a tensor named cross_entropy for logging purposes.
tf.identity(cross_entropy, name='cross_entropy')
Expand Down Expand Up @@ -282,8 +283,7 @@ def exclude_batch_norm(name):
train_op = None

if not tf.contrib.distribute.has_distribution_strategy():
accuracy = tf.metrics.accuracy(
tf.argmax(labels, axis=1), predictions['classes'])
accuracy = tf.metrics.accuracy(labels, predictions['classes'])
else:
# Metrics are currently not compatible with distribution strategies during
# training. This does not affect the overall performance of the model.
Expand Down