Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom behaviors plus jax leading to lookup in wrong spot #2603

Closed
alexander-held opened this issue Aug 2, 2023 · 5 comments · Fixed by #3025
Closed

Custom behaviors plus jax leading to lookup in wrong spot #2603

alexander-held opened this issue Aug 2, 2023 · 5 comments · Fixed by #3025
Assignees
Labels
autodiff Issue related to auto-differentiation bug The problem described is something that must be fixed

Comments

@alexander-held
Copy link
Member

Version of Awkward Array

ce63bf2

Description and code to reproduce

This is partner issue to CoffeaTeam/coffea#874 as perhaps this is more on the side of awkward than coffea. I am trying to combine custom behaviors (defined by coffea) with the jax backend of awkward. The reproducer below results in:

AttributeError: module 'jax.numpy' has no attribute '_mass2_kernel'

Reproducer:

import awkward as ak
from coffea.nanoevents.methods import candidate
import numpy as np
import uproot

ak.jax.register_and_check()
ak.behavior.update(candidate.behavior)

ttbar_file = "https://github.com/scikit-hep/scikit-hep-testdata/"\
    "raw/main/src/skhep_testdata/data/nanoAOD_2015_CMS_Open_Data_ttbar.root"

with uproot.open(ttbar_file) as f:
    arr = f["Events"].arrays(["Electron_pt", "Electron_eta", "Electron_phi",
                              "Electron_mass", "Electron_charge"])

px = arr.Electron_pt * np.cos(arr.Electron_phi)
py = arr.Electron_pt * np.sin(arr.Electron_phi)
pz = arr.Electron_pt * np.sinh(arr.Electron_eta)
E = np.sqrt(arr.Electron_mass**2 + px**2 + py**2 + pz**2)

evtfilter = ak.num(arr["Electron_pt"]) >= 2

els = ak.zip({"pt": arr.Electron_pt, "eta": arr.Electron_eta, "phi": arr.Electron_phi,
              "energy": E, "charge": arr.Electron_charge}, with_name="PtEtaPhiECandidate")[evtfilter]
els = ak.to_backend(els, "jax")

(els[:, 0] + els[:, 1]).mass

Using the "Momentum4D" behavior from vector (after vector.register_awkward()) works. Skipping the backend conversion to jax also makes this work.

Full trace
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[1], line 32
    28 els = ak.zip({"pt": arr.Electron_pt, "eta": arr.Electron_eta, "phi": arr.Electron_phi,
    29               "energy": E, "charge": arr.Electron_charge}, with_name="PtEtaPhiECandidate")[evtfilter]
    30 els = ak.to_backend(els, "jax")
---> 32 (els[:, 0] + els[:, 1]).mass

File ~/mambaforge/envs/agc-ad/lib/python3.11/site-packages/awkward/highlevel.py:1097, in Array.__getattr__(self, where)
1061 """
1062 Args:
1063     where (str): Attribute name to lookup
(...)
1094 *assigned* as attributes. See #ak.Array.__setitem__ for more.
1095 """
1096 if hasattr(type(self), where):
-> 1097     return super().__getattribute__(where)
1098 else:
1099     if where in self._layout.fields:

File ~/mambaforge/envs/agc-ad/lib/python3.11/site-packages/coffea/nanoevents/methods/vector.py:531, in LorentzVector.mass(self)
    525 @property
    526 def mass(self):
    527     r"""Invariant mass (+, -, -, -)
    528 
    529     :math:`\sqrt{t^2-x^2-y^2-z^2}`
    530     """
--> 531     return numpy.sqrt(self.mass2)

File ~/mambaforge/envs/agc-ad/lib/python3.11/site-packages/awkward/highlevel.py:1097, in Array.__getattr__(self, where)
1061 """
1062 Args:
1063     where (str): Attribute name to lookup
(...)
1094 *assigned* as attributes. See #ak.Array.__setitem__ for more.
1095 """
1096 if hasattr(type(self), where):
-> 1097     return super().__getattribute__(where)
1098 else:
1099     if where in self._layout.fields:

File ~/mambaforge/envs/agc-ad/lib/python3.11/site-packages/coffea/nanoevents/methods/vector.py:523, in LorentzVector.mass2(self)
    520 @property
    521 def mass2(self):
    522     """Squared `mass`"""
--> 523     return _mass2_kernel(self.t, self.x, self.y, self.z)

File ~/mambaforge/envs/agc-ad/lib/python3.11/site-packages/awkward/highlevel.py:1349, in Array.__array_ufunc__(self, ufunc, method, *inputs, **kwargs)
1347 name = f"{type(ufunc).__module__}.{ufunc.__name__}.{method!s}"
1348 with ak._errors.OperationErrorContext(name, inputs, kwargs):
-> 1349     return ak._connect.numpy.array_ufunc(ufunc, method, inputs, kwargs)

File ~/mambaforge/envs/agc-ad/lib/python3.11/site-packages/awkward/_connect/numpy.py:459, in array_ufunc(ufunc, method, inputs, kwargs)
    450     out = ak._do.recursively_apply(
    451         inputs[where],
    452         unary_action,
(...)
    455         allow_records=False,
    456     )
    458 else:
--> 459     out = ak._broadcasting.broadcast_and_apply(
    460         inputs, action, behavior, allow_records=False, function_name=ufunc.__name__
    461     )
    462     assert isinstance(out, tuple) and len(out) == 1
    463     out = out[0]

File ~/mambaforge/envs/agc-ad/lib/python3.11/site-packages/awkward/_broadcasting.py:1022, in broadcast_and_apply(inputs, action, behavior, depth_context, lateral_context, allow_records, left_broadcast, right_broadcast, numpy_to_regular, regular_to_jagged, function_name, broadcast_parameters_rule)
1020 backend = backend_of(*inputs)
1021 isscalar = []
-> 1022 out = apply_step(
1023     backend,
1024     broadcast_pack(inputs, isscalar),
1025     action,
1026     0,
1027     depth_context,
1028     lateral_context,
1029     behavior,
1030     {
1031         "allow_records": allow_records,
1032         "left_broadcast": left_broadcast,
1033         "right_broadcast": right_broadcast,
1034         "numpy_to_regular": numpy_to_regular,
1035         "regular_to_jagged": regular_to_jagged,
1036         "function_name": function_name,
1037         "broadcast_parameters_rule": broadcast_parameters_rule,
1038     },
1039 )
1040 assert isinstance(out, tuple)
1041 return tuple(broadcast_unpack(x, isscalar, backend) for x in out)

File ~/mambaforge/envs/agc-ad/lib/python3.11/site-packages/awkward/_broadcasting.py:1001, in apply_step(backend, inputs, action, depth, depth_context, lateral_context, behavior, options)
    999     return result
1000 elif result is None:
-> 1001     return continuation()
1002 else:
1003     raise AssertionError(result)

File ~/mambaforge/envs/agc-ad/lib/python3.11/site-packages/awkward/_broadcasting.py:974, in apply_step.<locals>.continuation()
    972 # Any non-string list-types?
    973 elif any(x.is_list and not is_string_like(x) for x in contents):
--> 974     return broadcast_any_list()
    976 # Any RecordArrays?
    977 elif any(x.is_record for x in contents):

File ~/mambaforge/envs/agc-ad/lib/python3.11/site-packages/awkward/_broadcasting.py:622, in apply_step.<locals>.broadcast_any_list()
    619         nextinputs.append(x)
    620         nextparameters.append(NO_PARAMETERS)
--> 622 outcontent = apply_step(
    623     backend,
    624     nextinputs,
    625     action,
    626     depth + 1,
    627     copy.copy(depth_context),
    628     lateral_context,
    629     behavior,
    630     options,
    631 )
    632 assert isinstance(outcontent, tuple)
    633 parameters = parameters_factory(nextparameters, len(outcontent))

File ~/mambaforge/envs/agc-ad/lib/python3.11/site-packages/awkward/_broadcasting.py:987, in apply_step(backend, inputs, action, depth, depth_context, lateral_context, behavior, options)
    980     else:
    981         raise ValueError(
    982             "cannot broadcast: {}{}".format(
    983                 ", ".join(repr(type(x)) for x in inputs), in_function(options)
    984             )
    985         )
--> 987 result = action(
    988     inputs,
    989     depth=depth,
    990     depth_context=depth_context,
    991     lateral_context=lateral_context,
    992     continuation=continuation,
    993     behavior=behavior,
    994     backend=backend,
    995     options=options,
    996 )
    998 if isinstance(result, tuple) and all(isinstance(x, Content) for x in result):
    999     return result

File ~/mambaforge/envs/agc-ad/lib/python3.11/site-packages/awkward/_connect/numpy.py:400, in array_ufunc.<locals>.action(inputs, **ignore)
    397         args.append(x)
    399 # Give backend a chance to change the ufunc implementation
--> 400 impl = backend.prepare_ufunc(ufunc)
    402 # Invoke ufunc
    403 result = impl(*args, **kwargs)

File ~/mambaforge/envs/agc-ad/lib/python3.11/site-packages/awkward/_backends/jax.py:50, in JaxBackend.prepare_ufunc(self, ufunc)
    47 def prepare_ufunc(self, ufunc: UfuncLike) -> UfuncLike:
    48     from awkward._connect.jax import get_jax_ufunc
---> 50     return get_jax_ufunc(ufunc)

File ~/mambaforge/envs/agc-ad/lib/python3.11/site-packages/awkward/_connect/jax/__init__.py:8, in get_jax_ufunc(ufunc)
    7 def get_jax_ufunc(ufunc):
----> 8     return getattr(jax.numpy, ufunc.__name__)

File ~/mambaforge/envs/agc-ad/lib/python3.11/site-packages/jax/_src/deprecations.py:53, in deprecation_getattr.<locals>.getattr(name)
    51   warnings.warn(message, DeprecationWarning, stacklevel=2)
    52   return fn
---> 53 raise AttributeError(f"module {module!r} has no attribute {name!r}")

AttributeError: module 'jax.numpy' has no attribute '_mass2_kernel'

This error occurred while calling

    numpy._mass2_kernel.__call__(
        <Array [192.54099, 132.60043, ..., 142.34727] type='5 * float32'>
        <Array [5.5301285, -46.949707, ..., -58.96562] type='5 * float32'>
        <Array [-70.93436, -12.467135, ..., -31.510773] type='5 * float32'>
        <Array [156.38907, -75.47587, ..., -115.080734] type='5 * float32'>
    )
@alexander-held alexander-held added the bug (unverified) The problem described would be a bug, but needs to be triaged label Aug 2, 2023
@jpivarski jpivarski added the autodiff Issue related to auto-differentiation label Oct 2, 2023
@alexander-held
Copy link
Member Author

alexander-held commented Oct 5, 2023

I made some progress understanding what causes this to happen. Here is a significantly simplified reproducer:

import awkward as ak
import numba
import numpy as np

behavior = {}

ak.jax.register_and_check()

USE_JAX = False  # set to False to run this successfully
input_arr = ak.Array([1.0], backend=("jax" if USE_JAX else "cpu"))


@numba.vectorize(
    [
        numba.float32(numba.float32, numba.float32),
        numba.float64(numba.float64, numba.float64),
    ]
)
def _some_kernel(x, y):
    return x * x + y * y


@ak.mixin_class(behavior)
class SomeClass:
    @property
    def some_kernel(self):
        return _some_kernel(self.x, self.y)


ak.behavior.update(behavior)

arr = ak.zip({"x": input_arr, "y": input_arr}, with_name="SomeClass")

arr.some_kernel  # crashes with Jax

This results in

AttributeError: module 'jax.numpy' has no attribute '_some_kernel'

This error occurred while calling

    numpy._some_kernel.__call__(
        <Array [1.0] type='1 * float32'>
        <Array [1.0] type='1 * float32'>
    )

The code runs successfully with USE_JAX = False. It also works fine when removing the @numba.vectorize decorator from the kernel. I imagine numba + jax are just generically incompatible here. If that is the case and it is expected that this setup does not work, maybe there is a way to improve the error message for such a setup.

@agoose77
Copy link
Collaborator

agoose77 commented Oct 5, 2023

Right - at the moment, users can't override ufuncs for JAX, so numba ufuncs throw exceptions. Numba functions wouldn't be differentiable via JAX; we'd need to substitute a JAX implementation.

@jpivarski jpivarski added this to Unprioritized in Finalization Jan 19, 2024
@jpivarski
Copy link
Member

@Saransh-cpp, this is another one that you should self-assign (anything with label autodiff, actually).

@jpivarski jpivarski removed this from Unprioritized in Finalization Jan 19, 2024
@Saransh-cpp Saransh-cpp self-assigned this Jan 20, 2024
@Saransh-cpp Saransh-cpp added bug The problem described is something that must be fixed and removed bug (unverified) The problem described would be a bug, but needs to be triaged labels Feb 12, 2024
@Saransh-cpp
Copy link
Collaborator

The coffea issue will be solved once their vector module is removed and scikit-hep/vector is recommended to the users - CoffeaTeam/coffea#874 (comment)

For the issue on the awkward end, I am a bit confused regarding how we want the ideal behavior to look like -

  • Do we want the users to refrain from using Jax and Numba together, or would we like to support doing that (if that is possible)? Or do we just want better error handling here?
  • Do we want to recommend Jax's jit mechanism to users when they plan on differentiating their functions? I have not tried jax.jit with awkward, but it might just work.

Thanks!

@jpivarski
Copy link
Member

jax.jit will not work in Awkward—that was something that we determined very early on. Looking at it, it was clear that it would never work because so many of the Awkward kernels need to determine a new buffer's length from an old buffer's values, and that is forbidden in JAX. JAX users find it hard enough to not be able to apply a boolean mask in compiled JAX (because the output array length depends on how many True values are in the mask), but Awkward has to do that sort of thing a lot.

So with JAX's JIT-compilation off the table, the alternative of compiling in Numba is still there, but Numba does not propagate derivatives through its compiled code. Starting in January 2022 and (I was following it) until January 2023, @ludgerpaehler was trying to compile through Numba by using Enzyme, an autograd tool for LLVM code. I don't know the current state of that project, but that would allow us to connect JAX's non-JITted autograd with Numba's JITted autograd. Users already have to switch programming models between non-JIT and JIT, but in principle, it's possible to preserve derivatives across that boundary.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
autodiff Issue related to auto-differentiation bug The problem described is something that must be fixed
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants