Skip to content

Commit

Permalink
FC preparation for int_oo in PyTorch (#3947)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3947

This is needed for pytorch/pytorch#127693 . This code is written so it is compatible before and after this PR.

Reviewed By: mergennachin, clee2000

Differential Revision: D58465158

fbshipit-source-id: ca0f2a79eb07e78ff2887f78eb62ff38eeea3ede
  • Loading branch information
ezyang authored and facebook-github-bot committed Jun 12, 2024
1 parent 8f08b8b commit d308fd5
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 6 deletions.
4 changes: 2 additions & 2 deletions exir/passes/sym_shape_eval_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,6 @@ def call(self, graph_module: GraphModule):
"Please use export's constrain_as_size() or constrain_as_value() apis and set a concrete upper bound to resolve this."
)

spec.shape = concrete_shape # pyre-ignore[8]: Attribute `stride` declared in class `TensorSpec` has type `Tuple[int]` but is used as type `List[Optional[int]]`
spec.stride = concrete_stride # pyre-ignore[8]: Attribute `stride` declared in class `TensorSpec` has type `Tuple[int]` but is used as type `List[Optional[int]]`
spec.shape = concrete_shape
spec.stride = concrete_stride # pyre-ignore[8]: Attribute `stride` declared in class `TensorSpec` has type `Tuple[int]` but is used as type `List[int]`
return PassResult(graph_module, True)
15 changes: 11 additions & 4 deletions exir/sym_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def eval_expr(symint: Union[int, torch.SymInt]) -> Optional[int]:
return int(output)


def eval_upper_bound(maybe_symint: Union[int, torch.SymInt]) -> Optional[int]:
def eval_upper_bound(maybe_symint: Union[int, torch.SymInt]) -> int:
"""
Evaluate a symint to its uppper bound value. Returns None if symint's symoblic expr's
upper bound can not be evaluated to valid integer according to the constraints in shape_env.
Expand All @@ -41,17 +41,24 @@ def eval_upper_bound(maybe_symint: Union[int, torch.SymInt]) -> Optional[int]:
expr = node.expr
var_range: ValueRanges = bound_sympy(expr, shape_env.var_to_range)
upper_bound = var_range.upper
# This import is needed temporarily until we update the pinned torch version.

try:
from torch.utils._sympy.numbers import int_oo # @manual # pyre-ignore
except ImportError:
int_oo = None

if isinstance(upper_bound, sympy.Integer):
concrete_upper = int(var_range.upper)
assert isinstance(
concrete_upper, int
), f"Expect upper bound to be a concrete int but got {concrete_upper}"
return concrete_upper
elif isinstance(upper_bound, sympy.oo):
return None
elif int_oo is not None and upper_bound is int_oo: # pyre-ignore
return int_oo # pyre-ignore
else:
raise RuntimeError(
f"Expect upper bound to be sympy.Integer or sympy.oo. but got {upper_bound}"
f"Expect upper bound to be sympy.Integer or int_oo. but got {upper_bound}"
)


Expand Down

0 comments on commit d308fd5

Please sign in to comment.