Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Benchmarks for flipping and random flipping #15348

Merged
merged 2 commits into from
Dec 14, 2017
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
83 changes: 83 additions & 0 deletions tensorflow/python/ops/image_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,89 @@ def testInvalidShapes(self):
self._adjustHueTf(x_np, delta_h)


class FlipImageBenchmark(test.Benchmark):

def _benchmarkFlipLeftRight(self, device, cpu_count):
image_shape = [299, 299, 3]
warmup_rounds = 100
benchmark_rounds = 1000
config = config_pb2.ConfigProto()
if cpu_count is not None:
config.inter_op_parallelism_threads = 1
config.intra_op_parallelism_threads = cpu_count
with session.Session("", graph=ops.Graph(), config=config) as sess:
with ops.device(device):
inputs = variables.Variable(
random_ops.random_uniform(
image_shape, dtype=dtypes.float32) * 255,
trainable=False,
dtype=dtypes.float32)
run_op = image_ops.flip_left_right(inputs)
sess.run(variables.global_variables_initializer())
for i in xrange(warmup_rounds + benchmark_rounds):
if i == warmup_rounds:
start = time.time()
sess.run(run_op)
end = time.time()
step_time = (end - start) / benchmark_rounds
tag = device + "_%s" % (cpu_count if cpu_count is not None else "_all")
print("benchmarkFlipLeftRight_299_299_3_%s step_time: %.2f us" %
(tag, step_time * 1e6))
self.report_benchmark(
name="benchmarkFlipLeftRight_299_299_3_%s" % (tag),
iters=benchmark_rounds,
wall_time=step_time)

def _benchmarkRandomFlipLeftRight(self, device, cpu_count):
image_shape = [299, 299, 3]
warmup_rounds = 100
benchmark_rounds = 1000
config = config_pb2.ConfigProto()
if cpu_count is not None:
config.inter_op_parallelism_threads = 1
config.intra_op_parallelism_threads = cpu_count
with session.Session("", graph=ops.Graph(), config=config) as sess:
with ops.device(device):
inputs = variables.Variable(
random_ops.random_uniform(
image_shape, dtype=dtypes.float32) * 255,
trainable=False,
dtype=dtypes.float32)
run_op = image_ops.random_flip_left_right(inputs)
sess.run(variables.global_variables_initializer())
for i in xrange(warmup_rounds + benchmark_rounds):
if i == warmup_rounds:
start = time.time()
sess.run(run_op)
end = time.time()
step_time = (end - start) / benchmark_rounds
tag = device + "_%s" % (cpu_count if cpu_count is not None else "_all")
print("benchmarkRandomFlipLeftRight_299_299_3_%s step_time: %.2f us" %
(tag, step_time * 1e6))
self.report_benchmark(
name="benchmarkRandomFlipLeftRight_299_299_3_%s" % (tag),
iters=benchmark_rounds,
wall_time=step_time)

def benchmarkFlipLeftRightCpu1(self):
self._benchmarkFlipLeftRight("/cpu:0", 1)

def benchmarkFlipLeftRightCpuAll(self):
self._benchmarkFlipLeftRight("/cpu:0", None)

def benchmarkFlipLeftRightGpu(self):
self._benchmarkFlipLeftRight(test.gpu_device_name(), None)

def benchmarkRandomFlipLeftRightCpu1(self):
self._benchmarkRandomFlipLeftRight("/cpu:0", 1)

def benchmarkRandomFlipLeftRightCpuAll(self):
self._benchmarkRandomFlipLeftRight("/cpu:0", None)

def benchmarkRandomFlipLeftRightGpu(self):
self._benchmarkRandomFlipLeftRight(test.gpu_device_name(), None)


class AdjustHueBenchmark(test.Benchmark):

def _benchmarkAdjustHue(self, device, cpu_count):
Expand Down