Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding a test for generator support #171

Merged
merged 17 commits into from
Feb 20, 2024

Conversation

knyazer
Copy link
Contributor

@knyazer knyazer commented Feb 12, 2024

Hey Patrick,

I was just looking through the issues, and discovered that there is #91, which states that generators do not work. I just checked, and seemingly they work on the current version, but just in case I added a simple test to make sure that typechecking generators will continue to work in the future. It is not the original test, but I guess it is close enough imitation that does not involve torch.

@patrick-kidger
Copy link
Owner

LGTM!
I'm guessing this might still fail for the double-decorator @jaxtyped(typechecker=None) @beartype def foo?

@knyazer
Copy link
Contributor Author

knyazer commented Feb 12, 2024

LGTM! I'm guessing this might still fail for the double-decorator @jaxtyped(typechecker=None) @beartype def foo?

Oh wow, you are actually right: it does in fact still fail when doing a double decorator, with torch tensors. I will actually have to fix this... Thanks for the observation!

@knyazer knyazer marked this pull request as draft February 12, 2024 18:29
@patrick-kidger
Copy link
Owner

Do you mean that it fails for torch tensors specifically, but not for JAX etc arrays?

For what it's worth, if it's difficult for you then I'd also be happy to just not support this! It's a pretty uncommon use case and it may complicate our own code quite a lot. (?)

@knyazer
Copy link
Contributor Author

knyazer commented Feb 12, 2024

Do you mean that it fails for torch tensors specifically, but not for JAX etc arrays?

For what it's worth, if it's difficult for you then I'd also be happy to just not support this! It's a pretty uncommon use case and it may complicate our own code quite a lot. (?)

For now I observed a bunch of weird ways it can fail, and I am not really sure how complex fixing it is going to be, so I will just try for a bit. And, if I give up, I will tell you in a week or so :)

e.g. right now one of the possible solutions that I can think of is to just replace annotations for generator outputs with dummy annotations, so that we just certainly do not typecheck anything. I reckon it is better in this case to guarantee that we don't typecheck outputs of generators (which is probably an expected behaviour, since doing it is sort of impossible), than to fail in some weird cases.

Anyway, I will see what I can do, and if I cannot fix it, then maybe it is not worth it. But I still want to try.

@knyazer knyazer force-pushed the test-generators branch 3 times, most recently from e0a8556 to df93af6 Compare February 13, 2024 23:36
@knyazer

This comment was marked as outdated.

@knyazer knyazer marked this pull request as ready for review February 14, 2024 00:00
@knyazer knyazer marked this pull request as draft February 14, 2024 09:22
Also guarded torch imports for better compatibility with
requirements.txt
@knyazer knyazer marked this pull request as ready for review February 14, 2024 11:20
@knyazer
Copy link
Contributor Author

knyazer commented Feb 14, 2024

So, I implemented a solution that just replaces any output annotations for arbitrarily-deeply nested generators with Any. Plus, I added import guards for torch: it is not in requirements.txt, so probably this makes sense.

There are drawbacks: it might be unexpected for end users that generators are not typechecked, but since it is the behavior of bear type to silently ignore generator outputs, I guess we can adhere to it, and do the same thing :)

Another issue is that if the custom typechecker the user employs uses generators for some reason, we might detect this generator and stop typechecking outputs, but I find it very unlikely: I simply don't see how one find any use for generators-as-decorators in this case. But I still wonder WDYT?

I spent quite some time trying to do something with _shape_storage: like, cleaning up variadic memoes after calls to generator functions, but I think the naive solution I implemented, while probably having more edge cases, is so much simpler, that I don't think implementing something like additional memoes cleanups makes sense: additional complexity is bad!

Anyway, @patrick-kidger, I wonder what do you think? You were right originally that the problem is a bit deep-rooted, but the solution I implemented is extremely simple, though I guess you might be able to find cases when it fails.


P.S. Sorry for the dumb things I wrote in the last comment, I am now officially stupid. I was kinda overwhelmed with the debugging, so I thought what I wrote was sensible, but after a good night's sleep it stopped making any sense at all. Thanks for allowing me to figure out my brain issues by myself!

while hasattr(wrp, "__wrapped__"):
wrp = wrp.__wrapped__
if inspect.isgeneratorfunction(wrp):
fn.__annotations__["return"] = Any
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generator decorators

Your question about a typechecking decorator using generators for some reason -- I think this could be handled by de-indenting these two lines. Do all the unwrapping, then check for being a generator.

Comments on this approach

That said, there are two things I don't like about this approach:

  1. We're mutating an object (fn) that we don't own.
  2. One of the intermediate decorators might have already inspected the annotations and stored a copy somewhere else. (Indeed any sane type-checking decorator will have already done this, as it will want to store a copy of any destringified annotations somewhere. That said, I don't think any of the existing typechecking decorators actually does this..., and I have a horrible feeling that they also trample all over the existing __annotations__ dictionary.)

Another possibility -- that I also don't love

Here's one other possibility I'm contemplating, that I think should still be fairly lightweight. At this spot in the code, we have:

wrp = fn
while hasattr(wrp, "__wrapped__"):
    wrp = wrp.__wrapped__
is_generator = inspect.isgeneratorfunction(wrp)

and then where we currently have return fn(*args, **kwargs), we instead have:

out = fn(*args, **kwargs):
if is_generator:
    out = _jaxtyped_generator(out)
return out

where we define:

def _jaxtyped_generator(gen):
    while True:
        try:
            with jaxtyped("context"):
                val = next(gen)
            yield val
        except StopIteration:
            break

Actually, I think the above might need a little more fiddling to be sure that .sending values to generators will works, but you get the idea.

This will create a fresh jaxtyped context for every value that is next'd by a user, so we don't get consistency across those... but I think we agree that supporting that is probably much more effort than it's worth.

(Downside of this approach: if for some reason a user has a generatorfunction that, when decorated, is transformed into something that isn't iterable, then the above may misbehave horribly.)

How much do we want to support?

That said, I'm definitely not sold on my suggestion above. I don't really want to support the full gamut of generators + simple coroutines + generator courtines + async def coroutines + contextlib.contextmanager-wrapped functions + whatever else might be cooked up that I've forgotten.
(Handy reference for the first three.)

Python overloads functions to mean many different things, it might be pretty difficult to support all of them.

WDYT? I'm honestly not sure of the best approach here.

@knyazer
Copy link
Contributor Author

knyazer commented Feb 16, 2024

Well, if considering the two ways, I certainly prefer the first one. The reason is simple, the second approach is going to lead to probably a lot of not-really-well-supported code, that might fail in unexpected ways. Supporting all possible ways Python can do generators is probably not feasible, there is simply too much technical debt.

Anyway, now I also see why you dislike my approach (aka dumb approach). But there is a modification that might be nice:

Modifying naive approach

  1. Well, certainly we want to avoid modifying the fn, this is sort of a big deal: we might accidentally break something in the hundreds of decorators nested by the user. So, the idea is quite simple: just modify only the jaxtyping things, since breaking jaxtyping is probably alright because it is going to be our problem. So, parse the return annotation of the generator, and somehow mark all the jaxtyping-type annotations.
  2. Still a problem.

About the actual solution: I just simply added a private property to the _MetaAbstractArray: I think this is the simplest way, but I also think that this is the ugliest way. Adding state to types sounds like an extremely bad idea in general :) But all the alternatives involve having an 'implicit' state, which is even worse.

Possible problems:

  1. The main problem is users doing some fancy-type-redefinitions, which are going to hide the original jaxtyping type from the annotation, while effectively being a jaxtyping type.
  2. Adding a state to the types is clearly a bad idea.
  3. Other decorators storing a copy of our annotations is also a problem, though when I think about it, I reckon the authors of both beartype and typeguard might have consciously made this decision. Simply because if we assume there are more annotation-processing decorators, we probably want to include them into consideration.

This is better than the original one, at least. But maybe it is still bad enough that you feel like not supporting generators at all is a better idea. I guess I am not that much just giving up, since it is the 'old' syntax, so we are officially allowed to just redirect users to the new syntax. Whereas if they have reasons to use the old syntax... Well, anyway, waiting for your opinion. The latest version of the code in the branch is approximately what I would have written: I still need to add support/check that it works for async generators, though I want to make sure that you are fine/not fine with the general approach before I do so.

@patrick-kidger
Copy link
Owner

I think I like the current approach! It's not perfect for the reasons we've each highlighted, but I think it might also just be as good as it gets. I think my only minor nits would be to (a) prefer typing.get_{args, origin} over access the attributes directly, and (b) to prefer issubclass(..., AbstractArray) over isinstance(..., _MetaAbstractArray).

@patrick-kidger patrick-kidger changed the base branch from main to dev February 17, 2024 01:36
@patrick-kidger
Copy link
Owner

I've just updated the target branch to dev. I think the conflicts are trivial/minor.
Let me know once this is done and we'll merge + do a new release.

@knyazer
Copy link
Contributor Author

knyazer commented Feb 17, 2024

Yeah, sure. Sorry it's taking so long

@knyazer
Copy link
Contributor Author

knyazer commented Feb 17, 2024

While writing more tests and modifying the code according to your recommendations I discovered that I have actually introduced a bug in the last version: the _skip_instancecheck field was a static field of the class (not an instance, but an actual class, representing the type annotation). Thus, test like this was failing:

def test_generators_dont_modify_same_annotations(typecheck):
    @jaxtyped(typechecker=None)
    @typecheck
    def g(x: Float[Array, "1"]) -> Iterator[Float[Array, "1"]]:
        yield x

    @jaxtyped(typechecker=typecheck)
    def m(x: Float[Array, "1"]) -> Iterator[Float[Array, "1"]]:
        return x

    with pytest.raises(ParamError): # fails (does not raise)
        next(g(jnp.zeros(2)))
    with pytest.raises(ParamError): # fails too
        m(jnp.zeros(2))

since all the annotations Float[Array, '1'] were marked as the ones to skip after the decorator was encountered. Very bad.

So, I rewrote the thing to the old solution, with some slight modifications: now, we don't replace the whole annotation with typing.Any but only all the jaxtyping-subclassed annotations. So, for example, typing.Iterator[Float[Array, '1']] becomes typing.Iterator[typing.Any]. To do that I used some very unstable tricks, and I was even surprised they worked at all, but probably it is better than ignoring this issue: I would guess we want to ensure we have no false positives as hard as possible since this is the main reason (at least in my case) for people to stop using jaxtyping. But still, not perfect, and maybe I have introduced some even worse bugs with this fix.

Also added pytest-asyncio model to test async generators, but the tests pass even on the old version. But still, it is better to have them, I guess.

Anyway, WDYT? I wonder if you might have come up with a nicer solution.

@knyazer
Copy link
Contributor Author

knyazer commented Feb 17, 2024

BTW, now the code should be ready for merging, so if you like everything - feel free to merge :)

@patrick-kidger
Copy link
Owner

Oh bother, that's a good observation. Yeah, that approach doesn't work then.

My concern with this approach is that we're back to mutating objects we don't own. E.g. setattr(ann, "__args__", tuple(new_args)) is actually assigning a brand-new tuple tuple(new_args), which will be a distinct Python object from the one we started with. If I've learnt anything doing jaxtyping, it's that Python's type annotations are very fragile and this kind of thing is going to break someone, somewhere, somehow...

I think I can see a way to revive the previous approach. The reason it didn't work is because we have a cache on _make_array (for performance reasons, so that we don't need to reperform the shape-string-parsing when having many Float[Array, "foo"]). This means that we're modifying the same jaxtyping object when setting our don't-actually-check flag.

I think all we need to do is to move this part of _make_array:

out = metaclass(
name,
(array_type, AbstractArray),
dict(
array_type=array_type,
dtypes=dtypes,
dims=dims,
index_variadic=index_variadic,
dim_str=dim_str,
),
)
if getattr(typing, "GENERATING_DOCUMENTATION", False):
out.__module__ = "builtins"
else:
out.__module__ = "jaxtyping"

to happen outside the cache. That is, we should cache the tuple of (metaclass, name, array_type, dtypes, ...etc), but then only actually call out = metaclass(...) outside the cache. This should be doable by wrapping _make_array.

This way we will get a fresh jaxtyping object on every __getitem__ call, that we can later feel free to mutate, whilst still preserving as much caching as possible.

WDYT?

@knyazer
Copy link
Contributor Author

knyazer commented Feb 19, 2024

That is a nice suggestion. I don't think I would have came up with this under any circumstances, it requires either better debugging skills or better codebase knowledge than I posses. Thanks!

Besides, I figured there was a bug, that was hidden thanks to caching: the equality between AbstractMetaArrays was not implemented (correctly), so things like Float[Array, ""] == Float[Array, ""] were true only because of caching, and without caching it was false. Funnily enough, there was only a single test that was using equality of annotations, and only as a side effect, since it was doing comparison between sets.

Well, I just made all the Dim* variations frozen dataclasses, that automatically implement the required __hash__ and __eq__ methods. The code seems to be complete, according to our discussion at least, so I wonder if you will be able to find any more places where we might have issues in the future :)

Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, marvellous stuff!
This took us a lot of brainpower -- but I think this is a solution that should work. :D
Usual nits below, then let's merge this.

jaxtyping/_array_types.py Show resolved Hide resolved
# annotations as not needing instance checks, while still being
# visible as original ones for the typechecker
def modify_annotation(ann):
if isinstance(ann, _MetaAbstractArray):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make this issubclass(ann, AbstractArray).
(Notice that we had to import an underscored object to get this. The convention I use throughout the Equinox+JAX ecosystem is that such objects are private to the file they're defined in.)

# visible as original ones for the typechecker
def modify_annotation(ann):
if isinstance(ann, _MetaAbstractArray):
setattr(ann, "_skip_instancecheck", True)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ann._skip_instancecheck = True ?

Comment on lines 326 to 328
elif hasattr(ann, "__args__"):
for sub_ann in get_args(ann):
modify_annotation(sub_ann)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the hasattr check can be removed: get_args will always return a tuple. (If necessary an empty tuple.)

Comment on lines 329 to 330
elif hasattr(ann, "__origin__"):
modify_annotation(get_origin(ann))
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In contrast this one should probably technically be

origin = get_origin(ann)
if origin is not None:
    modify_annotation(origin)

(IIRC __args__ and __origin__ are actually implementation details that we're not really supposed to touch.)

@knyazer
Copy link
Contributor Author

knyazer commented Feb 19, 2024

Simplified the modify_annotation method, removed access to the protected field.

WDYT?

@patrick-kidger patrick-kidger merged commit 62b1296 into patrick-kidger:dev Feb 20, 2024
1 check passed
@patrick-kidger
Copy link
Owner

Alright, great stuff! Just merged. Thank you as always for your efforts -- and I'm going to do a new jaxtyping release shortly :)

@patrick-kidger patrick-kidger mentioned this pull request Feb 25, 2024
Merged
patrick-kidger pushed a commit that referenced this pull request Feb 25, 2024
* Add a test for generators

* Remove output annotations from decorators

Also guarded torch imports for better compatibility with
requirements.txt

* Add flag to the main meta class to skip the typecheck

* Return to the old solution

* Make async tests work

* Minor adjustments/fixing typos

* Correct Python path for new tests

* Remove some jax-dependent code

* Implement equality for MetaArrays

* Make all Dim variations frozen dataclasses

* Shorten AbstractArray methods

* Final touches

* Removing get_origin use

* Update tests with @jaxtyp
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants