Skip to content

Commit

Permalink
update error messages
Browse files Browse the repository at this point in the history
  • Loading branch information
v0dro committed Oct 9, 2020
1 parent ff801b4 commit 81680d8
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
11 changes: 6 additions & 5 deletions aten/src/ATen/native/im2col_shape_check.h
Expand Up @@ -39,9 +39,9 @@ static inline void col2im_shape_check(
int64_t ndim = input.ndimension();
// allow dim=0 only the batch dimension.
TORCH_CHECK(
(ndim == 2 && input.size(1) != 0) ||
(ndim == 2 && input.size(0) != 0 && input.size(1) != 0) ||
(ndim == 3 && input.size(1) != 0 && input.size(2) != 0),
"2D or 3D (batch mode) tensor expected for input, but got input of size ",
"Expected 2D or 3D (batch mode) tensor for input with possibly 0 batch size and non-zero dimensions for input, but got: ",
input.sizes());

int64_t batch_dim = (ndim == 3) ? 0 : -1;
Expand Down Expand Up @@ -158,10 +158,11 @@ static inline void im2col_shape_check(
int64_t ndim = input.ndimension();

// allow dim=0 only the batch dimension.
bool valid_dims = input.size(1) != 0 && input.size(2) != 0;
TORCH_CHECK(
(ndim == 3 && input.size(1) != 0 && input.size(2) != 0) ||
(ndim == 4 && input.size(1) != 0 && input.size(2) != 0 && input.size(3) != 0),
"3D or 4D (batch mode) expected for input, but got input of size ",
(ndim == 3 && input.size(0) && valid_dims) ||
(ndim == 4 && valid_dims && input.size(3) != 0),
"Expected 3D or 4D (batch mode) tensor with possibly 0 batch size and other non-zero dimensions for input, but got: ",
input.sizes());

int64_t dim_batch = 0;
Expand Down
3 changes: 2 additions & 1 deletion test/test_nn.py
Expand Up @@ -10175,12 +10175,13 @@ def test_ReflectionPad_empty(self, device):
inp = torch.randn(3, 0, 10, 10, device=device)
mod(inp)

@onlyOnCPUAndCUDA
def test_Unfold_empty(self, device):
inp = torch.randn(0, 3, 3, 4, device=device)
unfold = torch.nn.Unfold(kernel_size=(2, 3)).to(device)
self._test_module_empty_input(unfold, inp, check_size=False)

with self.assertRaisesRegex(RuntimeError, '3D or 4D'):
with self.assertRaisesRegex(RuntimeError, 'Expected 3D or 4D'):
inp = torch.randn(3, 0, 3, 4, device=device)
unfold = torch.nn.Unfold(kernel_size=(2, 3)).to(device)
unfold(inp)
Expand Down

0 comments on commit 81680d8

Please sign in to comment.