-
Notifications
You must be signed in to change notification settings - Fork 5.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Jax + Ray[Core] #44087
Comments
Assigning P0 assumption basic Jax + Ray doesn't work. If it turns out not to be the case, we can downgrade. |
can't repro. I made a jax script to create random matrices, and dot product, and had no problems. import ray
import jax.random as jr
import jax.numpy as jnp
@ray.remote
class RemoteEnv(object):
def __init__(self):
self.key = jr.PRNGKey(42)
def create_matrix(self, shape):
new_key, self.key = jr.split(self.key)
return jr.normal(new_key, shape)
def dot(self, mat1, mat2):
result = jnp.dot(mat1, mat2)
return result
import numpy as np
ray.init()
remote_env = RemoteEnv.remote()
mat1_ref = remote_env.create_matrix.remote((20,50))
mat2_ref = remote_env.create_matrix.remote((50, 60))
result_ref = remote_env.dot.remote(mat1_ref, mat2_ref)
mat1, mat2, result = ray.get([mat1_ref, mat2_ref, result_ref])
for m in (mat1, mat2, result):
print(f"{m.shape}")
print(jnp.sum(m)) result:
can you share a repro script which can print that warning? |
Hi @rynewang I just tried with your example and it prints the warning for me. Just to clarify: the code does run - but I'm not sure how to check for deadlocks. I tried on two machines (x86 ubuntu and m1 osx) and both gave the error - I also tried with both python 3.10 and 3.11 and again, errors on both. Are you sure you're matching my jax / ray versions? EDIT: I'm running Jax on CPU here - maybe that explains the difference? An old (but related) issue: google/jax#1805 |
Repro'd the warnings on my laptop. It turns out to be the process creations for the ray clusters (raylets, gcs, ...). It should not have any "deadlocks" with Jax since we will execve Ray binaries right after the fork. For the warning, you can move import ray
ray.init()
import jax.random as jr
import jax.numpy as jnp
@ray.remote
class RemoteEnv(object):
def __init__(self):
self.key = jr.PRNGKey(42)
def create_matrix(self, shape):
new_key, self.key = jr.split(self.key)
return jr.normal(new_key, shape)
def dot(self, mat1, mat2):
result = jnp.dot(mat1, mat2)
return result
import numpy as np
remote_env = RemoteEnv.remote()
mat1_ref = remote_env.create_matrix.remote((20,50))
mat2_ref = remote_env.create_matrix.remote((50, 60))
result_ref = remote_env.dot.remote(mat1_ref, mat2_ref)
mat1, mat2, result = ray.get([mat1_ref, mat2_ref, result_ref])
for m in (mat1, mat2, result):
print(f"{m.shape}")
print(jnp.sum(m)) |
Great, thanks for the quick help! Given the label I'll keep the issue open, but feel free to close it if I'm wrong :-) |
What happened + What you expected to happen
I'm trying to run Jax inside a ray actor, but since Jax is multi threaded it's throwing errors due to ray's use of
os.fork
:RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
Is there some way to use 'spawn' or 'forkserver' strategy, or am I missing something else?
Versions / Dependencies
ray 2.9.3
Jax 0.4.25
Reproduction script
Issue Severity
High: It blocks me from completing my task.
The text was updated successfully, but these errors were encountered: