From a9a9823d5227829205cf437592bba26e7183e3b1 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Wed, 1 Jun 2022 17:16:11 +0200 Subject: [PATCH] Fix class resolving on windows --- src/zenml/utils/source_utils.py | 6 +++--- tests/unit/utils/test_source_utils.py | 18 ++++++++++++++++++ 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/src/zenml/utils/source_utils.py b/src/zenml/utils/source_utils.py index d72bc3cfc59..252d3ed1500 100644 --- a/src/zenml/utils/source_utils.py +++ b/src/zenml/utils/source_utils.py @@ -193,7 +193,7 @@ def get_module_source_from_module(module: ModuleType) -> str: # Kick out the .py and replace `/` with `.` to get the module source module_path = module_path.replace(".py", "") - module_source = module_path.replace("/", ".") + module_source = module_path.replace(os.path.sep, ".") logger.debug( f"Resolved module source for module {module} to: {module_source}" @@ -210,7 +210,7 @@ def get_relative_path_from_module_source(module_source: str) -> str: Args: module_source: A module e.g. zenml.core.step """ - return module_source.replace(".", "/") + return module_source.replace(".", os.path.sep) def get_absolute_path_from_module_source(module: str) -> str: @@ -470,7 +470,7 @@ def import_python_file(file_path: str) -> types.ModuleType: # module full_module_path = os.path.splitext( os.path.relpath(file_path, os.getcwd()) - )[0].replace("/", ".") + )[0].replace(os.path.sep, ".") if full_module_path not in sys.modules: with prepend_python_path(os.path.dirname(file_path)): diff --git a/tests/unit/utils/test_source_utils.py b/tests/unit/utils/test_source_utils.py index 26aa69e64a7..128e511b1cb 100644 --- a/tests/unit/utils/test_source_utils.py +++ b/tests/unit/utils/test_source_utils.py @@ -34,6 +34,24 @@ def test_is_third_party_module(): assert not source_utils.is_third_party_module(non_third_party_file) +class EmptyClass: + pass + + +def test_resolve_class(): + """Tests that class resolving works as expected.""" + os.getcwd() + parent_directory = os.path.dirname(os.path.dirname(__file__)) + os.chdir(parent_directory) + try: + assert ( + source_utils.resolve_class(EmptyClass) + == "utils.test_source_utils.EmptyClass" + ) + finally: + os.chdir(parent_directory) + + def test_get_source(): """Tests if source of objects is gotten properly.""" assert source_utils.get_source(pytest.Cache)