diff --git a/aten/src/ATen/native/im2col_shape_check.h b/aten/src/ATen/native/im2col_shape_check.h index 03a92c9c31cc..78725bb81ff7 100644 --- a/aten/src/ATen/native/im2col_shape_check.h +++ b/aten/src/ATen/native/im2col_shape_check.h @@ -39,9 +39,9 @@ static inline void col2im_shape_check( int64_t ndim = input.ndimension(); // allow dim=0 only the batch dimension. TORCH_CHECK( - (ndim == 2 && input.size(1) != 0) || + (ndim == 2 && input.size(0) != 0 && input.size(1) != 0) || (ndim == 3 && input.size(1) != 0 && input.size(2) != 0), - "2D or 3D (batch mode) tensor expected for input, but got input of size ", + "Expected 2D or 3D (batch mode) tensor for input with possibly 0 batch size and non-zero dimensions for input, but got: ", input.sizes()); int64_t batch_dim = (ndim == 3) ? 0 : -1; @@ -158,10 +158,11 @@ static inline void im2col_shape_check( int64_t ndim = input.ndimension(); // allow dim=0 only the batch dimension. + bool valid_dims = input.size(1) != 0 && input.size(2) != 0; TORCH_CHECK( - (ndim == 3 && input.size(1) != 0 && input.size(2) != 0) || - (ndim == 4 && input.size(1) != 0 && input.size(2) != 0 && input.size(3) != 0), - "3D or 4D (batch mode) expected for input, but got input of size ", + (ndim == 3 && input.size(0) && valid_dims) || + (ndim == 4 && valid_dims && input.size(3) != 0), + "Expected 3D or 4D (batch mode) tensor with possibly 0 batch size and other non-zero dimensions for input, but got: ", input.sizes()); int64_t dim_batch = 0; diff --git a/test/test_nn.py b/test/test_nn.py index 7184fabdfa9d..de24c674eda3 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -10175,12 +10175,13 @@ def test_ReflectionPad_empty(self, device): inp = torch.randn(3, 0, 10, 10, device=device) mod(inp) + @onlyOnCPUAndCUDA def test_Unfold_empty(self, device): inp = torch.randn(0, 3, 3, 4, device=device) unfold = torch.nn.Unfold(kernel_size=(2, 3)).to(device) self._test_module_empty_input(unfold, inp, check_size=False) - with self.assertRaisesRegex(RuntimeError, '3D or 4D'): + with self.assertRaisesRegex(RuntimeError, 'Expected 3D or 4D'): inp = torch.randn(3, 0, 3, 4, device=device) unfold = torch.nn.Unfold(kernel_size=(2, 3)).to(device) unfold(inp)