Skip to content

Commit

Permalink
Merge 08a99ca into cad4138
Browse files Browse the repository at this point in the history
  • Loading branch information
zsol committed Jul 26, 2019
2 parents cad4138 + 08a99ca commit f23c2cd
Show file tree
Hide file tree
Showing 8 changed files with 139 additions and 30 deletions.
96 changes: 71 additions & 25 deletions black.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import ast
import asyncio
from concurrent.futures import Executor, ProcessPoolExecutor
from contextlib import contextmanager
Expand Down Expand Up @@ -141,6 +142,7 @@ class Feature(Enum):
# set for every version of python.
ASYNC_IDENTIFIERS = 6
ASYNC_KEYWORDS = 7
ASSIGNMENT_EXPRESSIONS = 8


VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
Expand Down Expand Up @@ -175,6 +177,7 @@ class Feature(Enum):
Feature.TRAILING_COMMA_IN_CALL,
Feature.TRAILING_COMMA_IN_DEF,
Feature.ASYNC_KEYWORDS,
Feature.ASSIGNMENT_EXPRESSIONS,
},
}

Expand Down Expand Up @@ -2863,6 +2866,8 @@ def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
check_lpar = True

if check_lpar:
if is_walrus_assignment(child):
continue
if child.type == syms.atom:
if maybe_make_parens_invisible_in_atom(child, parent=node):
lpar = Leaf(token.LPAR, "")
Expand Down Expand Up @@ -3017,18 +3022,24 @@ def is_empty_tuple(node: LN) -> bool:
)


def unwrap_singleton_parenthesis(node: LN) -> Optional[LN]:
"""Returns `wrapped` if `node` is of the shape ( wrapped ).
Parenthesis can be optional. Returns None otherwise"""
if len(node.children) != 3:
return None
lpar, wrapped, rpar = node.children
if not (lpar.type == token.LPAR and rpar.type == token.RPAR):
return None

return wrapped


def is_one_tuple(node: LN) -> bool:
"""Return True if `node` holds a tuple with one element, with or without parens."""
if node.type == syms.atom:
if len(node.children) != 3:
return False

lpar, gexp, rpar = node.children
if not (
lpar.type == token.LPAR
and gexp.type == syms.testlist_gexp
and rpar.type == token.RPAR
):
gexp = unwrap_singleton_parenthesis(node)
if gexp is None or gexp.type != syms.testlist_gexp:
return False

return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
Expand All @@ -3040,6 +3051,12 @@ def is_one_tuple(node: LN) -> bool:
)


def is_walrus_assignment(node: LN) -> bool:
"""Return True iff `node` is of the shape ( test := test )"""
inner = unwrap_singleton_parenthesis(node)
return inner is not None and inner.type == syms.namedexpr_test


def is_yield(node: LN) -> bool:
"""Return True if `node` holds a `yield` or `yield from` expression."""
if node.type == syms.yield_expr:
Expand Down Expand Up @@ -3198,6 +3215,9 @@ def get_features_used(node: Node) -> Set[Feature]:
if "_" in n.value: # type: ignore
features.add(Feature.NUMERIC_UNDERSCORES)

elif n.type == token.COLONEQUAL:
features.add(Feature.ASSIGNMENT_EXPRESSIONS)

elif (
n.type in {syms.typedargslist, syms.arglist}
and n.children
Expand Down Expand Up @@ -3479,32 +3499,58 @@ def __str__(self) -> str:
return ", ".join(report) + "."


def parse_ast(src: str) -> Union[ast3.AST, ast27.AST]:
for feature_version in (7, 6):
try:
return ast3.parse(src, feature_version=feature_version)
except SyntaxError:
continue
def parse_ast(src: str) -> Union[ast.AST, ast3.AST, ast27.AST]:
filename = "<unknown>"
if sys.version_info >= (3, 8):
# TODO: support Python 4+ ;)
for minor_version in range(sys.version_info[1], 4, -1):
try:
return ast.parse(src, filename, feature_version=(3, minor_version))
except SyntaxError:
continue
else:
for feature_version in (7, 6):
try:
return ast3.parse(src, filename, feature_version=feature_version)
except SyntaxError:
continue

return ast27.parse(src)


def _fixup_ast_constants(
node: Union[ast.AST, ast3.AST, ast27.AST]
) -> Union[ast.AST, ast3.AST, ast27.AST]:
"""Map ast nodes deprecated in 3.8 to Constant."""
# casts are required until this is released:
# https://github.com/python/typeshed/pull/3142
if isinstance(node, (ast.Str, ast3.Str, ast27.Str, ast.Bytes, ast3.Bytes)):
return cast(ast.AST, ast.Constant(value=node.s))
elif isinstance(node, (ast.Num, ast3.Num, ast27.Num)):
return cast(ast.AST, ast.Constant(value=node.n))
elif isinstance(node, (ast.NameConstant, ast3.NameConstant)):
return cast(ast.AST, ast.Constant(value=node.value))
return node


def assert_equivalent(src: str, dst: str) -> None:
"""Raise AssertionError if `src` and `dst` aren't equivalent."""

def _v(node: Union[ast3.AST, ast27.AST], depth: int = 0) -> Iterator[str]:
def _v(node: Union[ast.AST, ast3.AST, ast27.AST], depth: int = 0) -> Iterator[str]:
"""Simple visitor generating strings to compare ASTs by content."""

node = _fixup_ast_constants(node)

yield f"{' ' * depth}{node.__class__.__name__}("

for field in sorted(node._fields):
# TypeIgnore has only one field 'lineno' which breaks this comparison
if isinstance(node, (ast3.TypeIgnore, ast27.TypeIgnore)):
type_ignore_classes = (ast3.TypeIgnore, ast27.TypeIgnore)
if sys.version_info >= (3, 8):
type_ignore_classes += (ast.TypeIgnore,)
if isinstance(node, type_ignore_classes):
break

# Ignore str kind which is case sensitive / and ignores unicode_literals
if isinstance(node, (ast3.Str, ast27.Str, ast3.Bytes)) and field == "kind":
continue

try:
value = getattr(node, field)
except AttributeError:
Expand All @@ -3518,15 +3564,15 @@ def _v(node: Union[ast3.AST, ast27.AST], depth: int = 0) -> Iterator[str]:
# parentheses and they change the AST.
if (
field == "targets"
and isinstance(node, (ast3.Delete, ast27.Delete))
and isinstance(item, (ast3.Tuple, ast27.Tuple))
and isinstance(node, (ast.Delete, ast3.Delete, ast27.Delete))
and isinstance(item, (ast.Tuple, ast3.Tuple, ast27.Tuple))
):
for item in item.elts:
yield from _v(item, depth + 2)
elif isinstance(item, (ast3.AST, ast27.AST)):
elif isinstance(item, (ast.AST, ast3.AST, ast27.AST)):
yield from _v(item, depth + 2)

elif isinstance(value, (ast3.AST, ast27.AST)):
elif isinstance(value, (ast.AST, ast3.AST, ast27.AST)):
yield from _v(value, depth + 2)

else:
Expand Down
8 changes: 5 additions & 3 deletions blib2to3/Grammar.txt
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ assert_stmt: 'assert' test [',' test]

compound_stmt: if_stmt | while_stmt | for_stmt | try_stmt | with_stmt | funcdef | classdef | decorated | async_stmt
async_stmt: ASYNC (funcdef | with_stmt | for_stmt)
if_stmt: 'if' test ':' suite ('elif' test ':' suite)* ['else' ':' suite]
if_stmt: 'if' namedexpr_test ':' suite ('elif' namedexpr_test ':' suite)* ['else' ':' suite]
while_stmt: 'while' test ':' suite ['else' ':' suite]
for_stmt: 'for' exprlist 'in' testlist ':' suite ['else' ':' suite]
try_stmt: ('try' ':' suite
Expand All @@ -91,6 +91,7 @@ testlist_safe: old_test [(',' old_test)+ [',']]
old_test: or_test | old_lambdef
old_lambdef: 'lambda' [varargslist] ':' old_test

namedexpr_test: test [':=' test]
test: or_test ['if' or_test 'else' test] | lambdef
or_test: and_test ('or' and_test)*
and_test: not_test ('and' not_test)*
Expand All @@ -111,8 +112,8 @@ atom: ('(' [yield_expr|testlist_gexp] ')' |
'{' [dictsetmaker] '}' |
'`' testlist1 '`' |
NAME | NUMBER | STRING+ | '.' '.' '.')
listmaker: (test|star_expr) ( old_comp_for | (',' (test|star_expr))* [','] )
testlist_gexp: (test|star_expr) ( old_comp_for | (',' (test|star_expr))* [','] )
listmaker: (namedexpr_test|star_expr) ( old_comp_for | (',' (namedexpr_test|star_expr))* [','] )
testlist_gexp: (namedexpr_test|star_expr) ( old_comp_for | (',' (namedexpr_test|star_expr))* [','] )
lambdef: 'lambda' [varargslist] ':' test
trailer: '(' [arglist] ')' | '[' subscriptlist ']' | '.' NAME
subscriptlist: subscript (',' subscript)* [',']
Expand All @@ -137,6 +138,7 @@ arglist: argument (',' argument)* [',']
# multiple (test comp_for) arguments are blocked; keyword unpackings
# that precede iterable unpackings are blocked; etc.
argument: ( test [comp_for] |
test ':=' test |
test '=' test |
'**' test |
'*' test )
Expand Down
1 change: 1 addition & 0 deletions blib2to3/pgen2/grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def report(self):
// DOUBLESLASH
//= DOUBLESLASHEQUAL
-> RARROW
:= COLONEQUAL
"""

opmap = {}
Expand Down
3 changes: 2 additions & 1 deletion blib2to3/pgen2/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@
AWAIT = 56
ASYNC = 57
ERRORTOKEN = 58
N_TOKENS = 59
COLONEQUAL = 59
N_TOKENS = 60
NT_OFFSET = 256
#--end constants--

Expand Down
2 changes: 1 addition & 1 deletion blib2to3/pgen2/tokenize.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def _combinations(*l):
# recognized as two instances of =).
Operator = group(r"\*\*=?", r">>=?", r"<<=?", r"<>", r"!=",
r"//=?", r"->",
r"[+\-*/%&@|^=<>]=?",
r"[+\-*/%&@|^=<>:]=?",
r"~")

Bracket = '[][(){}]'
Expand Down
1 change: 1 addition & 0 deletions blib2to3/pygram.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class python_symbols(Symbols):
import_stmt: int
lambdef: int
listmaker: int
namedexpr_test: int
not_test: int
old_comp_for: int
old_comp_if: int
Expand Down
41 changes: 41 additions & 0 deletions tests/data/pep_572.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
(a := 1)
(a := a)
if (match := pattern.search(data)) is None:
pass
[y := f(x), y ** 2, y ** 3]
filtered_data = [y for x in data if (y := f(x)) is None]
(y := f(x))
y0 = (y1 := f(x))
foo(x=(y := f(x)))


def foo(answer=(p := 42)):
pass


def foo(answer: (p := 42) = 5):
pass


lambda: (x := 1)
(x := lambda: 1)
(x := lambda: (y := 1))
lambda line: (m := re.match(pattern, line)) and m.group(1)
x = (y := 0)
(z := (y := (x := 0)))
(info := (name, phone, *rest))
(x := 1, 2)
(total := total + tax)
len(lines := f.readlines())
foo(x := 3, cat="vector")
foo(cat=(category := "vector"))
if any(len(longline := l) >= 100 for l in lines):
print(longline)
if env_base := os.environ.get("PYTHONUSERBASE", None):
return env_base
if self._is_special and (ans := self._check_nans(context=context)):
return ans
foo(b := 2, a=1)
foo(b := 2, a=1)
foo((b := 2), a=1)
foo(c=(b := 2), a=1)
17 changes: 17 additions & 0 deletions tests/test_black.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,23 @@ def test_expression(self) -> None:
black.assert_equivalent(source, actual)
black.assert_stable(source, actual, black.FileMode())

@patch("black.dump_to_file", dump_to_stderr)
def test_pep_572(self) -> None:
source, expected = read_data("pep_572")
actual = fs(source)
self.assertFormatEqual(expected, actual)
black.assert_stable(source, actual, black.FileMode())
if sys.version_info >= (3, 8):
black.assert_equivalent(source, actual)

def test_pep_572_version_detection(self) -> None:
source, _ = read_data("pep_572")
root = black.lib2to3_parse(source)
features = black.get_features_used(root)
self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
versions = black.detect_target_versions(root)
self.assertIn(black.TargetVersion.PY38, versions)

def test_expression_ff(self) -> None:
source, expected = read_data("expression")
tmp_file = Path(black.dump_to_file(source))
Expand Down

0 comments on commit f23c2cd

Please sign in to comment.