diff --git a/Lib/lib2to3/fixes/fix_import.py b/Lib/lib2to3/fixes/fix_import.py index 734ca294699c364..7573e46513997af 100644 --- a/Lib/lib2to3/fixes/fix_import.py +++ b/Lib/lib2to3/fixes/fix_import.py @@ -12,7 +12,7 @@ # Local imports from .. import fixer_base -from os.path import dirname, join, exists, sep +from os.path import dirname, join, exists, isdir, sep from ..fixer_util import FromImport, syms, token @@ -93,7 +93,12 @@ def probably_a_local_import(self, imp_name): # so can't be a relative import. if not exists(join(dirname(base_path), "__init__.py")): return False - for ext in [".py", sep, ".pyc", ".so", ".sl", ".pyd"]: + for ext in [".py", ".pyc", ".so", ".sl", ".pyd"]: if exists(base_path + ext): return True + + # If the path is a directory, check that is is a valid import + if isdir(base_path) and exists(join(base_path, "__init__.py")): + return True + return False diff --git a/Lib/lib2to3/tests/test_fixers.py b/Lib/lib2to3/tests/test_fixers.py index 3da5dd845c93c66..237259eb219b6ee 100644 --- a/Lib/lib2to3/tests/test_fixers.py +++ b/Lib/lib2to3/tests/test_fixers.py @@ -3803,17 +3803,24 @@ def setUp(self): # so we can check that it's doing the right thing self.files_checked = [] self.present_files = set() + self.present_directories = set() self.always_exists = True def fake_exists(name): self.files_checked.append(name) return self.always_exists or (name in self.present_files) + def fake_isdir(name): + self.files_checked.append(name) + return self.always_exists or (name in self.present_directories) + from lib2to3.fixes import fix_import fix_import.exists = fake_exists + fix_import.isdir = fake_isdir def tearDown(self): from lib2to3.fixes import fix_import fix_import.exists = os.path.exists + fix_import.isdir = os.path.isdir def check_both(self, b, a): self.always_exists = True @@ -3828,7 +3835,7 @@ def p(path): self.always_exists = False self.present_files = set(['__init__.py']) - expected_extensions = ('.py', os.path.sep, '.pyc', '.so', '.sl', '.pyd') + expected_extensions = ('.py', '', '.pyc', '.so', '.sl', '.pyd') names_to_test = (p("/spam/eggs.py"), "ni.py", p("../../shrubbery.py")) for name in names_to_test: @@ -3868,9 +3875,18 @@ def test_import_from_package(self): b = "import bar" a = "from . import bar" self.always_exists = False - self.present_files = set(["__init__.py", "bar" + os.path.sep]) + self.present_files = set(["__init__.py", "bar" + os.path.sep + "__init__.py"]) + self.present_directories = set(["bar"]) self.check(b, a) + def test_import_from_non_package(self): + # import with a directory that is not a python package (no change) + a = "import json" + self.always_exists = False + self.present_files = set(["__init__.py"]) + self.present_directories = set(["json"]) + self.unchanged(a) + def test_already_relative_import(self): s = "from . import bar" self.unchanged(s)