Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AbstractVar's that are also static #472

Closed
fhchl opened this issue Sep 5, 2023 · 10 comments
Closed

AbstractVar's that are also static #472

fhchl opened this issue Sep 5, 2023 · 10 comments
Labels
feature New feature

Comments

@fhchl
Copy link

fhchl commented Sep 5, 2023

I'd like to indicate to downstream users that a static instance attribute of a abstract base class should be overwritten by its sub-classes, e.g. using AbstractVar. However, it's not possible to combine them right now as AbstractVars do not allow default values.

Is there any chance one could do something along the following lines?

import equinox as eqx

class Foo(eqx.Module):
    a: AbstractVar[int] = eqx.field(static=True)
    # or
    a: int = eqx.field(static=True, abstract=True)
@patrick-kidger
Copy link
Owner

Hmm. So right now this is best handled by doing:

class AbstractFoo(eqx.Module):
    a: AbstractVar[int]

class ConcreteFoo(AbstractFoo):
    a: int = eqx.field(static=True)

as the intention of a foo: AbstractVar[int] is just to be able to communicate that "self.foo will exist at runtime", nothing more.

Doing something like this is actually surprisingly complicated. (Inside baseball: abstract variables are a lower-level piece of functionality, below both (a) dataclasses and (b) pytrees. This can basically be reduced to how this dataclass-making function should handle abstract variables.) Off the top of my head I'm not sure how we might implement this, but I'd welcome a PR / further discussion.

@fhchl
Copy link
Author

fhchl commented Sep 6, 2023

I tried some simple (but a tad convoluted) ideas in that pull request. Mainly keeping an additional list of abstractvars that were declared to be static. When the abstract attribute is finally made concrete, we add a eqx.field(default=...., static=True).

This allows the following:

from dataclasses import fields
import equinox as eqx

class Abstract(eqx.Module):
    a: eqx.AbstractVar[int] = eqx.field(static=True)

class Concrete(Abstract):
    a: int = 2

assert fields(Concrete())[0].metadata["static"] == True  # works

What do you think?

@patrick-kidger
Copy link
Owner

patrick-kidger commented Sep 6, 2023

Thanks for writing the PR!
Taking a look at it: right now, my main issue with this approach is that the AbstractVar machinery is now coupled to the static-field machinery. I think it's important that these stay separate.

I think what I'd probably suggest instead is to have AbstractVar allow all kinds of dataclass fields (i.e. don't treat them as a default value). Then in addition to performing an "are these hints compatible" check (currently done e.g. here), then simply also perform an "are these fields compatible" check. With "compatible" basically meaning that they have the exact same metadata etc.

I think that should be a much smaller change.

That would result in the syntax:

class Abstract(eqx.Module):
    a: eqx.AbstractVar[int] = eqx.field(static=True)

class Concrete(Abstract):
    a: int = eqx.field(default=2, static=True)

@fhchl
Copy link
Author

fhchl commented Sep 8, 2023

Aha! Then in your example, Concrete.a would require to have a field with the correct flags specified?

My original dream was a mechanism that raises an error if users forget to specify the value of a certain (static) attribute without them needing to think much about fields at all. But for that case, it might just be easier to have a __post_init__ method on the parent that does such a check.

@patrick-kidger
Copy link
Owner

I think having a: int be silently static, just because the parent is static, is going to be a footgun.

Moreover, that's not true without the AbstractVar annotation:

class A(eqx.Module):
    x: int = eqx.field(static=True)

class B(A):
    x: int

[field] = dataclasses.fields(B)
print(field.metadata)  # doesn't include static

@fhchl
Copy link
Author

fhchl commented Sep 11, 2023

Interesting! I didn't notice that if a subclass declares the same field with a type hint, the field metadata is cleared. I had the wrong assumption that B1 and B2 in the following example are basically the same class. 🤦🏼

class A(eqx.Module):
    x: int = eqx.field(static=True)

class B1(A):
    x: int = 1

class B2(A):
    x = 1

If static fields are not inherited in that way and the field declared with AbstractVar always needs to be implemented via a type declaration in the subclass, then yeah, this is not the right approach.

I will go with checking for those variables in __post_init__ then. The downside is that subclasses with their own __init__ must explicitly call super().__post_init__ at their end.

At least, I learned a bit about equinox's internals :)

@patrick-kidger
Copy link
Owner

Hmm. FWIW something like this is fairly common: one has an abstract class, and wishes to check some invariants after initialisation. However, if a downstream class implements __init__ or __post_init__ then those invariants go silently unchecked. In fact, I just checked and exactly this issue has existed in Diffrax for years: patrick-kidger/diffrax#308

Maybe we should add a new method -- call it __after_init__ or __eqx_post_init__ or __check__ or something -- which is (a) called automatically regardless of whether __init__ or __post_init__ is defined, and (b) automatically called for every superclass. It'd be easy enough to add this to type(Module).__call__.

Of course (b) does mean you can't override this in a subclass, but arguably that's kind of the point.

@patrick-kidger
Copy link
Owner

Alright, I've just written #492. This adds a new __check_init__ method for this purpose. WDYT?

@fhchl
Copy link
Author

fhchl commented Sep 12, 2023

I like it! Solves my usecase elegantly.

PS: I made a comment to the documentation here

@patrick-kidger
Copy link
Owner

patrick-kidger commented Sep 12, 2023

Great! In that case I'm going to close this issue, and this will appear in the upcoming release (v0.11.0) of Equinox.

Thanks for the discussion, this has turned out to be really useful.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature New feature
Projects
None yet
Development

No branches or pull requests

2 participants