Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add custom concatenate/stack functions #8

Merged
merged 3 commits into from
Dec 15, 2022

Conversation

calbach
Copy link
Contributor

@calbach calbach commented Dec 14, 2022

This is to support stacking / concatenating arrays within a sympy expression.

For discussion:

  • Does this belong in sympy2jax, or would it be better if we specified this as a custom_func (as a client)?

Copy link
Contributor Author

@calbach calbach left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -34,7 +37,16 @@ def fn_(*args):
return fn_


def _args_as_array(fn):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This "works", but looking for other ideas as it means clients cannot specify other args to stack/concatenate. Encapsulating the inputs as an array in sympy would require some more type handling there. Let me know if you have other ideas here.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not clear to me why this is necessary. Can't we just use stack/concatenate directly?

Copy link
Contributor Author

@calbach calbach Dec 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These interfaces take a sequence of arrays as the first argument. jnp.stack, for example: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.stack.html

jax.numpy.stack(arrays, axis=0, out=None, dtype=None)

IIUC to use this directly, I'd need to either already have an array, or to encode my input as an array/tuple/sequence within sympy - which I believe is not something that is supported in sympy2jax currently. e.g. stack(Array([x, y , z]))

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They accept arraylikes, not just arrays (i.e. bool/int/float/complex as well.) E.g. jnp.stack([1, 2]) works just fine.

That aside, most of the time the input should already be an array anyway (from earlier operations).

Note that what you're doing only works at the moment because jnp.array is doing the stack/concenate for you. A stack or concatenate op is the identity function on a single argument.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code here is not about changing the type to Array (that was just an implementation choice), it's for putting the arguments into the right position.

The 0th argument to jnp.stack needs to be an arraylike of values to concatenate. AFAIK I have no way of constructing a sequence like this in sympy2jax compatible sympy (let me know if that's incorrect and please share some example sympy that would achieve this).

In our case we have several state variables we're trying to stitch together and would like to emit as an array, e.g. state_x, state_y

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, you're also talking about just combining the args. Okay:

def _single_args(fn):
  def _fn(*args):
    return fn(args)
  return _fn

_single_args(jnp.stack)

?

Indeed I don't think passing the other arguments to stack is important. (Some of them are holdovers from numpy and aren't implemented in JAX anyway.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Sorry - we were talking past eachother a bit. My original comment was lamenting the usage of the wrapper function in general (for combining the args).

You're right that the call to jnp.array was not needed - I wasn't really even looking at this. Removed.

Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other than the one question, this LGTM!

I think it makes sense to include in sympy2jax so that folks don't need to think about how to implement this. But it could easily be done by a client, it's true.

@patrick-kidger
Copy link
Owner

Looks like the pre-commit checks failed.

FYI you can arrange to have these run automatically whenever you commit. They will either pass and commit, or fail and autoformat.

If the latter, you can then check you're happy with the autoformat, followed by git add -u; git commit to add the changes and commit for real.

(See CONTRIBUTING.MD.)

@calbach
Copy link
Contributor Author

calbach commented Dec 15, 2022

Fixed, I missed this repo was using isort. I'll try out the precommit hook next time.

@patrick-kidger patrick-kidger merged commit ccb4342 into patrick-kidger:main Dec 15, 2022
@patrick-kidger
Copy link
Owner

Merged! Ty.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants