Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PixelUnshuffle #49334

Closed
wants to merge 10 commits into from
1 change: 1 addition & 0 deletions aten/src/ATen/core/aten_interned_strings.h
Expand Up @@ -551,6 +551,7 @@ _(aten, permute) \
_(aten, pin_memory) \
_(aten, pinverse) \
_(aten, pixel_shuffle) \
_(aten, pixel_unshuffle) \
_(aten, poisson) \
_(aten, polygamma) \
_(aten, pow) \
Expand Down
83 changes: 72 additions & 11 deletions aten/src/ATen/native/PixelShuffle.cpp
Expand Up @@ -14,12 +14,16 @@ Tensor pixel_shuffle(const Tensor& self, int64_t upscale_factor) {
TORCH_CHECK(self.dim() >= 3,
"pixel_shuffle expects input to have at least 3 dimensions, but got input with ",
self.dim(), " dimension(s)");
TORCH_CHECK(
upscale_factor > 0,
"pixel_shuffle expects a positive upscale_factor, but got ",
upscale_factor);
// Format: (B1, ..., Bn), C, H, W
int64_t c = self.size(-3);
int64_t h = self.size(-2);
int64_t w = self.size(-1);
const auto NUM_NON_BATCH_DIMS = 3;
const auto last_batch_dim = self.sizes().end() - NUM_NON_BATCH_DIMS;
const auto self_sizes_batch_end = self.sizes().end() - NUM_NON_BATCH_DIMS;

int64_t upscale_factor_squared = upscale_factor * upscale_factor;
TORCH_CHECK(c % upscale_factor_squared == 0,
Expand All @@ -29,24 +33,81 @@ Tensor pixel_shuffle(const Tensor& self, int64_t upscale_factor) {
int64_t oh = h * upscale_factor;
int64_t ow = w * upscale_factor;

// First, reshape to expand the channels dim from c into 3 separate dims: (oc, upscale_factor, upscale_factor).
// This allows shuffling to be done next by permuting dims.
std::vector<int64_t> expanded_shape(self.sizes().begin(), last_batch_dim);
expanded_shape.insert(expanded_shape.end(), {oc, upscale_factor, upscale_factor, h, w});
const auto input_expanded = self.reshape(expanded_shape);
// First, reshape to split the channels dim from c into 3 separate dims: (oc,
// upscale_factor, upscale_factor). This allows shuffling to be done next by
// permuting dims.
std::vector<int64_t> added_dims_shape(
self.sizes().begin(), self_sizes_batch_end);
added_dims_shape.insert(
added_dims_shape.end(), {oc, upscale_factor, upscale_factor, h, w});
const auto input_reshaped = self.reshape(added_dims_shape);

// Next, shuffle by permuting the new upscale_factor dims alongside the height and width dims.
std::vector<int64_t> permutation(self.sizes().begin(), last_batch_dim);
std::vector<int64_t> permutation(self.sizes().begin(), self_sizes_batch_end);
// std::iota is used to maintain the batch dims within the permutation.
// Since expansion added 2 dims, the correct batch dim offsets are now: -expanded_shape.size(), ..., -7, -6.
std::iota(permutation.begin(), permutation.end(), -expanded_shape.size());
// Since 2 dims were added, the correct batch dim offsets are now:
// -added_dims_shape.size(), ..., -7, -6.
std::iota(permutation.begin(), permutation.end(), -added_dims_shape.size());
permutation.insert(permutation.end(), {-5 /* oc */, -2 /* h */, -4 /* 1st upscale_factor */, -1 /* w */,
-3 /* 2nd upscale_factor */});
const auto input_permuted = input_expanded.permute(permutation);
const auto input_permuted = input_reshaped.permute(permutation);

// Finally, upscale by collapsing (h, upscale_factor) -> a single dim (oh)
// and (w, upscale_factor) -> a single dim (ow).
std::vector<int64_t> final_shape(self.sizes().begin(), last_batch_dim);
std::vector<int64_t> final_shape(self.sizes().begin(), self_sizes_batch_end);
final_shape.insert(final_shape.end(), {oc, oh, ow});
return input_permuted.reshape(final_shape);
}


Tensor pixel_unshuffle(const Tensor& self, int64_t downscale_factor) {
TORCH_CHECK(self.dim() >= 3,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should there also be a check that downscale factor is > 0?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this, especially because the below checks of h % downscale_factor == 0 and w % downscale_factor == 0 may pass even with a negative downscale_factor due to the specific way mod is implemented in C++.

For consistency, what do you think about me adding an upscale_factor check in pixel_shuffle in this PR?

"pixel_unshuffle expects input to have at least 3 dimensions, but got input with ",
self.dim(), " dimension(s)");
TORCH_CHECK(
downscale_factor > 0,
"pixel_unshuffle expects a positive downscale_factor, but got ",
downscale_factor);
// Format: (B1, ..., Bn), C, H, W
int64_t c = self.size(-3);
int64_t h = self.size(-2);
int64_t w = self.size(-1);
const auto NUM_NON_BATCH_DIMS = 3;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: constexpr

const auto self_sizes_batch_end = self.sizes().end() - NUM_NON_BATCH_DIMS;

TORCH_CHECK(h % downscale_factor == 0,
"pixel_unshuffle expects height to be divisible by downscale_factor, but input.size(-2)=", h,
" is not divisible by ", downscale_factor)
TORCH_CHECK(w % downscale_factor == 0,
"pixel_unshuffle expects width to be divisible by downscale_factor, but input.size(-1)=", w,
" is not divisible by ", downscale_factor)
int64_t downscale_factor_squared = downscale_factor * downscale_factor;
int64_t oc = c * downscale_factor_squared;
int64_t oh = h / downscale_factor;
int64_t ow = w / downscale_factor;

// First, reshape to split height dim into (oh, downscale_factor) dims and
// width dim into (ow, downscale_factor) dims. This allows unshuffling to be
// done next by permuting dims.
std::vector<int64_t> added_dims_shape(
self.sizes().begin(), self_sizes_batch_end);
added_dims_shape.insert(
added_dims_shape.end(), {c, oh, downscale_factor, ow, downscale_factor});
const auto input_reshaped = self.reshape(added_dims_shape);

// Next, unshuffle by permuting the downscale_factor dims alongside the channel dim.
std::vector<int64_t> permutation(self.sizes().begin(), self_sizes_batch_end);
// std::iota is used to maintain the batch dims within the permutation.
// Since 2 dims were added, the correct batch dim offsets are now:
// -added_dims_shape.size(), ..., -7, -6.
std::iota(permutation.begin(), permutation.end(), -added_dims_shape.size());
permutation.insert(permutation.end(), {-5 /* c */, -3 /* 1st downscale_factor */, -1 /*2nd downscale_factor */,
-4 /* oh */, -2 /* ow */});
const auto input_permuted = input_reshaped.permute(permutation);

// Finally, downscale by collapsing (c, downscale_factor, downscale_factor) -> a single dim (oc),
// resulting in height=oh and width=ow.
std::vector<int64_t> final_shape(self.sizes().begin(), self_sizes_batch_end);
final_shape.insert(final_shape.end(), {oc, oh, ow});
return input_permuted.reshape(final_shape);
}
Expand Down
3 changes: 3 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Expand Up @@ -3261,6 +3261,9 @@
- func: pixel_shuffle(Tensor self, int upscale_factor) -> Tensor
use_c10_dispatcher: full

- func: pixel_unshuffle(Tensor self, int downscale_factor) -> Tensor
use_c10_dispatcher: full

- func: channel_shuffle(Tensor self, int groups) -> Tensor
use_c10_dispatcher: full
dispatch:
Expand Down
5 changes: 5 additions & 0 deletions docs/source/nn.functional.rst
Expand Up @@ -496,6 +496,11 @@ Vision functions

.. autofunction:: pixel_shuffle

:hidden:`pixel_unshuffle`
~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: pixel_unshuffle

:hidden:`pad`
~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
1 change: 1 addition & 0 deletions docs/source/nn.rst
Expand Up @@ -299,6 +299,7 @@ Vision Layers
:template: classtemplate.rst

nn.PixelShuffle
nn.PixelUnshuffle
nn.Upsample
nn.UpsamplingNearest2d
nn.UpsamplingBilinear2d
Expand Down
17 changes: 17 additions & 0 deletions test/cpp/api/functional.cpp
Expand Up @@ -1487,6 +1487,23 @@ TEST_F(FunctionalTest, PixelShuffle) {
ASSERT_TRUE(y.allclose(y_exp));
}

TEST_F(FunctionalTest, PixelUnshuffle) {
auto x = torch::tensor(
{{{{-17, 7, 19, 14}, {0, -15, -2, 0}, {-1, -3, 2, 1}, {-12, -3, 14, 9}}}},
torch::kFloat);
auto y_exp = torch::tensor(
{{{{-17, 19}, {-1, 2}},
{{7, 14}, {-3, 1}},
{{0, -2}, {-12, 14}},
{{-15, 0}, {-3, 9}}}},
torch::kFloat);
auto y = F::pixel_unshuffle(x, 2);

ASSERT_EQ(y.ndimension(), 4);
ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 4, 2, 2}));
ASSERT_TRUE(y.allclose(y_exp));
}

TEST_F(FunctionalTest, Softplus) {
const auto size = 3;
for (const auto beta : {0.5, 1.0, 2.0}) {
Expand Down
24 changes: 24 additions & 0 deletions test/cpp/api/modules.cpp
Expand Up @@ -2761,6 +2761,24 @@ TEST_F(ModulesTest, PixelShuffle) {
ASSERT_TRUE(y.allclose(y_exp));
}

TEST_F(ModulesTest, PixelUnshuffle) {
PixelUnshuffle module(/*downscale_factor=*/2);
auto x = torch::tensor(
{{{{-17, 7, 19, 14}, {0, -15, -2, 0}, {-1, -3, 2, 1}, {-12, -3, 14, 9}}}},
torch::kFloat);
auto y_exp = torch::tensor(
{{{{-17, 19}, {-1, 2}},
{{7, 14}, {-3, 1}},
{{0, -2}, {-12, 14}},
{{-15, 0}, {-3, 9}}}},
torch::kFloat);
auto y = module(x);

ASSERT_EQ(y.ndimension(), 4);
ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 4, 2, 2}));
ASSERT_TRUE(y.allclose(y_exp));
}

TEST_F(ModulesTest, Softplus) {
const auto size = 3;
for (const auto beta : {0.5, 1.0, 2.0}) {
Expand Down Expand Up @@ -4764,6 +4782,12 @@ TEST_F(ModulesTest, PrettyPrintPixelShuffle) {
"torch::nn::PixelShuffle(upscale_factor=5)");
}

TEST_F(ModulesTest, PrettyPrintPixelUnshuffle) {
ASSERT_EQ(
c10::str(PixelUnshuffle(PixelUnshuffleOptions(5))),
"torch::nn::PixelUnshuffle(downscale_factor=5)");
}

TEST_F(ModulesTest, PrettyPrintSoftplus) {
ASSERT_EQ(c10::str(Softplus()),
"torch::nn::Softplus(beta=1, threshold=20)");
Expand Down
1 change: 1 addition & 0 deletions test/cpp_api_parity/parity-tracker.md
Expand Up @@ -125,6 +125,7 @@ torch::nn::CosineEmbeddingLoss|Yes|No
torch::nn::MultiMarginLoss|Yes|No
torch::nn::TripletMarginLoss|Yes|No
torch::nn::PixelShuffle|Yes|No
torch::nn::PixelUnshuffle|Yes|No
torch::nn::Upsample|Yes|No
torch::nn::DataParallel|No|No
torch::nn::parallel::DistributedDataParallel|No|No
Expand Down
86 changes: 58 additions & 28 deletions test/test_nn.py
Expand Up @@ -6897,8 +6897,9 @@ def test_noncontig_conv_grad_cuda(self, dtype=torch.float):
output.backward(grad.contiguous())
self.assertEqual(result, input.grad.data, atol=dtype2prec_DONTUSE[dtype], rtol=0)

def test_pixel_shuffle(self):
def _test_pixel_shuffle_helper(num_input_dims, valid_channels_dim=True):
def test_pixel_shuffle_unshuffle(self):
def _test_pixel_shuffle_unshuffle_helper(num_input_dims, valid_channels_dim=True,
upscale_factor=None):
# Function to imperatively ensure pixels are shuffled to the correct locations.
# Used to validate the batch operations in pixel_shuffle.
def _verify_pixel_shuffle(input, output, upscale_factor):
Expand All @@ -6911,7 +6912,7 @@ def _verify_pixel_shuffle(input, output, upscale_factor):
(c * upscale_factor ** 2)
self.assertEqual(output[..., c, h, w], input[..., channel_idx, height_idx, weight_idx])

upscale_factor = random.randint(2, 5)
upscale_factor = random.randint(2, 5) if upscale_factor is None else upscale_factor
# If valid_channels_dim=False, add 1 to make channels dim indivisible by upscale_factor ** 2.
channels = random.randint(1, 4) * upscale_factor ** 2 + (0 if valid_channels_dim else 1)
height = random.randint(5, 10)
Expand All @@ -6925,47 +6926,76 @@ def _verify_pixel_shuffle(input, output, upscale_factor):
batch_sizes = [random.randint(1, 3) for _ in range(num_input_dims - 3)]
input = torch.rand(*batch_sizes, channels, height, width, requires_grad=True)
ps = nn.PixelShuffle(upscale_factor)
pus = nn.PixelUnshuffle(downscale_factor=upscale_factor)

if num_input_dims >= 3 and valid_channels_dim:
if num_input_dims >= 3 and valid_channels_dim and upscale_factor > 0:
output = ps(input)
_verify_pixel_shuffle(input, output, upscale_factor)
output.backward(output.data)
self.assertEqual(input.data, input.grad.data)

# Ensure unshuffle properly inverts shuffle.
unshuffle_output = pus(output)
self.assertEqual(input, unshuffle_output)
else:
self.assertRaises(RuntimeError, lambda: ps(input))

def test_pixel_shuffle_1D():
_test_pixel_shuffle_helper(num_input_dims=1)
def _test_pixel_unshuffle_error_case_helper(num_input_dims, valid_height_dim=True, valid_width_dim=True,
downscale_factor=None):
downscale_factor = random.randint(2, 5) if downscale_factor is None else downscale_factor
channels = random.randint(1, 4)
# If valid_height_dim=False, add 1 to make height dim indivisible by downscale_factor.
height = random.randint(3, 5) * abs(downscale_factor) + (0 if valid_height_dim else 1)
# If valid_width_dim=False, add 1 to make width dim indivisible by downscale_factor.
width = random.randint(3, 5) * abs(downscale_factor) + (0 if valid_width_dim else 1)

if num_input_dims == 1:
input = torch.rand(channels, requires_grad=True)
elif num_input_dims == 2:
input = torch.rand(height, width, requires_grad=True)
else:
batch_sizes = [random.randint(1, 3) for _ in range(num_input_dims - 3)]
input = torch.rand(*batch_sizes, channels, height, width, requires_grad=True)

pus = nn.PixelUnshuffle(downscale_factor)
self.assertRaises(RuntimeError, lambda: pus(input))

def _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims):
# For 1D - 2D, this is an error case.
# For 3D - 5D, this is a success case for pixel_shuffle + pixel_unshuffle.
_test_pixel_shuffle_unshuffle_helper(num_input_dims=num_input_dims)

def test_pixel_shuffle_2D():
_test_pixel_shuffle_helper(num_input_dims=2)
# Error cases for pixel_shuffle.
_test_pixel_shuffle_unshuffle_helper(num_input_dims=num_input_dims, valid_channels_dim=False)
_test_pixel_shuffle_unshuffle_helper(num_input_dims=num_input_dims, upscale_factor=0)
_test_pixel_shuffle_unshuffle_helper(num_input_dims=num_input_dims, upscale_factor=-2)

def test_pixel_shuffle_3D_with_valid_channels_dim():
_test_pixel_shuffle_helper(num_input_dims=3)
# Error cases for pixel_unshuffle.
_test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, valid_height_dim=False)
_test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, valid_width_dim=False)
_test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, downscale_factor=0)
_test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, downscale_factor=-2)

def test_pixel_shuffle_4D_with_valid_channels_dim():
_test_pixel_shuffle_helper(num_input_dims=4)
def test_pixel_shuffle_unshuffle_1D():
_test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=1)

def test_pixel_shuffle_5D_with_valid_channels_dim():
_test_pixel_shuffle_helper(num_input_dims=5)
def test_pixel_shuffle_unshuffle_2D():
_test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=2)

def test_pixel_shuffle_3D_with_invalid_channels_dim():
_test_pixel_shuffle_helper(num_input_dims=3, valid_channels_dim=False)
def test_pixel_shuffle_unshuffle_3D():
_test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=3)

def test_pixel_shuffle_4D_with_invalid_channels_dim():
_test_pixel_shuffle_helper(num_input_dims=4, valid_channels_dim=False)
def test_pixel_shuffle_unshuffle_4D():
_test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=4)

def test_pixel_shuffle_5D_with_invalid_channels_dim():
_test_pixel_shuffle_helper(num_input_dims=5, valid_channels_dim=False)
def test_pixel_shuffle_unshuffle_5D():
_test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=5)

test_pixel_shuffle_1D()
test_pixel_shuffle_2D()
test_pixel_shuffle_3D_with_valid_channels_dim()
test_pixel_shuffle_4D_with_valid_channels_dim()
test_pixel_shuffle_5D_with_valid_channels_dim()
test_pixel_shuffle_3D_with_invalid_channels_dim()
test_pixel_shuffle_4D_with_invalid_channels_dim()
test_pixel_shuffle_5D_with_invalid_channels_dim()
test_pixel_shuffle_unshuffle_1D()
test_pixel_shuffle_unshuffle_2D()
test_pixel_shuffle_unshuffle_3D()
test_pixel_shuffle_unshuffle_4D()
test_pixel_shuffle_unshuffle_5D()

def test_elu_inplace_view(self):
v = torch.tensor([1.0, -1.0, 1.0, -1.0], requires_grad=True)
Expand Down
1 change: 1 addition & 0 deletions tools/pyi/gen_pyi.py
Expand Up @@ -210,6 +210,7 @@ def gen_nn_functional(out: str) -> None:
'celu_',
'rrelu_',
'pixel_shuffle',
'pixel_unshuffle',
'channel_shuffle',
'pdist',
'cosine_similarity',
Expand Down
10 changes: 10 additions & 0 deletions torch/csrc/api/include/torch/nn/functional/pixelshuffle.h
Expand Up @@ -16,6 +16,10 @@ inline Tensor pixel_shuffle(
upscale_factor
);
}

inline Tensor pixel_unshuffle(const Tensor& input, int64_t downscale_factor) {
return torch::pixel_unshuffle(input, downscale_factor);
}
} // namespace detail
#endif /* DOXYGEN_SHOULD_SKIP_THIS */

Expand All @@ -36,6 +40,12 @@ inline Tensor pixel_shuffle(
return detail::pixel_shuffle(input, options.upscale_factor());
}

inline Tensor pixel_unshuffle(
const Tensor& input,
const PixelUnshuffleFuncOptions& options) {
return detail::pixel_unshuffle(input, options.downscale_factor());
}

} // namespace functional
} // namespace nn
} // namespace torch