Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Embarrassing parallel inference #494

Closed
ramav87 opened this issue Dec 16, 2019 · 10 comments
Closed

Embarrassing parallel inference #494

ramav87 opened this issue Dec 16, 2019 · 10 comments
Labels
question Further information is requested

Comments

@ramav87
Copy link

ramav87 commented Dec 16, 2019

A problem I have been unable to solve in pymc3 and numpyro is use of embarrassingly parallel for loops (e.g., with dask): assume we have a list of observations, and I am trying

lazy_result = []
for observation on observation_list:
    lazy_result = dask.delayed(compute_nuts)(xvec, observation )
    lazy_results.append(lazy_result)

futures = dask.persist(*lazy_results)  # trigger computation in the background
results = dask.compute(*futures)

where the function compute_nuts does MCMC and returns the trace. Is there any way to do this at the moment with numpyro? Even when I set the jax to only use the cpu, dask gives an assertion error:

/usr/local/lib/python3.6/dist-packages/numpyro/primitives.py in __exit__(self, *args, **kwargs)
     47 
     48     def __exit__(self, *args, **kwargs):
---> 49         assert _PYRO_STACK[-1] is self
     50         _PYRO_STACK.pop()
     51 

AssertionError:

This is obviously not surprising, but is there a workaround for this kind of task?

@fehiepsi
Copy link
Member

fehiepsi commented Dec 18, 2019

Hi @ramav87, I am not sure if dask is compatible with JAX. Could you try to test it first? Probably you would need to use mp strategies forkserver or spawn instead of fork and limit jax multithreading (see here and here).

Otherwise, you can use pmap for parallel job. Because pmap + pmap does not work in CPU (IIRC), you would need to use sequential chain method for each job.

from jax import random, pmap, numpy as np
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
numpyro.set_host_device_count(4)

def model():
    numpyro.sample("x", dist.Normal(0, 1))

def get_samples(i):
    mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10, num_chains=2,
                chain_method='sequential', progress_bar=False)
    mcmc.run(random.PRNGKey(i))
    return mcmc.get_samples()

pmap(get_samples)(np.arange(4))

@fehiepsi fehiepsi added the question Further information is requested label Dec 18, 2019
@fehiepsi
Copy link
Member

@ramav87 I think that using pmap works for your case. But I am not sure if there will be any performance issue so feel free to open a separate issue if you observe something strange.

@ramav87
Copy link
Author

ramav87 commented Dec 21, 2019

Yes, this works for me. Thank you so much!!!

@d-diaz
Copy link
Contributor

d-diaz commented Nov 11, 2022

@fehiepsi, is there a way to implement the pmap example you show above to submit different shards or batches of data using the i as an index? I have lists of X and y data (e.g., shards_x, shards_y) that I want to pass to mcmc.run(X, y), but I can't figure out how to do this inside the get_samples function.

For example:

def get_samples(i):
    X = shards_x[i]
    y = shards_y[i]
    mcmc = MCMC(NUTS(model), 10, 10, num_chains=1,
                chain_method='sequential', progress_bar=False)
    mcmc.run(random.PRNGKey(i), X, y)
    return mcmc.get_samples()

pmap(get_samples)(np.arange(16))

will throw an error like TracerIntegerConversionError, saying

TracerIntegerConversionError: The __index__() method was called on the JAX Tracer object Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=0/1)>
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerIntegerConversionError

I've tried working around this to set up boolean masks to select rows from the dataframe, but then start running into ConcretizationTypeErrors.

@fehiepsi
Copy link
Member

fehiepsi commented Nov 11, 2022

does

def get_samples(i):
    X = shards_x[i]
    y = shards_y[i]
    return X, y, random.PRNGKey(i)

pmap(get_samples)(np.arange(16))

works for you?

@d-diaz
Copy link
Contributor

d-diaz commented Nov 11, 2022

Nope.

data['shard'] = data.sample(frac=1).reset_index().index % len(jax.devices())
shards_x = [[data.loc[data.shard == i, col].values for col in MODEL_COVARS] for i in range(len(jax.devices()))]
shards_y = [data.loc[data.shard == i, 'DG_OBS'].values for i in range(len(jax.devices()))]

def get_samples(i):
    X = shards_x[i]
    y = shards_y[i]
    return X, y, random.PRNGKey(i)

jax.pmap(get_samples)(np.arange(len(jax.devices())))


---------------------------------------------------------------------------
TracerIntegerConversionError              Traceback (most recent call last)
Cell In [142], line 10
      7     y = shards_y[i]
      8     return X, y, random.PRNGKey(i)
---> 10 jax.pmap(get_samples)(np.arange(len(jax.devices())))

    [... skipping hidden 17 frame]

Cell In [142], line 6, in get_samples(i)
      5 def get_samples(i):
----> 6     X = shards_x[i]
      7     y = shards_y[i]
      8     return X, y, random.PRNGKey(i)

File /opt/miniconda3/envs/refit_fvs/lib/python3.10/site-packages/jax/core.py:547, in Tracer.__index__(self)
    546 def __index__(self):
--> 547   raise TracerIntegerConversionError(self)

TracerIntegerConversionError: The __index__() method was called on the JAX Tracer object Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=0/1)>
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerIntegerConversionError

As far as I can tell, Jax doesn't want me to use i as an index into a list.

@d-diaz
Copy link
Contributor

d-diaz commented Nov 11, 2022

I'm totally open to other methods to pmap the execution of mcmc on different shards, but am struggling to see how to do it.

This also fails regardless of whether the iterable I pass is a list, array, or DeviceArray:

def get_samples(i):
    X = shards_x[i]
    y = shards_y[i]
    return X, y, random.PRNGKey(i)

jax.pmap(get_samples, static_broadcasted_argnums=0)(jnp.arange(len(jax.devices())))

with

Cell In [158], line 10
      7     y = shards_y[i]
      8     return X, y, random.PRNGKey(i)
---> 10 jax.pmap(get_samples, static_broadcasted_argnums=0)(jnp.arange(len(jax.devices())))

ValueError: Non-hashable static arguments are not supported. An error occurred during a call to 'get_samples' while trying to hash an object of type <class 'jaxlib.xla_extension.DeviceArray'>, [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39]. The error was:
TypeError: unhashable type: 'DeviceArray'

@fehiepsi
Copy link
Member

fehiepsi commented Nov 14, 2022

Hi @d-diaz, I'm not sure what you are asking for. If you want to use pmap+MCMC, you can follow the code in this comment, you can replace the index i there by shards_x, shards_y,... (I don't use static_broadcasted_argnums so I'm not sure what you want to do, to pmap a function, the signature is simply pmap(f)(x) where x is a batch of values - by batch here, I mean an array with batched shape like batch_size x 2 x 3, not a list of arrays with shape 2 x 3)

We don't support what jax team does not support, like pmap through a list. It's also hard for us to understand the error that you got without reproducible code. Please feel free to open a thread in our forum for discussions. If you think there are issues here, please feel free to open a separate thread on github.

@d-diaz
Copy link
Contributor

d-diaz commented Dec 16, 2022

I'm trying to run MCMC chains on separate shards of data, with each data shard sent to a different device. The model in your comment doesn't take any inputs, so I can't figure out how I'm supposed to pass different shards of data to the model on each processor.

@fehiepsi
Copy link
Member

Hi @d-diaz, in that code we pmap over the initial seed. For shards of data, you can do

def model(shard):
    ...

def get_samples(shard):
    ...
    mcmc.run(random.PRNGKey(0), shard)
    return mcmc.get_samples()

pmap(get_samples)(shards)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants