Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Best practices to convert torch.nn.Module to eqx.Module #396

Open
emilemathieu opened this issue Jun 15, 2023 · 12 comments
Open

Best practices to convert torch.nn.Module to eqx.Module #396

emilemathieu opened this issue Jun 15, 2023 · 12 comments
Labels
question User queries

Comments

@emilemathieu
Copy link

Thanks for the great package!

I was wondering whether there was some documentation regarding the best practice for converting torch.nn.Module to eqx.Module?

In particular

  • It is quite clear that the register_parameters would be replace by an attribute .e.g. weights: Array
  • How one should handle register_buffer?
  • and add_module(name, intertwiner_basis)? especially when the name is not known in advanced e.g. f"module_{variable}"

Thanks a lot!

(for context I'm looking at porting escnn to jax cf QUVA-Lab/escnn#55)

@emilemathieu emilemathieu changed the title Best practice to convert torch.nn.Module to eqx.Module Best practices to convert torch.nn.Module to eqx.Module Jun 15, 2023
@patrick-kidger
Copy link
Owner

Hey there!

Buffers are most simply handled by storing them as an array (just like a parameter) and then calling jax.lax.stop_gradient(self.my_buffer) when you access them in __call__.

Alternatively, you can follow the freeze parameter example, which involves passing them through jax.grad as a nondifferentiable argument.


Modules are typically stored as attributes just like parameters, e.g. see the source code for eqx.nn.MLP.

If you need dynamically-named parameters then you can store those in a dictionary, and then store the dictionary on the parent module. Modules themselves are not variadically-sized.

(Note that this dynamism should only happen at __init__ time. At call time it is strongly recommended not to mutate your model, as it's easy to lose track of the changes when flattening/unflattening across JIT/grad/etc boundaries.)

@patrick-kidger patrick-kidger added the question User queries label Jun 15, 2023
@emilemathieu
Copy link
Author

Thanks @patrick-kidger for your answer! :)

I'd still have a question on how best handle the following scenario, where I have a linear layer which matrix M is given on a basis B (fixed) with coefficients W (to learn).
With eqx.tree_inference(layer, True) I can change the value of layer.inference, but I'd like to also store the matrix M = B @ W at evaluation time.
I feel that the two options are

  • returning a new layer when calling layer = layer.eval() passing as (optional) argument M
  • passing M as a state (as in the batch norm layer)

I would be really keen on knowing your opinion, as none of the two options seems ideal :)

class Linear(eqx.Module):
  W: Array
  B: Array
  M: Array
  inference: bool
    
  def __init__(self, ...):
  self.B  = ...
  self.W  = ...
  self.inference = inference

  def __call__(self, x):
    if self.inference:
      return self.M @ x
    else:
      return self.B @ self.W @ x

  def eval(self):
    self.M = self.B @ self.W

@patrick-kidger
Copy link
Owner

patrick-kidger commented Jun 18, 2023

I think doing the conversion before inference time probably makes most sense. Here's an example of training a linear layer with a symmetric weight matrix:

#
# Train-time: resolve on-the-fly.
#

class Symmetric(eqx.Module):
    array: Array

    def get(self):
        return 0.5 * (self.array + self.array.T)

is_symmetric = lambda x: isinstance(x, Symmetric)

@eqx.filter_jit
def train_step(model, ...):
    model = jax.tree_util.tree_map(lambda x: x.get() if is_symmetric(x) else x, model, is_leaf=is_symmetric)
    ...  # compute gradients, update, etc.

model = eqx.nn.Linear(...)
model = eqx.tree_at(lambda m: m.weight, model, replace_fn=Symmetric)
for _ in range(steps):
    model = train_step(model, ...)

#
# Inference time: perform conversion.
#

inference_model = eqx.tree_inference(model, True)
inference_model = jax.tree_util.tree_map(lambda x: x.get() if is_symmetric(x) else x, inference_model, is_leaf=is_symmetric)
inference_model(...)  # evaluate

Doing some kind of train->inference conversion is pretty common -- e.g. quantisation, pruning, absorbing adjacent batchnorm and linear layers into a single linear transformation, etc. etc.


Also, note that I don't do something like self.M = self.B @ self.W. You can't assign to eqx.Modules outside of __init__ -- much like tuples, they are immutable. You should use eqx.tree_at to create an out-of-place update instead.

This is a deliberate design choice, as it helps to reason about changes in the presence of jit, grad, etc.

@emilemathieu
Copy link
Author

emilemathieu commented Jun 19, 2023

Thanks that's really useful, wasn't aware of eqx.tree_at!

I eventually implemented something like

class Linear(eqx.Module):
  ....
  def eval():
    new = eqx.tree_inference(self, True)
    return eqx.tree_at(lambda m: m.matrix, new, replace=matrix)

model = Linear(...)
model = model.eval()

@emilemathieu
Copy link
Author

@patrick-kidger
In my eqx.Module I have some Array attributes which I want to learn (i.e. parameters) and other Array attributes which aren't and that I'm only setting at evaluation time with eqx.tree_at(lambda m: m.matrix, new, replace=matrix).
Would there be a similar way to params, static = eqx.partition(model, eqx.is_array) but which would filter out the non parameters Array attributes?

@patrick-kidger
Copy link
Owner

Yep, this is totally possible. First of all, if you just want to have non-learnt arrays then call lax.stop_gradient after access:

def __call__(self, ...):
    buffer = lax.stop_gradient(self.buffer)
    # now use `buffer` wherever you want to use it.

If you need to do something more complicated with filtering, then you can use a wrapper class:

class FooArray(eqx.Module):
    array: Array

class Model(eqx.Module):
    def __init__(self, ...):
        self.foo = FooArray(some_array)
    ...

model = Model(...)
is_foo = lambda x: isinstance(x, FooArray)
has_foo, no_foo = eqx.partition(model, is_foo, is_leaf=is_foo)

Here's a fully-fledged example for creating a linear transformation with a symmetric matrix.

@emilemathieu
Copy link
Author

Thanks @patrick-kidger!
fyi I've been working on porting escnn to jax & equinox as the only Jax supported equivariant NN library is e3nn_jax and it only support the O(3) whilst escnn supports many subgroups (and it does not so far support equinox) .
See for instance the EquivariantModule and Linear classes and an MNIST example with this escnn_jax library.

If that's something you're interested in and/or have any suggestions/remarks I'd be keen on hearing them :)

@patrick-kidger
Copy link
Owner

Thanks! I've just had a quick look.

@emilemathieu
Copy link
Author

This should probably be an iteration-over-pytree, not just layers only? Also note that you're technically doing O(n^2) work by calling tree_inference at multiple tree depths.

Regarding this, I completely agree, how can I achieve this? with something like the following?

is_layer = lambda m: isinstance(m, eqx.Module)
new = jax.tree_util.tree_map(lambda m: m.train(mode), self, is_leaf=is_layer)

is new = eqx.tree_at(...) needed?

@emilemathieu
Copy link
Author

@patrick-kidger would you have an idea by any chance whether it's usually better/faster in Jax when 'filling in an array' to (1) create an empty array, iterate and fill values with arr = are.at[...].set(...), or (2) creating an empty list arr = [], iterate whilst appending values, and finally concatenating them?

@emilemathieu
Copy link
Author

emilemathieu commented Jun 28, 2023

Also to handle both statelful and stateless modules I found myself adding something like

for layer in self.layers:
  if "state" in inspect.signature(layer).parameters:
      x, state = layer(x, state)
  else:
      x = layer(x)

Is there any way around? Could wrap the stateless module with

def state_wrapper(layer: eqx.Module):
    if "state" in inspect.signature(layer).parameters:
        return layer
    else:
        return lambda x, state: layer(x), state

or something like eqx.nn.Lambda.

Would it be worth adding to eqx.nn.Sequential an optional state: eqx.nn.State = None argument to pass along and if not None would be returned? The could wrap everything into a eqx.nn.Sequential layer and simply call it once.

@patrick-kidger
Copy link
Owner

Regarding this, I completely agree, how can I achieve this? with something like the following?
is_layer = lambda m: isinstance(m, eqx.Module)
new = jax.tree_util.tree_map(lambda m: m.train(mode), self, is_leaf=is_layer)
is new = eqx.tree_at(...) needed?

I'd recommend against this. Equinox modules are really just pytrees like any other, so it's not appropriate to special case them. Moreover what if there is some non-E3NN-Module that doesn't implement a .train method at all?

In the spirit of nominative subtyping, I would instead recommend the following pattern:

# Declare that this method should exist
class E3NNModule(eqx.Module):
    @abc.abstractmethod
    def train(self, mode):
        ...

# Now go looking for such layers, knowing that the train method must exist.
is_layer = lambda m: isinstance(m, E3NNModule)
new = jax.tree_util.tree_map(lambda m: m.train(mode), self, is_leaf=is_layer)

# On your concrete classes, go ahead and provide an implementation.
class SomeModule(E3NNModule):
    def train(self, mode):

If you have nested E3NNModules inside each other, make sure that the .train methods of the wrapper module call the .train method of the wrapped module.

Incidentally the above is exactly the sort of thing I do very widely across my JAX libraries -- I'm a big fan of using ABCs to explicitly declare what interfaces are available.

would you have an idea by any chance whether it's usually better/faster in Jax when 'filling in an array' to (1) create an empty array, iterate and fill values with arr = are.at[...].set(...), or (2) creating an empty list arr = [], iterate whilst appending values, and finally concatenating them?

I would recommend (2). JAX's heuristics for in-place updates are sometimes not great.

Also to handle both statelful and stateless modules I found myself adding something like

Hmm, these aren't really designed to be used interchangeably. After all, one could easily define a module with a completely arbitrary custom signature, it's not like the only two valid ones are (x,) and (x, state). Conversely, if someone defines a module with signature (x, foo, state) -- that just so happens to use the name state -- then your check will trigger incorrectly.

Stateful layers are pretty unusual -- in particular batchnorm is used very infrequently. What's your use case?

def state_wrapper(layer: eqx.Module):
    if "state" in inspect.signature(layer).parameters:
        return layer
    else:
        return lambda x, state: layer(x), state

Note that this snippet is dangerous. The else branch does not return a pytree, so any parameters inside layer will not actually be trained.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

2 participants