Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Keras layers do not track tf.Module (not conforming to SOLID principles) #47264

Closed
st-- opened this issue Feb 19, 2021 · 9 comments
Closed

Keras layers do not track tf.Module (not conforming to SOLID principles) #47264

st-- opened this issue Feb 19, 2021 · 9 comments
Assignees
Labels
comp:keras Keras related issues Fixed in Nightly Issues that are resolved in nightly version stale This label marks the issue/pr stale - to be closed automatically if no activity stat:awaiting response Status - Awaiting response from author TF 2.4 for issues related to TF 2.4 type:bug Bug

Comments

@st--
Copy link

st-- commented Feb 19, 2021

This is a reduction of tensorflow/probability#946 to an issue with TensorFlow by itself.

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): no
  • OS Platform and Distribution: Linux Ubuntu 18.04
  • TensorFlow installed from: binary
  • TensorFlow version: v2.4.0-49-g85c8b2a817f 2.4.1
  • Python version: 3.7.9

Describe the current behavior
tf.keras.layers.Layer is a subclass of tf.Module, but I cannot replace a use of tf.Module with tf.keras.layers.Layer, violating the Liskov substitution principle - which is a real issue for downstream projects as demonstrated by tensorflow/probability#946.

Describe the expected behavior
I can substitute a tf.keras.layers.Layer anywhere I need a tf.Module. Specifically, layer.trainable_variables and layer.trainable_weights discover all Variable instances that are in sub-Modules, not just in sub-Layers.

Ideally, .trainable_variables would have a consistent return type between Module and Layer.

Standalone code to reproduce the issue

module = tf.Module()
module.submodule = tf.Module()
module.submodule.var = tf.Variable(1.0)
assert module.trainable_variables == (module.submodule.var,)  # as expected

layer = tf.keras.layers.Layer()
assert isinstance(layer, tf.Module)  # passes
layer.sublayer = tf.keras.layers.Layer()
layer.sublayer.var = tf.Variable(1.0)
assert layer.trainable_variables == [layer.sublayer.var]  # acceptable

layer = tf.keras.layers.Layer()
layer.submodule = tf.Module()
layer.submodule.var = tf.Variable(1.0)
assert list(layer.trainable_variables) == [layer.submodule.var]  # FAILS
@st--
Copy link
Author

st-- commented Feb 19, 2021

We can resolve this with the following Layer subclass:

import itertools
from functools import wraps
from typing import Any, Callable, Optional, Sequence

import tensorflow as tf


def extend_and_filter(
    extend_method: Callable[..., Sequence], filter_method: Optional[Callable[..., Sequence]] = None,
) -> Callable[[Any], Any]:
    """
    This decorator calls a decorated method, and extends the result with another method
    on the same class. This method is called after the decorated function, with the same
    arguments as the decorated function. If specified, a second filter method can be applied
    to the extended list. Filter method should also be a method from the class.

    :param extend_method: Callable
        Accepts the same argument as the decorated method.
        The returned list from `extend_method` will be added to the
        decorated method's returned list.
    :param filter_method: Callable
        Takes in the extended list and filters it.
        Defaults to no filtering for `filter_method` equal to `None`.
    """

    def decorator(f: Callable) -> Callable:
        @wraps(f)
        def wrapped(self, *args, **kwargs):  # type: ignore
            ret = f(self, *args, **kwargs)
            ret.extend(extend_method(self, *args, **kwargs))
            ret = filter_method(self, ret) if filter_method is not None else ret
            return ret

        return wrapped

    return decorator


class TrackableLayer(tf.keras.layers.Layer):
    """
    A tf.Layer that implements tracking of tf.Variables on the class's
    attributes that are tf.Modules.

    Currently, tf.Modules track the tf.Variables of their attributes that are
    also tf.Modules.  Similarly, tf.Layers track the tf.Variables of their
    attributes that are also tf.Layers.  However, despite the fact that
    tf.Layer inherits from tf.Module, they cannot track the tf.Variables of
    their attributes that are generic tf.Modules. This seems to be an issue
    that the TensorFlow authors seem to want to fix in the future.
    """

    @property
    def _submodules(self) -> Sequence[tf.Module]:
        """Return a list of tf.Module instances that are attributes on the class. Note
        this also include list or tuples of tf.Modules"""

        submodules = []

        def get_nested_submodules(*objs: Any) -> None:
            for o in objs:
                if isinstance(o, tf.Module):
                    submodules.append(o)

        for key, obj in self.__dict__.items():
            if isinstance(obj, tf.Module):
                submodules.append(obj)
            elif isinstance(obj, (list, tuple)):
                tf.nest.map_structure(get_nested_submodules, obj)
            elif isinstance(obj, (dict,)):
                tf.nest.map_structure(get_nested_submodules, obj.values())

        return list(dict.fromkeys(submodules))  # remove duplicates, maintaining order (dict 3.6)

    def submodule_variables(self) -> Sequence[tf.Variable]:
        """Return flat iterable of variables from the attributes that are tf.Modules"""
        return list(itertools.chain(*[module.variables for module in self._submodules]))

    def submodule_trainable_variables(self) -> Sequence[tf.Variable]:
        """Return flat iterable of trainable variables from attributes that are tf.Modules"""
        return list(itertools.chain(*[module.trainable_variables for module in self._submodules]))

    def submodule_non_trainable_variables(self) -> Sequence[tf.Variable]:
        """Return flat iterable of non trainable variables from attributes that are tf.Modules"""
        return [v for module in self._submodules for v in module.variables if not v.trainable]

    def _dedup_weights(self, weights):  # type: ignore
        """Deduplicate weights while maintaining order as much as possible."""
        # copy this method from the super class
        # to have it in the local class' namespace
        return super()._dedup_weights(weights)

    @property  # type: ignore
    @extend_and_filter(submodule_trainable_variables, _dedup_weights)
    def trainable_weights(self) -> Sequence[tf.Variable]:
        return super().trainable_weights

    @property  # type: ignore
    @extend_and_filter(submodule_non_trainable_variables, _dedup_weights)
    def non_trainable_weights(self) -> Sequence[tf.Variable]:
        return super().non_trainable_weights

    @property  # type: ignore
    @extend_and_filter(submodule_trainable_variables, _dedup_weights)
    def trainable_variables(self) -> Sequence[tf.Variable]:
        return super().trainable_variables

    @property  # type: ignore
    @extend_and_filter(submodule_variables, _dedup_weights)
    def variables(self) -> Sequence[tf.Variable]:
        return super().variables

With this class, the following works as expected:

layer = TrackableLayer()
layer.submodule = tf.Module()
layer.submodule.var = tf.Variable(1.0)
assert list(layer.trainable_variables) == [layer.submodule.var]  # now passes

What would it take to get this behaviour into TensorFlow itself? Can we simply move the above methods onto tf.keras.layers.Layer itself? I'd be happy to contribute a PR if that helps resolve this issue faster.

@st--
Copy link
Author

st-- commented Feb 19, 2021

@amahendrakar
Copy link
Contributor

@st--,
I was able to reproduce the issue with TF v2.3 and TF v2.4. However, the issue seems to be fixed with the latest TF-nightly.

Please check the linked gist for reference. Thanks!

@amahendrakar amahendrakar added comp:keras Keras related issues Fixed in Nightly Issues that are resolved in nightly version stat:awaiting response Status - Awaiting response from author TF 2.4 for issues related to TF 2.4 labels Feb 22, 2021
@st--
Copy link
Author

st-- commented Feb 23, 2021

@amahendrakar ah, awesome - thanks for the quick response. I did upgrade to 2.4.1 to check it was still an issue but hadn't though of also checking nightly ... do you happen to know which PR/commit fixed this in the end? Would be curious to see how it actually got fixed in TF-core.

@tensorflowbutler tensorflowbutler removed the stat:awaiting response Status - Awaiting response from author label Feb 25, 2021
@amahendrakar
Copy link
Contributor

do you happen to know which PR/commit fixed this in the end? Would be curious to see how it actually got fixed in TF-core.

@st--,
Sorry for the delayed response. Since the modules are regularly being updated, we cannot pin point the exact commit/PR which fixed the issue.

Please feel free to close the issue if resolved. Thanks!

@amahendrakar amahendrakar added the stat:awaiting response Status - Awaiting response from author label Mar 10, 2021
@google-ml-butler
Copy link

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you.

@google-ml-butler google-ml-butler bot added the stale This label marks the issue/pr stale - to be closed automatically if no activity label Mar 17, 2021
@google-ml-butler
Copy link

Closing as stale. Please reopen if you'd like to work on this further.

@google-ml-butler
Copy link

Are you satisfied with the resolution of your issue?
Yes
No

@nyngwang
Copy link

nyngwang commented Jul 19, 2022

@amahendrakar My apologies about replying to the old issue, but if this issue is closed, should the documentation be updated as well? Specifically, I meant the following quote on https://www.tensorflow.org/guide/intro_to_modules#defining_models_and_layers_in_tensorflow:

Note: tf.Module is the base class for both tf.keras.layers.Layer and tf.keras.Model, so everything you come across here also applies in Keras. For historical compatibility reasons Keras layers do not collect variables from modules, so your models should use only modules or only Keras layers. However, the methods shown below for inspecting variables are the same in either case.

A more detailed description, including codes for testing, can be found on: https://stackoverflow.com/q/73033488/5290519.

(I'm using tf 2.9.1.)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:keras Keras related issues Fixed in Nightly Issues that are resolved in nightly version stale This label marks the issue/pr stale - to be closed automatically if no activity stat:awaiting response Status - Awaiting response from author TF 2.4 for issues related to TF 2.4 type:bug Bug
Projects
None yet
Development

No branches or pull requests

4 participants