-
-
Notifications
You must be signed in to change notification settings - Fork 131
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Example of wrapper to tweak a concrete class #635
Comments
Just so I understand you correctly -- are you trying to tweak an existing class, but keep the original version around -- or are you trying to monkey-patch the existing class itself? If the latter, then you should be able to do this in Equinox just like anything else in Python; Equinox classes are still mutable: import equinox
eqx.nn.MLP.__call__ = lambda self, x: x although this definitely comes under the banner of "you probably don't want to do this"! If the former (the more normal case), then read on. So the intention I had when writing these notes really was to manually forward every method. But is sounds like you happen to have a lot of methods, and would maybe prefer a slightly more heavyweight solution to what is a slightly more heavyweight problem :) In that case, I'd most recommend factoring out a common abstract base class. So right now you probably have code that looks something like: class Foo(eqx.Module):
def frobnicate(self):
print("Foo!")
def oblify(self):
print("Kaboom!")
def accepts_foo(x: Foo):
x.frobnicate()
x.oblify() And we'll refactor it into the following: class AbstractFoo(eqx.Module):
@abc.abstractmethod
def frobnicate(self):
pass
def oblify(self):
print("Kaboom!")
class Foo(AbstractFoo):
def frobnicate(self):
print("Foo!")
class Bar(AbstractFoo):
def frobnicate(self):
print("Bar!")
def accepts_foo(x: AbstractFoo):
x.frobnicate()
x.oblify() I hope that helps! |
Thanks! I think I've been leaning toward monkey-patching because I'm trying to avoid using an abstract base class. I suppose I should be a bit more specific. I have an Here's a simplified version: class MuscleModel(eqx.Module):
beta: float
omega: float
rho: float
# plus about 15 more floats/2-tuples of floats
def __call__(self, input, state):
intermediate_1 = self.intermediate_1(state)
intermediate_2 = self.intermediate_2(state)
muscle_tension = self.tension(intermediate_1, intermediate_2, input, state)
return muscle_tension
def intermediate_1(self, state):
return self.beta * state.something ** self.omega
def intermediate_2(self, state):
...
def tension(self, intermediate_1, intermediate_2, input, state):
... The issue is that there are a number of slightly different implementations of this model. Some would instantiate the full model with a full set of parameters. Others introduce various simplifications, which can be expressed by replacing or bypassing one or two of the intermediate methods. I've suspected that the most correct thing to do would be to make So I guess I feel trapped between using a hacky but straightforward monkey-patch solution, versus doing something really long and ugly but more correct. Oh, and thanks for making it really clear that classes can be mutated. For some reason I was trying to mutate the instances, and of course encountered |
Okay! So I think I can see several possible solutions, from what you've said. (And first of all, monkey patching is not one of them. Let's avoid that unless we absolutely have to!) Option 1: Just have one model class, with flagsclass MuscleModel(eqx.Module):
beta: Optional[float]
omega: Optional[float]
rho: float
def __check_init__(self):
if (beta is None) != (omega is None):
raise ValueError("If either beta or omega are provided then both must be provided")
def __call__(self, input, state):
intermediate_1 = self.intermediate_1(state)
intermediate_2 = self.intermediate_2(state)
muscle_tension = self.tension(intermediate_1, intermediate_2, input, state)
return muscle_tension
def intermediate_1(self, state):
if beta is None:
assert omega is None
return None
else:
assert omega is not None
return self.beta * state.something ** self.omega
def intermediate_2(self, state):
...
def tension(self, intermediate_1, intermediate_2, input, state):
... This is a little ugly in that it requires the consistency check; you could maybe consider replacing the parameter with a single In this case we're using the availability of a parameter as a flag, but we could also use an explicit boolean if that made more sense for some reason. Option 2: ABCsclass AbstractMuscleModel(eqx.Module):
def __call__(self, input, state):
intermediate_1 = self.intermediate_1(state)
intermediate_2 = self.intermediate_2(state)
muscle_tension = self.tension(intermediate_1, intermediate_2, input, state)
return muscle_tension
@abc.abstractmethod
def intermediate_1(self, state):
pass
@abc.abstractmethod
def intermediate_2(self, state):
pass
def tension(self, intermediate_1, intermediate_2, input, state):
...
class FooMuscleModel(AbstractMuscleModel):
beta: float
omega: float
def intermediate_1(self, state):
return self.beta * state.something ** self.omega
def intermediate_2(self, state):
return None Note that the ABC doesn't actually introduce any A variation on this btw is to explicit declare all fields (non abstractly!) in the ABC, and have them be shared by all subclasses. The pattern recommended in the guide is that all fields (and the Option 3: dependency inversionclass AbstractIntermediate(eqx.Module):
@abc.abstractmethod
def __call__(self, state):
pass
class FooIntermediate(AbstractIntermediate):
beta: float
omega: float
def __call__(self, state):
return self.beta * state.something ** self.omega
class MuscleModel(eqx.Module):
intermediate_1: AbstractIntermediate
intermediate_2: AbstractIntermediate
def __call__(self, input, state):
intermediate_1 = self.intermediate_1(state)
intermediate_2 = self.intermediate_2(state)
muscle_tension = self.tension(intermediate_1, intermediate_2, input, state)
return muscle_tension
def tension(self, intermediate_1, intermediate_2, input, state):
...
MuscleModel(FooIntermediate(beta=2.0, omega=1.0), ...) This last one is actually a variation of the Anyway, as you can see, you've got quite a few options! All of them are usually reasonable approaches. (And the fact that there are multiple ways of tackling problems like this serves to demonstrate why the abstract/final pattern is one I've never found restrictive.) |
Thanks for the detailed answer! My situation seems much clearer to me, now. I'm sure one of the approaches you've suggested will work. Option 3 probably makes the most sense. I am writing a library and I'd like the implementation to be flexible. |
From the notes on the abstract-final pattern:
Could you please give an example of how you would do this?
Assuming there's a way to do this without forwarding the methods one-by-one, I thought at first to write a class with no parents that 1) is composed of an instance of the wrapped module, 2) provides the tweaked method, and 3) implements
__getattr__
to forward all the non-tweaked stuff to the wrapped module. However this clearly doesn't work: it doesn't forward magic methods like__call__
, and once we call an untweaked method of the wrapped module, if the body of that method calls the method we're trying to tweak, it doesn't refer back to the wrapper, but calls the untweaked implementation in the wrapped class. And of course, the type of the wrapper instance doesn't check as that of the wrapped module.Much of the internet advice about tweaking instances of Python classes assumes they are mutable, and the method can be re-assigned directly.
Thanks for writing Equinox. It's been a pleasure to use!
The text was updated successfully, but these errors were encountered: