Skip to content

Commit

Permalink
Include type-tracer tests in all ufunc tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
jpivarski committed Dec 8, 2021
1 parent 3caa54f commit 5b0a273
Show file tree
Hide file tree
Showing 9 changed files with 107 additions and 53 deletions.
30 changes: 20 additions & 10 deletions src/awkward/_v2/_broadcasting.py
Expand Up @@ -59,14 +59,14 @@ def broadcast_pack(inputs, isscalar):
return nextinputs


def broadcast_unpack(x, isscalar):
def broadcast_unpack(x, isscalar, nplike):
if all(isscalar):
if len(x) == 0:
if not nplike.known_shape or len(x) == 0:
return x._getitem_nothing()._getitem_nothing()
else:
return x[0][0]
else:
if len(x) == 0:
if not nplike.known_shape or len(x) == 0:
return x._getitem_nothing()
else:
return x[0]
Expand Down Expand Up @@ -449,17 +449,20 @@ def continuation():
nextinputs = []
for x in inputs:
if isinstance(x, ListOffsetArray):
offsets = x.offsets
offsets = Index64(
nplike.empty((x.offsets.data.shape[0],), np.int64)
)
nextinputs.append(x.content)
elif isinstance(x, ListArray):
offsets = Index(
nplike.empty((x.starts.shape[0] + 1,), x.starts.dtype)
offsets = Index64(
nplike.empty((x.starts.data.shape[0] + 1,), np.int64)
)
nextinputs.append(x.content)
elif isinstance(x, RegularArray):
nextinputs.append(x.content)
else:
nextinputs.append(x)
assert offsets is not None

outcontent = apply_step(
nplike,
Expand Down Expand Up @@ -511,9 +514,15 @@ def continuation():
assert isinstance(outcontent, tuple)

if isinstance(offsets, Index):
return tuple(ListOffsetArray(offsets, x) for x in outcontent)
return tuple(
ListOffsetArray(offsets, x).toListOffsetArray64(False)
for x in outcontent
)
elif isinstance(starts, Index) and isinstance(stops, Index):
return tuple(ListArray(starts, stops, x) for x in outcontent)
return tuple(
ListArray(starts, stops, x).toListOffsetArray64(False)
for x in outcontent
)
else:
raise AssertionError(
"unexpected offsets, starts: {0}, {1}".format(
Expand Down Expand Up @@ -677,9 +686,10 @@ def broadcast_and_apply(
regular_to_jagged=False,
function_name=None,
):
nplike = ak.nplike.of(*inputs)
isscalar = []
out = apply_step(
ak.nplike.of(*inputs),
nplike,
broadcast_pack(inputs, isscalar),
action,
0,
Expand All @@ -696,4 +706,4 @@ def broadcast_and_apply(
},
)
assert isinstance(out, tuple)
return tuple(broadcast_unpack(x, isscalar) for x in out)
return tuple(broadcast_unpack(x, isscalar, nplike) for x in out)
42 changes: 34 additions & 8 deletions src/awkward/_v2/_connect/numpy.py
Expand Up @@ -195,12 +195,12 @@ def action(inputs, **ignore):
for x in inputs:
if isinstance(x, NumpyArray):
shape = x.shape
args.append(numpy.empty((0,), x.dtype))
args.append(numpy.empty((0,) + x.shape[1:], x.dtype))
else:
args.append(x)
assert shape is not None
dtype = getattr(ufunc, method)(*args, **kwargs).dtype
result = nplike.empty(shape, dtype)
tmp = getattr(ufunc, method)(*args, **kwargs)
result = nplike.empty((shape[0],) + tmp.shape[1:], tmp.dtype)

return (NumpyArray(result, nplike=nplike),)

Expand Down Expand Up @@ -239,11 +239,37 @@ def action(inputs, **ignore):

return None

out = ak._v2._broadcasting.broadcast_and_apply(
inputs, action, behavior, allow_records=False, function_name=ufunc.__name__
)
assert isinstance(out, tuple) and len(out) == 1
return ak._v2._util.wrap(out[0], behavior)
if sum(int(isinstance(x, ak._v2.contents.Content)) for x in inputs) == 1:
where = None
for i, x in enumerate(inputs):
if isinstance(x, ak._v2.contents.Content):
where = i
break
assert where is not None

nextinputs = list(inputs)

def unary_action(layout, **ignore):
nextinputs[where] = layout
result = action(tuple(nextinputs), **ignore)
if result is None:
return None
else:
assert isinstance(result, tuple) and len(result) == 1
return result[0]

out = inputs[where].recursively_apply(
unary_action, function_name=ufunc.__name__
)

else:
out = ak._v2._broadcasting.broadcast_and_apply(
inputs, action, behavior, allow_records=False, function_name=ufunc.__name__
)
assert isinstance(out, tuple) and len(out) == 1
out = out[0]

return ak._v2._util.wrap(out, behavior)


# def matmul_for_numba(lefts, rights, dtype):
Expand Down
37 changes: 8 additions & 29 deletions src/awkward/nplike.py
Expand Up @@ -15,38 +15,17 @@ def of(*arrays):
libs = set()
for array in arrays:
nplike = getattr(array, "nplike", None)
if isinstance(nplike, NumpyLike):
libs.add(nplike)
elif isinstance(array, numpy.ndarray):
ptr_lib = "cpu"
elif (
type(array).__module__.startswith("cupy.")
and type(array).__name__ == "ndarray"
):
ptr_lib = "cuda"
else:
ptr_lib = ak.operations.convert.kernels(array)
if ptr_lib is None:
pass
elif ptr_lib == "cpu":
libs.add("cpu")
elif ptr_lib == "cuda":
libs.add("cuda")
else:
raise ValueError(
"""structure mixes 'cpu' and 'cuda' buffers; use one of
if (
isinstance(array, ak._v2.highlevel.Array)
and isinstance(array.layout, ak._v2.contents.EmptyArray)
) or isinstance(array, ak._v2.contents.EmptyArray):
nplike = None

ak.to_kernels(array, 'cpu')
ak.to_kernels(array, 'cuda')
to obtain an unmixed array in main memory or the GPU(s)."""
+ ak._util.exception_suffix(__file__)
)
if nplike is not None:
libs.add(nplike)

if libs == set() or libs == set(["cpu"]):
if libs == set():
return Numpy.instance()
elif libs == set(["cuda"]):
return Cupy.instance()
elif len(libs) == 1:
return next(iter(libs))
else:
Expand Down
4 changes: 3 additions & 1 deletion tests/test_0645-from-jax.py
Expand Up @@ -6,7 +6,9 @@
import numpy as np # noqa: F401
import awkward as ak # noqa: F401

pytest.mark.skip(reason="Top-down JAX tests disabled; to be replaced by bottom-up.")
pytestmark = pytest.mark.skip(
reason="Top-down JAX tests disabled; to be replaced by bottom-up."
)

jax = pytest.importorskip("jax")
jax.config.update("jax_platform_name", "cpu")
Expand Down
4 changes: 3 additions & 1 deletion tests/test_0645-jax-refcount.py
Expand Up @@ -8,7 +8,9 @@
import numpy as np # noqa: F401
import awkward as ak # noqa: F401

pytest.mark.skip(reason="Top-down JAX tests disabled; to be replaced by bottom-up.")
pytestmark = pytest.mark.skip(
reason="Top-down JAX tests disabled; to be replaced by bottom-up."
)

jax = pytest.importorskip("jax")
jax.config.update("jax_platform_name", "cpu")
Expand Down
4 changes: 3 additions & 1 deletion tests/test_0645-to-jax.py
Expand Up @@ -6,7 +6,9 @@
import numpy as np # noqa: F401
import awkward as ak # noqa: F401

pytest.mark.skip(reason="Top-down JAX tests disabled; to be replaced by bottom-up.")
pytestmark = pytest.mark.skip(
reason="Top-down JAX tests disabled; to be replaced by bottom-up."
)

jax = pytest.importorskip("jax")
jax.config.update("jax_platform_name", "cpu")
Expand Down
4 changes: 3 additions & 1 deletion tests/test_0793-jax-element-wise-ops.py
Expand Up @@ -6,7 +6,9 @@
import numpy as np
import pytest

pytest.mark.skip(reason="Top-down JAX tests disabled; to be replaced by bottom-up.")
pytestmark = pytest.mark.skip(
reason="Top-down JAX tests disabled; to be replaced by bottom-up."
)

jax = pytest.importorskip("jax")
jax.config.update("jax_platform_name", "cpu")
Expand Down
24 changes: 24 additions & 0 deletions tests/v2/test_0086-nep13-ufunc.py
Expand Up @@ -7,10 +7,18 @@
import awkward as ak # noqa: F401


def tt(highlevel):
return ak._v2.highlevel.Array(highlevel.layout.typetracer)


def test_basic():
array = ak._v2.highlevel.Array([[1.1, 2.2, 3.3], [], [4.4, 5.5]])
assert ak.to_list(array + array) == [[2.2, 4.4, 6.6], [], [8.8, 11.0]]
assert (array + array).layout.form == (tt(array) + tt(array)).layout.form
assert ak.to_list(array * 2) == [[2.2, 4.4, 6.6], [], [8.8, 11.0]]
assert ak.to_list(2 * array) == [[2.2, 4.4, 6.6], [], [8.8, 11.0]]
assert (array * 2).layout.form == (tt(array) * 2).layout.form
assert (array * 2).layout.form == (2 * tt(array)).layout.form


def test_emptyarray():
Expand All @@ -19,6 +27,9 @@ def test_emptyarray():
assert ak.to_list(one + one) == []
assert ak.to_list(two + two) == []
assert ak.to_list(one + two) == []
assert (one + one).layout.form == (tt(one) + tt(one)).layout.form
assert (two + two).layout.form == (tt(two) + tt(two)).layout.form
assert (one + two).layout.form == (tt(one) + tt(two)).layout.form


def test_indexedarray():
Expand All @@ -30,6 +41,7 @@ def test_indexedarray():
one = ak._v2.highlevel.Array(ak._v2.contents.IndexedArray(index1, content))
two = ak._v2.highlevel.Array(ak._v2.contents.IndexedArray(index2, content))
assert ak.to_list(one + two) == [8.8, 8.8, 8.8, 8.8, 8.8]
assert (one + two).layout.form == (tt(one) + tt(two)).layout.form


def test_indexedoptionarray():
Expand All @@ -41,6 +53,7 @@ def test_indexedoptionarray():
one = ak._v2.highlevel.Array(ak._v2.contents.IndexedOptionArray(index1, content))
two = ak._v2.highlevel.Array(ak._v2.contents.IndexedOptionArray(index2, content))
assert ak.to_list(one + two) == [None, None, 8.8, None, 8.8]
assert (one + two).layout.form == (tt(one) + tt(two)).layout.form

uno = ak._v2.highlevel.Array(
ak._v2.contents.NumpyArray(np.array([2.2, 4.4, 4.4, 0.0, 8.8]))
Expand All @@ -49,7 +62,9 @@ def test_indexedoptionarray():
ak._v2.contents.NumpyArray(np.array([6.6, 4.4, 4.4, 8.8, 0.0]))
)
assert ak.to_list(uno + two) == [None, 8.8, 8.8, None, 8.8]
assert (uno + two).layout.form == (tt(uno) + tt(two)).layout.form
assert ak.to_list(one + dos) == [8.8, None, 8.8, 8.8, 8.8]
assert (one + dos).layout.form == (tt(one) + tt(dos)).layout.form


def test_regularize_shape():
Expand All @@ -64,18 +79,23 @@ def test_regulararray():
ak.to_list(array + array)
== (np.arange(2 * 3 * 5).reshape(2, 3, 5) * 2).tolist()
)
assert (array + array).layout.form == (tt(array) + tt(array)).layout.form
assert ak.to_list(array * 2) == (np.arange(2 * 3 * 5).reshape(2, 3, 5) * 2).tolist()
assert (array * 2).layout.form == (tt(array) * 2).layout.form
array2 = ak._v2.highlevel.Array(np.arange(2 * 1 * 5).reshape(2, 1, 5))
assert ak.to_list(array + array2) == ak.to_list(
np.arange(2 * 3 * 5).reshape(2, 3, 5) + np.arange(2 * 1 * 5).reshape(2, 1, 5)
)
assert (array + array2).layout.form == (tt(array) + tt(array2)).layout.form
array3 = ak._v2.highlevel.Array(np.arange(2 * 3 * 5).reshape(2, 3, 5).tolist())
assert ak.to_list(array + array3) == ak.to_list(
np.arange(2 * 3 * 5).reshape(2, 3, 5) + np.arange(2 * 3 * 5).reshape(2, 3, 5)
)
assert (array + array3).layout.form == (tt(array) + tt(array3)).layout.form
assert ak.to_list(array3 + array) == ak.to_list(
np.arange(2 * 3 * 5).reshape(2, 3, 5) + np.arange(2 * 3 * 5).reshape(2, 3, 5)
)
assert (array3 + array).layout.form == (tt(array3) + tt(array)).layout.form


def test_listarray():
Expand All @@ -95,6 +115,7 @@ def test_listarray():
[],
[110, 111],
]
assert (one + 100).layout.form == (tt(one) + 100).layout.form
assert ak.to_list(one + two) == [
[103, 104, 105, 106],
[200, 201, 202],
Expand All @@ -103,6 +124,7 @@ def test_listarray():
[],
[410, 411],
]
assert (one + two).layout.form == (tt(one) + tt(two)).layout.form
assert ak.to_list(two + one) == [
[103, 104, 105, 106],
[200, 201, 202],
Expand All @@ -111,6 +133,7 @@ def test_listarray():
[],
[410, 411],
]
assert (two + one).layout.form == (tt(two) + tt(one)).layout.form
assert ak.to_list(
one + np.array([100, 200, 300, 400, 500, 600])[:, np.newaxis]
) == [[103, 104, 105, 106], [200, 201, 202], [], [402, 403], [], [610, 611]]
Expand All @@ -125,6 +148,7 @@ def test_listarray():
[],
[110, 111],
]
assert (one + 100).layout.form == (tt(one) + 100).layout.form


def test_unionarray():
Expand Down
11 changes: 9 additions & 2 deletions tests/v2/test_1183-bugs-found-by-dask-project-2.py
Expand Up @@ -14,5 +14,12 @@ def test_example():
ttx = ak._v2.highlevel.Array(x.layout.typetracer)
tty = ak._v2.highlevel.Array(y.layout.typetracer)

assert (x + y).type == (ttx + tty).type
assert (x + np.sin(y)).type == (ttx + np.sin(tty)).type
assert (x + y).layout.form == (ttx + tty).layout.form
assert (x + np.sin(y)).layout.form == (ttx + np.sin(tty)).layout.form

x = ak._v2.highlevel.Array(
ak._v2.contents.ListArray(x.layout.starts, x.layout.stops, x.layout.content)
)
ttx = ak._v2.highlevel.Array(x.layout.typetracer)

assert (x + x).layout.form == (ttx + ttx).layout.form

0 comments on commit 5b0a273

Please sign in to comment.