@@ -65,6 +65,9 @@ class TestLoader(object):
6565 def __init__ (self ):
6666 super (TestLoader , self ).__init__ ()
6767 self .errors = []
68+ # Tracks packages which we have called into via load_tests, to
69+ # avoid infinite re-entrancy.
70+ self ._loading_packages = set ()
6871
6972 def loadTestsFromTestCase (self , testCaseClass ):
7073 """Return a suite of all tests cases contained in testCaseClass"""
@@ -229,9 +232,13 @@ def discover(self, start_dir, pattern='test*.py', top_level_dir=None):
229232
230233 If a test package name (directory with '__init__.py') matches the
231234 pattern then the package will be checked for a 'load_tests' function. If
232- this exists then it will be called with loader, tests, pattern.
235+ this exists then it will be called with (loader, tests, pattern) unless
236+ the package has already had load_tests called from the same discovery
237+ invocation, in which case the package module object is not scanned for
238+ tests - this ensures that when a package uses discover to further
239+ discover child tests that infinite recursion does not happen.
233240
234- If load_tests exists then discovery does *not* recurse into the package,
241+ If load_tests exists then discovery does *not* recurse into the package,
235242 load_tests is responsible for loading all tests in the package.
236243
237244 The pattern is deliberately not stored as a loader attribute so that
@@ -355,69 +362,110 @@ def _match_path(self, path, full_path, pattern):
355362
356363 def _find_tests (self , start_dir , pattern , namespace = False ):
357364 """Used by discovery. Yields test suites it loads."""
365+ # Handle the __init__ in this package
366+ name = self ._get_name_from_path (start_dir )
367+ # name is '.' when start_dir == top_level_dir (and top_level_dir is by
368+ # definition not a package).
369+ if name != '.' and name not in self ._loading_packages :
370+ # name is in self._loading_packages while we have called into
371+ # loadTestsFromModule with name.
372+ tests , should_recurse = self ._find_test_path (
373+ start_dir , pattern , namespace )
374+ if tests is not None :
375+ yield tests
376+ if not should_recurse :
377+ # Either an error occured, or load_tests was used by the
378+ # package.
379+ return
380+ # Handle the contents.
358381 paths = sorted (os .listdir (start_dir ))
359-
360382 for path in paths :
361383 full_path = os .path .join (start_dir , path )
362- if os .path .isfile (full_path ):
363- if not VALID_MODULE_NAME .match (path ):
364- # valid Python identifiers only
365- continue
366- if not self ._match_path (path , full_path , pattern ):
367- continue
368- # if the test file matches, load it
384+ tests , should_recurse = self ._find_test_path (
385+ full_path , pattern , namespace )
386+ if tests is not None :
387+ yield tests
388+ if should_recurse :
389+ # we found a package that didn't use load_tests.
369390 name = self ._get_name_from_path (full_path )
391+ self ._loading_packages .add (name )
370392 try :
371- module = self ._get_module_from_name (name )
372- except case .SkipTest as e :
373- yield _make_skipped_test (name , e , self .suiteClass )
374- except :
375- error_case , error_message = \
376- _make_failed_import_test (name , self .suiteClass )
377- self .errors .append (error_message )
378- yield error_case
379- else :
380- mod_file = os .path .abspath (getattr (module , '__file__' , full_path ))
381- realpath = _jython_aware_splitext (os .path .realpath (mod_file ))
382- fullpath_noext = _jython_aware_splitext (os .path .realpath (full_path ))
383- if realpath .lower () != fullpath_noext .lower ():
384- module_dir = os .path .dirname (realpath )
385- mod_name = _jython_aware_splitext (os .path .basename (full_path ))
386- expected_dir = os .path .dirname (full_path )
387- msg = ("%r module incorrectly imported from %r. Expected %r. "
388- "Is this module globally installed?" )
389- raise ImportError (msg % (mod_name , module_dir , expected_dir ))
390- yield self .loadTestsFromModule (module , pattern = pattern )
391- elif os .path .isdir (full_path ):
392- if (not namespace and
393- not os .path .isfile (os .path .join (full_path , '__init__.py' ))):
394- continue
395-
396- load_tests = None
397- tests = None
398- name = self ._get_name_from_path (full_path )
393+ yield from self ._find_tests (full_path , pattern , namespace )
394+ finally :
395+ self ._loading_packages .discard (name )
396+
397+ def _find_test_path (self , full_path , pattern , namespace = False ):
398+ """Used by discovery.
399+
400+ Loads tests from a single file, or a directories' __init__.py when
401+ passed the directory.
402+
403+ Returns a tuple (None_or_tests_from_file, should_recurse).
404+ """
405+ basename = os .path .basename (full_path )
406+ if os .path .isfile (full_path ):
407+ if not VALID_MODULE_NAME .match (basename ):
408+ # valid Python identifiers only
409+ return None , False
410+ if not self ._match_path (basename , full_path , pattern ):
411+ return None , False
412+ # if the test file matches, load it
413+ name = self ._get_name_from_path (full_path )
414+ try :
415+ module = self ._get_module_from_name (name )
416+ except case .SkipTest as e :
417+ return _make_skipped_test (name , e , self .suiteClass ), False
418+ except :
419+ error_case , error_message = \
420+ _make_failed_import_test (name , self .suiteClass )
421+ self .errors .append (error_message )
422+ return error_case , False
423+ else :
424+ mod_file = os .path .abspath (
425+ getattr (module , '__file__' , full_path ))
426+ realpath = _jython_aware_splitext (
427+ os .path .realpath (mod_file ))
428+ fullpath_noext = _jython_aware_splitext (
429+ os .path .realpath (full_path ))
430+ if realpath .lower () != fullpath_noext .lower ():
431+ module_dir = os .path .dirname (realpath )
432+ mod_name = _jython_aware_splitext (
433+ os .path .basename (full_path ))
434+ expected_dir = os .path .dirname (full_path )
435+ msg = ("%r module incorrectly imported from %r. Expected "
436+ "%r. Is this module globally installed?" )
437+ raise ImportError (
438+ msg % (mod_name , module_dir , expected_dir ))
439+ return self .loadTestsFromModule (module , pattern = pattern ), False
440+ elif os .path .isdir (full_path ):
441+ if (not namespace and
442+ not os .path .isfile (os .path .join (full_path , '__init__.py' ))):
443+ return None , False
444+
445+ load_tests = None
446+ tests = None
447+ name = self ._get_name_from_path (full_path )
448+ try :
449+ package = self ._get_module_from_name (name )
450+ except case .SkipTest as e :
451+ return _make_skipped_test (name , e , self .suiteClass ), False
452+ except :
453+ error_case , error_message = \
454+ _make_failed_import_test (name , self .suiteClass )
455+ self .errors .append (error_message )
456+ return error_case , False
457+ else :
458+ load_tests = getattr (package , 'load_tests' , None )
459+ # Mark this package as being in load_tests (possibly ;))
460+ self ._loading_packages .add (name )
399461 try :
400- package = self ._get_module_from_name (name )
401- except case .SkipTest as e :
402- yield _make_skipped_test (name , e , self .suiteClass )
403- except :
404- error_case , error_message = \
405- _make_failed_import_test (name , self .suiteClass )
406- self .errors .append (error_message )
407- yield error_case
408- else :
409- load_tests = getattr (package , 'load_tests' , None )
410462 tests = self .loadTestsFromModule (package , pattern = pattern )
411- if tests is not None :
412- # tests loaded from package file
413- yield tests
414-
415463 if load_tests is not None :
416- # loadTestsFromModule(package) has load_tests for us.
417- continue
418- # recurse into the package
419- yield from self . _find_tests ( full_path , pattern ,
420- namespace = namespace )
464+ # loadTestsFromModule(package) has loaded tests for us.
465+ return tests , False
466+ return tests , True
467+ finally :
468+ self . _loading_packages . discard ( name )
421469
422470
423471defaultTestLoader = TestLoader ()
0 commit comments