Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Catch RecursionError for delayed assattr nodes #1660

Closed

Conversation

jacobtylerwalls
Copy link
Member

Steps

  • For new features or bug fixes, add a ChangeLog entry describing what your PR does.
  • Write a good description on what the PR does.

Description

Type of Changes

Type
βœ“ πŸ› Bug fix

Related Issue

Closes #1646

This isn't "unreleased" as I first thought -- we just started tripping over it in the primer after a recent astroid PR.

Copy link
Member

@Pierre-Sassoulas Pierre-Sassoulas left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like something we might want to release asap as it fixes a crash and there won't be a lot of candidate in 2.11.7 anyway, what do you think ?

@DanielNoord
Copy link
Collaborator

@jacobtylerwalls Do you understand why @decorators.path_wrapper on ImportFrom._infer doesn't catch this? That should be the place to handle infinite recursion in inferences, but apparently it fails here?

@coveralls
Copy link

coveralls commented Jun 24, 2022

Pull Request Test Coverage Report for Build 2576737403

  • 1 of 1 (100.0%) changed or added relevant line in 1 file are covered.
  • 76 unchanged lines in 2 files lost coverage.
  • Overall coverage increased (+0.03%) to 92.339%

Files with Coverage Reduction New Missed Lines %
astroid/inference.py 22 94.61%
astroid/nodes/node_classes.py 54 95.42%
Totals Coverage Status
Change from base Build 2560834104: 0.03%
Covered Lines: 9426
Relevant Lines: 10208

πŸ’› - Coveralls

@jacobtylerwalls
Copy link
Member Author

jacobtylerwalls commented Jun 24, 2022

@DanielNoord I think it's because it calls do_import_module() and thus it's a different root cause for the recursion. (Unlike with garden-variety inference, here we have during "importing" of a module we have delayed assattr nodes that also then trigger do_import_module).

@DanielNoord
Copy link
Collaborator

@DanielNoord I think it's because it calls do_import_module() and thus it's a different root cause for the recursion. (Unlike with garden-variety inference, here we have during "importing" of a module we have delayed assattr nodes that also then trigger do_import_module).

Hm, so shouldn't this be fixed there then? By stopping do_import_module if the module is already being imported by a "higher callee"?

@jacobtylerwalls
Copy link
Member Author

Sure, but I'm unaware of a way to track that. I figured by targeting the "delayed assattr" branch I was getting as close to the source as feasible. But I'm happy to follow your lead if you have an idea.

@DanielNoord
Copy link
Collaborator

@jacobtylerwalls What's the file in pandas that we crash on? I can't seem to crash locally on .../pandas_dev/pandas/pandas/core/arrays/_mixins.py.

@jacobtylerwalls
Copy link
Member Author

Make sure to launch pylint from the pylint project root. There's no failure if you launch from the pandas package root. (I think the primer tool is unrealistically linting projects as if they are namespace packages by putting them in a big uber-folder that's not a python package.)

@DanielNoord
Copy link
Collaborator

I'm probably being stupid, but this doesn't crash for me:

cd /tmp
git clone https://github.com/PyCQA/pylint.git
cd pylint
virtualenv venv
source venv/bin/activate
pip install -r requirements_test.txt
python tests/primer/__main__.py prepare --clone
pylint tests/.pylint_primer_tests/pandas-dev/pandas/pandas/core/arrays/_mixins.py

@jacobtylerwalls
Copy link
Member Author

Did you install astroid bleeding edge?

@DanielNoord
Copy link
Collaborator

Did you install astroid bleeding edge?

I didn't, but I just did:

❯ pylint --version
pylint 2.15.0-dev0
astroid 2.12.0-dev0
Python 3.11.0b3 (main, Jun 12 2022, 16:13:11) [Clang 13.0.0 (clang-1300.0.29.30)]

Still no crash πŸ˜“

@jacobtylerwalls
Copy link
Member Author

Aha! Replace everything in that file with just this line (which is found in the file):

from pandas.core.algorithms import take

@DanielNoord
Copy link
Collaborator

DanielNoord commented Jun 25, 2022

❯ pylint --version
pylint 2.15.0-dev0
astroid 2.12.0-dev0
Python 3.11.0b3 (main, Jun 12 2022, 16:13:11) [Clang 13.0.0 (clang-1300.0.29.30)]
❯ which pylint
/private/tmp/pylint/venv/bin/pylint
❯ pylint tests/.pylint_primer_tests/pandas-dev/pandas/pandas/core/arrays/_mixins.py
************* Module pandas.core.arrays._mixins
tests/.pylint_primer_tests/pandas-dev/pandas/pandas/core/arrays/_mixins.py:1:0: W0611: Unused take imported from pandas.core.algorithms (unused-import)

------------------------------------------------------------------
Your code has been rated at 0.00/10 (previous run: 8.19/10, -8.19)
❯ cat tests/.pylint_primer_tests/pandas-dev/pandas/pandas/core/arrays/_mixins.py

from pandas.core.algorithms import take

πŸ˜…

@jacobtylerwalls
Copy link
Member Author

Huh. That recipe reproduces just fine for me, following your steps exactly (and installing astroid main with -e). I'm on Python3.10, that's the only difference.

@DanielNoord
Copy link
Collaborator

Huh. That recipe reproduces just fine for me, following your steps exactly (and installing astroid main with -e). I'm on Python3.10, that's the only difference.

Oh wow, this doesn't actually crash on 3.11. I should have tested on 3.10 immediately.

@DanielNoord
Copy link
Collaborator

I have been trying to get a reproducer, but it is proving very difficult. Please don't merge for now, even if we decide that the current approach is correct having a test would be good.

@jacobtylerwalls
Copy link
Member Author

I feel your pain! Thanks for giving it a go.

@DanielNoord
Copy link
Collaborator

pandas.zip

After lots of trial and error I created a package that crashes with minimal code: linting on 3.10 on pandas/core/algorithms.py should crash.
However, I was interested to see on which module we are recursing to see if I could create a stop like in #1392
Turns out, we're not actually recursing incorrectly πŸ˜…

The following diff already fixes the issue:

diff --git a/astroid/builder.py b/astroid/builder.py
index 24caa0c6e..09ccfaada 100644
--- a/astroid/builder.py
+++ b/astroid/builder.py
@@ -144,6 +144,9 @@ class AstroidBuilder(raw_building.InspectBuilder):
         self, module: nodes.Module, builder: rebuilder.TreeRebuilder, encoding: str
     ) -> nodes.Module:
         """Handles encoding and delayed nodes after a module has been built"""
+        import sys
+
+       sys.setrecursionlimit(1500)
         module.file_encoding = encoding
         self._manager.cache_module(module)
         # post tree building steps after we stored the module in the cache:

Normally this limit is 1000, but for pandas we need it to be slightly higher.

As to why this doesn't crash on 3.11: the only thing I can think of is that the detection of recursion improved on 3.11?

I think for the pandas package it would be enough to add sys.setrecursionlimit(1500) to the init-hook and then run pylint. It feels a bit weird to stop execution of the program due to Python not identifying our complicated recursion pattern with (hundreds of decorators) correctly...

@jacobtylerwalls
Copy link
Member Author

Turns out, we're not actually recursing incorrectly πŸ˜…

Yes, sorry, I failed to mention this.

@jacobtylerwalls
Copy link
Member Author

I think for the pandas package it would be enough to add sys.setrecursionlimit(1500) to the init-hook and then run pylint.

Yes, we could do that in the pylint primer launcher. But do you think this PR is still a good guard in case folks run into this themselves?

@Pierre-Sassoulas
Copy link
Member

Pierre-Sassoulas commented Jun 27, 2022

There's a performance implication of recursing over 500+ time, I think add sys.setrecursionlimit(1500) to the init-hook should be a user decision. Not crashing is sufficient for a hot fix imo. Maybe long term we need to raise a pylint error like "genuinely-large-recursion-fail" to warn the user that he can use a higher limit for recursion themself and suffer slowness, or have some false positives and a fast pylint.

@DanielNoord
Copy link
Collaborator

3.11 fixes this because of this I think.

Yes, we could do that in the pylint primer launcher. But do you think this PR is still a good guard in case folks run into this themselves?

Yeah, although we should probably add something about this in the documentation of init-hook or somewhere else.

I do wonder if this is the correct place though. In any case, the current comment is incorrect. It's not trying to import the same thing, but just have a lot of different calls of the same functions.

astroid/builder.py Outdated Show resolved Hide resolved
ChangeLog Outdated Show resolved Hide resolved
@DanielNoord DanielNoord self-requested a review June 28, 2022 14:59
@DanielNoord
Copy link
Collaborator

I'll probably have time to revisit this this weekend. I'd like to add a test + do one final check if this is the best place to catch the recursion. I completely agree that we should catch this in astroid and handle this better than we currently are, but I can see somebody wondering 3/4 years from now why there is a comment about infinite recursion on infer_import_from() when this method doesn't even call infer_import_from() πŸ˜„

Obviously that individual should recognise `.infer() but you get my point.

@Pierre-Sassoulas
Copy link
Member

Let's release astroid 2.11.7 after merging this :)

@DanielNoord
Copy link
Collaborator

DanielNoord commented Jul 2, 2022

@jacobtylerwalls I think I'm going mad. Could you retry with latest main: has this issue been resolved?

I can't reproduce anymore in my local environment and pylint main runs also seem to be fine: https://github.com/PyCQA/pylint/runs/7160221551?check_suite_focus=true?

Edit:
The unblocking of main occurred here;
Last failure:
https://github.com/PyCQA/pylint/runs/7002534787?check_suite_focus=true
Next run is okay:
https://github.com/PyCQA/pylint/runs/7004479986?check_suite_focus=true

Checking out astroid to db6509b crashes, while the next commit 7db01c1 fixes this.
Both on current main of pandas (with the change to _mixins.py) as well as their commit that originally crashed:
faacb72b96b7d6604e5eff13f0e224cf741ff8d8.

Edit2:
I'm not sure what the best course of action is. Clearly we can run into recursion issues within astroid, but seeing as that they seem to be resolved on main and will be less prominent on 3.11 due to new call mechanism I'm not sure if we should add an except here right now. This is also because I'm still not sure if this is the best place: the issue occurs due to sequential importing of modules so it seems the actual recursion should be caught somewhere in infer_import_from which could then (for example) return a dummy nodes.Module object.

@DanielNoord DanielNoord removed their request for review July 2, 2022 09:43
@jacobtylerwalls
Copy link
Member Author

jacobtylerwalls commented Jul 2, 2022

Thanks for digging! My bisect shows that the problem most recently went away as of 213b2a6 instead, which makes more sense.

. Clearly we can run into recursion issues within astroid, but seeing as that they seem to be resolved on main and will be less prominent on 3.11 due to new call mechanism I'm not sure if we should add an except here right now.

I agree. I'll close this PR and the pandas issue. Let's keep an πŸ‘‚ out for future issues. There is pylint-dev/pylint#7011 anyway.

@jacobtylerwalls
Copy link
Member Author

@DanielNoord this came up just now when running the old primer locally over keras against astroid 2.12.x:

First, please verify that the bug is not already filled: https://github.com/PyCQA/pylint/issues/

Issue title:
Crash ```` (if possible, be more specific about what made pylint crash)
Content:
When parsing the following file:

# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for variable store."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import gc
import threading

import numpy
import tensorflow as tf
from absl.testing import parameterized

from keras import models
from keras import regularizers
from keras.engine import base_layer
from keras.engine import input_layer as input_layer_module
from keras.engine import training as training_module
from keras.layers import core
from keras.legacy_tf_layers import core as core_layers
from keras.legacy_tf_layers import variable_scope_shim
from keras.testing_infra import test_combinations

# isort: off
from tensorflow.python.framework import (
    test_util as tf_test_utils,
)
from tensorflow.python.ops import variable_scope


def run_inside_wrap_function_in_eager_mode(graph_function):
    """Decorator to execute the same graph code in eager and graph modes.

    In graph mode, we just execute the graph_function passed as argument. In
    eager mode, we wrap the function using wrap_function and then execute the
    wrapped result.

    Args:
      graph_function: python function containing graph code to be wrapped

    Returns:
      decorated function
    """

    def wrap_and_execute(self):
        store = variable_scope_shim._EagerVariableStore()
        with variable_scope.with_variable_store(store):
            # use the original function
            graph_function(self)

    return wrap_and_execute


class VariableScopeTest(tf.test.TestCase):
    def tearDown(self):
        gc.collect()
        # This will only contain uncollectable garbage, i.e. reference cycles
        # involving objects with __del__ defined.
        self.assertEqual(0, len(gc.garbage))

    @tf_test_utils.run_in_graph_and_eager_modes
    @run_inside_wrap_function_in_eager_mode
    def testGetVar(self):
        vs = variable_scope._get_default_variable_store()
        v = vs.get_variable("v", [1])
        v1 = vs.get_variable("v", [1])
        self.assertIs(v, v1)

    @tf_test_utils.run_in_graph_and_eager_modes
    @run_inside_wrap_function_in_eager_mode
    def testNameExists(self):
        vs = variable_scope._get_default_variable_store()
        # No check by default, so we can both create and get existing names.
        v = vs.get_variable("v", [1])
        v1 = vs.get_variable("v", [1])
        self.assertIs(v, v1)

        self.assertIsNot(v, vs.get_variable("u", [1], reuse=False))

    @tf_test_utils.run_in_graph_and_eager_modes
    @run_inside_wrap_function_in_eager_mode
    def testNamelessStore(self):
        vs = variable_scope._get_default_variable_store()
        vs.get_variable("v1", [2])
        vs.get_variable("v2", [2])
        expected_names = ["%s:0" % name for name in ["v1", "v2"]]
        self.assertEqual(
            set(expected_names), set(v.name for v in vs._vars.values())
        )

    # TODO(mihaimaruseac): Not converted to use wrap_function because of
    # TypeError: Expected tf.group() expected Tensor arguments not 'None' with
    # type '<type 'NoneType'>'
    @tf_test_utils.run_in_graph_and_eager_modes
    def testVarScopeInitializer(self):
        init = tf.compat.v1.constant_initializer(0.3)
        with tf.compat.v1.variable_scope("tower0") as tower:
            with tf.compat.v1.variable_scope("foo", initializer=init):
                v = tf.compat.v1.get_variable("v", [])
                self.evaluate(tf.compat.v1.variables_initializer([v]))
                self.assertAllClose(self.evaluate(v.value()), 0.3)
            with tf.compat.v1.variable_scope(tower, initializer=init):
                w = tf.compat.v1.get_variable("w", [])
                self.evaluate(tf.compat.v1.variables_initializer([w]))
                self.assertAllClose(self.evaluate(w.value()), 0.3)

    @tf_test_utils.run_in_graph_and_eager_modes
    @run_inside_wrap_function_in_eager_mode
    def testVarScopeConstraint(self):
        constraint = lambda x: 0.0 * x
        with tf.compat.v1.variable_scope("tower1") as tower:
            with tf.compat.v1.variable_scope("foo", constraint=constraint):
                v = tf.compat.v1.get_variable("v", [])
                self.assertIsNotNone(v.constraint)
            with tf.compat.v1.variable_scope(tower, constraint=constraint):
                w = tf.compat.v1.get_variable("w", [])
                self.assertIsNotNone(w.constraint)

    @tf_test_utils.run_in_graph_and_eager_modes
    @run_inside_wrap_function_in_eager_mode
    def testVarScopeDType(self):
        with tf.compat.v1.variable_scope("tower2") as tower:
            with tf.compat.v1.variable_scope("foo", dtype=tf.float16):
                v = tf.compat.v1.get_variable("v", [])
                self.assertEqual(v.dtype.base_dtype, tf.float16)
            with tf.compat.v1.variable_scope(tower, dtype=tf.float16):
                w = tf.compat.v1.get_variable("w", [])
                self.assertEqual(w.dtype.base_dtype, tf.float16)

    @tf_test_utils.run_in_graph_and_eager_modes
    @run_inside_wrap_function_in_eager_mode
    def testInitFromNonTensorValue(self):
        v = tf.compat.v1.get_variable("v4", initializer=4, dtype=tf.int32)
        self.evaluate(tf.compat.v1.variables_initializer([v]))
        self.assertAllClose(self.evaluate(v.value()), 4)

        w = tf.compat.v1.get_variable(
            "w4", initializer=numpy.array([1, 2, 3]), dtype=tf.int64
        )
        self.evaluate(tf.compat.v1.variables_initializer([w]))
        self.assertAllClose(self.evaluate(w.value()), [1, 2, 3])

        # A quirk to be revisited?
        error = ValueError if tf.executing_eagerly() else TypeError
        with self.assertRaises(error):
            tf.compat.v1.get_variable("x4", initializer={})

    @tf_test_utils.run_in_graph_and_eager_modes
    @run_inside_wrap_function_in_eager_mode
    def testInitFromNonInitializer(self):
        # Test various dtypes with zeros initializer as following:
        types = [
            tf.int8,
            tf.uint8,
            tf.int16,
            tf.uint16,
            tf.int32,
            tf.int64,
            tf.bool,
        ]

        # Use different variable_name to distinguish various dtypes
        for (i, dtype) in enumerate(types):
            x = tf.compat.v1.get_variable(
                name="xx%d" % i, shape=(3, 4), dtype=dtype
            )
            y = tf.compat.v1.get_variable(
                name="yy%d" % i,
                shape=(3, 4),
                dtype=dtype,
                initializer=tf.compat.v1.zeros_initializer(dtype=dtype),
            )

            self.evaluate(tf.compat.v1.global_variables_initializer())
            self.assertAllEqual(
                self.evaluate(x.value()), self.evaluate(y.value())
            )

    @tf_test_utils.run_in_graph_and_eager_modes
    @run_inside_wrap_function_in_eager_mode
    def testVarScopeRegularizer(self):
        init = tf.compat.v1.constant_initializer(0.3)

        def regularizer1(v):
            return tf.reduce_mean(v) + 0.1

        def regularizer2(v):
            return tf.reduce_mean(v) + 0.2

        with tf.compat.v1.variable_scope(
            "tower3", regularizer=regularizer1
        ) as tower:
            with tf.compat.v1.variable_scope("foo", initializer=init):
                v = tf.compat.v1.get_variable("v", [])
                self.evaluate(tf.compat.v1.variables_initializer([v]))
            with tf.compat.v1.variable_scope(tower, initializer=init) as vs:
                tf.compat.v1.get_variable("u", [])
                vs.set_regularizer(regularizer2)
                tf.compat.v1.get_variable("w", [])
                # Next 3 variable not regularized to test disabling
                # regularization.
                tf.compat.v1.get_variable(
                    "x", [], regularizer=tf.compat.v1.no_regularizer
                )
                with tf.compat.v1.variable_scope(
                    "baz", regularizer=tf.compat.v1.no_regularizer
                ):
                    tf.compat.v1.get_variable("y", [])
                vs.set_regularizer(tf.compat.v1.no_regularizer)
                tf.compat.v1.get_variable("z", [])

    @tf_test_utils.run_in_graph_and_eager_modes
    @run_inside_wrap_function_in_eager_mode
    def testInitializeFromValue(self):
        init = tf.constant(0.1)
        w = tf.compat.v1.get_variable("v", initializer=init)
        self.evaluate(tf.compat.v1.variables_initializer([w]))
        self.assertAllClose(self.evaluate(w.value()), 0.1)

        with self.assertRaisesRegex(ValueError, "shape"):
            # We disallow explicit shape specification when initializer is
            # constant.
            tf.compat.v1.get_variable("u", [1], initializer=init)

        with tf.compat.v1.variable_scope("foo", initializer=init):
            # Constant initializer can be passed through scopes if needed.
            v = tf.compat.v1.get_variable("v")
            self.evaluate(tf.compat.v1.variables_initializer([v]))
            self.assertAllClose(self.evaluate(v.value()), 0.1)

        # Check that non-float32 initializer creates a non-float32 variable.
        init = tf.constant(1, dtype=tf.int32)
        t = tf.compat.v1.get_variable("t", initializer=init)
        self.assertEqual(t.dtype.base_dtype, tf.int32)

        # Raise error if `initializer` dtype and `dtype` are not identical.
        with self.assertRaisesRegex(ValueError, "don't match"):
            tf.compat.v1.get_variable("s", initializer=init, dtype=tf.float64)

    @tf_test_utils.run_in_graph_and_eager_modes
    @run_inside_wrap_function_in_eager_mode
    def testVarScopeGetOrCreateReuse(self):
        with self.cached_session():

            def test_value(value):
                x = tf.constant(value)
                with tf.compat.v1.variable_scope(
                    "testVarScopeGetOrCreateReuse_bar",
                    reuse=tf.compat.v1.AUTO_REUSE,
                ):
                    _ = tf.compat.v1.assign(
                        tf.compat.v1.get_variable("var", []), x
                    )
                with tf.compat.v1.variable_scope(
                    "testVarScopeGetOrCreateReuse_bar",
                    reuse=tf.compat.v1.AUTO_REUSE,
                ):
                    _ = tf.compat.v1.get_variable("var", [])
                self.assertEqual(value, self.evaluate(x))

            test_value(42.0)  # Variable is created.
            test_value(13.0)  # Variable is reused hereafter.
            test_value(17.0)

    @tf_test_utils.run_in_graph_and_eager_modes
    @run_inside_wrap_function_in_eager_mode
    def testVarScopeGetOrCreateReuseIgnoreFalse(self):
        with self.cached_session():

            def test_value(value):
                x = tf.constant(value)
                with tf.compat.v1.variable_scope(
                    "testVarScopeGetOrCreateReuse_bar", reuse=False
                ):
                    _ = tf.compat.v1.assign(
                        tf.compat.v1.get_variable("var", []), x
                    )
                # We need to ignore reuse=False in the shim, because the code is
                # expected to get rerun each time the user calls the shim.
                with tf.compat.v1.variable_scope(
                    "testVarScopeGetOrCreateReuse_bar", reuse=False
                ):
                    _ = tf.compat.v1.get_variable("var", [])
                self.assertEqual(value, self.evaluate(x))

            test_value(42.0)  # Variable is created.
            test_value(13.0)  # Variable is reused hereafter.
            test_value(17.0)

    @tf_test_utils.run_in_graph_and_eager_modes
    @run_inside_wrap_function_in_eager_mode
    def testVarOpScope(self):
        with self.cached_session():
            with tf.name_scope("testVarOpScope1"):
                with tf.compat.v1.variable_scope("tower", "default", []):
                    self.assertEqual(
                        tf.compat.v1.get_variable("w", []).name, "tower/w:0"
                    )

            with tf.name_scope("testVarOpScope2"):
                with tf.compat.v1.variable_scope(None, "default", []):
                    self.assertEqual(
                        tf.compat.v1.get_variable("w", []).name, "default/w:0"
                    )
                with tf.compat.v1.variable_scope(None, "default", []):
                    self.assertEqual(
                        tf.compat.v1.get_variable("w", []).name, "default_1/w:0"
                    )

    @tf_test_utils.run_in_graph_and_eager_modes
    @run_inside_wrap_function_in_eager_mode
    def testVarOpScopeUniqueNamesInterleavedSubstringScopes(self):
        with self.cached_session():
            with tf.compat.v1.variable_scope(None, "defaultScope1"):
                with tf.compat.v1.variable_scope(None, "layer"):
                    self.assertEqual(
                        tf.compat.v1.get_variable("w", []).name,
                        "defaultScope1/layer/w:0",
                    )
            with tf.compat.v1.variable_scope(None, "defaultScope1"):
                with tf.compat.v1.variable_scope(None, "layer"):
                    self.assertEqual(
                        tf.compat.v1.get_variable("w", []).name,
                        "defaultScope1_1/layer/w:0",
                    )
            with tf.compat.v1.variable_scope(None, "defaultScope"):
                with tf.compat.v1.variable_scope(None, "layer"):
                    self.assertEqual(
                        tf.compat.v1.get_variable("w", []).name,
                        "defaultScope/layer/w:0",
                    )
            with tf.compat.v1.variable_scope(None, "defaultScope1"):
                with tf.compat.v1.variable_scope(None, "layer"):
                    self.assertEqual(
                        tf.compat.v1.get_variable("w", []).name,
                        "defaultScope1_2/layer/w:0",
                    )

    @tf_test_utils.run_in_graph_and_eager_modes
    @run_inside_wrap_function_in_eager_mode
    def testVarOpScopeUniqueNamesWithJump(self):
        with self.cached_session():
            with tf.compat.v1.variable_scope("default") as default:
                with tf.compat.v1.variable_scope(None, "layer"):
                    self.assertEqual(
                        tf.compat.v1.get_variable("w", []).name,
                        "default/layer/w:0",
                    )
                with tf.compat.v1.variable_scope(None, "layer"):
                    self.assertEqual(
                        tf.compat.v1.get_variable("w", []).name,
                        "default/layer_1/w:0",
                    )
                with tf.compat.v1.variable_scope(default):
                    pass
                # No matter the jump in the middle, unique numbering continues.
                with tf.compat.v1.variable_scope(None, "layer"):
                    self.assertEqual(
                        tf.compat.v1.get_variable("w", []).name,
                        "default/layer_2/w:0",
                    )

    @tf_test_utils.run_in_graph_and_eager_modes
    @run_inside_wrap_function_in_eager_mode
    def testVarOpScopeReuse(self):
        with self.cached_session():
            with tf.compat.v1.variable_scope("outer") as outer:
                with tf.compat.v1.variable_scope("tower", "default", []):
                    self.assertEqual(
                        tf.compat.v1.get_variable("w", []).name,
                        "outer/tower/w:0",
                    )
                with tf.compat.v1.variable_scope(None, "default", []):
                    self.assertEqual(
                        tf.compat.v1.get_variable("w", []).name,
                        "outer/default/w:0",
                    )

            with tf.compat.v1.variable_scope(outer, reuse=True) as outer:
                with tf.compat.v1.variable_scope("tower", "default", []):
                    self.assertEqual(
                        tf.compat.v1.get_variable("w", []).name,
                        "outer/tower/w:0",
                    )
                with tf.compat.v1.variable_scope(None, "default", []):
                    self.assertEqual(
                        tf.compat.v1.get_variable("w", []).name,
                        "outer/default/w:0",
                    )

    @tf_test_utils.run_in_graph_and_eager_modes
    @run_inside_wrap_function_in_eager_mode
    def testVarScopeGetVar(self):
        with self.cached_session():
            with tf.compat.v1.variable_scope("root"):
                with tf.compat.v1.variable_scope("towerA") as tower_a:
                    va = tf.compat.v1.get_variable("v", [1])
                    self.assertEqual(va.name, "root/towerA/v:0")

                with tf.compat.v1.variable_scope(tower_a, reuse=True):
                    va2 = tf.compat.v1.get_variable("v", [1])
                    self.assertIs(va2, va)

                with tf.compat.v1.variable_scope("towerB"):
                    vb = tf.compat.v1.get_variable("v", [1])
                    self.assertEqual(vb.name, "root/towerB/v:0")

                with tf.compat.v1.variable_scope("towerA", reuse=True):
                    va2 = tf.compat.v1.get_variable("v", [1])
                    self.assertIs(va2, va)

                with tf.compat.v1.variable_scope("foo"):
                    with tf.compat.v1.variable_scope("bar"):
                        v = tf.compat.v1.get_variable("v", [1])
                        self.assertEqual(v.name, "root/foo/bar/v:0")
                        with tf.compat.v1.variable_scope(tower_a, reuse=True):
                            va3 = tf.compat.v1.get_variable("v", [1])
                            self.assertIs(va, va3)

                with self.assertRaises(ValueError) as exc:
                    with tf.compat.v1.variable_scope(tower_a, reuse=True):
                        tf.compat.v1.get_variable("v", [2])  # Different shape.
                self.assertEqual("shape" in str(exc.exception), True)

                with self.assertRaises(ValueError) as exc:
                    with tf.compat.v1.variable_scope(tower_a, reuse=True):
                        tf.compat.v1.get_variable("v", [1], dtype=tf.int32)
                self.assertEqual("dtype" in str(exc.exception), True)

    @tf_test_utils.run_in_graph_and_eager_modes
    @run_inside_wrap_function_in_eager_mode
    def testVarScopeOuterScope(self):
        with self.cached_session():
            with tf.compat.v1.variable_scope("outer") as outer:
                pass
            with tf.compat.v1.variable_scope(outer):
                self.assertEqual(
                    tf.compat.v1.get_variable("w", []).name, "outer/w:0"
                )
                with tf.compat.v1.variable_scope("default"):
                    self.assertEqual(
                        tf.compat.v1.get_variable("w", []).name,
                        "outer/default/w:0",
                    )

            with tf.compat.v1.variable_scope(outer, reuse=True):
                self.assertEqual(
                    tf.compat.v1.get_variable("w", []).name, "outer/w:0"
                )
                with tf.compat.v1.variable_scope("default", reuse=True):
                    self.assertEqual(
                        tf.compat.v1.get_variable("w", []).name,
                        "outer/default/w:0",
                    )

    @tf_test_utils.run_in_graph_and_eager_modes
    @run_inside_wrap_function_in_eager_mode
    def testVarScopeNestedOuterScope(self):
        with self.cached_session():
            with tf.compat.v1.variable_scope("outer") as outer:
                with tf.compat.v1.variable_scope(outer):
                    self.assertEqual(
                        tf.compat.v1.get_variable("w", []).name, "outer/w:0"
                    )
                with tf.compat.v1.variable_scope("default"):
                    self.assertEqual(
                        tf.compat.v1.get_variable("w", []).name,
                        "outer/default/w:0",
                    )

                with tf.compat.v1.variable_scope(outer, reuse=True):
                    self.assertEqual(
                        tf.compat.v1.get_variable("w", []).name, "outer/w:0"
                    )
                with tf.compat.v1.variable_scope("default", reuse=True):
                    self.assertEqual(
                        tf.compat.v1.get_variable("w", []).name,
                        "outer/default/w:0",
                    )

    @tf_test_utils.run_in_graph_and_eager_modes
    @run_inside_wrap_function_in_eager_mode
    def testVarOpScopeReuseParam(self):
        with self.cached_session():
            with tf.compat.v1.variable_scope("outer") as outer:
                with tf.compat.v1.variable_scope("tower", "default", []):
                    self.assertEqual(
                        tf.compat.v1.get_variable("w", []).name,
                        "outer/tower/w:0",
                    )
                with tf.compat.v1.variable_scope(None, "default", []):
                    self.assertEqual(
                        tf.compat.v1.get_variable("w", []).name,
                        "outer/default/w:0",
                    )

            with tf.compat.v1.variable_scope(outer) as outer:
                with tf.compat.v1.variable_scope(
                    "tower", "default", reuse=True
                ):
                    self.assertEqual(
                        tf.compat.v1.get_variable("w", []).name,
                        "outer/tower/w:0",
                    )
                outer.reuse_variables()
                with tf.compat.v1.variable_scope(None, "default", []):
                    self.assertEqual(
                        tf.compat.v1.get_variable("w", []).name,
                        "outer/default/w:0",
                    )

    @tf_test_utils.run_in_graph_and_eager_modes
    @run_inside_wrap_function_in_eager_mode
    def testVarOpScopeReuseError(self):
        with self.cached_session():
            with self.assertRaises(ValueError):
                with tf.compat.v1.variable_scope(None, "default", reuse=True):
                    self.assertEqual(
                        tf.compat.v1.get_variable("w", []).name,
                        "outer/tower/w:0",
                    )

    @tf_test_utils.run_in_graph_and_eager_modes
    @run_inside_wrap_function_in_eager_mode
    def testVarOpScopeOuterScope(self):
        with self.cached_session():
            with tf.compat.v1.variable_scope("outer") as outer:
                pass
            with tf.compat.v1.variable_scope(outer, "default", []):
                self.assertEqual(
                    tf.compat.v1.get_variable("w", []).name, "outer/w:0"
                )
                with tf.compat.v1.variable_scope(None, "default", []):
                    self.assertEqual(
                        tf.compat.v1.get_variable("w", []).name,
                        "outer/default/w:0",
                    )

            with tf.compat.v1.variable_scope(outer, "default", reuse=True):
                self.assertEqual(
                    tf.compat.v1.get_variable("w", []).name, "outer/w:0"
                )
                outer.reuse_variables()
                with tf.compat.v1.variable_scope(None, "default", []):
                    self.assertEqual(
                        tf.compat.v1.get_variable("w", []).name,
                        "outer/default/w:0",
                    )

    @tf_test_utils.run_in_graph_and_eager_modes
    @run_inside_wrap_function_in_eager_mode
    def testVarOpScopeNestedOuterScope(self):
        with self.cached_session():
            with tf.compat.v1.variable_scope("outer") as outer:
                with tf.compat.v1.variable_scope(outer, "default", []):
                    self.assertEqual(
                        tf.compat.v1.get_variable("w", []).name, "outer/w:0"
                    )
                with tf.compat.v1.variable_scope(None, "default", []):
                    self.assertEqual(
                        tf.compat.v1.get_variable("w", []).name,
                        "outer/default/w:0",
                    )

            with tf.compat.v1.variable_scope(outer, "default", reuse=True):
                self.assertEqual(
                    tf.compat.v1.get_variable("w", []).name, "outer/w:0"
                )
                with tf.compat.v1.variable_scope(None, "default", []):
                    self.assertEqual(
                        tf.compat.v1.get_variable("w", []).name,
                        "outer/default/w:0",
                    )

    @tf_test_utils.run_in_graph_and_eager_modes
    @run_inside_wrap_function_in_eager_mode
    def testBasicWhenAuxiliaryNameScopeIsFalse(self):
        with self.cached_session():
            with tf.compat.v1.variable_scope(
                "scope", auxiliary_name_scope=False
            ) as scope:
                self.assertEqual(
                    tf.compat.v1.get_variable("w", []).name, "scope/w:0"
                )
            with tf.compat.v1.variable_scope(scope, auxiliary_name_scope=False):
                self.assertEqual(
                    tf.compat.v1.get_variable("w1", []).name, "scope/w1:0"
                )

            with tf.compat.v1.variable_scope("outer"):
                with tf.compat.v1.variable_scope(
                    "inner", auxiliary_name_scope=False
                ) as inner:
                    self.assertEqual(inner.original_name_scope, "outer/")
                    self.assertEqual(
                        tf.compat.v1.get_variable("w", []).name,
                        "outer/inner/w:0",
                    )
                with tf.compat.v1.variable_scope(
                    inner, auxiliary_name_scope=False
                ) as inner1:
                    self.assertEqual(inner1.original_name_scope, "outer/")
                    self.assertEqual(
                        tf.compat.v1.get_variable("w1", []).name,
                        "outer/inner/w1:0",
                    )

    @tf_test_utils.run_in_graph_and_eager_modes
    @run_inside_wrap_function_in_eager_mode
    def testCreatedByDefaultNameWhenAuxiliaryNameScopeIsFalse(self):
        with self.cached_session():
            with tf.compat.v1.variable_scope(
                None, default_name="default", auxiliary_name_scope=False
            ):
                self.assertEqual(
                    tf.compat.v1.get_variable("w", []).name, "default/w:0"
                )

            with tf.compat.v1.variable_scope("outer"):
                with tf.compat.v1.variable_scope(
                    None, default_name="default", auxiliary_name_scope=False
                ) as inner:
                    self.assertEqual(inner.original_name_scope, "outer/")
                    self.assertEqual(
                        tf.compat.v1.get_variable("w", []).name,
                        "outer/default/w:0",
                    )

    @tf_test_utils.run_in_graph_and_eager_modes
    @run_inside_wrap_function_in_eager_mode
    def testReenterRootScopeWhenAuxiliaryNameScopeIsFalse(self):
        with self.cached_session():
            root_scope = tf.compat.v1.get_variable_scope()
            with tf.compat.v1.variable_scope(
                root_scope, auxiliary_name_scope=False
            ):
                self.assertEqual(tf.compat.v1.get_variable("w", []).name, "w:0")

            with tf.compat.v1.variable_scope("outer"):
                with tf.compat.v1.variable_scope(
                    root_scope, auxiliary_name_scope=False
                ) as inner:
                    self.assertEqual(inner.original_name_scope, "")
                    self.assertEqual(
                        tf.compat.v1.get_variable("w1", []).name, "w1:0"
                    )

    @tf_test_utils.run_in_graph_and_eager_modes
    @run_inside_wrap_function_in_eager_mode
    def testAuxiliaryNameScopeIsInvalid(self):
        with self.cached_session():
            with self.assertRaisesRegex(TypeError, "auxiliary_name_scope"):
                with tf.compat.v1.variable_scope(
                    None, default_name="scope", auxiliary_name_scope="invalid"
                ):
                    pass

            with self.assertRaisesRegex(TypeError, "auxiliary_name_scope"):
                with tf.compat.v1.variable_scope(
                    "scope", auxiliary_name_scope="invalid"
                ):
                    pass

            with tf.compat.v1.variable_scope("scope") as scope:
                pass
            with self.assertRaisesRegex(TypeError, "auxiliary_name_scope"):
                with tf.compat.v1.variable_scope(
                    scope, auxiliary_name_scope="invalid"
                ):
                    pass

    @tf_test_utils.run_in_graph_and_eager_modes
    @run_inside_wrap_function_in_eager_mode
    def testReuseScopeWithoutNameScopeCollision(self):
        # GitHub issue: #13429
        with self.cached_session():
            with tf.compat.v1.variable_scope("outer"):
                with tf.compat.v1.variable_scope("inner") as inner:
                    pass

            with tf.compat.v1.variable_scope(
                inner, auxiliary_name_scope=False
            ) as scope:
                with tf.name_scope(scope.original_name_scope):
                    self.assertEqual(
                        tf.compat.v1.get_variable("w", []).name,
                        "outer/inner/w:0",
                    )

            with tf.compat.v1.variable_scope("another"):
                with tf.compat.v1.variable_scope(
                    inner, auxiliary_name_scope=False
                ) as scope1:
                    with tf.name_scope(scope1.original_name_scope):
                        self.assertEqual(
                            tf.compat.v1.get_variable("w1", []).name,
                            "outer/inner/w1:0",
                        )

    @tf_test_utils.run_in_graph_and_eager_modes
    @run_inside_wrap_function_in_eager_mode
    def testGetVarWithDevice(self):
        g = tf.Graph()
        varname_type = []

        def device_func(op):
            if op.type in ["Variable", "VariableV2", "VarHandleOp"]:
                varname_type.append((op.name, op.get_attr("dtype")))
            return "/device:GPU:0"

        with g.as_default():
            with tf.compat.v1.device(device_func):
                _ = tf.compat.v1.get_variable("x", (100, 200))
                _ = tf.compat.v1.get_variable(
                    "y", dtype=tf.int64, initializer=numpy.arange(73)
                )
        self.assertEqual(varname_type[0], ("x", tf.float32))
        self.assertEqual(varname_type[1], ("y", tf.int64))

    @tf_test_utils.run_in_graph_and_eager_modes
    @run_inside_wrap_function_in_eager_mode
    def testGetVariableWithRefDtype(self):
        v = tf.compat.v1.get_variable("v", shape=[3, 4], dtype=tf.float32)
        # Ensure it is possible to do get_variable with a _ref dtype passed in.
        _ = tf.compat.v1.get_variable("w", shape=[5, 6], dtype=v.dtype)

    @tf_test_utils.run_in_graph_and_eager_modes
    @run_inside_wrap_function_in_eager_mode
    def testGetVariableWithInitializerWhichTakesNoArgs(self):
        v = tf.compat.v1.get_variable("foo", initializer=lambda: [2])
        self.assertEqual(v.name, "foo:0")

    @tf_test_utils.run_in_graph_and_eager_modes
    @run_inside_wrap_function_in_eager_mode
    def testGetVariableWithInitializerWhichTakesOptionalArgs(self):
        v = tf.compat.v1.get_variable("foo", initializer=lambda x=True: [2])
        self.assertEqual(v.name, "foo:0")

    @tf_test_utils.run_in_graph_and_eager_modes
    @run_inside_wrap_function_in_eager_mode
    def testTwoGraphs(self):
        def f():
            g1 = tf.Graph()
            g2 = tf.Graph()
            with g1.as_default():
                with g2.as_default():
                    with tf.compat.v1.variable_scope("_"):
                        pass

        self.assertRaisesRegex(
            ValueError, "'_' is not a valid (?:root )?scope name", f
        )


class VariableScopeWithCustomGetterTest(tf.test.TestCase):
    @tf_test_utils.run_in_graph_and_eager_modes
    @run_inside_wrap_function_in_eager_mode
    def testNonCallableGetterFails(self):
        with self.assertRaisesRegex(
            ValueError, r"custom_getter .* not callable:"
        ):
            with tf.compat.v1.variable_scope("scope0", custom_getter=3):
                tf.compat.v1.get_variable("name0")
        with self.assertRaisesRegex(
            ValueError, r"custom_getter .* not callable:"
        ):
            tf.compat.v1.get_variable("name0", custom_getter=3)

    @tf_test_utils.run_in_graph_and_eager_modes
    @run_inside_wrap_function_in_eager_mode
    def testNoSideEffectsWithIdentityCustomGetter(self):
        called = [0]

        def custom_getter(getter, *args, **kwargs):
            called[0] += 1
            return getter(*args, **kwargs)

        with tf.compat.v1.variable_scope(
            "scope", custom_getter=custom_getter
        ) as scope:
            v = tf.compat.v1.get_variable("v", [1])
        with tf.compat.v1.variable_scope(scope, reuse=True):
            v2 = tf.compat.v1.get_variable("v", [1])
        with tf.compat.v1.variable_scope("new_scope") as new_scope:
            v3 = tf.compat.v1.get_variable("v3", [1])
        with tf.compat.v1.variable_scope(
            new_scope, reuse=True, custom_getter=custom_getter
        ):
            v4 = tf.compat.v1.get_variable("v3", [1])

        self.assertIs(v, v2)
        self.assertIs(v3, v4)
        self.assertEqual(3, called[0])  # skipped one in the first new_scope

    @tf_test_utils.run_in_graph_and_eager_modes
    @run_inside_wrap_function_in_eager_mode
    def testSynchronizationAndAggregationWithCustomGetter(self):
        called = [0]
        synchronization = tf.VariableSynchronization.AUTO
        aggregation = tf.compat.v1.VariableAggregation.NONE

        def custom_getter(getter, *args, **kwargs):
            called[0] += 1

            # Verify synchronization and aggregation kwargs are as expected.
            self.assertEqual(kwargs["synchronization"], synchronization)
            self.assertEqual(kwargs["aggregation"], aggregation)
            return getter(*args, **kwargs)

        with tf.compat.v1.variable_scope("scope", custom_getter=custom_getter):
            tf.compat.v1.get_variable("v", [1])
        self.assertEqual(1, called[0])

        with tf.compat.v1.variable_scope("scope", custom_getter=custom_getter):
            synchronization = tf.VariableSynchronization.ON_READ
            aggregation = tf.compat.v1.VariableAggregation.MEAN
            tf.compat.v1.get_variable(
                "v1",
                [1],
                synchronization=synchronization,
                aggregation=aggregation,
            )

        self.assertEqual(2, called[0])

    @tf_test_utils.run_in_graph_and_eager_modes
    @run_inside_wrap_function_in_eager_mode
    def testVariableCreator(self):
        variable_names = []

        def creator_a(next_creator, **kwargs):
            variable_names.append(kwargs.get("name", ""))
            return next_creator(**kwargs)

        def creator_b(next_creator, **kwargs):
            kwargs["name"] = "forced_name"
            return next_creator(**kwargs)

        with tf.variable_creator_scope(creator_a):
            with tf.variable_creator_scope(creator_b):
                tf.compat.v1.Variable(1.0, name="one_name")

        self.assertEqual(variable_names[0], "forced_name")

        called = [False]

        def creater_c(next_creator, **kwargs):
            called[0] = True
            self.assertEqual(
                kwargs["synchronization"], tf.VariableSynchronization.ON_WRITE
            )
            self.assertEqual(
                kwargs["aggregation"], tf.compat.v1.VariableAggregation.MEAN
            )
            return next_creator(**kwargs)

        with tf.variable_creator_scope(creater_c):
            tf.compat.v1.get_variable(
                "v",
                [],
                synchronization=tf.VariableSynchronization.ON_WRITE,
                aggregation=tf.compat.v1.VariableAggregation.MEAN,
            )
        self.assertTrue(called[0])

    @tf_test_utils.run_in_graph_and_eager_modes
    @run_inside_wrap_function_in_eager_mode
    def testVariableCreatorNestingError(self):
        def creator(next_creator, **kwargs):
            return next_creator(**kwargs)

        # Save the state so we can clean up at the end.
        graph = tf.compat.v1.get_default_graph()
        old_creator_stack = graph._variable_creator_stack

        try:
            scope = tf.variable_creator_scope(creator)
            scope.__enter__()
            with tf.variable_creator_scope(creator):
                with self.assertRaises(RuntimeError):
                    scope.__exit__(None, None, None)
        finally:
            graph._variable_creator_stack = old_creator_stack


class VariableScopeMultithreadedTest(tf.test.TestCase):
    @tf_test_utils.run_in_graph_and_eager_modes
    @run_inside_wrap_function_in_eager_mode
    def testReenterMainScope(self):
        def thread_fn(graph, main_thread_scope):
            with graph.as_default():
                # Variable created with main scope will have prefix "main".
                with tf.compat.v1.variable_scope(main_thread_scope):
                    with tf.compat.v1.variable_scope("foo"):
                        v = tf.compat.v1.get_variable("v", [])
                        self.assertEqual("main/foo/v:0", v.name)

                # Variable created outside main scope will not have prefix
                # "main".
                with tf.compat.v1.variable_scope("bar"):
                    v = tf.compat.v1.get_variable("v", [])
                    self.assertEqual("bar/v:0", v.name)

        graph = tf.compat.v1.get_default_graph()
        with tf.compat.v1.variable_scope("main") as main_thread_scope:
            thread = threading.Thread(
                target=thread_fn, args=(graph, main_thread_scope)
            )
            thread.start()
            thread.join()


class CompatV1TemplateScaleByY(base_layer.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        def my_op(x, scalar_name):
            var1 = tf.compat.v1.get_variable(
                scalar_name,
                shape=[],
                regularizer=regularizers.L2(),
                initializer=tf.compat.v1.constant_initializer(1.5),
            )
            return x * var1

        self.scale_by_y = tf.compat.v1.make_template(
            "scale_by_y", my_op, scalar_name="y"
        )

    @variable_scope_shim.track_tf1_style_variables
    def call(self, inputs):
        with tf.compat.v1.variable_scope("foo"):
            return self.scale_by_y(inputs)


class VariableScopeModule(tf.Module):
    """Module that uses the shim."""

    @variable_scope_shim.track_tf1_style_variables
    def __call__(self, *args, **kwargs):
        with self.name_scope:
            return self.forward_pass(*args, **kwargs)

    def get_compat_v1_regularization_losses(self):
        """Dict w/ regularization losses from
        `get_variable`&`compat.v1.layers`."""
        return {
            name: regularizer()
            for name, regularizer in self._tf1_style_var_store._regularizers.items()  # noqa: E501
        }


@test_combinations.generate(test_combinations.combine(mode=["eager"]))
class TF1VariableScopeLayerTest(tf.test.TestCase, parameterized.TestCase):
    def test_get_variable(self):
        # Test the shim when using `get_variable` (and regularizers) directly

        class WrappedDenseLayer(base_layer.Layer):
            def __init__(self, units, *args, **kwargs):
                super().__init__(*args, **kwargs)
                self.units = units

            @variable_scope_shim.track_tf1_style_variables
            def call(self, inputs, training=None):
                out = inputs
                with tf.compat.v1.variable_scope("dense_one"):
                    # The weights are created with a `regularizer`,
                    # so the layer should track their regularization losses
                    kernel = tf.compat.v1.get_variable(
                        shape=[out.shape[-1], self.units],
                        regularizer=regularizers.L2(),
                        initializer=tf.compat.v1.ones_initializer(),
                        name="kernel",
                    )
                    bias = tf.compat.v1.get_variable(
                        shape=[
                            self.units,
                        ],
                        initializer=tf.compat.v1.zeros_initializer(),
                        name="bias",
                    )
                    out = tf.matmul(out, kernel)
                    out = tf.nn.bias_add(out, bias)
                with tf.compat.v1.variable_scope("nested_scope"):
                    with tf.compat.v1.variable_scope("dense_two"):
                        kernel = tf.compat.v1.get_variable(
                            shape=[out.shape[-1], self.units],
                            regularizer=regularizers.L2(),
                            initializer=tf.compat.v1.ones_initializer(),
                            name="kernel",
                        )
                        bias = tf.compat.v1.get_variable(
                            shape=[
                                self.units,
                            ],
                            initializer=tf.compat.v1.zeros_initializer(),
                            name="bias",
                        )
                        out = tf.matmul(out, kernel)
                        out = tf.nn.bias_add(out, bias)
                return out

        layer = WrappedDenseLayer(10)
        out = layer(tf.ones(shape=(5, 5)))
        weights = {x.name: x for x in layer.variables}

        # Verify the correct output, regularization losses, + variables were
        # made
        self.assertEqual(
            weights.keys(),
            {
                "dense_one/bias:0",
                "dense_one/kernel:0",
                "nested_scope/dense_two/bias:0",
                "nested_scope/dense_two/kernel:0",
            },
        )
        self.assertAllEqual(out, tf.ones(shape=(5, 10)) * 50)
        self.assertAllEqual(tf.add_n(layer.losses), 1.5)

        # Verify reuse by updating the variables then re-running
        weights["dense_one/kernel:0"].assign(tf.ones(shape=(5, 10)) * 2)
        weights["nested_scope/dense_two/kernel:0"].assign(
            tf.ones(shape=(10, 10)) * 2
        )
        out = layer(tf.ones(shape=(5, 5)))
        self.assertAllEqual(out, tf.ones(shape=(5, 10)) * 200)
        self.assertAllEqual(tf.add_n(layer.losses), 6)

    def test_compat_v1_layer(self):
        # Test the shim when using `compat.v1` layers

        class WrappedDenseLayer(base_layer.Layer):
            def __init__(self, units, *args, **kwargs):
                super().__init__(*args, **kwargs)
                self.units = units

            @variable_scope_shim.track_tf1_style_variables
            def call(self, inputs, training=None):
                out = core_layers.dense(
                    inputs,
                    self.units,
                    name="dense_one",
                    kernel_initializer=tf.compat.v1.ones_initializer(),
                    kernel_regularizer="l2",
                )
                with tf.compat.v1.variable_scope("nested_scope"):
                    out = core_layers.dense(
                        out,
                        self.units,
                        name="dense_two",
                        kernel_initializer=tf.compat.v1.ones_initializer(),
                        kernel_regularizer="l2",
                    )
                return out

        layer = WrappedDenseLayer(10)
        out = layer(tf.ones(shape=(5, 5)))
        weights = {x.name: x for x in layer.variables}

        # Verify the correct output, losses, + variables were made
        self.assertEqual(
            weights.keys(),
            {
                "dense_one/bias:0",
                "dense_one/kernel:0",
                "nested_scope/dense_two/bias:0",
                "nested_scope/dense_two/kernel:0",
            },
        )
        self.assertAllEqual(out, tf.ones(shape=(5, 10)) * 50)
        self.assertAllEqual(tf.add_n(layer.losses), 1.5)

        # Verify reuse by updating the variables then re-running
        weights["dense_one/kernel:0"].assign(tf.ones(shape=(5, 10)) * 2)
        weights["nested_scope/dense_two/kernel:0"].assign(
            tf.ones(shape=(10, 10)) * 2
        )
        out = layer(tf.ones(shape=(5, 5)))
        self.assertAllEqual(out, tf.ones(shape=(5, 10)) * 200)
        self.assertAllEqual(tf.add_n(layer.losses), 6)

    def test_shim_exporting(self):
        class WrappedDenseLayer(base_layer.Layer):
            def __init__(self, units, *args, **kwargs):
                super().__init__(*args, **kwargs)
                self.units = units

            @variable_scope_shim.track_tf1_style_variables
            def call(self, inputs, training=None):
                out = core_layers.dense(
                    inputs,
                    self.units,
                    name="dense_one",
                    kernel_initializer=tf.compat.v1.ones_initializer(),
                    kernel_regularizer="l2",
                )
                with tf.compat.v1.variable_scope("nested_scope"):
                    out = core_layers.dense(
                        out,
                        self.units,
                        name="dense_two",
                        kernel_initializer=tf.compat.v1.ones_initializer(),
                        kernel_regularizer="l2",
                    )
                return out

        layer = WrappedDenseLayer(10)
        layer(tf.ones(shape=(5, 5)))

        tmp_dir = self.get_temp_dir()

        # Try exporting the layer directly
        tf.saved_model.save(layer, tmp_dir)

        # Try exporting the layer nested in a functional model
        # This is where saving reflection gets tricky due to
        # trying to replace the passed training arg in training=True
        # and training=False modes
        inp = input_layer_module.Input(shape=(5, 5))
        outs = layer(inp)
        model = models.Model(inp, outs)
        tf.saved_model.save(model, tmp_dir)

    def test_variable_store_scope_get_variable(self):
        # Test the module shim when using `get_variable` (and regularizers)
        # directly

        class WrappedDenseLayer(tf.Module):
            def __init__(self, units, *args, **kwargs):
                super().__init__(*args, **kwargs)
                self.units = units
                self._variable_store = variable_scope_shim._EagerVariableStore()

            def get_compat_v1_regularization_losses(self):
                """Dict w/ regularization losses from `get_variable`."""
                return {
                    name: regularizer()
                    for name, regularizer in self._variable_store._regularizers.items()  # noqa: E501
                }

            def __call__(self, inputs, training=None):
                with self._variable_store.scope():
                    out = inputs
                    with tf.compat.v1.variable_scope("dense_one"):
                        # The weights are created with a `regularizer`,
                        # so the layer should track their regularization losses
                        kernel = tf.compat.v1.get_variable(
                            shape=[out.shape[-1], self.units],
                            regularizer=regularizers.L2(),
                            initializer=tf.compat.v1.ones_initializer(),
                            name="kernel",
                        )
                        bias = tf.compat.v1.get_variable(
                            shape=[
                                self.units,
                            ],
                            initializer=tf.compat.v1.zeros_initializer(),
                            name="bias",
                        )
                        out = tf.matmul(out, kernel)
                        out = tf.nn.bias_add(out, bias)
                    with tf.compat.v1.variable_scope("nested_scope"):
                        with tf.compat.v1.variable_scope("dense_two"):
                            kernel = tf.compat.v1.get_variable(
                                shape=[out.shape[-1], self.units],
                                regularizer=regularizers.L2(),
                                initializer=tf.compat.v1.ones_initializer(),
                                name="kernel",
                            )
                            bias = tf.compat.v1.get_variable(
                                shape=[
                                    self.units,
                                ],
                                initializer=tf.compat.v1.zeros_initializer(),
                                name="bias",
                            )
                            out = tf.matmul(out, kernel)
                            out = tf.nn.bias_add(out, bias)
                    return out

        layer = WrappedDenseLayer(10)
        out = layer(tf.ones(shape=(5, 5)))
        weights = {x.name: x for x in layer.variables}

        # Verify the correct output, regularization losses, + variables were
        # made
        self.assertEqual(
            weights.keys(),
            {
                "dense_one/bias:0",
                "dense_one/kernel:0",
                "nested_scope/dense_two/bias:0",
                "nested_scope/dense_two/kernel:0",
            },
        )
        self.assertAllEqual(out, tf.ones(shape=(5, 10)) * 50)
        self.assertAllEqual(
            tf.add_n(layer.get_compat_v1_regularization_losses().values()), 1.5
        )

        # Verify reuse by updating the variables then re-running
        weights["dense_one/kernel:0"].assign(tf.ones(shape=(5, 10)) * 2)
        weights["nested_scope/dense_two/kernel:0"].assign(
            tf.ones(shape=(10, 10)) * 2
        )
        out = layer(tf.ones(shape=(5, 5)))
        self.assertAllEqual(out, tf.ones(shape=(5, 10)) * 200)
        self.assertAllEqual(
            tf.add_n(layer.get_compat_v1_regularization_losses().values()), 6
        )

    def test_module_get_variable(self):
        # Test the module shim when using `get_variable` (and regularizers)
        # directly

        class WrappedDenseLayer(VariableScopeModule):
            def __init__(self, units, *args, **kwargs):
                super().__init__(*args, **kwargs)
                self.units = units

            def forward_pass(self, inputs, training=None):
                out = inputs
                with tf.compat.v1.variable_scope("dense_one"):
                    # The weights are created with a `regularizer`,
                    # so the layer should track their regularization losses
                    kernel = tf.compat.v1.get_variable(
                        shape=[out.shape[-1], self.units],
                        regularizer=regularizers.L2(),
                        initializer=tf.compat.v1.ones_initializer(),
                        name="kernel",
                    )
                    bias = tf.compat.v1.get_variable(
                        shape=[
                            self.units,
                        ],
                        initializer=tf.compat.v1.zeros_initializer(),
                        name="bias",
                    )
                    out = tf.matmul(out, kernel)
                    out = tf.nn.bias_add(out, bias)
                with tf.compat.v1.variable_scope("nested_scope"):
                    with tf.compat.v1.variable_scope("dense_two"):
                        kernel = tf.compat.v1.get_variable(
                            shape=[out.shape[-1], self.units],
                            regularizer=regularizers.L2(),
                            initializer=tf.compat.v1.ones_initializer(),
                            name="kernel",
                        )
                        bias = tf.compat.v1.get_variable(
                            shape=[
                                self.units,
                            ],
                            initializer=tf.compat.v1.zeros_initializer(),
                            name="bias",
                        )
                        out = tf.matmul(out, kernel)
                        out = tf.nn.bias_add(out, bias)
                return out

        layer = WrappedDenseLayer(10)
        out = layer(tf.ones(shape=(5, 5)))
        weights = {x.name: x for x in layer.variables}

        # Verify the correct output, regularization losses, + variables were
        # made
        self.assertEqual(
            weights.keys(),
            {
                "dense_one/bias:0",
                "dense_one/kernel:0",
                "nested_scope/dense_two/bias:0",
                "nested_scope/dense_two/kernel:0",
            },
        )
        self.assertAllEqual(out, tf.ones(shape=(5, 10)) * 50)
        self.assertAllEqual(
            tf.add_n(layer.get_compat_v1_regularization_losses().values()), 1.5
        )

        # Verify reuse by updating the variables then re-running
        weights["dense_one/kernel:0"].assign(tf.ones(shape=(5, 10)) * 2)
        weights["nested_scope/dense_two/kernel:0"].assign(
            tf.ones(shape=(10, 10)) * 2
        )
        out = layer(tf.ones(shape=(5, 5)))
        self.assertAllEqual(out, tf.ones(shape=(5, 10)) * 200)
        self.assertAllEqual(
            tf.add_n(layer.get_compat_v1_regularization_losses().values()), 6
        )

    def test_module_compat_v1_layer(self):
        # Test the module shim when using `compat.v1` layers

        class WrappedDenseLayer(VariableScopeModule):
            def __init__(self, units, *args, **kwargs):
                super().__init__(*args, **kwargs)
                self.units = units

            def forward_pass(self, inputs, training=None):
                out = core_layers.dense(
                    inputs,
                    self.units,
                    name="dense_one",
                    kernel_initializer=tf.compat.v1.ones_initializer(),
                    kernel_regularizer="l2",
                )
                with tf.compat.v1.variable_scope("nested_scope"):
                    out = core_layers.dense(
                        out,
                        self.units,
                        name="dense_two",
                        kernel_initializer=tf.compat.v1.ones_initializer(),
                        kernel_regularizer="l2",
                    )
                return out

        layer = WrappedDenseLayer(10)
        out = layer(tf.ones(shape=(5, 5)))
        weights = {x.name: x for x in layer.variables}

        # Verify the correct output, losses, + variables were made
        self.assertEqual(
            weights.keys(),
            {
                "dense_one/bias:0",
                "dense_one/kernel:0",
                "nested_scope/dense_two/bias:0",
                "nested_scope/dense_two/kernel:0",
            },
        )
        self.assertAllEqual(out, tf.ones(shape=(5, 10)) * 50)
        self.assertAllEqual(
            tf.add_n(layer.get_compat_v1_regularization_losses().values()), 1.5
        )

        # Verify reuse by updating the variables then re-running
        weights["dense_one/kernel:0"].assign(tf.ones(shape=(5, 10)) * 2)
        weights["nested_scope/dense_two/kernel:0"].assign(
            tf.ones(shape=(10, 10)) * 2
        )
        out = layer(tf.ones(shape=(5, 5)))
        self.assertAllEqual(out, tf.ones(shape=(5, 10)) * 200)
        self.assertAllEqual(
            tf.add_n(layer.get_compat_v1_regularization_losses().values()), 6
        )

    def test_shim_nesting(self):
        # Test that nesting the shim in itself works

        class NestedLayer(base_layer.Layer):
            def __init__(self, units, name, *args, **kwargs):
                super().__init__(*args, name=name, **kwargs)
                self.units = units

            @variable_scope_shim.track_tf1_style_variables
            def call(self, inputs):
                out = inputs
                with tf.compat.v1.variable_scope(self.name):
                    # The weights are created with a `regularizer`,
                    # so the layer should track their regularization losses
                    kernel = tf.compat.v1.get_variable(
                        shape=[out.shape[-1], self.units],
                        regularizer=regularizers.L2(1.0),
                        initializer=tf.compat.v1.ones_initializer(),
                        name="kernel",
                    )
                    bias = tf.compat.v1.get_variable(
                        shape=[
                            self.units,
                        ],
                        initializer=tf.compat.v1.initializers.zeros,
                        name="bias",
                    )
                    out = tf.linalg.matmul(out, kernel)
                    out = tf.compat.v1.nn.bias_add(out, bias)
                return out

        class WrappedDenseLayer(base_layer.Layer):
            def __init__(self, units, **kwargs):
                super().__init__(**kwargs)
                self.units = units
                self.dense_layer_a = None
                self.dense_layer_b = None

            @variable_scope_shim.track_tf1_style_variables
            def call(self, inputs):
                # Only create the nested tf.variable/module/layer/model if it
                # has not already been created!
                if not self.dense_layer_a:
                    self.dense_layer_a = NestedLayer(
                        self.units * 2, "dense_one"
                    )
                out = self.dense_layer_a(inputs)
                if not self.dense_layer_b:
                    self.dense_layer_b = NestedLayer(self.units, "dense_two")
                out = self.dense_layer_b(out)
                return out

        layer = WrappedDenseLayer(5)
        out = layer(tf.ones(shape=(1, 3)))
        weights = {x.name: x for x in layer.variables}

        # Verify the correct output, losses, + variables were made
        # (Specifically: no double-counting of any weights or reg. losses
        # between nested components!)
        self.assertEqual(
            {var.name for var in layer.trainable_weights},
            {
                "dense_one/bias:0",
                "dense_one/kernel:0",
                "dense_two/bias:0",
                "dense_two/kernel:0",
            },
        )
        self.assertEqual(
            {var.name for var in layer.dense_layer_a.weights},
            {"dense_one/bias:0", "dense_one/kernel:0"},
        )
        self.assertEqual(
            {var.name for var in layer.dense_layer_b.weights},
            {"dense_two/bias:0", "dense_two/kernel:0"},
        )
        self.assertAllEqual(out, tf.ones(shape=(1, 5)) * 30)
        self.assertAllEqual(tf.add_n(layer.dense_layer_a.losses), 30)
        self.assertAllEqual(tf.add_n(layer.dense_layer_b.losses), 50)
        self.assertAllEqual(tf.add_n(layer.losses), 80)

        # Verify reuse by updating the variables then re-running
        weights["dense_one/kernel:0"].assign(tf.ones(shape=(3, 10)) * 2)
        weights["dense_two/kernel:0"].assign(tf.ones(shape=(10, 5)) * 2)
        out = layer(tf.ones(shape=(1, 3)))
        self.assertAllEqual(out, tf.ones(shape=(1, 5)) * 120)
        self.assertAllEqual(tf.add_n(layer.losses), 320)

    def test_compat_v1_make_template_in_shim_eager(self):
        # Test the shim when using `compat.v1.make_template`
        # Verify it works correctly in eager
        layer = CompatV1TemplateScaleByY()
        for _ in range(3):
            # Use multiple calls to verify that no new weights get created
            self.assertAllEqual(
                layer(tf.ones(shape=(2, 3))), tf.constant(1.5, shape=(2, 3))
            )
        self.assertAllEqual(
            {var.name: var.numpy() for var in layer.weights},
            {"foo/scale_by_y/y:0": 1.5},
        )
        self.assertAllEqual(
            tf.add_n(layer.losses), regularizers.L2()(layer.weights[0])
        )

    def test_compat_v1_make_template_in_shim_tf_function(self):
        # Test the shim when using `compat.v1.make_template`
        # Verify it works correctly in a tf.function
        # when made outside the function
        layer = CompatV1TemplateScaleByY()

        @tf.function
        def foo(x):
            return layer(x), tf.add_n(layer.losses)

        for _ in range(3):
            # Use multiple calls to verify that no new weights get created
            out, loss = foo(tf.ones(shape=(2, 3)))
            self.assertAllEqual(out, tf.constant(1.5, shape=(2, 3)))
            self.assertAllEqual(loss, regularizers.L2()(layer.weights[0]))
        self.assertAllEqual(
            {var.name: var.numpy() for var in layer.weights},
            {"foo/scale_by_y/y:0": 1.5},
        )

    def test_compat_v1_make_template_in_trace_in_shim(self):
        # Test the shim when using `compat.v1.make_template`
        # Verify it works correctly when the make_template/layer/shim
        # is created on the first tf.function trace!
        layers = {}

        @tf.function
        def bar(x):
            if "layer" not in layers:
                layers["layer"] = CompatV1TemplateScaleByY()
            layer = layers["layer"]
            return layer(x), tf.add_n(layer.losses)

        for _ in range(3):
            # Use multiple calls to verify that no new weights get created
            out, loss = bar(tf.ones(shape=(2, 3)))
            self.assertAllEqual(out, tf.constant(1.5, shape=(2, 3)))
            self.assertAllEqual(
                loss, regularizers.L2()(layers["layer"].weights[0])
            )
        self.assertAllEqual(
            {var.name: var.numpy() for var in layers["layer"].weights},
            {"foo/scale_by_y/y:0": 1.5},
        )

    def test_only_track_get_variable(self):
        # Test the shim does not try tracking or reusing variables
        # that were not created by get_variable. These variables/modules/layers
        # need to be tracked separately

        class WrappedDenseLayer(base_layer.Layer):
            def __init__(self, units, **kwargs):
                super().__init__(**kwargs)
                self.units = units
                self._dense_model = None

            @variable_scope_shim.track_tf1_style_variables
            def call(self, inputs):
                dense_layer = core.Dense(
                    self.units,
                    name="dense",
                    kernel_initializer=tf.compat.v1.ones_initializer(),
                    kernel_regularizer="l2",
                )
                return dense_layer(inputs)

        layer = WrappedDenseLayer(10)
        out = layer(tf.ones(shape=(5, 5)))
        self.assertAllEqual(out, tf.ones(shape=(5, 10)) * 5)

        self.assertEmpty(layer.weights)

    def test_embedded_keras_model(self):
        # Test the shim when embedding a Keras model inside of it
        # And assigning the model to an attribute

        class WrappedDenseLayer(base_layer.Layer):
            def __init__(self, units, **kwargs):
                super().__init__(**kwargs)
                self.units = units
                self._dense_model = None

            @variable_scope_shim.track_tf1_style_variables
            def call(self, inputs):
                if not self._dense_model:
                    inp = input_layer_module.Input(shape=inputs.shape)
                    dense_layer = core.Dense(
                        self.units,
                        name="dense",
                        kernel_initializer=tf.compat.v1.ones_initializer(),
                        kernel_regularizer="l2",
                    )
                    self._dense_model = training_module.Model(
                        inputs=inp, outputs=dense_layer(inp)
                    )
                return self._dense_model(inputs)

        layer = WrappedDenseLayer(10)
        out = layer(tf.ones(shape=(5, 5)))
        weights = {x.name: x for x in layer.variables}

        # Verify the correct output, losses, + variables were made
        self.assertEqual(weights.keys(), {"dense/bias:0", "dense/kernel:0"})
        self.assertAllEqual(out, tf.ones(shape=(5, 10)) * 5)
        self.assertAllEqual(tf.add_n(layer.losses), 0.5)

        # Verify reuse by updating the variables then re-running
        weights["dense/kernel:0"].assign(tf.ones(shape=(5, 10)) * 2)
        out = layer(tf.ones(shape=(5, 5)))
        self.assertAllEqual(out, tf.ones(shape=(5, 10)) * 10)
        self.assertAllEqual(tf.add_n(layer.losses), 2)

    def test_embedded_keras_model_in_module(self):
        # Test the module shim when embedding a Keras model inside of it
        # And assigning the model to an attribute

        class WrappedDenseLayer(VariableScopeModule):
            def __init__(self, units, **kwargs):
                super().__init__(**kwargs)
                self.units = units
                self._dense_model = None

            def forward_pass(self, inputs):
                if not self._dense_model:
                    inp = input_layer_module.Input(shape=inputs.shape)
                    dense_layer = core.Dense(
                        self.units,
                        name="dense",
                        kernel_initializer=tf.compat.v1.ones_initializer(),
                        kernel_regularizer="l2",
                    )
                    self._dense_model = training_module.Model(
                        inputs=inp, outputs=dense_layer(inp)
                    )
                return self._dense_model(inputs)

        layer = WrappedDenseLayer(10)
        out = layer(tf.ones(shape=(5, 5)))
        weights = {x.name: x for x in layer.variables}

        # Verify the correct output, losses, + variables were made
        self.assertEqual(weights.keys(), {"dense/bias:0", "dense/kernel:0"})
        self.assertAllEqual(out, tf.ones(shape=(5, 10)) * 5)

        # The module shim will only track regularization losses made by
        # compat.v1.layers and compat.v1.get_variable. Other regularization
        # losses must be tracked by separate user-created mechanisms.
        self.assertEmpty(layer.get_compat_v1_regularization_losses())

        # Verify reuse by updating the variables then re-running
        weights["dense/kernel:0"].assign(tf.ones(shape=(5, 10)) * 2)
        out = layer(tf.ones(shape=(5, 5)))
        self.assertAllEqual(out, tf.ones(shape=(5, 10)) * 10)

        # The module shim will only track regularization losses made by
        # compat.v1.layers and compat.v1.get_variable. Other regularization
        # losses must be tracked by separate user-created mechanisms.
        self.assertEmpty(layer.get_compat_v1_regularization_losses())

    def test_training_arg(self):
        # Test the shim when passing in a Keras `training` arg

        class TrainingCheckLayer(base_layer.Layer):
            def __init__(self, units, *args, **kwargs):
                super().__init__(*args, **kwargs)
                self.units = units

            @variable_scope_shim.track_tf1_style_variables
            def call(self, inputs, training=None):
                if training:
                    out = core_layers.dense(
                        inputs, self.units, name="dense_training"
                    )
                else:
                    out = core_layers.dense(
                        inputs, self.units, name="dense_no_training"
                    )
                return out

        layer = TrainingCheckLayer(10)
        layer(tf.ones(shape=(5, 5)), training=True)
        weights = {x.name: x for x in layer.variables}

        # Verify the correct variables were made
        self.assertEqual(
            weights.keys(), {"dense_training/bias:0", "dense_training/kernel:0"}
        )

        layer = TrainingCheckLayer(10)
        layer(tf.ones(shape=(5, 5)))
        weights = {x.name: x for x in layer.variables}

        # Verify the correct variables were made
        self.assertEqual(
            weights.keys(),
            {"dense_no_training/bias:0", "dense_no_training/kernel:0"},
        )

    def test_incorrect_decoration(self):
        # Raise an error if you incorrectly decorate a method
        # that is not a method of a Module, layer, or model:
        @variable_scope_shim.track_tf1_style_variables
        def foo(x):
            return x * 2

        with self.assertRaisesRegex(ValueError, "does not extend"):
            foo(tf.ones(shape=(4, 4)))


class GetOrCreateLayerTest(tf.test.TestCase, parameterized.TestCase):
    @test_combinations.generate(test_combinations.combine(mode=["eager"]))
    def test_get_or_create_layer_with_regularizer_eager(self):
        class NestedLayer(base_layer.Layer):
            def __init__(self, units, *args, **kwargs):
                super().__init__(*args, **kwargs)
                self.units = units

            def build_model(self):
                inp = input_layer_module.Input(shape=(5, 5))
                dense_layer = core.Dense(
                    10,
                    name="dense",
                    kernel_regularizer="l2",
                    kernel_initializer=tf.compat.v1.ones_initializer(),
                )
                model = training_module.Model(
                    inputs=inp, outputs=dense_layer(inp)
                )
                return model

            @variable_scope_shim.track_tf1_style_variables
            def call(self, inputs):
                # enter a variable scope to check module key naming
                with tf.compat.v1.variable_scope("test_scope"):
                    model = variable_scope_shim.get_or_create_layer(
                        "dense_model", self.build_model
                    )
                    return model(inputs)

        layer = NestedLayer(10)
        x = tf.ones(shape=(5, 5))

        out1 = layer(tf.expand_dims(x, 0))

        model1 = layer.submodules[0]._layers["test_scope/dense_model"]

        out2 = layer(tf.expand_dims(x, 0))
        # Verify model produces same output on successive calls with same input
        self.assertAllEqual(out1, out2)

        # Verify the model used on subsequent calls is the same
        model2 = layer.submodules[0]._layers["test_scope/dense_model"]
        self.assertIs(model1, model2)

        # Verify that stored layer computes outputs and losses correctly
        weights = {x.name: x for x in layer.variables}
        self.assertEqual(weights.keys(), {"dense/bias:0", "dense/kernel:0"})
        self.assertAllEqual(out2, tf.ones(shape=(1, 5, 10)) * 5)
        self.assertAllEqual(layer.losses, [0.5])

    @test_combinations.generate(test_combinations.combine(mode=["eager"]))
    def test_get_or_create_layer_no_regularizer_eager(self):
        class NestedLayer(base_layer.Layer):
            def __init__(self, units, *args, **kwargs):
                super().__init__(*args, **kwargs)
                self.units = units

            def build_model(self):
                inp = input_layer_module.Input(shape=(5, 5))
                dense_layer = core.Dense(
                    10,
                    name="dense",
                    kernel_initializer=tf.compat.v1.ones_initializer(),
                )
                model = training_module.Model(
                    inputs=inp, outputs=dense_layer(inp)
                )
                return model

            @variable_scope_shim.track_tf1_style_variables
            def call(self, inputs):
                # enter a variable scope to check module key naming
                with tf.compat.v1.variable_scope("test_scope"):
                    model = variable_scope_shim.get_or_create_layer(
                        "dense_model", self.build_model
                    )
                    return model(inputs)

        layer = NestedLayer(10)
        x = tf.ones(shape=(5, 5))

        out1 = layer(tf.expand_dims(x, 0))

        model1 = layer.submodules[0]._layers["test_scope/dense_model"]

        out2 = layer(tf.expand_dims(x, 0))
        # Verify model produces same output on successive calls with same input
        self.assertAllEqual(out1, out2)

        # Verify the model used on subsequent calls is the same
        model2 = layer.submodules[0]._layers["test_scope/dense_model"]
        self.assertIs(model1, model2)

        # Verify that stored layer computes outputs and losses correctly
        weights = {x.name: x for x in layer.variables}
        self.assertEqual(weights.keys(), {"dense/bias:0", "dense/kernel:0"})
        self.assertAllEqual(out2, tf.ones(shape=(1, 5, 10)) * 5)
        self.assertAllEqual(layer.losses, [0.0])

    @test_combinations.generate(test_combinations.combine(mode=["eager"]))
    def test_get_or_create_layer_tf_function(self):
        class NestedLayer(base_layer.Layer):
            def __init__(self, units, *args, **kwargs):
                super().__init__(*args, **kwargs)
                self.units = units

            def build_model(self):
                inp = input_layer_module.Input(shape=(5, 5))
                dense_layer = core.Dense(
                    10,
                    name="dense",
                    kernel_regularizer="l2",
                )
                model = training_module.Model(
                    inputs=inp, outputs=dense_layer(inp)
                )
                return model

            @variable_scope_shim.track_tf1_style_variables
            def call(self, inputs):
                model = variable_scope_shim.get_or_create_layer(
                    "dense_model", self.build_model
                )
                return model(inputs)

        layer = NestedLayer(10)

        @tf.function
        def foo(x):
            return layer(x), tf.add_n(layer.losses)

        # Verify inner model is reused
        out1, loss1 = foo(tf.ones(shape=(5, 5)))
        out2, loss2 = foo(tf.ones(shape=(5, 5)))
        self.assertAllEqual(out1, out2)
        self.assertAllEqual(loss1, loss2)

    @tf_test_utils.run_deprecated_v1
    def test_get_or_create_layer_graph(self):
        class NestedLayer(object):
            def __init__(self, units, *args, **kwargs):
                super().__init__(*args, **kwargs)
                self.units = units

            def build_model(self):
                inp = input_layer_module.Input(shape=(5, 5))
                dense_layer = core.Dense(
                    10,
                    name="dense",
                    kernel_regularizer="l2",
                    kernel_initializer=tf.compat.v1.ones_initializer(),
                )
                model = training_module.Model(
                    inputs=inp, outputs=dense_layer(inp)
                )
                return model

            def __call__(self, inputs):
                model = variable_scope_shim.get_or_create_layer(
                    "dense_model", self.build_model
                )
                return model(inputs)

        with self.cached_session():
            layer = NestedLayer(10)
            x = tf.ones(shape=(5, 5))

            out1 = layer(tf.expand_dims(x, 0))
            self.evaluate(tf.compat.v1.global_variables_initializer())

            # verify output
            self.assertEqual(out1.shape, tf.TensorShape([1, 5, 10]))
            self.assertAllEqual(out1, tf.ones(shape=(1, 5, 10)) * 5)

            # verify variables are tracked
            weights = {var.name for var in tf.compat.v1.trainable_variables()}
            self.assertEqual(weights, {"dense/bias:0", "dense/kernel:0"})


if __name__ == "__main__":
    tf.test.main()

pylint crashed with a AstroidError and with the following stacktrace:

Traceback (most recent call last):
  File "/Users/.../pylint/pylint/lint/pylinter.py", line 731, in _check_file
    check_astroid_module(ast_node)
  File "/Users/.../pylint/pylint/lint/pylinter.py", line 950, in check_astroid_module
    retval = self._check_astroid_module(
  File "/Users/.../pylint/pylint/lint/pylinter.py", line 1000, in _check_astroid_module
    walker.walk(node)
  File "/Users/.../pylint/pylint/utils/ast_walker.py", line 93, in walk
    self.walk(child)
  File "/Users/.../pylint/pylint/utils/ast_walker.py", line 93, in walk
    self.walk(child)
  File "/Users/.../pylint/pylint/utils/ast_walker.py", line 93, in walk
    self.walk(child)
  File "/Users/.../pylint/pylint/utils/ast_walker.py", line 90, in walk
    callback(astroid)
  File "/Users/.../pylint/pylint/extensions/redefined_variable_type.py", line 100, in visit_assign
    _type = node_type(node.value)
  File "/Users/.../pylint/pylint/checkers/utils.py", line 1371, in node_type
    for var_type in node.infer():
  File "/Users/.../astroid/astroid/nodes/node_ng.py", line 169, in infer
    yield from self._infer(context=context, **kwargs)
  File "/Users/.../astroid/astroid/decorators.py", line 149, in raise_if_nothing_inferred
    yield from generator
  File "/Users/.../astroid/astroid/decorators.py", line 108, in wrapped
    for res in _func(node, context, **kwargs):
  File "/Users/.../astroid/astroid/inference.py", line 264, in infer_call
    yield from callee.infer_call_result(caller=self, context=callcontext)
  File "/Users/.../astroid/astroid/bases.py", line 296, in infer_call_result
    for res in node.infer_call_result(caller, context):
  File "/Users/.../astroid/astroid/nodes/scoped_nodes/scoped_nodes.py", line 1754, in infer_call_result
    yield from returnnode.value.infer(context)
  File "/Users/.../astroid/astroid/nodes/node_ng.py", line 182, in infer
    for i, result in enumerate(self._infer(context=context, **kwargs)):
  File "/Users/.../astroid/astroid/decorators.py", line 139, in raise_if_nothing_inferred
    yield next(generator)
  File "/Users/.../astroid/astroid/decorators.py", line 108, in wrapped
    for res in _func(node, context, **kwargs):
  File "/Users/.../astroid/astroid/bases.py", line 159, in _infer_stmts
    for inf in stmt.infer(context=context):  # type: ignore[union-attr]
  File "/Users/.../astroid/astroid/nodes/node_ng.py", line 182, in infer
    for i, result in enumerate(self._infer(context=context, **kwargs)):
  File "/Users/.../astroid/astroid/decorators.py", line 139, in raise_if_nothing_inferred
    yield next(generator)
  File "/Users/.../astroid/astroid/decorators.py", line 108, in wrapped
    for res in _func(node, context, **kwargs):
  File "/Users/.../astroid/astroid/bases.py", line 159, in _infer_stmts
    for inf in stmt.infer(context=context):  # type: ignore[union-attr]
  File "/Users/.../astroid/astroid/nodes/node_ng.py", line 182, in infer
    for i, result in enumerate(self._infer(context=context, **kwargs)):
  File "/Users/.../astroid/astroid/decorators.py", line 139, in raise_if_nothing_inferred
    yield next(generator)
  File "/Users/.../astroid/astroid/decorators.py", line 108, in wrapped
    for res in _func(node, context, **kwargs):
  File "/Users/.../astroid/astroid/inference.py", line 264, in infer_call
    yield from callee.infer_call_result(caller=self, context=callcontext)
  File "/Users/.../astroid/astroid/nodes/scoped_nodes/scoped_nodes.py", line 1754, in infer_call_result
    yield from returnnode.value.infer(context)
  File "/Users/.../astroid/astroid/nodes/node_ng.py", line 182, in infer
    for i, result in enumerate(self._infer(context=context, **kwargs)):
  File "/Users/.../astroid/astroid/decorators.py", line 139, in raise_if_nothing_inferred
    yield next(generator)
  File "/Users/.../astroid/astroid/decorators.py", line 108, in wrapped
    for res in _func(node, context, **kwargs):
  File "/Users/.../astroid/astroid/bases.py", line 159, in _infer_stmts
    for inf in stmt.infer(context=context):  # type: ignore[union-attr]
  File "/Users/.../astroid/astroid/nodes/node_ng.py", line 182, in infer
    for i, result in enumerate(self._infer(context=context, **kwargs)):
  File "/Users/.../astroid/astroid/decorators.py", line 139, in raise_if_nothing_inferred
    yield next(generator)
  File "/Users/.../astroid/astroid/decorators.py", line 108, in wrapped
    for res in _func(node, context, **kwargs):
  File "/Users/.../astroid/astroid/bases.py", line 159, in _infer_stmts
    for inf in stmt.infer(context=context):  # type: ignore[union-attr]
  File "/Users/.../astroid/astroid/nodes/node_ng.py", line 182, in infer
    for i, result in enumerate(self._infer(context=context, **kwargs)):
  File "/Users/.../astroid/astroid/decorators.py", line 139, in raise_if_nothing_inferred
    yield next(generator)
  File "/Users/.../astroid/astroid/decorators.py", line 108, in wrapped
    for res in _func(node, context, **kwargs):
  File "/Users/.../astroid/astroid/inference.py", line 264, in infer_call
    yield from callee.infer_call_result(caller=self, context=callcontext)
  File "/Users/.../astroid/astroid/nodes/scoped_nodes/scoped_nodes.py", line 1754, in infer_call_result
    yield from returnnode.value.infer(context)
  File "/Users/.../astroid/astroid/nodes/node_ng.py", line 182, in infer
    for i, result in enumerate(self._infer(context=context, **kwargs)):
  File "/Users/.../astroid/astroid/decorators.py", line 139, in raise_if_nothing_inferred
    yield next(generator)
  File "/Users/.../astroid/astroid/decorators.py", line 108, in wrapped
    for res in _func(node, context, **kwargs):
  File "/Users/.../astroid/astroid/inference.py", line 264, in infer_call
    yield from callee.infer_call_result(caller=self, context=callcontext)
  File "/Users/.../astroid/astroid/bases.py", line 296, in infer_call_result
    for res in node.infer_call_result(caller, context):
  File "/Users/.../astroid/astroid/bases.py", line 296, in infer_call_result
    for res in node.infer_call_result(caller, context):
  File "/Users/.../astroid/astroid/bases.py", line 296, in infer_call_result
    for res in node.infer_call_result(caller, context):
  [Previous line repeated 913 more times]
  File "/Users/.../astroid/astroid/bases.py", line 293, in infer_call_result
    for node in self._proxied.igetattr("__call__", context):
  File "/Users/.../astroid/astroid/nodes/scoped_nodes/scoped_nodes.py", line 2661, in igetattr
    inferred._proxied.getattr("__get__", context)
  File "/Users/.../astroid/astroid/nodes/scoped_nodes/scoped_nodes.py", line 2572, in getattr
    values += self._metaclass_lookup_attribute(name, context)
  File "/Users/.../astroid/astroid/nodes/scoped_nodes/scoped_nodes.py", line 2592, in _metaclass_lookup_attribute
    metaclass = self.metaclass(context=context)
  File "/Users/.../astroid/astroid/nodes/scoped_nodes/scoped_nodes.py", line 2875, in metaclass
    return self._find_metaclass(context=context)
  File "/Users/.../astroid/astroid/nodes/scoped_nodes/scoped_nodes.py", line 2856, in _find_metaclass
    klass = self.declared_metaclass(context=context)
  File "/Users/.../astroid/astroid/nodes/scoped_nodes/scoped_nodes.py", line 2830, in declared_metaclass
    for baseobj in base.infer(context=context):
  File "/Users/.../astroid/astroid/nodes/node_ng.py", line 182, in infer
    for i, result in enumerate(self._infer(context=context, **kwargs)):
  File "/Users/.../astroid/astroid/decorators.py", line 139, in raise_if_nothing_inferred
    yield next(generator)
  File "/Users/.../astroid/astroid/decorators.py", line 108, in wrapped
    for res in _func(node, context, **kwargs):
  File "/Users/.../astroid/astroid/inference.py", line 343, in infer_attribute
    for owner in self.expr.infer(context):
  File "/Users/.../astroid/astroid/nodes/node_ng.py", line 182, in infer
    for i, result in enumerate(self._infer(context=context, **kwargs)):
  File "/Users/.../astroid/astroid/decorators.py", line 139, in raise_if_nothing_inferred
    yield next(generator)
  File "/Users/.../astroid/astroid/decorators.py", line 108, in wrapped
    for res in _func(node, context, **kwargs):
  File "/Users/.../astroid/astroid/bases.py", line 159, in _infer_stmts
    for inf in stmt.infer(context=context):  # type: ignore[union-attr]
  File "/Users/.../astroid/astroid/nodes/node_ng.py", line 182, in infer
    for i, result in enumerate(self._infer(context=context, **kwargs)):
  File "/Users/.../astroid/astroid/decorators.py", line 139, in raise_if_nothing_inferred
    yield next(generator)
  File "/Users/.../astroid/astroid/decorators.py", line 108, in wrapped
    for res in _func(node, context, **kwargs):
  File "/Users/.../astroid/astroid/inference.py", line 319, in infer_import_from
    module = self.do_import_module()
  File "/Users/.../astroid/astroid/nodes/_base_nodes.py", line 148, in do_import_module
    return mymodule.import_module(
  File "/Users/.../astroid/astroid/nodes/scoped_nodes/scoped_nodes.py", line 519, in import_module
    absmodname = self.relative_to_absolute_name(modname, level)
  File "/Users/.../astroid/astroid/nodes/scoped_nodes/scoped_nodes.py", line 549, in relative_to_absolute_name
    if self.absolute_import_activated() and level is None:
RecursionError: maximum recursion depth exceeded

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/.../pylint/pylint/lint/pylinter.py", line 688, in _check_files
    self._check_file(get_ast, check_astroid_module, file)
  File "/Users/.../pylint/pylint/lint/pylinter.py", line 733, in _check_file
    raise astroid.AstroidError from e
astroid.exceptions.AstroidError

Do you think we should launch the primer with keras i.e. the old one (for now) with a higher recursion limit?

@DanielNoord
Copy link
Collaborator

@DanielNoord this came up just now when running the old primer locally over keras against astroid 2.12.x:

Oh no, that looks awfully similar...

What happens if we run catch RecursionError in infer_import_from and raise InferenceError?

Do you think we should launch the primer with keras i.e. the old one (for now) with a higher recursion limit?

Yeah. I think it makes sense to no longer run the old one and start relying only on the new one with a higher timeout.

@jacobtylerwalls
Copy link
Member Author

I'm currently getting a timing on my fork for moving music21 to the new primer. Would you be able to do the same for keras?

@jacobtylerwalls
Copy link
Member Author

Eventually we can speed the jobs up by actually fixing the underlying issues preventing us from using --jobs.

@DanielNoord
Copy link
Collaborator

I'll probably be gone when it is finished, but you can see it here:
https://github.com/DanielNoord/pylint/actions/runs/2644864891

@jacobtylerwalls
Copy link
Member Author

What happens if we run catch RecursionError in infer_import_from and raise InferenceError?

I'll give it a go. keras still doesn't lint with raising the recursion limit to 25000. 50000 just segfaults.

@DanielNoord
Copy link
Collaborator

Keras also failed with a timeout of 70 minutes, so let's not add it as of now.

@jacobtylerwalls
Copy link
Member Author

What happens if we run catch RecursionError in infer_import_from and raise InferenceError?

Doesn't work, since the recursion still happens in the processing of InferenceError. If we just remove the try and let AstroidBuildingError bubble up, then we get a parse-error in pylint, which is not much better than the AstroidError now (interrupts linting).

@jacobtylerwalls
Copy link
Member Author

I appreciate any help you can provide! Try:

pylint pylint/tests/.pylint_primer_tests/keras-team/keras/keras/legacy_tf_layers/variable_scope_shim_test.py

@jacobtylerwalls
Copy link
Member Author

music21 passed on the 3.7 job in 82m. 3.10 only took 45.

@DanielNoord
Copy link
Collaborator

diff --git a/astroid/decorators.py b/astroid/decorators.py
index 03b345867..b6f0cc63b 100644
--- a/astroid/decorators.py
+++ b/astroid/decorators.py
@@ -145,6 +145,8 @@ def raise_if_nothing_inferred(func, instance, args, kwargs):
         raise InferenceError(
             "StopIteration raised without any error information."
         ) from error
+    except RecursionError as error:
+        raise InferenceError("")
 
     yield from generator

This works. I don't think this would be such a strange place to add this?
We'd only need to create a test and add a little comment about why this can occur (something about the high amount of decorators and recursive calls).

@jacobtylerwalls
Copy link
Member Author

Let's do it. Would you like to open the PR? I think earlier in #1660 (comment) you created a test package, does it still serve as a test?

@DanielNoord
Copy link
Collaborator

No that test package no longer crashed πŸ˜“

So, we'll need to create a new test. I'm not sure I'll be able to do that tonight as I have another deadline as well. If you could, narrowing down the keras example to a MRE would be very helpful! Otherwise I'll probably get to it somewhere next week.

@DanielNoord
Copy link
Collaborator

keras.zip

This is a fairly small test package of 6kb. Not sure if we want to change any names within it, but I think we can use this to create a test for the recursion crash.

@jacobtylerwalls
Copy link
Member Author

Thanks for distilling a test case. Yeah we don't need to change every name but I would suggest replacing "keras".

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Infinite loop when inferring from token import *
4 participants