Skip to content

Commit

Permalink
[RFC]: Create aten native op for constrain_range (pytorch#103346)
Browse files Browse the repository at this point in the history
At high current implementation of constrains functions (constrain_as_**) will raise exception for the following code snippets:
```
def f(x):
    a = x.item()
    constrain_as_size(a, 4, 7)
    return torch.empty((a, 4))

inp = torch.tensor([5])
ep = torch._export.export(f, (inp,))
```

The reason is because current constrain logic is:
1) Purely python so it won't survive AOT export (the full node is gone after AOT export since AOT export only maintains aten level op).
2) Utilize side effect to add range constraints for traced symbol's shape env ([code](https://github.com/pytorch/pytorch/blob/9591e52880fb09cb6753d97606e6efb956075f5b/torch/fx/experimental/symbolic_shapes.py#L370-L372)).
3) If runtime assertion is turned on (by default). [`_AddRuntimeAssertionsForConstraintsPass`](https://github.com/pytorch/pytorch/blob/9591e52880fb09cb6753d97606e6efb956075f5b/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py#L98-L100) will try to append assertion node based on range constrains extracted from shape env of symbol during another interpretation round.
4). However, since 1), in the round of AOT export, range constraints logic won't run for symbols generated during this round. And later there is no range constrains information available for assertion round and caused issue.
5) As a result of above, it will failure at `torch.empty((a, 4))` (there is no constrains for `a` that it must be positive).

The fix here is just to implement range constrain logic as a native aten op (CPU implementation as no-op) to make it be able to survive AOT export.

**NOTE:**
[Logic](https://github.com/pytorch/pytorch/blob/2d745b95d723641e575027bd4e2fff612f61cc8f/torch/fx/experimental/symbolic_shapes.py#L350-L365C15) within [`constrain_range`](https://github.com/pytorch/pytorch/blob/2d745b95d723641e575027bd4e2fff612f61cc8f/torch/fx/experimental/symbolic_shapes.py#LL313C74-L313C74) is split out as `constrain_range_int` to capture case when non `SymInt` is passed in and reused in the new `_constrain_range`. The reason is when non `SymInt` is provided:
* If it directly calls `sym_constrain_range`, the C++ version will be called which will be no-op.
* So in this case it calls `constrain_range_int` instead to be able to capture issue like user provides a input whose tensor's shape could be out of range during exporting, like the following for above code example:
```
...
inp = torch.tensor([10])
ep = torch._export.export(f, (inp,)) # immediately raise error
```

Differential Revision: [D46734204](https://our.internmc.facebook.com/intern/diff/D46734204)

Pull Request resolved: pytorch#103346
Approved by: https://github.com/tugsbayasgalan
  • Loading branch information
XuanQi authored and pytorchmergebot committed Jun 16, 2023
1 parent df81448 commit b27c355
Show file tree
Hide file tree
Showing 14 changed files with 130 additions and 23 deletions.
15 changes: 15 additions & 0 deletions aten/src/ATen/native/Constraints.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS

#include <c10/core/Scalar.h>
#include <c10/util/Optional.h>

namespace at {
namespace native {

void sym_constrain_range_cpu(
const Scalar& size,
c10::optional<int64_t> min = c10::nullopt,
c10::optional<int64_t> max = c10::nullopt) {}

} // namespace native
} // namespace at
15 changes: 15 additions & 0 deletions aten/src/ATen/native/cuda/Constraints.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#define TORCH_ASSERT_NO_OPERATORS

#include <c10/core/Scalar.h>
#include <c10/util/Optional.h>

namespace at {
namespace native {

void sym_constrain_range_cuda(
const Scalar& size,
c10::optional<int64_t> min = c10::nullopt,
c10::optional<int64_t> max = c10::nullopt) {}

} // namespace native
} // namespace at
5 changes: 5 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,11 @@

- func: _assert_tensor_metadata(Tensor a, SymInt[]? size=None, SymInt[]? stride=None, ScalarType? dtype=None) -> ()

- func: sym_constrain_range(Scalar size, int? min=None, int? max=None) -> ()
dispatch:
CPU: sym_constrain_range_cpu
CUDA: sym_constrain_range_cuda

- func: refine_names(Tensor(a) self, Dimname[] names) -> Tensor(a)
variants: method

Expand Down
1 change: 1 addition & 0 deletions build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -1244,6 +1244,7 @@ aten_native_source_non_codegen_list = [
"aten/src/ATen/native/ChanelShuffle.cpp",
"aten/src/ATen/native/Col2Im.cpp",
"aten/src/ATen/native/PadNd.cpp",
"aten/src/ATen/native/Constraints.cpp",
"aten/src/ATen/native/Convolution.cpp",
"aten/src/ATen/native/ConvolutionMM2d.cpp",
"aten/src/ATen/native/ConvolutionMM3d.cpp",
Expand Down
39 changes: 39 additions & 0 deletions test/dynamo/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from torch._export.constraints import constrain_as_size, constrain_as_value
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.experimental.symbolic_shapes import ConstraintViolationError
from torch.testing import FileCheck
from torch.testing._internal import common_utils
from torch.testing._internal.common_utils import skipIfRocm

Expand Down Expand Up @@ -2408,6 +2409,44 @@ def f(x, y):
buffer = io.BytesIO()
torch.save(gm, buffer)

def test_export_with_inline_constraints(self):
def f(x):
a = x.item()
constrain_as_size(a, 4, 7)
return torch.empty((a, 4))

with self.assertRaisesRegex(
torch._dynamo.exc.UserError, r"Invalid value 20 for range \[4:7\]"
) as cm:
torch._export.export(f, (torch.tensor([20]),))

ep = torch._export.export(f, (torch.tensor([5]),))
self.assertEqual(ep(torch.tensor([6])).shape, (6, 4))

FileCheck().check_count(
"torch.ops.aten.sym_constrain_range.default", 1, exactly=True
).run(ep.graph_module.code)

with self.assertRaisesRegex(
RuntimeError,
r"_local_scalar_dense_default is outside of inline constraint \[4, 7\]",
) as cm:
ep(torch.tensor([30]))

def test_export_with_inline_constraints_complex(self):
def f(x):
a = x.item()
constrain_as_size(a, 4, 7)
empty = torch.empty((a, 4))

return torch.cat((empty.transpose(0, 1), torch.zeros(6, a)), 0)

ep = torch._export.export(f, (torch.tensor([6]),))
self.assertEqual(ep(torch.tensor([5])).shape, (10, 5))
FileCheck().check_count(
"torch.ops.aten.sym_constrain_range.default", 1, exactly=True
).run(ep.graph_module.code)

def test_export_dynamic_dim_not_1(self):
x = torch.randn([1, 1, 1])

Expand Down
1 change: 0 additions & 1 deletion test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def forward(self, x):

@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
class TestDynamismExpression(TestCase):
@unittest.expectedFailure
def test_export_inline_constraints(self):

def f(x):
Expand Down
6 changes: 3 additions & 3 deletions test/export/test_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,9 +268,9 @@ def forward(self, x):
num_assert = count_call_function(ep.graph, torch.ops.aten._assert_async.msg)
num_scalar_tensor = count_call_function(ep.graph, torch.ops.aten.scalar_tensor.default)

# 2 constraints for b
self.assertEqual(num_assert, 2)
self.assertEqual(num_scalar_tensor, 2)
# TODO: De-duplicate assertions for same symbol.
self.assertEqual(num_assert, 4)
self.assertEqual(num_scalar_tensor, 4)

with self.assertRaisesRegex(RuntimeError, r"nonzero_default.shape\[0\] is outside of inline constraint \[3, 5\]."):
ep(torch.tensor([1, 1, 0, 0, 0]))
Expand Down
20 changes: 17 additions & 3 deletions torch/_export/constraints.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
from typing import Optional
from typing import Optional, Callable, Union

import torch
from torch import SymInt, SymFloat
from torch._dynamo import allow_in_graph
from torch.fx.experimental.symbolic_shapes import constrain_range
from torch.fx.experimental.symbolic_shapes import constrain_range_int
from torch.utils._sympy.value_ranges import ValueRangeError

# `Scalar` type used in native_functions.ymal will be translated to `Union[Number, _complex]`
# could cause type error during since `SymInt` or `SymFloat` will be used.
# Here manually specify the type explicitly.
sym_constrain_range: Callable[
[Union[int, float, SymInt, SymFloat], Optional[int], Optional[int]],
None,
] = torch.sym_constrain_range # type: ignore[assignment]


# TODO: we want to hide this min/max stuff under some abstraction similar to
# DynamicDim
Expand All @@ -13,7 +23,11 @@ def constrain_as_value(symbol, min: Optional[int] = None, max: Optional[int] = N
Add min/max constraint on the intermediate symbol at tracing time
"""

constrain_range(symbol, min=min, max=max)
if not isinstance(symbol, SymInt):
constrain_range_int(symbol, min=min, max=max)
else:
sym_constrain_range(symbol, min, max)

return symbol


Expand Down
7 changes: 6 additions & 1 deletion torch/_meta_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
out_wrapper,
)
from torch._refs import _broadcast_shapes

from torch.fx.experimental.symbolic_shapes import constrain_range
from torch.utils._pytree import tree_map


Expand Down Expand Up @@ -311,6 +311,11 @@ def assert_async_meta(val, assert_msg):
return


@register_meta(aten.sym_constrain_range.default)
def sym_constrain_range(size, min, max):
constrain_range(size, min=min, max=max)


# From aten/src/ATen/native/LinearAlgebraUtils.h
def squareCheckInputs(self: Tensor, f_name: str):
assert (
Expand Down
2 changes: 1 addition & 1 deletion torch/_subclasses/fake_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1488,7 +1488,7 @@ def wrap(e, device=None):
nonlocal common_device
nonlocal has_scalar_only_inputs

if common_device is None:
if isinstance(e, torch.Tensor) and common_device is None:
(
common_device,
has_scalar_only_inputs,
Expand Down
39 changes: 25 additions & 14 deletions torch/fx/experimental/symbolic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,20 +348,7 @@ def constrain_range(a, *, min: Optional[int], max: Optional[int] = None):
if max is None:
max = sympy.oo
if not isinstance(a, SymInt):
if not (min <= a <= max):
raise ValueRangeError(f"Invalid value {a} for range [{min}:{max}]")

if (
(fake_mode := detect_fake_mode()) is not None and
getattr(fake_mode, "shape_env", None) is not None
):
# If we are tracing with a fake mode then add this integer to the
# shape_env's var_to_range
sym_integer = sympy.Integer(a)
shape_env = fake_mode.shape_env
_constrain_symbol_range(shape_env, sym_integer, min, max)
shape_env.var_to_stack[sym_integer] = TracingContext(fake_mode).extract_stack()

constrain_range_int(a, min=min, max=max)
return

if isinstance(a.node.expr, sympy.Integer):
Expand All @@ -376,6 +363,30 @@ def constrain_range(a, *, min: Optional[int], max: Optional[int] = None):
# SymInt).
_constrain_symbol_range(a.node.shape_env, a.node.expr, min, max)

def constrain_range_int(a, *, min, max):
"""
Constrain range on concrete int value.
This can happens for the following scenarios:
- Eager mode execution and real int value is provided.
- During tracing the traced symbol is resolved as a static integer (see
PR #101655 for more details).
"""

assert not isinstance(a, SymInt)
if not (min <= a <= max):
raise ValueRangeError(f"Invalid value {a} for range [{min}:{max}]")

if (
(fake_mode := detect_fake_mode()) is not None and
getattr(fake_mode, "shape_env", None) is not None
):
# If we are tracing with a fake mode then add this integer to the
# shape_env's var_to_range
sym_integer = sympy.Integer(a)
shape_env = fake_mode.shape_env
_constrain_symbol_range(shape_env, sym_integer, min, max)
shape_env.var_to_stack[sym_integer] = TracingContext(fake_mode).extract_stack()

def constrain_unify(a, b):
"""
Given two SymInts, constrain them so that they must be equal. NB:
Expand Down
1 change: 1 addition & 0 deletions torch/fx/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
torch._assert_async,
_ops.aten._assert_async.msg,
_ops.aten.copy_.default,
_ops.aten.sym_constrain_range.default,
_ops.profiler._record_function_enter,
_ops.profiler._record_function_enter_new,
_ops.profiler._record_function_exit}
Expand Down
1 change: 1 addition & 0 deletions torch/overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ def get_ignored_functions() -> Set[Callable]:
torch.sym_max,
torch.sym_min,
torch.sym_not,
torch.sym_constrain_range,
torch.tril_indices,
torch.triu_indices,
torch.vander,
Expand Down
1 change: 1 addition & 0 deletions torchgen/native_function_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
"qscheme", # returns a QScheme
"record_stream", # no return
"sparse_dim", # returns an int
"sym_constrain_range", # no return
"_nested_tensor_storage_offsets", # returns a vector of ints
"_chunk_grad_outputs_efficient_attention", # returns a bool
"_fused_sdp_choice", # returns an int
Expand Down

0 comments on commit b27c355

Please sign in to comment.