Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 3 additions & 9 deletions benchmarks/operator_benchmark/benchmark_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def _build_test(configs, bench_op, OperatorTestCase, run_backward, op_name_funct
run_backward: a bool parameter indicating backward path
op_name_function: a dictionary includes operator name and function
"""
test_list = []
for config in configs:
test_attrs = {}
tags = None
Expand Down Expand Up @@ -132,7 +131,7 @@ def _build_test(configs, bench_op, OperatorTestCase, run_backward, op_name_funct
# which use auto_set().
if op._num_inputs_require_grads > 0:
input_name = 'all'
test_list.append(_create_test(op, test_attrs, tags, OperatorTestCase, run_backward, input_name))
yield _create_test(op, test_attrs, tags, OperatorTestCase, run_backward, input_name)

# This for loop is only used when auto_set is used.
# _pass_count counts how many times init has been called.
Expand All @@ -147,9 +146,7 @@ def _build_test(configs, bench_op, OperatorTestCase, run_backward, op_name_funct
new_op.init(**init_dict)
# Input name index will start from input1
input_name = i + 1
test_list.append(_create_test(new_op, test_attrs, tags, OperatorTestCase, run_backward, input_name))

return test_list
yield _create_test(new_op, test_attrs, tags, OperatorTestCase, run_backward, input_name)


class BenchmarkRunner(object):
Expand Down Expand Up @@ -362,10 +359,7 @@ def run(self):
self._print_header()

for test_metainfo in BENCHMARK_TESTER:
# If auto_set is used, _build_test will return a list of tests including
# forward and backward ones
test_list = _build_test(*test_metainfo)
for test in test_list:
for test in _build_test(*test_metainfo):
full_test_id, test_case = test
op_test_config = test_case.test_config

Expand Down