Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Get sampling working using Apple Silicon GPU via jax backend #7332

Open
drbenvincent opened this issue May 23, 2024 · 32 comments
Open

ENH: Get sampling working using Apple Silicon GPU via jax backend #7332

drbenvincent opened this issue May 23, 2024 · 32 comments
Labels
feature request hackathon Suitable for hackathon jax macOS macOS related

Comments

@drbenvincent
Copy link

drbenvincent commented May 23, 2024

It would be great to utilise the GPU on Apple Silicon chips. The lowest resistance way of doing this is probably through the jax backend, see https://jax.readthedocs.io/en/latest/installation.html#apple-silicon-gpu-arm-based and the Apple docs Accelerated JAX training on Mac

I don't have the stats, but some sizeable portion of PyMC users run code on hardware with Apple Silicon, and this will increase over time as more people upgrade from Intel to Apple Silicon. Full utilisation of those chips (i.e. the GPU component) would likely unlock some speed-ups in sampling.

So far I have partial progress (ht to @twiecki). I have the following environment file, metal_test_env.yaml

name: metal_test_env
channels:
  - conda-forge
dependencies:
  - blackjax
  - ipykernel
  - jax==0.4.26
  - jupyter
  - numpy
  - pip
  - pymc
  - python<3.11
  - pip:
    - jax-metal
    - jaxlib==0.4.26
    - ml-dtypes==0.2.0

NOTE: It seems that pinning python<3.11 is a necessity at this point in time.

I build that with:

mamba env create -f metal_test_env.yaml
mamba activate metal_test_env

Then in an ipython session we can confirm that jax has detected the Apple Silicon GPU

import jax
jax.print_environment_info()

gives

Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1716460927.826518 4468516 mps_client.cc:510] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M1 Max

systemMemory: 64.00 GB
maxCacheSize: 24.00 GB

I0000 00:00:1716460927.843089 4468516 service.cc:145] XLA service 0x1276ac990 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1716460927.843114 4468516 service.cc:153]   StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1716460927.844725 4468516 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1716460927.844746 4468516 mps_client.cc:384] XLA backend will use up to 51539214336 bytes on device 0 for SimpleAllocator.
jax:    0.4.26
jaxlib: 0.4.26
numpy:  1.26.4
python: 3.10.14 | packaged by conda-forge | (main, Mar 20 2024, 12:51:49) [Clang 16.0.6 ]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='BenjamicStudio7', release='23.5.0', version='Darwin Kernel Version 23.5.0: Wed May  1 20:12:58 PDT 2024; root:xnu-10063.121.3~5/RELEASE_ARM64_T6000', machine='arm64')

The key line is: jax.devices (1 total, 1 local): [METAL(id=0)]

So the next step is to see if we can do sampling:

import numpy as np
import pymc as pm

x = np.random.normal(size=10)
with pm.Model() as model:
    mu = pm.Normal("mu", 0, 1)
    pm.Normal("x_obs", mu=mu, sigma=1, observed=x)
    idata = pm.sample(nuts_sampler="blackjax")

which as of now results in this traceback

Traceback
XlaRuntimeError                           Traceback (most recent call last)
Cell In[6], line 8
      6 mu = pm.Normal("mu", 0, 1)
      7 pm.Normal("x_obs", mu=mu, sigma=1, observed=x)
----> 8 idata = pm.sample(nuts_sampler="blackjax")

File ~/mambaforge/envs/metal_test_env/lib/python3.10/site-packages/pymc/sampling/mcmc.py:688, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, model, **kwargs)
    684     if not isinstance(step, NUTS):
    685         raise ValueError(
    686             "Model can not be sampled with NUTS alone. Your model is probably not continuous."
    687         )
--> 688     return _sample_external_nuts(
    689         sampler=nuts_sampler,
    690         draws=draws,
    691         tune=tune,
    692         chains=chains,
    693         target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.8),
    694         random_seed=random_seed,
    695         initvals=initvals,
    696         model=model,
    697         var_names=var_names,
    698         progressbar=progressbar,
    699         idata_kwargs=idata_kwargs,
    700         nuts_sampler_kwargs=nuts_sampler_kwargs,
    701         **kwargs,
    702     )
    704 if isinstance(step, list):
    705     step = CompoundStep(step)

File ~/mambaforge/envs/metal_test_env/lib/python3.10/site-packages/pymc/sampling/mcmc.py:351, in _sample_external_nuts(sampler, draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progressbar, idata_kwargs, nuts_sampler_kwargs, **kwargs)
    348 elif sampler in ("numpyro", "blackjax"):
    349     import pymc.sampling.jax as pymc_jax
--> 351     idata = pymc_jax.sample_jax_nuts(
    352         draws=draws,
    353         tune=tune,
    354         chains=chains,
    355         target_accept=target_accept,
    356         random_seed=random_seed,
    357         initvals=initvals,
    358         model=model,
    359         var_names=var_names,
    360         progressbar=progressbar,
    361         nuts_sampler=sampler,
    362         idata_kwargs=idata_kwargs,
    363         **nuts_sampler_kwargs,
    364     )
    365     return idata
    367 else:

File ~/mambaforge/envs/metal_test_env/lib/python3.10/site-packages/pymc/sampling/jax.py:567, in sample_jax_nuts(draws, tune, chains, target_accept, random_seed, initvals, jitter, model, var_names, nuts_kwargs, progressbar, keep_untransformed, chain_method, postprocessing_backend, postprocessing_vectorize, postprocessing_chunks, idata_kwargs, compute_convergence_checks, nuts_sampler)
    564     raise ValueError(f"{nuts_sampler=} not recognized")
    566 tic1 = datetime.now()
--> 567 raw_mcmc_samples, sample_stats, library = sampler_fn(
    568     model=model,
    569     target_accept=target_accept,
    570     tune=tune,
    571     draws=draws,
    572     chains=chains,
    573     chain_method=chain_method,
    574     progressbar=progressbar,
    575     random_seed=random_seed,
    576     initial_points=initial_points,
    577     nuts_kwargs=nuts_kwargs,
    578 )
    579 tic2 = datetime.now()
    581 jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)

File ~/mambaforge/envs/metal_test_env/lib/python3.10/site-packages/pymc/sampling/jax.py:398, in _sample_blackjax_nuts(model, target_accept, tune, draws, chains, chain_method, progressbar, random_seed, initial_points, nuts_kwargs)
    395 if chains == 1:
    396     initial_points = [np.stack(init_state) for init_state in zip(initial_points)]
--> 398 logprob_fn = get_jaxified_logp(model)
    400 seed = jax.random.PRNGKey(random_seed)
    401 keys = jax.random.split(seed, chains)

File ~/mambaforge/envs/metal_test_env/lib/python3.10/site-packages/pymc/sampling/jax.py:153, in get_jaxified_logp(model, negative_logp)
    151 if not negative_logp:
    152     model_logp = -model_logp
--> 153 logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logp])
    155 def logp_fn_wrap(x):
    156     return logp_fn(*x)[0]

File ~/mambaforge/envs/metal_test_env/lib/python3.10/site-packages/pymc/sampling/jax.py:146, in get_jaxified_graph(inputs, outputs)
    143 mode.JAX.optimizer.rewrite(fgraph)
    145 # We now jaxify the optimized fgraph
--> 146 return jax_funcify(fgraph)

File ~/mambaforge/envs/metal_test_env/lib/python3.10/functools.py:889, in singledispatch.<locals>.wrapper(*args, **kw)
    885 if not args:
    886     raise TypeError(f'{funcname} requires at least '
    887                     '1 positional argument')
--> 889 return dispatch(args[0].__class__)(*args, **kw)

File ~/mambaforge/envs/metal_test_env/lib/python3.10/site-packages/pytensor/link/jax/dispatch/basic.py:51, in jax_funcify_FunctionGraph(fgraph, node, fgraph_name, **kwargs)
     44 @jax_funcify.register(FunctionGraph)
     45 def jax_funcify_FunctionGraph(
     46     fgraph,
   (...)
     49     **kwargs,
     50 ):
---> 51     return fgraph_to_python(
     52         fgraph,
     53         jax_funcify,
     54         type_conversion_fn=jax_typify,
     55         fgraph_name=fgraph_name,
     56         **kwargs,
     57     )

File ~/mambaforge/envs/metal_test_env/lib/python3.10/site-packages/pytensor/link/utils.py:742, in fgraph_to_python(fgraph, op_conversion_fn, type_conversion_fn, order, storage_map, fgraph_name, global_env, local_env, get_name_for_object, squeeze_output, **kwargs)
    737 input_storage = storage_map.setdefault(
    738     i, [None if not isinstance(i, Constant) else i.data]
    739 )
    740 if input_storage[0] is not None or isinstance(i, Constant):
    741     # Constants need to be assigned locally and referenced
--> 742     global_env[local_input_name] = type_conversion_fn(
    743         input_storage[0], variable=i, storage=input_storage, **kwargs
    744     )
    745     # TODO: We could attempt to use the storage arrays directly
    746     # E.g. `local_input_name = f"{local_input_name}[0]"`
    747 node_input_names.append(local_input_name)

File ~/mambaforge/envs/metal_test_env/lib/python3.10/functools.py:889, in singledispatch.<locals>.wrapper(*args, **kw)
    885 if not args:
    886     raise TypeError(f'{funcname} requires at least '
    887                     '1 positional argument')
--> 889 return dispatch(args[0].__class__)(*args, **kw)

File ~/mambaforge/envs/metal_test_env/lib/python3.10/site-packages/pytensor/link/jax/dispatch/basic.py:35, in jax_typify_ndarray(data, dtype, **kwargs)
     33 if len(data.shape) == 0:
     34     return data.item()
---> 35 return jnp.array(data, dtype=dtype)

File ~/mambaforge/envs/metal_test_env/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:2197, in array(object, dtype, copy, order, ndmin)
   2194 else:
   2195   raise TypeError(f"Unexpected input type for array: {type(object)}")
-> 2197 out_array: Array = lax_internal._convert_element_type(
   2198     out, dtype, weak_type=weak_type)
   2199 if ndmin > ndim(out_array):
   2200   out_array = lax.expand_dims(out_array, range(ndmin - ndim(out_array)))

File ~/mambaforge/envs/metal_test_env/lib/python3.10/site-packages/jax/_src/lax/lax.py:558, in _convert_element_type(operand, new_dtype, weak_type)
    556   return type_cast(Array, operand)
    557 else:
--> 558   return convert_element_type_p.bind(operand, new_dtype=new_dtype,
    559                                      weak_type=bool(weak_type))

File ~/mambaforge/envs/metal_test_env/lib/python3.10/site-packages/jax/_src/core.py:422, in Primitive.bind(self, *args, **params)
    419 def bind(self, *args, **params):
    420   assert (not config.enable_checks.value or
    421           all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args
--> 422   return self.bind_with_trace(find_top_trace(args), args, params)

File ~/mambaforge/envs/metal_test_env/lib/python3.10/site-packages/jax/_src/core.py:425, in Primitive.bind_with_trace(self, trace, args, params)
    424 def bind_with_trace(self, trace, args, params):
--> 425   out = trace.process_primitive(self, map(trace.full_raise, args), params)
    426   return map(full_lower, out) if self.multiple_results else full_lower(out)

File ~/mambaforge/envs/metal_test_env/lib/python3.10/site-packages/jax/_src/core.py:913, in EvalTrace.process_primitive(self, primitive, tracers, params)
    912 def process_primitive(self, primitive, tracers, params):
--> 913   return primitive.impl(*tracers, **params)

File ~/mambaforge/envs/metal_test_env/lib/python3.10/site-packages/jax/_src/dispatch.py:87, in apply_primitive(prim, *args, **params)
     85 prev = lib.jax_jit.swap_thread_local_state_disable_jit(False)
     86 try:
---> 87   outs = fun(*args)
     88 finally:
     89   lib.jax_jit.swap_thread_local_state_disable_jit(prev)

    [... skipping hidden 14 frame]

File ~/mambaforge/envs/metal_test_env/lib/python3.10/site-packages/jax/_src/compiler.py:238, in backend_compile(backend, module, options, host_callbacks)
    233   return backend.compile(built_c, compile_options=options,
    234                          host_callbacks=host_callbacks)
    235 # Some backends don't have `host_callbacks` option yet
    236 # TODO(sharadmv): remove this fallback when all backends allow `compile`
    237 # to take in `host_callbacks`
--> 238 return backend.compile(built_c, compile_options=options)

XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
<unknown>:0: note: see current operation: 
"func.func"() <{arg_attrs = [{mhlo.layout_mode = "default"}], function_type = (tensor<10xf64>) -> tensor<10xf64>, res_attrs = [{jax.result_info = "", mhlo.layout_mode = "default"}], sym_name = "main", sym_visibility = "public"}> ({
^bb0(%arg0: tensor<10xf64>):
  "func.return"(%arg0) : (tensor<10xf64>) -> ()
}) : () -> ()
<unknown>:0: error: failed to legalize operation 'func.func'
<unknown>:0: note: see current operation: 
"func.func"() <{arg_attrs = [{mhlo.layout_mode = "default"}], function_type = (tensor<10xf64>) -> tensor<10xf64>, res_attrs = [{jax.result_info = "", mhlo.layout_mode = "default"}], sym_name = "main", sym_visibility = "public"}> ({
^bb0(%arg0: tensor<10xf64>):
  "func.return"(%arg0) : (tensor<10xf64>) -> ()
}) : () -> ()

For all I know the problem is on the jax side, and may require issues to be filled in that repo. But I think it makes sense to have a pymc issue to raise this goal as a priority and perhaps to coordinate any additional issues on the pymc or jax side.

@ricardoV94
Copy link
Member

ricardoV94 commented May 23, 2024

It seems to fail at something very basing, trying to call jax.numpy.array(data, dtype=dtype) on a numpy array

@ricardoV94
Copy link
Member

Perhaps try with pytensor.config.floatX="float32" before defining the model, maybe it has zero support for float64

@ricardoV94 ricardoV94 added jax macOS macOS related labels May 23, 2024
@drbenvincent
Copy link
Author

Perhaps try with pytensor.config.floatX="float32" before defining the model, maybe it has zero support for float64

Adding

import pytensor
pytensor.config.floatX="float32"

results in a different error

TypeError: true_fun and false_fun output must have identical types, got
Proposal(state=IntegratorState(position=['DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])'], momentum=['ShapedArray(float64[])'], logdensity='ShapedArray(float64[])', logdensity_grad=['DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])']), energy='ShapedArray(float64[])', weight='ShapedArray(float64[])', sum_log_p_accept='ShapedArray(float64[])').

@drbenvincent
Copy link
Author

It seems to fail at something very basing, trying to call jax.numpy.array(data, dtype=dtype) on a numpy array

This actually works fine:

>>> jax.numpy.array(x, dtype="float32")

Array([ 0.42651013,  1.9349691 ,  0.43221945, -0.24343772,  2.760918  ,
        1.2610279 , -1.5116365 ,  0.9801455 ,  0.5613332 ,  0.6750525 ],      dtype=float32)

@drbenvincent
Copy link
Author

It seems to fail at something very basing, trying to call jax.numpy.array(data, dtype=dtype) on a numpy array

Doesn't work for float64, jax.numpy.array(x, dtype="float64")

gives

XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
<unknown>:0: note: see current operation: 
"func.func"() <{arg_attrs = [{mhlo.layout_mode = "default"}], function_type = (tensor<10xf64>) -> tensor<10xf64>, res_attrs = [{jax.result_info = "", mhlo.layout_mode = "default"}], sym_name = "main", sym_visibility = "public"}> ({
^bb0(%arg0: tensor<10xf64>):
  "func.return"(%arg0) : (tensor<10xf64>) -> ()
}) : () -> ()
<unknown>:0: error: failed to legalize operation 'func.func'
<unknown>:0: note: see current operation: 
"func.func"() <{arg_attrs = [{mhlo.layout_mode = "default"}], function_type = (tensor<10xf64>) -> tensor<10xf64>, res_attrs = [{jax.result_info = "", mhlo.layout_mode = "default"}], sym_name = "main", sym_visibility = "public"}> ({
^bb0(%arg0: tensor<10xf64>):
  "func.return"(%arg0) : (tensor<10xf64>) -> ()
}) : () -> ()

@ricardoV94
Copy link
Member

It seems to fail at something very basing, trying to call jax.numpy.array(data, dtype=dtype) on a numpy array

Doesn't work for float64, jax.numpy.array(x, dtype="float64")

Yes that's the error you were getting first

@ricardoV94
Copy link
Member

You're now getting errors deep inside numpyro, and have left PyMC/PyTensor.

Could you try something simpler first? model.compile_logp(mode="JAX")(model.initial_point())?

@drbenvincent
Copy link
Author

drbenvincent commented May 23, 2024

You're now getting errors deep inside numpyro, and have left PyMC/PyTensor.

Could you try something simpler first? model.compile_logp(mode="JAX")(model.initial_point())?

That fails with

XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
<unknown>:0: note: see current operation: 
"func.func"() <{arg_attrs = [{mhlo.layout_mode = "default"}], function_type = (tensor<f32>) -> tensor<f64>, res_attrs = [{jax.result_info = "[0]", mhlo.layout_mode = "default"}], sym_name = "main", sym_visibility = "public"}> ({
^bb0(%arg0: tensor<f32>):
  %0 = "mhlo.constant"() {value = dense<0.918938517> : tensor<10xf32>} : () -> tensor<10xf32>
  %1 = "mhlo.constant"() {value = dense<-5.000000e-01> : tensor<10xf32>} : () -> tensor<10xf32>
  %2 = "mhlo.constant"() {value = dense<0.918938517> : tensor<f32>} : () -> tensor<f32>
  %3 = "mhlo.constant"() {value = dense<-5.000000e-01> : tensor<f32>} : () -> tensor<f32>
  %4 = "mhlo.constant"() {value = dense<0.000000e+00> : tensor<f32>} : () -> tensor<f32>
  %5 = "mhlo.constant"() {value = dense<[-0.0521879941, 1.64668274, -0.0849129111, -0.708814621, -1.28834987, -0.51588583, 0.643824816, 1.94189024, -1.18248034, 1.30991852]> : tensor<10xf32>} : () -> tensor<10xf32>
  %6 = "mhlo.reshape"(%arg0) : (tensor<f32>) -> tensor<1xf32>
  %7 = "mhlo.broadcast_in_dim"(%6) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<1xf32>) -> tensor<10xf32>
  %8 = "mhlo.subtract"(%5, %7) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
  %9 = "mhlo.multiply"(%8, %8) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
  %10 = "mhlo.multiply"(%9, %1) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
  %11 = "mhlo.subtract"(%10, %0) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
  %12 = "mhlo.reduce"(%11, %4) ({
  ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
    %21 = "mhlo.add"(%arg1, %arg2) : (tensor<f32>, tensor<f32>) -> tensor<f32>
    "mhlo.return"(%21) : (tensor<f32>) -> ()
  }) {dimensions = dense<0> : tensor<1xi64>} : (tensor<10xf32>, tensor<f32>) -> tensor<f32>
  %13 = "mhlo.multiply"(%arg0, %arg0) : (tensor<f32>, tensor<f32>) -> tensor<f32>
  %14 = "mhlo.multiply"(%13, %3) : (tensor<f32>, tensor<f32>) -> tensor<f32>
  %15 = "mhlo.subtract"(%14, %2) : (tensor<f32>, tensor<f32>) -> tensor<f32>
  %16 = "mhlo.reshape"(%15) : (tensor<f32>) -> tensor<1xf32>
  %17 = "mhlo.reshape"(%12) : (tensor<f32>) -> tensor<1xf32>
  %18 = "mhlo.concatenate"(%16, %17) {dimension = 0 : i64} : (tensor<1xf32>, tensor<1xf32>) -> tensor<2xf32>
  %19 = "mhlo.reduce"(%18, %4) ({
  ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
    %21 = "mhlo.add"(%arg1, %arg2) : (tensor<f32>, tensor<f32>) -> tensor<f32>
    "mhlo.return"(%21) : (tensor<f32>) -> ()
  }) {dimensions = dense<0> : tensor<1xi64>} : (tensor<2xf32>, tensor<f32>) -> tensor<f32>
  %20 = "mhlo.convert"(%19) : (tensor<f32>) -> tensor<f64>
  "func.return"(%20) : (tensor<f64>) -> ()
}) : () -> ()
<unknown>:0: error: failed to legalize operation 'func.func'
<unknown>:0: note: see current operation: 
"func.func"() <{arg_attrs = [{mhlo.layout_mode = "default"}], function_type = (tensor<f32>) -> tensor<f64>, res_attrs = [{jax.result_info = "[0]", mhlo.layout_mode = "default"}], sym_name = "main", sym_visibility = "public"}> ({
^bb0(%arg0: tensor<f32>):
  %0 = "mhlo.constant"() {value = dense<0.918938517> : tensor<10xf32>} : () -> tensor<10xf32>
  %1 = "mhlo.constant"() {value = dense<-5.000000e-01> : tensor<10xf32>} : () -> tensor<10xf32>
  %2 = "mhlo.constant"() {value = dense<0.918938517> : tensor<f32>} : () -> tensor<f32>
  %3 = "mhlo.constant"() {value = dense<-5.000000e-01> : tensor<f32>} : () -> tensor<f32>
  %4 = "mhlo.constant"() {value = dense<0.000000e+00> : tensor<f32>} : () -> tensor<f32>
  %5 = "mhlo.constant"() {value = dense<[-0.0521879941, 1.64668274, -0.0849129111, -0.708814621, -1.28834987, -0.51588583, 0.643824816, 1.94189024, -1.18248034, 1.30991852]> : tensor<10xf32>} : () -> tensor<10xf32>
  %6 = "mhlo.reshape"(%arg0) : (tensor<f32>) -> tensor<1xf32>
  %7 = "mhlo.broadcast_in_dim"(%6) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<1xf32>) -> tensor<10xf32>
  %8 = "mhlo.subtract"(%5, %7) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
  %9 = "mhlo.multiply"(%8, %8) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
  %10 = "mhlo.multiply"(%9, %1) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
  %11 = "mhlo.subtract"(%10, %0) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
  %12 = "mhlo.reduce"(%11, %4) ({
  ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
    %21 = "mhlo.add"(%arg1, %arg2) : (tensor<f32>, tensor<f32>) -> tensor<f32>
    "mhlo.return"(%21) : (tensor<f32>) -> ()
  }) {dimensions = dense<0> : tensor<1xi64>} : (tensor<10xf32>, tensor<f32>) -> tensor<f32>
  %13 = "mhlo.multiply"(%arg0, %arg0) : (tensor<f32>, tensor<f32>) -> tensor<f32>
  %14 = "mhlo.multiply"(%13, %3) : (tensor<f32>, tensor<f32>) -> tensor<f32>
  %15 = "mhlo.subtract"(%14, %2) : (tensor<f32>, tensor<f32>) -> tensor<f32>
  %16 = "mhlo.reshape"(%15) : (tensor<f32>) -> tensor<1xf32>
  %17 = "mhlo.reshape"(%12) : (tensor<f32>) -> tensor<1xf32>
  %18 = "mhlo.concatenate"(%16, %17) {dimension = 0 : i64} : (tensor<1xf32>, tensor<1xf32>) -> tensor<2xf32>
  %19 = "mhlo.reduce"(%18, %4) ({
  ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
    %21 = "mhlo.add"(%arg1, %arg2) : (tensor<f32>, tensor<f32>) -> tensor<f32>
    "mhlo.return"(%21) : (tensor<f32>) -> ()
  }) {dimensions = dense<0> : tensor<1xi64>} : (tensor<2xf32>, tensor<f32>) -> tensor<f32>
  %20 = "mhlo.convert"(%19) : (tensor<f32>) -> tensor<f64>
  "func.return"(%20) : (tensor<f64>) -> ()
}) : () -> ()

Apply node that caused the error: Sum{axes=None}(MakeVector{dtype='float32'}.0)
Toposort index: 10
Inputs types: [TensorType(float32, shape=(2,))]
Inputs shapes: [()]
Inputs strides: [()]
Inputs values: [array(0., dtype=float32)]
Outputs clients: [['output']]

Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
  File "/Users/benjamv/mambaforge/envs/metal_test_env/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3130, in _run_cell
    result = runner(coro)
  File "/Users/benjamv/mambaforge/envs/metal_test_env/lib/python3.10/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
    coro.send(None)
  File "/Users/benjamv/mambaforge/envs/metal_test_env/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3334, in run_cell_async
    has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
  File "/Users/benjamv/mambaforge/envs/metal_test_env/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3517, in run_ast_nodes
    if await self.run_code(code, result, async_=asy):
  File "/Users/benjamv/mambaforge/envs/metal_test_env/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3577, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-32-e6b494d7709d>", line 1, in <module>
    model.compile_logp(mode="JAX")(model.initial_point())
  File "/Users/benjamv/mambaforge/envs/metal_test_env/lib/python3.10/site-packages/pymc/model/core.py", line 637, in compile_logp
    return self.compile_fn(self.logp(vars=vars, jacobian=jacobian, sum=sum), **compile_kwargs)
  File "/Users/benjamv/mambaforge/envs/metal_test_env/lib/python3.10/site-packages/pymc/model/core.py", line 751, in logp
    logp_scalar = pt.sum([pt.sum(factor) for factor in logp_factors])

HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.

Well, there's more than that which I can give if you really want.

@twiecki
Copy link
Member

twiecki commented May 23, 2024

function_type = (tensor<f32>) -> tensor<f64> seems like it's somehow converting to float64 somewhere.

@ricardoV94
Copy link
Member

ricardoV94 commented May 23, 2024

@drbenvincent try something even simpler, a model just with mu and do compile_logp(sum=False, mode="JAX")

@drbenvincent
Copy link
Author

@drbenvincent try something even simpler, a model just with mu and do compile_logp(sum=False, mode="JAX")

import pymc as pm

with pm.Model() as model:
    mu = pm.Normal("mu", 0, 1)

model.compile_logp(sum=False, mode="JAX")

Does not error out :) Gives me <pymc.pytensorf.PointFunc at 0x11a009ba0>

@ricardoV94
Copy link
Member

ricardoV94 commented May 23, 2024

You need to eval it still with (model.initial_point())

import pytensor
pytensor.config.floatX = "float32"
import pymc as pm

with pm.Model() as model:
    mu = pm.Normal("mu", 0, 1)

model.compile_logp(sum=False, mode="JAX")(model.initial_point())

@drbenvincent
Copy link
Author

It works

In [3]: model.compile_logp(sum=False, mode="JAX")(model.initial_point())
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1716476754.467808 4768691 mps_client.cc:510] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M1 Max

systemMemory: 64.00 GB
maxCacheSize: 24.00 GB

I0000 00:00:1716476754.483546 4768691 service.cc:145] XLA service 0x600000fad200 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1716476754.483558 4768691 service.cc:153]   StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1716476754.484729 4768691 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1716476754.484737 4768691 mps_client.cc:384] XLA backend will use up to 51539214336 bytes on device 0 for SimpleAllocator.
Out[3]: [Array(-0.9189385, dtype=float32)]

@ricardoV94
Copy link
Member

Great. So the next question is, can you use numpyro to sample from that very simple model?

@drbenvincent
Copy link
Author

Looks like the answer is no, so far:

with model:
    idata = pm.sample(draws=10_000, nuts_sampler="numpyro", chains=1)

gives

XlaRuntimeError: INTERNAL: Unable to serialize MPS module
Full traceback
File ~/mambaforge/envs/metal_test_env/lib/python3.10/site-packages/pymc/sampling/mcmc.py:688, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, model, **kwargs)
    684     if not isinstance(step, NUTS):
    685         raise ValueError(
    686             "Model can not be sampled with NUTS alone. Your model is probably not continuous."
    687         )
--> 688     return _sample_external_nuts(
    689         sampler=nuts_sampler,
    690         draws=draws,
    691         tune=tune,
    692         chains=chains,
    693         target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.8),
    694         random_seed=random_seed,
    695         initvals=initvals,
    696         model=model,
    697         var_names=var_names,
    698         progressbar=progressbar,
    699         idata_kwargs=idata_kwargs,
    700         nuts_sampler_kwargs=nuts_sampler_kwargs,
    701         **kwargs,
    702     )
    704 if isinstance(step, list):
    705     step = CompoundStep(step)

File ~/mambaforge/envs/metal_test_env/lib/python3.10/site-packages/pymc/sampling/mcmc.py:351, in _sample_external_nuts(sampler, draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progressbar, idata_kwargs, nuts_sampler_kwargs, **kwargs)
    348 elif sampler in ("numpyro", "blackjax"):
    349     import pymc.sampling.jax as pymc_jax
--> 351     idata = pymc_jax.sample_jax_nuts(
    352         draws=draws,
    353         tune=tune,
    354         chains=chains,
    355         target_accept=target_accept,
    356         random_seed=random_seed,
    357         initvals=initvals,
    358         model=model,
    359         var_names=var_names,
    360         progressbar=progressbar,
    361         nuts_sampler=sampler,
    362         idata_kwargs=idata_kwargs,
    363         **nuts_sampler_kwargs,
    364     )
    365     return idata
    367 else:

File ~/mambaforge/envs/metal_test_env/lib/python3.10/site-packages/pymc/sampling/jax.py:567, in sample_jax_nuts(draws, tune, chains, target_accept, random_seed, initvals, jitter, model, var_names, nuts_kwargs, progressbar, keep_untransformed, chain_method, postprocessing_backend, postprocessing_vectorize, postprocessing_chunks, idata_kwargs, compute_convergence_checks, nuts_sampler)
    564     raise ValueError(f"{nuts_sampler=} not recognized")
    566 tic1 = datetime.now()
--> 567 raw_mcmc_samples, sample_stats, library = sampler_fn(
    568     model=model,
    569     target_accept=target_accept,
    570     tune=tune,
    571     draws=draws,
    572     chains=chains,
    573     chain_method=chain_method,
    574     progressbar=progressbar,
    575     random_seed=random_seed,
    576     initial_points=initial_points,
    577     nuts_kwargs=nuts_kwargs,
    578 )
    579 tic2 = datetime.now()
    581 jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)

File ~/mambaforge/envs/metal_test_env/lib/python3.10/site-packages/pymc/sampling/jax.py:484, in _sample_numpyro_nuts(model, target_accept, tune, draws, chains, chain_method, progressbar, random_seed, initial_points, nuts_kwargs)
    481 if chains > 1:
    482     map_seed = jax.random.split(map_seed, chains)
--> 484 pmap_numpyro.run(
    485     map_seed,
    486     init_params=initial_points,
    487     extra_fields=(
    488         "num_steps",
    489         "potential_energy",
    490         "energy",
    491         "adapt_state.step_size",
    492         "accept_prob",
    493         "diverging",
    494     ),
    495 )
    497 raw_mcmc_samples = pmap_numpyro.get_samples(group_by_chain=True)
    498 sample_stats = _numpyro_stats_to_dict(pmap_numpyro)

File ~/mambaforge/envs/metal_test_env/lib/python3.10/site-packages/numpyro/infer/mcmc.py:666, in MCMC.run(self, rng_key, extra_fields, init_params, *args, **kwargs)
    664 map_args = (rng_key, init_state, init_params)
    665 if self.num_chains == 1:
--> 666     states_flat, last_state = partial_map_fn(map_args)
    667     states = tree_map(lambda x: x[jnp.newaxis, ...], states_flat)
    668 else:

File ~/mambaforge/envs/metal_test_env/lib/python3.10/site-packages/numpyro/infer/mcmc.py:462, in MCMC._single_chain_mcmc(self, init, args, kwargs, collect_fields, remove_sites)
    456 collection_size = self._collection_params["collection_size"]
    457 collection_size = (
    458     collection_size
    459     if collection_size is None
    460     else collection_size // self.thinning
    461 )
--> 462 collect_vals = fori_collect(
    463     lower_idx,
    464     upper_idx,
    465     sample_fn,
    466     init_val,
    467     transform=_collect_fn(collect_fields, remove_sites),
    468     progbar=self.progress_bar,
    469     return_last_val=True,
    470     thinning=self.thinning,
    471     collection_size=collection_size,
    472     progbar_desc=partial(_get_progbar_desc_str, lower_idx, phase),
    473     diagnostics_fn=diagnostics,
    474     num_chains=self.num_chains if self.chain_method == "parallel" else 1,
    475 )
    476 states, last_val = collect_vals
    477 # Get first argument of type `HMCState`

File ~/mambaforge/envs/metal_test_env/lib/python3.10/site-packages/numpyro/util.py:367, in fori_collect(lower, upper, body_fun, init_val, transform, progbar, return_last_val, collection_size, thinning, **progbar_opts)
    365 with tqdm.trange(upper) as t:
    366     for i in t:
--> 367         vals = jit(_body_fn)(i, vals)
    368         t.set_description(progbar_desc(i), refresh=False)
    369         if diagnostics_fn:

    [... skipping hidden 14 frame]

File ~/mambaforge/envs/metal_test_env/lib/python3.10/site-packages/jax/_src/compiler.py:238, in backend_compile(backend, module, options, host_callbacks)
    233   return backend.compile(built_c, compile_options=options,
    234                          host_callbacks=host_callbacks)
    235 # Some backends don't have `host_callbacks` option yet
    236 # TODO(sharadmv): remove this fallback when all backends allow `compile`
    237 # to take in `host_callbacks`
--> 238 return backend.compile(built_c, compile_options=options)

XlaRuntimeError: INTERNAL: Unable to serialize MPS module

@twiecki
Copy link
Member

twiecki commented May 23, 2024

Can you try with chains=1, cores=1?

@drbenvincent
Copy link
Author

Same error I'm afraid

@ricardoV94
Copy link
Member

So it's still pretty broken

@twiecki
Copy link
Member

twiecki commented May 24, 2024

Does installing jaxlib==0.4.26+metal help anything?

@drbenvincent
Copy link
Author

Does installing jaxlib==0.4.26+metal help anything?

I tried a bunch of things and no luck yet. Specifically on that suggestion, I tried pip install jaxlib=="0.4.26+metal" but got

ERROR: Could not find a version that satisfies the requirement jaxlib==0.4.26+metal (from versions: 0.4.6, 0.4.7, 0.4.9, 0.4.10, 0.4.11, 0.4.12, 0.4.13, 0.4.14, 0.4.16, 0.4.17, 0.4.18, 0.4.19, 0.4.20, 0.4.21, 0.4.22, 0.4.23, 0.4.24, 0.4.25, 0.4.26, 0.4.27, 0.4.28)
ERROR: No matching distribution found for jaxlib==0.4.26+metal

@drbenvincent
Copy link
Author

FYI I created a topic in the Pyro discourse so we'll see if there are any plans for support from the pyro side of things.

@twiecki
Copy link
Member

twiecki commented May 27, 2024

Upgraded jax and jax-lib and set ENABLE_PJRT_COMPATIBILITY=1 allows using more recent jax versions.

@twiecki
Copy link
Member

twiecki commented May 27, 2024

Tried it myself with that setting, and the newer JAX version does not produce the previous error, but I'm still getting the MPS error.

>>> with model:
...     idata = pm.sample(draws=10_000, nuts_sampler="numpyro", chains=1)
...
  0%|                                                                                                                                               | 0/11000 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "<stdin>", line 2, in <module>
  File "/Users/twiecki/jax-metal/lib/python3.11/site-packages/pymc/sampling/mcmc.py", line 688, in sample
    return _sample_external_nuts(
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/twiecki/jax-metal/lib/python3.11/site-packages/pymc/sampling/mcmc.py", line 351, in _sample_external_nuts
    idata = pymc_jax.sample_jax_nuts(
            ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/twiecki/jax-metal/lib/python3.11/site-packages/pymc/sampling/jax.py", line 567, in sample_jax_nuts
    raw_mcmc_samples, sample_stats, library = sampler_fn(
                                              ^^^^^^^^^^^
  File "/Users/twiecki/jax-metal/lib/python3.11/site-packages/pymc/sampling/jax.py", line 484, in _sample_numpyro_nuts
    pmap_numpyro.run(
  File "/Users/twiecki/jax-metal/lib/python3.11/site-packages/numpyro/infer/mcmc.py", line 666, in run
    states_flat, last_state = partial_map_fn(map_args)
                              ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/twiecki/jax-metal/lib/python3.11/site-packages/numpyro/infer/mcmc.py", line 462, in _single_chain_mcmc
    collect_vals = fori_collect(
                   ^^^^^^^^^^^^^
  File "/Users/twiecki/jax-metal/lib/python3.11/site-packages/numpyro/util.py", line 367, in fori_collect
    vals = jit(_body_fn)(i, vals)
           ^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Unable to serialize MPS module
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
>>> jax.__version__
'0.4.28'

@twiecki
Copy link
Member

twiecki commented May 27, 2024

It does work with pm.sample() but gives a warning:

Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu]
/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
  self.pid = os.fork()

But seems to sample just fine so far.

@twiecki
Copy link
Member

twiecki commented May 27, 2024

Same error with blackjax. but using vectorization gives:

>>> with model:    idata = pm.sample(draws=10_000, nuts_sampler="blackjax", cores=1, chains=1, nuts_sampler_kwargs={"chain_method":  "vectorized"})
...
Running window adaptation
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/twiecki/jax-metal/lib/python3.11/site-packages/pymc/sampling/mcmc.py", line 688, in sample
    return _sample_external_nuts(
  File "/Users/twiecki/jax-metal/lib/python3.11/site-packages/pymc/sampling/mcmc.py", line 351, in _sample_external_nuts
    idata = pymc_jax.sample_jax_nuts(
  File "/Users/twiecki/jax-metal/lib/python3.11/site-packages/pymc/sampling/jax.py", line 567, in sample_jax_nuts
    raw_mcmc_samples, sample_stats, library = sampler_fn(
  File "/Users/twiecki/jax-metal/lib/python3.11/site-packages/pymc/sampling/jax.py", line 413, in _sample_blackjax_nuts
    states, stats = map_fn(get_posterior_samples)(keys, initial_points)
  File "/Users/twiecki/jax-metal/lib/python3.11/site-packages/pymc/sampling/jax.py", line 250, in _blackjax_inference_loop
    (last_state, tuned_params), _ = adapt.run(seed, init_position, num_steps=tune)
  File "/Users/twiecki/jax-metal/lib/python3.11/site-packages/blackjax/adaptation/window_adaptation.py", line 334, in run
    last_state, info = jax.lax.scan(
  File "/Users/twiecki/jax-metal/lib/python3.11/site-packages/blackjax/progress_bar.py", line 80, in wrapper_progress_bar
    _update_progress_bar(iter_num)
  File "/Users/twiecki/jax-metal/lib/python3.11/site-packages/blackjax/progress_bar.py", line 46, in _update_progress_bar
    _ = lax.cond(
  File "/Users/twiecki/jax-metal/lib/python3.11/site-packages/blackjax/progress_bar.py", line 48, in <lambda>
    lambda _: io_callback(_define_bar, None, iter_num),
  File "/Users/twiecki/jax-metal/lib/python3.11/site-packages/jax/_src/callback.py", line 502, in io_callback
    out_flat = io_callback_p.bind(
jax._src.source_info_util.JaxStackTraceBeforeTransformation: ValueError: `EmitPythonCallback` not supported on METAL backend.

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

@twiecki
Copy link
Member

twiecki commented May 27, 2024

@junpenglao Is there a way to not have blackjax spawn/fork a new process?

@twiecki
Copy link
Member

twiecki commented May 27, 2024

Current summary: it seems that running chain_method="parallel" uses pmap which calls os.fork() which breaks metal. Avoiding this by setting chain_method="vectorized" uses vmap which works, but then we run into lack of support for lax.scan used in blackjax https://github.com/blackjax-devs/blackjax/blob/360ac3bc56db2bb608b8456f446828fe8c349672/blackjax/adaptation/window_adaptation.py#L340. Apple issue: https://forums.developer.apple.com/forums/thread/750163

@drbenvincent
Copy link
Author

drbenvincent commented May 29, 2024

Upgraded jax and jax-lib and set ENABLE_PJRT_COMPATIBILITY=1 allows using more recent jax versions.

A quick recap of the steps I'm taking here.

With a environment file:

name: metal_test_env
channels:
  - conda-forge
dependencies:
  - jupyterlab
  - numpy
  - numpyro
  - pip
  - pymc
  - python<3.11
  - pip:
    - blackjax
    - jax==0.4.28
    - jax-metal
    - jaxlib==0.4.28
    - ml-dtypes==0.2.0
conda env create -f metal_test_env.yaml
conda activate metal_test_env

Then run

import os
import jax
import pytensor
import pymc as pm
import numpy as np

os.environ['ENABLE_PJRT_COMPATIBILITY'] = '1'
pytensor.config.floatX = "float32"

jax.print_environment_info() # gives positive output, metal device is recognised

This works fine:

with pm.Model() as model:
    mu = pm.Normal("mu", 0, 1)

model.compile_logp(sum=False, mode="JAX")(model.initial_point()) # returns something like [Array(-0.9189385, dtype=float32)]

pm.sample works, but doesn't utilise GPU, using activity monitor

with model:
    idata = pm.sample(draws=10_000, cores=1, chains=1)

And sampling with numpyro backend still gives MPS error

with model:
    idata = pm.sample(draws=1_000, nuts_sampler="numpyro", chains=1, cores=1)

# XlaRuntimeError: INTERNAL: Unable to serialize MPS module

Trying with blackjax backend I get the same errors as you:

  • XlaRuntimeError: INTERNAL: Unable to serialize MPS module
  • ValueError: EmitPythonCallback not supported on METAL backend. when setting nuts_sampler_kwargs={"chain_method": "vectorized"}

So as far as I understand, nobody so far has been able to use GPU to sample.

@twiecki
Copy link
Member

twiecki commented May 29, 2024 via email

@drbenvincent
Copy link
Author

Fingers crossed it just starts working randomly as the metal support for jax improves.

@twiecki
Copy link
Member

twiecki commented May 29, 2024 via email

@fonnesbeck fonnesbeck added the hackathon Suitable for hackathon label Jun 14, 2024
@twiecki
Copy link
Member

twiecki commented Jun 18, 2024

Tried with the new jax-metal 0.1.0 and jax and jaxlib 0.4.30

Setting ENABLE_PJRT_COMPATIBILITY=1
Code:

import pytensor
pytensor.config.floatX="float32"
import numpy as np
import pymc as pm
    x = np.random.normal(size=10)
    with pm.Model() as model:
        mu = pm.Normal("mu", 0, 1)
        pm.Normal("x_obs", mu=mu, sigma=1, observed=x)
        idata = pm.sample(nuts_sampler="blackjax", cores=1, chains=1, nuts_sampler_kwargs={"chain_method": "vectorized"})

Gives:

File ~/micromamba/envs/pymc5/lib/python3.11/site-packages/pymc/sampling/jax.py:250, in _blackjax_inference_loop(seed, init_position, logprob_fn, draws, tune, target_accept, **adaptation_kwargs)
    242     raise ValueError("Only supporting 'nuts' or 'hmc' as algorithm to draw samples.")
    244 adapt = blackjax.window_adaptation(
    245     algorithm=algorithm,
    246     logdensity_fn=logprob_fn,
    247     target_acceptance_rate=target_accept,
    248     **adaptation_kwargs,
    249 )
--> 250 (last_state, tuned_params), _ = adapt.run(seed, init_position, num_steps=tune)
    251 kernel = algorithm(logprob_fn, **tuned_params).step
    253 def _one_step(state, xs):

File ~/micromamba/envs/pymc5/lib/python3.11/site-packages/blackjax/adaptation/window_adaptation.py:334, in window_adaptation.<locals>.run(rng_key, position, num_steps)
    332 keys = jax.random.split(rng_key, num_steps)
    333 schedule = build_schedule(num_steps)
--> 334 last_state, info = jax.lax.scan(
    335     one_step_,
    336     (init_state, init_adaptation_state),
    337     (jnp.arange(num_steps), keys, schedule),
    338 )
    339 last_chain_state, last_warmup_state, *_ = last_state
    341 step_size, inverse_mass_matrix = adapt_final(last_warmup_state)

    [... skipping hidden 51 frame]

File ~/micromamba/envs/pymc5/lib/python3.11/site-packages/jax/_src/callback.py:457, in io_callback_lowering(ctx, callback, sharding, ordered, *args, **params)
    455   ctx.set_tokens_out(mlir.TokenSet({_OrderedIOEffect: (token,)}))
    456 else:
--> 457   result, token, _ = mlir.emit_python_callback(
    458       ctx,
    459       _callback,
    460       None,
    461       list(args),
    462       ctx.avals_in,
    463       ctx.avals_out,
    464       has_side_effect=True,
    465       sharding=op_sharding,
    466   )
    467 return result

File ~/micromamba/envs/pymc5/lib/python3.11/site-packages/jax/_src/interpreters/mlir.py:2458, in emit_python_callback(ctx, callback, token, operands, operand_avals, result_avals, has_side_effect, sharding, operand_layouts, result_layouts)
   2456 platform = ctx.module_context.platforms[0]
   2457 if platform not in {"cpu", "cuda", "rocm", "tpu"}:
-> 2458   raise ValueError(
   2459       f"`EmitPythonCallback` not supported on {platform} backend.")
   2460 backend = ctx.module_context.backend
   2461 result_shapes = util.flatten(
   2462     [xla.aval_to_xla_shapes(result_aval) for result_aval in result_avals])

ValueError: `EmitPythonCallback` not supported on METAL backend.

Running with numpyro:

File ~/micromamba/envs/pymc5/lib/python3.11/site-packages/pymc/sampling/jax.py:484, in _sample_numpyro_nuts(model, target_accept, tune, draws, chains, chain_method, progressbar, random_seed, initial_points, nuts_kwargs)
    481 if chains > 1:
    482     map_seed = jax.random.split(map_seed, chains)
--> 484 pmap_numpyro.run(
    485     map_seed,
    486     init_params=initial_points,
    487     extra_fields=(
    488         "num_steps",
    489         "potential_energy",
    490         "energy",
    491         "adapt_state.step_size",
    492         "accept_prob",
    493         "diverging",
    494     ),
    495 )
    497 raw_mcmc_samples = pmap_numpyro.get_samples(group_by_chain=True)
    498 sample_stats = _numpyro_stats_to_dict(pmap_numpyro)

File ~/micromamba/envs/pymc5/lib/python3.11/site-packages/numpyro/infer/mcmc.py:666, in MCMC.run(self, rng_key, extra_fields, init_params, *args, **kwargs)
    664 map_args = (rng_key, init_state, init_params)
    665 if self.num_chains == 1:
--> 666     states_flat, last_state = partial_map_fn(map_args)
    667     states = tree_map(lambda x: x[jnp.newaxis, ...], states_flat)
    668 else:

File ~/micromamba/envs/pymc5/lib/python3.11/site-packages/numpyro/infer/mcmc.py:462, in MCMC._single_chain_mcmc(self, init, args, kwargs, collect_fields, remove_sites)
    456 collection_size = self._collection_params["collection_size"]
    457 collection_size = (
    458     collection_size
    459     if collection_size is None
    460     else collection_size // self.thinning
    461 )
--> 462 collect_vals = fori_collect(
    463     lower_idx,
    464     upper_idx,
    465     sample_fn,
    466     init_val,
    467     transform=_collect_fn(collect_fields, remove_sites),
    468     progbar=self.progress_bar,
    469     return_last_val=True,
    470     thinning=self.thinning,
    471     collection_size=collection_size,
    472     progbar_desc=partial(_get_progbar_desc_str, lower_idx, phase),
    473     diagnostics_fn=diagnostics,
    474     num_chains=self.num_chains if self.chain_method == "parallel" else 1,
    475 )
    476 states, last_val = collect_vals
    477 # Get first argument of type `HMCState`

File ~/micromamba/envs/pymc5/lib/python3.11/site-packages/numpyro/util.py:367, in fori_collect(lower, upper, body_fun, init_val, transform, progbar, return_last_val, collection_size, thinning, **progbar_opts)
    365 with tqdm.trange(upper) as t:
    366     for i in t:
--> 367         vals = jit(_body_fn)(i, vals)
    368         t.set_description(progbar_desc(i), refresh=False)
    369         if diagnostics_fn:

    [... skipping hidden 14 frame]

File ~/micromamba/envs/pymc5/lib/python3.11/site-packages/jax/_src/compiler.py:238, in backend_compile(backend, module, options, host_callbacks)
    233   return backend.compile(built_c, compile_options=options,
    234                          host_callbacks=host_callbacks)
    235 # Some backends don't have `host_callbacks` option yet
    236 # TODO(sharadmv): remove this fallback when all backends allow `compile`
    237 # to take in `host_callbacks`
--> 238 return backend.compile(built_c, compile_options=options)

XlaRuntimeError: INTERNAL: Unable to serialize MPS module

Not sure why it's always trying to fork even when using vectorize.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature request hackathon Suitable for hackathon jax macOS macOS related
Projects
None yet
Development

No branches or pull requests

4 participants