/
ak_to_backend.py
68 lines (48 loc) · 2.29 KB
/
ak_to_backend.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE
from __future__ import annotations
from awkward._backends.dispatch import regularize_backend
from awkward._dispatch import high_level_function
from awkward._layout import HighLevelContext
from awkward._nplikes.numpy_like import NumpyMetadata
__all__ = ("to_backend",)
np = NumpyMetadata.instance()
@high_level_function()
def to_backend(array, backend, *, highlevel=True, behavior=None, attrs=None):
"""
Args:
array: Array-like data (anything #ak.to_layout recognizes).
backend (`"cpu"`, `"cuda"`, or `"jax"`): If `"cpu"`, the array structure is
recursively copied (if need be) to main memory for use with
the default Numpy backend; if `"cuda"`, the structure is copied
to the GPU(s) for use with CuPy. If `"jax"`, the structure is
copied to the CPU for use with JAX.
highlevel (bool): If True, return an #ak.Array; otherwise, return
a low-level #ak.contents.Content subclass.
behavior (None or dict): Custom #ak.behavior for the output array, if
high-level.
attrs (None or dict): Custom attributes for the output array, if
high-level.
Converts an array from `"cpu"`, `"cuda"`, `"jax"` kernels to `"cpu"`,
`"cuda"`, `"jax"`, or `"typetracer"` .
Any components that are already in the desired backend are viewed,
rather than copied, so this operation can be an inexpensive way to ensure
that an array is ready for a particular library.
To use `"cuda"`, the `cupy` package must be installed, either with
pip install cupy
or
conda install -c conda-forge cupy
To use `"jax"`, the `jax` package must be installed, either with
pip install jax
or
conda install -c conda-forge jax
See #ak.kernels.
"""
# Dispatch
yield (array,)
# Implementation
return _impl(array, backend, highlevel, behavior, attrs)
def _impl(array, backend, highlevel, behavior, attrs):
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
layout = ctx.unwrap(array, allow_record=True, allow_unknown=False)
backend_layout = layout.to_backend(regularize_backend(backend))
return ctx.wrap(backend_layout, highlevel=highlevel)