Skip to content

Commit

Permalink
Add func_def_split transformer
Browse files Browse the repository at this point in the history
Splits function definitions while respecting trailing commas and prioritising parts that are longer than line length.
  • Loading branch information
sumezulike committed Mar 10, 2024
1 parent 47d963e commit e960543
Showing 1 changed file with 194 additions and 2 deletions.
196 changes: 194 additions & 2 deletions src/black/linegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from dataclasses import replace
from enum import Enum, auto
from functools import partial, wraps
from typing import Collection, Iterator, List, Optional, Set, Union, cast
from typing import Collection, Iterator, List, Optional, Set, Tuple, Union, cast

from black.brackets import (
COMMA_PRIORITY,
Expand Down Expand Up @@ -591,7 +591,10 @@ def transform_line(
else:
transformers = []
elif line.is_def and not should_split_funcdef_with_rhs(line, mode):
transformers = [left_hand_split]
if Preview.function_def_wrap_order in mode:
transformers = [func_def_split, left_hand_split]
else:
transformers = [left_hand_split]
else:

def _rhs(
Expand Down Expand Up @@ -753,6 +756,195 @@ def left_hand_split(
yield result


def _get_func_def_parts(
line: Line,
) -> Tuple[List[Leaf], List[Leaf], List[Leaf], List[Leaf], List[Leaf]]:
start_leaves: List[Leaf] = []
type_param_leaves: List[Leaf] = []
func_param_leaves: List[Leaf] = []
return_type_leaves: List[Leaf] = []
end_leaves: List[Leaf] = []
current_leaves = start_leaves
type_param_open_bracket: Optional[Leaf] = None
func_param_open_bracket: Optional[Leaf] = None

for leaf in line.leaves:
if current_leaves is start_leaves and leaf.type == token.LSQB:
current_leaves = type_param_leaves
type_param_open_bracket = leaf
elif current_leaves is start_leaves and leaf.type == token.LPAR:
current_leaves = func_param_leaves
func_param_open_bracket = leaf
elif current_leaves is end_leaves and leaf.type == token.RARROW:
current_leaves = return_type_leaves

current_leaves.append(leaf)

if (
leaf.type == token.RSQB
and leaf.opening_bracket is type_param_open_bracket
and isinstance(type_param_open_bracket, Leaf)
):
current_leaves = start_leaves # changes to func_param_leaves on next leaf
if (
leaf.type == token.RPAR
and leaf.opening_bracket is func_param_open_bracket
and isinstance(func_param_open_bracket, Leaf)
):
current_leaves = end_leaves
if current_leaves is return_type_leaves and leaf.type == token.COLON:
current_leaves = end_leaves

return (
start_leaves,
type_param_leaves,
func_param_leaves,
return_type_leaves,
end_leaves,
)


def _partial_line(leaves: List[Leaf], line: Line) -> Line:
new_line = Line(mode=line.mode, depth=line.depth)
leaves_to_track = get_leaves_inside_matching_brackets(leaves)
for leaf in leaves:
new_line.append(
leaf, preformatted=True, track_bracket=id(leaf) in leaves_to_track
)
return new_line


def func_def_split(
line: Line, _features: Collection[Feature], mode: Mode
) -> Iterator[Line]:
"""Split a function definition line, either at the PEP 695 type parameters,
the function arguments, or the return type.
Split position priority in descending order:
- function parameters with magic trailing comma
- return type with magic trailing comma
- type parameters with magic trailing comma
- over-long return type
- over-long type parameters
- function parameters
E.g. (assuming a sufficiently small line length)
```
def f[T](a): ...
```
is formatted as
```
def f[T](
a
): ...
```
but
```
def f[T,](a): ...
```
becomes
```
def f[
T,
](a): ...
```
because the trailing comma gives the type parameters a higher priority.
"""

def _split(split_leaves: List[Leaf]) -> Iterator[Line]:
if split_leaves is func_param_leaves:
head_leaves = [*start_leaves, *type_param_leaves, func_param_leaves[0]]
body_leaves = func_param_leaves[1:-1]
tail_leaves = [func_param_leaves[-1], *return_type_leaves, *end_leaves]
opening_bracket = func_param_leaves[0]
elif split_leaves is type_param_leaves:
head_leaves = [*start_leaves, type_param_leaves[0]]
body_leaves = type_param_leaves[1:-1]
tail_leaves = [
type_param_leaves[-1],
*func_param_leaves,
*return_type_leaves,
*end_leaves,
]
opening_bracket = type_param_leaves[0]
elif split_leaves is return_type_leaves:
yield from right_hand_split(line, mode, _features)
return
else:
raise ValueError("Invalid argument for split_leaves")

head = bracket_split_build_line(
head_leaves, line, opening_bracket, component=_BracketSplitComponent.head
)
body = bracket_split_build_line(
body_leaves, line, opening_bracket, component=_BracketSplitComponent.body
)
tail = bracket_split_build_line(
tail_leaves, line, opening_bracket, component=_BracketSplitComponent.tail
)
bracket_split_succeeded_or_raise(head, body, tail)
for result in (head, body, tail):
if result:
yield result

(
start_leaves,
type_param_leaves,
func_param_leaves,
return_type_leaves,
end_leaves,
) = _get_func_def_parts(line)

possible_split_leaves = []
if len(func_param_leaves) > 2:
possible_split_leaves.append(func_param_leaves)
if type_param_leaves:
possible_split_leaves.append(type_param_leaves)
if return_type_leaves:
possible_split_leaves.append(return_type_leaves)

if len(possible_split_leaves) == 1:
# only one option, just delegate
yield from left_hand_split(line, _features, mode)
return

# multiple possible splits, look for trailing commas

func_param_line = _partial_line(func_param_leaves, line)
type_param_line = _partial_line(type_param_leaves, line)
return_type_line = _partial_line(return_type_leaves, line)

if len(func_param_leaves) > 2 and func_param_line.magic_trailing_comma is not None:
yield from _split(func_param_leaves)
return

if return_type_leaves and return_type_line.magic_trailing_comma is not None:
yield from _split(return_type_leaves)
return

if type_param_leaves and type_param_line.magic_trailing_comma is not None:
yield from _split(type_param_leaves)
return

# no trailing commas, look for overly long parts

if return_type_leaves and not is_line_short_enough(return_type_line, mode=mode):
yield from _split(return_type_leaves)
return

if type_param_leaves and not is_line_short_enough(type_param_line, mode=mode):
yield from _split(type_param_leaves)
return

# otherwise just start with function parameters

if len(func_param_leaves) > 2:
yield from _split(func_param_leaves)
return

raise CannotSplit("No good way to split funcdef")


def right_hand_split(
line: Line,
mode: Mode,
Expand Down

0 comments on commit e960543

Please sign in to comment.