diff --git a/launch_testing/launch_testing/pytest/hooks.py b/launch_testing/launch_testing/pytest/hooks.py index 4e55f5d2d..86c71d16c 100644 --- a/launch_testing/launch_testing/pytest/hooks.py +++ b/launch_testing/launch_testing/pytest/hooks.py @@ -13,6 +13,7 @@ # limitations under the License. import copy +import pathlib from unittest import TestCase from _pytest._code.code import ReprFileLocation @@ -23,6 +24,13 @@ from ..test_runner import LaunchTestRunner +def _pytest_version_ge(major, minor=0, patch=0): + """Return True if pytest version is >= the given version.""" + pytest_version = tuple([int(v) for v in pytest.__version__.split('.')]) + assert pytest_version + return pytest_version >= (major, minor, patch) + + class LaunchTestFailure(Exception): def __init__(self, message, results): @@ -119,29 +127,42 @@ def repr_failure(self, excinfo): return super().repr_failure(excinfo) def reportinfo(self): - return self.fspath, 0, 'launch tests: {}'.format(self.name) + if _pytest_version_ge(7): + path = self.path + else: + path = self.fspath + return path, 0, 'launch tests: {}'.format(self.name) class LaunchTestModule(pytest.File): - def __init__(self, parent, *, fspath): - super().__init__(parent=parent, fspath=fspath) + def __init__(self, *args, **kwargs): + if _pytest_version_ge(7): + if 'fspath' in kwargs: + if kwargs['fspath'] is not None: + kwargs['path'] = pathlib.Path(kwargs['fspath']) + del kwargs['fspath'] + super().__init__(*args, **kwargs) @classmethod - def from_parent(cls, parent, *, fspath): + def from_parent(cls, *args, **kwargs): """Override from_parent for compatibility.""" # pytest.File.from_parent didn't exist before pytest 5.4 - if hasattr(super(), 'from_parent'): - instance = getattr(super(), 'from_parent')(parent=parent, fspath=fspath) - else: - instance = cls(parent=parent, fspath=fspath) - return instance + if _pytest_version_ge(5, 4): + return super().from_parent(*args, **kwargs) + args_without_parent = args[1:] + return cls(*args_without_parent, **kwargs) def makeitem(self, *args, **kwargs): return LaunchTestItem.from_parent(*args, **kwargs) def collect(self): - module = self.fspath.pyimport() + if _pytest_version_ge(7): + # self.path exists since 7 + from _pytest.pathlib import import_path + module = import_path(self.path, root=None) + else: + module = self.fspath.pyimport() yield self.makeitem( name=module.__name__, parent=self, test_runs=LoadTestsFromPythonModule( @@ -152,7 +173,13 @@ def collect(self): def find_launch_test_entrypoint(path): try: - return getattr(path.pyimport(), 'generate_test_description', None) + if _pytest_version_ge(7): + from _pytest.pathlib import import_path + module = import_path(path, root=None) + else: + # Assume we got legacy path in earlier versions of pytest + module = path.pyimport() + return getattr(module, 'generate_test_description', None) except SyntaxError: return None @@ -166,18 +193,20 @@ def pytest_pycollect_makemodule(path, parent): ) if module is not None: return module - if path.basename == '__init__.py': - try: - # since https://docs.pytest.org/en/latest/changelog.html#deprecations - # todo: remove fallback once all platforms use pytest >=5.4 + + if _pytest_version_ge(7): + path = pathlib.Path(path) + if path.name == '__init__.py': + return pytest.Package.from_parent(parent, path=path) + return pytest.Module.from_parent(parent=parent, path=path) + elif _pytest_version_ge(5, 4): + if path.basename == '__init__.py': return pytest.Package.from_parent(parent, fspath=path) - except AttributeError: - return pytest.Package(path, parent) - try: - # since https://docs.pytest.org/en/latest/changelog.html#deprecations - # todo: remove fallback once all platforms use pytest >=5.4 return pytest.Module.from_parent(parent, fspath=path) - except AttributeError: + else: + # todo: remove fallback once all platforms use pytest >=5.4 + if path.basename == '__init__.py': + return pytest.Package(path, parent) return pytest.Module(path, parent) @@ -185,7 +214,11 @@ def pytest_pycollect_makemodule(path, parent): def pytest_launch_collect_makemodule(path, parent, entrypoint): marks = getattr(entrypoint, 'pytestmark', []) if marks and any(m.name == 'launch_test' for m in marks): - return LaunchTestModule.from_parent(parent, fspath=path) + if _pytest_version_ge(7): + path = pathlib.Path(path) + return LaunchTestModule.from_parent(parent=parent, path=path) + else: + return LaunchTestModule.from_parent(parent=parent, fspath=path) def pytest_addhooks(pluginmanager):