From ea0f143c73ed37395788762933938e245473ddea Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Thu, 25 Apr 2024 17:42:22 +0200 Subject: [PATCH] RF pad TF PadLayer handle_dynamic_dims, new behavior version --- .../behavior_version.rst | 6 + returnn/frontend/_backend.py | 2 + returnn/frontend/array_.py | 59 ++++++++- returnn/tf/frontend_layers/_backend.py | 2 + returnn/tf/layers/basic.py | 125 +++++++++++++++--- returnn/torch/frontend/_backend.py | 18 +++ returnn/util/basic.py | 2 +- tests/test_rf_array.py | 45 ++++++- 8 files changed, 235 insertions(+), 24 deletions(-) diff --git a/docs/configuration_reference/behavior_version.rst b/docs/configuration_reference/behavior_version.rst index c2830adf0..9e5e1da9a 100644 --- a/docs/configuration_reference/behavior_version.rst +++ b/docs/configuration_reference/behavior_version.rst @@ -22,6 +22,12 @@ and not listing legacy/deprecated parameters. Version History --------------- +Behavior version 21 (2024-04-25) + +RF ``pad`` and TF ``PadLayer`` defaults changed: + +* ``handle_dynamic_dims``: False → True + Behavior version 20 (2024-01-05) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/returnn/frontend/_backend.py b/returnn/frontend/_backend.py index 1c737b747..89abc67fb 100644 --- a/returnn/frontend/_backend.py +++ b/returnn/frontend/_backend.py @@ -475,6 +475,7 @@ def pad( axes: Sequence[Dim], padding: Sequence[Tuple[Union[Dim, int], Union[Dim, int]]], out_dims: Sequence[Dim], + handle_dynamic_dims: bool, mode: str = "constant", value: Optional[Union[rf.RawTensorTypes, Tensor]] = None, ) -> Tensor: @@ -483,6 +484,7 @@ def pad( :param axes: :param padding: :param out_dims: + :param handle_dynamic_dims: :param mode: :param value: :return: padded tensor diff --git a/returnn/frontend/array_.py b/returnn/frontend/array_.py index 479004e35..1e66922ce 100644 --- a/returnn/frontend/array_.py +++ b/returnn/frontend/array_.py @@ -4,6 +4,7 @@ from __future__ import annotations from typing import Optional, Union, Type, TypeVar, Sequence, Tuple +import logging import numpy from returnn.tensor import Tensor, Dim import returnn.frontend as rf @@ -385,6 +386,7 @@ def pad( out_dims: Optional[Sequence[Dim]] = None, mode: str = "constant", value: Optional[Union[rf.RawTensorTypes, Tensor]] = None, + handle_dynamic_dims: Optional[bool] = None, ) -> Tuple[Tensor, Sequence[Dim]]: """ Pad values left/right in the specified axes. @@ -392,9 +394,13 @@ def pad( :param source: :param axes: which axes to add padding to :param padding: list of (left, right) padding for each axis - :param out_dims: (optional) predefined out dim tags, otherwise will automatically create + :param out_dims: (optional) predefined out dims for each padded dim in axes. will automatically create if not given :param mode: 'constant', 'reflect', 'replicate' or 'circular' :param value: (optional) value to pad with in "constant" mode + :param handle_dynamic_dims: True: when doing right padding on a dynamic dim, value will be added after the seq end, + not at the end of the dimension. False: value will be added at the end of the dimension. + By default, in behavior version >=21, this is True, in older versions, this is False. + :return: padded tensor, out_dims. out dims are for each dim in axes """ assert len(axes) == len(padding) if not out_dims: @@ -405,13 +411,62 @@ def pad( assert not right.need_masking(), f"padding {padding} does not support dynamic right padding" # Note that even dynamic middle dims is not exactly correct... out_dims = [left + middle + right for middle, (left, right) in zip(axes, padding)] + if handle_dynamic_dims is None: + handle_dynamic_dims = _pad_handle_dynamic_dims_default(axes, padding, mode=mode) # noinspection PyProtectedMember return ( - source._raw_backend.pad(source, axes=axes, padding=padding, out_dims=out_dims, mode=mode, value=value), + source._raw_backend.pad( + source, + axes=axes, + padding=padding, + out_dims=out_dims, + handle_dynamic_dims=handle_dynamic_dims, + mode=mode, + value=value, + ), out_dims, ) +_pad_handle_dynamic_dims_shown_warning = False + + +def _pad_handle_dynamic_dims_default( + pad_axes: Sequence[Dim], padding: Sequence[Tuple[Union[Dim, int], Union[Dim, int]]], *, mode: str +) -> bool: + """ + :param pad_axes: list of axes to pad + :param padding: list of (left, right) padding for each axis + :param mode: 'constant', 'reflect', 'replicate' or 'circular' + :return: True if dynamic dims should be handled as specified in the default behavior + """ + from returnn.util.basic import BehaviorVersion + + if BehaviorVersion.get() >= 21: + return True + + # Check whether not handling the dynamic dims is safe. Print a warning if not safe. + global _pad_handle_dynamic_dims_shown_warning + if not _pad_handle_dynamic_dims_shown_warning: + for middle, (left, right) in zip(pad_axes, padding): + middle: Dim + if not middle.need_masking() and (isinstance(left, int) or not left.need_masking()): + continue + if mode != "circular" and isinstance(right, int) and right == 0: + continue + + logging.getLogger("returnn.frontend").warning( + f"rf.pad applied on dynamic dim {middle} but handle_dynamic_dims=False used by default" + f" due to behavior version {BehaviorVersion.get()} < 21." + " Set handle_dynamic_dims explicitly to avoid the warning," + " or switch to a new behavior version >= 21." + " (This warning is only printed once.)" + ) + _pad_handle_dynamic_dims_shown_warning = True + break + return False + + def cum_concat_step( source: Tensor, *, prev_accum: Tensor, axis: Dim, out_spatial_dim: Optional[Dim] = None ) -> Tuple[Tensor, Dim]: diff --git a/returnn/tf/frontend_layers/_backend.py b/returnn/tf/frontend_layers/_backend.py index a5ff9e3bc..082b5e52e 100644 --- a/returnn/tf/frontend_layers/_backend.py +++ b/returnn/tf/frontend_layers/_backend.py @@ -355,6 +355,7 @@ def pad( axes: Sequence[Dim], padding: Sequence[Tuple[Union[Dim, int], Union[Dim, int]]], out_dims: Sequence[Dim], + handle_dynamic_dims: bool, mode: str = "constant", value: Union[rf.RawTensorTypes, Tensor] = None, ) -> Tensor: @@ -367,6 +368,7 @@ def pad( "axes": axes, "padding": padding, "out_dims": out_dims, + "handle_dynamic_dims": handle_dynamic_dims, "mode": mode, "value": value, }, diff --git a/returnn/tf/layers/basic.py b/returnn/tf/layers/basic.py index ec8c79a85..30283ce1d 100644 --- a/returnn/tf/layers/basic.py +++ b/returnn/tf/layers/basic.py @@ -1981,13 +1981,15 @@ def __init__( # Now we need to implement a similar logic as `returnn.tf.util.basic.nd_indices`, but more generic. idxs = [ ( - tf.reshape( - tf.range(pos_shape[i], dtype=pos_v.dtype), [1] * i + [pos_shape[i]] + [1] * (pos_ndim - i - 1) + ( + tf.reshape( + tf.range(pos_shape[i], dtype=pos_v.dtype), [1] * i + [pos_shape[i]] + [1] * (pos_ndim - i - 1) + ) + + tf.zeros_like(pos_v) ) - + tf.zeros_like(pos_v) + if i != replace_common_axis + else pos_v ) - if i != replace_common_axis - else pos_v for i in range(pos_ndim) ] nd_idxs = tf.stack(idxs, axis=-1) @@ -4173,13 +4175,28 @@ class PadLayer(_ConcatInputLayer): layer_class = "pad" - def __init__(self, axes, padding, out_dims=None, value=0, mode="constant", **kwargs): + def __init__( + self, + *, + axes: Union[Dim, str, Sequence[Union[Dim, str]]], + padding: Union[int, Tuple[int, int], Sequence[Tuple[int, int]]], + out_dims: Optional[Union[Dim, Sequence[Dim]]] = None, + handle_dynamic_dims: Optional[bool] = None, + value: Union[int, float] = 0, + mode: str = "constant", + **kwargs, + ): """ - :param Dim|str|list[Dim|str] axes: e.g. "F" etc. see :func:`Data.get_axes_from_description`. - :param list[(int,int)]|(int,int)|int padding: how much to pad left/right in each axis - :param Dim|list[Dim]|None out_dims: - :param int|float value: what constant value to pad, with mode=="constant" - :param str mode: "constant", "reflect", "symmetric" and "replication" + :param axes: e.g. "F" etc. see :func:`Data.get_axes_from_description`. + :param padding: how much to pad left/right in each axis + :param out_dims: + :param handle_dynamic_dims: True: when doing right padding on a dynamic dim, + value will be added after the seq end, + not at the end of the dimension. + False: value will be added at the end of the dimension. + By default, in behavior version >=21, this is True, in older versions, this is False. + :param value: what constant value to pad, with mode=="constant" + :param mode: "constant", "reflect", "symmetric" and "replication" """ out_dims # noqa # handled in get_out_data_from_opts super(PadLayer, self).__init__(**kwargs) @@ -4190,15 +4207,42 @@ def __init__(self, axes, padding, out_dims=None, value=0, mode="constant", **kwa paddings = [(0, 0)] * len(range(self.input_data.batch_ndim)) for i, a in enumerate(axes): paddings[a] = padding[i] - mode = mode.upper() + mode = mode.lower() + if handle_dynamic_dims is None: + handle_dynamic_dims = self._handle_dynamic_dims_default( + pad_axes=[self.input_data.dims[axis] for axis in axes_], + padding=padding, + mode=mode, + ) if all(sum(p) == 0 for p in padding): self.output.placeholder = self.input_data.placeholder - elif mode == "REPLICATION": + elif mode == "replication": self.output.placeholder = tf_util.pad_replicate(self.input_data.placeholder, axes, padding) else: self.output.placeholder = tf.pad( self.input_data.placeholder, paddings=paddings, mode=mode, constant_values=value ) + if any(dim.need_masking() for dim in out_dims) and handle_dynamic_dims: + if all(right == 0 for left, right in padding) and mode != "circular": + pass # no masking needed + else: + import returnn.frontend as rf + + if mode != "constant": + raise NotImplementedError( + f"pad: mode {mode} not implemented with dynamic dims and handle_dynamic_dims=True" + ) + for out_dim, middle_axis, (left, right) in zip(out_dims, axes, padding): + out_dim: Dim + middle = self.input_data.dims[middle_axis] + if middle.need_masking() or (isinstance(left, Dim) and left.need_masking()): + if isinstance(right, Dim) or right > 0: + mask = rf.compare_bc(rf.range_over_dim(out_dim), "<", (left + middle).dyn_size_ext) + self.output.raw_tensor = tf_util.where_bc( + mask.copy_compatible_to(self.output, check_sparse=False, check_dtype=False).raw_tensor, + self.output.raw_tensor, + tf.convert_to_tensor(value, dtype=self.output.dtype), + ) @classmethod def _transform_padding(cls, padding, axes): @@ -4218,6 +4262,45 @@ def _transform_padding(cls, padding, axes): padding = [(padding, padding)] * len(axes) return padding + _handle_dynamic_dims_shown_warning = False + + @classmethod + def _handle_dynamic_dims_default( + cls, pad_axes: Sequence[Dim], padding: Sequence[Tuple[Union[Dim, int], Union[Dim, int]]], *, mode: str + ) -> bool: + """ + :param pad_axes: list of axes to pad + :param padding: list of (left, right) padding for each axis + :param mode: 'constant', 'reflect', 'replicate' or 'circular' + :return: True if dynamic dims should be handled as specified in the default behavior + """ + from returnn.util.basic import BehaviorVersion + + if BehaviorVersion.get() >= 21: + return True + + # Check whether not handling the dynamic dims is safe. Print a warning if not safe. + if not cls._handle_dynamic_dims_shown_warning: + import logging + + for middle, (left, right) in zip(pad_axes, padding): + middle: Dim + if not middle.need_masking() and (isinstance(left, int) or not left.need_masking()): + continue + if mode != "circular" and isinstance(right, int) and right == 0: + continue + + logging.getLogger("returnn.tf").warning( + f"PadLayer applied on dynamic dim {middle} but handle_dynamic_dims=False used by default" + f" due to behavior version {BehaviorVersion.get()} < 21." + " Set handle_dynamic_dims explicitly to avoid the warning," + " or switch to a new behavior version >= 21." + " (This warning is only printed once.)" + ) + cls._handle_dynamic_dims_shown_warning = True + break + return False + @classmethod def get_out_data_from_opts(cls, name, sources, axes, padding, out_dims=None, **kwargs): """ @@ -4929,14 +5012,16 @@ def get_out_data_from_opts(cls, name, axis, dims, pad_to_multiples=None, sources rem_dim = None if not resolved_dims: resolved_dims = tuple( - Dim( - kind=axis_dim_tag.kind if not axis_dim_tag.is_batch_dim() else Dim.Types.Spatial, - description="%s_split_dims%i" % (name, i), - dimension=shape_dim, - auto_generated=True, + ( + Dim( + kind=axis_dim_tag.kind if not axis_dim_tag.is_batch_dim() else Dim.Types.Spatial, + description="%s_split_dims%i" % (name, i), + dimension=shape_dim, + auto_generated=True, + ) + if rem_dim is None or i != rem_dim_idx + else rem_dim ) - if rem_dim is None or i != rem_dim_idx - else rem_dim for i, shape_dim in enumerate(resolved_shape_dims) ) out_batch = data.batch diff --git a/returnn/torch/frontend/_backend.py b/returnn/torch/frontend/_backend.py index 6055419c2..ca0361a9e 100644 --- a/returnn/torch/frontend/_backend.py +++ b/returnn/torch/frontend/_backend.py @@ -439,6 +439,7 @@ def pad( axes: Sequence[Dim], padding: Sequence[Tuple[Union[Dim, int], Union[Dim, int]]], out_dims: Sequence[Dim], + handle_dynamic_dims: bool, mode: str = "constant", value: Optional[Union[rf.RawTensorTypes, Tensor]] = None, ) -> Tensor: @@ -465,6 +466,23 @@ def pad( assert value.dims == (), f"value {value} must be a scalar" value = value.raw_tensor out.raw_tensor = torch.nn.functional.pad(source.raw_tensor, pad=raw_pad, mode=mode, value=value) + if any(dim.need_masking() for dim in out_dims) and handle_dynamic_dims: + if all(right == 0 for right in raw_pad[1::2]) and mode != "circular": + pass # no masking needed + else: + if mode != "constant": + raise NotImplementedError( + f"pad: mode {mode} not implemented with dynamic dims and handle_dynamic_dims=True" + ) + for out_dim, middle, (left, right) in zip(out_dims, axes, padding): + if middle.need_masking() or (isinstance(left, Dim) and left.need_masking()): + if isinstance(right, Dim) or right > 0: + mask = rf.compare_bc(rf.range_over_dim(out_dim), "<", (left + middle).dyn_size_ext) + out.raw_tensor = torch.where( + mask.copy_compatible_to(out, check_dtype=False, check_sparse=False).raw_tensor, + out.raw_tensor, + value, + ) return out @staticmethod diff --git a/returnn/util/basic.py b/returnn/util/basic.py index 9d44eeab8..b999a6011 100644 --- a/returnn/util/basic.py +++ b/returnn/util/basic.py @@ -203,7 +203,7 @@ class BehaviorVersion: See :ref:`behavior_version`. """ - _latest_behavior_version = 20 + _latest_behavior_version = 21 _behavior_version = None # type: typing.Optional[int] _min_behavior_version = 0 # type: int diff --git a/tests/test_rf_array.py b/tests/test_rf_array.py index 220e4afa3..7f88930ea 100644 --- a/tests/test_rf_array.py +++ b/tests/test_rf_array.py @@ -5,6 +5,7 @@ from __future__ import annotations from typing import Tuple import _setup_test_env # noqa +import numpy as np import returnn.frontend as rf from returnn.tensor import Tensor, Dim, TensorDict, batch_dim from rf_utils import run_model @@ -128,7 +129,7 @@ def test_pad_time(): ) class _Net(rf.Module): - def __call__(self, x: Tensor) -> Tuple[Tensor, Tuple[Dim, Dim]]: + def __call__(self, x: Tensor) -> Tuple[Tensor, Tuple[Dim]]: pack, (new_time,) = rf.pad(x, axes=[time_dim], padding=[(1, 0)], value=0) return pack, (new_time,) @@ -140,6 +141,48 @@ def _forward_step(*, model: _Net, extern_data: TensorDict): run_model(extern_data, lambda *, epoch, step: _Net(), _forward_step) +def test_pad_time_right(): + time_dim = Dim(Tensor("time", [batch_dim], dtype="int32")) + in_dim = Dim(7, name="in") + extern_data = TensorDict( + { + "data": Tensor("data", [batch_dim, time_dim, in_dim], dtype="float32"), + } + ) + + class _Net(rf.Module): + def __call__(self, x: Tensor) -> Tuple[Tensor, Tuple[Dim]]: + pack, (new_time,) = rf.pad(x, axes=[time_dim], padding=[(0, 1)], value=1) + return pack, (new_time,) + + # noinspection PyShadowingNames + def _forward_step(*, model: _Net, extern_data: TensorDict): + data = extern_data["data"] + data.mark_as_output("data", shape=(batch_dim, time_dim, in_dim)) + out, (new_time,) = model(data) + out.mark_as_default_output(shape=(batch_dim, new_time, in_dim)) + + res = run_model(extern_data, lambda *, epoch, step: _Net(), _forward_step) + data_: Tensor = res["data"] + out_: Tensor = res["output"] + assert data_.dims == (batch_dim, time_dim, in_dim) + new_time_dim = out_.dims[1] + assert out_.dims == (batch_dim, new_time_dim, in_dim) and new_time_dim != time_dim + assert new_time_dim == time_dim + 1 # math dim... not really necessary check here... + assert time_dim.dyn_size_ext.dims == new_time_dim.dyn_size_ext.dims == (batch_dim,) + batch_size = batch_dim.get_dim_value() + assert batch_size > 1 + assert len(set(time_dim.dyn_size_ext.raw_tensor)) > 1 # not all the same + for b in range(batch_size): + seq_len = time_dim.dyn_size_ext.raw_tensor[b] + new_seq_len = new_time_dim.dyn_size_ext.raw_tensor[b] + print(f"batch {b}, seq_len {seq_len}, new_seq_len {new_seq_len}") + assert new_seq_len == seq_len + 1 + np.testing.assert_allclose(data_.raw_tensor[b, :seq_len], out_.raw_tensor[b, :seq_len]) + print(out_.raw_tensor[b]) + assert all(out_.raw_tensor[b, seq_len] == 1.0) + + def test_gather(): time_dim = Dim(Tensor("time", [batch_dim], dtype="int32")) in_dim = Dim(7, name="in")