diff --git a/src/black/linegen.py b/src/black/linegen.py index cc8e41dfb2..6162c8753d 100644 --- a/src/black/linegen.py +++ b/src/black/linegen.py @@ -31,12 +31,12 @@ BRACKETS, CLOSING_BRACKETS, OPENING_BRACKETS, - RARROW, STANDALONE_COMMENT, STATEMENT, WHITESPACE, Visitor, ensure_visible, + get_annotation_type, is_arith_like, is_async_stmt_or_funcdef, is_atom_with_invisible_parens, @@ -1046,11 +1046,12 @@ def bracket_split_build_line( result.inside_brackets = True result.depth += 1 if leaves: - # Ensure a trailing comma for imports and standalone function arguments, but - # be careful not to add one after any comments or within type annotations. no_commas = ( + # Ensure a trailing comma for imports and standalone function arguments original.is_def + # Don't add one after any comments or within type annotations and opening_bracket.value == "(" + # Don't add one if there's already one there and not any( leaf.type == token.COMMA and ( @@ -1059,22 +1060,9 @@ def bracket_split_build_line( ) for leaf in leaves ) - # In particular, don't add one within a parenthesized return annotation. - # Unfortunately the indicator we're in a return annotation (RARROW) may - # be defined directly in the parent node, the parent of the parent ... - # and so on depending on how complex the return annotation is. - # This isn't perfect and there's some false negatives but they are in - # contexts were a comma is actually fine. - and not any( - node.prev_sibling.type == RARROW - for node in ( - leaves[0].parent, - getattr(leaves[0].parent, "parent", None), - ) - if isinstance(node, Node) and isinstance(node.prev_sibling, Leaf) - ) - # Except the false negatives above for PEP 604 unions where we - # can't add the comma. + # Don't add one inside parenthesized return annotations + and get_annotation_type(leaves[0]) != "return" + # Don't add one inside PEP 604 unions and not ( leaves[0].parent and leaves[0].parent.next_sibling diff --git a/src/black/nodes.py b/src/black/nodes.py index a8869cba23..c0dca6e578 100644 --- a/src/black/nodes.py +++ b/src/black/nodes.py @@ -3,7 +3,18 @@ """ import sys -from typing import Final, Generic, Iterator, List, Optional, Set, Tuple, TypeVar, Union +from typing import ( + Final, + Generic, + Iterator, + List, + Literal, + Optional, + Set, + Tuple, + TypeVar, + Union, +) if sys.version_info >= (3, 10): from typing import TypeGuard @@ -951,16 +962,21 @@ def is_number_token(nl: NL) -> TypeGuard[Leaf]: return nl.type == token.NUMBER -def is_part_of_annotation(leaf: Leaf) -> bool: - """Returns whether this leaf is part of type annotations.""" +def get_annotation_type(leaf: Leaf) -> Literal["return", "param", None]: + """Returns the type of annotation this leaf is part of, if any.""" ancestor = leaf.parent while ancestor is not None: if ancestor.prev_sibling and ancestor.prev_sibling.type == token.RARROW: - return True + return "return" if ancestor.parent and ancestor.parent.type == syms.tname: - return True + return "param" ancestor = ancestor.parent - return False + return None + + +def is_part_of_annotation(leaf: Leaf) -> bool: + """Returns whether this leaf is part of a type annotation.""" + return get_annotation_type(leaf) is not None def first_leaf(node: LN) -> Optional[Leaf]: