Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add context to help fix Arrays that are not fully addressable` errors #37

Open
dlwh opened this issue Sep 7, 2023 · 5 comments
Open

Comments

@dlwh
Copy link
Member

dlwh commented Sep 7, 2023

    raise RuntimeError(
jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: Running operations on `Array`s that are not fully addressable by this process (i.e. `Array`s with data sharded across multiple devices and processes.) is dangerous. It’s very important that all processes run the same cross-process computations in the same order otherwise it can lead to hangs. If you’re not already familiar with JAX’s multi-process programming model, please read https://jax.readthedocs.io/en/latest/multi_process.html. To fix this error, run your `jitted` computation inside `with jax.spmd_mode('allow_all'):` context manager.

This one doesn't show up inside jit (by construction) so it's a bit harder to intercept. maybe just a FAQ entry?

@dlwh
Copy link
Member Author

dlwh commented Sep 7, 2023

also

  File "/home/dlwh/levanter/src/levanter/trainer.py", line 204, in initial_state
    model, opt_state = named_jit(self._init_model_and_opt_state, self.parameter_axis_mapping)(model_init)
  File "/home/dlwh/venv310/lib/python3.10/site-packages/haliax/partitioning.py", line 333, in f
    out, out_static = cached_pjitted_fun(dynamic_donated, dynamic_reserved, static)
  File "/home/dlwh/venv310/lib/python3.10/site-packages/jax/_src/array.py", line 679, in _array_mlir_constant_handler
    return mlir.ir_constants(val._value,
  File "/home/dlwh/venv310/lib/python3.10/site-packages/jax/_src/array.py", line 524, in _value
    raise RuntimeError("Fetching value for `jax.Array` that spans "
RuntimeError: Fetching value for `jax.Array` that spans non-addressable devices is not possible. You can use `jax.experimental.multihost_utils.process_allgather` for this use case.

This one is because arrays snuck into a closure that shouldn't have.

@ASKabalan
Copy link

I am having the same issue.

Is there anyway we can debug this error in a jitted function?

@dlwh
Copy link
Member Author

dlwh commented Jul 1, 2024 via email

@ASKabalan
Copy link

I created this discussion with a MWE
jax-ml/jax#22212
It is not inside a JIT in this example by in my code I call this shardmap from a jitted function

@dlwh
Copy link
Member Author

dlwh commented Jul 1, 2024

try again with jax nightly? some improvements were just made.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants