Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not import sympy within torch._prims_common #112034

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 8 additions & 1 deletion torch/_decomp/decompositions.py
Expand Up @@ -21,7 +21,6 @@
_safe_copy_out,
out_wrapper,
)
from torch.fx.experimental.symbolic_shapes import expect_true, guard_int
from torch.utils._pytree import tree_flatten, tree_map

DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined]
Expand Down Expand Up @@ -1230,6 +1229,10 @@ def split_with_sizes(
num_splits = len(split_sizes)
splits = []
start_idx = 0

# Avoid importing sympy at a module level
from torch.fx.experimental.symbolic_shapes import expect_true

for i in range(num_splits):
length = split_sizes[i]
torch._check_is_size(
Expand Down Expand Up @@ -1264,6 +1267,10 @@ def split(self: Tensor, split_size: int, dim: int = 0) -> Tuple[Tensor, ...]:
assert dim_size == 0
return (self,)
chunks = (dim_size + split_size - 1) // split_size

# Avoid importing sympy at a module level
from torch.fx.experimental.symbolic_shapes import guard_int

chunks = guard_int(chunks)
split_sizes = [split_size for i in range(chunks)]
split_sizes[-1] = split_size - (split_size * chunks - dim_size)
Expand Down
10 changes: 9 additions & 1 deletion torch/_guards.py
Expand Up @@ -22,6 +22,7 @@
Optional,
Set,
Tuple,
TYPE_CHECKING,
TypeVar,
)

Expand All @@ -30,7 +31,14 @@

log = logging.getLogger(__name__)

import sympy

if TYPE_CHECKING:
# Import the following modules during type checking to enable code intelligence features,
# such as auto-completion in tools like pylance, even when these modules are not explicitly
# imported in user code.

import sympy


"""
torch._guards is the definitional source of truth for general purpose guard structures.
Expand Down
10 changes: 6 additions & 4 deletions torch/_meta_registrations.py
Expand Up @@ -30,10 +30,6 @@
out_wrapper,
)
from torch._refs import _broadcast_shapes, _maybe_broadcast
from torch.fx.experimental.symbolic_shapes import (
_constrain_range_for_size,
constrain_range,
)
from torch.utils._pytree import tree_map


Expand Down Expand Up @@ -522,6 +518,9 @@ def make_dep_token(

@register_meta(aten.sym_constrain_range.default)
def sym_constrain_range(size, min=None, max=None):
# Avoid importing sympy at a module level
from torch.fx.experimental.symbolic_shapes import constrain_range

if isinstance(size, (SymFloat, SymBool)):
raise ValueError("Constraining SymFloat or Symbool is nyi")
constrain_range(size, min=min, max=max)
Expand All @@ -535,6 +534,9 @@ def functional_sym_constrain_range(size, min=None, max=None, dep_token=None):

@register_meta(aten.sym_constrain_range_for_size.default)
def sym_constrain_range_for_size(size, min=None, max=None):
# Avoid importing sympy at a module level
from torch.fx.experimental.symbolic_shapes import _constrain_range_for_size

if isinstance(size, (SymFloat, SymBool)):
raise ValueError("Constraining SymFloat or Symbool is nyi")
_constrain_range_for_size(size, min=min, max=max)
Expand Down
14 changes: 13 additions & 1 deletion torch/_prims_common/__init__.py
Expand Up @@ -17,10 +17,17 @@
Sequence,
Tuple,
Type,
TYPE_CHECKING,
Union,
)

import sympy

if TYPE_CHECKING:
# Import the following modules during type checking to enable code intelligence features,
# such as auto-completion in tools like pylance, even when these modules are not explicitly
# imported in user code.

import sympy

import torch
from torch import sym_float, sym_int, sym_max
Expand Down Expand Up @@ -1375,6 +1382,11 @@ def elementwise_dtypes(
args = tuple(x for x in _args if x is not None)

highest_type: type = bool

# Import sympy locally, as importing it eagerly at a module level is too slow
# See https://dev-discuss.pytorch.org/t/delving-into-what-happens-when-you-import-torch/1589
import sympy

for x in args:
if not isinstance(x, (Number, TensorLike, sympy.Symbol)):
msg = f"Unexpected type {str(type(x))} when computing elementwise type promotion!"
Expand Down
55 changes: 43 additions & 12 deletions torch/_subclasses/fake_tensor.py
Expand Up @@ -8,7 +8,18 @@
import weakref
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Tuple,
Type,
TYPE_CHECKING,
TypeVar,
Union,
)
from weakref import ReferenceType

import torch
Expand All @@ -26,13 +37,6 @@
)
from torch._subclasses.meta_utils import MetaConverter
from torch._utils import render_call
from torch.fx.experimental.symbolic_shapes import (
_constrain_range_for_size,
DimConstraint,
DimDynamic,
free_symbols,
is_symbolic,
)
from torch.fx.operator_schemas import normalize_function
from torch.multiprocessing.reductions import StorageWeakRef
from torch.overrides import TorchFunctionMode
Expand All @@ -46,6 +50,11 @@
from torch.utils._stats import count, count_label
from torch.utils.weak import WeakIdRef

if TYPE_CHECKING:
# Import the following modules during type checking to enable code intelligence features
# Do not import unconditionally, as they import sympy and importing sympy is very slow
from torch.fx.experimental.symbolic_shapes import DimConstraint, DimDynamic

DimList = List

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -321,8 +330,8 @@ def from_real_tensor(
ignore_subclass=False,
*,
source=None,
dynamic_dims: Optional[DimList[DimDynamic]] = None,
constraint_dims: Optional[DimList[DimConstraint]] = None,
dynamic_dims: "Optional[DimList[DimDynamic]]" = None,
constraint_dims: "Optional[DimList[DimConstraint]]" = None,
memoized_only=False,
):
maybe_memo = self._get_memo(t)
Expand Down Expand Up @@ -541,6 +550,10 @@ def repeat_interleave_tensor(fake_mode, func, repeats, output_size=None):
raise DynamicOutputShapeException(func)

output_size = fake_mode.shape_env.create_unbacked_symint()

# Avoid importing sympy at a module level
from torch.fx.experimental.symbolic_shapes import _constrain_range_for_size

_constrain_range_for_size(output_size)
# TODO: consider a memo
return repeats.new_empty(output_size)
Expand Down Expand Up @@ -582,6 +595,13 @@ def nonzero(fake_mode, func, arg):
# remember, the hypothesis is that if your later code works
# with N >= 2, it will work with N = 1 and N = 0.
maxval = sys.maxsize - 1

# Avoid importing sympy at a module level
from torch.fx.experimental.symbolic_shapes import (
_constrain_range_for_size,
free_symbols,
)

if not free_symbols(arg.numel()):
# Don't upgrade the range if numel is less than two, since we then
# have an empty range which makes things go explodey. We also
Expand Down Expand Up @@ -611,6 +631,13 @@ def masked_select(fake_mode, func, self, mask):

# see nonzero for commentary
maxval = sys.maxsize - 1

# Avoid importing sympy at a module level
from torch.fx.experimental.symbolic_shapes import (
_constrain_range_for_size,
free_symbols,
)

if not free_symbols(arg.numel()):
if arg.numel() >= 2:
maxval = int(arg.numel())
Expand Down Expand Up @@ -722,6 +749,7 @@ def conv(fake_mode, func, *args, **kwargs):
k = kwargs["weight"].ndim
batch = kwargs["input"].shape[0]

# Avoid importing sympy at a module level
from torch.fx.experimental.symbolic_shapes import has_hint

if not has_hint(batch):
Expand Down Expand Up @@ -1250,6 +1278,9 @@ def merge_devices(t):
# of the tensor to create the output Python list, and (2) creating unbacked
# symints for each element of the list.
def tolist(self):
# Avoid importing sympy at a module level
from torch.fx.experimental.symbolic_shapes import is_symbolic

assert self.dim() == 1 and is_symbolic(self.shape[0])
shape_env = self.shape[0].node.shape_env
out = []
Expand Down Expand Up @@ -1819,8 +1850,8 @@ def from_tensor(
static_shapes=None,
ignore_subclass=False,
source: Optional[Source] = None,
dynamic_dims: Optional[DimList[DimDynamic]] = None,
constraint_dims: Optional[DimList[DimConstraint]] = None,
dynamic_dims: "Optional[DimList[DimDynamic]]" = None,
constraint_dims: "Optional[DimList[DimConstraint]]" = None,
# Setting this flag will force FakeTensorMode to return `None` if attempting to convert a tensor we have not
# seen before.
memoized_only=False,
Expand Down
4 changes: 3 additions & 1 deletion torch/_subclasses/fake_utils.py
Expand Up @@ -10,7 +10,6 @@
tree_flatten_only,
UnsupportedFakeTensorException,
)
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._pytree import tree_flatten

Expand Down Expand Up @@ -80,6 +79,9 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None):
and torch.Tag.inplace_view not in func.tags
and torch.Tag.data_dependent_output not in func.tags
):
# Do not import symbolic_shapes at the top of the module as it imports sympy and that's slow
from torch.fx.experimental.symbolic_shapes import ShapeEnv

try:
# TODO: enable_python_dispatcher() here
with FakeTensorMode(shape_env=ShapeEnv()) as fake_mode:
Expand Down
13 changes: 9 additions & 4 deletions torch/_subclasses/meta_utils.py
@@ -1,7 +1,7 @@
import contextlib
import warnings
import weakref
from typing import ContextManager, List, Optional
from typing import ContextManager, List, Optional, TYPE_CHECKING

import torch
from torch._C._functorch import (
Expand All @@ -13,14 +13,18 @@
)
from torch._guards import Source

from torch.fx.experimental.symbolic_shapes import DimConstraint, DimDynamic
from torch.multiprocessing.reductions import StorageWeakRef
from torch.utils._python_dispatch import (
is_traceable_wrapper_subclass,
transform_subclass,
)
from torch.utils.weak import WeakIdRef

if TYPE_CHECKING:
# Import the following modules during type checking to enable code intelligence features,
# Do not import unconditionally, as they import sympy and importing sympy is very slow
from torch.fx.experimental.symbolic_shapes import DimConstraint, DimDynamic

DimList = List


Expand Down Expand Up @@ -180,8 +184,8 @@ def meta_tensor(
shape_env=None,
callback=lambda t: t(),
source: Optional[Source] = None,
dynamic_dims: Optional[DimList[DimDynamic]] = None,
constraint_dims: Optional[DimList[DimConstraint]] = None,
dynamic_dims: "Optional[DimList[DimDynamic]]" = None,
constraint_dims: "Optional[DimList[DimConstraint]]" = None,
):
if source is None:
from torch._dynamo.source import ConstantSource
Expand Down Expand Up @@ -303,6 +307,7 @@ def sym_sizes_strides_storage_offset(t, src):
assert t._is_view()

from torch._dynamo.source import AttrSource
from torch.fx.experimental.symbolic_shapes import DimDynamic

if shape_env:
base_dynamic_dims = [DimDynamic.STATIC] * t._base.dim()
Expand Down
32 changes: 25 additions & 7 deletions torch/export/__init__.py
Expand Up @@ -3,21 +3,28 @@
import dataclasses
import inspect
import io
import math
import pathlib
import sys
import typing
from enum import auto, Enum
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union

import sympy
from typing import (
Any,
Callable,
Dict,
Iterator,
List,
Optional,
Tuple,
TYPE_CHECKING,
Union,
)

import torch
import torch.fx._pytree as fx_pytree
import torch.utils._pytree as pytree
from torch.fx._compatibility import compatibility

from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint

from torch.fx.passes.infra.pass_base import PassResult
from torch.fx.passes.infra.pass_manager import PassManager

Expand All @@ -28,6 +35,11 @@
UnflattenFunc,
)

if TYPE_CHECKING:
# Import the following modules during type checking to enable code intelligence features,
# Do not import unconditionally, as they import sympy and importing sympy is very slow
from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint


__all__ = [
"Constraint",
Expand Down Expand Up @@ -109,13 +121,15 @@ class Constraint(_ConstraintTarget, metaclass=_ConstraintFactory):
"""

# NOTE(avik): In the future, this could be Union[StrictMinMaxConstraint, <other kinds>]
constraint_range: StrictMinMaxConstraint
constraint_range: "StrictMinMaxConstraint"
# Represent that `constraint_range` is shared with another _ConstraintTarget, which
# typically arises because of a specified equality with another dynamic dimension.
shared: Optional[_ConstraintTarget] = None
debug_name: Optional[str] = None

def _clone_with_range(self, lower=2, upper=sympy.oo):
def _clone_with_range(self, lower=2, upper=math.inf):
# Import sympy locally
from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint
from torch.utils._sympy.value_ranges import ValueRanges

constraint_range = StrictMinMaxConstraint(
Expand Down Expand Up @@ -184,6 +198,10 @@ def __eq__(self, other):
"A dynamic dim can be specified equal only to another dynamic dim. "
f"Equality with {type(other)} is not supported."
)

# import sympy locally
from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint

constraint_range = StrictMinMaxConstraint(
vr=self.constraint_range.vr & other.constraint_range.vr,
warn_only=False,
Expand Down