diff --git a/equinox/_jit.py b/equinox/_jit.py index 633fe99a..997c6f2e 100644 --- a/equinox/_jit.py +++ b/equinox/_jit.py @@ -92,8 +92,9 @@ class XlaRuntimeError(Exception): def _modify_traceback(e: Exception): # Remove JAX's UnfilteredStackTrace, with its huge error messages. e.__cause__ = None - # Remove _JitWrapper.__call__ and _JitWrapper._call from the traceback - tb = e.__traceback__ = e.__traceback__.tb_next.tb_next # pyright: ignore + # Remove _JitWrapper.__call__ and _JitWrapper._call and Method.__call__ from the + # traceback + tb = e.__traceback__ = e.__traceback__.tb_next.tb_next.tb_next # pyright: ignore try: # See https://github.com/google/jax/blob/69cd3ebe99ce12a9f22e50009c00803a095737c7/jax/_src/traceback_util.py#L190 # noqa: E501 jax.lib.xla_extension.replace_thread_exc_traceback(tb) # pyright: ignore diff --git a/equinox/_module.py b/equinox/_module.py index 7293fecc..69b26d77 100644 --- a/equinox/_module.py +++ b/equinox/_module.py @@ -1,6 +1,7 @@ import dataclasses import functools as ft import inspect +import types import weakref from collections.abc import Callable from typing import Any, cast, TYPE_CHECKING, TypeVar, Union @@ -100,8 +101,7 @@ def __init__(self, method): def __get__(self, instance, owner): if instance is None: return self.method - _method = ft.wraps(self.method)(jtu.Partial(self.method, instance)) - delattr(_method, "__wrapped__") + _method = module_update_wrapper(BoundMethod(self.method, instance), self.method) return _method @@ -540,3 +540,15 @@ def __call__(self, *args, **kwargs): The result of the wrapped function. """ return self.func(*self.args, *args, **kwargs, **self.keywords) + + +class BoundMethod(Module): + func: types.FunctionType = field(static=True) + instance: Module + + def __call__(self, *args, **kwargs): + return self.func(self.instance, *args, **kwargs) + + @property + def __wrapped__(self): + return self.func.__get__(self.instance, type(self.instance)) # pyright: ignore diff --git a/tests/test_enum.py b/tests/test_enum.py index 71e37af6..305ea9d6 100644 --- a/tests/test_enum.py +++ b/tests/test_enum.py @@ -2,6 +2,7 @@ import jax.numpy as jnp import pytest +import equinox as eqx import equinox.internal as eqxi from .helpers import shaped_allclose @@ -245,11 +246,11 @@ class A(eqxi.Enumeration): token = jnp.array(True) A.a.error_if(token, False) - jax.jit(A.a.error_if)(token, False) + eqx.filter_jit(A.a.error_if)(token, False) with pytest.raises(Exception): A.a.error_if(token, True) with pytest.raises(Exception): - jax.jit(A.a.error_if)(token, True) + eqx.filter_jit(A.a.error_if)(token, True) def test_compile_time_eval(): diff --git a/tests/test_errors.py b/tests/test_errors.py index 18354973..031f2cd7 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -181,11 +181,13 @@ def _raises(): while tb is not None: code_stack.append(tb.tb_frame.f_code) tb = tb.tb_next - assert len(code_stack) == 3 - one, two, three = code_stack + assert len(code_stack) == 4 + one, two, three, four = code_stack assert one.co_filename.endswith("test_errors.py") assert one.co_name == "test_traceback_runtime_custom" assert two.co_filename.endswith("equinox/_jit.py") assert two.co_name == "__call__" - assert three.co_filename.endswith("equinox/_jit.py") - assert three.co_name == "_call" + assert three.co_filename.endswith("equinox/_module.py") + assert three.co_name == "__call__" + assert four.co_filename.endswith("equinox/_jit.py") + assert four.co_name == "_call" diff --git a/tests/test_module.py b/tests/test_module.py index 5a9ed129..35c0e725 100644 --- a/tests/test_module.py +++ b/tests/test_module.py @@ -183,13 +183,29 @@ def f(self, b): return self.a + b m = MyModule(13) - assert isinstance(m.f, jtu.Partial) + assert isinstance(m.f, eqx.Module) flat, treedef = jtu.tree_flatten(m.f) assert len(flat) == 1 assert flat[0] == 13 assert jtu.tree_unflatten(treedef, flat)(2) == 15 +def test_eq_method(): + # Expected behaviour from non-Module methods + class A: + def f(self): + pass + + a = A() + assert a.f == a.f + + class B(eqx.Module): + def f(self): + pass + + assert B().f == B().f + + def test_init_subclass(): ran = []