From 8db85a72e543ab9fe09c5f0cf3c882c23d3f7537 Mon Sep 17 00:00:00 2001 From: pfackeldey Date: Fri, 8 Aug 2025 13:22:38 -0400 Subject: [PATCH 1/3] fix: correctly infer the numpy-like module --- src/vector/backends/awkward.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/vector/backends/awkward.py b/src/vector/backends/awkward.py index 83633da7..6e049e83 100644 --- a/src/vector/backends/awkward.py +++ b/src/vector/backends/awkward.py @@ -649,11 +649,10 @@ class VectorAwkward: @property def lib(self): # type:ignore[no-untyped-def] - if ( - nplike := self.layout.backend.nplike # type:ignore[attr-defined] - ) is ak._nplikes.typetracer.TypeTracer.instance(): + nplike = self.layout.backend.nplike # type:ignore[attr-defined] + if nplike is ak._nplikes.typetracer.TypeTracer.instance(): return _lib(module=numpy, nplike=nplike) - return numpy + return _lib(module=nplike._module, nplike=nplike) def _wrap_result( self: AwkwardProtocol, From 06594bd30d835fb6d03bb7283eb37b697411e999 Mon Sep 17 00:00:00 2001 From: pfackeldey Date: Fri, 8 Aug 2025 13:33:36 -0400 Subject: [PATCH 2/3] add test --- pyproject.toml | 2 ++ tests/test_issues.py | 15 +++++++++++++++ 2 files changed, 17 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 17034323..428a3a20 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,7 @@ optional-dependencies.dev = [ "pytest-cov>=3", "pytest-doctestplus", "sympy", + "jax", ] optional-dependencies.docs = [ "awkward>=2", @@ -85,6 +86,7 @@ optional-dependencies.test = [ "pytest-doctestplus", ] optional-dependencies.test-extras = [ + "jax", "dask_awkward", "spark-parser", 'uncompyle6; python_version == "3.8"', diff --git a/tests/test_issues.py b/tests/test_issues.py index 099534be..86faee49 100644 --- a/tests/test_issues.py +++ b/tests/test_issues.py @@ -170,3 +170,18 @@ def test_issue_463(): for transform in "xyz", "xytheta", "xyeta", "rhophiz", "rhophitheta", "rhophieta": trv = getattr(v, "to_" + transform)() assert trv.deltaangle(trv) == 0.0 + + +def test_issue_621(): + _ = pytest.importorskip("awkward") + ak = pytest.importorskip("awkward") + vector.register_awkward() + ak.jax.register_and_check() + + a = b = ak.to_backend( + ak.zip({"x": [1], "y": [1], "z": [1], "t": [1]}, with_name="Momentum4D"), "jax" + ) + + # some computation that involves broadcast_and_apply in awkward + # enough to check if it computes at all + assert (a + b).mass From af899378cc7f7fbe664f9855350a45ffcf62540b Mon Sep 17 00:00:00 2001 From: pfackeldey Date: Fri, 8 Aug 2025 13:40:13 -0400 Subject: [PATCH 3/3] fix typo --- tests/test_issues.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_issues.py b/tests/test_issues.py index 86faee49..69726b84 100644 --- a/tests/test_issues.py +++ b/tests/test_issues.py @@ -173,7 +173,7 @@ def test_issue_463(): def test_issue_621(): - _ = pytest.importorskip("awkward") + _ = pytest.importorskip("jax") ak = pytest.importorskip("awkward") vector.register_awkward() ak.jax.register_and_check()