Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions Lib/lib2to3/fixes/fix_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

# Local imports
from .. import fixer_base
from os.path import dirname, join, exists, sep
from os.path import dirname, join, exists, isdir, sep
from ..fixer_util import FromImport, syms, token


Expand Down Expand Up @@ -93,7 +93,12 @@ def probably_a_local_import(self, imp_name):
# so can't be a relative import.
if not exists(join(dirname(base_path), "__init__.py")):
return False
for ext in [".py", sep, ".pyc", ".so", ".sl", ".pyd"]:
for ext in [".py", ".pyc", ".so", ".sl", ".pyd"]:
if exists(base_path + ext):
return True

# If the path is a directory, check that is is a valid import
if isdir(base_path) and exists(join(base_path, "__init__.py")):
return True

return False
20 changes: 18 additions & 2 deletions Lib/lib2to3/tests/test_fixers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3803,17 +3803,24 @@ def setUp(self):
# so we can check that it's doing the right thing
self.files_checked = []
self.present_files = set()
self.present_directories = set()
self.always_exists = True
def fake_exists(name):
self.files_checked.append(name)
return self.always_exists or (name in self.present_files)

def fake_isdir(name):
self.files_checked.append(name)
return self.always_exists or (name in self.present_directories)

from lib2to3.fixes import fix_import
fix_import.exists = fake_exists
fix_import.isdir = fake_isdir

def tearDown(self):
from lib2to3.fixes import fix_import
fix_import.exists = os.path.exists
fix_import.isdir = os.path.isdir

def check_both(self, b, a):
self.always_exists = True
Expand All @@ -3828,7 +3835,7 @@ def p(path):

self.always_exists = False
self.present_files = set(['__init__.py'])
expected_extensions = ('.py', os.path.sep, '.pyc', '.so', '.sl', '.pyd')
expected_extensions = ('.py', '', '.pyc', '.so', '.sl', '.pyd')
names_to_test = (p("/spam/eggs.py"), "ni.py", p("../../shrubbery.py"))

for name in names_to_test:
Expand Down Expand Up @@ -3868,9 +3875,18 @@ def test_import_from_package(self):
b = "import bar"
a = "from . import bar"
self.always_exists = False
self.present_files = set(["__init__.py", "bar" + os.path.sep])
self.present_files = set(["__init__.py", "bar" + os.path.sep + "__init__.py"])
self.present_directories = set(["bar"])
self.check(b, a)

def test_import_from_non_package(self):
# import with a directory that is not a python package (no change)
a = "import json"
self.always_exists = False
self.present_files = set(["__init__.py"])
self.present_directories = set(["json"])
self.unchanged(a)

def test_already_relative_import(self):
s = "from . import bar"
self.unchanged(s)
Expand Down