Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Symbolic expressions now support delayed binding to arguments. Fixes #93. #140

Merged
merged 9 commits into from
Nov 27, 2023

Conversation

patrick-kidger
Copy link
Owner

No description provided.

- A symbolic expression cannot be evaluated unless all of the axes sizes it refers to have already been processed. In practice this usually means that they should only be used in annotations for the return type, and only use axes declared in the arguments.
- Symbolic expressions are evaluated in two stages: they are first evaluated as f-strings using the arguments of the function, and second are evaluated using the processed axis sizes. The f-string evaluation means that they can use local variables by enclosing them with curly braces, e.g. `{variable}`, e.g.
```python
def full(size: int, fill: float) -> Float[Array, "{shape}"]:
Copy link
Sponsor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pretty sure this should be {size}?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed! I'll fix it now.
This does make me think, though -- handling a general shape is a little annoying with this change:

def full(shape: tuple[int, ...], fill: float) -> Float[Array, "{strshape}"]:
    strshape = " ".join(map(str, shape))
    return jnp.full(shape, fill)

I don't have a good answer for that right now.

Copy link
Sponsor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah i thought about this for a minute and decided it was tough. maybe just a helper function? jaxtyping.strshape that just does the first line? tuple shapes are pretty common!

@danielward27
Copy link

Not sure if this is intended, but after reading the updated docs it isn't obvious to me why one of these would work, and the other would not

import equinox as eqx
import jax.numpy as jnp
from beartype import beartype as typechecker
from jaxtyping import ArrayLike, Shaped, jaxtyped


class SomeClass1(eqx.Module):
    dim: int

    @jaxtyped
    @typechecker
    def __call__(self, arr: Shaped[ArrayLike, "2 {self.dim}"]):
        return arr


class SomeClass2(eqx.Module):
    dim: int

    @jaxtyped
    @typechecker
    def __call__(self, arr: Shaped[ArrayLike, "{self._strshape}"]):
        return arr

    @property
    def _strshape(self):
        return f"2 {self.dim}"


dim = 5
x = jnp.ones((2, dim))
my_class1 = SomeClass1(dim)
my_class2 = SomeClass2(dim)


my_class1(x)  # This works
my_class2(x)  # This raises BeartypeCallHintParamViolation

@patrick-kidger
Copy link
Owner Author

Ah, that's an interesting edge-case. The reason is that delayed bindings of the form "{self.foo}" can only be used for a single name (whether that's a single axis e.g. batch or variadic axes e.g. *batch). Here you're trying to expand to two names: 2 5.

I suspect that's probably a limitation that can be lifted though! I'll see what I can do.

@patrick-kidger
Copy link
Owner Author

Having now looked at this a bit, I don't think this is something that can easily be arranged. Right now the code for separating out each axis runs when the annotation is first created. Meanwhile, the code for these kinds of f-strings is ran during the isinstance check.

So in order to support use-cases like the above, we would need to delay running the "axis handling" code until after the f-string handling code. This would (a) represent a small performance drop for unJIT'd code, and (b) be a fairly fiddly rewrite to arrange.

As such I think this probably won't be something supported, at least in the near future.

@danielward27
Copy link

I see, that seems reasonable. Thanks for having a look anyway! I guess this means objects with tuple[int, ...] shape attributes can't use their shapes in the parameter annotations too (unless you know the length of the shape beforehand, in which case you could manually splat it out).

@patrick-kidger patrick-kidger merged commit 64a5bbb into dim2axis Nov 27, 2023
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants