Skip to content

Commit bf2bda3

Browse files
committed
Close python#22457: Honour load_tests in the start_dir of discovery.
We were not honouring load_tests in a package/__init__.py when that was the start_dir parameter, though we do when it is a child package. The fix required a little care since it introduces the possibility of infinite recursion.
1 parent d39e199 commit bf2bda3

File tree

6 files changed

+167
-59
lines changed

6 files changed

+167
-59
lines changed

Doc/library/unittest.rst

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1668,7 +1668,11 @@ Loading and running tests
16681668

16691669
If a package (a directory containing a file named :file:`__init__.py`) is
16701670
found, the package will be checked for a ``load_tests`` function. If this
1671-
exists then it will be called with *loader*, *tests*, *pattern*.
1671+
exists then it will be called
1672+
``package.load_tests(loader, tests, pattern)``. Test discovery takes care
1673+
to ensure that a package is only checked for tests once during an
1674+
invocation, even if the load_tests function itself calls
1675+
``loader.discover``.
16721676

16731677
If ``load_tests`` exists then discovery does *not* recurse into the
16741678
package, ``load_tests`` is responsible for loading all tests in the

Lib/unittest/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,12 @@ def testMultiply(self):
6767

6868
# deprecated
6969
_TextTestResult = TextTestResult
70+
71+
# There are no tests here, so don't try to run anything discovered from
72+
# introspecting the symbols (e.g. FunctionTestCase). Instead, all our
73+
# tests come from within unittest.test.
74+
def load_tests(loader, tests, pattern):
75+
import os.path
76+
# top level directory cached on loader instance
77+
this_dir = os.path.dirname(__file__)
78+
return loader.discover(start_dir=this_dir, pattern=pattern)

Lib/unittest/loader.py

Lines changed: 105 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ class TestLoader(object):
6565
def __init__(self):
6666
super(TestLoader, self).__init__()
6767
self.errors = []
68+
# Tracks packages which we have called into via load_tests, to
69+
# avoid infinite re-entrancy.
70+
self._loading_packages = set()
6871

6972
def loadTestsFromTestCase(self, testCaseClass):
7073
"""Return a suite of all tests cases contained in testCaseClass"""
@@ -229,9 +232,13 @@ def discover(self, start_dir, pattern='test*.py', top_level_dir=None):
229232
230233
If a test package name (directory with '__init__.py') matches the
231234
pattern then the package will be checked for a 'load_tests' function. If
232-
this exists then it will be called with loader, tests, pattern.
235+
this exists then it will be called with (loader, tests, pattern) unless
236+
the package has already had load_tests called from the same discovery
237+
invocation, in which case the package module object is not scanned for
238+
tests - this ensures that when a package uses discover to further
239+
discover child tests that infinite recursion does not happen.
233240
234-
If load_tests exists then discovery does *not* recurse into the package,
241+
If load_tests exists then discovery does *not* recurse into the package,
235242
load_tests is responsible for loading all tests in the package.
236243
237244
The pattern is deliberately not stored as a loader attribute so that
@@ -355,69 +362,110 @@ def _match_path(self, path, full_path, pattern):
355362

356363
def _find_tests(self, start_dir, pattern, namespace=False):
357364
"""Used by discovery. Yields test suites it loads."""
365+
# Handle the __init__ in this package
366+
name = self._get_name_from_path(start_dir)
367+
# name is '.' when start_dir == top_level_dir (and top_level_dir is by
368+
# definition not a package).
369+
if name != '.' and name not in self._loading_packages:
370+
# name is in self._loading_packages while we have called into
371+
# loadTestsFromModule with name.
372+
tests, should_recurse = self._find_test_path(
373+
start_dir, pattern, namespace)
374+
if tests is not None:
375+
yield tests
376+
if not should_recurse:
377+
# Either an error occured, or load_tests was used by the
378+
# package.
379+
return
380+
# Handle the contents.
358381
paths = sorted(os.listdir(start_dir))
359-
360382
for path in paths:
361383
full_path = os.path.join(start_dir, path)
362-
if os.path.isfile(full_path):
363-
if not VALID_MODULE_NAME.match(path):
364-
# valid Python identifiers only
365-
continue
366-
if not self._match_path(path, full_path, pattern):
367-
continue
368-
# if the test file matches, load it
384+
tests, should_recurse = self._find_test_path(
385+
full_path, pattern, namespace)
386+
if tests is not None:
387+
yield tests
388+
if should_recurse:
389+
# we found a package that didn't use load_tests.
369390
name = self._get_name_from_path(full_path)
391+
self._loading_packages.add(name)
370392
try:
371-
module = self._get_module_from_name(name)
372-
except case.SkipTest as e:
373-
yield _make_skipped_test(name, e, self.suiteClass)
374-
except:
375-
error_case, error_message = \
376-
_make_failed_import_test(name, self.suiteClass)
377-
self.errors.append(error_message)
378-
yield error_case
379-
else:
380-
mod_file = os.path.abspath(getattr(module, '__file__', full_path))
381-
realpath = _jython_aware_splitext(os.path.realpath(mod_file))
382-
fullpath_noext = _jython_aware_splitext(os.path.realpath(full_path))
383-
if realpath.lower() != fullpath_noext.lower():
384-
module_dir = os.path.dirname(realpath)
385-
mod_name = _jython_aware_splitext(os.path.basename(full_path))
386-
expected_dir = os.path.dirname(full_path)
387-
msg = ("%r module incorrectly imported from %r. Expected %r. "
388-
"Is this module globally installed?")
389-
raise ImportError(msg % (mod_name, module_dir, expected_dir))
390-
yield self.loadTestsFromModule(module, pattern=pattern)
391-
elif os.path.isdir(full_path):
392-
if (not namespace and
393-
not os.path.isfile(os.path.join(full_path, '__init__.py'))):
394-
continue
395-
396-
load_tests = None
397-
tests = None
398-
name = self._get_name_from_path(full_path)
393+
yield from self._find_tests(full_path, pattern, namespace)
394+
finally:
395+
self._loading_packages.discard(name)
396+
397+
def _find_test_path(self, full_path, pattern, namespace=False):
398+
"""Used by discovery.
399+
400+
Loads tests from a single file, or a directories' __init__.py when
401+
passed the directory.
402+
403+
Returns a tuple (None_or_tests_from_file, should_recurse).
404+
"""
405+
basename = os.path.basename(full_path)
406+
if os.path.isfile(full_path):
407+
if not VALID_MODULE_NAME.match(basename):
408+
# valid Python identifiers only
409+
return None, False
410+
if not self._match_path(basename, full_path, pattern):
411+
return None, False
412+
# if the test file matches, load it
413+
name = self._get_name_from_path(full_path)
414+
try:
415+
module = self._get_module_from_name(name)
416+
except case.SkipTest as e:
417+
return _make_skipped_test(name, e, self.suiteClass), False
418+
except:
419+
error_case, error_message = \
420+
_make_failed_import_test(name, self.suiteClass)
421+
self.errors.append(error_message)
422+
return error_case, False
423+
else:
424+
mod_file = os.path.abspath(
425+
getattr(module, '__file__', full_path))
426+
realpath = _jython_aware_splitext(
427+
os.path.realpath(mod_file))
428+
fullpath_noext = _jython_aware_splitext(
429+
os.path.realpath(full_path))
430+
if realpath.lower() != fullpath_noext.lower():
431+
module_dir = os.path.dirname(realpath)
432+
mod_name = _jython_aware_splitext(
433+
os.path.basename(full_path))
434+
expected_dir = os.path.dirname(full_path)
435+
msg = ("%r module incorrectly imported from %r. Expected "
436+
"%r. Is this module globally installed?")
437+
raise ImportError(
438+
msg % (mod_name, module_dir, expected_dir))
439+
return self.loadTestsFromModule(module, pattern=pattern), False
440+
elif os.path.isdir(full_path):
441+
if (not namespace and
442+
not os.path.isfile(os.path.join(full_path, '__init__.py'))):
443+
return None, False
444+
445+
load_tests = None
446+
tests = None
447+
name = self._get_name_from_path(full_path)
448+
try:
449+
package = self._get_module_from_name(name)
450+
except case.SkipTest as e:
451+
return _make_skipped_test(name, e, self.suiteClass), False
452+
except:
453+
error_case, error_message = \
454+
_make_failed_import_test(name, self.suiteClass)
455+
self.errors.append(error_message)
456+
return error_case, False
457+
else:
458+
load_tests = getattr(package, 'load_tests', None)
459+
# Mark this package as being in load_tests (possibly ;))
460+
self._loading_packages.add(name)
399461
try:
400-
package = self._get_module_from_name(name)
401-
except case.SkipTest as e:
402-
yield _make_skipped_test(name, e, self.suiteClass)
403-
except:
404-
error_case, error_message = \
405-
_make_failed_import_test(name, self.suiteClass)
406-
self.errors.append(error_message)
407-
yield error_case
408-
else:
409-
load_tests = getattr(package, 'load_tests', None)
410462
tests = self.loadTestsFromModule(package, pattern=pattern)
411-
if tests is not None:
412-
# tests loaded from package file
413-
yield tests
414-
415463
if load_tests is not None:
416-
# loadTestsFromModule(package) has load_tests for us.
417-
continue
418-
# recurse into the package
419-
yield from self._find_tests(full_path, pattern,
420-
namespace=namespace)
464+
# loadTestsFromModule(package) has loaded tests for us.
465+
return tests, False
466+
return tests, True
467+
finally:
468+
self._loading_packages.discard(name)
421469

422470

423471
defaultTestLoader = TestLoader()

Lib/unittest/test/test_discovery.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,51 @@ def _find_tests(start_dir, pattern, namespace=None):
368368
self.assertEqual(_find_tests_args, [(start_dir, 'pattern')])
369369
self.assertIn(top_level_dir, sys.path)
370370

371+
def test_discover_start_dir_is_package_calls_package_load_tests(self):
372+
# This test verifies that the package load_tests in a package is indeed
373+
# invoked when the start_dir is a package (and not the top level).
374+
# http://bugs.python.org/issue22457
375+
376+
# Test data: we expect the following:
377+
# an isfile to verify the package, then importing and scanning
378+
# as per _find_tests' normal behaviour.
379+
# We expect to see our load_tests hook called once.
380+
vfs = {abspath('/toplevel'): ['startdir'],
381+
abspath('/toplevel/startdir'): ['__init__.py']}
382+
def list_dir(path):
383+
return list(vfs[path])
384+
self.addCleanup(setattr, os, 'listdir', os.listdir)
385+
os.listdir = list_dir
386+
self.addCleanup(setattr, os.path, 'isfile', os.path.isfile)
387+
os.path.isfile = lambda path: path.endswith('.py')
388+
self.addCleanup(setattr, os.path, 'isdir', os.path.isdir)
389+
os.path.isdir = lambda path: not path.endswith('.py')
390+
self.addCleanup(sys.path.remove, abspath('/toplevel'))
391+
392+
class Module(object):
393+
paths = []
394+
load_tests_args = []
395+
396+
def __init__(self, path):
397+
self.path = path
398+
399+
def load_tests(self, loader, tests, pattern):
400+
return ['load_tests called ' + self.path]
401+
402+
def __eq__(self, other):
403+
return self.path == other.path
404+
405+
loader = unittest.TestLoader()
406+
loader._get_module_from_name = lambda name: Module(name)
407+
loader.suiteClass = lambda thing: thing
408+
409+
suite = loader.discover('/toplevel/startdir', top_level_dir='/toplevel')
410+
411+
# We should have loaded tests from the package __init__.
412+
# (normally this would be nested TestSuites.)
413+
self.assertEqual(suite,
414+
[['load_tests called startdir']])
415+
371416
def setup_import_issue_tests(self, fakefile):
372417
listdir = os.listdir
373418
os.listdir = lambda _: [fakefile]

Lib/unittest/test/test_loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -841,7 +841,7 @@ def test_loadTestsFromNames__unknown_attr_name(self):
841841
loader = unittest.TestLoader()
842842

843843
suite = loader.loadTestsFromNames(
844-
['unittest.loader.sdasfasfasdf', 'unittest'])
844+
['unittest.loader.sdasfasfasdf', 'unittest.test.dummy'])
845845
error, test = self.check_deferred_error(loader, list(suite)[0])
846846
expected = "module 'unittest.loader' has no attribute 'sdasfasfasdf'"
847847
self.assertIn(

Misc/NEWS

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,8 @@ Library
212212

213213
- Issue #22217: Implemented reprs of classes in the zipfile module.
214214

215+
- Issue #22457: Honour load_tests in the start_dir of discovery.
216+
215217
- Issue #18216: gettext now raises an error when a .mo file has an
216218
unsupported major version number. Patch by Aaron Hill.
217219

0 commit comments

Comments
 (0)