Skip to content

Commit

Permalink
test_subnetwork_deep_stack
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Mar 17, 2022
1 parent 9166451 commit 01d6984
Showing 1 changed file with 86 additions and 0 deletions.
86 changes: 86 additions & 0 deletions tests/test_TFNetworkLayer.py
Expand Up @@ -6024,6 +6024,92 @@ def test_subnetwork_unused_output():
network.construct_from_dict(net_dict)


def test_subnetwork_deep_stack():
# https://github.com/rwth-i6/returnn/issues/993
if not util.PY3:
raise unittest.SkipTest("test case needs python 3") # __qualname__, __func__, etc
with make_scope() as session:
import better_exchook
import traceback
import types
import collections

max_depth = 5

def _find_net_construct_from_dict_in_stack(stack):
found_last = None
for i, frame in enumerate(stack):
assert isinstance(frame, types.FrameType)
func = better_exchook.get_func_from_code_object(frame.f_code)
if func == TFNetwork.construct_from_dict:
found_last = i
if found_last is not None:
return found_last
raise Exception("construct_from_dict not found in stack")

def _extract_stack():
stack = [frame for frame, _ in traceback.walk_stack(None)]
stack = stack[:_find_net_construct_from_dict_in_stack(stack) + 1]
return stack

def _top_eval(source, **_):
print("Top eval func stack:")
stack = _extract_stack()
# LayerBase.transform_config_dict would only increase if this is not flat net construction,
# see https://github.com/rwth-i6/returnn/issues/992.
# For reference, output without fixes: https://gist.github.com/albertz/a1a20710e4d08292e35ba6910ca132e8
func_counter = collections.Counter()
for frame in stack:
assert isinstance(frame, types.FrameType)
func = better_exchook.get_func_from_code_object(frame.f_code)
if not func:
print(" Warning: unexpected code object: %s" % frame.f_code.co_name)
continue
func_counter[func] += 1
print(" ", func.__qualname__)
print("Stack depth:", len(stack))
print("Num functions:", len(func_counter))
for func, count in func_counter.most_common(10):
print("Most common func %s: count %i" % (func.__qualname__, count))
"""
For reference, without flat construction (#992), before the fix of #993, I get:
Most common func Subnetwork.get_sub_layer_func.<locals>.wrapped_get_layer: count 30
Most common func Subnetwork.get_layer_func.<locals>.wrapped_get_layer: count 30
Most common func TFNetwork.construct_layer: count 12
Most common func TFNetwork.construct_layer.<locals>.get_layer: count 11
Most common func TFNetwork.add_layer: count 6
Most common func LayerBase.transform_config_dict.<locals>.<listcomp>: count 6
Most common func LayerBase.transform_config_dict: count 6
Most common func CopyLayer.transform_config_dict: count 6
Most common func Subnetwork.construct_layer: count 5
Most common func SubnetworkLayer.transform_config_dict: count 5
"""
print("Num TFNetwork.construct_layer:", func_counter[TFNetwork.construct_layer])
print("Num LayerBase.transform_config_dict:", func_counter[LayerBase.transform_config_dict.__func__])
# Before the fix, the stack depth was 125; with the improved variant (even not flat) it is 65.
# We don't check for the exact number to allow for some variation of future changes.
# However, we want to avoid that it becomes too deep in any case.
assert len(stack) <= 70
return source(0)

def _create_subnet_layer_dict(depth):
if depth >= max_depth:
return {"class": "eval", "from": "base:" * depth + "data", "eval": _top_eval}
return {
'class': 'subnetwork',
'from': [],
'subnetwork': {
'sub%i' % depth: _create_subnet_layer_dict(depth + 1),
'output': {'class': 'copy', 'from': 'sub%i' % depth}}}

net_dict = {
'sub': _create_subnet_layer_dict(0),
'output': {'class': 'copy', 'from': 'sub'}}
config = Config({"extern_data": {"data": {"dim": 3}}})
network = TFNetwork(config=config)
network.construct_from_dict(net_dict)


def test_extra_search():
class Callbacks:
history = []
Expand Down

0 comments on commit 01d6984

Please sign in to comment.