Skip to content

Commit

Permalink
style: pre-commit fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
pre-commit-ci[bot] committed Nov 7, 2022
1 parent f59e68f commit de70891
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 15 deletions.
2 changes: 1 addition & 1 deletion src/awkward/_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def arrayptr(x):
if isinstance(x, int):
return x
elif isinstance(self.nplike, ak.nplikes.Cupy):
return x.data
return x.data
else:
return x.ctypes.data

Expand Down
28 changes: 14 additions & 14 deletions tests-cuda/test_1809-cuda-jit.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE

import cupy
import numpy as np # noqa: F401
import pytest # noqa: F401

import awkward as ak # noqa: F401
import cupy

numba = pytest.importorskip("numba")

from numba import cuda, types # noqa: F401, E402
from numba import config, cuda, types # noqa: F401, E402
from numba.core.typing.typeof import typeof, typeof_impl # noqa: F401, E402

from numba import config
config.CUDA_LOW_OCCUPANCY_WARNINGS = False
config.CUDA_WARN_ON_IMPLICIT_COPY = False

Expand All @@ -21,22 +20,25 @@

ak.numba.register_and_check()


class ArrayViewArgHandler:
def prepare_args(self, ty, val, **kwargs):
print(repr(val), type(val))
if isinstance(val, ak.Array):
if isinstance(val, ak.Array):
return ty, val
elif isinstance(val, ak._connect.numba.arrayview.ArrayView):
return types.uint64, val._numbaview.lookup.arrayptrs
else:
return ty, val


array_view_arg_handler = ArrayViewArgHandler()

# FIXME: configure the blocks
# threadsperblock = 32
# blockspergrid = 128


@cuda.jit(extensions=[array_view_arg_handler])
def swallow(array):
pass
Expand All @@ -62,18 +64,22 @@ def digest2(array):
tmp = array[0]
return tmp, tmp, array[0]


def test_numpy_array_1d():
nparray = np.array([0, 1, 2, 3], dtype=int)
swallow[1, 1](nparray)


def test_to_numy_array_1d():
akarray = ak.Array([0, 1, 2, 3])
swallow[1, 1](ak.to_numpy(akarray))

#def test_array_1d():

# def test_array_1d():
# akarray = ak.Array([0, 1, 2, 3])
# swallow[1, 1](akarray))



def test_array_njit():
@numba.njit
def something(array):
Expand All @@ -91,9 +97,7 @@ def something(array):
if index > len(array):
return

akarray = ak.Array(
[1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9], backend="cuda"
)
akarray = ak.Array([1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9], backend="cuda")
something[1, 1](ak.to_cupy(akarray))


Expand All @@ -104,9 +108,5 @@ def something(array):
if index > len(array):
return

akarray = ak.Array(
[1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9]
)
akarray = ak.Array([1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9])
something[1, 1](ak.to_numpy(akarray))


0 comments on commit de70891

Please sign in to comment.