Skip to content

Commit

Permalink
Merge pull request #944 from viourr/develop
Browse files Browse the repository at this point in the history
 keep-direct-and-as only inserts direct import if it was present in the original file
  • Loading branch information
timothycrosley committed May 9, 2019
2 parents 5fa7d94 + c52d83a commit 1bd5c99
Show file tree
Hide file tree
Showing 2 changed files with 304 additions and 25 deletions.
74 changes: 49 additions & 25 deletions isort/isort.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import copy
import itertools
import re
from collections import OrderedDict, namedtuple
from collections import OrderedDict, defaultdict, namedtuple
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple

from isort import utils
Expand Down Expand Up @@ -76,7 +76,7 @@ def __init__(self, file_contents: str, config: Dict[str, Any]) -> None:
self.out_lines = [] # type: List[str]
self.comments = {'from': {}, 'straight': {}, 'nested': {}, 'above': {'straight': {}, 'from': {}}} # type: CommentsDict
self.imports = OrderedDict() # type: OrderedDict[str, Dict[str, Any]]
self.as_map = {} # type: Dict[str, str]
self.as_map = defaultdict(list) # type: Dict[str, List[str]]

section_names = self.config['sections']
self.sections = namedtuple('Sections', section_names)(*[name for name in section_names]) # type: Any
Expand Down Expand Up @@ -264,18 +264,20 @@ def _add_straight_imports(self, straight_modules: Iterable[str], section: str, s
if module in self.remove_imports:
continue

import_definition = []
if module in self.as_map:
import_definition = ''
if self.config['keep_direct_and_as_imports']:
import_definition = "import {0}\n".format(module)
import_definition += "import {0} as {1}".format(module, self.as_map[module])
if self.config['keep_direct_and_as_imports'] and self.imports[section]['straight'][module]:
import_definition.append("import {0}".format(module))
import_definition.extend("import {0} as {1}".format(module, as_import)
for as_import in self.as_map[module])
else:
import_definition = "import {0}".format(module)
import_definition.append("import {0}".format(module))

comments_above = self.comments['above']['straight'].pop(module, None)
if comments_above:
section_output.extend(comments_above)
section_output.append(self._add_comments(self.comments['straight'].get(module), import_definition))
section_output.extend(self._add_comments(self.comments['straight'].get(module), idef)
for idef in import_definition)

def _add_from_imports(self, from_modules: Iterable[str], section: str, section_output: List[str], ignore_case: bool) -> None:
for module in from_modules:
Expand All @@ -292,14 +294,22 @@ def _add_from_imports(self, from_modules: Iterable[str], section: str, section_o

sub_modules = ['{0}.{1}'.format(module, from_import) for from_import in from_imports]
as_imports = {
from_import: "{0} as {1}".format(from_import, self.as_map[sub_module])
from_import: ["{0} as {1}".format(from_import, as_module)
for as_module in self.as_map[sub_module]]
for from_import, sub_module in zip(from_imports, sub_modules)
if sub_module in self.as_map
}
if self.config['combine_as_imports'] and not ("*" in from_imports and self.config['combine_star']):
if not self.config['no_inline_sort']:
for as_import in as_imports:
as_imports[as_import] = nsorted(as_imports[as_import])
for from_import in copy.copy(from_imports):
if from_import in as_imports:
from_imports[from_imports.index(from_import)] = as_imports.pop(from_import)
idx = from_imports.index(from_import)
if self.config['keep_direct_and_as_imports'] and self.imports[section]['from'][module][from_import]:
from_imports[(idx+1):(idx+1)] = as_imports.pop(from_import)
else:
from_imports[idx:(idx+1)] = as_imports.pop(from_import)

while from_imports:
comments = self.comments['from'].pop(module, ())
Expand All @@ -310,29 +320,35 @@ def _add_from_imports(self, from_modules: Iterable[str], section: str, section_o
import_statements = []
while from_imports:
from_import = from_imports.pop(0)
if from_import in as_imports:
from_comments = self.comments['straight'].get('{}.{}'.format(module, from_import))
import_statements.append(self._add_comments(from_comments,
self._wrap(import_start + as_imports[from_import])))
continue
single_import_line = self._add_comments(comments, import_start + from_import)
comment = self.comments['nested'].get(module, {}).pop(from_import, None)
if comment:
single_import_line += "{0} {1}".format(comments and ";" or self.config['comment_prefix'],
comment)
import_statements.append(self._wrap(single_import_line))
if from_import in as_imports:
if self.config['keep_direct_and_as_imports'] and self.imports[section]['from'][module][from_import]:
import_statements.append(self._wrap(single_import_line))
from_comments = self.comments['straight'].get('{}.{}'.format(module, from_import))
import_statements.extend(self._add_comments(from_comments,
self._wrap(import_start + as_import))
for as_import in nsorted(as_imports[from_import]))
else:
import_statements.append(self._wrap(single_import_line))
comments = None
import_statement = self.line_separator.join(import_statements)
else:
while from_imports and from_imports[0] in as_imports:
from_import = from_imports.pop(0)
as_imports[from_import] = nsorted(as_imports[from_import])
from_comments = self.comments['straight'].get('{}.{}'.format(module, from_import))
above_comments = self.comments['above']['from'].pop(module, None)
if above_comments:
section_output.extend(above_comments)

section_output.append(self._add_comments(from_comments,
self._wrap(import_start + as_imports[from_import])))
if self.config['keep_direct_and_as_imports'] and self.imports[section]['from'][module][from_import]:
section_output.append(self._add_comments(from_comments, self._wrap(import_start + from_import)))
section_output.extend(self._add_comments(from_comments, self._wrap(import_start + as_import))
for as_import in as_imports[from_import])

star_import = False
if "*" in from_imports:
Expand All @@ -342,7 +358,7 @@ def _add_from_imports(self, from_modules: Iterable[str], section: str, section_o
comments = None

for from_import in copy.copy(from_imports):
if from_import in as_imports:
if from_import in as_imports and not self.config['keep_direct_and_as_imports']:
continue
comment = self.comments['nested'].get(module, {}).pop(from_import, None)
if comment:
Expand All @@ -357,7 +373,10 @@ def _add_from_imports(self, from_modules: Iterable[str], section: str, section_o
comments = None

from_import_section = []
while from_imports and from_imports[0] not in as_imports:
while from_imports and (from_imports[0] not in as_imports or
(self.config['keep_direct_and_as_imports'] and
self.config['combine_as_imports'] and
self.imports[section]['from'][module][from_import])):
from_import_section.append(from_imports.pop(0))
if star_import:
import_statement = import_start + (", ").join(from_import_section)
Expand Down Expand Up @@ -876,15 +895,17 @@ def _parse(self) -> None:

imports = [item.replace("{|", "{ ").replace("|}", " }") for item in
self._strip_syntax(import_string).split()]
straight_import = True
if "as" in imports and (imports.index('as') + 1) < len(imports):
straight_import = False
while "as" in imports:
index = imports.index('as')
if import_type == "from":
module = imports[0] + "." + imports[index - 1]
self.as_map[module] = imports[index + 1]
self.as_map[module].append(imports[index + 1])
else:
module = imports[index - 1]
self.as_map[module] = imports[index + 1]
self.as_map[module].append(imports[index + 1])
if not self.config['combine_as_imports']:
self.comments['straight'][module] = comments
comments = []
Expand Down Expand Up @@ -921,8 +942,10 @@ def _parse(self) -> None:
self.import_index -= len(self.comments['above']['from'].get(import_from, []))

if import_from not in root:
root[import_from] = OrderedDict()
root[import_from].update((module, None) for module in imports)
root[import_from] = OrderedDict((module, straight_import) for module in imports)
else:
root[import_from].update((module, straight_import | root[import_from].get(module, False))
for module in imports)
else:
for module in imports:
if comments:
Expand Down Expand Up @@ -950,4 +973,5 @@ def _parse(self) -> None:
"WARNING: could not place module {0} of line {1} --"
" Do you need to define a default section?".format(import_from, line)
)
self.imports[placed_module][import_type][module] = None
straight_import |= self.imports[placed_module][import_type].get(module, False)
self.imports[placed_module][import_type][module] = straight_import
Loading

0 comments on commit 1bd5c99

Please sign in to comment.