Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New mp_ctx defaults in 4.3.0+ (#6218) for M1 Macs could break parallel sampling for M1 Mac users using JAX #6362

Closed
digicosmos86 opened this issue Dec 1, 2022 · 7 comments · Fixed by #6363
Labels
bug jax macOS macOS related

Comments

@digicosmos86
Copy link
Contributor

Describe the issue:

Problem:

In #6218, the mp_ctx is set to fork by default if the user's system runs M1 Macs. However, I was running some jax functions wrapped in aesara ops following this tutorial When cores is set to 2, sampling would get stuck at 0.00%. Setting cores to 1, setting mp_ctx to multiprocessing.get_context("forkserver") or downgrading to pymc 4.2.2 would solve this issue.

Cause:

According to this post, JAX is internally multithreaded and does not work with fork strategy in multiprocessing. So Jax functions wrapped in aesara ops would not work in parallel sampling after 4.3.0 on M1 Macs because mp_ctx is set to fork and would not change unless a context object is passed. Even simply setting forkserver to mp_ctx does not work because a string argument does not change mp_ctx at all if the system is M1 Mac. This makes the code very difficult to debug.

Solution:

  1. Change the behavior of mp_ctx so at least it does not internally force the context to fork if a str is passed on M1 Macs. Maybe the users can be warned if their system is M1 Macs.
  2. Potentially update this tutorial to the changes in 4.3.0 (Set start method to fork for MacOs ARM devices #6218).

Reproduceable code example:

# WFPT_LAN is a subclass of pm.Distribution that uses a jax function wrapped in an aesara Op as its log-likelihood function

with pm.Model() as m_angle:
    v = pm.Uniform("v", -3.0, 3.0)
    a = 1.5
    z = 0.5
    t = 0.5
    theta = 0.3

    rt = WFPT_LAN(
        name="rt",
        v=v,
        a=a,
        z=z,
        t=t,
        theta=theta,
        observed=obs_angle["rts"][:, 0] * obs_angle["choices"][:, 0]
    )

    trace_angle_nuts = pm.sample(cores=2, chains=2, draws=500, tune=500)

Error message:

Sampling just gets stuck when with `cores` set to 2. Setting `mp_ctx=multiprocessing.get_context("forkserver")` solves the problem.

INFO:pymc:Auto-assigning NUTS sampler...
INFO:pymc:Initializing NUTS using jitter+adapt_diag...
INFO:pymc:Multiprocess sampling (2 chains in 2 jobs)
INFO:pymc:NUTS: [v]
[INFO/worker_chain_0] child process calling self.run()
[INFO/worker_chain_1] child process calling self.run()

 0.00% [0/2000 00:00<? Sampling 2 chains, 0 divergences]

PyMC version information:

aesara==2.8.7, aesara==2.8.6
pymc==4.3.0, pymc==4.4.0

pymc==4.2.2 does not have the same issue.

Context for the issue:

No response

@digicosmos86 digicosmos86 changed the title New mp_ctx defaults in 4.3.0+ (#6218) for M1 Macs could break parallel sampling New mp_ctx defaults in 4.3.0+ (#6218) for M1 Macs could break parallel sampling for M1 max users using Jax Dec 1, 2022
@digicosmos86 digicosmos86 changed the title New mp_ctx defaults in 4.3.0+ (#6218) for M1 Macs could break parallel sampling for M1 max users using Jax New mp_ctx defaults in 4.3.0+ (#6218) for M1 Macs could break parallel sampling for M1 Mac users using JAX Dec 1, 2022
@ricardoV94
Copy link
Member

ricardoV94 commented Dec 2, 2022

Other users couldn't sample in parallel with the older default method. Can you sample parallel with any of the mp_ctx methods?

If not, you may be stuck with sequential sampling (cores=1) or relying on sample_numpyro_nuts for parallel/vectorized sampling

https://www.pymc.io/projects/docs/en/stable/api/generated/pymc.sampling.jax.sample_numpyro_nuts.html#pymc.sampling.jax.sample_numpyro_nuts

@digicosmos86
Copy link
Contributor Author

I am not asking to revert to the older defaults. I know my case is a fringe one, but I think some debugging messages or warnings would be helpful, because the current code is a bit opaque about two things:

  1. The mp_ctx parameter of pm.sample accepts a "multiprocessing.context.BaseContent" object according to the documentation. However, it also implicitly accepts a str.
  2. If the user does supply a str to mp_ctx, the code will override the str to fork if the user is on an M1 Mac. It does so without producing any warnings.

So I think adding some warnings or debug logs would be helpful here. Something like:

if mp_ctx is None or isinstance(mp_ctx, str):
    # Closes issue https://github.com/pymc-devs/pymc/issues/3849
    # Related issue https://github.com/pymc-devs/pymc/issues/5339
    if isinstance(mp_ctx, str):
        logger.warning("A str was passed to mp_ctx. We recommend passing a multiprocessing context using multiprocessing.get_context() method.")
    
    if platform.system() == "Darwin":
        if platform.processor() == "arm":
            mp_ctx = "fork"
            logger.debug("mp_ctx is set to 'fork' for MacOS with ARM architecture. This might cause unexpected behavior with JAX, which is inherently multithreaded.")
        else:
            mp_ctx = "forkserver"

    mp_ctx = multiprocessing.get_context(mp_ctx)

Happy to submit a PR if you think this is a good idea.

BTW, setting mp_ctx to `multiprocessing.get_context("forkserver") does solve this problem, but it took me 4 days to figure out why. That's why I think some debug messages would have been very helpful.

@ricardoV94
Copy link
Member

Oh, we shouldn't override the user specified str, only specify a default when it's None!

@ricardoV94
Copy link
Member

ricardoV94 commented Dec 2, 2022

Something like:

if mp_ctx is None or isinstance(mp_ctx, str):
    if mp_ctx is None and  platform.system() == "Darwin":
        # Closes issue https://github.com/pymc-devs/pymc/issues/3849
        # Related issue https://github.com/pymc-devs/pymc/issues/5339
        if platform.processor() == "arm":
            mp_ctx = "fork"    
        else:
            mp_ctx = "forkserver"
    mp_ctx = multiprocessing.get_context(mp_ctx)

@digicosmos86
Copy link
Contributor Author

That sounds good to me too 👍. I still think adding some debugging log to the default of fork would be helpful for users following this awesome official tutorial (you are the one who wrote it @ricardoV94 😄).

@ricardoV94
Copy link
Member

That sounds good to me too +1. I still think adding some debugging log to the default of fork would be helpful for users following this awesome official tutorial (you are the one who wrote it @ricardoV94 smile).

Yeah debugging logs sound fine, since this logic is all pretty hacky anyway

@digicosmos86
Copy link
Contributor Author

Great! I opened a PR #6363

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug jax macOS macOS related
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants