diff --git a/backends/arm/test/common.py b/backends/arm/test/common.py index 0d50f1882da..7cf0594b934 100644 --- a/backends/arm/test/common.py +++ b/backends/arm/test/common.py @@ -86,6 +86,29 @@ def is_option_enabled(option: str, fail_if_not_enabled: bool = False) -> bool: return False +def maybe_get_tosa_collate_path() -> str | None: + """ + Checks the environment variable TOSA_TESTCASES_BASE_PATH and returns the + path to the where to store the current tests if it is set. + """ + tosa_test_base = os.environ.get("TOSA_TESTCASES_BASE_PATH") + if tosa_test_base: + current_test = os.environ.get("PYTEST_CURRENT_TEST") + #'backends/arm/test/ops/test_mean_dim.py::TestMeanDim::test_meandim_tosa_BI_0_zeros (call)' + test_class = current_test.split("::")[1] + test_name = current_test.split("::")[-1].split(" ")[0] + if "BI" in test_name: + tosa_test_base = os.path.join(tosa_test_base, "tosa-bi") + elif "MI" in test_name: + tosa_test_base = os.path.join(tosa_test_base, "tosa-mi") + else: + tosa_test_base = os.path.join(tosa_test_base, "other") + + return os.path.join(tosa_test_base, test_class, test_name) + + return None + + def get_tosa_compile_spec( permute_memory_to_nhwc=True, custom_path=None ) -> list[CompileSpec]: @@ -101,7 +124,13 @@ def get_tosa_compile_spec_unbuilt( """Get the ArmCompileSpecBuilder for the default TOSA tests, to modify the compile spec before calling .build() to finalize it. """ - intermediate_path = custom_path or tempfile.mkdtemp(prefix="arm_tosa_") + if not custom_path: + intermediate_path = maybe_get_tosa_collate_path() or tempfile.mkdtemp( + prefix="arm_tosa_" + ) + else: + intermediate_path = custom_path + if not os.path.exists(intermediate_path): os.makedirs(intermediate_path, exist_ok=True) compile_spec_builder = ( diff --git a/backends/arm/test/misc/test_debug_feats.py b/backends/arm/test/misc/test_debug_feats.py index aa9703f9eba..1b77914a208 100644 --- a/backends/arm/test/misc/test_debug_feats.py +++ b/backends/arm/test/misc/test_debug_feats.py @@ -6,6 +6,7 @@ import logging import os +import shutil import tempfile import unittest @@ -149,3 +150,39 @@ def test_dump_ops_and_dtypes(self): .dump_operator_distribution() ) # Just test that there are no execeptions. + + +class TestCollateTosaTests(unittest.TestCase): + """Tests the collation of TOSA tests through setting the environment variable TOSA_TESTCASE_BASE_PATH.""" + + def test_collate_tosa_BI_tests(self): + # Set the environment variable to trigger the collation of TOSA tests + os.environ["TOSA_TESTCASES_BASE_PATH"] = "test_collate_tosa_tests" + # Clear out the directory + + model = Linear(20, 30) + ( + ArmTester( + model, + example_inputs=model.get_inputs(), + compile_spec=common.get_tosa_compile_spec(), + ) + .quantize() + .export() + .to_edge() + .partition() + .to_executorch() + ) + # test that the output directory is created and contains the expected files + assert os.path.exists( + "test_collate_tosa_tests/tosa-bi/TestCollateTosaTests/test_collate_tosa_BI_tests" + ) + assert os.path.exists( + "test_collate_tosa_tests/tosa-bi/TestCollateTosaTests/test_collate_tosa_BI_tests/output_tag8.tosa" + ) + assert os.path.exists( + "test_collate_tosa_tests/tosa-bi/TestCollateTosaTests/test_collate_tosa_BI_tests/desc_tag8.json" + ) + + os.environ.pop("TOSA_TESTCASES_BASE_PATH") + shutil.rmtree("test_collate_tosa_tests", ignore_errors=True)