Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Example of wrapper to tweak a concrete class #635

Closed
mlprt opened this issue Jan 8, 2024 · 4 comments
Closed

Example of wrapper to tweak a concrete class #635

mlprt opened this issue Jan 8, 2024 · 4 comments
Labels
question User queries

Comments

@mlprt
Copy link

mlprt commented Jan 8, 2024

From the notes on the abstract-final pattern:

What about when you have an existing concrete class that you want to tweak just-a-little-bit? In this case, prefer composition over inheritance. Write a wrapper that forwards each method as appropriate.

Could you please give an example of how you would do this?

Assuming there's a way to do this without forwarding the methods one-by-one, I thought at first to write a class with no parents that 1) is composed of an instance of the wrapped module, 2) provides the tweaked method, and 3) implements __getattr__ to forward all the non-tweaked stuff to the wrapped module. However this clearly doesn't work: it doesn't forward magic methods like __call__, and once we call an untweaked method of the wrapped module, if the body of that method calls the method we're trying to tweak, it doesn't refer back to the wrapper, but calls the untweaked implementation in the wrapped class. And of course, the type of the wrapper instance doesn't check as that of the wrapped module.

Much of the internet advice about tweaking instances of Python classes assumes they are mutable, and the method can be re-assigned directly.

Thanks for writing Equinox. It's been a pleasure to use!

@patrick-kidger
Copy link
Owner

Much of the internet advice about tweaking instances of Python classes assumes they are mutable, and the method can be re-assigned directly.

Just so I understand you correctly -- are you trying to tweak an existing class, but keep the original version around -- or are you trying to monkey-patch the existing class itself?

If the latter, then you should be able to do this in Equinox just like anything else in Python; Equinox classes are still mutable:

import equinox
eqx.nn.MLP.__call__ = lambda self, x: x

although this definitely comes under the banner of "you probably don't want to do this"!

If the former (the more normal case), then read on.

So the intention I had when writing these notes really was to manually forward every method. But is sounds like you happen to have a lot of methods, and would maybe prefer a slightly more heavyweight solution to what is a slightly more heavyweight problem :)

In that case, I'd most recommend factoring out a common abstract base class. So right now you probably have code that looks something like:

class Foo(eqx.Module):
    def frobnicate(self):
        print("Foo!")

    def oblify(self):
        print("Kaboom!")

def accepts_foo(x: Foo):
    x.frobnicate()
    x.oblify()

And we'll refactor it into the following:

class AbstractFoo(eqx.Module):
    @abc.abstractmethod
    def frobnicate(self):
        pass

    def oblify(self):
        print("Kaboom!")

class Foo(AbstractFoo):
    def frobnicate(self):
        print("Foo!")

class Bar(AbstractFoo):
    def frobnicate(self):
        print("Bar!")

def accepts_foo(x: AbstractFoo):
    x.frobnicate()
    x.oblify()

I hope that helps!

@patrick-kidger patrick-kidger added the question User queries label Jan 8, 2024
@mlprt
Copy link
Author

mlprt commented Jan 9, 2024

Thanks! I think I've been leaning toward monkey-patching because I'm trying to avoid using an abstract base class.

I suppose I should be a bit more specific. I have an eqx.Module whose __call__ implements an empirical model of muscle contraction. It has lots of parameters which I'm implementing as module fields.

Here's a simplified version:

class MuscleModel(eqx.Module):
    beta: float
    omega: float
    rho: float
    # plus about 15 more floats/2-tuples of floats

    def __call__(self, input, state):
        intermediate_1 = self.intermediate_1(state)
        intermediate_2 = self.intermediate_2(state)
        muscle_tension = self.tension(intermediate_1, intermediate_2, input, state)
        return muscle_tension

    def intermediate_1(self, state):
        return self.beta * state.something ** self.omega

    def intermediate_2(self, state):
        ...

    def tension(self, intermediate_1, intermediate_2, input, state):
        ...

The issue is that there are a number of slightly different implementations of this model. Some would instantiate the full model with a full set of parameters. Others introduce various simplifications, which can be expressed by replacing or bypassing one or two of the intermediate methods.

I've suspected that the most correct thing to do would be to make MuscleModel an abstract base class, then write final classes for the full model and for each possible simplification. But that would mean something like 20 AbstractVars which would need to be implemented again each time.

So I guess I feel trapped between using a hacky but straightforward monkey-patch solution, versus doing something really long and ugly but more correct.

Oh, and thanks for making it really clear that classes can be mutated. For some reason I was trying to mutate the instances, and of course encountered FrozenInstanceError.

@patrick-kidger
Copy link
Owner

Okay! So I think I can see several possible solutions, from what you've said. (And first of all, monkey patching is not one of them. Let's avoid that unless we absolutely have to!)

Option 1: Just have one model class, with flags

class MuscleModel(eqx.Module):
    beta: Optional[float]
    omega: Optional[float]
    rho: float

    def __check_init__(self):
        if (beta is None) != (omega is None):
            raise ValueError("If either beta or omega are provided then both must be provided")

    def __call__(self, input, state):
        intermediate_1 = self.intermediate_1(state)
        intermediate_2 = self.intermediate_2(state)
        muscle_tension = self.tension(intermediate_1, intermediate_2, input, state)
        return muscle_tension

    def intermediate_1(self, state):
        if beta is None:
            assert omega is None
            return None
        else:
            assert omega is not None
            return self.beta * state.something ** self.omega

    def intermediate_2(self, state):
        ...

    def tension(self, intermediate_1, intermediate_2, input, state):
        ...

This is a little ugly in that it requires the consistency check; you could maybe consider replacing the parameter with a single beta_omega: Optional[tuple[float, float]] so that you automatically have both being provided or neither being provided.

In this case we're using the availability of a parameter as a flag, but we could also use an explicit boolean if that made more sense for some reason.

Option 2: ABCs

class AbstractMuscleModel(eqx.Module):
    def __call__(self, input, state):
        intermediate_1 = self.intermediate_1(state)
        intermediate_2 = self.intermediate_2(state)
        muscle_tension = self.tension(intermediate_1, intermediate_2, input, state)
        return muscle_tension

    @abc.abstractmethod
    def intermediate_1(self, state):
        pass

    @abc.abstractmethod
    def intermediate_2(self, state):
        pass

    def tension(self, intermediate_1, intermediate_2, input, state):
        ...

class FooMuscleModel(AbstractMuscleModel):
    beta: float
    omega: float

    def intermediate_1(self, state):
        return self.beta * state.something ** self.omega

    def intermediate_2(self, state):
        return None

Note that the ABC doesn't actually introduce any AbstractVars, as at least in this MWE it doesn't explicitly access any (they're only used in the methods of the subclass). So maybe the extra verbosity there can be skipped. (In particular, it sounds like you only need some parameters for some cases -- so you probably don't want to declare them all as AbstractVars on the ABC anyway.)

A variation on this btw is to explicit declare all fields (non abstractly!) in the ABC, and have them be shared by all subclasses. The pattern recommended in the guide is that all fields (and the __init__ method) all be declared in precisely one class in the hierarchy. That's usually the concrete subclass at the bottom, but it doesn't have to be.

Option 3: dependency inversion

class AbstractIntermediate(eqx.Module):
    @abc.abstractmethod
    def __call__(self, state):
        pass

class FooIntermediate(AbstractIntermediate):
    beta: float
    omega: float

    def __call__(self, state):
        return self.beta * state.something ** self.omega

class MuscleModel(eqx.Module):
    intermediate_1: AbstractIntermediate
    intermediate_2: AbstractIntermediate

    def __call__(self, input, state):
        intermediate_1 = self.intermediate_1(state)
        intermediate_2 = self.intermediate_2(state)
        muscle_tension = self.tension(intermediate_1, intermediate_2, input, state)
        return muscle_tension

    def tension(self, intermediate_1, intermediate_2, input, state):
        ...

MuscleModel(FooIntermediate(beta=2.0, omega=1.0), ...)

This last one is actually a variation of the Optional[tuple[float, float]] we saw in Option 1. An instance of FooIntermediate is basically the same as a tuple[float, float]: it just packages two floats together. The only difference is where we put the self.beta * state.something ** self.omega logic. If you're developing at library, then this comes with one potential advantage over Option 1: a downstream user can come along and implement some BrandNewIntermediate with their own __call__ logic -- without needing any changes to MuscleModel itself.


Anyway, as you can see, you've got quite a few options! All of them are usually reasonable approaches. (And the fact that there are multiple ways of tackling problems like this serves to demonstrate why the abstract/final pattern is one I've never found restrictive.)

@mlprt
Copy link
Author

mlprt commented Jan 10, 2024

Thanks for the detailed answer! My situation seems much clearer to me, now. I'm sure one of the approaches you've suggested will work.

Option 3 probably makes the most sense. I am writing a library and I'd like the implementation to be flexible.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

2 participants