Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Script not compatible with version of jax > 0.2.8 and not executable #3

Open
BMP-TUD opened this issue Mar 9, 2023 · 0 comments
Open

Comments

@BMP-TUD
Copy link

BMP-TUD commented Mar 9, 2023

Dear Sirs,

thanks for creating the symder package, I have read your publication with a lot of interest. However, I tried to reproduce your study and downloaded the git repository. Here, in your README, you state the code should work for JAX >=0.2.8. After a quite painful trial and error process, the code is only able to load all packages just with only exactly your described package versions, especially for jax, jaxlib, optax and is also not compatible anymore with any higher version of numpy than yours (1.19.2). Therefore, I would suggest, to include this in the repository README. Here, it is also important that the mentioned versions are only compatible with python up to 3.9.X and not above!

This would be on the compatibility issues that I had. Another problem that I run into right now, is when I want execute rossler_model.py. When I try to run it it says that a data set has been generated, however, when the get_model function is executed I am getting the following error:

Loading dataset from file: ./data/rossler.npz
Traceback (most recent call last):

  File ~/miniconda3/envs/symder/lib/python3.9/site-packages/spyder_kernels/py3compat.py:356 in compat_exec
    exec(code, globals, locals)

  File ~/symder_test/symder/rossler_model.py:238
    model_apply, model_init, model_args = get_model(

  File ~/symder_test/symder/rossler_model.py:49 in get_model
    scale_vec = jnp.concatenate((scale[:, 0], jnp.ones(num_hidden)))

  File ~/miniconda3/envs/symder/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py:2897 in ones
    return lax.full(shape, 1, dtype)

  File ~/miniconda3/envs/symder/lib/python3.9/site-packages/jax/_src/lax/lax.py:1452 in full
    return broadcast(fill_value, shape)

  File ~/miniconda3/envs/symder/lib/python3.9/site-packages/jax/_src/lax/lax.py:688 in broadcast
    return broadcast_in_dim(operand, tuple(sizes) + np.shape(operand), dims)

  File ~/miniconda3/envs/symder/lib/python3.9/site-packages/jax/_src/lax/lax.py:701 in broadcast_in_dim
    return broadcast_in_dim_p.bind(

  File ~/miniconda3/envs/symder/lib/python3.9/site-packages/jax/core.py:271 in bind
    out = top_trace.process_primitive(self, tracers, params)

  File ~/miniconda3/envs/symder/lib/python3.9/site-packages/jax/core.py:595 in process_primitive
    return primitive.impl(*tracers, **params)

  File ~/miniconda3/envs/symder/lib/python3.9/site-packages/jax/_src/lax/lax.py:3257 in _broadcast_in_dim_impl
    if xla.type_is_device_array(operand) and np.all(

  File <__array_function__ internals>:200 in all

  File ~/miniconda3/envs/symder/lib/python3.9/site-packages/numpy/core/fromnumeric.py:2515 in all
    return _wrapreduction(a, np.logical_and, 'all', axis, None, out,

  File ~/miniconda3/envs/symder/lib/python3.9/site-packages/numpy/core/fromnumeric.py:86 in _wrapreduction
    return ufunc.reduce(obj, axis, dtype, out, **passkwargs)

TypeError: an integer is required (got type _NoValueType)

I am not really sure why this error is generated, maybe you could help me out?
Thanks for your help in advance,

Best wishes,
Bartosz

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant