-
Notifications
You must be signed in to change notification settings - Fork 21.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Move ConstantPadNd into ATen #10885
Closed
Closed
Move ConstantPadNd into ATen #10885
Changes from all commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
10a60b9
Forward function for ConstantPadNd
wdhorton 9a132d4
Add onnx definition
wdhorton 976d1e8
Fix onnx definition
wdhorton 8168691
Fix onnx signature
wdhorton 869d2f2
Rename to self
wdhorton d2e577c
Backward function.
wdhorton 38e841d
remove import
wdhorton c8340bd
Use .size(i)
wdhorton efa3604
Use emplace_back
wdhorton c6272c5
Change value to f for onnx
wdhorton 1dabe46
Remove Variable wrapper
wdhorton d5f4832
Make error message longer
wdhorton ee7d419
Add check that pad length is even, and improve error message
wdhorton c2473dd
Update with optimizations.
wdhorton 1fbe788
Add value to constant_pad_nd call in backward.
wdhorton dcee49c
semicolons
wdhorton 31ee9e3
fix derivatives.yaml
wdhorton 4c1a4c7
Add to aten_interned_strings.h
wdhorton f2109c4
Use at:: not at::native in backward.
wdhorton File filter
Filter by extension
Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
#include "ATen/ATen.h" | ||
|
||
namespace at { namespace native { | ||
|
||
Tensor constant_pad_nd(const Tensor& self, IntList pad, Scalar value) { | ||
AT_CHECK(pad.size() % 2 == 0, "Length of pad must be even but instead it equals ", | ||
pad.size()); | ||
|
||
auto input_sizes = self.sizes(); | ||
auto l_inp = self.dim(); | ||
|
||
auto l_pad = pad.size() / 2; | ||
auto l_diff = l_inp - l_pad; | ||
AT_CHECK(l_inp >= l_pad, "Length of pad should be no more than twice the number of " | ||
"dimensions of the input. Pad length is ", pad.size(), "while the input has ", | ||
l_inp, "dimensions."); | ||
|
||
std::vector<int64_t> new_shape; | ||
|
||
bool all_pads_non_positive = true; | ||
|
||
auto c_input = self; | ||
for (int i = l_diff; i < l_inp; i++) { | ||
auto pad_idx = 2 * (l_inp - i - 1); | ||
if (pad[pad_idx] < 0) { | ||
c_input = c_input.narrow(i, -pad[pad_idx], c_input.size(i) + pad[pad_idx]); | ||
} else if (pad[pad_idx] != 0) { | ||
all_pads_non_positive = false; | ||
} | ||
if (pad[pad_idx + 1] < 0) { | ||
c_input = c_input.narrow(i, 0, c_input.size(i) + pad[pad_idx + 1]); | ||
} else if (pad[pad_idx + 1] != 0) { | ||
all_pads_non_positive = false; | ||
} | ||
} | ||
|
||
// if none of the pads are positive we can optimize and just return the result | ||
// of calling .narrow() on the input | ||
if (all_pads_non_positive) { | ||
return c_input; | ||
} | ||
|
||
|
||
for (int i = 0; i < l_diff; i ++) { | ||
new_shape.emplace_back(input_sizes[i]); | ||
} | ||
|
||
for (int i = 0; i < l_pad; i++) { | ||
auto pad_idx = pad.size() - ((i + 1) * 2); | ||
auto new_dim = input_sizes[l_diff + i] + pad[pad_idx] + pad[pad_idx + 1]; | ||
AT_CHECK(new_dim > 0, "The input size ", input_sizes[l_diff + i], ", plus negative padding ", | ||
pad[pad_idx], " and ", pad[pad_idx + 1], "resulted in a negative output size, " | ||
"which is invalid. Check dimension ", l_diff + i, "of your input."); | ||
new_shape.emplace_back(new_dim); | ||
} | ||
|
||
auto output = at::empty(new_shape, self.options()); | ||
output.fill_(value); | ||
|
||
auto c_output = output; | ||
for (int i = l_diff; i < l_inp; i++) { | ||
auto pad_idx = 2 * (l_inp - i - 1); | ||
if (pad[pad_idx] > 0) { | ||
c_output = c_output.narrow(i, pad[pad_idx], c_output.size(i) - pad[pad_idx]); | ||
} | ||
if (pad[pad_idx + 1] > 0) { | ||
c_output = c_output.narrow(i, 0, c_output.size(i) - pad[pad_idx + 1]); | ||
} | ||
} | ||
c_output.copy_(c_input); | ||
return output; | ||
} | ||
|
||
}} // namespace at::native |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
Sorry, something went wrong.