-
Notifications
You must be signed in to change notification settings - Fork 48
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding a test for generator support #171
Conversation
198758b
to
196a935
Compare
LGTM! |
Oh wow, you are actually right: it does in fact still fail when doing a double decorator, with torch tensors. I will actually have to fix this... Thanks for the observation! |
196a935
to
6e6248d
Compare
Do you mean that it fails for torch tensors specifically, but not for JAX etc arrays? For what it's worth, if it's difficult for you then I'd also be happy to just not support this! It's a pretty uncommon use case and it may complicate our own code quite a lot. (?) |
For now I observed a bunch of weird ways it can fail, and I am not really sure how complex fixing it is going to be, so I will just try for a bit. And, if I give up, I will tell you in a week or so :) e.g. right now one of the possible solutions that I can think of is to just replace annotations for generator outputs with dummy annotations, so that we just certainly do not typecheck anything. I reckon it is better in this case to guarantee that we don't typecheck outputs of generators (which is probably an expected behaviour, since doing it is sort of impossible), than to fail in some weird cases. Anyway, I will see what I can do, and if I cannot fix it, then maybe it is not worth it. But I still want to try. |
e0a8556
to
df93af6
Compare
This comment was marked as outdated.
This comment was marked as outdated.
0c5e81e
to
d3ee886
Compare
Also guarded torch imports for better compatibility with requirements.txt
d3ee886
to
1824164
Compare
So, I implemented a solution that just replaces any output annotations for arbitrarily-deeply nested generators with There are drawbacks: it might be unexpected for end users that generators are not typechecked, but since it is the behavior of bear type to silently ignore generator outputs, I guess we can adhere to it, and do the same thing :) Another issue is that if the custom typechecker the user employs uses generators for some reason, we might detect this generator and stop typechecking outputs, but I find it very unlikely: I simply don't see how one find any use for generators-as-decorators in this case. But I still wonder WDYT? I spent quite some time trying to do something with Anyway, @patrick-kidger, I wonder what do you think? You were right originally that the problem is a bit deep-rooted, but the solution I implemented is extremely simple, though I guess you might be able to find cases when it fails. P.S. Sorry for the dumb things I wrote in the last comment, I am now officially stupid. I was kinda overwhelmed with the debugging, so I thought what I wrote was sensible, but after a good night's sleep it stopped making any sense at all. Thanks for allowing me to figure out my brain issues by myself! |
jaxtyping/_decorator.py
Outdated
while hasattr(wrp, "__wrapped__"): | ||
wrp = wrp.__wrapped__ | ||
if inspect.isgeneratorfunction(wrp): | ||
fn.__annotations__["return"] = Any |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generator decorators
Your question about a typechecking decorator using generators for some reason -- I think this could be handled by de-indenting these two lines. Do all the unwrapping, then check for being a generator.
Comments on this approach
That said, there are two things I don't like about this approach:
- We're mutating an object (
fn
) that we don't own. - One of the intermediate decorators might have already inspected the annotations and stored a copy somewhere else. (Indeed any sane type-checking decorator will have already done this, as it will want to store a copy of any destringified annotations somewhere. That said, I don't think any of the existing typechecking decorators actually does this..., and I have a horrible feeling that they also trample all over the existing
__annotations__
dictionary.)
Another possibility -- that I also don't love
Here's one other possibility I'm contemplating, that I think should still be fairly lightweight. At this spot in the code, we have:
wrp = fn
while hasattr(wrp, "__wrapped__"):
wrp = wrp.__wrapped__
is_generator = inspect.isgeneratorfunction(wrp)
and then where we currently have return fn(*args, **kwargs)
, we instead have:
out = fn(*args, **kwargs):
if is_generator:
out = _jaxtyped_generator(out)
return out
where we define:
def _jaxtyped_generator(gen):
while True:
try:
with jaxtyped("context"):
val = next(gen)
yield val
except StopIteration:
break
Actually, I think the above might need a little more fiddling to be sure that .send
ing values to generators will works, but you get the idea.
This will create a fresh jaxtyped
context for every value that is next
'd by a user, so we don't get consistency across those... but I think we agree that supporting that is probably much more effort than it's worth.
(Downside of this approach: if for some reason a user has a generatorfunction that, when decorated, is transformed into something that isn't iterable, then the above may misbehave horribly.)
How much do we want to support?
That said, I'm definitely not sold on my suggestion above. I don't really want to support the full gamut of generators + simple coroutines + generator courtines + async def
coroutines + contextlib.contextmanager
-wrapped functions + whatever else might be cooked up that I've forgotten.
(Handy reference for the first three.)
Python overloads functions to mean many different things, it might be pretty difficult to support all of them.
WDYT? I'm honestly not sure of the best approach here.
Well, if considering the two ways, I certainly prefer the first one. The reason is simple, the second approach is going to lead to probably a lot of not-really-well-supported code, that might fail in unexpected ways. Supporting all possible ways Python can do generators is probably not feasible, there is simply too much technical debt. Anyway, now I also see why you dislike my approach (aka dumb approach). But there is a modification that might be nice: Modifying naive approach
About the actual solution: I just simply added a private property to the Possible problems:
This is better than the original one, at least. But maybe it is still bad enough that you feel like not supporting generators at all is a better idea. I guess I am not that much just giving up, since it is the 'old' syntax, so we are officially allowed to just redirect users to the new syntax. Whereas if they have reasons to use the old syntax... Well, anyway, waiting for your opinion. The latest version of the code in the branch is approximately what I would have written: I still need to add support/check that it works for async generators, though I want to make sure that you are fine/not fine with the general approach before I do so. |
I think I like the current approach! It's not perfect for the reasons we've each highlighted, but I think it might also just be as good as it gets. I think my only minor nits would be to (a) prefer |
I've just updated the target branch to |
Yeah, sure. Sorry it's taking so long |
b432f47
to
39bc69c
Compare
While writing more tests and modifying the code according to your recommendations I discovered that I have actually introduced a bug in the last version: the def test_generators_dont_modify_same_annotations(typecheck):
@jaxtyped(typechecker=None)
@typecheck
def g(x: Float[Array, "1"]) -> Iterator[Float[Array, "1"]]:
yield x
@jaxtyped(typechecker=typecheck)
def m(x: Float[Array, "1"]) -> Iterator[Float[Array, "1"]]:
return x
with pytest.raises(ParamError): # fails (does not raise)
next(g(jnp.zeros(2)))
with pytest.raises(ParamError): # fails too
m(jnp.zeros(2)) since all the annotations So, I rewrote the thing to the old solution, with some slight modifications: now, we don't replace the whole annotation with Also added Anyway, WDYT? I wonder if you might have come up with a nicer solution. |
BTW, now the code should be ready for merging, so if you like everything - feel free to merge :) |
Oh bother, that's a good observation. Yeah, that approach doesn't work then. My concern with this approach is that we're back to mutating objects we don't own. E.g. I think I can see a way to revive the previous approach. The reason it didn't work is because we have a cache on I think all we need to do is to move this part of jaxtyping/jaxtyping/_array_types.py Lines 541 to 555 in 8de8c0b
to happen outside the cache. That is, we should cache the tuple of This way we will get a fresh jaxtyping object on every WDYT? |
That is a nice suggestion. I don't think I would have came up with this under any circumstances, it requires either better debugging skills or better codebase knowledge than I posses. Thanks! Besides, I figured there was a bug, that was hidden thanks to caching: the equality between AbstractMetaArrays was not implemented (correctly), so things like Well, I just made all the Dim* variations frozen dataclasses, that automatically implement the required |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright, marvellous stuff!
This took us a lot of brainpower -- but I think this is a solution that should work. :D
Usual nits below, then let's merge this.
jaxtyping/_decorator.py
Outdated
# annotations as not needing instance checks, while still being | ||
# visible as original ones for the typechecker | ||
def modify_annotation(ann): | ||
if isinstance(ann, _MetaAbstractArray): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's make this issubclass(ann, AbstractArray)
.
(Notice that we had to import an underscored object to get this. The convention I use throughout the Equinox+JAX ecosystem is that such objects are private to the file they're defined in.)
jaxtyping/_decorator.py
Outdated
# visible as original ones for the typechecker | ||
def modify_annotation(ann): | ||
if isinstance(ann, _MetaAbstractArray): | ||
setattr(ann, "_skip_instancecheck", True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ann._skip_instancecheck = True
?
jaxtyping/_decorator.py
Outdated
elif hasattr(ann, "__args__"): | ||
for sub_ann in get_args(ann): | ||
modify_annotation(sub_ann) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the hasattr
check can be removed: get_args
will always return a tuple. (If necessary an empty tuple.)
jaxtyping/_decorator.py
Outdated
elif hasattr(ann, "__origin__"): | ||
modify_annotation(get_origin(ann)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In contrast this one should probably technically be
origin = get_origin(ann)
if origin is not None:
modify_annotation(origin)
(IIRC __args__
and __origin__
are actually implementation details that we're not really supposed to touch.)
Simplified the WDYT? |
Alright, great stuff! Just merged. Thank you as always for your efforts -- and I'm going to do a new jaxtyping release shortly :) |
* Add a test for generators * Remove output annotations from decorators Also guarded torch imports for better compatibility with requirements.txt * Add flag to the main meta class to skip the typecheck * Return to the old solution * Make async tests work * Minor adjustments/fixing typos * Correct Python path for new tests * Remove some jax-dependent code * Implement equality for MetaArrays * Make all Dim variations frozen dataclasses * Shorten AbstractArray methods * Final touches * Removing get_origin use * Update tests with @jaxtyp
Hey Patrick,
I was just looking through the issues, and discovered that there is #91, which states that generators do not work. I just checked, and seemingly they work on the current version, but just in case I added a simple test to make sure that typechecking generators will continue to work in the future. It is not the original test, but I guess it is close enough imitation that does not involve torch.