From 77e058f055696d6ea9c0f314d36d8a5b7ae7ef9d Mon Sep 17 00:00:00 2001 From: Andrew Gu Date: Mon, 20 Nov 2023 08:47:15 -0800 Subject: [PATCH 001/221] [DTensor] Made `_Partial`, `Replicate` frozen dataclasses (#113919) This is part of the larger stack to work toward being able to cache hashes for `DTensorSpec`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/113919 Approved by: https://github.com/wanchaol --- torch/distributed/_tensor/placement_types.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torch/distributed/_tensor/placement_types.py b/torch/distributed/_tensor/placement_types.py index 3ddda3b58f4b..f4bd2ceb056c 100644 --- a/torch/distributed/_tensor/placement_types.py +++ b/torch/distributed/_tensor/placement_types.py @@ -275,6 +275,7 @@ def __str__(self) -> str: return f"S({self.dim})" +@dataclass(frozen=True) class Replicate(Placement): # replicate placement def __eq__(self, other: object) -> bool: @@ -315,6 +316,7 @@ def _replicate_tensor( return tensor +@dataclass(frozen=True) class _Partial(Placement): # This is a default partial placement with element-wise reduce op # when doing reduction it follows the contract of `_to_replicate` @@ -323,9 +325,7 @@ class _Partial(Placement): # # We can implement custom reductions as needed by subclassing this # class and override those contracts. - - def __init__(self, reduce_op: c10d.ReduceOp.RedOpType = c10d.ReduceOp.SUM): - self.reduce_op: c10d.ReduceOp.RedOpType = reduce_op + reduce_op: c10d.ReduceOp.RedOpType = c10d.ReduceOp.SUM def _to_replicate( self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int From b41ad7d69540738eeb9132a0a53a28325f30766b Mon Sep 17 00:00:00 2001 From: Andrew Gu Date: Mon, 20 Nov 2023 08:47:15 -0800 Subject: [PATCH 002/221] [DTensor] Used new placements for neg dim in `redistribute` (#113924) Pull Request resolved: https://github.com/pytorch/pytorch/pull/113924 Approved by: https://github.com/wanchaol ghstack dependencies: #113919 --- torch/distributed/_tensor/api.py | 6 ++++-- torch/distributed/_tensor/placement_types.py | 4 ++++ torch/distributed/_tensor/redistribute.py | 4 ++-- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/torch/distributed/_tensor/api.py b/torch/distributed/_tensor/api.py index fa2bcddc9659..505a3d7dfa4d 100644 --- a/torch/distributed/_tensor/api.py +++ b/torch/distributed/_tensor/api.py @@ -420,14 +420,16 @@ def redistribute( if placements is None: raise RuntimeError("placements is needed for redistribute!") - for placement in placements: + placements = list(placements) + for i, placement in enumerate(placements): if placement.is_partial(): raise RuntimeError( "Can not redistribute to _Partial, _Partial is for internal use only!" ) elif isinstance(placement, Shard) and placement.dim < 0: # normalize shard dim to be positive - placement.dim += self.ndim + placements[i] = Shard(placement.dim + self.ndim) + placements = tuple(placements) # Early return the original DTensor if the placements are the same. if self._spec.placements == placements: diff --git a/torch/distributed/_tensor/placement_types.py b/torch/distributed/_tensor/placement_types.py index f4bd2ceb056c..458963a124c8 100644 --- a/torch/distributed/_tensor/placement_types.py +++ b/torch/distributed/_tensor/placement_types.py @@ -384,6 +384,10 @@ class DTensorSpec: # tensor meta will only be set during sharding propagation tensor_meta: Optional[TensorMeta] = None + def __post_init__(self): + if not isinstance(self.placements, tuple): + self.placements = tuple(self.placements) + def __hash__(self) -> int: # hashing and equality check for DTensorSpec are used to cache the sharding # propagation results. We only need to consider the mesh, placements, shape diff --git a/torch/distributed/_tensor/redistribute.py b/torch/distributed/_tensor/redistribute.py index c1ee9990a08f..9ba008d5c697 100644 --- a/torch/distributed/_tensor/redistribute.py +++ b/torch/distributed/_tensor/redistribute.py @@ -182,12 +182,12 @@ def forward( # type: ignore[override] ctx, input: "dtensor.DTensor", device_mesh: DeviceMesh, - placements: List[Placement], + placements: Tuple[Placement, ...], ): current_spec = input._spec ctx.current_spec = current_spec target_spec = DTensorSpec( - device_mesh, tuple(placements), tensor_meta=input._spec.tensor_meta + device_mesh, placements, tensor_meta=input._spec.tensor_meta ) local_tensor = input._local_tensor From f4ffd46c081b8719024edd6383bec1cb126d9904 Mon Sep 17 00:00:00 2001 From: Andrew Gu Date: Mon, 20 Nov 2023 08:47:21 -0800 Subject: [PATCH 003/221] [DTensor] Used new placements for neg dim in `from_local` (#114134) Pull Request resolved: https://github.com/pytorch/pytorch/pull/114134 Approved by: https://github.com/wanchaol ghstack dependencies: #113919, #113924 --- torch/distributed/_tensor/api.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/torch/distributed/_tensor/api.py b/torch/distributed/_tensor/api.py index 505a3d7dfa4d..8999f9a4eed7 100644 --- a/torch/distributed/_tensor/api.py +++ b/torch/distributed/_tensor/api.py @@ -343,6 +343,14 @@ def from_local( # set default placements to replicated if not specified if placements is None: placements = [Replicate() for _ in range(device_mesh.ndim)] + else: + placements = list(placements) + for idx, placement in enumerate(placements): + # normalize shard dim to be positive + if placement.is_shard(): + placement = cast(Shard, placement) + if placement.dim < 0: + placements[idx] = Shard(placement.dim + local_tensor.ndim) # `from_local` is differentiable, and the gradient of the dist tensor this function # created should flow back the gradients to the local_tensor, so we call an autograd From e2095a04ae755ae8724a5f052b62a0c8ea707c13 Mon Sep 17 00:00:00 2001 From: Andrew Gu Date: Mon, 20 Nov 2023 08:47:22 -0800 Subject: [PATCH 004/221] [DTensor] Ensured `grad_placements` was tuple (#113925) Pull Request resolved: https://github.com/pytorch/pytorch/pull/113925 Approved by: https://github.com/wanchaol ghstack dependencies: #113919, #113924, #114134 --- torch/distributed/_tensor/api.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch/distributed/_tensor/api.py b/torch/distributed/_tensor/api.py index 8999f9a4eed7..ae6b15421727 100644 --- a/torch/distributed/_tensor/api.py +++ b/torch/distributed/_tensor/api.py @@ -390,6 +390,8 @@ def to_local( .. note:: `to_local` is differentiable, the `requires_grad` of the local tensor returned will depend on if the `DTensor` requires_grad or not. """ + if grad_placements is not None and not isinstance(grad_placements, tuple): + grad_placements = tuple(grad_placements) return _ToTorchTensor.apply( self, grad_placements, True ) # pyre-ignore[16]: autograd func From c39c69953f0b60efc4b5a6acf0c6d51ce750c316 Mon Sep 17 00:00:00 2001 From: Andrew Gu Date: Mon, 20 Nov 2023 08:47:22 -0800 Subject: [PATCH 005/221] [DTensor] Used new placements for neg dim in `distribute_tensor` (#113930) Pull Request resolved: https://github.com/pytorch/pytorch/pull/113930 Approved by: https://github.com/wanchaol ghstack dependencies: #113919, #113924, #114134, #113925 --- torch/distributed/_tensor/api.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/torch/distributed/_tensor/api.py b/torch/distributed/_tensor/api.py index ae6b15421727..1235d16ba6cd 100644 --- a/torch/distributed/_tensor/api.py +++ b/torch/distributed/_tensor/api.py @@ -586,12 +586,14 @@ def distribute_tensor( local_tensor = tensor # distribute the tensor according to the placements. + placements = list(placements) for idx, placement in enumerate(placements): if placement.is_shard(): placement = cast(Shard, placement) if placement.dim < 0: # normalize shard placement dim - placement.dim += tensor.ndim + placement = Shard(placement.dim + tensor.ndim) + placements[idx] = placement local_tensor = placement._shard_tensor(local_tensor, device_mesh, idx) elif placement.is_replicate(): placement = cast(Replicate, placement) @@ -600,6 +602,7 @@ def distribute_tensor( raise RuntimeError( f"Trying to distribute tensor with unsupported placements {placement} on device mesh dimension {idx}!" ) + placements = tuple(placements) assert local_tensor is not None, "distributing a tensor should not be None" # detach the local tensor passed to DTensor since after the construction @@ -607,7 +610,7 @@ def distribute_tensor( return DTensor( local_tensor.detach().requires_grad_(tensor.requires_grad), device_mesh, - tuple(placements), + placements, shape=tensor.size(), dtype=tensor.dtype, requires_grad=tensor.requires_grad, From 4b07fca7d7f761dee3191c301024a290861e2587 Mon Sep 17 00:00:00 2001 From: Adnan Akhundov Date: Mon, 20 Nov 2023 10:01:45 -0800 Subject: [PATCH 006/221] [export] Allow shifted constraint ranges in dynamo._export (#114024) Summary: Previously, when we had two dynamic shape symbols `s0` and `s1` bound by the relationship `s1 == s0 + 1`, even when the range constraints were set in accordance with the relationship (e.g., to `[2, 1024]` for `s0` and to `[3, 1025]` for `s1`), `torch._dynamo.export` raised an error saying that the constraint is violated. Here we add a range check between the expression and the constraint and, if the ranges match, don't declare the constraint violated. We also add a flag to disable the dim constraint solver in `torch._dynamo.export` (not set by default for BC), passed down from the `torch._export.aot_compile`. This is because, even for simple constraints like `s1 == s0 + 1`, the solver claims that the constraint is too complex and the dimension `s0` must be specialized. The new flag is not exposed as a part of the public API (i.e., the one without `_`s in the module names). Both changes are required to unblock PT2 compilation of an internal model with AOT Inductor. Test Plan: ``` $ python test/inductor/test_aot_inductor.py -k test_shifted_constraint_ranges s... ---------------------------------------------------------------------- Ran 4 tests in 53.247s OK (skipped=1) ``` Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/114024 Approved by: https://github.com/zhxchen17 --- test/inductor/test_aot_inductor.py | 68 ++++++++++++++++++++++-- torch/_dynamo/eval_frame.py | 7 ++- torch/_export/__init__.py | 15 +++++- torch/fx/experimental/symbolic_shapes.py | 11 +++- 4 files changed, 93 insertions(+), 8 deletions(-) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index fa7411dd5cf2..c217fd7d751e 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -60,7 +60,14 @@ class AOTInductorModelRunner: @classmethod - def compile(cls, model, example_inputs, options=None, constraints=None): + def compile( + cls, + model, + example_inputs, + options=None, + constraints=None, + disable_constraint_solver=False, + ): # The exact API is subject to change so_path = torch._export.aot_compile( model, @@ -68,6 +75,7 @@ def compile(cls, model, example_inputs, options=None, constraints=None): options=options, constraints=constraints, remove_runtime_assertions=True, + disable_constraint_solver=disable_constraint_solver, ) return so_path @@ -111,9 +119,21 @@ def optimized(*args): return optimized @classmethod - def run(cls, device, model, example_inputs, options=None, constraints=None): + def run( + cls, + device, + model, + example_inputs, + options=None, + constraints=None, + disable_constraint_solver=False, + ): so_path = AOTInductorModelRunner.compile( - model, example_inputs, options=options, constraints=constraints + model, + example_inputs, + options=options, + constraints=constraints, + disable_constraint_solver=disable_constraint_solver, ) optimized = AOTInductorModelRunner.load(device, so_path, example_inputs) return optimized(example_inputs) @@ -146,6 +166,7 @@ def check_model( example_inputs, options=None, constraints=None, + disable_constraint_solver=False, ): with torch.no_grad(), config.patch( "aot_inductor.abi_compatible", self.abi_compatible @@ -158,7 +179,12 @@ def check_model( torch.manual_seed(0) actual = AOTInductorModelRunner.run( - self.device, model, example_inputs, options, constraints + self.device, + model, + example_inputs, + options, + constraints, + disable_constraint_solver, ) self.assertTrue(same(actual, expected)) @@ -1205,6 +1231,36 @@ def forward(self, x): ] self.check_model(Model(), (a,), constraints=constraints) + def test_shifted_constraint_ranges(self): + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward( + self, + x: torch.Tensor, + y: torch.Tensor, + ): + torch._check(y.size(0) == x.size(0) + 1) + return x.sum(0) + y.sum(0) + + a = torch.randn((4, 5), device=self.device) + b = torch.randn((5, 5), device=self.device) + + constraints = [ + torch._export.dynamic_dim(a, 0) >= 2, + torch._export.dynamic_dim(a, 0) <= 1024, + torch._export.dynamic_dim(b, 0) >= 3, + torch._export.dynamic_dim(b, 0) <= 1025, + ] + + self.check_model( + Model(), + (a, b), + constraints=constraints, + disable_constraint_solver=True, + ) + common_utils.instantiate_parametrized_tests(AOTInductorTestsTemplate) @@ -1240,6 +1296,10 @@ class AOTInductorTestABICompatibleCpu(TestCase): "test_sdpa": TestFailure(("abi_compatible_cpu",)), "test_sdpa_2": TestFailure(("abi_compatible_cpu",)), "test_simple_dynamic": TestFailure(("abi_compatible_cpu",)), + # error: could not find s0 + "test_shifted_constraint_ranges": TestFailure( + ("abi_compatible_cpu",), is_skip=True + ), }, ) diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index b8c95fb168d8..4c0234d106e1 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -46,6 +46,7 @@ ) from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo from torch.nn.parallel.distributed import DistributedDataParallel + from ..fx import GraphModule from .backends.registry import CompilerFn, lookup_backend @@ -1181,6 +1182,7 @@ def export( constraints: Optional[List[Constraint]] = None, assume_static_by_default: bool = False, same_signature: bool = True, + disable_constraint_solver: bool = False, **extra_kwargs, ) -> Callable[..., ExportResult]: """ @@ -1206,6 +1208,8 @@ def export( same_signature (bool): If True, rewrite the returned graph's signature to be the same as f. + disable_constraint_solver (bool): Whether the dim constraint solver must be disabled. + Returns: A function that given args and kwargs, returns a tuple of (graph, guards) Graph: An FX graph representing the execution of the input PyTorch function with the provided arguments and options. @@ -1341,7 +1345,8 @@ def result_capturing_wrapper(*graph_inputs): remove_from_cache(f) if ( - (shape_env := getattr(fake_mode, "shape_env", None)) is not None + not disable_constraint_solver + and (shape_env := getattr(fake_mode, "shape_env", None)) is not None and (dim_constraints := shape_env.dim_constraints) is not None and not skipfiles.check(call_to_inspect) ): diff --git a/torch/_export/__init__.py b/torch/_export/__init__.py index e2a301bbc927..13e366e6d396 100644 --- a/torch/_export/__init__.py +++ b/torch/_export/__init__.py @@ -488,6 +488,7 @@ def _export_to_torch_ir( constraints: Optional[List[Constraint]] = None, *, preserve_module_call_signature: Tuple[str, ...] = (), + disable_constraint_solver: bool = False, ) -> torch.fx.GraphModule: """ Traces either an nn.Module's forward function or just a callable with PyTorch @@ -515,6 +516,7 @@ def _export_to_torch_ir( constraints=constraints, assume_static_by_default=True, tracing_mode="symbolic", + disable_constraint_solver=disable_constraint_solver, )( *args, **kwargs, @@ -604,7 +606,7 @@ def _export( args, kwargs, constraints, - preserve_module_call_signature=preserve_module_call_signature + preserve_module_call_signature=preserve_module_call_signature, ) params_buffers: Dict[str, Union[torch.Tensor, torch.nn.Parameter]] = {} @@ -924,6 +926,7 @@ def aot_compile( dynamic_shapes: Optional[Dict[str, Any]] = None, options: Optional[Dict[str, Any]] = None, remove_runtime_assertions: bool = False, + disable_constraint_solver: bool = False, ) -> str: """ Note: this function is not stable yet @@ -950,6 +953,8 @@ def aot_compile( options: A dictionary of options to control inductor + disable_constraint_solver: Whether the dim constraint solver must be disabled. + Returns: Path to the generated shared library """ @@ -966,7 +971,13 @@ def aot_compile( # We want to export to Torch IR here to utilize the pre_grad passes in # inductor, which run on Torch IR. - gm = _export_to_torch_ir(f, args, kwargs, constraints) + gm = _export_to_torch_ir( + f, + args, + kwargs, + constraints, + disable_constraint_solver=disable_constraint_solver + ) flat_example_inputs = pytree.arg_tree_leaves(*args, **kwargs or {}) with torch.no_grad(): diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 192b33c1e6b4..a6e05906a4ed 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -2542,7 +2542,16 @@ def track_symint(source, val, constraint=None): else: constraint_violated = False if isinstance(constraint, StrictMinMaxConstraint): - constraint_violated = True + # try inferring the ranges of the expr s + sym_vrs = {x: self.var_to_range.get(x, None) for x in s.free_symbols} + if all(vr is not None for vr in sym_vrs.values()): + expr_vr = bound_sympy(s, sym_vrs) + if (expr_vr != constraint.vr): + # the expr and constrain ranges don't match + constraint_violated = True + else: + # some of the free symbols in s don't have ranges + constraint_violated = True elif isinstance(constraint, RelaxedUnspecConstraint): if s.is_number: i = int(s) From ae00d9623e89168e4147a71f2e15378584630b9b Mon Sep 17 00:00:00 2001 From: Adnan Akhundov Date: Mon, 20 Nov 2023 10:01:46 -0800 Subject: [PATCH 007/221] [inductor] Add ABI shim function for torch.scatter (#114027) Summary: Scatter fallback calls `at::scatter` in the C++ wrapper codegen. This doesn't work in the ABI compatibility mode, as the latter requires a shim function. One is added in this PR. Test Plan: ``` $ python test/inductor/test_aot_inductor.py -k test_scatter_fallback s... ---------------------------------------------------------------------- Ran 4 tests in 52.713s OK (skipped=1) ``` Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/114027 Approved by: https://github.com/chenyang78, https://github.com/desertfire ghstack dependencies: #114024 --- test/inductor/test_aot_inductor.py | 23 +++++++++++++++++++ torch/_inductor/codegen/wrapper.py | 3 +++ torch/csrc/inductor/aoti_torch/c/shim.h | 7 ++++++ .../csrc/inductor/aoti_torch/shim_common.cpp | 16 +++++++++++++ 4 files changed, 49 insertions(+) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index c217fd7d751e..14b9368df302 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -1261,6 +1261,27 @@ def forward( disable_constraint_solver=True, ) + def test_scatter_fallback(self): + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward( + self, + inp: torch.Tensor, + index: torch.Tensor, + src: torch.Tensor, + ): + return torch.scatter(inp, 1, index, src) + + inputs = ( + torch.ones((3, 5), device=self.device, dtype=torch.int64), + torch.tensor([[0, 1, 2, 0]], device=self.device, dtype=torch.int64), + torch.zeros((2, 5), device=self.device, dtype=torch.int64), + ) + + self.check_model(Model(), inputs) + common_utils.instantiate_parametrized_tests(AOTInductorTestsTemplate) @@ -1300,6 +1321,8 @@ class AOTInductorTestABICompatibleCpu(TestCase): "test_shifted_constraint_ranges": TestFailure( ("abi_compatible_cpu",), is_skip=True ), + # the test segfaults + "test_scatter_fallback": TestFailure(("abi_compatible_cpu",), is_skip=True), }, ) diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 0bbceea258de..9648051e31c0 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -1781,6 +1781,9 @@ def generate_scatter_fallback( self, output, inputs, kernel, fn, src_is_tensor, reduce, kwargs ): # TODO: support other overload for cpp wrapper and remove the below assertions + if V.graph.aot_mode and config.aot_inductor.abi_compatible: + # call the ABI shim function instead of the ATen one + kernel = kernel.replace("at::", "aoti_torch_") line = f"{kernel}({output}, {','.join(map(str, inputs))}" if fn == "aten.scatter_": if src_is_tensor: diff --git a/torch/csrc/inductor/aoti_torch/c/shim.h b/torch/csrc/inductor/aoti_torch/c/shim.h index cf2a0bdcbc5f..0922096750d6 100644 --- a/torch/csrc/inductor/aoti_torch/c/shim.h +++ b/torch/csrc/inductor/aoti_torch/c/shim.h @@ -244,6 +244,13 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_repeat_interleave_Tensor( AOTI_TORCH_EXPORT AOTITorchError aoti_check_inf_and_nan(AtenTensorHandle tensor); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scatter_out( + AtenTensorHandle out, + AtenTensorHandle self, + int64_t dim, + AtenTensorHandle index, + AtenTensorHandle src); + #ifdef USE_CUDA struct CUDAStreamGuardOpaque; diff --git a/torch/csrc/inductor/aoti_torch/shim_common.cpp b/torch/csrc/inductor/aoti_torch/shim_common.cpp index b17f2a14c2b3..c1cf7e2b6332 100644 --- a/torch/csrc/inductor/aoti_torch/shim_common.cpp +++ b/torch/csrc/inductor/aoti_torch/shim_common.cpp @@ -26,6 +26,7 @@ #include #include #include +#include #endif @@ -408,6 +409,21 @@ AOTITorchError aoti_check_inf_and_nan(AtenTensorHandle tensor) { }); } +AOTITorchError aoti_torch_scatter_out( + AtenTensorHandle out, + AtenTensorHandle self, + int64_t dim, + AtenTensorHandle index, + AtenTensorHandle src) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + at::Tensor* out_tensor = tensor_handle_to_tensor_pointer(out); + at::Tensor* self_tensor = tensor_handle_to_tensor_pointer(self); + at::Tensor* index_tensor = tensor_handle_to_tensor_pointer(index); + at::Tensor* src_tensor = tensor_handle_to_tensor_pointer(src); + at::scatter_out(*out_tensor, *self_tensor, dim, *index_tensor, *src_tensor); + }); +} + // ProxyExecutor AOTITorchError aoti_torch_proxy_executor_call_function( AOTIProxyExecutorHandle proxy_executor, From 7afceb9f64debad7f950ee336135f6e15a9a7738 Mon Sep 17 00:00:00 2001 From: Sijia Chen Date: Mon, 20 Nov 2023 23:03:33 +0000 Subject: [PATCH 008/221] [AOTI] add float support of triton (#114014) Summary: As the title Test Plan: buck2 test 'fbcode//mode/opt' fbcode//caffe2/test/dynamo:test_dynamo -- --exact 'caffe2/test/dynamo:test_dynamo - test_functions.py::DefaultsTests::test_triton_kernel_None_arg' --print-passing-details Differential Revision: D51421325 Pull Request resolved: https://github.com/pytorch/pytorch/pull/114014 Approved by: https://github.com/oulgen, https://github.com/aakhundov --- test/dynamo/test_functions.py | 2 ++ torch/_inductor/codegen/triton_utils.py | 4 ++++ 2 files changed, 6 insertions(+) diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index f0c3ebb4381b..4c9d70622d22 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -2006,6 +2006,7 @@ def pass_kernel( n_elements, dummy_None, dummy_empty, + dummy_float, BLOCK_SIZE: "tl.constexpr", RANDOM_SIZE: "tl.constexpr", ): @@ -2020,6 +2021,7 @@ def call_triton(output): n_elements, None, torch.empty_like(output), + 3.1415926, RANDOM_SIZE=0, ) return output diff --git a/torch/_inductor/codegen/triton_utils.py b/torch/_inductor/codegen/triton_utils.py index 87490fdae356..7fb23d9a6f55 100644 --- a/torch/_inductor/codegen/triton_utils.py +++ b/torch/_inductor/codegen/triton_utils.py @@ -34,6 +34,8 @@ def signature_of(arg: Union[TensorArg, SizeArg], *, size_dtype: str) -> str: # From triton/runtime/jit.py # `None` is nullptr. Implicitly convert to *i8. return "*i8" + elif isinstance(arg.expr, float): + return "fp32" if size_dtype == "tl.int32": return "i32" elif size_dtype == "tl.int64": @@ -73,6 +75,8 @@ def is_aligned( return False if x.expr is None: return False + if isinstance(x.expr, float): + return False return V.graph.sizevars.statically_known_multiple_of(x.expr, alignment) raise NotImplementedError(f"unhandled {type(x)}: {x}") From 2ca1119d532af0ba385c7b5944b954c9385b4901 Mon Sep 17 00:00:00 2001 From: voznesenskym Date: Mon, 20 Nov 2023 11:47:53 -0800 Subject: [PATCH 009/221] Add Stateful/Stateless symbolic contexts, use fresh fake mode for dynamo backends (#113926) The primary problem we are setting out to solve here is fake tensor freshness. Before this PR, fake tensors after dynamo represented fake tensors *at the end* of trace, so subsequent retraces like aot_autograd would start off with fake tensors in the wrong (end result) state, rather than their expected fresh state. The solution here is to start a fresh fake mode, and re-fakify the tensors. The nuance comes from ensuring that symbols are uniformly created for the symbolic sizes and strides of the tensor. This PR is the result of *a lot* of back and forth with @ezyang and @eellison. Initially, the first pass at this was not super different from what we have in the PR - the broad strokes were the same: 1) We cache source->symbol in shape_env 2) We pass policy objects around, stored at dynamo fakificaiton time, and reused for later fakification 3) We create a new fake mode for backends (from https://github.com/pytorch/pytorch/pull/113605/files) This is ugly, and has some layering violations. We detoured our decision making through a few other alternatives. Immutable/mutable fake tensor mode was the most interesting alternative, https://github.com/pytorch/pytorch/pull/113653, and was struck down on concerns of complexity in fake mode combined with it not covering all edge cases. We also detoured on what to do about tensor memoization returning back potentially different tensors than requested, and if that was an anti pattern (it is) we want to hack in with the symbol cache (we don't). We went back to the drawing board here, but with a few concessions: 1) the cache for source->symbol must live outside of shape_env, for both lifecycle, and layering reasons 2) A good amount of work needs to be done to pipe policy around fake_mode and meta_utils correctly, to cover all the cases (@ezyang did this) Pull Request resolved: https://github.com/pytorch/pytorch/pull/113926 Approved by: https://github.com/ezyang, https://github.com/eellison --- docs/source/conf.py | 4 +- test/dynamo/test_export.py | 4 +- test/dynamo/test_subclasses.py | 10 +- test/test_dynamic_shapes.py | 4 +- test/test_fake_tensor.py | 4 +- torch/_dynamo/backends/distributed.py | 4 +- torch/_dynamo/eval_frame.py | 4 +- torch/_dynamo/output_graph.py | 11 ++ torch/_dynamo/utils.py | 14 +++ torch/_dynamo/variables/builder.py | 48 +++++--- torch/_functorch/aot_autograd.py | 9 +- torch/_guards.py | 3 + torch/_subclasses/fake_tensor.py | 16 +-- torch/_subclasses/meta_utils.py | 26 ++-- torch/fx/experimental/symbolic_shapes.py | 144 +++++++++++++++++------ 15 files changed, 217 insertions(+), 88 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 308d2e77863e..031ed72c03f9 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -2871,7 +2871,9 @@ "ShapeGuardPrinter", "StrictMinMaxConstraint", "SymDispatchMode", - "CreateSymbolicPolicy", + "SymbolicContext", + "StatelessSymbolicContext", + "StatefulSymbolicContext", # torch.fx.experimental.unification.match "Dispatcher", "VarDispatcher", diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py index b53ab7f6f266..3f6f906cf9b8 100644 --- a/test/dynamo/test_export.py +++ b/test/dynamo/test_export.py @@ -28,8 +28,8 @@ from torch.fx.experimental.symbolic_shapes import ( ConstraintViolationError, DimDynamic, - FreshCreateSymbolicPolicy, ShapeEnv, + StatelessSymbolicContext, ) from torch.testing._internal import common_utils @@ -3249,7 +3249,7 @@ def test_symbool_guards( ) as fake_mode: fake_x = fake_mode.from_tensor( x, - policy=FreshCreateSymbolicPolicy( + symbolic_context=StatelessSymbolicContext( dynamic_sizes=[DimDynamic.DYNAMIC for _ in range(x.dim())], ), ) diff --git a/test/dynamo/test_subclasses.py b/test/dynamo/test_subclasses.py index cadc164dd283..53fb7328e529 100644 --- a/test/dynamo/test_subclasses.py +++ b/test/dynamo/test_subclasses.py @@ -15,8 +15,8 @@ from torch.fx.experimental.symbolic_shapes import ( DimDynamic, - FreshCreateSymbolicPolicy, ShapeEnv, + StatelessSymbolicContext, ) from torch.nested._internal.nested_tensor import ( jagged_from_list, @@ -337,13 +337,13 @@ def test_dynamic_dim(f, x, dim_dynamic, exp_frame_count, exp_op_count): ) as fake_mode: x_fake = fake_mode.from_tensor( x, - policy=FreshCreateSymbolicPolicy( + symbolic_context=StatelessSymbolicContext( dynamic_sizes=[dim_dynamic for i in range(x.dim())] ), ) x1_fake = fake_mode.from_tensor( x1, - policy=FreshCreateSymbolicPolicy( + symbolic_context=StatelessSymbolicContext( dynamic_sizes=[dim_dynamic for i in range(x.dim())] ), ) @@ -373,7 +373,7 @@ def test_automatic_dynamic(f, inps, dim_dynamic, exp_frame_count, exp_op_count): for inp in inps: fake_inp = fake_mode.from_tensor( inp, - policy=FreshCreateSymbolicPolicy( + symbolic_context=StatelessSymbolicContext( [dim_dynamic for i in range(x.dim())] ), ) @@ -708,7 +708,7 @@ def test_recompilation( ) as fake_mode: fake_inp = fake_mode.from_tensor( x, - policy=FreshCreateSymbolicPolicy( + symbolic_context=StatelessSymbolicContext( dynamic_sizes=[DimDynamic.DYNAMIC for i in range(x.dim())] ), ) diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index daf293b43d00..bf843587af50 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -27,7 +27,7 @@ GuardOnDataDependentSymNode, ShapeEnv, is_symbolic, - FreshCreateSymbolicPolicy, + StatelessSymbolicContext, ) from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, @@ -137,7 +137,7 @@ def create_symbolic_tensor(name, arg, shape_env): shape_env.create_symbolic_sizes_strides_storage_offset( arg, source=ConstantSource(name), - policy=FreshCreateSymbolicPolicy( + symbolic_context=StatelessSymbolicContext( dynamic_sizes=dynamic_dims, constraint_sizes=constraint_dims ), diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py index 0b9f895f0a64..14a596508824 100644 --- a/test/test_fake_tensor.py +++ b/test/test_fake_tensor.py @@ -15,7 +15,7 @@ DynamicOutputShapeException, UnsupportedOperatorException, ) -from torch.fx.experimental.symbolic_shapes import ShapeEnv, DimDynamic, free_symbols, FreshCreateSymbolicPolicy +from torch.fx.experimental.symbolic_shapes import ShapeEnv, DimDynamic, free_symbols, StatelessSymbolicContext from torch.testing._internal.custom_op_db import custom_op_db from torch.testing._internal.common_device_type import ops from torch.testing._internal.common_device_type import instantiate_device_type_tests, OpDTypes @@ -541,7 +541,7 @@ def test_same_shape_env_preserved(self): mode1 = FakeTensorMode(shape_env=shape_env) t1 = mode1.from_tensor( torch.randn(10), - policy=FreshCreateSymbolicPolicy( + symbolic_context=StatelessSymbolicContext( dynamic_sizes=[DimDynamic.DYNAMIC], constraint_sizes=[None] ) diff --git a/torch/_dynamo/backends/distributed.py b/torch/_dynamo/backends/distributed.py index 90cb21c26351..adc68bb30bff 100644 --- a/torch/_dynamo/backends/distributed.py +++ b/torch/_dynamo/backends/distributed.py @@ -409,7 +409,9 @@ def run_node(self, n: Node) -> Any: if isinstance(arg, torch.Tensor) and not isinstance( arg, torch._subclasses.FakeTensor ): - new_args.append(fake_mode.from_tensor(arg)) + new_args.append( + torch._dynamo.utils.to_fake_tensor(arg, fake_mode) + ) else: new_args.append(arg) diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 4c0234d106e1..7fff2c3392fc 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -42,7 +42,7 @@ from torch.fx.experimental.symbolic_shapes import ( ConstraintViolationError, DimDynamic, - FreshCreateSymbolicPolicy, + StatelessSymbolicContext, ) from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo from torch.nn.parallel.distributed import DistributedDataParallel @@ -903,7 +903,7 @@ def __init__( # TODO(zhxchen17) Also preserve all the user constraints here. arg.node.meta["val"] = fake_mode.from_tensor( flat_args[i], - policy=FreshCreateSymbolicPolicy( + symbolic_context=StatelessSymbolicContext( dynamic_sizes=[ DimDynamic.DYNAMIC if d in flat_args_dynamic_dims[i] diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index e3a33c1503a8..b577b9ea94aa 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -1065,6 +1065,17 @@ def compile_and_call_fx_graph(self, tx, rv, root): "%s", LazyString(lambda: self.get_graph_sizes_log_str(name)) ) self.call_cleanup_hooks() + old_fake_mode = self.tracing_context.fake_mode + if not self.export: + # TODO(voz): The way export uses gm, and fake tensors, is not supported with us resetting + backend_fake_mode = torch._subclasses.FakeTensorMode( + shape_env=old_fake_mode.shape_env, + ) + # TODO(voz): Ostensibily, this should be scoped and + # restore back to old_fake_mode, but doing so currently violates + # a lot of fake_tensor ownership assumptions and runs afoul of detect_fake_mode + self.tracing_context.fake_mode = backend_fake_mode + with self.restore_global_state(): compiled_fn = self.call_user_compiler(gm) compiled_fn = disable(compiled_fn) diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index aa0719a3ab56..ba876a0fbb82 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -2295,3 +2295,17 @@ def has_torch_function(vt: "torch._dynamo.variables.base.VariableTracker") -> bo isinstance(vt, UserDefinedObjectVariable) and hasattr(vt.value, "__torch_function__") ) + + +# see note [Tensor Fakification and Symbol Caching] +def to_fake_tensor(t, fake_mode): + symbolic_context = None + source = None + if tracing_context := torch._guards.TracingContext.try_get(): + if t in tracing_context.tensor_to_context: + symbolic_context = tracing_context.tensor_to_context[t] + source = symbolic_context.tensor_source + + return fake_mode.from_tensor( + t, static_shapes=False, symbolic_context=symbolic_context, source=source + ) diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index be66d51c0f4d..a139efd3e166 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -26,11 +26,11 @@ from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode from torch.fx.experimental.symbolic_shapes import ( _constrain_range_for_size, - CreateSymbolicPolicy, DimConstraint, DimDynamic, - FreshCreateSymbolicPolicy, RelaxedUnspecConstraint, + StatefulSymbolicContext, + SymbolicContext, ) from torch.fx.immutable_collections import immutable_list from torch.nested._internal.nested_tensor import NestedTensor @@ -1564,23 +1564,33 @@ def __eq__(self, other: object) -> bool: # Performs automatic dynamic dim determination. -# Returns a CreateSymbolicPolicy -def _automatic_dynamic(e, tx, name, static_shapes) -> CreateSymbolicPolicy: +# Returns a SymbolicContext +def _automatic_dynamic(e, tx, source, static_shapes) -> SymbolicContext: + name = source.name() + prior_policy = tx.output.tracing_context.tensor_to_context.get(e, None) + source_to_symint_node_cache = ( + prior_policy.source_to_symint_node_cache if prior_policy else None + ) + if static_shapes: - return FreshCreateSymbolicPolicy( + return StatefulSymbolicContext( dynamic_sizes=[DimDynamic.STATIC] * e.dim(), constraint_sizes=[None] * e.dim(), + tensor_source=source, + source_to_symint_node_cache=source_to_symint_node_cache, ) # We preserve the dynamism of inputs. For example, when users call # make_fx(torch.cond, tracing_mode="symbolic")(*args), inputs have SymInt sizes. if any(isinstance(s, SymInt) for s in e.size()): - return FreshCreateSymbolicPolicy( + return StatefulSymbolicContext( dynamic_sizes=[ DimDynamic.DYNAMIC if isinstance(s, SymInt) else DimDynamic.STATIC for s in e.size() ], constraint_sizes=[None] * e.dim(), + tensor_source=source, + source_to_symint_node_cache=source_to_symint_node_cache, ) # Prep for automatic dynamic @@ -1699,7 +1709,7 @@ def update_dim2constraint(dim, constraint_range, debug_name): # Now, figure out if the dim is dynamic/duck/static if constraint_dim is not None or marked_dynamic or marked_weak_dynamic: # NB: We could assert static_shapes is False here, but it - # seems better to allow the user to override policy in this + # seems better to allow the user to override symbolic_context in this # case dynamic = DimDynamic.DYNAMIC elif static_shapes or config.assume_static_by_default or marked_static: @@ -1711,12 +1721,15 @@ def update_dim2constraint(dim, constraint_range, debug_name): tx.output.frame_state[name] = frame_state_entry - return FreshCreateSymbolicPolicy( + return StatefulSymbolicContext( dynamic_sizes=dynamic_dims, constraint_sizes=constraint_dims, + tensor_source=source, + source_to_symint_node_cache=source_to_symint_node_cache, ) +# See note [Tensor Fakification and Symbol Caching] def wrap_to_fake_tensor_and_record(e, tx, *, source: Optional[Source], is_tensor: bool): if ( type(e) in (torch.Tensor, torch.nn.Parameter, FakeTensor) @@ -1728,31 +1741,36 @@ def wrap_to_fake_tensor_and_record(e, tx, *, source: Optional[Source], is_tensor e, is_tensor, guard_source=source.guard_source() ) - policy = None + symbolic_context = None if not e.is_nested: # TODO: We should probably support this for nested tensors too - policy = _automatic_dynamic(e, tx, source.name(), static_shapes) + symbolic_context = _automatic_dynamic(e, tx, source, static_shapes) + + if symbolic_context: + tx.output.tracing_context.tensor_to_context[e] = symbolic_context log.debug( "wrap_to_fake %s %s %s %s", source.name(), tuple(e.shape), - policy.dynamic_sizes if policy is not None else None, - policy.constraint_sizes if policy is not None else None, + symbolic_context.dynamic_sizes if symbolic_context is not None else None, + symbolic_context.constraint_sizes if symbolic_context is not None else None, ) fake_e = wrap_fake_exception( lambda: tx.fake_mode.from_tensor( e, source=source, - policy=policy, + symbolic_context=symbolic_context, ) ) - # TODO: just store the whole policy here + # TODO: just store the whole symbolic_context here tx.output.tracked_fakes.append( TrackedFake( fake_e, source, - policy.constraint_sizes if policy is not None else None, + symbolic_context.constraint_sizes + if symbolic_context is not None + else None, ) ) tx.output.tracked_fakes_id_to_source[id(e)].append(source) diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py index c09ab6ba9b94..cefb2e826a8c 100644 --- a/torch/_functorch/aot_autograd.py +++ b/torch/_functorch/aot_autograd.py @@ -4381,14 +4381,17 @@ def convert(idx, x): if all(isinstance(getattr(x, attr), FakeTensor) for attr in attrs): assert all(getattr(x, attr).fake_mode is fake_mode for attr in attrs) return x - # TODO: Ensure that this codepath is never exercised from - # Dynamo + + if ( idx < aot_config.num_params_buffers and config.static_weight_shapes ): + # TODO: Ensure that this codepath is never exercised from + # Dynamo return fake_mode.from_tensor(x, static_shapes=True) - return fake_mode.from_tensor(x, static_shapes=False) + + return torch._dynamo.utils.to_fake_tensor(x, fake_mode) return [convert(idx, x) for idx, x in enumerate(flat_args)] diff --git a/torch/_guards.py b/torch/_guards.py index fe3a10d663b7..69912b15313d 100644 --- a/torch/_guards.py +++ b/torch/_guards.py @@ -29,6 +29,7 @@ import torch from torch.utils import _pytree as pytree from torch.utils._traceback import CapturedTraceback +from torch.utils.weak import WeakTensorKeyDictionary log = logging.getLogger(__name__) @@ -618,6 +619,8 @@ def __init__(self, fake_mode): # ints that are known to be size-like and may have 0/1 entries that we # must not specialize on. self.force_unspec_int_unbacked_size_like = False + # See note [Tensor Fakification and Symbol Caching] + self.tensor_to_context = WeakTensorKeyDictionary() @staticmethod @contextmanager diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index b36bc4c5bf8b..e4e676c9d8be 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -312,7 +312,7 @@ def from_real_tensor( shape_env=None, *, source=None, - policy=None, + symbolic_context=None, memoized_only=False, ): maybe_memo = self._get_memo(t) @@ -348,7 +348,7 @@ def mk_fake_tensor(make_meta_t): shape_env=shape_env, callback=mk_fake_tensor, source=source, - policy=policy, + symbolic_context=symbolic_context, ) if out is NotImplemented: raise UnsupportedFakeTensorException("meta converter nyi") @@ -383,7 +383,7 @@ def __call__( make_constant=False, shape_env=None, source=None, - policy=None, + symbolic_context=None, memoized_only=False, ): return self.from_real_tensor( @@ -392,7 +392,7 @@ def __call__( make_constant, shape_env=shape_env, source=source, - policy=policy, + symbolic_context=symbolic_context, memoized_only=memoized_only, ) @@ -1855,7 +1855,7 @@ def from_tensor( *, static_shapes=None, source: Optional[Source] = None, - policy=None, + symbolic_context=None, # Setting this flag will force FakeTensorMode to return `None` if attempting to convert a tensor we have not # seen before. memoized_only=False, @@ -1864,14 +1864,16 @@ def from_tensor( if static_shapes is None: static_shapes = self.static_shapes if static_shapes: - assert policy is None, "cannot set both static_shapes and policy" + assert ( + symbolic_context is None + ), "cannot set both static_shapes and symbolic_context" shape_env = None return self.fake_tensor_converter( self, tensor, shape_env=shape_env, source=source, - policy=policy, + symbolic_context=symbolic_context, memoized_only=memoized_only, ) diff --git a/torch/_subclasses/meta_utils.py b/torch/_subclasses/meta_utils.py index 1ff2a156379d..8db8f94b1b41 100644 --- a/torch/_subclasses/meta_utils.py +++ b/torch/_subclasses/meta_utils.py @@ -23,7 +23,7 @@ if TYPE_CHECKING: # Import the following modules during type checking to enable code intelligence features, # Do not import unconditionally, as they import sympy and importing sympy is very slow - from torch.fx.experimental.symbolic_shapes import CreateSymbolicPolicy + from torch.fx.experimental.symbolic_shapes import SymbolicContext DimList = List @@ -184,7 +184,7 @@ def meta_tensor( shape_env=None, callback=lambda t: t(), source: Optional[Source] = None, - policy: Optional["CreateSymbolicPolicy"] = None, + symbolic_context: Optional["SymbolicContext"] = None, ): from torch._subclasses.fake_tensor import FakeTensor @@ -250,10 +250,10 @@ def sym_sizes_strides_storage_offset( # the wrapper tensor and any inner tensors. # We can revisit this if this assumption does not hold # for any important subclasses later. - policy=policy, + symbolic_context=symbolic_context, ) else: - assert policy is None + assert symbolic_context is None return (t.size(), t.stride(), t.storage_offset()) # see expired-storages @@ -315,22 +315,22 @@ def sym_sizes_strides_storage_offset( from torch._dynamo.source import AttrSource from torch.fx.experimental.symbolic_shapes import ( DimDynamic, - FreshCreateSymbolicPolicy, + StatelessSymbolicContext, ) if shape_env and not t.is_nested and not t._base.is_nested: - base_policy = FreshCreateSymbolicPolicy( + base_symbolic_context = StatelessSymbolicContext( dynamic_sizes=[DimDynamic.STATIC] * t._base.dim(), constraint_sizes=[None] * t._base.dim(), ) else: - base_policy = None + base_symbolic_context = None base = self.meta_tensor( t._base, shape_env, callback, source=AttrSource(source, "_base"), - policy=base_policy, + symbolic_context=base_symbolic_context, ) def is_c_of_r(complex_dtype, real_dtype): @@ -620,7 +620,7 @@ def empty_create(inner_t, inner_src): shape_env, callback, source=AttrSource(source, "grad"), - policy=policy, + symbolic_context=symbolic_context, ) torch._C._set_conj(r, t.is_conj()) torch._C._set_neg(r, t.is_neg()) @@ -637,7 +637,7 @@ def __call__( *, callback=lambda t: t(), source=None, - policy=None, + symbolic_context=None, ): # TODO: zero tensors? We appear to have eliminated them by # excluding complex for now @@ -682,7 +682,7 @@ def __call__( shape_env=shape_env, callback=callback, source=source, - policy=policy, + symbolic_context=symbolic_context, ) out = torch._to_functional_tensor(fake_t) torch._mirror_autograd_meta_to(fake_t, out) @@ -700,7 +700,7 @@ def __call__( shape_env=shape_env, callback=callback, source=source, - policy=policy, + symbolic_context=symbolic_context, ) return _wrap_functional_tensor(fake_t, current_level()) self.miss += 1 @@ -712,7 +712,7 @@ def __call__( shape_env=shape_env, callback=callback, source=source, - policy=policy, + symbolic_context=symbolic_context, ) if type(t) is torch.nn.Parameter: # NB: Cannot directly use Parameter constructor diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index a6e05906a4ed..e7e9573afd61 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -63,7 +63,7 @@ class GuardOnDataDependentSymNode(RuntimeError): "guard_int", "guard_float", "guard_scalar", "hint_int", "SYMPY_INTERP", "free_symbols", "is_symbol_binding_fx_node", "is_concrete_bool", "SHAPEENV_EVENT_KEY", "CURRENT_NODE_KEY", - "has_free_symbols", "sym_eq", "CreateSymbolicPolicy", "FreshCreateSymbolicPolicy", + "has_free_symbols", "sym_eq", "SymbolicContext", "StatelessSymbolicContext", "StatefulSymbolicContext" ] # FX node metadata keys for symbolic shape FX graph. @@ -721,8 +721,14 @@ def render(self): def is_equal(self, source1, source2): return self._find(source1) == self._find(source2) + +def _assert_symbol_context(symbolic_context): + assert isinstance(symbolic_context, SymbolicContext), "Invalid symbolic_context object" + assert type(symbolic_context) is not SymbolicContext, "Illegal usage of symbolic_context ABC" + + @dataclass(frozen=True) -class CreateSymbolicPolicy: +class SymbolicContext: """ Data structure specifying how we should create symbols in ``create_symbolic_sizes_strides_storage_offset``; e.g., should @@ -736,20 +742,67 @@ class CreateSymbolicPolicy: @dataclass(frozen=True) -class FreshCreateSymbolicPolicy(CreateSymbolicPolicy): +class StatelessSymbolicContext(SymbolicContext): """ Create symbols in ``create_symbolic_sizes_strides_storage_offset`` via - a policy determination as given by ``DimDynamic`` and ``DimConstraint``. + a symbolic_context determination as given by ``DimDynamic`` and ``DimConstraint``. This will cause fresh symbols to be allocated """ dynamic_sizes: DimList[DimDynamic] constraint_sizes: DimList[DimConstraint] = None - # TODO: add storage offset and stride policy + # TODO: add storage offset and stride symbolic_context def __post_init__(self): if self.constraint_sizes is None: object.__setattr__(self, 'constraint_sizes', [None] * len(self.dynamic_sizes)) + +# note [Tensor Fakification and Symbol Caching] +# +# As of the time of this note, dynamo creates a fresh fake tensor mode for backends. +# The reason we do this is because there are certain classes of operations, namely, +# metadata mutations, that change tensor size, stride, etc. This means that the fake tensor +# state at the end of a dynamo trace is different than the fake tensor state at the beginning +# of a trace. Backends like aot_autograd need a fresh fake tensor to correctly track metadata mutation, +# view relationships, etc. +# +# As we create a new fake mode, we also lose the memoization that comes with it. Rather than +# transfer the memoization cache, we instead transfer the shape env. However, with this +# comes nuance - as dynamo is selective in how it makes symbolic shapes. Due to strategies in +# automatic dynamic and constraints, the policy for which dims are dynamic is nuanced and varies across +# recompilations. +# +# In order to preserve the symbolic decisions made during dynamo tensor fakification, we pass +# a StatefulSymbolicContext at creation time. This object is tracked, per tensor, on the TracingContext. +# The lifecycle of this object should match the lifecycle of the original dynamo tracked tensor, and it is +# safe to reuse this object as many times as necessary to create a fake tensor. Fake tensors +# created with new fake modes should produce the same exact symbols as the original, providing the same shape_env +# is used. +# TODO(voz): Shape env validation +@dataclass(frozen=True) +class StatefulSymbolicContext(StatelessSymbolicContext): + """ + Create symbols in ``create_symbolic_sizes_strides_storage_offset`` via + a symbolic_context determination as given by a cache of Source:Symbol. A cache hit + will reuse a stored symbol, and a cache miss will write to this cache. + + This behaves like StatelessSymbolicContext, except the cache supersedes the + other values - dynamic_sizes and constraint_sizes will not be read if we cache + hit. + + It is the cache owners responsibility to maintain the lifecycle of the cache + w/r/t different shape_envs, clearing, etc. + """ + tensor_source: Source = None + source_to_symint_node_cache : Dict["TensorPropertySource", SymInt] = None + + def __post_init__(self): + # The None default is annoying, but required because of dataclass limitations + assert self.tensor_source is not None + if not self.source_to_symint_node_cache: + object.__setattr__(self, 'source_to_symint_node_cache', {}) + + def is_symbolic(val: Union[int, SymInt, float, SymFloat, bool, SymBool]) -> bool: if isinstance(val, (int, float, bool)): return False @@ -1922,20 +1975,20 @@ def _update_version_counter(self): def _produce_dyn_sizes(self, ex_size: Sequence[int], source: Source, - policy: CreateSymbolicPolicy + symbolic_context: SymbolicContext ) -> List[sympy.Expr]: - return self._produce_dyn_sizes_from_int_tuple(tuple(ex.size()), source, policy) + return self._produce_dyn_sizes_from_int_tuple(tuple(ex.size()), source, symbolic_context) def _produce_dyn_sizes_from_int_tuple(self, tensor_size: Tuple[int], source: Source, - policy: CreateSymbolicPolicy, + symbolic_context: SymbolicContext, ) -> List[sympy.Expr]: assert all(not is_symbolic(val) for val in tensor_size), f"Expect size to be a plain tuple of ints but got {tensor_size}" from torch._dynamo.source import TensorPropertySource, TensorProperty - assert isinstance(policy, FreshCreateSymbolicPolicy) - dynamic_dims = policy.dynamic_sizes - constraint_dims = policy.constraint_sizes + _assert_symbol_context(symbolic_context) + dynamic_dims = symbolic_context.dynamic_sizes + constraint_dims = symbolic_context.constraint_sizes size = [] for i, val in enumerate(tensor_size): size.append(self.create_symbol( @@ -1948,7 +2001,7 @@ def create_symbolic_sizes_strides_storage_offset( ex: torch.Tensor, source: Source, *, - policy: Optional[CreateSymbolicPolicy] = None, + symbolic_context: Optional[SymbolicContext] = None, ): """ Returns a list of symbolic sizes and strides for the given tensor. @@ -2010,7 +2063,7 @@ def maybe_specialize_sym_int_with_hint(maybe_sym) -> int: ex_storage_offset, [_is_dim_dynamic(ex, i) for i in range(ex.dim())], source, - policy=policy, + symbolic_context=symbolic_context, ) @record_shapeenv_event() @@ -2022,12 +2075,12 @@ def _create_symbolic_sizes_strides_storage_offset( is_dim_dynamic: Sequence[bool], source: Source, *, - policy: Optional[CreateSymbolicPolicy] = None, + symbolic_context: Optional[SymbolicContext] = None, ): dim = len(ex_size) # Reimplement the legacy behavior - if policy is None: + if symbolic_context is None: constraint_dims = [None] * dim dynamic_dims = [] for i in range(dim): @@ -2041,13 +2094,14 @@ def _create_symbolic_sizes_strides_storage_offset( r = DimDynamic.DUCK dynamic_dims.append(r) dynamic_dims = [DimDynamic.DUCK] * dim - policy = FreshCreateSymbolicPolicy(dynamic_sizes=dynamic_dims, constraint_sizes=constraint_dims) - - assert isinstance(policy, FreshCreateSymbolicPolicy) - constraint_dims = policy.constraint_sizes - dynamic_dims = policy.dynamic_sizes - - # TODO: make this configurable from outside policy; we made a policy + # symbolic_context is None - set one + symbolic_context = StatelessSymbolicContext(dynamic_sizes=dynamic_dims, constraint_sizes=constraint_dims) + # We got a StatelessSymbolicContext + _assert_symbol_context(symbolic_context) + constraint_dims = symbolic_context.constraint_sizes + dynamic_dims = symbolic_context.dynamic_sizes + + # TODO: make this configurable from outside symbolic_context; we made a symbolic_context # decision here where if all sizes are static, we are going to # specialize all of the inner strides/offset too. We don't have to # do this. @@ -2058,7 +2112,7 @@ def _create_symbolic_sizes_strides_storage_offset( assert len(constraint_dims) == dim from torch._dynamo.source import TensorPropertySource, TensorProperty - size: List[sympy.Expr] = self._produce_dyn_sizes_from_int_tuple(ex_size, source, policy) + size: List[sympy.Expr] = self._produce_dyn_sizes_from_int_tuple(ex_size, source, symbolic_context) stride: List[Optional[sympy.Expr]] = [None] * len(size) for i, val in enumerate(ex_stride): if val in (0, 1): @@ -2096,7 +2150,12 @@ def _create_symbolic_sizes_strides_storage_offset( assert all(x is not None for x in stride) sym_sizes = [ - self.create_symintnode(sym, hint=hint, source=TensorPropertySource(source, TensorProperty.SIZE, i)) + self.create_symintnode( + sym, + hint=hint, + source=TensorPropertySource(source, TensorProperty.SIZE, i), + symbolic_context=symbolic_context + ) for i, (sym, hint) in enumerate(zip(size, ex_size)) ] sym_stride = [] @@ -2105,14 +2164,17 @@ def _create_symbolic_sizes_strides_storage_offset( # we computed assert stride_expr is not None sym_stride.append(self.create_symintnode( - stride_expr, hint=ex_stride[i], source=TensorPropertySource(source, TensorProperty.STRIDE, i) - )) - sym_storage_offset = self.create_symintnode(self.create_symbol( - ex_storage_offset, - TensorPropertySource(source, TensorProperty.STORAGE_OFFSET), - dynamic_dim=DimDynamic.DYNAMIC, - constraint_dim=None, - ), hint=ex_storage_offset, source=TensorPropertySource(source, TensorProperty.STORAGE_OFFSET)) + stride_expr, hint=ex_stride[i], source=TensorPropertySource(source, TensorProperty.STRIDE, i), + symbolic_context=symbolic_context)) + sym_storage_offset = self.create_symintnode( + self.create_symbol( + ex_storage_offset, + TensorPropertySource(source, TensorProperty.STORAGE_OFFSET), + dynamic_dim=DimDynamic.DYNAMIC, + constraint_dim=None, + ), + hint=ex_storage_offset, + source=TensorPropertySource(source, TensorProperty.STORAGE_OFFSET), symbolic_context=symbolic_context) return tuple(sym_sizes), tuple(sym_stride), sym_storage_offset # If you know what the current hint value of the SymInt to be created @@ -2125,7 +2187,10 @@ def create_symintnode( *, hint: Optional[int], source: Optional[Source] = None, + symbolic_context: Optional[SymbolicContext] = None, ): + source_name = source.name() if source else None + if self._translation_validation_enabled and source is not None: # Create a new symbol for this source. symbol = self._create_symbol_for_source(source) @@ -2139,11 +2204,20 @@ def create_symintnode( else: fx_node = None + # see note [Tensor Fakification and Symbol Caching] + if isinstance(symbolic_context, StatefulSymbolicContext) and source_name: + if source_name in symbolic_context.source_to_symint_node_cache: + return symbolic_context.source_to_symint_node_cache[source_name] + if isinstance(sym, sympy.Integer): if hint is not None: assert int(sym) == hint - return int(sym) - return SymInt(SymNode(sym, self, int, hint, fx_node=fx_node)) + out = int(sym) + else: + out = SymInt(SymNode(sym, self, int, hint, fx_node=fx_node)) + if isinstance(symbolic_context, StatefulSymbolicContext) and source_name: + symbolic_context.source_to_symint_node_cache[source_name] = out + return out @record_shapeenv_event() def create_unspecified_symint_and_symbol(self, value, source, dynamic_dim): @@ -2238,7 +2312,7 @@ def create_symbol( assert isinstance(source, Source), f"{type(source)} {source}" assert not (positive and val < 0), f"positive set for negative value: {val}" # It's always sound to allocate a symbol as DYNAMIC. If the user - # constrained the symbol, force the policy to DYNAMIC, because our + # constrained the symbol, force the symbolic_context to DYNAMIC, because our # constraint code will do weird stuff if, e.g., it's duck shaped if constraint_dim is not None: dynamic_dim = DimDynamic.DYNAMIC From c1d9d4a2b51c9fd351cfc16d16df0d7f5829d893 Mon Sep 17 00:00:00 2001 From: soulitzer Date: Mon, 20 Nov 2023 23:08:40 +0000 Subject: [PATCH 010/221] checkpoint_sequential warns if use_reentrant not passed explicitly (#114158) Use warning text for deprecation message. Pull Request resolved: https://github.com/pytorch/pytorch/pull/114158 Approved by: https://github.com/albanD --- test/test_autograd.py | 24 +++++++++++++++++++++++- torch/utils/checkpoint.py | 15 +++++++++++++-- 2 files changed, 36 insertions(+), 3 deletions(-) diff --git a/test/test_autograd.py b/test/test_autograd.py index 40ebf5b86a4f..e7dddeba57ae 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -27,7 +27,7 @@ from torch.autograd.function import once_differentiable from torch.autograd.profiler import (profile, record_function, emit_nvtx, emit_itt) from torch.autograd.profiler_util import (_format_time, EventList, FunctionEvent, FunctionEventAvg) -from torch.utils.checkpoint import checkpoint +from torch.utils.checkpoint import checkpoint, checkpoint_sequential from torch.testing import make_tensor from torch.testing._internal.common_cuda import TEST_CUDA from torch.testing._internal.common_utils import ( @@ -5989,6 +5989,28 @@ def test_checkpoint_warns_if_use_reentrant_not_passed_explcitly(self): str(w[0].message) ) + def test_checkpoint_sequential_warns_if_use_reentrant_not_passed_explcitly(self): + a = torch.randn(3, requires_grad=True) + modules_list = [ + torch.nn.Linear(3, 3), + torch.nn.Linear(3, 3), + torch.nn.Linear(3, 3) + ] + + # Passing explicitly should not warn + with warnings.catch_warnings(record=True) as w: + checkpoint_sequential(modules_list, 3, a, use_reentrant=False) + self.assertEqual(len(w), 0) + + # Not passing explicitly warns + with warnings.catch_warnings(record=True) as w: + checkpoint_sequential(modules_list, 3, a) + self.assertEqual(len(w), 1) + self.assertIn( + "please pass in use_reentrant=True or use_reentrant=False explicitly", + str(w[0].message) + ) + def test_checkpoint_detects_non_determinism(self): def save_3_tensors(x): out = x.sin().exp() diff --git a/torch/utils/checkpoint.py b/torch/utils/checkpoint.py index 25f5583cce42..42f1ef0a35ec 100644 --- a/torch/utils/checkpoint.py +++ b/torch/utils/checkpoint.py @@ -495,13 +495,13 @@ def checkpoint( return ret -def checkpoint_sequential(functions, segments, input, use_reentrant=True, **kwargs): +def checkpoint_sequential(functions, segments, input, use_reentrant=None, **kwargs): r"""Checkpoint a sequential model to save memory. Sequential models execute a list of modules/functions in order (sequentially). Therefore, we can divide such a model in various segments and checkpoint each segment. All segments except the last will not store - the intermediate activations. The inputs of each checkpointed segment will + the intermediate activations. The inputs of each checkpointed segment will be saved for re-running the segment in the backward pass. .. warning:: @@ -539,6 +539,17 @@ def checkpoint_sequential(functions, segments, input, use_reentrant=True, **kwar >>> model = nn.Sequential(...) >>> input_var = checkpoint_sequential(model, chunks, input_var) """ + if use_reentrant is None: + warnings.warn( + "torch.utils.checkpoint.checkpoint_sequential: please pass in " + "use_reentrant=True or use_reentrant=False explicitly. The default " + "value of use_reentrant will be updated to be False in the future. " + "To maintain current behavior, pass use_reentrant=True. It is " + "recommended that you use use_reentrant=False. Refer to docs for " + "more details on the differences between the two variants." + ) + use_reentrant = True + # Hack for keyword-only parameter in a python 2.7-compliant way preserve = kwargs.pop("preserve_rng_state", True) if kwargs: From 4182092febfe11b7a655730d2f89f3d2d1fa28fd Mon Sep 17 00:00:00 2001 From: ydwu4 Date: Mon, 20 Nov 2023 23:16:18 +0000 Subject: [PATCH 011/221] [reland][HigherOrderOp] remove _deprecated_global_ns (#113813) This is a reland of #112757. Cannot land original one internally because internal diff is not in sync with OSS due to issues in dealing with two export repos (executorch and pytorch) using the ghimport-ghexport approach. Will try the web UI of import and export instead of ghimport and ghexport flow. Pull Request resolved: https://github.com/pytorch/pytorch/pull/113813 Approved by: https://github.com/angelayi --- torch/_higher_order_ops/map.py | 2 +- torch/_ops.py | 20 ++++---------------- 2 files changed, 5 insertions(+), 17 deletions(-) diff --git a/torch/_higher_order_ops/map.py b/torch/_higher_order_ops/map.py index 8e578c0e1cd0..78cfc71bc562 100644 --- a/torch/_higher_order_ops/map.py +++ b/torch/_higher_order_ops/map.py @@ -34,7 +34,7 @@ def __call__(self, xs, *args): return map_wrapper(xs, *args) -map = MapWrapper("map", _deprecated_global_ns=True) +map = MapWrapper("map") map_impl = HigherOrderOperator("map_impl") dummy_aot_config = AOTConfig( diff --git a/torch/_ops.py b/torch/_ops.py index 29694c0bc7d1..828c605fc924 100644 --- a/torch/_ops.py +++ b/torch/_ops.py @@ -231,7 +231,6 @@ def resolve_key(op: OperatorBase, k: DispatchKey): # type: ignore[valid-type] raise NotImplementedError(f"could not find kernel for {op} at dispatch key {k}") -_global_higher_order_ops = {} _higher_order_ops = {} _HIGHER_ORDER_OP_DEFAULT_FALLTHROUGH_DISPATCH_KEYS = [ @@ -245,25 +244,19 @@ def resolve_key(op: OperatorBase, k: DispatchKey): # type: ignore[valid-type] class HigherOrderOperator(OperatorBase): - # _deprecated_global_ns: Whether or not the HigherOrderOperator appears as: - # (True) torch.ops.{name} - # (False) torch.ops.higher_order.{name} + # The HigherOrderOperator will appear as torch.ops.higher_order.{name} # # If you're creating a new HigherOrderOperator, please do not change the # default. Adding operators to the global torch.ops namespace is a bad # practice due to name collisions. - def __init__(self, name, *, _deprecated_global_ns=False): + def __init__(self, name): super().__init__() self._name = name # Make _OPNamespace not scream, this whole name based association needs a good hard look self.__name__ = name - if _deprecated_global_ns: - _global_higher_order_ops[name] = self - self._ns = None - else: - _higher_order_ops[name] = self - self._ns = "higher_order" + _higher_order_ops[name] = self + self._ns = "higher_order" # For a normal HigherOrderOperator instance, we will change its __module__ from torch._ops to # torch._ops.higher_order. @@ -871,9 +864,6 @@ class _Ops(types.ModuleType): def __init__(self): super().__init__("torch.ops") self.loaded_libraries = set() - self._global_higher_order_op_namespace = _PyOpNamespace( - "torch.ops", _global_higher_order_ops - ) self._higher_order_op_namespace = _PyOpNamespace( "torch.ops.higher_order", _higher_order_ops ) @@ -881,8 +871,6 @@ def __init__(self): def __getattr__(self, name): # Check if the name is a HigherOrderOperator - if name in self._global_higher_order_op_namespace._ops: - return getattr(self._global_higher_order_op_namespace, name) if name == "higher_order": return self._higher_order_op_namespace From e4a88d958114d01a226e502c7b621e7bdbdc5d9f Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Mon, 20 Nov 2023 23:35:36 +0000 Subject: [PATCH 012/221] Convert SymInts to SymFloats with SymPy (#113683) Fixes #109365 Pull Request resolved: https://github.com/pytorch/pytorch/pull/113683 Approved by: https://github.com/ezyang, https://github.com/lezcano --- test/dynamo/test_functions.py | 8 ++++++++ test/test_ops.py | 2 +- torch/fx/experimental/sym_node.py | 14 ++++++++++++-- torch/fx/experimental/symbolic_shapes.py | 17 ++++++++++------- 4 files changed, 31 insertions(+), 10 deletions(-) diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index 4c9d70622d22..cb2a1ae3a4dd 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -1252,6 +1252,14 @@ def test_partials_lambda(x): triple = functools.partial(multiply, y=3) return triple(x) + def test_pow_int(self): + def fn(a, b): + return torch.pow(a, b) + + x = torch.ones(2, 2) + opt_fn = torch.compile(fullgraph=True, backend="eager", dynamic=True)(fn) + self.assertEqual(opt_fn(x, 2), fn(x, 2)) + def test_tensor_size_indexed_by_symint(self): def fn(x, y): index = x.shape[-1] diff --git a/test/test_ops.py b/test/test_ops.py index bc9f4a74322a..4ce8e59058ce 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -263,7 +263,7 @@ def get_opoverloadpacket_from_dispatch(kernel): def test_numpy_ref(self, device, dtype, op): if ( TEST_WITH_TORCHINDUCTOR and - op.formatted_name == 'signal_windows_exponential' and + op.formatted_name in ('signal_windows_exponential', 'signal_windows_bartlett') and dtype == torch.float64 and 'cuda' in device ): # noqa: E121 raise unittest.SkipTest("XXX: raises tensor-likes are not close.") diff --git a/torch/fx/experimental/sym_node.py b/torch/fx/experimental/sym_node.py index 852389e3cde6..08985d0efb17 100644 --- a/torch/fx/experimental/sym_node.py +++ b/torch/fx/experimental/sym_node.py @@ -343,7 +343,9 @@ def guard_int(self, file, line): def guard_float(self, file, line): # TODO: use the file/line for some useful diagnostic on why a # guard occurred - r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node) + r = self.shape_env.evaluate_expr( + self.expr, self.hint, fx_node=self.fx_node, expect_rational=False + ) try: return float(r) except Exception: @@ -628,6 +630,14 @@ def _sympy_abs(a): return sympy.Abs(a) +def _sympy_sym_float(a): + # Cannot use sympy.Float(a) here, coz it expects python literals + # Multiply by 1.0 to cast to float. This is needed when the input + # is a SymInt which has the assumption that it is integer and + # SymPy will otherwise assume that return value cannot be a float. + return a * 1.0 + + magic_methods = { **reflectable_magic_methods, "sym_not": lambda a: ~a, @@ -638,7 +648,7 @@ def _sympy_abs(a): "le": _sympy_le, "ge": _sympy_ge, "floor": _sympy_floor, - "sym_float": lambda a: a, # Cannot use sympy.Float(a) here, coz it expects python literals + "sym_float": _sympy_sym_float, "ceil": _sympy_ceil, "neg": lambda a: -a, "sym_min": _sympy_min, diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index e7e9573afd61..7f056ec9d5a7 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -3029,7 +3029,8 @@ def get_shape_groups(self): @_lru_cache def _maybe_evaluate_static( - self, expr: "sympy.Expr", *, unbacked_only: bool = False, compute_hint: bool = False + self, expr: "sympy.Expr", *, unbacked_only: bool = False, compute_hint: bool = False, + expect_rational=True, ) -> "Optional[sympy.Expr]": """ Tries to evaluate expr without introducing guards @@ -3121,10 +3122,10 @@ def replace(expr, repl): # Check if the range can solve it statically out = bound_sympy(new_expr, new_range_env) - _assert_bound_is_rational(new_expr, out) - - if out.is_singleton(): - return out.lower + if expect_rational: + _assert_bound_is_rational(new_expr, out) + if out.is_singleton(): + return out.lower return new_expr if unbacked_only else None @@ -3450,7 +3451,8 @@ def _log_guard(self, prefix: str, g): @lru_cache(256) @record_shapeenv_event(save_tracked_fakes=True) - def evaluate_expr(self, orig_expr: "sympy.Expr", hint=None, fx_node=None): + def evaluate_expr(self, orig_expr: "sympy.Expr", hint=None, fx_node=None, + expect_rational=True): """ Given an expression, evaluates it, adding guards if necessary """ @@ -3510,7 +3512,8 @@ def evaluate_expr(self, orig_expr: "sympy.Expr", hint=None, fx_node=None): expr = orig_expr - static_expr = self._maybe_evaluate_static(expr) + static_expr = self._maybe_evaluate_static(expr, + expect_rational=expect_rational) if static_expr is not None: self.log.debug("eval %s == %s [statically known]", orig_expr, static_expr) # NB: don't test float as there may be precision issues From e7f12b1eb0cedfd20dcb41ea35e21e9a71e3390a Mon Sep 17 00:00:00 2001 From: Joel Schlosser Date: Mon, 20 Nov 2023 15:41:31 -0500 Subject: [PATCH 013/221] Print the index and summary of the SampleInput that failed an OpInfo test (#99444) Related to the Reproducible Testing BE project. Goal is to print out the sample input that failed an OpInfo test. Crazy idea: to avoid requiring widespread changes across tests that use OpInfo sample inputs, return a new special iterator type from `OpInfo.sample_inputs()`, etc. that tracks the most recent item seen. If a test fails later on, print out this info to identify the sample that failed the test. This solves the problem that the test framework currently has no concept of which sample input is being operated on. This PR contains the following changes: * New `TrackedInputIter` that wraps a sample inputs func iterator and tracks the most recent input seen in a `TrackedInput` structure * The information is stored in a dictionary on the test function itself, mapping `full test ID -> most recent TrackedInput` * To determine the test function that is being run, we do some stack crawling hackery in `extract_test_fn_and_id()` * Above applies only when one of the following is called: `OpInfo.sample_inputs()`, `OpInfo.error_inputs()`, `OpInfo.reference_inputs()`, and `OpInfo.conjugate_sample_inputs()`. This could easily be extended to `ModuleInfo`s and the sparse sample input funcs as well Example output when a sample input causes a failure: ``` ====================================================================== ERROR: test_foo_add_cpu_uint8 (__main__.TestFakeTensorCPU) ---------------------------------------------------------------------- Traceback (most recent call last): File "/home/jbschlosser/branches/reproducible_testing/torch/testing/_internal/common_device_type.py", line 911, in test_wrapper return test(*args, **kwargs) File "/home/jbschlosser/branches/reproducible_testing/torch/testing/_internal/common_device_type.py", line 1097, in only_fn return fn(slf, *args, **kwargs) File "/home/jbschlosser/branches/reproducible_testing/test/test_ops.py", line 2211, in test_foo self.fail('Example failure') AssertionError: Example failure The above exception was the direct cause of the following exception: Traceback (most recent call last): File "/home/jbschlosser/branches/reproducible_testing/torch/testing/_internal/common_utils.py", line 2436, in wrapper method(*args, **kwargs) File "/home/jbschlosser/branches/reproducible_testing/torch/testing/_internal/common_device_type.py", line 414, in instantiated_test result = test(self, **param_kwargs) File "/home/jbschlosser/branches/reproducible_testing/torch/testing/_internal/common_device_type.py", line 917, in test_wrapper raise Exception( Exception: Caused by sample input at index 2: SampleInput(input=Tensor[size=(5, 1), device="cpu", dtype=torch.uint8], args=TensorList[Tensor[size=(5,), device="cpu", dtype=torch.uint8]], kwargs={}, broadcasts_input=True, name='') To execute this test, run the following from the base repo dir: python test/test_ops.py -k test_foo_add_cpu_uint8 This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0 ---------------------------------------------------------------------- ``` This notably doesn't print the actual `SampleInput` values, as that's hard without fully reproducible random sample generation. I went down this path for a while and it seems infeasible without adding an untenable amount of overhead to set the random seed per SampleInput (see https://github.com/pytorch/pytorch/issues/86694#issuecomment-1614943708 for more details). For now, I am settling for at least spitting out the index and some metadata of the `SampleInput`, as it seems better than nothing. Pull Request resolved: https://github.com/pytorch/pytorch/pull/99444 Approved by: https://github.com/janeyx99 --- test/test_testing.py | 8 +- torch/testing/_internal/common_device_type.py | 27 ++++++- torch/testing/_internal/common_utils.py | 76 +++++++++++++++++++ torch/testing/_internal/opinfo/core.py | 32 ++++++-- 4 files changed, 129 insertions(+), 14 deletions(-) diff --git a/test/test_testing.py b/test/test_testing.py index feb408773f4c..542601d7ed97 100644 --- a/test/test_testing.py +++ b/test/test_testing.py @@ -12,7 +12,7 @@ import subprocess import sys import unittest.mock -from typing import Any, Callable, Iterator, List, Tuple, Generator +from typing import Any, Callable, Iterator, List, Tuple import torch @@ -2397,19 +2397,19 @@ class TestOpInfoSampleFunctions(TestCase): def test_opinfo_sample_generators(self, device, dtype, op): # Test op.sample_inputs doesn't generate multiple samples when called samples = op.sample_inputs(device, dtype) - self.assertIsInstance(samples, Generator) + self.assertIsInstance(samples, Iterator) @ops([op for op in op_db if op.reference_inputs_func is not None], dtypes=OpDTypes.any_one) def test_opinfo_reference_generators(self, device, dtype, op): # Test op.reference_inputs doesn't generate multiple samples when called samples = op.reference_inputs(device, dtype) - self.assertIsInstance(samples, Generator) + self.assertIsInstance(samples, Iterator) @ops([op for op in op_db if op.error_inputs_func is not None], dtypes=OpDTypes.none) def test_opinfo_error_generators(self, device, op): # Test op.error_inputs doesn't generate multiple inputs when called samples = op.error_inputs(device) - self.assertIsInstance(samples, Generator) + self.assertIsInstance(samples, Iterator) instantiate_device_type_tests(TestOpInfoSampleFunctions, globals()) diff --git a/torch/testing/_internal/common_device_type.py b/torch/testing/_internal/common_device_type.py index 96b7817b5c4a..4b550e95187d 100644 --- a/torch/testing/_internal/common_device_type.py +++ b/torch/testing/_internal/common_device_type.py @@ -15,7 +15,8 @@ skipCUDANonDefaultStreamIf, TEST_WITH_ASAN, TEST_WITH_UBSAN, TEST_WITH_TSAN, \ IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU, IS_WINDOWS, TEST_MPS, \ _TestParametrizer, compose_parametrize_fns, dtype_name, \ - TEST_WITH_MIOPEN_SUGGEST_NHWC, NATIVE_DEVICES, skipIfTorchDynamo + TEST_WITH_MIOPEN_SUGGEST_NHWC, NATIVE_DEVICES, skipIfTorchDynamo, \ + get_tracked_input, PRINT_REPRO_ON_FAILURE from torch.testing._internal.common_cuda import _get_torch_cuda_version, \ TEST_CUSPARSE_GENERIC, TEST_HIPSPARSE_GENERIC, _get_torch_rocm_version from torch.testing._internal.common_dtype import get_all_dtypes @@ -796,6 +797,12 @@ class OpDTypes(Enum): torch.bool ) +def _serialize_sample(sample_input): + # NB: For OpInfos, SampleInput.summary() prints in a cleaner way. + if getattr(sample_input, "summary", None) is not None: + return sample_input.summary() + return str(sample_input) + # Decorator that defines the OpInfos a test template should be instantiated for. # # Example usage: @@ -905,7 +912,23 @@ def _parametrize_test(self, test, generic_cls, device_cls): try: @wraps(test) def test_wrapper(*args, **kwargs): - return test(*args, **kwargs) + try: + return test(*args, **kwargs) + except unittest.SkipTest as e: + raise e + except Exception as e: + tracked_input = get_tracked_input() + if PRINT_REPRO_ON_FAILURE and tracked_input is not None: + raise Exception( + f"Caused by {tracked_input.type_desc} " + f"at index {tracked_input.index}: " + f"{_serialize_sample(tracked_input.val)}") from e + raise e + + # Initialize info for the last input seen. This is useful for tracking + # down which inputs caused a test failure. Note that TrackedInputIter is + # responsible for managing this. + test.tracked_input = None decorator_fn = partial(op.get_decorators, generic_cls.__name__, test.__name__, device_cls.device_type, dtype) diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 30f0311ba7b3..5149261f9935 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -36,6 +36,7 @@ from collections.abc import Mapping, Sequence from contextlib import closing, contextmanager from copy import deepcopy +from dataclasses import dataclass from enum import Enum from functools import partial, wraps from itertools import product, chain @@ -237,6 +238,81 @@ def wrapper(*args, **kwargs): fn(*args, **kwargs) return wrapper +# Tries to extract the current test function by crawling the stack. +# If unsuccessful, return None. +def extract_test_fn() -> Optional[Callable]: + try: + stack = inspect.stack() + for frame_info in stack: + frame = frame_info.frame + if "self" not in frame.f_locals: + continue + self_val = frame.f_locals["self"] + if isinstance(self_val, unittest.TestCase): + test_id = self_val.id() + test_name = test_id.split('.')[2] + test_fn = getattr(self_val, test_name).__func__ + return test_fn + except Exception: + pass + return None + +# Contains tracked input data useful for debugging purposes +@dataclass +class TrackedInput: + index: int + val: Any + type_desc: str + +# Attempt to pull out tracked input information from the test function. +# A TrackedInputIter is used to insert this information. +def get_tracked_input() -> Optional[TrackedInput]: + test_fn = extract_test_fn() + if test_fn is None: + return None + if not hasattr(test_fn, "tracked_input"): + return None + return test_fn.tracked_input + +# Wraps an iterator and tracks the most recent value the iterator produces +# for debugging purposes. Tracked values are stored on the test function. +class TrackedInputIter: + def __init__(self, child_iter, input_type_desc, callback=lambda x: x): + self.child_iter = enumerate(child_iter) + # Input type describes the things we're tracking (e.g. "sample input", "error input"). + self.input_type_desc = input_type_desc + # Callback is run on each iterated thing to get the thing to track. + self.callback = callback + self.test_fn = extract_test_fn() + + def __iter__(self): + return self + + def __next__(self): + try: + input_idx, input_val = next(self.child_iter) + self._set_tracked_input( + TrackedInput( + index=input_idx, val=self.callback(input_val), type_desc=self.input_type_desc + ) + ) + return input_val + except StopIteration as e: + self._clear_tracked_input() + raise e + + def _set_tracked_input(self, tracked_input: TrackedInput): + if self.test_fn is None: + return + if not hasattr(self.test_fn, "tracked_input"): + return + self.test_fn.tracked_input = tracked_input + + def _clear_tracked_input(self): + if self.test_fn is not None and hasattr(self.test_fn, "tracked_input"): + self.test_fn.tracked_input = None + self.test_fn = None + class _TestParametrizer: """ Decorator class for parametrizing a test function, yielding a set of new tests spawned diff --git a/torch/testing/_internal/opinfo/core.py b/torch/testing/_internal/opinfo/core.py index fc0fbf95864f..23b6e89e4a21 100644 --- a/torch/testing/_internal/opinfo/core.py +++ b/torch/testing/_internal/opinfo/core.py @@ -29,6 +29,7 @@ noncontiguous_like, TEST_WITH_ROCM, torch_to_numpy_dtype_dict, + TrackedInputIter, ) from torch.testing._internal.opinfo import utils @@ -207,7 +208,6 @@ def _repr_helper(self, formatter): f"input={formatter(self.input)}", f"args={formatter(self.args)}", f"kwargs={formatter(self.kwargs)}", - f"output_process_fn_grad={self.output_process_fn_grad}", f"broadcasts_input={self.broadcasts_input}", f"name={repr(self.name)}", ] @@ -227,8 +227,15 @@ def formatter(arg): # by Tensor[TensorShape] # Eg. Tensor with shape (3, 4) is formatted as Tensor[3, 4] if isinstance(arg, torch.Tensor): - shape = str(tuple(arg.shape)).replace("(", "").replace(")", "") - return f"Tensor[{shape}]" + shape = str(tuple(arg.shape)) + dtype = str(arg.dtype) + device = str(arg.device) + contiguity_suffix = "" + # NB: sparse CSR tensors annoyingly return is_sparse=False + is_sparse = arg.is_sparse or arg.layout == torch.sparse_csr + if not is_sparse and not arg.is_contiguous(): + contiguity_suffix = ", contiguous=False" + return f'Tensor[size={shape}, device="{device}", dtype={dtype}{contiguity_suffix}]' elif isinstance(arg, dict): return {k: formatter(v) for k, v in arg.items()} elif is_iterable_of_tensors(arg): @@ -1155,7 +1162,7 @@ def conjugate(tensor): else: sample.input[0] = conjugate(sample.input[0]) - return tuple(conj_samples) + return TrackedInputIter(iter(conj_samples), "conjugate sample input") def sample_inputs(self, device, dtype, requires_grad=False, **kwargs): """ @@ -1174,7 +1181,7 @@ def sample_inputs(self, device, dtype, requires_grad=False, **kwargs): samples_list.extend(conj_samples) samples = tuple(samples_list) - return samples + return TrackedInputIter(iter(samples), "sample input") def reference_inputs(self, device, dtype, requires_grad=False, **kwargs): """ @@ -1185,18 +1192,27 @@ def reference_inputs(self, device, dtype, requires_grad=False, **kwargs): the sample inputs. """ if self.reference_inputs_func is None: - return self.sample_inputs_func(self, device, dtype, requires_grad, **kwargs) + samples = self.sample_inputs_func( + self, device, dtype, requires_grad, **kwargs + ) + return TrackedInputIter(iter(samples), "sample input") if kwargs.get("include_conjugated_inputs", False): raise NotImplementedError - return self.reference_inputs_func(self, device, dtype, requires_grad, **kwargs) + references = self.reference_inputs_func( + self, device, dtype, requires_grad, **kwargs + ) + return TrackedInputIter(iter(references), "reference input") def error_inputs(self, device, **kwargs): """ Returns an iterable of ErrorInputs. """ - return self.error_inputs_func(self, device, **kwargs) + errs = self.error_inputs_func(self, device, **kwargs) + return TrackedInputIter( + iter(errs), "error input", callback=lambda e: e.sample_input + ) def error_inputs_sparse(self, device, layout, **kwargs): """ From d70857bd9e425ffc5b1a32639fb6651462eeef97 Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Tue, 21 Nov 2023 00:45:48 +0000 Subject: [PATCH 014/221] [pytorch][lite interpreter] add tracer run under inference guard (#114003) Summary: This can change the ops called under the hood. Its not safe to always call because of on device training. Test Plan: ci Differential Revision: D51440119 Pull Request resolved: https://github.com/pytorch/pytorch/pull/114003 Approved by: https://github.com/Jack-Khuu --- torch/csrc/jit/mobile/model_tracer/TracerRunner.cpp | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/torch/csrc/jit/mobile/model_tracer/TracerRunner.cpp b/torch/csrc/jit/mobile/model_tracer/TracerRunner.cpp index e665713e7db9..585747c14d82 100644 --- a/torch/csrc/jit/mobile/model_tracer/TracerRunner.cpp +++ b/torch/csrc/jit/mobile/model_tracer/TracerRunner.cpp @@ -343,6 +343,16 @@ TracerResult trace_run(const std::vector& input_module_paths) { << "ModelTracer encountered an error while attempting to run the model in FBGEMM mode" << ex.what() << "\n Skipping FBGEMM execution" << std::endl; } + try { + at::globalContext().setQEngine(at::QEngine::QNNPACK); + c10::InferenceMode guard(true); + run_model( + input_module_path, root_ops, enabled_backends, called_kernel_tags); + } catch (std::exception& ex) { + std::cerr + << "ModelTracer encountered an error while attempting to run the model under an inference guard" + << ex.what() << "\n Skipping inference guard execution" << std::endl; + } } call_dependent_methods(root_ops); From fb25fd6f865ed0532caf710ca130b6cc23a772a8 Mon Sep 17 00:00:00 2001 From: Andrew Gu Date: Mon, 20 Nov 2023 13:44:21 -0800 Subject: [PATCH 015/221] [DTensor] Replaced neg dim normalization with assert in helper (#114141) This is a replacement for https://github.com/pytorch/pytorch/pull/113922. I think we can still leave the check for negative shard dimension in `compute_local_shape_and_global_offset` and replace the normalization logic with an assert. This should provide us a stack trace to see which user-facing API did not normalize the dim as expected. Pull Request resolved: https://github.com/pytorch/pytorch/pull/114141 Approved by: https://github.com/wanchaol ghstack dependencies: #113919, #113924, #114134, #113925, #113930 --- torch/distributed/_tensor/_utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torch/distributed/_tensor/_utils.py b/torch/distributed/_tensor/_utils.py index 10d9b11c51b1..04b714c7789a 100644 --- a/torch/distributed/_tensor/_utils.py +++ b/torch/distributed/_tensor/_utils.py @@ -145,8 +145,10 @@ def compute_global_tensor_info( if placement.is_shard(): shard_placement = cast(Shard, placement) if shard_placement.dim < 0: - # normalize shard dim to be positive - shard_placement.dim += len(tensor_shape) + raise AssertionError( + "Shard placements should have negative dims normalized in " + f"the user-facing APIs: {shard_placement}" + ) shard_dim = shard_placement.dim assert ( From 3e49621f3b4652b8e7782aa8dafb28f9d985598b Mon Sep 17 00:00:00 2001 From: Andrew Gu Date: Mon, 20 Nov 2023 13:44:22 -0800 Subject: [PATCH 016/221] [DTensor] Cached hash for `DTensorSpec` (#113915) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit **Overview** Generally, I think we can try to freeze as many of these classes used in DTensor sharding propagation as possible so that we can cache hashes. This PR targets hashing `DTensorSpec`, which turns out to be relatively expensive. **Details** It looks like `tensor_meta` is only updated in `_wrap_output_spec_tensor_meta`, which only runs if the propagation was not cached: https://github.com/pytorch/pytorch/blob/ae94c7e491e22f58d3df66571c1a568e51d70acd/torch/distributed/_tensor/sharding_prop.py#L137 https://github.com/pytorch/pytorch/blob/ae94c7e491e22f58d3df66571c1a568e51d70acd/torch/distributed/_tensor/sharding_prop.py#L153 In that case, I think we can cache the hash for the `DTensorSpec` and only update it when one of the hashed attributes changes, which we only really expect to happen for `tensor_meta`. To ensure correctness, we need that all hashed attributes are immutable. - `DeviceMesh` caches its hash: https://github.com/pytorch/pytorch/blob/a9134fa99a8986adf478a12db2ea5729d24554db/torch/distributed/_device_mesh.py#L181 - This PR makes each `Placement` a frozen `dataclass`, making them immutable (relying on the fact that they do not have references to any mutable objects). - `TensorMeta` is a `NamedTuple` of `torch.Size`, `Tuple[int, ...]`, and `torch.dtype`, so it is immutable: https://github.com/pytorch/pytorch/blob/9916d8a9eaaf2c05c131f2a2dbe9eabeeaa9dffc/torch/distributed/_tensor/placement_types.py#L369-L375 **Example** For some simple small GPT model: Before: 0.125 ms Screenshot 2023-11-16 at 10 08 05 PM After: 0.048 ms Screenshot 2023-11-16 at 10 08 47 PM The overall Adam CPU step time decreases from 7.647 ms to 6.451 ms. Pull Request resolved: https://github.com/pytorch/pytorch/pull/113915 Approved by: https://github.com/wanchaol ghstack dependencies: #113919, #113924, #114134, #113925, #113930, #114141 --- torch/distributed/_tensor/placement_types.py | 25 +++++++++++++++----- 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/torch/distributed/_tensor/placement_types.py b/torch/distributed/_tensor/placement_types.py index 458963a124c8..7c1c61c21bf2 100644 --- a/torch/distributed/_tensor/placement_types.py +++ b/torch/distributed/_tensor/placement_types.py @@ -1,7 +1,7 @@ # Copyright (c) Meta Platforms, Inc. and affiliates from dataclasses import dataclass -from typing import cast, List, NamedTuple, Optional, Tuple +from typing import Any, cast, List, NamedTuple, Optional, Tuple import torch import torch.distributed._functional_collectives as funcol @@ -28,10 +28,10 @@ def is_partial(self) -> bool: return isinstance(self, _Partial) +@dataclass(frozen=True) class Shard(Placement): # shard placement, shard on a dim - def __init__(self, dim): - self.dim = dim + dim: int def _split_tensor( self, @@ -387,8 +387,16 @@ class DTensorSpec: def __post_init__(self): if not isinstance(self.placements, tuple): self.placements = tuple(self.placements) + self._hash = self._hash_impl() - def __hash__(self) -> int: + def __setattr__(self, attr: str, value: Any): + super().__setattr__(attr, value) + # Make sure to recompute the hash in case any of the hashed attributes + # change (though we do not expect `mesh` or `placements` to change) + if hasattr(self, "_hash") and attr in ("mesh", "placements", "tensor_meta"): + self._hash = self._hash_impl() + + def _hash_impl(self): # hashing and equality check for DTensorSpec are used to cache the sharding # propagation results. We only need to consider the mesh, placements, shape # dtype and stride. @@ -404,8 +412,13 @@ def __hash__(self) -> int: self.tensor_meta.dtype, ) ) - else: - return hash((self.mesh, self.placements)) + return hash((self.mesh, self.placements)) + + def __hash__(self) -> int: + # We eagerly cache the spec to avoid recomputing the hash upon each + # use, where we make sure to update the hash when the `tensor_meta` + # changes by overriding `__setattr__`. + return self._hash def __eq__(self, __o: object) -> bool: if not ( From 6ec344b08fa5460bcd9a3d4d0e2b7dd4d9eb8d28 Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Mon, 20 Nov 2023 18:18:21 +0000 Subject: [PATCH 017/221] Fix empty cpu tensor output in cudagraph (#114144) We can ignore empty cpu tensors Differential Revision: [D51472324](https://our.internmc.facebook.com/intern/diff/D51472324) Pull Request resolved: https://github.com/pytorch/pytorch/pull/114144 Approved by: https://github.com/davidberard98 --- test/inductor/test_cudagraph_trees.py | 13 +++++++++++++ torch/_inductor/cudagraph_trees.py | 2 +- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/test/inductor/test_cudagraph_trees.py b/test/inductor/test_cudagraph_trees.py index d950a12048cf..50b006599c7f 100644 --- a/test/inductor/test_cudagraph_trees.py +++ b/test/inductor/test_cudagraph_trees.py @@ -808,6 +808,19 @@ def foo(x): # didnt do additional recordings self.assertTrue(self.get_manager().new_graph_id().id == 2) + def test_empty_cpu_tensor(self): + def foo(x): + return x @ x, torch.tensor([]) + + foo_opt = torch.compile(foo) + x = torch.rand([4], device="cuda") + + for _ in range(3): + out_opt = foo_opt(x) + self.assertEqual(foo(x), out_opt) + + self.assertTrue(self.get_manager().new_graph_id().id == 1) + def test_output_alias(self): inp = torch.rand([20, 20], device="cuda") diff --git a/torch/_inductor/cudagraph_trees.py b/torch/_inductor/cudagraph_trees.py index 6abef96e0921..73a0deb2772b 100644 --- a/torch/_inductor/cudagraph_trees.py +++ b/torch/_inductor/cudagraph_trees.py @@ -1137,7 +1137,7 @@ def _add_first_outputs( continue torch._check( - o.is_cuda, + o.is_cuda or o.untyped_storage().data_ptr() == 0, lambda: ( "Expected all cuda outputs in cuda graph recording. Non cuda output " f"from {self.stack_traces[i] if self.stack_traces else '(unknown)'}" From 585332fb8d5131c6483d94bd3dbef5a7aac75ad9 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Tue, 21 Nov 2023 01:29:00 +0000 Subject: [PATCH 018/221] [ProcessGroupNCCL] Fix avoid-record-stream warning for P2P (#114168) I have been seen below warning even though I did not set `TORCH_NCCL_AVOID_RECORD_STREAMS` to 1. ``` Warning: 0TORCH_NCCL_AVOID_RECORD_STREAMS=1 has no effect for point-to-point collectives. (function operator()) ``` Turns out that `TORCH_WARN_ONCE` is unconditional, so the original code below would print out both the value of `avoidRecordStreams_` and the error message: ``` TORCH_WARN_ONCE( avoidRecordStreams_, "TORCH_NCCL_AVOID_RECORD_STREAMS=1 has no effect for point-to-point " "collectives."); ``` That's also where the "0" in the message came from. Cc: @eqy Pull Request resolved: https://github.com/pytorch/pytorch/pull/114168 Approved by: https://github.com/eqy, https://github.com/fduwjj, https://github.com/H-Huang --- torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index f6b47774316a..a2f2a7f86353 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -2492,10 +2492,11 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( // when it's safe to release the input back to the allocator, // and the present call has no way to know it's not an isend. // Therefore, we warn and fall back to the typical recordStream logic: - TORCH_WARN_ONCE( - avoidRecordStreams_, - "TORCH_NCCL_AVOID_RECORD_STREAMS=1 has no effect for point-to-point " - "collectives."); + if (avoidRecordStreams_) { + TORCH_WARN_ONCE( + "TORCH_NCCL_AVOID_RECORD_STREAMS=1 has no effect for point-to-point " + "collectives."); + } // Bump sequence number, updated in collective() as well seq_++; From 85ce8a602b6187533330d23af578d78b3f2321ce Mon Sep 17 00:00:00 2001 From: Huy Do Date: Tue, 21 Nov 2023 01:29:30 +0000 Subject: [PATCH 019/221] Pin pywavelets to 1.4.1 (scikit-image dependency) (#114146) This is to prevent pip from pulling in 1.22.4 and fails Docker image builds, for example, https://github.com/pytorch/pytorch/actions/runs/6923861547/job/18842791777 The new package was released on Nov 17th https://pypi.org/project/PyWavelets/1.5.0/ Pull Request resolved: https://github.com/pytorch/pytorch/pull/114146 Approved by: https://github.com/malfet, https://github.com/kit1980 --- .ci/docker/requirements-ci.txt | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/.ci/docker/requirements-ci.txt b/.ci/docker/requirements-ci.txt index d2e28d8fdfb3..25be26621985 100644 --- a/.ci/docker/requirements-ci.txt +++ b/.ci/docker/requirements-ci.txt @@ -292,3 +292,9 @@ tensorboard==2.13.0 #Description: Also included in .ci/docker/requirements-docs.txt #Pinned versions: #test that import: test_tensorboard + +pywavelets==1.4.1 +#Description: This is a requirement of scikit-image, we need to pin +# it here because 1.5.0 conflicts with numpy 1.21.2 used in CI +#Pinned versions: 1.4.1 +#test that import: From 77f16eb00cb8b4d4ead76f284d0de0c6e8f63764 Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Mon, 20 Nov 2023 15:10:29 -0300 Subject: [PATCH 020/221] Fix prod double backward when there are 2+ zeros (#113969) Pull Request resolved: https://github.com/pytorch/pytorch/pull/113969 Approved by: https://github.com/albanD --- torch/csrc/autograd/FunctionsManual.cpp | 2 +- torch/testing/_internal/common_methods_invocations.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index d8d76bfe8b51..8a63e7ceb3a8 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -825,7 +825,7 @@ Tensor prod_backward( Tensor zero_idx = (input == 0).nonzero(); if (zero_idx.sym_numel() == 0) { return grad * (result / input).conj(); - } else if (zero_idx.size(0) > 1) { + } else if (!at::GradMode::is_enabled() && zero_idx.size(0) > 1) { return at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); } else { return prod_safe_zeros_backward(grad, input.contiguous().view(-1), 0) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index a0a82e168375..33595ad51bb2 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -6000,6 +6000,7 @@ def prod_single_zero(): yield SampleInput(make_arg((3, 0)), args=(1,)) yield SampleInput(make_arg((3, 0)), args=(1,), kwargs={'keepdim': True}) + yield SampleInput(torch.tensor([2., 3, 0, 0], dtype=dtype, device=device, requires_grad=requires_grad)) # test zero scalar tensor zero = make_arg(()) From e8996055a91ba91b512890848817ffe98cfbad1d Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Tue, 21 Nov 2023 01:35:54 +0000 Subject: [PATCH 021/221] [iOS][PTMCoreMLCompiler] update other deprecated function (#114177) Summary: old way was deprecated Test Plan: ci Reviewed By: kirklandsign Differential Revision: D51172622 Pull Request resolved: https://github.com/pytorch/pytorch/pull/114177 Approved by: https://github.com/kirklandsign --- torch/csrc/jit/backends/coreml/objc/PTMCoreMLCompiler.mm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/csrc/jit/backends/coreml/objc/PTMCoreMLCompiler.mm b/torch/csrc/jit/backends/coreml/objc/PTMCoreMLCompiler.mm index 9fbaff4dbc69..e72feca35351 100644 --- a/torch/csrc/jit/backends/coreml/objc/PTMCoreMLCompiler.mm +++ b/torch/csrc/jit/backends/coreml/objc/PTMCoreMLCompiler.mm @@ -11,7 +11,7 @@ @implementation PTMCoreMLCompiler static NSString *gVersionExtension = @"version"; + (void)setCacheDirectory:(const std::string&)dir { - gCacheDirectory = [NSString stringWithCString:dir.c_str()]; + gCacheDirectory = [NSString stringWithCString:dir.c_str() encoding:NSUTF8StringEncoding]; } + (nonnull NSString *)cacheDirectory { From 81f93991d3be078fada6dcaf939c2fb0bdb67e63 Mon Sep 17 00:00:00 2001 From: Huy Do Date: Tue, 21 Nov 2023 01:36:49 +0000 Subject: [PATCH 022/221] Update merge rule to allow pytorchbot to land ExecuTorch hash update (#114180) The bot cannot merge the hash update PR otherwise, for example https://github.com/pytorch/pytorch/pull/114008#issuecomment-1818032181. I also need to move ExecuTorch jobs in trunk to pull to match the rule without the need to add `ciflow/trunk` label. The test job takes less than 20 minutes to finish atm on `2xlarge`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/114180 Approved by: https://github.com/seemethere, https://github.com/ZainRizvi, https://github.com/malfet --- .github/merge_rules.yaml | 12 ++++++++++++ .github/workflows/pull.yml | 20 ++++++++++++++++++++ .github/workflows/trunk.yml | 20 -------------------- 3 files changed, 32 insertions(+), 20 deletions(-) diff --git a/.github/merge_rules.yaml b/.github/merge_rules.yaml index c6c776cc90f5..f7d62cfdd6b6 100644 --- a/.github/merge_rules.yaml +++ b/.github/merge_rules.yaml @@ -85,6 +85,18 @@ - Lint - pull +- name: OSS CI /pytorchbot / Executorch + patterns: + - .ci/docker/ci_commit_pins/executorch.txt + approved_by: + - pytorchbot + ignore_flaky_failures: false + mandatory_checks_name: + - EasyCLA + - Lint + - pull / linux-jammy-py3-clang12-executorch / build + - pull / linux-jammy-py3-clang12-executorch / test (executorch, 1, 1, linux.2xlarge) + - name: OSS CI / pytorchbot / XLA patterns: - .github/ci_commit_pins/xla.txt diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index b773190f1b3e..cbf8d06a6db0 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -393,3 +393,23 @@ jobs: build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm86 docker-image: ${{ needs.linux-focal-cuda12_1-py3_10-gcc9-sm86-build.outputs.docker-image }} test-matrix: ${{ needs.linux-focal-cuda12_1-py3_10-gcc9-sm86-build.outputs.test-matrix }} + + linux-jammy-py3-clang12-executorch-build: + name: linux-jammy-py3-clang12-executorch + uses: ./.github/workflows/_linux-build.yml + with: + build-environment: linux-jammy-py3-clang12-executorch + docker-image-name: pytorch-linux-jammy-py3-clang12-executorch + test-matrix: | + { include: [ + { config: "executorch", shard: 1, num_shards: 1, runner: "linux.2xlarge" }, + ]} + + linux-jammy-py3-clang12-executorch-test: + name: linux-jammy-py3-clang12-executorch + uses: ./.github/workflows/_linux-test.yml + needs: linux-jammy-py3-clang12-executorch-build + with: + build-environment: linux-jammy-py3-clang12-executorch + docker-image: ${{ needs.linux-jammy-py3-clang12-executorch-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-py3-clang12-executorch-build.outputs.test-matrix }} diff --git a/.github/workflows/trunk.yml b/.github/workflows/trunk.yml index 6f7db8718de3..13d864084eb6 100644 --- a/.github/workflows/trunk.yml +++ b/.github/workflows/trunk.yml @@ -175,26 +175,6 @@ jobs: { config: "force_on_cpu", shard: 1, num_shards: 1, runner: "windows.4xlarge.nonephemeral" }, ]} - linux-jammy-py3-clang12-executorch-build: - name: linux-jammy-py3-clang12-executorch - uses: ./.github/workflows/_linux-build.yml - with: - build-environment: linux-jammy-py3-clang12-executorch - docker-image-name: pytorch-linux-jammy-py3-clang12-executorch - test-matrix: | - { include: [ - { config: "executorch", shard: 1, num_shards: 1, runner: "linux.2xlarge" }, - ]} - - linux-jammy-py3-clang12-executorch-test: - name: linux-jammy-py3-clang12-executorch - uses: ./.github/workflows/_linux-test.yml - needs: linux-jammy-py3-clang12-executorch-build - with: - build-environment: linux-jammy-py3-clang12-executorch - docker-image: ${{ needs.linux-jammy-py3-clang12-executorch-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-jammy-py3-clang12-executorch-build.outputs.test-matrix }} - linux-focal-rocm5_7-py3_8-build: name: linux-focal-rocm5.7-py3.8 uses: ./.github/workflows/_linux-build.yml From a911b4db9d82238a1d423e2b4c0a3d700217f0c1 Mon Sep 17 00:00:00 2001 From: voznesenskym Date: Mon, 20 Nov 2023 11:47:55 -0800 Subject: [PATCH 023/221] AOTAutograd: handle set_(), detect metadata mutations that cancel out (#111554) This should be enough to get @voznesenskym 's FSDP branch to plumb `set_()` through AOTAutograd properly and have everything properly no-op out. Main changes are: (1) graph break on `aten::set_.source_Tensor_storage_offset` (we could support it but it isn't needed, seems safer to graph break) (2) Functionalization: add a "proper" functionalization kernel for `aten::set_.source_Tensor`. The previous one we had was codegen'd and it was wrong (it would just clone() and call set_(), which does not do the right thing). I also manually mark on the `FunctionalTensorWrapper` when a given tensor has been mutated by a `set_()` call. (3) AOTAutograd: I added a new field, `InputAliasInfo.mutates_storage_metadata`, so we can distinguish between "regular" metadata mutations, and metadata mutations due to `set_()` calls. This is mainly because at runtime, one requires calling `as_strided_()` to fix up metadata, while the other requires calling `set_()`. (4) Made AOTAutograd's detection for metadata mutations / set_() mutations smarter and detect no-ops (if the storage and metadata are all the same). I also killed `was_updated()` and `was_metadata_updated()`, and replaced them with (existing) `has_data_mutation() ` and (new) `has_data_mutation()`, which can more accurately distinguish between data-mutation vs. `set_()` calls vs. metadata-mutation **This PR is still silently correct in one case though**, which I'd like to discuss more. In particular, this example: ``` def f(x): x_view = x.view(-1) x.set_(torch.ones(2)) x_view.mul_(2) return ``` If you have an input that experiences both a data-mutation **and** a `x_old.set_(x_new)` call, there are two cases: (a) the data mutation happened on the storage of `x_new`. This case should be handled automatically: if x_new is a graph intermediate then we will functionalize the mutation. If x_new is a different graph input, then we will perform the usual `copy_()` on that other graph input (b) the data mutation happened on the storage of `x_old`. This is more of a pain to handle, and doesn't currently work. At runtime, the right thing to do is probably something like: ``` def functionalized_f(x): x_view = x.view(-1) # set_() desugars into a no-op; later usages of x will use x_output x_output = torch.ones(2) # functionalize the mutation on x_view x_view_updated = x.mul(2) x_updated = x_view_updated.view(x.shape) # x experienced TWO TYPES of mutations; a data mutation and a metatadata mutation # We need to return both updated tensors in our graph return x_updated, x_output def runtime_wrapper(x): x_data_mutation_result, x_set_mutation_result = compiled_graph(x) # First, perform the data mutation on x's old storage x.copy_(x_data_mutation_result) # Then, swap out the storage of x with the new storage x.set_(x_set_mutation_result) ``` There are two things that make this difficult to do though: (1) Functionalization: the functionalization rule for `set_()` will fully throw away the old `FunctionalStorageImpl` on the graph input. So if there are any mutations to that `FunctionalStorageImpl` later on in the graph, the current graph input won't know about it. Maybe we can have a given `FunctionalTensorWrapper` remember all previous storages that it had, and track mutations on all of them - although this feels pretty complicated. (2) AOTAutograd now needs to know that we might have *two* graph outputs that correspond to a single "mutated input", which is annoying. It's worth pointing out that this issue is probably extremely unlikely for anyone to run into - can we just detect it and error? This feels slightly easier than solving it, although not significantly easier. We would still need `FunctionalTensorWrapper` to keep track of mutations on any of its "previous" storages, so it can report this info back to AOTAutograd so we can raise an error. Pull Request resolved: https://github.com/pytorch/pytorch/pull/111554 Approved by: https://github.com/ezyang ghstack dependencies: #113926 --- aten/src/ATen/FunctionalTensorWrapper.cpp | 29 +++ aten/src/ATen/FunctionalTensorWrapper.h | 14 ++ aten/src/ATen/FunctionalizeFallbackKernel.cpp | 25 +++ test/functorch/test_aotdispatch.py | 126 +++++++++++-- torch/_dynamo/variables/tensor.py | 8 + torch/_functorch/aot_autograd.py | 171 ++++++++++++------ .../python_torch_functions_manual.cpp | 51 ++++++ torchgen/gen_functionalization_type.py | 6 +- 8 files changed, 358 insertions(+), 72 deletions(-) diff --git a/aten/src/ATen/FunctionalTensorWrapper.cpp b/aten/src/ATen/FunctionalTensorWrapper.cpp index 7a6c5c41632e..5ab225467766 100644 --- a/aten/src/ATen/FunctionalTensorWrapper.cpp +++ b/aten/src/ATen/FunctionalTensorWrapper.cpp @@ -232,6 +232,35 @@ void FunctionalTensorWrapper::replace_(const Tensor& other) { mutation_counter_++; } +bool FunctionalTensorWrapper::has_data_mutation() { + // Current tensor's data was mutated if its storage saw any mutations. + return functional_storage_impl()->generation() > 0; +} + +void FunctionalTensorWrapper::set__impl(const FunctionalTensorWrapper* other) { + // self.set_(src) will cause self to have all of the tensor properties of self. + value_ = other->value_; + generation_ = other->generation_; + view_metas_ = other->view_metas_; + // FREEZE the old storage, preventing mutations to it. + // this is a huge pain to handle properly in all cases, so we ban it. + functional_storage_impl()->freeze(); + // Unsafely swap out the storage with other's storage, + // disconnecting `self` with its view chain + storage_ = other->storage_; + /// explicitly mark the tensor as having its storage changed from set_() + // Otherwise, we don't actually have a 100% accurate way to check this. + // (We could check if the updated value has a new storage than the original value, + // but this won't also let us uniquely determine if the tensor **also** + // experienced a data mutation). + was_storage_changed_ = true; + + auto sizes_ = value_.sym_sizes(); + auto strides_ = value_.sym_strides(); + auto storage_offset_ = value_.sym_storage_offset(); + set_sizes_and_strides(sizes_, strides_, storage_offset_); +} + void FunctionalTensorWrapper::maybe_replace_storage(const Tensor& other) { // Note [resize_() in functionalization pass] // resize_() is a special operator in functionalization because it can reallocate its underlying storage. diff --git a/aten/src/ATen/FunctionalTensorWrapper.h b/aten/src/ATen/FunctionalTensorWrapper.h index 3d899038c1e7..7b22ceeb01a6 100644 --- a/aten/src/ATen/FunctionalTensorWrapper.h +++ b/aten/src/ATen/FunctionalTensorWrapper.h @@ -122,6 +122,18 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl { // tensor by replaying the views off of the alias. void mutate_view_meta(at::functionalization::ViewMeta meta); + // Custom implementation of self.set_(src) + void set__impl(const FunctionalTensorWrapper* other); + + // Returns whether the current tensor's data was ever mutated + bool has_data_mutation(); + // + // Returns whether the current FunctionalTensorWrapper + // experienced a set_() call. + bool was_storage_changed() { + return was_storage_changed_; + } + // The functionalization pass can be used to remove mutations. // It does so by replacing any mutation op with it's corresponding // out-of-place op, followed by a call to replace_(). e.g: @@ -195,6 +207,8 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl { uint64_t mutation_hidden_from_autograd_counter_ = 0; bool has_metadata_mutation_ = false; bool is_multi_output_view_ = false; + // Did the tensor experience a set_() call. + bool was_storage_changed_ = false; size_t generation_ = 0; std::vector view_metas_; diff --git a/aten/src/ATen/FunctionalizeFallbackKernel.cpp b/aten/src/ATen/FunctionalizeFallbackKernel.cpp index 3e9e234db45a..783a925d6983 100644 --- a/aten/src/ATen/FunctionalizeFallbackKernel.cpp +++ b/aten/src/ATen/FunctionalizeFallbackKernel.cpp @@ -299,6 +299,28 @@ static at::Tensor _unsafe_view_functionalize(const at::Tensor & self, at::SymInt return out; } +static at::Tensor& set__functionalize(at::Tensor& self, const at::Tensor& src) { + // error case + TORCH_CHECK(at::functionalization::impl::isFunctionalTensor(self) || !at::functionalization::impl::isFunctionalTensor(src), + "set__functionalize: Tried to mutate a non-functional tensor with a functional tensor, which is not allowed"); + + TORCH_CHECK(at::functionalization::impl::isFunctionalTensor(src), + "set__functionalize: We do not currently support x.set_(y) where y is not a FunctionalTensor. Please file an issue"); + + // nop case + if (!at::functionalization::impl::isFunctionalTensor(self) && !at::functionalization::impl::isFunctionalTensor(src)) { + at::AutoDispatchSkipFunctionalize guard; + return self.set_(src); + } + + TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(self)); + TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(src)); + auto self_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(self); + auto src_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(src); + self_impl->set__impl(src_impl); + return self; +} + TORCH_LIBRARY_IMPL(_, Functionalize, m) { m.fallback(torch::CppFunction::makeFromBoxedFunction<&functionalizeFallback>()); } @@ -310,4 +332,7 @@ TORCH_LIBRARY_IMPL(aten, Functionalize, m) { m.impl("lift_fresh_copy", TORCH_FN(lift_fresh_functionalize_copy)); m.impl("_to_copy", TORCH_FN(_to_copy_functionalize)); m.impl("_unsafe_view", TORCH_FN(_unsafe_view_functionalize)); + // The overloads of set_() that take in a storage should never + // appear with torch.compile, because dynamo graph breaks + m.impl("set_.source_Tensor", TORCH_FN(set__functionalize)); } diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index 5da8308205c4..ab310c247abe 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -455,6 +455,97 @@ def forward(self, primals_1): mul_1 = torch.ops.aten.mul.Tensor(mul, 3) return [mul, mul_1]""") + def test_input_mutation_set__input_mutation(self): + def f(a): + b = torch.arange(9, dtype=a.dtype).reshape(3, 3) + with torch.no_grad(): + a.set_(b) + return a * b + inp = [torch.ones(3, 3, requires_grad=True)] + self.verify_aot_autograd(f, inp, test_mutation=True) + inp = [torch.ones(3, 3, requires_grad=False)] + self.verify_aot_autograd(f, inp, test_mutation=True) + + def test_set__steals_view_chain(self): + def f(a, b): + a_ = a.mul(2) + b_ = b.mul(2) + b_slice = b_[1].view(3, 3) + # a_clone should inherit the view chain from b_slice + a_.set_(b_slice) + # Also mutates b_, + a_.view(-1).mul_(2) + return a_ * b_slice + inp = [torch.ones(3, 3, requires_grad=False), torch.zeros(3, 9, requires_grad=False)] + self.verify_aot_autograd(f, inp) + + def test_set__and_data_mutation_good(self): + def f(a, b): + # The data mutation happens *after* the set_(). This is ok (see the graph below) + with torch.no_grad(): + a.set_(b) + b.mul_(2) + return a + b + inp = [torch.ones(3, 3, requires_grad=True), torch.ones(3, 3, requires_grad=True)] + fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) + inp = [torch.ones(3, 3, requires_grad=False), torch.zeros(3, 3, requires_grad=False)] + self.verify_aot_autograd(f, inp, test_mutation=True) + # Important things to note: + # - "return a.set_(b)" desugars into "return b" + # - Both a and b are recorded as experiencing mutations, + # which is why we see "b_updated" (output of the mul) twice in the graph outputs. + # a is recorded as both a data mutation and a metadata mutation (due to set_ swapping its storage). + # - the runtime epilogue for a is "a.set_(mul)" + # - the runtime epilogue for b is "b.copy_(mul)" + self.assertExpectedInline(fw_graph.code.strip(), """\ +def forward(self, primals_1, primals_2): + clone = torch.ops.aten.clone.default(primals_2); primals_2 = None + mul = torch.ops.aten.mul.Tensor(clone, 2); clone = None + add = torch.ops.aten.add.Tensor(mul, mul) + return [mul, mul, add]""") + + # This is a (hopefully) extremely rare case that is difficult to handle, + # so we ban it. + def test_set__and_data_mutation_bad(self): + def f(a): + a_view = a.view(-1) + tmp = torch.ones(3, 3, requires_grad=True) + # Now, any mutations on either tmp + # will be tracked as graph input mutations. + with torch.no_grad(): + a.set_(tmp) + # BAD: a_view is now detached from every graph input, + # so we won't recognize that this caused an input mutation! + a_view.mul_(2) + return a + tmp + inp = [torch.ones(3, 3, requires_grad=True)] + with self.assertRaisesRegex(RuntimeError, "cannot mutate tensors with frozen storage"): + self.verify_aot_autograd(f, inp, test_mutation=True) + + def test_input_mutation_set__nop(self): + def f(a): + b = torch.arange(9, dtype=a.dtype) + a_old = torch.ops.aten.alias.default(a) + with torch.no_grad(): + a.set_(b) + a.set_(a_old) + return a + b.reshape(3, 3) + inp = [torch.ones(3, 3, requires_grad=True)] + fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) + inp = [torch.ones(3, 3, requires_grad=False)] + self.verify_aot_autograd(f, inp, test_mutation=True) + # Things to note: + # - There are no set_() calls in the graph (we functionalize a.set_(b) into "b") + # - There is only **1** graph output. We properly realized that the two set_() calls + # undo each other, and so effectively no inputs are mutated. + self.assertExpectedInline(fw_graph.code.strip(), """\ +def forward(self, primals_1): + arange = torch.ops.aten.arange.default(9, dtype = torch.float32, device = device(type='cpu'), pin_memory = False) + alias = torch.ops.aten.alias.default(primals_1); primals_1 = None + view = torch.ops.aten.view.default(arange, [3, 3]); arange = None + add = torch.ops.aten.add.Tensor(alias, view); alias = view = None + return [add]""") + def test_input_mutation_simple_with_none_and_nontensor(self): # Tensor, None, int def f(a, b, c): @@ -1624,10 +1715,9 @@ def inp_callable(req_grad): # Expectation: fwd() takes in 2 args, and we don't construct a synthetic base. self.assertExpectedInline(fw_graph.code.strip(), """\ def forward(self, primals_1, primals_2): - view = torch.ops.aten.view.default(primals_1, [4]); primals_1 = None - t = torch.ops.aten.t.default(view); view = None - add = torch.ops.aten.add.Tensor(t, primals_2); primals_2 = None - return [t, add]""") + t = torch.ops.aten.t.default(primals_1); primals_1 = None + add = torch.ops.aten.add.Tensor(t, primals_2); t = primals_2 = None + return [add]""") def test_input_mutation_aliases_and_none_require_gradients(self): def f(a, b, c): @@ -1666,7 +1756,7 @@ def test_input_mutation_aliases_bases_out_of_order(self): # So we don't need to do the base construction / deconstruction def f(a, b, c, d): b.add_(1) - d.t_() + d.unsqueeze_(0) return a + c + d, b.view(-1) def inp_callable(req_grad): @@ -1695,11 +1785,11 @@ def forward(self, primals_1, primals_2, primals_3): as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [4], [1], 0); clone = add = None add_1 = torch.ops.aten.add.Tensor(primals_2, primals_3); primals_2 = primals_3 = None as_strided_3 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0) - t_1 = torch.ops.aten.t.default(as_strided_3); as_strided_3 = None - add_2 = torch.ops.aten.add.Tensor(add_1, t_1); add_1 = None + unsqueeze_1 = torch.ops.aten.unsqueeze.default(as_strided_3, 0); as_strided_3 = None + add_2 = torch.ops.aten.add.Tensor(add_1, unsqueeze_1); add_1 = None as_strided_11 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0) - view_1 = torch.ops.aten.view.default(as_strided_11, [-1]); as_strided_11 = None - return [as_strided_scatter, add_2, view_1, t_1]""") # noqa: B950 + view_2 = torch.ops.aten.view.default(as_strided_11, [-1]); as_strided_11 = None + return [as_strided_scatter, add_2, view_2, unsqueeze_1]""") # noqa: B950 @unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable") def test_synthetic_base_base_attribute_is_none(self): @@ -1937,7 +2027,7 @@ def f(x, y): def test_dupe_arg_torture(self): def f(x, y): x.t_() - y.t_() + y.unsqueeze_(0) return x + y x = torch.randn(3, 3, requires_grad=True).clone() @@ -1998,8 +2088,8 @@ def test_invalid_dupe_fake(self, counter): def _test_invalid_dupe(self, counter, fake): class F(torch.nn.Module): def forward(self, x, y): - x.t_() - y.t_() + x.unsqueeze_(0) + y.unsqueeze_(0) return (x + y,) x = torch.randn(3, 3, requires_grad=True).clone() @@ -2018,6 +2108,8 @@ def forward(self, x, y): fxy = aot_module_simplified(F(), (x, y), nop) fxy(x, y) + x = torch.randn(3, 3, requires_grad=True).clone() + y = torch.randn(3, 3, requires_grad=True).clone() fxy(x, x) # is ok! if fake: @@ -2025,9 +2117,13 @@ def forward(self, x, y): else: fxx = aot_module_simplified(F(), (x, x), nop) + x = torch.randn(3, 3, requires_grad=True).clone() + y = torch.randn(3, 3, requires_grad=True).clone() fxx(x, x) # Note This should not raise! Once we have guards in place here, # we will have this working correctly, as it should recompile. + x = torch.randn(3, 3, requires_grad=True).clone() + y = torch.randn(3, 3, requires_grad=True).clone() self.assertExpectedRaisesInline( AssertionError, lambda: fxx(x, y), """At compilation time, graph 1 was compiled under the assumption that input 1 would be a duplicate of input 0, but at runtime this was not the case. This indicates a guard bug in AOTAutograd or Dynamo, please file a bug to PyTorch.""" # noqa: B950 @@ -2648,7 +2744,7 @@ def fn(p, x): x.t_() return (x * 2,) mod = TestMod(fn) - inp = torch.randn(2) + inp = torch.randn(2, 4) with self.assertRaisesRegex( RuntimeError, "Found an input that received a metadata mutation" ): @@ -3357,7 +3453,7 @@ def f(a, b): def test_aot_dispatch_input_metadata_mutation(self): def f(a, b): a.t_() - b.t_() + b.unsqueeze_(0) return a + b b1_ref = torch.arange(9, requires_grad=True, dtype=torch.float32).reshape(3, 3) @@ -3402,7 +3498,7 @@ def f(a, b): def test_aot_dispatch_input_data_and_metadata_mutation(self): def f(a, b): a.t_() - b.t_() + b.unsqueeze_(0) a.mul_(2) b.mul_(3) return a + b diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index 5fb93f9337e0..a87eec6cfbbf 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -605,6 +605,14 @@ def has_bool_key(v): elif name in ("resize_", "resize_as_"): # Handling resizing in its full generality is difficult. unimplemented(f"Tensor.{name}") + elif name == "set_" and len(args) > 1: + # torch.Tensor.set_() has several overloads. + # aten::set_.source_Tensor(Tensor) gets special handling + # in AOTAutograd and functionalization, because it is the most common + # overload and is used by FSDP. + # graph-breaking on aten::set_source_Tensor_storage_offset for now, + # unless we find that we need to make it work. + unimplemented("Tensor.set_.source_Tensor_storage_offset") elif ( name == "add_" and len(args) == 1 and len(kwargs) == 1 and "alpha" in kwargs ): diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py index cefb2e826a8c..0057902bbe83 100644 --- a/torch/_functorch/aot_autograd.py +++ b/torch/_functorch/aot_autograd.py @@ -490,9 +490,19 @@ class InputAliasInfo: mutates_data: bool mutates_metadata: bool mutations_hidden_from_autograd: bool + # This can only happen from a call to aten.set_() on a graph input. + mutates_storage_metadata: bool requires_grad: bool mutation_type: MutationType + def __post_init__(self): + if self.mutates_storage_metadata: + # For convenience, we guarantee that this is always true. + # In practice, If we call .set_(), then at runtime there is no need + # to additionally fix up the tensor metadata, since our runtime + # call to inp.set_(updated_inp) will already have the right metadata + assert self.mutates_metadata + @dataclasses.dataclass class SubclassCreationMeta: @@ -949,65 +959,85 @@ def is_fun(t): # t here is either # (1) A FunctionalTensor(_to_functional_tensor(FakeTensor)) # (2) A traceable tensor subclass that holds a FunctionalTensor -def has_metadata_mutation(t): +# (3) Not a tensor +def has_data_mutation(t): if is_traceable_wrapper_subclass(t): attrs, _ = t.__tensor_flatten__() # A tensor subclass was updated if any of its inner elements were updated - return any(has_metadata_mutation(getattr(t, attr)) for attr in attrs) + return any(has_data_mutation(getattr(t, attr)) for attr in attrs) else: - assert isinstance(t, FunctionalTensor) - return torch._functionalize_has_metadata_mutation(t.elem) + if isinstance(t, torch.Tensor): + assert isinstance(t, FunctionalTensor) + return torch._functionalize_has_data_mutation(t.elem) + return False def are_all_mutations_hidden_from_autograd(t): if is_traceable_wrapper_subclass(t): attrs, _ = t.__tensor_flatten__() # If all inner elements are mutations hidden from autograd, then it is a mutation hidden from autograd. return all(are_all_mutations_hidden_from_autograd(getattr(t, attr)) for attr in attrs) - else: + elif isinstance(t, torch.Tensor): assert isinstance(t, FunctionalTensor) return torch._functionalize_are_all_mutations_hidden_from_autograd(t.elem) - -# new_arg and arg here are either: -# (1) both a FakeTensor -# (2) both a traceable tensor subclass that holds a FakeTensor -# Pre-condition: the two args are the "old" and "new" inputs from running functionalization. -# When we run functionalization and wrap our inputs into FunctionalTensors, -# we can detect whether or not an input was mutated by checking to see if the inner tensor has changed -# -# Normally it would be enough just to check if arg is new_arg, which is normally enough for functionalization -# to confirm that inputs were not mutated when running the user's model with functionalization on. -# But when we have subclass inputs, we can't rely on that: -# `from_fun(to_fun(x)) is x` will return False, because the call to `from_fun` constructs -# a brand new subclass instance: we are calling __tensor_unflatten__, and going -# from Subclass(FakeTensor) to Subclass(FunctionalTensor(FakeTensor)) -def was_updated(arg, new_arg): - if is_traceable_wrapper_subclass(arg): - assert is_traceable_wrapper_subclass(new_arg) - attrs, _ = arg.__tensor_flatten__() - new_attrs, _ = new_arg.__tensor_flatten__() - assert attrs == new_attrs - # A tensor subclass was updated if any of its inner elements were updated - return any(was_updated(getattr(arg, attr), getattr(new_arg, attr)) for attr in attrs) else: - return arg is not new_arg - -# new_arg and arg here are either: -# (1) both a FakeTensor -# (2) both a traceable tensor subclass that holds a FakeTensor -# Pre-condition: the two args are the "old" and "new" inputs from running functionalization. -# When we run functionalization and wrap our inputs into FunctionalTensors, -# we can detect whether or not an input was mutated by checking to see if the inner tensor has changed, -# but shares storage with the old input -def was_metadata_updated(arg, new_arg): - if is_traceable_wrapper_subclass(arg): - assert is_traceable_wrapper_subclass(new_arg) - attrs, _ = arg.__tensor_flatten__() - new_attrs, _ = new_arg.__tensor_flatten__() - assert attrs == new_attrs + return False + +# f_arg here is either +# (1) A FunctionalTensor(_to_functional_tensor(FakeTensor)) +# (2) A traceable tensor subclass that holds a FunctionalTensor +# (3) Not a tensor +# Assumption: arg promises to be the "original" tensor wrapped by f_arg +# Note: "storage mutations" coming from set_() are a type of metadata mutation. So: +# - check_only_storage_mutation=True: only return true if there was a storage mutation +# - check_only_storage_mutation=Flse: return true if there was any metadata mutation (including a storage mutation) +def has_metadata_mutation(f_arg, arg, *, check_only_storage_mutation: bool): + if is_traceable_wrapper_subclass(f_arg): + attrs, _ = f_arg.__tensor_flatten__() # A tensor subclass was updated if any of its inner elements were updated - return any(was_metadata_updated(getattr(arg, attr), getattr(new_arg, attr)) for attr in attrs) + f_inner_ts = [getattr(f_arg, attr) for attr in attrs] + inner_ts = [getattr(arg, attr) for attr in attrs] + return any(has_metadata_mutation(f_inner_t, inner_t, check_only_storage_mutation=check_only_storage_mutation) + for f_inner_t, inner_t in zip(f_inner_ts, inner_ts)) else: - return arg is not new_arg and StorageWeakRef(arg.untyped_storage()) == StorageWeakRef(new_arg.untyped_storage()) + if not isinstance(f_arg, torch.Tensor): + assert not isinstance(arg, torch.Tensor) + return False + assert isinstance(f_arg, FunctionalTensor) + assert isinstance(arg, FakeTensor) + + arg_after = torch._from_functional_tensor(f_arg.elem) + # This is true if the current tensor experienced at least one set_() call + maybe_storage_changed = torch._functionalize_was_storage_changed(f_arg.elem) + # However, multiple set_() calls can cancel out. So we also check whether the + # storage of the tensor has changed. + # Note: if an input experienced two set_() calls that cancel out, **and** + # it experiences an data mutation, we pessimistically think that the set_() + # call is necessary here. We could in theory fix this, but this will + # hopefully never happen in user code, and is not needed for fsdp. + same_storages = StorageWeakRef(arg.untyped_storage()) == StorageWeakRef(arg_after.untyped_storage()) + has_storage_metadata_mutation = maybe_storage_changed and not same_storages + if check_only_storage_mutation: + return has_storage_metadata_mutation + + # storage metadata mutation is a type of metadata mutation, so return true if we saw one + if has_storage_metadata_mutation: + return True + + maybe_metadata_mutated = torch._functionalize_has_metadata_mutation(f_arg.elem) + # This is true if the current tensor experienced at least one metadata mutation. + # So if false, we know there was no metadata mutation + if not maybe_metadata_mutated: + return False + + # However, multi metadata mutations can cancel out. + # So we also check if the concrete sizes/strides on the tensor have changed. + same_sizes = arg.shape == arg_after.shape + same_strides = arg.stride() == arg_after.stride() + same_offsets = arg.storage_offset() == arg_after.storage_offset() + has_metadata_mutation_ = maybe_metadata_mutated and not (same_sizes and same_strides and same_offsets) + # We consider a tensor to have been metadata mutated if its storage was mutated through a set_() call. + return has_metadata_mutation_ + def _get_hints(exprs): """ @@ -1135,18 +1165,27 @@ def inner(*flat_args): new_arg = arg else: new_arg = from_fun(f_arg) - if was_updated(arg, new_arg): - if was_metadata_updated(arg, new_arg): - mutates_data = False - mutates_metadata = True - else: - mutates_data = True - mutates_metadata = has_metadata_mutation(f_arg) - mutations_hidden_from_autograd = are_all_mutations_hidden_from_autograd(f_arg) - else: + mutates_metadata = has_metadata_mutation(f_arg, arg, check_only_storage_mutation=False) + mutates_storage_metadata = has_metadata_mutation(f_arg, arg, check_only_storage_mutation=True) + mutates_data = has_data_mutation(f_arg) + mutations_hidden_from_autograd = are_all_mutations_hidden_from_autograd(f_arg) + + # Here, we're saying that if an input experienced a set call, inp.set_(other), + # then we can effectively not have to worry about whether its data was mutated. + # There are 3 cases: + # (1) We mutate inp *after* the set_() call. other is a graph intermediate. + # In this case, we're not really mutating the input storage of "inp"; + # we're mutating the storage of an intermdiate value (other), + # and slamming that storage into the input tensor. So no data mutation is necessary. + # (2) We mutate inp *after* the set_() call. other is a graph *input*. + # In this case, the data mutation will be properly handled in the runtime + # epilogue during the processing of "other" + # (3) We mutate inp *before* the set_() call. + # This case is *not* currently handled. + # TODO: discuss this in the PR. Both supporting this, and detecting + erroring out, + # seem painful to get working. + if mutates_storage_metadata: mutates_data = False - mutates_metadata = False - mutations_hidden_from_autograd = False requires_grad = isinstance(f_arg, torch.Tensor) and f_arg.requires_grad @@ -1155,6 +1194,7 @@ def inner(*flat_args): mutates_data=mutates_data, mutates_metadata=mutates_metadata, mutations_hidden_from_autograd=mutations_hidden_from_autograd, + mutates_storage_metadata=mutates_storage_metadata, requires_grad=requires_grad, mutation_type=_get_mutation_type( keep_input_mutations, @@ -1683,6 +1723,8 @@ def maybe_to_fresh_input(idx, t, meta): # Make sure the primal we pass to autograd.grad() # sees the tensor before the mutation return t.clone() + # No need to do anything for meta.input_info[idx].mutates_storage_metadata, + # Because autograd doesn't support set_() if meta.input_info[idx] and meta.input_info[idx].mutates_metadata: # Make sure the primal we pass to autograd.grad() # sees the tensor before the metadata mutation @@ -2660,7 +2702,8 @@ def create_synthetic_base_metadata( # mutations, they will be hidden from the rest of aot autograd. mutates_data=mutates_data, mutates_metadata=mutates_metadata, - mutations_hidden_from_autograd=mutations_hidden_from_autograd, + mutations_hidden_from_autograd=all(m.input_info[x].mutations_hidden_from_autograd for x in outer_indices), + mutates_storage_metadata=False if len(outer_indices) > 1 else m.input_info[outer_indices[0]].mutates_storage_metadata, is_leaf=any_leaf, requires_grad=requires_grad, mutation_type=mutation_type, @@ -3210,6 +3253,22 @@ def runtime_wrapper(*args): continue original_inpt = args[inpt_idx] updated_inpt = updated_inputs[i] + if meta.mutates_storage_metadata: + # mutates_storage_metadata means our input saw a x.set_(y) call. + # What if x **also** saw a data and/or a metadata mutation? + # (1) If the [meta]data mutation occurred after the set_(), + # then there is no need to copy_() the data. + # When we perform x.set_(x_updated), we are guaranteed that + # x_updated already has the final version of the data/metadata + # (2) If a data mutation occurred before the set_(). + # This case seems very difficult to support. + # TODO: discuss on the PR and decide if we want to tr to + # either support it, or detect and ban it. + if trace_joint: + assert isinstance(updated_inpt, TensorAlias) + updated_inpt = updated_inpt.alias + original_inpt.set_(updated_inpt) + continue if meta.mutates_metadata and not meta.mutates_data: if trace_joint: assert isinstance(updated_inpt, TensorAlias) diff --git a/torch/csrc/autograd/python_torch_functions_manual.cpp b/torch/csrc/autograd/python_torch_functions_manual.cpp index d9b8e6787058..913d29b4b69c 100644 --- a/torch/csrc/autograd/python_torch_functions_manual.cpp +++ b/torch/csrc/autograd/python_torch_functions_manual.cpp @@ -462,6 +462,48 @@ static PyObject* THPVariable__is_functional_tensor( END_HANDLE_TH_ERRORS } +static PyObject* THPVariable__functionalize_was_storage_changed( + PyObject* self, + PyObject* args, + PyObject* kwargs) { + HANDLE_TH_ERRORS + static PythonArgParser parser( + {"_functionalize_was_storage_changed(Tensor t)"}, /*traceable=*/true); + + ParsedArgs<1> parsed_args; + auto r = parser.parse(args, kwargs, parsed_args); + auto self_ = r.tensor(0); + TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(self_)); + auto wrapper = at::functionalization::impl::unsafeGetFunctionalWrapper(self_); + if (wrapper->was_storage_changed()) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } + END_HANDLE_TH_ERRORS +} + +static PyObject* THPVariable__functionalize_has_data_mutation( + PyObject* self, + PyObject* args, + PyObject* kwargs) { + HANDLE_TH_ERRORS + static PythonArgParser parser( + {"_functionalize_has_data_mutation(Tensor t)"}, /*traceable=*/true); + + ParsedArgs<1> parsed_args; + auto r = parser.parse(args, kwargs, parsed_args); + auto self_ = r.tensor(0); + TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(self_)); + auto wrapper = at::functionalization::impl::unsafeGetFunctionalWrapper(self_); + if (wrapper->has_data_mutation()) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } + END_HANDLE_TH_ERRORS +} + static PyObject* THPVariable__functionalize_has_metadata_mutation( PyObject* self, PyObject* args, @@ -741,6 +783,15 @@ static PyMethodDef torch_functions_manual[] = { THPVariable__functionalize_is_multi_output_view), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, + {"_functionalize_has_data_mutation", + castPyCFunctionWithKeywords(THPVariable__functionalize_has_data_mutation), + METH_VARARGS | METH_KEYWORDS | METH_STATIC, + nullptr}, + {"_functionalize_was_storage_changed", + castPyCFunctionWithKeywords( + THPVariable__functionalize_was_storage_changed), + METH_VARARGS | METH_KEYWORDS | METH_STATIC, + nullptr}, {"_functionalize_enable_reapply_views", castPyCFunctionWithKeywords( THPVariable__functionalize_enable_reapply_views), diff --git a/torchgen/gen_functionalization_type.py b/torchgen/gen_functionalization_type.py index d918bfb562fb..c39fc3e3e3bf 100644 --- a/torchgen/gen_functionalization_type.py +++ b/torchgen/gen_functionalization_type.py @@ -716,7 +716,11 @@ def emit_registration_helper(f: NativeFunction) -> str: return view_str elif isinstance(g, NativeFunctionsGroup): - fns = list(g.functions()) + # Gets a hand-written functionalization kernel + if g.inplace is not None and str(g.inplace.func.name) == "set_.source_Tensor": + fns = [] + else: + fns = list(g.functions()) else: if str(g.func.name) in MUTABLE_OPS_NOT_USING_FUNCTIONALIZATION: return [] From 4812a62ca0ba66f7d72691666e550da27ecb7a3d Mon Sep 17 00:00:00 2001 From: Jez Ng Date: Sat, 18 Nov 2023 14:46:30 -0800 Subject: [PATCH 024/221] [inductor] Delete more type-ignores in dependencies.py (#114013) A couple of type hints were wrong Pull Request resolved: https://github.com/pytorch/pytorch/pull/114013 Approved by: https://github.com/eellison --- torch/_inductor/dependencies.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/torch/_inductor/dependencies.py b/torch/_inductor/dependencies.py index dbf0b1910d90..7effbdcc3974 100644 --- a/torch/_inductor/dependencies.py +++ b/torch/_inductor/dependencies.py @@ -212,7 +212,7 @@ def reads_and_writes(self): class _RecordLoadStoreInner(V.MockHandler): # type: ignore[name-defined] def __init__(self, var_ranges: VarRanges, normalize: bool): super().__init__() - self._reads: Set[MemoryDep] = set() + self._reads: Set[Dep] = set() self._writes: Set[MemoryDep] = set() self._index_exprs: Set[IndexExprDep] = set() self._var_ranges: VarRanges = var_ranges @@ -220,14 +220,14 @@ def __init__(self, var_ranges: VarRanges, normalize: bool): def canonicalize( self, index: sympy.Expr - ) -> Tuple[sympy.Expr, Tuple[sympy.Expr, ...]]: + ) -> Tuple[sympy.Expr, Tuple[sympy.Symbol, ...], Tuple[sympy.Expr, ...]]: if not self._normalize: sizes = [V.graph.sizevars.simplify(x) for x in self._var_ranges.values()] var_names = tuple( k for k, v in zip(self._var_ranges.keys(), sizes) if v != 1 ) sizes = tuple(v for v in sizes if v != 1) - return index, var_names, sizes # type: ignore[return-value] + return index, var_names, sizes # Try to further simplify the indexes even if simplify_loops didn't # convert it to the simplest form because of the interference from @@ -240,7 +240,7 @@ def canonicalize( # if k in free_symbols } index_vars = [*var_ranges.keys()] - sizes = [*var_ranges.values()] # type: ignore[assignment] + sizes = tuple(var_ranges.values()) new_sizes, reindex, prune = V.graph.sizevars._simplify_loops( index_vars, sizes, @@ -261,10 +261,10 @@ def canonicalize( # downstream users won't. Normalize this away. new_vars.pop() new_sizes.pop() - return index, tuple(new_vars), tuple(new_sizes) # type: ignore[return-value] + return index, tuple(new_vars), tuple(new_sizes) def load(self, name: str, index: sympy.Expr) -> str: - self._reads.add(MemoryDep(name, *self.canonicalize(index))) # type: ignore[call-arg] + self._reads.add(MemoryDep(name, *self.canonicalize(index))) return f"load({name}, {sympy_str(index)})" def load_seed(self, name: str, index: int): @@ -272,14 +272,14 @@ def load_seed(self, name: str, index: int): return self.load(name, sympy.Integer(index)) def store(self, name: str, index: sympy.Expr, value: str, mode=None) -> str: - self._writes.add(MemoryDep(name, *self.canonicalize(index))) # type: ignore[call-arg] + self._writes.add(MemoryDep(name, *self.canonicalize(index))) return f"store({name}, {sympy_str(index)}, {value}, {mode})" def store_reduction(self, name: str, index, value) -> str: return self.store(name, index, f"store_reduction({value})") def index_expr(self, index: sympy.Expr, dtype) -> str: - self._index_exprs.add(IndexExprDep(*self.canonicalize(index))) # type: ignore[call-arg] + self._index_exprs.add(IndexExprDep(*self.canonicalize(index))) return f"index_expr({sympy_str(index)}, {dtype})" def bucketize( @@ -290,7 +290,7 @@ def bucketize( indexing_dtype: torch.dtype, right: bool, ): - self._reads.add(StarDep(offsets_name)) # type: ignore[arg-type] + self._reads.add(StarDep(offsets_name)) return f"bucketize({values}, {offsets_name}, {sympy_str(offsets_size)}, {indexing_dtype}, {right})" @@ -376,7 +376,7 @@ def extract_read_writes( ) -def extract_input_node_reduction_ranges( # noqa: F722 +def extract_input_node_reduction_ranges( input_node: "torch._inductor.ir.TensorBox", ) -> Tuple[Optional[List[sympy.Expr]], Optional[List[sympy.Expr]]]: """ From 87925789ae1509fd04dc9105d4fc8e00d8ad544a Mon Sep 17 00:00:00 2001 From: Jez Ng Date: Sat, 18 Nov 2023 18:01:16 -0800 Subject: [PATCH 025/221] Make V.graph properly typed (#114025) Previously it lacked a type hint and so was treated as an Any type. This resulted in a lot of untyped code downstream as V.graph is referenced in many places in inductor code. I've typed it properly now as GraphLowering, and fixed the numerous type errors this surfaced. Pull Request resolved: https://github.com/pytorch/pytorch/pull/114025 Approved by: https://github.com/eellison ghstack dependencies: #114013 --- torch/_inductor/autotune_process.py | 22 ++++++++++++----- torch/_inductor/codegen/triton.py | 37 ++++++++++++++++++----------- torch/_inductor/codegen/wrapper.py | 8 ++++--- torch/_inductor/graph.py | 20 +++++++--------- torch/_inductor/ir.py | 25 +++++++++++++------ torch/_inductor/lowering.py | 6 +++-- torch/_inductor/scheduler.py | 7 +++--- torch/_inductor/sizevars.py | 4 ++-- torch/_inductor/virtualized.py | 9 ++++--- 9 files changed, 87 insertions(+), 51 deletions(-) diff --git a/torch/_inductor/autotune_process.py b/torch/_inductor/autotune_process.py index 4365fd642e87..7bf1f572238d 100644 --- a/torch/_inductor/autotune_process.py +++ b/torch/_inductor/autotune_process.py @@ -12,7 +12,17 @@ from ctypes import byref, c_size_t, c_void_p from multiprocessing.process import BaseProcess from multiprocessing.queues import Queue -from typing import Any, Callable, Dict, List, Optional, Sequence, TYPE_CHECKING, Union +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Optional, + Sequence, + TYPE_CHECKING, + Union, +) import torch from torch import multiprocessing @@ -331,8 +341,8 @@ def benchmark( class TensorMeta: device: torch.device dtype: torch.dtype - sizes: List[int] - strides: List[int] + sizes: torch._prims_common.ShapeType + strides: torch._prims_common.StrideType offset: int @classmethod @@ -390,7 +400,7 @@ def __init__( kernel_name: str, input_tensor_meta: Union[TensorMeta, List[TensorMeta]], output_tensor_meta: Union[TensorMeta, List[TensorMeta]], - extra_args: Dict[str, Any], + extra_args: Iterable[Any], ): # the kernel name defined in the module self.kernel_name = kernel_name @@ -478,7 +488,7 @@ def __init__( kernel_name: str, input_tensor_meta: Union[TensorMeta, List[TensorMeta]], output_tensor_meta: Union[TensorMeta, List[TensorMeta]], - extra_args: Dict[str, Any], + extra_args: Iterable[Any], module_path: str, # the path of the module defining the triton kernel module_cache_key: str, grid: List[int], @@ -525,7 +535,7 @@ def __init__( kernel_name: str, input_tensor_meta: Union[TensorMeta, List[TensorMeta]], output_tensor_meta: Union[TensorMeta, List[TensorMeta]], - extra_args: Dict[str, Any], + extra_args: Iterable[Any], source_code: str, ): super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args) diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index ee0fb32e0519..d178edae9813 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -9,7 +9,7 @@ import math import operator import os -from typing import Any, Counter, Dict, Iterable, List, Optional, Set, Tuple +from typing import Any, Counter, Dict, Iterable, List, Optional, Set, Tuple, Union import sympy @@ -811,8 +811,8 @@ class TritonKernel(Kernel): def __init__( self, *groups, - index_dtype, - mutations=None, + index_dtype: str, + mutations: Optional[Set[str]] = None, pid_cache=None, reduction_hint=ReductionHint.DEFAULT, min_elem_per_thread=0, @@ -821,21 +821,21 @@ def __init__( pid_cache = {} super().__init__() self.numels = [V.graph.sizevars.simplify(s) for s in groups] - self.mutations = mutations + self.mutations: Set[str] = mutations if mutations is not None else set() self.range_trees: List[IterationRangesRoot] = [] - self.range_tree_nodes = {} + self.range_tree_nodes: Dict[sympy.Symbol, IterationRangesEntry] = {} self.iter_vars_count = itertools.count() self.inside_reduction = self.numels[-1] != 1 self.body = IndentedBuffer() self.indexing_code = IndentedBuffer() self.suffix: IndentedBuffer = IndentedBuffer() # type: ignore[assignment] - self.outside_loop_vars = set() + self.outside_loop_vars: Set[Any] = set() self.reduction_hint = reduction_hint - self.index_dtype = index_dtype + self.index_dtype: str = index_dtype self.min_elem_per_thread = min_elem_per_thread - self.last_usage = set() + self.last_usage: Set[str] = set() - self.persistent_reduction = self.should_use_persistent_reduction() + self.persistent_reduction: bool = self.should_use_persistent_reduction() self.no_x_dim = ( self.reduction_hint == ReductionHint.INNER and self.persistent_reduction @@ -857,7 +857,7 @@ def simplify_indexing(index: sympy.Expr): self.simplify_indexing = simplify_indexing - def should_use_persistent_reduction(self): + def should_use_persistent_reduction(self) -> bool: """ Heuristic to set self.persistent_reduction and add guards if needed. @@ -1057,6 +1057,7 @@ def is_broadcasted(self, index: sympy.Expr): # Non-iterated variables, e.g. strides continue entry = self.range_tree_nodes[symbol] + assert isinstance(entry.parent, IterationRangesRoot) index_numels[entry.parent.index] *= entry.length # If the index variables only iterate over a subset of the kernel @@ -1490,6 +1491,9 @@ def reduction(self, dtype, src_dtype, reduction_type, value): value, ) + dim: int + root_op: str + def final_reduction(value): use_helper = reduction_type in {"any", "max", "min", "prod"} module = "triton_helpers" if use_helper else "tl" @@ -1915,8 +1919,11 @@ def codegen_kernel(self, name=None): mutated_args.add(self.args.output_buffers[mutation]) mutated_args = sorted(mutated_args) + triton_meta_signature = signature_to_meta( + signature, size_dtype=self.index_dtype + ) triton_meta = { - "signature": signature_to_meta(signature, size_dtype=self.index_dtype), + "signature": triton_meta_signature, "device": V.graph.scheduler.current_device.index, "device_type": V.graph.scheduler.current_device.type, "constants": {}, @@ -1932,7 +1939,7 @@ def codegen_kernel(self, name=None): if tree.prefix != "r" or self.inside_reduction: sizearg = SizeArg(f"{tree.prefix}numel", tree.numel) signature.append(sizearg) - triton_meta["signature"][len(argdefs)] = signature_of( + triton_meta_signature[len(argdefs)] = signature_of( sizearg, size_dtype=self.index_dtype ) argdefs.append(f"{tree.prefix}numel") @@ -2391,7 +2398,9 @@ def reduction_hint(node): return node.node.data.reduction_hint @staticmethod - def can_use_32bit_indexing(numel: sympy.Expr, buffers: Iterable[ir.Buffer]) -> bool: + def can_use_32bit_indexing( + numel: sympy.Expr, buffers: Iterable[Union[ir.Buffer, ir.TensorBox]] + ) -> bool: int_max = torch.iinfo(torch.int32).max size_hint = V.graph.sizevars.size_hint has_hint = V.graph.sizevars.shape_env.has_hint @@ -2437,7 +2446,7 @@ def select_index_dtype(node_schedule, numel, reduction_numel): buffer_names.update(node.used_buffer_names()) # Get buffers objects - def _get_buffer(name: str) -> ir.Buffer: + def _get_buffer(name: str) -> Union[ir.Buffer, ir.TensorBox]: if name in V.graph.name_to_buffer: return V.graph.name_to_buffer[name] elif name in V.graph.graph_inputs: diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 9648051e31c0..b66381a6f1f8 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -6,7 +6,7 @@ import os import re from itertools import chain, count -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Set, Tuple, Union import sympy from sympy import Expr @@ -359,7 +359,7 @@ def __init__(self): self.write_constant(name, hashed) self.allocated = set() - self.freed = set() + self.freed: Set[str] = set() # maps from reusing buffer to reused buffer self.reuses = dict() @@ -653,7 +653,9 @@ def codegen_input_stride_var_decl(self, code: IndentedBuffer, name): f"{self.declare}{name}_stride = {name}.{self.stride}{self.ending}" ) - def codegen_inputs(self, code: IndentedBuffer, graph_inputs: Dict[str, ir.Buffer]): + def codegen_inputs( + self, code: IndentedBuffer, graph_inputs: Dict[str, ir.TensorBox] + ): """Assign all symbolic shapes to locals""" @functools.lru_cache(None) diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 1027fbf7c854..021741129ef9 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -107,6 +107,8 @@ def is_magic_method(op): class GraphLowering(torch.fx.Interpreter): + graph_outputs: List[ir.IRNode] + def symbolic_sizes_strides(self, ex: torch.Tensor): """ Support dynamic shapes and dynamic strides by assigning variables @@ -196,11 +198,10 @@ def __init__( self.sizevars = SizeVarAllocator(shape_env) self.graph_inputs: Dict[str, TensorBox] = {} self.graph_inputs_original: Dict[str, InputBuffer] = {} - self.graph_outputs: Optional[List[ir.IRNode]] = None self.device_types: Set[str] = set() self.device_idxs: Set[int] = set() self.cuda = False - self.buffers: List[ir.ComputedBuffer] = [] + self.buffers: List[ir.Buffer] = [] self.constants: Dict[str, torch.Tensor] = {} self.constant_reprs: Dict[str, str] = {} self.removed_buffers: Set[str] = set() @@ -208,25 +209,25 @@ def __init__( self.mutated_buffers: Set[str] = set() self.never_reuse_buffers: Set[str] = set() self.inplaced_to_remove: Set[str] = set() - self.wrapper_code: Optional[WrapperCodeGen] = None + self.wrapper_code: WrapperCodeGen = None # type: ignore[assignment] # See `ProxyExecutor Design Note` in ir.py for more details self.extern_kernel_nodes: List[ir.ExternKernelNode] = [] self.extern_node_serializer: Optional[ Callable[[List[ir.ExternKernelNode]], Any] ] = extern_node_serializer - self.current_node: Optional[torch.fx.Node] = None + self.current_node: torch.fx.Node = None # type: ignore[assignment] self.num_static_inputs = num_static_inputs self.lists: Dict[str, List[str]] = {} self.mutated_inputs: Set[str] = set() self.mutated_input_idxs: List[int] = [] - self.name_to_buffer: Dict[str, ir.ComputedBuffer] = {} + self.name_to_buffer: Dict[str, ir.Buffer] = {} self.name_to_users: DefaultDict[str, List[ir.IRNode]] = defaultdict(list) self.creation_time = time.time() self.name = "GraphLowering" self.cpp_wrapper = cpp_wrapper self.aot_mode = aot_mode self.graph_id = graph_id - self.scheduler = None + self.scheduler: "torch._inductor.scheduler.Scheduler" = None # type: ignore[assignment] self.nodes_prefer_channels_last = ( self.find_nodes_prefer_channels_last() if self.layout_opt else set() ) @@ -447,7 +448,7 @@ def get_numel(self, buffer_name: str): def run(self, *args): return super().run(*args) - def register_buffer(self, buffer: ir.ComputedBuffer): + def register_buffer(self, buffer: ir.Buffer): name = f"buf{len(self.buffers)}" self.buffers.append(buffer) self.name_to_buffer[name] = buffer @@ -533,7 +534,7 @@ def allocate(name): ) ) - def constant_name(self, name: str, device_override: torch.device): + def constant_name(self, name: str, device_override: Optional[torch.device]): """ We AOT copy constants to the devices they are needed on. If device_override doesn't match the constant's device, then @@ -970,10 +971,8 @@ def codegen(self): self.init_wrapper_code() self.scheduler = Scheduler(self.buffers) - assert self.scheduler is not None # mypy can't figure this out V.debug.draw_orig_fx_graph(self.orig_gm, self.scheduler.nodes) self.scheduler.codegen() - assert self.wrapper_code is not None return self.wrapper_code.generate(self.is_inference) def count_bytes(self): @@ -1049,7 +1048,6 @@ def compile_to_fn(self): return self.compile_to_module().call def get_output_names(self): - assert self.graph_outputs is not None return [ node.get_name() for node in self.graph_outputs diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 3df69422ccb9..5daac542d1d1 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -201,6 +201,7 @@ def ir_node_to_tensor(x, guard_shape=True): if x is None: return None + shape_fn: Callable[[Expr], Union[int, Expr]] if not guard_shape: shape_fn = V.graph.sizevars.size_hint else: @@ -314,6 +315,9 @@ def realize(self): """ raise NotImplementedError(f"realize NYI on {type(self)}") + def codegen_reference(self, writer=None): + raise NotImplementedError(f"codegen_reference NYI on {type(self)}") + # The abstract method declarations below serve to convince mypy that all IRNode instances have these functions # defined, while having no effect at runtime. We cannot create stub implementations here because other parts of # the code dynamically check for defined attributes. @@ -326,7 +330,7 @@ def realize(self): has_exceeded_max_reads: Callable[[], bool] make_loader: Callable[[], Callable[[Any], Any]] make_indexer: Callable[[], Callable[[Any], Any]] - mark_reuse: Callable[[List[Any]], None] + mark_reuse: Callable[[int], None] realize_hint: Callable[[], None] @@ -2573,7 +2577,7 @@ def __post_init__(self): def make_indexer(self): return self.layout.make_indexer() - def get_name(self): + def get_name(self) -> str: assert self.name return self.name @@ -2768,19 +2772,22 @@ class InputBuffer(Buffer): class ConstantBuffer(InputBuffer): - override_device = None + override_device: Optional[torch.device] = None def make_loader(self): def loader(index): indexer = self.layout.make_indexer() return ops.load( - V.graph.constant_name(self.name, self.override_device), indexer(index) + V.graph.constant_name(self.get_name(), self.override_device), + indexer(index), ) return loader def constant_to_device(self, device): - return ConstantBuffer(V.graph.constant_name(self.name, device), self.layout) + return ConstantBuffer( + V.graph.constant_name(self.get_name(), device), self.layout + ) class NoneAsConstantBuffer(IRNode): @@ -3380,7 +3387,7 @@ def process_kernel(cls, kernel, *args, **kwargs): is_arg_tensor = [] tensor_args = [] - non_tensor_args = [] + non_tensor_args: List[Any] = [] for arg in args_flat: is_arg_tensor.append(isinstance(arg, IRNode)) if is_arg_tensor[-1]: @@ -4200,8 +4207,9 @@ class ExternKernelNode: } -@dataclasses.dataclass class FallbackKernel(ExternKernelAlloc): + args_default_value: List[Dict[str, Any]] + def __init__( self, layout, @@ -6037,6 +6045,9 @@ def __getattr__(self, name): def realize(self): return self.data.realize() + def codegen_reference(self, writer=None): + return self.data.codegen_reference(writer) + @property def layout(self): return self.data.layout # type: ignore[attr-defined] diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 61c8d4e61ef4..2dac15949d2e 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -5,7 +5,7 @@ import warnings from collections import defaultdict from collections.abc import Iterable -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import sympy @@ -811,6 +811,8 @@ def repeat(x, repeats): if all((a == 1 or b == 1) for a, b in zip(repeats, old_size)): return expand(x, new_size) + x_loader: Callable[[Any], Any] + def inner_fn(index): assert len(index) == len(repeats) index = list(index) @@ -3578,7 +3580,7 @@ def constant_pad_nd(x, padding, fill_value=0): n = len(sizes) - len(bounds) # if padding is a complicated expression, hoist it - bounds_precomp = [] + bounds_precomp: List[Tuple[sympy.Symbol, Any]] = [] for l, h in bounds: l_precomp = ( V.graph.sizevars.lookup_precomputed_size(l) diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index a40bd9209538..8fdfbceeb3ee 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -491,6 +491,7 @@ def is_materialized(buf, snodes): for buf_name in reads | writes: buf_accessed_elems = sum([node_numel for dep in buf_accesses[buf_name]]) + buf: Union[ir.Buffer, ir.TensorBox] if buf_name in V.graph.name_to_buffer: buf = V.graph.name_to_buffer[buf_name] elif buf_name in V.graph.graph_inputs: @@ -504,7 +505,7 @@ def get_buf_elems(buf): # Kind of a lazy way to get the MultiOutput nodes corresponding to # a MultiOutputLayout if isinstance(buf.layout, MultiOutputLayout): - users = self.scheduler.name_to_node[buf.name].users + users = self.scheduler.name_to_node[buf.get_name()].users buf_elems = sum(get_buf_elems(user.node.node) for user in users) else: buf_elems = get_buf_elems(buf) @@ -1216,7 +1217,7 @@ def __init__(self, nodes): self.debug_draw_graph() # used during codegen: - self.current_device = None + self.current_device: torch.device = None # type: ignore[assignment] self.buffer_names_to_free = set() # fx graph node to the position it appears in the graph @@ -2012,7 +2013,7 @@ def free_buffers(self): V.graph.wrapper_code.codegen_free(node.node) elif name in V.graph.graph_inputs: storage = V.graph.graph_inputs[name].data - assert storage.is_input_buffer() + assert isinstance(storage, ir.StorageBox) and storage.is_input_buffer() V.graph.wrapper_code.codegen_free(storage.data) self.buffer_names_to_free.clear() diff --git a/torch/_inductor/sizevars.py b/torch/_inductor/sizevars.py index 1839bd7b7aea..d62fc7d75852 100644 --- a/torch/_inductor/sizevars.py +++ b/torch/_inductor/sizevars.py @@ -1,7 +1,7 @@ import functools import itertools import logging -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union import sympy from sympy import Expr @@ -393,7 +393,7 @@ def size_hint(self, expr: Expr, *, fallback: Optional[int] = None) -> int: def size_hints( self, - exprs: List[Expr], + exprs: Iterable[Expr], *, fallback: Optional[int] = None, ) -> Tuple[int, ...]: diff --git a/torch/_inductor/virtualized.py b/torch/_inductor/virtualized.py index 7a8f748ce84a..5ad062fc979b 100644 --- a/torch/_inductor/virtualized.py +++ b/torch/_inductor/virtualized.py @@ -4,7 +4,7 @@ from contextlib import contextmanager from itertools import chain from threading import local -from typing import Any, Callable, Union +from typing import Any, Callable, TYPE_CHECKING, Union from unittest.mock import patch import sympy @@ -15,6 +15,9 @@ from .utils import reduction_num_outputs, sympy_str, sympy_symbol +if TYPE_CHECKING: + from torch._inductor.graph import GraphLowering + threadlocal = local() @@ -287,7 +290,7 @@ class _V: set_ops_handler: Callable[[Any], Any] = _ops._set_handler get_ops_handler: Callable[[], Any] = _ops._get_handler - set_graph_handler: Callable[[Any], Any] = _graph._set_handler + set_graph_handler: Callable[[GraphLowering], Any] = _graph._set_handler set_real_inputs: Callable[[Any], Any] = _real_inputs._set_handler get_real_inputs: Callable[[], Any] = _real_inputs._get_handler set_fake_mode: Callable[[Any], Any] = _fake_mode._set_handler @@ -306,7 +309,7 @@ def ops(self) -> _MockHandler: return _ops._get_handler() @property - def graph(self): + def graph(self) -> GraphLowering: """The graph currently being generated""" return _graph._get_handler() From 36869463e08a8aeb2ce74aa54920b2f8859bc123 Mon Sep 17 00:00:00 2001 From: Xilun Wu <12968408+XilunWu@users.noreply.github.com> Date: Mon, 20 Nov 2023 13:51:54 -0800 Subject: [PATCH 026/221] [DTensor] add forward layer norm test (#114174) Pull Request resolved: https://github.com/pytorch/pytorch/pull/114174 Approved by: https://github.com/fduwjj, https://github.com/wanchaol --- test/distributed/_tensor/test_math_ops.py | 41 ++++++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) diff --git a/test/distributed/_tensor/test_math_ops.py b/test/distributed/_tensor/test_math_ops.py index c5e90988cf75..4b9ff5a63b90 100644 --- a/test/distributed/_tensor/test_math_ops.py +++ b/test/distributed/_tensor/test_math_ops.py @@ -1,11 +1,12 @@ # Copyright (c) Meta Platforms, Inc. and affiliates # Owner(s): ["oncall: distributed"] +import copy import itertools import torch -from torch.distributed._tensor import DeviceMesh, distribute_tensor +from torch.distributed._tensor import DeviceMesh, distribute_module, distribute_tensor from torch.distributed._tensor.placement_types import Replicate, Shard from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( @@ -150,6 +151,44 @@ def test_full_shard_math_ops(self): actual_local_res = actual_rs.to_local() self.assertEqual(actual_local_res, expect_rs) + @with_comms + def test_layer_norm(self): + device_mesh = self.build_device_mesh() + + # NLP example from pytorch docs + # https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html + batch, sentence_length, embedding_dim = 20, 5, 10 + x = torch.rand(batch, sentence_length, embedding_dim, device=self.device_type) + norm_shape_idx_list = list(range(x.ndim)) + shard_dims = [-1, 0, 1, 2] + test_config_list = list(itertools.product(shard_dims, norm_shape_idx_list)) + + # normalized shape is a torch.Size object + for shard_dim, norm_idx in test_config_list: + normalized_shape = x.shape[norm_idx:] + layer_norm = torch.nn.LayerNorm(normalized_shape) + layer_norm_local = copy.deepcopy(layer_norm).to(self.device_type) + + def _replicate_fn(name, module, device_mesh): + for name, param in module.named_parameters(): + if name in ["weight", "bias"]: + param_dist = torch.nn.Parameter( + distribute_tensor(param, device_mesh, [Replicate()]) + ) + module.register_parameter(name, param_dist) + + layer_norm_dist = distribute_module(layer_norm, device_mesh, _replicate_fn) + + x_local = x + x_dist = distribute_tensor(x, device_mesh, [Shard(shard_dim)]) + + y_local = layer_norm_local(x_local) + y_dist = layer_norm_dist(x_dist).redistribute( + device_mesh, placements=[Replicate()] + ) + + self.assertEqual(y_local, y_dist.to_local()) + if __name__ == "__main__": run_tests() From b09bd364025660b06280fac43e034358101525dc Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Mon, 20 Nov 2023 10:58:11 -0800 Subject: [PATCH 027/221] [dtensor] add test for adamw (#114149) This PR add tests for adamw optimizers Pull Request resolved: https://github.com/pytorch/pytorch/pull/114149 Approved by: https://github.com/XilunWu --- test/distributed/_tensor/test_optimizers.py | 42 +++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/test/distributed/_tensor/test_optimizers.py b/test/distributed/_tensor/test_optimizers.py index e1add4719309..bbfd99846c97 100644 --- a/test/distributed/_tensor/test_optimizers.py +++ b/test/distributed/_tensor/test_optimizers.py @@ -112,3 +112,45 @@ def test_adam_1d_sharding(self): # on different ranks inp = torch.ones(8, 10, device=self.device_type) self._assert_optimizer(mesh, mod, opt, dist_mod, dist_opt, inp) + + @with_comms + def test_adamw_1d_sharding(self): + mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + + # TODO: add fused_adamw support + adamw_configs = [ + {"lr": 0.1}, + {"lr": 0.1, "weight_decay": 0.05}, + {"lr": 0.1, "weight_decay": 0.05, "foreach": True}, + { + "lr": 0.1, + "betas": (0.6, 0.66), + "eps": 1e-6, + "weight_decay": 0.05, + "amsgrad": True, + "foreach": True, + }, + { + "lr": 0.1, + "betas": (0.6, 0.66), + "eps": 1e-6, + "weight_decay": 0.05, + "maximize": True, + "amsgrad": True, + "foreach": True, + }, + ] + + for config in adamw_configs: + mod = MLPModule(self.device_type) + opt = torch.optim.AdamW(mod.parameters(), **config) + + dist_mod = distribute_module( + deepcopy(mod), mesh, shard_fn, input_fn, output_fn + ) + dist_opt = torch.optim.AdamW(dist_mod.parameters(), **config) + + # use ones to make sure the single machine model have the same input + # on different ranks + inp = torch.ones(8, 10, device=self.device_type) + self._assert_optimizer(mesh, mod, opt, dist_mod, dist_opt, inp) From 9b50611002277a6c78b2d0b9a3561e879b391ed6 Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Mon, 20 Nov 2023 10:58:14 -0800 Subject: [PATCH 028/221] [dtensor] add test for SGD optimizer (#114150) as titled Pull Request resolved: https://github.com/pytorch/pytorch/pull/114150 Approved by: https://github.com/XilunWu ghstack dependencies: #114149 --- test/distributed/_tensor/test_optimizers.py | 40 +++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/test/distributed/_tensor/test_optimizers.py b/test/distributed/_tensor/test_optimizers.py index bbfd99846c97..e6d8cbc40873 100644 --- a/test/distributed/_tensor/test_optimizers.py +++ b/test/distributed/_tensor/test_optimizers.py @@ -154,3 +154,43 @@ def test_adamw_1d_sharding(self): # on different ranks inp = torch.ones(8, 10, device=self.device_type) self._assert_optimizer(mesh, mod, opt, dist_mod, dist_opt, inp) + + @with_comms + def test_sgd_1d_sharding(self): + mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + + sgd_configs = [ + {"lr": 0.1}, + {"lr": 0.1, "momentum": 0.05}, + {"lr": 0.1, "momentum": 0.05, "foreach": True}, + {"lr": 0.1, "momentum": 0.06, "dampening": 0.07, "foreach": True}, + { + "lr": 0.1, + "momentum": 0.08, + "weight_decay": 0.05, + "nesterov": True, + "maximize": True, + }, + { + "lr": 0.1, + "momentum": 0.08, + "weight_decay": 0.05, + "nesterov": True, + "maximize": True, + "foreach": True, + }, + ] + + for config in sgd_configs: + mod = MLPModule(self.device_type) + opt = torch.optim.SGD(mod.parameters(), **config) + + dist_mod = distribute_module( + deepcopy(mod), mesh, shard_fn, input_fn, output_fn + ) + dist_opt = torch.optim.SGD(dist_mod.parameters(), **config) + + # use ones to make sure the single machine model have the same input + # on different ranks + inp = torch.ones(8, 10, device=self.device_type) + self._assert_optimizer(mesh, mod, opt, dist_mod, dist_opt, inp) From bcd310a7adc8623cd793e8241372afd094d6e6e9 Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Mon, 20 Nov 2023 14:04:37 -0800 Subject: [PATCH 029/221] [dtensor] enable adagrad foreach support (#114151) This PR enables the adagrad foreach mode support Pull Request resolved: https://github.com/pytorch/pytorch/pull/114151 Approved by: https://github.com/XilunWu ghstack dependencies: #114149, #114150 --- test/distributed/_tensor/test_optimizers.py | 54 +++++++++++++++++++ .../distributed/_tensor/ops/pointwise_ops.py | 12 ++++- 2 files changed, 64 insertions(+), 2 deletions(-) diff --git a/test/distributed/_tensor/test_optimizers.py b/test/distributed/_tensor/test_optimizers.py index e6d8cbc40873..5bd78dae6b92 100644 --- a/test/distributed/_tensor/test_optimizers.py +++ b/test/distributed/_tensor/test_optimizers.py @@ -194,3 +194,57 @@ def test_sgd_1d_sharding(self): # on different ranks inp = torch.ones(8, 10, device=self.device_type) self._assert_optimizer(mesh, mod, opt, dist_mod, dist_opt, inp) + + @with_comms + def test_adagrad_1d_sharding(self): + mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + + adagrad_configs = [ + {"lr": 0.1}, + {"lr": 0.1, "lr_decay": 0.05}, + {"lr": 0.1, "lr_decay": 0.02, "weight_decay": 0.05}, + { + "lr": 0.1, + "lr_decay": 0.02, + "weight_decay": 0.05, + "initial_accumulator_value": 0.03, + }, + { + "lr": 0.1, + "lr_decay": 0.02, + "weight_decay": 0.05, + "initial_accumulator_value": 0.03, + "eps": 1e-6, + }, + { + "lr": 0.1, + "lr_decay": 0.02, + "weight_decay": 0.05, + "initial_accumulator_value": 0.03, + "eps": 1e-6, + "maximize": True, + }, + { + "lr": 0.1, + "lr_decay": 0.02, + "weight_decay": 0.05, + "initial_accumulator_value": 0.03, + "eps": 1e-6, + "maximize": True, + "foreach": True, + }, + ] + + for config in adagrad_configs: + mod = MLPModule(self.device_type) + opt = torch.optim.Adagrad(mod.parameters(), **config) + + dist_mod = distribute_module( + deepcopy(mod), mesh, shard_fn, input_fn, output_fn + ) + dist_opt = torch.optim.Adagrad(dist_mod.parameters(), **config) + + # use ones to make sure the single machine model have the same input + # on different ranks + inp = torch.ones(8, 10, device=self.device_type) + self._assert_optimizer(mesh, mod, opt, dist_mod, dist_opt, inp) diff --git a/torch/distributed/_tensor/ops/pointwise_ops.py b/torch/distributed/_tensor/ops/pointwise_ops.py index 6c16fa10f199..8c7ef7a925c7 100644 --- a/torch/distributed/_tensor/ops/pointwise_ops.py +++ b/torch/distributed/_tensor/ops/pointwise_ops.py @@ -520,14 +520,22 @@ def linear_pointwise_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> Strategy # TODO: add all for_each ops for_each_ops = [ - aten._foreach_addcmul_.Scalar, + aten._foreach_addcdiv_.Scalar, aten._foreach_addcdiv_.ScalarList, + aten._foreach_addcdiv_.Tensor, + aten._foreach_addcmul_.Scalar, + aten._foreach_addcmul_.ScalarList, + aten._foreach_addcmul_.Tensor, aten._foreach_div_.ScalarList, aten._foreach_lerp_.Scalar, aten._foreach_maximum_.List, + aten._foreach_mul.Scalar, + aten._foreach_mul.List, aten._foreach_mul_.Scalar, - aten._foreach_neg_.default, + aten._foreach_mul_.ScalarList, + aten._foreach_mul_.List, aten._foreach_neg.default, + aten._foreach_neg_.default, aten._foreach_reciprocal_.default, aten._foreach_sub_.Scalar, aten._foreach_sqrt.default, From bbc39b7bb48d28d67e3253a89cc82df3687ddd1b Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Mon, 20 Nov 2023 14:04:38 -0800 Subject: [PATCH 030/221] [dtensor] enable RMSprop optimizer foreach support (#114152) as titled Pull Request resolved: https://github.com/pytorch/pytorch/pull/114152 Approved by: https://github.com/XilunWu ghstack dependencies: #114149, #114150, #114151 --- test/distributed/_tensor/test_optimizers.py | 59 +++++++++++++++++++ .../distributed/_tensor/ops/pointwise_ops.py | 1 + 2 files changed, 60 insertions(+) diff --git a/test/distributed/_tensor/test_optimizers.py b/test/distributed/_tensor/test_optimizers.py index 5bd78dae6b92..4d8b2d822d53 100644 --- a/test/distributed/_tensor/test_optimizers.py +++ b/test/distributed/_tensor/test_optimizers.py @@ -248,3 +248,62 @@ def test_adagrad_1d_sharding(self): # on different ranks inp = torch.ones(8, 10, device=self.device_type) self._assert_optimizer(mesh, mod, opt, dist_mod, dist_opt, inp) + + @with_comms + def test_RMSprop_1d_sharding(self): + mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + + RMSprop_configs = [ + {"lr": 0.1}, + {"lr": 0.1, "alpha": 0.85}, + {"lr": 0.1, "alpha": 0.88, "eps": 1e-6}, + {"lr": 0.1, "alpha": 0.88, "eps": 1e-6, "weight_decay": 0.05}, + { + "lr": 0.1, + "alpha": 0.88, + "eps": 1e-6, + "weight_decay": 0.05, + "momentum": 0.9, + }, + { + "lr": 0.1, + "alpha": 0.88, + "eps": 1e-6, + "weight_decay": 0.05, + "momentum": 0.9, + "centered": True, + }, + { + "lr": 0.1, + "alpha": 0.88, + "eps": 1e-6, + "weight_decay": 0.05, + "momentum": 0.9, + "centered": True, + "maximize": True, + }, + { + "lr": 0.1, + "alpha": 0.88, + "eps": 1e-6, + "weight_decay": 0.05, + "momentum": 0.9, + "centered": True, + "maximize": True, + "foreach": True, + }, + ] + + for config in RMSprop_configs: + mod = MLPModule(self.device_type) + opt = torch.optim.RMSprop(mod.parameters(), **config) + + dist_mod = distribute_module( + deepcopy(mod), mesh, shard_fn, input_fn, output_fn + ) + dist_opt = torch.optim.RMSprop(dist_mod.parameters(), **config) + + # use ones to make sure the single machine model have the same input + # on different ranks + inp = torch.ones(8, 10, device=self.device_type) + self._assert_optimizer(mesh, mod, opt, dist_mod, dist_opt, inp) diff --git a/torch/distributed/_tensor/ops/pointwise_ops.py b/torch/distributed/_tensor/ops/pointwise_ops.py index 8c7ef7a925c7..368af29a9934 100644 --- a/torch/distributed/_tensor/ops/pointwise_ops.py +++ b/torch/distributed/_tensor/ops/pointwise_ops.py @@ -523,6 +523,7 @@ def linear_pointwise_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> Strategy aten._foreach_addcdiv_.Scalar, aten._foreach_addcdiv_.ScalarList, aten._foreach_addcdiv_.Tensor, + aten._foreach_addcmul.Scalar, aten._foreach_addcmul_.Scalar, aten._foreach_addcmul_.ScalarList, aten._foreach_addcmul_.Tensor, From e76c54bd87bff96796bc63779c0bbbe9a5cc803d Mon Sep 17 00:00:00 2001 From: PyTorch UpdateBot Date: Tue, 21 Nov 2023 03:39:41 +0000 Subject: [PATCH 031/221] [vision hash update] update the pinned vision hash (#113217) This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/_update-commit-hash.yml). Update the pinned vision hash. Pull Request resolved: https://github.com/pytorch/pytorch/pull/113217 Approved by: https://github.com/pytorchbot --- .github/ci_commit_pins/vision.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/ci_commit_pins/vision.txt b/.github/ci_commit_pins/vision.txt index baf7bff3873a..ccc45b79e862 100644 --- a/.github/ci_commit_pins/vision.txt +++ b/.github/ci_commit_pins/vision.txt @@ -1 +1 @@ -4433680aa57439ed684f9854fac3443b76e03c03 +893b4abdc0c9df36c241c58769810f69e35dab48 From f67696f45e90261f9566cc38253db55583b46c37 Mon Sep 17 00:00:00 2001 From: Sergii Dymchenko Date: Tue, 21 Nov 2023 03:46:24 +0000 Subject: [PATCH 032/221] Update TorchFix to 0.2.0 (#114190) Pull Request resolved: https://github.com/pytorch/pytorch/pull/114190 Approved by: https://github.com/malfet --- .flake8 | 3 +++ .github/workflows/docker-builds.yml | 2 ++ .lintrunner.toml | 2 +- requirements-flake8.txt | 2 +- 4 files changed, 7 insertions(+), 2 deletions(-) diff --git a/.flake8 b/.flake8 index 8341a1e28b47..bca578ce563e 100644 --- a/.flake8 +++ b/.flake8 @@ -26,6 +26,9 @@ ignore = # TorchFix codes that don't make sense for PyTorch itself: # removed and deprecated PyTorch functions. TOR001,TOR101, + # TODO(kit1980): fix all TOR102 issues + # `torch.load` without `weights_only` parameter is unsafe + TOR102, per-file-ignores = __init__.py: F401 torch/utils/cpp_extension.py: B950 diff --git a/.github/workflows/docker-builds.yml b/.github/workflows/docker-builds.yml index 1651f0042639..68da02743656 100644 --- a/.github/workflows/docker-builds.yml +++ b/.github/workflows/docker-builds.yml @@ -6,6 +6,7 @@ on: paths: - .ci/docker/** - .github/workflows/docker-builds.yml + - .lintrunner.toml push: branches: - main @@ -14,6 +15,7 @@ on: paths: - .ci/docker/** - .github/workflows/docker-builds.yml + - .lintrunner.toml schedule: - cron: 1 3 * * 3 diff --git a/.lintrunner.toml b/.lintrunner.toml index 856e1b909031..d2ccb509c2dc 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -48,7 +48,7 @@ init_command = [ 'mccabe==0.7.0', 'pycodestyle==2.10.0', 'pyflakes==3.0.1', - 'torchfix==0.1.1', + 'torchfix==0.2.0', ] diff --git a/requirements-flake8.txt b/requirements-flake8.txt index f0d5e6f600f6..dc289b4e036f 100644 --- a/requirements-flake8.txt +++ b/requirements-flake8.txt @@ -8,4 +8,4 @@ flake8-pyi==20.5.0 mccabe==0.6.1 pycodestyle==2.6.0 pyflakes==2.2.0 -torchfix==0.1.1 +torchfix==0.2.0 From dc65f6c601bf02b75b868f76f8d153d1fe0fa2f6 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Tue, 21 Nov 2023 03:50:19 +0000 Subject: [PATCH 033/221] [c10d] Remove deprecated multi-gpu-per-thread APIs (#114156) As of today, PyTorch Distributed's preferred programming model is one device per thread, as exemplified by the APIs in its document. The multi-GPU functions (which stand for multiple GPUs per CPU thread) have been deprecated for three versions. Removing them now before 2.2 release. Pull Request resolved: https://github.com/pytorch/pytorch/pull/114156 Approved by: https://github.com/albanD, https://github.com/fduwjj, https://github.com/H-Huang --- docs/source/conf.py | 5 - docs/source/distributed.rst | 71 +--- .../distributed/c10d/ProcessGroupNCCL.cpp | 7 +- torch/csrc/distributed/c10d/init.cpp | 2 +- torch/distributed/distributed_c10d.py | 348 +----------------- torch/testing/_internal/common_distributed.py | 9 - .../_internal/distributed/distributed_test.py | 227 ------------ 7 files changed, 12 insertions(+), 657 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 031ed72c03f9..dcd3c7694674 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -505,17 +505,14 @@ "all_gather", "all_gather_coalesced", "all_gather_into_tensor", - "all_gather_multigpu", "all_gather_object", "all_reduce", "all_reduce_coalesced", - "all_reduce_multigpu", "all_to_all", "all_to_all_single", "barrier", "batch_isend_irecv", "broadcast", - "broadcast_multigpu", "broadcast_object_list", "destroy_process_group", "gather", @@ -543,9 +540,7 @@ "new_subgroups_by_enumeration", "recv", "reduce", - "reduce_multigpu", "reduce_scatter", - "reduce_scatter_multigpu", "reduce_scatter_tensor", "scatter", "scatter_object_list", diff --git a/docs/source/distributed.rst b/docs/source/distributed.rst index be2756487df8..8c71e8ddef26 100644 --- a/docs/source/distributed.rst +++ b/docs/source/distributed.rst @@ -483,72 +483,11 @@ Multi-GPU collective functions ------------------------------ .. warning:: - The multi-GPU functions will be deprecated. If you must use them, please revisit our documentation later. - -If you have more than one GPU on each node, when using the NCCL and Gloo backend, -:func:`~torch.distributed.broadcast_multigpu` -:func:`~torch.distributed.all_reduce_multigpu` -:func:`~torch.distributed.reduce_multigpu` -:func:`~torch.distributed.all_gather_multigpu` and -:func:`~torch.distributed.reduce_scatter_multigpu` support distributed collective -operations among multiple GPUs within each node. These functions can potentially -improve the overall distributed training performance and be easily used by -passing a list of tensors. Each Tensor in the passed tensor list needs -to be on a separate GPU device of the host where the function is called. Note -that the length of the tensor list needs to be identical among all the -distributed processes. Also note that currently the multi-GPU collective -functions are only supported by the NCCL backend. - -For example, if the system we use for distributed training has 2 nodes, each -of which has 8 GPUs. On each of the 16 GPUs, there is a tensor that we would -like to all-reduce. The following code can serve as a reference: - -Code running on Node 0 - -:: - - import torch - import torch.distributed as dist - - dist.init_process_group(backend="nccl", - init_method="file:///distributed_test", - world_size=2, - rank=0) - tensor_list = [] - for dev_idx in range(torch.cuda.device_count()): - tensor_list.append(torch.FloatTensor([1]).cuda(dev_idx)) - - dist.all_reduce_multigpu(tensor_list) - -Code running on Node 1 - -:: - - import torch - import torch.distributed as dist - - dist.init_process_group(backend="nccl", - init_method="file:///distributed_test", - world_size=2, - rank=1) - tensor_list = [] - for dev_idx in range(torch.cuda.device_count()): - tensor_list.append(torch.FloatTensor([1]).cuda(dev_idx)) - - dist.all_reduce_multigpu(tensor_list) - -After the call, all 16 tensors on the two nodes will have the all-reduced value -of 16 - -.. autofunction:: broadcast_multigpu - -.. autofunction:: all_reduce_multigpu - -.. autofunction:: reduce_multigpu - -.. autofunction:: all_gather_multigpu - -.. autofunction:: reduce_scatter_multigpu + The multi-GPU functions (which stand for multiple GPUs per CPU thread) are + deprecated. As of today, PyTorch Distributed's preferred programming model + is one device per thread, as exemplified by the APIs in this document. If + you are a backend developer and want to support multiple devices per thread, + please contact PyTorch Distributed's maintainers. .. _distributed-launch: diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index a2f2a7f86353..c51acc467f16 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -2311,18 +2311,13 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( // Bump collective counter seq_++; - // Currently, the API permits two scenarios where inputs.size() and + // Currently, the API permits one scenario where inputs.size() and // outputs.size() are > 0. // 1. If the call was a _coalesced call, all inputs must be on the same // device. // The group of nccl calls applies the collective separately to each input, // but the group as a whole should be efficient, and might even execute as // a single fused kernel. - // 2. If the call was a _multigpu call, all inputs must be on different - // devices. - // The nccl group applies the collective across them (eg, if the collective - // is an allreduce, the output on each device contains contributions summed - // across `inputs' tensors). const auto devices = getDeviceList(inputs); const bool inputs_same_dev = (devices.size() == 1); const auto key = getKeyFromDevices(devices); diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index f4fad2ee1df1..3a7798b433a7 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -705,7 +705,7 @@ Additionally, ``MAX``, ``MIN`` and ``PRODUCT`` are not supported for complex ten The values of this class can be accessed as attributes, e.g., ``ReduceOp.SUM``. They are used in specifying strategies for reduction collectives, e.g., -:func:`reduce`, :func:`all_reduce_multigpu`, etc. +:func:`reduce`. This class does not support ``__members__`` property.)"); diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 0c5a7628e21c..63f6c48d35f3 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -42,17 +42,17 @@ __all__ = [ 'Backend', 'BackendConfig', 'GroupMember', 'P2POp', 'all_gather', 'all_gather_coalesced', - 'all_gather_multigpu', 'all_gather_object', 'all_reduce', - 'all_reduce_coalesced', 'all_reduce_multigpu', 'all_to_all', + 'all_gather_object', 'all_reduce', + 'all_reduce_coalesced', 'all_to_all', 'all_to_all_single', 'barrier', 'batch_isend_irecv', 'broadcast', - 'broadcast_multigpu', 'broadcast_object_list', 'destroy_process_group', + 'broadcast_object_list', 'destroy_process_group', 'gather', 'gather_object', 'get_backend_config', 'get_backend', 'get_rank', 'get_world_size', 'group', 'init_process_group', 'irecv', 'is_gloo_available', 'is_initialized', 'is_mpi_available', 'is_backend_available', 'is_nccl_available', 'is_torchelastic_launched', 'is_ucc_available', 'isend', 'monitored_barrier', 'new_group', 'new_subgroups', - 'new_subgroups_by_enumeration', 'recv', 'reduce', 'reduce_multigpu', - 'reduce_scatter', 'reduce_scatter_multigpu', 'scatter', + 'new_subgroups_by_enumeration', 'recv', 'reduce', + 'reduce_scatter', 'scatter', 'scatter_object_list', 'send', 'supports_complex', 'AllreduceCoalescedOptions', 'AllreduceOptions', 'AllToAllOptions', 'BarrierOptions', 'BroadcastOptions', 'GatherOptions', 'PrefixStore', @@ -1851,66 +1851,6 @@ def batch_isend_irecv(p2p_op_list): return reqs -@_exception_logger -def broadcast_multigpu(tensor_list, src, group=None, async_op=False, src_tensor=0): - """ - Broadcasts the tensor to the whole group with multiple GPU tensors per node. - - ``tensor`` must have the same number of elements in all the GPUs from - all processes participating in the collective. each tensor in the list must - be on a different GPU - - Only nccl and gloo backend are currently supported - tensors should only be GPU tensors - - Args: - tensor_list (List[Tensor]): Tensors that participate in the collective - operation. If ``src`` is the rank, then the specified ``src_tensor`` - element of ``tensor_list`` (``tensor_list[src_tensor]``) will be - broadcast to all other tensors (on different GPUs) in the src process - and all tensors in ``tensor_list`` of other non-src processes. - You also need to make sure that ``len(tensor_list)`` is the same - for all the distributed processes calling this function. - - src (int): Source rank. - group (ProcessGroup, optional): The process group to work on. If None, - the default process group will be used. - async_op (bool, optional): Whether this op should be an async op - src_tensor (int, optional): Source tensor rank within ``tensor_list`` - - Returns: - Async work handle, if async_op is set to True. - None, if not async_op or if not part of the group - - """ - warnings.warn( - "torch.distributed.broadcast_multigpu will be deprecated. If you must " - "use it, please revisit our documentation later at " - "https://pytorch.org/docs/master/distributed.html#multi-gpu-collective-functions" - ) - - if _rank_not_in_group(group): - _warn_not_in_group("broadcast_multigpu") - return - - opts = BroadcastOptions() - opts.rootRank = src - opts.rootTensor = src_tensor - opts.asyncOp = async_op - - if group is None or group is GroupMember.WORLD: - default_pg = _get_default_group() - work = default_pg.broadcast(tensor_list, opts) - else: - group_src_rank = get_group_rank(group, src) - opts.rootRank = group_src_rank - work = group.broadcast(tensor_list, opts) - if async_op: - return work - else: - work.wait() - - @_exception_logger def broadcast(tensor, src, group=None, async_op=False): """ @@ -1954,68 +1894,6 @@ def broadcast(tensor, src, group=None, async_op=False): else: work.wait() -@_exception_logger -def all_reduce_multigpu(tensor_list, op=ReduceOp.SUM, group=None, async_op=False): - r""" - Reduces the tensor data across all machines in a way that all get the final result. - - This function reduces a number of tensors on every node, - while each tensor resides on different GPUs. - Therefore, the input tensor in the tensor list needs to be GPU tensors. - Also, each tensor in the tensor list needs to reside on a different GPU. - - After the call, all ``tensor`` in ``tensor_list`` is going to be bitwise - identical in all processes. - - Complex tensors are supported. - - Only nccl and gloo backend is currently supported - tensors should only be GPU tensors - - Args: - tensor_list (List[Tensor]): List of input and output tensors of - the collective. The function operates in-place and requires that - each tensor to be a GPU tensor on different GPUs. - You also need to make sure that ``len(tensor_list)`` is the same for - all the distributed processes calling this function. - op (optional): One of the values from - ``torch.distributed.ReduceOp`` - enum. Specifies an operation used for element-wise reductions. - group (ProcessGroup, optional): The process group to work on. If - ``None``, the default process group will be used. - async_op (bool, optional): Whether this op should be an async op - - Returns: - Async work handle, if async_op is set to True. - None, if not async_op or if not part of the group - - """ - warnings.warn( - "torch.distributed.all_reduce_multigpu will be deprecated. If you must " - "use it, please revisit our documentation later at " - "https://pytorch.org/docs/master/distributed.html#multi-gpu-collective-functions" - ) - - if _rank_not_in_group(group): - return - - tensor_list = [ - t if not t.is_complex() else torch.view_as_real(t) for t in tensor_list - ] - - opts = AllreduceOptions() - opts.reduceOp = op - if group is None: - default_pg = _get_default_group() - work = default_pg.allreduce(tensor_list, opts) - else: - work = group.allreduce(tensor_list, opts) - - if async_op: - return work - else: - work.wait() - @_exception_logger def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False): """ @@ -2159,69 +2037,6 @@ def all_reduce_coalesced(tensors, op=ReduceOp.SUM, group=None, async_op=False): else: work.wait() -@_exception_logger -def reduce_multigpu( - tensor_list, dst, op=ReduceOp.SUM, group=None, async_op=False, dst_tensor=0 -): - """ - Reduces the tensor data on multiple GPUs across all machines. - - Each tensor in ``tensor_list`` should reside on a separate GPU. - - Only the GPU of ``tensor_list[dst_tensor]`` on the process with rank ``dst`` - is going to receive the final result. - - Only nccl backend is currently supported - tensors should only be GPU tensors - - Args: - tensor_list (List[Tensor]): Input and output GPU tensors of the - collective. The function operates in-place. - You also need to make sure that ``len(tensor_list)`` is the same for - all the distributed processes calling this function. - dst (int): Destination rank - op (optional): One of the values from - ``torch.distributed.ReduceOp`` - enum. Specifies an operation used for element-wise reductions. - group (ProcessGroup, optional): The process group to work on. If None, - the default process group will be used. - async_op (bool, optional): Whether this op should be an async op - dst_tensor (int, optional): Destination tensor rank within - ``tensor_list`` - - Returns: - Async work handle, if async_op is set to True. - None, otherwise - - """ - warnings.warn( - "torch.distributed.reduce_multigpu will be deprecated. If you must " - "use it, please revisit our documentation later at " - "https://pytorch.org/docs/master/distributed.html#multi-gpu-collective-functions" - ) - - if _rank_not_in_group(group): - _warn_not_in_group("reduce_multigpu") - return - - opts = ReduceOptions() - opts.reduceOp = op - opts.rootRank = dst - opts.rootTensor = dst_tensor - - if group is None or group is GroupMember.WORLD: - default_pg = _get_default_group() - work = default_pg.reduce(tensor_list, opts) - else: - group_dst_rank = get_group_rank(group, dst) - opts.rootRank = group_dst_rank - work = group.reduce(tensor_list, opts) - - if async_op: - return work - else: - work.wait() - @_exception_logger def reduce(tensor, dst, op=ReduceOp.SUM, group=None, async_op=False): """ @@ -2267,83 +2082,6 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=None, async_op=False): else: work.wait() -@_exception_logger -def all_gather_multigpu( - output_tensor_lists, input_tensor_list, group=None, async_op=False -): - """ - Gathers tensors from the whole group in a list. - - Each tensor in ``tensor_list`` should reside on a separate GPU - - Only nccl backend is currently supported - tensors should only be GPU tensors - - Complex tensors are supported. - - Args: - output_tensor_lists (List[List[Tensor]]): Output lists. It should - contain correctly-sized tensors on each GPU to be used for output - of the collective, e.g. ``output_tensor_lists[i]`` contains the - all_gather result that resides on the GPU of - ``input_tensor_list[i]``. - - Note that each element of ``output_tensor_lists`` has the size of - ``world_size * len(input_tensor_list)``, since the function all - gathers the result from every single GPU in the group. To interpret - each element of ``output_tensor_lists[i]``, note that - ``input_tensor_list[j]`` of rank k will be appear in - ``output_tensor_lists[i][k * world_size + j]`` - - Also note that ``len(output_tensor_lists)``, and the size of each - element in ``output_tensor_lists`` (each element is a list, - therefore ``len(output_tensor_lists[i])``) need to be the same - for all the distributed processes calling this function. - - input_tensor_list (List[Tensor]): List of tensors(on different GPUs) to - be broadcast from current process. - Note that ``len(input_tensor_list)`` needs to be the same for - all the distributed processes calling this function. - - group (ProcessGroup, optional): The process group to work on. If None, - the default process group will be used. - async_op (bool, optional): Whether this op should be an async op - - Returns: - Async work handle, if async_op is set to True. - None, if not async_op or if not part of the group - - """ - warnings.warn( - "torch.distributed.all_gather_multigpu will be deprecated. If you must " - "use it, please revisit our documentation later at " - "https://pytorch.org/docs/master/distributed.html#multi-gpu-collective-functions" - ) - - if _rank_not_in_group(group): - _warn_not_in_group("all_gather_multigpu") - return - - output_tensor_lists = [ - [t if not t.is_complex() else torch.view_as_real(t) for t in l] - for l in output_tensor_lists - ] - input_tensor_list = [ - t if not t.is_complex() else torch.view_as_real(t) for t in input_tensor_list - ] - - if group is None: - default_pg = _get_default_group() - work = default_pg.allgather(output_tensor_lists, input_tensor_list) - else: - work = group.allgather(output_tensor_lists, input_tensor_list) - - if async_op: - return work - else: - work.wait() - - def _object_to_tensor(obj, device): f = io.BytesIO() _pickler(f).dump(obj) @@ -3235,77 +2973,6 @@ def scatter(tensor, scatter_list=None, src=0, group=None, async_op=False): work.wait() -@_exception_logger -def reduce_scatter_multigpu( - output_tensor_list, input_tensor_lists, op=ReduceOp.SUM, group=None, async_op=False -): - """ - Reduce and scatter a list of tensors to the whole group. - - Only nccl backend is currently supported. - - Each tensor in ``output_tensor_list`` should reside on a separate GPU, as - should each list of tensors in ``input_tensor_lists``. - - Args: - output_tensor_list (List[Tensor]): Output tensors (on different GPUs) - to receive the result of the operation. - - Note that ``len(output_tensor_list)`` needs to be the same for all - the distributed processes calling this function. - - input_tensor_lists (List[List[Tensor]]): Input lists. It should - contain correctly-sized tensors on each GPU to be used for input of - the collective, e.g. ``input_tensor_lists[i]`` contains the - reduce_scatter input that resides on the GPU of - ``output_tensor_list[i]``. - - Note that each element of ``input_tensor_lists`` has the size of - ``world_size * len(output_tensor_list)``, since the function - scatters the result from every single GPU in the group. To - interpret each element of ``input_tensor_lists[i]``, note that - ``output_tensor_list[j]`` of rank k receives the reduce-scattered - result from ``input_tensor_lists[i][k * world_size + j]`` - - Also note that ``len(input_tensor_lists)``, and the size of each - element in ``input_tensor_lists`` (each element is a list, - therefore ``len(input_tensor_lists[i])``) need to be the same for - all the distributed processes calling this function. - - group (ProcessGroup, optional): The process group to work on. If None, - the default process group will be used. - async_op (bool, optional): Whether this op should be an async op. - - Returns: - Async work handle, if async_op is set to True. - None, if not async_op or if not part of the group. - - """ - warnings.warn( - "torch.distributed.reduce_scatter_multigpu will be deprecated. If you must " - "use it, please revisit our documentation later at " - "https://pytorch.org/docs/master/distributed.html#multi-gpu-collective-functions" - ) - - if _rank_not_in_group(group): - _warn_not_in_group("reduce_scatter_multigpu") - return - - opts = ReduceScatterOptions() - opts.reduceOp = op - - if group is None: - default_pg = _get_default_group() - work = default_pg.reduce_scatter(output_tensor_list, input_tensor_lists, opts) - else: - work = group.reduce_scatter(output_tensor_list, input_tensor_lists, opts) - - if async_op: - return work - else: - work.wait() - - @_exception_logger def reduce_scatter(output, input_list, op=ReduceOp.SUM, group=None, async_op=False): """ @@ -4299,7 +3966,6 @@ def _get_process_group_store(pg: ProcessGroup) -> Store: # This ops are not friently to TorchDynamo. So, we decide to disallow these ops # in FX graph, allowing them to run them on eager, with torch.compile. dynamo_unsupported_distributed_c10d_ops = [ - all_reduce_multigpu, recv, all_gather_object, all_gather_coalesced, @@ -4311,14 +3977,10 @@ def _get_process_group_store(pg: ProcessGroup) -> Store: gather, broadcast_object_list, barrier, - reduce_multigpu, scatter, scatter_object_list, reduce, - reduce_scatter_multigpu, all_gather, - broadcast_multigpu, - all_gather_multigpu, reduce_scatter, all_gather_into_tensor, broadcast, diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py index c9cec3e269e4..8cbca096b500 100644 --- a/torch/testing/_internal/common_distributed.py +++ b/torch/testing/_internal/common_distributed.py @@ -458,15 +458,6 @@ def init_multigpu_helper(world_size: int, backend: str): nGPUs = torch.cuda.device_count() visible_devices = range(nGPUs) - if backend == "nccl": - # This is a hack for a known NCCL issue using multiprocess - # in conjunction with multiple threads to manage different GPUs which - # may cause ncclCommInitRank to fail. - # http://docs.nvidia.com/deeplearning/sdk/nccl-release-notes/rel_2.1.4.html#rel_2.1.4 - # It slows down the performance of collective operations. - # Without this setting NCCL might throw unhandled error. - os.environ["NCCL_MAX_NRINGS"] = "1" - # If rank is less than or equal to number of available GPU's # then each rank can be mapped to corresponding GPU. nGPUs_per_process = 1 diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index ddd7f6513618..5c007dcc98fe 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -4162,233 +4162,6 @@ def test_barrier_full_group(self): group, group_id, rank = self._init_full_group_test() self._test_barrier_helper(group, group_id, rank) - def _test_broadcast_multigpu_helper(self, group, group_id, rank, rank_to_GPU): - for src in group: - expected_tensor = _build_tensor(src + 1) - tensors = [ - _build_tensor(src + 1, -1).cuda(device=i) for i in rank_to_GPU[rank] - ] - if rank == src: - tensors[0] = expected_tensor.cuda(device=rank_to_GPU[rank][0]) - - dist.broadcast_multigpu(tensors, src, group_id) - for tensor in tensors: - self.assertEqual(tensor, expected_tensor) - self._barrier() - - @skip_but_pass_in_sandcastle_if( - BACKEND == "mpi", "MPI doesn't support broadcast multigpu" - ) - @skip_but_pass_in_sandcastle_if( - BACKEND == "nccl", "NCCL broadcast multigpu skipped" - ) - @skip_if_no_gpu - def test_broadcast_multigpu(self): - group, group_id, rank = self._init_global_test() - rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND) - self._test_broadcast_multigpu_helper(group, group_id, rank, rank_to_GPU) - - def _test_all_reduce_multigpu_helper( - self, - group, - group_id, - rank, - rank_to_GPU, - op, - master_value, - worker_value, - expected_value, - dtype=torch.float, - ): - for src in group: - curr_value = master_value if rank == src else worker_value - tensors = [ - _build_tensor(src + 1, curr_value, dtype=dtype).cuda(device=i) - for i in rank_to_GPU[rank] - ] - self.call_dist_op( - ":all_reduce", - False, - dist.all_reduce_multigpu, - tensors, - op, - group_id, - ) - expected_tensor = _build_tensor(src + 1, expected_value, dtype=dtype) - for tensor in tensors: - self.assertEqual(tensor, expected_tensor) - - self._barrier() - - @skip_but_pass_in_sandcastle_if( - BACKEND == "mpi", "MPI doesn't support broadcast multigpu" - ) - @skip_but_pass_in_sandcastle_if( - BACKEND == "nccl", "CUDA all_reduce multigpu skipped for NCCL" - ) - @skip_but_pass_in_sandcastle_if( - BACKEND == "ucc" and IS_SANDCASTLE, "Skipped internally" - ) - @skip_if_no_gpu - def test_all_reduce_multigpu(self): - group, group_id, rank = self._init_global_test() - rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND) - self._test_all_reduce_multigpu_helper( - group, - group_id, - rank, - rank_to_GPU, - dist.ReduceOp.SUM, - 2, - 10, - (2 + 10 * (len(group) - 1)) * len(rank_to_GPU[0]), - ) - - @skip_but_pass_in_sandcastle_if( - BACKEND == "mpi", "MPI doesn't support broadcast multigpu" - ) - @skip_but_pass_in_sandcastle_if( - BACKEND == "nccl", "CUDA all_reduce multigpu skipped for NCCL" - ) - @skip_but_pass_in_sandcastle_if( - BACKEND == "ucc" and IS_SANDCASTLE, "Skipped internally" - ) - @skip_if_no_gpu - def test_all_reduce_multigpu_complex(self): - group, group_id, rank = self._init_global_test() - rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND) - self._test_all_reduce_multigpu_helper( - group, - group_id, - rank, - rank_to_GPU, - dist.ReduceOp.SUM, - complex(2, 3), - complex(10, 11), - (complex(2, 3) + complex(10, 11) * (len(group) - 1)) - * len(rank_to_GPU[0]), - dtype=torch.cfloat, - ) - - def _test_reduce_multigpu_helper( - self, - group, - group_id, - rank, - rank_to_GPU, - op, - master_value, - worker_value, - expected_value, - ): - for src in group: - tensor_value = master_value if rank == src else worker_value - tensors = [ - _build_tensor(src + 1, tensor_value).cuda(device=i) - for i in rank_to_GPU[rank] - ] - self.call_dist_op( - ":reduce", - False, - dist.reduce_multigpu, - tensors, - src, - op, - group_id, - expect_event=len(tensors) == 1, - tensor_shapes=[tensors[0].shape], - ) - if rank == src: - expected_tensor = _build_tensor(src + 1, expected_value) - self.assertEqual(tensors[0], expected_tensor) - - self._barrier() - - @skip_but_pass_in_sandcastle_if( - BACKEND != "nccl", "Only Nccl backend supports reduce multigpu" - ) - @skip_if_no_gpu - def test_reduce_multigpu(self): - group, group_id, rank = self._init_global_test() - rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND) - device_id = rank_to_GPU[rank][0] - torch.cuda.set_device(device_id) - self._test_reduce_multigpu_helper( - group, - group_id, - rank, - rank_to_GPU, - dist.ReduceOp.SUM, - 2, - 10, - (2 + 10 * (len(group) - 1)) * len(rank_to_GPU[0]), - ) - - def _test_all_gather_multigpu_helper( - self, group, group_id, rank, rank_to_GPU, dtype=torch.float - ): - for dest in group: - tensors = [ - _build_tensor(dest + 1, dtype=dtype).cuda(device=i) - for i in rank_to_GPU[rank] - ] - - # construct expected output along with - # a place holder to receive all gather results - output_tensors = [] - expected_output = [] - output_per_gpu = ( - [_build_tensor(dest + 1, -1, dtype=dtype)] - * len(rank_to_GPU[0]) - * len(group) - ) - expected_per_gpu = ( - [_build_tensor(dest + 1, dtype=dtype)] - * len(rank_to_GPU[0]) - * len(group) - ) - for gpu in rank_to_GPU[rank]: - output_tensors.append([t.cuda(device=gpu) for t in output_per_gpu]) - expected_output.append( - [t.cuda(device=gpu) for t in expected_per_gpu] - ) - self.call_dist_op( - ":all_gather", - False, - dist.all_gather_multigpu, - output_tensors, - tensors, - group_id, - expect_event=len(expected_output) == 1, - ) - self.assertEqual(output_tensors, expected_output) - - self._barrier() - - @skip_but_pass_in_sandcastle_if( - BACKEND != "nccl", "Only Nccl backend supports allgather multigpu" - ) - @skip_if_no_gpu - def test_all_gather_multigpu(self): - group, group_id, rank = self._init_global_test() - rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND) - device_id = rank_to_GPU[rank][0] - torch.cuda.set_device(device_id) - self._test_all_gather_multigpu_helper(group, group_id, rank, rank_to_GPU) - - @skip_but_pass_in_sandcastle_if( - BACKEND != "nccl", "Only Nccl backend supports allgather multigpu" - ) - @skip_if_no_gpu - def test_all_gather_multigpu_complex(self): - group, group_id, rank = self._init_global_test() - rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND) - device_id = rank_to_GPU[rank][0] - torch.cuda.set_device(device_id) - self._test_all_gather_multigpu_helper( - group, group_id, rank, rank_to_GPU, dtype=torch.cfloat - ) - def _model_step(self, model): for param in model.parameters(): if param.grad is not None: From 18e1a37c4e637370ad3ceee91cecffcfeb4b79b6 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Mon, 20 Nov 2023 13:50:17 -0800 Subject: [PATCH 034/221] [ao] updating embedding_bag support for fx and eager (#107623) Summary: our docs were saying dynamic embedding bag wasn't supported but it actually is (at least at the same level as embeddings were) it just wasn't previously tested/listed. Test Plan: python test/test_quantization.py -k "test_embedding" Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/107623 Approved by: https://github.com/jerryzh168 --- docs/source/quantization.rst | 5 ++-- .../eager/test_quantize_eager_ptq.py | 27 ++++++++++++++++--- test/quantization/fx/test_quantize_fx.py | 12 +++++++++ 3 files changed, 38 insertions(+), 6 deletions(-) diff --git a/docs/source/quantization.rst b/docs/source/quantization.rst index 3d6408bf4722..1b4ae6ed200c 100644 --- a/docs/source/quantization.rst +++ b/docs/source/quantization.rst @@ -115,7 +115,6 @@ for a more comprehensive overview of the tradeoffs between these quantization types. Operator coverage varies between dynamic and static quantization and is captured in the table below. -Note that for FX quantization, the corresponding functionals are also supported. +---------------------------+-------------------+--------------------+ | |Static | Dynamic | @@ -135,7 +134,7 @@ Note that for FX quantization, the corresponding functionals are also supported. |nn.EmbeddingBag | Y (activations | | | | are in fp32) | Y | +---------------------------+-------------------+--------------------+ -|nn.Embedding | Y | N | +|nn.Embedding | Y | Y | +---------------------------+-------------------+--------------------+ | nn.MultiheadAttention | Y (through | Not supported | | | custom modules) | | @@ -881,7 +880,7 @@ Note that for FX Graph Mode Quantization, the corresponding functionals are also |nn.EmbeddingBag | Y (activations | | | | are in fp32) | Y | +---------------------------+-------------------+--------------------+ -|nn.Embedding | Y | N | +|nn.Embedding | Y | Y | +---------------------------+-------------------+--------------------+ |nn.MultiheadAttention |Not Supported | Not supported | +---------------------------+-------------------+--------------------+ diff --git a/test/quantization/eager/test_quantize_eager_ptq.py b/test/quantization/eager/test_quantize_eager_ptq.py index c14f8068d800..e6e3327b7cdf 100644 --- a/test/quantization/eager/test_quantize_eager_ptq.py +++ b/test/quantization/eager/test_quantize_eager_ptq.py @@ -1475,7 +1475,7 @@ def checkHooksIsPresent(model): checkHooksIsPresent(model) @skipIfNoFBGEMM - def test_embedding_ops_dynamic(self): + def test_embedding_bag_dynamic(self): class EmbeddingBagWithLinear(torch.nn.Module): def __init__(self): super().__init__() @@ -1496,9 +1496,30 @@ def forward(self, indices, offsets, linear_in): q_model = quantize_dynamic(model, qconfig_dict) q_model(indices, offsets, torch.randn(5, 5)) - self.assertTrue('QuantizedEmbedding' in str(q_model)) - self.assertTrue('DynamicQuantizedLinear' in str(q_model)) + self.assertTrue('QuantizedEmbeddingBag' in str(q_model.emb)) + self.assertTrue('DynamicQuantizedLinear' in str(q_model.fc)) + + @skipIfNoFBGEMM + def test_embedding_ops_dynamic(self): + class EmbeddingWithLinear(torch.nn.Module): + def __init__(self): + super().__init__() + self.emb = torch.nn.Embedding( + num_embeddings=10, embedding_dim=12, scale_grad_by_freq=False) + self.fc = torch.nn.Linear(5, 5) + def forward(self, indices, linear_in): + return self.emb(indices), self.fc(linear_in) + model = EmbeddingWithLinear().eval() + qconfig_dict = { + torch.nn.Embedding : float_qparams_weight_only_qconfig, + torch.nn.Linear: default_dynamic_qconfig + } + indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8, 3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3]) + q_model = quantize_dynamic(model, qconfig_dict) + self.assertTrue('QuantizedEmbedding' in str(q_model.emb)) + self.assertTrue('DynamicQuantizedLinear' in str(q_model.fc)) + q_model(indices, torch.randn(5, 5)) if __name__ == '__main__': raise RuntimeError("This test file is not meant to be run directly, use:\n\n" diff --git a/test/quantization/fx/test_quantize_fx.py b/test/quantization/fx/test_quantize_fx.py index 49c79cd9663e..19ff89fa11de 100644 --- a/test/quantization/fx/test_quantize_fx.py +++ b/test/quantization/fx/test_quantize_fx.py @@ -8636,12 +8636,24 @@ def forward(self, indices): indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8, 3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3]) example_inputs = (indices,) quantized_node = ns.call_module(nnq.Embedding) + + # check dynamic quant + self.checkGraphModeFxOp( + model, + example_inputs, + QuantType.DYNAMIC, + quantized_node, + custom_qconfig_dict={"": qconfig_type} + ) + model = M().eval() + configs = [ (qconfig_type, ns.call_module(nnq.Embedding)), (None, ns.call_module(nn.Embedding)), (default_qconfig, ns.call_module(nn.Embedding)), ] + # check static quantization for qconfig, node in configs: qconfig_dict = {"": qconfig} m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) From 7ea184d7e33369610492ff0936369ea00f2b3580 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Mon, 20 Nov 2023 23:31:19 -0500 Subject: [PATCH 035/221] Handle item() on boolean tensor (#114157) This needs some special handling because we don't actually allocate boolean symbols in sympy; we allocate an integer indicator variable. See comment for more details. Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/114157 Approved by: https://github.com/ydwu4 --- .../test_torchinductor_dynamic_shapes.py | 8 +++++++ torch/_inductor/ir.py | 22 +++++++++++++++++-- torch/_inductor/scheduler.py | 4 ++++ 3 files changed, 32 insertions(+), 2 deletions(-) diff --git a/test/inductor/test_torchinductor_dynamic_shapes.py b/test/inductor/test_torchinductor_dynamic_shapes.py index 53a4d24d86a1..459059a7434c 100644 --- a/test/inductor/test_torchinductor_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_dynamic_shapes.py @@ -206,6 +206,14 @@ def f(x): f(torch.tensor([3], device=device)) + @torch._dynamo.config.patch(capture_scalar_outputs=True) + def test_item_bool_nobreak(self, device): + @torch.compile(fullgraph=True) + def f(x): + return x.item() + + f(torch.tensor([True], device=device)) + @torch._dynamo.config.patch(capture_scalar_outputs=True) def test_item_zeros_nobreak(self, device): @torch.compile(fullgraph=True) diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 5daac542d1d1..f1956373bed5 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -4175,16 +4175,34 @@ def get_reads(self): def should_allocate(self): return False + # TODO: handle bools carefully def __init__(self, sym, data): super().__init__(None, NoneLayout(torch.device("cpu")), [data]) # type: ignore[arg-type] - self.sym = sym + if isinstance(sym, sympy.Symbol): + self.sym = sym + self.is_bool = False + else: + # Special case for boolean. For Reasons(TM), we don't represent + # boolean variables directly in sympy; instead, we generate an + # indicator integer variable which we then convert to a boolean by + # testing i0 == 1. We have to identify the underlying indicator + # variable, and then bind i0 to the appropriate integer value + # based on the runtime boolean. + assert isinstance(sym, sympy.Eq), sym + assert isinstance(sym.args[0], sympy.Symbol), sym + assert sym.args[1] == 1, sym + self.sym = sym.args[0] + self.is_bool = True def get_unbacked_symbol_defs(self): return {self.sym} def codegen(self, wrapper): (data,) = (t.codegen_reference() for t in self.inputs) - wrapper.writeline(f"{self.sym} = {data}.item()") + if self.is_bool: + wrapper.writeline(f"{self.sym} = 1 if {data}.item() else 0") + else: + wrapper.writeline(f"{self.sym} = {data}.item()") # No one should ever use this buffer, but for uniformity # define the variable and assign it None wrapper.writeline(f"{self.get_name()} = None") diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 8fdfbceeb3ee..7b8755a8699b 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -1348,6 +1348,7 @@ def add_user(used_by_name, user_node, can_inplace=False, is_weak=False): # unbacked symbols don't follow ordinary buffer dependencies, so # we track their def/uses separately for s in node.node.get_unbacked_symbol_defs(): + assert isinstance(s, sympy.Symbol) # Pick the first definer as canonical. There may be multiple # because if a MultiOutputLayout buffer propagates an unbacked # symint to multiple outputs, they will all claim to def it. @@ -1402,6 +1403,9 @@ def add_user(used_by_name, user_node, can_inplace=False, is_weak=False): for node in V.graph.graph_outputs: if isinstance(node, ir.ShapeAsConstantBuffer): for s in free_unbacked_symbols(node.shape): + assert ( + s in unbacked_symbol_to_origin_node + ), f"{s} not in {unbacked_symbol_to_origin_node.keys()}" node_name = unbacked_symbol_to_origin_node[s].node.name log.debug( "scheduling output %s for unbacked symint %s", node_name, s From 99af534e932380fac654e05a02b70e141ae89ad3 Mon Sep 17 00:00:00 2001 From: David Berard Date: Mon, 20 Nov 2023 23:24:58 +0000 Subject: [PATCH 036/221] [docs][jit] Mention dynamic-shapes settings in jit/OVERVIEW.md (#113964) Document torch._C._jit_set_fusion_strategy, which can control how many static-shape compilation attempts are made before falling back to dynamic shapes, before falling back to uncompiled graph execution. Would be good to keep all the graph executor settings documented in one place. Pull Request resolved: https://github.com/pytorch/pytorch/pull/113964 Approved by: https://github.com/eellison --- torch/csrc/jit/OVERVIEW.md | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/torch/csrc/jit/OVERVIEW.md b/torch/csrc/jit/OVERVIEW.md index 17aff9dc4825..2b5374741b1f 100644 --- a/torch/csrc/jit/OVERVIEW.md +++ b/torch/csrc/jit/OVERVIEW.md @@ -53,6 +53,7 @@ Sections start with a reference to the source file where the code related to the - [Interpreter](#interpreter) - [Graph Executor](#graph-executor) - [Specialization](#specialization) + - [Dynamic Shapes Options](#dynamic-shapes-options) - [Pre-derivative Optimization](#pre-derivative-optimization) - [Required Passes](#required-passes) - [Derivative Preserving Optimization](#derivative-preserving-optimization) @@ -942,6 +943,22 @@ The executor *specializes* the `Graph` for the particular set of inputs. Special The ArgumentSpec object is used as a key into a cache that holds pre-optimized Code objects (held in an ExecutionPlan object). On a cache hit, an InterpreterState is created and the Code in the cache is run. +### Dynamic Shapes Options ### + +In the "Specialization" section above, it is mentioned that "rank, but not size" is specialized on. This is partially true; size is sometimes specialized on because this specialization can sometimes produce more efficient code. By default, static shapes are specialized initially; if more shapes are observed then eventually the graph executor will generate a dynamic-shape version that doesn't depend on specific input shapes. + +To control these settings, you can use `torch._C._jit_set_fusion_strategy()`; it takes as an argument a list of tuples in the format `(type, number)` where `type` is a string in `{"DYNAMIC" ,"STATIC"}` and `number` is an integer. + +For example: +``` +torch._C._jit_set_fusion_strategy([ + ("STATIC", 2), + ("DYNAMIC", 20), +]) +``` + +This will make two attempts to generate static-shape graphs, and after that fall back to generating dynamic-shape graphs. If for some reason compilation keeps occuring (even with dynamic-shape graphs - e.g. this could happen if ranks or dtypes vary), after 20 compilation attempts the graph executor will fall back to running the graph without any attempts to compile it. + ### Pre-derivative Optimization ### On a code cache miss, we generate a new optimized `Graph` on the fly (`compileSpec`). It starts by creating a copy of the initial `Graph` and setting the input types to the specialized `Tensor` types observed in this specialization. TensorType inputs to the `Graph` will get refined with types that know the device, number of dimensions, and requires grad state. From e122c90d3cd810d53cf12a8d2f7120fc5a2af462 Mon Sep 17 00:00:00 2001 From: PyTorch UpdateBot Date: Tue, 21 Nov 2023 06:31:11 +0000 Subject: [PATCH 037/221] [executorch hash update] update the pinned executorch hash (#114008) This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/_update-commit-hash.yml). Update the pinned executorch hash. Pull Request resolved: https://github.com/pytorch/pytorch/pull/114008 Approved by: https://github.com/pytorchbot, https://github.com/huydhn --- .ci/docker/ci_commit_pins/executorch.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.ci/docker/ci_commit_pins/executorch.txt b/.ci/docker/ci_commit_pins/executorch.txt index d1b62b926904..d60ac4ac9b62 100644 --- a/.ci/docker/ci_commit_pins/executorch.txt +++ b/.ci/docker/ci_commit_pins/executorch.txt @@ -1 +1 @@ -9682172576d5d9a10f3162ad91e0a32b384a3b7c +f5e4a1e74daa5397362d87d1a8cb81f09446d34f From 8f8722e3f1d1121345b80428117be89209ccd772 Mon Sep 17 00:00:00 2001 From: Pavan Balaji Date: Tue, 21 Nov 2023 07:23:42 +0000 Subject: [PATCH 038/221] [nccl-pg] Avoid using NCCL_ prefix for non-NCCL env variables (#114077) NCCL_ prefix should only be used for NCCL library's environment variables. We currently use a few environment variables in PyTorch with the NCCL_ prefix that are the NCCL library does not understand. This patch renames such environment variables to use the TORCH_NCCL_ prefix instead. We still maintain the old NCCL_ variables, but throw a warning when they are used. The following env changes have been made: `NCCL_BLOCKING_WAIT` -> `TORCH_NCCL_BLOCKING_WAIT` `NCCL_ENABLE_TIMING` -> `TORCH_NCCL_ENABLE_TIMING` `NCCL_DESYNC_DEBUG` -> `TORCH_NCCL_DESYNC_DEBUG` `NCCL_ASYNC_ERROR_HANDLING` -> `TORCH_NCCL_ASYNC_ERROR_HANDLING` `ENABLE_NCCL_HEALTH_CHECK` -> `TORCH_ENABLE_NCCL_HEALTH_CHECK` `NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK` -> `TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK` Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/114077 Approved by: https://github.com/fduwjj --- test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp | 14 +++--- test/cpp/c10d/ProcessGroupNCCLTest.cpp | 7 ++- .../distributed/c10d/ProcessGroupNCCL.cpp | 50 ++++++++++--------- .../distributed/c10d/ProcessGroupNCCL.hpp | 25 +++++++--- torch/csrc/distributed/c10d/init.cpp | 4 +- torch/csrc/distributed/c10d/logger.cpp | 11 +++- 6 files changed, 66 insertions(+), 45 deletions(-) diff --git a/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp b/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp index 906f5112d814..736363994083 100644 --- a/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp +++ b/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp @@ -269,7 +269,7 @@ class ProcessGroupNCCLErrorsTest : public ::testing::Test { } void TearDown() override { - ASSERT_TRUE(setenv(c10d::NCCL_BLOCKING_WAIT[0].c_str(), "0", 1) == 0); + ASSERT_TRUE(setenv(c10d::TORCH_NCCL_BLOCKING_WAIT[0].c_str(), "0", 1) == 0); } std::vector tensors_; @@ -281,7 +281,7 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsBlocking) { return; } - ASSERT_TRUE(setenv(c10d::NCCL_BLOCKING_WAIT[0].c_str(), "1", 1) == 0); + ASSERT_TRUE(setenv(c10d::TORCH_NCCL_BLOCKING_WAIT[0].c_str(), "1", 1) == 0); auto options = c10d::ProcessGroupNCCL::Options::create(); options->timeout = std::chrono::milliseconds(1000); ProcessGroupNCCLSimulateErrors pg(store_, 0, 1, options); @@ -309,7 +309,7 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLTimedoutErrorsBlocking) { return; } - ASSERT_TRUE(setenv(c10d::NCCL_BLOCKING_WAIT[0].c_str(), "1", 1) == 0); + ASSERT_TRUE(setenv(c10d::TORCH_NCCL_BLOCKING_WAIT[0].c_str(), "1", 1) == 0); auto options = c10d::ProcessGroupNCCL::Options::create(); options->timeout = std::chrono::milliseconds(3000); ProcessGroupNCCLTimedOutErrors pg(store_, 0, 1, options); @@ -395,7 +395,7 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNoHeartbeat) { int heartBeatIntervalInSec = 2; std::string timeInterval = std::to_string(heartBeatIntervalInSec); - ASSERT_TRUE(setenv(c10d::NCCL_BLOCKING_WAIT[0].c_str(), "1", 1) == 0); + ASSERT_TRUE(setenv(c10d::TORCH_NCCL_BLOCKING_WAIT[0].c_str(), "1", 1) == 0); ASSERT_TRUE( setenv( c10d::TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC[0].c_str(), @@ -453,7 +453,7 @@ class ProcessGroupNCCLWatchdogTimeoutTest : public ProcessGroupNCCLErrorsTest { void SetUp() override { ProcessGroupNCCLErrorsTest::SetUp(); std::string timeInterval = std::to_string(heartBeatIntervalInSec); - ASSERT_TRUE(setenv(c10d::NCCL_BLOCKING_WAIT[0].c_str(), "1", 1) == 0); + ASSERT_TRUE(setenv(c10d::TORCH_NCCL_BLOCKING_WAIT[0].c_str(), "1", 1) == 0); ASSERT_TRUE( setenv( c10d::TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC[0].c_str(), @@ -461,12 +461,12 @@ class ProcessGroupNCCLWatchdogTimeoutTest : public ProcessGroupNCCLErrorsTest { 1) == 0); ASSERT_TRUE( setenv(c10d::TORCH_NCCL_ENABLE_MONITORING[0].c_str(), "1", 1) == 0); - ASSERT_TRUE(setenv(c10d::NCCL_DESYNC_DEBUG[0].c_str(), "1", 1) == 0); + ASSERT_TRUE(setenv(c10d::TORCH_NCCL_DESYNC_DEBUG[0].c_str(), "1", 1) == 0); // We cannot capture the exception thrown in watchdog thread without making // lots of changes to the code. So we don't let the watchdog throw // exception. ASSERT_TRUE( - setenv(c10d::NCCL_ASYNC_ERROR_HANDLING[0].c_str(), "0", 1) == 0); + setenv(c10d::TORCH_NCCL_ASYNC_ERROR_HANDLING[0].c_str(), "0", 1) == 0); options_ = c10d::ProcessGroupNCCL::Options::create(); // Set a super short watchdog timeout. options_->timeout = std::chrono::milliseconds(100); diff --git a/test/cpp/c10d/ProcessGroupNCCLTest.cpp b/test/cpp/c10d/ProcessGroupNCCLTest.cpp index 70aed2a0b0c6..61e9753988ea 100644 --- a/test/cpp/c10d/ProcessGroupNCCLTest.cpp +++ b/test/cpp/c10d/ProcessGroupNCCLTest.cpp @@ -41,7 +41,10 @@ class NCCLTestBase { c10::intrusive_ptr opts = c10::make_intrusive(); opts->timeout = pgTimeout_; - setenv(c10d::ENABLE_NCCL_HEALTH_CHECK[0].c_str(), "1", /* overwrite */ 1); + setenv( + c10d::TORCH_ENABLE_NCCL_HEALTH_CHECK[0].c_str(), + "1", + /* overwrite */ 1); pg_ = std::unique_ptr<::c10d::ProcessGroupNCCL>( new ::c10d::ProcessGroupNCCL(store, rank, size, std::move(opts))); } @@ -749,7 +752,7 @@ class ProcessGroupNCCLTest : public ::testing::Test { void TearDown() override { // Reset NCCL_BLOCKING_WAIT environment variable after each run. - ASSERT_TRUE(setenv(c10d::NCCL_BLOCKING_WAIT[0].c_str(), "0", 1) == 0); + ASSERT_TRUE(setenv(c10d::TORCH_NCCL_BLOCKING_WAIT[0].c_str(), "0", 1) == 0); } bool skipTest() { diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index c51acc467f16..f4691c815c61 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -841,7 +841,7 @@ void ProcessGroupNCCL::WorkNCCL::synchronizeInternal( rank_, "] Work ", (*this), - " timed out in blocking wait (NCCL_BLOCKING_WAIT=1)."); + " timed out in blocking wait (TORCH_NCCL_BLOCKING_WAIT=1)."); LOG(ERROR) << exceptionMsg; break; } @@ -953,10 +953,10 @@ ProcessGroupNCCL::ProcessGroupNCCL( ValueError, at::cuda::getNumGPUs() != 0, "ProcessGroupNCCL is only supported with GPUs, no GPUs found!"); - blockingWait_ = getCvarBool(NCCL_BLOCKING_WAIT, false); + blockingWait_ = getCvarBool(TORCH_NCCL_BLOCKING_WAIT, false); asyncErrorHandling_ = static_cast( - getCvarInt(NCCL_ASYNC_ERROR_HANDLING, 3 /*SkipCleanUp*/)); - desyncDebug_ = getCvarBool(NCCL_DESYNC_DEBUG, false) || + getCvarInt(TORCH_NCCL_ASYNC_ERROR_HANDLING, 3 /*SkipCleanUp*/)); + desyncDebug_ = getCvarBool(TORCH_NCCL_DESYNC_DEBUG, false) || (dist_debug_level_ >= DebugLevel::Detail); heartbeat_ = 1ULL; monitorThreadEnabled_.store(getCvarBool(TORCH_NCCL_ENABLE_MONITORING, false)); @@ -964,42 +964,44 @@ ProcessGroupNCCL::ProcessGroupNCCL( getCvarInt(TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC, 60 * 2 /*2 Mins*/); #ifdef ENABLE_NCCL_ERROR_CHECKING enableTiming_.store( - getCvarBool(NCCL_ENABLE_TIMING, false) || desyncDebug_ || + getCvarBool(TORCH_NCCL_ENABLE_TIMING, false) || desyncDebug_ || getCvarInt({"TORCH_NCCL_TRACE_BUFFER_SIZE"}, 0) > 0); #endif avoidRecordStreams_ = getCvarBool(TORCH_NCCL_AVOID_RECORD_STREAMS, false); #ifdef NCCL_HAS_COMM_REGISTER useTensorRegisterAllocatorHook_ = - getCvarBool(NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK, false); + getCvarBool(TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK, false); if (c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig:: expandable_segments()) { useTensorRegisterAllocatorHook_ = false; LOG(INFO) << "[Rank " << rank_ - << "] disables NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK because it is not compatible with CUDA allocator expandable segments mode."; + << "] disables TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK because it is not compatible with CUDA allocator expandable segments mode."; } #endif if (blockingWait_) { if (asyncErrorHandling_ != NoHandling || desyncDebug_) { - LOG(INFO) << "[Rank " << rank_ << "] NCCL_BLOCKING_WAIT and " - << "NCCL_ASYNC_ERROR_HANDLING|NCCL_DESYNC_DEBUG" - << "should not both be enabled. " - << "Only NCCL_BLOCKING_WAIT is being used in this process."; + LOG(INFO) + << "[Rank " << rank_ << "] TORCH_NCCL_BLOCKING_WAIT and " + << "TORCH_NCCL_ASYNC_ERROR_HANDLING|TORCH_NCCL_DESYNC_DEBUG" + << "should not both be enabled. " + << "Only TORCH_NCCL_BLOCKING_WAIT is being used in this process."; asyncErrorHandling_ = NoHandling; desyncDebug_ = false; } } else { if (desyncDebug_ && asyncErrorHandling_ == NoHandling) { - LOG(INFO) << "[Rank " << rank_ - << "] NCCL_DESYNC_DEBUG and NCCL_ASYNC_ERROR_HANDLING " - << "must both be enabled. " - << "Enabling NCCL_ASYNC_ERROR_HANDLING."; + LOG(INFO) + << "[Rank " << rank_ + << "] TORCH_NCCL_DESYNC_DEBUG and TORCH_NCCL_ASYNC_ERROR_HANDLING " + << "must both be enabled. " + << "Enabling TORCH_NCCL_ASYNC_ERROR_HANDLING."; asyncErrorHandling_ = SkipCleanUp; } } - if (getCvarBool(ENABLE_NCCL_HEALTH_CHECK, false)) { + if (getCvarBool(TORCH_ENABLE_NCCL_HEALTH_CHECK, false)) { // Perform health check by initializing dummy communicators and destroying // them. This will help indicate any NCCL-related issues prior to the first // collective. @@ -1021,16 +1023,16 @@ ProcessGroupNCCL::ProcessGroupNCCL( LOG(INFO) << "[Rank " << rank_ << "] ProcessGroupNCCL initialization options: " << "NCCL version: " << getNcclVersion() - << ", NCCL_ASYNC_ERROR_HANDLING: " << asyncErrorHandling_ - << ", NCCL_DESYNC_DEBUG: " << desyncDebug_ - << ", NCCL_ENABLE_TIMING: " << enableTiming_.load() - << ", NCCL_BLOCKING_WAIT: " << blockingWait_ + << ", TORCH_NCCL_ASYNC_ERROR_HANDLING: " << asyncErrorHandling_ + << ", TORCH_NCCL_DESYNC_DEBUG: " << desyncDebug_ + << ", TORCH_NCCL_ENABLE_TIMING: " << enableTiming_.load() + << ", TORCH_NCCL_BLOCKING_WAIT: " << blockingWait_ << ", TIMEOUT(ms): " << options_->timeout.count() << ", USE_HIGH_PRIORITY_STREAM: " << options_->is_high_priority_stream << ", TORCH_DISTRIBUTED_DEBUG: " << torch_distributed_debug #ifdef NCCL_HAS_COMM_REGISTER - << ", NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK: " + << ", TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK: " << useTensorRegisterAllocatorHook_ #endif << ", TORCH_NCCL_ENABLE_MONITORING: " @@ -1163,7 +1165,7 @@ void ProcessGroupNCCL::registerOnCompletionHook( ValueError, enableTiming_.load(), "ProcessGroupNCCL OnCompletion hook requires recording start and end " - "events which require setting NCCL_ENABLE_TIMING environment variable. " + "events which require setting TORCH_NCCL_ENABLE_TIMING environment variable. " "This is only available for NCCL version >= 2.4."); onCompletionHook_ = std::move(hook); onCompletionHookThread_ = std::thread(&ProcessGroupNCCL::runHookLoop, this); @@ -1550,11 +1552,11 @@ void ProcessGroupNCCL::workCleanupLoop() { dumpingDebugInfo->detach(); } } catch (const std::exception& e) { - LOG(ERROR) << "Failed to retrieve NCCL_DESYNC_DEBUG report. " + LOG(ERROR) << "Failed to retrieve TORCH_NCCL_DESYNC_DEBUG report. " << " Please file an issue. Error: " << e.what(); } catch (...) { LOG(ERROR) - << "Failed to rerieve NCCL_DESYNC_DEBUG report with unknown error." + << "Failed to rerieve TORCH_NCCL_DESYNC_DEBUG report with unknown error." << " Please file an issue."; } } diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index c1cd3c374fae..79c986e7512f 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -28,23 +28,31 @@ namespace c10d { // Environment variable which controls whether we perform a NCCL healt check // which ensures communicators are healthy at the beginning of init. -static std::vector ENABLE_NCCL_HEALTH_CHECK = { +static std::vector TORCH_ENABLE_NCCL_HEALTH_CHECK = { + "TORCH_ENABLE_NCCL_HEALTH_CHECK", "ENABLE_NCCL_HEALTH_CHECK"}; // Environment variable which controls whether or not wait() is blocking or // non-blocking. -static std::vector NCCL_BLOCKING_WAIT = {"NCCL_BLOCKING_WAIT"}; +static std::vector TORCH_NCCL_BLOCKING_WAIT = { + "TORCH_NCCL_BLOCKING_WAIT", + "NCCL_BLOCKING_WAIT"}; // Environment variable which controls whether or not we perform Async Error // Handling with NCCL. -static std::vector NCCL_ASYNC_ERROR_HANDLING = { +static std::vector TORCH_NCCL_ASYNC_ERROR_HANDLING = { + "TORCH_NCCL_ASYNC_ERROR_HANDLING", "NCCL_ASYNC_ERROR_HANDLING"}; // Environment Variable to control whether Desync Debug is enabled. -// This variable must be set together with NCCL_ASYNC_ERROR_HANDLING. -static std::vector NCCL_DESYNC_DEBUG = {"NCCL_DESYNC_DEBUG"}; +// This variable must be set together with TORCH_NCCL_ASYNC_ERROR_HANDLING. +static std::vector TORCH_NCCL_DESYNC_DEBUG = { + "TORCH_NCCL_DESYNC_DEBUG", + "NCCL_DESYNC_DEBUG"}; -static std::vector NCCL_ENABLE_TIMING = {"NCCL_ENABLE_TIMING"}; +static std::vector TORCH_NCCL_ENABLE_TIMING = { + "TORCH_NCCL_ENABLE_TIMING", + "NCCL_ENABLE_TIMING"}; static std::vector TORCH_NCCL_ENABLE_MONITORING = { "TORCH_NCCL_ENABLE_MONITORING"}; @@ -87,8 +95,9 @@ static std::vector TORCH_NCCL_AVOID_RECORD_STREAMS = { // If set, ProcessGroupNCCL registers postAlloc and preFree hooks to cuda cache // allocator so that whenever a tensor is allocated or freed, ProcessGroupNCCL // can register/deregister the tensor on all available NCCL communicators. -static std::vector NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK = { - "NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK"}; +static std::vector TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK = + {"TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK", + "NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK"}; // ProcessGroupNCCL implements NCCL bindings for c10d. // diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 3a7798b433a7..b4119666d0bc 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -2505,7 +2505,7 @@ Example:: ``fut.wait()`` will return after synchronizing the appropriate NCCL streams with PyTorch's current device streams to ensure we can have asynchronous CUDA execution and it does not wait for the entire operation to complete on GPU. Note that - ``CUDAFuture`` does not support ``NCCL_BLOCKING_WAIT`` flag or NCCL's ``barrier()``. + ``CUDAFuture`` does not support ``TORCH_NCCL_BLOCKING_WAIT`` flag or NCCL's ``barrier()``. In addition, if a callback function was added by ``fut.then()``, it will wait until ``WorkNCCL``'s NCCL streams synchronize with ``ProcessGroupNCCL``'s dedicated callback stream and invoke the callback inline after running the callback on the callback stream. @@ -2534,7 +2534,7 @@ Example:: .. warning :: This API only works for NCCL backend for now and must set - NCCL_ENABLE_TIMING environment variable. + TORCH_NCCL_ENABLE_TIMING environment variable. )") .def( "boxed", diff --git a/torch/csrc/distributed/c10d/logger.cpp b/torch/csrc/distributed/c10d/logger.cpp index 703b896ed8ff..552529282a42 100644 --- a/torch/csrc/distributed/c10d/logger.cpp +++ b/torch/csrc/distributed/c10d/logger.cpp @@ -13,6 +13,13 @@ namespace c10d { +static std::vector TORCH_NCCL_BLOCKING_WAIT = { + "TORCH_NCCL_BLOCKING_WAIT", + "NCCL_BLOCKING_WAIT"}; +static std::vector TORCH_NCCL_ASYNC_ERROR_HANDLING = { + "TORCH_NCCL_ASYNC_ERROR_HANDLING", + "NCCL_ASYNC_ERROR_HANDLING"}; + // Logs runtime stats to configured destination. Note that since data collection // only runs every ddp_runtime_logging_sample_rate iterations, the actual // training iterations recorded will be like 10, @@ -79,9 +86,9 @@ void Logger::set_env_variables() { ddp_logging_data_->strs_map["nccl_socket_ifname"] = getCvarString({"NCCL_SOCKET_IFNAME"}, "N/A"); ddp_logging_data_->strs_map["nccl_blocking_wait"] = - getCvarString({"NCCL_BLOCKING_WAIT"}, "N/A"); + getCvarString(TORCH_NCCL_BLOCKING_WAIT, "N/A"); ddp_logging_data_->strs_map["nccl_async_error_handling"] = - getCvarString({"NCCL_ASYNC_ERROR_HANDLING"}, "N/A"); + getCvarString(TORCH_NCCL_ASYNC_ERROR_HANDLING, "N/A"); ddp_logging_data_->strs_map["nccl_debug"] = getCvarString({"NCCL_DEBUG"}, "N/A"); ddp_logging_data_->strs_map["nccl_nthreads"] = From 2c0474c02d3ac04a429504225d7f1a6536d3b9e6 Mon Sep 17 00:00:00 2001 From: Jiong Gong Date: Tue, 21 Nov 2023 11:51:57 +0800 Subject: [PATCH 039/221] [inductor cpp] vectorize embedding lookup (#114062) For embedding lookup, there are indirect indexing with indices that are invariant to the vectorized itervar. To vectorize it, we need to keep the related indexing variables as scalars and allow vectorization when the related index_exprs are invariant to the vectorized itervar. This PR adds the support by lazily broadcasting scalar values (index_expr and constant) to vectors so that vector operations are only generated if needed by `CppVecKernel` when any of the inputs are vectors, otherwise, scalar ops are generated. The cse variable in cpp is now represented with `CppCSEVariable` which bookkeeps the relevant itervars to the variable and has a flag to mark whether it is a scalar or a vector. `CppVecOverrides` is improved to propagate these states when the ops are executed. For the added UT `test_embedding_vec`, the generated code before this PR is: ```c++ extern "C" void kernel(const long* in_ptr0, const float* in_ptr1, const float* in_ptr2, float* out_ptr0) { #pragma omp parallel num_threads(64) { { #pragma omp for for(long x0=static_cast(0L); x0(128L); x0+=static_cast(1L)) { #pragma GCC ivdep for(long x1=static_cast(0L); x1(128L); x1+=static_cast(1L)) { auto tmp0 = in_ptr0[static_cast(x0)]; auto tmp5 = in_ptr2[static_cast(x1 + (128L*x0))]; auto tmp1 = decltype(tmp0)(tmp0 + 64); auto tmp2 = tmp0 < 0; auto tmp3 = tmp2 ? tmp1 : tmp0; TORCH_CHECK((0 <= tmp3) & (tmp3 < 64L), "index out of bounds: 0 <= tmp3 < 64L") auto tmp4 = in_ptr1[static_cast(x1 + (128L*tmp3))]; auto tmp6 = decltype(tmp4)(tmp4 + tmp5); out_ptr0[static_cast(x1 + (128L*x0))] = tmp6; } } } } } ``` After this PR, we have: ```c++ extern "C" void kernel(const long* in_ptr0, const float* in_ptr1, const float* in_ptr2, float* out_ptr0) { #pragma omp parallel num_threads(64) { { #pragma omp for for(long x0=static_cast(0L); x0(128L); x0+=static_cast(1L)) { for(long x1=static_cast(0L); x1(128L); x1+=static_cast(16L)) { auto tmp0 = in_ptr0[static_cast(x0)]; auto tmp5 = at::vec::Vectorized::loadu(in_ptr2 + static_cast(x1 + (128L*x0))); auto tmp1 = decltype(tmp0)(tmp0 + 64); auto tmp2 = tmp0 < 0; auto tmp3 = tmp2 ? tmp1 : tmp0; TORCH_CHECK((0 <= tmp3) & (tmp3 < 64L), "index out of bounds: 0 <= tmp3 < 64L") auto tmp4 = at::vec::Vectorized::loadu(in_ptr1 + static_cast(x1 + (128L*tmp3))); auto tmp6 = tmp4 + tmp5; tmp6.store(out_ptr0 + static_cast(x1 + (128L*x0))); } } } } } ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/114062 Approved by: https://github.com/jansel ghstack dependencies: #113950 --- test/inductor/test_cpu_repro.py | 18 ++ torch/_inductor/codegen/cpp.py | 291 ++++++++++++++++++--------- torch/_inductor/codegen/cpp_prefix.h | 6 + 3 files changed, 224 insertions(+), 91 deletions(-) diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py index f5ce2369bfdb..06dff1d34e31 100644 --- a/test/inductor/test_cpu_repro.py +++ b/test/inductor/test_cpu_repro.py @@ -1299,6 +1299,7 @@ def test_cpu_vec_cosim(self): cpp_op_list.append(k) diff = [ + "constant", "index_expr", "signbit", "isinf", @@ -2612,6 +2613,23 @@ def forward(self, x): x = torch.randn(1, 39, 1, 18, 17) self.common(m, (x,)) + def test_embedding_vec(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.emb = torch.nn.Embedding(64, 128) + + def forward(self, idx, x): + return self.emb(idx) + x + + idx = torch.randint(0, 64, (4, 32)) + x = torch.randn(4, 32, 128) + m = M().eval() + with torch.no_grad(): + metrics.reset() + self.common(m, (idx, x)) + assert metrics.generated_cpp_vec_kernel_count == 1 + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 90f55ed8b55a..f08c51a0a7b8 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -7,7 +7,7 @@ import re import sys from copy import copy, deepcopy -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Set, Tuple, Union import sympy @@ -133,6 +133,19 @@ ] +def value_to_cpp(value, cpp_type): + if value == float("-inf"): + return f"-std::numeric_limits<{cpp_type}>::infinity()" + elif value == float("inf"): + return f"std::numeric_limits<{cpp_type}>::infinity()" + elif isinstance(value, bool): + return f"static_cast<{cpp_type}>({str(value).lower()})" + elif math.isnan(value): + return f"std::numeric_limits<{cpp_type}>::quiet_NaN()" + else: + return f"static_cast<{cpp_type}>({repr(value)})" + + def reduction_init(reduction_type, dtype): if dtype in DTYPE_LOWP_FP: # Since load promotes all half-precision inputs to float, the initial @@ -436,6 +449,54 @@ def get_current_node_opt_ctx() -> OptimizationContext: return get_opt_ctx(V.interpreter.current_node) +class CppCSEVariable(CSEVariable): + def __init__(self, name, bounds: ValueRanges): + super().__init__(name, bounds) + self.is_vec = False + self.dtype = None + self.dependent_itervars: Set[sympy.Symbol] = set() + + def update_on_args(self, name, args, kwargs): + if name == "load": + # args[1] is index + self._set_dependent_itervars(args[1]) + else: + # propagate relevant itervars and is_vec from args + self.dependent_itervars.update( + *[ + arg.dependent_itervars + for arg in args + if isinstance(arg, CppCSEVariable) + ] + ) + if name == "index_expr": + self._set_dependent_itervars(args[0]) + if any(arg.is_vec for arg in args if isinstance(arg, CppCSEVariable)): + self.is_vec = True + if ( + hasattr(V.interpreter, "current_node") + and get_current_node_opt_ctx() is not None + ): + self.dtype = get_current_node_opt_ctx().dtype + + def _set_dependent_itervars(self, index: sympy.Expr): + """ + Set the relevant itervars for this variable based on the `index` expression. + This includes the itervars directly used in the `index` as well as relevant itervars + of other cse variables used in the `index`. + """ + for s in index.free_symbols: + if s in V.kernel.itervars: + self.dependent_itervars.add(s) + elif s.name in V.kernel.cse.varname_map: + self.dependent_itervars.update( + V.kernel.cse.varname_map[s.name].dependent_itervars + ) + + def depends_on(self, itervar: sympy.Symbol): + return itervar in self.dependent_itervars + + class CppOverrides(OpOverrides): """Map element-wise ops to C++""" @@ -672,22 +733,20 @@ def mod(a, b): @staticmethod def constant(val, dtype): + opt_ctx: OptimizationContext = get_current_node_opt_ctx() + assert opt_ctx and opt_ctx.dtype is not None + dtype = opt_ctx.dtype if dtype in DTYPE_LOWP_FP: # Since load promotes all half-precision inputs to float, constants # must be promoted as well dtype = torch.float32 - if val == float("inf"): - return f"std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::infinity()" - elif val == float("-inf"): - return f"-std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::infinity()" - elif math.isnan(val): - return f"std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::quiet_NaN()" - elif val is True or val is False: - return ops.to_dtype(str(val).lower(), dtype) - return ops.to_dtype(repr(val), dtype) + return value_to_cpp(val, DTYPE_TO_CPP[dtype]) @staticmethod def index_expr(expr, dtype): + opt_ctx: OptimizationContext = get_current_node_opt_ctx() + assert opt_ctx and opt_ctx.dtype is not None + dtype = opt_ctx.dtype return ops.to_dtype(cexpr(V.kernel.rename_indexing(expr)), dtype) @staticmethod @@ -704,19 +763,7 @@ def masked(mask, body, other): V.kernel.compute.splice(code) # Use the lambda's return type as the type of other - type = f"decltype({body_var}())" - - if other == float("-inf"): - other_code = f"-std::numeric_limits<{type}>::infinity()" - elif other == float("inf"): - other_code = f"std::numeric_limits<{type}>::infinity()" - elif isinstance(other, bool): - other_code = f"static_cast<{type}>({str(other).lower()})" - elif math.isnan(other): - other_code = f"std::numeric_limits<{type}>::quiet_NaN()" - else: - other_code = f"static_cast<{type}>({repr(other)})" - + other_code = value_to_cpp(other, f"decltype({body_var}())") return f"{mask} ? {body_var}() : {other_code}" @staticmethod @@ -794,6 +841,54 @@ def sign(x): class CppVecOverrides(CppOverrides): """Map element-wise ops to aten vectorization C++""" + def __new__(cls, *args, **kargs): + self = super().__new__(cls) + + def wrap(func): + # `CppVecKernel` generates both scalar ops and vector ops according to + # whether the inputs are scalars or vectors while all ops in `CppVecOverrides` + # (except for "masked") assume the inputs are vectors. We wrap the ops in + # `CppVecOverrides` to broadcast scalar inputs to vectors if needed or fallback to + # `CppOverrides` when all inputs are scalars. + # + # Inputs to ops.masked are handled separately in its own function due to + # the need of recurive handling of masked body. + def wrapper(*args, **kwargs): + has_scalar = any( + not arg.is_vec for arg in args if isinstance(arg, CppCSEVariable) + ) + has_vector = any( + arg.is_vec for arg in args if isinstance(arg, CppCSEVariable) + ) + new_args = list(args) + if has_scalar and has_vector: + # broadcast scalar args to vector if needed + new_args = [] + for arg in args: + if isinstance(arg, CppCSEVariable) and not arg.is_vec: + assert isinstance(V.kernel, CppVecKernel) + new_arg = V.kernel.broadcast(arg) + new_args.append(new_arg) + else: + new_args.append(arg) + if has_vector: + return func(*new_args, **kwargs) + else: + # fallback to scalar ops + scalar_ops = super(CppVecOverrides, self) + scalar_func = getattr( + scalar_ops, func.__name__, scalar_ops.__getattr__(func.__name__) # type: ignore[attr-defined] + ) + assert scalar_func is not None + return scalar_func(*args, **kwargs) + + return wrapper + + for name, method in vars(cls).items(): + if getattr(method, "__class__", None) == staticmethod and name != "masked": + setattr(self, name, wrap(method.__func__)) + return self + @staticmethod def add(a, b): return f"{a} + {b}" @@ -1006,28 +1101,6 @@ def acosh(x): vec_one = f"decltype({x})(1)" return f"({x} + ({x}*{x} - {vec_one}).sqrt()).log()" - @staticmethod - def constant(val, dtype): - opt_ctx: OptimizationContext = get_current_node_opt_ctx() - assert opt_ctx - proposed_dtype = opt_ctx.dtype - assert proposed_dtype in [ - torch.float, - torch.int32, - ] - if val == float("inf"): - quote = f"std::numeric_limits<{DTYPE_TO_CPP[proposed_dtype]}>::infinity()" - elif val == float("-inf"): - quote = f"-std::numeric_limits<{DTYPE_TO_CPP[proposed_dtype]}>::infinity()" - elif math.isnan(val): - quote = f"std::numeric_limits<{DTYPE_TO_CPP[proposed_dtype]}>::quiet_NaN()" - elif val is True or val is False: - quote = f"static_cast<{DTYPE_TO_CPP[proposed_dtype]}>({str(val).lower()})" - else: - quote = f"static_cast<{DTYPE_TO_CPP[proposed_dtype]}>({repr(val)})" - - return f"at::vec::Vectorized<{DTYPE_TO_CPP[proposed_dtype]}>({quote})" - @staticmethod def relu(x): bug = config.cpp.inject_relu_bug_TESTING_ONLY @@ -1159,32 +1232,24 @@ def masked(mask, body, other): code.writeline(";") V.kernel.compute.splice(code) - if other == float("-inf"): - other_code = ( - "at::vec::Vectorized(-std::numeric_limits::infinity())" - ) - elif other == float("inf"): - other_code = ( - "at::vec::Vectorized(std::numeric_limits::infinity())" - ) - elif math.isnan(other): - other_code = ( - "at::vec::Vectorized(std::numeric_limits::quiet_NaN())" + other_code = value_to_cpp(other, "float") + other_code_vec = f"at::vec::Vectorized({other_code})" + + if result.is_vec: + type = f"decltype({var}())" + float_mask = f"to_float_mask({new_mask})" + csevar = V.kernel.cse.generate( + V.kernel.compute, + f"{type}::blendv({other_code_vec}, {var}(), {float_mask})", ) else: - other_code = f"at::vec::Vectorized({other!r})" - type = f"decltype({var}())" - float_mask = f"to_float_mask({new_mask})" - return f"{type}::blendv({other_code}, {var}(), {float_mask})" - - @staticmethod - def index_expr(expr, dtype): - assert dtype == torch.int64 - opt_ctx: OptimizationContext = get_current_node_opt_ctx() - assert opt_ctx - assert opt_ctx.dtype == torch.int32 - assert opt_ctx.is_most_inner_loop_irrevelant - return f"at::vec::Vectorized(static_cast({cexpr(V.kernel.rename_indexing(expr))}))" + csevar = V.kernel.cse.generate( + V.kernel.compute, f"{mask} ? {var}() : {other_code}" + ) + # `result` is explicitly added to the args for correct propagation + # of relevant itervars and vectorization status. + csevar.update_on_args("masked", (mask, body, other, result), {}) + return csevar class CppKernel(Kernel): @@ -1242,7 +1307,9 @@ def load(self, name: str, index: sympy.Expr): line = f"{var}[{cexpr_index(index)}]" if V.graph.get_dtype(name) in [torch.float16]: line = f"static_cast({line})" - return self.cse.generate(self.loads, line) + csevar = self.cse.generate(self.loads, line) + csevar.update_on_args("load", (name, index), {}) + return csevar def store(self, name, index, value, mode=None): assert "buf" in name @@ -1472,6 +1539,9 @@ def write_to_suffix(self): self.reduction_suffix.splice(self.stores) (self.loads, self.compute, self.stores, self.cse) = prior + def create_cse_var(self, *args, **kwargs): + return CppCSEVariable(*args, **kwargs) + class CppVecKernel(CppKernel): overrides = CppVecOverrides # type: ignore[assignment] @@ -1506,7 +1576,11 @@ def load(self, name: str, index: sympy.Expr): non_contiguous = ( not is_broadcast and stride_at(tiling_var, index) != 1 - or "tmp" in f"{index}" + or any( + self.cse.varname_map[s.name].depends_on(tiling_var) + for s in index.free_symbols + if s.name.startswith("tmp") + ) ) var_expr = ( f"{var}[{cexpr_index(index)}]" @@ -1515,13 +1589,9 @@ def load(self, name: str, index: sympy.Expr): ) loadbuf = "tmpbuf" if non_contiguous else var_expr if is_broadcast: - # should always be broadcast as float for vectorization since we always use float to compute - if is_mask: - loadbuf = f"flag_to_float_scalar({loadbuf})" - if dtype in DTYPE_LOWP_FP: - line = f"at::vec::Vectorized<{DTYPE_TO_CPP[dtype]}>({loadbuf})" - else: - line = f"at::vec::Vectorized(static_cast({loadbuf}))" + csevar = super().load(name, index) + csevar.dtype = dtype + return csevar elif dtype in [torch.uint8] and opt_ctx.is_load_uint8_as_float: line = ( f"masked_load({loadbuf}, {load_mask})" @@ -1563,7 +1633,11 @@ def load(self, name: str, index: sympy.Expr): tmpbufdefine += f"tmpbuf[{inner}] = {rhs};" line = f"([&]() {{ {tmpbufdeclare} {tmpbufdefine} return {line}; }})()" - return self.cse.generate(self.loads, line) + csevar = self.cse.generate(self.loads, line) + csevar.update_on_args("load", (name, index), {}) + assert isinstance(csevar, CppCSEVariable) + csevar.is_vec = True + return csevar def get_vec_store_line(self, value, var, index, dtype): """ @@ -1572,6 +1646,11 @@ def get_vec_store_line(self, value, var, index, dtype): :param var: buffer to store into. :index: index into the `var`. """ + # when value's type is str (e.g., welford reduction), caller should make sure + # it is a vector + assert isinstance(value, str) or ( + isinstance(value, CppCSEVariable) and value.is_vec + ), value tiling_var = self.itervars[self.tiling_idx] assert index.has(tiling_var) var_expr = f"{var} + {cexpr_index(index)}" @@ -1600,6 +1679,10 @@ def get_vec_store_line(self, value, var, index, dtype): def store(self, name, index, value, mode=None): assert "buf" in name assert mode is None + assert isinstance(value, CppCSEVariable), value + if not value.is_vec: + # this happens when we store a scalar into a vectorized buffer like "fill" + value = self.broadcast(value) opt_ctx: OptimizationContext = get_current_node_opt_ctx() var = self.args.output(name) index = self.rename_indexing(index) @@ -1622,6 +1705,7 @@ def reduction(self, dtype, src_dtype, reduction_type, value): } assert dtype == torch.float assert src_dtype == torch.float + assert isinstance(value, CppCSEVariable) and value.is_vec, value vec_ns = "at::vec" vec = f"{vec_ns}::Vectorized<{DTYPE_TO_CPP[dtype]}>" @@ -1740,6 +1824,26 @@ def store_reduction(self, name, index, value): ] self.reduction_suffix.writelines(store_lines) + def broadcast(self, scalar_var: CppCSEVariable): + assert ( + not scalar_var.is_vec + and self.itervars[self.tiling_idx] not in scalar_var.dependent_itervars + ) + if scalar_var.dtype == torch.bool: + vec_var = self.cse.generate( + self.compute, f"to_float_mask({scalar_var.name})" + ) + else: + vec_var = self.cse.generate( + self.compute, + f"at::vec::Vectorized<{DTYPE_TO_CPP[scalar_var.dtype]}>({scalar_var.name})", + ) + assert isinstance(vec_var, CppCSEVariable) + vec_var.dtype = scalar_var.dtype + vec_var.dependent_itervars = scalar_var.dependent_itervars + vec_var.is_vec = True + return vec_var + class CppTile2DKernel(CppVecKernel): """ @@ -1849,7 +1953,11 @@ def load(self, name: str, index: sympy.Expr): line = f"at::vec::Vectorized::loadu_one_fourth({loadbuf})" else: line = f"at::vec::Vectorized::loadu({loadbuf})" - return self.cse.generate(self.loads, line) + csevar = self.cse.generate(self.loads, line) + csevar.update_on_args("load", (name, index), {}) + assert isinstance(csevar, CppCSEVariable) + csevar.is_vec = True + return csevar else: new_index = self.scale_index_with_offset( index, @@ -1950,10 +2058,6 @@ def disable_vec(self, msg=None): schedule_log.debug("Disabled vectorization: %s", msg) self.simd_vec = False - def could_vec(self, name: str, index: sympy.Expr): - assert self.itervars is not None - return len(self.itervars) > 0 - def is_mask(self, name: str, users: Dict[torch.fx.Node, None]): load_type = V.graph.get_dtype(name) if load_type == torch.bool: @@ -2036,6 +2140,10 @@ def load(self, name: str, index: sympy.Expr): var = self.cse.newvar() + if len(self.itervars) == 0: + self.disable_vec("not a loop") + return var + if load_dtype in [torch.bool, torch.uint8] and not ( opt_ctx.is_load_as_mask or opt_ctx.is_load_uint8_as_float ): @@ -2046,18 +2154,21 @@ def load(self, name: str, index: sympy.Expr): return var if ( - load_dtype not in self.load_supported_dtypes - ) and not self.is_load_integer_scalar_tensor(name, index): + (load_dtype not in self.load_supported_dtypes) + and not self.is_load_integer_scalar_tensor(name, index) + and index.has(self.itervars[self.tiling_idx]) + ): self.disable_vec(f"{load_dtype} not supported by load") return var - index = self.rename_indexing(index) - if self.simd_vec and not self.could_vec(name, index): - self.disable_vec(f"not a loop: {index}") return var def store(self, name, index, value, mode=None): with RecordOptimizationContext(__name__) as node_ctx: + if len(self.itervars) == 0: + self.disable_vec("not a loop") + return self.simd_vec + store_dtype = V.graph.get_dtype(name) opt_ctx: OptimizationContext = node_ctx.get_opt_ctx() @@ -2085,8 +2196,6 @@ def store(self, name, index, value, mode=None): if index.is_number: self.disable_vec(f"constant store index: {index}") - if self.simd_vec and not self.could_vec(name, index): - self.disable_vec(f"not a loop: {index}") return self.simd_vec def reduction(self, dtype, src_dtype, reduction_type, value): diff --git a/torch/_inductor/codegen/cpp_prefix.h b/torch/_inductor/codegen/cpp_prefix.h index 1a532029cdeb..23f72218a0cc 100644 --- a/torch/_inductor/codegen/cpp_prefix.h +++ b/torch/_inductor/codegen/cpp_prefix.h @@ -401,4 +401,10 @@ template <> inline at::vec::Vectorized to_float_mask(at::vec::Vectorized src) { return src; } + +inline at::vec::Vectorized to_float_mask(int src) { + float mask; + *(uint32_t*)&mask = src ? 0xFFFFFFFF : 0; + return at::vec::Vectorized(mask); +} #endif From 6c597ef015c5676d1e32f43842ca4be077e7c989 Mon Sep 17 00:00:00 2001 From: Digant Desai Date: Tue, 21 Nov 2023 07:48:11 +0000 Subject: [PATCH 040/221] [PyTorch] Fix attr cleanup after constant folding (#113957) Summary: Two nodes can point to the same attribute via node.target. This makes sure, - we don't try to delete already deleted attribute, i.e. delete attr only once - we do delete all the nodes pointing to the attribute Test Plan: ``` buck run fbcode//mode/dev-nosan fbcode//executorch/backends/xnnpack/test:test_xnnpack_passes -- executorch.backends.xnnpack.test.passes.test_batch_norm_fusion.TestBatchNormFusion.test_q8_batch_norm_fusion ``` Differential Revision: D51419442 Pull Request resolved: https://github.com/pytorch/pytorch/pull/113957 Approved by: https://github.com/Skylion007 --- torch/_inductor/constant_folding.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch/_inductor/constant_folding.py b/torch/_inductor/constant_folding.py index cd478a4cb252..9d9e166ff27c 100644 --- a/torch/_inductor/constant_folding.py +++ b/torch/_inductor/constant_folding.py @@ -178,7 +178,8 @@ def constant_fold(gm, constraint_fn: Optional[Callable[[torch.fx.Node], bool]] = erased_params = [] for node in gm.graph.nodes: if node.op == "get_attr" and len(node.users) == 0: - delattr(gm, node.target) + if hasattr(gm, node.target): + delattr(gm, node.target) erased_params.append(node) for node in erased_params: From 5f0d72124e28d1b2f16324ddb4ecf6c2463cb202 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 21 Nov 2023 08:58:54 +0000 Subject: [PATCH 041/221] Revert "Print the index and summary of the SampleInput that failed an OpInfo test (#99444)" This reverts commit e7f12b1eb0cedfd20dcb41ea35e21e9a71e3390a. Reverted https://github.com/pytorch/pytorch/pull/99444 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it seems to cause memory leak on CUDA job https://hud.pytorch.org/pytorch/pytorch/commit/e7f12b1eb0cedfd20dcb41ea35e21e9a71e3390a ([comment](https://github.com/pytorch/pytorch/pull/99444#issuecomment-1820491298)) --- test/test_testing.py | 8 +- torch/testing/_internal/common_device_type.py | 27 +------ torch/testing/_internal/common_utils.py | 76 ------------------- torch/testing/_internal/opinfo/core.py | 32 ++------ 4 files changed, 14 insertions(+), 129 deletions(-) diff --git a/test/test_testing.py b/test/test_testing.py index 542601d7ed97..feb408773f4c 100644 --- a/test/test_testing.py +++ b/test/test_testing.py @@ -12,7 +12,7 @@ import subprocess import sys import unittest.mock -from typing import Any, Callable, Iterator, List, Tuple +from typing import Any, Callable, Iterator, List, Tuple, Generator import torch @@ -2397,19 +2397,19 @@ class TestOpInfoSampleFunctions(TestCase): def test_opinfo_sample_generators(self, device, dtype, op): # Test op.sample_inputs doesn't generate multiple samples when called samples = op.sample_inputs(device, dtype) - self.assertIsInstance(samples, Iterator) + self.assertIsInstance(samples, Generator) @ops([op for op in op_db if op.reference_inputs_func is not None], dtypes=OpDTypes.any_one) def test_opinfo_reference_generators(self, device, dtype, op): # Test op.reference_inputs doesn't generate multiple samples when called samples = op.reference_inputs(device, dtype) - self.assertIsInstance(samples, Iterator) + self.assertIsInstance(samples, Generator) @ops([op for op in op_db if op.error_inputs_func is not None], dtypes=OpDTypes.none) def test_opinfo_error_generators(self, device, op): # Test op.error_inputs doesn't generate multiple inputs when called samples = op.error_inputs(device) - self.assertIsInstance(samples, Iterator) + self.assertIsInstance(samples, Generator) instantiate_device_type_tests(TestOpInfoSampleFunctions, globals()) diff --git a/torch/testing/_internal/common_device_type.py b/torch/testing/_internal/common_device_type.py index 4b550e95187d..96b7817b5c4a 100644 --- a/torch/testing/_internal/common_device_type.py +++ b/torch/testing/_internal/common_device_type.py @@ -15,8 +15,7 @@ skipCUDANonDefaultStreamIf, TEST_WITH_ASAN, TEST_WITH_UBSAN, TEST_WITH_TSAN, \ IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU, IS_WINDOWS, TEST_MPS, \ _TestParametrizer, compose_parametrize_fns, dtype_name, \ - TEST_WITH_MIOPEN_SUGGEST_NHWC, NATIVE_DEVICES, skipIfTorchDynamo, \ - get_tracked_input, PRINT_REPRO_ON_FAILURE + TEST_WITH_MIOPEN_SUGGEST_NHWC, NATIVE_DEVICES, skipIfTorchDynamo from torch.testing._internal.common_cuda import _get_torch_cuda_version, \ TEST_CUSPARSE_GENERIC, TEST_HIPSPARSE_GENERIC, _get_torch_rocm_version from torch.testing._internal.common_dtype import get_all_dtypes @@ -797,12 +796,6 @@ class OpDTypes(Enum): torch.bool ) -def _serialize_sample(sample_input): - # NB: For OpInfos, SampleInput.summary() prints in a cleaner way. - if getattr(sample_input, "summary", None) is not None: - return sample_input.summary() - return str(sample_input) - # Decorator that defines the OpInfos a test template should be instantiated for. # # Example usage: @@ -912,23 +905,7 @@ def _parametrize_test(self, test, generic_cls, device_cls): try: @wraps(test) def test_wrapper(*args, **kwargs): - try: - return test(*args, **kwargs) - except unittest.SkipTest as e: - raise e - except Exception as e: - tracked_input = get_tracked_input() - if PRINT_REPRO_ON_FAILURE and tracked_input is not None: - raise Exception( - f"Caused by {tracked_input.type_desc} " - f"at index {tracked_input.index}: " - f"{_serialize_sample(tracked_input.val)}") from e - raise e - - # Initialize info for the last input seen. This is useful for tracking - # down which inputs caused a test failure. Note that TrackedInputIter is - # responsible for managing this. - test.tracked_input = None + return test(*args, **kwargs) decorator_fn = partial(op.get_decorators, generic_cls.__name__, test.__name__, device_cls.device_type, dtype) diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 5149261f9935..30f0311ba7b3 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -36,7 +36,6 @@ from collections.abc import Mapping, Sequence from contextlib import closing, contextmanager from copy import deepcopy -from dataclasses import dataclass from enum import Enum from functools import partial, wraps from itertools import product, chain @@ -238,81 +237,6 @@ def wrapper(*args, **kwargs): fn(*args, **kwargs) return wrapper -# Tries to extract the current test function by crawling the stack. -# If unsuccessful, return None. -def extract_test_fn() -> Optional[Callable]: - try: - stack = inspect.stack() - for frame_info in stack: - frame = frame_info.frame - if "self" not in frame.f_locals: - continue - self_val = frame.f_locals["self"] - if isinstance(self_val, unittest.TestCase): - test_id = self_val.id() - test_name = test_id.split('.')[2] - test_fn = getattr(self_val, test_name).__func__ - return test_fn - except Exception: - pass - return None - -# Contains tracked input data useful for debugging purposes -@dataclass -class TrackedInput: - index: int - val: Any - type_desc: str - -# Attempt to pull out tracked input information from the test function. -# A TrackedInputIter is used to insert this information. -def get_tracked_input() -> Optional[TrackedInput]: - test_fn = extract_test_fn() - if test_fn is None: - return None - if not hasattr(test_fn, "tracked_input"): - return None - return test_fn.tracked_input - -# Wraps an iterator and tracks the most recent value the iterator produces -# for debugging purposes. Tracked values are stored on the test function. -class TrackedInputIter: - def __init__(self, child_iter, input_type_desc, callback=lambda x: x): - self.child_iter = enumerate(child_iter) - # Input type describes the things we're tracking (e.g. "sample input", "error input"). - self.input_type_desc = input_type_desc - # Callback is run on each iterated thing to get the thing to track. - self.callback = callback - self.test_fn = extract_test_fn() - - def __iter__(self): - return self - - def __next__(self): - try: - input_idx, input_val = next(self.child_iter) - self._set_tracked_input( - TrackedInput( - index=input_idx, val=self.callback(input_val), type_desc=self.input_type_desc - ) - ) - return input_val - except StopIteration as e: - self._clear_tracked_input() - raise e - - def _set_tracked_input(self, tracked_input: TrackedInput): - if self.test_fn is None: - return - if not hasattr(self.test_fn, "tracked_input"): - return - self.test_fn.tracked_input = tracked_input - - def _clear_tracked_input(self): - if self.test_fn is not None and hasattr(self.test_fn, "tracked_input"): - self.test_fn.tracked_input = None - self.test_fn = None - class _TestParametrizer: """ Decorator class for parametrizing a test function, yielding a set of new tests spawned diff --git a/torch/testing/_internal/opinfo/core.py b/torch/testing/_internal/opinfo/core.py index 23b6e89e4a21..fc0fbf95864f 100644 --- a/torch/testing/_internal/opinfo/core.py +++ b/torch/testing/_internal/opinfo/core.py @@ -29,7 +29,6 @@ noncontiguous_like, TEST_WITH_ROCM, torch_to_numpy_dtype_dict, - TrackedInputIter, ) from torch.testing._internal.opinfo import utils @@ -208,6 +207,7 @@ def _repr_helper(self, formatter): f"input={formatter(self.input)}", f"args={formatter(self.args)}", f"kwargs={formatter(self.kwargs)}", + f"output_process_fn_grad={self.output_process_fn_grad}", f"broadcasts_input={self.broadcasts_input}", f"name={repr(self.name)}", ] @@ -227,15 +227,8 @@ def formatter(arg): # by Tensor[TensorShape] # Eg. Tensor with shape (3, 4) is formatted as Tensor[3, 4] if isinstance(arg, torch.Tensor): - shape = str(tuple(arg.shape)) - dtype = str(arg.dtype) - device = str(arg.device) - contiguity_suffix = "" - # NB: sparse CSR tensors annoyingly return is_sparse=False - is_sparse = arg.is_sparse or arg.layout == torch.sparse_csr - if not is_sparse and not arg.is_contiguous(): - contiguity_suffix = ", contiguous=False" - return f'Tensor[size={shape}, device="{device}", dtype={dtype}{contiguity_suffix}]' + shape = str(tuple(arg.shape)).replace("(", "").replace(")", "") + return f"Tensor[{shape}]" elif isinstance(arg, dict): return {k: formatter(v) for k, v in arg.items()} elif is_iterable_of_tensors(arg): @@ -1162,7 +1155,7 @@ def conjugate(tensor): else: sample.input[0] = conjugate(sample.input[0]) - return TrackedInputIter(iter(conj_samples), "conjugate sample input") + return tuple(conj_samples) def sample_inputs(self, device, dtype, requires_grad=False, **kwargs): """ @@ -1181,7 +1174,7 @@ def sample_inputs(self, device, dtype, requires_grad=False, **kwargs): samples_list.extend(conj_samples) samples = tuple(samples_list) - return TrackedInputIter(iter(samples), "sample input") + return samples def reference_inputs(self, device, dtype, requires_grad=False, **kwargs): """ @@ -1192,27 +1185,18 @@ def reference_inputs(self, device, dtype, requires_grad=False, **kwargs): the sample inputs. """ if self.reference_inputs_func is None: - samples = self.sample_inputs_func( - self, device, dtype, requires_grad, **kwargs - ) - return TrackedInputIter(iter(samples), "sample input") + return self.sample_inputs_func(self, device, dtype, requires_grad, **kwargs) if kwargs.get("include_conjugated_inputs", False): raise NotImplementedError - references = self.reference_inputs_func( - self, device, dtype, requires_grad, **kwargs - ) - return TrackedInputIter(iter(references), "reference input") + return self.reference_inputs_func(self, device, dtype, requires_grad, **kwargs) def error_inputs(self, device, **kwargs): """ Returns an iterable of ErrorInputs. """ - errs = self.error_inputs_func(self, device, **kwargs) - return TrackedInputIter( - iter(errs), "error input", callback=lambda e: e.sample_input - ) + return self.error_inputs_func(self, device, **kwargs) def error_inputs_sparse(self, device, layout, **kwargs): """ From 1efff12a884dfb82bbea425f4fd62ad8925085dd Mon Sep 17 00:00:00 2001 From: Justin Yip Date: Tue, 21 Nov 2023 09:06:33 +0000 Subject: [PATCH 042/221] [pytorch-vulkan] BinaryOps auto convert int tensors into float (#114145) Summary: Some model has hardcoded int constant tensors for some binary operations. Test Plan: ``` yipjustin@yipjustin-mbp fbsource % buck2 run -c pt.has_backtraces=1 --target-platforms ovr_config//platform/macos:arm64-fbsource //xplat/caffe2:pt_vulkan_api_test_binAppleMac\#macosx-arm64 -- --gtest_filter="*" ... [ OK ] VulkanAPITest.linear_3d_flat (0 ms) [ RUN ] VulkanAPITest.linear_3d_small [ OK ] VulkanAPITest.linear_3d_small (0 ms) [ RUN ] VulkanAPITest.linear_3d_large [ OK ] VulkanAPITest.linear_3d_large (0 ms) [ RUN ] VulkanAPITest.linear_4d_flat [ OK ] VulkanAPITest.linear_4d_flat (0 ms) [ RUN ] VulkanAPITest.linear_4d_small [ OK ] VulkanAPITest.linear_4d_small (0 ms) [ RUN ] VulkanAPITest.linear_4d_large [ OK ] VulkanAPITest.linear_4d_large (0 ms) [ RUN ] VulkanAPITest.lstm_success [ OK ] VulkanAPITest.lstm_success (5 ms) [ RUN ] VulkanAPITest.lstm_mclareninputs_success [ OK ] VulkanAPITest.lstm_mclareninputs_success (21 ms) [ RUN ] VulkanAPITest.lstm_prepack_success [ OK ] VulkanAPITest.lstm_prepack_success (8 ms) [ RUN ] VulkanAPITest.querypool_flushed_shader_log xplat/caffe2/aten/src/ATen/test/vulkan_api_test.cpp:8108: Skipped QueryPool is not available [ SKIPPED ] VulkanAPITest.querypool_flushed_shader_log (0 ms) [----------] 414 tests from VulkanAPITest (5690 ms total) [----------] Global test environment tear-down [==========] 414 tests from 1 test suite ran. (5690 ms total) [ PASSED ] 413 tests. [ SKIPPED ] 1 test, listed below: [ SKIPPED ] VulkanAPITest.querypool_flushed_shader_log YOU HAVE 9 DISABLED TESTS ``` Full Paste: P885827407 Differential Revision: D51452935 Pull Request resolved: https://github.com/pytorch/pytorch/pull/114145 Approved by: https://github.com/SS-JIA --- aten/src/ATen/native/vulkan/ops/BinaryOp.cpp | 39 +++++++++++- aten/src/ATen/native/vulkan/ops/Mm.cpp | 10 --- aten/src/ATen/test/vulkan_api_test.cpp | 65 +++++++++++++++++++- 3 files changed, 100 insertions(+), 14 deletions(-) diff --git a/aten/src/ATen/native/vulkan/ops/BinaryOp.cpp b/aten/src/ATen/native/vulkan/ops/BinaryOp.cpp index 4bd6611b2288..c197da25a6dc 100644 --- a/aten/src/ATen/native/vulkan/ops/BinaryOp.cpp +++ b/aten/src/ATen/native/vulkan/ops/BinaryOp.cpp @@ -66,6 +66,39 @@ Tensor binary_op_scalar( return convert(v_output); } +Tensor binary_op_preprocess_other_arg(const Tensor& other_arg) { + // Similar to binary_op_scalar where tensors is mapped to float, we + // also map known integer types (but not quant types) tensor to float. + + // Such conversion can only to be done before moving to vulkan, since vulkan + // doesn't yet support integer types. + Tensor other = other_arg; + if (!other.is_vulkan()) { + switch (other.scalar_type()) { + case at::kByte: + case at::kChar: + case at::kShort: + case at::kInt: + case at::kLong: + case at::kDouble: + other = other.to(kFloat); + break; + case at::kFloat: + // No op for expected type. + break; + default: + TORCH_CHECK( + false, + "binary_op_tensor, doesn't support type %s", + other.scalar_type()); + break; + } + other = other.vulkan(); + } + + return other; +} + Tensor& binary_op_scalar_( Tensor& self_arg, const Scalar& other, @@ -127,7 +160,8 @@ Tensor binary_op_tensor( const Tensor self = self_arg.is_vulkan() ? self_arg : self_arg.vulkan(); const vTensor& v_self = convert(self); - const Tensor other = other_arg.is_vulkan() ? other_arg : other_arg.vulkan(); + Tensor other = binary_op_preprocess_other_arg(other_arg); + const vTensor& v_other = convert(other); vTensor v_output{ @@ -301,7 +335,8 @@ Tensor& binary_op_tensor_( vTensor& v_self = convert(self_arg); - const Tensor other = other_arg.is_vulkan() ? other_arg : other_arg.vulkan(); + Tensor other = binary_op_preprocess_other_arg(other_arg); + const vTensor& v_other = convert(other); const double alpha = alpha_arg ? alpha_arg->to() : 1.0; diff --git a/aten/src/ATen/native/vulkan/ops/Mm.cpp b/aten/src/ATen/native/vulkan/ops/Mm.cpp index c08d41db3c84..19bef2d559cf 100644 --- a/aten/src/ATen/native/vulkan/ops/Mm.cpp +++ b/aten/src/ATen/native/vulkan/ops/Mm.cpp @@ -91,16 +91,6 @@ vTensor pack_weights(const Tensor& weight_arg, const bool use_batch = false) { // Rest of the logic are either quantized or batched. - bool quantized = false; - switch (weight_arg.scalar_type()) { - case at::kQInt8: - case at::kQUInt8: - quantized = true; - break; - default: - break; - } - api::Context* const context = api::context(); const Tensor weight = weight_arg.contiguous(); diff --git a/aten/src/ATen/test/vulkan_api_test.cpp b/aten/src/ATen/test/vulkan_api_test.cpp index 1dc7a0ad323d..e4fcf5462af2 100644 --- a/aten/src/ATen/test/vulkan_api_test.cpp +++ b/aten/src/ATen/test/vulkan_api_test.cpp @@ -489,6 +489,47 @@ TEST_F(VulkanAPITest, add_zero_dim) { test_add({2, 6, 5, 6}, {}, 1.5f); } +void test_add_other_cpu_int( + const at::IntArrayRef input_shape, + const at::IntArrayRef other_shape, + float alpha) { + const auto in_cpu = + at::rand(input_shape, at::device(at::kCPU).dtype(at::kFloat)); + const auto other_cpu = + (at::rand(input_shape, at::device(at::kCPU).dtype(at::kFloat)) * 100) + .to(at::kInt); + + const auto in_vulkan = in_cpu.vulkan(); + + const auto out_cpu = at::add(in_cpu, other_cpu, alpha); + const auto out_vulkan = at::add(in_vulkan, other_cpu, alpha); + + const auto check = almostEqual(out_cpu, out_vulkan.cpu()); + if (!check) { + showRtol(out_cpu, out_vulkan.cpu()); + } + + ASSERT_TRUE(check); +} + +TEST_F(VulkanAPITest, add_other_cpu_int) { + test_add_other_cpu_int({2, 3}, {2, 3}, 1.0f); + test_add_other_cpu_int({11, 7, 139, 109}, {11, 7, 139, 109}, 2.1f); +} + +TEST_F(VulkanAPITest, add_broadcast0_other_cpu_int) { + test_add_other_cpu_int({3, 5, 179, 221}, {3, 5, 1, 1}, 1.8f); +} + +TEST_F(VulkanAPITest, add_other_cpu_unsupported_type_should_fail) { + const auto in_cpu = at::rand({2,2,2}, at::device(at::kCPU).dtype(at::kFloat)); + + const auto other_cpu = + at::zeros({2, 2, 2}, at::device(at::kCPU).dtype(at::kComplexFloat)); + + EXPECT_THROW(at::add(in_cpu.vulkan(), other_cpu.vulkan(), 1.0f), ::c10::Error); +} + TEST_F(VulkanAPITest, add_) { auto a_cpu = at::rand({61, 17, 29, 83}, at::device(at::kCPU).dtype(at::kFloat)); auto a_vulkan = a_cpu.vulkan(); @@ -501,7 +542,7 @@ TEST_F(VulkanAPITest, add_) { const auto check = almostEqual(a_cpu, a_vulkan.cpu()); if (!check) { - showRtol(b_cpu, b_vulkan.cpu()); + showRtol(a_cpu, a_vulkan.cpu()); } ASSERT_TRUE(check); @@ -519,12 +560,32 @@ TEST_F(VulkanAPITest, add_broadcast0_) { const auto check = almostEqual(a_cpu, a_vulkan.cpu()); if (!check) { - showRtol(b_cpu, b_vulkan.cpu()); + showRtol(a_cpu, a_vulkan.cpu()); } ASSERT_TRUE(check); } +TEST_F(VulkanAPITest, add_other_cpu_int_) { + std::vector input_shape{12, 17, 29, 33}; + const auto in_cpu = + at::rand(input_shape, at::device(at::kCPU).dtype(at::kFloat)); + const auto other_cpu = + (at::rand(input_shape, at::device(at::kCPU).dtype(at::kFloat)) * 100) + .to(at::kInt); + + const auto in_vulkan = in_cpu.vulkan(); + + float alpha = -8.31f; + in_cpu.add(other_cpu, alpha); + in_vulkan.add(other_cpu, alpha); + + const auto check = almostEqual(in_cpu, in_vulkan.cpu()); + if (!check) { + showRtol(in_cpu, in_vulkan.cpu()); + } +} + TEST_F(VulkanAPITest, add_broadcast1_) { auto a_cpu = at::rand({3, 8, 29, 83}, at::device(at::kCPU).dtype(at::kFloat)); auto a_vulkan = a_cpu.vulkan(); From dd6ef0877ebe3d6b0645870b0a6517905469097e Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 21 Nov 2023 09:21:20 +0000 Subject: [PATCH 043/221] Revert "[inductor cpp] vectorize embedding lookup (#114062)" This reverts commit 2c0474c02d3ac04a429504225d7f1a6536d3b9e6. Reverted https://github.com/pytorch/pytorch/pull/114062 on behalf of https://github.com/huydhn due to Sorry for reverting your change, please help fix lint and reland it https://hud.pytorch.org/pytorch/pytorch/commit/2c0474c02d3ac04a429504225d7f1a6536d3b9e6 ([comment](https://github.com/pytorch/pytorch/pull/114062#issuecomment-1820526515)) --- test/inductor/test_cpu_repro.py | 18 -- torch/_inductor/codegen/cpp.py | 291 +++++++++------------------ torch/_inductor/codegen/cpp_prefix.h | 6 - 3 files changed, 91 insertions(+), 224 deletions(-) diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py index 06dff1d34e31..f5ce2369bfdb 100644 --- a/test/inductor/test_cpu_repro.py +++ b/test/inductor/test_cpu_repro.py @@ -1299,7 +1299,6 @@ def test_cpu_vec_cosim(self): cpp_op_list.append(k) diff = [ - "constant", "index_expr", "signbit", "isinf", @@ -2613,23 +2612,6 @@ def forward(self, x): x = torch.randn(1, 39, 1, 18, 17) self.common(m, (x,)) - def test_embedding_vec(self): - class M(torch.nn.Module): - def __init__(self): - super().__init__() - self.emb = torch.nn.Embedding(64, 128) - - def forward(self, idx, x): - return self.emb(idx) + x - - idx = torch.randint(0, 64, (4, 32)) - x = torch.randn(4, 32, 128) - m = M().eval() - with torch.no_grad(): - metrics.reset() - self.common(m, (idx, x)) - assert metrics.generated_cpp_vec_kernel_count == 1 - if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index f08c51a0a7b8..90f55ed8b55a 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -7,7 +7,7 @@ import re import sys from copy import copy, deepcopy -from typing import Dict, List, Optional, Set, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import sympy @@ -133,19 +133,6 @@ ] -def value_to_cpp(value, cpp_type): - if value == float("-inf"): - return f"-std::numeric_limits<{cpp_type}>::infinity()" - elif value == float("inf"): - return f"std::numeric_limits<{cpp_type}>::infinity()" - elif isinstance(value, bool): - return f"static_cast<{cpp_type}>({str(value).lower()})" - elif math.isnan(value): - return f"std::numeric_limits<{cpp_type}>::quiet_NaN()" - else: - return f"static_cast<{cpp_type}>({repr(value)})" - - def reduction_init(reduction_type, dtype): if dtype in DTYPE_LOWP_FP: # Since load promotes all half-precision inputs to float, the initial @@ -449,54 +436,6 @@ def get_current_node_opt_ctx() -> OptimizationContext: return get_opt_ctx(V.interpreter.current_node) -class CppCSEVariable(CSEVariable): - def __init__(self, name, bounds: ValueRanges): - super().__init__(name, bounds) - self.is_vec = False - self.dtype = None - self.dependent_itervars: Set[sympy.Symbol] = set() - - def update_on_args(self, name, args, kwargs): - if name == "load": - # args[1] is index - self._set_dependent_itervars(args[1]) - else: - # propagate relevant itervars and is_vec from args - self.dependent_itervars.update( - *[ - arg.dependent_itervars - for arg in args - if isinstance(arg, CppCSEVariable) - ] - ) - if name == "index_expr": - self._set_dependent_itervars(args[0]) - if any(arg.is_vec for arg in args if isinstance(arg, CppCSEVariable)): - self.is_vec = True - if ( - hasattr(V.interpreter, "current_node") - and get_current_node_opt_ctx() is not None - ): - self.dtype = get_current_node_opt_ctx().dtype - - def _set_dependent_itervars(self, index: sympy.Expr): - """ - Set the relevant itervars for this variable based on the `index` expression. - This includes the itervars directly used in the `index` as well as relevant itervars - of other cse variables used in the `index`. - """ - for s in index.free_symbols: - if s in V.kernel.itervars: - self.dependent_itervars.add(s) - elif s.name in V.kernel.cse.varname_map: - self.dependent_itervars.update( - V.kernel.cse.varname_map[s.name].dependent_itervars - ) - - def depends_on(self, itervar: sympy.Symbol): - return itervar in self.dependent_itervars - - class CppOverrides(OpOverrides): """Map element-wise ops to C++""" @@ -733,20 +672,22 @@ def mod(a, b): @staticmethod def constant(val, dtype): - opt_ctx: OptimizationContext = get_current_node_opt_ctx() - assert opt_ctx and opt_ctx.dtype is not None - dtype = opt_ctx.dtype if dtype in DTYPE_LOWP_FP: # Since load promotes all half-precision inputs to float, constants # must be promoted as well dtype = torch.float32 - return value_to_cpp(val, DTYPE_TO_CPP[dtype]) + if val == float("inf"): + return f"std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::infinity()" + elif val == float("-inf"): + return f"-std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::infinity()" + elif math.isnan(val): + return f"std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::quiet_NaN()" + elif val is True or val is False: + return ops.to_dtype(str(val).lower(), dtype) + return ops.to_dtype(repr(val), dtype) @staticmethod def index_expr(expr, dtype): - opt_ctx: OptimizationContext = get_current_node_opt_ctx() - assert opt_ctx and opt_ctx.dtype is not None - dtype = opt_ctx.dtype return ops.to_dtype(cexpr(V.kernel.rename_indexing(expr)), dtype) @staticmethod @@ -763,7 +704,19 @@ def masked(mask, body, other): V.kernel.compute.splice(code) # Use the lambda's return type as the type of other - other_code = value_to_cpp(other, f"decltype({body_var}())") + type = f"decltype({body_var}())" + + if other == float("-inf"): + other_code = f"-std::numeric_limits<{type}>::infinity()" + elif other == float("inf"): + other_code = f"std::numeric_limits<{type}>::infinity()" + elif isinstance(other, bool): + other_code = f"static_cast<{type}>({str(other).lower()})" + elif math.isnan(other): + other_code = f"std::numeric_limits<{type}>::quiet_NaN()" + else: + other_code = f"static_cast<{type}>({repr(other)})" + return f"{mask} ? {body_var}() : {other_code}" @staticmethod @@ -841,54 +794,6 @@ def sign(x): class CppVecOverrides(CppOverrides): """Map element-wise ops to aten vectorization C++""" - def __new__(cls, *args, **kargs): - self = super().__new__(cls) - - def wrap(func): - # `CppVecKernel` generates both scalar ops and vector ops according to - # whether the inputs are scalars or vectors while all ops in `CppVecOverrides` - # (except for "masked") assume the inputs are vectors. We wrap the ops in - # `CppVecOverrides` to broadcast scalar inputs to vectors if needed or fallback to - # `CppOverrides` when all inputs are scalars. - # - # Inputs to ops.masked are handled separately in its own function due to - # the need of recurive handling of masked body. - def wrapper(*args, **kwargs): - has_scalar = any( - not arg.is_vec for arg in args if isinstance(arg, CppCSEVariable) - ) - has_vector = any( - arg.is_vec for arg in args if isinstance(arg, CppCSEVariable) - ) - new_args = list(args) - if has_scalar and has_vector: - # broadcast scalar args to vector if needed - new_args = [] - for arg in args: - if isinstance(arg, CppCSEVariable) and not arg.is_vec: - assert isinstance(V.kernel, CppVecKernel) - new_arg = V.kernel.broadcast(arg) - new_args.append(new_arg) - else: - new_args.append(arg) - if has_vector: - return func(*new_args, **kwargs) - else: - # fallback to scalar ops - scalar_ops = super(CppVecOverrides, self) - scalar_func = getattr( - scalar_ops, func.__name__, scalar_ops.__getattr__(func.__name__) # type: ignore[attr-defined] - ) - assert scalar_func is not None - return scalar_func(*args, **kwargs) - - return wrapper - - for name, method in vars(cls).items(): - if getattr(method, "__class__", None) == staticmethod and name != "masked": - setattr(self, name, wrap(method.__func__)) - return self - @staticmethod def add(a, b): return f"{a} + {b}" @@ -1101,6 +1006,28 @@ def acosh(x): vec_one = f"decltype({x})(1)" return f"({x} + ({x}*{x} - {vec_one}).sqrt()).log()" + @staticmethod + def constant(val, dtype): + opt_ctx: OptimizationContext = get_current_node_opt_ctx() + assert opt_ctx + proposed_dtype = opt_ctx.dtype + assert proposed_dtype in [ + torch.float, + torch.int32, + ] + if val == float("inf"): + quote = f"std::numeric_limits<{DTYPE_TO_CPP[proposed_dtype]}>::infinity()" + elif val == float("-inf"): + quote = f"-std::numeric_limits<{DTYPE_TO_CPP[proposed_dtype]}>::infinity()" + elif math.isnan(val): + quote = f"std::numeric_limits<{DTYPE_TO_CPP[proposed_dtype]}>::quiet_NaN()" + elif val is True or val is False: + quote = f"static_cast<{DTYPE_TO_CPP[proposed_dtype]}>({str(val).lower()})" + else: + quote = f"static_cast<{DTYPE_TO_CPP[proposed_dtype]}>({repr(val)})" + + return f"at::vec::Vectorized<{DTYPE_TO_CPP[proposed_dtype]}>({quote})" + @staticmethod def relu(x): bug = config.cpp.inject_relu_bug_TESTING_ONLY @@ -1232,24 +1159,32 @@ def masked(mask, body, other): code.writeline(";") V.kernel.compute.splice(code) - other_code = value_to_cpp(other, "float") - other_code_vec = f"at::vec::Vectorized({other_code})" - - if result.is_vec: - type = f"decltype({var}())" - float_mask = f"to_float_mask({new_mask})" - csevar = V.kernel.cse.generate( - V.kernel.compute, - f"{type}::blendv({other_code_vec}, {var}(), {float_mask})", + if other == float("-inf"): + other_code = ( + "at::vec::Vectorized(-std::numeric_limits::infinity())" ) - else: - csevar = V.kernel.cse.generate( - V.kernel.compute, f"{mask} ? {var}() : {other_code}" + elif other == float("inf"): + other_code = ( + "at::vec::Vectorized(std::numeric_limits::infinity())" + ) + elif math.isnan(other): + other_code = ( + "at::vec::Vectorized(std::numeric_limits::quiet_NaN())" ) - # `result` is explicitly added to the args for correct propagation - # of relevant itervars and vectorization status. - csevar.update_on_args("masked", (mask, body, other, result), {}) - return csevar + else: + other_code = f"at::vec::Vectorized({other!r})" + type = f"decltype({var}())" + float_mask = f"to_float_mask({new_mask})" + return f"{type}::blendv({other_code}, {var}(), {float_mask})" + + @staticmethod + def index_expr(expr, dtype): + assert dtype == torch.int64 + opt_ctx: OptimizationContext = get_current_node_opt_ctx() + assert opt_ctx + assert opt_ctx.dtype == torch.int32 + assert opt_ctx.is_most_inner_loop_irrevelant + return f"at::vec::Vectorized(static_cast({cexpr(V.kernel.rename_indexing(expr))}))" class CppKernel(Kernel): @@ -1307,9 +1242,7 @@ def load(self, name: str, index: sympy.Expr): line = f"{var}[{cexpr_index(index)}]" if V.graph.get_dtype(name) in [torch.float16]: line = f"static_cast({line})" - csevar = self.cse.generate(self.loads, line) - csevar.update_on_args("load", (name, index), {}) - return csevar + return self.cse.generate(self.loads, line) def store(self, name, index, value, mode=None): assert "buf" in name @@ -1539,9 +1472,6 @@ def write_to_suffix(self): self.reduction_suffix.splice(self.stores) (self.loads, self.compute, self.stores, self.cse) = prior - def create_cse_var(self, *args, **kwargs): - return CppCSEVariable(*args, **kwargs) - class CppVecKernel(CppKernel): overrides = CppVecOverrides # type: ignore[assignment] @@ -1576,11 +1506,7 @@ def load(self, name: str, index: sympy.Expr): non_contiguous = ( not is_broadcast and stride_at(tiling_var, index) != 1 - or any( - self.cse.varname_map[s.name].depends_on(tiling_var) - for s in index.free_symbols - if s.name.startswith("tmp") - ) + or "tmp" in f"{index}" ) var_expr = ( f"{var}[{cexpr_index(index)}]" @@ -1589,9 +1515,13 @@ def load(self, name: str, index: sympy.Expr): ) loadbuf = "tmpbuf" if non_contiguous else var_expr if is_broadcast: - csevar = super().load(name, index) - csevar.dtype = dtype - return csevar + # should always be broadcast as float for vectorization since we always use float to compute + if is_mask: + loadbuf = f"flag_to_float_scalar({loadbuf})" + if dtype in DTYPE_LOWP_FP: + line = f"at::vec::Vectorized<{DTYPE_TO_CPP[dtype]}>({loadbuf})" + else: + line = f"at::vec::Vectorized(static_cast({loadbuf}))" elif dtype in [torch.uint8] and opt_ctx.is_load_uint8_as_float: line = ( f"masked_load({loadbuf}, {load_mask})" @@ -1633,11 +1563,7 @@ def load(self, name: str, index: sympy.Expr): tmpbufdefine += f"tmpbuf[{inner}] = {rhs};" line = f"([&]() {{ {tmpbufdeclare} {tmpbufdefine} return {line}; }})()" - csevar = self.cse.generate(self.loads, line) - csevar.update_on_args("load", (name, index), {}) - assert isinstance(csevar, CppCSEVariable) - csevar.is_vec = True - return csevar + return self.cse.generate(self.loads, line) def get_vec_store_line(self, value, var, index, dtype): """ @@ -1646,11 +1572,6 @@ def get_vec_store_line(self, value, var, index, dtype): :param var: buffer to store into. :index: index into the `var`. """ - # when value's type is str (e.g., welford reduction), caller should make sure - # it is a vector - assert isinstance(value, str) or ( - isinstance(value, CppCSEVariable) and value.is_vec - ), value tiling_var = self.itervars[self.tiling_idx] assert index.has(tiling_var) var_expr = f"{var} + {cexpr_index(index)}" @@ -1679,10 +1600,6 @@ def get_vec_store_line(self, value, var, index, dtype): def store(self, name, index, value, mode=None): assert "buf" in name assert mode is None - assert isinstance(value, CppCSEVariable), value - if not value.is_vec: - # this happens when we store a scalar into a vectorized buffer like "fill" - value = self.broadcast(value) opt_ctx: OptimizationContext = get_current_node_opt_ctx() var = self.args.output(name) index = self.rename_indexing(index) @@ -1705,7 +1622,6 @@ def reduction(self, dtype, src_dtype, reduction_type, value): } assert dtype == torch.float assert src_dtype == torch.float - assert isinstance(value, CppCSEVariable) and value.is_vec, value vec_ns = "at::vec" vec = f"{vec_ns}::Vectorized<{DTYPE_TO_CPP[dtype]}>" @@ -1824,26 +1740,6 @@ def store_reduction(self, name, index, value): ] self.reduction_suffix.writelines(store_lines) - def broadcast(self, scalar_var: CppCSEVariable): - assert ( - not scalar_var.is_vec - and self.itervars[self.tiling_idx] not in scalar_var.dependent_itervars - ) - if scalar_var.dtype == torch.bool: - vec_var = self.cse.generate( - self.compute, f"to_float_mask({scalar_var.name})" - ) - else: - vec_var = self.cse.generate( - self.compute, - f"at::vec::Vectorized<{DTYPE_TO_CPP[scalar_var.dtype]}>({scalar_var.name})", - ) - assert isinstance(vec_var, CppCSEVariable) - vec_var.dtype = scalar_var.dtype - vec_var.dependent_itervars = scalar_var.dependent_itervars - vec_var.is_vec = True - return vec_var - class CppTile2DKernel(CppVecKernel): """ @@ -1953,11 +1849,7 @@ def load(self, name: str, index: sympy.Expr): line = f"at::vec::Vectorized::loadu_one_fourth({loadbuf})" else: line = f"at::vec::Vectorized::loadu({loadbuf})" - csevar = self.cse.generate(self.loads, line) - csevar.update_on_args("load", (name, index), {}) - assert isinstance(csevar, CppCSEVariable) - csevar.is_vec = True - return csevar + return self.cse.generate(self.loads, line) else: new_index = self.scale_index_with_offset( index, @@ -2058,6 +1950,10 @@ def disable_vec(self, msg=None): schedule_log.debug("Disabled vectorization: %s", msg) self.simd_vec = False + def could_vec(self, name: str, index: sympy.Expr): + assert self.itervars is not None + return len(self.itervars) > 0 + def is_mask(self, name: str, users: Dict[torch.fx.Node, None]): load_type = V.graph.get_dtype(name) if load_type == torch.bool: @@ -2140,10 +2036,6 @@ def load(self, name: str, index: sympy.Expr): var = self.cse.newvar() - if len(self.itervars) == 0: - self.disable_vec("not a loop") - return var - if load_dtype in [torch.bool, torch.uint8] and not ( opt_ctx.is_load_as_mask or opt_ctx.is_load_uint8_as_float ): @@ -2154,21 +2046,18 @@ def load(self, name: str, index: sympy.Expr): return var if ( - (load_dtype not in self.load_supported_dtypes) - and not self.is_load_integer_scalar_tensor(name, index) - and index.has(self.itervars[self.tiling_idx]) - ): + load_dtype not in self.load_supported_dtypes + ) and not self.is_load_integer_scalar_tensor(name, index): self.disable_vec(f"{load_dtype} not supported by load") return var + index = self.rename_indexing(index) + if self.simd_vec and not self.could_vec(name, index): + self.disable_vec(f"not a loop: {index}") return var def store(self, name, index, value, mode=None): with RecordOptimizationContext(__name__) as node_ctx: - if len(self.itervars) == 0: - self.disable_vec("not a loop") - return self.simd_vec - store_dtype = V.graph.get_dtype(name) opt_ctx: OptimizationContext = node_ctx.get_opt_ctx() @@ -2196,6 +2085,8 @@ def store(self, name, index, value, mode=None): if index.is_number: self.disable_vec(f"constant store index: {index}") + if self.simd_vec and not self.could_vec(name, index): + self.disable_vec(f"not a loop: {index}") return self.simd_vec def reduction(self, dtype, src_dtype, reduction_type, value): diff --git a/torch/_inductor/codegen/cpp_prefix.h b/torch/_inductor/codegen/cpp_prefix.h index 23f72218a0cc..1a532029cdeb 100644 --- a/torch/_inductor/codegen/cpp_prefix.h +++ b/torch/_inductor/codegen/cpp_prefix.h @@ -401,10 +401,4 @@ template <> inline at::vec::Vectorized to_float_mask(at::vec::Vectorized src) { return src; } - -inline at::vec::Vectorized to_float_mask(int src) { - float mask; - *(uint32_t*)&mask = src ? 0xFFFFFFFF : 0; - return at::vec::Vectorized(mask); -} #endif From 8ec59d3553db4e3dae991db9aa3251558265dfd3 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 21 Nov 2023 10:05:15 +0000 Subject: [PATCH 044/221] Revert "[dynamo] report guard failure user stack, fix incorrectly skipping interesting files (#114053)" This reverts commit 826ab0e32d558415d5d682842417fd16b2223739. Reverted https://github.com/pytorch/pytorch/pull/114053 on behalf of https://github.com/DanilBaibak due to Break internal build ([comment](https://github.com/pytorch/pytorch/pull/114053#issuecomment-1820584281)) --- test/dynamo/test_aot_autograd.py | 32 +++++++++++++++---------------- test/dynamo/test_logging.py | 32 ------------------------------- test/dynamo/test_misc.py | 33 ++++++++++++++------------------ torch/_dynamo/guards.py | 9 +++------ 4 files changed, 33 insertions(+), 73 deletions(-) diff --git a/test/dynamo/test_aot_autograd.py b/test/dynamo/test_aot_autograd.py index 169c4b60452b..07af08acb36f 100644 --- a/test/dynamo/test_aot_autograd.py +++ b/test/dynamo/test_aot_autograd.py @@ -302,9 +302,9 @@ def guard_fail_fn(failure): fxy = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F()) compare_equal_outs_and_grads(self, F(), fxy, (x, y)) compare_equal_outs_and_grads(self, F(), fxy, (x, z)) - self.assertIn( - """tensor 'L['y']' requires_grad mismatch. expected requires_grad=1""", + self.assertExpectedInline( failure_reason, + """tensor 'L['y']' requires_grad mismatch. expected requires_grad=1""", ) # Reset failure reason @@ -421,7 +421,7 @@ def guard_fail_fn(failure): fxx(x3, x3) fxx(x4, y4) self.assertEqual(cc.frame_count, 2) - self.assertIn("""L['x'] is L['y']""", failure_reason) + self.assertExpectedInline(failure_reason, """L['x'] is L['y']""") @patch("torch._functorch.config.debug_assert", True) def test_arg_dupe_via_dynamo_recompiles_many_args_param_non_tensor_arg(self): @@ -456,9 +456,9 @@ def guard_fail_fn(failure): f(a1, a1, a1, a1, 2, 2) f(a2, b2, b2, b2, 2, 2) self.assertEqual(cc.frame_count, 2) - self.assertIn( - """L['a'] is L['b']""", + self.assertExpectedInline( failure_reason, + """L['a'] is L['b']""", ) torch._dynamo.reset() @@ -474,7 +474,7 @@ def guard_fail_fn(failure): f(a3, b3, c3, c3, 3, 3) f(a4, b4, c4, d4, 3, 3) self.assertEqual(cc.frame_count, 2) - self.assertIn("""L['c'] is L['d']""", failure_reason) + self.assertExpectedInline(failure_reason, """L['c'] is L['d']""") @patch("torch._functorch.config.debug_assert", True) def test_arg_dupe_via_dynamo_recompiles_many_with_global(self): @@ -512,9 +512,9 @@ def guard_fail_fn(failure): f(a1, a1, a1, a1, 2, 2) f(a2, b2, b2, b2, 2, 2) self.assertEqual(cc.frame_count, 2) - self.assertIn( - """L['a'] is L['b']""", + self.assertExpectedInline( failure_reason, + """L['a'] is L['b']""", ) @patch("torch._functorch.config.debug_assert", True) @@ -550,9 +550,9 @@ def guard_fail_fn(failure): f([3, 2, 1], [4, 5, 6], a1, a1, a1, a1) f([3, 2, 1], [4, 5, 6], a2, b2, b2, b2) self.assertEqual(cc.frame_count, 2) - self.assertIn( - """L['a'] is L['b']""", + self.assertExpectedInline( failure_reason, + """L['a'] is L['b']""", ) torch._dynamo.reset() @@ -602,9 +602,9 @@ def guard_fail_fn(failure): f(a1, a1, a1, a1) f(a2, b2, b2, b2) self.assertEqual(cc.frame_count, 2) - self.assertIn( - """L['a'] is L['b']""", + self.assertExpectedInline( failure_reason, + """L['a'] is L['b']""", ) torch._dynamo.reset() @@ -620,7 +620,7 @@ def guard_fail_fn(failure): f(a3, b3, c3, c3) f(a4, b4, c4, d4) self.assertEqual(cc.frame_count, 2) - self.assertIn("""L['c'] is L['d']""", failure_reason) + self.assertExpectedInline(failure_reason, """L['c'] is L['d']""") @patch("torch._functorch.config.debug_assert", True) def test_arg_dupe_via_dynamo_recompiles_many_args(self): @@ -651,9 +651,9 @@ def guard_fail_fn(failure): f(a1, a1, a1, a1) f(a2, b2, b2, b2) self.assertEqual(cc.frame_count, 2) - self.assertIn( - """L['a'] is L['b']""", + self.assertExpectedInline( failure_reason, + """L['a'] is L['b']""", ) torch._dynamo.reset() @@ -669,7 +669,7 @@ def guard_fail_fn(failure): f(a3, b3, c3, c3) f(a4, b4, c4, d4) self.assertEqual(cc.frame_count, 2) - self.assertIn("""L['c'] is L['d']""", failure_reason) + self.assertExpectedInline(failure_reason, """L['c'] is L['d']""") @expectedFailureDynamic # https://github.com/pytorch/pytorch/issues/103539 @torch._dynamo.config.patch(automatic_dynamic_shapes=False) diff --git a/test/dynamo/test_logging.py b/test/dynamo/test_logging.py index 7e233263d3e2..9b77b2f2e3aa 100644 --- a/test/dynamo/test_logging.py +++ b/test/dynamo/test_logging.py @@ -596,38 +596,6 @@ def fn(x): ~~^~~""", ) - @make_logging_test(guards=True, recompiles=True) - def test_guards_recompiles(self, records): - def fn(x, ys, zs): - return inner(x, ys, zs) - - def inner(x, ys, zs): - for y, z in zip(ys, zs): - x += y * z - return x - - ys = [1.0, 2.0] - zs = [3.0] - x = torch.tensor([1.0]) - - fn_opt = torch._dynamo.optimize("eager")(fn) - fn_opt(x, ys, zs) - fn_opt(x, ys[:1], zs) - - record_str = "\n".join(r.getMessage() for r in records) - - self.assertIn( - """\ -L['zs'][0] == 3.0 # for y, z in zip(ys, zs):""", - record_str, - ) - self.assertIn( - """\ - triggered by the following guard failure(s):\n\ - - len(L['ys']) == 2 # for y, z in zip(ys, zs):""", - record_str, - ) - @make_logging_test(**torch._logging.DEFAULT_LOGGING) def test_default_logging(self, records): def fn(a): diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 61928d4abd84..99d7e1ff0ffb 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -650,9 +650,9 @@ def guard_failures(failure): )(compare_shapes) opt_fn(torch.randn([3, 4])) opt_fn(torch.randn([4, 3])) - self.assertIn( - """tensor 'L['a']' size mismatch at index 0. expected 3, actual 4""", + self.assertExpectedInline( guard_failure.reason, + """tensor 'L['a']' size mismatch at index 0. expected 3, actual 4""", ) def test_builtin_abs(self): @@ -716,9 +716,9 @@ def fn(x, y): ), sorted(guard_code), ) - guard_code_str = "\n".join(guard_code) - - for line in """\ + self.assertExpectedInline( + "\n".join(guard_code), + """\ 2 <= L['x'].size()[0] L['x'] is L['y'] L['x'].ndimension() == 2 @@ -734,13 +734,8 @@ def fn(x, y): not ___dict_contains('cccccccc', G['sys'].modules) str(L['x'].device) == 'cpu' str(L['x'].dtype) == 'torch.float32' -utils_device.CURRENT_DEVICE == None""".split( - "\n" - ): - self.assertIn( - line, - guard_code_str, - ) +utils_device.CURRENT_DEVICE == None""", + ) def test_fold(self): def fn(a): @@ -5245,12 +5240,12 @@ def guard_failures(failure): self.assertTrue(guard_failure is not None) first_guard_failure = guard_failure[0].partition("\n")[0] if torch._dynamo.config.assume_static_by_default: - self.assertIn( - """tensor 'L['x']' size mismatch at index 0. expected 2, actual 5""", + self.assertExpectedInline( first_guard_failure, + """tensor 'L['x']' size mismatch at index 0. expected 2, actual 5""", ) else: - self.assertIn("""L['x'].size()[0] < 3""", first_guard_failure) + self.assertExpectedInline(first_guard_failure, """L['x'].size()[0] < 3""") def test_guard_failure_fn2(self): def fn(x, y): @@ -5278,9 +5273,9 @@ def guard_failures(failure): opt_fn(x2, y2) if torch._dynamo.config.assume_static_by_default: - self.assertIn( - """tensor 'L['x']' size mismatch at index 0. expected 2, actual 3""", + self.assertExpectedInline( guard_failure[0], + """tensor 'L['x']' size mismatch at index 0. expected 2, actual 3""", ) else: self.assertTrue(guard_failure is None) @@ -5313,9 +5308,9 @@ def guard_failures(failure): # guard is expected for both static and dynamic shapes self.assertTrue(guard_failure is not None) - self.assertIn( - """len(L['x']) == 10""", + self.assertExpectedInline( guard_failure[0], + """len(L['x']) == 10""", ) def test_restore_graphstate(self): diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 1b068402019b..9c182ac40a82 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -1031,15 +1031,15 @@ def compile_check_fn(self, builder, guards_out, guard_fail_fn): # Don't report this guard, it's always the same, useless! code_parts = ["___guarded_code.valid", "___check_global_state()"] - verbose_code_parts = code_parts[:] def add_code_part(code, guard, log_only=False): extra = "" if guard.user_stack: for fs in reversed(guard.user_stack): if fs.filename not in uninteresting_files(): - extra = f" # {format_frame(fs, line=True)}" break + else: + extra = f" # {format_frame(fs, line=True)}" elif guard.stack: extra = f" # {format_frame(guard.stack.summary()[-1])}" @@ -1064,7 +1064,6 @@ def add_code_part(code, guard, log_only=False): if not log_only: code_parts.append(code) - verbose_code_parts.append(f"{code:<60}{extra}") seen = set() for gcl in builder.code: @@ -1114,7 +1113,6 @@ def convert(size_or_stride): ) # Do this manually, to un-stagger the guards in log message code_parts.append(f"___check_tensors({tensor_check_args})") - verbose_code_parts.append(f"___check_tensors({tensor_check_args})") tensor_check_guards = builder.tensor_check_guards for i, name in enumerate(tensor_check_names): @@ -1185,7 +1183,6 @@ def convert(size_or_stride): # TODO(whc) maybe '.code_parts' was only kept around for the guard callback? so we don't need both guard_fn.args = largs guard_fn.code_parts = code_parts - guard_fn.verbose_code_parts = verbose_code_parts # Grab only G, but preserve "G" because guards access it as "G" guard_fn.global_scope = { "G": builder.scope["G"], @@ -1285,7 +1282,7 @@ def get_guard_fail_reason( scope.update(guard_fn.closure_vars) scope["___check_tensors"] = scope["___check_tensors_verbose"] reasons: List[str] = [] - for part in guard_fn.verbose_code_parts: + for part in guard_fn.code_parts: global_scope = dict(guard_fn.global_scope) global_scope["__compile_source__"] = part with report_compile_source_on_error(): From 2aa486de9b3f38d28ee2db9c15220a4f0919522d Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 21 Nov 2023 11:51:20 +0000 Subject: [PATCH 045/221] vendor packaging.version (#114108) Fixes #113940. This vendors the relevant parts of [`packaging==23.2.0`]() to have access to `Version` and `InvalidVersion` without taking a runtime dependency on `setuptools` or `packaging`. I didn't find any vendoring policy so I put it under `torch._vendor.packaging`. While I have only vendored the files we need, I have not touched or trimmed the files otherwise. Pull Request resolved: https://github.com/pytorch/pytorch/pull/114108 Approved by: https://github.com/malfet, https://github.com/albanD --- .lintrunner.toml | 1 + setup.py | 5 - test/run_test.py | 2 +- torch/_vendor/README.md | 14 + torch/_vendor/__init__.py | 0 torch/_vendor/packaging/LICENSE | 3 + torch/_vendor/packaging/LICENSE.APACHE | 177 ++++++++ torch/_vendor/packaging/LICENSE.BSD | 23 + torch/_vendor/packaging/__init__.py | 15 + torch/_vendor/packaging/_structures.py | 61 +++ torch/_vendor/packaging/version.py | 563 +++++++++++++++++++++++++ torch/torch_version.py | 41 +- 12 files changed, 862 insertions(+), 43 deletions(-) create mode 100644 torch/_vendor/README.md create mode 100644 torch/_vendor/__init__.py create mode 100644 torch/_vendor/packaging/LICENSE create mode 100644 torch/_vendor/packaging/LICENSE.APACHE create mode 100644 torch/_vendor/packaging/LICENSE.BSD create mode 100644 torch/_vendor/packaging/__init__.py create mode 100644 torch/_vendor/packaging/_structures.py create mode 100644 torch/_vendor/packaging/version.py diff --git a/.lintrunner.toml b/.lintrunner.toml index d2ccb509c2dc..a991cb8d4c4e 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -1681,6 +1681,7 @@ exclude_patterns = [ 'torch/_higher_order_ops/__init__.py', 'torch/_higher_order_ops/out_dtype.py', 'torch/_higher_order_ops/wrap.py', + 'torch/_vendor/**', 'torch/ao/__init__.py', 'torch/ao/nn/__init__.py', 'torch/ao/nn/intrinsic/__init__.py', diff --git a/setup.py b/setup.py index 2f10daf7195d..5047e4fda704 100644 --- a/setup.py +++ b/setup.py @@ -1117,11 +1117,6 @@ def main(): "fsspec", ] - if IS_WINDOWS and sys.version_info >= (3, 12, 0): - # torch.version requires this and it is not part - # of the default cpython install on windows in 3.12+ - install_requires.append("packaging") - # Parse the command line and check the arguments before we proceed with # building deps and setup. We need to set values so `--help` works. dist = Distribution() diff --git a/test/run_test.py b/test/run_test.py index 7440aed74992..69f001ddb879 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -819,7 +819,7 @@ def run_doctests(test_module, test_directory, options): pkgpath = pathlib.Path(torch.__file__).parent - exclude_module_list = [] + exclude_module_list = ["torch._vendor.*"] enabled = { # TODO: expose these options to the user # For now disable all feature-conditional tests diff --git a/torch/_vendor/README.md b/torch/_vendor/README.md new file mode 100644 index 000000000000..f7580057ff7a --- /dev/null +++ b/torch/_vendor/README.md @@ -0,0 +1,14 @@ +# Vendored libraries + +## `packaging` + +Source: https://github.com/pypa/packaging/ + +PyPI: https://pypi.org/project/packaging/ + +Vendored version: `23.2.0` + +Instructions to update: + +- Copy the file `packaging/version.py` and all files that it is depending on +- Check if the licensing has changed from the BSD / Apache dual licensing and update the license files accordingly diff --git a/torch/_vendor/__init__.py b/torch/_vendor/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/torch/_vendor/packaging/LICENSE b/torch/_vendor/packaging/LICENSE new file mode 100644 index 000000000000..6f62d44e4ef7 --- /dev/null +++ b/torch/_vendor/packaging/LICENSE @@ -0,0 +1,3 @@ +This software is made available under the terms of *either* of the licenses +found in LICENSE.APACHE or LICENSE.BSD. Contributions to this software is made +under the terms of *both* these licenses. diff --git a/torch/_vendor/packaging/LICENSE.APACHE b/torch/_vendor/packaging/LICENSE.APACHE new file mode 100644 index 000000000000..f433b1a53f5b --- /dev/null +++ b/torch/_vendor/packaging/LICENSE.APACHE @@ -0,0 +1,177 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS diff --git a/torch/_vendor/packaging/LICENSE.BSD b/torch/_vendor/packaging/LICENSE.BSD new file mode 100644 index 000000000000..42ce7b75c92f --- /dev/null +++ b/torch/_vendor/packaging/LICENSE.BSD @@ -0,0 +1,23 @@ +Copyright (c) Donald Stufft and individual contributors. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + 1. Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/torch/_vendor/packaging/__init__.py b/torch/_vendor/packaging/__init__.py new file mode 100644 index 000000000000..22809cfd5dc2 --- /dev/null +++ b/torch/_vendor/packaging/__init__.py @@ -0,0 +1,15 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +__title__ = "packaging" +__summary__ = "Core utilities for Python packages" +__uri__ = "https://github.com/pypa/packaging" + +__version__ = "23.2" + +__author__ = "Donald Stufft and individual contributors" +__email__ = "donald@stufft.io" + +__license__ = "BSD-2-Clause or Apache-2.0" +__copyright__ = "2014 %s" % __author__ diff --git a/torch/_vendor/packaging/_structures.py b/torch/_vendor/packaging/_structures.py new file mode 100644 index 000000000000..90a6465f9682 --- /dev/null +++ b/torch/_vendor/packaging/_structures.py @@ -0,0 +1,61 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + + +class InfinityType: + def __repr__(self) -> str: + return "Infinity" + + def __hash__(self) -> int: + return hash(repr(self)) + + def __lt__(self, other: object) -> bool: + return False + + def __le__(self, other: object) -> bool: + return False + + def __eq__(self, other: object) -> bool: + return isinstance(other, self.__class__) + + def __gt__(self, other: object) -> bool: + return True + + def __ge__(self, other: object) -> bool: + return True + + def __neg__(self: object) -> "NegativeInfinityType": + return NegativeInfinity + + +Infinity = InfinityType() + + +class NegativeInfinityType: + def __repr__(self) -> str: + return "-Infinity" + + def __hash__(self) -> int: + return hash(repr(self)) + + def __lt__(self, other: object) -> bool: + return True + + def __le__(self, other: object) -> bool: + return True + + def __eq__(self, other: object) -> bool: + return isinstance(other, self.__class__) + + def __gt__(self, other: object) -> bool: + return False + + def __ge__(self, other: object) -> bool: + return False + + def __neg__(self: object) -> InfinityType: + return Infinity + + +NegativeInfinity = NegativeInfinityType() diff --git a/torch/_vendor/packaging/version.py b/torch/_vendor/packaging/version.py new file mode 100644 index 000000000000..5faab9bd0dcf --- /dev/null +++ b/torch/_vendor/packaging/version.py @@ -0,0 +1,563 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. +""" +.. testsetup:: + + from packaging.version import parse, Version +""" + +import itertools +import re +from typing import Any, Callable, NamedTuple, Optional, SupportsInt, Tuple, Union + +from ._structures import Infinity, InfinityType, NegativeInfinity, NegativeInfinityType + +__all__ = ["VERSION_PATTERN", "parse", "Version", "InvalidVersion"] + +LocalType = Tuple[Union[int, str], ...] + +CmpPrePostDevType = Union[InfinityType, NegativeInfinityType, Tuple[str, int]] +CmpLocalType = Union[ + NegativeInfinityType, + Tuple[Union[Tuple[int, str], Tuple[NegativeInfinityType, Union[int, str]]], ...], +] +CmpKey = Tuple[ + int, + Tuple[int, ...], + CmpPrePostDevType, + CmpPrePostDevType, + CmpPrePostDevType, + CmpLocalType, +] +VersionComparisonMethod = Callable[[CmpKey, CmpKey], bool] + + +class _Version(NamedTuple): + epoch: int + release: Tuple[int, ...] + dev: Optional[Tuple[str, int]] + pre: Optional[Tuple[str, int]] + post: Optional[Tuple[str, int]] + local: Optional[LocalType] + + +def parse(version: str) -> "Version": + """Parse the given version string. + + >>> parse('1.0.dev1') + + + :param version: The version string to parse. + :raises InvalidVersion: When the version string is not a valid version. + """ + return Version(version) + + +class InvalidVersion(ValueError): + """Raised when a version string is not a valid version. + + >>> Version("invalid") + Traceback (most recent call last): + ... + packaging.version.InvalidVersion: Invalid version: 'invalid' + """ + + +class _BaseVersion: + _key: Tuple[Any, ...] + + def __hash__(self) -> int: + return hash(self._key) + + # Please keep the duplicated `isinstance` check + # in the six comparisons hereunder + # unless you find a way to avoid adding overhead function calls. + def __lt__(self, other: "_BaseVersion") -> bool: + if not isinstance(other, _BaseVersion): + return NotImplemented + + return self._key < other._key + + def __le__(self, other: "_BaseVersion") -> bool: + if not isinstance(other, _BaseVersion): + return NotImplemented + + return self._key <= other._key + + def __eq__(self, other: object) -> bool: + if not isinstance(other, _BaseVersion): + return NotImplemented + + return self._key == other._key + + def __ge__(self, other: "_BaseVersion") -> bool: + if not isinstance(other, _BaseVersion): + return NotImplemented + + return self._key >= other._key + + def __gt__(self, other: "_BaseVersion") -> bool: + if not isinstance(other, _BaseVersion): + return NotImplemented + + return self._key > other._key + + def __ne__(self, other: object) -> bool: + if not isinstance(other, _BaseVersion): + return NotImplemented + + return self._key != other._key + + +# Deliberately not anchored to the start and end of the string, to make it +# easier for 3rd party code to reuse +_VERSION_PATTERN = r""" + v? + (?: + (?:(?P[0-9]+)!)? # epoch + (?P[0-9]+(?:\.[0-9]+)*) # release segment + (?P
                                          # pre-release
+            [-_\.]?
+            (?Palpha|a|beta|b|preview|pre|c|rc)
+            [-_\.]?
+            (?P[0-9]+)?
+        )?
+        (?P                                         # post release
+            (?:-(?P[0-9]+))
+            |
+            (?:
+                [-_\.]?
+                (?Ppost|rev|r)
+                [-_\.]?
+                (?P[0-9]+)?
+            )
+        )?
+        (?P                                          # dev release
+            [-_\.]?
+            (?Pdev)
+            [-_\.]?
+            (?P[0-9]+)?
+        )?
+    )
+    (?:\+(?P[a-z0-9]+(?:[-_\.][a-z0-9]+)*))?       # local version
+"""
+
+VERSION_PATTERN = _VERSION_PATTERN
+"""
+A string containing the regular expression used to match a valid version.
+
+The pattern is not anchored at either end, and is intended for embedding in larger
+expressions (for example, matching a version number as part of a file name). The
+regular expression should be compiled with the ``re.VERBOSE`` and ``re.IGNORECASE``
+flags set.
+
+:meta hide-value:
+"""
+
+
+class Version(_BaseVersion):
+    """This class abstracts handling of a project's versions.
+
+    A :class:`Version` instance is comparison aware and can be compared and
+    sorted using the standard Python interfaces.
+
+    >>> v1 = Version("1.0a5")
+    >>> v2 = Version("1.0")
+    >>> v1
+    
+    >>> v2
+    
+    >>> v1 < v2
+    True
+    >>> v1 == v2
+    False
+    >>> v1 > v2
+    False
+    >>> v1 >= v2
+    False
+    >>> v1 <= v2
+    True
+    """
+
+    _regex = re.compile(r"^\s*" + VERSION_PATTERN + r"\s*$", re.VERBOSE | re.IGNORECASE)
+    _key: CmpKey
+
+    def __init__(self, version: str) -> None:
+        """Initialize a Version object.
+
+        :param version:
+            The string representation of a version which will be parsed and normalized
+            before use.
+        :raises InvalidVersion:
+            If the ``version`` does not conform to PEP 440 in any way then this
+            exception will be raised.
+        """
+
+        # Validate the version and parse it into pieces
+        match = self._regex.search(version)
+        if not match:
+            raise InvalidVersion(f"Invalid version: '{version}'")
+
+        # Store the parsed out pieces of the version
+        self._version = _Version(
+            epoch=int(match.group("epoch")) if match.group("epoch") else 0,
+            release=tuple(int(i) for i in match.group("release").split(".")),
+            pre=_parse_letter_version(match.group("pre_l"), match.group("pre_n")),
+            post=_parse_letter_version(
+                match.group("post_l"), match.group("post_n1") or match.group("post_n2")
+            ),
+            dev=_parse_letter_version(match.group("dev_l"), match.group("dev_n")),
+            local=_parse_local_version(match.group("local")),
+        )
+
+        # Generate a key which will be used for sorting
+        self._key = _cmpkey(
+            self._version.epoch,
+            self._version.release,
+            self._version.pre,
+            self._version.post,
+            self._version.dev,
+            self._version.local,
+        )
+
+    def __repr__(self) -> str:
+        """A representation of the Version that shows all internal state.
+
+        >>> Version('1.0.0')
+        
+        """
+        return f""
+
+    def __str__(self) -> str:
+        """A string representation of the version that can be rounded-tripped.
+
+        >>> str(Version("1.0a5"))
+        '1.0a5'
+        """
+        parts = []
+
+        # Epoch
+        if self.epoch != 0:
+            parts.append(f"{self.epoch}!")
+
+        # Release segment
+        parts.append(".".join(str(x) for x in self.release))
+
+        # Pre-release
+        if self.pre is not None:
+            parts.append("".join(str(x) for x in self.pre))
+
+        # Post-release
+        if self.post is not None:
+            parts.append(f".post{self.post}")
+
+        # Development release
+        if self.dev is not None:
+            parts.append(f".dev{self.dev}")
+
+        # Local version segment
+        if self.local is not None:
+            parts.append(f"+{self.local}")
+
+        return "".join(parts)
+
+    @property
+    def epoch(self) -> int:
+        """The epoch of the version.
+
+        >>> Version("2.0.0").epoch
+        0
+        >>> Version("1!2.0.0").epoch
+        1
+        """
+        return self._version.epoch
+
+    @property
+    def release(self) -> Tuple[int, ...]:
+        """The components of the "release" segment of the version.
+
+        >>> Version("1.2.3").release
+        (1, 2, 3)
+        >>> Version("2.0.0").release
+        (2, 0, 0)
+        >>> Version("1!2.0.0.post0").release
+        (2, 0, 0)
+
+        Includes trailing zeroes but not the epoch or any pre-release / development /
+        post-release suffixes.
+        """
+        return self._version.release
+
+    @property
+    def pre(self) -> Optional[Tuple[str, int]]:
+        """The pre-release segment of the version.
+
+        >>> print(Version("1.2.3").pre)
+        None
+        >>> Version("1.2.3a1").pre
+        ('a', 1)
+        >>> Version("1.2.3b1").pre
+        ('b', 1)
+        >>> Version("1.2.3rc1").pre
+        ('rc', 1)
+        """
+        return self._version.pre
+
+    @property
+    def post(self) -> Optional[int]:
+        """The post-release number of the version.
+
+        >>> print(Version("1.2.3").post)
+        None
+        >>> Version("1.2.3.post1").post
+        1
+        """
+        return self._version.post[1] if self._version.post else None
+
+    @property
+    def dev(self) -> Optional[int]:
+        """The development number of the version.
+
+        >>> print(Version("1.2.3").dev)
+        None
+        >>> Version("1.2.3.dev1").dev
+        1
+        """
+        return self._version.dev[1] if self._version.dev else None
+
+    @property
+    def local(self) -> Optional[str]:
+        """The local version segment of the version.
+
+        >>> print(Version("1.2.3").local)
+        None
+        >>> Version("1.2.3+abc").local
+        'abc'
+        """
+        if self._version.local:
+            return ".".join(str(x) for x in self._version.local)
+        else:
+            return None
+
+    @property
+    def public(self) -> str:
+        """The public portion of the version.
+
+        >>> Version("1.2.3").public
+        '1.2.3'
+        >>> Version("1.2.3+abc").public
+        '1.2.3'
+        >>> Version("1.2.3+abc.dev1").public
+        '1.2.3'
+        """
+        return str(self).split("+", 1)[0]
+
+    @property
+    def base_version(self) -> str:
+        """The "base version" of the version.
+
+        >>> Version("1.2.3").base_version
+        '1.2.3'
+        >>> Version("1.2.3+abc").base_version
+        '1.2.3'
+        >>> Version("1!1.2.3+abc.dev1").base_version
+        '1!1.2.3'
+
+        The "base version" is the public version of the project without any pre or post
+        release markers.
+        """
+        parts = []
+
+        # Epoch
+        if self.epoch != 0:
+            parts.append(f"{self.epoch}!")
+
+        # Release segment
+        parts.append(".".join(str(x) for x in self.release))
+
+        return "".join(parts)
+
+    @property
+    def is_prerelease(self) -> bool:
+        """Whether this version is a pre-release.
+
+        >>> Version("1.2.3").is_prerelease
+        False
+        >>> Version("1.2.3a1").is_prerelease
+        True
+        >>> Version("1.2.3b1").is_prerelease
+        True
+        >>> Version("1.2.3rc1").is_prerelease
+        True
+        >>> Version("1.2.3dev1").is_prerelease
+        True
+        """
+        return self.dev is not None or self.pre is not None
+
+    @property
+    def is_postrelease(self) -> bool:
+        """Whether this version is a post-release.
+
+        >>> Version("1.2.3").is_postrelease
+        False
+        >>> Version("1.2.3.post1").is_postrelease
+        True
+        """
+        return self.post is not None
+
+    @property
+    def is_devrelease(self) -> bool:
+        """Whether this version is a development release.
+
+        >>> Version("1.2.3").is_devrelease
+        False
+        >>> Version("1.2.3.dev1").is_devrelease
+        True
+        """
+        return self.dev is not None
+
+    @property
+    def major(self) -> int:
+        """The first item of :attr:`release` or ``0`` if unavailable.
+
+        >>> Version("1.2.3").major
+        1
+        """
+        return self.release[0] if len(self.release) >= 1 else 0
+
+    @property
+    def minor(self) -> int:
+        """The second item of :attr:`release` or ``0`` if unavailable.
+
+        >>> Version("1.2.3").minor
+        2
+        >>> Version("1").minor
+        0
+        """
+        return self.release[1] if len(self.release) >= 2 else 0
+
+    @property
+    def micro(self) -> int:
+        """The third item of :attr:`release` or ``0`` if unavailable.
+
+        >>> Version("1.2.3").micro
+        3
+        >>> Version("1").micro
+        0
+        """
+        return self.release[2] if len(self.release) >= 3 else 0
+
+
+def _parse_letter_version(
+    letter: Optional[str], number: Union[str, bytes, SupportsInt, None]
+) -> Optional[Tuple[str, int]]:
+
+    if letter:
+        # We consider there to be an implicit 0 in a pre-release if there is
+        # not a numeral associated with it.
+        if number is None:
+            number = 0
+
+        # We normalize any letters to their lower case form
+        letter = letter.lower()
+
+        # We consider some words to be alternate spellings of other words and
+        # in those cases we want to normalize the spellings to our preferred
+        # spelling.
+        if letter == "alpha":
+            letter = "a"
+        elif letter == "beta":
+            letter = "b"
+        elif letter in ["c", "pre", "preview"]:
+            letter = "rc"
+        elif letter in ["rev", "r"]:
+            letter = "post"
+
+        return letter, int(number)
+    if not letter and number:
+        # We assume if we are given a number, but we are not given a letter
+        # then this is using the implicit post release syntax (e.g. 1.0-1)
+        letter = "post"
+
+        return letter, int(number)
+
+    return None
+
+
+_local_version_separators = re.compile(r"[\._-]")
+
+
+def _parse_local_version(local: Optional[str]) -> Optional[LocalType]:
+    """
+    Takes a string like abc.1.twelve and turns it into ("abc", 1, "twelve").
+    """
+    if local is not None:
+        return tuple(
+            part.lower() if not part.isdigit() else int(part)
+            for part in _local_version_separators.split(local)
+        )
+    return None
+
+
+def _cmpkey(
+    epoch: int,
+    release: Tuple[int, ...],
+    pre: Optional[Tuple[str, int]],
+    post: Optional[Tuple[str, int]],
+    dev: Optional[Tuple[str, int]],
+    local: Optional[LocalType],
+) -> CmpKey:
+
+    # When we compare a release version, we want to compare it with all of the
+    # trailing zeros removed. So we'll use a reverse the list, drop all the now
+    # leading zeros until we come to something non zero, then take the rest
+    # re-reverse it back into the correct order and make it a tuple and use
+    # that for our sorting key.
+    _release = tuple(
+        reversed(list(itertools.dropwhile(lambda x: x == 0, reversed(release))))
+    )
+
+    # We need to "trick" the sorting algorithm to put 1.0.dev0 before 1.0a0.
+    # We'll do this by abusing the pre segment, but we _only_ want to do this
+    # if there is not a pre or a post segment. If we have one of those then
+    # the normal sorting rules will handle this case correctly.
+    if pre is None and post is None and dev is not None:
+        _pre: CmpPrePostDevType = NegativeInfinity
+    # Versions without a pre-release (except as noted above) should sort after
+    # those with one.
+    elif pre is None:
+        _pre = Infinity
+    else:
+        _pre = pre
+
+    # Versions without a post segment should sort before those with one.
+    if post is None:
+        _post: CmpPrePostDevType = NegativeInfinity
+
+    else:
+        _post = post
+
+    # Versions without a development segment should sort after those with one.
+    if dev is None:
+        _dev: CmpPrePostDevType = Infinity
+
+    else:
+        _dev = dev
+
+    if local is None:
+        # Versions without a local segment should sort before those with one.
+        _local: CmpLocalType = NegativeInfinity
+    else:
+        # Versions with a local segment need that segment parsed to implement
+        # the sorting rules in PEP440.
+        # - Alpha numeric segments sort before numeric segments
+        # - Alpha numeric segments sort lexicographically
+        # - Numeric segments sort numerically
+        # - Shorter versions sort before longer versions when the prefixes
+        #   match exactly
+        _local = tuple(
+            (i, "") if isinstance(i, int) else (NegativeInfinity, i) for i in local
+        )
+
+    return epoch, _release, _pre, _post, _dev, _local
diff --git a/torch/torch_version.py b/torch/torch_version.py
index f9445ce82c41..3d3b5aed2fa5 100644
--- a/torch/torch_version.py
+++ b/torch/torch_version.py
@@ -1,42 +1,9 @@
 from typing import Any, Iterable
 from .version import __version__ as internal_version
+from ._vendor.packaging.version import Version, InvalidVersion
 
-__all__ = ['TorchVersion', 'Version', 'InvalidVersion']
+__all__ = ['TorchVersion']
 
-class _LazyImport:
-    """Wraps around classes lazy imported from packaging.version
-    Output of the function v in following snippets are identical:
-       from packaging.version import Version
-       def v():
-           return Version('1.2.3')
-    and
-       Version = _LazyImport('Version')
-       def v():
-           return Version('1.2.3')
-    The difference here is that in later example imports
-    do not happen until v is called
-    """
-    def __init__(self, cls_name: str) -> None:
-        self._cls_name = cls_name
-
-    def get_cls(self):
-        try:
-            import packaging.version  # type: ignore[import]
-        except ImportError:
-            # If packaging isn't installed, try and use the vendored copy
-            # in pkg_resources
-            from pkg_resources import packaging  # type: ignore[attr-defined, no-redef]
-        return getattr(packaging.version, self._cls_name)
-
-    def __call__(self, *args, **kwargs):
-        return self.get_cls()(*args, **kwargs)
-
-    def __instancecheck__(self, obj):
-        return isinstance(obj, self.get_cls())
-
-
-Version = _LazyImport("Version")
-InvalidVersion = _LazyImport("InvalidVersion")
 
 class TorchVersion(str):
     """A string with magic powers to compare to both Version and iterables!
@@ -57,7 +24,7 @@ class TorchVersion(str):
     """
     # fully qualified type names here to appease mypy
     def _convert_to_version(self, inp: Any) -> Any:
-        if isinstance(inp, Version.get_cls()):
+        if isinstance(inp, Version):
             return inp
         elif isinstance(inp, str):
             return Version(inp)
@@ -76,7 +43,7 @@ def _cmp_wrapper(self, cmp: Any, method: str) -> bool:
         try:
             return getattr(Version(self), method)(self._convert_to_version(cmp))
         except BaseException as e:
-            if not isinstance(e, InvalidVersion.get_cls()):
+            if not isinstance(e, InvalidVersion):
                 raise
             # Fall back to regular string comparison if dealing with an invalid
             # version like 'parrot'

From 7733599b2eba42cca7a4dff388c0cf6c7826009b Mon Sep 17 00:00:00 2001
From: Max Ren 
Date: Tue, 21 Nov 2023 12:45:16 +0000
Subject: [PATCH 046/221] update pthreadpool to
 4fe0e1e183925bf8cfa6aae24237e724a96479b (#113904)

submodule / Updating pthreadpool to this revision.

This is in preparation for upgrading XNNPACK, as the new XNNPACK version uses some of the new pthreadpool APIs introduced in this revision.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113904
Approved by: https://github.com/Skylion007
---
 third_party/pthreadpool | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/third_party/pthreadpool b/third_party/pthreadpool
index a134dd5d4cee..4fe0e1e18392 160000
--- a/third_party/pthreadpool
+++ b/third_party/pthreadpool
@@ -1 +1 @@
-Subproject commit a134dd5d4cee80cce15db81a72e7f929d71dd413
+Subproject commit 4fe0e1e183925bf8cfa6aae24237e724a96479b8

From 1f8d00c5a312b490e97c31db5481cdc6544ebbcd Mon Sep 17 00:00:00 2001
From: vfdev-5 
Date: Tue, 21 Nov 2023 13:03:44 +0000
Subject: [PATCH 047/221] [inductor] Added decomposition for
 upsample_nearest_exact Nd (#113749)

Description:
- Added decomposition for upsample_nearest_exact: 1d, 2d, 3d

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113749
Approved by: https://github.com/lezcano
---
 ...DecompTest.test_aten_core_operators.expect |   3 +
 ...asDecompTest.test_has_decomposition.expect |   3 -
 torch/_decomp/decompositions.py               | 124 +++++++++++++++---
 torch/_inductor/lowering.py                   |  39 +++++-
 torch/_meta_registrations.py                  |  19 ++-
 .../_internal/common_methods_invocations.py   |   5 -
 6 files changed, 164 insertions(+), 29 deletions(-)

diff --git a/test/expect/HasDecompTest.test_aten_core_operators.expect b/test/expect/HasDecompTest.test_aten_core_operators.expect
index 326f69e39bfb..2ea3fbf1e287 100644
--- a/test/expect/HasDecompTest.test_aten_core_operators.expect
+++ b/test/expect/HasDecompTest.test_aten_core_operators.expect
@@ -19,6 +19,9 @@ aten::_softmax
 aten::_softmax.out
 aten::_to_copy
 aten::_to_copy.out
+aten::_upsample_nearest_exact1d
+aten::_upsample_nearest_exact2d
+aten::_upsample_nearest_exact3d
 aten::abs
 aten::abs.out
 aten::abs_
diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect
index b43759643a09..5694725a291c 100644
--- a/test/expect/HasDecompTest.test_has_decomposition.expect
+++ b/test/expect/HasDecompTest.test_has_decomposition.expect
@@ -576,15 +576,12 @@ aten::_upsample_bilinear2d_aa
 aten::_upsample_bilinear2d_aa.out
 aten::_upsample_bilinear2d_aa_backward
 aten::_upsample_bilinear2d_aa_backward.grad_input
-aten::_upsample_nearest_exact1d
 aten::_upsample_nearest_exact1d.out
 aten::_upsample_nearest_exact1d_backward
 aten::_upsample_nearest_exact1d_backward.grad_input
-aten::_upsample_nearest_exact2d
 aten::_upsample_nearest_exact2d.out
 aten::_upsample_nearest_exact2d_backward
 aten::_upsample_nearest_exact2d_backward.grad_input
-aten::_upsample_nearest_exact3d
 aten::_upsample_nearest_exact3d.out
 aten::_upsample_nearest_exact3d_backward
 aten::_upsample_nearest_exact3d_backward.grad_input
diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py
index f037221f469c..9dd35bb49e7b 100644
--- a/torch/_decomp/decompositions.py
+++ b/torch/_decomp/decompositions.py
@@ -2389,7 +2389,17 @@ def upsample_nearest1d_vec(input, output_size, scale_factors):
     osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
     scale = get_scale_value(scale_factors, 0)
 
-    return upsample_nearest1d(input, osize, scale)
+    return aten.upsample_nearest1d.default(input, osize, scale)
+
+
+@register_decomposition(aten._upsample_nearest_exact1d.vec)
+@aten._upsample_nearest_exact1d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
+@aten._upsample_nearest_exact1d.vec.py_impl(DispatchKey.Autograd)
+def _upsample_nearest_exact1d_vec(input, output_size, scale_factors):
+    osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
+    scale = get_scale_value(scale_factors, 0)
+
+    return aten._upsample_nearest_exact1d.default(input, osize, scale)
 
 
 @register_decomposition(aten.upsample_nearest2d.vec)
@@ -2400,7 +2410,18 @@ def upsample_nearest2d_vec(input, output_size, scale_factors):
     scale_h = get_scale_value(scale_factors, 0)
     scale_w = get_scale_value(scale_factors, 1)
 
-    return upsample_nearest2d(input, osize, scale_h, scale_w)
+    return aten.upsample_nearest2d.default(input, osize, scale_h, scale_w)
+
+
+@register_decomposition(aten._upsample_nearest_exact2d.vec)
+@aten._upsample_nearest_exact2d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
+@aten._upsample_nearest_exact2d.vec.py_impl(DispatchKey.Autograd)
+def _upsample_nearest_exact2d_vec(input, output_size, scale_factors):
+    osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
+    scale_h = get_scale_value(scale_factors, 0)
+    scale_w = get_scale_value(scale_factors, 1)
+
+    return aten._upsample_nearest_exact2d.default(input, osize, scale_h, scale_w)
 
 
 @register_decomposition(aten.upsample_nearest3d.vec)
@@ -2412,26 +2433,49 @@ def upsample_nearest3d_vec(input, output_size, scale_factors):
     scale_h = get_scale_value(scale_factors, 1)
     scale_w = get_scale_value(scale_factors, 2)
 
-    return upsample_nearest3d(input, osize, scale_d, scale_h, scale_w)
+    return aten.upsample_nearest3d.default(input, osize, scale_d, scale_h, scale_w)
+
+
+@register_decomposition(aten._upsample_nearest_exact3d.vec)
+@aten._upsample_nearest_exact3d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
+@aten._upsample_nearest_exact3d.vec.py_impl(DispatchKey.Autograd)
+def _upsample_nearest_exact3d_vec(input, output_size, scale_factors):
+    osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
+    scale_d = get_scale_value(scale_factors, 0)
+    scale_h = get_scale_value(scale_factors, 1)
+    scale_w = get_scale_value(scale_factors, 2)
+
+    return aten._upsample_nearest_exact3d.default(
+        input, osize, scale_d, scale_h, scale_w
+    )
 
 
-def _compute_upsample_nearest_indices(input, output_size, scales):
+def _compute_upsample_nearest_indices(input, output_size, scales, exact=False):
     # For each dim in output_size, compute the set of input indices used
     # to produce the upsampled output.
     indices = []
     num_spatial_dims = len(output_size)
-    input_dtype = torch.float if input.dtype == torch.uint8 else input.dtype
+    offset = 0.5 if exact else 0.0
+
     for d in range(num_spatial_dims):
         # Math matches aten/src/ATen/native/cpu/UpSampleKernel.cpp
+        #
         # Indices are computed as following:
         # scale = isize / osize
+        # Case: exact=False
         # input_index = floor(output_index * scale)
         # Same as OpenCV INTER_NEAREST
+        #
+        # Case: exact=False
+        # index_f32 = (output_index + 0.5) * scale - 0.5
+        # input_index = round(index_f32)
+        # Same as Pillow and Scikit-Image/Scipy ndi.zoom
         osize = output_size[d]
-        output_indices = torch.arange(osize, dtype=input_dtype, device=input.device)
         isize = input.shape[-num_spatial_dims + d]
         scale = isize / (isize * scales[d]) if scales[d] is not None else isize / osize
-        input_indices = (output_indices * scale).to(torch.int64)
+
+        output_indices = torch.arange(osize, dtype=torch.float32, device=input.device)
+        input_indices = ((output_indices + offset) * scale).to(torch.int64)
         for _ in range(num_spatial_dims - 1 - d):
             input_indices = input_indices.unsqueeze(-1)
         indices.append(input_indices)
@@ -2450,18 +2494,21 @@ def upsample_nearest1d(
     return aten._unsafe_index(input, (None, None, l_indices))
 
 
-@register_decomposition(aten.upsample_nearest2d.default)
-@aten.upsample_nearest2d.default.py_impl(DispatchKey.Autograd)
+@register_decomposition(aten._upsample_nearest_exact1d.default)
+@aten._upsample_nearest_exact1d.default.py_impl(DispatchKey.Autograd)
 @pw_cast_for_opmath
-def upsample_nearest2d(
+def _upsample_nearest_exact1d(
     input: Tensor,
     output_size: List[int],
-    scales_h: Optional[float] = None,
-    scales_w: Optional[float] = None,
+    scales: Optional[float] = None,
 ) -> Tensor:
-    h_indices, w_indices = _compute_upsample_nearest_indices(
-        input, output_size, (scales_h, scales_w)
+    (l_indices,) = _compute_upsample_nearest_indices(
+        input, output_size, (scales,), exact=True
     )
+    return aten._unsafe_index(input, (None, None, l_indices))
+
+
+def _upsample_nearest2d_common(input, h_indices, w_indices):
     result = aten._unsafe_index(input, (None, None, h_indices, w_indices))
 
     # convert output to correct memory format, if necessary
@@ -2473,10 +2520,39 @@ def upsample_nearest2d(
         memory_format = torch.contiguous_format
 
     result = result.contiguous(memory_format=memory_format)
-
     return result
 
 
+@register_decomposition(aten.upsample_nearest2d.default)
+@aten.upsample_nearest2d.default.py_impl(DispatchKey.Autograd)
+@pw_cast_for_opmath
+def upsample_nearest2d(
+    input: Tensor,
+    output_size: List[int],
+    scales_h: Optional[float] = None,
+    scales_w: Optional[float] = None,
+) -> Tensor:
+    h_indices, w_indices = _compute_upsample_nearest_indices(
+        input, output_size, (scales_h, scales_w)
+    )
+    return _upsample_nearest2d_common(input, h_indices, w_indices)
+
+
+@register_decomposition(aten._upsample_nearest_exact2d.default)
+@aten._upsample_nearest_exact2d.default.py_impl(DispatchKey.Autograd)
+@pw_cast_for_opmath
+def _upsample_nearest_exact2d(
+    input: Tensor,
+    output_size: List[int],
+    scales_h: Optional[float] = None,
+    scales_w: Optional[float] = None,
+) -> Tensor:
+    h_indices, w_indices = _compute_upsample_nearest_indices(
+        input, output_size, (scales_h, scales_w), exact=True
+    )
+    return _upsample_nearest2d_common(input, h_indices, w_indices)
+
+
 @register_decomposition(aten.upsample_nearest3d.default)
 @aten.upsample_nearest3d.default.py_impl(DispatchKey.Autograd)
 @pw_cast_for_opmath
@@ -2495,6 +2571,24 @@ def upsample_nearest3d(
     return result
 
 
+@register_decomposition(aten._upsample_nearest_exact3d.default)
+@aten._upsample_nearest_exact3d.default.py_impl(DispatchKey.Autograd)
+@pw_cast_for_opmath
+def _upsample_nearest_exact3d(
+    input: Tensor,
+    output_size: List[int],
+    scales_d: Optional[float] = None,
+    scales_h: Optional[float] = None,
+    scales_w: Optional[float] = None,
+) -> Tensor:
+    d_indices, h_indices, w_indices = _compute_upsample_nearest_indices(
+        input, output_size, (scales_d, scales_h, scales_w), exact=True
+    )
+    result = aten._unsafe_index(input, (None, None, d_indices, h_indices, w_indices))
+
+    return result
+
+
 def gather_params(params, has_biases, has_projections):
     if has_biases and has_projections:
         group_size = 5
diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py
index 2dac15949d2e..cf30ef7a7aac 100644
--- a/torch/_inductor/lowering.py
+++ b/torch/_inductor/lowering.py
@@ -102,6 +102,7 @@ def add_layout_constraint(fn, constraint):
         aten.max_pool2d_with_indices_backward,
         aten.mm,
         aten.upsample_nearest2d,
+        aten._upsample_nearest_exact2d,
         aten.upsample_bicubic2d,
         aten._int_mm,
     ]
@@ -3258,7 +3259,11 @@ def backend_reduce_str(reduce):
 
 
 def upsample_nearestnd(
-    x, output_size, scales_x: Tuple[Optional[float], ...], n: int = 2
+    x,
+    output_size,
+    scales_x: Tuple[Optional[float], ...],
+    n: int = 2,
+    exact: bool = False,
 ):
     x.realize_hint()  # elements are reused
     x_loader = x.make_loader()
@@ -3270,12 +3275,17 @@ def upsample_nearestnd(
     o_sizes = output_size
 
     scales = [i / o for i, o in zip(i_sizes, o_sizes)]
-    for i, scale in enumerate(scales):
+    for i, scale in enumerate(scales_x):
         if scale:
             scales[i] = scale
 
     def scale_fn(x, scale, size):
+        # Nearest Exact: input_index = round(scale * (output_index + 0.5) - 0.5)
+        #                            = floor(scale * (output_index + 0.5))
+        # Nearest: input_index = floor(scale * output_index)
         x = ops.index_expr(x, torch.float32)
+        if exact:
+            x = ops.add(x, ops.constant(0.5, torch.float32))
         x = ops.mul(x, ops.constant(scale, torch.float32))
         x = ops.to_dtype(x, torch.int32)
         return ops.indirect_indexing(x, size, check=False)
@@ -3300,6 +3310,11 @@ def upsample_nearest1d(x, output_size, scales: Optional[float] = None):
     return upsample_nearestnd(x, output_size, (scales,), n=1)
 
 
+@register_lowering(aten._upsample_nearest_exact1d.default)
+def _upsample_nearest_exact1d(x, output_size, scales: Optional[float] = None):
+    return upsample_nearestnd(x, output_size, (scales,), n=1, exact=True)
+
+
 @register_lowering(aten.upsample_nearest2d.default)
 def upsample_nearest2d(
     x, output_size, scales_h: Optional[float] = None, scales_w: Optional[float] = None
@@ -3307,6 +3322,13 @@ def upsample_nearest2d(
     return upsample_nearestnd(x, output_size, (scales_h, scales_w), n=2)
 
 
+@register_lowering(aten._upsample_nearest_exact2d.default)
+def _upsample_nearest_exact2d(
+    x, output_size, scales_h: Optional[float] = None, scales_w: Optional[float] = None
+):
+    return upsample_nearestnd(x, output_size, (scales_h, scales_w), n=2, exact=True)
+
+
 @register_lowering(aten.upsample_nearest3d.default)
 def upsample_nearest3d(
     x,
@@ -3318,6 +3340,19 @@ def upsample_nearest3d(
     return upsample_nearestnd(x, output_size, (scales_d, scales_h, scales_w), n=3)
 
 
+@register_lowering(aten._upsample_nearest_exact3d.default)
+def _upsample_nearest_exact3d(
+    x,
+    output_size,
+    scales_d: Optional[float] = None,
+    scales_h: Optional[float] = None,
+    scales_w: Optional[float] = None,
+):
+    return upsample_nearestnd(
+        x, output_size, (scales_d, scales_h, scales_w), n=3, exact=True
+    )
+
+
 def _create_constants(*args, dtype):
     return tuple(ops.constant(a, dtype) for a in args)
 
diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py
index d91df2b4f8c3..73879f3dc273 100644
--- a/torch/_meta_registrations.py
+++ b/torch/_meta_registrations.py
@@ -5386,7 +5386,9 @@ def upsample_common_check(input_size, output_size, num_spatial_dims):
     return (nbatch, channels, *output_size)
 
 
-@register_meta(aten.upsample_nearest1d.default)
+@register_meta(
+    [aten.upsample_nearest1d.default, aten._upsample_nearest_exact1d.default]
+)
 def upsample_nearest1d(input, output_size, scales=None):
     torch._check(
         input.numel() != 0 or multiply_integers(input.size()[1:]),
@@ -5400,7 +5402,9 @@ def upsample_nearest1d(input, output_size, scales=None):
     )
 
 
-@register_meta(aten.upsample_nearest2d.default)
+@register_meta(
+    [aten.upsample_nearest2d.default, aten._upsample_nearest_exact2d.default]
+)
 def upsample_nearest2d(input, output_size, scales_h=None, scales_w=None):
     torch._check(
         input.numel() != 0 or multiply_integers(input.size()[1:]),
@@ -5424,7 +5428,12 @@ def upsample_nearest2d(input, output_size, scales_h=None, scales_w=None):
     return output
 
 
-@register_meta(aten.upsample_nearest2d_backward.default)
+@register_meta(
+    [
+        aten.upsample_nearest2d_backward.default,
+        aten._upsample_nearest_exact2d_backward.default,
+    ]
+)
 def upsample_nearest2d_backward(
     grad_output: Tensor,
     output_size: Sequence[Union[int, torch.SymInt]],
@@ -5454,7 +5463,9 @@ def upsample_nearest2d_backward(
     )  # type: ignore[call-overload]
 
 
-@register_meta(aten.upsample_nearest3d.default)
+@register_meta(
+    [aten.upsample_nearest3d.default, aten._upsample_nearest_exact3d.default]
+)
 def upsample_nearest3d(input, output_size, scales_d=None, scales_h=None, scales_w=None):
     torch._check(
         input.numel() != 0 or multiply_integers(input.size()[1:]),
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 33595ad51bb2..fde2a7b24801 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -13517,11 +13517,6 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
                DecorateInfo(unittest.expectedFailure, 'TestOperators', 'test_vmapjvpall_has_batch_rule'),
                DecorateInfo(unittest.expectedFailure, 'TestOperators', 'test_vmapvjp_has_batch_rule'),
                DecorateInfo(unittest.expectedFailure, 'TestVmapOperatorsOpInfo', 'test_op_has_batch_rule'),
-               # MissingOperatorWithoutDecomp: missing lowering
-               DecorateInfo(unittest.expectedFailure, 'TestInductorOpInfo', 'test_comprehensive'),
-               # RuntimeError: Cannot call sizes() on tensor with symbolic sizes/strides
-               DecorateInfo(unittest.expectedFailure, 'TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive'),
-               DecorateInfo(unittest.expectedFailure, 'TestEagerFusionOpInfo', 'test_aot_autograd_symbolic_exhaustive'),
                # NotImplementedError: The operator 'aten::_upsample_nearest_exact3d.out' is not currently implemented
                # for the MPS device.
                DecorateInfo(unittest.expectedFailure, 'TestConsistency'),

From 4b7f9fa436be8f21a1b59c3ab6295cda02340570 Mon Sep 17 00:00:00 2001
From: Isuru Fernando 
Date: Mon, 20 Nov 2023 18:50:36 +0000
Subject: [PATCH 048/221] Meta register all foreach ops (#112281)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112281
Approved by: https://github.com/lezcano
---
 torch/_meta_registrations.py                  | 330 +++++++++--------
 .../_internal/common_methods_invocations.py   | 331 +++---------------
 2 files changed, 226 insertions(+), 435 deletions(-)

diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py
index 73879f3dc273..dc8948bf0a42 100644
--- a/torch/_meta_registrations.py
+++ b/torch/_meta_registrations.py
@@ -1,5 +1,6 @@
 import math
 from enum import Enum
+from functools import partial
 from typing import List, Optional, Sequence, Tuple, Union
 
 import torch
@@ -2959,37 +2960,166 @@ def meta_addbmm(self, batch1, batch2, *, beta=1, alpha=1):
     return self.new_empty(self.size())
 
 
-@register_meta(
+def register_meta_foreach(ops):
+    def wrapper(fn):
+        def register(op):
+            op_name = str(op).split(".")[1]
+            scalar_op = getattr(aten, op_name.replace("_foreach_", ""))
+
+            _add_op_to_registry(
+                meta_table,
+                op,
+                partial(
+                    fn,
+                    _scalar_op=scalar_op,
+                ),
+            )
+
+        pytree.tree_map_(register, ops)
+        return fn
+
+    return wrapper
+
+
+@register_meta_foreach(
     [
-        aten._foreach_abs_.default,
-        aten._foreach_neg_.default,
-        aten._foreach_reciprocal_.default,
-        aten._foreach_sqrt_.default,
-        aten._foreach_sign_.default,
-    ]
+        aten._foreach_abs,
+        aten._foreach_acos,
+        aten._foreach_asin,
+        aten._foreach_atan,
+        aten._foreach_ceil,
+        aten._foreach_cos,
+        aten._foreach_cosh,
+        aten._foreach_erf,
+        aten._foreach_erfc,
+        aten._foreach_exp,
+        aten._foreach_expm1,
+        aten._foreach_frac,
+        aten._foreach_floor,
+        aten._foreach_lgamma,
+        aten._foreach_log,
+        aten._foreach_log10,
+        aten._foreach_log1p,
+        aten._foreach_log2,
+        aten._foreach_neg,
+        aten._foreach_reciprocal,
+        aten._foreach_round,
+        aten._foreach_sigmoid,
+        aten._foreach_sign,
+        aten._foreach_sin,
+        aten._foreach_sinh,
+        aten._foreach_sqrt,
+        aten._foreach_tan,
+        aten._foreach_tanh,
+        aten._foreach_trunc,
+        aten._foreach_zero,
+        aten._foreach_add,
+        aten._foreach_sub,
+        aten._foreach_mul,
+        aten._foreach_div,
+        aten._foreach_clamp_min,
+        aten._foreach_clamp_max,
+        aten._foreach_lerp,
+    ],
 )
-def meta__foreach_unaop_(self):
+def _meta_foreach_out_of_place(*args, _scalar_op=None, **kwargs):
     torch._check(
-        isinstance(self, List),
-        lambda: f"Expect List[Tensor] but got {type(self)}",
+        isinstance(args[0], list),
+        lambda: (f"The first argument must be List[Tensor], but got {type(args[0])}."),
     )
 
+    nelem = len(args[0])
+    torch._check(
+        nelem > 0,
+        lambda: ("Tensor list must have at least one tensor."),
+    )
 
-@register_meta(
+    nlists = 1
+    for iarg, arg in enumerate(args[1:]):
+        if isinstance(arg, list):
+            nlists += 1
+            torch._check(
+                len(arg) == nelem,
+                lambda: (
+                    f"self and argument-{iarg+2} must match in length, "
+                    f"but got {nelem} and {len(arg)}."
+                ),
+            )
+        elif isinstance(arg, Tensor):
+            torch._check(
+                arg.dim() == 0 and arg.numel() == 1,
+                lambda: (
+                    "scalar tensor expected to be 0 dim but it has "
+                    f"{arg.dim()} dimensions and {arg.numel()} elements."
+                ),
+            )
+        else:
+            break
+
+    result = []
+    for elem in range(nelem):
+        each_args = [args[i][elem] for i in range(nlists)]
+        result.append(_scalar_op(*each_args, *args[nlists:], **kwargs))
+
+    return result
+
+
+@register_meta_foreach(
     [
-        aten._foreach_abs.default,
-        aten._foreach_neg.default,
-        aten._foreach_reciprocal.default,
-        aten._foreach_sqrt.default,
-        aten._foreach_sign.default,
+        aten._foreach_abs_,
+        aten._foreach_acos_,
+        aten._foreach_asin_,
+        aten._foreach_atan_,
+        aten._foreach_ceil_,
+        aten._foreach_cos_,
+        aten._foreach_cosh_,
+        aten._foreach_erf_,
+        aten._foreach_erfc_,
+        aten._foreach_exp_,
+        aten._foreach_expm1_,
+        aten._foreach_frac_,
+        aten._foreach_floor_,
+        aten._foreach_lgamma_,
+        aten._foreach_log_,
+        aten._foreach_log10_,
+        aten._foreach_log1p_,
+        aten._foreach_log2_,
+        aten._foreach_neg_,
+        aten._foreach_reciprocal_,
+        aten._foreach_round_,
+        aten._foreach_sigmoid_,
+        aten._foreach_sign_,
+        aten._foreach_sin_,
+        aten._foreach_sinh_,
+        aten._foreach_sqrt_,
+        aten._foreach_tan_,
+        aten._foreach_tanh_,
+        aten._foreach_trunc_,
+        aten._foreach_zero_,
+        aten._foreach_add_,
+        aten._foreach_sub_,
+        aten._foreach_mul_,
+        aten._foreach_div_,
+        aten._foreach_clamp_min_,
+        aten._foreach_clamp_max_,
+        aten._foreach_lerp_,
+        aten._foreach_copy_,
     ]
 )
-def meta__foreach_unaop(self):
+def _meta_foreach_inplace(*args, _scalar_op=None, **kwargs):
+    _meta_foreach_out_of_place(*args, _scalar_op=_scalar_op, **kwargs)
+    return
+
+
+@register_meta([aten._foreach_pow.ScalarAndTensor])
+def meta__foreach_pow_scalar_and_tensor(self, exponent):
+    # Only foreach_pow has a ScalarAndTensor method and needs special
+    # handling because it does not work with _meta_foreach_out_of_place.
     torch._check(
-        isinstance(self, List),
-        lambda: f"Expect List[Tensor] but got {type(self)}",
+        isinstance(exponent, List),
+        lambda: f"exponent must be a tensor list but got {type(exponent)}",
     )
-    return [torch.empty_like(s) for s in self]
+    return [torch.empty_like(e) for e in exponent]
 
 
 def _check_foreach_binop_tensor_lists(self, other):
@@ -3011,130 +3141,25 @@ def _check_foreach_binop_tensor_lists(self, other):
 
 @register_meta(
     [
-        aten._foreach_add.List,
-        aten._foreach_sub.List,
-        aten._foreach_mul.List,
-        aten._foreach_div.List,
-        aten._foreach_maximum.List,
-        aten._foreach_minimum.List,
-        aten._foreach_clamp_min.List,
-        aten._foreach_clamp_max.List,
-    ]
-)
-def meta__foreach_binop_list(self, other, alpha=1):
-    _check_foreach_binop_tensor_lists(self, other)
-    return [torch.empty_like(s) for s in self]
-
-
-@register_meta(
-    [
-        aten._foreach_add_.List,
-        aten._foreach_sub_.List,
-        aten._foreach_mul_.List,
-        aten._foreach_div_.List,
-        aten._foreach_maximum_.List,
-        aten._foreach_minimum_.List,
-        aten._foreach_clamp_min_.List,
-        aten._foreach_clamp_max_.List,
-    ]
-)
-def meta__foreach_binop__list(self, other, alpha=1):
-    _check_foreach_binop_tensor_lists(self, other)
-
-
-@register_meta(
-    [
-        aten._foreach_add.Tensor,
-    ]
-)
-def meta__foreach_binop_tensor(self, other, alpha=1):
-    torch._check(
-        isinstance(self, List),
-        lambda: f"The first argument must be List[Tensor], but got {type(self)}.",
-    )
-    torch._check(
-        isinstance(other, torch.Tensor),
-        lambda: f"The second argument must be Tensor, but got {type(other)}.",
-    )
-    return [torch.empty_like(s) for s in self]
-
-
-@register_meta(
-    [
-        aten._foreach_add_.Tensor,
-    ]
-)
-def meta__foreach_binop__tensor(self, other, alpha=1):
-    torch._check(
-        isinstance(self, List),
-        lambda: f"The first argument must be List[Tensor], but got {type(self)}.",
-    )
-    torch._check(
-        isinstance(other, torch.Tensor),
-        lambda: f"The second argument must be Tensor, but got {type(other)}.",
-    )
-
-
-@register_meta(
-    [
-        aten._foreach_add_.Scalar,
-        aten._foreach_mul_.Scalar,
-        aten._foreach_sub_.Scalar,
-        aten._foreach_div_.Scalar,
-        aten._foreach_maximum_.Scalar,
+        aten._foreach_maximum,
+        aten._foreach_minimum,
     ]
 )
-def meta__foreach_binop__scalar(self, scalar=1):
-    torch._check(
-        isinstance(self, List),
-        lambda: f"The first argument of must be List[Tensor], but got {type(self)}.",
-    )
+def meta__foreach_binop_scalar(*args):
+    # aten.maximum(Tensor, Scalar) does not exist.
+    return _meta_foreach_out_of_place(*args, _scalar_op=aten.clamp_min)
 
 
 @register_meta(
     [
-        aten._foreach_add.Scalar,
-        aten._foreach_div.Scalar,
-        aten._foreach_mul.Scalar,
-        aten._foreach_sub.Scalar,
-    ]
-)
-def meta__foreach_binop_scalar(self, scalar=1):
-    torch._check(
-        isinstance(self, List),
-        lambda: f"The first argument of must be List[Tensor], but got {type(self)}.",
-    )
-    return [torch.empty_like(s) for s in self]
-
-
-@register_meta(
-    [
-        aten._foreach_addcdiv_.Scalar,
-        aten._foreach_addcmul_.Scalar,
+        aten._foreach_maximum_,
+        aten._foreach_minimum_,
     ]
 )
-def meta__foreach_addcop__scalar(self, tensor1, tensor2, scalar=1):
-    torch._check(
-        all(isinstance(l, List) for l in [self, tensor1, tensor2]),
-        lambda: (
-            "All arguments of _foreach_addc*_ must be List[Tensor], "
-            f"but got {type(self)}, {type(tensor1)}, and {type(tensor2)}"
-        ),
-    )
-    torch._check(len(self) > 0, lambda: "input tensor list must not be empty.")
-    torch._check(
-        len(self) == len(tensor1) and len(self) == len(tensor2),
-        lambda: "All input tensor lists must have the same length",
-    )
-
-
-@register_meta(
-    [
-        aten._foreach_lerp_.Scalar,
-    ]
-)
-def meta__foreach_lerp__scalar(self, other, scalar=1):
-    _check_foreach_binop_tensor_lists(self, other)
+def meta__foreach_binop__scalar(*args):
+    # aten.maximum(Tensor, Scalar) does not exist
+    _meta_foreach_inplace(*args, _scalar_op=aten.clamp_min_)
+    return
 
 
 @register_meta(
@@ -3144,6 +3169,8 @@ def meta__foreach_lerp__scalar(self, other, scalar=1):
     ]
 )
 def meta__foreach_addcop_scalar(self, tensor1, tensor2, scalar=1):
+    # forach_addcdiv and addcdiv have different signatures and
+    # cannot use _meta_foreach_out_of_place.
     torch._check(
         all(isinstance(l, List) for l in [self, tensor1, tensor2]),
         lambda: (
@@ -3160,15 +3187,6 @@ def meta__foreach_addcop_scalar(self, tensor1, tensor2, scalar=1):
     return [torch.empty_like(s) for s in self]
 
 
-@register_meta([aten._foreach_pow.ScalarAndTensor])
-def meta__foreach_pow_scalar_and_tensor(self, exponent):
-    torch._check(
-        isinstance(exponent, List),
-        lambda: f"exponent must be a tensor list but got {type(exponent)}",
-    )
-    return [torch.empty_like(e) for e in exponent]
-
-
 @register_meta([aten._foreach_addcdiv_.Tensor, aten._foreach_addcmul_.Tensor])
 def meta__foreach_addcop_tensor(self, tensor1, tensor2, scalars):
     torch._check(
@@ -3186,9 +3204,25 @@ def meta__foreach_addcop_tensor(self, tensor1, tensor2, scalars):
     )
 
 
-@register_meta([aten._foreach_copy_])
-def meta__foreach_copy_inplace(self, src, non_blocking=False):
-    _check_foreach_binop_tensor_lists(self, src)
+@register_meta(
+    [
+        aten._foreach_addcdiv_.Scalar,
+        aten._foreach_addcmul_.Scalar,
+    ]
+)
+def meta__foreach_addcop__scalar(self, tensor1, tensor2, scalar=1):
+    torch._check(
+        all(isinstance(l, List) for l in [self, tensor1, tensor2]),
+        lambda: (
+            "All arguments of _foreach_addc*_ must be List[Tensor], "
+            f"but got {type(self)}, {type(tensor1)}, and {type(tensor2)}"
+        ),
+    )
+    torch._check(len(self) > 0, lambda: "input tensor list must not be empty.")
+    torch._check(
+        len(self) == len(tensor1) and len(self) == len(tensor2),
+        lambda: "All input tensor lists must have the same length",
+    )
 
 
 @register_meta([aten._fused_adam_.default])
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index fde2a7b24801..2629c481c495 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -8974,128 +8974,38 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs):
         'exp',
         foreach_inputs_sample_func(1, False, False),
         backward_requires_result=True,
-        skips=(
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
-        ),
     ),
     ForeachFuncInfo(
         'acos',
         foreach_inputs_sample_func(1, False, False),
-        skips=(
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
-        ),
     ),
     ForeachFuncInfo(
         'asin',
         foreach_inputs_sample_func(1, False, False),
-        skips=(
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
-        ),
     ),
     ForeachFuncInfo(
         'atan',
         foreach_inputs_sample_func(1, False, False),
-        skips=(
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
-        ),
     ),
     ForeachFuncInfo(
         'cos',
         foreach_inputs_sample_func(1, False, False),
-        skips=(
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
-        ),
     ),
     ForeachFuncInfo(
         'cosh',
         foreach_inputs_sample_func(1, False, False),
-        skips=(
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
-        ),
     ),
     ForeachFuncInfo(
         'log',
         foreach_inputs_sample_func(1, False, False),
-        skips=(
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
-        ),
     ),
     ForeachFuncInfo(
         'log10',
         foreach_inputs_sample_func(1, False, False),
-        skips=(
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
-        ),
     ),
     ForeachFuncInfo(
         'log2',
         foreach_inputs_sample_func(1, False, False),
-        skips=(
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
-        ),
     ),
     ForeachFuncInfo(
         'tan',
@@ -9114,16 +9024,6 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs):
                 device_type='cuda'
             ),
         ),
-        skips=(
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
-        ),
     ),
     ForeachFuncInfo(
         'tanh',
@@ -9139,44 +9039,14 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs):
                 device_type='cuda'
             ),
         ),
-        skips=(
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
-        ),
     ),
     ForeachFuncInfo(
         'sin',
         foreach_inputs_sample_func(1, False, False),
-        skips=(
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
-        ),
     ),
     ForeachFuncInfo(
         'sinh',
         foreach_inputs_sample_func(1, False, False),
-        skips=(
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
-        ),
     ),
     ForeachFuncInfo(
         'neg',
@@ -9196,48 +9066,18 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs):
         foreach_inputs_sample_func(1, False, False),
         dtypes=all_types_and(torch.bfloat16),
         dtypesIfCUDA=all_types_and(torch.half, torch.bfloat16),
-        skips=(
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
-        ),
     ),
     ForeachFuncInfo(
         'erf',
         foreach_inputs_sample_func(1, False, False),
         dtypes=floating_types_and(torch.bfloat16),
         dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
-        skips=(
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
-        ),
     ),
     ForeachFuncInfo(
         'erfc',
         foreach_inputs_sample_func(1, False, False),
         dtypes=floating_types_and(torch.bfloat16),
         dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
-        skips=(
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
-        ),
     ),
     ForeachFuncInfo(
         'expm1',
@@ -9245,80 +9085,30 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs):
         dtypes=floating_and_complex_types_and(torch.bfloat16),
         dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
         backward_requires_result=True,
-        skips=(
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
-        ),
     ),
     ForeachFuncInfo(
         'floor',
         foreach_inputs_sample_func(1, False, False),
         dtypes=all_types_and(torch.bfloat16),
         dtypesIfCUDA=all_types_and(torch.half, torch.bfloat16),
-        skips=(
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
-        ),
     ),
     ForeachFuncInfo(
         'log1p',
         foreach_inputs_sample_func(1, False, False),
         dtypes=floating_and_complex_types_and(torch.bfloat16),
         dtypesIfCUDA=floating_and_complex_types_and(torch.half),
-        skips=(
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
-        ),
     ),
     ForeachFuncInfo(
         'round',
         foreach_inputs_sample_func(1, False, False),
         dtypes=all_types_and(torch.bfloat16),
         dtypesIfCUDA=all_types_and(torch.half, torch.bfloat16),
-        skips=(
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
-        ),
     ),
     ForeachFuncInfo(
         'frac',
         foreach_inputs_sample_func(1, False, False),
         dtypes=floating_types_and(torch.bfloat16),
         dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
-        skips=(
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
-        ),
     ),
     ForeachFuncInfo(
         'reciprocal',
@@ -9333,32 +9123,12 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs):
         dtypes=floating_types_and(torch.bfloat16),
         dtypesIfCUDA=floating_types_and(torch.half),
         backward_requires_result=True,
-        skips=(
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
-        ),
     ),
     ForeachFuncInfo(
         'trunc',
         foreach_inputs_sample_func(1, False, False),
         dtypes=all_types_and(torch.bfloat16),
         dtypesIfCUDA=all_types_and(torch.half, torch.bfloat16),
-        skips=(
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
-        ),
     ),
     ForeachFuncInfo(
         'abs',
@@ -9367,22 +9137,12 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs):
         dtypesIfCUDA=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool),
         supports_fwgrad_bwgrad=True,
         skips=(
+            DecorateInfo(unittest.skip("In-place abs not supported for complex tensors"), "TestMeta",
+                         "test_dispatch_symbolic_meta_inplace", dtypes=complex_types()),
             DecorateInfo(unittest.skip("In-place abs not supported for complex tensors"), "TestMeta",
                          "test_dispatch_meta_inplace", dtypes=complex_types()),
             DecorateInfo(unittest.skip("In-place abs not supported for complex tensors"), "TestMeta",
                          "test_meta_inplace", dtypes=complex_types()),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace",
-                         dtypes=complex_types()),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace",
-                         dtypes=complex_types()),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace",
-                         dtypes=complex_types()),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace",
-                         dtypes=complex_types()),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides",
-                         dtypes=complex_types()),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides",
-                         dtypes=complex_types()),
         ),
     ),
     ForeachFuncInfo(
@@ -9391,13 +9151,9 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs):
         dtypes=all_types_and_complex_and(torch.bfloat16, torch.half),
         has_no_out_of_place=True,
         skips=(
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
             DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
             DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
             DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
             DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
         ),
     ),
@@ -9413,14 +9169,12 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs):
         dtypes=all_types_and(torch.bool, torch.bfloat16, torch.half),
         dtypesIfCUDA=all_types_and(torch.bool, torch.float16),
         skips=(
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
+            DecorateInfo(unittest.skip("In-place lgamma not supported for integral tensors"), "TestMeta",
+                         "test_dispatch_symbolic_meta_inplace", dtypes=integral_types_and(torch.bool)),
+            DecorateInfo(unittest.skip("In-place lgamma not supported for integral tensors"), "TestMeta",
+                         "test_dispatch_meta_inplace", dtypes=integral_types_and(torch.bool)),
+            DecorateInfo(unittest.skip("In-place lgamma not supported for integral tensors"), "TestMeta",
+                         "test_meta_inplace", dtypes=integral_types_and(torch.bool)),
         ),
     ),
 ]
@@ -9433,14 +9187,16 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs):
         dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
         supports_alpha_param=True,
         skips=(
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
+            # These tests fail with aten._local_scalar_dense not being implemented.
             DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
+            # Samples have complex types and inplace only works if the dtype is complex.
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace",
+                         dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace",
+                         dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides",
+                         dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)),
         ),
     ),
     ForeachFuncInfo(
@@ -9466,14 +9222,15 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs):
         dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
         sample_inputs_func=foreach_inputs_sample_func(2, True, True, True),
         skips=(
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
+            # Samples have complex types and inplace only works if the dtype is complex.
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace",
+                         dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace",
+                         dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace",
+                         dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides",
+                         dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)),
         ),
     ),
     ForeachFuncInfo(
@@ -9482,14 +9239,24 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs):
         dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
         sample_inputs_func=foreach_inputs_sample_func(2, True, True, True),
         skips=(
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
+            # Samples have complex types and inplace only works if the dtype is complex.
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace",
+                         dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace",
+                         dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace",
+                         dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides",
+                         dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)),
+            # fails with div_cpu is not implemented with ComplexHalf
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace",
+                         dtypes=(torch.float16,), device_type='cpu'),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace",
+                         dtypes=(torch.float16,), device_type='cpu'),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace",
+                         dtypes=(torch.float16,), device_type='cpu'),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides",
+                         dtypes=(torch.float16,), device_type='cpu'),
         ),
     ),
     ForeachFuncInfo(
@@ -9660,16 +9427,6 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs):
         dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
         dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
         dtypesIfROCM=floating_and_complex_types_and(torch.half, torch.bfloat16),
-        skips=(
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
-        ),
     ),
 ]
 

From b5dd37f23efecdd27d1ff8862dcd59613c36211d Mon Sep 17 00:00:00 2001
From: Nikita Shulga 
Date: Tue, 21 Nov 2023 14:52:55 +0000
Subject: [PATCH 049/221] [MPS] Fix memory leak in copy_from_mps_ (#114197)

By always calling `[destBuffer release]` before leaving the scope in which it was allocated.
Leak was introduced by https://github.com/pytorch/pytorch/pull/84928
Add regression test.
Before the change:
```
% python ../test/test_mps.py -v -k test_copy_cast_no_leak --repeat 10
test_copy_cast_no_leak (__main__.TestMemoryLeak) ... FAIL

======================================================================
FAIL: test_copy_cast_no_leak (__main__.TestMemoryLeak)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/Users/nshulga/git/pytorch/pytorch/torch/testing/_internal/common_utils.py", line 2554, in wrapper
    method(*args, **kwargs)
  File "/Users/nshulga/git/pytorch/pytorch/build/../test/test_mps.py", line 1064, in test_copy_cast_no_leak
    self.assertTrue(driver_before == driver_after, f"Detected {driver_after-driver_before} bytes leak of GPU memory")
AssertionError: False is not true : Detected 65536 bytes leak of GPU memory

To execute this test, run the following from the base repo dir:
     python test/test_mps.py -k test_copy_cast_no_leak

This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0

----------------------------------------------------------------------
Ran 1 test in 1.102s

FAILED (failures=1)
```
After:
```
% python ../test/test_mps.py -k test_copy_cast_no_leak --repeat 10
.
----------------------------------------------------------------------
Ran 1 test in 0.819s

OK
.
----------------------------------------------------------------------
Ran 1 test in 0.001s

OK
.
----------------------------------------------------------------------
Ran 1 test in 0.002s

OK
...
```

Fixes https://github.com/pytorch/pytorch/issues/114096

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114197
Approved by: https://github.com/kit1980
---
 aten/src/ATen/native/mps/operations/Copy.mm |  2 +-
 test/test_mps.py                            | 10 ++++++++++
 2 files changed, 11 insertions(+), 1 deletion(-)

diff --git a/aten/src/ATen/native/mps/operations/Copy.mm b/aten/src/ATen/native/mps/operations/Copy.mm
index d10b577aa729..8b1dd402e4f3 100644
--- a/aten/src/ATen/native/mps/operations/Copy.mm
+++ b/aten/src/ATen/native/mps/operations/Copy.mm
@@ -140,8 +140,8 @@ static void copy_cast_mps(at::Tensor& dst,
 
       stream->copy_and_sync(
           tmpBuffer, destBuffer, size_to_copy, storage_byte_offset, destOffset, non_blocking, profile_id);
-      [destBuffer release];
     }
+    [destBuffer release];
   }
   if (!dst.is_same(dst_)) {
     dst_.copy_(dst, non_blocking);
diff --git a/test/test_mps.py b/test/test_mps.py
index 600c7bfb5ced..240977e4390d 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -1053,6 +1053,16 @@ def leak_gpu0():
         with self.assertRaisesRegex(RuntimeError, r"MPS driver API confirmed .+"):
             leak_gpu0()
 
+    def test_copy_cast_no_leak(self):
+        a = torch.randn(128, 128, device='mps', dtype=torch.float16)
+        torch.mps.empty_cache()
+        driver_before = torch.mps.driver_allocated_memory()
+        a = a.to(device='cpu', dtype=torch.float32)
+        a = a.to(device='mps', dtype=torch.float16)
+        torch.mps.empty_cache()
+        driver_after = torch.mps.driver_allocated_memory()
+        self.assertTrue(driver_before == driver_after, f"Detected {driver_after-driver_before} bytes leak of GPU memory")
+
 
 class TestPixelShuffle(TestCaseMPS):
     def test_pixel_shuffle_unshuffle(self):

From ef90508f7541d61ec96d24ad5b17b9d280caa38e Mon Sep 17 00:00:00 2001
From: Oguz Ulgen 
Date: Tue, 21 Nov 2023 00:03:30 -0800
Subject: [PATCH 050/221] [AOTI] Support ReinterpretView in abi mode (#114169)

https://github.com/pytorch/pytorch/pull/113967 added support for
ReinterpretView but it turnes out we codegen it differently in abi
compat mode. This PR adds support for abi compat mode as well.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114169
Approved by: https://github.com/aakhundov
---
 test/inductor/test_aot_inductor.py      | 21 +++++++++++++++++++++
 torch/_inductor/codegen/common.py       |  7 +++++--
 torch/_inductor/codegen/triton_utils.py |  2 +-
 torch/_inductor/codegen/wrapper.py      |  8 +++++++-
 4 files changed, 34 insertions(+), 4 deletions(-)

diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py
index 14b9368df302..94d1bb2227f4 100644
--- a/test/inductor/test_aot_inductor.py
+++ b/test/inductor/test_aot_inductor.py
@@ -1231,6 +1231,27 @@ def forward(self, x):
         ]
         self.check_model(Model(), (a,), constraints=constraints)
 
+    def test_triton_kernel_reinterpret_view(self):
+        if self.device != "cuda":
+            raise unittest.SkipTest("requires CUDA")
+
+        @triton.jit
+        def pass_kernel(x, y):
+            pass
+
+        class Model(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+
+            def forward(self, x):
+                # AOT export does not allow for input mutation
+                x = x.clone()
+                pass_kernel[(1,)](x, torch.empty_like(x))
+                return x
+
+        example_inputs = (torch.randn(4, device=self.device),)
+        self.check_model(Model(), example_inputs)
+
     def test_shifted_constraint_ranges(self):
         class Model(torch.nn.Module):
             def __init__(self):
diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py
index 951e3f9cd2ff..920dc318d715 100644
--- a/torch/_inductor/codegen/common.py
+++ b/torch/_inductor/codegen/common.py
@@ -48,7 +48,7 @@ def data_type_logger(msg):
         schedule_log.debug("Data type propagation: %s", msg)
 
 
-TensorArg = namedtuple("TensorArg", ["name", "buffer", "dtype"])
+TensorArg = namedtuple("TensorArg", ["name", "buffer", "dtype", "check_alignment"])
 SizeArg = namedtuple("SizeArg", ["name", "expr"])
 
 DeviceCodegen = namedtuple("DeviceCodegen", ["scheduling", "wrapper_codegen"])
@@ -633,6 +633,7 @@ def python_argdefs(self):
                     inplaced.inner_name,
                     inplaced.other_names[-1],
                     V.graph.get_dtype(inplaced.other_names[-1]),
+                    True,
                 )
             )
         for outer, inner in chain(
@@ -642,7 +643,9 @@ def python_argdefs(self):
                 continue
             arg_defs.append(inner)
             call_args.append(outer)
-            precompile_args.append(TensorArg(inner, outer, V.graph.get_dtype(outer)))
+            precompile_args.append(
+                TensorArg(inner, outer, V.graph.get_dtype(outer), True)
+            )
         for outer, inner in self.sizevars.items():
             arg_defs.append(inner)
             call_args.append(outer)
diff --git a/torch/_inductor/codegen/triton_utils.py b/torch/_inductor/codegen/triton_utils.py
index 7fb23d9a6f55..7e4596e47690 100644
--- a/torch/_inductor/codegen/triton_utils.py
+++ b/torch/_inductor/codegen/triton_utils.py
@@ -62,7 +62,7 @@ def is_aligned(
         https://github.com/openai/triton/blob/5282ed890d453e10b9ee30076ef89115dd197761/python/triton/runtime/jit.py#L208-L222
         """
         if isinstance(x, TensorArg):
-            if x.buffer.startswith("reinterpret_tensor"):
+            if not x.check_alignment:
                 return False
             if include_tensor:
                 return not V.graph.scheduler.is_unaligned_buffer(x.buffer)
diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py
index b66381a6f1f8..6a8000fe4d16 100644
--- a/torch/_inductor/codegen/wrapper.py
+++ b/torch/_inductor/codegen/wrapper.py
@@ -882,7 +882,13 @@ def define_user_defined_triton_kernel(self, kernel, configs, kwargs):
                 continue
             if isinstance(arg, (ir.Buffer, ir.ReinterpretView)):
                 signature.append(
-                    TensorArg(key, arg.codegen_reference(), arg.get_dtype())
+                    TensorArg(
+                        key,
+                        arg.codegen_reference(),
+                        arg.get_dtype(),
+                        # For ReinterpretView, we do not want to check alignment
+                        not isinstance(arg, ReinterpretView),
+                    )
                 )
             else:
                 signature.append(SizeArg(key, arg))

From 7694b0541690638f632931e957973022ffcad6f3 Mon Sep 17 00:00:00 2001
From: Andrew Gu 
Date: Mon, 20 Nov 2023 15:20:17 -0800
Subject: [PATCH 051/221] [DTensor] Reduced to one `isinstance` call in
 `is_shard` (#114140)

This is a nit change to save one `isinstance` call for when `dim` is not `None` but the placement is not `Shard`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114140
Approved by: https://github.com/Skylion007, https://github.com/wanchaol
ghstack dependencies: #113919, #113924, #114134, #113925, #113930, #114141, #113915
---
 torch/distributed/_tensor/placement_types.py | 7 ++++---
 1 file changed, 4 insertions(+), 3 deletions(-)

diff --git a/torch/distributed/_tensor/placement_types.py b/torch/distributed/_tensor/placement_types.py
index 7c1c61c21bf2..a8a79298c566 100644
--- a/torch/distributed/_tensor/placement_types.py
+++ b/torch/distributed/_tensor/placement_types.py
@@ -16,10 +16,11 @@ class Placement:
 
     # convenient utils to check for placement types
     def is_shard(self, dim: Optional[int] = None) -> bool:
-        if dim is not None and isinstance(self, Shard):
-            return self.dim == dim
+        is_shard_instance = isinstance(self, Shard)
+        if dim is not None and is_shard_instance:
+            return cast(Shard, self).dim == dim
         else:
-            return isinstance(self, Shard)
+            return is_shard_instance
 
     def is_replicate(self) -> bool:
         return isinstance(self, Replicate)

From f66add9b854cf021a0f7fcbb200c9480e43d9372 Mon Sep 17 00:00:00 2001
From: Jon Chuang <9093549+jon-chuang@users.noreply.github.com>
Date: Tue, 21 Nov 2023 18:19:33 +0000
Subject: [PATCH 052/221] [dynamo] graph break on `np.ndarray.tobytes`
 (#114208)

We can't model this accurately across np and tnp https://github.com/pytorch/pytorch/issues/114204#issuecomment-1820269949

So let's not even try. Just graph break.

Fixes: https://github.com/pytorch/pytorch/issues/114204

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114208
Approved by: https://github.com/lezcano
---
 test/dynamo/test_repros.py        | 13 +++++++++++++
 torch/_dynamo/variables/tensor.py |  2 ++
 2 files changed, 15 insertions(+)

diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py
index 34680624bac4..97ba0a8af258 100644
--- a/test/dynamo/test_repros.py
+++ b/test/dynamo/test_repros.py
@@ -3638,6 +3638,19 @@ def fn(x):
         fn_opt = torch.compile(fn, backend="eager", fullgraph=True)
         self.assertEqual(fn_opt(torch.zeros(1)), fn(torch.zeros(1)))
 
+    def test_numpy_tobytes_no_error(self):
+        def fn(x):
+            x += 1
+            z = x.tobytes()
+            x += 1
+            return z
+
+        cnt = torch._dynamo.testing.CompileCounter()
+        opt_fn = torch._dynamo.optimize(cnt)(fn)
+        opt_arg, arg = np.array([1, 2]), np.array([1, 2])
+        self.assertEqual(opt_fn(opt_arg), fn(arg))
+        self.assertEqual(cnt.frame_count, 2)
+
     def test_numpy_not_ndarray_recompiles(self):
         import torch
 
diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py
index a87eec6cfbbf..6fbe6c2afcc8 100644
--- a/torch/_dynamo/variables/tensor.py
+++ b/torch/_dynamo/variables/tensor.py
@@ -924,6 +924,8 @@ def call_method(
         if name in ["__len__", "size", "tolist"]:
             # delegate back to TensorVariable
             return super().call_method(tx, name, args, kwargs)
+        if name == "tobytes":
+            unimplemented("tobytes is not modelled in torch._numpy")
         proxy = tx.output.create_proxy(
             "call_function",
             numpy_method_wrapper(name),

From 1a3dbf57ca946d633b668cea1e7bf36f8367ccdf Mon Sep 17 00:00:00 2001
From: kshitij12345 
Date: Tue, 21 Nov 2023 18:55:51 +0000
Subject: [PATCH 053/221] vmap: simple inplace batch rule (#113513)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113513
Approved by: https://github.com/zou3519
---
 aten/src/ATen/functorch/BatchRulesBinaryOps.cpp      | 5 +++++
 aten/src/ATen/functorch/BatchRulesDecompositions.cpp | 1 +
 test/functorch/test_vmap.py                          | 5 -----
 test/functorch/test_vmap_registrations.py            | 1 -
 4 files changed, 6 insertions(+), 6 deletions(-)

diff --git a/aten/src/ATen/functorch/BatchRulesBinaryOps.cpp b/aten/src/ATen/functorch/BatchRulesBinaryOps.cpp
index 6bbe7fdcc1ec..1dd417052cf1 100644
--- a/aten/src/ATen/functorch/BatchRulesBinaryOps.cpp
+++ b/aten/src/ATen/functorch/BatchRulesBinaryOps.cpp
@@ -461,16 +461,21 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
   using TensorScalarInplaceT = Tensor& (Tensor::*)(const Tensor&, const Scalar&) const;
   using ScalarScalarInplaceT = Tensor& (Tensor::*)(const Scalar&, const Scalar&) const;
   using TensorInplaceT = Tensor& (Tensor::*)(const Tensor&) const;
+  using TensorInplaceModeT = Tensor& (Tensor::*)(const Tensor&, c10::optional) const;
   using ScalarInplaceT = Tensor& (Tensor::*)(const Scalar&) const;
   using CopyT = Tensor& (Tensor::*)(const Tensor&, bool) const;
 
   POINTWISE_BOXED(add_.Tensor); // just testing
+  POINTWISE_BOXED(atan2_);
+  POINTWISE_BOXED(gcd_);
+  POINTWISE_BOXED(lcm_);
   VMAP_SUPPORT2(add_, Scalar, SINGLE_ARG(unary_inplace_batch_rule));
   VMAP_SUPPORT2(sub_, Tensor, SINGLE_ARG(binary_pointwise_inplace_batch_rule));
   VMAP_SUPPORT2(sub_, Scalar, SINGLE_ARG(unary_inplace_batch_rule));
   VMAP_SUPPORT2(mul_, Tensor, SINGLE_ARG(binary_pointwise_inplace_batch_rule));
   VMAP_SUPPORT2(mul_, Scalar, SINGLE_ARG(unary_inplace_batch_rule));
   VMAP_SUPPORT2(div_, Tensor, SINGLE_ARG(binary_pointwise_inplace_batch_rule));
+  VMAP_SUPPORT2(div_, Tensor_mode, SINGLE_ARG(binary_pointwise_inplace_batch_rule>));
   VMAP_SUPPORT2(div_, Scalar, SINGLE_ARG(unary_inplace_batch_rule));
   VMAP_SUPPORT2(clamp_min_, Tensor, SINGLE_ARG(binary_pointwise_inplace_batch_rule));
   VMAP_SUPPORT2(clamp_max_, Tensor, SINGLE_ARG(binary_pointwise_inplace_batch_rule));
diff --git a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp
index 8a25a350f950..1b179a505e9a 100644
--- a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp
+++ b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp
@@ -46,6 +46,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) {
   OP_DECOMPOSE(absolute);
   OP_DECOMPOSE(absolute_);
   OP_DECOMPOSE(arctan2);
+  OP_DECOMPOSE(arctan2_);
   OP_DECOMPOSE(argsort);
   OP_DECOMPOSE(avg_pool1d);
   OP_DECOMPOSE(adaptive_max_pool1d);
diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py
index b0c21421b8b4..fc28228045f1 100644
--- a/test/functorch/test_vmap.py
+++ b/test/functorch/test_vmap.py
@@ -3762,23 +3762,18 @@ def test_op_has_batch_rule(self, device, dtype, op):
             'addmm',
             'addmv',
             'addr',
-            'atan2',
             'baddbmm',
             'clamp',
             'conj_physical',
             'cumprod',
             'cumsum',
-            'div',
-            'div',
             'floor_divide',
             'fmod',
-            'gcd',
             'heaviside',
             'hypot',
             'igamma',
             'igammac',
             'index_copy',
-            'lcm',
             'ldexp',
             'lerp',
             'neg',
diff --git a/test/functorch/test_vmap_registrations.py b/test/functorch/test_vmap_registrations.py
index b25732dfcf5e..4952f2745b6d 100644
--- a/test/functorch/test_vmap_registrations.py
+++ b/test/functorch/test_vmap_registrations.py
@@ -38,7 +38,6 @@
     "aten::align_to.ellipsis_idx",
     "aten::alpha_dropout",
     "aten::alpha_dropout_",
-    "aten::arctan2_",
     "aten::argwhere",
     "aten::bilinear",
     "aten::can_cast",

From 85b97605ab6f47efe8ce0675a6e2ee87383de28b Mon Sep 17 00:00:00 2001
From: Ying Liu 
Date: Tue, 21 Nov 2023 19:47:24 +0000
Subject: [PATCH 054/221] Enable set sequence nr (#114120)

Summary:
In some cases (especially those involving collective calls) - we would want to always kick off a collective call first before running going down another path.

For  example:

```
tbe lookup -> a2a ->
                     overarch
dense ------------->
```

if the forward code is written as
a2a_out = a2a
dense = dense_net
out = overarch(a2a_out, dense)
out.backward()

The current default is running backwards in the opposite order the forward is called. However, there is no data dependency between a2a and dense, so in reality either of them could be run first. We would like the a2a to run first because it provides optimal (on average) overlap.

Changing the seq_nr of a2a_out to something large enough would allow autograd engine to kick it off first.

Test Plan: Tests incoming

Differential Revision: D51445261

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114120
Approved by: https://github.com/ezyang, https://github.com/albanD
---
 test/test_autograd.py                       | 41 +++++++++++++++++++++
 torch/csrc/autograd/function.h              |  6 ++-
 torch/csrc/autograd/python_cpp_function.cpp | 11 ++++++
 torch/csrc/autograd/python_cpp_function.h   |  9 ++++-
 torch/csrc/autograd/python_function.cpp     |  9 +++++
 5 files changed, 73 insertions(+), 3 deletions(-)

diff --git a/test/test_autograd.py b/test/test_autograd.py
index e7dddeba57ae..f91324bab6bd 100644
--- a/test/test_autograd.py
+++ b/test/test_autograd.py
@@ -11820,6 +11820,47 @@ def backward(ctx, gO):
         TestFn.apply(inp, None).sum().backward()
         self.assertEqual(local.my_obj[10], 5)
 
+    def test_set_sequence_nr(self):
+        x = torch.randn((10,), dtype=torch.float32, requires_grad=True)
+        y = torch.randn((10,), dtype=torch.float32, requires_grad=True)
+        z = torch.randn((10,), dtype=torch.float32, requires_grad=True)
+
+        a = x + y
+        b = y + z
+        c = a + b
+
+        self.assertIsNotNone(a.grad_fn)
+        self.assertIsNotNone(b.grad_fn)
+        self.assertIsNotNone(c.grad_fn)
+
+        a.grad_fn._set_sequence_nr(100)
+        b.grad_fn._set_sequence_nr(99)
+        c.grad_fn._set_sequence_nr(98)
+
+        self.assertEqual(a.grad_fn._sequence_nr(), 100)
+        self.assertEqual(b.grad_fn._sequence_nr(), 99)
+        self.assertEqual(c.grad_fn._sequence_nr(), 98)
+
+        def log_grad_order(grad: torch.Tensor, name: str, order):
+            order.append(name)
+            return grad
+
+        order = []
+        a.register_hook(partial(log_grad_order, name="a", order=order))
+        b.register_hook(partial(log_grad_order, name="b", order=order))
+        c.register_hook(partial(log_grad_order, name="c", order=order))
+
+        c.sum().backward()
+
+        # Expect to see that even though c has the smallest sequence number, it is still the first node to get run in autograd.
+        # Also check that although a comes first during the forward, after giving it priority with sequence_nr,
+        # its autograd node is run before that of b.
+        self.assertEqual(order, ['c', 'a', 'b'])
+
+        self.assertEqual(x.grad, torch.ones_like(x))
+        self.assertEqual(y.grad, 2 * torch.ones_like(x))
+        self.assertEqual(z.grad, torch.ones_like(x))
+
 
 # Import test cases from below autograd/ here. These are found
 # implicitly by the loader, so Flake8 thinks they are unused, hence
diff --git a/torch/csrc/autograd/function.h b/torch/csrc/autograd/function.h
index 337b49f469fc..af6f7a77b695 100644
--- a/torch/csrc/autograd/function.h
+++ b/torch/csrc/autograd/function.h
@@ -328,6 +328,10 @@ struct TORCH_API Node : std::enable_shared_from_this {
     return sequence_nr_;
   }
 
+  void set_sequence_nr(uint64_t sequence_nr) {
+    sequence_nr_ = sequence_nr;
+  }
+
   // NOTE [ Topological Number ]
   //
   // topological_nr is used to prune branches in the DAG during autograd
@@ -590,7 +594,7 @@ struct TORCH_API Node : std::enable_shared_from_this {
   // Sequence number used to correlate backward nodes with forward ops in the
   // profiler and provide determinism in the engine.
   // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
-  const uint64_t sequence_nr_;
+  uint64_t sequence_nr_;
 
   // See NOTE [ Topological Number ]
   uint64_t topological_nr_ = 0;
diff --git a/torch/csrc/autograd/python_cpp_function.cpp b/torch/csrc/autograd/python_cpp_function.cpp
index 1f957b172301..66a5a152ec78 100644
--- a/torch/csrc/autograd/python_cpp_function.cpp
+++ b/torch/csrc/autograd/python_cpp_function.cpp
@@ -201,6 +201,17 @@ PyObject* THPCppFunction_sequence_nr(PyObject* self, PyObject* noargs) {
   auto& fn = *((THPCppFunction*)self)->cdata;
   return THPUtils_packUInt64(fn.sequence_nr());
 }
+
+PyObject* THPCppFunction_set_sequence_nr(
+    PyObject* self,
+    PyObject* sequence_nr) {
+  HANDLE_TH_ERRORS
+  auto& fn = *((THPCppFunction*)self)->cdata;
+  fn.set_sequence_nr(THPUtils_unpackUInt64(sequence_nr));
+  Py_RETURN_NONE;
+  END_HANDLE_TH_ERRORS
+}
+
 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays)
 static struct PyMethodDef default_methods[] = {
     THP_FUNCTION_DEFAULT_METHODS,
diff --git a/torch/csrc/autograd/python_cpp_function.h b/torch/csrc/autograd/python_cpp_function.h
index 1cd91c9dca50..c1f5219203ff 100644
--- a/torch/csrc/autograd/python_cpp_function.h
+++ b/torch/csrc/autograd/python_cpp_function.h
@@ -43,8 +43,13 @@ PyObject* CppFunction_pynew(
        THPCppFunction_register_prehook,                                        \
        METH_O,                                                                 \
        nullptr},                                                               \
-      {(char*)"name", THPCppFunction_name, METH_NOARGS, nullptr}, {            \
-    (char*)"_sequence_nr", THPCppFunction_sequence_nr, METH_NOARGS, nullptr    \
+      {(char*)"name", THPCppFunction_name, METH_NOARGS, nullptr},              \
+      {(char*)"_sequence_nr",                                                  \
+       THPCppFunction_sequence_nr,                                             \
+       METH_NOARGS,                                                            \
+       nullptr},                                                               \
+  {                                                                            \
+    (char*)"_set_sequence_nr", THPCppFunction_set_sequence_nr, METH_O, nullptr \
   }
 
 #define THP_FUNCTION_DEFAULT_PROPERTIES                                   \
diff --git a/torch/csrc/autograd/python_function.cpp b/torch/csrc/autograd/python_function.cpp
index 93646684a8cd..fa286998dfd6 100644
--- a/torch/csrc/autograd/python_function.cpp
+++ b/torch/csrc/autograd/python_function.cpp
@@ -964,6 +964,14 @@ PyObject* THPFunction_sequence_nr(PyObject* self, PyObject* noargs) {
   END_HANDLE_TH_ERRORS
 }
 
+PyObject* THPFunction_set_sequence_nr(PyObject* self, PyObject* sequence_nr) {
+  HANDLE_TH_ERRORS;
+  auto cdata = ((THPFunction*)self)->cdata.lock();
+  cdata->set_sequence_nr(THPUtils_unpackUInt64(sequence_nr));
+  Py_RETURN_NONE;
+  END_HANDLE_TH_ERRORS
+}
+
 PyObject* THPFunction_maybe_clear_saved_tensors(
     PyObject* self,
     PyObject* noargs) {
@@ -1532,6 +1540,7 @@ static struct PyGetSetDef THPFunction_properties[] = {
 static struct PyMethodDef THPFunction_methods[] = {
     {(char*)"name", THPFunction_name, METH_NOARGS, nullptr},
     {(char*)"_sequence_nr", THPFunction_sequence_nr, METH_NOARGS, nullptr},
+    {(char*)"_set_sequence_nr", THPFunction_set_sequence_nr, METH_O, nullptr},
     {(char*)"maybe_clear_saved_tensors",
      THPFunction_maybe_clear_saved_tensors,
      METH_NOARGS,

From 4e4a6ad6ecd71a1aefde3992ecf7f77e37d2e264 Mon Sep 17 00:00:00 2001
From: Xuehai Pan 
Date: Wed, 22 Nov 2023 00:16:41 +0800
Subject: [PATCH 055/221] [pytree] register pytree node type in both C++ pytree
 and Python pytree (#112111)

Changes:

1. Add `_private_register_pytree_node` API in both C++ and Python pytree. In C++ pytree, the API will only register pytree node for C++ pytree. In Python pytree, the API will only register pytree node for Python pytree.
2. Do not allow registering a type as pytree node twice in the Python pytree.
3. Add thread lock to the Python pytree node register API.
4. The old `_register_pytree_node` API will call the `_private_register_pytree_node` API and raise a deprecation warning.
5. Add a new `register_pytree_node` API to register node type in both C++ and Python implementations.
6. Add tests to ensure a warning will be raised when the old private function is called.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112111
Approved by: https://github.com/zou3519
---
 test/export/test_export.py                    |  15 +-
 test/test_fx.py                               |   2 +-
 test/test_pytree.py                           |  64 +++++-
 torch/_export/utils.py                        |   4 +-
 torch/_functorch/aot_autograd.py              |  15 +-
 torch/fx/experimental/proxy_tensor.py         |   2 +-
 torch/fx/immutable_collections.py             |   6 +-
 .../_internal/fx/dynamo_graph_extractor.py    |  13 +-
 torch/return_types.py                         |   2 +-
 torch/utils/_cxx_pytree.py                    | 201 +++++++++++++++++-
 torch/utils/_pytree.py                        | 136 +++++++++---
 11 files changed, 395 insertions(+), 65 deletions(-)

diff --git a/test/export/test_export.py b/test/export/test_export.py
index 221ea9ba075b..b1ae7c6b5b8e 100644
--- a/test/export/test_export.py
+++ b/test/export/test_export.py
@@ -623,16 +623,23 @@ class MyDataClass:
         roundtrip_spec = treespec_loads(treespec_dumps(spec))
         self.assertEqual(roundtrip_spec, spec)
 
+        @dataclass
+        class MyOtherDataClass:  # the pytree registration don't allow registering the same class twice
+            x: int
+            y: int
+            z: int = None
+
         # Override the registration with keep none fields
-        register_dataclass_as_pytree_node(MyDataClass, return_none_fields=True, serialized_type_name="test_pytree_regster_data_class.MyDataClass")
+        register_dataclass_as_pytree_node(MyOtherDataClass, return_none_fields=True, serialized_type_name="test_pytree_regster_data_class.MyOtherDataClass")
 
+        dt = MyOtherDataClass(x=3, y=4)
         flat, spec = tree_flatten(dt)
         self.assertEqual(
             spec,
             TreeSpec(
-                MyDataClass,
+                MyOtherDataClass,
                 (
-                    MyDataClass,
+                    MyOtherDataClass,
                     ['x', 'y', 'z'],
                     [],
                 ),
@@ -642,7 +649,7 @@ class MyDataClass:
         self.assertEqual(flat, [3, 4, None])
 
         orig_dt = tree_unflatten(flat, spec)
-        self.assertTrue(isinstance(orig_dt, MyDataClass))
+        self.assertTrue(isinstance(orig_dt, MyOtherDataClass))
         self.assertEqual(orig_dt.x, 3)
         self.assertEqual(orig_dt.y, 4)
         self.assertEqual(orig_dt.z, None)
diff --git a/test/test_fx.py b/test/test_fx.py
index 30c5f838f127..9a9f046e1b0b 100644
--- a/test/test_fx.py
+++ b/test/test_fx.py
@@ -3529,7 +3529,7 @@ def f_dict_add(x):
         def f_namedtuple_add(x):
             return x.x + x.y
 
-        pytree._register_pytree_node(
+        pytree.register_pytree_node(
             Foo,
             lambda x: ([x.a, x.b], None),
             lambda x, _: Foo(x[0], x[1]),
diff --git a/test/test_pytree.py b/test/test_pytree.py
index 0c0120397eea..d943db41fe7e 100644
--- a/test/test_pytree.py
+++ b/test/test_pytree.py
@@ -1,7 +1,7 @@
 # Owner(s): ["module: pytree"]
 
 import unittest
-from collections import namedtuple, OrderedDict
+from collections import namedtuple, OrderedDict, UserDict
 
 import torch
 import torch.utils._cxx_pytree as cxx_pytree
@@ -26,6 +26,45 @@ def __init__(self, x, y):
 
 
 class TestGenericPytree(TestCase):
+    @parametrize(
+        "pytree_impl",
+        [
+            subtest(py_pytree, name="py"),
+            subtest(cxx_pytree, name="cxx"),
+        ],
+    )
+    def test_register_pytree_node(self, pytree_impl):
+        class MyDict(UserDict):
+            pass
+
+        d = MyDict(a=1, b=2, c=3)
+
+        # Custom types are leaf nodes by default
+        values, spec = pytree_impl.tree_flatten(d)
+        self.assertEqual(values, [d])
+        self.assertIs(values[0], d)
+        self.assertEqual(d, pytree_impl.tree_unflatten(values, spec))
+        self.assertTrue(spec.is_leaf())
+
+        # Register MyDict as a pytree node
+        pytree_impl.register_pytree_node(
+            MyDict,
+            lambda d: (list(d.values()), list(d.keys())),
+            lambda values, keys: MyDict(zip(keys, values)),
+        )
+
+        values, spec = pytree_impl.tree_flatten(d)
+        self.assertEqual(values, [1, 2, 3])
+        self.assertEqual(d, pytree_impl.tree_unflatten(values, spec))
+
+        # Do not allow registering the same type twice
+        with self.assertRaisesRegex(ValueError, "already registered"):
+            pytree_impl.register_pytree_node(
+                MyDict,
+                lambda d: (list(d.values()), list(d.keys())),
+                lambda values, keys: MyDict(zip(keys, values)),
+            )
+
     @parametrize(
         "pytree_impl",
         [
@@ -407,6 +446,21 @@ def test_pytree_serialize_bad_input(self, pytree_impl):
 
 
 class TestPythonPytree(TestCase):
+    def test_deprecated_register_pytree_node(self):
+        class DummyType:
+            def __init__(self, x, y):
+                self.x = x
+                self.y = y
+
+        with self.assertWarnsRegex(
+            UserWarning, "torch.utils._pytree._register_pytree_node"
+        ):
+            py_pytree._register_pytree_node(
+                DummyType,
+                lambda dummy: ([dummy.x, dummy.y], None),
+                lambda xs, _: DummyType(*xs),
+            )
+
     def test_treespec_equality(self):
         self.assertTrue(
             py_pytree.LeafSpec() == py_pytree.LeafSpec(),
@@ -540,7 +594,7 @@ def __init__(self, x, y):
                 self.x = x
                 self.y = y
 
-        py_pytree._register_pytree_node(
+        py_pytree.register_pytree_node(
             DummyType,
             lambda dummy: ([dummy.x, dummy.y], None),
             lambda xs, _: DummyType(*xs),
@@ -560,7 +614,7 @@ def __init__(self, x, y):
                 self.x = x
                 self.y = y
 
-        py_pytree._register_pytree_node(
+        py_pytree.register_pytree_node(
             DummyType,
             lambda dummy: ([dummy.x, dummy.y], None),
             lambda xs, _: DummyType(*xs),
@@ -585,7 +639,7 @@ def __init__(self, x, y):
         with self.assertRaisesRegex(
             ValueError, "Both to_dumpable_context and from_dumpable_context"
         ):
-            py_pytree._register_pytree_node(
+            py_pytree.register_pytree_node(
                 DummyType,
                 lambda dummy: ([dummy.x, dummy.y], None),
                 lambda xs, _: DummyType(*xs),
@@ -599,7 +653,7 @@ def __init__(self, x, y):
                 self.x = x
                 self.y = y
 
-        py_pytree._register_pytree_node(
+        py_pytree.register_pytree_node(
             DummyType,
             lambda dummy: ([dummy.x, dummy.y], None),
             lambda xs, _: DummyType(*xs),
diff --git a/torch/_export/utils.py b/torch/_export/utils.py
index afee8efc5946..d8344783a0a3 100644
--- a/torch/_export/utils.py
+++ b/torch/_export/utils.py
@@ -63,16 +63,16 @@ def register_dataclass_as_pytree_node(
     flatten_fn: Optional[FlattenFunc] = None,
     unflatten_fn: Optional[UnflattenFunc] = None,
     *,
+    serialized_type_name: Optional[str] = None,
     to_dumpable_context: Optional[ToDumpableContextFn] = None,
     from_dumpable_context: Optional[FromDumpableContextFn] = None,
-    serialized_type_name: Optional[str] = None,
     return_none_fields: bool = False,
 ) -> None:
     assert dataclasses.is_dataclass(
         cls
     ), f"Only dataclasses can be registered with this function: {cls}"
 
-    serialized_type = f"{cls.__module__}.{cls.__name__}"
+    serialized_type = f"{cls.__module__}.{cls.__qualname__}"
     SERIALIZED_DATACLASS_TO_PYTHON_DATACLASS[serialized_type] = cls
 
     def default_flatten_fn(obj: Any) -> Tuple[List[Any], Context]:
diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py
index 0057902bbe83..4c29f1a85002 100644
--- a/torch/_functorch/aot_autograd.py
+++ b/torch/_functorch/aot_autograd.py
@@ -29,7 +29,7 @@
 from torch._subclasses import FakeTensor, FakeTensorMode
 from torch._subclasses.fake_tensor import is_fake
 from torch._subclasses.functional_tensor import FunctionalTensor, FunctionalTensorMode
-from torch.fx import immutable_collections, Interpreter
+from torch.fx import Interpreter
 from torch.fx.experimental.proxy_tensor import is_sym_node, py_sym_types
 from torch.fx.experimental.symbolic_shapes import (
     ShapeEnv, is_concrete_int, fx_placeholder_vals, definitely_true, definitely_false, sym_eq
@@ -95,19 +95,6 @@ def strict_zip(*iterables, strict=True, **kwargs):
     )
 )
 
-pytree._register_pytree_node(
-    immutable_collections.immutable_list,
-    lambda x: (list(x), None),
-    lambda x, c: immutable_collections.immutable_list(x),
-)
-pytree._register_pytree_node(
-    immutable_collections.immutable_dict,
-    lambda x: (list(x.values()), list(x.keys())),
-    lambda x, c: immutable_collections.immutable_dict(
-        dict(zip(c, x))
-    ),
-)
-
 def partial_asdict(obj: Any) -> Any:
     if dataclasses.is_dataclass(obj):
         return {field.name: getattr(obj, field.name) for field in dataclasses.fields(obj)}
diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py
index dd3520f541aa..e3d8bd673a4d 100644
--- a/torch/fx/experimental/proxy_tensor.py
+++ b/torch/fx/experimental/proxy_tensor.py
@@ -49,7 +49,7 @@
 
 # We currently convert all SymInt to proxies before we use them.
 # This could plausibly be handled at the Dynamo level.
-pytree._register_pytree_node(torch.Size, lambda x: (list(x), None), lambda xs, _: tuple(xs))
+pytree.register_pytree_node(torch.Size, lambda x: (list(x), None), lambda xs, _: tuple(xs))
 
 def fake_signature(fn, nargs):
     """FX gets confused by varargs, de-confuse it"""
diff --git a/torch/fx/immutable_collections.py b/torch/fx/immutable_collections.py
index 616555015f0e..a359335f6ece 100644
--- a/torch/fx/immutable_collections.py
+++ b/torch/fx/immutable_collections.py
@@ -1,7 +1,7 @@
 from typing import Any, Dict, Iterable, List, Tuple
 
 from ._compatibility import compatibility
-from torch.utils._pytree import Context, _register_pytree_node
+from torch.utils._pytree import Context, register_pytree_node
 
 __all__ = ["immutable_list", "immutable_dict"]
 
@@ -50,5 +50,5 @@ def _immutable_list_unflatten(values: Iterable[Any], context: Context) -> List[A
     return immutable_list(values)
 
 
-_register_pytree_node(immutable_dict, _immutable_dict_flatten, _immutable_dict_unflatten)
-_register_pytree_node(immutable_list, _immutable_list_flatten, _immutable_list_unflatten)
+register_pytree_node(immutable_dict, _immutable_dict_flatten, _immutable_dict_unflatten)
+register_pytree_node(immutable_list, _immutable_list_flatten, _immutable_list_unflatten)
diff --git a/torch/onnx/_internal/fx/dynamo_graph_extractor.py b/torch/onnx/_internal/fx/dynamo_graph_extractor.py
index f55afefd1bbd..79a690f5f48a 100644
--- a/torch/onnx/_internal/fx/dynamo_graph_extractor.py
+++ b/torch/onnx/_internal/fx/dynamo_graph_extractor.py
@@ -40,7 +40,11 @@ def __init__(self):
 
     def __enter__(self):
         for class_type, (flatten_func, unflatten_func) in self._extensions.items():
-            pytree._register_pytree_node(class_type, flatten_func, unflatten_func)
+            pytree._private_register_pytree_node(
+                class_type,
+                flatten_func,
+                unflatten_func,
+            )
         return self
 
     def __exit__(self, exc_type, exc_val, exc_tb):
@@ -93,8 +97,11 @@ def model_output_unflatten(
         # All 'ModelOutput' subclasses are defined under module 'modeling_outputs'.
         named_model_output_classes = inspect.getmembers(
             modeling_outputs,
-            lambda x: inspect.isclass(x)
-            and issubclass(x, modeling_outputs.ModelOutput),
+            lambda x: (
+                inspect.isclass(x)
+                and issubclass(x, modeling_outputs.ModelOutput)
+                and x is not modeling_outputs.ModelOutput
+            ),
         )
 
         for _, class_type in named_model_output_classes:
diff --git a/torch/return_types.py b/torch/return_types.py
index 9f8c85285279..b1284c813387 100644
--- a/torch/return_types.py
+++ b/torch/return_types.py
@@ -13,7 +13,7 @@ def structseq_flatten(structseq):
     def structseq_unflatten(values, context):
         return cls(values)
 
-    torch.utils._pytree._register_pytree_node(cls, structseq_flatten, structseq_unflatten)
+    torch.utils._pytree.register_pytree_node(cls, structseq_flatten, structseq_unflatten)
 
 for name in dir(return_types):
     if name.startswith('__'):
diff --git a/torch/utils/_cxx_pytree.py b/torch/utils/_cxx_pytree.py
index 392c0e2688db..ab82367fccbe 100644
--- a/torch/utils/_cxx_pytree.py
+++ b/torch/utils/_cxx_pytree.py
@@ -13,6 +13,7 @@
 """
 
 import functools
+import warnings
 from typing import (
     Any,
     Callable,
@@ -26,6 +27,11 @@
     Union,
 )
 
+import torch
+
+if torch._running_with_deploy():
+    raise ImportError("C++ pytree utilities do not work with torch::deploy.")
+
 import optree
 from optree import PyTreeSpec  # direct import for type annotations
 
@@ -35,6 +41,9 @@
     "Context",
     "FlattenFunc",
     "UnflattenFunc",
+    "DumpableContext",
+    "ToDumpableContextFn",
+    "FromDumpableContextFn",
     "TreeSpec",
     "LeafSpec",
     "register_pytree_node",
@@ -68,6 +77,9 @@
 FlattenFunc = Callable[[PyTree], Tuple[List, Context]]
 UnflattenFunc = Callable[[Iterable, Context], PyTree]
 OpTreeUnflattenFunc = Callable[[Context, Iterable], PyTree]
+DumpableContext = Any  # Any json dumpable text
+ToDumpableContextFn = Callable[[Context], DumpableContext]
+FromDumpableContextFn = Callable[[DumpableContext], Context]
 
 
 def _reverse_args(func: UnflattenFunc) -> OpTreeUnflattenFunc:
@@ -84,9 +96,11 @@ def register_pytree_node(
     unflatten_fn: UnflattenFunc,
     *,
     serialized_type_name: Optional[str] = None,
+    to_dumpable_context: Optional[ToDumpableContextFn] = None,
+    from_dumpable_context: Optional[FromDumpableContextFn] = None,
     namespace: str = "torch",
 ) -> None:
-    """Extend the set of types that are considered internal nodes in pytrees.
+    """Register a container-like type as pytree node.
 
     The ``namespace`` argument is used to avoid collisions that occur when different libraries
     register the same Python type with different behaviors. It is recommended to add a unique prefix
@@ -109,6 +123,13 @@ def register_pytree_node(
             The function should return an instance of ``cls``.
         serialized_type_name (str, optional): A keyword argument used to specify the fully
             qualified name used when serializing the tree spec.
+        to_dumpable_context (callable, optional): An optional keyword argument to custom specify how
+            to convert the context of the pytree to a custom json dumpable representation. This is
+            used for json serialization, which is being used in :mod:`torch.export` right now.
+        from_dumpable_context (callable, optional): An optional keyword argument to custom specify
+            how to convert the custom json dumpable representation of the context back to the
+            original context. This is used for json deserialization, which is being used in
+            :mod:`torch.export` right now.
         namespace (str, optional): A non-empty string that uniquely identifies the namespace of the
             type registry. This is used to isolate the registry from other modules that might
             register a different custom behavior for the same type. (default: :const:`"torch"`)
@@ -193,24 +214,192 @@ def register_pytree_node(
             )
         )
     """
-    from ._pytree import _register_pytree_node
+    _private_register_pytree_node(
+        cls,
+        flatten_fn,
+        unflatten_fn,
+        serialized_type_name=serialized_type_name,
+        to_dumpable_context=to_dumpable_context,
+        from_dumpable_context=from_dumpable_context,
+        namespace=namespace,
+    )
+
+    from . import _pytree as python
 
-    _register_pytree_node(
+    python._private_register_pytree_node(
         cls,
         flatten_fn,
         unflatten_fn,
         serialized_type_name=serialized_type_name,
+        to_dumpable_context=to_dumpable_context,
+        from_dumpable_context=from_dumpable_context,
     )
 
-    optree.register_pytree_node(
+
+def _register_pytree_node(
+    cls: Type[Any],
+    flatten_fn: FlattenFunc,
+    unflatten_fn: UnflattenFunc,
+    *,
+    serialized_type_name: Optional[str] = None,
+    to_dumpable_context: Optional[ToDumpableContextFn] = None,
+    from_dumpable_context: Optional[FromDumpableContextFn] = None,
+    namespace: str = "torch",
+) -> None:
+    """Register a container-like type as pytree node for the C++ pytree only.
+
+    The ``namespace`` argument is used to avoid collisions that occur when different libraries
+    register the same Python type with different behaviors. It is recommended to add a unique prefix
+    to the namespace to avoid conflicts with other libraries. Namespaces can also be used to specify
+    the same class in different namespaces for different use cases.
+
+    .. warning::
+        For safety reasons, a ``namespace`` must be specified while registering a custom type. It is
+        used to isolate the behavior of flattening and unflattening a pytree node type. This is to
+        prevent accidental collisions between different libraries that may register the same type.
+
+    Args:
+        cls (type): A Python type to treat as an internal pytree node.
+        flatten_fn (callable): A function to be used during flattening, taking an instance of
+            ``cls`` and returning a pair, with (1) an iterable for the children to be flattened
+            recursively, and (2) some hashable auxiliary data to be stored in the treespec and to be
+            passed to the ``unflatten_fn``.
+        unflatten_fn (callable): A function taking two arguments: the auxiliary data that was
+            returned by ``flatten_fn`` and stored in the treespec, and the unflattened children.
+            The function should return an instance of ``cls``.
+        serialized_type_name (str, optional): A keyword argument used to specify the fully
+            qualified name used when serializing the tree spec.
+        to_dumpable_context (callable, optional): An optional keyword argument to custom specify how
+            to convert the context of the pytree to a custom json dumpable representation. This is
+            used for json serialization, which is being used in :mod:`torch.export` right now.
+        from_dumpable_context (callable, optional): An optional keyword argument to custom specify
+            how to convert the custom json dumpable representation of the context back to the
+            original context. This is used for json deserialization, which is being used in
+            :mod:`torch.export` right now.
+        namespace (str, optional): A non-empty string that uniquely identifies the namespace of the
+            type registry. This is used to isolate the registry from other modules that might
+            register a different custom behavior for the same type. (default: :const:`"torch"`)
+
+    Example::
+
+        >>> # xdoctest: +SKIP
+        >>> # Registry a Python type with lambda functions
+        >>> register_pytree_node(
+        ...     set,
+        ...     lambda s: (sorted(s), None, None),
+        ...     lambda children, _: set(children),
+        ...     namespace='set',
+        ... )
+
+        >>> # xdoctest: +SKIP
+        >>> # Register a Python type into a namespace
+        >>> import torch
+        >>> register_pytree_node(
+        ...     torch.Tensor,
+        ...     flatten_func=lambda tensor: (
+        ...         (tensor.cpu().detach().numpy(),),
+        ...         {'dtype': tensor.dtype, 'device': tensor.device, 'requires_grad': tensor.requires_grad},
+        ...     ),
+        ...     unflatten_func=lambda children, metadata: torch.tensor(children[0], **metadata),
+        ...     namespace='torch2numpy',
+        ... )
+
+        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
+        >>> tree = {'weight': torch.ones(size=(1, 2)).cuda(), 'bias': torch.zeros(size=(2,))}
+        >>> tree
+        {'weight': tensor([[1., 1.]], device='cuda:0'), 'bias': tensor([0., 0.])}
+
+        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
+        >>> # Flatten without specifying the namespace
+        >>> tree_flatten(tree)  # `torch.Tensor`s are leaf nodes  # xdoctest: +SKIP
+        ([tensor([0., 0.]), tensor([[1., 1.]], device='cuda:0')], PyTreeSpec({'bias': *, 'weight': *}))
+
+        >>> # xdoctest: +SKIP
+        >>> # Flatten with the namespace
+        >>> tree_flatten(tree, namespace='torch2numpy')  # xdoctest: +SKIP
+        (
+            [array([0., 0.], dtype=float32), array([[1., 1.]], dtype=float32)],
+            PyTreeSpec(
+                {
+                    'bias': CustomTreeNode(Tensor[{'dtype': torch.float32, ...}], [*]),
+                    'weight': CustomTreeNode(Tensor[{'dtype': torch.float32, ...}], [*])
+                },
+                namespace='torch2numpy'
+            )
+        )
+
+        >>> # xdoctest: +SKIP
+        >>> # Register the same type with a different namespace for different behaviors
+        >>> def tensor2flatparam(tensor):
+        ...     return [torch.nn.Parameter(tensor.reshape(-1))], tensor.shape, None
+        ...
+        >>> def flatparam2tensor(children, metadata):
+        ...     return children[0].reshape(metadata)
+        ...
+        >>> register_pytree_node(
+        ...     torch.Tensor,
+        ...     flatten_func=tensor2flatparam,
+        ...     unflatten_func=flatparam2tensor,
+        ...     namespace='tensor2flatparam',
+        ... )
+
+        >>> # xdoctest: +SKIP
+        >>> # Flatten with the new namespace
+        >>> tree_flatten(tree, namespace='tensor2flatparam')  # xdoctest: +SKIP
+        (
+            [
+                Parameter containing: tensor([0., 0.], requires_grad=True),
+                Parameter containing: tensor([1., 1.], device='cuda:0', requires_grad=True)
+            ],
+            PyTreeSpec(
+                {
+                    'bias': CustomTreeNode(Tensor[torch.Size([2])], [*]),
+                    'weight': CustomTreeNode(Tensor[torch.Size([1, 2])], [*])
+                },
+                namespace='tensor2flatparam'
+            )
+        )
+    """
+    warnings.warn(
+        "torch.utils._cxx_pytree._register_pytree_node is deprecated. "
+        "Please use torch.utils._cxx_pytree.register_pytree_node instead.",
+        stacklevel=2,
+    )
+
+    _private_register_pytree_node(
         cls,
         flatten_fn,
-        _reverse_args(unflatten_fn),
+        unflatten_fn,
+        serialized_type_name=serialized_type_name,
+        to_dumpable_context=to_dumpable_context,
+        from_dumpable_context=from_dumpable_context,
         namespace=namespace,
     )
 
 
-_register_pytree_node = register_pytree_node
+def _private_register_pytree_node(
+    cls: Type[Any],
+    flatten_fn: FlattenFunc,
+    unflatten_fn: UnflattenFunc,
+    *,
+    serialized_type_name: Optional[str] = None,
+    to_dumpable_context: Optional[ToDumpableContextFn] = None,
+    from_dumpable_context: Optional[FromDumpableContextFn] = None,
+    namespace: str = "torch",
+) -> None:
+    """This is an internal function that is used to register a pytree node type
+    for the C++ pytree only. End-users should use :func:`register_pytree_node`
+    instead.
+    """
+    # TODO(XuehaiPan): remove this condition when we make Python pytree out-of-box support
+    # PyStructSequence types
+    if not optree.is_structseq_class(cls):
+        optree.register_pytree_node(
+            cls,
+            flatten_fn,
+            _reverse_args(unflatten_fn),
+            namespace=namespace,
+        )
 
 
 def tree_flatten(
diff --git a/torch/utils/_pytree.py b/torch/utils/_pytree.py
index f74d4a76e5b8..5faa6c7c16ad 100644
--- a/torch/utils/_pytree.py
+++ b/torch/utils/_pytree.py
@@ -17,6 +17,7 @@
 
 import dataclasses
 import json
+import threading
 import warnings
 from collections import deque, namedtuple, OrderedDict
 from typing import (
@@ -99,6 +100,7 @@ class NodeDef(NamedTuple):
     unflatten_fn: UnflattenFunc
 
 
+_NODE_REGISTRY_LOCK = threading.Lock()
 SUPPORTED_NODES: Dict[Type[Any], NodeDef] = {}
 
 
@@ -120,6 +122,59 @@ class _SerializeNodeDef(NamedTuple):
 SERIALIZED_TYPE_TO_PYTHON_TYPE: Dict[str, Type[Any]] = {}
 
 
+def register_pytree_node(
+    cls: Any,
+    flatten_fn: FlattenFunc,
+    unflatten_fn: UnflattenFunc,
+    *,
+    serialized_type_name: Optional[str] = None,
+    to_dumpable_context: Optional[ToDumpableContextFn] = None,
+    from_dumpable_context: Optional[FromDumpableContextFn] = None,
+) -> None:
+    """Register a container-like type as pytree node.
+
+    Args:
+        cls: the type to register
+        flatten_fn: A callable that takes a pytree and returns a flattened
+            representation of the pytree and additional context to represent the
+            flattened pytree.
+        unflatten_fn: A callable that takes a flattened version of the pytree,
+            additional context, and returns an unflattened pytree.
+        serialized_type_name: A keyword argument used to specify the fully qualified
+            name used when serializing the tree spec.
+        to_dumpable_context: An optional keyword argument to custom specify how
+            to convert the context of the pytree to a custom json dumpable
+            representation. This is used for json serialization, which is being
+            used in torch.export right now.
+        from_dumpable_context: An optional keyword argument to custom specify how
+            to convert the custom json dumpable representation of the context
+            back to the original context. This is used for json deserialization,
+            which is being used in torch.export right now.
+    """
+    _private_register_pytree_node(
+        cls,
+        flatten_fn,
+        unflatten_fn,
+        serialized_type_name=serialized_type_name,
+        to_dumpable_context=to_dumpable_context,
+        from_dumpable_context=from_dumpable_context,
+    )
+
+    try:
+        from . import _cxx_pytree as cxx
+    except ImportError:
+        pass
+    else:
+        cxx._private_register_pytree_node(
+            cls,
+            flatten_fn,
+            unflatten_fn,
+            serialized_type_name=serialized_type_name,
+            to_dumpable_context=to_dumpable_context,
+            from_dumpable_context=from_dumpable_context,
+        )
+
+
 def _register_pytree_node(
     cls: Any,
     flatten_fn: FlattenFunc,
@@ -131,7 +186,8 @@ def _register_pytree_node(
     to_dumpable_context: Optional[ToDumpableContextFn] = None,
     from_dumpable_context: Optional[FromDumpableContextFn] = None,
 ) -> None:
-    """
+    """Register a container-like type as pytree node for the Python pytree only.
+
     Args:
         cls: the type to register
         flatten_fn: A callable that takes a pytree and returns a flattened
@@ -150,39 +206,69 @@ def _register_pytree_node(
             back to the original context. This is used for json deserialization,
             which is being used in torch.export right now.
     """
+    warnings.warn(
+        "torch.utils._pytree._register_pytree_node is deprecated. "
+        "Please use torch.utils._pytree.register_pytree_node instead.",
+        stacklevel=2,
+    )
+
     if to_str_fn is not None or maybe_from_str_fn is not None:
         warnings.warn(
             "to_str_fn and maybe_from_str_fn is deprecated. "
             "Please use to_dumpable_context and from_dumpable_context instead."
         )
 
-    node_def = NodeDef(
+    _private_register_pytree_node(
         cls,
         flatten_fn,
         unflatten_fn,
+        serialized_type_name=serialized_type_name,
+        to_dumpable_context=to_dumpable_context,
+        from_dumpable_context=from_dumpable_context,
     )
-    SUPPORTED_NODES[cls] = node_def
 
-    if (to_dumpable_context is None) ^ (from_dumpable_context is None):
-        raise ValueError(
-            f"Both to_dumpable_context and from_dumpable_context for {cls} must "
-            "be None or registered."
-        )
 
-    if serialized_type_name is None:
-        serialized_type_name = f"{cls.__module__}.{cls.__name__}"
+def _private_register_pytree_node(
+    cls: Any,
+    flatten_fn: FlattenFunc,
+    unflatten_fn: UnflattenFunc,
+    *,
+    serialized_type_name: Optional[str] = None,
+    to_dumpable_context: Optional[ToDumpableContextFn] = None,
+    from_dumpable_context: Optional[FromDumpableContextFn] = None,
+) -> None:
+    """This is an internal function that is used to register a pytree node type
+    for the Python pytree only. End-users should use :func:`register_pytree_node`
+    instead.
+    """
+    with _NODE_REGISTRY_LOCK:
+        if cls in SUPPORTED_NODES:
+            raise ValueError(f"{cls} is already registered as pytree node.")
+
+        node_def = NodeDef(
+            cls,
+            flatten_fn,
+            unflatten_fn,
+        )
+        SUPPORTED_NODES[cls] = node_def
 
-    serialize_node_def = _SerializeNodeDef(
-        cls,
-        serialized_type_name,
-        to_dumpable_context,
-        from_dumpable_context,
-    )
-    SUPPORTED_SERIALIZED_TYPES[cls] = serialize_node_def
-    SERIALIZED_TYPE_TO_PYTHON_TYPE[serialized_type_name] = cls
+        if (to_dumpable_context is None) ^ (from_dumpable_context is None):
+            raise ValueError(
+                f"Both to_dumpable_context and from_dumpable_context for {cls} must "
+                "be None or registered."
+            )
 
+        if serialized_type_name is None:
+            serialized_type_name = f"{cls.__module__}.{cls.__qualname__}"
 
-register_pytree_node = _register_pytree_node
+        serialize_node_def = _SerializeNodeDef(
+            cls,
+            serialized_type_name,
+            to_dumpable_context,
+            from_dumpable_context,
+        )
+        SUPPORTED_SERIALIZED_TYPES[cls] = serialize_node_def
+        SERIALIZED_TYPE_TO_PYTHON_TYPE[serialized_type_name] = cls
 
 
 def _dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]:
@@ -243,25 +329,25 @@ def _odict_unflatten(
     return OrderedDict((key, value) for key, value in zip(context, values))
 
 
-_register_pytree_node(
+_private_register_pytree_node(
     dict,
     _dict_flatten,
     _dict_unflatten,
     serialized_type_name="builtins.dict",
 )
-_register_pytree_node(
+_private_register_pytree_node(
     list,
     _list_flatten,
     _list_unflatten,
     serialized_type_name="builtins.list",
 )
-_register_pytree_node(
+_private_register_pytree_node(
     tuple,
     _tuple_flatten,
     _tuple_unflatten,
     serialized_type_name="builtins.tuple",
 )
-_register_pytree_node(
+_private_register_pytree_node(
     namedtuple,
     _namedtuple_flatten,
     _namedtuple_unflatten,
@@ -269,7 +355,7 @@ def _odict_unflatten(
     from_dumpable_context=_namedtuple_deserialize,
     serialized_type_name="collections.namedtuple",
 )
-_register_pytree_node(
+_private_register_pytree_node(
     OrderedDict,
     _odict_flatten,
     _odict_unflatten,
@@ -729,7 +815,7 @@ def _treespec_to_json(treespec: TreeSpec) -> _TreeSpecSchema:
 
     if treespec.type not in SUPPORTED_SERIALIZED_TYPES:
         raise NotImplementedError(
-            f"Serializing {treespec.type} in pytree is not registered."
+            f"Serializing {treespec.type} in pytree is not registered.",
         )
 
     serialize_node_def = SUPPORTED_SERIALIZED_TYPES[treespec.type]

From c47d2b80355db2120a591f21df494bdacff5ef30 Mon Sep 17 00:00:00 2001
From: CaoE 
Date: Tue, 21 Nov 2023 20:08:28 +0000
Subject: [PATCH 056/221] Add Half support for CPU autocast on eager mode
 (#112484)

Add Half support for CPU autocast on eager mode since common operators have Half support on CPU.
https://github.com/pytorch/pytorch/issues/96093.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112484
Approved by: https://github.com/leslie-fang-intel, https://github.com/ezyang
---
 test/test_autocast.py                         | 62 ++++++++++++++-----
 torch/amp/autocast_mode.py                    |  5 +-
 .../testing/_internal/autocast_test_lists.py  | 48 +++++++-------
 3 files changed, 77 insertions(+), 38 deletions(-)

diff --git a/test/test_autocast.py b/test/test_autocast.py
index 256aa627b580..85998107a062 100644
--- a/test/test_autocast.py
+++ b/test/test_autocast.py
@@ -17,7 +17,16 @@ def tearDown(self):
         del self.autocast_lists
         super().tearDown()
 
-    def _run_autocast_outofplace(self, op, args, run_as_type, out_type=None, module=torch, add_kwargs=None):
+    def _run_autocast_outofplace(
+        self,
+        op,
+        args,
+        run_as_type,
+        out_type=None,
+        module=torch,
+        add_kwargs=None,
+        amp_dtype=torch.bfloat16,
+    ):
         # helper to cast args
         def cast(val, to_type):
             if isinstance(val, torch.Tensor):
@@ -31,7 +40,7 @@ def cast(val, to_type):
             add_kwargs = {}
 
         self.assertFalse(torch.is_autocast_cpu_enabled())
-        with torch.cpu.amp.autocast():
+        with torch.cpu.amp.autocast(dtype=amp_dtype):
             self.assertTrue(torch.is_autocast_cpu_enabled())
             out_type = out_type if out_type is not None else run_as_type
             output = output_method = None
@@ -92,36 +101,61 @@ def args_maybe_kwargs(self, op_with_args):
             return op_with_args[0], op_with_args[1], op_with_args[2]
 
     def test_autocast_torch_expect_builtin_promote(self):
-        for op, args, out_type in self.autocast_lists.torch_expect_builtin_promote:
-            self._run_autocast_outofplace(op, args, torch.float32, out_type=out_type)
+        for op, args1, args2, out_type in self.autocast_lists.torch_expect_builtin_promote:
+            self._run_autocast_outofplace(op, args1, torch.float32, out_type=out_type)
+            self._run_autocast_outofplace(op, args2, torch.float32, out_type=out_type, amp_dtype=torch.float16)
 
     def test_autocast_methods_expect_builtin_promote(self):
-        for op, args, out_type in self.autocast_lists.methods_expect_builtin_promote:
-            self._run_autocast_outofplace(op, args, torch.float32, module=None, out_type=out_type)
+        for op, args1, args2, out_type in self.autocast_lists.methods_expect_builtin_promote:
+            self._run_autocast_outofplace(op, args1, torch.float32, module=None, out_type=out_type)
+            self._run_autocast_outofplace(op, args2, torch.float32, module=None, out_type=out_type, amp_dtype=torch.float16)
 
-    def test_autocast_torch_bf16(self):
-        for op_with_args in self.autocast_lists.torch_bf16:
+    def test_autocast_torch_16(self):
+        for op_with_args in self.autocast_lists.torch_16:
             op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
             self._run_autocast_outofplace(op, args, torch.bfloat16, add_kwargs=maybe_kwargs)
+            self._run_autocast_outofplace(op, args, torch.float16, add_kwargs=maybe_kwargs, amp_dtype=torch.float16)
 
-    def test_autocast_nn_bf16(self):
-        for op_with_args in self.autocast_lists.nn_bf16:
+    def test_autocast_nn_16(self):
+        for op_with_args in self.autocast_lists.nn_16:
             op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
-            self._run_autocast_outofplace(op, args, torch.bfloat16, module=torch._C._nn, add_kwargs=maybe_kwargs)
+            self._run_autocast_outofplace(
+                op, args, torch.bfloat16, module=torch._C._nn, add_kwargs=maybe_kwargs
+            )
+            self._run_autocast_outofplace(
+                op,
+                args,
+                torch.float16,
+                module=torch._C._nn,
+                add_kwargs=maybe_kwargs,
+                amp_dtype=torch.float16,
+            )
 
     def test_autocast_torch_fp32(self):
         for op_with_args in self.autocast_lists.torch_fp32:
             op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
             self._run_autocast_outofplace(op, args, torch.float32, add_kwargs=maybe_kwargs)
+            self._run_autocast_outofplace(op, args, torch.float32, add_kwargs=maybe_kwargs, amp_dtype=torch.float16)
 
     def test_autocast_nn_fp32(self):
         for op_with_args in self.autocast_lists.nn_fp32:
             op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
-            self._run_autocast_outofplace(op, args, torch.float32, module=torch._C._nn, add_kwargs=maybe_kwargs)
+            self._run_autocast_outofplace(
+                op, args, torch.float32, module=torch._C._nn, add_kwargs=maybe_kwargs
+            )
+            self._run_autocast_outofplace(
+                op,
+                args,
+                torch.float32,
+                module=torch._C._nn,
+                add_kwargs=maybe_kwargs,
+                amp_dtype=torch.float16,
+            )
 
     def test_autocast_torch_need_autocast_promote(self):
-        for op, args in self.autocast_lists.torch_need_autocast_promote:
-            self._run_autocast_outofplace(op, args, torch.float32)
+        for op, args1, args2 in self.autocast_lists.torch_need_autocast_promote:
+            self._run_autocast_outofplace(op, args1, torch.float32)
+            self._run_autocast_outofplace(op, args2, torch.float32, amp_dtype=torch.float16)
 
     @unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path")
     def test_autocast_rnn(self):
diff --git a/torch/amp/autocast_mode.py b/torch/amp/autocast_mode.py
index 8ed2c92b10ec..30c6aefcf1bd 100644
--- a/torch/amp/autocast_mode.py
+++ b/torch/amp/autocast_mode.py
@@ -257,11 +257,12 @@ def __init__(
             self._cache_enabled = cache_enabled
 
         if self.device == "cpu":
-            supported_dtype = [torch.bfloat16]
+            supported_dtype = [torch.bfloat16, torch.float16]
             if self.fast_dtype not in supported_dtype and enabled:
                 error_message = "In CPU autocast, but the target dtype is not supported. Disabling autocast.\n"
+                error_message += "CPU Autocast only supports dtype of "
                 error_message += (
-                    "CPU Autocast only supports dtype of torch.bfloat16 currently."
+                    ", ".join(str(dtype) for dtype in supported_dtype) + " currently."
                 )
                 warnings.warn(error_message)
                 enabled = False
diff --git a/torch/testing/_internal/autocast_test_lists.py b/torch/testing/_internal/autocast_test_lists.py
index 5eced2f65c73..e6b6dcfc0f40 100644
--- a/torch/testing/_internal/autocast_test_lists.py
+++ b/torch/testing/_internal/autocast_test_lists.py
@@ -244,6 +244,9 @@ def __init__(self, dev):
         mat1_bf16 = (torch.randn((n, n), dtype=torch.bfloat16, device=dev),)
         mat2_bf16 = (torch.randn((n, n), dtype=torch.bfloat16, device=dev),)
 
+        pointwise0_fp16 = (torch.randn(n, dtype=torch.float16, device=dev),)
+        pointwise1_fp16 = (torch.randn(n, dtype=torch.float16, device=dev),)
+
         dummy_dimsets = ((n,), (n, n), (n, n, n), (n, n, n, n), (n, n, n, n, n))
 
         dummy_bf16 = [(torch.randn(dimset, dtype=torch.bfloat16, device=dev),)
@@ -275,29 +278,30 @@ def __init__(self, dev):
         # Some ops implement built-in type promotion.  These don't need autocasting,
         # but autocasting relies on their promotion, so we include tests to double-check.
         self.torch_expect_builtin_promote = [
-            ("eq", pointwise0_fp32 + pointwise1_bf16, torch.bool),
-            ("ge", pointwise0_fp32 + pointwise1_bf16, torch.bool),
-            ("gt", pointwise0_fp32 + pointwise1_bf16, torch.bool),
-            ("le", pointwise0_fp32 + pointwise1_bf16, torch.bool),
-            ("lt", pointwise0_fp32 + pointwise1_bf16, torch.bool),
-            ("ne", pointwise0_fp32 + pointwise1_bf16, torch.bool),
-            ("add", pointwise0_fp32 + pointwise1_bf16, torch.float32),
-            ("div", pointwise0_fp32 + pointwise1_bf16, torch.float32),
-            ("mul", pointwise0_fp32 + pointwise1_bf16, torch.float32),
+            ("eq", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("ge", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("gt", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("le", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("lt", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("ne", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("add", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32),
+            ("div", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32),
+            ("mul", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32),
         ]
+
         self.methods_expect_builtin_promote = [
-            ("__eq__", pointwise0_fp32 + pointwise1_bf16, torch.bool),
-            ("__ge__", pointwise0_fp32 + pointwise1_bf16, torch.bool),
-            ("__gt__", pointwise0_fp32 + pointwise1_bf16, torch.bool),
-            ("__le__", pointwise0_fp32 + pointwise1_bf16, torch.bool),
-            ("__lt__", pointwise0_fp32 + pointwise1_bf16, torch.bool),
-            ("__ne__", pointwise0_fp32 + pointwise1_bf16, torch.bool),
-            ("__add__", pointwise0_fp32 + pointwise1_bf16, torch.float32),
-            ("__div__", pointwise0_fp32 + pointwise1_bf16, torch.float32),
-            ("__mul__", pointwise0_fp32 + pointwise1_bf16, torch.float32),
+            ("__eq__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("__ge__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("__gt__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("__le__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("__lt__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("__ne__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("__add__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32),
+            ("__div__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32),
+            ("__mul__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32),
         ]
         # The remaining lists organize ops that autocast treats explicitly.
-        self.torch_bf16 = [
+        self.torch_16 = [
             ("conv1d", conv_args_fp32[0]),
             ("conv2d", conv_args_fp32[1]),
             ("conv3d", conv_args_fp32[2]),
@@ -337,7 +341,7 @@ def __init__(self, dev):
             ("triplet_margin_loss", mat0_bf16 + mat1_bf16 + mat2_bf16),
             ("binary_cross_entropy_with_logits", mat0_bf16 + (torch.rand((n, n), device=dev, dtype=torch.bfloat16),)),
         ]
-        self.nn_bf16 = [
+        self.nn_16 = [
             ("linear", mat0_fp32 + mat1_fp32, {}),
         ]
         self.nn_fp32 = [
@@ -358,6 +362,6 @@ def __init__(self, dev):
             ("huber_loss", mat0_bf16 + mat1_bf16),
         ]
         self.torch_need_autocast_promote = [
-            ("cat", (pointwise0_bf16 + pointwise1_fp32,)),
-            ("stack", (pointwise0_bf16 + pointwise1_fp32,)),
+            ("cat", (pointwise0_bf16 + pointwise1_fp32,), (pointwise0_fp16 + pointwise1_fp32,)),
+            ("stack", (pointwise0_bf16 + pointwise1_fp32,), (pointwise0_fp16 + pointwise1_fp32,)),
         ]

From 2abfb8ec7d7a3970097c12caabe1ccb7a05bb5d5 Mon Sep 17 00:00:00 2001
From: "Edward Z. Yang" 
Date: Tue, 21 Nov 2023 10:13:43 -0500
Subject: [PATCH 057/221] Correctly codegen math.inf in Inductor (#114159)

Signed-off-by: Edward Z. Yang 

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114159
Approved by: https://github.com/lezcano
---
 .../test_torchinductor_dynamic_shapes.py      | 20 +++++++++++++++++++
 torch/_inductor/codegen/common.py             |  6 ++++++
 2 files changed, 26 insertions(+)

diff --git a/test/inductor/test_torchinductor_dynamic_shapes.py b/test/inductor/test_torchinductor_dynamic_shapes.py
index 459059a7434c..3745afa72bbd 100644
--- a/test/inductor/test_torchinductor_dynamic_shapes.py
+++ b/test/inductor/test_torchinductor_dynamic_shapes.py
@@ -235,6 +235,26 @@ def f(x):
 
         f(torch.tensor([3], device=device))
 
+    @torch._dynamo.config.patch(
+        capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True
+    )
+    def test_float_item_inf(self, device):
+        @torch.compile(fullgraph=True)
+        def f(x):
+            return x.item() == math.inf
+
+        f(torch.tensor([3.0], device=device))
+
+    @torch._dynamo.config.patch(
+        capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True
+    )
+    def test_float_item_neginf(self, device):
+        @torch.compile(fullgraph=True)
+        def f(x):
+            return x.item() == -math.inf
+
+        f(torch.tensor([3.0], device=device))
+
     @torch._dynamo.config.patch(capture_scalar_outputs=True)
     @torch._inductor.config.patch(implicit_fallbacks=True)
     def test_item_to_inputs_kernel_nobreak(self, device):
diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py
index 920dc318d715..be949e8f92a9 100644
--- a/torch/_inductor/codegen/common.py
+++ b/torch/_inductor/codegen/common.py
@@ -311,6 +311,12 @@ def _print_Pow(self, expr):
         else:  # exp == 0
             return "1"
 
+    def _print_Infinity(self, expr):
+        return "math.inf"
+
+    def _print_NegativeInfinity(self, expr):
+        return "-math.inf"
+
     def _print_Relational(self, expr):
         return f" {expr.rel_op} ".join(map(self.paren, map(self._print, expr.args)))
 

From 3b108a150a8b51c3741e24862402feac7c4f76f7 Mon Sep 17 00:00:00 2001
From: Ying Zhang 
Date: Tue, 21 Nov 2023 20:34:02 +0000
Subject: [PATCH 058/221] A fix for reduction + pointwise + multi-level
 reduction optimization (#112935)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

ATT, for cases like reduction + multiple pointwises + multi-level reduction, previously to decide num_splits of the multi-level reduction, we only check whether the input of multi-level reduction or input of input of multi-level reduction is a reduction node (i.e. max search level is 2). This PR changes the behavior to search for a reduction input node recursively if previous input nodes are pointwise nodes.

Performance-wise it looks fine.
![Screenshot 2023-11-15 at 11 52 28 PM](https://github.com/pytorch/pytorch/assets/10527447/e726948c-0c00-4839-87a4-bcf9044c66d7)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112935
Approved by: https://github.com/chenyang78
---
 test/inductor/test_perf.py      | 30 +++++++++++------------
 torch/_inductor/dependencies.py | 43 ++++++++++++++++++++++-----------
 2 files changed, 44 insertions(+), 29 deletions(-)

diff --git a/test/inductor/test_perf.py b/test/inductor/test_perf.py
index fc68491e5f16..2e5c367500a8 100644
--- a/test/inductor/test_perf.py
+++ b/test/inductor/test_perf.py
@@ -372,15 +372,11 @@ def f(a, b, c):
 
     def test_reduction_pointwise_multi_level_reduction(self):
         hidden_size = 4096
+        layer_norm = torch.nn.LayerNorm(hidden_size).cuda().float()
 
+        @torch.inference_mode()
         def f(x, scale, amax_keep_dim):
-            x = torch.nn.functional.layer_norm(
-                x.to(dtype=torch.float),
-                [hidden_size],
-                weight=None,
-                bias=None,
-                eps=1e-05,
-            )
+            x = layer_norm(x.to(dtype=torch.float))
             amax = torch.amax(torch.abs(x), keepdim=amax_keep_dim)
             x_scaled = x * scale
             y = torch.nn.functional.sigmoid(x_scaled)
@@ -389,22 +385,26 @@ def f(x, scale, amax_keep_dim):
         inp = (T(4, 2048, hidden_size, dtype=torch.float), T(1, dtype=torch.float))
 
         # 3 kernels:
-        # kernel 1: (input = X, scale, output = LN_pointwise(X), welford_reduction(X) * 2)
-        # kernel 2: (input = X, welford_reduction(X) * 2, output = first-level amax (split-reduction))
+        # kernel 1: (input = X, scale, LN scale, LN bias, output = LN_pointwise(X), welford_reduction(X) * 2)
+        # kernel 2: (input = X, welford_reduction(X) * 2, LN scale, LN bias, output = first-level amax (split-reduction))
         # kernel 3: (input = first-level amax, output = final amax)
-        # scale (1) + X (4*2048*hidden_size) * 3 + welford_reduction (4*2048) * 4 + amax (num_splits * 2 + 1)
+        # scale (1) + X (4*2048*hidden_size) * 3 + welford_reduction (4*2048) * 4 +
+        #   LN scale (hidden_size) * 2 + LN bias (hidden_size) * 2 + amax (num_splits * 2 + 1)
         # num_splits depends on SM architectures.
-        expected_amax_keep_dim_numel = 1 + 4 * 2048 * hidden_size * 3 + 4 * 2048 * 4 + 1
+        expected_amax_keep_dim_numel = (
+            1 + hidden_size * 4 + 4 * 2048 * hidden_size * 3 + 4 * 2048 * 4 + 1
+        )
         self.assertGreaterAlmostEqual(
-            count_numel(f, *inp, True), str(expected_amax_keep_dim_numel)
+            int(count_numel(f, *inp, True)), expected_amax_keep_dim_numel
         )
 
         # 2 kernels:
-        # kernel 1: (input = X, scale, output = LN_pointwise(X), first-level amax (split-reduction))
+        # kernel 1: (input = X, scale, LN scale, LN bias, output = LN_pointwise(X), first-level amax (split-reduction))
         # kernel 2: (input = first-level amax, output = final amax)
-        # scale (1) + X (4*2048*hidden_size) * 2 + amax (4 * 2048 * 2 + 1)
+        # scale (1) + X (4*2048*hidden_size) * 2 + LN scale (hidden_size) + LN bias (hidden_size) + amax (4 * 2048 * 2 + 1)
+
         expected_amax_no_keep_dim_numel = (
-            1 + 4 * 2048 * hidden_size * 2 + 4 * 2048 * 2 + 1
+            1 + hidden_size * 2 + 4 * 2048 * hidden_size * 2 + 4 * 2048 * 2 + 1
         )
         self.assertExpectedInline(
             count_numel(f, *inp, False), str(expected_amax_no_keep_dim_numel)
diff --git a/torch/_inductor/dependencies.py b/torch/_inductor/dependencies.py
index 7effbdcc3974..5a5982c68516 100644
--- a/torch/_inductor/dependencies.py
+++ b/torch/_inductor/dependencies.py
@@ -407,21 +407,36 @@ def extract_input_node_reduction_ranges(
     reads = input_node.get_reads()
     reduction_size = None
     size = None
-    for read in reads:
-        if not isinstance(read, MemoryDep):
-            continue
-        buffer = V.graph.get_buffer(read.name)
-        if buffer is None:
-            continue
-        if isinstance(buffer, ComputedBuffer) and len(buffer.get_reduction_size()) > 0:
-            if reduction_size is None:
-                reduction_size = buffer.get_reduction_size()
-                size = buffer.get_size()
-            elif (
-                reduction_size != buffer.get_reduction_size()
-                or size != buffer.get_size()
+    while reduction_size is None and len(reads) > 0:
+        seen = set()
+        new_reads = []
+        for read in reads:
+            if not isinstance(read, MemoryDep):
+                continue
+            if read.name in seen:
+                continue
+            seen.add(read.name)
+            buffer = V.graph.get_buffer(read.name)
+            if buffer is None:
+                continue
+            if (
+                isinstance(buffer, ComputedBuffer)
+                and len(buffer.get_reduction_size()) > 0
             ):
-                return (None, None)
+                if reduction_size is None:
+                    reduction_size = buffer.get_reduction_size()
+                    size = buffer.get_size()
+                elif (
+                    reduction_size != buffer.get_reduction_size()
+                    or size != buffer.get_size()
+                ):
+                    return (None, None)
+            else:
+                new_reads.extend(buffer.get_reads())
+        if reads == new_reads:
+            return (size, reduction_size)
+        else:
+            reads = new_reads
     return (size, reduction_size)
 
 

From 64a5372e6ce9b6ca0ee5c7482b27e24561725b28 Mon Sep 17 00:00:00 2001
From: Chip Turner 
Date: Tue, 21 Nov 2023 21:03:48 +0000
Subject: [PATCH 059/221] Opportunistically use `ncclCommSplit` when creating
 new NCCL groups (#112889)

Currently `ncclCommInitRankConfig` is always used when creating new
communicator groups.  This is wasteful as it creates non-shared pairs
of endpoint queues as well as costs time to re-establish
communication.

This change is transparent and opportunistic; when `dist.new_group` is
called, it will use the existing, healthy world process group to
select the right ranks to include in the process group.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112889
Approved by: https://github.com/kwen2501
---
 test/cpp/c10d/ProcessGroupNCCLTest.cpp        | 78 ++++++++++++++++---
 test/distributed/test_c10d_nccl.py            | 22 +++++-
 torch/csrc/distributed/c10d/NCCLUtils.hpp     | 26 +++++++
 .../distributed/c10d/ProcessGroupNCCL.cpp     | 47 +++++++++--
 .../distributed/c10d/ProcessGroupNCCL.hpp     | 11 +++
 torch/csrc/distributed/c10d/init.cpp          | 14 +++-
 torch/distributed/distributed_c10d.py         | 35 ++++++++-
 7 files changed, 212 insertions(+), 21 deletions(-)

diff --git a/test/cpp/c10d/ProcessGroupNCCLTest.cpp b/test/cpp/c10d/ProcessGroupNCCLTest.cpp
index 61e9753988ea..6a0d60b57315 100644
--- a/test/cpp/c10d/ProcessGroupNCCLTest.cpp
+++ b/test/cpp/c10d/ProcessGroupNCCLTest.cpp
@@ -31,12 +31,20 @@ class NCCLTestBase {
     pg_ = std::move(other.pg_);
   }
 
-  ::c10d::ProcessGroupNCCL& getProcessGroup() {
-    return *pg_;
+  std::shared_ptr<::c10d::ProcessGroupNCCL> getProcessGroup() {
+    return pg_;
   }
 
-  void initialize(int rank, int size) {
-    auto store = c10::make_intrusive<::c10d::FileStore>(path_, size);
+  ::c10::intrusive_ptr<::c10d::Store>& getProcessGroupStore() {
+    return store_;
+  }
+
+  void initialize(
+      int rank,
+      int size,
+      c10::optional<::std::shared_ptr<::c10d::ProcessGroupNCCL>> split_from =
+          c10::nullopt) {
+    store_ = c10::make_intrusive<::c10d::FileStore>(path_, size);
 
     c10::intrusive_ptr opts =
         c10::make_intrusive();
@@ -45,14 +53,22 @@ class NCCLTestBase {
         c10d::TORCH_ENABLE_NCCL_HEALTH_CHECK[0].c_str(),
         "1",
         /* overwrite */ 1);
+#ifdef NCCL_HAS_COMM_SPLIT
+    if (split_from) {
+      opts->split_from = *split_from;
+      opts->split_color = ++color_;
+    }
+#endif
     pg_ = std::unique_ptr<::c10d::ProcessGroupNCCL>(
-        new ::c10d::ProcessGroupNCCL(store, rank, size, std::move(opts)));
+        new ::c10d::ProcessGroupNCCL(store_, rank, size, std::move(opts)));
   }
 
  protected:
   std::string path_;
-  std::unique_ptr<::c10d::ProcessGroupNCCL> pg_;
+  std::shared_ptr<::c10d::ProcessGroupNCCL> pg_;
   std::chrono::milliseconds pgTimeout_;
+  ::c10::intrusive_ptr<::c10d::Store> store_;
+  int color_{1};
 };
 
 class NCCLTest : public NCCLTestBase {
@@ -718,9 +734,9 @@ void testSequenceNumInit(
   auto runTest = [&](int i) {
     NCCLTest test(path, worldSize);
     test.initialize(i, worldSize);
-    test.getProcessGroup().setSequenceNumberForGroup();
+    test.getProcessGroup()->setSequenceNumberForGroup();
     std::lock_guard lock(m);
-    auto seqNum = test.getProcessGroup().getSequenceNumberForGroup();
+    auto seqNum = test.getProcessGroup()->getSequenceNumberForGroup();
     nums.insert(seqNum);
   };
   std::vector threads;
@@ -877,11 +893,55 @@ TEST_F(ProcessGroupNCCLTest, testBackendName) {
     auto test = NCCLTestBase(file.path);
     test.initialize(rank_, size_);
     EXPECT_EQ(
-        test.getProcessGroup().getBackendName(),
+        test.getProcessGroup()->getBackendName(),
         std::string(c10d::NCCL_BACKEND_NAME));
   }
 }
 
+TEST_F(ProcessGroupNCCLTest, testSplittingCommunicator) {
+  if (skipTest()) {
+    return;
+  }
+  TemporaryFile file;
+  auto test1 = BroadcastNCCLTest(file.path, size_);
+  test1.initialize(rank_, size_);
+
+  auto test2 = BroadcastNCCLTest(file.path, size_);
+  test2.initialize(rank_, size_, test1.getProcessGroup());
+
+  // Steal the broadcast test and issue it for both of our groups.
+  // This ensures consistent full collective communication.  TODO:
+  // maybe refactor the guts rather than copy-pasta, but it may not be
+  // worth it.
+  for (auto test : {&test1, &test2}) {
+    const int numDevices = test->numDevices();
+    // try every permutation of root rank and root tensor
+    for (const auto rootRank : c10::irange(size_)) {
+      for (const auto rootTensor : c10::irange(numDevices)) {
+        auto work = test->run(rootRank, rootTensor);
+        test->wait(work);
+
+        // Check results
+        const auto expected = (rootRank * numDevices + rootTensor);
+        const auto tensors = test->getTensors();
+        for (const auto& tensor : tensors) {
+          const auto* const data = tensor.data_ptr();
+          for (const auto k : c10::irange(tensor.numel())) {
+            EXPECT_EQ(data[k], expected)
+                << "Broadcast outputs do not match expected outputs";
+          }
+        }
+      }
+    }
+  }
+
+  // Now that we've run full operations on both the original and split process
+  // group, ensure we saw exactly as many splits as we expected: 0 in the
+  // original process group, and one per device in the second.
+  EXPECT_EQ(test2.getProcessGroup()->getCommSplitCounter(), 0);
+  EXPECT_EQ(test1.getProcessGroup()->getCommSplitCounter(), test1.numDevices());
+}
+
 #ifdef IS_NCCL_EXP
 TEST_F(ProcessGroupNCCLTest, testSparseAllreduce) {
   if (skipTest()) {
diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py
index ada84507aef9..4ac72c2bd207 100644
--- a/test/distributed/test_c10d_nccl.py
+++ b/test/distributed/test_c10d_nccl.py
@@ -1272,6 +1272,27 @@ def allgather_base(output_t, input_t):
         # Verification
         self.assertEqual(torch.arange(self.world_size), output_t)
 
+    @requires_nccl()
+    def test_comm_split_optimization(self):
+        store = c10d.FileStore(self.file_name, self.world_size)
+        pg = self._create_process_group_nccl(store, self.opts())
+
+        # Test lazy splitting behavior across each per-device backend.
+        for device in self.rank_to_GPU[self.rank]:
+            backend = pg._get_backend(torch.device(device))
+
+            # split doesn't happen unless the original process group has lazily
+            # created communicators, so first verify we haven't split even when
+            # making the new group and running an operation on the original pg.
+            ng = c10d.new_group()
+            tensor = torch.tensor([self.rank]).cuda(device)
+            pg.broadcast(tensor, 0)
+            self.assertEqual(backend.comm_split_count(), 0)
+
+            # The new group will force a split of the original on first use.
+            ng.broadcast(tensor, 0)
+            self.assertEqual(backend.comm_split_count(), 1)
+
 class DistributedDataParallelTest(
     test_c10d_common.CommonDistributedDataParallelTest, MultiProcessTestCase
 ):
@@ -3676,7 +3697,6 @@ def gather_trace():
 
 
 
-
 if __name__ == "__main__":
     assert (
         not torch.cuda._initialized
diff --git a/torch/csrc/distributed/c10d/NCCLUtils.hpp b/torch/csrc/distributed/c10d/NCCLUtils.hpp
index e6c05e228cfd..2b4885f02ffc 100644
--- a/torch/csrc/distributed/c10d/NCCLUtils.hpp
+++ b/torch/csrc/distributed/c10d/NCCLUtils.hpp
@@ -17,6 +17,11 @@
 #define NCCL_HAS_COMM_NONBLOCKING
 #endif
 
+#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \
+    (NCCL_MINOR >= 18)
+#define NCCL_HAS_COMM_SPLIT
+#endif
+
 // ncclGetLastError() is enabled only for NCCL versions 2.13+
 // ncclRemoteError only exists in NCCL versions 2.13+
 #if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \
@@ -246,6 +251,22 @@ class NCCLComm {
   }
 #endif
 
+#ifdef NCCL_HAS_COMM_SPLIT
+  static std::shared_ptr split(
+      NCCLComm* source,
+      int color_id,
+      int rank,
+      ncclConfig_t& config) {
+    auto comm = std::make_shared();
+    C10D_NCCL_CHECK(
+        ncclCommSplit(
+            source->ncclComm_, color_id, rank, &(comm->ncclComm_), &config),
+        c10::nullopt);
+    ++source->ncclCommSplitCounter_;
+    return comm;
+  }
+#endif
+
   ncclUniqueId getNcclId() {
     return ncclId_;
   }
@@ -325,6 +346,10 @@ class NCCLComm {
     return aborted_;
   }
 
+  uint64_t getCommSplitCounter() const {
+    return ncclCommSplitCounter_;
+  }
+
   ncclResult_t checkForNcclError() {
     std::unique_lock lock(mutex_);
 #ifdef ENABLE_NCCL_ERROR_CHECKING
@@ -401,6 +426,7 @@ class NCCLComm {
   // Unique nccl_id for this communicator.
   ncclUniqueId ncclId_;
   bool aborted_;
+  uint64_t ncclCommSplitCounter_{0};
   ncclResult_t ncclAsyncErr_;
   mutable std::mutex mutex_;
   // Rank that this communicator corresponds to.
diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
index f4691c815c61..d2b74c046918 100644
--- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
+++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
@@ -1898,11 +1898,40 @@ std::vector>& ProcessGroupNCCL::getNCCLComm(
     int deviceIndex = devices[i].index();
 
     gpuGuard.set_index(deviceIndex);
+#ifdef NCCL_HAS_COMM_SPLIT
+    if (options_->split_from) {
+      TORCH_CHECK(
+          options_->split_color != 0,
+          "Must specify a non-zero color when splitting");
+      // Find a valid, healthy communicator to split from if possible.
+      std::lock_guard lock(options_->split_from->mutex_);
+      auto& other_comms = options_->split_from->devNCCLCommMap_;
+      auto dit = other_comms.find(devicesKey);
+      if (dit != other_comms.end() && !dit->second.empty()) {
+        TORCH_INTERNAL_ASSERT(
+            dit->second.size() == ncclComms.size(),
+            "split_from->devNCCLCommMap_ should be empty or the same size as ncclComms!");
+        if (dit->second[i] && !dit->second[i]->isAborted()) {
+          ncclComms[i] = NCCLComm::split(
+              dit->second[i].get(),
+              options_->split_color,
+              rank,
+              options_->config);
+        }
+      }
+    }
+#endif
+
+    // To simplify conditioonal nesting, just create the ncclComms[i]
+    // entry if it hasn't been yet rather than untangling the
+    // conditions that might have resulted in a split above.
+    if (!ncclComms[i]) {
 #ifdef NCCL_HAS_COMM_NONBLOCKING
-    ncclComms[i] = NCCLComm::create(numRanks, rank, ncclID, options_->config);
+      ncclComms[i] = NCCLComm::create(numRanks, rank, ncclID, options_->config);
 #else
-    ncclComms[i] = NCCLComm::create(numRanks, rank, ncclID);
+      ncclComms[i] = NCCLComm::create(numRanks, rank, ncclID);
 #endif
+    }
 
     // Creates the NCCL streams
     streamVal.push_back(
@@ -1948,9 +1977,6 @@ std::vector>& ProcessGroupNCCL::getNCCLComm(
       std::make_tuple(devicesKey),
       std::make_tuple(devices.size()));
 
-  // Hold the lock before modifying the cache.
-  std::lock_guard lock(mutex_);
-
   // Record the communicators based on ncclUniqueId.
   ncclIdToCommMap_.emplace(buildNcclUniqueIdStr(ncclID), ncclComms);
 
@@ -1994,9 +2020,20 @@ std::vector>& ProcessGroupNCCL::getNCCLComm(
   it = devNCCLCommMap_.find(devicesKey);
   TORCH_INTERNAL_ASSERT(
       it != devNCCLCommMap_.end(), "Communicators not populated in cache!");
+
   return it->second;
 }
 
+uint64_t ProcessGroupNCCL::getCommSplitCounter() const {
+  uint64_t ret = 0;
+  for (const auto& i : ncclIdToCommMap_) {
+    for (const auto& j : i.second) {
+      ret += j->getCommSplitCounter();
+    }
+  }
+  return ret;
+}
+
 namespace {
 
 // Check validity of tensor
diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp
index 79c986e7512f..6404d01f6cc7 100644
--- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp
+++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp
@@ -342,6 +342,13 @@ class TORCH_API ProcessGroupNCCL : public Backend {
     // Configure ranks
     ncclConfig_t config = NCCL_CONFIG_INITIALIZER;
 #endif
+
+    // Optional "parent" backend and color to create communicators from
+    // via `ncclCommSplit`
+#ifdef NCCL_HAS_COMM_SPLIT
+    std::shared_ptr split_from;
+    int64_t split_color{0};
+#endif
   };
 
   // If you wish to create multiple process groups, each with a potentially
@@ -510,6 +517,10 @@ class TORCH_API ProcessGroupNCCL : public Backend {
   // may indicate that there is some sort of collective desynchronization.
   uint64_t getSequenceNumberForGroup() override;
 
+  // Return the total number of splits the communicators held by this process
+  // group have performed.
+  uint64_t getCommSplitCounter() const;
+
   void registerOnCompletionHook(
       std::function)>&& hook) override;
   void waitForPendingWorks() override;
diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp
index b4119666d0bc..59ff55db9992 100644
--- a/torch/csrc/distributed/c10d/init.cpp
+++ b/torch/csrc/distributed/c10d/init.cpp
@@ -2290,6 +2290,9 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
               py::call_guard())
           .def("_group_start", &::c10d::ProcessGroupNCCL::groupStart)
           .def("_group_end", &::c10d::ProcessGroupNCCL::groupEnd)
+          .def(
+              "comm_split_count",
+              &::c10d::ProcessGroupNCCL::getCommSplitCounter)
           .def_property_readonly(
               "options", &::c10d::ProcessGroupNCCL::getOptions)
           .def_property_readonly(
@@ -2354,15 +2357,18 @@ Example::
       )")
       .def(py::init(), py::arg("is_high_priority_stream") = false)
 #ifdef NCCL_HAS_COMM_CTA_CGA
+      .def_readwrite("config", &::c10d::ProcessGroupNCCL::Options::config)
+#endif
       .def_readwrite(
           "is_high_priority_stream",
           &::c10d::ProcessGroupNCCL::Options::is_high_priority_stream)
-      .def_readwrite("config", &::c10d::ProcessGroupNCCL::Options::config);
-#else
+#ifdef NCCL_HAS_COMM_SPLIT
       .def_readwrite(
-          "is_high_priority_stream",
-          &::c10d::ProcessGroupNCCL::Options::is_high_priority_stream);
+          "split_from", &::c10d::ProcessGroupNCCL::Options::split_from)
+      .def_readwrite(
+          "split_color", &::c10d::ProcessGroupNCCL::Options::split_color)
 #endif
+      ;
 
 #endif
 
diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py
index 63f6c48d35f3..3bd35709505d 100644
--- a/torch/distributed/distributed_c10d.py
+++ b/torch/distributed/distributed_c10d.py
@@ -8,6 +8,7 @@
 import logging
 import os
 import pickle
+import sys
 import time
 import warnings
 from collections import namedtuple
@@ -1314,7 +1315,29 @@ def _new_process_group_helper(
                 pg_options.is_high_priority_stream = False
             pg_options._timeout = timeout
 
-            backend_class = ProcessGroupNCCL(backend_prefix_store, group_rank, group_size, pg_options)
+            # If our new group includes all ranks, we can reduce
+            # overhead by splitting the communicator (`nccCommSplit`).
+
+            # TODO: support this in the general case by calling
+            # `nccCommSplit` with `NCCL_SPLIT_NOCOLOR` for the ranks
+            # not in the communicator.
+            split_from = None
+            if (
+                is_initialized()
+                and _world.default_pg._get_backend_name() == Backend.NCCL
+                and len(global_ranks_in_group) == _world.default_pg.size()
+            ):
+                # If possible, find a backend to split from by peeling
+                # process group wrappers from the world's default pg.
+                split_from = _world.default_pg._get_backend(_get_pg_default_device())
+                while isinstance(split_from, _ProcessGroupWrapper):
+                    split_from = split_from.wrapped_pg
+
+                if split_from:
+                    pg_options.split_from = split_from
+                    pg_options.split_color = _process_group_color(global_ranks_in_group)
+            backend_class = ProcessGroupNCCL(
+                backend_prefix_store, group_rank, group_size, pg_options)
             backend_type = ProcessGroup.BackendType.NCCL
         elif backend_str == Backend.UCC and is_ucc_available():
             # TODO: once UCC plugin is fully deprecated, remove
@@ -3514,11 +3537,19 @@ def _create_process_group_wrapper(
     wrapped_pg = _ProcessGroupWrapper(wrapped_pg, helper_pg)
     return wrapped_pg
 
+# helper function for deterministically hashing a list of ranks
+def _hash_ranks(ranks: List[int]):
+    return hashlib.sha1(bytes("_".join(map(str, ranks)), "utf-8")).hexdigest()
+
+# Takes a list of ranks and computes an integer color
+def _process_group_color(ranks: List[int]) -> int:
+    # Convert our hash to an int, but avoid negative numbers by shifting a bit.
+    return int(_hash_ranks(ranks), 16) % (sys.maxsize >> 1)
 
 def _process_group_name(ranks, use_hashed_name):
     global _world
     if use_hashed_name:
-        pg_name = hashlib.sha1(bytes("_".join(map(str, ranks)), "utf-8")).hexdigest()
+        pg_name = _hash_ranks(ranks)
         while pg_name in _world.pg_names.values():
             pg_name = hashlib.sha1(bytes(pg_name + "_", "utf-8")).hexdigest()
     else:

From ebeaec71bf821d8ae34877e2e837eb70dd61c8c3 Mon Sep 17 00:00:00 2001
From: Yang Chen 
Date: Mon, 20 Nov 2023 16:40:04 -0800
Subject: [PATCH 060/221] [aotinductor] don't generate python profiling code in
 the cpp world (#114182)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114182
Approved by: https://github.com/aakhundov, https://github.com/desertfire
---
 test/inductor/test_aot_inductor.py | 16 ++++++++++++++++
 torch/_inductor/codegen/wrapper.py | 16 ++++++++++++++--
 2 files changed, 30 insertions(+), 2 deletions(-)

diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py
index 94d1bb2227f4..53b8d6a0a009 100644
--- a/test/inductor/test_aot_inductor.py
+++ b/test/inductor/test_aot_inductor.py
@@ -1075,6 +1075,22 @@ def forward(self, x):
         x = torch.randn(5, device=self.device)
         self.check_model(Model(self.device), (x,))
 
+    def test_with_profiler(self):
+        class Model(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.linear = torch.nn.Linear(10, 10)
+
+            def forward(self, x, y):
+                return x + self.linear(y)
+
+        example_inputs = (
+            torch.randn(10, 10, device=self.device),
+            torch.randn(10, 10, device=self.device),
+        )
+        with config.patch({"profile_bandwidth": "1", "profile_bandwidth_regex": ""}):
+            self.check_model(Model(), example_inputs)
+
     def test_repeat_output(self):
         class Model(torch.nn.Module):
             def __init__(self):
diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py
index 6a8000fe4d16..1787937f39ab 100644
--- a/torch/_inductor/codegen/wrapper.py
+++ b/torch/_inductor/codegen/wrapper.py
@@ -577,7 +577,7 @@ def generate(self, is_inference):
                 self.generate_profiler_mark_wrapper_call(stack)
             if config.profile_bandwidth:
                 self.write_triton_header_once()
-                self.wrapper_call.writeline("start_graph()")
+                self.generate_start_graph()
 
             # We disable planning during training because it presently increases peak memory consumption.
             if is_inference and config.memory_planning:
@@ -606,7 +606,7 @@ def generate(self, is_inference):
                 self.wrapper_call.writeline("torch.cuda.synchronize()")
 
             if config.profile_bandwidth:
-                self.wrapper_call.writeline("end_graph()")
+                self.generate_end_graph()
 
             self.generate_return(output_refs)
 
@@ -987,6 +987,12 @@ def generate_profiler_mark_wrapper_call(self, stack):
         )
         stack.enter_context(self.wrapper_call.indent())
 
+    def generate_start_graph(self):
+        self.wrapper_call.writeline("start_graph()")
+
+    def generate_end_graph(self):
+        self.wrapper_call.writeline("end_graph()")
+
     def generate_default_grid(self, name: str, grid_args: List[Any]):
         return grid_args
 
@@ -1874,6 +1880,12 @@ def generate_profiler_mark_wrapper_call(self, stack):
             'RECORD_FUNCTION("inductor_wrapper_call", c10::ArrayRef());'
         )
 
+    def generate_start_graph(self):
+        pass
+
+    def generate_end_graph(self):
+        pass
+
     def generate_inf_and_nan_checker(self, nodes):
         for buf in nodes.get_names():
             # TODO: Add buf name directly into check_inf_and_nan.

From 799d8c303558d1501eb7829cdf631fcbb71200de Mon Sep 17 00:00:00 2001
From: Bin Bao 
Date: Tue, 21 Nov 2023 13:23:07 -0500
Subject: [PATCH 061/221] [CI] Rename the inductor test config names for
 dynamic shapes tests (#113574)

Summary: To make the naming consistent with tests in inductor-periodic and simplify update_expected.py

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113574
Approved by: https://github.com/eellison, https://github.com/malfet, https://github.com/jansel
---
 .github/workflows/inductor.yml                            | 8 ++++----
 ...nce.csv => dynamic_inductor_huggingface_inference.csv} | 0
 ...ning.csv => dynamic_inductor_huggingface_training.csv} | 0
 ..._inference.csv => dynamic_inductor_timm_inference.csv} | 0
 ...ic_training.csv => dynamic_inductor_timm_training.csv} | 0
 ...ence.csv => dynamic_inductor_torchbench_inference.csv} | 0
 ...ining.csv => dynamic_inductor_torchbench_training.csv} | 0
 benchmarks/dynamo/ci_expected_accuracy/update_expected.py | 8 ++++----
 8 files changed, 8 insertions(+), 8 deletions(-)
 rename benchmarks/dynamo/ci_expected_accuracy/{inductor_huggingface_dynamic_inference.csv => dynamic_inductor_huggingface_inference.csv} (100%)
 rename benchmarks/dynamo/ci_expected_accuracy/{inductor_huggingface_dynamic_training.csv => dynamic_inductor_huggingface_training.csv} (100%)
 rename benchmarks/dynamo/ci_expected_accuracy/{inductor_timm_dynamic_inference.csv => dynamic_inductor_timm_inference.csv} (100%)
 rename benchmarks/dynamo/ci_expected_accuracy/{inductor_timm_dynamic_training.csv => dynamic_inductor_timm_training.csv} (100%)
 rename benchmarks/dynamo/ci_expected_accuracy/{inductor_torchbench_dynamic_inference.csv => dynamic_inductor_torchbench_inference.csv} (100%)
 rename benchmarks/dynamo/ci_expected_accuracy/{inductor_torchbench_dynamic_training.csv => dynamic_inductor_torchbench_training.csv} (100%)

diff --git a/.github/workflows/inductor.yml b/.github/workflows/inductor.yml
index be4a04347867..f38204a55e83 100644
--- a/.github/workflows/inductor.yml
+++ b/.github/workflows/inductor.yml
@@ -27,10 +27,10 @@ jobs:
           { config: "inductor_timm", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" },
           { config: "inductor_timm", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" },
           { config: "inductor_torchbench", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" },
-          { config: "inductor_huggingface_dynamic", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" },
-          { config: "inductor_timm_dynamic", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" },
-          { config: "inductor_timm_dynamic", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" },
-          { config: "inductor_torchbench_dynamic", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" },
+          { config: "dynamic_inductor_huggingface", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" },
+          { config: "dynamic_inductor_timm", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" },
+          { config: "dynamic_inductor_timm", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" },
+          { config: "dynamic_inductor_torchbench", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" },
           { config: "inductor_distributed", shard: 1, num_shards: 1, runner: "linux.g5.12xlarge.nvidia.gpu" },
           { config: "aot_inductor_huggingface", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" },
           { config: "aot_inductor_timm", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" },
diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_dynamic_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_huggingface_inference.csv
similarity index 100%
rename from benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_dynamic_inference.csv
rename to benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_huggingface_inference.csv
diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_dynamic_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_huggingface_training.csv
similarity index 100%
rename from benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_dynamic_training.csv
rename to benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_huggingface_training.csv
diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_timm_dynamic_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_timm_inference.csv
similarity index 100%
rename from benchmarks/dynamo/ci_expected_accuracy/inductor_timm_dynamic_inference.csv
rename to benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_timm_inference.csv
diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_timm_dynamic_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_timm_training.csv
similarity index 100%
rename from benchmarks/dynamo/ci_expected_accuracy/inductor_timm_dynamic_training.csv
rename to benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_timm_training.csv
diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_dynamic_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv
similarity index 100%
rename from benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_dynamic_inference.csv
rename to benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv
diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_dynamic_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_training.csv
similarity index 100%
rename from benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_dynamic_training.csv
rename to benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_training.csv
diff --git a/benchmarks/dynamo/ci_expected_accuracy/update_expected.py b/benchmarks/dynamo/ci_expected_accuracy/update_expected.py
index cca8ad9e6067..77640b2af835 100644
--- a/benchmarks/dynamo/ci_expected_accuracy/update_expected.py
+++ b/benchmarks/dynamo/ci_expected_accuracy/update_expected.py
@@ -82,7 +82,7 @@ def get_artifacts_urls(results, suites):
 
 def normalize_suite_filename(suite_name):
     strs = suite_name.split("_")
-    subsuite = strs[2] if strs[0] == "aot" else strs[1]
+    subsuite = strs[-1]
     if "timm" in subsuite:
         subsuite = subsuite.replace("timm", "timm_models")
 
@@ -143,13 +143,13 @@ def apply_lints(filename):
     suites = {
         "aot_inductor_huggingface",
         "inductor_huggingface",
-        "inductor_huggingface_dynamic",
+        "dynamic_inductor_huggingface",
         "aot_inductor_timm",
         "inductor_timm",
-        "inductor_timm_dynamic",
+        "dynamic_inductor_timm",
         "aot_inductor_torchbench",
         "inductor_torchbench",
-        "inductor_torchbench_dynamic",
+        "dynamic_inductor_torchbench",
     }
 
     root_path = "benchmarks/dynamo/ci_expected_accuracy/"

From 3c8a4f01b93d2cb905ea8a64ffc4678597533e39 Mon Sep 17 00:00:00 2001
From: Bin Bao 
Date: Tue, 21 Nov 2023 13:23:07 -0500
Subject: [PATCH 062/221] [CI] Increase the shard numbers for torchbench tests
 (#113575)

Summary: torchbench tests are always the lagging shards when comparing to other integration test shards, so let's bump up the corresponding shard numbers.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113575
Approved by: https://github.com/ezyang, https://github.com/malfet, https://github.com/jansel
ghstack dependencies: #113574
---
 .github/workflows/inductor-periodic.yml |  9 ++++++---
 .github/workflows/inductor.yml          | 17 +++++++++++------
 2 files changed, 17 insertions(+), 9 deletions(-)

diff --git a/.github/workflows/inductor-periodic.yml b/.github/workflows/inductor-periodic.yml
index 3c116e2ef34a..f775acf1e9e7 100644
--- a/.github/workflows/inductor-periodic.yml
+++ b/.github/workflows/inductor-periodic.yml
@@ -24,15 +24,18 @@ jobs:
       cuda-arch-list: '8.6'
       test-matrix: |
         { include: [
-          { config: "dynamo_eager_torchbench", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" },
+          { config: "dynamo_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" },
+          { config: "dynamo_eager_torchbench", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" },
           { config: "dynamo_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" },
           { config: "dynamo_eager_timm", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" },
           { config: "dynamo_eager_timm", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" },
-          { config: "aot_eager_torchbench", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" },
+          { config: "aot_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" },
+          { config: "aot_eager_torchbench", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" },
           { config: "aot_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" },
           { config: "aot_eager_timm", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" },
           { config: "aot_eager_timm", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" },
-          { config: "dynamic_aot_eager_torchbench", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" },
+          { config: "dynamic_aot_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" },
+          { config: "dynamic_aot_eager_torchbench", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" },
           { config: "dynamic_aot_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" },
           { config: "dynamic_aot_eager_timm", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" },
           { config: "dynamic_aot_eager_timm", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" },
diff --git a/.github/workflows/inductor.yml b/.github/workflows/inductor.yml
index f38204a55e83..141c8e619dc1 100644
--- a/.github/workflows/inductor.yml
+++ b/.github/workflows/inductor.yml
@@ -23,19 +23,22 @@ jobs:
       test-matrix: |
         { include: [
           { config: "inductor", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" },
+          { config: "inductor_distributed", shard: 1, num_shards: 1, runner: "linux.g5.12xlarge.nvidia.gpu" },
           { config: "inductor_huggingface", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" },
           { config: "inductor_timm", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" },
           { config: "inductor_timm", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" },
-          { config: "inductor_torchbench", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" },
+          { config: "inductor_torchbench", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" },
+          { config: "inductor_torchbench", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" },
           { config: "dynamic_inductor_huggingface", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" },
           { config: "dynamic_inductor_timm", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" },
           { config: "dynamic_inductor_timm", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" },
-          { config: "dynamic_inductor_torchbench", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" },
-          { config: "inductor_distributed", shard: 1, num_shards: 1, runner: "linux.g5.12xlarge.nvidia.gpu" },
+          { config: "dynamic_inductor_torchbench", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" },
+          { config: "dynamic_inductor_torchbench", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" },
           { config: "aot_inductor_huggingface", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" },
           { config: "aot_inductor_timm", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" },
           { config: "aot_inductor_timm", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" },
-          { config: "aot_inductor_torchbench", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" },
+          { config: "aot_inductor_torchbench", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" },
+          { config: "aot_inductor_torchbench", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" },
         ]}
     secrets:
       HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
@@ -88,11 +91,13 @@ jobs:
           { config: "inductor_huggingface_cpu_accuracy", shard: 1, num_shards: 1, runner: "linux.12xlarge" },
           { config: "inductor_timm_cpu_accuracy", shard: 1, num_shards: 2, runner: "linux.4xlarge" },
           { config: "inductor_timm_cpu_accuracy", shard: 2, num_shards: 2, runner: "linux.4xlarge" },
-          { config: "inductor_torchbench_cpu_accuracy", shard: 1, num_shards: 1, runner: "linux.4xlarge" },
+          { config: "inductor_torchbench_cpu_accuracy", shard: 1, num_shards: 2, runner: "linux.4xlarge" },
+          { config: "inductor_torchbench_cpu_accuracy", shard: 2, num_shards: 2, runner: "linux.4xlarge" },
           { config: "inductor_huggingface_dynamic_cpu_accuracy", shard: 1, num_shards: 1, runner: "linux.12xlarge" },
           { config: "inductor_timm_dynamic_cpu_accuracy", shard: 1, num_shards: 2, runner: "linux.12xlarge" },
           { config: "inductor_timm_dynamic_cpu_accuracy", shard: 2, num_shards: 2, runner: "linux.12xlarge" },
-          { config: "inductor_torchbench_dynamic_cpu_accuracy", shard: 1, num_shards: 1, runner: "linux.12xlarge" },
+          { config: "inductor_torchbench_dynamic_cpu_accuracy", shard: 1, num_shards: 2, runner: "linux.12xlarge" },
+          { config: "inductor_torchbench_dynamic_cpu_accuracy", shard: 2, num_shards: 2, runner: "linux.12xlarge" },
         ]}
     secrets:
       HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}

From 212f668408f63f228f2d02c9be3ea62105552d67 Mon Sep 17 00:00:00 2001
From: Bin Bao 
Date: Tue, 21 Nov 2023 13:23:07 -0500
Subject: [PATCH 063/221] [CI] Remove CI skip list for inductor integration
 tests (#113446)

Summary: Switch to completely rely on checking against expected result files.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113446
Approved by: https://github.com/ezyang, https://github.com/malfet, https://github.com/jansel
ghstack dependencies: #113574, #113575
---
 ...dynamic_inductor_huggingface_inference.csv |  8 +++
 .../dynamic_inductor_huggingface_training.csv | 12 ++++
 .../dynamic_inductor_timm_training.csv        | 40 ++++++++++++
 .../dynamic_inductor_torchbench_inference.csv | 52 ++++++++++++++++
 .../dynamic_inductor_torchbench_training.csv  | 40 ++++++++++++
 .../inductor_huggingface_training.csv         |  4 ++
 .../inductor_timm_training.csv                | 32 ++++++++++
 .../inductor_torchbench_inference.csv         | 32 ++++++++++
 .../inductor_torchbench_training.csv          | 28 +++++++++
 benchmarks/dynamo/common.py                   | 62 -------------------
 benchmarks/dynamo/timm_models.py              | 10 +++
 11 files changed, 258 insertions(+), 62 deletions(-)

diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_huggingface_inference.csv
index ef9ba4763a9b..349239b058a7 100644
--- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_huggingface_inference.csv
+++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_huggingface_inference.csv
@@ -18,6 +18,10 @@ BartForCausalLM,pass,0
 
 
 
+BartForConditionalGeneration,pass,0
+
+
+
 BertForMaskedLM,pass,0
 
 
@@ -54,6 +58,10 @@ DebertaV2ForMaskedLM,pass_due_to_skip,0
 
 
 
+DebertaV2ForQuestionAnswering,pass,0
+
+
+
 DistilBertForMaskedLM,pass,0
 
 
diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_huggingface_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_huggingface_training.csv
index 3111b7cf5402..6bccf5082c70 100644
--- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_huggingface_training.csv
+++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_huggingface_training.csv
@@ -6,6 +6,10 @@ AlbertForMaskedLM,pass,6
 
 
 
+AlbertForQuestionAnswering,pass,6
+
+
+
 AllenaiLongformerBase,pass,10
 
 
@@ -14,6 +18,10 @@ BartForCausalLM,pass,14
 
 
 
+BartForConditionalGeneration,pass,26
+
+
+
 BertForMaskedLM,pass,6
 
 
@@ -46,6 +54,10 @@ DebertaV2ForMaskedLM,pass_due_to_skip,0
 
 
 
+DebertaV2ForQuestionAnswering,eager_1st_run_OOM,0
+
+
+
 DistilBertForMaskedLM,pass,6
 
 
diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_timm_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_timm_training.csv
index dfb2526f85c6..0f52a123eb78 100644
--- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_timm_training.csv
+++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_timm_training.csv
@@ -30,10 +30,18 @@ convnext_base,pass,8
 
 
 
+crossvit_9_240,pass,8
+
+
+
 cspdarknet53,pass,8
 
 
 
+deit_base_distilled_patch16_224,pass,8
+
+
+
 dla102,pass,8
 
 
@@ -102,6 +110,10 @@ lcnet_050,pass,8
 
 
 
+levit_128,pass,8
+
+
+
 mixer_b16_224,pass,8
 
 
@@ -122,10 +134,18 @@ mobilenetv3_large_100,pass,8
 
 
 
+mobilevit_s,pass,8
+
+
+
 nfnet_l0,pass,8
 
 
 
+pit_b_224,pass,8
+
+
+
 pnasnet5large,pass,6
 
 
@@ -166,6 +186,10 @@ rexnet_100,pass,8
 
 
 
+sebotnet33ts_256,pass,8
+
+
+
 selecsls42b,pass,8
 
 
@@ -198,4 +222,20 @@ tnt_s_patch16_224,pass,8
 
 
 
+twins_pcpvt_base,pass,8
+
+
+
+visformer_small,pass,8
+
+
+
+vit_base_patch16_224,pass,8
+
+
+
 volo_d1_224,pass,8
+
+
+
+xcit_large_24_p8_224,pass_due_to_skip,8
diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv
index 87538263eb66..3d975ef1fe68 100644
--- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv
+++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv
@@ -10,6 +10,10 @@ Background_Matting,pass_due_to_skip,0
 
 
 
+DALLE2_pytorch,fail_to_run,21
+
+
+
 LearningToPaint,pass,0
 
 
@@ -46,6 +50,14 @@ dcgan,pass,0
 
 
 
+densenet121,pass,0
+
+
+
+detectron2_fcos_r_50_fpn,infra_error,0
+
+
+
 dlrm,pass,0
 
 
@@ -82,6 +94,10 @@ hf_Bart,pass,0
 
 
 
+hf_BigBird,fail_to_run,0
+
+
+
 hf_DistilBert,pass,0
 
 
@@ -90,10 +106,22 @@ hf_GPT2,pass,0
 
 
 
+hf_GPT2_large,pass_due_to_skip,0
+
+
+
 hf_Reformer,pass,5
 
 
 
+hf_T5,pass,0
+
+
+
+hf_T5_generate,fail_to_run,10
+
+
+
 hf_T5_large,pass_due_to_skip,0
 
 
@@ -130,6 +158,10 @@ moco,pass,11
 
 
 
+nanogpt,pass,0
+
+
+
 nvidia_deeprecommender,pass,0
 
 
@@ -146,6 +178,18 @@ phlippe_resnet,pass,0
 
 
 
+pyhpc_equation_of_state,pass,0
+
+
+
+pyhpc_isoneutral_mixing,pass,0
+
+
+
+pyhpc_turbulent_kinetic_energy,infra_error,0
+
+
+
 pytorch_CycleGAN_and_pix2pix,pass,0
 
 
@@ -186,6 +230,10 @@ soft_actor_critic,pass,0
 
 
 
+speech_transformer,pass,10
+
+
+
 squeezenet1_1,pass,0
 
 
@@ -230,4 +278,8 @@ vgg16,pass,0
 
 
 
+vision_maskrcnn,pass,17
+
+
+
 yolov3,pass,2
diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_training.csv
index a5a7ee0e85e2..9e3eea9ed570 100644
--- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_training.csv
+++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_training.csv
@@ -10,6 +10,10 @@ BERT_pytorch,pass,8
 
 
 
+Background_Matting,pass_due_to_skip,0
+
+
+
 LearningToPaint,pass,8
 
 
@@ -42,6 +46,14 @@ dcgan,pass,8
 
 
 
+demucs,fail_to_run,5
+
+
+
+densenet121,pass,8
+
+
+
 dlrm,pass,8
 
 
@@ -70,6 +82,10 @@ hf_Bart,pass,7
 
 
 
+hf_BigBird,fail_to_run,4
+
+
+
 hf_DistilBert,pass,7
 
 
@@ -78,10 +94,18 @@ hf_GPT2,pass,7
 
 
 
+hf_GPT2_large,pass_due_to_skip,0
+
+
+
 hf_Reformer,pass,27
 
 
 
+hf_T5_base,OOM,4
+
+
+
 hf_T5_large,pass_due_to_skip,0
 
 
@@ -106,10 +130,18 @@ mobilenet_v2,pass,8
 
 
 
+mobilenet_v3_large,pass,8
+
+
+
 moco,pass,19
 
 
 
+nanogpt,pass,8
+
+
+
 nvidia_deeprecommender,pass,8
 
 
@@ -158,6 +190,10 @@ soft_actor_critic,pass,7
 
 
 
+speech_transformer,infra_error,0
+
+
+
 squeezenet1_1,pass,8
 
 
@@ -202,4 +238,8 @@ vgg16,pass,8
 
 
 
+vision_maskrcnn,fail_accuracy,36
+
+
+
 yolov3,pass,10
diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_training.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_training.csv
index 6db7429479fd..6bccf5082c70 100644
--- a/benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_training.csv
+++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_training.csv
@@ -6,6 +6,10 @@ AlbertForMaskedLM,pass,6
 
 
 
+AlbertForQuestionAnswering,pass,6
+
+
+
 AllenaiLongformerBase,pass,10
 
 
diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_timm_training.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_timm_training.csv
index 5b5a272027c6..0f52a123eb78 100644
--- a/benchmarks/dynamo/ci_expected_accuracy/inductor_timm_training.csv
+++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_timm_training.csv
@@ -30,10 +30,18 @@ convnext_base,pass,8
 
 
 
+crossvit_9_240,pass,8
+
+
+
 cspdarknet53,pass,8
 
 
 
+deit_base_distilled_patch16_224,pass,8
+
+
+
 dla102,pass,8
 
 
@@ -126,10 +134,18 @@ mobilenetv3_large_100,pass,8
 
 
 
+mobilevit_s,pass,8
+
+
+
 nfnet_l0,pass,8
 
 
 
+pit_b_224,pass,8
+
+
+
 pnasnet5large,pass,6
 
 
@@ -206,4 +222,20 @@ tnt_s_patch16_224,pass,8
 
 
 
+twins_pcpvt_base,pass,8
+
+
+
+visformer_small,pass,8
+
+
+
+vit_base_patch16_224,pass,8
+
+
+
 volo_d1_224,pass,8
+
+
+
+xcit_large_24_p8_224,pass_due_to_skip,8
diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv
index ea4760ffaa99..ac90d0bbb8ac 100644
--- a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv
+++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv
@@ -10,6 +10,10 @@ Background_Matting,pass_due_to_skip,0
 
 
 
+DALLE2_pytorch,fail_to_run,21
+
+
+
 LearningToPaint,pass,0
 
 
@@ -46,6 +50,14 @@ dcgan,pass,0
 
 
 
+densenet121,pass,0
+
+
+
+detectron2_fcos_r_50_fpn,pass,42
+
+
+
 dlrm,pass,0
 
 
@@ -82,6 +94,10 @@ hf_Bart,pass,0
 
 
 
+hf_BigBird,fail_accuracy,0
+
+
+
 hf_DistilBert,pass,0
 
 
@@ -90,10 +106,18 @@ hf_GPT2,pass,0
 
 
 
+hf_GPT2_large,pass_due_to_skip,0
+
+
+
 hf_Reformer,pass,5
 
 
 
+hf_T5,pass,0
+
+
+
 hf_T5_generate,pass,20
 
 
@@ -154,10 +178,18 @@ phlippe_resnet,pass,0
 
 
 
+pyhpc_equation_of_state,pass,0
+
+
+
 pyhpc_isoneutral_mixing,pass,0
 
 
 
+pyhpc_turbulent_kinetic_energy,pass,0
+
+
+
 pytorch_CycleGAN_and_pix2pix,pass,0
 
 
diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv
index 1dc5c17908eb..c1af501dda96 100644
--- a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv
+++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv
@@ -10,6 +10,10 @@ BERT_pytorch,pass,8
 
 
 
+Background_Matting,pass_due_to_skip,0
+
+
+
 LearningToPaint,pass,8
 
 
@@ -42,6 +46,14 @@ dcgan,pass,8
 
 
 
+demucs,pass,11
+
+
+
+densenet121,pass,8
+
+
+
 dlrm,pass,8
 
 
@@ -70,6 +82,10 @@ hf_Bart,pass,7
 
 
 
+hf_BigBird,pass,7
+
+
+
 hf_DistilBert,pass,7
 
 
@@ -78,10 +94,18 @@ hf_GPT2,pass,7
 
 
 
+hf_GPT2_large,pass_due_to_skip,0
+
+
+
 hf_Reformer,pass,27
 
 
 
+hf_T5_base,OOM,4
+
+
+
 hf_T5_large,pass_due_to_skip,0
 
 
@@ -106,6 +130,10 @@ mobilenet_v2,pass,8
 
 
 
+mobilenet_v3_large,pass,8
+
+
+
 moco,pass,19
 
 
diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py
index b6b1d63de65e..fee02e094880 100644
--- a/benchmarks/dynamo/common.py
+++ b/benchmarks/dynamo/common.py
@@ -192,35 +192,6 @@ class CI(NamedTuple):
     "xcit_large_24_p8_224",  # fp64_OOM,
 ]
 
-CI_SKIP[CI("inductor", training=False)] = [
-    # TorchBench
-    "DALLE2_pytorch",  # AttributeError: text_encodings
-    "demucs",  # OOM
-    "detectron2_fasterrcnn_r_101_c4",
-    "detectron2_fasterrcnn_r_101_dc5",
-    "detectron2_fasterrcnn_r_101_fpn",
-    "detectron2_fasterrcnn_r_50_c4",
-    "detectron2_fasterrcnn_r_50_dc5",
-    "detectron2_fasterrcnn_r_50_fpn",
-    "detectron2_fcos_r_50_fpn",
-    "detectron2_maskrcnn_r_101_c4",
-    "detectron2_maskrcnn_r_101_fpn",
-    "detectron2_maskrcnn_r_50_c4",
-    "detectron2_maskrcnn_r_50_fpn",
-    # TorchBench
-    "detectron2",
-    "densenet121",  # flaky accuracy
-    "hf_T5",  # accuracy
-    "hf_BigBird",  # accuracy
-    "hf_GPT2_large",  # OOM
-    "maml",  # accuracy
-    "mobilenet_v2_quantized_qat",  # The eval test only supports CPU
-    "pytorch_struct",  # Test eval is not implemented
-    "pyhpc_equation_of_state",  # Accuracy
-    "pyhpc_turbulent_kinetic_energy",  # Accuracy
-    "tacotron2",
-]
-
 CI_SKIP[CI("inductor", training=False, device="cpu")] = [
     # TorchBench
     "drq",  # Need to update torchbench
@@ -256,24 +227,6 @@ class CI(NamedTuple):
     "opacus_cifar10",  # Fails to run https://github.com/pytorch/pytorch/issues/99201
 ]
 
-CI_SKIP[CI("inductor", training=True)] = [
-    *CI_SKIP[CI("inductor", training=False)],
-    # TorchBench
-    "Background_Matting",  # fp64_OOM
-    "hf_T5_base",  # accuracy
-    "mobilenet_v3_large",  # accuracy
-    "resnet50_quantized_qat",  # Eager model failed to run
-    "AlbertForQuestionAnswering",  # accuracy
-    "crossvit_9_240",  # fails to run on timm 0.8.22 with cudagraphs, mempools
-    "deit_base_distilled_patch16_224",  # fails to run in timm 0.8.22, cudagraphs
-    "mobilevit_s",
-    "pit_b_224",
-    "twins_pcpvt_base",
-    "visformer_small",
-    "vit_base_patch16_224",
-    "xcit_large_24_p8_224",
-]
-
 # Skips for dynamic=True
 
 CI_SKIP[CI("aot_eager", training=False, dynamic=True)] = [
@@ -291,21 +244,6 @@ class CI(NamedTuple):
     "torchrec_dlrm",  # RuntimeError: mat1 and mat2 must have the same dtype, but got Float and BFloat16
 ]
 
-CI_SKIP[CI("inductor", training=False, dynamic=True)] = [
-    *CI_SKIP[CI("aot_eager", training=False, dynamic=True)],
-    *CI_SKIP[CI("inductor", training=False)],
-    "nanogpt",  # Assertion `index out of bounds: 0 <= tmp0 < 64` failed.
-]
-
-CI_SKIP[CI("inductor", training=True, dynamic=True)] = [
-    # NB: Intentionally omitting for symmetry with dynamic=False
-    # *CI_SKIP[CI("aot_eager", training=True, dynamic=True)],
-    *CI_SKIP[CI("inductor", training=False, dynamic=True)],
-    *CI_SKIP[CI("inductor", training=True)],
-    "levit_128",  # Accuracy fails on A10G, passes on A100
-    "sebotnet33ts_256",  # Flaky accuracy failed
-]
-
 CI_SKIP[CI("inductor", training=False, dynamic=True, device="cpu")] = [
     *CI_SKIP[CI("inductor", training=False, device="cpu")],
     "pyhpc_isoneutral_mixing",
diff --git a/benchmarks/dynamo/timm_models.py b/benchmarks/dynamo/timm_models.py
index 89684c83acab..8e39dd5a4f1c 100755
--- a/benchmarks/dynamo/timm_models.py
+++ b/benchmarks/dynamo/timm_models.py
@@ -82,6 +82,10 @@ def pip_install(package):
     "xcit_large_24_p8_224",
 }
 
+SKIP_ACCURACY_CHECK_AS_EAGER_NON_DETERMINISTIC_MODELS = {
+    "xcit_large_24_p8_224",
+}
+
 
 def refresh_model_names():
     import glob
@@ -181,6 +185,12 @@ def force_amp_for_fp16_bf16_models(self):
     def force_fp16_for_bf16_models(self):
         return set()
 
+    @property
+    def skip_accuracy_check_as_eager_non_deterministic(self):
+        if self.args.accuracy and self.args.training:
+            return SKIP_ACCURACY_CHECK_AS_EAGER_NON_DETERMINISTIC_MODELS
+        return set()
+
     @download_retry_decorator
     def _download_model(self, model_name):
         model = create_model(

From a9f9f98e2f0f906170b681a57c420601fd88ea5a Mon Sep 17 00:00:00 2001
From: Bin Bao 
Date: Tue, 21 Nov 2023 13:23:07 -0500
Subject: [PATCH 064/221] [CI] Switch to check against expected result files
 for dynamo_eager and aot_eager benchmark tests (#113559)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113559
Approved by: https://github.com/ezyang, https://github.com/jansel
ghstack dependencies: #113574, #113575, #113446
---
 .ci/pytorch/test.sh                           |   3 +-
 .../aot_eager_huggingface_inference.csv       | 185 ++++++++++++
 .../aot_eager_huggingface_training.csv        | 181 +++++++++++
 .../aot_eager_timm_inference.csv              | 245 +++++++++++++++
 .../aot_eager_timm_training.csv               | 241 +++++++++++++++
 .../aot_eager_torchbench_inference.csv        | 285 ++++++++++++++++++
 .../aot_eager_torchbench_training.csv         | 245 +++++++++++++++
 ...ynamic_aot_eager_huggingface_inference.csv | 185 ++++++++++++
 ...dynamic_aot_eager_huggingface_training.csv | 181 +++++++++++
 .../dynamic_aot_eager_timm_inference.csv      | 245 +++++++++++++++
 .../dynamic_aot_eager_timm_training.csv       | 241 +++++++++++++++
 ...dynamic_aot_eager_torchbench_inference.csv | 285 ++++++++++++++++++
 .../dynamic_aot_eager_torchbench_training.csv | 245 +++++++++++++++
 .../dynamo_eager_huggingface_inference.csv    | 185 ++++++++++++
 .../dynamo_eager_huggingface_training.csv     | 181 +++++++++++
 .../dynamo_eager_timm_inference.csv           | 245 +++++++++++++++
 .../dynamo_eager_timm_training.csv            | 241 +++++++++++++++
 .../dynamo_eager_torchbench_inference.csv     | 285 ++++++++++++++++++
 .../dynamo_eager_torchbench_training.csv      | 245 +++++++++++++++
 .../ci_expected_accuracy/update_expected.py   |  28 +-
 benchmarks/dynamo/common.py                   |  97 ------
 21 files changed, 4165 insertions(+), 109 deletions(-)
 create mode 100644 benchmarks/dynamo/ci_expected_accuracy/aot_eager_huggingface_inference.csv
 create mode 100644 benchmarks/dynamo/ci_expected_accuracy/aot_eager_huggingface_training.csv
 create mode 100644 benchmarks/dynamo/ci_expected_accuracy/aot_eager_timm_inference.csv
 create mode 100644 benchmarks/dynamo/ci_expected_accuracy/aot_eager_timm_training.csv
 create mode 100644 benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv
 create mode 100644 benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_training.csv
 create mode 100644 benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_huggingface_inference.csv
 create mode 100644 benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_huggingface_training.csv
 create mode 100644 benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_timm_inference.csv
 create mode 100644 benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_timm_training.csv
 create mode 100644 benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv
 create mode 100644 benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_training.csv
 create mode 100644 benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_huggingface_inference.csv
 create mode 100644 benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_huggingface_training.csv
 create mode 100644 benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_timm_inference.csv
 create mode 100644 benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_timm_training.csv
 create mode 100644 benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv
 create mode 100644 benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_training.csv

diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh
index f2857b841397..b085e427acfd 100755
--- a/.ci/pytorch/test.sh
+++ b/.ci/pytorch/test.sh
@@ -452,8 +452,7 @@ test_single_dynamo_benchmark() {
       "$@" "${partition_flags[@]}" \
       --output "$TEST_REPORTS_DIR/${name}_${suite}.csv"
 
-    if [[ "${TEST_CONFIG}" == *inductor* ]] && [[ "${TEST_CONFIG}" != *cpu_accuracy* ]]; then
-      # other jobs (e.g. periodic, cpu-accuracy) may have different set of expected models.
+    if [[ "${TEST_CONFIG}" != *cpu_accuracy* ]]; then
       python benchmarks/dynamo/check_accuracy.py \
         --actual "$TEST_REPORTS_DIR/${name}_$suite.csv" \
         --expected "benchmarks/dynamo/ci_expected_accuracy/${TEST_CONFIG}_${name}.csv"
diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_huggingface_inference.csv
new file mode 100644
index 000000000000..349239b058a7
--- /dev/null
+++ b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_huggingface_inference.csv
@@ -0,0 +1,185 @@
+name,accuracy,graph_breaks
+
+
+
+AlbertForMaskedLM,pass,0
+
+
+
+AlbertForQuestionAnswering,pass,0
+
+
+
+AllenaiLongformerBase,pass,4
+
+
+
+BartForCausalLM,pass,0
+
+
+
+BartForConditionalGeneration,pass,0
+
+
+
+BertForMaskedLM,pass,0
+
+
+
+BertForQuestionAnswering,pass,0
+
+
+
+BlenderbotForCausalLM,pass_due_to_skip,0
+
+
+
+BlenderbotSmallForCausalLM,pass,0
+
+
+
+BlenderbotSmallForConditionalGeneration,pass,0
+
+
+
+CamemBert,pass,0
+
+
+
+DebertaForMaskedLM,pass,0
+
+
+
+DebertaForQuestionAnswering,pass,0
+
+
+
+DebertaV2ForMaskedLM,pass_due_to_skip,0
+
+
+
+DebertaV2ForQuestionAnswering,pass,0
+
+
+
+DistilBertForMaskedLM,pass,0
+
+
+
+DistilBertForQuestionAnswering,pass,0
+
+
+
+DistillGPT2,pass,0
+
+
+
+ElectraForCausalLM,pass,0
+
+
+
+ElectraForQuestionAnswering,pass,0
+
+
+
+GPT2ForSequenceClassification,pass,2
+
+
+
+GoogleFnet,pass,0
+
+
+
+LayoutLMForMaskedLM,pass,0
+
+
+
+LayoutLMForSequenceClassification,pass,2
+
+
+
+M2M100ForConditionalGeneration,pass,0
+
+
+
+MBartForCausalLM,pass,0
+
+
+
+MBartForConditionalGeneration,pass,0
+
+
+
+MT5ForConditionalGeneration,pass,0
+
+
+
+MegatronBertForCausalLM,pass,0
+
+
+
+MegatronBertForQuestionAnswering,pass,0
+
+
+
+MobileBertForMaskedLM,pass,0
+
+
+
+MobileBertForQuestionAnswering,pass,0
+
+
+
+OPTForCausalLM,pass,0
+
+
+
+PLBartForCausalLM,pass,0
+
+
+
+PLBartForConditionalGeneration,pass,0
+
+
+
+PegasusForCausalLM,pass,0
+
+
+
+PegasusForConditionalGeneration,pass,0
+
+
+
+RobertaForCausalLM,pass,0
+
+
+
+RobertaForQuestionAnswering,pass,0
+
+
+
+Speech2Text2ForCausalLM,pass,0
+
+
+
+T5ForConditionalGeneration,pass,0
+
+
+
+T5Small,pass,0
+
+
+
+TrOCRForCausalLM,pass,0
+
+
+
+XGLMForCausalLM,pass,0
+
+
+
+XLNetLMHeadModel,pass,0
+
+
+
+YituTechConvBert,pass,0
diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_huggingface_training.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_huggingface_training.csv
new file mode 100644
index 000000000000..6bccf5082c70
--- /dev/null
+++ b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_huggingface_training.csv
@@ -0,0 +1,181 @@
+name,accuracy,graph_breaks
+
+
+
+AlbertForMaskedLM,pass,6
+
+
+
+AlbertForQuestionAnswering,pass,6
+
+
+
+AllenaiLongformerBase,pass,10
+
+
+
+BartForCausalLM,pass,14
+
+
+
+BartForConditionalGeneration,pass,26
+
+
+
+BertForMaskedLM,pass,6
+
+
+
+BertForQuestionAnswering,pass,6
+
+
+
+BlenderbotSmallForCausalLM,pass,14
+
+
+
+BlenderbotSmallForConditionalGeneration,pass,26
+
+
+
+CamemBert,pass,6
+
+
+
+DebertaForMaskedLM,pass,6
+
+
+
+DebertaForQuestionAnswering,pass,6
+
+
+
+DebertaV2ForMaskedLM,pass_due_to_skip,0
+
+
+
+DebertaV2ForQuestionAnswering,eager_1st_run_OOM,0
+
+
+
+DistilBertForMaskedLM,pass,6
+
+
+
+DistilBertForQuestionAnswering,pass,6
+
+
+
+DistillGPT2,pass,6
+
+
+
+ElectraForCausalLM,pass,6
+
+
+
+ElectraForQuestionAnswering,pass,6
+
+
+
+GPT2ForSequenceClassification,pass,8
+
+
+
+GoogleFnet,pass,6
+
+
+
+LayoutLMForMaskedLM,pass,6
+
+
+
+LayoutLMForSequenceClassification,pass,8
+
+
+
+M2M100ForConditionalGeneration,pass,6
+
+
+
+MBartForCausalLM,pass,14
+
+
+
+MBartForConditionalGeneration,pass,26
+
+
+
+MT5ForConditionalGeneration,pass,6
+
+
+
+MegatronBertForCausalLM,pass,6
+
+
+
+MegatronBertForQuestionAnswering,pass,6
+
+
+
+MobileBertForMaskedLM,pass,4
+
+
+
+MobileBertForQuestionAnswering,pass,4
+
+
+
+OPTForCausalLM,pass,14
+
+
+
+PLBartForCausalLM,pass,14
+
+
+
+PLBartForConditionalGeneration,pass,31
+
+
+
+PegasusForCausalLM,pass,14
+
+
+
+PegasusForConditionalGeneration,pass,24
+
+
+
+RobertaForCausalLM,pass,6
+
+
+
+RobertaForQuestionAnswering,pass,6
+
+
+
+Speech2Text2ForCausalLM,pass,14
+
+
+
+T5ForConditionalGeneration,pass,6
+
+
+
+T5Small,pass,6
+
+
+
+TrOCRForCausalLM,pass,14
+
+
+
+XGLMForCausalLM,pass,14
+
+
+
+XLNetLMHeadModel,pass,6
+
+
+
+YituTechConvBert,pass,6
diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_timm_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_timm_inference.csv
new file mode 100644
index 000000000000..c889ba0e8d2f
--- /dev/null
+++ b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_timm_inference.csv
@@ -0,0 +1,245 @@
+name,accuracy,graph_breaks
+
+
+
+adv_inception_v3,pass,0
+
+
+
+beit_base_patch16_224,pass,0
+
+
+
+botnet26t_256,pass,0
+
+
+
+cait_m36_384,pass,0
+
+
+
+coat_lite_mini,pass,0
+
+
+
+convit_base,pass,0
+
+
+
+convmixer_768_32,pass,0
+
+
+
+convnext_base,pass,0
+
+
+
+crossvit_9_240,pass,0
+
+
+
+cspdarknet53,pass,0
+
+
+
+deit_base_distilled_patch16_224,pass,0
+
+
+
+dla102,pass,0
+
+
+
+dm_nfnet_f0,pass,0
+
+
+
+dpn107,pass,0
+
+
+
+eca_botnext26ts_256,pass,0
+
+
+
+eca_halonext26ts,pass,0
+
+
+
+ese_vovnet19b_dw,pass,0
+
+
+
+fbnetc_100,pass,0
+
+
+
+fbnetv3_b,pass,0
+
+
+
+gernet_l,pass,0
+
+
+
+ghostnet_100,pass,0
+
+
+
+gluon_inception_v3,pass,0
+
+
+
+gmixer_24_224,pass,0
+
+
+
+gmlp_s16_224,pass,0
+
+
+
+hrnet_w18,pass,0
+
+
+
+inception_v3,pass,0
+
+
+
+jx_nest_base,pass,0
+
+
+
+lcnet_050,pass,0
+
+
+
+levit_128,pass,0
+
+
+
+mixer_b16_224,pass,0
+
+
+
+mixnet_l,pass,0
+
+
+
+mnasnet_100,pass,0
+
+
+
+mobilenetv2_100,pass,0
+
+
+
+mobilenetv3_large_100,pass,0
+
+
+
+mobilevit_s,pass,0
+
+
+
+nfnet_l0,pass,0
+
+
+
+pit_b_224,pass,0
+
+
+
+pnasnet5large,pass,0
+
+
+
+poolformer_m36,pass,0
+
+
+
+regnety_002,pass,0
+
+
+
+repvgg_a2,pass,0
+
+
+
+res2net101_26w_4s,pass,0
+
+
+
+res2net50_14w_8s,pass,0
+
+
+
+res2next50,pass,0
+
+
+
+resmlp_12_224,pass,0
+
+
+
+resnest101e,pass,0
+
+
+
+rexnet_100,pass,0
+
+
+
+sebotnet33ts_256,pass,0
+
+
+
+selecsls42b,pass,0
+
+
+
+spnasnet_100,pass,0
+
+
+
+swin_base_patch4_window7_224,pass,0
+
+
+
+swsl_resnext101_32x16d,pass,0
+
+
+
+tf_efficientnet_b0,pass,0
+
+
+
+tf_mixnet_l,pass,0
+
+
+
+tinynet_a,pass,0
+
+
+
+tnt_s_patch16_224,pass,0
+
+
+
+twins_pcpvt_base,pass,0
+
+
+
+visformer_small,pass,0
+
+
+
+vit_base_patch16_224,pass,0
+
+
+
+volo_d1_224,pass,0
+
+
+
+xcit_large_24_p8_224,pass,0
diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_timm_training.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_timm_training.csv
new file mode 100644
index 000000000000..d733037b8e88
--- /dev/null
+++ b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_timm_training.csv
@@ -0,0 +1,241 @@
+name,accuracy,graph_breaks
+
+
+
+adv_inception_v3,pass,8
+
+
+
+beit_base_patch16_224,pass,8
+
+
+
+botnet26t_256,pass,8
+
+
+
+coat_lite_mini,pass,8
+
+
+
+convit_base,pass,8
+
+
+
+convmixer_768_32,pass,6
+
+
+
+convnext_base,pass,8
+
+
+
+crossvit_9_240,pass,8
+
+
+
+cspdarknet53,pass,8
+
+
+
+deit_base_distilled_patch16_224,pass,8
+
+
+
+dla102,pass,8
+
+
+
+dm_nfnet_f0,pass,8
+
+
+
+dpn107,pass,8
+
+
+
+eca_botnext26ts_256,pass,8
+
+
+
+eca_halonext26ts,pass,8
+
+
+
+ese_vovnet19b_dw,pass,8
+
+
+
+fbnetc_100,pass,8
+
+
+
+fbnetv3_b,pass,8
+
+
+
+gernet_l,pass,8
+
+
+
+ghostnet_100,pass,8
+
+
+
+gluon_inception_v3,pass,8
+
+
+
+gmixer_24_224,pass,8
+
+
+
+gmlp_s16_224,pass,8
+
+
+
+hrnet_w18,pass,6
+
+
+
+inception_v3,pass,8
+
+
+
+jx_nest_base,pass,8
+
+
+
+lcnet_050,fail_accuracy,8
+
+
+
+levit_128,pass,8
+
+
+
+mixer_b16_224,pass,8
+
+
+
+mixnet_l,pass,8
+
+
+
+mnasnet_100,pass,8
+
+
+
+mobilenetv2_100,pass,8
+
+
+
+mobilenetv3_large_100,pass,8
+
+
+
+mobilevit_s,pass,8
+
+
+
+nfnet_l0,pass,8
+
+
+
+pit_b_224,pass,8
+
+
+
+pnasnet5large,pass,6
+
+
+
+poolformer_m36,pass,8
+
+
+
+regnety_002,pass,8
+
+
+
+repvgg_a2,pass,8
+
+
+
+res2net101_26w_4s,pass,8
+
+
+
+res2net50_14w_8s,pass,8
+
+
+
+res2next50,pass,8
+
+
+
+resmlp_12_224,pass,8
+
+
+
+resnest101e,pass,8
+
+
+
+rexnet_100,pass,8
+
+
+
+sebotnet33ts_256,pass,8
+
+
+
+selecsls42b,pass,8
+
+
+
+spnasnet_100,pass,8
+
+
+
+swin_base_patch4_window7_224,pass,8
+
+
+
+swsl_resnext101_32x16d,pass,8
+
+
+
+tf_efficientnet_b0,pass,8
+
+
+
+tf_mixnet_l,pass,8
+
+
+
+tinynet_a,pass,8
+
+
+
+tnt_s_patch16_224,pass,8
+
+
+
+twins_pcpvt_base,pass,8
+
+
+
+visformer_small,pass,8
+
+
+
+vit_base_patch16_224,pass,8
+
+
+
+volo_d1_224,pass,8
+
+
+
+xcit_large_24_p8_224,pass_due_to_skip,8
diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv
new file mode 100644
index 000000000000..0854562b587a
--- /dev/null
+++ b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv
@@ -0,0 +1,285 @@
+name,accuracy,graph_breaks
+
+
+
+BERT_pytorch,pass,0
+
+
+
+Background_Matting,pass_due_to_skip,0
+
+
+
+DALLE2_pytorch,fail_to_run,21
+
+
+
+LearningToPaint,pass,0
+
+
+
+Super_SloMo,pass,0
+
+
+
+alexnet,pass,0
+
+
+
+basic_gnn_edgecnn,pass,0
+
+
+
+basic_gnn_gcn,pass,6
+
+
+
+basic_gnn_gin,pass,0
+
+
+
+basic_gnn_sage,pass,0
+
+
+
+cm3leon_generate,pass,4
+
+
+
+dcgan,pass,0
+
+
+
+densenet121,pass,0
+
+
+
+detectron2_fcos_r_50_fpn,pass,41
+
+
+
+dlrm,pass,0
+
+
+
+doctr_det_predictor,pass,5
+
+
+
+doctr_reco_predictor,pass,4
+
+
+
+drq,pass,0
+
+
+
+fastNLP_Bert,pass,4
+
+
+
+functorch_dp_cifar10,pass,0
+
+
+
+functorch_maml_omniglot,pass,0
+
+
+
+hf_Albert,pass,0
+
+
+
+hf_Bart,pass,0
+
+
+
+hf_BigBird,pass,0
+
+
+
+hf_DistilBert,pass,0
+
+
+
+hf_GPT2,pass,0
+
+
+
+hf_GPT2_large,pass_due_to_skip,0
+
+
+
+hf_Reformer,pass,5
+
+
+
+hf_T5,pass,0
+
+
+
+hf_T5_generate,pass,20
+
+
+
+hf_T5_large,pass_due_to_skip,0
+
+
+
+hf_Whisper,pass,0
+
+
+
+lennard_jones,pass,0
+
+
+
+llama,pass,0
+
+
+
+maml_omniglot,pass,0
+
+
+
+mnasnet1_0,pass,0
+
+
+
+mobilenet_v2,pass,0
+
+
+
+mobilenet_v3_large,pass,0
+
+
+
+moco,pass,11
+
+
+
+nanogpt,pass,0
+
+
+
+nvidia_deeprecommender,pass,0
+
+
+
+opacus_cifar10,pass,0
+
+
+
+phlippe_densenet,pass,0
+
+
+
+phlippe_resnet,pass,0
+
+
+
+pyhpc_equation_of_state,pass,0
+
+
+
+pyhpc_isoneutral_mixing,pass,0
+
+
+
+pyhpc_turbulent_kinetic_energy,pass,0
+
+
+
+pytorch_CycleGAN_and_pix2pix,pass,0
+
+
+
+pytorch_stargan,pass,0
+
+
+
+pytorch_unet,pass,0
+
+
+
+resnet152,pass,0
+
+
+
+resnet18,pass,0
+
+
+
+resnet50,pass,0
+
+
+
+resnext50_32x4d,pass,0
+
+
+
+sam,pass,0
+
+
+
+shufflenet_v2_x1_0,pass,0
+
+
+
+soft_actor_critic,pass,0
+
+
+
+speech_transformer,pass,10
+
+
+
+squeezenet1_1,pass,0
+
+
+
+stable_diffusion_text_encoder,pass,0
+
+
+
+stable_diffusion_unet,pass_due_to_skip,0
+
+
+
+timm_efficientnet,pass,0
+
+
+
+timm_regnet,pass,0
+
+
+
+timm_resnest,pass,0
+
+
+
+timm_vision_transformer,pass,0
+
+
+
+timm_vision_transformer_large,pass_due_to_skip,0
+
+
+
+timm_vovnet,pass,0
+
+
+
+tts_angular,pass,2
+
+
+
+vgg16,pass,0
+
+
+
+vision_maskrcnn,pass,17
+
+
+
+yolov3,pass,2
diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_training.csv
new file mode 100644
index 000000000000..aa7e8fab9eee
--- /dev/null
+++ b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_training.csv
@@ -0,0 +1,245 @@
+name,accuracy,graph_breaks
+
+
+
+torchrec_dlrm,pass,6
+
+
+
+BERT_pytorch,pass,8
+
+
+
+Background_Matting,pass_due_to_skip,0
+
+
+
+LearningToPaint,pass,8
+
+
+
+Super_SloMo,pass,8
+
+
+
+alexnet,pass,8
+
+
+
+basic_gnn_edgecnn,pass,23
+
+
+
+basic_gnn_gcn,pass,14
+
+
+
+basic_gnn_gin,pass,8
+
+
+
+basic_gnn_sage,pass,8
+
+
+
+dcgan,pass,8
+
+
+
+demucs,pass,11
+
+
+
+densenet121,pass,8
+
+
+
+dlrm,pass,8
+
+
+
+drq,pass,7
+
+
+
+fastNLP_Bert,pass,12
+
+
+
+functorch_dp_cifar10,pass,8
+
+
+
+functorch_maml_omniglot,pass,8
+
+
+
+hf_Albert,pass,7
+
+
+
+hf_Bart,pass,7
+
+
+
+hf_BigBird,pass,7
+
+
+
+hf_DistilBert,pass,7
+
+
+
+hf_GPT2,pass,7
+
+
+
+hf_GPT2_large,pass_due_to_skip,0
+
+
+
+hf_Reformer,pass,27
+
+
+
+hf_T5_base,OOM,7
+
+
+
+hf_T5_large,pass_due_to_skip,0
+
+
+
+hf_Whisper,pass,7
+
+
+
+lennard_jones,pass,8
+
+
+
+maml_omniglot,pass,8
+
+
+
+mnasnet1_0,pass,8
+
+
+
+mobilenet_v2,pass,8
+
+
+
+mobilenet_v3_large,pass,8
+
+
+
+moco,pass,19
+
+
+
+nanogpt,pass,8
+
+
+
+nvidia_deeprecommender,pass,8
+
+
+
+phlippe_densenet,pass,8
+
+
+
+phlippe_resnet,pass,8
+
+
+
+pytorch_CycleGAN_and_pix2pix,pass,8
+
+
+
+pytorch_stargan,pass,8
+
+
+
+pytorch_unet,pass,8
+
+
+
+resnet152,pass,8
+
+
+
+resnet18,pass,8
+
+
+
+resnet50,pass,8
+
+
+
+resnext50_32x4d,pass,8
+
+
+
+shufflenet_v2_x1_0,pass,8
+
+
+
+soft_actor_critic,pass,7
+
+
+
+speech_transformer,pass,18
+
+
+
+squeezenet1_1,pass,8
+
+
+
+stable_diffusion_text_encoder,pass,7
+
+
+
+stable_diffusion_unet,pass_due_to_skip,0
+
+
+
+timm_efficientnet,pass,8
+
+
+
+timm_regnet,pass,8
+
+
+
+timm_resnest,pass,8
+
+
+
+timm_vision_transformer,pass,8
+
+
+
+timm_vision_transformer_large,pass_due_to_skip,0
+
+
+
+timm_vovnet,pass,8
+
+
+
+tts_angular,pass,10
+
+
+
+vgg16,pass,8
+
+
+
+vision_maskrcnn,pass,37
+
+
+
+yolov3,pass,10
diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_huggingface_inference.csv
new file mode 100644
index 000000000000..349239b058a7
--- /dev/null
+++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_huggingface_inference.csv
@@ -0,0 +1,185 @@
+name,accuracy,graph_breaks
+
+
+
+AlbertForMaskedLM,pass,0
+
+
+
+AlbertForQuestionAnswering,pass,0
+
+
+
+AllenaiLongformerBase,pass,4
+
+
+
+BartForCausalLM,pass,0
+
+
+
+BartForConditionalGeneration,pass,0
+
+
+
+BertForMaskedLM,pass,0
+
+
+
+BertForQuestionAnswering,pass,0
+
+
+
+BlenderbotForCausalLM,pass_due_to_skip,0
+
+
+
+BlenderbotSmallForCausalLM,pass,0
+
+
+
+BlenderbotSmallForConditionalGeneration,pass,0
+
+
+
+CamemBert,pass,0
+
+
+
+DebertaForMaskedLM,pass,0
+
+
+
+DebertaForQuestionAnswering,pass,0
+
+
+
+DebertaV2ForMaskedLM,pass_due_to_skip,0
+
+
+
+DebertaV2ForQuestionAnswering,pass,0
+
+
+
+DistilBertForMaskedLM,pass,0
+
+
+
+DistilBertForQuestionAnswering,pass,0
+
+
+
+DistillGPT2,pass,0
+
+
+
+ElectraForCausalLM,pass,0
+
+
+
+ElectraForQuestionAnswering,pass,0
+
+
+
+GPT2ForSequenceClassification,pass,2
+
+
+
+GoogleFnet,pass,0
+
+
+
+LayoutLMForMaskedLM,pass,0
+
+
+
+LayoutLMForSequenceClassification,pass,2
+
+
+
+M2M100ForConditionalGeneration,pass,0
+
+
+
+MBartForCausalLM,pass,0
+
+
+
+MBartForConditionalGeneration,pass,0
+
+
+
+MT5ForConditionalGeneration,pass,0
+
+
+
+MegatronBertForCausalLM,pass,0
+
+
+
+MegatronBertForQuestionAnswering,pass,0
+
+
+
+MobileBertForMaskedLM,pass,0
+
+
+
+MobileBertForQuestionAnswering,pass,0
+
+
+
+OPTForCausalLM,pass,0
+
+
+
+PLBartForCausalLM,pass,0
+
+
+
+PLBartForConditionalGeneration,pass,0
+
+
+
+PegasusForCausalLM,pass,0
+
+
+
+PegasusForConditionalGeneration,pass,0
+
+
+
+RobertaForCausalLM,pass,0
+
+
+
+RobertaForQuestionAnswering,pass,0
+
+
+
+Speech2Text2ForCausalLM,pass,0
+
+
+
+T5ForConditionalGeneration,pass,0
+
+
+
+T5Small,pass,0
+
+
+
+TrOCRForCausalLM,pass,0
+
+
+
+XGLMForCausalLM,pass,0
+
+
+
+XLNetLMHeadModel,pass,0
+
+
+
+YituTechConvBert,pass,0
diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_huggingface_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_huggingface_training.csv
new file mode 100644
index 000000000000..6bccf5082c70
--- /dev/null
+++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_huggingface_training.csv
@@ -0,0 +1,181 @@
+name,accuracy,graph_breaks
+
+
+
+AlbertForMaskedLM,pass,6
+
+
+
+AlbertForQuestionAnswering,pass,6
+
+
+
+AllenaiLongformerBase,pass,10
+
+
+
+BartForCausalLM,pass,14
+
+
+
+BartForConditionalGeneration,pass,26
+
+
+
+BertForMaskedLM,pass,6
+
+
+
+BertForQuestionAnswering,pass,6
+
+
+
+BlenderbotSmallForCausalLM,pass,14
+
+
+
+BlenderbotSmallForConditionalGeneration,pass,26
+
+
+
+CamemBert,pass,6
+
+
+
+DebertaForMaskedLM,pass,6
+
+
+
+DebertaForQuestionAnswering,pass,6
+
+
+
+DebertaV2ForMaskedLM,pass_due_to_skip,0
+
+
+
+DebertaV2ForQuestionAnswering,eager_1st_run_OOM,0
+
+
+
+DistilBertForMaskedLM,pass,6
+
+
+
+DistilBertForQuestionAnswering,pass,6
+
+
+
+DistillGPT2,pass,6
+
+
+
+ElectraForCausalLM,pass,6
+
+
+
+ElectraForQuestionAnswering,pass,6
+
+
+
+GPT2ForSequenceClassification,pass,8
+
+
+
+GoogleFnet,pass,6
+
+
+
+LayoutLMForMaskedLM,pass,6
+
+
+
+LayoutLMForSequenceClassification,pass,8
+
+
+
+M2M100ForConditionalGeneration,pass,6
+
+
+
+MBartForCausalLM,pass,14
+
+
+
+MBartForConditionalGeneration,pass,26
+
+
+
+MT5ForConditionalGeneration,pass,6
+
+
+
+MegatronBertForCausalLM,pass,6
+
+
+
+MegatronBertForQuestionAnswering,pass,6
+
+
+
+MobileBertForMaskedLM,pass,4
+
+
+
+MobileBertForQuestionAnswering,pass,4
+
+
+
+OPTForCausalLM,pass,14
+
+
+
+PLBartForCausalLM,pass,14
+
+
+
+PLBartForConditionalGeneration,pass,31
+
+
+
+PegasusForCausalLM,pass,14
+
+
+
+PegasusForConditionalGeneration,pass,24
+
+
+
+RobertaForCausalLM,pass,6
+
+
+
+RobertaForQuestionAnswering,pass,6
+
+
+
+Speech2Text2ForCausalLM,pass,14
+
+
+
+T5ForConditionalGeneration,pass,6
+
+
+
+T5Small,pass,6
+
+
+
+TrOCRForCausalLM,pass,14
+
+
+
+XGLMForCausalLM,pass,14
+
+
+
+XLNetLMHeadModel,pass,6
+
+
+
+YituTechConvBert,pass,6
diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_timm_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_timm_inference.csv
new file mode 100644
index 000000000000..c889ba0e8d2f
--- /dev/null
+++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_timm_inference.csv
@@ -0,0 +1,245 @@
+name,accuracy,graph_breaks
+
+
+
+adv_inception_v3,pass,0
+
+
+
+beit_base_patch16_224,pass,0
+
+
+
+botnet26t_256,pass,0
+
+
+
+cait_m36_384,pass,0
+
+
+
+coat_lite_mini,pass,0
+
+
+
+convit_base,pass,0
+
+
+
+convmixer_768_32,pass,0
+
+
+
+convnext_base,pass,0
+
+
+
+crossvit_9_240,pass,0
+
+
+
+cspdarknet53,pass,0
+
+
+
+deit_base_distilled_patch16_224,pass,0
+
+
+
+dla102,pass,0
+
+
+
+dm_nfnet_f0,pass,0
+
+
+
+dpn107,pass,0
+
+
+
+eca_botnext26ts_256,pass,0
+
+
+
+eca_halonext26ts,pass,0
+
+
+
+ese_vovnet19b_dw,pass,0
+
+
+
+fbnetc_100,pass,0
+
+
+
+fbnetv3_b,pass,0
+
+
+
+gernet_l,pass,0
+
+
+
+ghostnet_100,pass,0
+
+
+
+gluon_inception_v3,pass,0
+
+
+
+gmixer_24_224,pass,0
+
+
+
+gmlp_s16_224,pass,0
+
+
+
+hrnet_w18,pass,0
+
+
+
+inception_v3,pass,0
+
+
+
+jx_nest_base,pass,0
+
+
+
+lcnet_050,pass,0
+
+
+
+levit_128,pass,0
+
+
+
+mixer_b16_224,pass,0
+
+
+
+mixnet_l,pass,0
+
+
+
+mnasnet_100,pass,0
+
+
+
+mobilenetv2_100,pass,0
+
+
+
+mobilenetv3_large_100,pass,0
+
+
+
+mobilevit_s,pass,0
+
+
+
+nfnet_l0,pass,0
+
+
+
+pit_b_224,pass,0
+
+
+
+pnasnet5large,pass,0
+
+
+
+poolformer_m36,pass,0
+
+
+
+regnety_002,pass,0
+
+
+
+repvgg_a2,pass,0
+
+
+
+res2net101_26w_4s,pass,0
+
+
+
+res2net50_14w_8s,pass,0
+
+
+
+res2next50,pass,0
+
+
+
+resmlp_12_224,pass,0
+
+
+
+resnest101e,pass,0
+
+
+
+rexnet_100,pass,0
+
+
+
+sebotnet33ts_256,pass,0
+
+
+
+selecsls42b,pass,0
+
+
+
+spnasnet_100,pass,0
+
+
+
+swin_base_patch4_window7_224,pass,0
+
+
+
+swsl_resnext101_32x16d,pass,0
+
+
+
+tf_efficientnet_b0,pass,0
+
+
+
+tf_mixnet_l,pass,0
+
+
+
+tinynet_a,pass,0
+
+
+
+tnt_s_patch16_224,pass,0
+
+
+
+twins_pcpvt_base,pass,0
+
+
+
+visformer_small,pass,0
+
+
+
+vit_base_patch16_224,pass,0
+
+
+
+volo_d1_224,pass,0
+
+
+
+xcit_large_24_p8_224,pass,0
diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_timm_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_timm_training.csv
new file mode 100644
index 000000000000..d733037b8e88
--- /dev/null
+++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_timm_training.csv
@@ -0,0 +1,241 @@
+name,accuracy,graph_breaks
+
+
+
+adv_inception_v3,pass,8
+
+
+
+beit_base_patch16_224,pass,8
+
+
+
+botnet26t_256,pass,8
+
+
+
+coat_lite_mini,pass,8
+
+
+
+convit_base,pass,8
+
+
+
+convmixer_768_32,pass,6
+
+
+
+convnext_base,pass,8
+
+
+
+crossvit_9_240,pass,8
+
+
+
+cspdarknet53,pass,8
+
+
+
+deit_base_distilled_patch16_224,pass,8
+
+
+
+dla102,pass,8
+
+
+
+dm_nfnet_f0,pass,8
+
+
+
+dpn107,pass,8
+
+
+
+eca_botnext26ts_256,pass,8
+
+
+
+eca_halonext26ts,pass,8
+
+
+
+ese_vovnet19b_dw,pass,8
+
+
+
+fbnetc_100,pass,8
+
+
+
+fbnetv3_b,pass,8
+
+
+
+gernet_l,pass,8
+
+
+
+ghostnet_100,pass,8
+
+
+
+gluon_inception_v3,pass,8
+
+
+
+gmixer_24_224,pass,8
+
+
+
+gmlp_s16_224,pass,8
+
+
+
+hrnet_w18,pass,6
+
+
+
+inception_v3,pass,8
+
+
+
+jx_nest_base,pass,8
+
+
+
+lcnet_050,fail_accuracy,8
+
+
+
+levit_128,pass,8
+
+
+
+mixer_b16_224,pass,8
+
+
+
+mixnet_l,pass,8
+
+
+
+mnasnet_100,pass,8
+
+
+
+mobilenetv2_100,pass,8
+
+
+
+mobilenetv3_large_100,pass,8
+
+
+
+mobilevit_s,pass,8
+
+
+
+nfnet_l0,pass,8
+
+
+
+pit_b_224,pass,8
+
+
+
+pnasnet5large,pass,6
+
+
+
+poolformer_m36,pass,8
+
+
+
+regnety_002,pass,8
+
+
+
+repvgg_a2,pass,8
+
+
+
+res2net101_26w_4s,pass,8
+
+
+
+res2net50_14w_8s,pass,8
+
+
+
+res2next50,pass,8
+
+
+
+resmlp_12_224,pass,8
+
+
+
+resnest101e,pass,8
+
+
+
+rexnet_100,pass,8
+
+
+
+sebotnet33ts_256,pass,8
+
+
+
+selecsls42b,pass,8
+
+
+
+spnasnet_100,pass,8
+
+
+
+swin_base_patch4_window7_224,pass,8
+
+
+
+swsl_resnext101_32x16d,pass,8
+
+
+
+tf_efficientnet_b0,pass,8
+
+
+
+tf_mixnet_l,pass,8
+
+
+
+tinynet_a,pass,8
+
+
+
+tnt_s_patch16_224,pass,8
+
+
+
+twins_pcpvt_base,pass,8
+
+
+
+visformer_small,pass,8
+
+
+
+vit_base_patch16_224,pass,8
+
+
+
+volo_d1_224,pass,8
+
+
+
+xcit_large_24_p8_224,pass_due_to_skip,8
diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv
new file mode 100644
index 000000000000..3d975ef1fe68
--- /dev/null
+++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv
@@ -0,0 +1,285 @@
+name,accuracy,graph_breaks
+
+
+
+BERT_pytorch,pass,0
+
+
+
+Background_Matting,pass_due_to_skip,0
+
+
+
+DALLE2_pytorch,fail_to_run,21
+
+
+
+LearningToPaint,pass,0
+
+
+
+Super_SloMo,pass,0
+
+
+
+alexnet,pass,0
+
+
+
+basic_gnn_edgecnn,pass,0
+
+
+
+basic_gnn_gcn,pass,6
+
+
+
+basic_gnn_gin,pass,0
+
+
+
+basic_gnn_sage,pass,0
+
+
+
+cm3leon_generate,pass,4
+
+
+
+dcgan,pass,0
+
+
+
+densenet121,pass,0
+
+
+
+detectron2_fcos_r_50_fpn,infra_error,0
+
+
+
+dlrm,pass,0
+
+
+
+doctr_det_predictor,pass,5
+
+
+
+doctr_reco_predictor,pass,4
+
+
+
+drq,pass,0
+
+
+
+fastNLP_Bert,pass,4
+
+
+
+functorch_dp_cifar10,pass,0
+
+
+
+functorch_maml_omniglot,pass,0
+
+
+
+hf_Albert,pass,0
+
+
+
+hf_Bart,pass,0
+
+
+
+hf_BigBird,fail_to_run,0
+
+
+
+hf_DistilBert,pass,0
+
+
+
+hf_GPT2,pass,0
+
+
+
+hf_GPT2_large,pass_due_to_skip,0
+
+
+
+hf_Reformer,pass,5
+
+
+
+hf_T5,pass,0
+
+
+
+hf_T5_generate,fail_to_run,10
+
+
+
+hf_T5_large,pass_due_to_skip,0
+
+
+
+hf_Whisper,pass,0
+
+
+
+lennard_jones,pass,0
+
+
+
+llama,pass,0
+
+
+
+maml_omniglot,pass,0
+
+
+
+mnasnet1_0,pass,0
+
+
+
+mobilenet_v2,pass,0
+
+
+
+mobilenet_v3_large,pass,0
+
+
+
+moco,pass,11
+
+
+
+nanogpt,pass,0
+
+
+
+nvidia_deeprecommender,pass,0
+
+
+
+opacus_cifar10,pass,0
+
+
+
+phlippe_densenet,pass,0
+
+
+
+phlippe_resnet,pass,0
+
+
+
+pyhpc_equation_of_state,pass,0
+
+
+
+pyhpc_isoneutral_mixing,pass,0
+
+
+
+pyhpc_turbulent_kinetic_energy,infra_error,0
+
+
+
+pytorch_CycleGAN_and_pix2pix,pass,0
+
+
+
+pytorch_stargan,pass,0
+
+
+
+pytorch_unet,pass,0
+
+
+
+resnet152,pass,0
+
+
+
+resnet18,pass,0
+
+
+
+resnet50,pass,0
+
+
+
+resnext50_32x4d,pass,0
+
+
+
+sam,pass,0
+
+
+
+shufflenet_v2_x1_0,pass,0
+
+
+
+soft_actor_critic,pass,0
+
+
+
+speech_transformer,pass,10
+
+
+
+squeezenet1_1,pass,0
+
+
+
+stable_diffusion_text_encoder,pass,0
+
+
+
+stable_diffusion_unet,pass_due_to_skip,0
+
+
+
+timm_efficientnet,pass,0
+
+
+
+timm_regnet,pass,0
+
+
+
+timm_resnest,pass,0
+
+
+
+timm_vision_transformer,pass,0
+
+
+
+timm_vision_transformer_large,pass_due_to_skip,0
+
+
+
+timm_vovnet,pass,0
+
+
+
+tts_angular,pass,2
+
+
+
+vgg16,pass,0
+
+
+
+vision_maskrcnn,pass,17
+
+
+
+yolov3,pass,2
diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_training.csv
new file mode 100644
index 000000000000..3440767110d8
--- /dev/null
+++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_training.csv
@@ -0,0 +1,245 @@
+name,accuracy,graph_breaks
+
+
+
+torchrec_dlrm,infra_error,0
+
+
+
+BERT_pytorch,pass,8
+
+
+
+Background_Matting,pass_due_to_skip,0
+
+
+
+LearningToPaint,pass,8
+
+
+
+Super_SloMo,pass,8
+
+
+
+alexnet,pass,8
+
+
+
+basic_gnn_edgecnn,pass,23
+
+
+
+basic_gnn_gcn,pass,14
+
+
+
+basic_gnn_gin,pass,8
+
+
+
+basic_gnn_sage,pass,8
+
+
+
+dcgan,pass,8
+
+
+
+demucs,pass,11
+
+
+
+densenet121,pass,8
+
+
+
+dlrm,pass,8
+
+
+
+drq,pass,7
+
+
+
+fastNLP_Bert,pass,12
+
+
+
+functorch_dp_cifar10,pass,8
+
+
+
+functorch_maml_omniglot,pass,8
+
+
+
+hf_Albert,pass,7
+
+
+
+hf_Bart,pass,7
+
+
+
+hf_BigBird,fail_to_run,4
+
+
+
+hf_DistilBert,pass,7
+
+
+
+hf_GPT2,pass,7
+
+
+
+hf_GPT2_large,pass_due_to_skip,0
+
+
+
+hf_Reformer,pass,27
+
+
+
+hf_T5_base,OOM,7
+
+
+
+hf_T5_large,pass_due_to_skip,0
+
+
+
+hf_Whisper,pass,7
+
+
+
+lennard_jones,pass,8
+
+
+
+maml_omniglot,pass,8
+
+
+
+mnasnet1_0,pass,8
+
+
+
+mobilenet_v2,pass,8
+
+
+
+mobilenet_v3_large,pass,8
+
+
+
+moco,pass,19
+
+
+
+nanogpt,pass,8
+
+
+
+nvidia_deeprecommender,pass,8
+
+
+
+phlippe_densenet,pass,8
+
+
+
+phlippe_resnet,pass,8
+
+
+
+pytorch_CycleGAN_and_pix2pix,pass,8
+
+
+
+pytorch_stargan,pass,8
+
+
+
+pytorch_unet,pass,8
+
+
+
+resnet152,pass,8
+
+
+
+resnet18,pass,8
+
+
+
+resnet50,pass,8
+
+
+
+resnext50_32x4d,pass,8
+
+
+
+shufflenet_v2_x1_0,pass,8
+
+
+
+soft_actor_critic,pass,7
+
+
+
+speech_transformer,infra_error,0
+
+
+
+squeezenet1_1,pass,8
+
+
+
+stable_diffusion_text_encoder,pass,7
+
+
+
+stable_diffusion_unet,pass_due_to_skip,0
+
+
+
+timm_efficientnet,pass,8
+
+
+
+timm_regnet,pass,8
+
+
+
+timm_resnest,pass,8
+
+
+
+timm_vision_transformer,pass,8
+
+
+
+timm_vision_transformer_large,pass_due_to_skip,0
+
+
+
+timm_vovnet,pass,8
+
+
+
+tts_angular,pass,10
+
+
+
+vgg16,pass,8
+
+
+
+vision_maskrcnn,pass,37
+
+
+
+yolov3,pass,10
diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_huggingface_inference.csv
new file mode 100644
index 000000000000..349239b058a7
--- /dev/null
+++ b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_huggingface_inference.csv
@@ -0,0 +1,185 @@
+name,accuracy,graph_breaks
+
+
+
+AlbertForMaskedLM,pass,0
+
+
+
+AlbertForQuestionAnswering,pass,0
+
+
+
+AllenaiLongformerBase,pass,4
+
+
+
+BartForCausalLM,pass,0
+
+
+
+BartForConditionalGeneration,pass,0
+
+
+
+BertForMaskedLM,pass,0
+
+
+
+BertForQuestionAnswering,pass,0
+
+
+
+BlenderbotForCausalLM,pass_due_to_skip,0
+
+
+
+BlenderbotSmallForCausalLM,pass,0
+
+
+
+BlenderbotSmallForConditionalGeneration,pass,0
+
+
+
+CamemBert,pass,0
+
+
+
+DebertaForMaskedLM,pass,0
+
+
+
+DebertaForQuestionAnswering,pass,0
+
+
+
+DebertaV2ForMaskedLM,pass_due_to_skip,0
+
+
+
+DebertaV2ForQuestionAnswering,pass,0
+
+
+
+DistilBertForMaskedLM,pass,0
+
+
+
+DistilBertForQuestionAnswering,pass,0
+
+
+
+DistillGPT2,pass,0
+
+
+
+ElectraForCausalLM,pass,0
+
+
+
+ElectraForQuestionAnswering,pass,0
+
+
+
+GPT2ForSequenceClassification,pass,2
+
+
+
+GoogleFnet,pass,0
+
+
+
+LayoutLMForMaskedLM,pass,0
+
+
+
+LayoutLMForSequenceClassification,pass,2
+
+
+
+M2M100ForConditionalGeneration,pass,0
+
+
+
+MBartForCausalLM,pass,0
+
+
+
+MBartForConditionalGeneration,pass,0
+
+
+
+MT5ForConditionalGeneration,pass,0
+
+
+
+MegatronBertForCausalLM,pass,0
+
+
+
+MegatronBertForQuestionAnswering,pass,0
+
+
+
+MobileBertForMaskedLM,pass,0
+
+
+
+MobileBertForQuestionAnswering,pass,0
+
+
+
+OPTForCausalLM,pass,0
+
+
+
+PLBartForCausalLM,pass,0
+
+
+
+PLBartForConditionalGeneration,pass,0
+
+
+
+PegasusForCausalLM,pass,0
+
+
+
+PegasusForConditionalGeneration,pass,0
+
+
+
+RobertaForCausalLM,pass,0
+
+
+
+RobertaForQuestionAnswering,pass,0
+
+
+
+Speech2Text2ForCausalLM,pass,0
+
+
+
+T5ForConditionalGeneration,pass,0
+
+
+
+T5Small,pass,0
+
+
+
+TrOCRForCausalLM,pass,0
+
+
+
+XGLMForCausalLM,pass,0
+
+
+
+XLNetLMHeadModel,pass,0
+
+
+
+YituTechConvBert,pass,0
diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_huggingface_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_huggingface_training.csv
new file mode 100644
index 000000000000..6bccf5082c70
--- /dev/null
+++ b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_huggingface_training.csv
@@ -0,0 +1,181 @@
+name,accuracy,graph_breaks
+
+
+
+AlbertForMaskedLM,pass,6
+
+
+
+AlbertForQuestionAnswering,pass,6
+
+
+
+AllenaiLongformerBase,pass,10
+
+
+
+BartForCausalLM,pass,14
+
+
+
+BartForConditionalGeneration,pass,26
+
+
+
+BertForMaskedLM,pass,6
+
+
+
+BertForQuestionAnswering,pass,6
+
+
+
+BlenderbotSmallForCausalLM,pass,14
+
+
+
+BlenderbotSmallForConditionalGeneration,pass,26
+
+
+
+CamemBert,pass,6
+
+
+
+DebertaForMaskedLM,pass,6
+
+
+
+DebertaForQuestionAnswering,pass,6
+
+
+
+DebertaV2ForMaskedLM,pass_due_to_skip,0
+
+
+
+DebertaV2ForQuestionAnswering,eager_1st_run_OOM,0
+
+
+
+DistilBertForMaskedLM,pass,6
+
+
+
+DistilBertForQuestionAnswering,pass,6
+
+
+
+DistillGPT2,pass,6
+
+
+
+ElectraForCausalLM,pass,6
+
+
+
+ElectraForQuestionAnswering,pass,6
+
+
+
+GPT2ForSequenceClassification,pass,8
+
+
+
+GoogleFnet,pass,6
+
+
+
+LayoutLMForMaskedLM,pass,6
+
+
+
+LayoutLMForSequenceClassification,pass,8
+
+
+
+M2M100ForConditionalGeneration,pass,6
+
+
+
+MBartForCausalLM,pass,14
+
+
+
+MBartForConditionalGeneration,pass,26
+
+
+
+MT5ForConditionalGeneration,pass,6
+
+
+
+MegatronBertForCausalLM,pass,6
+
+
+
+MegatronBertForQuestionAnswering,pass,6
+
+
+
+MobileBertForMaskedLM,pass,4
+
+
+
+MobileBertForQuestionAnswering,pass,4
+
+
+
+OPTForCausalLM,pass,14
+
+
+
+PLBartForCausalLM,pass,14
+
+
+
+PLBartForConditionalGeneration,pass,31
+
+
+
+PegasusForCausalLM,pass,14
+
+
+
+PegasusForConditionalGeneration,pass,24
+
+
+
+RobertaForCausalLM,pass,6
+
+
+
+RobertaForQuestionAnswering,pass,6
+
+
+
+Speech2Text2ForCausalLM,pass,14
+
+
+
+T5ForConditionalGeneration,pass,6
+
+
+
+T5Small,pass,6
+
+
+
+TrOCRForCausalLM,pass,14
+
+
+
+XGLMForCausalLM,pass,14
+
+
+
+XLNetLMHeadModel,pass,6
+
+
+
+YituTechConvBert,pass,6
diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_timm_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_timm_inference.csv
new file mode 100644
index 000000000000..c889ba0e8d2f
--- /dev/null
+++ b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_timm_inference.csv
@@ -0,0 +1,245 @@
+name,accuracy,graph_breaks
+
+
+
+adv_inception_v3,pass,0
+
+
+
+beit_base_patch16_224,pass,0
+
+
+
+botnet26t_256,pass,0
+
+
+
+cait_m36_384,pass,0
+
+
+
+coat_lite_mini,pass,0
+
+
+
+convit_base,pass,0
+
+
+
+convmixer_768_32,pass,0
+
+
+
+convnext_base,pass,0
+
+
+
+crossvit_9_240,pass,0
+
+
+
+cspdarknet53,pass,0
+
+
+
+deit_base_distilled_patch16_224,pass,0
+
+
+
+dla102,pass,0
+
+
+
+dm_nfnet_f0,pass,0
+
+
+
+dpn107,pass,0
+
+
+
+eca_botnext26ts_256,pass,0
+
+
+
+eca_halonext26ts,pass,0
+
+
+
+ese_vovnet19b_dw,pass,0
+
+
+
+fbnetc_100,pass,0
+
+
+
+fbnetv3_b,pass,0
+
+
+
+gernet_l,pass,0
+
+
+
+ghostnet_100,pass,0
+
+
+
+gluon_inception_v3,pass,0
+
+
+
+gmixer_24_224,pass,0
+
+
+
+gmlp_s16_224,pass,0
+
+
+
+hrnet_w18,pass,0
+
+
+
+inception_v3,pass,0
+
+
+
+jx_nest_base,pass,0
+
+
+
+lcnet_050,pass,0
+
+
+
+levit_128,pass,0
+
+
+
+mixer_b16_224,pass,0
+
+
+
+mixnet_l,pass,0
+
+
+
+mnasnet_100,pass,0
+
+
+
+mobilenetv2_100,pass,0
+
+
+
+mobilenetv3_large_100,pass,0
+
+
+
+mobilevit_s,pass,0
+
+
+
+nfnet_l0,pass,0
+
+
+
+pit_b_224,pass,0
+
+
+
+pnasnet5large,pass,0
+
+
+
+poolformer_m36,pass,0
+
+
+
+regnety_002,pass,0
+
+
+
+repvgg_a2,pass,0
+
+
+
+res2net101_26w_4s,pass,0
+
+
+
+res2net50_14w_8s,pass,0
+
+
+
+res2next50,pass,0
+
+
+
+resmlp_12_224,pass,0
+
+
+
+resnest101e,pass,0
+
+
+
+rexnet_100,pass,0
+
+
+
+sebotnet33ts_256,pass,0
+
+
+
+selecsls42b,pass,0
+
+
+
+spnasnet_100,pass,0
+
+
+
+swin_base_patch4_window7_224,pass,0
+
+
+
+swsl_resnext101_32x16d,pass,0
+
+
+
+tf_efficientnet_b0,pass,0
+
+
+
+tf_mixnet_l,pass,0
+
+
+
+tinynet_a,pass,0
+
+
+
+tnt_s_patch16_224,pass,0
+
+
+
+twins_pcpvt_base,pass,0
+
+
+
+visformer_small,pass,0
+
+
+
+vit_base_patch16_224,pass,0
+
+
+
+volo_d1_224,pass,0
+
+
+
+xcit_large_24_p8_224,pass,0
diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_timm_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_timm_training.csv
new file mode 100644
index 000000000000..0f52a123eb78
--- /dev/null
+++ b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_timm_training.csv
@@ -0,0 +1,241 @@
+name,accuracy,graph_breaks
+
+
+
+adv_inception_v3,pass,8
+
+
+
+beit_base_patch16_224,pass,8
+
+
+
+botnet26t_256,pass,8
+
+
+
+coat_lite_mini,pass,8
+
+
+
+convit_base,pass,8
+
+
+
+convmixer_768_32,pass,6
+
+
+
+convnext_base,pass,8
+
+
+
+crossvit_9_240,pass,8
+
+
+
+cspdarknet53,pass,8
+
+
+
+deit_base_distilled_patch16_224,pass,8
+
+
+
+dla102,pass,8
+
+
+
+dm_nfnet_f0,pass,8
+
+
+
+dpn107,pass,8
+
+
+
+eca_botnext26ts_256,pass,8
+
+
+
+eca_halonext26ts,pass,8
+
+
+
+ese_vovnet19b_dw,pass,8
+
+
+
+fbnetc_100,pass,8
+
+
+
+fbnetv3_b,pass,8
+
+
+
+gernet_l,pass,8
+
+
+
+ghostnet_100,pass,8
+
+
+
+gluon_inception_v3,pass,8
+
+
+
+gmixer_24_224,pass,8
+
+
+
+gmlp_s16_224,pass,8
+
+
+
+hrnet_w18,pass,6
+
+
+
+inception_v3,pass,8
+
+
+
+jx_nest_base,pass,8
+
+
+
+lcnet_050,pass,8
+
+
+
+levit_128,pass,8
+
+
+
+mixer_b16_224,pass,8
+
+
+
+mixnet_l,pass,8
+
+
+
+mnasnet_100,pass,8
+
+
+
+mobilenetv2_100,pass,8
+
+
+
+mobilenetv3_large_100,pass,8
+
+
+
+mobilevit_s,pass,8
+
+
+
+nfnet_l0,pass,8
+
+
+
+pit_b_224,pass,8
+
+
+
+pnasnet5large,pass,6
+
+
+
+poolformer_m36,pass,8
+
+
+
+regnety_002,pass,8
+
+
+
+repvgg_a2,pass,8
+
+
+
+res2net101_26w_4s,pass,8
+
+
+
+res2net50_14w_8s,pass,8
+
+
+
+res2next50,pass,8
+
+
+
+resmlp_12_224,pass,8
+
+
+
+resnest101e,pass,8
+
+
+
+rexnet_100,pass,8
+
+
+
+sebotnet33ts_256,pass,8
+
+
+
+selecsls42b,pass,8
+
+
+
+spnasnet_100,pass,8
+
+
+
+swin_base_patch4_window7_224,pass,8
+
+
+
+swsl_resnext101_32x16d,pass,8
+
+
+
+tf_efficientnet_b0,pass,8
+
+
+
+tf_mixnet_l,pass,8
+
+
+
+tinynet_a,pass,8
+
+
+
+tnt_s_patch16_224,pass,8
+
+
+
+twins_pcpvt_base,pass,8
+
+
+
+visformer_small,pass,8
+
+
+
+vit_base_patch16_224,pass,8
+
+
+
+volo_d1_224,pass,8
+
+
+
+xcit_large_24_p8_224,pass_due_to_skip,8
diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv
new file mode 100644
index 000000000000..0854562b587a
--- /dev/null
+++ b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv
@@ -0,0 +1,285 @@
+name,accuracy,graph_breaks
+
+
+
+BERT_pytorch,pass,0
+
+
+
+Background_Matting,pass_due_to_skip,0
+
+
+
+DALLE2_pytorch,fail_to_run,21
+
+
+
+LearningToPaint,pass,0
+
+
+
+Super_SloMo,pass,0
+
+
+
+alexnet,pass,0
+
+
+
+basic_gnn_edgecnn,pass,0
+
+
+
+basic_gnn_gcn,pass,6
+
+
+
+basic_gnn_gin,pass,0
+
+
+
+basic_gnn_sage,pass,0
+
+
+
+cm3leon_generate,pass,4
+
+
+
+dcgan,pass,0
+
+
+
+densenet121,pass,0
+
+
+
+detectron2_fcos_r_50_fpn,pass,41
+
+
+
+dlrm,pass,0
+
+
+
+doctr_det_predictor,pass,5
+
+
+
+doctr_reco_predictor,pass,4
+
+
+
+drq,pass,0
+
+
+
+fastNLP_Bert,pass,4
+
+
+
+functorch_dp_cifar10,pass,0
+
+
+
+functorch_maml_omniglot,pass,0
+
+
+
+hf_Albert,pass,0
+
+
+
+hf_Bart,pass,0
+
+
+
+hf_BigBird,pass,0
+
+
+
+hf_DistilBert,pass,0
+
+
+
+hf_GPT2,pass,0
+
+
+
+hf_GPT2_large,pass_due_to_skip,0
+
+
+
+hf_Reformer,pass,5
+
+
+
+hf_T5,pass,0
+
+
+
+hf_T5_generate,pass,20
+
+
+
+hf_T5_large,pass_due_to_skip,0
+
+
+
+hf_Whisper,pass,0
+
+
+
+lennard_jones,pass,0
+
+
+
+llama,pass,0
+
+
+
+maml_omniglot,pass,0
+
+
+
+mnasnet1_0,pass,0
+
+
+
+mobilenet_v2,pass,0
+
+
+
+mobilenet_v3_large,pass,0
+
+
+
+moco,pass,11
+
+
+
+nanogpt,pass,0
+
+
+
+nvidia_deeprecommender,pass,0
+
+
+
+opacus_cifar10,pass,0
+
+
+
+phlippe_densenet,pass,0
+
+
+
+phlippe_resnet,pass,0
+
+
+
+pyhpc_equation_of_state,pass,0
+
+
+
+pyhpc_isoneutral_mixing,pass,0
+
+
+
+pyhpc_turbulent_kinetic_energy,pass,0
+
+
+
+pytorch_CycleGAN_and_pix2pix,pass,0
+
+
+
+pytorch_stargan,pass,0
+
+
+
+pytorch_unet,pass,0
+
+
+
+resnet152,pass,0
+
+
+
+resnet18,pass,0
+
+
+
+resnet50,pass,0
+
+
+
+resnext50_32x4d,pass,0
+
+
+
+sam,pass,0
+
+
+
+shufflenet_v2_x1_0,pass,0
+
+
+
+soft_actor_critic,pass,0
+
+
+
+speech_transformer,pass,10
+
+
+
+squeezenet1_1,pass,0
+
+
+
+stable_diffusion_text_encoder,pass,0
+
+
+
+stable_diffusion_unet,pass_due_to_skip,0
+
+
+
+timm_efficientnet,pass,0
+
+
+
+timm_regnet,pass,0
+
+
+
+timm_resnest,pass,0
+
+
+
+timm_vision_transformer,pass,0
+
+
+
+timm_vision_transformer_large,pass_due_to_skip,0
+
+
+
+timm_vovnet,pass,0
+
+
+
+tts_angular,pass,2
+
+
+
+vgg16,pass,0
+
+
+
+vision_maskrcnn,pass,17
+
+
+
+yolov3,pass,2
diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_training.csv
new file mode 100644
index 000000000000..00bff55b77ec
--- /dev/null
+++ b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_training.csv
@@ -0,0 +1,245 @@
+name,accuracy,graph_breaks
+
+
+
+torchrec_dlrm,pass,8
+
+
+
+BERT_pytorch,pass,8
+
+
+
+Background_Matting,pass_due_to_skip,0
+
+
+
+LearningToPaint,pass,8
+
+
+
+Super_SloMo,pass,8
+
+
+
+alexnet,pass,8
+
+
+
+basic_gnn_edgecnn,pass,23
+
+
+
+basic_gnn_gcn,pass,14
+
+
+
+basic_gnn_gin,pass,8
+
+
+
+basic_gnn_sage,pass,8
+
+
+
+dcgan,pass,8
+
+
+
+demucs,pass,11
+
+
+
+densenet121,pass,8
+
+
+
+dlrm,pass,8
+
+
+
+drq,pass,7
+
+
+
+fastNLP_Bert,pass,12
+
+
+
+functorch_dp_cifar10,pass,8
+
+
+
+functorch_maml_omniglot,pass,8
+
+
+
+hf_Albert,pass,7
+
+
+
+hf_Bart,pass,7
+
+
+
+hf_BigBird,pass,7
+
+
+
+hf_DistilBert,pass,7
+
+
+
+hf_GPT2,pass,7
+
+
+
+hf_GPT2_large,pass_due_to_skip,0
+
+
+
+hf_Reformer,pass,27
+
+
+
+hf_T5_base,pass,7
+
+
+
+hf_T5_large,pass_due_to_skip,0
+
+
+
+hf_Whisper,pass,7
+
+
+
+lennard_jones,pass,8
+
+
+
+maml_omniglot,pass,8
+
+
+
+mnasnet1_0,pass,8
+
+
+
+mobilenet_v2,pass,8
+
+
+
+mobilenet_v3_large,pass,8
+
+
+
+moco,pass,19
+
+
+
+nanogpt,pass,8
+
+
+
+nvidia_deeprecommender,pass,8
+
+
+
+phlippe_densenet,pass,8
+
+
+
+phlippe_resnet,pass,8
+
+
+
+pytorch_CycleGAN_and_pix2pix,pass,8
+
+
+
+pytorch_stargan,pass,8
+
+
+
+pytorch_unet,pass,8
+
+
+
+resnet152,pass,8
+
+
+
+resnet18,pass,8
+
+
+
+resnet50,pass,8
+
+
+
+resnext50_32x4d,pass,8
+
+
+
+shufflenet_v2_x1_0,pass,8
+
+
+
+soft_actor_critic,pass,7
+
+
+
+speech_transformer,pass,18
+
+
+
+squeezenet1_1,pass,8
+
+
+
+stable_diffusion_text_encoder,pass,7
+
+
+
+stable_diffusion_unet,pass_due_to_skip,0
+
+
+
+timm_efficientnet,pass,8
+
+
+
+timm_regnet,pass,8
+
+
+
+timm_resnest,pass,8
+
+
+
+timm_vision_transformer,pass,8
+
+
+
+timm_vision_transformer_large,pass_due_to_skip,0
+
+
+
+timm_vovnet,pass,8
+
+
+
+tts_angular,pass,10
+
+
+
+vgg16,pass,8
+
+
+
+vision_maskrcnn,pass,37
+
+
+
+yolov3,pass,10
diff --git a/benchmarks/dynamo/ci_expected_accuracy/update_expected.py b/benchmarks/dynamo/ci_expected_accuracy/update_expected.py
index 77640b2af835..0835cd8b024e 100644
--- a/benchmarks/dynamo/ci_expected_accuracy/update_expected.py
+++ b/benchmarks/dynamo/ci_expected_accuracy/update_expected.py
@@ -23,6 +23,7 @@
 import sys
 import urllib
 from io import BytesIO
+from itertools import product
 from urllib.request import urlopen
 from zipfile import ZipFile
 
@@ -65,7 +66,10 @@ def parse_test_str(test_str):
 def get_artifacts_urls(results, suites):
     urls = {}
     for r in results:
-        if "inductor" == r["workflowName"] and "test" in r["jobName"]:
+        if (
+            r["workflowName"] in ("inductor", "inductor-periodic")
+            and "test" in r["jobName"]
+        ):
             config_str, test_str = parse_job_name(r["jobName"])
             suite, shard_id, num_shards, machine, *_ = parse_test_str(test_str)
             workflowId = r["workflowId"]
@@ -140,16 +144,20 @@ def apply_lints(filename):
     args = parser.parse_args()
 
     repo = "pytorch/pytorch"
+
     suites = {
-        "aot_inductor_huggingface",
-        "inductor_huggingface",
-        "dynamic_inductor_huggingface",
-        "aot_inductor_timm",
-        "inductor_timm",
-        "dynamic_inductor_timm",
-        "aot_inductor_torchbench",
-        "inductor_torchbench",
-        "dynamic_inductor_torchbench",
+        f"{a}_{b}"
+        for a, b in product(
+            [
+                "aot_eager",
+                "aot_inductor",
+                "dynamic_aot_eager",
+                "dynamo_eager",
+                "inductor",
+                "dynamic_inductor",
+            ],
+            ["huggingface", "timm", "torchbench"],
+        )
     }
 
     root_path = "benchmarks/dynamo/ci_expected_accuracy/"
diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py
index fee02e094880..17d090ec8221 100644
--- a/benchmarks/dynamo/common.py
+++ b/benchmarks/dynamo/common.py
@@ -112,86 +112,6 @@ class CI(NamedTuple):
 
 CI_SKIP = collections.defaultdict(list)
 
-
-# Skips for dynamic=False
-
-# Here eager really means dynamo+eager
-CI_SKIP[CI("eager", training=False)] = [
-    # TorchBench
-    "DALLE2_pytorch",  # AttributeError: text_encodings
-    "hf_BigBird",  # fail_accuracy
-    # TypeError: pad_center() takes 1 positional argument but 2 were given
-    "tacotron2",
-    # Huggingface
-    "DebertaV2ForQuestionAnswering",  # OOM
-]
-
-CI_SKIP[CI("eager", training=True)] = [
-    *CI_SKIP[CI("eager", training=False)],
-    # TorchBench
-    "BERT_pytorch",  # accuracy
-    "Background_Matting",  # fp64_OOM
-    "hf_BigBird",  # fp64_OOM
-    "hf_T5_base",  # fp64_OOM
-    "llama",  # Accuracy failed: allclose not within tol=0.001
-    "vision_maskrcnn",  # The size of tensor a (29) must match the size of tensor b (33) (doesn't repro)
-    # Huggingface
-    "XGLMForCausalLM",  # OOM
-    # TIMM
-    "cait_m36_384",  # fp64_OOM
-    "convit_base",  # fp64_OOM
-    "mobilenetv2_100",  # accuracy
-    "xcit_large_24_p8_224",  # fp64_OOM,
-]
-
-CI_SKIP[CI("aot_eager", training=False)] = [
-    *CI_SKIP[CI("eager", training=False)],
-    # all dynamic shapes errors for detectron variants
-    "demucs",  # OOM
-    "detectron2_fasterrcnn_r_101_c4",
-    "detectron2_fasterrcnn_r_101_dc5",
-    "detectron2_fasterrcnn_r_101_fpn",
-    "detectron2_fasterrcnn_r_50_c4",
-    "detectron2_fasterrcnn_r_50_dc5",
-    "detectron2_fasterrcnn_r_50_fpn",
-    "detectron2_fcos_r_50_fpn",
-    "detectron2_maskrcnn_r_101_c4",
-    "detectron2_maskrcnn_r_101_fpn",
-    "detectron2_maskrcnn_r_50_c4",
-    "detectron2_maskrcnn_r_50_fpn",
-    "hf_BigBird",  # OOM
-    "tacotron2",  # AssertionError: Deduped args out of bounds
-    # Huggingface
-    "BartForConditionalGeneration",  # OOM
-    "DebertaV2ForQuestionAnswering",  # OOM
-    # Torchbench
-    "speech_transformer",  # https://github.com/pytorch/pytorch/issues/99893
-    "pyhpc_isoneutral_mixing",  # https://github.com/pytorch/pytorch/issues/99893
-    "pyhpc_turbulent_kinetic_energy",  # https://github.com/pytorch/pytorch/issues/99893
-]
-
-CI_SKIP[CI("aot_eager", training=True)] = [
-    *CI_SKIP[CI("aot_eager", training=False)],
-    # TorchBench
-    "Background_Matting",  # fp64_OOM
-    "hf_T5_base",  # fp64_OOM
-    "mobilenet_v2_quantized_qat",  # fp64_OOM
-    "resnet50_quantized_qat",  # fp64_OOM
-    "pytorch_struct",
-    # Huggingface
-    "MBartForConditionalGeneration",  # OOM
-    "M2M100ForConditionalGeneration",  # OOM
-    "XGLMForCausalLM",  # OOM
-    # TIMM
-    "cait_m36_384",  # fp64_OOM
-    "convit_base",  # fp64_OOM
-    "fbnetv3_b",  # Accuracy (blocks.2.2.bn1.weight.grad)
-    "levit_128",  # Accuracy (patch_embed.0.c.weight.grad)
-    "lcnet_050",  # Accuracy (blocks.1.0.bn2.weight.grad)
-    "sebotnet33ts_256",  # Accuracy (stem.conv1.conv.weight.grad)
-    "xcit_large_24_p8_224",  # fp64_OOM,
-]
-
 CI_SKIP[CI("inductor", training=False, device="cpu")] = [
     # TorchBench
     "drq",  # Need to update torchbench
@@ -227,23 +147,6 @@ class CI(NamedTuple):
     "opacus_cifar10",  # Fails to run https://github.com/pytorch/pytorch/issues/99201
 ]
 
-# Skips for dynamic=True
-
-CI_SKIP[CI("aot_eager", training=False, dynamic=True)] = [
-    *CI_SKIP[CI("aot_eager", training=False)],
-    "vision_maskrcnn",  # accuracy failure on boxes, after https://github.com/pytorch/pytorch/issues/101093
-    # https://github.com/pytorch/pytorch/issues/103760
-    "hf_T5_generate",
-    "hf_Bert",  # Error: RelaxedUnspecConstraint(L['input_ids'].size()[0]) - inferred constant (4)
-]
-
-CI_SKIP[CI("aot_eager", training=True, dynamic=True)] = [
-    *CI_SKIP[CI("aot_eager", training=True)],
-    *CI_SKIP[CI("aot_eager", training=False, dynamic=True)],
-    "llama",  # AssertionError: cannot compute free_symbols of True
-    "torchrec_dlrm",  # RuntimeError: mat1 and mat2 must have the same dtype, but got Float and BFloat16
-]
-
 CI_SKIP[CI("inductor", training=False, dynamic=True, device="cpu")] = [
     *CI_SKIP[CI("inductor", training=False, device="cpu")],
     "pyhpc_isoneutral_mixing",

From 6ff72607000eb6019650fc44e352de528051e583 Mon Sep 17 00:00:00 2001
From: Bin Bao 
Date: Tue, 21 Nov 2023 13:23:08 -0500
Subject: [PATCH 065/221] [CI] Switch to check against expected result files
 for cpu inductor integration tests (#113668)

Summary: With this, we can completely remove CI_SKIP from common.py.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113668
Approved by: https://github.com/ezyang, https://github.com/jansel
ghstack dependencies: #113574, #113575, #113446, #113559
---
 .ci/pytorch/test.sh                           |  26 +-
 .github/workflows/inductor.yml                |  20 +-
 .../cpu_inductor_huggingface_inference.csv    | 185 +++++++++++
 .../cpu_inductor_timm_inference.csv           | 245 ++++++++++++++
 .../cpu_inductor_torchbench_inference.csv     | 297 +++++++++++++++++
 ...mic_cpu_inductor_huggingface_inference.csv | 185 +++++++++++
 .../dynamic_cpu_inductor_timm_inference.csv   | 245 ++++++++++++++
 ...amic_cpu_inductor_torchbench_inference.csv | 301 ++++++++++++++++++
 .../ci_expected_accuracy/update_expected.py   |   4 +-
 benchmarks/dynamo/common.py                   |  74 +----
 10 files changed, 1484 insertions(+), 98 deletions(-)
 create mode 100644 benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_huggingface_inference.csv
 create mode 100644 benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_timm_inference.csv
 create mode 100644 benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv
 create mode 100644 benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_huggingface_inference.csv
 create mode 100644 benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_timm_inference.csv
 create mode 100644 benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv

diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh
index b085e427acfd..679485dfa0ec 100755
--- a/.ci/pytorch/test.sh
+++ b/.ci/pytorch/test.sh
@@ -332,7 +332,7 @@ if [[ "${TEST_CONFIG}" == *dynamic* ]]; then
   DYNAMO_BENCHMARK_FLAGS+=(--dynamic-shapes --dynamic-batch-only)
 fi
 
-if [[ "${TEST_CONFIG}" == *cpu_accuracy* ]]; then
+if [[ "${TEST_CONFIG}" == *cpu_inductor* ]]; then
   DYNAMO_BENCHMARK_FLAGS+=(--device cpu)
 else
   DYNAMO_BENCHMARK_FLAGS+=(--device cuda)
@@ -451,18 +451,12 @@ test_single_dynamo_benchmark() {
       "${DYNAMO_BENCHMARK_FLAGS[@]}" \
       "$@" "${partition_flags[@]}" \
       --output "$TEST_REPORTS_DIR/${name}_${suite}.csv"
-
-    if [[ "${TEST_CONFIG}" != *cpu_accuracy* ]]; then
-      python benchmarks/dynamo/check_accuracy.py \
-        --actual "$TEST_REPORTS_DIR/${name}_$suite.csv" \
-        --expected "benchmarks/dynamo/ci_expected_accuracy/${TEST_CONFIG}_${name}.csv"
-      python benchmarks/dynamo/check_graph_breaks.py \
-        --actual "$TEST_REPORTS_DIR/${name}_$suite.csv" \
-        --expected "benchmarks/dynamo/ci_expected_accuracy/${TEST_CONFIG}_${name}.csv"
-    else
-      python benchmarks/dynamo/check_csv.py \
-        -f "$TEST_REPORTS_DIR/${name}_${suite}.csv"
-    fi
+    python benchmarks/dynamo/check_accuracy.py \
+      --actual "$TEST_REPORTS_DIR/${name}_$suite.csv" \
+      --expected "benchmarks/dynamo/ci_expected_accuracy/${TEST_CONFIG}_${name}.csv"
+    python benchmarks/dynamo/check_graph_breaks.py \
+      --actual "$TEST_REPORTS_DIR/${name}_$suite.csv" \
+      --expected "benchmarks/dynamo/ci_expected_accuracy/${TEST_CONFIG}_${name}.csv"
   fi
 }
 
@@ -480,7 +474,7 @@ test_dynamo_benchmark() {
   elif [[ "${TEST_CONFIG}" == *perf* ]]; then
     test_single_dynamo_benchmark "dashboard" "$suite" "$shard_id" "$@"
   else
-    if [[ "${TEST_CONFIG}" == *cpu_accuracy* ]]; then
+    if [[ "${TEST_CONFIG}" == *cpu_inductor* ]]; then
       test_single_dynamo_benchmark "inference" "$suite" "$shard_id" --inference --float32 "$@"
     elif [[ "${TEST_CONFIG}" == *aot_inductor* ]]; then
       test_single_dynamo_benchmark "inference" "$suite" "$shard_id" --inference --bfloat16 "$@"
@@ -1062,7 +1056,7 @@ elif [[ "${TEST_CONFIG}" == *timm* ]]; then
   id=$((SHARD_NUMBER-1))
   test_dynamo_benchmark timm_models "$id"
 elif [[ "${TEST_CONFIG}" == *torchbench* ]]; then
-  if [[ "${TEST_CONFIG}" == *cpu_accuracy* ]]; then
+  if [[ "${TEST_CONFIG}" == *cpu_inductor* ]]; then
     install_torchaudio cpu
   else
     install_torchaudio cuda
@@ -1079,7 +1073,7 @@ elif [[ "${TEST_CONFIG}" == *torchbench* ]]; then
     checkout_install_torchbench
     # Do this after checkout_install_torchbench to ensure we clobber any
     # nightlies that torchbench may pull in
-    if [[ "${TEST_CONFIG}" != *cpu_accuracy* ]]; then
+    if [[ "${TEST_CONFIG}" != *cpu_inductor* ]]; then
       install_torchrec_and_fbgemm
     fi
     PYTHONPATH=$(pwd)/torchbench test_dynamo_benchmark torchbench "$id"
diff --git a/.github/workflows/inductor.yml b/.github/workflows/inductor.yml
index 141c8e619dc1..06d75f40a6a0 100644
--- a/.github/workflows/inductor.yml
+++ b/.github/workflows/inductor.yml
@@ -88,16 +88,16 @@ jobs:
       docker-image-name: pytorch-linux-jammy-py3.8-gcc11-inductor-benchmarks
       test-matrix: |
         { include: [
-          { config: "inductor_huggingface_cpu_accuracy", shard: 1, num_shards: 1, runner: "linux.12xlarge" },
-          { config: "inductor_timm_cpu_accuracy", shard: 1, num_shards: 2, runner: "linux.4xlarge" },
-          { config: "inductor_timm_cpu_accuracy", shard: 2, num_shards: 2, runner: "linux.4xlarge" },
-          { config: "inductor_torchbench_cpu_accuracy", shard: 1, num_shards: 2, runner: "linux.4xlarge" },
-          { config: "inductor_torchbench_cpu_accuracy", shard: 2, num_shards: 2, runner: "linux.4xlarge" },
-          { config: "inductor_huggingface_dynamic_cpu_accuracy", shard: 1, num_shards: 1, runner: "linux.12xlarge" },
-          { config: "inductor_timm_dynamic_cpu_accuracy", shard: 1, num_shards: 2, runner: "linux.12xlarge" },
-          { config: "inductor_timm_dynamic_cpu_accuracy", shard: 2, num_shards: 2, runner: "linux.12xlarge" },
-          { config: "inductor_torchbench_dynamic_cpu_accuracy", shard: 1, num_shards: 2, runner: "linux.12xlarge" },
-          { config: "inductor_torchbench_dynamic_cpu_accuracy", shard: 2, num_shards: 2, runner: "linux.12xlarge" },
+          { config: "cpu_inductor_huggingface", shard: 1, num_shards: 1, runner: "linux.12xlarge" },
+          { config: "cpu_inductor_timm", shard: 1, num_shards: 2, runner: "linux.4xlarge" },
+          { config: "cpu_inductor_timm", shard: 2, num_shards: 2, runner: "linux.4xlarge" },
+          { config: "cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "linux.4xlarge" },
+          { config: "cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "linux.4xlarge" },
+          { config: "dynamic_cpu_inductor_huggingface", shard: 1, num_shards: 1, runner: "linux.12xlarge" },
+          { config: "dynamic_cpu_inductor_timm", shard: 1, num_shards: 2, runner: "linux.12xlarge" },
+          { config: "dynamic_cpu_inductor_timm", shard: 2, num_shards: 2, runner: "linux.12xlarge" },
+          { config: "dynamic_cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "linux.12xlarge" },
+          { config: "dynamic_cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "linux.12xlarge" },
         ]}
     secrets:
       HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_huggingface_inference.csv
new file mode 100644
index 000000000000..349239b058a7
--- /dev/null
+++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_huggingface_inference.csv
@@ -0,0 +1,185 @@
+name,accuracy,graph_breaks
+
+
+
+AlbertForMaskedLM,pass,0
+
+
+
+AlbertForQuestionAnswering,pass,0
+
+
+
+AllenaiLongformerBase,pass,4
+
+
+
+BartForCausalLM,pass,0
+
+
+
+BartForConditionalGeneration,pass,0
+
+
+
+BertForMaskedLM,pass,0
+
+
+
+BertForQuestionAnswering,pass,0
+
+
+
+BlenderbotForCausalLM,pass_due_to_skip,0
+
+
+
+BlenderbotSmallForCausalLM,pass,0
+
+
+
+BlenderbotSmallForConditionalGeneration,pass,0
+
+
+
+CamemBert,pass,0
+
+
+
+DebertaForMaskedLM,pass,0
+
+
+
+DebertaForQuestionAnswering,pass,0
+
+
+
+DebertaV2ForMaskedLM,pass_due_to_skip,0
+
+
+
+DebertaV2ForQuestionAnswering,pass,0
+
+
+
+DistilBertForMaskedLM,pass,0
+
+
+
+DistilBertForQuestionAnswering,pass,0
+
+
+
+DistillGPT2,pass,0
+
+
+
+ElectraForCausalLM,pass,0
+
+
+
+ElectraForQuestionAnswering,pass,0
+
+
+
+GPT2ForSequenceClassification,pass,2
+
+
+
+GoogleFnet,pass,0
+
+
+
+LayoutLMForMaskedLM,pass,0
+
+
+
+LayoutLMForSequenceClassification,pass,2
+
+
+
+M2M100ForConditionalGeneration,pass,0
+
+
+
+MBartForCausalLM,pass,0
+
+
+
+MBartForConditionalGeneration,pass,0
+
+
+
+MT5ForConditionalGeneration,pass,0
+
+
+
+MegatronBertForCausalLM,pass,0
+
+
+
+MegatronBertForQuestionAnswering,pass,0
+
+
+
+MobileBertForMaskedLM,pass,0
+
+
+
+MobileBertForQuestionAnswering,pass,0
+
+
+
+OPTForCausalLM,pass,0
+
+
+
+PLBartForCausalLM,pass,0
+
+
+
+PLBartForConditionalGeneration,pass,0
+
+
+
+PegasusForCausalLM,pass,0
+
+
+
+PegasusForConditionalGeneration,pass,0
+
+
+
+RobertaForCausalLM,pass,0
+
+
+
+RobertaForQuestionAnswering,pass,0
+
+
+
+Speech2Text2ForCausalLM,pass,0
+
+
+
+T5ForConditionalGeneration,pass,0
+
+
+
+T5Small,pass,0
+
+
+
+TrOCRForCausalLM,pass,0
+
+
+
+XGLMForCausalLM,pass,0
+
+
+
+XLNetLMHeadModel,pass,0
+
+
+
+YituTechConvBert,pass,0
diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_timm_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_timm_inference.csv
new file mode 100644
index 000000000000..dd89a722815d
--- /dev/null
+++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_timm_inference.csv
@@ -0,0 +1,245 @@
+name,accuracy,graph_breaks
+
+
+
+adv_inception_v3,pass,0
+
+
+
+beit_base_patch16_224,pass,0
+
+
+
+botnet26t_256,pass,0
+
+
+
+cait_m36_384,infra_error,0
+
+
+
+coat_lite_mini,pass,0
+
+
+
+convit_base,pass,0
+
+
+
+convmixer_768_32,pass,0
+
+
+
+convnext_base,pass,0
+
+
+
+crossvit_9_240,pass,0
+
+
+
+cspdarknet53,pass,0
+
+
+
+deit_base_distilled_patch16_224,pass,0
+
+
+
+dla102,pass,0
+
+
+
+dm_nfnet_f0,pass,0
+
+
+
+dpn107,pass,0
+
+
+
+eca_botnext26ts_256,pass,0
+
+
+
+eca_halonext26ts,pass,0
+
+
+
+ese_vovnet19b_dw,pass,0
+
+
+
+fbnetc_100,pass,0
+
+
+
+fbnetv3_b,pass,0
+
+
+
+gernet_l,pass,0
+
+
+
+ghostnet_100,pass,0
+
+
+
+gluon_inception_v3,pass,0
+
+
+
+gmixer_24_224,pass,0
+
+
+
+gmlp_s16_224,pass,0
+
+
+
+hrnet_w18,pass,0
+
+
+
+inception_v3,pass,0
+
+
+
+jx_nest_base,pass,0
+
+
+
+lcnet_050,pass,0
+
+
+
+levit_128,pass,0
+
+
+
+mixer_b16_224,pass,0
+
+
+
+mixnet_l,pass,0
+
+
+
+mnasnet_100,pass,0
+
+
+
+mobilenetv2_100,pass,0
+
+
+
+mobilenetv3_large_100,pass,0
+
+
+
+mobilevit_s,pass,0
+
+
+
+nfnet_l0,pass,0
+
+
+
+pit_b_224,pass,0
+
+
+
+pnasnet5large,pass,0
+
+
+
+poolformer_m36,pass,0
+
+
+
+regnety_002,pass,0
+
+
+
+repvgg_a2,pass,0
+
+
+
+res2net101_26w_4s,pass,0
+
+
+
+res2net50_14w_8s,pass,0
+
+
+
+res2next50,pass,0
+
+
+
+resmlp_12_224,pass,0
+
+
+
+resnest101e,pass,0
+
+
+
+rexnet_100,pass,0
+
+
+
+sebotnet33ts_256,pass,0
+
+
+
+selecsls42b,pass,0
+
+
+
+spnasnet_100,pass,0
+
+
+
+swin_base_patch4_window7_224,pass,0
+
+
+
+swsl_resnext101_32x16d,pass,0
+
+
+
+tf_efficientnet_b0,pass,0
+
+
+
+tf_mixnet_l,pass,0
+
+
+
+tinynet_a,pass,0
+
+
+
+tnt_s_patch16_224,pass,0
+
+
+
+twins_pcpvt_base,pass,0
+
+
+
+visformer_small,pass,0
+
+
+
+vit_base_patch16_224,pass,0
+
+
+
+volo_d1_224,pass,0
+
+
+
+xcit_large_24_p8_224,pass,0
diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv
new file mode 100644
index 000000000000..e8e91d7c83f3
--- /dev/null
+++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv
@@ -0,0 +1,297 @@
+name,accuracy,graph_breaks
+
+
+
+BERT_pytorch,pass,0
+
+
+
+Background_Matting,pass_due_to_skip,0
+
+
+
+LearningToPaint,pass,0
+
+
+
+Super_SloMo,pass,0
+
+
+
+alexnet,pass,0
+
+
+
+basic_gnn_edgecnn,pass,0
+
+
+
+basic_gnn_gcn,pass,6
+
+
+
+basic_gnn_gin,pass,0
+
+
+
+basic_gnn_sage,pass,0
+
+
+
+dcgan,pass,0
+
+
+
+demucs,pass,3
+
+
+
+densenet121,pass,0
+
+
+
+detectron2_fasterrcnn_r_101_c4,pass,52
+
+
+
+detectron2_fasterrcnn_r_101_dc5,pass,52
+
+
+
+detectron2_fasterrcnn_r_101_fpn,pass,56
+
+
+
+detectron2_fasterrcnn_r_50_c4,pass,52
+
+
+
+detectron2_fasterrcnn_r_50_dc5,pass,52
+
+
+
+detectron2_fasterrcnn_r_50_fpn,pass,56
+
+
+
+detectron2_fcos_r_50_fpn,pass,44
+
+
+
+detectron2_maskrcnn_r_101_c4,fail_accuracy,67
+
+
+
+detectron2_maskrcnn_r_101_fpn,pass,74
+
+
+
+detectron2_maskrcnn_r_50_c4,pass,67
+
+
+
+detectron2_maskrcnn_r_50_fpn,pass,74
+
+
+
+dlrm,pass,0
+
+
+
+doctr_det_predictor,pass,5
+
+
+
+doctr_reco_predictor,pass,4
+
+
+
+drq,pass,0
+
+
+
+fastNLP_Bert,pass,4
+
+
+
+functorch_dp_cifar10,pass,0
+
+
+
+functorch_maml_omniglot,pass,0
+
+
+
+hf_Albert,pass,0
+
+
+
+hf_Bart,pass,0
+
+
+
+hf_DistilBert,pass,0
+
+
+
+hf_GPT2,pass,0
+
+
+
+hf_GPT2_large,infra_error,0
+
+
+
+hf_Reformer,pass,5
+
+
+
+hf_T5_large,pass_due_to_skip,0
+
+
+
+lennard_jones,pass,0
+
+
+
+llama,pass,0
+
+
+
+maml_omniglot,pass,0
+
+
+
+mnasnet1_0,pass,0
+
+
+
+mobilenet_v2,pass,0
+
+
+
+mobilenet_v3_large,pass,0
+
+
+
+nvidia_deeprecommender,pass,0
+
+
+
+opacus_cifar10,pass,0
+
+
+
+phi_1_5,pass,74
+
+
+
+phlippe_densenet,pass,0
+
+
+
+phlippe_resnet,pass,0
+
+
+
+pyhpc_equation_of_state,pass,0
+
+
+
+pyhpc_isoneutral_mixing,pass,0
+
+
+
+pyhpc_turbulent_kinetic_energy,pass,0
+
+
+
+pytorch_CycleGAN_and_pix2pix,pass,0
+
+
+
+pytorch_stargan,pass,0
+
+
+
+pytorch_unet,pass,0
+
+
+
+resnet152,pass,0
+
+
+
+resnet18,pass,0
+
+
+
+resnet50,pass,0
+
+
+
+resnext50_32x4d,pass,0
+
+
+
+shufflenet_v2_x1_0,pass,0
+
+
+
+soft_actor_critic,pass,0
+
+
+
+speech_transformer,pass,10
+
+
+
+squeezenet1_1,pass,0
+
+
+
+stable_diffusion_unet,pass_due_to_skip,0
+
+
+
+timm_efficientnet,pass,0
+
+
+
+timm_nfnet,pass,0
+
+
+
+timm_regnet,pass,0
+
+
+
+timm_resnest,pass,0
+
+
+
+timm_vision_transformer,pass,0
+
+
+
+timm_vision_transformer_large,pass_due_to_skip,0
+
+
+
+timm_vovnet,pass,0
+
+
+
+tts_angular,pass,2
+
+
+
+vgg16,pass,0
+
+
+
+vision_maskrcnn,pass,29
+
+
+
+yolov3,pass,2
diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_huggingface_inference.csv
new file mode 100644
index 000000000000..349239b058a7
--- /dev/null
+++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_huggingface_inference.csv
@@ -0,0 +1,185 @@
+name,accuracy,graph_breaks
+
+
+
+AlbertForMaskedLM,pass,0
+
+
+
+AlbertForQuestionAnswering,pass,0
+
+
+
+AllenaiLongformerBase,pass,4
+
+
+
+BartForCausalLM,pass,0
+
+
+
+BartForConditionalGeneration,pass,0
+
+
+
+BertForMaskedLM,pass,0
+
+
+
+BertForQuestionAnswering,pass,0
+
+
+
+BlenderbotForCausalLM,pass_due_to_skip,0
+
+
+
+BlenderbotSmallForCausalLM,pass,0
+
+
+
+BlenderbotSmallForConditionalGeneration,pass,0
+
+
+
+CamemBert,pass,0
+
+
+
+DebertaForMaskedLM,pass,0
+
+
+
+DebertaForQuestionAnswering,pass,0
+
+
+
+DebertaV2ForMaskedLM,pass_due_to_skip,0
+
+
+
+DebertaV2ForQuestionAnswering,pass,0
+
+
+
+DistilBertForMaskedLM,pass,0
+
+
+
+DistilBertForQuestionAnswering,pass,0
+
+
+
+DistillGPT2,pass,0
+
+
+
+ElectraForCausalLM,pass,0
+
+
+
+ElectraForQuestionAnswering,pass,0
+
+
+
+GPT2ForSequenceClassification,pass,2
+
+
+
+GoogleFnet,pass,0
+
+
+
+LayoutLMForMaskedLM,pass,0
+
+
+
+LayoutLMForSequenceClassification,pass,2
+
+
+
+M2M100ForConditionalGeneration,pass,0
+
+
+
+MBartForCausalLM,pass,0
+
+
+
+MBartForConditionalGeneration,pass,0
+
+
+
+MT5ForConditionalGeneration,pass,0
+
+
+
+MegatronBertForCausalLM,pass,0
+
+
+
+MegatronBertForQuestionAnswering,pass,0
+
+
+
+MobileBertForMaskedLM,pass,0
+
+
+
+MobileBertForQuestionAnswering,pass,0
+
+
+
+OPTForCausalLM,pass,0
+
+
+
+PLBartForCausalLM,pass,0
+
+
+
+PLBartForConditionalGeneration,pass,0
+
+
+
+PegasusForCausalLM,pass,0
+
+
+
+PegasusForConditionalGeneration,pass,0
+
+
+
+RobertaForCausalLM,pass,0
+
+
+
+RobertaForQuestionAnswering,pass,0
+
+
+
+Speech2Text2ForCausalLM,pass,0
+
+
+
+T5ForConditionalGeneration,pass,0
+
+
+
+T5Small,pass,0
+
+
+
+TrOCRForCausalLM,pass,0
+
+
+
+XGLMForCausalLM,pass,0
+
+
+
+XLNetLMHeadModel,pass,0
+
+
+
+YituTechConvBert,pass,0
diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_timm_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_timm_inference.csv
new file mode 100644
index 000000000000..c889ba0e8d2f
--- /dev/null
+++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_timm_inference.csv
@@ -0,0 +1,245 @@
+name,accuracy,graph_breaks
+
+
+
+adv_inception_v3,pass,0
+
+
+
+beit_base_patch16_224,pass,0
+
+
+
+botnet26t_256,pass,0
+
+
+
+cait_m36_384,pass,0
+
+
+
+coat_lite_mini,pass,0
+
+
+
+convit_base,pass,0
+
+
+
+convmixer_768_32,pass,0
+
+
+
+convnext_base,pass,0
+
+
+
+crossvit_9_240,pass,0
+
+
+
+cspdarknet53,pass,0
+
+
+
+deit_base_distilled_patch16_224,pass,0
+
+
+
+dla102,pass,0
+
+
+
+dm_nfnet_f0,pass,0
+
+
+
+dpn107,pass,0
+
+
+
+eca_botnext26ts_256,pass,0
+
+
+
+eca_halonext26ts,pass,0
+
+
+
+ese_vovnet19b_dw,pass,0
+
+
+
+fbnetc_100,pass,0
+
+
+
+fbnetv3_b,pass,0
+
+
+
+gernet_l,pass,0
+
+
+
+ghostnet_100,pass,0
+
+
+
+gluon_inception_v3,pass,0
+
+
+
+gmixer_24_224,pass,0
+
+
+
+gmlp_s16_224,pass,0
+
+
+
+hrnet_w18,pass,0
+
+
+
+inception_v3,pass,0
+
+
+
+jx_nest_base,pass,0
+
+
+
+lcnet_050,pass,0
+
+
+
+levit_128,pass,0
+
+
+
+mixer_b16_224,pass,0
+
+
+
+mixnet_l,pass,0
+
+
+
+mnasnet_100,pass,0
+
+
+
+mobilenetv2_100,pass,0
+
+
+
+mobilenetv3_large_100,pass,0
+
+
+
+mobilevit_s,pass,0
+
+
+
+nfnet_l0,pass,0
+
+
+
+pit_b_224,pass,0
+
+
+
+pnasnet5large,pass,0
+
+
+
+poolformer_m36,pass,0
+
+
+
+regnety_002,pass,0
+
+
+
+repvgg_a2,pass,0
+
+
+
+res2net101_26w_4s,pass,0
+
+
+
+res2net50_14w_8s,pass,0
+
+
+
+res2next50,pass,0
+
+
+
+resmlp_12_224,pass,0
+
+
+
+resnest101e,pass,0
+
+
+
+rexnet_100,pass,0
+
+
+
+sebotnet33ts_256,pass,0
+
+
+
+selecsls42b,pass,0
+
+
+
+spnasnet_100,pass,0
+
+
+
+swin_base_patch4_window7_224,pass,0
+
+
+
+swsl_resnext101_32x16d,pass,0
+
+
+
+tf_efficientnet_b0,pass,0
+
+
+
+tf_mixnet_l,pass,0
+
+
+
+tinynet_a,pass,0
+
+
+
+tnt_s_patch16_224,pass,0
+
+
+
+twins_pcpvt_base,pass,0
+
+
+
+visformer_small,pass,0
+
+
+
+vit_base_patch16_224,pass,0
+
+
+
+volo_d1_224,pass,0
+
+
+
+xcit_large_24_p8_224,pass,0
diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv
new file mode 100644
index 000000000000..d92d45f6b0bf
--- /dev/null
+++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv
@@ -0,0 +1,301 @@
+name,accuracy,graph_breaks
+
+
+
+BERT_pytorch,pass,0
+
+
+
+Background_Matting,pass_due_to_skip,0
+
+
+
+LearningToPaint,pass,0
+
+
+
+Super_SloMo,pass,0
+
+
+
+alexnet,pass,0
+
+
+
+basic_gnn_edgecnn,pass,0
+
+
+
+basic_gnn_gcn,pass,6
+
+
+
+basic_gnn_gin,pass,0
+
+
+
+basic_gnn_sage,pass,0
+
+
+
+dcgan,pass,0
+
+
+
+demucs,pass,3
+
+
+
+densenet121,pass,0
+
+
+
+detectron2_fasterrcnn_r_101_c4,infra_error,0
+
+
+
+detectron2_fasterrcnn_r_101_dc5,infra_error,0
+
+
+
+detectron2_fasterrcnn_r_101_fpn,infra_error,0
+
+
+
+detectron2_fasterrcnn_r_50_c4,infra_error,0
+
+
+
+detectron2_fasterrcnn_r_50_dc5,infra_error,0
+
+
+
+detectron2_fasterrcnn_r_50_fpn,infra_error,0
+
+
+
+detectron2_fcos_r_50_fpn,infra_error,0
+
+
+
+detectron2_maskrcnn_r_101_c4,infra_error,0
+
+
+
+detectron2_maskrcnn_r_101_fpn,infra_error,0
+
+
+
+detectron2_maskrcnn_r_50_c4,infra_error,0
+
+
+
+detectron2_maskrcnn_r_50_fpn,infra_error,0
+
+
+
+dlrm,pass,0
+
+
+
+doctr_det_predictor,pass,5
+
+
+
+doctr_reco_predictor,pass,4
+
+
+
+drq,pass,0
+
+
+
+fastNLP_Bert,pass,4
+
+
+
+functorch_dp_cifar10,pass,0
+
+
+
+functorch_maml_omniglot,pass,0
+
+
+
+hf_Albert,pass,0
+
+
+
+hf_Bart,pass,0
+
+
+
+hf_DistilBert,pass,0
+
+
+
+hf_GPT2,pass,0
+
+
+
+hf_GPT2_large,pass_due_to_skip,0
+
+
+
+hf_Reformer,pass,5
+
+
+
+hf_T5_base,pass,0
+
+
+
+hf_T5_large,pass_due_to_skip,0
+
+
+
+lennard_jones,pass,0
+
+
+
+llama,pass,0
+
+
+
+maml_omniglot,pass,0
+
+
+
+mnasnet1_0,pass,0
+
+
+
+mobilenet_v2,pass,0
+
+
+
+mobilenet_v3_large,pass,0
+
+
+
+nvidia_deeprecommender,pass,0
+
+
+
+opacus_cifar10,pass,0
+
+
+
+phi_1_5,pass,74
+
+
+
+phlippe_densenet,pass,0
+
+
+
+phlippe_resnet,pass,0
+
+
+
+pyhpc_equation_of_state,pass,0
+
+
+
+pyhpc_isoneutral_mixing,pass,0
+
+
+
+pyhpc_turbulent_kinetic_energy,infra_error,0
+
+
+
+pytorch_CycleGAN_and_pix2pix,pass,0
+
+
+
+pytorch_stargan,pass,0
+
+
+
+pytorch_unet,pass,0
+
+
+
+resnet152,pass,0
+
+
+
+resnet18,pass,0
+
+
+
+resnet50,pass,0
+
+
+
+resnext50_32x4d,pass,0
+
+
+
+shufflenet_v2_x1_0,pass,0
+
+
+
+soft_actor_critic,pass,0
+
+
+
+speech_transformer,pass,10
+
+
+
+squeezenet1_1,pass,0
+
+
+
+stable_diffusion_unet,pass_due_to_skip,0
+
+
+
+timm_efficientnet,pass,0
+
+
+
+timm_nfnet,pass,0
+
+
+
+timm_regnet,pass,0
+
+
+
+timm_resnest,pass,0
+
+
+
+timm_vision_transformer,pass,0
+
+
+
+timm_vision_transformer_large,pass_due_to_skip,0
+
+
+
+timm_vovnet,pass,0
+
+
+
+tts_angular,pass,2
+
+
+
+vgg16,pass,0
+
+
+
+vision_maskrcnn,pass,29
+
+
+
+yolov3,pass,2
diff --git a/benchmarks/dynamo/ci_expected_accuracy/update_expected.py b/benchmarks/dynamo/ci_expected_accuracy/update_expected.py
index 0835cd8b024e..5d73cf658c17 100644
--- a/benchmarks/dynamo/ci_expected_accuracy/update_expected.py
+++ b/benchmarks/dynamo/ci_expected_accuracy/update_expected.py
@@ -151,10 +151,12 @@ def apply_lints(filename):
             [
                 "aot_eager",
                 "aot_inductor",
+                "cpu_inductor",
                 "dynamic_aot_eager",
+                "dynamic_cpu_inductor",
+                "dynamic_inductor",
                 "dynamo_eager",
                 "inductor",
-                "dynamic_inductor",
             ],
             ["huggingface", "timm", "torchbench"],
         )
diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py
index 17d090ec8221..21387c123158 100644
--- a/benchmarks/dynamo/common.py
+++ b/benchmarks/dynamo/common.py
@@ -110,49 +110,6 @@ class CI(NamedTuple):
     device: str = "cuda"
 
 
-CI_SKIP = collections.defaultdict(list)
-
-CI_SKIP[CI("inductor", training=False, device="cpu")] = [
-    # TorchBench
-    "drq",  # Need to update torchbench
-    "detectron2_fasterrcnn_r_101_c4",
-    "detectron2_fasterrcnn_r_101_dc5",
-    "detectron2_fasterrcnn_r_101_fpn",
-    "detectron2_fasterrcnn_r_50_c4",
-    "detectron2_fasterrcnn_r_50_dc5",
-    "detectron2_fasterrcnn_r_50_fpn",
-    "detectron2_fcos_r_50_fpn",
-    "detectron2_maskrcnn_r_101_c4",
-    "detectron2_maskrcnn_r_101_fpn",
-    "detectron2_maskrcnn_r_50_c4",
-    "detectron2_maskrcnn_r_50_fpn",
-    "doctr_det_predictor",  # requires newer gcc
-    "doctr_reco_predictor",  # requires newer gcc
-    "gat",  # does not work with fp32
-    "gcn",  # does not work with fp32
-    "hf_Bert_large",  # OOM
-    "hf_GPT2_large",  # Intermittent failure on CI
-    "hf_T5_base",  # OOM
-    "mobilenet_v2_quantized_qat",
-    "pyhpc_turbulent_kinetic_energy",
-    "resnet50_quantized_qat",  # Eager model failed to run(Quantize only works on Float Tensor, got Double)
-    "sage",  # does not work with fp32
-    # Huggingface
-    "MBartForConditionalGeneration",  # Accuracy https://github.com/pytorch/pytorch/issues/94793
-    "PLBartForConditionalGeneration",  # Accuracy https://github.com/pytorch/pytorch/issues/94794
-    # TIMM
-    "cait_m36_384",  # Accuracy
-    "pnasnet5large",  # OOM
-    "xcit_large_24_p8_224",  # OOM https://github.com/pytorch/pytorch/issues/95984
-    "opacus_cifar10",  # Fails to run https://github.com/pytorch/pytorch/issues/99201
-]
-
-CI_SKIP[CI("inductor", training=False, dynamic=True, device="cpu")] = [
-    *CI_SKIP[CI("inductor", training=False, device="cpu")],
-    "pyhpc_isoneutral_mixing",
-    "dpn107",
-]
-
 CI_SKIP_OPTIMIZER = {
     # TIMM
     "convmixer_768_32",  # accuracy
@@ -2728,17 +2685,6 @@ def parse_args(args=None):
     parser.add_argument(
         "--ci", action="store_true", help="Flag to tell that its a CI run"
     )
-    parser.add_argument(
-        "--dynamic-ci-skips-only",
-        action="store_true",
-        help=(
-            "Run only the models that would have been skipped in CI "
-            "if dynamic-shapes, compared to running without dynamic-shapes.  "
-            "This is useful for checking if more models are now "
-            "successfully passing with dynamic shapes.  "
-            "Implies --dynamic-shapes and --ci"
-        ),
-    )
     parser.add_argument(
         "--dashboard", action="store_true", help="Flag to tell that its a Dashboard run"
     )
@@ -3194,9 +3140,6 @@ def run(runner, args, original_dir=None):
     if args.inductor:
         assert args.backend is None
         args.backend = "inductor"
-    if args.dynamic_ci_skips_only:
-        args.dynamic_shapes = True
-        args.ci = True
     if args.dynamic_batch_only:
         args.dynamic_shapes = True
         torch._dynamo.config.assume_static_by_default = True
@@ -3213,20 +3156,9 @@ def run(runner, args, original_dir=None):
             # Set translation validation on by default on CI accuracy runs.
             torch.fx.experimental._config.translation_validation = True
 
-        if args.dynamic_ci_skips_only:
-            # Test only the incremental set of jobs whose skipped was
-            # caused solely by turning on dynamic shapes
-            assert args.dynamic_shapes
-            ci = functools.partial(CI, args.backend, training=args.training)
-            args.filter = list(
-                set(CI_SKIP[ci(dynamic=True)]) - set(CI_SKIP[ci(dynamic=False)])
-            )
-        else:
-            ci = functools.partial(
-                CI, args.backend, training=args.training, dynamic=args.dynamic_shapes
-            )
-            for device in args.devices:
-                args.exclude_exact.extend(CI_SKIP[ci(device=device)])
+        ci = functools.partial(
+            CI, args.backend, training=args.training, dynamic=args.dynamic_shapes
+        )
     if args.ddp:
         # TODO: we could also hook DDP bench up to --speedup bench, _not_ for mgpu e2e perf,
         # but just to measure impact on singlenode of performing graph-breaks.

From 54d04553eaa2b67957f883f5b60ea007c1040a71 Mon Sep 17 00:00:00 2001
From: Jon Chuang <9093549+jon-chuang@users.noreply.github.com>
Date: Sun, 19 Nov 2023 01:48:43 -0500
Subject: [PATCH 066/221] [fx, DDP] fx.split_module will setup/unwind autocast
 & grad_mode (#113374)

---

Replaces: https://github.com/pytorch/pytorch/pull/112231
Fixes: https://github.com/pytorch/pytorch/issues/111794

DDPOptimizer splits modules. We need to setup/unwind global states (autocast, grad_enabled) for each split, as this affects downstream compilation.

---

See before and after this PR for the split fx modules here (for autocast mode): https://github.com/pytorch/pytorch/pull/112231#issuecomment-1804274605

---

### Discussion
We don't actually have to do this for grad mode: https://github.com/pytorch/pytorch/pull/112231#issuecomment-1804280031. It's not wrong to do it anyway, but maybe unnecessary? But may still be better to keep this PR's changes so we're sure what the grad mode state ought to be for each subgraph.

It may come in handy in the future.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113374
Approved by: https://github.com/wconstab
---
 test/distributed/test_dynamo_distributed.py |  70 +++++++++--
 torch/fx/passes/split_module.py             | 133 ++++++++++++++++++--
 2 files changed, 187 insertions(+), 16 deletions(-)

diff --git a/test/distributed/test_dynamo_distributed.py b/test/distributed/test_dynamo_distributed.py
index 1db04610b42e..f4426e8ef60c 100644
--- a/test/distributed/test_dynamo_distributed.py
+++ b/test/distributed/test_dynamo_distributed.py
@@ -47,8 +47,9 @@ def init_weights(m):
         m.bias.data.fill_(0.01)
 
 class ToyModel(nn.Module):
-    def __init__(self, in_feat=10, hidden_feat=5000, out_feat=5):
+    def __init__(self, in_feat=10, hidden_feat=5000, out_feat=5, ctx_manager=None):
         super().__init__()
+        self.ctx_manager = ctx_manager
         self.net = nn.Sequential(
             *[nn.Linear(in_feat, hidden_feat), nn.ReLU()]
             + [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()]
@@ -57,10 +58,14 @@ def __init__(self, in_feat=10, hidden_feat=5000, out_feat=5):
         )
 
     def forward(self, inputs):
-        return self.net(inputs)
-
-def get_model(device, bsz=20, in_feat=10, hidden_feat=5000, out_feat=5):
-    m = ToyModel(in_feat=in_feat, hidden_feat=hidden_feat, out_feat=out_feat).to(device)
+        if self.ctx_manager is not None:
+            with self.ctx_manager():
+                return self.net(inputs)
+        else:
+            return self.net(inputs)
+
+def get_model(device, bsz=20, in_feat=10, hidden_feat=5000, out_feat=5, ctx_manager=None):
+    m = ToyModel(in_feat=in_feat, hidden_feat=hidden_feat, out_feat=out_feat, ctx_manager=ctx_manager).to(device)
     m.apply(init_weights)
     inputs = torch.rand(bsz, in_feat).to(device)
     outputs = m(inputs)
@@ -508,8 +513,8 @@ class TestSingleProc(DynamoDistributedSingleProcTestCase):
     Use TestMultiProc for things that really need to run on multiple nodes
     """
 
-    def get_model(self, bsz=20, in_feat=10, hidden_feat=5000, out_feat=5):
-        m = ToyModel(in_feat=in_feat, hidden_feat=hidden_feat, out_feat=out_feat).to(self.device)
+    def get_model(self, bsz=20, in_feat=10, hidden_feat=5000, out_feat=5, ctx_manager=None):
+        m = ToyModel(in_feat=in_feat, hidden_feat=hidden_feat, out_feat=out_feat, ctx_manager=ctx_manager).to(self.device)
         m.apply(init_weights)
         inputs = torch.rand(bsz, in_feat).to(self.device)
         outputs = m(inputs)
@@ -565,6 +570,57 @@ def opt_fn(inputs):
         self.assertEqual(len(break_reasons), 3)
         self.assertTrue(all("DDPOptimizer" in r.reason for r in break_reasons))
 
+    @patch.object(config, "optimize_ddp", True)
+    def test_graph_split_ctx_manager(self):
+        """
+        Ensures that we get the right number of splits and that the respective
+        context managers' effects are applied to the computation.
+        """
+
+        for get_compiler in [
+            lambda: CheckSplitsCompiler(),
+            lambda: None,
+        ]:
+            for ctx_manager, output_test in [
+                (
+                    lambda: torch.autocast(torch.device(self.device).type, torch.float16),
+                    lambda out: self.assertEqual(out.dtype, torch.float16),
+                ),
+                (
+                    torch.enable_grad,
+                    lambda out: self.assertTrue(out.requires_grad)
+                ),
+                (
+                    torch.no_grad,
+                    lambda out: self.assertTrue(not out.requires_grad)
+                ),
+            ]:
+                m, inputs, correct_outputs = self.get_model(out_feat=1000, hidden_feat=1000, in_feat=1000, ctx_manager=ctx_manager)
+                # inp - 1000 * 1000 matrix of float32 (4 bytes) = 4MB
+                # hidden - 1000 * 1000 matrix of float32 (4 bytes) = 4MB
+                bucket_cap_mb = 3.5  # 4MB
+                ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=bucket_cap_mb)
+
+                compiler = get_compiler()
+
+                @torch._dynamo.optimize(compiler.compile_fn if compiler else "aot_eager")
+                def opt_fn(inputs):
+                    return ddp_m(inputs)
+
+                opt_outputs = opt_fn(inputs)
+                self.assertTrue(same(correct_outputs, opt_outputs))
+                if compiler:
+                    self.assertEqual(compiler.compiler_called, 4)
+
+                output_test(opt_outputs)
+
+                # ensure compatibility with dynamo explain
+
+                explain_out = torch._dynamo.explain(ddp_m)(inputs)
+                break_reasons = explain_out.break_reasons
+                self.assertEqual(len(break_reasons), 4)
+                self.assertTrue(all("DDPOptimizer" in r.reason for r in break_reasons))
+
     @patch.object(config, "optimize_ddp", True)
     @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
     def test_graph_split_inductor(self):
diff --git a/torch/fx/passes/split_module.py b/torch/fx/passes/split_module.py
index 7c96a9c49615..d175d6d1de88 100644
--- a/torch/fx/passes/split_module.py
+++ b/torch/fx/passes/split_module.py
@@ -1,5 +1,7 @@
 import inspect
-from typing import Any, Callable, Dict, List, Optional
+from typing import Any, Callable, Dict, List, Optional, Set
+from collections import OrderedDict
+import logging
 
 import torch
 from torch.fx._compatibility import compatibility
@@ -7,7 +9,7 @@
 from torch.fx.node import Node
 
 __all__ = ["Partition", "split_module"]
-
+_LOGGER = logging.getLogger(__name__)
 
 @compatibility(is_backward_compatible=True)
 class Partition:
@@ -194,6 +196,76 @@ def instantiate_node_partition_mapping(node):
         partition.node_names.append(node.name)
         node._fx_partition = partition_name
 
+    # Global State Nodes are nodes which by their global state effects,
+    # "taint" all downstream nodes while they are active.
+    GLOBAL_STATE_NODES = [
+        torch.amp._enter_autocast,
+        torch.amp._exit_autocast,
+        torch._C._set_grad_enabled
+    ]
+
+    # For grad regions:
+    # ------------------------
+    # 1. first region: we do nothing
+    # 2. subsequent regions: we insert the set_grad at the beginning
+    grad_regions: OrderedDict[Node, Set[int]] = OrderedDict()
+
+    # For autocast regions:
+    # ------------------------
+    # 1. first region: we will only insert the _exit at the end
+    # 2. intermediate regions: we will insert both the
+    #    _enter at the beginning and _exit at the end
+    # 3. last region: we will only insert _enter at the beginning
+    # We will do so in the order in which the autocasts were instantiated.
+    autocast_regions: OrderedDict[Node, Set[int]] = OrderedDict()
+    autocast_exits: Dict[Node, Optional[Node]] = {}
+
+    active_grad = None
+    active_autocasts = set()
+
+    for node in m.graph.nodes:
+        if node.op in ["placeholder", "get_attr", "output"]:
+            continue
+
+        instantiate_node_partition_mapping(node)
+
+        if node.op == "call_function" and node.target in GLOBAL_STATE_NODES:
+            if node.target == torch._C._set_grad_enabled:
+                assert len(node.args) == 1
+                assert isinstance(node.args[0], bool)
+                active_grad = node
+                grad_regions[active_grad] = set({split_callback(node)})
+            elif node.target == torch.amp._enter_autocast:
+                # Should all be python constants
+                assert all(not isinstance(arg, Node) for arg in node.args)
+                active_autocasts.add(node)
+                autocast_regions[node] = set({split_callback(node)})
+                autocast_exits[node] = None
+            elif node.target == torch.amp._exit_autocast:
+                assert len(node.args) == 1
+                autocast_regions[node.args[0]].add(split_callback(node))
+                active_autocasts.remove(node.args[0])
+                autocast_exits[node.args[0]] = node
+
+        if active_grad is not None:
+            grad_regions[active_grad].add(split_callback(node))
+
+        for a in active_autocasts:
+            autocast_regions[a].add(split_callback(node))
+
+    assert all(v is not None for v in autocast_exits.values()), "autocast must exit"
+
+    autocast_regions = {k: sorted(v) for k, v in autocast_regions.items()}
+    grad_regions = {k: sorted(v) for k, v in grad_regions.items()}
+
+    if _LOGGER.isEnabledFor(logging.DEBUG):
+        _LOGGER.debug("autocast_regions: %s", autocast_regions)
+        _LOGGER.debug("grad_regions: %s", grad_regions)
+
+    assert_monotonically_increasing = bool(autocast_regions) or bool(grad_regions)
+
+    # split nodes into partitions
+    highest_partition = -1
     for node in m.graph.nodes:
         orig_nodes[node.name] = node
 
@@ -207,14 +279,22 @@ def instantiate_node_partition_mapping(node):
             )
             continue
 
-        instantiate_node_partition_mapping(node)
+        if assert_monotonically_increasing:
+            pid = split_callback(node)
+            assert highest_partition <= pid,\
+                ("autocast or set_grad_enabled require monotonically increasing partitions:"
+                 f"highest: {highest_partition}, this node's: {pid}")
+            highest_partition = pid
 
-        torch.fx.graph.map_arg(
-            node.args, lambda def_node: record_cross_partition_use(def_node, node)
-        )
-        torch.fx.graph.map_arg(
-            node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)
-        )  # noqa: B950
+        # do not capture cross-partition dependencies for global state nodes as they will be
+        # self-contained - their setup and unwind will be isolated to each partition submodule.
+        if node.target not in GLOBAL_STATE_NODES:
+            torch.fx.graph.map_arg(
+                node.args, lambda def_node: record_cross_partition_use(def_node, node)
+            )
+            torch.fx.graph.map_arg(
+                node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)
+            )  # noqa: B950
 
     original_partition_order = list(partitions.keys())
     # find partitions with no dependencies
@@ -235,6 +315,23 @@ def instantiate_node_partition_mapping(node):
     if len(sorted_partitions) != len(partitions):
         raise RuntimeError("cycle exists between partitions!")
 
+    # Enter prelude
+    for regions_mapping in [autocast_regions, grad_regions]:
+        for node, regions in regions_mapping.items():
+            assert len(regions) > 0
+            partitions[str(regions[0])].environment[node] = node
+            for r in regions[1:]:
+                partition = partitions[str(r)]
+                new_node = partition.graph.create_node(
+                    op=node.op,
+                    target=node.target,
+                    args=tuple(arg for arg in node.args),
+                    kwargs={},
+                    type_expr=node.type,
+                )
+                new_node.meta = node.meta.copy()  # is it really a good idea to copy this?
+                partition.environment[node] = new_node
+
     # add placeholders to partition inputs
     for partition_name in sorted_partitions:
         partition = partitions[partition_name]
@@ -289,6 +386,24 @@ def instantiate_node_partition_mapping(node):
             new_node.meta = node.meta.copy()
             partition.environment[node] = new_node
 
+    # Exit epilogue
+    for regions_mapping in [autocast_regions]:
+        for node in reversed(regions_mapping):
+            regions = regions_mapping[node]
+            assert len(regions) > 0
+            for r in regions[:-1]:
+                partition = partitions[str(r)]
+                exit_node = autocast_exits[node]
+                assert exit_node is not None, "Missing exit node"
+                new_node = partition.graph.create_node(
+                    op=exit_node.op,
+                    target=exit_node.target,
+                    args=(partition.environment[node],),
+                    kwargs={},
+                    type_expr=exit_node.type,
+                )
+                new_node.meta = exit_node.meta.copy()  # is it really a good idea to copy this?
+
     # original module environment dict mapping node names to nodes
     orig_mod_env: Dict[str, Node] = {}
     # Set up values to construct base module

From 266054c3cac0f800f37348aea1409c4759dd2315 Mon Sep 17 00:00:00 2001
From: Jon Chuang <9093549+jon-chuang@users.noreply.github.com>
Date: Tue, 21 Nov 2023 22:40:04 +0000
Subject: [PATCH 067/221] [dynamo / DDP] - lazily compile submodules - to
 propagate real tensor strides to backend compiler (#114154)

Fixes https://github.com/pytorch/pytorch/issues/113812, https://github.com/pytorch/pytorch/issues/102591, Probably fixes: https://github.com/pytorch/pytorch/issues/113740, https://github.com/pytorch/pytorch/issues/113786, https://github.com/pytorch/pytorch/issues/113788

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114154
Approved by: https://github.com/wconstab
---
 test/distributed/test_dynamo_distributed.py |  41 ++++++
 torch/_dynamo/backends/distributed.py       | 141 +++++++-------------
 2 files changed, 90 insertions(+), 92 deletions(-)

diff --git a/test/distributed/test_dynamo_distributed.py b/test/distributed/test_dynamo_distributed.py
index f4426e8ef60c..24298c671538 100644
--- a/test/distributed/test_dynamo_distributed.py
+++ b/test/distributed/test_dynamo_distributed.py
@@ -543,6 +543,7 @@ def test_ddp_baseline_inductor(self):
 
     @patch.object(config, "optimize_ddp", True)
     def test_graph_split(self):
+        assert config.optimize_ddp
         """
         Just ensures that the appropriate number of splits happen (based on
         bucket size and model parameters) - verifies the number of times
@@ -624,6 +625,7 @@ def opt_fn(inputs):
     @patch.object(config, "optimize_ddp", True)
     @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
     def test_graph_split_inductor(self):
+        assert config.optimize_ddp
         """
         Same as above, but using inductor backend.
         We observed issues with inductor/fx interface in the past.
@@ -638,6 +640,45 @@ def opt_fn(inputs):
         opt_outputs = opt_fn(inputs)
         self.assertTrue(same(correct_outputs, opt_outputs))
 
+    @torch._inductor.config.patch({"layout_optimization": True, "keep_output_stride": False})
+    @patch.object(config, "optimize_ddp", True)
+    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
+    def test_graph_split_inductor_layout_optimizations(self):
+        assert config.optimize_ddp
+        channel_dim = 512
+        # channel dim must be > 64 for inductor to do layout optimization and use NHWC
+
+        class ToyModelConv(nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.net = nn.Sequential(
+                    *[nn.Conv2d(channel_dim, channel_dim, 1, stride=1, bias=False), nn.ReLU()]
+                    + [nn.Conv2d(channel_dim, channel_dim, 1, stride=1, bias=False), nn.ReLU()]
+                    + [nn.Conv2d(channel_dim, channel_dim, 1, stride=1, bias=False), nn.ReLU()]
+                    + [nn.Conv2d(channel_dim, channel_dim, 1, stride=1, bias=False), nn.ReLU()]
+                )
+
+            def forward(self, inputs):
+                return self.net(inputs)
+
+        def get_model():
+            m = ToyModelConv().to(self.device)
+            m.apply(init_weights)
+            inputs = torch.rand(2, channel_dim, channel_dim, 128).to(self.device)
+            outputs = m(inputs)
+            return m, inputs, outputs
+
+        m, inputs, correct_outputs = get_model()
+        ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25)
+
+        @torch._dynamo.optimize("inductor")
+        def opt_fn(inputs):
+            return ddp_m(inputs)
+
+        opt_outputs = opt_fn(inputs)
+        self.assertTrue(same(correct_outputs, opt_outputs))
+
+
     @patch.object(config, "optimize_ddp", True)
     def test_no_split(self):
         """
diff --git a/torch/_dynamo/backends/distributed.py b/torch/_dynamo/backends/distributed.py
index adc68bb30bff..774f378b15b2 100644
--- a/torch/_dynamo/backends/distributed.py
+++ b/torch/_dynamo/backends/distributed.py
@@ -6,7 +6,8 @@
 import torch
 from torch import fx
 from torch._dynamo.output_graph import GraphCompileReason
-from torch._dynamo.utils import deepcopy_to_fake_tensor, detect_fake_mode
+from torch._dynamo.utils import detect_fake_mode
+from torch._subclasses.fake_tensor import is_fake
 from torch.fx.node import Node
 
 log = logging.getLogger(__name__)
@@ -214,23 +215,6 @@ def compile_fn(self, gm: fx.GraphModule, example_inputs: List[torch.Tensor]):
         and returns its callable.
         """
 
-        # Today, optimize_ddp=True and keep_output_stride=False can lead to silent
-        # correctness issues. The problem is that ddp_optimizer works by partitioning
-        # the dynamo graph, sending each subgraph through aot autograd to inductor,
-        # and creates example inputs by eagerly interpreting each subgraph to get
-        # an output that with the same metadata that we'd get from eager mode.
-        # This is a problem though, for torch._inductor.config.keep_output_stride.
-        # The above config can cause the outputs of the first graph to have
-        # **different** strides from eager, causing the inputs that we pass
-        # to the second graph to be wrong.
-        # To really fix this, we would need to faithfully ask inductor
-        # what the outputs to each graph it expects are.
-        assert torch._inductor.config.keep_output_stride, """\
-Detected that you are running DDP with torch.compile, along with these two flags:
-- torch._dynamo.config.optimize_ddp = True
-- torch._inductor.config.keep_output_stride = False
-This combination of flags is incompatible. Please set keep_output_stride to False,
-or file a github issue."""
         fake_mode = detect_fake_mode(example_inputs)
         if fake_mode is None:
             fake_mode = torch._subclasses.fake_tensor.FakeTensorMode()
@@ -329,32 +313,54 @@ def compile_fn(self, gm: fx.GraphModule, example_inputs: List[torch.Tensor]):
         debug_str += "\n---------------\n"
         ddp_graph_log.debug(debug_str)
 
-        # 3: compile each of the partitioned submodules using the user-provided compiler
-        class SubmodCompiler(torch.fx.interpreter.Interpreter):
+        # 3: Replace submodules with lazily compiling submodule
+        class SubmoduleReplacer(torch.fx.interpreter.Interpreter):
             def __init__(self, module, compiler):
                 super().__init__(module)
                 self.compiler = compiler
 
-            def compile_submod(self, input_mod, args, kwargs):
+            def lazily_compiled_submod(self, input_mod):
                 """
-                Compile the submodule,
-                using a wrapper to make sure its output is always a tuple,
-                which is required by AotAutograd based compilers
+                Create a wrapper around submodules which:
+                - lazily compiles each of the partitioned submodules using the user-provided compiler
+                - unpacks singleton tuples/lists into flat arg
                 """
-                assert len(kwargs) == 0, "We assume only args for these modules"
 
-                class WrapperModule(torch.nn.Module):
-                    def __init__(self, submod, unwrap_singleton_tuple):
+                class LazilyCompiledModule(torch.nn.Module):
+                    def __init__(self, submod, compiler, unwrap_singleton_tuple):
                         super().__init__()
                         self.submod = submod
+                        self.compiler = compiler
+                        self.compiled = False
                         self.unwrap_singleton_tuple = unwrap_singleton_tuple
 
                     def forward(self, *args):
+                        if not self.compiled:
+                            assert (
+                                fake_mode
+                            ), "fake mode must have been available when creating lazy submod"
+                            fake_args = []
+                            for arg in args:
+                                if isinstance(arg, torch.Tensor) and not is_fake(arg):
+                                    fake_args.append(
+                                        torch._dynamo.utils.to_fake_tensor(
+                                            arg, fake_mode
+                                        )
+                                    )
+                                else:
+                                    fake_args.append(arg)
+                            # First trace with fake args
+                            new_submod = self.compiler(self.submod, tuple(fake_args))
+                            del self.submod
+                            self.submod = new_submod
+                            self.compiled = True
+                            self.compiler = None
+
                         x = self.submod(*args)
-                        # TODO(whc)
-                        # for some reason the isinstance check is necessary if I split one node per submod
-                        # - even though I supposedly wrapped the output in a tuple in those cases, the real
-                        # compiled module was still returning a tensor
+                        # we must let 'input_mod' return a tuple, to make AOT happy.
+                        # (aot_autograd compile_fn literally requires that the output of a graph it compiles is a tuple).
+                        # however, we don't acutally want this tuple to be returned, since the fx logic that calls the submod
+                        # will again wrap outputs from the submod in a tuple.  So we unwrap it, and count on it being re-wrapped
                         if self.unwrap_singleton_tuple and isinstance(x, (tuple, list)):
                             return x[0]
                         return x
@@ -375,84 +381,35 @@ def forward(self, *args):
                         traceback.FrameSummary(__file__, 0, DDPOptimizer),
                     ],
                 )
-                wrapper = WrapperModule(
-                    self.compiler(input_mod, args),
+                wrapper = LazilyCompiledModule(
+                    input_mod,
+                    self.compiler,
                     unwrap_singleton_tuple,
                 )
                 return wrapper
 
-            # Note:
-            #
-            # The way distributed works today around fake tensors can be somewhat confusing.
-            # Some of these codepaths are shared in both runtime, and compile time. The presence
-            # of a fake_mode, read off of fake tensor inputs, dictates how we will operate.
-            #
-            # A few things to keep in mind:
-            #
-            # 1) We invoke `compile_submod` with a real module. The output of that gets stored
-            # on the graph via `self.module.add_submodule(n.target, compiled_submod_real)`.
-            #
-            # 2) When running a call_module targeted node, if we have a fake_mode, we fakify the
-            # module we got from self.fetch_attr(n.target). Regardless of fake_mode, we then execute it.
-            #
-            # 3) Fake tensors should always be around during compile time.
-            #
-            # 4) Fake tensors should never be around at runtime.
-            #
-            # 5) We end up with a compilation mode that takes a real submodule and fake tensors,
-            # to match what aot_autograd expects. See Note: [Fake Modules and AOTAutograd]
+            # We replace the submodules with lazy submodules which compile
+            # the corresponding submodules when they are run with real values
+            # Always returns `None` - we do not need to propagate values in order
+            # to replace submodules.
             def run_node(self, n: Node) -> Any:
-                args, kwargs = self.fetch_args_kwargs_from_env(n)
-                new_args = []
-                assert fake_mode
-                for arg in args:
-                    if isinstance(arg, torch.Tensor) and not isinstance(
-                        arg, torch._subclasses.FakeTensor
-                    ):
-                        new_args.append(
-                            torch._dynamo.utils.to_fake_tensor(arg, fake_mode)
-                        )
-                    else:
-                        new_args.append(arg)
-
-                log.debug("run_node %s, %s got args %s", n.op, n.target, args_str(args))
-                assert isinstance(args, tuple)
-                assert isinstance(kwargs, dict)
-
                 if n.op == "call_module":
                     real_mod = self.fetch_attr(n.target)
-                    if fake_mode:
-                        curr_submod = deepcopy_to_fake_tensor(real_mod, fake_mode)
-                    else:
-                        curr_submod = real_mod
 
                     ddp_graph_log.debug(
-                        "\n---%s graph---\n%s", n.target, curr_submod.graph
+                        "\n---%s graph---\n%s", n.target, real_mod.graph
                     )
 
-                    # When calling the compiler on the submod, inputs (new_args) are expected to
-                    # be FakeTensors already since Dynamo would have made them FakeTensors in the
-                    # non-DDP flow.  However, the parameters are _not_ expected to be FakeTensors,
-                    # since this wrapping happens during compilation
-                    compiled_submod_real = self.compile_submod(
-                        real_mod, new_args, kwargs
-                    )
+                    assert len(n.kwargs) == 0, "We assume only args for these modules"
+                    lazily_compiled_submod = self.lazily_compiled_submod(real_mod)
 
                     # We update the original (outer) graph with a call into the compiled module
                     # instead of the uncompiled one.
                     self.module.delete_submodule(n.target)
                     n.target = "compiled_" + n.target
-                    self.module.add_submodule(n.target, compiled_submod_real)
-
-                    # Finally, we have to produce inputs for use compiling the next submodule,
-                    # and these need to be FakeTensors, so we execute the module under fake_mode
-                    with fake_mode:
-                        return curr_submod(*new_args, **kwargs)
-                else:
-                    # placeholder or output nodes don't need to get compiled, just executed
-                    return getattr(self, n.op)(n.target, new_args, kwargs)
+                    self.module.add_submodule(n.target, lazily_compiled_submod)
 
-        submod_compiler = SubmodCompiler(split_gm, self.backend_compile_fn)
+        submod_compiler = SubmoduleReplacer(split_gm, self.backend_compile_fn)
         submod_compiler.run(*example_inputs)
         split_gm.recompile()
 

From 62de29d06f1ddcc7c4b11757adc75f1459ef6991 Mon Sep 17 00:00:00 2001
From: Jon Chuang <9093549+jon-chuang@users.noreply.github.com>
Date: Tue, 21 Nov 2023 22:44:46 +0000
Subject: [PATCH 068/221] [optim] be explicit about CPU scalar tensor dtypes
 (#111008)

Fixes https://github.com/pytorch/pytorch/issues/110940

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111008
Approved by: https://github.com/janeyx99
---
 test/optim/test_optim.py | 26 +++++++++++++++++++++++---
 torch/optim/adagrad.py   |  4 ++--
 torch/optim/adam.py      |  6 +++---
 torch/optim/adamax.py    |  4 ++--
 torch/optim/adamw.py     |  6 +++---
 torch/optim/asgd.py      | 12 ++++++------
 torch/optim/nadam.py     | 12 ++++++------
 torch/optim/radam.py     |  4 ++--
 8 files changed, 47 insertions(+), 27 deletions(-)

diff --git a/test/optim/test_optim.py b/test/optim/test_optim.py
index 5224d6d3d8db..e4c047527263 100644
--- a/test/optim/test_optim.py
+++ b/test/optim/test_optim.py
@@ -737,7 +737,7 @@ def _test_derived_optimizers_varying_tensors(self, optimizer_with_kwargs, kwarg)
                     actual = mt_p_state[k]
                     self.assertEqual(st_p_state[k], actual, rtol=rtol, atol=atol)
 
-    def _test_derived_optimizers(self, optimizer_pairs_with_flags, flag):
+    def _test_derived_optimizers(self, optimizer_pairs_with_flags, flag, reduced_precision=False):
         if not torch.cuda.is_available():
             return
         assert flag in ("foreach", "fused")
@@ -794,15 +794,20 @@ def _test_derived_optimizers(self, optimizer_pairs_with_flags, flag):
 
             st_state = state[0]
             mt_state = state[1]
+
+            assert_eq_kwargs = {}
+            if reduced_precision:
+                assert_eq_kwargs = {'atol': 1e-5, 'rtol': 1e-4}
+
             for st_p, mt_p in zip(res[0], res[1]):
-                self.assertEqual(st_p, mt_p)
+                self.assertEqual(st_p, mt_p, **assert_eq_kwargs)
 
                 # check that optimizer states are the same
                 st_p_state = st_state[st_p]
                 mt_p_state = mt_state[mt_p]
 
                 for k in st_p_state:
-                    self.assertEqual(st_p_state[k], mt_p_state[k])
+                    self.assertEqual(st_p_state[k], mt_p_state[k], **assert_eq_kwargs)
 
     def _test_foreach_memory(self, optimizer_pairs_with_flags):
         if not torch.cuda.is_available():
@@ -959,6 +964,21 @@ def _multi_tensor_optimizer_configs(self):
     def test_multi_tensor_optimizers(self):
         self._test_derived_optimizers(self._multi_tensor_optimizer_configs, "foreach")
 
+    def test_multi_tensor_optimizers_default_dtype(self):
+        # https://github.com/pytorch/pytorch/issues/110940
+        # We coerce step to always be float32
+        default_dtype = torch.tensor(0.0).dtype
+        for dtype in [torch.float64, torch.float16]:
+            try:
+                torch.set_default_dtype(dtype)
+                self._test_derived_optimizers(
+                    self._multi_tensor_optimizer_configs,
+                    "foreach",
+                    reduced_precision=dtype == torch.float16
+                )
+            finally:
+                torch.set_default_dtype(default_dtype)
+
     @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
     def test_multi_tensor_optimizers_with_varying_tensors(self):
         self._test_derived_optimizers_varying_tensors(self._multi_tensor_optimizer_configs, "foreach")
diff --git a/torch/optim/adagrad.py b/torch/optim/adagrad.py
index 8fdaeb61eb47..2634333863f9 100644
--- a/torch/optim/adagrad.py
+++ b/torch/optim/adagrad.py
@@ -50,7 +50,7 @@ def __init__(
         for group in self.param_groups:
             for p in group["params"]:
                 state = self.state[p]
-                state["step"] = torch.tensor(0.0)
+                state["step"] = torch.tensor(0.0, dtype=torch.float32)
                 init_value = (
                     complex(initial_accumulator_value, initial_accumulator_value)
                     if torch.is_complex(p)
@@ -73,7 +73,7 @@ def __setstate__(self, state):
         )
         if not step_is_tensor:
             for s in state_values:
-                s["step"] = torch.tensor(float(s["step"]))
+                s["step"] = torch.tensor(float(s["step"]), dtype=torch.float32)
 
     def share_memory(self):
         for group in self.param_groups:
diff --git a/torch/optim/adam.py b/torch/optim/adam.py
index b14753fdb6b2..fade018c8834 100644
--- a/torch/optim/adam.py
+++ b/torch/optim/adam.py
@@ -75,7 +75,7 @@ def __setstate__(self, state):
         step_is_tensor = (len(state_values) != 0) and torch.is_tensor(state_values[0]['step'])
         if not step_is_tensor:
             for s in state_values:
-                s['step'] = torch.tensor(float(s['step']))
+                s['step'] = torch.tensor(float(s['step']), dtype=torch.float32)
 
     def _init_group(
         self,
@@ -103,9 +103,9 @@ def _init_group(
                     # Deliberately host `step` on CPU if both capturable and fused are off.
                     # This is because kernel launches are costly on CUDA and XLA.
                     state['step'] = (
-                        torch.zeros((), dtype=torch.float, device=p.device)
+                        torch.zeros((), dtype=torch.float32, device=p.device)
                         if group['capturable'] or group['fused']
-                        else torch.tensor(0.)
+                        else torch.tensor(0.0, dtype=torch.float32)
                     )
                     # Exponential moving average of gradient values
                     state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
diff --git a/torch/optim/adamax.py b/torch/optim/adamax.py
index c5b85df909f8..c0b9362abadd 100644
--- a/torch/optim/adamax.py
+++ b/torch/optim/adamax.py
@@ -56,7 +56,7 @@ def __setstate__(self, state):
         )
         if not step_is_tensor:
             for s in state_values:
-                s["step"] = torch.tensor(float(s["step"]))
+                s["step"] = torch.tensor(float(s["step"]), dtype=torch.float32)
 
     def _init_group(self, group, params_with_grad, grads, exp_avgs, exp_infs, state_steps):
         has_complex = False
@@ -73,7 +73,7 @@ def _init_group(self, group, params_with_grad, grads, exp_avgs, exp_infs, state_
 
             # State initialization
             if len(state) == 0:
-                state["step"] = torch.tensor(0.0)
+                state["step"] = torch.tensor(0.0, dtype=torch.float32)
                 state["exp_avg"] = torch.zeros_like(
                     p, memory_format=torch.preserve_format
                 )
diff --git a/torch/optim/adamw.py b/torch/optim/adamw.py
index d7745e3966c2..17d72c66c6db 100644
--- a/torch/optim/adamw.py
+++ b/torch/optim/adamw.py
@@ -85,7 +85,7 @@ def __setstate__(self, state):
         )
         if not step_is_tensor:
             for s in state_values:
-                s["step"] = torch.tensor(float(s["step"]))
+                s["step"] = torch.tensor(float(s["step"]), dtype=torch.float32)
 
     def _init_group(
         self,
@@ -115,9 +115,9 @@ def _init_group(
                 # note(crcrpar): Deliberately host `step` on CPU if both capturable and fused are off.
                 # This is because kernel launches are costly on CUDA and XLA.
                 state["step"] = (
-                    torch.zeros((), dtype=torch.float, device=p.device)
+                    torch.zeros((), dtype=torch.float32, device=p.device)
                     if group["capturable"] or group["fused"]
-                    else torch.tensor(0.0)
+                    else torch.tensor(0.0, dtype=torch.float32)
                 )
                 # Exponential moving average of gradient values
                 state["exp_avg"] = torch.zeros_like(
diff --git a/torch/optim/asgd.py b/torch/optim/asgd.py
index a97f3bf70e20..104550361527 100644
--- a/torch/optim/asgd.py
+++ b/torch/optim/asgd.py
@@ -62,19 +62,19 @@ def __setstate__(self, state):
         )
         if not step_is_tensor:
             for s in state_values:
-                s["step"] = torch.tensor(float(s["step"]))
+                s["step"] = torch.tensor(float(s["step"]), dtype=torch.float32)
         eta_is_tensor = (len(state_values) != 0) and torch.is_tensor(
             state_values[0]["eta"]
         )
         if not eta_is_tensor:
             for s in state_values:
-                s["eta"] = torch.tensor(s["eta"])
+                s["eta"] = torch.tensor(s["eta"], dtype=torch.float32)
         mu_is_tensor = (len(state_values) != 0) and torch.is_tensor(
             state_values[0]["mu"]
         )
         if not mu_is_tensor:
             for s in state_values:
-                s["mu"] = torch.tensor(float(s["mu"]))
+                s["mu"] = torch.tensor(float(s["mu"]), dtype=torch.float32)
 
     def _init_group(self, group, params_with_grad, grads, mus, axs, etas, state_steps):
         has_complex = False
@@ -89,9 +89,9 @@ def _init_group(self, group, params_with_grad, grads, mus, axs, etas, state_step
                 state = self.state[p]
                 # State initialization
                 if len(state) == 0:
-                    state["step"] = torch.zeros((), device=p.device)
-                    state["eta"] = torch.tensor(group["lr"], device=p.device)
-                    state["mu"] = torch.ones((), device=p.device)
+                    state["step"] = torch.zeros((), device=p.device, dtype=torch.float32)
+                    state["eta"] = torch.tensor(group["lr"], device=p.device, dtype=torch.float32)
+                    state["mu"] = torch.ones((), device=p.device, dtype=torch.float32)
                     state["ax"] = torch.zeros_like(
                         p, memory_format=torch.preserve_format
                     )
diff --git a/torch/optim/nadam.py b/torch/optim/nadam.py
index 9ce3fb2ad33c..d1e0abbefbbf 100644
--- a/torch/optim/nadam.py
+++ b/torch/optim/nadam.py
@@ -40,11 +40,11 @@ def __setstate__(self, state):
         step_is_tensor = (len(state_values) != 0) and torch.is_tensor(state_values[0]['step'])
         if not step_is_tensor:
             for s in state_values:
-                s['step'] = torch.tensor(float(s['step']))
+                s['step'] = torch.tensor(float(s['step']), dtype=torch.float32)
         mu_product_is_tensor = (len(state_values) != 0) and torch.is_tensor(state_values[0]['mu_product'])
         if not mu_product_is_tensor:
             for s in state_values:
-                s['mu_product'] = torch.tensor(s['mu_product'])
+                s['mu_product'] = torch.tensor(s['mu_product'], dtype=torch.float32)
 
     def _init_group(self, group, params_with_grad, grads, exp_avgs, exp_avg_sqs, mu_products, state_steps):
         has_complex = False
@@ -63,12 +63,12 @@ def _init_group(self, group, params_with_grad, grads, exp_avgs, exp_avg_sqs, mu_
                     # Deliberately host `step` and `mu_product` on CPU if capturable is False.
                     # This is because kernel launches are costly on CUDA and XLA.
                     state['step'] = (
-                        torch.zeros((), dtype=torch.float, device=p.device)
-                        if group['capturable'] else torch.tensor(0.)
+                        torch.zeros((), dtype=torch.float32, device=p.device)
+                        if group['capturable'] else torch.tensor(0.0, dtype=torch.float32)
                     )
                     state['mu_product'] = (
-                        torch.ones((), dtype=torch.float, device=p.device)
-                        if group['capturable'] else torch.tensor(1.)
+                        torch.ones((), dtype=torch.float32, device=p.device)
+                        if group['capturable'] else torch.tensor(1.0, dtype=torch.float32)
                     )
                     # Exponential moving average of gradient values
                     state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
diff --git a/torch/optim/radam.py b/torch/optim/radam.py
index bde93c949201..60ae225ab495 100644
--- a/torch/optim/radam.py
+++ b/torch/optim/radam.py
@@ -65,7 +65,7 @@ def __setstate__(self, state):
         )
         if not step_is_tensor:
             for s in state_values:
-                s["step"] = torch.tensor(float(s["step"]))
+                s["step"] = torch.tensor(float(s["step"]), dtype=torch.float32)
 
     def _init_group(self, group, params_with_grad, grads, exp_avgs, exp_avg_sqs, state_steps):
         has_complex = False
@@ -80,7 +80,7 @@ def _init_group(self, group, params_with_grad, grads, exp_avgs, exp_avg_sqs, sta
                 state = self.state[p]
                 # Lazy state initialization
                 if len(state) == 0:
-                    state["step"] = torch.tensor(0.0)
+                    state["step"] = torch.tensor(0.0, dtype=torch.float32)
                     # Exponential moving average of gradient values
                     state["exp_avg"] = torch.zeros_like(
                         p, memory_format=torch.preserve_format

From b88abb16748538a5cd81d2e23d506e4902266db1 Mon Sep 17 00:00:00 2001
From: CYuxian 
Date: Tue, 21 Nov 2023 22:45:46 +0000
Subject: [PATCH 069/221] [ONNX] Fix export issue of aten::layer_norm in opset
 17 (#114058)

For torch.nn.LayerNorm, weight and bias could be None(when parameter elementwise_affine is False or bias is False), but for onnx op LayerNormalization from opset 17, weight and bias cannot be None.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114058
Approved by: https://github.com/thiagocrepaldi
---
 test/onnx/test_pytorch_onnx_onnxruntime.py | 10 +++++++---
 torch/onnx/symbolic_opset17.py             | 10 ++++++++++
 2 files changed, 17 insertions(+), 3 deletions(-)

diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py
index f2280b2c4f12..05171d3ef995 100644
--- a/test/onnx/test_pytorch_onnx_onnxruntime.py
+++ b/test/onnx/test_pytorch_onnx_onnxruntime.py
@@ -3977,9 +3977,13 @@ def test_layer_norm(self):
         # As layer_norm works on the last D dimension, please keep
         # this test case at least three dimension to prevent the
         # situation of axis=2 mapping to the same axis as axis=-2
-        model = torch.nn.LayerNorm([10, 10, 10])
-        x = torch.randn(20, 5, 10, 10, 10)
-        self.run_test(model, x)
+        for elementwise_affine in (True, False):
+            for bias in (True, False):
+                model = torch.nn.LayerNorm(
+                    [10, 10, 10], elementwise_affine=elementwise_affine, bias=bias
+                )
+                x = torch.randn(20, 5, 10, 10, 10)
+                self.run_test(model, x)
 
     def test_batchnorm1d(self):
         x = torch.randn(10, 10)
diff --git a/torch/onnx/symbolic_opset17.py b/torch/onnx/symbolic_opset17.py
index 3fa03ab8e119..3aad249a1126 100644
--- a/torch/onnx/symbolic_opset17.py
+++ b/torch/onnx/symbolic_opset17.py
@@ -47,6 +47,16 @@ def layer_norm(
     # layer_norm normalizes on the last D dimensions,
     # where D is the size of normalized_shape
     axis = -len(normalized_shape)
+    scalar_type = _type_utils.JitScalarType.from_value(
+        input, _type_utils.JitScalarType.FLOAT
+    )
+    dtype = scalar_type.dtype()
+    if symbolic_helper._is_none(weight):
+        weight_value = torch.ones(normalized_shape, dtype=dtype)
+        weight = g.op("Constant", value_t=weight_value)
+    if symbolic_helper._is_none(bias):
+        bias_value = torch.zeros(normalized_shape, dtype=dtype)
+        bias = g.op("Constant", value_t=bias_value)
     return g.op(
         "LayerNormalization",
         input,

From 7fc292930c3b8ae5f6dec0a6176d4b5ca0b29d8f Mon Sep 17 00:00:00 2001
From: Antonio Kim 
Date: Tue, 21 Nov 2023 23:07:21 +0000
Subject: [PATCH 070/221] Add support for `torch.Generator` type in TorchScript
 (#110413)

- Add support for `torch.Generator` type in TorchScript
- Add `generator` args to all `torch.nn.init` functions that call `uniform_` or `normal_`
- Add support for `torch.Generator` in LTC's TorchScript backend (CC: @wconstab)

CC: @eellison @davidberard98 @GlebKazantaev @behzad-a
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110413
Approved by: https://github.com/wconstab, https://github.com/albanD, https://github.com/glebk-cerebras, https://github.com/davidberard98
---
 aten/src/ATen/core/ivalue.cpp                 |   8 +
 aten/src/ATen/core/type_factory.cpp           |   1 +
 aten/src/ATen/native/ts_native_functions.yaml |   4 +-
 build_variables.bzl                           |   1 -
 docs/source/jit_unsupported.rst               |   1 -
 test/jit/test_generator.py                    | 195 ++++++++++++++++++
 test/lazy/test_generator.py                   | 103 +++++++++
 test/lazy/test_ts_opinfo.py                   |   7 +
 test/test_jit.py                              |  36 ++++
 test/test_public_bindings.py                  |   1 +
 torch/_C/__init__.pyi.in                      |   4 +
 torch/_decomp/decompositions.py               |   5 +-
 torch/_prims/__init__.py                      |  14 +-
 torch/_refs/__init__.py                       |   2 +-
 torch/csrc/jit/frontend/sugared_value.cpp     |  11 +
 torch/csrc/jit/frontend/tracer.cpp            |  10 +-
 torch/csrc/jit/ir/constants.cpp               |   8 +
 torch/csrc/jit/ir/node_hashing.cpp            |   3 +
 torch/csrc/jit/python/pybind_utils.h          |   2 +
 torch/csrc/jit/python/python_ir.cpp           |   4 +
 torch/csrc/jit/runtime/register_ops_utils.cpp |  44 ++++
 torch/csrc/jit/runtime/register_ops_utils.h   |   6 +
 torch/csrc/jit/runtime/register_prim_ops.cpp  |  39 ++++
 .../csrc/jit/runtime/register_special_ops.cpp |  14 +-
 torch/csrc/lazy/core/hash.h                   |   5 +
 torch/csrc/lazy/core/shape_inference.cpp      |  16 ++
 torch/csrc/lazy/core/shape_inference.h        |   2 +
 torch/csrc/lazy/python/init.cpp               |  15 ++
 torch/csrc/lazy/ts_backend/ops/random_ops.cpp |  47 -----
 torch/csrc/lazy/ts_backend/ops/random_ops.h   |  30 ---
 .../csrc/lazy/ts_backend/tensor_aten_ops.cpp  |   1 -
 .../lazy/ts_backend/ts_native_functions.cpp   |  31 ---
 torch/jit/annotations.py                      |   3 +
 torch/nn/init.py                              |  93 ++++++---
 torch/overrides.py                            |   6 +-
 .../_internal/common_methods_invocations.py   |  19 +-
 torch/testing/_internal/common_utils.py       |  22 ++
 torchgen/api/lazy.py                          |  20 +-
 torchgen/dest/lazy_ir.py                      |  11 +-
 39 files changed, 666 insertions(+), 178 deletions(-)
 create mode 100644 test/jit/test_generator.py
 create mode 100644 test/lazy/test_generator.py
 delete mode 100644 torch/csrc/lazy/ts_backend/ops/random_ops.cpp
 delete mode 100644 torch/csrc/lazy/ts_backend/ops/random_ops.h

diff --git a/aten/src/ATen/core/ivalue.cpp b/aten/src/ATen/core/ivalue.cpp
index eebdd3b330a9..2e98c3649106 100644
--- a/aten/src/ATen/core/ivalue.cpp
+++ b/aten/src/ATen/core/ivalue.cpp
@@ -644,6 +644,13 @@ std::ostream& IValue::repr(
       c10::printQuotedString(out, device_stream.str());
       return out << ")";
     }
+    case IValue::Tag::Generator: {
+      auto generator = v.toGenerator();
+      out << "torch.Generator(device=";
+      c10::printQuotedString(out, generator.device().str());
+      out << ", seed=" << generator.current_seed() << ")";
+      return out;
+    }
     case IValue::Tag::GenericDict:
       return printMaybeAnnotatedDict(out, v, formatter);
     case IValue::Tag::Enum: {
@@ -956,6 +963,7 @@ IValue IValue::deepcopy(
     case IValue::Tag::SymBool:
     case IValue::Tag::Bool:
     case IValue::Tag::Device:
+    case IValue::Tag::Generator:
     case IValue::Tag::Uninitialized: {
       copy = *this;
     } break;
diff --git a/aten/src/ATen/core/type_factory.cpp b/aten/src/ATen/core/type_factory.cpp
index 78c5a31b86ef..b36c25c8c775 100644
--- a/aten/src/ATen/core/type_factory.cpp
+++ b/aten/src/ATen/core/type_factory.cpp
@@ -28,6 +28,7 @@ namespace c10 {
   _(complex, ComplexType)           \
   _(str, StringType)                \
   _(Device, DeviceObjType)          \
+  _(Generator, GeneratorType)       \
   _(Stream, StreamObjType)          \
   _(number, NumberType)             \
   _(None, NoneType)                 \
diff --git a/aten/src/ATen/native/ts_native_functions.yaml b/aten/src/ATen/native/ts_native_functions.yaml
index 85ac57e127c4..17c9bd4234f3 100644
--- a/aten/src/ATen/native/ts_native_functions.yaml
+++ b/aten/src/ATen/native/ts_native_functions.yaml
@@ -168,6 +168,9 @@ full_codegen:
   - slice_scatter
   - diagonal_scatter
   - as_strided_scatter
+  # random ops
+  - normal_functional
+  - uniform
 ir_gen:
   - selu
 supported:
@@ -177,7 +180,6 @@ supported:
   - empty.memory_format
   - empty_strided
   - fill_.Scalar
-  - normal_
   - max_pool3d_with_indices
   - max_pool3d_with_indices_backward
   - _to_copy
diff --git a/build_variables.bzl b/build_variables.bzl
index 70d48d836443..a634f640e8cb 100644
--- a/build_variables.bzl
+++ b/build_variables.bzl
@@ -447,7 +447,6 @@ lazy_tensor_ts_sources = [
     "torch/csrc/lazy/ts_backend/dynamic_ir.cpp",
     "torch/csrc/lazy/ts_backend/config.cpp",
     "torch/csrc/lazy/ts_backend/ops/device_data.cpp",
-    "torch/csrc/lazy/ts_backend/ops/random_ops.cpp",
     "torch/csrc/lazy/ts_backend/ops/generic.cpp",
     "torch/csrc/lazy/ts_backend/tensor_aten_ops.cpp",
     "torch/csrc/lazy/ts_backend/ts_autograd_functions.cpp",
diff --git a/docs/source/jit_unsupported.rst b/docs/source/jit_unsupported.rst
index 7a6538300984..60bca7d6d92c 100644
--- a/docs/source/jit_unsupported.rst
+++ b/docs/source/jit_unsupported.rst
@@ -88,4 +88,3 @@ we suggest using :meth:`torch.jit.trace`.
   * :class:`torch.nn.AdaptiveLogSoftmaxWithLoss`
   * :class:`torch.autograd.Function`
   * :class:`torch.autograd.enable_grad`
-  * :class:`torch.Generator`
diff --git a/test/jit/test_generator.py b/test/jit/test_generator.py
new file mode 100644
index 000000000000..8a993c7fed10
--- /dev/null
+++ b/test/jit/test_generator.py
@@ -0,0 +1,195 @@
+# Owner(s): ["oncall: jit"]
+
+import io
+import math
+import unittest
+
+import torch
+from torch.nn import init
+from torch.testing._internal.common_utils import skipIfLegacyJitExecutor
+from torch.testing._internal.jit_utils import JitTestCase
+
+
+if __name__ == "__main__":
+    raise RuntimeError(
+        "This test file is not meant to be run directly, use:\n\n"
+        "\tpython test/test_jit.py TESTNAME\n\n"
+        "instead."
+    )
+
+
+class TestGenerator(JitTestCase):
+    # torch.jit.trace does not properly capture the generator manual seed
+    # and thus is non deterministic even if the generator is manually seeded
+    @skipIfLegacyJitExecutor("legacy JIT executor does not support Generator type")
+    @unittest.expectedFailure
+    def test_trace(self):
+        def f():
+            generator = torch.Generator()
+            generator.seed()
+            generator.manual_seed(2023)
+            generator.initial_seed()
+            tensor = torch.empty(2, 2)
+            tensor.uniform_(0, 1, generator=generator)
+            return tensor
+
+        traced_f = torch.jit.trace(f, ())
+
+        # Run this 3 times to ensure that the generator is being manually seeded
+        # each time the traced function is run
+        for i in range(3):
+            torch.manual_seed(1)
+
+            eager_tensor = f()
+
+            # Change the seed of the default generator to
+            # check that we're using the generator from the
+            # trace
+            torch.manual_seed(2)
+            traced_tensor = traced_f()
+
+            self.assertEqual(eager_tensor, traced_tensor)
+
+    def test_script(self):
+        def f():
+            generator = torch.Generator()
+            generator.seed()
+            generator.manual_seed(2023)
+            generator.initial_seed()
+            tensor = torch.empty(2, 2)
+            tensor.normal_(-1.0, 1.0, generator=generator)
+            return tensor
+
+        script_f = torch.jit.script(f, ())
+
+        # Run this 3 times to ensure that the generator is being manually seeded
+        # each time the traced function is run
+        for i in range(3):
+            torch.manual_seed(1)
+
+            eager_tensor = f()
+
+            # Change the seed of the default generator to
+            # check that we're using the generator from the
+            # trace
+            torch.manual_seed(2)
+
+            script_tensor = script_f()
+
+            self.assertEqual(eager_tensor, script_tensor)
+
+    def test_default_generator(self):
+        def f():
+            # check that calling manual seed for the default generator works
+            torch.manual_seed(2023)
+            tensor = torch.empty(2, 2)
+            tensor.normal_(-1.0, 1.0)
+            return tensor
+
+        torch.manual_seed(1)
+
+        eager_tensor = f()
+
+        torch.manual_seed(2)
+
+        script_f = torch.jit.script(f, ())
+        script_tensor = script_f()
+
+        self.assertEqual(eager_tensor, script_tensor)
+
+    def test_generator_arg(self):
+        def f(generator: torch.Generator):
+            tensor = torch.empty(2, 2)
+            tensor.normal_(-1.0, 1.0, generator=generator)
+            return tensor
+
+        generator = torch.Generator()
+        generator.manual_seed(2023)
+
+        script_f = torch.jit.script(f, (generator,))
+
+        for i in range(3):
+            generator = torch.Generator()
+            generator.manual_seed(2023 + i)
+
+            torch.manual_seed(1 + i)
+
+            eager_tensor = f(generator)
+
+            generator = torch.Generator()
+            generator.manual_seed(2023 + i)
+
+            torch.manual_seed(1 + i)
+
+            script_tensor = script_f(generator)
+
+            self.assertEqual(eager_tensor, script_tensor)
+
+    def test_save_load(self):
+        class Foo(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.foo = torch.nn.Linear(2, 2, bias=False)
+                self.bar = torch.nn.Linear(2, 2, bias=False)
+
+                self.reset_parameters()
+
+            def reset_linear(self, module, generator):
+                init.kaiming_uniform_(
+                    module.weight, a=math.sqrt(5), generator=generator
+                )
+
+            def reset_parameters(self):
+                generator = torch.Generator()
+                generator.manual_seed(1)
+                self.reset_linear(self.foo, generator)
+
+                generator = torch.Generator()
+                generator.manual_seed(2)
+                self.reset_linear(self.bar, generator)
+
+            def forward(self, x):
+                x = self.foo(x)
+                x = self.bar(x)
+
+                generator = torch.Generator()
+                generator.manual_seed(3)
+                r = torch.empty_like(x)
+                r.normal_(0.0, 1.0, generator=generator)
+
+                return x, r
+
+        eager_foo = Foo()
+
+        script_module = torch.jit.script(Foo())
+        saved_module = io.BytesIO()
+        torch.jit.save(script_module, saved_module)
+        saved_module.seek(0)
+
+        loaded_module = torch.jit.load(saved_module)
+
+        self.assertEqual(eager_foo.foo.weight, loaded_module.foo.weight)
+        self.assertEqual(eager_foo.bar.weight, loaded_module.bar.weight)
+
+        try:
+            # Run this 3 times so make sure that the generator seed is being set
+            # every time forward is called
+            for i in range(3):
+                x = torch.ones(2, 2)
+                out1, r1 = eager_foo(x)
+                out2, r2 = loaded_module(x)
+
+                try:
+                    self.assertEqual(out1, out2)
+                except:  # noqa: B001, E722
+                    print(f"Iteration {i}:\n{out1=}\n{out2=}")
+                    raise
+
+                try:
+                    self.assertEqual(r1, r2)
+                except:  # noqa: B001, E722
+                    print(f"Iteration {i}:\n{r1=}\n{r2=}")
+                    raise
+        except:  # noqa: B001, E722
+            print(loaded_module.forward.code)
+            raise
diff --git a/test/lazy/test_generator.py b/test/lazy/test_generator.py
new file mode 100644
index 000000000000..a4bc94cf26f8
--- /dev/null
+++ b/test/lazy/test_generator.py
@@ -0,0 +1,103 @@
+# Owner(s): ["oncall: jit"]
+
+import torch
+import torch._lazy.metrics as metrics
+import torch._lazy.ts_backend
+from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase
+
+torch._lazy.ts_backend.init()
+
+
+class LazyGeneratorTest(TestCase):
+    def test_generator(self):
+        """
+        Test that generators are being inserted into the TorchScript
+        graph by setting different seeds before each call to
+        generate_tensor but the resulting tensor is the same
+        """
+
+        def generate_tensor():
+            g1 = torch.Generator()
+            g1.manual_seed(2023)
+            t1 = torch.tensor(1.0)
+            t1.uniform_(generator=g1)
+
+            g2 = torch.Generator()
+            g2.manual_seed(2024)
+            t2 = torch.tensor(1.0)
+            t2.normal_(generator=g2)
+
+            return t1, t2
+
+        torch.manual_seed(1)
+
+        with torch.device("cpu"):
+            cpu_t1, cpu_t2 = generate_tensor()
+
+        torch.manual_seed(2)
+
+        with torch.device("lazy"):
+            lazy_t1, lazy_t2 = generate_tensor()
+
+        torch._lazy.mark_step()
+
+        assert torch.allclose(
+            cpu_t1, lazy_t1.to("cpu")
+        ), f"Expected {cpu_t1}, got {lazy_t1.to('cpu')}"
+        assert torch.allclose(
+            cpu_t2, lazy_t2.to("cpu")
+        ), f"Expected {cpu_t2}, got {lazy_t2.to('cpu')}"
+
+    @skipIfTorchDynamo("Torch Dynamo does not support torch.Generator type")
+    def test_generator_causes_multiple_compiles(self):
+        """
+        Test that inserting generators with different seed caused recompile
+        """
+
+        def generate_tensor(seed):
+            t = torch.tensor(1.0)
+            g = torch.Generator()
+            g.manual_seed(seed)
+            t.uniform_(-1, 1, generator=g)
+            return t
+
+        metrics.reset()
+
+        with torch.device("lazy"):
+            t = generate_tensor(1)
+            torch._lazy.mark_step()
+
+            uncached_compile = metrics.counter_value("UncachedCompile")
+            assert (
+                uncached_compile == 1
+            ), f"Expected 1 uncached compiles, got {uncached_compile}"
+
+            t = generate_tensor(2)
+            torch._lazy.mark_step()
+
+            uncached_compile = metrics.counter_value("UncachedCompile")
+            assert (
+                uncached_compile == 2
+            ), f"Expected 2 uncached compiles, got {uncached_compile}"
+
+            t = generate_tensor(1)
+            torch._lazy.mark_step()
+
+            uncached_compile = metrics.counter_value("UncachedCompile")
+            assert (
+                uncached_compile == 2
+            ), f"Expected 2 uncached compiles, got {uncached_compile}"
+            cached_compile = metrics.counter_value("CachedCompile")
+            assert (
+                cached_compile == 1
+            ), f"Expected 1 cached compile, got {cached_compile}"
+
+        metrics.reset()
+
+        latest_graph = torch._C._lazy_ts_backend._get_latest_computation_graph()
+        assert 'torch.Generator(device="cpu", seed=1)' in latest_graph
+        assert "aten::uniform" in latest_graph
+
+
+if __name__ == "__main__":
+    run_tests()
diff --git a/test/lazy/test_ts_opinfo.py b/test/lazy/test_ts_opinfo.py
index a16265da20b8..6cc5ccec7454 100644
--- a/test/lazy/test_ts_opinfo.py
+++ b/test/lazy/test_ts_opinfo.py
@@ -231,6 +231,9 @@ def assert_allclose_rec(t):
 
         samples = op.sample_inputs("lazy", dtype, requires_grad=False)
         for sample in samples:
+            # Need to run mark step so that all random ops are computed in the right order
+            torch._lazy.mark_step()
+
             args = [sample.input] + list(sample.args)
             kwargs = sample.kwargs
             copy_args = clone_to_device(args, test_device)
@@ -238,6 +241,7 @@ def assert_allclose_rec(t):
             r_exp = op(*copy_args, **kwargs)
             r_actual = op(*args, **kwargs)
 
+            torch._lazy.mark_step()
             assert_allclose_rec((r_actual, r_exp))
 
     @ops([op for op in op_db if op.name in LAZY_OPS_LIST and op.name not in SKIP_RUNTIME_ERROR_LIST | SKIP_INCORRECT_RESULTS_LIST], allowed_dtypes=(torch.float,))  # noqa: B950
@@ -263,6 +267,9 @@ def assert_allclose_rec(t):
 
         samples = op.sample_inputs("lazy", dtype, requires_grad=False)
         for sample in samples:
+            # Need to run mark step so that all random ops are computed in the right order
+            torch._lazy.mark_step()
+
             args = [sample.input] + list(sample.args)
             kwargs = sample.kwargs
             copy_args = clone_to_device(args, test_device)
diff --git a/test/test_jit.py b/test/test_jit.py
index 58e9cc3e3a97..d693a266e0f6 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -75,6 +75,7 @@
 from jit.test_sparse import TestSparse  # noqa: F401
 from jit.test_tensor_methods import TestTensorMethods  # noqa: F401
 from jit.test_dataclasses import TestDataclasses  # noqa: F401
+from jit.test_generator import TestGenerator  # noqa: F401
 
 # Torch
 from torch import Tensor
@@ -14169,6 +14170,41 @@ def test({arg_str}):
 
             FileCheck().check_not("prim::PythonOp").run(cu.test.graph)
 
+    def test_nn_init_generator(self):
+        init_fns = (
+            'uniform_', 'normal_', 'xavier_normal_', 'xavier_uniform_',
+        )
+
+        for name in init_fns:
+            # Build test code
+            code = dedent('''
+                def test(tensor, generator):
+                    # type: (Tensor, Generator)
+                    return torch.nn.init.{name}(tensor, generator=generator)
+            ''').format(name=name)
+            cu = torch.jit.CompilationUnit(code)
+
+            # Compare functions
+            init_fn = getattr(torch.nn.init, name)
+
+            torch.manual_seed(1)
+
+            g = torch.Generator()
+            g.manual_seed(2023)
+            script_out = cu.test(torch.ones(2, 2), g)
+
+            # Change the seed of the default generator to make
+            # sure that we're using the provided generator
+            torch.manual_seed(2)
+
+            g = torch.Generator()
+            g.manual_seed(2023)
+            eager_out = init_fn(torch.ones(2, 2), generator=g)
+
+            self.assertEqual(script_out, eager_out)
+
+            FileCheck().check_not("prim::PythonOp").run(cu.test.graph)
+
     def test_early_return_rewrite(self):
         def test_foo(x: bool):
             if x:
diff --git a/test/test_public_bindings.py b/test/test_public_bindings.py
index 9241b0dfdf9e..4bfac1e9d459 100644
--- a/test/test_public_bindings.py
+++ b/test/test_public_bindings.py
@@ -94,6 +94,7 @@ def test_no_new_bindings(self):
             "Future",
             "FutureType",
             "Generator",
+            "GeneratorType",
             "get_autocast_cpu_dtype",
             "get_autocast_ipu_dtype",
             "get_default_dtype",
diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in
index 6847955348ed..a12befa5cdc8 100644
--- a/torch/_C/__init__.pyi.in
+++ b/torch/_C/__init__.pyi.in
@@ -1969,6 +1969,10 @@ class DeviceObjType(JitType):
     @staticmethod
     def get() -> DeviceObjType: ...
 
+class _GeneratorType(JitType):
+    @staticmethod
+    def get() -> _GeneratorType: ...
+
 class StreamObjType(JitType):
     @staticmethod
     def get() -> StreamObjType: ...
diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py
index 9dd35bb49e7b..645bc3259cd1 100644
--- a/torch/_decomp/decompositions.py
+++ b/torch/_decomp/decompositions.py
@@ -2331,6 +2331,7 @@ def uniform(
     x: Tensor,
     low: Union[bool, int, float] = 0.0,
     high: Union[bool, int, float] = 1.0,
+    generator: Optional[torch.Generator] = None,
 ):
     return prims._uniform_helper(
         x.shape,
@@ -2338,13 +2339,13 @@ def uniform(
         high=sym_float(high),
         dtype=x.dtype,
         device=x.device,
+        generator=generator,
     )
 
 
 @register_decomposition(aten.uniform_)
 def uniform_(self, low=0, high=1, generator=None):
-    assert generator is None
-    return self.copy_(uniform(self, low, high))
+    return self.copy_(uniform(self, low, high, generator))
 
 
 # aten/src/ATen/native/UpSample.cpp compute_output_size
diff --git a/torch/_prims/__init__.py b/torch/_prims/__init__.py
index dcb9b24f1b2e..e25aec08a266 100644
--- a/torch/_prims/__init__.py
+++ b/torch/_prims/__init__.py
@@ -2765,8 +2765,6 @@ def _svd_aten(
 #
 
 
-# TODO: add generator support
-# NOTE: there is currently no way of acquiring the "default" torch generator
 def _normal_meta(
     shape: ShapeType,
     *,
@@ -2775,6 +2773,7 @@ def _normal_meta(
     dtype: torch.dtype,
     device: torch.device,
     requires_grad: bool,
+    generator: Optional[torch.Generator] = None,
 ) -> TensorLikeType:
     torch._check(
         std >= 0.0,
@@ -2798,11 +2797,12 @@ def _normal_aten(
     dtype: torch.dtype,
     device: torch.device,
     requires_grad: bool,
+    generator: Optional[torch.Generator] = None,
 ) -> Tensor:
     a = torch.empty(shape, dtype=dtype, device=device, requires_grad=requires_grad)
     with torch.no_grad():
         # NOTE: normal_ is incorrectly annotated to expect mean to be a float
-        a.normal_(mean, std)  # type: ignore[arg-type]
+        a.normal_(mean, std, generator=generator)  # type: ignore[arg-type]
     return a
 
 
@@ -2815,7 +2815,7 @@ def _normal_aten(
 
 normal = _make_prim(
     schema=(
-        "normal(SymInt[] shape, *, Scalar mean, Scalar std, ScalarType dtype, Device device, bool requires_grad) -> Tensor"
+        "normal(SymInt[] shape, *, Scalar mean, Scalar std, ScalarType dtype, Device device, bool requires_grad, Generator? generator=None) -> Tensor"  # noqa: B950
     ),
     return_type=RETURN_TYPE.NEW,
     meta=_normal_meta,
@@ -2831,6 +2831,7 @@ def _uniform_meta(
     high: float,
     dtype: torch.dtype,
     device: torch.device,
+    generator: Optional[torch.Generator] = None,
 ) -> TensorLikeType:
     strides = utils.make_contiguous_strides_for(shape)
     return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=device)
@@ -2843,9 +2844,10 @@ def _uniform_aten(
     high: float,
     dtype: torch.dtype,
     device: torch.device,
+    generator: Optional[torch.Generator] = None,
 ) -> Tensor:
     a = torch.empty(shape, dtype=dtype, device=device)
-    a.uniform_(low, high)
+    a.uniform_(low, high, generator=generator)
     return a
 
 
@@ -2856,7 +2858,7 @@ def _uniform_aten(
 # TODO: we should more seriously review randomness modeling and prims
 _uniform_helper = _make_prim(
     schema=(
-        "uniform(SymInt[] shape, *, Scalar low, Scalar high, ScalarType dtype, Device device) -> Tensor"
+        "uniform(SymInt[] shape, *, Scalar low, Scalar high, ScalarType dtype, Device device, Generator? generator=None) -> Tensor"
     ),
     return_type=RETURN_TYPE.NEW,
     meta=_uniform_meta,
diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py
index 66da794183e4..450c198067aa 100644
--- a/torch/_refs/__init__.py
+++ b/torch/_refs/__init__.py
@@ -5900,7 +5900,6 @@ def normal(
     device=None,
     pin_memory=None,
 ):
-    assert generator is None
     assert layout is None or layout == torch.strided
 
     if not isinstance(std, TensorLike):
@@ -5937,6 +5936,7 @@ def normal(
         dtype=dtype,
         device=device,
         requires_grad=False,
+        generator=generator,
     )
     return std * normal_samples + mean
 
diff --git a/torch/csrc/jit/frontend/sugared_value.cpp b/torch/csrc/jit/frontend/sugared_value.cpp
index 307cd04e9fa7..3095d0cc01b4 100644
--- a/torch/csrc/jit/frontend/sugared_value.cpp
+++ b/torch/csrc/jit/frontend/sugared_value.cpp
@@ -241,6 +241,17 @@ std::shared_ptr SimpleValue::attr(
     return SpecialFormValue::create(aten::index);
   }
 
+  if (auto generator_type = value_->type()->cast()) {
+    // Handle access to Generator's `manual_seed`, `initial_seed` and `seed`
+    // attributes.
+    if (field == "manual_seed" || field == "initial_seed" || field == "seed") {
+      if (auto builtin = BuiltinFunction::tryCreate(
+              Symbol::aten(field), NamedValue(loc, "self", value_))) {
+        return builtin;
+      }
+    }
+  }
+
   ErrorReport report(loc);
   report << "'" << value_->type()->repr_str()
          << "' object has no attribute or method '" << field << "'.";
diff --git a/torch/csrc/jit/frontend/tracer.cpp b/torch/csrc/jit/frontend/tracer.cpp
index c34e044e7db8..823b27f30fcb 100644
--- a/torch/csrc/jit/frontend/tracer.cpp
+++ b/torch/csrc/jit/frontend/tracer.cpp
@@ -679,12 +679,14 @@ void addInputs(
     Node* n,
     const char* name,
     const c10::optional& value) {
+  Graph* g = n->owningGraph();
+
   if (value.has_value() && value->defined()) {
-    detail::badArgType(*value);
+    detail::genericAddInput(n, *value);
+  } else {
+    Value* undef_gen = g->insertNode(g->createNone())->output();
+    n->addInput(undef_gen);
   }
-  Graph* g = n->owningGraph();
-  Value* undef_gen = g->insertNode(g->createNone())->output();
-  n->addInput(undef_gen);
 }
 void addInputs(Node* n, const char* name, at::Device value) {
   detail::genericAddInput(n, value);
diff --git a/torch/csrc/jit/ir/constants.cpp b/torch/csrc/jit/ir/constants.cpp
index ae639df7a301..905088a20d1e 100644
--- a/torch/csrc/jit/ir/constants.cpp
+++ b/torch/csrc/jit/ir/constants.cpp
@@ -5,6 +5,7 @@
 #include 
 #include 
 #include 
+#include 
 
 namespace torch::jit {
 
@@ -108,6 +109,10 @@ c10::optional tryInsertConstant(
     ss << val.toDevice();
     n->s_(attr::value, ss.str());
     n->output()->setType(DeviceObjType::get());
+  } else if (val.isGenerator()) {
+    auto generator = val.toGenerator();
+    n->ival_(attr::value, generator);
+    n->output()->setType(GeneratorType::get());
   } else if (val.isStream()) {
     // packing into int64_t removed
     n->ival_(attr::value, val);
@@ -194,6 +199,9 @@ c10::optional toIValue(const Value* v) {
   } else if (type == DeviceObjType::get()) {
     auto d = c10::Device(node->s(attr::value));
     return d;
+  } else if (type == GeneratorType::get()) {
+    auto generator = node->ival(attr::value).toGenerator();
+    return generator;
   } else if (type == StreamObjType::get()) {
     // int64_t packing removed
     auto s = node->ival(attr::value).toStream();
diff --git a/torch/csrc/jit/ir/node_hashing.cpp b/torch/csrc/jit/ir/node_hashing.cpp
index 56f70c36bbf1..372ca2fcf497 100644
--- a/torch/csrc/jit/ir/node_hashing.cpp
+++ b/torch/csrc/jit/ir/node_hashing.cpp
@@ -142,6 +142,9 @@ bool ivaluesEqual(const IValue& a1, const IValue& a2) {
   if (a1.isObject()) {
     return &a1.toObjectRef() == &a2.toObjectRef();
   }
+  if (a1.isGenerator()) {
+    return a1.toGenerator() == a2.toGenerator();
+  }
   TORCH_INTERNAL_ASSERT(false);
 }
 
diff --git a/torch/csrc/jit/python/pybind_utils.h b/torch/csrc/jit/python/pybind_utils.h
index 60b71faa0347..e145ee09290b 100644
--- a/torch/csrc/jit/python/pybind_utils.h
+++ b/torch/csrc/jit/python/pybind_utils.h
@@ -397,6 +397,8 @@ inline InferredType tryToInferType(py::handle input) {
     return InferredType(IntType::get());
   } else if (THPDevice_Check(input.ptr())) {
     return InferredType(DeviceObjType::get());
+  } else if (THPGenerator_Check(input.ptr())) {
+    return InferredType(GeneratorType::get());
   } else if (THPStream_Check(input.ptr())) {
     return InferredType(StreamObjType::get());
   } else if (THPDtype_Check(input.ptr())) {
diff --git a/torch/csrc/jit/python/python_ir.cpp b/torch/csrc/jit/python/python_ir.cpp
index 46ac9f78ddc0..71e6a2fefa13 100644
--- a/torch/csrc/jit/python/python_ir.cpp
+++ b/torch/csrc/jit/python/python_ir.cpp
@@ -1012,6 +1012,10 @@ void initPythonIRBindings(PyObject* module_) {
       .def_static("get", &StringType::get);
   py::class_(m, "DeviceObjType")
       .def_static("get", &DeviceObjType::get);
+  // TODO(antoniojkim): Add GeneratorType to the public API once its been added
+  //                    to the public documentation
+  py::class_(m, "_GeneratorType")
+      .def_static("get", &GeneratorType::get);
   py::class_(m, "StreamObjType")
       .def_static("get", &StreamObjType::get);
   py::class_(m, "PyObjectType")
diff --git a/torch/csrc/jit/runtime/register_ops_utils.cpp b/torch/csrc/jit/runtime/register_ops_utils.cpp
index 397b83d588ef..6d91f28328be 100644
--- a/torch/csrc/jit/runtime/register_ops_utils.cpp
+++ b/torch/csrc/jit/runtime/register_ops_utils.cpp
@@ -1,3 +1,12 @@
+#include 
+// TODO(antoniojkim): Add CUDA support for make_generator_for_device
+// #ifdef USE_CUDA
+// #include 
+// #endif
+#ifdef USE_MPS
+#include 
+#endif
+
 #include 
 #include 
 #include 
@@ -392,4 +401,39 @@ void listSetItem(Stack& stack) {
 
   push(stack, std::move(list));
 }
+
+at::Generator make_generator_for_device(
+    c10::Device device,
+    c10::optional seed) {
+  if (device.is_cpu()) {
+    if (seed.has_value()) {
+      return at::detail::createCPUGenerator(seed.value());
+    } else {
+      return at::detail::createCPUGenerator();
+    }
+// TODO(antoniojkim): Enable support for CUDA device
+//                    Implementation below causes issues during rocm build
+// #ifdef USE_CUDA
+//   } else if (device.is_cuda()) {
+//     auto generator = at::cuda::detail::createCUDAGenerator(device.index());
+//     if (seed.has_value()) {
+//       generator.set_current_seed(seed.value());
+//     }
+//     return generator;
+// #endif
+#ifdef USE_MPS
+  } else if (device.is_mps()) {
+    if (seed.has_value()) {
+      return at::mps::detail::createMPSGenerator(seed.value());
+    } else {
+      return at::mps::detail::createMPSGenerator();
+    }
+#endif
+  } else {
+    AT_ERROR(
+        "Unsupported device for at::make_generator_for_device found: ",
+        device.str());
+  }
+}
+
 } // namespace torch::jit
diff --git a/torch/csrc/jit/runtime/register_ops_utils.h b/torch/csrc/jit/runtime/register_ops_utils.h
index 25d30b432e60..2a269931afbc 100644
--- a/torch/csrc/jit/runtime/register_ops_utils.h
+++ b/torch/csrc/jit/runtime/register_ops_utils.h
@@ -26,7 +26,9 @@
 #include 
 #include 
 #include 
+#include 
 #include 
+#include 
 #include 
 #include 
 #include 
@@ -876,4 +878,8 @@ struct OperatorGeneratorArgs {
           aten_op, op, op, op, bool),                                    \
       DEFINE_STR_CMP_OP(aten_op, op)
 
+TORCH_API at::Generator make_generator_for_device(
+    c10::Device device,
+    c10::optional seed = c10::nullopt);
+
 } // namespace torch::jit
diff --git a/torch/csrc/jit/runtime/register_prim_ops.cpp b/torch/csrc/jit/runtime/register_prim_ops.cpp
index 4267332851e2..4d8a0cd89d8f 100644
--- a/torch/csrc/jit/runtime/register_prim_ops.cpp
+++ b/torch/csrc/jit/runtime/register_prim_ops.cpp
@@ -1,4 +1,5 @@
 #include 
+#include 
 #include 
 #include 
 #include 
@@ -2492,6 +2493,44 @@ static const std::vector opGenArgs1{
         TORCH_SELECTIVE_SCHEMA("aten::manual_seed(int seed) -> ()"),
         [](Stack& stack) { at::manual_seed(pop(stack).toInt()); },
         aliasAnalysisFromSchema()),
+    OperatorGeneratorArgs(
+        TORCH_SELECTIVE_SCHEMA(
+            "aten::Generator(*, Device? device=None, int? seed=None) -> Generator"),
+        [](Stack& stack) {
+          auto seed = pop(stack).toOptional();
+          auto device = pop(stack).toOptional();
+          push(
+              stack,
+              torch::jit::make_generator_for_device(
+                  device.value_or(c10::Device("cpu")), seed));
+        },
+        aliasAnalysisFromSchema()),
+    OperatorGeneratorArgs(
+        TORCH_SELECTIVE_SCHEMA("aten::initial_seed(Generator self) -> int"),
+        [](Stack& stack) {
+          auto generator = pop(stack);
+          auto current_seed = generator.toGenerator().current_seed();
+          push(stack, (int64_t)current_seed);
+        },
+        aliasAnalysisFromSchema()),
+    OperatorGeneratorArgs(
+        TORCH_SELECTIVE_SCHEMA(
+            "aten::manual_seed.generator(Generator(a!) self, int seed) -> Generator(a!)"),
+        [](Stack& stack) {
+          auto seed = pop(stack).toInt();
+          auto generator = pop(stack);
+          generator.toGenerator().set_current_seed(seed);
+          push(stack, generator);
+        },
+        aliasAnalysisFromSchema()),
+    OperatorGeneratorArgs(
+        TORCH_SELECTIVE_SCHEMA("aten::seed(Generator(a!) self) -> int"),
+        [](Stack& stack) {
+          auto generator = pop(stack);
+          auto current_seed = generator.toGenerator().seed();
+          push(stack, (int64_t)current_seed);
+        },
+        aliasAnalysisFromSchema()),
     OperatorGeneratorArgs(
         TORCH_SELECTIVE_SCHEMA("aten::cuda(Tensor(a) self) -> Tensor(a|b)"),
         [](Stack& stack) {
diff --git a/torch/csrc/jit/runtime/register_special_ops.cpp b/torch/csrc/jit/runtime/register_special_ops.cpp
index 4a9cc5695bbd..65de32eaa668 100644
--- a/torch/csrc/jit/runtime/register_special_ops.cpp
+++ b/torch/csrc/jit/runtime/register_special_ops.cpp
@@ -393,7 +393,7 @@ RegisterOperators reg({
         aliasAnalysisFromSchema()),
     OperatorGenerator(
         TORCH_SELECTIVE_SCHEMA(
-            "aten::_no_grad_uniform_(Tensor(a!) tensor, float a, float b) -> Tensor(a!)"),
+            "aten::_no_grad_uniform_(Tensor(a!) tensor, float a, float b, Generator? generator=None) -> Tensor(a!)"),
         [](Stack& stack) {
           // TODO: remove when script supports setting grad mode
           torch::NoGradGuard no_grad;
@@ -403,13 +403,16 @@ RegisterOperators reg({
           double a;
           // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
           double b;
+          c10::optional generator =
+              pop(stack).toOptional();
+
           pop(stack, tensor, a, b);
-          push(stack, tensor.uniform_(a, b));
+          push(stack, tensor.uniform_(a, b, generator));
         },
         aliasAnalysisFromSchema()),
     OperatorGenerator(
         TORCH_SELECTIVE_SCHEMA(
-            "aten::_no_grad_normal_(Tensor(a!) tensor, float mean, float std) -> Tensor(a!)"),
+            "aten::_no_grad_normal_(Tensor(a!) tensor, float mean, float std, Generator? generator=None) -> Tensor(a!)"),
         [](Stack& stack) {
           // TODO: remove when script supports setting grad mode
           torch::NoGradGuard no_grad;
@@ -419,8 +422,11 @@ RegisterOperators reg({
           double mean;
           // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
           double std;
+          c10::optional generator =
+              pop(stack).toOptional();
+
           pop(stack, tensor, mean, std);
-          push(stack, tensor.normal_(mean, std));
+          push(stack, tensor.normal_(mean, std, generator));
         },
         aliasAnalysisFromSchema()),
     OperatorGenerator(
diff --git a/torch/csrc/lazy/core/hash.h b/torch/csrc/lazy/core/hash.h
index 4e08213fcb20..bb6a779555f2 100644
--- a/torch/csrc/lazy/core/hash.h
+++ b/torch/csrc/lazy/core/hash.h
@@ -148,6 +148,11 @@ static inline hash_t Hash(const std::string& value) {
 static inline hash_t Hash(const c10::string_view& value) {
   return DataHash(value.data(), value.size());
 }
+
+static inline hash_t Hash(const at::Generator& value) {
+  return TensorHash(value.get_state());
+}
+
 // Taken from glibc's implementation of hashing optionals,
 // we want to include a contribution to the hash to distinguish
 // cases where one or another option was null, but we hope it doesn't
diff --git a/torch/csrc/lazy/core/shape_inference.cpp b/torch/csrc/lazy/core/shape_inference.cpp
index b738cbbaada3..39f26d8043de 100644
--- a/torch/csrc/lazy/core/shape_inference.cpp
+++ b/torch/csrc/lazy/core/shape_inference.cpp
@@ -1368,6 +1368,22 @@ std::vector compute_shape_as_strided_scatter_symint(
   return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
 }
 
+std::vector compute_shape_normal_functional(
+    const at::Tensor& self,
+    double mean,
+    double std,
+    c10::optional generator) {
+  return {Shape(self.scalar_type(), self.sizes().vec())};
+}
+
+std::vector compute_shape_uniform(
+    const at::Tensor& self,
+    double from,
+    double to,
+    c10::optional generator) {
+  return {Shape(self.scalar_type(), self.sizes().vec())};
+}
+
 // Restore unused-parameters warnings
 #pragma GCC diagnostic pop
 
diff --git a/torch/csrc/lazy/core/shape_inference.h b/torch/csrc/lazy/core/shape_inference.h
index e243798cfc77..a8388a0b2235 100644
--- a/torch/csrc/lazy/core/shape_inference.h
+++ b/torch/csrc/lazy/core/shape_inference.h
@@ -70,6 +70,7 @@ TORCH_API std::vector compute_shape_new_empty_strided(const
 TORCH_API std::vector compute_shape_nll_loss2d_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const c10::optional & weight, int64_t reduction, int64_t ignore_index, const at::Tensor & total_weight);
 TORCH_API std::vector compute_shape_nll_loss2d_forward(const at::Tensor & self, const at::Tensor & target, const c10::optional & weight, int64_t reduction, int64_t ignore_index);
 TORCH_API std::vector compute_shape_nonzero(const at::Tensor & self);
+TORCH_API std::vector compute_shape_normal_functional(const at::Tensor & self, double mean, double std, c10::optional generator);
 TORCH_API std::vector compute_shape_random(const at::Tensor & self, c10::optional generator);
 TORCH_API std::vector compute_shape_random(const at::Tensor & self, int64_t to, c10::optional generator);
 TORCH_API std::vector compute_shape_random(const at::Tensor & self, int64_t from, c10::optional to, c10::optional generator);
@@ -91,6 +92,7 @@ TORCH_API std::vector compute_shape_narrow_copy_symint(const
 TORCH_API std::vector compute_shape_hardswish(const at::Tensor & self);
 TORCH_API std::vector compute_shape_hardswish_backward(const at::Tensor & grad_output, const at::Tensor & self);
 TORCH_API std::vector compute_shape_selu(const at::Tensor & self);
+TORCH_API std::vector compute_shape_uniform(const at::Tensor & self, double from, double to, c10::optional generator);
 
 // Non-Native ops
 TORCH_API std::vector compute_shape_scalar(const at::Scalar& value, const at::ScalarType& type);
diff --git a/torch/csrc/lazy/python/init.cpp b/torch/csrc/lazy/python/init.cpp
index af4afb78a4fd..ca996c43fb00 100644
--- a/torch/csrc/lazy/python/init.cpp
+++ b/torch/csrc/lazy/python/init.cpp
@@ -307,6 +307,21 @@ void initLazyBindings(PyObject* module) {
 #endif // !(defined(FBCODE_CAFFE2) || defined(OVRSOURCE))
         return result;
       });
+  lazy_ts_backend.def("_get_latest_computation_graph", []() {
+#if !(defined(FBCODE_CAFFE2) || defined(OVRSOURCE))
+    auto computation = LazyGraphExecutor::Get()
+                           ->GetComputationCache()
+                           ->GetLatest()
+                           ->computation;
+    auto ts_computation = dynamic_cast(computation.get());
+    TORCH_CHECK(ts_computation, "Found non-TSComputation in cache");
+    return ts_computation->graph()->toString();
+#else
+    TORCH_CHECK(
+        false, "TorchScript backend not yet supported in FBCODE builds");
+    return "";
+#endif // !(defined(FBCODE_CAFFE2) || defined(OVRSOURCE))
+  });
 
   // GetPythonFramesFunction() has not ever worked with torchdeploy/multipy
   // possibly becuase GetPythonFrames resolves to external cpython rather
diff --git a/torch/csrc/lazy/ts_backend/ops/random_ops.cpp b/torch/csrc/lazy/ts_backend/ops/random_ops.cpp
deleted file mode 100644
index 7c2e1f4386c9..000000000000
--- a/torch/csrc/lazy/ts_backend/ops/random_ops.cpp
+++ /dev/null
@@ -1,47 +0,0 @@
-#include 
-#include 
-
-namespace torch {
-namespace lazy {
-
-Normal::Normal(
-    const torch::lazy::Value& self,
-    const double& mean,
-    const double& std,
-    std::vector&& shapes)
-    : torch::lazy::TsNode(
-          ClassOpKind(),
-          {self},
-          std::move(shapes),
-          /* num_outputs */ 1,
-          torch::lazy::MHash(mean, std)),
-      mean_(mean),
-      std_(std) {}
-
-std::string Normal::ToString() const {
-  std::stringstream ss;
-  ss << TsNode::ToString();
-  ss << ", mean=" << mean_;
-  ss << ", std=" << std_;
-  return ss.str();
-}
-
-torch::lazy::TSOpVector Normal::Lower(
-    std::shared_ptr function,
-    torch::lazy::TSLoweringContext* loctx) const {
-  std::vector arguments;
-  std::vector kwarguments;
-  arguments.reserve(3);
-  size_t i = 0;
-  arguments.emplace_back(loctx->GetOutputOp(operand(i++)));
-  arguments.emplace_back("mean", mean_);
-  arguments.emplace_back("std", std_);
-  torch::lazy::TSOpVector normal__out =
-      torch::lazy::LowerTSBuiltin(function, op().op, arguments, kwarguments);
-  TORCH_CHECK_EQ(normal__out.size(), 1);
-
-  return normal__out;
-}
-
-} // namespace lazy
-} // namespace torch
diff --git a/torch/csrc/lazy/ts_backend/ops/random_ops.h b/torch/csrc/lazy/ts_backend/ops/random_ops.h
deleted file mode 100644
index eb095a6a9542..000000000000
--- a/torch/csrc/lazy/ts_backend/ops/random_ops.h
+++ /dev/null
@@ -1,30 +0,0 @@
-#pragma once
-
-#include 
-
-namespace torch {
-namespace lazy {
-
-class Normal : public torch::lazy::TsNode {
- public:
-  static OpKind ClassOpKind() {
-    return OpKind::Get("aten::normal_");
-  }
-
-  Normal(
-      const torch::lazy::Value& self,
-      const double& mean,
-      const double& std,
-      std::vector&& shapes);
-
-  std::string ToString() const override;
-  torch::lazy::TSOpVector Lower(
-      std::shared_ptr function,
-      torch::lazy::TSLoweringContext* loctx) const override;
-
-  double mean_;
-  double std_;
-};
-
-} // namespace lazy
-} // namespace torch
diff --git a/torch/csrc/lazy/ts_backend/tensor_aten_ops.cpp b/torch/csrc/lazy/ts_backend/tensor_aten_ops.cpp
index 8970e5354a7f..f7da32b698c7 100644
--- a/torch/csrc/lazy/ts_backend/tensor_aten_ops.cpp
+++ b/torch/csrc/lazy/ts_backend/tensor_aten_ops.cpp
@@ -13,7 +13,6 @@
 #include 
 #include 
 #include 
-#include 
 #include 
 #include 
 
diff --git a/torch/csrc/lazy/ts_backend/ts_native_functions.cpp b/torch/csrc/lazy/ts_backend/ts_native_functions.cpp
index e3315f9b9f5d..d01c3fc1e616 100644
--- a/torch/csrc/lazy/ts_backend/ts_native_functions.cpp
+++ b/torch/csrc/lazy/ts_backend/ts_native_functions.cpp
@@ -14,7 +14,6 @@
 #include 
 #include 
 #include 
-#include 
 #include 
 #include 
 #include 
@@ -372,36 +371,6 @@ at::Tensor LazyNativeFunctions::max_pool3d_with_indices_backward(
           indices);
 }
 
-at::Tensor& LazyNativeFunctions::normal_(
-    at::Tensor& self,
-    double mean,
-    double std,
-    c10::optional generator) {
-  // Unconditionally fall back.
-  // implementing normal_ via lazy tensor caused differences in results compared
-  // to eager.
-  return at::native::call_fallback_fn<<c_eager_fallback, ATEN_OP(normal_)>::
-      call(self, mean, std, generator);
-
-  // if (force_eager_fallback(c10::Symbol::fromQualString("aten::normal_"))) {
-  //   return at::native::call_fallback_fn<<c_eager_fallback,
-  //   ATEN_OP(normal_)>::call(self, mean, std, generator);
-  // }
-
-  // if (generator.has_value()) {
-  //   return at::native::call_fallback_fn<<c_eager_fallback,
-  //   ATEN_OP(normal_)>::call(self, mean, std, generator);
-  // }
-
-  // TORCH_LAZY_FN_COUNTER("lazy::");
-  // auto device = bridge::GetBackendDevice(self);
-  // LazyTensor lazy_self = GetLtcTensorOrCreateForWrappedNumber(self, *device);
-  // std::vector shapes =
-  // {torch::lazy::Shape(self.scalar_type(), self.sizes().vec())}; auto node =
-  // torch::lazy::MakeNode(lazy_self.GetIrValue(), mean, std,
-  // std::move(shapes)); lazy_self.SetInPlaceIrValue(node); return self;
-};
-
 at::Tensor LazyNativeFunctions::_unsafe_view(
     const at::Tensor& self,
     at::IntArrayRef size) {
diff --git a/torch/jit/annotations.py b/torch/jit/annotations.py
index fffb650e6347..804475b35e1d 100644
--- a/torch/jit/annotations.py
+++ b/torch/jit/annotations.py
@@ -13,6 +13,7 @@
 import torch
 
 from torch._C import (
+    _GeneratorType,
     AnyType,
     AwaitType,
     BoolType,
@@ -479,6 +480,8 @@ def try_ann_to_type(ann, loc, rcb=None):
         return InterfaceType(ann.__torch_script_interface__)
     if ann is torch.device:
         return DeviceObjType.get()
+    if ann is torch.Generator:
+        return _GeneratorType.get()
     if ann is torch.Stream:
         return StreamObjType.get()
     if ann is torch.dtype:
diff --git a/torch/nn/init.py b/torch/nn/init.py
index ad99ddf8f769..426069d780c0 100644
--- a/torch/nn/init.py
+++ b/torch/nn/init.py
@@ -10,14 +10,14 @@
 # functions that use `with torch.no_grad()`. The JIT doesn't support context
 # managers, so these need to be implemented as builtins. Using these wrappers
 # lets us keep those builtins small and re-usable.
-def _no_grad_uniform_(tensor, a, b):
+def _no_grad_uniform_(tensor, a, b, generator=None):
     with torch.no_grad():
-        return tensor.uniform_(a, b)
+        return tensor.uniform_(a, b, generator=generator)
 
 
-def _no_grad_normal_(tensor, mean, std):
+def _no_grad_normal_(tensor, mean, std, generator=None):
     with torch.no_grad():
-        return tensor.normal_(mean, std)
+        return tensor.normal_(mean, std, generator=generator)
 
 
 def _no_grad_trunc_normal_(tensor, mean, std, a, b, generator=None):
@@ -121,7 +121,12 @@ def calculate_gain(nonlinearity, param=None):
         raise ValueError(f"Unsupported nonlinearity {nonlinearity}")
 
 
-def uniform_(tensor: Tensor, a: float = 0., b: float = 1.) -> Tensor:
+def uniform_(
+    tensor: Tensor,
+    a: float = 0.0,
+    b: float = 1.0,
+    generator: _Optional[torch.Generator] = None,
+) -> Tensor:
     r"""Fill the input Tensor with values drawn from the uniform distribution.
 
     :math:`\mathcal{U}(a, b)`.
@@ -130,17 +135,25 @@ def uniform_(tensor: Tensor, a: float = 0., b: float = 1.) -> Tensor:
         tensor: an n-dimensional `torch.Tensor`
         a: the lower bound of the uniform distribution
         b: the upper bound of the uniform distribution
+        generator: the torch Generator to sample from (default: None)
 
     Examples:
         >>> w = torch.empty(3, 5)
         >>> nn.init.uniform_(w)
     """
     if torch.overrides.has_torch_function_variadic(tensor):
-        return torch.overrides.handle_torch_function(uniform_, (tensor,), tensor=tensor, a=a, b=b)
-    return _no_grad_uniform_(tensor, a, b)
+        return torch.overrides.handle_torch_function(
+            uniform_, (tensor,), tensor=tensor, a=a, b=b, generator=generator
+        )
+    return _no_grad_uniform_(tensor, a, b, generator)
 
 
-def normal_(tensor: Tensor, mean: float = 0., std: float = 1.) -> Tensor:
+def normal_(
+    tensor: Tensor,
+    mean: float = 0.0,
+    std: float = 1.0,
+    generator: _Optional[torch.Generator] = None,
+) -> Tensor:
     r"""Fill the input Tensor with values drawn from the normal distribution.
 
     :math:`\mathcal{N}(\text{mean}, \text{std}^2)`.
@@ -149,14 +162,17 @@ def normal_(tensor: Tensor, mean: float = 0., std: float = 1.) -> Tensor:
         tensor: an n-dimensional `torch.Tensor`
         mean: the mean of the normal distribution
         std: the standard deviation of the normal distribution
+        generator: the torch Generator to sample from (default: None)
 
     Examples:
         >>> w = torch.empty(3, 5)
         >>> nn.init.normal_(w)
     """
     if torch.overrides.has_torch_function_variadic(tensor):
-        return torch.overrides.handle_torch_function(normal_, (tensor,), tensor=tensor, mean=mean, std=std)
-    return _no_grad_normal_(tensor, mean, std)
+        return torch.overrides.handle_torch_function(
+            normal_, (tensor,), tensor=tensor, mean=mean, std=std, generator=generator
+        )
+    return _no_grad_normal_(tensor, mean, std, generator)
 
 def trunc_normal_(
     tensor: Tensor,
@@ -180,6 +196,7 @@ def trunc_normal_(
         std: the standard deviation of the normal distribution
         a: the minimum cutoff value
         b: the maximum cutoff value
+        generator: the torch Generator to sample from (default: None)
 
     Examples:
         >>> w = torch.empty(3, 5)
@@ -314,7 +331,9 @@ def _calculate_fan_in_and_fan_out(tensor):
     return fan_in, fan_out
 
 
-def xavier_uniform_(tensor: Tensor, gain: float = 1.) -> Tensor:
+def xavier_uniform_(
+    tensor: Tensor, gain: float = 1.0, generator: _Optional[torch.Generator] = None
+) -> Tensor:
     r"""Fill the input `Tensor` with values using a Xavier uniform distribution.
 
     The method is described in `Understanding the difficulty of training
@@ -330,6 +349,7 @@ def xavier_uniform_(tensor: Tensor, gain: float = 1.) -> Tensor:
     Args:
         tensor: an n-dimensional `torch.Tensor`
         gain: an optional scaling factor
+        generator: the torch Generator to sample from (default: None)
 
     Examples:
         >>> w = torch.empty(3, 5)
@@ -339,10 +359,14 @@ def xavier_uniform_(tensor: Tensor, gain: float = 1.) -> Tensor:
     std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
     a = math.sqrt(3.0) * std  # Calculate uniform bounds from standard deviation
 
-    return _no_grad_uniform_(tensor, -a, a)
+    return _no_grad_uniform_(tensor, -a, a, generator)
 
 
-def xavier_normal_(tensor: Tensor, gain: float = 1.) -> Tensor:
+def xavier_normal_(
+    tensor: Tensor,
+    gain: float = 1.0,
+    generator: _Optional[torch.Generator] = None,
+) -> Tensor:
     r"""Fill the input `Tensor` with values using a Xavier normal distribution.
 
     The method is described in `Understanding the difficulty of training deep feedforward
@@ -357,6 +381,7 @@ def xavier_normal_(tensor: Tensor, gain: float = 1.) -> Tensor:
     Args:
         tensor: an n-dimensional `torch.Tensor`
         gain: an optional scaling factor
+        generator: the torch Generator to sample from (default: None)
 
     Examples:
         >>> w = torch.empty(3, 5)
@@ -365,7 +390,7 @@ def xavier_normal_(tensor: Tensor, gain: float = 1.) -> Tensor:
     fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
     std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
 
-    return _no_grad_normal_(tensor, 0., std)
+    return _no_grad_normal_(tensor, 0., std, generator)
 
 
 def _calculate_correct_fan(tensor, mode):
@@ -379,7 +404,11 @@ def _calculate_correct_fan(tensor, mode):
 
 
 def kaiming_uniform_(
-    tensor: Tensor, a: float = 0, mode: str = 'fan_in', nonlinearity: str = 'leaky_relu'
+    tensor: Tensor,
+    a: float = 0,
+    mode: str = "fan_in",
+    nonlinearity: str = "leaky_relu",
+    generator: _Optional[torch.Generator] = None,
 ):
     r"""Fill the input `Tensor` with values using a Kaiming uniform distribution.
 
@@ -403,6 +432,7 @@ def kaiming_uniform_(
             backwards pass.
         nonlinearity: the non-linear function (`nn.functional` name),
             recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
+        generator: the torch Generator to sample from (default: None)
 
     Examples:
         >>> w = torch.empty(3, 5)
@@ -415,7 +445,8 @@ def kaiming_uniform_(
             tensor=tensor,
             a=a,
             mode=mode,
-            nonlinearity=nonlinearity)
+            nonlinearity=nonlinearity,
+            generator=generator)
 
     if 0 in tensor.shape:
         warnings.warn("Initializing zero-element tensors is a no-op")
@@ -425,11 +456,15 @@ def kaiming_uniform_(
     std = gain / math.sqrt(fan)
     bound = math.sqrt(3.0) * std  # Calculate uniform bounds from standard deviation
     with torch.no_grad():
-        return tensor.uniform_(-bound, bound)
+        return tensor.uniform_(-bound, bound, generator=generator)
 
 
 def kaiming_normal_(
-    tensor: Tensor, a: float = 0, mode: str = 'fan_in', nonlinearity: str = 'leaky_relu'
+    tensor: Tensor,
+    a: float = 0,
+    mode: str = "fan_in",
+    nonlinearity: str = "leaky_relu",
+    generator: _Optional[torch.Generator] = None,
 ):
     r"""Fill the input `Tensor` with values using a Kaiming normal distribution.
 
@@ -453,6 +488,7 @@ def kaiming_normal_(
             backwards pass.
         nonlinearity: the non-linear function (`nn.functional` name),
             recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
+        generator: the torch Generator to sample from (default: None)
 
     Examples:
         >>> w = torch.empty(3, 5)
@@ -465,10 +501,14 @@ def kaiming_normal_(
     gain = calculate_gain(nonlinearity, a)
     std = gain / math.sqrt(fan)
     with torch.no_grad():
-        return tensor.normal_(0, std)
+        return tensor.normal_(0, std, generator=generator)
 
 
-def orthogonal_(tensor, gain=1):
+def orthogonal_(
+    tensor,
+    gain=1,
+    generator: _Optional[torch.Generator] = None,
+):
     r"""Fill the input `Tensor` with a (semi) orthogonal matrix.
 
     Described in `Exact solutions to the nonlinear dynamics of learning in deep
@@ -479,6 +519,7 @@ def orthogonal_(tensor, gain=1):
     Args:
         tensor: an n-dimensional `torch.Tensor`, where :math:`n \geq 2`
         gain: optional scaling factor
+        generator: the torch Generator to sample from (default: None)
 
     Examples:
         >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK)
@@ -493,7 +534,7 @@ def orthogonal_(tensor, gain=1):
         return tensor
     rows = tensor.size(0)
     cols = tensor.numel() // rows
-    flattened = tensor.new(rows, cols).normal_(0, 1)
+    flattened = tensor.new(rows, cols).normal_(0, 1, generator=generator)
 
     if rows < cols:
         flattened.t_()
@@ -514,7 +555,12 @@ def orthogonal_(tensor, gain=1):
     return tensor
 
 
-def sparse_(tensor, sparsity, std=0.01):
+def sparse_(
+    tensor,
+    sparsity,
+    std=0.01,
+    generator: _Optional[torch.Generator] = None,
+):
     r"""Fill the 2D input `Tensor` as a sparse matrix.
 
     The non-zero elements will be drawn from the normal distribution
@@ -526,6 +572,7 @@ def sparse_(tensor, sparsity, std=0.01):
         sparsity: The fraction of elements in each column to be set to zero
         std: the standard deviation of the normal distribution used to generate
             the non-zero values
+        generator: the torch Generator to sample from (default: None)
 
     Examples:
         >>> w = torch.empty(3, 5)
@@ -538,7 +585,7 @@ def sparse_(tensor, sparsity, std=0.01):
     num_zeros = int(math.ceil(sparsity * rows))
 
     with torch.no_grad():
-        tensor.normal_(0, std)
+        tensor.normal_(0, std, generator=generator)
         for col_idx in range(cols):
             row_indices = torch.randperm(rows)
             zero_indices = row_indices[:num_zeros]
diff --git a/torch/overrides.py b/torch/overrides.py
index 1767d43dd7a4..3084bf066825 100644
--- a/torch/overrides.py
+++ b/torch/overrides.py
@@ -927,10 +927,10 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
                                                                 distance_function=None, margin=1.0,
                                                                 swap=False, reduction='mean': -1),
         torch.nn.functional.unfold: lambda input, kernel_size, dilation=1, padding=0, stride=1: -1,
-        torch.nn.init.uniform_: lambda tensor, a=0., b=1.: -1,
-        torch.nn.init.normal_: lambda tensor, mean=0., std=1.: -1,
+        torch.nn.init.uniform_: lambda tensor, a=0., b=1., generator=None: -1,
+        torch.nn.init.normal_: lambda tensor, mean=0., std=1., generator=None: -1,
         torch.nn.init.constant_: lambda tensor, val: -1,
-        torch.nn.init.kaiming_uniform_: lambda tensor, a=0, mode='fan_in', nonlinearity='leaky_relu': -1,
+        torch.nn.init.kaiming_uniform_: lambda tensor, a=0, mode='fan_in', nonlinearity='leaky_relu', generator=None: -1,
         torch.nonzero: lambda input, as_tuple=False: -1,
         torch.nonzero_static: lambda input, *, size, fill_value=-1: -1,
         torch.argwhere: lambda input: -1,
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 2629c481c495..4e42a4497162 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -9559,7 +9559,14 @@ def wrapper_set_seed(op, *args, **kwargs):
     """
     with freeze_rng_state():
         torch.manual_seed(42)
-        return op(*args, **kwargs)
+        output = op(*args, **kwargs)
+
+        if isinstance(output, torch.Tensor) and output.device.type == "lazy":
+            # We need to call mark step inside freeze_rng_state so that numerics
+            # match eager execution
+            torch._lazy.mark_step()
+
+        return output
 
 
 def reference_layer_norm(inp: np.ndarray, normalized_shape: Tuple[int], weight=None, bias=None, eps=1e-5):
@@ -18095,8 +18102,14 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
             DecorateInfo(unittest.skip('Skipped!'), 'TestJit', 'test_variant_consistency_jit'),
             # Lazy tensor failures
             DecorateInfo(unittest.skip('Skipped!'), 'TestLazyOpInfo', 'test_dispatched_to_lazy'),
-            DecorateInfo(unittest.expectedFailure, 'TestLazyOpInfo', 'test_correctness'),
-            DecorateInfo(unittest.expectedFailure, 'TestLazyOpInfo', 'test_correctness_with_reusing_ir'),
+            # These tests fail only when built with ASAN
+            DecorateInfo(unittest.skip("Fails with ASAN"), 'TestLazyOpInfo', 'test_correctness', active_if=TEST_WITH_ASAN),
+            DecorateInfo(
+                unittest.skip("Fails with ASAN"),
+                'TestLazyOpInfo',
+                'test_correctness_with_reusing_ir',
+                active_if=TEST_WITH_ASAN
+            ),
         ),
     ),
     OpInfo(
diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py
index 30f0311ba7b3..49b8db0f8081 100644
--- a/torch/testing/_internal/common_utils.py
+++ b/torch/testing/_internal/common_utils.py
@@ -1331,6 +1331,28 @@ def wrapper(*args, **kwargs):
 def skipRocmIfTorchInductor(msg="test doesn't currently work with torchinductor on the ROCm stack"):
     return skipIfTorchInductor(msg=msg, condition=TEST_WITH_ROCM and TEST_WITH_TORCHINDUCTOR)
 
+def skipIfLegacyJitExecutor(msg="test doesn't currently work with legacy JIT executor"):
+    def decorator(fn):
+        if not isinstance(fn, type):
+            @wraps(fn)
+            def wrapper(*args, **kwargs):
+                if GRAPH_EXECUTOR == ProfilingMode.LEGACY:
+                    raise unittest.SkipTest(msg)
+                else:
+                    fn(*args, **kwargs)
+            return wrapper
+
+        assert(isinstance(fn, type))
+        if GRAPH_EXECUTOR == ProfilingMode.LEGACY:
+            fn.__unittest_skip__ = True
+            fn.__unittest_skip_why__ = msg
+
+        return fn
+
+
+    return decorator
+
+
 # Run PyTorch tests with translation validation on.
 TEST_WITH_TV = os.getenv('PYTORCH_TEST_WITH_TV') == '1'
 
diff --git a/torchgen/api/lazy.py b/torchgen/api/lazy.py
index 90dc93cfa969..8fdd2ddcfa7a 100644
--- a/torchgen/api/lazy.py
+++ b/torchgen/api/lazy.py
@@ -7,6 +7,7 @@
     CType,
     deviceT,
     doubleT,
+    generatorT,
     layoutT,
     ListCType,
     longT,
@@ -109,6 +110,8 @@ def process_ir_type(
             return BaseCType(stringT)
         elif typ.name == BaseTy.Device:
             return BaseCType(deviceT)
+        elif typ.name == BaseTy.Generator:
+            return BaseCType(generatorT)
         elif typ.name == BaseTy.Layout:
             return BaseCType(layoutT)
         elif typ.name == BaseTy.MemoryFormat:
@@ -218,16 +221,7 @@ def __init__(self, arg: Argument, properties: "LazyIrProperties", *, symint: boo
         self.symint = symint
         self.is_optional = isinstance(arg.type, OptionalType)
         self.is_generator = isGeneratorType(arg.type)
-        if self.is_generator:
-            assert (
-                self.is_optional
-            ), "We expect all generators are optional since currently they are"
-            # there is no handling for generators in TorchScript IR (or XLA)
-            # so we fall back to eager if the (optional)generator has value, and otherwise
-            # its null and safe to exclude from lazy IR
-            self.lazy_type_ = None
-        else:
-            self.lazy_type_ = process_ir_type(arg.type, properties, symint=symint)
+        self.lazy_type_ = process_ir_type(arg.type, properties, symint=symint)
         self.is_wrapped_scalar = isWrappedScalarType(arg.type)
         self.is_symint_or_list = symint and (
             isSymIntType(arg.type)
@@ -236,9 +230,7 @@ def __init__(self, arg: Argument, properties: "LazyIrProperties", *, symint: boo
             # or (isinstance(arg.type, ListType) and isSymIntType(arg.type.elem))
         )
 
-        self.is_lazy_value = not self.is_generator and isValueType(
-            self.lazy_type, properties
-        )
+        self.is_lazy_value = isValueType(self.lazy_type, properties)
 
     @property
     def lazy_type(self) -> CType:
@@ -419,7 +411,7 @@ def filtered_args(
         keyword: bool = True,
         values: bool = True,
         scalars: bool = True,
-        generator: bool = False,
+        generator: bool = True,
     ) -> List[LazyArgument]:
         # This function maintains the sorted order of arguments but provides different filtered views.
         # Some parts of the code care about kwargs vs args (TS lowerings),
diff --git a/torchgen/dest/lazy_ir.py b/torchgen/dest/lazy_ir.py
index 5476bdd32a40..43cde1e04043 100644
--- a/torchgen/dest/lazy_ir.py
+++ b/torchgen/dest/lazy_ir.py
@@ -122,12 +122,8 @@ def gen_fallback_code(
         aten_op_str = f"ATEN_OP2({schema.aten_name}, {overload_name})"
     else:
         aten_op_str = f"ATEN_OP({schema.aten_name})"
-    or_has_generator = ""
-    if schema.generator_arg:
-        # generators are always optional and there is never more than one, at least currently
-        or_has_generator = f" || ({schema.generator_arg.name}.has_value() && {schema.generator_arg.name}->defined())"
     return f"""
-        if (force_eager_fallback({aten_symbol(schema)}){or_has_generator}) {{
+        if (force_eager_fallback({aten_symbol(schema)})) {{
             return at::native::call_fallback_fn_symint<<c_eager_fallback, {aten_op_str}>::call(
                 {fallback_args}
             );
@@ -290,9 +286,12 @@ def gen(self, schema: LazyIrSchema) -> List[str]:
         members_to_string = []
         for arg in scalar_args:
             if isinstance(arg.lazy_type, OptionalCType):
+                value = f"{arg.name}.value()"
+                if arg.is_generator:
+                    value = '"torch.Generator()"'
                 members_to_string.append(
                     f"""if ({arg.name}.has_value()) {{
-      ss << ", {arg.name}=" << {arg.name}.value();
+      ss << ", {arg.name}=" << {value};
     }} else {{
       ss << ", {arg.name}=null";
     }}"""

From afdc5285203041277c8136d295a886482f761f44 Mon Sep 17 00:00:00 2001
From: Joel Schlosser 
Date: Tue, 21 Nov 2023 13:56:57 -0500
Subject: [PATCH 071/221] Print the index and summary of the SampleInput that
 failed an OpInfo test (#99444)

Related to the Reproducible Testing BE project. Goal is to print out the sample input that failed an OpInfo test.

Crazy idea: to avoid requiring widespread changes across tests that use OpInfo sample inputs, return a new special iterator type from `OpInfo.sample_inputs()`, etc. that tracks the most recent item seen. If a test fails later on, print out this info to identify the sample that failed the test.

This solves the problem that the test framework currently has no concept of which sample input is being operated on.

This PR contains the following changes:
* New `TrackedInputIter` that wraps a sample inputs func iterator and tracks the most recent input seen in a `TrackedInput` structure
    * The information is stored in a dictionary on the test function itself, mapping `full test ID -> most recent TrackedInput`
* To determine the test function that is being run, we do some stack crawling hackery in `extract_test_fn_and_id()`
* Above applies only when one of the following is called: `OpInfo.sample_inputs()`, `OpInfo.error_inputs()`, `OpInfo.reference_inputs()`, and `OpInfo.conjugate_sample_inputs()`. This could easily be extended to `ModuleInfo`s and the sparse sample input funcs as well

Example output when a sample input causes a failure:
```
======================================================================
ERROR: test_foo_add_cpu_uint8 (__main__.TestFakeTensorCPU)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/jbschlosser/branches/reproducible_testing/torch/testing/_internal/common_device_type.py", line 911, in test_wrapper
    return test(*args, **kwargs)
  File "/home/jbschlosser/branches/reproducible_testing/torch/testing/_internal/common_device_type.py", line 1097, in only_fn
    return fn(slf, *args, **kwargs)
  File "/home/jbschlosser/branches/reproducible_testing/test/test_ops.py", line 2211, in test_foo
    self.fail('Example failure')
AssertionError: Example failure

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/jbschlosser/branches/reproducible_testing/torch/testing/_internal/common_utils.py", line 2436, in wrapper
    method(*args, **kwargs)
  File "/home/jbschlosser/branches/reproducible_testing/torch/testing/_internal/common_device_type.py", line 414, in instantiated_test
    result = test(self, **param_kwargs)
  File "/home/jbschlosser/branches/reproducible_testing/torch/testing/_internal/common_device_type.py", line 917, in test_wrapper
    raise Exception(
Exception: Caused by sample input at index 2: SampleInput(input=Tensor[size=(5, 1), device="cpu", dtype=torch.uint8], args=TensorList[Tensor[size=(5,), device="cpu", dtype=torch.uint8]], kwargs={}, broadcasts_input=True, name='')

To execute this test, run the following from the base repo dir:
     python test/test_ops.py -k test_foo_add_cpu_uint8

This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0

----------------------------------------------------------------------
```

This notably doesn't print the actual `SampleInput` values, as that's hard without fully reproducible random sample generation. I went down this path for a while and it seems infeasible without adding an untenable amount of overhead to set the random seed per SampleInput (see https://github.com/pytorch/pytorch/issues/86694#issuecomment-1614943708 for more details). For now, I am settling for at least spitting out the index and some metadata of the `SampleInput`, as it seems better than nothing.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99444
Approved by: https://github.com/janeyx99
---
 test/test_testing.py                          |  8 +-
 torch/testing/_internal/common_device_type.py | 29 ++++++-
 torch/testing/_internal/common_utils.py       | 76 +++++++++++++++++++
 torch/testing/_internal/opinfo/core.py        | 32 ++++++--
 4 files changed, 131 insertions(+), 14 deletions(-)

diff --git a/test/test_testing.py b/test/test_testing.py
index feb408773f4c..542601d7ed97 100644
--- a/test/test_testing.py
+++ b/test/test_testing.py
@@ -12,7 +12,7 @@
 import subprocess
 import sys
 import unittest.mock
-from typing import Any, Callable, Iterator, List, Tuple, Generator
+from typing import Any, Callable, Iterator, List, Tuple
 
 import torch
 
@@ -2397,19 +2397,19 @@ class TestOpInfoSampleFunctions(TestCase):
     def test_opinfo_sample_generators(self, device, dtype, op):
         # Test op.sample_inputs doesn't generate multiple samples when called
         samples = op.sample_inputs(device, dtype)
-        self.assertIsInstance(samples, Generator)
+        self.assertIsInstance(samples, Iterator)
 
     @ops([op for op in op_db if op.reference_inputs_func is not None], dtypes=OpDTypes.any_one)
     def test_opinfo_reference_generators(self, device, dtype, op):
         # Test op.reference_inputs doesn't generate multiple samples when called
         samples = op.reference_inputs(device, dtype)
-        self.assertIsInstance(samples, Generator)
+        self.assertIsInstance(samples, Iterator)
 
     @ops([op for op in op_db if op.error_inputs_func is not None], dtypes=OpDTypes.none)
     def test_opinfo_error_generators(self, device, op):
         # Test op.error_inputs doesn't generate multiple inputs when called
         samples = op.error_inputs(device)
-        self.assertIsInstance(samples, Generator)
+        self.assertIsInstance(samples, Iterator)
 
 
 instantiate_device_type_tests(TestOpInfoSampleFunctions, globals())
diff --git a/torch/testing/_internal/common_device_type.py b/torch/testing/_internal/common_device_type.py
index 96b7817b5c4a..b5d1e769209b 100644
--- a/torch/testing/_internal/common_device_type.py
+++ b/torch/testing/_internal/common_device_type.py
@@ -15,7 +15,8 @@
     skipCUDANonDefaultStreamIf, TEST_WITH_ASAN, TEST_WITH_UBSAN, TEST_WITH_TSAN, \
     IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU, IS_WINDOWS, TEST_MPS, \
     _TestParametrizer, compose_parametrize_fns, dtype_name, \
-    TEST_WITH_MIOPEN_SUGGEST_NHWC, NATIVE_DEVICES, skipIfTorchDynamo
+    TEST_WITH_MIOPEN_SUGGEST_NHWC, NATIVE_DEVICES, skipIfTorchDynamo, \
+    get_tracked_input, clear_tracked_input, PRINT_REPRO_ON_FAILURE
 from torch.testing._internal.common_cuda import _get_torch_cuda_version, \
     TEST_CUSPARSE_GENERIC, TEST_HIPSPARSE_GENERIC, _get_torch_rocm_version
 from torch.testing._internal.common_dtype import get_all_dtypes
@@ -796,6 +797,12 @@ class OpDTypes(Enum):
     torch.bool
 )
 
+def _serialize_sample(sample_input):
+    # NB: For OpInfos, SampleInput.summary() prints in a cleaner way.
+    if getattr(sample_input, "summary", None) is not None:
+        return sample_input.summary()
+    return str(sample_input)
+
 # Decorator that defines the OpInfos a test template should be instantiated for.
 #
 # Example usage:
@@ -905,7 +912,25 @@ def _parametrize_test(self, test, generic_cls, device_cls):
                 try:
                     @wraps(test)
                     def test_wrapper(*args, **kwargs):
-                        return test(*args, **kwargs)
+                        try:
+                            return test(*args, **kwargs)
+                        except unittest.SkipTest as e:
+                            raise e
+                        except Exception as e:
+                            tracked_input = get_tracked_input()
+                            if PRINT_REPRO_ON_FAILURE and tracked_input is not None:
+                                raise Exception(
+                                    f"Caused by {tracked_input.type_desc} "
+                                    f"at index {tracked_input.index}: "
+                                    f"{_serialize_sample(tracked_input.val)}") from e
+                            raise e
+                        finally:
+                            clear_tracked_input()
+
+                    # Initialize info for the last input seen. This is useful for tracking
+                    # down which inputs caused a test failure. Note that TrackedInputIter is
+                    # responsible for managing this.
+                    test.tracked_input = None
 
                     decorator_fn = partial(op.get_decorators, generic_cls.__name__,
                                            test.__name__, device_cls.device_type, dtype)
diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py
index 49b8db0f8081..72a8dfa8c024 100644
--- a/torch/testing/_internal/common_utils.py
+++ b/torch/testing/_internal/common_utils.py
@@ -36,6 +36,7 @@
 from collections.abc import Mapping, Sequence
 from contextlib import closing, contextmanager
 from copy import deepcopy
+from dataclasses import dataclass
 from enum import Enum
 from functools import partial, wraps
 from itertools import product, chain
@@ -237,6 +238,81 @@ def wrapper(*args, **kwargs):
         fn(*args, **kwargs)
     return wrapper
 
+# Tries to extract the current test function by crawling the stack.
+# If unsuccessful, return None.
+def extract_test_fn() -> Optional[Callable]:
+    try:
+        stack = inspect.stack()
+        for frame_info in stack:
+            frame = frame_info.frame
+            if "self" not in frame.f_locals:
+                continue
+            self_val = frame.f_locals["self"]
+            if isinstance(self_val, unittest.TestCase):
+                test_id = self_val.id()
+                test_name = test_id.split('.')[2]
+                test_fn = getattr(self_val, test_name).__func__
+                return test_fn
+    except Exception:
+        pass
+    return None
+
+# Contains tracked input data useful for debugging purposes
+@dataclass
+class TrackedInput:
+    index: int
+    val: Any
+    type_desc: str
+
+# Attempt to pull out tracked input information from the test function.
+# A TrackedInputIter is used to insert this information.
+def get_tracked_input() -> Optional[TrackedInput]:
+    test_fn = extract_test_fn()
+    if test_fn is None:
+        return None
+    if not hasattr(test_fn, "tracked_input"):
+        return None
+    return test_fn.tracked_input
+
+def clear_tracked_input():
+    test_fn = extract_test_fn()
+    if test_fn is None:
+        return
+    if not hasattr(test_fn, "tracked_input"):
+        return None
+    test_fn.tracked_input = None
+
+# Wraps an iterator and tracks the most recent value the iterator produces
+# for debugging purposes. Tracked values are stored on the test function.
+class TrackedInputIter:
+    def __init__(self, child_iter, input_type_desc, callback=lambda x: x):
+        self.child_iter = enumerate(child_iter)
+        # Input type describes the things we're tracking (e.g. "sample input", "error input").
+        self.input_type_desc = input_type_desc
+        # Callback is run on each iterated thing to get the thing to track.
+        self.callback = callback
+        self.test_fn = extract_test_fn()
+
+    def __iter__(self):
+        return self
+
+    def __next__(self):
+        # allow StopIteration to bubble up
+        input_idx, input_val = next(self.child_iter)
+        self._set_tracked_input(
+            TrackedInput(
+                index=input_idx, val=self.callback(input_val), type_desc=self.input_type_desc
+            )
+        )
+        return input_val
+
+    def _set_tracked_input(self, tracked_input: TrackedInput):
+        if self.test_fn is None:
+            return
+        if not hasattr(self.test_fn, "tracked_input"):
+            return
+        self.test_fn.tracked_input = tracked_input
+
 class _TestParametrizer:
     """
     Decorator class for parametrizing a test function, yielding a set of new tests spawned
diff --git a/torch/testing/_internal/opinfo/core.py b/torch/testing/_internal/opinfo/core.py
index fc0fbf95864f..23b6e89e4a21 100644
--- a/torch/testing/_internal/opinfo/core.py
+++ b/torch/testing/_internal/opinfo/core.py
@@ -29,6 +29,7 @@
     noncontiguous_like,
     TEST_WITH_ROCM,
     torch_to_numpy_dtype_dict,
+    TrackedInputIter,
 )
 from torch.testing._internal.opinfo import utils
 
@@ -207,7 +208,6 @@ def _repr_helper(self, formatter):
             f"input={formatter(self.input)}",
             f"args={formatter(self.args)}",
             f"kwargs={formatter(self.kwargs)}",
-            f"output_process_fn_grad={self.output_process_fn_grad}",
             f"broadcasts_input={self.broadcasts_input}",
             f"name={repr(self.name)}",
         ]
@@ -227,8 +227,15 @@ def formatter(arg):
             # by Tensor[TensorShape]
             # Eg. Tensor with shape (3, 4) is formatted as Tensor[3, 4]
             if isinstance(arg, torch.Tensor):
-                shape = str(tuple(arg.shape)).replace("(", "").replace(")", "")
-                return f"Tensor[{shape}]"
+                shape = str(tuple(arg.shape))
+                dtype = str(arg.dtype)
+                device = str(arg.device)
+                contiguity_suffix = ""
+                # NB: sparse CSR tensors annoyingly return is_sparse=False
+                is_sparse = arg.is_sparse or arg.layout == torch.sparse_csr
+                if not is_sparse and not arg.is_contiguous():
+                    contiguity_suffix = ", contiguous=False"
+                return f'Tensor[size={shape}, device="{device}", dtype={dtype}{contiguity_suffix}]'
             elif isinstance(arg, dict):
                 return {k: formatter(v) for k, v in arg.items()}
             elif is_iterable_of_tensors(arg):
@@ -1155,7 +1162,7 @@ def conjugate(tensor):
             else:
                 sample.input[0] = conjugate(sample.input[0])
 
-        return tuple(conj_samples)
+        return TrackedInputIter(iter(conj_samples), "conjugate sample input")
 
     def sample_inputs(self, device, dtype, requires_grad=False, **kwargs):
         """
@@ -1174,7 +1181,7 @@ def sample_inputs(self, device, dtype, requires_grad=False, **kwargs):
             samples_list.extend(conj_samples)
             samples = tuple(samples_list)
 
-        return samples
+        return TrackedInputIter(iter(samples), "sample input")
 
     def reference_inputs(self, device, dtype, requires_grad=False, **kwargs):
         """
@@ -1185,18 +1192,27 @@ def reference_inputs(self, device, dtype, requires_grad=False, **kwargs):
         the sample inputs.
         """
         if self.reference_inputs_func is None:
-            return self.sample_inputs_func(self, device, dtype, requires_grad, **kwargs)
+            samples = self.sample_inputs_func(
+                self, device, dtype, requires_grad, **kwargs
+            )
+            return TrackedInputIter(iter(samples), "sample input")
 
         if kwargs.get("include_conjugated_inputs", False):
             raise NotImplementedError
 
-        return self.reference_inputs_func(self, device, dtype, requires_grad, **kwargs)
+        references = self.reference_inputs_func(
+            self, device, dtype, requires_grad, **kwargs
+        )
+        return TrackedInputIter(iter(references), "reference input")
 
     def error_inputs(self, device, **kwargs):
         """
         Returns an iterable of ErrorInputs.
         """
-        return self.error_inputs_func(self, device, **kwargs)
+        errs = self.error_inputs_func(self, device, **kwargs)
+        return TrackedInputIter(
+            iter(errs), "error input", callback=lambda e: e.sample_input
+        )
 
     def error_inputs_sparse(self, device, layout, **kwargs):
         """

From 066ac56e0202ee6a9388eafd13867cbb046de03a Mon Sep 17 00:00:00 2001
From: Eli Uriegas 
Date: Tue, 21 Nov 2023 14:16:10 -0600
Subject: [PATCH 072/221] ci: Clean up logic for `merge -r` (#114295)

Rely on built in bash conditionals for doing the if statement rather
than relying on $?

To avoid issues observed in https://github.com/pytorch/pytorch/pull/111008#issuecomment-1821547141

Signed-off-by: Eli Uriegas 
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114295
Approved by: https://github.com/huydhn, https://github.com/malfet
---
 .github/workflows/trymerge.yml | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/.github/workflows/trymerge.yml b/.github/workflows/trymerge.yml
index c95db9efab1b..db0b80b32fa4 100644
--- a/.github/workflows/trymerge.yml
+++ b/.github/workflows/trymerge.yml
@@ -33,6 +33,7 @@ jobs:
           git config --global user.email "pytorchmergebot@users.noreply.github.com"
           git config --global user.name "PyTorch MergeBot"
       - name: Merge PR
+        shell: bash
         env:
           GITHUB_TOKEN: ${{ secrets.MERGEBOT_TOKEN }}
           PR_NUM: ${{ github.event.client_payload.pr_num }}
@@ -45,8 +46,8 @@ jobs:
         run: |
           set -x
           if [ -n "${REBASE}" ]; then
-            python3 .github/scripts/tryrebase.py "${PR_NUM}" --branch "${REBASE}"
-            if [ $? != 1 ]; then
+            # attempt to rebase, if it fails then comment on the PR that it failed
+            if ! python3 .github/scripts/tryrebase.py "${PR_NUM}" --branch "${REBASE}"; then
               python3 .github/scripts/comment_on_pr.py "${PR_NUM}" "merge"
               exit 0
             fi

From 628586606ebc000aef6063d9b7f93131bb5c7db7 Mon Sep 17 00:00:00 2001
From: Jon Chuang <9093549+jon-chuang@users.noreply.github.com>
Date: Wed, 22 Nov 2023 00:04:35 +0000
Subject: [PATCH 073/221] [test] fix broken test, enable test (#114235)

Fixes root cause of https://github.com/pytorch/pytorch/pull/114053#issuecomment-1820632457

This test was not running on OSS CI.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114235
Approved by: https://github.com/ezyang
---
 test/dynamo/test_recompile_ux.py | 89 +++++++++++++++++---------------
 1 file changed, 46 insertions(+), 43 deletions(-)

diff --git a/test/dynamo/test_recompile_ux.py b/test/dynamo/test_recompile_ux.py
index 3d57b84cc77d..4039d924aec5 100644
--- a/test/dynamo/test_recompile_ux.py
+++ b/test/dynamo/test_recompile_ux.py
@@ -10,6 +10,7 @@
 import torch._dynamo.testing
 
 import torch._logging
+from torch.testing._internal.logging_utils import kwargs_to_settings, log_settings
 
 
 class RecompileUxTests(torch._dynamo.test_case.TestCase):
@@ -223,67 +224,69 @@ def f(x):
             opt_f(torch.randn(8 + i))
 
         failure_str = "\n".join(failure_reasons)
-        self.assertExpectedInline(
-            failure_str,
-            """\
+        for line in """\
 tensor 'L['x']' size mismatch at index 0. expected 11, actual 12
 tensor 'L['x']' size mismatch at index 0. expected 10, actual 12
 tensor 'L['x']' size mismatch at index 0. expected 9, actual 12
-tensor 'L['x']' size mismatch at index 0. expected 8, actual 12""",
-        )
+tensor 'L['x']' size mismatch at index 0. expected 8, actual 12""".split(
+            "\n"
+        ):
+            self.assertIn(
+                line,
+                failure_str,
+            )
 
     @torch._dynamo.config.patch("cache_size_limit", 32)
     def test_multiple_guard_fails_report_all(self):
-        torch._logging.set_logs(recompiles_verbose=True)
-        failure_reasons = []
+        with log_settings(kwargs_to_settings(recompiles_verbose=True)):
+            failure_reasons = []
 
-        def guard_fail_fn(failure):
-            failure_reasons.append(failure[0])
+            def guard_fail_fn(failure):
+                failure_reasons.append(failure[0])
 
-        def f(x):
-            return torch.ones(len(x), x[-1])
+            def f(x):
+                return torch.ones(len(x), x[-1])
 
-        opt_f = torch._dynamo.optimize(
-            backend="eager", guard_fail_fn=guard_fail_fn, dynamic=False
-        )(f)
+            opt_f = torch._dynamo.optimize(
+                backend="eager", guard_fail_fn=guard_fail_fn, dynamic=False
+            )(f)
 
-        opt_f([4, 5, 6])
+            opt_f([4, 5, 6])
 
-        def filter_reasons():
-            return "\n".join(
-                [
-                    line
-                    for line in "\n".join(failure_reasons).splitlines()
-                    if not line.startswith("___check_type_id")
-                ]
-            )
+            def filter_reasons():
+                return "\n".join(
+                    [
+                        line
+                        for line in "\n".join(failure_reasons).splitlines()
+                        if not line.startswith("___check_type_id")
+                    ]
+                )
+
+            failure_reasons.clear()
+            opt_f([7, 8])
 
-        failure_reasons.clear()
-        opt_f([7, 8])
-        self.assertExpectedInline(
-            filter_reasons(),
-            """\
+            for line in """\
 len(L['x']) == 3
 L['x'][0] == 4
-L['x'][1] == 5""",
-        )
+L['x'][1] == 5""".split(
+                "\n"
+            ):
+                self.assertIn(line, filter_reasons())
+
+            failure_reasons.clear()
+            opt_f([9])
 
-        failure_reasons.clear()
-        opt_f([9])
-        self.assertExpectedInline(
-            filter_reasons(),
-            """\
+            for line in """\
 len(L['x']) == 2
 L['x'][0] == 7
 len(L['x']) == 3
-L['x'][0] == 4""",
-        )
+L['x'][0] == 4""".split(
+                "\n"
+            ):
+                self.assertIn(line, filter_reasons())
 
-        # reset logging state
-        torch._logging.set_logs()
 
+if __name__ == "__main__":
+    from torch._dynamo.test_case import run_tests
 
-# TODO(jansel): these pass with pytest, but not with pytorch CI
-# if __name__ == "__main__":
-#     from torch._dynamo.testing import run_tests
-#     run_tests()
+    run_tests()

From 995fae6060dc4e284dc29901e7775e821bf3bd8f Mon Sep 17 00:00:00 2001
From: atalman 
Date: Wed, 22 Nov 2023 00:10:03 +0000
Subject: [PATCH 074/221] Move small pypi build as default for linux cuda 12.1
 (#114281)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

This is first PR to resolve: https://github.com/pytorch/pytorch/issues/113972
Move our small wheel build as default
Test:
```
pip3 install --no-cache-dir --pre torch-2.2.0.dev20231121%2Bcu121-cp310-cp310-linux_x86_64.whl  --index-url https://download.pytorch.org/whl/nightly/cu121
Looking in indexes: https://download.pytorch.org/whl/nightly/cu121
Processing ./torch-2.2.0.dev20231121%2Bcu121-cp310-cp310-linux_x86_64.whl
Collecting filelock (from torch==2.2.0.dev20231121+cu121)
  Downloading https://download.pytorch.org/whl/nightly/filelock-3.9.0-py3-none-any.whl (9.7 kB)
Collecting typing-extensions>=4.8.0 (from torch==2.2.0.dev20231121+cu121)
  Downloading https://download.pytorch.org/whl/nightly/typing_extensions-4.8.0-py3-none-any.whl (31 kB)
Collecting sympy (from torch==2.2.0.dev20231121+cu121)
  Downloading https://download.pytorch.org/whl/nightly/sympy-1.11.1-py3-none-any.whl (6.5 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 6.5/6.5 MB 253.4 MB/s eta 0:00:00
Collecting networkx (from torch==2.2.0.dev20231121+cu121)
  Downloading https://download.pytorch.org/whl/nightly/networkx-3.0rc1-py3-none-any.whl (2.0 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.0/2.0 MB 387.1 MB/s eta 0:00:00
Collecting jinja2 (from torch==2.2.0.dev20231121+cu121)
  Downloading https://download.pytorch.org/whl/nightly/Jinja2-3.1.2-py3-none-any.whl (133 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 133.1/133.1 kB 365.3 MB/s eta 0:00:00
Collecting fsspec (from torch==2.2.0.dev20231121+cu121)
  Downloading https://download.pytorch.org/whl/nightly/fsspec-2023.4.0-py3-none-any.whl (153 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 154.0/154.0 kB 370.6 MB/s eta 0:00:00
Collecting pytorch-triton==2.1.0+6e4932cda8 (from torch==2.2.0.dev20231121+cu121)
  Downloading https://download.pytorch.org/whl/nightly/pytorch_triton-2.1.0%2B6e4932cda8-cp310-cp310-linux_x86_64.whl (125.4 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 125.4/125.4 MB 384.1 MB/s eta 0:00:00
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch==2.2.0.dev20231121+cu121)
  Downloading https://download.pytorch.org/whl/nightly/cu121/nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 23.7/23.7 MB 404.9 MB/s eta 0:00:00
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch==2.2.0.dev20231121+cu121)
  Downloading https://download.pytorch.org/whl/nightly/cu121/nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 823.6/823.6 kB 402.5 MB/s eta 0:00:00
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch==2.2.0.dev20231121+cu121)
  Downloading https://download.pytorch.org/whl/nightly/cu121/nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 14.1/14.1 MB 383.9 MB/s eta 0:00:00
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch==2.2.0.dev20231121+cu121)
  Downloading https://download.pytorch.org/whl/nightly/cu121/nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 731.7/731.7 MB 406.9 MB/s eta 0:00:00
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch==2.2.0.dev20231121+cu121)
  Downloading https://download.pytorch.org/whl/nightly/cu121/nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 410.6/410.6 MB 388.2 MB/s eta 0:00:00
Collecting nvidia-cufft-cu12==11.0.2.54 (from torch==2.2.0.dev20231121+cu121)
  Downloading https://download.pytorch.org/whl/nightly/cu121/nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 121.6/121.6 MB 410.5 MB/s eta 0:00:00
Collecting nvidia-curand-cu12==10.3.2.106 (from torch==2.2.0.dev20231121+cu121)
  Downloading https://download.pytorch.org/whl/nightly/cu121/nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl (56.5 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 56.5/56.5 MB 272.9 MB/s eta 0:00:00
Collecting nvidia-cusolver-cu12==11.4.5.107 (from torch==2.2.0.dev20231121+cu121)
  Downloading https://download.pytorch.org/whl/nightly/cu121/nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl (124.2 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 124.2/124.2 MB 381.5 MB/s eta 0:00:00
Collecting nvidia-cusparse-cu12==12.1.0.106 (from torch==2.2.0.dev20231121+cu121)
  Downloading https://download.pytorch.org/whl/nightly/cu121/nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl (196.0 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 196.0/196.0 MB 394.6 MB/s eta 0:00:00
Collecting nvidia-nccl-cu12==2.19.3 (from torch==2.2.0.dev20231121+cu121)
  Downloading https://download.pytorch.org/whl/nightly/cu121/nvidia_nccl_cu12-2.19.3-py3-none-manylinux1_x86_64.whl (166.0 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 166.0/166.0 MB 384.7 MB/s eta 0:00:00
Collecting nvidia-nvtx-cu12==12.1.105 (from torch==2.2.0.dev20231121+cu121)
  Downloading https://download.pytorch.org/whl/nightly/cu121/nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (99 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 99.1/99.1 kB 281.8 MB/s eta 0:00:00
Collecting nvidia-nvjitlink-cu12 (from nvidia-cusolver-cu12==11.4.5.107->torch==2.2.0.dev20231121+cu121)
  Downloading https://download.pytorch.org/whl/nightly/cu121/nvidia_nvjitlink_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (19.8 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 19.8/19.8 MB 367.3 MB/s eta 0:00:00
Collecting MarkupSafe>=2.0 (from jinja2->torch==2.2.0.dev20231121+cu121)
  Downloading https://download.pytorch.org/whl/nightly/MarkupSafe-2.1.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (25 kB)
Collecting mpmath>=0.19 (from sympy->torch==2.2.0.dev20231121+cu121)
  Downloading https://download.pytorch.org/whl/nightly/mpmath-1.2.1-py3-none-any.whl (532 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 532.6/532.6 kB 391.3 MB/s eta 0:00:00
Installing collected packages: mpmath, typing-extensions, sympy, nvidia-nvtx-cu12, nvidia-nvjitlink-cu12, nvidia-nccl-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, networkx, MarkupSafe, fsspec, filelock, pytorch-triton, nvidia-cusparse-cu12, nvidia-cudnn-cu12, jinja2, nvidia-cusolver-cu12, torch
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114281
Approved by: https://github.com/malfet, https://github.com/huydhn
---
 .circleci/scripts/binary_populate_env.sh      |   7 -
 .circleci/scripts/binary_upload.sh            |   5 -
 .../scripts/generate_binary_build_matrix.py   |  52 ++-
 .github/scripts/generate_ci_workflows.py      |   1 -
 .../generated-linux-binary-manywheel-main.yml |  10 +-
 ...nerated-linux-binary-manywheel-nightly.yml | 315 +-----------------
 6 files changed, 34 insertions(+), 356 deletions(-)

diff --git a/.circleci/scripts/binary_populate_env.sh b/.circleci/scripts/binary_populate_env.sh
index d32e4d3f4bd3..b3ac4c655478 100755
--- a/.circleci/scripts/binary_populate_env.sh
+++ b/.circleci/scripts/binary_populate_env.sh
@@ -77,15 +77,8 @@ else
   export PYTORCH_BUILD_VERSION="${BASE_BUILD_VERSION}+$DESIRED_CUDA"
 fi
 
-# The build with with-pypi-cudnn suffix is only applicabe to
-# pypi small wheel Linux x86 build
-if [[ -n "${PYTORCH_EXTRA_INSTALL_REQUIREMENTS:-}" ]] && [[ "$(uname)" == 'Linux' && "$(uname -m)" == "x86_64" ]]; then
-  export PYTORCH_BUILD_VERSION="${PYTORCH_BUILD_VERSION}-with-pypi-cudnn"
-fi
-
 export PYTORCH_BUILD_NUMBER=1
 
-
 JAVA_HOME=
 BUILD_JNI=OFF
 if [[ "$PACKAGE_TYPE" == libtorch ]]; then
diff --git a/.circleci/scripts/binary_upload.sh b/.circleci/scripts/binary_upload.sh
index a980ce098b57..ac5e9485a185 100755
--- a/.circleci/scripts/binary_upload.sh
+++ b/.circleci/scripts/binary_upload.sh
@@ -16,11 +16,6 @@ UPLOAD_BUCKET="s3://pytorch"
 BACKUP_BUCKET="s3://pytorch-backup"
 BUILD_NAME=${BUILD_NAME:-}
 
-# this is temporary change to upload pypi-cudnn builds to separate folder
-if [[ ${BUILD_NAME} == *with-pypi-cudnn* ]]; then
-  UPLOAD_SUBFOLDER="${UPLOAD_SUBFOLDER}_pypi_cudnn"
-fi
-
 DRY_RUN=${DRY_RUN:-enabled}
 # Don't actually do work unless explicit
 ANACONDA="true anaconda"
diff --git a/.github/scripts/generate_binary_build_matrix.py b/.github/scripts/generate_binary_build_matrix.py
index 4c257075e6bd..8511e8a289ec 100644
--- a/.github/scripts/generate_binary_build_matrix.py
+++ b/.github/scripts/generate_binary_build_matrix.py
@@ -264,7 +264,6 @@ def generate_wheels_matrix(
     os: str,
     arches: Optional[List[str]] = None,
     python_versions: Optional[List[str]] = None,
-    gen_special_an_non_special_wheel: bool = True,
 ) -> List[Dict[str, str]]:
     package_type = "wheel"
     if os == "linux" or os == "linux-aarch64":
@@ -298,8 +297,7 @@ def generate_wheels_matrix(
                 else arch_version
             )
 
-            # special 12.1 wheels package without dependencies
-            # dependency downloaded via pip install
+            # 12.1 linux wheels require PYTORCH_EXTRA_INSTALL_REQUIREMENTS to install
             if arch_version == "12.1" and os == "linux":
                 ret.append(
                     {
@@ -313,35 +311,33 @@ def generate_wheels_matrix(
                         "container_image": WHEEL_CONTAINER_IMAGES[arch_version],
                         "package_type": package_type,
                         "pytorch_extra_install_requirements": PYTORCH_EXTRA_INSTALL_REQUIREMENTS,
-                        "build_name": f"{package_type}-py{python_version}-{gpu_arch_type}{gpu_arch_version}-with-pypi-cudnn".replace(  # noqa: B950
+                        "build_name": f"{package_type}-py{python_version}-{gpu_arch_type}{gpu_arch_version}".replace(  # noqa: B950
                             ".", "_"
                         ),
                     }
                 )
-                if not gen_special_an_non_special_wheel:
-                    continue
-
-            ret.append(
-                {
-                    "python_version": python_version,
-                    "gpu_arch_type": gpu_arch_type,
-                    "gpu_arch_version": gpu_arch_version,
-                    "desired_cuda": translate_desired_cuda(
-                        gpu_arch_type, gpu_arch_version
-                    ),
-                    "devtoolset": "cxx11-abi"
-                    if arch_version == "cpu-cxx11-abi"
-                    else "",
-                    "container_image": WHEEL_CONTAINER_IMAGES[arch_version],
-                    "package_type": package_type,
-                    "build_name": f"{package_type}-py{python_version}-{gpu_arch_type}{gpu_arch_version}".replace(
-                        ".", "_"
-                    ),
-                    "pytorch_extra_install_requirements": PYTORCH_EXTRA_INSTALL_REQUIREMENTS
-                    if os != "linux"
-                    else "",
-                }
-            )
+            else:
+                ret.append(
+                    {
+                        "python_version": python_version,
+                        "gpu_arch_type": gpu_arch_type,
+                        "gpu_arch_version": gpu_arch_version,
+                        "desired_cuda": translate_desired_cuda(
+                            gpu_arch_type, gpu_arch_version
+                        ),
+                        "devtoolset": "cxx11-abi"
+                        if arch_version == "cpu-cxx11-abi"
+                        else "",
+                        "container_image": WHEEL_CONTAINER_IMAGES[arch_version],
+                        "package_type": package_type,
+                        "build_name": f"{package_type}-py{python_version}-{gpu_arch_type}{gpu_arch_version}".replace(
+                            ".", "_"
+                        ),
+                        "pytorch_extra_install_requirements": PYTORCH_EXTRA_INSTALL_REQUIREMENTS
+                        if os != "linux"
+                        else "",
+                    }
+                )
     return ret
 
 
diff --git a/.github/scripts/generate_ci_workflows.py b/.github/scripts/generate_ci_workflows.py
index e944c607956f..1075db4255ed 100755
--- a/.github/scripts/generate_ci_workflows.py
+++ b/.github/scripts/generate_ci_workflows.py
@@ -158,7 +158,6 @@ class OperatingSystem:
             OperatingSystem.LINUX,
             arches=["11.8", "12.1"],
             python_versions=["3.8"],
-            gen_special_an_non_special_wheel=False,
         ),
         branches="main",
     ),
diff --git a/.github/workflows/generated-linux-binary-manywheel-main.yml b/.github/workflows/generated-linux-binary-manywheel-main.yml
index 11a1c2867b48..deaf04b0350d 100644
--- a/.github/workflows/generated-linux-binary-manywheel-main.yml
+++ b/.github/workflows/generated-linux-binary-manywheel-main.yml
@@ -70,7 +70,7 @@ jobs:
     secrets:
       github-token: ${{ secrets.GITHUB_TOKEN }}
 
-  manywheel-py3_8-cuda12_1-with-pypi-cudnn-build:
+  manywheel-py3_8-cuda12_1-build:
     if: ${{ github.repository_owner == 'pytorch' }}
     uses: ./.github/workflows/_binary-build-linux.yml
     with:
@@ -84,14 +84,14 @@ jobs:
       GPU_ARCH_TYPE: cuda
       DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main
       DESIRED_PYTHON: "3.8"
-      build_name: manywheel-py3_8-cuda12_1-with-pypi-cudnn
+      build_name: manywheel-py3_8-cuda12_1
       build_environment: linux-binary-manywheel
       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.19.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64'
     secrets:
       github-token: ${{ secrets.GITHUB_TOKEN }}
-  manywheel-py3_8-cuda12_1-with-pypi-cudnn-test:  # Testing
+  manywheel-py3_8-cuda12_1-test:  # Testing
     if: ${{ github.repository_owner == 'pytorch' }}
-    needs: manywheel-py3_8-cuda12_1-with-pypi-cudnn-build
+    needs: manywheel-py3_8-cuda12_1-build
     uses: ./.github/workflows/_binary-test-linux.yml
     with:
       PYTORCH_ROOT: /pytorch
@@ -104,7 +104,7 @@ jobs:
       GPU_ARCH_TYPE: cuda
       DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main
       DESIRED_PYTHON: "3.8"
-      build_name: manywheel-py3_8-cuda12_1-with-pypi-cudnn
+      build_name: manywheel-py3_8-cuda12_1
       build_environment: linux-binary-manywheel
       runs_on: linux.4xlarge.nvidia.gpu
     secrets:
diff --git a/.github/workflows/generated-linux-binary-manywheel-nightly.yml b/.github/workflows/generated-linux-binary-manywheel-nightly.yml
index e61b31ae916e..663d0c2d666b 100644
--- a/.github/workflows/generated-linux-binary-manywheel-nightly.yml
+++ b/.github/workflows/generated-linux-binary-manywheel-nightly.yml
@@ -216,68 +216,6 @@ jobs:
       conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }}
     uses: ./.github/workflows/_binary-upload.yml
 
-  manywheel-py3_8-cuda12_1-with-pypi-cudnn-build:
-    if: ${{ github.repository_owner == 'pytorch' }}
-    uses: ./.github/workflows/_binary-build-linux.yml
-    with:
-      PYTORCH_ROOT: /pytorch
-      BUILDER_ROOT: /builder
-      PACKAGE_TYPE: manywheel
-      # TODO: This is a legacy variable that we eventually want to get rid of in
-      #       favor of GPU_ARCH_VERSION
-      DESIRED_CUDA: cu121
-      GPU_ARCH_VERSION: 12.1
-      GPU_ARCH_TYPE: cuda
-      DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main
-      DESIRED_PYTHON: "3.8"
-      build_name: manywheel-py3_8-cuda12_1-with-pypi-cudnn
-      build_environment: linux-binary-manywheel
-      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.19.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64'
-    secrets:
-      github-token: ${{ secrets.GITHUB_TOKEN }}
-  manywheel-py3_8-cuda12_1-with-pypi-cudnn-test:  # Testing
-    if: ${{ github.repository_owner == 'pytorch' }}
-    needs: manywheel-py3_8-cuda12_1-with-pypi-cudnn-build
-    uses: ./.github/workflows/_binary-test-linux.yml
-    with:
-      PYTORCH_ROOT: /pytorch
-      BUILDER_ROOT: /builder
-      PACKAGE_TYPE: manywheel
-      # TODO: This is a legacy variable that we eventually want to get rid of in
-      #       favor of GPU_ARCH_VERSION
-      DESIRED_CUDA: cu121
-      GPU_ARCH_VERSION: 12.1
-      GPU_ARCH_TYPE: cuda
-      DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main
-      DESIRED_PYTHON: "3.8"
-      build_name: manywheel-py3_8-cuda12_1-with-pypi-cudnn
-      build_environment: linux-binary-manywheel
-      runs_on: linux.4xlarge.nvidia.gpu
-    secrets:
-      github-token: ${{ secrets.GITHUB_TOKEN }}
-  manywheel-py3_8-cuda12_1-with-pypi-cudnn-upload:  # Uploading
-    if: ${{ github.repository_owner == 'pytorch' }}
-    needs: manywheel-py3_8-cuda12_1-with-pypi-cudnn-test
-    with:
-      PYTORCH_ROOT: /pytorch
-      BUILDER_ROOT: /builder
-      PACKAGE_TYPE: manywheel
-      # TODO: This is a legacy variable that we eventually want to get rid of in
-      #       favor of GPU_ARCH_VERSION
-      DESIRED_CUDA: cu121
-      GPU_ARCH_VERSION: 12.1
-      GPU_ARCH_TYPE: cuda
-      DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main
-      DESIRED_PYTHON: "3.8"
-      build_name: manywheel-py3_8-cuda12_1-with-pypi-cudnn
-    secrets:
-      github-token: ${{ secrets.GITHUB_TOKEN }}
-      aws-pytorch-uploader-access-key-id: ${{ secrets.AWS_PYTORCH_UPLOADER_ACCESS_KEY_ID }}
-      aws-pytorch-uploader-secret-access-key: ${{ secrets.AWS_PYTORCH_UPLOADER_SECRET_ACCESS_KEY }}
-      conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }}
-      conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }}
-    uses: ./.github/workflows/_binary-upload.yml
-
   manywheel-py3_8-cuda12_1-build:
     if: ${{ github.repository_owner == 'pytorch' }}
     uses: ./.github/workflows/_binary-build-linux.yml
@@ -294,6 +232,7 @@ jobs:
       DESIRED_PYTHON: "3.8"
       build_name: manywheel-py3_8-cuda12_1
       build_environment: linux-binary-manywheel
+      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.19.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64'
     secrets:
       github-token: ${{ secrets.GITHUB_TOKEN }}
   manywheel-py3_8-cuda12_1-test:  # Testing
@@ -723,68 +662,6 @@ jobs:
       conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }}
     uses: ./.github/workflows/_binary-upload.yml
 
-  manywheel-py3_9-cuda12_1-with-pypi-cudnn-build:
-    if: ${{ github.repository_owner == 'pytorch' }}
-    uses: ./.github/workflows/_binary-build-linux.yml
-    with:
-      PYTORCH_ROOT: /pytorch
-      BUILDER_ROOT: /builder
-      PACKAGE_TYPE: manywheel
-      # TODO: This is a legacy variable that we eventually want to get rid of in
-      #       favor of GPU_ARCH_VERSION
-      DESIRED_CUDA: cu121
-      GPU_ARCH_VERSION: 12.1
-      GPU_ARCH_TYPE: cuda
-      DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main
-      DESIRED_PYTHON: "3.9"
-      build_name: manywheel-py3_9-cuda12_1-with-pypi-cudnn
-      build_environment: linux-binary-manywheel
-      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.19.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64'
-    secrets:
-      github-token: ${{ secrets.GITHUB_TOKEN }}
-  manywheel-py3_9-cuda12_1-with-pypi-cudnn-test:  # Testing
-    if: ${{ github.repository_owner == 'pytorch' }}
-    needs: manywheel-py3_9-cuda12_1-with-pypi-cudnn-build
-    uses: ./.github/workflows/_binary-test-linux.yml
-    with:
-      PYTORCH_ROOT: /pytorch
-      BUILDER_ROOT: /builder
-      PACKAGE_TYPE: manywheel
-      # TODO: This is a legacy variable that we eventually want to get rid of in
-      #       favor of GPU_ARCH_VERSION
-      DESIRED_CUDA: cu121
-      GPU_ARCH_VERSION: 12.1
-      GPU_ARCH_TYPE: cuda
-      DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main
-      DESIRED_PYTHON: "3.9"
-      build_name: manywheel-py3_9-cuda12_1-with-pypi-cudnn
-      build_environment: linux-binary-manywheel
-      runs_on: linux.4xlarge.nvidia.gpu
-    secrets:
-      github-token: ${{ secrets.GITHUB_TOKEN }}
-  manywheel-py3_9-cuda12_1-with-pypi-cudnn-upload:  # Uploading
-    if: ${{ github.repository_owner == 'pytorch' }}
-    needs: manywheel-py3_9-cuda12_1-with-pypi-cudnn-test
-    with:
-      PYTORCH_ROOT: /pytorch
-      BUILDER_ROOT: /builder
-      PACKAGE_TYPE: manywheel
-      # TODO: This is a legacy variable that we eventually want to get rid of in
-      #       favor of GPU_ARCH_VERSION
-      DESIRED_CUDA: cu121
-      GPU_ARCH_VERSION: 12.1
-      GPU_ARCH_TYPE: cuda
-      DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main
-      DESIRED_PYTHON: "3.9"
-      build_name: manywheel-py3_9-cuda12_1-with-pypi-cudnn
-    secrets:
-      github-token: ${{ secrets.GITHUB_TOKEN }}
-      aws-pytorch-uploader-access-key-id: ${{ secrets.AWS_PYTORCH_UPLOADER_ACCESS_KEY_ID }}
-      aws-pytorch-uploader-secret-access-key: ${{ secrets.AWS_PYTORCH_UPLOADER_SECRET_ACCESS_KEY }}
-      conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }}
-      conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }}
-    uses: ./.github/workflows/_binary-upload.yml
-
   manywheel-py3_9-cuda12_1-build:
     if: ${{ github.repository_owner == 'pytorch' }}
     uses: ./.github/workflows/_binary-build-linux.yml
@@ -801,6 +678,7 @@ jobs:
       DESIRED_PYTHON: "3.9"
       build_name: manywheel-py3_9-cuda12_1
       build_environment: linux-binary-manywheel
+      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.19.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64'
     secrets:
       github-token: ${{ secrets.GITHUB_TOKEN }}
   manywheel-py3_9-cuda12_1-test:  # Testing
@@ -1230,68 +1108,6 @@ jobs:
       conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }}
     uses: ./.github/workflows/_binary-upload.yml
 
-  manywheel-py3_10-cuda12_1-with-pypi-cudnn-build:
-    if: ${{ github.repository_owner == 'pytorch' }}
-    uses: ./.github/workflows/_binary-build-linux.yml
-    with:
-      PYTORCH_ROOT: /pytorch
-      BUILDER_ROOT: /builder
-      PACKAGE_TYPE: manywheel
-      # TODO: This is a legacy variable that we eventually want to get rid of in
-      #       favor of GPU_ARCH_VERSION
-      DESIRED_CUDA: cu121
-      GPU_ARCH_VERSION: 12.1
-      GPU_ARCH_TYPE: cuda
-      DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main
-      DESIRED_PYTHON: "3.10"
-      build_name: manywheel-py3_10-cuda12_1-with-pypi-cudnn
-      build_environment: linux-binary-manywheel
-      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.19.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64'
-    secrets:
-      github-token: ${{ secrets.GITHUB_TOKEN }}
-  manywheel-py3_10-cuda12_1-with-pypi-cudnn-test:  # Testing
-    if: ${{ github.repository_owner == 'pytorch' }}
-    needs: manywheel-py3_10-cuda12_1-with-pypi-cudnn-build
-    uses: ./.github/workflows/_binary-test-linux.yml
-    with:
-      PYTORCH_ROOT: /pytorch
-      BUILDER_ROOT: /builder
-      PACKAGE_TYPE: manywheel
-      # TODO: This is a legacy variable that we eventually want to get rid of in
-      #       favor of GPU_ARCH_VERSION
-      DESIRED_CUDA: cu121
-      GPU_ARCH_VERSION: 12.1
-      GPU_ARCH_TYPE: cuda
-      DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main
-      DESIRED_PYTHON: "3.10"
-      build_name: manywheel-py3_10-cuda12_1-with-pypi-cudnn
-      build_environment: linux-binary-manywheel
-      runs_on: linux.4xlarge.nvidia.gpu
-    secrets:
-      github-token: ${{ secrets.GITHUB_TOKEN }}
-  manywheel-py3_10-cuda12_1-with-pypi-cudnn-upload:  # Uploading
-    if: ${{ github.repository_owner == 'pytorch' }}
-    needs: manywheel-py3_10-cuda12_1-with-pypi-cudnn-test
-    with:
-      PYTORCH_ROOT: /pytorch
-      BUILDER_ROOT: /builder
-      PACKAGE_TYPE: manywheel
-      # TODO: This is a legacy variable that we eventually want to get rid of in
-      #       favor of GPU_ARCH_VERSION
-      DESIRED_CUDA: cu121
-      GPU_ARCH_VERSION: 12.1
-      GPU_ARCH_TYPE: cuda
-      DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main
-      DESIRED_PYTHON: "3.10"
-      build_name: manywheel-py3_10-cuda12_1-with-pypi-cudnn
-    secrets:
-      github-token: ${{ secrets.GITHUB_TOKEN }}
-      aws-pytorch-uploader-access-key-id: ${{ secrets.AWS_PYTORCH_UPLOADER_ACCESS_KEY_ID }}
-      aws-pytorch-uploader-secret-access-key: ${{ secrets.AWS_PYTORCH_UPLOADER_SECRET_ACCESS_KEY }}
-      conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }}
-      conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }}
-    uses: ./.github/workflows/_binary-upload.yml
-
   manywheel-py3_10-cuda12_1-build:
     if: ${{ github.repository_owner == 'pytorch' }}
     uses: ./.github/workflows/_binary-build-linux.yml
@@ -1308,6 +1124,7 @@ jobs:
       DESIRED_PYTHON: "3.10"
       build_name: manywheel-py3_10-cuda12_1
       build_environment: linux-binary-manywheel
+      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.19.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64'
     secrets:
       github-token: ${{ secrets.GITHUB_TOKEN }}
   manywheel-py3_10-cuda12_1-test:  # Testing
@@ -1737,68 +1554,6 @@ jobs:
       conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }}
     uses: ./.github/workflows/_binary-upload.yml
 
-  manywheel-py3_11-cuda12_1-with-pypi-cudnn-build:
-    if: ${{ github.repository_owner == 'pytorch' }}
-    uses: ./.github/workflows/_binary-build-linux.yml
-    with:
-      PYTORCH_ROOT: /pytorch
-      BUILDER_ROOT: /builder
-      PACKAGE_TYPE: manywheel
-      # TODO: This is a legacy variable that we eventually want to get rid of in
-      #       favor of GPU_ARCH_VERSION
-      DESIRED_CUDA: cu121
-      GPU_ARCH_VERSION: 12.1
-      GPU_ARCH_TYPE: cuda
-      DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main
-      DESIRED_PYTHON: "3.11"
-      build_name: manywheel-py3_11-cuda12_1-with-pypi-cudnn
-      build_environment: linux-binary-manywheel
-      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.19.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64'
-    secrets:
-      github-token: ${{ secrets.GITHUB_TOKEN }}
-  manywheel-py3_11-cuda12_1-with-pypi-cudnn-test:  # Testing
-    if: ${{ github.repository_owner == 'pytorch' }}
-    needs: manywheel-py3_11-cuda12_1-with-pypi-cudnn-build
-    uses: ./.github/workflows/_binary-test-linux.yml
-    with:
-      PYTORCH_ROOT: /pytorch
-      BUILDER_ROOT: /builder
-      PACKAGE_TYPE: manywheel
-      # TODO: This is a legacy variable that we eventually want to get rid of in
-      #       favor of GPU_ARCH_VERSION
-      DESIRED_CUDA: cu121
-      GPU_ARCH_VERSION: 12.1
-      GPU_ARCH_TYPE: cuda
-      DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main
-      DESIRED_PYTHON: "3.11"
-      build_name: manywheel-py3_11-cuda12_1-with-pypi-cudnn
-      build_environment: linux-binary-manywheel
-      runs_on: linux.4xlarge.nvidia.gpu
-    secrets:
-      github-token: ${{ secrets.GITHUB_TOKEN }}
-  manywheel-py3_11-cuda12_1-with-pypi-cudnn-upload:  # Uploading
-    if: ${{ github.repository_owner == 'pytorch' }}
-    needs: manywheel-py3_11-cuda12_1-with-pypi-cudnn-test
-    with:
-      PYTORCH_ROOT: /pytorch
-      BUILDER_ROOT: /builder
-      PACKAGE_TYPE: manywheel
-      # TODO: This is a legacy variable that we eventually want to get rid of in
-      #       favor of GPU_ARCH_VERSION
-      DESIRED_CUDA: cu121
-      GPU_ARCH_VERSION: 12.1
-      GPU_ARCH_TYPE: cuda
-      DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main
-      DESIRED_PYTHON: "3.11"
-      build_name: manywheel-py3_11-cuda12_1-with-pypi-cudnn
-    secrets:
-      github-token: ${{ secrets.GITHUB_TOKEN }}
-      aws-pytorch-uploader-access-key-id: ${{ secrets.AWS_PYTORCH_UPLOADER_ACCESS_KEY_ID }}
-      aws-pytorch-uploader-secret-access-key: ${{ secrets.AWS_PYTORCH_UPLOADER_SECRET_ACCESS_KEY }}
-      conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }}
-      conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }}
-    uses: ./.github/workflows/_binary-upload.yml
-
   manywheel-py3_11-cuda12_1-build:
     if: ${{ github.repository_owner == 'pytorch' }}
     uses: ./.github/workflows/_binary-build-linux.yml
@@ -1815,6 +1570,7 @@ jobs:
       DESIRED_PYTHON: "3.11"
       build_name: manywheel-py3_11-cuda12_1
       build_environment: linux-binary-manywheel
+      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.19.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64'
     secrets:
       github-token: ${{ secrets.GITHUB_TOKEN }}
   manywheel-py3_11-cuda12_1-test:  # Testing
@@ -2244,68 +2000,6 @@ jobs:
       conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }}
     uses: ./.github/workflows/_binary-upload.yml
 
-  manywheel-py3_12-cuda12_1-with-pypi-cudnn-build:
-    if: ${{ github.repository_owner == 'pytorch' }}
-    uses: ./.github/workflows/_binary-build-linux.yml
-    with:
-      PYTORCH_ROOT: /pytorch
-      BUILDER_ROOT: /builder
-      PACKAGE_TYPE: manywheel
-      # TODO: This is a legacy variable that we eventually want to get rid of in
-      #       favor of GPU_ARCH_VERSION
-      DESIRED_CUDA: cu121
-      GPU_ARCH_VERSION: 12.1
-      GPU_ARCH_TYPE: cuda
-      DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main
-      DESIRED_PYTHON: "3.12"
-      build_name: manywheel-py3_12-cuda12_1-with-pypi-cudnn
-      build_environment: linux-binary-manywheel
-      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.19.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64'
-    secrets:
-      github-token: ${{ secrets.GITHUB_TOKEN }}
-  manywheel-py3_12-cuda12_1-with-pypi-cudnn-test:  # Testing
-    if: ${{ github.repository_owner == 'pytorch' }}
-    needs: manywheel-py3_12-cuda12_1-with-pypi-cudnn-build
-    uses: ./.github/workflows/_binary-test-linux.yml
-    with:
-      PYTORCH_ROOT: /pytorch
-      BUILDER_ROOT: /builder
-      PACKAGE_TYPE: manywheel
-      # TODO: This is a legacy variable that we eventually want to get rid of in
-      #       favor of GPU_ARCH_VERSION
-      DESIRED_CUDA: cu121
-      GPU_ARCH_VERSION: 12.1
-      GPU_ARCH_TYPE: cuda
-      DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main
-      DESIRED_PYTHON: "3.12"
-      build_name: manywheel-py3_12-cuda12_1-with-pypi-cudnn
-      build_environment: linux-binary-manywheel
-      runs_on: linux.4xlarge.nvidia.gpu
-    secrets:
-      github-token: ${{ secrets.GITHUB_TOKEN }}
-  manywheel-py3_12-cuda12_1-with-pypi-cudnn-upload:  # Uploading
-    if: ${{ github.repository_owner == 'pytorch' }}
-    needs: manywheel-py3_12-cuda12_1-with-pypi-cudnn-test
-    with:
-      PYTORCH_ROOT: /pytorch
-      BUILDER_ROOT: /builder
-      PACKAGE_TYPE: manywheel
-      # TODO: This is a legacy variable that we eventually want to get rid of in
-      #       favor of GPU_ARCH_VERSION
-      DESIRED_CUDA: cu121
-      GPU_ARCH_VERSION: 12.1
-      GPU_ARCH_TYPE: cuda
-      DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main
-      DESIRED_PYTHON: "3.12"
-      build_name: manywheel-py3_12-cuda12_1-with-pypi-cudnn
-    secrets:
-      github-token: ${{ secrets.GITHUB_TOKEN }}
-      aws-pytorch-uploader-access-key-id: ${{ secrets.AWS_PYTORCH_UPLOADER_ACCESS_KEY_ID }}
-      aws-pytorch-uploader-secret-access-key: ${{ secrets.AWS_PYTORCH_UPLOADER_SECRET_ACCESS_KEY }}
-      conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }}
-      conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }}
-    uses: ./.github/workflows/_binary-upload.yml
-
   manywheel-py3_12-cuda12_1-build:
     if: ${{ github.repository_owner == 'pytorch' }}
     uses: ./.github/workflows/_binary-build-linux.yml
@@ -2322,6 +2016,7 @@ jobs:
       DESIRED_PYTHON: "3.12"
       build_name: manywheel-py3_12-cuda12_1
       build_environment: linux-binary-manywheel
+      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.19.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64'
     secrets:
       github-token: ${{ secrets.GITHUB_TOKEN }}
   manywheel-py3_12-cuda12_1-test:  # Testing

From 9cbee4757e296645039b5b3fbcea4115fab5b718 Mon Sep 17 00:00:00 2001
From: Hongtao Yu 
Date: Wed, 22 Nov 2023 00:28:04 +0000
Subject: [PATCH 075/221] [Autotune] Reduce XLBOCK for outer reduction 
 (#114284)

I have observed that quite a few Reduction.Outer kernels have potential for performance improvement by reducing register pressure. This is due to our current register pressure reduction logics, which only reduces RBLOCK, doesn't work for outer reductions. While we can tighten up there, which will likely increase compile time, I found a better workaround to tune down XBLOCK in the first place.

Perf job: main 9efbb4ea73 (11/16) vs hoy/autotune/reduction
Slight compile time and perf improvement seen.
I also saw perf improvement locally for the few kernels being investigated.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114284
Approved by: https://github.com/jansel
---
 torch/_inductor/triton_heuristics.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/torch/_inductor/triton_heuristics.py b/torch/_inductor/triton_heuristics.py
index 17094067693f..4f0a8c7aff94 100644
--- a/torch/_inductor/triton_heuristics.py
+++ b/torch/_inductor/triton_heuristics.py
@@ -1076,7 +1076,7 @@ def reduction(
         contiguous_config = triton_config_reduction(
             size_hints, 1, (rnumel if 256 <= rnumel < 2048 else 2048)
         )
-        outer_config = triton_config_reduction(size_hints, 128, 8)
+        outer_config = triton_config_reduction(size_hints, 64, 8)
         tiny_config = triton_config_reduction(
             size_hints, 2 * (256 // rnumel) if rnumel <= 256 else 1, min(rnumel, 2048)
         )

From 4d07428edee863e7f5920f0672957a9711a9f0b5 Mon Sep 17 00:00:00 2001
From: Andrew Calvano 
Date: Wed, 22 Nov 2023 01:05:39 +0000
Subject: [PATCH 076/221] Fix for out of bounds read in mobile interpreter
 FORMAT opcode handler (#110303)

Summary:
The FORMAT opcode for the mobile TorchScript interpreter contained an out of bounds read issue leading to memory corruption.

This change adds an explicit check that the number of inputs passed to the format method called when handling the FORMAT opcode is a valid and within bounds of the stack.

Test Plan: contbuild + OSS signals

Differential Revision: D49739095

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110303
Approved by: https://github.com/malfet
---
 torch/csrc/jit/runtime/vararg_functions.cpp | 4 ++++
 1 file changed, 4 insertions(+)

diff --git a/torch/csrc/jit/runtime/vararg_functions.cpp b/torch/csrc/jit/runtime/vararg_functions.cpp
index bb28b61fe7e2..688c26939885 100644
--- a/torch/csrc/jit/runtime/vararg_functions.cpp
+++ b/torch/csrc/jit/runtime/vararg_functions.cpp
@@ -106,6 +106,10 @@ void tupleUnpack(Stack& stack) {
 }
 
 void format(Stack& stack, size_t num_inputs) {
+  if (num_inputs == 0 || num_inputs > stack.size()) {
+    AT_ERROR("Invalid number of inputs for format string: ", num_inputs);
+  }
+
   // static const std::regex unsupported_options("\\{(.*?)\\}");
   auto format = peek(stack, 0, num_inputs).toStringRef();
   // // Temporally comment out the warning message because of

From d5d62e85615fdf345e0556a9d8edbee2d3c64ae2 Mon Sep 17 00:00:00 2001
From: Jon Chuang <9093549+jon-chuang@users.noreply.github.com>
Date: Tue, 21 Nov 2023 16:51:28 -0500
Subject: [PATCH 077/221] [fx/DDP] add nested ctx_manager test for DDP Dynamo
 (#114056)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114056
Approved by: https://github.com/wconstab
---
 test/distributed/test_c10d_nccl.py            |  13 +-
 test/distributed/test_dynamo_distributed.py   | 196 ++++++++++++++----
 torch/testing/_internal/common_distributed.py |  11 +
 3 files changed, 168 insertions(+), 52 deletions(-)

diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py
index 4ac72c2bd207..bfd9a4fff4f9 100644
--- a/test/distributed/test_c10d_nccl.py
+++ b/test/distributed/test_c10d_nccl.py
@@ -12,7 +12,6 @@
 import pickle
 import time
 import warnings
-from contextlib import contextmanager
 from datetime import timedelta
 from itertools import chain, product
 from unittest import mock
@@ -45,6 +44,7 @@
     skip_if_rocm,
     with_dist_debug_levels,
     with_nccl_blocking_wait,
+    first_bucket_size,
 )
 from torch.testing._internal.common_utils import (
     TestCase,
@@ -2039,17 +2039,6 @@ def _test_grad_layout(self, replica_devices, layer_devs, local_batch_size):
         local_batch_start = self.rank * local_batch_size
         local_batch_end = (self.rank + 1) * local_batch_size
 
-        # Reducer.cpp sneakily creates one "initial bucket" that ignores the "bucket_cap_mb"
-        # argument.  The following makes sure the initial bucket also complies.
-        @contextmanager
-        def first_bucket_size(ddp_bucket_mb):
-            old_DEFAULT_FIRST_BUCKET_BYTES = dist._DEFAULT_FIRST_BUCKET_BYTES
-            dist._DEFAULT_FIRST_BUCKET_BYTES = int(ddp_bucket_mb * 1.0e6)
-            try:
-                yield
-            finally:
-                dist._DEFAULT_FIRST_BUCKET_BYTES = old_DEFAULT_FIRST_BUCKET_BYTES
-
         with torch.backends.cudnn.flags(
             enabled=True, deterministic=True, benchmark=False
         ):
diff --git a/test/distributed/test_dynamo_distributed.py b/test/distributed/test_dynamo_distributed.py
index 24298c671538..b914bbd9a465 100644
--- a/test/distributed/test_dynamo_distributed.py
+++ b/test/distributed/test_dynamo_distributed.py
@@ -3,6 +3,7 @@
 import functools
 from io import StringIO
 from typing import List
+from itertools import product
 import random
 import unittest
 from unittest.mock import patch
@@ -29,6 +30,7 @@
     skip_if_lt_x_gpu,
     requires_nccl,
     _dynamo_dist_per_rank_init,
+    first_bucket_size,
 )
 import torch._dynamo.logging
 from torch.testing._internal.common_cuda import (
@@ -71,7 +73,33 @@ def get_model(device, bsz=20, in_feat=10, hidden_feat=5000, out_feat=5, ctx_mana
     outputs = m(inputs)
     return m, inputs, outputs
 
+class ToyModelMultiOutput(nn.Module):
+    def __init__(self, ctx_manager_1, ctx_manager_2, hidden_feat=1000):
+        super().__init__()
+        self.ctx_manager_1 = ctx_manager_1
+        self.ctx_manager_2 = ctx_manager_2
+        self.net1 = nn.Sequential(
+            * [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()]
+            + [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()]
+            + [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()]
+        )
+        self.net2 = nn.Sequential(
+            * [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()]
+            + [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()]
+        )
+        self.net3 = nn.Sequential(
+            * [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()]
+            + [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()]
+            + [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()]
+        )
 
+    def forward(self, inputs):
+        with self.ctx_manager_1():
+            intermediates_1 = self.net1(inputs)
+            with self.ctx_manager_2():
+                intermediates_2 = self.net2(intermediates_1)
+            outputs = self.net3(inputs)
+        return intermediates_1, intermediates_2, outputs
 
 class ToyInnerModel(nn.Module):
     def __init__(self):
@@ -520,6 +548,13 @@ def get_model(self, bsz=20, in_feat=10, hidden_feat=5000, out_feat=5, ctx_manage
         outputs = m(inputs)
         return m, inputs, outputs
 
+    def get_model_multi_output(self, ctx_manager_1, ctx_manager_2, bsz=20, hidden_feat=1000):
+        m = ToyModelMultiOutput(ctx_manager_1, ctx_manager_2, hidden_feat=hidden_feat).to(self.device)
+        m.apply(init_weights)
+        inputs = torch.rand(bsz, hidden_feat).to(self.device)
+        outputs = m(inputs)
+        return m, inputs, outputs
+
     @patch.object(config, "optimize_ddp", False)
     def test_ddp_baseline_aot_eager(self):
         from torch.nn.parallel import DistributedDataParallel as DDP
@@ -578,49 +613,130 @@ def test_graph_split_ctx_manager(self):
         context managers' effects are applied to the computation.
         """
 
-        for get_compiler in [
-            lambda: CheckSplitsCompiler(),
-            lambda: None,
-        ]:
-            for ctx_manager, output_test in [
-                (
-                    lambda: torch.autocast(torch.device(self.device).type, torch.float16),
-                    lambda out: self.assertEqual(out.dtype, torch.float16),
-                ),
-                (
-                    torch.enable_grad,
-                    lambda out: self.assertTrue(out.requires_grad)
-                ),
-                (
-                    torch.no_grad,
-                    lambda out: self.assertTrue(not out.requires_grad)
-                ),
-            ]:
-                m, inputs, correct_outputs = self.get_model(out_feat=1000, hidden_feat=1000, in_feat=1000, ctx_manager=ctx_manager)
-                # inp - 1000 * 1000 matrix of float32 (4 bytes) = 4MB
-                # hidden - 1000 * 1000 matrix of float32 (4 bytes) = 4MB
-                bucket_cap_mb = 3.5  # 4MB
-                ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=bucket_cap_mb)
-
-                compiler = get_compiler()
-
-                @torch._dynamo.optimize(compiler.compile_fn if compiler else "aot_eager")
-                def opt_fn(inputs):
-                    return ddp_m(inputs)
+        bucket_cap_mb = 3.5
+        # inp - 1000 * 1000 matrix of float32 (4 bytes) = 4MB
+        with first_bucket_size(bucket_cap_mb):
+            for ambient_grad, get_compiler in product([
+                torch.no_grad,
+                torch.enable_grad,
+            ], [
+                lambda: CheckSplitsCompiler(),
+                lambda: None,
+            ]):
+                with ambient_grad():
+                    for ctx_manager, output_test in [
+                        (
+                            lambda: torch.autocast(torch.device(self.device).type, torch.float16),
+                            lambda out: self.assertEqual(out.dtype, torch.float16),
+                        ),
+                        (
+                            torch.enable_grad,
+                            lambda out: self.assertTrue(out.requires_grad)
+                        ),
+                        (
+                            torch.no_grad,
+                            lambda out: self.assertTrue(not out.requires_grad)
+                        ),
+                    ]:
+                        m, inputs, correct_outputs = self.get_model(
+                            out_feat=1000, hidden_feat=1000, in_feat=1000, ctx_manager=ctx_manager
+                        )
+                        ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=bucket_cap_mb)
+
+                        compiler = get_compiler()
+
+                        @torch._dynamo.optimize(compiler.compile_fn if compiler else "aot_eager")
+                        def opt_fn(inputs):
+                            return ddp_m(inputs)
+
+                        opt_outputs = opt_fn(inputs)
+                        self.assertTrue(same(correct_outputs, opt_outputs))
+                        if compiler:
+                            self.assertEqual(compiler.compiler_called, 4)
+
+                        output_test(opt_outputs)
+
+                        # ensure compatibility with dynamo explain
+
+                        explain_out = torch._dynamo.explain(ddp_m)(inputs)
+                        break_reasons = explain_out.break_reasons
+                        self.assertEqual(len(break_reasons), 4)
+                        self.assertTrue(all("DDPOptimizer" in r.reason for r in break_reasons))
 
-                opt_outputs = opt_fn(inputs)
-                self.assertTrue(same(correct_outputs, opt_outputs))
-                if compiler:
-                    self.assertEqual(compiler.compiler_called, 4)
 
-                output_test(opt_outputs)
-
-                # ensure compatibility with dynamo explain
+    @patch.object(config, "optimize_ddp", True)
+    def test_graph_split_ctx_manager_nested(self):
+        """
+        Ensures that we get the right number of splits and that the respective
+        context managers' effects are applied to the computation.
+        """
+        try:
+            torch.autocast(torch.device(self.device).type, torch.bfloat16)
+            torch.autocast(torch.device(self.device).type, torch.float16)
+        except Exception:
+            self.skipTest("Need both bfloat16, float16 support on device")
+
+        ctx_managers_outer = [
+            (
+                lambda: torch.autocast(torch.device(self.device).type, torch.float16),
+                lambda out: self.assertEqual(out.dtype, torch.float16),
+            ),
+            (
+                torch.enable_grad,
+                lambda out: self.assertTrue(out.requires_grad)
+            ),
+        ]
+        ctx_managers_inner = [
+            (
+                lambda: torch.autocast(torch.device(self.device).type, torch.bfloat16),
+                lambda out: self.assertEqual(out.dtype, torch.bfloat16),
+            ),
+            (
+                torch.no_grad,
+                lambda out: self.assertTrue(not out.requires_grad)
+            ),
+        ]
 
-                explain_out = torch._dynamo.explain(ddp_m)(inputs)
-                break_reasons = explain_out.break_reasons
-                self.assertEqual(len(break_reasons), 4)
-                self.assertTrue(all("DDPOptimizer" in r.reason for r in break_reasons))
+        bucket_cap_mb = 7.5
+        # hidden - 1000 * 1000 matrix of float32 (4 bytes) = 4MB
+        with first_bucket_size(bucket_cap_mb):
+            for ambient_grad, get_compiler in product([
+                torch.no_grad,
+                torch.enable_grad,
+            ], [
+                lambda: CheckSplitsCompiler(),
+                lambda: None,
+            ]):
+                with ambient_grad():
+                    for ctx_manager_1, output_test_1 in ctx_managers_outer:
+                        for ctx_manager_2, output_test_2 in ctx_managers_inner:
+                            m, inputs, correct_outputs = self.get_model_multi_output(
+                                ctx_manager_1, ctx_manager_2, hidden_feat=1000
+                            )
+                            ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=bucket_cap_mb)
+
+                            compiler = get_compiler()
+
+                            @torch._dynamo.optimize(compiler.compile_fn if compiler else "aot_eager")
+                            def opt_fn(inputs):
+                                return ddp_m(inputs)
+
+                            opt_outputs = opt_fn(inputs)
+                            self.assertTrue(same(correct_outputs, opt_outputs))
+                            if compiler:
+                                self.assertEqual(compiler.compiler_called, 4)
+
+                            opt_outputs_1, opt_outputs_2, opt_outputs_3 = opt_outputs
+                            output_test_1(opt_outputs_1)
+                            output_test_2(opt_outputs_2)
+                            output_test_1(opt_outputs_3)
+
+                            # ensure compatibility with dynamo explain
+
+                            explain_out = torch._dynamo.explain(ddp_m)(inputs)
+                            break_reasons = explain_out.break_reasons
+                            self.assertEqual(len(break_reasons), 4)
+                            self.assertTrue(all("DDPOptimizer" in r.reason for r in break_reasons))
 
     @patch.object(config, "optimize_ddp", True)
     @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py
index 8cbca096b500..520785f07e75 100644
--- a/torch/testing/_internal/common_distributed.py
+++ b/torch/testing/_internal/common_distributed.py
@@ -1253,3 +1253,14 @@ def _run(cls, rank: int, test_name: str, file_name: str, parent_pipe) -> None:
         self.rank = rank
         self.file_name = file_name
         self.run_test(test_name, parent_pipe)
+
+# Reducer.cpp sneakily creates one "initial bucket" that ignores the "bucket_cap_mb"
+# argument.  The following makes sure the initial bucket also complies.
+@contextmanager
+def first_bucket_size(ddp_bucket_mb):
+    old_DEFAULT_FIRST_BUCKET_BYTES = c10d._DEFAULT_FIRST_BUCKET_BYTES
+    c10d._DEFAULT_FIRST_BUCKET_BYTES = int(ddp_bucket_mb * 1.0e6)
+    try:
+        yield
+    finally:
+        c10d._DEFAULT_FIRST_BUCKET_BYTES = old_DEFAULT_FIRST_BUCKET_BYTES

From 044cd56dcc7b64c51fd5b3c2a8460254c8f05f3d Mon Sep 17 00:00:00 2001
From: voznesenskym 
Date: Tue, 21 Nov 2023 14:18:56 -0800
Subject: [PATCH 078/221] [Easy] make @markDynamoStrictTest set nopython=True
 (#114308)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114308
Approved by: https://github.com/zou3519, https://github.com/oulgen
---
 torch/testing/_internal/common_utils.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py
index 72a8dfa8c024..67e9e5992a3b 100644
--- a/torch/testing/_internal/common_utils.py
+++ b/torch/testing/_internal/common_utils.py
@@ -2738,7 +2738,7 @@ def _run_with_retry(self, result=None, num_runs_left=0, report_only=True, num_re
                 super_run = torch._dynamo.optimize("aot_eager_decomp_partition", save_config=False)(super_run)
             elif TEST_WITH_TORCHDYNAMO:
                 # TorchDynamo optimize annotation
-                super_run = torch._dynamo.optimize("eager", save_config=False)(super_run)
+                super_run = torch._dynamo.optimize("eager", save_config=False, nopython=strict_mode)(super_run)
 
             super_run(result=result)
 

From 3f736c2d77b5ded0f38a44213883088113c42ff8 Mon Sep 17 00:00:00 2001
From: Thiago Crepaldi 
Date: Tue, 21 Nov 2023 19:54:50 +0000
Subject: [PATCH 079/221] Add ONNXProgram.__call__ API to run model with ONNX
 Runtime (#113495)

Currently the user can use torch.onnx.dynamo_export to export the model.
to ONNX.

```python
import torch

class Model(torch.nn.Module):
    def forward(self, x):
        return x + 1.0

onnx_program = torch.onnx.dynamo_export(
    Model(),
    torch.randn(1, 1, 2, dtype=torch.float),
)
```

The next step would be instantiating a ONNX runtime to execute it.

```python
import onnxruntime  # type: ignore[import]

onnx_input = self.adapt_torch_inputs_to_onnx(*args, **kwargs)
options = options or {}
providers = options.get("providers", onnxruntime.get_available_providers())
onnx_model = self.model_proto.SerializeToString()
ort_session = onnxruntime.InferenceSession(onnx_model, providers=providers)

def to_numpy(tensor):
    return (
        tensor.detach().cpu().numpy()
        if tensor.requires_grad
        else tensor.cpu().numpy()
    )

onnxruntime_input = {
    k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), onnx_input)
}

return ort_session.run(None, onnxruntime_input)
```

This PR provides the `ONNXProgram.__call__` method as facilitator to use ONNX Runtime under the hood, similar to how `torch.export.ExportedProgram.__call__` which allows the underlying `torch.fx.GraphModule` to be executed.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113495
Approved by: https://github.com/titaiwangms
---
 docs/source/onnx_dynamo.rst                   |  3 +
 test/onnx/onnx_test_common.py                 |  8 +--
 test/onnx/test_fx_to_onnx_with_onnxruntime.py | 15 ++++
 torch/onnx/__init__.py                        |  3 +
 torch/onnx/_internal/exporter.py              | 70 ++++++++++++++++++-
 5 files changed, 91 insertions(+), 8 deletions(-)

diff --git a/docs/source/onnx_dynamo.rst b/docs/source/onnx_dynamo.rst
index a156c51310c3..09a09bc3a300 100644
--- a/docs/source/onnx_dynamo.rst
+++ b/docs/source/onnx_dynamo.rst
@@ -146,6 +146,9 @@ API Reference
 .. autoclass:: torch.onnx.ONNXProgramSerializer
     :members:
 
+.. autoclass:: torch.onnx.ONNXRuntimeOptions
+    :members:
+
 .. autoclass:: torch.onnx.InvalidExportOptionsError
     :members:
 
diff --git a/test/onnx/onnx_test_common.py b/test/onnx/onnx_test_common.py
index bcc8aa7a012d..2892a23f520a 100644
--- a/test/onnx/onnx_test_common.py
+++ b/test/onnx/onnx_test_common.py
@@ -439,15 +439,11 @@ def _compare_pytorch_onnx_with_ort(
         ref_input_args = input_args
         ref_input_kwargs = input_kwargs
 
-    # Format original model inputs into the format expected by exported ONNX model.
-    onnx_format_args = onnx_program.adapt_torch_inputs_to_onnx(
-        *input_args, **input_kwargs
-    )
-
     ref_outputs = onnx_program.adapt_torch_outputs_to_onnx(
         ref_model(*ref_input_args, **ref_input_kwargs)
     )
-    ort_outputs = run_ort(onnx_program, onnx_format_args)
+
+    ort_outputs = onnx_program(*input_args, **input_kwargs)
 
     if len(ref_outputs) != len(ort_outputs):
         raise AssertionError(
diff --git a/test/onnx/test_fx_to_onnx_with_onnxruntime.py b/test/onnx/test_fx_to_onnx_with_onnxruntime.py
index 728430cba994..26fa6f215bec 100644
--- a/test/onnx/test_fx_to_onnx_with_onnxruntime.py
+++ b/test/onnx/test_fx_to_onnx_with_onnxruntime.py
@@ -896,6 +896,21 @@ def create_pytorch_only_extra_kwargs():
             create_pytorch_only_extra_kwargs,
         )
 
+    def test_execute_model_with___call__(self):
+        class Model(torch.nn.Module):
+            def forward(self, x):
+                return x + 1.0
+
+        input_x = torch.randn(1, 1, 2, dtype=torch.float)
+        onnx_program = torch.onnx.dynamo_export(
+            Model(),
+            input_x,
+        )
+
+        # The other tests use ONNXProgram.__call__ indirectly and check for output equality
+        # This test aims to ensure ONNXProgram.__call__ API runs successfully despite internal test infra code
+        _ = onnx_program(input_x)
+
     def test_exported_program_as_input(self):
         class Model(torch.nn.Module):
             def forward(self, x):
diff --git a/torch/onnx/__init__.py b/torch/onnx/__init__.py
index e50dfb33004c..ad3af0984d4d 100644
--- a/torch/onnx/__init__.py
+++ b/torch/onnx/__init__.py
@@ -48,6 +48,7 @@
     ExportOptions,
     ONNXProgram,
     ONNXProgramSerializer,
+    ONNXRuntimeOptions,
     InvalidExportOptionsError,
     OnnxExporterError,
     OnnxRegistry,
@@ -103,6 +104,7 @@
     "ExportOptions",
     "ONNXProgram",
     "ONNXProgramSerializer",
+    "ONNXRuntimeOptions",
     "InvalidExportOptionsError",
     "OnnxExporterError",
     "OnnxRegistry",
@@ -118,6 +120,7 @@
 ExportOptions.__module__ = "torch.onnx"
 ONNXProgram.__module__ = "torch.onnx"
 ONNXProgramSerializer.__module__ = "torch.onnx"
+ONNXRuntimeOptions.__module__ = "torch.onnx"
 dynamo_export.__module__ = "torch.onnx"
 InvalidExportOptionsError.__module__ = "torch.onnx"
 OnnxExporterError.__module__ = "torch.onnx"
diff --git a/torch/onnx/_internal/exporter.py b/torch/onnx/_internal/exporter.py
index fbda341ee039..807ef52a0483 100644
--- a/torch/onnx/_internal/exporter.py
+++ b/torch/onnx/_internal/exporter.py
@@ -1,5 +1,6 @@
-# necessary to surface onnx.ModelProto through ONNXProgram:
-from __future__ import annotations
+from __future__ import (  # for onnx.ModelProto (ONNXProgram) and onnxruntime (ONNXRuntimeOptions)
+    annotations,
+)
 
 import abc
 
@@ -52,6 +53,7 @@
 # 'import onnx' inside of dynamo_export (by way of _assert_dependencies).
 if TYPE_CHECKING:
     import onnx
+    import onnxruntime  # type: ignore[import]
     import onnxscript  # type: ignore[import]
     from onnxscript.function_libs.torch_lib import (  # type: ignore[import]
         registration as torchlib_registry,
@@ -602,6 +604,41 @@ def serialize(
             )
 
 
+class ONNXRuntimeOptions:
+    """Options to influence the execution of the ONNX model through ONNX Runtime.
+
+    Attributes:
+        session_options: ONNX Runtime session options.
+        execution_providers: ONNX Runtime execution providers to use during model execution.
+        execution_provider_options: ONNX Runtime execution provider options.
+    """
+
+    session_options: Optional[Sequence["onnxruntime.SessionOptions"]] = None
+    """ONNX Runtime session options."""
+
+    execution_providers: Optional[
+        Sequence[Union[str, Tuple[str, Dict[Any, Any]]]]
+    ] = None
+    """ONNX Runtime execution providers to use during model execution."""
+
+    execution_provider_options: Optional[Sequence[Dict[Any, Any]]] = None
+    """ONNX Runtime execution provider options."""
+
+    @_beartype.beartype
+    def __init__(
+        self,
+        *,
+        session_options: Optional[Sequence["onnxruntime.SessionOptions"]] = None,
+        execution_providers: Optional[
+            Sequence[Union[str, Tuple[str, Dict[Any, Any]]]]
+        ] = None,
+        execution_provider_options: Optional[Sequence[Dict[Any, Any]]] = None,
+    ):
+        self.session_options = session_options
+        self.execution_providers = execution_providers
+        self.execution_provider_options = execution_provider_options
+
+
 class ONNXProgram:
     """An in-memory representation of a PyTorch model that has been exported to ONNX.
 
@@ -643,6 +680,34 @@ def __init__(
         self._fake_context = fake_context
         self._export_exception = export_exception
 
+    def __call__(
+        self, *args: Any, options: Optional[ONNXRuntimeOptions] = None, **kwargs: Any
+    ) -> Any:
+        """Runs the ONNX model using ONNX Runtime
+
+        Args:
+            args: The positional inputs to the model.
+            kwargs: The keyword inputs to the model.
+            options: The options to use for running the model with ONNX Runtime.
+
+        Returns:
+            The model output as computed by ONNX Runtime
+        """
+        import onnxruntime  # type: ignore[import]
+
+        onnx_input = self.adapt_torch_inputs_to_onnx(*args, **kwargs)
+        options = options or ONNXRuntimeOptions()
+        providers = options.execution_providers or onnxruntime.get_available_providers()
+        onnx_model = self.model_proto.SerializeToString()
+        ort_session = onnxruntime.InferenceSession(onnx_model, providers=providers)
+
+        onnxruntime_input = {
+            k.name: v.numpy(force=True)
+            for k, v in zip(ort_session.get_inputs(), onnx_input)
+        }
+
+        return ort_session.run(None, onnxruntime_input)
+
     @property
     def model_proto(self) -> onnx.ModelProto:  # type: ignore[name-defined]
         """The exported ONNX model as an :py:obj:`onnx.ModelProto`."""
@@ -1416,6 +1481,7 @@ def common_pre_export_passes(
     "ExportOptions",
     "ONNXProgram",
     "ONNXProgramSerializer",
+    "ONNXRuntimeOptions",
     "InvalidExportOptionsError",
     "OnnxExporterError",
     "OnnxRegistry",

From a785fbe513d72dd475a3640ceda370580ccd0ca8 Mon Sep 17 00:00:00 2001
From: Jerry Zhang 
Date: Mon, 20 Nov 2023 12:01:15 -0800
Subject: [PATCH 080/221] [reland][quant][pt2e] Refactor insert observer to do
 sharing checking in the same place (#113458) (#113920)

Summary:
Previously it is scatter in two different places: before inserting observer and during observer,
this PR moved everything before we insert observer

* Next: refactor QuantizationSpec and check more fields for sharing

Test Plan:
CI (regression tests)

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D51420029](https://our.internmc.facebook.com/intern/diff/D51420029)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113920
Approved by: https://github.com/andrewor14
---
 torch/ao/quantization/pt2e/prepare.py | 71 +++++++++++++++------------
 1 file changed, 40 insertions(+), 31 deletions(-)

diff --git a/torch/ao/quantization/pt2e/prepare.py b/torch/ao/quantization/pt2e/prepare.py
index 6ebef94c8610..25e90f2fcf0d 100644
--- a/torch/ao/quantization/pt2e/prepare.py
+++ b/torch/ao/quantization/pt2e/prepare.py
@@ -66,6 +66,7 @@ def _update_shared_with(edge_or_node: EdgeOrNode, qspec: QuantizationSpecBase, s
         # qspec for a = SharedQuantizationSpec(b) means `a` points to `b`
         _union(sharing_with, edge_or_node, shared_with_map)
 
+# TODO: simplify this
 def _find_root_qspec(
     qspec: QuantizationSpecBase,
     edge_or_node_to_qspec: Dict[EdgeOrNode, QuantizationSpecBase],
@@ -113,6 +114,24 @@ def _get_edge_or_node_to_qspec(model: torch.fx.GraphModule) -> Dict[EdgeOrNode,
                 edge_or_node_to_qspec[output_node] = qspec
     return edge_or_node_to_qspec
 
+def _union_input_edge_with(input_edge, input_edge_root_qspec, edge_or_node, edge_or_node_to_qspec, shared_with_map):
+    root_qspec = None
+    if edge_or_node in edge_or_node_to_qspec:
+        qspec = edge_or_node_to_qspec[edge_or_node]
+        root_qspec = _find_root_qspec(qspec, edge_or_node_to_qspec, shared_with_map)
+    # TODO: add assertions for types of root qspecs
+    if (
+        root_qspec is not None and
+        _has_same_dtype(root_qspec, input_edge_root_qspec) and
+        _has_same_is_dynamic(root_qspec, input_edge_root_qspec)
+    ):
+        # the input arg to the node should reuse the existing output observer for arg
+        # since dtype is the same (we may want to extend this to be a more strict check
+        # in the future)
+        # so we point from `input_edge` to `arg` (output of the argument)
+        _union(edge_or_node, input_edge, shared_with_map)
+
+
 def _get_edge_or_node_to_group_id(edge_or_node_to_qspec: Dict[EdgeOrNode, QuantizationSpecBase]) -> Dict[EdgeOrNode, int]:
     """Map from edge/node to the group ID, generated from quantization annotations,
     edge/node with the same group ID should use the same observer/fake_quant instance
@@ -179,21 +198,23 @@ def _get_edge_or_node_to_group_id(edge_or_node_to_qspec: Dict[EdgeOrNode, Quanti
             assert isinstance(input_edge, tuple)
             arg, n = input_edge
             if n.meta["quantization_annotation"].allow_implicit_sharing:
-                arg_as_output_root_qspec = None
-                if arg in edge_or_node_to_qspec:
-                    arg_as_output_qspec = edge_or_node_to_qspec[arg]
-                    arg_as_output_root_qspec = _find_root_qspec(arg_as_output_qspec, edge_or_node_to_qspec, shared_with_map)
-                # TODO: add assertions for types of root qspecs
-                if (
-                    arg_as_output_root_qspec is not None and
-                    _has_same_dtype(arg_as_output_root_qspec, input_edge_root_qspec) and
-                    _has_same_is_dynamic(arg_as_output_root_qspec, input_edge_root_qspec)
-                ):
-                    # the input arg to the node should reuse the existing output observer for arg
-                    # since dtype is the same (we may want to extend this to be a more strict check
-                    # in the future)
-                    # so we point from `input_edge` to `arg` (output of the argument)
-                    _union(arg, input_edge, shared_with_map)
+                # sharing with previous output
+                _union_input_edge_with(input_edge, input_edge_root_qspec, arg, edge_or_node_to_qspec, shared_with_map)
+
+                # sharing with other users of the previous output
+                # (arg, user)
+                for user in arg.users:
+                    if user is n:
+                        continue
+                    arg_to_user_edge = (arg, user)
+                    _union_input_edge_with(
+                        input_edge,
+                        input_edge_root_qspec,
+                        arg_to_user_edge,
+                        edge_or_node_to_qspec,
+                        shared_with_map
+                    )
+
             _update_shared_with(input_edge, qspec, shared_with_map)
 
     # now that we get the sharing relations between all edges and nodes, we can assingn group ids
@@ -281,10 +302,7 @@ def _maybe_insert_input_observer_for_arg_or_kwarg(
     # otherwise, we'll insert a new observer/fake_quant node
 
     existing_obs_node = None
-    # skip inserting new observers if there is an observer inserted for the arg before
-    # that has the same dtype that we want to insert here
-    # alternatively we could have a dedup pass after we insert all observers to deduplicate
-    # observers
+    # skip inserting new observers if the same observer instance is inserted before for another user
     # Example:
     # conv1 -> obs1 -> existing_obs -> conv2
     #             \ -> conv3
@@ -296,19 +314,10 @@ def _maybe_insert_input_observer_for_arg_or_kwarg(
         if not _is_activation_post_process_node(maybe_obs_node, named_modules):
             continue
         maybe_obs_mod = named_modules[maybe_obs_node.target]  # type: ignore[index]
-        if (
-            type(maybe_obs_mod) == type(input_edge_obs_or_fq) and
-            maybe_obs_mod.dtype == input_edge_obs_or_fq.dtype
-        ):
-            input_edge_obs_or_fq = maybe_obs_mod  # type: ignore[assignment]
-            existing_obs_node = maybe_obs_node
-            break
-
-    if existing_obs_node is None:
-        new_arg = _insert_obs_or_fq(arg, input_edge_obs_or_fq, model, named_modules, model.graph)
-    else:
-        new_arg = existing_obs_node
+        if id(maybe_obs_mod) == id(input_edge_obs_or_fq):
+            return maybe_obs_node
 
+    new_arg = _insert_obs_or_fq(arg, input_edge_obs_or_fq, model, named_modules, model.graph)
     return new_arg
 
 def _maybe_insert_input_observers_for_node(

From 6187153753c9d870e012780d9e57beee38562a75 Mon Sep 17 00:00:00 2001
From: "Edward Z. Yang" 
Date: Tue, 21 Nov 2023 20:59:54 -0500
Subject: [PATCH 081/221] Consolidate sym/non-sym overloads for
 _make_wrapper_subclass (#114236)

I'm not sure why we needed two overloads previously, let's find out! Removing the int overload is load bearing because it now forces specialization on SymInt arguments instead of falling through to the SymInt overload, see new test.

I decided NOT to allow storage offset simultaneously with None strides.

Signed-off-by: Edward Z. Yang 
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114236
Approved by: https://github.com/albanD
---
 test/distributed/test_dynamo_distributed.py   | 13 ++++
 torch/csrc/autograd/python_variable.cpp       | 70 +++++++++----------
 .../_internal/common_methods_invocations.py   |  4 +-
 3 files changed, 48 insertions(+), 39 deletions(-)

diff --git a/test/distributed/test_dynamo_distributed.py b/test/distributed/test_dynamo_distributed.py
index b914bbd9a465..2cdad86269c4 100644
--- a/test/distributed/test_dynamo_distributed.py
+++ b/test/distributed/test_dynamo_distributed.py
@@ -37,6 +37,7 @@
     PLATFORM_SUPPORTS_FLASH_ATTENTION, PLATFORM_SUPPORTS_MEM_EFF_ATTENTION
 )
 from torch._dynamo.comptime import comptime
+from torch.distributed._functional_collectives import _maybe_wrap_tensor
 
 def reset_rng_state():
     torch.manual_seed(1337)
@@ -1129,6 +1130,18 @@ def _add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
         for test_out in test_outs:
             self.assertEqual(test_out, ref_out)
 
+    def test_async_subclass_no_specialize(self):
+        cnt = torch._dynamo.testing.CompileCounterWithBackend("eager")
+
+        @torch.compile(backend=cnt, fullgraph=True, dynamic=True)
+        def f(x):
+            return x + 1
+
+        f(_maybe_wrap_tensor(torch.randn(10)))
+        f(_maybe_wrap_tensor(torch.randn(12)))
+
+        self.assertEqual(cnt.frame_count, 1)
+
 
 if __name__ == "__main__":
     from torch._dynamo.test_case import run_tests
diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp
index dde7066b2b85..ba0e913896d7 100644
--- a/torch/csrc/autograd/python_variable.cpp
+++ b/torch/csrc/autograd/python_variable.cpp
@@ -652,12 +652,7 @@ static PyObject* THPVariable_make_wrapper_subclass(
   // NB: pin_memory doesn't actually do anything
   // TODO: strides variant?
   static PythonArgParser parser({
-      "_make_wrapper_subclass(PyObject* cls, IntArrayRef size, *, IntArrayRef? strides=None, "
-      "int64_t? storage_offset=None, MemoryFormat? memory_format=None, ScalarType dtype=None, "
-      "Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False, "
-      "c10::string_view? dispatch_sizes_strides_policy=None, bool dispatch_device=False, bool dispatch_layout=False, "
-      "DispatchKeySet _extra_dispatch_keys=None)",
-      "_make_wrapper_subclass(PyObject* cls, SymIntArrayRef size, SymIntArrayRef strides, "
+      "_make_wrapper_subclass(PyObject* cls, SymIntArrayRef size, SymIntArrayRef? strides=None, "
       "SymInt? storage_offset=None, MemoryFormat? memory_format=None, ScalarType dtype=None, "
       "Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False, "
       "c10::string_view? dispatch_sizes_strides_policy=None, bool dispatch_device=False, bool dispatch_layout=False, "
@@ -699,42 +694,40 @@ static PyObject* THPVariable_make_wrapper_subclass(
 
   // don't bother releasing GIL here, as we are not allocating any nontrivial
   // data
-  // TODO: for_blob produces non-resizable tensors, we might want this to be
-  // resizable (have to define a custom allocator in that case)
   Tensor tensor;
-  if (r.idx == 0) {
-    TORCH_CHECK(
-        !r.toDispatchKeySetOptional(13),
-        "This overload of _make_wrapper_subclass does not support _extra_dispatch_keys");
-    tensor = at::for_blob(nullptr, r.intlist(1))
-                 .strides(r.intlistOptional(2))
-                 .storage_offset(r.toInt64Optional(3))
-                 .context(nullptr, [](void* ctx) {})
-                 .target_device(
-                     options.device()) // TODO: this shouldn't be necessary if
-                                       // it came from options
-                 .options(options)
-                 .allocator(c10::GetAllocator(c10::kMeta))
-                 .resizeable_storage()
-                 .make_tensor();
 
-    const auto sizes_strides_policy = r.stringViewOptional(10);
-    if (sizes_strides_policy.has_value()) {
-      tensor.unsafeGetTensorImpl()->set_python_custom_sizes_strides(
-          parseSizesStridesPolicyArgument(*sizes_strides_policy));
-    }
-  } else {
+  {
     AutoDispatchBelowADInplaceOrView guard{}; // TODO: Remove.
     tracer::impl::NoTracerDispatchMode tracer_guard{};
 
+    auto sym_sizes = r.symintlist(1);
+    auto sym_strides_own = r.symintlistOptional(2);
+    auto sym_strides =
+        static_cast>(sym_strides_own);
+    auto sym_storage_offset = r.toSymIntOptional(3);
+
+    c10::SymInt size_bytes;
+    auto dtype_itemsize = static_cast(options.dtype().itemsize());
+    if (sym_strides.has_value()) {
+      size_bytes = at::detail::computeStorageNbytes(
+          sym_sizes,
+          sym_strides.value(),
+          dtype_itemsize,
+          sym_storage_offset.value_or(0));
+    } else {
+      size_bytes = at::detail::computeStorageNbytesContiguous(
+          sym_sizes, dtype_itemsize, sym_storage_offset.value_or(0));
+    }
+
     // We use storages **only** to track aliasing of subclasses during tracing.
     // The actual data pointers are not valid.
     Storage storage{
         Storage::use_byte_size_t{},
-        0,
-        at::DataPtr{nullptr, r.device(7)},
+        size_bytes,
         /*allocator=*/c10::GetAllocator(c10::kMeta),
         /*resizable=*/true};
+    // TODO: constructor should probably accept data pointer
+    storage.set_data_ptr_noswap(at::DataPtr{nullptr, r.device(7)});
 
     auto keys = c10::DispatchKeySet({options.computeDispatchKey()});
     if (auto mb_extra_keys = r.toDispatchKeySetOptional(13)) {
@@ -743,14 +736,17 @@ static PyObject* THPVariable_make_wrapper_subclass(
     tensor = at::detail::make_tensor(
         std::move(storage), keys, options.dtype());
 
-    auto sym_sizes = r.symintlist(1);
-    auto sym_strides = r.symintlist(2);
-    auto sym_storage_offset = r.toSymIntOptional(3);
-
     TensorImpl* tensor_impl = tensor.unsafeGetTensorImpl();
 
-    tensor_impl->set_sizes_and_strides(
-        sym_sizes, sym_strides, sym_storage_offset.value_or(0));
+    if (sym_strides.has_value()) {
+      tensor_impl->set_sizes_and_strides(
+          sym_sizes, sym_strides.value(), sym_storage_offset);
+    } else {
+      TORCH_CHECK(
+          !sym_storage_offset.has_value(),
+          "setting storage offset without stride not supported");
+      tensor_impl->generic_set_sizes_contiguous(sym_sizes);
+    }
 
     const auto sizes_strides_policy = r.stringViewOptional(10);
     if (sizes_strides_policy.has_value()) {
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 4e42a4497162..1a0c7cfd26bf 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -12573,8 +12573,8 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
                # RuntimeError: This operator is not Composite Compliant: the
                # storage_offset of the tensor was modified directly without
                # going through the PyTorch dispatcher.
-               DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance'),
-
+               DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'),
+               DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'),
 
                # These fail because the test changes the input's in-memory layout
                DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_complex_half_reference_testing'),

From db8f9686a7e596d7155dfd363d6133e74e0eb730 Mon Sep 17 00:00:00 2001
From: Sunita Nadampalli 
Date: Wed, 22 Nov 2023 02:49:30 +0000
Subject: [PATCH 082/221] [cmake] set 'mcpu=generic' as the default build flag
 for mkldnn on aarch64 (#113820)

This is to remove the dependencies on mkldnn cmake default definitions

Fixes https://github.com/pytorch/pytorch/issues/109312

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113820
Approved by: https://github.com/malfet
---
 cmake/Modules/FindMKLDNN.cmake | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/cmake/Modules/FindMKLDNN.cmake b/cmake/Modules/FindMKLDNN.cmake
index 47694f6856e5..e62d86897fee 100644
--- a/cmake/Modules/FindMKLDNN.cmake
+++ b/cmake/Modules/FindMKLDNN.cmake
@@ -89,6 +89,8 @@ IF(NOT MKLDNN_FOUND)
     IF(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" OR CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
       IF(CPU_INTEL)
         SET(DNNL_ARCH_OPT_FLAGS "-msse4" CACHE STRING "" FORCE)
+      ELSEIF(CPU_AARCH64)
+        SET(DNNL_ARCH_OPT_FLAGS "-mcpu=generic" CACHE STRING "" FORCE)
       ENDIF()
     ELSE()
       SET(DNNL_ARCH_OPT_FLAGS "" CACHE STRING "" FORCE)

From 2c4930a91d8e4b8f7938f712bf3b78c1fdc2882f Mon Sep 17 00:00:00 2001
From: PyTorch MergeBot 
Date: Wed, 22 Nov 2023 02:52:31 +0000
Subject: [PATCH 083/221] Revert "[fx/DDP] add nested ctx_manager test for DDP
 Dynamo (#114056)"

This reverts commit d5d62e85615fdf345e0556a9d8edbee2d3c64ae2.

Reverted https://github.com/pytorch/pytorch/pull/114056 on behalf of https://github.com/malfet due to Breaks inductor_distributed, see https://hud.pytorch.org/pytorch/pytorch/commit/d5d62e85615fdf345e0556a9d8edbee2d3c64ae2 ([comment](https://github.com/pytorch/pytorch/pull/114056#issuecomment-1822006423))
---
 test/distributed/test_c10d_nccl.py            |  13 +-
 test/distributed/test_dynamo_distributed.py   | 196 ++++--------------
 torch/testing/_internal/common_distributed.py |  11 -
 3 files changed, 52 insertions(+), 168 deletions(-)

diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py
index bfd9a4fff4f9..4ac72c2bd207 100644
--- a/test/distributed/test_c10d_nccl.py
+++ b/test/distributed/test_c10d_nccl.py
@@ -12,6 +12,7 @@
 import pickle
 import time
 import warnings
+from contextlib import contextmanager
 from datetime import timedelta
 from itertools import chain, product
 from unittest import mock
@@ -44,7 +45,6 @@
     skip_if_rocm,
     with_dist_debug_levels,
     with_nccl_blocking_wait,
-    first_bucket_size,
 )
 from torch.testing._internal.common_utils import (
     TestCase,
@@ -2039,6 +2039,17 @@ def _test_grad_layout(self, replica_devices, layer_devs, local_batch_size):
         local_batch_start = self.rank * local_batch_size
         local_batch_end = (self.rank + 1) * local_batch_size
 
+        # Reducer.cpp sneakily creates one "initial bucket" that ignores the "bucket_cap_mb"
+        # argument.  The following makes sure the initial bucket also complies.
+        @contextmanager
+        def first_bucket_size(ddp_bucket_mb):
+            old_DEFAULT_FIRST_BUCKET_BYTES = dist._DEFAULT_FIRST_BUCKET_BYTES
+            dist._DEFAULT_FIRST_BUCKET_BYTES = int(ddp_bucket_mb * 1.0e6)
+            try:
+                yield
+            finally:
+                dist._DEFAULT_FIRST_BUCKET_BYTES = old_DEFAULT_FIRST_BUCKET_BYTES
+
         with torch.backends.cudnn.flags(
             enabled=True, deterministic=True, benchmark=False
         ):
diff --git a/test/distributed/test_dynamo_distributed.py b/test/distributed/test_dynamo_distributed.py
index 2cdad86269c4..1547e595c924 100644
--- a/test/distributed/test_dynamo_distributed.py
+++ b/test/distributed/test_dynamo_distributed.py
@@ -3,7 +3,6 @@
 import functools
 from io import StringIO
 from typing import List
-from itertools import product
 import random
 import unittest
 from unittest.mock import patch
@@ -30,7 +29,6 @@
     skip_if_lt_x_gpu,
     requires_nccl,
     _dynamo_dist_per_rank_init,
-    first_bucket_size,
 )
 import torch._dynamo.logging
 from torch.testing._internal.common_cuda import (
@@ -74,33 +72,7 @@ def get_model(device, bsz=20, in_feat=10, hidden_feat=5000, out_feat=5, ctx_mana
     outputs = m(inputs)
     return m, inputs, outputs
 
-class ToyModelMultiOutput(nn.Module):
-    def __init__(self, ctx_manager_1, ctx_manager_2, hidden_feat=1000):
-        super().__init__()
-        self.ctx_manager_1 = ctx_manager_1
-        self.ctx_manager_2 = ctx_manager_2
-        self.net1 = nn.Sequential(
-            * [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()]
-            + [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()]
-            + [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()]
-        )
-        self.net2 = nn.Sequential(
-            * [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()]
-            + [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()]
-        )
-        self.net3 = nn.Sequential(
-            * [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()]
-            + [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()]
-            + [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()]
-        )
 
-    def forward(self, inputs):
-        with self.ctx_manager_1():
-            intermediates_1 = self.net1(inputs)
-            with self.ctx_manager_2():
-                intermediates_2 = self.net2(intermediates_1)
-            outputs = self.net3(inputs)
-        return intermediates_1, intermediates_2, outputs
 
 class ToyInnerModel(nn.Module):
     def __init__(self):
@@ -549,13 +521,6 @@ def get_model(self, bsz=20, in_feat=10, hidden_feat=5000, out_feat=5, ctx_manage
         outputs = m(inputs)
         return m, inputs, outputs
 
-    def get_model_multi_output(self, ctx_manager_1, ctx_manager_2, bsz=20, hidden_feat=1000):
-        m = ToyModelMultiOutput(ctx_manager_1, ctx_manager_2, hidden_feat=hidden_feat).to(self.device)
-        m.apply(init_weights)
-        inputs = torch.rand(bsz, hidden_feat).to(self.device)
-        outputs = m(inputs)
-        return m, inputs, outputs
-
     @patch.object(config, "optimize_ddp", False)
     def test_ddp_baseline_aot_eager(self):
         from torch.nn.parallel import DistributedDataParallel as DDP
@@ -614,130 +579,49 @@ def test_graph_split_ctx_manager(self):
         context managers' effects are applied to the computation.
         """
 
-        bucket_cap_mb = 3.5
-        # inp - 1000 * 1000 matrix of float32 (4 bytes) = 4MB
-        with first_bucket_size(bucket_cap_mb):
-            for ambient_grad, get_compiler in product([
-                torch.no_grad,
-                torch.enable_grad,
-            ], [
-                lambda: CheckSplitsCompiler(),
-                lambda: None,
-            ]):
-                with ambient_grad():
-                    for ctx_manager, output_test in [
-                        (
-                            lambda: torch.autocast(torch.device(self.device).type, torch.float16),
-                            lambda out: self.assertEqual(out.dtype, torch.float16),
-                        ),
-                        (
-                            torch.enable_grad,
-                            lambda out: self.assertTrue(out.requires_grad)
-                        ),
-                        (
-                            torch.no_grad,
-                            lambda out: self.assertTrue(not out.requires_grad)
-                        ),
-                    ]:
-                        m, inputs, correct_outputs = self.get_model(
-                            out_feat=1000, hidden_feat=1000, in_feat=1000, ctx_manager=ctx_manager
-                        )
-                        ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=bucket_cap_mb)
-
-                        compiler = get_compiler()
-
-                        @torch._dynamo.optimize(compiler.compile_fn if compiler else "aot_eager")
-                        def opt_fn(inputs):
-                            return ddp_m(inputs)
-
-                        opt_outputs = opt_fn(inputs)
-                        self.assertTrue(same(correct_outputs, opt_outputs))
-                        if compiler:
-                            self.assertEqual(compiler.compiler_called, 4)
-
-                        output_test(opt_outputs)
-
-                        # ensure compatibility with dynamo explain
-
-                        explain_out = torch._dynamo.explain(ddp_m)(inputs)
-                        break_reasons = explain_out.break_reasons
-                        self.assertEqual(len(break_reasons), 4)
-                        self.assertTrue(all("DDPOptimizer" in r.reason for r in break_reasons))
+        for get_compiler in [
+            lambda: CheckSplitsCompiler(),
+            lambda: None,
+        ]:
+            for ctx_manager, output_test in [
+                (
+                    lambda: torch.autocast(torch.device(self.device).type, torch.float16),
+                    lambda out: self.assertEqual(out.dtype, torch.float16),
+                ),
+                (
+                    torch.enable_grad,
+                    lambda out: self.assertTrue(out.requires_grad)
+                ),
+                (
+                    torch.no_grad,
+                    lambda out: self.assertTrue(not out.requires_grad)
+                ),
+            ]:
+                m, inputs, correct_outputs = self.get_model(out_feat=1000, hidden_feat=1000, in_feat=1000, ctx_manager=ctx_manager)
+                # inp - 1000 * 1000 matrix of float32 (4 bytes) = 4MB
+                # hidden - 1000 * 1000 matrix of float32 (4 bytes) = 4MB
+                bucket_cap_mb = 3.5  # 4MB
+                ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=bucket_cap_mb)
+
+                compiler = get_compiler()
 
+                @torch._dynamo.optimize(compiler.compile_fn if compiler else "aot_eager")
+                def opt_fn(inputs):
+                    return ddp_m(inputs)
 
-    @patch.object(config, "optimize_ddp", True)
-    def test_graph_split_ctx_manager_nested(self):
-        """
-        Ensures that we get the right number of splits and that the respective
-        context managers' effects are applied to the computation.
-        """
-        try:
-            torch.autocast(torch.device(self.device).type, torch.bfloat16)
-            torch.autocast(torch.device(self.device).type, torch.float16)
-        except Exception:
-            self.skipTest("Need both bfloat16, float16 support on device")
-
-        ctx_managers_outer = [
-            (
-                lambda: torch.autocast(torch.device(self.device).type, torch.float16),
-                lambda out: self.assertEqual(out.dtype, torch.float16),
-            ),
-            (
-                torch.enable_grad,
-                lambda out: self.assertTrue(out.requires_grad)
-            ),
-        ]
-        ctx_managers_inner = [
-            (
-                lambda: torch.autocast(torch.device(self.device).type, torch.bfloat16),
-                lambda out: self.assertEqual(out.dtype, torch.bfloat16),
-            ),
-            (
-                torch.no_grad,
-                lambda out: self.assertTrue(not out.requires_grad)
-            ),
-        ]
+                opt_outputs = opt_fn(inputs)
+                self.assertTrue(same(correct_outputs, opt_outputs))
+                if compiler:
+                    self.assertEqual(compiler.compiler_called, 4)
+
+                output_test(opt_outputs)
+
+                # ensure compatibility with dynamo explain
 
-        bucket_cap_mb = 7.5
-        # hidden - 1000 * 1000 matrix of float32 (4 bytes) = 4MB
-        with first_bucket_size(bucket_cap_mb):
-            for ambient_grad, get_compiler in product([
-                torch.no_grad,
-                torch.enable_grad,
-            ], [
-                lambda: CheckSplitsCompiler(),
-                lambda: None,
-            ]):
-                with ambient_grad():
-                    for ctx_manager_1, output_test_1 in ctx_managers_outer:
-                        for ctx_manager_2, output_test_2 in ctx_managers_inner:
-                            m, inputs, correct_outputs = self.get_model_multi_output(
-                                ctx_manager_1, ctx_manager_2, hidden_feat=1000
-                            )
-                            ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=bucket_cap_mb)
-
-                            compiler = get_compiler()
-
-                            @torch._dynamo.optimize(compiler.compile_fn if compiler else "aot_eager")
-                            def opt_fn(inputs):
-                                return ddp_m(inputs)
-
-                            opt_outputs = opt_fn(inputs)
-                            self.assertTrue(same(correct_outputs, opt_outputs))
-                            if compiler:
-                                self.assertEqual(compiler.compiler_called, 4)
-
-                            opt_outputs_1, opt_outputs_2, opt_outputs_3 = opt_outputs
-                            output_test_1(opt_outputs_1)
-                            output_test_2(opt_outputs_2)
-                            output_test_1(opt_outputs_3)
-
-                            # ensure compatibility with dynamo explain
-
-                            explain_out = torch._dynamo.explain(ddp_m)(inputs)
-                            break_reasons = explain_out.break_reasons
-                            self.assertEqual(len(break_reasons), 4)
-                            self.assertTrue(all("DDPOptimizer" in r.reason for r in break_reasons))
+                explain_out = torch._dynamo.explain(ddp_m)(inputs)
+                break_reasons = explain_out.break_reasons
+                self.assertEqual(len(break_reasons), 4)
+                self.assertTrue(all("DDPOptimizer" in r.reason for r in break_reasons))
 
     @patch.object(config, "optimize_ddp", True)
     @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py
index 520785f07e75..8cbca096b500 100644
--- a/torch/testing/_internal/common_distributed.py
+++ b/torch/testing/_internal/common_distributed.py
@@ -1253,14 +1253,3 @@ def _run(cls, rank: int, test_name: str, file_name: str, parent_pipe) -> None:
         self.rank = rank
         self.file_name = file_name
         self.run_test(test_name, parent_pipe)
-
-# Reducer.cpp sneakily creates one "initial bucket" that ignores the "bucket_cap_mb"
-# argument.  The following makes sure the initial bucket also complies.
-@contextmanager
-def first_bucket_size(ddp_bucket_mb):
-    old_DEFAULT_FIRST_BUCKET_BYTES = c10d._DEFAULT_FIRST_BUCKET_BYTES
-    c10d._DEFAULT_FIRST_BUCKET_BYTES = int(ddp_bucket_mb * 1.0e6)
-    try:
-        yield
-    finally:
-        c10d._DEFAULT_FIRST_BUCKET_BYTES = old_DEFAULT_FIRST_BUCKET_BYTES

From 0a33cf95c6034d978baf7259df8387a2a14818c4 Mon Sep 17 00:00:00 2001
From: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Date: Wed, 22 Nov 2023 03:26:32 +0000
Subject: [PATCH 084/221] Add python-3.12 to triton wheels build matrix
 (#114327)

Not sure if it will work, but perhaps worth a try

Inspired by [following comment](https://github.com/pytorch/builder/blob/56556d0aaca4da61c0497608b9136b058573c8d6/manywheel/build_cuda.sh#L266):
```
# No triton dependency for now on 3.12 since we don't have binaries for it
# and torch.compile doesn't work.
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114327
Approved by: https://github.com/kit1980, https://github.com/PaliC
---
 .github/workflows/build-triton-wheel.yml | 5 ++++-
 1 file changed, 4 insertions(+), 1 deletion(-)

diff --git a/.github/workflows/build-triton-wheel.yml b/.github/workflows/build-triton-wheel.yml
index 103f35b3d519..48e1b2e3b7d0 100644
--- a/.github/workflows/build-triton-wheel.yml
+++ b/.github/workflows/build-triton-wheel.yml
@@ -34,7 +34,7 @@ jobs:
     strategy:
       fail-fast: false
       matrix:
-        py_vers: [ "3.8", "3.9", "3.10", "3.11" ]
+        py_vers: [ "3.8", "3.9", "3.10", "3.11", "3.12" ]
         device: ["cuda", "rocm"]
         include:
           - device: "rocm"
@@ -94,6 +94,9 @@ jobs:
           3.11)
             PYTHON_EXECUTABLE=/opt/python/cp311-cp311/bin/python
             ;;
+          3.12)
+            PYTHON_EXECUTABLE=/opt/python/cp312-cp312/bin/python
+            ;;
           *)
             echo "Unsupported python version ${PY_VERS}"
             exit 1

From e0ec71deab2aedd6d44f4ea3e03b52bdaf5db3da Mon Sep 17 00:00:00 2001
From: Will Constable 
Date: Tue, 21 Nov 2023 16:19:32 -0800
Subject: [PATCH 085/221] Fix module: distributed labeler (#114324)

Removes preceding `/` which was preventing labeler from working.  (looks like a typo in the original PR)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114324
Approved by: https://github.com/XilunWu, https://github.com/fegin
---
 .github/labeler.yml | 10 +++++-----
 1 file changed, 5 insertions(+), 5 deletions(-)

diff --git a/.github/labeler.yml b/.github/labeler.yml
index d7c2ae2da190..8dc33342d511 100644
--- a/.github/labeler.yml
+++ b/.github/labeler.yml
@@ -69,8 +69,8 @@
 - .ci/docker/ci_commit_pins/triton.txt
 
 "module: distributed":
-- /torch/csrc/distributed/**
-- /torch/distributed/**
-- /torch/nn/parallel/**
-- /test/distributed/**
-- /torch/testing/_internal/distributed/**
+- torch/csrc/distributed/**
+- torch/distributed/**
+- torch/nn/parallel/**
+- test/distributed/**
+- torch/testing/_internal/distributed/**

From 9e657ce2edc1fb13c392f9a0d4571c7ed5c98369 Mon Sep 17 00:00:00 2001
From: ydwu4 
Date: Tue, 21 Nov 2023 16:50:09 -0800
Subject: [PATCH 086/221] [HigherOrderOp] set should_flatten_output=True for
 cond (#113819)

This PR add should_flatten_outpu=True for cond. This effectively allows cond to support pytree output with the output being flattened. Note: a single tensor output will be automatically casted as tuple for torch.ops.higher_order.cond.

This PR also adds support for comparing BuiltinVariables e.g. tuple, this is to make sure we could make dynamo inline comparing two tree_spec to make sure both branches returns the same tree_spec.

Test Plan:
Existing tests. Will add more pytree tests and modify the documentations in the follow-up prs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113819
Approved by: https://github.com/zou3519
---
 test/dynamo/test_export.py                  | 32 ++++++-------
 test/dynamo/test_higher_order_ops.py        | 17 ++++---
 test/functorch/test_control_flow.py         | 36 +++++++++-----
 torch/_dynamo/variables/builtin.py          |  5 ++
 torch/_dynamo/variables/higher_order_ops.py | 53 ++++++++++++++++-----
 5 files changed, 96 insertions(+), 47 deletions(-)

diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py
index 3f6f906cf9b8..9ee881414d13 100644
--- a/test/dynamo/test_export.py
+++ b/test/dynamo/test_export.py
@@ -3099,18 +3099,17 @@ def f_branch_return_non_tensor(x):
             )(*example_inputs)
 
     def test_cond_raise_user_error_on_branch_return_multiple_tensors(self):
-        def f_branch_return_multiple_tensors(x, y):
-            return cond(x, lambda x: (x, x), lambda x: (x, x), [y])
+        def f_branch_return_multiple_tensors(pred, x, y):
+            return cond(pred, lambda x: (x, x), lambda x: (x, x), [y])
 
-        example_inputs = (torch.randn(4), torch.randn(2))
-        with self.assertRaisesRegex(
-            torch._dynamo.exc.UncapturedHigherOrderOpError,
-            "Cond doesn't work unless it is captured completely with torch.compile",
-        ):
-            torch._dynamo.export(
-                f_branch_return_multiple_tensors,
-                aten_graph=True,
-            )(*example_inputs)
+        example_inputs = (torch.tensor(True), torch.randn(4), torch.randn(2))
+        gm, _ = torch._dynamo.export(
+            f_branch_return_multiple_tensors,
+            aten_graph=True,
+        )(*example_inputs)
+        self.assertEqual(
+            gm(*example_inputs), f_branch_return_multiple_tensors(*example_inputs)
+        )
 
     def test_multiple_outputs_op_with_evaluator(self):
         class TopKModel(torch.nn.Module):
@@ -3654,7 +3653,8 @@ def forward(self, pred, x):
     cond_true_0 = self.cond_true_0
     cond_false_0 = self.cond_false_0
     cond = torch.ops.higher_order.cond(l_pred_, cond_true_0, cond_false_0, [a, b, l_x_, d, c]);  l_pred_ = cond_true_0 = cond_false_0 = a = b = l_x_ = d = c = None
-    return pytree.tree_unflatten([cond], self._out_spec)""",  # noqa: B950,E122
+    getitem = cond[0];  cond = None
+    return pytree.tree_unflatten([getitem], self._out_spec)""",  # noqa: B950,E122
         )
 
         self.assertExpectedInline(
@@ -3671,7 +3671,7 @@ def forward(self, a, b, l_x_, d_true_branch, c_false_branch):
     add_2 = add_1 + cos_1;  add_1 = cos_1 = None
     cos_2 = d_true_branch.cos();  d_true_branch = None
     add_3 = add_2 + cos_2;  add_2 = cos_2 = None
-    return add_3""",
+    return (add_3,)""",
         )
 
         self.assertExpectedInline(
@@ -3688,7 +3688,7 @@ def forward(self, a, b, l_x_, d_true_branch, c_false_branch):
     add_1 = add + sin_1;  add = sin_1 = None
     sin_2 = c_false_branch.sin();  c_false_branch = None
     add_2 = add_1 + sin_2;  add_1 = sin_2 = None
-    return add_2""",
+    return (add_2,)""",
         )
 
     @unittest.skipIf(
@@ -3989,7 +3989,7 @@ def forward(self, x):
 def forward(self, arg0_1, arg1_1, arg2_1):
     out_dtype = torch.ops.higher_order.out_dtype(torch.ops.aten.mm.default, torch.int32, arg0_1, arg2_1);  arg0_1 = arg2_1 = None
     sum_1 = torch.ops.aten.sum.default(out_dtype);  out_dtype = None
-    return sum_1""",
+    return (sum_1,)""",
         )
 
         self.assertExpectedInline(
@@ -3998,7 +3998,7 @@ def forward(self, arg0_1, arg1_1, arg2_1):
 def forward(self, arg0_1, arg1_1, arg2_1):
     out_dtype = torch.ops.higher_order.out_dtype(torch.ops.aten.mul.Tensor, torch.int32, arg0_1, arg2_1);  arg0_1 = arg2_1 = None
     sum_1 = torch.ops.aten.sum.default(out_dtype);  out_dtype = None
-    return sum_1""",
+    return (sum_1,)""",
         )
 
     def test_export_nn_module_stack_patched_module(self):
diff --git a/test/dynamo/test_higher_order_ops.py b/test/dynamo/test_higher_order_ops.py
index e3c7ede22b85..ba5427d5271c 100644
--- a/test/dynamo/test_higher_order_ops.py
+++ b/test/dynamo/test_higher_order_ops.py
@@ -1321,7 +1321,8 @@ def forward(self, L_x_ : torch.Tensor):
     cond_true_0 = self.cond_true_0
     cond_false_0 = self.cond_false_0
     cond = torch.ops.higher_order.cond(gt, cond_true_0, cond_false_0, [l_x_]);  gt = cond_true_0 = cond_false_0 = l_x_ = None
-    return (cond,)""",
+    getitem = cond[0];  cond = None
+    return (getitem,)""",
             )
             self.assertExpectedInline(
                 true_graph,
@@ -1329,7 +1330,7 @@ def forward(self, L_x_ : torch.Tensor):
 def forward(self, l_x_):
     l_x__1 = l_x_
     sin = torch.sin(l_x__1);  l_x__1 = None
-    return sin""",
+    return (sin,)""",
             )
             self.assertExpectedInline(
                 false_graph,
@@ -1337,7 +1338,7 @@ def forward(self, l_x_):
 def forward(self, l_x_):
     l_x__1 = l_x_
     cos = torch.cos(l_x__1);  l_x__1 = None
-    return cos""",
+    return (cos,)""",
             )
 
     def test_cond_branches_no_arguments_no_closure(self):
@@ -1364,14 +1365,15 @@ def forward(self, L_x_ : torch.Tensor):
     cond_true_0 = self.cond_true_0
     cond_false_0 = self.cond_false_0
     cond = torch.ops.higher_order.cond(gt, cond_true_0, cond_false_0, []);  gt = cond_true_0 = cond_false_0 = None
-    return (cond,)""",
+    getitem = cond[0];  cond = None
+    return (getitem,)""",
             )
             self.assertExpectedInline(
                 true_graph,
                 """\
 def forward(self):
     ones = torch.ones(3, 4)
-    return ones""",
+    return (ones,)""",
             )
             self.assertExpectedInline(
                 false_graph,
@@ -1379,7 +1381,7 @@ def forward(self):
 def forward(self):
     ones = torch.ones(3, 4)
     sin = ones.sin();  ones = None
-    return sin""",
+    return (sin,)""",
             )
 
     def test_cond_side_effect_in_one_branches(self):
@@ -2185,7 +2187,8 @@ def forward(self, L_pred_ : torch.Tensor, L_pytree_in_0_ : torch.Tensor, L_pytre
     cond_true_0 = self.cond_true_0
     cond_false_0 = self.cond_false_0
     cond = torch.ops.higher_order.cond(l_pred_, cond_true_0, cond_false_0, [l_pytree_in_0_, l_pytree_in_1_0_0_0_, l_pytree_in_2_, l_pytree_in_3_0_, l_pytree_in_3_1_0_, l_pytree_in_3_2_, l_pytree_in_4_g_]);  l_pred_ = cond_true_0 = cond_false_0 = l_pytree_in_0_ = l_pytree_in_1_0_0_0_ = l_pytree_in_2_ = l_pytree_in_3_0_ = l_pytree_in_3_1_0_ = l_pytree_in_3_2_ = l_pytree_in_4_g_ = None
-    return (cond,)""",  # noqa: B950
+    getitem = cond[0];  cond = None
+    return (getitem,)""",  # noqa: B950
         )
 
     def test_cond_pytree_operands_with_non_tensor_leaves(self):
diff --git a/test/functorch/test_control_flow.py b/test/functorch/test_control_flow.py
index e472c7fa59af..fec1a20d9f67 100644
--- a/test/functorch/test_control_flow.py
+++ b/test/functorch/test_control_flow.py
@@ -316,7 +316,7 @@ def f(x, pred, pred2):
         graph = make_fx(f, tracing_mode="symbolic")(x, torch.tensor(False), torch.tensor(False))
         self.assertEqual(graph(x, torch.tensor(True), torch.tensor(True)), f(x, torch.tensor(True), torch.tensor(True)))
 
-    def test_cond_functionalized(self):
+    def test_cond_functionalized_hah(self):
         def true_fn(x):
             y = x.sin()
             y.add_(4)
@@ -685,15 +685,17 @@ def forward(self, x_1, pred_1, pred2_1):
     true_graph_0 = self.true_graph_0
     false_graph_0 = self.false_graph_0
     conditional = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, [x_1]);  pred_1 = true_graph_0 = false_graph_0 = None
+    getitem = conditional[0];  conditional = None
     true_graph_1 = self.true_graph_1
     false_graph_1 = self.false_graph_1
     conditional_1 = torch.ops.higher_order.cond(pred2_1, true_graph_1, false_graph_1, [x_1]);  pred2_1 = true_graph_1 = false_graph_1 = x_1 = None
-    add = torch.ops.aten.add.Tensor(conditional, conditional_1);  conditional = conditional_1 = None
+    getitem_1 = conditional_1[0];  conditional_1 = None
+    add = torch.ops.aten.add.Tensor(getitem, getitem_1);  getitem = getitem_1 = None
     return add""")  # noqa: B950
         self.assertExpectedInline(graph.true_graph_0.code.strip(), """\
 def forward(self, arg0_1):
     mul = torch.ops.aten.mul.Tensor(arg0_1, arg0_1);  arg0_1 = None
-    return mul""")
+    return (mul,)""")
 
     def test_raise_error_on_mismatch_type_size(self):
         def true_fn(x):
@@ -836,15 +838,17 @@ def forward(self, x_1, pred_1, pred2_1):
     true_graph_0 = self.true_graph_0
     false_graph_0 = self.false_graph_0
     conditional = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, [x_1]);  pred_1 = true_graph_0 = false_graph_0 = None
+    getitem = conditional[0];  conditional = None
     true_graph_1 = self.true_graph_1
     false_graph_1 = self.false_graph_1
     conditional_1 = torch.ops.higher_order.cond(pred2_1, true_graph_1, false_graph_1, [x_1]);  pred2_1 = true_graph_1 = false_graph_1 = x_1 = None
-    add = torch.ops.aten.add.Tensor(conditional, conditional_1);  conditional = conditional_1 = None
+    getitem_1 = conditional_1[0];  conditional_1 = None
+    add = torch.ops.aten.add.Tensor(getitem, getitem_1);  getitem = getitem_1 = None
     return add""")  # noqa: B950
         self.assertExpectedInline(graph.true_graph_0.code.strip(), """\
 def forward(self, arg0_1):
     mul = torch.ops.aten.mul.Tensor(arg0_1, arg0_1);  arg0_1 = None
-    return mul""")
+    return (mul,)""")
 
     def test_raise_error_on_mismatch_type_size_fake_tensor(self):
         def true_fn(x):
@@ -1260,7 +1264,8 @@ def forward(self, x_1):
     true_graph_0 = self.true_graph_0
     false_graph_0 = self.false_graph_0
     conditional = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, [x_1]);  eq = true_graph_0 = false_graph_0 = x_1 = None
-    return conditional""")  # noqa: B950
+    getitem = conditional[0];  conditional = None
+    return getitem""")  # noqa: B950
 
         # We expect the traced graph module to work even if input size changes.
         x = torch.ones(4, 3, 2)
@@ -1334,11 +1339,12 @@ def forward(self, x_1):
     _tensor_constant0 = self._tensor_constant0
     _tensor_constant1 = self._tensor_constant1
     conditional = torch.ops.higher_order.cond(False, true_graph_0, false_graph_0, [x_1, _tensor_constant0, _tensor_constant1]);  true_graph_0 = false_graph_0 = x_1 = _tensor_constant0 = _tensor_constant1 = None
-    return conditional""")  # noqa: B950
+    getitem = conditional[0];  conditional = None
+    return getitem""")  # noqa: B950
         self.assertExpectedInline(gm.true_graph_0.code.strip(), """\
 def forward(self, arg0_1, arg1_1, arg2_1):
     add = torch.ops.aten.add.Tensor(arg0_1, arg1_1);  arg0_1 = arg1_1 = None
-    return add""")
+    return (add,)""")
 
     def test_cond_with_module_param_closure(self):
         class Mod(torch.nn.Module):
@@ -1455,7 +1461,8 @@ def forward(self, arg0_1, arg1_1):
     true_graph_0 = self.true_graph_0
     false_graph_0 = self.false_graph_0
     conditional = torch.ops.higher_order.cond(arg1_1, true_graph_0, false_graph_0, [arg0_1]);  arg1_1 = true_graph_0 = false_graph_0 = arg0_1 = None
-    return [conditional]""")  # noqa: B950
+    getitem = conditional[0];  conditional = None
+    return [getitem]""")  # noqa: B950
 
     def test_cond_make_fx_preserve_stack_trace_for_nodes_in_subgraph(self):
 
@@ -1503,7 +1510,9 @@ def make_dummy_fn(op):
         for _ in range(iter_n):
             # each lambda has a different object id thus fails the guard
             self.assertEqual(foo(inp, make_dummy_fn("cos"), make_dummy_fn("sin")), exp_out)
-        self.assertEqual(counters["stats"]["calls_captured"], iter_n)
+
+        # each iteration captures a cond and a getitem from the tuple output
+        self.assertEqual(counters["stats"]["calls_captured"], iter_n * 2)
         self.assertEqual(counters["stats"]["unique_graphs"], iter_n)
 
     def test_cond_with_consecutive_make_fx_symbolic(self):
@@ -1526,19 +1535,20 @@ def forward(self, x_1):
     true_graph_0 = self.true_graph_0
     false_graph_0 = self.false_graph_0
     conditional = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, [x_1]);  eq = true_graph_0 = false_graph_0 = x_1 = None
-    return conditional""")  # noqa: B950
+    getitem = conditional[0];  conditional = None
+    return getitem""")  # noqa: B950
 
             self.assertExpectedInline(gm.true_graph_0.code.strip(), """\
 def forward(self, arg0_1):
     cos = torch.ops.aten.cos.default(arg0_1)
     sub = torch.ops.aten.sub.Tensor(arg0_1, cos);  arg0_1 = cos = None
-    return sub""")
+    return (sub,)""")
 
             self.assertExpectedInline(gm.false_graph_0.code.strip(), """\
 def forward(self, arg0_1):
     sin = torch.ops.aten.sin.default(arg0_1)
     add = torch.ops.aten.add.Tensor(arg0_1, sin);  arg0_1 = sin = None
-    return add""")
+    return (add,)""")
 
     def _create_test_fns_for_cond(self, pred, inner_most_fn, operands, closure_list, nested_level):
         if nested_level == 0:
diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py
index dc09bfe1e1e7..5ddc1f32a763 100644
--- a/torch/_dynamo/variables/builtin.py
+++ b/torch/_dynamo/variables/builtin.py
@@ -1479,6 +1479,11 @@ def _unimplemented():
             if type(left) is not type(right):
                 return ConstantVariable.create(False)
 
+        if isinstance(left, BuiltinVariable) and isinstance(right, BuiltinVariable):
+            return ConstantVariable.create(op(left.fn, right.fn))
+
+        _unimplemented()
+
     def call_and_(self, tx, a, b):
         # Rely on constant_handler
         if isinstance(a, ConstantVariable) and isinstance(b, ConstantVariable):
diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py
index 699bd1df854b..927cf6751511 100644
--- a/torch/_dynamo/variables/higher_order_ops.py
+++ b/torch/_dynamo/variables/higher_order_ops.py
@@ -479,7 +479,11 @@ def speculate_branch(branch):
             # NB: 0 is predicate
             ix = 1 if branch else 2
             # TODO: Support kwargs
-            (ret_val, _), ret_graph, ret_lifted_freevars = speculate_subgraph(
+            (
+                (ret_val, ret_treespec),
+                ret_graph,
+                ret_lifted_freevars,
+            ) = speculate_subgraph(
                 tx,
                 args[ix],
                 operands,
@@ -489,20 +493,34 @@ def speculate_branch(branch):
                 "cond",
                 source_target=self.value,
                 manually_set_subgraph_inputs=False,
+                should_flatten_outputs=True,
             )
 
-            if not isinstance(ret_val, TensorVariable):
+            if not only_consist_of(ret_val, (TensorVariable,)):
                 unimplemented(
-                    "Expected branch to return a single tensor",
+                    "Expected branches to return a possibly nested list/tuple/dict of tensors but it consists of non tensors.",
                 )
-            return ret_val, ret_graph, ret_lifted_freevars
+            return ret_val, ret_treespec, ret_graph, ret_lifted_freevars
 
-        (true_r, true_graph, true_lifted_freevars) = speculate_branch(True)
+        (true_r, true_treespec, true_graph, true_lifted_freevars) = speculate_branch(
+            True
+        )
         true_nn_modules = tx.copy_graphstate().output.nn_modules
 
-        (false_r, false_graph, false_lifted_freevars) = speculate_branch(False)
+        (
+            false_r,
+            false_treespec,
+            false_graph,
+            false_lifted_freevars,
+        ) = speculate_branch(False)
         false_nn_modules = tx.copy_graphstate().output.nn_modules
 
+        same_treespec = _make_inlined(tx, pytree.TreeSpec.__eq__)(
+            true_treespec, false_treespec
+        )
+        if not same_treespec.as_python_constant():
+            unimplemented("Expected branches to return the same pytree structure.")
+
         def dedup_and_sort_lifted_freevars(true_lifted_freevars, false_lifted_freevars):
             shared_freevars = true_lifted_freevars.keys() & false_lifted_freevars.keys()
             unique_true_freevars = true_lifted_freevars.keys() - shared_freevars
@@ -584,12 +602,14 @@ def _insert_or_replace_phs(new_args, name_suffix):
             false_node,
             shared + unique_true + unique_false,
         )
-        # TODO: assert that the true/false return values are
-        # consistent
-        example_value = true_r.as_proxy().node.meta["example_value"]
+        flat_example_value = pytree.tree_map_only(
+            torch.fx.Proxy,
+            lambda a: a.node.meta["example_value"],
+            true_r.as_proxy(),
+        )
 
         # Store the invocation as a call
-        return wrap_fx_proxy(
+        flat_variable = wrap_fx_proxy(
             tx=tx,
             proxy=tx.output.create_proxy(
                 "call_function",
@@ -597,7 +617,18 @@ def _insert_or_replace_phs(new_args, name_suffix):
                 args=tuple(p_args),
                 kwargs={},
             ),
-            example_value=example_value,
+            example_value=flat_example_value,
+        )
+
+        # Transform variable back into a list (previously made into a tuple by
+        # speculate_subgraph function) so as to respect the pytree API typing.
+        flat_list_variable = BuiltinVariable(list).call_function(
+            tx, [flat_variable], {}
+        )
+        return (
+            _make_inlined(tx, pytree.tree_unflatten)(flat_list_variable, true_treespec)
+            if true_treespec
+            else flat_variable
         )
 
 

From c5ddfa79b3c8bde3099743f24f7d662f6d04152e Mon Sep 17 00:00:00 2001
From: ydwu4 
Date: Tue, 21 Nov 2023 16:50:09 -0800
Subject: [PATCH 087/221] [HigherOrderOp] add output tensor meta check for cond
 (#113900)

This PR checks the tensor meta of the outputs of cond's branches. This helps us to identify several tests that return outputs that have different requires_grad. Also fix the error messages, which previously was in torch.ops.higher_order.cond now is raised in dynamo CondHigherOrder.

Test Plan:
Existing tests.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113900
Approved by: https://github.com/zou3519
ghstack dependencies: #113819
---
 test/dynamo/test_export.py                  | 70 +++++++++++++++------
 test/functorch/test_control_flow.py         | 11 ++--
 torch/_dynamo/variables/higher_order_ops.py | 21 +++++++
 3 files changed, 78 insertions(+), 24 deletions(-)

diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py
index 9ee881414d13..33cab08435ee 100644
--- a/test/dynamo/test_export.py
+++ b/test/dynamo/test_export.py
@@ -1639,7 +1639,7 @@ def false_fn(x, y):
 
         for Module in [Foo, Bar, FooBar]:
             mod = Module()
-            x = torch.randn([3, 3])
+            x = torch.randn([3, 3], requires_grad=True)
             pred = torch.tensor(x[0][0].item() < 0)
             real_result = mod.forward(pred, x)
             out_graph, _ = torch._dynamo.export(mod.forward)(pred, x)
@@ -1682,16 +1682,7 @@ def false_fn(x):
 
                 return cond(x.shape[0] <= 2, true_fn, false_fn, [x])
 
-        mod = Module()
-        x = torch.randn(2, 2)
-        out_graph, _ = torch._dynamo.export(mod)(x)
-        test_x = torch.randn(3, 2)
-        self.assertEqual(out_graph(test_x), mod(test_x))
-
-    def test_export_with_cond_dynamic_shape_pred_tuple_operands(self):
-        from functorch.experimental.control_flow import cond
-
-        class Module(torch.nn.Module):
+        class Module2(torch.nn.Module):
             def forward(self, x):
                 def true_fn(x):
                     return x + x
@@ -1701,11 +1692,54 @@ def false_fn(x):
 
                 return cond(x.shape[0] <= 2, true_fn, false_fn, (x,))
 
-        mod = Module()
-        x = torch.randn(2, 2)
-        out_graph, _ = torch._dynamo.export(mod)(x)
-        test_x = torch.randn(3, 2)
-        self.assertEqual(out_graph(test_x), mod(test_x))
+        mods = [Module(), Module2()]
+        for mod in mods:
+            x = torch.randn(2, 2)
+            out_graph, guards = torch._dynamo.export(mod)(x)
+            self.assertExpectedInline(
+                out_graph.code.strip(),
+                """\
+def forward(self, x):
+    arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
+    l_x_ = arg0
+    size = l_x_.size()
+    getitem = size[0];  size = None
+    le = getitem <= 2;  getitem = None
+    cond_true_0 = self.cond_true_0
+    cond_false_0 = self.cond_false_0
+    cond = torch.ops.higher_order.cond(le, cond_true_0, cond_false_0, [l_x_]);  le = cond_true_0 = cond_false_0 = l_x_ = None
+    getitem_2 = cond[0];  cond = None
+    return pytree.tree_unflatten([getitem_2], self._out_spec)""",
+            )
+            self.assertExpectedInline(
+                out_graph.cond_true_0.code.strip(),
+                """\
+def forward(self, l_x_):
+    l_x__1 = l_x_
+    add = l_x__1 + l_x__1;  l_x__1 = None
+    return (add,)""",
+            )
+            self.assertExpectedInline(
+                out_graph.cond_false_0.code.strip(),
+                """\
+def forward(self, l_x_):
+    l_x__1 = l_x_
+    getitem = l_x__1[slice(None, 2, None)];  l_x__1 = None
+    return (getitem,)""",
+            )
+            with self.assertRaisesRegex(
+                torch._dynamo.exc.UncapturedHigherOrderOpError,
+                "Cond doesn't work unless it is captured completely with torch.compile",
+            ):
+                # True branch and false branch return tensors of different shape
+                torch._dynamo.export(mod)(torch.randn(3, 2))
+            with self.assertRaisesRegex(
+                torch._dynamo.exc.UncapturedHigherOrderOpError,
+                "Cond doesn't work unless it is captured completely with torch.compile",
+            ):
+                # True branch and false branch return tensors of different shape
+                test_x = torch.randn(3, 2)
+                mod(test_x)
 
     def test_export_with_map_cond(self):
         from functorch.experimental.control_flow import cond, map
@@ -3152,8 +3186,8 @@ def f_return_tensor_mismatch(x):
 
         example_inputs = (torch.rand(5),)
         with self.assertRaisesRegex(
-            RuntimeError,
-            "Expected each tensor to have same metadata but got",
+            torch._dynamo.exc.UncapturedHigherOrderOpError,
+            "Cond doesn't work unless it is captured completely with torch.compile",
         ):
             torch._dynamo.export(f_return_tensor_mismatch, aten_graph=True)(
                 *example_inputs,
diff --git a/test/functorch/test_control_flow.py b/test/functorch/test_control_flow.py
index fec1a20d9f67..da5596d59969 100644
--- a/test/functorch/test_control_flow.py
+++ b/test/functorch/test_control_flow.py
@@ -9,7 +9,6 @@
 from functorch.experimental.control_flow import UnsupportedAliasMutationException, cond
 from torch.fx.experimental.proxy_tensor import make_fx
 from torch.testing._internal.common_utils import run_tests, TestCase
-from torch._dynamo.exc import CondOpArgsMismatchError
 from torch.testing._internal.common_quantization import skipIfNoDynamoSupport
 from torch._subclasses.functional_tensor import FunctionalTensor
 
@@ -726,8 +725,8 @@ def f(x, y):
 
         x = torch.randn(4)
         with self.assertRaisesRegex(
-            CondOpArgsMismatchError,
-            "Expected each tensor to have same metadata but got",
+            torch._dynamo.exc.UncapturedHigherOrderOpError,
+            "Cond doesn't work unless it is captured completely with torch.compile"
         ):
             make_fx(f)(x, torch.tensor(False))
 
@@ -880,8 +879,8 @@ def f(x, y):
 
         x = torch.randn(4)
         with self.assertRaisesRegex(
-            CondOpArgsMismatchError,
-            "Expected each tensor to have same metadata but got",
+            torch._dynamo.exc.UncapturedHigherOrderOpError,
+            "Cond doesn't work unless it is captured completely with torch.compile"
         ):
             make_fx(f, tracing_mode="fake")(x, torch.tensor(False))
 
@@ -1350,7 +1349,7 @@ def test_cond_with_module_param_closure(self):
         class Mod(torch.nn.Module):
             def __init__(self):
                 super().__init__()
-                self.register_parameter("param", torch.nn.Parameter(torch.ones(2, 3)))
+                self.register_parameter("param", torch.nn.Parameter(torch.ones(2, 3), requires_grad=False))
                 self.register_buffer("buffer", torch.ones(2, 3) + 1)
 
         my_mode = Mod()
diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py
index 927cf6751511..bd06d487652d 100644
--- a/torch/_dynamo/variables/higher_order_ops.py
+++ b/torch/_dynamo/variables/higher_order_ops.py
@@ -16,6 +16,7 @@
 from torch._dynamo.variables.functions import UserFunctionVariable
 from torch._dynamo.variables.tensor import SymNodeVariable
 from torch._guards import Source
+from torch.fx.passes.shape_prop import _extract_tensor_metadata
 from torch.utils import _pytree as pytree
 
 from ..exc import (
@@ -521,6 +522,26 @@ def speculate_branch(branch):
         if not same_treespec.as_python_constant():
             unimplemented("Expected branches to return the same pytree structure.")
 
+        def diff_meta(tensor_vars1, tensor_vars2):
+            assert all(
+                isinstance(var, TensorVariable) for var in tensor_vars1 + tensor_vars2
+            )
+            all_diffs = []
+            for i, (var1, var2) in enumerate(zip(tensor_vars1, tensor_vars2)):
+                # We check the meta data associated with meta["example_value"]
+                meta1 = _extract_tensor_metadata(var1.proxy.node.meta["example_value"])
+                meta2 = _extract_tensor_metadata(var2.proxy.node.meta["example_value"])
+                if meta1 != meta2:
+                    all_diffs.append((f"pair{i}:", meta1, meta2))
+            return all_diffs
+
+        if diffs := diff_meta(
+            true_r.unpack_var_sequence(tx), false_r.unpack_var_sequence(tx)
+        ):
+            unimplemented(
+                f"Expected branches to return tensors with same metadata. [(tensor_pair, difference)...]:{diffs}"
+            )
+
         def dedup_and_sort_lifted_freevars(true_lifted_freevars, false_lifted_freevars):
             shared_freevars = true_lifted_freevars.keys() & false_lifted_freevars.keys()
             unique_true_freevars = true_lifted_freevars.keys() - shared_freevars

From e7326ec295559c16795088e79a5631e784bb4d61 Mon Sep 17 00:00:00 2001
From: Andrew Gu 
Date: Tue, 21 Nov 2023 16:11:32 -0800
Subject: [PATCH 088/221] [DTensor] Computed `DTensorSpec` hash lazily
 (#114322)

This is a forward fix for https://github.com/pytorch/pytorch/issues/113781.

We lazily compute the hash so that we do not try to compute the hash on `SymInt`s (for the stride) during Dynamo tracing.

Tested via:
```
python test/distributed/_tensor/test_dtensor_compile.py -k test_2d_fsdp_tp_ac_compile
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114322
Approved by: https://github.com/wanchaol
ghstack dependencies: #113919, #113924, #114134, #113925, #113930, #114141, #113915, #114140
---
 torch/distributed/_tensor/placement_types.py | 11 +++++++----
 1 file changed, 7 insertions(+), 4 deletions(-)

diff --git a/torch/distributed/_tensor/placement_types.py b/torch/distributed/_tensor/placement_types.py
index a8a79298c566..e4153565fe03 100644
--- a/torch/distributed/_tensor/placement_types.py
+++ b/torch/distributed/_tensor/placement_types.py
@@ -388,7 +388,7 @@ class DTensorSpec:
     def __post_init__(self):
         if not isinstance(self.placements, tuple):
             self.placements = tuple(self.placements)
-        self._hash = self._hash_impl()
+        self._hash: Optional[int] = None
 
     def __setattr__(self, attr: str, value: Any):
         super().__setattr__(attr, value)
@@ -397,7 +397,7 @@ def __setattr__(self, attr: str, value: Any):
         if hasattr(self, "_hash") and attr in ("mesh", "placements", "tensor_meta"):
             self._hash = self._hash_impl()
 
-    def _hash_impl(self):
+    def _hash_impl(self) -> int:
         # hashing and equality check for DTensorSpec are used to cache the sharding
         # propagation results. We only need to consider the mesh, placements, shape
         # dtype and stride.
@@ -416,9 +416,12 @@ def _hash_impl(self):
         return hash((self.mesh, self.placements))
 
     def __hash__(self) -> int:
-        # We eagerly cache the spec to avoid recomputing the hash upon each
+        # We lazily cache the spec to avoid recomputing the hash upon each
         # use, where we make sure to update the hash when the `tensor_meta`
-        # changes by overriding `__setattr__`.
+        # changes by overriding `__setattr__`. This must be lazy so that Dynamo
+        # does not try to hash non-singleton `SymInt`s for the stride.
+        if self._hash is None:
+            self._hash = self._hash_impl()
         return self._hash
 
     def __eq__(self, __o: object) -> bool:

From bd44bdb6750327aa03fa7e69c8d0d9596f319f0a Mon Sep 17 00:00:00 2001
From: BowenBao 
Date: Mon, 20 Nov 2023 11:58:36 -0800
Subject: [PATCH 089/221] [ONNX][dynamo_export] Turn off opmath for type
 promotion (#113780)

Although opmath is the right thing to do to retain on-par precision, it inserts
upcasts everywhere in the graph. This is particularly hard for backend to optimize
since there is no way to differentiate between inserted upcasts and model code
casts. Hence we consolidate the input dtype to the result dtype to avoid this.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113780
Approved by: https://github.com/titaiwangms, https://github.com/justinchuby
---
 test/onnx/onnx_test_common.py                 |  5 +--
 test/onnx/test_fx_op_consistency.py           |  7 +++++
 .../_internal/fx/passes/type_promotion.py     | 31 +++++++++++++++++--
 3 files changed, 37 insertions(+), 6 deletions(-)

diff --git a/test/onnx/onnx_test_common.py b/test/onnx/onnx_test_common.py
index 2892a23f520a..f664e7e84a42 100644
--- a/test/onnx/onnx_test_common.py
+++ b/test/onnx/onnx_test_common.py
@@ -220,7 +220,6 @@ def run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
         rtol: Optional[float] = 1e-3,
         atol: Optional[float] = 1e-7,
         has_mutation: bool = False,
-        verbose: bool = False,
         additional_test_inputs: Optional[
             List[
                 Union[
@@ -242,8 +241,6 @@ def run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
             has_mutation (bool, optional): Whether the model mutates its input or state.
                 `mutation` as `True` incurs extra overhead of cloning the inputs and model.
                 Defaults to False.
-            verbose (bool, optional): Whether to save diagnostics as Sarif log and print
-                verbose information. Defaults to False.
             additional_test_inputs: Test the models with another dataset input, which
                 is designed for dynamic axes testing. Defaults to None. It's a list of
                 different input sets in tuples. Inside tuple, the first element is a tuple
@@ -308,7 +305,7 @@ def run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
             export_error = e
             onnx_program = e.onnx_program
 
-        if verbose and diagnostics.is_onnx_diagnostics_log_artifact_enabled():
+        if diagnostics.is_onnx_diagnostics_log_artifact_enabled():
             onnx_program.save_diagnostics(
                 f"test_report_{self._testMethodName}"
                 f"_op_level_debug_{self.op_level_debug}"
diff --git a/test/onnx/test_fx_op_consistency.py b/test/onnx/test_fx_op_consistency.py
index 085fb05f67a7..b1ba49f7310a 100644
--- a/test/onnx/test_fx_op_consistency.py
+++ b/test/onnx/test_fx_op_consistency.py
@@ -450,6 +450,12 @@ def skip_torchlib_forward_compatibility(
         dtypes=onnx_test_common.INT_TYPES,
         reason=onnx_test_common.reason_onnx_does_not_support("AveragePool", "int"),
     ),
+    xfail(
+        # NOTE: this is a temporary skip, see https://github.com/pytorch/pytorch/issues/113808.
+        "nn.functional.celu",
+        dtypes=(torch.float16,),
+        reason=onnx_test_common.reason_onnx_does_not_support("Celu", "float16"),
+    ),
     xfail(
         "nn.functional.conv1d",
         dtypes=(torch.int64,),
@@ -780,6 +786,7 @@ class TestOnnxModelOutputConsistency(onnx_test_common._TestONNXRuntime):
     )
 
     fp16_low_precision_list = [
+        "baddbmm",
         "nn.functional.batch_norm",
         "native_batch_norm",
         "dot",
diff --git a/torch/onnx/_internal/fx/passes/type_promotion.py b/torch/onnx/_internal/fx/passes/type_promotion.py
index 0e14e7a3980e..04fea082d7d7 100644
--- a/torch/onnx/_internal/fx/passes/type_promotion.py
+++ b/torch/onnx/_internal/fx/passes/type_promotion.py
@@ -138,6 +138,12 @@ def preview_type_promotion(
 class ElementwiseTypePromotionRule(TypePromotionRule):
     """Defines how to perform elementwise type promotion for 'torch.ops.{namespace}.{op_name}'."""
 
+    _USE_OPMATH: bool = False
+    """Whether to use opmath to compute the promoted input dtype.
+    If used, upcasts will be inserted everywhere for lower precision models.
+    Set to False and have torchlib handle upcasts in op implementation internally.
+    """
+
     def __init__(
         self,
         namespace: str,
@@ -180,6 +186,23 @@ def __eq__(self, __value: object) -> bool:
     def __hash__(self) -> int:
         return f"{type(self)}:{self.namespace}.{self.op_name}".__hash__()
 
+    def _consolidate_input_dtype(
+        self, computed_dtype: torch.dtype, result_dtype: torch.dtype
+    ) -> torch.dtype:
+        """
+        Although opmath is the right thing to do to retain on-par precision, it inserts
+        upcasts everywhere in the graph. This is particularly hard for backend to optimize
+        since there is no way to differentiate between inserted upcasts and model code
+        casts. Hence we consolidate the input dtype to the result dtype to avoid this.
+        """
+        if (
+            not self._USE_OPMATH
+            and self.promotion_kind
+            == _prims_common.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+        ):
+            return result_dtype
+        return computed_dtype
+
     def preview_type_promotion(
         self, args: tuple, kwargs: dict
     ) -> TypePromotionSnapshot:
@@ -199,9 +222,13 @@ def preview_type_promotion(
             type_promotion_kind=self.promotion_kind,
         )
 
+        consolidated_input_dtype = self._consolidate_input_dtype(
+            computed_dtype, result_dtype
+        )
+
         return TypePromotionSnapshot(
-            {i: computed_dtype for i in candidate_args.keys()},
-            {name: computed_dtype for name in candidate_kwargs.keys()},
+            {i: consolidated_input_dtype for i in candidate_args.keys()},
+            {name: consolidated_input_dtype for name in candidate_kwargs.keys()},
             result_dtype,
         )
 

From bebe66e26261d3c938371cbe87c5f23cfd65a443 Mon Sep 17 00:00:00 2001
From: BowenBao 
Date: Mon, 20 Nov 2023 11:58:40 -0800
Subject: [PATCH 090/221] [ONNX] Benchmark to save sample inputs to disk before
 running (#114163)

Such that even if failures occur during model run, the sample inputs
are accessible for later investigation.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114163
Approved by: https://github.com/thiagocrepaldi
ghstack dependencies: #113780
---
 benchmarks/dynamo/common.py | 7 ++++---
 1 file changed, 4 insertions(+), 3 deletions(-)

diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py
index 21387c123158..4339d1f958fe 100644
--- a/benchmarks/dynamo/common.py
+++ b/benchmarks/dynamo/common.py
@@ -1193,14 +1193,15 @@ def save_tensor_data(cls, numpy_tensor, output_path):
             f.write(proto_tensor.SerializeToString())
 
     def run_and_serialize_inputs_outputs(self, pt_inputs):
-        onnx_inputs = self.adapt_pt_inputs_to_onnx(pt_inputs)
-        onnx_outputs = self.run_with_onnx_inputs(onnx_inputs)
-
         test_data_dir = self.model_dir / "test_data_set_0"
         test_data_dir.mkdir(parents=True, exist_ok=True)
 
+        onnx_inputs = self.adapt_pt_inputs_to_onnx(pt_inputs)
         for i, onnx_input in enumerate(onnx_inputs.values()):
             self.save_tensor_data(onnx_input, str(test_data_dir / f"input_{i}.pb"))
+
+        onnx_outputs = self.run_with_onnx_inputs(onnx_inputs)
+
         for i, onnx_output in enumerate(onnx_outputs):
             self.save_tensor_data(onnx_output, str(test_data_dir / f"output_{i}.pb"))
 

From 9f0deb132b3f271561bea07610cd01a5e99fb3a6 Mon Sep 17 00:00:00 2001
From: Yanbo Liang 
Date: Wed, 22 Nov 2023 05:46:23 +0000
Subject: [PATCH 091/221] [Inductor] Refactor group/batch fusion to support
 user defined execution order and configs (#113738)

Meta internal customers need more flexible configs on these group/batch fusion's execution order and parameters, I'd like to provide a new inductor config that users can fine and auto tune these group/batch fusions for different models.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113738
Approved by: https://github.com/xuzhao9
---
 test/inductor/test_group_batch_fusion.py      |  2 +-
 torch/_inductor/config.py                     | 20 +++-
 .../_inductor/fx_passes/group_batch_fusion.py | 98 ++++++++++++++-----
 torch/_inductor/fx_passes/post_grad.py        |  4 +-
 torch/_inductor/fx_passes/pre_grad.py         |  4 +-
 5 files changed, 94 insertions(+), 34 deletions(-)

diff --git a/test/inductor/test_group_batch_fusion.py b/test/inductor/test_group_batch_fusion.py
index 2637eaf41036..f36cee3354a6 100644
--- a/test/inductor/test_group_batch_fusion.py
+++ b/test/inductor/test_group_batch_fusion.py
@@ -224,7 +224,7 @@ def forward(self, x):
 
 
 @requires_cuda()
-@torch._inductor.config.patch(group_fusion=True, batch_fusion=True)
+@torch._inductor.config.patch(post_grad_fusion_options={"group_linear": {}})
 class TestGroupBatchFusion(TestCase):
     def compare_dict_tensors(self, ref_dict, res_dict, rtol=1e-3, atol=1e-3):
         if len(set(ref_dict.keys())) != len(set(res_dict.keys())):
diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py
index e07c90654467..3912c65a4d1e 100644
--- a/torch/_inductor/config.py
+++ b/torch/_inductor/config.py
@@ -1,6 +1,6 @@
 import os  # noqa: C101
 import sys
-from typing import TYPE_CHECKING
+from typing import Any, Dict, TYPE_CHECKING
 
 import torch
 
@@ -85,12 +85,26 @@
 # Optimize conv-batchnorm if batchnorm is in eval mode. Slightly reduces numerical stability.
 efficient_conv_bn_eval_fx_passes = False
 
-# enable pattern match with group fusion (using fbgemm)
+# Deprecated
 group_fusion = False
 
-# enable pattern match with batch fusion (using torch op)
+# Deprecated
 batch_fusion = True
 
+# Pre grad group/batch fusion and options in order, set to empty dict to disable fusion.
+# Call `torch._inductor.fx_passes.group_batch_fusion.list_group_batch_fusions()` to see available fusions.
+pre_grad_fusion_options: Dict[str, Dict[str, Any]] = {
+    "batch_linear": {},
+    "batch_linear_lhs": {},
+    "batch_layernorm": {},
+    "batch_tanh": {},
+    "batch_relu": {},
+}
+
+# Post grad group/batch fusion and options, set to empty dict to disable fusion.
+# Call `torch._inductor.fx_passes.group_batch_fusion.list_group_batch_fusions(False)` to see available fusions.
+post_grad_fusion_options: Dict[str, Dict[str, Any]] = {}
+
 # enable reordering pass for improving memory locality
 reorder_for_locality = True
 
diff --git a/torch/_inductor/fx_passes/group_batch_fusion.py b/torch/_inductor/fx_passes/group_batch_fusion.py
index 778550999333..a3783f113619 100644
--- a/torch/_inductor/fx_passes/group_batch_fusion.py
+++ b/torch/_inductor/fx_passes/group_batch_fusion.py
@@ -1,7 +1,7 @@
 import collections
 import logging
 import operator
-from typing import Any, DefaultDict, Deque, Iterator, List, Optional, Set, Tuple
+from typing import Any, DefaultDict, Deque, Dict, Iterator, List, Optional, Set, Tuple
 
 import torch
 from torch._dynamo.utils import counters
@@ -40,7 +40,22 @@
 SEARCH_EXCLUSIONS = {operator.getitem}
 
 
+default_graph_search_options = {
+    "min_fuse_set_size": MIN_FUSE_SET_SIZE,
+    "max_fuse_set_size": MAX_FUSE_SET_SIZE,
+    "max_fuse_search_depth": MAX_FUSE_SEARCH_DEPTH,
+    "max_fuse_tensor_size_group_linear": MAX_FUSE_TENSOR_SIZE_GROUP_LINEAR,
+}
+
+graph_search_options = default_graph_search_options
+
+
 class GroupBatchFusionBase:
+    def __init__(self, **kwargs):
+        self.graph_search_options = kwargs.pop(
+            "graph_search_options", default_graph_search_options
+        )
+
     def match(self, node):
         raise NotImplementedError("match called on base")
 
@@ -48,6 +63,28 @@ def fuse(self, graph, subset):
         raise NotImplementedError("fuse called on base")
 
 
+PRE_GRAD_FUSIONS: Dict[str, GroupBatchFusionBase] = dict()
+POST_GRAD_FUSIONS: Dict[str, GroupBatchFusionBase] = dict()
+
+
+def register_fusion(name: str, pre_grad=True):
+    def decorator(fusion_cls: GroupBatchFusionBase):
+        if pre_grad:
+            PRE_GRAD_FUSIONS[name] = fusion_cls
+        else:
+            POST_GRAD_FUSIONS[name] = fusion_cls
+        return fusion_cls
+
+    return decorator
+
+
+def list_group_batch_fusions(pre_grad=True) -> List[str]:
+    if pre_grad:
+        return list(PRE_GRAD_FUSIONS.keys())
+    else:
+        return list(POST_GRAD_FUSIONS.keys())
+
+
 class GroupFusion(GroupBatchFusionBase):
     """
     Fuse ops in a group way, e.g, fuse mm/addmm of arbitrary input shapes with fbgemm.gmm.
@@ -64,6 +101,7 @@ class BatchFusion(GroupBatchFusionBase):
     pass
 
 
+@register_fusion("group_linear", pre_grad=False)
 class GroupLinearFusion(GroupFusion):
     def _addmm_node_can_be_fused(self, node: torch.fx.Node):
         input_shape = node.args[1].meta["tensor_meta"].shape
@@ -75,7 +113,7 @@ def _addmm_node_can_be_fused(self, node: torch.fx.Node):
             and len(weight_shape) == 2
             and all(x % 2 == 0 for x in input_shape + weight_shape)
             and all(
-                shape <= MAX_FUSE_TENSOR_SIZE_GROUP_LINEAR
+                shape <= self.graph_search_options["max_fuse_tensor_size_group_linear"]
                 for shape in input_shape + weight_shape
             )
         )
@@ -88,7 +126,7 @@ def _mm_node_can_be_fused(self, node: torch.fx.Node):
             and len(weight_shape) == 2
             and all(x % 2 == 0 for x in input_shape + weight_shape)
             and all(
-                shape <= MAX_FUSE_TENSOR_SIZE_GROUP_LINEAR
+                shape <= self.graph_search_options["max_fuse_tensor_size_group_linear"]
                 for shape in input_shape + weight_shape
             )
         )
@@ -143,6 +181,7 @@ def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
             graph.erase_node(original_mm)
 
 
+@register_fusion("batch_linear_lhs")
 class BatchLinearLHSFusion(BatchFusion):
     """
     Batch linear left-hand side fusion. This pass tries to fuse the following patterns:
@@ -238,6 +277,7 @@ def is_linear_node_can_be_fused(node: torch.fx.Node):
     )
 
 
+@register_fusion("batch_linear")
 class BatchLinearFusion(BatchFusion):
     """
     Batch linear fusion in pre grad pass.
@@ -316,6 +356,7 @@ def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
                 graph.erase_node(linear)
 
 
+@register_fusion("batch_tanh")
 class BatchTanhFusion(BatchFusion):
     """
     Batch tanh fusion in pre grad pass.
@@ -375,6 +416,7 @@ def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
                 graph.erase_node(node)
 
 
+@register_fusion("batch_layernorm")
 class BatchLayernormFusion(BatchFusion):
     """
     Batch layer norm fusion in pre grad pass
@@ -486,6 +528,7 @@ def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
             graph.erase_node(node)
 
 
+@register_fusion("batch_relu")
 class BatchReLUFusion(BatchFusion):
     """
     Batch relu fusion in pre grad pass.
@@ -550,6 +593,7 @@ def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
 
 def find_independent_subset_greedy(
     node_list: List[torch.fx.Node],
+    graph_search_options: Dict[str, Any],
 ) -> Iterator[List[torch.fx.Node]]:
     """
     Return a list of subset from node_list, all nodes in each subset are independent with each other and can be fused together.
@@ -572,7 +616,7 @@ def find_dependent_nodes(src_node, cur_node):
         subset_deps: Set[torch.fx.Node] = set()
 
         for node in node_list:
-            if len(subset) >= MAX_FUSE_SET_SIZE:
+            if len(subset) >= graph_search_options["max_fuse_set_size"]:
                 break
 
             visited_node_set.clear()
@@ -583,7 +627,7 @@ def find_dependent_nodes(src_node, cur_node):
                 subset.append(node)
                 subset_deps.update(dep_set)
 
-        if len(subset) >= MIN_FUSE_SET_SIZE:
+        if len(subset) >= graph_search_options["min_fuse_set_size"]:
             yield subset
 
         next_round_node_list = [node for node in node_list if node not in subset]
@@ -595,7 +639,7 @@ def get_fusion_candidates(
 ) -> DefaultDict[Any, List[torch.fx.Node]]:
     """
     Search fusion candidates for a specific rule using BFS starting from the root node.
-    We only search the subgraph within MAX_FUSE_SEARCH_DEPTH.
+    We only search the subgraph within graph_search_options["max_fuse_search_depth"].
     """
     q: Deque[Tuple[int, torch.fx.Node]] = collections.deque()
 
@@ -624,7 +668,7 @@ def get_fusion_candidates(
             if node not in candidate_nodes:
                 candidate_nodes.append(node)
         else:
-            if depth < MAX_FUSE_SEARCH_DEPTH:
+            if depth < rule.graph_search_options["max_fuse_search_depth"]:
                 for next_node in node.all_input_nodes:
                     if next_node not in visited_set:
                         visited_set.add(next_node)
@@ -644,7 +688,9 @@ def apply_group_batch_fusion(graph: torch.fx.GraphModule, rule: GroupBatchFusion
             if len(candidate_nodes) < MIN_FUSE_SET_SIZE:
                 continue
 
-            for subset in find_independent_subset_greedy(candidate_nodes):
+            for subset in find_independent_subset_greedy(
+                candidate_nodes, rule.graph_search_options
+            ):
                 rule.fuse(graph, subset)
                 fused_set.update(subset)
                 if isinstance(rule, GroupFusion):
@@ -662,29 +708,29 @@ def print_graph(graph: torch.fx.Graph, msg: str):
         log.info("%s Print graph: %s", msg, get_everpaste_url(str(graph)))  # noqa: F401
 
 
-def group_batch_fusion_post_grad_passes(graph: torch.fx.Graph):
-    print_graph(graph, "Before group_batch fusion in post grads pass.")
+def generate_fusion_from_config(config_options: Dict[str, Any], pre_grad=True):
     fusions: List[GroupBatchFusionBase] = []
+    for name, options in config_options.items():
+        fusion_cls = PRE_GRAD_FUSIONS[name] if pre_grad else POST_GRAD_FUSIONS[name]
+        _options = graph_search_options.copy()
+        _options.update(options)
+        fusions.append(fusion_cls(graph_search_options=_options))  # type: ignore[operator]
+    return fusions
 
-    if config.group_fusion and has_fbgemm:
-        fusions += [GroupLinearFusion()]
 
-    for rule in fusions:
-        apply_group_batch_fusion(graph, rule)
-        print_graph(graph, f"Apply fusion {rule.__class__.__name__}.")
+def group_batch_fusion_passes(graph: torch.fx.Graph, pre_grad=True):
+    print_graph(graph, "Before group_batch fusion in post grads pass.")
+    fusions: List[GroupBatchFusionBase] = []
 
+    if pre_grad:
+        fusions = generate_fusion_from_config(
+            config.pre_grad_fusion_options, pre_grad=True
+        )
+    elif has_fbgemm:  # Only group fusion (which needs fbgemm) in post grad.
+        fusions = generate_fusion_from_config(
+            config.post_grad_fusion_options, pre_grad=False
+        )
 
-def group_batch_fusion_pre_grad_passes(graph: torch.fx.Graph):
-    print_graph(graph, "Before group_batch fusion in pre grads pass.")
-    fusions: List[GroupBatchFusionBase] = []
-    if config.batch_fusion:
-        fusions += [
-            BatchLinearFusion(),
-            BatchLinearLHSFusion(),
-            BatchLayernormFusion(),
-            BatchTanhFusion(),
-            BatchReLUFusion(),
-        ]
     for rule in fusions:
         apply_group_batch_fusion(graph, rule)
         print_graph(graph, f"Apply fusion {rule.__class__.__name__}.")
diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py
index c044400c060a..86f84a63b6c8 100644
--- a/torch/_inductor/fx_passes/post_grad.py
+++ b/torch/_inductor/fx_passes/post_grad.py
@@ -39,7 +39,7 @@
 )
 from ..utils import decode_device, is_pointwise_use
 from ..virtualized import V
-from .group_batch_fusion import group_batch_fusion_post_grad_passes
+from .group_batch_fusion import group_batch_fusion_passes
 
 
 log = logging.getLogger(__name__)
@@ -78,7 +78,7 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
     if config.pattern_matcher:
         lazy_init()
 
-        group_batch_fusion_post_grad_passes(gm.graph)
+        group_batch_fusion_passes(gm.graph, pre_grad=False)
         remove_noop_ops(gm.graph)
 
         for patterns in pass_patterns:
diff --git a/torch/_inductor/fx_passes/pre_grad.py b/torch/_inductor/fx_passes/pre_grad.py
index 6861ea89fb0a..fe879bce6cb4 100644
--- a/torch/_inductor/fx_passes/pre_grad.py
+++ b/torch/_inductor/fx_passes/pre_grad.py
@@ -22,7 +22,7 @@
     stable_topological_sort,
 )
 from ..utils import is_cpu_device
-from .group_batch_fusion import group_batch_fusion_pre_grad_passes
+from .group_batch_fusion import group_batch_fusion_passes
 from .misc_patterns import numpy_compat_normalization
 
 log = logging.getLogger(__name__)
@@ -69,7 +69,7 @@ def pre_grad_passes(gm: torch.fx.GraphModule, example_inputs):
         lazy_init()
         gm = fuse_fx(gm, example_inputs)
         numpy_compat_normalization(gm.graph)
-        group_batch_fusion_pre_grad_passes(gm.graph)
+        group_batch_fusion_passes(gm.graph, pre_grad=True)
         for pattern_matcher_pass in pattern_matcher_passes:
             pattern_matcher_pass.apply(gm.graph)
 

From c77a4a409654dbc0ac4a528c37873b0acb1be32d Mon Sep 17 00:00:00 2001
From: Isuru Fernando 
Date: Wed, 22 Nov 2023 07:32:16 +0000
Subject: [PATCH 092/221] Fix compiling add with torch.int32 and scalars
 (#113965)

Fixes #113944

When `b` and `alpha` are both scalars, using `prims.mul` will create a tensor with dtype `int64` resulting in wrong dtype.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113965
Approved by: https://github.com/ezyang
---
 test/inductor/test_torchinductor.py | 3 ++-
 torch/_refs/__init__.py             | 5 ++++-
 2 files changed, 6 insertions(+), 2 deletions(-)

diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py
index fbe6ef45a5b4..42480e275a08 100644
--- a/test/inductor/test_torchinductor.py
+++ b/test/inductor/test_torchinductor.py
@@ -598,7 +598,8 @@ def test_add_const_int(self):
         def fn(a):
             return (a + 1, torch.add(a, 1, alpha=2))
 
-        self.common(fn, (torch.randn(32),))
+        for dtype in [torch.float32, torch.int32, torch.int64]:
+            self.common(fn, (torch.arange(32, dtype=dtype),))
 
     def test_add_const_float(self):
         def fn(a):
diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py
index 450c198067aa..7a78c08dd0c2 100644
--- a/torch/_refs/__init__.py
+++ b/torch/_refs/__init__.py
@@ -1074,7 +1074,10 @@ def add(
         ):
             msg = f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!"
             raise ValueError(msg)
-        b = prims.mul(b, alpha)
+        if isinstance(b, TensorLike):
+            b = prims.mul(b, alpha)
+        else:
+            b = b * alpha
 
     output = prims.add(a, b)
     return handle_noncontiguous_outputs([a, b], output)

From 172a103857020a50c3b0dd353499e3e36e95267c Mon Sep 17 00:00:00 2001
From: Jon Chuang <9093549+jon-chuang@users.noreply.github.com>
Date: Wed, 22 Nov 2023 08:48:47 +0000
Subject: [PATCH 093/221] [dynamo] `strict=True` kwarg for zip (#114047)

Fixes https://github.com/pytorch/pytorch/issues/113894

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114047
Approved by: https://github.com/ezyang
---
 test/dynamo/test_functions.py      | 28 ++++++++++++++++++++++++++++
 torch/_dynamo/variables/builtin.py | 16 +++++++++++-----
 2 files changed, 39 insertions(+), 5 deletions(-)

diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py
index cb2a1ae3a4dd..f699789b5402 100644
--- a/test/dynamo/test_functions.py
+++ b/test/dynamo/test_functions.py
@@ -2600,6 +2600,34 @@ def fn(param, param2):
         self.assertEqual(opt_fn(param, param), fn(param, param))
         self.assertEqual(cnts.frame_count, 2)  # Recompiles
 
+    @unittest.skipIf(
+        sys.version_info < (3, 10),
+        "zip strict kwargs not implemented for Python < 3.10",
+    )
+    def test_zip_strict(self):
+        def fn(x, ys, zs):
+            x = x.clone()
+            for y, z in zip(ys, zs, strict=True):
+                x += y * z
+            return x
+
+        opt_fn = torch._dynamo.optimize(backend="eager")(fn)
+        nopython_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn)
+
+        x = torch.ones(3)
+        ys = [1.0, 2.0, 3.0]
+        zs = [2.0, 5.0, 8.0]
+
+        self.assertEqual(opt_fn(x, ys, zs), fn(x, ys, zs))
+
+        # If nopython, should raise UserError
+        with self.assertRaisesRegex(torch._dynamo.exc.UserError, "zip()"):
+            nopython_fn(x, ys[:1], zs)
+
+        # Should cause fallback if allow graph break
+        with self.assertRaisesRegex(ValueError, "zip()"):
+            opt_fn(x, ys[:1], zs)
+
     def test_compare_constant_and_tensor(self):
         for op in [
             operator.lt,
diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py
index 5ddc1f32a763..101e2a17c248 100644
--- a/torch/_dynamo/variables/builtin.py
+++ b/torch/_dynamo/variables/builtin.py
@@ -899,12 +899,18 @@ def call_custom_dict(tx, user_cls, *args, **kwargs):
             )
         unimplemented(f"dict(): {args} {kwargs}")
 
-    def call_zip(self, tx, *args):
+    def call_zip(self, tx, *args, **kwargs):
+        if kwargs:
+            assert len(kwargs) == 1 and "strict" in kwargs
         if all(x.has_unpack_var_sequence(tx) for x in args):
-            items = [
-                variables.TupleVariable(list(item))
-                for item in zip(*[arg.unpack_var_sequence(tx) for arg in args])
-            ]
+            unpacked = [arg.unpack_var_sequence(tx) for arg in args]
+            if kwargs.pop("strict", False) and len(unpacked) > 0:
+                if not all(len(u) == len(unpacked[0]) for u in unpacked):
+                    raise UserError(
+                        ValueError,
+                        "zip() has one argument of len differing from others",
+                    )
+            items = [variables.TupleVariable(list(item)) for item in zip(*unpacked)]
             return variables.TupleVariable(items)
 
     def call_enumerate(self, tx, *args):

From 3e1abde46d7904dea60cc4fe317730a0c47b6e9e Mon Sep 17 00:00:00 2001
From: PyTorch MergeBot 
Date: Wed, 22 Nov 2023 10:13:48 +0000
Subject: [PATCH 094/221] Revert "AOTAutograd: handle set_(), detect metadata
 mutations that cancel out (#111554)"

This reverts commit a911b4db9d82238a1d423e2b4c0a3d700217f0c1.

Reverted https://github.com/pytorch/pytorch/pull/111554 on behalf of https://github.com/DanilBaibak due to The lower PR in the stack #113926 breaks the internal build ([comment](https://github.com/pytorch/pytorch/pull/111554#issuecomment-1822472206))
---
 aten/src/ATen/FunctionalTensorWrapper.cpp     |  29 ---
 aten/src/ATen/FunctionalTensorWrapper.h       |  14 --
 aten/src/ATen/FunctionalizeFallbackKernel.cpp |  25 ---
 test/functorch/test_aotdispatch.py            | 126 ++-----------
 torch/_dynamo/variables/tensor.py             |   8 -
 torch/_functorch/aot_autograd.py              | 171 ++++++------------
 .../python_torch_functions_manual.cpp         |  51 ------
 torchgen/gen_functionalization_type.py        |   6 +-
 8 files changed, 72 insertions(+), 358 deletions(-)

diff --git a/aten/src/ATen/FunctionalTensorWrapper.cpp b/aten/src/ATen/FunctionalTensorWrapper.cpp
index 5ab225467766..7a6c5c41632e 100644
--- a/aten/src/ATen/FunctionalTensorWrapper.cpp
+++ b/aten/src/ATen/FunctionalTensorWrapper.cpp
@@ -232,35 +232,6 @@ void FunctionalTensorWrapper::replace_(const Tensor& other) {
   mutation_counter_++;
 }
 
-bool FunctionalTensorWrapper::has_data_mutation() {
-  // Current tensor's data was mutated if its storage saw any mutations.
-  return functional_storage_impl()->generation() > 0;
-}
-
-void FunctionalTensorWrapper::set__impl(const FunctionalTensorWrapper* other) {
-  // self.set_(src) will cause self to have all of the tensor properties of self.
-  value_ = other->value_;
-  generation_ = other->generation_;
-  view_metas_ = other->view_metas_;
-  // FREEZE the old storage, preventing mutations to it.
-  // this is a huge pain to handle properly in all cases, so we ban it.
-  functional_storage_impl()->freeze();
-  // Unsafely swap out the storage with other's storage,
-  // disconnecting `self` with its view chain
-  storage_ = other->storage_;
-  /// explicitly mark the tensor as having its storage changed from set_()
-  // Otherwise, we don't actually have a 100% accurate way to check this.
-  // (We could check if the updated value has a new storage than the original value,
-  // but this won't also let us uniquely determine if the tensor **also**
-  // experienced a data mutation).
-  was_storage_changed_ = true;
-
-  auto sizes_ = value_.sym_sizes();
-  auto strides_ = value_.sym_strides();
-  auto storage_offset_ = value_.sym_storage_offset();
-  set_sizes_and_strides(sizes_, strides_, storage_offset_);
-}
-
 void FunctionalTensorWrapper::maybe_replace_storage(const Tensor& other) {
   // Note [resize_() in functionalization pass]
   // resize_() is a special operator in functionalization because it can reallocate its underlying storage.
diff --git a/aten/src/ATen/FunctionalTensorWrapper.h b/aten/src/ATen/FunctionalTensorWrapper.h
index 7b22ceeb01a6..3d899038c1e7 100644
--- a/aten/src/ATen/FunctionalTensorWrapper.h
+++ b/aten/src/ATen/FunctionalTensorWrapper.h
@@ -122,18 +122,6 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
   // tensor by replaying the views off of the alias.
   void mutate_view_meta(at::functionalization::ViewMeta meta);
 
-  // Custom implementation of self.set_(src)
-  void set__impl(const FunctionalTensorWrapper* other);
-
-  // Returns whether the current tensor's data was ever mutated
-  bool has_data_mutation();
-  //
-  // Returns whether the current FunctionalTensorWrapper
-  // experienced a set_() call.
-  bool was_storage_changed() {
-    return was_storage_changed_;
-  }
-
   // The functionalization pass can be used to remove mutations.
   // It does so by replacing any mutation op with it's corresponding
   // out-of-place op, followed by a call to replace_(). e.g:
@@ -207,8 +195,6 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
   uint64_t mutation_hidden_from_autograd_counter_ = 0;
   bool has_metadata_mutation_ = false;
   bool is_multi_output_view_ = false;
-  // Did the tensor experience a set_() call.
-  bool was_storage_changed_ = false;
 
   size_t generation_ = 0;
   std::vector view_metas_;
diff --git a/aten/src/ATen/FunctionalizeFallbackKernel.cpp b/aten/src/ATen/FunctionalizeFallbackKernel.cpp
index 783a925d6983..3e9e234db45a 100644
--- a/aten/src/ATen/FunctionalizeFallbackKernel.cpp
+++ b/aten/src/ATen/FunctionalizeFallbackKernel.cpp
@@ -299,28 +299,6 @@ static at::Tensor _unsafe_view_functionalize(const at::Tensor & self, at::SymInt
   return out;
 }
 
-static at::Tensor& set__functionalize(at::Tensor& self, const at::Tensor& src) {
-  // error case
-  TORCH_CHECK(at::functionalization::impl::isFunctionalTensor(self) || !at::functionalization::impl::isFunctionalTensor(src),
-    "set__functionalize: Tried to mutate a non-functional tensor with a functional tensor, which is not allowed");
-
-  TORCH_CHECK(at::functionalization::impl::isFunctionalTensor(src),
-    "set__functionalize: We do not currently support x.set_(y) where y is not a FunctionalTensor. Please file an issue");
-
-  // nop case
-  if (!at::functionalization::impl::isFunctionalTensor(self) && !at::functionalization::impl::isFunctionalTensor(src)) {
-    at::AutoDispatchSkipFunctionalize guard;
-    return self.set_(src);
-  }
-
-  TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(self));
-  TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(src));
-  auto self_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(self);
-  auto src_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(src);
-  self_impl->set__impl(src_impl);
-  return self;
-}
-
 TORCH_LIBRARY_IMPL(_, Functionalize, m) {
   m.fallback(torch::CppFunction::makeFromBoxedFunction<&functionalizeFallback>());
 }
@@ -332,7 +310,4 @@ TORCH_LIBRARY_IMPL(aten, Functionalize, m) {
   m.impl("lift_fresh_copy", TORCH_FN(lift_fresh_functionalize_copy));
   m.impl("_to_copy", TORCH_FN(_to_copy_functionalize));
   m.impl("_unsafe_view", TORCH_FN(_unsafe_view_functionalize));
-  // The overloads of set_() that take in a storage should never
-  // appear with torch.compile, because dynamo graph breaks
-  m.impl("set_.source_Tensor", TORCH_FN(set__functionalize));
 }
diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py
index ab310c247abe..5da8308205c4 100644
--- a/test/functorch/test_aotdispatch.py
+++ b/test/functorch/test_aotdispatch.py
@@ -455,97 +455,6 @@ def forward(self, primals_1):
     mul_1 = torch.ops.aten.mul.Tensor(mul, 3)
     return [mul, mul_1]""")
 
-    def test_input_mutation_set__input_mutation(self):
-        def f(a):
-            b = torch.arange(9, dtype=a.dtype).reshape(3, 3)
-            with torch.no_grad():
-                a.set_(b)
-            return a * b
-        inp = [torch.ones(3, 3, requires_grad=True)]
-        self.verify_aot_autograd(f, inp, test_mutation=True)
-        inp = [torch.ones(3, 3, requires_grad=False)]
-        self.verify_aot_autograd(f, inp, test_mutation=True)
-
-    def test_set__steals_view_chain(self):
-        def f(a, b):
-            a_ = a.mul(2)
-            b_ = b.mul(2)
-            b_slice = b_[1].view(3, 3)
-            # a_clone should inherit the view chain from b_slice
-            a_.set_(b_slice)
-            # Also mutates b_,
-            a_.view(-1).mul_(2)
-            return a_ * b_slice
-        inp = [torch.ones(3, 3, requires_grad=False), torch.zeros(3, 9, requires_grad=False)]
-        self.verify_aot_autograd(f, inp)
-
-    def test_set__and_data_mutation_good(self):
-        def f(a, b):
-            # The data mutation happens *after* the set_(). This is ok (see the graph below)
-            with torch.no_grad():
-                a.set_(b)
-            b.mul_(2)
-            return a + b
-        inp = [torch.ones(3, 3, requires_grad=True), torch.ones(3, 3, requires_grad=True)]
-        fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True)
-        inp = [torch.ones(3, 3, requires_grad=False), torch.zeros(3, 3, requires_grad=False)]
-        self.verify_aot_autograd(f, inp, test_mutation=True)
-        # Important things to note:
-        # - "return a.set_(b)" desugars into "return b"
-        # - Both a and b are recorded as experiencing mutations,
-        #   which is why we see "b_updated" (output of the mul) twice in the graph outputs.
-        #   a is recorded as both a data mutation and a metadata mutation (due to set_ swapping its storage).
-        # - the runtime epilogue for a is "a.set_(mul)"
-        # - the runtime epilogue for b is "b.copy_(mul)"
-        self.assertExpectedInline(fw_graph.code.strip(), """\
-def forward(self, primals_1, primals_2):
-    clone = torch.ops.aten.clone.default(primals_2);  primals_2 = None
-    mul = torch.ops.aten.mul.Tensor(clone, 2);  clone = None
-    add = torch.ops.aten.add.Tensor(mul, mul)
-    return [mul, mul, add]""")
-
-    # This is a (hopefully) extremely rare case that is difficult to handle,
-    # so we ban it.
-    def test_set__and_data_mutation_bad(self):
-        def f(a):
-            a_view = a.view(-1)
-            tmp = torch.ones(3, 3, requires_grad=True)
-            # Now, any mutations on either tmp
-            # will be tracked as graph input mutations.
-            with torch.no_grad():
-                a.set_(tmp)
-            # BAD: a_view is now detached from every graph input,
-            # so we won't recognize that this caused an input mutation!
-            a_view.mul_(2)
-            return a + tmp
-        inp = [torch.ones(3, 3, requires_grad=True)]
-        with self.assertRaisesRegex(RuntimeError, "cannot mutate tensors with frozen storage"):
-            self.verify_aot_autograd(f, inp, test_mutation=True)
-
-    def test_input_mutation_set__nop(self):
-        def f(a):
-            b = torch.arange(9, dtype=a.dtype)
-            a_old = torch.ops.aten.alias.default(a)
-            with torch.no_grad():
-                a.set_(b)
-                a.set_(a_old)
-            return a + b.reshape(3, 3)
-        inp = [torch.ones(3, 3, requires_grad=True)]
-        fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True)
-        inp = [torch.ones(3, 3, requires_grad=False)]
-        self.verify_aot_autograd(f, inp, test_mutation=True)
-        # Things to note:
-        # - There are no set_() calls in the graph (we functionalize a.set_(b) into "b")
-        # - There is only **1** graph output. We properly realized that the two set_() calls
-        #   undo each other, and so effectively no inputs are mutated.
-        self.assertExpectedInline(fw_graph.code.strip(), """\
-def forward(self, primals_1):
-    arange = torch.ops.aten.arange.default(9, dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
-    alias = torch.ops.aten.alias.default(primals_1);  primals_1 = None
-    view = torch.ops.aten.view.default(arange, [3, 3]);  arange = None
-    add = torch.ops.aten.add.Tensor(alias, view);  alias = view = None
-    return [add]""")
-
     def test_input_mutation_simple_with_none_and_nontensor(self):
         # Tensor, None, int
         def f(a, b, c):
@@ -1715,9 +1624,10 @@ def inp_callable(req_grad):
         # Expectation: fwd() takes in 2 args, and we don't construct a synthetic base.
         self.assertExpectedInline(fw_graph.code.strip(), """\
 def forward(self, primals_1, primals_2):
-    t = torch.ops.aten.t.default(primals_1);  primals_1 = None
-    add = torch.ops.aten.add.Tensor(t, primals_2);  t = primals_2 = None
-    return [add]""")
+    view = torch.ops.aten.view.default(primals_1, [4]);  primals_1 = None
+    t = torch.ops.aten.t.default(view);  view = None
+    add = torch.ops.aten.add.Tensor(t, primals_2);  primals_2 = None
+    return [t, add]""")
 
     def test_input_mutation_aliases_and_none_require_gradients(self):
         def f(a, b, c):
@@ -1756,7 +1666,7 @@ def test_input_mutation_aliases_bases_out_of_order(self):
         # So we don't need to do the base construction / deconstruction
         def f(a, b, c, d):
             b.add_(1)
-            d.unsqueeze_(0)
+            d.t_()
             return a + c + d, b.view(-1)
 
         def inp_callable(req_grad):
@@ -1785,11 +1695,11 @@ def forward(self, primals_1, primals_2, primals_3):
     as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [4], [1], 0);  clone = add = None
     add_1 = torch.ops.aten.add.Tensor(primals_2, primals_3);  primals_2 = primals_3 = None
     as_strided_3 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
-    unsqueeze_1 = torch.ops.aten.unsqueeze.default(as_strided_3, 0);  as_strided_3 = None
-    add_2 = torch.ops.aten.add.Tensor(add_1, unsqueeze_1);  add_1 = None
+    t_1 = torch.ops.aten.t.default(as_strided_3);  as_strided_3 = None
+    add_2 = torch.ops.aten.add.Tensor(add_1, t_1);  add_1 = None
     as_strided_11 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
-    view_2 = torch.ops.aten.view.default(as_strided_11, [-1]);  as_strided_11 = None
-    return [as_strided_scatter, add_2, view_2, unsqueeze_1]""")  # noqa: B950
+    view_1 = torch.ops.aten.view.default(as_strided_11, [-1]);  as_strided_11 = None
+    return [as_strided_scatter, add_2, view_1, t_1]""")  # noqa: B950
 
     @unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable")
     def test_synthetic_base_base_attribute_is_none(self):
@@ -2027,7 +1937,7 @@ def f(x, y):
     def test_dupe_arg_torture(self):
         def f(x, y):
             x.t_()
-            y.unsqueeze_(0)
+            y.t_()
             return x + y
 
         x = torch.randn(3, 3, requires_grad=True).clone()
@@ -2088,8 +1998,8 @@ def test_invalid_dupe_fake(self, counter):
     def _test_invalid_dupe(self, counter, fake):
         class F(torch.nn.Module):
             def forward(self, x, y):
-                x.unsqueeze_(0)
-                y.unsqueeze_(0)
+                x.t_()
+                y.t_()
                 return (x + y,)
 
         x = torch.randn(3, 3, requires_grad=True).clone()
@@ -2108,8 +2018,6 @@ def forward(self, x, y):
             fxy = aot_module_simplified(F(), (x, y), nop)
 
         fxy(x, y)
-        x = torch.randn(3, 3, requires_grad=True).clone()
-        y = torch.randn(3, 3, requires_grad=True).clone()
         fxy(x, x)  # is ok!
 
         if fake:
@@ -2117,13 +2025,9 @@ def forward(self, x, y):
         else:
             fxx = aot_module_simplified(F(), (x, x), nop)
 
-        x = torch.randn(3, 3, requires_grad=True).clone()
-        y = torch.randn(3, 3, requires_grad=True).clone()
         fxx(x, x)
         # Note This should not raise! Once we have guards in place here,
         # we will have this working correctly, as it should recompile.
-        x = torch.randn(3, 3, requires_grad=True).clone()
-        y = torch.randn(3, 3, requires_grad=True).clone()
         self.assertExpectedRaisesInline(
             AssertionError, lambda: fxx(x, y),
             """At compilation time, graph 1 was compiled under the assumption that input 1 would be a duplicate of input 0, but at runtime this was not the case.  This indicates a guard bug in AOTAutograd or Dynamo, please file a bug to PyTorch."""  # noqa: B950
@@ -2744,7 +2648,7 @@ def fn(p, x):
             x.t_()
             return (x * 2,)
         mod = TestMod(fn)
-        inp = torch.randn(2, 4)
+        inp = torch.randn(2)
         with self.assertRaisesRegex(
             RuntimeError, "Found an input that received a metadata mutation"
         ):
@@ -3453,7 +3357,7 @@ def f(a, b):
     def test_aot_dispatch_input_metadata_mutation(self):
         def f(a, b):
             a.t_()
-            b.unsqueeze_(0)
+            b.t_()
             return a + b
 
         b1_ref = torch.arange(9, requires_grad=True, dtype=torch.float32).reshape(3, 3)
@@ -3498,7 +3402,7 @@ def f(a, b):
     def test_aot_dispatch_input_data_and_metadata_mutation(self):
         def f(a, b):
             a.t_()
-            b.unsqueeze_(0)
+            b.t_()
             a.mul_(2)
             b.mul_(3)
             return a + b
diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py
index 6fbe6c2afcc8..ca8d34988d56 100644
--- a/torch/_dynamo/variables/tensor.py
+++ b/torch/_dynamo/variables/tensor.py
@@ -605,14 +605,6 @@ def has_bool_key(v):
         elif name in ("resize_", "resize_as_"):
             # Handling resizing in its full generality is difficult.
             unimplemented(f"Tensor.{name}")
-        elif name == "set_" and len(args) > 1:
-            # torch.Tensor.set_() has several overloads.
-            # aten::set_.source_Tensor(Tensor) gets special handling
-            # in AOTAutograd and functionalization, because it is the most common
-            # overload and is used by FSDP.
-            # graph-breaking on aten::set_source_Tensor_storage_offset for now,
-            # unless we find that we need to make it work.
-            unimplemented("Tensor.set_.source_Tensor_storage_offset")
         elif (
             name == "add_" and len(args) == 1 and len(kwargs) == 1 and "alpha" in kwargs
         ):
diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py
index 4c29f1a85002..818388360d81 100644
--- a/torch/_functorch/aot_autograd.py
+++ b/torch/_functorch/aot_autograd.py
@@ -477,19 +477,9 @@ class InputAliasInfo:
     mutates_data: bool
     mutates_metadata: bool
     mutations_hidden_from_autograd: bool
-    # This can only happen from a call to aten.set_() on a graph input.
-    mutates_storage_metadata: bool
     requires_grad: bool
     mutation_type: MutationType
 
-    def __post_init__(self):
-        if self.mutates_storage_metadata:
-            # For convenience, we guarantee that this is always true.
-            # In practice, If we call .set_(), then at runtime there is no need
-            # to additionally fix  up the tensor metadata, since our runtime
-            # call to inp.set_(updated_inp) will already have the right metadata
-            assert self.mutates_metadata
-
 
 @dataclasses.dataclass
 class SubclassCreationMeta:
@@ -946,85 +936,65 @@ def is_fun(t):
 # t here is either
 # (1) A FunctionalTensor(_to_functional_tensor(FakeTensor))
 # (2) A traceable tensor subclass that holds a FunctionalTensor
-# (3) Not a tensor
-def has_data_mutation(t):
+def has_metadata_mutation(t):
     if is_traceable_wrapper_subclass(t):
         attrs, _ = t.__tensor_flatten__()
         # A tensor subclass was updated if any of its inner elements were updated
-        return any(has_data_mutation(getattr(t, attr)) for attr in attrs)
+        return any(has_metadata_mutation(getattr(t, attr)) for attr in attrs)
     else:
-        if isinstance(t, torch.Tensor):
-            assert isinstance(t, FunctionalTensor)
-            return torch._functionalize_has_data_mutation(t.elem)
-        return False
+        assert isinstance(t, FunctionalTensor)
+        return torch._functionalize_has_metadata_mutation(t.elem)
 
 def are_all_mutations_hidden_from_autograd(t):
     if is_traceable_wrapper_subclass(t):
         attrs, _ = t.__tensor_flatten__()
         # If all inner elements are mutations hidden from autograd, then it is a mutation hidden from autograd.
         return all(are_all_mutations_hidden_from_autograd(getattr(t, attr)) for attr in attrs)
-    elif isinstance(t, torch.Tensor):
+    else:
         assert isinstance(t, FunctionalTensor)
         return torch._functionalize_are_all_mutations_hidden_from_autograd(t.elem)
-    else:
-        return False
 
-# f_arg here is either
-# (1) A FunctionalTensor(_to_functional_tensor(FakeTensor))
-# (2) A traceable tensor subclass that holds a FunctionalTensor
-# (3) Not a tensor
-# Assumption: arg promises to be the "original" tensor wrapped by f_arg
-# Note: "storage mutations" coming from set_() are a type of metadata mutation. So:
-# - check_only_storage_mutation=True: only return true if there was a storage mutation
-# - check_only_storage_mutation=Flse: return true if there was any metadata mutation (including a storage mutation)
-def has_metadata_mutation(f_arg, arg, *, check_only_storage_mutation: bool):
-    if is_traceable_wrapper_subclass(f_arg):
-        attrs, _ = f_arg.__tensor_flatten__()
+# new_arg and arg here are either:
+# (1) both a FakeTensor
+# (2) both a traceable tensor subclass that holds a FakeTensor
+# Pre-condition: the two args are the "old" and "new" inputs from running functionalization.
+# When we run functionalization and wrap our inputs into FunctionalTensors,
+# we can detect whether or not an input was mutated by checking to see if the inner tensor has changed
+#
+# Normally it would be enough just to check if arg is new_arg, which is normally enough for functionalization
+# to confirm that inputs were not mutated when running the user's model with functionalization on.
+# But when we have subclass inputs, we can't rely on that:
+# `from_fun(to_fun(x)) is x` will return False, because the call to `from_fun` constructs
+# a brand new subclass instance: we are calling __tensor_unflatten__, and going
+# from Subclass(FakeTensor) to Subclass(FunctionalTensor(FakeTensor))
+def was_updated(arg, new_arg):
+    if is_traceable_wrapper_subclass(arg):
+        assert is_traceable_wrapper_subclass(new_arg)
+        attrs, _ = arg.__tensor_flatten__()
+        new_attrs, _ = new_arg.__tensor_flatten__()
+        assert attrs == new_attrs
         # A tensor subclass was updated if any of its inner elements were updated
-        f_inner_ts = [getattr(f_arg, attr) for attr in attrs]
-        inner_ts = [getattr(arg, attr) for attr in attrs]
-        return any(has_metadata_mutation(f_inner_t, inner_t, check_only_storage_mutation=check_only_storage_mutation)
-                   for f_inner_t, inner_t in zip(f_inner_ts, inner_ts))
+        return any(was_updated(getattr(arg, attr), getattr(new_arg, attr)) for attr in attrs)
     else:
-        if not isinstance(f_arg, torch.Tensor):
-            assert not isinstance(arg, torch.Tensor)
-            return False
-        assert isinstance(f_arg, FunctionalTensor)
-        assert isinstance(arg, FakeTensor)
-
-        arg_after = torch._from_functional_tensor(f_arg.elem)
-        # This is true if the current tensor experienced at least one set_() call
-        maybe_storage_changed = torch._functionalize_was_storage_changed(f_arg.elem)
-        # However, multiple set_() calls can cancel out. So we also check whether the
-        # storage of the tensor has changed.
-        # Note: if an input experienced two set_() calls that cancel out, **and**
-        # it experiences an data mutation, we pessimistically think that the set_()
-        # call is necessary here. We could in theory fix this, but this will
-        # hopefully never happen in user code, and is not needed for fsdp.
-        same_storages = StorageWeakRef(arg.untyped_storage()) == StorageWeakRef(arg_after.untyped_storage())
-        has_storage_metadata_mutation = maybe_storage_changed and not same_storages
-        if check_only_storage_mutation:
-            return has_storage_metadata_mutation
-
-        # storage metadata mutation is a type of metadata mutation, so return true if we saw one
-        if has_storage_metadata_mutation:
-            return True
-
-        maybe_metadata_mutated = torch._functionalize_has_metadata_mutation(f_arg.elem)
-        # This is true if the current tensor experienced at least one metadata mutation.
-        # So if false, we know there was no metadata mutation
-        if not maybe_metadata_mutated:
-            return False
-
-        # However, multi metadata mutations can cancel out.
-        # So we also check if the concrete sizes/strides on the tensor have changed.
-        same_sizes = arg.shape == arg_after.shape
-        same_strides = arg.stride() == arg_after.stride()
-        same_offsets = arg.storage_offset() == arg_after.storage_offset()
-        has_metadata_mutation_ = maybe_metadata_mutated and not (same_sizes and same_strides and same_offsets)
-        # We consider a tensor to have been metadata mutated if its storage was mutated through a set_() call.
-        return has_metadata_mutation_
-
+        return arg is not new_arg
+
+# new_arg and arg here are either:
+# (1) both a FakeTensor
+# (2) both a traceable tensor subclass that holds a FakeTensor
+# Pre-condition: the two args are the "old" and "new" inputs from running functionalization.
+# When we run functionalization and wrap our inputs into FunctionalTensors,
+# we can detect whether or not an input was mutated by checking to see if the inner tensor has changed,
+# but shares storage with the old input
+def was_metadata_updated(arg, new_arg):
+    if is_traceable_wrapper_subclass(arg):
+        assert is_traceable_wrapper_subclass(new_arg)
+        attrs, _ = arg.__tensor_flatten__()
+        new_attrs, _ = new_arg.__tensor_flatten__()
+        assert attrs == new_attrs
+        # A tensor subclass was updated if any of its inner elements were updated
+        return any(was_metadata_updated(getattr(arg, attr), getattr(new_arg, attr)) for attr in attrs)
+    else:
+        return arg is not new_arg and StorageWeakRef(arg.untyped_storage()) == StorageWeakRef(new_arg.untyped_storage())
 
 def _get_hints(exprs):
     """
@@ -1152,27 +1122,18 @@ def inner(*flat_args):
                 new_arg = arg
             else:
                 new_arg = from_fun(f_arg)
-            mutates_metadata = has_metadata_mutation(f_arg, arg, check_only_storage_mutation=False)
-            mutates_storage_metadata = has_metadata_mutation(f_arg, arg, check_only_storage_mutation=True)
-            mutates_data = has_data_mutation(f_arg)
-            mutations_hidden_from_autograd = are_all_mutations_hidden_from_autograd(f_arg)
-
-            # Here, we're saying that if an input experienced a set call, inp.set_(other),
-            # then we can effectively not have to worry about whether its data was mutated.
-            # There are 3 cases:
-            # (1) We mutate inp *after* the set_() call. other is a graph intermediate.
-            #     In this case, we're not really mutating the input storage of "inp";
-            #     we're mutating the storage of an intermdiate value (other),
-            #     and slamming that storage into the input tensor. So no data mutation is necessary.
-            # (2) We mutate inp *after* the set_() call. other is a graph *input*.
-            #     In this case, the data mutation will be properly handled in the runtime
-            #     epilogue during the processing of "other"
-            # (3) We mutate inp *before* the set_() call.
-            #     This case is *not* currently handled.
-            #     TODO: discuss this in the PR. Both supporting this, and detecting + erroring out,
-            #     seem painful to get working.
-            if mutates_storage_metadata:
+            if was_updated(arg, new_arg):
+                if was_metadata_updated(arg, new_arg):
+                    mutates_data = False
+                    mutates_metadata = True
+                else:
+                    mutates_data = True
+                    mutates_metadata = has_metadata_mutation(f_arg)
+                mutations_hidden_from_autograd = are_all_mutations_hidden_from_autograd(f_arg)
+            else:
                 mutates_data = False
+                mutates_metadata = False
+                mutations_hidden_from_autograd = False
 
             requires_grad = isinstance(f_arg, torch.Tensor) and f_arg.requires_grad
 
@@ -1181,7 +1142,6 @@ def inner(*flat_args):
                 mutates_data=mutates_data,
                 mutates_metadata=mutates_metadata,
                 mutations_hidden_from_autograd=mutations_hidden_from_autograd,
-                mutates_storage_metadata=mutates_storage_metadata,
                 requires_grad=requires_grad,
                 mutation_type=_get_mutation_type(
                     keep_input_mutations,
@@ -1710,8 +1670,6 @@ def maybe_to_fresh_input(idx, t, meta):
             # Make sure the primal we pass to autograd.grad()
             # sees the tensor before the mutation
             return t.clone()
-        # No need to do anything for  meta.input_info[idx].mutates_storage_metadata,
-        # Because autograd doesn't support set_()
         if meta.input_info[idx] and meta.input_info[idx].mutates_metadata:
             # Make sure the primal we pass to autograd.grad()
             # sees the tensor before the metadata mutation
@@ -2689,8 +2647,7 @@ def create_synthetic_base_metadata(
             # mutations, they will be hidden from the rest of aot autograd.
             mutates_data=mutates_data,
             mutates_metadata=mutates_metadata,
-            mutations_hidden_from_autograd=all(m.input_info[x].mutations_hidden_from_autograd for x in outer_indices),
-            mutates_storage_metadata=False if len(outer_indices) > 1 else m.input_info[outer_indices[0]].mutates_storage_metadata,
+            mutations_hidden_from_autograd=mutations_hidden_from_autograd,
             is_leaf=any_leaf,
             requires_grad=requires_grad,
             mutation_type=mutation_type,
@@ -3240,22 +3197,6 @@ def runtime_wrapper(*args):
                     continue
                 original_inpt = args[inpt_idx]
                 updated_inpt = updated_inputs[i]
-                if meta.mutates_storage_metadata:
-                    # mutates_storage_metadata means our input saw a x.set_(y) call.
-                    # What if x **also** saw a data and/or a metadata mutation?
-                    # (1) If the [meta]data mutation occurred after the set_(),
-                    #     then there is no need to copy_() the data.
-                    #     When we perform x.set_(x_updated), we are guaranteed that
-                    #     x_updated already has the final version of the data/metadata
-                    # (2) If a data mutation occurred before the set_().
-                    #     This case seems very difficult to support.
-                    #     TODO: discuss on the PR and decide if we want to tr to
-                    #     either support it, or detect and ban it.
-                    if trace_joint:
-                        assert isinstance(updated_inpt, TensorAlias)
-                        updated_inpt = updated_inpt.alias
-                    original_inpt.set_(updated_inpt)
-                    continue
                 if meta.mutates_metadata and not meta.mutates_data:
                     if trace_joint:
                         assert isinstance(updated_inpt, TensorAlias)
diff --git a/torch/csrc/autograd/python_torch_functions_manual.cpp b/torch/csrc/autograd/python_torch_functions_manual.cpp
index 913d29b4b69c..d9b8e6787058 100644
--- a/torch/csrc/autograd/python_torch_functions_manual.cpp
+++ b/torch/csrc/autograd/python_torch_functions_manual.cpp
@@ -462,48 +462,6 @@ static PyObject* THPVariable__is_functional_tensor(
   END_HANDLE_TH_ERRORS
 }
 
-static PyObject* THPVariable__functionalize_was_storage_changed(
-    PyObject* self,
-    PyObject* args,
-    PyObject* kwargs) {
-  HANDLE_TH_ERRORS
-  static PythonArgParser parser(
-      {"_functionalize_was_storage_changed(Tensor t)"}, /*traceable=*/true);
-
-  ParsedArgs<1> parsed_args;
-  auto r = parser.parse(args, kwargs, parsed_args);
-  auto self_ = r.tensor(0);
-  TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(self_));
-  auto wrapper = at::functionalization::impl::unsafeGetFunctionalWrapper(self_);
-  if (wrapper->was_storage_changed()) {
-    Py_RETURN_TRUE;
-  } else {
-    Py_RETURN_FALSE;
-  }
-  END_HANDLE_TH_ERRORS
-}
-
-static PyObject* THPVariable__functionalize_has_data_mutation(
-    PyObject* self,
-    PyObject* args,
-    PyObject* kwargs) {
-  HANDLE_TH_ERRORS
-  static PythonArgParser parser(
-      {"_functionalize_has_data_mutation(Tensor t)"}, /*traceable=*/true);
-
-  ParsedArgs<1> parsed_args;
-  auto r = parser.parse(args, kwargs, parsed_args);
-  auto self_ = r.tensor(0);
-  TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(self_));
-  auto wrapper = at::functionalization::impl::unsafeGetFunctionalWrapper(self_);
-  if (wrapper->has_data_mutation()) {
-    Py_RETURN_TRUE;
-  } else {
-    Py_RETURN_FALSE;
-  }
-  END_HANDLE_TH_ERRORS
-}
-
 static PyObject* THPVariable__functionalize_has_metadata_mutation(
     PyObject* self,
     PyObject* args,
@@ -783,15 +741,6 @@ static PyMethodDef torch_functions_manual[] = {
          THPVariable__functionalize_is_multi_output_view),
      METH_VARARGS | METH_KEYWORDS | METH_STATIC,
      nullptr},
-    {"_functionalize_has_data_mutation",
-     castPyCFunctionWithKeywords(THPVariable__functionalize_has_data_mutation),
-     METH_VARARGS | METH_KEYWORDS | METH_STATIC,
-     nullptr},
-    {"_functionalize_was_storage_changed",
-     castPyCFunctionWithKeywords(
-         THPVariable__functionalize_was_storage_changed),
-     METH_VARARGS | METH_KEYWORDS | METH_STATIC,
-     nullptr},
     {"_functionalize_enable_reapply_views",
      castPyCFunctionWithKeywords(
          THPVariable__functionalize_enable_reapply_views),
diff --git a/torchgen/gen_functionalization_type.py b/torchgen/gen_functionalization_type.py
index c39fc3e3e3bf..d918bfb562fb 100644
--- a/torchgen/gen_functionalization_type.py
+++ b/torchgen/gen_functionalization_type.py
@@ -716,11 +716,7 @@ def emit_registration_helper(f: NativeFunction) -> str:
         return view_str
 
     elif isinstance(g, NativeFunctionsGroup):
-        # Gets a hand-written functionalization kernel
-        if g.inplace is not None and str(g.inplace.func.name) == "set_.source_Tensor":
-            fns = []
-        else:
-            fns = list(g.functions())
+        fns = list(g.functions())
     else:
         if str(g.func.name) in MUTABLE_OPS_NOT_USING_FUNCTIONALIZATION:
             return []

From a0e3321f0c4bae8102961a50eafaef1dd304d8cf Mon Sep 17 00:00:00 2001
From: Jiong Gong 
Date: Wed, 22 Nov 2023 16:00:38 +0800
Subject: [PATCH 095/221] [inductor cpp] vectorize embedding lookup (#114062)

For embedding lookup, there are indirect indexing with indices that are invariant to the vectorized itervar. To vectorize it, we need to keep the related indexing variables as scalars and allow vectorization when the related index_exprs are invariant to the vectorized itervar.

This PR adds the support by lazily broadcasting scalar values (index_expr and constant) to vectors so that vector operations are only generated if needed by `CppVecKernel` when any of the inputs are vectors, otherwise, scalar ops are generated. The cse variable in cpp is now represented with `CppCSEVariable` which bookkeeps the relevant itervars to the variable and has a flag to mark whether it is a scalar or a vector. `CppVecOverrides` is improved to propagate these states when the ops are executed.

For the added UT `test_embedding_vec`, the generated code before this PR is:
```c++
extern "C" void kernel(const long* in_ptr0,
                       const float* in_ptr1,
                       const float* in_ptr2,
                       float* out_ptr0)
{
    #pragma omp parallel num_threads(64)
    {
        {
            #pragma omp for
            for(long x0=static_cast(0L); x0(128L); x0+=static_cast(1L))
            {
                #pragma GCC ivdep
                for(long x1=static_cast(0L); x1(128L); x1+=static_cast(1L))
                {
                    auto tmp0 = in_ptr0[static_cast(x0)];
                    auto tmp5 = in_ptr2[static_cast(x1 + (128L*x0))];
                    auto tmp1 = decltype(tmp0)(tmp0 + 64);
                    auto tmp2 = tmp0 < 0;
                    auto tmp3 = tmp2 ? tmp1 : tmp0;
                    TORCH_CHECK((0 <= tmp3) & (tmp3 < 64L), "index out of bounds: 0 <= tmp3 < 64L")
                    auto tmp4 = in_ptr1[static_cast(x1 + (128L*tmp3))];
                    auto tmp6 = decltype(tmp4)(tmp4 + tmp5);
                    out_ptr0[static_cast(x1 + (128L*x0))] = tmp6;
                }
            }
        }
    }
}
```

After this PR, we have:
```c++
extern "C" void kernel(const long* in_ptr0,
                       const float* in_ptr1,
                       const float* in_ptr2,
                       float* out_ptr0)
{
    #pragma omp parallel num_threads(64)
    {
        {
            #pragma omp for
            for(long x0=static_cast(0L); x0(128L); x0+=static_cast(1L))
            {
                for(long x1=static_cast(0L); x1(128L); x1+=static_cast(16L))
                {
                    auto tmp0 = in_ptr0[static_cast(x0)];
                    auto tmp5 = at::vec::Vectorized::loadu(in_ptr2 + static_cast(x1 + (128L*x0)));
                    auto tmp1 = decltype(tmp0)(tmp0 + 64);
                    auto tmp2 = tmp0 < 0;
                    auto tmp3 = tmp2 ? tmp1 : tmp0;
                    TORCH_CHECK((0 <= tmp3) & (tmp3 < 64L), "index out of bounds: 0 <= tmp3 < 64L")
                    auto tmp4 = at::vec::Vectorized::loadu(in_ptr1 + static_cast(x1 + (128L*tmp3)));
                    auto tmp6 = tmp4 + tmp5;
                    tmp6.store(out_ptr0 + static_cast(x1 + (128L*x0)));
                }
            }
        }
    }
}
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114062
Approved by: https://github.com/jansel
---
 test/inductor/test_cpu_repro.py      |  18 ++
 torch/_inductor/codegen/cpp.py       | 292 ++++++++++++++++++---------
 torch/_inductor/codegen/cpp_prefix.h |   6 +
 3 files changed, 225 insertions(+), 91 deletions(-)

diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py
index f5ce2369bfdb..06dff1d34e31 100644
--- a/test/inductor/test_cpu_repro.py
+++ b/test/inductor/test_cpu_repro.py
@@ -1299,6 +1299,7 @@ def test_cpu_vec_cosim(self):
                 cpp_op_list.append(k)
 
         diff = [
+            "constant",
             "index_expr",
             "signbit",
             "isinf",
@@ -2612,6 +2613,23 @@ def forward(self, x):
         x = torch.randn(1, 39, 1, 18, 17)
         self.common(m, (x,))
 
+    def test_embedding_vec(self):
+        class M(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.emb = torch.nn.Embedding(64, 128)
+
+            def forward(self, idx, x):
+                return self.emb(idx) + x
+
+        idx = torch.randint(0, 64, (4, 32))
+        x = torch.randn(4, 32, 128)
+        m = M().eval()
+        with torch.no_grad():
+            metrics.reset()
+            self.common(m, (idx, x))
+            assert metrics.generated_cpp_vec_kernel_count == 1
+
 
 if __name__ == "__main__":
     from torch._dynamo.test_case import run_tests
diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py
index 90f55ed8b55a..b64a67bd86a2 100644
--- a/torch/_inductor/codegen/cpp.py
+++ b/torch/_inductor/codegen/cpp.py
@@ -7,7 +7,7 @@
 import re
 import sys
 from copy import copy, deepcopy
-from typing import Dict, List, Optional, Tuple, Union
+from typing import Dict, List, Optional, Set, Tuple, Union
 
 import sympy
 
@@ -133,6 +133,19 @@
 ]
 
 
+def value_to_cpp(value, cpp_type):
+    if value == float("-inf"):
+        return f"-std::numeric_limits<{cpp_type}>::infinity()"
+    elif value == float("inf"):
+        return f"std::numeric_limits<{cpp_type}>::infinity()"
+    elif isinstance(value, bool):
+        return f"static_cast<{cpp_type}>({str(value).lower()})"
+    elif math.isnan(value):
+        return f"std::numeric_limits<{cpp_type}>::quiet_NaN()"
+    else:
+        return f"static_cast<{cpp_type}>({repr(value)})"
+
+
 def reduction_init(reduction_type, dtype):
     if dtype in DTYPE_LOWP_FP:
         # Since load promotes all half-precision inputs to float, the initial
@@ -436,6 +449,54 @@ def get_current_node_opt_ctx() -> OptimizationContext:
     return get_opt_ctx(V.interpreter.current_node)
 
 
+class CppCSEVariable(CSEVariable):
+    def __init__(self, name, bounds: ValueRanges):
+        super().__init__(name, bounds)
+        self.is_vec = False
+        self.dtype: Optional[torch.dtype] = None
+        self.dependent_itervars: Set[sympy.Symbol] = set()
+
+    def update_on_args(self, name, args, kwargs):
+        if name == "load":
+            # args[1] is index
+            self._set_dependent_itervars(args[1])
+        else:
+            # propagate relevant itervars and is_vec from args
+            self.dependent_itervars.update(
+                *[
+                    arg.dependent_itervars
+                    for arg in args
+                    if isinstance(arg, CppCSEVariable)
+                ]
+            )
+            if name == "index_expr":
+                self._set_dependent_itervars(args[0])
+            if any(arg.is_vec for arg in args if isinstance(arg, CppCSEVariable)):
+                self.is_vec = True
+        if (
+            hasattr(V.interpreter, "current_node")
+            and get_current_node_opt_ctx() is not None
+        ):
+            self.dtype = get_current_node_opt_ctx().dtype
+
+    def _set_dependent_itervars(self, index: sympy.Expr):
+        """
+        Set the relevant itervars for this variable based on the `index` expression.
+        This includes the itervars directly used in the `index` as well as relevant itervars
+        of other cse variables used in the `index`.
+        """
+        for s in index.free_symbols:
+            if s in V.kernel.itervars:
+                self.dependent_itervars.add(s)
+            elif s.name in V.kernel.cse.varname_map:
+                self.dependent_itervars.update(
+                    V.kernel.cse.varname_map[s.name].dependent_itervars
+                )
+
+    def depends_on(self, itervar: sympy.Symbol):
+        return itervar in self.dependent_itervars
+
+
 class CppOverrides(OpOverrides):
     """Map element-wise ops to C++"""
 
@@ -672,22 +733,20 @@ def mod(a, b):
 
     @staticmethod
     def constant(val, dtype):
+        opt_ctx: OptimizationContext = get_current_node_opt_ctx()
+        assert opt_ctx and opt_ctx.dtype is not None
+        dtype = opt_ctx.dtype
         if dtype in DTYPE_LOWP_FP:
             # Since load promotes all half-precision inputs to float, constants
             # must be promoted as well
             dtype = torch.float32
-        if val == float("inf"):
-            return f"std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::infinity()"
-        elif val == float("-inf"):
-            return f"-std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::infinity()"
-        elif math.isnan(val):
-            return f"std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::quiet_NaN()"
-        elif val is True or val is False:
-            return ops.to_dtype(str(val).lower(), dtype)
-        return ops.to_dtype(repr(val), dtype)
+        return value_to_cpp(val, DTYPE_TO_CPP[dtype])
 
     @staticmethod
     def index_expr(expr, dtype):
+        opt_ctx: OptimizationContext = get_current_node_opt_ctx()
+        assert opt_ctx and opt_ctx.dtype is not None
+        dtype = opt_ctx.dtype
         return ops.to_dtype(cexpr(V.kernel.rename_indexing(expr)), dtype)
 
     @staticmethod
@@ -704,19 +763,7 @@ def masked(mask, body, other):
         V.kernel.compute.splice(code)
 
         # Use the lambda's return type as the type of other
-        type = f"decltype({body_var}())"
-
-        if other == float("-inf"):
-            other_code = f"-std::numeric_limits<{type}>::infinity()"
-        elif other == float("inf"):
-            other_code = f"std::numeric_limits<{type}>::infinity()"
-        elif isinstance(other, bool):
-            other_code = f"static_cast<{type}>({str(other).lower()})"
-        elif math.isnan(other):
-            other_code = f"std::numeric_limits<{type}>::quiet_NaN()"
-        else:
-            other_code = f"static_cast<{type}>({repr(other)})"
-
+        other_code = value_to_cpp(other, f"decltype({body_var}())")
         return f"{mask} ? {body_var}() : {other_code}"
 
     @staticmethod
@@ -794,6 +841,54 @@ def sign(x):
 class CppVecOverrides(CppOverrides):
     """Map element-wise ops to aten vectorization C++"""
 
+    def __new__(cls, *args, **kargs):
+        self = super().__new__(cls)
+
+        def wrap(func):
+            # `CppVecKernel` generates both scalar ops and vector ops according to
+            # whether the inputs are scalars or vectors while all ops in `CppVecOverrides`
+            # (except for "masked") assume the inputs are vectors. We wrap the ops in
+            # `CppVecOverrides` to broadcast scalar inputs to vectors if needed or fallback to
+            # `CppOverrides` when all inputs are scalars.
+            #
+            # Inputs to ops.masked are handled separately in its own function due to
+            # the need of recurive handling of masked body.
+            def wrapper(*args, **kwargs):
+                has_scalar = any(
+                    not arg.is_vec for arg in args if isinstance(arg, CppCSEVariable)
+                )
+                has_vector = any(
+                    arg.is_vec for arg in args if isinstance(arg, CppCSEVariable)
+                )
+                new_args = list(args)
+                if has_scalar and has_vector:
+                    # broadcast scalar args to vector if needed
+                    new_args = []
+                    for arg in args:
+                        if isinstance(arg, CppCSEVariable) and not arg.is_vec:
+                            assert isinstance(V.kernel, CppVecKernel)
+                            new_arg = V.kernel.broadcast(arg)
+                            new_args.append(new_arg)
+                        else:
+                            new_args.append(arg)
+                if has_vector:
+                    return func(*new_args, **kwargs)
+                else:
+                    # fallback to scalar ops
+                    scalar_ops = super(CppVecOverrides, self)
+                    scalar_func = getattr(
+                        scalar_ops, func.__name__, scalar_ops.__getattr__(func.__name__)  # type: ignore[attr-defined]
+                    )
+                    assert scalar_func is not None
+                    return scalar_func(*args, **kwargs)
+
+            return wrapper
+
+        for name, method in vars(cls).items():
+            if getattr(method, "__class__", None) == staticmethod and name != "masked":
+                setattr(self, name, wrap(method.__func__))
+        return self
+
     @staticmethod
     def add(a, b):
         return f"{a} + {b}"
@@ -1006,28 +1101,6 @@ def acosh(x):
         vec_one = f"decltype({x})(1)"
         return f"({x} + ({x}*{x} - {vec_one}).sqrt()).log()"
 
-    @staticmethod
-    def constant(val, dtype):
-        opt_ctx: OptimizationContext = get_current_node_opt_ctx()
-        assert opt_ctx
-        proposed_dtype = opt_ctx.dtype
-        assert proposed_dtype in [
-            torch.float,
-            torch.int32,
-        ]
-        if val == float("inf"):
-            quote = f"std::numeric_limits<{DTYPE_TO_CPP[proposed_dtype]}>::infinity()"
-        elif val == float("-inf"):
-            quote = f"-std::numeric_limits<{DTYPE_TO_CPP[proposed_dtype]}>::infinity()"
-        elif math.isnan(val):
-            quote = f"std::numeric_limits<{DTYPE_TO_CPP[proposed_dtype]}>::quiet_NaN()"
-        elif val is True or val is False:
-            quote = f"static_cast<{DTYPE_TO_CPP[proposed_dtype]}>({str(val).lower()})"
-        else:
-            quote = f"static_cast<{DTYPE_TO_CPP[proposed_dtype]}>({repr(val)})"
-
-        return f"at::vec::Vectorized<{DTYPE_TO_CPP[proposed_dtype]}>({quote})"
-
     @staticmethod
     def relu(x):
         bug = config.cpp.inject_relu_bug_TESTING_ONLY
@@ -1159,32 +1232,24 @@ def masked(mask, body, other):
         code.writeline(";")
         V.kernel.compute.splice(code)
 
-        if other == float("-inf"):
-            other_code = (
-                "at::vec::Vectorized(-std::numeric_limits::infinity())"
-            )
-        elif other == float("inf"):
-            other_code = (
-                "at::vec::Vectorized(std::numeric_limits::infinity())"
-            )
-        elif math.isnan(other):
-            other_code = (
-                "at::vec::Vectorized(std::numeric_limits::quiet_NaN())"
+        other_code = value_to_cpp(other, "float")
+        other_code_vec = f"at::vec::Vectorized({other_code})"
+
+        if result.is_vec:
+            type = f"decltype({var}())"
+            float_mask = f"to_float_mask({new_mask})"
+            csevar = V.kernel.cse.generate(
+                V.kernel.compute,
+                f"{type}::blendv({other_code_vec}, {var}(), {float_mask})",
             )
         else:
-            other_code = f"at::vec::Vectorized({other!r})"
-        type = f"decltype({var}())"
-        float_mask = f"to_float_mask({new_mask})"
-        return f"{type}::blendv({other_code}, {var}(), {float_mask})"
-
-    @staticmethod
-    def index_expr(expr, dtype):
-        assert dtype == torch.int64
-        opt_ctx: OptimizationContext = get_current_node_opt_ctx()
-        assert opt_ctx
-        assert opt_ctx.dtype == torch.int32
-        assert opt_ctx.is_most_inner_loop_irrevelant
-        return f"at::vec::Vectorized(static_cast({cexpr(V.kernel.rename_indexing(expr))}))"
+            csevar = V.kernel.cse.generate(
+                V.kernel.compute, f"{mask} ? {var}() : {other_code}"
+            )
+        # `result` is explicitly added to the args for correct propagation
+        # of relevant itervars and vectorization status.
+        csevar.update_on_args("masked", (mask, body, other, result), {})
+        return csevar
 
 
 class CppKernel(Kernel):
@@ -1242,7 +1307,9 @@ def load(self, name: str, index: sympy.Expr):
         line = f"{var}[{cexpr_index(index)}]"
         if V.graph.get_dtype(name) in [torch.float16]:
             line = f"static_cast({line})"
-        return self.cse.generate(self.loads, line)
+        csevar = self.cse.generate(self.loads, line)
+        csevar.update_on_args("load", (name, index), {})
+        return csevar
 
     def store(self, name, index, value, mode=None):
         assert "buf" in name
@@ -1472,6 +1539,9 @@ def write_to_suffix(self):
         self.reduction_suffix.splice(self.stores)
         (self.loads, self.compute, self.stores, self.cse) = prior
 
+    def create_cse_var(self, *args, **kwargs):
+        return CppCSEVariable(*args, **kwargs)
+
 
 class CppVecKernel(CppKernel):
     overrides = CppVecOverrides  # type: ignore[assignment]
@@ -1506,7 +1576,11 @@ def load(self, name: str, index: sympy.Expr):
         non_contiguous = (
             not is_broadcast
             and stride_at(tiling_var, index) != 1
-            or "tmp" in f"{index}"
+            or any(
+                self.cse.varname_map[s.name].depends_on(tiling_var)
+                for s in index.free_symbols
+                if s.name.startswith("tmp")
+            )
         )
         var_expr = (
             f"{var}[{cexpr_index(index)}]"
@@ -1515,13 +1589,9 @@ def load(self, name: str, index: sympy.Expr):
         )
         loadbuf = "tmpbuf" if non_contiguous else var_expr
         if is_broadcast:
-            # should always be broadcast as float for vectorization since we always use float to compute
-            if is_mask:
-                loadbuf = f"flag_to_float_scalar({loadbuf})"
-            if dtype in DTYPE_LOWP_FP:
-                line = f"at::vec::Vectorized<{DTYPE_TO_CPP[dtype]}>({loadbuf})"
-            else:
-                line = f"at::vec::Vectorized(static_cast({loadbuf}))"
+            csevar = super().load(name, index)
+            csevar.dtype = dtype
+            return csevar
         elif dtype in [torch.uint8] and opt_ctx.is_load_uint8_as_float:
             line = (
                 f"masked_load({loadbuf}, {load_mask})"
@@ -1563,7 +1633,11 @@ def load(self, name: str, index: sympy.Expr):
             tmpbufdefine += f"tmpbuf[{inner}] = {rhs};"
             line = f"([&]() {{ {tmpbufdeclare} {tmpbufdefine} return {line}; }})()"
 
-        return self.cse.generate(self.loads, line)
+        csevar = self.cse.generate(self.loads, line)
+        csevar.update_on_args("load", (name, index), {})
+        assert isinstance(csevar, CppCSEVariable)
+        csevar.is_vec = True
+        return csevar
 
     def get_vec_store_line(self, value, var, index, dtype):
         """
@@ -1572,6 +1646,11 @@ def get_vec_store_line(self, value, var, index, dtype):
         :param var: buffer to store into.
         :index: index into the `var`.
         """
+        # when value's type is str (e.g., welford reduction), caller should make sure
+        # it is a vector
+        assert isinstance(value, str) or (
+            isinstance(value, CppCSEVariable) and value.is_vec
+        ), value
         tiling_var = self.itervars[self.tiling_idx]
         assert index.has(tiling_var)
         var_expr = f"{var} + {cexpr_index(index)}"
@@ -1600,6 +1679,10 @@ def get_vec_store_line(self, value, var, index, dtype):
     def store(self, name, index, value, mode=None):
         assert "buf" in name
         assert mode is None
+        assert isinstance(value, CppCSEVariable), value
+        if not value.is_vec:
+            # this happens when we store a scalar into a vectorized buffer like "fill"
+            value = self.broadcast(value)
         opt_ctx: OptimizationContext = get_current_node_opt_ctx()
         var = self.args.output(name)
         index = self.rename_indexing(index)
@@ -1622,6 +1705,7 @@ def reduction(self, dtype, src_dtype, reduction_type, value):
         }
         assert dtype == torch.float
         assert src_dtype == torch.float
+        assert isinstance(value, CppCSEVariable) and value.is_vec, value
 
         vec_ns = "at::vec"
         vec = f"{vec_ns}::Vectorized<{DTYPE_TO_CPP[dtype]}>"
@@ -1740,6 +1824,27 @@ def store_reduction(self, name, index, value):
             ]
             self.reduction_suffix.writelines(store_lines)
 
+    def broadcast(self, scalar_var: CppCSEVariable):
+        assert (
+            not scalar_var.is_vec
+            and self.itervars[self.tiling_idx] not in scalar_var.dependent_itervars
+        )
+        if scalar_var.dtype == torch.bool:
+            vec_var = self.cse.generate(
+                self.compute, f"to_float_mask({scalar_var.name})"
+            )
+        else:
+            assert scalar_var.dtype is not None
+            vec_var = self.cse.generate(
+                self.compute,
+                f"at::vec::Vectorized<{DTYPE_TO_CPP[scalar_var.dtype]}>({scalar_var.name})",
+            )
+        assert isinstance(vec_var, CppCSEVariable)
+        vec_var.dtype = scalar_var.dtype
+        vec_var.dependent_itervars = scalar_var.dependent_itervars
+        vec_var.is_vec = True
+        return vec_var
+
 
 class CppTile2DKernel(CppVecKernel):
     """
@@ -1849,7 +1954,11 @@ def load(self, name: str, index: sympy.Expr):
                 line = f"at::vec::Vectorized::loadu_one_fourth({loadbuf})"
             else:
                 line = f"at::vec::Vectorized::loadu({loadbuf})"
-            return self.cse.generate(self.loads, line)
+            csevar = self.cse.generate(self.loads, line)
+            csevar.update_on_args("load", (name, index), {})
+            assert isinstance(csevar, CppCSEVariable)
+            csevar.is_vec = True
+            return csevar
         else:
             new_index = self.scale_index_with_offset(
                 index,
@@ -1950,10 +2059,6 @@ def disable_vec(self, msg=None):
             schedule_log.debug("Disabled vectorization: %s", msg)
         self.simd_vec = False
 
-    def could_vec(self, name: str, index: sympy.Expr):
-        assert self.itervars is not None
-        return len(self.itervars) > 0
-
     def is_mask(self, name: str, users: Dict[torch.fx.Node, None]):
         load_type = V.graph.get_dtype(name)
         if load_type == torch.bool:
@@ -2036,6 +2141,10 @@ def load(self, name: str, index: sympy.Expr):
 
             var = self.cse.newvar()
 
+            if len(self.itervars) == 0:
+                self.disable_vec("not a loop")
+                return var
+
             if load_dtype in [torch.bool, torch.uint8] and not (
                 opt_ctx.is_load_as_mask or opt_ctx.is_load_uint8_as_float
             ):
@@ -2046,18 +2155,21 @@ def load(self, name: str, index: sympy.Expr):
                 return var
 
             if (
-                load_dtype not in self.load_supported_dtypes
-            ) and not self.is_load_integer_scalar_tensor(name, index):
+                (load_dtype not in self.load_supported_dtypes)
+                and not self.is_load_integer_scalar_tensor(name, index)
+                and index.has(self.itervars[self.tiling_idx])
+            ):
                 self.disable_vec(f"{load_dtype} not supported by load")
                 return var
 
-            index = self.rename_indexing(index)
-            if self.simd_vec and not self.could_vec(name, index):
-                self.disable_vec(f"not a loop: {index}")
             return var
 
     def store(self, name, index, value, mode=None):
         with RecordOptimizationContext(__name__) as node_ctx:
+            if len(self.itervars) == 0:
+                self.disable_vec("not a loop")
+                return self.simd_vec
+
             store_dtype = V.graph.get_dtype(name)
 
             opt_ctx: OptimizationContext = node_ctx.get_opt_ctx()
@@ -2085,8 +2197,6 @@ def store(self, name, index, value, mode=None):
 
             if index.is_number:
                 self.disable_vec(f"constant store index: {index}")
-            if self.simd_vec and not self.could_vec(name, index):
-                self.disable_vec(f"not a loop: {index}")
             return self.simd_vec
 
     def reduction(self, dtype, src_dtype, reduction_type, value):
diff --git a/torch/_inductor/codegen/cpp_prefix.h b/torch/_inductor/codegen/cpp_prefix.h
index 1a532029cdeb..23f72218a0cc 100644
--- a/torch/_inductor/codegen/cpp_prefix.h
+++ b/torch/_inductor/codegen/cpp_prefix.h
@@ -401,4 +401,10 @@ template <>
 inline at::vec::Vectorized to_float_mask(at::vec::Vectorized src) {
   return src;
 }
+
+inline at::vec::Vectorized to_float_mask(int src) {
+  float mask;
+  *(uint32_t*)&mask = src ? 0xFFFFFFFF : 0;
+  return at::vec::Vectorized(mask);
+}
 #endif

From 2b72543f3667096f3e41ddab528af14191c15ec3 Mon Sep 17 00:00:00 2001
From: ancestor-mithril 
Date: Wed, 22 Nov 2023 11:38:31 +0000
Subject: [PATCH 096/221] Solving pickle error when saving CyclicLR state_dict
 (#110931)

## How to reproduce:
```py
import os
import tempfile

import torch
from torch import nn
from torch.optim import SGD
from torch.optim.lr_scheduler import CyclicLR

model = nn.Linear(100, 100)
opt = SGD(model.parameters(), lr=1.)
scheduler = CyclicLR(opt, base_lr=0.1, max_lr=0.2, scale_fn=lambda x: 0.99)

tmp = tempfile.NamedTemporaryFile(delete=False)
try:
    torch.save(scheduler.state_dict(), tmp.name)
    scheduler.load_state_dict(torch.load(tmp.name))
finally:
    tmp.close()
    os.unlink(tmp.name)
```
Error:
```
_pickle.PicklingError: Can't pickle  at 0x000001A51DF67600>: attribute lookup  on __main__ failed
```
## Fix:
Saving `scale_fn` to the state dict only if it is a callable object and not if it is a function or lambda.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110931
Approved by: https://github.com/janeyx99
---
 test/optim/test_lrscheduler.py | 29 +++++++++++++++++++++++++++++
 torch/optim/lr_scheduler.py    | 16 +++++++++++++---
 2 files changed, 42 insertions(+), 3 deletions(-)

diff --git a/test/optim/test_lrscheduler.py b/test/optim/test_lrscheduler.py
index 962026312f10..20c5028649e5 100644
--- a/test/optim/test_lrscheduler.py
+++ b/test/optim/test_lrscheduler.py
@@ -1530,10 +1530,39 @@ def test():
 
     def test_cycle_lr_state_dict_picklable(self):
         adam_opt = Adam(self.net.parameters())
+
+        # Case 1: Built-in mode
         scheduler = CyclicLR(adam_opt, base_lr=1, max_lr=5, cycle_momentum=False)
         self.assertIsInstance(scheduler._scale_fn_ref, types.FunctionType)
         state = scheduler.state_dict()
         self.assertNotIn("_scale_fn_ref", state)
+        self.assertIs(state["_scale_fn_custom"], None)
+        pickle.dumps(state)
+
+        # Case 2: Custom `scale_fn`, a function object
+        def scale_fn(_):
+            return 0.5
+
+        scheduler = CyclicLR(adam_opt, base_lr=1, max_lr=5, cycle_momentum=False, scale_fn=scale_fn)
+        state = scheduler.state_dict()
+        self.assertNotIn("_scale_fn_ref", state)
+        self.assertIs(state["_scale_fn_custom"], None)
+        pickle.dumps(state)
+
+        # Case 3: Custom `scale_fn`, a callable class
+        class ScaleFn:
+            def __init__(self):
+                self.x = 0.5
+
+            def __call__(self, _):
+                return self.x
+
+        scale_fn = ScaleFn()
+
+        scheduler = CyclicLR(adam_opt, base_lr=1, max_lr=5, cycle_momentum=False, scale_fn=scale_fn)
+        state = scheduler.state_dict()
+        self.assertNotIn("_scale_fn_ref", state)
+        self.assertEqual(state["_scale_fn_custom"], scale_fn.__dict__)
         pickle.dumps(state)
 
     def test_cycle_lr_scale_fn_restored_from_state_dict(self):
diff --git a/torch/optim/lr_scheduler.py b/torch/optim/lr_scheduler.py
index 94ee154198cc..df659b61a998 100644
--- a/torch/optim/lr_scheduler.py
+++ b/torch/optim/lr_scheduler.py
@@ -1365,16 +1365,26 @@ def get_lr(self):
 
     def state_dict(self):
         state = super().state_dict()
-        # We are dropping the `_scale_fn_ref` attribute because it is a `weakref.WeakMethod` and can't be pickled
-        state.pop("_scale_fn_ref")
+        # We are dropping the `_scale_fn_ref` attribute because it is a
+        # `weakref.WeakMethod` and can't be pickled.
+        state.pop('_scale_fn_ref')
+        fn = state.pop('_scale_fn_custom')
+        state['_scale_fn_custom'] = None
+        if fn is not None and not isinstance(fn, types.FunctionType):
+            # The _scale_fn_custom will only be saved if it is a callable object
+            # and not if it is a function or lambda.
+            state['_scale_fn_custom'] = fn.__dict__.copy()
+
         return state
 
     def load_state_dict(self, state_dict):
+        fn = state_dict.pop('_scale_fn_custom')
         super().load_state_dict(state_dict)
+        if fn is not None:
+            self._scale_fn_custom.__dict__.update(fn)
         self._init_scale_fn()
 
 
-
 class CosineAnnealingWarmRestarts(LRScheduler):
     r"""Set the learning rate of each parameter group using a cosine annealing
     schedule, where :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}`

From b4faa6bfa467d48959eeb56d0d739d05873e6fd0 Mon Sep 17 00:00:00 2001
From: Jon Chuang <9093549+jon-chuang@users.noreply.github.com>
Date: Wed, 22 Nov 2023 12:26:37 +0000
Subject: [PATCH 097/221] [dynamo] report guard failure user stack, fix
 incorrectly skipping interesting files (#114053)

Fixes https://github.com/pytorch/pytorch/issues/114015

Before:
```
test/dynamo/test_functions.py::DefaultsTests::test_zip_strict [2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] GUARDS:
[2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] hasattr(L['x'], '_dynamo_dynamic_indices') == False
[2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['ys'], 94696321555200)
[2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] len(L['ys']) == 3
[2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['zs'], 94696321555200)
[2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] len(L['zs']) == 3
[2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['ys'][0], 94696321556032)
[2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] L['ys'][0] == 1.0
[2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['ys'][1], 94696321556032)
[2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] L['ys'][1] == 2.0
[2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['ys'][2], 94696321556032)
[2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] L['ys'][2] == 3.0
[2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['zs'][0], 94696321556032)
[2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] L['zs'][0] == 2.0
[2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['zs'][1], 94696321556032)
[2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] L['zs'][1] == 5.0
[2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['zs'][2], 94696321556032)
[2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] L['zs'][2] == 8.0
[2023-11-18 23:11:09,317] [0/0] torch._dynamo.guards.__guards: [DEBUG] utils_device.CURRENT_DEVICE == None                           # _dynamo/output_graph.py:365 in init_ambient_guards
[2023-11-18 23:11:09,317] [0/0] torch._dynamo.guards.__guards: [DEBUG] (___skip_backend_check() or ___current_backend() == ___lookup_backend(140084534469552))  # _dynamo/output_graph.py:371 in init_ambient_guards
[2023-11-18 23:11:09,317] [0/0] torch._dynamo.guards.__guards: [DEBUG] check_tensor(L['x'], Tensor, DispatchKeySet(CPU, BackendSelect, ADInplaceOrView, AutogradCPU), torch.float32, device=None, requires_grad=False, size=[3], stride=[1])
[2023-11-18 23:11:09,320] torch._dynamo.guards.__recompiles: [DEBUG] Recompiling function fn in /home/jonch/Desktop/Programming/mlsys/pytorch/test/dynamo/test_functions.py:2539
[2023-11-18 23:11:09,320] torch._dynamo.guards.__recompiles: [DEBUG]     triggered by the following guard failure(s):
[2023-11-18 23:11:09,320] torch._dynamo.guards.__recompiles: [DEBUG]     - L['zs'][2] == 8.0

```

After:
```
test/dynamo/test_functions.py::DefaultsTests::test_zip_strict [2023-11-18 23:07:33,341] [0/0] torch._dynamo.guards.__guards: [DEBUG] GUARDS:
[2023-11-18 23:07:33,341] [0/0] torch._dynamo.guards.__guards: [DEBUG] hasattr(L['x'], '_dynamo_dynamic_indices') == False           # x = x.clone()  # test/dynamo/test_functions.py:2540 in fn
[2023-11-18 23:07:33,341] [0/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['ys'], 94568804551424)                     # for y, z in zip(ys, zs, strict=True):  # test/dynamo/test_functions.py:2541 in fn
[2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] len(L['ys']) == 3                                             # for y, z in zip(ys, zs, strict=True):  # test/dynamo/test_functions.py:2541 in fn
[2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['zs'], 94568804551424)                     # for y, z in zip(ys, zs, strict=True):  # test/dynamo/test_functions.py:2541 in fn
[2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] len(L['zs']) == 3                                             # for y, z in zip(ys, zs, strict=True):  # test/dynamo/test_functions.py:2541 in fn
[2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['ys'][0], 94568804552256)                  # for y, z in zip(ys, zs, strict=True):  # test/dynamo/test_functions.py:2541 in fn
[2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] L['ys'][0] == 1.0                                             # for y, z in zip(ys, zs, strict=True):  # test/dynamo/test_functions.py:2541 in fn
[2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['ys'][1], 94568804552256)                  # for y, z in zip(ys, zs, strict=True):  # test/dynamo/test_functions.py:2541 in fn
[2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] L['ys'][1] == 2.0                                             # for y, z in zip(ys, zs, strict=True):  # test/dynamo/test_functions.py:2541 in fn
[2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['ys'][2], 94568804552256)                  # for y, z in zip(ys, zs, strict=True):  # test/dynamo/test_functions.py:2541 in fn
[2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] L['ys'][2] == 3.0                                             # for y, z in zip(ys, zs, strict=True):  # test/dynamo/test_functions.py:2541 in fn
[2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['zs'][0], 94568804552256)                  # for y, z in zip(ys, zs, strict=True):  # test/dynamo/test_functions.py:2541 in fn
[2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] L['zs'][0] == 2.0                                             # for y, z in zip(ys, zs, strict=True):  # test/dynamo/test_functions.py:2541 in fn
[2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['zs'][1], 94568804552256)                  # for y, z in zip(ys, zs, strict=True):  # test/dynamo/test_functions.py:2541 in fn
[2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] L['zs'][1] == 5.0                                             # for y, z in zip(ys, zs, strict=True):  # test/dynamo/test_functions.py:2541 in fn
[2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['zs'][2], 94568804552256)                  # for y, z in zip(ys, zs, strict=True):  # test/dynamo/test_functions.py:2541 in fn
[2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] L['zs'][2] == 8.0                                             # for y, z in zip(ys, zs, strict=True):  # test/dynamo/test_functions.py:2541 in fn
[2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] utils_device.CURRENT_DEVICE == None                           # _dynamo/output_graph.py:365 in init_ambient_guards
[2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] (___skip_backend_check() or ___current_backend() == ___lookup_backend(140370726823264))  # _dynamo/output_graph.py:371 in init_ambient_guards
[2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] check_tensor(L['x'], Tensor, DispatchKeySet(CPU, BackendSelect, ADInplaceOrView, AutogradCPU), torch.float32, device=None, requires_grad=False, size=[3], stride=[1])  # x = x.clone()  # test/dynamo/test_functions.py:2540 in fn
[2023-11-18 23:07:33,346] torch._dynamo.guards.__recompiles: [DEBUG] Recompiling function fn in /home/jonch/Desktop/Programming/mlsys/pytorch/test/dynamo/test_functions.py:2539
[2023-11-18 23:07:33,346] torch._dynamo.guards.__recompiles: [DEBUG]     triggered by the following guard failure(s):
[2023-11-18 23:07:33,346] torch._dynamo.guards.__recompiles: [DEBUG]     - L['zs'][2] == 8.0                                             # for y, z in zip(ys, zs, strict=True):  # test/dynamo/test_functions.py:2541 in fn

```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114053
Approved by: https://github.com/ezyang
---
 test/dynamo/test_aot_autograd.py | 32 +++++++++++++++----------------
 test/dynamo/test_logging.py      | 32 +++++++++++++++++++++++++++++++
 test/dynamo/test_misc.py         | 33 ++++++++++++++++++--------------
 torch/_dynamo/guards.py          |  9 ++++++---
 4 files changed, 73 insertions(+), 33 deletions(-)

diff --git a/test/dynamo/test_aot_autograd.py b/test/dynamo/test_aot_autograd.py
index 07af08acb36f..169c4b60452b 100644
--- a/test/dynamo/test_aot_autograd.py
+++ b/test/dynamo/test_aot_autograd.py
@@ -302,9 +302,9 @@ def guard_fail_fn(failure):
         fxy = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
         compare_equal_outs_and_grads(self, F(), fxy, (x, y))
         compare_equal_outs_and_grads(self, F(), fxy, (x, z))
-        self.assertExpectedInline(
-            failure_reason,
+        self.assertIn(
             """tensor 'L['y']' requires_grad mismatch. expected requires_grad=1""",
+            failure_reason,
         )
 
         # Reset failure reason
@@ -421,7 +421,7 @@ def guard_fail_fn(failure):
         fxx(x3, x3)
         fxx(x4, y4)
         self.assertEqual(cc.frame_count, 2)
-        self.assertExpectedInline(failure_reason, """L['x'] is L['y']""")
+        self.assertIn("""L['x'] is L['y']""", failure_reason)
 
     @patch("torch._functorch.config.debug_assert", True)
     def test_arg_dupe_via_dynamo_recompiles_many_args_param_non_tensor_arg(self):
@@ -456,9 +456,9 @@ def guard_fail_fn(failure):
         f(a1, a1, a1, a1, 2, 2)
         f(a2, b2, b2, b2, 2, 2)
         self.assertEqual(cc.frame_count, 2)
-        self.assertExpectedInline(
-            failure_reason,
+        self.assertIn(
             """L['a'] is L['b']""",
+            failure_reason,
         )
 
         torch._dynamo.reset()
@@ -474,7 +474,7 @@ def guard_fail_fn(failure):
         f(a3, b3, c3, c3, 3, 3)
         f(a4, b4, c4, d4, 3, 3)
         self.assertEqual(cc.frame_count, 2)
-        self.assertExpectedInline(failure_reason, """L['c'] is L['d']""")
+        self.assertIn("""L['c'] is L['d']""", failure_reason)
 
     @patch("torch._functorch.config.debug_assert", True)
     def test_arg_dupe_via_dynamo_recompiles_many_with_global(self):
@@ -512,9 +512,9 @@ def guard_fail_fn(failure):
         f(a1, a1, a1, a1, 2, 2)
         f(a2, b2, b2, b2, 2, 2)
         self.assertEqual(cc.frame_count, 2)
-        self.assertExpectedInline(
-            failure_reason,
+        self.assertIn(
             """L['a'] is L['b']""",
+            failure_reason,
         )
 
     @patch("torch._functorch.config.debug_assert", True)
@@ -550,9 +550,9 @@ def guard_fail_fn(failure):
         f([3, 2, 1], [4, 5, 6], a1, a1, a1, a1)
         f([3, 2, 1], [4, 5, 6], a2, b2, b2, b2)
         self.assertEqual(cc.frame_count, 2)
-        self.assertExpectedInline(
-            failure_reason,
+        self.assertIn(
             """L['a'] is L['b']""",
+            failure_reason,
         )
 
         torch._dynamo.reset()
@@ -602,9 +602,9 @@ def guard_fail_fn(failure):
         f(a1, a1, a1, a1)
         f(a2, b2, b2, b2)
         self.assertEqual(cc.frame_count, 2)
-        self.assertExpectedInline(
-            failure_reason,
+        self.assertIn(
             """L['a'] is L['b']""",
+            failure_reason,
         )
 
         torch._dynamo.reset()
@@ -620,7 +620,7 @@ def guard_fail_fn(failure):
         f(a3, b3, c3, c3)
         f(a4, b4, c4, d4)
         self.assertEqual(cc.frame_count, 2)
-        self.assertExpectedInline(failure_reason, """L['c'] is L['d']""")
+        self.assertIn("""L['c'] is L['d']""", failure_reason)
 
     @patch("torch._functorch.config.debug_assert", True)
     def test_arg_dupe_via_dynamo_recompiles_many_args(self):
@@ -651,9 +651,9 @@ def guard_fail_fn(failure):
         f(a1, a1, a1, a1)
         f(a2, b2, b2, b2)
         self.assertEqual(cc.frame_count, 2)
-        self.assertExpectedInline(
-            failure_reason,
+        self.assertIn(
             """L['a'] is L['b']""",
+            failure_reason,
         )
 
         torch._dynamo.reset()
@@ -669,7 +669,7 @@ def guard_fail_fn(failure):
         f(a3, b3, c3, c3)
         f(a4, b4, c4, d4)
         self.assertEqual(cc.frame_count, 2)
-        self.assertExpectedInline(failure_reason, """L['c'] is L['d']""")
+        self.assertIn("""L['c'] is L['d']""", failure_reason)
 
     @expectedFailureDynamic  # https://github.com/pytorch/pytorch/issues/103539
     @torch._dynamo.config.patch(automatic_dynamic_shapes=False)
diff --git a/test/dynamo/test_logging.py b/test/dynamo/test_logging.py
index 9b77b2f2e3aa..7e233263d3e2 100644
--- a/test/dynamo/test_logging.py
+++ b/test/dynamo/test_logging.py
@@ -596,6 +596,38 @@ def fn(x):
                    ~~^~~""",
         )
 
+    @make_logging_test(guards=True, recompiles=True)
+    def test_guards_recompiles(self, records):
+        def fn(x, ys, zs):
+            return inner(x, ys, zs)
+
+        def inner(x, ys, zs):
+            for y, z in zip(ys, zs):
+                x += y * z
+            return x
+
+        ys = [1.0, 2.0]
+        zs = [3.0]
+        x = torch.tensor([1.0])
+
+        fn_opt = torch._dynamo.optimize("eager")(fn)
+        fn_opt(x, ys, zs)
+        fn_opt(x, ys[:1], zs)
+
+        record_str = "\n".join(r.getMessage() for r in records)
+
+        self.assertIn(
+            """\
+L['zs'][0] == 3.0                                             # for y, z in zip(ys, zs):""",
+            record_str,
+        )
+        self.assertIn(
+            """\
+    triggered by the following guard failure(s):\n\
+    - len(L['ys']) == 2                                             # for y, z in zip(ys, zs):""",
+            record_str,
+        )
+
     @make_logging_test(**torch._logging.DEFAULT_LOGGING)
     def test_default_logging(self, records):
         def fn(a):
diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py
index 99d7e1ff0ffb..61928d4abd84 100644
--- a/test/dynamo/test_misc.py
+++ b/test/dynamo/test_misc.py
@@ -650,9 +650,9 @@ def guard_failures(failure):
         )(compare_shapes)
         opt_fn(torch.randn([3, 4]))
         opt_fn(torch.randn([4, 3]))
-        self.assertExpectedInline(
-            guard_failure.reason,
+        self.assertIn(
             """tensor 'L['a']' size mismatch at index 0. expected 3, actual 4""",
+            guard_failure.reason,
         )
 
     def test_builtin_abs(self):
@@ -716,9 +716,9 @@ def fn(x, y):
             ),
             sorted(guard_code),
         )
-        self.assertExpectedInline(
-            "\n".join(guard_code),
-            """\
+        guard_code_str = "\n".join(guard_code)
+
+        for line in """\
 2 <= L['x'].size()[0]
 L['x'] is L['y']
 L['x'].ndimension() == 2
@@ -734,8 +734,13 @@ def fn(x, y):
 not ___dict_contains('cccccccc', G['sys'].modules)
 str(L['x'].device) == 'cpu'
 str(L['x'].dtype) == 'torch.float32'
-utils_device.CURRENT_DEVICE == None""",
-        )
+utils_device.CURRENT_DEVICE == None""".split(
+            "\n"
+        ):
+            self.assertIn(
+                line,
+                guard_code_str,
+            )
 
     def test_fold(self):
         def fn(a):
@@ -5240,12 +5245,12 @@ def guard_failures(failure):
         self.assertTrue(guard_failure is not None)
         first_guard_failure = guard_failure[0].partition("\n")[0]
         if torch._dynamo.config.assume_static_by_default:
-            self.assertExpectedInline(
-                first_guard_failure,
+            self.assertIn(
                 """tensor 'L['x']' size mismatch at index 0. expected 2, actual 5""",
+                first_guard_failure,
             )
         else:
-            self.assertExpectedInline(first_guard_failure, """L['x'].size()[0] < 3""")
+            self.assertIn("""L['x'].size()[0] < 3""", first_guard_failure)
 
     def test_guard_failure_fn2(self):
         def fn(x, y):
@@ -5273,9 +5278,9 @@ def guard_failures(failure):
         opt_fn(x2, y2)
 
         if torch._dynamo.config.assume_static_by_default:
-            self.assertExpectedInline(
-                guard_failure[0],
+            self.assertIn(
                 """tensor 'L['x']' size mismatch at index 0. expected 2, actual 3""",
+                guard_failure[0],
             )
         else:
             self.assertTrue(guard_failure is None)
@@ -5308,9 +5313,9 @@ def guard_failures(failure):
 
         # guard is expected for both static and dynamic shapes
         self.assertTrue(guard_failure is not None)
-        self.assertExpectedInline(
-            guard_failure[0],
+        self.assertIn(
             """len(L['x']) == 10""",
+            guard_failure[0],
         )
 
     def test_restore_graphstate(self):
diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py
index 9c182ac40a82..1b068402019b 100644
--- a/torch/_dynamo/guards.py
+++ b/torch/_dynamo/guards.py
@@ -1031,15 +1031,15 @@ def compile_check_fn(self, builder, guards_out, guard_fail_fn):
 
         # Don't report this guard, it's always the same, useless!
         code_parts = ["___guarded_code.valid", "___check_global_state()"]
+        verbose_code_parts = code_parts[:]
 
         def add_code_part(code, guard, log_only=False):
             extra = ""
             if guard.user_stack:
                 for fs in reversed(guard.user_stack):
                     if fs.filename not in uninteresting_files():
+                        extra = f"  # {format_frame(fs, line=True)}"
                         break
-                else:
-                    extra = f"  # {format_frame(fs, line=True)}"
             elif guard.stack:
                 extra = f"  # {format_frame(guard.stack.summary()[-1])}"
 
@@ -1064,6 +1064,7 @@ def add_code_part(code, guard, log_only=False):
 
             if not log_only:
                 code_parts.append(code)
+                verbose_code_parts.append(f"{code:<60}{extra}")
 
         seen = set()
         for gcl in builder.code:
@@ -1113,6 +1114,7 @@ def convert(size_or_stride):
             )
             # Do this manually, to un-stagger the guards in log message
             code_parts.append(f"___check_tensors({tensor_check_args})")
+            verbose_code_parts.append(f"___check_tensors({tensor_check_args})")
             tensor_check_guards = builder.tensor_check_guards
 
             for i, name in enumerate(tensor_check_names):
@@ -1183,6 +1185,7 @@ def convert(size_or_stride):
         # TODO(whc) maybe '.code_parts' was only kept around for the guard callback? so we don't need both
         guard_fn.args = largs
         guard_fn.code_parts = code_parts
+        guard_fn.verbose_code_parts = verbose_code_parts
         # Grab only G, but preserve "G" because guards access it as "G"
         guard_fn.global_scope = {
             "G": builder.scope["G"],
@@ -1282,7 +1285,7 @@ def get_guard_fail_reason(
     scope.update(guard_fn.closure_vars)
     scope["___check_tensors"] = scope["___check_tensors_verbose"]
     reasons: List[str] = []
-    for part in guard_fn.code_parts:
+    for part in guard_fn.verbose_code_parts:
         global_scope = dict(guard_fn.global_scope)
         global_scope["__compile_source__"] = part
         with report_compile_source_on_error():

From e239a2b2d7d1376e81fc22733d48f45f6855c14a Mon Sep 17 00:00:00 2001
From: PyTorch MergeBot 
Date: Wed, 22 Nov 2023 12:46:15 +0000
Subject: [PATCH 098/221] Revert "[dynamo / DDP] - lazily compile submodules -
 to propagate real tensor strides to backend compiler (#114154)"

This reverts commit 266054c3cac0f800f37348aea1409c4759dd2315.

Reverted https://github.com/pytorch/pytorch/pull/114154 on behalf of https://github.com/DanilBaibak due to The lower PR in the stack https://github.com/pytorch/pytorch/pull/113926 breaks the internal build ([comment](https://github.com/pytorch/pytorch/pull/114154#issuecomment-1822704476))
---
 test/distributed/test_dynamo_distributed.py |  41 ------
 torch/_dynamo/backends/distributed.py       | 141 +++++++++++++-------
 2 files changed, 92 insertions(+), 90 deletions(-)

diff --git a/test/distributed/test_dynamo_distributed.py b/test/distributed/test_dynamo_distributed.py
index 1547e595c924..82d4248fb6cb 100644
--- a/test/distributed/test_dynamo_distributed.py
+++ b/test/distributed/test_dynamo_distributed.py
@@ -544,7 +544,6 @@ def test_ddp_baseline_inductor(self):
 
     @patch.object(config, "optimize_ddp", True)
     def test_graph_split(self):
-        assert config.optimize_ddp
         """
         Just ensures that the appropriate number of splits happen (based on
         bucket size and model parameters) - verifies the number of times
@@ -626,7 +625,6 @@ def opt_fn(inputs):
     @patch.object(config, "optimize_ddp", True)
     @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
     def test_graph_split_inductor(self):
-        assert config.optimize_ddp
         """
         Same as above, but using inductor backend.
         We observed issues with inductor/fx interface in the past.
@@ -641,45 +639,6 @@ def opt_fn(inputs):
         opt_outputs = opt_fn(inputs)
         self.assertTrue(same(correct_outputs, opt_outputs))
 
-    @torch._inductor.config.patch({"layout_optimization": True, "keep_output_stride": False})
-    @patch.object(config, "optimize_ddp", True)
-    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
-    def test_graph_split_inductor_layout_optimizations(self):
-        assert config.optimize_ddp
-        channel_dim = 512
-        # channel dim must be > 64 for inductor to do layout optimization and use NHWC
-
-        class ToyModelConv(nn.Module):
-            def __init__(self):
-                super().__init__()
-                self.net = nn.Sequential(
-                    *[nn.Conv2d(channel_dim, channel_dim, 1, stride=1, bias=False), nn.ReLU()]
-                    + [nn.Conv2d(channel_dim, channel_dim, 1, stride=1, bias=False), nn.ReLU()]
-                    + [nn.Conv2d(channel_dim, channel_dim, 1, stride=1, bias=False), nn.ReLU()]
-                    + [nn.Conv2d(channel_dim, channel_dim, 1, stride=1, bias=False), nn.ReLU()]
-                )
-
-            def forward(self, inputs):
-                return self.net(inputs)
-
-        def get_model():
-            m = ToyModelConv().to(self.device)
-            m.apply(init_weights)
-            inputs = torch.rand(2, channel_dim, channel_dim, 128).to(self.device)
-            outputs = m(inputs)
-            return m, inputs, outputs
-
-        m, inputs, correct_outputs = get_model()
-        ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25)
-
-        @torch._dynamo.optimize("inductor")
-        def opt_fn(inputs):
-            return ddp_m(inputs)
-
-        opt_outputs = opt_fn(inputs)
-        self.assertTrue(same(correct_outputs, opt_outputs))
-
-
     @patch.object(config, "optimize_ddp", True)
     def test_no_split(self):
         """
diff --git a/torch/_dynamo/backends/distributed.py b/torch/_dynamo/backends/distributed.py
index 774f378b15b2..adc68bb30bff 100644
--- a/torch/_dynamo/backends/distributed.py
+++ b/torch/_dynamo/backends/distributed.py
@@ -6,8 +6,7 @@
 import torch
 from torch import fx
 from torch._dynamo.output_graph import GraphCompileReason
-from torch._dynamo.utils import detect_fake_mode
-from torch._subclasses.fake_tensor import is_fake
+from torch._dynamo.utils import deepcopy_to_fake_tensor, detect_fake_mode
 from torch.fx.node import Node
 
 log = logging.getLogger(__name__)
@@ -215,6 +214,23 @@ def compile_fn(self, gm: fx.GraphModule, example_inputs: List[torch.Tensor]):
         and returns its callable.
         """
 
+        # Today, optimize_ddp=True and keep_output_stride=False can lead to silent
+        # correctness issues. The problem is that ddp_optimizer works by partitioning
+        # the dynamo graph, sending each subgraph through aot autograd to inductor,
+        # and creates example inputs by eagerly interpreting each subgraph to get
+        # an output that with the same metadata that we'd get from eager mode.
+        # This is a problem though, for torch._inductor.config.keep_output_stride.
+        # The above config can cause the outputs of the first graph to have
+        # **different** strides from eager, causing the inputs that we pass
+        # to the second graph to be wrong.
+        # To really fix this, we would need to faithfully ask inductor
+        # what the outputs to each graph it expects are.
+        assert torch._inductor.config.keep_output_stride, """\
+Detected that you are running DDP with torch.compile, along with these two flags:
+- torch._dynamo.config.optimize_ddp = True
+- torch._inductor.config.keep_output_stride = False
+This combination of flags is incompatible. Please set keep_output_stride to False,
+or file a github issue."""
         fake_mode = detect_fake_mode(example_inputs)
         if fake_mode is None:
             fake_mode = torch._subclasses.fake_tensor.FakeTensorMode()
@@ -313,54 +329,32 @@ def compile_fn(self, gm: fx.GraphModule, example_inputs: List[torch.Tensor]):
         debug_str += "\n---------------\n"
         ddp_graph_log.debug(debug_str)
 
-        # 3: Replace submodules with lazily compiling submodule
-        class SubmoduleReplacer(torch.fx.interpreter.Interpreter):
+        # 3: compile each of the partitioned submodules using the user-provided compiler
+        class SubmodCompiler(torch.fx.interpreter.Interpreter):
             def __init__(self, module, compiler):
                 super().__init__(module)
                 self.compiler = compiler
 
-            def lazily_compiled_submod(self, input_mod):
+            def compile_submod(self, input_mod, args, kwargs):
                 """
-                Create a wrapper around submodules which:
-                - lazily compiles each of the partitioned submodules using the user-provided compiler
-                - unpacks singleton tuples/lists into flat arg
+                Compile the submodule,
+                using a wrapper to make sure its output is always a tuple,
+                which is required by AotAutograd based compilers
                 """
+                assert len(kwargs) == 0, "We assume only args for these modules"
 
-                class LazilyCompiledModule(torch.nn.Module):
-                    def __init__(self, submod, compiler, unwrap_singleton_tuple):
+                class WrapperModule(torch.nn.Module):
+                    def __init__(self, submod, unwrap_singleton_tuple):
                         super().__init__()
                         self.submod = submod
-                        self.compiler = compiler
-                        self.compiled = False
                         self.unwrap_singleton_tuple = unwrap_singleton_tuple
 
                     def forward(self, *args):
-                        if not self.compiled:
-                            assert (
-                                fake_mode
-                            ), "fake mode must have been available when creating lazy submod"
-                            fake_args = []
-                            for arg in args:
-                                if isinstance(arg, torch.Tensor) and not is_fake(arg):
-                                    fake_args.append(
-                                        torch._dynamo.utils.to_fake_tensor(
-                                            arg, fake_mode
-                                        )
-                                    )
-                                else:
-                                    fake_args.append(arg)
-                            # First trace with fake args
-                            new_submod = self.compiler(self.submod, tuple(fake_args))
-                            del self.submod
-                            self.submod = new_submod
-                            self.compiled = True
-                            self.compiler = None
-
                         x = self.submod(*args)
-                        # we must let 'input_mod' return a tuple, to make AOT happy.
-                        # (aot_autograd compile_fn literally requires that the output of a graph it compiles is a tuple).
-                        # however, we don't acutally want this tuple to be returned, since the fx logic that calls the submod
-                        # will again wrap outputs from the submod in a tuple.  So we unwrap it, and count on it being re-wrapped
+                        # TODO(whc)
+                        # for some reason the isinstance check is necessary if I split one node per submod
+                        # - even though I supposedly wrapped the output in a tuple in those cases, the real
+                        # compiled module was still returning a tensor
                         if self.unwrap_singleton_tuple and isinstance(x, (tuple, list)):
                             return x[0]
                         return x
@@ -381,35 +375,84 @@ def forward(self, *args):
                         traceback.FrameSummary(__file__, 0, DDPOptimizer),
                     ],
                 )
-                wrapper = LazilyCompiledModule(
-                    input_mod,
-                    self.compiler,
+                wrapper = WrapperModule(
+                    self.compiler(input_mod, args),
                     unwrap_singleton_tuple,
                 )
                 return wrapper
 
-            # We replace the submodules with lazy submodules which compile
-            # the corresponding submodules when they are run with real values
-            # Always returns `None` - we do not need to propagate values in order
-            # to replace submodules.
+            # Note:
+            #
+            # The way distributed works today around fake tensors can be somewhat confusing.
+            # Some of these codepaths are shared in both runtime, and compile time. The presence
+            # of a fake_mode, read off of fake tensor inputs, dictates how we will operate.
+            #
+            # A few things to keep in mind:
+            #
+            # 1) We invoke `compile_submod` with a real module. The output of that gets stored
+            # on the graph via `self.module.add_submodule(n.target, compiled_submod_real)`.
+            #
+            # 2) When running a call_module targeted node, if we have a fake_mode, we fakify the
+            # module we got from self.fetch_attr(n.target). Regardless of fake_mode, we then execute it.
+            #
+            # 3) Fake tensors should always be around during compile time.
+            #
+            # 4) Fake tensors should never be around at runtime.
+            #
+            # 5) We end up with a compilation mode that takes a real submodule and fake tensors,
+            # to match what aot_autograd expects. See Note: [Fake Modules and AOTAutograd]
             def run_node(self, n: Node) -> Any:
+                args, kwargs = self.fetch_args_kwargs_from_env(n)
+                new_args = []
+                assert fake_mode
+                for arg in args:
+                    if isinstance(arg, torch.Tensor) and not isinstance(
+                        arg, torch._subclasses.FakeTensor
+                    ):
+                        new_args.append(
+                            torch._dynamo.utils.to_fake_tensor(arg, fake_mode)
+                        )
+                    else:
+                        new_args.append(arg)
+
+                log.debug("run_node %s, %s got args %s", n.op, n.target, args_str(args))
+                assert isinstance(args, tuple)
+                assert isinstance(kwargs, dict)
+
                 if n.op == "call_module":
                     real_mod = self.fetch_attr(n.target)
+                    if fake_mode:
+                        curr_submod = deepcopy_to_fake_tensor(real_mod, fake_mode)
+                    else:
+                        curr_submod = real_mod
 
                     ddp_graph_log.debug(
-                        "\n---%s graph---\n%s", n.target, real_mod.graph
+                        "\n---%s graph---\n%s", n.target, curr_submod.graph
                     )
 
-                    assert len(n.kwargs) == 0, "We assume only args for these modules"
-                    lazily_compiled_submod = self.lazily_compiled_submod(real_mod)
+                    # When calling the compiler on the submod, inputs (new_args) are expected to
+                    # be FakeTensors already since Dynamo would have made them FakeTensors in the
+                    # non-DDP flow.  However, the parameters are _not_ expected to be FakeTensors,
+                    # since this wrapping happens during compilation
+                    compiled_submod_real = self.compile_submod(
+                        real_mod, new_args, kwargs
+                    )
 
                     # We update the original (outer) graph with a call into the compiled module
                     # instead of the uncompiled one.
                     self.module.delete_submodule(n.target)
                     n.target = "compiled_" + n.target
-                    self.module.add_submodule(n.target, lazily_compiled_submod)
+                    self.module.add_submodule(n.target, compiled_submod_real)
+
+                    # Finally, we have to produce inputs for use compiling the next submodule,
+                    # and these need to be FakeTensors, so we execute the module under fake_mode
+                    with fake_mode:
+                        return curr_submod(*new_args, **kwargs)
+                else:
+                    # placeholder or output nodes don't need to get compiled, just executed
+                    return getattr(self, n.op)(n.target, new_args, kwargs)
 
-        submod_compiler = SubmoduleReplacer(split_gm, self.backend_compile_fn)
+        submod_compiler = SubmodCompiler(split_gm, self.backend_compile_fn)
         submod_compiler.run(*example_inputs)
         split_gm.recompile()
 

From 2f3beb715c608a060934c237de402faa40ea211f Mon Sep 17 00:00:00 2001
From: PyTorch MergeBot 
Date: Wed, 22 Nov 2023 12:52:33 +0000
Subject: [PATCH 099/221] Revert "Add Stateful/Stateless symbolic contexts, use
 fresh fake mode for dynamo backends (#113926)"

This reverts commit 2ca1119d532af0ba385c7b5944b954c9385b4901.

Reverted https://github.com/pytorch/pytorch/pull/113926 on behalf of https://github.com/DanilBaibak due to Break internal build ([comment](https://github.com/pytorch/pytorch/pull/113926#issuecomment-1822713852))
---
 docs/source/conf.py                      |   4 +-
 test/dynamo/test_export.py               |   4 +-
 test/dynamo/test_subclasses.py           |  10 +-
 test/test_dynamic_shapes.py              |   4 +-
 test/test_fake_tensor.py                 |   4 +-
 torch/_dynamo/backends/distributed.py    |   4 +-
 torch/_dynamo/eval_frame.py              |   4 +-
 torch/_dynamo/output_graph.py            |  11 --
 torch/_dynamo/utils.py                   |  14 ---
 torch/_dynamo/variables/builder.py       |  48 +++-----
 torch/_functorch/aot_autograd.py         |   9 +-
 torch/_guards.py                         |   3 -
 torch/_subclasses/fake_tensor.py         |  16 ++-
 torch/_subclasses/meta_utils.py          |  26 ++--
 torch/fx/experimental/symbolic_shapes.py | 144 ++++++-----------------
 15 files changed, 88 insertions(+), 217 deletions(-)

diff --git a/docs/source/conf.py b/docs/source/conf.py
index dcd3c7694674..2ec2d66bbcb0 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -2866,9 +2866,7 @@
     "ShapeGuardPrinter",
     "StrictMinMaxConstraint",
     "SymDispatchMode",
-    "SymbolicContext",
-    "StatelessSymbolicContext",
-    "StatefulSymbolicContext",
+    "CreateSymbolicPolicy",
     # torch.fx.experimental.unification.match
     "Dispatcher",
     "VarDispatcher",
diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py
index 33cab08435ee..39b3040c3a83 100644
--- a/test/dynamo/test_export.py
+++ b/test/dynamo/test_export.py
@@ -28,8 +28,8 @@
 from torch.fx.experimental.symbolic_shapes import (
     ConstraintViolationError,
     DimDynamic,
+    FreshCreateSymbolicPolicy,
     ShapeEnv,
-    StatelessSymbolicContext,
 )
 from torch.testing._internal import common_utils
 
@@ -3282,7 +3282,7 @@ def test_symbool_guards(
             ) as fake_mode:
                 fake_x = fake_mode.from_tensor(
                     x,
-                    symbolic_context=StatelessSymbolicContext(
+                    policy=FreshCreateSymbolicPolicy(
                         dynamic_sizes=[DimDynamic.DYNAMIC for _ in range(x.dim())],
                     ),
                 )
diff --git a/test/dynamo/test_subclasses.py b/test/dynamo/test_subclasses.py
index 53fb7328e529..cadc164dd283 100644
--- a/test/dynamo/test_subclasses.py
+++ b/test/dynamo/test_subclasses.py
@@ -15,8 +15,8 @@
 
 from torch.fx.experimental.symbolic_shapes import (
     DimDynamic,
+    FreshCreateSymbolicPolicy,
     ShapeEnv,
-    StatelessSymbolicContext,
 )
 from torch.nested._internal.nested_tensor import (
     jagged_from_list,
@@ -337,13 +337,13 @@ def test_dynamic_dim(f, x, dim_dynamic, exp_frame_count, exp_op_count):
             ) as fake_mode:
                 x_fake = fake_mode.from_tensor(
                     x,
-                    symbolic_context=StatelessSymbolicContext(
+                    policy=FreshCreateSymbolicPolicy(
                         dynamic_sizes=[dim_dynamic for i in range(x.dim())]
                     ),
                 )
                 x1_fake = fake_mode.from_tensor(
                     x1,
-                    symbolic_context=StatelessSymbolicContext(
+                    policy=FreshCreateSymbolicPolicy(
                         dynamic_sizes=[dim_dynamic for i in range(x.dim())]
                     ),
                 )
@@ -373,7 +373,7 @@ def test_automatic_dynamic(f, inps, dim_dynamic, exp_frame_count, exp_op_count):
                 for inp in inps:
                     fake_inp = fake_mode.from_tensor(
                         inp,
-                        symbolic_context=StatelessSymbolicContext(
+                        policy=FreshCreateSymbolicPolicy(
                             [dim_dynamic for i in range(x.dim())]
                         ),
                     )
@@ -708,7 +708,7 @@ def test_recompilation(
             ) as fake_mode:
                 fake_inp = fake_mode.from_tensor(
                     x,
-                    symbolic_context=StatelessSymbolicContext(
+                    policy=FreshCreateSymbolicPolicy(
                         dynamic_sizes=[DimDynamic.DYNAMIC for i in range(x.dim())]
                     ),
                 )
diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py
index bf843587af50..daf293b43d00 100644
--- a/test/test_dynamic_shapes.py
+++ b/test/test_dynamic_shapes.py
@@ -27,7 +27,7 @@
     GuardOnDataDependentSymNode,
     ShapeEnv,
     is_symbolic,
-    StatelessSymbolicContext,
+    FreshCreateSymbolicPolicy,
 )
 from torch.testing._internal.common_utils import (
     instantiate_parametrized_tests,
@@ -137,7 +137,7 @@ def create_symbolic_tensor(name, arg, shape_env):
         shape_env.create_symbolic_sizes_strides_storage_offset(
             arg,
             source=ConstantSource(name),
-            symbolic_context=StatelessSymbolicContext(
+            policy=FreshCreateSymbolicPolicy(
                 dynamic_sizes=dynamic_dims,
                 constraint_sizes=constraint_dims
             ),
diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py
index 14a596508824..0b9f895f0a64 100644
--- a/test/test_fake_tensor.py
+++ b/test/test_fake_tensor.py
@@ -15,7 +15,7 @@
     DynamicOutputShapeException,
     UnsupportedOperatorException,
 )
-from torch.fx.experimental.symbolic_shapes import ShapeEnv, DimDynamic, free_symbols, StatelessSymbolicContext
+from torch.fx.experimental.symbolic_shapes import ShapeEnv, DimDynamic, free_symbols, FreshCreateSymbolicPolicy
 from torch.testing._internal.custom_op_db import custom_op_db
 from torch.testing._internal.common_device_type import ops
 from torch.testing._internal.common_device_type import instantiate_device_type_tests, OpDTypes
@@ -541,7 +541,7 @@ def test_same_shape_env_preserved(self):
         mode1 = FakeTensorMode(shape_env=shape_env)
         t1 = mode1.from_tensor(
             torch.randn(10),
-            symbolic_context=StatelessSymbolicContext(
+            policy=FreshCreateSymbolicPolicy(
                 dynamic_sizes=[DimDynamic.DYNAMIC],
                 constraint_sizes=[None]
             )
diff --git a/torch/_dynamo/backends/distributed.py b/torch/_dynamo/backends/distributed.py
index adc68bb30bff..90cb21c26351 100644
--- a/torch/_dynamo/backends/distributed.py
+++ b/torch/_dynamo/backends/distributed.py
@@ -409,9 +409,7 @@ def run_node(self, n: Node) -> Any:
                     if isinstance(arg, torch.Tensor) and not isinstance(
                         arg, torch._subclasses.FakeTensor
                     ):
-                        new_args.append(
-                            torch._dynamo.utils.to_fake_tensor(arg, fake_mode)
-                        )
+                        new_args.append(fake_mode.from_tensor(arg))
                     else:
                         new_args.append(arg)
 
diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py
index 7fff2c3392fc..4c0234d106e1 100644
--- a/torch/_dynamo/eval_frame.py
+++ b/torch/_dynamo/eval_frame.py
@@ -42,7 +42,7 @@
 from torch.fx.experimental.symbolic_shapes import (
     ConstraintViolationError,
     DimDynamic,
-    StatelessSymbolicContext,
+    FreshCreateSymbolicPolicy,
 )
 from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
 from torch.nn.parallel.distributed import DistributedDataParallel
@@ -903,7 +903,7 @@ def __init__(
                     # TODO(zhxchen17) Also preserve all the user constraints here.
                     arg.node.meta["val"] = fake_mode.from_tensor(
                         flat_args[i],
-                        symbolic_context=StatelessSymbolicContext(
+                        policy=FreshCreateSymbolicPolicy(
                             dynamic_sizes=[
                                 DimDynamic.DYNAMIC
                                 if d in flat_args_dynamic_dims[i]
diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py
index b577b9ea94aa..e3a33c1503a8 100644
--- a/torch/_dynamo/output_graph.py
+++ b/torch/_dynamo/output_graph.py
@@ -1065,17 +1065,6 @@ def compile_and_call_fx_graph(self, tx, rv, root):
             "%s", LazyString(lambda: self.get_graph_sizes_log_str(name))
         )
         self.call_cleanup_hooks()
-        old_fake_mode = self.tracing_context.fake_mode
-        if not self.export:
-            # TODO(voz): The way export uses gm, and fake tensors, is not supported with us resetting
-            backend_fake_mode = torch._subclasses.FakeTensorMode(
-                shape_env=old_fake_mode.shape_env,
-            )
-            # TODO(voz): Ostensibily, this should be scoped and
-            # restore back to old_fake_mode, but doing so currently violates
-            # a lot of fake_tensor ownership assumptions and runs afoul of detect_fake_mode
-            self.tracing_context.fake_mode = backend_fake_mode
-
         with self.restore_global_state():
             compiled_fn = self.call_user_compiler(gm)
         compiled_fn = disable(compiled_fn)
diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py
index ba876a0fbb82..aa0719a3ab56 100644
--- a/torch/_dynamo/utils.py
+++ b/torch/_dynamo/utils.py
@@ -2295,17 +2295,3 @@ def has_torch_function(vt: "torch._dynamo.variables.base.VariableTracker") -> bo
         isinstance(vt, UserDefinedObjectVariable)
         and hasattr(vt.value, "__torch_function__")
     )
-
-
-# see note [Tensor Fakification and Symbol Caching]
-def to_fake_tensor(t, fake_mode):
-    symbolic_context = None
-    source = None
-    if tracing_context := torch._guards.TracingContext.try_get():
-        if t in tracing_context.tensor_to_context:
-            symbolic_context = tracing_context.tensor_to_context[t]
-            source = symbolic_context.tensor_source
-
-    return fake_mode.from_tensor(
-        t, static_shapes=False, symbolic_context=symbolic_context, source=source
-    )
diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py
index a139efd3e166..be66d51c0f4d 100644
--- a/torch/_dynamo/variables/builder.py
+++ b/torch/_dynamo/variables/builder.py
@@ -26,11 +26,11 @@
 from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode
 from torch.fx.experimental.symbolic_shapes import (
     _constrain_range_for_size,
+    CreateSymbolicPolicy,
     DimConstraint,
     DimDynamic,
+    FreshCreateSymbolicPolicy,
     RelaxedUnspecConstraint,
-    StatefulSymbolicContext,
-    SymbolicContext,
 )
 from torch.fx.immutable_collections import immutable_list
 from torch.nested._internal.nested_tensor import NestedTensor
@@ -1564,33 +1564,23 @@ def __eq__(self, other: object) -> bool:
 
 
 # Performs automatic dynamic dim determination.
-# Returns a SymbolicContext
-def _automatic_dynamic(e, tx, source, static_shapes) -> SymbolicContext:
-    name = source.name()
-    prior_policy = tx.output.tracing_context.tensor_to_context.get(e, None)
-    source_to_symint_node_cache = (
-        prior_policy.source_to_symint_node_cache if prior_policy else None
-    )
-
+# Returns a CreateSymbolicPolicy
+def _automatic_dynamic(e, tx, name, static_shapes) -> CreateSymbolicPolicy:
     if static_shapes:
-        return StatefulSymbolicContext(
+        return FreshCreateSymbolicPolicy(
             dynamic_sizes=[DimDynamic.STATIC] * e.dim(),
             constraint_sizes=[None] * e.dim(),
-            tensor_source=source,
-            source_to_symint_node_cache=source_to_symint_node_cache,
         )
 
     # We preserve the dynamism of inputs. For example, when users call
     # make_fx(torch.cond, tracing_mode="symbolic")(*args), inputs have SymInt sizes.
     if any(isinstance(s, SymInt) for s in e.size()):
-        return StatefulSymbolicContext(
+        return FreshCreateSymbolicPolicy(
             dynamic_sizes=[
                 DimDynamic.DYNAMIC if isinstance(s, SymInt) else DimDynamic.STATIC
                 for s in e.size()
             ],
             constraint_sizes=[None] * e.dim(),
-            tensor_source=source,
-            source_to_symint_node_cache=source_to_symint_node_cache,
         )
 
     # Prep for automatic dynamic
@@ -1709,7 +1699,7 @@ def update_dim2constraint(dim, constraint_range, debug_name):
         # Now, figure out if the dim is dynamic/duck/static
         if constraint_dim is not None or marked_dynamic or marked_weak_dynamic:
             # NB: We could assert static_shapes is False here, but it
-            # seems better to allow the user to override symbolic_context in this
+            # seems better to allow the user to override policy in this
             # case
             dynamic = DimDynamic.DYNAMIC
         elif static_shapes or config.assume_static_by_default or marked_static:
@@ -1721,15 +1711,12 @@ def update_dim2constraint(dim, constraint_range, debug_name):
 
     tx.output.frame_state[name] = frame_state_entry
 
-    return StatefulSymbolicContext(
+    return FreshCreateSymbolicPolicy(
         dynamic_sizes=dynamic_dims,
         constraint_sizes=constraint_dims,
-        tensor_source=source,
-        source_to_symint_node_cache=source_to_symint_node_cache,
     )
 
 
-# See note [Tensor Fakification and Symbol Caching]
 def wrap_to_fake_tensor_and_record(e, tx, *, source: Optional[Source], is_tensor: bool):
     if (
         type(e) in (torch.Tensor, torch.nn.Parameter, FakeTensor)
@@ -1741,36 +1728,31 @@ def wrap_to_fake_tensor_and_record(e, tx, *, source: Optional[Source], is_tensor
             e, is_tensor, guard_source=source.guard_source()
         )
 
-        symbolic_context = None
+        policy = None
         if not e.is_nested:
             # TODO: We should probably support this for nested tensors too
-            symbolic_context = _automatic_dynamic(e, tx, source, static_shapes)
-
-        if symbolic_context:
-            tx.output.tracing_context.tensor_to_context[e] = symbolic_context
+            policy = _automatic_dynamic(e, tx, source.name(), static_shapes)
 
         log.debug(
             "wrap_to_fake %s %s %s %s",
             source.name(),
             tuple(e.shape),
-            symbolic_context.dynamic_sizes if symbolic_context is not None else None,
-            symbolic_context.constraint_sizes if symbolic_context is not None else None,
+            policy.dynamic_sizes if policy is not None else None,
+            policy.constraint_sizes if policy is not None else None,
         )
         fake_e = wrap_fake_exception(
             lambda: tx.fake_mode.from_tensor(
                 e,
                 source=source,
-                symbolic_context=symbolic_context,
+                policy=policy,
             )
         )
-        # TODO: just store the whole symbolic_context here
+        # TODO: just store the whole policy here
         tx.output.tracked_fakes.append(
             TrackedFake(
                 fake_e,
                 source,
-                symbolic_context.constraint_sizes
-                if symbolic_context is not None
-                else None,
+                policy.constraint_sizes if policy is not None else None,
             )
         )
         tx.output.tracked_fakes_id_to_source[id(e)].append(source)
diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py
index 818388360d81..a560db5be495 100644
--- a/torch/_functorch/aot_autograd.py
+++ b/torch/_functorch/aot_autograd.py
@@ -4368,17 +4368,14 @@ def convert(idx, x):
                     if all(isinstance(getattr(x, attr), FakeTensor) for attr in attrs):
                         assert all(getattr(x, attr).fake_mode is fake_mode for attr in attrs)
                         return x
-
-
+                # TODO: Ensure that this codepath is never exercised from
+                # Dynamo
                 if (
                     idx < aot_config.num_params_buffers
                     and config.static_weight_shapes
                 ):
-                    # TODO: Ensure that this codepath is never exercised from
-                    # Dynamo
                     return fake_mode.from_tensor(x, static_shapes=True)
-
-                return torch._dynamo.utils.to_fake_tensor(x, fake_mode)
+                return fake_mode.from_tensor(x, static_shapes=False)
 
             return [convert(idx, x) for idx, x in enumerate(flat_args)]
 
diff --git a/torch/_guards.py b/torch/_guards.py
index 69912b15313d..fe3a10d663b7 100644
--- a/torch/_guards.py
+++ b/torch/_guards.py
@@ -29,7 +29,6 @@
 import torch
 from torch.utils import _pytree as pytree
 from torch.utils._traceback import CapturedTraceback
-from torch.utils.weak import WeakTensorKeyDictionary
 
 log = logging.getLogger(__name__)
 
@@ -619,8 +618,6 @@ def __init__(self, fake_mode):
         # ints that are known to be size-like and may have 0/1 entries that we
         # must not specialize on.
         self.force_unspec_int_unbacked_size_like = False
-        # See note [Tensor Fakification and Symbol Caching]
-        self.tensor_to_context = WeakTensorKeyDictionary()
 
     @staticmethod
     @contextmanager
diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py
index e4e676c9d8be..b36bc4c5bf8b 100644
--- a/torch/_subclasses/fake_tensor.py
+++ b/torch/_subclasses/fake_tensor.py
@@ -312,7 +312,7 @@ def from_real_tensor(
         shape_env=None,
         *,
         source=None,
-        symbolic_context=None,
+        policy=None,
         memoized_only=False,
     ):
         maybe_memo = self._get_memo(t)
@@ -348,7 +348,7 @@ def mk_fake_tensor(make_meta_t):
             shape_env=shape_env,
             callback=mk_fake_tensor,
             source=source,
-            symbolic_context=symbolic_context,
+            policy=policy,
         )
         if out is NotImplemented:
             raise UnsupportedFakeTensorException("meta converter nyi")
@@ -383,7 +383,7 @@ def __call__(
         make_constant=False,
         shape_env=None,
         source=None,
-        symbolic_context=None,
+        policy=None,
         memoized_only=False,
     ):
         return self.from_real_tensor(
@@ -392,7 +392,7 @@ def __call__(
             make_constant,
             shape_env=shape_env,
             source=source,
-            symbolic_context=symbolic_context,
+            policy=policy,
             memoized_only=memoized_only,
         )
 
@@ -1855,7 +1855,7 @@ def from_tensor(
         *,
         static_shapes=None,
         source: Optional[Source] = None,
-        symbolic_context=None,
+        policy=None,
         # Setting this flag will force FakeTensorMode to return `None` if attempting to convert a tensor we have not
         # seen before.
         memoized_only=False,
@@ -1864,16 +1864,14 @@ def from_tensor(
         if static_shapes is None:
             static_shapes = self.static_shapes
         if static_shapes:
-            assert (
-                symbolic_context is None
-            ), "cannot set both static_shapes and symbolic_context"
+            assert policy is None, "cannot set both static_shapes and policy"
             shape_env = None
         return self.fake_tensor_converter(
             self,
             tensor,
             shape_env=shape_env,
             source=source,
-            symbolic_context=symbolic_context,
+            policy=policy,
             memoized_only=memoized_only,
         )
 
diff --git a/torch/_subclasses/meta_utils.py b/torch/_subclasses/meta_utils.py
index 8db8f94b1b41..1ff2a156379d 100644
--- a/torch/_subclasses/meta_utils.py
+++ b/torch/_subclasses/meta_utils.py
@@ -23,7 +23,7 @@
 if TYPE_CHECKING:
     # Import the following modules during type checking to enable code intelligence features,
     # Do not import unconditionally, as they import sympy and importing sympy is very slow
-    from torch.fx.experimental.symbolic_shapes import SymbolicContext
+    from torch.fx.experimental.symbolic_shapes import CreateSymbolicPolicy
 
 DimList = List
 
@@ -184,7 +184,7 @@ def meta_tensor(
         shape_env=None,
         callback=lambda t: t(),
         source: Optional[Source] = None,
-        symbolic_context: Optional["SymbolicContext"] = None,
+        policy: Optional["CreateSymbolicPolicy"] = None,
     ):
         from torch._subclasses.fake_tensor import FakeTensor
 
@@ -250,10 +250,10 @@ def sym_sizes_strides_storage_offset(
                         # the wrapper tensor and any inner tensors.
                         # We can revisit this if this assumption does not hold
                         # for any important subclasses later.
-                        symbolic_context=symbolic_context,
+                        policy=policy,
                     )
             else:
-                assert symbolic_context is None
+                assert policy is None
             return (t.size(), t.stride(), t.storage_offset())
 
         # see expired-storages
@@ -315,22 +315,22 @@ def sym_sizes_strides_storage_offset(
                     from torch._dynamo.source import AttrSource
                     from torch.fx.experimental.symbolic_shapes import (
                         DimDynamic,
-                        StatelessSymbolicContext,
+                        FreshCreateSymbolicPolicy,
                     )
 
                     if shape_env and not t.is_nested and not t._base.is_nested:
-                        base_symbolic_context = StatelessSymbolicContext(
+                        base_policy = FreshCreateSymbolicPolicy(
                             dynamic_sizes=[DimDynamic.STATIC] * t._base.dim(),
                             constraint_sizes=[None] * t._base.dim(),
                         )
                     else:
-                        base_symbolic_context = None
+                        base_policy = None
                     base = self.meta_tensor(
                         t._base,
                         shape_env,
                         callback,
                         source=AttrSource(source, "_base"),
-                        symbolic_context=base_symbolic_context,
+                        policy=base_policy,
                     )
 
                     def is_c_of_r(complex_dtype, real_dtype):
@@ -620,7 +620,7 @@ def empty_create(inner_t, inner_src):
                         shape_env,
                         callback,
                         source=AttrSource(source, "grad"),
-                        symbolic_context=symbolic_context,
+                        policy=policy,
                     )
                 torch._C._set_conj(r, t.is_conj())
                 torch._C._set_neg(r, t.is_neg())
@@ -637,7 +637,7 @@ def __call__(
         *,
         callback=lambda t: t(),
         source=None,
-        symbolic_context=None,
+        policy=None,
     ):
         # TODO: zero tensors?  We appear to have eliminated them by
         # excluding complex for now
@@ -682,7 +682,7 @@ def __call__(
                                 shape_env=shape_env,
                                 callback=callback,
                                 source=source,
-                                symbolic_context=symbolic_context,
+                                policy=policy,
                             )
                         out = torch._to_functional_tensor(fake_t)
                         torch._mirror_autograd_meta_to(fake_t, out)
@@ -700,7 +700,7 @@ def __call__(
                                 shape_env=shape_env,
                                 callback=callback,
                                 source=source,
-                                symbolic_context=symbolic_context,
+                                policy=policy,
                             )
                         return _wrap_functional_tensor(fake_t, current_level())
                 self.miss += 1
@@ -712,7 +712,7 @@ def __call__(
                     shape_env=shape_env,
                     callback=callback,
                     source=source,
-                    symbolic_context=symbolic_context,
+                    policy=policy,
                 )
                 if type(t) is torch.nn.Parameter:
                     # NB: Cannot directly use Parameter constructor
diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py
index 7f056ec9d5a7..3d97727ff7b8 100644
--- a/torch/fx/experimental/symbolic_shapes.py
+++ b/torch/fx/experimental/symbolic_shapes.py
@@ -63,7 +63,7 @@ class GuardOnDataDependentSymNode(RuntimeError):
     "guard_int", "guard_float", "guard_scalar",
     "hint_int", "SYMPY_INTERP", "free_symbols", "is_symbol_binding_fx_node",
     "is_concrete_bool", "SHAPEENV_EVENT_KEY", "CURRENT_NODE_KEY",
-    "has_free_symbols", "sym_eq", "SymbolicContext", "StatelessSymbolicContext", "StatefulSymbolicContext"
+    "has_free_symbols", "sym_eq", "CreateSymbolicPolicy", "FreshCreateSymbolicPolicy",
 ]
 
 # FX node metadata keys for symbolic shape FX graph.
@@ -721,14 +721,8 @@ def render(self):
     def is_equal(self, source1, source2):
         return self._find(source1) == self._find(source2)
 
-
-def _assert_symbol_context(symbolic_context):
-    assert isinstance(symbolic_context, SymbolicContext), "Invalid symbolic_context object"
-    assert type(symbolic_context) is not SymbolicContext, "Illegal usage of symbolic_context ABC"
-
-
 @dataclass(frozen=True)
-class SymbolicContext:
+class CreateSymbolicPolicy:
     """
     Data structure specifying how we should create symbols in
     ``create_symbolic_sizes_strides_storage_offset``; e.g., should
@@ -742,67 +736,20 @@ class SymbolicContext:
 
 
 @dataclass(frozen=True)
-class StatelessSymbolicContext(SymbolicContext):
+class FreshCreateSymbolicPolicy(CreateSymbolicPolicy):
     """
     Create symbols in ``create_symbolic_sizes_strides_storage_offset`` via
-    a symbolic_context determination as given by ``DimDynamic`` and ``DimConstraint``.
+    a policy determination as given by ``DimDynamic`` and ``DimConstraint``.
     This will cause fresh symbols to be allocated
     """
     dynamic_sizes: DimList[DimDynamic]
     constraint_sizes: DimList[DimConstraint] = None
-    # TODO: add storage offset and stride symbolic_context
+    # TODO: add storage offset and stride policy
 
     def __post_init__(self):
         if self.constraint_sizes is None:
             object.__setattr__(self, 'constraint_sizes', [None] * len(self.dynamic_sizes))
 
-
-# note [Tensor Fakification and Symbol Caching]
-#
-# As of the time of this note, dynamo creates a fresh fake tensor mode for backends.
-# The reason we do this is because there are certain classes of operations, namely,
-# metadata mutations, that change tensor size, stride, etc. This means that the fake tensor
-# state at the end of a dynamo trace is different than the fake tensor state at the beginning
-# of a trace. Backends like aot_autograd need a fresh fake tensor to correctly track metadata mutation,
-# view relationships, etc.
-#
-# As we create a new fake mode, we also lose the memoization that comes with it. Rather than
-# transfer the memoization cache, we instead transfer the shape env. However, with this
-# comes nuance - as dynamo is selective in how it makes symbolic shapes. Due to strategies in
-# automatic dynamic and constraints, the policy for which dims are dynamic is nuanced and varies across
-# recompilations.
-#
-# In order to preserve the symbolic decisions made during dynamo tensor fakification, we pass
-# a StatefulSymbolicContext at creation time. This object is tracked, per tensor, on the TracingContext.
-# The lifecycle of this object should match the lifecycle of the original dynamo tracked tensor, and it is
-# safe to reuse this object as many times as necessary to create a fake tensor. Fake tensors
-# created with new fake modes should produce the same exact symbols as the original, providing the same shape_env
-# is used.
-# TODO(voz): Shape env validation
-@dataclass(frozen=True)
-class StatefulSymbolicContext(StatelessSymbolicContext):
-    """
-    Create symbols in ``create_symbolic_sizes_strides_storage_offset`` via
-    a symbolic_context determination as given by a cache of Source:Symbol. A cache hit
-    will reuse a stored symbol, and a cache miss will write to this cache.
-
-    This behaves like StatelessSymbolicContext, except the cache supersedes the
-    other values - dynamic_sizes and constraint_sizes will not be read if we cache
-    hit.
-
-    It is the cache owners responsibility to maintain the lifecycle of the cache
-    w/r/t different shape_envs, clearing, etc.
-    """
-    tensor_source: Source = None
-    source_to_symint_node_cache : Dict["TensorPropertySource", SymInt] = None
-
-    def __post_init__(self):
-        # The None default is annoying, but required because of dataclass limitations
-        assert self.tensor_source is not None
-        if not self.source_to_symint_node_cache:
-            object.__setattr__(self, 'source_to_symint_node_cache', {})
-
-
 def is_symbolic(val: Union[int, SymInt, float, SymFloat, bool, SymBool]) -> bool:
     if isinstance(val, (int, float, bool)):
         return False
@@ -1975,20 +1922,20 @@ def _update_version_counter(self):
     def _produce_dyn_sizes(self,
                            ex_size: Sequence[int],
                            source: Source,
-                           symbolic_context: SymbolicContext
+                           policy: CreateSymbolicPolicy
                            ) -> List[sympy.Expr]:
-        return self._produce_dyn_sizes_from_int_tuple(tuple(ex.size()), source, symbolic_context)
+        return self._produce_dyn_sizes_from_int_tuple(tuple(ex.size()), source, policy)
 
     def _produce_dyn_sizes_from_int_tuple(self,
                                           tensor_size: Tuple[int],
                                           source: Source,
-                                          symbolic_context: SymbolicContext,
+                                          policy: CreateSymbolicPolicy,
                                           ) -> List[sympy.Expr]:
         assert all(not is_symbolic(val) for val in tensor_size), f"Expect size to be a plain tuple of ints but got {tensor_size}"
         from torch._dynamo.source import TensorPropertySource, TensorProperty
-        _assert_symbol_context(symbolic_context)
-        dynamic_dims = symbolic_context.dynamic_sizes
-        constraint_dims = symbolic_context.constraint_sizes
+        assert isinstance(policy, FreshCreateSymbolicPolicy)
+        dynamic_dims = policy.dynamic_sizes
+        constraint_dims = policy.constraint_sizes
         size = []
         for i, val in enumerate(tensor_size):
             size.append(self.create_symbol(
@@ -2001,7 +1948,7 @@ def create_symbolic_sizes_strides_storage_offset(
         ex: torch.Tensor,
         source: Source,
         *,
-        symbolic_context: Optional[SymbolicContext] = None,
+        policy: Optional[CreateSymbolicPolicy] = None,
     ):
         """
         Returns a list of symbolic sizes and strides for the given tensor.
@@ -2063,7 +2010,7 @@ def maybe_specialize_sym_int_with_hint(maybe_sym) -> int:
             ex_storage_offset,
             [_is_dim_dynamic(ex, i) for i in range(ex.dim())],
             source,
-            symbolic_context=symbolic_context,
+            policy=policy,
         )
 
     @record_shapeenv_event()
@@ -2075,12 +2022,12 @@ def _create_symbolic_sizes_strides_storage_offset(
         is_dim_dynamic: Sequence[bool],
         source: Source,
         *,
-        symbolic_context: Optional[SymbolicContext] = None,
+        policy: Optional[CreateSymbolicPolicy] = None,
     ):
         dim = len(ex_size)
 
         # Reimplement the legacy behavior
-        if symbolic_context is None:
+        if policy is None:
             constraint_dims = [None] * dim
             dynamic_dims = []
             for i in range(dim):
@@ -2094,14 +2041,13 @@ def _create_symbolic_sizes_strides_storage_offset(
                     r = DimDynamic.DUCK
                 dynamic_dims.append(r)
             dynamic_dims = [DimDynamic.DUCK] * dim
-            # symbolic_context is None - set one
-            symbolic_context = StatelessSymbolicContext(dynamic_sizes=dynamic_dims, constraint_sizes=constraint_dims)
-        # We got a StatelessSymbolicContext
-        _assert_symbol_context(symbolic_context)
-        constraint_dims = symbolic_context.constraint_sizes
-        dynamic_dims = symbolic_context.dynamic_sizes
-
-        # TODO: make this configurable from outside symbolic_context; we made a symbolic_context
+            policy = FreshCreateSymbolicPolicy(dynamic_sizes=dynamic_dims, constraint_sizes=constraint_dims)
+
+        assert isinstance(policy, FreshCreateSymbolicPolicy)
+        constraint_dims = policy.constraint_sizes
+        dynamic_dims = policy.dynamic_sizes
+
+        # TODO: make this configurable from outside policy; we made a policy
         # decision here where if all sizes are static, we are going to
         # specialize all of the inner strides/offset too. We don't have to
         # do this.
@@ -2112,7 +2058,7 @@ def _create_symbolic_sizes_strides_storage_offset(
         assert len(constraint_dims) == dim
 
         from torch._dynamo.source import TensorPropertySource, TensorProperty
-        size: List[sympy.Expr] = self._produce_dyn_sizes_from_int_tuple(ex_size, source, symbolic_context)
+        size: List[sympy.Expr] = self._produce_dyn_sizes_from_int_tuple(ex_size, source, policy)
         stride: List[Optional[sympy.Expr]] = [None] * len(size)
         for i, val in enumerate(ex_stride):
             if val in (0, 1):
@@ -2150,12 +2096,7 @@ def _create_symbolic_sizes_strides_storage_offset(
         assert all(x is not None for x in stride)
 
         sym_sizes = [
-            self.create_symintnode(
-                sym,
-                hint=hint,
-                source=TensorPropertySource(source, TensorProperty.SIZE, i),
-                symbolic_context=symbolic_context
-            )
+            self.create_symintnode(sym, hint=hint, source=TensorPropertySource(source, TensorProperty.SIZE, i))
             for i, (sym, hint) in enumerate(zip(size, ex_size))
         ]
         sym_stride = []
@@ -2164,17 +2105,14 @@ def _create_symbolic_sizes_strides_storage_offset(
             # we computed
             assert stride_expr is not None
             sym_stride.append(self.create_symintnode(
-                stride_expr, hint=ex_stride[i], source=TensorPropertySource(source, TensorProperty.STRIDE, i),
-                symbolic_context=symbolic_context))
-        sym_storage_offset = self.create_symintnode(
-            self.create_symbol(
-                ex_storage_offset,
-                TensorPropertySource(source, TensorProperty.STORAGE_OFFSET),
-                dynamic_dim=DimDynamic.DYNAMIC,
-                constraint_dim=None,
-            ),
-            hint=ex_storage_offset,
-            source=TensorPropertySource(source, TensorProperty.STORAGE_OFFSET), symbolic_context=symbolic_context)
+                stride_expr, hint=ex_stride[i], source=TensorPropertySource(source, TensorProperty.STRIDE, i)
+            ))
+        sym_storage_offset = self.create_symintnode(self.create_symbol(
+            ex_storage_offset,
+            TensorPropertySource(source, TensorProperty.STORAGE_OFFSET),
+            dynamic_dim=DimDynamic.DYNAMIC,
+            constraint_dim=None,
+        ), hint=ex_storage_offset, source=TensorPropertySource(source, TensorProperty.STORAGE_OFFSET))
         return tuple(sym_sizes), tuple(sym_stride), sym_storage_offset
 
     # If you know what the current hint value of the SymInt to be created
@@ -2187,10 +2125,7 @@ def create_symintnode(
             *,
             hint: Optional[int],
             source: Optional[Source] = None,
-            symbolic_context: Optional[SymbolicContext] = None,
     ):
-        source_name = source.name() if source else None
-
         if self._translation_validation_enabled and source is not None:
             # Create a new symbol for this source.
             symbol = self._create_symbol_for_source(source)
@@ -2204,20 +2139,11 @@ def create_symintnode(
         else:
             fx_node = None
 
-        # see note [Tensor Fakification and Symbol Caching]
-        if isinstance(symbolic_context, StatefulSymbolicContext) and source_name:
-            if source_name in symbolic_context.source_to_symint_node_cache:
-                return symbolic_context.source_to_symint_node_cache[source_name]
-
         if isinstance(sym, sympy.Integer):
             if hint is not None:
                 assert int(sym) == hint
-            out = int(sym)
-        else:
-            out = SymInt(SymNode(sym, self, int, hint, fx_node=fx_node))
-        if isinstance(symbolic_context, StatefulSymbolicContext) and source_name:
-            symbolic_context.source_to_symint_node_cache[source_name] = out
-        return out
+            return int(sym)
+        return SymInt(SymNode(sym, self, int, hint, fx_node=fx_node))
 
     @record_shapeenv_event()
     def create_unspecified_symint_and_symbol(self, value, source, dynamic_dim):
@@ -2312,7 +2238,7 @@ def create_symbol(
         assert isinstance(source, Source), f"{type(source)} {source}"
         assert not (positive and val < 0), f"positive set for negative value: {val}"
         # It's always sound to allocate a symbol as DYNAMIC.  If the user
-        # constrained the symbol, force the symbolic_context to DYNAMIC, because our
+        # constrained the symbol, force the policy to DYNAMIC, because our
         # constraint code will do weird stuff if, e.g., it's duck shaped
         if constraint_dim is not None:
             dynamic_dim = DimDynamic.DYNAMIC

From 33fad1c0d423d75801d74aa1c906262964d702b2 Mon Sep 17 00:00:00 2001
From: Bin Bao 
Date: Tue, 21 Nov 2023 14:19:30 -0800
Subject: [PATCH 100/221] [AOTI] Fix a weight loading issue when the weight
 size can be 0 (#114280)

Summary: When a weight tensor is 0-size, no device memory should be allocated for it. This PR fixes the weight loading logic for such a case. This problem was found when running the 14K model test.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114280
Approved by: https://github.com/chenyang78
---
 test/inductor/test_aot_inductor.py            | 22 +++++++++++++++++++
 torch/_inductor/codecache.py                  |  3 +++
 torch/csrc/inductor/aoti_runtime/model.h      |  5 +++--
 .../csrc/inductor/aoti_torch/shim_common.cpp  | 13 ++++++-----
 4 files changed, 36 insertions(+), 7 deletions(-)

diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py
index 53b8d6a0a009..a41860cb14cb 100644
--- a/test/inductor/test_aot_inductor.py
+++ b/test/inductor/test_aot_inductor.py
@@ -1319,6 +1319,28 @@ def forward(
 
         self.check_model(Model(), inputs)
 
+    def test_zero_size_weight(self):
+        class Model(torch.nn.Module):
+            def __init__(self, channel, r=8):
+                super().__init__()
+                self.pool = torch.nn.AdaptiveAvgPool2d(1)
+                self.net = torch.nn.Sequential(
+                    torch.nn.Linear(channel, channel // r, bias=False),
+                    torch.nn.ReLU(inplace=True),
+                    torch.nn.Linear(channel // r, channel, bias=False),
+                    torch.nn.Sigmoid(),
+                )
+
+            def forward(self, inp):
+                b, c, _, _ = inp.shape
+                x = self.pool(inp).view(b, c)
+                x = self.net(x).view(b, c, 1, 1)
+                x = inp * x
+                return x
+
+        inputs = (torch.rand(4, 4, 4, 4, device=self.device),)
+        self.check_model(Model(4), inputs)
+
 
 common_utils.instantiate_parametrized_tests(AOTInductorTestsTemplate)
 
diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py
index 8152f8795240..ce674a82c28a 100644
--- a/torch/_inductor/codecache.py
+++ b/torch/_inductor/codecache.py
@@ -1608,6 +1608,9 @@ def _to_bytes(t: torch.Tensor) -> bytes:
                         # the raw data of the underlying structure.
                         import ctypes
 
+                        if t.numel() == 0:
+                            return b""
+
                         t_cpu = t.untyped_storage().cpu()
                         raw_array = ctypes.cast(
                             t_cpu.data_ptr(),
diff --git a/torch/csrc/inductor/aoti_runtime/model.h b/torch/csrc/inductor/aoti_runtime/model.h
index 1c9b6f3079ed..2860a8b251e6 100644
--- a/torch/csrc/inductor/aoti_runtime/model.h
+++ b/torch/csrc/inductor/aoti_runtime/model.h
@@ -241,8 +241,9 @@ class AOTInductorModelBase {
     for (size_t i = 0; i < num_constants; i++) {
       std::string name = this->constant_name(i);
       size_t data_size = this->constant_data_size(i);
-      uint8_t* internal_ptr =
-          constant_ptr(constants_internal_offset[i], bytes_read, data_size);
+      uint8_t* internal_ptr = (data_size != 0)
+          ? constant_ptr(constants_internal_offset[i], bytes_read, data_size)
+          : nullptr;
       bytes_read += data_size;
 
       // Create at::Tensor from copied memory.
diff --git a/torch/csrc/inductor/aoti_torch/shim_common.cpp b/torch/csrc/inductor/aoti_torch/shim_common.cpp
index c1cf7e2b6332..46bb26e52f54 100644
--- a/torch/csrc/inductor/aoti_torch/shim_common.cpp
+++ b/torch/csrc/inductor/aoti_torch/shim_common.cpp
@@ -224,11 +224,14 @@ AOTITorchError aoti_torch_create_tensor_from_blob(
     c10::Device device = c10_device(device_type, device_index);
     c10::TensorOptions options = c10::TensorOptions().device(device).dtype(
         static_cast(dtype));
-    at::Tensor* new_tensor = new at::Tensor(at::for_blob(data, sizes)
-                                                .strides(strides)
-                                                .storage_offset(storage_offset)
-                                                .options(options)
-                                                .make_tensor());
+    at::Tensor* new_tensor = (data != nullptr)
+        ? new at::Tensor(at::for_blob(data, sizes)
+                             .strides(strides)
+                             .storage_offset(storage_offset)
+                             .options(options)
+                             .make_tensor())
+        // data == nullptr can happen for a 0-size tensor
+        : new at::Tensor(at::empty_strided(sizes, strides, options));
     *ret_new_tensor = tensor_pointer_to_tensor_handle(new_tensor);
   });
 }

From 324cde59b2122644617eaec975feb6fcf6c26c65 Mon Sep 17 00:00:00 2001
From: Nikita Shulga 
Date: Wed, 22 Nov 2023 14:48:24 +0000
Subject: [PATCH 101/221] [MPS] Fix test_copy_cast_no_leak (#114313)

When running on MacOS-13.2 test always fails on first run, but succeeds on the second as presumably it reserves some memory to cache f32->f16 graph. Make it resilient against such failures by adding a warmup step when one conversion is performed before recording driver memory utilization.

Fixes https://github.com/pytorch/pytorch/issues/114305
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114313
Approved by: https://github.com/huydhn
---
 test/test_mps.py | 10 ++++++++--
 1 file changed, 8 insertions(+), 2 deletions(-)

diff --git a/test/test_mps.py b/test/test_mps.py
index 240977e4390d..2a1bcdb30782 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -1054,11 +1054,17 @@ def leak_gpu0():
             leak_gpu0()
 
     def test_copy_cast_no_leak(self):
+
+        def step(x):
+            x = x.to(device='cpu', dtype=torch.float32)
+            x = x.to(device='mps', dtype=torch.float16)
+
         a = torch.randn(128, 128, device='mps', dtype=torch.float16)
+        # Warm up / prebuild MPS shaders (otherwise check fails on 13.2)
+        step(a)
         torch.mps.empty_cache()
         driver_before = torch.mps.driver_allocated_memory()
-        a = a.to(device='cpu', dtype=torch.float32)
-        a = a.to(device='mps', dtype=torch.float16)
+        step(a)
         torch.mps.empty_cache()
         driver_after = torch.mps.driver_allocated_memory()
         self.assertTrue(driver_before == driver_after, f"Detected {driver_after-driver_before} bytes leak of GPU memory")

From d7f698102e5b889af09082b86d23b9407bfe6863 Mon Sep 17 00:00:00 2001
From: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Date: Wed, 22 Nov 2023 15:08:15 +0000
Subject: [PATCH 102/221] Disable MPS tests on macos-m1-13 runners (#114360)

As all of them are down at the moment, see screenshot below from [HUD](https://hud.pytorch.org/metrics)
image

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114360
Approved by: https://github.com/atalman
---
 .github/workflows/mac-mps.yml | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/.github/workflows/mac-mps.yml b/.github/workflows/mac-mps.yml
index 619210eb60ed..bea8f58d5851 100644
--- a/.github/workflows/mac-mps.yml
+++ b/.github/workflows/mac-mps.yml
@@ -28,7 +28,8 @@ jobs:
       test-matrix: |
         { include: [
           { config: "mps", shard: 1, num_shards: 1, runner: "macos-m1-12" },
-          { config: "mps", shard: 1, num_shards: 1, runner: "macos-m1-13" },
+          # TODO: Revert me when those runners are back online
+          # { config: "mps", shard: 1, num_shards: 1, runner: "macos-m1-13" },
         ]}
 
   macos-12-py3-arm64-mps-test:

From f2ca07b680c2a53f54ea5a9667faba61e1d30edc Mon Sep 17 00:00:00 2001
From: Ke Wen 
Date: Wed, 22 Nov 2023 15:35:03 +0000
Subject: [PATCH 103/221] [ProcessGroupNCCL] Remove jumper to UCC (#114170)

The "jumper" to UCC lib in ProcessGroupNCCL was a temporary solution a while back. Cleaning it now that UCC has its own "PG" representation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114170
Approved by: https://github.com/wconstab, https://github.com/fduwjj, https://github.com/XilunWu, https://github.com/Aidyn-A
---
 CMakeLists.txt                                |  3 --
 caffe2/CMakeLists.txt                         |  3 --
 cmake/Summary.cmake                           |  1 -
 .../distributed/c10d/ProcessGroupNCCL.cpp     | 35 -------------------
 .../distributed/c10d/ProcessGroupNCCL.hpp     |  9 -----
 torch/csrc/distributed/c10d/UCCForNCCL.hpp    | 27 --------------
 torch/csrc/distributed/c10d/init.cpp          |  4 +--
 7 files changed, 1 insertion(+), 81 deletions(-)
 delete mode 100644 torch/csrc/distributed/c10d/UCCForNCCL.hpp

diff --git a/CMakeLists.txt b/CMakeLists.txt
index 713f98a48df6..c243652416ec 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -328,9 +328,6 @@ cmake_dependent_option(
     USE_C10D_GLOO "USE C10D GLOO" ON "USE_DISTRIBUTED;USE_GLOO" OFF)
 cmake_dependent_option(
     USE_C10D_NCCL "USE C10D NCCL" ON "USE_DISTRIBUTED;USE_NCCL" OFF)
-cmake_dependent_option(
-    USE_NCCL_WITH_UCC "Enable UCC support for ProcessGroupNCCL. Only available if USE_C10D_NCCL is on." OFF
-    "USE_C10D_NCCL" OFF)
 cmake_dependent_option(
     USE_C10D_MPI "USE C10D MPI" ON "USE_DISTRIBUTED;USE_MPI" OFF)
 cmake_dependent_option(
diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt
index 76585f7571b1..748363725bcc 100644
--- a/caffe2/CMakeLists.txt
+++ b/caffe2/CMakeLists.txt
@@ -1330,9 +1330,6 @@ if(USE_DISTRIBUTED)
       target_compile_definitions(torch_hip PUBLIC USE_C10D_NCCL)
     else()
       target_compile_definitions(torch_cuda PUBLIC USE_C10D_NCCL)
-      if(USE_NCCL_WITH_UCC)
-        target_compile_definitions(torch_cuda PUBLIC USE_NCCL_WITH_UCC)
-      endif()
     endif()
   endif()
   if(USE_MPI AND USE_C10D_MPI)
diff --git a/cmake/Summary.cmake b/cmake/Summary.cmake
index 8c7604f47cc3..0cb9aef3e621 100644
--- a/cmake/Summary.cmake
+++ b/cmake/Summary.cmake
@@ -154,7 +154,6 @@ function(caffe2_print_configuration_summary)
   message(STATUS "  USE_NCCL              : ${USE_NCCL}")
   if(${USE_NCCL})
     message(STATUS "    USE_SYSTEM_NCCL     : ${USE_SYSTEM_NCCL}")
-    message(STATUS "    USE_NCCL_WITH_UCC   : ${USE_NCCL_WITH_UCC}")
   endif()
   message(STATUS "  USE_NNPACK            : ${USE_NNPACK}")
   message(STATUS "  USE_NUMPY             : ${USE_NUMPY}")
diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
index d2b74c046918..5fd0a59acf3d 100644
--- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
+++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
@@ -1,6 +1,5 @@
 #include 
 #include 
-#include 
 #include 
 #include 
 #include 
@@ -1064,28 +1063,6 @@ ProcessGroupNCCL::ProcessGroupNCCL(
         &cacheAllocatorDeregisterHook);
     allocatorHooksAttached = true;
   }
-
-#ifdef USE_NCCL_WITH_UCC
-  static c10::once_flag initialize_ucc_lib_flag;
-  c10::call_once(initialize_ucc_lib_flag, [&] {
-    uccLib_ = loadTorchUCC();
-    if (uccLib_ != nullptr) {
-      LOG(INFO) << "[Rank " << rank_ << "] torch_ucc.so loaded";
-    }
-  });
-
-  if (uccLib_ != nullptr) {
-    LOG(INFO) << "[Rank " << rank_ << "] torch_ucc.so loaded";
-    typedef c10::intrusive_ptr fn(
-        const c10::intrusive_ptr& store, int rank, int size);
-    auto createProcessGroupUCC =
-        reinterpret_cast(uccLib_->sym("createProcessGroupUCC"));
-    if (createProcessGroupUCC != nullptr) {
-      uccPG_ = createProcessGroupUCC(store, rank_, size_);
-      LOG(INFO) << "[Rank " << rank_ << "] ProcessGroupUCC created.";
-    }
-  }
-#endif
 }
 
 void ProcessGroupNCCL::runHealthCheck() {
@@ -4134,18 +4111,6 @@ c10::intrusive_ptr ProcessGroupNCCL::_allgather_base(
       avoidRecordStreams);
 }
 
-#ifdef USE_NCCL_WITH_UCC
-std::shared_ptr ProcessGroupNCCL::uccLib_ = nullptr;
-#endif
-
-bool ProcessGroupNCCL::isUCCAvailable() const {
-#ifdef USE_NCCL_WITH_UCC
-  return (uccPG_ != nullptr);
-#else
-  return false;
-#endif
-}
-
 } // namespace c10d
 
 #endif // USE_C10D_NCCL
diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp
index 6404d01f6cc7..7d7b12fbd556 100644
--- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp
+++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp
@@ -12,7 +12,6 @@
 #include 
 #include 
 #include 
-#include 
 
 #include 
 #include 
@@ -530,9 +529,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
   // Provide an API for users to define their own ways to store NCCL debug info.
   void registerDebugInfoWriter(std::unique_ptr writer);
 
-  // Tests if the UCC fallback path is available
-  bool isUCCAvailable() const;
-
   // Provides an API to abort the ProcessGroup (similar to ncclCommAbort)
   // instead of relying on ProcessGroupNCCL destructor.
   void abort(c10::optional abortReason = c10::nullopt);
@@ -899,11 +895,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
   // The callback function to store NCCL debug info.
   std::unique_ptr debugInfoWriter_ = nullptr;
 
-#ifdef USE_NCCL_WITH_UCC
-  // ProcessGroupUCC shared library handle and ProcessGroup pointer
-  static std::shared_ptr uccLib_;
-  c10::intrusive_ptr uccPG_;
-#endif
   size_t uid_;
 };
 
diff --git a/torch/csrc/distributed/c10d/UCCForNCCL.hpp b/torch/csrc/distributed/c10d/UCCForNCCL.hpp
deleted file mode 100644
index 5ed2545f7145..000000000000
--- a/torch/csrc/distributed/c10d/UCCForNCCL.hpp
+++ /dev/null
@@ -1,27 +0,0 @@
-#pragma once
-
-#include 
-#include 
-#include 
-#include 
-
-#include 
-
-namespace c10d {
-
-inline std::shared_ptr loadTorchUCC() {
-  const char* path = std::getenv("TORCH_UCC_LIBRARY_PATH");
-  if (path != nullptr) {
-    try {
-      return std::make_shared(path);
-    } catch (const c10::DynamicLibraryError& e) {
-      TORCH_WARN(
-          "TORCH_UCC_LIBRARY_PATH is set, "
-          "but the loading of torch_ucc.so failed with:",
-          e.msg());
-    }
-  }
-  return nullptr;
-}
-
-} // namespace c10d
diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp
index 59ff55db9992..aae99cc239a6 100644
--- a/torch/csrc/distributed/c10d/init.cpp
+++ b/torch/csrc/distributed/c10d/init.cpp
@@ -2294,9 +2294,7 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
               "comm_split_count",
               &::c10d::ProcessGroupNCCL::getCommSplitCounter)
           .def_property_readonly(
-              "options", &::c10d::ProcessGroupNCCL::getOptions)
-          .def_property_readonly(
-              "is_ucc_available", &::c10d::ProcessGroupNCCL::isUCCAvailable);
+              "options", &::c10d::ProcessGroupNCCL::getOptions);
 
 #ifdef NCCL_HAS_COMM_CTA_CGA
   py::class_(

From 9bab96c78c899647fd4966a5a8769fdfc4916908 Mon Sep 17 00:00:00 2001
From: CYuxian 
Date: Wed, 22 Nov 2023 15:40:57 +0000
Subject: [PATCH 104/221] [ONNX] Consider negative dim in
 _index_fill_reshape_helper (#114050)

Fix export issue of index_copy op with negative dim.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114050
Approved by: https://github.com/thiagocrepaldi
---
 test/onnx/test_pytorch_onnx_onnxruntime.py | 9 +++++++--
 torch/onnx/symbolic_helper.py              | 2 ++
 2 files changed, 9 insertions(+), 2 deletions(-)

diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py
index 05171d3ef995..6971f815d317 100644
--- a/test/onnx/test_pytorch_onnx_onnxruntime.py
+++ b/test/onnx/test_pytorch_onnx_onnxruntime.py
@@ -3771,13 +3771,18 @@ def forward(self, input):
     @skipIfUnsupportedMinOpsetVersion(9)
     def test_index_copy(self):
         class IndexCopyModel(torch.nn.Module):
+            def __init__(self, dim):
+                super().__init__()
+                self.dim = dim
+
             def forward(self, input):
                 index = torch.tensor([2, 0])
                 source = torch.ones(3, 2, 5)
-                return input.index_copy(1, index, source)
+                return input.index_copy(self.dim, index, source)
 
         x = torch.randn(3, 4, 5, requires_grad=True)
-        self.run_test(IndexCopyModel(), x)
+        for dim in (1, -2):
+            self.run_test(IndexCopyModel(dim), x)
 
     def test_select(self):
         class Select(torch.nn.Module):
diff --git a/torch/onnx/symbolic_helper.py b/torch/onnx/symbolic_helper.py
index d151cc106463..c8b55c7dec99 100644
--- a/torch/onnx/symbolic_helper.py
+++ b/torch/onnx/symbolic_helper.py
@@ -1388,6 +1388,8 @@ def _index_fill_reshape_helper(g: jit_utils.GraphContext, self, dim, index):
         return _unimplemented("index_fill", "input rank not accessible")
     self_dim = self.type().dim()
     dim_value = _parse_arg(dim, "i")
+    if dim_value < 0:
+        dim_value += self_dim
     unsqueezed_index = _unsqueeze_helper(
         g, index, [i for i in range(self_dim) if i != dim_value]
     )

From 9fcf1f9632fc1981e87a5948e5555d05896217b7 Mon Sep 17 00:00:00 2001
From: Angela Yi 
Date: Wed, 22 Nov 2023 16:43:43 +0000
Subject: [PATCH 105/221] [export] Update schema (#114172)

Summary: Will update CustomClassHolder in a followup

Test Plan: CI

Differential Revision: D51343522

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114172
Approved by: https://github.com/zhxchen17
---
 test/export/test_serialize.py    | 15 ++++++++-
 torch/_export/serde/schema.py    | 52 ++++++++++++++++++++++++-----
 torch/_export/serde/serialize.py | 56 ++++++++++++--------------------
 3 files changed, 77 insertions(+), 46 deletions(-)

diff --git a/test/export/test_serialize.py b/test/export/test_serialize.py
index 7f4b184baefd..7eaff4f75ce7 100644
--- a/test/export/test_serialize.py
+++ b/test/export/test_serialize.py
@@ -440,6 +440,19 @@ def forward(self, x, y):
 
         self.check_graph(M(), (torch.rand(3, 2), torch.rand(3, 2)))
 
+    def test_list_of_optional_tensors(self) -> None:
+        class MyModule(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+
+            def forward(self, x, y, z):
+                indices = [None, None, torch.tensor([1, 3, 5, 7])]
+                indexed = torch.ops.aten.index.Tensor(x + y, indices)
+                return indexed + z
+
+        inputs = (torch.rand(8, 8, 8), torch.rand(8, 8, 8), torch.rand(8, 8, 4))
+        self.check_graph(MyModule(), inputs)
+
     @parametrize(
         "name,case",
         get_filtered_export_db_tests(),
@@ -603,7 +616,7 @@ def f(x):
 
             with self.assertRaisesRegex(RuntimeError, r"Serialized version -1 does not match our current"):
                 f.seek(0)
-                loaded_ep = load(f)
+                load(f)
 
     def test_save_constants(self):
         class Foo(torch.nn.Module):
diff --git a/torch/_export/serde/schema.py b/torch/_export/serde/schema.py
index 02f2895951b0..048d9ce8098a 100644
--- a/torch/_export/serde/schema.py
+++ b/torch/_export/serde/schema.py
@@ -81,10 +81,21 @@ class Device:
     index: Optional[int]
 
 
+@dataclass(repr=False)
+class SymExprHint(_Union):
+    as_int: int
+    as_float: float
+    as_bool: bool
+
+
+# This is for storing the symbolic expressions behind symints/symfloats/symbools
+# For example, we can get something like
+# SymExpr(expr_str="s0 + s1", hint=SymExprHint(as_int=4)
+# if we also have the hint that s0 and s1 are both 2.
 @dataclass
 class SymExpr:
     expr_str: str
-    hint: Optional[int]
+    hint: Optional[SymExprHint] = None
 
 
 @dataclass(repr=False)
@@ -95,7 +106,7 @@ class SymInt(_Union):
 
 @dataclass(repr=False)
 class SymBool(_Union):
-    as_expr: str
+    as_expr: SymExpr
     as_bool: bool
 
 
@@ -110,12 +121,24 @@ class TensorMeta:
     layout: Layout
 
 
+# In most cases we will use the "as_name" field to store arguments which are
+# SymInts.
+# The "as_int" field is used in the case where we have a list containing a mix
+# of SymInt and ints (ex. [1, s0, ...]). We will serialize this type of list to
+# be List[SymIntArgument] and map the SymInts to the "as_name" field, and ints
+# to the "as_int" field.
 @dataclass(repr=False)
 class SymIntArgument(_Union):
     as_name: str
     as_int: int
 
 
+# In most cases we will use the "as_name" field to store arguments which are
+# SymBools.
+# The "as_bool" field is used in the case where we have a list containing a mix
+# of SymBool and bools (ex. [True, i0, ...]). We will serialize this type of list to
+# be List[SymboolArgument] and map the SymBools to the "as_name" field, and bools
+# to the "as_bool" field.
 @dataclass(repr=False)
 class SymBoolArgument(_Union):
     as_name: str
@@ -127,6 +150,10 @@ class TensorArgument:
     name: str
 
 
+# This is use for storing the contents of a list which contain optional tensors
+# (Tensor?[], ex. [Tensor, None, ...]), where the list will be serialized to the
+# type List[OptionalTensorArgument], with tensor values seiralized to the
+# "as_tensor" field, and None values serialized to the "as_none" field.
 @dataclass(repr=False)
 class OptionalTensorArgument(_Union):
     as_tensor: str
@@ -173,6 +200,7 @@ class Argument(_Union):
 
 @dataclass
 class NamedArgument:
+    # Argument name from the operator schema
     name: str
     arg: Argument
 
@@ -185,24 +213,24 @@ class Node:
     metadata: Dict[str, str]
 
 
-@dataclass
-class TensorValue:
-    meta: TensorMeta
-
-
 @dataclass
 class Graph:
     inputs: List[Argument]
     outputs: List[Argument]
     nodes: List[Node]
-    tensor_values: Dict[str, TensorValue]
+    tensor_values: Dict[str, TensorMeta]
     sym_int_values: Dict[str, SymInt]
     sym_bool_values: Dict[str, SymBool]
+    # This is for deserializing the submodule graphs from higher order ops
+    # (ex. cond, map) where single tensor returns will just return a single
+    # tensor, rather than following export schema and returning a singleton
+    # list.
     is_single_tensor_return: bool = False
 
 
 @dataclass
 class UserInputSpec:
+    # Actually, only tensors and SymInts are allowed here
     arg: Argument
 
 
@@ -285,6 +313,9 @@ class RangeConstraint:
 class ModuleCallSignature:
     inputs: List[Argument]
     outputs: List[Argument]
+
+    # These are serialized by calling pytree.treespec_loads
+    # And deserialized by calling pytree.treespec_dumps
     in_spec: str
     out_spec: str
 
@@ -299,14 +330,17 @@ class ModuleCallEntry:
 class GraphModule:
     graph: Graph
     signature: GraphSignature
+    # This is used for unflattening, by tracking the calling structure of all of
+    # the modules in order to unflatten the modules back to the eager calling
+    # conventions.
     module_call_graph: List[ModuleCallEntry]
 
 
 @dataclass
 class ExportedProgram:
     graph_module: GraphModule
+    # Key is the opset namespace (ex. aten), and value is the version number
     opset_version: Dict[str, int]
     range_constraints: Dict[str, RangeConstraint]
-    equality_constraints: List[Tuple[Tuple[str, int], Tuple[str, int]]]
     schema_version: int
     dialect: str
diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py
index b8ee8735f29c..d3fad019aa09 100644
--- a/torch/_export/serde/serialize.py
+++ b/torch/_export/serde/serialize.py
@@ -11,7 +11,7 @@
 from contextlib import contextmanager
 from dataclasses import dataclass, field
 from enum import Enum
-from typing import Any, Callable, cast, Dict, Iterator, List, Optional, Tuple, Union
+from typing import Any, Callable, cast, Dict, Iterator, List, Optional, Union
 
 import sympy
 
@@ -55,11 +55,11 @@
     SymBool,
     SymBoolArgument,
     SymExpr,
+    SymExprHint,
     SymInt,
     SymIntArgument,
     TensorArgument,
     TensorMeta,
-    TensorValue,
     TREESPEC_VERSION,
     UserInputSpec,
     UserOutputSpec,
@@ -181,7 +181,10 @@ def serialize_sym_int(s: Union[int, torch.SymInt]) -> SymInt:
             return SymInt.create(as_int=int(s))
         else:
             assert isinstance(s, torch.SymInt)
-            return SymInt.create(as_expr=SymExpr(str(s), s.node.hint))
+            if s.node.hint is None:
+                return SymInt.create(as_expr=SymExpr(str(s)))
+            else:
+                return SymInt.create(as_expr=SymExpr(str(s), hint=SymExprHint.create(as_int=s.node.hint)))
     else:
         raise SerializeError(
             f"SymInt should be either symbol or int, got `{s}` of type `{type(s)}`"
@@ -193,7 +196,7 @@ def serialize_sym_bool(s: Union[bool, torch.SymBool]) -> SymBool:
         if symbolic_shapes.is_concrete_bool(s):
             return SymBool.create(as_bool=bool(s))
         else:
-            return SymBool.create(as_expr=str(s))
+            return SymBool.create(as_expr=SymExpr(expr_str=str(s)))
     else:
         raise SerializeError(
             f"SymBool should be either symbol or bool, got `{s}` of type `{type(s)}`"
@@ -276,24 +279,6 @@ def serialize_range_constraints(
     }
 
 
-def serialize_equality_constraints(
-    equality_constraints: List[Tuple[torch._export.exported_program.InputDim, torch._export.exported_program.InputDim]]
-) -> List[Tuple[Tuple[str, int], Tuple[str, int]]]:
-    return [
-        ((v1.input_name, v1.dim), (v2.input_name, v2.dim))
-        for (v1, v2) in equality_constraints
-    ]
-
-
-def deserialize_equality_constraints(
-    equality_constraints: List[Tuple[Tuple[str, int], Tuple[str, int]]]
-) -> List[Tuple[torch._export.exported_program.InputDim, torch._export.exported_program.InputDim]]:
-    return [
-        (torch._export.exported_program.InputDim(v1[0], v1[1]), torch._export.exported_program.InputDim(v2[0], v2[1]))
-        for (v1, v2) in equality_constraints
-    ]
-
-
 def _is_single_tensor_return(target: torch._ops.OpOverload) -> bool:
     returns = target._schema.returns
     return len(returns) == 1 and isinstance(returns[0].real_type, torch.TensorType)
@@ -314,7 +299,7 @@ class GraphState:
     inputs: List[Argument] = field(default_factory=list)
     outputs: List[Argument] = field(default_factory=list)
     nodes: List[Node] = field(default_factory=list)
-    tensor_values: Dict[str, TensorValue] = field(default_factory=dict)
+    tensor_values: Dict[str, TensorMeta] = field(default_factory=dict)
     sym_int_values: Dict[str, SymInt] = field(default_factory=dict)
     sym_bool_values: Dict[str, SymBool] = field(default_factory=dict)
     is_single_tensor_return: bool = False
@@ -344,9 +329,7 @@ def handle_placeholder(self, node: torch.fx.Node):
         assert node.op == "placeholder"
         if isinstance(node.meta['val'], torch.Tensor):
             graph_input = Argument.create(as_tensor=TensorArgument(name=node.name))
-            self.graph_state.tensor_values[node.name] = TensorValue(
-                meta=serialize_tensor_meta(node.meta["val"])
-            )
+            self.graph_state.tensor_values[node.name] = serialize_tensor_meta(node.meta["val"])
         elif isinstance(node.meta['val'], torch.SymInt):
             raise AssertionError("SymInt graph input is not implemented yet.")
         elif isinstance(node.meta['val'], (int, bool, str, float, type(None))):
@@ -675,7 +658,7 @@ def serialize_optional_tensor_args(a):
 
     def serialize_tensor_output(self, name, meta_val) -> TensorArgument:
         assert name not in self.graph_state.tensor_values
-        self.graph_state.tensor_values[name] = TensorValue(meta=serialize_tensor_meta(meta_val))
+        self.graph_state.tensor_values[name] = serialize_tensor_meta(meta_val)
         return TensorArgument(name=name)
 
     def serialize_sym_int_output(self, name, meta_val) -> SymIntArgument:
@@ -973,14 +956,12 @@ def serialize(self, exported_program: ep.ExportedProgram) -> SerializedArtifact:
             ).serialize(exported_program.graph_module)
         )
         serialized_range_constraints = serialize_range_constraints(exported_program.range_constraints)
-        serialized_equality_constraints = serialize_equality_constraints(exported_program.equality_constraints)
 
         return SerializedArtifact(
             ExportedProgram(
                 graph_module=serialized_graph_module,
                 opset_version=self.opset_version,
                 range_constraints=serialized_range_constraints,
-                equality_constraints=serialized_equality_constraints,
                 schema_version=SCHEMA_VERSION,
                 dialect=exported_program.dialect,
             ),
@@ -1053,7 +1034,13 @@ def deserialize_sym_int(self, s: SymInt) -> Union[int, torch.SymInt]:
                             runtime_max=vr.upper  # type: ignore[arg-type]
                         )
 
-            return self.shape_env.create_symintnode(sym, hint=val.hint)
+            if val.hint is None:
+                hint = None
+            else:
+                assert val.hint.type == "as_int"
+                hint = val.hint.value
+
+            return self.shape_env.create_symintnode(sym, hint=hint)
         elif s.type == "as_int":
             assert isinstance(val, int)
             return val
@@ -1065,7 +1052,7 @@ def deserialize_sym_int(self, s: SymInt) -> Union[int, torch.SymInt]:
     def deserialize_sym_bool(self, s: SymBool) -> Union[bool, torch.SymBool]:
         val = s.value
         if s.type == "as_expr":
-            expr = sympy.sympify(val, locals=self.symbol_name_to_symbol)
+            expr = sympy.sympify(val.expr_str, locals=self.symbol_name_to_symbol)
             return self.shape_env.create_symboolnode(expr)
         elif s.type == "as_bool":
             assert isinstance(val, bool)
@@ -1102,7 +1089,7 @@ def deserialize_graph_output(self, output) -> torch.fx.Node:
     def deserialize_graph(self, serialized_graph: Graph) -> torch.fx.Graph:
         # Handle the tensor metas.
         for name, tensor_value in serialized_graph.tensor_values.items():
-            meta_val = self.deserialize_tensor_meta(tensor_value.meta, self.fake_tensor_mode)
+            meta_val = self.deserialize_tensor_meta(tensor_value, self.fake_tensor_mode)
             self.serialized_name_to_meta[name] = meta_val
 
         for name, sym_int_value in serialized_graph.sym_int_values.items():
@@ -1580,9 +1567,6 @@ def deserialize(
 
         state_dict = deserialize_torch_artifact(serialized_artifact.state_dict)
         tensor_constants = deserialize_torch_artifact(serialized_artifact.tensor_constants)
-        equality_constraints = deserialize_equality_constraints(
-            serialized_artifact.exported_program.equality_constraints
-        )
 
         exported_program = ep.ExportedProgram(
             res.graph_module,
@@ -1590,7 +1574,7 @@ def deserialize(
             res.signature,
             state_dict,  # type: ignore[arg-type]
             range_constraints,
-            equality_constraints,
+            [],
             res.module_call_graph,
             None,
             load_verifier(serialized_artifact.exported_program.dialect),

From 00ae299016fd6f7c5ea68f311860b376bc9df3fa Mon Sep 17 00:00:00 2001
From: Pavan Balaji 
Date: Wed, 22 Nov 2023 17:31:16 +0000
Subject: [PATCH 106/221] [c10d] Remove unused function (#114341)

Summary: As the title suggests

Test Plan: OSS CI

Differential Revision: D51386619

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114341
Approved by: https://github.com/Skylion007
---
 torch/csrc/distributed/c10d/Utils.hpp | 4 ----
 1 file changed, 4 deletions(-)

diff --git a/torch/csrc/distributed/c10d/Utils.hpp b/torch/csrc/distributed/c10d/Utils.hpp
index dd878b96fbe3..122c7732534f 100644
--- a/torch/csrc/distributed/c10d/Utils.hpp
+++ b/torch/csrc/distributed/c10d/Utils.hpp
@@ -187,10 +187,6 @@ inline bool getCvarBool(const std::vector& env, bool def) {
   return ret;
 }
 
-inline bool parseEnvVarFlag(const char* envVarName) {
-  return getCvarBool({envVarName}, false);
-}
-
 inline void assertSameSizes(
     const at::IntArrayRef& sizes,
     const std::vector& tensors) {

From b927a4e2cad62317c8d0bdadf94a20da1d0825ef Mon Sep 17 00:00:00 2001
From: PyTorch MergeBot 
Date: Wed, 22 Nov 2023 17:43:51 +0000
Subject: [PATCH 107/221] Revert "Opportunistically use `ncclCommSplit` when
 creating new NCCL groups (#112889)"

This reverts commit 64a5372e6ce9b6ca0ee5c7482b27e24561725b28.

Reverted https://github.com/pytorch/pytorch/pull/112889 on behalf of https://github.com/huydhn due to Sorry for reverting you change, but it is failing ROCm distributed jobs in trunk https://hud.pytorch.org/pytorch/pytorch/commit/4d07428edee863e7f5920f0672957a9711a9f0b5 ([comment](https://github.com/pytorch/pytorch/pull/112889#issuecomment-1823214376))
---
 test/cpp/c10d/ProcessGroupNCCLTest.cpp        | 78 +++----------------
 test/distributed/test_c10d_nccl.py            | 22 +-----
 torch/csrc/distributed/c10d/NCCLUtils.hpp     | 26 -------
 .../distributed/c10d/ProcessGroupNCCL.cpp     | 47 ++---------
 .../distributed/c10d/ProcessGroupNCCL.hpp     | 11 ---
 torch/csrc/distributed/c10d/init.cpp          | 14 +---
 torch/distributed/distributed_c10d.py         | 35 +--------
 7 files changed, 21 insertions(+), 212 deletions(-)

diff --git a/test/cpp/c10d/ProcessGroupNCCLTest.cpp b/test/cpp/c10d/ProcessGroupNCCLTest.cpp
index 6a0d60b57315..61e9753988ea 100644
--- a/test/cpp/c10d/ProcessGroupNCCLTest.cpp
+++ b/test/cpp/c10d/ProcessGroupNCCLTest.cpp
@@ -31,20 +31,12 @@ class NCCLTestBase {
     pg_ = std::move(other.pg_);
   }
 
-  std::shared_ptr<::c10d::ProcessGroupNCCL> getProcessGroup() {
-    return pg_;
+  ::c10d::ProcessGroupNCCL& getProcessGroup() {
+    return *pg_;
   }
 
-  ::c10::intrusive_ptr<::c10d::Store>& getProcessGroupStore() {
-    return store_;
-  }
-
-  void initialize(
-      int rank,
-      int size,
-      c10::optional<::std::shared_ptr<::c10d::ProcessGroupNCCL>> split_from =
-          c10::nullopt) {
-    store_ = c10::make_intrusive<::c10d::FileStore>(path_, size);
+  void initialize(int rank, int size) {
+    auto store = c10::make_intrusive<::c10d::FileStore>(path_, size);
 
     c10::intrusive_ptr opts =
         c10::make_intrusive();
@@ -53,22 +45,14 @@ class NCCLTestBase {
         c10d::TORCH_ENABLE_NCCL_HEALTH_CHECK[0].c_str(),
         "1",
         /* overwrite */ 1);
-#ifdef NCCL_HAS_COMM_SPLIT
-    if (split_from) {
-      opts->split_from = *split_from;
-      opts->split_color = ++color_;
-    }
-#endif
     pg_ = std::unique_ptr<::c10d::ProcessGroupNCCL>(
-        new ::c10d::ProcessGroupNCCL(store_, rank, size, std::move(opts)));
+        new ::c10d::ProcessGroupNCCL(store, rank, size, std::move(opts)));
   }
 
  protected:
   std::string path_;
-  std::shared_ptr<::c10d::ProcessGroupNCCL> pg_;
+  std::unique_ptr<::c10d::ProcessGroupNCCL> pg_;
   std::chrono::milliseconds pgTimeout_;
-  ::c10::intrusive_ptr<::c10d::Store> store_;
-  int color_{1};
 };
 
 class NCCLTest : public NCCLTestBase {
@@ -734,9 +718,9 @@ void testSequenceNumInit(
   auto runTest = [&](int i) {
     NCCLTest test(path, worldSize);
     test.initialize(i, worldSize);
-    test.getProcessGroup()->setSequenceNumberForGroup();
+    test.getProcessGroup().setSequenceNumberForGroup();
     std::lock_guard lock(m);
-    auto seqNum = test.getProcessGroup()->getSequenceNumberForGroup();
+    auto seqNum = test.getProcessGroup().getSequenceNumberForGroup();
     nums.insert(seqNum);
   };
   std::vector threads;
@@ -893,55 +877,11 @@ TEST_F(ProcessGroupNCCLTest, testBackendName) {
     auto test = NCCLTestBase(file.path);
     test.initialize(rank_, size_);
     EXPECT_EQ(
-        test.getProcessGroup()->getBackendName(),
+        test.getProcessGroup().getBackendName(),
         std::string(c10d::NCCL_BACKEND_NAME));
   }
 }
 
-TEST_F(ProcessGroupNCCLTest, testSplittingCommunicator) {
-  if (skipTest()) {
-    return;
-  }
-  TemporaryFile file;
-  auto test1 = BroadcastNCCLTest(file.path, size_);
-  test1.initialize(rank_, size_);
-
-  auto test2 = BroadcastNCCLTest(file.path, size_);
-  test2.initialize(rank_, size_, test1.getProcessGroup());
-
-  // Steal the broadcast test and issue it for both of our groups.
-  // This ensures consistent full collective communication.  TODO:
-  // maybe refactor the guts rather than copy-pasta, but it may not be
-  // worth it.
-  for (auto test : {&test1, &test2}) {
-    const int numDevices = test->numDevices();
-    // try every permutation of root rank and root tensor
-    for (const auto rootRank : c10::irange(size_)) {
-      for (const auto rootTensor : c10::irange(numDevices)) {
-        auto work = test->run(rootRank, rootTensor);
-        test->wait(work);
-
-        // Check results
-        const auto expected = (rootRank * numDevices + rootTensor);
-        const auto tensors = test->getTensors();
-        for (const auto& tensor : tensors) {
-          const auto* const data = tensor.data_ptr();
-          for (const auto k : c10::irange(tensor.numel())) {
-            EXPECT_EQ(data[k], expected)
-                << "Broadcast outputs do not match expected outputs";
-          }
-        }
-      }
-    }
-  }
-
-  // Now that we've run full operations on both the original and split process
-  // group, ensure we saw exactly as many splits as we expected: 0 in the
-  // original process group, and one per device in the second.
-  EXPECT_EQ(test2.getProcessGroup()->getCommSplitCounter(), 0);
-  EXPECT_EQ(test1.getProcessGroup()->getCommSplitCounter(), test1.numDevices());
-}
-
 #ifdef IS_NCCL_EXP
 TEST_F(ProcessGroupNCCLTest, testSparseAllreduce) {
   if (skipTest()) {
diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py
index 4ac72c2bd207..ada84507aef9 100644
--- a/test/distributed/test_c10d_nccl.py
+++ b/test/distributed/test_c10d_nccl.py
@@ -1272,27 +1272,6 @@ def allgather_base(output_t, input_t):
         # Verification
         self.assertEqual(torch.arange(self.world_size), output_t)
 
-    @requires_nccl()
-    def test_comm_split_optimization(self):
-        store = c10d.FileStore(self.file_name, self.world_size)
-        pg = self._create_process_group_nccl(store, self.opts())
-
-        # Test lazy splitting behavior across each per-device backend.
-        for device in self.rank_to_GPU[self.rank]:
-            backend = pg._get_backend(torch.device(device))
-
-            # split doesn't happen unless the original process group has lazily
-            # created communicators, so first verify we haven't split even when
-            # making the new group and running an operation on the original pg.
-            ng = c10d.new_group()
-            tensor = torch.tensor([self.rank]).cuda(device)
-            pg.broadcast(tensor, 0)
-            self.assertEqual(backend.comm_split_count(), 0)
-
-            # The new group will force a split of the original on first use.
-            ng.broadcast(tensor, 0)
-            self.assertEqual(backend.comm_split_count(), 1)
-
 class DistributedDataParallelTest(
     test_c10d_common.CommonDistributedDataParallelTest, MultiProcessTestCase
 ):
@@ -3697,6 +3676,7 @@ def gather_trace():
 
 
 
+
 if __name__ == "__main__":
     assert (
         not torch.cuda._initialized
diff --git a/torch/csrc/distributed/c10d/NCCLUtils.hpp b/torch/csrc/distributed/c10d/NCCLUtils.hpp
index 2b4885f02ffc..e6c05e228cfd 100644
--- a/torch/csrc/distributed/c10d/NCCLUtils.hpp
+++ b/torch/csrc/distributed/c10d/NCCLUtils.hpp
@@ -17,11 +17,6 @@
 #define NCCL_HAS_COMM_NONBLOCKING
 #endif
 
-#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \
-    (NCCL_MINOR >= 18)
-#define NCCL_HAS_COMM_SPLIT
-#endif
-
 // ncclGetLastError() is enabled only for NCCL versions 2.13+
 // ncclRemoteError only exists in NCCL versions 2.13+
 #if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \
@@ -251,22 +246,6 @@ class NCCLComm {
   }
 #endif
 
-#ifdef NCCL_HAS_COMM_SPLIT
-  static std::shared_ptr split(
-      NCCLComm* source,
-      int color_id,
-      int rank,
-      ncclConfig_t& config) {
-    auto comm = std::make_shared();
-    C10D_NCCL_CHECK(
-        ncclCommSplit(
-            source->ncclComm_, color_id, rank, &(comm->ncclComm_), &config),
-        c10::nullopt);
-    ++source->ncclCommSplitCounter_;
-    return comm;
-  }
-#endif
-
   ncclUniqueId getNcclId() {
     return ncclId_;
   }
@@ -346,10 +325,6 @@ class NCCLComm {
     return aborted_;
   }
 
-  uint64_t getCommSplitCounter() const {
-    return ncclCommSplitCounter_;
-  }
-
   ncclResult_t checkForNcclError() {
     std::unique_lock lock(mutex_);
 #ifdef ENABLE_NCCL_ERROR_CHECKING
@@ -426,7 +401,6 @@ class NCCLComm {
   // Unique nccl_id for this communicator.
   ncclUniqueId ncclId_;
   bool aborted_;
-  uint64_t ncclCommSplitCounter_{0};
   ncclResult_t ncclAsyncErr_;
   mutable std::mutex mutex_;
   // Rank that this communicator corresponds to.
diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
index 5fd0a59acf3d..40eb9d06ef0d 100644
--- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
+++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
@@ -1875,40 +1875,11 @@ std::vector>& ProcessGroupNCCL::getNCCLComm(
     int deviceIndex = devices[i].index();
 
     gpuGuard.set_index(deviceIndex);
-#ifdef NCCL_HAS_COMM_SPLIT
-    if (options_->split_from) {
-      TORCH_CHECK(
-          options_->split_color != 0,
-          "Must specify a non-zero color when splitting");
-      // Find a valid, healthy communicator to split from if possible.
-      std::lock_guard lock(options_->split_from->mutex_);
-      auto& other_comms = options_->split_from->devNCCLCommMap_;
-      auto dit = other_comms.find(devicesKey);
-      if (dit != other_comms.end() && !dit->second.empty()) {
-        TORCH_INTERNAL_ASSERT(
-            dit->second.size() == ncclComms.size(),
-            "split_from->devNCCLCommMap_ should be empty or the same size as ncclComms!");
-        if (dit->second[i] && !dit->second[i]->isAborted()) {
-          ncclComms[i] = NCCLComm::split(
-              dit->second[i].get(),
-              options_->split_color,
-              rank,
-              options_->config);
-        }
-      }
-    }
-#endif
-
-    // To simplify conditioonal nesting, just create the ncclComms[i]
-    // entry if it hasn't been yet rather than untangling the
-    // conditions that might have resulted in a split above.
-    if (!ncclComms[i]) {
 #ifdef NCCL_HAS_COMM_NONBLOCKING
-      ncclComms[i] = NCCLComm::create(numRanks, rank, ncclID, options_->config);
+    ncclComms[i] = NCCLComm::create(numRanks, rank, ncclID, options_->config);
 #else
-      ncclComms[i] = NCCLComm::create(numRanks, rank, ncclID);
+    ncclComms[i] = NCCLComm::create(numRanks, rank, ncclID);
 #endif
-    }
 
     // Creates the NCCL streams
     streamVal.push_back(
@@ -1954,6 +1925,9 @@ std::vector>& ProcessGroupNCCL::getNCCLComm(
       std::make_tuple(devicesKey),
       std::make_tuple(devices.size()));
 
+  // Hold the lock before modifying the cache.
+  std::lock_guard lock(mutex_);
+
   // Record the communicators based on ncclUniqueId.
   ncclIdToCommMap_.emplace(buildNcclUniqueIdStr(ncclID), ncclComms);
 
@@ -1997,20 +1971,9 @@ std::vector>& ProcessGroupNCCL::getNCCLComm(
   it = devNCCLCommMap_.find(devicesKey);
   TORCH_INTERNAL_ASSERT(
       it != devNCCLCommMap_.end(), "Communicators not populated in cache!");
-
   return it->second;
 }
 
-uint64_t ProcessGroupNCCL::getCommSplitCounter() const {
-  uint64_t ret = 0;
-  for (const auto& i : ncclIdToCommMap_) {
-    for (const auto& j : i.second) {
-      ret += j->getCommSplitCounter();
-    }
-  }
-  return ret;
-}
-
 namespace {
 
 // Check validity of tensor
diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp
index 7d7b12fbd556..b47983cf0e4b 100644
--- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp
+++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp
@@ -341,13 +341,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
     // Configure ranks
     ncclConfig_t config = NCCL_CONFIG_INITIALIZER;
 #endif
-
-    // Optional "parent" backend and color to create communicators from
-    // via `ncclCommSplit`
-#ifdef NCCL_HAS_COMM_SPLIT
-    std::shared_ptr split_from;
-    int64_t split_color{0};
-#endif
   };
 
   // If you wish to create multiple process groups, each with a potentially
@@ -516,10 +509,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
   // may indicate that there is some sort of collective desynchronization.
   uint64_t getSequenceNumberForGroup() override;
 
-  // Return the total number of splits the communicators held by this process
-  // group have performed.
-  uint64_t getCommSplitCounter() const;
-
   void registerOnCompletionHook(
       std::function)>&& hook) override;
   void waitForPendingWorks() override;
diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp
index aae99cc239a6..909773dfe47e 100644
--- a/torch/csrc/distributed/c10d/init.cpp
+++ b/torch/csrc/distributed/c10d/init.cpp
@@ -2290,9 +2290,6 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
               py::call_guard())
           .def("_group_start", &::c10d::ProcessGroupNCCL::groupStart)
           .def("_group_end", &::c10d::ProcessGroupNCCL::groupEnd)
-          .def(
-              "comm_split_count",
-              &::c10d::ProcessGroupNCCL::getCommSplitCounter)
           .def_property_readonly(
               "options", &::c10d::ProcessGroupNCCL::getOptions);
 
@@ -2355,18 +2352,15 @@ Example::
       )")
       .def(py::init(), py::arg("is_high_priority_stream") = false)
 #ifdef NCCL_HAS_COMM_CTA_CGA
-      .def_readwrite("config", &::c10d::ProcessGroupNCCL::Options::config)
-#endif
       .def_readwrite(
           "is_high_priority_stream",
           &::c10d::ProcessGroupNCCL::Options::is_high_priority_stream)
-#ifdef NCCL_HAS_COMM_SPLIT
+      .def_readwrite("config", &::c10d::ProcessGroupNCCL::Options::config);
+#else
       .def_readwrite(
-          "split_from", &::c10d::ProcessGroupNCCL::Options::split_from)
-      .def_readwrite(
-          "split_color", &::c10d::ProcessGroupNCCL::Options::split_color)
+          "is_high_priority_stream",
+          &::c10d::ProcessGroupNCCL::Options::is_high_priority_stream);
 #endif
-      ;
 
 #endif
 
diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py
index 3bd35709505d..63f6c48d35f3 100644
--- a/torch/distributed/distributed_c10d.py
+++ b/torch/distributed/distributed_c10d.py
@@ -8,7 +8,6 @@
 import logging
 import os
 import pickle
-import sys
 import time
 import warnings
 from collections import namedtuple
@@ -1315,29 +1314,7 @@ def _new_process_group_helper(
                 pg_options.is_high_priority_stream = False
             pg_options._timeout = timeout
 
-            # If our new group includes all ranks, we can reduce
-            # overhead by splitting the communicator (`nccCommSplit`).
-
-            # TODO: support this in the general case by calling
-            # `nccCommSplit` with `NCCL_SPLIT_NOCOLOR` for the ranks
-            # not in the communicator.
-            split_from = None
-            if (
-                is_initialized()
-                and _world.default_pg._get_backend_name() == Backend.NCCL
-                and len(global_ranks_in_group) == _world.default_pg.size()
-            ):
-                # If possible, find a backend to split from by peeling
-                # process group wrappers from the world's default pg.
-                split_from = _world.default_pg._get_backend(_get_pg_default_device())
-                while isinstance(split_from, _ProcessGroupWrapper):
-                    split_from = split_from.wrapped_pg
-
-                if split_from:
-                    pg_options.split_from = split_from
-                    pg_options.split_color = _process_group_color(global_ranks_in_group)
-            backend_class = ProcessGroupNCCL(
-                backend_prefix_store, group_rank, group_size, pg_options)
+            backend_class = ProcessGroupNCCL(backend_prefix_store, group_rank, group_size, pg_options)
             backend_type = ProcessGroup.BackendType.NCCL
         elif backend_str == Backend.UCC and is_ucc_available():
             # TODO: once UCC plugin is fully deprecated, remove
@@ -3537,19 +3514,11 @@ def _create_process_group_wrapper(
     wrapped_pg = _ProcessGroupWrapper(wrapped_pg, helper_pg)
     return wrapped_pg
 
-# helper function for deterministically hashing a list of ranks
-def _hash_ranks(ranks: List[int]):
-    return hashlib.sha1(bytes("_".join(map(str, ranks)), "utf-8")).hexdigest()
-
-# Takes a list of ranks and computes an integer color
-def _process_group_color(ranks: List[int]) -> int:
-    # Convert our hash to an int, but avoid negative numbers by shifting a bit.
-    return int(_hash_ranks(ranks), 16) % (sys.maxsize >> 1)
 
 def _process_group_name(ranks, use_hashed_name):
     global _world
     if use_hashed_name:
-        pg_name = _hash_ranks(ranks)
+        pg_name = hashlib.sha1(bytes("_".join(map(str, ranks)), "utf-8")).hexdigest()
         while pg_name in _world.pg_names.values():
             pg_name = hashlib.sha1(bytes(pg_name + "_", "utf-8")).hexdigest()
     else:

From d6578b36781b0402163e230e8838a802bf51ba0f Mon Sep 17 00:00:00 2001
From: Jerry Zhang 
Date: Tue, 21 Nov 2023 21:39:27 -0800
Subject: [PATCH 108/221] [quant][pt2e] Refactor some internal code for
 observer insertion (#113500)

Summary:
att

Test Plan:
python test/test_quantization.py TestQuantizePT2E

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113500
Approved by: https://github.com/kimishpatel
---
 torch/ao/quantization/pt2e/prepare.py | 27 ++++++++++++---------------
 1 file changed, 12 insertions(+), 15 deletions(-)

diff --git a/torch/ao/quantization/pt2e/prepare.py b/torch/ao/quantization/pt2e/prepare.py
index 25e90f2fcf0d..5f89b7282502 100644
--- a/torch/ao/quantization/pt2e/prepare.py
+++ b/torch/ao/quantization/pt2e/prepare.py
@@ -30,7 +30,7 @@
 ]
 
 
-def _find_root(edge_or_node: EdgeOrNode, shared_with_map: Dict[EdgeOrNode, EdgeOrNode]) -> EdgeOrNode:
+def _find_root_edge_or_node(edge_or_node: EdgeOrNode, shared_with_map: Dict[EdgeOrNode, EdgeOrNode]) -> EdgeOrNode:
     """Find the root node for the sharing tree
     Args:
         edge_or_node: edge/node that we want to find the root
@@ -42,7 +42,7 @@ def _find_root(edge_or_node: EdgeOrNode, shared_with_map: Dict[EdgeOrNode, EdgeO
     parent = shared_with_map[edge_or_node]
     if parent == edge_or_node:
         return edge_or_node
-    root = _find_root(parent, shared_with_map)
+    root = _find_root_edge_or_node(parent, shared_with_map)
     # path compression
     shared_with_map[edge_or_node] = root
     return root
@@ -50,8 +50,8 @@ def _find_root(edge_or_node: EdgeOrNode, shared_with_map: Dict[EdgeOrNode, EdgeO
 def _union(parent: EdgeOrNode, child: EdgeOrNode, shared_with_map: Dict[EdgeOrNode, EdgeOrNode]) -> None:
     """Merge the subtree for `child` with `parent`, the order is important here
     """
-    root_parent = _find_root(parent, shared_with_map)
-    root_child = _find_root(child, shared_with_map)
+    root_parent = _find_root_edge_or_node(parent, shared_with_map)
+    root_child = _find_root_edge_or_node(child, shared_with_map)
     # union the two trees by pointing the root of child to root of parent
     shared_with_map[root_child] = root_parent
 
@@ -66,22 +66,21 @@ def _update_shared_with(edge_or_node: EdgeOrNode, qspec: QuantizationSpecBase, s
         # qspec for a = SharedQuantizationSpec(b) means `a` points to `b`
         _union(sharing_with, edge_or_node, shared_with_map)
 
-# TODO: simplify this
-def _find_root_qspec(
+def _unwrap_shared_qspec(
     qspec: QuantizationSpecBase,
     edge_or_node_to_qspec: Dict[EdgeOrNode, QuantizationSpecBase],
     shared_with_map: Dict[EdgeOrNode, EdgeOrNode]
 ) -> QuantizationSpecBase:
     """Unwraps qspec to get the final root qspec (non SharedQuantizationSpec)
     if qspec is SharedQuantizationSpec
-       (1). tries to find the root node for the node that the qspec points to
+       (1). tries to find the root edge or node for the node that the qspec points to
        (2). recursively find the root qspec based on the qspec for the root node
     """
     if isinstance(qspec, SharedQuantizationSpec):
         sharing_with = qspec.edge_or_node
-        root = _find_root(sharing_with, shared_with_map)
+        root = _find_root_edge_or_node(sharing_with, shared_with_map)
         qspec = edge_or_node_to_qspec[root]
-        return _find_root_qspec(qspec, edge_or_node_to_qspec, shared_with_map)
+        return _unwrap_shared_qspec(qspec, edge_or_node_to_qspec, shared_with_map)
     return qspec
 
 def _has_same_dtype(qspec_a: QuantizationSpecBase, qspec_b: QuantizationSpecBase):
@@ -115,10 +114,11 @@ def _get_edge_or_node_to_qspec(model: torch.fx.GraphModule) -> Dict[EdgeOrNode,
     return edge_or_node_to_qspec
 
 def _union_input_edge_with(input_edge, input_edge_root_qspec, edge_or_node, edge_or_node_to_qspec, shared_with_map):
+    # find root_qspec for `arg` Node (the output of previous node)
     root_qspec = None
     if edge_or_node in edge_or_node_to_qspec:
         qspec = edge_or_node_to_qspec[edge_or_node]
-        root_qspec = _find_root_qspec(qspec, edge_or_node_to_qspec, shared_with_map)
+        root_qspec = _unwrap_shared_qspec(qspec, edge_or_node_to_qspec, shared_with_map)
     # TODO: add assertions for types of root qspecs
     if (
         root_qspec is not None and
@@ -190,11 +190,8 @@ def _get_edge_or_node_to_group_id(edge_or_node_to_qspec: Dict[EdgeOrNode, Quanti
             _update_shared_with(output_node, qspec, shared_with_map)
         else:
             input_edge = edge_or_node
-            input_edge_root = _find_root(input_edge, shared_with_map)
-            input_edge_root_qspec = edge_or_node_to_qspec[input_edge_root]
-            input_edge_root_qspec = _find_root_qspec(input_edge_root_qspec, edge_or_node_to_qspec, shared_with_map)
+            input_edge_root_qspec = _unwrap_shared_qspec(qspec, edge_or_node_to_qspec, shared_with_map)
 
-            # find root_qspec for `arg` Node (the output of previous node)
             assert isinstance(input_edge, tuple)
             arg, n = input_edge
             if n.meta["quantization_annotation"].allow_implicit_sharing:
@@ -221,7 +218,7 @@ def _get_edge_or_node_to_group_id(edge_or_node_to_qspec: Dict[EdgeOrNode, Quanti
     cur_group_id = 0
     edge_or_node_to_group_id: Dict[EdgeOrNode, int] = {}
     for edge_or_node in shared_with_map.keys():
-        root = _find_root(edge_or_node, shared_with_map)
+        root = _find_root_edge_or_node(edge_or_node, shared_with_map)
         if root not in edge_or_node_to_group_id:
             edge_or_node_to_group_id[root] = cur_group_id
             cur_group_id += 1

From 1f1ff629a82432b6a5e8f2d8334166b36c52e3c4 Mon Sep 17 00:00:00 2001
From: Isuru Fernando 
Date: Tue, 21 Nov 2023 21:06:48 +0000
Subject: [PATCH 109/221] Use parent class attribute  supports_out for
 foreach_zero opinfo (#112778)

Instead of introducing a new has_no_out_of_place attribute
Also fixes foreach_copy tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112778
Approved by: https://github.com/lezcano
---
 test/test_foreach.py                          | 30 +++++++++----------
 .../_internal/common_methods_invocations.py   | 16 ++--------
 torch/testing/_internal/opinfo/core.py        | 10 ++++---
 3 files changed, 23 insertions(+), 33 deletions(-)

diff --git a/test/test_foreach.py b/test/test_foreach.py
index 52d91f681c30..1ac70fdc8cc6 100644
--- a/test/test_foreach.py
+++ b/test/test_foreach.py
@@ -126,7 +126,7 @@ def test_all_zero_size_tensors_do_not_launch_kernel(self, device, dtype, op):
         wrapped_op, _, inplace_op, _ = self._get_funcs(op)
 
         for sample in op.sample_zero_size_inputs(device, dtype):
-            if not op.has_no_out_of_place:
+            if op.supports_out:
                 wrapped_op((sample.input, *sample.args), is_cuda=self.is_cuda, expect_fastpath=True, zero_size=True)
             with InplaceForeachVersionBumpCheck(self, sample.input):
                 inplace_op((sample.input, *sample.args), is_cuda=self.is_cuda, expect_fastpath=True, zero_size=True)
@@ -168,7 +168,7 @@ def test_parity(self, device, dtype, op, noncontiguous, inplace):
             except Exception as e:
                 with (
                     self.assertRaisesRegex(type(e), re.escape(str(e)))
-                    if not (op.has_no_in_place or op.has_no_out_of_place)
+                    if not (op.has_no_in_place or not op.supports_out)
                     else self.assertRaises(type(e))
                 ):
                     ref([ref_input, *sample.ref_args], **ref_kwargs)
@@ -355,7 +355,7 @@ def test_add_scalar_with_empty_list_and_empty_tensor(self, device, dtype):
             self.assertEqual(res, tensors)
 
     @ops(
-        filter(lambda op: not op.has_no_out_of_place, foreach_binary_op_db),
+        filter(lambda op: op.supports_out, foreach_binary_op_db),
         dtypes=OpDTypes.supported,
     )
     def test_binary_op_scalar_with_overlapping_tensors(self, device, dtype, op):
@@ -374,7 +374,7 @@ def test_binary_op_scalar_with_overlapping_tensors(self, device, dtype, op):
         self.assertEqual(res, expected)
 
     @ops(
-        filter(lambda op: not op.has_no_out_of_place, foreach_binary_op_db),
+        filter(lambda op: op.supports_out, foreach_binary_op_db),
         allowed_dtypes=[torch.float],
     )
     def test_binary_op_scalar_with_different_tensor_dtypes(self, device, dtype, op):
@@ -392,7 +392,7 @@ def test_binary_op_scalar_with_different_tensor_dtypes(self, device, dtype, op):
 
     @skipIfTorchDynamo("Different error msgs, TODO")
     @ops(
-        filter(lambda op: not op.has_no_out_of_place, foreach_binary_op_db),
+        filter(lambda op: op.supports_out, foreach_binary_op_db),
         dtypes=OpDTypes.supported,
     )
     def test_binary_op_list_error_cases(self, device, dtype, op):
@@ -457,7 +457,7 @@ def test_binary_op_list_error_cases(self, device, dtype, op):
 
     @unittest.skipIf(not torch.cuda.is_available(), "CUDA not found")
     @ops(
-        filter(lambda op: not op.has_no_out_of_place, foreach_binary_op_db),
+        filter(lambda op: op.supports_out, foreach_binary_op_db),
         dtypes=OpDTypes.supported,
     )
     def test_binary_op_list_slow_path(self, device, dtype, op):
@@ -509,7 +509,7 @@ def test_binary_op_list_slow_path(self, device, dtype, op):
             alpha=None, scalar_self_arg=False)
 
     @ops(
-        filter(lambda op: not op.has_no_out_of_place, foreach_binary_op_db),
+        filter(lambda op: op.supports_out, foreach_binary_op_db),
         dtypes=floating_types_and(torch.half, torch.bfloat16),
     )
     def test_binary_op_float_inf_nan(self, device, dtype, op):
@@ -539,12 +539,11 @@ def test_binary_op_float_inf_nan(self, device, dtype, op):
     @onlyCUDA
     @ops(foreach_unary_op_db)
     def test_unary_op_tensors_on_different_devices(self, device, dtype, op):
-        op.has_no_out_of_place = op.name != "_foreach_zero"
         method, ref, inplace_method, ref_inplace = self._get_funcs(op)
         # tensors: ['cuda', 'cpu]
         tensors = list(op.sample_inputs(device, dtype, num_input_tensors=[2]))[0].input
         tensors[1] = tensors[1].to("cpu")
-        if op.has_no_out_of_place:
+        if not op.supports_out:
             try:
                 actual = method((tensors,), False, False, zero_size=False)
             except RuntimeError as e:
@@ -560,13 +559,13 @@ def test_unary_op_tensors_on_different_devices(self, device, dtype, op):
             with self.assertRaisesRegex(type(e), str(e)):
                 ref_inplace((tensors,))
         else:
-            if op.has_no_out_of_place:
+            if not op.supports_out:
                 self.assertEqual(expected, tensors)
             else:
                 self.assertEqual([torch.zeros_like(t) for t in tensors], tensors)
 
     @onlyCUDA
-    @ops(filter(lambda op: not op.has_no_out_of_place, foreach_binary_op_db))
+    @ops(filter(lambda op: op.supports_out, foreach_binary_op_db))
     def test_binary_op_tensors_on_different_devices(self, device, dtype, op):
         # `tensors1`: ['cuda', 'cpu']
         # `tensors2`: ['cuda', 'cpu']
@@ -683,12 +682,13 @@ def test_inplace_foreach_leaf_check_and_grad_fn(self, device, dtype, op):
 
     @onlyCUDA
     @ops(
-        foreach_unary_op_db + foreach_binary_op_db + foreach_pointwise_op_db + foreach_other_op_db,
+        filter(
+            lambda op: op.supports_out,
+            foreach_unary_op_db + foreach_binary_op_db + foreach_pointwise_op_db + foreach_other_op_db,
+        ),
         dtypes=(torch.float,),
     )
     def test_outplace_with_invalid_grads(self, device, dtype, op):
-        if op.has_no_out_of_place:
-            self.skipTest(f"{op.name} does not have out-of-place implementation")
         func, *_ = self._get_funcs(op)
         sample = list(op.sample_inputs(dtype=dtype, device=device, requires_grad=True, num_input_tensors=[2], same_size=True))[0]
         self.assertTrue(all(t.requires_grad for t in sample.input))
@@ -831,7 +831,7 @@ def test_foreach_copy_with_multi_device_inputs(self, device, dtype, op):
     def test_autodiff(self, device, dtype, op, inplace):
         if not (op.supports_autograd or op.supports_forward_ad):
             self.skipTest("neither reverse mode nor forward mode supported")
-        if (not inplace) and op.has_no_out_of_place:
+        if (not inplace) and not op.supports_out:
             self.skipTest("out-of-place not implemented")
         if inplace and op.has_no_in_place:
             self.skipTest("in-place not implemented")
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 1a0c7cfd26bf..ff926426b8ca 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -9149,13 +9149,7 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs):
         'zero',
         foreach_inputs_sample_func(1, False, False),
         dtypes=all_types_and_complex_and(torch.bfloat16, torch.half),
-        has_no_out_of_place=True,
-        skips=(
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
-        ),
+        supports_out=False,
     ),
     ForeachFuncInfo(
         'sign',
@@ -9354,15 +9348,9 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs):
         "copy",
         foreach_inputs_sample_func(2, False, False),
         dtypes=all_types_and_complex_and(torch.bfloat16, torch.half),
-        has_no_out_of_place=True,
+        supports_out=False,
         supports_forward_ad=False,
         supports_autograd=False,
-        skips=(
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
-        ),
     )
 ]
 
diff --git a/torch/testing/_internal/opinfo/core.py b/torch/testing/_internal/opinfo/core.py
index 23b6e89e4a21..f100865f9d5d 100644
--- a/torch/testing/_internal/opinfo/core.py
+++ b/torch/testing/_internal/opinfo/core.py
@@ -2699,7 +2699,7 @@ def __init__(
         supports_scalar_self_arg=False,
         supports_forward_ad=True,
         backward_requires_result=False,
-        has_no_out_of_place=False,
+        supports_out=True,
         **kwargs,
     ):
         (
@@ -2708,13 +2708,15 @@ def __init__(
             torch_ref_method,
             torch_ref_inplace,
         ) = get_foreach_method_names(name)
-        if has_no_out_of_place:
+        if not supports_out:
             # note(crcrpar): `foreach_method` for `"zero"` is `None` but `None` would call
             # `_getattr_qual` in `OpInfo.__post_init__` which should fail since `_foreach_zero`
             # is not defined at the moment. Thus to skip the qualification, set a similar torch
             # function.
             assert foreach_method is None
-            foreach_method = getattr(torch.Tensor, f"{name}_")
+            assert torch_ref_method is None
+            foreach_method = foreach_method_inplace
+            torch_ref_method = torch_ref_inplace
         super().__init__(
             name="_foreach_" + name,
             op=foreach_method,
@@ -2727,6 +2729,7 @@ def __init__(
             sample_inputs_func=sample_inputs_func,
             supports_autograd=supports_autograd,
             supports_forward_ad=supports_forward_ad,
+            supports_out=supports_out,
             **kwargs,
         )
         self.supports_scalar_self_arg = supports_scalar_self_arg
@@ -2734,7 +2737,6 @@ def __init__(
         self.ref_inplace = torch_ref_inplace
         self.supports_alpha_param = supports_alpha_param
         self.backward_requires_result = backward_requires_result
-        self.has_no_out_of_place = has_no_out_of_place
         self.has_no_in_place = self.inplace_variant is None
         self.supports_inplace_autograd = supports_inplace_autograd
 

From 0f887a6d1a62449c92ad22b7659c471797ed3762 Mon Sep 17 00:00:00 2001
From: Xu Han 
Date: Wed, 22 Nov 2023 18:05:33 +0000
Subject: [PATCH 110/221] limit fused kernel num args. (#113131)

Fixes #97361

When fused kernel more than 1024 parameters, it should throw error from ctypes.
Limit args number is should be a mechanism to protect stack memory. As we known, CPP is passing args via stack memory, and stack memory has size limitation.

Code change:

1. cpp backend will check the fused nodes' args number, if it is reach the limitation. It will status flush status to ready.
2. scheduler will check `ready_to_flush` API and help backend flush codegen.
3. Add `ready_to_flush` API to `BaseScheduling`, Triton backend will return False due to not support it yet.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113131
Approved by: https://github.com/jgong5, https://github.com/mlazos
---
 test/inductor/test_torchinductor.py | 18 ++++++++++++++++++
 torch/_inductor/codegen/cpp.py      | 25 +++++++++++++++++++++++++
 torch/_inductor/codegen/triton.py   |  3 +++
 torch/_inductor/scheduler.py        | 12 ++++++++++++
 4 files changed, 58 insertions(+)

diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py
index 42480e275a08..56e4db86b249 100644
--- a/test/inductor/test_torchinductor.py
+++ b/test/inductor/test_torchinductor.py
@@ -7713,6 +7713,24 @@ def fn(x, y):
         b = torch.randn(65, 2**24, device=self.device)
         fn(a, b)
 
+    def test_fuse_large_params(self):
+        def pt2_optimizer_step(optimizer):
+            @torch.compile()
+            def f():
+                optimizer.step()
+
+            f()
+
+        params = [
+            torch.rand(10, 10, dtype=torch.float32, device=self.device)
+            for _ in range(194)
+        ]
+        for p in params:
+            p.grad = torch.rand_like(p)
+
+        o = torch.optim.AdamW(params)
+        pt2_optimizer_step(o)
+
     def test_adaptive_avg_pool1d_argmax(self):
         # https://github.com/pytorch/pytorch/issues/113013
         def fn(x):
diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py
index b64a67bd86a2..525bec5374a6 100644
--- a/torch/_inductor/codegen/cpp.py
+++ b/torch/_inductor/codegen/cpp.py
@@ -2919,9 +2919,18 @@ def codegen_loops(self, code, worksharing):
 
 
 class CppScheduling(BaseScheduling):
+    # ctypes limits the number of args to 1024, refer to:
+    # https://github.com/python/cpython/commit/a285af7e626d1b81cf09f8b2bf7656f100bc1237
+    # We set a conservative threshold here.
+    MAX_FUSED_KERNEL_ARGS_NUM = 500
+
     def __init__(self, scheduler):
         self.scheduler = scheduler
         self.get_kernel_group()
+        self._ready_to_flush = False
+
+    def _set_flush_status(self, status: bool):
+        self._ready_to_flush = status
 
     def group_fn(self, sizes):
         return tuple(tuple(map(V.graph.sizevars.simplify, s)) for s in sizes)
@@ -2968,12 +2977,23 @@ def codegen_nodes(self, nodes):
 
         kernel_group.finalize_kernel(cpp_kernel_proxy, nodes)
 
+        args_num = self._get_scheduled_num_args()
+        if args_num > CppScheduling.MAX_FUSED_KERNEL_ARGS_NUM:
+            self._set_flush_status(True)
+
+    def _get_scheduled_num_args(self):
+        return self.kernel_group.get_num_args()
+
+    def ready_to_flush(self):
+        return self._ready_to_flush
+
     def codegen_sync(self):
         pass
 
     def flush(self):
         self.kernel_group.codegen_define_and_call(V.graph.wrapper_code)
         self.get_kernel_group()
+        self._set_flush_status(False)
 
 
 class KernelGroup:
@@ -2995,6 +3015,11 @@ def finalize_kernel(self, new_kernel, nodes):
         ws = self.ws
         new_kernel.codegen_loops(code, ws)
 
+    def get_num_args(self):
+        arg_defs, call_args, arg_types = self.args.cpp_argdefs()
+        args_num = len(arg_defs)
+        return args_num
+
     def codegen_define_and_call(self, wrapper):
         self.stack.close()
         if not self.scheduled_nodes:
diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py
index d178edae9813..0f08f728330f 100644
--- a/torch/_inductor/codegen/triton.py
+++ b/torch/_inductor/codegen/triton.py
@@ -2859,6 +2859,9 @@ def select_tiling(cls, node_schedule, numel, reduction_numel=sympy.Integer(1)):
     def flush(self):
         pass
 
+    def ready_to_flush(self) -> bool:
+        return False
+
     def benchmark_fused_nodes(self, nodes):
         _, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group
         node_schedule = self.generate_node_schedule(nodes, numel, rnumel)
diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py
index 7b8755a8699b..4d1a48f73ffc 100644
--- a/torch/_inductor/scheduler.py
+++ b/torch/_inductor/scheduler.py
@@ -2194,6 +2194,11 @@ def codegen(self):
 
             self.available_buffer_names.update(node.get_names())
 
+            if not isinstance(node, NopKernelSchedulerNode):
+                device = node.get_device()
+                if self.get_backend(device).ready_to_flush():
+                    self.flush()
+
         self.flush()
 
     def is_unaligned_buffer(self, buf_name):
@@ -2250,6 +2255,13 @@ def codegen_sync(self):
         """
         raise NotImplementedError()
 
+    def ready_to_flush(self) -> bool:
+        """
+        Check whether the backend is requesting the scheduler to flush the generated kernel.
+        If not supported, please return False.
+        """
+        return False
+
     def flush(self):
         """
         Flush the generated kernel and python wrapper code to the source code file.

From 84909fef529e1709ea66e2735a63a254ddd2422f Mon Sep 17 00:00:00 2001
From: Tomasz Bohutyn 
Date: Wed, 22 Nov 2023 18:24:20 +0000
Subject: [PATCH 111/221] Add meta registration for aten.linear_backward
 (#114359)

Fixes #114358

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114359
Approved by: https://github.com/ezyang
---
 torch/_meta_registrations.py | 14 ++++++++++++++
 1 file changed, 14 insertions(+)

diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py
index dc8948bf0a42..4c54df447bdb 100644
--- a/torch/_meta_registrations.py
+++ b/torch/_meta_registrations.py
@@ -5775,6 +5775,20 @@ def _thnn_fused_lstm_cell_backward_impl(grad_hy, grad_cy, cx, cy, workspace, has
     return grad_gates, grad_cx, grad_bias
 
 
+# From aten/src/ATen/native/mps/operations/Linear.mm
+@register_meta(aten.linear_backward.default)
+def linear_backward(input_, grad_output_, weight_, output_mask):
+    grad_input = None
+    grad_weight = None
+    grad_bias = None
+    if output_mask[0]:
+        grad_input = grad_output_.new_empty(input_.size())
+    if output_mask[1] or output_mask[2]:
+        grad_weight = grad_output_.new_empty((grad_output_.size(-1), input_.size(-1)))
+        grad_bias = grad_output_.new_empty(grad_output_.size(-1))
+    return (grad_input, grad_weight, grad_bias)
+
+
 @register_meta(aten.pixel_shuffle.default)
 def meta_pixel_shuffle(self, upscale_factor):
     assert (

From 9d68cfee0dec46be44732667466303a36155f144 Mon Sep 17 00:00:00 2001
From: Jesse Cai 
Date: Tue, 21 Nov 2023 07:42:44 -0800
Subject: [PATCH 112/221] [sparse][semi-structured] Make cusparseLt handle +
 flag thread_local (#114273)

Summary:

As raised in this issue: https://github.com/pytorch/pytorch/issues/113776

cuSPARSELt does not support sharing handles across different threads.
Ideally we would use something like CuSparseHandlePool to do this, but
since cuSPARSELt handle creation is inconsitent with the rest of CUDA,
we have to do make these variables thread_local instead.

Test Plan:

`python test/test_sparse_semi_structured.py`

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114273
Approved by: https://github.com/danthe3rd
---
 aten/src/ATen/native/sparse/cuda/cuSPARSELtOps.cpp | 9 +++++++--
 1 file changed, 7 insertions(+), 2 deletions(-)

diff --git a/aten/src/ATen/native/sparse/cuda/cuSPARSELtOps.cpp b/aten/src/ATen/native/sparse/cuda/cuSPARSELtOps.cpp
index 09462fd06ff0..6b4e143e6e8e 100644
--- a/aten/src/ATen/native/sparse/cuda/cuSPARSELtOps.cpp
+++ b/aten/src/ATen/native/sparse/cuda/cuSPARSELtOps.cpp
@@ -18,8 +18,13 @@
 
 namespace at::native {
 
-cusparseLtHandle_t handle;
-bool handle_initialized = false;
+// Ideally we would use the same DeviceThreadHandlePool mechanism as used in aten/src/ATen/cuda/CuSparseHandlePool.cpp
+// which would handle this for us. However, the cuSPARSELt handle signature is different from that of cuSPARSE/cuBLAS,
+// so it's not possible to reuse the existing pooling mechanism. Instead we have to handle our handles ourselves, which
+// is why these variables are thread local. Once cuSPARSELt updates their handle signature to be consistent with the rest
+// of CUDA, we can switch to using DeviceThreadHandlePool.
+thread_local cusparseLtHandle_t handle;
+thread_local bool handle_initialized = false;
 
 at::Tensor _cslt_compress(const Tensor& sparse_input)
 {

From 07b6f377b401933e69a605037b8a5c2fba627601 Mon Sep 17 00:00:00 2001
From: Tianyu Liu 
Date: Tue, 21 Nov 2023 15:32:33 -0800
Subject: [PATCH 113/221] deprecate PairwiseParallel from test (#114314)

**Summary**
To solve issue #113706:
1. replace `PariwiseParallel` with `ColwiseParallel` and `RowwiseParallel`.
2. replace the input of ColwiseParallel from `make_input_replicate_1d` and `make_output_replicate_1d` to `input_layouts` and `output_layouts`.
3. deprecate the tests for `_parallelize_mlp` as it only supports `PariwiseParallel`.

**Test Plan**
`pytest pytorch/test/distributed/tensor/parallel/`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114314
Approved by: https://github.com/wanchaol, https://github.com/XilunWu
---
 .../tensor/parallel/test_ddp_2d_parallel.py   | 14 +++--
 .../tensor/parallel/test_fsdp_2d_parallel.py  | 37 ++++++++----
 .../tensor/parallel/test_parallelize_api.py   | 57 ++-----------------
 .../tensor/parallel/test_tp_examples.py       | 11 ++--
 .../tensor/parallel/test_tp_random_state.py   | 14 +----
 5 files changed, 50 insertions(+), 83 deletions(-)

diff --git a/test/distributed/tensor/parallel/test_ddp_2d_parallel.py b/test/distributed/tensor/parallel/test_ddp_2d_parallel.py
index 4c78b8b2eba6..e68cff5e4023 100644
--- a/test/distributed/tensor/parallel/test_ddp_2d_parallel.py
+++ b/test/distributed/tensor/parallel/test_ddp_2d_parallel.py
@@ -3,7 +3,11 @@
 import torch
 import torch.distributed as dist
 from torch.distributed._tensor import DeviceMesh, DTensor, Replicate
-from torch.distributed.tensor.parallel import PairwiseParallel, parallelize_module
+from torch.distributed.tensor.parallel import (
+    ColwiseParallel,
+    parallelize_module,
+    RowwiseParallel,
+)
 from torch.distributed.tensor.parallel.ddp import _pre_dp_module_transform
 
 from torch.nn.parallel import DistributedDataParallel as DDP
@@ -38,9 +42,11 @@ def init_model(device_type, model_parallel_size=TP_DEGREE):
 
     dp_pg = twod_mesh.get_dim_groups()[0]
 
-    twod_model = parallelize_module(
-        twod_model, twod_mesh, PairwiseParallel(), tp_mesh_dim=1
-    )
+    parallelize_plan = {
+        "net1": ColwiseParallel(),
+        "net2": RowwiseParallel(),
+    }
+    twod_model = parallelize_module(twod_model, twod_mesh, parallelize_plan, tp_mesh_dim=1)
     _pre_dp_module_transform(twod_model)
     # TODO: Add tests when using gradient_as_bucket_view and static_graph for DDP.
     twod_model = DDP(twod_model, process_group=dp_pg)
diff --git a/test/distributed/tensor/parallel/test_fsdp_2d_parallel.py b/test/distributed/tensor/parallel/test_fsdp_2d_parallel.py
index c0feafb3f2be..60dcb181715c 100644
--- a/test/distributed/tensor/parallel/test_fsdp_2d_parallel.py
+++ b/test/distributed/tensor/parallel/test_fsdp_2d_parallel.py
@@ -20,7 +20,6 @@
 from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
 from torch.distributed.tensor.parallel import (
     ColwiseParallel,
-    PairwiseParallel,
     parallelize_module,
     RowwiseParallel,
 )
@@ -105,8 +104,12 @@ def test_raise_invalid_tp_composition(self):
             mesh_2d = init_device_mesh(
                 self.device_type, (2, self.world_size // 2), mesh_dim_names=("tp", "dp")
             )
+            parallelize_plan = {
+                "net1": ColwiseParallel(),
+                "net2": RowwiseParallel(),
+            }
             model_2d = parallelize_module(
-                SimpleModel().cuda(), mesh_2d["tp"], PairwiseParallel()
+                SimpleModel().cuda(), mesh_2d["tp"], parallelize_plan
             )
 
     @with_comms
@@ -138,7 +141,11 @@ def _test_2d_e2e_training(
         )
         tp_mesh = mesh_2d["tp"]
         dp_mesh = mesh_2d["dp"]
-        model_2d = parallelize_module(SimpleModel().cuda(), tp_mesh, PairwiseParallel())
+        parallelize_plan = {
+            "net1": ColwiseParallel(),
+            "net2": RowwiseParallel(),
+        }
+        model_2d = parallelize_module(SimpleModel().cuda(), tp_mesh, parallelize_plan)
         model_2d = FSDP(
             model_2d,
             device_mesh=dp_mesh,
@@ -246,9 +253,11 @@ def test_2d_state_dict(self, is_even_sharded_model):
         )
         tp_mesh = mesh_2d["tp"]
         dp_mesh = mesh_2d["dp"]
-        model_2d = parallelize_module(
-            simple_model().cuda(), tp_mesh, PairwiseParallel()
-        )
+        parallelize_plan = {
+            "net1": ColwiseParallel(),
+            "net2": RowwiseParallel(),
+        }
+        model_2d = parallelize_module(simple_model().cuda(), tp_mesh, parallelize_plan)
         model_2d = FSDP(model_2d, device_mesh=dp_mesh, use_orig_params=True)
 
         FSDP.set_state_dict_type(
@@ -292,9 +301,11 @@ def test_2d_load_state_dict(self, is_even_sharded_model):
         )
         tp_mesh = mesh_2d["tp"]
         dp_mesh = mesh_2d["dp"]
-        model_2d = parallelize_module(
-            simple_model().cuda(), tp_mesh, PairwiseParallel()
-        )
+        parallelize_plan = {
+            "net1": ColwiseParallel(),
+            "net2": RowwiseParallel(),
+        }
+        model_2d = parallelize_module(simple_model().cuda(), tp_mesh, parallelize_plan)
         model_2d = FSDP(model_2d, device_mesh=dp_mesh, use_orig_params=True)
         optim_2d = torch.optim.Adam(model_2d.parameters(), lr=0.01)
 
@@ -351,9 +362,11 @@ def test_2d_optim_state_dict(self, is_even_sharded_model):
         mesh_2d = init_device_mesh(
             self.device_type, (2, self.world_size // 2), mesh_dim_names=("dp", "tp")
         )
-        model_2d = parallelize_module(
-            simple_model().cuda(), mesh_2d["tp"], PairwiseParallel()
-        )
+        parallelize_plan = {
+            "net1": ColwiseParallel(),
+            "net2": RowwiseParallel(),
+        }
+        model_2d = parallelize_module(simple_model().cuda(), mesh_2d["tp"], parallelize_plan)
         model_2d = FSDP(model_2d, device_mesh=mesh_2d["dp"], use_orig_params=True)
         FSDP.set_state_dict_type(
             model_2d,
diff --git a/test/distributed/tensor/parallel/test_parallelize_api.py b/test/distributed/tensor/parallel/test_parallelize_api.py
index 91fb2b50662b..44a8687ffb77 100644
--- a/test/distributed/tensor/parallel/test_parallelize_api.py
+++ b/test/distributed/tensor/parallel/test_parallelize_api.py
@@ -6,15 +6,10 @@
 from torch.distributed.tensor.parallel._utils import _create_1d_device_mesh
 from torch.distributed.tensor.parallel.api import (
     _parallelize_linear_like_module,
-    _parallelize_mlp,
     parallelize_module,
 )
 from torch.distributed.tensor.parallel.style import (
     ColwiseParallel,
-    make_input_replicate_1d,
-    make_output_replicate_1d,
-    PairwiseParallel,
-    ParallelStyle,
     PrepareModuleInput,
     PrepareModuleOutput,
     RowwiseParallel,
@@ -141,23 +136,6 @@ def _compare_module(
         dist_optim.step()
         self._compare_params(local_module, dist_module, rank0_only, rowwise)
 
-    @with_comms
-    def test_parallelize_mlp(self):
-        inp_size = [12, 10]
-        model = MLPModule(self.device_type)
-        model_tp = MLPModule(self.device_type)
-
-        # Ensure model are initialized the same way.
-        self.assertEqual(model.net1.weight, model_tp.net1.weight)
-        self.assertEqual(model.net1.bias, model_tp.net1.bias)
-        self.assertEqual(model.net2.weight, model_tp.net2.weight)
-        self.assertEqual(model.net2.bias, model_tp.net2.bias)
-
-        # Parallelize module.
-        device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
-        model_tp = _parallelize_mlp(model_tp, device_mesh, PairwiseParallel())
-        self._compare_module(model, model_tp, inp_size)
-
     @with_comms
     def test_parallelize_mlp_with_module_api(self):
         inp_size = [12, 10]
@@ -177,10 +155,10 @@ def test_parallelize_mlp_with_module_api(self):
             device_mesh,
             {
                 "net1": ColwiseParallel(
-                    make_input_replicate_1d, make_output_replicate_1d
+                    input_layouts=Replicate(), output_layouts=Replicate()
                 ),
                 "net2": ColwiseParallel(
-                    make_input_replicate_1d, make_output_replicate_1d
+                    input_layouts=Replicate(), output_layouts=Replicate()
                 ),
             },
         )
@@ -217,40 +195,15 @@ def test_parallelize_mlp_with_module_api_nested(self):
             device_mesh,
             {
                 "dummy_encoder.net1": ColwiseParallel(
-                    make_input_replicate_1d, make_output_replicate_1d
+                    input_layouts=Replicate(), output_layouts=Replicate()
                 ),
                 "dummy_encoder.net2": ColwiseParallel(
-                    make_input_replicate_1d, make_output_replicate_1d
+                    input_layouts=Replicate(), output_layouts=Replicate()
                 ),
             },
         )
         self._compare_module(model, model_tp, inp_size, rank0_only=False)
 
-    @with_comms
-    def test_parallelize_mlp_error(self):
-        class DummyParallel(ParallelStyle):
-            def __init__(self) -> None:
-                super().__init__(
-                    make_input_replicate_1d,
-                    make_output_replicate_1d,
-                    input_layouts=None,
-                    output_layouts=None,
-                    use_local_output=False,
-                )
-
-        model_tp = MLPModule(self.device_type)
-        device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
-        with self.assertRaisesRegex(
-            NotImplementedError,
-            "Only support PairwiseParallel for MLP parallelization.",
-        ):
-            _parallelize_mlp(model_tp, device_mesh, DummyParallel())
-
-        with self.assertRaisesRegex(
-            RuntimeError, "More than one nn.Linear needed for a MLP."
-        ):
-            _parallelize_mlp(torch.nn.Linear(10, 5), device_mesh, PairwiseParallel())
-
     @with_comms
     def test_linear_row_wise_parallel(self):
         # test RowwiseParallel
@@ -274,7 +227,7 @@ def test_linear_row_wise_parallel(self):
     def test_linear_col_wise_parallel(self):
         # test ColwiseParallel
         inp_size = [8, 10]
-        colwise = ColwiseParallel(make_input_replicate_1d, make_output_replicate_1d)
+        colwise = ColwiseParallel(input_layouts=Replicate(), output_layouts=Replicate())
 
         torch.manual_seed(5)
         model = torch.nn.Linear(10, 16, device=self.device_type)
diff --git a/test/distributed/tensor/parallel/test_tp_examples.py b/test/distributed/tensor/parallel/test_tp_examples.py
index a15c818354e5..a37fec29574a 100644
--- a/test/distributed/tensor/parallel/test_tp_examples.py
+++ b/test/distributed/tensor/parallel/test_tp_examples.py
@@ -10,7 +10,6 @@
 )
 from torch.distributed.tensor.parallel import (
     ColwiseParallel,
-    PairwiseParallel,
     parallelize_module,
     RowwiseParallel,
 )
@@ -62,7 +61,7 @@ def _test_mlp_training_e2e(self, is_seq_parallel=False, recompute_activation=Fal
             self.device_type,
             torch.arange(0, NUM_DEVICES),
         )
-        parallel_style = {
+        parallelize_plan = {
             "net1": ColwiseParallel(input_layouts=Shard(0))
             if is_seq_parallel
             else ColwiseParallel(),
@@ -70,7 +69,7 @@ def _test_mlp_training_e2e(self, is_seq_parallel=False, recompute_activation=Fal
             if is_seq_parallel
             else RowwiseParallel(),
         }
-        model_tp = parallelize_module(model_tp, device_mesh, parallel_style)
+        model_tp = parallelize_module(model_tp, device_mesh, parallelize_plan)
         if recompute_activation:
             model_tp = input_reshard(
                 checkpoint_wrapper(
@@ -124,7 +123,11 @@ def _test_mlp_inference(self, device_mesh):
         self._check_module(model, model_tp)
 
         # Shard module and initialize optimizer.
-        model_tp = parallelize_module(model_tp, device_mesh, PairwiseParallel())
+        parallelize_plan = {
+            "net1": ColwiseParallel(),
+            "net2": RowwiseParallel(),
+        }
+        model_tp = parallelize_module(model_tp, device_mesh, parallelize_plan)
 
         output = model(inp)
         output_tp = model_tp(inp)
diff --git a/test/distributed/tensor/parallel/test_tp_random_state.py b/test/distributed/tensor/parallel/test_tp_random_state.py
index 75444c59afcc..812bcfb5a969 100644
--- a/test/distributed/tensor/parallel/test_tp_random_state.py
+++ b/test/distributed/tensor/parallel/test_tp_random_state.py
@@ -5,11 +5,7 @@
 
 from torch.distributed._tensor import DeviceMesh
 from torch.distributed.tensor.parallel.api import parallelize_module
-from torch.distributed.tensor.parallel.style import (
-    ColwiseParallel,
-    make_input_replicate_1d,
-    make_output_replicate_1d,
-)
+from torch.distributed.tensor.parallel.style import ColwiseParallel
 from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
 from torch.testing._internal.common_utils import run_tests
 from torch.testing._internal.distributed._tensor.common_dtensor import (
@@ -55,12 +51,8 @@ def test_model_init(self):
                 model,
                 device_mesh,
                 {
-                    "net1": ColwiseParallel(
-                        make_input_replicate_1d, make_output_replicate_1d
-                    ),
-                    "net2": ColwiseParallel(
-                        make_input_replicate_1d, make_output_replicate_1d
-                    ),
+                    "net1": ColwiseParallel(input_layouts=Replicate(), output_layouts=Replicate()),
+                    "net2": ColwiseParallel(input_layouts=Replicate(), output_layouts=Replicate()),
                 },
             )
             # in most cases, the random number generator states is set by data loader

From f882c175d8e9731238c3f29ca10821f2fe9f0797 Mon Sep 17 00:00:00 2001
From: drisspg 
Date: Wed, 22 Nov 2023 20:02:47 +0000
Subject: [PATCH 114/221] Require less alignment for masking (#114173)

# Summary
Improved Fix for Attention Mask Alignment Issue (#112577)

This PR addresses Issue #112577 by refining the previously implemented fix, which was found to be incorrect and causes un-needed memory regressions. The update simplifies the approach to handling the alignment of the attention mask for mem eff attention.

## Changes
Alignment Check and Padding: Initially, the alignment of the attention mask is checked. If misalignment is detected, padding is applied, followed by slicing. During this process, a warning is raised to alert users.

Should this be warn_once?

We only call expand, once on the aligned mask.

Reference
https://github.com/facebookresearch/xformers/blob/main/xformers/ops/fmha/cutlass.py#L115

@albanD, @mruberry, @jbschlosser, @walterddr, and @mikaylagawarecki.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114173
Approved by: https://github.com/danthe3rd
---
 .../ATen/native/transformers/attention.cpp    | 45 +++++++++----------
 test/test_transformers.py                     | 18 ++++++++
 torch/_meta_registrations.py                  | 16 ++++---
 3 files changed, 48 insertions(+), 31 deletions(-)

diff --git a/aten/src/ATen/native/transformers/attention.cpp b/aten/src/ATen/native/transformers/attention.cpp
index 63b4a52d8c07..ffecaf994520 100644
--- a/aten/src/ATen/native/transformers/attention.cpp
+++ b/aten/src/ATen/native/transformers/attention.cpp
@@ -522,9 +522,14 @@ c10::optional convert_boolean_attn_mask(const c10::optional& att
 // We apply this function to the top level SDPA so that
 // if padding is done it will be tracked for backward automatically
 
-template 
-bool is_aligned(const SymInt& size){
-  return size % alignment == 0;
+template
+bool aligned_tensor(const at::Tensor& tensor){
+  for(const auto i : c10::irange(tensor.dim() - 1)){
+    if(tensor.sym_stride(i) % alignment != 0){
+      return false;
+    }
+  }
+  return tensor.sym_stride(-1) == 1;
 }
 
 template 
@@ -540,31 +545,23 @@ at::Tensor preprocess_mask(
     const at::Tensor& query,
     const at::Tensor& key,
     const at::Tensor& value) {
-  constexpr int mem_eff_alignment = 16;
-  // Expand to 4d case
-  at::Tensor attn_mask = mask.expand_symint(
+  constexpr int mem_eff_alignment = 8;
+  at::Tensor result_mask = mask;
+  if (!aligned_tensor(mask)) {
+    TORCH_WARN_ONCE(
+        "Memory Efficient Attention requires the attn_mask to be aligned to, ",
+        mem_eff_alignment,
+        " elements. "
+        "Prior to calling SDPA, pad the last dimension of the attn_mask "
+        "to be a multiple of ", mem_eff_alignment,
+        " and then slice the attn_mask to the original size.");
+    result_mask = pad_bias(mask);
+  }
+  return result_mask.expand_symint(
       {query.sym_size(0),
        query.sym_size(1),
        query.sym_size(2),
        key.sym_size(2)});
-
-  bool aligned_last_dim = is_aligned(attn_mask.sym_size(-1));
-  // Apply pad_bias and store the result in attn_mask
-  if (!aligned_last_dim) {
-    return pad_bias(attn_mask);
-  }
-  // Check and make the tensor contiguous if needed
-  auto needs_contig = [](const c10::SymInt& stride) {
-    return (stride % 16 != 0) || (stride == 0);
-  };
-  if (needs_contig(attn_mask.sym_stride(0)) ||
-      needs_contig(attn_mask.sym_stride(1)) ||
-      needs_contig(attn_mask.sym_stride(2)) ||
-      needs_contig(attn_mask.sym_stride(3))) {
-    return attn_mask.contiguous();
-  }
-
-  return attn_mask;
 }
 // FlashAttentionV2 requires that head dimension be a multiple of 8
 // This was previously done within the kernel, however
diff --git a/test/test_transformers.py b/test/test_transformers.py
index 5785fedca0e1..81e574b75655 100644
--- a/test/test_transformers.py
+++ b/test/test_transformers.py
@@ -1898,6 +1898,24 @@ def test_mem_eff_attention_long_sequence_mask(self, device, dtype):
             out = F.scaled_dot_product_attention(query, key, value, mask)
         out.sum().backward()
 
+    @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system")
+    def test_mem_eff_attention_non_contig_mask_bug(self, device):
+        dtype = torch.float32
+        make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=True)
+        batch, num_heads, head_dim = 1, 16, 128
+        seq_len_q, seq_len_kv = 1, 16
+        query = make_tensor(batch, seq_len_q, num_heads * head_dim).view(batch, seq_len_q, num_heads, head_dim).transpose(1, 2)
+        kv_shape = (batch, seq_len_kv, head_dim)
+        key, value = make_tensor(kv_shape).unsqueeze(1), make_tensor(kv_shape).unsqueeze(1)
+        key = key.expand(-1, num_heads, -1, -1)
+        value = value.expand(-1, num_heads, -1, -1)
+        mask = torch.ones((1, 1, seq_len_q, seq_len_kv), device=device, dtype=torch.bool)
+        with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]):
+            out = F.scaled_dot_product_attention(query, key, value, mask)
+            out_no_mask = F.scaled_dot_product_attention(query, key, value, None)
+        max_diff = (out - out_no_mask).abs().mean()
+        assert max_diff.item() < 1e-9
+
     @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system")
     @parametrize("type", ["dense", "nested"])
     @parametrize("is_contiguous", [True, False])
diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py
index 4c54df447bdb..bb10a34c4c06 100644
--- a/torch/_meta_registrations.py
+++ b/torch/_meta_registrations.py
@@ -5221,12 +5221,14 @@ def meta__scaled_dot_product_efficient_backward(
     )
     grad_bias = None
     if attn_bias is not None and grad_input_mask[3]:
-        grad_bias = torch.empty_strided(
-            attn_bias.size(),
-            attn_bias.stride(),
-            dtype=attn_bias.dtype,
-            device=attn_bias.device,
+        lastDim = attn_bias.size(-1)
+        lastDimAligned = lastDim if lastDim % 16 == 0 else lastDim + 16 - lastDim % 16
+        new_sizes = list(attn_bias.size())
+        new_sizes[-1] = lastDimAligned
+        grad_bias = torch.empty(
+            new_sizes, dtype=attn_bias.dtype, device=attn_bias.device
         )
+        grad_bias = grad_bias[..., :lastDim]
 
     return grad_q, grad_k, grad_v, grad_bias
 
@@ -5303,12 +5305,12 @@ def meta__efficient_attention_backward(
     grad_value = torch.empty_like(value)
 
     if bias is not None:
-        assert bias is not None
         lastDim = bias.size(-1)
-        lastDimAligned = 16 * ((lastDim + 15) // 16)
+        lastDimAligned = lastDim if lastDim % 16 == 0 else lastDim + 16 - lastDim % 16
         new_sizes = list(bias.size())
         new_sizes[-1] = lastDimAligned
         grad_bias = torch.empty(new_sizes, dtype=bias.dtype, device=bias.device)
+        grad_bias = grad_bias[..., :lastDim]
     else:
         grad_bias = torch.empty((), device=query.device)
 

From d416e5b34f5bfcb0d0d3cdaa53e59eacc5986bc5 Mon Sep 17 00:00:00 2001
From: Wanchao Liang 
Date: Tue, 21 Nov 2023 20:31:01 -0800
Subject: [PATCH 115/221] [torchrun] fix incorrect warning for non static
 backend (#114335)

This PR fixes a incorrect warning for non static rdzv backend, the
warning should only be thrown when the rdzv endpoint not specified.

error repro from @stas00

```
$ cat test.py
import torch

$ python -u -m torch.distributed.run --nproc_per_node=1 --rdzv_endpoint localhost:6000  --rdzv_backend c10d test.py
master_addr is only used for static rdzv_backend and when rdzv_endpoint is not specified
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114335
Approved by: https://github.com/H-Huang
---
 torch/distributed/run.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/torch/distributed/run.py b/torch/distributed/run.py
index 72b0df316fd9..507be2daada1 100644
--- a/torch/distributed/run.py
+++ b/torch/distributed/run.py
@@ -694,7 +694,7 @@ def config_from_args(args) -> Tuple[LaunchConfig, Union[Callable, str], List[str
     assert 0 < min_nodes <= max_nodes
     assert args.max_restarts >= 0
 
-    if hasattr(args, "master_addr") and args.rdzv_backend != "static":
+    if hasattr(args, "master_addr") and args.rdzv_backend != "static" and not args.rdzv_endpoint:
         log.warning(
             "master_addr is only used for static rdzv_backend and when rdzv_endpoint "
             "is not specified."

From 1b66701379050ae5e11a014dd29bf10ca1849bb3 Mon Sep 17 00:00:00 2001
From: Eli Uriegas 
Date: Wed, 22 Nov 2023 13:58:19 -0600
Subject: [PATCH 116/221] ci: Bump TorchAudio, less third_party deps (#114393)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Installing the current pinned version of TorchAudio can be problematic because it
expects to be able to download a file from sourceware.org (see [ref](https://github.com/pytorch/audio/blob/a8f4e97bd5356a7a77510cdf6a3a62e25a5dc602/third_party/bzip2/CMakeLists.txt#L14)) and that does
not have any guarantees of uptime.

This bumps this commit to the latest v2.1.1 commit (https://github.com/pytorch/audio/releases/tag/v2.1.1) which should have less third_party dependencies and thus be less flaky

Should help with errors like: logs link: https://github.com/pytorch/pytorch/actions/runs/6959510046/job/18942955523#step:15:592 ``` 5h+ pip install --progress-bar off --no-use-pep517 --user git+https://github.com/pytorch/audio.git@a8f4e97bd5356a7a77510cdf6a3a62e25a5dc602 Collecting git+https://github.com/pytorch/audio.git@a8f4e97bd5356a7a77510cdf6a3a62e25a5dc602 Cloning https://github.com/pytorch/audio.git (to revision a8f4e97bd5356a7a77510cdf6a3a62e25a5dc602) to /tmp/pip-req-build-6b5hkzmq Running command git clone --filter=blob:none --quiet https://github.com/pytorch/audio.git /tmp/pip-req-build-6b5hkzmq Running command git rev-parse -q --verify 'sha^a8f4e97bd5356a7a77510cdf6a3a62e25a5dc602' Running command git fetch -q https://github.com/pytorch/audio.git a8f4e97bd5356a7a77510cdf6a3a62e25a5dc602 Running command git checkout -q a8f4e97bd5356a7a77510cdf6a3a62e25a5dc602 Resolved https://github.com/pytorch/audio.git to commit a8f4e97bd5356a7a77510cdf6a3a62e25a5dc602 Running command git submodule update --init --recursive -q Preparing metadata (setup.py) ... 25l- error error: subprocess-exited-with-error × python setup.py egg_info did not run successfully. │ exit code: 1 ╰─> [60 lines of output] Traceback (most recent call last): File "/opt/conda/envs/py_3.10/lib/python3.10/urllib/request.py", line 1348, in do_open h.request(req.get_method(), req.selector, req.data, headers, File "/opt/conda/envs/py_3.10/lib/python3.10/http/client.py", line 1283, in request self._send_request(method, url, body, headers, encode_chunked) File "/opt/conda/envs/py_3.10/lib/python3.10/http/client.py", line 1329, in _send_request self.endheaders(body, encode_chunked=encode_chunked) File "/opt/conda/envs/py_3.10/lib/python3.10/http/client.py", line 1278, in endheaders self._send_output(message_body, encode_chunked=encode_chunked) File "/opt/conda/envs/py_3.10/lib/python3.10/http/client.py", line 1038, in _send_output self.send(msg) File "/opt/conda/envs/py_3.10/lib/python3.10/http/client.py", line 976, in send self.connect() File "/opt/conda/envs/py_3.10/lib/python3.10/http/client.py", line 1448, in connect super().connect() File "/opt/conda/envs/py_3.10/lib/python3.10/http/client.py", line 942, in connect self.sock = self._create_connection( File "/opt/conda/envs/py_3.10/lib/python3.10/socket.py", line 845, in create_connection raise err File "/opt/conda/envs/py_3.10/lib/python3.10/socket.py", line 833, in create_connection sock.connect(sa) OSError: [Errno 99] Cannot assign requested address During handling of the above exception, another exception occurred: Traceback (most recent call last): File "", line 2, in File "", line 34, in File "/tmp/pip-req-build-6b5hkzmq/setup.py", line 184, in _main() File "/tmp/pip-req-build-6b5hkzmq/setup.py", line 145, in _main _fetch_third_party_libraries() File "/tmp/pip-req-build-6b5hkzmq/setup.py", line 129, in _fetch_third_party_libraries _fetch_archives(_parse_sources()) File "/tmp/pip-req-build-6b5hkzmq/setup.py", line 123, in _fetch_archives torch.hub.download_url_to_file(url, dest, progress=False) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/hub.py", line 620, in download_url_to_file u = urlopen(req) File "/opt/conda/envs/py_3.10/lib/python3.10/urllib/request.py", line 216, in urlopen return opener.open(url, data, timeout) File "/opt/conda/envs/py_3.10/lib/python3.10/urllib/request.py", line 519, in open response = self._open(req, data) File "/opt/conda/envs/py_3.10/lib/python3.10/urllib/request.py", line 536, in _open result = self._call_chain(self.handle_open, protocol, protocol + File "/opt/conda/envs/py_3.10/lib/python3.10/urllib/request.py", line 496, in _call_chain result = func(*args) File "/opt/conda/envs/py_3.10/lib/python3.10/urllib/request.py", line 1391, in https_open return self.do_open(http.client.HTTPSConnection, req, File "/opt/conda/envs/py_3.10/lib/python3.10/urllib/request.py", line 1351, in do_open raise URLError(err) urllib.error.URLError: -- Git branch: HEAD -- Git SHA: a8f4e97bd5356a7a77510cdf6a3a62e25a5dc[602](https://github.com/pytorch/pytorch/actions/runs/6959510046/job/18942955523#step:15:603) -- Git tag: None -- PyTorch dependency: torch -- Building version 2.0.0a0+a8f4e97 --- Initializing submodules --- Initialized submodule --- Fetching v1.2.12.tar.gz --- Fetching bzip2-1.0.8.tar.gz [end of output] note: This error originates from a subprocess, and is likely not a problem with pip. error: metadata-generation-failed ```
Signed-off-by: Eli Uriegas Pull Request resolved: https://github.com/pytorch/pytorch/pull/114393 Approved by: https://github.com/atalman, https://github.com/kit1980, https://github.com/huydhn --- .github/ci_commit_pins/audio.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/ci_commit_pins/audio.txt b/.github/ci_commit_pins/audio.txt index 975ee9d67de4..e8dd2d1c99fd 100644 --- a/.github/ci_commit_pins/audio.txt +++ b/.github/ci_commit_pins/audio.txt @@ -1 +1 @@ -a8f4e97bd5356a7a77510cdf6a3a62e25a5dc602 \ No newline at end of file +db624844f5c95bb7618fe5a5f532bf9b68efeb45 From e7726b596e0a9bfd07eb6cf8744e2cdb78b5ffce Mon Sep 17 00:00:00 2001 From: Andrew Gu Date: Wed, 22 Nov 2023 10:21:43 -0800 Subject: [PATCH 117/221] [FSDP] Added DDP parity test for CPU training (#114372) This is a follow-up to https://github.com/pytorch/pytorch/pull/112145/ to include a numerical parity test with DDP for CPU training. ``` python -m pytest test/distributed/fsdp/test_fsdp_misc.py -k test_fsdp_cpu_training -s ``` We should follow-up on https://github.com/pytorch/pytorch/pull/112145/files#r1375102283 at some point too. Pull Request resolved: https://github.com/pytorch/pytorch/pull/114372 Approved by: https://github.com/XilunWu --- test/distributed/fsdp/test_fsdp_misc.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/test/distributed/fsdp/test_fsdp_misc.py b/test/distributed/fsdp/test_fsdp_misc.py index b409ec5bb3ff..d2589cc9d7e4 100644 --- a/test/distributed/fsdp/test_fsdp_misc.py +++ b/test/distributed/fsdp/test_fsdp_misc.py @@ -29,6 +29,7 @@ ) from torch.distributed.optim import _apply_optimizer_in_backward from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer +from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_fsdp import ( _assert_module_states, @@ -505,7 +506,6 @@ def test_fsdp_optimizer_overlap(self): @skip_if_lt_x_gpu(2) def test_fsdp_cpu_training(self): """Tests FSDP training on CPU.""" - torch.manual_seed(0) gloo_pg = dist.new_group(backend="gloo") for ss in [ ShardingStrategy.NO_SHARD, @@ -514,15 +514,28 @@ def test_fsdp_cpu_training(self): ShardingStrategy.HYBRID_SHARD, ShardingStrategy._HYBRID_SHARD_ZERO2, ]: + torch.manual_seed(42) model = MyModel() - fsdp = FSDP( + ref_model = DDP(deepcopy(model), process_group=gloo_pg) + model = FSDP( model, auto_wrap_policy=always_wrap_policy, process_group=gloo_pg, device_id=torch.device("cpu"), ) + ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2) + optim = torch.optim.Adam(model.parameters(), lr=1e-2) + torch.manual_seed(42 + self.rank) inp = torch.randn(2, 2) - fsdp(inp, inp).sum().backward() + for _ in range(10): + losses = [] + for _model, _optim in ((ref_model, ref_optim), (model, optim)): + loss = _model(inp, inp).sum() + losses.append(loss) + loss.backward() + _optim.step() + _optim.zero_grad() + self.assertEqual(losses[0], losses[1]) @skip_if_lt_x_gpu(2) def test_fsdp_cpu_init_stays_on_cpu(self): From 88a8a0daa4447e19e8355fa93ca9b3d8b3347ce8 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 22 Nov 2023 21:49:31 +0000 Subject: [PATCH 118/221] Revert "Require less alignment for masking (#114173)" This reverts commit f882c175d8e9731238c3f29ca10821f2fe9f0797. Reverted https://github.com/pytorch/pytorch/pull/114173 on behalf of https://github.com/huydhn due to Sorry for reverting you change, but it is failing some inductor tests https://hud.pytorch.org/pytorch/pytorch/commit/f882c175d8e9731238c3f29ca10821f2fe9f0797 ([comment](https://github.com/pytorch/pytorch/pull/114173#issuecomment-1823552362)) --- .../ATen/native/transformers/attention.cpp | 45 ++++++++++--------- test/test_transformers.py | 18 -------- torch/_meta_registrations.py | 16 +++---- 3 files changed, 31 insertions(+), 48 deletions(-) diff --git a/aten/src/ATen/native/transformers/attention.cpp b/aten/src/ATen/native/transformers/attention.cpp index ffecaf994520..63b4a52d8c07 100644 --- a/aten/src/ATen/native/transformers/attention.cpp +++ b/aten/src/ATen/native/transformers/attention.cpp @@ -522,14 +522,9 @@ c10::optional convert_boolean_attn_mask(const c10::optional& att // We apply this function to the top level SDPA so that // if padding is done it will be tracked for backward automatically -template -bool aligned_tensor(const at::Tensor& tensor){ - for(const auto i : c10::irange(tensor.dim() - 1)){ - if(tensor.sym_stride(i) % alignment != 0){ - return false; - } - } - return tensor.sym_stride(-1) == 1; +template +bool is_aligned(const SymInt& size){ + return size % alignment == 0; } template @@ -545,23 +540,31 @@ at::Tensor preprocess_mask( const at::Tensor& query, const at::Tensor& key, const at::Tensor& value) { - constexpr int mem_eff_alignment = 8; - at::Tensor result_mask = mask; - if (!aligned_tensor(mask)) { - TORCH_WARN_ONCE( - "Memory Efficient Attention requires the attn_mask to be aligned to, ", - mem_eff_alignment, - " elements. " - "Prior to calling SDPA, pad the last dimension of the attn_mask " - "to be a multiple of ", mem_eff_alignment, - " and then slice the attn_mask to the original size."); - result_mask = pad_bias(mask); - } - return result_mask.expand_symint( + constexpr int mem_eff_alignment = 16; + // Expand to 4d case + at::Tensor attn_mask = mask.expand_symint( {query.sym_size(0), query.sym_size(1), query.sym_size(2), key.sym_size(2)}); + + bool aligned_last_dim = is_aligned(attn_mask.sym_size(-1)); + // Apply pad_bias and store the result in attn_mask + if (!aligned_last_dim) { + return pad_bias(attn_mask); + } + // Check and make the tensor contiguous if needed + auto needs_contig = [](const c10::SymInt& stride) { + return (stride % 16 != 0) || (stride == 0); + }; + if (needs_contig(attn_mask.sym_stride(0)) || + needs_contig(attn_mask.sym_stride(1)) || + needs_contig(attn_mask.sym_stride(2)) || + needs_contig(attn_mask.sym_stride(3))) { + return attn_mask.contiguous(); + } + + return attn_mask; } // FlashAttentionV2 requires that head dimension be a multiple of 8 // This was previously done within the kernel, however diff --git a/test/test_transformers.py b/test/test_transformers.py index 81e574b75655..5785fedca0e1 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -1898,24 +1898,6 @@ def test_mem_eff_attention_long_sequence_mask(self, device, dtype): out = F.scaled_dot_product_attention(query, key, value, mask) out.sum().backward() - @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system") - def test_mem_eff_attention_non_contig_mask_bug(self, device): - dtype = torch.float32 - make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=True) - batch, num_heads, head_dim = 1, 16, 128 - seq_len_q, seq_len_kv = 1, 16 - query = make_tensor(batch, seq_len_q, num_heads * head_dim).view(batch, seq_len_q, num_heads, head_dim).transpose(1, 2) - kv_shape = (batch, seq_len_kv, head_dim) - key, value = make_tensor(kv_shape).unsqueeze(1), make_tensor(kv_shape).unsqueeze(1) - key = key.expand(-1, num_heads, -1, -1) - value = value.expand(-1, num_heads, -1, -1) - mask = torch.ones((1, 1, seq_len_q, seq_len_kv), device=device, dtype=torch.bool) - with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]): - out = F.scaled_dot_product_attention(query, key, value, mask) - out_no_mask = F.scaled_dot_product_attention(query, key, value, None) - max_diff = (out - out_no_mask).abs().mean() - assert max_diff.item() < 1e-9 - @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system") @parametrize("type", ["dense", "nested"]) @parametrize("is_contiguous", [True, False]) diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index bb10a34c4c06..4c54df447bdb 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -5221,14 +5221,12 @@ def meta__scaled_dot_product_efficient_backward( ) grad_bias = None if attn_bias is not None and grad_input_mask[3]: - lastDim = attn_bias.size(-1) - lastDimAligned = lastDim if lastDim % 16 == 0 else lastDim + 16 - lastDim % 16 - new_sizes = list(attn_bias.size()) - new_sizes[-1] = lastDimAligned - grad_bias = torch.empty( - new_sizes, dtype=attn_bias.dtype, device=attn_bias.device + grad_bias = torch.empty_strided( + attn_bias.size(), + attn_bias.stride(), + dtype=attn_bias.dtype, + device=attn_bias.device, ) - grad_bias = grad_bias[..., :lastDim] return grad_q, grad_k, grad_v, grad_bias @@ -5305,12 +5303,12 @@ def meta__efficient_attention_backward( grad_value = torch.empty_like(value) if bias is not None: + assert bias is not None lastDim = bias.size(-1) - lastDimAligned = lastDim if lastDim % 16 == 0 else lastDim + 16 - lastDim % 16 + lastDimAligned = 16 * ((lastDim + 15) // 16) new_sizes = list(bias.size()) new_sizes[-1] = lastDimAligned grad_bias = torch.empty(new_sizes, dtype=bias.dtype, device=bias.device) - grad_bias = grad_bias[..., :lastDim] else: grad_bias = torch.empty((), device=query.device) From ea7d70aecc8336530411194bbd4076d3986b8272 Mon Sep 17 00:00:00 2001 From: Aaron Gokaslan Date: Wed, 22 Nov 2023 22:09:57 +0000 Subject: [PATCH 119/221] [BE]: ruff FURB136: replace ternary with min/max (preview) (#114382) Replaces ternary if else statements with simple min max when appropriate. Pull Request resolved: https://github.com/pytorch/pytorch/pull/114382 Approved by: https://github.com/albanD --- tools/testing/test_selections.py | 2 +- torch/autograd/profiler.py | 2 +- torch/nn/modules/_functions.py | 2 +- torch/utils/benchmark/utils/sparse_fuzzer.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tools/testing/test_selections.py b/tools/testing/test_selections.py index 340db6f499b8..19c9a312649a 100644 --- a/tools/testing/test_selections.py +++ b/tools/testing/test_selections.py @@ -37,7 +37,7 @@ count += 1 assert count > 0 # there must be at least 1 GPU # Limiting to 8 GPUs(PROCS) - NUM_PROCS = 8 if count > 8 else count + NUM_PROCS = min(count, 8) except subprocess.CalledProcessError as e: # The safe default for ROCm GHA runners is to run tests serially. NUM_PROCS = 1 diff --git a/torch/autograd/profiler.py b/torch/autograd/profiler.py index acae6ef5b337..9c5f8bd15169 100644 --- a/torch/autograd/profiler.py +++ b/torch/autograd/profiler.py @@ -471,7 +471,7 @@ def _privateuse1_memory_usage(mem_record): device_index=kineto_event.device_index(), flops=kineto_event.flops(), ) - max_evt_id = fe.id if fe.id > max_evt_id else max_evt_id + max_evt_id = max(max_evt_id, fe.id) if fe.device_type == DeviceType.CPU and not fe.is_async: if self.use_device: privateuse1_time = kineto_event.privateuse1_elapsed_us() diff --git a/torch/nn/modules/_functions.py b/torch/nn/modules/_functions.py index 770609a9ec9c..669448ce4fda 100644 --- a/torch/nn/modules/_functions.py +++ b/torch/nn/modules/_functions.py @@ -212,7 +212,7 @@ def forward(ctx, input, size, alpha=1e-4, beta=0.75, k=1): torch.pow(input, 2, out=input_square) pre_pad = int((ctx.size - 1) / 2 + 1) - pre_pad_crop = channels if pre_pad > channels else pre_pad + pre_pad_crop = min(pre_pad, channels) scale_first = ctx.scale.select(1, 0) scale_first.zero_() diff --git a/torch/utils/benchmark/utils/sparse_fuzzer.py b/torch/utils/benchmark/utils/sparse_fuzzer.py index 1b2e884ce956..eac6a6baf910 100644 --- a/torch/utils/benchmark/utils/sparse_fuzzer.py +++ b/torch/utils/benchmark/utils/sparse_fuzzer.py @@ -97,7 +97,7 @@ def _make_tensor(self, params, state): is_coalesced = params['coalesced'] sparse_dim = params['sparse_dim'] if self._sparse_dim else len(size) - sparse_dim = len(size) if len(size) < sparse_dim else sparse_dim + sparse_dim = min(sparse_dim, len(size)) tensor = self.sparse_tensor_constructor(size, self._dtype, sparse_dim, nnz, is_coalesced) if self._cuda: From 2f536ff92c0c8b9ef07f2e75b08ecdffc78ad56d Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Tue, 21 Nov 2023 21:06:48 +0000 Subject: [PATCH 120/221] Refactor values kwarg in foreach tests (#112781) Pull Request resolved: https://github.com/pytorch/pytorch/pull/112781 Approved by: https://github.com/lezcano ghstack dependencies: #112778 --- test/test_foreach.py | 66 +++++++++---------- .../_internal/common_methods_invocations.py | 7 +- 2 files changed, 38 insertions(+), 35 deletions(-) diff --git a/test/test_foreach.py b/test/test_foreach.py index 1ac70fdc8cc6..aeed73c5218c 100644 --- a/test/test_foreach.py +++ b/test/test_foreach.py @@ -30,12 +30,14 @@ class RegularFuncWrapper: def __init__(self, func): self.func = func - def __call__(self, inputs, values=None, **kwargs): - if values is not None: + def __call__(self, inputs, scalars=None, **kwargs): + if scalars is not None: assert len(inputs) == 3 - if isinstance(values, Number): - values = [values for _ in range(len(inputs[0]))] - return [self.func(*i, value=values[idx], **kwargs) for idx, i in enumerate(zip(*inputs))] + # We need to distribute each scalar to the regular func and it needs + # special consideration as it is a keyword only argument to the + # regular func. (Strangely, it is not a keyword only argument to the + # foreach func) + return [self.func(*i, value=scalars[idx], **kwargs) for idx, i in enumerate(zip(*inputs))] if len(inputs) == 2 and isinstance(inputs[1], (Number, torch.Tensor)): # binary op with tensorlist and scalar. inputs[1] = [inputs[1] for _ in range(len(inputs[0]))] @@ -149,14 +151,9 @@ def test_parity(self, device, dtype, op, noncontiguous, inplace): func, ref, _, _ = self._get_funcs(op) for sample in op.sample_inputs(device, dtype, noncontiguous=noncontiguous): ref_kwargs = sample.kwargs - kwargs = ref_kwargs.copy() # div promotes ints to floats, so we cannot go on the fastpath there div_slowpath = dtype in integral_types_and(torch.bool) and op.name == '_foreach_div' expect_fastpath = not (noncontiguous or sample.disable_fastpath or div_slowpath) - if op in foreach_pointwise_op_db: - values = kwargs.pop("values", None) - if values is not None: - sample.args = (*sample.args, values) ref_input, ctxmgr = sample.input, nullcontext() if inplace: with torch.no_grad(): @@ -164,7 +161,7 @@ def test_parity(self, device, dtype, op, noncontiguous, inplace): ctxmgr = InplaceForeachVersionBumpCheck(self, sample.input) try: with ctxmgr: - actual = func([sample.input, *sample.args], self.is_cuda, expect_fastpath, **kwargs) + actual = func([sample.input, *sample.args], self.is_cuda, expect_fastpath, **sample.kwargs) except Exception as e: with ( self.assertRaisesRegex(type(e), re.escape(str(e))) @@ -256,40 +253,44 @@ def test_pointwise_op_with_tensor_of_scalarlist_overload(self, device, dtype, op assert isinstance(sample.args, tuple) assert len(sample.args) == 2 inputs = [sample.input, *sample.args] - kwargs = sample.kwargs + kwargs = sample.kwargs.copy() disable_fastpath = sample.disable_fastpath and is_fastpath wrapped_op, ref, inplace_op, inplace_ref = self._get_funcs(op) - values = kwargs.pop("values", None) + scalars = kwargs.pop("scalars", None) - if is_fastpath and isinstance(values, list): + if is_fastpath and scalars: sample = sample.transform(lambda t: t.clone().detach() if torch.is_tensor(t) else t) inputs = [sample.input, *sample.args] - tensor_values = torch.tensor(values) + tensor_values = torch.tensor(scalars) # 1D Tensor of scalars for is_inplace, op_, ref_ in ((False, wrapped_op, ref), (True, inplace_op, inplace_ref)): self._pointwise_test( op_, ref_, inputs, is_fastpath and not disable_fastpath, is_inplace, - values=tensor_values) + scalars=tensor_values, **kwargs) self._pointwise_test( op_, ref_, inputs, is_fastpath and not disable_fastpath, is_inplace, - values=tensor_values[0], + scalars=tensor_values[0], custom_values_err="Expected packed scalar Tensor to be of dimension 1. Got 0 instead.", + **kwargs, ) if self.is_cuda: self._pointwise_test( op_, ref_, inputs, is_fastpath and not disable_fastpath, is_inplace, - values=tensor_values.cuda(), + scalars=tensor_values.cuda(), custom_values_err="Expected scalars to be on CPU, got cuda:0 instead.", + **kwargs, ) self._pointwise_test( op_, ref_, inputs, is_fastpath and not disable_fastpath, is_inplace, - values=tensor_values[:2], - custom_values_err=f"Expected length of scalars to match input of length {len(values)} but got 2 instead.", + scalars=tensor_values[:2], + custom_values_err=f"Expected length of scalars to match input of length {len(scalars)} but got 2 instead.", + **kwargs, ) self._pointwise_test( op_, ref_, inputs, is_fastpath and not disable_fastpath, is_inplace, - values=torch.tensor([[0, 1], [2, 3]])[:, 1], + scalars=torch.tensor([[0, 1], [2, 3]])[:, 1], custom_values_err="Expected scalars to be contiguous.", + **kwargs, ) # Tests of implicit broadcasting @@ -307,41 +308,42 @@ def test_pointwise_op_with_tensor_of_scalarlist_overload(self, device, dtype, op ] self._pointwise_test( wrapped_op, ref, inputs, is_fastpath and disable_fastpath, is_inplace=False, - values=values) + scalars=scalars, **kwargs) self._pointwise_test( inplace_op, inplace_ref, inputs, is_fastpath and disable_fastpath, - is_inplace=True, values=values) + is_inplace=True, scalars=scalars, **kwargs) def _pointwise_test( self, op, ref, inputs, is_fastpath, is_inplace, *, - values=None, custom_values_err=None, + scalars=None, custom_values_err=None, **kwargs ): - kwargs = {} ref_inputs = [[t.clone().detach() for t in inputs[0]], inputs[1], inputs[2]] if is_inplace else inputs try: with (InplaceForeachVersionBumpCheck(self, inputs[0]) if is_inplace else nullcontext()): actual = op(inputs, self.is_cuda, is_fastpath, **kwargs) except RuntimeError as e: with self.assertRaisesRegex(type(e), re.escape(str(e))): - ref(ref_inputs) + ref(ref_inputs, **kwargs) else: - expected = ref(ref_inputs) + expected = ref(ref_inputs, **kwargs) self.assertEqual(expected, actual) - if values is not None: + if scalars is not None: + kwargs = kwargs.copy() + kwargs["scalars"] = scalars try: - actual = op(inputs + [values], self.is_cuda, is_fastpath, **kwargs) + actual = op(inputs, self.is_cuda, is_fastpath, **kwargs) except RuntimeError as e: # Match with error messages from regular non-foreach reference if no # custom error message was provided. if custom_values_err is None: with self.assertRaisesRegex(type(e), re.escape(str(e))): - ref(ref_inputs, values=values) + ref(ref_inputs, **kwargs) else: self.assertEqual(re.escape(str(e)), re.escape(custom_values_err)) else: - expected = ref(ref_inputs, values=values) + expected = ref(ref_inputs, **kwargs) self.assertEqual(expected, actual) @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16)) @@ -692,8 +694,6 @@ def test_outplace_with_invalid_grads(self, device, dtype, op): func, *_ = self._get_funcs(op) sample = list(op.sample_inputs(dtype=dtype, device=device, requires_grad=True, num_input_tensors=[2], same_size=True))[0] self.assertTrue(all(t.requires_grad for t in sample.input)) - if func.func in foreach_pointwise_op_db: - sample.kwargs.pop("values", None) (out1, out2) = func([sample.input, *sample.args], is_cuda=False, expect_fastpath=False, **sample.kwargs) out1.backward(torch.ones_like(out1)) self.assertIsNotNone(sample.input[0].grad) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index ff926426b8ca..e6cd427a99e2 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -8936,7 +8936,8 @@ def sample_zero_size_tensor_inputs(self, opinfo, device, dtype, requires_grad, * sample_inputs_foreach(None, device, dtype, NUM_SIZE0_TENSORS, zero_size=True, **_foreach_inputs_kwargs) for _ in range(2) ] - kwargs["values"] = None + if "scalars" in kwargs: + del kwargs["scalars"] kwargs.update(self._sample_kwargs(opinfo, args[-1], ForeachRightmostArgType.TensorList, dtype)) yield ForeachSampleInput(input, *args, **kwargs) @@ -8959,8 +8960,10 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): kwargs = {} if rightmost_arg_type == ForeachRightmostArgType.TensorList: args.append(rightmost_arg) + elif rightmost_arg_type in [ForeachRightmostArgType.Tensor, ForeachRightmostArgType.ScalarList]: + kwargs["scalars"] = rightmost_arg else: - kwargs["values"] = rightmost_arg + kwargs["value"] = rightmost_arg kwargs.update(self._sample_kwargs(opinfo, rightmost_arg, rightmost_arg_type, dtype)) assert len(args) == 2, f"{len(args)=}" sample = ForeachSampleInput(input, *args, **kwargs) From 6f3cd046ab99a293cf2560fbf249ab0d52e9b121 Mon Sep 17 00:00:00 2001 From: ydwu4 Date: Wed, 22 Nov 2023 10:53:22 -0800 Subject: [PATCH 121/221] [BE] remove skipIfDynamo for some module hook tests (#114387) As titled. Test Plan: exiting tests. Pull Request resolved: https://github.com/pytorch/pytorch/pull/114387 Approved by: https://github.com/ezyang --- test/nn/test_module_hooks.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/test/nn/test_module_hooks.py b/test/nn/test_module_hooks.py index 4102a26897cf..34958e6e5fd5 100644 --- a/test/nn/test_module_hooks.py +++ b/test/nn/test_module_hooks.py @@ -184,7 +184,6 @@ def __exit__(self, *args, **kwargs): class TestModuleHooks(TestCase): - @skipIfTorchDynamo("Dynamo does not yet capture hooks") @parametrize_test("named_tuple", (True, False)) def test_forward_hooks(self, named_tuple): fired_hooks: List[int] = [] @@ -207,7 +206,6 @@ def test_forward_hooks(self, named_tuple): model(x)[0].sum().backward() self.assertEqual(fired_hooks, expected + expected) - @skipIfTorchDynamo("Dynamo does not yet capture hooks") @parametrize_test("named_tuple", (True, False)) def test_forward_pre_hooks(self, named_tuple): fired_hooks: List[int] = [] @@ -234,7 +232,6 @@ def test_forward_pre_hooks(self, named_tuple): model(x)[0].sum().backward() self.assertEqual(fired_hooks, expected + expected) - @skipIfTorchDynamo("Dynamo does not yet capture hooks") @parametrize_test("named_tuple", (True, False)) def test_full_backward_hooks(self, named_tuple): fired_hooks: List[int] = [] @@ -257,7 +254,6 @@ def test_full_backward_hooks(self, named_tuple): model(x)[0].sum().backward() self.assertEqual(fired_hooks, expected + expected) - @skipIfTorchDynamo("Dynamo does not yet capture hooks") @parametrize_test("named_tuple", (True, False)) def test_full_backward_pre_hooks(self, named_tuple): fired_hooks: List[int] = [] @@ -297,7 +293,6 @@ def fn(_unused_module, grad_output): out.sum().backward() self.assertEqual(a.grad, torch.zeros_like(a)) - @skipIfTorchDynamo("Dynamo does not yet capture hooks") @parametrize_test("named_tuple", (True, False)) def test_mixed_hooks(self, named_tuple): fired_hooks: List[int] = [] @@ -325,7 +320,6 @@ def test_mixed_hooks(self, named_tuple): model(x)[0].sum().backward() self.assertEqual(fired_hooks, [0, 1, 2, 3, 0, 1, 2, 3]) - @skipIfTorchDynamo("Dynamo does not yet capture hooks") def test_kwarg_hooks(self): # 1. test forward pre hook fired_hooks: List[int] = [] @@ -382,7 +376,6 @@ def test_kwarg_hooks(self): self.assertEqual(out, x + 2 * bias, rtol=0, atol=1e-5) - @skipIfTorchDynamo("Dynamo does not yet capture hooks") def test_remove_kwarg_hooks(self): # test forward pre and forward hooks fired_hooks: List[int] = [] @@ -428,7 +421,6 @@ def test_remove_kwarg_hooks(self): forward_pre_hook_handle.id in model._forward_pre_hooks_with_kwargs ) - @skipIfTorchDynamo("Dynamo does not yet capture hooks") def test_always_called_forward_hooks(self): x: torch.Tensor = torch.ones(10, 10) model = FailsInForwardModel() @@ -516,7 +508,6 @@ def throw_hook(m, i, o): model(x) self.assertEqual(stack, [2, -1, 2, -1, 2, -1, 2, -1, 2, -1, 2, -1, 2]) - @skipIfTorchDynamo("Dynamo does not yet capture hooks") def test_bw_hook_warning_for_non_tensor_or_tuple(self): # Test to verify that backward hook raises warning # if result is not a Tensor or tuple of Tensors. @@ -871,7 +862,6 @@ def bw_fail2(self, grad_input, grad_output): with self.assertRaisesRegex(RuntimeError, 'got 2, but expected 1'): module(input).sum().backward() - @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/847") def test_module_backward_global_hook_writeable(self): module = nn.Sigmoid() input = torch.randn(5, 5, requires_grad=True) @@ -911,7 +901,6 @@ def forward_hook(m, input, output): expected_grad = -sig_x * (1 - sig_x) * 2 * mask self.assertEqual(input.grad, expected_grad) - @skipIfTorchDynamo("TorchDynamo does not work well with hooks") def test_module_forward_preforward_hook_removable(self): """ This test is to test when multiple pre-forward hook functions can be @@ -947,7 +936,6 @@ def removable_hook_2(m, input): self.assertEqual(len(handle.hooks_dict_ref()), 0) self.assertEqual(len(handle_2.hooks_dict_ref()), 0) - @skipIfTorchDynamo("TorchDynamo does not work well with hooks") def test_module_forward_forward_hook_removable(self): """ This test is to test when multiple forward hook functions can be registered @@ -1131,7 +1119,6 @@ def bw_pre_hook(inc, h_module, grad_output): test_fwd.remove() test_bwd.remove() - @skipIfTorchDynamo("TorchDynamo does not work well with hooks") def test_hooks(self): self._test_hooks("register_backward_hook") self._test_hooks("register_full_backward_hook") @@ -1149,7 +1136,6 @@ def hook(module, grad_inputs, grad_outputs): output = bn(torch.randn(5, 5, requires_grad=True)) output.sum().backward() - @skipIfTorchDynamo("TorchDynamo does not work well with hooks") def test_backward_hooks_interaction(self): # Test to make sure that the grad_outputs # updated by full_backward_pre_hook are received by @@ -1224,7 +1210,6 @@ def forward(self, arg1, arg2, arg3): mod.register_full_backward_hook(lambda mod, gI, gO: None) mod(inp, inp.detach(), inp) - @skipIfTorchDynamo("TorchDynamo does not work well with hooks") def test_hook_no_requires_grad(self): mod = nn.Linear(2, 3) @@ -1418,7 +1403,6 @@ def bw_hook(module, grad_input, grad_output): with module.register_full_backward_hook(bw_hook): module(inp1, inp2).sum().backward() - @skipIfTorchDynamo("TorchDynamo does not work well with hooks") def test_hook_backward_writeable(self): module = nn.Sigmoid() input = torch.randn(5, 5, requires_grad=True) @@ -1436,7 +1420,6 @@ def bw_hook(module, grad_input, grad_output): expected_grad = sig_x * (1 - sig_x) * 2 self.assertEqual(input.grad, expected_grad) - @skipIfTorchDynamo("TorchDynamo does not work well with hooks") def test_hook_forward_preforward_writable(self): module = nn.Sigmoid() input = torch.randn(5, 5, requires_grad=True) From aca6446a6e0b238715bc48e3b49f447875f37c06 Mon Sep 17 00:00:00 2001 From: PyTorch UpdateBot Date: Wed, 22 Nov 2023 22:38:40 +0000 Subject: [PATCH 122/221] [executorch hash update] update the pinned executorch hash (#114325) This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/_update-commit-hash.yml). Update the pinned executorch hash. Pull Request resolved: https://github.com/pytorch/pytorch/pull/114325 Approved by: https://github.com/pytorchbot --- .ci/docker/ci_commit_pins/executorch.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.ci/docker/ci_commit_pins/executorch.txt b/.ci/docker/ci_commit_pins/executorch.txt index d60ac4ac9b62..8178012c2274 100644 --- a/.ci/docker/ci_commit_pins/executorch.txt +++ b/.ci/docker/ci_commit_pins/executorch.txt @@ -1 +1 @@ -f5e4a1e74daa5397362d87d1a8cb81f09446d34f +ccdb12eebe91b9e04ec92988991228916a92292f From 5f504d1de74a5189f93e65aa9283fc4607b8631c Mon Sep 17 00:00:00 2001 From: Pedro Caldeira Date: Wed, 22 Nov 2023 22:57:32 +0000 Subject: [PATCH 123/221] Check for boolean values as argument on pow function. (#114133) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Hello everyone! 😄 Also @lezcano , nice to meet you! :) Sorry if I miss anything, this is my first time around here. 🙃 This PR basically makes the same behaviour for cuda when using `torch.pow`. Basically Python considers True as 1 and False as 0. I just added this check into `pow` function. From what I understood, when I do `.equal` for `Scalar` that is boolean, I'm sure that types match so that won't cause more trouble. I know that the issue suggest to disable this case but that could be a little more complicated, in my humble opinion. And that can create some compability problems too, I guess. My argument is that code below is correct for native language, so I guess it does makes sense sending booleans as Scalar. ``` $ x = True $ x + x 2 ``` This was my first test: ``` Python 3.12.0 | packaged by Anaconda, Inc. | (main, Oct 2 2023, 17:29:18) [GCC 11.2.0] on linux Type "help", "copyright", "credits" or "license" for more information. >>> import torch >>> torch.pow(torch.tensor([1, 2], device='cuda'), True) tensor([1, 2], device='cuda:0') >>> torch.pow(torch.tensor([1, 2]), True) tensor([1, 2]) >>> torch.pow(torch.tensor([1, 2]), False) tensor([1, 1]) >>> torch.pow(torch.tensor([1, 2], device='cuda'), False) tensor([1, 1], device='cuda:0') ``` I've run `test_torch.py` and got following results, so my guess is that I didn't break anything. I was just looking for a test that uses linear regression, as suggested. ``` Ran 1619 tests in 52.363s OK (skipped=111) [TORCH_VITAL] Dataloader.enabled True [TORCH_VITAL] Dataloader.basic_unit_test TEST_VALUE_STRING [TORCH_VITAL] CUDA.used true ``` (I can paste whole log, if necessary) If this is a bad idea overall, dont worry about it. It's not a big deal, it's actually a two line change 😅 so can we talk of how do things in a different strategy. For the record I've signed the agreement already. And I didn't run linter because it's not working 😞 . Looks like PyYaml 6.0 is broken and there's a 6.0.1 fix already but I have no idea how to update that 😅 Fixes #113198 Pull Request resolved: https://github.com/pytorch/pytorch/pull/114133 Approved by: https://github.com/lezcano --- aten/src/ATen/native/Pow.cpp | 4 ++-- test/test_binary_ufuncs.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/native/Pow.cpp b/aten/src/ATen/native/Pow.cpp index 0fa0fceab6ca..5c8147d7ced3 100644 --- a/aten/src/ATen/native/Pow.cpp +++ b/aten/src/ATen/native/Pow.cpp @@ -50,9 +50,9 @@ TORCH_IMPL_FUNC(pow_Tensor_Tensor_out) (const Tensor& base, const Tensor& exp, c } TORCH_IMPL_FUNC(pow_Tensor_Scalar_out) (const Tensor& base, const Scalar& exp, const Tensor& out) { - if (exp.equal(0.0)) { + if (exp.equal(0.0) || exp.equal(false)) { out.fill_(1); - } else if (exp.equal(1.0)) { + } else if (exp.equal(1.0) || exp.equal(true) ) { out.copy_(base); } else { pow_tensor_scalar_stub(device_type(), *this, exp); diff --git a/test/test_binary_ufuncs.py b/test/test_binary_ufuncs.py index 2f22569f9cf1..9fcc8b445eb1 100644 --- a/test/test_binary_ufuncs.py +++ b/test/test_binary_ufuncs.py @@ -1345,7 +1345,7 @@ def test_pow(self, device, dtype): (100, 100), low=1, high=range_high, dtype=dtype, device=device ) - exponents = [-2.8, -2, -1, -0.5, 0, 0.5, 1, 2, 3, 4, 3.3] + exponents = [-2.8, -2, -1, -0.5, 0, 0.5, 1, 2, 3, 4, 3.3, True, False] complex_exponents = [ -2.5j, -1.0j, From 6a86cf00adb071fe7200229cc1996dfefbb79289 Mon Sep 17 00:00:00 2001 From: eqy Date: Wed, 22 Nov 2023 23:23:46 +0000 Subject: [PATCH 124/221] [CUDA][cuBLAS] Remove explicit cuBLAS workspace allocation for CUDA 12.2+ (#113994) cuBLAS should be using `cudaMallocAsync` in CUDA 12.2+, which removes the need for explicit workspace allocation to avoid increasing memory usage with multiple graph captures. CC @ptrblck @malfet Pull Request resolved: https://github.com/pytorch/pytorch/pull/113994 Approved by: https://github.com/ezyang, https://github.com/malfet --- aten/src/ATen/cuda/CublasHandlePool.cpp | 10 +++++++--- test/test_cuda.py | 4 +++- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/aten/src/ATen/cuda/CublasHandlePool.cpp b/aten/src/ATen/cuda/CublasHandlePool.cpp index e495dffe02a9..dae61a4365cb 100644 --- a/aten/src/ATen/cuda/CublasHandlePool.cpp +++ b/aten/src/ATen/cuda/CublasHandlePool.cpp @@ -40,7 +40,9 @@ using CuBlasPoolType = DeviceThreadHandlePoolreserve(device); auto stream = c10::cuda::getCurrentCUDAStream(); TORCH_CUDABLAS_CHECK(cublasSetStream(handle, stream)); -#if !defined(USE_ROCM) - // cublasSetWorkspace not available on CUDA 10.2 +#if !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION < 12200 + // cuBLAS should not need an explicitly allocated workspace after CUDA 12.2 + // to avoid increasing memory usage during graph captures + // original issue: https://github.com/pytorch/pytorch/pull/83461 cudaStream_t _stream = stream; auto key = std::make_tuple(static_cast(handle), static_cast(_stream)); auto workspace_it = cublas_handle_stream_to_workspace().find(key); diff --git a/test/test_cuda.py b/test/test_cuda.py index e4bb3145d56a..6bd2d416ca84 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -29,7 +29,8 @@ NO_MULTIPROCESSING_SPAWN, skipIfRocm, load_tests, IS_WINDOWS, \ slowTest, skipCUDANonDefaultStreamIf, skipCUDAMemoryLeakCheckIf, TEST_CUDA, TEST_CUDA_GRAPH, TEST_WITH_ROCM, TEST_NUMPY, \ get_cycles_per_ms, parametrize, instantiate_parametrized_tests, subtest, IS_JETSON, gcIfJetson, NoTest, IS_LINUX -from torch.testing._internal.common_cuda import TEST_CUDNN, TEST_MULTIGPU, _create_scaling_case, _create_scaling_models_optimizers +from torch.testing._internal.common_cuda import TEST_CUDNN, TEST_MULTIGPU, \ + _create_scaling_case, _create_scaling_models_optimizers, _get_torch_cuda_version from torch.testing._internal.autocast_test_lists import AutocastTestLists from torch.utils.viz._cycles import observe_tensor_cycles @@ -296,6 +297,7 @@ def test_serialization_array_with_storage(self): self.assertEqual(q_copy[3], torch.cuda.IntStorage(10).fill_(10)) @unittest.skipIf(TEST_CUDAMALLOCASYNC or TEST_WITH_ROCM, "temporarily disabled for async") + @unittest.skipIf(_get_torch_cuda_version() >= (12, 2), "skipped as explicit workspace allocation is removed") def test_cublas_workspace_explicit_allocation(self): a = torch.randn(7, 7, device='cuda', requires_grad=False) default_workspace_size = 4096 * 2 * 1024 + 16 * 8 * 1024 # :4096:2:16:8 From f961bda939e85cf8973f23a28e33abb50f5cfbec Mon Sep 17 00:00:00 2001 From: Angela Yi Date: Wed, 22 Nov 2023 23:44:20 +0000 Subject: [PATCH 125/221] [export] Move serialized custom class objs to toplevel (#114371) Summary: Move the serialized CustomClassHolder objects to the toplevel SerializedArtifact instead of embedding the bytes in the graph. Currently the CustomClassHolder objects are embedded in the graph instead of being lifted to the ExportedProgram, so there's some logic introduced to lift it to the higher level of the serialized ExportedProgram. However, once that CustomClassHolder objects get lifted, we can remove the TODOs I added. Test Plan: CI Reviewed By: zhxchen17 Differential Revision: D51479125 Pull Request resolved: https://github.com/pytorch/pytorch/pull/114371 Approved by: https://github.com/ydwu4 --- torch/_export/serde/schema.py | 2 +- torch/_export/serde/serialize.py | 56 ++++++++++++++++++++------------ 2 files changed, 36 insertions(+), 22 deletions(-) diff --git a/torch/_export/serde/schema.py b/torch/_export/serde/schema.py index 048d9ce8098a..047809350cd3 100644 --- a/torch/_export/serde/schema.py +++ b/torch/_export/serde/schema.py @@ -168,7 +168,7 @@ class GraphArgument: @dataclass class CustomObjArgument: - blob: bytes + name: str # This is actually a union type diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py index d3fad019aa09..8324a769d991 100644 --- a/torch/_export/serde/serialize.py +++ b/torch/_export/serde/serialize.py @@ -5,7 +5,6 @@ import logging import math import operator -import pickle import typing from contextlib import contextmanager @@ -166,7 +165,7 @@ def _reverse_map(d: Dict[Any, Enum]): class SerializedArtifact: exported_program: Union[ExportedProgram, bytes] state_dict: bytes - tensor_constants: bytes + constants: bytes def deserialize_device(d: Device) -> torch.device: @@ -303,7 +302,6 @@ class GraphState: sym_int_values: Dict[str, SymInt] = field(default_factory=dict) sym_bool_values: Dict[str, SymBool] = field(default_factory=dict) is_single_tensor_return: bool = False - constants: Dict[str, torch.Tensor] = field(default_factory=dict) class GraphModuleSerializer: @@ -315,6 +313,7 @@ def __init__( self.graph_state = GraphState() self.graph_signature = graph_signature self.module_call_graph = module_call_graph + self.custom_objs: Dict[str, torch._C.ScriptObject] = {} @contextmanager def save_graph_state(self): @@ -640,19 +639,20 @@ def serialize_optional_tensor_args(a): return Argument.create(as_layout=_TORCH_TO_SERIALIZE_LAYOUT[arg]) elif isinstance(arg, torch._C.ScriptObject): if not ( - hasattr(type(arg), "__getstate__") and - hasattr(type(arg), "__setstate__") + arg._has_method("__getstate__") and # type: ignore[attr-defined] + arg._has_method("__setstate__") # type: ignore[attr-defined] ): raise SerializeError( - f"Unable to serialize ScriptObject {arg}. Please define " + f"Unable to serialize custom class {arg}. Please define " "serialization methods via def_pickle()." ) # Custom objects through torchind are serializable with pickle, # through implementing the .def_pickle function. This should result # in the object containing a __getstate__ and __setstate__ # serialize/deserialize function. - blob = pickle.dumps(arg) - return Argument.create(as_custom_obj=CustomObjArgument(blob)) + custom_obj_name = f"_custom_obj_{len(self.custom_objs)}" + self.custom_objs[custom_obj_name] = arg + return Argument.create(as_custom_obj=CustomObjArgument(custom_obj_name)) else: raise SerializeError(f"Unsupported argument type: {type(arg)}") @@ -949,14 +949,23 @@ def __init__(self, opset_version: Optional[Dict[str, int]] = None): self.opset_version["aten"] = torch._C._get_max_operator_version() def serialize(self, exported_program: ep.ExportedProgram) -> SerializedArtifact: - serialized_graph_module = ( - GraphModuleSerializer( - exported_program.graph_signature, - exported_program.module_call_graph - ).serialize(exported_program.graph_module) + gm_serializer = GraphModuleSerializer( + exported_program.graph_signature, + exported_program.module_call_graph ) + serialized_graph_module = gm_serializer.serialize(exported_program.graph_module) serialized_range_constraints = serialize_range_constraints(exported_program.range_constraints) + # TODO: Directly serialize exported_program.constants once + # CustomClassHolders get stored in the ExportedProgram rather than in + # the graph + constants = {} + for n, c in gm_serializer.custom_objs.items(): + constants[n] = c + for n, t in exported_program.tensor_constants.items(): + assert n not in constants + constants[n] = t + return SerializedArtifact( ExportedProgram( graph_module=serialized_graph_module, @@ -966,7 +975,7 @@ def serialize(self, exported_program: ep.ExportedProgram) -> SerializedArtifact: dialect=exported_program.dialect, ), serialize_torch_artifact(exported_program.state_dict), - serialize_torch_artifact(exported_program.tensor_constants), + serialize_torch_artifact(constants), ) @@ -1251,6 +1260,7 @@ def deserialize( self, serialized_graph_module: GraphModule, symbol_name_to_range: Optional[Dict[str, symbolic_shapes.ValueRanges]] = None, + constants: Optional[Dict[str, Any]] = None, ) -> Result: self.shape_env = symbolic_shapes.ShapeEnv(assume_static_by_default=True) self.fake_tensor_mode = FakeTensorMode( @@ -1260,6 +1270,7 @@ def deserialize( ) self.symbol_name_to_symbol: Dict[str, sympy.Symbol] = {} self.symbol_name_to_range = {} if symbol_name_to_range is None else symbol_name_to_range + self.constants = {} if constants is None else constants self.deserialize_graph(serialized_graph_module.graph) @@ -1357,10 +1368,7 @@ def deserialize_optional_tensor_args(a): else: raise SerializeError(f"Unhandled argument {inp}") elif isinstance(value, CustomObjArgument): - # Custom objects through torchind are deserializable with pickle, - # through implementing the .def_pickle function. - blob = base64.b64decode(value.blob) - return pickle.loads(blob) + return self.constants[value.name] else: raise SerializeError(f"Unhandled argument {inp}") @@ -1549,12 +1557,19 @@ def deserialize( k: symbolic_shapes.ValueRanges(_int_to_sympy_int(v.min_val), _int_to_sympy_int(v.max_val)) for k, v in serialized_artifact.exported_program.range_constraints.items() } + constants = deserialize_torch_artifact(serialized_artifact.constants) + + # TODO: No need to do this once CustomClassHolders are lifted to the ExportedProgram + tensor_constants = { + k: v for k, v in constants.items() if isinstance(v, torch.Tensor) + } res = ( GraphModuleDeserializer() .deserialize( serialized_artifact.exported_program.graph_module, symbol_name_to_range, + constants, ) ) range_constraints = self.deserialize_range_constraints( @@ -1566,7 +1581,6 @@ def deserialize( upgrader = GraphModuleOpUpgrader(self.expected_opset_version, model_opset_version) state_dict = deserialize_torch_artifact(serialized_artifact.state_dict) - tensor_constants = deserialize_torch_artifact(serialized_artifact.tensor_constants) exported_program = ep.ExportedProgram( res.graph_module, @@ -1648,7 +1662,7 @@ def serialize( artifact = SerializedArtifact( json_bytes, serialized_artifact.state_dict, - serialized_artifact.tensor_constants + serialized_artifact.constants ) return artifact @@ -1705,7 +1719,7 @@ def deserialize( SerializedArtifact( serialized_exported_program, artifact.state_dict, - artifact.tensor_constants + artifact.constants ) ) ) From 272b40aee584bedde3c86f3c216f76ac2f8502c4 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 23 Nov 2023 01:43:32 +0000 Subject: [PATCH 126/221] Revert "deprecate PairwiseParallel from test (#114314)" This reverts commit 07b6f377b401933e69a605037b8a5c2fba627601. Reverted https://github.com/pytorch/pytorch/pull/114314 on behalf of https://github.com/huydhn due to Sorry for reverting your change, but this seems to fail periodic multigpu tests ([comment](https://github.com/pytorch/pytorch/pull/114314#issuecomment-1823727818)) --- .../tensor/parallel/test_ddp_2d_parallel.py | 14 ++--- .../tensor/parallel/test_fsdp_2d_parallel.py | 37 ++++-------- .../tensor/parallel/test_parallelize_api.py | 57 +++++++++++++++++-- .../tensor/parallel/test_tp_examples.py | 11 ++-- .../tensor/parallel/test_tp_random_state.py | 14 ++++- 5 files changed, 83 insertions(+), 50 deletions(-) diff --git a/test/distributed/tensor/parallel/test_ddp_2d_parallel.py b/test/distributed/tensor/parallel/test_ddp_2d_parallel.py index e68cff5e4023..4c78b8b2eba6 100644 --- a/test/distributed/tensor/parallel/test_ddp_2d_parallel.py +++ b/test/distributed/tensor/parallel/test_ddp_2d_parallel.py @@ -3,11 +3,7 @@ import torch import torch.distributed as dist from torch.distributed._tensor import DeviceMesh, DTensor, Replicate -from torch.distributed.tensor.parallel import ( - ColwiseParallel, - parallelize_module, - RowwiseParallel, -) +from torch.distributed.tensor.parallel import PairwiseParallel, parallelize_module from torch.distributed.tensor.parallel.ddp import _pre_dp_module_transform from torch.nn.parallel import DistributedDataParallel as DDP @@ -42,11 +38,9 @@ def init_model(device_type, model_parallel_size=TP_DEGREE): dp_pg = twod_mesh.get_dim_groups()[0] - parallelize_plan = { - "net1": ColwiseParallel(), - "net2": RowwiseParallel(), - } - twod_model = parallelize_module(twod_model, twod_mesh, parallelize_plan, tp_mesh_dim=1) + twod_model = parallelize_module( + twod_model, twod_mesh, PairwiseParallel(), tp_mesh_dim=1 + ) _pre_dp_module_transform(twod_model) # TODO: Add tests when using gradient_as_bucket_view and static_graph for DDP. twod_model = DDP(twod_model, process_group=dp_pg) diff --git a/test/distributed/tensor/parallel/test_fsdp_2d_parallel.py b/test/distributed/tensor/parallel/test_fsdp_2d_parallel.py index 60dcb181715c..c0feafb3f2be 100644 --- a/test/distributed/tensor/parallel/test_fsdp_2d_parallel.py +++ b/test/distributed/tensor/parallel/test_fsdp_2d_parallel.py @@ -20,6 +20,7 @@ from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType from torch.distributed.tensor.parallel import ( ColwiseParallel, + PairwiseParallel, parallelize_module, RowwiseParallel, ) @@ -104,12 +105,8 @@ def test_raise_invalid_tp_composition(self): mesh_2d = init_device_mesh( self.device_type, (2, self.world_size // 2), mesh_dim_names=("tp", "dp") ) - parallelize_plan = { - "net1": ColwiseParallel(), - "net2": RowwiseParallel(), - } model_2d = parallelize_module( - SimpleModel().cuda(), mesh_2d["tp"], parallelize_plan + SimpleModel().cuda(), mesh_2d["tp"], PairwiseParallel() ) @with_comms @@ -141,11 +138,7 @@ def _test_2d_e2e_training( ) tp_mesh = mesh_2d["tp"] dp_mesh = mesh_2d["dp"] - parallelize_plan = { - "net1": ColwiseParallel(), - "net2": RowwiseParallel(), - } - model_2d = parallelize_module(SimpleModel().cuda(), tp_mesh, parallelize_plan) + model_2d = parallelize_module(SimpleModel().cuda(), tp_mesh, PairwiseParallel()) model_2d = FSDP( model_2d, device_mesh=dp_mesh, @@ -253,11 +246,9 @@ def test_2d_state_dict(self, is_even_sharded_model): ) tp_mesh = mesh_2d["tp"] dp_mesh = mesh_2d["dp"] - parallelize_plan = { - "net1": ColwiseParallel(), - "net2": RowwiseParallel(), - } - model_2d = parallelize_module(simple_model().cuda(), tp_mesh, parallelize_plan) + model_2d = parallelize_module( + simple_model().cuda(), tp_mesh, PairwiseParallel() + ) model_2d = FSDP(model_2d, device_mesh=dp_mesh, use_orig_params=True) FSDP.set_state_dict_type( @@ -301,11 +292,9 @@ def test_2d_load_state_dict(self, is_even_sharded_model): ) tp_mesh = mesh_2d["tp"] dp_mesh = mesh_2d["dp"] - parallelize_plan = { - "net1": ColwiseParallel(), - "net2": RowwiseParallel(), - } - model_2d = parallelize_module(simple_model().cuda(), tp_mesh, parallelize_plan) + model_2d = parallelize_module( + simple_model().cuda(), tp_mesh, PairwiseParallel() + ) model_2d = FSDP(model_2d, device_mesh=dp_mesh, use_orig_params=True) optim_2d = torch.optim.Adam(model_2d.parameters(), lr=0.01) @@ -362,11 +351,9 @@ def test_2d_optim_state_dict(self, is_even_sharded_model): mesh_2d = init_device_mesh( self.device_type, (2, self.world_size // 2), mesh_dim_names=("dp", "tp") ) - parallelize_plan = { - "net1": ColwiseParallel(), - "net2": RowwiseParallel(), - } - model_2d = parallelize_module(simple_model().cuda(), mesh_2d["tp"], parallelize_plan) + model_2d = parallelize_module( + simple_model().cuda(), mesh_2d["tp"], PairwiseParallel() + ) model_2d = FSDP(model_2d, device_mesh=mesh_2d["dp"], use_orig_params=True) FSDP.set_state_dict_type( model_2d, diff --git a/test/distributed/tensor/parallel/test_parallelize_api.py b/test/distributed/tensor/parallel/test_parallelize_api.py index 44a8687ffb77..91fb2b50662b 100644 --- a/test/distributed/tensor/parallel/test_parallelize_api.py +++ b/test/distributed/tensor/parallel/test_parallelize_api.py @@ -6,10 +6,15 @@ from torch.distributed.tensor.parallel._utils import _create_1d_device_mesh from torch.distributed.tensor.parallel.api import ( _parallelize_linear_like_module, + _parallelize_mlp, parallelize_module, ) from torch.distributed.tensor.parallel.style import ( ColwiseParallel, + make_input_replicate_1d, + make_output_replicate_1d, + PairwiseParallel, + ParallelStyle, PrepareModuleInput, PrepareModuleOutput, RowwiseParallel, @@ -136,6 +141,23 @@ def _compare_module( dist_optim.step() self._compare_params(local_module, dist_module, rank0_only, rowwise) + @with_comms + def test_parallelize_mlp(self): + inp_size = [12, 10] + model = MLPModule(self.device_type) + model_tp = MLPModule(self.device_type) + + # Ensure model are initialized the same way. + self.assertEqual(model.net1.weight, model_tp.net1.weight) + self.assertEqual(model.net1.bias, model_tp.net1.bias) + self.assertEqual(model.net2.weight, model_tp.net2.weight) + self.assertEqual(model.net2.bias, model_tp.net2.bias) + + # Parallelize module. + device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + model_tp = _parallelize_mlp(model_tp, device_mesh, PairwiseParallel()) + self._compare_module(model, model_tp, inp_size) + @with_comms def test_parallelize_mlp_with_module_api(self): inp_size = [12, 10] @@ -155,10 +177,10 @@ def test_parallelize_mlp_with_module_api(self): device_mesh, { "net1": ColwiseParallel( - input_layouts=Replicate(), output_layouts=Replicate() + make_input_replicate_1d, make_output_replicate_1d ), "net2": ColwiseParallel( - input_layouts=Replicate(), output_layouts=Replicate() + make_input_replicate_1d, make_output_replicate_1d ), }, ) @@ -195,15 +217,40 @@ def test_parallelize_mlp_with_module_api_nested(self): device_mesh, { "dummy_encoder.net1": ColwiseParallel( - input_layouts=Replicate(), output_layouts=Replicate() + make_input_replicate_1d, make_output_replicate_1d ), "dummy_encoder.net2": ColwiseParallel( - input_layouts=Replicate(), output_layouts=Replicate() + make_input_replicate_1d, make_output_replicate_1d ), }, ) self._compare_module(model, model_tp, inp_size, rank0_only=False) + @with_comms + def test_parallelize_mlp_error(self): + class DummyParallel(ParallelStyle): + def __init__(self) -> None: + super().__init__( + make_input_replicate_1d, + make_output_replicate_1d, + input_layouts=None, + output_layouts=None, + use_local_output=False, + ) + + model_tp = MLPModule(self.device_type) + device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + with self.assertRaisesRegex( + NotImplementedError, + "Only support PairwiseParallel for MLP parallelization.", + ): + _parallelize_mlp(model_tp, device_mesh, DummyParallel()) + + with self.assertRaisesRegex( + RuntimeError, "More than one nn.Linear needed for a MLP." + ): + _parallelize_mlp(torch.nn.Linear(10, 5), device_mesh, PairwiseParallel()) + @with_comms def test_linear_row_wise_parallel(self): # test RowwiseParallel @@ -227,7 +274,7 @@ def test_linear_row_wise_parallel(self): def test_linear_col_wise_parallel(self): # test ColwiseParallel inp_size = [8, 10] - colwise = ColwiseParallel(input_layouts=Replicate(), output_layouts=Replicate()) + colwise = ColwiseParallel(make_input_replicate_1d, make_output_replicate_1d) torch.manual_seed(5) model = torch.nn.Linear(10, 16, device=self.device_type) diff --git a/test/distributed/tensor/parallel/test_tp_examples.py b/test/distributed/tensor/parallel/test_tp_examples.py index a37fec29574a..a15c818354e5 100644 --- a/test/distributed/tensor/parallel/test_tp_examples.py +++ b/test/distributed/tensor/parallel/test_tp_examples.py @@ -10,6 +10,7 @@ ) from torch.distributed.tensor.parallel import ( ColwiseParallel, + PairwiseParallel, parallelize_module, RowwiseParallel, ) @@ -61,7 +62,7 @@ def _test_mlp_training_e2e(self, is_seq_parallel=False, recompute_activation=Fal self.device_type, torch.arange(0, NUM_DEVICES), ) - parallelize_plan = { + parallel_style = { "net1": ColwiseParallel(input_layouts=Shard(0)) if is_seq_parallel else ColwiseParallel(), @@ -69,7 +70,7 @@ def _test_mlp_training_e2e(self, is_seq_parallel=False, recompute_activation=Fal if is_seq_parallel else RowwiseParallel(), } - model_tp = parallelize_module(model_tp, device_mesh, parallelize_plan) + model_tp = parallelize_module(model_tp, device_mesh, parallel_style) if recompute_activation: model_tp = input_reshard( checkpoint_wrapper( @@ -123,11 +124,7 @@ def _test_mlp_inference(self, device_mesh): self._check_module(model, model_tp) # Shard module and initialize optimizer. - parallelize_plan = { - "net1": ColwiseParallel(), - "net2": RowwiseParallel(), - } - model_tp = parallelize_module(model_tp, device_mesh, parallelize_plan) + model_tp = parallelize_module(model_tp, device_mesh, PairwiseParallel()) output = model(inp) output_tp = model_tp(inp) diff --git a/test/distributed/tensor/parallel/test_tp_random_state.py b/test/distributed/tensor/parallel/test_tp_random_state.py index 812bcfb5a969..75444c59afcc 100644 --- a/test/distributed/tensor/parallel/test_tp_random_state.py +++ b/test/distributed/tensor/parallel/test_tp_random_state.py @@ -5,7 +5,11 @@ from torch.distributed._tensor import DeviceMesh from torch.distributed.tensor.parallel.api import parallelize_module -from torch.distributed.tensor.parallel.style import ColwiseParallel +from torch.distributed.tensor.parallel.style import ( + ColwiseParallel, + make_input_replicate_1d, + make_output_replicate_1d, +) from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( @@ -51,8 +55,12 @@ def test_model_init(self): model, device_mesh, { - "net1": ColwiseParallel(input_layouts=Replicate(), output_layouts=Replicate()), - "net2": ColwiseParallel(input_layouts=Replicate(), output_layouts=Replicate()), + "net1": ColwiseParallel( + make_input_replicate_1d, make_output_replicate_1d + ), + "net2": ColwiseParallel( + make_input_replicate_1d, make_output_replicate_1d + ), }, ) # in most cases, the random number generator states is set by data loader From 2bae888f659f991d29e73c91e703c652f4197615 Mon Sep 17 00:00:00 2001 From: Facebook Community Bot Date: Thu, 23 Nov 2023 01:46:30 +0000 Subject: [PATCH 127/221] Automated submodule update: FBGEMM (#113977) This is an automated pull request to update the first-party submodule for [pytorch/FBGEMM](https://github.com/pytorch/FBGEMM). New submodule commit: https://github.com/pytorch/FBGEMM/commit/a142e2064de172cb128ca086ca6438f92f99484e Test Plan: Ensure that CI jobs succeed on GitHub before landing. Pull Request resolved: https://github.com/pytorch/pytorch/pull/113977 Approved by: https://github.com/malfet --- third_party/fbgemm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/fbgemm b/third_party/fbgemm index f1bbb608d217..f49dea6a16f1 160000 --- a/third_party/fbgemm +++ b/third_party/fbgemm @@ -1 +1 @@ -Subproject commit f1bbb608d217f8089f0e678b17ee4c6b7d749d7f +Subproject commit f49dea6a16f153584b926203f8b3a1ee5801dea7 From 7a697c4683375737cde91da6b061ab0c219b4fe8 Mon Sep 17 00:00:00 2001 From: atalman Date: Thu, 23 Nov 2023 02:14:22 +0000 Subject: [PATCH 128/221] [RelEng] Tag docker images for release, pin unstable and disabled jobs, apply release only changes (#114355) 1. This tags docker images using docker pull/tag/push for current release 2. Sets RELEASE_VERSION_TAG var and regenerates the workflows using the new docker tag 3. Remove conda token setting and Binary tests release changes these are already automated 4. Pin unstable and disabled jobs, autumate: https://github.com/pytorch/pytorch/pull/111675 Test: ``` RELEASE_VERSION=2.2 ./scripts/release/apply-release-changes.sh Tagging pytorch/manylinux-builder:cuda11.8-main to pytorch/manylinux-builder:cuda11.8-2.2 , dry_run: enabled Tagging pytorch/manylinux-builder:cuda12.1-main to pytorch/manylinux-builder:cuda12.1-2.2 , dry_run: enabled Tagging pytorch/libtorch-cxx11-builder:cuda11.8-main to pytorch/libtorch-cxx11-builder:cuda11.8-2.2 , dry_run: enabled Tagging pytorch/libtorch-cxx11-builder:cuda12.1-main to pytorch/libtorch-cxx11-builder:cuda12.1-2.2 , dry_run: enabled Tagging pytorch/manylinux-builder:rocm5.6-main to pytorch/manylinux-builder:rocm5.6-2.2 , dry_run: enabled Tagging pytorch/manylinux-builder:rocm5.7-main to pytorch/manylinux-builder:rocm5.7-2.2 , dry_run: enabled Tagging pytorch/libtorch-cxx11-builder:rocm5.6-main to pytorch/libtorch-cxx11-builder:rocm5.6-2.2 , dry_run: enabled Tagging pytorch/libtorch-cxx11-builder:rocm5.7-main to pytorch/libtorch-cxx11-builder:rocm5.7-2.2 , dry_run: enabled Tagging pytorch/manylinux-builder:cpu-main to pytorch/manylinux-builder:cpu-2.2 , dry_run: enabled Tagging pytorch/libtorch-cxx11-builder:cpu-main to pytorch/libtorch-cxx11-builder:cpu-2.2 , dry_run: enabled Tagging pytorch/manylinuxcxx11-abi-builder:cpu-cxx11-abi-main to pytorch/manylinuxcxx11-abi-builder:cpu-cxx11-abi-2.2 , dry_run: enabled Tagging pytorch/manylinuxaarch64-builder:cpu-aarch64-main to pytorch/manylinuxaarch64-builder:cpu-aarch64-2.2 , dry_run: enabled Tagging pytorch/conda-builder:cuda11.8-main to pytorch/conda-builder:cuda11.8-2.2 , dry_run: enabled Tagging pytorch/conda-builder:cuda12.1-main to pytorch/conda-builder:cuda12.1-2.2 , dry_run: enabled Tagging pytorch/conda-builder:cpu-main to pytorch/conda-builder:cpu-2.2 , dry_run: enabled /data/users/atalman/pytorch/.github/workflows/generated-linux-binary-manywheel-nightly.yml /data/users/atalman/pytorch/.github/workflows/generated-linux-binary-conda-nightly.yml /data/users/atalman/pytorch/.github/workflows/generated-linux-binary-libtorch-cxx11-abi-nightly.yml /data/users/atalman/pytorch/.github/workflows/generated-linux-binary-libtorch-pre-cxx11-nightly.yml /data/users/atalman/pytorch/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml /data/users/atalman/pytorch/.github/workflows/generated-linux-binary-manywheel-main.yml /data/users/atalman/pytorch/.github/workflows/generated-linux-binary-libtorch-cxx11-abi-main.yml /data/users/atalman/pytorch/.github/workflows/generated-linux-binary-libtorch-pre-cxx11-main.yml /data/users/atalman/pytorch/.github/workflows/generated-windows-binary-wheel-nightly.yml /data/users/atalman/pytorch/.github/workflows/generated-windows-binary-conda-nightly.yml /data/users/atalman/pytorch/.github/workflows/generated-windows-binary-libtorch-release-nightly.yml /data/users/atalman/pytorch/.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml /data/users/atalman/pytorch/.github/workflows/generated-windows-binary-libtorch-release-main.yml /data/users/atalman/pytorch/.github/workflows/generated-windows-binary-libtorch-debug-main.yml /data/users/atalman/pytorch/.github/workflows/generated-macos-binary-wheel-nightly.yml /data/users/atalman/pytorch/.github/workflows/generated-macos-binary-conda-nightly.yml /data/users/atalman/pytorch/.github/workflows/generated-macos-binary-libtorch-cxx11-abi-nightly.yml /data/users/atalman/pytorch/.github/workflows/generated-macos-arm64-binary-libtorch-cxx11-abi-nightly.yml /data/users/atalman/pytorch/.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml /data/users/atalman/pytorch/.github/workflows/generated-macos-arm64-binary-conda-nightly.yml ```` Result of pinning unstable and disabled jobs: ``` # The link to the published list of disabled jobs DISABLED_JOBS_URL = "https://ossci-metrics.s3.amazonaws.com/disabled-jobs.json?versionid=kKJlAXdrUbk3CilXbKu.6OwNTGQB8a.B" # and unstable jobs UNSTABLE_JOBS_URL = "https://ossci-metrics.s3.amazonaws.com/unstable-jobs.json?versionid=vzaicOxSsh55iXBXwgGrW6dFeVtPfrhr" ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/114355 Approved by: https://github.com/malfet --- .../scripts/generate_binary_build_matrix.py | 4 +- .../scripts/tag_docker_images_for_release.py | 64 +++++++++++++++++++ scripts/release/apply-release-changes.sh | 21 +++--- 3 files changed, 78 insertions(+), 11 deletions(-) create mode 100644 .github/scripts/tag_docker_images_for_release.py diff --git a/.github/scripts/generate_binary_build_matrix.py b/.github/scripts/generate_binary_build_matrix.py index 8511e8a289ec..a34cd3a76199 100644 --- a/.github/scripts/generate_binary_build_matrix.py +++ b/.github/scripts/generate_binary_build_matrix.py @@ -10,9 +10,9 @@ * Latest ROCM """ +import os from typing import Dict, List, Optional, Tuple - CUDA_ARCHES = ["11.8", "12.1"] @@ -95,7 +95,7 @@ def arch_type(arch_version: str) -> str: # This can be updated to the release version when cutting release branch, i.e. 2.1 -DEFAULT_TAG = "main" +DEFAULT_TAG = os.getenv("RELEASE_VERSION_TAG", "main") WHEEL_CONTAINER_IMAGES = { **{ diff --git a/.github/scripts/tag_docker_images_for_release.py b/.github/scripts/tag_docker_images_for_release.py new file mode 100644 index 000000000000..fb3f7233f89c --- /dev/null +++ b/.github/scripts/tag_docker_images_for_release.py @@ -0,0 +1,64 @@ +import argparse +import subprocess +from typing import Dict + +import generate_binary_build_matrix + + +def tag_image( + image: str, + default_tag: str, + release_version: str, + dry_run: str, + tagged_images: Dict[str, bool], +) -> None: + if image in tagged_images: + return + release_image = image.replace(f"-{default_tag}", f"-{release_version}") + print(f"Tagging {image} to {release_image} , dry_run: {dry_run}") + + if dry_run == "disabled": + subprocess.check_call(["docker", "pull", image]) + subprocess.check_call(["docker", "tag", image, release_image]) + subprocess.check_call(["docker", "push", release_image]) + tagged_images[image] = True + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument( + "--version", + help="Version to tag", + type=str, + default="2.2", + ) + parser.add_argument( + "--dry-run", + help="No Runtime Error check", + type=str, + choices=["enabled", "disabled"], + default="enabled", + ) + + options = parser.parse_args() + tagged_images: Dict[str, bool] = dict() + platform_images = [ + generate_binary_build_matrix.WHEEL_CONTAINER_IMAGES, + generate_binary_build_matrix.LIBTORCH_CONTAINER_IMAGES, + generate_binary_build_matrix.CONDA_CONTAINER_IMAGES, + ] + default_tag = generate_binary_build_matrix.DEFAULT_TAG + + for platform_image in platform_images: # type: ignore[attr-defined] + for arch in platform_image.keys(): # type: ignore[attr-defined] + tag_image( + platform_image[arch], # type: ignore[index] + default_tag, + options.version, + options.dry_run, + tagged_images, + ) + + +if __name__ == "__main__": + main() diff --git a/scripts/release/apply-release-changes.sh b/scripts/release/apply-release-changes.sh index c3da890e3e7e..4c07057e6661 100755 --- a/scripts/release/apply-release-changes.sh +++ b/scripts/release/apply-release-changes.sh @@ -1,7 +1,7 @@ #!/usr/bin/env bash # # Usage (run from root of project): -# RELEASE_VERSION=2.1 apply-release-changes.sh +# DRY_RUN=disabled RELEASE_VERSION=2.2 ./scripts/release/apply-release-changes.sh # # RELEASE_VERSION: Version of this current release @@ -10,6 +10,9 @@ set -eou pipefail # Create and Check out to Release Branch # git checkout -b "${RELEASE_BRANCH}" +DRY_RUN=${DRY_RUN:-enabled} +python3 .github/scripts/tag_docker_images_for_release.py --version ${RELEASE_VERSION} --dry-run ${DRY_RUN} + # Change all GitHub Actions to reference the test-infra release branch # as opposed to main. echo "Applying to workflows" @@ -22,8 +25,6 @@ echo "Applying to templates" for i in .github/templates/*.yml.j2; do sed -i 's#common.checkout(\(.*\))#common.checkout(\1, checkout_pr_head=False)#' $i; done -# Change conda token for test env for conda upload -sed -i 's#CONDA_PYTORCHBOT_TOKEN#CONDA_PYTORCHBOT_TOKEN_TEST#' .github/templates/upload.yml.j2 # Triton wheel echo "Triton Changes" @@ -34,14 +35,16 @@ echo "XLA Changes" sed -i -e s#--quiet#-b\ r"${RELEASE_VERSION}"# .ci/pytorch/common_utils.sh sed -i -e s#.*#r"${RELEASE_VERSION}"# .github/ci_commit_pins/xla.txt -# Binary tests -echo "Binary tests" -sed -i 's#/nightly/#/test/#' .circleci/scripts/binary_linux_test.sh -sed -i 's#"\\${PYTORCH_CHANNEL}"#pytorch-test#' .circleci/scripts/binary_linux_test.sh - -# Regenerated templates +# Regenerate templates +export RELEASE_VERSION_TAG=${RELEASE_VERSION} ./.github/regenerate.sh +# Pin Unstable and disabled jobs +UNSTABLE_VER=$(aws s3api list-object-versions --bucket ossci-metrics --prefix unstable-jobs.json --query 'Versions[?IsLatest].[VersionId]' --output text) +DISABLED_VER=$(aws s3api list-object-versions --bucket ossci-metrics --prefix disabled-jobs.json --query 'Versions[?IsLatest].[VersionId]' --output text) +sed -i -e s#unstable-jobs.json#"unstable-jobs.json?versionid=${UNSTABLE_VER}"# .github/scripts/filter_test_configs.py +sed -i -e s#disabled-jobs.json#"disabled-jobs.json?versionid=${DISABLED_VER}"# .github/scripts/filter_test_configs.py + # Optional # git commit -m "[RELEASE-ONLY CHANGES] Branch Cut for Release {RELEASE_VERSION}" # git push origin "${RELEASE_BRANCH}" From b27565ad7d07baed98df50b05f108e209a6f4755 Mon Sep 17 00:00:00 2001 From: Yidi Wu Date: Thu, 23 Nov 2023 02:58:52 +0000 Subject: [PATCH 129/221] Forward fix D51468211 (#114381) Summary: Forward fix test failures caused by D51468211. The root cause is that when converting the param_buffer into fake_tensor, we didn't set the static_shapes=True, this causes the shape_env to have more symbols than expected. The current status is that we assume all param and buffers are constant sizes. Test Plan: buck2 test 'fbcode//mode/opt' fbcode//aps_models/ads/icvr/tests:export_test_cpu -- --exact 'aps_models/ads/icvr/tests:export_test_cpu - test_20x_icvr_export (aps_models.ads.icvr.tests.export_test.ExportTest)' Reviewed By: hongtansun-meta Differential Revision: D51531279 Pull Request resolved: https://github.com/pytorch/pytorch/pull/114381 Approved by: https://github.com/angelayi --- torch/_export/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_export/__init__.py b/torch/_export/__init__.py index 13e366e6d396..438b54c2b0bd 100644 --- a/torch/_export/__init__.py +++ b/torch/_export/__init__.py @@ -426,7 +426,7 @@ def convert_to_fake(x): # TODO properly use the cached fake tensor fake_kwargs = pytree.tree_map_only(torch.Tensor, fake_mode.from_tensor, kwargs) fake_params_buffers = pytree.tree_map_only(torch.Tensor, - fake_mode.from_tensor, + functools.partial(fake_mode.from_tensor, static_shapes=True), {**dict(gm.named_parameters(remove_duplicate=False)), **dict(gm.named_buffers(remove_duplicate=False))}) return fake_args, fake_kwargs, fake_params_buffers, fake_mode From c4a22d6918b7ca218f2712d7e7e147aca7127fa3 Mon Sep 17 00:00:00 2001 From: Thiago Crepaldi Date: Wed, 22 Nov 2023 22:13:48 +0000 Subject: [PATCH 130/221] Add support for models with mutated buffer on torch.onnx.dynamo_export (#112272) This PR adds a unit test that leverages `torch.export.ExportedProgram` models that mutates registered buffers. Although the exporter already works out of the box in such scenario, the GraphModule and the exported ONNX model have extra outputs containing the mutated buffers. On future runs of the ONNX model, the mutated buffers are used as input to the model. The aforementioned extra inputs and outputs are by design and the `ONNXProgram.model_signature` can be used to fetch detailed input/output schema for the exported model. However, when we want to compare pytorch output to ONNX's, there is a mismatch between the schema because pytorch output does not include the mutated buffers present on the ONNX output. This PR extends `onnx_program.adapt_torch_outputs_to_onnx(torch_outputs)` so that the mutated buffers are prepended to the Pytorch output, matching the ONNX schema. Pull Request resolved: https://github.com/pytorch/pytorch/pull/112272 Approved by: https://github.com/titaiwangms, https://github.com/BowenBao --- test/onnx/onnx_test_common.py | 10 +++--- test/onnx/test_fx_to_onnx_with_onnxruntime.py | 31 ++++++++++++++++ .../fx/torch_export_graph_extractor.py | 4 +++ torch/onnx/_internal/io_adapter.py | 36 +++++++++++++++++++ 4 files changed, 77 insertions(+), 4 deletions(-) diff --git a/test/onnx/onnx_test_common.py b/test/onnx/onnx_test_common.py index f664e7e84a42..0be906fada77 100644 --- a/test/onnx/onnx_test_common.py +++ b/test/onnx/onnx_test_common.py @@ -436,16 +436,18 @@ def _compare_pytorch_onnx_with_ort( ref_input_args = input_args ref_input_kwargs = input_kwargs - ref_outputs = onnx_program.adapt_torch_outputs_to_onnx( - ref_model(*ref_input_args, **ref_input_kwargs) - ) - + # ONNXProgram holds a reference (not copy) to the original ref_model, including its state_dict. + # Thus, ONNXProgram() must run before ref_model() to prevent ref_model.forward() from changing the state_dict. + # Otherwise, the ref_model can change buffers on state_dict which would be used by ONNXProgram.__call__() ort_outputs = onnx_program(*input_args, **input_kwargs) + ref_outputs = ref_model(*ref_input_args, **ref_input_kwargs) + ref_outputs = onnx_program.adapt_torch_outputs_to_onnx(ref_outputs) if len(ref_outputs) != len(ort_outputs): raise AssertionError( f"Expected {len(ref_outputs)} outputs, got {len(ort_outputs)}" ) + for ref_output, ort_output in zip(ref_outputs, ort_outputs): torch.testing.assert_close( ref_output, torch.tensor(ort_output), rtol=rtol, atol=atol diff --git a/test/onnx/test_fx_to_onnx_with_onnxruntime.py b/test/onnx/test_fx_to_onnx_with_onnxruntime.py index 26fa6f215bec..dcede4718a8c 100644 --- a/test/onnx/test_fx_to_onnx_with_onnxruntime.py +++ b/test/onnx/test_fx_to_onnx_with_onnxruntime.py @@ -942,6 +942,37 @@ def forward(self, x): loaded_exported_program, (x,), skip_dynamic_shapes_check=True ) + @pytorch_test_common.xfail_if_model_type_is_not_exportedprogram( + "Unsupported FX nodes: {'call_function': ['aten.add_.Tensor']}. " + "github issue: https://github.com/pytorch/pytorch/issues/114406" + ) + def test_exported_program_as_input_lifting_buffers_mutation(self): + for persistent in (True, False): + + class CustomModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer( + "my_buffer", torch.tensor(4.0), persistent=persistent + ) + + def forward(self, x, b): + output = x + b + ( + self.my_buffer.add_(1.0) + 3.0 + ) # Mutate buffer through in-place addition + return output + + inputs = (torch.rand((3, 3), dtype=torch.float32), torch.randn(3, 3)) + model = CustomModule() + self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime( + model, inputs, skip_dynamic_shapes_check=True + ) + # Buffer will be mutated after the first iteration + self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime( + model, inputs, skip_dynamic_shapes_check=True + ) + def _parameterized_class_attrs_and_values_with_fake_options(): input_values = [] diff --git a/torch/onnx/_internal/fx/torch_export_graph_extractor.py b/torch/onnx/_internal/fx/torch_export_graph_extractor.py index 51c31560b144..5f1fbb5c7481 100644 --- a/torch/onnx/_internal/fx/torch_export_graph_extractor.py +++ b/torch/onnx/_internal/fx/torch_export_graph_extractor.py @@ -63,6 +63,10 @@ def generate_fx( # tensor, etc), we flatten the collection and register each element as output. options.fx_tracer.output_adapter.append_step(io_adapter.FlattenOutputStep()) + options.fx_tracer.output_adapter.append_step( + io_adapter.PrependParamsAndBuffersAotAutogradOutputStep(model) + ) + # Export FX graph to ONNX ModelProto. return self.pre_export_passes(options, model, model.graph_module, updated_model_args) # type: ignore[return-value] diff --git a/torch/onnx/_internal/io_adapter.py b/torch/onnx/_internal/io_adapter.py index 45134505000f..28db50a5b58a 100644 --- a/torch/onnx/_internal/io_adapter.py +++ b/torch/onnx/_internal/io_adapter.py @@ -550,3 +550,39 @@ def apply( if model_kwargs: return MergeKwargsIntoArgsInputStep().apply(updated_args, model_kwargs) return updated_args, {} + + +class PrependParamsAndBuffersAotAutogradOutputStep(OutputAdaptStep): + """Prepend model's mutated buffers to the user output. + + :func:`torch.export.export` lifts model's mutated buffers as outputs, thus, they + must be added to the user output after the model is executed. + + Args: + model: The PyTorch model with mutated buffers. + """ + + def __init__(self, model: torch_export.ExportedProgram): + assert isinstance( + model, torch_export.ExportedProgram + ), "'model' must be a torch.export.ExportedProgram." + self.model = model + + def apply(self, model_outputs: Any) -> Sequence[Any]: + """Flatten the model outputs and validate the `SpecTree` output. + + Args: + model_outputs: The model outputs to flatten. + + Returns: + flattened_outputs: The flattened model outputs. + """ + + ordered_buffers = tuple( + self.model.state_dict[name] + for name in self.model.graph_signature.buffers_to_mutate.values() + ) + + # NOTE: calling convention is first mutated buffers, then outputs args as model returned them. + updated_outputs = (*ordered_buffers, *model_outputs) + return updated_outputs From d18e6b07aa61ec8b480f15fdcd02eed53580c85c Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Mon, 20 Nov 2023 14:49:11 +0800 Subject: [PATCH 131/221] Overload vec::dequantize to eliminate rounding error for quantized sigmoid (#114098) **Description** Fix #107030 Dequantize X by `(x_val - zp) * scale` instead of `x_val * scale + (-zp * scale)` to eliminate rounding error. Now this overload is used for sigmoid only. Performance impact: ![image](https://github.com/pytorch/pytorch/assets/12522207/655abd16-7d9d-4a9a-8c59-327ebf39157a) Intel(R) Xeon(R) Platinum 8358 CPU @ 2.60GHz (Ice Lake) **Test plan** `python test_quantization.py TestQuantizedOps.test_sigmoid_dequantize_rounding_error` Pull Request resolved: https://github.com/pytorch/pytorch/pull/114098 Approved by: https://github.com/jgong5, https://github.com/jerryzh168 --- aten/src/ATen/cpu/vec/vec256/vec256_qint.h | 54 ++++++++++++++++++ .../cpu/vec/vec256/vsx/vec256_qint32_vsx.h | 14 +++++ .../ATen/cpu/vec/vec256/zarch/vec256_zarch.h | 10 ++++ aten/src/ATen/cpu/vec/vec512/vec512_qint.h | 55 +++++++++++++++++++ .../cpu/kernels/QuantizedOpKernels.cpp | 4 +- test/quantization/core/test_quantized_op.py | 17 ++++++ 6 files changed, 151 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_qint.h b/aten/src/ATen/cpu/vec/vec256/vec256_qint.h index 16550a6af20b..ee14de69324f 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_qint.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256_qint.h @@ -305,6 +305,13 @@ struct Vectorized : public Vectorizedqi { return {vec::fmadd(scale, Vectorized(float_vals), scale_zp_premul)}; } + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point) const { + __m256 float_vals = _mm256_cvtepi32_ps(vals); + return {(Vectorized(float_vals) - zero_point) * scale}; + } + static Vectorized quantize( const float_vec_return_type& rhs, float scale, @@ -520,6 +527,26 @@ struct Vectorized : public Vectorizedqi { return {val0, val1, val2, val3}; } + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point) const { + __m128i int_val0 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 0)); + __m128i int_val1 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 1)); + __m128i int_val2 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 2)); + __m128i int_val3 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 3)); + + __m256 float_val0 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val0)); + __m256 float_val1 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val1)); + __m256 float_val2 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val2)); + __m256 float_val3 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val3)); + + auto val0 = (Vectorized(float_val0) - zero_point) * scale; + auto val1 = (Vectorized(float_val1) - zero_point) * scale; + auto val2 = (Vectorized(float_val2) - zero_point) * scale; + auto val3 = (Vectorized(float_val3) - zero_point) * scale; + return {val0, val1, val2, val3}; + } + static Vectorized quantize( const float_vec_return_type& rhs, float /*scale*/, @@ -698,6 +725,26 @@ struct Vectorized : public Vectorizedqi { return {val0, val1, val2, val3}; } + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point) const { + __m128i int_val0 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 0)); + __m128i int_val1 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 1)); + __m128i int_val2 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 2)); + __m128i int_val3 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 3)); + + __m256 float_val0 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val0)); + __m256 float_val1 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val1)); + __m256 float_val2 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val2)); + __m256 float_val3 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val3)); + + auto val0 = (Vectorized(float_val0) - zero_point) * scale; + auto val1 = (Vectorized(float_val1) - zero_point) * scale; + auto val2 = (Vectorized(float_val2) - zero_point) * scale; + auto val3 = (Vectorized(float_val3) - zero_point) * scale; + return {val0, val1, val2, val3}; + } + static Vectorized quantize( const float_vec_return_type& rhs, float /*scale*/, @@ -853,6 +900,13 @@ struct VectorizedQuantizedConverter { return rv; } + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point) const { + Vectorized scale_zp_premul; + return dequantize(scale, zero_point, scale_zp_premul); + } + protected: VectorizedQuantizedConverter() {} }; diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_qint32_vsx.h b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_qint32_vsx.h index a85730c9a6df..746a5e27a5c1 100644 --- a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_qint32_vsx.h +++ b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_qint32_vsx.h @@ -121,6 +121,20 @@ struct Vectorized { vec_madd(scale_vec1, float_vals1, scale_zp_premul1)}}; } + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point) const { + vfloat32 float_vals0 = vec_float(_vec0); + vfloat32 float_vals1 = vec_float(_vec1); + vfloat32 scale_vec0 = scale.vec0(); + vfloat32 scale_vec1 = scale.vec1(); + vfloat32 zero_point0 = zero_point.vec0(); + vfloat32 zero_point1 = zero_point.vec1(); + return {Vectorized{ + (float_vals0 - zero_point0) * scale_vec0, + (float_vals1 - zero_point1) * scale_vec1}}; + } + static Vectorized quantize( const float_vec_return_type& rhs, float scale, diff --git a/aten/src/ATen/cpu/vec/vec256/zarch/vec256_zarch.h b/aten/src/ATen/cpu/vec/vec256/zarch/vec256_zarch.h index 25ca208ee24c..70b130421cdf 100644 --- a/aten/src/ATen/cpu/vec/vec256/zarch/vec256_zarch.h +++ b/aten/src/ATen/cpu/vec/vec256/zarch/vec256_zarch.h @@ -1730,6 +1730,16 @@ struct Vectorized()>> { return {fmadd(scale, float_val, scale_zp_premul)}; } + template < + typename U = T, + std::enable_if_t::float_num_vecs() == 1, int> = 0> + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point) const { + auto float_val = convert_to_float(_vec); + return {(float_val - zero_point) * scale}; + } + template < typename U = T, std::enable_if_t::float_num_vecs() == 1, int> = 0> diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_qint.h b/aten/src/ATen/cpu/vec/vec512/vec512_qint.h index 493573ccacf1..b03da5d2c3e9 100644 --- a/aten/src/ATen/cpu/vec/vec512/vec512_qint.h +++ b/aten/src/ATen/cpu/vec/vec512/vec512_qint.h @@ -317,6 +317,13 @@ struct Vectorized : public Vectorizedqi { return {vec::fmadd(scale, Vectorized(float_vals), scale_zp_premul)}; } + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point) const { + __m512 float_vals = _mm512_cvtepi32_ps(vals); + return {(Vectorized(float_vals) - zero_point) * scale}; + } + static Vectorized quantize( const float_vec_return_type& rhs, float scale, @@ -531,6 +538,26 @@ struct Vectorized : public Vectorizedqi { return {val0, val1, val2, val3}; } + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point) const { + __m128i int_val0 = _mm_set_epi64x(vals[1], vals[0]); + __m128i int_val1 = _mm_set_epi64x(vals[3], vals[2]); + __m128i int_val2 = _mm_set_epi64x(vals[5], vals[4]); + __m128i int_val3 = _mm_set_epi64x(vals[7], vals[6]); + + __m512 float_val0 = _mm512_cvtepi32_ps(cvtepi8_epi32(int_val0)); + __m512 float_val1 = _mm512_cvtepi32_ps(cvtepi8_epi32(int_val1)); + __m512 float_val2 = _mm512_cvtepi32_ps(cvtepi8_epi32(int_val2)); + __m512 float_val3 = _mm512_cvtepi32_ps(cvtepi8_epi32(int_val3)); + + auto val0 = (Vectorized(float_val0) - zero_point) * scale; + auto val1 = (Vectorized(float_val1) - zero_point) * scale; + auto val2 = (Vectorized(float_val2) - zero_point) * scale; + auto val3 = (Vectorized(float_val3) - zero_point) * scale; + return {val0, val1, val2, val3}; + } + static Vectorized quantize( const float_vec_return_type& rhs, float scale, @@ -708,6 +735,27 @@ struct Vectorized : public Vectorizedqi { return {val0, val1, val2, val3}; } + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point) const { + __m128i int_val0 = _mm_set_epi64x(vals[1], vals[0]); + __m128i int_val1 = _mm_set_epi64x(vals[3], vals[2]); + __m128i int_val2 = _mm_set_epi64x(vals[5], vals[4]); + __m128i int_val3 = _mm_set_epi64x(vals[7], vals[6]); + + __m512 float_val0 = _mm512_cvtepi32_ps(cvtepu8_epi32(int_val0)); + __m512 float_val1 = _mm512_cvtepi32_ps(cvtepu8_epi32(int_val1)); + __m512 float_val2 = _mm512_cvtepi32_ps(cvtepu8_epi32(int_val2)); + __m512 float_val3 = _mm512_cvtepi32_ps(cvtepu8_epi32(int_val3)); + + auto val0 = (Vectorized(float_val0) - zero_point) * scale; + auto val1 = (Vectorized(float_val1) - zero_point) * scale; + auto val2 = (Vectorized(float_val2) - zero_point) * scale; + auto val3 = (Vectorized(float_val3) - zero_point) * scale; + + return {val0, val1, val2, val3}; + } + static Vectorized quantize( const float_vec_return_type& rhs, float scale, @@ -865,6 +913,13 @@ struct VectorizedQuantizedConverter { return rv; } + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point) const { + Vectorized scale_zp_premul; + return dequantize(scale, zero_point, scale_zp_premul); + } + protected: VectorizedQuantizedConverter() {} }; diff --git a/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp b/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp index 479b023a56df..373ef4af33da 100644 --- a/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp +++ b/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp @@ -836,7 +836,6 @@ void qsigmoid_kernel( float scale = qx.q_scale(); auto scale_vec = Vectorized(scale); auto zero_point_vec = Vectorized((float)zero_point); - auto scale_neg_zp_premul_vec = scale_vec * zero_point_vec.neg(); AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qsigmoid", [&]() { float inv_output_scale = 1.0 / output_scale; @@ -861,8 +860,7 @@ void qsigmoid_kernel( output_scale, output_zero_point, value_dy); }, [&](Vec value_qx) -> Vec { - auto value_dx = value_qx.dequantize( - scale_vec, zero_point_vec, scale_neg_zp_premul_vec); + auto value_dx = value_qx.dequantize(scale_vec, zero_point_vec); for (auto & value : value_dx) { value = value.neg(); value = value.exp(); diff --git a/test/quantization/core/test_quantized_op.py b/test/quantization/core/test_quantized_op.py index 15aaee96da3c..59784a63d3ef 100644 --- a/test/quantization/core/test_quantized_op.py +++ b/test/quantization/core/test_quantized_op.py @@ -328,6 +328,23 @@ def test_sigmoid(self, X): ] self._test_activation_function(X, 'sigmoid', sigmoid_test_configs) + @skipIfNoFBGEMM + def test_sigmoid_dequantize_rounding_error(self): + # issue #107030 + sigmoid_test_configs = [ + { + 'quantized_fn': [ + torch.ops.quantized.sigmoid + ], + 'reference_fn': torch.sigmoid, + 'output_range': (0.0, 1.0), + 'change_zero_point': True, + 'output_is_observed': True, + } + ] + X = (np.full(64, 514., dtype=np.float32), (1028.02, 255, torch.quint8)) + self._test_activation_function(X, 'sigmoid', sigmoid_test_configs) + """Tests the correctness of the quantized::hardsigmoid op.""" @override_qengines def test_qhardsigmoid(self): From 088043fc496747156b192ab28047c28ef28303f7 Mon Sep 17 00:00:00 2001 From: Andrew Gu Date: Wed, 22 Nov 2023 17:21:51 -0800 Subject: [PATCH 132/221] [FSDP] Passed `TORCH_NCCL_DESYNC_DEBUG` instead of `NCCL_DESYNC_DEBUG` (#114432) This is to silence some warnings like: ``` [rank0]:[W Utils.hpp:164] Warning: Environment variable NCCL_DESYNC_DEBUG is deprecated; use TORCH_NCCL_DESYNC_DEBUG instead (function getCvarBool) [rank3]:[W Utils.hpp:164] Warning: Environment variable NCCL_DESYNC_DEBUG is deprecated; use TORCH_NCCL_DESYNC_DEBUG instead (function getCvarBool) [rank1]:[W Utils.hpp:164] Warning: Environment variable NCCL_DESYNC_DEBUG is deprecated; use TORCH_NCCL_DESYNC_DEBUG instead (function getCvarBool) [rank2]:[W Utils.hpp:164] Warning: Environment variable NCCL_DESYNC_DEBUG is deprecated; use TORCH_NCCL_DESYNC_DEBUG instead (function getCvarBool) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/114432 Approved by: https://github.com/fegin --- torch/testing/_internal/common_fsdp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch/testing/_internal/common_fsdp.py b/torch/testing/_internal/common_fsdp.py index d8935696ef98..90e9faa5ac37 100644 --- a/torch/testing/_internal/common_fsdp.py +++ b/torch/testing/_internal/common_fsdp.py @@ -865,10 +865,10 @@ def run_subtests(self, *args, **kwargs): class FSDPTest(MultiProcessTestCase): def setUp(self): super().setUp() - # Set NCCL_DESYNC_DEBUG=0 to disable the NCCL `workCleanupLoop()`, + # Set TORCH_NCCL_DESYNC_DEBUG=0 to disable the NCCL `workCleanupLoop()`, # which can cause unit test flakiness: # https://github.com/pytorch/pytorch/issues/90848 - os.environ["NCCL_DESYNC_DEBUG"] = "0" + os.environ["TORCH_NCCL_DESYNC_DEBUG"] = "0" self._spawn_processes() @property From c340db56d5e3eb74e3f9a63939c2bca59fd3ec92 Mon Sep 17 00:00:00 2001 From: PyTorch UpdateBot Date: Thu, 23 Nov 2023 04:54:02 +0000 Subject: [PATCH 133/221] [executorch hash update] update the pinned executorch hash (#114427) This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/_update-commit-hash.yml). Update the pinned executorch hash. Pull Request resolved: https://github.com/pytorch/pytorch/pull/114427 Approved by: https://github.com/pytorchbot --- .ci/docker/ci_commit_pins/executorch.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.ci/docker/ci_commit_pins/executorch.txt b/.ci/docker/ci_commit_pins/executorch.txt index 8178012c2274..143c259d1612 100644 --- a/.ci/docker/ci_commit_pins/executorch.txt +++ b/.ci/docker/ci_commit_pins/executorch.txt @@ -1 +1 @@ -ccdb12eebe91b9e04ec92988991228916a92292f +f4578fc150f1690be27fd1ba3258b35a20d9c39d From 36763d31353eb087c116c8ddfcbecaad4f8b1157 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Thu, 23 Nov 2023 05:07:36 +0000 Subject: [PATCH 134/221] [ProcessGroupNCCL] Move new trace utils (#114367) to TraceUtils.h Pull Request resolved: https://github.com/pytorch/pytorch/pull/114367 Approved by: https://github.com/wconstab, https://github.com/XilunWu --- .../distributed/c10d/ProcessGroupNCCL.cpp | 257 +---------------- torch/csrc/distributed/c10d/TraceUtils.h | 260 ++++++++++++++++++ 2 files changed, 261 insertions(+), 256 deletions(-) diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 40eb9d06ef0d..e816b6fc193d 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -24,44 +24,16 @@ #include #include #include +#include #include #include #include - -#include -#include - -#include #include namespace c10d { constexpr const char* const kNCCLAbortedCommStoreKey = "NCCLABORTEDCOMM"; -DebugInfoWriter::DebugInfoWriter(int rank) { - std::string fileName = getCvarString( - {"TORCH_NCCL_DEBUG_INFO_TEMP_FILE"}, "/tmp/nccl_trace_rank_"); - filename_ = c10::str(fileName, rank); -} - -DebugInfoWriter::~DebugInfoWriter() = default; - -void DebugInfoWriter::write(const std::string& ncclTrace) { - // Open a file for writing. The ios::binary flag is used to write data as - // binary. - std::ofstream file(filename_, std::ios::binary); - - // Check if the file was opened successfully. - if (!file.is_open()) { - LOG(ERROR) << "Error opening file for writing NCCLPG debug info: " - << filename_; - return; - } - - file.write(ncclTrace.data(), ncclTrace.size()); - LOG(INFO) << "Wrote finished "; -} - namespace { #if defined(NCCL_MAJOR) && \ @@ -317,31 +289,6 @@ inline void errorIfCapturingNonCapturableNCCL(c10::cuda::CaptureStatus status) { } // namespace -namespace { -std::string pickle_str(const c10::IValue& v) { - std::vector result; - { - auto writer = [&](const char* data, size_t size) { - result.insert(result.end(), data, data + size); - }; - torch::jit::Pickler pickler( - writer, nullptr, nullptr, nullptr, nullptr, false); - pickler.protocol(); - pickler.pushIValue(v); - pickler.stop(); - } - return std::string(result.begin(), result.end()); -} -c10::Dict new_dict() { - return c10::Dict( - c10::AnyType::get(), c10::AnyType::get()); -} -c10::List new_list() { - return c10::List(c10::AnyType::get()); -} - -} // namespace - // Map from each communicator to its device index. // This map is used when register/deregister cache segments from cache // allocator. See design notes below: @@ -391,208 +338,6 @@ void cacheAllocatorDeregisterHook( } } -struct NCCLTraceBuffer { - static NCCLTraceBuffer* get() { - // intentionally leak on exit - // because this will hold python state that may get destructed - static NCCLTraceBuffer* instance = new NCCLTraceBuffer(); - return instance; - } - NCCLTraceBuffer() { - max_entries_ = getCvarInt({"TORCH_NCCL_TRACE_BUFFER_SIZE"}, 0); - enabled_ = max_entries_ > 0; - } - using EventList = std::vector; - struct Entry { - size_t id_; // incremented id in the trace buffer - // used to figure out where in the circular entries - // buffer this entry will be located to - // update state information - size_t pg_id_; - size_t seq_id_; // as tracked by the process group - const char* profiling_name_; - - std::shared_ptr traceback_; - // we borrow pointser to start_ and end_ so we can query the state - // on reporting. However, once the event is completed, the call - // to `complete` will clear these. - EventList *start_, *end_; - const char* state_ = "scheduled"; - - // size information for input/output tensors - c10::SmallVector input_dims_; - c10::SmallVector output_dims_; - c10::SmallVector sizes_; // flattened from inputs, outputs - }; - - bool enabled_ = false; - std::mutex mutex_; - std::vector entries_; - size_t max_entries_ = 0; - size_t next_ = 0; - size_t id_ = 0; - - c10::optional record( - size_t pg_id, - size_t seq_id, - const char* profiling_name, - const std::vector& inputs, - const std::vector& outputs, - EventList* start, - EventList* end) { - if (!enabled_) { - return c10::nullopt; - } - auto traceback = torch::CapturedTraceback::gather(true, true, true); - std::lock_guard guard(mutex_); - - auto te = Entry{ - id_, - pg_id, - seq_id, - profiling_name, - std::move(traceback), - std::move(start), - std::move(end)}; - - for (const auto& input : inputs) { - c10::IntArrayRef sizes = input.sizes(); - te.input_dims_.push_back(sizes.size()); - te.sizes_.insert(te.sizes_.end(), sizes.begin(), sizes.end()); - } - - for (const auto& output : outputs) { - c10::IntArrayRef sizes = output.sizes(); - te.output_dims_.push_back(sizes.size()); - te.sizes_.insert(te.sizes_.end(), sizes.begin(), sizes.end()); - } - - if (entries_.size() < max_entries_) { - entries_.emplace_back(std::move(te)); - } else { - entries_[next_++] = std::move(te); - if (next_ == max_entries_) { - next_ = 0; - } - } - return id_++; - } - - std::vector dump_entries() { - std::lock_guard guard(mutex_); - std::vector result; - result.reserve(entries_.size()); - result.insert(result.end(), entries_.begin() + next_, entries_.end()); - result.insert(result.end(), entries_.begin(), entries_.begin() + next_); - // query any remaining events - for (auto& r : result) { - if (r.start_ != nullptr) { - bool started = true; - for (auto& ev : *r.start_) { - if (!ev.query()) { - started = false; - break; - } - } - if (started) { - r.state_ = "started"; - } - r.start_ = nullptr; - } - if (r.end_ != nullptr) { - bool completed = true; - for (auto& ev : *r.end_) { - if (!ev.query()) { - completed = false; - break; - } - } - if (completed) { - r.state_ = "completed"; - } - r.end_ = nullptr; - } - } - return result; - } - - void complete(c10::optional id) { - if (!enabled_ || !id) { - return; - } - std::lock_guard guard(mutex_); - auto& entry = entries_.at(*id % max_entries_); - if (entry.id_ == *id) { - entry.state_ = "completed"; - entry.start_ = entry.end_ = nullptr; - } - } - - std::string dump() { - auto result = dump_entries(); - auto entries = new_list(); - c10::IValue pg_id_s = "pg_id"; - c10::IValue seq_id_s = "seq_id"; - c10::IValue profiling_name_s = "profiling_name"; - c10::IValue input_sizes_s = "input_sizes"; - c10::IValue output_sizes_s = "output_sizes"; - - c10::IValue frames_s = "frames"; - c10::IValue state_s = "state"; - c10::IValue line_s = "line"; - c10::IValue name_s = "name"; - c10::IValue filename_s = "filename"; - - std::vector tracebacks; - for (auto& e : result) { - tracebacks.push_back(e.traceback_.get()); - } - torch::SymbolizedTracebacks stracebacks = torch::symbolize(tracebacks); - std::vector all_frames; - for (const auto& f : stracebacks.all_frames) { - auto d = new_dict(); - d.insert(name_s, f.funcname); - d.insert(filename_s, f.filename); - d.insert(line_s, int64_t(f.lineno)); - all_frames.emplace_back(std::move(d)); - } - - for (auto i : c10::irange(result.size())) { - auto& e = result.at(i); - auto& tb = stracebacks.tracebacks.at(i); - auto dict = new_dict(); - dict.insert(pg_id_s, int64_t(e.pg_id_)); - dict.insert(seq_id_s, int64_t(e.seq_id_)); - dict.insert(profiling_name_s, e.profiling_name_); - - auto it = e.sizes_.begin(); - auto read_sizes = [&](const c10::SmallVector& dims) { - auto sizes = new_list(); - for (auto dim : dims) { - auto arg_sizes = new_list(); - for (auto i : c10::irange(dim)) { - (void)i; - arg_sizes.push_back(*it++); - } - sizes.push_back(arg_sizes); - } - return sizes; - }; - - dict.insert(input_sizes_s, read_sizes(e.input_dims_)); - dict.insert(output_sizes_s, read_sizes(e.output_dims_)); - dict.insert(state_s, e.state_); - auto frames = new_list(); - for (int64_t frame : tb) { - frames.push_back(all_frames.at(frame)); - } - dict.insert(frames_s, frames); - entries.push_back(dict); - } - return pickle_str(entries); - } -}; - std::string dump_nccl_trace() { return NCCLTraceBuffer::get()->dump(); } diff --git a/torch/csrc/distributed/c10d/TraceUtils.h b/torch/csrc/distributed/c10d/TraceUtils.h index 2b3358d24c78..61eb5c8b7819 100644 --- a/torch/csrc/distributed/c10d/TraceUtils.h +++ b/torch/csrc/distributed/c10d/TraceUtils.h @@ -3,6 +3,8 @@ #include #include #include +#include +#include #include @@ -12,6 +14,9 @@ #include namespace c10d { + +/* Trace Utils Related to TORCH_NCCL_DESYNC_DEBUG */ + inline std::string getTraceStartKey(const std::string& pgName, int rank) { return pgName + "_" + std::to_string(rank) + "_trace_start"; } @@ -256,4 +261,259 @@ inline std::string retrieveDesyncReport( return report; } +/* Trace Utils Related to Flight Recorder */ + +/* Note: this is only used by PGNCCL (could be generalized in an ideal world but + * wasn't done that way, so isn't expected to be fully general at the moment) */ + +DebugInfoWriter::DebugInfoWriter(int rank) { + std::string fileName = getCvarString( + {"TORCH_NCCL_DEBUG_INFO_TEMP_FILE"}, "/tmp/nccl_trace_rank_"); + filename_ = c10::str(fileName, rank); +} + +DebugInfoWriter::~DebugInfoWriter() = default; + +void DebugInfoWriter::write(const std::string& ncclTrace) { + // Open a file for writing. The ios::binary flag is used to write data as + // binary. + std::ofstream file(filename_, std::ios::binary); + + // Check if the file was opened successfully. + if (!file.is_open()) { + LOG(ERROR) << "Error opening file for writing NCCLPG debug info: " + << filename_; + return; + } + + file.write(ncclTrace.data(), ncclTrace.size()); + LOG(INFO) << "Wrote finished "; +} + +inline std::string pickle_str(const c10::IValue& v) { + std::vector result; + { + auto writer = [&](const char* data, size_t size) { + result.insert(result.end(), data, data + size); + }; + torch::jit::Pickler pickler( + writer, nullptr, nullptr, nullptr, nullptr, false); + pickler.protocol(); + pickler.pushIValue(v); + pickler.stop(); + } + return std::string(result.begin(), result.end()); +} + +inline c10::Dict new_dict() { + return c10::Dict( + c10::AnyType::get(), c10::AnyType::get()); +} + +inline c10::List new_list() { + return c10::List(c10::AnyType::get()); +} + +struct NCCLTraceBuffer { + static NCCLTraceBuffer* get() { + // intentionally leak on exit + // because this will hold python state that may get destructed + static NCCLTraceBuffer* instance = new NCCLTraceBuffer(); + return instance; + } + NCCLTraceBuffer() { + max_entries_ = getCvarInt({"TORCH_NCCL_TRACE_BUFFER_SIZE"}, 0); + enabled_ = max_entries_ > 0; + } + using EventList = std::vector; + struct Entry { + size_t id_; // incremented id in the trace buffer + // used to figure out where in the circular entries + // buffer this entry will be located to + // update state information + size_t pg_id_; + size_t seq_id_; // as tracked by the process group + const char* profiling_name_; + + std::shared_ptr traceback_; + // we borrow pointser to start_ and end_ so we can query the state + // on reporting. However, once the event is completed, the call + // to `complete` will clear these. + EventList *start_, *end_; + const char* state_ = "scheduled"; + + // size information for input/output tensors + c10::SmallVector input_dims_; + c10::SmallVector output_dims_; + c10::SmallVector sizes_; // flattened from inputs, outputs + }; + + bool enabled_ = false; + std::mutex mutex_; + std::vector entries_; + size_t max_entries_ = 0; + size_t next_ = 0; + size_t id_ = 0; + + c10::optional record( + size_t pg_id, + size_t seq_id, + const char* profiling_name, + const std::vector& inputs, + const std::vector& outputs, + EventList* start, + EventList* end) { + if (!enabled_) { + return c10::nullopt; + } + auto traceback = torch::CapturedTraceback::gather(true, true, true); + std::lock_guard guard(mutex_); + + auto te = Entry{ + id_, + pg_id, + seq_id, + profiling_name, + std::move(traceback), + std::move(start), + std::move(end)}; + + for (const auto& input : inputs) { + c10::IntArrayRef sizes = input.sizes(); + te.input_dims_.push_back(sizes.size()); + te.sizes_.insert(te.sizes_.end(), sizes.begin(), sizes.end()); + } + + for (const auto& output : outputs) { + c10::IntArrayRef sizes = output.sizes(); + te.output_dims_.push_back(sizes.size()); + te.sizes_.insert(te.sizes_.end(), sizes.begin(), sizes.end()); + } + + if (entries_.size() < max_entries_) { + entries_.emplace_back(std::move(te)); + } else { + entries_[next_++] = std::move(te); + if (next_ == max_entries_) { + next_ = 0; + } + } + return id_++; + } + + std::vector dump_entries() { + std::lock_guard guard(mutex_); + std::vector result; + result.reserve(entries_.size()); + result.insert(result.end(), entries_.begin() + next_, entries_.end()); + result.insert(result.end(), entries_.begin(), entries_.begin() + next_); + // query any remaining events + for (auto& r : result) { + if (r.start_ != nullptr) { + bool started = true; + for (auto& ev : *r.start_) { + if (!ev.query()) { + started = false; + break; + } + } + if (started) { + r.state_ = "started"; + } + r.start_ = nullptr; + } + if (r.end_ != nullptr) { + bool completed = true; + for (auto& ev : *r.end_) { + if (!ev.query()) { + completed = false; + break; + } + } + if (completed) { + r.state_ = "completed"; + } + r.end_ = nullptr; + } + } + return result; + } + + void complete(c10::optional id) { + if (!enabled_ || !id) { + return; + } + std::lock_guard guard(mutex_); + auto& entry = entries_.at(*id % max_entries_); + if (entry.id_ == *id) { + entry.state_ = "completed"; + entry.start_ = entry.end_ = nullptr; + } + } + + std::string dump() { + auto result = dump_entries(); + auto entries = new_list(); + c10::IValue pg_id_s = "pg_id"; + c10::IValue seq_id_s = "seq_id"; + c10::IValue profiling_name_s = "profiling_name"; + c10::IValue input_sizes_s = "input_sizes"; + c10::IValue output_sizes_s = "output_sizes"; + + c10::IValue frames_s = "frames"; + c10::IValue state_s = "state"; + c10::IValue line_s = "line"; + c10::IValue name_s = "name"; + c10::IValue filename_s = "filename"; + + std::vector tracebacks; + for (auto& e : result) { + tracebacks.push_back(e.traceback_.get()); + } + torch::SymbolizedTracebacks stracebacks = torch::symbolize(tracebacks); + std::vector all_frames; + for (const auto& f : stracebacks.all_frames) { + auto d = new_dict(); + d.insert(name_s, f.funcname); + d.insert(filename_s, f.filename); + d.insert(line_s, int64_t(f.lineno)); + all_frames.emplace_back(std::move(d)); + } + + for (auto i : c10::irange(result.size())) { + auto& e = result.at(i); + auto& tb = stracebacks.tracebacks.at(i); + auto dict = new_dict(); + dict.insert(pg_id_s, int64_t(e.pg_id_)); + dict.insert(seq_id_s, int64_t(e.seq_id_)); + dict.insert(profiling_name_s, e.profiling_name_); + + auto it = e.sizes_.begin(); + auto read_sizes = [&](const c10::SmallVector& dims) { + auto sizes = new_list(); + for (auto dim : dims) { + auto arg_sizes = new_list(); + for (auto i : c10::irange(dim)) { + (void)i; + arg_sizes.push_back(*it++); + } + sizes.push_back(arg_sizes); + } + return sizes; + }; + + dict.insert(input_sizes_s, read_sizes(e.input_dims_)); + dict.insert(output_sizes_s, read_sizes(e.output_dims_)); + dict.insert(state_s, e.state_); + auto frames = new_list(); + for (int64_t frame : tb) { + frames.push_back(all_frames.at(frame)); + } + dict.insert(frames_s, frames); + entries.push_back(dict); + } + return pickle_str(entries); + } +}; + } // namespace c10d From 34326e43eb7206f376a380125b24f473666ef1ae Mon Sep 17 00:00:00 2001 From: Andrew Gu Date: Thu, 23 Nov 2023 01:57:35 +0000 Subject: [PATCH 135/221] [DTensor] Made `DTensorSpec` hash recomputation lazy (#114379) If we assign `spec.tensor_meta = ...`, we do not have to recompute the hash eagerly. We just need to clear the existing hash so that the next call to `__hash__` recomputes it. We found that the breakage of the DTensor + `torch.compile` tests comes from https://github.com/pytorch/pytorch/pull/114236 and are not directly related to the `DTensorSpec` hashing changes. We fix that in the following PR temporarily by passing `dynamic=False`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/114379 Approved by: https://github.com/wanchaol --- torch/distributed/_tensor/placement_types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/distributed/_tensor/placement_types.py b/torch/distributed/_tensor/placement_types.py index e4153565fe03..8b7902ada026 100644 --- a/torch/distributed/_tensor/placement_types.py +++ b/torch/distributed/_tensor/placement_types.py @@ -395,7 +395,7 @@ def __setattr__(self, attr: str, value: Any): # Make sure to recompute the hash in case any of the hashed attributes # change (though we do not expect `mesh` or `placements` to change) if hasattr(self, "_hash") and attr in ("mesh", "placements", "tensor_meta"): - self._hash = self._hash_impl() + self._hash = None def _hash_impl(self) -> int: # hashing and equality check for DTensorSpec are used to cache the sharding From ed05af278caa47db55b4fea878341f683acb8767 Mon Sep 17 00:00:00 2001 From: Andrew Gu Date: Thu, 23 Nov 2023 01:57:35 +0000 Subject: [PATCH 136/221] [DTensor] Passed `dynamic=False` for compile tests (#114390) Test Plan: ``` python test/distributed/_tensor/test_dtensor_compile.py ``` We found that after https://github.com/pytorch/pytorch/pull/114236 landed, DTensor + `torch.compile` tests were breaking (which was confounded with `DTensorSpec` hash changes). The temporary solution is to pass `dynamic=False`. Otherwise, we see errors like:
``` ====================================================================== ERROR: test_2d_fsdp_tp_ac_compile (__main__.TestDTensorCompileE2E) ---------------------------------------------------------------------- Traceback (most recent call last): File "/data/users/andgu/pytorch/torch/testing/_internal/common_distributed.py", line 533, in wrapper self._join_processes(fn) File "/data/users/andgu/pytorch/torch/testing/_internal/common_distributed.py", line 752, in _join_processes self._check_return_codes(elapsed_time) File "/data/users/andgu/pytorch/torch/testing/_internal/common_distributed.py", line 802, in _check_return_codes raise RuntimeError(error) RuntimeError: Process 2 exited with error code 10 and exception: Traceback (most recent call last): File "/data/users/andgu/pytorch/torch/testing/_internal/common_distributed.py", line 649, in run_test getattr(self, test_name)() File "/data/users/andgu/pytorch/torch/testing/_internal/common_distributed.py", line 535, in wrapper fn() File "/data/users/andgu/pytorch/torch/testing/_internal/common_utils.py", line 2652, in wrapper method(*args, **kwargs) File "/data/users/andgu/pytorch/torch/testing/_internal/distributed/_tensor/common_dtensor.py", line 193, in wrapper func(self, *args, **kwargs) # type: ignore[misc] File "/data/users/andgu/pytorch/torch/testing/_internal/common_distributed.py", line 174, in wrapper return func(*args, **kwargs) File "/data/users/andgu/pytorch/test/distributed/_tensor/test_dtensor_compile.py", line 328, in test_2d_fsdp_tp_ac_compile compiled_output = compiled_2d(inp) File "/data/users/andgu/pytorch/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/data/users/andgu/pytorch/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(*args, **kwargs) File "/data/users/andgu/pytorch/torch/_dynamo/eval_frame.py", line 489, in _fn return fn(*args, **kwargs) File "/data/users/andgu/pytorch/torch/_dynamo/external_utils.py", line 17, in inner return fn(*args, **kwargs) File "/data/users/andgu/pytorch/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/data/users/andgu/pytorch/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(*args, **kwargs) File "/data/users/andgu/pytorch/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 848, in forward output = self._fsdp_wrapped_module(*args, **kwargs) File "/data/users/andgu/pytorch/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/data/users/andgu/pytorch/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(*args, **kwargs) File "/data/users/andgu/pytorch/torch/_dynamo/eval_frame.py", line 655, in catch_errors return callback(frame, cache_entry, hooks, frame_state) File "/data/users/andgu/pytorch/torch/_dynamo/convert_frame.py", line 721, in _convert_frame result = inner_convert(frame, cache_entry, hooks, frame_state) File "/data/users/andgu/pytorch/torch/_dynamo/convert_frame.py", line 383, in _convert_frame_assert compiled_product = _compile( File "/data/users/andgu/pytorch/torch/_dynamo/convert_frame.py", line 645, in _compile guarded_code = compile_inner(code, one_graph, hooks, transform) File "/data/users/andgu/pytorch/torch/_dynamo/utils.py", line 244, in time_wrapper r = func(*args, **kwargs) File "/data/users/andgu/pytorch/torch/_dynamo/convert_frame.py", line 562, in compile_inner out_code = transform_code_object(code, transform) File "/data/users/andgu/pytorch/torch/_dynamo/bytecode_transformation.py", line 1033, in transform_code_object transformations(instructions, code_options) File "/data/users/andgu/pytorch/torch/_dynamo/convert_frame.py", line 151, in _fn return fn(*args, **kwargs) File "/data/users/andgu/pytorch/torch/_dynamo/convert_frame.py", line 527, in transform tracer.run() File "/data/users/andgu/pytorch/torch/_dynamo/symbolic_convert.py", line 2123, in run super().run() File "/data/users/andgu/pytorch/torch/_dynamo/symbolic_convert.py", line 818, in run and self.step() File "/data/users/andgu/pytorch/torch/_dynamo/symbolic_convert.py", line 781, in step getattr(self, inst.opname)(inst) File "/data/users/andgu/pytorch/torch/_dynamo/symbolic_convert.py", line 2238, in RETURN_VALUE self.output.compile_subgraph( File "/data/users/andgu/pytorch/torch/_dynamo/output_graph.py", line 912, in compile_subgraph self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root) File "/home/andgu/local/miniconda3/envs/pytorch-3.10/lib/python3.10/contextlib.py", line 79, in inner return func(*args, **kwds) File "/data/users/andgu/pytorch/torch/_dynamo/output_graph.py", line 1069, in compile_and_call_fx_graph compiled_fn = self.call_user_compiler(gm) File "/data/users/andgu/pytorch/torch/_dynamo/utils.py", line 244, in time_wrapper r = func(*args, **kwargs) File "/data/users/andgu/pytorch/torch/_dynamo/output_graph.py", line 1141, in call_user_compiler raise BackendCompilerFailed(self.compiler_fn, e).with_traceback( File "/data/users/andgu/pytorch/torch/_dynamo/output_graph.py", line 1122, in call_user_compiler compiled_fn = compiler_fn(gm, self.example_inputs()) File "/data/users/andgu/pytorch/torch/_dynamo/repro/after_dynamo.py", line 117, in debug_wrapper compiled_gm = compiler_fn(gm, example_inputs) File "/data/users/andgu/pytorch/torch/__init__.py", line 1696, in __call__ return self.compiler_fn(model_, inputs_, **self.kwargs) File "/data/users/andgu/pytorch/torch/_dynamo/backends/common.py", line 55, in compiler_fn cg = aot_module_simplified(gm, example_inputs, **kwargs) File "/data/users/andgu/pytorch/torch/_functorch/aot_autograd.py", line 4946, in aot_module_simplified compiled_fn = create_aot_dispatcher_function( File "/data/users/andgu/pytorch/torch/_dynamo/utils.py", line 244, in time_wrapper r = func(*args, **kwargs) File "/data/users/andgu/pytorch/torch/_functorch/aot_autograd.py", line 4486, in create_aot_dispatcher_function compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config, fw_metadata=fw_metadata) File "/data/users/andgu/pytorch/torch/_functorch/aot_autograd.py", line 2825, in aot_wrapper_dedupe return compiler_fn(flat_fn, leaf_flat_args, aot_config, fw_metadata=fw_metadata) File "/data/users/andgu/pytorch/torch/_functorch/aot_autograd.py", line 3011, in aot_wrapper_synthetic_base return compiler_fn(flat_fn, flat_args, aot_config, fw_metadata=fw_metadata) File "/data/users/andgu/pytorch/torch/_functorch/aot_autograd.py", line 3714, in aot_dispatch_autograd fx_g, joint_inputs, maybe_subclass_meta = aot_dispatch_autograd_graph(flat_fn, flat_args, aot_config, fw_metadata=fw_metadata) File "/data/users/andgu/pytorch/torch/_functorch/aot_autograd.py", line 3694, in aot_dispatch_autograd_graph fx_g = create_graph(joint_fn_to_trace, updated_joint_inputs, aot_config=aot_config) File "/data/users/andgu/pytorch/torch/_functorch/aot_autograd.py", line 1955, in create_graph fx_g = make_fx(f, decomposition_table=aot_config.decompositions)(*args) File "/data/users/andgu/pytorch/torch/fx/experimental/proxy_tensor.py", line 869, in wrapped t = dispatch_trace(wrap_key(func, args, fx_tracer, pre_dispatch), tracer=fx_tracer, concrete_args=tuple(phs)) File "/data/users/andgu/pytorch/torch/_compile.py", line 24, in inner return torch._dynamo.disable(fn, recursive)(*args, **kwargs) File "/data/users/andgu/pytorch/torch/_dynamo/eval_frame.py", line 489, in _fn return fn(*args, **kwargs) File "/data/users/andgu/pytorch/torch/_dynamo/external_utils.py", line 17, in inner return fn(*args, **kwargs) File "/data/users/andgu/pytorch/torch/fx/experimental/proxy_tensor.py", line 481, in dispatch_trace graph = tracer.trace(root, concrete_args) File "/data/users/andgu/pytorch/torch/_dynamo/eval_frame.py", line 489, in _fn return fn(*args, **kwargs) File "/data/users/andgu/pytorch/torch/_dynamo/external_utils.py", line 17, in inner return fn(*args, **kwargs) File "/data/users/andgu/pytorch/torch/fx/_symbolic_trace.py", line 821, in trace (self.create_arg(fn(*args)),), File "/data/users/andgu/pytorch/torch/fx/_symbolic_trace.py", line 688, in flatten_fn tree_out = root_fn(*tree_args) File "/data/users/andgu/pytorch/torch/fx/experimental/proxy_tensor.py", line 517, in wrapped out = f(*tensors) File "/data/users/andgu/pytorch/torch/_functorch/aot_autograd.py", line 3607, in joint_fn return inner_fn(flat_fn_maybe_joint, (primals, tangents), use_trace_joint=True) File "/data/users/andgu/pytorch/torch/_functorch/aot_autograd.py", line 3591, in inner_fn wrapped_outs = fn(*all_args) File "/data/users/andgu/pytorch/torch/_functorch/aot_autograd.py", line 1941, in joint_helper return functionalized_f_helper(primals, tangents) File "/data/users/andgu/pytorch/torch/_functorch/aot_autograd.py", line 1894, in functionalized_f_helper f_outs = fn(*f_args) File "/data/users/andgu/pytorch/torch/_functorch/aot_autograd.py", line 1862, in inner_fn_with_anomaly return inner_fn(*args) File "/data/users/andgu/pytorch/torch/_functorch/aot_autograd.py", line 1796, in inner_fn outs, tangent_mask = fn(*primals) File "/data/users/andgu/pytorch/torch/_functorch/aot_autograd.py", line 1724, in inner_fn outs = fn(*args_maybe_cloned) File "/data/users/andgu/pytorch/torch/_functorch/aot_autograd.py", line 4552, in functional_call out = Interpreter(mod).run(*args[params_len:], **kwargs) File "/data/users/andgu/pytorch/torch/fx/interpreter.py", line 138, in run self.env[node] = self.run_node(node) File "/data/users/andgu/pytorch/torch/fx/interpreter.py", line 195, in run_node return getattr(self, n.op)(n.target, args, kwargs) File "/data/users/andgu/pytorch/torch/fx/interpreter.py", line 267, in call_function return target(*args, **kwargs) File "/data/users/andgu/pytorch/torch/distributed/_tensor/api.py", line 280, in __torch_dispatch__ return DTensor._op_dispatcher.dispatch( File "/data/users/andgu/pytorch/torch/distributed/_tensor/dispatch.py", line 106, in dispatch self.sharding_propagator.propagate(op_info) File "/data/users/andgu/pytorch/torch/distributed/_tensor/sharding_prop.py", line 161, in propagate output_sharding = self.propagate_op_sharding_non_cached(op_info.schema) File "/data/users/andgu/pytorch/torch/distributed/_tensor/sharding_prop.py", line 175, in propagate_op_sharding_non_cached out_tensor_meta = self._propagate_tensor_meta(op_schema) File "/data/users/andgu/pytorch/torch/distributed/_tensor/sharding_prop.py", line 85, in _propagate_tensor_meta fake_args = op_schema.gen_fake_args() File "/data/users/andgu/pytorch/torch/distributed/_tensor/op_schema.py", line 332, in gen_fake_args return tree_map_only( File "/data/users/andgu/pytorch/torch/utils/_cxx_pytree.py", line 765, in tree_map_only return tree_map( File "/data/users/andgu/pytorch/torch/utils/_cxx_pytree.py", line 607, in tree_map return optree.tree_map( File "/home/andgu/local/miniconda3/envs/pytorch-3.10/lib/python3.10/site-packages/optree/ops.py", line 473, in tree_map return treespec.unflatten(flat_results) File "/data/users/andgu/pytorch/torch/utils/_cxx_pytree.py", line 713, in wrapped return func(x) File "/data/users/andgu/pytorch/torch/distributed/_tensor/op_schema.py", line 31, in _rebuild_tensor_from_dtensor_meta return torch.empty_strided( File "/data/users/andgu/pytorch/torch/_subclasses/functional_tensor.py", line 297, in __torch_dispatch__ outs_unwrapped = func(*args_unwrapped, **kwargs_unwrapped) File "/data/users/andgu/pytorch/torch/_ops.py", line 509, in __call__ return self._op(*args, **kwargs or {}) File "/data/users/andgu/pytorch/torch/utils/_stats.py", line 20, in wrapper return fn(*args, **kwargs) File "/data/users/andgu/pytorch/torch/fx/experimental/proxy_tensor.py", line 594, in __torch_dispatch__ return self.inner_torch_dispatch(func, types, args, kwargs) File "/data/users/andgu/pytorch/torch/fx/experimental/proxy_tensor.py", line 629, in inner_torch_dispatch return proxy_call(self, func, self.pre_dispatch, args, kwargs) File "/data/users/andgu/pytorch/torch/fx/experimental/proxy_tensor.py", line 317, in proxy_call proxy_args, proxy_kwargs = pytree.tree_map_only( File "/data/users/andgu/pytorch/torch/utils/_pytree.py", line 631, in tree_map_only return tree_map(map_only(__type_or_types)(func), tree) File "/data/users/andgu/pytorch/torch/utils/_pytree.py", line 523, in tree_map return tree_unflatten([func(i) for i in flat_args], spec) File "/data/users/andgu/pytorch/torch/utils/_pytree.py", line 523, in return tree_unflatten([func(i) for i in flat_args], spec) File "/data/users/andgu/pytorch/torch/utils/_pytree.py", line 591, in wrapped return func(x) File "/data/users/andgu/pytorch/torch/fx/experimental/proxy_tensor.py", line 247, in inner return get_proxy_slot(n, tracer)() File "/data/users/andgu/pytorch/torch/fx/experimental/proxy_tensor.py", line 110, in get_proxy_slot raise RuntimeError(f"{obj} is not tracked with proxy for {tracer}") torch._dynamo.exc.BackendCompilerFailed: backend='aot_eager' raised: RuntimeError: s0 is not tracked with proxy for While executing %result_2 : [num_users=1] = call_function[target=torch._C._nn.linear](args = (%prim_redistribute_2, %l_self_mlp_0_net2_weight, %l_self_mlp_0_net2_bias), kwargs = {}) Original traceback: File "/data/users/andgu/pytorch/test/distributed/_tensor/test_dtensor_compile.py", line 51, in forward return self.mlp_1(self.mlp_0(input)) File "/data/users/andgu/pytorch/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(*args, **kwargs) File "/data/users/andgu/pytorch/torch/testing/_internal/distributed/_tensor/common_dtensor.py", line 64, in forward return self.net2(self.relu(self.net1(x))) File "/data/users/andgu/pytorch/torch/nn/modules/module.py", line 1561, in _call_impl result = forward_call(*args, **kwargs) File "/data/users/andgu/pytorch/torch/nn/modules/linear.py", line 116, in forward return F.linear(input, self.weight, self.bias) ```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114390 Approved by: https://github.com/wanchaol, https://github.com/huydhn ghstack dependencies: #114379 --- .../_tensor/test_dtensor_compile.py | 26 +++++++++++-------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/test/distributed/_tensor/test_dtensor_compile.py b/test/distributed/_tensor/test_dtensor_compile.py index 3176947d239e..ee56000a02d7 100644 --- a/test/distributed/_tensor/test_dtensor_compile.py +++ b/test/distributed/_tensor/test_dtensor_compile.py @@ -102,7 +102,7 @@ def fn(x): x = DTensor.from_local(torch.rand(1), mesh, [Shard(0)], run_check=False) ref = fn(x) - opt_fn = torch.compile(fn, backend="aot_eager", fullgraph=True) + opt_fn = torch.compile(fn, backend="aot_eager", fullgraph=True, dynamic=False) res = opt_fn(x) self.assertEqual(res, ref) @@ -116,7 +116,7 @@ def fn(x): x = DTensor.from_local(torch.rand(1), mesh, [Shard(0)], run_check=False) ref = fn(x) - opt_fn = torch.compile(fn, backend="aot_eager", fullgraph=True) + opt_fn = torch.compile(fn, backend="aot_eager", fullgraph=True, dynamic=False) res = opt_fn(x) self.assertEqual(res, ref) @@ -141,7 +141,7 @@ def fn(x): x = torch.ones(1) ref = fn(x) - opt_fn = torch.compile(fn, backend="aot_eager", fullgraph=True) + opt_fn = torch.compile(fn, backend="aot_eager", fullgraph=True, dynamic=False) res = opt_fn(x) self.assertEqual(res, ref) @@ -154,7 +154,7 @@ def from_local_kwargs_fn(x): ref = from_local_kwargs_fn(x) opt_kwargs_fn = torch.compile( - from_local_kwargs_fn, backend="aot_eager", fullgraph=True + from_local_kwargs_fn, backend="aot_eager", fullgraph=True, dynamic=False ) res = opt_kwargs_fn(x) self.assertEqual(res, ref) @@ -170,7 +170,7 @@ def fn(x): x = torch.ones(1) ref = fn(x) - opt_fn = torch.compile(fn, backend="aot_eager", fullgraph=True) + opt_fn = torch.compile(fn, backend="aot_eager", fullgraph=True, dynamic=False) res = opt_fn(x) self.assertEqual(res, ref) @@ -184,7 +184,7 @@ def redistribute_kwargs_fn(x): x = torch.ones(1) ref = redistribute_kwargs_fn(x) opt_kwargs_fn = torch.compile( - redistribute_kwargs_fn, backend="aot_eager", fullgraph=True + redistribute_kwargs_fn, backend="aot_eager", fullgraph=True, dynamic=False ) res = opt_kwargs_fn(x) self.assertEqual(res, ref) @@ -232,7 +232,9 @@ def test_tp_compile_fullgraph(self, is_seq_parallel): torch.manual_seed(rng_seed) inp = torch.rand(20, 10, device=self.device_type) out = model(inp) - compiled_mod = torch.compile(model, backend="aot_eager", fullgraph=True) + compiled_mod = torch.compile( + model, backend="aot_eager", fullgraph=True, dynamic=False + ) compiled_out = compiled_mod(inp) self.assertEqual(compiled_out, out) @@ -280,7 +282,7 @@ def test_2d_fsdp_tp_compile(self): ) # TODO: once aot autograd support is ready we can just use default backend - compiled_2d = torch.compile(fsdp_2d, backend="aot_eager") + compiled_2d = torch.compile(fsdp_2d, backend="aot_eager", dynamic=False) compiled_output = compiled_2d(inp) self.assertEqual(out, compiled_output) @@ -321,7 +323,7 @@ def test_2d_fsdp_tp_ac_compile(self): use_orig_params=True, ) # TODO: once aot autograd support is ready we can just use default backend - compiled_2d = torch.compile(fsdp_2d, backend="aot_eager") + compiled_2d = torch.compile(fsdp_2d, backend="aot_eager", dynamic=False) # forward pass out = eager_2d(inp) @@ -346,9 +348,11 @@ def fn(x, y): dt2 = DTensor.from_local(y.reshape(4, 2), mesh, [Shard(1)], run_check=False) dt_out = torch.matmul(dt, dt2) dt_out_redistribute = dt_out.redistribute(mesh, [Replicate()]) - return dt_out.to_local() + return dt_out_redistribute.to_local() - opt_fn = torch.compile(fn, backend=aot_eager_graph, fullgraph=True) + opt_fn = torch.compile( + fn, backend=aot_eager_graph, fullgraph=True, dynamic=False + ) x_ref = torch.arange(8, requires_grad=True, dtype=torch.float32) y_ref = torch.arange(8, requires_grad=True, dtype=torch.float32) From 066e072524718f376c8eed36b638683eaf628121 Mon Sep 17 00:00:00 2001 From: Chip Turner Date: Thu, 23 Nov 2023 06:59:57 +0000 Subject: [PATCH 137/221] Retry #112889 (Opportunistically use ncclCommSplit when creating new NCCL groups) (#114385) - [c10d] (retry) Opportunistically use `ncclCommSplit` when creating new NCCL groups (#112889) - Guard use of `split_from` with a `hasattr` check for cases when NCCL (or RCCL) lacks `ncclCommSplit` Fixes cause of revert of original PR Pull Request resolved: https://github.com/pytorch/pytorch/pull/114385 Approved by: https://github.com/huydhn --- test/cpp/c10d/ProcessGroupNCCLTest.cpp | 78 ++++++++++++++++--- test/distributed/test_c10d_nccl.py | 22 +++++- torch/csrc/distributed/c10d/NCCLUtils.hpp | 26 +++++++ .../distributed/c10d/ProcessGroupNCCL.cpp | 47 +++++++++-- .../distributed/c10d/ProcessGroupNCCL.hpp | 9 +++ torch/csrc/distributed/c10d/init.cpp | 13 ++-- torch/distributed/distributed_c10d.py | 35 ++++++++- 7 files changed, 208 insertions(+), 22 deletions(-) diff --git a/test/cpp/c10d/ProcessGroupNCCLTest.cpp b/test/cpp/c10d/ProcessGroupNCCLTest.cpp index 61e9753988ea..6a0d60b57315 100644 --- a/test/cpp/c10d/ProcessGroupNCCLTest.cpp +++ b/test/cpp/c10d/ProcessGroupNCCLTest.cpp @@ -31,12 +31,20 @@ class NCCLTestBase { pg_ = std::move(other.pg_); } - ::c10d::ProcessGroupNCCL& getProcessGroup() { - return *pg_; + std::shared_ptr<::c10d::ProcessGroupNCCL> getProcessGroup() { + return pg_; } - void initialize(int rank, int size) { - auto store = c10::make_intrusive<::c10d::FileStore>(path_, size); + ::c10::intrusive_ptr<::c10d::Store>& getProcessGroupStore() { + return store_; + } + + void initialize( + int rank, + int size, + c10::optional<::std::shared_ptr<::c10d::ProcessGroupNCCL>> split_from = + c10::nullopt) { + store_ = c10::make_intrusive<::c10d::FileStore>(path_, size); c10::intrusive_ptr opts = c10::make_intrusive(); @@ -45,14 +53,22 @@ class NCCLTestBase { c10d::TORCH_ENABLE_NCCL_HEALTH_CHECK[0].c_str(), "1", /* overwrite */ 1); +#ifdef NCCL_HAS_COMM_SPLIT + if (split_from) { + opts->split_from = *split_from; + opts->split_color = ++color_; + } +#endif pg_ = std::unique_ptr<::c10d::ProcessGroupNCCL>( - new ::c10d::ProcessGroupNCCL(store, rank, size, std::move(opts))); + new ::c10d::ProcessGroupNCCL(store_, rank, size, std::move(opts))); } protected: std::string path_; - std::unique_ptr<::c10d::ProcessGroupNCCL> pg_; + std::shared_ptr<::c10d::ProcessGroupNCCL> pg_; std::chrono::milliseconds pgTimeout_; + ::c10::intrusive_ptr<::c10d::Store> store_; + int color_{1}; }; class NCCLTest : public NCCLTestBase { @@ -718,9 +734,9 @@ void testSequenceNumInit( auto runTest = [&](int i) { NCCLTest test(path, worldSize); test.initialize(i, worldSize); - test.getProcessGroup().setSequenceNumberForGroup(); + test.getProcessGroup()->setSequenceNumberForGroup(); std::lock_guard lock(m); - auto seqNum = test.getProcessGroup().getSequenceNumberForGroup(); + auto seqNum = test.getProcessGroup()->getSequenceNumberForGroup(); nums.insert(seqNum); }; std::vector threads; @@ -877,11 +893,55 @@ TEST_F(ProcessGroupNCCLTest, testBackendName) { auto test = NCCLTestBase(file.path); test.initialize(rank_, size_); EXPECT_EQ( - test.getProcessGroup().getBackendName(), + test.getProcessGroup()->getBackendName(), std::string(c10d::NCCL_BACKEND_NAME)); } } +TEST_F(ProcessGroupNCCLTest, testSplittingCommunicator) { + if (skipTest()) { + return; + } + TemporaryFile file; + auto test1 = BroadcastNCCLTest(file.path, size_); + test1.initialize(rank_, size_); + + auto test2 = BroadcastNCCLTest(file.path, size_); + test2.initialize(rank_, size_, test1.getProcessGroup()); + + // Steal the broadcast test and issue it for both of our groups. + // This ensures consistent full collective communication. TODO: + // maybe refactor the guts rather than copy-pasta, but it may not be + // worth it. + for (auto test : {&test1, &test2}) { + const int numDevices = test->numDevices(); + // try every permutation of root rank and root tensor + for (const auto rootRank : c10::irange(size_)) { + for (const auto rootTensor : c10::irange(numDevices)) { + auto work = test->run(rootRank, rootTensor); + test->wait(work); + + // Check results + const auto expected = (rootRank * numDevices + rootTensor); + const auto tensors = test->getTensors(); + for (const auto& tensor : tensors) { + const auto* const data = tensor.data_ptr(); + for (const auto k : c10::irange(tensor.numel())) { + EXPECT_EQ(data[k], expected) + << "Broadcast outputs do not match expected outputs"; + } + } + } + } + } + + // Now that we've run full operations on both the original and split process + // group, ensure we saw exactly as many splits as we expected: 0 in the + // original process group, and one per device in the second. + EXPECT_EQ(test2.getProcessGroup()->getCommSplitCounter(), 0); + EXPECT_EQ(test1.getProcessGroup()->getCommSplitCounter(), test1.numDevices()); +} + #ifdef IS_NCCL_EXP TEST_F(ProcessGroupNCCLTest, testSparseAllreduce) { if (skipTest()) { diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index ada84507aef9..f330fb0b08f7 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -1272,6 +1272,27 @@ def allgather_base(output_t, input_t): # Verification self.assertEqual(torch.arange(self.world_size), output_t) + @requires_nccl_version((2, 18), "Need NCCL 2.18+ for ncclCommSplit") + def test_comm_split_optimization(self): + store = c10d.FileStore(self.file_name, self.world_size) + pg = self._create_process_group_nccl(store, self.opts()) + + # Test lazy splitting behavior across each per-device backend. + for device in self.rank_to_GPU[self.rank]: + backend = pg._get_backend(torch.device(device)) + + # split doesn't happen unless the original process group has lazily + # created communicators, so first verify we haven't split even when + # making the new group and running an operation on the original pg. + ng = c10d.new_group() + tensor = torch.tensor([self.rank]).cuda(device) + pg.broadcast(tensor, 0) + self.assertEqual(backend.comm_split_count(), 0) + + # The new group will force a split of the original on first use. + ng.broadcast(tensor, 0) + self.assertEqual(backend.comm_split_count(), 1) + class DistributedDataParallelTest( test_c10d_common.CommonDistributedDataParallelTest, MultiProcessTestCase ): @@ -3676,7 +3697,6 @@ def gather_trace(): - if __name__ == "__main__": assert ( not torch.cuda._initialized diff --git a/torch/csrc/distributed/c10d/NCCLUtils.hpp b/torch/csrc/distributed/c10d/NCCLUtils.hpp index e6c05e228cfd..2b4885f02ffc 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.hpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.hpp @@ -17,6 +17,11 @@ #define NCCL_HAS_COMM_NONBLOCKING #endif +#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \ + (NCCL_MINOR >= 18) +#define NCCL_HAS_COMM_SPLIT +#endif + // ncclGetLastError() is enabled only for NCCL versions 2.13+ // ncclRemoteError only exists in NCCL versions 2.13+ #if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \ @@ -246,6 +251,22 @@ class NCCLComm { } #endif +#ifdef NCCL_HAS_COMM_SPLIT + static std::shared_ptr split( + NCCLComm* source, + int color_id, + int rank, + ncclConfig_t& config) { + auto comm = std::make_shared(); + C10D_NCCL_CHECK( + ncclCommSplit( + source->ncclComm_, color_id, rank, &(comm->ncclComm_), &config), + c10::nullopt); + ++source->ncclCommSplitCounter_; + return comm; + } +#endif + ncclUniqueId getNcclId() { return ncclId_; } @@ -325,6 +346,10 @@ class NCCLComm { return aborted_; } + uint64_t getCommSplitCounter() const { + return ncclCommSplitCounter_; + } + ncclResult_t checkForNcclError() { std::unique_lock lock(mutex_); #ifdef ENABLE_NCCL_ERROR_CHECKING @@ -401,6 +426,7 @@ class NCCLComm { // Unique nccl_id for this communicator. ncclUniqueId ncclId_; bool aborted_; + uint64_t ncclCommSplitCounter_{0}; ncclResult_t ncclAsyncErr_; mutable std::mutex mutex_; // Rank that this communicator corresponds to. diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index e816b6fc193d..50135339093f 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -1620,11 +1620,40 @@ std::vector>& ProcessGroupNCCL::getNCCLComm( int deviceIndex = devices[i].index(); gpuGuard.set_index(deviceIndex); +#ifdef NCCL_HAS_COMM_SPLIT + if (options_->split_from) { + TORCH_CHECK( + options_->split_color != 0, + "Must specify a non-zero color when splitting"); + // Find a valid, healthy communicator to split from if possible. + std::lock_guard lock(options_->split_from->mutex_); + auto& other_comms = options_->split_from->devNCCLCommMap_; + auto dit = other_comms.find(devicesKey); + if (dit != other_comms.end() && !dit->second.empty()) { + TORCH_INTERNAL_ASSERT( + dit->second.size() == ncclComms.size(), + "split_from->devNCCLCommMap_ should be empty or the same size as ncclComms!"); + if (dit->second[i] && !dit->second[i]->isAborted()) { + ncclComms[i] = NCCLComm::split( + dit->second[i].get(), + options_->split_color, + rank, + options_->config); + } + } + } +#endif + + // To simplify conditioonal nesting, just create the ncclComms[i] + // entry if it hasn't been yet rather than untangling the + // conditions that might have resulted in a split above. + if (!ncclComms[i]) { #ifdef NCCL_HAS_COMM_NONBLOCKING - ncclComms[i] = NCCLComm::create(numRanks, rank, ncclID, options_->config); + ncclComms[i] = NCCLComm::create(numRanks, rank, ncclID, options_->config); #else - ncclComms[i] = NCCLComm::create(numRanks, rank, ncclID); + ncclComms[i] = NCCLComm::create(numRanks, rank, ncclID); #endif + } // Creates the NCCL streams streamVal.push_back( @@ -1670,9 +1699,6 @@ std::vector>& ProcessGroupNCCL::getNCCLComm( std::make_tuple(devicesKey), std::make_tuple(devices.size())); - // Hold the lock before modifying the cache. - std::lock_guard lock(mutex_); - // Record the communicators based on ncclUniqueId. ncclIdToCommMap_.emplace(buildNcclUniqueIdStr(ncclID), ncclComms); @@ -1716,9 +1742,20 @@ std::vector>& ProcessGroupNCCL::getNCCLComm( it = devNCCLCommMap_.find(devicesKey); TORCH_INTERNAL_ASSERT( it != devNCCLCommMap_.end(), "Communicators not populated in cache!"); + return it->second; } +uint64_t ProcessGroupNCCL::getCommSplitCounter() const { + uint64_t ret = 0; + for (const auto& i : ncclIdToCommMap_) { + for (const auto& j : i.second) { + ret += j->getCommSplitCounter(); + } + } + return ret; +} + namespace { // Check validity of tensor diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index b47983cf0e4b..cca5f9276944 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -341,6 +341,11 @@ class TORCH_API ProcessGroupNCCL : public Backend { // Configure ranks ncclConfig_t config = NCCL_CONFIG_INITIALIZER; #endif + + // Optional "parent" backend and color to create communicators from + // via `ncclCommSplit` + std::shared_ptr split_from; + int64_t split_color{0}; }; // If you wish to create multiple process groups, each with a potentially @@ -509,6 +514,10 @@ class TORCH_API ProcessGroupNCCL : public Backend { // may indicate that there is some sort of collective desynchronization. uint64_t getSequenceNumberForGroup() override; + // Return the total number of splits the communicators held by this process + // group have performed. + uint64_t getCommSplitCounter() const; + void registerOnCompletionHook( std::function)>&& hook) override; void waitForPendingWorks() override; diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 909773dfe47e..7d23ad1b6479 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -2290,6 +2290,9 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`). py::call_guard()) .def("_group_start", &::c10d::ProcessGroupNCCL::groupStart) .def("_group_end", &::c10d::ProcessGroupNCCL::groupEnd) + .def( + "comm_split_count", + &::c10d::ProcessGroupNCCL::getCommSplitCounter) .def_property_readonly( "options", &::c10d::ProcessGroupNCCL::getOptions); @@ -2352,15 +2355,15 @@ Example:: )") .def(py::init(), py::arg("is_high_priority_stream") = false) #ifdef NCCL_HAS_COMM_CTA_CGA + .def_readwrite("config", &::c10d::ProcessGroupNCCL::Options::config) +#endif .def_readwrite( "is_high_priority_stream", &::c10d::ProcessGroupNCCL::Options::is_high_priority_stream) - .def_readwrite("config", &::c10d::ProcessGroupNCCL::Options::config); -#else .def_readwrite( - "is_high_priority_stream", - &::c10d::ProcessGroupNCCL::Options::is_high_priority_stream); -#endif + "split_from", &::c10d::ProcessGroupNCCL::Options::split_from) + .def_readwrite( + "split_color", &::c10d::ProcessGroupNCCL::Options::split_color); #endif diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 63f6c48d35f3..3bd35709505d 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -8,6 +8,7 @@ import logging import os import pickle +import sys import time import warnings from collections import namedtuple @@ -1314,7 +1315,29 @@ def _new_process_group_helper( pg_options.is_high_priority_stream = False pg_options._timeout = timeout - backend_class = ProcessGroupNCCL(backend_prefix_store, group_rank, group_size, pg_options) + # If our new group includes all ranks, we can reduce + # overhead by splitting the communicator (`nccCommSplit`). + + # TODO: support this in the general case by calling + # `nccCommSplit` with `NCCL_SPLIT_NOCOLOR` for the ranks + # not in the communicator. + split_from = None + if ( + is_initialized() + and _world.default_pg._get_backend_name() == Backend.NCCL + and len(global_ranks_in_group) == _world.default_pg.size() + ): + # If possible, find a backend to split from by peeling + # process group wrappers from the world's default pg. + split_from = _world.default_pg._get_backend(_get_pg_default_device()) + while isinstance(split_from, _ProcessGroupWrapper): + split_from = split_from.wrapped_pg + + if split_from: + pg_options.split_from = split_from + pg_options.split_color = _process_group_color(global_ranks_in_group) + backend_class = ProcessGroupNCCL( + backend_prefix_store, group_rank, group_size, pg_options) backend_type = ProcessGroup.BackendType.NCCL elif backend_str == Backend.UCC and is_ucc_available(): # TODO: once UCC plugin is fully deprecated, remove @@ -3514,11 +3537,19 @@ def _create_process_group_wrapper( wrapped_pg = _ProcessGroupWrapper(wrapped_pg, helper_pg) return wrapped_pg +# helper function for deterministically hashing a list of ranks +def _hash_ranks(ranks: List[int]): + return hashlib.sha1(bytes("_".join(map(str, ranks)), "utf-8")).hexdigest() + +# Takes a list of ranks and computes an integer color +def _process_group_color(ranks: List[int]) -> int: + # Convert our hash to an int, but avoid negative numbers by shifting a bit. + return int(_hash_ranks(ranks), 16) % (sys.maxsize >> 1) def _process_group_name(ranks, use_hashed_name): global _world if use_hashed_name: - pg_name = hashlib.sha1(bytes("_".join(map(str, ranks)), "utf-8")).hexdigest() + pg_name = _hash_ranks(ranks) while pg_name in _world.pg_names.values(): pg_name = hashlib.sha1(bytes(pg_name + "_", "utf-8")).hexdigest() else: From a43edd836ca8c1d3b65ed0122439ef14193a8b20 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 23 Nov 2023 07:07:56 +0000 Subject: [PATCH 138/221] Revert "Add support for models with mutated buffer on torch.onnx.dynamo_export (#112272)" This reverts commit c4a22d6918b7ca218f2712d7e7e147aca7127fa3. Reverted https://github.com/pytorch/pytorch/pull/112272 on behalf of https://github.com/huydhn due to Sorry for reverting you change but it is failing dynamo test in trunk https://hud.pytorch.org/pytorch/pytorch/commit/c4a22d6918b7ca218f2712d7e7e147aca7127fa3 ([comment](https://github.com/pytorch/pytorch/pull/112272#issuecomment-1823897964)) --- test/onnx/onnx_test_common.py | 10 +++--- test/onnx/test_fx_to_onnx_with_onnxruntime.py | 31 ---------------- .../fx/torch_export_graph_extractor.py | 4 --- torch/onnx/_internal/io_adapter.py | 36 ------------------- 4 files changed, 4 insertions(+), 77 deletions(-) diff --git a/test/onnx/onnx_test_common.py b/test/onnx/onnx_test_common.py index 0be906fada77..f664e7e84a42 100644 --- a/test/onnx/onnx_test_common.py +++ b/test/onnx/onnx_test_common.py @@ -436,18 +436,16 @@ def _compare_pytorch_onnx_with_ort( ref_input_args = input_args ref_input_kwargs = input_kwargs - # ONNXProgram holds a reference (not copy) to the original ref_model, including its state_dict. - # Thus, ONNXProgram() must run before ref_model() to prevent ref_model.forward() from changing the state_dict. - # Otherwise, the ref_model can change buffers on state_dict which would be used by ONNXProgram.__call__() + ref_outputs = onnx_program.adapt_torch_outputs_to_onnx( + ref_model(*ref_input_args, **ref_input_kwargs) + ) + ort_outputs = onnx_program(*input_args, **input_kwargs) - ref_outputs = ref_model(*ref_input_args, **ref_input_kwargs) - ref_outputs = onnx_program.adapt_torch_outputs_to_onnx(ref_outputs) if len(ref_outputs) != len(ort_outputs): raise AssertionError( f"Expected {len(ref_outputs)} outputs, got {len(ort_outputs)}" ) - for ref_output, ort_output in zip(ref_outputs, ort_outputs): torch.testing.assert_close( ref_output, torch.tensor(ort_output), rtol=rtol, atol=atol diff --git a/test/onnx/test_fx_to_onnx_with_onnxruntime.py b/test/onnx/test_fx_to_onnx_with_onnxruntime.py index dcede4718a8c..26fa6f215bec 100644 --- a/test/onnx/test_fx_to_onnx_with_onnxruntime.py +++ b/test/onnx/test_fx_to_onnx_with_onnxruntime.py @@ -942,37 +942,6 @@ def forward(self, x): loaded_exported_program, (x,), skip_dynamic_shapes_check=True ) - @pytorch_test_common.xfail_if_model_type_is_not_exportedprogram( - "Unsupported FX nodes: {'call_function': ['aten.add_.Tensor']}. " - "github issue: https://github.com/pytorch/pytorch/issues/114406" - ) - def test_exported_program_as_input_lifting_buffers_mutation(self): - for persistent in (True, False): - - class CustomModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.register_buffer( - "my_buffer", torch.tensor(4.0), persistent=persistent - ) - - def forward(self, x, b): - output = x + b - ( - self.my_buffer.add_(1.0) + 3.0 - ) # Mutate buffer through in-place addition - return output - - inputs = (torch.rand((3, 3), dtype=torch.float32), torch.randn(3, 3)) - model = CustomModule() - self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime( - model, inputs, skip_dynamic_shapes_check=True - ) - # Buffer will be mutated after the first iteration - self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime( - model, inputs, skip_dynamic_shapes_check=True - ) - def _parameterized_class_attrs_and_values_with_fake_options(): input_values = [] diff --git a/torch/onnx/_internal/fx/torch_export_graph_extractor.py b/torch/onnx/_internal/fx/torch_export_graph_extractor.py index 5f1fbb5c7481..51c31560b144 100644 --- a/torch/onnx/_internal/fx/torch_export_graph_extractor.py +++ b/torch/onnx/_internal/fx/torch_export_graph_extractor.py @@ -63,10 +63,6 @@ def generate_fx( # tensor, etc), we flatten the collection and register each element as output. options.fx_tracer.output_adapter.append_step(io_adapter.FlattenOutputStep()) - options.fx_tracer.output_adapter.append_step( - io_adapter.PrependParamsAndBuffersAotAutogradOutputStep(model) - ) - # Export FX graph to ONNX ModelProto. return self.pre_export_passes(options, model, model.graph_module, updated_model_args) # type: ignore[return-value] diff --git a/torch/onnx/_internal/io_adapter.py b/torch/onnx/_internal/io_adapter.py index 28db50a5b58a..45134505000f 100644 --- a/torch/onnx/_internal/io_adapter.py +++ b/torch/onnx/_internal/io_adapter.py @@ -550,39 +550,3 @@ def apply( if model_kwargs: return MergeKwargsIntoArgsInputStep().apply(updated_args, model_kwargs) return updated_args, {} - - -class PrependParamsAndBuffersAotAutogradOutputStep(OutputAdaptStep): - """Prepend model's mutated buffers to the user output. - - :func:`torch.export.export` lifts model's mutated buffers as outputs, thus, they - must be added to the user output after the model is executed. - - Args: - model: The PyTorch model with mutated buffers. - """ - - def __init__(self, model: torch_export.ExportedProgram): - assert isinstance( - model, torch_export.ExportedProgram - ), "'model' must be a torch.export.ExportedProgram." - self.model = model - - def apply(self, model_outputs: Any) -> Sequence[Any]: - """Flatten the model outputs and validate the `SpecTree` output. - - Args: - model_outputs: The model outputs to flatten. - - Returns: - flattened_outputs: The flattened model outputs. - """ - - ordered_buffers = tuple( - self.model.state_dict[name] - for name in self.model.graph_signature.buffers_to_mutate.values() - ) - - # NOTE: calling convention is first mutated buffers, then outputs args as model returned them. - updated_outputs = (*ordered_buffers, *model_outputs) - return updated_outputs From 6f340c6f3089a6acca599ab58d140a8b27ed2c44 Mon Sep 17 00:00:00 2001 From: Huy Do Date: Thu, 23 Nov 2023 07:32:46 +0000 Subject: [PATCH 139/221] Handle the case when opening a reverted PR with deleted head branch (#114423) When reopening a reverted PR, `422: Unprocessable Entity` is returned when the head branch has been deleted, for example https://github.com/pytorch/pytorch/pull/112889#issuecomment-1823216686 ``` { "message": "Validation Failed", "errors": [ { "resource": "PullRequest", "code": "custom", "field": "state", "message": "state cannot be changed. The commsplit branch has been deleted." } ], "documentation_url": "https://docs.github.com/rest/pulls/pulls#update-a-pull-request" } ``` The revert still happens though, only reopening PR fails, which is ok to ignore in this case I think instead of going the complicated route of trying to restore the deleted branch by merge bot. Pull Request resolved: https://github.com/pytorch/pytorch/pull/114423 Approved by: https://github.com/malfet, https://github.com/kit1980 --- .github/scripts/github_utils.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/.github/scripts/github_utils.py b/.github/scripts/github_utils.py index ff98546b385f..05b95fc91664 100644 --- a/.github/scripts/github_utils.py +++ b/.github/scripts/github_utils.py @@ -178,4 +178,14 @@ def gh_fetch_merge_base(org: str, repo: str, base: str, head: str) -> str: def gh_update_pr_state(org: str, repo: str, pr_num: int, state: str = "open") -> None: url = f"{GITHUB_API_URL}/repos/{org}/{repo}/pulls/{pr_num}" - gh_fetch_url(url, method="PATCH", data={"state": state}) + try: + gh_fetch_url(url, method="PATCH", data={"state": state}) + except HTTPError as err: + # When trying to open the pull request, error 422 means that the branch + # has been deleted and the API couldn't re-open it + if err.code == 422 and state == "open": + warnings.warn( + f"Failed to open {pr_num} because its head branch has been deleted: {err}" + ) + else: + raise From 7daeb6509fe021462063d98c544339fc469075df Mon Sep 17 00:00:00 2001 From: Huy Do Date: Thu, 23 Nov 2023 07:36:55 +0000 Subject: [PATCH 140/221] Update audio pinned commit nightly (#114426) I think we could have this pinned commit being update nightly like what we have with vision. This will avoid having an outdated audio pinned commit that needs to be updated manually, i.e. https://github.com/pytorch/pytorch/pull/114393 Pull Request resolved: https://github.com/pytorch/pytorch/pull/114426 Approved by: https://github.com/atalman, https://github.com/seemethere, https://github.com/malfet --- .github/merge_rules.yaml | 2 ++ .github/workflows/nightly.yml | 17 +++++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/.github/merge_rules.yaml b/.github/merge_rules.yaml index f7d62cfdd6b6..8a5dfeef3284 100644 --- a/.github/merge_rules.yaml +++ b/.github/merge_rules.yaml @@ -74,6 +74,7 @@ - name: OSS CI / pytorchbot patterns: + - .github/ci_commit_pins/audio.txt - .github/ci_commit_pins/vision.txt - .github/ci_commit_pins/torchdynamo.txt - .ci/docker/ci_commit_pins/triton.txt @@ -84,6 +85,7 @@ - EasyCLA - Lint - pull + - inductor - name: OSS CI /pytorchbot / Executorch patterns: diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index f338f844af33..76c38c032f57 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -53,6 +53,23 @@ jobs: updatebot-token: ${{ secrets.UPDATEBOT_TOKEN }} pytorchbot-token: ${{ secrets.GH_PYTORCHBOT_TOKEN }} + update-audio-commit-hash: + runs-on: ubuntu-latest + environment: update-commit-hash + steps: + - name: Checkout repo + uses: actions/checkout@v3 + with: + fetch-depth: 0 + - name: update-audio-commit-hash + uses: ./.github/actions/update-commit-hash + if: ${{ github.event_name == 'schedule' }} + with: + repo-name: audio + branch: main + updatebot-token: ${{ secrets.UPDATEBOT_TOKEN }} + pytorchbot-token: ${{ secrets.GH_PYTORCHBOT_TOKEN }} + update-executorch-commit-hash: runs-on: ubuntu-latest environment: update-commit-hash From a76bb5d84d017cc2cd8b95e13b4ea2637457bcc3 Mon Sep 17 00:00:00 2001 From: Thiago Crepaldi Date: Wed, 22 Nov 2023 22:13:48 +0000 Subject: [PATCH 141/221] Add support for models with mutated buffer on torch.onnx.dynamo_export (#112272) This PR adds a unit test that leverages `torch.export.ExportedProgram` models that mutates registered buffers. Although the exporter already works out of the box in such scenario, the GraphModule and the exported ONNX model have extra outputs containing the mutated buffers. On future runs of the ONNX model, the mutated buffers are used as input to the model. The aforementioned extra inputs and outputs are by design and the `ONNXProgram.model_signature` can be used to fetch detailed input/output schema for the exported model. However, when we want to compare pytorch output to ONNX's, there is a mismatch between the schema because pytorch output does not include the mutated buffers present on the ONNX output. This PR extends `onnx_program.adapt_torch_outputs_to_onnx(torch_outputs)` so that the mutated buffers are prepended to the Pytorch output, matching the ONNX schema. Pull Request resolved: https://github.com/pytorch/pytorch/pull/112272 Approved by: https://github.com/titaiwangms, https://github.com/BowenBao --- test/onnx/onnx_test_common.py | 10 +++--- test/onnx/test_fx_to_onnx_with_onnxruntime.py | 31 ++++++++++++++++ .../fx/torch_export_graph_extractor.py | 4 +++ torch/onnx/_internal/io_adapter.py | 36 +++++++++++++++++++ 4 files changed, 77 insertions(+), 4 deletions(-) diff --git a/test/onnx/onnx_test_common.py b/test/onnx/onnx_test_common.py index f664e7e84a42..0be906fada77 100644 --- a/test/onnx/onnx_test_common.py +++ b/test/onnx/onnx_test_common.py @@ -436,16 +436,18 @@ def _compare_pytorch_onnx_with_ort( ref_input_args = input_args ref_input_kwargs = input_kwargs - ref_outputs = onnx_program.adapt_torch_outputs_to_onnx( - ref_model(*ref_input_args, **ref_input_kwargs) - ) - + # ONNXProgram holds a reference (not copy) to the original ref_model, including its state_dict. + # Thus, ONNXProgram() must run before ref_model() to prevent ref_model.forward() from changing the state_dict. + # Otherwise, the ref_model can change buffers on state_dict which would be used by ONNXProgram.__call__() ort_outputs = onnx_program(*input_args, **input_kwargs) + ref_outputs = ref_model(*ref_input_args, **ref_input_kwargs) + ref_outputs = onnx_program.adapt_torch_outputs_to_onnx(ref_outputs) if len(ref_outputs) != len(ort_outputs): raise AssertionError( f"Expected {len(ref_outputs)} outputs, got {len(ort_outputs)}" ) + for ref_output, ort_output in zip(ref_outputs, ort_outputs): torch.testing.assert_close( ref_output, torch.tensor(ort_output), rtol=rtol, atol=atol diff --git a/test/onnx/test_fx_to_onnx_with_onnxruntime.py b/test/onnx/test_fx_to_onnx_with_onnxruntime.py index 26fa6f215bec..dcede4718a8c 100644 --- a/test/onnx/test_fx_to_onnx_with_onnxruntime.py +++ b/test/onnx/test_fx_to_onnx_with_onnxruntime.py @@ -942,6 +942,37 @@ def forward(self, x): loaded_exported_program, (x,), skip_dynamic_shapes_check=True ) + @pytorch_test_common.xfail_if_model_type_is_not_exportedprogram( + "Unsupported FX nodes: {'call_function': ['aten.add_.Tensor']}. " + "github issue: https://github.com/pytorch/pytorch/issues/114406" + ) + def test_exported_program_as_input_lifting_buffers_mutation(self): + for persistent in (True, False): + + class CustomModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer( + "my_buffer", torch.tensor(4.0), persistent=persistent + ) + + def forward(self, x, b): + output = x + b + ( + self.my_buffer.add_(1.0) + 3.0 + ) # Mutate buffer through in-place addition + return output + + inputs = (torch.rand((3, 3), dtype=torch.float32), torch.randn(3, 3)) + model = CustomModule() + self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime( + model, inputs, skip_dynamic_shapes_check=True + ) + # Buffer will be mutated after the first iteration + self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime( + model, inputs, skip_dynamic_shapes_check=True + ) + def _parameterized_class_attrs_and_values_with_fake_options(): input_values = [] diff --git a/torch/onnx/_internal/fx/torch_export_graph_extractor.py b/torch/onnx/_internal/fx/torch_export_graph_extractor.py index 51c31560b144..5f1fbb5c7481 100644 --- a/torch/onnx/_internal/fx/torch_export_graph_extractor.py +++ b/torch/onnx/_internal/fx/torch_export_graph_extractor.py @@ -63,6 +63,10 @@ def generate_fx( # tensor, etc), we flatten the collection and register each element as output. options.fx_tracer.output_adapter.append_step(io_adapter.FlattenOutputStep()) + options.fx_tracer.output_adapter.append_step( + io_adapter.PrependParamsAndBuffersAotAutogradOutputStep(model) + ) + # Export FX graph to ONNX ModelProto. return self.pre_export_passes(options, model, model.graph_module, updated_model_args) # type: ignore[return-value] diff --git a/torch/onnx/_internal/io_adapter.py b/torch/onnx/_internal/io_adapter.py index 45134505000f..28db50a5b58a 100644 --- a/torch/onnx/_internal/io_adapter.py +++ b/torch/onnx/_internal/io_adapter.py @@ -550,3 +550,39 @@ def apply( if model_kwargs: return MergeKwargsIntoArgsInputStep().apply(updated_args, model_kwargs) return updated_args, {} + + +class PrependParamsAndBuffersAotAutogradOutputStep(OutputAdaptStep): + """Prepend model's mutated buffers to the user output. + + :func:`torch.export.export` lifts model's mutated buffers as outputs, thus, they + must be added to the user output after the model is executed. + + Args: + model: The PyTorch model with mutated buffers. + """ + + def __init__(self, model: torch_export.ExportedProgram): + assert isinstance( + model, torch_export.ExportedProgram + ), "'model' must be a torch.export.ExportedProgram." + self.model = model + + def apply(self, model_outputs: Any) -> Sequence[Any]: + """Flatten the model outputs and validate the `SpecTree` output. + + Args: + model_outputs: The model outputs to flatten. + + Returns: + flattened_outputs: The flattened model outputs. + """ + + ordered_buffers = tuple( + self.model.state_dict[name] + for name in self.model.graph_signature.buffers_to_mutate.values() + ) + + # NOTE: calling convention is first mutated buffers, then outputs args as model returned them. + updated_outputs = (*ordered_buffers, *model_outputs) + return updated_outputs From 01366efcc9d2e6f42d2783a8436a5357895140f4 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 23 Nov 2023 09:59:29 +0000 Subject: [PATCH 142/221] Revert "[pytree] register pytree node type in both C++ pytree and Python pytree (#112111)" This reverts commit 4e4a6ad6ecd71a1aefde3992ecf7f77e37d2e264. Reverted https://github.com/pytorch/pytorch/pull/112111 on behalf of https://github.com/DanilBaibak due to Break internal build ([comment](https://github.com/pytorch/pytorch/pull/112111#issuecomment-1824099658)) --- test/export/test_export.py | 15 +- test/test_fx.py | 2 +- test/test_pytree.py | 64 +----- torch/_export/utils.py | 4 +- torch/_functorch/aot_autograd.py | 15 +- torch/fx/experimental/proxy_tensor.py | 2 +- torch/fx/immutable_collections.py | 6 +- .../_internal/fx/dynamo_graph_extractor.py | 13 +- torch/return_types.py | 2 +- torch/utils/_cxx_pytree.py | 201 +----------------- torch/utils/_pytree.py | 136 +++--------- 11 files changed, 65 insertions(+), 395 deletions(-) diff --git a/test/export/test_export.py b/test/export/test_export.py index b1ae7c6b5b8e..221ea9ba075b 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -623,23 +623,16 @@ class MyDataClass: roundtrip_spec = treespec_loads(treespec_dumps(spec)) self.assertEqual(roundtrip_spec, spec) - @dataclass - class MyOtherDataClass: # the pytree registration don't allow registering the same class twice - x: int - y: int - z: int = None - # Override the registration with keep none fields - register_dataclass_as_pytree_node(MyOtherDataClass, return_none_fields=True, serialized_type_name="test_pytree_regster_data_class.MyOtherDataClass") + register_dataclass_as_pytree_node(MyDataClass, return_none_fields=True, serialized_type_name="test_pytree_regster_data_class.MyDataClass") - dt = MyOtherDataClass(x=3, y=4) flat, spec = tree_flatten(dt) self.assertEqual( spec, TreeSpec( - MyOtherDataClass, + MyDataClass, ( - MyOtherDataClass, + MyDataClass, ['x', 'y', 'z'], [], ), @@ -649,7 +642,7 @@ class MyOtherDataClass: # the pytree registration don't allow registering the s self.assertEqual(flat, [3, 4, None]) orig_dt = tree_unflatten(flat, spec) - self.assertTrue(isinstance(orig_dt, MyOtherDataClass)) + self.assertTrue(isinstance(orig_dt, MyDataClass)) self.assertEqual(orig_dt.x, 3) self.assertEqual(orig_dt.y, 4) self.assertEqual(orig_dt.z, None) diff --git a/test/test_fx.py b/test/test_fx.py index 9a9f046e1b0b..30c5f838f127 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -3529,7 +3529,7 @@ def f_dict_add(x): def f_namedtuple_add(x): return x.x + x.y - pytree.register_pytree_node( + pytree._register_pytree_node( Foo, lambda x: ([x.a, x.b], None), lambda x, _: Foo(x[0], x[1]), diff --git a/test/test_pytree.py b/test/test_pytree.py index d943db41fe7e..0c0120397eea 100644 --- a/test/test_pytree.py +++ b/test/test_pytree.py @@ -1,7 +1,7 @@ # Owner(s): ["module: pytree"] import unittest -from collections import namedtuple, OrderedDict, UserDict +from collections import namedtuple, OrderedDict import torch import torch.utils._cxx_pytree as cxx_pytree @@ -26,45 +26,6 @@ def __init__(self, x, y): class TestGenericPytree(TestCase): - @parametrize( - "pytree_impl", - [ - subtest(py_pytree, name="py"), - subtest(cxx_pytree, name="cxx"), - ], - ) - def test_register_pytree_node(self, pytree_impl): - class MyDict(UserDict): - pass - - d = MyDict(a=1, b=2, c=3) - - # Custom types are leaf nodes by default - values, spec = pytree_impl.tree_flatten(d) - self.assertEqual(values, [d]) - self.assertIs(values[0], d) - self.assertEqual(d, pytree_impl.tree_unflatten(values, spec)) - self.assertTrue(spec.is_leaf()) - - # Register MyDict as a pytree node - pytree_impl.register_pytree_node( - MyDict, - lambda d: (list(d.values()), list(d.keys())), - lambda values, keys: MyDict(zip(keys, values)), - ) - - values, spec = pytree_impl.tree_flatten(d) - self.assertEqual(values, [1, 2, 3]) - self.assertEqual(d, pytree_impl.tree_unflatten(values, spec)) - - # Do not allow registering the same type twice - with self.assertRaisesRegex(ValueError, "already registered"): - pytree_impl.register_pytree_node( - MyDict, - lambda d: (list(d.values()), list(d.keys())), - lambda values, keys: MyDict(zip(keys, values)), - ) - @parametrize( "pytree_impl", [ @@ -446,21 +407,6 @@ def test_pytree_serialize_bad_input(self, pytree_impl): class TestPythonPytree(TestCase): - def test_deprecated_register_pytree_node(self): - class DummyType: - def __init__(self, x, y): - self.x = x - self.y = y - - with self.assertWarnsRegex( - UserWarning, "torch.utils._pytree._register_pytree_node" - ): - py_pytree._register_pytree_node( - DummyType, - lambda dummy: ([dummy.x, dummy.y], None), - lambda xs, _: DummyType(*xs), - ) - def test_treespec_equality(self): self.assertTrue( py_pytree.LeafSpec() == py_pytree.LeafSpec(), @@ -594,7 +540,7 @@ def __init__(self, x, y): self.x = x self.y = y - py_pytree.register_pytree_node( + py_pytree._register_pytree_node( DummyType, lambda dummy: ([dummy.x, dummy.y], None), lambda xs, _: DummyType(*xs), @@ -614,7 +560,7 @@ def __init__(self, x, y): self.x = x self.y = y - py_pytree.register_pytree_node( + py_pytree._register_pytree_node( DummyType, lambda dummy: ([dummy.x, dummy.y], None), lambda xs, _: DummyType(*xs), @@ -639,7 +585,7 @@ def __init__(self, x, y): with self.assertRaisesRegex( ValueError, "Both to_dumpable_context and from_dumpable_context" ): - py_pytree.register_pytree_node( + py_pytree._register_pytree_node( DummyType, lambda dummy: ([dummy.x, dummy.y], None), lambda xs, _: DummyType(*xs), @@ -653,7 +599,7 @@ def __init__(self, x, y): self.x = x self.y = y - py_pytree.register_pytree_node( + py_pytree._register_pytree_node( DummyType, lambda dummy: ([dummy.x, dummy.y], None), lambda xs, _: DummyType(*xs), diff --git a/torch/_export/utils.py b/torch/_export/utils.py index d8344783a0a3..afee8efc5946 100644 --- a/torch/_export/utils.py +++ b/torch/_export/utils.py @@ -63,16 +63,16 @@ def register_dataclass_as_pytree_node( flatten_fn: Optional[FlattenFunc] = None, unflatten_fn: Optional[UnflattenFunc] = None, *, - serialized_type_name: Optional[str] = None, to_dumpable_context: Optional[ToDumpableContextFn] = None, from_dumpable_context: Optional[FromDumpableContextFn] = None, + serialized_type_name: Optional[str] = None, return_none_fields: bool = False, ) -> None: assert dataclasses.is_dataclass( cls ), f"Only dataclasses can be registered with this function: {cls}" - serialized_type = f"{cls.__module__}.{cls.__qualname__}" + serialized_type = f"{cls.__module__}.{cls.__name__}" SERIALIZED_DATACLASS_TO_PYTHON_DATACLASS[serialized_type] = cls def default_flatten_fn(obj: Any) -> Tuple[List[Any], Context]: diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py index a560db5be495..c09ab6ba9b94 100644 --- a/torch/_functorch/aot_autograd.py +++ b/torch/_functorch/aot_autograd.py @@ -29,7 +29,7 @@ from torch._subclasses import FakeTensor, FakeTensorMode from torch._subclasses.fake_tensor import is_fake from torch._subclasses.functional_tensor import FunctionalTensor, FunctionalTensorMode -from torch.fx import Interpreter +from torch.fx import immutable_collections, Interpreter from torch.fx.experimental.proxy_tensor import is_sym_node, py_sym_types from torch.fx.experimental.symbolic_shapes import ( ShapeEnv, is_concrete_int, fx_placeholder_vals, definitely_true, definitely_false, sym_eq @@ -95,6 +95,19 @@ def strict_zip(*iterables, strict=True, **kwargs): ) ) +pytree._register_pytree_node( + immutable_collections.immutable_list, + lambda x: (list(x), None), + lambda x, c: immutable_collections.immutable_list(x), +) +pytree._register_pytree_node( + immutable_collections.immutable_dict, + lambda x: (list(x.values()), list(x.keys())), + lambda x, c: immutable_collections.immutable_dict( + dict(zip(c, x)) + ), +) + def partial_asdict(obj: Any) -> Any: if dataclasses.is_dataclass(obj): return {field.name: getattr(obj, field.name) for field in dataclasses.fields(obj)} diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index e3d8bd673a4d..dd3520f541aa 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -49,7 +49,7 @@ # We currently convert all SymInt to proxies before we use them. # This could plausibly be handled at the Dynamo level. -pytree.register_pytree_node(torch.Size, lambda x: (list(x), None), lambda xs, _: tuple(xs)) +pytree._register_pytree_node(torch.Size, lambda x: (list(x), None), lambda xs, _: tuple(xs)) def fake_signature(fn, nargs): """FX gets confused by varargs, de-confuse it""" diff --git a/torch/fx/immutable_collections.py b/torch/fx/immutable_collections.py index a359335f6ece..616555015f0e 100644 --- a/torch/fx/immutable_collections.py +++ b/torch/fx/immutable_collections.py @@ -1,7 +1,7 @@ from typing import Any, Dict, Iterable, List, Tuple from ._compatibility import compatibility -from torch.utils._pytree import Context, register_pytree_node +from torch.utils._pytree import Context, _register_pytree_node __all__ = ["immutable_list", "immutable_dict"] @@ -50,5 +50,5 @@ def _immutable_list_unflatten(values: Iterable[Any], context: Context) -> List[A return immutable_list(values) -register_pytree_node(immutable_dict, _immutable_dict_flatten, _immutable_dict_unflatten) -register_pytree_node(immutable_list, _immutable_list_flatten, _immutable_list_unflatten) +_register_pytree_node(immutable_dict, _immutable_dict_flatten, _immutable_dict_unflatten) +_register_pytree_node(immutable_list, _immutable_list_flatten, _immutable_list_unflatten) diff --git a/torch/onnx/_internal/fx/dynamo_graph_extractor.py b/torch/onnx/_internal/fx/dynamo_graph_extractor.py index 79a690f5f48a..f55afefd1bbd 100644 --- a/torch/onnx/_internal/fx/dynamo_graph_extractor.py +++ b/torch/onnx/_internal/fx/dynamo_graph_extractor.py @@ -40,11 +40,7 @@ def __init__(self): def __enter__(self): for class_type, (flatten_func, unflatten_func) in self._extensions.items(): - pytree._private_register_pytree_node( - class_type, - flatten_func, - unflatten_func, - ) + pytree._register_pytree_node(class_type, flatten_func, unflatten_func) return self def __exit__(self, exc_type, exc_val, exc_tb): @@ -97,11 +93,8 @@ def model_output_unflatten( # All 'ModelOutput' subclasses are defined under module 'modeling_outputs'. named_model_output_classes = inspect.getmembers( modeling_outputs, - lambda x: ( - inspect.isclass(x) - and issubclass(x, modeling_outputs.ModelOutput) - and x is not modeling_outputs.ModelOutput - ), + lambda x: inspect.isclass(x) + and issubclass(x, modeling_outputs.ModelOutput), ) for _, class_type in named_model_output_classes: diff --git a/torch/return_types.py b/torch/return_types.py index b1284c813387..9f8c85285279 100644 --- a/torch/return_types.py +++ b/torch/return_types.py @@ -13,7 +13,7 @@ def structseq_flatten(structseq): def structseq_unflatten(values, context): return cls(values) - torch.utils._pytree.register_pytree_node(cls, structseq_flatten, structseq_unflatten) + torch.utils._pytree._register_pytree_node(cls, structseq_flatten, structseq_unflatten) for name in dir(return_types): if name.startswith('__'): diff --git a/torch/utils/_cxx_pytree.py b/torch/utils/_cxx_pytree.py index ab82367fccbe..392c0e2688db 100644 --- a/torch/utils/_cxx_pytree.py +++ b/torch/utils/_cxx_pytree.py @@ -13,7 +13,6 @@ """ import functools -import warnings from typing import ( Any, Callable, @@ -27,11 +26,6 @@ Union, ) -import torch - -if torch._running_with_deploy(): - raise ImportError("C++ pytree utilities do not work with torch::deploy.") - import optree from optree import PyTreeSpec # direct import for type annotations @@ -41,9 +35,6 @@ "Context", "FlattenFunc", "UnflattenFunc", - "DumpableContext", - "ToDumpableContextFn", - "FromDumpableContextFn", "TreeSpec", "LeafSpec", "register_pytree_node", @@ -77,9 +68,6 @@ FlattenFunc = Callable[[PyTree], Tuple[List, Context]] UnflattenFunc = Callable[[Iterable, Context], PyTree] OpTreeUnflattenFunc = Callable[[Context, Iterable], PyTree] -DumpableContext = Any # Any json dumpable text -ToDumpableContextFn = Callable[[Context], DumpableContext] -FromDumpableContextFn = Callable[[DumpableContext], Context] def _reverse_args(func: UnflattenFunc) -> OpTreeUnflattenFunc: @@ -96,11 +84,9 @@ def register_pytree_node( unflatten_fn: UnflattenFunc, *, serialized_type_name: Optional[str] = None, - to_dumpable_context: Optional[ToDumpableContextFn] = None, - from_dumpable_context: Optional[FromDumpableContextFn] = None, namespace: str = "torch", ) -> None: - """Register a container-like type as pytree node. + """Extend the set of types that are considered internal nodes in pytrees. The ``namespace`` argument is used to avoid collisions that occur when different libraries register the same Python type with different behaviors. It is recommended to add a unique prefix @@ -123,13 +109,6 @@ def register_pytree_node( The function should return an instance of ``cls``. serialized_type_name (str, optional): A keyword argument used to specify the fully qualified name used when serializing the tree spec. - to_dumpable_context (callable, optional): An optional keyword argument to custom specify how - to convert the context of the pytree to a custom json dumpable representation. This is - used for json serialization, which is being used in :mod:`torch.export` right now. - from_dumpable_context (callable, optional): An optional keyword argument to custom specify - how to convert the custom json dumpable representation of the context back to the - original context. This is used for json deserialization, which is being used in - :mod:`torch.export` right now. namespace (str, optional): A non-empty string that uniquely identifies the namespace of the type registry. This is used to isolate the registry from other modules that might register a different custom behavior for the same type. (default: :const:`"torch"`) @@ -214,192 +193,24 @@ def register_pytree_node( ) ) """ - _private_register_pytree_node( - cls, - flatten_fn, - unflatten_fn, - serialized_type_name=serialized_type_name, - to_dumpable_context=to_dumpable_context, - from_dumpable_context=from_dumpable_context, - namespace=namespace, - ) - - from . import _pytree as python + from ._pytree import _register_pytree_node - python._private_register_pytree_node( + _register_pytree_node( cls, flatten_fn, unflatten_fn, serialized_type_name=serialized_type_name, - to_dumpable_context=to_dumpable_context, - from_dumpable_context=from_dumpable_context, ) - -def _register_pytree_node( - cls: Type[Any], - flatten_fn: FlattenFunc, - unflatten_fn: UnflattenFunc, - *, - serialized_type_name: Optional[str] = None, - to_dumpable_context: Optional[ToDumpableContextFn] = None, - from_dumpable_context: Optional[FromDumpableContextFn] = None, - namespace: str = "torch", -) -> None: - """Register a container-like type as pytree node for the C++ pytree only. - - The ``namespace`` argument is used to avoid collisions that occur when different libraries - register the same Python type with different behaviors. It is recommended to add a unique prefix - to the namespace to avoid conflicts with other libraries. Namespaces can also be used to specify - the same class in different namespaces for different use cases. - - .. warning:: - For safety reasons, a ``namespace`` must be specified while registering a custom type. It is - used to isolate the behavior of flattening and unflattening a pytree node type. This is to - prevent accidental collisions between different libraries that may register the same type. - - Args: - cls (type): A Python type to treat as an internal pytree node. - flatten_fn (callable): A function to be used during flattening, taking an instance of - ``cls`` and returning a pair, with (1) an iterable for the children to be flattened - recursively, and (2) some hashable auxiliary data to be stored in the treespec and to be - passed to the ``unflatten_fn``. - unflatten_fn (callable): A function taking two arguments: the auxiliary data that was - returned by ``flatten_fn`` and stored in the treespec, and the unflattened children. - The function should return an instance of ``cls``. - serialized_type_name (str, optional): A keyword argument used to specify the fully - qualified name used when serializing the tree spec. - to_dumpable_context (callable, optional): An optional keyword argument to custom specify how - to convert the context of the pytree to a custom json dumpable representation. This is - used for json serialization, which is being used in :mod:`torch.export` right now. - from_dumpable_context (callable, optional): An optional keyword argument to custom specify - how to convert the custom json dumpable representation of the context back to the - original context. This is used for json deserialization, which is being used in - :mod:`torch.export` right now. - namespace (str, optional): A non-empty string that uniquely identifies the namespace of the - type registry. This is used to isolate the registry from other modules that might - register a different custom behavior for the same type. (default: :const:`"torch"`) - - Example:: - - >>> # xdoctest: +SKIP - >>> # Registry a Python type with lambda functions - >>> register_pytree_node( - ... set, - ... lambda s: (sorted(s), None, None), - ... lambda children, _: set(children), - ... namespace='set', - ... ) - - >>> # xdoctest: +SKIP - >>> # Register a Python type into a namespace - >>> import torch - >>> register_pytree_node( - ... torch.Tensor, - ... flatten_func=lambda tensor: ( - ... (tensor.cpu().detach().numpy(),), - ... {'dtype': tensor.dtype, 'device': tensor.device, 'requires_grad': tensor.requires_grad}, - ... ), - ... unflatten_func=lambda children, metadata: torch.tensor(children[0], **metadata), - ... namespace='torch2numpy', - ... ) - - >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) - >>> tree = {'weight': torch.ones(size=(1, 2)).cuda(), 'bias': torch.zeros(size=(2,))} - >>> tree - {'weight': tensor([[1., 1.]], device='cuda:0'), 'bias': tensor([0., 0.])} - - >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) - >>> # Flatten without specifying the namespace - >>> tree_flatten(tree) # `torch.Tensor`s are leaf nodes # xdoctest: +SKIP - ([tensor([0., 0.]), tensor([[1., 1.]], device='cuda:0')], PyTreeSpec({'bias': *, 'weight': *})) - - >>> # xdoctest: +SKIP - >>> # Flatten with the namespace - >>> tree_flatten(tree, namespace='torch2numpy') # xdoctest: +SKIP - ( - [array([0., 0.], dtype=float32), array([[1., 1.]], dtype=float32)], - PyTreeSpec( - { - 'bias': CustomTreeNode(Tensor[{'dtype': torch.float32, ...}], [*]), - 'weight': CustomTreeNode(Tensor[{'dtype': torch.float32, ...}], [*]) - }, - namespace='torch2numpy' - ) - ) - - >>> # xdoctest: +SKIP - >>> # Register the same type with a different namespace for different behaviors - >>> def tensor2flatparam(tensor): - ... return [torch.nn.Parameter(tensor.reshape(-1))], tensor.shape, None - ... - >>> def flatparam2tensor(children, metadata): - ... return children[0].reshape(metadata) - ... - >>> register_pytree_node( - ... torch.Tensor, - ... flatten_func=tensor2flatparam, - ... unflatten_func=flatparam2tensor, - ... namespace='tensor2flatparam', - ... ) - - >>> # xdoctest: +SKIP - >>> # Flatten with the new namespace - >>> tree_flatten(tree, namespace='tensor2flatparam') # xdoctest: +SKIP - ( - [ - Parameter containing: tensor([0., 0.], requires_grad=True), - Parameter containing: tensor([1., 1.], device='cuda:0', requires_grad=True) - ], - PyTreeSpec( - { - 'bias': CustomTreeNode(Tensor[torch.Size([2])], [*]), - 'weight': CustomTreeNode(Tensor[torch.Size([1, 2])], [*]) - }, - namespace='tensor2flatparam' - ) - ) - """ - warnings.warn( - "torch.utils._cxx_pytree._register_pytree_node is deprecated. " - "Please use torch.utils._cxx_pytree.register_pytree_node instead.", - stacklevel=2, - ) - - _private_register_pytree_node( + optree.register_pytree_node( cls, flatten_fn, - unflatten_fn, - serialized_type_name=serialized_type_name, - to_dumpable_context=to_dumpable_context, - from_dumpable_context=from_dumpable_context, + _reverse_args(unflatten_fn), namespace=namespace, ) -def _private_register_pytree_node( - cls: Type[Any], - flatten_fn: FlattenFunc, - unflatten_fn: UnflattenFunc, - *, - serialized_type_name: Optional[str] = None, - to_dumpable_context: Optional[ToDumpableContextFn] = None, - from_dumpable_context: Optional[FromDumpableContextFn] = None, - namespace: str = "torch", -) -> None: - """This is an internal function that is used to register a pytree node type - for the C++ pytree only. End-users should use :func:`register_pytree_node` - instead. - """ - # TODO(XuehaiPan): remove this condition when we make Python pytree out-of-box support - # PyStructSequence types - if not optree.is_structseq_class(cls): - optree.register_pytree_node( - cls, - flatten_fn, - _reverse_args(unflatten_fn), - namespace=namespace, - ) +_register_pytree_node = register_pytree_node def tree_flatten( diff --git a/torch/utils/_pytree.py b/torch/utils/_pytree.py index 5faa6c7c16ad..f74d4a76e5b8 100644 --- a/torch/utils/_pytree.py +++ b/torch/utils/_pytree.py @@ -17,7 +17,6 @@ import dataclasses import json -import threading import warnings from collections import deque, namedtuple, OrderedDict from typing import ( @@ -100,7 +99,6 @@ class NodeDef(NamedTuple): unflatten_fn: UnflattenFunc -_NODE_REGISTRY_LOCK = threading.Lock() SUPPORTED_NODES: Dict[Type[Any], NodeDef] = {} @@ -122,59 +120,6 @@ class _SerializeNodeDef(NamedTuple): SERIALIZED_TYPE_TO_PYTHON_TYPE: Dict[str, Type[Any]] = {} -def register_pytree_node( - cls: Any, - flatten_fn: FlattenFunc, - unflatten_fn: UnflattenFunc, - *, - serialized_type_name: Optional[str] = None, - to_dumpable_context: Optional[ToDumpableContextFn] = None, - from_dumpable_context: Optional[FromDumpableContextFn] = None, -) -> None: - """Register a container-like type as pytree node. - - Args: - cls: the type to register - flatten_fn: A callable that takes a pytree and returns a flattened - representation of the pytree and additional context to represent the - flattened pytree. - unflatten_fn: A callable that takes a flattened version of the pytree, - additional context, and returns an unflattened pytree. - serialized_type_name: A keyword argument used to specify the fully qualified - name used when serializing the tree spec. - to_dumpable_context: An optional keyword argument to custom specify how - to convert the context of the pytree to a custom json dumpable - representation. This is used for json serialization, which is being - used in torch.export right now. - from_dumpable_context: An optional keyword argument to custom specify how - to convert the custom json dumpable representation of the context - back to the original context. This is used for json deserialization, - which is being used in torch.export right now. - """ - _private_register_pytree_node( - cls, - flatten_fn, - unflatten_fn, - serialized_type_name=serialized_type_name, - to_dumpable_context=to_dumpable_context, - from_dumpable_context=from_dumpable_context, - ) - - try: - from . import _cxx_pytree as cxx - except ImportError: - pass - else: - cxx._private_register_pytree_node( - cls, - flatten_fn, - unflatten_fn, - serialized_type_name=serialized_type_name, - to_dumpable_context=to_dumpable_context, - from_dumpable_context=from_dumpable_context, - ) - - def _register_pytree_node( cls: Any, flatten_fn: FlattenFunc, @@ -186,8 +131,7 @@ def _register_pytree_node( to_dumpable_context: Optional[ToDumpableContextFn] = None, from_dumpable_context: Optional[FromDumpableContextFn] = None, ) -> None: - """Register a container-like type as pytree node for the Python pytree only. - + """ Args: cls: the type to register flatten_fn: A callable that takes a pytree and returns a flattened @@ -206,69 +150,39 @@ def _register_pytree_node( back to the original context. This is used for json deserialization, which is being used in torch.export right now. """ - warnings.warn( - "torch.utils._pytree._register_pytree_node is deprecated. " - "Please use torch.utils._pytree.register_pytree_node instead.", - stacklevel=2, - ) - if to_str_fn is not None or maybe_from_str_fn is not None: warnings.warn( "to_str_fn and maybe_from_str_fn is deprecated. " "Please use to_dumpable_context and from_dumpable_context instead." ) - _private_register_pytree_node( + node_def = NodeDef( cls, flatten_fn, unflatten_fn, - serialized_type_name=serialized_type_name, - to_dumpable_context=to_dumpable_context, - from_dumpable_context=from_dumpable_context, ) + SUPPORTED_NODES[cls] = node_def - -def _private_register_pytree_node( - cls: Any, - flatten_fn: FlattenFunc, - unflatten_fn: UnflattenFunc, - *, - serialized_type_name: Optional[str] = None, - to_dumpable_context: Optional[ToDumpableContextFn] = None, - from_dumpable_context: Optional[FromDumpableContextFn] = None, -) -> None: - """This is an internal function that is used to register a pytree node type - for the Python pytree only. End-users should use :func:`register_pytree_node` - instead. - """ - with _NODE_REGISTRY_LOCK: - if cls in SUPPORTED_NODES: - raise ValueError(f"{cls} is already registered as pytree node.") - - node_def = NodeDef( - cls, - flatten_fn, - unflatten_fn, + if (to_dumpable_context is None) ^ (from_dumpable_context is None): + raise ValueError( + f"Both to_dumpable_context and from_dumpable_context for {cls} must " + "be None or registered." ) - SUPPORTED_NODES[cls] = node_def - if (to_dumpable_context is None) ^ (from_dumpable_context is None): - raise ValueError( - f"Both to_dumpable_context and from_dumpable_context for {cls} must " - "be None or registered." - ) + if serialized_type_name is None: + serialized_type_name = f"{cls.__module__}.{cls.__name__}" - if serialized_type_name is None: - serialized_type_name = f"{cls.__module__}.{cls.__qualname__}" + serialize_node_def = _SerializeNodeDef( + cls, + serialized_type_name, + to_dumpable_context, + from_dumpable_context, + ) + SUPPORTED_SERIALIZED_TYPES[cls] = serialize_node_def + SERIALIZED_TYPE_TO_PYTHON_TYPE[serialized_type_name] = cls - serialize_node_def = _SerializeNodeDef( - cls, - serialized_type_name, - to_dumpable_context, - from_dumpable_context, - ) - SUPPORTED_SERIALIZED_TYPES[cls] = serialize_node_def - SERIALIZED_TYPE_TO_PYTHON_TYPE[serialized_type_name] = cls + +register_pytree_node = _register_pytree_node def _dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]: @@ -329,25 +243,25 @@ def _odict_unflatten( return OrderedDict((key, value) for key, value in zip(context, values)) -_private_register_pytree_node( +_register_pytree_node( dict, _dict_flatten, _dict_unflatten, serialized_type_name="builtins.dict", ) -_private_register_pytree_node( +_register_pytree_node( list, _list_flatten, _list_unflatten, serialized_type_name="builtins.list", ) -_private_register_pytree_node( +_register_pytree_node( tuple, _tuple_flatten, _tuple_unflatten, serialized_type_name="builtins.tuple", ) -_private_register_pytree_node( +_register_pytree_node( namedtuple, _namedtuple_flatten, _namedtuple_unflatten, @@ -355,7 +269,7 @@ def _odict_unflatten( from_dumpable_context=_namedtuple_deserialize, serialized_type_name="collections.namedtuple", ) -_private_register_pytree_node( +_register_pytree_node( OrderedDict, _odict_flatten, _odict_unflatten, @@ -815,7 +729,7 @@ def _treespec_to_json(treespec: TreeSpec) -> _TreeSpecSchema: if treespec.type not in SUPPORTED_SERIALIZED_TYPES: raise NotImplementedError( - f"Serializing {treespec.type} in pytree is not registered.", + f"Serializing {treespec.type} in pytree is not registered." ) serialize_node_def = SUPPORTED_SERIALIZED_TYPES[treespec.type] From 85aa3723749e0d06aa5fd34215b9b93529a60995 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Thu, 23 Nov 2023 13:13:06 +0000 Subject: [PATCH 143/221] [inductor] Fixed conv issue with dynamic shapes (#114351) EDIT: fixes https://github.com/pytorch/pytorch/issues/114354 Description: The following code is failing: ```python import torch def func(x, w): return torch.nn.functional.conv2d(x, w, groups=int(w.shape[0])) x = torch.rand(1, 3, 64, 64) w = torch.rand(3, 1, 3, 3) y1 = func(x, w) cfunc = torch.compile(func, fullgraph=True, dynamic=True) y2 = cfunc(x, w) torch.testing.assert_close(y1, y2) ``` with the error: ``` File "/pytorch/torch/_inductor/kernel/conv.py", line 315, in convolution assert isinstance(groups, int) torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised: LoweringException: AssertionError: target: aten.convolution.default args[0]: TensorBox(StorageBox( InputBuffer(name='arg3_1', layout=FixedLayout('cpu', torch.float32, size=[1, s0, s1, s1], stride=[s0*s1**2, s1**2, s1, 1])) )) args[1]: TensorBox(StorageBox( InputBuffer(name='arg1_1', layout=FixedLayout('cpu', torch.float32, size=[s0, 1, s0, s0], stride=[s0**2, s0**2, s0, 1])) )) args[2]: None args[3]: [1, 1] args[4]: [0, 0] args[5]: [1, 1] args[6]: False args[7]: [0, 0] args[8]: s0 ``` where `groups` argument is a symbol but expected to be `int`. This PR specializes `group` to its int value and fixes the problem. Context: Failing tests in torchvision with gaussian blur and adjust_sharpness ops - https://github.com/pytorch/vision/actions/runs/6955843968/job/18926393710?pr=8127 Pull Request resolved: https://github.com/pytorch/pytorch/pull/114351 Approved by: https://github.com/ezyang --- test/inductor/test_torchinductor.py | 14 ++++++++++++++ torch/_inductor/kernel/conv.py | 2 ++ 2 files changed, 16 insertions(+) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 56e4db86b249..02787e019317 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -2514,6 +2514,20 @@ def test_convolution3(self): rtol=0.001, ) + @skipIfRocm + def test_convolution4(self): + def fn(x, w): + x = F.conv2d(x, w, groups=w.shape[0]) + return x.sum() + + self.common( + fn, + ( + torch.randn([2, 3, 16, 20]), + torch.randn([3, 1, 5, 5]), + ), + ) + def test_conv2d_channels_last(self): if self.device == "cuda": raise unittest.SkipTest("only support cpu conv2d channels_last") diff --git a/torch/_inductor/kernel/conv.py b/torch/_inductor/kernel/conv.py index 0ef78bc4f785..849058349778 100644 --- a/torch/_inductor/kernel/conv.py +++ b/torch/_inductor/kernel/conv.py @@ -310,6 +310,8 @@ def convolution( padding = tuple(padding) dilation = tuple(dilation) output_padding = tuple(output_padding) + if not isinstance(groups, int): + groups = V.graph.sizevars.evaluate_static_shape(groups) assert isinstance(groups, int) kwargs: ConvLayoutParams = { "stride": stride, From fd1a01a393cfe26205795a7fae48ff4e49b5a564 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Thu, 23 Nov 2023 19:07:38 +0000 Subject: [PATCH 144/221] Set default LR value of SGD to 1e-3 (#114467) Fixes https://github.com/pytorch/pytorch/issues/114089 Set the lr to 1e-3 in SGD to increase the consistency of input signature of optimizers. @janeyx99 This should be the redacted PR #114434 , sincerely. Pull Request resolved: https://github.com/pytorch/pytorch/pull/114467 Approved by: https://github.com/janeyx99 --- torch/optim/sgd.py | 10 +++++----- torch/optim/sgd.pyi | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/torch/optim/sgd.py b/torch/optim/sgd.py index 23f91c39f1df..8c7d73b83a2b 100644 --- a/torch/optim/sgd.py +++ b/torch/optim/sgd.py @@ -1,16 +1,17 @@ import torch from torch import Tensor -from .optimizer import (Optimizer, required, _use_grad_for_differentiable, _default_to_fused_or_foreach, +from .optimizer import (Optimizer, _use_grad_for_differentiable, _default_to_fused_or_foreach, _differentiable_doc, _foreach_doc, _maximize_doc) from typing import List, Optional __all__ = ['SGD', 'sgd'] + class SGD(Optimizer): - def __init__(self, params, lr=required, momentum=0, dampening=0, + def __init__(self, params, lr=1e-3, momentum=0, dampening=0, weight_decay=0, nesterov=False, *, maximize: bool = False, foreach: Optional[bool] = None, differentiable: bool = False): - if lr is not required and lr < 0.0: + if lr < 0.0: raise ValueError(f"Invalid learning rate: {lr}") if momentum < 0.0: raise ValueError(f"Invalid momentum value: {momentum}") @@ -51,7 +52,6 @@ def _init_group(self, group, params_with_grad, d_p_list, momentum_buffer_list): return has_sparse_grad - @_use_grad_for_differentiable def step(self, closure=None): """Performs a single optimization step. @@ -130,7 +130,7 @@ def step(self, closure=None): Args: params (iterable): iterable of parameters to optimize or dicts defining parameter groups - lr (float): learning rate + lr (float, optional): learning rate (default: 1e-3) momentum (float, optional): momentum factor (default: 0) weight_decay (float, optional): weight decay (L2 penalty) (default: 0) dampening (float, optional): dampening for momentum (default: 0) diff --git a/torch/optim/sgd.pyi b/torch/optim/sgd.pyi index 48721a434bbd..ba1bcd60a1b8 100644 --- a/torch/optim/sgd.pyi +++ b/torch/optim/sgd.pyi @@ -4,7 +4,7 @@ class SGD(Optimizer): def __init__( self, params: ParamsT, - lr: float, + lr: float = ..., momentum: float = ..., dampening: float = ..., weight_decay: float = ..., From a28876832c8b32278bbb2fbd2e4fd7efe7bcfc57 Mon Sep 17 00:00:00 2001 From: Tobias Ringwald Date: Thu, 23 Nov 2023 21:17:43 +0000 Subject: [PATCH 145/221] Fixed an export problem when moving tensors to CPU during `torch.export.save` (#114029) For whatever reason calling`.cpu()` on a `nn.Parameter` wrapping a CUDA tensor will return a plain (non-parameter) tensor. This PR fixes the symptom in the linked issue, but not the underlying issue. Fixes #113999. Pull Request resolved: https://github.com/pytorch/pytorch/pull/114029 Approved by: https://github.com/zhxchen17 --- test/dynamo/test_export.py | 33 ++++++++++++++++++++++++++++++-- torch/_export/serde/serialize.py | 2 ++ 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py index 39b3040c3a83..f8469ff15ca4 100644 --- a/test/dynamo/test_export.py +++ b/test/dynamo/test_export.py @@ -6,6 +6,7 @@ import copy import functools import inspect +import io import math import operator import unittest @@ -32,6 +33,7 @@ ShapeEnv, ) from torch.testing._internal import common_utils +from torch.testing._internal.common_cuda import TEST_CUDA class ExportTests(torch._dynamo.test_case.TestCase): @@ -2289,6 +2291,35 @@ def qux(x, y): ): torch.export.export(qux, (torch.tensor(3), 5)) + @unittest.skipIf(not TEST_CUDA, "No CUDA available.") + def test_export_with_parameters(self): + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.features = torch.nn.Sequential( + torch.nn.Conv2d( + 3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1) + ), + torch.nn.ReLU(inplace=True), + ) + + def forward(self, x): + return self.features(x) + + model = MyModule().eval().cuda() + random_inputs = (torch.rand([32, 3, 32, 32]).to("cuda"),) + dim_x = torch.export.Dim("dim_x", min=1, max=32) + exp_program = torch.export.export( + model, random_inputs, dynamic_shapes={"x": {0: dim_x}} + ) + output_buffer = io.BytesIO() + # Tests if we can restore saved nn.Parameters when we load them again + torch.export.save(exp_program, output_buffer) + loaded_model = torch.export.load(output_buffer) + self.assertTrue( + isinstance(loaded_model.module().features_0_weight, torch.nn.Parameter) + ) + def test_export_meta(self): class MyModule(torch.nn.Module): def __init__(self): @@ -2571,8 +2602,6 @@ def f(x): capture_scalar_outputs=True, ) def test_exported_graph_serialization(self): - import io - def f(x, y): b = x.item() torch._constrain_as_size(b) diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py index 8324a769d991..44b8f8727647 100644 --- a/torch/_export/serde/serialize.py +++ b/torch/_export/serde/serialize.py @@ -229,6 +229,8 @@ def serialize_torch_artifact(artifact) -> bytes: def _tensor_to_cpu(t: torch.Tensor): if t.is_meta: return t + elif isinstance(t, torch.nn.Parameter): + return torch.nn.Parameter(t.cpu()) else: return t.cpu() artifact = tree_map_only(torch.Tensor, _tensor_to_cpu, artifact) From b76e2949f77ede4bdcc8af296dc5ecd7b098392e Mon Sep 17 00:00:00 2001 From: cyy Date: Thu, 23 Nov 2023 21:20:41 +0000 Subject: [PATCH 146/221] Fix pool_size type in TaskThreadPool (#114063) As negative values of pool_size mean calling defaultNumThreads() Pull Request resolved: https://github.com/pytorch/pytorch/pull/114063 Approved by: https://github.com/Skylion007 --- c10/core/thread_pool.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/c10/core/thread_pool.h b/c10/core/thread_pool.h index 49090d678022..50bc6f99ac86 100644 --- a/c10/core/thread_pool.h +++ b/c10/core/thread_pool.h @@ -102,7 +102,7 @@ class C10_API ThreadPool : public c10::TaskThreadPoolBase { class C10_API TaskThreadPool : public c10::ThreadPool { public: - explicit TaskThreadPool(std::size_t pool_size, int numa_node_id = -1) + explicit TaskThreadPool(int pool_size, int numa_node_id = -1) : ThreadPool(pool_size, numa_node_id, [numa_node_id]() { setThreadName("CaffeTaskThread"); NUMABind(numa_node_id); From a378ae33e99fbdcc2c73c6456947212cb3fe647c Mon Sep 17 00:00:00 2001 From: Tugsbayasgalan Manlaibaatar Date: Thu, 23 Nov 2023 11:07:51 -0800 Subject: [PATCH 147/221] [BE][aot_autograd] Remove mutated_inp_indices (#114421) We should use mutated_inp_runtime_indices moving forward Pull Request resolved: https://github.com/pytorch/pytorch/pull/114421 Approved by: https://github.com/zhxchen17 --- torch/_functorch/aot_autograd.py | 19 ++++++------------- torch/_inductor/freezing.py | 2 +- 2 files changed, 7 insertions(+), 14 deletions(-) diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py index c09ab6ba9b94..4e17864fd314 100644 --- a/torch/_functorch/aot_autograd.py +++ b/torch/_functorch/aot_autograd.py @@ -615,10 +615,6 @@ class ViewAndMutationMeta: grad_enabled_mutation: Optional[bool] = None def __post_init__(self): - mutated_inp_indices = [ - i for i, m in enumerate(self.input_info) - if m.mutation_type in (MutationType.MUTATED_IN_GRAPH, MutationType.MUTATED_OUT_GRAPH) - ] # pre-compute the indices of the inputs that are mutated. # When keep_input_mutations is set, we don't need to worry about our epilogue # handling data-only mutations, because we keep them directly in the graph. @@ -632,7 +628,10 @@ def __post_init__(self): if (m.mutation_type == MutationType.MUTATED_OUT_GRAPH) ] else: - mutated_inp_runtime_indices = mutated_inp_indices + mutated_inp_runtime_indices = [ + i for i, m in enumerate(self.input_info) + if m.mutation_type in (MutationType.MUTATED_IN_GRAPH, MutationType.MUTATED_OUT_GRAPH) + ] mutated_graph_handled_indices = [ i for i, m in enumerate(self.input_info) @@ -650,18 +649,12 @@ def __post_init__(self): i for i, m in enumerate(self.output_info) if m.output_type is OutputType.unsafe_view_alias ] - self.mutated_inp_indices = mutated_inp_indices # This is pre-computed in post_init for perf. # It contains the index of every element # of input_info that corresponds to a mutation (data or metadata or both) self.mutated_inp_runtime_indices = mutated_inp_runtime_indices self.num_mutated_inp_runtime_indices = len(self.mutated_inp_runtime_indices) - assert ( - self.num_mutated_graph_handled_indices + self.num_mutated_inp_runtime_indices == - len(mutated_inp_indices) - ) - # This is pre-computed for perf. # It contains the index of every element # of output_info that corresponds to an alias (either of an input or intermediate) @@ -4004,7 +3997,7 @@ def forward(ctx, *deduped_flat_tensor_args): # so that autograd.Function doesn't treat them as tensors if num_mutated_metadata_only_inputs > 0: for i, idx in enumerate( - CompiledFunction.metadata.mutated_inp_indices + CompiledFunction.metadata.mutated_inp_runtime_indices ): # We could make this faster by only looping over inputs with metadata-only mutations # (instead of looping over inputs with either data or metadata mutations), but there shouldn't be many. @@ -4500,7 +4493,7 @@ def convert(idx, x): if aot_config.is_export: mutated_user_inp_locs = [ idx - aot_config.num_params_buffers - for idx in fw_metadata.mutated_inp_indices + for idx in fw_metadata.mutated_inp_runtime_indices if idx >= aot_config.num_params_buffers ] if len(mutated_user_inp_locs) > 0: diff --git a/torch/_inductor/freezing.py b/torch/_inductor/freezing.py index 09348d4a0d4e..0cd7335c2065 100644 --- a/torch/_inductor/freezing.py +++ b/torch/_inductor/freezing.py @@ -41,7 +41,7 @@ def replace_params_with_constants( if out_info.base_idx is not None ] for i, (real_input, node) in enumerate(zip(flat_params, fake_inp_nodes)): - if i in fw_metadata.mutated_inp_indices or i in aliased_input_args: + if i in fw_metadata.mutated_inp_runtime_indices or i in aliased_input_args: preserved_arg_indices.append(i) continue replace_node_with_constant(gm, node, real_input) From e6e650d5eb06fde446e69c2b4f09285a974925a8 Mon Sep 17 00:00:00 2001 From: Tugsbayasgalan Manlaibaatar Date: Thu, 23 Nov 2023 11:40:03 -0800 Subject: [PATCH 148/221] [BE][aot_autograd] Remove num_mutated_inputs (#114479) Pull Request resolved: https://github.com/pytorch/pytorch/pull/114479 Approved by: https://github.com/zhxchen17 ghstack dependencies: #114421 --- torch/_functorch/aot_autograd.py | 24 +++++++++--------------- 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py index 4e17864fd314..5d2809d4c0e7 100644 --- a/torch/_functorch/aot_autograd.py +++ b/torch/_functorch/aot_autograd.py @@ -708,7 +708,6 @@ def __post_init__(self): if not x.mutates_data and x.mutates_metadata ] ) - self.num_mutated_inputs = self.num_mutated_data_inputs + self.num_mutated_metadata_only_inputs self.dynamic_outputs = any( o.dynamic_dims for o in self.output_info ) @@ -1605,9 +1604,9 @@ def from_tracing_metadata( buffer_name = state_names[idx] mutated_buffers.append(buffer_name) - assert len(mutated_buffers) == view_mutation_metadata.num_mutated_inputs + assert len(mutated_buffers) == view_mutation_metadata.num_mutated_inp_runtime_indices - start, stop = 0, view_mutation_metadata.num_mutated_inputs + start, stop = 0, view_mutation_metadata.num_mutated_inp_runtime_indices buffers_to_mutate = dict(zip(graph_outputs[start:stop], mutated_buffers)) start, stop = stop, stop + num_user_outputs @@ -3173,7 +3172,6 @@ def runtime_wrapper(*args): disable_amp=disable_amp, ) - num_mutated_inps = runtime_metadata.num_mutated_inputs num_mutated_runtime_inps = runtime_metadata.num_mutated_inp_runtime_indices num_metadata_mutated_inps = runtime_metadata.num_mutated_metadata_inputs num_intermediate_bases = runtime_metadata.num_intermediate_bases @@ -3967,7 +3965,6 @@ def forward(ctx, *deduped_flat_tensor_args): num_outputs_aliased = CompiledFunction.metadata.num_outputs_aliased num_intermediate_bases = CompiledFunction.metadata.num_intermediate_bases num_symints_saved_for_bw = CompiledFunction.num_symints_saved_for_bw - num_mutated_inputs = CompiledFunction.metadata.num_mutated_inputs num_mutated_runtime_inps = CompiledFunction.metadata.num_mutated_inp_runtime_indices num_mutated_metadata_only_inputs = ( CompiledFunction.metadata.num_mutated_metadata_only_inputs @@ -4006,7 +4003,7 @@ def forward(ctx, *deduped_flat_tensor_args): raw_returns[i] = TensorAlias(raw_returns[i]) if config.debug_assert: - user_mutated_inputs_raw = raw_returns[0:num_mutated_inputs] + user_mutated_inputs_raw = raw_returns[0:num_mutated_runtime_inps] mut_inp_infos = [ x for x in CompiledFunction.metadata.input_info if x.mutates_data or x.mutates_metadata ] @@ -4064,7 +4061,6 @@ def backward(ctx, *flat_args): # - updated inputs due to metadata-only mutations. # We need to return them in the forward, but ensure that they all do not get gradients in the backward, # and we filter them out here before passing the remaining grad_outputs into the compiled backward. - num_mutated_inps = CompiledFunction.metadata.num_mutated_inputs num_intermediate_bases = CompiledFunction.metadata.num_intermediate_bases num_graph_handled_inputs = CompiledFunction.metadata.num_mutated_graph_handled_indices num_mutated_runtime_inps = CompiledFunction.metadata.num_mutated_inp_runtime_indices @@ -4077,12 +4073,10 @@ def backward(ctx, *flat_args): assert len(flat_args) == expected_grad_outs out_info = CompiledFunction.metadata.output_info - num_mutated_inps_returned = CompiledFunction.metadata.num_mutated_inp_runtime_indices - inp_tangents, out_tangents, intermediate_base_tangents = ( - flat_args[0:num_mutated_inps_returned], - flat_args[num_mutated_inps_returned:num_mutated_inps_returned + CompiledFunction.metadata.num_outputs], - flat_args[num_mutated_inps_returned + CompiledFunction.metadata.num_outputs:], + flat_args[0:num_mutated_runtime_inps], + flat_args[num_mutated_runtime_inps:num_mutated_runtime_inps + CompiledFunction.metadata.num_outputs], + flat_args[num_mutated_runtime_inps + CompiledFunction.metadata.num_outputs:], ) # input_info contains info on *every* input, # But in the backward(), we are only given grad outputs for every mutated input @@ -4643,7 +4637,7 @@ def create_graph_signature( if trace_joint: assert num_user_fw_outs is not None - num_fw_outs = num_user_fw_outs + fw_metadata.num_mutated_inputs + num_fw_outs = num_user_fw_outs + fw_metadata.num_mutated_inp_runtime_indices backward_output_names = graph_output_names[num_fw_outs:] grad_index = itertools.count(0) @@ -4671,7 +4665,7 @@ def create_graph_signature( ) else: backward_signature = None - num_user_fw_outs = len(graph_output_names) - fw_metadata.num_mutated_inputs + num_user_fw_outs = len(graph_output_names) - fw_metadata.num_mutated_inp_runtime_indices return GraphSignature.from_tracing_metadata( in_spec=in_spec, @@ -5110,7 +5104,7 @@ def flattened_joint(*args): # and there are therefore no tangents that are needed to run the joint graph. # This function "fixes" both of the above by removing any tangent inputs, # and removing pytrees from the original FX graph. - fake_tangents = [None for _ in range(metadata.num_outputs + metadata.num_mutated_inputs)] + fake_tangents = [None for _ in range(metadata.num_outputs + metadata.num_mutated_inp_runtime_indices)] fw_outs, gradients = fx_g(args, fake_tangents) assert len(gradients) == len(args) output_gradients = [] From fa71f5efdc72460f3c22867d641c7bab2dcfa54e Mon Sep 17 00:00:00 2001 From: Tugsbayasgalan Manlaibaatar Date: Thu, 23 Nov 2023 11:46:21 -0800 Subject: [PATCH 149/221] [BE][aot_autograd] Remove unnecessary fields from ViewMutationData (#114481) Pull Request resolved: https://github.com/pytorch/pytorch/pull/114481 Approved by: https://github.com/zhxchen17 ghstack dependencies: #114421, #114479 --- torch/_functorch/aot_autograd.py | 25 ++----------------------- 1 file changed, 2 insertions(+), 23 deletions(-) diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py index 5d2809d4c0e7..92ae0f6314f7 100644 --- a/torch/_functorch/aot_autograd.py +++ b/torch/_functorch/aot_autograd.py @@ -690,24 +690,7 @@ def __post_init__(self): self.num_outputs_aliased = ( self.num_outputs_aliased_to_inputs + self.num_outputs_aliased_to_intermediates ) - self.num_mutated_data_inputs = len( - [x for x in self.input_info if x.mutates_data] - ) - self.num_mutated_metadata_inputs = len( - [ - x - for x in self.input_info - if x.mutates_metadata - ] - ) - self.num_mutated_metadata_only_inputs = len( - [ - x - for x in self.input_info - if not x.mutates_data and x.mutates_metadata - ] - ) self.dynamic_outputs = any( o.dynamic_dims for o in self.output_info ) @@ -3173,7 +3156,6 @@ def runtime_wrapper(*args): ) num_mutated_runtime_inps = runtime_metadata.num_mutated_inp_runtime_indices - num_metadata_mutated_inps = runtime_metadata.num_mutated_metadata_inputs num_intermediate_bases = runtime_metadata.num_intermediate_bases if keep_input_mutations and trace_joint: @@ -3509,7 +3491,7 @@ def create_metadata_for_subclass(meta: ViewAndMutationMeta) -> ViewAndMutationMe # output infos output_info = [] - subclass_out_meta_user_outs_only = meta.subclass_fw_graph_out_meta[meta.num_mutated_data_inputs:] + subclass_out_meta_user_outs_only = meta.subclass_fw_graph_out_meta[meta.num_mutated_inp_runtime_indices:] if meta.num_intermediate_bases > 0: subclass_out_meta_user_outs_only = subclass_out_meta_user_outs_only[:-meta.num_intermediate_bases] # sanity assert @@ -3966,9 +3948,6 @@ def forward(ctx, *deduped_flat_tensor_args): num_intermediate_bases = CompiledFunction.metadata.num_intermediate_bases num_symints_saved_for_bw = CompiledFunction.num_symints_saved_for_bw num_mutated_runtime_inps = CompiledFunction.metadata.num_mutated_inp_runtime_indices - num_mutated_metadata_only_inputs = ( - CompiledFunction.metadata.num_mutated_metadata_only_inputs - ) num_forward_returns = CompiledFunction.metadata.num_forward_returns num_forward = CompiledFunction.metadata.num_forward @@ -3992,7 +3971,7 @@ def forward(ctx, *deduped_flat_tensor_args): # Wrap all autograd.Function.forward() outputs that are aliases # so that autograd.Function doesn't treat them as tensors - if num_mutated_metadata_only_inputs > 0: + if num_mutated_runtime_inps > 0: for i, idx in enumerate( CompiledFunction.metadata.mutated_inp_runtime_indices ): From dad3cc4d026794ad3ec309f7ae38e9c3798ce34f Mon Sep 17 00:00:00 2001 From: Tugsbayasgalan Manlaibaatar Date: Thu, 23 Nov 2023 11:49:52 -0800 Subject: [PATCH 150/221] Fix type for keep_inference_mutation flag (#114482) Pull Request resolved: https://github.com/pytorch/pytorch/pull/114482 Approved by: https://github.com/Skylion007 ghstack dependencies: #114421, #114479, #114481 --- torch/_functorch/aot_autograd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py index 92ae0f6314f7..58bf329dc566 100644 --- a/torch/_functorch/aot_autograd.py +++ b/torch/_functorch/aot_autograd.py @@ -564,7 +564,7 @@ class ViewAndMutationMeta: num_intermediate_bases: int # For inference only: instructs us to keep data-only input mutations directly in the graph - keep_input_mutations: int + keep_input_mutations: bool # length = (# inputs w data mutations) + (# user outputs that are non_aliasing tensors) # + (# intermediate bases) From 07e00de8d7dca48c91b60e846fed559b63dc5e3d Mon Sep 17 00:00:00 2001 From: cyy Date: Fri, 24 Nov 2023 03:54:11 +0000 Subject: [PATCH 151/221] Add missing member initialization in c10::ExtraMeta constructor (#114448) Pull Request resolved: https://github.com/pytorch/pytorch/pull/114448 Approved by: https://github.com/Skylion007 --- c10/core/TensorImpl.h | 1 + 1 file changed, 1 insertion(+) diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index c4cd54dcaece..11f148ff8753 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -257,6 +257,7 @@ struct C10_API ExtraMeta { : symbolic_shape_meta_(std::move(symbolic_shape_meta)), named_tensor_meta_(std::move(named_tensor_meta)), backend_meta_(std::move(backend_meta)), + custom_data_ptr_error_msg_(std::move(custom_data_ptr_error_msg)), custom_storage_error_msg_(std::move(custom_storage_access_error_msg)) {} std::unique_ptr clone() const { From 51390722e9b01cc3e166e3e398d7a4e0e61842f7 Mon Sep 17 00:00:00 2001 From: Oguz Ulgen Date: Thu, 23 Nov 2023 17:46:39 -0800 Subject: [PATCH 152/221] Fix ConvolutionBinaryInplace using target node (#114436) This IR node mutates in place, it needs to use the argument not the target. Fixes #113440 Pull Request resolved: https://github.com/pytorch/pytorch/pull/114436 Approved by: https://github.com/jansel ghstack dependencies: #114169 --- test/inductor/test_mkldnn_pattern_matcher.py | 95 ++++++++++++++++++++ torch/_inductor/ir.py | 5 +- 2 files changed, 99 insertions(+), 1 deletion(-) diff --git a/test/inductor/test_mkldnn_pattern_matcher.py b/test/inductor/test_mkldnn_pattern_matcher.py index 65f69cd8918c..dc7bfe54c5e1 100644 --- a/test/inductor/test_mkldnn_pattern_matcher.py +++ b/test/inductor/test_mkldnn_pattern_matcher.py @@ -1440,6 +1440,101 @@ def forward(self, input_tensor): include_ops = ["mkldnn._convolution_pointwise_.binary"] self._test_code_common(mod, (input,), include_ops, []) + def test_reproduce_113440_issue_1(self): + class Mod(torch.nn.Module): + def __init__( + self, + add_fn, + **kwargs, + ): + super().__init__() + self.conv1 = torch.nn.Conv2d(3, 6, kernel_size=3, stride=1) + self.conv2 = torch.nn.Conv2d(3, 6, kernel_size=3, stride=1) + self.add_fn = add_fn + self.relu = torch.nn.ReLU(inplace=True) + self.conv3 = torch.nn.Conv2d(6, 6, kernel_size=3, stride=1) + self.conv4 = torch.nn.Conv2d(6, 6, kernel_size=3, stride=1) + self.add_fn2 = add_fn + self.relu2 = torch.nn.ReLU(inplace=True) + self.use_relu = True + + def forward(self, x): + x1 = self.conv1(x) + x2 = self.conv2(x) + tmp = self.add_fn(x1, x2) + if self.use_relu: + tmp = self.relu(tmp) + tmp1 = self.conv3(tmp) + tmp2 = self.conv4(tmp) + res = self.add_fn2(tmp1, tmp2) + if self.use_relu: + res = self.relu2(res) + return res + + with torch.no_grad(): + example_inputs = ( + torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False).add( + 1 + ), + ) + example_inputs[0].get_device() + m = Mod( + lambda x, y: x.add_(y), + ).eval() + om = torch.compile(m) + om(*example_inputs) + om(*example_inputs) + + def test_reproduce_113440_issue_2(self): + class Mod(torch.nn.Module): + def __init__( + self, + add_fn, + **kwargs, + ): + super().__init__() + self.conv1 = torch.nn.Conv2d(3, 6, kernel_size=3, stride=1) + self.conv2 = torch.nn.Conv2d(3, 6, kernel_size=3, stride=1) + self.add_fn = add_fn + self.relu = torch.nn.ReLU(inplace=True) + self.conv3 = torch.nn.Conv2d(6, 6, kernel_size=3, stride=1) + self.conv4 = torch.nn.Conv2d(6, 6, kernel_size=3, stride=1) + self.add_fn2 = add_fn + self.relu2 = torch.nn.ReLU(inplace=True) + + self.conv5 = torch.nn.Conv2d(6, 6, kernel_size=3, stride=1) + self.conv6 = torch.nn.Conv2d(6, 6, kernel_size=3, stride=1) + self.conv7 = torch.nn.Conv2d(6, 6, kernel_size=1, stride=1) + self.add_fn3 = add_fn + self.relu3 = torch.nn.ReLU(inplace=True) + + self.use_relu = True + + def forward(self, x): + x1 = self.conv1(x) + x2 = self.conv2(x) + tmp = self.add_fn(x1, x2) + if self.use_relu: + tmp = self.relu(tmp) + + tmp1 = self.conv3(tmp) + res = self.relu2(tmp1) + + return res + + with torch.no_grad(): + example_inputs = ( + torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False).add( + 1 + ), + ) + m = Mod( + lambda x, y: x.add_(y), + ).eval() + om = torch.compile(m) + om(*example_inputs) + om(*example_inputs) + @dynamo_config.patch({"dynamic_shapes": True, "assume_static_by_default": False}) class TestDynamicPatternMatcher(TestPatternMatcherBase): diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index f1956373bed5..bfe26ade6c7d 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -5155,7 +5155,10 @@ def create( constant_args=constant_args, ) mark_node_as_mutating(packed, inputs[1]) - return packed + # This op mutates in place which means that the result is not the + # target but rather the input that is being mutated + # init reorders the inputs, so inputs[1] becomes packed.inputs[0] + return packed.inputs[0] class MKLPackedLinear(ExternKernelAlloc): From cd7d6938c18d90870356553d4631f1388d2bb699 Mon Sep 17 00:00:00 2001 From: chundian Date: Thu, 23 Nov 2023 12:46:16 -0800 Subject: [PATCH 153/221] [inductor] Fix torch.split bug on unbacked symint (#113406) torch.split(x, l) fails when l's shape is the unbacked symint. E.g. l = y.tolist() makes l the unbacked shape, because l depends on the data access of y. The downdtream call `SliceView.create()` evaluates the shape even if the input shape is unbacked symint, which brings up the bug. Test Plan: python test/inductor/test_unbacked_symints.py -k test_split_with_sizes Pull Request resolved: https://github.com/pytorch/pytorch/pull/113406 Approved by: https://github.com/aakhundov, https://github.com/ezyang --- test/inductor/test_unbacked_symints.py | 15 +++++++++++++++ torch/_inductor/codegen/common.py | 4 ++++ torch/_inductor/codegen/triton.py | 3 ++- torch/_inductor/ir.py | 11 ++++++----- torch/fx/experimental/validator.py | 2 ++ 5 files changed, 29 insertions(+), 6 deletions(-) diff --git a/test/inductor/test_unbacked_symints.py b/test/inductor/test_unbacked_symints.py index 2797bea8ceb1..4ab72c6721bc 100644 --- a/test/inductor/test_unbacked_symints.py +++ b/test/inductor/test_unbacked_symints.py @@ -54,6 +54,21 @@ def fn(x, y): torch.testing.assert_close(actual, expected) + def test_split_with_sizes(self): + def fn(x, y): + l = y.tolist() + s = torch.split(x, l) + d = l[0] + l[1] + l[2] + return s[0].sum(), d + + example_inputs = (torch.randn((32), device="cuda"), torch.tensor((7, 16, 9))) + + with dynamo_config.patch({"capture_scalar_outputs": True}): + actual = torch.compile(fn, fullgraph=True)(*example_inputs) + expected = fn(*example_inputs) + + torch.testing.assert_close(actual, expected) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index be949e8f92a9..19fa6de3219e 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -378,6 +378,10 @@ def _print_Max(self, expr): assert len(expr.args) >= 2 return f"max({', '.join(map(self._print, expr.args))})" + def _print_Min(self, expr): + assert len(expr.args) >= 2 + return f"min({', '.join(map(self._print, expr.args))})" + class OpOverrides: def __init__(self, parent): diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 0f08f728330f..6b288cec254d 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -19,6 +19,7 @@ from torch._prims_common import is_integer_dtype from torch.utils._sympy.functions import FloorDiv, ModularIndexing from torch.utils._sympy.value_ranges import ValueRanges + from ..._dynamo.utils import counters from .. import config, ir, scheduler from ..codecache import code_hash, get_path, PyCodeCache @@ -1143,7 +1144,7 @@ def indexing( # indirect indexing cse_var = self.cse.varname_map[var.name] mask_vars.update(cse_var.mask_vars) - elif var.name.startswith(("s", "ps")): + elif var.name.startswith(("s", "ps", "i")): pass else: # var is one of xN, yN or rN diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index bfe26ade6c7d..4d5f78c7051c 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -2088,11 +2088,12 @@ def create(cls, x, dim, start, end, step=1): start = cls.handle_negative_index(start, new_size[dim]) end = cls.handle_negative_index(end, new_size[dim]) - end = sizevars.evaluate_min(end, new_size[dim]) - start = sizevars.evaluate_min(start, end) - if start == 0 and sizevars.size_hint(end - new_size[dim]) == 0 and step == 1: - sizevars.guard_equals(end, new_size[dim]) - return x + if free_unbacked_symbols(start) or free_unbacked_symbols(end): + end = sympy.Min(end, new_size[dim]) + start = sympy.Min(start, end) + else: + end = sizevars.evaluate_min(end, new_size[dim]) + start = sizevars.evaluate_min(start, end) new_size[dim] = FloorDiv(end - start + (step - 1), step) diff --git a/torch/fx/experimental/validator.py b/torch/fx/experimental/validator.py index 8c795bb5b3de..48ad07dd8559 100644 --- a/torch/fx/experimental/validator.py +++ b/torch/fx/experimental/validator.py @@ -363,6 +363,8 @@ def __getattr__(self, name: str) -> Any: "not_": z3.Not, "floor": self._ops.floor, "ceil": self._ops.ceil, + "minimum": self._ops.min, + "maximum": self._ops.max, } if name in REPLACEMENT: From 0a063ad2c00d1a591bc3ecafa907031118e8c2a9 Mon Sep 17 00:00:00 2001 From: Adnan Akhundov Date: Fri, 24 Nov 2023 00:32:21 -0800 Subject: [PATCH 154/221] [inductor] Pass None and skip constexpr in custom Triton kernel calls from C++ (#114475) Summary: `None` arguments are codegened as `*i8` in the `triton_meta` of the generated or user-defined Triton kernels: https://github.com/pytorch/pytorch/blob/85aa3723749e0d06aa5fd34215b9b93529a60995/torch/_inductor/codegen/triton_utils.py#L33-L36 Due to this, in contrary to the conventional Triton, we actually should pass `nullptr` to the Triton kernels in C++ wrapper codegen instead of passing nothing (as normally `None` doesn't make it to the generated PTX parameters, just like `tl.constexpr` args). This PR adds two things: 1. Proper C++ wrapper codegening (ABI and non-ABI) of `nullptr` and `c10::nullopt`, as the prior way codegening `c10::nullopt` as tensor breaks (also `c10` breaks in the ABI mode). 2. Skipping `tl.constexpr` args when calling the loaded-from-cubin compiled Triton kernel in the C++ wrapper codegen. As a side effect, this also resolves an issue with string arguments: now they are simply omitted in the C++ wrapper codegen. Test Plan: ``` $ python test/inductor/test_aot_inductor.py -k test_triton_kernel_with_none_input ... ---------------------------------------------------------------------- Ran 4 tests in 40.364s OK (skipped=2) ``` Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/114475 Approved by: https://github.com/oulgen --- test/inductor/test_aot_inductor.py | 45 +++++++++++++++++++++++++ torch/_inductor/codegen/wrapper.py | 10 ++++-- torch/_inductor/ir.py | 9 ++++- torch/testing/_internal/triton_utils.py | 21 ++++++++++++ 4 files changed, 82 insertions(+), 3 deletions(-) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index a41860cb14cb..b2ebfbce5c92 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -37,6 +37,7 @@ add_kernel, add_kernel_2d_autotuned, add_kernel_autotuned, + add_kernel_with_optional_param, ) if IS_WINDOWS and IS_CI: @@ -1268,6 +1269,50 @@ def forward(self, x): example_inputs = (torch.randn(4, device=self.device),) self.check_model(Model(), example_inputs) + def test_triton_kernel_with_none_input(self): + if self.device != "cuda": + raise unittest.SkipTest("requires CUDA") + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + # AOT export does not allow for input mutation + n_elements = x.size()[0] + BLOCK_SIZE = 1024 + + x = x.clone() + y = y.clone() + output_wo_y = torch.empty_like(x) + output_with_y = torch.empty_like(x) + + wo_kernel = add_kernel_with_optional_param[(1,)]( + x, + None, + output_wo_y, + n_elements, + ARGS_PASSED="one", + BLOCK_SIZE=BLOCK_SIZE, + ) + with_kernel = add_kernel_with_optional_param[(1,)]( + x, + y, + output_with_y, + n_elements, + ARGS_PASSED="two", + BLOCK_SIZE=BLOCK_SIZE, + ) + + return 2.71 * output_wo_y + 3.14 * output_with_y + + example_inputs = ( + torch.randn(1023, device=self.device), + torch.randn(1023, device=self.device), + ) + + self.check_model(Model(), example_inputs) + def test_shifted_constraint_ranges(self): class Model(torch.nn.Module): def __init__(self): diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 1787937f39ab..196fdc35d298 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -2350,8 +2350,10 @@ def extract_output_name(out): def val_to_arg_str(self, val): if val is None: # When None is passed as an argument, it represents an optional that does not contain a value. - # TODO: add abi-compatible support - return "c10::nullopt" + if config.aot_inductor.abi_compatible: + return "nullptr" + else: + return "c10::nullopt" elif isinstance(val, bool): if config.aot_inductor.abi_compatible: return "1" if val else "0" @@ -2532,6 +2534,10 @@ def generate_args_decl(self, call_args): self.writeline(f"float {var_name} = {arg};") elif any(str(arg) == s.name for s in dynamic_symbols): self.writeline(f"auto {var_name} = {arg};") + elif arg == "nullptr": + self.writeline(f"auto {var_name} = nullptr;") + elif arg == "c10::nullopt": + self.writeline(f"auto {var_name} = c10::nullopt;") else: if config.aot_inductor.abi_compatible: self.writeline(f"CUdeviceptr {var_name};") diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 4d5f78c7051c..4c45f4896678 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -3831,13 +3831,20 @@ def codegen(self, wrapper): kernel, configs, self.kwargs ) + args = self.codegen_kwargs() + if V.graph.cpp_wrapper: + # in C++ wrapper, we don't pass constexpr args, as they don't + # get added as parameters to the PTX code compiled from the + # user-defined Triton kernel (only non-constexpr args do) + args = [arg for i, arg in enumerate(args) if i not in kernel.constexprs] + # Call to kernel self.codegen_comment(wrapper) wrapper.generate_user_defined_triton_kernel( new_name, self.grid, configs, - self.codegen_kwargs(), + args, ) def should_allocate(self): diff --git a/torch/testing/_internal/triton_utils.py b/torch/testing/_internal/triton_utils.py index e711be91455e..ddd33109ba0e 100644 --- a/torch/testing/_internal/triton_utils.py +++ b/torch/testing/_internal/triton_utils.py @@ -27,6 +27,27 @@ def add_kernel( output = x + y tl.store(out_ptr + offsets, output, mask=mask) + @triton.jit + def add_kernel_with_optional_param( + in_ptr0, + in_ptr1, + out_ptr, + n_elements, + ARGS_PASSED: "tl.constexpr", + BLOCK_SIZE: "tl.constexpr", + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr0 + offsets, mask=mask) + if ARGS_PASSED == "two": + y = tl.load(in_ptr1 + offsets, mask=mask) + output = x + y + else: + output = x + tl.store(out_ptr + offsets, output, mask=mask) + @triton.autotune( configs=[ triton.Config({"BLOCK_SIZE": 128}, num_stages=3, num_warps=8), From c6d88604d56a4e57b34a3c61982a57bdc0ccc0a1 Mon Sep 17 00:00:00 2001 From: Oguz Ulgen Date: Thu, 23 Nov 2023 22:50:48 -0800 Subject: [PATCH 155/221] [Inductor] Fix mutation tracking of ConvolutionBinaryInplace (#114501) Init function reorders the arguments so the mutation actually happens on argument input[0] I am not sure if there's a good way to test this unfortunately.. Added tests on https://github.com/pytorch/pytorch/pull/114436 Pull Request resolved: https://github.com/pytorch/pytorch/pull/114501 Approved by: https://github.com/leslie-fang-intel, https://github.com/aakhundov --- torch/_inductor/ir.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 4c45f4896678..ae031b8f2818 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -5118,7 +5118,7 @@ def codegen(self, wrapper): ) def get_mutation_names(self): - return [self.inputs[1].get_name()] + return [self.inputs[0].get_name()] def get_unbacked_symbol_defs(self): return {} From d30497f6b62007c9d1e3c38179528e9d25ac1292 Mon Sep 17 00:00:00 2001 From: Aaron Gokaslan Date: Fri, 24 Nov 2023 23:29:51 +0000 Subject: [PATCH 156/221] [BE]: Enable Ruff + Flake8 G201,G202 logging format rule. (#114474) Standardizes logging calls to always use logging.exception instead of logging.error where appropriate and enforces it with a lint. Pull Request resolved: https://github.com/pytorch/pytorch/pull/114474 Approved by: https://github.com/jansel --- .flake8 | 2 +- pyproject.toml | 2 +- torch/_dynamo/guards.py | 3 +-- torch/_dynamo/utils.py | 2 +- torch/distributed/elastic/multiprocessing/api.py | 3 +-- torch/distributed/elastic/timer/api.py | 9 ++++----- .../distributed/elastic/timer/file_based_local_timer.py | 8 ++++---- torch/distributed/elastic/timer/local_timer.py | 4 ++-- 8 files changed, 15 insertions(+), 18 deletions(-) diff --git a/.flake8 b/.flake8 index bca578ce563e..1e61b459df94 100644 --- a/.flake8 +++ b/.flake8 @@ -18,7 +18,7 @@ ignore = # these ignores are from flake8-comprehensions; please fix! C407, # these ignores are from flake8-logging-format; please fix! - G100,G101,G200,G201,G202 + G100,G101,G200 # these ignores are from flake8-simplify. please fix or ignore with commented reason SIM105,SIM108,SIM110,SIM111,SIM113,SIM114,SIM115,SIM116,SIM117,SIM118,SIM119,SIM12, # flake8-simplify code styles diff --git a/pyproject.toml b/pyproject.toml index 279bd6fa058b..71157c4f3cf3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,7 +43,7 @@ ignore = [ "F821", "F841", # these ignores are from flake8-logging-format; please fix! - "G101", "G201", "G202", + "G101", # these ignores are from RUFF perf; please fix! "PERF203", "PERF4", # these ignores are from PYI; please fix! diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 1b068402019b..0ef173155e2f 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -1315,9 +1315,8 @@ def get_guard_fail_reason( GuardFail(reason_str or "unknown reason", orig_code_map[code]) ) except Exception as e: - log.error( + log.exception( "Failure in guard_fail_fn callback - raising here will cause a NULL Error on guard eval", - exc_info=True, ) return reason_str diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index aa0719a3ab56..2a520eb304c5 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -400,7 +400,7 @@ def write_record_to_file(filename, exec_record): with open(filename, "wb") as f: exec_record.dump(f) except Exception: - log.error("Unable to write execution record %s", filename, exc_info=True) + log.exception("Unable to write execution record %s", filename) def count_calls(g: fx.Graph): diff --git a/torch/distributed/elastic/multiprocessing/api.py b/torch/distributed/elastic/multiprocessing/api.py index 32426be08010..c7c870bdb073 100644 --- a/torch/distributed/elastic/multiprocessing/api.py +++ b/torch/distributed/elastic/multiprocessing/api.py @@ -477,14 +477,13 @@ def _poll(self) -> Optional[RunProcsResult]: failed_proc = self._pc.processes[failed_local_rank] error_filepath = self.error_files[failed_local_rank] - log.error( + log.exception( "failed (exitcode: %s)" " local_rank: %s (pid: %s)" " of fn: %s (start_method: %s)", failed_proc.exitcode, failed_local_rank, e.pid, fn_name, self.start_method, - exc_info=True, ) self.close() diff --git a/torch/distributed/elastic/timer/api.py b/torch/distributed/elastic/timer/api.py index 6dd308891988..566a3d4acbc7 100644 --- a/torch/distributed/elastic/timer/api.py +++ b/torch/distributed/elastic/timer/api.py @@ -169,11 +169,10 @@ def _reap_worker_no_throw(self, worker_id: Any) -> bool: """ try: return self._reap_worker(worker_id) - except Exception as e: - log.error( + except Exception: + log.exception( "Uncaught exception thrown from _reap_worker(), " "check that the implementation correctly catches exceptions", - exc_info=e, ) return True @@ -181,8 +180,8 @@ def _watchdog_loop(self): while not self._stop_signaled: try: self._run_watchdog() - except Exception as e: - log.error("Error running watchdog", exc_info=e) + except Exception: + log.exception("Error running watchdog") def _run_watchdog(self): batch_size = max(1, self._request_queue.size()) diff --git a/torch/distributed/elastic/timer/file_based_local_timer.py b/torch/distributed/elastic/timer/file_based_local_timer.py index 597000c6d20d..26ebce33dcb5 100644 --- a/torch/distributed/elastic/timer/file_based_local_timer.py +++ b/torch/distributed/elastic/timer/file_based_local_timer.py @@ -225,8 +225,8 @@ def _watchdog_loop(self) -> None: self._run_watchdog(fd) if run_once: break - except Exception as e: - log.error("Error running watchdog", exc_info=e) + except Exception: + log.exception("Error running watchdog") def _run_watchdog(self, fd: io.TextIOWrapper) -> None: timer_requests = self._get_requests(fd, self._max_interval) @@ -328,6 +328,6 @@ def _reap_worker(self, worker_pid: int, signal: int) -> bool: except ProcessLookupError: log.info("Process with pid=%s does not exist. Skipping", worker_pid) return True - except Exception as e: - log.error("Error terminating pid=%s", worker_pid, exc_info=e) + except Exception: + log.exception("Error terminating pid=%s", worker_pid) return False diff --git a/torch/distributed/elastic/timer/local_timer.py b/torch/distributed/elastic/timer/local_timer.py index 240163f1bf6c..05f467c807a5 100644 --- a/torch/distributed/elastic/timer/local_timer.py +++ b/torch/distributed/elastic/timer/local_timer.py @@ -120,6 +120,6 @@ def _reap_worker(self, worker_id: int) -> bool: except ProcessLookupError: log.info("Process with pid=%s does not exist. Skipping", worker_id) return True - except Exception as e: - log.error("Error terminating pid=%s", worker_id, exc_info=e) + except Exception: + log.exception("Error terminating pid=%s", worker_id) return False From 0f5e24bda9450a89ba56d2fdd471f56d97fe4546 Mon Sep 17 00:00:00 2001 From: Jez Ng Date: Fri, 24 Nov 2023 10:43:09 -0800 Subject: [PATCH 157/221] Properly type CachedFunction & rename to CachedMethod (#114161) Previously, I was unsure how to properly type the parameters of a decorated method. Then I found https://github.com/python/mypy/issues/13222#issuecomment-1193073470 which explains how to use `Concatenate` to hackily achieve it. Not entirely sure why we can't write a user-defined version of `Callable` that works seamlessly for both functions and methods... Pull Request resolved: https://github.com/pytorch/pytorch/pull/114161 Approved by: https://github.com/Skylion007 --- torch/_inductor/codegen/memory_planning.py | 32 ++++++++++++---------- torch/_inductor/utils.py | 10 ++++--- 2 files changed, 23 insertions(+), 19 deletions(-) diff --git a/torch/_inductor/codegen/memory_planning.py b/torch/_inductor/codegen/memory_planning.py index 806d8504845f..0299941b40ea 100644 --- a/torch/_inductor/codegen/memory_planning.py +++ b/torch/_inductor/codegen/memory_planning.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import collections import dataclasses import itertools @@ -8,7 +10,7 @@ import torch from .. import config, ir -from ..utils import cache_on_self, CachedFunction, IndentedBuffer +from ..utils import cache_on_self, CachedMethod, IndentedBuffer from ..virtualized import V from .wrapper import ( @@ -63,11 +65,11 @@ class LiveRange: begin: float # int | ±inf end: float # int | ±inf - def contains(self, other: "LiveRange"): + def contains(self, other: LiveRange): """Is other entirely within self""" return self.begin <= other.begin and other.end <= self.end - def join(self, other: "LiveRange"): + def join(self, other: LiveRange): """Combine two ranges using a union operation""" return LiveRange(min(self.begin, other.begin), max(self.end, other.end)) @@ -93,7 +95,7 @@ def __init__(self, ranges: Iterable[LiveRange]): else: self.ranges.append(r) - def overlaps(self, other: "LiveRanges"): + def overlaps(self, other: LiveRanges): """Check if any pair of ranges in self and other overlap""" left = collections.deque(self.ranges) right = collections.deque(other.ranges) @@ -123,7 +125,7 @@ class AllocationTreeNode: Abstract base class for nodes in allocation pool. """ - def allocate(self, block: "Allocation", is_last: bool) -> bool: + def allocate(self, block: Allocation, is_last: bool) -> bool: """ Try to assign block to a memory location in this bool. Return True if an assignment was made. @@ -142,7 +144,7 @@ def get_symbolic_size(self) -> sympy.Expr: """Number of bytes needed at runtime""" raise NotImplementedError() - def finalize(self, pool, offset) -> "AllocationTreeNode": + def finalize(self, pool, offset) -> AllocationTreeNode: """Called after all allocations have been made""" return self @@ -161,7 +163,7 @@ class Allocation(AllocationTreeNode): size_hint: int symbolic_size: sympy.Expr allocated: bool = False - pool: Optional["AllocationPool"] = None + pool: Optional[AllocationPool] = None offset: Optional[sympy.Expr] = None @property @@ -231,11 +233,11 @@ def is_empty(self): class MemorySplitProtocol(Protocol): - get_live_ranges: CachedFunction[LiveRanges] - get_size_hint: CachedFunction[int] - get_symbolic_size: CachedFunction[sympy.Expr] + get_live_ranges: CachedMethod[[], LiveRanges] + get_size_hint: CachedMethod[[], int] + get_symbolic_size: CachedMethod[[], sympy.Expr] - def _allocate(self, block: "Allocation", is_last: bool) -> bool: + def _allocate(self, block: Allocation, is_last: bool) -> bool: ... @@ -245,7 +247,7 @@ class ClearCacheOnAllocateMixin(MemorySplitProtocol): get_symbolic_size. """ - def allocate(self, block: "Allocation", is_last: bool): + def allocate(self, block: Allocation, is_last: bool): is_allocated = self._allocate(block, is_last) if is_allocated: self.clear_cache() @@ -268,7 +270,7 @@ class TemporalSplit(ClearCacheOnAllocateMixin, AllocationTreeNode): allocations: List[AllocationTreeNode] - def _allocate(self, block: "Allocation", is_last: bool): + def _allocate(self, block: Allocation, is_last: bool): slot_size = self.get_size_hint() block_size = block.get_size_hint() if not is_last and block_size > slot_size: @@ -354,7 +356,7 @@ def create(left, extra_space): assert isinstance(extra_space, int) and extra_space >= 1 return SpatialSplit(TemporalSplit([left]), TemporalSplit([Empty(extra_space)])) - def _allocate(self, block: "Allocation", is_last: bool): + def _allocate(self, block: Allocation, is_last: bool): return self.left.allocate(block, False) or self.right.allocate(block, is_last) @cache_on_self @@ -399,7 +401,7 @@ class AllocationPool: names_to_del: List[str] = dataclasses.field(default_factory=list) creation_cache: Dict[str, str] = dataclasses.field(default_factory=dict) - def allocate(self, block: "Allocation", is_last: bool): + def allocate(self, block: Allocation, is_last: bool): if self.restrict_live_range and not self.restrict_live_range.contains( block.live_range ): diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 598f3cb31898..40e106621913 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -38,6 +38,7 @@ from unittest import mock import sympy +from typing_extensions import Concatenate, ParamSpec import torch from torch._dynamo.device_interface import get_interface_for_device @@ -382,20 +383,21 @@ def sort_func(elem): return sorted(x, key=sort_func) +P = ParamSpec("P") RV = TypeVar("RV", covariant=True) -# FIXME this should take in a ParamSpec too -class CachedFunction(Generic[RV], Protocol): +class CachedMethod(Generic[P, RV], Protocol): @staticmethod def clear_cache(self) -> None: ... - def __call__(self, *args, **kwargs) -> RV: + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> RV: ... -def cache_on_self(fn: Callable[..., RV]) -> CachedFunction[RV]: +# See https://github.com/python/mypy/issues/13222#issuecomment-1193073470 to understand the type signature +def cache_on_self(fn: Callable[Concatenate[Any, P], RV]) -> CachedMethod[P, RV]: key = f"__{fn.__name__}_cache" @functools.wraps(fn) From d37c4c69954ad7bdccca96854105c48e93d4587e Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Sat, 25 Nov 2023 23:15:47 +0000 Subject: [PATCH 158/221] Update `torch.compiler_troubleshooting.rst` (#114530) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit If you copy and paste the env var in the docs: ```console TORCHDYNAMO_REPRO_AFTER=“aot” ``` it leads to this error: ```python @functools.wraps(unconfigured_compiler_fn) def debug_wrapper(gm, example_inputs, **kwargs): compiler_fn = functools.partial(unconfigured_compiler_fn, **kwargs) > assert config.repro_after in ("dynamo", "aot", None) E torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised: E AssertionError: ``` because `config.repro_after` is being `'“aot”'` but not `'aot'`. --- It would've saved a few minutes of my time 😄 Pull Request resolved: https://github.com/pytorch/pytorch/pull/114530 Approved by: https://github.com/Chillee --- docs/source/torch.compiler_troubleshooting.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/torch.compiler_troubleshooting.rst b/docs/source/torch.compiler_troubleshooting.rst index 1c655da54cc7..25795b38a0a5 100644 --- a/docs/source/torch.compiler_troubleshooting.rst +++ b/docs/source/torch.compiler_troubleshooting.rst @@ -274,7 +274,7 @@ Minifying TorchInductor Errors ------------------------------ From here, let’s run the minifier to get a minimal repro. Setting the -environment variable ``TORCHDYNAMO_REPRO_AFTER=“aot”`` (or setting +environment variable ``TORCHDYNAMO_REPRO_AFTER="aot"`` (or setting ``torch._dynamo.config.repro_after="aot"`` directly) will generate a Python program which reduces the graph produced by AOTAutograd to the smallest subgraph which reproduces the error. (See below for an example @@ -376,7 +376,7 @@ through an example. In order to run the code after TorchDynamo has traced the forward graph, you can use the ``TORCHDYNAMO_REPRO_AFTER`` environment variable. Running -this program with ``TORCHDYNAMO_REPRO_AFTER=“dynamo”`` (or +this program with ``TORCHDYNAMO_REPRO_AFTER="dynamo"`` (or ``torch._dynamo.config.repro_after="dynamo"``) should produce `this output `__\ and the following code in ``{torch._dynamo.config.base_dir}/repro.py``. From bbdd9b059fa06404701f0d39abaae75ebab5d27f Mon Sep 17 00:00:00 2001 From: PyTorch UpdateBot Date: Sun, 26 Nov 2023 03:50:54 +0000 Subject: [PATCH 159/221] [executorch hash update] update the pinned executorch hash (#114486) This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/_update-commit-hash.yml). Update the pinned executorch hash. Pull Request resolved: https://github.com/pytorch/pytorch/pull/114486 Approved by: https://github.com/pytorchbot --- .ci/docker/ci_commit_pins/executorch.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.ci/docker/ci_commit_pins/executorch.txt b/.ci/docker/ci_commit_pins/executorch.txt index 143c259d1612..4d5c9db5b489 100644 --- a/.ci/docker/ci_commit_pins/executorch.txt +++ b/.ci/docker/ci_commit_pins/executorch.txt @@ -1 +1 @@ -f4578fc150f1690be27fd1ba3258b35a20d9c39d +1f584d77e610b98ed38138fce9922f9f4b7d9e21 From 028071c4a1609b822c328b9b4a6a6befd3588042 Mon Sep 17 00:00:00 2001 From: Scott Lowder Date: Sun, 26 Nov 2023 09:25:41 +0000 Subject: [PATCH 160/221] Fix test assertions in test_min_max_nodes_parse. (#114537) Calls to `assertTrue` corrected to be `assertEqual` in `ElasticLaunchTest test_min_max_nodes_parse`. As originally written, the `assertTrue` statements will always pass, not actually asserting anything of value for the test. Pull Request resolved: https://github.com/pytorch/pytorch/pull/114537 Approved by: https://github.com/Skylion007 --- test/distributed/launcher/run_test.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/distributed/launcher/run_test.py b/test/distributed/launcher/run_test.py index 7cdab9f02769..9725ef0836d1 100644 --- a/test/distributed/launcher/run_test.py +++ b/test/distributed/launcher/run_test.py @@ -452,11 +452,11 @@ def test_launch_elastic_multiple_agents(self): def test_min_max_nodes_parse(self): min_nodes, max_nodes = launch.parse_min_max_nnodes("1") - self.assertTrue(min_nodes, max_nodes) - self.assertTrue(1, min_nodes) + self.assertEqual(min_nodes, max_nodes) + self.assertEqual(1, min_nodes) min_nodes, max_nodes = launch.parse_min_max_nnodes("2:20") - self.assertTrue(2, min_nodes) - self.assertTrue(20, max_nodes) + self.assertEqual(2, min_nodes) + self.assertEqual(20, max_nodes) with self.assertRaises(RuntimeError): launch.parse_min_max_nnodes("2:20:30") From 4fa1ff8404b6c26c076288aa2a0aa77f0c24916a Mon Sep 17 00:00:00 2001 From: Khushi Agrawal Date: Sun, 26 Nov 2023 13:44:30 +0000 Subject: [PATCH 161/221] [opinfo][fix] conv3d & fix conv{1, 2}d for neg dilation|groups & add ErrorInputs for conv ops (#113885) Previous PR: https://github.com/pytorch/pytorch/pull/85202 Also, cc'ing @lezcano @kshitij12345 @zou3519, who reviewed my previous PR. Thanks! Pull Request resolved: https://github.com/pytorch/pytorch/pull/113885 Approved by: https://github.com/lezcano --- aten/src/ATen/native/Convolution.cpp | 10 + test/functorch/test_ops.py | 2 + test/test_mps.py | 2 + .../_internal/common_methods_invocations.py | 240 ++++++++++++++++-- 4 files changed, 228 insertions(+), 26 deletions(-) diff --git a/aten/src/ATen/native/Convolution.cpp b/aten/src/ATen/native/Convolution.cpp index 9c31026af54c..43ee07b41107 100644 --- a/aten/src/ATen/native/Convolution.cpp +++ b/aten/src/ATen/native/Convolution.cpp @@ -343,6 +343,14 @@ struct ConvParams { return is_non_neg; } + bool is_dilation_neg() const { + bool is_non_neg = false; + for (const auto& p : dilation) { + is_non_neg |= (p < 0); + } + return is_non_neg; + } + bool is_stride_nonpos() const { bool is_nonpos = false; for (auto s : stride) { @@ -652,6 +660,7 @@ static void check_shape_forward(const at::Tensor& input, TORCH_CHECK(!params.is_padding_neg(), "negative padding is not supported"); TORCH_CHECK(!params.is_output_padding_neg(), "negative output_padding is not supported"); TORCH_CHECK(!params.is_stride_nonpos(), "non-positive stride is not supported"); + TORCH_CHECK(!params.is_dilation_neg(), "dilation should be greater than zero"); TORCH_CHECK(weight_dim == k, "Expected ", weight_dim, "-dimensional input for ", weight_dim, @@ -973,6 +982,7 @@ static Tensor convolution_same( auto k = weight.dim(); TORCH_CHECK(k > 2, "weight should have at least three dimensions"); + TORCH_CHECK(groups > 0, "non-positive groups is not supported"); auto dim = static_cast(k - 2); auto weight_sizes = weight.sym_sizes(); auto input_sizes = input.sym_sizes(); diff --git a/test/functorch/test_ops.py b/test/functorch/test_ops.py index 945162ac69e4..eb90ba83d512 100644 --- a/test/functorch/test_ops.py +++ b/test/functorch/test_ops.py @@ -1808,6 +1808,8 @@ def fn(input, weight, bias): {torch.float32: tol(atol=2e-04, rtol=1e-04)}, device_type='cuda'), tol2('linalg.pinv', 'hermitian', {torch.float32: tol(atol=5e-06, rtol=5e-06)}), + tol1('nn.functional.conv3d', + {torch.float32: tol(atol=5e-04, rtol=9e-03)}), )) def test_vmap_autograd_grad(self, device, dtype, op): def is_differentiable(inp): diff --git a/test/test_mps.py b/test/test_mps.py index 2a1bcdb30782..a1f3f5215eef 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -695,6 +695,7 @@ def mps_ops_modifier(ops): # Convolution for integral types is not supported on MPS 'nn.functional.conv1d': [torch.int64], 'nn.functional.conv2d': [torch.int64], + 'nn.functional.conv3d': [torch.int64], 'nn.functional.conv_transpose1d': [torch.int64], 'nn.functional.conv_transpose2d': [torch.int64], @@ -882,6 +883,7 @@ def mps_ops_error_inputs_modifier(ops): 'multinomial', 'nn.functional.conv1d', 'nn.functional.conv2d', + 'nn.functional.conv3d', 'gather', 'scatter', 'scatter_add', diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index e6cd427a99e2..54cce7dcf3c3 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -3848,10 +3848,6 @@ def sample_inputs_conv1d(op_info, device, dtype, requires_grad, **kwargs): ((1, 4, 5), (3, 4, 3), None, {}), ) - # TODO: (@krshrimali), add error_inputs_func once https://github.com/pytorch/pytorch/pull/67354 is merged - # Should replace test_conv_modules_raise_error_on_incorrect_input_size and test_conv_shapecheck - # in test/test_nn.py - for input_shape, weight, bias, kwargs in cases: # Batched yield SampleInput(make_arg(input_shape), args=( @@ -3866,33 +3862,112 @@ def sample_inputs_conv1d(op_info, device, dtype, requires_grad, **kwargs): def error_inputs_conv1d(opinfo, device, **kwargs): - input = torch.randn(size=(33, 16, 30), device=device, dtype=torch.float64) - weight = torch.randn(size=(20, 16, 5), device=device, dtype=torch.float64) - groups = 0 + make_arg = partial(make_tensor, device=device, dtype=torch.float64) + + # error inputs for negative strides yield ErrorInput( - SampleInput(input, kwargs={"weight": weight, "groups": groups}), - error_regex="non-positive groups is not supported" - ) + SampleInput(make_arg((1, 1, 4)), args=(make_arg((1, 2, 2)), make_arg((1,))), + kwargs={'stride': (-1,)}), error_regex="non-positive stride is not supported") + + # error inputs for negative padding + yield ErrorInput( + SampleInput(make_arg((1, 1, 4)), args=(make_arg((1, 2, 2)), make_arg((1,))), + kwargs={'padding': (-1,)}), error_regex="negative padding is not supported") + + # error inputs for negative dilation + yield ErrorInput( + SampleInput(make_arg((1, 1, 4)), args=(make_arg((1, 1, 2)), make_arg((1,))), + kwargs={'dilation': (-1,)}), error_regex="dilation should be greater than zero") + + # FIXME: https://github.com/pytorch/pytorch/issues/85656 + # error inputs for bias shape not equal to the output channels + # yield ErrorInput(SampleInput(make_arg((1, 1, 4)), args=(make_arg((1, 1, 3)), make_arg((2,)))), + # error_regex="expected bias to be 1-dimensional with 1 elements") + + # error inputs for input.ndim != weight.ndim + yield ErrorInput(SampleInput(make_arg((1, 1, 4)), args=(make_arg((1, 2)), make_arg((1,)))), + error_regex="weight should have at least three dimensions") + + # error inputs for the weight[0] are less than the number of groups + yield ErrorInput( + SampleInput(make_arg((2, 2, 4)), args=(make_arg((2, 2, 2)), make_arg((2,))), + kwargs={'padding': 'same', 'groups': 3}), error_regex="expected weight to be at least 3 at dimension 0") + + # error inputs for the weight[0] are less than the number of groups + yield ErrorInput( + SampleInput(make_arg((2, 2, 4)), args=(make_arg((2, 2, 2)), make_arg((2,))), + kwargs={'groups': 3}), error_regex="expected weight to be at least 3 at dimension 0") + + # error inputs for invalid groups + yield ErrorInput( + SampleInput(make_arg((2, 2, 4)), args=(make_arg((2, 2, 2)), make_arg((2,))), + kwargs={'padding': 'same', 'groups': -1}), error_regex="non-positive groups is not supported") + + # error inputs for invalid groups + yield ErrorInput( + SampleInput(make_arg((2, 2, 4)), args=(make_arg((2, 2, 2)), make_arg((2,))), + kwargs={'padding': 'same', 'groups': 0}), error_regex="non-positive groups is not supported") def error_inputs_conv2d(opinfo, device, **kwargs): - weight = torch.randint(high=10, size=(3, 2, 3, 3), device=device) - input = torch.randint(high=10, size=(2, 4, 4), device=device) - bias = torch.rand((3,), dtype=torch.float32, device=device) - yield ErrorInput(SampleInput(input, args=(weight, bias)), error_regex="should be the same") - - weight = torch.rand(size=(3, 2, 3, 3), device=device, dtype=torch.float64) - input = torch.rand(size=(2, 4, 4), device=device, dtype=torch.float64) - bias = torch.rand((3,), dtype=torch.complex128, device=device) - yield ErrorInput(SampleInput(input, args=(weight, bias)), error_regex="should be the same") - - input = torch.randn(size=(1, 4, 5, 5), device=device, dtype=torch.float64) - weight = torch.randn(size=(8, 4, 3, 3), device=device, dtype=torch.float64) - groups = 0 + make_arg = partial(make_tensor, device=device, dtype=torch.float64) + make_int_arg = partial(make_tensor, device=device, dtype=torch.int64) + make_complex_arg = partial(make_tensor, device=device, dtype=torch.complex128) + + # error inputs for different dtypes of input tensor and bias yield ErrorInput( - SampleInput(input, kwargs={"weight": weight, "groups": groups}), - error_regex="non-positive groups is not supported" - ) + SampleInput(make_int_arg((2, 4, 4)), args=(make_int_arg((3, 2, 3, 3)), make_arg((3,)))), + error_regex="should be the same") + + # error inputs for different dtypes of input tensor and bias + yield ErrorInput( + SampleInput(make_arg((2, 4, 4)), args=(make_arg((3, 2, 3, 3)), make_complex_arg((3,)))), + error_regex="should be the same") + + # error inputs for negative strides + yield ErrorInput( + SampleInput(make_arg((1, 1, 4, 4)), args=(make_arg((1, 2, 2, 3)), make_arg((1,))), + kwargs={'stride': (-1,)}), error_regex="non-positive stride is not supported") + + # error inputs for negative padding + yield ErrorInput( + SampleInput(make_arg((1, 1, 4, 3)), args=(make_arg((1, 2, 2, 4)), make_arg((1,))), + kwargs={'padding': (-1,)}), error_regex="negative padding is not supported") + + # error inputs for negative dilation + yield ErrorInput( + SampleInput(make_arg((1, 1, 4, 2)), args=(make_arg((1, 1, 2, 5)), make_arg((1,))), + kwargs={'dilation': (-1,)}), error_regex="dilation should be greater than zero") + + # FIXME: https://github.com/pytorch/pytorch/issues/85656 + # error inputs for bias shape not equal to the output channels + # yield ErrorInput(SampleInput(make_arg((1, 1, 4, 4)), args=(make_arg((1, 1, 3, 2)), make_arg((2,)))), + # error_regex="expected bias to be 1-dimensional with 1 elements") + + # error inputs for input.ndim != weight.ndim + yield ErrorInput( + SampleInput(make_arg((1, 1, 4, 3)), args=(make_arg((1, 2, 2)), make_arg((1,))), + kwargs={'padding': 'same'}), error_regex="Expected 3-dimensional input for 3-dimensional weight") + + # error inputs for the weight[0] are less than the number of groups + yield ErrorInput( + SampleInput(make_arg((2, 2, 4, 3)), args=(make_arg((2, 2, 1, 3)), make_arg((2,))), + kwargs={'groups': 3}), error_regex="expected weight to be at least 3 at dimension 0") + + # error inputs for groups the weight[0] are less than the number of groups + yield ErrorInput( + SampleInput(make_arg((2, 2, 4, 3)), args=(make_arg((2, 2, 1, 3)), make_arg((2,))), + kwargs={'padding': 'same', 'groups': 3}), error_regex="expected weight to be at least 3 at dimension 0") + + # error inputs for invalid groups + yield ErrorInput( + SampleInput(make_arg((2, 2, 4, 5)), args=(make_arg((2, 2, 1, 4)), make_arg((2,))), + kwargs={'padding': 'same', 'groups': -1}), error_regex="non-positive groups is not supported") + + # error inputs for invalid groups + yield ErrorInput( + SampleInput(make_arg((2, 2, 4, 3)), args=(make_arg((2, 2, 4, 3)), make_arg((2,))), + kwargs={'padding': 'same', 'groups': 0}), error_regex="non-positive groups is not supported") def sample_inputs_conv2d(op_info, device, dtype, requires_grad, jit_fail_sample=False, **kwargs): @@ -3940,6 +4015,90 @@ def sample_inputs_conv2d(op_info, device, dtype, requires_grad, jit_fail_sample= ), kwargs=kwargs) +def sample_inputs_conv3d(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Ordered as shapes for input, weight, bias + # and dict of values of (stride, padding, dilation, groups) + cases: Tuple = ( + ((1, 1, 4, 4, 4), (1, 1, 1, 1, 1), (1,), {'padding': 'same'}), + ((1, 1, 4, 4, 4), (1, 1, 4, 4, 4), (1,), {'stride': (2, 2, 2)}), + ((1, 1, 5, 5, 5), (1, 1, 3, 3, 3), (1,), {'dilation': 2}), + ((1, 1, 1, 1, 10), (1, 1, 1, 1, 4), None, {'padding': 'valid'}), + ((1, 1, 10, 11, 12), (1, 1, 1, 2, 5), None, {'padding': 'same'}), + ((1, 1, 10, 11, 12), (1, 1, 1, 2, 5), None, {'padding': 'same', 'dilation': 2}), + ((1, 1, 10, 11, 12), (1, 1, 4, 4, 4), None, {'padding': 'same', 'dilation': 3}), + ((1, 1, 1, 1, 10), (1, 1, 1, 1, 4), None, {'padding': 'valid'}), + ((3, 9, 3, 1, 9), (3, 3, 3, 1, 9), (3,), {'groups': 3}), + ((3, 9, 3, 1, 9), (3, 3, 3, 1, 9), (3,), {'stride': (2, 2, 2), 'dilation': 1, 'groups': 3}), + ) + + for input_shape, weight, bias, kwargs in cases: + # Batched + yield SampleInput(make_arg(input_shape), args=( + make_arg(weight), + make_arg(bias) if bias is not None else bias + ), kwargs=kwargs) + # Unbatched + yield SampleInput(make_arg(input_shape[1:]), args=( + make_arg(weight), + make_arg(bias) if bias is not None else bias + ), kwargs=kwargs) + + +def error_inputs_conv3d(opinfo, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float64) + + # error inputs for negative strides + yield ErrorInput( + SampleInput(make_arg((1, 1, 4, 4, 4)), args=(make_arg((1, 1, 2, 2, 2)), make_arg((1,))), + kwargs={'stride': (-1,)}), error_regex="non-positive stride is not supported") + + # error inputs for negative padding + yield ErrorInput( + SampleInput(make_arg((1, 1, 4, 4, 4)), args=(make_arg((1, 1, 2, 2, 2)), make_arg((1,))), + kwargs={'padding': (-1,)}), error_regex="negative padding is not supported") + + # error inputs for negative dilation + yield ErrorInput( + SampleInput(make_arg((1, 1, 4, 4, 4)), args=(make_arg((1, 1, 2, 2, 2)), make_arg((1,))), + kwargs={'dilation': (-1,)}), error_regex="dilation should be greater than zero") + + # FIXME: https://github.com/pytorch/pytorch/issues/85656 + # error inputs for bias shape not equal to the output channels + # yield ErrorInput(SampleInput(make_arg((1, 1, 4, 4, 4)), args=(make_arg((1, 1, 3, 3, 3)), make_arg((2,)))), + # error_regex="expected bias to be 1-dimensional with 1 elements") + + # error inputs for input.ndim != weight.ndim + yield ErrorInput( + SampleInput(make_arg((1, 1, 3, 4, 5)), args=(make_arg((1, 1, 4, 3)), make_arg((1,))), + kwargs={'padding': 'same'}), error_regex="Expected 4-dimensional input for 4-dimensional weight") + + # error inputs for the weight[0] are less than the number of groups + yield ErrorInput( + SampleInput(make_arg((2, 2, 3, 4, 5)), args=(make_arg((2, 2, 4, 3, 3)), + make_arg((2,))), kwargs={'groups': 3}), + error_regex="expected weight to be at least 3 at dimension 0") + + # error inputs for the weight[0] are less than the number of groups + yield ErrorInput( + SampleInput(make_arg((2, 2, 3, 4, 5)), args=(make_arg((2, 2, 4, 3, 3)), + make_arg((2,))), kwargs={'padding': 'same', 'groups': 3}), + error_regex="expected weight to be at least 3 at dimension 0") + + # error inputs for invalid groups + yield ErrorInput( + SampleInput(make_arg((2, 2, 3, 4, 5)), args=(make_arg((2, 2, 4, 3, 3)), + make_arg((2,))), kwargs={'padding': 'same', 'groups': 0}), + error_regex="non-positive groups is not supported") + + # error inputs for padding='same' not supported by strided convolutions + yield ErrorInput( + SampleInput(make_arg((18, 27, 9, 1, 9)), args=(make_arg((9, 9, 9, 1, 9)), + make_arg((9,))), kwargs={'stride': 2, 'padding': 'same', 'groups': 3}), + error_regex="padding='same' is not supported for strided convolutions") + + def sample_inputs_group_norm(opinfo, device, dtype, requires_grad, **kwargs): make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) @@ -13076,6 +13235,35 @@ def reference_flatten(input, start_dim=0, end_dim=-1): ), supports_expanded_weight=True, supports_out=False,), + OpInfo('nn.functional.conv3d', + aliases=('conv3d',), + aten_name='conv3d', + dtypes=floating_and_complex_types_and(torch.int64, torch.bfloat16, torch.float16), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.chalf, torch.bfloat16), + sample_inputs_func=sample_inputs_conv3d, + error_inputs_func=error_inputs_conv3d, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + decorators=( + DecorateInfo( + toleranceOverride({torch.chalf: tol(atol=6e-2, rtol=5e-2)}), + 'TestCommon', 'test_complex_half_reference_testing', + ), + ), + skips=( + # RuntimeError: !lhs.isAliasOf(rhs) INTERNAL ASSERT FAILED at + # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":103, please report a bug to PyTorch. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + # RuntimeError: UNSUPPORTED DTYPE: complex + DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', + 'test_nnc_correctness', dtypes=(torch.complex64, torch.complex128)), + # RuntimeError: Conv3D is not supported on MPS + DecorateInfo(unittest.expectedFailure, 'TestConsistency'), + ), + supports_expanded_weight=True, + supports_out=False,), OpInfo('nn.functional.group_norm', aten_name='group_norm', aliases=('group_norm',), From 081c5b3adc342d0e201df18977ffa4d939c5edad Mon Sep 17 00:00:00 2001 From: voznesenskym Date: Sun, 26 Nov 2023 23:40:32 +0000 Subject: [PATCH 162/221] Add Stateful/Stateless symbolic contexts, use fresh fake mode for dynamo backends (#113926) (#114526) Summary: The primary problem we are setting out to solve here is fake tensor freshness. Before this PR, fake tensors after dynamo represented fake tensors *at the end* of trace, so subsequent retraces like aot_autograd would start off with fake tensors in the wrong (end result) state, rather than their expected fresh state. The solution here is to start a fresh fake mode, and re-fakify the tensors. The nuance comes from ensuring that symbols are uniformly created for the symbolic sizes and strides of the tensor. This PR is the result of *a lot* of back and forth with ezyang and eellison. Initially, the first pass at this was not super different from what we have in the PR - the broad strokes were the same: 1) We cache source->symbol in shape_env 2) We pass policy objects around, stored at dynamo fakificaiton time, and reused for later fakification 3) We create a new fake mode for backends (from https://github.com/pytorch/pytorch/pull/113605/files) This is ugly, and has some layering violations. We detoured our decision making through a few other alternatives. Immutable/mutable fake tensor mode was the most interesting alternative, https://github.com/pytorch/pytorch/pull/113653, and was struck down on concerns of complexity in fake mode combined with it not covering all edge cases. We also detoured on what to do about tensor memoization returning back potentially different tensors than requested, and if that was an anti pattern (it is) we want to hack in with the symbol cache (we don't). We went back to the drawing board here, but with a few concessions: 1) the cache for source->symbol must live outside of shape_env, for both lifecycle, and layering reasons 2) A good amount of work needs to be done to pipe policy around fake_mode and meta_utils correctly, to cover all the cases (ezyang did this) cc penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng imported-using-ghimport Test Plan: Imported from OSS Reviewed By: huydhn, Chillee Differential Revision: D51566250 Pulled By: voznesenskym Pull Request resolved: https://github.com/pytorch/pytorch/pull/114526 Approved by: https://github.com/Chillee, https://github.com/huydhn --- docs/source/conf.py | 4 +- test/dynamo/test_export.py | 4 +- test/dynamo/test_subclasses.py | 10 +- test/test_dynamic_shapes.py | 4 +- test/test_fake_tensor.py | 4 +- torch/_dynamo/backends/distributed.py | 4 +- torch/_dynamo/eval_frame.py | 4 +- torch/_dynamo/output_graph.py | 11 ++ torch/_dynamo/utils.py | 14 +++ torch/_dynamo/variables/builder.py | 48 +++++--- torch/_functorch/aot_autograd.py | 19 ++- torch/_guards.py | 3 + torch/_subclasses/fake_tensor.py | 29 +++-- torch/_subclasses/meta_utils.py | 26 ++-- torch/fx/experimental/symbolic_shapes.py | 144 +++++++++++++++++------ 15 files changed, 240 insertions(+), 88 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 2ec2d66bbcb0..dcd3c7694674 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -2866,7 +2866,9 @@ "ShapeGuardPrinter", "StrictMinMaxConstraint", "SymDispatchMode", - "CreateSymbolicPolicy", + "SymbolicContext", + "StatelessSymbolicContext", + "StatefulSymbolicContext", # torch.fx.experimental.unification.match "Dispatcher", "VarDispatcher", diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py index f8469ff15ca4..ca49583b7825 100644 --- a/test/dynamo/test_export.py +++ b/test/dynamo/test_export.py @@ -29,8 +29,8 @@ from torch.fx.experimental.symbolic_shapes import ( ConstraintViolationError, DimDynamic, - FreshCreateSymbolicPolicy, ShapeEnv, + StatelessSymbolicContext, ) from torch.testing._internal import common_utils from torch.testing._internal.common_cuda import TEST_CUDA @@ -3311,7 +3311,7 @@ def test_symbool_guards( ) as fake_mode: fake_x = fake_mode.from_tensor( x, - policy=FreshCreateSymbolicPolicy( + symbolic_context=StatelessSymbolicContext( dynamic_sizes=[DimDynamic.DYNAMIC for _ in range(x.dim())], ), ) diff --git a/test/dynamo/test_subclasses.py b/test/dynamo/test_subclasses.py index cadc164dd283..53fb7328e529 100644 --- a/test/dynamo/test_subclasses.py +++ b/test/dynamo/test_subclasses.py @@ -15,8 +15,8 @@ from torch.fx.experimental.symbolic_shapes import ( DimDynamic, - FreshCreateSymbolicPolicy, ShapeEnv, + StatelessSymbolicContext, ) from torch.nested._internal.nested_tensor import ( jagged_from_list, @@ -337,13 +337,13 @@ def test_dynamic_dim(f, x, dim_dynamic, exp_frame_count, exp_op_count): ) as fake_mode: x_fake = fake_mode.from_tensor( x, - policy=FreshCreateSymbolicPolicy( + symbolic_context=StatelessSymbolicContext( dynamic_sizes=[dim_dynamic for i in range(x.dim())] ), ) x1_fake = fake_mode.from_tensor( x1, - policy=FreshCreateSymbolicPolicy( + symbolic_context=StatelessSymbolicContext( dynamic_sizes=[dim_dynamic for i in range(x.dim())] ), ) @@ -373,7 +373,7 @@ def test_automatic_dynamic(f, inps, dim_dynamic, exp_frame_count, exp_op_count): for inp in inps: fake_inp = fake_mode.from_tensor( inp, - policy=FreshCreateSymbolicPolicy( + symbolic_context=StatelessSymbolicContext( [dim_dynamic for i in range(x.dim())] ), ) @@ -708,7 +708,7 @@ def test_recompilation( ) as fake_mode: fake_inp = fake_mode.from_tensor( x, - policy=FreshCreateSymbolicPolicy( + symbolic_context=StatelessSymbolicContext( dynamic_sizes=[DimDynamic.DYNAMIC for i in range(x.dim())] ), ) diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index daf293b43d00..bf843587af50 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -27,7 +27,7 @@ GuardOnDataDependentSymNode, ShapeEnv, is_symbolic, - FreshCreateSymbolicPolicy, + StatelessSymbolicContext, ) from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, @@ -137,7 +137,7 @@ def create_symbolic_tensor(name, arg, shape_env): shape_env.create_symbolic_sizes_strides_storage_offset( arg, source=ConstantSource(name), - policy=FreshCreateSymbolicPolicy( + symbolic_context=StatelessSymbolicContext( dynamic_sizes=dynamic_dims, constraint_sizes=constraint_dims ), diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py index 0b9f895f0a64..14a596508824 100644 --- a/test/test_fake_tensor.py +++ b/test/test_fake_tensor.py @@ -15,7 +15,7 @@ DynamicOutputShapeException, UnsupportedOperatorException, ) -from torch.fx.experimental.symbolic_shapes import ShapeEnv, DimDynamic, free_symbols, FreshCreateSymbolicPolicy +from torch.fx.experimental.symbolic_shapes import ShapeEnv, DimDynamic, free_symbols, StatelessSymbolicContext from torch.testing._internal.custom_op_db import custom_op_db from torch.testing._internal.common_device_type import ops from torch.testing._internal.common_device_type import instantiate_device_type_tests, OpDTypes @@ -541,7 +541,7 @@ def test_same_shape_env_preserved(self): mode1 = FakeTensorMode(shape_env=shape_env) t1 = mode1.from_tensor( torch.randn(10), - policy=FreshCreateSymbolicPolicy( + symbolic_context=StatelessSymbolicContext( dynamic_sizes=[DimDynamic.DYNAMIC], constraint_sizes=[None] ) diff --git a/torch/_dynamo/backends/distributed.py b/torch/_dynamo/backends/distributed.py index 90cb21c26351..adc68bb30bff 100644 --- a/torch/_dynamo/backends/distributed.py +++ b/torch/_dynamo/backends/distributed.py @@ -409,7 +409,9 @@ def run_node(self, n: Node) -> Any: if isinstance(arg, torch.Tensor) and not isinstance( arg, torch._subclasses.FakeTensor ): - new_args.append(fake_mode.from_tensor(arg)) + new_args.append( + torch._dynamo.utils.to_fake_tensor(arg, fake_mode) + ) else: new_args.append(arg) diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 4c0234d106e1..7fff2c3392fc 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -42,7 +42,7 @@ from torch.fx.experimental.symbolic_shapes import ( ConstraintViolationError, DimDynamic, - FreshCreateSymbolicPolicy, + StatelessSymbolicContext, ) from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo from torch.nn.parallel.distributed import DistributedDataParallel @@ -903,7 +903,7 @@ def __init__( # TODO(zhxchen17) Also preserve all the user constraints here. arg.node.meta["val"] = fake_mode.from_tensor( flat_args[i], - policy=FreshCreateSymbolicPolicy( + symbolic_context=StatelessSymbolicContext( dynamic_sizes=[ DimDynamic.DYNAMIC if d in flat_args_dynamic_dims[i] diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index e3a33c1503a8..b577b9ea94aa 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -1065,6 +1065,17 @@ def compile_and_call_fx_graph(self, tx, rv, root): "%s", LazyString(lambda: self.get_graph_sizes_log_str(name)) ) self.call_cleanup_hooks() + old_fake_mode = self.tracing_context.fake_mode + if not self.export: + # TODO(voz): The way export uses gm, and fake tensors, is not supported with us resetting + backend_fake_mode = torch._subclasses.FakeTensorMode( + shape_env=old_fake_mode.shape_env, + ) + # TODO(voz): Ostensibily, this should be scoped and + # restore back to old_fake_mode, but doing so currently violates + # a lot of fake_tensor ownership assumptions and runs afoul of detect_fake_mode + self.tracing_context.fake_mode = backend_fake_mode + with self.restore_global_state(): compiled_fn = self.call_user_compiler(gm) compiled_fn = disable(compiled_fn) diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 2a520eb304c5..47275ea04185 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -2295,3 +2295,17 @@ def has_torch_function(vt: "torch._dynamo.variables.base.VariableTracker") -> bo isinstance(vt, UserDefinedObjectVariable) and hasattr(vt.value, "__torch_function__") ) + + +# see note [Tensor Fakification and Symbol Caching] +def to_fake_tensor(t, fake_mode): + symbolic_context = None + source = None + if tracing_context := torch._guards.TracingContext.try_get(): + if t in tracing_context.tensor_to_context: + symbolic_context = tracing_context.tensor_to_context[t] + source = symbolic_context.tensor_source + + return fake_mode.from_tensor( + t, static_shapes=False, symbolic_context=symbolic_context, source=source + ) diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index be66d51c0f4d..a139efd3e166 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -26,11 +26,11 @@ from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode from torch.fx.experimental.symbolic_shapes import ( _constrain_range_for_size, - CreateSymbolicPolicy, DimConstraint, DimDynamic, - FreshCreateSymbolicPolicy, RelaxedUnspecConstraint, + StatefulSymbolicContext, + SymbolicContext, ) from torch.fx.immutable_collections import immutable_list from torch.nested._internal.nested_tensor import NestedTensor @@ -1564,23 +1564,33 @@ def __eq__(self, other: object) -> bool: # Performs automatic dynamic dim determination. -# Returns a CreateSymbolicPolicy -def _automatic_dynamic(e, tx, name, static_shapes) -> CreateSymbolicPolicy: +# Returns a SymbolicContext +def _automatic_dynamic(e, tx, source, static_shapes) -> SymbolicContext: + name = source.name() + prior_policy = tx.output.tracing_context.tensor_to_context.get(e, None) + source_to_symint_node_cache = ( + prior_policy.source_to_symint_node_cache if prior_policy else None + ) + if static_shapes: - return FreshCreateSymbolicPolicy( + return StatefulSymbolicContext( dynamic_sizes=[DimDynamic.STATIC] * e.dim(), constraint_sizes=[None] * e.dim(), + tensor_source=source, + source_to_symint_node_cache=source_to_symint_node_cache, ) # We preserve the dynamism of inputs. For example, when users call # make_fx(torch.cond, tracing_mode="symbolic")(*args), inputs have SymInt sizes. if any(isinstance(s, SymInt) for s in e.size()): - return FreshCreateSymbolicPolicy( + return StatefulSymbolicContext( dynamic_sizes=[ DimDynamic.DYNAMIC if isinstance(s, SymInt) else DimDynamic.STATIC for s in e.size() ], constraint_sizes=[None] * e.dim(), + tensor_source=source, + source_to_symint_node_cache=source_to_symint_node_cache, ) # Prep for automatic dynamic @@ -1699,7 +1709,7 @@ def update_dim2constraint(dim, constraint_range, debug_name): # Now, figure out if the dim is dynamic/duck/static if constraint_dim is not None or marked_dynamic or marked_weak_dynamic: # NB: We could assert static_shapes is False here, but it - # seems better to allow the user to override policy in this + # seems better to allow the user to override symbolic_context in this # case dynamic = DimDynamic.DYNAMIC elif static_shapes or config.assume_static_by_default or marked_static: @@ -1711,12 +1721,15 @@ def update_dim2constraint(dim, constraint_range, debug_name): tx.output.frame_state[name] = frame_state_entry - return FreshCreateSymbolicPolicy( + return StatefulSymbolicContext( dynamic_sizes=dynamic_dims, constraint_sizes=constraint_dims, + tensor_source=source, + source_to_symint_node_cache=source_to_symint_node_cache, ) +# See note [Tensor Fakification and Symbol Caching] def wrap_to_fake_tensor_and_record(e, tx, *, source: Optional[Source], is_tensor: bool): if ( type(e) in (torch.Tensor, torch.nn.Parameter, FakeTensor) @@ -1728,31 +1741,36 @@ def wrap_to_fake_tensor_and_record(e, tx, *, source: Optional[Source], is_tensor e, is_tensor, guard_source=source.guard_source() ) - policy = None + symbolic_context = None if not e.is_nested: # TODO: We should probably support this for nested tensors too - policy = _automatic_dynamic(e, tx, source.name(), static_shapes) + symbolic_context = _automatic_dynamic(e, tx, source, static_shapes) + + if symbolic_context: + tx.output.tracing_context.tensor_to_context[e] = symbolic_context log.debug( "wrap_to_fake %s %s %s %s", source.name(), tuple(e.shape), - policy.dynamic_sizes if policy is not None else None, - policy.constraint_sizes if policy is not None else None, + symbolic_context.dynamic_sizes if symbolic_context is not None else None, + symbolic_context.constraint_sizes if symbolic_context is not None else None, ) fake_e = wrap_fake_exception( lambda: tx.fake_mode.from_tensor( e, source=source, - policy=policy, + symbolic_context=symbolic_context, ) ) - # TODO: just store the whole policy here + # TODO: just store the whole symbolic_context here tx.output.tracked_fakes.append( TrackedFake( fake_e, source, - policy.constraint_sizes if policy is not None else None, + symbolic_context.constraint_sizes + if symbolic_context is not None + else None, ) ) tx.output.tracked_fakes_id_to_source[id(e)].append(source) diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py index 58bf329dc566..db83c84e8a6b 100644 --- a/torch/_functorch/aot_autograd.py +++ b/torch/_functorch/aot_autograd.py @@ -4347,14 +4347,27 @@ def convert(idx, x): if all(isinstance(getattr(x, attr), FakeTensor) for attr in attrs): assert all(getattr(x, attr).fake_mode is fake_mode for attr in attrs) return x - # TODO: Ensure that this codepath is never exercised from - # Dynamo + + + # see note [Tensor Fakification and Symbol Caching] + symbolic_context = None + source = None + if tracing_context := torch._guards.TracingContext.try_get(): + if x in tracing_context.tensor_to_context: + symbolic_context = tracing_context.tensor_to_context[x] + source = symbolic_context.tensor_source if ( idx < aot_config.num_params_buffers and config.static_weight_shapes + and not symbolic_context ): + # TODO: Ensure that this codepath is never exercised from + # Dynamo return fake_mode.from_tensor(x, static_shapes=True) - return fake_mode.from_tensor(x, static_shapes=False) + + return fake_mode.from_tensor( + x, static_shapes=False, symbolic_context=symbolic_context, source=source + ) return [convert(idx, x) for idx, x in enumerate(flat_args)] diff --git a/torch/_guards.py b/torch/_guards.py index fe3a10d663b7..69912b15313d 100644 --- a/torch/_guards.py +++ b/torch/_guards.py @@ -29,6 +29,7 @@ import torch from torch.utils import _pytree as pytree from torch.utils._traceback import CapturedTraceback +from torch.utils.weak import WeakTensorKeyDictionary log = logging.getLogger(__name__) @@ -618,6 +619,8 @@ def __init__(self, fake_mode): # ints that are known to be size-like and may have 0/1 entries that we # must not specialize on. self.force_unspec_int_unbacked_size_like = False + # See note [Tensor Fakification and Symbol Caching] + self.tensor_to_context = WeakTensorKeyDictionary() @staticmethod @contextmanager diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index b36bc4c5bf8b..e505dfb3bb2b 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -312,9 +312,16 @@ def from_real_tensor( shape_env=None, *, source=None, - policy=None, + symbolic_context=None, memoized_only=False, ): + # see note [Tensor Fakification and Symbol Caching] + if not symbolic_context and not source and shape_env: + if tracing_context := torch._guards.TracingContext.try_get(): + if t in tracing_context.tensor_to_context: + symbolic_context = tracing_context.tensor_to_context[t] + source = symbolic_context.tensor_source + maybe_memo = self._get_memo(t) if maybe_memo is not None: return maybe_memo @@ -348,7 +355,7 @@ def mk_fake_tensor(make_meta_t): shape_env=shape_env, callback=mk_fake_tensor, source=source, - policy=policy, + symbolic_context=symbolic_context, ) if out is NotImplemented: raise UnsupportedFakeTensorException("meta converter nyi") @@ -383,7 +390,7 @@ def __call__( make_constant=False, shape_env=None, source=None, - policy=None, + symbolic_context=None, memoized_only=False, ): return self.from_real_tensor( @@ -392,7 +399,7 @@ def __call__( make_constant, shape_env=shape_env, source=source, - policy=policy, + symbolic_context=symbolic_context, memoized_only=memoized_only, ) @@ -1855,7 +1862,7 @@ def from_tensor( *, static_shapes=None, source: Optional[Source] = None, - policy=None, + symbolic_context=None, # Setting this flag will force FakeTensorMode to return `None` if attempting to convert a tensor we have not # seen before. memoized_only=False, @@ -1864,14 +1871,22 @@ def from_tensor( if static_shapes is None: static_shapes = self.static_shapes if static_shapes: - assert policy is None, "cannot set both static_shapes and policy" + assert ( + symbolic_context is None + ), "cannot set both static_shapes and symbolic_context" shape_env = None + # see note [Tensor Fakification and Symbol Caching] + if not symbolic_context and not source and not static_shapes: + if tracing_context := torch._guards.TracingContext.try_get(): + if tensor in tracing_context.tensor_to_context: + symbolic_context = tracing_context.tensor_to_context[tensor] + source = symbolic_context.tensor_source return self.fake_tensor_converter( self, tensor, shape_env=shape_env, source=source, - policy=policy, + symbolic_context=symbolic_context, memoized_only=memoized_only, ) diff --git a/torch/_subclasses/meta_utils.py b/torch/_subclasses/meta_utils.py index 1ff2a156379d..8db8f94b1b41 100644 --- a/torch/_subclasses/meta_utils.py +++ b/torch/_subclasses/meta_utils.py @@ -23,7 +23,7 @@ if TYPE_CHECKING: # Import the following modules during type checking to enable code intelligence features, # Do not import unconditionally, as they import sympy and importing sympy is very slow - from torch.fx.experimental.symbolic_shapes import CreateSymbolicPolicy + from torch.fx.experimental.symbolic_shapes import SymbolicContext DimList = List @@ -184,7 +184,7 @@ def meta_tensor( shape_env=None, callback=lambda t: t(), source: Optional[Source] = None, - policy: Optional["CreateSymbolicPolicy"] = None, + symbolic_context: Optional["SymbolicContext"] = None, ): from torch._subclasses.fake_tensor import FakeTensor @@ -250,10 +250,10 @@ def sym_sizes_strides_storage_offset( # the wrapper tensor and any inner tensors. # We can revisit this if this assumption does not hold # for any important subclasses later. - policy=policy, + symbolic_context=symbolic_context, ) else: - assert policy is None + assert symbolic_context is None return (t.size(), t.stride(), t.storage_offset()) # see expired-storages @@ -315,22 +315,22 @@ def sym_sizes_strides_storage_offset( from torch._dynamo.source import AttrSource from torch.fx.experimental.symbolic_shapes import ( DimDynamic, - FreshCreateSymbolicPolicy, + StatelessSymbolicContext, ) if shape_env and not t.is_nested and not t._base.is_nested: - base_policy = FreshCreateSymbolicPolicy( + base_symbolic_context = StatelessSymbolicContext( dynamic_sizes=[DimDynamic.STATIC] * t._base.dim(), constraint_sizes=[None] * t._base.dim(), ) else: - base_policy = None + base_symbolic_context = None base = self.meta_tensor( t._base, shape_env, callback, source=AttrSource(source, "_base"), - policy=base_policy, + symbolic_context=base_symbolic_context, ) def is_c_of_r(complex_dtype, real_dtype): @@ -620,7 +620,7 @@ def empty_create(inner_t, inner_src): shape_env, callback, source=AttrSource(source, "grad"), - policy=policy, + symbolic_context=symbolic_context, ) torch._C._set_conj(r, t.is_conj()) torch._C._set_neg(r, t.is_neg()) @@ -637,7 +637,7 @@ def __call__( *, callback=lambda t: t(), source=None, - policy=None, + symbolic_context=None, ): # TODO: zero tensors? We appear to have eliminated them by # excluding complex for now @@ -682,7 +682,7 @@ def __call__( shape_env=shape_env, callback=callback, source=source, - policy=policy, + symbolic_context=symbolic_context, ) out = torch._to_functional_tensor(fake_t) torch._mirror_autograd_meta_to(fake_t, out) @@ -700,7 +700,7 @@ def __call__( shape_env=shape_env, callback=callback, source=source, - policy=policy, + symbolic_context=symbolic_context, ) return _wrap_functional_tensor(fake_t, current_level()) self.miss += 1 @@ -712,7 +712,7 @@ def __call__( shape_env=shape_env, callback=callback, source=source, - policy=policy, + symbolic_context=symbolic_context, ) if type(t) is torch.nn.Parameter: # NB: Cannot directly use Parameter constructor diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 3d97727ff7b8..7f056ec9d5a7 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -63,7 +63,7 @@ class GuardOnDataDependentSymNode(RuntimeError): "guard_int", "guard_float", "guard_scalar", "hint_int", "SYMPY_INTERP", "free_symbols", "is_symbol_binding_fx_node", "is_concrete_bool", "SHAPEENV_EVENT_KEY", "CURRENT_NODE_KEY", - "has_free_symbols", "sym_eq", "CreateSymbolicPolicy", "FreshCreateSymbolicPolicy", + "has_free_symbols", "sym_eq", "SymbolicContext", "StatelessSymbolicContext", "StatefulSymbolicContext" ] # FX node metadata keys for symbolic shape FX graph. @@ -721,8 +721,14 @@ def render(self): def is_equal(self, source1, source2): return self._find(source1) == self._find(source2) + +def _assert_symbol_context(symbolic_context): + assert isinstance(symbolic_context, SymbolicContext), "Invalid symbolic_context object" + assert type(symbolic_context) is not SymbolicContext, "Illegal usage of symbolic_context ABC" + + @dataclass(frozen=True) -class CreateSymbolicPolicy: +class SymbolicContext: """ Data structure specifying how we should create symbols in ``create_symbolic_sizes_strides_storage_offset``; e.g., should @@ -736,20 +742,67 @@ class CreateSymbolicPolicy: @dataclass(frozen=True) -class FreshCreateSymbolicPolicy(CreateSymbolicPolicy): +class StatelessSymbolicContext(SymbolicContext): """ Create symbols in ``create_symbolic_sizes_strides_storage_offset`` via - a policy determination as given by ``DimDynamic`` and ``DimConstraint``. + a symbolic_context determination as given by ``DimDynamic`` and ``DimConstraint``. This will cause fresh symbols to be allocated """ dynamic_sizes: DimList[DimDynamic] constraint_sizes: DimList[DimConstraint] = None - # TODO: add storage offset and stride policy + # TODO: add storage offset and stride symbolic_context def __post_init__(self): if self.constraint_sizes is None: object.__setattr__(self, 'constraint_sizes', [None] * len(self.dynamic_sizes)) + +# note [Tensor Fakification and Symbol Caching] +# +# As of the time of this note, dynamo creates a fresh fake tensor mode for backends. +# The reason we do this is because there are certain classes of operations, namely, +# metadata mutations, that change tensor size, stride, etc. This means that the fake tensor +# state at the end of a dynamo trace is different than the fake tensor state at the beginning +# of a trace. Backends like aot_autograd need a fresh fake tensor to correctly track metadata mutation, +# view relationships, etc. +# +# As we create a new fake mode, we also lose the memoization that comes with it. Rather than +# transfer the memoization cache, we instead transfer the shape env. However, with this +# comes nuance - as dynamo is selective in how it makes symbolic shapes. Due to strategies in +# automatic dynamic and constraints, the policy for which dims are dynamic is nuanced and varies across +# recompilations. +# +# In order to preserve the symbolic decisions made during dynamo tensor fakification, we pass +# a StatefulSymbolicContext at creation time. This object is tracked, per tensor, on the TracingContext. +# The lifecycle of this object should match the lifecycle of the original dynamo tracked tensor, and it is +# safe to reuse this object as many times as necessary to create a fake tensor. Fake tensors +# created with new fake modes should produce the same exact symbols as the original, providing the same shape_env +# is used. +# TODO(voz): Shape env validation +@dataclass(frozen=True) +class StatefulSymbolicContext(StatelessSymbolicContext): + """ + Create symbols in ``create_symbolic_sizes_strides_storage_offset`` via + a symbolic_context determination as given by a cache of Source:Symbol. A cache hit + will reuse a stored symbol, and a cache miss will write to this cache. + + This behaves like StatelessSymbolicContext, except the cache supersedes the + other values - dynamic_sizes and constraint_sizes will not be read if we cache + hit. + + It is the cache owners responsibility to maintain the lifecycle of the cache + w/r/t different shape_envs, clearing, etc. + """ + tensor_source: Source = None + source_to_symint_node_cache : Dict["TensorPropertySource", SymInt] = None + + def __post_init__(self): + # The None default is annoying, but required because of dataclass limitations + assert self.tensor_source is not None + if not self.source_to_symint_node_cache: + object.__setattr__(self, 'source_to_symint_node_cache', {}) + + def is_symbolic(val: Union[int, SymInt, float, SymFloat, bool, SymBool]) -> bool: if isinstance(val, (int, float, bool)): return False @@ -1922,20 +1975,20 @@ def _update_version_counter(self): def _produce_dyn_sizes(self, ex_size: Sequence[int], source: Source, - policy: CreateSymbolicPolicy + symbolic_context: SymbolicContext ) -> List[sympy.Expr]: - return self._produce_dyn_sizes_from_int_tuple(tuple(ex.size()), source, policy) + return self._produce_dyn_sizes_from_int_tuple(tuple(ex.size()), source, symbolic_context) def _produce_dyn_sizes_from_int_tuple(self, tensor_size: Tuple[int], source: Source, - policy: CreateSymbolicPolicy, + symbolic_context: SymbolicContext, ) -> List[sympy.Expr]: assert all(not is_symbolic(val) for val in tensor_size), f"Expect size to be a plain tuple of ints but got {tensor_size}" from torch._dynamo.source import TensorPropertySource, TensorProperty - assert isinstance(policy, FreshCreateSymbolicPolicy) - dynamic_dims = policy.dynamic_sizes - constraint_dims = policy.constraint_sizes + _assert_symbol_context(symbolic_context) + dynamic_dims = symbolic_context.dynamic_sizes + constraint_dims = symbolic_context.constraint_sizes size = [] for i, val in enumerate(tensor_size): size.append(self.create_symbol( @@ -1948,7 +2001,7 @@ def create_symbolic_sizes_strides_storage_offset( ex: torch.Tensor, source: Source, *, - policy: Optional[CreateSymbolicPolicy] = None, + symbolic_context: Optional[SymbolicContext] = None, ): """ Returns a list of symbolic sizes and strides for the given tensor. @@ -2010,7 +2063,7 @@ def maybe_specialize_sym_int_with_hint(maybe_sym) -> int: ex_storage_offset, [_is_dim_dynamic(ex, i) for i in range(ex.dim())], source, - policy=policy, + symbolic_context=symbolic_context, ) @record_shapeenv_event() @@ -2022,12 +2075,12 @@ def _create_symbolic_sizes_strides_storage_offset( is_dim_dynamic: Sequence[bool], source: Source, *, - policy: Optional[CreateSymbolicPolicy] = None, + symbolic_context: Optional[SymbolicContext] = None, ): dim = len(ex_size) # Reimplement the legacy behavior - if policy is None: + if symbolic_context is None: constraint_dims = [None] * dim dynamic_dims = [] for i in range(dim): @@ -2041,13 +2094,14 @@ def _create_symbolic_sizes_strides_storage_offset( r = DimDynamic.DUCK dynamic_dims.append(r) dynamic_dims = [DimDynamic.DUCK] * dim - policy = FreshCreateSymbolicPolicy(dynamic_sizes=dynamic_dims, constraint_sizes=constraint_dims) - - assert isinstance(policy, FreshCreateSymbolicPolicy) - constraint_dims = policy.constraint_sizes - dynamic_dims = policy.dynamic_sizes - - # TODO: make this configurable from outside policy; we made a policy + # symbolic_context is None - set one + symbolic_context = StatelessSymbolicContext(dynamic_sizes=dynamic_dims, constraint_sizes=constraint_dims) + # We got a StatelessSymbolicContext + _assert_symbol_context(symbolic_context) + constraint_dims = symbolic_context.constraint_sizes + dynamic_dims = symbolic_context.dynamic_sizes + + # TODO: make this configurable from outside symbolic_context; we made a symbolic_context # decision here where if all sizes are static, we are going to # specialize all of the inner strides/offset too. We don't have to # do this. @@ -2058,7 +2112,7 @@ def _create_symbolic_sizes_strides_storage_offset( assert len(constraint_dims) == dim from torch._dynamo.source import TensorPropertySource, TensorProperty - size: List[sympy.Expr] = self._produce_dyn_sizes_from_int_tuple(ex_size, source, policy) + size: List[sympy.Expr] = self._produce_dyn_sizes_from_int_tuple(ex_size, source, symbolic_context) stride: List[Optional[sympy.Expr]] = [None] * len(size) for i, val in enumerate(ex_stride): if val in (0, 1): @@ -2096,7 +2150,12 @@ def _create_symbolic_sizes_strides_storage_offset( assert all(x is not None for x in stride) sym_sizes = [ - self.create_symintnode(sym, hint=hint, source=TensorPropertySource(source, TensorProperty.SIZE, i)) + self.create_symintnode( + sym, + hint=hint, + source=TensorPropertySource(source, TensorProperty.SIZE, i), + symbolic_context=symbolic_context + ) for i, (sym, hint) in enumerate(zip(size, ex_size)) ] sym_stride = [] @@ -2105,14 +2164,17 @@ def _create_symbolic_sizes_strides_storage_offset( # we computed assert stride_expr is not None sym_stride.append(self.create_symintnode( - stride_expr, hint=ex_stride[i], source=TensorPropertySource(source, TensorProperty.STRIDE, i) - )) - sym_storage_offset = self.create_symintnode(self.create_symbol( - ex_storage_offset, - TensorPropertySource(source, TensorProperty.STORAGE_OFFSET), - dynamic_dim=DimDynamic.DYNAMIC, - constraint_dim=None, - ), hint=ex_storage_offset, source=TensorPropertySource(source, TensorProperty.STORAGE_OFFSET)) + stride_expr, hint=ex_stride[i], source=TensorPropertySource(source, TensorProperty.STRIDE, i), + symbolic_context=symbolic_context)) + sym_storage_offset = self.create_symintnode( + self.create_symbol( + ex_storage_offset, + TensorPropertySource(source, TensorProperty.STORAGE_OFFSET), + dynamic_dim=DimDynamic.DYNAMIC, + constraint_dim=None, + ), + hint=ex_storage_offset, + source=TensorPropertySource(source, TensorProperty.STORAGE_OFFSET), symbolic_context=symbolic_context) return tuple(sym_sizes), tuple(sym_stride), sym_storage_offset # If you know what the current hint value of the SymInt to be created @@ -2125,7 +2187,10 @@ def create_symintnode( *, hint: Optional[int], source: Optional[Source] = None, + symbolic_context: Optional[SymbolicContext] = None, ): + source_name = source.name() if source else None + if self._translation_validation_enabled and source is not None: # Create a new symbol for this source. symbol = self._create_symbol_for_source(source) @@ -2139,11 +2204,20 @@ def create_symintnode( else: fx_node = None + # see note [Tensor Fakification and Symbol Caching] + if isinstance(symbolic_context, StatefulSymbolicContext) and source_name: + if source_name in symbolic_context.source_to_symint_node_cache: + return symbolic_context.source_to_symint_node_cache[source_name] + if isinstance(sym, sympy.Integer): if hint is not None: assert int(sym) == hint - return int(sym) - return SymInt(SymNode(sym, self, int, hint, fx_node=fx_node)) + out = int(sym) + else: + out = SymInt(SymNode(sym, self, int, hint, fx_node=fx_node)) + if isinstance(symbolic_context, StatefulSymbolicContext) and source_name: + symbolic_context.source_to_symint_node_cache[source_name] = out + return out @record_shapeenv_event() def create_unspecified_symint_and_symbol(self, value, source, dynamic_dim): @@ -2238,7 +2312,7 @@ def create_symbol( assert isinstance(source, Source), f"{type(source)} {source}" assert not (positive and val < 0), f"positive set for negative value: {val}" # It's always sound to allocate a symbol as DYNAMIC. If the user - # constrained the symbol, force the policy to DYNAMIC, because our + # constrained the symbol, force the symbolic_context to DYNAMIC, because our # constraint code will do weird stuff if, e.g., it's duck shaped if constraint_dim is not None: dynamic_dim = DimDynamic.DYNAMIC From 624f2025229312bffd725775760dc9b147cc32d8 Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Wed, 22 Nov 2023 13:32:42 -0800 Subject: [PATCH 163/221] [dtensor] add CommDebugMode for debugging (#113592) This PR adds a CommDebugMode debugging tool to record the number of distributed collectives, utilizing TorchDispatchMode, the idea borrows from the FlopCounterMode and we can expand this later to make it more feature complete like the FlopCounterMode This is useful for debugging with DTensor and testing, in general this fits for any complex distributed algorithms where it's non-trival to understand the algorithm, we can use this tool to understand what happened under the hood., we can later cover c10d collectives directly Not sure if it would be a good general distributed debug tool yet, so adding to the dtensor package first Pull Request resolved: https://github.com/pytorch/pytorch/pull/113592 Approved by: https://github.com/wconstab --- .../_tensor/debug/test_comm_mode.py | 57 +++++++++++++++ torch/distributed/_tensor/debug/__init__.py | 2 + torch/distributed/_tensor/debug/comm_mode.py | 69 +++++++++++++++++++ 3 files changed, 128 insertions(+) create mode 100644 test/distributed/_tensor/debug/test_comm_mode.py create mode 100644 torch/distributed/_tensor/debug/comm_mode.py diff --git a/test/distributed/_tensor/debug/test_comm_mode.py b/test/distributed/_tensor/debug/test_comm_mode.py new file mode 100644 index 000000000000..c52fb84d52c7 --- /dev/null +++ b/test/distributed/_tensor/debug/test_comm_mode.py @@ -0,0 +1,57 @@ +# Owner(s): ["oncall: distributed"] + +import torch +import torch.distributed as dist + +import torch.distributed._functional_collectives as funcol +import torch.nn as nn + +from torch.distributed._tensor.debug.comm_mode import CommDebugMode +from torch.testing._internal.common_utils import run_tests, TestCase +from torch.testing._internal.distributed._tensor.common_dtensor import MLPModule +from torch.testing._internal.distributed.fake_pg import FakeStore + +c10d_functional = torch.ops.c10d_functional + + +class TestCommMode(TestCase): + def tearDown(self): + super().tearDown() + dist.destroy_process_group() + + def setUp(self): + super().setUp() + store = FakeStore() + dist.init_process_group(backend="fake", rank=1, world_size=2, store=store) + self.device_type = "cuda" if torch.cuda.is_available() else "cpu" + self.world_pg = dist.distributed_c10d._get_default_group() + + def test_comm_mode(self): + world_pg = self.world_pg + + class WrapperModel(nn.Module): + def __init__(self, device): + super().__init__() + self.model = MLPModule(device=device) + + def forward(self, x): + x = funcol.all_gather_tensor(x, 0, world_pg) + x = funcol.reduce_scatter_tensor(x, "sum", 0, world_pg) + out = self.model(x) + return funcol.all_reduce(out, "sum", world_pg) + + model = WrapperModel(self.device_type) + + comm_mode = CommDebugMode() + with comm_mode: + model(torch.randn(20, 10, device=self.device_type)) + + comm_counts = comm_mode.get_comm_counts() + self.assertEqual(comm_mode.get_total_counts(), 3) + self.assertEqual(comm_counts[c10d_functional.all_reduce], 1) + self.assertEqual(comm_counts[c10d_functional.all_gather_into_tensor], 1) + self.assertEqual(comm_counts[c10d_functional.reduce_scatter_tensor], 1) + + +if __name__ == "__main__": + run_tests() diff --git a/torch/distributed/_tensor/debug/__init__.py b/torch/distributed/_tensor/debug/__init__.py index 47b38d54331f..2cd388cf93e4 100644 --- a/torch/distributed/_tensor/debug/__init__.py +++ b/torch/distributed/_tensor/debug/__init__.py @@ -1,5 +1,7 @@ from torch.distributed._tensor.api import DTensor +from torch.distributed._tensor.debug.comm_mode import CommDebugMode + def get_sharding_prop_cache_info(): """ diff --git a/torch/distributed/_tensor/debug/comm_mode.py b/torch/distributed/_tensor/debug/comm_mode.py new file mode 100644 index 000000000000..25852a842352 --- /dev/null +++ b/torch/distributed/_tensor/debug/comm_mode.py @@ -0,0 +1,69 @@ +from collections import defaultdict +from typing import Any, Dict + +import torch +from torch.utils._python_dispatch import TorchDispatchMode + + +funcol = torch.ops.c10d_functional + + +class CommDebugMode(TorchDispatchMode): + """ + ``CommDebugMode`` is a context manager that counts the number of + functional collectives within its context. It does this using a + ``TorchDispatchMode``. + + NOTE: this mode only works for functional collective atm and the + distributed_c10d collectives are not supported yet. + + Example usage + + .. code-block:: python + + mod = ... + comm_mode = CommDebugMode() + with comm_mode: + mod.sum().backward() + + """ + + def __init__(self): + self.comm_counts: Dict[Any, int] = defaultdict(int) + self.comm_registry = { + funcol.all_gather_into_tensor, + funcol.all_gather_into_tensor_coalesced, + funcol.all_reduce, + funcol.all_to_all_single, + funcol.broadcast, + funcol.reduce_scatter_tensor, + funcol.reduce_scatter_tensor_coalesced, + } + + def get_total_counts(self) -> int: + return sum(self.comm_counts.values()) + + def get_comm_counts(self) -> Dict[Any, int]: + """Returns the communication counts as a dictionary. + + Returns: + Dict[Any, int]: The communication counts as a dictionary. + """ + return self.comm_counts + + def __enter__(self): + self.comm_counts.clear() + super().__enter__() + return self + + def __exit__(self, *args): + super().__exit__(*args) + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + kwargs = kwargs if kwargs else {} + out = func(*args, **kwargs) + func_packet = func._overloadpacket + if func_packet in self.comm_registry: + self.comm_counts[func_packet] += 1 + + return out From b62c0d96bcbe5f354ddce930fbdcd992dbaf1ce8 Mon Sep 17 00:00:00 2001 From: Zhengxu Chen Date: Mon, 27 Nov 2023 04:53:38 +0000 Subject: [PATCH 164/221] [export] Support user input mutation. [1/2] (#114496) Summary: Serialization not implemented yet. Will do in the next diff. Resolving Github issues: https://github.com/pytorch/pytorch/issues/112429 https://github.com/pytorch/pytorch/issues/114142 Test Plan: buck2 run mode/opt caffe2/test:test_export -- -r test_export_ input_mutation Differential Revision: D51556962 Pull Request resolved: https://github.com/pytorch/pytorch/pull/114496 Approved by: https://github.com/tugsbayasgalan --- test/export/test_db.py | 14 ++- test/export/test_export.py | 55 +++++++-- test/export/test_passes.py | 42 ++----- test/export/test_serialize.py | 1 + test/export/test_unflatten.py | 4 +- torch/_export/__init__.py | 107 ++++++++++++++++-- .../db/examples/user_input_mutation.py | 4 +- torch/_export/verifier.py | 37 +++--- torch/export/exported_program.py | 33 ++++-- torch/export/graph_signature.py | 74 ++++++++---- 10 files changed, 258 insertions(+), 113 deletions(-) diff --git a/test/export/test_db.py b/test/export/test_db.py index d126684beb60..a2d2f36af42e 100644 --- a/test/export/test_db.py +++ b/test/export/test_db.py @@ -1,14 +1,15 @@ # Owner(s): ["module: dynamo"] +import copy import unittest import torch._dynamo as torchdynamo -from torch.export import export from torch._export.db.case import ExportCase, normalize_inputs, SupportLevel from torch._export.db.examples import ( filter_examples_by_support_level, get_rewrite_cases, ) +from torch.export import export from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, @@ -28,18 +29,19 @@ class ExampleTests(TestCase): def test_exportdb_supported(self, name: str, case: ExportCase) -> None: model = case.model - inputs = normalize_inputs(case.example_inputs) + inputs_export = normalize_inputs(case.example_inputs) + inputs_model = copy.deepcopy(inputs_export) exported_program = export( model, - inputs.args, - inputs.kwargs, + inputs_export.args, + inputs_export.kwargs, dynamic_shapes=case.dynamic_shapes, ) exported_program.graph_module.print_readable() self.assertEqual( - exported_program(*inputs.args, **inputs.kwargs), - model(*inputs.args, **inputs.kwargs), + exported_program(*inputs_export.args, **inputs_export.kwargs), + model(*inputs_model.args, **inputs_model.kwargs), ) if case.extra_inputs is not None: diff --git a/test/export/test_export.py b/test/export/test_export.py index 221ea9ba075b..caa576bfa987 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -1,5 +1,6 @@ # Owner(s): ["module: dynamo"] # flake8: noqa +import copy import dataclasses import unittest from contextlib import contextmanager @@ -1092,13 +1093,13 @@ def f(x, y): torch.allclose(exported(torch.ones(8, 5), 5), f(torch.ones(8, 5), 5)) ) with self.assertRaisesRegex( - RuntimeError, "Input arg1_1 is specialized to be 5 at tracing time" + RuntimeError, "is specialized to be 5 at tracing time" ): _ = exported(torch.ones(8, 5), 6) exported = torch.export.export(f, (tensor_inp, 5.0), dynamic_shapes=dynamic_shapes) with self.assertRaisesRegex( - RuntimeError, "Input arg1_1 is specialized to be 5.0 at tracing time" + RuntimeError, "is specialized to be 5.0 at tracing time" ): _ = exported(torch.ones(7, 5), 6.0) @@ -1109,7 +1110,7 @@ def g(a, b, mode): inps = (torch.randn(4, 4), torch.randn(4), "trunc") exported = torch._export.export(g, inps) - with self.assertRaisesRegex(RuntimeError, "Input arg2_1 is specialized to be trunc at"): + with self.assertRaisesRegex(RuntimeError, "is specialized to be trunc at"): _ = exported(torch.randn(4, 4), torch.randn(4), "floor") self.assertTrue(torch.allclose(exported(*inps), g(*inps))) @@ -1190,7 +1191,7 @@ def forward(self, x): dim0_x = torch.export.Dim("dim0_x") exported = torch.export.export(Foo(), (inp,), dynamic_shapes={"x": {0: dim0_x}}) reexported = torch.export.export(exported, (inp,)) - with self.assertRaisesRegex(RuntimeError, "Input arg2_1\.shape\[0\] is specialized at 5"): + with self.assertRaisesRegex(RuntimeError, "shape\[0\] is specialized at 5"): reexported(torch.ones(7, 5)) reexported = torch.export.export(exported, (inp,), dynamic_shapes=({0: dim0_x},)) @@ -1199,7 +1200,7 @@ def forward(self, x): # can't retrace with invalid inputs with respect to the original ExportedProgram dim0_x_v2 = torch.export.Dim("dim0_x_v2", min=3) exported_v2 = torch.export.export(Foo(), (inp,), dynamic_shapes={"x": {0: dim0_x_v2}}) - with self.assertRaisesRegex(RuntimeError, "Input arg2_1"): + with self.assertRaisesRegex(RuntimeError, "shape\[1\] is specialized at 5"): torch.export.export(exported_v2, (torch.randn(2, 2),)) def test_retrace_graph_level_meta_preservation(self): @@ -1472,8 +1473,8 @@ def f(x): ep = export(f, (torch.tensor([3]),)) self.assertExpectedInline(str(ep.graph_module.code).strip(), """\ -def forward(self, arg0_1): - _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(arg0_1); arg0_1 = None +def forward(self, l_x_): + _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(l_x_); l_x_ = None ge = _local_scalar_dense >= 0 scalar_tensor = torch.ops.aten.scalar_tensor.default(ge); ge = None _assert_async = torch.ops.aten._assert_async.msg(scalar_tensor, '_local_scalar_dense is outside of inline constraint [0, inf].'); scalar_tensor = None @@ -1492,7 +1493,7 @@ def foo(a, b): self.assertEqual(ep(*test_inp), foo(*test_inp)) ep_v2 = torch.export.export(foo, (torch.randn(4, 4), torch.randn(4, 4)), dynamic_shapes=(None, None)) - with self.assertRaisesRegex(RuntimeError, "Input arg1_1.shape\[0\] is specialized at 4"): + with self.assertRaisesRegex(RuntimeError, "shape\[0\] is specialized at 4"): ep_v2(*test_inp) def test_non_arg_name_dynamic_shapes_api_with_kwarg(self): @@ -1540,7 +1541,7 @@ def dynamify_inp(x): test_inp = ((torch.randn(4, 4), torch.randn(2, 4)), torch.randn(4, 4)) with self.assertRaisesRegex( RuntimeError, - "Input arg1_1.shape\[0\] is outside of specified dynamic range \[3, inf\]" + "shape\[0\] is outside of specified dynamic range \[3, inf\]" ): ep(*test_inp) @@ -1721,6 +1722,42 @@ def forward(self, scores, mask): optimized_model = torch.compile(exported_model) optimized_model(tensor_cpu, mask_cpu) + def test_export_input_mutation_static_shape(self): + class MutationModel(torch.nn.Module): + def forward(self, x, y): + x.view(3, 2, -1).add_(y) + return x + inputs = (torch.randn(12), 2.0) + model = MutationModel() + ep = torch.export.export(model, inputs) + inputs_export = copy.deepcopy(inputs) + inputs_model = copy.deepcopy(inputs) + self.assertEqual(ep(*inputs_export), model(*inputs_model)) + self.assertEqual(inputs[0] + 2.0, inputs_model[0]) + self.assertEqual(inputs[0] + 2.0, inputs_export[0]) + + def test_export_input_mutation_dynamic_shape(self): + class MutationModel(torch.nn.Module): + def forward(self, x, y): + x[0].mul_(y) + return x + inputs = ((torch.randn(12), torch.randn(3, 2)), 2.0) + model = MutationModel() + ep = torch.export.export( + model, + inputs, + dynamic_shapes={'x': ({0: torch.export.Dim("dim")}, None), "y": None} + ) + nodes = list(ep.graph.nodes) + self.assertEqual(nodes[0].op, "placeholder") + self.assertIsInstance(nodes[0].meta['val'], torch.Tensor) + self.assertIsInstance(nodes[0].meta['val'].shape[0], torch.SymInt) + + inputs_export = copy.deepcopy(inputs) + inputs_model = copy.deepcopy(inputs) + self.assertEqual(ep(*inputs_export), model(*inputs_model)) + self.assertEqual(inputs[0][0] * 2.0, inputs_model[0][0]) + self.assertEqual(inputs[0][0] * 2.0, inputs_export[0][0]) if __name__ == '__main__': run_tests() diff --git a/test/export/test_passes.py b/test/export/test_passes.py index 86b30cb05980..627f1d5ee98f 100644 --- a/test/export/test_passes.py +++ b/test/export/test_passes.py @@ -76,7 +76,7 @@ def forward(self, x): dim1_x = torch.export.Dim("dim1_x", min=2, max=6) ep = torch.export.export(M(), (x,), dynamic_shapes={"x": {1: dim1_x}}) - with self.assertRaisesRegex(RuntimeError, "Input arg0_1"): + with self.assertRaisesRegex(RuntimeError, "is outside of specified dynamic range"): ep(torch.zeros(2, 7, 3)) self.assertEqual(ep(torch.ones(2, 4, 3)), M().forward(torch.ones(2, 4, 3))) @@ -99,10 +99,10 @@ def forward(self, x, y): M(), (x, y), dynamic_shapes={"x": {0: dim0_x, 1: dim1_x}, "y": {0: dim0_y}} ) - with self.assertRaisesRegex(RuntimeError, "Input arg0_1"): + with self.assertRaisesRegex(RuntimeError, "is outside of specified dynamic range"): ep(torch.zeros(4, 7, 3), torch.ones(5, 5, 5)) - with self.assertRaisesRegex(RuntimeError, "Input arg1_1"): + with self.assertRaisesRegex(RuntimeError, "is outside of specified dynamic range"): ep(torch.zeros(4, 2, 3), torch.ones(2, 5, 5)) def test_runtime_assert_some_dims_not_specified(self) -> None: @@ -123,12 +123,12 @@ def forward(self, x, y): M(), (x, y), dynamic_shapes={"x": {0: dim0_x, 1: dim1_x}, "y": None} ) - with self.assertRaisesRegex(RuntimeError, "Input arg0_1"): + with self.assertRaisesRegex(RuntimeError, "is outside of specified dynamic range"): ep(torch.zeros(4, 7, 3), torch.ones(5, 5, 5)) # y is specialized to 5 with self.assertRaisesRegex( - RuntimeError, r"Input arg1_1.shape\[0\] is specialized at 5" + RuntimeError, r"shape\[0\] is specialized at 5" ): ep(torch.zeros(4, 2, 3), torch.ones(2, 5, 5)) @@ -152,12 +152,12 @@ def forward(self, x, y): dim1_y = torch.export.Dim("dim1_y", min=3, max=6) ep = torch.export.export(M(), (x, y), dynamic_shapes={"x": None, "y": {1: dim1_y}}) - with self.assertRaisesRegex(RuntimeError, "Input arg0_1"): + with self.assertRaisesRegex(RuntimeError, r"shape\[1\] is specialized at 2"): ep(torch.zeros(4, 7, 3), torch.ones(5, 5, 5)) # y is specialized to 5 with self.assertRaisesRegex( - RuntimeError, r"Input arg1_1.shape\[0\] is specialized at 5" + RuntimeError, r"shape\[0\] is specialized at 5" ): ep(torch.zeros(4, 2, 3), torch.ones(2, 5, 5)) @@ -302,34 +302,6 @@ def false_fn(x, y): with self.assertRaisesRegex(RuntimeError, "is outside of inline constraint \\[2, 5\\]."): ep(torch.tensor(False), torch.tensor([6]), torch.tensor([6])) - def test_runtime_assert_equality_constraint(self): - class Adder(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - - def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - return x + y - - m = Adder() - x = torch.rand(3, 4) - y = torch.rand(3, 4) - dim1 = torch.export.Dim("dim1") - exported = torch.export.export( - m, (x, y), dynamic_shapes={"x": {1: dim1}, "y": {1: dim1}} - ) - - x = torch.rand(3, 5) - y = torch.rand(3, 6) - with self.assertRaisesRegex( - RuntimeError, r"Input arg0_1.shape\[1\] is not equal to input arg1_1.shape\[1\]" - ): - exported(x, y) - - y = torch.rand(3, 5) - dynamo_result = exported(x, y) - real_result = m(x, y) - self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) - def test_functionalize_inline_contraints(self) -> None: def f(x): a = x.item() diff --git a/test/export/test_serialize.py b/test/export/test_serialize.py index 7eaff4f75ce7..49b4e35b6130 100644 --- a/test/export/test_serialize.py +++ b/test/export/test_serialize.py @@ -44,6 +44,7 @@ def get_filtered_export_db_tests(): "dictionary", # Graph output must be a tuple() "fn_with_kwargs", # export doesn't support kwargs yet "scalar_output", # Tracing through 'f' must produce a single graph + "user_input_mutation", # TODO(zhxchen17) Support serializing user inputs mutation. } return [ diff --git a/test/export/test_unflatten.py b/test/export/test_unflatten.py index c6c9e678f015..444a64c2eb56 100644 --- a/test/export/test_unflatten.py +++ b/test/export/test_unflatten.py @@ -277,11 +277,11 @@ def forward(self, x): return a export_module = torch.export.export(Mod(), (torch.randn((2, 3)),)) - with self.assertRaisesRegex(RuntimeError, "Input arg4_1.shape"): + with self.assertRaisesRegex(RuntimeError, ".shape\[1\] is specialized at 3"): export_module(torch.randn(6, 6)) unflattened = export_module.module(flat=False) - with self.assertRaisesRegex(RuntimeError, "Input arg4_1.shape"): + with self.assertRaisesRegex(RuntimeError, ".shape\[1\] is specialized at 3"): unflattened(torch.randn(6, 6)) diff --git a/torch/_export/__init__.py b/torch/_export/__init__.py index 438b54c2b0bd..fd7bb2cff3f3 100644 --- a/torch/_export/__init__.py +++ b/torch/_export/__init__.py @@ -30,28 +30,28 @@ from torch._dynamo.exc import UserError, UserErrorType from torch._dynamo.source import ConstantSource from torch._export.passes.collect_tracepoints_pass import CollectTracepointsPass -from torch._functorch.aot_autograd import aot_export_module +from torch._functorch.aot_autograd import aot_export_module, GraphSignature from torch._functorch.eager_transforms import functionalize from torch._guards import detect_fake_mode from torch._ops import OpOverload from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode from torch.export import _create_constraint, _Dim, Constraint +from torch.export.exported_program import ( + ExportedProgram, + ModuleCallEntry, + ModuleCallSignature, +) from torch.export.graph_signature import ( - ExportGraphSignature, _sig_to_specs, ArgumentSpec, ConstantArgument, + ExportGraphSignature, InputKind, + InputSpec, OutputKind, OutputSpec, SymIntArgument, TensorArgument, - InputSpec -) -from torch.export.exported_program import ( - ExportedProgram, - ModuleCallEntry, - ModuleCallSignature, ) from torch.fx import traceback as fx_traceback from torch.fx._compatibility import compatibility @@ -559,6 +559,88 @@ def export( preserve_module_call_signature=preserve_module_call_signature, ) + +def _prepare_module( + gm_torch_level: torch.fx.GraphModule, + aot_export_args +) -> List[str]: + flat_args = pytree.tree_leaves(aot_export_args) + user_input_names = [] + with gm_torch_level.graph.inserting_before(): + for i, (arg, node) in enumerate(zip(flat_args, gm_torch_level.graph.nodes)): + assert node.op == "placeholder" + user_input_names.append(node.name) + if isinstance(arg, torch.Tensor): + assert not hasattr(gm_torch_level, node.name) + gm_torch_level.register_buffer(node.name, arg) + get_attr = gm_torch_level.graph.get_attr(node.name) + node.replace_all_uses_with(get_attr) + get_attr.meta = copy.copy(node.meta) + + for node in list(gm_torch_level.graph.nodes): + if node.op == "placeholder": + assert len(node.users) == 0 + gm_torch_level.graph.erase_node(node) + gm_torch_level.recompile() + return user_input_names + + +def _unwrap_user_inputs( + gm: torch.fx.GraphModule, + graph_signature: GraphSignature, + user_input_names: List[str] +) -> Dict[str, str]: + assert len(graph_signature.user_inputs) == 0 + assert graph_signature.backward_signature is None + names = set(user_input_names) + + placeholders = [node for node in gm.graph.nodes if node.op == "placeholder"] + # user inputs are always added in the end + start = len(graph_signature.parameters) + end = start + len(graph_signature.buffers) + buffer_nodes = placeholders[start:end] + last_placeholder_node = placeholders[-1] if len(placeholders) > 0 else None + old_nodes: Dict[str, torch.fx.Node] = {} + for node in buffer_nodes: + buffer_name = graph_signature.inputs_to_buffers[node.name] + if buffer_name not in names: + continue + old_nodes[buffer_name] = node + replaces = {} + new_node_names: Dict[str, str] = {} + with gm.graph.inserting_after(last_placeholder_node): + for name in reversed(user_input_names): + new_node = gm.graph.placeholder(name) + new_node.target = new_node.name + new_node_names[name] = new_node.name + if name in old_nodes: + old_node = old_nodes[name] + new_node.meta = copy.copy(old_node.meta) + old_node.replace_all_uses_with(new_node) + replaces[old_node.name] = new_node.name + + for old_node in old_nodes.values(): + gm.graph.erase_node(old_node) + + gm.recompile() + + graph_signature.buffers = [b for b in graph_signature.buffers if b not in names] + graph_signature.inputs_to_buffers = { + i: b for i, b in graph_signature.inputs_to_buffers.items() if b not in names + } + user_inputs_to_mutate = { + o: b for o, b in graph_signature.buffers_to_mutate.items() if b in names + } + graph_signature.buffers_to_mutate = { + o: b for o, b in graph_signature.buffers_to_mutate.items() if b not in names + } + graph_signature.user_inputs = list(reversed(new_node_names.values())) # type: ignore[arg-type] + graph_signature.user_outputs = [ + replaces[o] if o in replaces else o for o in graph_signature.user_outputs + ] + return user_inputs_to_mutate # type: ignore[return-value] + + def _disable_prexisiting_fake_mode(fn): @functools.wraps(fn) @@ -703,6 +785,10 @@ def _export( if isinstance(f, torch.nn.Module): _normalize_nn_module_stack(gm_torch_level, type(f)) + aot_export_args = (*fake_args, *_reorder_kwargs_by_names(orig_args, fake_args, fake_kwargs).values()) + + user_input_names = _prepare_module(gm_torch_level, aot_export_args) + # Note: aot_export_module doesn't accept kwargs, we'd like to reorder the kwargs as an OrderedDict # to follow the order in orig_args and correctly call gm_torch_level @@ -712,9 +798,10 @@ def _export( with torch.nn.utils.stateless._reparametrize_module(gm_torch_level, fake_params_buffers): gm, graph_signature = aot_export_module( gm_torch_level, - (*fake_args, *_reorder_kwargs_by_names(orig_args, fake_args, fake_kwargs).values()), + (), trace_joint=False ) + user_inputs_to_mutate = _unwrap_user_inputs(gm, graph_signature, user_input_names) def to_str_list(sig_component: List[Any]): return [str(v) for v in sig_component] @@ -771,6 +858,7 @@ def to_str_dict(sig_component: Dict[Any, Any]): is_joint = graph_signature.backward_signature is not None def make_argument_spec(node) -> ArgumentSpec: + assert "val" in node.meta, f"{node} has no 'val' metadata field" val = node.meta["val"] if isinstance(val, FakeTensor): return TensorArgument(name=node.name) @@ -784,6 +872,7 @@ def make_argument_spec(node) -> ArgumentSpec: inputs_to_buffers=graph_signature.inputs_to_buffers, # type: ignore[arg-type] user_outputs=set(graph_signature.user_outputs), # type: ignore[arg-type] buffer_mutations=graph_signature.buffers_to_mutate, # type: ignore[arg-type] + user_input_mutations=user_inputs_to_mutate, # type: ignore[arg-type] grad_params=graph_signature.backward_signature.gradients_to_parameters if is_joint else {}, # type: ignore[arg-type, union-attr] grad_user_inputs=graph_signature.backward_signature.gradients_to_user_inputs if is_joint else {}, # type: ignore[arg-type, union-attr] loss_output=graph_signature.backward_signature.loss_output if is_joint else None, # type: ignore[arg-type, union-attr] diff --git a/torch/_export/db/examples/user_input_mutation.py b/torch/_export/db/examples/user_input_mutation.py index 56af08d0c359..2bb16cd64a56 100644 --- a/torch/_export/db/examples/user_input_mutation.py +++ b/torch/_export/db/examples/user_input_mutation.py @@ -6,11 +6,11 @@ @export_case( example_inputs=(torch.ones(3, 2),), tags={"torch.mutation"}, - support_level=SupportLevel.NOT_SUPPORTED_YET, + support_level=SupportLevel.SUPPORTED, ) class UserInputMutation(torch.nn.Module): """ - Can't directly mutate user input in forward + Directly mutate user input in forward """ def forward(self, x): diff --git a/torch/_export/verifier.py b/torch/_export/verifier.py index dae103724c5e..391d7f99f69b 100644 --- a/torch/_export/verifier.py +++ b/torch/_export/verifier.py @@ -230,12 +230,6 @@ def _verify_exported_program_signature(exported_program) -> None: # Check ExportedProgram signature matches gs = exported_program.graph_signature - bs_grad_to_param = {} - bs_grad_to_user_inputs = {} - if gs.backward_signature is not None: - bs_grad_to_param = gs.backward_signature.gradients_to_parameters - bs_grad_to_user_inputs = gs.backward_signature.gradients_to_user_inputs - # Check every node in the signature exists in the graph input_node_names = [node.name for node in exported_program.graph.nodes if node.op == "placeholder"] @@ -324,19 +318,28 @@ def _verify_exported_program_signature(exported_program) -> None: f"Number of user outputs: {len(gs.user_outputs)}. \n" ) - buffer_mutate_nodes = output_nodes[:len(gs.buffers_to_mutate)] - user_output_nodes = output_nodes[len(gs.buffers_to_mutate):len(gs.user_outputs) + len(gs.buffers_to_mutate)] + end = len(gs.buffers_to_mutate) + len(gs.user_inputs_to_mutate) + mutate_nodes: List[str] = output_nodes[:end] + user_output_nodes = output_nodes[end:end + len(gs.user_outputs)] - for buffer_node in buffer_mutate_nodes: - if ( - buffer_node not in gs.buffers_to_mutate or - gs.buffers_to_mutate[buffer_node] not in gs.buffers - ): + for mutation_node in mutate_nodes: + if mutation_node in gs.buffers_to_mutate: + if gs.buffers_to_mutate[mutation_node] not in gs.buffers: + raise SpecViolationError( + f"Buffer output {mutation_node} does not point to a buffer that exists. \n" + f"Dict of buffers that are mutated, in order: {gs.buffers_to_mutate} \n" + f"Buffer nodes available: {gs.buffers} \n" + ) + elif mutation_node in gs.user_inputs_to_mutate: + if gs.user_inputs_to_mutate[mutation_node] not in gs.user_inputs: + raise SpecViolationError( + f"User input output {mutation_node} does not point to a user input that exists. \n" + f"Dict of user inputs that are mutated, in order: {gs.user_inputs_to_mutate} \n" + f"User input nodes available: {gs.user_inputs} \n") + else: raise SpecViolationError( - f"Buffer output {buffer_node} is not in buffer mutation dictionary " - "or, it does not point to a buffer that exists. \n" - f"Dict of buffers that are mutated, in order: {gs.buffers_to_mutate} \n" - f"Buffer nodes available: {gs.buffers} \n" + f"Mutation node {mutation_node} is neither a buffer nor a user input. " + f"Buffers to mutate: {gs.buffers_to_mutate}, User inputs to mutate: {gs.user_inputs_to_mutate}" ) for user_output_node, user_output_name in zip(user_output_nodes, gs.user_outputs): diff --git a/torch/export/exported_program.py b/torch/export/exported_program.py index dde89b4fdd9c..015d63ff8bbe 100644 --- a/torch/export/exported_program.py +++ b/torch/export/exported_program.py @@ -277,9 +277,10 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: ) if self.call_spec.out_spec is not None: - mutation = self.graph_signature.buffers_to_mutate - num_mutated = len(mutation) - mutated_buffers = res[:num_mutated] + buffer_mutation = self.graph_signature.buffers_to_mutate + user_input_mutation = self.graph_signature.user_inputs_to_mutate + num_mutated = len(buffer_mutation) + len(user_input_mutation) + mutated_values = res[:num_mutated] # Exclude dependency token from final result. assertion_dep_token = self.graph_signature.assertion_dep_token @@ -299,10 +300,27 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: f"{received_spec}" ) finally: - ix = 0 - for buffer in self.graph_signature.buffers_to_mutate.values(): - self.state_dict[buffer] = mutated_buffers[ix] - ix += 1 + user_inputs = [ + spec + for spec in self.graph_signature.input_specs + if spec.kind == InputKind.USER_INPUT + ] + for i, value in enumerate(mutated_values): + output_spec = self.graph_signature.output_specs[i] + if output_spec.kind == OutputKind.BUFFER_MUTATION: + assert output_spec.target is not None + self.state_dict[output_spec.target] = value + elif output_spec.kind == OutputKind.USER_INPUT_MUTATION: + assert output_spec.target is not None + index = next( + i + for i, spec in enumerate(user_inputs) + if spec.arg.name == output_spec.target + ) + args[index].copy_(value) + else: + raise AssertionError(f"Unexpected kind: {output_spec.kind}") + return res def __str__(self) -> str: @@ -365,7 +383,6 @@ def _get_placeholders(gm): decomp_table = decomp_table or core_aten_decompositions() old_placeholders = _get_placeholders(self.graph_module) - old_outputs = list(self.graph.nodes)[-1].args[0] fake_args = [node.meta["val"] for node in old_placeholders] buffers_to_remove = [name for name, _ in self.graph_module.named_buffers()] diff --git a/torch/export/graph_signature.py b/torch/export/graph_signature.py index d208868d0d53..06c7f8e53e62 100644 --- a/torch/export/graph_signature.py +++ b/torch/export/graph_signature.py @@ -57,6 +57,7 @@ class OutputKind(Enum): BUFFER_MUTATION = auto() GRADIENT_TO_PARAMETER = auto() GRADIENT_TO_USER_INPUT = auto() + USER_INPUT_MUTATION = auto() @dataclasses.dataclass @@ -76,6 +77,7 @@ def _sig_to_specs( inputs_to_buffers: Mapping[str, str], user_outputs: Set[str], buffer_mutations: Mapping[str, str], + user_input_mutations: Mapping[str, str], grad_params: Mapping[str, str], grad_user_inputs: Mapping[str, str], loss_output: Optional[str], @@ -101,37 +103,49 @@ def to_input_spec(i: ArgumentSpec) -> InputSpec: else: raise AssertionError(f"Unknown tensor input kind: {name}") - def to_output_spec(o: ArgumentSpec) -> OutputSpec: + def to_output_spec(idx: int, o: ArgumentSpec) -> OutputSpec: if not isinstance(o, TensorArgument): return OutputSpec(kind=OutputKind.USER_OUTPUT, arg=o, target=None) name = o.name - if name in user_outputs: - return OutputSpec(kind=OutputKind.USER_OUTPUT, arg=o, target=None) - elif name in buffer_mutations: - return OutputSpec( - kind=OutputKind.BUFFER_MUTATION, - arg=o, - target=buffer_mutations[name], - ) - elif name in grad_params: - return OutputSpec( - kind=OutputKind.GRADIENT_TO_PARAMETER, - arg=o, - target=grad_params[name], - ) - elif name in grad_user_inputs: - return OutputSpec( - kind=OutputKind.GRADIENT_TO_USER_INPUT, - arg=o, - target=grad_user_inputs[name], - ) - elif name == loss_output: - return OutputSpec(kind=OutputKind.LOSS_OUTPUT, arg=o, target=None) + if idx < len(buffer_mutations) + len(user_input_mutations): + if name in buffer_mutations: + return OutputSpec( + kind=OutputKind.BUFFER_MUTATION, + arg=o, + target=buffer_mutations[name], + ) + elif name in user_input_mutations: + return OutputSpec( + kind=OutputKind.USER_INPUT_MUTATION, + arg=o, + target=user_input_mutations[name], + ) + else: + raise AssertionError(f"Unknown tensor mutation kind: {name}") else: - raise AssertionError(f"Unknown tensor output kind: {name}") + if name in user_outputs: + return OutputSpec(kind=OutputKind.USER_OUTPUT, arg=o, target=None) + + elif name in grad_params: + return OutputSpec( + kind=OutputKind.GRADIENT_TO_PARAMETER, + arg=o, + target=grad_params[name], + ) + elif name in grad_user_inputs: + return OutputSpec( + kind=OutputKind.GRADIENT_TO_USER_INPUT, + arg=o, + target=grad_user_inputs[name], + ) + elif name == loss_output: + return OutputSpec(kind=OutputKind.LOSS_OUTPUT, arg=o, target=None) + + else: + raise AssertionError(f"Unknown tensor output kind: {name}") input_specs = [to_input_spec(i) for i in inputs] - output_specs = [to_output_spec(o) for o in outputs] + output_specs = [to_output_spec(idx, o) for idx, o in enumerate(outputs)] return input_specs, output_specs @@ -304,6 +318,16 @@ def buffers_to_mutate(self) -> Mapping[str, str]: and isinstance(s.target, str) } + @property + def user_inputs_to_mutate(self) -> Mapping[str, str]: + return { + s.arg.name: s.target + for s in self.output_specs + if s.kind == OutputKind.USER_INPUT_MUTATION + and isinstance(s.arg, TensorArgument) + and isinstance(s.target, str) + } + # A dictionary mapping graph input node names to lifted tensor constants. @property def inputs_to_lifted_tensor_constants(self) -> Mapping[str, str]: From 68a36d2faa73d07c1fc83427771aef73da1498ff Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Wed, 22 Nov 2023 13:32:45 -0800 Subject: [PATCH 165/221] [dtensor] refactor some existing test util to use comm mode (#114404) As titled, This is just a test util refactor: redistributed profiler is not good to use and we should use comm mode going forward Pull Request resolved: https://github.com/pytorch/pytorch/pull/114404 Approved by: https://github.com/wconstab ghstack dependencies: #113592 --- test/distributed/_tensor/test_view_ops.py | 9 ++- .../distributed/_tensor/common_dtensor.py | 56 +++---------------- 2 files changed, 14 insertions(+), 51 deletions(-) diff --git a/test/distributed/_tensor/test_view_ops.py b/test/distributed/_tensor/test_view_ops.py index 9ed082cb07f4..855a7df79345 100644 --- a/test/distributed/_tensor/test_view_ops.py +++ b/test/distributed/_tensor/test_view_ops.py @@ -8,6 +8,7 @@ import torch.distributed as dist from torch import rand, randn, Tensor from torch.distributed._tensor import DeviceMesh, distribute_tensor, Replicate, Shard +from torch.distributed._tensor.debug import CommDebugMode from torch.distributed._tensor.ops.view_ops import ( Broadcast, Flatten, @@ -22,7 +23,6 @@ from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, - redistribute_profiler, with_comms, ) from torch.utils import _pytree as pytree @@ -166,10 +166,13 @@ def call_dt_test(self, op, args, kwargs, device_mesh: DeviceMesh): # print(f' |--- {in_shard}') in_dt = distribute_tensor(args[0], device_mesh, in_shard) - with redistribute_profiler() as profiler: + comm_mode = CommDebugMode() + with comm_mode: out_dt = op(in_dt, *args[1:], **kwargs) - self.assertEqual(profiler.num_calls, 0, "Expected no redistribution.") + self.assertEqual( + comm_mode.get_total_counts(), 0, "Expected no redistribution." + ) full_out = out_dt.full_tensor() diff --git a/torch/testing/_internal/distributed/_tensor/common_dtensor.py b/torch/testing/_internal/distributed/_tensor/common_dtensor.py index 3b886c8c0c35..b3471f48f1c8 100644 --- a/torch/testing/_internal/distributed/_tensor/common_dtensor.py +++ b/torch/testing/_internal/distributed/_tensor/common_dtensor.py @@ -1,14 +1,11 @@ # Copyright (c) Meta Platforms, Inc. and affiliates -from contextlib import contextmanager -from dataclasses import dataclass import itertools import sys from functools import wraps from typing import ( Any, Callable, - Generator, Iterator, Tuple, Dict, @@ -34,10 +31,8 @@ Shard, Replicate, distribute_tensor, - redistribute, ) -from torch.distributed._tensor.api import DTensor -from torch.distributed._tensor.placement_types import Placement, DTensorSpec +from torch.distributed._tensor.placement_types import Placement DEVICE_TYPE = "cuda" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else "cpu" PG_BACKEND = "nccl" if DEVICE_TYPE == "cuda" else "gloo" @@ -81,35 +76,6 @@ def skip_unless_torch_gpu(method: T) -> T: return cast(T, skip_if_lt_x_gpu(NUM_DEVICES)(method)) -@dataclass -class RedistributeProfile: - num_calls: int - - -@contextmanager -def redistribute_profiler() -> Generator[RedistributeProfile, None, None]: - - orig_redistribute_local_tensor = redistribute.redistribute_local_tensor - profile: RedistributeProfile = RedistributeProfile(num_calls=0) - - # pyre-ignore[53] - def patched_redistribute_local_tensor( - local_tensor: torch.Tensor, - current_spec: DTensorSpec, - target_spec: DTensorSpec, - ) -> DTensor: - result = orig_redistribute_local_tensor(local_tensor, current_spec, target_spec) - profile.num_calls += 1 - return result - - try: - # pyre-ignore[9] - redistribute.redistribute_local_tensor = patched_redistribute_local_tensor - yield profile - finally: - redistribute.redistribute_local_tensor = orig_redistribute_local_tensor - - class DTensorTestBase(MultiProcessTestCase): @property def world_size(self) -> int: @@ -155,19 +121,13 @@ def setUp(self) -> None: # pyre-ignore[2]: def _test_op(self, mesh: DeviceMesh, op_call, *args, **kwargs) -> None: - with redistribute_profiler() as profile: - out = op_call(*args, **kwargs) - dtc = DTensorConverter(mesh, args, kwargs) - for d_args, d_kwargs in dtc: - # pyre can't find assertTrue anymore? - self.assertEqual(dtc.successful(), True) - d_out = op_call(*d_args, **d_kwargs) - self.assertEqual( - d_out.redistribute( - mesh, [Replicate()] * mesh.ndim - ).to_local(), - out, - ) + out = op_call(*args, **kwargs) + dtc = DTensorConverter(mesh, args, kwargs) + for d_args, d_kwargs in dtc: + # pyre can't find assertTrue anymore? + self.assertEqual(dtc.successful(), True) + d_out = op_call(*d_args, **d_kwargs) + self.assertEqual(d_out.full_tensor(), out) def run_subtests(self, *args, **kwargs): return run_subtests(self, *args, **kwargs) From 150aaf46cab1d4bc3e4e1cecfbb66d0612f73cbb Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 27 Nov 2023 07:33:00 +0000 Subject: [PATCH 166/221] Revert "[opinfo][fix] conv3d & fix conv{1, 2}d for neg dilation|groups & add ErrorInputs for conv ops (#113885)" This reverts commit 4fa1ff8404b6c26c076288aa2a0aa77f0c24916a. Reverted https://github.com/pytorch/pytorch/pull/113885 on behalf of https://github.com/huydhn due to Sorry for reverting you change but its TestCommonCUDA::test_compare_cpu_nn_functional_conv3d test failing in trunk https://hud.pytorch.org/pytorch/pytorch/commit/4fa1ff8404b6c26c076288aa2a0aa77f0c24916a ([comment](https://github.com/pytorch/pytorch/pull/113885#issuecomment-1827268473)) --- aten/src/ATen/native/Convolution.cpp | 10 - test/functorch/test_ops.py | 2 - test/test_mps.py | 2 - .../_internal/common_methods_invocations.py | 240 ++---------------- 4 files changed, 26 insertions(+), 228 deletions(-) diff --git a/aten/src/ATen/native/Convolution.cpp b/aten/src/ATen/native/Convolution.cpp index 43ee07b41107..9c31026af54c 100644 --- a/aten/src/ATen/native/Convolution.cpp +++ b/aten/src/ATen/native/Convolution.cpp @@ -343,14 +343,6 @@ struct ConvParams { return is_non_neg; } - bool is_dilation_neg() const { - bool is_non_neg = false; - for (const auto& p : dilation) { - is_non_neg |= (p < 0); - } - return is_non_neg; - } - bool is_stride_nonpos() const { bool is_nonpos = false; for (auto s : stride) { @@ -660,7 +652,6 @@ static void check_shape_forward(const at::Tensor& input, TORCH_CHECK(!params.is_padding_neg(), "negative padding is not supported"); TORCH_CHECK(!params.is_output_padding_neg(), "negative output_padding is not supported"); TORCH_CHECK(!params.is_stride_nonpos(), "non-positive stride is not supported"); - TORCH_CHECK(!params.is_dilation_neg(), "dilation should be greater than zero"); TORCH_CHECK(weight_dim == k, "Expected ", weight_dim, "-dimensional input for ", weight_dim, @@ -982,7 +973,6 @@ static Tensor convolution_same( auto k = weight.dim(); TORCH_CHECK(k > 2, "weight should have at least three dimensions"); - TORCH_CHECK(groups > 0, "non-positive groups is not supported"); auto dim = static_cast(k - 2); auto weight_sizes = weight.sym_sizes(); auto input_sizes = input.sym_sizes(); diff --git a/test/functorch/test_ops.py b/test/functorch/test_ops.py index eb90ba83d512..945162ac69e4 100644 --- a/test/functorch/test_ops.py +++ b/test/functorch/test_ops.py @@ -1808,8 +1808,6 @@ def fn(input, weight, bias): {torch.float32: tol(atol=2e-04, rtol=1e-04)}, device_type='cuda'), tol2('linalg.pinv', 'hermitian', {torch.float32: tol(atol=5e-06, rtol=5e-06)}), - tol1('nn.functional.conv3d', - {torch.float32: tol(atol=5e-04, rtol=9e-03)}), )) def test_vmap_autograd_grad(self, device, dtype, op): def is_differentiable(inp): diff --git a/test/test_mps.py b/test/test_mps.py index a1f3f5215eef..2a1bcdb30782 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -695,7 +695,6 @@ def mps_ops_modifier(ops): # Convolution for integral types is not supported on MPS 'nn.functional.conv1d': [torch.int64], 'nn.functional.conv2d': [torch.int64], - 'nn.functional.conv3d': [torch.int64], 'nn.functional.conv_transpose1d': [torch.int64], 'nn.functional.conv_transpose2d': [torch.int64], @@ -883,7 +882,6 @@ def mps_ops_error_inputs_modifier(ops): 'multinomial', 'nn.functional.conv1d', 'nn.functional.conv2d', - 'nn.functional.conv3d', 'gather', 'scatter', 'scatter_add', diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 54cce7dcf3c3..e6cd427a99e2 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -3848,6 +3848,10 @@ def sample_inputs_conv1d(op_info, device, dtype, requires_grad, **kwargs): ((1, 4, 5), (3, 4, 3), None, {}), ) + # TODO: (@krshrimali), add error_inputs_func once https://github.com/pytorch/pytorch/pull/67354 is merged + # Should replace test_conv_modules_raise_error_on_incorrect_input_size and test_conv_shapecheck + # in test/test_nn.py + for input_shape, weight, bias, kwargs in cases: # Batched yield SampleInput(make_arg(input_shape), args=( @@ -3862,112 +3866,33 @@ def sample_inputs_conv1d(op_info, device, dtype, requires_grad, **kwargs): def error_inputs_conv1d(opinfo, device, **kwargs): - make_arg = partial(make_tensor, device=device, dtype=torch.float64) - - # error inputs for negative strides - yield ErrorInput( - SampleInput(make_arg((1, 1, 4)), args=(make_arg((1, 2, 2)), make_arg((1,))), - kwargs={'stride': (-1,)}), error_regex="non-positive stride is not supported") - - # error inputs for negative padding - yield ErrorInput( - SampleInput(make_arg((1, 1, 4)), args=(make_arg((1, 2, 2)), make_arg((1,))), - kwargs={'padding': (-1,)}), error_regex="negative padding is not supported") - - # error inputs for negative dilation - yield ErrorInput( - SampleInput(make_arg((1, 1, 4)), args=(make_arg((1, 1, 2)), make_arg((1,))), - kwargs={'dilation': (-1,)}), error_regex="dilation should be greater than zero") - - # FIXME: https://github.com/pytorch/pytorch/issues/85656 - # error inputs for bias shape not equal to the output channels - # yield ErrorInput(SampleInput(make_arg((1, 1, 4)), args=(make_arg((1, 1, 3)), make_arg((2,)))), - # error_regex="expected bias to be 1-dimensional with 1 elements") - - # error inputs for input.ndim != weight.ndim - yield ErrorInput(SampleInput(make_arg((1, 1, 4)), args=(make_arg((1, 2)), make_arg((1,)))), - error_regex="weight should have at least three dimensions") - - # error inputs for the weight[0] are less than the number of groups - yield ErrorInput( - SampleInput(make_arg((2, 2, 4)), args=(make_arg((2, 2, 2)), make_arg((2,))), - kwargs={'padding': 'same', 'groups': 3}), error_regex="expected weight to be at least 3 at dimension 0") - - # error inputs for the weight[0] are less than the number of groups + input = torch.randn(size=(33, 16, 30), device=device, dtype=torch.float64) + weight = torch.randn(size=(20, 16, 5), device=device, dtype=torch.float64) + groups = 0 yield ErrorInput( - SampleInput(make_arg((2, 2, 4)), args=(make_arg((2, 2, 2)), make_arg((2,))), - kwargs={'groups': 3}), error_regex="expected weight to be at least 3 at dimension 0") - - # error inputs for invalid groups - yield ErrorInput( - SampleInput(make_arg((2, 2, 4)), args=(make_arg((2, 2, 2)), make_arg((2,))), - kwargs={'padding': 'same', 'groups': -1}), error_regex="non-positive groups is not supported") - - # error inputs for invalid groups - yield ErrorInput( - SampleInput(make_arg((2, 2, 4)), args=(make_arg((2, 2, 2)), make_arg((2,))), - kwargs={'padding': 'same', 'groups': 0}), error_regex="non-positive groups is not supported") + SampleInput(input, kwargs={"weight": weight, "groups": groups}), + error_regex="non-positive groups is not supported" + ) def error_inputs_conv2d(opinfo, device, **kwargs): - make_arg = partial(make_tensor, device=device, dtype=torch.float64) - make_int_arg = partial(make_tensor, device=device, dtype=torch.int64) - make_complex_arg = partial(make_tensor, device=device, dtype=torch.complex128) - - # error inputs for different dtypes of input tensor and bias - yield ErrorInput( - SampleInput(make_int_arg((2, 4, 4)), args=(make_int_arg((3, 2, 3, 3)), make_arg((3,)))), - error_regex="should be the same") - - # error inputs for different dtypes of input tensor and bias - yield ErrorInput( - SampleInput(make_arg((2, 4, 4)), args=(make_arg((3, 2, 3, 3)), make_complex_arg((3,)))), - error_regex="should be the same") - - # error inputs for negative strides - yield ErrorInput( - SampleInput(make_arg((1, 1, 4, 4)), args=(make_arg((1, 2, 2, 3)), make_arg((1,))), - kwargs={'stride': (-1,)}), error_regex="non-positive stride is not supported") - - # error inputs for negative padding + weight = torch.randint(high=10, size=(3, 2, 3, 3), device=device) + input = torch.randint(high=10, size=(2, 4, 4), device=device) + bias = torch.rand((3,), dtype=torch.float32, device=device) + yield ErrorInput(SampleInput(input, args=(weight, bias)), error_regex="should be the same") + + weight = torch.rand(size=(3, 2, 3, 3), device=device, dtype=torch.float64) + input = torch.rand(size=(2, 4, 4), device=device, dtype=torch.float64) + bias = torch.rand((3,), dtype=torch.complex128, device=device) + yield ErrorInput(SampleInput(input, args=(weight, bias)), error_regex="should be the same") + + input = torch.randn(size=(1, 4, 5, 5), device=device, dtype=torch.float64) + weight = torch.randn(size=(8, 4, 3, 3), device=device, dtype=torch.float64) + groups = 0 yield ErrorInput( - SampleInput(make_arg((1, 1, 4, 3)), args=(make_arg((1, 2, 2, 4)), make_arg((1,))), - kwargs={'padding': (-1,)}), error_regex="negative padding is not supported") - - # error inputs for negative dilation - yield ErrorInput( - SampleInput(make_arg((1, 1, 4, 2)), args=(make_arg((1, 1, 2, 5)), make_arg((1,))), - kwargs={'dilation': (-1,)}), error_regex="dilation should be greater than zero") - - # FIXME: https://github.com/pytorch/pytorch/issues/85656 - # error inputs for bias shape not equal to the output channels - # yield ErrorInput(SampleInput(make_arg((1, 1, 4, 4)), args=(make_arg((1, 1, 3, 2)), make_arg((2,)))), - # error_regex="expected bias to be 1-dimensional with 1 elements") - - # error inputs for input.ndim != weight.ndim - yield ErrorInput( - SampleInput(make_arg((1, 1, 4, 3)), args=(make_arg((1, 2, 2)), make_arg((1,))), - kwargs={'padding': 'same'}), error_regex="Expected 3-dimensional input for 3-dimensional weight") - - # error inputs for the weight[0] are less than the number of groups - yield ErrorInput( - SampleInput(make_arg((2, 2, 4, 3)), args=(make_arg((2, 2, 1, 3)), make_arg((2,))), - kwargs={'groups': 3}), error_regex="expected weight to be at least 3 at dimension 0") - - # error inputs for groups the weight[0] are less than the number of groups - yield ErrorInput( - SampleInput(make_arg((2, 2, 4, 3)), args=(make_arg((2, 2, 1, 3)), make_arg((2,))), - kwargs={'padding': 'same', 'groups': 3}), error_regex="expected weight to be at least 3 at dimension 0") - - # error inputs for invalid groups - yield ErrorInput( - SampleInput(make_arg((2, 2, 4, 5)), args=(make_arg((2, 2, 1, 4)), make_arg((2,))), - kwargs={'padding': 'same', 'groups': -1}), error_regex="non-positive groups is not supported") - - # error inputs for invalid groups - yield ErrorInput( - SampleInput(make_arg((2, 2, 4, 3)), args=(make_arg((2, 2, 4, 3)), make_arg((2,))), - kwargs={'padding': 'same', 'groups': 0}), error_regex="non-positive groups is not supported") + SampleInput(input, kwargs={"weight": weight, "groups": groups}), + error_regex="non-positive groups is not supported" + ) def sample_inputs_conv2d(op_info, device, dtype, requires_grad, jit_fail_sample=False, **kwargs): @@ -4015,90 +3940,6 @@ def sample_inputs_conv2d(op_info, device, dtype, requires_grad, jit_fail_sample= ), kwargs=kwargs) -def sample_inputs_conv3d(opinfo, device, dtype, requires_grad, **kwargs): - make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) - - # Ordered as shapes for input, weight, bias - # and dict of values of (stride, padding, dilation, groups) - cases: Tuple = ( - ((1, 1, 4, 4, 4), (1, 1, 1, 1, 1), (1,), {'padding': 'same'}), - ((1, 1, 4, 4, 4), (1, 1, 4, 4, 4), (1,), {'stride': (2, 2, 2)}), - ((1, 1, 5, 5, 5), (1, 1, 3, 3, 3), (1,), {'dilation': 2}), - ((1, 1, 1, 1, 10), (1, 1, 1, 1, 4), None, {'padding': 'valid'}), - ((1, 1, 10, 11, 12), (1, 1, 1, 2, 5), None, {'padding': 'same'}), - ((1, 1, 10, 11, 12), (1, 1, 1, 2, 5), None, {'padding': 'same', 'dilation': 2}), - ((1, 1, 10, 11, 12), (1, 1, 4, 4, 4), None, {'padding': 'same', 'dilation': 3}), - ((1, 1, 1, 1, 10), (1, 1, 1, 1, 4), None, {'padding': 'valid'}), - ((3, 9, 3, 1, 9), (3, 3, 3, 1, 9), (3,), {'groups': 3}), - ((3, 9, 3, 1, 9), (3, 3, 3, 1, 9), (3,), {'stride': (2, 2, 2), 'dilation': 1, 'groups': 3}), - ) - - for input_shape, weight, bias, kwargs in cases: - # Batched - yield SampleInput(make_arg(input_shape), args=( - make_arg(weight), - make_arg(bias) if bias is not None else bias - ), kwargs=kwargs) - # Unbatched - yield SampleInput(make_arg(input_shape[1:]), args=( - make_arg(weight), - make_arg(bias) if bias is not None else bias - ), kwargs=kwargs) - - -def error_inputs_conv3d(opinfo, device, **kwargs): - make_arg = partial(make_tensor, device=device, dtype=torch.float64) - - # error inputs for negative strides - yield ErrorInput( - SampleInput(make_arg((1, 1, 4, 4, 4)), args=(make_arg((1, 1, 2, 2, 2)), make_arg((1,))), - kwargs={'stride': (-1,)}), error_regex="non-positive stride is not supported") - - # error inputs for negative padding - yield ErrorInput( - SampleInput(make_arg((1, 1, 4, 4, 4)), args=(make_arg((1, 1, 2, 2, 2)), make_arg((1,))), - kwargs={'padding': (-1,)}), error_regex="negative padding is not supported") - - # error inputs for negative dilation - yield ErrorInput( - SampleInput(make_arg((1, 1, 4, 4, 4)), args=(make_arg((1, 1, 2, 2, 2)), make_arg((1,))), - kwargs={'dilation': (-1,)}), error_regex="dilation should be greater than zero") - - # FIXME: https://github.com/pytorch/pytorch/issues/85656 - # error inputs for bias shape not equal to the output channels - # yield ErrorInput(SampleInput(make_arg((1, 1, 4, 4, 4)), args=(make_arg((1, 1, 3, 3, 3)), make_arg((2,)))), - # error_regex="expected bias to be 1-dimensional with 1 elements") - - # error inputs for input.ndim != weight.ndim - yield ErrorInput( - SampleInput(make_arg((1, 1, 3, 4, 5)), args=(make_arg((1, 1, 4, 3)), make_arg((1,))), - kwargs={'padding': 'same'}), error_regex="Expected 4-dimensional input for 4-dimensional weight") - - # error inputs for the weight[0] are less than the number of groups - yield ErrorInput( - SampleInput(make_arg((2, 2, 3, 4, 5)), args=(make_arg((2, 2, 4, 3, 3)), - make_arg((2,))), kwargs={'groups': 3}), - error_regex="expected weight to be at least 3 at dimension 0") - - # error inputs for the weight[0] are less than the number of groups - yield ErrorInput( - SampleInput(make_arg((2, 2, 3, 4, 5)), args=(make_arg((2, 2, 4, 3, 3)), - make_arg((2,))), kwargs={'padding': 'same', 'groups': 3}), - error_regex="expected weight to be at least 3 at dimension 0") - - # error inputs for invalid groups - yield ErrorInput( - SampleInput(make_arg((2, 2, 3, 4, 5)), args=(make_arg((2, 2, 4, 3, 3)), - make_arg((2,))), kwargs={'padding': 'same', 'groups': 0}), - error_regex="non-positive groups is not supported") - - # error inputs for padding='same' not supported by strided convolutions - yield ErrorInput( - SampleInput(make_arg((18, 27, 9, 1, 9)), args=(make_arg((9, 9, 9, 1, 9)), - make_arg((9,))), kwargs={'stride': 2, 'padding': 'same', 'groups': 3}), - error_regex="padding='same' is not supported for strided convolutions") - - def sample_inputs_group_norm(opinfo, device, dtype, requires_grad, **kwargs): make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) @@ -13235,35 +13076,6 @@ def reference_flatten(input, start_dim=0, end_dim=-1): ), supports_expanded_weight=True, supports_out=False,), - OpInfo('nn.functional.conv3d', - aliases=('conv3d',), - aten_name='conv3d', - dtypes=floating_and_complex_types_and(torch.int64, torch.bfloat16, torch.float16), - dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.chalf, torch.bfloat16), - sample_inputs_func=sample_inputs_conv3d, - error_inputs_func=error_inputs_conv3d, - gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, - gradcheck_fast_mode=True, - supports_forward_ad=True, - supports_fwgrad_bwgrad=True, - decorators=( - DecorateInfo( - toleranceOverride({torch.chalf: tol(atol=6e-2, rtol=5e-2)}), - 'TestCommon', 'test_complex_half_reference_testing', - ), - ), - skips=( - # RuntimeError: !lhs.isAliasOf(rhs) INTERNAL ASSERT FAILED at - # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":103, please report a bug to PyTorch. - DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), - # RuntimeError: UNSUPPORTED DTYPE: complex - DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', - 'test_nnc_correctness', dtypes=(torch.complex64, torch.complex128)), - # RuntimeError: Conv3D is not supported on MPS - DecorateInfo(unittest.expectedFailure, 'TestConsistency'), - ), - supports_expanded_weight=True, - supports_out=False,), OpInfo('nn.functional.group_norm', aten_name='group_norm', aliases=('group_norm',), From 8232d4d1c3a2f5468aa459ff823b041557dd1934 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 27 Nov 2023 07:36:08 +0000 Subject: [PATCH 167/221] Revert "[BE]: Enable Ruff + Flake8 G201,G202 logging format rule. (#114474)" This reverts commit d30497f6b62007c9d1e3c38179528e9d25ac1292. Reverted https://github.com/pytorch/pytorch/pull/114474 on behalf of https://github.com/huydhn due to Sorry for reverting your change, but I see a bunch of inductor failure after the commit https://hud.pytorch.org/pytorch/pytorch/commit/d30497f6b62007c9d1e3c38179528e9d25ac1292, trying to revert to see if it helps fix the issues ([comment](https://github.com/pytorch/pytorch/pull/114474#issuecomment-1827271887)) --- .flake8 | 2 +- pyproject.toml | 2 +- torch/_dynamo/guards.py | 3 ++- torch/_dynamo/utils.py | 2 +- torch/distributed/elastic/multiprocessing/api.py | 3 ++- torch/distributed/elastic/timer/api.py | 9 +++++---- .../distributed/elastic/timer/file_based_local_timer.py | 8 ++++---- torch/distributed/elastic/timer/local_timer.py | 4 ++-- 8 files changed, 18 insertions(+), 15 deletions(-) diff --git a/.flake8 b/.flake8 index 1e61b459df94..bca578ce563e 100644 --- a/.flake8 +++ b/.flake8 @@ -18,7 +18,7 @@ ignore = # these ignores are from flake8-comprehensions; please fix! C407, # these ignores are from flake8-logging-format; please fix! - G100,G101,G200 + G100,G101,G200,G201,G202 # these ignores are from flake8-simplify. please fix or ignore with commented reason SIM105,SIM108,SIM110,SIM111,SIM113,SIM114,SIM115,SIM116,SIM117,SIM118,SIM119,SIM12, # flake8-simplify code styles diff --git a/pyproject.toml b/pyproject.toml index 71157c4f3cf3..279bd6fa058b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,7 +43,7 @@ ignore = [ "F821", "F841", # these ignores are from flake8-logging-format; please fix! - "G101", + "G101", "G201", "G202", # these ignores are from RUFF perf; please fix! "PERF203", "PERF4", # these ignores are from PYI; please fix! diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 0ef173155e2f..1b068402019b 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -1315,8 +1315,9 @@ def get_guard_fail_reason( GuardFail(reason_str or "unknown reason", orig_code_map[code]) ) except Exception as e: - log.exception( + log.error( "Failure in guard_fail_fn callback - raising here will cause a NULL Error on guard eval", + exc_info=True, ) return reason_str diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 47275ea04185..ba876a0fbb82 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -400,7 +400,7 @@ def write_record_to_file(filename, exec_record): with open(filename, "wb") as f: exec_record.dump(f) except Exception: - log.exception("Unable to write execution record %s", filename) + log.error("Unable to write execution record %s", filename, exc_info=True) def count_calls(g: fx.Graph): diff --git a/torch/distributed/elastic/multiprocessing/api.py b/torch/distributed/elastic/multiprocessing/api.py index c7c870bdb073..32426be08010 100644 --- a/torch/distributed/elastic/multiprocessing/api.py +++ b/torch/distributed/elastic/multiprocessing/api.py @@ -477,13 +477,14 @@ def _poll(self) -> Optional[RunProcsResult]: failed_proc = self._pc.processes[failed_local_rank] error_filepath = self.error_files[failed_local_rank] - log.exception( + log.error( "failed (exitcode: %s)" " local_rank: %s (pid: %s)" " of fn: %s (start_method: %s)", failed_proc.exitcode, failed_local_rank, e.pid, fn_name, self.start_method, + exc_info=True, ) self.close() diff --git a/torch/distributed/elastic/timer/api.py b/torch/distributed/elastic/timer/api.py index 566a3d4acbc7..6dd308891988 100644 --- a/torch/distributed/elastic/timer/api.py +++ b/torch/distributed/elastic/timer/api.py @@ -169,10 +169,11 @@ def _reap_worker_no_throw(self, worker_id: Any) -> bool: """ try: return self._reap_worker(worker_id) - except Exception: - log.exception( + except Exception as e: + log.error( "Uncaught exception thrown from _reap_worker(), " "check that the implementation correctly catches exceptions", + exc_info=e, ) return True @@ -180,8 +181,8 @@ def _watchdog_loop(self): while not self._stop_signaled: try: self._run_watchdog() - except Exception: - log.exception("Error running watchdog") + except Exception as e: + log.error("Error running watchdog", exc_info=e) def _run_watchdog(self): batch_size = max(1, self._request_queue.size()) diff --git a/torch/distributed/elastic/timer/file_based_local_timer.py b/torch/distributed/elastic/timer/file_based_local_timer.py index 26ebce33dcb5..597000c6d20d 100644 --- a/torch/distributed/elastic/timer/file_based_local_timer.py +++ b/torch/distributed/elastic/timer/file_based_local_timer.py @@ -225,8 +225,8 @@ def _watchdog_loop(self) -> None: self._run_watchdog(fd) if run_once: break - except Exception: - log.exception("Error running watchdog") + except Exception as e: + log.error("Error running watchdog", exc_info=e) def _run_watchdog(self, fd: io.TextIOWrapper) -> None: timer_requests = self._get_requests(fd, self._max_interval) @@ -328,6 +328,6 @@ def _reap_worker(self, worker_pid: int, signal: int) -> bool: except ProcessLookupError: log.info("Process with pid=%s does not exist. Skipping", worker_pid) return True - except Exception: - log.exception("Error terminating pid=%s", worker_pid) + except Exception as e: + log.error("Error terminating pid=%s", worker_pid, exc_info=e) return False diff --git a/torch/distributed/elastic/timer/local_timer.py b/torch/distributed/elastic/timer/local_timer.py index 05f467c807a5..240163f1bf6c 100644 --- a/torch/distributed/elastic/timer/local_timer.py +++ b/torch/distributed/elastic/timer/local_timer.py @@ -120,6 +120,6 @@ def _reap_worker(self, worker_id: int) -> bool: except ProcessLookupError: log.info("Process with pid=%s does not exist. Skipping", worker_id) return True - except Exception: - log.exception("Error terminating pid=%s", worker_id) + except Exception as e: + log.error("Error terminating pid=%s", worker_id, exc_info=e) return False From fa1ccc34c4f65756bc50c3e3ab135c88b175b18c Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 27 Nov 2023 07:52:21 +0000 Subject: [PATCH 168/221] Revert "[export] Support user input mutation. [1/2] (#114496)" This reverts commit b62c0d96bcbe5f354ddce930fbdcd992dbaf1ce8. Reverted https://github.com/pytorch/pytorch/pull/114496 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/114496#issuecomment-1827289635)) --- test/export/test_db.py | 14 +-- test/export/test_export.py | 55 ++------- test/export/test_passes.py | 42 +++++-- test/export/test_serialize.py | 1 - test/export/test_unflatten.py | 4 +- torch/_export/__init__.py | 107 ++---------------- .../db/examples/user_input_mutation.py | 4 +- torch/_export/verifier.py | 37 +++--- torch/export/exported_program.py | 33 ++---- torch/export/graph_signature.py | 74 ++++-------- 10 files changed, 113 insertions(+), 258 deletions(-) diff --git a/test/export/test_db.py b/test/export/test_db.py index a2d2f36af42e..d126684beb60 100644 --- a/test/export/test_db.py +++ b/test/export/test_db.py @@ -1,15 +1,14 @@ # Owner(s): ["module: dynamo"] -import copy import unittest import torch._dynamo as torchdynamo +from torch.export import export from torch._export.db.case import ExportCase, normalize_inputs, SupportLevel from torch._export.db.examples import ( filter_examples_by_support_level, get_rewrite_cases, ) -from torch.export import export from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, @@ -29,19 +28,18 @@ class ExampleTests(TestCase): def test_exportdb_supported(self, name: str, case: ExportCase) -> None: model = case.model - inputs_export = normalize_inputs(case.example_inputs) - inputs_model = copy.deepcopy(inputs_export) + inputs = normalize_inputs(case.example_inputs) exported_program = export( model, - inputs_export.args, - inputs_export.kwargs, + inputs.args, + inputs.kwargs, dynamic_shapes=case.dynamic_shapes, ) exported_program.graph_module.print_readable() self.assertEqual( - exported_program(*inputs_export.args, **inputs_export.kwargs), - model(*inputs_model.args, **inputs_model.kwargs), + exported_program(*inputs.args, **inputs.kwargs), + model(*inputs.args, **inputs.kwargs), ) if case.extra_inputs is not None: diff --git a/test/export/test_export.py b/test/export/test_export.py index caa576bfa987..221ea9ba075b 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -1,6 +1,5 @@ # Owner(s): ["module: dynamo"] # flake8: noqa -import copy import dataclasses import unittest from contextlib import contextmanager @@ -1093,13 +1092,13 @@ def f(x, y): torch.allclose(exported(torch.ones(8, 5), 5), f(torch.ones(8, 5), 5)) ) with self.assertRaisesRegex( - RuntimeError, "is specialized to be 5 at tracing time" + RuntimeError, "Input arg1_1 is specialized to be 5 at tracing time" ): _ = exported(torch.ones(8, 5), 6) exported = torch.export.export(f, (tensor_inp, 5.0), dynamic_shapes=dynamic_shapes) with self.assertRaisesRegex( - RuntimeError, "is specialized to be 5.0 at tracing time" + RuntimeError, "Input arg1_1 is specialized to be 5.0 at tracing time" ): _ = exported(torch.ones(7, 5), 6.0) @@ -1110,7 +1109,7 @@ def g(a, b, mode): inps = (torch.randn(4, 4), torch.randn(4), "trunc") exported = torch._export.export(g, inps) - with self.assertRaisesRegex(RuntimeError, "is specialized to be trunc at"): + with self.assertRaisesRegex(RuntimeError, "Input arg2_1 is specialized to be trunc at"): _ = exported(torch.randn(4, 4), torch.randn(4), "floor") self.assertTrue(torch.allclose(exported(*inps), g(*inps))) @@ -1191,7 +1190,7 @@ def forward(self, x): dim0_x = torch.export.Dim("dim0_x") exported = torch.export.export(Foo(), (inp,), dynamic_shapes={"x": {0: dim0_x}}) reexported = torch.export.export(exported, (inp,)) - with self.assertRaisesRegex(RuntimeError, "shape\[0\] is specialized at 5"): + with self.assertRaisesRegex(RuntimeError, "Input arg2_1\.shape\[0\] is specialized at 5"): reexported(torch.ones(7, 5)) reexported = torch.export.export(exported, (inp,), dynamic_shapes=({0: dim0_x},)) @@ -1200,7 +1199,7 @@ def forward(self, x): # can't retrace with invalid inputs with respect to the original ExportedProgram dim0_x_v2 = torch.export.Dim("dim0_x_v2", min=3) exported_v2 = torch.export.export(Foo(), (inp,), dynamic_shapes={"x": {0: dim0_x_v2}}) - with self.assertRaisesRegex(RuntimeError, "shape\[1\] is specialized at 5"): + with self.assertRaisesRegex(RuntimeError, "Input arg2_1"): torch.export.export(exported_v2, (torch.randn(2, 2),)) def test_retrace_graph_level_meta_preservation(self): @@ -1473,8 +1472,8 @@ def f(x): ep = export(f, (torch.tensor([3]),)) self.assertExpectedInline(str(ep.graph_module.code).strip(), """\ -def forward(self, l_x_): - _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(l_x_); l_x_ = None +def forward(self, arg0_1): + _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(arg0_1); arg0_1 = None ge = _local_scalar_dense >= 0 scalar_tensor = torch.ops.aten.scalar_tensor.default(ge); ge = None _assert_async = torch.ops.aten._assert_async.msg(scalar_tensor, '_local_scalar_dense is outside of inline constraint [0, inf].'); scalar_tensor = None @@ -1493,7 +1492,7 @@ def foo(a, b): self.assertEqual(ep(*test_inp), foo(*test_inp)) ep_v2 = torch.export.export(foo, (torch.randn(4, 4), torch.randn(4, 4)), dynamic_shapes=(None, None)) - with self.assertRaisesRegex(RuntimeError, "shape\[0\] is specialized at 4"): + with self.assertRaisesRegex(RuntimeError, "Input arg1_1.shape\[0\] is specialized at 4"): ep_v2(*test_inp) def test_non_arg_name_dynamic_shapes_api_with_kwarg(self): @@ -1541,7 +1540,7 @@ def dynamify_inp(x): test_inp = ((torch.randn(4, 4), torch.randn(2, 4)), torch.randn(4, 4)) with self.assertRaisesRegex( RuntimeError, - "shape\[0\] is outside of specified dynamic range \[3, inf\]" + "Input arg1_1.shape\[0\] is outside of specified dynamic range \[3, inf\]" ): ep(*test_inp) @@ -1722,42 +1721,6 @@ def forward(self, scores, mask): optimized_model = torch.compile(exported_model) optimized_model(tensor_cpu, mask_cpu) - def test_export_input_mutation_static_shape(self): - class MutationModel(torch.nn.Module): - def forward(self, x, y): - x.view(3, 2, -1).add_(y) - return x - inputs = (torch.randn(12), 2.0) - model = MutationModel() - ep = torch.export.export(model, inputs) - inputs_export = copy.deepcopy(inputs) - inputs_model = copy.deepcopy(inputs) - self.assertEqual(ep(*inputs_export), model(*inputs_model)) - self.assertEqual(inputs[0] + 2.0, inputs_model[0]) - self.assertEqual(inputs[0] + 2.0, inputs_export[0]) - - def test_export_input_mutation_dynamic_shape(self): - class MutationModel(torch.nn.Module): - def forward(self, x, y): - x[0].mul_(y) - return x - inputs = ((torch.randn(12), torch.randn(3, 2)), 2.0) - model = MutationModel() - ep = torch.export.export( - model, - inputs, - dynamic_shapes={'x': ({0: torch.export.Dim("dim")}, None), "y": None} - ) - nodes = list(ep.graph.nodes) - self.assertEqual(nodes[0].op, "placeholder") - self.assertIsInstance(nodes[0].meta['val'], torch.Tensor) - self.assertIsInstance(nodes[0].meta['val'].shape[0], torch.SymInt) - - inputs_export = copy.deepcopy(inputs) - inputs_model = copy.deepcopy(inputs) - self.assertEqual(ep(*inputs_export), model(*inputs_model)) - self.assertEqual(inputs[0][0] * 2.0, inputs_model[0][0]) - self.assertEqual(inputs[0][0] * 2.0, inputs_export[0][0]) if __name__ == '__main__': run_tests() diff --git a/test/export/test_passes.py b/test/export/test_passes.py index 627f1d5ee98f..86b30cb05980 100644 --- a/test/export/test_passes.py +++ b/test/export/test_passes.py @@ -76,7 +76,7 @@ def forward(self, x): dim1_x = torch.export.Dim("dim1_x", min=2, max=6) ep = torch.export.export(M(), (x,), dynamic_shapes={"x": {1: dim1_x}}) - with self.assertRaisesRegex(RuntimeError, "is outside of specified dynamic range"): + with self.assertRaisesRegex(RuntimeError, "Input arg0_1"): ep(torch.zeros(2, 7, 3)) self.assertEqual(ep(torch.ones(2, 4, 3)), M().forward(torch.ones(2, 4, 3))) @@ -99,10 +99,10 @@ def forward(self, x, y): M(), (x, y), dynamic_shapes={"x": {0: dim0_x, 1: dim1_x}, "y": {0: dim0_y}} ) - with self.assertRaisesRegex(RuntimeError, "is outside of specified dynamic range"): + with self.assertRaisesRegex(RuntimeError, "Input arg0_1"): ep(torch.zeros(4, 7, 3), torch.ones(5, 5, 5)) - with self.assertRaisesRegex(RuntimeError, "is outside of specified dynamic range"): + with self.assertRaisesRegex(RuntimeError, "Input arg1_1"): ep(torch.zeros(4, 2, 3), torch.ones(2, 5, 5)) def test_runtime_assert_some_dims_not_specified(self) -> None: @@ -123,12 +123,12 @@ def forward(self, x, y): M(), (x, y), dynamic_shapes={"x": {0: dim0_x, 1: dim1_x}, "y": None} ) - with self.assertRaisesRegex(RuntimeError, "is outside of specified dynamic range"): + with self.assertRaisesRegex(RuntimeError, "Input arg0_1"): ep(torch.zeros(4, 7, 3), torch.ones(5, 5, 5)) # y is specialized to 5 with self.assertRaisesRegex( - RuntimeError, r"shape\[0\] is specialized at 5" + RuntimeError, r"Input arg1_1.shape\[0\] is specialized at 5" ): ep(torch.zeros(4, 2, 3), torch.ones(2, 5, 5)) @@ -152,12 +152,12 @@ def forward(self, x, y): dim1_y = torch.export.Dim("dim1_y", min=3, max=6) ep = torch.export.export(M(), (x, y), dynamic_shapes={"x": None, "y": {1: dim1_y}}) - with self.assertRaisesRegex(RuntimeError, r"shape\[1\] is specialized at 2"): + with self.assertRaisesRegex(RuntimeError, "Input arg0_1"): ep(torch.zeros(4, 7, 3), torch.ones(5, 5, 5)) # y is specialized to 5 with self.assertRaisesRegex( - RuntimeError, r"shape\[0\] is specialized at 5" + RuntimeError, r"Input arg1_1.shape\[0\] is specialized at 5" ): ep(torch.zeros(4, 2, 3), torch.ones(2, 5, 5)) @@ -302,6 +302,34 @@ def false_fn(x, y): with self.assertRaisesRegex(RuntimeError, "is outside of inline constraint \\[2, 5\\]."): ep(torch.tensor(False), torch.tensor([6]), torch.tensor([6])) + def test_runtime_assert_equality_constraint(self): + class Adder(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return x + y + + m = Adder() + x = torch.rand(3, 4) + y = torch.rand(3, 4) + dim1 = torch.export.Dim("dim1") + exported = torch.export.export( + m, (x, y), dynamic_shapes={"x": {1: dim1}, "y": {1: dim1}} + ) + + x = torch.rand(3, 5) + y = torch.rand(3, 6) + with self.assertRaisesRegex( + RuntimeError, r"Input arg0_1.shape\[1\] is not equal to input arg1_1.shape\[1\]" + ): + exported(x, y) + + y = torch.rand(3, 5) + dynamo_result = exported(x, y) + real_result = m(x, y) + self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) + def test_functionalize_inline_contraints(self) -> None: def f(x): a = x.item() diff --git a/test/export/test_serialize.py b/test/export/test_serialize.py index 49b4e35b6130..7eaff4f75ce7 100644 --- a/test/export/test_serialize.py +++ b/test/export/test_serialize.py @@ -44,7 +44,6 @@ def get_filtered_export_db_tests(): "dictionary", # Graph output must be a tuple() "fn_with_kwargs", # export doesn't support kwargs yet "scalar_output", # Tracing through 'f' must produce a single graph - "user_input_mutation", # TODO(zhxchen17) Support serializing user inputs mutation. } return [ diff --git a/test/export/test_unflatten.py b/test/export/test_unflatten.py index 444a64c2eb56..c6c9e678f015 100644 --- a/test/export/test_unflatten.py +++ b/test/export/test_unflatten.py @@ -277,11 +277,11 @@ def forward(self, x): return a export_module = torch.export.export(Mod(), (torch.randn((2, 3)),)) - with self.assertRaisesRegex(RuntimeError, ".shape\[1\] is specialized at 3"): + with self.assertRaisesRegex(RuntimeError, "Input arg4_1.shape"): export_module(torch.randn(6, 6)) unflattened = export_module.module(flat=False) - with self.assertRaisesRegex(RuntimeError, ".shape\[1\] is specialized at 3"): + with self.assertRaisesRegex(RuntimeError, "Input arg4_1.shape"): unflattened(torch.randn(6, 6)) diff --git a/torch/_export/__init__.py b/torch/_export/__init__.py index fd7bb2cff3f3..438b54c2b0bd 100644 --- a/torch/_export/__init__.py +++ b/torch/_export/__init__.py @@ -30,28 +30,28 @@ from torch._dynamo.exc import UserError, UserErrorType from torch._dynamo.source import ConstantSource from torch._export.passes.collect_tracepoints_pass import CollectTracepointsPass -from torch._functorch.aot_autograd import aot_export_module, GraphSignature +from torch._functorch.aot_autograd import aot_export_module from torch._functorch.eager_transforms import functionalize from torch._guards import detect_fake_mode from torch._ops import OpOverload from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode from torch.export import _create_constraint, _Dim, Constraint -from torch.export.exported_program import ( - ExportedProgram, - ModuleCallEntry, - ModuleCallSignature, -) from torch.export.graph_signature import ( + ExportGraphSignature, _sig_to_specs, ArgumentSpec, ConstantArgument, - ExportGraphSignature, InputKind, - InputSpec, OutputKind, OutputSpec, SymIntArgument, TensorArgument, + InputSpec +) +from torch.export.exported_program import ( + ExportedProgram, + ModuleCallEntry, + ModuleCallSignature, ) from torch.fx import traceback as fx_traceback from torch.fx._compatibility import compatibility @@ -559,88 +559,6 @@ def export( preserve_module_call_signature=preserve_module_call_signature, ) - -def _prepare_module( - gm_torch_level: torch.fx.GraphModule, - aot_export_args -) -> List[str]: - flat_args = pytree.tree_leaves(aot_export_args) - user_input_names = [] - with gm_torch_level.graph.inserting_before(): - for i, (arg, node) in enumerate(zip(flat_args, gm_torch_level.graph.nodes)): - assert node.op == "placeholder" - user_input_names.append(node.name) - if isinstance(arg, torch.Tensor): - assert not hasattr(gm_torch_level, node.name) - gm_torch_level.register_buffer(node.name, arg) - get_attr = gm_torch_level.graph.get_attr(node.name) - node.replace_all_uses_with(get_attr) - get_attr.meta = copy.copy(node.meta) - - for node in list(gm_torch_level.graph.nodes): - if node.op == "placeholder": - assert len(node.users) == 0 - gm_torch_level.graph.erase_node(node) - gm_torch_level.recompile() - return user_input_names - - -def _unwrap_user_inputs( - gm: torch.fx.GraphModule, - graph_signature: GraphSignature, - user_input_names: List[str] -) -> Dict[str, str]: - assert len(graph_signature.user_inputs) == 0 - assert graph_signature.backward_signature is None - names = set(user_input_names) - - placeholders = [node for node in gm.graph.nodes if node.op == "placeholder"] - # user inputs are always added in the end - start = len(graph_signature.parameters) - end = start + len(graph_signature.buffers) - buffer_nodes = placeholders[start:end] - last_placeholder_node = placeholders[-1] if len(placeholders) > 0 else None - old_nodes: Dict[str, torch.fx.Node] = {} - for node in buffer_nodes: - buffer_name = graph_signature.inputs_to_buffers[node.name] - if buffer_name not in names: - continue - old_nodes[buffer_name] = node - replaces = {} - new_node_names: Dict[str, str] = {} - with gm.graph.inserting_after(last_placeholder_node): - for name in reversed(user_input_names): - new_node = gm.graph.placeholder(name) - new_node.target = new_node.name - new_node_names[name] = new_node.name - if name in old_nodes: - old_node = old_nodes[name] - new_node.meta = copy.copy(old_node.meta) - old_node.replace_all_uses_with(new_node) - replaces[old_node.name] = new_node.name - - for old_node in old_nodes.values(): - gm.graph.erase_node(old_node) - - gm.recompile() - - graph_signature.buffers = [b for b in graph_signature.buffers if b not in names] - graph_signature.inputs_to_buffers = { - i: b for i, b in graph_signature.inputs_to_buffers.items() if b not in names - } - user_inputs_to_mutate = { - o: b for o, b in graph_signature.buffers_to_mutate.items() if b in names - } - graph_signature.buffers_to_mutate = { - o: b for o, b in graph_signature.buffers_to_mutate.items() if b not in names - } - graph_signature.user_inputs = list(reversed(new_node_names.values())) # type: ignore[arg-type] - graph_signature.user_outputs = [ - replaces[o] if o in replaces else o for o in graph_signature.user_outputs - ] - return user_inputs_to_mutate # type: ignore[return-value] - - def _disable_prexisiting_fake_mode(fn): @functools.wraps(fn) @@ -785,10 +703,6 @@ def _export( if isinstance(f, torch.nn.Module): _normalize_nn_module_stack(gm_torch_level, type(f)) - aot_export_args = (*fake_args, *_reorder_kwargs_by_names(orig_args, fake_args, fake_kwargs).values()) - - user_input_names = _prepare_module(gm_torch_level, aot_export_args) - # Note: aot_export_module doesn't accept kwargs, we'd like to reorder the kwargs as an OrderedDict # to follow the order in orig_args and correctly call gm_torch_level @@ -798,10 +712,9 @@ def _export( with torch.nn.utils.stateless._reparametrize_module(gm_torch_level, fake_params_buffers): gm, graph_signature = aot_export_module( gm_torch_level, - (), + (*fake_args, *_reorder_kwargs_by_names(orig_args, fake_args, fake_kwargs).values()), trace_joint=False ) - user_inputs_to_mutate = _unwrap_user_inputs(gm, graph_signature, user_input_names) def to_str_list(sig_component: List[Any]): return [str(v) for v in sig_component] @@ -858,7 +771,6 @@ def to_str_dict(sig_component: Dict[Any, Any]): is_joint = graph_signature.backward_signature is not None def make_argument_spec(node) -> ArgumentSpec: - assert "val" in node.meta, f"{node} has no 'val' metadata field" val = node.meta["val"] if isinstance(val, FakeTensor): return TensorArgument(name=node.name) @@ -872,7 +784,6 @@ def make_argument_spec(node) -> ArgumentSpec: inputs_to_buffers=graph_signature.inputs_to_buffers, # type: ignore[arg-type] user_outputs=set(graph_signature.user_outputs), # type: ignore[arg-type] buffer_mutations=graph_signature.buffers_to_mutate, # type: ignore[arg-type] - user_input_mutations=user_inputs_to_mutate, # type: ignore[arg-type] grad_params=graph_signature.backward_signature.gradients_to_parameters if is_joint else {}, # type: ignore[arg-type, union-attr] grad_user_inputs=graph_signature.backward_signature.gradients_to_user_inputs if is_joint else {}, # type: ignore[arg-type, union-attr] loss_output=graph_signature.backward_signature.loss_output if is_joint else None, # type: ignore[arg-type, union-attr] diff --git a/torch/_export/db/examples/user_input_mutation.py b/torch/_export/db/examples/user_input_mutation.py index 2bb16cd64a56..56af08d0c359 100644 --- a/torch/_export/db/examples/user_input_mutation.py +++ b/torch/_export/db/examples/user_input_mutation.py @@ -6,11 +6,11 @@ @export_case( example_inputs=(torch.ones(3, 2),), tags={"torch.mutation"}, - support_level=SupportLevel.SUPPORTED, + support_level=SupportLevel.NOT_SUPPORTED_YET, ) class UserInputMutation(torch.nn.Module): """ - Directly mutate user input in forward + Can't directly mutate user input in forward """ def forward(self, x): diff --git a/torch/_export/verifier.py b/torch/_export/verifier.py index 391d7f99f69b..dae103724c5e 100644 --- a/torch/_export/verifier.py +++ b/torch/_export/verifier.py @@ -230,6 +230,12 @@ def _verify_exported_program_signature(exported_program) -> None: # Check ExportedProgram signature matches gs = exported_program.graph_signature + bs_grad_to_param = {} + bs_grad_to_user_inputs = {} + if gs.backward_signature is not None: + bs_grad_to_param = gs.backward_signature.gradients_to_parameters + bs_grad_to_user_inputs = gs.backward_signature.gradients_to_user_inputs + # Check every node in the signature exists in the graph input_node_names = [node.name for node in exported_program.graph.nodes if node.op == "placeholder"] @@ -318,28 +324,19 @@ def _verify_exported_program_signature(exported_program) -> None: f"Number of user outputs: {len(gs.user_outputs)}. \n" ) - end = len(gs.buffers_to_mutate) + len(gs.user_inputs_to_mutate) - mutate_nodes: List[str] = output_nodes[:end] - user_output_nodes = output_nodes[end:end + len(gs.user_outputs)] + buffer_mutate_nodes = output_nodes[:len(gs.buffers_to_mutate)] + user_output_nodes = output_nodes[len(gs.buffers_to_mutate):len(gs.user_outputs) + len(gs.buffers_to_mutate)] - for mutation_node in mutate_nodes: - if mutation_node in gs.buffers_to_mutate: - if gs.buffers_to_mutate[mutation_node] not in gs.buffers: - raise SpecViolationError( - f"Buffer output {mutation_node} does not point to a buffer that exists. \n" - f"Dict of buffers that are mutated, in order: {gs.buffers_to_mutate} \n" - f"Buffer nodes available: {gs.buffers} \n" - ) - elif mutation_node in gs.user_inputs_to_mutate: - if gs.user_inputs_to_mutate[mutation_node] not in gs.user_inputs: - raise SpecViolationError( - f"User input output {mutation_node} does not point to a user input that exists. \n" - f"Dict of user inputs that are mutated, in order: {gs.user_inputs_to_mutate} \n" - f"User input nodes available: {gs.user_inputs} \n") - else: + for buffer_node in buffer_mutate_nodes: + if ( + buffer_node not in gs.buffers_to_mutate or + gs.buffers_to_mutate[buffer_node] not in gs.buffers + ): raise SpecViolationError( - f"Mutation node {mutation_node} is neither a buffer nor a user input. " - f"Buffers to mutate: {gs.buffers_to_mutate}, User inputs to mutate: {gs.user_inputs_to_mutate}" + f"Buffer output {buffer_node} is not in buffer mutation dictionary " + "or, it does not point to a buffer that exists. \n" + f"Dict of buffers that are mutated, in order: {gs.buffers_to_mutate} \n" + f"Buffer nodes available: {gs.buffers} \n" ) for user_output_node, user_output_name in zip(user_output_nodes, gs.user_outputs): diff --git a/torch/export/exported_program.py b/torch/export/exported_program.py index 015d63ff8bbe..dde89b4fdd9c 100644 --- a/torch/export/exported_program.py +++ b/torch/export/exported_program.py @@ -277,10 +277,9 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: ) if self.call_spec.out_spec is not None: - buffer_mutation = self.graph_signature.buffers_to_mutate - user_input_mutation = self.graph_signature.user_inputs_to_mutate - num_mutated = len(buffer_mutation) + len(user_input_mutation) - mutated_values = res[:num_mutated] + mutation = self.graph_signature.buffers_to_mutate + num_mutated = len(mutation) + mutated_buffers = res[:num_mutated] # Exclude dependency token from final result. assertion_dep_token = self.graph_signature.assertion_dep_token @@ -300,27 +299,10 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: f"{received_spec}" ) finally: - user_inputs = [ - spec - for spec in self.graph_signature.input_specs - if spec.kind == InputKind.USER_INPUT - ] - for i, value in enumerate(mutated_values): - output_spec = self.graph_signature.output_specs[i] - if output_spec.kind == OutputKind.BUFFER_MUTATION: - assert output_spec.target is not None - self.state_dict[output_spec.target] = value - elif output_spec.kind == OutputKind.USER_INPUT_MUTATION: - assert output_spec.target is not None - index = next( - i - for i, spec in enumerate(user_inputs) - if spec.arg.name == output_spec.target - ) - args[index].copy_(value) - else: - raise AssertionError(f"Unexpected kind: {output_spec.kind}") - + ix = 0 + for buffer in self.graph_signature.buffers_to_mutate.values(): + self.state_dict[buffer] = mutated_buffers[ix] + ix += 1 return res def __str__(self) -> str: @@ -383,6 +365,7 @@ def _get_placeholders(gm): decomp_table = decomp_table or core_aten_decompositions() old_placeholders = _get_placeholders(self.graph_module) + old_outputs = list(self.graph.nodes)[-1].args[0] fake_args = [node.meta["val"] for node in old_placeholders] buffers_to_remove = [name for name, _ in self.graph_module.named_buffers()] diff --git a/torch/export/graph_signature.py b/torch/export/graph_signature.py index 06c7f8e53e62..d208868d0d53 100644 --- a/torch/export/graph_signature.py +++ b/torch/export/graph_signature.py @@ -57,7 +57,6 @@ class OutputKind(Enum): BUFFER_MUTATION = auto() GRADIENT_TO_PARAMETER = auto() GRADIENT_TO_USER_INPUT = auto() - USER_INPUT_MUTATION = auto() @dataclasses.dataclass @@ -77,7 +76,6 @@ def _sig_to_specs( inputs_to_buffers: Mapping[str, str], user_outputs: Set[str], buffer_mutations: Mapping[str, str], - user_input_mutations: Mapping[str, str], grad_params: Mapping[str, str], grad_user_inputs: Mapping[str, str], loss_output: Optional[str], @@ -103,49 +101,37 @@ def to_input_spec(i: ArgumentSpec) -> InputSpec: else: raise AssertionError(f"Unknown tensor input kind: {name}") - def to_output_spec(idx: int, o: ArgumentSpec) -> OutputSpec: + def to_output_spec(o: ArgumentSpec) -> OutputSpec: if not isinstance(o, TensorArgument): return OutputSpec(kind=OutputKind.USER_OUTPUT, arg=o, target=None) name = o.name - if idx < len(buffer_mutations) + len(user_input_mutations): - if name in buffer_mutations: - return OutputSpec( - kind=OutputKind.BUFFER_MUTATION, - arg=o, - target=buffer_mutations[name], - ) - elif name in user_input_mutations: - return OutputSpec( - kind=OutputKind.USER_INPUT_MUTATION, - arg=o, - target=user_input_mutations[name], - ) - else: - raise AssertionError(f"Unknown tensor mutation kind: {name}") + if name in user_outputs: + return OutputSpec(kind=OutputKind.USER_OUTPUT, arg=o, target=None) + elif name in buffer_mutations: + return OutputSpec( + kind=OutputKind.BUFFER_MUTATION, + arg=o, + target=buffer_mutations[name], + ) + elif name in grad_params: + return OutputSpec( + kind=OutputKind.GRADIENT_TO_PARAMETER, + arg=o, + target=grad_params[name], + ) + elif name in grad_user_inputs: + return OutputSpec( + kind=OutputKind.GRADIENT_TO_USER_INPUT, + arg=o, + target=grad_user_inputs[name], + ) + elif name == loss_output: + return OutputSpec(kind=OutputKind.LOSS_OUTPUT, arg=o, target=None) else: - if name in user_outputs: - return OutputSpec(kind=OutputKind.USER_OUTPUT, arg=o, target=None) - - elif name in grad_params: - return OutputSpec( - kind=OutputKind.GRADIENT_TO_PARAMETER, - arg=o, - target=grad_params[name], - ) - elif name in grad_user_inputs: - return OutputSpec( - kind=OutputKind.GRADIENT_TO_USER_INPUT, - arg=o, - target=grad_user_inputs[name], - ) - elif name == loss_output: - return OutputSpec(kind=OutputKind.LOSS_OUTPUT, arg=o, target=None) - - else: - raise AssertionError(f"Unknown tensor output kind: {name}") + raise AssertionError(f"Unknown tensor output kind: {name}") input_specs = [to_input_spec(i) for i in inputs] - output_specs = [to_output_spec(idx, o) for idx, o in enumerate(outputs)] + output_specs = [to_output_spec(o) for o in outputs] return input_specs, output_specs @@ -318,16 +304,6 @@ def buffers_to_mutate(self) -> Mapping[str, str]: and isinstance(s.target, str) } - @property - def user_inputs_to_mutate(self) -> Mapping[str, str]: - return { - s.arg.name: s.target - for s in self.output_specs - if s.kind == OutputKind.USER_INPUT_MUTATION - and isinstance(s.arg, TensorArgument) - and isinstance(s.target, str) - } - # A dictionary mapping graph input node names to lifted tensor constants. @property def inputs_to_lifted_tensor_constants(self) -> Mapping[str, str]: From ccb1de3595fad0d8dc1f9269130dede16547fb77 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 27 Nov 2023 12:20:52 +0000 Subject: [PATCH 169/221] Revert "[inductor] Fix torch.split bug on unbacked symint (#113406)" This reverts commit cd7d6938c18d90870356553d4631f1388d2bb699. Reverted https://github.com/pytorch/pytorch/pull/113406 on behalf of https://github.com/DanilBaibak due to Break internal build ([comment](https://github.com/pytorch/pytorch/pull/113406#issuecomment-1827727411)) --- test/inductor/test_unbacked_symints.py | 15 --------------- torch/_inductor/codegen/common.py | 4 ---- torch/_inductor/codegen/triton.py | 3 +-- torch/_inductor/ir.py | 11 +++++------ torch/fx/experimental/validator.py | 2 -- 5 files changed, 6 insertions(+), 29 deletions(-) diff --git a/test/inductor/test_unbacked_symints.py b/test/inductor/test_unbacked_symints.py index 4ab72c6721bc..2797bea8ceb1 100644 --- a/test/inductor/test_unbacked_symints.py +++ b/test/inductor/test_unbacked_symints.py @@ -54,21 +54,6 @@ def fn(x, y): torch.testing.assert_close(actual, expected) - def test_split_with_sizes(self): - def fn(x, y): - l = y.tolist() - s = torch.split(x, l) - d = l[0] + l[1] + l[2] - return s[0].sum(), d - - example_inputs = (torch.randn((32), device="cuda"), torch.tensor((7, 16, 9))) - - with dynamo_config.patch({"capture_scalar_outputs": True}): - actual = torch.compile(fn, fullgraph=True)(*example_inputs) - expected = fn(*example_inputs) - - torch.testing.assert_close(actual, expected) - if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index 19fa6de3219e..be949e8f92a9 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -378,10 +378,6 @@ def _print_Max(self, expr): assert len(expr.args) >= 2 return f"max({', '.join(map(self._print, expr.args))})" - def _print_Min(self, expr): - assert len(expr.args) >= 2 - return f"min({', '.join(map(self._print, expr.args))})" - class OpOverrides: def __init__(self, parent): diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 6b288cec254d..0f08f728330f 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -19,7 +19,6 @@ from torch._prims_common import is_integer_dtype from torch.utils._sympy.functions import FloorDiv, ModularIndexing from torch.utils._sympy.value_ranges import ValueRanges - from ..._dynamo.utils import counters from .. import config, ir, scheduler from ..codecache import code_hash, get_path, PyCodeCache @@ -1144,7 +1143,7 @@ def indexing( # indirect indexing cse_var = self.cse.varname_map[var.name] mask_vars.update(cse_var.mask_vars) - elif var.name.startswith(("s", "ps", "i")): + elif var.name.startswith(("s", "ps")): pass else: # var is one of xN, yN or rN diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index ae031b8f2818..0c5846153892 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -2088,12 +2088,11 @@ def create(cls, x, dim, start, end, step=1): start = cls.handle_negative_index(start, new_size[dim]) end = cls.handle_negative_index(end, new_size[dim]) - if free_unbacked_symbols(start) or free_unbacked_symbols(end): - end = sympy.Min(end, new_size[dim]) - start = sympy.Min(start, end) - else: - end = sizevars.evaluate_min(end, new_size[dim]) - start = sizevars.evaluate_min(start, end) + end = sizevars.evaluate_min(end, new_size[dim]) + start = sizevars.evaluate_min(start, end) + if start == 0 and sizevars.size_hint(end - new_size[dim]) == 0 and step == 1: + sizevars.guard_equals(end, new_size[dim]) + return x new_size[dim] = FloorDiv(end - start + (step - 1), step) diff --git a/torch/fx/experimental/validator.py b/torch/fx/experimental/validator.py index 48ad07dd8559..8c795bb5b3de 100644 --- a/torch/fx/experimental/validator.py +++ b/torch/fx/experimental/validator.py @@ -363,8 +363,6 @@ def __getattr__(self, name: str) -> Any: "not_": z3.Not, "floor": self._ops.floor, "ceil": self._ops.ceil, - "minimum": self._ops.min, - "maximum": self._ops.max, } if name in REPLACEMENT: From cff84871ce5fd78fbb8b59a2adf3e1fdae3f257b Mon Sep 17 00:00:00 2001 From: Khushi Agrawal Date: Mon, 27 Nov 2023 14:45:44 +0000 Subject: [PATCH 170/221] [reland][opinfo][fix] conv3d & fix conv{1, 2}d for neg dilation|groups & add ErrorInputs for conv ops (#114589) Previous PR: #113885 Pull Request resolved: https://github.com/pytorch/pytorch/pull/114589 Approved by: https://github.com/lezcano --- aten/src/ATen/native/Convolution.cpp | 10 + test/functorch/test_ops.py | 2 + test/test_mps.py | 2 + .../_internal/common_methods_invocations.py | 243 ++++++++++++++++-- 4 files changed, 231 insertions(+), 26 deletions(-) diff --git a/aten/src/ATen/native/Convolution.cpp b/aten/src/ATen/native/Convolution.cpp index 9c31026af54c..43ee07b41107 100644 --- a/aten/src/ATen/native/Convolution.cpp +++ b/aten/src/ATen/native/Convolution.cpp @@ -343,6 +343,14 @@ struct ConvParams { return is_non_neg; } + bool is_dilation_neg() const { + bool is_non_neg = false; + for (const auto& p : dilation) { + is_non_neg |= (p < 0); + } + return is_non_neg; + } + bool is_stride_nonpos() const { bool is_nonpos = false; for (auto s : stride) { @@ -652,6 +660,7 @@ static void check_shape_forward(const at::Tensor& input, TORCH_CHECK(!params.is_padding_neg(), "negative padding is not supported"); TORCH_CHECK(!params.is_output_padding_neg(), "negative output_padding is not supported"); TORCH_CHECK(!params.is_stride_nonpos(), "non-positive stride is not supported"); + TORCH_CHECK(!params.is_dilation_neg(), "dilation should be greater than zero"); TORCH_CHECK(weight_dim == k, "Expected ", weight_dim, "-dimensional input for ", weight_dim, @@ -973,6 +982,7 @@ static Tensor convolution_same( auto k = weight.dim(); TORCH_CHECK(k > 2, "weight should have at least three dimensions"); + TORCH_CHECK(groups > 0, "non-positive groups is not supported"); auto dim = static_cast(k - 2); auto weight_sizes = weight.sym_sizes(); auto input_sizes = input.sym_sizes(); diff --git a/test/functorch/test_ops.py b/test/functorch/test_ops.py index 945162ac69e4..eb90ba83d512 100644 --- a/test/functorch/test_ops.py +++ b/test/functorch/test_ops.py @@ -1808,6 +1808,8 @@ def fn(input, weight, bias): {torch.float32: tol(atol=2e-04, rtol=1e-04)}, device_type='cuda'), tol2('linalg.pinv', 'hermitian', {torch.float32: tol(atol=5e-06, rtol=5e-06)}), + tol1('nn.functional.conv3d', + {torch.float32: tol(atol=5e-04, rtol=9e-03)}), )) def test_vmap_autograd_grad(self, device, dtype, op): def is_differentiable(inp): diff --git a/test/test_mps.py b/test/test_mps.py index 2a1bcdb30782..a1f3f5215eef 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -695,6 +695,7 @@ def mps_ops_modifier(ops): # Convolution for integral types is not supported on MPS 'nn.functional.conv1d': [torch.int64], 'nn.functional.conv2d': [torch.int64], + 'nn.functional.conv3d': [torch.int64], 'nn.functional.conv_transpose1d': [torch.int64], 'nn.functional.conv_transpose2d': [torch.int64], @@ -882,6 +883,7 @@ def mps_ops_error_inputs_modifier(ops): 'multinomial', 'nn.functional.conv1d', 'nn.functional.conv2d', + 'nn.functional.conv3d', 'gather', 'scatter', 'scatter_add', diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index e6cd427a99e2..b1021491577f 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -3848,10 +3848,6 @@ def sample_inputs_conv1d(op_info, device, dtype, requires_grad, **kwargs): ((1, 4, 5), (3, 4, 3), None, {}), ) - # TODO: (@krshrimali), add error_inputs_func once https://github.com/pytorch/pytorch/pull/67354 is merged - # Should replace test_conv_modules_raise_error_on_incorrect_input_size and test_conv_shapecheck - # in test/test_nn.py - for input_shape, weight, bias, kwargs in cases: # Batched yield SampleInput(make_arg(input_shape), args=( @@ -3866,33 +3862,112 @@ def sample_inputs_conv1d(op_info, device, dtype, requires_grad, **kwargs): def error_inputs_conv1d(opinfo, device, **kwargs): - input = torch.randn(size=(33, 16, 30), device=device, dtype=torch.float64) - weight = torch.randn(size=(20, 16, 5), device=device, dtype=torch.float64) - groups = 0 + make_arg = partial(make_tensor, device=device, dtype=torch.float64) + + # error inputs for negative strides yield ErrorInput( - SampleInput(input, kwargs={"weight": weight, "groups": groups}), - error_regex="non-positive groups is not supported" - ) + SampleInput(make_arg((1, 1, 4)), args=(make_arg((1, 2, 2)), make_arg((1,))), + kwargs={'stride': (-1,)}), error_regex="non-positive stride is not supported") + + # error inputs for negative padding + yield ErrorInput( + SampleInput(make_arg((1, 1, 4)), args=(make_arg((1, 2, 2)), make_arg((1,))), + kwargs={'padding': (-1,)}), error_regex="negative padding is not supported") + + # error inputs for negative dilation + yield ErrorInput( + SampleInput(make_arg((1, 1, 4)), args=(make_arg((1, 1, 2)), make_arg((1,))), + kwargs={'dilation': (-1,)}), error_regex="dilation should be greater than zero") + + # FIXME: https://github.com/pytorch/pytorch/issues/85656 + # error inputs for bias shape not equal to the output channels + # yield ErrorInput(SampleInput(make_arg((1, 1, 4)), args=(make_arg((1, 1, 3)), make_arg((2,)))), + # error_regex="expected bias to be 1-dimensional with 1 elements") + + # error inputs for input.ndim != weight.ndim + yield ErrorInput(SampleInput(make_arg((1, 1, 4)), args=(make_arg((1, 2)), make_arg((1,)))), + error_regex="weight should have at least three dimensions") + + # error inputs for the weight[0] are less than the number of groups + yield ErrorInput( + SampleInput(make_arg((2, 2, 4)), args=(make_arg((2, 2, 2)), make_arg((2,))), + kwargs={'padding': 'same', 'groups': 3}), error_regex="expected weight to be at least 3 at dimension 0") + + # error inputs for the weight[0] are less than the number of groups + yield ErrorInput( + SampleInput(make_arg((2, 2, 4)), args=(make_arg((2, 2, 2)), make_arg((2,))), + kwargs={'groups': 3}), error_regex="expected weight to be at least 3 at dimension 0") + + # error inputs for invalid groups + yield ErrorInput( + SampleInput(make_arg((2, 2, 4)), args=(make_arg((2, 2, 2)), make_arg((2,))), + kwargs={'padding': 'same', 'groups': -1}), error_regex="non-positive groups is not supported") + + # error inputs for invalid groups + yield ErrorInput( + SampleInput(make_arg((2, 2, 4)), args=(make_arg((2, 2, 2)), make_arg((2,))), + kwargs={'padding': 'same', 'groups': 0}), error_regex="non-positive groups is not supported") def error_inputs_conv2d(opinfo, device, **kwargs): - weight = torch.randint(high=10, size=(3, 2, 3, 3), device=device) - input = torch.randint(high=10, size=(2, 4, 4), device=device) - bias = torch.rand((3,), dtype=torch.float32, device=device) - yield ErrorInput(SampleInput(input, args=(weight, bias)), error_regex="should be the same") - - weight = torch.rand(size=(3, 2, 3, 3), device=device, dtype=torch.float64) - input = torch.rand(size=(2, 4, 4), device=device, dtype=torch.float64) - bias = torch.rand((3,), dtype=torch.complex128, device=device) - yield ErrorInput(SampleInput(input, args=(weight, bias)), error_regex="should be the same") - - input = torch.randn(size=(1, 4, 5, 5), device=device, dtype=torch.float64) - weight = torch.randn(size=(8, 4, 3, 3), device=device, dtype=torch.float64) - groups = 0 + make_arg = partial(make_tensor, device=device, dtype=torch.float64) + make_int_arg = partial(make_tensor, device=device, dtype=torch.int64) + make_complex_arg = partial(make_tensor, device=device, dtype=torch.complex128) + + # error inputs for different dtypes of input tensor and bias yield ErrorInput( - SampleInput(input, kwargs={"weight": weight, "groups": groups}), - error_regex="non-positive groups is not supported" - ) + SampleInput(make_int_arg((2, 4, 4)), args=(make_int_arg((3, 2, 3, 3)), make_arg((3,)))), + error_regex="should be the same") + + # error inputs for different dtypes of input tensor and bias + yield ErrorInput( + SampleInput(make_arg((2, 4, 4)), args=(make_arg((3, 2, 3, 3)), make_complex_arg((3,)))), + error_regex="should be the same") + + # error inputs for negative strides + yield ErrorInput( + SampleInput(make_arg((1, 1, 4, 4)), args=(make_arg((1, 2, 2, 3)), make_arg((1,))), + kwargs={'stride': (-1,)}), error_regex="non-positive stride is not supported") + + # error inputs for negative padding + yield ErrorInput( + SampleInput(make_arg((1, 1, 4, 3)), args=(make_arg((1, 2, 2, 4)), make_arg((1,))), + kwargs={'padding': (-1,)}), error_regex="negative padding is not supported") + + # error inputs for negative dilation + yield ErrorInput( + SampleInput(make_arg((1, 1, 4, 2)), args=(make_arg((1, 1, 2, 5)), make_arg((1,))), + kwargs={'dilation': (-1,)}), error_regex="dilation should be greater than zero") + + # FIXME: https://github.com/pytorch/pytorch/issues/85656 + # error inputs for bias shape not equal to the output channels + # yield ErrorInput(SampleInput(make_arg((1, 1, 4, 4)), args=(make_arg((1, 1, 3, 2)), make_arg((2,)))), + # error_regex="expected bias to be 1-dimensional with 1 elements") + + # error inputs for input.ndim != weight.ndim + yield ErrorInput( + SampleInput(make_arg((1, 1, 4, 3)), args=(make_arg((1, 2, 2)), make_arg((1,))), + kwargs={'padding': 'same'}), error_regex="Expected 3-dimensional input for 3-dimensional weight") + + # error inputs for the weight[0] are less than the number of groups + yield ErrorInput( + SampleInput(make_arg((2, 2, 4, 3)), args=(make_arg((2, 2, 1, 3)), make_arg((2,))), + kwargs={'groups': 3}), error_regex="expected weight to be at least 3 at dimension 0") + + # error inputs for groups the weight[0] are less than the number of groups + yield ErrorInput( + SampleInput(make_arg((2, 2, 4, 3)), args=(make_arg((2, 2, 1, 3)), make_arg((2,))), + kwargs={'padding': 'same', 'groups': 3}), error_regex="expected weight to be at least 3 at dimension 0") + + # error inputs for invalid groups + yield ErrorInput( + SampleInput(make_arg((2, 2, 4, 5)), args=(make_arg((2, 2, 1, 4)), make_arg((2,))), + kwargs={'padding': 'same', 'groups': -1}), error_regex="non-positive groups is not supported") + + # error inputs for invalid groups + yield ErrorInput( + SampleInput(make_arg((2, 2, 4, 3)), args=(make_arg((2, 2, 4, 3)), make_arg((2,))), + kwargs={'padding': 'same', 'groups': 0}), error_regex="non-positive groups is not supported") def sample_inputs_conv2d(op_info, device, dtype, requires_grad, jit_fail_sample=False, **kwargs): @@ -3940,6 +4015,90 @@ def sample_inputs_conv2d(op_info, device, dtype, requires_grad, jit_fail_sample= ), kwargs=kwargs) +def sample_inputs_conv3d(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Ordered as shapes for input, weight, bias + # and dict of values of (stride, padding, dilation, groups) + cases: Tuple = ( + ((1, 1, 4, 4, 4), (1, 1, 1, 1, 1), (1,), {'padding': 'same'}), + ((1, 1, 4, 4, 4), (1, 1, 4, 4, 4), (1,), {'stride': (2, 2, 2)}), + ((1, 1, 5, 5, 5), (1, 1, 3, 3, 3), (1,), {'dilation': 2}), + ((1, 1, 1, 1, 10), (1, 1, 1, 1, 4), None, {'padding': 'valid'}), + ((1, 1, 10, 11, 12), (1, 1, 1, 2, 5), None, {'padding': 'same'}), + ((1, 1, 10, 11, 12), (1, 1, 1, 2, 5), None, {'padding': 'same', 'dilation': 2}), + ((1, 1, 10, 11, 12), (1, 1, 4, 4, 4), None, {'padding': 'same', 'dilation': 3}), + ((1, 1, 1, 1, 10), (1, 1, 1, 1, 4), None, {'padding': 'valid'}), + ((3, 9, 3, 1, 9), (3, 3, 3, 1, 9), (3,), {'groups': 3}), + ((3, 9, 3, 1, 9), (3, 3, 3, 1, 9), (3,), {'stride': (2, 2, 2), 'dilation': 1, 'groups': 3}), + ) + + for input_shape, weight, bias, kwargs in cases: + # Batched + yield SampleInput(make_arg(input_shape), args=( + make_arg(weight), + make_arg(bias) if bias is not None else bias + ), kwargs=kwargs) + # Unbatched + yield SampleInput(make_arg(input_shape[1:]), args=( + make_arg(weight), + make_arg(bias) if bias is not None else bias + ), kwargs=kwargs) + + +def error_inputs_conv3d(opinfo, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float64) + + # error inputs for negative strides + yield ErrorInput( + SampleInput(make_arg((1, 1, 4, 4, 4)), args=(make_arg((1, 1, 2, 2, 2)), make_arg((1,))), + kwargs={'stride': (-1,)}), error_regex="non-positive stride is not supported") + + # error inputs for negative padding + yield ErrorInput( + SampleInput(make_arg((1, 1, 4, 4, 4)), args=(make_arg((1, 1, 2, 2, 2)), make_arg((1,))), + kwargs={'padding': (-1,)}), error_regex="negative padding is not supported") + + # error inputs for negative dilation + yield ErrorInput( + SampleInput(make_arg((1, 1, 4, 4, 4)), args=(make_arg((1, 1, 2, 2, 2)), make_arg((1,))), + kwargs={'dilation': (-1,)}), error_regex="dilation should be greater than zero") + + # FIXME: https://github.com/pytorch/pytorch/issues/85656 + # error inputs for bias shape not equal to the output channels + # yield ErrorInput(SampleInput(make_arg((1, 1, 4, 4, 4)), args=(make_arg((1, 1, 3, 3, 3)), make_arg((2,)))), + # error_regex="expected bias to be 1-dimensional with 1 elements") + + # error inputs for input.ndim != weight.ndim + yield ErrorInput( + SampleInput(make_arg((1, 1, 3, 4, 5)), args=(make_arg((1, 1, 4, 3)), make_arg((1,))), + kwargs={'padding': 'same'}), error_regex="Expected 4-dimensional input for 4-dimensional weight") + + # error inputs for the weight[0] are less than the number of groups + yield ErrorInput( + SampleInput(make_arg((2, 2, 3, 4, 5)), args=(make_arg((2, 2, 4, 3, 3)), + make_arg((2,))), kwargs={'groups': 3}), + error_regex="expected weight to be at least 3 at dimension 0") + + # error inputs for the weight[0] are less than the number of groups + yield ErrorInput( + SampleInput(make_arg((2, 2, 3, 4, 5)), args=(make_arg((2, 2, 4, 3, 3)), + make_arg((2,))), kwargs={'padding': 'same', 'groups': 3}), + error_regex="expected weight to be at least 3 at dimension 0") + + # error inputs for invalid groups + yield ErrorInput( + SampleInput(make_arg((2, 2, 3, 4, 5)), args=(make_arg((2, 2, 4, 3, 3)), + make_arg((2,))), kwargs={'padding': 'same', 'groups': 0}), + error_regex="non-positive groups is not supported") + + # error inputs for padding='same' not supported by strided convolutions + yield ErrorInput( + SampleInput(make_arg((18, 27, 9, 1, 9)), args=(make_arg((9, 9, 9, 1, 9)), + make_arg((9,))), kwargs={'stride': 2, 'padding': 'same', 'groups': 3}), + error_regex="padding='same' is not supported for strided convolutions") + + def sample_inputs_group_norm(opinfo, device, dtype, requires_grad, **kwargs): make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) @@ -13076,6 +13235,38 @@ def reference_flatten(input, start_dim=0, end_dim=-1): ), supports_expanded_weight=True, supports_out=False,), + OpInfo('nn.functional.conv3d', + aliases=('conv3d',), + aten_name='conv3d', + dtypes=floating_and_complex_types_and(torch.int64, torch.bfloat16, torch.float16), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.chalf, torch.bfloat16), + sample_inputs_func=sample_inputs_conv3d, + error_inputs_func=error_inputs_conv3d, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + decorators=( + DecorateInfo( + toleranceOverride({torch.chalf: tol(atol=6e-2, rtol=5e-2)}), + 'TestCommon', 'test_complex_half_reference_testing', + ), + ), + skips=( + # RuntimeError: !lhs.isAliasOf(rhs) INTERNAL ASSERT FAILED at + # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":103, please report a bug to PyTorch. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + # RuntimeError: UNSUPPORTED DTYPE: complex + DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', + 'test_nnc_correctness', dtypes=(torch.complex64, torch.complex128)), + # RuntimeError: Conv3D is not supported on MPS + DecorateInfo(unittest.expectedFailure, 'TestConsistency'), + # AssertionError: Tensor-likes are not close! + # break slow tests + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_compare_cpu'), + ), + supports_expanded_weight=True, + supports_out=False,), OpInfo('nn.functional.group_norm', aten_name='group_norm', aliases=('group_norm',), From 56a95afb422725b01ebb7407c94f4647c48da6a4 Mon Sep 17 00:00:00 2001 From: atalman Date: Mon, 27 Nov 2023 15:15:19 +0000 Subject: [PATCH 171/221] [RelEng] Pin disabled and slow test for release (#114515) Follow up for https://github.com/pytorch/pytorch/pull/114355 Pin disabled and slow tests when applying release only changes Pull Request resolved: https://github.com/pytorch/pytorch/pull/114515 Approved by: https://github.com/DanilBaibak --- scripts/release/apply-release-changes.sh | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/scripts/release/apply-release-changes.sh b/scripts/release/apply-release-changes.sh index 4c07057e6661..613a0e4eff6b 100755 --- a/scripts/release/apply-release-changes.sh +++ b/scripts/release/apply-release-changes.sh @@ -39,12 +39,15 @@ sed -i -e s#.*#r"${RELEASE_VERSION}"# .github/ci_commit_pins/xla.txt export RELEASE_VERSION_TAG=${RELEASE_VERSION} ./.github/regenerate.sh -# Pin Unstable and disabled jobs +# Pin Unstable and disabled jobs and tests UNSTABLE_VER=$(aws s3api list-object-versions --bucket ossci-metrics --prefix unstable-jobs.json --query 'Versions[?IsLatest].[VersionId]' --output text) DISABLED_VER=$(aws s3api list-object-versions --bucket ossci-metrics --prefix disabled-jobs.json --query 'Versions[?IsLatest].[VersionId]' --output text) -sed -i -e s#unstable-jobs.json#"unstable-jobs.json?versionid=${UNSTABLE_VER}"# .github/scripts/filter_test_configs.py -sed -i -e s#disabled-jobs.json#"disabled-jobs.json?versionid=${DISABLED_VER}"# .github/scripts/filter_test_configs.py - +SLOW_VER=$(aws s3api list-object-versions --bucket ossci-metrics --prefix slow-tests.json --query 'Versions[?IsLatest].[VersionId]' --output text) +DISABLED_TESTS_VER=$(aws s3api list-object-versions --bucket ossci-metrics --prefix disabled-tests-condensed.json --query 'Versions[?IsLatest].[VersionId]' --output text) +sed -i -e s#unstable-jobs.json#"unstable-jobs.json?versionId=${UNSTABLE_VER}"# .github/scripts/filter_test_configs.py +sed -i -e s#disabled-jobs.json#"disabled-jobs.json?versionId=${DISABLED_VER}"# .github/scripts/filter_test_configs.py +sed -i -e s#slow-tests.json#"slow-tests.json?versionId=${SLOW_VER}"# tools/stats/import_test_stats.py +sed -i -e s#disabled-tests-condensed.json#"disabled-tests-condensed.json?versionId=${DISABLED_TESTS_VER}"# tools/stats/import_test_stats.py # Optional # git commit -m "[RELEASE-ONLY CHANGES] Branch Cut for Release {RELEASE_VERSION}" # git push origin "${RELEASE_BRANCH}" From 9fd447c346e24be2acb8b8f07625e3c1b55531a2 Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Mon, 27 Nov 2023 08:30:45 -0800 Subject: [PATCH 172/221] [CI] Bump up the graph break count for DALLE2_pytorch temporarily (#114598) Summary: rotary-embedding-torch's version changing from 0.3.3 to 0.3.6 caused some new graph breaks for DALLE2_pytorch. A proper fix is to pin down rotary-embedding-torch's version in torchbench, and then update our torchbench pin to pick up that change. Pull Request resolved: https://github.com/pytorch/pytorch/pull/114598 Approved by: https://github.com/seemethere, https://github.com/aakhundov --- .../ci_expected_accuracy/aot_eager_torchbench_inference.csv | 2 +- .../dynamic_aot_eager_torchbench_inference.csv | 2 +- .../dynamic_inductor_torchbench_inference.csv | 2 +- .../ci_expected_accuracy/dynamo_eager_torchbench_inference.csv | 2 +- .../ci_expected_accuracy/inductor_torchbench_inference.csv | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv index 0854562b587a..6992dfe750d2 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv @@ -10,7 +10,7 @@ Background_Matting,pass_due_to_skip,0 -DALLE2_pytorch,fail_to_run,21 +DALLE2_pytorch,fail_to_run,31 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv index 3d975ef1fe68..17be27a4a9ea 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv @@ -10,7 +10,7 @@ Background_Matting,pass_due_to_skip,0 -DALLE2_pytorch,fail_to_run,21 +DALLE2_pytorch,fail_to_run,31 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv index 3d975ef1fe68..17be27a4a9ea 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv @@ -10,7 +10,7 @@ Background_Matting,pass_due_to_skip,0 -DALLE2_pytorch,fail_to_run,21 +DALLE2_pytorch,fail_to_run,31 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv index 0854562b587a..6992dfe750d2 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv @@ -10,7 +10,7 @@ Background_Matting,pass_due_to_skip,0 -DALLE2_pytorch,fail_to_run,21 +DALLE2_pytorch,fail_to_run,31 diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv index ac90d0bbb8ac..30ee17649f30 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv @@ -10,7 +10,7 @@ Background_Matting,pass_due_to_skip,0 -DALLE2_pytorch,fail_to_run,21 +DALLE2_pytorch,fail_to_run,31 From bcfca41a2a8d80ca186e2550e4bca1a52e7873ff Mon Sep 17 00:00:00 2001 From: "Yu, Guangye" Date: Mon, 27 Nov 2023 17:11:59 +0000 Subject: [PATCH 173/221] [Inductor] fix wrong Inductor UTs (#114504) # Motivation These UTs seem wrong. Fix them. Pull Request resolved: https://github.com/pytorch/pytorch/pull/114504 Approved by: https://github.com/aakhundov --- test/inductor/test_torchinductor.py | 39 +++++++++++++++-------------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 02787e019317..6877657d9c01 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -1945,29 +1945,31 @@ def fn(a, b): def test_mixed_mm(self): def fn(a, b): return torch.mm(a, b.to(a.dtype)) - self.common( - fn, - ( - torch.randn(8, 8), - torch.randint(-128, 127, (8, 8), dtype=torch.int8), - ), - check_lowp=True, - ) + + self.common( + fn, + ( + torch.randn(8, 8), + torch.randint(-128, 127, (8, 8), dtype=torch.int8), + ), + check_lowp=True, + ) @config.patch(force_mixed_mm=True) def test_mixed_mm2(self): def fn(a, b, scale, bias): return torch.mm(a, b.to(a.dtype)) * scale + bias - self.common( - fn, - ( - torch.randn(8, 8), - torch.randint(-128, 127, (8, 8), dtype=torch.int8), - torch.randn(8), - torch.randn(8), - ), - check_lowp=True, - ) + + self.common( + fn, + ( + torch.randn(8, 8), + torch.randint(-128, 127, (8, 8), dtype=torch.int8), + torch.randn(8), + torch.randn(8), + ), + check_lowp=True, + ) @config.patch(use_mixed_mm=True) def test_uint4x2_mixed_mm(self): @@ -6844,7 +6846,6 @@ def test_rsqrt_dynamic_shapes(self): def fn(a, b): r = 1 / math.sqrt(a.size(1)) return torch.bmm(a, b) / r - return (r,) self.common( fn, From 3a4dea99dff8a01f490efdb2f50d2c2a8c5f7731 Mon Sep 17 00:00:00 2001 From: Jack Taylor <108682042+jataylo@users.noreply.github.com> Date: Mon, 27 Nov 2023 17:29:23 +0000 Subject: [PATCH 174/221] ROCm triton commit pin update (#114348) Small bump in rocm triton commit pin to resolve reported issue on 7900XTX > RuntimeError: Triton Error [HIP]: Code: 719, Messsage: unspecified launch failure https://github.com/ROCmSoftwarePlatform/triton/issues/396 Pull Request resolved: https://github.com/pytorch/pytorch/pull/114348 Approved by: https://github.com/jeffdaily --- .ci/docker/ci_commit_pins/triton-rocm.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.ci/docker/ci_commit_pins/triton-rocm.txt b/.ci/docker/ci_commit_pins/triton-rocm.txt index 8b0800990ac3..4a873428eaa6 100644 --- a/.ci/docker/ci_commit_pins/triton-rocm.txt +++ b/.ci/docker/ci_commit_pins/triton-rocm.txt @@ -1 +1 @@ -e8a35b3968780e48df1374482d56cc6cdbb9e351 +dafe1459823b9549417ed95e9720f1b594fab329 From 4bb3a02d02e7a15ffe65a91bc3dd80774f5dc0ff Mon Sep 17 00:00:00 2001 From: Aaron Gokaslan Date: Mon, 27 Nov 2023 17:38:08 +0000 Subject: [PATCH 175/221] [BE]: Enable Ruff + Flake8 G201,G202 logging format rule. (#114474) Standardizes logging calls to always use logging.exception instead of logging.error where appropriate and enforces it with a lint. Pull Request resolved: https://github.com/pytorch/pytorch/pull/114474 Approved by: https://github.com/jansel, https://github.com/malfet --- .flake8 | 2 +- pyproject.toml | 2 +- torch/_dynamo/guards.py | 3 +-- torch/_dynamo/utils.py | 2 +- torch/distributed/elastic/multiprocessing/api.py | 3 +-- torch/distributed/elastic/timer/api.py | 9 ++++----- .../distributed/elastic/timer/file_based_local_timer.py | 8 ++++---- torch/distributed/elastic/timer/local_timer.py | 4 ++-- 8 files changed, 15 insertions(+), 18 deletions(-) diff --git a/.flake8 b/.flake8 index bca578ce563e..1e61b459df94 100644 --- a/.flake8 +++ b/.flake8 @@ -18,7 +18,7 @@ ignore = # these ignores are from flake8-comprehensions; please fix! C407, # these ignores are from flake8-logging-format; please fix! - G100,G101,G200,G201,G202 + G100,G101,G200 # these ignores are from flake8-simplify. please fix or ignore with commented reason SIM105,SIM108,SIM110,SIM111,SIM113,SIM114,SIM115,SIM116,SIM117,SIM118,SIM119,SIM12, # flake8-simplify code styles diff --git a/pyproject.toml b/pyproject.toml index 279bd6fa058b..71157c4f3cf3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,7 +43,7 @@ ignore = [ "F821", "F841", # these ignores are from flake8-logging-format; please fix! - "G101", "G201", "G202", + "G101", # these ignores are from RUFF perf; please fix! "PERF203", "PERF4", # these ignores are from PYI; please fix! diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 1b068402019b..0ef173155e2f 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -1315,9 +1315,8 @@ def get_guard_fail_reason( GuardFail(reason_str or "unknown reason", orig_code_map[code]) ) except Exception as e: - log.error( + log.exception( "Failure in guard_fail_fn callback - raising here will cause a NULL Error on guard eval", - exc_info=True, ) return reason_str diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index ba876a0fbb82..47275ea04185 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -400,7 +400,7 @@ def write_record_to_file(filename, exec_record): with open(filename, "wb") as f: exec_record.dump(f) except Exception: - log.error("Unable to write execution record %s", filename, exc_info=True) + log.exception("Unable to write execution record %s", filename) def count_calls(g: fx.Graph): diff --git a/torch/distributed/elastic/multiprocessing/api.py b/torch/distributed/elastic/multiprocessing/api.py index 32426be08010..c7c870bdb073 100644 --- a/torch/distributed/elastic/multiprocessing/api.py +++ b/torch/distributed/elastic/multiprocessing/api.py @@ -477,14 +477,13 @@ def _poll(self) -> Optional[RunProcsResult]: failed_proc = self._pc.processes[failed_local_rank] error_filepath = self.error_files[failed_local_rank] - log.error( + log.exception( "failed (exitcode: %s)" " local_rank: %s (pid: %s)" " of fn: %s (start_method: %s)", failed_proc.exitcode, failed_local_rank, e.pid, fn_name, self.start_method, - exc_info=True, ) self.close() diff --git a/torch/distributed/elastic/timer/api.py b/torch/distributed/elastic/timer/api.py index 6dd308891988..566a3d4acbc7 100644 --- a/torch/distributed/elastic/timer/api.py +++ b/torch/distributed/elastic/timer/api.py @@ -169,11 +169,10 @@ def _reap_worker_no_throw(self, worker_id: Any) -> bool: """ try: return self._reap_worker(worker_id) - except Exception as e: - log.error( + except Exception: + log.exception( "Uncaught exception thrown from _reap_worker(), " "check that the implementation correctly catches exceptions", - exc_info=e, ) return True @@ -181,8 +180,8 @@ def _watchdog_loop(self): while not self._stop_signaled: try: self._run_watchdog() - except Exception as e: - log.error("Error running watchdog", exc_info=e) + except Exception: + log.exception("Error running watchdog") def _run_watchdog(self): batch_size = max(1, self._request_queue.size()) diff --git a/torch/distributed/elastic/timer/file_based_local_timer.py b/torch/distributed/elastic/timer/file_based_local_timer.py index 597000c6d20d..26ebce33dcb5 100644 --- a/torch/distributed/elastic/timer/file_based_local_timer.py +++ b/torch/distributed/elastic/timer/file_based_local_timer.py @@ -225,8 +225,8 @@ def _watchdog_loop(self) -> None: self._run_watchdog(fd) if run_once: break - except Exception as e: - log.error("Error running watchdog", exc_info=e) + except Exception: + log.exception("Error running watchdog") def _run_watchdog(self, fd: io.TextIOWrapper) -> None: timer_requests = self._get_requests(fd, self._max_interval) @@ -328,6 +328,6 @@ def _reap_worker(self, worker_pid: int, signal: int) -> bool: except ProcessLookupError: log.info("Process with pid=%s does not exist. Skipping", worker_pid) return True - except Exception as e: - log.error("Error terminating pid=%s", worker_pid, exc_info=e) + except Exception: + log.exception("Error terminating pid=%s", worker_pid) return False diff --git a/torch/distributed/elastic/timer/local_timer.py b/torch/distributed/elastic/timer/local_timer.py index 240163f1bf6c..05f467c807a5 100644 --- a/torch/distributed/elastic/timer/local_timer.py +++ b/torch/distributed/elastic/timer/local_timer.py @@ -120,6 +120,6 @@ def _reap_worker(self, worker_id: int) -> bool: except ProcessLookupError: log.info("Process with pid=%s does not exist. Skipping", worker_id) return True - except Exception as e: - log.error("Error terminating pid=%s", worker_id, exc_info=e) + except Exception: + log.exception("Error terminating pid=%s", worker_id) return False From 1793ef77c62685083410fb205993fabb1ba4a389 Mon Sep 17 00:00:00 2001 From: Khushi Agrawal Date: Mon, 27 Nov 2023 18:30:59 +0000 Subject: [PATCH 176/221] [BC-breaking] conv1d & conv3d (#114594) As discussed here: https://github.com/pytorch/pytorch/pull/113885#discussion_r1404573875 #### TODO - [x] add error inputs after #114589 is merged Pull Request resolved: https://github.com/pytorch/pytorch/pull/114594 Approved by: https://github.com/lezcano --- aten/src/ATen/native/Convolution.cpp | 16 +++++++++++++ .../_internal/common_methods_invocations.py | 24 +++++++++++++++++++ 2 files changed, 40 insertions(+) diff --git a/aten/src/ATen/native/Convolution.cpp b/aten/src/ATen/native/Convolution.cpp index 43ee07b41107..13de0c41f7ba 100644 --- a/aten/src/ATen/native/Convolution.cpp +++ b/aten/src/ATen/native/Convolution.cpp @@ -917,6 +917,14 @@ at::Tensor conv1d_symint( c10::MaybeOwned bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt); const Tensor& bias = *bias_maybe_owned; + TORCH_CHECK( + !bias.defined() || bias.dtype() == input_.dtype(), + "Input type (", + input_.dtype().name(), + ") and bias type (", + bias.dtype().name(), + ") should be the same"); + Tensor input; bool is_batched; std::tie(input, is_batched) = batchify(input_, /*num_spatial_dims=*/ 1, "conv1d"); @@ -963,6 +971,14 @@ at::Tensor conv3d_symint( c10::MaybeOwned bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt); const Tensor& bias = *bias_maybe_owned; + TORCH_CHECK( + !bias.defined() || bias.dtype() == input_.dtype(), + "Input type (", + input_.dtype().name(), + ") and bias type (", + bias.dtype().name(), + ") should be the same"); + Tensor input; bool is_batched; std::tie(input, is_batched) = batchify(input_, /*num_spatial_dims=*/ 3, "conv3d"); diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index b1021491577f..b030f3e9f1b9 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -3863,6 +3863,18 @@ def sample_inputs_conv1d(op_info, device, dtype, requires_grad, **kwargs): def error_inputs_conv1d(opinfo, device, **kwargs): make_arg = partial(make_tensor, device=device, dtype=torch.float64) + make_int_arg = partial(make_tensor, device=device, dtype=torch.int64) + make_complex_arg = partial(make_tensor, device=device, dtype=torch.complex128) + + # error inputs for different dtypes of input tensor and bias + yield ErrorInput( + SampleInput(make_int_arg((1, 1, 4)), args=(make_int_arg((1, 1, 2)), make_arg((1,)))), + error_regex="should be the same") + + # error inputs for different dtypes of input tensor and bias + yield ErrorInput( + SampleInput(make_arg((1, 1, 4)), args=(make_arg((1, 1, 2)), make_complex_arg((1,)))), + error_regex="should be the same") # error inputs for negative strides yield ErrorInput( @@ -4048,6 +4060,18 @@ def sample_inputs_conv3d(opinfo, device, dtype, requires_grad, **kwargs): def error_inputs_conv3d(opinfo, device, **kwargs): make_arg = partial(make_tensor, device=device, dtype=torch.float64) + make_int_arg = partial(make_tensor, device=device, dtype=torch.int64) + make_complex_arg = partial(make_tensor, device=device, dtype=torch.complex128) + + # error inputs for different dtypes of input tensor and bias + yield ErrorInput( + SampleInput(make_int_arg((1, 1, 4, 4, 4)), args=(make_int_arg((1, 1, 2, 2, 2)), make_arg((1,)))), + error_regex="should be the same") + + # error inputs for different dtypes of input tensor and bias + yield ErrorInput( + SampleInput(make_arg((1, 1, 4, 4, 4)), args=(make_arg((1, 1, 2, 2, 2)), make_complex_arg((1,)))), + error_regex="should be the same") # error inputs for negative strides yield ErrorInput( From 7fa12510806899067c9af0664ceca32f42e71a04 Mon Sep 17 00:00:00 2001 From: Aaron Gokaslan Date: Mon, 27 Nov 2023 18:56:10 +0000 Subject: [PATCH 177/221] [BE][Easy]: Enable NPY lint rules for ruff (#114476) Enable NPY lint rules for ruff Pull Request resolved: https://github.com/pytorch/pytorch/pull/114476 Approved by: https://github.com/justinchuby, https://github.com/malfet --- pyproject.toml | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 71157c4f3cf3..d59bed1e9187 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,9 @@ ignore = [ "F841", # these ignores are from flake8-logging-format; please fix! "G101", - # these ignores are from RUFF perf; please fix! + # these ignores are from ruff NPY; please fix! + "NPY002", + # these ignores are from ruff PERF; please fix! "PERF203", "PERF4", # these ignores are from PYI; please fix! "PYI019", @@ -74,7 +76,7 @@ select = [ "SIM1", "W", # Not included in flake8 - "UP", + "NPY", "PERF", "PGH004", "PIE794", @@ -95,6 +97,7 @@ select = [ "RUF017", "TRY200", "TRY302", + "UP", ] [tool.ruff.per-file-ignores] From 69024883fbd78ce914afa3ce46c3060b4add975d Mon Sep 17 00:00:00 2001 From: Will Constable Date: Wed, 22 Nov 2023 17:06:38 -0800 Subject: [PATCH 178/221] Make dynamo's test_logging print helpful error (#114428) BEFORE ``` expected torch._dynamo.backends.distributed is DEBUG, got 0 ``` (0 is both unhelpful and also not numerically the right value, getEffectiveLevel() returns 20 not 0 for this particular case) AFTER ``` expected torch._dynamo.backends.distributed is DEBUG, got INFO ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/114428 Approved by: https://github.com/Skylion007 --- test/dynamo/test_logging.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/dynamo/test_logging.py b/test/dynamo/test_logging.py index 7e233263d3e2..59a1e33df8a7 100644 --- a/test/dynamo/test_logging.py +++ b/test/dynamo/test_logging.py @@ -252,13 +252,13 @@ def test_all(self, _): self.assertEqual( logger.getEffectiveLevel(), logging.INFO, - msg=f"expected {logger_qname} is INFO, got {logger.level}", + msg=f"expected {logger_qname} is INFO, got {logging.getLevelName(logger.getEffectiveLevel())}", ) else: self.assertEqual( logger.getEffectiveLevel(), logging.DEBUG, - msg=f"expected {logger_qname} is DEBUG, got {logger.level}", + msg=f"expected {logger_qname} is DEBUG, got {logging.getLevelName(logger.getEffectiveLevel())}", ) @make_logging_test(graph_breaks=True) From 800cf5f7cbd36b59d17e986d9783e49ded0b7b48 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Mon, 27 Nov 2023 19:55:26 +0000 Subject: [PATCH 179/221] Add USE_C10D_NCCL around NCCL trace utils (#114597) Fixes #114575 Pull Request resolved: https://github.com/pytorch/pytorch/pull/114597 Approved by: https://github.com/malfet --- torch/csrc/distributed/c10d/TraceUtils.h | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torch/csrc/distributed/c10d/TraceUtils.h b/torch/csrc/distributed/c10d/TraceUtils.h index 61eb5c8b7819..9d72e9960b2e 100644 --- a/torch/csrc/distributed/c10d/TraceUtils.h +++ b/torch/csrc/distributed/c10d/TraceUtils.h @@ -266,6 +266,8 @@ inline std::string retrieveDesyncReport( /* Note: this is only used by PGNCCL (could be generalized in an ideal world but * wasn't done that way, so isn't expected to be fully general at the moment) */ +#ifdef USE_C10D_NCCL + DebugInfoWriter::DebugInfoWriter(int rank) { std::string fileName = getCvarString( {"TORCH_NCCL_DEBUG_INFO_TEMP_FILE"}, "/tmp/nccl_trace_rank_"); @@ -516,4 +518,5 @@ struct NCCLTraceBuffer { } }; +#endif } // namespace c10d From e0d2a24967218d7c39e24f66bb6c4836c9d1d427 Mon Sep 17 00:00:00 2001 From: Zhengxu Chen Date: Mon, 27 Nov 2023 20:19:04 +0000 Subject: [PATCH 180/221] Reland "[export] Support user input mutation. [1/2]" (#114496) (#114596) Summary: Serialization not implemented yet. Will do in the next diff. Resolving Github issues: https://github.com/pytorch/pytorch/issues/112429 https://github.com/pytorch/pytorch/issues/114142 Test Plan: onnx doc test ``` python -m xdoctest /opt/conda/envs/py_3.8/lib/python3.8/site-packages/torch/onnx/_internal/exporter.py ONNXProgram.model_signature:0 ``` Differential Revision: D51588558 Pull Request resolved: https://github.com/pytorch/pytorch/pull/114596 Approved by: https://github.com/angelayi --- test/export/test_db.py | 14 ++- test/export/test_export.py | 55 +++++++-- test/export/test_passes.py | 42 ++----- test/export/test_serialize.py | 1 + test/export/test_unflatten.py | 4 +- torch/_export/__init__.py | 107 ++++++++++++++++-- .../db/examples/user_input_mutation.py | 4 +- torch/_export/verifier.py | 37 +++--- torch/export/exported_program.py | 33 ++++-- torch/export/graph_signature.py | 74 ++++++++---- torch/onnx/_internal/exporter.py | 4 +- 11 files changed, 260 insertions(+), 115 deletions(-) diff --git a/test/export/test_db.py b/test/export/test_db.py index d126684beb60..a2d2f36af42e 100644 --- a/test/export/test_db.py +++ b/test/export/test_db.py @@ -1,14 +1,15 @@ # Owner(s): ["module: dynamo"] +import copy import unittest import torch._dynamo as torchdynamo -from torch.export import export from torch._export.db.case import ExportCase, normalize_inputs, SupportLevel from torch._export.db.examples import ( filter_examples_by_support_level, get_rewrite_cases, ) +from torch.export import export from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, @@ -28,18 +29,19 @@ class ExampleTests(TestCase): def test_exportdb_supported(self, name: str, case: ExportCase) -> None: model = case.model - inputs = normalize_inputs(case.example_inputs) + inputs_export = normalize_inputs(case.example_inputs) + inputs_model = copy.deepcopy(inputs_export) exported_program = export( model, - inputs.args, - inputs.kwargs, + inputs_export.args, + inputs_export.kwargs, dynamic_shapes=case.dynamic_shapes, ) exported_program.graph_module.print_readable() self.assertEqual( - exported_program(*inputs.args, **inputs.kwargs), - model(*inputs.args, **inputs.kwargs), + exported_program(*inputs_export.args, **inputs_export.kwargs), + model(*inputs_model.args, **inputs_model.kwargs), ) if case.extra_inputs is not None: diff --git a/test/export/test_export.py b/test/export/test_export.py index 221ea9ba075b..caa576bfa987 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -1,5 +1,6 @@ # Owner(s): ["module: dynamo"] # flake8: noqa +import copy import dataclasses import unittest from contextlib import contextmanager @@ -1092,13 +1093,13 @@ def f(x, y): torch.allclose(exported(torch.ones(8, 5), 5), f(torch.ones(8, 5), 5)) ) with self.assertRaisesRegex( - RuntimeError, "Input arg1_1 is specialized to be 5 at tracing time" + RuntimeError, "is specialized to be 5 at tracing time" ): _ = exported(torch.ones(8, 5), 6) exported = torch.export.export(f, (tensor_inp, 5.0), dynamic_shapes=dynamic_shapes) with self.assertRaisesRegex( - RuntimeError, "Input arg1_1 is specialized to be 5.0 at tracing time" + RuntimeError, "is specialized to be 5.0 at tracing time" ): _ = exported(torch.ones(7, 5), 6.0) @@ -1109,7 +1110,7 @@ def g(a, b, mode): inps = (torch.randn(4, 4), torch.randn(4), "trunc") exported = torch._export.export(g, inps) - with self.assertRaisesRegex(RuntimeError, "Input arg2_1 is specialized to be trunc at"): + with self.assertRaisesRegex(RuntimeError, "is specialized to be trunc at"): _ = exported(torch.randn(4, 4), torch.randn(4), "floor") self.assertTrue(torch.allclose(exported(*inps), g(*inps))) @@ -1190,7 +1191,7 @@ def forward(self, x): dim0_x = torch.export.Dim("dim0_x") exported = torch.export.export(Foo(), (inp,), dynamic_shapes={"x": {0: dim0_x}}) reexported = torch.export.export(exported, (inp,)) - with self.assertRaisesRegex(RuntimeError, "Input arg2_1\.shape\[0\] is specialized at 5"): + with self.assertRaisesRegex(RuntimeError, "shape\[0\] is specialized at 5"): reexported(torch.ones(7, 5)) reexported = torch.export.export(exported, (inp,), dynamic_shapes=({0: dim0_x},)) @@ -1199,7 +1200,7 @@ def forward(self, x): # can't retrace with invalid inputs with respect to the original ExportedProgram dim0_x_v2 = torch.export.Dim("dim0_x_v2", min=3) exported_v2 = torch.export.export(Foo(), (inp,), dynamic_shapes={"x": {0: dim0_x_v2}}) - with self.assertRaisesRegex(RuntimeError, "Input arg2_1"): + with self.assertRaisesRegex(RuntimeError, "shape\[1\] is specialized at 5"): torch.export.export(exported_v2, (torch.randn(2, 2),)) def test_retrace_graph_level_meta_preservation(self): @@ -1472,8 +1473,8 @@ def f(x): ep = export(f, (torch.tensor([3]),)) self.assertExpectedInline(str(ep.graph_module.code).strip(), """\ -def forward(self, arg0_1): - _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(arg0_1); arg0_1 = None +def forward(self, l_x_): + _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(l_x_); l_x_ = None ge = _local_scalar_dense >= 0 scalar_tensor = torch.ops.aten.scalar_tensor.default(ge); ge = None _assert_async = torch.ops.aten._assert_async.msg(scalar_tensor, '_local_scalar_dense is outside of inline constraint [0, inf].'); scalar_tensor = None @@ -1492,7 +1493,7 @@ def foo(a, b): self.assertEqual(ep(*test_inp), foo(*test_inp)) ep_v2 = torch.export.export(foo, (torch.randn(4, 4), torch.randn(4, 4)), dynamic_shapes=(None, None)) - with self.assertRaisesRegex(RuntimeError, "Input arg1_1.shape\[0\] is specialized at 4"): + with self.assertRaisesRegex(RuntimeError, "shape\[0\] is specialized at 4"): ep_v2(*test_inp) def test_non_arg_name_dynamic_shapes_api_with_kwarg(self): @@ -1540,7 +1541,7 @@ def dynamify_inp(x): test_inp = ((torch.randn(4, 4), torch.randn(2, 4)), torch.randn(4, 4)) with self.assertRaisesRegex( RuntimeError, - "Input arg1_1.shape\[0\] is outside of specified dynamic range \[3, inf\]" + "shape\[0\] is outside of specified dynamic range \[3, inf\]" ): ep(*test_inp) @@ -1721,6 +1722,42 @@ def forward(self, scores, mask): optimized_model = torch.compile(exported_model) optimized_model(tensor_cpu, mask_cpu) + def test_export_input_mutation_static_shape(self): + class MutationModel(torch.nn.Module): + def forward(self, x, y): + x.view(3, 2, -1).add_(y) + return x + inputs = (torch.randn(12), 2.0) + model = MutationModel() + ep = torch.export.export(model, inputs) + inputs_export = copy.deepcopy(inputs) + inputs_model = copy.deepcopy(inputs) + self.assertEqual(ep(*inputs_export), model(*inputs_model)) + self.assertEqual(inputs[0] + 2.0, inputs_model[0]) + self.assertEqual(inputs[0] + 2.0, inputs_export[0]) + + def test_export_input_mutation_dynamic_shape(self): + class MutationModel(torch.nn.Module): + def forward(self, x, y): + x[0].mul_(y) + return x + inputs = ((torch.randn(12), torch.randn(3, 2)), 2.0) + model = MutationModel() + ep = torch.export.export( + model, + inputs, + dynamic_shapes={'x': ({0: torch.export.Dim("dim")}, None), "y": None} + ) + nodes = list(ep.graph.nodes) + self.assertEqual(nodes[0].op, "placeholder") + self.assertIsInstance(nodes[0].meta['val'], torch.Tensor) + self.assertIsInstance(nodes[0].meta['val'].shape[0], torch.SymInt) + + inputs_export = copy.deepcopy(inputs) + inputs_model = copy.deepcopy(inputs) + self.assertEqual(ep(*inputs_export), model(*inputs_model)) + self.assertEqual(inputs[0][0] * 2.0, inputs_model[0][0]) + self.assertEqual(inputs[0][0] * 2.0, inputs_export[0][0]) if __name__ == '__main__': run_tests() diff --git a/test/export/test_passes.py b/test/export/test_passes.py index 86b30cb05980..627f1d5ee98f 100644 --- a/test/export/test_passes.py +++ b/test/export/test_passes.py @@ -76,7 +76,7 @@ def forward(self, x): dim1_x = torch.export.Dim("dim1_x", min=2, max=6) ep = torch.export.export(M(), (x,), dynamic_shapes={"x": {1: dim1_x}}) - with self.assertRaisesRegex(RuntimeError, "Input arg0_1"): + with self.assertRaisesRegex(RuntimeError, "is outside of specified dynamic range"): ep(torch.zeros(2, 7, 3)) self.assertEqual(ep(torch.ones(2, 4, 3)), M().forward(torch.ones(2, 4, 3))) @@ -99,10 +99,10 @@ def forward(self, x, y): M(), (x, y), dynamic_shapes={"x": {0: dim0_x, 1: dim1_x}, "y": {0: dim0_y}} ) - with self.assertRaisesRegex(RuntimeError, "Input arg0_1"): + with self.assertRaisesRegex(RuntimeError, "is outside of specified dynamic range"): ep(torch.zeros(4, 7, 3), torch.ones(5, 5, 5)) - with self.assertRaisesRegex(RuntimeError, "Input arg1_1"): + with self.assertRaisesRegex(RuntimeError, "is outside of specified dynamic range"): ep(torch.zeros(4, 2, 3), torch.ones(2, 5, 5)) def test_runtime_assert_some_dims_not_specified(self) -> None: @@ -123,12 +123,12 @@ def forward(self, x, y): M(), (x, y), dynamic_shapes={"x": {0: dim0_x, 1: dim1_x}, "y": None} ) - with self.assertRaisesRegex(RuntimeError, "Input arg0_1"): + with self.assertRaisesRegex(RuntimeError, "is outside of specified dynamic range"): ep(torch.zeros(4, 7, 3), torch.ones(5, 5, 5)) # y is specialized to 5 with self.assertRaisesRegex( - RuntimeError, r"Input arg1_1.shape\[0\] is specialized at 5" + RuntimeError, r"shape\[0\] is specialized at 5" ): ep(torch.zeros(4, 2, 3), torch.ones(2, 5, 5)) @@ -152,12 +152,12 @@ def forward(self, x, y): dim1_y = torch.export.Dim("dim1_y", min=3, max=6) ep = torch.export.export(M(), (x, y), dynamic_shapes={"x": None, "y": {1: dim1_y}}) - with self.assertRaisesRegex(RuntimeError, "Input arg0_1"): + with self.assertRaisesRegex(RuntimeError, r"shape\[1\] is specialized at 2"): ep(torch.zeros(4, 7, 3), torch.ones(5, 5, 5)) # y is specialized to 5 with self.assertRaisesRegex( - RuntimeError, r"Input arg1_1.shape\[0\] is specialized at 5" + RuntimeError, r"shape\[0\] is specialized at 5" ): ep(torch.zeros(4, 2, 3), torch.ones(2, 5, 5)) @@ -302,34 +302,6 @@ def false_fn(x, y): with self.assertRaisesRegex(RuntimeError, "is outside of inline constraint \\[2, 5\\]."): ep(torch.tensor(False), torch.tensor([6]), torch.tensor([6])) - def test_runtime_assert_equality_constraint(self): - class Adder(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - - def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - return x + y - - m = Adder() - x = torch.rand(3, 4) - y = torch.rand(3, 4) - dim1 = torch.export.Dim("dim1") - exported = torch.export.export( - m, (x, y), dynamic_shapes={"x": {1: dim1}, "y": {1: dim1}} - ) - - x = torch.rand(3, 5) - y = torch.rand(3, 6) - with self.assertRaisesRegex( - RuntimeError, r"Input arg0_1.shape\[1\] is not equal to input arg1_1.shape\[1\]" - ): - exported(x, y) - - y = torch.rand(3, 5) - dynamo_result = exported(x, y) - real_result = m(x, y) - self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) - def test_functionalize_inline_contraints(self) -> None: def f(x): a = x.item() diff --git a/test/export/test_serialize.py b/test/export/test_serialize.py index 7eaff4f75ce7..49b4e35b6130 100644 --- a/test/export/test_serialize.py +++ b/test/export/test_serialize.py @@ -44,6 +44,7 @@ def get_filtered_export_db_tests(): "dictionary", # Graph output must be a tuple() "fn_with_kwargs", # export doesn't support kwargs yet "scalar_output", # Tracing through 'f' must produce a single graph + "user_input_mutation", # TODO(zhxchen17) Support serializing user inputs mutation. } return [ diff --git a/test/export/test_unflatten.py b/test/export/test_unflatten.py index c6c9e678f015..444a64c2eb56 100644 --- a/test/export/test_unflatten.py +++ b/test/export/test_unflatten.py @@ -277,11 +277,11 @@ def forward(self, x): return a export_module = torch.export.export(Mod(), (torch.randn((2, 3)),)) - with self.assertRaisesRegex(RuntimeError, "Input arg4_1.shape"): + with self.assertRaisesRegex(RuntimeError, ".shape\[1\] is specialized at 3"): export_module(torch.randn(6, 6)) unflattened = export_module.module(flat=False) - with self.assertRaisesRegex(RuntimeError, "Input arg4_1.shape"): + with self.assertRaisesRegex(RuntimeError, ".shape\[1\] is specialized at 3"): unflattened(torch.randn(6, 6)) diff --git a/torch/_export/__init__.py b/torch/_export/__init__.py index 438b54c2b0bd..fd7bb2cff3f3 100644 --- a/torch/_export/__init__.py +++ b/torch/_export/__init__.py @@ -30,28 +30,28 @@ from torch._dynamo.exc import UserError, UserErrorType from torch._dynamo.source import ConstantSource from torch._export.passes.collect_tracepoints_pass import CollectTracepointsPass -from torch._functorch.aot_autograd import aot_export_module +from torch._functorch.aot_autograd import aot_export_module, GraphSignature from torch._functorch.eager_transforms import functionalize from torch._guards import detect_fake_mode from torch._ops import OpOverload from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode from torch.export import _create_constraint, _Dim, Constraint +from torch.export.exported_program import ( + ExportedProgram, + ModuleCallEntry, + ModuleCallSignature, +) from torch.export.graph_signature import ( - ExportGraphSignature, _sig_to_specs, ArgumentSpec, ConstantArgument, + ExportGraphSignature, InputKind, + InputSpec, OutputKind, OutputSpec, SymIntArgument, TensorArgument, - InputSpec -) -from torch.export.exported_program import ( - ExportedProgram, - ModuleCallEntry, - ModuleCallSignature, ) from torch.fx import traceback as fx_traceback from torch.fx._compatibility import compatibility @@ -559,6 +559,88 @@ def export( preserve_module_call_signature=preserve_module_call_signature, ) + +def _prepare_module( + gm_torch_level: torch.fx.GraphModule, + aot_export_args +) -> List[str]: + flat_args = pytree.tree_leaves(aot_export_args) + user_input_names = [] + with gm_torch_level.graph.inserting_before(): + for i, (arg, node) in enumerate(zip(flat_args, gm_torch_level.graph.nodes)): + assert node.op == "placeholder" + user_input_names.append(node.name) + if isinstance(arg, torch.Tensor): + assert not hasattr(gm_torch_level, node.name) + gm_torch_level.register_buffer(node.name, arg) + get_attr = gm_torch_level.graph.get_attr(node.name) + node.replace_all_uses_with(get_attr) + get_attr.meta = copy.copy(node.meta) + + for node in list(gm_torch_level.graph.nodes): + if node.op == "placeholder": + assert len(node.users) == 0 + gm_torch_level.graph.erase_node(node) + gm_torch_level.recompile() + return user_input_names + + +def _unwrap_user_inputs( + gm: torch.fx.GraphModule, + graph_signature: GraphSignature, + user_input_names: List[str] +) -> Dict[str, str]: + assert len(graph_signature.user_inputs) == 0 + assert graph_signature.backward_signature is None + names = set(user_input_names) + + placeholders = [node for node in gm.graph.nodes if node.op == "placeholder"] + # user inputs are always added in the end + start = len(graph_signature.parameters) + end = start + len(graph_signature.buffers) + buffer_nodes = placeholders[start:end] + last_placeholder_node = placeholders[-1] if len(placeholders) > 0 else None + old_nodes: Dict[str, torch.fx.Node] = {} + for node in buffer_nodes: + buffer_name = graph_signature.inputs_to_buffers[node.name] + if buffer_name not in names: + continue + old_nodes[buffer_name] = node + replaces = {} + new_node_names: Dict[str, str] = {} + with gm.graph.inserting_after(last_placeholder_node): + for name in reversed(user_input_names): + new_node = gm.graph.placeholder(name) + new_node.target = new_node.name + new_node_names[name] = new_node.name + if name in old_nodes: + old_node = old_nodes[name] + new_node.meta = copy.copy(old_node.meta) + old_node.replace_all_uses_with(new_node) + replaces[old_node.name] = new_node.name + + for old_node in old_nodes.values(): + gm.graph.erase_node(old_node) + + gm.recompile() + + graph_signature.buffers = [b for b in graph_signature.buffers if b not in names] + graph_signature.inputs_to_buffers = { + i: b for i, b in graph_signature.inputs_to_buffers.items() if b not in names + } + user_inputs_to_mutate = { + o: b for o, b in graph_signature.buffers_to_mutate.items() if b in names + } + graph_signature.buffers_to_mutate = { + o: b for o, b in graph_signature.buffers_to_mutate.items() if b not in names + } + graph_signature.user_inputs = list(reversed(new_node_names.values())) # type: ignore[arg-type] + graph_signature.user_outputs = [ + replaces[o] if o in replaces else o for o in graph_signature.user_outputs + ] + return user_inputs_to_mutate # type: ignore[return-value] + + def _disable_prexisiting_fake_mode(fn): @functools.wraps(fn) @@ -703,6 +785,10 @@ def _export( if isinstance(f, torch.nn.Module): _normalize_nn_module_stack(gm_torch_level, type(f)) + aot_export_args = (*fake_args, *_reorder_kwargs_by_names(orig_args, fake_args, fake_kwargs).values()) + + user_input_names = _prepare_module(gm_torch_level, aot_export_args) + # Note: aot_export_module doesn't accept kwargs, we'd like to reorder the kwargs as an OrderedDict # to follow the order in orig_args and correctly call gm_torch_level @@ -712,9 +798,10 @@ def _export( with torch.nn.utils.stateless._reparametrize_module(gm_torch_level, fake_params_buffers): gm, graph_signature = aot_export_module( gm_torch_level, - (*fake_args, *_reorder_kwargs_by_names(orig_args, fake_args, fake_kwargs).values()), + (), trace_joint=False ) + user_inputs_to_mutate = _unwrap_user_inputs(gm, graph_signature, user_input_names) def to_str_list(sig_component: List[Any]): return [str(v) for v in sig_component] @@ -771,6 +858,7 @@ def to_str_dict(sig_component: Dict[Any, Any]): is_joint = graph_signature.backward_signature is not None def make_argument_spec(node) -> ArgumentSpec: + assert "val" in node.meta, f"{node} has no 'val' metadata field" val = node.meta["val"] if isinstance(val, FakeTensor): return TensorArgument(name=node.name) @@ -784,6 +872,7 @@ def make_argument_spec(node) -> ArgumentSpec: inputs_to_buffers=graph_signature.inputs_to_buffers, # type: ignore[arg-type] user_outputs=set(graph_signature.user_outputs), # type: ignore[arg-type] buffer_mutations=graph_signature.buffers_to_mutate, # type: ignore[arg-type] + user_input_mutations=user_inputs_to_mutate, # type: ignore[arg-type] grad_params=graph_signature.backward_signature.gradients_to_parameters if is_joint else {}, # type: ignore[arg-type, union-attr] grad_user_inputs=graph_signature.backward_signature.gradients_to_user_inputs if is_joint else {}, # type: ignore[arg-type, union-attr] loss_output=graph_signature.backward_signature.loss_output if is_joint else None, # type: ignore[arg-type, union-attr] diff --git a/torch/_export/db/examples/user_input_mutation.py b/torch/_export/db/examples/user_input_mutation.py index 56af08d0c359..2bb16cd64a56 100644 --- a/torch/_export/db/examples/user_input_mutation.py +++ b/torch/_export/db/examples/user_input_mutation.py @@ -6,11 +6,11 @@ @export_case( example_inputs=(torch.ones(3, 2),), tags={"torch.mutation"}, - support_level=SupportLevel.NOT_SUPPORTED_YET, + support_level=SupportLevel.SUPPORTED, ) class UserInputMutation(torch.nn.Module): """ - Can't directly mutate user input in forward + Directly mutate user input in forward """ def forward(self, x): diff --git a/torch/_export/verifier.py b/torch/_export/verifier.py index dae103724c5e..391d7f99f69b 100644 --- a/torch/_export/verifier.py +++ b/torch/_export/verifier.py @@ -230,12 +230,6 @@ def _verify_exported_program_signature(exported_program) -> None: # Check ExportedProgram signature matches gs = exported_program.graph_signature - bs_grad_to_param = {} - bs_grad_to_user_inputs = {} - if gs.backward_signature is not None: - bs_grad_to_param = gs.backward_signature.gradients_to_parameters - bs_grad_to_user_inputs = gs.backward_signature.gradients_to_user_inputs - # Check every node in the signature exists in the graph input_node_names = [node.name for node in exported_program.graph.nodes if node.op == "placeholder"] @@ -324,19 +318,28 @@ def _verify_exported_program_signature(exported_program) -> None: f"Number of user outputs: {len(gs.user_outputs)}. \n" ) - buffer_mutate_nodes = output_nodes[:len(gs.buffers_to_mutate)] - user_output_nodes = output_nodes[len(gs.buffers_to_mutate):len(gs.user_outputs) + len(gs.buffers_to_mutate)] + end = len(gs.buffers_to_mutate) + len(gs.user_inputs_to_mutate) + mutate_nodes: List[str] = output_nodes[:end] + user_output_nodes = output_nodes[end:end + len(gs.user_outputs)] - for buffer_node in buffer_mutate_nodes: - if ( - buffer_node not in gs.buffers_to_mutate or - gs.buffers_to_mutate[buffer_node] not in gs.buffers - ): + for mutation_node in mutate_nodes: + if mutation_node in gs.buffers_to_mutate: + if gs.buffers_to_mutate[mutation_node] not in gs.buffers: + raise SpecViolationError( + f"Buffer output {mutation_node} does not point to a buffer that exists. \n" + f"Dict of buffers that are mutated, in order: {gs.buffers_to_mutate} \n" + f"Buffer nodes available: {gs.buffers} \n" + ) + elif mutation_node in gs.user_inputs_to_mutate: + if gs.user_inputs_to_mutate[mutation_node] not in gs.user_inputs: + raise SpecViolationError( + f"User input output {mutation_node} does not point to a user input that exists. \n" + f"Dict of user inputs that are mutated, in order: {gs.user_inputs_to_mutate} \n" + f"User input nodes available: {gs.user_inputs} \n") + else: raise SpecViolationError( - f"Buffer output {buffer_node} is not in buffer mutation dictionary " - "or, it does not point to a buffer that exists. \n" - f"Dict of buffers that are mutated, in order: {gs.buffers_to_mutate} \n" - f"Buffer nodes available: {gs.buffers} \n" + f"Mutation node {mutation_node} is neither a buffer nor a user input. " + f"Buffers to mutate: {gs.buffers_to_mutate}, User inputs to mutate: {gs.user_inputs_to_mutate}" ) for user_output_node, user_output_name in zip(user_output_nodes, gs.user_outputs): diff --git a/torch/export/exported_program.py b/torch/export/exported_program.py index dde89b4fdd9c..015d63ff8bbe 100644 --- a/torch/export/exported_program.py +++ b/torch/export/exported_program.py @@ -277,9 +277,10 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: ) if self.call_spec.out_spec is not None: - mutation = self.graph_signature.buffers_to_mutate - num_mutated = len(mutation) - mutated_buffers = res[:num_mutated] + buffer_mutation = self.graph_signature.buffers_to_mutate + user_input_mutation = self.graph_signature.user_inputs_to_mutate + num_mutated = len(buffer_mutation) + len(user_input_mutation) + mutated_values = res[:num_mutated] # Exclude dependency token from final result. assertion_dep_token = self.graph_signature.assertion_dep_token @@ -299,10 +300,27 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: f"{received_spec}" ) finally: - ix = 0 - for buffer in self.graph_signature.buffers_to_mutate.values(): - self.state_dict[buffer] = mutated_buffers[ix] - ix += 1 + user_inputs = [ + spec + for spec in self.graph_signature.input_specs + if spec.kind == InputKind.USER_INPUT + ] + for i, value in enumerate(mutated_values): + output_spec = self.graph_signature.output_specs[i] + if output_spec.kind == OutputKind.BUFFER_MUTATION: + assert output_spec.target is not None + self.state_dict[output_spec.target] = value + elif output_spec.kind == OutputKind.USER_INPUT_MUTATION: + assert output_spec.target is not None + index = next( + i + for i, spec in enumerate(user_inputs) + if spec.arg.name == output_spec.target + ) + args[index].copy_(value) + else: + raise AssertionError(f"Unexpected kind: {output_spec.kind}") + return res def __str__(self) -> str: @@ -365,7 +383,6 @@ def _get_placeholders(gm): decomp_table = decomp_table or core_aten_decompositions() old_placeholders = _get_placeholders(self.graph_module) - old_outputs = list(self.graph.nodes)[-1].args[0] fake_args = [node.meta["val"] for node in old_placeholders] buffers_to_remove = [name for name, _ in self.graph_module.named_buffers()] diff --git a/torch/export/graph_signature.py b/torch/export/graph_signature.py index d208868d0d53..06c7f8e53e62 100644 --- a/torch/export/graph_signature.py +++ b/torch/export/graph_signature.py @@ -57,6 +57,7 @@ class OutputKind(Enum): BUFFER_MUTATION = auto() GRADIENT_TO_PARAMETER = auto() GRADIENT_TO_USER_INPUT = auto() + USER_INPUT_MUTATION = auto() @dataclasses.dataclass @@ -76,6 +77,7 @@ def _sig_to_specs( inputs_to_buffers: Mapping[str, str], user_outputs: Set[str], buffer_mutations: Mapping[str, str], + user_input_mutations: Mapping[str, str], grad_params: Mapping[str, str], grad_user_inputs: Mapping[str, str], loss_output: Optional[str], @@ -101,37 +103,49 @@ def to_input_spec(i: ArgumentSpec) -> InputSpec: else: raise AssertionError(f"Unknown tensor input kind: {name}") - def to_output_spec(o: ArgumentSpec) -> OutputSpec: + def to_output_spec(idx: int, o: ArgumentSpec) -> OutputSpec: if not isinstance(o, TensorArgument): return OutputSpec(kind=OutputKind.USER_OUTPUT, arg=o, target=None) name = o.name - if name in user_outputs: - return OutputSpec(kind=OutputKind.USER_OUTPUT, arg=o, target=None) - elif name in buffer_mutations: - return OutputSpec( - kind=OutputKind.BUFFER_MUTATION, - arg=o, - target=buffer_mutations[name], - ) - elif name in grad_params: - return OutputSpec( - kind=OutputKind.GRADIENT_TO_PARAMETER, - arg=o, - target=grad_params[name], - ) - elif name in grad_user_inputs: - return OutputSpec( - kind=OutputKind.GRADIENT_TO_USER_INPUT, - arg=o, - target=grad_user_inputs[name], - ) - elif name == loss_output: - return OutputSpec(kind=OutputKind.LOSS_OUTPUT, arg=o, target=None) + if idx < len(buffer_mutations) + len(user_input_mutations): + if name in buffer_mutations: + return OutputSpec( + kind=OutputKind.BUFFER_MUTATION, + arg=o, + target=buffer_mutations[name], + ) + elif name in user_input_mutations: + return OutputSpec( + kind=OutputKind.USER_INPUT_MUTATION, + arg=o, + target=user_input_mutations[name], + ) + else: + raise AssertionError(f"Unknown tensor mutation kind: {name}") else: - raise AssertionError(f"Unknown tensor output kind: {name}") + if name in user_outputs: + return OutputSpec(kind=OutputKind.USER_OUTPUT, arg=o, target=None) + + elif name in grad_params: + return OutputSpec( + kind=OutputKind.GRADIENT_TO_PARAMETER, + arg=o, + target=grad_params[name], + ) + elif name in grad_user_inputs: + return OutputSpec( + kind=OutputKind.GRADIENT_TO_USER_INPUT, + arg=o, + target=grad_user_inputs[name], + ) + elif name == loss_output: + return OutputSpec(kind=OutputKind.LOSS_OUTPUT, arg=o, target=None) + + else: + raise AssertionError(f"Unknown tensor output kind: {name}") input_specs = [to_input_spec(i) for i in inputs] - output_specs = [to_output_spec(o) for o in outputs] + output_specs = [to_output_spec(idx, o) for idx, o in enumerate(outputs)] return input_specs, output_specs @@ -304,6 +318,16 @@ def buffers_to_mutate(self) -> Mapping[str, str]: and isinstance(s.target, str) } + @property + def user_inputs_to_mutate(self) -> Mapping[str, str]: + return { + s.arg.name: s.target + for s in self.output_specs + if s.kind == OutputKind.USER_INPUT_MUTATION + and isinstance(s.arg, TensorArgument) + and isinstance(s.target, str) + } + # A dictionary mapping graph input node names to lifted tensor constants. @property def inputs_to_lifted_tensor_constants(self) -> Mapping[str, str]: diff --git a/torch/onnx/_internal/exporter.py b/torch/onnx/_internal/exporter.py index 807ef52a0483..f04f8510d177 100644 --- a/torch/onnx/_internal/exporter.py +++ b/torch/onnx/_internal/exporter.py @@ -776,8 +776,8 @@ def model_signature(self) -> Optional[torch.export.ExportGraphSignature]: InputSpec(kind=, arg=TensorArgument(name='arg3_1'), target='fc2.weight'), InputSpec(kind=, arg=TensorArgument(name='arg4_1'), target='my_buffer2'), InputSpec(kind=, arg=TensorArgument(name='arg5_1'), target='my_buffer1'), - InputSpec(kind=, arg=TensorArgument(name='arg6_1'), target=None), - InputSpec(kind=, arg=TensorArgument(name='arg7_1'), target=None) + InputSpec(kind=, arg=TensorArgument(name='l_x_'), target=None), + InputSpec(kind=, arg=TensorArgument(name='arg1'), target=None) ], output_specs=[ OutputSpec(kind=, arg=TensorArgument(name='add'), target='my_buffer2'), From 4c794f2ef183c7bf28d6973cd0fd946dfe73a97c Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Mon, 27 Nov 2023 21:32:42 +0000 Subject: [PATCH 181/221] Reinplace foreach when safe and allow aliasing during lowering (#112440) This reduces compile time of Adam on 1k parameters from 180s to 140s (28%), the main reason being that thousands of buffers no longer get sent to the scheduler. The idea behind this is that if a destination buffer (from a copy_) has no users, it shouldn't matter if dst aliases src. This is implemented by reinplacing copy_ nodes when safe. Pull Request resolved: https://github.com/pytorch/pytorch/pull/112440 Approved by: https://github.com/jansel --- test/inductor/test_foreach.py | 65 +++++++++++++++++++ torch/_inductor/fx_passes/post_grad.py | 86 +++++++++++++++++--------- torch/_inductor/ir.py | 26 ++++---- torch/_inductor/lowering.py | 67 +++++++++++++++++--- torch/_inductor/scheduler.py | 16 +++-- 5 files changed, 203 insertions(+), 57 deletions(-) diff --git a/test/inductor/test_foreach.py b/test/inductor/test_foreach.py index f9a18f48f6f7..bd74cc297dd0 100644 --- a/test/inductor/test_foreach.py +++ b/test/inductor/test_foreach.py @@ -29,6 +29,12 @@ sys.exit(0) raise +inplace_bin_ops_under_test = [ + torch._foreach_add_, + torch._foreach_mul_, + torch._foreach_sub_, + torch._foreach_div_, +] bin_ops_under_test = [ torch._foreach_add, @@ -54,6 +60,9 @@ "op", bin_ops_under_test + un_ops_under_test, name_fn=lambda f: f.__name__ ) bin_ops = parametrize("op", bin_ops_under_test, name_fn=lambda f: f.__name__) +inplace_bin_ops = parametrize( + "op", inplace_bin_ops_under_test, name_fn=lambda f: f.__name__ +) scalar_bin_ops = parametrize("op", bin_ops_under_test[:4], name_fn=lambda f: f.__name__) scalar_tensor_bin_ops = parametrize( "op", bin_ops_under_test[:2], name_fn=lambda f: f.__name__ @@ -614,6 +623,62 @@ def fn(a0, a1, a2, b0, b1, b2): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2) + @requires_cuda() + @inplace_bin_ops + def test_reinplacing(self, op): + def fn(a0, a1, b0, b1): + op([a0, a1], [b0, b1]) + return [a0, a1] + + inputs = ( + torch.rand(10, 10, device="cuda:0"), + torch.rand(20, 20, device="cuda:0"), + torch.rand(10, 10, device="cuda:0"), + torch.rand(20, 20, device="cuda:0"), + ) + + self.check_model_cuda(fn, inputs, check_lowp=False) + + self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) + + @requires_cuda() + @inplace_bin_ops + def test_reinplacing_mut_before(self, op): + def fn(a0, a1, b0, b1): + a0.add_(torch.ones(10, 10, device="cuda:0")) + op([a0, a1], [b0, b1]) + return [a0, a1] + + inputs = ( + torch.rand(10, 10, device="cuda:0"), + torch.rand(20, 20, device="cuda:0"), + torch.rand(10, 10, device="cuda:0"), + torch.rand(20, 20, device="cuda:0"), + ) + + self.check_model_cuda(fn, inputs, check_lowp=False) + + self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) + + @requires_cuda() + @inplace_bin_ops + def test_reinplacing_mut_after(self, op): + def fn(a0, a1, b0, b1): + op([a0, a1], [b0, b1]) + a0.add_(torch.ones(10, 10, device="cuda:0")) + return [a0, a1] + + inputs = ( + torch.rand(10, 10, device="cuda:0"), + torch.rand(20, 20, device="cuda:0"), + torch.rand(10, 10, device="cuda:0"), + torch.rand(20, 20, device="cuda:0"), + ) + + self.check_model_cuda(fn, inputs, check_lowp=False) + + self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py index 86f84a63b6c8..03eae27afe26 100644 --- a/torch/_inductor/fx_passes/post_grad.py +++ b/torch/_inductor/fx_passes/post_grad.py @@ -20,7 +20,11 @@ from .. import config, inductor_prims, ir, pattern_matcher from ..fx_utils import FakeTensorUpdater, get_fake_args_kwargs, get_node_storage -from ..lowering import lowerings as L + +from ..lowering import ( + inplaceable_foreach_ops as inplaceable_foreach_ops_lowerings, + lowerings as L, +) from ..pattern_matcher import ( _return_true, Arg, @@ -631,6 +635,34 @@ def remove_noop_ops(graph: torch.fx.Graph): InplaceableOp = namedtuple("InplaceableOp", ["inplace_op", "mutated_arg"]) +inplaceable_ops = { + aten.index_put.default: InplaceableOp(aten.index_put_.default, 0), + aten._unsafe_index_put.default: InplaceableOp(inductor_prims._unsafe_index_put_, 0), +} + +try: + c10d_functional = torch.ops._c10d_functional + inplaceable_collective_ops = { + c10d_functional.all_reduce.default: InplaceableOp( + c10d_functional.all_reduce_.default, 0 + ), + c10d_functional.all_reduce_coalesced.default: InplaceableOp( + c10d_functional.all_reduce_coalesced_.default, 0 + ), + } + inplaceable_ops.update(inplaceable_collective_ops) +except AttributeError: + # _c10d_functional ops are only available when torch + # is built with USE_DISTRIBUTED=1. + pass + +inplaceable_foreach_ops = {} +for outplace_op, inplace_op in inplaceable_foreach_ops_lowerings.items(): + inplaceable_foreach_ops[outplace_op] = InplaceableOp(inplace_op, 0) + + +inplaceable_triton_ops = {triton_kernel_wrapper_functional} + def reinplace_inplaceable_ops(graph): """ @@ -648,6 +680,7 @@ def reinplace_inplaceable_ops(graph): """ copy_args_to_copy_nodes = {} + foreach_node_to_copy_nodes = defaultdict(list) mutated_inputs = set() storage_to_nodes = defaultdict(list) node_order: Dict[Any, int] = {} @@ -659,13 +692,17 @@ def reinplace_inplaceable_ops(graph): src = node.args[1] # If the target is a getitem and it indexes a possible clone, # then skip over it - if ( - src.target == operator.getitem - and src.args[0].target == triton_kernel_wrapper_functional - and src.args[0].kwargs["kwargs"][src.args[1]] == node.args[0] + if src.target == operator.getitem and ( + ( + src.args[0].target == triton_kernel_wrapper_functional + and src.args[0].kwargs["kwargs"][src.args[1]] == node.args[0] + ) + or (src.args[0].target in inplaceable_foreach_ops) ): src = src.args[0] + copy_args_to_copy_nodes[(dst, src)] = node + assert node.args[0].op == "placeholder" mutated_inputs.add(node.args[0]) @@ -711,31 +748,6 @@ def can_inplace(node, mutated_arg): node, shared_view_nodes, copy_node=None ) - inplaceable_ops = { - aten.index_put.default: InplaceableOp(aten.index_put_.default, 0), - aten._unsafe_index_put.default: InplaceableOp( - inductor_prims._unsafe_index_put_, 0 - ), - } - - try: - c10d_functional = torch.ops._c10d_functional - inplaceable_collective_ops = { - c10d_functional.all_reduce.default: InplaceableOp( - c10d_functional.all_reduce_.default, 0 - ), - c10d_functional.all_reduce_coalesced.default: InplaceableOp( - c10d_functional.all_reduce_coalesced_.default, 0 - ), - } - inplaceable_ops.update(inplaceable_collective_ops) - except AttributeError: - # _c10d_functional ops are only available when torch - # is built with USE_DISTRIBUTED=1. - pass - - inplaceable_triton_ops = {triton_kernel_wrapper_functional} - for node in graph.nodes: if (inplaceable_op := inplaceable_ops.get(node.target, None)) is not None: mutated_arg = node.args[inplaceable_op.mutated_arg] @@ -765,6 +777,20 @@ def can_inplace(node, mutated_arg): kwargs = dict(node.kwargs) kwargs["tensors_to_clone"] = tensors_to_clone node.kwargs = immutable_dict(kwargs) + elif ( + inplaceable_op := inplaceable_foreach_ops.get(node.target, None) + ) is not None: + mutated_args = node.args[inplaceable_op.mutated_arg] + + if not all((arg, node) in copy_args_to_copy_nodes for arg in mutated_args): + continue + + if can_inplace(node, mutated_args): + for arg in mutated_args: + copy_node = copy_args_to_copy_nodes[(arg, node)] + graph.erase_node(copy_node) + + node.target = inplaceable_op.inplace_op @register_lowering_pattern( diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 0c5846153892..bf9c9fff123e 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -2521,7 +2521,7 @@ def real_layout(self): return self.get_buffer().layout @classmethod - def realize_into(cls, src, dst): + def realize_into(cls, src, dst, unsafe_alias=False): dst.realize() # NOTE: We must realize users of `dst` before we realize `src`, since # realization order determines scheduling order. Otherwise, src's @@ -2535,20 +2535,22 @@ def realize_into(cls, src, dst): # be fused into a single kernel by the scheduler. # NOTE: We cannot change src's layout to mutate dst directly as this # would alias src to dst, which is not correct as further mutations to - # dst would effect users of src. + # dst would effect users of src. However if there are no more users of + # dst, we can alias src to dst. src.realize_hint() - src = Pointwise.create( - device=src.get_device(), - dtype=src.get_dtype(), - inner_fn=src.make_loader(), - ranges=[ - V.graph.sizevars.guard_equals(a, b) - for a, b in zip(src.get_size(), dst.get_size()) - ], - ).data - src.realize() + if not unsafe_alias: + src = Pointwise.create( + device=src.get_device(), + dtype=src.get_dtype(), + inner_fn=src.make_loader(), + ranges=[ + V.graph.sizevars.guard_equals(a, b) + for a, b in zip(src.get_size(), dst.get_size()) + ], + ).data + src.realize() assert isinstance(src.data.layout, FlexibleLayout) src.data.layout = MutationLayout(dst) return src.data diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index cf30ef7a7aac..1ef717a301cf 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -66,6 +66,8 @@ prims = torch.ops.prims needs_realized_inputs = set() foreach_ops = set() +inplace_foreach_ops = set() +inplaceable_foreach_ops = dict() def assert_nyi(cond, msg): @@ -451,10 +453,13 @@ def group_args(arg_pairs): out[(device, use_foreach)].append((i, args)) return out - realize_outputs = False + realize_outputs = ( + len(V.graph.current_node.users) == 0 + or V.graph.current_node.target in inplace_foreach_ops + ) for node in V.graph.current_node.users: for user in node.users: - if not (user.op == "call_function" and user.target in foreach_ops): + if not (user.op == "call_function" and (user.target in foreach_ops)): realize_outputs = True a_list_input = None @@ -4724,7 +4729,7 @@ def fn(idx): return pow_native(a, b) -def mutate_to(changed, val): +def mutate_to(changed, val, unsafe_alias=False): if isinstance(changed, TensorBox): changed_data = changed.data else: @@ -4750,7 +4755,7 @@ def mutate_to(changed, val): changed_data.data = val.data return changed - ir.MutationLayout.realize_into(val, changed_data) + ir.MutationLayout.realize_into(val, changed_data, unsafe_alias=unsafe_alias) return changed @@ -5022,19 +5027,23 @@ def register_pointwise_numeric_ldf64(op): register_pointwise_numeric(aten.log10) register_pointwise_numeric(aten.nextafter) -register_foreach_pointwise(aten._foreach_add.List, add, allow_alpha=True) -register_foreach_pointwise(aten._foreach_add.Scalar, add, allow_alpha=True) +foreach_add_list = register_foreach_pointwise( + aten._foreach_add.List, add, allow_alpha=True +) +foreach_add_scalar = register_foreach_pointwise( + aten._foreach_add.Scalar, add, allow_alpha=True +) register_foreach_pointwise(aten._foreach_add.Tensor, add, allow_alpha=True) -register_foreach_pointwise(aten._foreach_mul.List, mul) -register_foreach_pointwise(aten._foreach_mul.Scalar, mul) +foreach_mul_list = register_foreach_pointwise(aten._foreach_mul.List, mul) +foreach_mul_scalar = register_foreach_pointwise(aten._foreach_mul.Scalar, mul) register_foreach_pointwise(aten._foreach_sub.List, sub) register_foreach_pointwise(aten._foreach_sub.Scalar, sub) register_foreach_pointwise(aten._foreach_neg.default, neg) register_foreach_pointwise(aten._foreach_abs.default, abs) register_foreach_pointwise(aten._foreach_pow.Scalar, pow) register_foreach_pointwise(aten._foreach_pow.ScalarAndTensor, pow) -register_foreach_pointwise(aten._foreach_div.List, div) -register_foreach_pointwise(aten._foreach_div.Scalar, div) +foreach_div_list = register_foreach_pointwise(aten._foreach_div.List, div) +foreach_div_scalar = register_foreach_pointwise(aten._foreach_div.Scalar, div) register_foreach_pointwise(aten._foreach_sqrt, sqrt) register_foreach_pointwise(aten._foreach_maximum.List, maximum) register_foreach_pointwise(aten._foreach_maximum.Scalar, maximum) @@ -5049,6 +5058,44 @@ def register_pointwise_numeric_ldf64(op): register_foreach_pointwise(aten._foreach_copy, copy) +# these are only encountered as outputs of the graph +# reinplacing epilogue copies improves compile time +# by removing extra buffers sent to the scheduler. +def register_foreach_inplace(aten_op, outplace_aten_op, outplace_op): + inplaceable_foreach_ops[outplace_aten_op] = aten_op + inplace_foreach_ops.add(aten_op) + + def fn(*args, **kwargs): + results = outplace_op(*args, **kwargs) + mut_results = [] + for arg, result in zip(args[0], results): + mut_results.append(mutate_to(arg, result, unsafe_alias=True)) + + return mut_results + + _register_foreach_lowering(aten_op, fn) + + +register_foreach_inplace( + aten._foreach_add_.List, aten._foreach_add.List, foreach_add_list +) +register_foreach_inplace( + aten._foreach_add_.Scalar, aten._foreach_add.Scalar, foreach_add_scalar +) +register_foreach_inplace( + aten._foreach_mul_.List, aten._foreach_mul.List, foreach_mul_list +) +register_foreach_inplace( + aten._foreach_mul_.Scalar, aten._foreach_mul.Scalar, foreach_mul_scalar +) +register_foreach_inplace( + aten._foreach_div_.List, aten._foreach_div.List, foreach_div_list +) +register_foreach_inplace( + aten._foreach_div_.Scalar, aten._foreach_div.Scalar, foreach_div_scalar +) + + def register_inplace(aten_op, outplace_op): @register_lowering(aten_op, type_promotion_kind=None) def fn(*args, **kwargs): diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 4d1a48f73ffc..ef20893168c5 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -212,11 +212,11 @@ def should_prune(dep): def prune_redundant_deps(self, name_to_fused_node): """ - Prunes stardeps intended for mutation ordering + Prunes weakdeps intended for mutation ordering on an upstream fused node if after fusion there is another dependency - on the fused upstream node, making the stardep redundant + on the fused upstream node, making the weakdep redundant - In essence this enforces an ordering on fusions. As fusions occur, prunable stardeps will + In essence this enforces an ordering on fusions. As fusions occur, weakdeps will be incrementally removed, enabling other fusions, ensuring they are fused in order. """ name_to_dep_count: Counter[str] = collections.Counter() @@ -239,8 +239,10 @@ def should_prune(dep): return False deps_to_prune = {dep for dep in self.unmet_dependencies if should_prune(dep)} - self.unmet_dependencies = self.unmet_dependencies - deps_to_prune - self.set_read_writes(self.read_writes.remove_reads(deps_to_prune)) + + if deps_to_prune: + self.unmet_dependencies = self.unmet_dependencies - deps_to_prune + self.set_read_writes(self.read_writes.remove_reads(deps_to_prune)) def get_name(self) -> str: return self.node.get_name() @@ -1082,6 +1084,10 @@ def get_nodes(self): def get_first_name(self): return self.snodes[0].get_first_name() + def prune_redundant_deps(self, name_to_fused_node): + for node in self.snodes: + node.prune_redundant_deps(name_to_fused_node) + def pick_loop_order(stride_lengths, sizes, priority_idx=()): """ From 2ac0b61e6062fd16e48faaabba56a1274fad9928 Mon Sep 17 00:00:00 2001 From: ydwu4 Date: Mon, 27 Nov 2023 09:43:17 -0800 Subject: [PATCH 182/221] [HigherOrderOp] dedup repeated get_attr placeholders in branches of cond (#112874) We further de-duplicate the dupliacted get_attrs nodes. For code below: ```python def test_cond_free_variable_in_both_branches(self): backend = EagerAndRecordGraphs() cnt = CompileCounterWithBackend(backend) z = torch.ones(4, 4) class Foo(torch.nn.Module): def __init__(self): super().__init__() self.register_buffer("buffer", torch.ones(6, 4)) def forward(self, x, y): def true_fn(x): return x.sum() + self.buffer.sum() + z.sum() def false_fn(x): return x.sum() - z.sum() - self.buffer.sum() return control_flow.cond(y, true_fn, false_fn, [x]) mod_for_compile = torch.compile( Foo(), backend=cnt, dynamic=True, fullgraph=True ) ``` Before de-duplication, we have the following graph module: ```python class GraphModule(torch.nn.Module): def forward(self, L_y_ : torch.Tensor, L_x_ : torch.Tensor, s0 : torch.SymInt, L_z_ : torch.Tensor): l_y_ = L_y_ l_x_ = L_x_ l_z_ = L_z_ # File: /home/yidi/local/pytorch/test/dynamo/test_higher_order_ops.py:1243, code: return x.sum() + self.buffer.sum() + z.sum() l__self___buffer = self.L__self___buffer # File: /home/yidi/local/pytorch/test/dynamo/test_higher_order_ops.py:1246, code: return x.sum() - z.sum() - self.buffer.sum() l__self___buffer_1 = self.L__self___buffer # File: /home/yidi/local/pytorch/torch/_higher_order_ops/cond.py:118, code: return cond_op(pred, true_fn, false_fn, operands) cond_true_0 = self.cond_true_0 cond_false_0 = self.cond_false_0 cond = torch.ops.higher_order.cond(l_y_, cond_true_0, cond_false_0, [l_x_, l_z_, l__self___buffer, l__self___buffer_1]); l_y_ = cond_true_0 = cond_false_0 = l_x_ = l_z_ = l__self___buffer = l__self___buffer_1 = None return (cond,) class GraphModule(torch.nn.Module): def forward(self, l_x_, l_z_, l__self___buffer_true_branch, l__self___buffer_1_false_branch): l_x__1 = l_x_ l_z__1 = l_z_ # File: /home/yidi/local/pytorch/test/dynamo/test_higher_order_ops.py:1243, code: return x.sum() + self.buffer.sum() + z.sum() sum_1 = l_x__1.sum(); l_x__1 = None sum_2 = l__self___buffer_true_branch.sum(); l__self___buffer_true_branch = None add = sum_1 + sum_2; sum_1 = sum_2 = None sum_3 = l_z__1.sum(); l_z__1 = None add_1 = add + sum_3; add = sum_3 = None return add_1 class GraphModule(torch.nn.Module): def forward(self, l_x_, l_z_, l__self___buffer_true_branch, l__self___buffer_1_false_branch): l_x__1 = l_x_ l_z__1 = l_z_ # File: /home/yidi/local/pytorch/test/dynamo/test_higher_order_ops.py:1246, code: return x.sum() - z.sum() - self.buffer.sum() sum_1 = l_x__1.sum(); l_x__1 = None sum_2 = l_z__1.sum(); l_z__1 = None sub = sum_1 - sum_2; sum_1 = sum_2 = None sum_3 = l__self___buffer_1_false_branch.sum(); l__self___buffer_1_false_branch = None sub_1 = sub - sum_3; sub = sum_3 = None return sub_1 ``` After de-duplication, we have the following graph module: ```python class GraphModule(torch.nn.Module): def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor, s0 : torch.SymInt, L_z_ : torch.Tensor): l_x_ = L_x_ l_y_ = L_y_ l_z_ = L_z_ # File: /home/yidi/local/pytorch/test/dynamo/test_higher_order_ops.py:1232, code: return x.sum() + self.buffer.sum() + z.sum() l__self___buffer = self.L__self___buffer # File: /home/yidi/local/pytorch/torch/_higher_order_ops/cond.py:118, code: return cond_op(pred, true_fn, false_fn, operands) cond_true_0 = self.cond_true_0 cond_false_0 = self.cond_false_0 cond = torch.ops.higher_order.cond(l_y_, cond_true_0, cond_false_0, [l__self___buffer, l_x_, l_z_]); l_y_ = cond_true_0 = cond_false_0 = l__self___buffer = l_x_ = l_z_ = None return (cond,) class GraphModule(torch.nn.Module): def forward(self, l__self___buffer, l_x_, l_z_): l__self___buffer_1 = l__self___buffer l_x__1 = l_x_ l_z__1 = l_z_ # File: /home/yidi/local/pytorch/test/dynamo/test_higher_order_ops.py:1232, code: return x.sum() + self.buffer.sum() + z.sum() sum_1 = l_x__1.sum(); l_x__1 = None sum_2 = l__self___buffer_1.sum(); l__self___buffer_1 = None add = sum_1 + sum_2; sum_1 = sum_2 = None sum_3 = l_z__1.sum(); l_z__1 = None add_1 = add + sum_3; add = sum_3 = None return add_1 class GraphModule(torch.nn.Module): def forward(self, l__self___buffer_1, l_x_, l_z_): l__self___buffer_2 = l__self___buffer_1 l_x__1 = l_x_ l_z__1 = l_z_ # File: /home/yidi/local/pytorch/test/dynamo/test_higher_order_ops.py:1235, code: return x.sum() - z.sum() - self.buffer.sum() sum_1 = l_x__1.sum(); l_x__1 = None sum_2 = l_z__1.sum(); l_z__1 = None sub = sum_1 - sum_2; sum_1 = sum_2 = None sum_3 = l__self___buffer_2.sum(); l__self___buffer_2 = None sub_1 = sub - sum_3; sub = sum_3 = None return sub_1 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/112874 Approved by: https://github.com/zou3519 --- test/dynamo/test_export.py | 8 +-- test/dynamo/test_higher_order_ops.py | 8 +-- torch/_dynamo/variables/higher_order_ops.py | 64 +++++++++++++++++---- 3 files changed, 61 insertions(+), 19 deletions(-) diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py index ca49583b7825..306d871dc0f6 100644 --- a/test/dynamo/test_export.py +++ b/test/dynamo/test_export.py @@ -4049,8 +4049,8 @@ def forward(self, x): self.assertExpectedInline( gm.true_graph_0.code.strip(), """\ -def forward(self, arg0_1, arg1_1, arg2_1): - out_dtype = torch.ops.higher_order.out_dtype(torch.ops.aten.mm.default, torch.int32, arg0_1, arg2_1); arg0_1 = arg2_1 = None +def forward(self, arg0_1, arg1_1): + out_dtype = torch.ops.higher_order.out_dtype(torch.ops.aten.mm.default, torch.int32, arg1_1, arg0_1); arg1_1 = arg0_1 = None sum_1 = torch.ops.aten.sum.default(out_dtype); out_dtype = None return (sum_1,)""", ) @@ -4058,8 +4058,8 @@ def forward(self, arg0_1, arg1_1, arg2_1): self.assertExpectedInline( gm.false_graph_0.code.strip(), """\ -def forward(self, arg0_1, arg1_1, arg2_1): - out_dtype = torch.ops.higher_order.out_dtype(torch.ops.aten.mul.Tensor, torch.int32, arg0_1, arg2_1); arg0_1 = arg2_1 = None +def forward(self, arg0_1, arg1_1): + out_dtype = torch.ops.higher_order.out_dtype(torch.ops.aten.mul.Tensor, torch.int32, arg1_1, arg0_1); arg1_1 = arg0_1 = None sum_1 = torch.ops.aten.sum.default(out_dtype); out_dtype = None return (sum_1,)""", ) diff --git a/test/dynamo/test_higher_order_ops.py b/test/dynamo/test_higher_order_ops.py index ba5427d5271c..51b9786faa74 100644 --- a/test/dynamo/test_higher_order_ops.py +++ b/test/dynamo/test_higher_order_ops.py @@ -1263,10 +1263,8 @@ def false_fn(x): and node.target == torch.ops.higher_order.cond ): _, _, _, operands = node.args - # Each branch takes 4 inputs (x, z, buffer_true_branch, buffer_false_branch) - # TODO: we should be able to de-duplicate the buffer accessed from two branches so that - # operands become (x, z, buffer) - self.assertEqual(len(operands), 4) + # Each branch takes 3 inputs (buffer, x, z) + self.assertEqual(len(operands), 3) if node.op == "get_attr": if str(node.target) in ("cond_true_0, cond_false_0"): num_placeholders = len( @@ -1278,7 +1276,7 @@ def false_fn(x): if node.op == "placeholder" ] ) - self.assertEqual(num_placeholders, 4) + self.assertEqual(num_placeholders, 3) def _check_cond_graph_and_extract(self, fn, args): backend = EagerAndRecordGraphs() diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py index bd06d487652d..c9e4cca9ed48 100644 --- a/torch/_dynamo/variables/higher_order_ops.py +++ b/torch/_dynamo/variables/higher_order_ops.py @@ -543,22 +543,64 @@ def diff_meta(tensor_vars1, tensor_vars2): ) def dedup_and_sort_lifted_freevars(true_lifted_freevars, false_lifted_freevars): - shared_freevars = true_lifted_freevars.keys() & false_lifted_freevars.keys() - unique_true_freevars = true_lifted_freevars.keys() - shared_freevars - unique_false_freevars = false_lifted_freevars.keys() - shared_freevars + # The nn module attributes are guaranteed to be registered into the top-level graph module during + # higher order op speculation. Therefore, get_attr nodes in two branches with the same + # target refer to the same attribute and we can safely deduplicate them with their target. + # + # Note: ideally, dynamo should just create a single proxy for the same attribute of a nn module. But + # true_branch and false_branch belong to two separate tracing contexts, they may register the same + # attribute to top level seperately. This creates two get_attr proxies for the same attribute + # that have different meta data such as stack_trace (one stack trace for the true_branch, + # and the other for false_branch). It seems better to discard the proxy explicitly in cond + # than make dynamo create a single proxy for the same get_attr target. + def shared_getattrs(true_lifted_proxies, false_lifted_proxies): + true_targets = { + proxy.node.target: proxy + for proxy in true_lifted_proxies + if proxy.node.op == "get_attr" + } + true_fn_shared_getattrs = {} + false_fn_shared_getattrs = {} + + for false_proxy in false_lifted_proxies: + if ( + false_proxy.node.op == "get_attr" + and false_proxy.node.target in true_targets + ): + true_proxy = true_targets[false_proxy.node.target] + true_fn_shared_getattrs[true_proxy] = true_proxy + false_fn_shared_getattrs[false_proxy] = true_proxy + return true_fn_shared_getattrs, false_fn_shared_getattrs + + true_fn_shared_getattrs, false_fn_shared_getattrs = shared_getattrs( + true_lifted_freevars.keys(), false_lifted_freevars.keys() + ) + + true_shared_freevars = ( + true_lifted_freevars.keys() & false_lifted_freevars.keys() + ).union(true_fn_shared_getattrs.keys()) + false_shared_freevars = ( + true_lifted_freevars.keys() & false_lifted_freevars.keys() + ).union(false_fn_shared_getattrs.keys()) + unique_true_freevars = true_lifted_freevars.keys() - true_shared_freevars + unique_false_freevars = false_lifted_freevars.keys() - false_shared_freevars def _sort_by_name(vars): return sorted(vars, key=lambda var: var.node.name) return ( - list(_sort_by_name(list(shared_freevars))), + list(_sort_by_name(list(true_shared_freevars))), + list(_sort_by_name(list(false_shared_freevars))), list(_sort_by_name(list(unique_true_freevars))), list(_sort_by_name(list(unique_false_freevars))), ) - shared, unique_true, unique_false = dedup_and_sort_lifted_freevars( - true_lifted_freevars, false_lifted_freevars - ) + ( + true_shared, + false_shared, + unique_true, + unique_false, + ) = dedup_and_sort_lifted_freevars(true_lifted_freevars, false_lifted_freevars) # Let's say we capture cond(pred, true_fn, false_fn, (x,)) # With mannually_set_graph_input set to False, @@ -579,6 +621,7 @@ def fixup_branch_inps( def _insert_or_replace_phs(new_args, name_suffix): for arg in new_args: new_ph = graph.placeholder(arg.node.name + name_suffix) + # Override with new_ph if there exists a old placeholder. if arg in lifted_freevars: old_ph = lifted_freevars[arg].node old_ph.replace_all_uses_with(new_ph) @@ -595,10 +638,10 @@ def _insert_or_replace_phs(new_args, name_suffix): _insert_or_replace_phs(unique_false, "_false_branch") fixup_branch_inps( - true_graph, true_lifted_freevars, shared, unique_true, unique_false + true_graph, true_lifted_freevars, true_shared, unique_true, unique_false ) fixup_branch_inps( - false_graph, false_lifted_freevars, shared, unique_true, unique_false + false_graph, false_lifted_freevars, false_shared, unique_true, unique_false ) true_name = add_subgraph( @@ -621,7 +664,8 @@ def _insert_or_replace_phs(new_args, name_suffix): args[0].as_proxy(), true_node, false_node, - shared + unique_true + unique_false, + # We pick true_shared but it shouldn't matter + true_shared + unique_true + unique_false, ) flat_example_value = pytree.tree_map_only( torch.fx.Proxy, From 2ea2421b44e6e064d5c8407daba4152a6d706623 Mon Sep 17 00:00:00 2001 From: Jithun Nair Date: Mon, 27 Nov 2023 22:25:31 +0000 Subject: [PATCH 183/221] Skip unit tests that fail on MI210 runners (#114613) Taken from https://github.com/pytorch/pytorch/pull/105980 Pull Request resolved: https://github.com/pytorch/pytorch/pull/114613 Approved by: https://github.com/malfet --- test/nn/test_convolution.py | 1 + test/run_test.py | 1 + test/test_matmul_cuda.py | 2 ++ torch/testing/_internal/common_methods_invocations.py | 2 ++ 4 files changed, 6 insertions(+) diff --git a/test/nn/test_convolution.py b/test/nn/test_convolution.py index abfd240d4231..893292e7c487 100644 --- a/test/nn/test_convolution.py +++ b/test/nn/test_convolution.py @@ -1927,6 +1927,7 @@ def test_conv_noncontig_weights_and_bias(self, device): @onlyCUDA @largeTensorTest('12GB') + @skipIfRocmVersionLessThan((6, 0)) def test_conv_transposed_large(self, device): dtype = torch.half if self.device_type == 'cuda' else torch.float conv = nn.ConvTranspose2d(1, 1, 1, 1, bias=False).to(device).to(dtype) diff --git a/test/run_test.py b/test/run_test.py index 69f001ddb879..c88f5bc19c4f 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -298,6 +298,7 @@ def skip_test_p(name: str) -> bool: "test_jit_legacy", "test_cuda_nvml_based_avail", "test_jit_cuda_fuser", + "dynamo/test_activation_checkpointing", ] # The tests inside these files should never be run in parallel with each other diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index 6ee19a105699..140903b5e2b7 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -30,6 +30,7 @@ run_tests, skipIfRocmVersionLessThan, TEST_WITH_ROCM, + skipIfRocm, TestCase, ) @@ -160,6 +161,7 @@ def test_cublas_addmm_alignment(self, dtype): (1, 10000, 10000, 10000)], name_fn=lambda batch_size, N, M, P: f"{batch_size}_{N}_{M}_{P}", ) + @skipIfRocm def test_cublas_baddbmm_large_input(self, device, batch_size, N, M, P, dtype): cpu_dtype = dtype if dtype == torch.float16 or dtype == torch.bfloat16: diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index b030f3e9f1b9..d704534b00bb 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -13186,6 +13186,8 @@ def reference_flatten(input, start_dim=0, end_dim=-1): # RuntimeError: UNSUPPORTED DTYPE: complex DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness', dtypes=(torch.complex64, torch.complex128)), + DecorateInfo(unittest.skip('Skipped for ROCm!'), 'TestCommon', 'test_complex_half_reference_testing', + dtypes=[torch.complex32], active_if=TEST_WITH_ROCM), ), supports_out=False,), OpInfo('nn.functional.conv1d', From e4b1378a926b9727ec9ca41bfc0a428fa4a166ad Mon Sep 17 00:00:00 2001 From: Will Constable Date: Mon, 27 Nov 2023 08:44:37 -0800 Subject: [PATCH 184/221] Fix dynamo test_logging handling of partial qnames (#114429) if logger_qname is a.b.c and dynamo_qnames contains a.b, it still matches dynamo's INFO setting concretely, torch._dynamo.backends.distributed is implicitly part of the dynamo namespace since it is covered by `torch._dynamo` which is one of dynamo_qnames. However, it is not an exact match for any of dynamo_qnames, which made this test fail when adding a specific qname for backends.distributed in the subsequent PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/114429 Approved by: https://github.com/Skylion007 ghstack dependencies: #114428 --- test/dynamo/test_logging.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/dynamo/test_logging.py b/test/dynamo/test_logging.py index 59a1e33df8a7..15d3c946806f 100644 --- a/test/dynamo/test_logging.py +++ b/test/dynamo/test_logging.py @@ -248,7 +248,8 @@ def test_all(self, _): for logger_qname in torch._logging._internal.log_registry.get_log_qnames(): logger = logging.getLogger(logger_qname) - if logger_qname in dynamo_qnames: + # if logger_qname is a.b.c and dynamo_qnames contains a.b, it still matches dynamo's INFO setting + if any(logger_qname.find(d) == 0 for d in dynamo_qnames): self.assertEqual( logger.getEffectiveLevel(), logging.INFO, From 7c98bac4a01128a5756c22f7aee6fc219d354d10 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Mon, 27 Nov 2023 23:32:59 +0000 Subject: [PATCH 185/221] [BE] Speedup register schema compilation (#114438) For some reason, inlining initializer list into a std::vector takes a lot of time using clang-15. But considering that there are only dozen or so distrinct tags, creating them once and pass as def argument should not affect runtime speed at all, but this significantly improves compilation time. On Mac M1 it reduces time needed to compiler RegisterSchema.cpp from 50 to 3 seconds. Special case empty tags, to keep torch_gen tests happy Before ``` % /Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/bin/c++ -ftime-report -DAT_PER_OPERATOR_HEADERS -DCAFFE2_BUILD_MAIN_LIB -DCPUINFO_SUPPORTED_PLATFORM=1 -DFMT_HEADER_ONLY=1 -DFXDIV_USE_INLINE_ASSEMBLY=0 -DHAVE_MMAP=1 -DHAVE_SHM_OPEN=1 -DHAVE_SHM_UNLINK=1 -DMINIZ_DISABLE_ZIP_READER_CRC32_CHECKS -DNNP_CONVOLUTION_ONLY=0 -DNNP_INFERENCE_ONLY=0 -DONNXIFI_ENABLE_EXT=1 -DONNX_ML=1 -DONNX_NAMESPACE=onnx_torch -DUSE_EXTERNAL_MZCRC -D_FILE_OFFSET_BITS=64 -Dtorch_cpu_EXPORTS -I/Users/nshulga/git/pytorch/pytorch/build/aten/src -I/Users/nshulga/git/pytorch/pytorch/aten/src -I/Users/nshulga/git/pytorch/pytorch/build -I/Users/nshulga/git/pytorch/pytorch -I/Users/nshulga/git/pytorch/pytorch/cmake/../third_party/benchmark/include -I/Users/nshulga/git/pytorch/pytorch/third_party/onnx -I/Users/nshulga/git/pytorch/pytorch/build/third_party/onnx -I/Users/nshulga/git/pytorch/pytorch/third_party/foxi -I/Users/nshulga/git/pytorch/pytorch/build/third_party/foxi -I/Users/nshulga/git/pytorch/pytorch/torch/csrc/api -I/Users/nshulga/git/pytorch/pytorch/torch/csrc/api/include -I/Users/nshulga/git/pytorch/pytorch/caffe2/aten/src/TH -I/Users/nshulga/git/pytorch/pytorch/build/caffe2/aten/src/TH -I/Users/nshulga/git/pytorch/pytorch/build/caffe2/aten/src -I/Users/nshulga/git/pytorch/pytorch/build/caffe2/../aten/src -I/Users/nshulga/git/pytorch/pytorch/torch/csrc -I/Users/nshulga/git/pytorch/pytorch/third_party/miniz-2.1.0 -I/Users/nshulga/git/pytorch/pytorch/third_party/kineto/libkineto/include -I/Users/nshulga/git/pytorch/pytorch/third_party/kineto/libkineto/src -I/Users/nshulga/git/pytorch/pytorch/aten/src/ATen/.. -I/Users/nshulga/git/pytorch/pytorch/third_party/FXdiv/include -I/Users/nshulga/git/pytorch/pytorch/c10/.. -I/Users/nshulga/git/pytorch/pytorch/third_party/pthreadpool/include -I/Users/nshulga/git/pytorch/pytorch/third_party/cpuinfo/include -I/Users/nshulga/git/pytorch/pytorch/aten/src/ATen/native/quantized/cpu/qnnpack/include -I/Users/nshulga/git/pytorch/pytorch/aten/src/ATen/native/quantized/cpu/qnnpack/src -I/Users/nshulga/git/pytorch/pytorch/third_party/cpuinfo/deps/clog/include -I/Users/nshulga/git/pytorch/pytorch/third_party/NNPACK/include -I/Users/nshulga/git/pytorch/pytorch/third_party/FP16/include -I/Users/nshulga/git/pytorch/pytorch/third_party/fmt/include -I/Users/nshulga/git/pytorch/pytorch/third_party/flatbuffers/include -isystem /Users/nshulga/git/pytorch/pytorch/cmake/../third_party/googletest/googlemock/include -isystem /Users/nshulga/git/pytorch/pytorch/cmake/../third_party/googletest/googletest/include -isystem /Users/nshulga/git/pytorch/pytorch/third_party/protobuf/src -isystem /Users/nshulga/git/pytorch/pytorch/third_party/XNNPACK/include -isystem /Users/nshulga/git/pytorch/pytorch/cmake/../third_party/eigen -isystem /Users/nshulga/git/pytorch/pytorch/build/include -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOCUPTI -DLIBKINETO_NOROCTRACER -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=braced-scalar-init -Werror=range-loop-construct -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wvla-extension -Wsuggest-override -Wnewline-eof -Winconsistent-missing-override -Winconsistent-missing-destructor-override -Wno-pass-failed -Wno-error=pedantic -Wno-error=old-style-cast -Wno-error=inconsistent-missing-override -Wno-error=inconsistent-missing-destructor-override -Wconstant-conversion -Wno-invalid-partial-specialization -Wno-missing-braces -Qunused-arguments -fcolor-diagnostics -faligned-new -Werror -Wno-unused-but-set-variable -fno-math-errno -fno-trapping-math -Werror=format -DUSE_MPS -Wno-unused-private-field -Wno-missing-braces -O3 -DNDEBUG -DNDEBUG -arch arm64 -isysroot /Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX14.0.sdk -fPIC -D__NEON__ -Wall -Wextra -Wdeprecated -Wno-unused-parameter -Wno-unused-function -Wno-missing-field-initializers -Wno-unknown-pragmas -Wno-type-limits -Wno-array-bounds -Wno-strict-overflow -Wno-strict-aliasing -fvisibility=hidden -O2 -Wmissing-prototypes -Werror=missing-prototypes -Xpreprocessor -fopenmp -I/Users/nshulga/miniforge3/include -std=gnu++17 -Wno-missing-prototypes -Wno-error=missing-prototypes -o caffe2/CMakeFiles/torch_cpu.dir/__/aten/src/ATen/RegisterSchema.cpp.o -c /Users/nshulga/git/pytorch/pytorch/build/aten/src/ATen/RegisterSchema.cpp ===-------------------------------------------------------------------------=== ... Pass execution timing report ... ===-------------------------------------------------------------------------=== Total Execution Time: 131.8054 seconds (132.5540 wall clock) ---User Time--- --System Time-- --User+System-- ---Wall Time--- ---Instr--- --- Name --- 43.6364 ( 33.2%) 0.0919 ( 30.1%) 43.7282 ( 33.2%) 43.9658 ( 33.2%) 536345245380 ModuleInlinerWrapperPass 43.6291 ( 33.2%) 0.0891 ( 29.2%) 43.7182 ( 33.2%) 43.9549 ( 33.2%) 536264096394 DevirtSCCRepeatedPass 42.3766 ( 32.2%) 0.0185 ( 6.1%) 42.3951 ( 32.2%) 42.6198 ( 32.2%) 523040901767 GVNPass 0.4085 ( 0.3%) 0.0040 ( 1.3%) 0.4125 ( 0.3%) 0.4195 ( 0.3%) 4106085945 SimplifyCFGPass 0.3611 ( 0.3%) 0.0115 ( 3.8%) 0.3726 ( 0.3%) 0.3779 ( 0.3%) 4864696407 InstCombinePass 0.1607 ( 0.1%) 0.0088 ( 2.9%) 0.1695 ( 0.1%) 0.1720 ( 0.1%) 1780986175 InlinerPass 0.0865 ( 0.1%) 0.0024 ( 0.8%) 0.0889 ( 0.1%) 0.0914 ( 0.1%) 1489982961 SROAPass 0.0750 ( 0.1%) 0.0013 ( 0.4%) 0.0763 ( 0.1%) 0.0764 ( 0.1%) 620016338 SCCPPass 0.0661 ( 0.1%) 0.0040 ( 1.3%) 0.0701 ( 0.1%) 0.0735 ( 0.1%) 592027163 EarlyCSEPass ... ===-------------------------------------------------------------------------=== Clang front-end time report ===-------------------------------------------------------------------------=== Total Execution Time: 48.2802 seconds (48.8638 wall clock) ... ``` After ``` % /Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/bin/c++ -ftime-report -DAT_PER_OPERATOR_HEADERS -DCAFFE2_BUILD_MAIN_LIB -DCPUINFO_SUPPORTED_PLATFORM=1 -DFMT_HEADER_ONLY=1 -DFXDIV_USE_INLINE_ASSEMBLY=0 -DHAVE_MMAP=1 -DHAVE_SHM_OPEN=1 -DHAVE_SHM_UNLINK=1 -DMINIZ_DISABLE_ZIP_READER_CRC32_CHECKS -DNNP_CONVOLUTION_ONLY=0 -DNNP_INFERENCE_ONLY=0 -DONNXIFI_ENABLE_EXT=1 -DONNX_ML=1 -DONNX_NAMESPACE=onnx_torch -DUSE_EXTERNAL_MZCRC -D_FILE_OFFSET_BITS=64 -Dtorch_cpu_EXPORTS -I/Users/nshulga/git/pytorch/pytorch/build/aten/src -I/Users/nshulga/git/pytorch/pytorch/aten/src -I/Users/nshulga/git/pytorch/pytorch/build -I/Users/nshulga/git/pytorch/pytorch -I/Users/nshulga/git/pytorch/pytorch/cmake/../third_party/benchmark/include -I/Users/nshulga/git/pytorch/pytorch/third_party/onnx -I/Users/nshulga/git/pytorch/pytorch/build/third_party/onnx -I/Users/nshulga/git/pytorch/pytorch/third_party/foxi -I/Users/nshulga/git/pytorch/pytorch/build/third_party/foxi -I/Users/nshulga/git/pytorch/pytorch/torch/csrc/api -I/Users/nshulga/git/pytorch/pytorch/torch/csrc/api/include -I/Users/nshulga/git/pytorch/pytorch/caffe2/aten/src/TH -I/Users/nshulga/git/pytorch/pytorch/build/caffe2/aten/src/TH -I/Users/nshulga/git/pytorch/pytorch/build/caffe2/aten/src -I/Users/nshulga/git/pytorch/pytorch/build/caffe2/../aten/src -I/Users/nshulga/git/pytorch/pytorch/torch/csrc -I/Users/nshulga/git/pytorch/pytorch/third_party/miniz-2.1.0 -I/Users/nshulga/git/pytorch/pytorch/third_party/kineto/libkineto/include -I/Users/nshulga/git/pytorch/pytorch/third_party/kineto/libkineto/src -I/Users/nshulga/git/pytorch/pytorch/aten/src/ATen/.. -I/Users/nshulga/git/pytorch/pytorch/third_party/FXdiv/include -I/Users/nshulga/git/pytorch/pytorch/c10/.. -I/Users/nshulga/git/pytorch/pytorch/third_party/pthreadpool/include -I/Users/nshulga/git/pytorch/pytorch/third_party/cpuinfo/include -I/Users/nshulga/git/pytorch/pytorch/aten/src/ATen/native/quantized/cpu/qnnpack/include -I/Users/nshulga/git/pytorch/pytorch/aten/src/ATen/native/quantized/cpu/qnnpack/src -I/Users/nshulga/git/pytorch/pytorch/third_party/cpuinfo/deps/clog/include -I/Users/nshulga/git/pytorch/pytorch/third_party/NNPACK/include -I/Users/nshulga/git/pytorch/pytorch/third_party/FP16/include -I/Users/nshulga/git/pytorch/pytorch/third_party/fmt/include -I/Users/nshulga/git/pytorch/pytorch/third_party/flatbuffers/include -isystem /Users/nshulga/git/pytorch/pytorch/cmake/../third_party/googletest/googlemock/include -isystem /Users/nshulga/git/pytorch/pytorch/cmake/../third_party/googletest/googletest/include -isystem /Users/nshulga/git/pytorch/pytorch/third_party/protobuf/src -isystem /Users/nshulga/git/pytorch/pytorch/third_party/XNNPACK/include -isystem /Users/nshulga/git/pytorch/pytorch/cmake/../third_party/eigen -isystem /Users/nshulga/git/pytorch/pytorch/build/include -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOCUPTI -DLIBKINETO_NOROCTRACER -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=braced-scalar-init -Werror=range-loop-construct -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wvla-extension -Wsuggest-override -Wnewline-eof -Winconsistent-missing-override -Winconsistent-missing-destructor-override -Wno-pass-failed -Wno-error=pedantic -Wno-error=old-style-cast -Wno-error=inconsistent-missing-override -Wno-error=inconsistent-missing-destructor-override -Wconstant-conversion -Wno-invalid-partial-specialization -Wno-missing-braces -Qunused-arguments -fcolor-diagnostics -faligned-new -Werror -Wno-unused-but-set-variable -fno-math-errno -fno-trapping-math -Werror=format -DUSE_MPS -Wno-unused-private-field -Wno-missing-braces -O3 -DNDEBUG -DNDEBUG -arch arm64 -isysroot /Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX14.0.sdk -fPIC -D__NEON__ -Wall -Wextra -Wdeprecated -Wno-unused-parameter -Wno-unused-function -Wno-missing-field-initializers -Wno-unknown-pragmas -Wno-type-limits -Wno-array-bounds -Wno-strict-overflow -Wno-strict-aliasing -fvisibility=hidden -O2 -Wmissing-prototypes -Werror=missing-prototypes -Xpreprocessor -fopenmp -I/Users/nshulga/miniforge3/include -std=gnu++17 -Wno-missing-prototypes -Wno-error=missing-prototypes -o caffe2/CMakeFiles/torch_cpu.dir/__/aten/src/ATen/RegisterSchema.cpp.o -c /Users/nshulga/git/pytorch/pytorch/build/aten/src/ATen/RegisterSchema.cpp ===-------------------------------------------------------------------------=== ... Pass execution timing report ... ===-------------------------------------------------------------------------=== Total Execution Time: 1.2920 seconds (1.3187 wall clock) ---User Time--- --System Time-- --User+System-- ---Wall Time--- ---Instr--- --- Name --- 0.3070 ( 27.6%) 0.0547 ( 30.2%) 0.3617 ( 28.0%) 0.3654 ( 27.7%) 3719690895 ModuleInlinerWrapperPass 0.3024 ( 27.2%) 0.0525 ( 29.0%) 0.3549 ( 27.5%) 0.3585 ( 27.2%) 3653363330 DevirtSCCRepeatedPass 0.0619 ( 5.6%) 0.0073 ( 4.0%) 0.0692 ( 5.4%) 0.0711 ( 5.4%) 868136227 InstCombinePass 0.0601 ( 5.4%) 0.0065 ( 3.6%) 0.0666 ( 5.2%) 0.0679 ( 5.1%) 696430647 InlinerPass 0.0363 ( 3.3%) 0.0033 ( 1.8%) 0.0396 ( 3.1%) 0.0425 ( 3.2%) 535426974 SimplifyCFGPass 0.0280 ( 2.5%) 0.0069 ( 3.8%) 0.0348 ( 2.7%) 0.0358 ( 2.7%) 378716394 BlockFrequencyAnalysis 0.0208 ( 1.9%) 0.0049 ( 2.7%) 0.0257 ( 2.0%) 0.0262 ( 2.0%) 283689627 BranchProbabilityAnalysis 0.0239 ( 2.1%) 0.0002 ( 0.1%) 0.0241 ( 1.9%) 0.0241 ( 1.8%) 219122704 OpenMPOptCGSCCPass 0.0174 ( 1.6%) 0.0015 ( 0.8%) 0.0189 ( 1.5%) 0.0192 ( 1.5%) 215583965 GVNPass 0.0153 ( 1.4%) 0.0025 ( 1.4%) 0.0178 ( 1.4%) 0.0187 ( 1.4%) 184232295 EarlyCSEPass ... ===-------------------------------------------------------------------------=== Clang front-end time report ===-------------------------------------------------------------------------=== Total Execution Time: 2.9128 seconds (3.1027 wall clock) ... ``` And the generated schema file looks as follows: ```cpp TORCH_LIBRARY(aten, m) { const std::vector tags_0 = {at::Tag::pt2_compliant_tag}; m.def("_cast_Byte(Tensor self, bool non_blocking=False) -> Tensor", tags_0); m.def("_cast_Char(Tensor self, bool non_blocking=False) -> Tensor", tags_0); m.def("_cast_Double(Tensor self, bool non_blocking=False) -> Tensor", tags_0); m.def("_cast_Float(Tensor self, bool non_blocking=False) -> Tensor", tags_0); m.def("_cast_Int(Tensor self, bool non_blocking=False) -> Tensor", tags_0); m.def("_cast_Long(Tensor self, bool non_blocking=False) -> Tensor", tags_0); m.def("_cast_Short(Tensor self, bool non_blocking=False) -> Tensor", tags_0); m.def("_cast_Half(Tensor self, bool non_blocking=False) -> Tensor", tags_0); m.def("_backward(Tensor self, Tensor[] inputs, Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False) -> ()", tags_0); m.def("set_data(Tensor(a!) self, Tensor new_data) -> ()", tags_0); m.def("data(Tensor self) -> Tensor", tags_0); m.def("is_leaf(Tensor self) -> bool", tags_0); m.def("output_nr(Tensor self) -> int", tags_0); m.def("_version(Tensor self) -> int", tags_0); m.def("requires_grad_(Tensor(a!) self, bool requires_grad=True) -> Tensor(a!)", tags_0); m.def("retain_grad(Tensor(a!) self) -> ()", tags_0); m.def("retains_grad(Tensor self) -> bool", tags_0); m.def("_fw_primal(Tensor(a) self, int level) -> Tensor(a)", tags_0); m.def("_make_dual(Tensor(a) primal, Tensor tangent, int level) -> Tensor(a)", tags_0); m.def("_unpack_dual(Tensor(a) dual, int level) -> (Tensor(a) primal, Tensor tangent)", tags_0); m.def("_new_zeros_with_same_feature_meta(Tensor self, Tensor other, *, int self_num_batch_dims=0) -> Tensor", tags_0); m.def("_has_same_storage_numel(Tensor self, Tensor other) -> bool", tags_0); const std::vector tags_1 = {at::Tag::inplace_view, at::Tag::pt2_compliant_tag}; m.def("rename_(Tensor(a!) self, Dimname[]? names) -> Tensor(a!)", tags_1); m.def("rename(Tensor(a) self, Dimname[]? names) -> Tensor(a)", tags_0); m.def("align_to(Tensor(a) self, Dimname[] names) -> Tensor(a)", tags_0); m.def("align_to.ellipsis_idx(Tensor(a) self, Dimname[] order, int ellipsis_idx) -> Tensor(a)", tags_0); m.def("align_as(Tensor self, Tensor other) -> Tensor", tags_0); ... ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/114438 Approved by: https://github.com/zou3519 --- torchgen/gen.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/torchgen/gen.py b/torchgen/gen.py index 19fa667f33e2..6bb15c28555d 100644 --- a/torchgen/gen.py +++ b/torchgen/gen.py @@ -4,7 +4,7 @@ import os import pathlib from collections import defaultdict, namedtuple, OrderedDict -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import ( Any, Callable, @@ -542,13 +542,21 @@ def static_dispatch( @dataclass(frozen=True) class RegisterSchema: selector: SelectiveBuilder + known_tags: Dict[str, int] = field(default_factory=dict) @method_with_native_function def __call__(self, f: NativeFunction) -> Optional[str]: if not self.selector.is_native_function_selected(f): return None tags = "{" + ", ".join(f"at::Tag::{tag}" for tag in sorted(f.tags)) + "}" - return f"m.def({cpp_string(str(f.func))}, {tags});\n" + if tags == "{}": + return f"m.def({cpp_string(str(f.func))}, {{}});\n" + maybe_tags = "" + if tags not in self.known_tags: + idx = len(self.known_tags) + self.known_tags[tags] = idx + maybe_tags = f"const std::vector tags_{idx} = {tags};\n" + return f"{maybe_tags}m.def({cpp_string(str(f.func))}, tags_{self.known_tags[tags]});\n" # Generates Operators.h and Operators.cpp. From f505d764628f97182c7286f398fb5bc622471606 Mon Sep 17 00:00:00 2001 From: Pritam Damania Date: Mon, 27 Nov 2023 23:52:35 +0000 Subject: [PATCH 186/221] Bug fixes to DDP _update_process_group API. (#114194) https://github.com/pytorch/pytorch/pull/113580 introduced the `DDP._update_process_group` API. However, the implementation did not correctly reset all of the necessary state in the reducer. In particular if an error occurred during backward, DDP would end up in an incorrect state. As a result, in this PR I've enhanced the unit test to test for this case and also appropriately fixed resetting Reducer state. Pull Request resolved: https://github.com/pytorch/pytorch/pull/114194 Approved by: https://github.com/rohan-varma --- torch/_C/_distributed_c10d.pyi | 2 +- torch/csrc/distributed/c10d/init.cpp | 6 +- torch/csrc/distributed/c10d/reducer.cpp | 11 +- torch/csrc/distributed/c10d/reducer.hpp | 4 +- torch/nn/parallel/distributed.py | 2 +- .../_internal/distributed/distributed_test.py | 137 ++++++++++++------ 6 files changed, 107 insertions(+), 55 deletions(-) diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index 6e16c3a4c1b0..8f3153d67a93 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -66,7 +66,7 @@ class Reducer: def _remove_autograd_hooks(self) -> None: ... def _check_reducer_finalized(self) -> None: ... def _set_sparse_metadata(self, global_unique_ids: Dict[str, Tensor]) -> None: ... - def _force_bucket_rebuild(self) -> None: ... + def _reset_state(self) -> None: ... def _update_process_group(self, new_process_group: ProcessGroup) -> None: ... class DDPLoggingData: diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 7d23ad1b6479..5e312e1d9887 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -598,10 +598,8 @@ An enum-like class for built-in communication hooks: ``ALLREDUCE`` and ``FP16_CO [](::c10d::Reducer& reducer) { return reducer.check_finalized(); }, py::call_guard()) .def( - "_force_bucket_rebuild", - [](::c10d::Reducer& reducer) { - return reducer.force_bucket_rebuild(); - }, + "_reset_state", + [](::c10d::Reducer& reducer) { return reducer.reset_state(); }, py::call_guard()) .def( "_update_process_group", diff --git a/torch/csrc/distributed/c10d/reducer.cpp b/torch/csrc/distributed/c10d/reducer.cpp index 7acd1a8fcc7a..c5cf7e6c103c 100644 --- a/torch/csrc/distributed/c10d/reducer.cpp +++ b/torch/csrc/distributed/c10d/reducer.cpp @@ -2314,11 +2314,20 @@ void Reducer::update_process_group( process_group_ = std::move(new_process_group); } -void Reducer::force_bucket_rebuild() { +void Reducer::reset_state() { std::lock_guard lock(mutex_); + // Force rebuild of buckets. has_rebuilt_bucket_ = false; rebuilt_params_.clear(); rebuilt_param_indices_.clear(); + + // Ensure forward can run despite previous backward not succeeding. + expect_autograd_hooks_ = false; + require_finalize_ = false; + + // Unset allreduce division factor, as it may change in next backwards pass + // when running with DDP join mode. + div_factor_ = kUnsetDivFactor; } } // namespace c10d diff --git a/torch/csrc/distributed/c10d/reducer.hpp b/torch/csrc/distributed/c10d/reducer.hpp index 710d17a323db..43782204be05 100644 --- a/torch/csrc/distributed/c10d/reducer.hpp +++ b/torch/csrc/distributed/c10d/reducer.hpp @@ -194,8 +194,8 @@ class TORCH_API Reducer { void update_process_group( c10::intrusive_ptr new_process_group); - // Forces a rebuild of buckets on next iteration. - void force_bucket_rebuild(); + // Resets reducer state. + void reset_state(); protected: // Forward declaration. diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py index 8732718aa87c..e186f1239864 100644 --- a/torch/nn/parallel/distributed.py +++ b/torch/nn/parallel/distributed.py @@ -2254,7 +2254,7 @@ def _update_process_group(self, new_process_group): # re-evaluates previous assumptions of buckets given the world size might have # changed. self._has_rebuilt_buckets = False - self.reducer._force_bucket_rebuild() + self.reducer._reset_state() if not _rank_not_in_group(new_process_group): self.process_group = new_process_group diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index 5c007dcc98fe..f61fe8ef54b6 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -9566,79 +9566,124 @@ def abort(device): running = False t.join() - @skip_if_lt_x_gpu(4) - @require_world_size(4) - @skip_but_pass_in_sandcastle_if( - BACKEND not in DistTestCases.backend_feature["ddp"], - f"The {BACKEND} backend does not support DistributedDataParallel", - ) - def test_ddp_update_process_group(self): + def _run_ddp_update_process_group(self, new_pg): def get_num_torch_recompiles(): guard_failures = torch._dynamo.utils.guard_failures num_recompiles = [len(guard_failures[code]) for code in guard_failures] return 0 if len(num_recompiles) == 0 else max(num_recompiles) - input = torch.rand(10, 10).cuda(self.rank) + class SimulateError(torch.autograd.Function): + @staticmethod + def forward(ctx, input): + return input + + @staticmethod + def backward(ctx, grad_output): + raise RuntimeError() + + class MyModel(torch.nn.Module): + def __init__(self, device): + super().__init__() + # 4MB for multiple buckets. + self.fc1 = torch.nn.Linear(1024, 1024).cuda(device) + self.fc2 = torch.nn.Linear(1024, 1024).cuda(device) + self.fc3 = torch.nn.Linear(1024, 1024).cuda(device) + + def forward(self, inp, error): + if error: + return self.fc3(self.fc2(self.fc1(SimulateError.apply(inp)))) + else: + return self.fc3(self.fc2(self.fc1(inp))) + + + input = torch.rand(10, 1024, requires_grad=True).cuda(self.rank) ddp = torch.nn.parallel.DistributedDataParallel( - torch.nn.Linear(10, 10).cuda(self.rank), + MyModel(self.rank), device_ids=[self.rank], + find_unused_parameters=True, + bucket_cap_mb=1, ) model = torch.compile(ddp) def run_iteration(): - out = model(input) + # Run regular iteration. + out = model(input, error=False) out.sum().backward() torch.cuda.synchronize() - # Run regular iteration. - run_iteration() - num_compiles = get_num_torch_recompiles() - assert 0 == num_compiles - - # Now reduce world_size and run iteration. - group_size_2 = dist.new_group(ranks=[0, 1]) - ddp._update_process_group(group_size_2) - if self.rank in [0, 1]: - run_iteration() + # Run with error. + with self.assertRaises(RuntimeError): + out = model(input, error=True) + out.sum().backward() + torch.cuda.synchronize() - # Increase the world size and run iteration. - group_size_3 = dist.new_group(ranks=[1, 2, 3]) - ddp._update_process_group(group_size_3) - if self.rank in [1, 2, 3]: + run_iteration() + assert 0 == get_num_torch_recompiles() + + if new_pg: + # Now reduce world_size and run iteration. + group_size_2 = dist.new_group(ranks=[0, 1]) + ddp._update_process_group(group_size_2) + if self.rank in [0, 1]: + run_iteration() + + # Increase the world size and run iteration. + group_size_3 = dist.new_group(ranks=[1, 2, 3]) + ddp._update_process_group(group_size_3) + if self.rank in [1, 2, 3]: + run_iteration() + + # Back to default size. + ddp._update_process_group(_get_default_group()) run_iteration() + else: + # Create default pg of smaller size. + dist.destroy_process_group() - # Back to default size. - ddp._update_process_group(_get_default_group()) - run_iteration() + if self.rank in [1, 2, 3]: + dist.init_process_group( + init_method=self.init_method, + backend=BACKEND, + world_size=3, + rank=self.rank - 1, + timeout=timedelta(seconds=default_pg_timeout), + ) + ddp._update_process_group(_get_default_group()) + run_iteration() + dist.destroy_process_group() - # Now create default pg of smaller size. - dist.destroy_process_group() + # Need a barrier here to ensure ranks 1, 2 and 3 are done. + self._barrier(wait_for=4) - if self.rank in [1, 2, 3]: + # Need to init pg again for "_barrier" to succeed. dist.init_process_group( init_method=self.init_method, backend=BACKEND, - world_size=3, - rank=self.rank - 1, + world_size=4, + rank=self.rank, timeout=timedelta(seconds=default_pg_timeout), ) - ddp._update_process_group(_get_default_group()) - run_iteration() - dist.destroy_process_group() - - # Need to init pg again for "_barrier" to succeed. - dist.init_process_group( - init_method=self.init_method, - backend=BACKEND, - world_size=4, - rank=self.rank, - timeout=timedelta(seconds=default_pg_timeout), - ) # Validate no more recompiles. - num_compiles = get_num_torch_recompiles() - assert 0 == num_compiles + assert 0 == get_num_torch_recompiles() + + @skip_if_lt_x_gpu(4) + @require_world_size(4) + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["ddp"], + f"The {BACKEND} backend does not support DistributedDataParallel", + ) + def test_ddp_update_process_group_new_group(self): + self._run_ddp_update_process_group(new_pg=True) + @skip_if_lt_x_gpu(4) + @require_world_size(4) + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["ddp"], + f"The {BACKEND} backend does not support DistributedDataParallel", + ) + def test_ddp_update_process_group_default_group(self): + self._run_ddp_update_process_group(new_pg=False) @skip_if_lt_x_gpu(2) @skip_but_pass_in_sandcastle_if( From a0be4b7ea73024d7979e89f905a959db3086c479 Mon Sep 17 00:00:00 2001 From: Angela Yi Date: Tue, 28 Nov 2023 00:18:41 +0000 Subject: [PATCH 187/221] [fx] Update symbolic_trace nn_module_stack (#114422) Summary: Fixed nn_module_stack dynamo produced by symbolic trace to align with the nn_module_stack metadata produced by dynamo. The key should be the module path, with the value being a unique name, and the type. Something like: `{'L__self___one_module': ("L['self'].one_module", .GraphModuleImpl'>)}` This was causing some tests to fail when using export + the old quantization flow (prepare_fx calls symbolic_trace). Test Plan: D51534471 `buck2 run @//mode/dev-nosan //executorch/backends/xnnpack/test:test_xnnpack_quantized -- -r "test_xnnpack_leaky_relu"` Differential Revision: D51539118 Pull Request resolved: https://github.com/pytorch/pytorch/pull/114422 Approved by: https://github.com/JacobSzwejbka, https://github.com/jerryzh168 --- test/test_fx.py | 4 ++-- torch/fx/_symbolic_trace.py | 2 +- torch/fx/proxy.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/test/test_fx.py b/test/test_fx.py index 30c5f838f127..8de7c3dd6a9c 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -1752,8 +1752,8 @@ def forward(self, x): gm = torch.fx.symbolic_trace(m) mod_stack = {} - expected_stack = [('sub_mod', type(m.sub_mod)), - ('sub_mod.conv_mod', type(m.sub_mod.conv_mod))] + expected_stack = [('sub_mod', ('sub_mod', type(m.sub_mod))), + ('sub_mod.conv_mod', ('sub_mod.conv_mod', type(m.sub_mod.conv_mod)))] for node in gm.graph.nodes: mod_stack = node.meta.get('nn_module_stack', {}) if mod_stack: diff --git a/torch/fx/_symbolic_trace.py b/torch/fx/_symbolic_trace.py index 27f6d1d281ac..5db2cf7db224 100644 --- a/torch/fx/_symbolic_trace.py +++ b/torch/fx/_symbolic_trace.py @@ -478,7 +478,7 @@ def call_module( with ScopeContextManager(self.scope, Scope(module_qualified_name, type(m))) as _scope: # module_stack is an ordered dict so writing then deleting the # entry is equivalent to push/pop on a list - self.module_stack[_scope.module_path] = _scope.module_type + self.module_stack[_scope.module_path] = (module_qualified_name, _scope.module_type) if not self.is_leaf_module(m, module_qualified_name): ret_val = forward(*args, **kwargs) else: diff --git a/torch/fx/proxy.py b/torch/fx/proxy.py index 3147deeecd7d..66b785b8b03f 100644 --- a/torch/fx/proxy.py +++ b/torch/fx/proxy.py @@ -107,7 +107,7 @@ class TracerBase: scope : Scope # Records the module call stack - module_stack: OrderedDict[str, str] + module_stack: OrderedDict[str, Tuple[str, Any]] # Mapping of node name to module scope node_name_to_scope: Dict[str, Tuple[str, type]] From dffa5f3f23535cdf86b57f7476ee69bf76da03af Mon Sep 17 00:00:00 2001 From: Angela Yi Date: Tue, 28 Nov 2023 00:27:23 +0000 Subject: [PATCH 188/221] [dynamo][reland] `ExecutorchCallDelegateHigherOrderVariable` - add sanity check that input and output tensors are disjoint (#114167) Summary: Reland of https://github.com/pytorch/pytorch/pull/111960, Fixes https://github.com/pytorch/pytorch/issues/111917 Original PR broke some internal tests which the current diff has resolved. Test Plan: CI Differential Revision: D51473196 Pull Request resolved: https://github.com/pytorch/pytorch/pull/114167 Approved by: https://github.com/jon-chuang, https://github.com/zou3519 --- test/dynamo/test_higher_order_ops.py | 12 ++++++++++++ torch/_dynamo/variables/higher_order_ops.py | 19 +++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/test/dynamo/test_higher_order_ops.py b/test/dynamo/test_higher_order_ops.py index 51b9786faa74..fbebeaf3baba 100644 --- a/test/dynamo/test_higher_order_ops.py +++ b/test/dynamo/test_higher_order_ops.py @@ -3606,6 +3606,18 @@ def false_fn(x): with self.assertRaises(AssertionError): opt_test(True, False, inp) + def test_non_aliasing_util(self): + from torch._dynamo.variables.higher_order_ops import _assert_tensors_nonaliasing + + a = [torch.tensor(1), {"a": torch.tensor(1)}] + b = (torch.tensor(1),) + _assert_tensors_nonaliasing(a, b) + + with self.assertRaisesRegex( + AssertionError, "inputs to function body cannot alias outputs" + ): + _assert_tensors_nonaliasing(a, a) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py index c9e4cca9ed48..3d2fdf80e0e0 100644 --- a/torch/_dynamo/variables/higher_order_ops.py +++ b/torch/_dynamo/variables/higher_order_ops.py @@ -94,6 +94,18 @@ def inline_call(*args, **kwargs): return inline_call +def _assert_tensors_nonaliasing(inputs, outputs): + input_tensor_ids = { + id(t) for t in pytree.tree_leaves(inputs) if isinstance(t, torch.Tensor) + } + output_tensor_ids = { + id(t) for t in pytree.tree_leaves(outputs) if isinstance(t, torch.Tensor) + } + assert input_tensor_ids.isdisjoint( + output_tensor_ids + ), "inputs to function body cannot alias outputs" + + def validate_args_and_maybe_create_graph_inputs( sub_args, tracer, tx, manually_set_subgraph_inputs ): @@ -819,7 +831,14 @@ def call_function( real_sub_args = pytree.tree_map_only( torch.fx.Proxy, lambda a: get_real_value(a.node, tx.output), p_args ) + example_res = lowered_module.original_module(*real_sub_args) + + # NOTE [Guaranteeing the 1-1 correspondence of FakeTensors and real tensors]: + # executorch modules promise not to alias inputs and outputs. + # Thus, output FakeTensors will correctly not alias input FakeTensors. + _assert_tensors_nonaliasing(real_sub_args, example_res) + example_value = deepcopy_to_fake_tensor(example_res, tx.fake_mode) p_args = (lowered_node,) + p_args From b1fb5912728d7480896fb17d10acef950482a5e4 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Mon, 27 Nov 2023 09:52:59 -0800 Subject: [PATCH 189/221] [replicate] Simplify replicate() init logic and remove unnecessary variables in _ReplicateState (#113679) Many variables _ReplicateState are created because replicate() was lazy initialized. This PR removes these variables and simplifes the logic.y Differential Revision: [D51317874](https://our.internmc.facebook.com/intern/diff/D51317874/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/113679 Approved by: https://github.com/awgu --- test/distributed/_composable/test_compose.py | 4 +- torch/distributed/_composable/replicate.py | 97 +++++++++++--------- 2 files changed, 55 insertions(+), 46 deletions(-) diff --git a/test/distributed/_composable/test_compose.py b/test/distributed/_composable/test_compose.py index 769cc3d9a599..720caeaa5241 100644 --- a/test/distributed/_composable/test_compose.py +++ b/test/distributed/_composable/test_compose.py @@ -187,7 +187,7 @@ def test_fully_shard_replicate_correct_replicate_params(self): # Ensure replicate param names are as expected, i.e. # immediate parameters of model and parameters of model's non-UnitModule # submodules are replicated - param_names = replicate.state(model)._replicate_param_names + param_names = replicate.state(model)._param_names replicated_modules = [ (name, mod) for (name, mod) in model.named_children() @@ -255,7 +255,7 @@ def test_composable_fsdp_replicate(self): # `replicate` are applied on the same module, it should raise exception. model = CompositeModel(device=torch.device("cuda")) fully_shard(model.l1) - with self.assertRaisesRegex(AssertionError, "Cannot apply .*replicate"): + with self.assertRaisesRegex(RuntimeError, "Cannot apply .*replicate"): replicate(model.l1) replicate(model.l2) # should not raise diff --git a/torch/distributed/_composable/replicate.py b/torch/distributed/_composable/replicate.py index 4e253e702be6..b3205f9aff03 100644 --- a/torch/distributed/_composable/replicate.py +++ b/torch/distributed/_composable/replicate.py @@ -4,7 +4,6 @@ import torch import torch.nn as nn from torch.distributed._composable_state import _State - from torch.nn.parallel import DistributedDataParallel from .contract import _get_registry, contract @@ -15,83 +14,85 @@ class _ReplicateState(_State): def __init__(self) -> None: super().__init__() - self.module: Optional[nn.Module] = None + self.module: nn.Module = nn.ParameterList() self.has_initialized: bool = False self._param_list: nn.ParameterList = nn.ParameterList() - self.kwargs: dict = {} - self.ignored_modules: Set[torch.nn.Module] = set() - self.ignored_params: Set[torch.nn.Parameter] = set() - # Only used for testing + # TODO(@fegin): this variable is originally create for testing, we + # should remove this if possible. self._param_names: List[str] = [] - def mark_module( + def _collect_params( self, module: nn.Module, - ignored_modules: Optional[Iterable[torch.nn.Module]], - **kwargs, - ) -> None: - if _is_fully_sharded(module): - raise AssertionError( - "Cannot apply `replicate()` on a Module already managed by `fully_shard`" - ) - self.module = module - self.ignored_modules = set(ignored_modules) if ignored_modules else set() - self.ignored_params = {p for m in self.ignored_modules for p in m.parameters()} - module.register_forward_pre_hook(self.forward_pre_hook, with_kwargs=True) - # TODO(@yhcharles): fix type error - module.register_forward_hook(self.forward_post_hook) # type: ignore[arg-type] - self.kwargs = kwargs - - def _collect_params( - self, module: nn.Module, prefix: str = _ROOT_MODULE_PREFIX + ignored_modules: Set[nn.Module], + ignored_params: Set[nn.Parameter], + prefix: str = _ROOT_MODULE_PREFIX, ) -> None: # skip if managed by fully_sharded API if _is_fully_sharded(module): return - if module in self.ignored_modules: - return # if module A is ignored, all of A's children are also ignored. + # if a module is ignored, all descendants of the module are ignored. + if module in ignored_modules: + return recurse_prefix = ( f"{prefix}." if prefix != _ROOT_MODULE_PREFIX else _ROOT_MODULE_PREFIX ) for n, p in module.named_parameters(recurse=False): - if p not in self.ignored_params: + if p not in ignored_params: self._param_list.append(p) self._param_names.append(f"{recurse_prefix}{n}") for name, child_module in module.named_children(): - self._collect_params(module=child_module, prefix=f"{recurse_prefix}{name}") + self._collect_params( + child_module, + ignored_modules, + ignored_params, + prefix=f"{recurse_prefix}{name}", + ) + + def init( + self, + module: nn.Module, + ignored_modules: Set[nn.Module], + **kwargs, + ) -> None: + if _is_fully_sharded(module): + raise RuntimeError( + "Cannot apply `replicate()` on a Module already managed by `fully_shard`" + ) - def init_helper(self) -> None: if self.has_initialized: return self.has_initialized = True + self.module = module + ignored_params = {p for m in ignored_modules for p in m.parameters()} + self._collect_params(module, ignored_modules, ignored_params) + module.register_forward_pre_hook(self.forward_pre_hook, with_kwargs=True) + module.register_forward_hook(self.forward_post_hook) # type: ignore[arg-type] - self._collect_params(self.module) # type: ignore[arg-type] - # Only saved for testing - replicate.state(self.module)._replicate_param_names = self._param_names - if "device_id" in self.kwargs: + if "device_id" in kwargs: # replicate() supports a small usability enhancement where # user can pass in device_id as a Union[int, torch.device] even for # CPU devices so users don't have to change code for CPU/GPU runs. # We derive the right device_ids to feed into DDP to support this. - if self.kwargs["device_id"] is not None: - device_id = self.kwargs["device_id"] + if kwargs["device_id"] is not None: + device_id = kwargs["device_id"] # Convert to device_ids that DDP expects. if isinstance(device_id, torch.device) and device_id.type == "cpu": # CPU modules receive device_ids None - self.kwargs["device_ids"] = None + kwargs["device_ids"] = None else: # GPU modules expect device_ids=[cuda_device] - self.kwargs["device_ids"] = [device_id] + kwargs["device_ids"] = [device_id] else: - self.kwargs["device_ids"] = None - self.kwargs.pop("device_id") + kwargs["device_ids"] = None + kwargs.pop("device_id") - self._ddp = DistributedDataParallel(self._param_list, **self.kwargs) + self._ddp = DistributedDataParallel(self._param_list, **kwargs) # Weakref to the DDP instance is currently only used for testing. replicate.state(self.module)._ddp_weakref = weakref.ref(self._ddp) @@ -111,7 +112,7 @@ def forward_post_hook( @contract(state_cls=_ReplicateState) def replicate( - module: nn.Module, # NOTE: contract now supports single module only + module: nn.Module, ignored_modules: Optional[Iterable[torch.nn.Module]] = None, **kwargs, ) -> nn.Module: @@ -126,14 +127,22 @@ def replicate( >>> replicate(module) """ torch._C._log_api_usage_once("torch.distributed.replicate") + + # TODO(fegin): using kwargs is not a good idea if we would like to make + # replicate a formal API to replace DDP. if "device_id" in kwargs: if not isinstance(kwargs["device_id"], (int, torch.device)): raise RuntimeError( - f"Expected device_id to be int or torch.device, but got {type(kwargs['device_id'])}" + "Expected device_id to be int or torch.device, " + f"but got {type(kwargs['device_id'])}" ) - replicate.state(module).mark_module(module, ignored_modules, **kwargs) - replicate.state(module).init_helper() + if ignored_modules is None: + ignored_modules = {} + else: + ignored_modules = set(ignored_modules) + replicate.state(module).init(module, ignored_modules, **kwargs) + return module From e592b9a469e83defb92c8d8167940b13db2843de Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Sun, 26 Nov 2023 09:09:59 +0800 Subject: [PATCH 190/221] [Quant] [PT2] Fix an issue in Conv Binary Quantization Annotation (#114540) **Summary** To annotate a conv-binary pattern, should skip the pattern if the conv node has more than one user. **Test Plan** ``` python -m pytest test_x86inductor_quantizer.py -k test_conv2d_binary2 python -m pytest test_x86inductor_quantizer.py -k test_qat_conv2d_binary2 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/114540 Approved by: https://github.com/jgong5, https://github.com/jerryzh168 --- .../pt2e/test_x86inductor_quantizer.py | 106 ++++++++++++++++++ .../quantizer/x86_inductor_quantizer.py | 14 ++- 2 files changed, 118 insertions(+), 2 deletions(-) diff --git a/test/quantization/pt2e/test_x86inductor_quantizer.py b/test/quantization/pt2e/test_x86inductor_quantizer.py index 9c62bb1b74fc..c1616f960b21 100644 --- a/test/quantization/pt2e/test_x86inductor_quantizer.py +++ b/test/quantization/pt2e/test_x86inductor_quantizer.py @@ -258,6 +258,30 @@ def __init__(self, use_bias, postop, inplace_postop) -> None: def forward(self, x): return self.postop(self.linear(x)) + class Conv2dAddModule2(torch.nn.Module): + def __init__(self, + inplace_add: bool = False, + ) -> None: + super().__init__() + self.conv = torch.nn.Conv2d( + in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1 + ) + self.conv2 = torch.nn.Conv2d( + in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1 + ) + self.inplace_add = inplace_add + self.bn = torch.nn.BatchNorm2d(3) + self.bn2 = torch.nn.BatchNorm2d(3) + + def forward(self, x): + if self.inplace_add: + tmp = self.bn(self.conv(x)) + tmp += self.bn2(self.conv2(tmp)) + return tmp + else: + tmp = self.bn(self.conv(x)) + return tmp + self.bn2(self.conv2(tmp)) + class X86InductorQuantTestCase(QuantizationTestCase): def _test_quantizer( self, @@ -418,6 +442,46 @@ def test_conv2d_binary(self): node_list, ) + + @skipIfNoX86 + def test_conv2d_binary2(self): + """ + Test Pattern: + tmp = conv2d_1(x) + tmp2 = conv2d_2(tmp) + return tmp + tmp2 + Since conv2d_1 has 2 users, we should annotate conv2d_2 for binary fusion instead of conv2d_1 + """ + example_inputs = (torch.randn(2, 3, 6, 6),) + quantizer = X86InductorQuantizer().set_global( + xiq.get_default_x86_inductor_quantization_config() + ) + inplace_add_list = [True, False] + with override_quantized_engine("x86"), torch.no_grad(): + for inplace_add in inplace_add_list: + m = TestHelperModules.Conv2dAddModule2(inplace_add=inplace_add).eval() + node_occurrence = { + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, + } + node_list = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv2d.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.aten.add_.Tensor if inplace_add else torch.ops.aten.add.Tensor, + ] + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + @skipIfNoX86 def test_conv2d_binary_unary(self): """ @@ -1006,6 +1070,48 @@ def test_qat_conv2d_binary(self): is_qat=True, ) + @skipIfNoX86 + def test_qat_conv2d_binary2(self): + """ + Test qat Pattern: + tmp = bn1(conv2d_1(x)) + tmp2 = bn2(conv2d_2(tmp)) + return tmp + tmp2 + Since conv2d_1 has 2 users, we should annotate conv2d_2 for binary fusion instead of conv2d_1 + """ + example_inputs = (torch.randn(2, 3, 6, 6),) + quantizer = X86InductorQuantizer().set_global( + xiq.get_default_x86_inductor_quantization_config(is_qat=True) + ) + inplace_add_list = [True, False] + with override_quantized_engine("x86"), torch.no_grad(): + for inplace_add in inplace_add_list: + m = TestHelperModules.Conv2dAddModule2(inplace_add=inplace_add) + node_occurrence = { + torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, + # BN should be folded into Conv + torch.ops.aten._native_batch_norm_legit.default: 0, + } + node_list = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv2d.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.aten.add_.Tensor if inplace_add else torch.ops.aten.add.Tensor, + ] + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + is_qat=True, + ) + @skipIfNoX86 def test_qat_conv2d_binary_unary(self): """ diff --git a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py index 97ab14f08258..cdbc85f0fb12 100644 --- a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py @@ -445,7 +445,9 @@ def _annotate_qat_conv2d_bn_binary_unary( ) = self._get_output_nodes_of_partitions( [conv_partition, bn_partition, binary_partition, unary_partition] ) - + if len(bn_output_node.users) != 1: + # Conv BN pattern should only has 1 user. + continue ( bn_output_node_idx, extra_input_node_idx, @@ -502,7 +504,9 @@ def _annotate_qat_conv2d_bn_binary( ) = self._get_output_nodes_of_partitions( [conv_partition, bn_partition, binary_partition] ) - + if len(bn_output_node.users) != 1: + # Conv BN pattern should only has 1 user. + continue ( bn_output_node_idx, extra_input_node_idx, @@ -634,6 +638,9 @@ def _annotate_conv2d_binary_unary( conv_node, binary_node, unary_node = self._get_output_nodes_of_partitions( [conv_partition, binary_partition, unary_partition] ) + if len(conv_node.users) != 1: + # Conv Node should only has 1 user node + continue conv_node_idx, extra_input_node_idx = self._get_input_idx_for_binary_node( conv_node, binary_node ) @@ -676,6 +683,9 @@ def _annotate_conv2d_binary( conv_node, binary_node = self._get_output_nodes_of_partitions( [conv_partition, binary_partition] ) + if len(conv_node.users) != 1: + # Conv Node should only has 1 user node + continue conv_node_idx, extra_input_node_idx = self._get_input_idx_for_binary_node( conv_node, binary_node ) From cf9f3ae8d82afd1f2608637821acb1b386975cdf Mon Sep 17 00:00:00 2001 From: Huy Do Date: Tue, 28 Nov 2023 01:11:19 +0000 Subject: [PATCH 191/221] Skip an example of test_instance_norm when running internally due to its size (#114452) After https://github.com/pytorch/pytorch/pull/113420, `torch.unique` now includes a call to `torch.sort` and that call is slow when running in dev mode, i.e. `@fbcode//mode/dev`. This causes the test to take more than 10 minutes and time out internally [T170720856](https://www.internalfb.com/intern/tasks/?t=170720856). Running the test in `@fbcode//mode/opt` is fine, so please let me know if there is a way to set that. Otherwise, this change will skip the largest example when running in sandcastle internally. Pull Request resolved: https://github.com/pytorch/pytorch/pull/114452 Approved by: https://github.com/malfet --- test/quantization/core/test_quantized_op.py | 23 ++++++++++++--------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/test/quantization/core/test_quantized_op.py b/test/quantization/core/test_quantized_op.py index 59784a63d3ef..3c294211c494 100644 --- a/test/quantization/core/test_quantized_op.py +++ b/test/quantization/core/test_quantized_op.py @@ -22,7 +22,7 @@ hu.assert_deadline_disabled() from torch.testing._internal.common_utils import TestCase -from torch.testing._internal.common_utils import IS_PPC, TEST_WITH_UBSAN, IS_MACOS, BUILD_WITH_CAFFE2 +from torch.testing._internal.common_utils import IS_PPC, TEST_WITH_UBSAN, IS_MACOS, BUILD_WITH_CAFFE2, IS_SANDCASTLE from torch.testing._internal.common_quantization import skipIfNoFBGEMM, skipIfNoQNNPACK, skipIfNoONEDNN from torch.testing._internal.common_quantized import _quantize, _dequantize, _calculate_dynamic_qparams, \ override_quantized_engine, supported_qengines, override_qengines, _snr @@ -2574,15 +2574,18 @@ def test_instance_norm(self): combined = [shape_list, torch_types, y_scales, y_zero_points, channels_last_list, affine_list] test_cases_product = itertools.product(*combined) test_cases = list(test_cases_product) - # add just one test case to test overflow - test_cases.append([ - [1, 4, 224, 224, 160], # shape, - torch.qint8, # torch_type - 0.1, # scale - 0, # zero_point - False, # channels_last - True, # affine - ]) + # NB: Add just one test case to test overflow, but this case is too slow to run + # internally in @fbcode//mode/dev, the long pole is the 4x calls to torch.sort + # inside torch.unique current implementation + if not IS_SANDCASTLE: + test_cases.append([ + [1, 4, 224, 224, 160], # shape, + torch.qint8, # torch_type + 0.1, # scale + 0, # zero_point + False, # channels_last + True, # affine + ]) with override_quantized_engine("fbgemm"): for test_case in test_cases: From 304ea761f577c673bd92f8ef5cd01fc2e7d15829 Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Tue, 28 Nov 2023 01:25:46 +0000 Subject: [PATCH 192/221] [executorch][be] update test_emit to use export (#114294) Summary: exir.capture is deprecated. Switch to blessed path Test Plan: fbsource/fbcode/executorch/exir/emit/test (c40a7a0d2)]$ buck test : Differential Revision: D51503120 Pull Request resolved: https://github.com/pytorch/pytorch/pull/114294 Approved by: https://github.com/zhxchen17 --- torch/_export/verifier.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/torch/_export/verifier.py b/torch/_export/verifier.py index 391d7f99f69b..225ac16344c6 100644 --- a/torch/_export/verifier.py +++ b/torch/_export/verifier.py @@ -203,12 +203,23 @@ def _allowed_op_types() -> Tuple[Type[Any], ...]: if isinstance(attr, torch.nn.Module): def _is_type(name, ty): return isinstance(getattr(attr, name, None), ty) - if type(attr).__name__ == "LoweredBackendModule" \ - and _is_type("backend_id", str) \ - and _is_type("processed_bytes", bytes) \ - and _is_type("compile_specs", list) \ - and hasattr(attr, "original_module"): - continue + if type(attr).__name__ == "LoweredBackendModule": + if _is_type("backend_id", str) \ + and _is_type("processed_bytes", bytes) \ + and _is_type("compile_specs", list) \ + and hasattr(attr, "original_module"): + continue + else: + backend_id = getattr(attr, "backend_id", None) + processed_bytes = getattr(attr, "processed_bytes", None) + compile_specs = getattr(attr, "compile_specs", None) + raise SpecViolationError( + f"Invalid get_attr type {type(attr)}. \n" + f"LoweredBackendModule fields: " + f"backend_id(str) : {type(backend_id)}, " + f"processed_bytes(bytes) : {type(processed_bytes)}, " + f"compile_specs(list) : {type(compile_specs)}" + ) if not isinstance(attr, _allowed_getattr_types()): raise SpecViolationError( From e25b146b8c85a3da964986195c6f3f338db45428 Mon Sep 17 00:00:00 2001 From: Aaron Gokaslan Date: Tue, 28 Nov 2023 01:27:51 +0000 Subject: [PATCH 193/221] [BE][Easy]: Enable flake8-exe rules in ruff too. (#114521) Enable flake8-exe rules in ruff too. RUFF requires EXE rules to enabled separately from the E prefix. This fixes a parity bug between flake8 and ruff. Pull Request resolved: https://github.com/pytorch/pytorch/pull/114521 Approved by: https://github.com/kit1980 --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index d59bed1e9187..a929d58e7a3a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,6 +72,7 @@ select = [ "C4", "G", "E", + "EXE", "F", "SIM1", "W", From 74370a8a9d97bc8738b22b1e49eabc8d424f93f4 Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Fri, 24 Nov 2023 09:53:55 +0800 Subject: [PATCH 194/221] Add adaptive_avg_pool2d and flatten into x86 Inductor Quantizer recipe (#114442) **Summary** Add adaptive_avg_pool2d and flatten into x86 Inductor Quantizer recipe **Test Plan** ``` python -m pytest test_x86inductor_quantizer.py -k test_adaptive_avg_pool2d_recipe python -m pytest test_x86inductor_quantizer.py -k test_flatten_recipe ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/114442 Approved by: https://github.com/jgong5, https://github.com/jerryzh168 --- .../pt2e/test_x86inductor_quantizer.py | 79 +++++++++++++------ .../quantizer/x86_inductor_quantizer.py | 20 +++-- 2 files changed, 65 insertions(+), 34 deletions(-) diff --git a/test/quantization/pt2e/test_x86inductor_quantizer.py b/test/quantization/pt2e/test_x86inductor_quantizer.py index c1616f960b21..1158d72c5bbb 100644 --- a/test/quantization/pt2e/test_x86inductor_quantizer.py +++ b/test/quantization/pt2e/test_x86inductor_quantizer.py @@ -155,15 +155,15 @@ def forward(self, x): else: return self.relu2(self.conv(x) + self.conv2(x)) - class Conv2dMaxpoolPowModule(nn.Module): - def __init__(self): + class Conv2dSingleOpPowModule(nn.Module): + def __init__(self, single_op): super().__init__() self.conv = nn.Conv2d(2, 2, 1) - self.pool = nn.MaxPool2d(1, 1) + self.single_op = single_op def forward(self, x): x = self.conv(x) - x = self.pool(x) + x = self.single_op(x) return torch.pow(x, 2) class SerialsConv2dAddReLUModule(torch.nn.Module): @@ -569,14 +569,7 @@ def test_conv2d_serials_binary_unary(self): node_list, ) - @skipIfNoX86 - def test_maxpool2d_recipe(self): - r""" - Test pattern: int8_in_int8_out_ops(maxpool) - non_quantizable op(pow) - Since maxpool is a int8_in_int8_out_op, there is obs between maxpool and pow. - """ - m = TestHelperModules.Conv2dMaxpoolPowModule().eval() - x = torch.rand(1, 2, 14, 14) + def _single_op_share_observer_recipe_test_helper(self, m, x, single_op): quantizer = X86InductorQuantizer().set_global( xiq.get_default_x86_inductor_quantization_config() ) @@ -595,7 +588,7 @@ def test_maxpool2d_recipe(self): torch.ops.aten.conv2d.default, torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, - torch.ops.aten.max_pool2d.default, + single_op, torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, ] @@ -610,14 +603,14 @@ def test_maxpool2d_recipe(self): for node in prepare_model.graph.nodes: if ( node.op == "call_function" - and node.target is torch.ops.aten.max_pool2d.default + and node.target is single_op ): - maxpool_node = node - input_obs_of_maxpool = getattr( - prepare_model, maxpool_node.args[0].target + single_op_node = node + input_obs_of_single_op = getattr( + prepare_model, single_op_node.args[0].target ) - output_obs_of_maxpool = getattr( - prepare_model, list(maxpool_node.users)[0].target + output_obs_of_single_op = getattr( + prepare_model, list(single_op_node.users)[0].target ) elif ( node.op == "call_function" @@ -625,11 +618,51 @@ def test_maxpool2d_recipe(self): ): conv_node = node input_obs_of_conv = getattr(prepare_model, conv_node.args[0].target) - self.assertTrue(isinstance(input_obs_of_maxpool, ObserverBase)) - self.assertTrue(isinstance(output_obs_of_maxpool, ObserverBase)) + self.assertTrue(isinstance(input_obs_of_single_op, ObserverBase)) + self.assertTrue(isinstance(output_obs_of_single_op, ObserverBase)) self.assertTrue(isinstance(input_obs_of_conv, ObserverBase)) - self.assertTrue(input_obs_of_maxpool is output_obs_of_maxpool) - self.assertTrue(input_obs_of_maxpool is not input_obs_of_conv) + self.assertTrue(input_obs_of_single_op is output_obs_of_single_op) + self.assertTrue(input_obs_of_single_op is not input_obs_of_conv) + + + @skipIfNoX86 + def test_maxpool2d_recipe(self): + r""" + Test pattern: int8_in_int8_out_ops(maxpool) - non_quantizable op(pow) + Since maxpool is a int8_in_int8_out_op, there is obs between maxpool and pow. + """ + self._single_op_share_observer_recipe_test_helper( + TestHelperModules.Conv2dSingleOpPowModule(nn.MaxPool2d(1, 1)).eval(), + torch.rand(1, 2, 14, 14), + torch.ops.aten.max_pool2d.default, + ) + + + @skipIfNoX86 + def test_adaptive_avg_pool2d_recipe(self): + r""" + Test pattern: int8_in_int8_out_ops(adaptive_avg_pool2d) - non_quantizable op(pow) + Since adaptive_avg_pool2d is a int8_in_int8_out_op, there is obs between adaptive_avg_pool2d and pow. + """ + self._single_op_share_observer_recipe_test_helper( + TestHelperModules.Conv2dSingleOpPowModule(nn.AdaptiveAvgPool2d((1, 1))).eval(), + torch.rand(1, 2, 14, 14), + torch.ops.aten.adaptive_avg_pool2d.default, + ) + + + @skipIfNoX86 + def test_flatten_recipe(self): + r""" + Test pattern: int8_in_int8_out_ops(flatten) - non_quantizable op(pow) + Since flatten is a int8_in_int8_out_op, there is obs between flatten and pow. + """ + self._single_op_share_observer_recipe_test_helper( + TestHelperModules.Conv2dSingleOpPowModule(lambda x: torch.flatten(x, 1)).eval(), + torch.rand(1, 2, 14, 14), + torch.ops.aten.flatten.using_ints, + ) + @skipIfNoX86 def test_cat_recipe(self): diff --git a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py index cdbc85f0fb12..69ad2c7af004 100644 --- a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py @@ -53,24 +53,22 @@ class _X86InductorQuantizationAnnotation(QuantizationAnnotation): _is_output_of_quantized_pattern: bool = False -# Ops support int8 data type and excludes ops like conv, linear. -quantizable_ops_pt2e: Set = { - torch.ops.aten.max_pool2d.default, - torch.ops.aten.cat.default, - torch.ops.aten.avg_pool2d.default, -} - - -# Ops that: -# 1. Ops prefer to run with int8 when int8 input is given. -# 2. Ops don't support int8 in and fp32 out. +# Operations that: +# 1. Operations are optimized to run with int8 when int8 input provided. +# 2. Operations do not support int8 input and produce fp32 output. int8_in_int8_out_ops_pt2e: Set = { torch.ops.aten.max_pool2d.default, torch.ops.aten.cat.default, torch.ops.aten.avg_pool2d.default, + torch.ops.aten.adaptive_avg_pool2d.default, + torch.ops.aten.flatten.using_ints, } +# Operations support the int8 data type and exclude operations such as conv and linear. +# A superset of int8_in_int8_out_ops_pt2e incorporating additional operators. +quantizable_ops_pt2e = copy.deepcopy(int8_in_int8_out_ops_pt2e) + QUANT_ANNOTATION_KEY = "quantization_annotation" From 4ba3e6758d9d94be111cf3bc735421dfe4dc2b0a Mon Sep 17 00:00:00 2001 From: lezcano Date: Mon, 27 Nov 2023 22:24:37 +0000 Subject: [PATCH 195/221] Canonicalize runtime asserts (#114509) This allows us to remove quite a few redundant runtime asserts, and potentially a number of guards as well. On ``` python test/dynamo/test_subclasses.py -k test_unbind ``` we go from ``` inserting runtime assert i0 <= s0 inserting runtime assert 0 <= -i0 + s0 inserting runtime assert i0 + i1 <= s0 inserting runtime assert i0 <= -i1 + s0 inserting runtime assert i0 + i1 + i2 <= s0 inserting runtime assert i0 + i1 <= -i2 + s0 inserting runtime assert Eq(i0 + i1 + i2 + i3, s0) inserting runtime assert i0 + i1 + i2 + i3 <= s0 inserting runtime assert i0 + i1 + i2 <= -i3 + s0 ``` to ``` inserting runtime assert i0 - s0 <= 0 inserting runtime assert i0 + i1 - s0 <= 0 inserting runtime assert i0 + i1 + i2 - s0 <= 0 inserting runtime assert Eq(i0 + i1 + i2 + i3, s0) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/114509 Approved by: https://github.com/voznesenskym --- docs/source/conf.py | 1 + test/test_dynamic_shapes.py | 2 +- torch/fx/experimental/symbolic_shapes.py | 79 ++++++++++++++++++++---- 3 files changed, 69 insertions(+), 13 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index dcd3c7694674..b3eb2442ac27 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -974,6 +974,7 @@ "parallel_and", "parallel_or", "sym_eq", + "canonicalize_bool_expr", # torch.fx.experimental.unification.core "reify", # torch.fx.experimental.unification.match diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index bf843587af50..227dc0ce8da2 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -497,7 +497,7 @@ def test_expect_true_with_s0(self): self.assertTrue(expect_true(i0 <= s0)) self.assertExpectedInline( str([ra.expr for ra in shape_env.deferred_runtime_asserts[i0.node.expr]]), - """[i0 <= s0]""" + """[i0 - s0 <= 0]""" ) self.assertTrue(i0 <= s0) self.assertFalse(i0 > s0) diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 7f056ec9d5a7..a3f7d0f476ae 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -60,7 +60,7 @@ class GuardOnDataDependentSymNode(RuntimeError): __all__ = [ "has_symbolic_sizes_strides", "create_contiguous", "ShapeEnv", "is_concrete_int", - "guard_int", "guard_float", "guard_scalar", + "guard_int", "guard_float", "guard_scalar", "canonicalize_bool_expr", "hint_int", "SYMPY_INTERP", "free_symbols", "is_symbol_binding_fx_node", "is_concrete_bool", "SHAPEENV_EVENT_KEY", "CURRENT_NODE_KEY", "has_free_symbols", "sym_eq", "SymbolicContext", "StatelessSymbolicContext", "StatefulSymbolicContext" @@ -138,6 +138,54 @@ def is_concrete_int(a: Union[int, SymInt]): return False +def canonicalize_bool_expr(expr: sympy.Expr): + r""" Canonicalize a boolean expression by transforming it into a lt / le + inequality and moving all the non-constant terms to the rhs. + We canonicalize And / Ors / Not via cnf and then canonicalize their subexpr + recursively + nb. sympy.Rel.canonical is not good enough https://github.com/sympy/sympy/issues/25924 + + Args: + expr (sympy.Expr): Expression to canonicalize + """ + # Canonicalise an inequality by transforming it into a lt / le + # inequality and moving all the non-constant terms to the rhs + # We canonicalise And / Ors / Not via cnf + # nb. Relational.canonical in sympy is broken + # https://github.com/sympy/sympy/issues/25924 + + if not isinstance(expr, (sympy.Rel, sympy.And, sympy.Or, sympy.Not, sympy.Eq, sympy.Ne)): + return expr + + if isinstance(expr, (sympy.And, sympy.Or, sympy.Not)): + expr = sympy.logic.boolalg.to_cnf(expr) + return _canonicalize_bool_expr_impl(expr) + +def _canonicalize_bool_expr_impl(expr: sympy.Expr): + if isinstance(expr, (sympy.And, sympy.Or)): + return type(expr)(*map(canonicalize_bool_expr, expr.args)) + + opposite = {sympy.Gt: sympy.Lt, sympy.Ge: sympy.Le} + if isinstance(expr, tuple(opposite.keys())): + lhs = expr.rhs - expr.lhs + t = opposite[type(expr)] + else: + assert isinstance(expr, (sympy.Lt, sympy.Le, sympy.Eq, sympy.Ne)) + lhs = expr.lhs - expr.rhs + t = type(expr) + rhs = 0 + if isinstance(lhs, sympy.Add): + cts = [] + variables = [] + for term in lhs.args: + if term.is_number: + cts.append(term) + else: + variables.append(term) + lhs = sympy.Add(*variables) + rhs = -sympy.Add(*cts) + return t(lhs, rhs) + def is_concrete_bool(a: Union[bool, SymBool]): r""" Utility to check if underlying object in SymBool is concrete value. Also returns @@ -3049,6 +3097,8 @@ def _maybe_evaluate_static( if compute_hint: expr = expr.xreplace(self.var_to_val) + expr = canonicalize_bool_expr(expr) + symbols = list(expr.free_symbols) # Apply known runtime asserts @@ -3057,17 +3107,20 @@ def _maybe_evaluate_static( if s in self.var_to_val: continue subst = {} - if s in self.deferred_runtime_asserts: - for ra in self.deferred_runtime_asserts[s]: - if compute_hint: - e = ra.expr.xreplace(self.var_to_val) - else: - e = ra.expr - subst[e] = sympy.true - subst[sympy.Not(e)] = sympy.false - # NB: this doesn't match relations if they're flipped; e.g., - # if you have x < 5, we won't get 5 > x. Holler if this is - # a problem + for ra in self.deferred_runtime_asserts.get(s, ()): + if compute_hint: + e = canonicalize_bool_expr(ra.expr.xreplace(self.var_to_val)) + else: + e = ra.expr + # e is already canonical + subst[e] = sympy.true + subst[canonicalize_bool_expr(sympy.Not(e))] = sympy.false + if isinstance(e, sympy.Eq): + subst[sympy.Le(e.lhs, e.rhs)] = sympy.true + subst[sympy.Le(-e.lhs, -e.rhs)] = sympy.true + subst[sympy.Lt(e.lhs, e.rhs)] = sympy.false + subst[sympy.Lt(-e.lhs, -e.rhs)] = sympy.false + # NB: this helps us deal with And/Or connectives expr = expr.subs(subst) @@ -3628,6 +3681,8 @@ def defer_runtime_assert(self, orig_expr: "sympy.Expr", msg, fx_node=None): self._maybe_guard_eq(expr, True) if not self._suppress_guards_tls(): + # canonicalise to remove equations that are trivially equal + expr = canonicalize_bool_expr(expr) stack = CapturedTraceback.extract(skip=1) ra = RuntimeAssert(expr, msg, stack) # TODO: Do this in a way that is less janky than int(s.name[1:]) From 6ae0554d11b973930d7b8ec1e937b27ac961d7bf Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Fri, 24 Nov 2023 09:53:56 +0800 Subject: [PATCH 196/221] Enable the lowering of quantized reshape (#114443) **Summary** Enable the lowering of `dq->reshape->q` into a `qreshape` **Test Plan** ``` python -m pytest test_mkldnn_pattern_matcher.py -k test_qflatten ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/114443 Approved by: https://github.com/jgong5, https://github.com/eellison, https://github.com/jerryzh168 ghstack dependencies: #114442 --- test/inductor/test_mkldnn_pattern_matcher.py | 36 +++++++++++++++++++ torch/_inductor/fx_passes/quantization.py | 37 ++++++++++++++++++-- 2 files changed, 70 insertions(+), 3 deletions(-) diff --git a/test/inductor/test_mkldnn_pattern_matcher.py b/test/inductor/test_mkldnn_pattern_matcher.py index dc7bfe54c5e1..d2528fa86499 100644 --- a/test/inductor/test_mkldnn_pattern_matcher.py +++ b/test/inductor/test_mkldnn_pattern_matcher.py @@ -1176,6 +1176,42 @@ def forward(self, x): check_quantization=True, ) + @skipIfNoDynamoSupport + @skipIfRocm + def test_qflatten(self): + r""" + This testcase will quantize Conv2d->AdaptiveAvgPool2d->flatten pattern. + """ + + class M(torch.nn.Module): + def __init__( + self, + ): + super().__init__() + self.conv = torch.nn.Conv2d( + 3, 64, 7, bias=True, stride=2, padding=3, dilation=1 + ) + self.relu = torch.nn.ReLU() + self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1)) + + def forward(self, x): + return torch.flatten( + self.adaptive_avg_pool2d(self.relu(self.conv(x))), 1 + ) + + mod = M().eval() + v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False).add(1) + + def matcher_check_fn(): + self.assertEqual(counters["inductor"]["qreshape_matcher_count"], 1) + + self._test_common( + mod, + (v,), + check_quantization=True, + matcher_check_fn=matcher_check_fn, + ) + @skipIfNoDynamoSupport @skipIfRocm def test_qcat(self): diff --git a/torch/_inductor/fx_passes/quantization.py b/torch/_inductor/fx_passes/quantization.py index 7d35f78c8619..2bb4824dc95f 100644 --- a/torch/_inductor/fx_passes/quantization.py +++ b/torch/_inductor/fx_passes/quantization.py @@ -771,7 +771,7 @@ def _register_quantization_maxpool2d(): ) -def _is_valid_quantized_cat_optimization_pattern(): +def _is_input_output_same_scale_zp(check_node): def fn(match): # Ensure all the inputs and output has same scale and zero point # Step 1: Check inputs/output zero point @@ -790,7 +790,7 @@ def fn(match): scales = [ ( mul_node.args[1] - if mul_node.args[0].target is aten.cat.default + if mul_node.args[0].target is check_node else 1.0 / mul_node.args[1] ) for mul_node in mul_nodes @@ -809,7 +809,7 @@ def _register_quantized_cat_lowering( ): @register_lowering_pattern( pattern, - extra_check=_is_valid_quantized_cat_optimization_pattern(), + extra_check=_is_input_output_same_scale_zp(aten.cat.default), ) def qcat(match: Match, inputs, dim, **kwargs): # inputs is with format: [[x1, x1_dq_dtype, x1_zp, x1_scale], ...] @@ -846,11 +846,42 @@ def _register_quantization_cat(): ) +def _register_quantized_reshape_lowering( + pattern, + computation_op, +): + @register_lowering_pattern( + pattern, + extra_check=_is_input_output_same_scale_zp(aten.reshape.default), + ) + def qreshape(match: Match, *args, **kwargs): + qx = kwargs["x"] + shape = kwargs["shape"] + counters["inductor"]["qreshape_matcher_count"] += 1 + counters["inductor"]["qreshape_matcher_nodes"] += len(match.nodes) + return L[computation_op](qx, shape) + + return qreshape + + +def _register_quantization_reshape(): + dequantize_reshape_pattern = CallFunction( + torch.ops.aten.reshape.default, + dequantize_per_tensor_activation_pattern, + KeywordArg("shape"), + ) + _register_quantized_reshape_lowering( + generate_pattern_with_output_quant(dequantize_reshape_pattern), + aten.reshape, + ) + + def _register_quantization_lowerings(): _register_quantization_unary_fusion() _register_quantization_binary_fusion() _register_quantization_maxpool2d() _register_quantization_cat() + _register_quantization_reshape() def _is_valid_dequant_promotion_pattern(dtype=torch.float32): From ae40a3ebcfa233bd66f0e6a11656c0bea69904dc Mon Sep 17 00:00:00 2001 From: Yang Chen Date: Mon, 27 Nov 2023 09:35:18 -0800 Subject: [PATCH 197/221] [inductor] added a config to dump profiling results to a file (#114587) Currently, we print out profile bandwidth result for each triton kernel to stdout after each profiling run finishes. Consequently, the profiling results are mixed with other debug outputs. This PR adds a config, profile_bandwidth_output, to specify a file where we can dump the results in a sorted order. The new config can be set by setting "TORCHINDUCTOR_PROFILE_OUTPUT" environment variable. Hopefully it would offer a slightly better way to navigate the profiling results. Pull Request resolved: https://github.com/pytorch/pytorch/pull/114587 Approved by: https://github.com/Chillee --- torch/_inductor/config.py | 3 +++ torch/_inductor/triton_heuristics.py | 26 ++++++++++++++++++++++++-- 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 3912c65a4d1e..5f91c96af38c 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -347,6 +347,9 @@ def decide_compile_threads(): _profile_var = os.environ.get("TORCHINDUCTOR_PROFILE", "") profile_bandwidth = _profile_var != "" profile_bandwidth_regex = "" if _profile_var == "1" else _profile_var +# Specify a file where we print out the profiling results. +# None means we do not dump results to a file. +profile_bandwidth_output = os.environ.get("TORCHINDUCTOR_PROFILE_OUTPUT", None) # TODO: remove later disable_cpp_codegen = False diff --git a/torch/_inductor/triton_heuristics.py b/torch/_inductor/triton_heuristics.py index 4f0a8c7aff94..d39eee24eee8 100644 --- a/torch/_inductor/triton_heuristics.py +++ b/torch/_inductor/triton_heuristics.py @@ -591,11 +591,33 @@ def end_graph(): overall_time = sum(call[0] for call in collected_calls) overall_gb = sum(call[1] for call in collected_calls) cur_file = inspect.stack()[1].filename - print(f"SUMMARY ({cur_file})") - print( + summary_str = ( + f"SUMMARY ({cur_file})\n" f"{overall_time:.2f}ms \t {overall_gb:.2f} GB\t {overall_gb/(overall_time/1e3):.2f}GB/s" ) + print(summary_str) print() + output_file = config.profile_bandwidth_output + if output_file is not None: + # sort perf numbers in descending order, i.e. placing the + # most runtime-heavy kernels at the top of the list + sorted_calls = sorted(collected_calls, key=lambda c: float(c[0]), reverse=True) + try: + with open(output_file, "a") as file: + log.debug("Save profile bandwidth results to %s", output_file) + file.write("====================\n") + file.write(f"TRITON KERNELS BANDWIDTH INFO ({cur_file})\n") + for ms, num_gb, gb_per_s, kernel_name in sorted_calls: + # also display the runtime percentage for each kernel + percentage = f"{ms/overall_time*100:.2f}%" + suffix = f" \t {percentage} \t {kernel_name}" + bw_info_str = create_bandwidth_info_str( + ms, num_gb, gb_per_s, suffix=suffix + ) + file.write(bw_info_str + "\n") + file.write(f"{summary_str}\n\n") + except Exception as e: + log.warning("failed to write profile bandwidth result into %s", output_file) class DebugAutotuner(CachingAutotuner): From 2333d381b214ff46eb08faaf319ae395abfb4132 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Mon, 27 Nov 2023 15:10:38 -0800 Subject: [PATCH 198/221] Make 'distributed' TORCH_LOGS include ddpoptimizer (#114376) There are now 3 ways to see logs from ddpoptimzer. 1) TORCH_LOGS="distributed" 2) TORCH_LOGS="dynamo" 3) TORCH_LOGS="torch._dynamo.backends.distributed" (1 and 2 are different supersets of 3 that also include other content) Note: ddp_graphs is still a separate 'artifact' logger, which just includes graph dumps from the graph-splitting process. Pull Request resolved: https://github.com/pytorch/pytorch/pull/114376 Approved by: https://github.com/wanchaol --- torch/_dynamo/backends/distributed.py | 3 +++ torch/_logging/_registrations.py | 3 ++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/torch/_dynamo/backends/distributed.py b/torch/_dynamo/backends/distributed.py index adc68bb30bff..85e2e066290a 100644 --- a/torch/_dynamo/backends/distributed.py +++ b/torch/_dynamo/backends/distributed.py @@ -9,6 +9,9 @@ from torch._dynamo.utils import deepcopy_to_fake_tensor, detect_fake_mode from torch.fx.node import Node +# Regular log messages should go through 'log'. +# ddp_graph_log is a separate artifact logger reserved for dumping graphs. +# See docs/source/logging.rst for more info. log = logging.getLogger(__name__) ddp_graph_log = torch._logging.getArtifactLogger(__name__, "ddp_graphs") diff --git a/torch/_logging/_registrations.py b/torch/_logging/_registrations.py index 5f66f9ac4514..afb8507f0826 100644 --- a/torch/_logging/_registrations.py +++ b/torch/_logging/_registrations.py @@ -2,13 +2,14 @@ from ._internal import register_artifact, register_log DYNAMIC = ["torch.fx.experimental.symbolic_shapes", "torch.fx.experimental.sym_node"] +DISTRIBUTED = ["torch.distributed", "torch._dynamo.backends.distributed"] register_log("dynamo", ["torch._dynamo", *DYNAMIC]) register_log("aot", "torch._functorch.aot_autograd") register_log("inductor", "torch._inductor") register_log("dynamic", DYNAMIC) register_log("torch", "torch") -register_log("distributed", "torch.distributed") +register_log("distributed", DISTRIBUTED) register_log("onnx", "torch.onnx") register_artifact( From 4abf2b22615afe8c24057dd0678cc99303203d12 Mon Sep 17 00:00:00 2001 From: Yang Chen Date: Mon, 27 Nov 2023 12:22:34 -0800 Subject: [PATCH 199/221] [dynamo] fixed record_replayer issue when TORCH_COMPILE_DEBUG=1 (#114623) In https://github.com/pytorch/pytorch/pull/113432, we changed the behavior of _is_allowed_module_prefix, where we moved the '.' from the module perfixes. Consequently, 'LOAD_ATTR submodule' (e.g. LOAD_ATTR fx) is turned into PythonModuleVariable instead of TorchVariable. This caused some issue for record_replayer.record_module_access , which is enabled by setting TORCH_COMPILER_DEBUG=1, because 'torch.fx' doesn't exist in record_replayer's name_to_modrec dictionary when record_module_access is called. This PR fixed the issue by adding "torch.fx" into record_replayer's EXCLUDES list. The fix is likely to be a workaround to unblock internal workflow. There might be some fundamental changes to the relevant pieces along with Yanbo's refactoring PRs for tracing in-graph functions. Pull Request resolved: https://github.com/pytorch/pytorch/pull/114623 Approved by: https://github.com/mlazos, https://github.com/yanboliang --- torch/_dynamo/replay_record.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_dynamo/replay_record.py b/torch/_dynamo/replay_record.py index e106e09253f6..0a388ece0233 100644 --- a/torch/_dynamo/replay_record.py +++ b/torch/_dynamo/replay_record.py @@ -40,7 +40,7 @@ def load(cls, f): @dataclasses.dataclass class ExecutionRecorder: - MOD_EXCLUDES = ["torch", "torch.fx.passes"] + MOD_EXCLUDES = ["torch", "torch.fx", "torch.fx.passes"] LOCAL_MOD_PREFIX = "___local_mod_" code: CodeType From 8556a09d44e3e65ef360ba6a64f724d264ee35cb Mon Sep 17 00:00:00 2001 From: drisspg Date: Tue, 28 Nov 2023 02:40:38 +0000 Subject: [PATCH 200/221] Require less alignment for attn bias (#114173) # Summary Improved Fix for Attention Mask Alignment Issue (#112577) This PR addresses Issue #112577 by refining the previously implemented fix, which was found to be incorrect and causes un-needed memory regressions. The update simplifies the approach to handling the alignment of the attention mask for mem eff attention. ## Changes Alignment Check and Padding: Initially, the alignment of the attention mask is checked. If misalignment is detected, padding is applied, followed by slicing. During this process, a warning is raised to alert users. Should this be warn_once? We only call expand, once on the aligned mask. Reference https://github.com/facebookresearch/xformers/blob/main/xformers/ops/fmha/cutlass.py#L115 @albanD, @mruberry, @jbschlosser, @walterddr, and @mikaylagawarecki. Pull Request resolved: https://github.com/pytorch/pytorch/pull/114173 Approved by: https://github.com/danthe3rd --- .../ATen/native/transformers/attention.cpp | 45 +++++++++--------- .../cuda/mem_eff_attention/kernel_backward.h | 14 ++++-- .../cuda/mem_eff_attention/kernel_forward.h | 12 +++-- test/inductor/test_torchinductor.py | 47 +++++++++++++++++++ ...st_torchinductor_codegen_dynamic_shapes.py | 1 + test/test_transformers.py | 18 +++++++ torch/_inductor/lowering.py | 41 ++++++++++++---- torch/_meta_registrations.py | 16 ++++--- 8 files changed, 147 insertions(+), 47 deletions(-) diff --git a/aten/src/ATen/native/transformers/attention.cpp b/aten/src/ATen/native/transformers/attention.cpp index 63b4a52d8c07..a0772ef52f66 100644 --- a/aten/src/ATen/native/transformers/attention.cpp +++ b/aten/src/ATen/native/transformers/attention.cpp @@ -522,9 +522,14 @@ c10::optional convert_boolean_attn_mask(const c10::optional& att // We apply this function to the top level SDPA so that // if padding is done it will be tracked for backward automatically -template -bool is_aligned(const SymInt& size){ - return size % alignment == 0; +template +bool aligned_tensor(const at::Tensor& tensor){ + for(const auto i : c10::irange(tensor.dim() - 1)){ + if(tensor.sym_stride(i) % alignment != 0){ + return false; + } + } + return tensor.sym_stride(-1) == 1; } template @@ -540,31 +545,23 @@ at::Tensor preprocess_mask( const at::Tensor& query, const at::Tensor& key, const at::Tensor& value) { - constexpr int mem_eff_alignment = 16; - // Expand to 4d case - at::Tensor attn_mask = mask.expand_symint( + constexpr int mem_eff_alignment = 8; + at::Tensor result_mask = mask; + if (!aligned_tensor(mask)) { + TORCH_WARN_ONCE( + "Memory Efficient Attention requires the attn_mask to be aligned to, ", + mem_eff_alignment, + " elements. " + "Prior to calling SDPA, pad the last dimension of the attn_mask " + "to be at least a multiple of ", mem_eff_alignment, + " and then slice the attn_mask to the original size."); + result_mask = pad_bias(mask); + } + return result_mask.expand_symint( {query.sym_size(0), query.sym_size(1), query.sym_size(2), key.sym_size(2)}); - - bool aligned_last_dim = is_aligned(attn_mask.sym_size(-1)); - // Apply pad_bias and store the result in attn_mask - if (!aligned_last_dim) { - return pad_bias(attn_mask); - } - // Check and make the tensor contiguous if needed - auto needs_contig = [](const c10::SymInt& stride) { - return (stride % 16 != 0) || (stride == 0); - }; - if (needs_contig(attn_mask.sym_stride(0)) || - needs_contig(attn_mask.sym_stride(1)) || - needs_contig(attn_mask.sym_stride(2)) || - needs_contig(attn_mask.sym_stride(3))) { - return attn_mask.contiguous(); - } - - return attn_mask; } // FlashAttentionV2 requires that head dimension be a multiple of 8 // This was previously done within the kernel, however diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h index b032b83326ef..987c223fa942 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h @@ -1183,7 +1183,7 @@ struct AttentionBackwardKernel { "value is not correctly aligned (strideH)"); TORCH_CHECK( p.num_batches <= 1 || p.q_strideB % kMinimumAlignment == 0, - "query is not correctly aligned (strideB)"); + "query is not correctly aligned (strideB)."); TORCH_CHECK( p.num_batches <= 1 || p.k_strideB % kMinimumAlignment == 0, "key is not correctly aligned (strideB)"); @@ -1202,13 +1202,19 @@ struct AttentionBackwardKernel { if (p.bias_ptr) { TORCH_CHECK( p.num_batches <= 1 || p.bias_strideB % kMinimumAlignment == 0, - "attn_bias is not correctly aligned (strideB)"); + "attn_bias is not correctly aligned (strideB). ", + "attn_bias.stride(0) = ", p.bias_strideB, ", and should be a " + "multiple of ", kMinimumAlignment, "."); TORCH_CHECK( p.num_heads <= 1 || p.bias_strideH % kMinimumAlignment == 0, - "attn_bias is not correctly aligned (strideH)"); + "attn_bias is not correctly aligned (strideH) ." + "attn_bias.stride(1) = ", p.bias_strideH, ", and should be a " + "multiple of ", kMinimumAlignment, "."); TORCH_CHECK( p.num_queries <= 1 || p.bias_strideM % kMinimumAlignment == 0, - "attn_bias is not correctly aligned (strideM)"); + "attn_bias is not correctly aligned (strideM). " + "attn_bias.stride(2) = ", p.bias_strideM, ", and should be a ", + "multiple of ", kMinimumAlignment, "."); } if (p.grad_bias_ptr) { TORCH_CHECK( diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h index 0a661fbfe817..02ce33ce262c 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h @@ -577,13 +577,19 @@ struct AttentionKernel { CHECK_ALIGNED_PTR(p.attn_bias_ptr, kAlignmentQ); TORCH_CHECK( p.num_batches <= 1 || p.bias_strideB % kAlignmentQ == 0, - "attn_bias is not correctly aligned (strideB)"); + "attn_bias is not correctly aligned (strideB). ", + "attn_bias.stride( 0) = ", p.bias_strideB, ", and should be a " + "multiple of ", kAlignmentQ, "."); TORCH_CHECK( p.num_heads <= 1 || p.bias_strideH % kAlignmentQ == 0, - "attn_bias is not correctly aligned (strideH)"); + "attn_bias is not correctly aligned (strideH). " + "attn_bias.stride(1) = ", p.bias_strideH, ", and should be a " + "multiple of ", kAlignmentQ, "."); TORCH_CHECK( p.num_queries <= 1 || p.bias_strideM % kAlignmentQ == 0, - "attn_bias is not correctly aligned (strideM)"); + "attn_bias is not correctly aligned (strideM). " + "attn_bias.stride(2) = ", p.bias_strideM, ", and should be a " + "multiple of ", kAlignmentQ, "."); } TORCH_CHECK( p.q_strideM % kAlignmentQ == 0, diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 6877657d9c01..5eaff6d7fbdf 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -46,6 +46,7 @@ from torch.testing import FileCheck, make_tensor from torch.testing._internal.common_cuda import ( PLATFORM_SUPPORTS_FLASH_ATTENTION, + PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, SM80OrLater, TEST_CUDNN, with_tf32_off, @@ -7111,6 +7112,52 @@ def foo(arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): rtol=1e4, ) + @requires_cuda() + @unittest.skipIf( + not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, + "Does not support mem_eff_attention", + ) + @skipIfRocm + def test_sdpa_unaligned_mask(self): + def foo( + arg0_1: "f32[8, 8, 16, 16]", + arg1_1: "f32[8, 8, 15, 16]", + arg2_1: "f32[8, 8, 15, 16]", + arg3_1: "f32[1, 1, 16, 15]", + ): + constant_pad_nd: "f32[1, 1, 16, 16]" = ( + torch.ops.aten.constant_pad_nd.default(arg3_1, [0, 1], 0.0) + ) + arg3_1 = None + slice_1: "f32[1, 1, 16, 15]" = torch.ops.aten.slice.Tensor( + constant_pad_nd, -1, 0, 15 + ) + constant_pad_nd = None + expand: "f32[8, 8, 16, 15]" = torch.ops.aten.expand.default( + slice_1, [8, 8, 16, 15] + ) + slice_1 = None + _scaled_dot_product_efficient_attention = ( + torch.ops.aten._scaled_dot_product_efficient_attention.default( + arg0_1, arg1_1, arg2_1, expand, False + ) + ) + arg0_1 = arg1_1 = arg2_1 = expand = None + getitem: "f32[8, 8, 16, 16]" = _scaled_dot_product_efficient_attention[0] + _scaled_dot_product_efficient_attention = None + return (getitem,) + + query = torch.rand(8, 8, 16, 16, device="cuda") + key = torch.rand(8, 8, 15, 16, device="cuda") + value = torch.rand(8, 8, 15, 16, device="cuda") + bias = torch.rand(1, 1, 16, 15, device="cuda") + self.common( + foo, + (query, key, value, bias), + atol=0.02, + rtol=1e4, + ) + def test_where_with_logical_op(self): def fn_and(x, y): return torch.where(torch.logical_and(x, y), 1.0, 0.0) diff --git a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py index 1813588c21e5..463867cf2f85 100644 --- a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py @@ -265,6 +265,7 @@ def run(*ex, **kwargs): ("cpu", "cuda"), is_skip=True ), "test_sdpa_dynamic_shapes": TestFailure(("cpu",), is_skip=True), + "test_sdpa_unaligned_mask_dynamic_shapes": TestFailure(("cpu",), is_skip=True), # # The following tests do not support dynamic shapes yet: # diff --git a/test/test_transformers.py b/test/test_transformers.py index 5785fedca0e1..81e574b75655 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -1898,6 +1898,24 @@ def test_mem_eff_attention_long_sequence_mask(self, device, dtype): out = F.scaled_dot_product_attention(query, key, value, mask) out.sum().backward() + @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system") + def test_mem_eff_attention_non_contig_mask_bug(self, device): + dtype = torch.float32 + make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=True) + batch, num_heads, head_dim = 1, 16, 128 + seq_len_q, seq_len_kv = 1, 16 + query = make_tensor(batch, seq_len_q, num_heads * head_dim).view(batch, seq_len_q, num_heads, head_dim).transpose(1, 2) + kv_shape = (batch, seq_len_kv, head_dim) + key, value = make_tensor(kv_shape).unsqueeze(1), make_tensor(kv_shape).unsqueeze(1) + key = key.expand(-1, num_heads, -1, -1) + value = value.expand(-1, num_heads, -1, -1) + mask = torch.ones((1, 1, seq_len_q, seq_len_kv), device=device, dtype=torch.bool) + with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]): + out = F.scaled_dot_product_attention(query, key, value, mask) + out_no_mask = F.scaled_dot_product_attention(query, key, value, None) + max_diff = (out - out_no_mask).abs().mean() + assert max_diff.item() < 1e-9 + @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system") @parametrize("type", ["dense", "nested"]) @parametrize("is_contiguous", [True, False]) diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 1ef717a301cf..3f237cc43dd3 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -2045,21 +2045,44 @@ def apply_constraint(arg, fx_arg): # contiguous stride order stride_order = list(reversed(range(len(arg.get_size())))) - ALIGNMENT = 16 + # This is the minimum alignment required by SDPA kernels for attention_bias. + # This value can be found in pytorch/aten/src/ATen/native/transformers/attention.cpp preprocess_mask + ALIGNMENT = 8 + + is_backward = fx_node.target in ( + aten._scaled_dot_product_efficient_attention_backward.default, + aten._scaled_dot_product_flash_attention_backward.default, + ) def is_aligned(x): return (V.graph.sizevars.size_hint(x.get_size()[-1]) % ALIGNMENT) == 0 assert isinstance(arg, TensorBox) - unaligned_input_shape = isinstance(arg.data, ir.SliceView) and not is_aligned( - arg - ) - aligned_input_view = unaligned_input_shape and is_aligned(arg.unwrap_view()) - # input is padded, requiring_stride_order will unwrap the view and unpad. - # Would be nice to be able to require certain padding from inductor ir, nyi - if aligned_input_view: - return arg + # This correctly handles the forward case: + if isinstance(arg.data, (ir.SliceView, ir.ExpandView)): + if not is_aligned(arg): + # input is padded, requiring_stride_order will unwrap the view and unpad. + # Would be nice to be able to require certain padding from inductor ir, nyi + if is_aligned(arg.unwrap_view()): + return arg + + def is_aligned_backward(x): + aligned_strides = all( + (V.graph.sizevars.size_hint(x.get_stride()[i]) % ALIGNMENT) == 0 + for i in range(len(x.get_stride()) - 1) + ) + return ( + V.graph.sizevars.size_hint(x.get_stride()[-1]) + ) == 1 and aligned_strides + + if ( + isinstance(arg.data, ir.StorageBox) + and arg.data.is_input_buffer() + and is_backward + ): + if len(arg.data.get_size()) == 4 and is_aligned_backward(arg): + return arg return ir.ExternKernel.require_stride_order(arg, stride_order) diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 4c54df447bdb..bb10a34c4c06 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -5221,12 +5221,14 @@ def meta__scaled_dot_product_efficient_backward( ) grad_bias = None if attn_bias is not None and grad_input_mask[3]: - grad_bias = torch.empty_strided( - attn_bias.size(), - attn_bias.stride(), - dtype=attn_bias.dtype, - device=attn_bias.device, + lastDim = attn_bias.size(-1) + lastDimAligned = lastDim if lastDim % 16 == 0 else lastDim + 16 - lastDim % 16 + new_sizes = list(attn_bias.size()) + new_sizes[-1] = lastDimAligned + grad_bias = torch.empty( + new_sizes, dtype=attn_bias.dtype, device=attn_bias.device ) + grad_bias = grad_bias[..., :lastDim] return grad_q, grad_k, grad_v, grad_bias @@ -5303,12 +5305,12 @@ def meta__efficient_attention_backward( grad_value = torch.empty_like(value) if bias is not None: - assert bias is not None lastDim = bias.size(-1) - lastDimAligned = 16 * ((lastDim + 15) // 16) + lastDimAligned = lastDim if lastDim % 16 == 0 else lastDim + 16 - lastDim % 16 new_sizes = list(bias.size()) new_sizes[-1] = lastDimAligned grad_bias = torch.empty(new_sizes, dtype=bias.dtype, device=bias.device) + grad_bias = grad_bias[..., :lastDim] else: grad_bias = torch.empty((), device=query.device) From 11f11e95df9c205d427fe4dd7e63c9adb91ea03f Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Mon, 27 Nov 2023 15:14:37 +0800 Subject: [PATCH 201/221] [Quant] [Inductor] Fix an issue in QConv Binary Pattern Match (#114541) **Summary** Add the `extra_check` in `_register_quantized_conv_binary_lowering` to skip the pattern which matched unexpected. To match a Conv-Binary pattern, we should expect the extra input of binary node comes from a dequant pattern instead of a constant scalar. **Test Plan** ``` python -m pytest test_mkldnn_pattern_matcher.py -k test_qconv2d_add_2 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/114541 Approved by: https://github.com/jgong5, https://github.com/jerryzh168 ghstack dependencies: #114540 --- test/inductor/test_mkldnn_pattern_matcher.py | 46 ++++++++++++++++++++ torch/_inductor/fx_passes/quantization.py | 29 ++++++++++++ 2 files changed, 75 insertions(+) diff --git a/test/inductor/test_mkldnn_pattern_matcher.py b/test/inductor/test_mkldnn_pattern_matcher.py index d2528fa86499..032557fed8eb 100644 --- a/test/inductor/test_mkldnn_pattern_matcher.py +++ b/test/inductor/test_mkldnn_pattern_matcher.py @@ -644,6 +644,52 @@ def test_qconv2d_add_relu_cpu(self): def test_qconv2d_add_relu_int8_mixed_bf16(self): self._qconv2d_add_cpu_test_helper(use_relu=True, int8_mixed_bf16=True) + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfRocm + def test_qconv2d_add_2(self): + r""" + This testcase prevents this pattern be matched as a conv_binary fusion by mistake. + Conv(X) 3 + \ / + Add + We see this pattern in Mobilenet v3 large which add is decomposed from torch.nn.Hardswish or torch.nn.Hardsigmoid. + """ + + class M(torch.nn.Module): + def __init__( + self, + post_op, + ): + super().__init__() + self.conv = torch.nn.Conv2d(3, 6, kernel_size=3, stride=1) + self.post_op = post_op + + def forward(self, x): + return self.post_op(self.conv(x)) + + for post_op in [ + torch.nn.Hardswish(inplace=True), + torch.nn.Hardsigmoid(inplace=True), + ]: + mod = M(post_op).eval() + v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False).add( + 1 + ) + + def matcher_check_fn(): + # Shouldn't hit conv binary fusion + self.assertEqual( + counters["inductor"]["qconv2d_binary_matcher_count"], 0 + ) + + self._test_common( + mod, + (v,), + check_quantization=True, + matcher_check_fn=matcher_check_fn, + ) + @skipIfNoDynamoSupport @skipIfNoONEDNN @skipIfRocm diff --git a/torch/_inductor/fx_passes/quantization.py b/torch/_inductor/fx_passes/quantization.py index 2bb4824dc95f..7a41f97638e7 100644 --- a/torch/_inductor/fx_passes/quantization.py +++ b/torch/_inductor/fx_passes/quantization.py @@ -389,6 +389,34 @@ def qlinear(match: Match, *args, **kwargs): return qlinear +def _is_valid_quantized_conv_binary_optimization_pattern(output_dtype): + # Check if it's a valid Conv Binary Pattern: + # * qconv2d_pointwise should only has one users + # * Extra input of binary node comes from dequant pattern + def fn(match): + qconv2d_node_after_weight_prepack = filter_nodes( + match.nodes, torch.ops.onednn.qconv2d_pointwise + )[0] + if len(qconv2d_node_after_weight_prepack.users) != 1: + return False + if output_dtype is not None: + binary_node_inputs = list(qconv2d_node_after_weight_prepack.users)[0].args + assert len(binary_node_inputs) == 2, "Expects binary node with 2 inputs" + extra_input_node = None + for arg in binary_node_inputs: + if arg != qconv2d_node_after_weight_prepack: + extra_input_node = arg + break + assert extra_input_node is not None + if (not isinstance(extra_input_node, torch.fx.Node)) or ( + extra_input_node.target != aten.mul.Tensor + ): + return False + return True + + return fn + + def _register_quantized_conv_binary_lowering( pattern, pass_number, @@ -398,6 +426,7 @@ def _register_quantized_conv_binary_lowering( ): @register_lowering_pattern( pattern, + extra_check=_is_valid_quantized_conv_binary_optimization_pattern(output_dtype), pass_number=pass_number, ) def qconv_binary(match: Match, *args, **kwargs): From 0de67e7949b899a90911f7249d5ccece34e46211 Mon Sep 17 00:00:00 2001 From: "Liao, Xuan" Date: Tue, 28 Nov 2023 04:03:20 +0000 Subject: [PATCH 202/221] [cpu] Modify inductor opt flag (#113347) Fixes https://github.com/pytorch/pytorch/issues/113014, https://github.com/pytorch/pytorch/issues/113012, https://github.com/pytorch/pytorch/issues/93598. For CPU inductor path, remove `-funsafe-math-optimizations` from optimization flags to fix functional issues. ### Validation on 3 benchmark suites **FP32** image - No accuracy problem - Slight geomean perf drop - 3 outlier models (speed up < 0.8). Could be solved by adding vectorizations later. **BF16** image - No accuracy problem - Slight geomean perf drop - 4 outlier models (speed up < 0.8). Could be solved by adding vectorizations later. Pull Request resolved: https://github.com/pytorch/pytorch/pull/113347 Approved by: https://github.com/jgong5, https://github.com/desertfire --- torch/_inductor/codecache.py | 2 ++ torch/_inductor/config.py | 3 +++ 2 files changed, 5 insertions(+) diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index ce674a82c28a..b4738def9de1 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -1183,6 +1183,8 @@ def cpp_wrapper_flags() -> str: def optimization_flags() -> str: base_flags = "-O0 -g" if config.aot_inductor.debug_compile else "-O3 -DNDEBUG" base_flags += " -ffast-math -fno-finite-math-only" + if not config.cpp.enable_unsafe_math_opt_flag: + base_flags += " -fno-unsafe-math-optimizations" if config.is_fbcode(): # FIXME: passing `-fopenmp` adds libgomp.so to the generated shared library's dependencies. diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 5f91c96af38c..a3f0e4b5a3eb 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -417,6 +417,9 @@ class cpp: # using atomic_add. fallback_scatter_reduce_sum = True + # Use funsafe-math-optimizations when compiling + enable_unsafe_math_opt_flag = False + # config specific to codegen/triton.py class triton: From 2f875c74bfb3c4d8af95b63455e4a9e31efee8c3 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Tue, 28 Nov 2023 04:38:13 +0000 Subject: [PATCH 203/221] Print ghcr docker pull during build/test (#114510) To make debugging easier to external devs Test plan: Copy and run command from [`Use the following to pull public copy of the image`](https://github.com/pytorch/pytorch/actions/runs/7012511180/job/19077533416?pr=114510#step:6:9): ``` docker pull ghcr.io/pytorch/ci-image:pytorch-linux-jammy-py3.8-gcc11-0d0042fd2e432ea07301ad6f6a474d36a581f0dc ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/114510 Approved by: https://github.com/atalman, https://github.com/huydhn --- .github/workflows/_linux-build.yml | 9 +++++++++ .github/workflows/_linux-test.yml | 9 +++++++++ 2 files changed, 18 insertions(+) diff --git a/.github/workflows/_linux-build.yml b/.github/workflows/_linux-build.yml index 895507ff40ea..9a88ed70b7f2 100644 --- a/.github/workflows/_linux-build.yml +++ b/.github/workflows/_linux-build.yml @@ -93,6 +93,15 @@ jobs: with: docker-image-name: ${{ inputs.docker-image-name }} + - name: Use following to pull public copy of the image + id: print-ghcr-mirror + env: + ECR_DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }} + shell: bash + run: | + tag=${ECR_DOCKER_IMAGE##*/} + echo "docker pull ghcr.io/pytorch/ci-image:${tag/:/-}" + - name: Pull docker image uses: pytorch/test-infra/.github/actions/pull-docker-image@main with: diff --git a/.github/workflows/_linux-test.yml b/.github/workflows/_linux-test.yml index 685c05580611..4bd0e38a0f75 100644 --- a/.github/workflows/_linux-test.yml +++ b/.github/workflows/_linux-test.yml @@ -77,6 +77,15 @@ jobs: with: docker-image-name: ${{ inputs.docker-image }} + - name: Use following to pull public copy of the image + id: print-ghcr-mirror + env: + ECR_DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }} + shell: bash + run: | + tag=${ECR_DOCKER_IMAGE##*/} + echo "docker pull ghcr.io/pytorch/ci-image:${tag/:/-}" + - name: Pull docker image uses: pytorch/test-infra/.github/actions/pull-docker-image@main with: From 8933ff35953908545cb410d736a6c4a3a7761c54 Mon Sep 17 00:00:00 2001 From: cyy Date: Tue, 28 Nov 2023 05:03:34 +0000 Subject: [PATCH 204/221] Make torch::jit::module movable (#114041) This PR makes torch::jit::module movable to improve performance. Pull Request resolved: https://github.com/pytorch/pytorch/pull/114041 Approved by: https://github.com/huydhn --- torch/csrc/jit/api/module.h | 2 ++ torch/csrc/jit/api/object.h | 2 ++ 2 files changed, 4 insertions(+) diff --git a/torch/csrc/jit/api/module.h b/torch/csrc/jit/api/module.h index 83d46e0ce205..6c49b695cb6b 100644 --- a/torch/csrc/jit/api/module.h +++ b/torch/csrc/jit/api/module.h @@ -90,6 +90,8 @@ struct TORCH_API Module : public Object { Module() = default; Module(const Module&) = default; Module& operator=(const Module&) = default; + Module(Module&&) noexcept = default; + Module& operator=(Module&&) noexcept = default; Module( c10::QualifiedName, std::shared_ptr cu, diff --git a/torch/csrc/jit/api/object.h b/torch/csrc/jit/api/object.h index 6a3d38a8292d..7ccacf385be5 100644 --- a/torch/csrc/jit/api/object.h +++ b/torch/csrc/jit/api/object.h @@ -25,6 +25,8 @@ struct TORCH_API Object { Object() = default; Object(const Object&) = default; Object& operator=(const Object&) = default; + Object(Object&&) noexcept = default; + Object& operator=(Object&&) noexcept = default; Object(ObjectPtr _ivalue) : _ivalue_(std::move(_ivalue)) {} Object(std::shared_ptr cu, const c10::ClassTypePtr& type); Object( From 6636c2b1787f8e7933be92ac57cb4b6a28d36246 Mon Sep 17 00:00:00 2001 From: PyTorch UpdateBot Date: Tue, 28 Nov 2023 05:41:33 +0000 Subject: [PATCH 205/221] [executorch hash update] update the pinned executorch hash (#114648) This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/_update-commit-hash.yml). Update the pinned executorch hash. Pull Request resolved: https://github.com/pytorch/pytorch/pull/114648 Approved by: https://github.com/pytorchbot --- .ci/docker/ci_commit_pins/executorch.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.ci/docker/ci_commit_pins/executorch.txt b/.ci/docker/ci_commit_pins/executorch.txt index 4d5c9db5b489..5853ef2cffa6 100644 --- a/.ci/docker/ci_commit_pins/executorch.txt +++ b/.ci/docker/ci_commit_pins/executorch.txt @@ -1 +1 @@ -1f584d77e610b98ed38138fce9922f9f4b7d9e21 +5159de436ced71c78bc1c22e3c1d93654c429227 From 3f574eadb4d8a4c9cf9eb2fcd91a2944f3555886 Mon Sep 17 00:00:00 2001 From: Jon Chuang <9093549+jon-chuang@users.noreply.github.com> Date: Tue, 28 Nov 2023 06:29:40 +0000 Subject: [PATCH 206/221] [dynamo / DDP] - lazily compile submodules - to propagate real tensor strides to backend compiler (#114154) Fixes https://github.com/pytorch/pytorch/issues/113812, https://github.com/pytorch/pytorch/issues/102591, Probably fixes: https://github.com/pytorch/pytorch/issues/113740, https://github.com/pytorch/pytorch/issues/113786, https://github.com/pytorch/pytorch/issues/113788 Pull Request resolved: https://github.com/pytorch/pytorch/pull/114154 Approved by: https://github.com/wconstab --- test/distributed/test_dynamo_distributed.py | 41 ++++++ torch/_dynamo/backends/distributed.py | 141 +++++++------------- 2 files changed, 90 insertions(+), 92 deletions(-) diff --git a/test/distributed/test_dynamo_distributed.py b/test/distributed/test_dynamo_distributed.py index 82d4248fb6cb..1547e595c924 100644 --- a/test/distributed/test_dynamo_distributed.py +++ b/test/distributed/test_dynamo_distributed.py @@ -544,6 +544,7 @@ def test_ddp_baseline_inductor(self): @patch.object(config, "optimize_ddp", True) def test_graph_split(self): + assert config.optimize_ddp """ Just ensures that the appropriate number of splits happen (based on bucket size and model parameters) - verifies the number of times @@ -625,6 +626,7 @@ def opt_fn(inputs): @patch.object(config, "optimize_ddp", True) @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") def test_graph_split_inductor(self): + assert config.optimize_ddp """ Same as above, but using inductor backend. We observed issues with inductor/fx interface in the past. @@ -639,6 +641,45 @@ def opt_fn(inputs): opt_outputs = opt_fn(inputs) self.assertTrue(same(correct_outputs, opt_outputs)) + @torch._inductor.config.patch({"layout_optimization": True, "keep_output_stride": False}) + @patch.object(config, "optimize_ddp", True) + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + def test_graph_split_inductor_layout_optimizations(self): + assert config.optimize_ddp + channel_dim = 512 + # channel dim must be > 64 for inductor to do layout optimization and use NHWC + + class ToyModelConv(nn.Module): + def __init__(self): + super().__init__() + self.net = nn.Sequential( + *[nn.Conv2d(channel_dim, channel_dim, 1, stride=1, bias=False), nn.ReLU()] + + [nn.Conv2d(channel_dim, channel_dim, 1, stride=1, bias=False), nn.ReLU()] + + [nn.Conv2d(channel_dim, channel_dim, 1, stride=1, bias=False), nn.ReLU()] + + [nn.Conv2d(channel_dim, channel_dim, 1, stride=1, bias=False), nn.ReLU()] + ) + + def forward(self, inputs): + return self.net(inputs) + + def get_model(): + m = ToyModelConv().to(self.device) + m.apply(init_weights) + inputs = torch.rand(2, channel_dim, channel_dim, 128).to(self.device) + outputs = m(inputs) + return m, inputs, outputs + + m, inputs, correct_outputs = get_model() + ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25) + + @torch._dynamo.optimize("inductor") + def opt_fn(inputs): + return ddp_m(inputs) + + opt_outputs = opt_fn(inputs) + self.assertTrue(same(correct_outputs, opt_outputs)) + + @patch.object(config, "optimize_ddp", True) def test_no_split(self): """ diff --git a/torch/_dynamo/backends/distributed.py b/torch/_dynamo/backends/distributed.py index 85e2e066290a..46cb70cefe4d 100644 --- a/torch/_dynamo/backends/distributed.py +++ b/torch/_dynamo/backends/distributed.py @@ -6,7 +6,8 @@ import torch from torch import fx from torch._dynamo.output_graph import GraphCompileReason -from torch._dynamo.utils import deepcopy_to_fake_tensor, detect_fake_mode +from torch._dynamo.utils import detect_fake_mode +from torch._subclasses.fake_tensor import is_fake from torch.fx.node import Node # Regular log messages should go through 'log'. @@ -217,23 +218,6 @@ def compile_fn(self, gm: fx.GraphModule, example_inputs: List[torch.Tensor]): and returns its callable. """ - # Today, optimize_ddp=True and keep_output_stride=False can lead to silent - # correctness issues. The problem is that ddp_optimizer works by partitioning - # the dynamo graph, sending each subgraph through aot autograd to inductor, - # and creates example inputs by eagerly interpreting each subgraph to get - # an output that with the same metadata that we'd get from eager mode. - # This is a problem though, for torch._inductor.config.keep_output_stride. - # The above config can cause the outputs of the first graph to have - # **different** strides from eager, causing the inputs that we pass - # to the second graph to be wrong. - # To really fix this, we would need to faithfully ask inductor - # what the outputs to each graph it expects are. - assert torch._inductor.config.keep_output_stride, """\ -Detected that you are running DDP with torch.compile, along with these two flags: -- torch._dynamo.config.optimize_ddp = True -- torch._inductor.config.keep_output_stride = False -This combination of flags is incompatible. Please set keep_output_stride to False, -or file a github issue.""" fake_mode = detect_fake_mode(example_inputs) if fake_mode is None: fake_mode = torch._subclasses.fake_tensor.FakeTensorMode() @@ -332,32 +316,54 @@ def compile_fn(self, gm: fx.GraphModule, example_inputs: List[torch.Tensor]): debug_str += "\n---------------\n" ddp_graph_log.debug(debug_str) - # 3: compile each of the partitioned submodules using the user-provided compiler - class SubmodCompiler(torch.fx.interpreter.Interpreter): + # 3: Replace submodules with lazily compiling submodule + class SubmoduleReplacer(torch.fx.interpreter.Interpreter): def __init__(self, module, compiler): super().__init__(module) self.compiler = compiler - def compile_submod(self, input_mod, args, kwargs): + def lazily_compiled_submod(self, input_mod): """ - Compile the submodule, - using a wrapper to make sure its output is always a tuple, - which is required by AotAutograd based compilers + Create a wrapper around submodules which: + - lazily compiles each of the partitioned submodules using the user-provided compiler + - unpacks singleton tuples/lists into flat arg """ - assert len(kwargs) == 0, "We assume only args for these modules" - class WrapperModule(torch.nn.Module): - def __init__(self, submod, unwrap_singleton_tuple): + class LazilyCompiledModule(torch.nn.Module): + def __init__(self, submod, compiler, unwrap_singleton_tuple): super().__init__() self.submod = submod + self.compiler = compiler + self.compiled = False self.unwrap_singleton_tuple = unwrap_singleton_tuple def forward(self, *args): + if not self.compiled: + assert ( + fake_mode + ), "fake mode must have been available when creating lazy submod" + fake_args = [] + for arg in args: + if isinstance(arg, torch.Tensor) and not is_fake(arg): + fake_args.append( + torch._dynamo.utils.to_fake_tensor( + arg, fake_mode + ) + ) + else: + fake_args.append(arg) + # First trace with fake args + new_submod = self.compiler(self.submod, tuple(fake_args)) + del self.submod + self.submod = new_submod + self.compiled = True + self.compiler = None + x = self.submod(*args) - # TODO(whc) - # for some reason the isinstance check is necessary if I split one node per submod - # - even though I supposedly wrapped the output in a tuple in those cases, the real - # compiled module was still returning a tensor + # we must let 'input_mod' return a tuple, to make AOT happy. + # (aot_autograd compile_fn literally requires that the output of a graph it compiles is a tuple). + # however, we don't acutally want this tuple to be returned, since the fx logic that calls the submod + # will again wrap outputs from the submod in a tuple. So we unwrap it, and count on it being re-wrapped if self.unwrap_singleton_tuple and isinstance(x, (tuple, list)): return x[0] return x @@ -378,84 +384,35 @@ def forward(self, *args): traceback.FrameSummary(__file__, 0, DDPOptimizer), ], ) - wrapper = WrapperModule( - self.compiler(input_mod, args), + wrapper = LazilyCompiledModule( + input_mod, + self.compiler, unwrap_singleton_tuple, ) return wrapper - # Note: - # - # The way distributed works today around fake tensors can be somewhat confusing. - # Some of these codepaths are shared in both runtime, and compile time. The presence - # of a fake_mode, read off of fake tensor inputs, dictates how we will operate. - # - # A few things to keep in mind: - # - # 1) We invoke `compile_submod` with a real module. The output of that gets stored - # on the graph via `self.module.add_submodule(n.target, compiled_submod_real)`. - # - # 2) When running a call_module targeted node, if we have a fake_mode, we fakify the - # module we got from self.fetch_attr(n.target). Regardless of fake_mode, we then execute it. - # - # 3) Fake tensors should always be around during compile time. - # - # 4) Fake tensors should never be around at runtime. - # - # 5) We end up with a compilation mode that takes a real submodule and fake tensors, - # to match what aot_autograd expects. See Note: [Fake Modules and AOTAutograd] + # We replace the submodules with lazy submodules which compile + # the corresponding submodules when they are run with real values + # Always returns `None` - we do not need to propagate values in order + # to replace submodules. def run_node(self, n: Node) -> Any: - args, kwargs = self.fetch_args_kwargs_from_env(n) - new_args = [] - assert fake_mode - for arg in args: - if isinstance(arg, torch.Tensor) and not isinstance( - arg, torch._subclasses.FakeTensor - ): - new_args.append( - torch._dynamo.utils.to_fake_tensor(arg, fake_mode) - ) - else: - new_args.append(arg) - - log.debug("run_node %s, %s got args %s", n.op, n.target, args_str(args)) - assert isinstance(args, tuple) - assert isinstance(kwargs, dict) - if n.op == "call_module": real_mod = self.fetch_attr(n.target) - if fake_mode: - curr_submod = deepcopy_to_fake_tensor(real_mod, fake_mode) - else: - curr_submod = real_mod ddp_graph_log.debug( - "\n---%s graph---\n%s", n.target, curr_submod.graph + "\n---%s graph---\n%s", n.target, real_mod.graph ) - # When calling the compiler on the submod, inputs (new_args) are expected to - # be FakeTensors already since Dynamo would have made them FakeTensors in the - # non-DDP flow. However, the parameters are _not_ expected to be FakeTensors, - # since this wrapping happens during compilation - compiled_submod_real = self.compile_submod( - real_mod, new_args, kwargs - ) + assert len(n.kwargs) == 0, "We assume only args for these modules" + lazily_compiled_submod = self.lazily_compiled_submod(real_mod) # We update the original (outer) graph with a call into the compiled module # instead of the uncompiled one. self.module.delete_submodule(n.target) n.target = "compiled_" + n.target - self.module.add_submodule(n.target, compiled_submod_real) - - # Finally, we have to produce inputs for use compiling the next submodule, - # and these need to be FakeTensors, so we execute the module under fake_mode - with fake_mode: - return curr_submod(*new_args, **kwargs) - else: - # placeholder or output nodes don't need to get compiled, just executed - return getattr(self, n.op)(n.target, new_args, kwargs) + self.module.add_submodule(n.target, lazily_compiled_submod) - submod_compiler = SubmodCompiler(split_gm, self.backend_compile_fn) + submod_compiler = SubmoduleReplacer(split_gm, self.backend_compile_fn) submod_compiler.run(*example_inputs) split_gm.recompile() From 71b742b42c12f9d90d1def70fc414735497492eb Mon Sep 17 00:00:00 2001 From: Jez Ng Date: Mon, 27 Nov 2023 18:16:18 -0800 Subject: [PATCH 207/221] [inductor] Remove more type: ignore comments (#114162) Pull Request resolved: https://github.com/pytorch/pytorch/pull/114162 Approved by: https://github.com/Skylion007, https://github.com/eellison --- torch/_dynamo/guards.py | 6 +-- torch/_inductor/codegen/common.py | 60 ++++++++++++--------- torch/_inductor/codegen/cpp.py | 2 +- torch/_inductor/codegen/cuda/cuda_kernel.py | 9 ++-- torch/_inductor/codegen/triton.py | 2 +- torch/_inductor/lowering.py | 2 +- 6 files changed, 47 insertions(+), 34 deletions(-) diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 0ef173155e2f..47a65f703643 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -18,7 +18,7 @@ import types import weakref from inspect import currentframe, getframeinfo -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union from weakref import ReferenceType @@ -368,7 +368,7 @@ def EQUALS_MATCH(self, guard: Guard): val = self.get(guard.name) t = type(val) if np: - np_types = ( + np_types: Tuple[Type[Any], ...] = ( np.int8, np.int16, np.int32, @@ -382,7 +382,7 @@ def EQUALS_MATCH(self, guard: Guard): np.float64, ) else: - np_types = () # type: ignore[assignment] + np_types = () ok_types = ( int, float, diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index be949e8f92a9..0b6fa9732bc0 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -289,28 +289,6 @@ def all_in_parens(string): return string return f"({string})" - def _print_Pow(self, expr): - # Pow() confuses triton - base, exp = expr.args - # NB: Remember this is sizevar computation! You don't typically - # expect to have to do floating point computation including exponents - # in sizevar compute. Instead of adding support for floating - # point pow, you should make upstream retranslate the Sympy expression - # into Tensor expressions earlier and do that instead. - if exp == 0.5: - return self._helper_sqrt(base) # type: ignore[attr-defined] - elif exp == -0.5: - return "1/" + self._helper_sqrt(base) # type: ignore[attr-defined] - base = self._print(base) - assert exp == int(exp), exp - exp = int(exp) - if exp > 0: - return "*".join([self.paren(base)] * exp) - elif exp < 0: - return "1/" + self.paren("*".join([self.paren(base)] * abs(exp))) - else: # exp == 0 - return "1" - def _print_Infinity(self, expr): return "math.inf" @@ -329,8 +307,11 @@ def _print_Add(self, expr): def _print_Mod(self, expr): return " % ".join(map(self.paren, map(self._print, expr.args))) + def _print_FloorDiv(self, expr): + raise NotImplementedError(f"_print_FloorDiv not implemented for {type(self)}") + def _print_CleanDiv(self, expr): - return self._print_FloorDiv(expr) # type: ignore[attr-defined] + return self._print_FloorDiv(expr) def _print_GreaterThan(self, expr): # GreaterThan: >= @@ -362,6 +343,28 @@ def _print_FloorDiv(self, expr): def _helper_sqrt(self, expr): return f"math.sqrt({self._print(expr)})" + def _print_Pow(self, expr): + # Pow() confuses triton + base, exp = expr.args + # NB: Remember this is sizevar computation! You don't typically + # expect to have to do floating point computation including exponents + # in sizevar compute. Instead of adding support for floating + # point pow, you should make upstream retranslate the Sympy expression + # into Tensor expressions earlier and do that instead. + if exp == 0.5: + return self._helper_sqrt(base) + elif exp == -0.5: + return "1/" + self._helper_sqrt(base) + base = self._print(base) + assert exp == int(exp), exp + exp = int(exp) + if exp > 0: + return "*".join([self.paren(base)] * exp) + elif exp < 0: + return "1/" + self.paren("*".join([self.paren(base)] * abs(exp))) + else: # exp == 0 + return "1" + def _print_floor(self, expr): assert len(expr.args) == 1 return f"math.floor({self._print(expr.args[0])})" @@ -982,6 +985,13 @@ def bucketize( """ raise NotImplementedError() + @property + def assert_function(self) -> str: + raise NotImplementedError() + + def index_to_str(self, index: sympy.Expr) -> str: + raise NotImplementedError() + def __enter__(self): class CSEProxy: self.name = "CSEProxy" @@ -1055,14 +1065,14 @@ def indirect_indexing(var, size, check=True): self.compute.writeline( IndirectAssertLine( line, - self.assert_function, # type: ignore[attr-defined] + self.assert_function, var, mask, self.indirect_max_sizes, ) ) - self.indirect_max_sizes[map_key] = (size, self.index_to_str(size)) # type: ignore[attr-defined] + self.indirect_max_sizes[map_key] = (size, self.index_to_str(size)) return sympy_symbol(str(var)) @staticmethod diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 525bec5374a6..2054dd284b5e 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -1502,7 +1502,7 @@ def codegen_loops(self, code, worksharing): self.codegen_loops_impl(loop_nest, code, worksharing) @property - def assert_function(self): + def assert_function(self) -> str: return "TORCH_CHECK" def decide_parallel_depth(self, ranges, threads): diff --git a/torch/_inductor/codegen/cuda/cuda_kernel.py b/torch/_inductor/codegen/cuda/cuda_kernel.py index cf5eea9484ae..c365590be59a 100644 --- a/torch/_inductor/codegen/cuda/cuda_kernel.py +++ b/torch/_inductor/codegen/cuda/cuda_kernel.py @@ -1,5 +1,5 @@ import logging -from typing import Callable, Dict, List, Optional +from typing import Callable, Dict, List, Optional, TYPE_CHECKING from ... import ir from ...autotune_process import CUDABenchmarkRequest @@ -11,6 +11,9 @@ from ..common import IndentedBuffer, Kernel, OpOverrides from ..cpp import CppPrinter, DTYPE_TO_CPP +if TYPE_CHECKING: + from torch._inductor.codegen.cuda.cuda_template import CUDATemplate + log = logging.getLogger(__name__) cexpr = CppPrinter().doprint @@ -134,7 +137,7 @@ def def_kernel( return f"PT_EXPORT int {self.kernel_name}({', '.join(arg_defs)}, {self._EXTRA_CPP_ARGS})" def call_kernel( - self, name: str, node: "CUDATemplateBuffer", epilogue_nodes: List[ir.Buffer] # type: ignore[name-defined] + self, name: str, node: "CUDATemplateBuffer", epilogue_nodes: List[ir.Buffer] ) -> None: """ Generates code to call the kernel through V.graph.wrapper_code. @@ -295,7 +298,7 @@ def __init__( layout: Layout, make_kernel_render: Callable[[CUDATemplateBuffer, Optional[List[IRNode]]], str], bmreq: CUDABenchmarkRequest, - template: "CUDATemplate", # type: ignore[name-defined] + template: "CUDATemplate", ): super().__init__(name, input_nodes, layout) self.category = category diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 0f08f728330f..2e9707420846 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -1268,7 +1268,7 @@ def load_mask(self, var): return mask @property - def assert_function(self): + def assert_function(self) -> str: return "tl.device_assert" def get_strides_of_load(self, index: sympy.Expr): diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 3f237cc43dd3..d7808c55b63e 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -445,7 +445,7 @@ def group_args(arg_pairs): device = None for t in args: if isinstance(t, TensorBox): - device = t.data.get_device() # type: ignore[attr-defined] + device = t.data.get_device() break assert ( device is not None From 5cfa0647a7702248f0c1be08a59cc8b1349677da Mon Sep 17 00:00:00 2001 From: Jez Ng Date: Mon, 27 Nov 2023 18:16:18 -0800 Subject: [PATCH 208/221] Update mypy to 1.7.0 (#114160) It appears that `mypy` is now checking a few more previously-unchecked files; these files are being found via import-following. Not sure exactly why they weren't being checked before. Pull Request resolved: https://github.com/pytorch/pytorch/pull/114160 Approved by: https://github.com/eellison ghstack dependencies: #114162 --- .ci/docker/requirements-ci.txt | 4 ++-- .lintrunner.toml | 2 +- mypy-inductor.ini | 2 +- torch/ao/pruning/_experimental/pruner/prune_functions.py | 4 ++-- torch/ao/quantization/fuser_method_mappings.py | 4 ++-- torch/autograd/profiler_util.py | 6 +++--- torch/distributed/fsdp/_optim_utils.py | 2 +- torch/utils/_cxx_pytree.py | 6 +++--- torch/utils/_pytree.py | 4 ++-- 9 files changed, 17 insertions(+), 17 deletions(-) diff --git a/.ci/docker/requirements-ci.txt b/.ci/docker/requirements-ci.txt index 25be26621985..6405e7aa8726 100644 --- a/.ci/docker/requirements-ci.txt +++ b/.ci/docker/requirements-ci.txt @@ -75,10 +75,10 @@ librosa>=0.6.2 ; python_version < "3.11" #Pinned versions: #test that import: -mypy==1.6.0 +mypy==1.7.0 # Pin MyPy version because new errors are likely to appear with each release #Description: linter -#Pinned versions: 1.6.0 +#Pinned versions: 1.7.0 #test that import: test_typing.py, test_type_hints.py networkx==2.8.8 diff --git a/.lintrunner.toml b/.lintrunner.toml index a991cb8d4c4e..16df3e1cb32b 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -165,7 +165,7 @@ init_command = [ 'numpy==1.24.3 ; python_version == "3.8"', 'numpy==1.26.0 ; python_version >= "3.9"', 'expecttest==0.1.6', - 'mypy==1.6.0', + 'mypy==1.7.0', 'types-requests==2.27.25', 'types-PyYAML==6.0.7', 'types-tabulate==0.8.8', diff --git a/mypy-inductor.ini b/mypy-inductor.ini index 91403e0e2cdd..ea95d844e564 100644 --- a/mypy-inductor.ini +++ b/mypy-inductor.ini @@ -1,7 +1,7 @@ [mypy] plugins = mypy_plugins/check_mypy_version.py, numpy.typing.mypy_plugin -cache_dir = .mypy_cache/nofollow +cache_dir = .mypy_cache/inductor allow_redefinition = True warn_unused_configs = True warn_redundant_casts = True diff --git a/torch/ao/pruning/_experimental/pruner/prune_functions.py b/torch/ao/pruning/_experimental/pruner/prune_functions.py index c4c94e0887ad..8278ec642e9d 100644 --- a/torch/ao/pruning/_experimental/pruner/prune_functions.py +++ b/torch/ao/pruning/_experimental/pruner/prune_functions.py @@ -2,7 +2,7 @@ Collection of conversion functions for linear / conv2d structured pruning Also contains utilities for bias propagation """ -from typing import cast, Optional, Callable, Tuple +from typing import cast, List, Optional, Callable, Tuple import torch from torch import nn, Tensor @@ -13,7 +13,7 @@ # BIAS PROPAGATION def _remove_bias_handles(module: nn.Module) -> None: if hasattr(module, "_forward_hooks"): - bias_hooks = [] + bias_hooks: List[int] = [] for key, hook in module._forward_hooks.items(): if isinstance(hook, BiasHook): bias_hooks.append(key) diff --git a/torch/ao/quantization/fuser_method_mappings.py b/torch/ao/quantization/fuser_method_mappings.py index ede5a45cbe14..7381c0571415 100644 --- a/torch/ao/quantization/fuser_method_mappings.py +++ b/torch/ao/quantization/fuser_method_mappings.py @@ -1,7 +1,7 @@ import torch.nn as nn import torch.ao.nn.intrinsic as nni -from typing import Union, Callable, Tuple, Dict, Optional, Type +from typing import Any, Union, Callable, List, Tuple, Dict, Optional, Type from torch.ao.quantization.utils import Pattern, get_combined_dict, MatchAllNode import itertools @@ -231,7 +231,7 @@ def _get_valid_patterns(op_pattern): (MatchAllNode, (MatchAllNode, MatchAllNode)), ] """ - result = [] + result: List[Any] if isinstance(op_pattern, (tuple, list)): sub_combs = [] for sub_pattern in op_pattern: diff --git a/torch/autograd/profiler_util.py b/torch/autograd/profiler_util.py index fc2379b5ea3a..5e35b8604b86 100644 --- a/torch/autograd/profiler_util.py +++ b/torch/autograd/profiler_util.py @@ -774,11 +774,11 @@ class MemRecordsAcc: def __init__(self, mem_records): self._mem_records = mem_records - self._start_uses = [] - self._indices = [] + self._start_uses: List[int] = [] + self._indices: List[int] = [] if len(mem_records) > 0: tmp = sorted([(r[0].start_us(), i) for i, r in enumerate(mem_records)]) - self._start_uses, self._indices = zip(*tmp) + self._start_uses, self._indices = zip(*tmp) # type: ignore[assignment] def in_interval(self, start_us, end_us): start_idx = bisect.bisect_left(self._start_uses, start_us) diff --git a/torch/distributed/fsdp/_optim_utils.py b/torch/distributed/fsdp/_optim_utils.py index 04c8d81b8bea..4e179d75aeb3 100644 --- a/torch/distributed/fsdp/_optim_utils.py +++ b/torch/distributed/fsdp/_optim_utils.py @@ -340,7 +340,7 @@ def _broadcast_processed_state( if fsdp_state.rank == 0: objects[0] = tree_map_only( torch.Tensor, - lambda v: v.cpu() if v.dim() == 0 else _PosDimTensorInfo(v.shape, v.dtype), + lambda v: v.cpu() if v.dim() == 0 else _PosDimTensorInfo(v.shape, v.dtype), # type: ignore[union-attr] optim_state, ) dist.broadcast_object_list(objects, src=0, group=group) diff --git a/torch/utils/_cxx_pytree.py b/torch/utils/_cxx_pytree.py index 392c0e2688db..06309499ec49 100644 --- a/torch/utils/_cxx_pytree.py +++ b/torch/utils/_cxx_pytree.py @@ -65,9 +65,9 @@ Context = Optional[Any] PyTree = Any TreeSpec = PyTreeSpec -FlattenFunc = Callable[[PyTree], Tuple[List, Context]] -UnflattenFunc = Callable[[Iterable, Context], PyTree] -OpTreeUnflattenFunc = Callable[[Context, Iterable], PyTree] +FlattenFunc = Callable[[PyTree], Tuple[List[Any], Context]] +UnflattenFunc = Callable[[Iterable[Any], Context], PyTree] +OpTreeUnflattenFunc = Callable[[Context, Iterable[Any]], PyTree] def _reverse_args(func: UnflattenFunc) -> OpTreeUnflattenFunc: diff --git a/torch/utils/_pytree.py b/torch/utils/_pytree.py index f74d4a76e5b8..6821a3acb495 100644 --- a/torch/utils/_pytree.py +++ b/torch/utils/_pytree.py @@ -77,8 +77,8 @@ Context = Any PyTree = Any -FlattenFunc = Callable[[PyTree], Tuple[List, Context]] -UnflattenFunc = Callable[[Iterable, Context], PyTree] +FlattenFunc = Callable[[PyTree], Tuple[List[Any], Context]] +UnflattenFunc = Callable[[Iterable[Any], Context], PyTree] DumpableContext = Any # Any json dumpable text ToDumpableContextFn = Callable[[Context], DumpableContext] FromDumpableContextFn = Callable[[DumpableContext], Context] From 06abac971a8fc5cc24550d957a74dfd4729f7d6a Mon Sep 17 00:00:00 2001 From: Andrew Gu Date: Mon, 27 Nov 2023 10:29:25 -0800 Subject: [PATCH 209/221] [FSDP] Simplified FSDP wrapping in ignored module test (#114611) This saves some verbosity. There is no change to functionality. Pull Request resolved: https://github.com/pytorch/pytorch/pull/114611 Approved by: https://github.com/wanchaol --- .../fsdp/test_fsdp_ignored_modules.py | 33 +++++-------------- 1 file changed, 9 insertions(+), 24 deletions(-) diff --git a/test/distributed/fsdp/test_fsdp_ignored_modules.py b/test/distributed/fsdp/test_fsdp_ignored_modules.py index 0acbbb8043d0..44fef4c2f369 100644 --- a/test/distributed/fsdp/test_fsdp_ignored_modules.py +++ b/test/distributed/fsdp/test_fsdp_ignored_modules.py @@ -1,5 +1,6 @@ # Owner(s): ["oncall: distributed"] +import functools import sys import torch @@ -228,33 +229,17 @@ def _test_ignored_modules_nested( # sequential's second linear layer (`layer1[1]`) and then wraps the # overall model while ignoring the nested sequential (`layer1`) model = Model().cuda() - model.layer1[1] = ( - FSDP(model.layer1[1], use_orig_params=use_orig_params) - if not composable - else fully_shard(model.layer1[1]) + fsdp_fn = ( + fully_shard + if composable + else functools.partial(FSDP, use_orig_params=use_orig_params) ) + model.layer1[1] = fsdp_fn(model.layer1[1]) if ignore_modules: - wrapped_model = ( - FSDP( - model, - ignored_modules=[model.layer1], - use_orig_params=use_orig_params, - ) - if not composable - else fully_shard(model, ignored_modules=[model.layer1]) - ) + wrapped_model = fsdp_fn(model, ignored_modules=[model.layer1]) else: - wrapped_model = ( - FSDP( - model, - ignored_states=[model.layer1], - use_orig_params=use_orig_params, - ) - if not composable - else fully_shard( - model, - ignored_states=[model.layer1], - ) + wrapped_model = fsdp_fn( + model, ignored_states=list(model.layer1.parameters()) ) # Check that the wrapped model's flattened parameter does not include # the ignored nested sequential's parameters From 8a35a68bb75b22570bce327bb282c797a7c8dd8a Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Tue, 28 Nov 2023 10:55:20 +0800 Subject: [PATCH 210/221] [Quant] Enable QConv2d with hardtanh post op (#114578) **Summary** Enable QConv2d implementation with post op `hardtanh` **Test Plan** ``` python -m pytest test_quantized_op.py -k test_qconv2d_hardtanh_pt2e ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/114578 Approved by: https://github.com/jgong5, https://github.com/jerryzh168 --- aten/src/ATen/native/quantized/cpu/qconv.cpp | 22 ++++++-- test/quantization/core/test_quantized_op.py | 55 ++++++++++++++++++++ 2 files changed, 74 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/native/quantized/cpu/qconv.cpp b/aten/src/ATen/native/quantized/cpu/qconv.cpp index 149d3defcedc..63ba67e03940 100644 --- a/aten/src/ATen/native/quantized/cpu/qconv.cpp +++ b/aten/src/ATen/native/quantized/cpu/qconv.cpp @@ -1630,7 +1630,23 @@ static at::Tensor _quantized_convolution_onednn( dst.set_scale(accum_ideep_scale); dst.set_zero_point(accum_ideep_zero_points); } else { - op_attr = (has_unary_post_op && unary_attr.value()=="relu") ? ideep::attr_t::fuse_relu() : ideep::attr_t(); + if (has_unary_post_op && unary_attr.value()=="relu") { + op_attr = ideep::attr_t::fuse_relu(); + } else if (has_unary_post_op && unary_attr.value()=="hardtanh") { + TORCH_CHECK( + unary_scalars.size() == 2 && + unary_scalars[0].get().toOptional().has_value() && + unary_scalars[1].get().toOptional().has_value(), + "hardtanh is expected to have two scalar input: min_val and max_val"); + + auto lower_bound_value = + unary_scalars[0].get().toOptional().value().to(); + auto upper_bound_value = + unary_scalars[1].get().toOptional().value().to(); + op_attr = ideep::attr_t::fuse_clamp(lower_bound_value, upper_bound_value); + } else { + op_attr = ideep::attr_t(); + } } // Weight Reorder @@ -1821,8 +1837,8 @@ class QConvoneDNN final { } else { // Conv2D post op check TORCH_CHECK( - attr == "none" || attr == "relu", - "none post_op or post_op relu is supported for quantized pointwise conv2d. Got unary_post_op: ", + attr == "none" || attr == "relu" || attr == "hardtanh", + "none post_op or post_op relu/hardtanh is supported for quantized pointwise conv2d. Got unary_post_op: ", attr, ".") } diff --git a/test/quantization/core/test_quantized_op.py b/test/quantization/core/test_quantized_op.py index 3c294211c494..128f1ff2c403 100644 --- a/test/quantization/core/test_quantized_op.py +++ b/test/quantization/core/test_quantized_op.py @@ -6552,6 +6552,11 @@ def _test_qconv_impl_cpu_tensor( assert not use_transpose, "Cannot fuse ReLU with ConvTranspose" relu = torch.nn.ReLU() result_ref = relu(result_ref) + elif post_op.unary_attr == "hardtanh": + assert not use_transpose, "Cannot fuse hardtanh with ConvTranspose" + assert len(post_op.scalars) == 2, "For post op hardtanh, expect 2 parameters passed in" + hardtanh = torch.nn.Hardtanh(min_val=post_op.scalars[0], max_val=post_op.scalars[1]) + result_ref = hardtanh(result_ref) # Quantize reference results for comparison result_ref_q = torch.quantize_per_tensor( @@ -6894,6 +6899,56 @@ def test_qconv2d_relu_pt2e(self): qconv_output_dtype=output_dtype, ) + # Test qconv with post op hardtanh + @skipIfNoONEDNN + def test_qconv2d_hardtanh_pt2e(self): + input_channels_per_group = 2 + output_channels_per_group = 2 + groups_list = [1, 10] + input_feature_map_shape = (10, 10) + kernels = (3, 3) + strides = (2, 2) + pads = (1, 1) + dilations = (1, 1) + W_scale = [1.5] + W_zero_point = [0] + use_bias_list = [False, True] + use_channelwise_list = [False, True] + output_dtype_list = [None, torch.float32, torch.bfloat16] + options = itertools.product(groups_list, use_bias_list, use_channelwise_list, output_dtype_list) + for groups, use_bias, use_channelwise, output_dtype in options: + qconv = torch.ops.onednn.qconv2d_pointwise + qconv_prepack = torch.ops.onednn.qconv_prepack + conv_op = torch.nn.Conv2d( + input_channels_per_group * groups, + output_channels_per_group * groups, + kernels, + strides, + pads, + dilations, + groups, + ) + pointwise_post_op = PointwisePostOp(unary_attr="hardtanh", scalars=[0.0, 6.0]) + self._test_qconv_impl_cpu_tensor( + qconv, + qconv_prepack, + conv_op, + input_channels_per_group=input_channels_per_group, + input_feature_map_shape=input_feature_map_shape, + output_channels_per_group=output_channels_per_group, + groups=groups, + kernels=kernels, + strides=strides, + pads=pads, + dilations=dilations, + W_scale=W_scale, + W_zero_point=W_zero_point, + use_bias=use_bias, + post_op=pointwise_post_op, + use_channelwise=use_channelwise, + qconv_output_dtype=output_dtype, + ) + # Test qconv with post op add @skipIfNoONEDNN def test_qconv2d_add_pt2e(self): From 8c1f65dc2ba54bcb6766413c28ee558cd047ce83 Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Tue, 28 Nov 2023 10:55:21 +0800 Subject: [PATCH 211/221] [Quant] [PT2] Add Hardtanh and ReLU6 into X86InductorQuantizer Conv2d Unary Annotation (#114579) **Summary** Add `Hardtanh` and `ReLU6` into X86InductorQuantizer Conv2d Unary Annotation **TestPlan** ``` python -m pytest test_x86inductor_quantizer.py -k test_conv2d_unary python -m pytest test_x86inductor_quantizer.py -k test_qat_conv2d_unary ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/114579 Approved by: https://github.com/jgong5, https://github.com/jerryzh168 ghstack dependencies: #114578 --- .../pt2e/test_x86inductor_quantizer.py | 42 ++++++++++++------- .../quantizer/x86_inductor_quantizer.py | 30 ++++++++++--- 2 files changed, 52 insertions(+), 20 deletions(-) diff --git a/test/quantization/pt2e/test_x86inductor_quantizer.py b/test/quantization/pt2e/test_x86inductor_quantizer.py index 1158d72c5bbb..7bb9fbb08607 100644 --- a/test/quantization/pt2e/test_x86inductor_quantizer.py +++ b/test/quantization/pt2e/test_x86inductor_quantizer.py @@ -42,11 +42,11 @@ def forward(self, x): x = self.bn(x) return x - class Conv2dReLUModule(torch.nn.Module): - def __init__(self, inplace_relu: bool = False, use_bias: bool = False, with_bn=False) -> None: + class Conv2dUnaryModule(torch.nn.Module): + def __init__(self, post_op, use_bias: bool = False, with_bn=False) -> None: super().__init__() self.conv = nn.Conv2d(3, 6, (2, 2), stride=(1, 1), padding=(1, 1), bias=use_bias) - self.relu = nn.ReLU(inplace=inplace_relu) + self.post_op = post_op self.bn = torch.nn.BatchNorm2d(6) self.with_bn = with_bn @@ -54,7 +54,7 @@ def forward(self, x): x = self.conv(x) if self.with_bn: x = self.bn(x) - x = self.relu(x) + x = self.post_op(x) return x class Conv2dAddModule(torch.nn.Module): @@ -358,14 +358,20 @@ def test_conv2d(self): @skipIfNoX86 def test_conv2d_unary(self): """ - Test pattern of conv2d with unary post ops (such as relu, sigmoid) with X86InductorQuantizer. - Currently, only relu as unary post op is supported. + Test pattern of conv2d with unary post ops (such as relu, hardtanh, relu6) with X86InductorQuantizer. """ - inplace_relu_list = [True, False] + unary_map = { + "relu": [torch.nn.ReLU(inplace=False), torch.ops.aten.relu.default], + "relu_inplace": [torch.nn.ReLU(inplace=True), torch.ops.aten.relu_.default], + "hardtanh": [torch.nn.Hardtanh(min_val=0.0, max_val=6.0, inplace=False), torch.ops.aten.hardtanh.default], + "hardtanh_inplace": [torch.nn.Hardtanh(min_val=0.0, max_val=6.0, inplace=True), torch.ops.aten.hardtanh_.default], + "relu6": [torch.nn.ReLU6(inplace=False), torch.ops.aten.hardtanh.default], + "relu6_inplace": [torch.nn.ReLU6(inplace=True), torch.ops.aten.hardtanh_.default] + } use_bias_list = [True, False] with override_quantized_engine("x86"), torch.no_grad(): - for inplace_relu, use_bias in itertools.product(inplace_relu_list, use_bias_list): - m = TestHelperModules.Conv2dReLUModule(inplace_relu=inplace_relu, use_bias=use_bias).eval() + for unary_op, use_bias in itertools.product(unary_map.keys(), use_bias_list): + m = TestHelperModules.Conv2dUnaryModule(unary_map[unary_op][0], use_bias=use_bias).eval() example_inputs = (torch.randn(2, 3, 16, 16),) quantizer = X86InductorQuantizer().set_global( xiq.get_default_x86_inductor_quantization_config() @@ -382,7 +388,7 @@ def test_conv2d_unary(self): torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.aten.conv2d.default, - torch.ops.aten.relu_.default if inplace_relu else torch.ops.aten.relu.default, + unary_map[unary_op][1], ] self._test_quantizer( m, @@ -1026,10 +1032,18 @@ def test_qat_conv2d_unary(self): Test QAT pattern of conv2d_bn with unary post ops (such as relu, sigmoid) with X86InductorQuantizer. Currently, only relu as unary post op is supported. """ - inplace_relu_list = [True, False] + unary_map = { + "relu": [torch.nn.ReLU(inplace=False), torch.ops.aten.relu.default], + "relu_inplace": [torch.nn.ReLU(inplace=True), torch.ops.aten.relu_.default], + "hardtanh": [torch.nn.Hardtanh(min_val=0.0, max_val=6.0, inplace=False), torch.ops.aten.hardtanh.default], + "hardtanh_inplace": [torch.nn.Hardtanh(min_val=0.0, max_val=6.0, inplace=True), torch.ops.aten.hardtanh_.default], + "relu6": [torch.nn.ReLU6(inplace=False), torch.ops.aten.hardtanh.default], + "relu6_inplace": [torch.nn.ReLU6(inplace=True), torch.ops.aten.hardtanh_.default] + } + with override_quantized_engine("x86"): - for inplace_relu in itertools.product(inplace_relu_list): - m = TestHelperModules.Conv2dReLUModule(inplace_relu=inplace_relu, with_bn=True) + for unary_op in unary_map.keys(): + m = TestHelperModules.Conv2dUnaryModule(unary_map[unary_op][0], with_bn=True) example_inputs = (torch.randn(2, 3, 16, 16),) quantizer = X86InductorQuantizer().set_global( xiq.get_default_x86_inductor_quantization_config(is_qat=True) @@ -1048,7 +1062,7 @@ def test_qat_conv2d_unary(self): torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.aten.conv2d.default, - torch.ops.aten.relu_.default if inplace_relu else torch.ops.aten.relu.default, + unary_map[unary_op][1], torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, ] diff --git a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py index 69ad2c7af004..da8fd9782da3 100644 --- a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py @@ -546,9 +546,18 @@ def _annotate_qat_conv2d_bn_binary( def _annotate_qat_conv2d_bn_unary( self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig ) -> None: - fused_partitions = find_sequential_partitions( - gm, [torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.ReLU] - ) + fused_partitions = [] + unary_patterns = [ + [torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.ReLU], + [torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.Hardtanh], + [torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.ReLU6], + ] + for unary_pattern in unary_patterns: + partitions = find_sequential_partitions(gm, unary_pattern) + if partitions: + # Extend the fused_partitions if partitions is not empty + fused_partitions.extend(partitions) + for fused_partition in fused_partitions: conv_partition, bn_partition, unary_partition = fused_partition ( @@ -715,9 +724,18 @@ def _annotate_conv2d_binary( def _annotate_conv2d_unary( self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig ) -> None: - fused_partitions = find_sequential_partitions( - gm, [torch.nn.Conv2d, torch.nn.ReLU] - ) + fused_partitions = [] + unary_patterns = [ + [torch.nn.Conv2d, torch.nn.ReLU], + [torch.nn.Conv2d, torch.nn.Hardtanh], + [torch.nn.Conv2d, torch.nn.ReLU6], + ] + for unary_pattern in unary_patterns: + partitions = find_sequential_partitions(gm, unary_pattern) + if partitions: + # Extend the fused_partitions if partitions is not empty + fused_partitions.extend(partitions) + for fused_partition in fused_partitions: conv_partition, unary_partition = fused_partition conv_node, unary_node = self._get_output_nodes_of_partitions( From 95aec251aa9aec4502658dc7599c0a4d6a4eff73 Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Tue, 28 Nov 2023 10:55:22 +0800 Subject: [PATCH 212/221] [Quant] [Inductor] Enable the Inductor Lowering of QConv2d post op hardtanh (#114580) **Summary** Enable the fusion pattern of `QConv2d -> hardtanh` lowering to `hardtanh` as `QConv2d` post operator. **Test Plan** ``` python -m pytest test_mkldnn_pattern_matcher.py -k test_qconv2d_relu6_cpu python -m pytest test_mkldnn_pattern_matcher.py -k test_qconv2d_hardtanh_cpu python -m pytest test_mkldnn_pattern_matcher.py -k test_qat_qconv2d_relu6 python -m pytest test_mkldnn_pattern_matcher.py -k test_qat_qconv2d_hardtanh ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/114580 Approved by: https://github.com/jgong5, https://github.com/jerryzh168 ghstack dependencies: #114578, #114579 --- test/inductor/test_mkldnn_pattern_matcher.py | 91 +++++++++++++++----- torch/_inductor/fx_passes/quantization.py | 29 ++++++- 2 files changed, 95 insertions(+), 25 deletions(-) diff --git a/test/inductor/test_mkldnn_pattern_matcher.py b/test/inductor/test_mkldnn_pattern_matcher.py index 032557fed8eb..9d0ec862beef 100644 --- a/test/inductor/test_mkldnn_pattern_matcher.py +++ b/test/inductor/test_mkldnn_pattern_matcher.py @@ -1,5 +1,6 @@ # Owner(s): ["module: inductor"] import contextlib +import copy import itertools import unittest @@ -495,7 +496,11 @@ def test_qconv2d_int8_mixed_bf16(self): """ self._qconv2d_cpu_test_helper(int8_mixed_bf16=True) - def _qconv2d_unary_cpu_test_helper(self, int8_mixed_bf16=False): + def _qconv2d_unary_cpu_test_helper( + self, + int8_mixed_bf16=False, + unary_op=torch.nn.ReLU(), + ): class M(torch.nn.Module): def __init__( self, @@ -503,9 +508,9 @@ def __init__( ): super().__init__() self.conv = torch.nn.Conv2d(3, 128, kernel_size=3, stride=1) - self.unary_fn = torch.nn.ReLU() + self.unary_fn = copy.deepcopy(unary_op) self.conv2 = torch.nn.Conv2d(128, 128, kernel_size=3, stride=1) - self.unary_fn2 = torch.nn.ReLU() + self.unary_fn2 = copy.deepcopy(unary_op) def forward(self, x): tmp = self.unary_fn(self.conv(x)) @@ -549,6 +554,24 @@ def test_qconv2d_relu_int8_mixed_bf16(self): """ self._qconv2d_unary_cpu_test_helper(int8_mixed_bf16=True) + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfRocm + def test_qconv2d_relu6_cpu(self): + r""" + This testcase will quantize Conv2d->ReLU6 pattern. + """ + self._qconv2d_unary_cpu_test_helper(unary_op=torch.nn.ReLU6()) + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfRocm + def test_qconv2d_hardtanh_cpu(self): + r""" + This testcase will quantize Conv2d->Hardtanh pattern. + """ + self._qconv2d_unary_cpu_test_helper(unary_op=torch.nn.Hardtanh()) + def _qconv2d_add_cpu_test_helper(self, use_relu=False, int8_mixed_bf16=False): r""" This testcase will quantize a Conv2d->Add pattern as: @@ -735,26 +758,26 @@ def matcher_check_fn(): matcher_check_fn=matcher_check_fn, ) - @skipIfNoDynamoSupport - @skipIfNoONEDNN - @skipIfRocm - def test_qat_qconv2d_relu(self): - r""" - This testcase will quantize Conv2d->ReLU pattern with qat flow. - """ - + def _qat_qconv2d_unary_cpu_test_helper( + self, + unary_op=torch.nn.ReLU(), + ): class M(torch.nn.Module): def __init__( self, **kwargs, ): super().__init__() - self.conv = torch.nn.Conv2d(3, 128, kernel_size=3, stride=1) - self.unary_fn = torch.nn.ReLU() - self.bn = torch.nn.BatchNorm2d(128) + self.conv = torch.nn.Conv2d(3, 3, kernel_size=3, stride=1) + self.unary_fn = copy.deepcopy(unary_op) + self.bn = torch.nn.BatchNorm2d(3) + self.conv2 = torch.nn.Conv2d(3, 3, kernel_size=3, stride=1) + self.unary_fn2 = copy.deepcopy(unary_op) + self.bn2 = torch.nn.BatchNorm2d(3) def forward(self, x): - return self.unary_fn(self.bn(self.conv(x))) + tmp = self.unary_fn(self.bn(self.conv(x))) + return self.unary_fn2(self.bn2(self.conv2(tmp))) mod = M() v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=True).add(1) @@ -763,15 +786,11 @@ def matcher_check_fn(): # 1. Dequant-conv pattern matched in quantization weight prepack * 1 # [convert_element_type_1, sub, mul_1, dequantize_per_channel, clone, convolution] self.assertEqual( - counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 1 - ) - self.assertEqual( - counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"], 6 + counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 2 ) # 2. QConv2D Unary fusion in post-grad fusion pass * 1 # [qconv2d_pointwise_default, relu, div_1, round_2, add_1, clamp_min_1, clamp_max_1, convert_element_type_2] - self.assertEqual(counters["inductor"]["qconv2d_unary_matcher_count"], 1) - self.assertEqual(counters["inductor"]["qconv2d_unary_matcher_nodes"], 8) + self.assertEqual(counters["inductor"]["qconv2d_unary_matcher_count"], 2) self._test_common( mod, @@ -781,6 +800,36 @@ def matcher_check_fn(): matcher_check_fn=matcher_check_fn, ) + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfRocm + def test_qat_qconv2d_relu(self): + r""" + This testcase will quantize Conv2d->ReLU pattern with qat flow. + """ + + self._qat_qconv2d_unary_cpu_test_helper() + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfRocm + def test_qat_qconv2d_relu6(self): + r""" + This testcase will quantize Conv2d->ReLU6 pattern with qat flow. + """ + + self._qat_qconv2d_unary_cpu_test_helper(unary_op=torch.nn.ReLU6()) + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfRocm + def test_qat_qconv2d_hardtanh(self): + r""" + This testcase will quantize Conv2d->Hardtanh pattern with qat flow. + """ + + self._qat_qconv2d_unary_cpu_test_helper(unary_op=torch.nn.Hardtanh()) + @skipIfNoDynamoSupport @skipIfNoONEDNN @skipIfRocm diff --git a/torch/_inductor/fx_passes/quantization.py b/torch/_inductor/fx_passes/quantization.py index 7a41f97638e7..66429373f536 100644 --- a/torch/_inductor/fx_passes/quantization.py +++ b/torch/_inductor/fx_passes/quantization.py @@ -169,10 +169,17 @@ def generate_pattern_with_binary( def generate_pattern_with_unary(computation_call, unary_post_op): if unary_post_op is not None: - return CallFunction( - unary_post_op, - computation_call, - ) + if unary_post_op == aten.hardtanh.default: + return CallFunction( + aten.clamp_max, + CallFunction(aten.clamp_min, computation_call, KeywordArg("min_value")), + KeywordArg("max_value"), + ) + else: + return CallFunction( + unary_post_op, + computation_call, + ) return computation_call @@ -286,6 +293,11 @@ def qconv(match: Match, *args, **kwargs): assert ( kwargs["attr"] == "none" ) # Expected no post op fused in weight prepack phase + if unary_attr.op_name == "hardtanh": + min_value = kwargs.get("min_value") + max_value = kwargs.get("max_value") + unary_attr.scalars_attr = [min_value, max_value] + computation_args = ( x, x_scale, @@ -506,6 +518,12 @@ def __init__(self, op_name: str, scalars_attr=None, algorithm_attr=None): ), dtype=original_pattern_output_dtype, ), + UnaryAttr("hardtanh", [], ""): generate_pattern_with_output_quant( + generate_pattern_with_unary( + dequantize_qconv_pt2e_pattern, aten.hardtanh.default + ), + dtype=original_pattern_output_dtype, + ), } for unary_attr, patterns in conv_unary_replace_patterns.items(): @@ -524,6 +542,9 @@ def __init__(self, op_name: str, scalars_attr=None, algorithm_attr=None): UnaryAttr("relu", [], ""): generate_pattern_with_unary( dequantize_qconv_pt2e_pattern, aten.relu.default ), + UnaryAttr("hardtanh", [], ""): generate_pattern_with_unary( + dequantize_qconv_pt2e_pattern, aten.hardtanh.default + ), } for unary_attr, patterns in conv_unary_replace_float_out_patterns.items(): From 00412e6dfacdec5a8508c841b3ea0846388dc872 Mon Sep 17 00:00:00 2001 From: angelayi Date: Tue, 28 Nov 2023 07:40:15 +0000 Subject: [PATCH 213/221] [export] Add meta to params (#114622) The graph from `capture_pre_autograd_graph` doesn't have `meta["val"]` on the param nodes. Pull Request resolved: https://github.com/pytorch/pytorch/pull/114622 Approved by: https://github.com/frank-wei, https://github.com/zhxchen17, https://github.com/khabinov --- torch/_dynamo/eval_frame.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 7fff2c3392fc..b2510e6eb82e 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -1422,6 +1422,14 @@ def graph_with_interpreter(*args): case_name="cond_operands", ) + for node in graph.graph.nodes: + if node.op == "get_attr" and isinstance( + getattr(graph, node.target), torch.Tensor + ): + node.meta["val"] = fake_mode.from_tensor( + getattr(graph, node.target), static_shapes=True + ) + if same_signature: flat_args_dynamic_dims = [ {c.dim for c in (constraints or ()) if c.w_tensor() is x} From 088fc7779eeb4c7690ce6df0cdd390f53b287cff Mon Sep 17 00:00:00 2001 From: Pearu Peterson Date: Mon, 27 Nov 2023 15:22:01 +0000 Subject: [PATCH 214/221] Eliminate unnecessary copy in CUDA addmm with sparse compressed block operand (#114484) As in the title. As a result, `nn.linear(, , bias=)` performance increases as follows (`float16`, `NVIDIA A100-SXM4-80GB`): - 256x256 weights, speed up is 14..27 % - 512x512 weights, speed up is 9..25 % - 1024x1024 weights, speed up is 5..20 % - 2048x2048 weights, speed up is 3..16 % - 4092x4092 weights, speed up is 2..9 % Pull Request resolved: https://github.com/pytorch/pytorch/pull/114484 Approved by: https://github.com/cpuhrsch --- .../ATen/native/sparse/cuda/SparseBlas.cpp | 27 ++++++--- .../native/sparse/cuda/SparseBlasImpl.cpp | 57 +++++++++++++------ .../ATen/native/sparse/cuda/SparseBlasImpl.h | 1 + 3 files changed, 59 insertions(+), 26 deletions(-) diff --git a/aten/src/ATen/native/sparse/cuda/SparseBlas.cpp b/aten/src/ATen/native/sparse/cuda/SparseBlas.cpp index 297e4b601cbc..6cac383ac60c 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseBlas.cpp +++ b/aten/src/ATen/native/sparse/cuda/SparseBlas.cpp @@ -126,13 +126,12 @@ Tensor& addmm_out_sparse_compressed_cuda( "x", self_->size(1)); - if (&result != &self) { + if (!result.is_same(self)) { if (result.layout() == kStrided) { at::native::resize_output(result, self_->sizes()); } else { result.resize_as_sparse_(*self_); } - result.copy_(*self_); } if (result.numel() == 0) { @@ -142,15 +141,21 @@ Tensor& addmm_out_sparse_compressed_cuda( if (sparse::impl::_is_sparse_and_zero(mat1) || sparse::impl::_is_sparse_and_zero(mat2)) { // According to docs, when beta==0 values in self should be ignored. // nans and infs should not propagate - if (beta.toComplexDouble() == 0.) { + const auto beta_val = beta.toComplexDouble(); + if (beta_val == 0.) { result.zero_(); } else { - result.mul_(beta); + if (!result.is_same(self)) { + result.copy_(*self_); + } + if (beta_val != 1.) { + result.mul_(beta); + } } return result; } - sparse::impl::cuda::addmm_out_sparse_csr(mat1, mat2, beta, alpha, result); + sparse::impl::cuda::addmm_out_sparse_csr(*self_, mat1, mat2, beta, alpha, result); return result; } @@ -167,9 +172,8 @@ Tensor& baddbmm_out_sparse_csr_cuda( TORCH_CHECK(mat2.layout() == kStrided, "torch.baddbmm: Expect mat2 to be strided, but got ", mat2.layout()); TORCH_CHECK(result.layout() == kStrided, "torch.baddbmm: Expect result to be strided, but got ", result.layout()); - if (&result != &self) { + if (!result.is_same(self)) { at::native::resize_output(result, self.sizes()); - result.copy_(self); } if (mat1._nnz() == 0) { @@ -178,12 +182,17 @@ Tensor& baddbmm_out_sparse_csr_cuda( if (beta.toComplexDouble() == 0.) { result.zero_(); } else { - result.mul_(beta); + if (!result.is_same(self)) { + result.copy_(self); + } + if (beta.toComplexDouble() != 1.) { + result.mul_(beta); + } } return result; } - sparse::impl::cuda::addmm_out_sparse_csr(mat1, mat2, beta, alpha, result); + sparse::impl::cuda::addmm_out_sparse_csr(self, mat1, mat2, beta, alpha, result); return result; } diff --git a/aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp b/aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp index 2309a40d6555..408d25b79f27 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp +++ b/aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp @@ -464,6 +464,7 @@ void block_sparse_mv( } void block_sparse_mm( + const Tensor& input, const at::sparse_csr::SparseCsrTensor& mat1, const Tensor& mat2, const Scalar& beta, @@ -486,7 +487,7 @@ void block_sparse_mm( // especially for not very sparse inputs. if (mat1.scalar_type() == ScalarType::Half || mat1.scalar_type() == ScalarType::BFloat16) { at::native::sparse::impl::_compressed_row_strided_addmm_out( - result, + input, mat1, mat2, /*beta=*/beta, @@ -497,6 +498,10 @@ void block_sparse_mm( return; } + if (beta.toComplexDouble() != 0. && !result.is_same(input)) { + result.copy_(input); + } + const cusparseDirection_t block_layout = mat1.values().is_contiguous() ? CUSPARSE_DIRECTION_ROW : CUSPARSE_DIRECTION_COLUMN; @@ -838,6 +843,7 @@ void spgemm( } // anonymous namespace void addmm_out_sparse_csr( + const Tensor& input, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, @@ -853,6 +859,39 @@ void addmm_out_sparse_csr( // Valid combinations terminate in a return // Invalid combinations are omitted and will fall though to the TORCH check // generating an informative error message + + // mm functions that copy input to result when needed (e.g. mm + // triton kernels do not require result being initialized with + // input): + if (mat1.layout() == kSparseBsr) { + if (mat2.layout() == kStrided) { + if (result.layout() == kStrided) + return block_sparse_mm(input, mat1, mat2, beta, alpha, result); + } + } + + if (mat1.layout() == kStrided) { + if (mat2.layout() == kSparseBsc) { + if (result.layout() == kStrided) { + auto result_t = result.transpose(-2, -1); + auto input_t = (result.is_same(input) ? result_t : input.transpose(-2, -1)); + return block_sparse_mm( + input_t, + mat2.transpose(-2, -1), + mat1.transpose(-2, -1), + beta, + alpha, + result_t); + } + } + } + + // copy input to result: + if (beta.toComplexDouble() != 0. && !result.is_same(input)) { + result.copy_(input); + } + + // mm functions that assume that result contains input: if (mat1.layout() == kStrided) { if (mat2.layout() == kSparseCsr) { if (result.layout() == kStrided) { @@ -875,16 +914,6 @@ void addmm_out_sparse_csr( result.transpose(-2, -1)); } } - if (mat2.layout() == kSparseBsc) { - if (result.layout() == kStrided) { - return block_sparse_mm( - mat2.transpose(-2, -1), - mat1.transpose(-2, -1), - beta, - alpha, - result.transpose(-2, -1)); - } - } } if (mat1.layout() == kSparseCsr) { if (mat2.layout() == kStrided) { @@ -933,12 +962,6 @@ void addmm_out_sparse_csr( } } } - if (mat1.layout() == kSparseBsr) { - if (mat2.layout() == kStrided) { - if (result.layout() == kStrided) - return block_sparse_mm(mat1, mat2, beta, alpha, result); - } - } TORCH_CHECK( false, "addmm: computation on CUDA is not implemented for ", diff --git a/aten/src/ATen/native/sparse/cuda/SparseBlasImpl.h b/aten/src/ATen/native/sparse/cuda/SparseBlasImpl.h index 4bd7281dea3a..b2bae735dfd6 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseBlasImpl.h +++ b/aten/src/ATen/native/sparse/cuda/SparseBlasImpl.h @@ -11,6 +11,7 @@ namespace impl { namespace cuda { void addmm_out_sparse_csr( + const Tensor& input, const at::sparse_csr::SparseCsrTensor& mat1, const Tensor& mat2, const Scalar& beta, From 89a1fe69667c16b76de1cd87b8dda8ffd77762d3 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Tue, 28 Nov 2023 15:27:07 +0800 Subject: [PATCH 215/221] [pytree] register pytree node type in both C++ pytree and Python pytree (#112111) Changes: 1. Add `_private_register_pytree_node` API in both C++ and Python pytree. In C++ pytree, the API will only register pytree node for C++ pytree. In Python pytree, the API will only register pytree node for Python pytree. 2. Do not allow registering a type as pytree node twice in the Python pytree. 3. Add thread lock to the Python pytree node register API. 4. The old `_register_pytree_node` API will call the `_private_register_pytree_node` API and raise a deprecation warning. 5. Add a new `register_pytree_node` API to register node type in both C++ and Python implementations. 6. Add tests to ensure a warning will be raised when the old private function is called. Pull Request resolved: https://github.com/pytorch/pytorch/pull/112111 Approved by: https://github.com/zou3519 --- test/export/test_export.py | 15 +- test/test_fx.py | 2 +- test/test_pytree.py | 71 ++++++- torch/_export/utils.py | 4 +- torch/_functorch/aot_autograd.py | 15 +- torch/fx/experimental/proxy_tensor.py | 2 +- torch/fx/immutable_collections.py | 6 +- .../_internal/fx/dynamo_graph_extractor.py | 13 +- torch/return_types.py | 2 +- torch/utils/_cxx_pytree.py | 201 +++++++++++++++++- torch/utils/_pytree.py | 144 ++++++++++--- 11 files changed, 410 insertions(+), 65 deletions(-) diff --git a/test/export/test_export.py b/test/export/test_export.py index caa576bfa987..27e44f27aea7 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -624,16 +624,23 @@ class MyDataClass: roundtrip_spec = treespec_loads(treespec_dumps(spec)) self.assertEqual(roundtrip_spec, spec) + @dataclass + class MyOtherDataClass: # the pytree registration don't allow registering the same class twice + x: int + y: int + z: int = None + # Override the registration with keep none fields - register_dataclass_as_pytree_node(MyDataClass, return_none_fields=True, serialized_type_name="test_pytree_regster_data_class.MyDataClass") + register_dataclass_as_pytree_node(MyOtherDataClass, return_none_fields=True, serialized_type_name="test_pytree_regster_data_class.MyOtherDataClass") + dt = MyOtherDataClass(x=3, y=4) flat, spec = tree_flatten(dt) self.assertEqual( spec, TreeSpec( - MyDataClass, + MyOtherDataClass, ( - MyDataClass, + MyOtherDataClass, ['x', 'y', 'z'], [], ), @@ -643,7 +650,7 @@ class MyDataClass: self.assertEqual(flat, [3, 4, None]) orig_dt = tree_unflatten(flat, spec) - self.assertTrue(isinstance(orig_dt, MyDataClass)) + self.assertTrue(isinstance(orig_dt, MyOtherDataClass)) self.assertEqual(orig_dt.x, 3) self.assertEqual(orig_dt.y, 4) self.assertEqual(orig_dt.z, None) diff --git a/test/test_fx.py b/test/test_fx.py index 8de7c3dd6a9c..fa63e79cb46a 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -3529,7 +3529,7 @@ def f_dict_add(x): def f_namedtuple_add(x): return x.x + x.y - pytree._register_pytree_node( + pytree.register_pytree_node( Foo, lambda x: ([x.a, x.b], None), lambda x, _: Foo(x[0], x[1]), diff --git a/test/test_pytree.py b/test/test_pytree.py index 0c0120397eea..ab96a9e1f3ce 100644 --- a/test/test_pytree.py +++ b/test/test_pytree.py @@ -1,7 +1,7 @@ # Owner(s): ["module: pytree"] import unittest -from collections import namedtuple, OrderedDict +from collections import namedtuple, OrderedDict, UserDict import torch import torch.utils._cxx_pytree as cxx_pytree @@ -26,6 +26,45 @@ def __init__(self, x, y): class TestGenericPytree(TestCase): + @parametrize( + "pytree_impl", + [ + subtest(py_pytree, name="py"), + subtest(cxx_pytree, name="cxx"), + ], + ) + def test_register_pytree_node(self, pytree_impl): + class MyDict(UserDict): + pass + + d = MyDict(a=1, b=2, c=3) + + # Custom types are leaf nodes by default + values, spec = pytree_impl.tree_flatten(d) + self.assertEqual(values, [d]) + self.assertIs(values[0], d) + self.assertEqual(d, pytree_impl.tree_unflatten(values, spec)) + self.assertTrue(spec.is_leaf()) + + # Register MyDict as a pytree node + pytree_impl.register_pytree_node( + MyDict, + lambda d: (list(d.values()), list(d.keys())), + lambda values, keys: MyDict(zip(keys, values)), + ) + + values, spec = pytree_impl.tree_flatten(d) + self.assertEqual(values, [1, 2, 3]) + self.assertEqual(d, pytree_impl.tree_unflatten(values, spec)) + + # Do not allow registering the same type twice + with self.assertRaisesRegex(ValueError, "already registered"): + pytree_impl.register_pytree_node( + MyDict, + lambda d: (list(d.values()), list(d.keys())), + lambda values, keys: MyDict(zip(keys, values)), + ) + @parametrize( "pytree_impl", [ @@ -407,6 +446,28 @@ def test_pytree_serialize_bad_input(self, pytree_impl): class TestPythonPytree(TestCase): + def test_deprecated_register_pytree_node(self): + class DummyType: + def __init__(self, x, y): + self.x = x + self.y = y + + with self.assertWarnsRegex( + UserWarning, "torch.utils._pytree._register_pytree_node" + ): + py_pytree._register_pytree_node( + DummyType, + lambda dummy: ([dummy.x, dummy.y], None), + lambda xs, _: DummyType(*xs), + ) + + with self.assertWarnsRegex(UserWarning, "already registered"): + py_pytree._register_pytree_node( + DummyType, + lambda dummy: ([dummy.x, dummy.y], None), + lambda xs, _: DummyType(*xs), + ) + def test_treespec_equality(self): self.assertTrue( py_pytree.LeafSpec() == py_pytree.LeafSpec(), @@ -540,7 +601,7 @@ def __init__(self, x, y): self.x = x self.y = y - py_pytree._register_pytree_node( + py_pytree.register_pytree_node( DummyType, lambda dummy: ([dummy.x, dummy.y], None), lambda xs, _: DummyType(*xs), @@ -560,7 +621,7 @@ def __init__(self, x, y): self.x = x self.y = y - py_pytree._register_pytree_node( + py_pytree.register_pytree_node( DummyType, lambda dummy: ([dummy.x, dummy.y], None), lambda xs, _: DummyType(*xs), @@ -585,7 +646,7 @@ def __init__(self, x, y): with self.assertRaisesRegex( ValueError, "Both to_dumpable_context and from_dumpable_context" ): - py_pytree._register_pytree_node( + py_pytree.register_pytree_node( DummyType, lambda dummy: ([dummy.x, dummy.y], None), lambda xs, _: DummyType(*xs), @@ -599,7 +660,7 @@ def __init__(self, x, y): self.x = x self.y = y - py_pytree._register_pytree_node( + py_pytree.register_pytree_node( DummyType, lambda dummy: ([dummy.x, dummy.y], None), lambda xs, _: DummyType(*xs), diff --git a/torch/_export/utils.py b/torch/_export/utils.py index afee8efc5946..d8344783a0a3 100644 --- a/torch/_export/utils.py +++ b/torch/_export/utils.py @@ -63,16 +63,16 @@ def register_dataclass_as_pytree_node( flatten_fn: Optional[FlattenFunc] = None, unflatten_fn: Optional[UnflattenFunc] = None, *, + serialized_type_name: Optional[str] = None, to_dumpable_context: Optional[ToDumpableContextFn] = None, from_dumpable_context: Optional[FromDumpableContextFn] = None, - serialized_type_name: Optional[str] = None, return_none_fields: bool = False, ) -> None: assert dataclasses.is_dataclass( cls ), f"Only dataclasses can be registered with this function: {cls}" - serialized_type = f"{cls.__module__}.{cls.__name__}" + serialized_type = f"{cls.__module__}.{cls.__qualname__}" SERIALIZED_DATACLASS_TO_PYTHON_DATACLASS[serialized_type] = cls def default_flatten_fn(obj: Any) -> Tuple[List[Any], Context]: diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py index db83c84e8a6b..ff48fd2bb1b3 100644 --- a/torch/_functorch/aot_autograd.py +++ b/torch/_functorch/aot_autograd.py @@ -29,7 +29,7 @@ from torch._subclasses import FakeTensor, FakeTensorMode from torch._subclasses.fake_tensor import is_fake from torch._subclasses.functional_tensor import FunctionalTensor, FunctionalTensorMode -from torch.fx import immutable_collections, Interpreter +from torch.fx import Interpreter from torch.fx.experimental.proxy_tensor import is_sym_node, py_sym_types from torch.fx.experimental.symbolic_shapes import ( ShapeEnv, is_concrete_int, fx_placeholder_vals, definitely_true, definitely_false, sym_eq @@ -95,19 +95,6 @@ def strict_zip(*iterables, strict=True, **kwargs): ) ) -pytree._register_pytree_node( - immutable_collections.immutable_list, - lambda x: (list(x), None), - lambda x, c: immutable_collections.immutable_list(x), -) -pytree._register_pytree_node( - immutable_collections.immutable_dict, - lambda x: (list(x.values()), list(x.keys())), - lambda x, c: immutable_collections.immutable_dict( - dict(zip(c, x)) - ), -) - def partial_asdict(obj: Any) -> Any: if dataclasses.is_dataclass(obj): return {field.name: getattr(obj, field.name) for field in dataclasses.fields(obj)} diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index dd3520f541aa..e3d8bd673a4d 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -49,7 +49,7 @@ # We currently convert all SymInt to proxies before we use them. # This could plausibly be handled at the Dynamo level. -pytree._register_pytree_node(torch.Size, lambda x: (list(x), None), lambda xs, _: tuple(xs)) +pytree.register_pytree_node(torch.Size, lambda x: (list(x), None), lambda xs, _: tuple(xs)) def fake_signature(fn, nargs): """FX gets confused by varargs, de-confuse it""" diff --git a/torch/fx/immutable_collections.py b/torch/fx/immutable_collections.py index 616555015f0e..a359335f6ece 100644 --- a/torch/fx/immutable_collections.py +++ b/torch/fx/immutable_collections.py @@ -1,7 +1,7 @@ from typing import Any, Dict, Iterable, List, Tuple from ._compatibility import compatibility -from torch.utils._pytree import Context, _register_pytree_node +from torch.utils._pytree import Context, register_pytree_node __all__ = ["immutable_list", "immutable_dict"] @@ -50,5 +50,5 @@ def _immutable_list_unflatten(values: Iterable[Any], context: Context) -> List[A return immutable_list(values) -_register_pytree_node(immutable_dict, _immutable_dict_flatten, _immutable_dict_unflatten) -_register_pytree_node(immutable_list, _immutable_list_flatten, _immutable_list_unflatten) +register_pytree_node(immutable_dict, _immutable_dict_flatten, _immutable_dict_unflatten) +register_pytree_node(immutable_list, _immutable_list_flatten, _immutable_list_unflatten) diff --git a/torch/onnx/_internal/fx/dynamo_graph_extractor.py b/torch/onnx/_internal/fx/dynamo_graph_extractor.py index f55afefd1bbd..79a690f5f48a 100644 --- a/torch/onnx/_internal/fx/dynamo_graph_extractor.py +++ b/torch/onnx/_internal/fx/dynamo_graph_extractor.py @@ -40,7 +40,11 @@ def __init__(self): def __enter__(self): for class_type, (flatten_func, unflatten_func) in self._extensions.items(): - pytree._register_pytree_node(class_type, flatten_func, unflatten_func) + pytree._private_register_pytree_node( + class_type, + flatten_func, + unflatten_func, + ) return self def __exit__(self, exc_type, exc_val, exc_tb): @@ -93,8 +97,11 @@ def model_output_unflatten( # All 'ModelOutput' subclasses are defined under module 'modeling_outputs'. named_model_output_classes = inspect.getmembers( modeling_outputs, - lambda x: inspect.isclass(x) - and issubclass(x, modeling_outputs.ModelOutput), + lambda x: ( + inspect.isclass(x) + and issubclass(x, modeling_outputs.ModelOutput) + and x is not modeling_outputs.ModelOutput + ), ) for _, class_type in named_model_output_classes: diff --git a/torch/return_types.py b/torch/return_types.py index 9f8c85285279..b1284c813387 100644 --- a/torch/return_types.py +++ b/torch/return_types.py @@ -13,7 +13,7 @@ def structseq_flatten(structseq): def structseq_unflatten(values, context): return cls(values) - torch.utils._pytree._register_pytree_node(cls, structseq_flatten, structseq_unflatten) + torch.utils._pytree.register_pytree_node(cls, structseq_flatten, structseq_unflatten) for name in dir(return_types): if name.startswith('__'): diff --git a/torch/utils/_cxx_pytree.py b/torch/utils/_cxx_pytree.py index 06309499ec49..6e55c21a511c 100644 --- a/torch/utils/_cxx_pytree.py +++ b/torch/utils/_cxx_pytree.py @@ -13,6 +13,7 @@ """ import functools +import warnings from typing import ( Any, Callable, @@ -26,6 +27,11 @@ Union, ) +import torch + +if torch._running_with_deploy(): + raise ImportError("C++ pytree utilities do not work with torch::deploy.") + import optree from optree import PyTreeSpec # direct import for type annotations @@ -35,6 +41,9 @@ "Context", "FlattenFunc", "UnflattenFunc", + "DumpableContext", + "ToDumpableContextFn", + "FromDumpableContextFn", "TreeSpec", "LeafSpec", "register_pytree_node", @@ -68,6 +77,9 @@ FlattenFunc = Callable[[PyTree], Tuple[List[Any], Context]] UnflattenFunc = Callable[[Iterable[Any], Context], PyTree] OpTreeUnflattenFunc = Callable[[Context, Iterable[Any]], PyTree] +DumpableContext = Any # Any json dumpable text +ToDumpableContextFn = Callable[[Context], DumpableContext] +FromDumpableContextFn = Callable[[DumpableContext], Context] def _reverse_args(func: UnflattenFunc) -> OpTreeUnflattenFunc: @@ -84,9 +96,11 @@ def register_pytree_node( unflatten_fn: UnflattenFunc, *, serialized_type_name: Optional[str] = None, + to_dumpable_context: Optional[ToDumpableContextFn] = None, + from_dumpable_context: Optional[FromDumpableContextFn] = None, namespace: str = "torch", ) -> None: - """Extend the set of types that are considered internal nodes in pytrees. + """Register a container-like type as pytree node. The ``namespace`` argument is used to avoid collisions that occur when different libraries register the same Python type with different behaviors. It is recommended to add a unique prefix @@ -109,6 +123,13 @@ def register_pytree_node( The function should return an instance of ``cls``. serialized_type_name (str, optional): A keyword argument used to specify the fully qualified name used when serializing the tree spec. + to_dumpable_context (callable, optional): An optional keyword argument to custom specify how + to convert the context of the pytree to a custom json dumpable representation. This is + used for json serialization, which is being used in :mod:`torch.export` right now. + from_dumpable_context (callable, optional): An optional keyword argument to custom specify + how to convert the custom json dumpable representation of the context back to the + original context. This is used for json deserialization, which is being used in + :mod:`torch.export` right now. namespace (str, optional): A non-empty string that uniquely identifies the namespace of the type registry. This is used to isolate the registry from other modules that might register a different custom behavior for the same type. (default: :const:`"torch"`) @@ -193,24 +214,192 @@ def register_pytree_node( ) ) """ - from ._pytree import _register_pytree_node + _private_register_pytree_node( + cls, + flatten_fn, + unflatten_fn, + serialized_type_name=serialized_type_name, + to_dumpable_context=to_dumpable_context, + from_dumpable_context=from_dumpable_context, + namespace=namespace, + ) + + from . import _pytree as python - _register_pytree_node( + python._private_register_pytree_node( cls, flatten_fn, unflatten_fn, serialized_type_name=serialized_type_name, + to_dumpable_context=to_dumpable_context, + from_dumpable_context=from_dumpable_context, ) - optree.register_pytree_node( + +def _register_pytree_node( + cls: Type[Any], + flatten_fn: FlattenFunc, + unflatten_fn: UnflattenFunc, + *, + serialized_type_name: Optional[str] = None, + to_dumpable_context: Optional[ToDumpableContextFn] = None, + from_dumpable_context: Optional[FromDumpableContextFn] = None, + namespace: str = "torch", +) -> None: + """Register a container-like type as pytree node for the C++ pytree only. + + The ``namespace`` argument is used to avoid collisions that occur when different libraries + register the same Python type with different behaviors. It is recommended to add a unique prefix + to the namespace to avoid conflicts with other libraries. Namespaces can also be used to specify + the same class in different namespaces for different use cases. + + .. warning:: + For safety reasons, a ``namespace`` must be specified while registering a custom type. It is + used to isolate the behavior of flattening and unflattening a pytree node type. This is to + prevent accidental collisions between different libraries that may register the same type. + + Args: + cls (type): A Python type to treat as an internal pytree node. + flatten_fn (callable): A function to be used during flattening, taking an instance of + ``cls`` and returning a pair, with (1) an iterable for the children to be flattened + recursively, and (2) some hashable auxiliary data to be stored in the treespec and to be + passed to the ``unflatten_fn``. + unflatten_fn (callable): A function taking two arguments: the auxiliary data that was + returned by ``flatten_fn`` and stored in the treespec, and the unflattened children. + The function should return an instance of ``cls``. + serialized_type_name (str, optional): A keyword argument used to specify the fully + qualified name used when serializing the tree spec. + to_dumpable_context (callable, optional): An optional keyword argument to custom specify how + to convert the context of the pytree to a custom json dumpable representation. This is + used for json serialization, which is being used in :mod:`torch.export` right now. + from_dumpable_context (callable, optional): An optional keyword argument to custom specify + how to convert the custom json dumpable representation of the context back to the + original context. This is used for json deserialization, which is being used in + :mod:`torch.export` right now. + namespace (str, optional): A non-empty string that uniquely identifies the namespace of the + type registry. This is used to isolate the registry from other modules that might + register a different custom behavior for the same type. (default: :const:`"torch"`) + + Example:: + + >>> # xdoctest: +SKIP + >>> # Registry a Python type with lambda functions + >>> register_pytree_node( + ... set, + ... lambda s: (sorted(s), None, None), + ... lambda children, _: set(children), + ... namespace='set', + ... ) + + >>> # xdoctest: +SKIP + >>> # Register a Python type into a namespace + >>> import torch + >>> register_pytree_node( + ... torch.Tensor, + ... flatten_func=lambda tensor: ( + ... (tensor.cpu().detach().numpy(),), + ... {'dtype': tensor.dtype, 'device': tensor.device, 'requires_grad': tensor.requires_grad}, + ... ), + ... unflatten_func=lambda children, metadata: torch.tensor(children[0], **metadata), + ... namespace='torch2numpy', + ... ) + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> tree = {'weight': torch.ones(size=(1, 2)).cuda(), 'bias': torch.zeros(size=(2,))} + >>> tree + {'weight': tensor([[1., 1.]], device='cuda:0'), 'bias': tensor([0., 0.])} + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> # Flatten without specifying the namespace + >>> tree_flatten(tree) # `torch.Tensor`s are leaf nodes # xdoctest: +SKIP + ([tensor([0., 0.]), tensor([[1., 1.]], device='cuda:0')], PyTreeSpec({'bias': *, 'weight': *})) + + >>> # xdoctest: +SKIP + >>> # Flatten with the namespace + >>> tree_flatten(tree, namespace='torch2numpy') # xdoctest: +SKIP + ( + [array([0., 0.], dtype=float32), array([[1., 1.]], dtype=float32)], + PyTreeSpec( + { + 'bias': CustomTreeNode(Tensor[{'dtype': torch.float32, ...}], [*]), + 'weight': CustomTreeNode(Tensor[{'dtype': torch.float32, ...}], [*]) + }, + namespace='torch2numpy' + ) + ) + + >>> # xdoctest: +SKIP + >>> # Register the same type with a different namespace for different behaviors + >>> def tensor2flatparam(tensor): + ... return [torch.nn.Parameter(tensor.reshape(-1))], tensor.shape, None + ... + >>> def flatparam2tensor(children, metadata): + ... return children[0].reshape(metadata) + ... + >>> register_pytree_node( + ... torch.Tensor, + ... flatten_func=tensor2flatparam, + ... unflatten_func=flatparam2tensor, + ... namespace='tensor2flatparam', + ... ) + + >>> # xdoctest: +SKIP + >>> # Flatten with the new namespace + >>> tree_flatten(tree, namespace='tensor2flatparam') # xdoctest: +SKIP + ( + [ + Parameter containing: tensor([0., 0.], requires_grad=True), + Parameter containing: tensor([1., 1.], device='cuda:0', requires_grad=True) + ], + PyTreeSpec( + { + 'bias': CustomTreeNode(Tensor[torch.Size([2])], [*]), + 'weight': CustomTreeNode(Tensor[torch.Size([1, 2])], [*]) + }, + namespace='tensor2flatparam' + ) + ) + """ + warnings.warn( + "torch.utils._cxx_pytree._register_pytree_node is deprecated. " + "Please use torch.utils._cxx_pytree.register_pytree_node instead.", + stacklevel=2, + ) + + _private_register_pytree_node( cls, flatten_fn, - _reverse_args(unflatten_fn), + unflatten_fn, + serialized_type_name=serialized_type_name, + to_dumpable_context=to_dumpable_context, + from_dumpable_context=from_dumpable_context, namespace=namespace, ) -_register_pytree_node = register_pytree_node +def _private_register_pytree_node( + cls: Type[Any], + flatten_fn: FlattenFunc, + unflatten_fn: UnflattenFunc, + *, + serialized_type_name: Optional[str] = None, + to_dumpable_context: Optional[ToDumpableContextFn] = None, + from_dumpable_context: Optional[FromDumpableContextFn] = None, + namespace: str = "torch", +) -> None: + """This is an internal function that is used to register a pytree node type + for the C++ pytree only. End-users should use :func:`register_pytree_node` + instead. + """ + # TODO(XuehaiPan): remove this condition when we make Python pytree out-of-box support + # PyStructSequence types + if not optree.is_structseq_class(cls): + optree.register_pytree_node( + cls, + flatten_fn, + _reverse_args(unflatten_fn), + namespace=namespace, + ) def tree_flatten( diff --git a/torch/utils/_pytree.py b/torch/utils/_pytree.py index 6821a3acb495..4e085121ef41 100644 --- a/torch/utils/_pytree.py +++ b/torch/utils/_pytree.py @@ -17,6 +17,7 @@ import dataclasses import json +import threading import warnings from collections import deque, namedtuple, OrderedDict from typing import ( @@ -99,6 +100,7 @@ class NodeDef(NamedTuple): unflatten_fn: UnflattenFunc +_NODE_REGISTRY_LOCK = threading.Lock() SUPPORTED_NODES: Dict[Type[Any], NodeDef] = {} @@ -120,6 +122,63 @@ class _SerializeNodeDef(NamedTuple): SERIALIZED_TYPE_TO_PYTHON_TYPE: Dict[str, Type[Any]] = {} +def register_pytree_node( + cls: Any, + flatten_fn: FlattenFunc, + unflatten_fn: UnflattenFunc, + *, + serialized_type_name: Optional[str] = None, + to_dumpable_context: Optional[ToDumpableContextFn] = None, + from_dumpable_context: Optional[FromDumpableContextFn] = None, +) -> None: + """Register a container-like type as pytree node. + + Args: + cls: the type to register + flatten_fn: A callable that takes a pytree and returns a flattened + representation of the pytree and additional context to represent the + flattened pytree. + unflatten_fn: A callable that takes a flattened version of the pytree, + additional context, and returns an unflattened pytree. + serialized_type_name: A keyword argument used to specify the fully qualified + name used when serializing the tree spec. + to_dumpable_context: An optional keyword argument to custom specify how + to convert the context of the pytree to a custom json dumpable + representation. This is used for json serialization, which is being + used in torch.export right now. + from_dumpable_context: An optional keyword argument to custom specify how + to convert the custom json dumpable representation of the context + back to the original context. This is used for json deserialization, + which is being used in torch.export right now. + """ + with _NODE_REGISTRY_LOCK: + if cls in SUPPORTED_NODES: + raise ValueError(f"{cls} is already registered as pytree node.") + + _private_register_pytree_node( + cls, + flatten_fn, + unflatten_fn, + serialized_type_name=serialized_type_name, + to_dumpable_context=to_dumpable_context, + from_dumpable_context=from_dumpable_context, + ) + + try: + from . import _cxx_pytree as cxx + except ImportError: + pass + else: + cxx._private_register_pytree_node( + cls, + flatten_fn, + unflatten_fn, + serialized_type_name=serialized_type_name, + to_dumpable_context=to_dumpable_context, + from_dumpable_context=from_dumpable_context, + ) + + def _register_pytree_node( cls: Any, flatten_fn: FlattenFunc, @@ -131,7 +190,8 @@ def _register_pytree_node( to_dumpable_context: Optional[ToDumpableContextFn] = None, from_dumpable_context: Optional[FromDumpableContextFn] = None, ) -> None: - """ + """Register a container-like type as pytree node for the Python pytree only. + Args: cls: the type to register flatten_fn: A callable that takes a pytree and returns a flattened @@ -150,39 +210,73 @@ def _register_pytree_node( back to the original context. This is used for json deserialization, which is being used in torch.export right now. """ + warnings.warn( + "torch.utils._pytree._register_pytree_node is deprecated. " + "Please use torch.utils._pytree.register_pytree_node instead.", + stacklevel=2, + ) + if to_str_fn is not None or maybe_from_str_fn is not None: warnings.warn( "to_str_fn and maybe_from_str_fn is deprecated. " "Please use to_dumpable_context and from_dumpable_context instead." ) - node_def = NodeDef( + _private_register_pytree_node( cls, flatten_fn, unflatten_fn, + serialized_type_name=serialized_type_name, + to_dumpable_context=to_dumpable_context, + from_dumpable_context=from_dumpable_context, ) - SUPPORTED_NODES[cls] = node_def - if (to_dumpable_context is None) ^ (from_dumpable_context is None): - raise ValueError( - f"Both to_dumpable_context and from_dumpable_context for {cls} must " - "be None or registered." - ) - if serialized_type_name is None: - serialized_type_name = f"{cls.__module__}.{cls.__name__}" +def _private_register_pytree_node( + cls: Any, + flatten_fn: FlattenFunc, + unflatten_fn: UnflattenFunc, + *, + serialized_type_name: Optional[str] = None, + to_dumpable_context: Optional[ToDumpableContextFn] = None, + from_dumpable_context: Optional[FromDumpableContextFn] = None, +) -> None: + """This is an internal function that is used to register a pytree node type + for the Python pytree only. End-users should use :func:`register_pytree_node` + instead. + """ + with _NODE_REGISTRY_LOCK: + if cls in SUPPORTED_NODES: + # TODO: change this warning to an error after OSS/internal stabilize + warnings.warn( + f"{cls} is already registered as pytree node. " + "Overwriting the previous registration.", + ) - serialize_node_def = _SerializeNodeDef( - cls, - serialized_type_name, - to_dumpable_context, - from_dumpable_context, - ) - SUPPORTED_SERIALIZED_TYPES[cls] = serialize_node_def - SERIALIZED_TYPE_TO_PYTHON_TYPE[serialized_type_name] = cls + node_def = NodeDef( + cls, + flatten_fn, + unflatten_fn, + ) + SUPPORTED_NODES[cls] = node_def + if (to_dumpable_context is None) ^ (from_dumpable_context is None): + raise ValueError( + f"Both to_dumpable_context and from_dumpable_context for {cls} must " + "be None or registered." + ) + + if serialized_type_name is None: + serialized_type_name = f"{cls.__module__}.{cls.__qualname__}" -register_pytree_node = _register_pytree_node + serialize_node_def = _SerializeNodeDef( + cls, + serialized_type_name, + to_dumpable_context, + from_dumpable_context, + ) + SUPPORTED_SERIALIZED_TYPES[cls] = serialize_node_def + SERIALIZED_TYPE_TO_PYTHON_TYPE[serialized_type_name] = cls def _dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]: @@ -243,25 +337,25 @@ def _odict_unflatten( return OrderedDict((key, value) for key, value in zip(context, values)) -_register_pytree_node( +_private_register_pytree_node( dict, _dict_flatten, _dict_unflatten, serialized_type_name="builtins.dict", ) -_register_pytree_node( +_private_register_pytree_node( list, _list_flatten, _list_unflatten, serialized_type_name="builtins.list", ) -_register_pytree_node( +_private_register_pytree_node( tuple, _tuple_flatten, _tuple_unflatten, serialized_type_name="builtins.tuple", ) -_register_pytree_node( +_private_register_pytree_node( namedtuple, _namedtuple_flatten, _namedtuple_unflatten, @@ -269,7 +363,7 @@ def _odict_unflatten( from_dumpable_context=_namedtuple_deserialize, serialized_type_name="collections.namedtuple", ) -_register_pytree_node( +_private_register_pytree_node( OrderedDict, _odict_flatten, _odict_unflatten, @@ -729,7 +823,7 @@ def _treespec_to_json(treespec: TreeSpec) -> _TreeSpecSchema: if treespec.type not in SUPPORTED_SERIALIZED_TYPES: raise NotImplementedError( - f"Serializing {treespec.type} in pytree is not registered." + f"Serializing {treespec.type} in pytree is not registered.", ) serialize_node_def = SUPPORTED_SERIALIZED_TYPES[treespec.type] From 0bb2600c2885f1d16bb5bc57bae5bab9c3f99eba Mon Sep 17 00:00:00 2001 From: lezcano Date: Tue, 28 Nov 2023 08:18:50 +0000 Subject: [PATCH 216/221] Allow to differentiate through NumPy code (#114608) With this PR it is possible to differentiate through NumPy code modulo the usual caveats that apply to differentiation: - That there are no graphbreaks - That the decomposition in `torch._numpy` is differentiable @ev-br and I were somewhat careful to achieve the second point, but it is not tested though and through, so YMMV Pull Request resolved: https://github.com/pytorch/pytorch/pull/114608 Approved by: https://github.com/voznesenskym --- test/dynamo/test_misc.py | 2 ++ test/inductor/test_torchinductor.py | 26 ++++++++++++++++++++++++++ torch/_dynamo/variables/tensor.py | 19 +++++++++---------- 3 files changed, 37 insertions(+), 10 deletions(-) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 61928d4abd84..945bbe4078f9 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -1692,6 +1692,7 @@ def fn(x): opt_fn = torch._dynamo.optimize(cnts)(fn) x = torch.randn(3) res = opt_fn(x) + self.assertEqual(type(res), np.ndarray) self.assertEqual(cnts.frame_count, 1) def fn(x): @@ -1701,6 +1702,7 @@ def fn(x): opt_fn = torch._dynamo.optimize(cnts)(fn) x = torch.randn(3, requires_grad=True) res = opt_fn(x) + self.assertEqual(type(res), np.ndarray) self.assertEqual(cnts.frame_count, 1) def test_numpy_recompilation_scalar(self): diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 5eaff6d7fbdf..82b68b37db0a 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -8166,6 +8166,32 @@ def fn_cuda(x): self.assertEqual(type(r), np.ndarray) self.assertEqual(r, np.sin(x)) + def test_numpy_autograd(self): + def my_torch(x): + y = torch.cat([torch.sin(x) ** 2, torch.max(x)[None]]) + return y.sum() + + def my_np(x): + y = np.concatenate([np.sin(x) ** 2, np.max(x)[None]]) + return np.sum(y) + + @torch.compile + def wrapper(x): + x = x.numpy() + y = my_np(x) + return torch.as_tensor(y) + + x_np = torch.arange(8, dtype=torch.float32, requires_grad=True) + x = torch.arange(8, dtype=torch.float32, requires_grad=True) + + out_np = wrapper(x_np) + out = my_torch(x) + self.assertEqual(out, out_np) + + out_np.backward() + out.backward() + self.assertEqual(x.grad, x_np.grad) + # Disable constant propagation, so we isolate value range analysis @patch.object(config, "constant_and_index_propagation", False) @patch.object(config, "joint_graph_constant_folding", False) diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index ca8d34988d56..d12872a9201e 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -523,19 +523,18 @@ def make_const_size_variable(x, **options): f"can't convert {self.layout} layout tensor to numpy. Use Tensor.dense() first" ) # We don't check that the tensor is on CPU when force is False, as this - # allows us to execute NumPy code on CUDA. - # We don't check that requires_grad=False as we are currently doing an - # unconditional detach. - # TODO: We may want to avoid detaching if `requires_grad=True` - # and `force=False` to allow computing gradients. + # allows us to execute NumPy code on CUDA. Same for requires_grad=True force = "force" in kwargs and kwargs["force"].as_python_constant() - proxy = tx.output.create_proxy( - "call_method", "detach", *proxy_args_kwargs([self], {}) - ) if force: - # TODO Add resolve_conj and resolve_neg once we support complex tensors + # If the user set force=True we try to preserve the semantics (no gradients, move to CPU...) + t = self.call_method(tx, "detach", [], {}) + proxy = tx.output.create_proxy( + "call_method", "cpu", (t.as_proxy(),), {} + ) + else: + # Hacky way to create a view of self that will be marked as NumpyNdarrayVariable proxy = tx.output.create_proxy( - "call_method", "cpu", *proxy_args_kwargs([self], {}) + "call_method", "view_as", *proxy_args_kwargs([self, self], {}) ) return NumpyNdarrayVariable.create(tx, proxy) elif name == "tolist": From 79ee99e6d2454612cd8c3fc93f468a8830bbb8fc Mon Sep 17 00:00:00 2001 From: lezcano Date: Tue, 28 Nov 2023 08:18:50 +0000 Subject: [PATCH 217/221] [easy] Dispatch torch.from_numpy to torch.as_tensor (#114609) ...rather than detaching the tensor Pull Request resolved: https://github.com/pytorch/pytorch/pull/114609 Approved by: https://github.com/larryliu0820, https://github.com/voznesenskym ghstack dependencies: #114608 --- test/inductor/test_torchinductor.py | 14 +++++++++++++- torch/_dynamo/variables/torch.py | 29 ++++++++++------------------- 2 files changed, 23 insertions(+), 20 deletions(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 82b68b37db0a..70f95bd0fb8b 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -8181,17 +8181,29 @@ def wrapper(x): y = my_np(x) return torch.as_tensor(y) + @torch.compile + def wrapper2(x): + x = x.numpy() + y = my_np(x) + return torch.from_numpy(y) + x_np = torch.arange(8, dtype=torch.float32, requires_grad=True) x = torch.arange(8, dtype=torch.float32, requires_grad=True) - out_np = wrapper(x_np) out = my_torch(x) self.assertEqual(out, out_np) + x2_np = torch.arange(8, dtype=torch.float32, requires_grad=True) + out2_np = wrapper2(x2_np) + self.assertEqual(out, out2_np) + out_np.backward() out.backward() self.assertEqual(x.grad, x_np.grad) + out2_np.backward() + self.assertEqual(x.grad, x2_np.grad) + # Disable constant propagation, so we isolate value range analysis @patch.object(config, "constant_and_index_propagation", False) @patch.object(config, "joint_graph_constant_folding", False) diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index e20ba3abee9f..ab4a9e0eec70 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -386,25 +386,16 @@ def call_function( unimplemented("torch.from_numpy. config.trace_numpy is False") if not np: unimplemented("torch.from_numpy. NumPy is not available") - assert len(args) == 1, f"Got arguments {args}" - assert not kwargs - t = args[0] - from .tensor import NumpyNdarrayVariable - - if isinstance(t, NumpyNdarrayVariable): - # TODO: mark the tensor as non-resizable - return wrap_fx_proxy_cls( - target_cls=TensorVariable, - tx=tx, - proxy=tx.output.create_proxy( - "call_function", - torch.detach, - *proxy_args_kwargs(args, {}), - ), - example_value=None, - ) - else: - unimplemented(f"torch.from_numpy(<{type(t)}>)") + return wrap_fx_proxy_cls( + target_cls=TensorVariable, + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + torch.as_tensor, + *proxy_args_kwargs(args, {}), + ), + example_value=None, + ) elif can_dispatch_torch_function(tx, args, kwargs): return dispatch_torch_function(tx, self, args, kwargs) elif self.value is torch.autograd._profiler_enabled: From cc7a969bb38e434848ec7a5187ef6f4d97886092 Mon Sep 17 00:00:00 2001 From: Andrew Gu Date: Mon, 27 Nov 2023 15:07:45 -0800 Subject: [PATCH 218/221] [FSDP] Added test for `ignored_states` + auto wrap (#114612) This adds some unit testing for the `ignored_states` argument and auto wrapping. There is some ongoing discussion with @erhoo82 about his particular use case, but it should not block this PR. (We can land a separate PR if needed.) Pull Request resolved: https://github.com/pytorch/pytorch/pull/114612 Approved by: https://github.com/wanchaol ghstack dependencies: #114611 --- .../fsdp/test_fsdp_ignored_modules.py | 56 ++++++++++++++++++- 1 file changed, 55 insertions(+), 1 deletion(-) diff --git a/test/distributed/fsdp/test_fsdp_ignored_modules.py b/test/distributed/fsdp/test_fsdp_ignored_modules.py index 44fef4c2f369..dc3eee5761b4 100644 --- a/test/distributed/fsdp/test_fsdp_ignored_modules.py +++ b/test/distributed/fsdp/test_fsdp_ignored_modules.py @@ -1,6 +1,7 @@ # Owner(s): ["oncall: distributed"] import functools +import math import sys import torch @@ -10,7 +11,7 @@ from torch.distributed._composable import fully_shard from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp._common_utils import _get_module_fsdp_state -from torch.distributed.fsdp.wrap import ModuleWrapPolicy +from torch.distributed.fsdp.wrap import ModuleWrapPolicy, transformer_auto_wrap_policy from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_fsdp import ( CUDAInitMode, @@ -92,6 +93,10 @@ def __init__(self, num_ignored: int) -> None: class TestFSDPIgnoredModules(FSDPTest): + @property + def world_size(self): + return min(torch.cuda.device_count(), 2) + def _train_model(self, model, optim, num_iters, device=torch.device("cuda")): for _ in range(num_iters): module = model.module if isinstance(model, FSDP) else model @@ -270,6 +275,55 @@ def _test_ignored_modules_nested( optim = torch.optim.Adam(wrapped_model.parameters(), lr=1e-3) self._train_model(wrapped_model, optim, 3) + @skip_if_lt_x_gpu(2) + def test_ignored_states_auto_wrap(self): + transformer_policy = functools.partial( + transformer_auto_wrap_policy, transformer_layer_cls={nn.Sequential} + ) + self.run_subtests( + { + "policy": [transformer_policy, ModuleWrapPolicy((nn.Sequential,))], + "ignore_bias": [True, False], + }, + self._test_ignored_states_auto_wrap, + ) + + def _test_ignored_states_auto_wrap(self, policy, ignore_bias: bool): + model = Model().cuda() + ignored_states = [model.layer1[1].weight] + if ignore_bias: + ignored_states.append(model.layer1[1].bias) + # Construct 2 flat parameters: one for `layer1` and one for the model + fsdp_model = FSDP( + model, + # Use `False` to avoid complexity of intra-flat-parameter padding + use_orig_params=False, + auto_wrap_policy=policy, + ignored_states=ignored_states, + ) + ref_model = Model() + expected_layer1_unsharded_numel = ( + sum(p.numel() for p in ref_model.layer1.parameters()) + - ref_model.layer1[1].weight.numel() + ) + if ignore_bias: + expected_layer1_unsharded_numel -= ref_model.layer1[1].bias.numel() + expected_model_unsharded_numel = sum( + p.numel() for p in ref_model.parameters() + ) - sum(p.numel() for p in ref_model.layer1.parameters()) + expected_layer1_sharded_numel = math.ceil( + expected_layer1_unsharded_numel / self.world_size + ) + expected_model_sharded_numel = math.ceil( + expected_model_unsharded_numel / self.world_size + ) + self.assertLessEqual( + fsdp_model.layer1.module._flat_param.numel(), expected_layer1_sharded_numel + ) + self.assertLessEqual( + fsdp_model.module._flat_param.numel(), expected_model_sharded_numel + ) + @skip_if_lt_x_gpu(2) @parametrize("composable", [True, False]) def test_ignored_modules_invalid(self, composable): From 0bef97fac35bdffc15a4c142bae6a5516799f862 Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Mon, 27 Nov 2023 18:23:05 +0000 Subject: [PATCH 219/221] [dynamo] Support itertools.groupby (#114192) Summary: for https://github.com/pytorch/pytorch/issues/108698 Pull Request resolved: https://github.com/pytorch/pytorch/pull/114192 Approved by: https://github.com/jansel --- test/dynamo/test_misc.py | 30 ++++++++++++++++++ torch/_dynamo/variables/misc.py | 54 +++++++++++++++++++++++++++++++++ 2 files changed, 84 insertions(+) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 945bbe4078f9..a734fe4a0828 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -7995,6 +7995,36 @@ def fn(it): self.assertEqual(list(eager), list(compiled)) self.assertEqual(counter.frame_count, 1) + def test_itertools_groupby_pure_python_default_identify_func(self): + counters.clear() + + def fn(l): + return [(k, list(g)) for k, g in itertools.groupby(l)] + + l = [1, 2, 2, 3, 4, 4, 4, 1, 2] + eager = fn(l) + + compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn) + compiled = compiled_fn(l) + + self.assertEqual(eager, compiled) + self.assertEqual(len(counters["graph_break"]), 0) + + def test_itertools_groupby_pure_python_key_func(self): + counters.clear() + + def fn(l): + return [(k, list(g)) for k, g in itertools.groupby(l, key=operator.neg)] + + l = [1, 2, -2, 3, 4, 4, -4, 0, -2] + eager = fn(l) + + compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn) + compiled = compiled_fn(l) + + self.assertEqual(eager, compiled) + self.assertEqual(len(counters["graph_break"]), 0) + def test_shape_env_no_recording(self): main = ShapeEnv(should_record_events=False) diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index c2b1bedce853..5aca0349762a 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -848,6 +848,60 @@ def call_function( for item in itertools.combinations(iterable, r): items.append(variables.TupleVariable(list(item))) return variables.ListIteratorVariable(items, mutable_local=MutableLocal()) + elif self.value is itertools.groupby: + if any(kw != "key" for kw in kwargs.keys()): + unimplemented( + "Unsupported kwargs for itertools.groupby: " + f"{','.join(set(kwargs.keys()) - {'key'})}" + ) + + def retrieve_const_key(key): + if isinstance(key, variables.SymNodeVariable): + return key.evaluate_expr() + elif isinstance(key, variables.ConstantVariable): + return key.as_python_constant() + else: + raise unimplemented( + "Unsupported key type for itertools.groupby: " + str(type(key)) + ) + + if len(args) == 1 and args[0].has_unpack_var_sequence(tx): + seq = args[0].unpack_var_sequence(tx) + keyfunc = ( + ( + lambda x: ( + retrieve_const_key( + kwargs.get("key").call_function(tx, [x], {}) + ) + ) + ) + if "key" in kwargs + else None + ) + else: + unimplemented("Unsupported arguments for itertools.groupby") + + result = [] + try: + for k, v in itertools.groupby(seq, key=keyfunc): + result.append( + variables.TupleVariable( + [ + variables.ConstantVariable.create(k) + if variables.ConstantVariable.is_literal(k) + else k, + variables.ListIteratorVariable( + list(v), mutable_local=MutableLocal() + ), + ], + mutable_local=MutableLocal(), + ) + ) + except Exception: + raise unimplemented( # noqa: TRY200 + "Unexpected failure when calling itertools.groupby" + ) + return variables.ListIteratorVariable(result, mutable_local=MutableLocal()) elif ( self.value is functools.wraps and not kwargs From b060694088bf90f66d667b274c56854c8426ee3a Mon Sep 17 00:00:00 2001 From: Nikita Shulga <2453524+malfet@users.noreply.github.com> Date: Tue, 28 Nov 2023 15:21:58 +0000 Subject: [PATCH 220/221] Add `bits` dtypes to `torch._C` stubs (#114661) As defined https://github.com/pytorch/pytorch/blob/6ae0554d11b973930d7b8ec1e937b27ac961d7bf/c10/core/ScalarType.h#L54-L58 Pull Request resolved: https://github.com/pytorch/pytorch/pull/114661 Approved by: https://github.com/ngimel --- tools/pyi/gen_pyi.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tools/pyi/gen_pyi.py b/tools/pyi/gen_pyi.py index 208003438f58..eb8a997ac855 100644 --- a/tools/pyi/gen_pyi.py +++ b/tools/pyi/gen_pyi.py @@ -1264,6 +1264,11 @@ def replace_special_case(hint: str) -> str: "bool", "quint4x2", "quint2x4", + "bits1x8", + "bits2x4", + "bits4x2", + "bits8", + "bits16", ] ] From e6a8052051e0ade18deaa5758d7327c6ab6d9fa3 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Mon, 27 Nov 2023 16:55:13 -0800 Subject: [PATCH 221/221] [C10D] Flight recorder - disable c++ stacktrace by default (#114651) CPP Stacktrace processing (symbolizer) takes a long time on some systems using a particular version of addr2line. In slow systems, this makes flight-recorder dumping slow enough to time out on even toy programs. TORCH_NCCL_TRACE_CPP_STACK=True will re-enable CPP stacktrace collection as part of the flight recorder. CPP stacktrace is fast enough for use on certain combinations of OS. We can investigate moving to llvm's symbolizer as a replacement. On devserver with C++ stacktraces disabled/enabled: ``` python test/distributed/test_c10d_nccl.py -k test_short Ran 1 test in 12.175s TORCH_NCCL_TRACE_CPP_STACK=1 python test/distributed/test_c10d_nccl.py -k test_short Ran 1 test in 53.338s ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/114651 Approved by: https://github.com/zdevito --- torch/csrc/distributed/c10d/TraceUtils.h | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torch/csrc/distributed/c10d/TraceUtils.h b/torch/csrc/distributed/c10d/TraceUtils.h index 9d72e9960b2e..5653b796e846 100644 --- a/torch/csrc/distributed/c10d/TraceUtils.h +++ b/torch/csrc/distributed/c10d/TraceUtils.h @@ -325,6 +325,7 @@ struct NCCLTraceBuffer { } NCCLTraceBuffer() { max_entries_ = getCvarInt({"TORCH_NCCL_TRACE_BUFFER_SIZE"}, 0); + capture_cpp_stack_ = getCvarBool({"TORCH_NCCL_TRACE_CPP_STACK"}, false); enabled_ = max_entries_ > 0; } using EventList = std::vector; @@ -351,6 +352,7 @@ struct NCCLTraceBuffer { }; bool enabled_ = false; + bool capture_cpp_stack_ = false; std::mutex mutex_; std::vector entries_; size_t max_entries_ = 0; @@ -368,7 +370,8 @@ struct NCCLTraceBuffer { if (!enabled_) { return c10::nullopt; } - auto traceback = torch::CapturedTraceback::gather(true, true, true); + auto traceback = + torch::CapturedTraceback::gather(true, true, capture_cpp_stack_); std::lock_guard guard(mutex_); auto te = Entry{