Skip to content

Commit

Permalink
Add codegen_util.load_generated_function
Browse files Browse the repository at this point in the history
Defines the new function load_generated_function which is
meant to be a more user friendly version of `load_generated_package`
(which will be particularly more friendly should we no longer re-export
all sub-modules of generated packages).

Also, explains the arguments of `load_generated_package` in its
doc-string.

As an example of the expected usage of this function, for `co` a
`Codegen` object:
``` python3
generated_paths = co.generate_function()

func = load_generated_function(co.name, generated_paths.function_dir)
```

Tested in `test/symforce_codegen_util_test.py`
  • Loading branch information
bradley-solliday-skydio committed Nov 3, 2022
1 parent 443997a commit 63ad07b
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 9 deletions.
26 changes: 26 additions & 0 deletions symforce/codegen/codegen_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,13 @@ def _load_generated_package_internal(name: str, path: Path) -> T.Tuple[T.Any, T.
def load_generated_package(name: str, path: T.Openable) -> T.Any:
"""
Dynamically load generated package (or module).
name is the full name of the package or module to load (for example,
"pkg.sub_pkg" for a package called "sub_pkg" inside of another package
"pkg", or "pkg.sub_pkg.mod" for a module called "mod" inside of pkg.sub_pkg).
path is the path to the directory (or __init__.py) of the package, or the python
file of the module.
"""
# NOTE(brad): We remove all possibly conflicting modules from the cache. This is
# to ensure that when name is executed, it loads local modules (if any) rather
Expand Down Expand Up @@ -574,6 +581,25 @@ def load_generated_package(name: str, path: T.Openable) -> T.Any:
return module


def load_generated_function(func_name: str, path_to_package: T.Openable) -> T.Any:
"""
Loads the function with name func_name found inside the package located at
path_to_package.
Preconditions:
path_to_package is a python package with an `__init__.py` containing a module
defined in `func_name.py` which in turn defines an attribute named `func_name`.
Note: the precondition will be satisfied if the package was generated by
`Codegen.generate_function` from a `Codegen` function with name `func_name`.
"""
pkg_path = Path(path_to_package)
if pkg_path.name == "__init__.py":
pkg_path = pkg_path.parent
pkg_name = pkg_path.name
func_module = load_generated_package(f"{pkg_name}.{func_name}", pkg_path / f"{func_name}.py")
return getattr(func_module, func_name)


def load_generated_lcmtype(
package: str, type_name: str, lcmtypes_path: T.Union[str, Path]
) -> T.Type:
Expand Down
40 changes: 31 additions & 9 deletions test/symforce_codegen_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
from symforce.codegen import codegen_util
from symforce.test_util import TestCase

PKG_LOCATIONS = Path(__file__).parent / "test_data" / "codegen_util_test_data"
RELATIVE_PATH = Path("example_pkg", "__init__.py")
PACKAGE_NAME = "example_pkg"


class SymforceCodegenUtilTest(TestCase):
"""
Expand All @@ -22,32 +26,50 @@ def test_load_generated_package(self) -> None:
codegen_util.load_generated_package
"""

pkg_locations = Path(__file__).parent / "test_data" / "codegen_util_test_data"

relative_path = Path("example_pkg", "__init__.py")

package_name = "example_pkg"

pkg1 = codegen_util.load_generated_package(
name=package_name, path=pkg_locations / "example_pkg_1" / relative_path
name=PACKAGE_NAME, path=PKG_LOCATIONS / "example_pkg_1" / RELATIVE_PATH
)

# Testing that the module was loaded correctly
self.assertEqual(pkg1.package_id, 1)
self.assertEqual(pkg1.sub_module.sub_module_id, 1)

# Testing that sys.modules was not polluted
self.assertFalse(package_name in sys.modules)
self.assertFalse(PACKAGE_NAME in sys.modules)

pkg2 = codegen_util.load_generated_package(
name=package_name, path=pkg_locations / "example_pkg_2" / relative_path
name=PACKAGE_NAME, path=PKG_LOCATIONS / "example_pkg_2" / RELATIVE_PATH
)

# Testing that the module was loaded correctly when a module with the same name has
# already been loaded
self.assertEqual(pkg2.package_id, 2)
self.assertEqual(pkg2.sub_module.sub_module_id, 2)

def test_load_generated_function(self) -> None:
"""
Tests:
codegen_util.load_generated_function
"""

func1 = codegen_util.load_generated_function(
func_name="func", path_to_package=PKG_LOCATIONS / "example_pkg_1" / RELATIVE_PATH
)

# Testing that the function was loaded correctly
self.assertEqual(func1(), 1)

# Testing that sys.modules was not polluted
self.assertFalse(PACKAGE_NAME in sys.modules)

func2 = codegen_util.load_generated_function(
func_name="func", path_to_package=PKG_LOCATIONS / "example_pkg_2" / RELATIVE_PATH
)

# Testing that the function was loaded correctly when a function of the same name
# has already been loaded.
self.assertEqual(func2(), 2)


if __name__ == "__main__":
TestCase.main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# ----------------------------------------------------------------------------
# SymForce - Copyright 2022, Skydio, Inc.
# This source code is under the Apache 2.0 license found in the LICENSE file.
# ----------------------------------------------------------------------------


def func() -> int:
return 1
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# ----------------------------------------------------------------------------
# SymForce - Copyright 2022, Skydio, Inc.
# This source code is under the Apache 2.0 license found in the LICENSE file.
# ----------------------------------------------------------------------------


def func() -> int:
return 2

0 comments on commit 63ad07b

Please sign in to comment.