Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: replace nplike with backend in to_buffers #1942

Merged
merged 4 commits into from Dec 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions dev/generate-tests.py
Expand Up @@ -679,7 +679,7 @@ def gencudakerneltests(specdict):
)

f.write(
"import cupy\nimport pytest\n\nimport awkward as ak\nimport awkward._connect.cuda as ak_cu\n\ncupy_nplike = ak.nplikes.Cupy.instance()\n\n"
"import cupy\nimport pytest\n\nimport awkward as ak\nimport awkward._connect.cuda as ak_cu\n\ncupy_backend = ak._backends.CupyBackend.instance()\n\n"
)
num = 1
if spec.tests == []:
Expand Down Expand Up @@ -728,7 +728,7 @@ def gencudakerneltests(specdict):
# )
# )
cuda_string = (
"funcC = cupy_nplike['"
"funcC = cupy_backend['"
+ spec.templatized_kernel_name
+ "', {}]\n".format(", ".join(dtypes))
)
Expand Down
2 changes: 1 addition & 1 deletion src/awkward/__init__.py
Expand Up @@ -4,7 +4,7 @@
from awkward._version import __version__

# NumPy-like alternatives
import awkward.nplikes
import awkward._nplikes
import awkward._typetracer
import awkward._backends

Expand Down
6 changes: 3 additions & 3 deletions src/awkward/_backends.py
Expand Up @@ -5,8 +5,7 @@
import awkward_cpp

import awkward as ak
from awkward._typetracer import NoKernel, TypeTracer
from awkward.nplikes import (
from awkward._nplikes import (
Cupy,
CupyKernel,
Jax,
Expand All @@ -18,6 +17,7 @@
Singleton,
nplike_of,
)
from awkward._typetracer import NoKernel, TypeTracer
from awkward.typing import (
Any,
Callable,
Expand Down Expand Up @@ -155,7 +155,7 @@ def __getitem__(self, index: KernelKeyType) -> NoKernel:
return NoKernel(index)


def _backend_for_nplike(nplike: ak.nplikes.NumpyLike) -> Backend:
def _backend_for_nplike(nplike: ak._nplikes.NumpyLike) -> Backend:
# Currently there exists a one-to-one relationship between the nplike
# and the backend. In future, this might need refactoring
if isinstance(nplike, Numpy):
Expand Down
6 changes: 3 additions & 3 deletions src/awkward/_broadcasting.py
Expand Up @@ -30,8 +30,8 @@
from awkward.record import Record
from awkward.typing import Any, Callable, Dict, List, TypeAlias, Union

np = ak.nplikes.NumpyMetadata.instance()
numpy = ak.nplikes.Numpy.instance()
np = ak._nplikes.NumpyMetadata.instance()
numpy = ak._nplikes.Numpy.instance()

optiontypes = (IndexedOptionArray, ByteMaskedArray, BitMaskedArray, UnmaskedArray)
listtypes = (ListOffsetArray, ListArray, RegularArray)
Expand All @@ -48,7 +48,7 @@ def broadcast_pack(inputs: Sequence, isscalar: list[bool]) -> list:
nextinputs = []
for x in inputs:
if isinstance(x, Record):
index = ak.nplikes.nplike_of(*inputs).full(maxlen, x.at, dtype=np.int64)
index = ak._nplikes.nplike_of(*inputs).full(maxlen, x.at, dtype=np.int64)
nextinputs.append(RegularArray(x.array[index], maxlen, 1))
isscalar.append(True)
elif isinstance(x, Content):
Expand Down
4 changes: 2 additions & 2 deletions src/awkward/_connect/cling.py
Expand Up @@ -8,8 +8,8 @@

import awkward as ak

np = ak.nplikes.NumpyMetadata.instance()
numpy = ak.nplikes.Numpy.instance()
np = ak._nplikes.NumpyMetadata.instance()
numpy = ak._nplikes.Numpy.instance()


cache = {}
Expand Down
2 changes: 1 addition & 1 deletion src/awkward/_connect/jax/reducers.py
Expand Up @@ -5,7 +5,7 @@
import awkward as ak
from awkward._reducers import Reducer

np = ak.nplikes.NumpyMetadata.instance()
np = ak._nplikes.NumpyMetadata.instance()


class ArgMin(Reducer):
Expand Down
10 changes: 5 additions & 5 deletions src/awkward/_connect/jax/trees.py
Expand Up @@ -4,11 +4,11 @@
import jax

import awkward as ak
from awkward import _errors, contents, highlevel, nplikes, record
from awkward import _errors, _nplikes, contents, highlevel, record
from awkward.typing import Generic, TypeVar, Union

numpy = nplikes.Numpy.instance()
np = nplikes.NumpyMetadata.instance()
numpy = _nplikes.Numpy.instance()
np = _nplikes.NumpyMetadata.instance()


def find_all_buffers(
Expand All @@ -32,7 +32,7 @@ def replace_all_buffers(
backend: ak._backends.Backend,
):
def action(node, **kwargs):
jaxlike = nplikes.Jax.instance()
jaxlike = _nplikes.Jax.instance()
if isinstance(node, ak.contents.NumpyArray):
buffer = buffers.pop(0)
# JAX might give us non-buffers, so ignore them
Expand Down Expand Up @@ -96,7 +96,7 @@ def from_array_or_layout(cls, obj: T):
# layout = replace_all_buffers(
# layout,
# [create_placeholder_like(n) for n in buffers],
# nplike=nplikes.Numpy.instance(),
# nplike=_nplikes.Numpy.instance(),
# )

return buffers, AuxData(
Expand Down
10 changes: 5 additions & 5 deletions src/awkward/_connect/numba/arrayview.py
Expand Up @@ -8,7 +8,7 @@

import awkward as ak

np = ak.nplikes.NumpyMetadata.instance()
np = ak._nplikes.NumpyMetadata.instance()


def code_to_function(code, function_name, externals=None, debug=False):
Expand Down Expand Up @@ -871,7 +871,7 @@ def array_supported(dtype):
) or isinstance(dtype, (numba.types.NPDatetime, numba.types.NPTimedelta))


@numba.extending.overload(ak.nplikes.numpy.array)
@numba.extending.overload(ak._nplikes.numpy.array)
def overload_np_array(array, dtype=None):
if isinstance(array, ArrayViewType):
ndim = array.type.ndim
Expand Down Expand Up @@ -936,11 +936,11 @@ def array_impl(array, dtype=None):
"\n ".join(fill_array),
),
"array_impl",
{"numpy": ak.nplikes.numpy},
{"numpy": ak._nplikes.numpy},
)


@numba.extending.type_callable(ak.nplikes.numpy.asarray)
@numba.extending.type_callable(ak._nplikes.numpy.asarray)
def type_asarray(context):
def typer(arrayview):
if (
Expand All @@ -954,7 +954,7 @@ def typer(arrayview):
return typer


@numba.extending.lower_builtin(ak.nplikes.numpy.asarray, ArrayViewType)
@numba.extending.lower_builtin(ak._nplikes.numpy.asarray, ArrayViewType)
def lower_asarray(context, builder, sig, args):
rettype, (viewtype,) = sig.return_type, sig.args
(viewval,) = args
Expand Down
2 changes: 1 addition & 1 deletion src/awkward/_connect/numba/builder.py
Expand Up @@ -7,7 +7,7 @@

import awkward as ak

numpy = ak.nplikes.Numpy.instance()
numpy = ak._nplikes.Numpy.instance()


dynamic_addrs = {}
Expand Down
4 changes: 2 additions & 2 deletions src/awkward/_connect/numba/layout.py
Expand Up @@ -6,8 +6,8 @@

import awkward as ak

np = ak.nplikes.NumpyMetadata.instance()
numpy = ak.nplikes.Numpy.instance()
np = ak._nplikes.NumpyMetadata.instance()
numpy = ak._nplikes.Numpy.instance()


@numba.extending.typeof_impl.register(ak.contents.Content)
Expand Down
8 changes: 4 additions & 4 deletions src/awkward/_connect/numpy.py
Expand Up @@ -26,10 +26,10 @@ def convert_to_array(layout, args, kwargs):

def _to_rectilinear(arg):
if isinstance(arg, tuple):
nplike = ak.nplikes.nplike_of(*arg)
nplike = ak._nplikes.nplike_of(*arg)
return tuple(nplike.to_rectilinear(x) for x in arg)
else:
nplike = ak.nplikes.nplike_of(arg)
nplike = ak._nplikes.nplike_of(arg)
nplike.to_rectilinear(arg)


Expand All @@ -39,7 +39,7 @@ def array_function(func, types, args, kwargs):
args = tuple(_to_rectilinear(x) for x in args)
kwargs = {k: _to_rectilinear(v) for k, v in kwargs.items()}
out = func(*args, **kwargs)
nplike = ak.nplikes.nplike_of(out)
nplike = ak._nplikes.nplike_of(out)
if isinstance(out, nplike.ndarray) and len(out.shape) != 0:
return ak.Array(out)
else:
Expand Down Expand Up @@ -167,7 +167,7 @@ def action(inputs, **ignore):
else:
args.append(x)

if isinstance(nplike, ak.nplikes.Jax):
if isinstance(nplike, ak._nplikes.Jax):
from awkward._connect.jax import get_jax_ufunc

jax_ufunc = get_jax_ufunc(ufunc)
Expand Down
2 changes: 1 addition & 1 deletion src/awkward/_connect/pyarrow.py
Expand Up @@ -7,7 +7,7 @@

import awkward as ak

np = ak.nplikes.NumpyMetadata.instance()
np = ak._nplikes.NumpyMetadata.instance()

try:
import pyarrow
Expand Down
6 changes: 3 additions & 3 deletions src/awkward/_connect/rdataframe/from_rdataframe.py
Expand Up @@ -29,8 +29,8 @@
"timedelta64": "std::difftime",
}

np = ak.nplikes.NumpyMetadata.instance()
numpy = ak.nplikes.Numpy.instance()
np = ak._nplikes.NumpyMetadata.instance()
numpy = ak._nplikes.Numpy.instance()


cppyy.add_include_path(
Expand Down Expand Up @@ -71,7 +71,7 @@ def form_dtype(form):
def empty_buffers(cpp_buffers_self, names_nbytes):
buffers = {}
for item in names_nbytes:
buffers[item.first] = ak.nplikes.numpy.empty(item.second)
buffers[item.first] = ak._nplikes.numpy.empty(item.second)
cpp_buffers_self.append(
item.first,
buffers[item.first].ctypes.data_as(ctypes.POINTER(ctypes.c_ubyte)),
Expand Down
10 changes: 4 additions & 6 deletions src/awkward/_errors.py
Expand Up @@ -5,9 +5,9 @@
import warnings
from collections.abc import Mapping, Sequence

from awkward import nplikes
from awkward import _nplikes

np = nplikes.NumpyMetadata.instance()
np = _nplikes.NumpyMetadata.instance()


class PartialFunction:
Expand Down Expand Up @@ -128,7 +128,7 @@ class OperationErrorContext(ErrorContext):

def __init__(self, name, arguments):
if self.primary() is not None or all(
nplikes.nplike_of(x).is_eager for x in arguments
_nplikes.nplike_of(x).is_eager for x in arguments
):
# if primary is not None: we won't be setting an ErrorContext
# if all nplikes are eager: no accumulation of large arrays
Expand Down Expand Up @@ -185,9 +185,7 @@ class SlicingErrorContext(ErrorContext):
_width = 80 - 4

def __init__(self, array, where):
if self.primary() is not None or (
nplikes.nplike_of(array).is_eager and nplikes.nplike_of(where).is_eager
):
if self.primary() is not None or _nplikes.nplike_of(array, where).is_eager:
# if primary is not None: we won't be setting an ErrorContext
# if all nplikes are eager: no accumulation of large arrays
# --> in either case, delay string generation
Expand Down
2 changes: 1 addition & 1 deletion src/awkward/_lookup.py
Expand Up @@ -2,7 +2,7 @@

import awkward as ak

np = ak.nplikes.NumpyMetadata.instance()
np = ak._nplikes.NumpyMetadata.instance()


class Lookup:
Expand Down
10 changes: 5 additions & 5 deletions src/awkward/nplikes.py → src/awkward/_nplikes.py
Expand Up @@ -835,11 +835,11 @@ def ascontiguousarray(self, array, dtype=None):
def raw(self, array, nplike):
if isinstance(nplike, Jax):
return array
elif isinstance(nplike, ak.nplikes.Cupy):
cupy = ak.nplikes.Cupy.instance()
elif isinstance(nplike, ak._nplikes.Cupy):
cupy = ak._nplikes.Cupy.instance()
return cupy.asarray(array)
elif isinstance(nplike, ak.nplikes.Numpy):
numpy = ak.nplikes.Numpy.instance()
elif isinstance(nplike, ak._nplikes.Numpy):
numpy = ak._nplikes.Numpy.instance()
return numpy.asarray(array)
elif isinstance(nplike, ak._typetracer.TypeTracer):
return ak._typetracer.TypeTracerArray(dtype=array.dtype, shape=array.shape)
Expand Down Expand Up @@ -939,7 +939,7 @@ def nplike_of(*arrays, default: D = _UNSET) -> NumpyLike | D:
*arrays: iterable of possible array objects
default: default NumpyLike instance if no array objects found

Return the #ak.nplikes.NumpyLike that is best-suited to operating upon the given
Return the #ak._nplikes.NumpyLike that is best-suited to operating upon the given
iterable of arrays. Return an instance of the `default_cls` if no known array types
are found.
"""
Expand Down
2 changes: 1 addition & 1 deletion src/awkward/_prettyprint.py
Expand Up @@ -6,7 +6,7 @@

import awkward as ak

numpy = ak.nplikes.Numpy.instance()
numpy = ak._nplikes.Numpy.instance()


def half(integer):
Expand Down
2 changes: 1 addition & 1 deletion src/awkward/_reducers.py
Expand Up @@ -2,7 +2,7 @@

import awkward as ak

np = ak.nplikes.NumpyMetadata.instance()
np = ak._nplikes.NumpyMetadata.instance()


class Reducer:
Expand Down
8 changes: 4 additions & 4 deletions src/awkward/_slicing.py
Expand Up @@ -3,7 +3,7 @@
import awkward as ak
from awkward.typing import Sequence

np = ak.nplikes.NumpyMetadata.instance()
np = ak._nplikes.NumpyMetadata.instance()


def headtail(oldtail):
Expand Down Expand Up @@ -62,7 +62,7 @@ def prepare_advanced_indexing(items):
)

# Then broadcast the index items
nplike = ak.nplikes.nplike_of(*broadcastable)
nplike = ak._nplikes.nplike_of(*broadcastable)
broadcasted = nplike.broadcast_arrays(*broadcastable)

# And re-assemble the index with the broadcasted items
Expand Down Expand Up @@ -381,7 +381,7 @@ def normalise_item_bool_to_int(item):
and issubclass(item.content.content.dtype.type, (bool, np.bool_))
):
if item.backend.nplike.known_data or item.backend.nplike.known_shape:
if isinstance(item.backend.nplike, ak.nplikes.Jax):
if isinstance(item.backend.nplike, ak._nplikes.Jax):
raise ak._errors.wrap_error(
"This slice is not supported for JAX differentiation."
)
Expand Down Expand Up @@ -450,7 +450,7 @@ def normalise_item_bool_to_int(item):
item.content.dtype.type, (bool, np.bool_)
):
if item.backend.nplike.known_data or item.backend.nplike.known_shape:
if isinstance(item.backend.nplike, ak.nplikes.Jax):
if isinstance(item.backend.nplike, ak._nplikes.Jax):
raise ak._errors.wrap_error(
"This slice is not supported for JAX differentiation."
)
Expand Down
6 changes: 3 additions & 3 deletions src/awkward/_typetracer.py
Expand Up @@ -5,10 +5,10 @@
import numpy

import awkward as ak
from awkward import index, nplikes
from awkward import _nplikes, index
from awkward.typing import TypeVar

np = nplikes.NumpyMetadata.instance()
np = _nplikes.NumpyMetadata.instance()


class NoError:
Expand Down Expand Up @@ -486,7 +486,7 @@ def copy(self):
return self


class TypeTracer(ak.nplikes.NumpyLike):
class TypeTracer(ak._nplikes.NumpyLike):
known_data = False
known_shape = False

Expand Down