diff --git a/returnn/tf/layers/basic.py b/returnn/tf/layers/basic.py index 691c89a64..af08972a6 100644 --- a/returnn/tf/layers/basic.py +++ b/returnn/tf/layers/basic.py @@ -6236,12 +6236,24 @@ def reduce(cls, input_data, mode, axes=None, keep_dims=False, enforce_batch_dim_ x_ = tf_util.where_bc(mask, x_, replacement_value, name="x_masked_axis_%i" % axis) if f == tf.reduce_mean: + tag = x.dim_tags[axis] + assert tag.dyn_size_ext is not None # checked above + size_all = tf.shape(x.placeholder)[axis] + size_actual = tag.dyn_size_ext + while any(d not in out_data.dim_tags for d in size_actual.dim_tags): + # We have some axis (e.g. B) which is not in the output. + # We need to remove this. + # https://github.com/rwth-i6/returnn/issues/1242 + i, d = [(i, d) for i, d in enumerate(size_actual.dim_tags) if d not in out_data.dim_tags][0] + assert not d.is_dynamic() # not implemented + size_all *= d.get_dim_value() + s = tf.reduce_sum(size_actual.placeholder, axis=i) + size_actual = size_actual.copy_template_excluding_axis(i) + size_actual.placeholder = s seq_len_bc = ( - x.dim_tags[axis].dyn_size_ext - .copy_compatible_to(out_data, check_sparse=False, check_dtype=False) - .placeholder) + size_actual.copy_compatible_to(out_data, check_sparse=False, check_dtype=False).placeholder) seq_len_bc = tf.maximum(seq_len_bc, 1) # avoid nan - correction_factor_ = tf.cast(tf.shape(x.placeholder)[axis], tf.float32) / tf.cast(seq_len_bc, tf.float32) + correction_factor_ = tf.cast(size_all, tf.float32) / tf.cast(seq_len_bc, tf.float32) correction_factor = tf_util.optional_mul(correction_factor, correction_factor_) if mode in arg_funcs: assert len(axes) == 1, "For argmax/argmin, only one reduction axis is supported" diff --git a/tests/test_TFNetworkLayer.py b/tests/test_TFNetworkLayer.py index 7f671aa48..69a5ced8d 100644 --- a/tests/test_TFNetworkLayer.py +++ b/tests/test_TFNetworkLayer.py @@ -10567,6 +10567,29 @@ def test_reduce_mean_batch_time(): numpy.testing.assert_allclose(ref, v, rtol=1e-5) +def test_ReduceLayer_mean_btf(): + # https://github.com/rwth-i6/returnn/issues/1242 + net_dict = { + "output": {"class": "reduce", "mode": "mean", "from": "data", "axis": ["B", "T", "F"]} + } + config = Config(dict( + extern_data={"data": {"shape": (None, 4)}} + )) + with make_scope() as session: + network = TFNetwork(config=config) + network.construct_from_dict(net_dict) + in_ = network.extern_data.get_default_input_data() + out = network.get_default_output_layer().output + in_v, seq_len, out_v = session.run( + (in_.placeholder, in_.get_sequence_lengths(), out.placeholder), + feed_dict=make_feed_dict(network.extern_data)) + n_batch = in_v.shape[0] + assert n_batch == seq_len.shape[0] + for b in range(n_batch): + in_v[b, seq_len[b]:, :] = numpy.nan + numpy.testing.assert_almost_equal(out_v, numpy.nanmean(in_v)) + + def test_automatic_seq_lengths(): with make_scope() as session: n_out = 5