Skip to content

Commit

Permalink
Now running tree_check when initialising a Module
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Sep 12, 2023
1 parent a09da95 commit 77ab1ce
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 6 deletions.
41 changes: 40 additions & 1 deletion equinox/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from ._caches import internal_lru_caches
from ._doc_utils import doc_repr
from ._pretty_print import tree_pformat
from ._tree import tree_equal
from ._tree import tree_check_internal, tree_equal


_P = ParamSpec("_P")
Expand Down Expand Up @@ -141,6 +141,11 @@ def _not_magic(k: str) -> bool:


_has_dataclass_init = weakref.WeakKeyDictionary()
_has_been_checked = weakref.WeakValueDictionary()


def _skip(node):
return isinstance(node, Module) and node is _has_been_checked.get(id(node), None)


_dummy_abstract = abc.abstractmethod(lambda self: 1)
Expand Down Expand Up @@ -272,13 +277,47 @@ def __call__(cls, *args, **kwargs):
else:
setattr(self, field.name, converter(getattr(self, field.name)))
object.__setattr__(self, "__class__", cls)
# Note that these checks only run during the initial creation, and not during
# unflattening.
for kls in cls.__mro__:
try:
check = kls.__dict__["__check_init__"]
except KeyError:
pass
else:
check(self)
try:
tree_check_internal(self, _skip)
except ValueError as e:
raise ValueError(
"As of Equinox v0.11.0, `equinox.Module`s now validate that there "
"aren't any repeated layers inside a module. This is because this was "
"previously a common bug.\n"
"As an example, something like this:\n"
"```\n`"
"class MyModule(eqx.Module):\n"
" linear1: eqx.nn.Linear\n"
" linear2: eqx.nn.Linear\n"
"\n"
" def __init__(self, ...):\n"
" linear = eqx.nn.Linear(...)\n"
" self.linear1 = linear\n"
" self.linear2 = linear\n"
"```\n"
"resulted in two independent linear layers after a gradient update had "
"happened.\n"
"An exception is being thrown now as this error been detected.\n"
"If you intended to share the layer, then use the new functionality "
"`eqx.nn.Shared`. If you intended to have two duplicate copies, then "
"please instantiate two separate layers. If it's easier, you can also "
"clone an existing layer by doing\n"
"```\n"
"layer = ...\n"
"leaves, treedef = jax.tree_util.tree_flatten(layer)\n"
"clone_layer = jax.tree_util.tree_unflatten(treedef, leaves)\n"
"```"
) from e
_has_been_checked[id(self)] = self
return self


Expand Down
19 changes: 14 additions & 5 deletions equinox/_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,14 @@ def is_leaf(node):
return jtu.tree_flatten(pytree, is_leaf=is_leaf)


def tree_check_internal(pytree, skip) -> None:
"""As `tree_check`, but can skips checking some nodes (typically those that have
alread been checked).
"""
all_nodes = {}
_tree_check(pytree, all_nodes, skip)


def tree_check(pytree: Any) -> None:
"""Checks if the PyTree is well-formed: does it have no self-references, and does
it have no duplicate layers.
Expand Down Expand Up @@ -389,13 +397,13 @@ def tree_check(pytree: Any) -> None:
A `ValueError` if the PyTree is not well-formed.
"""
all_nodes = {}
_tree_check(pytree, all_nodes)
_tree_check(pytree, all_nodes, skip=lambda _: False)


_leaf_treedef = jtu.tree_structure(0)


def _tree_check(node, all_nodes):
def _tree_check(node, all_nodes, skip):
subnodes, treedef = tree_flatten_one_level(node)
# We allow duplicate leaves and empty containers, so don't raise an error with those
if treedef != _leaf_treedef and treedef.num_leaves > 0:
Expand All @@ -422,7 +430,8 @@ def _tree_check(node, all_nodes):
except AttributeError:
# AttributeError: in case we cannot get __name__ for some weird reason.
type_string = "<unknown type>"
all_nodes[id(node)] = (True, type_string)
for subnode in subnodes:
_tree_check(subnode, all_nodes)
if not skip(node):
all_nodes[id(node)] = (True, type_string)
for subnode in subnodes:
_tree_check(subnode, all_nodes, skip)
all_nodes[id(node)] = (False, type_string)
23 changes: 23 additions & 0 deletions tests/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import dataclasses
import functools as ft
from collections.abc import Callable
import gc
from typing import Any

import jax
Expand Down Expand Up @@ -540,3 +541,25 @@ class Abstract3(eqx.Module, strict=True):
@abc.abstractmethod
def foo(self):
pass


def test_tree_check_cache(getkey):
gc.collect()
has_been_checked = eqx._module._has_been_checked
num_checked = len(has_been_checked)
mlp = eqx.nn.MLP(2, 2, 2, 2, key=getkey())
# +4: one for `MLP`, and three for its `Linear` layers inside.
assert len(has_been_checked) == num_checked + 4
del mlp
gc.collect()
assert len(has_been_checked) == num_checked


def test_duplicate_layer_error(getkey):
class M(eqx.Module):
l1: eqx.nn.Linear
l2: eqx.nn.Linear

linear = eqx.nn.Linear(2, 2, key=getkey())
with pytest.raises(ValueError, match="As of Equinox v0.11.0"):
M(linear, linear)

0 comments on commit 77ab1ce

Please sign in to comment.