diff --git a/rope/refactor/importutils/module_imports.py b/rope/refactor/importutils/module_imports.py index a205c1075..2436879a4 100644 --- a/rope/refactor/importutils/module_imports.py +++ b/rope/refactor/importutils/module_imports.py @@ -1,4 +1,5 @@ from rope.base import ast +from rope.base import exceptions from rope.base import pynames from rope.base import utils from rope.refactor.importutils import actions @@ -31,8 +32,36 @@ def _get_unbound_names(self, defined_pyobject): ast.walk(self.pymodule.get_ast(), visitor) return visitor.unbound + def _get_all_star_list(self, pymodule): + result = set() + try: + all_star_list = pymodule.get_attribute('__all__') + except exceptions.AttributeNotFoundError: + return result + + # FIXME: Need a better way to recursively infer possible values. + # Currently pyobjects can recursively infer type, but not values. + # Do a very basic 1-level value inference + for assignment in all_star_list.assignments: + if isinstance(assignment.ast_node, ast.List): + stack = list(assignment.ast_node.elts) + while stack: + el = stack.pop() + if isinstance(el, ast.Str): + result.add(el.s) + elif isinstance(el, ast.Name): + name = pymodule.get_attribute(el.id) + if isinstance(name, pynames.AssignedName): + for av in name.assignments: + if isinstance(av.ast_node, ast.Str): + result.add(av.ast_node.s) + elif isinstance(el, ast.IfExp): + stack.append(el.body) + stack.append(el.orelse) + return result + def remove_unused_imports(self): - can_select = _OneTimeSelector(self._get_unbound_names(self.pymodule)) + can_select = _OneTimeSelector(self._get_unbound_names(self.pymodule) | self._get_all_star_list(self.pymodule)) visitor = actions.RemovingVisitor( self.project, self._current_folder(), can_select) for import_statement in self.imports: diff --git a/ropetest/refactor/importutilstest.py b/ropetest/refactor/importutilstest.py index b51db2894..bb1ff1f86 100644 --- a/ropetest/refactor/importutilstest.py +++ b/ropetest/refactor/importutilstest.py @@ -3,6 +3,8 @@ except ImportError: import unittest +from textwrap import dedent + from rope.refactor.importutils import ImportTools, importinfo, add_import from ropetest import testutils @@ -937,6 +939,97 @@ def test_sorting_future_imports(self): 'from __future__ import devision\n\nimport os\n', self.import_tools.sort_imports(pymod)) + def test_organizing_imports_all_star(self): + code = expected = dedent('''\ + from package import some_name + + + __all__ = ["some_name"] + ''') + self.mod.write(code) + pymod = self.project.get_pymodule(self.mod) + self.assertEqual( + expected, + self.import_tools.organize_imports(pymod)) + + def test_organizing_imports_all_star_with_variables(self): + code = expected = dedent('''\ + from package import name_one, name_two + + + if something(): + foo = 'name_one' + else: + foo = 'name_two' + __all__ = [foo] + ''') + self.mod.write(code) + pymod = self.project.get_pymodule(self.mod) + self.assertEqual( + expected, + self.import_tools.organize_imports(pymod)) + + def test_organizing_imports_all_star_with_inline_if(self): + code = expected = dedent('''\ + from package import name_one, name_two + + + __all__ = ['name_one' if something() else 'name_two'] + ''') + self.mod.write(code) + pymod = self.project.get_pymodule(self.mod) + self.assertEqual( + expected, + self.import_tools.organize_imports(pymod)) + + @testutils.only_for_versions_higher('3') + def test_organizing_imports_all_star_tolerates_non_list_of_str_1(self): + code = expected = dedent('''\ + from package import name_one, name_two + + + foo = 'name_two' + __all__ = [bar, *abc] + mylist + __all__ = [foo, 'name_one', *abc] + __all__ = [it for it in mylist] + ''') + self.mod.write(code) + pymod = self.project.get_pymodule(self.mod) + self.assertEqual( + expected, + self.import_tools.organize_imports(pymod)) + + def test_organizing_imports_all_star_tolerates_non_list_of_str_2(self): + code = expected = dedent('''\ + from package import name_one, name_two + + + foo = 'name_two' + __all__ = [foo, 3, 'name_one'] + __all__ = [it for it in mylist] + ''') + self.mod.write(code) + pymod = self.project.get_pymodule(self.mod) + self.assertEqual( + expected, + self.import_tools.organize_imports(pymod)) + + @testutils.time_limit(60) + def test_organizing_imports_all_star_no_infinite_loop(self): + code = expected = dedent('''\ + from package import name_one, name_two + + + foo = bar + bar = foo + __all__ = [foo, 'name_one', 'name_two'] + ''') + self.mod.write(code) + pymod = self.project.get_pymodule(self.mod) + self.assertEqual( + expected, + self.import_tools.organize_imports(pymod)) + def test_customized_import_organization(self): self.mod.write('import sys\nimport sys\n') pymod = self.project.get_pymodule(self.mod) diff --git a/ropetest/testutils.py b/ropetest/testutils.py index 23cf64e08..0b4e98ebe 100644 --- a/ropetest/testutils.py +++ b/ropetest/testutils.py @@ -92,3 +92,15 @@ def only_for_versions_higher(version): def skipNotPOSIX(): return unittest.skipIf(os.name != 'posix', 'This test works only on POSIX') + + +def time_limit(timeout): + if not any(procname in sys.argv[0] for procname in {'pytest', 'py.test'}): + # no-op when running tests without pytest + return lambda *args, **kwargs: lambda func: func + + # do a local import so we don't import pytest when running without pytest + import pytest + + # this prevents infinite loop/recursion from taking forever in CI + return pytest.mark.time_limit(timeout) diff --git a/setup.py b/setup.py index 6779d3ab9..c59558868 100644 --- a/setup.py +++ b/setup.py @@ -68,6 +68,7 @@ def get_version(): classifiers=classifiers, extras_require={ 'dev': [ - 'pytest' + 'pytest', + 'pytest-timeout', ] })