diff --git a/symforce/codegen/codegen_util.py b/symforce/codegen/codegen_util.py index 91af73a2d..f2b4329d8 100644 --- a/symforce/codegen/codegen_util.py +++ b/symforce/codegen/codegen_util.py @@ -543,6 +543,13 @@ def _load_generated_package_internal(name: str, path: Path) -> T.Tuple[T.Any, T. def load_generated_package(name: str, path: T.Openable) -> T.Any: """ Dynamically load generated package (or module). + + name is the full name of the package or module to load (for example, + "pkg.sub_pkg" for a package called "sub_pkg" inside of another package + "pkg", or "pkg.sub_pkg.mod" for a module called "mod" inside of pkg.sub_pkg). + + path is the path to the directory (or __init__.py) of the package, or the python + file of the module. """ # NOTE(brad): We remove all possibly conflicting modules from the cache. This is # to ensure that when name is executed, it loads local modules (if any) rather @@ -574,6 +581,25 @@ def load_generated_package(name: str, path: T.Openable) -> T.Any: return module +def load_generated_function(func_name: str, path_to_package: T.Openable) -> T.Any: + """ + Loads the function with name func_name found inside the package located at + path_to_package. + + Preconditions: + path_to_package is a python package with an `__init__.py` containing a module + defined in `func_name.py` which in turn defines an attribute named `func_name`. + Note: the precondition will be satisfied if the package was generated by + `Codegen.generate_function` from a `Codegen` function with name `func_name`. + """ + pkg_path = Path(path_to_package) + if pkg_path.name == "__init__.py": + pkg_path = pkg_path.parent + pkg_name = pkg_path.name + func_module = load_generated_package(f"{pkg_name}.{func_name}", pkg_path / f"{func_name}.py") + return getattr(func_module, func_name) + + def load_generated_lcmtype( package: str, type_name: str, lcmtypes_path: T.Union[str, Path] ) -> T.Type: diff --git a/test/symforce_codegen_util_test.py b/test/symforce_codegen_util_test.py index 017392b98..e9872ea75 100644 --- a/test/symforce_codegen_util_test.py +++ b/test/symforce_codegen_util_test.py @@ -10,6 +10,10 @@ from symforce.codegen import codegen_util from symforce.test_util import TestCase +PKG_LOCATIONS = Path(__file__).parent / "test_data" / "codegen_util_test_data" +RELATIVE_PATH = Path("example_pkg", "__init__.py") +PACKAGE_NAME = "example_pkg" + class SymforceCodegenUtilTest(TestCase): """ @@ -22,14 +26,8 @@ def test_load_generated_package(self) -> None: codegen_util.load_generated_package """ - pkg_locations = Path(__file__).parent / "test_data" / "codegen_util_test_data" - - relative_path = Path("example_pkg", "__init__.py") - - package_name = "example_pkg" - pkg1 = codegen_util.load_generated_package( - name=package_name, path=pkg_locations / "example_pkg_1" / relative_path + name=PACKAGE_NAME, path=PKG_LOCATIONS / "example_pkg_1" / RELATIVE_PATH ) # Testing that the module was loaded correctly @@ -37,10 +35,10 @@ def test_load_generated_package(self) -> None: self.assertEqual(pkg1.sub_module.sub_module_id, 1) # Testing that sys.modules was not polluted - self.assertFalse(package_name in sys.modules) + self.assertFalse(PACKAGE_NAME in sys.modules) pkg2 = codegen_util.load_generated_package( - name=package_name, path=pkg_locations / "example_pkg_2" / relative_path + name=PACKAGE_NAME, path=PKG_LOCATIONS / "example_pkg_2" / RELATIVE_PATH ) # Testing that the module was loaded correctly when a module with the same name has @@ -48,6 +46,30 @@ def test_load_generated_package(self) -> None: self.assertEqual(pkg2.package_id, 2) self.assertEqual(pkg2.sub_module.sub_module_id, 2) + def test_load_generated_function(self) -> None: + """ + Tests: + codegen_util.load_generated_function + """ + + func1 = codegen_util.load_generated_function( + func_name="func", path_to_package=PKG_LOCATIONS / "example_pkg_1" / RELATIVE_PATH + ) + + # Testing that the function was loaded correctly + self.assertEqual(func1(), 1) + + # Testing that sys.modules was not polluted + self.assertFalse(PACKAGE_NAME in sys.modules) + + func2 = codegen_util.load_generated_function( + func_name="func", path_to_package=PKG_LOCATIONS / "example_pkg_2" / RELATIVE_PATH + ) + + # Testing that the function was loaded correctly when a function of the same name + # has already been loaded. + self.assertEqual(func2(), 2) + if __name__ == "__main__": TestCase.main() diff --git a/test/test_data/codegen_util_test_data/example_pkg_1/example_pkg/func.py b/test/test_data/codegen_util_test_data/example_pkg_1/example_pkg/func.py new file mode 100644 index 000000000..208476796 --- /dev/null +++ b/test/test_data/codegen_util_test_data/example_pkg_1/example_pkg/func.py @@ -0,0 +1,8 @@ +# ---------------------------------------------------------------------------- +# SymForce - Copyright 2022, Skydio, Inc. +# This source code is under the Apache 2.0 license found in the LICENSE file. +# ---------------------------------------------------------------------------- + + +def func() -> int: + return 1 diff --git a/test/test_data/codegen_util_test_data/example_pkg_2/example_pkg/func.py b/test/test_data/codegen_util_test_data/example_pkg_2/example_pkg/func.py new file mode 100644 index 000000000..a0ae2aa52 --- /dev/null +++ b/test/test_data/codegen_util_test_data/example_pkg_2/example_pkg/func.py @@ -0,0 +1,8 @@ +# ---------------------------------------------------------------------------- +# SymForce - Copyright 2022, Skydio, Inc. +# This source code is under the Apache 2.0 license found in the LICENSE file. +# ---------------------------------------------------------------------------- + + +def func() -> int: + return 2