-
Notifications
You must be signed in to change notification settings - Fork 48
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Symbolic expressions now support delayed binding to arguments. Fixes #93. #140
Conversation
65e5e98
to
a39f01f
Compare
6f15a41
to
6c127c2
Compare
a39f01f
to
61ddd77
Compare
f28fbe6
to
d00cd5b
Compare
6431fa6
to
0cfdf31
Compare
d00cd5b
to
a28d631
Compare
docs/api/array.md
Outdated
- A symbolic expression cannot be evaluated unless all of the axes sizes it refers to have already been processed. In practice this usually means that they should only be used in annotations for the return type, and only use axes declared in the arguments. | ||
- Symbolic expressions are evaluated in two stages: they are first evaluated as f-strings using the arguments of the function, and second are evaluated using the processed axis sizes. The f-string evaluation means that they can use local variables by enclosing them with curly braces, e.g. `{variable}`, e.g. | ||
```python | ||
def full(size: int, fill: float) -> Float[Array, "{shape}"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pretty sure this should be {size}
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed! I'll fix it now.
This does make me think, though -- handling a general shape
is a little annoying with this change:
def full(shape: tuple[int, ...], fill: float) -> Float[Array, "{strshape}"]:
strshape = " ".join(map(str, shape))
return jnp.full(shape, fill)
I don't have a good answer for that right now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah i thought about this for a minute and decided it was tough. maybe just a helper function? jaxtyping.strshape
that just does the first line? tuple shapes are pretty common!
a28d631
to
b4140ef
Compare
0cfdf31
to
6f94f0e
Compare
Not sure if this is intended, but after reading the updated docs it isn't obvious to me why one of these would work, and the other would not import equinox as eqx
import jax.numpy as jnp
from beartype import beartype as typechecker
from jaxtyping import ArrayLike, Shaped, jaxtyped
class SomeClass1(eqx.Module):
dim: int
@jaxtyped
@typechecker
def __call__(self, arr: Shaped[ArrayLike, "2 {self.dim}"]):
return arr
class SomeClass2(eqx.Module):
dim: int
@jaxtyped
@typechecker
def __call__(self, arr: Shaped[ArrayLike, "{self._strshape}"]):
return arr
@property
def _strshape(self):
return f"2 {self.dim}"
dim = 5
x = jnp.ones((2, dim))
my_class1 = SomeClass1(dim)
my_class2 = SomeClass2(dim)
my_class1(x) # This works
my_class2(x) # This raises BeartypeCallHintParamViolation |
Ah, that's an interesting edge-case. The reason is that delayed bindings of the form I suspect that's probably a limitation that can be lifted though! I'll see what I can do. |
Having now looked at this a bit, I don't think this is something that can easily be arranged. Right now the code for separating out each axis runs when the annotation is first created. Meanwhile, the code for these kinds of f-strings is ran during the So in order to support use-cases like the above, we would need to delay running the "axis handling" code until after the f-string handling code. This would (a) represent a small performance drop for unJIT'd code, and (b) be a fairly fiddly rewrite to arrange. As such I think this probably won't be something supported, at least in the near future. |
I see, that seems reasonable. Thanks for having a look anyway! I guess this means objects with |
b4140ef
to
0f34fba
Compare
6f94f0e
to
59d732e
Compare
0f34fba
to
4815860
Compare
59d732e
to
7447383
Compare
…his avoids edge-case crash when using pytree-path dependent sizes
…no longer include Float[np.bool, ...].
… from args and kwargs to just arguments.
No description provided.