Skip to content

Commit

Permalink
Expose torch.export.constrain_as_{size,value} APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
gmagogsfm committed Aug 23, 2023
1 parent 137d96a commit 03824dc
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 0 deletions.
2 changes: 2 additions & 0 deletions docs/source/export.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ torch.export
.. automodule:: torch.export
.. autofunction:: export
.. autofunction:: dynamic_dim
.. autofunction:: constrain_as_size
.. autofunction:: constrain_as_value

.. toctree::
:glob:
Expand Down
161 changes: 161 additions & 0 deletions torch/export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,172 @@


__all__ = [
"constrain_as_size",
"constrain_as_value",
"dynamic_dim",
"export",
]


def constrain_as_value(symbol, min: Optional[int] = None, max: Optional[int] = None):
"""
Hint `export()` about the constraint of an intermediate scalar value so that subsequent
branching behaviors that check on the range of aforementioned scalar value can be
soundly traced.
.. warning::
(Note that if the intermediate scalar value will be used as a shape,
call `constrain_as_size` API instead.)
For example, following program can not be traced soundly wihout using
`constrain_as_value` to give `export()` a hint about which branch to take::
def fn(x):
v = x.max().item()
if v > 1024:
return x
else:
return x * 2
`export()` would give following error::
torch._dynamo.exc.UserError: Consider annotating your code using
torch.export.constrain_as_size() or torch.export().constrain_as_value() APIs.
It appears that you're trying to get a value out of symbolic int/float whose value
is data-dependent (and thus we do not know the true value.) The expression we were
trying to evaluate is f0 > 1024 (unhinted: f0 > 1024).
Assuming the actual range of `v` can be between [10, 200], you can add a call to
`constrain_as_value` in the source code like this::
def fn(x):
v = x.max().item()
# Give export() a hint
torch.export.constrain_as_value(v, min=10, max=200)
if v > 1024:
return x
else:
return x * 2
With the additional hint, `export()` would be able to trace the program correctly by taking
the `else` branch, resulting in following graph::
graph():
%arg0_1 := placeholder[target=arg0_1]
# v = x.max().item()
%max_1 := call_function[target=torch.ops.aten.max.default](args = (%arg0_1,))
%_local_scalar_dense := call_function[target=torch.ops.aten._local_scalar_dense.default](args = (%max_1,))
# Asserting 10 <= v <= 200
%ge := call_function[target=operator.ge](args = (%_local_scalar_dense, 10))
%scalar_tensor := call_function[target=torch.ops.aten.scalar_tensor.default](args = (%ge,))
%_assert_async := call_function[target=torch.ops.aten._assert_async.msg](
args = (%scalar_tensor, _local_scalar_dense is outside of inline constraint [10, 200].))
%le := call_function[target=operator.le](args = (%_local_scalar_dense, 200))
%scalar_tensor_1 := call_function[target=torch.ops.aten.scalar_tensor.default](args = (%le,))
%_assert_async_1 := call_function[target=torch.ops.aten._assert_async.msg](
args = (%scalar_tensor_1, _local_scalar_dense is outside of inline constraint [10, 200].))
%sym_constrain_range := call_function[target=torch.ops.aten.sym_constrain_range.default](
args = (%_local_scalar_dense,), kwargs = {min: 10, max: 200})
# Always taking `else` branch to multiply elements `x` by 2 due to hints above
%mul := call_function[target=torch.ops.aten.mul.Tensor](args = (%arg0_1, 2), kwargs = {})
return (mul,)
Args:
symbol: Intermediate scalar value (int-only now) to apply range constraint on.
min (Optional[int]): Minimum possible value of given symbol (inclusive)
max (Optional[int]): Maximum possible value of given symbol (inclusive)
Returns:
None
"""
from torch._export.constraints import constrain_as_value

return constrain_as_value(symbol, min, max)


def constrain_as_size(symbol, min: Optional[int] = None, max: Optional[int] = None):
"""
Hint `export()` about the constraint of an intermediate scalar value that
represents shape of a tensor so that subsequent tensor constructors can be
traced correctly because many operators need to make assumption about range
of sizes.
For example, following program can not be traced soundly wihout using
`constrain_as_size` to give `export()` a hint about shape ranges::
def fn(x):
d = x.max().item()
return torch.ones(v)
`export()` would give following error::
torch._dynamo.exc.Unsupported: guard on data-dependent symbolic int/float
Assuming the actual range of `d` can be between [3, 10], you can add a call to
`constrain_as_size` in the source code like this::
def fn(x):
d = x.max().item()
torch.export.constrain_as_size(d, min=3, max=10)
return torch.ones(d)
With the additional hint, `export()` would be able to trace the program correctly by taking
the `else` branch, resulting in following graph::
graph():
%arg0_1 := placeholder[target=arg0_1]
# d = x.max().item()
%max_1 := call_function[target=torch.ops.aten.max.default](args = (%arg0_1,))
%_local_scalar_dense := call_function[target=torch.ops.aten._local_scalar_dense.default](args = (%max_1,))
# Asserting 3 <= d <= 10
%ge := call_function[target=operator.ge](args = (%_local_scalar_dense, 3))
%scalar_tensor := call_function[target=torch.ops.aten.scalar_tensor.default](args = (%ge,))
%_assert_async := call_function[target=torch.ops.aten._assert_async.msg](
args = (%scalar_tensor, _local_scalar_dense is outside of inline constraint [3, 10].))
%le := call_function[target=operator.le](args = (%_local_scalar_dense, 10))
%scalar_tensor_1 := call_function[target=torch.ops.aten.scalar_tensor.default](args = (%le,))
%_assert_async_1 := call_function[target=torch.ops.aten._assert_async.msg](
args = (%scalar_tensor_1, _local_scalar_dense is outside of inline constraint [3, 10].))
%sym_constrain_range_for_size := call_function[target=torch.ops.aten.sym_constrain_range_for_size.default](
args = (%_local_scalar_dense,), kwargs = {min: 3, max: 10})
# Constructing new tensor with d
%full := call_function[target=torch.ops.aten.full.default](
args = ([%_local_scalar_dense], 1),
kwargs = {dtype: torch.float32, layout: torch.strided, device: cpu, pin_memory: False})
......
.. warning::
It is illegal to specify a range that contains 0 and 1. 0/1 values are always specialized
and can not be part of dynamic range.
Args:
symbol: Intermediate scalar value (int-only now) to apply range constraint on.
min (Optional[int]): Minimum possible value of given symbol (inclusive)
max (Optional[int]): Maximum possible value of given symbol (inclusive)
Returns:
None
"""

from torch._export.constraints import constrain_as_size

return constrain_as_size(symbol, min, max)


def dynamic_dim(t: torch.Tensor, index: int):
"""
`dynamic_dim` constructs a `Constraint` object that describes the dynamism of
Expand Down
2 changes: 2 additions & 0 deletions torch/overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ def get_ignored_functions() -> Set[Callable]:
torch.empty_permuted,
torch.empty_strided,
torch.empty_quantized,
torch.export.constrain_as_size,
torch.export.constrain_as_value,
torch.export.dynamic_dim,
torch.export.export,
torch.eye,
Expand Down

0 comments on commit 03824dc

Please sign in to comment.