/
ak_where.py
128 lines (106 loc) · 4.99 KB
/
ak_where.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE
from __future__ import annotations
import awkward as ak
from awkward._backends.numpy import NumpyBackend
from awkward._dispatch import high_level_function
from awkward._layout import HighLevelContext, ensure_same_backend
from awkward._nplikes.numpy_like import NumpyMetadata
__all__ = ("where",)
np = NumpyMetadata.instance()
cpu = NumpyBackend.instance()
@ak._connect.numpy.implements("where")
@high_level_function()
def where(condition, *args, mergebool=True, highlevel=True, behavior=None, attrs=None):
"""
Args:
condition: Array-like data (anything #ak.to_layout recognizes) of booleans.
x: Optional array-like data (anything #ak.to_layout recognizes) with the same
length as `condition`.
y: Optional array-like data (anything #ak.to_layout recognizes) with the same
length as `condition`.
mergebool (bool, default is True): If True, boolean and numeric data
can be combined into the same buffer, losing information about
False vs `0` and True vs `1`; otherwise, they are kept in separate
buffers with distinct types (using an #ak.contents.UnionArray).
highlevel (bool, default is True): If True, return an #ak.Array;
otherwise, return a low-level #ak.contents.Content subclass.
behavior (None or dict): Custom #ak.behavior for the output array, if
high-level.
attrs (None or dict): Custom attributes for the output array, if
high-level.
This function has a one-argument form, `condition` without `x` or `y`, and
a three-argument form, `condition`, `x`, and `y`. In the one-argument form,
it is completely equivalent to NumPy's
[nonzero](https://docs.scipy.org/doc/numpy/reference/generated/numpy.nonzero.html)
function.
In the three-argument form, it acts as a vectorized ternary operator:
`condition`, `x`, and `y` must all have the same length and
output[i] = x[i] if condition[i] else y[i]
for all `i`. The structure of `x` and `y` do not need to be the same; if
they are incompatible types, the output will have #ak.type.UnionType.
"""
# Dispatch
yield (*args, condition)
# Implementation
if len(args) == 0:
return _impl1(condition, mergebool, highlevel, behavior, attrs)
elif len(args) == 1:
raise ValueError("either both or neither of x and y should be given")
elif len(args) == 2:
x, y = args
return _impl3(condition, x, y, mergebool, highlevel, behavior, attrs)
else:
raise TypeError(
f"where() takes from 1 to 3 positional arguments but {len(args) + 1} were "
"given"
)
def _impl1(condition, mergebool, highlevel, behavior, attrs):
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
layout = ctx.unwrap(condition, allow_record=False, primitive_policy="error")
out = layout.backend.nplike.nonzero(layout.to_backend_array(allow_missing=False))
return tuple(
ctx.wrap(ak.contents.NumpyArray(x, backend=layout.backend), highlevel=highlevel)
for x in out
)
def _impl3(condition, x, y, mergebool, highlevel, behavior, attrs):
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
layouts = ensure_same_backend(
ctx.unwrap(x, allow_record=False, primitive_policy="pass-through"),
ctx.unwrap(y, allow_record=False, primitive_policy="pass-through"),
ctx.unwrap(condition, allow_record=False, primitive_policy="error"),
)
def action(inputs, backend, **kwargs):
x, y, condition = inputs
if isinstance(condition, ak.contents.NumpyArray):
npcondition = backend.index_nplike.asarray(condition.data)
tags = ak.index.Index8((npcondition == 0).view(np.int8))
index = ak.index.Index64(
backend.index_nplike.arange(tags.length, dtype=np.int64),
nplike=backend.index_nplike,
)
if not isinstance(x, ak.contents.Content):
x = ak.contents.NumpyArray(
backend.nplike.repeat(
backend.nplike.asarray(x),
backend.nplike.shape_item_as_index(tags.length),
)
)
if not isinstance(y, ak.contents.Content):
y = ak.contents.NumpyArray(
backend.nplike.repeat(
backend.nplike.asarray(y),
backend.nplike.shape_item_as_index(tags.length),
)
)
return (
ak.contents.UnionArray.simplified(
tags,
index,
[x, y],
mergebool=mergebool,
),
)
else:
return None
out = ak._broadcasting.broadcast_and_apply(layouts, action, numpy_to_regular=True)
return ctx.wrap(out[0], highlevel=highlevel)