Skip to content

Commit

Permalink
Fixed and updated import_default_cases_module
Browse files Browse the repository at this point in the history
  • Loading branch information
Sylvain MARIE committed Mar 21, 2022
1 parent cadffa2 commit def94ce
Showing 1 changed file with 11 additions and 19 deletions.
30 changes: 11 additions & 19 deletions src/pytest_cases/case_parametrizer_new.py
Expand Up @@ -300,7 +300,7 @@ def get_all_cases(parametrization_target=None, # type: Callable
# module
if c is AUTO:
# First try `test_<name>_cases.py` Then `case_<name>.py`
c = import_default_cases_module(parametrization_target)
c = import_default_cases_module(caller_module_name)

elif c is THIS_MODULE or c == '.':
c = caller_module_name
Expand Down Expand Up @@ -647,34 +647,26 @@ def _get_fixture_cases(module_or_class # type: Union[ModuleType, Type]
return cache, imported_fixtures_list


def import_default_cases_module(context):
def import_default_cases_module(test_module_name):
"""
Implements the `module=AUTO` behaviour of `@parameterize_cases`: based on the context
passed in. This can either a <module> object or a decorated test function in which
case it finds its containing module name "test_<module>.py" and then tries to import
the python module "test_<module>_cases.py".
Implements the `module=AUTO` behaviour of `@parameterize_cases`.
If "test_<module>_cases.py" module is not found it looks for the alternate
file `cases_<module>.py`.
`test_module_name` will have the format "test_<module>.py", the associated python module "test_<module>_cases.py"
will be loaded to load the cases.
:param f: the decorated test function or a module
If "test_<module>_cases.py" module is not found it looks for the alternate file `cases_<module>.py`.
:param test_module_name: the test module
:return:
"""
if ismodule(context):
module_name = context.__name__
elif hasattr(context, "__module__"):
module_name = context.__module__
else:
raise ValueError("Can't get module from context %s" % context)

# First try `test_<name>_cases.py`
cases_module_name1 = "%s_cases" % module_name
cases_module_name1 = "%s_cases" % test_module_name

try:
cases_module = import_module(cases_module_name1)
except ModuleNotFoundError:
# Then try `case_<name>.py`
parts = module_name.split('.')
parts = test_module_name.split('.')
assert parts[-1][0:5] == 'test_'
cases_module_name2 = "%s.cases_%s" % ('.'.join(parts[:-1]), parts[-1][5:])
try:
Expand All @@ -684,7 +676,7 @@ def import_default_cases_module(context):
raise ValueError("Error importing test cases module to parametrize %r: unable to import AUTO "
"cases module %r nor %r. Maybe you wish to import cases from somewhere else ? In that case"
"please specify `cases=...`."
% (f, cases_module_name1, cases_module_name2))
% (test_module_name, cases_module_name1, cases_module_name2))

return cases_module

Expand Down

0 comments on commit def94ce

Please sign in to comment.