New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add PixelUnshuffle #49334
Closed
Closed
Add PixelUnshuffle #49334
Changes from 8 commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
ae09559
[pytorch] Add PixelUnshuffle (#49334)
jbschlosser bd143a4
Addressing PR comments (doc changes + extra checks + renaming)
jbschlosser 1bdd60c
PixelUnshuffle: C++ parity & tests + lint + better py tests
jbschlosser e02d893
Linter strikes again
jbschlosser 3ea84a5
Manually fix long lines (need to get flake8 running properly)
jbschlosser 07c33ed
Adding override func and interned string for pixel unshuffle
jbschlosser cbe6c1c
Adding JIT test & gen_pyi entry for pixel unshuffle
jbschlosser f6cda98
Fix doc warning: horizontal line too short
jbschlosser 7b0d728
Addressing PR comments
jbschlosser 26c3b7a
Adding implicit comment to PixelUnshuffleOptions
jbschlosser File filter
Filter by extension
Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,12 +14,16 @@ Tensor pixel_shuffle(const Tensor& self, int64_t upscale_factor) { | |
TORCH_CHECK(self.dim() >= 3, | ||
"pixel_shuffle expects input to have at least 3 dimensions, but got input with ", | ||
self.dim(), " dimension(s)"); | ||
TORCH_CHECK( | ||
upscale_factor > 0, | ||
"pixel_shuffle expects a positive upscale_factor, but got ", | ||
upscale_factor); | ||
// Format: (B1, ..., Bn), C, H, W | ||
int64_t c = self.size(-3); | ||
int64_t h = self.size(-2); | ||
int64_t w = self.size(-1); | ||
const auto NUM_NON_BATCH_DIMS = 3; | ||
const auto last_batch_dim = self.sizes().end() - NUM_NON_BATCH_DIMS; | ||
const auto self_sizes_batch_end = self.sizes().end() - NUM_NON_BATCH_DIMS; | ||
|
||
int64_t upscale_factor_squared = upscale_factor * upscale_factor; | ||
TORCH_CHECK(c % upscale_factor_squared == 0, | ||
|
@@ -29,24 +33,81 @@ Tensor pixel_shuffle(const Tensor& self, int64_t upscale_factor) { | |
int64_t oh = h * upscale_factor; | ||
int64_t ow = w * upscale_factor; | ||
|
||
// First, reshape to expand the channels dim from c into 3 separate dims: (oc, upscale_factor, upscale_factor). | ||
// This allows shuffling to be done next by permuting dims. | ||
std::vector<int64_t> expanded_shape(self.sizes().begin(), last_batch_dim); | ||
expanded_shape.insert(expanded_shape.end(), {oc, upscale_factor, upscale_factor, h, w}); | ||
const auto input_expanded = self.reshape(expanded_shape); | ||
// First, reshape to split the channels dim from c into 3 separate dims: (oc, | ||
// upscale_factor, upscale_factor). This allows shuffling to be done next by | ||
// permuting dims. | ||
std::vector<int64_t> added_dims_shape( | ||
self.sizes().begin(), self_sizes_batch_end); | ||
added_dims_shape.insert( | ||
added_dims_shape.end(), {oc, upscale_factor, upscale_factor, h, w}); | ||
const auto input_reshaped = self.reshape(added_dims_shape); | ||
|
||
// Next, shuffle by permuting the new upscale_factor dims alongside the height and width dims. | ||
std::vector<int64_t> permutation(self.sizes().begin(), last_batch_dim); | ||
std::vector<int64_t> permutation(self.sizes().begin(), self_sizes_batch_end); | ||
// std::iota is used to maintain the batch dims within the permutation. | ||
// Since expansion added 2 dims, the correct batch dim offsets are now: -expanded_shape.size(), ..., -7, -6. | ||
std::iota(permutation.begin(), permutation.end(), -expanded_shape.size()); | ||
// Since 2 dims were added, the correct batch dim offsets are now: | ||
// -added_dims_shape.size(), ..., -7, -6. | ||
std::iota(permutation.begin(), permutation.end(), -added_dims_shape.size()); | ||
permutation.insert(permutation.end(), {-5 /* oc */, -2 /* h */, -4 /* 1st upscale_factor */, -1 /* w */, | ||
-3 /* 2nd upscale_factor */}); | ||
const auto input_permuted = input_expanded.permute(permutation); | ||
const auto input_permuted = input_reshaped.permute(permutation); | ||
|
||
// Finally, upscale by collapsing (h, upscale_factor) -> a single dim (oh) | ||
// and (w, upscale_factor) -> a single dim (ow). | ||
std::vector<int64_t> final_shape(self.sizes().begin(), last_batch_dim); | ||
std::vector<int64_t> final_shape(self.sizes().begin(), self_sizes_batch_end); | ||
final_shape.insert(final_shape.end(), {oc, oh, ow}); | ||
return input_permuted.reshape(final_shape); | ||
} | ||
|
||
|
||
Tensor pixel_unshuffle(const Tensor& self, int64_t downscale_factor) { | ||
TORCH_CHECK(self.dim() >= 3, | ||
"pixel_unshuffle expects input to have at least 3 dimensions, but got input with ", | ||
self.dim(), " dimension(s)"); | ||
TORCH_CHECK( | ||
downscale_factor > 0, | ||
"pixel_unshuffle expects a positive downscale_factor, but got ", | ||
downscale_factor); | ||
// Format: (B1, ..., Bn), C, H, W | ||
int64_t c = self.size(-3); | ||
int64_t h = self.size(-2); | ||
int64_t w = self.size(-1); | ||
const auto NUM_NON_BATCH_DIMS = 3; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: constexpr |
||
const auto self_sizes_batch_end = self.sizes().end() - NUM_NON_BATCH_DIMS; | ||
|
||
TORCH_CHECK(h % downscale_factor == 0, | ||
"pixel_unshuffle expects height to be divisible by downscale_factor, but input.size(-2)=", h, | ||
" is not divisible by ", downscale_factor) | ||
TORCH_CHECK(w % downscale_factor == 0, | ||
"pixel_unshuffle expects width to be divisible by downscale_factor, but input.size(-1)=", w, | ||
" is not divisible by ", downscale_factor) | ||
int64_t downscale_factor_squared = downscale_factor * downscale_factor; | ||
int64_t oc = c * downscale_factor_squared; | ||
int64_t oh = h / downscale_factor; | ||
int64_t ow = w / downscale_factor; | ||
|
||
// First, reshape to split height dim into (oh, downscale_factor) dims and | ||
// width dim into (ow, downscale_factor) dims. This allows unshuffling to be | ||
// done next by permuting dims. | ||
std::vector<int64_t> added_dims_shape( | ||
self.sizes().begin(), self_sizes_batch_end); | ||
added_dims_shape.insert( | ||
added_dims_shape.end(), {c, oh, downscale_factor, ow, downscale_factor}); | ||
const auto input_reshaped = self.reshape(added_dims_shape); | ||
|
||
// Next, unshuffle by permuting the downscale_factor dims alongside the channel dim. | ||
std::vector<int64_t> permutation(self.sizes().begin(), self_sizes_batch_end); | ||
// std::iota is used to maintain the batch dims within the permutation. | ||
// Since 2 dims were added, the correct batch dim offsets are now: | ||
// -added_dims_shape.size(), ..., -7, -6. | ||
std::iota(permutation.begin(), permutation.end(), -added_dims_shape.size()); | ||
permutation.insert(permutation.end(), {-5 /* c */, -3 /* 1st downscale_factor */, -1 /*2nd downscale_factor */, | ||
-4 /* oh */, -2 /* ow */}); | ||
const auto input_permuted = input_reshaped.permute(permutation); | ||
|
||
// Finally, downscale by collapsing (c, downscale_factor, downscale_factor) -> a single dim (oc), | ||
// resulting in height=oh and width=ow. | ||
std::vector<int64_t> final_shape(self.sizes().begin(), self_sizes_batch_end); | ||
final_shape.insert(final_shape.end(), {oc, oh, ow}); | ||
return input_permuted.reshape(final_shape); | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should there also be a check that downscale factor is > 0?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like this, especially because the below checks of
h % downscale_factor == 0
andw % downscale_factor == 0
may pass even with a negativedownscale_factor
due to the specific way mod is implemented in C++.For consistency, what do you think about me adding an upscale_factor check in
pixel_shuffle
in this PR?