From 01d6984eeb2db2a05d473c9d52301eddcd991d10 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Wed, 16 Mar 2022 11:50:34 +0100 Subject: [PATCH] test_subnetwork_deep_stack #993 --- tests/test_TFNetworkLayer.py | 86 ++++++++++++++++++++++++++++++++++++ 1 file changed, 86 insertions(+) diff --git a/tests/test_TFNetworkLayer.py b/tests/test_TFNetworkLayer.py index 621c1116b..dc724adaa 100644 --- a/tests/test_TFNetworkLayer.py +++ b/tests/test_TFNetworkLayer.py @@ -6024,6 +6024,92 @@ def test_subnetwork_unused_output(): network.construct_from_dict(net_dict) +def test_subnetwork_deep_stack(): + # https://github.com/rwth-i6/returnn/issues/993 + if not util.PY3: + raise unittest.SkipTest("test case needs python 3") # __qualname__, __func__, etc + with make_scope() as session: + import better_exchook + import traceback + import types + import collections + + max_depth = 5 + + def _find_net_construct_from_dict_in_stack(stack): + found_last = None + for i, frame in enumerate(stack): + assert isinstance(frame, types.FrameType) + func = better_exchook.get_func_from_code_object(frame.f_code) + if func == TFNetwork.construct_from_dict: + found_last = i + if found_last is not None: + return found_last + raise Exception("construct_from_dict not found in stack") + + def _extract_stack(): + stack = [frame for frame, _ in traceback.walk_stack(None)] + stack = stack[:_find_net_construct_from_dict_in_stack(stack) + 1] + return stack + + def _top_eval(source, **_): + print("Top eval func stack:") + stack = _extract_stack() + # LayerBase.transform_config_dict would only increase if this is not flat net construction, + # see https://github.com/rwth-i6/returnn/issues/992. + # For reference, output without fixes: https://gist.github.com/albertz/a1a20710e4d08292e35ba6910ca132e8 + func_counter = collections.Counter() + for frame in stack: + assert isinstance(frame, types.FrameType) + func = better_exchook.get_func_from_code_object(frame.f_code) + if not func: + print(" Warning: unexpected code object: %s" % frame.f_code.co_name) + continue + func_counter[func] += 1 + print(" ", func.__qualname__) + print("Stack depth:", len(stack)) + print("Num functions:", len(func_counter)) + for func, count in func_counter.most_common(10): + print("Most common func %s: count %i" % (func.__qualname__, count)) + """ + For reference, without flat construction (#992), before the fix of #993, I get: + Most common func Subnetwork.get_sub_layer_func..wrapped_get_layer: count 30 + Most common func Subnetwork.get_layer_func..wrapped_get_layer: count 30 + Most common func TFNetwork.construct_layer: count 12 + Most common func TFNetwork.construct_layer..get_layer: count 11 + Most common func TFNetwork.add_layer: count 6 + Most common func LayerBase.transform_config_dict..: count 6 + Most common func LayerBase.transform_config_dict: count 6 + Most common func CopyLayer.transform_config_dict: count 6 + Most common func Subnetwork.construct_layer: count 5 + Most common func SubnetworkLayer.transform_config_dict: count 5 + """ + print("Num TFNetwork.construct_layer:", func_counter[TFNetwork.construct_layer]) + print("Num LayerBase.transform_config_dict:", func_counter[LayerBase.transform_config_dict.__func__]) + # Before the fix, the stack depth was 125; with the improved variant (even not flat) it is 65. + # We don't check for the exact number to allow for some variation of future changes. + # However, we want to avoid that it becomes too deep in any case. + assert len(stack) <= 70 + return source(0) + + def _create_subnet_layer_dict(depth): + if depth >= max_depth: + return {"class": "eval", "from": "base:" * depth + "data", "eval": _top_eval} + return { + 'class': 'subnetwork', + 'from': [], + 'subnetwork': { + 'sub%i' % depth: _create_subnet_layer_dict(depth + 1), + 'output': {'class': 'copy', 'from': 'sub%i' % depth}}} + + net_dict = { + 'sub': _create_subnet_layer_dict(0), + 'output': {'class': 'copy', 'from': 'sub'}} + config = Config({"extern_data": {"data": {"dim": 3}}}) + network = TFNetwork(config=config) + network.construct_from_dict(net_dict) + + def test_extra_search(): class Callbacks: history = []