Skip to content

Commit

Permalink
Do not import sympy within torch._prims_common (#112034)
Browse files Browse the repository at this point in the history
This is the first of a few PRs that avoid importing SymPy at import time.
The pitch here is that we (almost!) do not have SymPy on our API, so
this should be feasible.

This should speed-up torch imports by a good 15% as per
https://dev-discuss.pytorch.org/t/delving-into-what-happens-when-you-import-torch/1589

In this PR we just move a few global imports into local imports.
Pull Request resolved: #112034
Approved by: https://github.com/ezyang
  • Loading branch information
lezcano authored and pytorchmergebot committed Oct 26, 2023
1 parent d6724a5 commit c8a5bb4
Show file tree
Hide file tree
Showing 9 changed files with 128 additions and 36 deletions.
9 changes: 8 additions & 1 deletion torch/_decomp/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
_safe_copy_out,
out_wrapper,
)
from torch.fx.experimental.symbolic_shapes import expect_true, guard_int
from torch.utils._pytree import tree_flatten, tree_map

DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined]
Expand Down Expand Up @@ -1230,6 +1229,10 @@ def split_with_sizes(
num_splits = len(split_sizes)
splits = []
start_idx = 0

# Avoid importing sympy at a module level
from torch.fx.experimental.symbolic_shapes import expect_true

for i in range(num_splits):
length = split_sizes[i]
torch._check_is_size(
Expand Down Expand Up @@ -1264,6 +1267,10 @@ def split(self: Tensor, split_size: int, dim: int = 0) -> Tuple[Tensor, ...]:
assert dim_size == 0
return (self,)
chunks = (dim_size + split_size - 1) // split_size

# Avoid importing sympy at a module level
from torch.fx.experimental.symbolic_shapes import guard_int

chunks = guard_int(chunks)
split_sizes = [split_size for i in range(chunks)]
split_sizes[-1] = split_size - (split_size * chunks - dim_size)
Expand Down
10 changes: 9 additions & 1 deletion torch/_guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
Optional,
Set,
Tuple,
TYPE_CHECKING,
TypeVar,
)

Expand All @@ -30,7 +31,14 @@

log = logging.getLogger(__name__)

import sympy

if TYPE_CHECKING:
# Import the following modules during type checking to enable code intelligence features,
# such as auto-completion in tools like pylance, even when these modules are not explicitly
# imported in user code.

import sympy


"""
torch._guards is the definitional source of truth for general purpose guard structures.
Expand Down
10 changes: 6 additions & 4 deletions torch/_meta_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,6 @@
out_wrapper,
)
from torch._refs import _broadcast_shapes, _maybe_broadcast
from torch.fx.experimental.symbolic_shapes import (
_constrain_range_for_size,
constrain_range,
)
from torch.utils._pytree import tree_map


Expand Down Expand Up @@ -522,6 +518,9 @@ def make_dep_token(

@register_meta(aten.sym_constrain_range.default)
def sym_constrain_range(size, min=None, max=None):
# Avoid importing sympy at a module level
from torch.fx.experimental.symbolic_shapes import constrain_range

if isinstance(size, (SymFloat, SymBool)):
raise ValueError("Constraining SymFloat or Symbool is nyi")
constrain_range(size, min=min, max=max)
Expand All @@ -535,6 +534,9 @@ def functional_sym_constrain_range(size, min=None, max=None, dep_token=None):

@register_meta(aten.sym_constrain_range_for_size.default)
def sym_constrain_range_for_size(size, min=None, max=None):
# Avoid importing sympy at a module level
from torch.fx.experimental.symbolic_shapes import _constrain_range_for_size

if isinstance(size, (SymFloat, SymBool)):
raise ValueError("Constraining SymFloat or Symbool is nyi")
_constrain_range_for_size(size, min=min, max=max)
Expand Down
14 changes: 13 additions & 1 deletion torch/_prims_common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,17 @@
Sequence,
Tuple,
Type,
TYPE_CHECKING,
Union,
)

import sympy

if TYPE_CHECKING:
# Import the following modules during type checking to enable code intelligence features,
# such as auto-completion in tools like pylance, even when these modules are not explicitly
# imported in user code.

import sympy

import torch
from torch import sym_float, sym_int, sym_max
Expand Down Expand Up @@ -1375,6 +1382,11 @@ def elementwise_dtypes(
args = tuple(x for x in _args if x is not None)

highest_type: type = bool

# Import sympy locally, as importing it eagerly at a module level is too slow
# See https://dev-discuss.pytorch.org/t/delving-into-what-happens-when-you-import-torch/1589
import sympy

for x in args:
if not isinstance(x, (Number, TensorLike, sympy.Symbol)):
msg = f"Unexpected type {str(type(x))} when computing elementwise type promotion!"
Expand Down
55 changes: 43 additions & 12 deletions torch/_subclasses/fake_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,18 @@
import weakref
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Tuple,
Type,
TYPE_CHECKING,
TypeVar,
Union,
)
from weakref import ReferenceType

import torch
Expand All @@ -26,13 +37,6 @@
)
from torch._subclasses.meta_utils import MetaConverter
from torch._utils import render_call
from torch.fx.experimental.symbolic_shapes import (
_constrain_range_for_size,
DimConstraint,
DimDynamic,
free_symbols,
is_symbolic,
)
from torch.fx.operator_schemas import normalize_function
from torch.multiprocessing.reductions import StorageWeakRef
from torch.overrides import TorchFunctionMode
Expand All @@ -46,6 +50,11 @@
from torch.utils._stats import count, count_label
from torch.utils.weak import WeakIdRef

if TYPE_CHECKING:
# Import the following modules during type checking to enable code intelligence features
# Do not import unconditionally, as they import sympy and importing sympy is very slow
from torch.fx.experimental.symbolic_shapes import DimConstraint, DimDynamic

DimList = List

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -321,8 +330,8 @@ def from_real_tensor(
ignore_subclass=False,
*,
source=None,
dynamic_dims: Optional[DimList[DimDynamic]] = None,
constraint_dims: Optional[DimList[DimConstraint]] = None,
dynamic_dims: "Optional[DimList[DimDynamic]]" = None,
constraint_dims: "Optional[DimList[DimConstraint]]" = None,
memoized_only=False,
):
maybe_memo = self._get_memo(t)
Expand Down Expand Up @@ -541,6 +550,10 @@ def repeat_interleave_tensor(fake_mode, func, repeats, output_size=None):
raise DynamicOutputShapeException(func)

output_size = fake_mode.shape_env.create_unbacked_symint()

# Avoid importing sympy at a module level
from torch.fx.experimental.symbolic_shapes import _constrain_range_for_size

_constrain_range_for_size(output_size)
# TODO: consider a memo
return repeats.new_empty(output_size)
Expand Down Expand Up @@ -582,6 +595,13 @@ def nonzero(fake_mode, func, arg):
# remember, the hypothesis is that if your later code works
# with N >= 2, it will work with N = 1 and N = 0.
maxval = sys.maxsize - 1

# Avoid importing sympy at a module level
from torch.fx.experimental.symbolic_shapes import (
_constrain_range_for_size,
free_symbols,
)

if not free_symbols(arg.numel()):
# Don't upgrade the range if numel is less than two, since we then
# have an empty range which makes things go explodey. We also
Expand Down Expand Up @@ -611,6 +631,13 @@ def masked_select(fake_mode, func, self, mask):

# see nonzero for commentary
maxval = sys.maxsize - 1

# Avoid importing sympy at a module level
from torch.fx.experimental.symbolic_shapes import (
_constrain_range_for_size,
free_symbols,
)

if not free_symbols(arg.numel()):
if arg.numel() >= 2:
maxval = int(arg.numel())
Expand Down Expand Up @@ -722,6 +749,7 @@ def conv(fake_mode, func, *args, **kwargs):
k = kwargs["weight"].ndim
batch = kwargs["input"].shape[0]

# Avoid importing sympy at a module level
from torch.fx.experimental.symbolic_shapes import has_hint

if not has_hint(batch):
Expand Down Expand Up @@ -1250,6 +1278,9 @@ def merge_devices(t):
# of the tensor to create the output Python list, and (2) creating unbacked
# symints for each element of the list.
def tolist(self):
# Avoid importing sympy at a module level
from torch.fx.experimental.symbolic_shapes import is_symbolic

assert self.dim() == 1 and is_symbolic(self.shape[0])
shape_env = self.shape[0].node.shape_env
out = []
Expand Down Expand Up @@ -1819,8 +1850,8 @@ def from_tensor(
static_shapes=None,
ignore_subclass=False,
source: Optional[Source] = None,
dynamic_dims: Optional[DimList[DimDynamic]] = None,
constraint_dims: Optional[DimList[DimConstraint]] = None,
dynamic_dims: "Optional[DimList[DimDynamic]]" = None,
constraint_dims: "Optional[DimList[DimConstraint]]" = None,
# Setting this flag will force FakeTensorMode to return `None` if attempting to convert a tensor we have not
# seen before.
memoized_only=False,
Expand Down
4 changes: 3 additions & 1 deletion torch/_subclasses/fake_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
tree_flatten_only,
UnsupportedFakeTensorException,
)
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._pytree import tree_flatten

Expand Down Expand Up @@ -80,6 +79,9 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None):
and torch.Tag.inplace_view not in func.tags
and torch.Tag.data_dependent_output not in func.tags
):
# Do not import symbolic_shapes at the top of the module as it imports sympy and that's slow
from torch.fx.experimental.symbolic_shapes import ShapeEnv

try:
# TODO: enable_python_dispatcher() here
with FakeTensorMode(shape_env=ShapeEnv()) as fake_mode:
Expand Down
13 changes: 9 additions & 4 deletions torch/_subclasses/meta_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import contextlib
import warnings
import weakref
from typing import ContextManager, List, Optional
from typing import ContextManager, List, Optional, TYPE_CHECKING

import torch
from torch._C._functorch import (
Expand All @@ -13,14 +13,18 @@
)
from torch._guards import Source

from torch.fx.experimental.symbolic_shapes import DimConstraint, DimDynamic
from torch.multiprocessing.reductions import StorageWeakRef
from torch.utils._python_dispatch import (
is_traceable_wrapper_subclass,
transform_subclass,
)
from torch.utils.weak import WeakIdRef

if TYPE_CHECKING:
# Import the following modules during type checking to enable code intelligence features,
# Do not import unconditionally, as they import sympy and importing sympy is very slow
from torch.fx.experimental.symbolic_shapes import DimConstraint, DimDynamic

DimList = List


Expand Down Expand Up @@ -180,8 +184,8 @@ def meta_tensor(
shape_env=None,
callback=lambda t: t(),
source: Optional[Source] = None,
dynamic_dims: Optional[DimList[DimDynamic]] = None,
constraint_dims: Optional[DimList[DimConstraint]] = None,
dynamic_dims: "Optional[DimList[DimDynamic]]" = None,
constraint_dims: "Optional[DimList[DimConstraint]]" = None,
):
if source is None:
from torch._dynamo.source import ConstantSource
Expand Down Expand Up @@ -303,6 +307,7 @@ def sym_sizes_strides_storage_offset(t, src):
assert t._is_view()

from torch._dynamo.source import AttrSource
from torch.fx.experimental.symbolic_shapes import DimDynamic

if shape_env:
base_dynamic_dims = [DimDynamic.STATIC] * t._base.dim()
Expand Down
32 changes: 25 additions & 7 deletions torch/export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,28 @@
import dataclasses
import inspect
import io
import math
import pathlib
import sys
import typing
from enum import auto, Enum
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union

import sympy
from typing import (
Any,
Callable,
Dict,
Iterator,
List,
Optional,
Tuple,
TYPE_CHECKING,
Union,
)

import torch
import torch.fx._pytree as fx_pytree
import torch.utils._pytree as pytree
from torch.fx._compatibility import compatibility

from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint

from torch.fx.passes.infra.pass_base import PassResult
from torch.fx.passes.infra.pass_manager import PassManager

Expand All @@ -28,6 +35,11 @@
UnflattenFunc,
)

if TYPE_CHECKING:
# Import the following modules during type checking to enable code intelligence features,
# Do not import unconditionally, as they import sympy and importing sympy is very slow
from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint


__all__ = [
"Constraint",
Expand Down Expand Up @@ -109,13 +121,15 @@ class Constraint(_ConstraintTarget, metaclass=_ConstraintFactory):
"""

# NOTE(avik): In the future, this could be Union[StrictMinMaxConstraint, <other kinds>]
constraint_range: StrictMinMaxConstraint
constraint_range: "StrictMinMaxConstraint"
# Represent that `constraint_range` is shared with another _ConstraintTarget, which
# typically arises because of a specified equality with another dynamic dimension.
shared: Optional[_ConstraintTarget] = None
debug_name: Optional[str] = None

def _clone_with_range(self, lower=2, upper=sympy.oo):
def _clone_with_range(self, lower=2, upper=math.inf):
# Import sympy locally
from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint
from torch.utils._sympy.value_ranges import ValueRanges

constraint_range = StrictMinMaxConstraint(
Expand Down Expand Up @@ -184,6 +198,10 @@ def __eq__(self, other):
"A dynamic dim can be specified equal only to another dynamic dim. "
f"Equality with {type(other)} is not supported."
)

# import sympy locally
from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint

constraint_range = StrictMinMaxConstraint(
vr=self.constraint_range.vr & other.constraint_range.vr,
warn_only=False,
Expand Down

0 comments on commit c8a5bb4

Please sign in to comment.