Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Odd behavior for jax.tree_util.Partial and interactions with eqx.Module #480

Closed
bowlingmh opened this issue Sep 8, 2023 · 3 comments
Closed
Labels
bug Something isn't working

Comments

@bowlingmh
Copy link
Contributor

bowlingmh commented Sep 8, 2023

Here's something pretty odd about jax.tree_util.Partial: it mimics the equality behavior of functools.partial, which returns that two partial calls on the same function on the same arguments are not equal, unless they are the same object.

This is a problem, because equinox wraps its method calls to return a jax.tree_util.Partial, giving you the following very unexpected behavior:

import equinox as eqx

class M(eqx.Module):
  def f(self, x):
    return x

m = M()
M.f == M.f # True
m.f == m.f # False

The same instance method isn't the same! This is because they come from separate calls to the wrapper that turns the
class's method into jax.tree_util.Partial.

One possible fix is to switch the wrapper to use equinox's own Partial, which behaves sensibly with equality. Except, you can't use functools.wraps on an equinox Partial because it's a frozen dataclass and wraps tries to change attributes on the method. The only simple fix is to drop the wraps call altogether but that seems like a bad choice too. Maybe the Partial could be unfrozen while it is wrapped?

Probably the better solution is for jax to change how __eq__ works on jax.tree_util.Partial, because it leads to more unusual behavior. The choice for equality may be a fine decision for functools.partial, but jax.tree_util.Partial is PyTree compatible and so can be flattened. I would expect that if two flattened PyTrees are equal that the unflattened PyTrees would be equal. That doesn't hold for jax.tree_util.Partial, which flattens to the function and its arguments, which are equal for two equivalent jax.tree_util.Partial objects.

Equinox's Partial doesn't have any of these problems. See below, that only jax.tree_util.Partial is inconsistent between equality of its flattened and unflattened representations.

import functools as ft
import jax.tree_util as jtu
import equinox as eqx

def f(x):
  return x

print('functools: ', 
      ft.partial(f, 3) == ft.partial(f, 3),
      jtu.tree_flatten(ft.partial(f, 3)) == jtu.tree_flatten(ft.partial(f, 3)))
print('jax.tree_utils: ',
      jtu.Partial(f, 3) == jtu.Partial(f, 3),
      jtu.tree_flatten(jtu.Partial(f, 3)) == jtu.tree_flatten(jtu.Partial(f, 3)))
print('equinox: ',
      eqx.Partial(f, 3) == eqx.Partial(f, 3),
      jtu.tree_flatten(eqx.Partial(f, 3)) == jtu.tree_flatten(eqx.Partial(f, 3)))

gives the output:

functools:  False False
jax.tree_utils:  False True
equinox:  True True
@patrick-kidger
Copy link
Owner

Thanks for the report! This is quite the edge-case. Indeed, it looks like jax.tree_util.Partial has some odd behaviour on this front.

I've just written #485, which should fix this.

@bowlingmh
Copy link
Contributor Author

Thanks!

@patrick-kidger
Copy link
Owner

Closing as fixed in #485. This will be included the next release (v0.11.0) of Equinox.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants