Skip to content

Commit 96cc4d9

Browse files
authored
Bug: Executor - Fix executor for Benchmark Execution Without Explicit Framework Field (#636)
**Description** Fix executor for Benchmark Execution Without Explicit Framework Field
1 parent 7af75df commit 96cc4d9

File tree

4 files changed

+68
-24
lines changed

4 files changed

+68
-24
lines changed

superbench/executor/executor.py

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -228,29 +228,16 @@ def exec(self):
228228
logger.warning('Monitor can not support CPU platform.')
229229

230230
benchmark_real_name = benchmark_name.split(':')[0]
231-
if 'frameworks' in benchmark_config:
232-
for framework in benchmark_config.frameworks or [Framework.NONE.value]:
233-
if benchmark_real_name == 'model-benchmarks' or (
234-
':' not in benchmark_name and benchmark_name.endswith('_models')
235-
):
236-
for model in benchmark_config.models:
237-
full_name = f'{benchmark_name}/{framework}-{model}'
238-
logger.info('Executor is going to execute %s.', full_name)
239-
context = BenchmarkRegistry.create_benchmark_context(
240-
model,
241-
platform=self.__get_platform(),
242-
framework=Framework(framework.lower()),
243-
parameters=self.__get_arguments(
244-
{} if 'parameters' not in benchmark_config else benchmark_config.parameters
245-
)
246-
)
247-
result = self.__exec_benchmark(full_name, context)
248-
benchmark_results.append(result)
249-
else:
250-
full_name = benchmark_name
231+
frameworks = benchmark_config.get('frameworks', [Framework.NONE.value])
232+
for framework in frameworks:
233+
if benchmark_real_name == 'model-benchmarks' or (
234+
':' not in benchmark_name and benchmark_name.endswith('_models')
235+
):
236+
for model in benchmark_config.models:
237+
full_name = f'{benchmark_name}/{framework}-{model}'
251238
logger.info('Executor is going to execute %s.', full_name)
252239
context = BenchmarkRegistry.create_benchmark_context(
253-
benchmark_real_name,
240+
model,
254241
platform=self.__get_platform(),
255242
framework=Framework(framework.lower()),
256243
parameters=self.__get_arguments(
@@ -259,6 +246,18 @@ def exec(self):
259246
)
260247
result = self.__exec_benchmark(full_name, context)
261248
benchmark_results.append(result)
249+
else:
250+
full_name = benchmark_name
251+
logger.info('Executor is going to execute %s.', full_name)
252+
context = BenchmarkRegistry.create_benchmark_context(
253+
benchmark_real_name,
254+
platform=self.__get_platform(),
255+
framework=Framework(framework.lower()),
256+
parameters=self.
257+
__get_arguments({} if 'parameters' not in benchmark_config else benchmark_config.parameters)
258+
)
259+
result = self.__exec_benchmark(full_name, context)
260+
benchmark_results.append(result)
262261

263262
if monitor:
264263
monitor.stop()

superbench/runner/runner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def __validate_sb_config(self): # noqa: C901
8484
if 'proc_num' not in mode:
8585
self._sb_benchmarks[name].modes[idx].proc_num = 8
8686
elif mode.name == 'mpi':
87-
if 'machinefile' not in mode:
87+
if 'mca' not in mode:
8888
self._sb_benchmarks[name].modes[idx].mca = {
8989
'pml': 'ob1',
9090
'btl': '^openib',
@@ -448,7 +448,7 @@ def _run_proc(self, benchmark_name, mode, vars):
448448
mode.env.update({'SB_MODE_SERIAL_INDEX': mode.serial_index, 'SB_MODE_PARALLEL_INDEX': mode.parallel_index})
449449
logger.info('Runner is going to run %s in %s mode, proc rank %d.', benchmark_name, mode.name, mode.proc_rank)
450450

451-
timeout = self._sb_benchmarks[benchmark_name].get('timeout', 60)
451+
timeout = self._sb_benchmarks[benchmark_name].get('timeout', None)
452452
if isinstance(timeout, int):
453453
timeout = max(timeout, 60)
454454

tests/executor/test_executor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,5 +166,7 @@ def test_exec_default_benchmarks(self, mock_launch_benchmark):
166166
self.assertTrue(p.is_dir())
167167
self.assertTrue((p / 'results.json').is_file())
168168
with (p / 'results.json').open() as f:
169-
for result in json.load(f):
169+
results = json.load(f)
170+
self.assertTrue(len(results) > 0)
171+
for result in results:
170172
self.assertIn(benchmark_name, result['name'])

tests/runner/test_runner.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,22 @@ def test_set_logger(self):
4141
expected_log_file = Path(self.runner._sb_output_dir) / 'sb-run.log'
4242
self.assertTrue(expected_log_file.is_file())
4343

44+
def test_validate_sb_config(self):
45+
"""Test validate_sb_config."""
46+
self.runner._SuperBenchRunner__validate_sb_config()
47+
self.assertIn('env', self.runner._sb_config.superbench)
48+
for name in self.runner._sb_benchmarks:
49+
self.assertIn('modes', self.runner._sb_config.superbench.benchmarks[name])
50+
for mode in self.runner._sb_config.superbench.benchmarks[name].modes:
51+
self.assertIn('env', mode)
52+
if mode.name == 'local':
53+
self.assertIn('proc_num', mode)
54+
self.assertIn('prefix', mode)
55+
if mode.name == 'torch.distributed':
56+
self.assertIn('proc_num', mode)
57+
if mode.name == 'mpi':
58+
self.assertIn('mca', mode)
59+
4460
def test_get_failure_count(self):
4561
"""Test get_failure_count."""
4662
self.assertEqual(0, self.runner.get_failure_count())
@@ -410,3 +426,30 @@ def test_generate_metric_name(self):
410426
test_case['run_count'], test_case['curr_rank'], test_case['curr_run']
411427
), test_case['expected']
412428
)
429+
430+
def test_run_proc_timeout(self):
431+
"""Test run_proc_ timeout."""
432+
self.runner._sb_benchmarks = {
433+
'benchmark1': {
434+
'timeout': 120
435+
},
436+
'benchmark2': {
437+
'timeout': None
438+
},
439+
'benchmark3': {
440+
'timeout': 30
441+
},
442+
}
443+
444+
test_cases = [
445+
('benchmark1', 120),
446+
('benchmark2', None),
447+
('benchmark3', 60),
448+
]
449+
450+
for benchmark_name, expected_timeout in test_cases:
451+
with self.subTest(benchmark_name=benchmark_name):
452+
timeout = self.runner._sb_benchmarks[benchmark_name].get('timeout', None)
453+
if isinstance(timeout, int):
454+
timeout = max(timeout, 60)
455+
self.assertEqual(timeout, expected_timeout)

0 commit comments

Comments
 (0)