Skip to content

Commit

Permalink
torch.nn.Unfold accepts 0-dim for batch size (#40689)
Browse files Browse the repository at this point in the history
Summary:
In partial completion of #12013

Allows specifying a tensor with 0-dim batch size for `torch.nn.Unfold()`.

Pull Request resolved: #40689

Reviewed By: zou3519

Differential Revision: D24441164

Pulled By: ngimel

fbshipit-source-id: 49cd53b9b23f2e221aecdb4b5fed19a234038063
  • Loading branch information
v0dro authored and facebook-github-bot committed Oct 22, 2020
1 parent c57c560 commit 982fa07
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 5 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/native/Im2Col.cpp
Expand Up @@ -10,7 +10,7 @@
namespace at {
namespace native {
namespace {

static void im2col_out_cpu_template(
Tensor& output,
const Tensor& input_,
Expand Down
13 changes: 9 additions & 4 deletions aten/src/ATen/native/im2col_shape_check.h
Expand Up @@ -37,9 +37,11 @@ static inline void col2im_shape_check(
dilation_width);

int64_t ndim = input.ndimension();
// allow dim=0 only the batch dimension.
TORCH_CHECK(
input.numel() != 0 && (ndim == 2 || ndim == 3),
"Expected non-empty 2D or 3D input tensor, but got input of sizes",
(ndim == 2 && input.size(0) != 0 && input.size(1) != 0) ||
(ndim == 3 && input.size(1) != 0 && input.size(2) != 0),
"Expected 2D or 3D (batch mode) tensor for input with possibly 0 batch size and non-zero dimensions for input, but got: ",
input.sizes());

int64_t batch_dim = (ndim == 3) ? 0 : -1;
Expand Down Expand Up @@ -155,9 +157,12 @@ static inline void im2col_shape_check(

int64_t ndim = input.ndimension();

// allow dim=0 only the batch dimension.
bool valid_dims = input.size(1) != 0 && input.size(2) != 0;
TORCH_CHECK(
input.numel() != 0 && (ndim == 3 || ndim == 4),
"Expected non-empty 3D or 4D input tensor, but got input of size ",
(ndim == 3 && input.size(0) && valid_dims) ||
(ndim == 4 && valid_dims && input.size(3) != 0),
"Expected 3D or 4D (batch mode) tensor with possibly 0 batch size and other non-zero dimensions for input, but got: ",
input.sizes());

int64_t dim_batch = 0;
Expand Down
11 changes: 11 additions & 0 deletions test/test_nn.py
Expand Up @@ -10196,6 +10196,17 @@ def test_ReflectionPad_empty(self, device):
inp = torch.randn(3, 0, 10, 10, device=device)
mod(inp)

@onlyOnCPUAndCUDA
def test_Unfold_empty(self, device):
inp = torch.randn(0, 3, 3, 4, device=device)
unfold = torch.nn.Unfold(kernel_size=(2, 3)).to(device)
self._test_module_empty_input(unfold, inp, check_size=False)

with self.assertRaisesRegex(RuntimeError, 'Expected 3D or 4D'):
inp = torch.randn(3, 0, 3, 4, device=device)
unfold = torch.nn.Unfold(kernel_size=(2, 3)).to(device)
unfold(inp)

@onlyCUDA
@dtypes(torch.float, torch.double)
@tf32_on_and_off(0.005)
Expand Down

0 comments on commit 982fa07

Please sign in to comment.