Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion rope/refactor/importutils/module_imports.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from rope.base import ast
from rope.base import exceptions
from rope.base import pynames
from rope.base import utils
from rope.refactor.importutils import actions
Expand Down Expand Up @@ -31,8 +32,36 @@ def _get_unbound_names(self, defined_pyobject):
ast.walk(self.pymodule.get_ast(), visitor)
return visitor.unbound

def _get_all_star_list(self, pymodule):
result = set()
try:
all_star_list = pymodule.get_attribute('__all__')
except exceptions.AttributeNotFoundError:
return result

# FIXME: Need a better way to recursively infer possible values.
# Currently pyobjects can recursively infer type, but not values.
# Do a very basic 1-level value inference
for assignment in all_star_list.assignments:
if isinstance(assignment.ast_node, ast.List):
stack = list(assignment.ast_node.elts)
while stack:
el = stack.pop()
if isinstance(el, ast.Str):
result.add(el.s)
elif isinstance(el, ast.Name):
name = pymodule.get_attribute(el.id)
if isinstance(name, pynames.AssignedName):
for av in name.assignments:
if isinstance(av.ast_node, ast.Str):
result.add(av.ast_node.s)
elif isinstance(el, ast.IfExp):
stack.append(el.body)
stack.append(el.orelse)
return result

def remove_unused_imports(self):
can_select = _OneTimeSelector(self._get_unbound_names(self.pymodule))
can_select = _OneTimeSelector(self._get_unbound_names(self.pymodule) | self._get_all_star_list(self.pymodule))
visitor = actions.RemovingVisitor(
self.project, self._current_folder(), can_select)
for import_statement in self.imports:
Expand Down
93 changes: 93 additions & 0 deletions ropetest/refactor/importutilstest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
except ImportError:
import unittest

from textwrap import dedent

from rope.refactor.importutils import ImportTools, importinfo, add_import
from ropetest import testutils

Expand Down Expand Up @@ -937,6 +939,97 @@ def test_sorting_future_imports(self):
'from __future__ import devision\n\nimport os\n',
self.import_tools.sort_imports(pymod))

def test_organizing_imports_all_star(self):
code = expected = dedent('''\
from package import some_name


__all__ = ["some_name"]
''')
self.mod.write(code)
pymod = self.project.get_pymodule(self.mod)
self.assertEqual(
expected,
self.import_tools.organize_imports(pymod))

def test_organizing_imports_all_star_with_variables(self):
code = expected = dedent('''\
from package import name_one, name_two


if something():
foo = 'name_one'
else:
foo = 'name_two'
__all__ = [foo]
''')
self.mod.write(code)
pymod = self.project.get_pymodule(self.mod)
self.assertEqual(
expected,
self.import_tools.organize_imports(pymod))

def test_organizing_imports_all_star_with_inline_if(self):
code = expected = dedent('''\
from package import name_one, name_two


__all__ = ['name_one' if something() else 'name_two']
''')
self.mod.write(code)
pymod = self.project.get_pymodule(self.mod)
self.assertEqual(
expected,
self.import_tools.organize_imports(pymod))

@testutils.only_for_versions_higher('3')
def test_organizing_imports_all_star_tolerates_non_list_of_str_1(self):
code = expected = dedent('''\
from package import name_one, name_two


foo = 'name_two'
__all__ = [bar, *abc] + mylist
__all__ = [foo, 'name_one', *abc]
__all__ = [it for it in mylist]
''')
self.mod.write(code)
pymod = self.project.get_pymodule(self.mod)
self.assertEqual(
expected,
self.import_tools.organize_imports(pymod))

def test_organizing_imports_all_star_tolerates_non_list_of_str_2(self):
code = expected = dedent('''\
from package import name_one, name_two


foo = 'name_two'
__all__ = [foo, 3, 'name_one']
__all__ = [it for it in mylist]
''')
self.mod.write(code)
pymod = self.project.get_pymodule(self.mod)
self.assertEqual(
expected,
self.import_tools.organize_imports(pymod))

@testutils.time_limit(60)
def test_organizing_imports_all_star_no_infinite_loop(self):
code = expected = dedent('''\
from package import name_one, name_two


foo = bar
bar = foo
__all__ = [foo, 'name_one', 'name_two']
''')
self.mod.write(code)
pymod = self.project.get_pymodule(self.mod)
self.assertEqual(
expected,
self.import_tools.organize_imports(pymod))

def test_customized_import_organization(self):
self.mod.write('import sys\nimport sys\n')
pymod = self.project.get_pymodule(self.mod)
Expand Down
12 changes: 12 additions & 0 deletions ropetest/testutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,15 @@ def only_for_versions_higher(version):
def skipNotPOSIX():
return unittest.skipIf(os.name != 'posix',
'This test works only on POSIX')


def time_limit(timeout):
if not any(procname in sys.argv[0] for procname in {'pytest', 'py.test'}):
# no-op when running tests without pytest
return lambda *args, **kwargs: lambda func: func

# do a local import so we don't import pytest when running without pytest
import pytest

# this prevents infinite loop/recursion from taking forever in CI
return pytest.mark.time_limit(timeout)
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def get_version():
classifiers=classifiers,
extras_require={
'dev': [
'pytest'
'pytest',
'pytest-timeout',
]
})