-
-
Notifications
You must be signed in to change notification settings - Fork 608
/
collections.py
129 lines (102 loc) · 4.58 KB
/
collections.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# Copyright 2017 Pants project contributors (see CONTRIBUTORS.md).
# Licensed under the Apache License, Version 2.0 (see LICENSE).
from __future__ import annotations
import collections
import collections.abc
import math
from typing import Any, Callable, Iterable, Iterator, MutableMapping, TypeVar
from pants.engine.internals import native_engine
def recursively_update(d: MutableMapping, d2: MutableMapping) -> None:
"""dict.update but which merges child dicts (dict2 takes precedence where there's conflict)."""
for k, v in d2.items():
if k in d:
if isinstance(v, dict):
recursively_update(d[k], v)
continue
d[k] = v
_T = TypeVar("_T")
def assert_single_element(iterable: Iterable[_T]) -> _T:
"""Get the single element of `iterable`, or raise an error.
:raise: :class:`StopIteration` if there is no element.
:raise: :class:`ValueError` if there is more than one element.
"""
it = iter(iterable)
first_item = next(it)
try:
next(it)
except StopIteration:
return first_item
raise ValueError(f"iterable {iterable!r} has more than one element.")
def ensure_list(
val: Any | Iterable[Any], *, expected_type: type[_T], allow_single_scalar: bool = False
) -> list[_T]:
"""Ensure that every element of an iterable is the expected type and convert the result to a
list.
If `allow_single_scalar` is True, a single value T will be wrapped into a `List[T]`.
"""
if isinstance(val, expected_type):
if not allow_single_scalar:
raise ValueError(f"The value {val} must be wrapped in an iterable (e.g. a list).")
return [val]
if not isinstance(val, collections.abc.Iterable):
raise ValueError(
f"The value {val} (type {type(val)}) was not an iterable of {expected_type}."
)
result: list[_T] = []
for i, x in enumerate(val):
if not isinstance(x, expected_type):
raise ValueError(
f"Not all elements of the iterable have type {expected_type}. Encountered the "
f"element {x} of type {type(x)} at index {i}."
)
result.append(x)
return result
def ensure_str_list(val: str | Iterable[str], *, allow_single_str: bool = False) -> list[str]:
"""Ensure that every element of an iterable is a string and convert the result to a list.
If `allow_single_str` is True, a single `str` will be wrapped into a `List[str]`.
"""
return ensure_list(val, expected_type=str, allow_single_scalar=allow_single_str)
def partition_sequentially(
items: Iterable[_T],
*,
key: Callable[[_T], str],
size_target: int,
size_max: int | None = None,
) -> Iterator[list[_T]]:
"""Stably partitions the given items into batches of around `size_target` items.
The "stability" property refers to avoiding adjusting all batches when a single item is added,
which could happen if the items were trivially windowed using `itertools.islice` and an
item was added near the front of the list.
Batches will optionally be capped to `size_max`, but note that this can weaken the stability
properties of the bucketing, by forcing bucket boundaries to be created where they otherwise
might not.
"""
# To stably partition the arguments into ranges of approximately `size_target`, we sort them,
# and create a new batch sequentially once we encounter an item hash prefixed with a threshold
# of zeros.
#
# The hashes act like a (deterministic) series of rolls of an evenly distributed die. The
# probability of a hash prefixed with Z zero bits is 1/2^Z, and so to break after N items on
# average, we look for `Z == log2(N)` zero bits.
#
# Breaking on these deterministic boundaries reduces the chance that adding or removing items
# causes multiple buckets to be recalculated. But when a `size_max` value is set, it's possible
# for adding items to cause multiple sequential buckets to be affected.
zero_prefix_threshold = math.log(max(1, size_target), 2)
batch: list[_T] = []
def emit_batch() -> list[_T]:
assert batch
result = list(batch)
batch.clear()
return result
keyed_items = []
for item in items:
keyed_items.append((key(item), item))
keyed_items.sort()
for item_key, item in keyed_items:
batch.append(item)
prefix_zero_bits = native_engine.hash_prefix_zero_bits(item_key)
if prefix_zero_bits >= zero_prefix_threshold or (size_max and len(batch) >= size_max):
yield emit_batch()
if batch:
yield emit_batch()