New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Keras layers do not track tf.Module (not conforming to SOLID principles) #47264
Comments
We can resolve this with the following Layer subclass: import itertools
from functools import wraps
from typing import Any, Callable, Optional, Sequence
import tensorflow as tf
def extend_and_filter(
extend_method: Callable[..., Sequence], filter_method: Optional[Callable[..., Sequence]] = None,
) -> Callable[[Any], Any]:
"""
This decorator calls a decorated method, and extends the result with another method
on the same class. This method is called after the decorated function, with the same
arguments as the decorated function. If specified, a second filter method can be applied
to the extended list. Filter method should also be a method from the class.
:param extend_method: Callable
Accepts the same argument as the decorated method.
The returned list from `extend_method` will be added to the
decorated method's returned list.
:param filter_method: Callable
Takes in the extended list and filters it.
Defaults to no filtering for `filter_method` equal to `None`.
"""
def decorator(f: Callable) -> Callable:
@wraps(f)
def wrapped(self, *args, **kwargs): # type: ignore
ret = f(self, *args, **kwargs)
ret.extend(extend_method(self, *args, **kwargs))
ret = filter_method(self, ret) if filter_method is not None else ret
return ret
return wrapped
return decorator
class TrackableLayer(tf.keras.layers.Layer):
"""
A tf.Layer that implements tracking of tf.Variables on the class's
attributes that are tf.Modules.
Currently, tf.Modules track the tf.Variables of their attributes that are
also tf.Modules. Similarly, tf.Layers track the tf.Variables of their
attributes that are also tf.Layers. However, despite the fact that
tf.Layer inherits from tf.Module, they cannot track the tf.Variables of
their attributes that are generic tf.Modules. This seems to be an issue
that the TensorFlow authors seem to want to fix in the future.
"""
@property
def _submodules(self) -> Sequence[tf.Module]:
"""Return a list of tf.Module instances that are attributes on the class. Note
this also include list or tuples of tf.Modules"""
submodules = []
def get_nested_submodules(*objs: Any) -> None:
for o in objs:
if isinstance(o, tf.Module):
submodules.append(o)
for key, obj in self.__dict__.items():
if isinstance(obj, tf.Module):
submodules.append(obj)
elif isinstance(obj, (list, tuple)):
tf.nest.map_structure(get_nested_submodules, obj)
elif isinstance(obj, (dict,)):
tf.nest.map_structure(get_nested_submodules, obj.values())
return list(dict.fromkeys(submodules)) # remove duplicates, maintaining order (dict 3.6)
def submodule_variables(self) -> Sequence[tf.Variable]:
"""Return flat iterable of variables from the attributes that are tf.Modules"""
return list(itertools.chain(*[module.variables for module in self._submodules]))
def submodule_trainable_variables(self) -> Sequence[tf.Variable]:
"""Return flat iterable of trainable variables from attributes that are tf.Modules"""
return list(itertools.chain(*[module.trainable_variables for module in self._submodules]))
def submodule_non_trainable_variables(self) -> Sequence[tf.Variable]:
"""Return flat iterable of non trainable variables from attributes that are tf.Modules"""
return [v for module in self._submodules for v in module.variables if not v.trainable]
def _dedup_weights(self, weights): # type: ignore
"""Deduplicate weights while maintaining order as much as possible."""
# copy this method from the super class
# to have it in the local class' namespace
return super()._dedup_weights(weights)
@property # type: ignore
@extend_and_filter(submodule_trainable_variables, _dedup_weights)
def trainable_weights(self) -> Sequence[tf.Variable]:
return super().trainable_weights
@property # type: ignore
@extend_and_filter(submodule_non_trainable_variables, _dedup_weights)
def non_trainable_weights(self) -> Sequence[tf.Variable]:
return super().non_trainable_weights
@property # type: ignore
@extend_and_filter(submodule_trainable_variables, _dedup_weights)
def trainable_variables(self) -> Sequence[tf.Variable]:
return super().trainable_variables
@property # type: ignore
@extend_and_filter(submodule_variables, _dedup_weights)
def variables(self) -> Sequence[tf.Variable]:
return super().variables With this class, the following works as expected: layer = TrackableLayer()
layer.submodule = tf.Module()
layer.submodule.var = tf.Variable(1.0)
assert list(layer.trainable_variables) == [layer.submodule.var] # now passes What would it take to get this behaviour into TensorFlow itself? Can we simply move the above methods onto tf.keras.layers.Layer itself? I'd be happy to contribute a PR if that helps resolve this issue faster. |
There's a related work-around in Edwards2: https://github.com/google/edward2/blob/5e18aec19af3925274808b55ff22f8a60a6cebdb/edward2/tensorflow/layers/utils.py#L38 |
@st--, Please check the linked gist for reference. Thanks! |
@amahendrakar ah, awesome - thanks for the quick response. I did upgrade to 2.4.1 to check it was still an issue but hadn't though of also checking nightly ... do you happen to know which PR/commit fixed this in the end? Would be curious to see how it actually got fixed in TF-core. |
@st--, Please feel free to close the issue if resolved. Thanks! |
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you. |
Closing as stale. Please reopen if you'd like to work on this further. |
@amahendrakar My apologies about replying to the old issue, but if this issue is closed, should the documentation be updated as well? Specifically, I meant the following quote on https://www.tensorflow.org/guide/intro_to_modules#defining_models_and_layers_in_tensorflow:
A more detailed description, including codes for testing, can be found on: https://stackoverflow.com/q/73033488/5290519. (I'm using tf 2.9.1.) |
This is a reduction of tensorflow/probability#946 to an issue with TensorFlow by itself.
System information
Describe the current behavior
tf.keras.layers.Layer is a subclass of tf.Module, but I cannot replace a use of tf.Module with tf.keras.layers.Layer, violating the Liskov substitution principle - which is a real issue for downstream projects as demonstrated by tensorflow/probability#946.
Describe the expected behavior
I can substitute a
tf.keras.layers.Layer
anywhere I need atf.Module
. Specifically,layer.trainable_variables
andlayer.trainable_weights
discover all Variable instances that are in sub-Modules, not just in sub-Layers.Ideally,
.trainable_variables
would have a consistent return type between Module and Layer.Standalone code to reproduce the issue
The text was updated successfully, but these errors were encountered: