Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jax + Ray[Core] #44087

Closed
GJBoth opened this issue Mar 18, 2024 · 5 comments
Closed

Jax + Ray[Core] #44087

GJBoth opened this issue Mar 18, 2024 · 5 comments
Labels
bug Something that is supposed to be working; but isn't core Issues that should be addressed in Ray Core P1.5 Issues that will be fixed in a couple releases. It will be bumped once all P1s are cleared stability

Comments

@GJBoth
Copy link

GJBoth commented Mar 18, 2024

What happened + What you expected to happen

I'm trying to run Jax inside a ray actor, but since Jax is multi threaded it's throwing errors due to ray's use of os.fork:

RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.

Is there some way to use 'spawn' or 'forkserver' strategy, or am I missing something else?

Versions / Dependencies

ray 2.9.3
Jax 0.4.25

Reproduction script

import ray
import jax.random as jr

@ray.remote
class RemoteEnv(object):
    def __init__(self):
        key = jr.PRNGKey(42)
        x = jr.normal(key, (1, )

Issue Severity

High: It blocks me from completing my task.

@GJBoth GJBoth added bug Something that is supposed to be working; but isn't triage Needs triage (eg: priority, bug/not-bug, and owning component) labels Mar 18, 2024
@edoakes edoakes added the core Issues that should be addressed in Ray Core label Mar 18, 2024
@jjyao jjyao removed their assignment Mar 18, 2024
@jjyao jjyao added P0 Issues that should be fixed in short order and removed triage Needs triage (eg: priority, bug/not-bug, and owning component) labels Mar 18, 2024
@jjyao
Copy link
Contributor

jjyao commented Mar 18, 2024

Assigning P0 assumption basic Jax + Ray doesn't work. If it turns out not to be the case, we can downgrade.

@rynewang
Copy link
Contributor

rynewang commented Mar 20, 2024

can't repro. I made a jax script to create random matrices, and dot product, and had no problems.

import ray
import jax.random as jr
import jax.numpy as jnp

@ray.remote
class RemoteEnv(object):
    def __init__(self):
        self.key = jr.PRNGKey(42)

    def create_matrix(self, shape):
        new_key, self.key = jr.split(self.key)
        return jr.normal(new_key, shape)

    def dot(self, mat1, mat2):
        result = jnp.dot(mat1, mat2)
        return result

import numpy as np
ray.init()
remote_env = RemoteEnv.remote()
mat1_ref = remote_env.create_matrix.remote((20,50))
mat2_ref = remote_env.create_matrix.remote((50, 60))
result_ref = remote_env.dot.remote(mat1_ref, mat2_ref)
mat1, mat2, result = ray.get([mat1_ref, mat2_ref, result_ref])

for m in (mat1, mat2, result):
    print(f"{m.shape}")
    print(jnp.sum(m))

result:

(20, 50)
-36.139606
(50, 60)
37.13751
(20, 60)
-41.57849

can you share a repro script which can print that warning?

@GJBoth
Copy link
Author

GJBoth commented Mar 20, 2024

Hi @rynewang I just tried with your example and it prints the warning for me. Just to clarify: the code does run - but I'm not sure how to check for deadlocks.

I tried on two machines (x86 ubuntu and m1 osx) and both gave the error - I also tried with both python 3.10 and 3.11 and again, errors on both. Are you sure you're matching my jax / ray versions? EDIT: I'm running Jax on CPU here - maybe that explains the difference?

An old (but related) issue: google/jax#1805

@rynewang
Copy link
Contributor

rynewang commented Mar 20, 2024

Repro'd the warnings on my laptop.

It turns out to be the process creations for the ray clusters (raylets, gcs, ...). It should not have any "deadlocks" with Jax since we will execve Ray binaries right after the fork.

For the warning, you can move ray.init() up before any jax imports and it should work. Like this: (see line 2)

import ray
ray.init()

import jax.random as jr
import jax.numpy as jnp

@ray.remote
class RemoteEnv(object):
    def __init__(self):
        self.key = jr.PRNGKey(42)

    def create_matrix(self, shape):
        new_key, self.key = jr.split(self.key)
        return jr.normal(new_key, shape)

    def dot(self, mat1, mat2):
        result = jnp.dot(mat1, mat2)
        return result

import numpy as np

remote_env = RemoteEnv.remote()
mat1_ref = remote_env.create_matrix.remote((20,50))
mat2_ref = remote_env.create_matrix.remote((50, 60))
result_ref = remote_env.dot.remote(mat1_ref, mat2_ref)
mat1, mat2, result = ray.get([mat1_ref, mat2_ref, result_ref])

for m in (mat1, mat2, result):
    print(f"{m.shape}")
    print(jnp.sum(m))

@jjyao jjyao added P1.5 Issues that will be fixed in a couple releases. It will be bumped once all P1s are cleared and removed P0 Issues that should be fixed in short order labels Mar 20, 2024
@GJBoth
Copy link
Author

GJBoth commented Mar 22, 2024

Great, thanks for the quick help! Given the label I'll keep the issue open, but feel free to close it if I'm wrong :-)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something that is supposed to be working; but isn't core Issues that should be addressed in Ray Core P1.5 Issues that will be fixed in a couple releases. It will be bumped once all P1s are cleared stability
Projects
None yet
Development

No branches or pull requests

4 participants