Skip to content

Commit

Permalink
Add PixelUnshuffle (#49334)
Browse files Browse the repository at this point in the history
Summary:
Adds an implementation of `torch.nn.PixelUnshuffle` as the inverse operation of `torch.nn.PixelShuffle`. This addresses #2456

Pull Request resolved: #49334

Test Plan:
```
# Unit tests.
python test/test_nn.py TestNN.test_pixel_shuffle_unshuffle

# Module test.
python test/test_nn.py TestNN.test_PixelUnshuffle

# C++ API tests.
build/bin/test_api

# C++ / python parity tests.
python test/test_cpp_api_parity.py

# JIT test.
python test/test_jit.py TestJitGeneratedFunctional.test_nn_pixel_unshuffle

# Override tests.
python test/test_overrides.py

# Type hint tests.
python test/test_type_hints.py
```

Screenshots of rendered docs:
<img width="876" alt="Screen Shot 2020-12-18 at 12 19 05 PM" src="https://user-images.githubusercontent.com/75754324/102642255-6b07bb00-412b-11eb-88fa-e53e7e8ba720.png">
<img width="984" alt="Screen Shot 2020-12-18 at 12 19 26 PM" src="https://user-images.githubusercontent.com/75754324/102642276-70fd9c00-412b-11eb-8548-445082a2db02.png">
<img width="932" alt="Screen Shot 2020-12-18 at 12 19 34 PM" src="https://user-images.githubusercontent.com/75754324/102642704-19abfb80-412c-11eb-9546-95bdd1c3cf22.png">
<img width="876" alt="Screen Shot 2020-12-22 at 12 51 36 PM" src="https://user-images.githubusercontent.com/75754324/102918259-986aa680-4454-11eb-99e7-a0b4c8b3e283.png">
<img width="869" alt="Screen Shot 2020-12-22 at 12 51 44 PM" src="https://user-images.githubusercontent.com/75754324/102918274-9ef91e00-4454-11eb-94bb-91b58aff47d3.png">

Reviewed By: mruberry

Differential Revision: D25401439

Pulled By: jbschlosser

fbshipit-source-id: 209d92ce7295e51699e83616d0c62170a7ce75c8
  • Loading branch information
jbschlosser authored and facebook-github-bot committed Dec 23, 2020
1 parent 461aafe commit 68d438c
Show file tree
Hide file tree
Showing 20 changed files with 371 additions and 56 deletions.
1 change: 1 addition & 0 deletions aten/src/ATen/core/aten_interned_strings.h
Expand Up @@ -553,6 +553,7 @@ _(aten, permute) \
_(aten, pin_memory) \
_(aten, pinverse) \
_(aten, pixel_shuffle) \
_(aten, pixel_unshuffle) \
_(aten, poisson) \
_(aten, polygamma) \
_(aten, pow) \
Expand Down
83 changes: 72 additions & 11 deletions aten/src/ATen/native/PixelShuffle.cpp
Expand Up @@ -14,12 +14,16 @@ Tensor pixel_shuffle(const Tensor& self, int64_t upscale_factor) {
TORCH_CHECK(self.dim() >= 3,
"pixel_shuffle expects input to have at least 3 dimensions, but got input with ",
self.dim(), " dimension(s)");
TORCH_CHECK(
upscale_factor > 0,
"pixel_shuffle expects a positive upscale_factor, but got ",
upscale_factor);
// Format: (B1, ..., Bn), C, H, W
int64_t c = self.size(-3);
int64_t h = self.size(-2);
int64_t w = self.size(-1);
const auto NUM_NON_BATCH_DIMS = 3;
const auto last_batch_dim = self.sizes().end() - NUM_NON_BATCH_DIMS;
const auto self_sizes_batch_end = self.sizes().end() - NUM_NON_BATCH_DIMS;

int64_t upscale_factor_squared = upscale_factor * upscale_factor;
TORCH_CHECK(c % upscale_factor_squared == 0,
Expand All @@ -29,24 +33,81 @@ Tensor pixel_shuffle(const Tensor& self, int64_t upscale_factor) {
int64_t oh = h * upscale_factor;
int64_t ow = w * upscale_factor;

// First, reshape to expand the channels dim from c into 3 separate dims: (oc, upscale_factor, upscale_factor).
// This allows shuffling to be done next by permuting dims.
std::vector<int64_t> expanded_shape(self.sizes().begin(), last_batch_dim);
expanded_shape.insert(expanded_shape.end(), {oc, upscale_factor, upscale_factor, h, w});
const auto input_expanded = self.reshape(expanded_shape);
// First, reshape to split the channels dim from c into 3 separate dims: (oc,
// upscale_factor, upscale_factor). This allows shuffling to be done next by
// permuting dims.
std::vector<int64_t> added_dims_shape(
self.sizes().begin(), self_sizes_batch_end);
added_dims_shape.insert(
added_dims_shape.end(), {oc, upscale_factor, upscale_factor, h, w});
const auto input_reshaped = self.reshape(added_dims_shape);

// Next, shuffle by permuting the new upscale_factor dims alongside the height and width dims.
std::vector<int64_t> permutation(self.sizes().begin(), last_batch_dim);
std::vector<int64_t> permutation(self.sizes().begin(), self_sizes_batch_end);
// std::iota is used to maintain the batch dims within the permutation.
// Since expansion added 2 dims, the correct batch dim offsets are now: -expanded_shape.size(), ..., -7, -6.
std::iota(permutation.begin(), permutation.end(), -expanded_shape.size());
// Since 2 dims were added, the correct batch dim offsets are now:
// -added_dims_shape.size(), ..., -7, -6.
std::iota(permutation.begin(), permutation.end(), -added_dims_shape.size());
permutation.insert(permutation.end(), {-5 /* oc */, -2 /* h */, -4 /* 1st upscale_factor */, -1 /* w */,
-3 /* 2nd upscale_factor */});
const auto input_permuted = input_expanded.permute(permutation);
const auto input_permuted = input_reshaped.permute(permutation);

// Finally, upscale by collapsing (h, upscale_factor) -> a single dim (oh)
// and (w, upscale_factor) -> a single dim (ow).
std::vector<int64_t> final_shape(self.sizes().begin(), last_batch_dim);
std::vector<int64_t> final_shape(self.sizes().begin(), self_sizes_batch_end);
final_shape.insert(final_shape.end(), {oc, oh, ow});
return input_permuted.reshape(final_shape);
}


Tensor pixel_unshuffle(const Tensor& self, int64_t downscale_factor) {
TORCH_CHECK(self.dim() >= 3,
"pixel_unshuffle expects input to have at least 3 dimensions, but got input with ",
self.dim(), " dimension(s)");
TORCH_CHECK(
downscale_factor > 0,
"pixel_unshuffle expects a positive downscale_factor, but got ",
downscale_factor);
// Format: (B1, ..., Bn), C, H, W
int64_t c = self.size(-3);
int64_t h = self.size(-2);
int64_t w = self.size(-1);
constexpr auto NUM_NON_BATCH_DIMS = 3;
const auto self_sizes_batch_end = self.sizes().end() - NUM_NON_BATCH_DIMS;

TORCH_CHECK(h % downscale_factor == 0,
"pixel_unshuffle expects height to be divisible by downscale_factor, but input.size(-2)=", h,
" is not divisible by ", downscale_factor)
TORCH_CHECK(w % downscale_factor == 0,
"pixel_unshuffle expects width to be divisible by downscale_factor, but input.size(-1)=", w,
" is not divisible by ", downscale_factor)
int64_t downscale_factor_squared = downscale_factor * downscale_factor;
int64_t oc = c * downscale_factor_squared;
int64_t oh = h / downscale_factor;
int64_t ow = w / downscale_factor;

// First, reshape to split height dim into (oh, downscale_factor) dims and
// width dim into (ow, downscale_factor) dims. This allows unshuffling to be
// done next by permuting dims.
std::vector<int64_t> added_dims_shape(
self.sizes().begin(), self_sizes_batch_end);
added_dims_shape.insert(
added_dims_shape.end(), {c, oh, downscale_factor, ow, downscale_factor});
const auto input_reshaped = self.reshape(added_dims_shape);

// Next, unshuffle by permuting the downscale_factor dims alongside the channel dim.
std::vector<int64_t> permutation(self.sizes().begin(), self_sizes_batch_end);
// std::iota is used to maintain the batch dims within the permutation.
// Since 2 dims were added, the correct batch dim offsets are now:
// -added_dims_shape.size(), ..., -7, -6.
std::iota(permutation.begin(), permutation.end(), -added_dims_shape.size());
permutation.insert(permutation.end(), {-5 /* c */, -3 /* 1st downscale_factor */, -1 /*2nd downscale_factor */,
-4 /* oh */, -2 /* ow */});
const auto input_permuted = input_reshaped.permute(permutation);

// Finally, downscale by collapsing (c, downscale_factor, downscale_factor) -> a single dim (oc),
// resulting in height=oh and width=ow.
std::vector<int64_t> final_shape(self.sizes().begin(), self_sizes_batch_end);
final_shape.insert(final_shape.end(), {oc, oh, ow});
return input_permuted.reshape(final_shape);
}
Expand Down
3 changes: 3 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Expand Up @@ -3342,6 +3342,9 @@
- func: pixel_shuffle(Tensor self, int upscale_factor) -> Tensor
use_c10_dispatcher: full

- func: pixel_unshuffle(Tensor self, int downscale_factor) -> Tensor
use_c10_dispatcher: full

- func: channel_shuffle(Tensor self, int groups) -> Tensor
use_c10_dispatcher: full
dispatch:
Expand Down
5 changes: 5 additions & 0 deletions docs/source/nn.functional.rst
Expand Up @@ -496,6 +496,11 @@ Vision functions

.. autofunction:: pixel_shuffle

:hidden:`pixel_unshuffle`
~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: pixel_unshuffle

:hidden:`pad`
~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
1 change: 1 addition & 0 deletions docs/source/nn.rst
Expand Up @@ -299,6 +299,7 @@ Vision Layers
:template: classtemplate.rst

nn.PixelShuffle
nn.PixelUnshuffle
nn.Upsample
nn.UpsamplingNearest2d
nn.UpsamplingBilinear2d
Expand Down
17 changes: 17 additions & 0 deletions test/cpp/api/functional.cpp
Expand Up @@ -1487,6 +1487,23 @@ TEST_F(FunctionalTest, PixelShuffle) {
ASSERT_TRUE(y.allclose(y_exp));
}

TEST_F(FunctionalTest, PixelUnshuffle) {
auto x = torch::tensor(
{{{{-17, 7, 19, 14}, {0, -15, -2, 0}, {-1, -3, 2, 1}, {-12, -3, 14, 9}}}},
torch::kFloat);
auto y_exp = torch::tensor(
{{{{-17, 19}, {-1, 2}},
{{7, 14}, {-3, 1}},
{{0, -2}, {-12, 14}},
{{-15, 0}, {-3, 9}}}},
torch::kFloat);
auto y = F::pixel_unshuffle(x, 2);

ASSERT_EQ(y.ndimension(), 4);
ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 4, 2, 2}));
ASSERT_TRUE(y.allclose(y_exp));
}

TEST_F(FunctionalTest, Softplus) {
const auto size = 3;
for (const auto beta : {0.5, 1.0, 2.0}) {
Expand Down
24 changes: 24 additions & 0 deletions test/cpp/api/modules.cpp
Expand Up @@ -2761,6 +2761,24 @@ TEST_F(ModulesTest, PixelShuffle) {
ASSERT_TRUE(y.allclose(y_exp));
}

TEST_F(ModulesTest, PixelUnshuffle) {
PixelUnshuffle module(/*downscale_factor=*/2);
auto x = torch::tensor(
{{{{-17, 7, 19, 14}, {0, -15, -2, 0}, {-1, -3, 2, 1}, {-12, -3, 14, 9}}}},
torch::kFloat);
auto y_exp = torch::tensor(
{{{{-17, 19}, {-1, 2}},
{{7, 14}, {-3, 1}},
{{0, -2}, {-12, 14}},
{{-15, 0}, {-3, 9}}}},
torch::kFloat);
auto y = module(x);

ASSERT_EQ(y.ndimension(), 4);
ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 4, 2, 2}));
ASSERT_TRUE(y.allclose(y_exp));
}

TEST_F(ModulesTest, Softplus) {
const auto size = 3;
for (const auto beta : {0.5, 1.0, 2.0}) {
Expand Down Expand Up @@ -4764,6 +4782,12 @@ TEST_F(ModulesTest, PrettyPrintPixelShuffle) {
"torch::nn::PixelShuffle(upscale_factor=5)");
}

TEST_F(ModulesTest, PrettyPrintPixelUnshuffle) {
ASSERT_EQ(
c10::str(PixelUnshuffle(PixelUnshuffleOptions(5))),
"torch::nn::PixelUnshuffle(downscale_factor=5)");
}

TEST_F(ModulesTest, PrettyPrintSoftplus) {
ASSERT_EQ(c10::str(Softplus()),
"torch::nn::Softplus(beta=1, threshold=20)");
Expand Down
1 change: 1 addition & 0 deletions test/cpp_api_parity/parity-tracker.md
Expand Up @@ -125,6 +125,7 @@ torch::nn::CosineEmbeddingLoss|Yes|No
torch::nn::MultiMarginLoss|Yes|No
torch::nn::TripletMarginLoss|Yes|No
torch::nn::PixelShuffle|Yes|No
torch::nn::PixelUnshuffle|Yes|No
torch::nn::Upsample|Yes|No
torch::nn::DataParallel|No|No
torch::nn::parallel::DistributedDataParallel|No|No
Expand Down
86 changes: 58 additions & 28 deletions test/test_nn.py
Expand Up @@ -6897,8 +6897,9 @@ def test_noncontig_conv_grad_cuda(self, dtype=torch.float):
output.backward(grad.contiguous())
self.assertEqual(result, input.grad.data, atol=dtype2prec_DONTUSE[dtype], rtol=0)

def test_pixel_shuffle(self):
def _test_pixel_shuffle_helper(num_input_dims, valid_channels_dim=True):
def test_pixel_shuffle_unshuffle(self):
def _test_pixel_shuffle_unshuffle_helper(num_input_dims, valid_channels_dim=True,
upscale_factor=None):
# Function to imperatively ensure pixels are shuffled to the correct locations.
# Used to validate the batch operations in pixel_shuffle.
def _verify_pixel_shuffle(input, output, upscale_factor):
Expand All @@ -6911,7 +6912,7 @@ def _verify_pixel_shuffle(input, output, upscale_factor):
(c * upscale_factor ** 2)
self.assertEqual(output[..., c, h, w], input[..., channel_idx, height_idx, weight_idx])

upscale_factor = random.randint(2, 5)
upscale_factor = random.randint(2, 5) if upscale_factor is None else upscale_factor
# If valid_channels_dim=False, add 1 to make channels dim indivisible by upscale_factor ** 2.
channels = random.randint(1, 4) * upscale_factor ** 2 + (0 if valid_channels_dim else 1)
height = random.randint(5, 10)
Expand All @@ -6925,47 +6926,76 @@ def _verify_pixel_shuffle(input, output, upscale_factor):
batch_sizes = [random.randint(1, 3) for _ in range(num_input_dims - 3)]
input = torch.rand(*batch_sizes, channels, height, width, requires_grad=True)
ps = nn.PixelShuffle(upscale_factor)
pus = nn.PixelUnshuffle(downscale_factor=upscale_factor)

if num_input_dims >= 3 and valid_channels_dim:
if num_input_dims >= 3 and valid_channels_dim and upscale_factor > 0:
output = ps(input)
_verify_pixel_shuffle(input, output, upscale_factor)
output.backward(output.data)
self.assertEqual(input.data, input.grad.data)

# Ensure unshuffle properly inverts shuffle.
unshuffle_output = pus(output)
self.assertEqual(input, unshuffle_output)
else:
self.assertRaises(RuntimeError, lambda: ps(input))

def test_pixel_shuffle_1D():
_test_pixel_shuffle_helper(num_input_dims=1)
def _test_pixel_unshuffle_error_case_helper(num_input_dims, valid_height_dim=True, valid_width_dim=True,
downscale_factor=None):
downscale_factor = random.randint(2, 5) if downscale_factor is None else downscale_factor
channels = random.randint(1, 4)
# If valid_height_dim=False, add 1 to make height dim indivisible by downscale_factor.
height = random.randint(3, 5) * abs(downscale_factor) + (0 if valid_height_dim else 1)
# If valid_width_dim=False, add 1 to make width dim indivisible by downscale_factor.
width = random.randint(3, 5) * abs(downscale_factor) + (0 if valid_width_dim else 1)

if num_input_dims == 1:
input = torch.rand(channels, requires_grad=True)
elif num_input_dims == 2:
input = torch.rand(height, width, requires_grad=True)
else:
batch_sizes = [random.randint(1, 3) for _ in range(num_input_dims - 3)]
input = torch.rand(*batch_sizes, channels, height, width, requires_grad=True)

pus = nn.PixelUnshuffle(downscale_factor)
self.assertRaises(RuntimeError, lambda: pus(input))

def _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims):
# For 1D - 2D, this is an error case.
# For 3D - 5D, this is a success case for pixel_shuffle + pixel_unshuffle.
_test_pixel_shuffle_unshuffle_helper(num_input_dims=num_input_dims)

def test_pixel_shuffle_2D():
_test_pixel_shuffle_helper(num_input_dims=2)
# Error cases for pixel_shuffle.
_test_pixel_shuffle_unshuffle_helper(num_input_dims=num_input_dims, valid_channels_dim=False)
_test_pixel_shuffle_unshuffle_helper(num_input_dims=num_input_dims, upscale_factor=0)
_test_pixel_shuffle_unshuffle_helper(num_input_dims=num_input_dims, upscale_factor=-2)

def test_pixel_shuffle_3D_with_valid_channels_dim():
_test_pixel_shuffle_helper(num_input_dims=3)
# Error cases for pixel_unshuffle.
_test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, valid_height_dim=False)
_test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, valid_width_dim=False)
_test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, downscale_factor=0)
_test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, downscale_factor=-2)

def test_pixel_shuffle_4D_with_valid_channels_dim():
_test_pixel_shuffle_helper(num_input_dims=4)
def test_pixel_shuffle_unshuffle_1D():
_test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=1)

def test_pixel_shuffle_5D_with_valid_channels_dim():
_test_pixel_shuffle_helper(num_input_dims=5)
def test_pixel_shuffle_unshuffle_2D():
_test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=2)

def test_pixel_shuffle_3D_with_invalid_channels_dim():
_test_pixel_shuffle_helper(num_input_dims=3, valid_channels_dim=False)
def test_pixel_shuffle_unshuffle_3D():
_test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=3)

def test_pixel_shuffle_4D_with_invalid_channels_dim():
_test_pixel_shuffle_helper(num_input_dims=4, valid_channels_dim=False)
def test_pixel_shuffle_unshuffle_4D():
_test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=4)

def test_pixel_shuffle_5D_with_invalid_channels_dim():
_test_pixel_shuffle_helper(num_input_dims=5, valid_channels_dim=False)
def test_pixel_shuffle_unshuffle_5D():
_test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=5)

test_pixel_shuffle_1D()
test_pixel_shuffle_2D()
test_pixel_shuffle_3D_with_valid_channels_dim()
test_pixel_shuffle_4D_with_valid_channels_dim()
test_pixel_shuffle_5D_with_valid_channels_dim()
test_pixel_shuffle_3D_with_invalid_channels_dim()
test_pixel_shuffle_4D_with_invalid_channels_dim()
test_pixel_shuffle_5D_with_invalid_channels_dim()
test_pixel_shuffle_unshuffle_1D()
test_pixel_shuffle_unshuffle_2D()
test_pixel_shuffle_unshuffle_3D()
test_pixel_shuffle_unshuffle_4D()
test_pixel_shuffle_unshuffle_5D()

def test_elu_inplace_view(self):
v = torch.tensor([1.0, -1.0, 1.0, -1.0], requires_grad=True)
Expand Down
1 change: 1 addition & 0 deletions tools/pyi/gen_pyi.py
Expand Up @@ -210,6 +210,7 @@ def gen_nn_functional(out: str) -> None:
'celu_',
'rrelu_',
'pixel_shuffle',
'pixel_unshuffle',
'channel_shuffle',
'pdist',
'cosine_similarity',
Expand Down
10 changes: 10 additions & 0 deletions torch/csrc/api/include/torch/nn/functional/pixelshuffle.h
Expand Up @@ -16,6 +16,10 @@ inline Tensor pixel_shuffle(
upscale_factor
);
}

inline Tensor pixel_unshuffle(const Tensor& input, int64_t downscale_factor) {
return torch::pixel_unshuffle(input, downscale_factor);
}
} // namespace detail
#endif /* DOXYGEN_SHOULD_SKIP_THIS */

Expand All @@ -36,6 +40,12 @@ inline Tensor pixel_shuffle(
return detail::pixel_shuffle(input, options.upscale_factor());
}

inline Tensor pixel_unshuffle(
const Tensor& input,
const PixelUnshuffleFuncOptions& options) {
return detail::pixel_unshuffle(input, options.downscale_factor());
}

} // namespace functional
} // namespace nn
} // namespace torch

0 comments on commit 68d438c

Please sign in to comment.