Skip to content

Commit

Permalink
Bound methods now support equality against each other.
Browse files Browse the repository at this point in the history
Fixes #480.
  • Loading branch information
patrick-kidger committed Sep 29, 2023
1 parent c48fb7f commit c90333d
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 11 deletions.
5 changes: 3 additions & 2 deletions equinox/_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,9 @@ class XlaRuntimeError(Exception):
def _modify_traceback(e: Exception):
# Remove JAX's UnfilteredStackTrace, with its huge error messages.
e.__cause__ = None
# Remove _JitWrapper.__call__ and _JitWrapper._call from the traceback
tb = e.__traceback__ = e.__traceback__.tb_next.tb_next # pyright: ignore
# Remove _JitWrapper.__call__ and _JitWrapper._call and Method.__call__ from the
# traceback
tb = e.__traceback__ = e.__traceback__.tb_next.tb_next.tb_next # pyright: ignore
try:
# See https://github.com/google/jax/blob/69cd3ebe99ce12a9f22e50009c00803a095737c7/jax/_src/traceback_util.py#L190 # noqa: E501
jax.lib.xla_extension.replace_thread_exc_traceback(tb) # pyright: ignore
Expand Down
16 changes: 14 additions & 2 deletions equinox/_module.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import dataclasses
import functools as ft
import inspect
import types
import weakref
from collections.abc import Callable
from typing import Any, cast, TYPE_CHECKING, TypeVar, Union
Expand Down Expand Up @@ -100,8 +101,7 @@ def __init__(self, method):
def __get__(self, instance, owner):
if instance is None:
return self.method
_method = ft.wraps(self.method)(jtu.Partial(self.method, instance))
delattr(_method, "__wrapped__")
_method = module_update_wrapper(BoundMethod(self.method, instance), self.method)
return _method


Expand Down Expand Up @@ -540,3 +540,15 @@ def __call__(self, *args, **kwargs):
The result of the wrapped function.
"""
return self.func(*self.args, *args, **kwargs, **self.keywords)


class BoundMethod(Module):
func: types.FunctionType = field(static=True)
instance: Module

def __call__(self, *args, **kwargs):
return self.func(self.instance, *args, **kwargs)

@property
def __wrapped__(self):
return self.func.__get__(self.instance, type(self.instance)) # pyright: ignore
5 changes: 3 additions & 2 deletions tests/test_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import jax.numpy as jnp
import pytest

import equinox as eqx
import equinox.internal as eqxi

from .helpers import shaped_allclose
Expand Down Expand Up @@ -245,11 +246,11 @@ class A(eqxi.Enumeration):

token = jnp.array(True)
A.a.error_if(token, False)
jax.jit(A.a.error_if)(token, False)
eqx.filter_jit(A.a.error_if)(token, False)
with pytest.raises(Exception):
A.a.error_if(token, True)
with pytest.raises(Exception):
jax.jit(A.a.error_if)(token, True)
eqx.filter_jit(A.a.error_if)(token, True)


def test_compile_time_eval():
Expand Down
10 changes: 6 additions & 4 deletions tests/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,11 +181,13 @@ def _raises():
while tb is not None:
code_stack.append(tb.tb_frame.f_code)
tb = tb.tb_next
assert len(code_stack) == 3
one, two, three = code_stack
assert len(code_stack) == 4
one, two, three, four = code_stack
assert one.co_filename.endswith("test_errors.py")
assert one.co_name == "test_traceback_runtime_custom"
assert two.co_filename.endswith("equinox/_jit.py")
assert two.co_name == "__call__"
assert three.co_filename.endswith("equinox/_jit.py")
assert three.co_name == "_call"
assert three.co_filename.endswith("equinox/_module.py")
assert three.co_name == "__call__"
assert four.co_filename.endswith("equinox/_jit.py")
assert four.co_name == "_call"
18 changes: 17 additions & 1 deletion tests/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,13 +183,29 @@ def f(self, b):
return self.a + b

m = MyModule(13)
assert isinstance(m.f, jtu.Partial)
assert isinstance(m.f, eqx.Module)
flat, treedef = jtu.tree_flatten(m.f)
assert len(flat) == 1
assert flat[0] == 13
assert jtu.tree_unflatten(treedef, flat)(2) == 15


def test_eq_method():
# Expected behaviour from non-Module methods
class A:
def f(self):
pass

a = A()
assert a.f == a.f

class B(eqx.Module):
def f(self):
pass

assert B().f == B().f


def test_init_subclass():
ran = []

Expand Down

0 comments on commit c90333d

Please sign in to comment.