Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose torch.export.constrain_as_{size,value} APIs #107735

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/export.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ torch.export
.. automodule:: torch.export
.. autofunction:: export
.. autofunction:: dynamic_dim
.. autofunction:: constrain_as_size
.. autofunction:: constrain_as_value

.. toctree::
:glob:
Expand Down
161 changes: 161 additions & 0 deletions torch/export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,172 @@


__all__ = [
"constrain_as_size",
"constrain_as_value",
"dynamic_dim",
"export",
]


def constrain_as_value(symbol, min: Optional[int] = None, max: Optional[int] = None):
"""
Hint `export()` about the constraint of an intermediate scalar value so that subsequent
branching behaviors that check on the range of aforementioned scalar value can be
soundly traced.

.. warning::
(Note that if the intermediate scalar value will be used as a shape,
call `constrain_as_size` API instead.)

For example, following program can not be traced soundly wihout using
`constrain_as_value` to give `export()` a hint about which branch to take::

def fn(x):
v = x.max().item()
if v > 1024:
return x
else:
return x * 2


`export()` would give following error::

torch._dynamo.exc.UserError: Consider annotating your code using
torch.export.constrain_as_size() or torch.export().constrain_as_value() APIs.
It appears that you're trying to get a value out of symbolic int/float whose value
is data-dependent (and thus we do not know the true value.) The expression we were
trying to evaluate is f0 > 1024 (unhinted: f0 > 1024).

Assuming the actual range of `v` can be between [10, 200], you can add a call to
`constrain_as_value` in the source code like this::

def fn(x):
v = x.max().item()

# Give export() a hint
torch.export.constrain_as_value(v, min=10, max=200)

if v > 1024:
return x
else:
return x * 2

With the additional hint, `export()` would be able to trace the program correctly by taking
the `else` branch, resulting in following graph::

graph():
%arg0_1 := placeholder[target=arg0_1]

# v = x.max().item()
%max_1 := call_function[target=torch.ops.aten.max.default](args = (%arg0_1,))
%_local_scalar_dense := call_function[target=torch.ops.aten._local_scalar_dense.default](args = (%max_1,))

# Asserting 10 <= v <= 200
%ge := call_function[target=operator.ge](args = (%_local_scalar_dense, 10))
%scalar_tensor := call_function[target=torch.ops.aten.scalar_tensor.default](args = (%ge,))
%_assert_async := call_function[target=torch.ops.aten._assert_async.msg](
args = (%scalar_tensor, _local_scalar_dense is outside of inline constraint [10, 200].))
%le := call_function[target=operator.le](args = (%_local_scalar_dense, 200))
%scalar_tensor_1 := call_function[target=torch.ops.aten.scalar_tensor.default](args = (%le,))
%_assert_async_1 := call_function[target=torch.ops.aten._assert_async.msg](
args = (%scalar_tensor_1, _local_scalar_dense is outside of inline constraint [10, 200].))
%sym_constrain_range := call_function[target=torch.ops.aten.sym_constrain_range.default](
args = (%_local_scalar_dense,), kwargs = {min: 10, max: 200})

# Always taking `else` branch to multiply elements `x` by 2 due to hints above
%mul := call_function[target=torch.ops.aten.mul.Tensor](args = (%arg0_1, 2), kwargs = {})
return (mul,)


Args:
symbol: Intermediate scalar value (int-only now) to apply range constraint on.
min (Optional[int]): Minimum possible value of given symbol (inclusive)
max (Optional[int]): Maximum possible value of given symbol (inclusive)

Returns:
None

"""
from torch._export.constraints import constrain_as_value

return constrain_as_value(symbol, min, max)


def constrain_as_size(symbol, min: Optional[int] = None, max: Optional[int] = None):
"""
Hint `export()` about the constraint of an intermediate scalar value that
represents shape of a tensor so that subsequent tensor constructors can be
traced correctly because many operators need to make assumption about range
of sizes.

For example, following program can not be traced soundly wihout using
`constrain_as_size` to give `export()` a hint about shape ranges::

def fn(x):
d = x.max().item()
return torch.ones(v)

`export()` would give following error::

torch._dynamo.exc.Unsupported: guard on data-dependent symbolic int/float

Assuming the actual range of `d` can be between [3, 10], you can add a call to
`constrain_as_size` in the source code like this::

def fn(x):
d = x.max().item()
torch.export.constrain_as_size(d, min=3, max=10)
return torch.ones(d)

With the additional hint, `export()` would be able to trace the program correctly by taking
the `else` branch, resulting in following graph::

graph():
%arg0_1 := placeholder[target=arg0_1]

# d = x.max().item()
%max_1 := call_function[target=torch.ops.aten.max.default](args = (%arg0_1,))
%_local_scalar_dense := call_function[target=torch.ops.aten._local_scalar_dense.default](args = (%max_1,))

# Asserting 3 <= d <= 10
%ge := call_function[target=operator.ge](args = (%_local_scalar_dense, 3))
%scalar_tensor := call_function[target=torch.ops.aten.scalar_tensor.default](args = (%ge,))
%_assert_async := call_function[target=torch.ops.aten._assert_async.msg](
args = (%scalar_tensor, _local_scalar_dense is outside of inline constraint [3, 10].))
%le := call_function[target=operator.le](args = (%_local_scalar_dense, 10))
%scalar_tensor_1 := call_function[target=torch.ops.aten.scalar_tensor.default](args = (%le,))
%_assert_async_1 := call_function[target=torch.ops.aten._assert_async.msg](
args = (%scalar_tensor_1, _local_scalar_dense is outside of inline constraint [3, 10].))
%sym_constrain_range_for_size := call_function[target=torch.ops.aten.sym_constrain_range_for_size.default](
args = (%_local_scalar_dense,), kwargs = {min: 3, max: 10})

# Constructing new tensor with d
%full := call_function[target=torch.ops.aten.full.default](
args = ([%_local_scalar_dense], 1),
kwargs = {dtype: torch.float32, layout: torch.strided, device: cpu, pin_memory: False})

......


.. warning::
It is illegal to specify a range that contains 0 and 1. 0/1 values are always specialized
and can not be part of dynamic range.

Args:
symbol: Intermediate scalar value (int-only now) to apply range constraint on.
min (Optional[int]): Minimum possible value of given symbol (inclusive)
max (Optional[int]): Maximum possible value of given symbol (inclusive)

Returns:
None

"""

from torch._export.constraints import constrain_as_size

return constrain_as_size(symbol, min, max)


def dynamic_dim(t: torch.Tensor, index: int):
"""
`dynamic_dim` constructs a `Constraint` object that describes the dynamism of
Expand Down
2 changes: 2 additions & 0 deletions torch/overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ def get_ignored_functions() -> Set[Callable]:
torch.empty_permuted,
torch.empty_strided,
torch.empty_quantized,
torch.export.constrain_as_size,
torch.export.constrain_as_value,
torch.export.dynamic_dim,
torch.export.export,
torch.eye,
Expand Down