diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h index 644d75c04c06..518e74b95d54 100644 --- a/aten/src/ATen/core/aten_interned_strings.h +++ b/aten/src/ATen/core/aten_interned_strings.h @@ -553,6 +553,7 @@ _(aten, permute) \ _(aten, pin_memory) \ _(aten, pinverse) \ _(aten, pixel_shuffle) \ +_(aten, pixel_unshuffle) \ _(aten, poisson) \ _(aten, polygamma) \ _(aten, pow) \ diff --git a/aten/src/ATen/native/PixelShuffle.cpp b/aten/src/ATen/native/PixelShuffle.cpp index e6301e682d77..20214470ba28 100644 --- a/aten/src/ATen/native/PixelShuffle.cpp +++ b/aten/src/ATen/native/PixelShuffle.cpp @@ -14,12 +14,16 @@ Tensor pixel_shuffle(const Tensor& self, int64_t upscale_factor) { TORCH_CHECK(self.dim() >= 3, "pixel_shuffle expects input to have at least 3 dimensions, but got input with ", self.dim(), " dimension(s)"); + TORCH_CHECK( + upscale_factor > 0, + "pixel_shuffle expects a positive upscale_factor, but got ", + upscale_factor); // Format: (B1, ..., Bn), C, H, W int64_t c = self.size(-3); int64_t h = self.size(-2); int64_t w = self.size(-1); const auto NUM_NON_BATCH_DIMS = 3; - const auto last_batch_dim = self.sizes().end() - NUM_NON_BATCH_DIMS; + const auto self_sizes_batch_end = self.sizes().end() - NUM_NON_BATCH_DIMS; int64_t upscale_factor_squared = upscale_factor * upscale_factor; TORCH_CHECK(c % upscale_factor_squared == 0, @@ -29,24 +33,81 @@ Tensor pixel_shuffle(const Tensor& self, int64_t upscale_factor) { int64_t oh = h * upscale_factor; int64_t ow = w * upscale_factor; - // First, reshape to expand the channels dim from c into 3 separate dims: (oc, upscale_factor, upscale_factor). - // This allows shuffling to be done next by permuting dims. - std::vector expanded_shape(self.sizes().begin(), last_batch_dim); - expanded_shape.insert(expanded_shape.end(), {oc, upscale_factor, upscale_factor, h, w}); - const auto input_expanded = self.reshape(expanded_shape); + // First, reshape to split the channels dim from c into 3 separate dims: (oc, + // upscale_factor, upscale_factor). This allows shuffling to be done next by + // permuting dims. + std::vector added_dims_shape( + self.sizes().begin(), self_sizes_batch_end); + added_dims_shape.insert( + added_dims_shape.end(), {oc, upscale_factor, upscale_factor, h, w}); + const auto input_reshaped = self.reshape(added_dims_shape); // Next, shuffle by permuting the new upscale_factor dims alongside the height and width dims. - std::vector permutation(self.sizes().begin(), last_batch_dim); + std::vector permutation(self.sizes().begin(), self_sizes_batch_end); // std::iota is used to maintain the batch dims within the permutation. - // Since expansion added 2 dims, the correct batch dim offsets are now: -expanded_shape.size(), ..., -7, -6. - std::iota(permutation.begin(), permutation.end(), -expanded_shape.size()); + // Since 2 dims were added, the correct batch dim offsets are now: + // -added_dims_shape.size(), ..., -7, -6. + std::iota(permutation.begin(), permutation.end(), -added_dims_shape.size()); permutation.insert(permutation.end(), {-5 /* oc */, -2 /* h */, -4 /* 1st upscale_factor */, -1 /* w */, -3 /* 2nd upscale_factor */}); - const auto input_permuted = input_expanded.permute(permutation); + const auto input_permuted = input_reshaped.permute(permutation); // Finally, upscale by collapsing (h, upscale_factor) -> a single dim (oh) // and (w, upscale_factor) -> a single dim (ow). - std::vector final_shape(self.sizes().begin(), last_batch_dim); + std::vector final_shape(self.sizes().begin(), self_sizes_batch_end); + final_shape.insert(final_shape.end(), {oc, oh, ow}); + return input_permuted.reshape(final_shape); +} + + +Tensor pixel_unshuffle(const Tensor& self, int64_t downscale_factor) { + TORCH_CHECK(self.dim() >= 3, + "pixel_unshuffle expects input to have at least 3 dimensions, but got input with ", + self.dim(), " dimension(s)"); + TORCH_CHECK( + downscale_factor > 0, + "pixel_unshuffle expects a positive downscale_factor, but got ", + downscale_factor); + // Format: (B1, ..., Bn), C, H, W + int64_t c = self.size(-3); + int64_t h = self.size(-2); + int64_t w = self.size(-1); + constexpr auto NUM_NON_BATCH_DIMS = 3; + const auto self_sizes_batch_end = self.sizes().end() - NUM_NON_BATCH_DIMS; + + TORCH_CHECK(h % downscale_factor == 0, + "pixel_unshuffle expects height to be divisible by downscale_factor, but input.size(-2)=", h, + " is not divisible by ", downscale_factor) + TORCH_CHECK(w % downscale_factor == 0, + "pixel_unshuffle expects width to be divisible by downscale_factor, but input.size(-1)=", w, + " is not divisible by ", downscale_factor) + int64_t downscale_factor_squared = downscale_factor * downscale_factor; + int64_t oc = c * downscale_factor_squared; + int64_t oh = h / downscale_factor; + int64_t ow = w / downscale_factor; + + // First, reshape to split height dim into (oh, downscale_factor) dims and + // width dim into (ow, downscale_factor) dims. This allows unshuffling to be + // done next by permuting dims. + std::vector added_dims_shape( + self.sizes().begin(), self_sizes_batch_end); + added_dims_shape.insert( + added_dims_shape.end(), {c, oh, downscale_factor, ow, downscale_factor}); + const auto input_reshaped = self.reshape(added_dims_shape); + + // Next, unshuffle by permuting the downscale_factor dims alongside the channel dim. + std::vector permutation(self.sizes().begin(), self_sizes_batch_end); + // std::iota is used to maintain the batch dims within the permutation. + // Since 2 dims were added, the correct batch dim offsets are now: + // -added_dims_shape.size(), ..., -7, -6. + std::iota(permutation.begin(), permutation.end(), -added_dims_shape.size()); + permutation.insert(permutation.end(), {-5 /* c */, -3 /* 1st downscale_factor */, -1 /*2nd downscale_factor */, + -4 /* oh */, -2 /* ow */}); + const auto input_permuted = input_reshaped.permute(permutation); + + // Finally, downscale by collapsing (c, downscale_factor, downscale_factor) -> a single dim (oc), + // resulting in height=oh and width=ow. + std::vector final_shape(self.sizes().begin(), self_sizes_batch_end); final_shape.insert(final_shape.end(), {oc, oh, ow}); return input_permuted.reshape(final_shape); } diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 9c0053f40b7e..48692e792ae3 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -3342,6 +3342,9 @@ - func: pixel_shuffle(Tensor self, int upscale_factor) -> Tensor use_c10_dispatcher: full +- func: pixel_unshuffle(Tensor self, int downscale_factor) -> Tensor + use_c10_dispatcher: full + - func: channel_shuffle(Tensor self, int groups) -> Tensor use_c10_dispatcher: full dispatch: diff --git a/docs/source/nn.functional.rst b/docs/source/nn.functional.rst index 416121cec8d6..17b0e0a80b36 100644 --- a/docs/source/nn.functional.rst +++ b/docs/source/nn.functional.rst @@ -496,6 +496,11 @@ Vision functions .. autofunction:: pixel_shuffle +:hidden:`pixel_unshuffle` +~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: pixel_unshuffle + :hidden:`pad` ~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/nn.rst b/docs/source/nn.rst index 4e3e8437b88b..74f7994447a1 100644 --- a/docs/source/nn.rst +++ b/docs/source/nn.rst @@ -299,6 +299,7 @@ Vision Layers :template: classtemplate.rst nn.PixelShuffle + nn.PixelUnshuffle nn.Upsample nn.UpsamplingNearest2d nn.UpsamplingBilinear2d diff --git a/test/cpp/api/functional.cpp b/test/cpp/api/functional.cpp index 707c1bfd7ac0..d4f353f5607f 100644 --- a/test/cpp/api/functional.cpp +++ b/test/cpp/api/functional.cpp @@ -1487,6 +1487,23 @@ TEST_F(FunctionalTest, PixelShuffle) { ASSERT_TRUE(y.allclose(y_exp)); } +TEST_F(FunctionalTest, PixelUnshuffle) { + auto x = torch::tensor( + {{{{-17, 7, 19, 14}, {0, -15, -2, 0}, {-1, -3, 2, 1}, {-12, -3, 14, 9}}}}, + torch::kFloat); + auto y_exp = torch::tensor( + {{{{-17, 19}, {-1, 2}}, + {{7, 14}, {-3, 1}}, + {{0, -2}, {-12, 14}}, + {{-15, 0}, {-3, 9}}}}, + torch::kFloat); + auto y = F::pixel_unshuffle(x, 2); + + ASSERT_EQ(y.ndimension(), 4); + ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 4, 2, 2})); + ASSERT_TRUE(y.allclose(y_exp)); +} + TEST_F(FunctionalTest, Softplus) { const auto size = 3; for (const auto beta : {0.5, 1.0, 2.0}) { diff --git a/test/cpp/api/modules.cpp b/test/cpp/api/modules.cpp index 14ed92f9fb0d..f24f8b42a19b 100644 --- a/test/cpp/api/modules.cpp +++ b/test/cpp/api/modules.cpp @@ -2761,6 +2761,24 @@ TEST_F(ModulesTest, PixelShuffle) { ASSERT_TRUE(y.allclose(y_exp)); } +TEST_F(ModulesTest, PixelUnshuffle) { + PixelUnshuffle module(/*downscale_factor=*/2); + auto x = torch::tensor( + {{{{-17, 7, 19, 14}, {0, -15, -2, 0}, {-1, -3, 2, 1}, {-12, -3, 14, 9}}}}, + torch::kFloat); + auto y_exp = torch::tensor( + {{{{-17, 19}, {-1, 2}}, + {{7, 14}, {-3, 1}}, + {{0, -2}, {-12, 14}}, + {{-15, 0}, {-3, 9}}}}, + torch::kFloat); + auto y = module(x); + + ASSERT_EQ(y.ndimension(), 4); + ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 4, 2, 2})); + ASSERT_TRUE(y.allclose(y_exp)); +} + TEST_F(ModulesTest, Softplus) { const auto size = 3; for (const auto beta : {0.5, 1.0, 2.0}) { @@ -4764,6 +4782,12 @@ TEST_F(ModulesTest, PrettyPrintPixelShuffle) { "torch::nn::PixelShuffle(upscale_factor=5)"); } +TEST_F(ModulesTest, PrettyPrintPixelUnshuffle) { + ASSERT_EQ( + c10::str(PixelUnshuffle(PixelUnshuffleOptions(5))), + "torch::nn::PixelUnshuffle(downscale_factor=5)"); +} + TEST_F(ModulesTest, PrettyPrintSoftplus) { ASSERT_EQ(c10::str(Softplus()), "torch::nn::Softplus(beta=1, threshold=20)"); diff --git a/test/cpp_api_parity/parity-tracker.md b/test/cpp_api_parity/parity-tracker.md index 66931b6f9316..55d3f33f32b2 100644 --- a/test/cpp_api_parity/parity-tracker.md +++ b/test/cpp_api_parity/parity-tracker.md @@ -125,6 +125,7 @@ torch::nn::CosineEmbeddingLoss|Yes|No torch::nn::MultiMarginLoss|Yes|No torch::nn::TripletMarginLoss|Yes|No torch::nn::PixelShuffle|Yes|No +torch::nn::PixelUnshuffle|Yes|No torch::nn::Upsample|Yes|No torch::nn::DataParallel|No|No torch::nn::parallel::DistributedDataParallel|No|No diff --git a/test/test_nn.py b/test/test_nn.py index 78aab89611b6..1d63be6e3075 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -6897,8 +6897,9 @@ def test_noncontig_conv_grad_cuda(self, dtype=torch.float): output.backward(grad.contiguous()) self.assertEqual(result, input.grad.data, atol=dtype2prec_DONTUSE[dtype], rtol=0) - def test_pixel_shuffle(self): - def _test_pixel_shuffle_helper(num_input_dims, valid_channels_dim=True): + def test_pixel_shuffle_unshuffle(self): + def _test_pixel_shuffle_unshuffle_helper(num_input_dims, valid_channels_dim=True, + upscale_factor=None): # Function to imperatively ensure pixels are shuffled to the correct locations. # Used to validate the batch operations in pixel_shuffle. def _verify_pixel_shuffle(input, output, upscale_factor): @@ -6911,7 +6912,7 @@ def _verify_pixel_shuffle(input, output, upscale_factor): (c * upscale_factor ** 2) self.assertEqual(output[..., c, h, w], input[..., channel_idx, height_idx, weight_idx]) - upscale_factor = random.randint(2, 5) + upscale_factor = random.randint(2, 5) if upscale_factor is None else upscale_factor # If valid_channels_dim=False, add 1 to make channels dim indivisible by upscale_factor ** 2. channels = random.randint(1, 4) * upscale_factor ** 2 + (0 if valid_channels_dim else 1) height = random.randint(5, 10) @@ -6925,47 +6926,76 @@ def _verify_pixel_shuffle(input, output, upscale_factor): batch_sizes = [random.randint(1, 3) for _ in range(num_input_dims - 3)] input = torch.rand(*batch_sizes, channels, height, width, requires_grad=True) ps = nn.PixelShuffle(upscale_factor) + pus = nn.PixelUnshuffle(downscale_factor=upscale_factor) - if num_input_dims >= 3 and valid_channels_dim: + if num_input_dims >= 3 and valid_channels_dim and upscale_factor > 0: output = ps(input) _verify_pixel_shuffle(input, output, upscale_factor) output.backward(output.data) self.assertEqual(input.data, input.grad.data) + + # Ensure unshuffle properly inverts shuffle. + unshuffle_output = pus(output) + self.assertEqual(input, unshuffle_output) else: self.assertRaises(RuntimeError, lambda: ps(input)) - def test_pixel_shuffle_1D(): - _test_pixel_shuffle_helper(num_input_dims=1) + def _test_pixel_unshuffle_error_case_helper(num_input_dims, valid_height_dim=True, valid_width_dim=True, + downscale_factor=None): + downscale_factor = random.randint(2, 5) if downscale_factor is None else downscale_factor + channels = random.randint(1, 4) + # If valid_height_dim=False, add 1 to make height dim indivisible by downscale_factor. + height = random.randint(3, 5) * abs(downscale_factor) + (0 if valid_height_dim else 1) + # If valid_width_dim=False, add 1 to make width dim indivisible by downscale_factor. + width = random.randint(3, 5) * abs(downscale_factor) + (0 if valid_width_dim else 1) + + if num_input_dims == 1: + input = torch.rand(channels, requires_grad=True) + elif num_input_dims == 2: + input = torch.rand(height, width, requires_grad=True) + else: + batch_sizes = [random.randint(1, 3) for _ in range(num_input_dims - 3)] + input = torch.rand(*batch_sizes, channels, height, width, requires_grad=True) + + pus = nn.PixelUnshuffle(downscale_factor) + self.assertRaises(RuntimeError, lambda: pus(input)) + + def _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims): + # For 1D - 2D, this is an error case. + # For 3D - 5D, this is a success case for pixel_shuffle + pixel_unshuffle. + _test_pixel_shuffle_unshuffle_helper(num_input_dims=num_input_dims) - def test_pixel_shuffle_2D(): - _test_pixel_shuffle_helper(num_input_dims=2) + # Error cases for pixel_shuffle. + _test_pixel_shuffle_unshuffle_helper(num_input_dims=num_input_dims, valid_channels_dim=False) + _test_pixel_shuffle_unshuffle_helper(num_input_dims=num_input_dims, upscale_factor=0) + _test_pixel_shuffle_unshuffle_helper(num_input_dims=num_input_dims, upscale_factor=-2) - def test_pixel_shuffle_3D_with_valid_channels_dim(): - _test_pixel_shuffle_helper(num_input_dims=3) + # Error cases for pixel_unshuffle. + _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, valid_height_dim=False) + _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, valid_width_dim=False) + _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, downscale_factor=0) + _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, downscale_factor=-2) - def test_pixel_shuffle_4D_with_valid_channels_dim(): - _test_pixel_shuffle_helper(num_input_dims=4) + def test_pixel_shuffle_unshuffle_1D(): + _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=1) - def test_pixel_shuffle_5D_with_valid_channels_dim(): - _test_pixel_shuffle_helper(num_input_dims=5) + def test_pixel_shuffle_unshuffle_2D(): + _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=2) - def test_pixel_shuffle_3D_with_invalid_channels_dim(): - _test_pixel_shuffle_helper(num_input_dims=3, valid_channels_dim=False) + def test_pixel_shuffle_unshuffle_3D(): + _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=3) - def test_pixel_shuffle_4D_with_invalid_channels_dim(): - _test_pixel_shuffle_helper(num_input_dims=4, valid_channels_dim=False) + def test_pixel_shuffle_unshuffle_4D(): + _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=4) - def test_pixel_shuffle_5D_with_invalid_channels_dim(): - _test_pixel_shuffle_helper(num_input_dims=5, valid_channels_dim=False) + def test_pixel_shuffle_unshuffle_5D(): + _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=5) - test_pixel_shuffle_1D() - test_pixel_shuffle_2D() - test_pixel_shuffle_3D_with_valid_channels_dim() - test_pixel_shuffle_4D_with_valid_channels_dim() - test_pixel_shuffle_5D_with_valid_channels_dim() - test_pixel_shuffle_3D_with_invalid_channels_dim() - test_pixel_shuffle_4D_with_invalid_channels_dim() - test_pixel_shuffle_5D_with_invalid_channels_dim() + test_pixel_shuffle_unshuffle_1D() + test_pixel_shuffle_unshuffle_2D() + test_pixel_shuffle_unshuffle_3D() + test_pixel_shuffle_unshuffle_4D() + test_pixel_shuffle_unshuffle_5D() def test_elu_inplace_view(self): v = torch.tensor([1.0, -1.0, 1.0, -1.0], requires_grad=True) diff --git a/tools/pyi/gen_pyi.py b/tools/pyi/gen_pyi.py index d2073bec9a27..7ad514d5d067 100644 --- a/tools/pyi/gen_pyi.py +++ b/tools/pyi/gen_pyi.py @@ -210,6 +210,7 @@ def gen_nn_functional(out: str) -> None: 'celu_', 'rrelu_', 'pixel_shuffle', + 'pixel_unshuffle', 'channel_shuffle', 'pdist', 'cosine_similarity', diff --git a/torch/csrc/api/include/torch/nn/functional/pixelshuffle.h b/torch/csrc/api/include/torch/nn/functional/pixelshuffle.h index 7ea98bf07d99..32161d04d806 100644 --- a/torch/csrc/api/include/torch/nn/functional/pixelshuffle.h +++ b/torch/csrc/api/include/torch/nn/functional/pixelshuffle.h @@ -16,6 +16,10 @@ inline Tensor pixel_shuffle( upscale_factor ); } + +inline Tensor pixel_unshuffle(const Tensor& input, int64_t downscale_factor) { + return torch::pixel_unshuffle(input, downscale_factor); +} } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ @@ -36,6 +40,12 @@ inline Tensor pixel_shuffle( return detail::pixel_shuffle(input, options.upscale_factor()); } +inline Tensor pixel_unshuffle( + const Tensor& input, + const PixelUnshuffleFuncOptions& options) { + return detail::pixel_unshuffle(input, options.downscale_factor()); +} + } // namespace functional } // namespace nn } // namespace torch diff --git a/torch/csrc/api/include/torch/nn/modules/pixelshuffle.h b/torch/csrc/api/include/torch/nn/modules/pixelshuffle.h index 98d4be45e04a..08278ea2162e 100644 --- a/torch/csrc/api/include/torch/nn/modules/pixelshuffle.h +++ b/torch/csrc/api/include/torch/nn/modules/pixelshuffle.h @@ -12,12 +12,13 @@ namespace nn { // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ PixelShuffle ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ /// Rearranges elements in a tensor of shape :math:`(*, C \times r^2, H, W)` -/// to a tensor of shape :math:`(*, C, H \times r, W \times r)`. -/// See https://pytorch.org/docs/master/nn.html#torch.nn.PixelShuffle to learn -/// about the exact behavior of this module. +/// to a tensor of shape :math:`(*, C, H \times r, W \times r)`, where r is an +/// upscale factor. See +/// https://pytorch.org/docs/master/nn.html#torch.nn.PixelShuffle to learn about +/// the exact behavior of this module. /// -/// See the documentation for `torch::nn::PixelShuffleOptions` class to learn what -/// constructor arguments are supported for this module. +/// See the documentation for `torch::nn::PixelShuffleOptions` class to learn +/// what constructor arguments are supported for this module. /// /// Example: /// ``` @@ -44,5 +45,42 @@ struct TORCH_API PixelShuffleImpl : public torch::nn::Cloneable { + explicit PixelUnshuffleImpl(const PixelUnshuffleOptions& options_); + + /// Pretty prints the `PixelUnshuffle` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + Tensor forward(const Tensor& input); + + void reset() override; + + /// The options with which this `Module` was constructed. + PixelUnshuffleOptions options; +}; + +/// A `ModuleHolder` subclass for `PixelUnshuffleImpl`. +/// See the documentation for `PixelUnshuffleImpl` class to learn what methods +/// it provides, and examples of how to use `PixelUnshuffle` with +/// `torch::nn::PixelUnshuffleOptions`. See the documentation for `ModuleHolder` +/// to learn about PyTorch's module storage semantics. +TORCH_MODULE(PixelUnshuffle); + } // namespace nn } // namespace torch diff --git a/torch/csrc/api/include/torch/nn/options/pixelshuffle.h b/torch/csrc/api/include/torch/nn/options/pixelshuffle.h index e72e6931e49a..e28e0053e98b 100644 --- a/torch/csrc/api/include/torch/nn/options/pixelshuffle.h +++ b/torch/csrc/api/include/torch/nn/options/pixelshuffle.h @@ -21,6 +21,20 @@ struct TORCH_API PixelShuffleOptions { TORCH_ARG(int64_t, upscale_factor); }; +/// Options for the `PixelUnshuffle` module. +/// +/// Example: +/// ``` +/// PixelUnshuffle model(PixelUnshuffleOptions(5)); +/// ``` +struct TORCH_API PixelUnshuffleOptions { + /* implicit */ PixelUnshuffleOptions(int64_t downscale_factor) + : downscale_factor_(downscale_factor) {} + + /// Factor to decrease spatial resolution by + TORCH_ARG(int64_t, downscale_factor); +}; + namespace functional { /// Options for `torch::nn::functional::pixel_shuffle`. /// @@ -33,6 +47,18 @@ namespace functional { /// F::pixel_shuffle(x, F::PixelShuffleFuncOptions(2)); /// ``` using PixelShuffleFuncOptions = PixelShuffleOptions; + +/// Options for `torch::nn::functional::pixel_unshuffle`. +/// +/// See the documentation for `torch::nn::PixelUnshuffleOptions` class to learn +/// what arguments are supported. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::pixel_unshuffle(x, F::PixelUnshuffleFuncOptions(2)); +/// ``` +using PixelUnshuffleFuncOptions = PixelUnshuffleOptions; } // namespace functional } // namespace nn diff --git a/torch/csrc/api/src/nn/modules/pixelshuffle.cpp b/torch/csrc/api/src/nn/modules/pixelshuffle.cpp index dd2d34655979..7062b07fe5d7 100644 --- a/torch/csrc/api/src/nn/modules/pixelshuffle.cpp +++ b/torch/csrc/api/src/nn/modules/pixelshuffle.cpp @@ -21,5 +21,19 @@ Tensor PixelShuffleImpl::forward( return F::detail::pixel_shuffle(input, options.upscale_factor()); } +PixelUnshuffleImpl::PixelUnshuffleImpl(const PixelUnshuffleOptions& options_) + : options(options_) {} + +void PixelUnshuffleImpl::pretty_print(std::ostream& stream) const { + stream << "torch::nn::PixelUnshuffle(downscale_factor=" + << options.downscale_factor() << ")"; +} + +void PixelUnshuffleImpl::reset() {} + +Tensor PixelUnshuffleImpl::forward(const Tensor& input) { + return F::detail::pixel_unshuffle(input, options.downscale_factor()); +} + } // namespace nn } // namespace torch diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 24bfecb49ed5..2563d4b0ba29 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -2799,7 +2799,7 @@ def multi_margin_loss(input, target, p=1, margin=1., weight=None, size_average=N pixel_shuffle(input, upscale_factor) -> Tensor Rearranges elements in a tensor of shape :math:`(*, C \times r^2, H, W)` to a -tensor of shape :math:`(*, C, H \times r, W \times r)`. +tensor of shape :math:`(*, C, H \times r, W \times r)`, where r is the :attr:`upscale_factor`. See :class:`~torch.nn.PixelShuffle` for details. @@ -2815,6 +2815,27 @@ def multi_margin_loss(input, target, p=1, margin=1., weight=None, size_average=N torch.Size([1, 1, 12, 12]) """) +pixel_unshuffle = _add_docstr(torch.pixel_unshuffle, r""" +pixel_unshuffle(input, downscale_factor) -> Tensor + +Reverses the :class:`~torch.nn.PixelShuffle` operation by rearranging elements in a +tensor of shape :math:`(*, C, H \times r, W \times r)` to a tensor of shape +:math:`(*, C \times r^2, H, W)`, where r is the :attr:`downscale_factor`. + +See :class:`~torch.nn.PixelUnshuffle` for details. + +Args: + input (Tensor): the input tensor + downscale_factor (int): factor to increase spatial resolution by + +Examples:: + + >>> input = torch.randn(1, 1, 12, 12) + >>> output = torch.nn.functional.pixel_unshuffle(input, 3) + >>> print(output.size()) + torch.Size([1, 9, 4, 4]) +""") + channel_shuffle = _add_docstr(torch.channel_shuffle, r""" channel_shuffle(input, groups) -> Tensor diff --git a/torch/nn/modules/__init__.py b/torch/nn/modules/__init__.py index 30b0d61b42d2..4911d4bef38f 100644 --- a/torch/nn/modules/__init__.py +++ b/torch/nn/modules/__init__.py @@ -24,7 +24,7 @@ from .sparse import Embedding, EmbeddingBag from .rnn import RNNBase, RNN, LSTM, GRU, \ RNNCellBase, RNNCell, LSTMCell, GRUCell -from .pixelshuffle import PixelShuffle +from .pixelshuffle import PixelShuffle, PixelUnshuffle from .upsampling import UpsamplingNearest2d, UpsamplingBilinear2d, Upsample from .distance import PairwiseDistance, CosineSimilarity from .fold import Fold, Unfold @@ -50,7 +50,7 @@ 'Dropout', 'Dropout2d', 'Dropout3d', 'AlphaDropout', 'FeatureAlphaDropout', 'ReflectionPad1d', 'ReflectionPad2d', 'ReplicationPad2d', 'ReplicationPad1d', 'ReplicationPad3d', 'CrossMapLRN2d', 'Embedding', 'EmbeddingBag', 'RNNBase', 'RNN', 'LSTM', 'GRU', 'RNNCellBase', 'RNNCell', - 'LSTMCell', 'GRUCell', 'PixelShuffle', 'Upsample', 'UpsamplingNearest2d', 'UpsamplingBilinear2d', + 'LSTMCell', 'GRUCell', 'PixelShuffle', 'PixelUnshuffle', 'Upsample', 'UpsamplingNearest2d', 'UpsamplingBilinear2d', 'PairwiseDistance', 'AdaptiveMaxPool1d', 'AdaptiveMaxPool2d', 'AdaptiveMaxPool3d', 'AdaptiveAvgPool1d', 'AdaptiveAvgPool2d', 'AdaptiveAvgPool3d', 'TripletMarginLoss', 'ZeroPad2d', 'ConstantPad1d', 'ConstantPad2d', 'ConstantPad3d', 'Bilinear', 'CosineSimilarity', 'Unfold', 'Fold', diff --git a/torch/nn/modules/pixelshuffle.py b/torch/nn/modules/pixelshuffle.py index 8256b111b988..d17f5616c2e9 100644 --- a/torch/nn/modules/pixelshuffle.py +++ b/torch/nn/modules/pixelshuffle.py @@ -6,26 +6,30 @@ class PixelShuffle(Module): r"""Rearranges elements in a tensor of shape :math:`(*, C \times r^2, H, W)` - to a tensor of shape :math:`(*, C, H \times r, W \times r)`. + to a tensor of shape :math:`(*, C, H \times r, W \times r)`, where r is an upscale factor. This is useful for implementing efficient sub-pixel convolution with a stride of :math:`1/r`. - Look at the paper: + See the paper: `Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network`_ by Shi et. al (2016) for more details. - Note that this function can take inputs with any number of batch dimensions: - :math:`(L, H_{in}, W_{in})`, :math:`(N, L, H_{in}, W_{in})`, :math:`(N_1, N_2, L, H_{in}, W_{in})`, etc. - Args: upscale_factor (int): factor to increase spatial resolution by Shape: - - Input: :math:`(*, L, H_{in}, W_{in})` where :math:`L=C \times \text{upscale\_factor}^2` - - Output: :math:`(*, C, H_{out}, W_{out})` where - :math:`H_{out} = H_{in} \times \text{upscale\_factor}` - and :math:`W_{out} = W_{in} \times \text{upscale\_factor}` + - Input: :math:`(*, C_{in}, H_{in}, W_{in})`, where * is zero or more batch dimensions + - Output: :math:`(*, C_{out}, H_{out}, W_{out})`, where + + .. math:: + C_{out} = C_{in} \div \text{upscale\_factor}^2 + + .. math:: + H_{out} = H_{in} \times \text{upscale\_factor} + + .. math:: + W_{out} = W_{in} \times \text{upscale\_factor} Examples:: @@ -50,3 +54,53 @@ def forward(self, input: Tensor) -> Tensor: def extra_repr(self) -> str: return 'upscale_factor={}'.format(self.upscale_factor) + + +class PixelUnshuffle(Module): + r"""Reverses the :class:`~torch.nn.PixelShuffle` operation by rearranging elements + in a tensor of shape :math:`(*, C, H \times r, W \times r)` to a tensor of shape + :math:`(*, C \times r^2, H, W)`, where r is a downscale factor. + + See the paper: + `Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network`_ + by Shi et. al (2016) for more details. + + Args: + downscale_factor (int): factor to decrease spatial resolution by + + Shape: + - Input: :math:`(*, C_{in}, H_{in}, W_{in})`, where * is zero or more batch dimensions + - Output: :math:`(*, C_{out}, H_{out}, W_{out})`, where + + .. math:: + C_{out} = C_{in} \times \text{downscale\_factor}^2 + + .. math:: + H_{out} = H_{in} \div \text{downscale\_factor} + + .. math:: + W_{out} = W_{in} \div \text{downscale\_factor} + + Examples:: + + >>> pixel_unshuffle = nn.PixelUnshuffle(3) + >>> input = torch.randn(1, 1, 12, 12) + >>> output = pixel_unshuffle(input) + >>> print(output.size()) + torch.Size([1, 9, 4, 4]) + + .. _Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network: + https://arxiv.org/abs/1609.05158 + """ + __constants__ = ['downscale_factor'] + downscale_factor: int + + def __init__(self, downscale_factor: int) -> None: + super(PixelUnshuffle, self).__init__() + self.downscale_factor = downscale_factor + + def forward(self, input: Tensor) -> Tensor: + return F.pixel_unshuffle(input, self.downscale_factor) + + def extra_repr(self) -> str: + return 'downscale_factor={}'.format(self.downscale_factor) diff --git a/torch/overrides.py b/torch/overrides.py index d23e34831bdd..6c193b273344 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -706,6 +706,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.pdist: lambda input, p=2: -1, torch.pinverse: lambda input, rcond=1e-15: -1, torch.pixel_shuffle: lambda input, upscale_factor: -1, + torch.pixel_unshuffle: lambda input, downscale_factor: -1, torch.poisson: lambda input, generator=None: -1, torch.poisson_nll_loss: lambda input, target, log_input, full, eps, reduction: -1, torch.polygamma: lambda input, n, out=None: -1, diff --git a/torch/testing/_internal/common_nn.py b/torch/testing/_internal/common_nn.py index 8b9a5072e50f..c588f69c2875 100644 --- a/torch/testing/_internal/common_nn.py +++ b/torch/testing/_internal/common_nn.py @@ -2516,6 +2516,12 @@ def fractional_max_pool3d_test(test_case): cpp_constructor_args='torch::nn::PixelShuffleOptions(3)', input_size=(1, 9, 4, 4), ), + dict( + module_name='PixelUnshuffle', + constructor_args=(3,), + cpp_constructor_args='torch::nn::PixelUnshuffleOptions(3)', + input_size=(1, 1, 12, 12), + ), dict( constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'), cpp_options_args='''F::InterpolateFuncOptions() diff --git a/torch/testing/_internal/jit_metaprogramming_utils.py b/torch/testing/_internal/jit_metaprogramming_utils.py index 4a91394d53c5..2acc380579e5 100644 --- a/torch/testing/_internal/jit_metaprogramming_utils.py +++ b/torch/testing/_internal/jit_metaprogramming_utils.py @@ -140,6 +140,7 @@ ('multilabel_soft_margin_loss', (3, S), (non_differentiable(torch.rand(3, S)),),), ('cosine_embedding_loss', (S, S), ((S, S), non_differentiable(torch.rand(S,))),), ('pixel_shuffle', (1, 9, 4, 4), (3,),), + ('pixel_unshuffle', (1, 1, 12, 12), (3,),), ('affine_grid', (S, 2, 3), (torch.Size([S, 1, 7, 7]),),), ('pad', (3, 3, 4, 2), ([1, 1],),), ('pairwise_distance', (S, S), ((S, S),),),