Skip to content

Commit

Permalink
Merge b96f9ca into 7b11f04
Browse files Browse the repository at this point in the history
  • Loading branch information
alexmv committed Oct 15, 2019
2 parents 7b11f04 + b96f9ca commit 561f8cd
Show file tree
Hide file tree
Showing 3 changed files with 208 additions and 39 deletions.
146 changes: 107 additions & 39 deletions black.py
Original file line number Diff line number Diff line change
Expand Up @@ -1814,6 +1814,18 @@ def visit_STANDALONE_COMMENT(self, leaf: Leaf) -> Iterator[Line]:
yield from self.line()
yield from self.visit_default(leaf)

def visit_STRING(self, leaf: Leaf) -> Iterator[Line]:
# Check if it's a docstring
if prev_siblings_are(
leaf.parent, [None, token.NEWLINE, token.INDENT, syms.simple_stmt]
) and is_multiline_string(leaf):
prefix = " " * self.current_line.depth
docstring = fix_docstring(leaf.value[3:-3], prefix)
leaf.value = leaf.value[0:3] + docstring + leaf.value[-3:]
normalize_string_quotes(leaf)

yield from self.visit_default(leaf)

def __attrs_post_init__(self) -> None:
"""You are in a twisty little maze of passages."""
v = self.visit_stmt
Expand Down Expand Up @@ -2095,6 +2107,22 @@ def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
return None


def prev_siblings_are(node: Optional[LN], tokens: List[Optional[NodeType]]) -> bool:
"""Return if the `node` and its previous siblings match types against the provided
list of tokens; the provided `node`has its type matched against the last element in
the list. `None` can be used as the first element to declare that the start of the
list is anchored at the start of its parent's children."""
if not tokens:
return True
if tokens[-1] is None:
return node is None
if not node:
return False
if node.type != tokens[-1]:
return False
return prev_siblings_are(node.prev_sibling, tokens[:-1])


def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
"""Return the child of `ancestor` that contains `descendant`."""
node: Optional[LN] = descendant
Expand Down Expand Up @@ -3634,52 +3662,65 @@ def _fixup_ast_constants(
return node


def assert_equivalent(src: str, dst: str) -> None:
"""Raise AssertionError if `src` and `dst` aren't equivalent."""
def _stringify_ast(
node: Union[ast.AST, ast3.AST, ast27.AST], depth: int = 0
) -> Iterator[str]:
"""Simple visitor generating strings to compare ASTs by content."""

def _v(node: Union[ast.AST, ast3.AST, ast27.AST], depth: int = 0) -> Iterator[str]:
"""Simple visitor generating strings to compare ASTs by content."""
node = _fixup_ast_constants(node)

node = _fixup_ast_constants(node)
yield f"{' ' * depth}{node.__class__.__name__}("

yield f"{' ' * depth}{node.__class__.__name__}("
for field in sorted(node._fields):
# TypeIgnore has only one field 'lineno' which breaks this comparison
type_ignore_classes = (ast3.TypeIgnore, ast27.TypeIgnore)
if sys.version_info >= (3, 8):
type_ignore_classes += (ast.TypeIgnore,)
if isinstance(node, type_ignore_classes):
break

for field in sorted(node._fields):
# TypeIgnore has only one field 'lineno' which breaks this comparison
type_ignore_classes = (ast3.TypeIgnore, ast27.TypeIgnore)
if sys.version_info >= (3, 8):
type_ignore_classes += (ast.TypeIgnore,)
if isinstance(node, type_ignore_classes):
break
try:
value = getattr(node, field)
except AttributeError:
continue

try:
value = getattr(node, field)
except AttributeError:
continue
yield f"{' ' * (depth+1)}{field}="

if isinstance(value, list):
for item in value:
# Ignore nested tuples within del statements, because we may insert
# parentheses and they change the AST.
if (
field == "targets"
and isinstance(node, (ast.Delete, ast3.Delete, ast27.Delete))
and isinstance(item, (ast.Tuple, ast3.Tuple, ast27.Tuple))
):
for item in item.elts:
yield from _stringify_ast(item, depth + 2)
elif isinstance(item, (ast.AST, ast3.AST, ast27.AST)):
yield from _stringify_ast(item, depth + 2)

yield f"{' ' * (depth+1)}{field}="

if isinstance(value, list):
for item in value:
# Ignore nested tuples within del statements, because we may insert
# parentheses and they change the AST.
if (
field == "targets"
and isinstance(node, (ast.Delete, ast3.Delete, ast27.Delete))
and isinstance(item, (ast.Tuple, ast3.Tuple, ast27.Tuple))
):
for item in item.elts:
yield from _v(item, depth + 2)
elif isinstance(item, (ast.AST, ast3.AST, ast27.AST)):
yield from _v(item, depth + 2)

elif isinstance(value, (ast.AST, ast3.AST, ast27.AST)):
yield from _v(value, depth + 2)
elif isinstance(value, (ast.AST, ast3.AST, ast27.AST)):
yield from _stringify_ast(value, depth + 2)

else:
# Constant strings may be indented across newlines, if they are
# docstrings; fold spaces after newlines when comparing
if (
isinstance(node, ast.Constant)
and field == "value"
and isinstance(value, str)
):
normalized = re.sub(r"\n[ \t]+", "\n ", value)
else:
yield f"{' ' * (depth+2)}{value!r}, # {value.__class__.__name__}"
normalized = value
yield f"{' ' * (depth+2)}{normalized!r}, # {value.__class__.__name__}"

yield f"{' ' * depth}) # /{node.__class__.__name__}"
yield f"{' ' * depth}) # /{node.__class__.__name__}"


def assert_equivalent(src: str, dst: str) -> None:
"""Raise AssertionError if `src` and `dst` aren't equivalent."""

try:
src_ast = parse_ast(src)
Expand All @@ -3699,8 +3740,8 @@ def _v(node: Union[ast.AST, ast3.AST, ast27.AST], depth: int = 0) -> Iterator[st
f"This invalid output might be helpful: {log}"
) from None

src_ast_str = "\n".join(_v(src_ast))
dst_ast_str = "\n".join(_v(dst_ast))
src_ast_str = "\n".join(_stringify_ast(src_ast))
dst_ast_str = "\n".join(_stringify_ast(dst_ast))
if src_ast_str != dst_ast_str:
log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
raise AssertionError(
Expand Down Expand Up @@ -4064,5 +4105,32 @@ def patched_main() -> None:
main()


def fix_docstring(docstring: str, prefix: str) -> str:
# https://www.python.org/dev/peps/pep-0257/#handling-docstring-indentation
if not docstring:
return ""
# Convert tabs to spaces (following the normal Python rules)
# and split into a list of lines:
lines = docstring.expandtabs().splitlines()
# Determine minimum indentation (first line doesn't count):
indent = sys.maxsize
for line in lines[1:]:
stripped = line.lstrip()
if stripped:
indent = min(indent, len(line) - len(stripped))
# Remove indentation (first line is special):
trimmed = [lines[0].strip()]
if indent < sys.maxsize:
last_line_idx = len(lines) - 2
for i, line in enumerate(lines[1:]):
stripped_line = line[indent:].rstrip()
if stripped_line or i == last_line_idx:
trimmed.append(prefix + stripped_line)
else:
trimmed.append("")
# Return a single string:
return "\n".join(trimmed)


if __name__ == "__main__":
patched_main()
93 changes: 93 additions & 0 deletions tests/data/docstring.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
class MyClass:
"""Multiline
class docstring
"""

def method(self):
"""Multiline
method docstring
"""
pass


def foo():
"""This is a docstring with
some lines of text here
"""
return


def bar():
'''This is another docstring
with more lines of text
'''
return


def baz():
'''"This" is a string with some
embedded "quotes"'''
return


def troz():
'''Indentation with tabs
is just as OK
'''
return


def zort():
"""Another
multiline
docstring
"""
pass

# output

class MyClass:
"""Multiline
class docstring
"""

def method(self):
"""Multiline
method docstring
"""
pass


def foo():
"""This is a docstring with
some lines of text here
"""
return


def bar():
"""This is another docstring
with more lines of text
"""
return


def baz():
'''"This" is a string with some
embedded "quotes"'''
return


def troz():
"""Indentation with tabs
is just as OK
"""
return


def zort():
"""Another
multiline
docstring
"""
pass
8 changes: 8 additions & 0 deletions tests/test_black.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,14 @@ def test_string_quotes(self) -> None:
black.assert_equivalent(source, not_normalized)
black.assert_stable(source, not_normalized, mode=mode)

@patch("black.dump_to_file", dump_to_stderr)
def test_docstring(self) -> None:
source, expected = read_data("docstring")
actual = fs(source)
self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual)
black.assert_stable(source, actual, black.FileMode())

@patch("black.dump_to_file", dump_to_stderr)
def test_slices(self) -> None:
source, expected = read_data("slices")
Expand Down

0 comments on commit 561f8cd

Please sign in to comment.