Skip to content

Commit

Permalink
ConvLayer (and co) out data, fix time_dim_axis (#999)
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Mar 16, 2022
1 parent a07ba0e commit 9166451
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 1 deletion.
9 changes: 8 additions & 1 deletion returnn/tf/layers/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5230,6 +5230,7 @@ def get_out_data_from_opts(
assert n_out
out_dim = FeatureDim("%s:channel" % name, dimension=n_out, auto_generated=True)
dim_tags.append(out_dim)
time_dim_axis = data.time_dim_axis
feature_dim_axis = NotSpecified
# Swap the dims if the input dim order doesn't fit the flag auto_use_channel_first.
if auto_use_channel_first is NotSpecified:
Expand All @@ -5239,8 +5240,14 @@ def get_out_data_from_opts(
if len([d for d in dim_tags if d.dimension]) > 1:
feature_dim_axis = num_batch_dims
dim_tags = dim_tags[:num_batch_dims] + dim_tags[-1:] + dim_tags[num_batch_dims:-1]
if time_dim_axis is not None and time_dim_axis >= num_batch_dims:
if time_dim_axis == len(dim_tags) - 1:
time_dim_axis = num_batch_dims
else:
time_dim_axis += 1
out = Data(
name="%s_output" % name, dim_tags=dim_tags, feature_dim_axis=feature_dim_axis,
name="%s_output" % name, dim_tags=dim_tags,
time_dim_axis=time_dim_axis, feature_dim_axis=feature_dim_axis,
batch=data.batch, beam=data.beam, control_flow_ctx=data.control_flow_ctx)
if len(old_spatial_dim_tags) == len(filter_size):
cls.set_output_dim_tags(
Expand Down
37 changes: 37 additions & 0 deletions tests/test_TFNetworkLayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4180,6 +4180,43 @@ def test_ConvLayer_get_out_data_from_opts_out_spatial_dims():
feed_dict=make_feed_dict(net.extern_data))


def test_ConvLayer_unrelated_dim():
from returnn.tf.util.data import batch_dim
time_dim = SpatialDim("time")
feat_dim = FeatureDim("input", 7)
other_dim = SpatialDim("other")
config = Config({"extern_data": {"data": {"dim_tags": [batch_dim, time_dim, other_dim, feat_dim]}}})
with make_scope() as session:
net = TFNetwork(config=config)
layer_desc = {
'name': "conv", "_name": "conv",
"network": net, "_network": net,
"from": "data",
'filter_size': [4],
"in_spatial_dims": [time_dim],
'strides': 3,
'padding': 'valid',
'n_out': 13,
}
ConvLayer.transform_config_dict(layer_desc, network=net, get_layer=net.get_layer)
conv_out = ConvLayer.get_out_data_from_opts(**layer_desc)
print("conv out:", conv_out)
dyn_axes = conv_out.get_dynamic_axes()
assert len(dyn_axes) == 2, "conv out: %r" % conv_out
assert conv_out.get_axis_from_description(other_dim) in dyn_axes
dyn_axes.remove(conv_out.get_axis_from_description(other_dim))
out_spatial_dim = conv_out.dim_tags[dyn_axes[0]]
assert out_spatial_dim not in net.extern_data.get_default_input_data().dim_tags
assert conv_out.time_dim_axis == dyn_axes[0]
with tf_compat.v1.variable_scope("conv"):
conv_layer = ConvLayer(output=conv_out, **layer_desc)
net.layers["conv"] = conv_layer
net.initialize_params(session)
session.run(
(conv_layer.output.placeholder, conv_layer.output.get_sequence_lengths()),
feed_dict=make_feed_dict(net.extern_data))


def test_conv_layer_NCHW():
with make_scope() as session:
import numpy as np
Expand Down

0 comments on commit 9166451

Please sign in to comment.