Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reland "Add nn.CircularPad{*}d for consistency + fix no_batch_dim support (#106148)" #106632

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
45 changes: 29 additions & 16 deletions aten/src/ATen/native/PadNd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,26 +108,39 @@ Tensor constant_pad_nd(const Tensor& self, IntArrayRef pad, const Scalar& value)

Tensor _pad_circular_symint(const Tensor &self, c10::SymIntArrayRef padding) {
const auto in_shape = self.sym_sizes();
const auto ndim = static_cast<int64_t>(in_shape.size()) - 2;
TORCH_CHECK(padding.size() + 4 == in_shape.size() * 2,
"Invalid padding size, expected ", ndim * 2, " but got ", padding.size());
const auto self_ndim = static_cast<int64_t>(in_shape.size());

// number of dimensions that are padded
const auto ndim_padded = padding.size() / 2;
// number of preceding non_padded dimensions (1 for no_batch_dim case or 2)
const auto ndim_nonpadded = self_ndim - ndim_padded;

TORCH_CHECK(ndim_nonpadded == 1 || ndim_nonpadded == 2,
"Invalid padding size, expected 1 or 2 non-padded dimensions, ",
"which would be equivalent to padding of length ",
(self_ndim - 1) * 2,
" or ",
(self_ndim - 2) * 2,
" respectively but got ",
padding.size());

c10::SymDimVector out_shape(in_shape.size());
out_shape[0] = in_shape[0];
out_shape[1] = in_shape[1];
for (const auto i: c10::irange(ndim_nonpadded)) {
out_shape[i] = in_shape[i];
}

// Get shape of padded tensor
for (const auto i : c10::irange(ndim)) {
const auto& pad_l = padding[2 * (ndim - i - 1) + 0];
const auto& pad_r = padding[2 * (ndim - i - 1) + 1];
const auto& size = in_shape[2 + i];
out_shape[2 + i] = size + pad_l + pad_r;
for (const auto i : c10::irange(ndim_padded)) {
const auto& pad_l = padding[2 * (ndim_padded - i - 1) + 0];
const auto& pad_r = padding[2 * (ndim_padded - i - 1) + 1];
const auto& size = in_shape[ndim_nonpadded + i];
out_shape[ndim_nonpadded + i] = size + pad_l + pad_r;

TORCH_CHECK(
pad_l <= size && pad_r <= size,
"Padding value causes wrapping around more than once.");
TORCH_CHECK(
out_shape[2 + i] >= 0,
out_shape[ndim_nonpadded + i] >= 0,
"Negative padding value is resulting in an empty dimension");
}

Expand All @@ -137,8 +150,8 @@ Tensor _pad_circular_symint(const Tensor &self, c10::SymIntArrayRef padding) {
Tensor out_slice = out;
Tensor in_slice = self;
const SymInt zero = 0;
for (const auto i : c10::irange(ndim)) {
const auto dim = ndim - i + 1;
for (const auto i : c10::irange(ndim_padded)) {
const auto dim = ndim_padded - i + ndim_nonpadded - 1;
const auto& pad_l = padding[2*i + 0];
const auto& pad_r = padding[2*i + 1];
out_slice = out_slice.slice_symint(dim, std::max(pad_l, zero), out_shape[dim] - std::max(pad_r, zero));
Expand All @@ -148,12 +161,12 @@ Tensor _pad_circular_symint(const Tensor &self, c10::SymIntArrayRef padding) {

// The following steps first pad the beginning of the tensor (left side),
// and then pad the end of the tensor (right side).
// Note: Corners will be written more than once when ndim > 1.
// Note: Corners will be written more than once when ndim_padded > 1.
//
// Only in cases where padding values are > 0 are when additional copying
// is required.
for (const auto i : c10::irange(ndim)) {
const auto dim = ndim - i + 1;
for (const auto i : c10::irange(ndim_padded)) {
const auto dim = ndim_padded - i + ndim_nonpadded - 1;
const auto& pad_l = padding[2*i + 0];
const auto& pad_r = padding[2*i + 1];

Expand Down
3 changes: 3 additions & 0 deletions test/test_module_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ def build_constructor_arg_db():
torch.nn.ConstantPad1d: ((2, 3.5), {}),
torch.nn.ConstantPad2d: ((2, 3.5), {}),
torch.nn.ConstantPad3d: ((2, 3.5), {}),
torch.nn.CircularPad1d: ((2,), {}),
torch.nn.CircularPad2d: ((2,), {}),
torch.nn.CircularPad3d: ((2,), {}),
torch.nn.Conv1d: ((3, 3, 3), {}),
torch.nn.Conv2d: ((3, 3, 3), {}),
torch.nn.Conv3d: ((3, 3, 3), {}),
Expand Down
13 changes: 7 additions & 6 deletions torch/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -4425,12 +4425,13 @@ def affine_grid(theta: Tensor, size: List[int], align_corners: Optional[bool] =
:math:`\text{padding\_front}, \text{padding\_back})`.

Padding mode:
See :class:`torch.nn.ConstantPad2d`, :class:`torch.nn.ReflectionPad2d`, and
:class:`torch.nn.ReplicationPad2d` for concrete examples on how each of the
padding modes works. Constant padding is implemented for arbitrary dimensions.
Replicate and reflection padding are implemented for padding the last 3
dimensions of a 4D or 5D input tensor, the last 2 dimensions of a 3D
or 4D input tensor, or the last dimension of a 2D or 3D input tensor.
See :class:`torch.nn.CircularPad2d`, :class:`torch.nn.ConstantPad2d`,
:class:`torch.nn.ReflectionPad2d`, and :class:`torch.nn.ReplicationPad2d`
for concrete examples on how each of the padding modes works. Constant
padding is implemented for arbitrary dimensions. Circular, replicate and
reflection padding are implemented for padding the last 3 dimensions of a
4D or 5D input tensor, the last 2 dimensions of a 3D or 4D input tensor,
or the last dimension of a 2D or 3D input tensor.

Note:
When using the CUDA backend, this operation may induce nondeterministic
Expand Down
6 changes: 4 additions & 2 deletions torch/nn/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
from .normalization import LocalResponseNorm, CrossMapLRN2d, LayerNorm, GroupNorm
from .dropout import Dropout, Dropout1d, Dropout2d, Dropout3d, AlphaDropout, FeatureAlphaDropout
from .padding import ReflectionPad1d, ReflectionPad2d, ReflectionPad3d, ReplicationPad1d, ReplicationPad2d, \
ReplicationPad3d, ZeroPad1d, ZeroPad2d, ZeroPad3d, ConstantPad1d, ConstantPad2d, ConstantPad3d
ReplicationPad3d, ZeroPad1d, ZeroPad2d, ZeroPad3d, ConstantPad1d, ConstantPad2d, ConstantPad3d, \
CircularPad1d, CircularPad2d, CircularPad3d
from .sparse import Embedding, EmbeddingBag
from .rnn import RNNBase, RNN, LSTM, GRU, \
RNNCellBase, RNNCell, LSTMCell, GRUCell
Expand Down Expand Up @@ -62,5 +63,6 @@
'LazyConvTranspose1d', 'LazyConvTranspose2d', 'LazyConvTranspose3d',
'LazyBatchNorm1d', 'LazyBatchNorm2d', 'LazyBatchNorm3d',
'LazyInstanceNorm1d', 'LazyInstanceNorm2d', 'LazyInstanceNorm3d',
'Flatten', 'Unflatten', 'Hardsigmoid', 'Hardswish', 'SiLU', 'Mish', 'TripletMarginWithDistanceLoss', 'ChannelShuffle'
'Flatten', 'Unflatten', 'Hardsigmoid', 'Hardswish', 'SiLU', 'Mish', 'TripletMarginWithDistanceLoss', 'ChannelShuffle',
'CircularPad1d', 'CircularPad2d', 'CircularPad3d'
]
184 changes: 182 additions & 2 deletions torch/nn/modules/padding.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,188 @@

# TODO: grad_output size asserts in THNN

__all__ = ['ConstantPad1d', 'ConstantPad2d', 'ConstantPad3d', 'ReflectionPad1d', 'ReflectionPad2d',
'ReflectionPad3d', 'ReplicationPad1d', 'ReplicationPad2d', 'ReplicationPad3d', 'ZeroPad1d', 'ZeroPad2d', 'ZeroPad3d']
__all__ = ['CircularPad1d', 'CircularPad2d', 'CircularPad3d', 'ConstantPad1d', 'ConstantPad2d',
'ConstantPad3d', 'ReflectionPad1d', 'ReflectionPad2d', 'ReflectionPad3d',
'ReplicationPad1d', 'ReplicationPad2d', 'ReplicationPad3d', 'ZeroPad1d', 'ZeroPad2d', 'ZeroPad3d']


class _CircularPadNd(Module):
__constants__ = ['padding']
padding: Sequence[int]

def _check_input_dim(self, input):
raise NotImplementedError

def forward(self, input: Tensor) -> Tensor:
self._check_input_dim(input)
return F.pad(input, self.padding, 'circular')

def extra_repr(self) -> str:
return f'{self.padding}'


class CircularPad1d(_CircularPadNd):
r"""Pads the input tensor using circular padding of the input boundary.

Tensor values at the beginning of the dimension are used to pad the end,
and values at the end are used to pad the beginning. If negative padding is
applied then the ends of the tensor get removed.

For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`.

Args:
padding (int, tuple): the size of the padding. If is `int`, uses the same
padding in all boundaries. If a 2-`tuple`, uses
(:math:`\text{padding\_left}`, :math:`\text{padding\_right}`)

Shape:
- Input: :math:`(C, W_{in})` or :math:`(N, C, W_{in})`.
- Output: :math:`(C, W_{out})` or :math:`(N, C, W_{out})`, where

:math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`

Examples::

>>> # xdoctest: +IGNORE_WANT("not sure why xdoctest is choking on this")
>>> m = nn.CircularPad1d(2)
>>> input = torch.arange(8, dtype=torch.float).reshape(1, 2, 4)
>>> input
tensor([[[0., 1., 2., 3.],
[4., 5., 6., 7.]]])
>>> m(input)
tensor([[[2., 3., 0., 1., 2., 3., 0., 1.],
[6., 7., 4., 5., 6., 7., 4., 5.]]])
>>> # using different paddings for different sides
>>> m = nn.CircularPad1d((3, 1))
>>> m(input)
tensor([[[1., 2., 3., 0., 1., 2., 3., 0.],
[5., 6., 7., 4., 5., 6., 7., 4.]]])

"""
padding: Tuple[int, int]

def __init__(self, padding: _size_2_t) -> None:
super().__init__()
self.padding = _pair(padding)

def _check_input_dim(self, input):
if input.dim() != 2 and input.dim() != 3:
raise ValueError(
f"expected 2D or 3D input (got {input.dim()}D input)"
)


class CircularPad2d(_CircularPadNd):
r"""Pads the input tensor using circular padding of the input boundary.

Tensor values at the beginning of the dimension are used to pad the end,
and values at the end are used to pad the beginning. If negative padding is
applied then the ends of the tensor get removed.

For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`.

Args:
padding (int, tuple): the size of the padding. If is `int`, uses the same
padding in all boundaries. If a 4-`tuple`, uses (:math:`\text{padding\_left}`,
:math:`\text{padding\_right}`, :math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`)

Shape:
- Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`.
- Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`, where

:math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}`

:math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`

Examples::

>>> m = nn.CircularPad2d(2)
>>> input = torch.arange(9, dtype=torch.float).reshape(1, 1, 3, 3)
>>> input
tensor([[[[0., 1., 2.],
[3., 4., 5.],
[6., 7., 8.]]]])
>>> m(input)
tensor([[[[4., 5., 3., 4., 5., 3., 4.],
[7., 8., 6., 7., 8., 6., 7.],
[1., 2., 0., 1., 2., 0., 1.],
[4., 5., 3., 4., 5., 3., 4.],
[7., 8., 6., 7., 8., 6., 7.],
[1., 2., 0., 1., 2., 0., 1.],
[4., 5., 3., 4., 5., 3., 4.]]]])
>>> # using different paddings for different sides
>>> m = nn.CircularPad2d((1, 1, 2, 0))
>>> m(input)
tensor([[[[5., 3., 4., 5., 3.],
[8., 6., 7., 8., 6.],
[2., 0., 1., 2., 0.],
[5., 3., 4., 5., 3.],
[8., 6., 7., 8., 6.]]]])

"""
padding: Tuple[int, int, int, int]

def __init__(self, padding: _size_4_t) -> None:
super().__init__()
self.padding = _quadruple(padding)

def _check_input_dim(self, input):
if input.dim() != 3 and input.dim() != 4:
raise ValueError(
f"expected 3D or 4D input (got {input.dim()}D input)"
)


class CircularPad3d(_CircularPadNd):
r"""Pads the input tensor using circular padding of the input boundary.

Tensor values at the beginning of the dimension are used to pad the end,
and values at the end are used to pad the beginning. If negative padding is
applied then the ends of the tensor get removed.

For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`.

Args:
padding (int, tuple): the size of the padding. If is `int`, uses the same
padding in all boundaries. If a 6-`tuple`, uses
(:math:`\text{padding\_left}`, :math:`\text{padding\_right}`,
:math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`,
:math:`\text{padding\_front}`, :math:`\text{padding\_back}`)

Shape:
- Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`.
- Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` or :math:`(C, D_{out}, H_{out}, W_{out})`,
where

:math:`D_{out} = D_{in} + \text{padding\_front} + \text{padding\_back}`

:math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}`

:math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`

Examples::

>>> # xdoctest: +IGNORE_WANT("non-deterministic")
>>> m = nn.CircularPad3d(3)
>>> input = torch.randn(16, 3, 8, 320, 480)
>>> output = m(input)
>>> # using different paddings for different sides
>>> m = nn.CircularPad3d((3, 3, 6, 6, 1, 1))
>>> output = m(input)

"""
padding: Tuple[int, int, int, int, int, int]

def __init__(self, padding: _size_6_t) -> None:
super().__init__()
self.padding = _ntuple(6)(padding)

def _check_input_dim(self, input):
if input.dim() != 4 and input.dim() != 5:
raise ValueError(
f"expected 4D or 5D input (got {input.dim()}D input)"
)


class _ConstantPadNd(Module):
__constants__ = ['padding', 'value']
Expand Down