forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[RFC]: Create aten native op for constrain_range (pytorch#103346)
At high current implementation of constrains functions (constrain_as_**) will raise exception for the following code snippets: ``` def f(x): a = x.item() constrain_as_size(a, 4, 7) return torch.empty((a, 4)) inp = torch.tensor([5]) ep = torch._export.export(f, (inp,)) ``` The reason is because current constrain logic is: 1) Purely python so it won't survive AOT export (the full node is gone after AOT export since AOT export only maintains aten level op). 2) Utilize side effect to add range constraints for traced symbol's shape env ([code](https://github.com/pytorch/pytorch/blob/9591e52880fb09cb6753d97606e6efb956075f5b/torch/fx/experimental/symbolic_shapes.py#L370-L372)). 3) If runtime assertion is turned on (by default). [`_AddRuntimeAssertionsForConstraintsPass`](https://github.com/pytorch/pytorch/blob/9591e52880fb09cb6753d97606e6efb956075f5b/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py#L98-L100) will try to append assertion node based on range constrains extracted from shape env of symbol during another interpretation round. 4). However, since 1), in the round of AOT export, range constraints logic won't run for symbols generated during this round. And later there is no range constrains information available for assertion round and caused issue. 5) As a result of above, it will failure at `torch.empty((a, 4))` (there is no constrains for `a` that it must be positive). The fix here is just to implement range constrain logic as a native aten op (CPU implementation as no-op) to make it be able to survive AOT export. **NOTE:** [Logic](https://github.com/pytorch/pytorch/blob/2d745b95d723641e575027bd4e2fff612f61cc8f/torch/fx/experimental/symbolic_shapes.py#L350-L365C15) within [`constrain_range`](https://github.com/pytorch/pytorch/blob/2d745b95d723641e575027bd4e2fff612f61cc8f/torch/fx/experimental/symbolic_shapes.py#LL313C74-L313C74) is split out as `constrain_range_int` to capture case when non `SymInt` is passed in and reused in the new `_constrain_range`. The reason is when non `SymInt` is provided: * If it directly calls `sym_constrain_range`, the C++ version will be called which will be no-op. * So in this case it calls `constrain_range_int` instead to be able to capture issue like user provides a input whose tensor's shape could be out of range during exporting, like the following for above code example: ``` ... inp = torch.tensor([10]) ep = torch._export.export(f, (inp,)) # immediately raise error ``` Differential Revision: [D46734204](https://our.internmc.facebook.com/intern/diff/D46734204) Pull Request resolved: pytorch#103346 Approved by: https://github.com/tugsbayasgalan
- Loading branch information
1 parent
df81448
commit b27c355
Showing
14 changed files
with
130 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS | ||
|
||
#include <c10/core/Scalar.h> | ||
#include <c10/util/Optional.h> | ||
|
||
namespace at { | ||
namespace native { | ||
|
||
void sym_constrain_range_cpu( | ||
const Scalar& size, | ||
c10::optional<int64_t> min = c10::nullopt, | ||
c10::optional<int64_t> max = c10::nullopt) {} | ||
|
||
} // namespace native | ||
} // namespace at |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
#define TORCH_ASSERT_NO_OPERATORS | ||
|
||
#include <c10/core/Scalar.h> | ||
#include <c10/util/Optional.h> | ||
|
||
namespace at { | ||
namespace native { | ||
|
||
void sym_constrain_range_cuda( | ||
const Scalar& size, | ||
c10::optional<int64_t> min = c10::nullopt, | ||
c10::optional<int64_t> max = c10::nullopt) {} | ||
|
||
} // namespace native | ||
} // namespace at |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters