From 4333e122d4b74cdf84351ed2907045c6a767b4cd Mon Sep 17 00:00:00 2001 From: Will Feng Date: Thu, 16 May 2024 10:27:01 +0000 Subject: [PATCH 001/116] [Traceable FSDP2] Add all_gather_into_tensor out variant (#126334) This PR adds `torch.ops._c10d_functional.all_gather_into_tensor_out`. It's important for tracing FSDP2, because FSDP2 pre-allocates the output buffer of AllGather, and makes input buffer an alias of the output buffer, and expects both of them to be used to achieve lower memory usage. If we don't preserve this behavior and instead functionalize the AllGather op, AllGather op will then create a brand-new output buffer (instead of reusing), thus significantly increasing the memory usage. The expectation is that we will "re-inplace" the AllGather op by switching to the out variant in Inductor post-grad stage via an FX pass, so this API is not expected to be directly used by users. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126334 Approved by: https://github.com/yifuwang, https://github.com/wanchaol --- .../test_c10d_functional_native.py | 12 +++++++++++ torch/csrc/distributed/c10d/Functional.cpp | 20 +++++++++++++++++++ torch/distributed/_functional_collectives.py | 9 +++++++++ 3 files changed, 41 insertions(+) diff --git a/test/distributed/test_c10d_functional_native.py b/test/distributed/test_c10d_functional_native.py index 54030d1f1d42..775d3f9cc03d 100644 --- a/test/distributed/test_c10d_functional_native.py +++ b/test/distributed/test_c10d_functional_native.py @@ -195,6 +195,18 @@ def test_all_gather_into_tensor_single(self) -> None: assert torch.allclose(output, expect) assert output.eq(expect).all() + # Test out-variant of all_gather_into_tensor + output = torch.empty(expect.shape, device=self.device) + output = torch.ops._c10d_functional.all_gather_into_tensor_out( + input, + self.world_size, + "default", + out=output, + ) + output = torch.ops._c10d_functional.wait_tensor(output) + assert torch.allclose(output, expect) + assert output.eq(expect).all() + # Test Python API and AsyncCollectiveTensor output = all_gather_tensor( input, diff --git a/torch/csrc/distributed/c10d/Functional.cpp b/torch/csrc/distributed/c10d/Functional.cpp index d392c0213b84..9d525f0d5640 100644 --- a/torch/csrc/distributed/c10d/Functional.cpp +++ b/torch/csrc/distributed/c10d/Functional.cpp @@ -196,6 +196,19 @@ at::Tensor all_gather_into_tensor( inputs, group_size, std::move(group_name))[0]; } +at::Tensor& all_gather_into_tensor_out( + at::Tensor& input, + int64_t group_size, + std::string group_name, + at::Tensor& output) { + c10d::AllgatherOptions opts; + + auto group = c10d::resolve_process_group(group_name); + auto work = group->_allgather_base(output, input, opts); + c10d::RankLocal::get().register_work(output, work); + return output; +} + at::Tensor allocate_reduce_scatter_output( const at::Tensor& input, const int64_t group_size) { @@ -321,6 +334,13 @@ TORCH_LIBRARY(_c10d_functional, m) { c10::DispatchKey::CompositeExplicitAutograd, ::all_reduce_coalesced_), {at::Tag::pt2_compliant_tag}); + m.def( + "all_gather_into_tensor_out(Tensor input, int group_size, str group_name, *, Tensor(a!) out) -> Tensor(a!)", + torch::dispatch( + c10::DispatchKey::CompositeExplicitAutograd, + ::all_gather_into_tensor_out), + {at::Tag::pt2_compliant_tag}); + m.def( "all_gather_into_tensor(Tensor input, int group_size, str group_name) -> Tensor", torch::dispatch( diff --git a/torch/distributed/_functional_collectives.py b/torch/distributed/_functional_collectives.py index b1250eddf037..8d598713cf50 100644 --- a/torch/distributed/_functional_collectives.py +++ b/torch/distributed/_functional_collectives.py @@ -894,6 +894,12 @@ def _all_to_all_single_meta( return input.new_empty(out_size) +def _all_gather_into_tensor_out_native_meta(input, group_size, group_name, *, out): + shape = list(input.size()) + shape[0] *= group_size + return input.new_empty(shape) + + def _all_gather_into_tensor_native_meta(input, group_size, group_name): shape = list(input.size()) shape[0] *= group_size @@ -932,6 +938,9 @@ def _reduce_scatter_tensor_coalesced_native_meta( lib_impl.impl("all_reduce_coalesced", _all_reduce_coalesced_meta, "Meta") lib_impl.impl("all_reduce_coalesced_", _all_reduce_coalesced__meta, "Meta") lib_impl.impl("wait_tensor", _wait_tensor_meta, "Meta") + lib_impl.impl( + "all_gather_into_tensor_out", _all_gather_into_tensor_out_native_meta, "Meta" + ) lib_impl.impl("all_gather_into_tensor", _all_gather_into_tensor_native_meta, "Meta") lib_impl.impl( "all_gather_into_tensor_coalesced", From 691af57fbc8b4a5cc4d53d0b0bad6e17e8d36276 Mon Sep 17 00:00:00 2001 From: yuanx749 Date: Thu, 16 May 2024 11:46:32 +0000 Subject: [PATCH 002/116] Fix broken link of scikit-learn (#120972) The link is broken in https://pytorch.org/docs/main/community/design.html Pull Request resolved: https://github.com/pytorch/pytorch/pull/120972 Approved by: https://github.com/Skylion007 --- docs/source/community/design.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/community/design.rst b/docs/source/community/design.rst index 73ed7e1447b8..16b1500afcdd 100644 --- a/docs/source/community/design.rst +++ b/docs/source/community/design.rst @@ -119,7 +119,7 @@ This principle began as **Python First**: PyTorch is not a Python binding into a monolithic C++ framework. It is built to be deeply integrated into Python. You can use it naturally like you would use `NumPy `__, - `SciPy `__, `scikit-learn <(https://scikit-learn.org/>`__, + `SciPy `__, `scikit-learn `__, or other Python libraries. You can write your new neural network layers in Python itself, using your favorite libraries and use packages such as `Cython `__ and From c2f8c75129e0837b5e3ecd9cb5139d635a8f2f7b Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Thu, 16 May 2024 12:00:16 +0000 Subject: [PATCH 003/116] [Reopen] Upgrade submodule oneDNN to v3.4.2 (#126137) Reopen of https://github.com/pytorch/pytorch/pull/122472 ## Improvements This upgrade fixes the following issues: - https://github.com/pytorch/pytorch/issues/120982 This upgrade brings the following new features: - Introduced memory descriptor serialization API. This API is needed to support freezing on CPU in AOTInductor (https://github.com/pytorch/pytorch/issues/114450) ## Validation results on CPU Original results with oneDNN v3.4.1 are here: https://github.com/pytorch/pytorch/pull/122472#issue-2201602846 Need to rerun validation and update results. Co-authored-by: Sunita Nadampalli Pull Request resolved: https://github.com/pytorch/pytorch/pull/126137 Approved by: https://github.com/jgong5, https://github.com/snadampal, https://github.com/atalman --- .ci/docker/common/install_acl.sh | 2 +- third_party/ideep | 2 +- third_party/mkl-dnn.BUILD | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.ci/docker/common/install_acl.sh b/.ci/docker/common/install_acl.sh index f5e5ce92af4a..8a6dc4d1c79c 100644 --- a/.ci/docker/common/install_acl.sh +++ b/.ci/docker/common/install_acl.sh @@ -1,6 +1,6 @@ set -euo pipefail -readonly version=v23.08 +readonly version=v24.04 readonly src_host=https://review.mlplatform.org/ml readonly src_repo=ComputeLibrary diff --git a/third_party/ideep b/third_party/ideep index 8a6cc4e09dc5..55ca0191687a 160000 --- a/third_party/ideep +++ b/third_party/ideep @@ -1 +1 @@ -Subproject commit 8a6cc4e09dc509f04f83c085e38786b1fb44e14d +Subproject commit 55ca0191687aaf19aca5cdb7881c791e3bea442b diff --git a/third_party/mkl-dnn.BUILD b/third_party/mkl-dnn.BUILD index dac4f9e3e8cf..9a688a52b1cf 100644 --- a/third_party/mkl-dnn.BUILD +++ b/third_party/mkl-dnn.BUILD @@ -63,9 +63,9 @@ template_rule( out = "include/oneapi/dnnl/dnnl_version.h", substitutions = { "@DNNL_VERSION_MAJOR@": "3", - "@DNNL_VERSION_MINOR@": "3", - "@DNNL_VERSION_PATCH@": "6", - "@DNNL_VERSION_HASH@": "86e6af5974177e513fd3fee58425e1063e7f1361", + "@DNNL_VERSION_MINOR@": "4", + "@DNNL_VERSION_PATCH@": "2", + "@DNNL_VERSION_HASH@": "1137e04ec0b5251ca2b4400a4fd3c667ce843d67", }, ) From ab07867084779cd95f9bc9e60c9376a9a6eefd6b Mon Sep 17 00:00:00 2001 From: Andrew Gu Date: Tue, 14 May 2024 05:46:12 -0700 Subject: [PATCH 004/116] [FSDP2] Supported `set_all_reduce_gradients=False` for HSDP (#126166) **Context** For FSDP, gradient accumulation across microbatches has two flavors: (1) reduce-scatter or (2) no reduce-scatter. (1) incurs the collective per microbatch backward but saves gradient memory (storing the sharded gradients), while (2) avoids the communication but uses more gradient memory (storing the unsharded gradients). - FSDP2 offers (1) without any intervention. The user should simply make sure to run the optimizer step after `K` microbatches for `K > 1`. - FSDP2 offers (2) via `module.set_requires_gradient_sync()` (e.g. `module.set_requires_gradient_sync(is_last_microbatch)`. For HSDP, since we reduce-scatter and then all-reduce, we have additional flexibility and get three flavors: (1) reduce-scatter and all-reduce, (2) reduce-scatter but no all-reduce, and (3) no reduce-scatter and no all-reduce. This PR adds support for (2). - FSDP2 offers (1) without any intervention like mentioned above. - FSDP2 offers (3) via `module.set_requires_gradient_sync()` like mentioned above. - FSDP2 offers (2) via `module.set_requires_all_reduce()` similar to `set_requires_gradient_sync()`. **Overview** For HSDP, to reduce-scatter but not all-reduce during gradient accumulation, the user can do something like: ``` for microbatch_idx, microbatch in enumerate(microbatches): is_last_microbatch = microbatch_idx == len(microbatches) - 1 model.set_requires_all_reduce(is_last_microbatch) # Run forward/backward ``` This PR also makes the minor change of making the `recurse: bool` argument in these setter methods to be kwarg only. **Developer Notes** We choose to implement this by saving the partial reduce output to the `FSDPParamGroup` for simplicity, where we assume that the set of parameters that receive gradients does not change across microbatches. An alternative would be to view into the partial reduce output per parameter and save the view to each parameter. We prefer to avoid this alternative for now because it introduces more complexity to do extra viewing when saving the partial reduce output to each parameter, accumulating into them, and accumulating back to the last microbatch's reduce output. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126166 Approved by: https://github.com/weifengpy, https://github.com/wanchaol ghstack dependencies: #126067, #126070, #126161 --- .../_composable/fsdp/test_fully_shard_comm.py | 6 +- .../fsdp/test_fully_shard_training.py | 72 ++++++++++++------- .../_composable/fsdp/_fsdp_collectives.py | 55 ++++++++------ .../_composable/fsdp/_fsdp_param_group.py | 20 +++--- .../_composable/fsdp/fully_shard.py | 9 +-- 5 files changed, 98 insertions(+), 64 deletions(-) diff --git a/test/distributed/_composable/fsdp/test_fully_shard_comm.py b/test/distributed/_composable/fsdp/test_fully_shard_comm.py index 115c1f93227c..283b8ab2b944 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_comm.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_comm.py @@ -244,7 +244,7 @@ def _test_reduce_scatter( group = fsdp_param_group.mesh_info.shard_process_group self.assertEqual(group.size(), self.world_size) all_reduce_stream = torch.cuda.Stream() - view_out_event = foreach_reduce( + post_reduce_event, _ = foreach_reduce( fsdp_params, unsharded_grads, group, @@ -254,8 +254,10 @@ def _test_reduce_scatter( device=self.device, all_reduce_group=None, all_reduce_stream=all_reduce_stream, + all_reduce_grads=True, + partial_reduce_output=None, ) - torch.cuda.current_stream().wait_event(view_out_event) + torch.cuda.current_stream().wait_event(post_reduce_event) # Check reduce-scatter correctness predivide_factor, postdivide_factor = _get_gradient_divide_factors( diff --git a/test/distributed/_composable/fsdp/test_fully_shard_training.py b/test/distributed/_composable/fsdp/test_fully_shard_training.py index eec060d3004c..392596549d77 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_training.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_training.py @@ -672,6 +672,12 @@ def test_gradient_accumulation(self): "mode": ["all", "root_only", "some_mlps"], "reshard_after_backward": [False, True], "offload_policy": [OffloadPolicy(), CPUOffloadPolicy()], + # For HSDP only: + # `True`: reduce-scatter only (no all-reduce) each microbatch + # until the last microbatch + # `False`: neither reduce-scatter nor all-reduce each + # microbatch until the last microbatch + "reduce_scatter_only": [False, True], }, self._test_gradient_accumulation, ) @@ -683,15 +689,20 @@ def _test_gradient_accumulation( mode: str, reshard_after_backward: bool, offload_policy: OffloadPolicy, + reduce_scatter_only: bool, # for HSDP ): if ( - not reshard_after_backward - and (reshard_after_forward is not False or mode == "some_mlps") - ) or ( - isinstance(offload_policy, CPUOffloadPolicy) - and reshard_after_forward is not True + ( + not reshard_after_backward + and (reshard_after_forward is not False or mode == "some_mlps") + ) + or ( + isinstance(offload_policy, CPUOffloadPolicy) + and reshard_after_forward is not True + ) + or (mesh.ndim != 2 and reduce_scatter_only) ): - return # skip since not common + return # skip since not common or applicable torch.manual_seed(42) batch_size, lin_dim, num_mlps, num_microbatches = (2, 32, 3, 3) @@ -713,29 +724,35 @@ def _test_gradient_accumulation( ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2) optim = torch.optim.Adam(model.parameters(), lr=1e-2) + def set_grad_sync_flag( + module: nn.Module, is_last_microbatch: bool, recurse: bool = True + ): + if reduce_scatter_only: + module.set_requires_all_reduce(is_last_microbatch, recurse=recurse) + else: + module.set_requires_gradient_sync(is_last_microbatch, recurse=recurse) + + def set_backward_flags(_model: nn.Module, is_last_microbatch: bool): + if mode == "all": + set_grad_sync_flag(_model, is_last_microbatch) + if not reshard_after_backward: + _model.set_reshard_after_backward(is_last_microbatch) + elif mode == "some_mlps": + for mlp in model[1 : 1 + num_mlps_to_disable_reduce_scatter]: + set_grad_sync_flag(mlp, is_last_microbatch) + if not reshard_after_backward: + mlp.set_reshard_after_backward(is_last_microbatch) + elif mode == "root_only": + set_grad_sync_flag(model, is_last_microbatch, recurse=False) + if not reshard_after_backward: + model.set_reshard_after_backward(is_last_microbatch, recurse=False) + torch.manual_seed(42 + self.rank + 1) for iter_idx in range(5): with CommDebugMode() as comm_mode: for microbatch_idx in range(num_microbatches): is_last_microbatch = microbatch_idx == num_microbatches - 1 - if mode == "all": - model.set_requires_gradient_sync(is_last_microbatch) - if not reshard_after_backward: - model.set_reshard_after_backward(is_last_microbatch) - elif mode == "some_mlps": - for mlp in model[1 : 1 + num_mlps_to_disable_reduce_scatter]: - mlp.set_requires_gradient_sync(is_last_microbatch) - if not reshard_after_backward: - mlp.set_reshard_after_backward(is_last_microbatch) - elif mode == "root_only": - model.set_requires_gradient_sync( - is_last_microbatch, recurse=False - ) - if not reshard_after_backward: - model.set_reshard_after_backward( - is_last_microbatch, recurse=False - ) - + set_backward_flags(model, is_last_microbatch) inp = torch.randn(batch_size, lin_dim, device="cuda") losses: List[torch.Tensor] = [] for _model in (ref_model, model): @@ -760,10 +777,15 @@ def _test_gradient_accumulation( elif mode == "root_only": # Expect additional reduce-scatters for all MLPs expected_reduce_scatter_count += (num_mlps) * (num_microbatches - 1) - self.assertEqual(reduce_scatter_count, expected_reduce_scatter_count) expected_all_reduce_count = ( expected_reduce_scatter_count if mesh.ndim == 2 else 0 ) + if reduce_scatter_only: + # Specially for HSDP if only reduce-scattering but not + # all-reducing until the last microbatch, expect one + # reduce-scatter per MLP plus for the root per microbatch + expected_reduce_scatter_count = (num_mlps + 1) * num_microbatches + self.assertEqual(reduce_scatter_count, expected_reduce_scatter_count) self.assertEqual(all_reduce_count, expected_all_reduce_count) # Expect one all-gather per MLP plus one for the root's linear in diff --git a/torch/distributed/_composable/fsdp/_fsdp_collectives.py b/torch/distributed/_composable/fsdp/_fsdp_collectives.py index f27970315159..b7264cb34d6d 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_collectives.py +++ b/torch/distributed/_composable/fsdp/_fsdp_collectives.py @@ -125,9 +125,11 @@ def foreach_reduce( orig_dtype: torch.dtype, reduce_dtype: Optional[torch.dtype], device: torch.device, - all_reduce_group: Optional[dist.ProcessGroup], + all_reduce_group: Optional[dist.ProcessGroup], # not `None` iff HSDP all_reduce_stream: torch.cuda.Stream, -) -> torch.cuda.Event: + all_reduce_grads: bool, + partial_reduce_output: Optional[torch.Tensor], # only used for HSDP +) -> Tuple[torch.cuda.Event, Optional[torch.Tensor]]: """ ``unsharded_grads`` owns the references to the gradients computed by autograd, so clearing the list frees the gradients. @@ -163,36 +165,43 @@ def foreach_reduce( # computed in the default stream current_stream.wait_stream(reduce_scatter_stream) unsharded_grads.clear() - post_reduce_output = reduce_scatter_input.new_empty( - (reduce_scatter_output_numel,) - ) + reduce_output = reduce_scatter_input.new_empty((reduce_scatter_output_numel,)) _div_if_needed(reduce_scatter_input, predivide_factor) dist.reduce_scatter_tensor( - output=post_reduce_output, + output=reduce_output, input=reduce_scatter_input, group=reduce_scatter_group, op=ReduceOp.AVG if predivide_factor is None else ReduceOp.SUM, ) - view_out_stream = reduce_scatter_stream - if all_reduce_group is not None: - view_out_stream = all_reduce_stream - all_reduce_stream.wait_stream(reduce_scatter_stream) - with torch.cuda.stream(all_reduce_stream): - dist.all_reduce( - post_reduce_output, - group=all_reduce_group, - op=ReduceOp.AVG if predivide_factor is None else ReduceOp.SUM, - ) - with torch.cuda.stream(view_out_stream): - _div_if_needed(post_reduce_output, postdivide_factor) - post_reduce_output = _to_dtype_if_needed(post_reduce_output, orig_dtype) - # - View out and accumulate + post_reduce_stream = reduce_scatter_stream + if all_reduce_group is not None: # HSDP + # Accumulations must run in the reduce-scatter stream + if not all_reduce_grads: + if partial_reduce_output is not None: + partial_reduce_output += reduce_output + else: + partial_reduce_output = reduce_output + return post_reduce_stream.record_event(), partial_reduce_output + if partial_reduce_output is not None: + reduce_output += partial_reduce_output + post_reduce_stream = all_reduce_stream + all_reduce_stream.wait_stream(reduce_scatter_stream) + with torch.cuda.stream(all_reduce_stream): + dist.all_reduce( + reduce_output, + group=all_reduce_group, + op=ReduceOp.AVG if predivide_factor is None else ReduceOp.SUM, + ) + with torch.cuda.stream(post_reduce_stream): + _div_if_needed(reduce_output, postdivide_factor) + reduce_output = _to_dtype_if_needed(reduce_output, orig_dtype) + # View out and accumulate sharded gradients flat_grad_offset = 0 # [0, reduce_scatter_output_numel - 1] for padded_unsharded_size, fsdp_param in zip( padded_unsharded_sizes, fsdp_params ): new_sharded_grad = torch.as_strided( - post_reduce_output, + reduce_output, size=fsdp_param.sharded_size, stride=fsdp_param.contiguous_sharded_stride, storage_offset=flat_grad_offset, @@ -220,12 +229,12 @@ def foreach_reduce( fsdp_param.sharded_param.grad = new_sharded_dtensor_grad padded_sharded_numel = padded_unsharded_size.numel() // world_size flat_grad_offset += padded_sharded_numel - post_reduce_view_out_event = view_out_stream.record_event() + post_reduce_event = post_reduce_stream.record_event() # The RS output is allocated in the RS stream and used in the default # stream (for optimizer). To ensure its memory is not reused for later # RSs, we do not need extra synchronization since the sharded parameters # hold refs through the end of backward. - return post_reduce_view_out_event + return post_reduce_event, None def foreach_reduce_scatter_copy_in( diff --git a/torch/distributed/_composable/fsdp/_fsdp_param_group.py b/torch/distributed/_composable/fsdp/_fsdp_param_group.py index 9e9813102db3..569858e92656 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_param_group.py +++ b/torch/distributed/_composable/fsdp/_fsdp_param_group.py @@ -138,11 +138,15 @@ def __init__( # Holds the reduce-scatter/all-reduce view-out CUDA event that marks the end of # the group's post-backward (e.g. reduce-scatter, all-reduce and div), which # should be waited on at the end of backward - self._post_reduce_view_out_event: Optional[torch.cuda.Event] = None + self._post_reduce_event: Optional[torch.cuda.Event] = None # Holds the reshard-after-forward CUDA event when resharding to a # different world size, which should be waited on in the next unshard self._reshard_after_forward_event: Optional[torch.cuda.Event] = None + # Only for HSDP, if accumulating gradients without all-reduce, save the + # partial reduce output (only reduce-scattered but not all-reduced) + self._partial_reduce_output: Optional[torch.Tensor] = None + # Initialization # def _init_mp_dtypes(self) -> None: for fsdp_param in self.fsdp_params: @@ -311,7 +315,7 @@ def post_backward(self, *unused: Any): if len(fsdp_params_with_grad) == 0: return with torch.profiler.record_function("FSDP::post_backward_reduce"): - self._post_reduce_view_out_event = foreach_reduce( + self._post_reduce_event, self._partial_reduce_output = foreach_reduce( fsdp_params_with_grad, unsharded_grads, self._reduce_scatter_process_group, @@ -319,16 +323,16 @@ def post_backward(self, *unused: Any): self._orig_dtype, self._reduce_dtype, self.device, - self._all_reduce_process_group - if self._is_hsdp and self.all_reduce_grads - else None, + self._all_reduce_process_group if self._is_hsdp else None, self.comm_ctx.all_reduce_stream, + self.all_reduce_grads, + self._partial_reduce_output, ) def finalize_backward(self): - if self._post_reduce_view_out_event is not None: - torch.cuda.current_stream().wait_event(self._post_reduce_view_out_event) - self._post_reduce_view_out_event = None + if self._post_reduce_event is not None: + torch.cuda.current_stream().wait_event(self._post_reduce_event) + self._post_reduce_event = None for fsdp_param in self.fsdp_params: if fsdp_param.grad_offload_event is not None: fsdp_param.grad_offload_event.synchronize() diff --git a/torch/distributed/_composable/fsdp/fully_shard.py b/torch/distributed/_composable/fsdp/fully_shard.py index a5204701731c..981b82987462 100644 --- a/torch/distributed/_composable/fsdp/fully_shard.py +++ b/torch/distributed/_composable/fsdp/fully_shard.py @@ -208,7 +208,7 @@ def set_is_last_backward(self, is_last_backward: bool) -> None: state._state_ctx.is_last_backward = is_last_backward def set_requires_gradient_sync( - self, requires_gradient_sync: bool, recurse: bool = True + self, requires_gradient_sync: bool, *, recurse: bool = True ) -> None: """ Sets if the module should sync gradients. This can be used to implement @@ -231,16 +231,13 @@ def set_requires_gradient_sync( fsdp_param_group.all_reduce_grads = requires_gradient_sync def set_requires_all_reduce( - self, requires_all_reduce: bool, recurse: bool = True + self, requires_all_reduce: bool, *, recurse: bool = True ) -> None: """ Sets if the module should all-reduce gradients. This can be used to implement gradient accumulation with only reduce-scatter but not all-reduce for HSDP. """ - # TODO: post_reduce_output += fsdp_param.sharded_param.grad - # after reduce-scatter and before all-reduce - raise NotImplementedError("requires_all_reduce is not yet supported in HSDP") self_module = cast(nn.Module, self) modules = list(self_module.modules()) if recurse else [self_module] for module in modules: @@ -250,7 +247,7 @@ def set_requires_all_reduce( fsdp_param_group.all_reduce_grads = requires_all_reduce def set_reshard_after_backward( - self, reshard_after_backward: bool, recurse: bool = True + self, reshard_after_backward: bool, *, recurse: bool = True ) -> None: """ Sets if the module should reshard parameters after backward. This can From 91bf952d10e9524a9b078900d9807efa5d252f5c Mon Sep 17 00:00:00 2001 From: Nikita Shulga <2453524+malfet@users.noreply.github.com> Date: Thu, 16 May 2024 13:41:45 +0000 Subject: [PATCH 005/116] Fix aarch64 debug build with GCC (#126290) By working around GCCs quirks in instantiating templates that require immediate values. Provide alternative implementation for scaling the output if compiled without any optimizations (both GCC and clang define `__OPTIMIZE__` if invoked with anything but `-O0`) Fixes https://github.com/pytorch/pytorch/issues/126283 Pull Request resolved: https://github.com/pytorch/pytorch/pull/126290 Approved by: https://github.com/atalman, https://github.com/seemethere --- aten/src/ATen/native/cpu/int8mm_kernel.cpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/aten/src/ATen/native/cpu/int8mm_kernel.cpp b/aten/src/ATen/native/cpu/int8mm_kernel.cpp index bd266030b256..9eaf43ec5f00 100644 --- a/aten/src/ATen/native/cpu/int8mm_kernel.cpp +++ b/aten/src/ATen/native/cpu/int8mm_kernel.cpp @@ -250,11 +250,19 @@ inline void tinygemm_kernel_( }); } +#if __OPTIMIZE__ float32x4_t scale_val = load_as_float32x4(scales); c10::ForcedUnroll{}([&](auto i) { C[m * ldc + i] = reduce(c_val[i]) * vgetq_lane_f32(scale_val, i); }); } +#else + // Workaround GCCs inability to infer lane index at compile time + // See https://github.com/pytorch/pytorch/issues/126283 + c10::ForcedUnroll{}([&](auto i) { + C[m * ldc + i] = reduce(c_val[i]) * float(scales[i]); + }); +#endif } template From 14d8e3aec0f755c26ba17fd81f9144c8691fe116 Mon Sep 17 00:00:00 2001 From: Jithun Nair Date: Thu, 16 May 2024 16:38:09 +0000 Subject: [PATCH 006/116] Add distributed/_tensor/test_attention to ROCM_BLOCKLIST (#126336) Fixes #125504 Fixes #126252 Fixes #126296 Fixes #126330 This PR doesn't really fix the RingAttentionTest tests for ROCm, but explicitly adds the whole test file to ROCM_BLOCKLIST to get a clean signal on ROCm distributed CI. We will enable these tests in a follow-up PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126336 Approved by: https://github.com/huydhn, https://github.com/pruthvistony --- test/run_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/run_test.py b/test/run_test.py index 5b24a0078996..cbee11b37a7a 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -181,6 +181,7 @@ def __contains__(self, item): "test_jit_legacy", "test_cuda_nvml_based_avail", "test_jit_cuda_fuser", + "distributed/_tensor/test_attention", ] XPU_BLOCKLIST = [ From f155ed6bf28e3a44f7808959378c5346aa713c8e Mon Sep 17 00:00:00 2001 From: Andres Lugo-Reyes Date: Thu, 16 May 2024 16:40:31 +0000 Subject: [PATCH 007/116] [ROCm] amax hipblaslt integration (#125921) AMAX is coming as part of rocm6.2. This code adds that functionality Pull Request resolved: https://github.com/pytorch/pytorch/pull/125921 Approved by: https://github.com/eqy, https://github.com/lezcano --- aten/src/ATen/cuda/CUDABlas.cpp | 9 ++++++--- aten/src/ATen/native/cuda/Blas.cpp | 14 ++++++++++++-- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index 2502456e285b..ce991a9bcad4 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -1422,10 +1422,13 @@ void scaled_gemm( computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, mat1_scale_ptr); computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, mat2_scale_ptr); computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr); +#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 60200) + // Amax support in ROCm as of 6.2 + if (isFloat8Type(result_dtype)) { + computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, amax_ptr); + } +#endif #ifndef USE_ROCM -if (isFloat8Type(result_dtype)) { - computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, amax_ptr); -} computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_FAST_ACCUM, fastAccuMode); #endif CuBlasLtMatrixLayout Adesc(ScalarTypeToCudaDataType(mat1_dtype), m, k, mat1_ld, transa == 't'); diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp index c0ed650cf021..84c59a4fd0d7 100644 --- a/aten/src/ATen/native/cuda/Blas.cpp +++ b/aten/src/ATen/native/cuda/Blas.cpp @@ -32,6 +32,7 @@ #include #include #include +#include #include #include #endif @@ -988,6 +989,11 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, else #endif { +#if defined(USE_ROCM) && ROCM_VERSION >= 60200 + // hipBlasLT requires scaleD to be set to something in order to use AMAX + auto dummy_options = TensorOptions().dtype(kFloat).device(kCUDA); + auto dummy_scale = at::ones(1, dummy_options); +#endif at::cuda::blas::scaled_gemm( args.transa, args.transb, @@ -1005,15 +1011,19 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, bias ? bias->data_ptr(): nullptr, bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_, args.result->data_ptr(), +#if defined(USE_ROCM) && ROCM_VERSION >= 60200 + scale_result ? scale_result->data_ptr() : dummy_scale.data_ptr(), +#else scale_result ? scale_result->data_ptr() : nullptr, +#endif args.result_ld, out_dtype_, amax.data_ptr(), use_fast_accum); } -#if defined(USE_ROCM) - // rocm's hipblaslt does not yet support amax, so calculate separately +#if defined(USE_ROCM) && ROCM_VERSION >= 60000 && ROCM_VERSION < 60200 + // ROCm's hipBLASLt does not support amax before 6.2, so calculate separately amax = at::max(at::abs(out.to(kFloat))); #endif From a55d63659ad0b9a14cbf5b495464994a9180c988 Mon Sep 17 00:00:00 2001 From: Jithun Nair <37884920+jithunnair-amd@users.noreply.github.com> Date: Thu, 16 May 2024 16:50:02 +0000 Subject: [PATCH 008/116] Add 2nd shard to ROCm trunk workflow for core distributed UTs (#121716) Pull Request resolved: https://github.com/pytorch/pytorch/pull/121716 Approved by: https://github.com/ezyang, https://github.com/huydhn --- .github/workflows/trunk.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/trunk.yml b/.github/workflows/trunk.yml index 6c2bd277166c..00813edd3d91 100644 --- a/.github/workflows/trunk.yml +++ b/.github/workflows/trunk.yml @@ -194,6 +194,7 @@ jobs: { include: [ { config: "default", shard: 1, num_shards: 2, runner: "linux.rocm.gpu" }, { config: "default", shard: 2, num_shards: 2, runner: "linux.rocm.gpu" }, + { config: "distributed", shard: 1, num_shards: 1, runner: "linux.rocm.gpu" }, ]} linux-focal-rocm6_1-py3_8-test: @@ -209,4 +210,4 @@ jobs: build-environment: linux-focal-rocm6.1-py3.8 docker-image: ${{ needs.linux-focal-rocm6_1-py3_8-build.outputs.docker-image }} test-matrix: ${{ needs.linux-focal-rocm6_1-py3_8-build.outputs.test-matrix }} - tests-to-include: "test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs test_autograd inductor/test_torchinductor" + tests-to-include: "test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs test_autograd inductor/test_torchinductor distributed/test_c10d_common distributed/test_c10d_nccl" \ No newline at end of file From c5f926ab87751490e39bb99f16d48ad21a075aab Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Wed, 15 May 2024 06:37:06 -0700 Subject: [PATCH 009/116] [AOTI][torchgen] Support at::Generator via C shim (#126181) Summary: Support at::Generator which is used by many random number generator ops Pull Request resolved: https://github.com/pytorch/pytorch/pull/126181 Approved by: https://github.com/chenyang78 --- torch/csrc/inductor/aoti_torch/c/shim.h | 3 +++ .../inductor/aoti_torch/generated/c_shim_cpu.h | 5 +++++ .../aoti_torch/generated/c_shim_cuda.h | 5 +++++ torch/csrc/inductor/aoti_torch/utils.h | 18 ++++++++++++++++++ torchgen/gen_aoti_c_shim.py | 5 ++++- 5 files changed, 35 insertions(+), 1 deletion(-) diff --git a/torch/csrc/inductor/aoti_torch/c/shim.h b/torch/csrc/inductor/aoti_torch/c/shim.h index b05c52c6a387..6fa7df75c056 100644 --- a/torch/csrc/inductor/aoti_torch/c/shim.h +++ b/torch/csrc/inductor/aoti_torch/c/shim.h @@ -72,6 +72,9 @@ extern "C" { struct AtenTensorOpaque; using AtenTensorHandle = AtenTensorOpaque*; +struct AtenGeneratorOpaque; +using AtenGeneratorHandle = AtenGeneratorOpaque*; + struct AOTIProxyExecutorOpaque; using AOTIProxyExecutorHandle = AOTIProxyExecutorOpaque*; diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h index 8058618f9748..bbd8cbc9d31a 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h @@ -58,6 +58,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_cummax(AtenTensorHandle self, in AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_cummin(AtenTensorHandle self, int64_t dim, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_cumprod(AtenTensorHandle self, int64_t dim, int32_t* dtype, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_cumsum(AtenTensorHandle self, int64_t dim, int32_t* dtype, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_exponential(AtenTensorHandle self, double lambd, AtenGeneratorHandle* generator, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_fractional_max_pool2d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle random_samples, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_fractional_max_pool2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle indices, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_fractional_max_pool3d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle random_samples, AtenTensorHandle* ret0, AtenTensorHandle* ret1); @@ -94,9 +95,12 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_pow_Scalar(double self, AtenTens AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_pow_Tensor_Scalar(AtenTensorHandle self, double exponent, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_pow_Tensor_Tensor(AtenTensorHandle self, AtenTensorHandle exponent, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_rand(const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_rand_generator(const int64_t* size, int64_t size_len_, AtenGeneratorHandle* generator, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_randint(int64_t high, const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_randint_generator(int64_t high, const int64_t* size, int64_t size_len_, AtenGeneratorHandle* generator, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_randint_low_out(AtenTensorHandle out, int64_t low, int64_t high, const int64_t* size, int64_t size_len_); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_randn(const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_randn_generator(const int64_t* size, int64_t size_len_, AtenGeneratorHandle* generator, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_randperm(int64_t n, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_repeat_interleave_Tensor(AtenTensorHandle repeats, int64_t* output_size, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_replication_pad1d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* padding, int64_t padding_len_, AtenTensorHandle* ret0); @@ -113,6 +117,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_sort(AtenTensorHandle self, int6 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_sort_stable(AtenTensorHandle self, int32_t* stable, int64_t dim, int32_t descending, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_topk(AtenTensorHandle self, int64_t k, int64_t dim, int32_t largest, int32_t sorted, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_triangular_solve(AtenTensorHandle self, AtenTensorHandle A, int32_t upper, int32_t transpose, int32_t unitriangular, AtenTensorHandle* ret0, AtenTensorHandle* ret1); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_uniform(AtenTensorHandle self, double from, double to, AtenGeneratorHandle* generator, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_upsample_bicubic2d_backward(AtenTensorHandle grad_output, const int64_t* output_size, int64_t output_size_len_, const int64_t* input_size, int64_t input_size_len_, int32_t align_corners, double* scales_h, double* scales_w, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_upsample_linear1d_backward(AtenTensorHandle grad_output, const int64_t* output_size, int64_t output_size_len_, const int64_t* input_size, int64_t input_size_len_, int32_t align_corners, double* scales, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_upsample_trilinear3d_backward(AtenTensorHandle grad_output, const int64_t* output_size, int64_t output_size_len_, const int64_t* input_size, int64_t input_size_len_, int32_t align_corners, double* scales_d, double* scales_h, double* scales_w, AtenTensorHandle* ret0); diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h index df0099d37bed..2905aa810d3c 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h @@ -66,6 +66,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_cummax(AtenTensorHandle self, i AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_cummin(AtenTensorHandle self, int64_t dim, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_cumprod(AtenTensorHandle self, int64_t dim, int32_t* dtype, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_cumsum(AtenTensorHandle self, int64_t dim, int32_t* dtype, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_exponential(AtenTensorHandle self, double lambd, AtenGeneratorHandle* generator, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_fractional_max_pool2d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle random_samples, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_fractional_max_pool2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle indices, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_fractional_max_pool3d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle random_samples, AtenTensorHandle* ret0, AtenTensorHandle* ret1); @@ -101,9 +102,12 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_pow_Scalar(double self, AtenTen AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_pow_Tensor_Scalar(AtenTensorHandle self, double exponent, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_pow_Tensor_Tensor(AtenTensorHandle self, AtenTensorHandle exponent, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_rand(const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_rand_generator(const int64_t* size, int64_t size_len_, AtenGeneratorHandle* generator, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_randint(int64_t high, const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_randint_generator(int64_t high, const int64_t* size, int64_t size_len_, AtenGeneratorHandle* generator, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_randint_low_out(AtenTensorHandle out, int64_t low, int64_t high, const int64_t* size, int64_t size_len_); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_randn(const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_randn_generator(const int64_t* size, int64_t size_len_, AtenGeneratorHandle* generator, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_randperm(int64_t n, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_repeat_interleave_Tensor(AtenTensorHandle repeats, int64_t* output_size, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_replication_pad1d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* padding, int64_t padding_len_, AtenTensorHandle* ret0); @@ -120,6 +124,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_sort(AtenTensorHandle self, int AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_sort_stable(AtenTensorHandle self, int32_t* stable, int64_t dim, int32_t descending, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_topk(AtenTensorHandle self, int64_t k, int64_t dim, int32_t largest, int32_t sorted, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_triangular_solve(AtenTensorHandle self, AtenTensorHandle A, int32_t upper, int32_t transpose, int32_t unitriangular, AtenTensorHandle* ret0, AtenTensorHandle* ret1); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_uniform(AtenTensorHandle self, double from, double to, AtenGeneratorHandle* generator, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_upsample_bicubic2d_backward(AtenTensorHandle grad_output, const int64_t* output_size, int64_t output_size_len_, const int64_t* input_size, int64_t input_size_len_, int32_t align_corners, double* scales_h, double* scales_w, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_upsample_linear1d_backward(AtenTensorHandle grad_output, const int64_t* output_size, int64_t output_size_len_, const int64_t* input_size, int64_t input_size_len_, int32_t align_corners, double* scales, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_upsample_trilinear3d_backward(AtenTensorHandle grad_output, const int64_t* output_size, int64_t output_size_len_, const int64_t* input_size, int64_t input_size_len_, int32_t align_corners, double* scales_d, double* scales_h, double* scales_w, AtenTensorHandle* ret0); diff --git a/torch/csrc/inductor/aoti_torch/utils.h b/torch/csrc/inductor/aoti_torch/utils.h index 0964479caabd..44ca34b1c6e8 100644 --- a/torch/csrc/inductor/aoti_torch/utils.h +++ b/torch/csrc/inductor/aoti_torch/utils.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include @@ -32,6 +33,16 @@ inline AtenTensorHandle tensor_pointer_to_tensor_handle(at::Tensor* tensor) { return reinterpret_cast(tensor); } +inline at::Generator* generator_handle_to_generator_pointer( + AtenGeneratorHandle handle) { + return reinterpret_cast(handle); +} + +inline AtenGeneratorHandle generator_pointer_to_generator_handle( + at::Generator* generator) { + return reinterpret_cast(generator); +} + inline AtenTensorHandle new_tensor_handle(at::Tensor&& tensor) { at::Tensor* new_tensor = new at::Tensor(std::move(tensor)); return tensor_pointer_to_tensor_handle(new_tensor); @@ -61,6 +72,13 @@ inline std::optional pointer_to_optional( : c10::nullopt; } +template <> +inline std::optional pointer_to_optional( + AtenGeneratorHandle* ptr) { + return ptr ? c10::make_optional(*generator_handle_to_generator_pointer(*ptr)) + : c10::nullopt; +} + inline std::optional pointer_to_optional_device( int32_t* device_type, int32_t device_index) { diff --git a/torchgen/gen_aoti_c_shim.py b/torchgen/gen_aoti_c_shim.py index 1f99e3a9f3fa..5bc29e514a27 100644 --- a/torchgen/gen_aoti_c_shim.py +++ b/torchgen/gen_aoti_c_shim.py @@ -34,6 +34,7 @@ BaseTy.Layout: "int32_t", # Represent enum as int BaseTy.MemoryFormat: "int32_t", # Represent enum as int BaseTy.ScalarType: "int32_t", # Represent enum as int + BaseTy.Generator: "AtenGeneratorHandle", } base_type_to_aten_type = { @@ -48,6 +49,7 @@ BaseTy.Layout: "c10::Layout", BaseTy.MemoryFormat: "c10::MemoryFormat", BaseTy.ScalarType: "c10::ScalarType", + BaseTy.Generator: "at::Generator", } base_type_to_callsite_expr = { @@ -62,6 +64,7 @@ BaseTy.Layout: "static_cast", BaseTy.MemoryFormat: "static_cast", BaseTy.ScalarType: "static_cast", + BaseTy.Generator: "*generator_handle_to_generator_pointer", } @@ -89,7 +92,7 @@ def convert_arg_type_and_name(typ: Type, name: str) -> Tuple[List[str], List[str ], ) else: - # TODO: BaseTy.Dimname, BaseTy.Generator, etc. + # TODO: BaseTy.Dimname, etc. raise NotImplementedError(f"TODO: add support for arg type {repr(typ)}") elif isinstance(typ, OptionalType): c_types, names, aten_types, callsite_exprs = convert_arg_type_and_name( From 5792bc3c3e8ab7322e44f04e59b0ed3a27132e85 Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Wed, 15 May 2024 07:05:09 -0700 Subject: [PATCH 010/116] [AOTI] Refactor some fallback op util functions (#126182) Summary: Move some util functions for cpp kernel naming and missing arg filling from FallbackKernel to ExternKernel, since they are useful for ExternKernel in general. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126182 Approved by: https://github.com/chenyang78 ghstack dependencies: #126181 --- torch/_inductor/codegen/cpp_wrapper_cpu.py | 2 +- torch/_inductor/ir.py | 176 +++++++++------------ 2 files changed, 73 insertions(+), 105 deletions(-) diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index 6ce230714632..899332ef5646 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -1120,7 +1120,7 @@ def g(args): ) def get_c_shim_func_name(self, kernel): - if not config.abi_compatible: + if not config.abi_compatible or kernel.startswith("aoti_torch_"): return kernel assert "::" in kernel, "Cpp kernel name: " + kernel + " does not contain '::'" diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index bd092791eb7c..cbf990ea0b77 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -3930,6 +3930,21 @@ def should_allocate(self): return True +def get_aten_cpp_kernel_name(kernel): + # Calling with the default kernel name can lead to ambiguous behavior like the following example. + # repeat_interleave(const at::Tensor & repeats, c10::optional output_size=c10::nullopt) + # repeat_interleave(const at::Tensor & self, int64_t repeats, + # c10::optional dim=c10::nullopt, c10::optional output_size=c10::nullopt) + if not isinstance(kernel, torch._ops.OpOverload) or kernel.namespace != "aten": + return None + opname = ( + kernel.__name__.split(".")[0] + if kernel._overloadname == "default" + else kernel.__name__.replace(".", "_") + ) + return f"at::_ops::{opname}::call" + + @dataclasses.dataclass class ExternKernel(InputsKernel): constant_args: Tuple[Any, ...] = () @@ -3973,7 +3988,8 @@ def __init__( self.kwargs = kwargs if kwargs else {} self.output_view = output_view self.python_kernel_name = python_kernel_name - self.cpp_kernel_name = cpp_kernel_name + # If cpp_kernel_name is None, we will try to construct it from op_overload + self.cpp_kernel_name = cpp_kernel_name or get_aten_cpp_kernel_name(op_overload) self.ordered_kwargs_for_cpp_kernel = ordered_kwargs_for_cpp_kernel self.op_overload = op_overload self.collect_arg_kwarg_properties() @@ -4016,6 +4032,40 @@ def collect_arg_kwarg_properties(self): else {} ) + def fill_non_provided_args(self, args, kwargs, convert_val_to_str=False): + # Previously, we want to maintain forward-compatibility by skipping + # default args in the serialized artifacts in fbcode. However, + # some of our shim interfaces require default values being set. + # Discussed with Sherlock offline and we decided to allow serializing + # default args into the C++ wrapper code for now. We will refine this + # part if we see real FC requirement. More details related to FC + # can be found at: + # https://docs.google.com/document/d/1FzWm-sHYwmRi3x_g036kOxd99KaYquUsA-L5JwOn8ys/edit?usp=sharing + assert isinstance(args, (list, tuple)) + if isinstance(args, tuple): + args = list(args) + assert self.arg_properties, "ExternKernel.arg_properties should not be empty" + + n_args = len(args) + n_pos_args = len(self.arg_properties) + # For cpp wrapper, if some positional args are not provided, we need to check + # if they're in the kwargs or use their default value + if n_args < n_pos_args: + log.debug( + "%s has %d unprovided positional arguments. " + "Will check if they are in the keyword arguments or will use default values.", + self.op_overload, + n_pos_args - n_args, + ) + for i in range(n_args, n_pos_args): + arg_name = self.arg_properties[i]["name"] + args.append( + kwargs[arg_name] + if arg_name in kwargs + else self.arg_properties[i]["default_value"] + ) + return args + def decide_layout(self): if isinstance(self.layout, FlexibleLayout): self.apply_constraint() @@ -4030,7 +4080,15 @@ def codegen(self, wrapper): raise NotImplementedError def get_kernel_name(self): - return self.cpp_kernel_name if V.graph.cpp_wrapper else self.python_kernel_name + return ( + ( + V.graph.wrapper_code.get_c_shim_func_name(self.cpp_kernel_name) # type: ignore[attr-defined] + if config.abi_compatible + else self.cpp_kernel_name + ) + if V.graph.cpp_wrapper + else self.python_kernel_name + ) @staticmethod def copy_input(x): @@ -5128,25 +5186,7 @@ class ExternKernelNode: } -def get_aten_cpp_kernel_name(kernel): - # Calling with the default kernel name can lead to ambiguous behavior like the following example. - # repeat_interleave(const at::Tensor & repeats, c10::optional output_size=c10::nullopt) - # repeat_interleave(const at::Tensor & self, int64_t repeats, - # c10::optional dim=c10::nullopt, c10::optional output_size=c10::nullopt) - assert ( - isinstance(kernel, torch._ops.OpOverload) and kernel.namespace == "aten" - ), "Invalid aten kernel" - opname = ( - kernel.__name__.split(".")[0] - if kernel._overloadname == "default" - else kernel.__name__.replace(".", "_") - ) - return f"at::_ops::{opname}::call" - - class FallbackKernel(ExternKernelAlloc): - args_default_value: List[Dict[str, Any]] - def __init__( self, layout, @@ -5158,12 +5198,23 @@ def __init__( *, unbacked_bindings=None, ): + if ( + kernel == aten.mul.Tensor + and len(tensor_args) == 1 + and len(nontensor_args) == 1 + ): + # When aten.mul.Tensor's second arg is constant, cpp wrapper expects + # to call mul_Scalar. A more proper fix is to do it in decomposition. + # See https://github.com/pytorch/pytorch/issues/123478 + kernel = aten.mul.Scalar + super().__init__( layout, tuple(tensor_args), tuple(nontensor_args), op_overload=kernel, ) + # We need output buffers for generating kernel arguments in the # abi-compatible mode, where we retrieve outputs by pass each individual # output through the abi-compatible interface. @@ -5179,7 +5230,6 @@ def __init__( ), ), f"Fails to create FallbackKernel for {kernel}: {type(kernel)} not supported" self.op_overload = kernel - self.unflatten_args = unflatten_args self.kwargs = {} if kwargs is None else kwargs V.graph.warn_fallback(self.python_kernel_name) @@ -5341,41 +5391,6 @@ def is_not_write(arg): self.cpp_kernel_key = f"{self.cpp_kernel_name.replace('::', '_')}_{self.cpp_kernel_overload_name}" # type: ignore[union-attr] self.cpp_op_schema = get_cpp_op_schema(kernel) - self.init_args_default_value(kernel._schema) - - def init_args_default_value(self, schema): - self.args_default_value = [ - { - "name": x.name, - "type": x.real_type, - "value": x.default_value, - } - for x in schema.arguments - if not x.kwarg_only - ] - - def get_pos_arg_value(self, pos, kwargs): - # positional args may be provided in kwargs - pos_arg_name = self.args_default_value[pos]["name"] - if pos_arg_name in kwargs: - log.debug( - "Found argument %s with value %s from kwargs", - pos_arg_name, - kwargs[pos_arg_name], - ) - return kwargs[pos_arg_name] - - assert hasattr( - self, "args_default_value" - ), "self.args_default_value has to be provided" - assert pos < len( - self.args_default_value - ), f"expected the index {pos} to be smaller than len(self.args_default_value): {len(self.args_default_value)}" - arg_default_value = self.args_default_value[pos]["value"] - log.debug( - "Use default value %s for argument %s", arg_default_value, pos_arg_name - ) - return arg_default_value def codegen_args(self): @dataclasses.dataclass @@ -5388,6 +5403,7 @@ def __repr__(self): tensor_args = [Shim(x.codegen_reference()) for x in self.inputs] args, kwargs = self.unflatten_args(tensor_args, self.constant_args) if V.graph.cpp_wrapper and isinstance(self.op_overload, torch._ops.OpOverload): + args = self.fill_non_provided_args(args, kwargs) args = [ V.graph.wrapper_code.val_to_cpp_arg_str(param.real_type, x) for param, x in zip(self.op_overload._schema.arguments, args) @@ -5395,17 +5411,6 @@ def __repr__(self): else: args = [V.graph.wrapper_code.val_to_arg_str(x) for x in args] - # Previously, we want to maintain forward-compatibility by skipping - # default args in the serialized artifacts in fbcode. However, - # some of our shim interfaces require default values being set. - # Discussed with Sherlock offline and we decided to allow serializing - # default args into the C++ wrapper code for now. We will refine this - # part if we see real FC requirement. More details related to FC - # can be found at: - # https://docs.google.com/document/d/1FzWm-sHYwmRi3x_g036kOxd99KaYquUsA-L5JwOn8ys/edit?usp=sharing - if V.graph.cpp_wrapper and hasattr(self, "args_default_value"): - self.fill_non_provided_args(args, kwargs, convert_val_to_str=True) - # let self.codegen_kwargs handle kwargs self.kwargs.update(kwargs) return args @@ -5441,30 +5446,6 @@ def get_mutation_names(self): assert len(self.mutation_names) <= 1 return self.mutation_names - def fill_non_provided_args(self, args, kwargs, convert_val_to_str=False): - assert isinstance(args, (list, tuple)) - if isinstance(args, tuple): - args = list(args) - assert hasattr(self, "args_default_value") - n_args = len(args) - n_pos_args = len(self.args_default_value) - # For cpp wrapper, if some positional args are not provided, we need to check - # if they're in the kwargs or use their default value - if n_args < n_pos_args: - log.debug( - "%s has %d unprovided positional arguments. " - "Will check if they are in the keyword arguments or will use default values.", - self.op_overload, - n_pos_args - n_args, - ) - pos_args = [ - self.get_pos_arg_value(i, kwargs) for i in range(n_args, n_pos_args) - ] - if convert_val_to_str: - pos_args = [V.graph.wrapper_code.val_to_arg_str(x) for x in pos_args] - args.extend(pos_args) - return args - # ProxyExecutor Design Note # We export the ExternFallbackNodes (for custom ops) into a serialized file # and run it with a host side proxy executor to address the ABI problem @@ -5539,15 +5520,6 @@ def codegen(self, wrapper): if kernel.namespace == "aten": # type: ignore[union-attr] # Aten Fallback Ops assert isinstance(kernel, torch._ops.OpOverload) - - if ( - kernel == aten.mul.Tensor - and len(self.inputs) == 1 - and len(self.constant_args) == 1 - ): - # When aten.mul.Tensor's second arg is constant, cpp wrapper expects to call mul_Scalar - kernel = aten.mul.Scalar - if V.graph.cpp_wrapper: if ( config.is_fbcode() @@ -5562,10 +5534,6 @@ def codegen(self, wrapper): ) self.use_runtime_dispatch = True self.set_cpp_kernel(kernel) - else: - self.cpp_kernel_name = get_aten_cpp_kernel_name(kernel) - schema = kernel._schema # type: ignore[union-attr] - self.init_args_default_value(schema) else: self.python_kernel_name = str(kernel) elif kernel.namespace == "_quantized": # type: ignore[union-attr] From 0332b5812ed3ebdeb5d28efdde0939d13d991e75 Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Wed, 15 May 2024 07:05:09 -0700 Subject: [PATCH 011/116] [AOTI] Support InplaceBernoulliFallback in the ABI-compatible codegen (#126183) Summary: Update the torchgen rule for inplace ops like bernoulli_, and update InplaceBernoulliFallback to codegen in the ABI-compatible mode. Fixes https://github.com/pytorch/pytorch/issues/121809 Pull Request resolved: https://github.com/pytorch/pytorch/pull/126183 Approved by: https://github.com/angelayi ghstack dependencies: #126181, #126182 --- test/inductor/test_cpu_cpp_wrapper.py | 1 - test/inductor/test_cuda_cpp_wrapper.py | 1 - torch/_inductor/ir.py | 25 ++++++++++++------- torch/_inductor/lowering.py | 7 +++++- .../aoti_torch/generated/c_shim_cpu.h | 6 +++-- .../aoti_torch/generated/c_shim_cuda.h | 6 +++-- torchgen/aoti/fallback_ops.py | 2 ++ torchgen/gen_aoti_c_shim.py | 12 ++++----- 8 files changed, 38 insertions(+), 22 deletions(-) diff --git a/test/inductor/test_cpu_cpp_wrapper.py b/test/inductor/test_cpu_cpp_wrapper.py index b8fdbc49bd38..66b92eedc97c 100644 --- a/test/inductor/test_cpu_cpp_wrapper.py +++ b/test/inductor/test_cpu_cpp_wrapper.py @@ -71,7 +71,6 @@ class DynamicShapesCppWrapperCpuTests(InductorTestCase): if config.abi_compatible: xfail_list = [ - "test_bernoulli1_cpu", # cpp fallback op naming issue "test_conv2d_binary_inplace_fusion_failed_cpu", "test_conv2d_binary_inplace_fusion_pass_cpu", "test_dynamic_qlinear_cpu", diff --git a/test/inductor/test_cuda_cpp_wrapper.py b/test/inductor/test_cuda_cpp_wrapper.py index 5bbe588d3a84..5cb8af9db165 100644 --- a/test/inductor/test_cuda_cpp_wrapper.py +++ b/test/inductor/test_cuda_cpp_wrapper.py @@ -97,7 +97,6 @@ class DynamicShapesCudaWrapperCudaTests(InductorTestCase): if config.abi_compatible: xfail_list = [ - "test_bernoulli1_cuda", # cpp fallback op naming issue "test_profiler_mark_wrapper_call_cuda", "test_scaled_dot_product_attention_cuda_dynamic_shapes", ] diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index cbf990ea0b77..7b7e8e567a0b 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -4784,9 +4784,17 @@ class InplaceBernoulliFallback(ExternKernel): def codegen(self, wrapper): (x,) = (t.codegen_reference() for t in self.inputs) - wrapper.writeline( - f"{self.get_kernel_name()}({x}, {', '.join(map(repr, self.constant_args))}){wrapper.ending}" - ) + + if V.graph.cpp_wrapper and config.abi_compatible: + # Inductor doesn't really support aten Generator, so the Generator kwarg is always NULL here, + # which needs to be explicitly generated for cpp wrapper + wrapper.writeline( + f"{self.get_kernel_name()}({x}, {', '.join(map(repr, self.constant_args))}, NULL){wrapper.ending}" + ) + else: + wrapper.writeline( + f"{self.get_kernel_name()}({x}, {', '.join(map(repr, self.constant_args))}){wrapper.ending}" + ) def should_allocate(self): return False @@ -4797,20 +4805,19 @@ def get_mutation_names(self): def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]: return set() - def __init__(self, x, *constant_args): + def __init__(self, op_overload, x, *constant_args): super().__init__( None, NoneLayout(x.get_device()), # type: ignore[arg-type] self.unwrap_storage([x]), constant_args, + op_overload=op_overload, ) self.name = V.graph.register_buffer(self) self.python_kernel_name = "aten.bernoulli_" - self.cpp_kernel_name = ( - "aoti_torch_bernoulli_" - if config.abi_compatible - else "at::native::bernoulli_" - ) + if not config.abi_compatible: + # TODO: this should be simplified once we switch to ABI-compatible only + self.cpp_kernel_name = "at::native::bernoulli_" mark_node_as_mutating(self, x) diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 389ff16e3902..77d7b6c046de 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -1788,7 +1788,12 @@ def bernoulli_(x, *args): "cpu" ), "this should be handled in decomps unless config.fallback_random or the device is CPU" x.realize() - ir.InplaceBernoulliFallback(x, *args) + op_overload = ( + aten.bernoulli_.float + if len(args) == 0 or isinstance(args[0], float) + else aten.bernoulli_.Tensor + ) + ir.InplaceBernoulliFallback(op_overload, x, *args) return x diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h index bbd8cbc9d31a..2c7f05dd84cd 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h @@ -47,6 +47,8 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_avg_pool2d(AtenTensorHandle self AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_avg_pool2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_avg_pool3d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_avg_pool3d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_bernoulli__Tensor(AtenTensorHandle self, AtenTensorHandle p, AtenGeneratorHandle* generator); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_bernoulli__float(AtenTensorHandle self, double p, AtenGeneratorHandle* generator); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_bmm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_bucketize_Tensor(AtenTensorHandle self, AtenTensorHandle boundaries, int32_t out_int32, int32_t right, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_cat(const AtenTensorHandle* tensors, int64_t tensors_len_, int64_t dim, AtenTensorHandle* ret0); @@ -105,8 +107,8 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_randperm(int64_t n, int32_t* dty AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_repeat_interleave_Tensor(AtenTensorHandle repeats, int64_t* output_size, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_replication_pad1d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* padding, int64_t padding_len_, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_replication_pad2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* padding, int64_t padding_len_, AtenTensorHandle* ret0); -AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_resize_(AtenTensorHandle self, const int64_t* size, int64_t size_len_, int32_t* memory_format, AtenTensorHandle* ret0); -AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_resize_as_(AtenTensorHandle self, AtenTensorHandle the_template, int32_t* memory_format, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_resize_(AtenTensorHandle self, const int64_t* size, int64_t size_len_, int32_t* memory_format); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_resize_as_(AtenTensorHandle self, AtenTensorHandle the_template, int32_t* memory_format); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_scatter_src_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, AtenTensorHandle src); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_scatter_value_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, double value); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_scatter_reduce_two_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, AtenTensorHandle src, const char* reduce, int32_t include_self); diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h index 2905aa810d3c..1dceac240e40 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h @@ -55,6 +55,8 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_avg_pool2d(AtenTensorHandle sel AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_avg_pool2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_avg_pool3d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_avg_pool3d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_bernoulli__Tensor(AtenTensorHandle self, AtenTensorHandle p, AtenGeneratorHandle* generator); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_bernoulli__float(AtenTensorHandle self, double p, AtenGeneratorHandle* generator); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_bmm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_bucketize_Tensor(AtenTensorHandle self, AtenTensorHandle boundaries, int32_t out_int32, int32_t right, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_cat(const AtenTensorHandle* tensors, int64_t tensors_len_, int64_t dim, AtenTensorHandle* ret0); @@ -112,8 +114,8 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_randperm(int64_t n, int32_t* dt AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_repeat_interleave_Tensor(AtenTensorHandle repeats, int64_t* output_size, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_replication_pad1d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* padding, int64_t padding_len_, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_replication_pad2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* padding, int64_t padding_len_, AtenTensorHandle* ret0); -AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_resize_(AtenTensorHandle self, const int64_t* size, int64_t size_len_, int32_t* memory_format, AtenTensorHandle* ret0); -AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_resize_as_(AtenTensorHandle self, AtenTensorHandle the_template, int32_t* memory_format, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_resize_(AtenTensorHandle self, const int64_t* size, int64_t size_len_, int32_t* memory_format); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_resize_as_(AtenTensorHandle self, AtenTensorHandle the_template, int32_t* memory_format); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_scatter_src_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, AtenTensorHandle src); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_scatter_value_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, double value); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_scatter_reduce_two_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, AtenTensorHandle src, const char* reduce, int32_t include_self); diff --git a/torchgen/aoti/fallback_ops.py b/torchgen/aoti/fallback_ops.py index f77527a156be..4a300c3cc301 100644 --- a/torchgen/aoti/fallback_ops.py +++ b/torchgen/aoti/fallback_ops.py @@ -25,6 +25,8 @@ "aten.avg_pool2d.default", "aten.avg_pool3d_backward.default", "aten.avg_pool3d.default", + "aten.bernoulli_.float", + "aten.bernoulli_.Tensor", "aten.bmm.out", "aten.bucketize.Tensor", "aten.cat.default", diff --git a/torchgen/gen_aoti_c_shim.py b/torchgen/gen_aoti_c_shim.py index 5bc29e514a27..f123bc879cd3 100644 --- a/torchgen/gen_aoti_c_shim.py +++ b/torchgen/gen_aoti_c_shim.py @@ -249,18 +249,18 @@ def gen_declaration_and_definition( return declaration_definition_cache[(func_name, device, backend_call)] if schema.is_out_fn(): - # out_variant has out arguments in the front, and it's ok to ignore return value + # out_variant has out arguments in the front, and it's ok to ignore return values # because C shim functions only return AOTITorchError - # Somehow at::native out-variant functions have out arguments in the back args, callsite_exprs = gen_arguments( - [*schema.arguments.flat_non_out, *schema.arguments.out] - if "at::native" in backend_call - else [*schema.arguments.out, *schema.arguments.flat_non_out], + [*schema.arguments.out, *schema.arguments.flat_non_out] ) ret_assignments: List[str] = [] else: args, callsite_exprs = gen_arguments(schema.arguments.flat_all) - ret_declarations, ret_assignments = gen_returns(schema) + # ignore return values for inplace ops + ret_declarations, ret_assignments = ( + ([], []) if schema.name.name.inplace else gen_returns(schema) + ) args.extend(ret_declarations) declaration = f"AOTITorchError aoti_torch_{device}_{func_name}({', '.join(args)})" From 9fbf2696d7d7de46cd21716b34dd8ceb5e1da56b Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Wed, 15 May 2024 18:59:47 -0700 Subject: [PATCH 012/116] [AOTI][refactor] Add aoti_torch_item as a util function (#126352) Summary: The logic has been repeated several times in the code, so it's worth to write a common util function. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126352 Approved by: https://github.com/chenyang78 ghstack dependencies: #126181, #126182, #126183 --- torch/_inductor/codegen/cpp_wrapper_cpu.py | 42 ++++++++++++---------- 1 file changed, 24 insertions(+), 18 deletions(-) diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index 899332ef5646..9595f1da6f95 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -552,10 +552,8 @@ def write_wrapper_decl(self): ), "Fails to get the dtype of the sympy.Expr" cpp_dtype = DTYPE_TO_CPP[dtype] if config.abi_compatible: - self.prefix.writeline(f"{cpp_dtype} {input_key};") - dtype_str = str(dtype).split(".")[-1] - self.prefix.writeline( - f"aoti_torch_item_{dtype_str}(inputs[{idx}], &{input_key});" + self.codegen_tensor_item( + dtype, f"inputs[{idx}]", input_key, self.prefix ) else: self.prefix.writeline( @@ -890,6 +888,19 @@ def codegen_scalar_to_tensor(self, output: str): ) return name + def codegen_tensor_item( + self, dtype: torch.dtype, tensor: str, scalar: str, indented_buffer=None + ): + assert ( + config.abi_compatible + ), "codegen_tensor_item is only used for the ABI-compatible mode" + dtype_str = str(dtype).split(".")[-1] + writer = indented_buffer or self + writer.writeline(f"{DTYPE_TO_CPP[dtype]} {scalar};") + writer.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_item_{dtype_str}({tensor}, &{scalar}));" + ) + @cache_on_self def get_output_refs(self): return [ @@ -1376,10 +1387,9 @@ def codegen_shape_tuple(self, shape: Tuple[Expr, ...]) -> str: def codegen_dynamic_scalar(self, node): (data,) = (t.codegen_reference() for t in node.inputs) if config.abi_compatible: - dtype = node.inputs[0].get_dtype() - dtype_str = str(dtype).split(".")[-1] - self.writeline(f"{DTYPE_TO_CPP[dtype]} {node.sym}_raw;") - self.writeline(f"aoti_torch_item_{dtype_str}({data}, &{node.sym}_raw);") + self.codegen_tensor_item( + node.inputs[0].get_dtype(), data, f"{node.sym}_raw" + ) else: convert_type = DTYPE_TO_ATEN[node.inputs[0].get_dtype()].replace( "at::k", "to" @@ -1763,12 +1773,13 @@ def codegen_conditional(self, conditional): outer_outputs.append(out.get_name()) if not isinstance(conditional.predicate, ir.ShapeAsConstantBuffer): - predicate = f"{conditional.predicate.get_name()}_scalar" - self.writeline(f"bool {predicate};") # in ABI-compatible mode, we need to use the ABI shim function # to extract a C++ bool from the unrelying scalar bool Tensor - self.writeline( - f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_item_bool({conditional.predicate.codegen_reference()}, &{predicate}));" + predicate = f"{conditional.predicate.get_name()}_scalar" + self.codegen_tensor_item( + torch.bool, + conditional.predicate.codegen_reference(), + predicate, ) else: # the predicate is not a Tensor: SymBool or Python bool @@ -1847,12 +1858,7 @@ def codegen_while_loop(self, while_loop): if config.abi_compatible: cond_result = f"{cond_result_name}_scalar" - self.writeline(f"bool {cond_result};") - # in ABI-compatible mode, we need to use the ABI shim function - # to extract a C++ bool from the unrelying scalar bool Tensor - self.writeline( - f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_item_bool({cond_result_name}, &{cond_result}));" - ) + self.codegen_tensor_item(torch.bool, cond_result_name, cond_result) else: cond_result = f"{cond_result_name}.item()" self.writeline(f"if (!{cond_result}) break;") From 0dd53650dddd098eefa4553f478466d91ea0b459 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Wed, 15 May 2024 18:06:55 -0700 Subject: [PATCH 013/116] [BE][FSDP] Change the logging level to info (#126362) As title Differential Revision: [D57419445](https://our.internmc.facebook.com/intern/diff/D57419445/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/126362 Approved by: https://github.com/awgu, https://github.com/Skylion007 --- torch/distributed/fsdp/_debug_utils.py | 2 +- torch/distributed/fsdp/_optim_utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/torch/distributed/fsdp/_debug_utils.py b/torch/distributed/fsdp/_debug_utils.py index 4ed76476e56b..a41a817724e5 100644 --- a/torch/distributed/fsdp/_debug_utils.py +++ b/torch/distributed/fsdp/_debug_utils.py @@ -57,7 +57,7 @@ def dump_and_reset(cls, msg: str) -> None: # This cannot be combined with DETAIL distributed log # as the profiling will be very incorrect. if dist.get_rank() == 0 and dist.get_debug_level() == dist.DebugLevel.INFO: - logger.warning("%s %s", msg, cls.results) + logger.info("%s %s", msg, cls.results) cls.reset() diff --git a/torch/distributed/fsdp/_optim_utils.py b/torch/distributed/fsdp/_optim_utils.py index 163cde70b3f9..b066f930ebaf 100644 --- a/torch/distributed/fsdp/_optim_utils.py +++ b/torch/distributed/fsdp/_optim_utils.py @@ -1511,7 +1511,7 @@ def _allgather_orig_param_states( """ fsdp_state = fsdp_param_info.state if fsdp_state.rank == 0 and dist.get_debug_level() == dist.DebugLevel.DETAIL: - logger.warning( + logger.info( "Memory Summary before calling to _allgather_orig_param_states %s", fsdp_state._device_handle.memory_summary(), ) From a0429c01ad665ffb2faa04a411913ecee9962566 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Wed, 15 May 2024 18:15:27 -0700 Subject: [PATCH 014/116] [BE][FSDP] Remove unnecessary warnings (#126365) As title Differential Revision: [D57419704](https://our.internmc.facebook.com/intern/diff/D57419704/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/126365 Approved by: https://github.com/awgu, https://github.com/Skylion007 ghstack dependencies: #126362 --- torch/distributed/fsdp/_common_utils.py | 19 ++----------------- 1 file changed, 2 insertions(+), 17 deletions(-) diff --git a/torch/distributed/fsdp/_common_utils.py b/torch/distributed/fsdp/_common_utils.py index 7d9394ef1fbd..c1d77bf410b5 100644 --- a/torch/distributed/fsdp/_common_utils.py +++ b/torch/distributed/fsdp/_common_utils.py @@ -421,29 +421,14 @@ def f(module: torch.nn.Module, prefix: str, tree_level: int, *args, **kwargs): # ``named_children`` + `named_parameter(recurse=False)``. # This hack is a must to make the traversal work. # TODO: Remove this hack once DMP + FSDP is not supported. + # It turns out that recursive wrapping may trigger this as + # well. if ( submodule_name == "_fsdp_wrapped_module" or submodule_name == "_dmp_wrapped_module" ): - if ( - not torch.distributed._functional_collectives.is_torchdynamo_compiling() - ): - # TODO(voz): Don't graph break on this - warnings.warn( - "An unexpected prefix is detected. This case " - " should only happen when using DMP with FSDP. " - f"prefix = {prefix}, " - f"submodule_name = {submodule_name}" - ) new_prefix = prefix elif submodule_name == "module": - warnings.warn( - "An unexpected prefix is detected. This case " - " should only happen when DDP wraps the outer " - " modules while FSDP wraps the inner ones." - f"prefix = {prefix}, " - f"submodule_name = {submodule_name}" - ) new_prefix = prefix f(submodule, new_prefix, new_tree_level, *args, **kwargs) From 5862521ad15cfb79c3c9f5cf7aef33eb5bd6e08d Mon Sep 17 00:00:00 2001 From: Gustav Larsson Date: Thu, 16 May 2024 18:48:56 +0000 Subject: [PATCH 015/116] [onnx.export] Cache SetGraphInputTypeReliable (#124912) This PR is part of an effort to speed up torch.onnx.export (https://github.com/pytorch/pytorch/issues/121422). - For each node that is processed in onnx.export, a check is run to see if all inputs are "reliable" (static shape, etc.). This value does not change, so it is much faster to cache it on the first computation. The caching is added to the ConstantMap state. - Resolves (6) in #121422. - Also see #123028 with a similar addition of a cache state. (partial fix of #121545) Pull Request resolved: https://github.com/pytorch/pytorch/pull/124912 Approved by: https://github.com/justinchuby --- torch/csrc/jit/passes/onnx/constant_map.cpp | 9 +++++++++ torch/csrc/jit/passes/onnx/constant_map.h | 5 +++++ torch/csrc/jit/passes/onnx/shape_type_inference.cpp | 12 +++++++++--- 3 files changed, 23 insertions(+), 3 deletions(-) diff --git a/torch/csrc/jit/passes/onnx/constant_map.cpp b/torch/csrc/jit/passes/onnx/constant_map.cpp index 8fd1bed0b7a1..e249d0a83a64 100644 --- a/torch/csrc/jit/passes/onnx/constant_map.cpp +++ b/torch/csrc/jit/passes/onnx/constant_map.cpp @@ -48,6 +48,14 @@ c10::optional ConstantValueMap::GetAllGraphInputsStatic() { return ConstantValueMap::getInstance().allGraphInputsStatic; } +void ConstantValueMap::SetAllGraphInputsReliableComputed(bool computed) { + ConstantValueMap::getInstance().allGraphInputsReliableComputed = computed; +} + +bool ConstantValueMap::GetAllGraphInputsReliableComputed() { + return ConstantValueMap::getInstance().allGraphInputsReliableComputed; +} + void ConstantValueMap::SetShape( const std::string& tensorName, const c10::SymbolicShape& shapeValue) { @@ -277,6 +285,7 @@ void ConstantValueMap::ClearMaps() { ConstantValueMap::getInstance().symbolDimMap.clear(); ConstantValueMap::getInstance().dimSymbolMap.clear(); ConstantValueMap::getInstance().allGraphInputsStatic = c10::nullopt; + ConstantValueMap::getInstance().allGraphInputsReliableComputed = false; } // For debug only. diff --git a/torch/csrc/jit/passes/onnx/constant_map.h b/torch/csrc/jit/passes/onnx/constant_map.h index 303d373eea56..4261e45cc56c 100644 --- a/torch/csrc/jit/passes/onnx/constant_map.h +++ b/torch/csrc/jit/passes/onnx/constant_map.h @@ -29,6 +29,9 @@ class ConstantValueMap { static void SetAllGraphInputsStatic(bool all_static); static c10::optional GetAllGraphInputsStatic(); + static void SetAllGraphInputsReliableComputed(bool computed); + static bool GetAllGraphInputsReliableComputed(); + static void SetShape( const std::string& tensorName, const c10::SymbolicShape& shapeValue); @@ -108,6 +111,8 @@ class ConstantValueMap { DimSymbolMap dimSymbolMap; // Stores if all graph-level inputs have static shape c10::optional allGraphInputsStatic; + // True if reliable has been computed for all graph inputs + bool allGraphInputsReliableComputed; }; } // namespace jit diff --git a/torch/csrc/jit/passes/onnx/shape_type_inference.cpp b/torch/csrc/jit/passes/onnx/shape_type_inference.cpp index dd79754f4c01..65d065adeb2b 100644 --- a/torch/csrc/jit/passes/onnx/shape_type_inference.cpp +++ b/torch/csrc/jit/passes/onnx/shape_type_inference.cpp @@ -2035,11 +2035,17 @@ void UpdateReliable(Node* n) { } } +// Traverse the graph inputs and compute reliability (e.g., are shapes static). +// Since the inputs do not change during export, we save computation time by +// marking it as computed and subsequently skipping. void SetGraphInputTypeReliable(const Graph* g) { - for (auto graph_input : g->inputs()) { - if (!ConstantValueMap::HasTypeReliable(graph_input->debugName())) { - ConstantValueMap::SetTypeReliable(graph_input->debugName(), true); + if (!ConstantValueMap::GetAllGraphInputsReliableComputed()) { + for (auto graph_input : g->inputs()) { + if (!ConstantValueMap::HasTypeReliable(graph_input->debugName())) { + ConstantValueMap::SetTypeReliable(graph_input->debugName(), true); + } } + ConstantValueMap::SetAllGraphInputsReliableComputed(true); } } From aab448e381366d4cf499145adffe9fcb1ac2b28d Mon Sep 17 00:00:00 2001 From: Jiashen Cao Date: Thu, 16 May 2024 19:22:16 +0000 Subject: [PATCH 016/116] Remove redundant serialization code (#126249) After https://github.com/pytorch/pytorch/pull/123308, we no longer need separate serialization path to handle different types that exist in the `nn_module` metadata. This PR cleans up the redundant code. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126249 Approved by: https://github.com/angelayi --- torch/_export/serde/serialize.py | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py index e0d76920157c..5ff02a787690 100644 --- a/torch/_export/serde/serialize.py +++ b/torch/_export/serde/serialize.py @@ -539,21 +539,9 @@ def export_nn_module_stack(val): path, ty = val assert isinstance(path, str) + assert isinstance(ty, str) - # node.meta["nn_module_stack"] could have two forms: - # 1. (path: str, module_type: 'type'), e.g. - # ('', ) - # 2. (path: str, module_type: str), e.g. - # ('', 'sigmoid.inference.MySimpleModel') - # ExportedProgram directly produced by torch.export() has form 1 - # ExportedProgram deserialized from disk has form 2 - # TODO: This is not ideal, we should fix this. - if isinstance(ty, str): - normalized_ty = ty - else: - normalized_ty = ty.__module__ + "." + ty.__qualname__ - - return path + "," + normalized_ty + return path + "," + ty # Serialize to "key,orig_path,type_str" nn_module_list = [ From da9bf77f0abaa6707055b9b8cb7c6f3deb794caf Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 16 May 2024 20:05:29 +0000 Subject: [PATCH 017/116] [Dynamo] Support SET_UPDATE (#126243) Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/126243 Approved by: https://github.com/anijain2305, https://github.com/Skylion007, https://github.com/jansel --- test/dynamo/test_functions.py | 26 ++++++++++++++++++++++++++ torch/_dynamo/symbolic_convert.py | 8 ++++++++ torch/_dynamo/variables/dicts.py | 20 ++++++++++++++++++++ 3 files changed, 54 insertions(+) diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index fdb23e3f590f..472e9c56bae6 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -1164,6 +1164,32 @@ def test_tuple_contains(a, b): return a + b return a - b + @unittest.skipIf( + sys.version_info < (3, 9), + "SET_UPDATE was added at Python 3.9", + ) + @make_test + def test_set_update_bytecode(x): + # This produces bytecode SET_UPDATE since python 3.9 + var = {"apple", "banana", "cherry"} + if isinstance(var, set): + return x + 1 + else: + return x - 1 + + @unittest.skipIf( + sys.version_info < (3, 9), + "SET_UPDATE was added at Python 3.9", + ) + @make_test + def test_set_update_list_with_duplicated_items(x): + list1 = ["apple", "banana", "apple"] + list2 = ["orange", "banana"] + if len({*list1, *list2}) == 3: + return x + 1 + else: + return x - 1 + @make_test def test_set_contains(a, b): vals = set(["a", "b", "c"]) diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index d6fb3e2145b7..4b4d6d3de675 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -1504,6 +1504,14 @@ def SET_ADD(self, inst): assert obj.mutable_local return obj.call_method(self, "add", [v], {}) + def SET_UPDATE(self, inst): + v = self.pop() + assert inst.argval > 0 + obj = self.stack[-inst.arg] + assert isinstance(obj, SetVariable) + assert obj.mutable_local + obj.call_method(self, "update", [v], {}) + def LIST_APPEND(self, inst): v = self.pop() assert inst.argval > 0 diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index c8eabc2c8879..0724a80621f7 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -407,6 +407,8 @@ def call_method( args: List[VariableTracker], kwargs: Dict[str, VariableTracker], ) -> "VariableTracker": + from . import ListVariable, TupleVariable + # We foward the calls to the dictionary model if name == "add": assert not kwargs @@ -426,6 +428,24 @@ def call_method( return variables.UserFunctionVariable( polyfill.set_isdisjoint ).call_function(tx, [self, args[0]], {}) + elif ( + name == "update" + and len(args) == 1 + and isinstance( + args[0], + ( + SetVariable, + ListVariable, + TupleVariable, + ), + ) + and self.mutable_local + ): + if isinstance(args[0], (ListVariable, TupleVariable)): + arg = SetVariable(args[0].unpack_var_sequence(tx)) + else: + arg = args[0] + return super().call_method(tx, "update", (arg,), kwargs) return super().call_method(tx, name, args, kwargs) def getitem_const(self, arg: VariableTracker): From 8f0c207e187d3e62943af1598f4545771763f634 Mon Sep 17 00:00:00 2001 From: Dmitry Rogozhkin Date: Thu, 16 May 2024 20:22:17 +0000 Subject: [PATCH 018/116] xpu: implement xpu serialization (#125530) Fixes: #125529 BC-breaking note: The deprecated "async" argument to the Storage.cuda and Storage.hpu has been removed. Use non_blocking instead. CC: @jbschlosser, @frank-wei @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @albanD Pull Request resolved: https://github.com/pytorch/pytorch/pull/125530 Approved by: https://github.com/guangyey, https://github.com/albanD --- test/test_serialization.py | 40 +++++++++---- test/test_xpu.py | 35 +++++++++++ torch/_utils.py | 79 ++++++++---------------- torch/serialization.py | 120 +++++++++---------------------------- torch/storage.py | 60 +++++++++++++++---- 5 files changed, 164 insertions(+), 170 deletions(-) diff --git a/test/test_serialization.py b/test/test_serialization.py index 49f8880885ec..5c6b78b44564 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -493,6 +493,15 @@ def test_serialization_map_location(self): def map_location(storage, loc): return storage + def generate_map_locations(device_type): + return [ + {'cuda:0': device_type + ':0'}, + device_type, + device_type + ':0', + torch.device(device_type), + torch.device(device_type, 0) + ] + def load_bytes(): with open(test_file_path, 'rb') as f: return io.BytesIO(f.read()) @@ -504,34 +513,39 @@ def load_bytes(): 'cpu', torch.device('cpu'), ] - gpu_0_map_locations = [ - {'cuda:0': 'cuda:0'}, - 'cuda', - 'cuda:0', - torch.device('cuda'), - torch.device('cuda', 0) - ] + gpu_0_map_locations = generate_map_locations('cuda') gpu_last_map_locations = [ f'cuda:{torch.cuda.device_count() - 1}', ] + xpu_0_map_locations = generate_map_locations('xpu') + xpu_last_map_locations = [ + f'xpu:{torch.xpu.device_count() - 1}', + ] - def check_map_locations(map_locations, tensor_class, intended_device): + def check_map_locations(map_locations, dtype, intended_device): for fileobject_lambda in fileobject_lambdas: for map_location in map_locations: tensor = torch.load(fileobject_lambda(), map_location=map_location) self.assertEqual(tensor.device, intended_device) - self.assertIsInstance(tensor, tensor_class) - self.assertEqual(tensor, tensor_class([[1.0, 2.0], [3.0, 4.0]])) + self.assertEqual(tensor.dtype, dtype) + self.assertEqual(tensor, torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=dtype, device=intended_device)) - check_map_locations(cpu_map_locations, torch.FloatTensor, torch.device('cpu')) + check_map_locations(cpu_map_locations, torch.float, torch.device('cpu')) if torch.cuda.is_available(): - check_map_locations(gpu_0_map_locations, torch.cuda.FloatTensor, torch.device('cuda', 0)) + check_map_locations(gpu_0_map_locations, torch.float, torch.device('cuda', 0)) check_map_locations( gpu_last_map_locations, - torch.cuda.FloatTensor, + torch.float, torch.device('cuda', torch.cuda.device_count() - 1) ) + if torch.xpu.is_available(): + check_map_locations(xpu_0_map_locations, torch.float, torch.device('xpu', 0)) + check_map_locations( + xpu_last_map_locations, + torch.float, + torch.device('xpu', torch.xpu.device_count() - 1) + ) @unittest.skipIf(torch.cuda.is_available(), "Testing torch.load on CPU-only machine") def test_load_nonexistent_device(self): diff --git a/test/test_xpu.py b/test/test_xpu.py index 74cc891a9e62..a3838f1d5a05 100644 --- a/test/test_xpu.py +++ b/test/test_xpu.py @@ -1,6 +1,7 @@ # Owner(s): ["module: intel"] import sys +import tempfile import unittest import torch @@ -270,6 +271,40 @@ def convert_boolean_tensors(x): self.assertEqual(expect, actual) + def test_serialization_array_with_storage(self): + x = torch.randn(5, 5).xpu() + y = torch.zeros(2, 5, dtype=torch.int, device="xpu") + q = [x, y, x, y.storage()] + with tempfile.NamedTemporaryFile() as f: + torch.save(q, f) + f.seek(0) + q_copy = torch.load(f) + self.assertEqual(q_copy, q, atol=0, rtol=0) + q_copy[0].fill_(5) + self.assertEqual(q_copy[0], q_copy[2], atol=0, rtol=0) + self.assertEqual(q_copy[0].dtype, torch.float) + self.assertEqual(q_copy[1].dtype, torch.int) + self.assertEqual(q_copy[2].dtype, torch.float) + self.assertTrue(isinstance(q_copy[3], torch.storage.TypedStorage)) + self.assertTrue(isinstance(q_copy[3]._untyped_storage, torch.UntypedStorage)) + q_copy[1].fill_(10) + y.fill_(10) + self.assertEqual(q_copy[3], y.storage()) + + def test_serialization_array_with_empty(self): + x = [ + torch.randn(4, 4).xpu(), + torch.tensor([], dtype=torch.float, device=torch.device("xpu")), + ] + with tempfile.NamedTemporaryFile() as f: + torch.save(x, f) + f.seek(0) + x_copy = torch.load(f) + for original, copy in zip(x, x_copy): + self.assertEqual(copy, original) + self.assertIs(type(copy), type(original)) + self.assertEqual(copy.get_device(), original.get_device()) + instantiate_device_type_tests(TestXpu, globals(), only_for="xpu") diff --git a/torch/_utils.py b/torch/_utils.py index 2e48fe9a1a9d..1bb726252dee 100644 --- a/torch/_utils.py +++ b/torch/_utils.py @@ -52,71 +52,40 @@ def _type(self, dtype=None, non_blocking=False, **kwargs): return dtype(self.size()).copy_(self, non_blocking) -def _hpu(self, device=None, non_blocking=False, **kwargs): - """Returns a copy of this object in HPU memory. +def _to(self, device, non_blocking=False): + """Returns a copy of this object in device memory. - If this object is already in HPU memory and on the correct device, then - no copy is performed and the original object is returned. + If this object is already on the correct device, then no copy is performed + and the original object is returned. Args: - device (int): The destination HPU id. Defaults to the current device. + device (int): The destination device. non_blocking (bool): If ``True`` and the source is in pinned memory, the copy will be asynchronous with respect to the host. Otherwise, the argument has no effect. - **kwargs: For compatibility, may contain the key ``async`` in place of - the ``non_blocking`` argument. """ - non_blocking = _get_async_or_non_blocking("hpu", non_blocking, kwargs) - hpu = getattr(torch, "hpu", None) - assert hpu is not None, "HPU device module is not loaded" - if self.is_hpu: - if device is None: - device = hpu.current_device() - if self.get_device() == device: - return self - else: - if device is None: - device = -1 - with hpu.device(device): - assert not self.is_sparse, "sparse storage is not supported for HPU tensors" - untyped_storage = torch.UntypedStorage(self.size(), device=torch.device("hpu")) - untyped_storage.copy_(self, non_blocking) - return untyped_storage - - -def _cuda(self, device=None, non_blocking=False, **kwargs): - """Returns a copy of this object in CUDA memory. - - If this object is already in CUDA memory and on the correct device, then - no copy is performed and the original object is returned. + if self.device == device: + return self - Args: - device (int): The destination GPU id. Defaults to the current device. - non_blocking (bool): If ``True`` and the source is in pinned memory, - the copy will be asynchronous with respect to the host. Otherwise, - the argument has no effect. - **kwargs: For compatibility, may contain the key ``async`` in place of - the ``non_blocking`` argument. - """ - non_blocking = _get_async_or_non_blocking("cuda", non_blocking, kwargs) - if self.is_cuda: - if device is None: - device = torch.cuda.current_device() - if self.get_device() == device: - return self - else: - if device is None: - device = -1 - with torch.cuda.device(device): - if self.is_sparse: - new_type = getattr(torch.cuda.sparse, self.__class__.__name__) - indices = torch.Tensor._indices(self).cuda(device, non_blocking) - values = torch.Tensor._values(self).cuda(device, non_blocking) + device_module = getattr(torch, device.type, None) + assert ( + device_module is not None + ), f"{device.type.upper()} device module is not loaded" + with device_module.device(device): + if self.is_sparse and hasattr(device_module, "sparse"): + new_type = getattr(device_module.sparse, self.__class__.__name__) + indices = getattr(torch.Tensor._indices(self), device.type)( + device, non_blocking + ) + values = getattr(torch.Tensor._values(self), device.type)( + device, non_blocking + ) return new_type(indices, values, self.size()) else: - untyped_storage = torch.UntypedStorage( - self.size(), device=torch.device("cuda") - ) + assert ( + not self.is_sparse + ), f"sparse storage is not supported for {device.type.upper()} tensors" + untyped_storage = torch.UntypedStorage(self.size(), device=device) untyped_storage.copy_(self, non_blocking) return untyped_storage diff --git a/torch/serialization.py b/torch/serialization.py index df839408ee77..616c21e80d7f 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -1,4 +1,5 @@ import difflib +import functools import os import io import shutil @@ -252,14 +253,6 @@ def _cpu_tag(obj): return 'cpu' -def _cuda_tag(obj): - if obj.device.type == 'cuda': - return 'cuda:' + str(obj.device.index) - -def _hpu_tag(obj): - if obj.device.type == 'hpu': - return 'hpu:' + str(obj.device.index) - def _mps_tag(obj): if obj.device.type == 'mps': return 'mps' @@ -270,8 +263,9 @@ def _meta_tag(obj): return 'meta' -def _privateuse1_tag(obj): - backend_name = torch._C._get_privateuse1_backend_name() +def _backend_tag(backend_name, obj): + if backend_name == 'privateuse1': + backend_name = torch._C._get_privateuse1_backend_name() if obj.device.type == backend_name: if obj.device.index is None: return backend_name @@ -284,66 +278,6 @@ def _cpu_deserialize(obj, location): return obj -def validate_cuda_device(location): - device = torch.cuda._utils._get_device_index(location, True) - - if not torch.cuda.is_available(): - raise RuntimeError('Attempting to deserialize object on a CUDA ' - 'device but torch.cuda.is_available() is False. ' - 'If you are running on a CPU-only machine, ' - 'please use torch.load with map_location=torch.device(\'cpu\') ' - 'to map your storages to the CPU.') - device_count = torch.cuda.device_count() - if device >= device_count: - raise RuntimeError('Attempting to deserialize object on CUDA device ' - f'{device} but torch.cuda.device_count() is {device_count}. Please use ' - 'torch.load with map_location to map your storages ' - 'to an existing device.') - return device - - -def _cuda_deserialize(obj, location): - if location.startswith('cuda'): - device = validate_cuda_device(location) - if getattr(obj, "_torch_load_uninitialized", False): - with torch.cuda.device(device): - return torch.UntypedStorage(obj.nbytes(), device=torch.device(location)) - else: - return obj.cuda(device) - - -def validate_hpu_device(location): - hpu = getattr(torch, "hpu", None) - assert hpu is not None, "HPU device module is not loaded" - device = hpu._utils._get_device_index(location, optional=True) - - if not hpu.is_available(): - raise RuntimeError('Attempting to deserialize object on a HPU ' - 'device but torch.hpu.is_available() is False. ' - 'If you are running on a CPU-only machine, ' - 'please use torch.load with map_location=torch.device(\'cpu\') ' - 'to map your storages to the CPU.') - device_count = hpu.device_count() - if device >= device_count: - raise RuntimeError('Attempting to deserialize object on HPU device ' - f'{device} but torch.hpu.device_count() is {device_count}. Please use ' - 'torch.load with map_location to map your storages ' - 'to an existing device.') - return device - - -def _hpu_deserialize(obj, location): - if location.startswith('hpu'): - hpu = getattr(torch, "hpu", None) - assert hpu is not None, "HPU device module is not loaded" - device = validate_hpu_device(location) - if getattr(obj, "_torch_load_uninitialized", False): - with hpu.device(device): - return torch.UntypedStorage(obj.nbytes(), device=torch.device(location)) - else: - return obj.hpu(device) - - def _mps_deserialize(obj, location): if location.startswith('mps'): return obj.mps() @@ -354,18 +288,18 @@ def _meta_deserialize(obj, location): return torch.UntypedStorage(obj.nbytes(), device='meta') -def _validate_privateuse1_device(location, backend_name): +def _validate_device(location, backend_name): ''' - Check whether the device index of privateuse1 is valid + Check whether the device index of specified backend is valid - Register a device_module of privateuse1 by torch._register_device_module. - Implement the following methods in device_module like cuda: - device_module._utils._get_device_index(location, True), + In case of privateuse1 backend, your must first register a device_module for + privateuse1 using torch._register_device_module. Implement the following + methods in device_module like cuda: device_module._utils._get_device_index(location, True), device_module.device_count(). Args: location: string of device - backend_name: the name of privateuse1, which can be renamed + backend_name: the backend name or the name of privateuse1, which can be renamed Returns: device_index: int @@ -378,6 +312,7 @@ def _validate_privateuse1_device(location, backend_name): device_module = getattr(torch, backend_name) if hasattr(device_module, '_utils') and hasattr(device_module._utils, '_get_device_index'): device_index = device_module._utils._get_device_index(location, True) + device = torch.device(backend_name, device_index) else: device = torch.device(location) device_index = device.index if device.index else 0 @@ -394,29 +329,32 @@ def _validate_privateuse1_device(location, backend_name): f'{device_index} but torch.{backend_name}.device_count() is {device_count}. ' 'Please use torch.load with map_location to map your storages ' 'to an existing device.') - return device_index + return device + + +def validate_cuda_device(location): + return _validate_device(location, 'cuda').index -def _privateuse1_deserialize(obj, location): - backend_name = torch._C._get_privateuse1_backend_name() +def validate_hpu_device(location): + return _validate_device(location, 'hpu').index + + +def _deserialize(backend_name, obj, location): + if backend_name == 'privateuse1': + backend_name = torch._C._get_privateuse1_backend_name() if location.startswith(backend_name): - if not hasattr(obj, backend_name): - raise RuntimeError(f'Attempting to load the storages to the {backend_name.upper()} device ' - f'but torch.storage._StorageBase.{backend_name}() or ' - f'torch.storage.TypedStorage.{backend_name}() is not generated. ' - 'Please use torch.utils.generate_methods_for_privateuse1_backend ' - f'to generate storage.{backend_name}() method first.') - device_index = _validate_privateuse1_device(location, backend_name) - return getattr(obj, backend_name)(device_index) + device = _validate_device(location, backend_name) + return obj.to(device=device) register_package(10, _cpu_tag, _cpu_deserialize) -register_package(20, _cuda_tag, _cuda_deserialize) +register_package(20, functools.partial(_backend_tag, 'cuda'), functools.partial(_deserialize, 'cuda')) register_package(21, _mps_tag, _mps_deserialize) register_package(22, _meta_tag, _meta_deserialize) -register_package(23, _privateuse1_tag, _privateuse1_deserialize) -register_package(24, _hpu_tag, _hpu_deserialize) - +register_package(23, functools.partial(_backend_tag, 'privateuse1'), functools.partial(_deserialize, 'privateuse1')) +register_package(24, functools.partial(_backend_tag, 'hpu'), functools.partial(_deserialize, 'hpu')) +register_package(25, functools.partial(_backend_tag, 'xpu'), functools.partial(_deserialize, 'xpu')) def location_tag(storage: Union[Storage, torch.storage.TypedStorage, torch.UntypedStorage]): for _, tagger, _ in _package_registry: diff --git a/torch/storage.py b/torch/storage.py index 306dd99a93ad..32070783f494 100644 --- a/torch/storage.py +++ b/torch/storage.py @@ -1,7 +1,7 @@ import io import torch -from ._utils import _type, _cuda, _hpu +from ._utils import _type, _to from torch.types import Storage from typing import cast, Any, Dict as _Dict, Optional as _Optional, TypeVar, Type, Union import copy @@ -38,8 +38,37 @@ def size(self) -> int: return self.nbytes() def type(self, dtype: _Optional[str] = None, non_blocking: bool = False) -> T: ... # type: ignore[empty-body, misc, type-var] # noqa: E704 - def cuda(self, device=None, non_blocking=False, **kwargs) -> T: ... # type: ignore[empty-body, misc, type-var] # noqa: E704 - def hpu(self, device=None, non_blocking=False, **kwargs) -> T: ... # type: ignore[empty-body, misc, type-var] # noqa: E704 + + def cuda(self, device=None, non_blocking=False) -> T: # type: ignore[type-var] # noqa: E704 + """Returns a copy of this object in CUDA memory. + + If this object is already in CUDA memory and on the correct device, then + no copy is performed and the original object is returned. + + Args: + device (int): The destination GPU id. Defaults to the current device. + non_blocking (bool): If ``True`` and the source is in pinned memory, + the copy will be asynchronous with respect to the host. Otherwise, + the argument has no effect. + """ + device2 = torch.device('cuda', device) if device else torch.device('cuda') + return self.to(device=device2, non_blocking=non_blocking) + + def hpu(self, device=None, non_blocking=False) -> T: # type: ignore[type-var] # noqa: E704 + """Returns a copy of this object in HPU memory. + + If this object is already in HPU memory and on the correct device, then + no copy is performed and the original object is returned. + + Args: + device (int): The destination HPU id. Defaults to the current device. + non_blocking (bool): If ``True`` and the source is in pinned memory, + the copy will be asynchronous with respect to the host. Otherwise, + the argument has no effect. + """ + device2 = torch.device('hpu', device) if device else torch.device('hpu') + return self.to(device=device2, non_blocking=non_blocking) + def element_size(self) -> int: ... # type: ignore[empty-body, type-var] # noqa: E704 def get_device(self) -> int: @@ -153,6 +182,9 @@ def _to(self, dtype): storage = storage.clone() return storage + def to(self, *, device: torch.device, non_blocking: bool = False) -> T: # type: ignore[type-var] # noqa: E704 + return _to(self, device, non_blocking) + def double(self): """Casts this storage to double type.""" return self._to(torch.double) @@ -382,8 +414,6 @@ def _load_from_bytes(b): _StorageBase.type = _type # type: ignore[assignment] -_StorageBase.cuda = _cuda # type: ignore[assignment] -_StorageBase.hpu = _hpu # type: ignore[assignment] @lru_cache(maxsize=None) @@ -812,20 +842,27 @@ def type(self, dtype: _Optional[str] = None, non_blocking: bool = False) -> Unio else: return self._untyped_storage.type(dtype, non_blocking) - def cuda(self, device=None, non_blocking=False, **kwargs) -> T: # type: ignore[misc, type-var] + def cuda(self, device=None, non_blocking=False) -> T: # type: ignore[misc, type-var] _warn_typed_storage_removal() if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]: raise RuntimeError("Cannot create CUDA storage with quantized dtype") - cuda_storage: torch.UntypedStorage = self._untyped_storage.cuda(device, non_blocking, **kwargs) + cuda_storage: torch.UntypedStorage = self._untyped_storage.cuda(device, non_blocking) return self._new_wrapped_storage(cuda_storage) - def hpu(self, device=None, non_blocking=False, **kwargs) -> T: # type: ignore[misc, type-var] + def hpu(self, device=None, non_blocking=False) -> T: # type: ignore[misc, type-var] _warn_typed_storage_removal() if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]: raise RuntimeError("Cannot create HPU storage with quantized dtype") - hpu_storage: torch.UntypedStorage = self._untyped_storage.hpu(device, non_blocking, **kwargs) + hpu_storage: torch.UntypedStorage = self._untyped_storage.hpu(device, non_blocking) return self._new_wrapped_storage(hpu_storage) + def to(self, *, device: torch.device, non_blocking: bool = False) -> T: # type: ignore[type-var] + _warn_typed_storage_removal() + if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]: + raise RuntimeError(f"Cannot create {device.type.upper()} storage with quantized dtype") + to_storage: torch.UntypedStorage = self._untyped_storage.to(device=device, non_blocking=non_blocking) + return self._new_wrapped_storage(to_storage) + def element_size(self): _warn_typed_storage_removal() return self._element_size() @@ -1209,8 +1246,9 @@ def _get_legacy_storage_class(self): return None TypedStorage.type.__doc__ = _type.__doc__ -TypedStorage.cuda.__doc__ = _cuda.__doc__ -TypedStorage.hpu.__doc__ = _hpu.__doc__ +TypedStorage.cuda.__doc__ = _StorageBase.cuda.__doc__ +TypedStorage.hpu.__doc__ = _StorageBase.hpu.__doc__ +TypedStorage.to.__doc__ = _to.__doc__ class _LegacyStorageMeta(type): dtype: torch.dtype From 866ca4630c6391f9774873b083bbacb84116bc74 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Thu, 16 May 2024 08:23:32 -0700 Subject: [PATCH 019/116] Don't install inplace_methods on MockHandler, not needed (#126398) Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/126398 Approved by: https://github.com/jansel, https://github.com/peterbell10 --- torch/_inductor/ops_handler.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/torch/_inductor/ops_handler.py b/torch/_inductor/ops_handler.py index 88f9d406c2e1..6da386709997 100644 --- a/torch/_inductor/ops_handler.py +++ b/torch/_inductor/ops_handler.py @@ -17,7 +17,7 @@ import torch import torch.utils._pytree as pytree -from torch.fx.graph import inplace_methods, magic_methods +from torch.fx.graph import magic_methods from .utils import IndentedBuffer, reduction_num_outputs, sympy_index_symbol, sympy_str T = TypeVar("T") @@ -578,9 +578,7 @@ def inner(*args): return inner - for name, format_string in itertools.chain( - magic_methods.items(), inplace_methods.items() - ): + for name, format_string in itertools.chain(magic_methods.items()): setattr(cls, name, make_handler(format_string)) From 82c66bc41a025c3b6b06083129f3d09949ff93d5 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Thu, 16 May 2024 08:20:07 -0700 Subject: [PATCH 020/116] Make 'pytest test/inductor/test_memory_planning.py' work (#126397) There's still another naughty direct test_* import, I'm out of patience right now though. Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/126397 Approved by: https://github.com/peterbell10, https://github.com/int3 --- test/inductor/test_memory_planning.py | 6 +++--- test/inductor/test_torchinductor.py | 24 +----------------------- torch/_inductor/utils.py | 23 +++++++++++++++++++++++ 3 files changed, 27 insertions(+), 26 deletions(-) diff --git a/test/inductor/test_memory_planning.py b/test/inductor/test_memory_planning.py index 1bd546e5b4df..1ec1dd9f89e9 100644 --- a/test/inductor/test_memory_planning.py +++ b/test/inductor/test_memory_planning.py @@ -2,6 +2,8 @@ import sys +import unittest + from torch.testing._internal.common_utils import IS_CI, IS_WINDOWS, skipIfRocm from torch.testing._internal.inductor_utils import HAS_CUDA @@ -13,14 +15,12 @@ sys.exit(0) raise unittest.SkipTest("requires sympy/functorch/filelock") # noqa: F821 -import unittest - import torch -from test_torchinductor import run_and_get_cpp_code from torch._C import FileCheck from torch._dynamo.utils import same from torch._inductor import config from torch._inductor.test_case import run_tests, TestCase +from torch._inductor.utils import run_and_get_cpp_code from torch.export import Dim from torch.utils._triton import has_triton diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 3a7b66d66065..695e3ebfe896 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -46,6 +46,7 @@ aoti_eager_cache_dir, load_aoti_eager_cache, run_and_get_code, + run_and_get_cpp_code, run_and_get_triton_code, ) from torch._inductor.virtualized import V @@ -342,29 +343,6 @@ def clone_preserve_strides(x, device=None): return out -def run_and_get_cpp_code(fn, *args, **kwargs): - # We use the patch context manager instead of using it as a decorator. - # In this way, we can ensure that the attribute is patched and unpatched correctly - # even if this run_and_get_cpp_code function is called multiple times. - with patch.object(config, "debug", True): - torch._dynamo.reset() - import io - import logging - - log_capture_string = io.StringIO() - ch = logging.StreamHandler(log_capture_string) - from torch._inductor.graph import output_code_log - - output_code_log.addHandler(ch) - prev_level = output_code_log.level - output_code_log.setLevel(logging.DEBUG) - result = fn(*args, **kwargs) - s = log_capture_string.getvalue() - output_code_log.setLevel(prev_level) - output_code_log.removeHandler(ch) - return result, s - - def check_model( self: TestCase, model, diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index bcf586862a71..caca6eaf2e21 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -1696,3 +1696,26 @@ def aoti_compile_with_persistent_cache( return kernel_lib_path except Exception as e: return "" + + +def run_and_get_cpp_code(fn, *args, **kwargs): + # We use the patch context manager instead of using it as a decorator. + # In this way, we can ensure that the attribute is patched and unpatched correctly + # even if this run_and_get_cpp_code function is called multiple times. + with unittest.mock.patch.object(config, "debug", True): + torch._dynamo.reset() + import io + import logging + + log_capture_string = io.StringIO() + ch = logging.StreamHandler(log_capture_string) + from torch._inductor.graph import output_code_log + + output_code_log.addHandler(ch) + prev_level = output_code_log.level + output_code_log.setLevel(logging.DEBUG) + result = fn(*args, **kwargs) + s = log_capture_string.getvalue() + output_code_log.setLevel(prev_level) + output_code_log.removeHandler(ch) + return result, s From 4f1a56cd425fd0959f5a0b51b07550e9f0449336 Mon Sep 17 00:00:00 2001 From: Tobias Ringwald Date: Thu, 16 May 2024 20:58:24 +0000 Subject: [PATCH 021/116] Switched from parameter in can_cast to from_. (#126030) Fixes #126012. `from` is a reserved keyword in Python, thus we can't make the C++ impl available with `from` as function parameter. This PR changes the name to `from_` and also adjusts the docs. If we want to preserve backwards compatibility, we can leave the C++ name as-is and only fix the docs. However, `torch.can_cast(from_=torch.int, to=torch.int)` won't work then. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126030 Approved by: https://github.com/albanD --- aten/src/ATen/native/TypeProperties.cpp | 4 ++-- aten/src/ATen/native/native_functions.yaml | 2 +- .../check_forward_backward_compatibility.py | 2 ++ torch/_torch_docs.py | 4 ++-- 4 files changed, 7 insertions(+), 5 deletions(-) diff --git a/aten/src/ATen/native/TypeProperties.cpp b/aten/src/ATen/native/TypeProperties.cpp index 7091e4f78aef..4afc7619c2eb 100644 --- a/aten/src/ATen/native/TypeProperties.cpp +++ b/aten/src/ATen/native/TypeProperties.cpp @@ -191,8 +191,8 @@ ScalarType result_type(const Scalar& scalar1, const Scalar& scalar2) { return result_type(state); } -bool can_cast(const at::ScalarType from, const at::ScalarType to) { - return at::canCast(from, to); +bool can_cast(const at::ScalarType from_, const at::ScalarType to) { + return at::canCast(from_, to); } ScalarType promote_types(ScalarType type1, ScalarType type2) { diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 8cf229c69c23..10d8b1ad79ca 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -7714,7 +7714,7 @@ - func: result_type.Scalar_Scalar(Scalar scalar1, Scalar scalar2) -> ScalarType -- func: can_cast(ScalarType from, ScalarType to) -> bool +- func: can_cast(ScalarType from_, ScalarType to) -> bool variants: function - func: promote_types(ScalarType type1, ScalarType type2) -> ScalarType diff --git a/test/forward_backward_compatibility/check_forward_backward_compatibility.py b/test/forward_backward_compatibility/check_forward_backward_compatibility.py index 285e410a79ed..81b85a4fe42f 100644 --- a/test/forward_backward_compatibility/check_forward_backward_compatibility.py +++ b/test/forward_backward_compatibility/check_forward_backward_compatibility.py @@ -140,6 +140,8 @@ ("onednn::qconv2d_pointwise", datetime.date(2024, 12, 31)), ("onednn::qconv3d_pointwise", datetime.date(2024, 12, 31)), ("onednn::qconv2d_pointwise.binary", datetime.date(2024, 12, 31)), + # BC-breaking change in can_cast signature: 'from' -> 'from_' + ("aten::can_cast", datetime.date(2024, 5, 31)), ] ALLOW_LIST_COMPILED = [ diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index ba8e899dc943..6d22f9dcf984 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -2195,13 +2195,13 @@ def merge_dicts(*dicts): add_docstr( torch.can_cast, r""" -can_cast(from, to) -> bool +can_cast(from_, to) -> bool Determines if a type conversion is allowed under PyTorch casting rules described in the type promotion :ref:`documentation `. Args: - from (dtype): The original :class:`torch.dtype`. + from\_ (dtype): The original :class:`torch.dtype`. to (dtype): The target :class:`torch.dtype`. Example:: From f5abf28e414f46a32c8626c7954f40dc67cb3253 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Thu, 16 May 2024 21:36:56 +0000 Subject: [PATCH 022/116] [Traceable FSDP2] Use DTensor.from_local() in _from_local_no_grad when compile (#126346) As discussed before, for now Dynamo is not able to support DTensor constructor, and instead we have to use `DTensor.from_local()`. This won't affect eager and it's a compile-only change. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126346 Approved by: https://github.com/awgu --- .../_composable/fsdp/_fsdp_common.py | 35 ++++++++++++------- 1 file changed, 22 insertions(+), 13 deletions(-) diff --git a/torch/distributed/_composable/fsdp/_fsdp_common.py b/torch/distributed/_composable/fsdp/_fsdp_common.py index 94b024917769..1395e3487847 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_common.py +++ b/torch/distributed/_composable/fsdp/_fsdp_common.py @@ -117,20 +117,29 @@ def _from_local_no_grad( global_stride: Tuple[int, ...], ) -> DTensor: """ - This method is similar to ``DTensor.from_local()`` except it avoids some - CPU overhead by avoiding default args and not being differentiable. + This method is similar to ``DTensor.from_local()`` except that in eager mode + it avoids some CPU overhead by avoiding default args and not being differentiable. """ - return DTensor( - # Use the local tensor directly instead of constructing a new tensor - # variable, e.g. with `view_as()`, since this is not differentiable - local_tensor, - device_mesh, - placements, - shape=global_size, - dtype=local_tensor.dtype, - requires_grad=local_tensor.requires_grad, - stride=global_stride, - ) + if not torch._dynamo.compiled_autograd.compiled_autograd_enabled: + return DTensor( + # Use the local tensor directly instead of constructing a new tensor + # variable, e.g. with `view_as()`, since this is not differentiable + local_tensor, + device_mesh, + placements, + shape=global_size, + dtype=local_tensor.dtype, + requires_grad=local_tensor.requires_grad, + stride=global_stride, + ) + else: + return DTensor.from_local( + local_tensor, + device_mesh, + placements, + shape=global_size, + stride=global_stride, + ) def _to_dtype_if_needed( From e9719aec30008cefde4a933ed53717469ab95a41 Mon Sep 17 00:00:00 2001 From: Yuanhao Ji Date: Thu, 16 May 2024 21:42:53 +0000 Subject: [PATCH 023/116] Fix strict default value in StateDictOptions (#125998) Fixes #125992 The default value of the parameter `strict` should be `True`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/125998 Approved by: https://github.com/fegin --- torch/distributed/checkpoint/state_dict.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torch/distributed/checkpoint/state_dict.py b/torch/distributed/checkpoint/state_dict.py index e0a4d8886fc7..e7072d623012 100644 --- a/torch/distributed/checkpoint/state_dict.py +++ b/torch/distributed/checkpoint/state_dict.py @@ -135,7 +135,6 @@ class StateDictOptions: - ``strict``: the ``strict`` option when ``set_state_dict`` calls model.load_state_dict(). - The default value is False. - ``broadcast_from_rank0``: when the option is True, rank0 should receive a full state_dict and will broadcast the tensors in the state_dict/ From 4b7eee34509bbf43280bd1fc0167ae8320fec52a Mon Sep 17 00:00:00 2001 From: Tarun Karuturi Date: Thu, 16 May 2024 21:55:11 +0000 Subject: [PATCH 024/116] Print export warning only once in capture_pre_autograd (#126403) Summary: Missed this in D57163341 Test Plan: CI Differential Revision: D57442088 Pull Request resolved: https://github.com/pytorch/pytorch/pull/126403 Approved by: https://github.com/zhxchen17 --- torch/_export/__init__.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torch/_export/__init__.py b/torch/_export/__init__.py index 05aee5f28f66..105a7ee2594b 100644 --- a/torch/_export/__init__.py +++ b/torch/_export/__init__.py @@ -149,7 +149,10 @@ def capture_pre_autograd_graph( kwargs = {} if export_api_rollout_check(): - log.warning("Using torch.export._trace._export") + @lru_cache + def print_export_warning(): + log.warning("Using torch.export._trace._export") + print_export_warning() module = torch.export._trace._export(f, args, kwargs, dynamic_shapes=dynamic_shapes, pre_dispatch=True).module() else: log_export_usage(event="export.private_api", flags={"capture_pre_autograd_graph"}) From 4cd4463c1cf4d3d5cdaed61bd8563977d989e283 Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Thu, 16 May 2024 10:30:19 -0700 Subject: [PATCH 025/116] [compiled autograd] Fix LoggingTensor flaky test (#126144) LoggingTensor fails consistently when root logger level is INFO or lower By default, root logger should be WARNING But, triton driver initialization will overwrite root logger to INFO, which causes flakiness: https://github.com/pytorch/pytorch/issues/126143 Pull Request resolved: https://github.com/pytorch/pytorch/pull/126144 Approved by: https://github.com/jansel --- test/inductor/test_compiled_autograd.py | 56 +++++++++++++++++++++++ torch/_dynamo/compiled_autograd.py | 7 +++ torch/testing/_internal/logging_tensor.py | 6 +-- 3 files changed, 66 insertions(+), 3 deletions(-) diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py index 074d075fc848..201dd4a3c77d 100644 --- a/test/inductor/test_compiled_autograd.py +++ b/test/inductor/test_compiled_autograd.py @@ -1,5 +1,6 @@ # Owner(s): ["module: inductor"] import functools +import logging import re import sys import unittest @@ -51,6 +52,14 @@ def hook3(gI, gO): class TestCompiledAutograd(TestCase): + def setUp(self) -> None: + super().setUp() + compiled_autograd.reset() + + def tearDown(self) -> None: + super().tearDown() + compiled_autograd.reset() + def check_output_and_recompiles( self, fn, count=1, compiler_fn=compiler_fn, compile_fn=False ): @@ -322,6 +331,7 @@ def bytecode_hook(code, out_code): handle.remove() def test_inputs_aliasing_bytecode_stack_restore(self): + logging.getLogger().setLevel(logging.WARNING) from torch.testing._internal.logging_tensor import LoggingTensor # Create a graph that allows inputs stealing @@ -752,6 +762,52 @@ def backward(ctx, gO_1, gO_2, gO_3): self.check_output_and_recompiles(fn, count=2) + @unittest.skipIf(not HAS_CUDA, "requires cuda") + def test_logging_tensor_flaky(self) -> None: + # when you first run some test using triton and then run test_inputs_aliasing_bytecode_stack_restore + # resulting in: + # - pytest: `TypeError: unsupported operand type(s) for +: 'Tensor' and 'LoggingTensor'` + # - python: `TypeError: not all arguments converted during string formatting` + + # 1. some triton involving test + def fn(): + def _fn(x): + return x + + x = torch.arange( + 1, 10, requires_grad=True, dtype=torch.float16, device="cuda" + ) + out = _fn(x) + loss = out.sum() + loss.backward() + + with compiled_autograd.enable(compiler_fn): + fn() + + logging.getLogger().setLevel( + logging.WARNING + ) # triton setup overwrote it to INFO + # 2. test_inputs_aliasing_bytecode_stack_restore + from torch.testing._internal.logging_tensor import LoggingTensor + + def forward(inputs): + add = inputs[0] + 1 + add_1 = add + inputs[1] + out = add_1.cpu() + return (out,) + + gm = torch.fx.symbolic_trace(forward) + print(gm.print_readable()) + torch._dynamo.utils.set_locals_to_steal(gm, ["inputs"]) + compiled_fn = torch.compile(gm) + + inputs = [ + torch.ones(1000000, dtype=torch.float32), + LoggingTensor(torch.ones(1)), + ] + + compiled_fn(inputs) + @unittest.skipIf(not HAS_CUDA, "requires cuda") def test_custom_fn_output_metadata(self): def my_compiler_fn(gm): diff --git a/torch/_dynamo/compiled_autograd.py b/torch/_dynamo/compiled_autograd.py index e8e61042d474..386d0b4dd4ae 100644 --- a/torch/_dynamo/compiled_autograd.py +++ b/torch/_dynamo/compiled_autograd.py @@ -319,3 +319,10 @@ def disable(): if prior: compiled_autograd_enabled = True torch._C._dynamo.compiled_autograd.set_autograd_compiler(prior) + + +# return to starting state of a new process +def reset() -> None: + compiled_autograd_enable = False + assert compiled_autograd_enabled_count == 0 + torch._C._dynamo.compiled_autograd.set_autograd_compiler(None) diff --git a/torch/testing/_internal/logging_tensor.py b/torch/testing/_internal/logging_tensor.py index 5ddd53747440..8b7faf45b3c3 100644 --- a/torch/testing/_internal/logging_tensor.py +++ b/torch/testing/_internal/logging_tensor.py @@ -11,6 +11,7 @@ import functools from torch._C._profiler import gather_traceback, symbolize_tracebacks +logger = logging.getLogger("LoggingTensor") _dtype_abbrs = { torch.bfloat16: "bf16", @@ -135,8 +136,8 @@ def emit(self, record): if self.tracebacks_list is not None: self.tracebacks_list.append(record.traceback) -def log_input(name: str, var: object): - logging.getLogger("LoggingTensor").info("input", (name,), {}, var) # noqa: PLE1205 +def log_input(name: str, var: object) -> None: + logger.info("input", (name,), {}, var) # noqa: PLE1205 class GatherTraceback(logging.Filter): def __init__(self, python=True, script=True, cpp=False): @@ -151,7 +152,6 @@ def filter(self, record): @contextlib.contextmanager def capture_logs(is_mode=False, python_tb=False, script_tb=False, cpp_tb=False) -> Iterator[List[str]]: collect_traceback = python_tb or script_tb or cpp_tb - logger = logging.getLogger("LoggingTensor") log_list: List[str] = [] tracebacks_list: List[str] = [] handler = LoggingTensorHandler( From cef7756c9cf1a5ff53241e9db47b06fbdef17ad8 Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Thu, 16 May 2024 10:30:20 -0700 Subject: [PATCH 026/116] [inductor] Clear cache on ctx manager exit (#126146) FIXES https://github.com/pytorch/pytorch/issues/126128. Right now, we only clear the cache on ctx manager enter. So state is bad unless we call fresh_inductor_cache again, usually fine in tests. Cue compiled autograd tests when going from TestCompiledAutograd -> TestAutogradWithCompiledAutograd. TestCompiledAutograd uses the ctx manager, but TestAutogradWithCompiledAutograd don't Pull Request resolved: https://github.com/pytorch/pytorch/pull/126146 Approved by: https://github.com/jgong5, https://github.com/oulgen ghstack dependencies: #126144 --- test/dynamo/test_repros.py | 16 ++++++++++++++++ torch/_inductor/utils.py | 2 ++ 2 files changed, 18 insertions(+) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index d28b67f3aa94..ff229c06432f 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -4971,6 +4971,22 @@ def fn(x): opt_fn = torch.compile(fn, backend="eager") opt_fn(np.ones([3, 3])) + def test_issue126128(self): + def fn(): + x = torch.randn(1, 10) + y = torch.randn(10, 1) + return torch.mm(x, y).sum() + + def fn2(): + x = torch.randn(10, 100) + y = torch.randn(100, 10) + return torch.mm(x, y).sum() + + with torch._inductor.utils.fresh_inductor_cache(): + torch.compile(fn)() + + torch.compile(fn2)() + instantiate_parametrized_tests(ReproTests) diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index caca6eaf2e21..f0487c9025d5 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -726,6 +726,8 @@ def fresh_inductor_cache(cache_entries=None): except Exception: log.warning("on error, temporary cache dir kept at %s", inductor_cache_dir) raise + finally: + clear_inductor_caches() def argsort(seq) -> List[int]: From 93524cf5ffd4950d93474c1430c0373810ec1034 Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Thu, 16 May 2024 10:30:20 -0700 Subject: [PATCH 027/116] [compiled autograd] clear compiled_autograd_verbose once test is done (#126148) verbose flag leaks into tests ran after Pull Request resolved: https://github.com/pytorch/pytorch/pull/126148 Approved by: https://github.com/jansel ghstack dependencies: #126144, #126146 --- torch/_dynamo/compiled_autograd.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch/_dynamo/compiled_autograd.py b/torch/_dynamo/compiled_autograd.py index 386d0b4dd4ae..f9cf03947a8c 100644 --- a/torch/_dynamo/compiled_autograd.py +++ b/torch/_dynamo/compiled_autograd.py @@ -326,3 +326,4 @@ def reset() -> None: compiled_autograd_enable = False assert compiled_autograd_enabled_count == 0 torch._C._dynamo.compiled_autograd.set_autograd_compiler(None) + torch._C._dynamo.compiled_autograd.set_verbose_logging(False) From f17572fcf63168d4d27ec794a03091c0e0acdbea Mon Sep 17 00:00:00 2001 From: William Wen Date: Wed, 15 May 2024 11:39:51 -0700 Subject: [PATCH 028/116] add 3.12 inductor CI tests (#126218) Pull Request resolved: https://github.com/pytorch/pytorch/pull/126218 Approved by: https://github.com/huydhn, https://github.com/desertfire --- .ci/docker/build.sh | 15 +++++++++++++++ .github/workflows/docker-builds.yml | 1 + .github/workflows/inductor.yml | 21 +++++++++++++++++++++ test/inductor/test_torchinductor.py | 5 +++++ torch/_dynamo/testing.py | 6 ++++++ 5 files changed, 48 insertions(+) diff --git a/.ci/docker/build.sh b/.ci/docker/build.sh index 73e3f09394b7..8786471a7bdd 100755 --- a/.ci/docker/build.sh +++ b/.ci/docker/build.sh @@ -149,6 +149,21 @@ case "$image" in TRITON=yes INDUCTOR_BENCHMARKS=yes ;; + pytorch-linux-focal-cuda12.1-cudnn8-py3.12-gcc9-inductor-benchmarks) + CUDA_VERSION=12.1.1 + CUDNN_VERSION=8 + ANACONDA_PYTHON_VERSION=3.12 + GCC_VERSION=9 + PROTOBUF=yes + DB=yes + VISION=yes + KATEX=yes + UCX_COMMIT=${_UCX_COMMIT} + UCC_COMMIT=${_UCC_COMMIT} + CONDA_CMAKE=yes + TRITON=yes + INDUCTOR_BENCHMARKS=yes + ;; pytorch-linux-focal-cuda11.8-cudnn8-py3-gcc9) CUDA_VERSION=11.8.0 CUDNN_VERSION=8 diff --git a/.github/workflows/docker-builds.yml b/.github/workflows/docker-builds.yml index 9f0dfe973dc9..bb356dce5da9 100644 --- a/.github/workflows/docker-builds.yml +++ b/.github/workflows/docker-builds.yml @@ -42,6 +42,7 @@ jobs: pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9-inductor-benchmarks, pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9, pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9-inductor-benchmarks, + pytorch-linux-focal-cuda12.1-cudnn8-py3.12-gcc9-inductor-benchmarks, pytorch-linux-focal-cuda11.8-cudnn8-py3-gcc9, pytorch-linux-focal-py3.8-clang10, pytorch-linux-focal-py3.11-clang10, diff --git a/.github/workflows/inductor.yml b/.github/workflows/inductor.yml index c00630b5e8b6..3d1c3a539686 100644 --- a/.github/workflows/inductor.yml +++ b/.github/workflows/inductor.yml @@ -107,6 +107,27 @@ jobs: secrets: HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} + linux-focal-cuda12_1-py3_12-gcc9-inductor-build: + name: cuda12.1-py3.12-gcc9-sm86 + uses: ./.github/workflows/_linux-build.yml + with: + build-environment: linux-focal-cuda12.1-py3.12-gcc9-sm86 + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3.12-gcc9-inductor-benchmarks + cuda-arch-list: '8.6' + test-matrix: | + { include: [ + { config: "inductor", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, + ]} + + linux-focal-cuda12_1-py3_12-gcc9-inductor-test: + name: cuda12.1-py3.12-gcc9-sm86 + uses: ./.github/workflows/_linux-test.yml + needs: linux-focal-cuda12_1-py3_12-gcc9-inductor-build + with: + build-environment: linux-focal-cuda12.1-py3.12-gcc9-sm86 + docker-image: ${{ needs.linux-focal-cuda12_1-py3_12-gcc9-inductor-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-focal-cuda12_1-py3_12-gcc9-inductor-build.outputs.test-matrix }} + linux-jammy-cpu-py3_8-gcc11-inductor-build: name: linux-jammy-cpu-py3.8-gcc11-inductor uses: ./.github/workflows/_linux-build.yml diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 695e3ebfe896..73779d22bd42 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -36,6 +36,8 @@ expectedFailureCodegenDynamic, rand_strided, same, + skipIfPy312, + xfailIfPy312, ) from torch._inductor.codegen.common import DataTypePropagation, OptimizationContext from torch._inductor.fx_passes import pad_mm @@ -2721,6 +2723,7 @@ def fn(a, b): check_lowp=False, ) + @skipIfPy312 # segfaults @config.patch(force_mixed_mm=True) def test_mixed_mm(self): def fn(a, b): @@ -2735,6 +2738,7 @@ def fn(a, b): check_lowp=True, ) + @skipIfPy312 # segfaults @config.patch(force_mixed_mm=True) def test_mixed_mm2(self): def fn(a, b, scale, bias): @@ -9426,6 +9430,7 @@ def fn(inp, offsets): self.common(fn, (inp, offsets), check_lowp=False) + @xfailIfPy312 @requires_gpu() @config.patch(assume_aligned_inputs=False) def test_config_option_dont_assume_alignment(self): diff --git a/torch/_dynamo/testing.py b/torch/_dynamo/testing.py index b4c022e8d8c2..9e9abe84228b 100644 --- a/torch/_dynamo/testing.py +++ b/torch/_dynamo/testing.py @@ -349,6 +349,12 @@ def xfailIfPy312(fn): return fn +def skipIfPy312(fn): + if sys.version_info >= (3, 12): + return unittest.skip(fn) + return fn + + # Controls tests generated in test/inductor/test_torchinductor_dynamic_shapes.py # and test/dynamo/test_dynamic_shapes.py def expectedFailureDynamic(fn): From c226839f5cc464559ea0df9a82a7787ef3d5f71b Mon Sep 17 00:00:00 2001 From: Richard Barnes Date: Thu, 16 May 2024 22:37:45 +0000 Subject: [PATCH 029/116] Eliminate some C++11 checks (#126308) Test Plan: Sandcastle Reviewed By: palmje Differential Revision: D57246912 Pull Request resolved: https://github.com/pytorch/pytorch/pull/126308 Approved by: https://github.com/Skylion007 --- c10/util/Float8_e4m3fn.h | 2 +- c10/util/Float8_e4m3fnuz.h | 2 +- c10/util/Float8_e5m2fnuz.h | 2 +- c10/util/Half.h | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/c10/util/Float8_e4m3fn.h b/c10/util/Float8_e4m3fn.h index 8e05e2e43bb0..e7a59e343c1f 100644 --- a/c10/util/Float8_e4m3fn.h +++ b/c10/util/Float8_e4m3fn.h @@ -19,7 +19,7 @@ #include #include -#if defined(__cplusplus) && (__cplusplus >= 201103L) +#if defined(__cplusplus) #include #include #elif !defined(__OPENCL_VERSION__) diff --git a/c10/util/Float8_e4m3fnuz.h b/c10/util/Float8_e4m3fnuz.h index 86ece9ebdadb..cf73b322e899 100644 --- a/c10/util/Float8_e4m3fnuz.h +++ b/c10/util/Float8_e4m3fnuz.h @@ -22,7 +22,7 @@ #include #include -#if defined(__cplusplus) && (__cplusplus >= 201103L) +#if defined(__cplusplus) #include #elif !defined(__OPENCL_VERSION__) #include diff --git a/c10/util/Float8_e5m2fnuz.h b/c10/util/Float8_e5m2fnuz.h index f63773914c11..145464e2cfff 100644 --- a/c10/util/Float8_e5m2fnuz.h +++ b/c10/util/Float8_e5m2fnuz.h @@ -21,7 +21,7 @@ #include #include -#if defined(__cplusplus) && (__cplusplus >= 201103L) +#if defined(__cplusplus) #include #elif !defined(__OPENCL_VERSION__) #include diff --git a/c10/util/Half.h b/c10/util/Half.h index 3d5a38cb365c..af3435941e48 100644 --- a/c10/util/Half.h +++ b/c10/util/Half.h @@ -16,7 +16,7 @@ #include #include -#if defined(__cplusplus) && (__cplusplus >= 201103L) +#if defined(__cplusplus) #include #elif !defined(__OPENCL_VERSION__) #include From 62403b57b9810cf47955a101a1d642e244722358 Mon Sep 17 00:00:00 2001 From: Hongyang Zhao Date: Thu, 16 May 2024 22:38:05 +0000 Subject: [PATCH 030/116] Add prefix option to CapabilityBasedPartitioner (#126382) Summary: Add prefix arg so that users can provide the submodule name to partitioner. Test Plan: https://fburl.com/anp/2kue4qp9 Differential Revision: D57416926 Pull Request resolved: https://github.com/pytorch/pytorch/pull/126382 Approved by: https://github.com/SherlockNoMad --- torch/fx/passes/infra/partitioner.py | 12 ++++++++---- torch/fx/passes/utils/fuser_utils.py | 4 ++-- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/torch/fx/passes/infra/partitioner.py b/torch/fx/passes/infra/partitioner.py index 7b36918928d3..3952bb652517 100644 --- a/torch/fx/passes/infra/partitioner.py +++ b/torch/fx/passes/infra/partitioner.py @@ -262,10 +262,14 @@ def _update_partition_map(node: Node, id: int): return [partition for partition in partitions_by_id.values() if partition.size() > 0] - def fuse_partitions(self, partitions: List[Partition]) -> GraphModule: + def fuse_partitions(self, partitions: List[Partition], prefix: str = "fused_") -> GraphModule: logger.debug("Fusing partitions...") # fuse_by_partitions expects partitions in List[List[Node]]: [ [node0, node1], [node2, node3] ] - return fuse_by_partitions(self.graph_module, [list(partition.nodes) for partition in partitions]) + return fuse_by_partitions( + self.graph_module, + [list(partition.nodes) for partition in partitions], + prefix=prefix, + ) # remove non-compute-ops that sits at the boundary of a partition. def remove_bookend_non_compute_ops(self, partitions: List[Partition]): @@ -323,7 +327,7 @@ def is_transparent_output_node(node: Node, partition: Set[Node], removed_nodes: if len(remove_node) != 0: partition.nodes = partition.nodes - remove_node - def partition_and_fuse(self) -> GraphModule: + def partition_and_fuse(self, prefix: str = "fused_") -> GraphModule: partitions = self.propose_partitions() - fused_gm = self.fuse_partitions(partitions) + fused_gm = self.fuse_partitions(partitions, prefix=prefix) return fused_gm diff --git a/torch/fx/passes/utils/fuser_utils.py b/torch/fx/passes/utils/fuser_utils.py index 8976690ed73a..3423ea3dad5a 100644 --- a/torch/fx/passes/utils/fuser_utils.py +++ b/torch/fx/passes/utils/fuser_utils.py @@ -218,11 +218,11 @@ def erase_nodes(gm: GraphModule, nodes: NodeList): @compatibility(is_backward_compatible=False) -def fuse_by_partitions(gm: GraphModule, partitions: List[NodeList]) -> GraphModule: +def fuse_by_partitions(gm: GraphModule, partitions: List[NodeList], prefix: str = "fused_") -> GraphModule: for partition_id, nodes in enumerate(partitions): sorted_nodes = topo_sort(nodes) - submodule_name = "fused_" + str(partition_id) + submodule_name = prefix + str(partition_id) sub_gm, orig_inputs, orig_outputs = fuse_as_graphmodule(gm, sorted_nodes, submodule_name) insert_subgm(gm, sub_gm, orig_inputs, orig_outputs) From 796dff7147e491ecc13ee66c6ca49ce00d965509 Mon Sep 17 00:00:00 2001 From: Matthias Braun Date: Thu, 16 May 2024 22:51:26 +0000 Subject: [PATCH 031/116] Import MKL via //third-party/mkl targets (#126371) Summary: This is a step towards upgrading the MKL library and using a buckified targets rather than importing from TP2. - Add new `//third-party/mkl:mkl_xxx` targets that are currently aliases to `third-party//IntelComposerXE:mkl_xxx`. - Switch usage of `external_deps = [("IntelComposerXE", None, "mkl_xxx")]` to `deps = ["fbsource//third-party/mkl:mkl_xxx"]` Note that this only changes references to `mkl_xxx` references in `IntelComposerXE` but not references to "svml" or "ipp*". Test Plan: sandcastle Differential Revision: D57360438 Pull Request resolved: https://github.com/pytorch/pytorch/pull/126371 Approved by: https://github.com/bertmaher --- defs.bzl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/defs.bzl b/defs.bzl index 6c32f5f9c8b4..d2978f3bfb97 100644 --- a/defs.bzl +++ b/defs.bzl @@ -1,7 +1,7 @@ def get_blas_gomp_arch_deps(): return [ ("x86_64", [ - "third-party//IntelComposerXE:{}".format(native.read_config("fbcode", "mkl_lp64", "mkl_lp64_omp")), + "fbsource//third-party/mkl:{}".format(native.read_config("fbcode", "mkl_lp64", "mkl_lp64_omp")), ]), ("aarch64", [ "third-party//OpenBLAS:OpenBLAS", From 55628624b872bd2ce51903841a65cc385ed7c526 Mon Sep 17 00:00:00 2001 From: Shuqiang Zhang Date: Thu, 16 May 2024 09:21:47 -0700 Subject: [PATCH 032/116] [c10d] add pg_name and pg_desc to logger (#126409) Summary: This should further improve our debuggability Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/126409 Approved by: https://github.com/XilunWu --- torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index a609da1654b9..7ff75e1bd7f5 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -1567,6 +1567,8 @@ void ProcessGroupNCCL::watchdogHandler() { data.strings["last_enqueued_work_name"] = lastEnqueuedWorkName_; data.strings["last_started_work_name"] = lastStartedWorkName_; data.strings["last_completed_work_name"] = lastCompletedWorkName_; + data.strings["pg_name"] = pg_name_; + data.strings["pg_desc"] = pg_desc_; logger->log(data); lastStatusUpdateTime = std::chrono::steady_clock::now(); } From cb3b8cd0d330723f649262bfd8ce697ac1c91060 Mon Sep 17 00:00:00 2001 From: David Berard Date: Wed, 15 May 2024 14:03:15 -0700 Subject: [PATCH 033/116] Use object identity for deepcopy memo (#126126) Copy of #126089, with some additional fixes & tests Partial fix for #125635: previously, the deepcopy implementation would group together any tensors with any aliasing relationship and assign them to the same tensor. This was sort of good if you have two tensors `b = a.detach()`, because then if you deepcopy `list = [a, b]` to `list2 = list.deepcopy()`, then writes to `list2[0]` will also modify `list2[1]`. But for the most part, it's bad; (1) if you have `b = a.as_strided((4, 4), (16, 1), 16)`, then it'll make `b == a` in the deepcopied implementation, which is completely wrong; and (2) even if you have `b = a.detach()`, these are still initially two different tensors which become the same tensor after the old deepcopy implementation. The new implementation only groups together tensors that have the same identity. This is a partial fix, but it's more reasonable. What changes: * (becomes more correct): different views of the same base tensor will no longer all become equal after deepcopying * (still kind of wrong): views won't actually alias each other after deepcopying. * (arguably a minor regression): equivalent views of the same tensor will no longer be copied to the same tensor - so they won't alias. BC breaking: C++ deepcopy interface changes from accepting `IValue::HashAliasedIValueMap memo` to accepting `IValue::HashIdentityIValueMap memo`. If there are objections, we can keep the old API. However, it seems likely that users generally won't try to deepcopy from C++. Differential Revision: [D57406306](https://our.internmc.facebook.com/intern/diff/D57406306) Pull Request resolved: https://github.com/pytorch/pytorch/pull/126126 Approved by: https://github.com/ezyang --- aten/src/ATen/core/ivalue.cpp | 8 +-- aten/src/ATen/core/ivalue.h | 19 +++++- aten/src/ATen/core/ivalue_inl.h | 2 +- test/cpp/api/CMakeLists.txt | 1 + test/cpp/api/ivalue.cpp | 63 +++++++++++++++++++ torch/csrc/jit/api/module.cpp | 6 +- torch/csrc/jit/api/module.h | 2 +- .../passes/quantization/insert_observers.cpp | 4 +- torch/csrc/jit/python/script_init.cpp | 4 +- 9 files changed, 95 insertions(+), 14 deletions(-) create mode 100644 test/cpp/api/ivalue.cpp diff --git a/aten/src/ATen/core/ivalue.cpp b/aten/src/ATen/core/ivalue.cpp index 6c505f8b656c..3086fa18add6 100644 --- a/aten/src/ATen/core/ivalue.cpp +++ b/aten/src/ATen/core/ivalue.cpp @@ -887,12 +887,12 @@ c10::intrusive_ptr ivalue::Object::create( } IValue IValue::deepcopy(std::optional device) const { - IValue::HashAliasedIValueMap memo; + IValue::HashIdentityIValueMap memo; return deepcopy(memo, device); } IValue IValue::deepcopy( - IValue::HashAliasedIValueMap& memo, + IValue::HashIdentityIValueMap& memo, std::optional device) const { if (memo.count(*this)) { return memo.at(*this); @@ -1028,12 +1028,12 @@ c10::intrusive_ptr ivalue::Object::copy_to_weak_compilation_ref( c10::intrusive_ptr ivalue::Object::deepcopy( std::optional device) const { - IValue::HashAliasedIValueMap memo; + IValue::HashIdentityIValueMap memo; return deepcopy(memo, device); } c10::intrusive_ptr ivalue::Object::deepcopy( - IValue::HashAliasedIValueMap& memo, + IValue::HashIdentityIValueMap& memo, std::optional device) const { auto cu = type_.cu_; auto object = ivalue::Object::create(WeakOrStrongTypePtr(type_.cu_, type_.type_), type()->numAttributes()); diff --git a/aten/src/ATen/core/ivalue.h b/aten/src/ATen/core/ivalue.h index 7715ffbe3c31..922b10b8efeb 100644 --- a/aten/src/ATen/core/ivalue.h +++ b/aten/src/ATen/core/ivalue.h @@ -1117,6 +1117,23 @@ struct TORCH_API IValue final { using HashAliasedIValueMap = std::unordered_map; + struct HashIdentityIValue { + size_t operator()(const IValue& val) const { + return val.payload.u.as_int; + } + }; + + struct CompIdentityIValues { + bool operator()(const IValue& lhs, const IValue& rhs) const { + return lhs.is(rhs); + } + }; + + using HashIdentityIValues = + std::unordered_set; + using HashIdentityIValueMap = + std::unordered_map; + // Chechs if this and rhs has a subvalues in common. // [t1,t2] and [t2, t3] returns true. bool overlaps(const IValue& rhs) const; @@ -1130,7 +1147,7 @@ struct TORCH_API IValue final { void visit(const std::function& visitor) const; IValue deepcopy(std::optional device = c10::nullopt) const; IValue deepcopy( - HashAliasedIValueMap& memo, + HashIdentityIValueMap& memo, std::optional device = c10::nullopt) const; private: diff --git a/aten/src/ATen/core/ivalue_inl.h b/aten/src/ATen/core/ivalue_inl.h index b1124c12cfb3..b99229f2759c 100644 --- a/aten/src/ATen/core/ivalue_inl.h +++ b/aten/src/ATen/core/ivalue_inl.h @@ -1589,7 +1589,7 @@ struct C10_EXPORT ivalue::Object final : c10::intrusive_ptr_target { std::optional device = c10::nullopt) const; c10::intrusive_ptr deepcopy( - IValue::HashAliasedIValueMap& memo, + IValue::HashIdentityIValueMap& memo, std::optional device = c10::nullopt) const; bool is_weak_compilation_ref() const { diff --git a/test/cpp/api/CMakeLists.txt b/test/cpp/api/CMakeLists.txt index 42b67d8cb25c..b0e296ad2309 100644 --- a/test/cpp/api/CMakeLists.txt +++ b/test/cpp/api/CMakeLists.txt @@ -10,6 +10,7 @@ set(TORCH_API_TEST_SOURCES ${TORCH_API_TEST_DIR}/functional.cpp ${TORCH_API_TEST_DIR}/init.cpp ${TORCH_API_TEST_DIR}/integration.cpp + ${TORCH_API_TEST_DIR}/ivalue.cpp ${TORCH_API_TEST_DIR}/jit.cpp ${TORCH_API_TEST_DIR}/memory.cpp ${TORCH_API_TEST_DIR}/meta_tensor.cpp diff --git a/test/cpp/api/ivalue.cpp b/test/cpp/api/ivalue.cpp new file mode 100644 index 000000000000..fa8dcc25cd4d --- /dev/null +++ b/test/cpp/api/ivalue.cpp @@ -0,0 +1,63 @@ +#include + +#include + +#include +#include +#include + +#include + +#include + +#include +#include +#include +#include +#include + +using namespace torch::test; +using namespace torch::nn; +using namespace torch::optim; + +TEST(IValueTest, DeepcopyTensors) { + torch::Tensor t0 = torch::randn({2, 3}); + torch::Tensor t1 = torch::randn({3, 4}); + torch::Tensor t2 = t0.detach(); + torch::Tensor t3 = t0; + torch::Tensor t4 = t1.as_strided({2, 3}, {3, 1}, 2); + std::vector tensor_vector = {t0, t1, t2, t3, t4}; + c10::List tensor_list(tensor_vector); + torch::IValue tensor_list_ivalue(tensor_list); + + c10::IValue::CompIdentityIValues ivalue_compare; + + // Make sure our setup configuration is correct + ASSERT_TRUE(ivalue_compare(tensor_list[0].get(), tensor_list[3].get())); + ASSERT_FALSE(ivalue_compare(tensor_list[0].get(), tensor_list[1].get())); + ASSERT_FALSE(ivalue_compare(tensor_list[0].get(), tensor_list[2].get())); + ASSERT_FALSE(ivalue_compare(tensor_list[1].get(), tensor_list[4].get())); + ASSERT_TRUE(tensor_list[0].get().isAliasOf(tensor_list[2].get())); + + c10::IValue copied_ivalue = tensor_list_ivalue.deepcopy(); + c10::List copied_list = copied_ivalue.toList(); + + // Make sure our setup configuration is correct + ASSERT_TRUE(ivalue_compare(copied_list[0].get(), copied_list[3].get())); + ASSERT_FALSE(ivalue_compare(copied_list[0].get(), copied_list[1].get())); + ASSERT_FALSE(ivalue_compare(copied_list[0].get(), copied_list[2].get())); + ASSERT_FALSE(ivalue_compare(copied_list[1].get(), copied_list[4].get())); + // NOTE: this is actually incorrect. Ideally, these _should_ be aliases. + ASSERT_FALSE(copied_list[0].get().isAliasOf(copied_list[2].get())); + + ASSERT_TRUE(copied_list[0].get().toTensor().allclose( + tensor_list[0].get().toTensor())); + ASSERT_TRUE(copied_list[1].get().toTensor().allclose( + tensor_list[1].get().toTensor())); + ASSERT_TRUE(copied_list[2].get().toTensor().allclose( + tensor_list[2].get().toTensor())); + ASSERT_TRUE(copied_list[3].get().toTensor().allclose( + tensor_list[3].get().toTensor())); + ASSERT_TRUE(copied_list[4].get().toTensor().allclose( + tensor_list[4].get().toTensor())); +} diff --git a/torch/csrc/jit/api/module.cpp b/torch/csrc/jit/api/module.cpp index 1b9932ed34d4..45b99eb8e47a 100644 --- a/torch/csrc/jit/api/module.cpp +++ b/torch/csrc/jit/api/module.cpp @@ -323,7 +323,7 @@ Module Module::deepcopy(std::optional device) const { Module Module::clone(bool inplace) const { std::unordered_map type_remap; - IValue::HashAliasedIValueMap memo; + IValue::HashIdentityIValueMap memo; const std::unordered_set ignored_methods; const std::unordered_set ignored_attributes; return clone_impl( @@ -335,7 +335,7 @@ Module Module::clone( const std::unordered_set& ignored_methods, const std::unordered_set& ignored_attributes) const { std::unordered_map type_remap; - IValue::HashAliasedIValueMap memo; + IValue::HashIdentityIValueMap memo; return clone_impl( type_remap, inplace, memo, ignored_methods, ignored_attributes); } @@ -343,7 +343,7 @@ Module Module::clone( Module Module::clone_impl( std::unordered_map& type_remap, bool inplace, - IValue::HashAliasedIValueMap memo, + IValue::HashIdentityIValueMap memo, const std::unordered_set& ignored_methods, const std::unordered_set& ignored_attributes) const { // Create a new _ivalue in the same compilation unit. diff --git a/torch/csrc/jit/api/module.h b/torch/csrc/jit/api/module.h index 0787210a4aef..e779542e315f 100644 --- a/torch/csrc/jit/api/module.h +++ b/torch/csrc/jit/api/module.h @@ -301,7 +301,7 @@ struct TORCH_API Module : public Object { Module clone_impl( std::unordered_map& type_remap, bool inplace, - IValue::HashAliasedIValueMap memo, + IValue::HashIdentityIValueMap memo, const std::unordered_set& ignored_methods, const std::unordered_set& ignored_attributes) const; diff --git a/torch/csrc/jit/passes/quantization/insert_observers.cpp b/torch/csrc/jit/passes/quantization/insert_observers.cpp index e5df64f1929c..de1cff1ba9d1 100644 --- a/torch/csrc/jit/passes/quantization/insert_observers.cpp +++ b/torch/csrc/jit/passes/quantization/insert_observers.cpp @@ -92,7 +92,7 @@ class ModuleCloneHelper { const ModuleQConfigMap& module_qconfig_map, bool inplace = false) { std::unordered_map type_remap; - IValue::HashAliasedIValueMap memo; + IValue::HashIdentityIValueMap memo; return clone_impl( module, module_qconfig_map, type_remap, inplace, std::move(memo)); } @@ -103,7 +103,7 @@ class ModuleCloneHelper { const ModuleQConfigMap& module_qconfig_map, std::unordered_map& type_remap, bool inplace, - IValue::HashAliasedIValueMap memo) { + IValue::HashIdentityIValueMap memo) { auto qconfig = module_qconfig_map.at(module._ivalue()); auto type = module.type(); // Create a new _ivalue in the same compilation unit. diff --git a/torch/csrc/jit/python/script_init.cpp b/torch/csrc/jit/python/script_init.cpp index 971b6c76ca47..c46762a88615 100644 --- a/torch/csrc/jit/python/script_init.cpp +++ b/torch/csrc/jit/python/script_init.cpp @@ -668,13 +668,13 @@ static constexpr std::array magic_method_names = { }; struct DeepCopyMemoTable { - std::shared_ptr map; + std::shared_ptr map; }; IValue pyIValueDeepcopy(const IValue& ivalue, const py::dict& memo) { if (!memo.contains(py::str("__torch_script_memo_table"))) { memo["__torch_script_memo_table"] = - DeepCopyMemoTable{std::make_shared()}; + DeepCopyMemoTable{std::make_shared()}; } auto& ivalue_memo = *py::cast(memo["__torch_script_memo_table"]).map; From 59ca0d8c141d016f5dcfaf76dfcb63949ee888fd Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 17 May 2024 00:14:59 +0000 Subject: [PATCH 034/116] Revert "[inductor][cpp] bf16/fp16 gemm template computed with fp32 w/o epilogue fusion (#126068)" This reverts commit 927e631dc2356c0cb600dbdf9e8f84ce792a8ba1. Reverted https://github.com/pytorch/pytorch/pull/126068 on behalf of https://github.com/huydhn due to Sorry for reverting your change, but the dependency PR https://github.com/pytorch/pytorch/pull/124021 is going to be revert ([comment](https://github.com/pytorch/pytorch/pull/126019#issuecomment-2116408137)) --- test/inductor/test_cpu_select_algorithm.py | 15 +-- torch/_inductor/codegen/cpp.py | 3 +- torch/_inductor/codegen/cpp_gemm_template.py | 69 +++-------- torch/_inductor/codegen/cpp_micro_gemm.py | 90 ++++---------- .../_inductor/codegen/cpp_template_kernel.py | 113 ++++++++---------- torch/_inductor/codegen/cpp_utils.py | 62 +--------- torch/_inductor/ir.py | 8 +- torch/_inductor/mkldnn_lowerings.py | 100 ++-------------- torch/_inductor/utils.py | 10 +- 9 files changed, 112 insertions(+), 358 deletions(-) diff --git a/test/inductor/test_cpu_select_algorithm.py b/test/inductor/test_cpu_select_algorithm.py index 505ae2f69a4e..75bdff1cba96 100644 --- a/test/inductor/test_cpu_select_algorithm.py +++ b/test/inductor/test_cpu_select_algorithm.py @@ -77,11 +77,11 @@ class TestSelectAlgorithm(TestCase): @torch.no_grad @unittest.skipIf(not TEST_MKL, "Test requires MKL") @parametrize("batch_size", (1, 2, 1000)) - @parametrize("in_features", (1, 1000)) - @parametrize("out_features", (1, 1024)) + @parametrize("in_features", (1, 2, 1000)) + @parametrize("out_features", (1, 32, 1024)) @parametrize("bias", (True, False)) @parametrize("input_3d", (True, False)) - @dtypes(torch.float, torch.bfloat16, torch.half) + @dtypes(torch.float) def test_linear_static_shapes( self, batch_size, in_features, out_features, bias, input_3d, dtype ): @@ -97,14 +97,7 @@ def forward(self, x): mod = M(bias=bias).to(dtype=dtype).eval() B = (2, batch_size) if input_3d else (batch_size,) v = torch.randn(*B, in_features).to(dtype=dtype) - # For bfloat16 and half, we have to relax the tolerance - # due to the difference associave orders in different - # kernel implementations - atol, rtol = 1e-4, 1e-4 - if dtype == torch.half or dtype == torch.bfloat16: - atol, rtol = 1e-2, 1e-2 - with patch.object(select_algorithm, "VERIFY", dict(atol=atol, rtol=rtol)): - self.common(mod, (v,), atol=atol, rtol=rtol) + self.common(mod, (v,)) self.assertEqual( counters["inductor"]["select_algorithm_autotune"], 1 if out_features != 1 else 0, diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 7a026b9b3c6d..40d4f53989af 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -2748,8 +2748,9 @@ def store_reduction(self, name, index, value): return self.simd_vec def __exit__(self, exc_type, exc_val, exc_tb): + assert self._orig_wrapper_code is not None # Restore the wrapper_code - V.graph.wrapper_code = self._orig_wrapper_code # type: ignore[assignment] + V.graph.wrapper_code = self._orig_wrapper_code self.exit_stack.__exit__(exc_type, exc_val, exc_tb) def __enter__(self): diff --git a/torch/_inductor/codegen/cpp_gemm_template.py b/torch/_inductor/codegen/cpp_gemm_template.py index 4d2a640515f5..c664ba7fae45 100644 --- a/torch/_inductor/codegen/cpp_gemm_template.py +++ b/torch/_inductor/codegen/cpp_gemm_template.py @@ -146,7 +146,6 @@ def __init__( beta=1, alpha=1, ): - assert layout.dtype in [torch.float, torch.bfloat16, torch.half] super().__init__("packed_gemm", input_nodes, layout) self.beta = beta self.alpha = alpha @@ -212,13 +211,7 @@ def cache_blocking(self) -> GemmBlocking: @staticmethod def add_choices( - choices, - layout, - input_nodes, - beta=1, - alpha=1, - trans_w=False, - input_indices=None, + choices, layout, input_nodes, beta=1, alpha=1, trans_w=False, input_indices=None ): if input_indices is None: input_indices = list(range(len(input_nodes))) @@ -238,58 +231,28 @@ def reorder_and_filter(inputs, layout_or_out): w_idx = input_indices[2] return [inputs[x_idx], inputs[w_idx], inputs[inp_idx]], layout_or_out - def maybe_to_dense(inputs, layout_or_out): - new_inputs = list(inputs) - if isinstance(inputs[1], torch.Tensor): - W = inputs[1] - new_inputs[1] = W.to_dense() if W.is_mkldnn else W - return new_inputs, layout_or_out - - def normalize_shapes(inputs, layout_or_out): + def transpose_weight(inputs, layout_or_out): if not trans_w: return inputs, layout_or_out new_inputs = list(inputs) - X = inputs[0] W = inputs[1] - B = inputs[2] if len(inputs) > 2 else None if isinstance(W, ir.IRNode): - if trans_w: - if not isinstance(W, ir.TensorBox): - W = ir.TensorBox(W) - W = L.permute(W, [1, 0]) + if not isinstance(W, ir.TensorBox): + W = ir.TensorBox(W) + new_inputs[1] = L.permute(W, [1, 0]) + return new_inputs, layout_or_out else: - if trans_w: - assert isinstance(W, torch.Tensor) - W = W.transpose(0, 1) - if B is not None: - if isinstance(B, ir.IRNode): - if not isinstance(B, ir.TensorBox): - B = ir.TensorBox(B) - B = L.expand(B, (X.get_size()[0], B.get_size()[-1])) - else: - assert isinstance(B, torch.Tensor) - B = B.expand(X.shape[0], B.shape[-1]) - new_inputs[1] = W - if B is not None: - new_inputs[2] = B + assert isinstance(W, torch.Tensor) + new_inputs[1] = W.transpose(0, 1) return new_inputs, layout_or_out # TODO(jgong5): decide proper number of threads per problem size num_threads = parallel_num_threads() - new_inputs, _ = normalize_shapes( - *maybe_to_dense(*reorder_and_filter(input_nodes, layout)) - ) + new_inputs, _ = transpose_weight(*reorder_and_filter(input_nodes, layout)) m, n, k, *_ = mm_args(new_inputs[0], new_inputs[1]) micro_gemm = create_micro_gemm( - "micro_gemm", - m, - n, - k, - input_dtype=layout.dtype, - output_dtype=torch.float, - alpha=alpha, - num_threads=num_threads, + "micro_gemm", m, n, k, layout.dtype, alpha=alpha, num_threads=num_threads ) assert micro_gemm is not None _, block_n, _ = micro_gemm.register_blocking @@ -336,9 +299,7 @@ def pack_weight(inputs, layout_or_out): return new_inputs, layout_or_out def preprocessor(inputs, layout): - return pack_weight( - *normalize_shapes(*maybe_to_dense(*reorder_and_filter(inputs, layout))) - ) + return pack_weight(*transpose_weight(*reorder_and_filter(inputs, layout))) def postprocessor(output): if isinstance(output, ir.TensorBox): @@ -353,7 +314,7 @@ def postprocessor(output): W = V.graph.constants[W_node.get_name()] new_input_nodes[1] = W new_input_nodes, _ = pack_weight( - *normalize_shapes(*maybe_to_dense(new_input_nodes, layout)) + *transpose_weight(new_input_nodes, layout) ) W_packed = new_input_nodes[1] W_packed_constant = V.graph.add_tensor_constant(W_packed) @@ -396,7 +357,8 @@ def render( # type: ignore[override] template_buffer = Y Y_is_transposed = False - use_local_acc = self.layout.dtype != torch.float + # TODO(jgong5): support local accumulation + use_local_acc = False if epilogue_nodes: Y = cast(ir.Buffer, epilogue_nodes[-1]) assert Y.get_name() in V.kernel.inplace_update_buffers @@ -408,8 +370,7 @@ def render( # type: ignore[override] self.m, self.n, self.k, - input_dtype=self.layout.dtype, - output_dtype=torch.float, + self.layout.dtype, alpha=self.alpha, num_threads=self.num_threads, ) diff --git a/torch/_inductor/codegen/cpp_micro_gemm.py b/torch/_inductor/codegen/cpp_micro_gemm.py index 375da4ec1258..353562923c91 100644 --- a/torch/_inductor/codegen/cpp_micro_gemm.py +++ b/torch/_inductor/codegen/cpp_micro_gemm.py @@ -59,11 +59,7 @@ def __init__( def get_common_options(self): return { - "torch": torch, "kernel_name": self.name, - "input_dtype": self.input_dtype, - "output_dtype": self.output_dtype, - "compute_dtype": self.compute_dtype, "input_t": DTYPE_TO_CPP[self.input_dtype], "output_t": DTYPE_TO_CPP[self.output_dtype], "compute_t": DTYPE_TO_CPP[self.compute_dtype], @@ -140,29 +136,6 @@ def inner(cls): return inner -def generate_gemm_config( - vec_isa_cls, - register_blockings, - input_dtype=torch.float, - output_dtype=None, - compute_dtype=None, -): - if output_dtype is None: - output_dtype = input_dtype - if compute_dtype is None: - compute_dtype = output_dtype - return [ - CppMicroGemmConfig( - input_dtype, - output_dtype, - compute_dtype, - vec_isa_cls, - GemmBlocking(*blocking), - ) - for blocking in register_blockings - ] - - class CppMicroGemmRef(CppMicroGemm): """ A reference implementation of the CppMicroGemm class with naive C++ code. @@ -197,41 +170,28 @@ def codegen_define(self, kernel: CppTemplateKernel) -> str: @register_micro_gemm( - *generate_gemm_config( - VecAVX512, [(8, 48, 1), (8, 32, 1), (16, 16, 1)], input_dtype=torch.float + CppMicroGemmConfig( + torch.float32, torch.float32, torch.float32, VecAVX512, GemmBlocking(8, 48, 1) ), - *generate_gemm_config( - VecAVX512, - [(8, 48, 1), (8, 32, 1), (16, 16, 1)], - input_dtype=torch.bfloat16, - output_dtype=torch.float, + CppMicroGemmConfig( + torch.float32, torch.float32, torch.float32, VecAVX512, GemmBlocking(8, 32, 1) ), - *generate_gemm_config( - VecAVX512, - [(8, 48, 1), (8, 32, 1), (16, 16, 1)], - input_dtype=torch.half, - output_dtype=torch.float, + CppMicroGemmConfig( + torch.float32, torch.float32, torch.float32, VecAVX512, GemmBlocking(16, 16, 1) ), - *generate_gemm_config( - VecAVX2, [(4, 24, 1), (4, 16, 1), (8, 8, 1)], input_dtype=torch.float + CppMicroGemmConfig( + torch.float32, torch.float32, torch.float32, VecAVX2, GemmBlocking(4, 24, 1) ), - *generate_gemm_config( - VecAVX2, - [(4, 24, 1), (4, 16, 1), (8, 8, 1)], - input_dtype=torch.bfloat16, - output_dtype=torch.float, + CppMicroGemmConfig( + torch.float32, torch.float32, torch.float32, VecAVX2, GemmBlocking(4, 16, 1) ), - *generate_gemm_config( - VecAVX2, - [(4, 24, 1), (4, 16, 1), (8, 8, 1)], - input_dtype=torch.half, - output_dtype=torch.float, + CppMicroGemmConfig( + torch.float32, torch.float32, torch.float32, VecAVX2, GemmBlocking(8, 8, 1) ), ) class CppMicroGemmFP32Vec(CppMicroGemm): """ - This class generates the code for micro gemm using fp32 vec instructions for compute. - It supports input types of torch.float, torch.bfloat16, and torch.half with fp32 output. + This class generates the code for fp32 micro gemm using vec instructions. """ TEMPLATE_ENTRY = r""" @@ -279,23 +239,22 @@ class CppMicroGemmFP32Vec(CppMicroGemm): TEMPLATE_KERNEL = r""" template inline void {{kernel_name}}_kernel( - const {{input_t}}* __restrict__ A, - const {{input_t}}* __restrict__ B, - {{output_t}}* __restrict__ C, + const float* __restrict__ A, + const float* __restrict__ B, + float* __restrict__ C, int64_t K, int64_t lda, int64_t ldb, int64_t ldc ) { - using Vectorized = at::vec::Vectorized<{{compute_t}}>; - using VectorizedIn = at::vec::Vectorized<{{input_t}}>; + using Vectorized = at::vec::Vectorized; constexpr auto VLEN = Vectorized::size(); constexpr auto ROWS = BLOCK_M; constexpr auto COLS = BLOCK_N / VLEN; Vectorized va; - at::vec::VectorizedN<{{compute_t}}, COLS> vb; - at::vec::VectorizedN<{{compute_t}}, ROWS*COLS> vc; + at::vec::VectorizedN vb; + at::vec::VectorizedN vc; auto loadc = [&](auto i) { if constexpr (accum) { @@ -314,19 +273,14 @@ class CppMicroGemmFP32Vec(CppMicroGemm): if constexpr (col == 0) { {%- if alpha != 1 %} - va = Vectorized(static_cast<{{compute_t}}>(A[row * lda + k]) * {{alpha}}); + va = Vectorized(A[row * lda + k] * {{alpha}}); {%- else %} - va = Vectorized(static_cast<{{compute_t}}>(A[row * lda + k])); + va = Vectorized(A[row * lda + k]); {%- endif %} } if constexpr (row == 0) { - {%- if input_dtype == torch.bfloat16 or input_dtype == torch.float16 %} - auto b = VectorizedIn::loadu(B + k * ldb + col * VLEN, VLEN); - vb[col] = at::vec::convert<{{compute_t}}>(b); - {%- else %} vb[col] = Vectorized::loadu(B + k * ldb + col * VLEN); - {%- endif %} } constexpr int idx = row * COLS + col; @@ -395,7 +349,7 @@ def create_from_config(cls, config: CppMicroGemmConfig): if output_dtype is None: output_dtype = input_dtype if compute_dtype is None: - compute_dtype = output_dtype + compute_dtype = input_dtype if num_threads < 0: num_threads = parallel_num_threads() vec_isa = pick_vec_isa() diff --git a/torch/_inductor/codegen/cpp_template_kernel.py b/torch/_inductor/codegen/cpp_template_kernel.py index deff54e10eb9..55444d7b2bbd 100644 --- a/torch/_inductor/codegen/cpp_template_kernel.py +++ b/torch/_inductor/codegen/cpp_template_kernel.py @@ -13,7 +13,7 @@ from ..virtualized import V from .common import Kernel, OpOverrides from .cpp import CppKernelProxy, KernelGroup -from .cpp_utils import cexpr_index, DTYPE_TO_CPP, LocalBufferScope +from .cpp_utils import cexpr_index, DTYPE_TO_CPP def parse_expr_with_index_symbols(expr): @@ -110,13 +110,7 @@ def index(self, node: ir.Buffer, indices: List[Any]) -> str: indexer = node.make_indexer() index = indexer(parse_expr_with_index_symbols(indices)) index = self.rename_indexing(index) - outer_name = node.get_name() - inner_name = ( - outer_name - if outer_name in self.local_buffers - else self.args.input(node.get_name()) - ) - return f"{inner_name}[{cexpr_index(index)}]" + return f"{self.args.input(node.get_name())}[{cexpr_index(index)}]" def slice_nd(self, node, ranges: List[Tuple[Any, Any]]) -> ir.ReinterpretView: """ @@ -175,50 +169,6 @@ def define_buffer(self, name, sizes: List[Any], dtype=torch.float) -> str: numel = f"{cexpr_index(buf.get_numel())}" return f"auto _{name} = std::make_unique<{ctype}[]>({numel}); auto {name} = _{name}.get();" - def store_pointwise_nodes( - self, - dst: ir.Buffer, - nodes: List[ir.IRNode], - offsets: Optional[List[sympy.Expr]] = None, - reindexer: Optional[Callable[[List[Any]], List[Any]]] = None, - ) -> str: - var_sizes = (tuple(dst.get_size()), ()) - var_ranges = {sympy.Symbol(f"z{i}"): sz for i, sz in enumerate(var_sizes[0])} - if not offsets: - offsets = [sympy.Integer(0)] * len(var_sizes[0]) - assert len(offsets) == len(var_sizes[0]) - output_index = dst.get_layout().make_indexer()(var_ranges.keys()) - kernel_group = KernelGroup() - kernel_group.args = self.args - cpp_kernel_proxy = CppKernelProxy(kernel_group) - bodies = [] - var_sizes_list = [] - for i, node in enumerate(nodes): - output_name = node.get_name() if i < len(nodes) - 1 else dst.get_name() - node = node.data if isinstance(node, ir.ComputedBuffer) else node - assert isinstance(node, ir.Pointwise), node - - def fn(*args): - assert len(args) == 2 - assert len(args[0]) == len(var_sizes[0]) - assert len(args[1]) == 0 - new_args = [arg + offset for arg, offset in zip(args[0], offsets)] # type: ignore[arg-type] - if reindexer is not None: - new_args = reindexer(new_args) - V.ops.store( - output_name, - output_index, - node.make_loader()(new_args).value, - ) - - body = ir.LoopBody(fn, (list(var_ranges.keys()), ()), var_ranges) - bodies.append(body) - var_sizes_list.append(var_sizes) - - cpp_kernel_proxy.codegen_loop_bodies(bodies, var_sizes_list) - kernel_group.finalize_kernel(cpp_kernel_proxy, []) - return kernel_group.loops_code.getvalue() - def store_output( self, dst: ir.Buffer, @@ -246,20 +196,55 @@ def store_output( needed on the indices to `epilogue_nodes` to match the indexing of `dst`. """ assert dst.get_size() == src.get_size() - if offsets: - offsets = parse_expr_with_index_symbols(offsets) if epilogue_nodes: - return self.store_pointwise_nodes(dst, epilogue_nodes, offsets, reindexer) + var_sizes = (tuple(dst.get_size()), ()) + var_ranges = { + sympy.Symbol(f"z{i}"): sz for i, sz in enumerate(var_sizes[0]) + } + + # epilogues are all pointwises, hence all indexed the same way as dst + output_index = dst.get_layout().make_indexer()(var_ranges.keys()) + + if not offsets: + offsets = [0] * len(var_sizes[0]) + assert len(offsets) == len(var_sizes[0]) + offsets = parse_expr_with_index_symbols(offsets) + + kernel_group = KernelGroup() + kernel_group.args = self.args + cpp_kernel_proxy = CppKernelProxy(kernel_group) + bodies = [] + var_sizes_list = [] + for i, node in enumerate(epilogue_nodes): + assert isinstance(node, ir.ComputedBuffer) + output_name = ( + node.get_name() if i < len(epilogue_nodes) - 1 else dst.get_name() + ) + + def fn(*args): + assert len(args) == 2 + assert len(args[0]) == len(var_sizes[0]) + assert len(args[1]) == 0 + new_args = [arg + offset for arg, offset in zip(args[0], offsets)] # type: ignore[arg-type] + if reindexer is not None: + new_args = reindexer(new_args) + V.ops.store( + output_name, + output_index, + node.data.make_loader()(new_args).value, + ) + + body = ir.LoopBody(fn, (list(var_ranges.keys()), ()), var_ranges) + bodies.append(body) + var_sizes_list.append(var_sizes) + + cpp_kernel_proxy.codegen_loop_bodies(bodies, var_sizes_list) + kernel_group.finalize_kernel(cpp_kernel_proxy, []) + return kernel_group.loops_code.getvalue() else: - if dst.get_name() != src.get_name(): - # src is local - copy = L.copy(dst, src).data.data - with LocalBufferScope(self) as scope: - scope.add_local_buffer(src) - return self.store_pointwise_nodes(dst, [copy]) - else: - assert dst.layout == src.layout - return "" + # TODO(jgong5): support local acc buffer to avoid assertion below + assert dst.get_name() == src.get_name() and dst.layout == src.layout + return "" class CppTemplateCaller(ir.ChoiceCaller): diff --git a/torch/_inductor/codegen/cpp_utils.py b/torch/_inductor/codegen/cpp_utils.py index 54904f33f20b..a3b4fd3206b6 100644 --- a/torch/_inductor/codegen/cpp_utils.py +++ b/torch/_inductor/codegen/cpp_utils.py @@ -1,15 +1,10 @@ -import contextlib import math from collections import namedtuple -from typing import Dict -from unittest.mock import patch import torch -from .. import ir -from ..virtualized import V -from .common import ExprPrinter, Kernel +from .common import ExprPrinter DTYPE_TO_CPP = { torch.float32: "float", @@ -241,58 +236,3 @@ def value_to_cpp(value, cpp_type): return f"std::numeric_limits<{cpp_type}>::quiet_NaN()" else: return f"static_cast<{cpp_type}>({repr(value)})" - - -class LocalBufferScope: - """ - This class creates a context that helps to generate code involving Inductor IR with - function local buffers. These buffers are constructed during the codegen process and - are used to store intermediate results such as local accumulators. We do not want to - add them to `V.graph` since they are not global and we do not want to add them as - function arguments either. So we patch the codegen processes under this scope to support - these buffers without exposure to the outside world. - """ - - def __init__(self, kernel: Kernel): - self.kernel = kernel - self.exit_stack = contextlib.ExitStack() - self.local_buffers: Dict[str, ir.Buffer] = {} - - def __enter__(self): - self.exit_stack.__enter__() - original_get_dtype = V.graph.get_dtype - - def get_dtype(name): - if name in self.local_buffers: - return self.local_buffers[name].get_dtype() - return original_get_dtype(name) - - self.exit_stack.enter_context(patch.object(V.graph, "get_dtype", get_dtype)) - - original_input = self.kernel.args.input - - def input(name): - if name in self.local_buffers: - return name - return original_input(name) - - self.exit_stack.enter_context(patch.object(self.kernel.args, "input", input)) - - original_output = self.kernel.args.output - - def output(name): - if name in self.local_buffers: - return name - return original_output(name) - - self.exit_stack.enter_context(patch.object(self.kernel.args, "output", output)) - - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.local_buffers.clear() - self.exit_stack.__exit__(exc_type, exc_val, exc_tb) - - def add_local_buffer(self, buffer: ir.Buffer): - assert buffer.get_name() not in self.local_buffers - self.local_buffers[buffer.get_name()] = buffer diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 7b7e8e567a0b..513709a18f32 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -6314,7 +6314,7 @@ def codegen(self, wrapper): ) @classmethod - def create(cls, x, w, B, attr, scalars, algorithm): + def create(cls, x, w, b, attr, scalars, algorithm): x = cls.require_contiguous(cls.realize_input(x)) w = cls.require_contiguous(cls.realize_input(w)) @@ -6322,9 +6322,9 @@ def create(cls, x, w, B, attr, scalars, algorithm): oc, ic = w.get_size() inputs = [x, w] constant_args = [attr, scalars if scalars else [-1], algorithm] - if B is not None: - B = cls.require_contiguous(cls.realize_input(B)) - inputs.append(B) + if b is not None: + b = cls.require_contiguous(cls.realize_input(b)) + inputs.append(b) else: constant_args.insert(0, None) diff --git a/torch/_inductor/mkldnn_lowerings.py b/torch/_inductor/mkldnn_lowerings.py index 075a9b8b709e..399cb1668dad 100644 --- a/torch/_inductor/mkldnn_lowerings.py +++ b/torch/_inductor/mkldnn_lowerings.py @@ -13,25 +13,14 @@ permute, register_lowering, to_dtype, - view, -) -from .select_algorithm import ( - autotune_select_algorithm, - ChoiceCaller, - ExternKernelChoice, ) +from .select_algorithm import autotune_select_algorithm, ExternKernelChoice from .utils import use_aten_gemm_kernels, use_cpp_packed_gemm_template, use_max_autotune from .virtualized import V def register_onednn_fusion_ops(): if torch._C._has_mkldnn: - aten_mkldnn_linear_unary = ExternKernelChoice( - torch.ops.mkldnn._linear_pointwise, - "mkldnn::_linear_pointwise", - has_out_variant=False, - kernel_creator=ir.LinearUnary.create, - ) cpu_needs_realized_inputs = [ torch.ops.mkldnn._convolution_pointwise, torch.ops.mkldnn._convolution_pointwise_, @@ -139,75 +128,11 @@ def convolution_binary_inplace( @register_lowering(torch.ops.mkldnn._linear_pointwise) def linear_unary( - x: TensorBox, - w: TensorBox, - b: TensorBox, - attr, - scalars, - algorithm, - layout=None, + x: TensorBox, w: TensorBox, b: TensorBox, attr, scalars, algorithm ): - x_size = x.get_size() - if len(x_size) > 2: - # GEMM template needs 2D input, normalize input shape here - x = view(x, [-1, x_size[-1]]) - choices: List[ChoiceCaller] = [] - if len(choices) == 0 or use_aten_gemm_kernels(): - choices.append( - aten_mkldnn_linear_unary.bind( - (x, w), - layout, - B=None, - attr=attr, - scalars=scalars, - algorithm=algorithm, - ) - if b is None - else aten_mkldnn_linear_unary.bind( - (x, w, b), - layout, - attr=attr, - scalars=scalars, - algorithm=algorithm, - ) - ) - if use_max_autotune(): - transposed_w = permute(w, [1, 0]) - *_, layout, x, transposed_w = mm_args(x, transposed_w, layout=layout) - # TODO(jgong5): support epilogue fusion - if ( - use_cpp_packed_gemm_template(layout, x, transposed_w) - and attr == "none" - ): - if b is None: - CppPackedGemmTemplate.add_choices( - choices, - layout, - [x, w], - trans_w=True, - ) - else: - CppPackedGemmTemplate.add_choices( - choices, - layout, - [x, w, b], - trans_w=True, - input_indices=[2, 0, 1], - ) - assert w.get_name() in V.graph.constants - input_gen_fns = { - 1: lambda x: V.graph.constants[x.get_name()], - } - result = autotune_select_algorithm( - "linear_unary", - choices, - [x, w] if b is None else [x, w, b], - layout, - input_gen_fns=input_gen_fns, + return TensorBox.create( + ir.LinearUnary.create(x, w, b, attr, scalars, algorithm) ) - if len(x_size) > 2: - result = view(result, (*x_size[:-1], result.get_size()[-1])) - return result @register_lowering(torch.ops.mkldnn._linear_pointwise.binary) def linear_binary(x: TensorBox, y: TensorBox, w: TensorBox, b: TensorBox, attr): @@ -444,7 +369,15 @@ def mkl_packed_linear( *, layout=None, ): - choices: List[ChoiceCaller] = [] + choices = ( + [ + aten_mkl_linear.bind( + (x, packed_w, orig_w), layout, B=None, batch_size=batch_size + ) + ] + if use_aten_gemm_kernels() + else [] + ) if use_max_autotune(): transposed_w = permute(orig_w, [1, 0]) *_, layout, x, transposed_w = mm_args( @@ -459,13 +392,6 @@ def mkl_packed_linear( input_indices=[0, 2], ) - if len(choices) == 0 or use_aten_gemm_kernels(): - choices.append( - aten_mkl_linear.bind( - (x, packed_w, orig_w), layout, B=None, batch_size=batch_size - ) - ) - assert packed_w.get_name() in V.graph.constants assert orig_w.get_name() in V.graph.constants # packed_w is a mkldnn tensor which we can't generate directly diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index f0487c9025d5..5ff1d951bb42 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -1002,7 +1002,7 @@ def use_cpp_packed_gemm_template(layout, mat1, mat2): if not config.cpp.weight_prepack: return False - layout_dtypes = [torch.float32, torch.bfloat16, torch.half] + layout_dtypes = [torch.float32] m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2) # TODO(jgong5): support dynamic shapes for n or k if has_free_symbols((n, k)): @@ -1010,13 +1010,7 @@ def use_cpp_packed_gemm_template(layout, mat1, mat2): if isinstance(mat2, ir.BaseView): mat2 = mat2.unwrap_view() micro_gemm = create_micro_gemm( - "micro_gemm", - m, - n, - k, - input_dtype=layout.dtype, - output_dtype=torch.float, - num_threads=parallel_num_threads(), + "micro_gemm", m, n, k, layout.dtype, num_threads=parallel_num_threads() ) # TODO(jgong5): support n % n_block_size != 0 return ( From 4a5ef0b7938f9b5b6cb2f1cd200f925059ba3f59 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 17 May 2024 00:15:00 +0000 Subject: [PATCH 035/116] Revert "[inductor][cpp] epilogue support for gemm template (#126019)" This reverts commit 7844c202b2076ec3efa23264226f3eaef11a6fcb. Reverted https://github.com/pytorch/pytorch/pull/126019 on behalf of https://github.com/huydhn due to Sorry for reverting your change, but the dependency PR https://github.com/pytorch/pytorch/pull/124021 is going to be revert ([comment](https://github.com/pytorch/pytorch/pull/126019#issuecomment-2116408137)) --- test/inductor/test_cpu_select_algorithm.py | 135 +---------------- torch/_inductor/codegen/common.py | 2 +- torch/_inductor/codegen/cpp.py | 91 ++++------- torch/_inductor/codegen/cpp_gemm_template.py | 61 +++----- torch/_inductor/codegen/cpp_template.py | 6 +- .../_inductor/codegen/cpp_template_kernel.py | 142 ++---------------- torch/_inductor/select_algorithm.py | 3 +- 7 files changed, 73 insertions(+), 367 deletions(-) diff --git a/test/inductor/test_cpu_select_algorithm.py b/test/inductor/test_cpu_select_algorithm.py index 75bdff1cba96..5377b1f8f7e5 100644 --- a/test/inductor/test_cpu_select_algorithm.py +++ b/test/inductor/test_cpu_select_algorithm.py @@ -1,7 +1,5 @@ # Owner(s): ["oncall: cpu inductor"] import functools - -import sys import unittest from unittest.mock import patch @@ -19,18 +17,6 @@ from torch.testing._internal.common_utils import IS_MACOS, parametrize, TEST_MKL -try: - try: - from . import test_torchinductor - except ImportError: - import test_torchinductor -except unittest.SkipTest: - if __name__ == "__main__": - sys.exit(0) - raise - -check_model = test_torchinductor.check_model - aten = torch.ops.aten @@ -38,14 +24,7 @@ def patches(fn): def skip_cache(self, choices, name, key, benchmark): if benchmark is None: return {} - timings = benchmark(choices) - for choice, timing in timings.items(): - if isinstance(choice, select_algorithm.ExternKernelCaller): - # we intentionally make ATEN kernel slower to cover the cases - # where template kernels are always chosen with fusions applied - # and correctness checks at runtime. - timings[choice] = timing * 1000 - return timings + return benchmark(choices) for patcher in [ dynamo_config.patch(verbose=True), @@ -70,8 +49,6 @@ def wrapped(*args, **kwargs): class TestSelectAlgorithm(TestCase): - common = check_model - @inductor_config.patch({"freezing": True}) @patches @torch.no_grad @@ -90,6 +67,7 @@ def __init__(self, bias): super().__init__() self.linear = torch.nn.Linear(in_features, out_features, bias) + @torch.compile def forward(self, x): return self.linear(x) @@ -97,7 +75,7 @@ def forward(self, x): mod = M(bias=bias).to(dtype=dtype).eval() B = (2, batch_size) if input_3d else (batch_size,) v = torch.randn(*B, in_features).to(dtype=dtype) - self.common(mod, (v,)) + mod(v) self.assertEqual( counters["inductor"]["select_algorithm_autotune"], 1 if out_features != 1 else 0, @@ -126,108 +104,10 @@ def forward(self, x): counters.clear() mod = M(bias=bias).to(dtype=dtype).eval() v = torch.randn(in_features, batch_size).to(dtype=dtype) - self.common(mod, (v.transpose(0, 1),)) + mod(v.transpose(0, 1)) # TODO(jgong5): support transposed input self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 0) - @inductor_config.patch({"freezing": True}) - @patches - @torch.no_grad - @unittest.skipIf(not TEST_MKL, "Test requires MKL") - @parametrize("bias", (True, False)) - @parametrize( - "epilogue", - ( - "relu", - "gelu", - "silu", - "sigmoid", - "tanh", - "hardswish", - "hardsigmoid", - "leaky_relu", - "hardtanh", - "add", - "sub", - "mul", - "div", - ), - ) - @dtypes(torch.float) - def test_linear_with_pointwise(self, bias, epilogue, dtype): - batch_size = 384 - in_features = 196 - out_features = 384 - - class M(torch.nn.Module): - def __init__(self, bias, epilogue, other): - super().__init__() - self.linear = torch.nn.Linear(in_features, out_features, bias) - if epilogue == "relu": - self.epilogue = torch.nn.ReLU() - elif epilogue == "gelu": - self.epilogue = torch.nn.GELU() - elif epilogue == "silu": - self.epilogue = torch.nn.SiLU() - elif epilogue == "sigmoid": - self.epilogue = torch.nn.Sigmoid() - elif epilogue == "tanh": - self.epilogue = torch.nn.Tanh() - elif epilogue == "hardswish": - self.epilogue = torch.nn.Hardswish() - elif epilogue == "hardsigmoid": - self.epilogue = torch.nn.Hardsigmoid() - elif epilogue == "leaky_relu": - self.epilogue = torch.nn.LeakyReLU() - elif epilogue == "hardtanh": - self.epilogue = torch.nn.Hardtanh() - elif epilogue == "add": - self.epilogue = lambda x: x + other - elif epilogue == "sub": - self.epilogue = lambda x: x - other - elif epilogue == "mul": - self.epilogue = lambda x: x * other - elif epilogue == "div": - self.epilogue = lambda x: x / other - - def forward(self, x): - return self.epilogue(self.linear(x)) - - counters.clear() - v = torch.randn(batch_size, in_features).to(dtype=dtype) - u = torch.randn(batch_size, out_features).to(dtype=dtype) - mod = M(bias=bias, epilogue=epilogue, other=u).to(dtype=dtype).eval() - self.common(mod, (v,)) - self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) - self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1) - - @inductor_config.patch({"freezing": True}) - @patches - @torch.no_grad - @unittest.skipIf(not TEST_MKL, "Test requires MKL") - @parametrize("bias", (True, False)) - @dtypes(torch.float) - def test_linear_with_transpose(self, bias, dtype): - batch_size = 384 - in_features = 196 - out_features = 128 - - class M(torch.nn.Module): - def __init__(self, bias): - super().__init__() - self.linear = torch.nn.Linear(in_features, out_features, bias) - - def forward(self, x, y): - return self.linear(x).transpose(0, 1) + y - - counters.clear() - mod = M(bias=bias).to(dtype=dtype).eval() - v = torch.randn(batch_size, in_features).to(dtype=dtype) - u = torch.randn(out_features, batch_size).to(dtype=dtype) - self.common(mod, (v, u)) - self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) - self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1) - @dynamo_config.patch({"dynamic_shapes": True, "assume_static_by_default": False}) class _DynamicShapesTestBase(TestCase): @@ -235,14 +115,7 @@ class _DynamicShapesTestBase(TestCase): class TestSelectAlgorithmDynamicShapes(_DynamicShapesTestBase): - common = check_model test_linear_dynamic_shapes = TestSelectAlgorithm.test_linear_static_shapes - test_linear_with_pointwise_dynamic_shapes = ( - TestSelectAlgorithm.test_linear_with_pointwise - ) - test_linear_with_transpose_dynamic_shapes = ( - TestSelectAlgorithm.test_linear_with_transpose - ) instantiate_device_type_tests(TestSelectAlgorithm, globals(), only_for="cpu") diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index a6ac6234ab08..0d90e474d04b 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -1508,7 +1508,7 @@ def _bound_variable(name, *args, **kwargs): return ValueRanges.unknown() fx_node = V.interpreter.current_node - if fx_node.target == name and self.node_to_bounds is not None: + if fx_node.target == name: assert isinstance(self.node_to_bounds, dict) return self.node_to_bounds.get(fx_node, ValueRanges.unknown()) elif config.compute_all_bounds and hasattr(ValueRangeAnalysis, name): diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 40d4f53989af..e12a72d11601 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -46,7 +46,7 @@ sympy_subs, ) -from ..virtualized import NullKernelHandler, ops, OpsValue, V +from ..virtualized import ops, OpsValue, V from .common import ( BracesBuffer, CppWrapperKernelArgs, @@ -3148,11 +3148,27 @@ def is_memory_copy_scheduler_node(node: SchedulerNode): body: ir.LoopBody = node._body _legalize_lowp_fp(body) - def codegen_functions(self, fn_list, var_sizes_list, vec_dtype=torch.float): - # TODO(jgong5): remove vec_dtype arg with alternative tiling factors for various dtypes - assert len(fn_list) == len(var_sizes_list) + def codegen_nodes(self, nodes: List[SchedulerNode]): + # Legalize BF16 node by adding to_dtype explicitly + self.legalize_lowp_fp_dtype(nodes) + self.data_type_propagation(nodes) + + assert len(nodes) >= 1 + first_node = nodes[0] + vec_dtype = ( + first_node._lowp_fp_type # type: ignore[attr-defined] + if all( + hasattr(_node, "_lowp_fp_type") + and _node._lowp_fp_type == first_node._lowp_fp_type # type: ignore[attr-defined] + for _node in nodes + ) + else torch.float + ) + kernel_group = self.kernel_group - group, reduction_group = max(var_sizes_list, key=lambda sizes: len(sizes[1])) + _, (group, reduction_group) = max( + nodes, key=lambda x: int(x.is_reduction()) + ).group self.set_ranges(group, reduction_group) @@ -3168,22 +3184,22 @@ def codegen_kernel(cls, *args): def run(kernel): vars, reduction_vars = kernel.set_ranges(group, reduction_group) in_suffix = False - for fn, var_sizes in zip(fn_list, var_sizes_list): - if var_sizes in [ + for node in nodes: + if node.group[1] in [ (group, reduction_group), (group + reduction_group, ()), ]: assert not in_suffix - fn(vars, reduction_vars) + node.run(vars, reduction_vars) else: in_suffix = True - assert var_sizes == ( + assert node.group[1] == ( group, (), - ), f"unexpected group: {var_sizes} != {group}, {reduction_group}" + ), f"unexpected group: {node.group[1]} != {group}, {reduction_group}" # we can fuse in some extra pointwise into the suffix with kernel.write_to_suffix(): - fn(vars, ()) + node.run(vars, ()) scalar_kernel = codegen_kernel(CppKernel) V.graph.removed_buffers |= scalar_kernel.removed_buffers @@ -3195,8 +3211,8 @@ def run(kernel): def select_tiling_indices(tiling_factor): all_index = [] - for fn, var_sizes in zip(fn_list, var_sizes_list): - rw = dependencies.extract_read_writes(fn, *var_sizes) + for node in nodes: + rw = dependencies.extract_read_writes(node._body, *node._sizes) all_index += [dep.index for dep in itertools.chain(rw.reads, rw.writes)] contig_vars = set() contig_vars_list = [] @@ -3310,41 +3326,6 @@ def select_tiling(dtype: torch.dtype = torch.float): inner_main_loop.set_kernel(tile2d_kernel) inner_tail_loop.set_kernel(vec_kernel) - def codegen_loop_bodies(self, loop_bodies, var_sizes_list): - # TODO(jgong5): support lowp legalization - for body in loop_bodies: - DataTypePropagation.propagate_loopbody(body) - self.codegen_functions(loop_bodies, var_sizes_list) - - def codegen_nodes(self, nodes: List[SchedulerNode]): - # Legalize BF16 node by adding to_dtype explicitly - self.legalize_lowp_fp_dtype(nodes) - self.data_type_propagation(nodes) - - assert len(nodes) >= 1 - first_node = nodes[0] - vec_dtype = ( - first_node._lowp_fp_type # type: ignore[attr-defined] - if all( - hasattr(_node, "_lowp_fp_type") - and _node._lowp_fp_type == first_node._lowp_fp_type # type: ignore[attr-defined] - for _node in nodes - ) - else torch.float - ) - - def fn(node, *index_vars): - node.decide_inplace_update() - node.mark_run() - if isinstance(V.kernel, NullKernelHandler): - return node._body(*index_vars) - else: - return node.codegen(index_vars) - - fn_list = [functools.partial(fn, node) for node in nodes] - var_sizes_list = [node.group[1] for node in nodes] - self.codegen_functions(fn_list, var_sizes_list, vec_dtype) - def codegen_loops(self, code, worksharing): self.codegen_loops_impl(self.loop_nest, code, worksharing) @@ -3409,9 +3390,6 @@ def reset_kernel_group(self): def fuse(self, node1, node2): if node1.is_foreach() or node2.is_foreach(): return ForeachKernelSchedulerNode.fuse(node1, node2) - elif node1.is_template(): - assert not node2.is_template() - return FusedSchedulerNode.fuse(node1, node2) else: if ( self._why_fuse_nodes(node1, node2) @@ -3610,9 +3588,7 @@ def _get_outer_loop_fusion_depth(self, node1, node2): def can_fuse_vertical_outer_loop(self, node1, node2): return ( - not node1.is_template() - and not node2.is_template() - and node1.get_names() & node2.ancestors + node1.get_names() & node2.ancestors and not ( self._can_fuse_horizontal_impl(node1, node2) and not node1.is_reduction() @@ -3628,11 +3604,9 @@ def get_fusion_pair_priority(self, node1, node2): return 0 def can_fuse_vertical(self, node1, node2): - if node2.is_template(): - # TODO(jgong5): support pre-op fusion with template + # TODO(jgong5): support vertical fusion for template nodes + if node1.is_template() or node2.is_template(): return False - if node1.is_template(): - return not node2.is_reduction() return ( self._can_fuse_horizontal_impl(node1, node2) and not node1.is_reduction() ) or self.can_fuse_vertical_outer_loop(node1, node2) @@ -3715,7 +3689,6 @@ def codegen_template( kernel, render = ctb.make_kernel_render(ctb, epilogue_nodes=epilogue_ir_nodes) with kernel: for node in [template_node, *epilogue_nodes]: - node.decide_inplace_update() node.mark_run() src_code = render() diff --git a/torch/_inductor/codegen/cpp_gemm_template.py b/torch/_inductor/codegen/cpp_gemm_template.py index c664ba7fae45..c623f262b015 100644 --- a/torch/_inductor/codegen/cpp_gemm_template.py +++ b/torch/_inductor/codegen/cpp_gemm_template.py @@ -24,7 +24,7 @@ { {{kernel.maybe_codegen_profile()}} constexpr int64_t num_threads = {{num_threads}}; - constexpr int64_t N = {{kernel.size(GemmOut, 1)}}; + constexpr int64_t N = {{kernel.size(Y, 1)}}; constexpr int64_t K = {{kernel.size(X, 1)}}; constexpr int64_t M0 = {{micro_gemm.register_blocking.block_m}}; constexpr int64_t N0 = {{micro_gemm.register_blocking.block_n}}; @@ -36,7 +36,7 @@ // TODO(jgong5): improve cache blocking with CPU info (Mc, Kc) {%- if is_dynamic_M %} - const int64_t M = {{kernel.size(GemmOut, 0)}}; + const int64_t M = {{kernel.size(Y, 0)}}; const int64_t M0_blocks = (M + M0 - 1) / M0; {%- if num_threads > 1 %} const auto [Mt_blocks, Nt_blocks, Kt_blocks] = mm_get_thread_blocking(M, N, K, M0, N0, K0, num_threads); @@ -48,7 +48,7 @@ const int64_t Mc_blocks = Mt_blocks; const int64_t Kc_blocks = Kt_blocks; {%- else %} - constexpr int64_t M = {{kernel.size(GemmOut, 0)}}; + constexpr int64_t M = {{kernel.size(Y, 0)}}; constexpr int64_t M0_blocks = (M + M0 - 1) / M0; constexpr int64_t Mt_blocks = {{template.thread_blocking().block_m}}; constexpr int64_t Nt_blocks = {{template.thread_blocking().block_n}}; @@ -83,23 +83,16 @@ int64_t k_block_end = K0_blocks; {%- endif %} for (int64_t mc = m_block_start; mc < m_block_end; mc += Mc_blocks) { - const int64_t m_start = mc * M0; - const int64_t m_end = std::min((mc + Mc_blocks) * M0, M); - const int64_t m_size = m_end - m_start; + int64_t m_start = mc * M0; + int64_t m_end = std::min((mc + Mc_blocks) * M0, M); for (int64_t nc = n_block_start; nc < n_block_end; ++nc) { - const int64_t n_start = nc * N0; - const int64_t n_size = N0; - {%- if use_local_acc %} - {{ kernel.define_buffer("acc_local_buf", ["m_end - m_start", "N0"]) }} - {%- set acc = kernel.local_buffers["acc_local_buf"] %} - {%- else %} - {%- set acc = kernel.slice_nd(GemmOut, [("m_start", "m_end"), ("n_start", "n_start + N0")]) %} - {%- endif %} + int64_t n_start = nc * N0; + // TODO(jgong5): use float32 temporary buffer to support bfloat16/float16 gemm {%- if inp is not none and beta != 0 %} - for (int64_t m = 0; m < m_size; ++m) { + for (int64_t m = m_start; m < m_end; ++m) { #pragma omp simd - for (int64_t n = 0; n < n_size; ++n) { - {{kernel.index(acc, ["m", "n"])}} = {{beta}} * {{kernel.index(inp, ["m + m_start", "n + n_start"])}}; + for (int64_t n = n_start; n < n_start + N0; ++n) { + {{kernel.index(Y, ["m", "n"])}} = {{beta}} * {{kernel.index(inp, ["m", "n"])}}; } } {%- endif %} @@ -109,26 +102,17 @@ {%- set tile_X = kernel.slice_nd(X, [("m_start", "m_end"), ("k_start", "k_end")]) %} {%- set tile_W_3d = kernel.slice_nd(W, [("nc", "nc + 1"), ("k_start", "k_end"), ()]) %} {%- set tile_W = kernel.view(tile_W_3d, ["k_end - k_start", micro_gemm.register_blocking.block_n]) %} + {%- set tile_Y = kernel.slice_nd(Y, [("m_start", "m_end"), ("n_start", "n_start + N0")]) %} {%- if inp is not none and beta != 0 %} - {{ micro_gemm.codegen_call(kernel, tile_X, tile_W, acc, accum=True)|indent(20, false) }} + {{ micro_gemm.codegen_call(kernel, tile_X, tile_W, tile_Y, accum=True)|indent(20, false) }} {%- else %} if (kc == k_block_start) { - {{ micro_gemm.codegen_call(kernel, tile_X, tile_W, acc, accum=False)|indent(24, false) }} + {{ micro_gemm.codegen_call(kernel, tile_X, tile_W, tile_Y, accum=False)|indent(24, false) }} } else { - {{ micro_gemm.codegen_call(kernel, tile_X, tile_W, acc, accum=True)|indent(24, false) }} + {{ micro_gemm.codegen_call(kernel, tile_X, tile_W, tile_Y, accum=True)|indent(24, false) }} } {%- endif %} } - {%- if reindexer is not none %} - {%- set Y_maybe_transposed = kernel.permute(Y, reindexer([0,1])) %} - {%- else %} - {%- set Y_maybe_transposed = Y %} - {%- endif %} - {%- set tile_Y = kernel.slice_nd(Y_maybe_transposed, [("m_start", "m_end"), ("n_start", "n_start + N0")]) %} - {{ kernel.store_output( - tile_Y, acc, epilogue_nodes, offsets=("m_start", "n_start"), reindexer=reindexer - )|indent(16, false) - }} } } } @@ -344,6 +328,7 @@ def render( # type: ignore[override] epilogue_nodes: Optional[List[ir.IRNode]] = None, **kwargs, ) -> str: + assert not epilogue_nodes, "Epilogue nodes are not supported for GEMM template." assert len(self.input_nodes) >= 2 X, W = self.input_nodes[0], self.input_nodes[1] @@ -354,16 +339,9 @@ def render( # type: ignore[override] # Use the updated prepacked weight buffer W = template_buffer_node.inputs[1] Y = template_buffer_node - - template_buffer = Y - Y_is_transposed = False - # TODO(jgong5): support local accumulation - use_local_acc = False - if epilogue_nodes: + if epilogue_nodes is not None and len(epilogue_nodes) > 0: Y = cast(ir.Buffer, epilogue_nodes[-1]) - assert Y.get_name() in V.kernel.inplace_update_buffers - if Y.get_stride() == list(reversed(template_buffer.get_stride())): - Y_is_transposed = True + assert self.output_node is not None micro_gemm = create_micro_gemm( f"{kernel.kernel_name}_micro_gemm", @@ -382,7 +360,6 @@ def render( # type: ignore[override] W=W, inp=inp, Y=Y, - GemmOut=template_buffer, beta=self.beta, alpha=self.alpha, num_threads=self.num_threads, @@ -390,8 +367,6 @@ def render( # type: ignore[override] is_dynamic_M=self.is_dynamic_M, template=self, kernel=kernel, - epilogue_nodes=epilogue_nodes, - reindexer=(lambda x: list(reversed(x))) if Y_is_transposed else None, - use_local_acc=use_local_acc, + epilogues=epilogue_nodes, ) return self._template_from_string(GEMM_TEMPLATE).render(**options) diff --git a/torch/_inductor/codegen/cpp_template.py b/torch/_inductor/codegen/cpp_template.py index 222d6a2e57ba..3d15010a8838 100644 --- a/torch/_inductor/codegen/cpp_template.py +++ b/torch/_inductor/codegen/cpp_template.py @@ -39,7 +39,7 @@ def generate(self, **kwargs): ), CppTemplateKernel( kernel_name=kernel_name, ) as kernel: - code = kernel.render(self, **kwargs) + code = self.render(kernel=kernel, **kwargs) _, call_args, _ = kernel.args.python_argdefs() log.debug("Generated Code:\n%s", code) log.debug( @@ -79,8 +79,8 @@ def make_kernel_render( kernel_name=str(Placeholder.KERNEL_NAME), ) render = functools.partial( - kernel.render, - self, + self.render, + kernel=kernel, template_buffer_node=template_node, epilogue_nodes=epilogue_nodes, **kwargs, diff --git a/torch/_inductor/codegen/cpp_template_kernel.py b/torch/_inductor/codegen/cpp_template_kernel.py index 55444d7b2bbd..6a978c45fa28 100644 --- a/torch/_inductor/codegen/cpp_template_kernel.py +++ b/torch/_inductor/codegen/cpp_template_kernel.py @@ -5,26 +5,19 @@ from sympy.parsing.sympy_parser import parse_expr import torch -from .. import codecache, config, ir, lowering as L -from ..autotune_process import CppBenchmarkRequest -from ..select_algorithm import PartialRender -from ..utils import sympy_index_symbol +from torch._inductor.autotune_process import CppBenchmarkRequest +from torch._inductor.utils import sympy_index_symbol +from .. import codecache, config, ir, lowering as L from ..virtualized import V from .common import Kernel, OpOverrides -from .cpp import CppKernelProxy, KernelGroup from .cpp_utils import cexpr_index, DTYPE_TO_CPP -def parse_expr_with_index_symbols(expr): - if isinstance(expr, sympy.Expr): - return expr - elif isinstance(expr, (list, tuple)): - return [parse_expr_with_index_symbols(e) for e in expr] - else: - expr = parse_expr(str(expr)) - int_symbols = {sym: sympy_index_symbol(sym.name) for sym in expr.free_symbols} - return expr.subs(int_symbols) +def parse_expr_with_index_symbols(expr_str: str) -> sympy.Expr: + expr = parse_expr(expr_str) + int_symbols = {sym: sympy_index_symbol(sym.name) for sym in expr.free_symbols} + return expr.subs(int_symbols) def wrap_with_tensorbox(node) -> ir.TensorBox: @@ -39,13 +32,6 @@ class CppTemplateKernel(Kernel): def __init__(self, kernel_name): super().__init__() self.kernel_name = kernel_name - self.render_hooks = {} - self.local_buffers = {} - - def render(self, template, **kwargs): - return PartialRender( - template.render(kernel=self, **kwargs), self.render_hooks - ).finalize() def def_kernel( self, @@ -56,8 +42,7 @@ def def_kernel( if inp is not None: self.args.input_buffers[inp.get_name()] = name for name, out in outputs.items(): - if out.get_name() not in self.args.inplace_buffers: - self.args.output_buffers[out.get_name()] = name + self.args.output_buffers[out.get_name()] = name unique_sizevars = { s for input in inputs.values() @@ -76,15 +61,8 @@ def def_kernel( sizevars = sorted(unique_sizevars, key=str) for sizevar in sizevars: self.args.sizevars[sizevar] = f"k{sizevar}" - - def hook(): - cpp_argdefs, _, _ = self.args.cpp_argdefs() - return f"void {self.kernel_name}({', '.join(cpp_argdefs)})" - - placeholder = "" - assert placeholder not in self.render_hooks - self.render_hooks[placeholder] = hook - return placeholder + cpp_argdefs, _, _ = self.args.cpp_argdefs() + return f"void {self.kernel_name}({', '.join(cpp_argdefs)})" def call_kernel(self, name: str, node: ir.CppTemplateBuffer): wrapper = V.graph.wrapper_code @@ -108,11 +86,11 @@ def stride(self, node: ir.Buffer, dim: int) -> str: def index(self, node: ir.Buffer, indices: List[Any]) -> str: indexer = node.make_indexer() - index = indexer(parse_expr_with_index_symbols(indices)) + index = indexer([parse_expr_with_index_symbols(str(idx)) for idx in indices]) index = self.rename_indexing(index) return f"{self.args.input(node.get_name())}[{cexpr_index(index)}]" - def slice_nd(self, node, ranges: List[Tuple[Any, Any]]) -> ir.ReinterpretView: + def slice_nd(self, node, ranges: List[Tuple[Any]]) -> ir.ReinterpretView: """ Slice the given node with a list of ranges (start and end) corresponding to its dims. The dim is not sliced if the corresponding range is empty. @@ -123,22 +101,16 @@ def slice_nd(self, node, ranges: List[Tuple[Any, Any]]) -> ir.ReinterpretView: if len(_range) == 0: continue assert len(_range) == 2 - start, end = parse_expr_with_index_symbols(_range) + start, end = (parse_expr_with_index_symbols(str(r)) for r in _range) sliced = L.slice_(sliced, dim, start, end, clamp=False) assert isinstance(sliced.data, ir.ReinterpretView) return sliced.data def view(self, node, sizes: List[Any]) -> ir.View: node = wrap_with_tensorbox(node) - sizes = parse_expr_with_index_symbols(sizes) + sizes = [parse_expr_with_index_symbols(str(s)) for s in sizes] return L.view(node, sizes).data - def permute(self, node, dims): - node = wrap_with_tensorbox(node) - permuted = L.permute(node, dims).data - assert isinstance(permuted, ir.ReinterpretView) - return permuted - @property def assert_function(self) -> str: if V.graph.aot_mode: @@ -160,92 +132,6 @@ def unroll_pragma(self, unroll): else: return f"#pragma unroll {unroll}" - def define_buffer(self, name, sizes: List[Any], dtype=torch.float) -> str: - """Define kernel local buffer""" - sizes = parse_expr_with_index_symbols(sizes) - buf = ir.Buffer(name, ir.FixedLayout(torch.device("cpu"), dtype, sizes)) - self.local_buffers[name] = buf - ctype = f"{DTYPE_TO_CPP[dtype]}" - numel = f"{cexpr_index(buf.get_numel())}" - return f"auto _{name} = std::make_unique<{ctype}[]>({numel}); auto {name} = _{name}.get();" - - def store_output( - self, - dst: ir.Buffer, - src: ir.Buffer, - epilogue_nodes: Optional[List[ir.IRNode]] = None, - offsets: Optional[List[Any]] = None, - reindexer: Optional[Callable[[List[Any]], List[Any]]] = None, - ): - """ - Store the `src` buffer to the `dst` buffer. The size of `src` and `dst` should match. - If `epilogue_nodes` is provided, the `src` buffer is firstly computed with the epilogues - before stored to `dst`. The `epilogues_nodes` are all pointwise. - - Notes: - 1. `src` and `dst` buffer could be the same buffer in which case we are doing in-place compute - and stores. In case `epilogue_nodes` are not provided, we do nothing. - 2. The `epilogue_nodes`, if exist, have computations on `src` before storing to `dst` but since - they come form the original Inductor IR, they might need to be adjusted before working with - `src` and `dst` as outlined below: - a) `src` or `dst` buffer could be a sub-slice of the ranges the `epilogue_nodes`work on. - In this case, the `offsets` could be provided to adjust the indices passed to - `epilogue_nodes` during codegen and the data ranges are also configured according to - the sizes of `src` and `dst`. - b) `dst` might be indexed in a different way as the `epilogue_nodes`, hence a `reindexer` is - needed on the indices to `epilogue_nodes` to match the indexing of `dst`. - """ - assert dst.get_size() == src.get_size() - if epilogue_nodes: - var_sizes = (tuple(dst.get_size()), ()) - var_ranges = { - sympy.Symbol(f"z{i}"): sz for i, sz in enumerate(var_sizes[0]) - } - - # epilogues are all pointwises, hence all indexed the same way as dst - output_index = dst.get_layout().make_indexer()(var_ranges.keys()) - - if not offsets: - offsets = [0] * len(var_sizes[0]) - assert len(offsets) == len(var_sizes[0]) - offsets = parse_expr_with_index_symbols(offsets) - - kernel_group = KernelGroup() - kernel_group.args = self.args - cpp_kernel_proxy = CppKernelProxy(kernel_group) - bodies = [] - var_sizes_list = [] - for i, node in enumerate(epilogue_nodes): - assert isinstance(node, ir.ComputedBuffer) - output_name = ( - node.get_name() if i < len(epilogue_nodes) - 1 else dst.get_name() - ) - - def fn(*args): - assert len(args) == 2 - assert len(args[0]) == len(var_sizes[0]) - assert len(args[1]) == 0 - new_args = [arg + offset for arg, offset in zip(args[0], offsets)] # type: ignore[arg-type] - if reindexer is not None: - new_args = reindexer(new_args) - V.ops.store( - output_name, - output_index, - node.data.make_loader()(new_args).value, - ) - - body = ir.LoopBody(fn, (list(var_ranges.keys()), ()), var_ranges) - bodies.append(body) - var_sizes_list.append(var_sizes) - - cpp_kernel_proxy.codegen_loop_bodies(bodies, var_sizes_list) - kernel_group.finalize_kernel(cpp_kernel_proxy, []) - return kernel_group.loops_code.getvalue() - else: - # TODO(jgong5): support local acc buffer to avoid assertion below - assert dst.get_name() == src.get_name() and dst.layout == src.layout - return "" - class CppTemplateCaller(ir.ChoiceCaller): """ diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 39d8334fe7d5..d1550529bb8e 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -1019,8 +1019,7 @@ def __call__( # Templates selected with input_gen_fns require specific input data to avoid IMA # Passing custom input gen fns to benchmark_fusion NYI, so skip deferred template selection - # TODO(jgong5): support multi-template on CPU - if input_gen_fns is not None or layout.device.type == "cpu": + if input_gen_fns is not None: return_multi_template = False # TODO - assert that we have not mutating kernels here From 337830f6574057b5e04a7f2c0aa671a575058add Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 17 May 2024 00:22:40 +0000 Subject: [PATCH 036/116] Revert "[inductor][cpp] GEMM template (infra and fp32) (#124021)" This reverts commit f060b0c6e608436997a1dc229c82ce26c1e6676f. Reverted https://github.com/pytorch/pytorch/pull/124021 on behalf of https://github.com/huydhn due to Unfortunately, the new tests are still failing internally ([comment](https://github.com/pytorch/pytorch/pull/124021#issuecomment-2116415398)) --- test/inductor/test_cpu_select_algorithm.py | 131 ------ torch/_inductor/codegen/cpp.py | 49 +-- torch/_inductor/codegen/cpp_gemm_template.py | 372 ---------------- torch/_inductor/codegen/cpp_micro_gemm.py | 401 ------------------ torch/_inductor/codegen/cpp_prefix.h | 98 ----- torch/_inductor/codegen/cpp_template.py | 116 ----- .../_inductor/codegen/cpp_template_kernel.py | 200 --------- torch/_inductor/codegen/cpp_utils.py | 4 - torch/_inductor/config.py | 5 +- torch/_inductor/ir.py | 15 +- torch/_inductor/kernel/mm.py | 18 - torch/_inductor/mkldnn_lowerings.py | 65 +-- torch/_inductor/select_algorithm.py | 90 +--- torch/_inductor/utils.py | 36 -- 14 files changed, 14 insertions(+), 1586 deletions(-) delete mode 100644 test/inductor/test_cpu_select_algorithm.py delete mode 100644 torch/_inductor/codegen/cpp_gemm_template.py delete mode 100644 torch/_inductor/codegen/cpp_micro_gemm.py delete mode 100644 torch/_inductor/codegen/cpp_template.py delete mode 100644 torch/_inductor/codegen/cpp_template_kernel.py diff --git a/test/inductor/test_cpu_select_algorithm.py b/test/inductor/test_cpu_select_algorithm.py deleted file mode 100644 index 5377b1f8f7e5..000000000000 --- a/test/inductor/test_cpu_select_algorithm.py +++ /dev/null @@ -1,131 +0,0 @@ -# Owner(s): ["oncall: cpu inductor"] -import functools -import unittest -from unittest.mock import patch - -import torch -import torch._dynamo.config -import torch._dynamo.config as dynamo_config -import torch._inductor.config as inductor_config -import torch._inductor.select_algorithm as select_algorithm -from torch._dynamo.utils import counters -from torch._inductor.test_case import run_tests, TestCase -from torch.testing._internal.common_device_type import ( - dtypes, - instantiate_device_type_tests, -) - -from torch.testing._internal.common_utils import IS_MACOS, parametrize, TEST_MKL - -aten = torch.ops.aten - - -def patches(fn): - def skip_cache(self, choices, name, key, benchmark): - if benchmark is None: - return {} - return benchmark(choices) - - for patcher in [ - dynamo_config.patch(verbose=True), - inductor_config.patch( - debug=True, - max_autotune=True, - epilogue_fusion=True, - max_autotune_gemm_backends="CPP,ATEN", - ), - patch.object(select_algorithm, "VERIFY", dict(atol=1e-4, rtol=1e-4)), - patch.object(select_algorithm.AlgorithmSelectorCache, "lookup", skip_cache), - ]: - fn = patcher(fn) - - @functools.wraps(fn) - def wrapped(*args, **kwargs): - counters.clear() - torch.manual_seed(12345) - return fn(*args, **kwargs) - - return wrapped - - -class TestSelectAlgorithm(TestCase): - @inductor_config.patch({"freezing": True}) - @patches - @torch.no_grad - @unittest.skipIf(not TEST_MKL, "Test requires MKL") - @parametrize("batch_size", (1, 2, 1000)) - @parametrize("in_features", (1, 2, 1000)) - @parametrize("out_features", (1, 32, 1024)) - @parametrize("bias", (True, False)) - @parametrize("input_3d", (True, False)) - @dtypes(torch.float) - def test_linear_static_shapes( - self, batch_size, in_features, out_features, bias, input_3d, dtype - ): - class M(torch.nn.Module): - def __init__(self, bias): - super().__init__() - self.linear = torch.nn.Linear(in_features, out_features, bias) - - @torch.compile - def forward(self, x): - return self.linear(x) - - counters.clear() - mod = M(bias=bias).to(dtype=dtype).eval() - B = (2, batch_size) if input_3d else (batch_size,) - v = torch.randn(*B, in_features).to(dtype=dtype) - mod(v) - self.assertEqual( - counters["inductor"]["select_algorithm_autotune"], - 1 if out_features != 1 else 0, - ) - - @inductor_config.patch({"freezing": True}) - @patches - @torch.no_grad - @unittest.skipIf(not TEST_MKL, "Test requires MKL") - @parametrize("bias", (True, False)) - @dtypes(torch.float) - def test_linear_input_transpose(self, bias, dtype): - batch_size = 384 - in_features = 196 - out_features = 384 - - class M(torch.nn.Module): - def __init__(self, bias): - super().__init__() - self.linear = torch.nn.Linear(in_features, out_features, bias) - - @torch.compile - def forward(self, x): - return self.linear(x) - - counters.clear() - mod = M(bias=bias).to(dtype=dtype).eval() - v = torch.randn(in_features, batch_size).to(dtype=dtype) - mod(v.transpose(0, 1)) - # TODO(jgong5): support transposed input - self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 0) - - -@dynamo_config.patch({"dynamic_shapes": True, "assume_static_by_default": False}) -class _DynamicShapesTestBase(TestCase): - pass - - -class TestSelectAlgorithmDynamicShapes(_DynamicShapesTestBase): - test_linear_dynamic_shapes = TestSelectAlgorithm.test_linear_static_shapes - - -instantiate_device_type_tests(TestSelectAlgorithm, globals(), only_for="cpu") -instantiate_device_type_tests( - TestSelectAlgorithmDynamicShapes, globals(), only_for="cpu" -) - - -if __name__ == "__main__": - from torch.testing._internal.inductor_utils import HAS_CPU - - if HAS_CPU and not IS_MACOS: - run_tests() diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index e12a72d11601..c0aad2d27428 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -8,7 +8,7 @@ import sys from copy import copy, deepcopy from enum import Enum -from typing import Any, cast, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Dict, List, Optional, Set, Tuple, Union import sympy @@ -20,7 +20,6 @@ from torch.utils._sympy.functions import FloorDiv, ModularIndexing from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges -from ..._dynamo.utils import counters from .. import codecache, config, ir, metrics from ..codegen.wrapper import WrapperCodeGen @@ -3522,8 +3521,6 @@ def _can_fuse_horizontal_impl(self, node1, node2): return self._why_fuse_nodes(node1, node2) is not None def can_fuse_horizontal(self, node1, node2): - if node1.is_template() or node2.is_template(): - return False if ( len(node1.get_nodes()) + len(node2.get_nodes()) > config.cpp.max_horizontal_fusion_size @@ -3604,9 +3601,6 @@ def get_fusion_pair_priority(self, node1, node2): return 0 def can_fuse_vertical(self, node1, node2): - # TODO(jgong5): support vertical fusion for template nodes - if node1.is_template() or node2.is_template(): - return False return ( self._can_fuse_horizontal_impl(node1, node2) and not node1.is_reduction() ) or self.can_fuse_vertical_outer_loop(node1, node2) @@ -3663,42 +3657,6 @@ def codegen_node( if args_num > CppScheduling.MAX_FUSED_KERNEL_ARGS_NUM: self._set_flush_status(True) - def is_cpp_template(self, node: BaseSchedulerNode) -> bool: - return isinstance(node, SchedulerNode) and isinstance( - node.node, ir.CppTemplateBuffer - ) - - def codegen_template( - self, template_node: BaseSchedulerNode, epilogue_nodes: List[SchedulerNode] - ): - """ - Codegen a CPP template, possibly with fused epilogues - """ - counters["inductor"]["cpp_epilogue_fusion_counter"] += len(epilogue_nodes) - assert self.is_cpp_template( - template_node - ), "Template node passed to CppScheduler.codegen_template must be a SchedulerNode that wraps a CppTemplateBuffer" - template_node = cast(SchedulerNode, template_node) - _, (_, rnumel) = template_node.group - assert rnumel == () - ctb: ir.CppTemplateBuffer = cast(ir.CppTemplateBuffer, template_node.node) - epilogue_ir_nodes: List[ir.Buffer] = [n.node for n in epilogue_nodes] - assert all( - isinstance(n, ir.ComputedBuffer) for n in epilogue_ir_nodes - ), "Epilogue nodes must all be instances of ir.ComputedBuffer" - kernel, render = ctb.make_kernel_render(ctb, epilogue_nodes=epilogue_ir_nodes) - with kernel: - for node in [template_node, *epilogue_nodes]: - node.mark_run() - src_code = render() - - with V.set_kernel_handler(kernel): - node_schedule = [template_node, *epilogue_nodes] - kernel_name = self.define_kernel(src_code, node_schedule, kernel.args) - kernel.call_kernel(kernel_name, ctb) - V.graph.removed_buffers |= kernel.removed_buffers - self.scheduler.free_buffers() - def _get_scheduled_num_args(self): return self.kernel_group.get_num_args() @@ -3708,7 +3666,7 @@ def ready_to_flush(self): def codegen_sync(self): pass - def define_kernel(self, src_code, nodes, kernel_args=None): + def define_kernel(self, src_code, nodes): wrapper = V.graph.wrapper_code fused_name = ( get_fused_kernel_name(nodes, config.cpp.descriptive_names) @@ -3724,8 +3682,7 @@ def define_kernel(self, src_code, nodes, kernel_args=None): src_code = src_code.replace("#pragma CMT", "//") compile_wrapper = IndentedBuffer() - args = self.kernel_group.args if kernel_args is None else kernel_args - _, _, arg_types = args.cpp_argdefs() + _, _, arg_types = self.kernel_group.args.cpp_argdefs() if not V.graph.cpp_wrapper: compile_wrapper.writeline(f"async_compile.cpp_pybinding({arg_types!r}, '''") compile_wrapper.splice(src_code, strip=True) diff --git a/torch/_inductor/codegen/cpp_gemm_template.py b/torch/_inductor/codegen/cpp_gemm_template.py deleted file mode 100644 index c623f262b015..000000000000 --- a/torch/_inductor/codegen/cpp_gemm_template.py +++ /dev/null @@ -1,372 +0,0 @@ -from typing import cast, List, Optional - -import torch -import torch.utils -from .. import ir, lowering as L - -from ..kernel.mm_common import mm_args -from ..select_algorithm import DataProcessorTemplateWrapper -from ..utils import cache_on_self, has_free_symbols, parallel_num_threads -from ..virtualized import V -from .cpp_micro_gemm import create_micro_gemm -from .cpp_template import CppTemplate - -from .cpp_template_kernel import CppTemplateKernel -from .cpp_utils import GemmBlocking - -GEMM_TEMPLATE = r""" -{{template.header().getvalue()}} - -{{micro_gemm.codegen_define(kernel)}} - -extern "C" -{{kernel.def_kernel(inputs={"X": X, "W": W, "inp": inp}, outputs={"Y": Y})}} -{ - {{kernel.maybe_codegen_profile()}} - constexpr int64_t num_threads = {{num_threads}}; - constexpr int64_t N = {{kernel.size(Y, 1)}}; - constexpr int64_t K = {{kernel.size(X, 1)}}; - constexpr int64_t M0 = {{micro_gemm.register_blocking.block_m}}; - constexpr int64_t N0 = {{micro_gemm.register_blocking.block_n}}; - constexpr int64_t K0 = {{micro_gemm.register_blocking.block_k}}; - constexpr int64_t N0_blocks = (N + N0 - 1) / N0; - constexpr int64_t K0_blocks = (K + K0 - 1) / K0; - - static_assert(N % N0 == 0, "N dimension must be multiple of N0"); - - // TODO(jgong5): improve cache blocking with CPU info (Mc, Kc) - {%- if is_dynamic_M %} - const int64_t M = {{kernel.size(Y, 0)}}; - const int64_t M0_blocks = (M + M0 - 1) / M0; - {%- if num_threads > 1 %} - const auto [Mt_blocks, Nt_blocks, Kt_blocks] = mm_get_thread_blocking(M, N, K, M0, N0, K0, num_threads); - {%- else %} - const auto Mt_blocks = M0_blocks; - const auto Nt_blocks = N0_blocks; - const auto Kt_blocks = K0_blocks; - {%- endif %} - const int64_t Mc_blocks = Mt_blocks; - const int64_t Kc_blocks = Kt_blocks; - {%- else %} - constexpr int64_t M = {{kernel.size(Y, 0)}}; - constexpr int64_t M0_blocks = (M + M0 - 1) / M0; - constexpr int64_t Mt_blocks = {{template.thread_blocking().block_m}}; - constexpr int64_t Nt_blocks = {{template.thread_blocking().block_n}}; - constexpr int64_t Kt_blocks = {{template.thread_blocking().block_k}}; - constexpr int64_t Mc_blocks = {{template.cache_blocking().block_m}}; - constexpr int64_t Kc_blocks = {{template.cache_blocking().block_k}}; - {%- endif %} - - // TODO(jgong5): support k-slicing - {{kernel.assert_function}}(Kt_blocks == K0_blocks, "Do not support k slicing yet."); - // make sure all partitions are assigned - {{kernel.assert_function}}( - Mt_blocks * Nt_blocks * Kt_blocks * {{num_threads}} >= M0_blocks * N0_blocks * K0_blocks, - "Not all partitions are assigned." - ); - - {%- if num_threads > 1 %} - #pragma omp parallel num_threads({{num_threads}}) - { - int tid = omp_get_thread_num(); - int64_t m_block_start, m_block_end, n_block_start, n_block_end, k_block_start, k_block_end; - mm_get_thread_blocks( - tid, M0_blocks, N0_blocks, K0_blocks, Mt_blocks, Nt_blocks, Kt_blocks, - m_block_start, m_block_end, n_block_start, n_block_end, k_block_start, k_block_end); - {%- else %} - { - int64_t m_block_start = 0; - int64_t m_block_end = M0_blocks; - int64_t n_block_start = 0; - int64_t n_block_end = N0_blocks; - int64_t k_block_start = 0; - int64_t k_block_end = K0_blocks; - {%- endif %} - for (int64_t mc = m_block_start; mc < m_block_end; mc += Mc_blocks) { - int64_t m_start = mc * M0; - int64_t m_end = std::min((mc + Mc_blocks) * M0, M); - for (int64_t nc = n_block_start; nc < n_block_end; ++nc) { - int64_t n_start = nc * N0; - // TODO(jgong5): use float32 temporary buffer to support bfloat16/float16 gemm - {%- if inp is not none and beta != 0 %} - for (int64_t m = m_start; m < m_end; ++m) { - #pragma omp simd - for (int64_t n = n_start; n < n_start + N0; ++n) { - {{kernel.index(Y, ["m", "n"])}} = {{beta}} * {{kernel.index(inp, ["m", "n"])}}; - } - } - {%- endif %} - for (int64_t kc = k_block_start; kc < k_block_end; kc += Kc_blocks) { - int64_t k_start = kc * K0; - int64_t k_end = std::min((kc + Kc_blocks) * K0, K); - {%- set tile_X = kernel.slice_nd(X, [("m_start", "m_end"), ("k_start", "k_end")]) %} - {%- set tile_W_3d = kernel.slice_nd(W, [("nc", "nc + 1"), ("k_start", "k_end"), ()]) %} - {%- set tile_W = kernel.view(tile_W_3d, ["k_end - k_start", micro_gemm.register_blocking.block_n]) %} - {%- set tile_Y = kernel.slice_nd(Y, [("m_start", "m_end"), ("n_start", "n_start + N0")]) %} - {%- if inp is not none and beta != 0 %} - {{ micro_gemm.codegen_call(kernel, tile_X, tile_W, tile_Y, accum=True)|indent(20, false) }} - {%- else %} - if (kc == k_block_start) { - {{ micro_gemm.codegen_call(kernel, tile_X, tile_W, tile_Y, accum=False)|indent(24, false) }} - } else { - {{ micro_gemm.codegen_call(kernel, tile_X, tile_W, tile_Y, accum=True)|indent(24, false) }} - } - {%- endif %} - } - } - } - } -} -""" - - -class CppPackedGemmTemplate(CppTemplate): - def __init__( - self, - input_nodes, - layout: ir.Layout, - num_threads: int, - register_blocking: GemmBlocking, - beta=1, - alpha=1, - ): - super().__init__("packed_gemm", input_nodes, layout) - self.beta = beta - self.alpha = alpha - self.num_threads = num_threads - self.register_blocking = register_blocking - m, n = layout.size - _, k = input_nodes[0].get_size() - self.m, self.n, self.k = m, n, k - self.is_dynamic_M = has_free_symbols((m,)) - - @cache_on_self - def thread_blocking(self) -> GemmBlocking: - # TODO(jgong5): allow tuning various blocking options - def get_factors(number): - factors = [] - # priorize more evenly divided factors - for i in range(int(number**0.5), 0, -1): - if number % i == 0: - factors.append(number // i) - factors.append(i) - return factors - - def get_blocking(num_threads, factor, m_blocks, n_blocks, k_blocks): - thread_block_n = (n_blocks + factor - 1) // factor - cofactor = num_threads // factor - thread_block_m = (m_blocks + cofactor - 1) // cofactor - return GemmBlocking(thread_block_m, thread_block_n, k_blocks) - - assert ( - not self.is_dynamic_M - ), "Unable to determine thread blocking for dynamic M." - register_blocking = self.register_blocking - m_blocks = (self.m + register_blocking.block_m - 1) // register_blocking.block_m - n_blocks = (self.n + register_blocking.block_n - 1) // register_blocking.block_n - k_blocks = (self.k + register_blocking.block_k - 1) // register_blocking.block_k - factors = get_factors(self.num_threads) - assert len(factors) > 0 - for factor in factors: - if n_blocks % factor == 0 and m_blocks % (self.num_threads // factor) == 0: - return get_blocking( - self.num_threads, factor, m_blocks, n_blocks, k_blocks - ) - for factor in factors: - if n_blocks % factor == 0: - return get_blocking( - self.num_threads, factor, m_blocks, n_blocks, k_blocks - ) - cofactor = self.num_threads // factor - if m_blocks % cofactor == 0: - return get_blocking( - self.num_threads, factor, m_blocks, n_blocks, k_blocks - ) - raise AssertionError("Should not reach here.") - - @cache_on_self - def cache_blocking(self) -> GemmBlocking: - # TODO(jgong5): improve cache blocking with CPU info - assert ( - not self.is_dynamic_M - ), "Unable to determine cache blocking for dynamic M." - thread_blocking = self.thread_blocking() - return GemmBlocking(thread_blocking.block_m, 1, thread_blocking.block_k) - - @staticmethod - def add_choices( - choices, layout, input_nodes, beta=1, alpha=1, trans_w=False, input_indices=None - ): - if input_indices is None: - input_indices = list(range(len(input_nodes))) - - def reorder_and_filter(inputs, layout_or_out): - if len(input_indices) == 2: - x_idx = input_indices[0] - w_idx = input_indices[1] - return [inputs[x_idx], inputs[w_idx]], layout_or_out - else: - assert ( - len(input_indices) == 3 - ), "Cpp Packed GEMM template requires 2 or 3 input nodes." - # assume the input order is [inp, x, w] and we reorder it to [x, w, inp] - inp_idx = input_indices[0] - x_idx = input_indices[1] - w_idx = input_indices[2] - return [inputs[x_idx], inputs[w_idx], inputs[inp_idx]], layout_or_out - - def transpose_weight(inputs, layout_or_out): - if not trans_w: - return inputs, layout_or_out - - new_inputs = list(inputs) - W = inputs[1] - if isinstance(W, ir.IRNode): - if not isinstance(W, ir.TensorBox): - W = ir.TensorBox(W) - new_inputs[1] = L.permute(W, [1, 0]) - return new_inputs, layout_or_out - else: - assert isinstance(W, torch.Tensor) - new_inputs[1] = W.transpose(0, 1) - return new_inputs, layout_or_out - - # TODO(jgong5): decide proper number of threads per problem size - num_threads = parallel_num_threads() - new_inputs, _ = transpose_weight(*reorder_and_filter(input_nodes, layout)) - m, n, k, *_ = mm_args(new_inputs[0], new_inputs[1]) - micro_gemm = create_micro_gemm( - "micro_gemm", m, n, k, layout.dtype, alpha=alpha, num_threads=num_threads - ) - assert micro_gemm is not None - _, block_n, _ = micro_gemm.register_blocking - - def pack_weight(inputs, layout_or_out): - W = inputs[1] - new_inputs = list(inputs) - if isinstance(W, ir.IRNode): - if not isinstance(W, ir.TensorBox): - W = ir.TensorBox(W) - k, n = W.get_size() - assert ( - n % block_n == 0 - ), f"The last dimension of W must be a multiple of {block_n}." - blocked_w = L.permute( - L.view(W, (k, n // block_n, block_n)), - [1, 0, 2], - ) - blocked_w = ir.ExternKernel.realize_input(blocked_w) - blocked_w = ir.ExternKernel.require_contiguous(blocked_w) - if isinstance(blocked_w, ir.ReinterpretView): - # normalize stride to be "contiguous_strides" per size - # this avoids the problems in L.view during template codegen - assert isinstance(blocked_w.layout, ir.FixedLayout) - blocked_w.layout = ir.FixedLayout( - blocked_w.layout.device, - blocked_w.layout.dtype, - blocked_w.layout.size, - ir.FlexibleLayout.contiguous_strides(blocked_w.layout.size), - blocked_w.layout.offset, - ) - else: - k, n = list(W.shape) - blocked_w = ( - W.reshape(k, n // block_n, block_n).transpose(0, 1).contiguous() - ) - # normalize stride to be "contiguous_strides" per size - # this avoids the problems in L.view during template codegen - new_stride = [1] - for sz in reversed(blocked_w.shape[1:]): - new_stride.insert(0, new_stride[0] * sz) - blocked_w = blocked_w.as_strided(blocked_w.shape, new_stride) - new_inputs[1] = blocked_w - return new_inputs, layout_or_out - - def preprocessor(inputs, layout): - return pack_weight(*transpose_weight(*reorder_and_filter(inputs, layout))) - - def postprocessor(output): - if isinstance(output, ir.TensorBox): - # prepack the weight as input to the template buffer - # TODO(jgong5): prune the unused constants in V.graph - # Should we implement it with constant folding in the scheduler instead? - template_buffer = ir.InputsKernel.unwrap_storage_for_input(output) - assert isinstance(template_buffer, ir.CppTemplateBuffer) - new_input_nodes, _ = reorder_and_filter(input_nodes, layout) - W_node = new_input_nodes[1] - assert W_node.get_name() in V.graph.constants - W = V.graph.constants[W_node.get_name()] - new_input_nodes[1] = W - new_input_nodes, _ = pack_weight( - *transpose_weight(new_input_nodes, layout) - ) - W_packed = new_input_nodes[1] - W_packed_constant = V.graph.add_tensor_constant(W_packed) - template_buffer.inputs[1] = ir.InputsKernel.unwrap_storage_for_input( - W_packed_constant - ) - return output - - template = DataProcessorTemplateWrapper( - CppPackedGemmTemplate, - preprocessor, - postprocessor, - input_nodes=input_nodes, - layout=layout, - num_threads=num_threads, - register_blocking=micro_gemm.register_blocking, - beta=beta, - alpha=alpha, - ) - template.maybe_append_choice(choices) - return template - - def render( # type: ignore[override] - self, - kernel: CppTemplateKernel, - template_buffer_node: Optional[ir.CppTemplateBuffer] = None, - epilogue_nodes: Optional[List[ir.IRNode]] = None, - **kwargs, - ) -> str: - assert not epilogue_nodes, "Epilogue nodes are not supported for GEMM template." - assert len(self.input_nodes) >= 2 - - X, W = self.input_nodes[0], self.input_nodes[1] - inp = self.input_nodes[2] if len(self.input_nodes) > 2 else None - Y = self.output_node - - if template_buffer_node is not None: - # Use the updated prepacked weight buffer - W = template_buffer_node.inputs[1] - Y = template_buffer_node - if epilogue_nodes is not None and len(epilogue_nodes) > 0: - Y = cast(ir.Buffer, epilogue_nodes[-1]) - assert self.output_node is not None - - micro_gemm = create_micro_gemm( - f"{kernel.kernel_name}_micro_gemm", - self.m, - self.n, - self.k, - self.layout.dtype, - alpha=self.alpha, - num_threads=self.num_threads, - ) - assert micro_gemm is not None - assert self.register_blocking == micro_gemm.register_blocking - - options = dict( - X=X, - W=W, - inp=inp, - Y=Y, - beta=self.beta, - alpha=self.alpha, - num_threads=self.num_threads, - micro_gemm=micro_gemm, - is_dynamic_M=self.is_dynamic_M, - template=self, - kernel=kernel, - epilogues=epilogue_nodes, - ) - return self._template_from_string(GEMM_TEMPLATE).render(**options) diff --git a/torch/_inductor/codegen/cpp_micro_gemm.py b/torch/_inductor/codegen/cpp_micro_gemm.py deleted file mode 100644 index 353562923c91..000000000000 --- a/torch/_inductor/codegen/cpp_micro_gemm.py +++ /dev/null @@ -1,401 +0,0 @@ -from collections import namedtuple -from typing import Dict, List, Optional, Type - -import sympy - -import torch - -from .. import ir -from ..codecache import pick_vec_isa, VecAVX2, VecAVX512 -from ..utils import IndentedBuffer, parallel_num_threads -from ..virtualized import V -from .common import KernelTemplate -from .cpp_template_kernel import CppTemplateKernel -from .cpp_utils import DTYPE_TO_CPP, GemmBlocking, value_to_cpp - - -class CppMicroGemm: - """ - A class that codegens a kernel that computes small-sized matrix multiplication. - - A micro GEMM kernel is responsible for register blocking, instruction selection, - and other CPU architecture-specific optimizations. - - The subclasses need to override `codegen_define` to define the kernel function - that is called by the code generated by `codegen_call`. - """ - - # TODO(jgong5): support constant shapes and lds as template args. - DECLARE_KERNEL = r""" -template -inline void {{kernel_name}}( - const {{input_t}}* __restrict__ A, - const {{input_t}}* __restrict__ B, - {{output_t}}* __restrict__ C, - int64_t M, - int64_t N, - int64_t K, - int64_t lda, - int64_t ldb, - int64_t ldc -) -""" - - def __init__( - self, - name, - input_dtype, - output_dtype, - compute_dtype, - register_blocking, - alpha=1, - ): - self.name = name - self.input_dtype = input_dtype - self.output_dtype = output_dtype - self.compute_dtype = compute_dtype - self.register_blocking = register_blocking - self.alpha = alpha - - def get_common_options(self): - return { - "kernel_name": self.name, - "input_t": DTYPE_TO_CPP[self.input_dtype], - "output_t": DTYPE_TO_CPP[self.output_dtype], - "compute_t": DTYPE_TO_CPP[self.compute_dtype], - "alpha": self.alpha, - } - - def get_kernel_declaration(self): - options = self.get_common_options() - return KernelTemplate._template_from_string(self.DECLARE_KERNEL).render(options) - - def codegen_define(self, kernel: CppTemplateKernel) -> str: - raise NotImplementedError - - def codegen_call( - self, - kernel: CppTemplateKernel, - A: ir.Buffer, - B: ir.Buffer, - C: ir.Buffer, - accum: bool, - ) -> str: - """ - Generate the code for calling the templated kernel that computes - `C += alpha * A @ B` if `accum` is True, or `C = alpha * A @ B` otherwise. - """ - A_ptr = f"&({kernel.index(A, [0, 0])})" - B_ptr = f"&({kernel.index(B, [0, 0])})" - C_ptr = f"&({kernel.index(C, [0, 0])})" - M = kernel.size(C, 0) - N = kernel.size(C, 1) - K = kernel.size(A, 1) - lda = kernel.stride(A, 0) - ldb = kernel.stride(B, 0) - ldc = kernel.stride(C, 0) - res = IndentedBuffer() - res.writeline(f"{self.name}<{value_to_cpp(accum, 'bool')}>(") - with res.indent(): - res.writeline(f"{A_ptr},") - res.writeline(f"{B_ptr},") - res.writeline(f"{C_ptr},") - res.writeline(f"{M},") - res.writeline(f"{N},") - res.writeline(f"{K},") - res.writeline(f"{lda},") - res.writeline(f"{ldb},") - res.writeline(f"{ldc}") - res.writeline(");") - return res.getvalue() - - -CppMicroGemmConfig = namedtuple( - "CppMicroGemmConfig", - [ - "input_dtype", - "output_dtype", - "compute_dtype", - "vec_isa_cls", - "register_blocking", - ], -) - -micro_gemm_configs: Dict[Type[CppMicroGemm], List[CppMicroGemmConfig]] = {} - - -def register_micro_gemm(*configs): - def inner(cls): - assert ( - cls not in micro_gemm_configs - ), f"Duplicate micro_gemm registration for {cls}" - assert len(configs) > 0, f"No micro_gemm configs provided for {cls}" - micro_gemm_configs[cls] = list(configs) - return cls - - return inner - - -class CppMicroGemmRef(CppMicroGemm): - """ - A reference implementation of the CppMicroGemm class with naive C++ code. - It is used for correctness debugging. - """ - - TEMPLATE_ENTRY = r""" -{{declare_kernel}} { - for (int64_t m = 0; m < M; ++m) { - for (int64_t n = 0; n < N; ++n) { - {{compute_t}} result = accum ? C[m * ldc + n] : 0; - for (int64_t k = 0; k < K; ++k) { - result += ({{compute_t}})A[m * lda + k] * ({{compute_t}})B[k * ldb + n] * {{alpha}}; - } - C[m * ldc + n] = result; - } - } -} -""" - - def __init__(self, name, input_dtype, output_dtype, compute_dtype, alpha): - super().__init__( - name, input_dtype, output_dtype, compute_dtype, GemmBlocking(1, 1, 1), alpha - ) - - def codegen_define(self, kernel: CppTemplateKernel) -> str: - options = { - "declare_kernel": self.get_kernel_declaration(), - **self.get_common_options(), - } - return KernelTemplate._template_from_string(self.TEMPLATE_ENTRY).render(options) - - -@register_micro_gemm( - CppMicroGemmConfig( - torch.float32, torch.float32, torch.float32, VecAVX512, GemmBlocking(8, 48, 1) - ), - CppMicroGemmConfig( - torch.float32, torch.float32, torch.float32, VecAVX512, GemmBlocking(8, 32, 1) - ), - CppMicroGemmConfig( - torch.float32, torch.float32, torch.float32, VecAVX512, GemmBlocking(16, 16, 1) - ), - CppMicroGemmConfig( - torch.float32, torch.float32, torch.float32, VecAVX2, GemmBlocking(4, 24, 1) - ), - CppMicroGemmConfig( - torch.float32, torch.float32, torch.float32, VecAVX2, GemmBlocking(4, 16, 1) - ), - CppMicroGemmConfig( - torch.float32, torch.float32, torch.float32, VecAVX2, GemmBlocking(8, 8, 1) - ), -) -class CppMicroGemmFP32Vec(CppMicroGemm): - """ - This class generates the code for fp32 micro gemm using vec instructions. - """ - - TEMPLATE_ENTRY = r""" -{{declare_kernel}} { - TORCH_CHECK(N % {{block_n}} == 0, "N dimension must be multiple of {{block_n}}"); - TORCH_CHECK(K % {{block_k}} == 0, "K dimension must be multiple of {{block_k}}"); - // TODO(jgong5): loop unroll for M and N - for (int64_t m = 0; m < M; m += {{block_m}}) { - int64_t block_m = std::min(M - m, {{block_m}}); - for (int64_t n = 0; n < N; n += {{block_n}}) { - if (block_m == {{block_m}}) { - {{kernel_name}}_kernel<{{block_m}}, {{block_n}}, accum>( - A + m * lda, - B + n, - C + m * ldc + n, - K, - lda, - ldb, - ldc - ); - } else { - switch (block_m) { - {%- for b in range(block_m - 1, 0, -1) %} - case {{b}}: - {{kernel_name}}_kernel<{{b}}, {{block_n}}, accum>( - A + m * lda, - B + n, - C + m * ldc + n, - K, - lda, - ldb, - ldc - ); - break; - {%- endfor %} - default: - {{kernel.assert_function}}(false, "Unsupported block_m: ", block_m); - } - } - } - } -} -""" - - TEMPLATE_KERNEL = r""" -template -inline void {{kernel_name}}_kernel( - const float* __restrict__ A, - const float* __restrict__ B, - float* __restrict__ C, - int64_t K, - int64_t lda, - int64_t ldb, - int64_t ldc -) { - using Vectorized = at::vec::Vectorized; - constexpr auto VLEN = Vectorized::size(); - constexpr auto ROWS = BLOCK_M; - constexpr auto COLS = BLOCK_N / VLEN; - - Vectorized va; - at::vec::VectorizedN vb; - at::vec::VectorizedN vc; - - auto loadc = [&](auto i) { - if constexpr (accum) { - constexpr int row = i / COLS; - constexpr int col = i % COLS; - vc[i] = Vectorized::loadu(C + row * ldc + col * VLEN); - } else { - vc[i] = Vectorized(0.0f); - } - }; - c10::ForcedUnroll{}(loadc); - - auto compute = [&, COLS](auto i, int k) { - constexpr int row = i / COLS; - constexpr int col = i % COLS; - - if constexpr (col == 0) { - {%- if alpha != 1 %} - va = Vectorized(A[row * lda + k] * {{alpha}}); - {%- else %} - va = Vectorized(A[row * lda + k]); - {%- endif %} - } - - if constexpr (row == 0) { - vb[col] = Vectorized::loadu(B + k * ldb + col * VLEN); - } - - constexpr int idx = row * COLS + col; - vc[idx] = at::vec::fmadd(va, vb[col], vc[idx]); - }; - - {{kernel.unroll_pragma(4)}} - for (int k = 0; k < K; ++k) { - c10::ForcedUnroll{}(compute, k); - } - - // store to C - auto storec = [&](auto i) { - constexpr int row = i / COLS; - constexpr int col = i % COLS; - vc[i].store(C + row * ldc + col * VLEN); - }; - c10::ForcedUnroll{}(storec); -} -""" - - def codegen_define(self, kernel: CppTemplateKernel) -> str: - options = { - "declare_kernel": self.get_kernel_declaration(), - "kernel": kernel, - "block_m": self.register_blocking.block_m, - "block_n": self.register_blocking.block_n, - "block_k": self.register_blocking.block_k, - **self.get_common_options(), - } - result = KernelTemplate._template_from_string(self.TEMPLATE_KERNEL).render( - options - ) - result += KernelTemplate._template_from_string(self.TEMPLATE_ENTRY).render( - options - ) - return result - - -def create_micro_gemm( - name, - m, - n, - k, - input_dtype, - output_dtype=None, - compute_dtype=None, - alpha=1, - num_threads=-1, - use_ref=False, -) -> Optional[CppMicroGemm]: - def create_from_config(cls, config: CppMicroGemmConfig): - return cls( - name, - config.input_dtype, - config.output_dtype, - config.compute_dtype, - config.register_blocking, - alpha, - ) - - assert isinstance(n, int) or n.is_number, n - assert isinstance(k, int) or k.is_number, k - m = V.graph.sizevars.size_hint(m, fallback=1) if isinstance(m, sympy.Expr) else m - assert isinstance(m, int), m - if output_dtype is None: - output_dtype = input_dtype - if compute_dtype is None: - compute_dtype = input_dtype - if num_threads < 0: - num_threads = parallel_num_threads() - vec_isa = pick_vec_isa() - matched_configs = [] - for cls, configs in micro_gemm_configs.items(): - for config in configs: - if not isinstance(vec_isa, config.vec_isa_cls): - continue - if ( - config.input_dtype == input_dtype - and config.output_dtype == output_dtype - and config.compute_dtype == compute_dtype - ): - block_m, block_n, block_k = config.register_blocking - # TODO(jgong5): support n % n_block_size != 0 - if n % block_n != 0: - continue - # Criteria on the ranking of configurations - # 1. Dividable by block sizes (block_m, block_k) - # 2. Number of mxn blocks is large enough to occupy all the threads - # 3. Register blocks are larger - dividable_score = 0 - if k % block_k == 0: - dividable_score += 1 - if m % block_m == 0: - dividable_score += 1 - occupancy_score = 0 - n_blocks = n // block_n - total_mxn_blocks = n // block_n * ((m + block_m - 1) // block_m) - if n_blocks >= num_threads: - occupancy_score += 1 - if total_mxn_blocks >= num_threads: - occupancy_score += 1 - matched_configs.append( - ( - (dividable_score, occupancy_score, block_m * block_n * block_k), - cls, - config, - ) - ) - if len(matched_configs) == 0: - if use_ref: - return CppMicroGemmRef( - name, input_dtype, output_dtype, compute_dtype, alpha - ) - else: - return None - # TODO(jgong5): allow autotuning on choices of configs - return create_from_config(*max(matched_configs, key=lambda x: x[0])[1:]) diff --git a/torch/_inductor/codegen/cpp_prefix.h b/torch/_inductor/codegen/cpp_prefix.h index 45f874fc4d26..096f716bc8da 100644 --- a/torch/_inductor/codegen/cpp_prefix.h +++ b/torch/_inductor/codegen/cpp_prefix.h @@ -5,7 +5,6 @@ #include #include #include -#include #include // WARNING: be extra careful when including more ATen/c10 header files here! @@ -294,100 +293,3 @@ atomic_add(volatile T *addr, T offset) { std::atomic *atomic_addr = (std::atomic *)addr; atomic_addr->fetch_add(offset, std::memory_order_relaxed); } - -std::tuple mm_get_thread_blocking( - int64_t M, - int64_t N, - int64_t K, - int64_t M0, - int64_t N0, - int64_t K0, - int num_threads) { - auto get_factors = [](int64_t number) { - int count = 0; - for (int64_t i = std::sqrt(number); i > 0; --i) { - if (number % i == 0) { - count += 2; - } - } - auto factors = std::make_unique(count); - int index = 0; - for (int64_t i = std::sqrt(number); i > 0; --i) { - if (number % i == 0) { - factors[index++] = number / i; - factors[index++] = i; - } - } - return std::make_tuple(std::move(factors), count); - }; - - auto get_blocking = [](int64_t num_threads, - int64_t factor, - int64_t m_blocks, - int64_t n_blocks, - int64_t k_blocks) { - int64_t thread_block_n = (n_blocks + factor - 1) / factor; - int64_t cofactor = num_threads / factor; - int64_t thread_block_m = (m_blocks + cofactor - 1) / cofactor; - return std::make_tuple(thread_block_m, thread_block_n, k_blocks); - }; - - int64_t m_blocks = (M + M0 - 1) / M0; - int64_t n_blocks = (N + N0 - 1) / N0; - int64_t k_blocks = (K + K0 - 1) / K0; - - auto [factors, count] = get_factors(num_threads); - assert(count > 0); - - for (int i = 0; i < count; ++i) { - int64_t factor = factors[i]; - if (n_blocks % factor == 0 && - m_blocks % (num_threads / factor) == 0) { - return get_blocking( - num_threads, factor, m_blocks, n_blocks, k_blocks); - } - } - - for (int i = 0; i < count; ++i) { - int64_t factor = factors[i]; - if (n_blocks % factor == 0) { - return get_blocking( - num_threads, factor, m_blocks, n_blocks, k_blocks); - } - int64_t cofactor = num_threads / factor; - if (m_blocks % cofactor == 0) { - return get_blocking( - num_threads, factor, m_blocks, n_blocks, k_blocks); - } - } - - assert(false && "Should not reach here."); - // Dummy return to avoid compiler warning - return std::make_tuple(0, 0, 0); -} - -inline void mm_get_thread_blocks( - int thread_id, - int64_t M_blocks, - int64_t N_blocks, - int64_t K_blocks, - int64_t Mt_blocks, - int64_t Nt_blocks, - int64_t Kt_blocks, - int64_t& m_block_start, - int64_t& m_block_end, - int64_t& n_block_start, - int64_t& n_block_end, - int64_t& k_block_start, - int64_t& k_block_end) { - int64_t num_Kt = (K_blocks + Kt_blocks - 1) / Kt_blocks; - k_block_start = (thread_id % num_Kt) * Kt_blocks; - k_block_end = std::min(k_block_start + Kt_blocks, K_blocks); - thread_id /= num_Kt; - int64_t num_Nt = (N_blocks + Nt_blocks - 1) / Nt_blocks; - n_block_start = (thread_id % num_Nt) * Nt_blocks; - n_block_end = std::min(n_block_start + Nt_blocks, N_blocks); - thread_id /= num_Nt; - m_block_start = std::min(thread_id * Mt_blocks, M_blocks); - m_block_end = std::min(m_block_start + Mt_blocks, M_blocks); -} diff --git a/torch/_inductor/codegen/cpp_template.py b/torch/_inductor/codegen/cpp_template.py deleted file mode 100644 index 3d15010a8838..000000000000 --- a/torch/_inductor/codegen/cpp_template.py +++ /dev/null @@ -1,116 +0,0 @@ -import functools -import itertools -import logging - -import sys -from typing import List, Optional -from unittest.mock import patch - -import sympy - -from .. import codecache, config, ir -from ..autotune_process import CppBenchmarkRequest, TensorMeta -from ..utils import IndentedBuffer, Placeholder, unique -from ..virtualized import V -from .common import KernelTemplate -from .cpp_template_kernel import CppTemplateCaller, CppTemplateKernel - -log = logging.getLogger(__name__) - - -class CppTemplate(KernelTemplate): - index_counter = itertools.count() - - def __init__( - self, - name: str, - input_nodes, - layout: ir.Layout, - ): - super().__init__(name) - self.input_nodes = input_nodes - self.output_node: ir.Buffer = ir.Buffer("buf_out", layout) - self.layout = layout - - def generate(self, **kwargs): - kernel_name = f"cpp_{self.name}" - with patch.object( - V.graph, "get_dtype", self._fake_get_dtype(self.output_node) - ), CppTemplateKernel( - kernel_name=kernel_name, - ) as kernel: - code = self.render(kernel=kernel, **kwargs) - _, call_args, _ = kernel.args.python_argdefs() - log.debug("Generated Code:\n%s", code) - log.debug( - "Args: cpp_argdefs: %s, python_argdefs: %s", - kernel.args.cpp_argdefs(), - kernel.args.python_argdefs(), - ) - - expected_args = list( - unique(input_node.get_name() for input_node in self.input_nodes) - ) - expected_args.extend([self.output_node.get_name()]) - assert list(call_args)[: len(expected_args)] == expected_args, ( - call_args, - expected_args, - ) - extra_args = V.graph.sizevars.size_hints( - map(sympy.expand, call_args[len(expected_args) :]) - ) - - kernel_hash_name = f"cpp_{self.name}_{next(self.index_counter)}" - - # Create the BenchmarkRequest for CPP - bmreq = CppBenchmarkRequest( - kernel_name=kernel_name, - input_tensor_meta=TensorMeta.from_irnodes(self.input_nodes), - output_tensor_meta=TensorMeta.from_irnodes(self.output_node), - extra_args=extra_args, - source_code=code, - ) - - def make_kernel_render( - template_node: ir.CppTemplateBuffer, - epilogue_nodes: Optional[List[ir.IRNode]] = None, - ): - kernel = CppTemplateKernel( - kernel_name=str(Placeholder.KERNEL_NAME), - ) - render = functools.partial( - self.render, - kernel=kernel, - template_buffer_node=template_node, - epilogue_nodes=epilogue_nodes, - **kwargs, - ) - return kernel, render - - return CppTemplateCaller( - kernel_hash_name, - self.name, - self.input_nodes, - self.output_node.get_layout(), - make_kernel_render, - bmreq, - self, - ) - - def header(self) -> IndentedBuffer: - res = IndentedBuffer() - res.writeline(codecache.cpp_prefix()) - res.splice( - """ - #include "c10/util/Unroll.h" - """ - ) - enable_kernel_profile = ( - config.cpp.enable_kernel_profile and sys.platform == "linux" - ) - if enable_kernel_profile: - res.writelines(["#include "]) - return res - - def render(self, **kwargs) -> str: - raise NotImplementedError diff --git a/torch/_inductor/codegen/cpp_template_kernel.py b/torch/_inductor/codegen/cpp_template_kernel.py deleted file mode 100644 index 6a978c45fa28..000000000000 --- a/torch/_inductor/codegen/cpp_template_kernel.py +++ /dev/null @@ -1,200 +0,0 @@ -import itertools -from typing import Any, Callable, Dict, List, Optional, Tuple, Union - -import sympy -from sympy.parsing.sympy_parser import parse_expr - -import torch - -from torch._inductor.autotune_process import CppBenchmarkRequest -from torch._inductor.utils import sympy_index_symbol -from .. import codecache, config, ir, lowering as L -from ..virtualized import V -from .common import Kernel, OpOverrides -from .cpp_utils import cexpr_index, DTYPE_TO_CPP - - -def parse_expr_with_index_symbols(expr_str: str) -> sympy.Expr: - expr = parse_expr(expr_str) - int_symbols = {sym: sympy_index_symbol(sym.name) for sym in expr.free_symbols} - return expr.subs(int_symbols) - - -def wrap_with_tensorbox(node) -> ir.TensorBox: - return ( - ir.TensorBox.create(node) if isinstance(node, ir.Buffer) else ir.TensorBox(node) - ) - - -class CppTemplateKernel(Kernel): - overrides = OpOverrides - - def __init__(self, kernel_name): - super().__init__() - self.kernel_name = kernel_name - - def def_kernel( - self, - inputs: Dict[str, ir.Buffer], - outputs: Dict[str, ir.Buffer], - ) -> str: - for name, inp in inputs.items(): - if inp is not None: - self.args.input_buffers[inp.get_name()] = name - for name, out in outputs.items(): - self.args.output_buffers[out.get_name()] = name - unique_sizevars = { - s - for input in inputs.values() - if input is not None - for sym in itertools.chain(input.get_size(), input.get_stride()) - if isinstance(sym, sympy.Expr) - for s in sym.free_symbols - } - unique_sizevars |= { - s - for output in outputs.values() - for sym in itertools.chain(output.get_size(), output.get_stride()) - if isinstance(sym, sympy.Expr) - for s in sym.free_symbols - } - sizevars = sorted(unique_sizevars, key=str) - for sizevar in sizevars: - self.args.sizevars[sizevar] = f"k{sizevar}" - cpp_argdefs, _, _ = self.args.cpp_argdefs() - return f"void {self.kernel_name}({', '.join(cpp_argdefs)})" - - def call_kernel(self, name: str, node: ir.CppTemplateBuffer): - wrapper = V.graph.wrapper_code - _, call_args, arg_types = self.args.cpp_argdefs() - wrapper.generate_kernel_call(name, call_args, cuda=False, arg_types=arg_types) - - def dtype(self, node: ir.Buffer) -> str: - return DTYPE_TO_CPP[node.get_dtype()] - - def acc_dtype(self, node: ir.Buffer) -> str: - if node.get_dtype() in [torch.float32, torch.bfloat16, torch.half]: - return "float" - else: - raise NotImplementedError(f"Unsupported dtype: {node.get_dtype()}") - - def size(self, node: ir.Buffer, dim: int) -> str: - return cexpr_index(self.rename_indexing(node.get_size()[dim])) - - def stride(self, node: ir.Buffer, dim: int) -> str: - return cexpr_index(self.rename_indexing(node.get_stride()[dim])) - - def index(self, node: ir.Buffer, indices: List[Any]) -> str: - indexer = node.make_indexer() - index = indexer([parse_expr_with_index_symbols(str(idx)) for idx in indices]) - index = self.rename_indexing(index) - return f"{self.args.input(node.get_name())}[{cexpr_index(index)}]" - - def slice_nd(self, node, ranges: List[Tuple[Any]]) -> ir.ReinterpretView: - """ - Slice the given node with a list of ranges (start and end) corresponding to its dims. - The dim is not sliced if the corresponding range is empty. - """ - assert len(ranges) == len(node.get_size()) - sliced = wrap_with_tensorbox(node) - for dim, _range in enumerate(ranges): - if len(_range) == 0: - continue - assert len(_range) == 2 - start, end = (parse_expr_with_index_symbols(str(r)) for r in _range) - sliced = L.slice_(sliced, dim, start, end, clamp=False) - assert isinstance(sliced.data, ir.ReinterpretView) - return sliced.data - - def view(self, node, sizes: List[Any]) -> ir.View: - node = wrap_with_tensorbox(node) - sizes = [parse_expr_with_index_symbols(str(s)) for s in sizes] - return L.view(node, sizes).data - - @property - def assert_function(self) -> str: - if V.graph.aot_mode: - return "AOTI_TORCH_CHECK" - else: - return "TORCH_CHECK" - - def maybe_codegen_profile(self) -> str: - if config.cpp.enable_kernel_profile: - graph_id = V.graph.graph_id - prefix = "graph_" + str(graph_id) + "_" if graph_id is not None else "" - return f'RECORD_FUNCTION("{prefix}{self.kernel_name}", c10::ArrayRef({{}}));' - else: - return "" - - def unroll_pragma(self, unroll): - if codecache.is_gcc(): - return f"#pragma GCC unroll {unroll}" - else: - return f"#pragma unroll {unroll}" - - -class CppTemplateCaller(ir.ChoiceCaller): - """ - CppTemplateCaller - - This class represents a caller for CPP template kernels. It is a subclass of ir.ChoiceCaller. - Attributes: - name (str): The name of the caller. - category (str): The category of the caller. - bmreq (CppBenchmarkRequest): The benchmark request for the caller. - template_buffer (ir.CppTemplateBuffer): The template buffer for the caller. - """ - - def __init__( - self, - name: str, - category: str, - input_nodes: List[ir.Buffer], - layout: ir.Layout, - make_kernel_render: Callable[ - [ir.CppTemplateBuffer, Optional[List[ir.IRNode]]], str - ], - bmreq: CppBenchmarkRequest, - template: "CppTemplate", # type: ignore[name-defined] # noqa: F821 - info_kwargs: Optional[ - Dict[str, Union[ir.PrimitiveInfoType, List[ir.PrimitiveInfoType]]] - ] = None, - ): - super().__init__(name, input_nodes, layout) - self.category = category - self.make_kernel_render = make_kernel_render - self.bmreq = bmreq - self.template = template - self.info_kwargs = info_kwargs - - def precompile(self) -> None: - assert self.bmreq is not None - self.bmreq.precompile() - - def benchmark(self, *args, out) -> float: - assert self.bmreq is not None - return self.bmreq.benchmark(*args, output_tensor=out) - - def hash_key(self) -> str: - return "-".join( - [ - self.category, - self.bmreq.hash_key, - ] - ) - - def info_dict( - self, - ) -> Dict[str, Union[ir.PrimitiveInfoType, List[ir.PrimitiveInfoType]]]: - return {"backend": "CPP", "op_type": "unknown"} - - def output_node(self) -> ir.TensorBox: - return ir.TensorBox.create( - ir.CppTemplateBuffer( - layout=self.layout, - inputs=self.input_nodes, - make_kernel_render=self.make_kernel_render, - template=self.template, - choice=self, - ) - ) diff --git a/torch/_inductor/codegen/cpp_utils.py b/torch/_inductor/codegen/cpp_utils.py index a3b4fd3206b6..7e6f06b9e507 100644 --- a/torch/_inductor/codegen/cpp_utils.py +++ b/torch/_inductor/codegen/cpp_utils.py @@ -1,7 +1,5 @@ import math -from collections import namedtuple - import torch from .common import ExprPrinter @@ -57,8 +55,6 @@ INDEX_TYPE = "long" -GemmBlocking = namedtuple("GemmBlocking", ["block_m", "block_n", "block_k"]) - class CppPrinter(ExprPrinter): def _print_Integer(self, expr): diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 79af641514bd..640c0d25c264 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -233,13 +233,12 @@ def is_fbcode(): True if is_fbcode() else os.environ.get("TORCHINDUCTOR_FORCE_SAME_PRECISION") == "1" ) # Specify candidate backends for gemm autotune. -# Possible choices are combinations of: ATen, Triton, CUTLASS, CPP. +# Possible choices are combinations of: ATen, Triton, CUTLASS. # ATen: default Pytorch ATen kernels. # Triton: Triton templates defined in torch inductor. # CUTLASS: Cutlass templates and kernels. -# CPP: CPP templates and kernels for CPU. max_autotune_gemm_backends = os.environ.get( - "TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS", "ATEN,TRITON,CPP" + "TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS", "ATEN,TRITON" ).upper() # the value used as a fallback for the unbacked SymInts diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 513709a18f32..d00a01184d1e 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -3718,13 +3718,6 @@ def get_workspace_size(self): return self.workspace_size if self.workspace_size is not None else 0 -class CppTemplateBuffer(TemplateBuffer): - def __init__(self, layout, inputs, make_kernel_render, template, choice): - super().__init__(layout, inputs, make_kernel_render) - self.template = template - self.choice = choice - - @dataclasses.dataclass class InputsKernel(Buffer): inputs: List[Buffer] @@ -6255,7 +6248,7 @@ def codegen(self, wrapper): ) @classmethod - def create(cls, x, packed_w, orig_w, B, batch_size): + def create(cls, x, packed_w, orig_w, batch_size): x = cls.require_stride1(cls.realize_input(x)) orig_w = cls.require_stride1(cls.realize_input(orig_w)) *m, _ = x.get_size() @@ -6263,11 +6256,7 @@ def create(cls, x, packed_w, orig_w, B, batch_size): output_size = list(m) + [oc] output_stride = make_contiguous_strides_for(output_size) inputs = [x, packed_w, orig_w] - constant_args = [batch_size] - if B is not None: - inputs += [B] - else: - constant_args.insert(0, None) + constant_args = [None, batch_size] return MKLPackedLinear( layout=FixedLayout( diff --git a/torch/_inductor/kernel/mm.py b/torch/_inductor/kernel/mm.py index fa14b4406de6..593da39d2bf6 100644 --- a/torch/_inductor/kernel/mm.py +++ b/torch/_inductor/kernel/mm.py @@ -3,7 +3,6 @@ from typing import Any, Dict, List, Optional import torch -from torch._inductor.codegen.cpp_gemm_template import CppPackedGemmTemplate from torch._inductor.virtualized import V from .. import config as inductor_config from ..codegen.cuda.gemm_template import CUTLASSGemmTemplate @@ -18,7 +17,6 @@ ) from ..utils import ( use_aten_gemm_kernels, - use_cpp_packed_gemm_template, use_cutlass_template, use_max_autotune, use_triton_template, @@ -158,13 +156,6 @@ def tuned_mm(mat1, mat2, *, layout=None): if static_shape and is_nonzero and use_cutlass_template(layout, m, n, k): CUTLASSGemmTemplate.add_cutlass_gemm_choices(choices, layout, [mat1, mat2]) - if use_cpp_packed_gemm_template(layout, mat1, mat2): - CppPackedGemmTemplate.add_choices( - choices, - layout, - [mat1, mat2], - ) - if len(choices) == 0 and not use_aten_gemm_kernels(): log.warning("No choices for GEMM, using ATen backend as fallback") choices.append(aten_mm.bind((mat1, mat2), aten_layout)) @@ -320,15 +311,6 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): beta=beta, ) - if use_cpp_packed_gemm_template(layout, mat1, mat2): - CppPackedGemmTemplate.add_choices( - choices, - layout, - [inp_expanded, mat1, mat2], - alpha=alpha, - beta=beta, - ) - add_aten_fallback = False if len(choices) == 0: log.warning("No choices for GEMM, using ATen backend as fallback") diff --git a/torch/_inductor/mkldnn_lowerings.py b/torch/_inductor/mkldnn_lowerings.py index 399cb1668dad..0ebccbf27ea3 100644 --- a/torch/_inductor/mkldnn_lowerings.py +++ b/torch/_inductor/mkldnn_lowerings.py @@ -1,22 +1,10 @@ -from typing import List, Optional +from typing import List import torch import torch.utils._pytree as pytree -from torch._inductor.kernel.mm_common import mm_args from . import ir -from .codegen.cpp_gemm_template import CppPackedGemmTemplate from .ir import TensorBox -from .lowering import ( - add, - add_needs_realized_inputs, - aten, - permute, - register_lowering, - to_dtype, -) -from .select_algorithm import autotune_select_algorithm, ExternKernelChoice -from .utils import use_aten_gemm_kernels, use_cpp_packed_gemm_template, use_max_autotune -from .virtualized import V +from .lowering import add, add_needs_realized_inputs, aten, register_lowering, to_dtype def register_onednn_fusion_ops(): @@ -351,12 +339,6 @@ def qlinear_unary( ) if torch._C.has_mkl: - aten_mkl_linear = ExternKernelChoice( - torch.ops.mkl._mkl_linear, - "mkl::_mkl_linear", - has_out_variant=False, - kernel_creator=ir.MKLPackedLinear.create, - ) cpu_needs_realized_inputs.append(torch.ops.mkl._mkl_linear) @register_lowering(torch.ops.mkl._mkl_linear) @@ -364,48 +346,11 @@ def mkl_packed_linear( x: TensorBox, packed_w: TensorBox, orig_w: TensorBox, - b: Optional[TensorBox], + b: TensorBox, batch_size, - *, - layout=None, ): - choices = ( - [ - aten_mkl_linear.bind( - (x, packed_w, orig_w), layout, B=None, batch_size=batch_size - ) - ] - if use_aten_gemm_kernels() - else [] - ) - if use_max_autotune(): - transposed_w = permute(orig_w, [1, 0]) - *_, layout, x, transposed_w = mm_args( - x, transposed_w, layout=layout - ) - if use_cpp_packed_gemm_template(layout, x, transposed_w): - CppPackedGemmTemplate.add_choices( - choices, - layout, - [x, packed_w, orig_w], - trans_w=True, - input_indices=[0, 2], - ) - - assert packed_w.get_name() in V.graph.constants - assert orig_w.get_name() in V.graph.constants - # packed_w is a mkldnn tensor which we can't generate directly - # so we use the weights from the original tensor in autotune. - input_gen_fns = { - 1: lambda x: V.graph.constants[x.get_name()], - 2: lambda x: V.graph.constants[x.get_name()], - } - result: TensorBox = autotune_select_algorithm( - "packed_linear", - choices, - [x, packed_w, orig_w], - layout, - input_gen_fns=input_gen_fns, + result = TensorBox.create( + ir.MKLPackedLinear.create(x, packed_w, orig_w, batch_size) ) if b is not None: result = add(result, b) diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index d1550529bb8e..0999e6ce3b21 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -696,19 +696,17 @@ def __init__( has_out_variant=True, op_overload=None, use_fallback_kernel=False, - kernel_creator=None, ): super().__init__() name = name or kernel.__name__ assert callable(kernel) - assert not hasattr(extern_kernels, name), f"duplicate extern kernel: {name}" + assert not hasattr(extern_kernels, name), "duplicate extern kernel" self.name = name self.cpp_kernel_name = cpp_kernel self.has_out_variant = has_out_variant setattr(extern_kernels, name, kernel) self.op_overload = op_overload self.use_fallback_kernel = use_fallback_kernel - self.kernel_creator = kernel_creator def to_callable(self): return getattr(extern_kernels, self.name) @@ -875,8 +873,6 @@ def output_node(self): inner = ir.FallbackKernel.create( self.choice.op_overload, *self.input_nodes, **self.kwargs ) - elif self.choice.kernel_creator is not None: - inner = self.choice.kernel_creator(*self.input_nodes, **self.kwargs) else: cls = ir.ExternKernelOut if self.has_out_variant else ir.ExternKernelAlloc inner = cls( @@ -899,86 +895,6 @@ def info_dict(self) -> Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType } -class DataProcessorChoiceCallerWrapper: - def __init__(self, wrapped, preprocessor, postprocessor): - self._wrapped = wrapped - if preprocessor is not None: - self._preprocessor = preprocessor - else: - self._preprocessor = lambda x, y: (x, y) - if postprocessor is not None: - self._postprocessor = postprocessor - else: - self._postprocessor = lambda x: x - - def __getattr__(self, name): - return getattr(self._wrapped, name) - - def benchmark(self, *args, out) -> float: - new_args, new_out = self._preprocessor(args, out) - result = self._wrapped.benchmark(*new_args, out=new_out) - new_out = self._postprocessor(new_out) - if out is not new_out: - out.copy_(new_out) - return result - - def output_node(self) -> ir.TensorBox: - result = self._wrapped.output_node() - return self._postprocessor(result) - - def __repr__(self) -> str: - return f"DataProcessorChoiceCallerWrapper({self._wrapped})" - - -class DataProcessorTemplateWrapper: - """ - A wrapper class for a kernel template. - - This class together with `DataProcessorChoiceCallerWrapper` provides a convenient way to - preprocess and postprocess data before and after using the wrapped template. A typical - usage is to reorder or filter the input nodes in order to match the expected input of other - kernel choices like a ATen kernel. A more complicated usage is to prepack the weights. - See the example from :mod:`cpp_gemm_template` for more details. - """ - - def __init__( - self, - wrapped_template_cls, - preprocessor, - postprocessor, - **kwargs, - ): - if preprocessor is not None: - self._preprocessor = preprocessor - else: - self._preprocessor = lambda x, y: (x, y) - if postprocessor is not None: - self._postprocessor = postprocessor - else: - self._postprocessor = lambda x: x - assert "input_nodes" in kwargs - assert "layout" in kwargs - kwargs["input_nodes"], kwargs["layout"] = preprocessor( - kwargs["input_nodes"], kwargs["layout"] - ) - self._wrapped = wrapped_template_cls(**kwargs) - - def __getattr__(self, name): - return getattr(self._wrapped, name) - - def maybe_append_choice(self, choices, **kwargs): - return type(self._wrapped).maybe_append_choice(self, choices, **kwargs) - - def generate(self, **kwargs): - choice_caller = self._wrapped.generate(**kwargs) - return DataProcessorChoiceCallerWrapper( - choice_caller, self._preprocessor, self._postprocessor - ) - - def __repr__(self) -> str: - return f"DataProcessorTemplateWrapper({self._wrapped})" - - class ErrorFromChoice(RuntimeError): def __init__(self, msg, choice: ChoiceCaller, inputs_str): msg += f"\nFrom choice {choice}\n{inputs_str}" @@ -1257,9 +1173,7 @@ def get_inputs(): } example_inputs = list(unique_example_inputs.values()) example_inputs_extern = [ - unique_example_inputs[input_node.get_name()] - if unique_example_inputs[input_node.get_name()].is_mkldnn - else torch.as_strided( + torch.as_strided( unique_example_inputs[input_node.get_name()], V.graph.sizevars.size_hints( input_node.get_size(), diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 5ff1d951bb42..c10e3cc512f0 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -987,42 +987,6 @@ def use_cutlass_template(layout, m, n, k): return res -def _use_template_for_cpu(layout): - return use_max_autotune() and layout.device.type == "cpu" - - -def use_cpp_packed_gemm_template(layout, mat1, mat2): - from . import ir - from .codegen.cpp_micro_gemm import create_micro_gemm - from .kernel.mm_common import mm_args - - if not _use_template_for_cpu(layout) or not _use_autotune_backend("CPP"): - return False - - if not config.cpp.weight_prepack: - return False - - layout_dtypes = [torch.float32] - m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2) - # TODO(jgong5): support dynamic shapes for n or k - if has_free_symbols((n, k)): - return False - if isinstance(mat2, ir.BaseView): - mat2 = mat2.unwrap_view() - micro_gemm = create_micro_gemm( - "micro_gemm", m, n, k, layout.dtype, num_threads=parallel_num_threads() - ) - # TODO(jgong5): support n % n_block_size != 0 - return ( - layout.dtype in layout_dtypes - and micro_gemm is not None - and n % micro_gemm.register_blocking[1] == 0 - and mat1.get_stride()[-1] == 1 # TODO(jgong5): support transposed input - and isinstance(mat2, ir.StorageBox) - and mat2.is_module_buffer() - ) - - def use_aten_gemm_kernels(): return not use_max_autotune() or _use_autotune_backend("ATEN") From 762ce6f062f387e963a49b3a631392dd0fd51586 Mon Sep 17 00:00:00 2001 From: drisspg Date: Fri, 17 May 2024 00:41:55 +0000 Subject: [PATCH 037/116] Add Lowering for FlexAttention Backwards (#125515) # Summary #### What does this PR do? It enables Inductor to actually generate the fused flex attention kernel for the backwards I did some other things along the way: - Abstract out the 'build_subgraph_buffer' subroutine and make it reusable between flex attention and flex_attention backwards. In total we need too build 3 subgraphs for fwd + bwd. 1 for the fwd graph and then 2 in the bwd. The FAv2 algorithm recomputes the parts of the forward (more efficiently since we already have the row_max via logsumexp), therefore we need to inline both the fwd graph and the joint graph in the bwds kernel. - The version of the backwards kernel is from a somewhat older version of the triton tutorial implementation. I think that we should update in a follow up to a newer version. Notably the blocks need to be square for this to work as currently implemented. I am sure there are many opportunities for optimization. - I didnt correctly register the decomp table + IndexMode when I landed: https://github.com/pytorch/pytorch/pull/123902, this remedies that. - The rel_bias helper func was reversed in terms of causality. I updated and then add a test specific for "future causal" attention. - This PRs but the main point that I think still needs to be worked out is the store_output call. I have it hacked up to be 'fake' but I dont think we want to land that and likely want to just have a mutated 'dq' and a stored_output 'dk' - I also needed to update the `TritonTemplateKernel` to actually accept multiple subgraphs (modifications) - I updated the benchmark to also profile bwds performance ### Benchmark Numbers: _The current implementation is not parallelizing over ctx length in the bwd_ FWD Speedups | Type | Speedup | shape | score_mod | dtype | |---------|-----------|--------------------|-------------|----------------| | Average | 0.991 | | | | | Max | 1.182 | (16, 16, 4096, 64) | noop | torch.bfloat16 | | Min | 0.796 | (2, 16, 512, 256) | head_bias | torch.bfloat16 | BWD Speedups | Type | Speedup | shape | score_mod | dtype | |---------|-----------|--------------------|-------------|----------------| | Average | 0.291 | | | | | Max | 0.652 | (8, 16, 512, 64) | head_bias | torch.bfloat16 | | Min | 0.073 | (2, 16, 4096, 128) | head_bias | torch.bfloat16 |
Full Data | shape | score_mod | dtype | fwd_eager_time | fwd_compiled_time | bwd_eager_time | bwd_compiled_time | fwd_speedup | bwd_speedup | |---------------------|---------------|----------------|------------------|---------------------|------------------|---------------------|---------------|---------------| | (2, 16, 512, 64) | noop | torch.bfloat16 | 19.936 | 19.092 | 57.851 | 193.564 | 1.044 | 0.299 | | (2, 16, 512, 64) | causal_mask | torch.bfloat16 | 19.955 | 19.497 | 57.662 | 206.278 | 1.024 | 0.280 | | (2, 16, 512, 64) | relative_bias | torch.bfloat16 | 19.455 | 21.297 | 57.674 | 195.219 | 0.913 | 0.295 | | (2, 16, 512, 64) | head_bias | torch.bfloat16 | 19.958 | 21.289 | 57.674 | 193.859 | 0.938 | 0.298 | | (2, 16, 512, 128) | noop | torch.bfloat16 | 28.157 | 28.615 | 82.831 | 454.211 | 0.984 | 0.182 | | (2, 16, 512, 128) | causal_mask | torch.bfloat16 | 28.154 | 28.444 | 83.091 | 432.083 | 0.990 | 0.192 | | (2, 16, 512, 128) | relative_bias | torch.bfloat16 | 28.722 | 27.897 | 83.175 | 446.789 | 1.030 | 0.186 | | (2, 16, 512, 128) | head_bias | torch.bfloat16 | 28.299 | 27.673 | 83.052 | 459.179 | 1.023 | 0.181 | | (2, 16, 512, 256) | noop | torch.bfloat16 | 41.167 | 50.504 | 175.019 | 1083.545 | 0.815 | 0.162 | | (2, 16, 512, 256) | causal_mask | torch.bfloat16 | 41.656 | 51.933 | 175.078 | 1171.176 | 0.802 | 0.149 | | (2, 16, 512, 256) | relative_bias | torch.bfloat16 | 41.697 | 50.722 | 175.159 | 1097.312 | 0.822 | 0.160 | | (2, 16, 512, 256) | head_bias | torch.bfloat16 | 41.690 | 52.387 | 175.184 | 1097.336 | 0.796 | 0.160 | | (2, 16, 1024, 64) | noop | torch.bfloat16 | 39.232 | 37.454 | 127.847 | 612.430 | 1.047 | 0.209 | | (2, 16, 1024, 64) | causal_mask | torch.bfloat16 | 39.930 | 39.599 | 127.755 | 665.359 | 1.008 | 0.192 | | (2, 16, 1024, 64) | relative_bias | torch.bfloat16 | 39.417 | 41.304 | 127.902 | 614.990 | 0.954 | 0.208 | | (2, 16, 1024, 64) | head_bias | torch.bfloat16 | 39.965 | 42.034 | 127.953 | 613.273 | 0.951 | 0.209 | | (2, 16, 1024, 128) | noop | torch.bfloat16 | 63.964 | 71.024 | 226.510 | 1637.669 | 0.901 | 0.138 | | (2, 16, 1024, 128) | causal_mask | torch.bfloat16 | 63.843 | 72.451 | 226.750 | 1558.949 | 0.881 | 0.145 | | (2, 16, 1024, 128) | relative_bias | torch.bfloat16 | 64.301 | 70.487 | 226.651 | 1610.063 | 0.912 | 0.141 | | (2, 16, 1024, 128) | head_bias | torch.bfloat16 | 64.033 | 71.394 | 226.676 | 1668.511 | 0.897 | 0.136 | | (2, 16, 1024, 256) | noop | torch.bfloat16 | 129.348 | 141.390 | 507.337 | 4405.175 | 0.915 | 0.115 | | (2, 16, 1024, 256) | causal_mask | torch.bfloat16 | 129.538 | 145.680 | 507.178 | 4768.874 | 0.889 | 0.106 | | (2, 16, 1024, 256) | relative_bias | torch.bfloat16 | 129.438 | 142.782 | 507.004 | 4401.002 | 0.907 | 0.115 | | (2, 16, 1024, 256) | head_bias | torch.bfloat16 | 129.058 | 146.242 | 507.547 | 4434.251 | 0.883 | 0.114 | | (2, 16, 4096, 64) | noop | torch.bfloat16 | 481.606 | 409.120 | 1440.890 | 14147.269 | 1.177 | 0.102 | | (2, 16, 4096, 64) | causal_mask | torch.bfloat16 | 480.227 | 438.847 | 1434.419 | 14973.386 | 1.094 | 0.096 | | (2, 16, 4096, 64) | relative_bias | torch.bfloat16 | 480.831 | 458.104 | 1432.935 | 14193.253 | 1.050 | 0.101 | | (2, 16, 4096, 64) | head_bias | torch.bfloat16 | 480.749 | 452.497 | 1437.040 | 14084.869 | 1.062 | 0.102 | | (2, 16, 4096, 128) | noop | torch.bfloat16 | 872.534 | 848.275 | 2600.895 | 35156.849 | 1.029 | 0.074 | | (2, 16, 4096, 128) | causal_mask | torch.bfloat16 | 872.647 | 868.279 | 2587.581 | 31919.531 | 1.005 | 0.081 | | (2, 16, 4096, 128) | relative_bias | torch.bfloat16 | 871.484 | 827.644 | 2593.989 | 34805.634 | 1.053 | 0.075 | | (2, 16, 4096, 128) | head_bias | torch.bfloat16 | 871.422 | 856.437 | 2602.482 | 35708.591 | 1.017 | 0.073 | | (2, 16, 4096, 256) | noop | torch.bfloat16 | 1904.497 | 1758.183 | 6122.416 | 66754.593 | 1.083 | 0.092 | | (2, 16, 4096, 256) | causal_mask | torch.bfloat16 | 1911.174 | 1762.821 | 6113.207 | 72759.392 | 1.084 | 0.084 | | (2, 16, 4096, 256) | relative_bias | torch.bfloat16 | 1911.254 | 1727.108 | 6123.530 | 66577.988 | 1.107 | 0.092 | | (2, 16, 4096, 256) | head_bias | torch.bfloat16 | 1916.977 | 1801.804 | 6118.158 | 67359.680 | 1.064 | 0.091 | | (8, 16, 512, 64) | noop | torch.bfloat16 | 44.984 | 43.974 | 170.276 | 262.259 | 1.023 | 0.649 | | (8, 16, 512, 64) | causal_mask | torch.bfloat16 | 45.001 | 46.265 | 170.509 | 274.893 | 0.973 | 0.620 | | (8, 16, 512, 64) | relative_bias | torch.bfloat16 | 45.466 | 48.211 | 170.606 | 262.759 | 0.943 | 0.649 | | (8, 16, 512, 64) | head_bias | torch.bfloat16 | 45.481 | 48.435 | 170.267 | 261.265 | 0.939 | 0.652 | | (8, 16, 512, 128) | noop | torch.bfloat16 | 72.565 | 74.736 | 313.220 | 773.126 | 0.971 | 0.405 | | (8, 16, 512, 128) | causal_mask | torch.bfloat16 | 72.015 | 75.755 | 313.311 | 775.513 | 0.951 | 0.404 | | (8, 16, 512, 128) | relative_bias | torch.bfloat16 | 72.105 | 74.189 | 313.806 | 769.238 | 0.972 | 0.408 | | (8, 16, 512, 128) | head_bias | torch.bfloat16 | 72.005 | 74.364 | 313.509 | 775.237 | 0.968 | 0.404 | | (8, 16, 512, 256) | noop | torch.bfloat16 | 138.656 | 165.453 | 663.707 | 2672.067 | 0.838 | 0.248 | | (8, 16, 512, 256) | causal_mask | torch.bfloat16 | 139.096 | 172.613 | 663.593 | 2926.538 | 0.806 | 0.227 | | (8, 16, 512, 256) | relative_bias | torch.bfloat16 | 139.500 | 168.417 | 663.938 | 2658.629 | 0.828 | 0.250 | | (8, 16, 512, 256) | head_bias | torch.bfloat16 | 139.776 | 173.549 | 662.920 | 2667.266 | 0.805 | 0.249 | | (8, 16, 1024, 64) | noop | torch.bfloat16 | 134.883 | 125.004 | 484.706 | 1195.254 | 1.079 | 0.406 | | (8, 16, 1024, 64) | causal_mask | torch.bfloat16 | 134.297 | 132.875 | 485.420 | 1234.953 | 1.011 | 0.393 | | (8, 16, 1024, 64) | relative_bias | torch.bfloat16 | 134.839 | 139.231 | 485.470 | 1198.556 | 0.968 | 0.405 | | (8, 16, 1024, 64) | head_bias | torch.bfloat16 | 133.822 | 136.449 | 485.608 | 1189.198 | 0.981 | 0.408 | | (8, 16, 1024, 128) | noop | torch.bfloat16 | 235.470 | 234.765 | 886.094 | 2662.944 | 1.003 | 0.333 | | (8, 16, 1024, 128) | causal_mask | torch.bfloat16 | 236.305 | 241.382 | 886.293 | 2646.984 | 0.979 | 0.335 | | (8, 16, 1024, 128) | relative_bias | torch.bfloat16 | 236.414 | 233.980 | 885.250 | 2642.178 | 1.010 | 0.335 | | (8, 16, 1024, 128) | head_bias | torch.bfloat16 | 237.176 | 239.040 | 885.754 | 2665.242 | 0.992 | 0.332 | | (8, 16, 1024, 256) | noop | torch.bfloat16 | 504.445 | 517.855 | 1978.956 | 9592.906 | 0.974 | 0.206 | | (8, 16, 1024, 256) | causal_mask | torch.bfloat16 | 502.428 | 536.002 | 1978.611 | 10607.342 | 0.937 | 0.187 | | (8, 16, 1024, 256) | relative_bias | torch.bfloat16 | 503.396 | 523.960 | 1977.993 | 9539.284 | 0.961 | 0.207 | | (8, 16, 1024, 256) | head_bias | torch.bfloat16 | 503.818 | 536.014 | 1980.131 | 9576.262 | 0.940 | 0.207 | | (8, 16, 4096, 64) | noop | torch.bfloat16 | 1970.139 | 1674.930 | 5750.940 | 16724.134 | 1.176 | 0.344 | | (8, 16, 4096, 64) | causal_mask | torch.bfloat16 | 1959.036 | 1775.056 | 5780.512 | 17390.350 | 1.104 | 0.332 | | (8, 16, 4096, 64) | relative_bias | torch.bfloat16 | 1947.198 | 1773.869 | 5780.643 | 16779.699 | 1.098 | 0.345 | | (8, 16, 4096, 64) | head_bias | torch.bfloat16 | 1963.935 | 1829.502 | 5780.018 | 16703.259 | 1.073 | 0.346 | | (8, 16, 4096, 128) | noop | torch.bfloat16 | 3582.711 | 3362.623 | 10436.069 | 36415.565 | 1.065 | 0.287 | | (8, 16, 4096, 128) | causal_mask | torch.bfloat16 | 3581.504 | 3499.472 | 10346.869 | 36164.959 | 1.023 | 0.286 | | (8, 16, 4096, 128) | relative_bias | torch.bfloat16 | 3589.779 | 3337.849 | 10529.621 | 36261.696 | 1.075 | 0.290 | | (8, 16, 4096, 128) | head_bias | torch.bfloat16 | 3602.265 | 3436.444 | 10458.660 | 36507.790 | 1.048 | 0.286 | | (8, 16, 4096, 256) | noop | torch.bfloat16 | 7695.923 | 7126.275 | 24643.009 | 140949.081 | 1.080 | 0.175 | | (8, 16, 4096, 256) | causal_mask | torch.bfloat16 | 7679.939 | 7186.252 | 24538.105 | 157156.067 | 1.069 | 0.156 | | (8, 16, 4096, 256) | relative_bias | torch.bfloat16 | 7681.374 | 6994.832 | 24549.713 | 140077.179 | 1.098 | 0.175 | | (8, 16, 4096, 256) | head_bias | torch.bfloat16 | 7679.822 | 7212.278 | 24627.823 | 140675.003 | 1.065 | 0.175 | | (16, 16, 512, 64) | noop | torch.bfloat16 | 80.126 | 78.291 | 333.719 | 541.165 | 1.023 | 0.617 | | (16, 16, 512, 64) | causal_mask | torch.bfloat16 | 80.065 | 81.696 | 333.779 | 551.113 | 0.980 | 0.606 | | (16, 16, 512, 64) | relative_bias | torch.bfloat16 | 80.138 | 86.715 | 333.364 | 542.118 | 0.924 | 0.615 | | (16, 16, 512, 64) | head_bias | torch.bfloat16 | 80.415 | 85.204 | 333.294 | 536.840 | 0.944 | 0.621 | | (16, 16, 512, 128) | noop | torch.bfloat16 | 134.964 | 138.025 | 607.093 | 1333.102 | 0.978 | 0.455 | | (16, 16, 512, 128) | causal_mask | torch.bfloat16 | 134.192 | 141.523 | 606.269 | 1424.318 | 0.948 | 0.426 | | (16, 16, 512, 128) | relative_bias | torch.bfloat16 | 135.711 | 138.639 | 606.283 | 1327.974 | 0.979 | 0.457 | | (16, 16, 512, 128) | head_bias | torch.bfloat16 | 135.552 | 140.555 | 607.107 | 1347.370 | 0.964 | 0.451 | | (16, 16, 512, 256) | noop | torch.bfloat16 | 275.113 | 315.144 | 1301.583 | 5268.153 | 0.873 | 0.247 | | (16, 16, 512, 256) | causal_mask | torch.bfloat16 | 274.867 | 328.106 | 1302.513 | 5770.594 | 0.838 | 0.226 | | (16, 16, 512, 256) | relative_bias | torch.bfloat16 | 276.052 | 321.770 | 1302.904 | 5241.920 | 0.858 | 0.249 | | (16, 16, 512, 256) | head_bias | torch.bfloat16 | 271.409 | 328.839 | 1302.142 | 5266.037 | 0.825 | 0.247 | | (16, 16, 1024, 64) | noop | torch.bfloat16 | 260.489 | 237.463 | 955.884 | 1817.558 | 1.097 | 0.526 | | (16, 16, 1024, 64) | causal_mask | torch.bfloat16 | 262.378 | 254.350 | 955.280 | 1843.807 | 1.032 | 0.518 | | (16, 16, 1024, 64) | relative_bias | torch.bfloat16 | 261.338 | 268.253 | 956.038 | 1820.036 | 0.974 | 0.525 | | (16, 16, 1024, 64) | head_bias | torch.bfloat16 | 262.153 | 264.156 | 956.023 | 1810.076 | 0.992 | 0.528 | | (16, 16, 1024, 128) | noop | torch.bfloat16 | 476.475 | 461.413 | 1760.578 | 4306.521 | 1.033 | 0.409 | | (16, 16, 1024, 128) | causal_mask | torch.bfloat16 | 473.794 | 479.178 | 1761.277 | 4619.439 | 0.989 | 0.381 | | (16, 16, 1024, 128) | relative_bias | torch.bfloat16 | 473.839 | 463.282 | 1758.692 | 4290.562 | 1.023 | 0.410 | | (16, 16, 1024, 128) | head_bias | torch.bfloat16 | 472.979 | 472.896 | 1763.086 | 4367.931 | 1.000 | 0.404 | | (16, 16, 1024, 256) | noop | torch.bfloat16 | 1014.184 | 1026.764 | 3922.997 | 19104.147 | 0.988 | 0.205 | | (16, 16, 1024, 256) | causal_mask | torch.bfloat16 | 1013.217 | 1039.046 | 3928.382 | 21086.281 | 0.975 | 0.186 | | (16, 16, 1024, 256) | relative_bias | torch.bfloat16 | 1008.519 | 1015.278 | 3922.133 | 18980.652 | 0.993 | 0.207 | | (16, 16, 1024, 256) | head_bias | torch.bfloat16 | 1011.360 | 1047.542 | 3931.245 | 19069.172 | 0.965 | 0.206 | | (16, 16, 4096, 64) | noop | torch.bfloat16 | 3929.850 | 3325.667 | 11411.704 | 23344.280 | 1.182 | 0.489 | | (16, 16, 4096, 64) | causal_mask | torch.bfloat16 | 3885.262 | 3581.544 | 11390.515 | 23725.639 | 1.085 | 0.480 | | (16, 16, 4096, 64) | relative_bias | torch.bfloat16 | 3865.737 | 3537.308 | 11489.901 | 23406.330 | 1.093 | 0.491 | | (16, 16, 4096, 64) | head_bias | torch.bfloat16 | 3880.530 | 3665.249 | 11484.411 | 23299.496 | 1.059 | 0.493 | | (16, 16, 4096, 128) | noop | torch.bfloat16 | 7030.306 | 6745.715 | 20621.264 | 57464.096 | 1.042 | 0.359 | | (16, 16, 4096, 128) | causal_mask | torch.bfloat16 | 7095.414 | 7034.385 | 20410.656 | 61660.511 | 1.009 | 0.331 | | (16, 16, 4096, 128) | relative_bias | torch.bfloat16 | 7084.779 | 6686.497 | 20315.161 | 57243.969 | 1.060 | 0.355 | | (16, 16, 4096, 128) | head_bias | torch.bfloat16 | 7075.367 | 6863.305 | 20494.385 | 58481.953 | 1.031 | 0.350 | | (16, 16, 4096, 256) | noop | torch.bfloat16 | 15612.741 | 14297.482 | 55306.847 | 281161.865 | 1.092 | 0.197 | | (16, 16, 4096, 256) | causal_mask | torch.bfloat16 | 15326.592 | 14263.878 | 55227.806 | 313063.232 | 1.075 | 0.176 | | (16, 16, 4096, 256) | relative_bias | torch.bfloat16 | 15297.963 | 14007.379 | 54558.029 | 279529.175 | 1.092 | 0.195 | | (16, 16, 4096, 256) | head_bias | torch.bfloat16 | 15216.160 | 14276.027 | 55081.581 | 280996.826 | 1.066 | 0.196 |
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125515 Approved by: https://github.com/Chillee --- benchmarks/transformer/score_mod.py | 171 ++++-- test/inductor/test_flex_attention.py | 196 ++++--- test/run_test.py | 5 +- torch/_higher_order_ops/flex_attention.py | 44 +- torch/_inductor/ir.py | 5 +- torch/_inductor/kernel/flex_attention.py | 640 +++++++++++++++++----- torch/_inductor/select_algorithm.py | 63 ++- torch/_inductor/utils.py | 2 +- torch/nn/attention/_flex_attention.py | 4 +- torch/testing/_internal/hop_db.py | 6 +- 10 files changed, 828 insertions(+), 308 deletions(-) diff --git a/benchmarks/transformer/score_mod.py b/benchmarks/transformer/score_mod.py index 2c5f41502f7e..57088c45f8a0 100644 --- a/benchmarks/transformer/score_mod.py +++ b/benchmarks/transformer/score_mod.py @@ -3,7 +3,7 @@ from collections import defaultdict from dataclasses import asdict, dataclass from functools import partial -from typing import Callable, List +from typing import Callable, List, Optional, Tuple import numpy as np import torch @@ -29,28 +29,32 @@ def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) -> @dataclass(frozen=True) class ExperimentConfig: - batch_size: int - num_heads: int - q_seq_len: int - k_seq_len: int - head_dim: int + shape: Tuple[int] score_mod: Callable dtype: torch.dtype + calculate_bwd_time: bool + + def __post_init__(self): + assert len(self.shape) == 4, "Shape must be of length 4" def asdict(self): - return asdict(self) + # Convert the dataclass instance to a dictionary + d = asdict(self) + # Remove the 'calculate_bwd_time' key + d.pop("calculate_bwd_time", None) + return d @dataclass(frozen=True) -class ExperimentResults: +class Times: eager_time: float compiled_time: float - def get_entries(self) -> List: - return [ - f"{self.eager_time:2f}", - f"{self.compiled_time:2f}", - ] + +@dataclass(frozen=True) +class ExperimentResults: + fwd_times: Times + bwd_times: Optional[Times] @dataclass(frozen=True) @@ -58,29 +62,31 @@ class Experiment: config: ExperimentConfig results: ExperimentResults - def get_entries(self) -> List: - return self.config.get_entries() + self.results.get_entries() - def asdict(self): - dict1 = asdict(self.config) + dict1 = self.config.asdict() dict2 = asdict(self.results) return {**dict1, **dict2} def generate_inputs( - batch_size, - num_heads, - q_sequence_length, - kv_sequence_length, - head_dim, - dtype, - device, + batch_size: int, + num_heads: int, + q_sequence_length: int, + kv_sequence_length: int, + head_dim: int, + dtype: torch.dtype, + device: torch.device, + requires_grad: bool, ): q_shape = (batch_size, q_sequence_length, num_heads * head_dim) kv_shape = (batch_size, kv_sequence_length, num_heads * head_dim) - make_q = partial(torch.rand, q_shape, device=device, dtype=dtype) - make_kv = partial(torch.rand, kv_shape, device=device, dtype=dtype) + make_q = partial( + torch.rand, q_shape, device=device, dtype=dtype, requires_grad=requires_grad + ) + make_kv = partial( + torch.rand, kv_shape, device=device, dtype=dtype, requires_grad=requires_grad + ) query = ( make_q() .view(batch_size, q_sequence_length, num_heads, head_dim) @@ -101,14 +107,16 @@ def generate_inputs( def run_single_experiment(config: ExperimentConfig, dynamic=False) -> ExperimentResults: device = torch.device("cuda") + batch_size, num_heads, q_seq_len, head_dim = config.shape query, key, value = generate_inputs( - config.batch_size, - config.num_heads, - config.q_seq_len, - config.k_seq_len, - config.head_dim, + batch_size, + num_heads, + q_seq_len, + q_seq_len, + head_dim, config.dtype, device, + requires_grad=config.calculate_bwd_time, ) def eager_sdpa(query, key, value, _): @@ -125,23 +133,47 @@ def eager_sdpa(query, key, value, _): compiled_sdpa, query, key, value, score_mod ) - return ExperimentResults( - eager_time=forward_eager_time, - compiled_time=forward_compiled_time, - ) + if config.calculate_bwd_time: + out_eager = eager_sdpa(query, key, value, score_mod) + dOut = torch.randn_like(out_eager) + backward_eager_time = benchmark_torch_function_in_microseconds( + out_eager.backward, dOut, retain_graph=True + ) + + out_compile = compiled_sdpa(query, key, value, score_mod) + dOut = torch.randn_like(out_eager) + backward_compile_time = benchmark_torch_function_in_microseconds( + out_compile.backward, dOut, retain_graph=True + ) + + return ExperimentResults( + fwd_times=Times(forward_eager_time, forward_compiled_time), + bwd_times=Times(backward_eager_time, backward_compile_time), + ) + else: + return ExperimentResults( + fwd_times=Times(forward_eager_time, forward_compiled_time), + bwd_times=None, + ) -def calculate_speedup(results: ExperimentResults) -> float: - return results.eager_time / results.compiled_time +def calculate_speedup(results: ExperimentResults, type: str) -> float: + if type == "fwd": + return results.fwd_times.eager_time / results.fwd_times.compiled_time + elif type == "bwd": + assert results.bwd_times is not None + return results.bwd_times.eager_time / results.bwd_times.compiled_time + else: + raise ValueError(f"Invalid type {type}") def get_func_name(func): return func.__name__.split(".")[-1].split(" at ")[0] -def get_average_speedups(results: List[Experiment]): +def get_average_speedups(results: List[Experiment], type: str): # Calculate speedups - speedups = [calculate_speedup(r.results) for r in results] + speedups = [calculate_speedup(r.results, type) for r in results] # Find indices of max and min speedups max_speedup_index = np.argmax(speedups) @@ -177,20 +209,39 @@ def print_results(results: List[Experiment]): table_data = defaultdict(list) for experiment in results: for key, value in experiment.asdict().items(): - if key == "eager_time" or key == "compiled_time": - value = float(value) - table_data[key].append(value) + if key == "fwd_times": + for name, time in value.items(): + table_data[f"fwd_{name}"].append(float(time)) + elif key == "bwd_times": + if experiment.config.calculate_bwd_time: + for name, time in value.items(): + table_data[f"bwd_{name}"].append(float(time)) + else: + table_data[key].append(value) # Calculate speedups - speedups = [calculate_speedup(r.results) for r in results] - table_data["speedup"] = speedups + fwd_speedups = [calculate_speedup(r.results, type="fwd") for r in results] + table_data["fwd_speedup"] = fwd_speedups + if results[0].config.calculate_bwd_time: + bwd_speedups = [calculate_speedup(r.results, type="bwd") for r in results] + table_data["bwd_speedup"] = bwd_speedups table_data["score_mod"] = [get_func_name(func) for func in table_data["score_mod"]] print(tabulate(table_data, headers="keys", tablefmt="github", floatfmt=".3f")) - average_data = get_average_speedups(results) + print("\n") + print("FWD Speedups".center(125, "=")) + print("\n") + average_data = get_average_speedups(results, type="fwd") print(tabulate(average_data, headers="keys", tablefmt="github", floatfmt=".3f")) + if results[0].config.calculate_bwd_time: + print("\n") + print("BWD Speedups".center(125, "=")) + print("\n") + average_data = get_average_speedups(results, type="bwd") + print(tabulate(average_data, headers="keys", tablefmt="github", floatfmt=".3f")) + def generate_score_mods() -> List[Callable]: def noop(score, b, h, m, n): @@ -208,8 +259,8 @@ def head_bias(score, b, h, m, n): return [noop, causal_mask, relative_bias, head_bias] -def generate_experiment_configs() -> List[ExperimentConfig]: - batch_sizes = [1, 8, 16] +def generate_experiment_configs(calculate_bwd: bool) -> List[ExperimentConfig]: + batch_sizes = [2, 8, 16] num_heads = [16] q_kv_seq_lens = [(512, 512), (1024, 1024), (4096, 4096)] head_dims = [64, 128, 256] @@ -228,41 +279,49 @@ def generate_experiment_configs() -> List[ExperimentConfig]: ) in itertools.product( batch_sizes, num_heads, q_kv_seq_lens, head_dims, score_mods, dtypes ): + assert q_seq_len == kv_seq_len, "Only equal length inputs supported for now." all_configs.append( ExperimentConfig( - batch_size=bsz, - num_heads=n_heads, - q_seq_len=q_seq_len, - k_seq_len=kv_seq_len, - head_dim=head_dim, + shape=(bsz, n_heads, q_seq_len, head_dim), score_mod=score_mod, dtype=dtype, + calculate_bwd_time=calculate_bwd, ) ) return all_configs -def main(dynamic=False): +def main(dynamic: bool, calculate_bwd: bool): seed = 123 np.random.seed(seed) torch.manual_seed(seed) results = [] - for config in tqdm(generate_experiment_configs()): + for config in tqdm(generate_experiment_configs(calculate_bwd)): results.append( Experiment(config, run_single_experiment(config, dynamic=dynamic)) ) + for config in tqdm(generate_experiment_configs(calculate_bwd)): + results.append(Experiment(config, run_single_experiment(config))) print_results(results) if __name__ == "__main__": - parser = argparse.ArgumentParser() + # Set up the argument parser + parser = argparse.ArgumentParser( + description="Run sweep over sizes and score mods for flex attention" + ) parser.add_argument( "--dynamic", action="store_true", help="Runs a dynamic shapes version of compiled flex attention.", ) + parser.add_argument( + "--calculate-bwd", action="store_true", help="Calculate backward pass times" + ) + # Parse arguments args = parser.parse_args() - main(args.dynamic) + + main(args.dynamic, args.calculate_bwd) diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index 9df905d2ad54..f3a9026a3c80 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -1,8 +1,9 @@ # Owner(s): ["module: inductor"] import functools +import unittest from collections import namedtuple -from typing import Callable +from typing import Callable, Optional from unittest import expectedFailure, skip, skipUnless from unittest.mock import patch @@ -58,14 +59,8 @@ def create_attention(score_mod): # --------- Useful score mod functions for testing --------- - -test_score_mods = [ - _identity, - _causal, - _rel_bias, - _rel_causal, - _generate_alibi_bias(8), -] +def _inverse_causal(score, b, h, m, n): + return torch.where(m <= n, score, float("-inf")) def _times_two(score, b, h, m, n): @@ -79,13 +74,11 @@ def _squared(score, b, h, m, n): def _head_offset(dtype: torch.dtype): - """Captured Buffer - Note: this builds a score_mod with index of a type - """ + """Captured Buffer""" head_offset = torch.rand(H, device="cuda", dtype=dtype) def score_mod(score, b, h, m, n): - return score * index(head_offset, [h]) + return score * head_offset[h] return score_mod @@ -103,20 +96,19 @@ def _trig2(score, b, h, m, n): return z -def _buffer_reduced(dtype: torch.dtype): - """Reduction in captured buffer""" - batch_offsets = torch.rand(B, 8, device="cuda", dtype=dtype) - - def score_mod(score, b, h, m, n): - batch_vals = index(batch_offsets, [b]) - return score + batch_vals.sum() - - return score_mod - +test_score_mods = [ + _identity, + _times_two, + _squared, + _causal, + _inverse_causal, + _rel_bias, + _rel_causal, + _generate_alibi_bias(8), +] captured_buffers_map = { "_head_offset": _head_offset, - "_buffer_reduced": _buffer_reduced, } B = 4 @@ -125,18 +117,35 @@ def score_mod(score, b, h, m, n): D = 64 -class TestTemplatedSDPA(InductorTestCase): - def _check_equal(self, golden_out, ref_out, compiled_out, dtype): +def query_key_value_clones( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + dtype: torch.dtype = None, +): + """Clones the query, key, and value tensors and moves them to the specified dtype.""" + if dtype is None: + dtype = query.dtype + query_ref = query.clone().detach().to(dtype).requires_grad_(query.requires_grad) + key_ref = key.clone().detach().to(dtype).requires_grad_(key.requires_grad) + value_ref = value.clone().detach().to(dtype).requires_grad_(value.requires_grad) + return query_ref, key_ref, value_ref + + +class TestFlexAttention(InductorTestCase): + def _check_equal( + self, + golden_out: torch.Tensor, + ref_out: torch.Tensor, + compiled_out: torch.Tensor, + fudge_factor: float, + tensor_name: Optional[str] = None, + ): compiled_error = (golden_out - compiled_out).abs().mean() ref_error = (golden_out - ref_out).abs().mean() - # Note, it seems like we really are less accurate than the float32 - # computation, likely due to the online softmax - if dtype == torch.float32: - fudge_factor = 10.0 - else: - fudge_factor = 1.1 if compiled_error > ref_error * fudge_factor: - msg = f"Compiled error {compiled_error} is greater than ref error {ref_error} by more than {fudge_factor}X." + name = tensor_name if tensor_name is not None else "" + msg = f"{name} Compiled error {compiled_error} is greater than ref error {ref_error} by more than {fudge_factor}X." self.assertTrue(False, msg) def run_test( @@ -150,15 +159,45 @@ def run_test( ): sdpa_partial = create_attention(score_mod) compiled_sdpa = torch.compile(sdpa_partial) - q = torch.randn((B, H, S, D), dtype=dtype, device="cuda") - k = torch.randn((B, H, S, D), dtype=dtype, device="cuda") - v = torch.randn((B, H, S, D), dtype=dtype, device="cuda") - golden_out = sdpa_partial( - q.to(torch.float64), k.to(torch.float64), v.to(torch.float64) - ) - ref_out = sdpa_partial(q, k, v) + q = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) + k = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) + v = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) + q_ref, k_ref, v_ref = query_key_value_clones(q, k, v) + q_gold, k_gold, v_gold = query_key_value_clones(q, k, v, torch.float64) + golden_out = sdpa_partial(q_gold, k_gold, v_gold) + ref_out = sdpa_partial(q_ref, k_ref, v_ref) compiled_out = compiled_sdpa(q, k, v) - self._check_equal(golden_out, ref_out, compiled_out, dtype) + + backward_grad = torch.randn((B, H, S, D), dtype=dtype, device="cuda") + + golden_out.backward(backward_grad.to(torch.float64)) + ref_out.backward(backward_grad) + compiled_out.backward(backward_grad) + + with torch.no_grad(): + # Note, it seems like we really are less accurate than the float32 + # computation, likely due to the online softmax + if dtype == torch.float32: + fudge_factor = 10.0 + else: + fudge_factor = 1.1 + + # Checkout output + self._check_equal(golden_out, ref_out, compiled_out, fudge_factor, "Out") + + # Check gradients + q_fudge_factor = 2.5 * fudge_factor + self._check_equal( + q_gold.grad, q_ref.grad, q.grad, q_fudge_factor, "Grad_Query" + ) + k_fudge_factor = 4 * fudge_factor + self._check_equal( + k_gold.grad, k_ref.grad, k.grad, k_fudge_factor, "Grad_Key" + ) + v_fudge_factor = 8 * fudge_factor + self._check_equal( + v_gold.grad, v_ref.grad, v.grad, v_fudge_factor, "Grad_Value" + ) def run_dynamic_test( self, @@ -196,12 +235,20 @@ def run_dynamic_test( # Compiling with dynamic shape in the first batch. compiled_sdpa = torch.compile(sdpa_partial, dynamic=True) compiled_out1 = compiled_sdpa(q1, k1, v1) - self._check_equal(golden_out1, ref_out1, compiled_out1, dtype) + + # Note, it seems like we really are less accurate than the float32 + # computation, likely due to the online softmax + if dtype == torch.float32: + fudge_factor = 10.0 + else: + fudge_factor = 1.1 + + self._check_equal(golden_out1, ref_out1, compiled_out1, fudge_factor) self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1) # No re-compilation, use the compiled dynamic shape version. compiled_out2 = compiled_sdpa(q2, k2, v2) - self._check_equal(golden_out2, ref_out2, compiled_out2, dtype) + self._check_equal(golden_out2, ref_out2, compiled_out2, fudge_factor) self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1) def run_automatic_dynamic_test( @@ -251,20 +298,28 @@ def run_automatic_dynamic_test( # 2, the second batch is compiled with dynamic shape # 3, no re-compilation in the third batch torch._dynamo.reset() + + # Note, it seems like we really are less accurate than the float32 + # computation, likely due to the online softmax + if dtype == torch.float32: + fudge_factor = 10.0 + else: + fudge_factor = 1.1 + # The first batch. compiled_sdpa = torch.compile(sdpa_partial) compiled_out1 = compiled_sdpa(q1, k1, v1) - self._check_equal(golden_out1, ref_out1, compiled_out1, dtype) + self._check_equal(golden_out1, ref_out1, compiled_out1, fudge_factor) self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1) # The second batch (automatic dynamic). compiled_out2 = compiled_sdpa(q2, k2, v2) - self._check_equal(golden_out2, ref_out2, compiled_out2, dtype) + self._check_equal(golden_out2, ref_out2, compiled_out2, fudge_factor) self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 2) # The third batch (no re-compilation). compiled_out3 = compiled_sdpa(q3, k3, v3) - self._check_equal(golden_out3, ref_out3, compiled_out3, dtype) + self._check_equal(golden_out3, ref_out3, compiled_out3, fudge_factor) self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 2) @supported_platform @@ -318,6 +373,21 @@ def score_mod(score, b, h, m, n): self.run_test(score_mod, dtype) + @supported_platform + @common_utils.parametrize("dtype", test_dtypes) + def test_captured_buffers_all_dims(self, dtype: torch.dtype): + head_scale = torch.randn(H, device="cuda") + batch_scale = torch.randn(B, device="cuda") + tok_scale = torch.randn(S, device="cuda") + + def all_bias(score, batch, head, token_q, token_kv): + score = score + tok_scale[token_q] + score = score + batch_scale[batch] + score = score + head_scale[head] + return score + + self.run_test(all_bias, dtype) + @supported_platform @common_utils.parametrize("dtype", test_dtypes_fast) def test_seq_masking(self, dtype): @@ -422,7 +492,7 @@ def score_mod_func(score, b, h, q, kv): make_tensor = functools.partial( torch.randn, - (2, 2, 8, 4), + (2, 2, 128, 4), device="cuda", dtype=torch.float64, requires_grad=True, @@ -458,6 +528,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): @supported_platform @common_utils.parametrize("dtype", test_dtypes_fast) + @unittest.skip("Silu decomp failing for full in backwards") def test_silu_on_score(self, dtype): def silu_score(score, b, h, q, kv): return torch.nn.functional.silu(score) @@ -597,23 +668,6 @@ def njt_score_mod(qk, b, h, q, kv): self.run_test(causal_njt, dtype) - @supported_platform - def test_backwards_fails(self): - make_tensor = functools.partial( - torch.randn, - (B, H, S, D), - dtype=torch.float32, - device="cuda", - requires_grad=True, - ) - q, k, v = make_tensor(), make_tensor(), make_tensor() - func = torch.compile(_flex_attention, backend="inductor", fullgraph=True) - with self.assertRaisesRegex( - AssertionError, "flex_attention_backward is not an OpOverload" - ): - out = func(q, k, v, _identity) - out.backward(torch.ones_like(out)) - @supported_platform def test_mixed_dtypes_fails(self): query = torch.randn((1, 1, 1024, 64), dtype=torch.float32, device="cuda") @@ -641,6 +695,7 @@ def score_mod(score, b, h, m, n): self.run_test(score_mod) @supported_platform + @skip("TODO: Figure out why this is erroring") @patch.object(torch._inductor.config, "max_autotune", True) def test_max_autotune_with_captured(self): head_scale = torch.randn(H, device="cuda") @@ -776,7 +831,7 @@ def test_aot_eager_gradcheck(self, score_mod): ) @supported_platform - @common_utils.parametrize("score_mod_name", ["_head_offset", "_buffer_reduced"]) + @common_utils.parametrize("score_mod_name", ["_head_offset"]) @common_utils.parametrize("mode", ["eager", "aot_eager"]) def test_captured_score_mod_aot_eager_gradcheck( self, score_mod_name: str, mode: str @@ -864,13 +919,10 @@ def debug_compile_fx_inner(graph, example_inputs, *args, **kwargs): joint_graph, """\ class GraphModule(torch.nn.Module): - def forward(self, primals_1: "f64[2, 2, 8, 4]", primals_2: "f64[2, 2, 8, 4]", primals_3: "f64[2, 2, 8, 4]", """ - + """alias_5: "f64[2, 2, 8, 4]", alias_7: "f32[2, 2, 8]", tangents_1: "f64[2, 2, 8, 4]"): + def forward(self, primals_1: "f64[2, 2, 8, 4]", primals_2: "f64[2, 2, 8, 4]", primals_3: "f64[2, 2, 8, 4]", alias_3: "f64[2, 2, 8, 4]", alias_5: "f32[2, 2, 8]", tangents_1: "f64[2, 2, 8, 4]"): fw_graph = self.fw_graph joint_graph = self.joint_graph - flex_attention_backward = torch.ops.higher_order.flex_attention_backward(primals_1, primals_2, """ - + """primals_3, alias_5, alias_7, tangents_1, fw_graph, joint_graph); primals_1 = primals_2 = primals_3 = alias_5 """ - + """= alias_7 = tangents_1 = fw_graph = joint_graph = None + flex_attention_backward = torch.ops.higher_order.flex_attention_backward(primals_1, primals_2, primals_3, alias_3, alias_5, tangents_1, fw_graph, joint_graph); primals_1 = primals_2 = primals_3 = alias_3 = alias_5 = tangents_1 = fw_graph = joint_graph = None getitem_2: "f64[2, 2, 8, 4]" = flex_attention_backward[0] getitem_3: "f64[2, 2, 8, 4]" = flex_attention_backward[1] getitem_4: "f64[2, 2, 8, 4]" = flex_attention_backward[2]; flex_attention_backward = None @@ -888,11 +940,11 @@ def forward(self, arg0_1: "f64[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i3 mul_2: "f64[]" = torch.ops.aten.mul.Tensor(arg5_1, arg0_1); arg5_1 = arg0_1 = None add: "f64[]" = torch.ops.aten.add.Tensor(mul_2, mul_1); mul_2 = mul_1 = None return [add, None, None, None, None] -""", +""", # noqa: B950 ) -common_utils.instantiate_parametrized_tests(TestTemplatedSDPA) +common_utils.instantiate_parametrized_tests(TestFlexAttention) if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/test/run_test.py b/test/run_test.py index cbee11b37a7a..71ab08199f7a 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -240,7 +240,8 @@ def __contains__(self, item): "test_native_mha", # OOM "test_module_hooks", # OOM "inductor/test_max_autotune", - "inductor/test_cutlass_backend", # slow due to many nvcc compilation steps + "inductor/test_cutlass_backend", # slow due to many nvcc compilation steps, + "inductor/test_flex_attention", # OOM ] # A subset of onnx tests that cannot run in parallel due to high memory usage. ONNX_SERIAL_LIST = [ @@ -407,7 +408,7 @@ def run_test( stepcurrent_key = f"{test_file}_{test_module.shard}_{os.urandom(8).hex()}" if options.verbose: - unittest_args.append(f'-{"v"*options.verbose}') # in case of pytest + unittest_args.append(f'-{"v" * options.verbose}') # in case of pytest if test_file in RUN_PARALLEL_BLOCKLIST: unittest_args = [ diff --git a/torch/_higher_order_ops/flex_attention.py b/torch/_higher_order_ops/flex_attention.py index b5e1385da346..f4586a0a57b0 100644 --- a/torch/_higher_order_ops/flex_attention.py +++ b/torch/_higher_order_ops/flex_attention.py @@ -406,17 +406,20 @@ def flex_attention_autograd( score_mod: Callable, *other_buffers: Tuple[torch.Tensor, ...], ) -> Tuple[torch.Tensor, torch.Tensor]: - input_requires_grad = any(t.requires_grad for t in (query, key, value)) - if torch.is_grad_enabled() and input_requires_grad: - example_vals = [ - torch.zeros((), dtype=query.dtype, requires_grad=input_requires_grad) - ] + [torch.zeros((), dtype=torch.int) for _ in range(4)] - fw_graph, bw_graph = create_fw_bw_graph(score_mod, example_vals, other_buffers) - else: - fw_graph, bw_graph = score_mod, None - out, logsumexp = FlexAttentionAutogradOp.apply( - query, key, value, fw_graph, bw_graph, *other_buffers - ) + with TransformGetItemToIndex(): + input_requires_grad = any(t.requires_grad for t in (query, key, value)) + if torch.is_grad_enabled() and input_requires_grad: + example_vals = [ + torch.zeros((), dtype=query.dtype, requires_grad=input_requires_grad) + ] + [torch.zeros((), dtype=torch.int) for _ in range(4)] + fw_graph, bw_graph = create_fw_bw_graph( + score_mod, example_vals, other_buffers + ) + else: + fw_graph, bw_graph = score_mod, None + out, logsumexp = FlexAttentionAutogradOp.apply( + query, key, value, fw_graph, bw_graph, *other_buffers + ) return out, logsumexp @@ -449,9 +452,10 @@ def sdpa_dense_backward( score_mod = torch.vmap(score_mod, in_dims=(0, None, 0, None, None) + in_dim_buffers) score_mod = torch.vmap(score_mod, in_dims=(0, 0, None, None, None) + in_dim_buffers) - post_mod_scores = score_mod(scores, b, h, m, n, *other_buffers).to( - working_precision - ) + with TransformGetItemToIndex(): + post_mod_scores = score_mod(scores, b, h, m, n, *other_buffers).to( + working_precision + ) softmax_scores = torch.exp(post_mod_scores - logsumexp.unsqueeze(-1)) @@ -485,9 +489,10 @@ def sdpa_dense_backward( in_dims=(0, 0, None, None, None, 0) + in_dim_buffers, out_dims=out_dims, ) - grad_scores, *_ = joint_score_mod( - scores, b, h, m, n, grad_score_mod, *other_buffers - ) + with TransformGetItemToIndex(): + grad_scores, *_ = joint_score_mod( + scores, b, h, m, n, grad_score_mod, *other_buffers + ) grad_scores = grad_scores.to(query.dtype) grad_query = grad_scores @ key @@ -524,8 +529,9 @@ def trace_flex_attention_backward( torch.zeros((), dtype=query.dtype, requires_grad=query.requires_grad) ] + [torch.zeros((), dtype=torch.int) for _ in range(4)] bw_example_vals = fw_example_vals + [torch.zeros((), dtype=query.dtype)] - fw_graph = make_fx(fw_graph)(*fw_example_vals, *other_buffers) - joint_graph = make_fx(joint_graph)(*bw_example_vals, *other_buffers) + with TransformGetItemToIndex(): + fw_graph = reenter_make_fx(fw_graph)(*fw_example_vals, *other_buffers) + joint_graph = reenter_make_fx(joint_graph)(*bw_example_vals, *other_buffers) proxy_mode.tracer.root.register_module("fw_graph", fw_graph) proxy_mode.tracer.root.register_module("joint_graph", joint_graph) node_args = ( diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index d00a01184d1e..be730decda14 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -3595,7 +3595,10 @@ def __init__( self.mutated_inputs = mutated_inputs if mutated_inputs is not None: # Ensure that the mutated inputs are only allowed for certain nodes - allowed_set = {torch.ops.higher_order.flex_attention} + allowed_set = { + torch.ops.higher_order.flex_attention, + torch.ops.higher_order.flex_attention_backward, + } current_node = V.graph.current_node.target assert ( current_node in allowed_set diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index a780d3709cb0..32dff9d46668 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -1,17 +1,39 @@ """ Triton Implementation of the flex_attention Kernel""" + import logging -from typing import Any, List +import math +from enum import auto, Enum +from typing import Any, List, Tuple import torch +from torch._prims_common import make_contiguous_strides_for from .. import config -from ..lowering import empty_strided, lowerings, register_lowering +from ..ir import ( + ComputedBuffer, + FixedLayout, + FlexibleLayout, + InputBuffer, + IRNode, + StorageBox, + Subgraph, + TensorBox, +) +from ..lowering import empty_strided, full, lowerings, register_lowering from ..select_algorithm import autotune_select_algorithm, TritonTemplate log = logging.getLogger(__name__) aten = torch.ops.aten -def sdpa_grid(batch_size, num_heads, num_queries, d_model, meta): +class SubgraphType(Enum): + """The type of subgraph for which we want to generate an output buffer.""" + + FWD = auto() # Forward pass + JOINT_FWD = auto() # The recompute step fo the of the bwds kernel + JOINT_BWD = auto() # The bwd pass of the joint + + +def flex_attention_grid(batch_size, num_heads, num_queries, d_model, meta): """How is this kernel parallelized? We create a grid of (batch_size * num_heads, ceil_div(n_queries, query_block_size), 1) Each block is responsible for iterating over blocks of keys and values calculating @@ -22,9 +44,117 @@ def sdpa_grid(batch_size, num_heads, num_queries, d_model, meta): return (triton.cdiv(num_queries, meta["BLOCK_M"]), batch_size * num_heads, 1) -sdpa_template = TritonTemplate( - name="sdpa", - grid=sdpa_grid, +def create_placeholder( + name: str, dtype: torch.dtype, device: torch.device +) -> TensorBox: + """Creates a placeholder input buffers for producing subgraph_output.""" + input_buffer = InputBuffer(name, FixedLayout(device, dtype, [1], [1])) + return TensorBox.create(input_buffer) + + +def index_to_other_buffers(cnt: int, graph_type: SubgraphType) -> int: + """This function needs to be aware of the signatures for flex_attention_forward + and flex_attention_backward. If new args are added, or the signature changes + be sure to update the indexing math + + Args: + cnt (int): The current index of the placeholder node + is_joint_graph (bool): Whether or not this subgraph represents the joint graph + """ + # Current fwd_args = [query, key, value, score_mod, *other_buffers] + # For fwd_graphs we have 5 dummy values this when the first lifted args + # is seen cnt = 5 and the start of the index_buffers is at args[4] + # thus we subtract 1 from the current cnt + if graph_type == SubgraphType.FWD: + return cnt - 1 + + # Current bwd_args = [q, k, v, out, lse, grad_out, fw_graph, joint_graph, *other_buffers] + # We have 5 dummy values but the start of other_buffers is at index 8 + if graph_type == SubgraphType.JOINT_FWD: + return cnt + 3 + + # Same bwd args but now with 6 dummy values while other_buffers still start at 8 + if graph_type == SubgraphType.JOINT_BWD: + return cnt + 2 + + +def build_subgraph_buffer( + args: Tuple[IRNode], + placeholder_inps: List[TensorBox], + subgraph: Subgraph, + graph_type: SubgraphType, +) -> ComputedBuffer: + """This function's goal is to take in the required args and produce the subgraph buffer + The subgraph buffer is a ComputedBuffer that will be inlined into the triton template + + Args: + args: The args that were passed into the flex_attention kernel + placeholder_inps: The list of scalar inputs, these were created on the fly through `create_placeholder` + subgraph: The Subgraph ir for which to produce the output node + graph_type: The type of subgraph for which we want to produce the output node, see enum above for details + """ + cnt = 0 + env = {} + for node in subgraph.graph_module.graph.nodes: + # There are two classes of placeholder inpts that we need + # to handle differently. For the first n_scalar_inps inputs + # we expect that these placeholders were generated by the make_fx call + # in the flex Attention HOP. So we need to create a new placeholder + # TensorBox for each of these inputs. For the rest of the inputs we + # expect that these are lifted inputs that fill up the '*other_buffers' + # tuple and already have corresponding TensorBoxes passed in as args. + if node.op == "placeholder": + is_lifted_input = cnt >= len(placeholder_inps) + lifted_input_index = index_to_other_buffers(cnt, graph_type) + env[node] = ( + args[lifted_input_index] if is_lifted_input else placeholder_inps[cnt] + ) + cnt += 1 + elif node.op == "call_function": + # For call_function we use the default lowerings and pass in the + # already created TensorBoxes as args + from torch.utils._pytree import tree_map + + env[node] = lowerings[node.target]( + *tree_map(lambda x: env[x] if x in env else x, node.args) + ) + elif node.op == "output": + # For the output node we need to create a ComputedBuffer + # which represents the actual score modification + # The joint_graph's output should be of the form[grad_score, None, None, None, None] + # This is because only the 'score' requires grad and the other outputs are + # the non-differentiable index scalars + if graph_type == SubgraphType.FWD or graph_type == SubgraphType.JOINT_FWD: + output_node = node.args[0] + else: + output_node = node.args[0][0] + output_buffer = env[output_node] + assert isinstance(output_buffer, TensorBox), ( + "The output node for flex attention's subgraph must be a TensorBox, but got: ", + type(output_buffer), + ) + assert isinstance(output_buffer.data, StorageBox), ( + "The output node for the flex attention subgraph must be a StorageBox, but got: ", + type(output_buffer), + ) + # Create the ComputedBuffer directly that will be inlined into the modification block + subgraph_buffer = ComputedBuffer( + name=None, + layout=FlexibleLayout( + device=output_buffer.data.get_device(), + dtype=output_buffer.data.get_dtype(), + size=output_buffer.data.get_size(), + ), + data=output_buffer.data.data, # type: ignore[arg-type] + ) + return subgraph_buffer + + raise ValueError("TemplatedAttention was passed a subgraph with no output node!") + + +flex_attention_template = TritonTemplate( + name="flex_attention", + grid=flex_attention_grid, source=r""" {{def_kernel("Q", "K", "V", "LSE")}} # Sub notation for this kernel: @@ -118,6 +248,7 @@ def sdpa_grid(batch_size, num_heads, num_queries, d_model, meta): m = offs_m[:, None] n = start_n + offs_n[None, :] {{ modification( + subgraph_number=0, score="qk", b="off_hz // H", h="off_hz % H", @@ -192,7 +323,7 @@ def sdpa_grid(batch_size, num_heads, num_queries, d_model, meta): } -def _get_default_config(query): +def _get_default_config_fwd(query) -> Tuple[int, int, int, int]: dtype = query.get_dtype() head_dim = query.get_size()[-1] default_config = None @@ -218,143 +349,394 @@ def _get_default_config(query): return default_config +def _get_default_config_bwd(query) -> Tuple[int, int, int, int]: + head_dim = query.get_size()[-1] + dtype = query.get_dtype() + + if head_dim <= 256 and torch.cuda.get_device_capability() >= (9, 0): # H100 + if dtype == torch.float32: + return (64, 64, 4, 1) + return (128, 128, 4, 3) + elif head_dim <= 256 and torch.cuda.get_device_capability() >= (8, 0): # A100 + return (32, 32, 4, 1) + else: # modest hardware or extremely large head_dim + return (32, 32, 4, 1) + + # TODO: We probably also need a layout constraint? @register_lowering(torch.ops.higher_order.flex_attention, type_promotion_kind=None) def flex_attention(*args, **kwargs): - from torch._prims_common import make_contiguous_strides_for - from ..ir import ( - ComputedBuffer, - FixedLayout, - FlexibleLayout, - InputBuffer, - StorageBox, - TensorBox, - ) - query, key, value, subgraph, *other_buffers = args + placeholder_inps = [ + create_placeholder(name, dtype, query.get_device()) + for name, dtype in [ + ("score", query.get_dtype()), + ("b", torch.int32), + ("h", torch.int32), + ("m", torch.int32), + ("n", torch.int32), + ] + ] + subgraph_buffer = build_subgraph_buffer( + args, placeholder_inps, subgraph, graph_type=SubgraphType.FWD + ) + layout = FixedLayout( + query.get_device(), + query.get_dtype(), + query.get_size(), + make_contiguous_strides_for(query.get_size()), + ) + # see NOTE:[TritonTemplates with multiple outputs] + logsumexp_shape = query.get_size()[:-1] # [B, H, M] + logsumexp = empty_strided( + logsumexp_shape, + None, + dtype=torch.float32, # The logsumexp is always stored in fp32 regardless of the input dtype + device=query.get_device(), + ) + choices: List[Any] = [] + configs: List[Tuple[int, int, int, int]] = [] + configs.append(_get_default_config_fwd(query)) + if config.max_autotune: + configs += [ + (128, 64, 4, 3), + (128, 128, 4, 3), + (128, 128, 8, 2), + (64, 128, 4, 3), + (64, 64, 4, 3), + ] - def create_placeholder(name: str, dtype: torch.dtype) -> InputBuffer: - return TensorBox.create( - InputBuffer( - name, - FixedLayout( - query.get_device(), - dtype, - [ - 1, - ], - [ - 1, - ], - ), - ) + # Note, we don't need to pass in the captured buffers explicitly + # because they're implicitly added by the score_mod function + # We do need to explicitly pass it in for autotuning though. + for BLOCK_M, BLOCK_N, num_warps, num_stages in configs: + flex_attention_template.maybe_append_choice( + choices=choices, + input_nodes=[query, key, value, logsumexp], + layout=layout, + subgraphs=[ + subgraph_buffer, + ], + mutated_inputs=[ + logsumexp, + ], + num_stages=num_stages, + num_warps=num_warps, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_DMODEL=query.get_size()[-1], + # For now, we always assume the "sound" option + SCORE_MOD_IS_LINEAR=False, + ROWS_GUARANTEED_SAFE=False, + OUTPUT_LOGSUMEXP=True, ) + inputs_for_autotuning = [query, key, value, logsumexp] + list(other_buffers) + return ( + autotune_select_algorithm( + "flex_attention", choices, inputs_for_autotuning, layout + ), + logsumexp, + ) - scalar_inps = ["score", "b", "h", "m", "n"] - env = {} - cnt = 0 - placeholder_inps = [ - create_placeholder(name, dtype) + +# ---------------------------- Backward HOP Implementation ---------------------------- + + +def flex_attention_backward_grid(batch_size, num_heads, num_key_value, d_model, meta): + """How is this kernel parallelized? + Currently this is only parallelizing over batch * num_heads, but we can, and want to + parallelize over ceil_div(num_key_value, key_value_block_size). To do this will either require + atomic updates to some grad values or to have a two pass kernel design. + """ + return (batch_size * num_heads, 1, 1) + + +flex_attention_backward_template = TritonTemplate( + name="flex_attention_backward", + grid=flex_attention_backward_grid, + source=r""" +{{def_kernel("Q", "K", "V", "OUT", "LSE", "DELTA", "DO", "DQ", "DV")}} + # Sub notation for this kernel: + # Q: Query, K: Key, V: Value + # OUT: Forward output, LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT* DO, axis=1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values, D: Model dimension + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # (Modifiable) Config options: + # BLOCK_M + # BLOCK_N + # SCORE_MOD_IS_LINEAR: Is the score modifier linear? If so, we can lift the + # change of base out of the loop + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + + # Define Q Strides + stride_qz = {{stride("Q", 0)}} + stride_qh = {{stride("Q", 1)}} + stride_qm = {{stride("Q", 2)}} + stride_qk = {{stride("Q", 3)}} + # Define K Strides + stride_kz = {{stride("K", 0)}} + stride_kh = {{stride("K", 1)}} + stride_kn = {{stride("K", 2)}} + stride_kk = {{stride("K", 3)}} + # Define V Strides + stride_vz = {{stride("V", 0)}} + stride_vh = {{stride("V", 1)}} + stride_vn = {{stride("V", 2)}} + stride_vk = {{stride("V", 3)}} + + Z = {{size("Q", 0)}} + H = {{size("Q", 1)}} + N_CTX = {{size("Q", 2)}} + + qk_scale = 1.0 + MATMUL_PRECISION = Q.dtype.element_ty + + off_hz = tl.program_id(0) + off_z = off_hz // H # batch idx + off_h = off_hz % H # head idx + + # offset pointers for batch/head + Q += off_z * stride_qz + off_h * stride_qh + K += off_z * stride_kz + off_h * stride_kh + V += off_z * stride_vz + off_h * stride_vh + + # Asserting contiguous for now... + DO += off_z * stride_qz + off_h * stride_qh + DQ += off_z * stride_qz + off_h * stride_qh + DV += off_z * stride_vz + off_h * stride_vh + + # TODO I think that this should be N_CTX/BLOCK_N blocks + for start_n in range(0, NUM_Q_BLOCKS): + # We are not doing the causal optimization yet allowing us to start further down the + # kv column + offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_m = tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, BLOCK_DMODEL) + + # initialize pointers to value-like data + q_ptrs = Q + (offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk) + k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) + v_ptrs = V + (offs_n[:, None] * stride_vn + offs_k[None, :] * stride_vk) + do_ptrs = DO + (offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk) + dq_ptrs = DQ + (offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk) + + # pointer to row-wise quantities in value-like data + D_ptrs = DELTA + off_hz * N_CTX + l_ptrs = LSE + off_hz * N_CTX + + # initialize dv and dk + dv = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) + dk = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) + + # Key and Value stay in SRAM throughout + k = tl.load(k_ptrs) + v = tl.load(v_ptrs) + + for start_m in range(0, NUM_Q_BLOCKS * BLOCK_M, BLOCK_M): + offs_m_curr = start_m + offs_m + + # load q, k, v, do on-chip + q = tl.load(q_ptrs) + + if SCORE_MOD_IS_LINEAR: + qk_scale *= 1.44269504 + q = (q * qk_scale).to(MATMUL_PRECISION) + + # -- compute qk --- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = tl.dot(q, tl.trans(k.to(MATMUL_PRECISION)), acc=qk) + pre_mod_scores = qk + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = offs_m_curr[:, None] + n = offs_n[None, :] + {{ modification( + subgraph_number=0, + score="qk", + b="off_z", + h="off_h", + m="m", + n="n", + out="qk" + ) | indent_except_first(3) }} + # TODO: In the case that score_mod is linear, this can be LICMed + if not SCORE_MOD_IS_LINEAR: + qk *= 1.44269504 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + l_i = tl.load(l_ptrs + offs_m_curr) + p = tl.math.exp2(qk - l_i[:, None]) + + # compute dv + do = tl.load(do_ptrs) + dv += tl.dot(tl.trans(p.to(MATMUL_PRECISION)), do) + + # compute dp = dot(v, do) + Di = tl.load(D_ptrs + offs_m_curr) # [BLOCKM, 1] + + # compute ds = p * (dp - delta[:, None]) + dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None] + dp += tl.dot(do, tl.trans(v)) + ds = p * dp + + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + {{ modification( + subgraph_number=1, + score="pre_mod_scores", + b="off_z", + h="off_h", + m="m", + n="n", + out="ds" + ) | indent_except_first(3) }} + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # compute dk = dot(ds.T, q) + dk += tl.dot(tl.trans(ds.to(MATMUL_PRECISION)), q) + # compute dq + dq = tl.load(dq_ptrs) + dq += tl.dot(ds.to(MATMUL_PRECISION), k) + + # Store grad_query + tl.store(dq_ptrs, dq) + + # increment pointers + dq_ptrs += BLOCK_M * stride_qm + q_ptrs += BLOCK_M * stride_qm + do_ptrs += BLOCK_M * stride_qm + + # write-back + index_n = offs_n[:, None] + index_k = offs_k[None, :] + + # Store grad_key and grad_value + dv_ptrs = DV + (index_n * stride_vn + index_k * stride_vk) + tl.store(dv_ptrs, dv) + + # TODO generalize and add proper mask support + mask = (index_n != -1) & (index_k != -1) + {{store_output(("off_z", "off_h", "index_n", "index_k"), "dk", "mask", indent_width=8)}} + + """, +) + + +# TODO: We probably also need a layout constraint? +@register_lowering( + torch.ops.higher_order.flex_attention_backward, type_promotion_kind=None +) +def flex_attention_backward(*args, **kwargs): + ( + query, + key, + value, + out, + logsumexp, + grad_out, + fw_graph, + joint_graph, + *other_buffers, + ) = args + + device = query.get_device() + dtype = query.get_dtype() + + fwd_placeholder_inps = [ + create_placeholder(name, dtype, device) for name, dtype in [ - ("score", query.get_dtype()), + ("score", dtype), ("b", torch.int32), ("h", torch.int32), ("m", torch.int32), ("n", torch.int32), ] ] - for node in subgraph.graph_module.graph.nodes: - # There are two classes of placeholder inpts that we need - # to handle differently. For the first n_scalar_inps inputs - # we expect that these placeholders were generated by the make_fx call - # in the flex Attention HOP. So we need to create a new placeholder - # TensorBox for each of these inputs. For the rest of the inputs we - # expect that these are lifted inputs that fill up the '*other_buffers' - # tuple and already have corresponding TensorBoxes passed in as args. - if node.op == "placeholder": - is_lifted_input = cnt >= len(scalar_inps) - env[node] = args[cnt - 1] if is_lifted_input else placeholder_inps[cnt] - cnt += 1 - elif node.op == "call_function": - # For call_function we use the defulat lowerings and pass in the - # already created TensorBoxes as args - from torch.utils._pytree import tree_map + fw_subgraph_buffer = build_subgraph_buffer( + args, fwd_placeholder_inps, fw_graph, graph_type=SubgraphType.JOINT_FWD + ) - env[node] = lowerings[node.target]( - *tree_map(lambda x: env[x] if x in env else x, node.args) - ) - elif node.op == "output": - # For the output node we need to create a ComputedBuffer - # which represents the actual score modification + joint_placeholder_inps = fwd_placeholder_inps + [ + create_placeholder("out", dtype, device) + ] + joint_subgraph_buffer = build_subgraph_buffer( + args, joint_placeholder_inps, joint_graph, graph_type=SubgraphType.JOINT_BWD + ) - output_buffer = env[node.args[0]] - assert isinstance(output_buffer.data, StorageBox), ( - "The output node for the flex attention subgraph must be a StorageBox, but got: ", - type(output_buffer), - ) - # Create the ComputedBuffer directly that will be inlined into the modification block - subgraph_buffer = ComputedBuffer( - name=None, - layout=FlexibleLayout( - device=output_buffer.data.get_device(), - dtype=output_buffer.data.get_dtype(), - size=output_buffer.data.get_size(), - ), - data=output_buffer.data.data, # type: ignore[arg-type] - ) + layout_k = FixedLayout( + key.get_device(), + key.get_dtype(), + key.get_size(), + make_contiguous_strides_for(key.get_size()), + ) - layout = FixedLayout( - output_buffer.get_device(), - query.get_dtype(), - query.get_size(), - make_contiguous_strides_for(query.get_size()), - ) - # see NOTE:[TritonTemplates with multiple outputs] - logsumexp_shape = query.get_size()[:-1] # [B, H, M] - logsumexp = empty_strided( - logsumexp_shape, - None, - dtype=torch.float32, # The logsumexp is always stored in fp32 regardless of the input dtype - device=output_buffer.get_device(), - ) - choices: List[Any] = [] - configs: List[Any] = [] - configs.append(_get_default_config(query)) - if config.max_autotune: - configs += [ - (128, 64, 4, 3), - (128, 128, 4, 3), - (128, 128, 8, 2), - (64, 128, 4, 3), - (64, 64, 4, 3), - ] - # Note, we don't need to pass in the captured buffers explicitly - # because they're implicitly added by the score_mod function - # We do need to explicitly pass it in for autotuning though. - for BLOCK_M, BLOCK_N, num_warps, num_stages in configs: - sdpa_template.maybe_append_choice( - choices=choices, - input_nodes=[query, key, value, logsumexp], - layout=layout, - subgraphs=subgraph_buffer, - mutated_inputs=[ - logsumexp, - ], - num_stages=num_stages, - num_warps=num_warps, - BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, - BLOCK_DMODEL=query.get_size()[-1], - # For now, we always assume the "sound" option - SCORE_MOD_IS_LINEAR=False, - ROWS_GUARANTEED_SAFE=False, - OUTPUT_LOGSUMEXP=True, - ) - inputs_for_autotuning = [query, key, value, logsumexp] + list(other_buffers) - return ( - autotune_select_algorithm( - "sdpa", choices, inputs_for_autotuning, layout - ), + # Create delta which will is needed for the bwd's kernel + mul_delta = lowerings[aten.mul](out, grad_out) + delta = lowerings[aten.sum](mul_delta, axis=-1) + + # see NOTE:[TritonTemplates with multiple outputs] + grad_query = full( + query.get_size(), 0.0, dtype=dtype, device=device + ) # torch.zeros equivalent + grad_query.realize() + grad_value = empty_strided(value.get_size(), None, dtype=dtype, device=device) + + choices: List[Any] = [] + configs: List[Tuple[int, int, int, int]] = [] + configs.append(_get_default_config_bwd(query)) + if config.max_autotune: + configs += [ + (128, 128, 4, 3), + (128, 128, 8, 1), + (64, 64, 4, 3), + (64, 64, 8, 1), + ] + + for BLOCK_M, BLOCK_N, num_warps, num_stages in configs: + flex_attention_backward_template.maybe_append_choice( + choices=choices, + input_nodes=[ + query, + key, + value, + out, logsumexp, - ) - raise ValueError("TemplatedAttention was passed a subgraph with no output node!") + delta, + grad_out, + grad_query, + grad_value, + ], + layout=layout_k, # We use store_output only for grad_key + subgraphs=[fw_subgraph_buffer, joint_subgraph_buffer], + mutated_inputs=[grad_query, grad_value], + num_stages=num_stages, + num_warps=num_warps, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_DMODEL=query.get_size()[-1], + NUM_Q_BLOCKS=math.ceil(query.get_size()[-2] / BLOCK_M), + # For now, we always assume the "sound" option + SCORE_MOD_IS_LINEAR=False, + ) + inputs_for_autotuning = [ + query, + key, + value, + out, + logsumexp, + delta, + grad_out, + grad_query, + grad_value, + ] + list(other_buffers) + + grad_key = autotune_select_algorithm( + "flex_attention_backward", choices, inputs_for_autotuning, layout_k + ) + return ( + grad_query, + grad_key, + grad_value, + ) diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 0999e6ce3b21..eadb3d8159ab 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -103,7 +103,7 @@ def __init__( prefix_args=0, suffix_args=0, epilogue_fn=identity, - subgraphs=None, + subgraphs: Optional[List[ir.ComputedBuffer]] = None, *, index_dtype, ): @@ -114,7 +114,7 @@ def __init__( ) self.input_nodes = input_nodes self.output_node = output_node - self.named_input_nodes = {} + self.named_input_nodes = {} # type: ignore[var-annotated] self.defines = defines self.kernel_name = kernel_name self.template_mask = None @@ -128,10 +128,10 @@ def __init__( self.prefix_args = prefix_args self.suffix_args = suffix_args self.epilogue_fn = epilogue_fn - self.render_hooks = dict() + self.render_hooks = dict() # type: ignore[var-annotated] self.triton_meta: Optional[Dict[str, object]] = None - # For Templated Attention - self.subgraphs = subgraphs + # For Templated Attention this can be a list of ir.Subgraph + self.subgraphs: Optional[List[ir.ComputedBuffer]] = subgraphs def need_numel_args(self): return False @@ -271,19 +271,28 @@ def stride(self, name, index): val = self.named_input_nodes[name].get_stride()[index] return texpr(self.rename_indexing(val)) - def modification(self, **fixed_inputs) -> str: - """This function generates the code body to populate - a 'modification' placeholder within a template + def modification(self, subgraph_number: int, **fixed_inputs) -> str: + """This creates a modification function for a subgraph. + To use this inside a template, the first argument should specify which subgraph to codegen for - TODO come up with standardized way to modify templates, with - potential multiple modifications + Args: + subgraph_number (int): The index of the subgraph in self.subgraphs """ + assert isinstance(subgraph_number, int) + assert isinstance(self.subgraphs, list) + assert subgraph_number < len( + self.subgraphs + ), f"Invalid subgraph number provided to create_modification, {subgraph_number} must be < {len(self.subgraphs)}" + + subgraph = self.subgraphs[subgraph_number] def add_input(name): return self.args.input(name) + name = f"PlaceholderSubstitution_{subgraph_number}" + class PlaceholderSubstitution(V.WrapperHandler): # type: ignore[name-defined] - self.name = "PlaceholderSubstitution" + self.name = name def load(self, name: str, index: sympy.Expr): if name not in fixed_inputs: @@ -297,15 +306,14 @@ def load(self, name: str, index: sympy.Expr): def indirect_indexing(self, index_var, size, check): return sympy_index_symbol(str(index_var)) - # if self.modification_cache is None: with V.set_ops_handler(PlaceholderSubstitution(V.ops)): assert isinstance( - self.subgraphs, ir.ComputedBuffer - ), "Expected the subgraph to be a ComputedBuffer" - if isinstance(self.subgraphs.data, ir.InputBuffer): - out = self.subgraphs.data.make_loader()((1,)) + subgraph, ir.ComputedBuffer + ), f"Expected the subgraph to be a ComputedBuffer, got {type(subgraph)}" + if isinstance(subgraph.data, ir.InputBuffer): + out = subgraph.data.make_loader()((1,)) else: - out = self.subgraphs.data.inner_fn((1,)) + out = subgraph.data.inner_fn((1,)) self.codegen_body() self.body.writeline(f"{fixed_inputs['out']} = {out.value}") @@ -320,11 +328,18 @@ def store_output( indices: Union[List[Any], Tuple[Any]], val: str, mask: Optional[str] = None, + indent_width: int = 4, ): - """ - Hook called from template code to store the final output - (if the buffer hasn't been optimized away), then append any - epilogue fusions. + """Stores the final output and appends any epilogue fusions if the buffer hasn't been optimized away. + + Args: + indices (Union[List, Tuple]): The index for each dimension of the output. The dot product of + these indices and output strides must match `val`. + val (str): The value to store. + mask (Optional[str]): An optional mask to use for the store operation. If provided, this mask + will be applied to the store. + indent_width (int): The number of spaces to use for indentation. This is used when the call to + store_output is indented in the kernel definition. """ assert isinstance(indices, (list, tuple)) assert isinstance(val, str) @@ -348,7 +363,7 @@ def store_output( self.range_trees[0].lookup(sympy.Integer(1), sympy_product(lengths)).set_name( "xindex" ) - self.template_mask = mask + self.template_mask = mask # type: ignore[assignment] self.template_indices = indices output_index = self.output_node.get_layout().make_indexer()(index_symbols) output_index = self.rename_indexing(output_index) @@ -373,7 +388,7 @@ def store_output( def hook(): # more stuff might have been added since the codegen_body above self.codegen_body() - return textwrap.indent(self.body.getvalue(), " ").strip() + return textwrap.indent(self.body.getvalue(), " " * indent_width).strip() assert "" not in self.render_hooks self.render_hooks[""] = hook @@ -1334,7 +1349,7 @@ def log_results( result = timings[choice] if result: sys.stderr.write( - f" {choice.name} {result:.4f} ms {best_time/result:.1%}\n" + f" {choice.name} {result:.4f} ms {best_time / result:.1%}\n" ) else: sys.stderr.write( diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index c10e3cc512f0..1bbeac16e21e 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -339,7 +339,7 @@ def print_performance( ): timings = torch.tensor([timed(fn, args, times, device) for _ in range(repeat)]) took = torch.median(timings) / times - print(f"{took/baseline:.6f}") + print(f"{took / baseline:.6f}") return took diff --git a/torch/nn/attention/_flex_attention.py b/torch/nn/attention/_flex_attention.py index 6f158c20db6d..c56374fcbc40 100644 --- a/torch/nn/attention/_flex_attention.py +++ b/torch/nn/attention/_flex_attention.py @@ -96,6 +96,8 @@ def score_mod( raise ValueError( "NYI: The target sequence length (L) of the query tensor must match the source sequence length (S) of the key tensor." ) + if query.size(-2) % 128 != 0: + raise ValueError("NYI: S and L must be a multiple of 128") if not torch._dynamo.is_dynamo_supported(): raise RuntimeError("flex_attention requires dynamo support.") @@ -149,7 +151,7 @@ def _rel_causal( token_q: torch.Tensor, token_kv: torch.Tensor, ) -> torch.Tensor: - return torch.where(token_q <= token_kv, score + (token_q - token_kv), float("-inf")) + return torch.where(token_q >= token_kv, score + (token_q - token_kv), float("-inf")) def _generate_alibi_bias(num_heads: int): diff --git a/torch/testing/_internal/hop_db.py b/torch/testing/_internal/hop_db.py index 4772fb42a963..df78812a6504 100644 --- a/torch/testing/_internal/hop_db.py +++ b/torch/testing/_internal/hop_db.py @@ -118,9 +118,9 @@ def score_mod(score, b, h, m, n): return score + h yield SampleInput( - make_arg(2, 2, 64, 8, low=0.1, high=2), - make_arg(2, 2, 64, 8, low=0.1, high=2), - make_arg(2, 2, 64, 8, low=0.1, high=2), + make_arg(2, 2, 128, 8, low=0.1, high=2), + make_arg(2, 2, 128, 8, low=0.1, high=2), + make_arg(2, 2, 128, 8, low=0.1, high=2), score_mod, ) From 8bb7a2f46da40569589197bf1ae730ccbba6bddf Mon Sep 17 00:00:00 2001 From: Yidi Wu Date: Thu, 16 May 2024 11:32:16 -0700 Subject: [PATCH 038/116] Fix documentation for register_fake_class (#126422) Pull Request resolved: https://github.com/pytorch/pytorch/pull/126422 Approved by: https://github.com/angelayi --- torch/_library/fake_class_registry.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/torch/_library/fake_class_registry.py b/torch/_library/fake_class_registry.py index 32a9aa8c8711..d77989cd829b 100644 --- a/torch/_library/fake_class_registry.py +++ b/torch/_library/fake_class_registry.py @@ -134,8 +134,10 @@ def register_fake_class(qualname, fake_class: Optional[HasStaticMethodFromReal] returns an instance of the fake class. All tensors in the fake object should also be properly fakified with to_fake_tensor() in from_real. + Examples: # For a custom class Foo defined in test_custom_class_registration.cpp: + TORCH_LIBRARY(_TorchScriptTesting, m) { m.class_("_TensorQueue") .def(torch::init()) @@ -144,6 +146,7 @@ def register_fake_class(qualname, fake_class: Optional[HasStaticMethodFromReal] .def("top", &TensorQueue::top) .def("size", &TensorQueue::size) .def("clone_queue", &TensorQueue::clone_queue) + .def("__obj_flatten__", &TensorQueue::__obj_flatten__) .def_pickle( // __getstate__ [](const c10::intrusive_ptr& self) @@ -166,8 +169,7 @@ def __init__(self, queue): @classmethod def __obj_unflatten__(cls, flattened_ctx): - ctx = {flattened_ctx[0]: flattened_ctx[1]} - return cls(**ctx) + return cls(**dict(ctx)) def push(self, x): self.queue.append(x) @@ -178,6 +180,11 @@ def pop(self): def size(self): return len(self.queue) + In this example, the original TensorQeue need to addd a __obj_flatten__ method + to the class TensorQueue and the flattend result is passed into FakeTensorQueue's + __obj_unflatten__ as inputs to create a fake class. This protocol allows pytorch to look + at the contents of the script object and properly handle them in the subsystems + like dynamo, aot_aotugrad or more. """ def inner(fake_class: HasStaticMethodFromReal): From 1018a68e3121fe259011046860b7a20ec714ee83 Mon Sep 17 00:00:00 2001 From: angelayi Date: Fri, 17 May 2024 00:48:32 +0000 Subject: [PATCH 039/116] [export] Delete predispatch tests (#126459) Deleting predispatch tests as we moved export to predispatch already Pull Request resolved: https://github.com/pytorch/pytorch/pull/126459 Approved by: https://github.com/tugsbayasgalan --- test/export/test_export.py | 3 -- test/export/test_export_predispatch.py | 50 -------------------------- test/export/test_serdes.py | 23 ------------ 3 files changed, 76 deletions(-) delete mode 100644 test/export/test_export_predispatch.py diff --git a/test/export/test_export.py b/test/export/test_export.py index cec463fa3dc0..406e1f55dd80 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -1522,7 +1522,6 @@ def forward(self, arg1, arg2, *args, kw1, kw2, **kwargs): self._test_export_same_as_eager(kw_func, args, kwargs) @testing.expectedFailureSerDer # we don't save placeholder metadata - @testing.expectedFailureSerDerPreDispatch @testing.expectedFailureNonStrict def test_linear_conv(self): class MyLinear(torch.nn.Module): @@ -2902,7 +2901,6 @@ def forward(self, xs, y): ) @testing.expectedFailureSerDer # We don't preserve metadata on graph module - @testing.expectedFailureSerDerPreDispatch @testing.expectedFailureNonStrict def test_retrace_graph_level_meta_preservation(self): class Foo(torch.nn.Module): @@ -3692,7 +3690,6 @@ def forward(self, q, k, v): self.assertEqual(ep.module()(*inputs), m(*inputs)) @testing.expectedFailureSerDer # symfloat nyi - @testing.expectedFailureSerDerPreDispatch # symfloat nyi def test_sym_sqrt(self): import math diff --git a/test/export/test_export_predispatch.py b/test/export/test_export_predispatch.py deleted file mode 100644 index 2075cba58ca6..000000000000 --- a/test/export/test_export_predispatch.py +++ /dev/null @@ -1,50 +0,0 @@ -# Owner(s): ["oncall: export"] - -try: - from . import test_export, testing -except ImportError: - import test_export - import testing -from torch.export._trace import _export - -test_classes = {} - - -def mocked_predispatch_export(*args, **kwargs): - # If user already specified strict, don't make it non-strict - ep = _export(*args, **kwargs, pre_dispatch=True) - return ep.run_decompositions() - - -def make_dynamic_cls(cls): - suffix = "_pre_dispatch" - - cls_prefix = "PreDispatchExport" - - test_class = testing.make_test_cls_with_mocked_export( - cls, - cls_prefix, - suffix, - mocked_predispatch_export, - xfail_prop="_expected_failure_pre_dispatch", - ) - - test_classes[test_class.__name__] = test_class - # REMOVING THIS LINE WILL STOP TESTS FROM RUNNING - globals()[test_class.__name__] = test_class - test_class.__module__ = __name__ - return test_class - - -tests = [ - test_export.TestDynamismExpression, - test_export.TestExport, -] -for test in tests: - make_dynamic_cls(test) -del test - -if __name__ == "__main__": - from torch._dynamo.test_case import run_tests - - run_tests() diff --git a/test/export/test_serdes.py b/test/export/test_serdes.py index 253c6db81819..bd11cd7f8366 100644 --- a/test/export/test_serdes.py +++ b/test/export/test_serdes.py @@ -9,7 +9,6 @@ import testing from torch.export import export, load, save -from torch.export._trace import _export test_classes = {} @@ -23,21 +22,10 @@ def mocked_serder_export(*args, **kwargs): return loaded_ep -def mocked_serder_export_pre_dispatch(*args, **kwargs): - ep = _export(*args, **kwargs, pre_dispatch=True) - buffer = io.BytesIO() - save(ep, buffer) - buffer.seek(0) - loaded_ep = load(buffer) - return loaded_ep - - def make_dynamic_cls(cls): suffix = "_serdes" - suffix_pre_dispatch = "_serdes_pre_dispatch" cls_prefix = "SerDesExport" - cls_prefix_pre_dispatch = "SerDesExportPreDispatch" test_class = testing.make_test_cls_with_mocked_export( cls, @@ -47,21 +35,10 @@ def make_dynamic_cls(cls): xfail_prop="_expected_failure_serdes", ) - test_class_pre_dispatch = testing.make_test_cls_with_mocked_export( - cls, - cls_prefix_pre_dispatch, - suffix_pre_dispatch, - mocked_serder_export_pre_dispatch, - xfail_prop="_expected_failure_serdes_pre_dispatch", - ) - test_classes[test_class.__name__] = test_class - test_classes[test_class_pre_dispatch.__name__] = test_class_pre_dispatch # REMOVING THIS LINE WILL STOP TESTS FROM RUNNING globals()[test_class.__name__] = test_class - globals()[test_class_pre_dispatch.__name__] = test_class_pre_dispatch test_class.__module__ = __name__ - test_class_pre_dispatch.__module__ = __name__ tests = [ From 697ed6f5b3484a09410af075c34419e94fa42592 Mon Sep 17 00:00:00 2001 From: Andrew Gu Date: Thu, 16 May 2024 07:48:29 -0700 Subject: [PATCH 040/116] [DeviceMesh] Supported N groups in `from_group` (#126258) **Overview** This PR supports constructing an ND mesh with `from_group()` by passing in `group: List[ProcessGroup]` and `mesh: Union[torch.Tensor, "ArrayLike"]` together. The `ndim` of the device mesh returned from `from_group()` is equal to the number of `ProcessGroup`s passed. If the `ndim` is greater than 1, then the `mesh` argument is required (since there is no simple way to recover the `mesh` tensor from the process groups otherwise). This PR also adds `mesh_dim_names` as an argument to forward to the device mesh for convenience.
Old Approach **Overview** - This PR mainly adds `mesh_shape` to `from_group()` so that the user can construct an ND (N > 1) device mesh from a process group. This is to unblock HSDP, where we can pass the overall data parallel process group to `from_group()` with `mesh_shape = (replicate_dim_size, shard_dim_size)` and `from_group()` will construct subgroups for the user. (The user can then get the subgroups from the submeshes.) - Constructing the 2D `DeviceMesh` from an existing shard process group and replicate process group is hard because we cannot easily recover the array of ranks in their parent group on each rank in general. - This PR also adds `mesh_dim_names` to `from_group()` so that the user can name the mesh dimensions of the constructed device mesh.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126258 Approved by: https://github.com/wanchaol --- .../_composable/fsdp/test_fully_shard_init.py | 94 +++++++++++++++++-- test/distributed/test_device_mesh.py | 64 ++++++++++++- torch/distributed/device_mesh.py | 66 +++++++++++-- 3 files changed, 205 insertions(+), 19 deletions(-) diff --git a/test/distributed/_composable/fsdp/test_fully_shard_init.py b/test/distributed/_composable/fsdp/test_fully_shard_init.py index 3dfaab80dbe1..73e078c0b2f2 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_init.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_init.py @@ -24,6 +24,10 @@ Shard, ) from torch.distributed.device_mesh import init_device_mesh +from torch.distributed.fsdp._init_utils import ( + _init_inter_node_process_group, + _init_intra_node_process_group, +) from torch.distributed.tensor.parallel import ( ColwiseParallel, parallelize_module, @@ -672,7 +676,7 @@ def world_size(self) -> int: return 4 @unittest.skipIf(not TEST_CUDA, "no cuda") - def test_process_group_init(self): + def test_1d_process_group_init(self): assert self.world_size == 4, f"{self.world_size}" # For convenience, use device mesh's infra to construct the DP PG # (in practice, the trainer would do it manually via `new_group()`) @@ -684,11 +688,10 @@ def test_process_group_init(self): dp_pg = ref_dp_mesh.get_group(0) # Check the `from_group()` API for correctness - dp_mesh = DeviceMesh.from_group(dp_pg, "cuda") - # We only compare the mesh tensors instead of the DeviceMesh objects - # since mesh_dim_names attributes and parent mesh are different. + dp_mesh = DeviceMesh.from_group(dp_pg, "cuda", mesh_dim_names=("dp",)) + # Only compare the mesh tensors, not `DeviceMesh` objects themselves, + # since the ref has a parent mesh, while the `from_group` one does not self.assertEqual(dp_mesh.mesh, ref_dp_mesh.mesh) - # self.assertFalse(hasattr(dp_mesh, "_coordinate_on_dim")) self.assertEqual(dp_mesh._coordinate_on_dim, ref_dp_mesh._coordinate_on_dim) self.assertEqual(dp_mesh._dim_group_infos, ref_dp_mesh._dim_group_infos) @@ -722,7 +725,9 @@ def test_process_group_init(self): loss.backward() self.assertEqual(loss, ref_loss) for param, ref_param in zip(model.parameters(), ref_model.parameters()): - # we cannot directly compare param and ref_param because their parent mesh is different. + # Cannot compare `DTensor`s directly since their meshes are not + # equal due to the ref parameter's mesh having a parent mesh while + # the other's mesh does not self.assertEqual(param.to_local(), ref_param.to_local()) self.assertEqual(param.device_mesh.mesh, ref_param.device_mesh.mesh) self.assertEqual(param.grad.to_local(), ref_param.grad.to_local()) @@ -730,6 +735,83 @@ def test_process_group_init(self): param.grad.device_mesh.mesh, ref_param.grad.device_mesh.mesh ) + @unittest.skipIf(not TEST_CUDA, "no cuda") + def test_2d_process_group_init(self): + shard_mesh_dim_size = 2 + assert ( + self.world_size % shard_mesh_dim_size == 0 + ), f"Expects {self.world_size} to be divisible by {shard_mesh_dim_size}" + replicate_mesh_dim_size = self.world_size // shard_mesh_dim_size + mesh_dim_names = ("replicate", "shard") + ref_mesh = init_device_mesh( + "cuda", + (replicate_mesh_dim_size, shard_mesh_dim_size), + mesh_dim_names=mesh_dim_names, + ) + + # Use the global PG as the parent group (in practice, this could be a + # subgroup of the global PG) + dp_group = dist.distributed_c10d._get_default_group() + dp_shard_group = _init_intra_node_process_group(shard_mesh_dim_size) + dp_replicate_group = _init_inter_node_process_group( + dp_group, replicate_mesh_dim_size + ) + mesh_tensor = torch.tensor( + dist.get_process_group_ranks(dp_group), dtype=torch.int + ).view(replicate_mesh_dim_size, shard_mesh_dim_size) + + # Check the `from_group()` API for correctness + mesh = DeviceMesh.from_group( + [dp_replicate_group, dp_shard_group], + "cuda", + mesh_dim_names=mesh_dim_names, + mesh=mesh_tensor, + ) + self.assertEqual(mesh.mesh, ref_mesh.mesh) + self.assertEqual(mesh._coordinate_on_dim, ref_mesh._coordinate_on_dim) + for (tag, ranks, group_name), (ref_tag, ref_ranks, ref_group_name) in zip( + mesh._dim_group_infos, ref_mesh._dim_group_infos + ): + # Since we manually constructed new subgroups, the test and ref + # groups are not the same + self.assertEqual(ranks, ref_ranks) + for mesh_dim_name in mesh_dim_names: + child_mesh = mesh[mesh_dim_name] + ref_child_mesh = ref_mesh[mesh_dim_name] + self.assertEqual(child_mesh, ref_child_mesh) + child_ranks = dist.distributed_c10d.get_process_group_ranks( + child_mesh.get_group() + ) + ref_child_ranks = dist.distributed_c10d.get_process_group_ranks( + ref_child_mesh.get_group() + ) + self.assertEqual(child_ranks, ref_child_ranks) + + # Check HSDP forward/backward parity + torch.manual_seed(42) + mlp_dim = 8 + ref_model = MLP(mlp_dim) + for param in ref_model.parameters(): + dist.broadcast(param.detach(), src=0) + model = copy.deepcopy(ref_model) + + # Parallelize the test model with the ref mesh + for module in (ref_model.in_proj, ref_model.out_proj, ref_model): + fully_shard(module, mesh=ref_mesh) + # Parallelize the test model with the new mesh from the PG + for module in (model.in_proj, model.out_proj, model): + fully_shard(module, mesh=mesh) + + inp = torch.randn((4, mlp_dim), device="cuda") + ref_loss = ref_model(inp).sum() + ref_loss.backward() + loss = model(inp).sum() + loss.backward() + self.assertEqual(loss, ref_loss) + for param, ref_param in zip(model.parameters(), ref_model.parameters()): + self.assertEqual(param, ref_param) + self.assertEqual(param.grad, ref_param.grad) + class TestFullyShardHSDPBroadcast(FSDPTestMultiThread): @property diff --git a/test/distributed/test_device_mesh.py b/test/distributed/test_device_mesh.py index e6c1e27e23ce..8f70ee2f0b7d 100644 --- a/test/distributed/test_device_mesh.py +++ b/test/distributed/test_device_mesh.py @@ -168,7 +168,7 @@ def test_fake_pg_device_mesh(self): self.assertEqual(global_tensor.shape, (self.world_size * 2, 8)) @with_comms - def test_from_group(self): + def test_from_group_with_global_pg(self): # Simple test: check `from_group` for a global PG vs. directly # initializing via `init_device_mesh` global_pg = _get_default_group() @@ -180,6 +180,23 @@ def test_from_group(self): ref_global_mesh._coordinate_on_dim, global_mesh._coordinate_on_dim ) + @with_comms + def test_from_group_with_invalid_mesh(self): + global_pg = _get_default_group() + global_pg_size = global_pg.size() + assert global_pg_size == 4, "Test assumes global world size of 4" + invalid_mesh = [[0, 1], [2, 3]] # 2D mesh when we need 1D + regex = r"Invalid mesh \[\[0, 1\], \[2, 3\]\] for ProcessGroup with ranks \[0, 1, 2, 3\]" + with self.assertRaisesRegex(ValueError, regex): + DeviceMesh.from_group(global_pg, "cuda", invalid_mesh) + + device_mesh = init_device_mesh(self.device_type, (2, 2)) + groups = device_mesh.get_group() + invalid_mesh = (0, 1, 2, 3) # 1D mesh when we need 2D + regex = r"Expects mesh with ndim equal to number of ProcessGroups but got mesh \[0, 1, 2, 3\] and 2 ProcessGroups" + with self.assertRaisesRegex(ValueError, regex): + DeviceMesh.from_group(groups, self.device_type, invalid_mesh) + def test_raises_invalid_device_type(self): with self.assertRaisesRegex( RuntimeError, @@ -280,8 +297,8 @@ def test_device_mesh_parent_child_hash(self): ep_mesh_1 = DeviceMesh(self.device_type, mesh_group_1) ep_mesh_2 = DeviceMesh(self.device_type, mesh_group_2) ep_mesh = ep_mesh_1 if self.rank < self.world_size // 2 else ep_mesh_2 - # # ep_mesh is considered different from mesh_2d["TP"] - # # since mesh_2d["TP"] has a parent mesh while ep_mesh does not. + # ep_mesh is considered different from mesh_2d["TP"] + # since mesh_2d["TP"] has a parent mesh while ep_mesh does not. self.assertEqual(mesh_2d["TP"]._flatten_mesh_list, ep_mesh._flatten_mesh_list) self.assertEqual(mesh_2d["TP"].mesh.shape, ep_mesh.mesh.shape) self.assertEqual(mesh_2d["TP"].device_type, ep_mesh.device_type) @@ -307,6 +324,47 @@ def test_device_mesh_parent_child_hash(self): self.assertEqual(hash(ep_mesh), hash(another_mesh)) self.assertEqual(ep_mesh, another_mesh) + @with_comms + def test_from_group_with_mesh_shape(self): + """Tests ``from_group`` when passing ``mesh_shape`` as 2D.""" + # Consider two different logical views of the same mesh: + # - (4, 2) ("dp", "tp") mesh + # - (2, 2, 2) ("dp_replicate", "dp_shard", "tp") mesh + mesh_shape = (2, 2, 2) + mesh_dim_names = ("dp_replicate", "dp_shard", "tp") + ref_mesh = init_device_mesh( + self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names + ) + + dp_shard_group = ref_mesh["dp_shard"].get_group() + dp_replicate_group = ref_mesh["dp_replicate"].get_group() + + dp_mesh = DeviceMesh.from_group( + [dp_replicate_group, dp_shard_group], + self.device_type, + mesh=ref_mesh.mesh[:, :, ref_mesh.get_local_rank(2)], + mesh_dim_names=mesh_dim_names[:2], + ) + + ref_mesh_dp_dim_group_infos = ref_mesh._dim_group_infos[:2] + for (_, ref_ranks, _), (_, ranks, _) in zip( + ref_mesh_dp_dim_group_infos, dp_mesh._dim_group_infos + ): + self.assertEqual(ref_ranks, ranks) + # Cannot check directly for mesh equality since parent meshes are not + # the same since the ref's parent mesh is 3D + self.assertEqual(dp_mesh["dp_replicate"].mesh, ref_mesh["dp_replicate"].mesh) + for (_, ref_ranks, _), (_, ranks, _) in zip( + dp_mesh["dp_replicate"]._dim_group_infos, + ref_mesh["dp_replicate"]._dim_group_infos, + ): + self.assertEqual(ref_ranks, ranks) + self.assertEqual(dp_mesh["dp_shard"].mesh, ref_mesh["dp_shard"].mesh) + for (_, ref_ranks, _), (_, ranks, _) in zip( + dp_mesh["dp_shard"]._dim_group_infos, ref_mesh["dp_shard"]._dim_group_infos + ): + self.assertEqual(ref_ranks, ranks) + class InitDeviceMeshTest(DTensorTestBase): @property diff --git a/torch/distributed/device_mesh.py b/torch/distributed/device_mesh.py index 05dc7710215d..c0981a549c6b 100644 --- a/torch/distributed/device_mesh.py +++ b/torch/distributed/device_mesh.py @@ -209,7 +209,7 @@ def __init__( self.mesh = ( mesh.detach().to(dtype=torch.int) if isinstance(mesh, torch.Tensor) - else torch.tensor(mesh, dtype=torch.int) + else torch.tensor(mesh, device="cpu", dtype=torch.int) ) self.mesh_dim_names = tuple(mesh_dim_names) if mesh_dim_names else None @@ -451,21 +451,67 @@ def get_group( return dim_groups @staticmethod - def from_group(group: ProcessGroup, device_type: str) -> "DeviceMesh": + def from_group( + group: Union[ProcessGroup, List[ProcessGroup]], + device_type: str, + mesh: Optional[Union[torch.Tensor, "ArrayLike"]] = None, + *, + mesh_dim_names: Optional[Tuple[str, ...]] = None, + ) -> "DeviceMesh": """ Contstructs a :class:`DeviceMesh` with ``device_type`` from an existing :class:`ProcessGroup`. - The constructed device mesh is assumed to be 1D. + The constructed device mesh has number of dimensions equal to the + number of groups passed. If more than one group is passed, then the + ``mesh`` argument is required. """ - # Manually define `_dim_group_infos` instead of relying on the - # normal logic since we already have the PG - group_ranks = get_process_group_ranks(group) - mesh = DeviceMesh(device_type, group_ranks, _init_backend=False) - mesh._dim_group_infos = [ - (_get_group_tag(group), group_ranks, group.group_name) + if isinstance(group, ProcessGroup): + group_ranks = get_process_group_ranks(group) + if ( + isinstance(mesh, torch.Tensor) and mesh.tolist() != group_ranks + ) or (mesh is not None and mesh != group_ranks): + raise ValueError( + f"Invalid mesh {str(mesh)} for ProcessGroup with ranks {group_ranks}" + ) + mesh = torch.tensor(group_ranks, device="cpu", dtype=torch.int) + device_mesh = DeviceMesh( + device_type, + mesh, + mesh_dim_names=mesh_dim_names, + _init_backend=False, + ) + device_mesh._dim_group_infos = [ + (_get_group_tag(group), group_ranks, group.group_name) + ] + return device_mesh + groups = list(group) + if len(groups) == 0: + raise ValueError("Expects at least one ProcessGroup to be passed") + if mesh is None: + raise ValueError("Must pass mesh if passing multiple ProcessGroups") + mesh = ( + mesh.detach().to(dtype=torch.int, device="cpu") + if isinstance(mesh, torch.Tensor) + else torch.tensor(mesh, device="cpu", dtype=torch.int) + ) + if mesh.ndim != len(groups): + raise ValueError( + "Expects mesh with ndim equal to number of ProcessGroups but got " + f"mesh {mesh.tolist()} and {len(groups)} ProcessGroups" + ) + device_mesh = DeviceMesh( + device_type, mesh, mesh_dim_names=mesh_dim_names, _init_backend=False + ) + device_mesh._dim_group_infos = [ + ( + _get_group_tag(group), + get_process_group_ranks(group), + group.group_name, + ) + for group in groups ] - return mesh + return device_mesh def size(self, mesh_dim: Optional[int] = None) -> int: return self.mesh.numel() if mesh_dim is None else self.mesh.size(mesh_dim) From 776b87891714f6fb8da34db49f3e2fcd7a3c6090 Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Thu, 16 May 2024 22:26:05 +0000 Subject: [PATCH 041/116] [easy] Fix typing for `map_location` docs in torch.load (#125473) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Currently it incorrectly has `Callable[[Tensor, str], Tensor]` as a possible type signature, this should be `Callable[[Storage, str], Storage]` Screenshot 2024-05-03 at 12 09 54 PM Pull Request resolved: https://github.com/pytorch/pytorch/pull/125473 Approved by: https://github.com/albanD --- torch/serialization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/serialization.py b/torch/serialization.py index 616c21e80d7f..64a1e6e0ce06 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -32,7 +32,7 @@ STORAGE_KEY_SEPARATOR = ',' FILE_LIKE: TypeAlias = Union[str, os.PathLike, BinaryIO, IO[bytes]] -MAP_LOCATION: TypeAlias = Optional[Union[Callable[[torch.Tensor, str], torch.Tensor], torch.device, str, Dict[str, str]]] +MAP_LOCATION: TypeAlias = Optional[Union[Callable[[Storage, str], Storage], torch.device, str, Dict[str, str]]] STORAGE: TypeAlias = Union[Storage, torch.storage.TypedStorage, torch.UntypedStorage] IS_WINDOWS = sys.platform == "win32" From d2f5a8ac99b4f345eac95611688dbbaf210659be Mon Sep 17 00:00:00 2001 From: "Yu, Guangye" Date: Thu, 16 May 2024 14:21:22 +0000 Subject: [PATCH 042/116] [doc] expose torch.Tensor.xpu API to doc (#126383) # Motivation The doc string related `torch.Tensor.xpu` has been added [here](https://github.com/pytorch/pytorch/blob/d61a81a9e76688ac8f338a6cfba932bf7779e5ce/torch/_tensor_docs.py#L1434) but not expose it to public doc, like [torch.Tensor.cuda](https://pytorch.org/docs/stable/generated/torch.Tensor.cuda.html#torch.Tensor.cuda). This PR intends to expose the document of `torch.Tensor.xpu` to public doc. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126383 Approved by: https://github.com/albanD --- tools/pyi/gen_pyi.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tools/pyi/gen_pyi.py b/tools/pyi/gen_pyi.py index d5f6837cba01..a7eb81341eb5 100644 --- a/tools/pyi/gen_pyi.py +++ b/tools/pyi/gen_pyi.py @@ -1130,6 +1130,18 @@ def replace_special_case(hint: str) -> str: ) ) ], + "xpu": [ + "def xpu({}) -> Tensor: ...".format( + ", ".join( + [ + "self", + "device: Optional[Union[_device, _int, str]] = None", + "non_blocking: _bool = False", + "memory_format: torch.memory_format = torch.preserve_format", + ] + ) + ) + ], "cpu": [ "def cpu(self, memory_format: torch.memory_format = torch.preserve_format) -> Tensor: ..." ], From da1fc85d60fcf0bd1e8638d643a7c0c6560c3a5f Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Thu, 16 May 2024 13:27:23 -0700 Subject: [PATCH 043/116] Add symbolic_shape_specialization structured trace (#126450) This is typically the information you want when diagnosing why something overspecialized in dynamic shapes. Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/126450 Approved by: https://github.com/albanD --- torch/fx/experimental/symbolic_shapes.py | 27 ++++++++++++++++++------ 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index be1be24137f8..e310d490b77c 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -4397,6 +4397,9 @@ def _set_replacement(self, a: "sympy.Symbol", tgt: "sympy.Expr", msg: str) -> No Use this instead of `self.replacements[a] = tgt`. """ + if tgt == self.replacements.get(a, None): + return + # Precondition: a == tgt assert isinstance(a, sympy.Symbol) @@ -4487,14 +4490,24 @@ def issubset(x, y): "[%s not subset of %s (size-oblivious conditions)]", a, tgt, msg, tgt_bound_so, src_bound_so) return - if config.print_specializations and isinstance(tgt, (sympy.Integer, sympy.Float)): - # specializing to a constant, which is likely unexpected + if isinstance(tgt, (sympy.Integer, sympy.Float)): + # specializing to a constant, which is likely unexpected (unless + # you specified dynamic=True) + + user_tb = TracingContext.extract_stack() + trace_structured( + "symbolic_shape_specialization", + metadata_fn=lambda: { + "symbol": repr(a), + "sources": [s.name() for s in self.var_to_sources[a]], + "value": repr(tgt), + "reason": msg, + "stack": structured.from_traceback(CapturedTraceback.extract(skip=1).summary()), + "user_stack": structured.from_traceback(user_tb) if user_tb else None, + } + ) - # NOTE(avik): It is possible that we try logging the same specialization multiple times, e.g., - # when adding a to self.replacements, and again when simplifying an expression containing a. - # Thus to avoid duplication, checking whether a is in self.replacements isn't enough; if it is, - # it must not already map to `tgt`. Fortunately this check is cheap because `tgt` is a constant. - if a not in self.replacements or tgt != self.replacements[a]: + if config.print_specializations: self.log.warning("Specializing %s to %s", self.var_to_sources[a][0].name(), tgt) self.log.debug("SPECIALIZATION", stack_info=True) log.info("set_replacement %s = %s (%s) %s", a, tgt, msg, tgt_bound) From 1a27e24ff510ad7ff27485a1a9afca197bee81e6 Mon Sep 17 00:00:00 2001 From: Alex Denisov Date: Fri, 17 May 2024 04:19:23 +0000 Subject: [PATCH 044/116] Make inductor scheduler graph extension configurable (#125578) This patch makes the inductor scheduler graph extension configurable. It enables ease of debugging by changing the graph format (dot, png, etc.). Particularly, it's very convenient to work with the graph interactively using tools like https://github.com/tintinweb/vscode-interactive-graphviz Pull Request resolved: https://github.com/pytorch/pytorch/pull/125578 Approved by: https://github.com/Chillee --- torch/_functorch/config.py | 4 ++++ torch/_functorch/partitioners.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/torch/_functorch/config.py b/torch/_functorch/config.py index 5749477c6e98..c559951f3809 100644 --- a/torch/_functorch/config.py +++ b/torch/_functorch/config.py @@ -136,6 +136,10 @@ # of tensors in question. fake_tensor_propagate_real_tensors = False +# Controls the default graph output format used by draw_graph +# Supported formats are defined here https://graphviz.org/docs/outputs/ +torch_compile_graph_format = os.environ.get("TORCH_COMPILE_GRAPH_FORMAT", "svg") + if TYPE_CHECKING: from torch.utils._config_typing import * # noqa: F401, F403 diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py index ba549e5bd6e2..5a43cd5e7bf3 100644 --- a/torch/_functorch/partitioners.py +++ b/torch/_functorch/partitioners.py @@ -1342,7 +1342,7 @@ def draw_graph( node.meta = {} base, ext = os.path.splitext(fname) if not ext: - ext = ".svg" + ext = "." + config.torch_compile_graph_format print(f"Writing FX graph to file: {base}{ext}") g = graph_drawer.FxGraphDrawer( traced, From 88582195fde1c921b9c3f364cf5c11e83f797913 Mon Sep 17 00:00:00 2001 From: wz337 Date: Fri, 17 May 2024 04:29:17 +0000 Subject: [PATCH 045/116] [FSDP2][Test] Fix _test_clip_grad_norm (#126457) Fixes #ISSUE_NUMBER We need to compare ref_total_norm to total_norm.full_tensor(). Example: ``` iter_idx:0, rank:0,\ ref_total_norm=tensor(1052.5934, device='cuda:0'),\ total_norm=DTensor(local_tensor=482.0861511230469, device_mesh=DeviceMesh([0, 1]), placements=(_NormPartial(reduce_op='sum', norm_type=2.0),)),\ total_norm.full_tensor()=tensor(1052.5934, device='cuda:0') ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/126457 Approved by: https://github.com/awgu --- .../_composable/fsdp/test_fully_shard_clip_grad_norm_.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/distributed/_composable/fsdp/test_fully_shard_clip_grad_norm_.py b/test/distributed/_composable/fsdp/test_fully_shard_clip_grad_norm_.py index 4c22ea347156..d3978febec09 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_clip_grad_norm_.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_clip_grad_norm_.py @@ -73,7 +73,7 @@ def _test_clip_grad_norm( norm_type=norm_type, foreach=True, ) - self.assertEqual(ref_total_norm, total_norm) + self.assertEqual(ref_total_norm, total_norm.full_tensor()) # Expect one all-reduce per mesh dim for partial -> replicate expected_all_reduces = len(total_norm.placements) self.assertEqual( From a8c41e06784308441c67a265f6fd04ecd8c2afb7 Mon Sep 17 00:00:00 2001 From: eellison Date: Thu, 16 May 2024 16:16:53 -0700 Subject: [PATCH 046/116] dont pad 0 dim mm inputs (#126475) Otherwise you get an error in constant_pad_nd. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126475 Approved by: https://github.com/huydhn ghstack dependencies: #125772, #125773, #125780 --- test/inductor/test_pad_mm.py | 10 ++++++++++ torch/_inductor/fx_passes/pad_mm.py | 15 ++++++++++++--- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/test/inductor/test_pad_mm.py b/test/inductor/test_pad_mm.py index b16e5e5d62ed..bb37368f9567 100644 --- a/test/inductor/test_pad_mm.py +++ b/test/inductor/test_pad_mm.py @@ -169,6 +169,16 @@ def forward(self, a, b): res2, (code,) = run_and_get_code(compiled_fn, a, b) self.assertEqual(res1, res2) + @inductor_config.patch(force_shape_pad=True) + def test_zero_dim(self): + def addmm(x, a, b): + return torch.addmm(x, a, b) + + x = torch.randn(100).cuda() + a = torch.randn(0, 10).cuda() + b = torch.randn(10, 100).cuda() + self.assertEqual(torch.compile(addmm)(x, a, b), addmm(x, a, b)) + @inductor_config.patch(max_autotune=True, max_autotune_gemm_backends="TRITON") def test_pad_bmm_dyn_b(self): B = 10 diff --git a/torch/_inductor/fx_passes/pad_mm.py b/torch/_inductor/fx_passes/pad_mm.py index df282629e2ce..43f7e009af83 100644 --- a/torch/_inductor/fx_passes/pad_mm.py +++ b/torch/_inductor/fx_passes/pad_mm.py @@ -1,4 +1,5 @@ import functools +import itertools import operator from typing import List, Optional, Union @@ -325,6 +326,17 @@ def should_pad_bench( if m_padded_length == k_padded_length == n_padded_length == 0: return False + def realize_symbols(ds): + return [d if isinstance(d, int) else d.node.hint for d in ds] + + if any( + dim == 0 + for dim in itertools.chain( + realize_symbols(mat1.shape), realize_symbols(mat2.shape) + ) + ): + return False + if torch._inductor.config.force_shape_pad: return True @@ -342,9 +354,6 @@ def should_pad_bench( if cached_pad is not None: return cached_pad - def realize_symbols(ds): - return [d if isinstance(d, int) else d.node.hint for d in ds] - def realize_tensor(t): if isinstance(t, FakeTensor): size_hints = realize_symbols(t.size()) From 4b2ae2ac338f3a0de340c9711b03989b8ce66fc6 Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Fri, 17 May 2024 05:09:06 +0000 Subject: [PATCH 047/116] c10d: add Collectives abstraction (#125978) This adds a new `Collectives` API for doing distributed collectives operations. This is intended to replace the [current Elastic store abstraction](https://github.com/pytorch/pytorch/blob/main/torch/distributed/elastic/utils/store.py) with more performant and debugable primitives. Design doc: https://docs.google.com/document/d/147KcKJXEHvk1Q6tISLbJVvLejHg_1kIhBQeu-8RQxhY/edit The standard implementation is using `StoreCollectives` but other more performant backends will be added in a follow up PR. Test plan: ``` python test/distributed/test_collectives.py -v ``` This tests both functionality using multiple threads as well as timeout behavior. Pull Request resolved: https://github.com/pytorch/pytorch/pull/125978 Approved by: https://github.com/shuqiangzhang --- BUILD.bazel | 2 +- build_variables.bzl | 1 + test/distributed/test_control_collectives.py | 189 +++++++++++ torch/_C/_distributed_c10d.pyi | 14 + torch/csrc/distributed/c10d/HashStore.hpp | 2 +- torch/csrc/distributed/c10d/Store.hpp | 29 ++ .../ControlCollectives.hpp | 59 ++++ .../control_collectives/StoreCollectives.cpp | 222 +++++++++++++ .../control_collectives/StoreCollectives.hpp | 68 ++++ torch/csrc/distributed/c10d/init.cpp | 302 +++++++++++++++--- torch/distributed/__init__.py | 2 + 11 files changed, 837 insertions(+), 53 deletions(-) create mode 100644 test/distributed/test_control_collectives.py create mode 100644 torch/csrc/distributed/c10d/control_collectives/ControlCollectives.hpp create mode 100644 torch/csrc/distributed/c10d/control_collectives/StoreCollectives.cpp create mode 100644 torch/csrc/distributed/c10d/control_collectives/StoreCollectives.hpp diff --git a/BUILD.bazel b/BUILD.bazel index 3f7e6327452c..831d64b44c2f 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -772,7 +772,7 @@ cc_library( [ "torch/*.h", "torch/csrc/**/*.h", - "torch/csrc/distributed/c10d/*.hpp", + "torch/csrc/distributed/c10d/**/*.hpp", "torch/lib/libshm/*.h", ], exclude = [ diff --git a/build_variables.bzl b/build_variables.bzl index 3f16f9b847c1..152324a4d90c 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -487,6 +487,7 @@ libtorch_core_sources = sorted( # These files are the only ones that are supported on Windows. libtorch_distributed_base_sources = [ "torch/csrc/distributed/c10d/Backend.cpp", + "torch/csrc/distributed/c10d/control_collectives/StoreCollectives.cpp", "torch/csrc/distributed/c10d/FileStore.cpp", "torch/csrc/distributed/c10d/Functional.cpp", "torch/csrc/distributed/c10d/GlooDeviceFactory.cpp", diff --git a/test/distributed/test_control_collectives.py b/test/distributed/test_control_collectives.py new file mode 100644 index 000000000000..fb0067f2dd2e --- /dev/null +++ b/test/distributed/test_control_collectives.py @@ -0,0 +1,189 @@ +# Owner(s): ["oncall: distributed"] + +from datetime import timedelta +from multiprocessing.pool import ThreadPool + +import torch +import torch.distributed as dist +from torch.testing._internal.common_utils import run_tests, TestCase + + +class TestCollectives(TestCase): + def test_barrier(self) -> None: + store = dist.HashStore() + + world_size = 2 + + def f(rank: int) -> None: + collectives = dist._StoreCollectives(store, rank, world_size) + collectives.barrier("foo", timedelta(seconds=10), True) + + with ThreadPool(world_size) as pool: + pool.map(f, range(world_size)) + + def test_broadcast(self) -> None: + store = dist.HashStore() + + world_size = 4 + timeout = timedelta(seconds=10) + + def f(rank: int) -> None: + collectives = dist._StoreCollectives(store, rank, world_size) + if rank == 2: + collectives.broadcast_send("foo", b"data", timeout) + else: + out = collectives.broadcast_recv("foo", timeout) + self.assertEqual(out, b"data") + + with ThreadPool(world_size) as pool: + pool.map(f, range(world_size)) + + def test_gather(self) -> None: + store = dist.HashStore() + + world_size = 4 + timeout = timedelta(seconds=10) + + def f(rank: int) -> None: + collectives = dist._StoreCollectives(store, rank, world_size) + if rank == 2: + out = collectives.gather_recv("foo", str(rank), timeout) + self.assertEqual(out, [b"0", b"1", b"2", b"3"]) + else: + collectives.gather_send("foo", str(rank), timeout) + + with ThreadPool(world_size) as pool: + pool.map(f, range(world_size)) + + def test_scatter(self) -> None: + store = dist.HashStore() + + world_size = 4 + timeout = timedelta(seconds=10) + + def f(rank: int) -> None: + collectives = dist._StoreCollectives(store, rank, world_size) + if rank == 2: + out = collectives.scatter_send( + "foo", [str(i) for i in range(world_size)], timeout + ) + else: + out = collectives.scatter_recv("foo", timeout) + self.assertEqual(out, str(rank).encode()) + + with ThreadPool(world_size) as pool: + pool.map(f, range(world_size)) + + def test_all_sum(self) -> None: + store = dist.HashStore() + + world_size = 4 + timeout = timedelta(seconds=10) + + def f(rank: int) -> None: + collectives = dist._StoreCollectives(store, rank, world_size) + out = collectives.all_sum("foo", rank, timeout) + self.assertEqual(out, sum(range(world_size))) + + with ThreadPool(world_size) as pool: + pool.map(f, range(world_size)) + + def test_broadcast_timeout(self) -> None: + store = dist.HashStore() + + world_size = 4 + timeout = timedelta(milliseconds=1) + collectives = dist._StoreCollectives(store, 1, world_size) + with self.assertRaisesRegex(Exception, "Wait timeout"): + collectives.broadcast_recv("foo", timeout) + + def test_gather_timeout(self) -> None: + store = dist.HashStore() + + world_size = 4 + timeout = timedelta(milliseconds=1) + collectives = dist._StoreCollectives(store, 1, world_size) + with self.assertRaisesRegex( + Exception, "gather failed -- missing ranks: 0, 2, 3" + ): + collectives.gather_recv("foo", "data", timeout) + + def test_scatter_timeout(self) -> None: + store = dist.HashStore() + + world_size = 4 + timeout = timedelta(milliseconds=1) + collectives = dist._StoreCollectives(store, 1, world_size) + with self.assertRaisesRegex(Exception, "Wait timeout"): + collectives.scatter_recv("foo", timeout) + + def test_all_gather_timeout(self) -> None: + store = dist.HashStore() + + world_size = 4 + timeout = timedelta(milliseconds=1) + collectives = dist._StoreCollectives(store, 1, world_size) + with self.assertRaisesRegex( + Exception, "all_gather failed -- missing ranks: 0, 2, 3" + ): + collectives.all_gather("foo", "data", timeout) + + def test_barrier_timeout(self) -> None: + store = dist.HashStore() + + world_size = 4 + timeout = timedelta(milliseconds=1) + collectives = dist._StoreCollectives(store, 1, world_size) + with self.assertRaisesRegex( + Exception, "barrier failed -- missing ranks: 0, 2, 3" + ): + collectives.barrier("foo", timeout, True) + + def test_all_sum_timeout(self) -> None: + store = dist.HashStore() + + world_size = 4 + timeout = timedelta(milliseconds=1) + collectives = dist._StoreCollectives(store, 1, world_size) + with self.assertRaisesRegex( + Exception, "barrier failed -- missing ranks: 0, 2, 3" + ): + collectives.all_sum("foo", 1, timeout) + + def test_unique(self) -> None: + store = dist.HashStore() + + collectives = dist._StoreCollectives(store, 1, 1) + collectives.broadcast_send("foo", "bar") + + with self.assertRaisesRegex(Exception, "Key foo has already been used"): + collectives.broadcast_send("foo", "bar") + + with self.assertRaisesRegex(Exception, "Key foo has already been used"): + collectives.broadcast_recv("foo") + + with self.assertRaisesRegex(Exception, "Key foo has already been used"): + collectives.gather_send("foo", "bar") + + with self.assertRaisesRegex(Exception, "Key foo has already been used"): + collectives.gather_recv("foo", "asdf") + + with self.assertRaisesRegex(Exception, "Key foo has already been used"): + collectives.scatter_send("foo", ["asdf"]) + + with self.assertRaisesRegex(Exception, "Key foo has already been used"): + collectives.scatter_recv("foo") + + with self.assertRaisesRegex(Exception, "Key foo has already been used"): + collectives.all_gather("foo", "bar") + + with self.assertRaisesRegex(Exception, "Key foo has already been used"): + collectives.all_sum("foo", 2) + + +if __name__ == "__main__": + assert ( + not torch.cuda._initialized + ), "test_distributed must not have initialized CUDA context on main process" + + run_tests() diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index 28d790e3d690..74a73a3ddaa4 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -210,6 +210,20 @@ class PrefixStore(Store): @property def underlying_store(self) -> Store: ... +class _ControlCollectives: + def barrier(self, key: str, timeout: timedelta, blocking: bool) -> None: ... + def broadcast_send(self, key: str, data: str, timeout: timedelta) -> None: ... + def broadcast_recv(self, key: str, timeout: timedelta) -> str: ... + def gather_send(self, key: str, data: str, timeout: timedelta) -> None: ... + def gather_recv(self, key: str, timeout: timedelta) -> str: ... + def scatter_send(self, key: str, data: str, timeout: timedelta) -> None: ... + def scatter_recv(self, key: str, timeout: timedelta) -> str: ... + def all_gather(self, key: str, data: str, timeout: timedelta) -> str: ... + def all_sum(self, key: str, data: str, timeout: timedelta) -> int: ... + +class _StoreCollectives(_ControlCollectives): + def __init__(self, store: Store, rank: int, world_size: int) -> None: ... + class _DistributedBackendOptions: def __init__(self): ... @property diff --git a/torch/csrc/distributed/c10d/HashStore.hpp b/torch/csrc/distributed/c10d/HashStore.hpp index 1453c0a72808..3697d62301ba 100644 --- a/torch/csrc/distributed/c10d/HashStore.hpp +++ b/torch/csrc/distributed/c10d/HashStore.hpp @@ -22,7 +22,7 @@ class TORCH_API HashStore : public Store { std::vector get(const std::string& key) override; void wait(const std::vector& keys) override { - wait(keys, Store::kDefaultTimeout); + wait(keys, timeout_); } void wait( diff --git a/torch/csrc/distributed/c10d/Store.hpp b/torch/csrc/distributed/c10d/Store.hpp index af715ba98a79..993284fa7cc5 100644 --- a/torch/csrc/distributed/c10d/Store.hpp +++ b/torch/csrc/distributed/c10d/Store.hpp @@ -97,4 +97,33 @@ class TORCH_API Store : public torch::CustomClassHolder { std::chrono::milliseconds timeout_; }; +/* +StoreTimeoutGuard is a RAII guard that will set the store timeout and restore it +when it returns. +*/ +class StoreTimeoutGuard { + public: + explicit StoreTimeoutGuard( + Store& store, + const std::chrono::milliseconds& timeout) + : store_(store) { + oldTimeout_ = store.getTimeout(); + store.setTimeout(timeout); + } + + ~StoreTimeoutGuard() { + store_.setTimeout(oldTimeout_); + } + + /* Disabling copy and move semantics */ + StoreTimeoutGuard(const StoreTimeoutGuard&) = delete; + StoreTimeoutGuard& operator=(const StoreTimeoutGuard&) = delete; + StoreTimeoutGuard(StoreTimeoutGuard&&) = delete; + StoreTimeoutGuard& operator=(StoreTimeoutGuard&&) = delete; + + private: + Store& store_; + std::chrono::milliseconds oldTimeout_; +}; + } // namespace c10d diff --git a/torch/csrc/distributed/c10d/control_collectives/ControlCollectives.hpp b/torch/csrc/distributed/c10d/control_collectives/ControlCollectives.hpp new file mode 100644 index 000000000000..b98f9a71fb02 --- /dev/null +++ b/torch/csrc/distributed/c10d/control_collectives/ControlCollectives.hpp @@ -0,0 +1,59 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include +#include + +namespace c10d { + +using namespace std::chrono_literals; + +class TORCH_API ControlCollectives : public torch::CustomClassHolder { + public: + virtual void barrier( + const std::string& key, + std::chrono::milliseconds timeout = 5min, + bool block = true) = 0; + + virtual void broadcastSend( + const std::string& key, + const std::vector& data, + std::chrono::milliseconds timeout = 5min) = 0; + virtual std::vector broadcastRecv( + const std::string& key, + std::chrono::milliseconds timeout = 5min) = 0; + + virtual void gatherSend( + const std::string& key, + const std::vector& data, + std::chrono::milliseconds timeout = 5min) = 0; + virtual std::vector> gatherRecv( + const std::string& key, + const std::vector& data, + std::chrono::milliseconds timeout = 5min) = 0; + + virtual std::vector scatterSend( + const std::string& key, + const std::vector>& data, + std::chrono::milliseconds timeout = 5min) = 0; + virtual std::vector scatterRecv( + const std::string& key, + std::chrono::milliseconds timeout = 5min) = 0; + + virtual std::vector> allGather( + const std::string& key, + const std::vector& data, + std::chrono::milliseconds timeout = 5min) = 0; + + virtual int64_t allSum( + const std::string& key, + int64_t data, + std::chrono::milliseconds timeout = 5min) = 0; +}; + +} // namespace c10d diff --git a/torch/csrc/distributed/c10d/control_collectives/StoreCollectives.cpp b/torch/csrc/distributed/c10d/control_collectives/StoreCollectives.cpp new file mode 100644 index 000000000000..995899441d46 --- /dev/null +++ b/torch/csrc/distributed/c10d/control_collectives/StoreCollectives.cpp @@ -0,0 +1,222 @@ +#include +#include +#include +#include +#include +#include +#include + +namespace { +std::string getRankKey(const std::string& key, int rank) { + return fmt::format("{}/{}", key, rank); +} +} // namespace + +namespace c10d { + +StoreCollectives::StoreCollectives( + c10::intrusive_ptr<::c10d::Store> store, + int rank, + int worldSize) + : store_(std::move(store)), rank_(rank), worldSize_(worldSize) {} + +void StoreCollectives::barrier( + const std::string& key, + std::chrono::milliseconds timeout, + bool blocking) { + enforceUnique(key); + StoreTimeoutGuard g{*store_, timeout}; + + auto num_members_key = fmt::format("{}/num_members", key); + auto last_members_key = fmt::format("{}/last_members", key); + + auto idx = store_->add(num_members_key, 1); + store_->set(getRankKey(key, rank_), "joined"); + + if (idx == worldSize_) { + store_->set(last_members_key, ""); + } else if (blocking) { + try { + store_->wait({last_members_key}); + } catch (const std::exception& e) { + std::string msg = "barrier failed -- missing ranks: "; + for (int i = 0; i < worldSize_; i++) { + if (i == rank_) { + continue; + } + auto rank_key = getRankKey(key, i); + if (!store_->check({rank_key})) { + msg += fmt::format("{}, ", i); + } + } + throw std::runtime_error(msg + e.what()); + } + } +} + +void StoreCollectives::broadcastSend( + const std::string& key, + const std::vector& data, + std::chrono::milliseconds timeout) { + enforceUnique(key); + StoreTimeoutGuard g{*store_, timeout}; + + store_->set(key, data); +} + +std::vector StoreCollectives::broadcastRecv( + const std::string& key, + std::chrono::milliseconds timeout) { + enforceUnique(key); + StoreTimeoutGuard g{*store_, timeout}; + + return store_->get(key); +} + +void StoreCollectives::gatherSend( + const std::string& key, + const std::vector& data, + std::chrono::milliseconds timeout) { + enforceUnique(key); + StoreTimeoutGuard g{*store_, timeout}; + + auto rank_key = getRankKey(key, rank_); + store_->set(rank_key, data); +} + +std::vector> StoreCollectives::gatherRecv( + const std::string& key, + const std::vector& data, + std::chrono::milliseconds timeout) { + enforceUnique(key); + StoreTimeoutGuard g{*store_, timeout}; + + std::vector keys; + keys.reserve(worldSize_); + + for (int i = 0; i < worldSize_; i++) { + if (i == rank_) { + continue; + } + auto rank_key = getRankKey(key, i); + keys.emplace_back(rank_key); + } + + std::vector> results; + results.reserve(worldSize_); + + try { + results = store_->multiGet(keys); + } catch (const std::exception& e) { + std::string msg = "gather failed -- missing ranks: "; + for (int i = 0; i < worldSize_; i++) { + if (i == rank_) { + continue; + } + auto rank_key = getRankKey(key, i); + if (!store_->check({rank_key})) { + msg += fmt::format("{}, ", i); + } + } + throw std::runtime_error(msg + e.what()); + } + + // insert local data + results.insert(results.begin() + rank_, data); + return results; +} + +std::vector StoreCollectives::scatterSend( + const std::string& key, + const std::vector>& data, + std::chrono::milliseconds timeout) { + enforceUnique(key); + StoreTimeoutGuard g{*store_, timeout}; + + std::vector keys; + keys.reserve(worldSize_); + for (int i = 0; i < worldSize_; i++) { + if (i == rank_) { + continue; + } + auto rank_key = getRankKey(key, i); + keys.emplace_back(rank_key); + } + auto local = data.at(rank_); + + std::vector> toSend{data}; + + toSend.erase(toSend.begin() + rank_); + + store_->multiSet(keys, toSend); + + return local; +} + +std::vector StoreCollectives::scatterRecv( + const std::string& key, + std::chrono::milliseconds timeout) { + enforceUnique(key); + StoreTimeoutGuard g{*store_, timeout}; + + auto rank_key = getRankKey(key, rank_); + return store_->get(rank_key); +} + +std::vector> StoreCollectives::allGather( + const std::string& key, + const std::vector& data, + std::chrono::milliseconds timeout) { + enforceUnique(key); + StoreTimeoutGuard g{*store_, timeout}; + + auto localKey = getRankKey(key, rank_); + store_->set(localKey, data); + + std::vector keys; + keys.reserve(worldSize_); + + for (int i = 0; i < worldSize_; i++) { + auto rank_key = getRankKey(key, i); + keys.emplace_back(rank_key); + } + + try { + return store_->multiGet(keys); + } catch (const std::exception& e) { + std::string msg = "all_gather failed -- missing ranks: "; + for (int i = 0; i < worldSize_; i++) { + if (i == rank_) { + continue; + } + auto rank_key = getRankKey(key, i); + if (!store_->check({rank_key})) { + msg += fmt::format("{}, ", i); + } + } + throw std::runtime_error(msg + e.what()); + } +} + +int64_t StoreCollectives::allSum( + const std::string& key, + int64_t value, + std::chrono::milliseconds timeout) { + enforceUnique(key); + StoreTimeoutGuard g{*store_, timeout}; + + store_->add(key, value); + + barrier(key + "/barrier", timeout); + + return store_->add(key, 0); +} + +void StoreCollectives::enforceUnique(const std::string& key) { + auto it = seenKeys_.find(key); + TORCH_INTERNAL_ASSERT( + it == seenKeys_.end(), "Key ", key, " has already been used."); + seenKeys_.emplace(key); +} + +} // namespace c10d diff --git a/torch/csrc/distributed/c10d/control_collectives/StoreCollectives.hpp b/torch/csrc/distributed/c10d/control_collectives/StoreCollectives.hpp new file mode 100644 index 000000000000..7d3eb5038565 --- /dev/null +++ b/torch/csrc/distributed/c10d/control_collectives/StoreCollectives.hpp @@ -0,0 +1,68 @@ +#pragma once + +#include +#include +#include +#include + +namespace c10d { + +class TORCH_API StoreCollectives : public ControlCollectives { + public: + explicit StoreCollectives( + c10::intrusive_ptr store, + int rank, + int worldSize); + + void barrier( + const std::string& key, + std::chrono::milliseconds timeout = 5min, + bool block = true) override; + + void broadcastSend( + const std::string& key, + const std::vector& data, + std::chrono::milliseconds timeout = 5min) override; + std::vector broadcastRecv( + const std::string& key, + std::chrono::milliseconds timeout = 5min) override; + + void gatherSend( + const std::string& key, + const std::vector& data, + std::chrono::milliseconds timeout = 5min) override; + std::vector> gatherRecv( + const std::string& key, + const std::vector& data, + std::chrono::milliseconds timeout = 5min) override; + + std::vector scatterSend( + const std::string& key, + const std::vector>& data, + std::chrono::milliseconds timeout = 5min) override; + std::vector scatterRecv( + const std::string& key, + std::chrono::milliseconds timeout = 5min) override; + + std::vector> allGather( + const std::string& key, + const std::vector& data, + std::chrono::milliseconds timeout = 5min) override; + + int64_t allSum( + const std::string& key, + int64_t data, + std::chrono::milliseconds timeout = 5min) override; + + private: + void enforceUnique(const std::string& key); + + private: + c10::intrusive_ptr store_; + int rank_; + int worldSize_; + + c10::FastSet seenKeys_{}; +}; + +} // namespace c10d diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 483becbce009..505b64e2a697 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -6,6 +6,9 @@ #include #include #include +#include +#include +#include #ifndef _WIN32 #include #include @@ -136,6 +139,34 @@ namespace torch::distributed::c10d { namespace { +py::bytes toPyBytes(const std::vector& data) { + return py::bytes(reinterpret_cast(data.data()), data.size()); +} + +std::vector toPyBytes( + const std::vector>& data) { + std::vector out; + out.reserve(data.size()); + for (const std::vector& data_ : data) { + out.emplace_back(reinterpret_cast(data_.data()), data_.size()); + } + return out; +} + +std::vector toVec8(const std::string& data) { + std::vector out{data.begin(), data.end()}; + return out; +} + +std::vector> toVec8(const std::vector& data) { + std::vector> out; + out.reserve(data.size()); + for (auto& data_ : data) { + out.emplace_back(toVec8(data_)); + } + return out; +} + template using shared_ptr_class_ = py::class_>; @@ -166,8 +197,7 @@ class PythonStore : public ::c10d::Store { pybind11::get_overload(static_cast(this), "set"); TORCH_INTERNAL_ASSERT(fn, "Not implemented."); // Call function with a py::bytes object for the value. - fn(key, - py::bytes(reinterpret_cast(value.data()), value.size())); + fn(key, toPyBytes(value)); } // Note: this function manually calls the Python-side overload @@ -184,7 +214,7 @@ class PythonStore : public ::c10d::Store { // std::vector. There is no API for directly accessing // the contents of the py::bytes object. std::string str = pybind11::cast(fn(key)); - return std::vector(str.begin(), str.end()); + return toVec8(str); } // Note: this function manually calls the Python-side overload @@ -204,14 +234,8 @@ class PythonStore : public ::c10d::Store { // std::vector. There is no API for directly accessing // the contents of the py::bytes object. std::string str = pybind11::cast( - fn(key, - py::bytes( - reinterpret_cast(expectedValue.data()), - expectedValue.size()), - py::bytes( - reinterpret_cast(desiredValue.data()), - desiredValue.size()))); - return std::vector(str.begin(), str.end()); + fn(key, toPyBytes(expectedValue), toPyBytes(desiredValue))); + return toVec8(str); } int64_t add(const std::string& key, int64_t value) override { @@ -253,8 +277,7 @@ class PythonStore : public ::c10d::Store { return Store::append(key, value); } // Call function with a py::bytes object for the value. - fn(key, - py::bytes(reinterpret_cast(value.data()), value.size())); + fn(key, toPyBytes(value)); } std::vector> multiGet( @@ -287,14 +310,7 @@ class PythonStore : public ::c10d::Store { return Store::multiSet(keys, values); } - std::vector bytes; - bytes.reserve(values.size()); - for (auto& value : values) { - bytes.emplace_back( - reinterpret_cast(value.data()), value.size()); - } - - fn(keys, bytes); + fn(keys, toPyBytes(values)); } bool hasExtendedApi() const override { @@ -973,10 +989,7 @@ and :class:`~torch.distributed.HashStore`). "set", [](::c10d::Store& store, const std::string& key, - const std::string& value) { - std::vector value_(value.begin(), value.end()); - store.set(key, value_); - }, + const std::string& value) { store.set(key, toVec8(value)); }, py::call_guard(), R"( Inserts the key-value pair into the store based on the supplied ``key`` and @@ -1001,14 +1014,9 @@ Example:: const std::string& key, const std::string& expected_value, const std::string& desired_value) -> py::bytes { - std::vector expectedValue_( - expected_value.begin(), expected_value.end()); - std::vector desiredValue_( - desired_value.begin(), desired_value.end()); - auto value = - store.compareSet(key, expectedValue_, desiredValue_); - return py::bytes( - reinterpret_cast(value.data()), value.size()); + auto value = store.compareSet( + key, toVec8(expected_value), toVec8(desired_value)); + return toPyBytes(value); }, py::call_guard(), R"( @@ -1040,8 +1048,7 @@ Example:: py::gil_scoped_release guard; return store.get(key); }(); - return py::bytes( - reinterpret_cast(value.data()), value.size()); + return toPyBytes(value); }, R"( Retrieves the value associated with the given ``key`` in the store. If ``key`` is not @@ -1240,8 +1247,7 @@ Example:: [](::c10d::Store& store, const std::string& key, const std::string& value) { - std::vector value_(value.begin(), value.end()); - store.append(key, value_); + store.append(key, toVec8(value)); }, py::call_guard(), R"( @@ -1268,14 +1274,7 @@ Example:: py::gil_scoped_release guard; return store.multiGet(keys); }(); - std::vector res; - for (auto& value : values) { - auto bytes = py::bytes( - reinterpret_cast(value.data()), - value.size()); - res.push_back(bytes); - } - return res; + return toPyBytes(values); }, R"( Retrieve all values in ``keys``. If any key in ``keys`` is not @@ -1298,12 +1297,7 @@ Example:: [](::c10d::Store& store, const std::vector& keys, const std::vector& values) { - std::vector> vals; - vals.reserve(values.size()); - for (auto& value : values) { - vals.emplace_back(value.begin(), value.end()); - } - store.multiSet(keys, vals); + store.multiSet(keys, toVec8(values)); }, py::call_guard(), R"( @@ -1487,6 +1481,212 @@ that adds a prefix to each key inserted to the store. &::c10d::PrefixStore::getUnderlyingNonPrefixStore, R"(Recursively to get the store before layers of wrapping with PrefixStore.)"); + using namespace std::chrono_literals; + + auto collectives = + py::class_< + ::c10d::ControlCollectives, + c10::intrusive_ptr<::c10d::ControlCollectives>>( + module, + "_ControlCollectives", + R"( +Base class for all ControlCollectives implementations. +)") + .def( + "barrier", + &::c10d::ControlCollectives::barrier, + py::arg("key"), + py::arg("timeout") = 5min, + py::arg("block") = true, + py::call_guard(), + R"( +Blocks until all workers have entered this function. + +Arguments: + key (str): The unique key used to identify this operation. + timeout (duration): The timeout for this operation. + block (bool): whether to block this working waiting on the results of the barrier. +)") + .def( + "all_sum", + &::c10d::ControlCollectives::allSum, + py::arg("key"), + py::arg("data"), + py::arg("timeout") = 5min, + py::call_guard(), + R"( +Computes a sum across all workers and returns the final value. + +Arguments: + key (str): The unique key used to identify this operation. + data (int): The data to sum. + timeout (duration): The timeout for this operation. +)") + .def( + "broadcast_send", + [](::c10d::ControlCollectives& collectives, + const std::string& key, + const std::string& data, + std::chrono::milliseconds timeout = 5min) { + collectives.broadcastSend(key, toVec8(data), timeout); + }, + py::arg("key"), + py::arg("data"), + py::arg("timeout") = 5min, + py::call_guard(), + R"( +Sends data to all other workers. Must be only called from one worker. + +Arguments: + key (str): The unique key used to identify this operation. + data (str): The data to send. + timeout (duration): The timeout for this operation. +)") + .def( + "broadcast_recv", + [](::c10d::ControlCollectives& collectives, + const std::string& key, + std::chrono::milliseconds timeout = 5min) { + auto out = [&]() { + py::gil_scoped_release guard; + return collectives.broadcastRecv(key, timeout); + }(); + return toPyBytes(out); + }, + py::arg("key"), + py::arg("timeout") = 5min, + R"( +Receives data broadcasted from 1 worker. + +Arguments: + key (str): The unique key used to identify this operation. + timeout (duration): The timeout for this operation. +)") + .def( + "gather_send", + [](::c10d::ControlCollectives& collectives, + const std::string& key, + const std::string& data, + std::chrono::milliseconds timeout = 5min) { + collectives.gatherSend(key, toVec8(data), timeout); + }, + py::arg("key"), + py::arg("data"), + py::arg("timeout") = 5min, + py::call_guard(), + R"( +Sends data to one other worker. + +Arguments: + key (str): The unique key used to identify this operation. + data (str): The data to send. + timeout (duration): The timeout for this operation. +)") + .def( + "gather_recv", + [](::c10d::ControlCollectives& collectives, + const std::string& key, + const std::string& data, + std::chrono::milliseconds timeout = 5min) { + auto out = [&]() { + py::gil_scoped_release guard; + return collectives.gatherRecv(key, toVec8(data), timeout); + }(); + return toPyBytes(out); + }, + py::arg("key"), + py::arg("data"), + py::arg("timeout") = 5min, + R"( +Receives data broadcasted from all workers. Must only be called by one worker. + +Arguments: + key (str): The unique key used to identify this operation. + timeout (duration): The timeout for this operation. +)") + + .def( + "scatter_send", + [](::c10d::ControlCollectives& collectives, + const std::string& key, + const std::vector& data, + std::chrono::milliseconds timeout = 5min) { + auto out = [&]() { + py::gil_scoped_release guard; + return collectives.scatterSend(key, toVec8(data), timeout); + }(); + return toPyBytes(out); + }, + py::arg("key"), + py::arg("data"), + py::arg("timeout") = 5min, + R"( +Sends rank specific data to all other workers. + +Arguments: + key (str): The unique key used to identify this operation. + data (str): The data to send. + timeout (duration): The timeout for this operation. +)") + .def( + "scatter_recv", + [](::c10d::ControlCollectives& collectives, + const std::string& key, + std::chrono::milliseconds timeout = 5min) { + auto out = [&]() { + py::gil_scoped_release guard; + return collectives.scatterRecv(key, timeout); + }(); + return toPyBytes(out); + }, + py::arg("key"), + py::arg("timeout") = 5min, + R"( +Receives rank specific data from one worker. + +Arguments: + key (str): The unique key used to identify this operation. + timeout (duration): The timeout for this operation. +)") + + .def( + "all_gather", + [](::c10d::ControlCollectives& collectives, + const std::string& key, + const std::string& data, + std::chrono::milliseconds timeout = 5min) { + auto out = [&]() { + py::gil_scoped_release guard; + return collectives.allGather(key, toVec8(data), timeout); + }(); + return toPyBytes(out); + }, + py::arg("key"), + py::arg("data"), + py::arg("timeout") = 5min, + R"( +Sends data to all workers and receives data from all other workers. + +Arguments: + key (str): The unique key used to identify this operation. + data (str): The data to send. + timeout (duration): The timeout for this operation. +)"); + + intrusive_ptr_class_<::c10d::StoreCollectives>( + module, + "_StoreCollectives", + collectives, + R"( +An implementation of ControlCollectives that uses the provided store as the underlying +communication mechanism. + )") + .def( + py::init, int, int>(), + py::arg("store"), + py::arg("rank"), + py::arg("world_size")); + auto processGroup = py::class_< ::c10d::ProcessGroup, diff --git a/torch/distributed/__init__.py b/torch/distributed/__init__.py index eb7a690fa958..3e7dce97b54c 100644 --- a/torch/distributed/__init__.py +++ b/torch/distributed/__init__.py @@ -54,6 +54,8 @@ def is_available() -> bool: set_debug_level, set_debug_level_from_env, _make_nccl_premul_sum, + _ControlCollectives, + _StoreCollectives, ) class _DistributedPdb(pdb.Pdb): From a0df40f195b3b3ad559dfa8e31ee7066013ef550 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Thu, 16 May 2024 18:18:39 -0700 Subject: [PATCH 048/116] Add dist_pp shortcut to TORCH_LOGS (#126322) distributed log category already includes pipelining since its under the torch.distributed umbrella. So both TORCH_LOGS=distributed and TORCH_LOGS=dist_pp will enable PP logs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126322 Approved by: https://github.com/kwen2501 --- torch/_logging/_registrations.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch/_logging/_registrations.py b/torch/_logging/_registrations.py index 0530c12df304..d76b5610e97e 100644 --- a/torch/_logging/_registrations.py +++ b/torch/_logging/_registrations.py @@ -31,6 +31,7 @@ register_log( "dist_ddp", ["torch.nn.parallel.distributed", "torch._dynamo.backends.distributed"] ) +register_log("dist_pp", ["torch.distributed.pipelining"]) register_log("dist_fsdp", ["torch.distributed.fsdp"]) register_log("onnx", "torch.onnx") register_log("export", ["torch._dynamo", "torch.export", *DYNAMIC]) From 9edf54df4d12c8ec0ab007d1b28758f3c6dfbddf Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Thu, 16 May 2024 13:13:02 -0700 Subject: [PATCH 049/116] [dtensor] refactor view ops to use OpStrategy (#126011) As titled. Some ops require adjustment of output shape argument. In rule-based sharding prop, global output shape was inferred in the rule (in `view_ops.py`). In strategy-based sharding prop, it is now obtained from propagated out_tensor_meta (in `sharding_prop.py`). Pull Request resolved: https://github.com/pytorch/pytorch/pull/126011 Approved by: https://github.com/wanchaol, https://github.com/XilunWu --- test/distributed/_tensor/test_view_ops.py | 15 +- torch/distributed/_spmd/batch_dim_utils.py | 10 +- torch/distributed/_tensor/op_schema.py | 8 + torch/distributed/_tensor/ops/view_ops.py | 302 ++++++++------------- torch/distributed/_tensor/sharding_prop.py | 55 ++-- 5 files changed, 160 insertions(+), 230 deletions(-) diff --git a/test/distributed/_tensor/test_view_ops.py b/test/distributed/_tensor/test_view_ops.py index 429e62588651..2ea89e34789b 100644 --- a/test/distributed/_tensor/test_view_ops.py +++ b/test/distributed/_tensor/test_view_ops.py @@ -11,9 +11,9 @@ from torch.distributed._tensor.debug import CommDebugMode from torch.distributed._tensor.ops.view_ops import ( Broadcast, + dim_maps, Flatten, InputDim, - ops, Repeat, Singleton, Split, @@ -130,8 +130,8 @@ def world_size(self) -> int: return 6 def call_dt_test(self, op, args, kwargs, device_mesh: DeviceMesh): - spec = ops[op] - rules = spec.dim_map(*args, **kwargs) + dim_map = dim_maps[op] + rules = dim_map(*args, **kwargs) outputs = op(*args, **kwargs) flat_args = pytree.arg_tree_leaves(*args) in_shape = flat_args[0].shape @@ -163,7 +163,6 @@ def call_dt_test(self, op, args, kwargs, device_mesh: DeviceMesh): ) for in_shard in all_sharding_choices: - # print(f' |--- {in_shard}') in_dt = distribute_tensor(args[0], device_mesh, in_shard) comm_mode = CommDebugMode() @@ -180,7 +179,7 @@ def call_dt_test(self, op, args, kwargs, device_mesh: DeviceMesh): self.assertEqual(outputs, full_out) def dimmap_test(self, op, args, expected_rule_output): - rules = ops[op].dim_map(*args) + rules = dim_maps[op](*args) self.assertEqual(rules, expected_rule_output) self.call_dt_test(op, args, {}, self.device_mesh) @@ -229,7 +228,7 @@ def test_view_ops(self): ) with self.assertRaises(AssertionError): - ops[torch.broadcast_to].dim_map(randn(24, 36), (1, 2, 4)) + dim_maps[torch.broadcast_to](randn(24, 36), (1, 2, 4)) self.dimmap_test( torch.broadcast_to, @@ -495,14 +494,14 @@ def test_complex_view_ops(self): InputDim(0), Flatten((InputDim(1), InputDim(2))), ) - view_as_complex_rule = ops[torch.view_as_complex].dim_map(inp) + view_as_complex_rule = dim_maps[torch.view_as_complex](inp) self.assertEqual(view_as_complex_rule, expected_view_as_complex_rule) expected_view_as_real_rule = ( InputDim(0), Split(InputDim(1), (13, 2), 0), Split(InputDim(1), (13, 2), 1), ) - view_as_real_rule = ops[torch.view_as_real].dim_map(intermediate) + view_as_real_rule = dim_maps[torch.view_as_real](intermediate) self.assertEqual(view_as_real_rule, expected_view_as_real_rule) # test sharded computation correctness diff --git a/torch/distributed/_spmd/batch_dim_utils.py b/torch/distributed/_spmd/batch_dim_utils.py index afb9dd2e7d3b..6d36b2e38118 100644 --- a/torch/distributed/_spmd/batch_dim_utils.py +++ b/torch/distributed/_spmd/batch_dim_utils.py @@ -9,11 +9,7 @@ from torch import Tensor from torch.distributed._tensor import DeviceMesh, Replicate, Shard -from torch.distributed._tensor.ops.view_ops import ( - DimSpec, - InputDim, - ops as view_op_rules, -) +from torch.distributed._tensor.ops.view_ops import dim_maps, DimSpec, InputDim from torch.distributed._tensor.placement_types import _Partial, DTensorSpec aten = torch.ops.aten @@ -80,12 +76,12 @@ def compute_batch_dim(self, node: fx.Node, full_reduction=False) -> int: return self.batch_dim_map[node] if node.target in self.dim_rule_map: - view_op_rule = view_op_rules[self.dim_rule_map[node.target]] # type: ignore[index] + dim_map = dim_maps[self.dim_rule_map[node.target]] # type: ignore[index] args_val = pytree.tree_map_only(fx.Node, lambda n: n.meta["val"], node.args) kwargs_val = pytree.tree_map_only( fx.Node, lambda n: n.meta["val"], node.kwargs ) - output_dim_rules = view_op_rule.dim_map(*args_val, **kwargs_val) + output_dim_rules = dim_map(*args_val, **kwargs_val) def collect_input_dim(cmd: DimSpec, input_dims: Set[int]): if isinstance(cmd, InputDim): diff --git a/torch/distributed/_tensor/op_schema.py b/torch/distributed/_tensor/op_schema.py index 7d5bd691395b..4918bffec621 100644 --- a/torch/distributed/_tensor/op_schema.py +++ b/torch/distributed/_tensor/op_schema.py @@ -161,6 +161,14 @@ def output_ndim(self): def output_shape(self): return self.strategies[0].output_spec.shape + @property + def ndim(self): + return self.output_ndim + + @property + def shape(self): + return self.output_shape + class TupleStrategy(StrategyType): """ diff --git a/torch/distributed/_tensor/ops/view_ops.py b/torch/distributed/_tensor/ops/view_ops.py index be72cc9509f5..598c973170f4 100644 --- a/torch/distributed/_tensor/ops/view_ops.py +++ b/torch/distributed/_tensor/ops/view_ops.py @@ -16,23 +16,24 @@ import torch from torch import Tensor -from torch._subclasses.fake_tensor import unset_fake_temporarily -from torch.distributed._tensor._utils import compute_local_shape from torch.distributed._tensor.api import Shard from torch.distributed._tensor.op_schema import ( OpSchema, - OutputSharding, + OpStrategy, + PlacementStrategy, RuntimeSchemaInfo, + StrategyType, ) from torch.distributed._tensor.ops.utils import ( + generate_redistribute_costs, normalize_dim, normalize_dims, prod, - register_prop_rule, + register_op_strategy, ) from torch.distributed._tensor.placement_types import DTensorSpec, Placement, Replicate -from torch.fx.experimental.proxy_tensor import disable_proxy_modes_tracing +from torch.distributed.device_mesh import DeviceMesh aten = torch.ops.aten @@ -454,68 +455,41 @@ def dim_reduction( ) -@dataclass -class Op: - dim_map: Callable[..., DimMap] - shape_argnum: Optional[int] = None - - -ops: Dict[Callable[..., torch.Tensor], Op] = { - torch.atleast_1d: Op(dim_map=lambda x: dim_pad_left(x.ndim, 1)), - torch.atleast_2d: Op(dim_map=lambda x: dim_pad_left(x.ndim, 2)), - torch.atleast_3d: Op(dim_map=lambda x: dim_atleast_3d(x.ndim)), - torch.broadcast_to: Op( - dim_map=lambda input, shape: expand(input.shape, shape), shape_argnum=1 - ), - Tensor.expand: Op( - dim_map=lambda self, *sizes: expand(self.shape, normalize_sizes(sizes)), - shape_argnum=1, - ), - torch.flatten: Op(dim_map=lambda tensor: dim_flatten(tensor.ndim)), - torch.movedim: Op( - dim_map=lambda input, source, destination: dim_movedim( - input.ndim, source, destination - ) - ), - torch.permute: Op( - dim_map=lambda input, dims: tuple( - InputDim(i) for i in normalize_dims(dims, input.ndim) - ) - ), - torch.ravel: Op(dim_map=lambda tensor: dim_flatten(tensor.ndim)), - Tensor.repeat: Op(dim_map=lambda self, *sizes: dim_repeat(self.ndim, sizes)), - torch.reshape: Op( - dim_map=lambda input, shape: view_groups(input.shape, shape), - shape_argnum=1, +dim_maps: Dict[Callable[..., torch.Tensor], Callable[..., DimMap]] = { + torch.atleast_1d: lambda x: dim_pad_left(x.ndim, 1), + torch.atleast_2d: lambda x: dim_pad_left(x.ndim, 2), + torch.atleast_3d: lambda x: dim_atleast_3d(x.ndim), + torch.broadcast_to: lambda input, shape: expand(input.shape, shape), + Tensor.expand: lambda self, *sizes: expand(self.shape, normalize_sizes(sizes)), + torch.flatten: lambda tensor: dim_flatten(tensor.ndim), + torch.movedim: lambda input, source, destination: dim_movedim( + input.ndim, source, destination ), - torch.squeeze: Op(dim_map=lambda input, dim=None: dim_squeeze(input.shape, dim)), - torch.tile: Op(dim_map=lambda input, dims: dim_tile(input.ndim, dims)), - torch.transpose: Op( - dim_map=lambda input, dim0, dim1: dim_transpose(input.ndim, dim0, dim1) + torch.permute: lambda input, dims: tuple( + InputDim(i) for i in normalize_dims(dims, input.ndim) ), - torch.unsqueeze: Op(dim_map=lambda input, dim: dim_unsqueeze(input.ndim, dim)), - Tensor.view: Op( - dim_map=lambda input, *shape: view_groups(input.shape, shape), - shape_argnum=1, - ), - torch.view_as_complex: Op( - dim_map=lambda input: dim_flatten(input.ndim, input.ndim - 2) - ), - torch.view_as_real: Op(dim_map=lambda input: dim_view_as_real(input.shape)), + torch.ravel: lambda tensor: dim_flatten(tensor.ndim), + Tensor.repeat: lambda self, *sizes: dim_repeat(self.ndim, sizes), + torch.reshape: lambda input, shape: view_groups(input.shape, shape), + torch.squeeze: lambda input, dim=None: dim_squeeze(input.shape, dim), + torch.tile: lambda input, dims: dim_tile(input.ndim, dims), + torch.transpose: lambda input, dim0, dim1: dim_transpose(input.ndim, dim0, dim1), + torch.unsqueeze: lambda input, dim: dim_unsqueeze(input.ndim, dim), + Tensor.view: lambda input, *shape: view_groups(input.shape, shape), + torch.view_as_complex: lambda input: dim_flatten(input.ndim, input.ndim - 2), + torch.view_as_real: lambda input: dim_view_as_real(input.shape), } def propagate_shape_and_sharding( - in_shard: Sequence[Placement], + input_src_placements: Sequence[Placement], local_in_shape: Shape, rule: DimMap, mesh_sizes: Shape, -) -> Tuple[Shape, Optional[Sequence[Placement]], torch.Tensor]: +) -> Tuple[Sequence[Placement], Sequence[Placement]]: """ - Determine output sharding and tensor shape based on given global tensor shape and input sharding. - - Takes as input the global shape of the tensor, and the input sharding, - and produce corresponding output sharding and shape of the output tensor. + Determine input target sharding and output sharding based on + given global tensor shape and input source sharding. Sharding propagation follows mapped dimensions: - An output dimension that maps directly to an input dimension is sharded equally @@ -524,16 +498,13 @@ def propagate_shape_and_sharding( - An output dimension that is a split of the input dimension can only be sharded if the leftmost split size is divisible by the mesh dimension """ - assert len(in_shard) == len(mesh_sizes) - sharded_in_dims: Set[int] = {s.dim for s in in_shard if isinstance(s, Shard)} + assert len(input_src_placements) == len(mesh_sizes) # for each input dim, for each mesh dim, provides a list of possible shardable dimensions - shardable_dims: torch.Tensor = torch.ones( - (len(local_in_shape), len(mesh_sizes)), dtype=torch.bool - ) + mesh_ndim = len(mesh_sizes) + shardable_dims: Dict[int, List[bool]] = {} # in case an input dimension disappears (e.g. collapsing, reduction) # we cannot shard in that dimension (we need a replication fall-back rule) - seen_input_dims: Set[int] = set() def collect_used_inputs(cmd: DimSpec) -> None: @@ -545,28 +516,19 @@ def collect_used_inputs(cmd: DimSpec) -> None: for cmd in rule: collect_used_inputs(cmd) for dim in range(len(local_in_shape)): - shardable_dims[dim, :] = dim in seen_input_dims + shardable_dims[dim] = [dim in seen_input_dims] * mesh_ndim - def get_dim_size(cmd: DimSpec) -> Tuple[int, Optional[InputDim]]: + def get_in_dim_to_shard(cmd: DimSpec) -> Optional[InputDim]: if isinstance(cmd, InputDim): - seen_input_dims.add(cmd.input_dim) - return ( - local_in_shape[cmd.input_dim], - cmd if cmd.input_dim in sharded_in_dims else None, - ) + return cmd elif isinstance(cmd, Flatten): for dim in cmd.input_dims[1:]: if isinstance(dim, InputDim): - shardable_dims[dim.input_dim, :] = False + shardable_dims[dim.input_dim] = [False] * mesh_ndim dim0 = cmd.input_dims[0] - return ( - prod(get_dim_size(a)[0] for a in cmd.input_dims), - dim0 - if isinstance(dim0, InputDim) and dim0.input_dim in sharded_in_dims - else None, - ) + return dim0 if isinstance(dim0, InputDim) else None elif isinstance(cmd, Split): - _, in_dim = get_dim_size(cmd.input_dim) + in_dim = get_in_dim_to_shard(cmd.input_dim) out_size = cmd.group_shape[cmd.split_id] if cmd.split_id == 0 and in_dim is not None: # we need to check that the input dimension is divisible @@ -579,14 +541,13 @@ def get_dim_size(cmd: DimSpec) -> Tuple[int, Optional[InputDim]]: # but we will allow it if that's the input and it's compatible # 1. is this dimension shardable on each individual mesh dim? - for mesh_dim, mesh_dim_size in enumerate(mesh_sizes): - shardable_dims[in_dim.input_dim, mesh_dim] = ( - out_size % mesh_dim_size == 0 - ) + shardable_dims[in_dim.input_dim] = [ + out_size % mesh_dim_size == 0 for mesh_dim_size in mesh_sizes + ] # 2. here we special case things like [Shard(0), Shard(0)] submesh_size = 1 - for size, shard in zip(mesh_sizes, in_shard): + for size, shard in zip(mesh_sizes, input_src_placements): if isinstance(shard, Shard) and shard.dim == in_dim: submesh_size *= size assert ( @@ -594,158 +555,113 @@ def get_dim_size(cmd: DimSpec) -> Tuple[int, Optional[InputDim]]: ), f"Resulting dimension size {out_size} is not divisible by its mesh dimension {submesh_size}." # we will only shard our first component of the split - return out_size, in_dim if cmd.split_id == 0 else None - elif isinstance(cmd, Singleton): - return 1, None - elif isinstance(cmd, Broadcast): - return cmd.dim_size, None - elif isinstance(cmd, NewDim): - return cmd.size, None + return in_dim if cmd.split_id == 0 else None elif isinstance(cmd, Repeat): - size, in_dim = get_dim_size(cmd.input_dim) + in_dim = get_in_dim_to_shard(cmd.input_dim) if in_dim is not None: - shardable_dims[in_dim.input_dim, :] = False - return size * cmd.times, None + shardable_dims[in_dim.input_dim] = [False] * mesh_ndim + return None else: - raise RuntimeError(f"cmd not found: {cmd}, in rule: {rule}") + return None - dim_map = {} - out_shape = [] + # for each output dim, find the corresponding input dim in terms of sharding prop + shard_dim_map = {} for dim, cmd in enumerate(rule): - out_size, in_dim = get_dim_size(cmd) - out_shape.append(out_size) + in_dim = get_in_dim_to_shard(cmd) if in_dim is not None: - dim_map[in_dim.input_dim] = dim + shard_dim_map[in_dim.input_dim] = dim - needs_reshard = any( - isinstance(placement, Shard) and not shardable_dims[placement.dim][mesh_dim] - for mesh_dim, placement in enumerate(in_shard) - ) - - output_placements = ( - None - if needs_reshard - else [Shard(dim_map[s.dim]) if isinstance(s, Shard) else s for s in in_shard] - ) + input_tgt_placements = [ + Replicate() + if isinstance(p, Shard) and not shardable_dims[p.dim][mesh_dim] + else p + for mesh_dim, p in enumerate(input_src_placements) + ] + output_placements = [ + Shard(shard_dim_map[p.dim]) if isinstance(p, Shard) else p + for p in input_tgt_placements + ] - return (tuple(out_shape), output_placements, shardable_dims) + return input_tgt_placements, output_placements -def register_prop_rule_map( +def register_op_strategy_map( aten_op_overload: torch._ops.OpOverload, local_op_name: Callable[..., torch.Tensor], schema_info: Optional[RuntimeSchemaInfo] = None, ) -> None: - spec: Op = ops[local_op_name] - - @register_prop_rule(aten_op_overload, schema_info=schema_info) - def reshape_prop(op_schema: OpSchema) -> OutputSharding: - rules = spec.dim_map(*op_schema.args_schema, **op_schema.kwargs_schema) - input_dtensor_spec = cast(DTensorSpec, op_schema.args_schema[0]) - mesh = input_dtensor_spec.mesh - - assert isinstance( - input_dtensor_spec, DTensorSpec - ), "Expected first input to be a DTensorSpec" - global_in_shape = input_dtensor_spec.shape + dim_map: Callable[..., DimMap] = dim_maps[local_op_name] + + @register_op_strategy(aten_op_overload, schema_info=schema_info) + def reshape_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: + rules = dim_map(*op_schema.args_schema, **op_schema.kwargs_schema) + input_strategy = cast(OpStrategy, op_schema.args_schema[0]) + global_in_shape = input_strategy.output_shape assert global_in_shape is not None, "Shape required." - with disable_proxy_modes_tracing(), unset_fake_temporarily(): - ( - global_out_shape, - shard_out, - shardable_dims, - ) = propagate_shape_and_sharding( - input_dtensor_spec.placements, + output_strategy = OpStrategy([]) + for input_placement_strategy in input_strategy.strategies: + input_src_spec = input_placement_strategy.output_spec + + input_tgt_placements, output_placements = propagate_shape_and_sharding( + input_src_spec.placements, tuple(global_in_shape), rules, mesh.shape, ) - if shard_out is not None: - # no reshard needed - output_dtensor_spec = DTensorSpec(mesh=mesh, placements=tuple(shard_out)) - - # We only need the local shape to lower the call into the local op - args = op_schema.args_schema - shape_argnum = spec.shape_argnum - if shape_argnum is not None: - # compute the local shape from the global shape, then return - # a resharding even if we don't really reshard, the only reason - # for this type of resharding is to lower the global shape to - # local shape - local_out_shape = compute_local_shape( - list(global_out_shape), mesh, shard_out - ) - - suggested_schema = OpSchema( - op=op_schema.op, - args_schema=args[:shape_argnum] - + (tuple(local_out_shape),) - + args[shape_argnum + 1 :], - kwargs_schema=op_schema.kwargs_schema, - ) - return OutputSharding( - output_spec=output_dtensor_spec, - redistribute_schema=suggested_schema, - needs_redistribute=True, - ) - - return OutputSharding(output_spec=output_dtensor_spec) - - else: # TODO: optimize this. we shouldn't simply blindly replicate # unshardable dims ... # FIXME: this can be wrong for situations where we have # [Shard(0), Shard(0)] - suggested_placements = [ - p - if not isinstance(p, Shard) or shardable_dims[p.dim][mesh_dim] - else Replicate() - for mesh_dim, p in enumerate(input_dtensor_spec.placements) + input_tgt_spec = DTensorSpec( + placements=tuple(input_tgt_placements), + mesh=input_src_spec.mesh, + tensor_meta=input_src_spec.tensor_meta, + ) + redistribute_costs = [ + generate_redistribute_costs(input_strategy, input_tgt_spec) ] - return OutputSharding( - output_spec=None, - redistribute_schema=OpSchema( - op=op_schema.op, - args_schema=( - DTensorSpec( - placements=tuple(suggested_placements), - mesh=input_dtensor_spec.mesh, - tensor_meta=input_dtensor_spec.tensor_meta, - ), - ) - + op_schema.args_schema[1:], - kwargs_schema=op_schema.kwargs_schema, - ), + + output_spec = DTensorSpec(mesh=mesh, placements=tuple(output_placements)) + output_strategy.strategies.append( + PlacementStrategy( + output_specs=output_spec, + input_specs=(input_tgt_spec,), + redistribute_cost=redistribute_costs, + ) ) + return output_strategy -register_prop_rule_map(aten.squeeze.default, torch.squeeze) -register_prop_rule_map( + +register_op_strategy_map(aten.squeeze.default, torch.squeeze) +register_op_strategy_map( aten.squeeze.dim, torch.squeeze, schema_info=RuntimeSchemaInfo(1) ) -register_prop_rule_map(aten.view.default, Tensor.view, schema_info=RuntimeSchemaInfo(1)) -register_prop_rule_map( +register_op_strategy_map( + aten.view.default, Tensor.view, schema_info=RuntimeSchemaInfo(1) +) +register_op_strategy_map( aten.reshape.default, torch.reshape, schema_info=RuntimeSchemaInfo(1) ) -register_prop_rule_map( +register_op_strategy_map( aten._unsafe_view.default, Tensor.view, schema_info=RuntimeSchemaInfo(1) ) -register_prop_rule_map( +register_op_strategy_map( aten.unsqueeze.default, torch.unsqueeze, schema_info=RuntimeSchemaInfo(1) ) -register_prop_rule_map( +register_op_strategy_map( aten.expand.default, Tensor.expand, schema_info=RuntimeSchemaInfo(1) ) -register_prop_rule_map( +register_op_strategy_map( aten.permute.default, torch.permute, schema_info=RuntimeSchemaInfo(1) ) -register_prop_rule_map( +register_op_strategy_map( aten.repeat.default, Tensor.repeat, schema_info=RuntimeSchemaInfo(1) ) -register_prop_rule_map( +register_op_strategy_map( aten.transpose.int, torch.transpose, schema_info=RuntimeSchemaInfo(1) ) -register_prop_rule_map(aten.view_as_complex.default, torch.view_as_complex) -register_prop_rule_map(aten.view_as_real.default, torch.view_as_real) +register_op_strategy_map(aten.view_as_complex.default, torch.view_as_complex) +register_op_strategy_map(aten.view_as_real.default, torch.view_as_real) diff --git a/torch/distributed/_tensor/sharding_prop.py b/torch/distributed/_tensor/sharding_prop.py index 9acf6aa0c919..d173a91a771c 100644 --- a/torch/distributed/_tensor/sharding_prop.py +++ b/torch/distributed/_tensor/sharding_prop.py @@ -45,15 +45,21 @@ def __init__(self) -> None: # op map to save static argnum to decide to reuse sharding prop cache or re-run sharding prop self.op_to_schema_info: Dict[OpOverload, RuntimeSchemaInfo] = {} self.propagate_op_sharding = lru_cache(None)(self.propagate_op_sharding_non_cached) # type: ignore[method-assign] - # op map to save indices of size (and stride) args which may need to be modified in sharding prop - self.op_to_size_and_stride_idx: Dict[ + # op map to save indices of shape (and stride) args which may need to be modified in sharding prop + self.op_to_shape_and_stride_idx: Dict[ OpOverload, Union[int, Tuple[int, int]] ] = { + # new factory ops aten.new_empty.default: 1, aten.new_full.default: 1, aten.new_ones.default: 1, aten.new_zeros.default: 1, aten.new_empty_strided.default: (1, 2), + # view ops + aten.expand.default: 1, + aten.reshape.default: 1, + aten.view.default: 1, + aten._unsafe_view.default: 1, } def register_sharding_prop_rule( @@ -260,16 +266,19 @@ def spec_to_strategy(spec: object) -> object: ) suggestion_schema._inplace_rewrap_schema_suggestion(op_schema) - # size and stride args need to be modified for new factory ops, potentially - if op_schema.op in self.op_to_size_and_stride_idx: + # shape and stride args need to be modified for + # view ops and new factory ops, potentially + if op_schema.op in self.op_to_shape_and_stride_idx: assert isinstance(output_strategy.output_spec, DTensorSpec) # It happens when the output has the same shape as the input # and the input placements are not all Replicate(). if output_strategy.output_spec.is_sharded(): - needs_redistribute = True - suggestion_schema = self._adjust_size_and_stride_args( - op_schema, output_strategy.output_spec, mesh + schema = suggestion_schema or op_schema + assert isinstance(out_tensor_meta, TensorMeta) + suggestion_schema = self._adjust_shape_and_stride_args( + out_tensor_meta, schema, output_strategy.output_spec, mesh ) + needs_redistribute = True # construct output spec for the op if op_schema.return_type_tuple_tensor_like(): @@ -442,29 +451,31 @@ def _select_strategy(self, strategy: OpStrategy) -> PlacementStrategy: # for eager execution, we just select the one with the minimal redistribute cost return strategy.strategies[strategy_costs.index(min(strategy_costs))] - def _adjust_size_and_stride_args( - self, op_schema: OpSchema, spec: DTensorSpec, mesh: DeviceMesh + def _adjust_shape_and_stride_args( + self, + out_tensor_meta: TensorMeta, + schema: OpSchema, + spec: DTensorSpec, + mesh: DeviceMesh, ) -> OpSchema: - size_stride_idx = self.op_to_size_and_stride_idx[op_schema.op] - if isinstance(size_stride_idx, tuple): - size_idx, stride_idx = size_stride_idx + shape_stride_idx = self.op_to_shape_and_stride_idx[schema.op] + if isinstance(shape_stride_idx, tuple): + shape_idx, stride_idx = shape_stride_idx else: - size_idx = size_stride_idx + shape_idx = shape_stride_idx stride_idx = None - expected_input_schema = list(op_schema.args_schema) - size = cast(list, expected_input_schema[size_idx]) - # # adjust size to be the same as that of the _local_tensor - # # of the DTensor input arg at index 0, which is inferred - expected_input_schema[size_idx] = compute_local_shape( - size, mesh, spec.placements + expected_input_schema = list(schema.args_schema) + # adjust shape to be the same as that of the _local_tensor + # of the DTensor input arg at index 0, which is inferred + expected_input_schema[shape_idx] = compute_local_shape( + out_tensor_meta.shape, mesh, spec.placements ) # adjust the stride arg for aten.new_empty_strided.default if stride_idx: - stride = cast(list, expected_input_schema[stride_idx]) expected_input_schema[stride_idx] = compute_local_stride( - stride, mesh, spec.placements + out_tensor_meta.stride, mesh, spec.placements ) - return OpSchema(op_schema.op, tuple(expected_input_schema), {}) + return OpSchema(schema.op, tuple(expected_input_schema), schema.kwargs_schema) From 5756b53dd886583d2565a206e1abf94984f0c5f5 Mon Sep 17 00:00:00 2001 From: Stonepia Date: Fri, 17 May 2024 06:05:51 +0000 Subject: [PATCH 050/116] [XPU] call empty_cache for dynamo tests (#126377) When running a batch of models, lacking `empty_cache()` would result in OOM for subsequent models. This PR unifies the `empty_cache` call for both CUDA and XPU. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126377 Approved by: https://github.com/EikanWang, https://github.com/guangyey, https://github.com/desertfire --- benchmarks/dynamo/common.py | 30 ++++++++++++++++++++++++------ 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index 096dbc48ec7d..2b877e43447e 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -354,6 +354,24 @@ def deterministic_torch_manual_seed(*args, **kwargs): torch.manual_seed = deterministic_torch_manual_seed +def empty_gpu_cache(device): + """ + Explicitly empty gpu cache to avoid OOM in subsequent run. + """ + + if device not in ["cuda", "xpu"]: + log.warning( + "Trying to call the empty_gpu_cache for device: %s, which is not in list [cuda, xpu]", + device, + ) + return + + if device == "cuda": + torch.cuda.empty_cache() + elif device == "xpu": + torch.xpu.empty_cache() + + def synchronize(): pass @@ -2278,7 +2296,7 @@ def decay_batch_exp(self, batch_size, factor=0.5, divisor=2): def batch_size_finder(self, device, model_name, initial_batch_size=1024): batch_size = initial_batch_size while batch_size >= 1: - torch.cuda.empty_cache() + empty_gpu_cache(current_device) try: device, name, model, example_inputs, _ = self.load_model( device, @@ -2468,7 +2486,7 @@ def record_status(accuracy_status, dynamo_start_stats): fp64_outputs = None finally: del model_fp64, inputs_fp64 - torch.cuda.empty_cache() + empty_gpu_cache(current_device) tolerance, cos_similarity = self.get_tolerance_and_cosine_flag( self.args.training, current_device, name @@ -2497,7 +2515,7 @@ def record_status(accuracy_status, dynamo_start_stats): return record_status(accuracy_status, dynamo_start_stats=start_stats) finally: del model_copy - torch.cuda.empty_cache() + empty_gpu_cache(current_device) # Rerun native pytorch reset_rng_state() @@ -2518,7 +2536,7 @@ def record_status(accuracy_status, dynamo_start_stats): return record_status(accuracy_status, dynamo_start_stats=start_stats) finally: del model_copy - torch.cuda.empty_cache() + empty_gpu_cache(current_device) # Two eager runs should have exactly same result is_same = True @@ -2719,7 +2737,7 @@ def warmup(fn, model, example_inputs, mode, niters=5): try: if current_device == "cuda": torch.cuda.reset_peak_memory_stats() - torch.cuda.empty_cache() + empty_gpu_cache(current_device) t0 = time.perf_counter() for _ in range(niters): fn(model, example_inputs) @@ -2949,7 +2967,7 @@ def run_one_model( name, model, example_inputs, optimize_ctx, experiment, tag ) print(status) - torch.cuda.empty_cache() + empty_gpu_cache(current_device) self.maybe_preserve_compile_debug(name, status) From f9a70331941951d759f77d96d87593c0fb96e9cc Mon Sep 17 00:00:00 2001 From: chilli Date: Thu, 16 May 2024 14:44:55 -0700 Subject: [PATCH 051/116] Refactor partitioner and clean it up (#126318) Pull Request resolved: https://github.com/pytorch/pytorch/pull/126318 Approved by: https://github.com/anijain2305 --- functorch/compile/__init__.py | 1 - test/functorch/test_aotdispatch.py | 64 -- torch/_functorch/compile_utils.py | 4 +- torch/_functorch/partitioners.py | 962 ++++++++++++++++------------- 4 files changed, 547 insertions(+), 484 deletions(-) diff --git a/functorch/compile/__init__.py b/functorch/compile/__init__.py index 96b853cd2e27..e7548a5ff6b9 100644 --- a/functorch/compile/__init__.py +++ b/functorch/compile/__init__.py @@ -25,7 +25,6 @@ from torch._functorch.partitioners import ( default_partition, draw_graph, - draw_joint_graph, min_cut_rematerialization_partition, ) from torch._functorch.python_key import pythonkey_decompose diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index ffa71a7e905b..cfbd96e7368d 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -4835,70 +4835,6 @@ def f(a, b, c, d): self.assertEqual(get_num_ins_outs(fw_graph), (4, 2)) self.assertEqual(get_num_ins_outs(bw_graph), (2, 4)) - @unittest.skipIf(not USE_NETWORKX, "networkx not available") - def test_min_cut_partitioner_recomputable_ops(self): - def f(x): - return x * x * x - - recomputable_ops = [] - partition_fn = partial( - min_cut_rematerialization_partition, recomputable_ops=recomputable_ops - ) - - fw_graph, bw_graph = get_fw_bw_graph( - f, [torch.randn(3, requires_grad=True)], partition_fn - ) - # Expected forward graph: - # opcode name target args kwargs - # ------------- --------- --------------- -------------------------- -------- - # placeholder primals_1 primals_1 () {} - # call_function mul aten.mul.Tensor (primals_1, primals_1) {} - # call_function mul_1 aten.mul.Tensor (mul, primals_1) {} - # output output output ([mul_1, primals_1, mul],) {} - self.assertEqual(get_num_ins_outs(fw_graph), (1, 3)) - # Expected backward graph: - # opcode name target args kwargs - # ------------- ---------- --------------- ----------------------- -------- - # placeholder primals_1 primals_1 () {} - # placeholder mul mul () {} - # placeholder tangents_1 tangents_1 () {} - # call_function mul_2 aten.mul.Tensor (tangents_1, mul) {} - # call_function mul_3 aten.mul.Tensor (tangents_1, primals_1) {} - # call_function mul_4 aten.mul.Tensor (mul_3, primals_1) {} - # call_function add aten.add.Tensor (mul_2, mul_4) {} - # call_function add_1 aten.add.Tensor (add, mul_4) {} - # output output output ([add_1],) {} - self.assertEqual(get_num_ins_outs(bw_graph), (3, 1)) - - recomputable_ops = [torch.ops.aten.mul] - partition_fn = partial( - min_cut_rematerialization_partition, recomputable_ops=recomputable_ops - ) - fw_graph, bw_graph = get_fw_bw_graph( - f, [torch.randn(3, requires_grad=True)], partition_fn - ) - # Expected forward graph: - # opcode name target args kwargs - # ------------- --------- --------------- ---------------------- -------- - # placeholder primals_1 primals_1 () {} - # call_function mul aten.mul.Tensor (primals_1, primals_1) {} - # call_function mul_1 aten.mul.Tensor (mul, primals_1) {} - # output output output ([mul_1, primals_1],) {} - self.assertEqual(get_num_ins_outs(fw_graph), (1, 2)) - # Expected backward graph: - # opcode name target args kwargs - # ------------- ---------- --------------- ----------------------- -------- - # placeholder primals_1 primals_1 () {} - # placeholder tangents_1 tangents_1 () {} - # call_function mul aten.mul.Tensor (primals_1, primals_1) {} # RECOMPUTED - # call_function mul_2 aten.mul.Tensor (tangents_1, mul) {} - # call_function mul_3 aten.mul.Tensor (tangents_1, primals_1) {} - # call_function mul_4 aten.mul.Tensor (mul_3, primals_1) {} - # call_function add aten.add.Tensor (mul_2, mul_4) {} - # call_function add_1 aten.add.Tensor (add, mul_4) {} - # output output output ([add_1],) {} - self.assertEqual(get_num_ins_outs(bw_graph), (2, 1)) - def test_contiguous(self): # The test simulates the condition where transpose followed by view # happens in the backward pass. diff --git a/torch/_functorch/compile_utils.py b/torch/_functorch/compile_utils.py index ffa37e59f04d..c9c750835a9f 100644 --- a/torch/_functorch/compile_utils.py +++ b/torch/_functorch/compile_utils.py @@ -1,6 +1,8 @@ # mypy: ignore-errors +from typing import Callable + import torch import torch.fx as fx from torch.utils import _pytree as pytree @@ -9,7 +11,7 @@ aten = torch.ops.aten -def get_aten_target(node): +def get_aten_target(node: fx.Node) -> Callable: if hasattr(node.target, "overloadpacket"): return node.target.overloadpacket return node.target diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py index 5a43cd5e7bf3..d104247b3f63 100644 --- a/torch/_functorch/partitioners.py +++ b/torch/_functorch/partitioners.py @@ -1,5 +1,3 @@ -# mypy: ignore-errors - import copy import functools import heapq @@ -9,7 +7,10 @@ import operator import os from collections import defaultdict -from typing import List, Optional, Set, Tuple, TYPE_CHECKING, Union +from dataclasses import dataclass, replace +from typing import Callable, Dict, List, Optional, Set, Tuple, Union + +import sympy import torch import torch._inductor.inductor_prims @@ -28,19 +29,84 @@ from . import config from .compile_utils import fx_graph_cse, get_aten_target -if TYPE_CHECKING: - import sympy - AOT_PARTITIONER_DEBUG = config.debug_partitioner log = logging.getLogger(__name__) +aten = torch.ops.aten +prims = torch.ops.prims + + +@dataclass +class OpTypes: + """Class for keeping track of different operator categories""" + + fusible_ops: Set[Callable] + compute_intensive_ops: Set[Callable] + random_ops: Set[Callable] + view_ops: Set[Callable] + recomputable_ops: Set[Callable] + + def is_fusible(self, node: fx.Node): + return get_aten_target(node) in self.fusible_ops + + def is_compute_intensive(self, node: fx.Node): + return get_aten_target(node) in self.compute_intensive_ops + + def is_random(self, node: fx.Node): + return get_aten_target(node) in self.random_ops + + def is_view(self, node: fx.Node): + return get_aten_target(node) in self.view_ops + + def is_recomputable(self, node: fx.Node): + return get_aten_target(node) in self.recomputable_ops + + +@dataclass +class NodeInfo: + # Be careful about iterating over these explicitly, as their order may not + # be deterministic + inputs: List[fx.Node] + _required_fw_nodes: Set[fx.Node] + required_bw_nodes: Set[fx.Node] + unclaimed_nodes: Set[fx.Node] + fw_order: Dict[fx.Node, int] + + @property + def required_fw_nodes(self) -> List[fx.Node]: + return sorted( + (n for n in self._required_fw_nodes), key=lambda n: self.fw_order[n] + ) + + def is_required_fw(self, n: fx.Node) -> bool: + return n in self._required_fw_nodes -def must_recompute(node): + def is_required_bw(self, n: fx.Node) -> bool: + return n in self.required_bw_nodes + + def is_unclaimed(self, n: fx.Node) -> bool: + return n in self.unclaimed_nodes + + def get_fw_order(self, n: fx.Node) -> int: + assert n in self._required_fw_nodes, f"Node {n} not in fw nodes!" + return self.fw_order[n] + + +@dataclass +class MinCutOptions: + ban_if_used_far_apart: bool + ban_if_long_fusible_chains: bool + ban_if_materialized_backward: bool + ban_if_not_in_allowlist: bool + ban_if_reduction: bool + + +def must_recompute(node: fx.Node) -> bool: return node.meta.get("recompute", False) -def has_recomputable_ops(fx_g): +def has_recomputable_ops(fx_g: fx.GraphModule) -> bool: found = False for node in fx_g.graph.nodes: if must_recompute(node): @@ -48,7 +114,7 @@ def has_recomputable_ops(fx_g): return False -def has_recomputable_rng_ops(fx_g): +def has_recomputable_rng_ops(fx_g: fx.GraphModule) -> bool: for node in fx_g.graph.nodes: if ( must_recompute(node) @@ -59,7 +125,7 @@ def has_recomputable_rng_ops(fx_g): return False -def sym_node_size(node): +def sym_node_size(node: fx.Node) -> int: if isinstance(node.meta["val"], (torch.SymInt, torch.SymBool)): return 1 assert isinstance(node.meta["val"], torch.SymFloat) @@ -74,7 +140,9 @@ def __repr__(self): InvalidNode = InvalidNodeBase() -def _extract_graph_with_inputs_outputs(joint_graph, inputs, outputs): +def _extract_graph_with_inputs_outputs( + joint_graph: fx.Graph, inputs: List[fx.Node], outputs: List[fx.Node] +) -> fx.Graph: """ Given a graph, extracts out a subgraph that takes the specified nodes as inputs and returns the specified outputs. @@ -136,36 +204,38 @@ def _extract_graph_with_inputs_outputs(joint_graph, inputs, outputs): return new_graph -def _is_primal(node): +def _is_primal(node: fx.Node) -> bool: return ( node.op == "placeholder" - and "tangents" not in node.target + and "tangents" not in str(node.target) and not _is_bwd_seed_offset(node) and not _is_fwd_seed_offset(node) ) -def _is_tangent(node): - return node.op == "placeholder" and "tangents" in node.target +def _is_tangent(node: fx.Node) -> bool: + return node.op == "placeholder" and "tangents" in str(node.target) -def _is_bwd_seed_offset(node): +def _is_bwd_seed_offset(node: fx.Node) -> bool: return node.op == "placeholder" and ( - "bwd_seed" in node.target or "bwd_base_offset" in node.target + "bwd_seed" in str(node.target) or "bwd_base_offset" in str(node.target) ) -def _is_fwd_seed_offset(node): +def _is_fwd_seed_offset(node: fx.Node) -> bool: return node.op == "placeholder" and ( - "fwd_seed" in node.target or "fwd_base_offset" in node.target + "fwd_seed" in str(node.target) or "fwd_base_offset" in str(node.target) ) -def _is_backward_state(node): +def _is_backward_state(node: fx.Node) -> bool: return node.op == "placeholder" and isinstance(node.meta.get("val"), BackwardState) -def _extract_fwd_bwd_outputs(joint_module: fx.GraphModule, *, num_fwd_outputs): +def _extract_fwd_bwd_outputs( + joint_module: fx.GraphModule, *, num_fwd_outputs +) -> Tuple[List[fx.Node], List[fx.Node]]: outputs = pytree.arg_tree_leaves( *(node.args for node in joint_module.graph.find_nodes(op="output")) ) @@ -174,7 +244,7 @@ def _extract_fwd_bwd_outputs(joint_module: fx.GraphModule, *, num_fwd_outputs): return fwd_outputs, bwd_outputs -def _remove_by_name(saved_values, name): +def _remove_by_name(saved_values: List[fx.Node], name: str): for saved_value in saved_values: if saved_value.name == name: saved_values.remove(saved_value) @@ -182,8 +252,12 @@ def _remove_by_name(saved_values, name): def _extract_fwd_bwd_modules( - joint_module: fx.GraphModule, saved_values, saved_sym_nodes, *, num_fwd_outputs -): + joint_module: fx.GraphModule, + saved_values: List[fx.Node], + saved_sym_nodes: List[fx.Node], + *, + num_fwd_outputs: int, +) -> Tuple[fx.GraphModule, fx.GraphModule]: fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs( joint_module, num_fwd_outputs=num_fwd_outputs ) @@ -359,14 +433,10 @@ def default_partition( ) -def _prod(x): - s = 1 - for i in x: - s *= i - return s +INT_INF = int(1e6) -def _tensor_nbytes(numel, dtype): +def _tensor_nbytes(numel: int, dtype) -> int: return numel * dtype.itemsize @@ -374,10 +444,7 @@ def _size_of(node: fx.Node) -> int: if "val" in node.meta: val = node.meta["val"] if isinstance(val, py_sym_types): - if isinstance(val, torch.SymInt): - return 1 - else: - return 999999 + return 1 # NB: The fallback values here are meaningless, maybe we should respect # torch._inductor.config.unbacked_symint_fallback (but this is a # layering violation) @@ -391,28 +458,18 @@ def _size_of(node: fx.Node) -> int: return _tensor_nbytes(hint_int(val.numel(), fallback=4098), val.dtype) raise RuntimeError(f"Unknown metadata type {type(val)}") - - # Only needed since we don't always trace with fake tensors. - if "tensor_meta" in node.meta: - metadata = node.meta["tensor_meta"] - # TODO: What is to_size_hint suppose to be? - numel = _prod(map(to_size_hint, metadata.shape)) # noqa: F821 - dtype = metadata.dtype - else: - return 0 - - return _tensor_nbytes(numel, dtype) + raise RuntimeError("We should always have `val` metadata on the nodes") # Used for some investigative purposes -def _count_ops(graph): +def _count_ops(graph: fx.Graph): from collections import defaultdict - cnt = defaultdict(int) + cnt: Dict[str, int] = defaultdict(int) for node in graph.nodes: if node.op == "call_function": cnt[node.target.__name__] += 1 - print(sorted(cnt.items(), key=operator.itemgetter(1), reverse=True)) + print(sorted(cnt.items(), key=lambda x: x[1], reverse=True)) @functools.lru_cache(None) @@ -433,14 +490,14 @@ def pointwise_ops(): return ops -def sort_depths(args, depth_map): +def sort_depths(args, depth_map: Dict[fx.Node, int]) -> List[Tuple[fx.Node, int]]: arg_depths = { arg: depth_map[arg] for arg in args if isinstance(arg, torch.fx.node.Node) } - return sorted(arg_depths.items(), key=operator.itemgetter(1), reverse=True) + return sorted(arg_depths.items(), key=lambda x: x[1], reverse=True) -def reordering_to_mimic_autograd_engine(gm): +def reordering_to_mimic_autograd_engine(gm: fx.GraphModule) -> fx.GraphModule: """ This pass finds the first bwd node in the graph (by looking at users of tangents) and then reorders the graph by walking from this node to all the @@ -464,7 +521,7 @@ def reordering_to_mimic_autograd_engine(gm): """ new_graph = fx.Graph() - env = {} + env: Dict[fx.Node, fx.Node] = {} # Add new placeholder nodes in the order specified by the inputs for node in gm.graph.find_nodes(op="placeholder"): @@ -517,7 +574,12 @@ def insert_node_in_graph(node): return new_gm -def functionalize_rng_ops(joint_module, fw_module, bw_module, num_sym_nodes): +def functionalize_rng_ops( + joint_module: fx.GraphModule, + fw_module: fx.GraphModule, + bw_module: fx.GraphModule, + num_sym_nodes: int, +) -> Tuple[fx.GraphModule, fx.GraphModule]: # During user-driven activation checkpointing, we have to ensure that a rng # op in fwd yields the same output as the recomputed rng op in the bwd. To # do this, we use functionalize wrappers to wrap the random ops and share @@ -591,11 +653,15 @@ def get_sample_rng_state(device): run_and_save_rng = torch._prims.rng_prims.run_and_save_rng_state run_with_rng_state = torch._prims.rng_prims.run_with_rng_state - + bw_tangent_start_node = None for node in bw_module.graph.find_nodes(op="placeholder"): if "tangent" in node.name: bw_tangent_start_node = node break + if bw_tangent_start_node is None: + raise RuntimeError( + "Couldn't find tangent node in graph inputs. This is unexpected, please file a bug if you see this" + ) fw_rng_state_outputs = [] for base_node, node_pair in recomputable_rng_ops_map.items(): @@ -665,7 +731,7 @@ def get_sample_rng_state(device): return fw_module, bw_module -def cleanup_recompute_tags(joint_module): +def cleanup_recompute_tags(joint_module: fx.GraphModule) -> fx.GraphModule: """ If there are two consecutive checkpointed blocks with no operator in between, we would still want to stash the tensor at the boundary of @@ -683,332 +749,50 @@ def cleanup_recompute_tags(joint_module): return joint_module -def min_cut_rematerialization_partition( - joint_module: fx.GraphModule, - _joint_inputs, - compiler="inductor", - recomputable_ops=None, - *, - num_fwd_outputs, -) -> Tuple[fx.GraphModule, fx.GraphModule]: - """ - Partitions the joint graph such that the backward recomputes the forward. - Recomputing helps in trading off memory bandwidth with computation. - - To create the fwd and bwd graph, we copy the joint graph, manually set the - outputs to just original forward or backward outputs. And then we run the - resulting graphs through dead code elimination. - - .. warning:: - This API is experimental and likely to change. - - Args: - joint_module(fx.GraphModule): The joint forward and backward graph. This - is the result of AOT Autograd tracing. - _joint_inputs: The inputs to the joint graph. This is unused. - compiler: This option determines the default set of recomputable ops. - Currently, there are two options: ``nvfuser`` and ``inductor``. - recomputable_ops: This is an optional set of recomputable ops. If this - is not None, then this set of ops will be used instead of the - default set of ops. - num_fwd_outputs: The number of outputs from the forward graph. - - Returns: - Returns the generated forward and backward Fx graph modules. - """ - try: - import networkx as nx - except ImportError as e: - raise RuntimeError( - "Need networkx installed to perform smart recomputation " "heuristics" - ) from e - - joint_module.graph.eliminate_dead_code() - joint_module.recompile() - - fx_g = joint_module.graph - - # add the CSE pass - if config.cse: - cse_graph = fx_graph_cse(fx_g) - joint_module.graph = cse_graph - joint_graph = joint_module.graph - - graph_has_recomputable_ops = has_recomputable_ops(joint_module) - graph_has_recomputable_rng_ops = has_recomputable_rng_ops(joint_module) - if graph_has_recomputable_ops: - joint_module = cleanup_recompute_tags(joint_module) - - name_to_node = {} - for node in joint_module.graph.nodes: - name_to_node[node.name] = node - - def classify_nodes(joint_module): - required_bw_nodes = set() - for node in joint_module.graph.nodes: - if node.op == "placeholder" and "tangents" in node.target: - required_bw_nodes.add(node) - if node in required_bw_nodes: - required_bw_nodes.update(node.users) - - primal_inputs = list(filter(_is_primal, joint_module.graph.nodes)) - fwd_seed_offset_inputs = list( - filter(_is_fwd_seed_offset, joint_module.graph.nodes) - ) - inputs = primal_inputs + fwd_seed_offset_inputs - fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs( - joint_module, num_fwd_outputs=num_fwd_outputs - ) - required_bw_nodes.update( - o for o in bwd_outputs if o is not None and o.op != "output" - ) - forward_only_graph = _extract_graph_with_inputs_outputs( - joint_module.graph, inputs, fwd_outputs - ) - required_fw_nodes = { - name_to_node[node.name] - for node in forward_only_graph.nodes - if node.op != "output" - } - unclaimed_nodes = { - node - for node in joint_module.graph.nodes - if node not in required_fw_nodes and node not in required_bw_nodes - } - return ( - fwd_outputs, - required_fw_nodes, - required_bw_nodes, - unclaimed_nodes, - inputs, - ) - - ( - orig_fw_outputs, - required_fw_nodes, - required_bw_nodes, - unclaimed_nodes, - inputs, - ) = classify_nodes(joint_module) - - # networkx blows up on graphs with no required backward nodes - # Since there's nothing to partition anyway, and the default partitioner can "handle" - # this case, send our graph over to the default partitioner. - if len(required_bw_nodes) == 0: - return default_partition( - joint_module, _joint_inputs, num_fwd_outputs=num_fwd_outputs - ) - - def is_fusible(a, b): - # We can perform "memory fusion" into a cat, but cat cannot be a - # producer to a fusion - if get_aten_target(b) == aten.cat: - return True - return get_aten_target(a) in fusible_ops and get_aten_target(b) in fusible_ops - - fw_order = 0 - for node in joint_module.graph.nodes: - if node in required_fw_nodes: - node.fw_order = fw_order - fw_order += 1 - - for node in reversed(joint_module.graph.nodes): - if node not in required_fw_nodes: - node.dist_from_bw = 0 - else: - node.dist_from_bw = int(1e9) - for user in node.users: - node.dist_from_bw = min(node.dist_from_bw, user.dist_from_bw + 1) - - aten = torch.ops.aten - prims = torch.ops.prims - - # compiler == "nvfuser" is the default set of recomputable ops - default_recomputable_ops = [ - aten.add, - aten.sub, - aten.div, - aten.atan2, - aten.mul, - aten.max, - aten.min, - aten.pow, - aten.remainder, - aten.fmod, - aten.__and__, - aten.__or__, - aten.__xor__, - aten.__lshift__, - aten.__rshift__, - aten.eq, - aten.ne, - aten.ge, - aten.gt, - aten.le, - aten.lt, - aten.abs, - aten.bitwise_not, - aten.ceil, - aten.floor, - aten.frac, - aten.neg, - aten.relu, - aten.round, - aten.silu, - aten.trunc, - aten.log, - aten.log10, - aten.log1p, - aten.log2, - aten.lgamma, - aten.exp, - aten.expm1, - aten.erf, - aten.erfc, - aten.cos, - aten.acos, - aten.cosh, - aten.sin, - aten.asin, - aten.sinh, - aten.tan, - aten.atan, - aten.tanh, - aten.atanh, - aten.sqrt, - aten.rsqrt, - aten.reciprocal, - aten.sigmoid, - aten.softplus, - aten.threshold, - aten.threshold_backward, - aten.clamp, - aten.where, - aten.lerp, - aten.addcmul, - aten.gelu, - aten.gelu_backward, - aten.sum, - aten.mean, - aten._grad_sum_to_size, - aten.sum_to_size, - aten.amax, - aten.to, - aten.type_as, - operator.getitem, - aten.squeeze, - aten.unsqueeze, - aten.rsub, - aten._to_copy, - ] # noqa: E501,B950 - view_ops = [aten.squeeze, aten.unsqueeze, aten.alias] - if compiler == "inductor": - default_recomputable_ops += [ - prims.div, - prims.convert_element_type, - aten.clone, - aten._to_copy, - aten.full_like, - prims.var, - prims.sum, - aten.var, - aten.std, - prims.broadcast_in_dim, - aten.select, - aten._unsafe_view, - aten.view, - aten.expand, - aten.slice, - aten.reshape, - aten.broadcast_tensors, - aten.scalar_tensor, - aten.ones, - aten.new_zeros, - aten.lift_fresh_copy, - aten.arange, - aten.triu, - aten.var_mean, - aten.isinf, - aten.any, - aten.full, - aten.as_strided, - aten.zeros, - aten.argmax, - aten.maximum, - prims.iota, - prims._low_memory_max_pool2d_offsets_to_indices, - ] # noqa: E501,B950 - view_ops += [ - aten.view, - aten.slice, - aten.t, - prims.broadcast_in_dim, - aten.expand, - aten.as_strided, - aten.permute, - ] - # Natalia said that we should allow recomputing indexing :) - default_recomputable_ops += [aten.index, aten.gather] - default_recomputable_ops += view_ops - - default_recomputable_ops += pointwise_ops() - - default_recomputable_ops += [ - aten.zeros_like, - ] - - default_recomputable_ops += [method_to_operator(m) for m in magic_methods] - recomputable_ops = ( - set(recomputable_ops) - if recomputable_ops is not None - else set(default_recomputable_ops) - ) - - random_ops = [aten.native_dropout, aten.rand_like, aten.randn_like] - compute_intensive_ops = [ - aten.mm, - aten.convolution, - aten.convolution_backward, - aten.bmm, - aten.addmm, - aten._scaled_dot_product_flash_attention, - aten._scaled_dot_product_efficient_attention, - aten.upsample_bilinear2d, - ] # noqa: E501,B950 +def get_saved_values( + joint_graph: fx.Graph, + node_info: NodeInfo, + min_cut_options: MinCutOptions, + dont_ban=None, +): + if dont_ban is None: + dont_ban = set() + op_types = get_default_op_list() - fusible_ops = recomputable_ops | set(random_ops) if AOT_PARTITIONER_DEBUG: joint_module_ops = { str(node.target._overloadpacket) - for node in joint_module.graph.nodes + for node in joint_graph.nodes if node.op == "call_function" and hasattr(node.target, "_overloadpacket") } - ops_ignored = joint_module_ops - {str(i) for i in recomputable_ops} + ops_ignored = joint_module_ops - {str(i) for i in op_types.recomputable_ops} print("Ops banned from rematerialization: ", ops_ignored) print() - BAN_IF_USED_FAR_APART = config.ban_recompute_used_far_apart - BAN_IF_LONG_FUSIBLE_CHAINS = config.ban_recompute_long_fusible_chains - BAN_IF_MATERIALIZED_BACKWARDS = config.ban_recompute_materialized_backward - BAN_IF_NOT_IN_ALLOWLIST = config.ban_recompute_not_in_allowlist - BAN_IF_REDUCTION = config.ban_recompute_reductions + def is_fusible(a, b): + # We can perform "memory fusion" into a cat, but cat cannot be a + # producer to a fusion + if get_aten_target(b) == aten.cat: + return True + return op_types.is_fusible(a) and op_types.is_fusible(b) - if config.aggressive_recomputation: - BAN_IF_MATERIALIZED_BACKWARDS = False - BAN_IF_USED_FAR_APART = False - BAN_IF_LONG_FUSIBLE_CHAINS = False - BAN_IF_NOT_IN_ALLOWLIST = False + try: + import networkx as nx + except ImportError as e: + raise RuntimeError( + "Need networkx installed to perform smart recomputation " "heuristics" + ) from e def is_materialized_backwards(node): - if get_aten_target(node) in view_ops: + if op_types.is_view(node): return False cur_nodes = {node} while len(cur_nodes) > 0: cur = cur_nodes.pop() for user in cur.users: - if user not in required_fw_nodes and not is_fusible(cur, user): + if not node_info.is_required_fw(user) and not is_fusible(cur, user): return True - if get_aten_target(user) in view_ops: + if op_types.is_view(user): cur_nodes.add(user) return False @@ -1020,17 +804,15 @@ def should_ban_recomputation(node): return False if node.target in [aten.lift_fresh_copy.default, aten.lift_fresh.default]: return False - # NB: "recompute" == 0 means that must save this node. if node.meta.get("recompute", None) == 0: return True - if BAN_IF_NOT_IN_ALLOWLIST: - if get_aten_target(node) not in recomputable_ops: + if min_cut_options.ban_if_not_in_allowlist: + if not op_types.is_recomputable(node): return True else: - ignored_ops = random_ops + compute_intensive_ops - if get_aten_target(node) in ignored_ops: + if op_types.is_random(node) or op_types.is_compute_intensive(node): return True # If a node *must* be materialized in the backwards pass, then we @@ -1038,7 +820,9 @@ def should_ban_recomputation(node): # general, the assumption we make is that recomputing a node in the # backwards pass is "free". However, if a node must be materialized # in the backwards pass, then recomputing it is never free. - if is_materialized_backwards(node) and BAN_IF_MATERIALIZED_BACKWARDS: + if min_cut_options.ban_if_materialized_backward and is_materialized_backwards( + node + ): log.info("materialized backwards: %s %s", node, tuple(node.users)) return True @@ -1046,16 +830,15 @@ def should_ban_recomputation(node): # modification appears to have made this heuristic a lot less critical # for performance. # NB: As of PR #121692, this hack no longer seems necessary. - if not graph_has_recomputable_ops: - if compiler == "inductor" and node.dist_from_bw > config.max_dist_from_bw: - return True + if node.dist_from_bw < 1000 and node.dist_from_bw > config.max_dist_from_bw: + return True # If the output of an op is 4x smaller (arbitrary choice), # then we don't allow recomputation. The idea here is that for # things like reductions, saving the output of the reduction is very # cheap/small, and it makes sure we don't do things like recompute # normalizations in the backwards. - if BAN_IF_REDUCTION: + if min_cut_options.ban_if_reduction: input_tensors_size = sum( _size_of(i) for i in node.args if isinstance(i, fx.Node) ) @@ -1069,9 +852,14 @@ def is_materialized(node): return not all(is_fusible(node, user) for user in node.users) - def get_node_weight(node) -> int: + def get_node_weight(node) -> float: mem_sz = _size_of(node) + if isinstance(node.meta["val"], py_sym_types): + # We never want to save symfloats + if not isinstance(node.meta["val"], torch.SymInt): + return INT_INF + # Heuristic to bias towards nodes closer to the backwards pass # Complete guess about current value mem_sz = int(mem_sz * (1.1 ** max(min(node.dist_from_bw, 100), 1))) @@ -1084,6 +872,11 @@ def get_node_weight(node) -> int: banned_nodes = set() def ban_recomputation_if_allowed(node): + if op_types.is_view(node): + return False + if node in dont_ban: + return False + # breakpoint() # This bans recomputation of the node unless we've been forced not to by # user annotation # NB: "recompute" > 0 means that user annotation has asked us to @@ -1106,8 +899,8 @@ def ban_recomputation_if_allowed(node): if node.op == "output": continue - if node in required_bw_nodes: - if node not in inputs: + if node in node_info.required_bw_nodes: + if node not in node_info.inputs: nx_graph.add_edge(node.name + "_in", "sink", capacity=math.inf) continue # If someone saves a input for backward as-is and backward @@ -1126,7 +919,7 @@ def ban_recomputation_if_allowed(node): # If a node can't be recomputed (too expensive or involves randomness), # we prevent it from being recomputed by adding an inf edge to the source # We only need to ban nodes in the fw pass, as those are the only ones that would be recomputed. - if node in required_fw_nodes and should_ban_recomputation(node): + if node_info.is_required_fw(node) and should_ban_recomputation(node): ban_recomputation_if_allowed(node) # Checks if a node is actually a tuple. Can be simplified to just an isinstance check if we always use faketensors. @@ -1135,12 +928,13 @@ def ban_recomputation_if_allowed(node): ) or ("val" in node.meta and not isinstance(node.meta["val"], torch.Tensor)) if is_sym_node(node): - weight = sym_node_size(node) + weight = float(sym_node_size(node)) elif is_non_tensor_node: - weight = 0 if isinstance(node.meta.get("val"), BackwardState) else math.inf + weight = ( + 0.0 if isinstance(node.meta.get("val"), BackwardState) else math.inf + ) else: weight = get_node_weight(node) - # Creates the weights on the "node" edge nx_graph.add_edge(node.name + "_in", node.name + "_out", capacity=weight) for user in node.users: @@ -1168,35 +962,40 @@ def find_first_unfusible(start_nodes: List[fx.Node], max_range: int) -> int: Finds the first unfusible node in the chain of nodes starting from `start_nodes` and returns its position. """ - sorted_nodes = [] + sorted_nodes: List[Tuple[int, fx.Node, bool]] = [] for n in start_nodes: - heapq.heappush(sorted_nodes, (n.fw_order, n, True)) + heapq.heappush(sorted_nodes, (node_info.get_fw_order(n), n, True)) while len(sorted_nodes) > 0: _, node, node_is_fusible = heapq.heappop(sorted_nodes) if not node_is_fusible: - return node.fw_order + return node_info.get_fw_order(node) for user in node.users: - if user in required_fw_nodes: - if user.fw_order > max_range: + if node_info.is_required_fw(user): + if node_info.get_fw_order(user) > max_range: continue heapq.heappush( - sorted_nodes, (user.fw_order, user, is_fusible(node, user)) + sorted_nodes, + (node_info.get_fw_order(user), user, is_fusible(node, user)), ) return max_range - if BAN_IF_USED_FAR_APART: - for used_node in required_fw_nodes: + if min_cut_options.ban_if_used_far_apart: + for used_node in node_info.required_fw_nodes: orders = [ - user.fw_order for user in used_node.users if user in required_fw_nodes + node_info.get_fw_order(user) + for user in used_node.users + if user in node_info.required_fw_nodes + ] + fw_users = [ + user for user in used_node.users if node_info.is_required_fw(user) ] - fw_users = [user for user in used_node.users if user in required_fw_nodes] if len(orders) > 0: first_unfusible_use = find_first_unfusible(fw_users, max(orders)) for user in tuple(used_node.users): if ( - user in required_fw_nodes - and user.fw_order > first_unfusible_use + user in node_info.required_fw_nodes + and node_info.get_fw_order(user) > first_unfusible_use and is_fusible(used_node, user) ): if user in banned_nodes: @@ -1204,10 +1003,10 @@ def find_first_unfusible(start_nodes: List[fx.Node], max_range: int) -> int: log.info( "used above/below fusible %s:(%s) -> %s -> %s:(%s)", used_node, - used_node.fw_order, + node_info.get_fw_order(used_node), first_unfusible_use, user, - user.fw_order, + node_info.get_fw_order(user), ) ban_recomputation_if_allowed(user) @@ -1222,47 +1021,51 @@ def find_first_unfusible(start_nodes: List[fx.Node], max_range: int) -> int: # Some models it improves perf on are cait_m36_384, mixer_b16_224, poolformer_m36 - if BAN_IF_LONG_FUSIBLE_CHAINS: + if min_cut_options.ban_if_long_fusible_chains: visited = set() for start_node in joint_graph.nodes: - if start_node not in required_fw_nodes: + if start_node not in node_info.required_fw_nodes: continue - fusible = [(start_node.fw_order, start_node)] - start_order = start_node.fw_order + fusible = [(node_info.get_fw_order(start_node), start_node)] + start_order = node_info.get_fw_order(start_node) while len(fusible) > 0: _, cur = heapq.heappop(fusible) if cur in visited: continue visited.add(cur) # 100 is arbitrary choice to try and prevent degenerate cases - if cur.fw_order > start_order + 100 and len(fusible) == 0: + if ( + node_info.get_fw_order(cur) > start_order + 100 + and len(fusible) == 0 + ): log.info( "too long %s %s %s %s", cur, start_node, - cur.fw_order, - start_node.fw_order, + node_info.get_fw_order(cur), + node_info.get_fw_order(start_node), ) ban_recomputation_if_allowed(cur) break for user in cur.users: if ( - user in required_fw_nodes + user in node_info.required_fw_nodes and is_fusible(cur, user) and user not in banned_nodes ): - heapq.heappush(fusible, (user.fw_order, user)) + heapq.heappush(fusible, (node_info.get_fw_order(user), user)) try: cut_value, partition = nx.minimum_cut(nx_graph, "source", "sink") except Exception: print("Failed to compute min-cut on following graph:") print("\n".join(nx.readwrite.edgelist.generate_edgelist(nx_graph))) + visualize_min_cut_graph(nx_graph) raise reachable, non_reachable = partition - cutset = set() + cutset: Set[Tuple[str, str]] = set() for u, nbrs in ((n, nx_graph[n]) for n in reachable): cutset.update((u, v) for v in nbrs if v in non_reachable) @@ -1272,14 +1075,347 @@ def find_first_unfusible(start_nodes: List[fx.Node], max_range: int) -> int: node_name = node_in[:-3] cut_nodes.add(node_name) + name_to_node = get_name_to_node(joint_graph) # To make this stuff deterministic - node_idx = {node: idx for idx, node in enumerate(joint_module.graph.nodes)} + node_idx = {node: idx for idx, node in enumerate(joint_graph.nodes)} saved_values = sorted( (name_to_node[node] for node in cut_nodes), key=lambda x: node_idx[x] ) + return saved_values, banned_nodes + + +def visualize_min_cut_graph(nx_graph): + import networkx as nx + import pydot + + dot_format = nx.nx_pydot.to_pydot(nx_graph).to_string() + dot_graph = pydot.graph_from_dot_data(dot_format)[0] + for edge in dot_graph.get_edges(): + weight = nx_graph[edge.get_source()][edge.get_destination()]["capacity"] + # Set edge label to weight + edge.set_label(str(weight)) + # Color edges with weight 'inf' as red + if weight == float("inf"): + edge.set_color("red") + print("Visualizing the failed graph to min_cut_failed.svg") + dot_graph.write_svg("min_cut_failed.svg") + + +def get_default_op_list() -> OpTypes: + default_recomputable_ops: List[Callable] = [ + aten.add, + aten.sub, + aten.div, + aten.atan2, + aten.mul, + aten.max, + aten.min, + aten.pow, + aten.remainder, + aten.fmod, + aten.__and__, + aten.__or__, + aten.__xor__, + aten.__lshift__, + aten.__rshift__, + aten.eq, + aten.ne, + aten.ge, + aten.gt, + aten.le, + aten.lt, + aten.abs, + aten.bitwise_not, + aten.ceil, + aten.floor, + aten.frac, + aten.neg, + aten.relu, + aten.round, + aten.silu, + aten.trunc, + aten.log, + aten.log10, + aten.log1p, + aten.log2, + aten.lgamma, + aten.exp, + aten.expm1, + aten.erf, + aten.erfc, + aten.cos, + aten.acos, + aten.cosh, + aten.sin, + aten.asin, + aten.sinh, + aten.tan, + aten.atan, + aten.tanh, + aten.atanh, + aten.sqrt, + aten.rsqrt, + aten.reciprocal, + aten.sigmoid, + aten.softplus, + aten.threshold, + aten.threshold_backward, + aten.clamp, + aten.where, + aten.lerp, + aten.addcmul, + aten.gelu, + aten.gelu_backward, + aten.sum, + aten.mean, + aten._grad_sum_to_size, + aten.sum_to_size, + aten.amax, + aten.to, + aten.type_as, + operator.getitem, + aten.squeeze, + aten.unsqueeze, + aten.rsub, + aten._to_copy, + ] # noqa: E501,B950 + recomputable_view_ops = [aten.squeeze, aten.unsqueeze, aten.alias] + recomputable_view_ops += [ + aten.view, + aten.slice, + aten.t, + prims.broadcast_in_dim, + aten.expand, + aten.as_strided, + aten.permute, + ] + view_ops = recomputable_view_ops + default_recomputable_ops += [ + prims.div, + prims.convert_element_type, + aten.clone, + aten._to_copy, + aten.full_like, + prims.var, + prims.sum, + aten.var, + aten.std, + prims.broadcast_in_dim, + aten.select, + aten._unsafe_view, + aten.view, + aten.expand, + aten.slice, + aten.reshape, + aten.broadcast_tensors, + aten.scalar_tensor, + aten.ones, + aten.new_zeros, + aten.lift_fresh_copy, + aten.arange, + aten.triu, + aten.var_mean, + aten.isinf, + aten.any, + aten.full, + aten.as_strided, + aten.zeros, + aten.argmax, + aten.maximum, + prims.iota, + prims._low_memory_max_pool2d_offsets_to_indices, + ] # noqa: E501,B950 + # Natalia said that we should allow recomputing indexing :) + default_recomputable_ops += [aten.index, aten.gather] + default_recomputable_ops += view_ops + + default_recomputable_ops += pointwise_ops() + + default_recomputable_ops += [ + aten.zeros_like, + ] + + default_recomputable_ops += [method_to_operator(m) for m in magic_methods] + recomputable_ops = set(default_recomputable_ops) + + random_ops = [aten.native_dropout, aten.rand_like, aten.randn_like] + compute_intensive_ops = [ + aten.mm, + aten.convolution, + aten.convolution_backward, + aten.bmm, + aten.addmm, + aten._scaled_dot_product_flash_attention, + aten._scaled_dot_product_efficient_attention, + aten.upsample_bilinear2d, + ] # noqa: E501,B950 + + fusible_ops = recomputable_ops | set(random_ops) + return OpTypes( + set(fusible_ops), + set(compute_intensive_ops), + set(random_ops), + set(view_ops), + set(recomputable_ops), + ) + + +def get_name_to_node(graph: fx.Graph): + name_to_node = {} + for node in graph.nodes: + name_to_node[node.name] = node + return name_to_node + + +def choose_saved_values_set( + joint_graph: fx.Graph, node_info: NodeInfo, memory_budget=1 +) -> List[fx.Node]: + min_cut_options = MinCutOptions( + ban_if_used_far_apart=config.ban_recompute_used_far_apart, + ban_if_long_fusible_chains=config.ban_recompute_long_fusible_chains, + ban_if_materialized_backward=config.ban_recompute_materialized_backward, + ban_if_not_in_allowlist=config.ban_recompute_not_in_allowlist, + ban_if_reduction=config.ban_recompute_reductions, + ) + + if config.aggressive_recomputation: + min_cut_options = replace( + min_cut_options, + ban_if_used_far_apart=False, + ban_if_long_fusible_chains=False, + ban_if_materialized_backward=False, + ban_if_not_in_allowlist=False, + ) + + if memory_budget == 0: + return node_info.inputs + + runtime_optimized_saved_values, _ = get_saved_values( + joint_graph, + node_info, + min_cut_options, + ) + return runtime_optimized_saved_values + + +def min_cut_rematerialization_partition( + joint_module: fx.GraphModule, + _joint_inputs, + compiler="inductor", + *, + num_fwd_outputs, +) -> Tuple[fx.GraphModule, fx.GraphModule]: + """ + Partitions the joint graph such that the backward recomputes the forward. + Recomputing helps in trading off memory bandwidth with computation. + + To create the fwd and bwd graph, we copy the joint graph, manually set the + outputs to just original forward or backward outputs. And then we run the + resulting graphs through dead code elimination. + + .. warning:: + This API is experimental and likely to change. + + Args: + joint_module(fx.GraphModule): The joint forward and backward graph. This + is the result of AOT Autograd tracing. + _joint_inputs: The inputs to the joint graph. This is unused. + compiler: This option determines the default set of recomputable ops. + Currently, there are two options: ``nvfuser`` and ``inductor``. + recomputable_ops: This is an optional set of recomputable ops. If this + is not None, then this set of ops will be used instead of the + default set of ops. + num_fwd_outputs: The number of outputs from the forward graph. + + Returns: + Returns the generated forward and backward Fx graph modules. + """ + + joint_module.graph.eliminate_dead_code() + joint_module.recompile() + + fx_g = joint_module.graph + + # add the CSE pass + if config.cse: + cse_graph = fx_graph_cse(fx_g) + joint_module.graph = cse_graph + joint_graph = joint_module.graph + + graph_has_recomputable_ops = has_recomputable_ops(joint_module) + graph_has_recomputable_rng_ops = has_recomputable_rng_ops(joint_module) + if graph_has_recomputable_ops: + joint_module = cleanup_recompute_tags(joint_module) + + def classify_nodes(joint_module): + name_to_node = get_name_to_node(joint_module.graph) + required_bw_nodes = set() + for node in joint_module.graph.nodes: + if node.op == "placeholder" and "tangents" in node.target: + required_bw_nodes.add(node) + if node in required_bw_nodes: + for user in node.users: + required_bw_nodes.add(user) + + primal_inputs = list(filter(_is_primal, joint_module.graph.nodes)) + fwd_seed_offset_inputs = list( + filter(_is_fwd_seed_offset, joint_module.graph.nodes) + ) + inputs = primal_inputs + fwd_seed_offset_inputs + fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs( + joint_module, num_fwd_outputs=num_fwd_outputs + ) + required_bw_nodes.update( + o for o in bwd_outputs if o is not None and o.op != "output" + ) + forward_only_graph = _extract_graph_with_inputs_outputs( + joint_module.graph, inputs, fwd_outputs + ) + required_fw_nodes: Set[fx.Node] = { + name_to_node[node.name] + for node in forward_only_graph.nodes + if node.op != "output" + } + unclaimed_nodes = { + node + for node in joint_module.graph.nodes + if node not in required_fw_nodes and node not in required_bw_nodes + } + fw_cnt = 0 + fw_order = {} + for node in joint_module.graph.nodes: + if node in required_fw_nodes: + fw_order[node] = fw_cnt + fw_cnt += 1 + return NodeInfo( + inputs, required_fw_nodes, required_bw_nodes, unclaimed_nodes, fw_order + ) + + node_info = classify_nodes(joint_module) + + # networkx blows up on graphs with no required backward nodes + # Since there's nothing to partition anyway, and the default partitioner can "handle" + # this case, send our graph over to the default partitioner. + if len(node_info.required_bw_nodes) == 0: + return default_partition( + joint_module, _joint_inputs, num_fwd_outputs=num_fwd_outputs + ) + + for node in reversed(joint_module.graph.nodes): + if node.op == "output": + node.dist_from_bw = int(1e9) + elif node not in node_info.required_fw_nodes: + node.dist_from_bw = 0 + else: + node.dist_from_bw = int(1e9) + for user in node.users: + node.dist_from_bw = min(node.dist_from_bw, user.dist_from_bw + 1) + + saved_values = choose_saved_values_set(joint_graph, node_info, memory_budget=1) # save_for_backward on tensors and stashes symints in autograd .ctx saved_sym_nodes = list(filter(is_sym_node, saved_values)) saved_values = list(filter(lambda n: not is_sym_node(n), saved_values)) + # NB: saved_sym_nodes will be mutated to reflect the actual saved symbols fw_module, bw_module = _extract_fwd_bwd_modules( joint_module, @@ -1312,7 +1448,7 @@ def find_first_unfusible(start_nodes: List[fx.Node], max_range: int) -> int: } remat_nodes = fw_module_nodes & bw_module_nodes - counts = defaultdict(int) + counts: Dict[str, int] = defaultdict(int) for node in fw_module.graph.nodes: if node.name in remat_nodes and hasattr(node.target, "_overloadpacket"): counts[str(node.target._overloadpacket)] += 1 @@ -1321,7 +1457,7 @@ def find_first_unfusible(start_nodes: List[fx.Node], max_range: int) -> int: ) print( "Count of Ops Rematerialized: ", - sorted(counts.items(), key=operator.itemgetter(1), reverse=True), + sorted(counts.items(), key=lambda x: x[1], reverse=True), ) return fw_module, bw_module @@ -1331,7 +1467,7 @@ def draw_graph( fname: str, figname: str = "fx_graph", clear_meta: bool = True, - prog: Union[str, List[str]] = None, + prog: Optional[Union[str, List[str]]] = None, parse_stack_trace: bool = False, dot_graph_shape: Optional[str] = None, ) -> None: @@ -1357,13 +1493,3 @@ def draw_graph( write_method(fname) else: write_method(fname, prog=prog) - - -def draw_joint_graph( - graph: torch.fx.GraphModule, - joint_inputs, - file_name: str = "full_graph.png", - dot_graph_shape: Optional[str] = None, -): - draw_graph(graph, file_name, dot_graph_shape=dot_graph_shape) - return default_partition(graph, joint_inputs) From 15ca562f863ffe69d76c0ccaf448b27d18ceb2e8 Mon Sep 17 00:00:00 2001 From: wz337 Date: Fri, 17 May 2024 06:57:49 +0000 Subject: [PATCH 052/116] [DTensor] Turn on foreach implementation for clip_grad_norm_ for DTensor by default (#126423) Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/126423 Approved by: https://github.com/awgu --- .../fsdp/test_fully_shard_clip_grad_norm_.py | 2 +- torch/distributed/_tensor/__init__.py | 18 +++++++++++++----- torch/utils/_foreach_utils.py | 4 +++- 3 files changed, 17 insertions(+), 7 deletions(-) diff --git a/test/distributed/_composable/fsdp/test_fully_shard_clip_grad_norm_.py b/test/distributed/_composable/fsdp/test_fully_shard_clip_grad_norm_.py index d3978febec09..9139b62f1367 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_clip_grad_norm_.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_clip_grad_norm_.py @@ -67,11 +67,11 @@ def _test_clip_grad_norm( ) comm_mode = CommDebugMode() with comm_mode: + # foreach is default to turn on so we don't need to specify it. total_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), max_norm=max_norm, norm_type=norm_type, - foreach=True, ) self.assertEqual(ref_total_norm, total_norm.full_tensor()) # Expect one all-reduce per mesh dim for partial -> replicate diff --git a/torch/distributed/_tensor/__init__.py b/torch/distributed/_tensor/__init__.py index f7afe41e753c..7c391c4821aa 100644 --- a/torch/distributed/_tensor/__init__.py +++ b/torch/distributed/_tensor/__init__.py @@ -10,7 +10,12 @@ from torch.distributed._tensor.ops.utils import normalize_to_torch_size from torch.distributed._tensor.placement_types import Placement, Replicate, Shard from torch.distributed.device_mesh import _mesh_resources, DeviceMesh, init_device_mesh -from torch.optim.optimizer import _foreach_supported_types +from torch.optim.optimizer import ( + _foreach_supported_types as _optim_foreach_supported_types, +) +from torch.utils._foreach_utils import ( + _foreach_supported_types as _util_foreach_supported_types, +) # All public APIs from dtensor package @@ -25,10 +30,13 @@ ] -# Append DTensor to the list of supported types for foreach implementation of optimizer -# so that we will try to use foreach over the for-loop implementation on CUDA. -if DTensor not in _foreach_supported_types: - _foreach_supported_types.append(DTensor) +# Append DTensor to the list of supported types for foreach implementation for optimizer +# and clip_grad_norm_ so that we will try to use foreach over the for-loop implementation on CUDA. +if DTensor not in _optim_foreach_supported_types: + _optim_foreach_supported_types.append(DTensor) + +if DTensor not in _util_foreach_supported_types: + _util_foreach_supported_types.append(DTensor) def _dtensor_init_helper( diff --git a/torch/utils/_foreach_utils.py b/torch/utils/_foreach_utils.py index 840df0432b53..6f8a9b5b7e23 100644 --- a/torch/utils/_foreach_utils.py +++ b/torch/utils/_foreach_utils.py @@ -15,6 +15,8 @@ def _get_fused_kernels_supported_devices() -> List[str]: TensorListList: TypeAlias = List[List[Optional[Tensor]]] Indices: TypeAlias = List[int] +_foreach_supported_types = [torch.Tensor] + # This util function splits tensors into groups by device and dtype, which is useful before sending # tensors off to a foreach implementation, which requires tensors to be on one device and dtype. @@ -44,4 +46,4 @@ def _device_has_foreach_support(device: torch.device) -> bool: def _has_foreach_support(tensors: List[Tensor], device: torch.device) -> bool: - return _device_has_foreach_support(device) and all(t is None or type(t) == torch.Tensor for t in tensors) + return _device_has_foreach_support(device) and all(t is None or type(t) in _foreach_supported_types for t in tensors) From 2edaae436ac570949b3595f0c72f6808037555d2 Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Thu, 16 May 2024 21:57:57 +0000 Subject: [PATCH 053/116] Fix cummax and cummin lowering for empty case (#126461) Pull Request resolved: https://github.com/pytorch/pytorch/pull/126461 Approved by: https://github.com/peterbell10 --- test/inductor/test_torchinductor_opinfo.py | 2 ++ torch/_inductor/lowering.py | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index fdb9a8c37a47..9bd873ac747b 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -438,6 +438,8 @@ def wrapper_noop_set_seed(op, *args, **kwargs): "mH", "rsub", "triu", + "cummax", + "cummin", } diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 77d7b6c046de..80eee352458d 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -5323,7 +5323,7 @@ def log_add_exp_helper(a_tuple, b_tuple): def cummax(x, axis=None): if len(x.get_size()) == 0: assert axis in [0, -1] - return clone(x), torch.empty_like(x, dtype=torch.int64) + return clone(x), empty_like(x, dtype=torch.int64) dtype = x.get_dtype() combine_fn = ir.get_reduction_combine_fn( @@ -5353,7 +5353,7 @@ def cummax(x, axis=None): def cummin(x, axis=None): if len(x.get_size()) == 0: assert axis in [0, -1] - return clone(x), torch.empty_like(x, dtype=torch.int64) + return clone(x), empty_like(x, dtype=torch.int64) dtype = x.get_dtype() combine_fn = ir.get_reduction_combine_fn( From 45f2d0945230ff00095413dcd570500eca0bcd6e Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Fri, 17 May 2024 10:42:21 +0800 Subject: [PATCH 054/116] [Quant][Inductor] Enable lowering of qlinear-binary(-unary) fusion for X86Inductor (#122593) **Description** Lower the qlinear binary post op pattern to Inductor. Use post op sum (in-place) if the extra input has the same dtype as output. Otherwise, it uses binary add. **Supported linear-binary(-unary) patterns** ``` linear(X) extra input \ / Add | Optional(relu) | Y 1. int8-mixed-fp32 +---+---------------+-----------+------------------------------+---------+ | # | Add type | Quant out | Pattern | Post op | +---+---------------+-----------+------------------------------+---------+ | 1 | In-/out-place | Yes | linear + fp32 -> (relu) -> q | add | +---+---------------+-----------+------------------------------+---------+ | 2 | In-/out-place | No | linear + fp32 -> (relu) | sum | +---+---------------+-----------+------------------------------+---------+ 2. int8-mixed-bf16 +---+----------+---------------+-----------+--------------------------------------------------+---------+ | # | X2 dtype | Add type | Quant out | Pattern | Post op | +---+----------+---------------+-----------+--------------------------------------------------+---------+ | 1 | BF16 | In-/out-place | Yes | linear + bf16 -> (relu) -> to_fp32 -> q | add | +---+----------+---------------+-----------+--------------------------------------------------+---------+ | 2 | BF16 | In-/out-place | No | linear + bf16 -> (relu) | sum | +---+----------+---------------+-----------+--------------------------------------------------+---------+ | 3 | FP32 | Out-place | Yes | linear + fp32 -> (relu) -> q | add | | | | In-place right| | | | +---+----------+---------------+-----------+--------------------------------------------------+---------+ | 4 | FP32 | Out-place | No | linear + fp32 -> (relu) | sum | | | | In-place right| | | | +---+----------+---------------+-----------+--------------------------------------------------+---------+ | 5 | FP32 | In-place left | Yes | linear + fp32 -> to_bf16 -> relu -> to_fp32 -> q | add | +---+----------+---------------+-----------+--------------------------------------------------+---------+ | 6 | FP32 | In-place left | No | linear + fp32 -> to_bf16 -> (relu) | add | +---+----------+---------------+-----------+--------------------------------------------------+---------+ ``` Note (1) The positions of linear and the extra input can be swapped. (2) we don't insert q-dq before the extra input of linear-add by recipe. But if q-dq is found at the extra input, we don't match that pattern because we cannot match all these patterns in 3 passes. **Test plan** python test/inductor/test_mkldnn_pattern_matcher.py -k test_qlinear_add python test/inductor/test_cpu_cpp_wrapper.py -k test_qlinear_add Pull Request resolved: https://github.com/pytorch/pytorch/pull/122593 Approved by: https://github.com/leslie-fang-intel, https://github.com/jgong5, https://github.com/eellison --- test/inductor/test_cpu_cpp_wrapper.py | 18 + test/inductor/test_mkldnn_pattern_matcher.py | 138 +++++++ torch/_inductor/fx_passes/quantization.py | 375 ++++++++++++++++++- torch/_inductor/graph.py | 2 + torch/_inductor/ir.py | 226 +++++++++++ torch/_inductor/mkldnn_lowerings.py | 64 ++++ 6 files changed, 804 insertions(+), 19 deletions(-) diff --git a/test/inductor/test_cpu_cpp_wrapper.py b/test/inductor/test_cpu_cpp_wrapper.py index 66b92eedc97c..0888f3ad47a1 100644 --- a/test/inductor/test_cpu_cpp_wrapper.py +++ b/test/inductor/test_cpu_cpp_wrapper.py @@ -296,6 +296,24 @@ class BaseTest(NamedTuple): test_mkldnn_pattern_matcher.TestPatternMatcher(), condition=torch.backends.mkldnn.is_available(), ), + BaseTest( + "test_qlinear_gelu", + "cpu", + test_mkldnn_pattern_matcher.TestPatternMatcher(), + condition=torch.backends.mkldnn.is_available(), + ), + BaseTest( + "test_qlinear_add", + "cpu", + test_mkldnn_pattern_matcher.TestPatternMatcher(), + condition=torch.backends.mkldnn.is_available(), + ), + BaseTest( + "test_qlinear_add_relu", + "cpu", + test_mkldnn_pattern_matcher.TestPatternMatcher(), + condition=torch.backends.mkldnn.is_available(), + ), BaseTest( "test_qlinear_dequant_promotion", "cpu", diff --git a/test/inductor/test_mkldnn_pattern_matcher.py b/test/inductor/test_mkldnn_pattern_matcher.py index cbf9dd89c506..756de35df84c 100644 --- a/test/inductor/test_mkldnn_pattern_matcher.py +++ b/test/inductor/test_mkldnn_pattern_matcher.py @@ -1589,6 +1589,144 @@ def test_qlinear_gelu_int8_mixed_bf16(self): (torch.randn((2, 4)),), gelu, int8_mixed_bf16=True ) + def _qlinear_add_cpu_test_helper(self, use_relu=False, int8_mixed_bf16=False): + r""" + This testcase will quantize two consecutive Linear->Add(->relu) patterns as: + X + / \ + linear(X) linear(X) + \ / + Add + | + Optional(relu) + / \ + linear(X) linear(X) + \ / + Add + | + Optional(relu) + | + Y + """ + + def fake_quant(x): + # to produce a float32 result as extra input + qlib = torch.ops.quantized_decomposed + x = qlib.quantize_per_tensor.default(x, 0.0166785, 42, 0, 255, torch.uint8) + x = qlib.dequantize_per_tensor.default( + x, 0.0166785, 42, 0, 255, torch.uint8 + ) + return x + + class M(torch.nn.Module): + def __init__( + self, + add_fn, + use_relu, + fake_quant_before_extra_input, + ): + super().__init__() + self.linear1 = torch.nn.Linear(4, 4) + self.linear2 = torch.nn.Linear(4, 4) + self.add_fn = add_fn + self.relu = torch.nn.ReLU() + self.linear3 = torch.nn.Linear(4, 4) + self.linear4 = torch.nn.Linear(4, 4) + self.add_fn2 = add_fn + self.relu2 = torch.nn.ReLU() + self.use_relu = use_relu + self.fake_quant_before_extra_input = fake_quant_before_extra_input + + def forward(self, x): + x1 = self.linear1(x) + x2 = self.linear2(x) + if self.fake_quant_before_extra_input: + x2 = fake_quant(x2) + tmp = self.add_fn(x1, x2) + if self.use_relu: + tmp = self.relu(tmp) + tmp1 = self.linear3(tmp) + tmp2 = self.linear4(tmp) + if self.fake_quant_before_extra_input: + tmp2 = fake_quant(tmp2) + res = self.add_fn2(tmp1, tmp2) + if self.use_relu: + res = self.relu2(res) + return res + + add_fn_list = [ + lambda x, y: x + y, + lambda x, y: y + x, + lambda x, y: x.add_(y), + lambda x, y: y.add_(x), + ] + fake_quant_x2_list = [False, True] if int8_mixed_bf16 else [False] + cases = itertools.product(add_fn_list, fake_quant_x2_list) + for add_fn, fq_x2 in cases: + mod = M(add_fn, use_relu, fq_x2).eval() + v = torch.randn((4, 4), dtype=torch.float32, requires_grad=False).add(1) + + def matcher_check_fn(): + # 1. Dequant-linear pattern matched in quantization weight prepack * 4 + self.assertEqual( + counters["inductor"]["qlinear_weight_prepack_matcher_count"], 4 + ) + # pattern = [dequant_per_tensor, (convert_dtype), dequant_per_channel, (convert_dtype), permute, addmm] + nodes_per_match = 6 if int8_mixed_bf16 else 4 + self.assertEqual( + counters["inductor"]["qlinear_weight_prepack_matcher_nodes"], + 4 * nodes_per_match, + ) + # 2. Qlinear Binary Unary fusion in post-grad fusion pass * 2 + self.assertEqual( + counters["inductor"]["qlinear_binary_matcher_count"], 2 + ) + # Two linear-binary patterns are matched + # matched patter1 = [qlinear, add, (convert dtype), (relu), quantize_per_tensor] + # matched patter2 = [qlinear, add, (convert dtype), (relu)] + # If add_fn is x.add_(y), x is bf16 and y is fp32, there is a to_bf16 node after binary + to_bf16_after_binary = 2 * (add_fn == add_fn_list[2] and fq_x2) + self.assertEqual( + counters["inductor"]["qlinear_binary_matcher_nodes"], + 5 + 2 * use_relu + to_bf16_after_binary, + ) + + for is_qat in [False, True]: + self._test_common( + mod, + (v,), + check_quantization=True, + check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float, + matcher_check_fn=matcher_check_fn, + is_qat=is_qat, + ) + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfRocm + def test_qlinear_add_cpu(self): + self._qlinear_add_cpu_test_helper() + + @skipIfNoDynamoSupport + @skipIfNoONEDNNBF16 + @skipIfNoONEDNN + @skipIfRocm + def test_qlinear_add_int8_mixed_bf16(self): + self._qlinear_add_cpu_test_helper(int8_mixed_bf16=True) + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfRocm + def test_qlinear_add_relu_cpu(self): + self._qlinear_add_cpu_test_helper(use_relu=True) + + @skipIfNoDynamoSupport + @skipIfNoONEDNNBF16 + @skipIfNoONEDNN + @skipIfRocm + def test_qlinear_add_relu_int8_mixed_bf16(self): + self._qlinear_add_cpu_test_helper(use_relu=True, int8_mixed_bf16=True) + def _qlinear_dequant_promotion_cpu_test_helper( self, inputs, diff --git a/torch/_inductor/fx_passes/quantization.py b/torch/_inductor/fx_passes/quantization.py index 0d4fc3b42933..4476a9ccd512 100644 --- a/torch/_inductor/fx_passes/quantization.py +++ b/torch/_inductor/fx_passes/quantization.py @@ -224,17 +224,26 @@ def generate_pattern_with_binary( binary_post_op, computation_call, extra_input_pattern, - int8_mixed_bf16_with_inplace_add=False, + dtype_convert=False, + swap_inputs=False, ): - binary_pattern = CallFunction( - binary_post_op, - computation_call, - extra_input_pattern, + binary_pattern = ( + CallFunction( + binary_post_op, + extra_input_pattern, + computation_call, + ) + if swap_inputs + else CallFunction( + binary_post_op, + computation_call, + extra_input_pattern, + ) ) return _may_generate_pattern_with_dtype_convert( binary_pattern, KeywordArg("convert_dtype_after_inplace_add"), - int8_mixed_bf16_with_inplace_add, + dtype_convert, ) @@ -435,10 +444,109 @@ def qlinear(match: Match, *args, **kwargs): return qlinear -def _is_valid_quantized_conv_binary_optimization_pattern(): - # Check if it's a valid Conv Binary Pattern: - # * qconv2d_pointwise should only has one users - # * Extra input of binary node comes from dequant pattern +def _register_quantized_linear_binary_lowering( + pattern, + pass_number, + computation_op, + binary_unary_attr, +): + @register_lowering_pattern( + pattern, + extra_check=_is_valid_qlinear_binary_optimization_pattern(), + pass_number=pass_number, + ) + def qlinear_binary(match: Match, *args, **kwargs): + output_dtype = _get_pattern_output_dtype(match) + # Activation QParams + x, x_scale, x_zp = ( + kwargs["x"], + kwargs["x_scale"], + kwargs["x_zp"], + ) + x2 = ( + kwargs["accum"] + if binary_unary_attr.binary_op_name == "sum" + else kwargs["other"] + ) + x2_scale = 1.0 + x2_zp = 0 + # Weight QParams + packed_weight, w_scale, w_zp = ( + kwargs["packed_weight"], + kwargs["w_scale"], + kwargs["w_zp"], + ) + # bias + b = kwargs["b"] if "b" in kwargs else None + # Output QParams + o_inv_scale = kwargs["o_inv_scale"] if output_dtype is None else 1.0 + o_zero_point = kwargs["o_zp"] if output_dtype is None else 0 + + x2.realize() + from .mkldnn_fusion import _can_be_inplace + + if binary_unary_attr.binary_op_name == "sum": + assert _can_be_inplace( + x2 + ), "QLinear Binary Inplace Fusion requires accum is not an alias or mutation." + + # if the binary post op is sum but output dtype is not the same as accum, + # use accum's dtype as output dtype + out_dtype = output_dtype + if ( + output_dtype + and binary_unary_attr.binary_op_name == "sum" + and output_dtype != x2.dtype + ): + out_dtype = x2.dtype + + computation_args = ( + x, + x_scale, + x_zp, + packed_weight, + w_scale, + w_zp, + b, + o_inv_scale, + o_zero_point, + out_dtype, + x2, + x2_scale, + x2_zp, + binary_unary_attr.binary_op_name, + binary_unary_attr.alpha, + binary_unary_attr.unary_op_name, + binary_unary_attr.scalars_attr, + binary_unary_attr.algorithm_attr, + ) + counters["inductor"]["qlinear_binary_matcher_count"] += 1 + counters["inductor"]["qlinear_binary_matcher_nodes"] += len(match.nodes) + return L[computation_op](*computation_args) + + return qlinear_binary + + +def _is_valid_qconv_binary_optimization_pattern(): + return _is_valid_quantized_op_binary_optimization_pattern( + torch.ops.onednn.qconv2d_pointwise + ) + + +def _is_valid_qlinear_binary_optimization_pattern(): + return _is_valid_quantized_op_binary_optimization_pattern( + torch.ops.onednn.qlinear_pointwise, + # we don't insert q-dq for extra input due to accuracy issues + extra_input_from_dequant=False, + ) + + +def _is_valid_quantized_op_binary_optimization_pattern( + qop, extra_input_from_dequant=True +): + # Check if it's a valid Binary Pattern for qconv2d and qlinear: + # * qop_pointwise should only has one users + # * If extra_input_from_dequant is True, extra input of binary node should come from dequant pattern # * the two inputs of binary node should have attribute "meta" and should be tensors # * the two inputs of binary node should have the same shape # * All users of the extra input in this pattern should be @@ -446,8 +554,8 @@ def _is_valid_quantized_conv_binary_optimization_pattern(): # connected to the compute node. def fn(match): output_dtype = _get_pattern_output_dtype(match) - compute_node = filter_nodes(match.nodes, torch.ops.onednn.qconv2d_pointwise)[0] - # qconv2d_pointwise should only have one user + compute_node = filter_nodes(match.nodes, qop)[0] + # qop_pointwise should only have one user if len(compute_node.users) != 1: return False binary_node_inputs = next(iter(compute_node.users)).args @@ -460,9 +568,12 @@ def fn(match): break assert extra_input_of_binary_node is not None # Extra input of binary node comes from dequant pattern - if (not isinstance(extra_input_of_binary_node, torch.fx.Node)) or ( - extra_input_of_binary_node.target - != quantized_decomposed.dequantize_per_tensor.default + if extra_input_from_dequant and ( + (not isinstance(extra_input_of_binary_node, torch.fx.Node)) + or ( + extra_input_of_binary_node.target + != quantized_decomposed.dequantize_per_tensor.default + ) ): return False @@ -489,9 +600,13 @@ def fn(match): from .mkldnn_fusion import _get_remaining_users extra_input_of_pattern = ( - match.kwargs["accum"] - if output_dtype is None - else match.kwargs["accum_after_dequant"] + match.kwargs["other"] + if "other" in match.kwargs + else ( + match.kwargs["accum"] + if output_dtype is None or (not extra_input_from_dequant) + else match.kwargs["accum_after_dequant"] + ) ) if ( len( @@ -517,7 +632,7 @@ def _register_quantized_conv_binary_lowering( ): @register_lowering_pattern( pattern, - extra_check=_is_valid_quantized_conv_binary_optimization_pattern(), + extra_check=_is_valid_qconv_binary_optimization_pattern(), pass_number=pass_number, ) def qconv_binary(match: Match, *args, **kwargs): @@ -884,6 +999,228 @@ def __init__( binary_unary_attr, # binary_unary_attr ) + # QLinear + r""" + Supported linear-binary(-unary) patterns + + linear(X) extra input + \ / + Add + | + Optional(relu) + | + Y + + 1. int8-mixed-fp32 + +---+---------------+-----------+------------------------------+---------+ + | # | Add type | Quant out | Pattern | Post op | + +---+---------------+-----------+------------------------------+---------+ + | 1 | In-/out-place | Yes | linear + fp32 -> (relu) -> q | add | + +---+---------------+-----------+------------------------------+---------+ + | 2 | In-/out-place | No | linear + fp32 -> (relu) | sum | + +---+---------------+-----------+------------------------------+---------+ + + 2. int8-mixed-bf16 + +---+----------+---------------+-----------+-----------------------------------------+---------+ + | # | X2 dtype | Add type | Quant out | Pattern | Post op | + +---+----------+---------------+-----------+-----------------------------------------+---------+ + | 1 | BF16 | In-/out-place | Yes | linear + bf16 -> (relu) -> q | add | + +---+----------+---------------+-----------+-----------------------------------------+---------+ + | 2 | BF16 | In-/out-place | No | linear + bf16 -> (relu) | sum | + +---+----------+---------------+-----------+-----------------------------------------+---------+ + | 3 | FP32 | Out-place | Yes | linear + fp32 -> (relu) -> q | add | + | | | In-place right| | | | + +---+----------+---------------+-----------+-----------------------------------------+---------+ + | 4 | FP32 | Out-place | No | linear + fp32 -> (relu) | sum | + | | | In-place right| | | | + +---+----------+---------------+-----------+-----------------------------------------+---------+ + | 5 | FP32 | In-place left | Yes | linear + fp32 -> to_bf16 -> (relu) -> q | add | + +---+----------+---------------+-----------+-----------------------------------------+---------+ + | 6 | FP32 | In-place left | No | linear + fp32 -> to_bf16 -> (relu) | add | + +---+----------+---------------+-----------+-----------------------------------------+---------+ + + Note + (1) The positions of linear and the extra input can be swapped. + (2) we don't insert q-dq before the extra input of linear-add by recipe. But if q-dq is found at the + extra input, we don't match that pattern because we cannot match all these patterns in 3 passes. + """ + for x_scale_zp_are_tensors in (False, True): + qlinear_binary_op = ( + torch.ops.onednn.qlinear_pointwise.binary_tensor + if x_scale_zp_are_tensors + else torch.ops.onednn.qlinear_pointwise.binary + ) + unary_postop_list = ["none", "relu"] + unary_postop_dict = { + "none": None, + "relu": aten.relu.default, + } + convert_dtype_after_binary_list = [False, True] + + # Priority 1 to match: QLinear Binary or Binary-Unary pattern with int8 output + # Covers case (1) of int8-mixed-fp32 and case (1)(3)(5) of int8-mixed-bf16, + # totally 3 patterns (2 are identical) + swap_binary_inputs_list = [False, True] + int8_mixed_bf16_list = [False, True] + combinations = itertools.product( + unary_postop_list, + int8_mixed_bf16_list, + swap_binary_inputs_list, + convert_dtype_after_binary_list, + ) + qlinear_binary_replace_patterns = {} + for unary_op, int8_mixed_bf16, swap_inputs, cvt_dtype_binary in combinations: + if not int8_mixed_bf16 and cvt_dtype_binary: + # No convert node after binary node if dtypes are all fp32 + continue + qlinear_binary_replace_patterns.update( + { + BinaryUnaryAttr( + "add", 1.0, unary_op, [], "" + ): generate_pattern_with_output_quant( + generate_pattern_with_unary( + generate_pattern_with_binary( + aten.add.Tensor, + get_qlinear_pt2e_pattern(x_scale_zp_are_tensors), + KeywordArg("other"), + # If fp32 extra input is inplace added to bf16 linear output, + # a to_bf16 node is inserted after binary + dtype_convert=cvt_dtype_binary, + swap_inputs=swap_inputs, + ), + unary_postop_dict[unary_op], + ), + ) + } + ) + for binary_unary_attr, patterns in qlinear_binary_replace_patterns.items(): + _register_quantized_linear_binary_lowering( + patterns, + 0, # pass_number + qlinear_binary_op, # computation_op + binary_unary_attr, # binary_unary_attr + ) + + # Priority 2.1 to match: QLinear Binary-Unary pattern with fp32/bfloat16 output + # Covers case (2) of int8-mixed-fp32 and case (2)(4) of int8-mixed-bf16, + # totally 2 patterns (2 are identical) + binary_replace_float_out_patterns = {} + for swap_binary_inputs in swap_binary_inputs_list: + binary_replace_float_out_patterns.update( + { + BinaryUnaryAttr( + "sum", 1.0, "relu", [], "" + ): generate_pattern_with_unary( + generate_pattern_with_binary( + aten.add.Tensor, + get_qlinear_pt2e_pattern(x_scale_zp_are_tensors), + KeywordArg("accum"), + dtype_convert=False, + swap_inputs=swap_binary_inputs, + ), + aten.relu.default, + ), + } + ) + for ( + binary_unary_attr, + patterns, + ) in binary_replace_float_out_patterns.items(): + _register_quantized_linear_binary_lowering( + patterns, + 1, # pass_number + qlinear_binary_op, # computation_op + binary_unary_attr, + ) + # Priority 2.2 to match: QLinear Binary-Unary pattern with fp32/bfloat16 output + # Covers case (6) of int8-mixed-bf16 + binary_replace_float_out_patterns = {} + for swap_binary_inputs in swap_binary_inputs_list: + binary_replace_float_out_patterns.update( + { + BinaryUnaryAttr( + "add", 1.0, "relu", [], "" + ): generate_pattern_with_unary( + generate_pattern_with_binary( + aten.add.Tensor, + get_qlinear_pt2e_pattern(x_scale_zp_are_tensors), + KeywordArg("other"), + dtype_convert=True, + swap_inputs=swap_binary_inputs, + ), + aten.relu.default, + ), + } + ) + for ( + binary_unary_attr, + patterns, + ) in binary_replace_float_out_patterns.items(): + _register_quantized_linear_binary_lowering( + patterns, + 1, # pass_number + qlinear_binary_op, # computation_op + binary_unary_attr, + ) + + # Priority 3.1: QLinear Binary pattern with fp32/bfloat16 output + # Covers case (2) of int8-mixed-fp32 and case (2)(4) of int8-mixed-bf16, + # totally 2 patterns (2 are identical) + binary_replace_float_out_patterns = {} + for swap_binary_inputs in swap_binary_inputs_list: + binary_replace_float_out_patterns.update( + { + BinaryUnaryAttr( + "sum", 1.0, "none", [], "" + ): generate_pattern_with_binary( + aten.add.Tensor, + get_qlinear_pt2e_pattern(x_scale_zp_are_tensors), + KeywordArg("accum"), + dtype_convert=False, + swap_inputs=swap_binary_inputs, + ), + } + ) + for ( + binary_unary_attr, + patterns, + ) in binary_replace_float_out_patterns.items(): + _register_quantized_linear_binary_lowering( + patterns, + 2, # pass_number + qlinear_binary_op, # computation_op + # Output dtype should be the same as accum's dtype but we don't know + # its dtype. So, leave it to be determined in the lowering function + binary_unary_attr, + ) + # Priority 3.2: QLinear Binary pattern with fp32/bfloat16 output + # Covers (6) of int8-mixed-bf16 + binary_replace_float_out_patterns = {} + for swap_binary_inputs in swap_binary_inputs_list: + binary_replace_float_out_patterns.update( + { + BinaryUnaryAttr( + "add", 1.0, "none", [], "" + ): generate_pattern_with_binary( + aten.add.Tensor, + get_qlinear_pt2e_pattern(x_scale_zp_are_tensors), + KeywordArg("other"), + dtype_convert=True, + swap_inputs=swap_binary_inputs, + ), + } + ) + for ( + binary_unary_attr, + patterns, + ) in binary_replace_float_out_patterns.items(): + _register_quantized_linear_binary_lowering( + patterns, + 2, # pass_number + qlinear_binary_op, # computation_op + binary_unary_attr, + ) + def _is_valid_quantized_maxpool2d_optimization_pattern(): def fn(match): diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 672509b20a56..bfb7b8dea7eb 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -1303,6 +1303,8 @@ def debug(msg): torch.ops.aten.mkldnn_rnn_layer.default, torch.ops.onednn.qlinear_pointwise.default, torch.ops.onednn.qlinear_pointwise.tensor, + torch.ops.onednn.qlinear_pointwise.binary, + torch.ops.onednn.qlinear_pointwise.binary_tensor, ] need_fixed_channels_last_layout += [ torch.ops.mkldnn._convolution_pointwise.default, diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index be730decda14..689877ba6928 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -7158,6 +7158,232 @@ def create( ) +class QLinearPointwiseBinaryPT2E(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + has_bias=True, + x_scale_zp_are_tensors=False, + ): + """ + if bias is not None + - inputs = [x, w, b, weight_scale, weight_zp, x2] + - const_args is: [x_scale, x_zp, o_inv_scale, o_zp, + fp32_output, binary_attr, aplha, unary_attr, unary_scalars, unary_algorithm] + else + - inputs = [x, w, weight_scale, weight_zp, x2] + - const_args is: [bias, x_scale, x_zp, o_inv_scale, o_zp, + fp32_output, binary_attr, aplha, unary_attr, unary_scalars, unary_algorithm] + """ + self.has_bias = has_bias + self.x_scale_zp_are_tensors = x_scale_zp_are_tensors + super().__init__( + layout, + inputs, + constant_args, + None, + python_kernel_name=( + "torch.ops.onednn.qlinear_pointwise.binary_tensor" + if x_scale_zp_are_tensors + else "torch.ops.onednn.qlinear_pointwise.binary" + ), + cpp_kernel_name="onednn::qlinear_pointwise", + ) + self.cpp_kernel_overload_name = ( + "binary_tensor" if x_scale_zp_are_tensors else "binary" + ) + self.cpp_kernel_key = "qlinear_pointwise_binary" + x_scale_type_str, x_zp_type_str = ( + ("at::Tensor", "at::Tensor") + if x_scale_zp_are_tensors + else ("double", "int64_t") + ) + self.cpp_op_schema = f""" + at::Tensor( + at::Tensor act, + {x_scale_type_str} act_scale, + {x_zp_type_str} act_zero_point, + at::Tensor weight, + at::Tensor weight_scales, + at::Tensor weight_zero_points, + c10::optional bias, + double inv_output_scale, + int64_t output_zero_point, + c10::optional output_dtype, + c10::optional other, + double other_scale, + int64_t other_zero_point, + c10::string_view binary_post_op, + double binary_alpha, + c10::string_view unary_post_op, + torch::List> unary_post_op_args, + c10::string_view unary_post_op_algorithm)""" + + def codegen(self, wrapper): + # Parser the inputs and constant + args = [x.codegen_reference() for x in self.inputs] + const_args = [] + const_args.extend(self.codegen_const_args()) + + x = args[0] + packed_weight = args[1] + bias = args[2] if self.has_bias else const_args[0] + w_scale, w_zp, other = args[-3], args[-2], args[-1] + if self.x_scale_zp_are_tensors: + assert len(args) >= 5 + x_scale, x_zp = args[-5], args[-4] + ( + o_inv_scale, + o_zp, + output_dtype, + other_scale, + other_zp, + binary_attr, + alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ) = const_args[-10:] + else: + assert len(const_args) >= 8 + ( + x_scale, + x_zp, + o_inv_scale, + o_zp, + output_dtype, + other_scale, + other_zp, + binary_attr, + alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ) = const_args[-12:] + + codegen_args = ( + x, + x_scale, + x_zp, + packed_weight, + w_scale, + w_zp, + bias, + o_inv_scale, + o_zp, + output_dtype, + other, + other_scale, + other_zp, + binary_attr, + alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ) + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.python_kernel_name, + self.cpp_kernel_name, + codegen_args, + self.cpp_op_schema, + self.cpp_kernel_key, + self.cpp_kernel_overload_name, + ) + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) + + @classmethod + def create( + cls, + x: "TensorBox", + x_scale: float, + x_zp: int, + weight: "TensorBox", # packed_weight + w_scale: "TensorBox", + w_zp: "TensorBox", + bias: "TensorBox", + o_inv_scale: float, + output_zero_point: int, + output_dtype, + other: "TensorBox", + other_scale, + other_zp, + binary_attr, + alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ): + ( + inputs, + constant_args, + kernel_layout, + req_stride_order, + ) = _prepare_linear_fusion_create( + cls, + x, + weight, + bias, + ) + + if isinstance(x_scale, TensorBox) and isinstance(x_zp, TensorBox): + x_scale.realize() + x_zp.realize() + inputs = inputs + [x_scale, x_zp] + x_scale_zp_are_tensors = True + else: + assert isinstance(x_scale, float) and isinstance(x_zp, int) + constant_args = constant_args + [x_scale, x_zp] + x_scale_zp_are_tensors = False + w_scale.realize() + w_zp.realize() + inputs = inputs + [w_scale, w_zp] + if binary_attr == "sum": + other = cls.require_stride_order(other, req_stride_order) + inputs.append(other) + constant_args = constant_args + [ + o_inv_scale, + output_zero_point, + output_dtype, + other_scale, + other_zp, + binary_attr, + alpha, + unary_attr, + may_convert_to_optional(unary_scalars), + unary_algorithm, + ] + + if binary_attr == "sum": + packed = QLinearPointwiseBinaryPT2E( + layout=NoneLayout(other.get_device()), + inputs=inputs, + constant_args=constant_args, + has_bias=(bias is not None), + x_scale_zp_are_tensors=x_scale_zp_are_tensors, + ) + mark_node_as_mutating(packed, other) + # Return other since it has been inplace changed. + return packed.inputs[-1] + + if output_dtype is not None: + assert output_dtype in [torch.float32, torch.bfloat16] + # in _prepare_linear_fusion_create, we use x.dtype (uint8) to create kernel_layout + # if we set fp32_output, the output buf should be dtype float32 instead of uint8. + kernel_layout.dtype = output_dtype + + return QLinearPointwiseBinaryPT2E( + layout=kernel_layout, + inputs=inputs, + constant_args=constant_args, + has_bias=(bias is not None), + x_scale_zp_are_tensors=x_scale_zp_are_tensors, + ) + + @dataclasses.dataclass class MutableBox(IRNode): """ diff --git a/torch/_inductor/mkldnn_lowerings.py b/torch/_inductor/mkldnn_lowerings.py index 0ebccbf27ea3..5a12a5c090bf 100644 --- a/torch/_inductor/mkldnn_lowerings.py +++ b/torch/_inductor/mkldnn_lowerings.py @@ -338,6 +338,70 @@ def qlinear_unary( ) ) + @register_lowering( + torch.ops.onednn.qlinear_pointwise.binary, type_promotion_kind=None + ) + @register_lowering( + torch.ops.onednn.qlinear_pointwise.binary_tensor, type_promotion_kind=None + ) + def qlinear_binary( + x: TensorBox, + x_scale, + x_zp, + packed_weight: TensorBox, + w_scale: TensorBox, + w_zp: TensorBox, + bias: TensorBox, + o_inv_scale, + o_zero_point, + output_dtype, + x2: TensorBox, + x2_scale, + x2_zp, + binary_attr, + alpha, + unary_attr, + unary_scalars, + unary_algorithmm, + ): + if binary_attr == "sum": + if output_dtype in [ + torch.float32, + torch.bfloat16, + ] and x2.get_dtype() in [torch.float32, torch.bfloat16]: + if x2.get_dtype() != output_dtype: + # For int8-mixed-bf16 quantization and inplace add, + # there is case when accum dtype is float32 but output dtype is bfloat16. + # Since the accum will be inplaced changed with post op sum, + # we will do accum dtype convertion here. + x2 = to_dtype(x2, output_dtype) + else: + assert ( + x2.get_dtype() == output_dtype + ), "dtype of accum for qlinear post op sum should be the same as output" + return TensorBox.create( + ir.QLinearPointwiseBinaryPT2E.create( + x, + x_scale, + x_zp, + packed_weight, + w_scale, + w_zp, + bias, + o_inv_scale, + o_zero_point, + output_dtype, + x2, + x2_scale, + x2_zp, + binary_attr, + alpha, + unary_attr, + unary_scalars, + unary_algorithmm, + ) + ) + if torch._C.has_mkl: cpu_needs_realized_inputs.append(torch.ops.mkl._mkl_linear) From 8619fe6214cd8f31345ae73c5b90024a0233dc40 Mon Sep 17 00:00:00 2001 From: Nicolas Macchioni Date: Fri, 17 May 2024 08:09:48 +0000 Subject: [PATCH 055/116] variable search spaces for gemm autotuning (#126220) add a switch to change the gemm autotuning search space between the default (the current set of hardcoded configs) and an exhaustive search space that enumerates all block sizes in [16, 32, 64, 128, 256], stages in [1, 2, 3, 4, 5], and warps in [2, 4, 6] Pull Request resolved: https://github.com/pytorch/pytorch/pull/126220 Approved by: https://github.com/eellison --- torch/_inductor/config.py | 8 +++ torch/_inductor/kernel/mm_common.py | 78 +++++++++++++++++------------ 2 files changed, 53 insertions(+), 33 deletions(-) diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 640c0d25c264..7f14dc62de93 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -232,6 +232,7 @@ def is_fbcode(): force_same_precision = ( True if is_fbcode() else os.environ.get("TORCHINDUCTOR_FORCE_SAME_PRECISION") == "1" ) + # Specify candidate backends for gemm autotune. # Possible choices are combinations of: ATen, Triton, CUTLASS. # ATen: default Pytorch ATen kernels. @@ -241,6 +242,13 @@ def is_fbcode(): "TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS", "ATEN,TRITON" ).upper() +# Specify the size of the search space for GEMM autotuning. +# DEFAULT - balance between compile time overhead and performance +# EXHAUSTIVE - maximize performance +max_autotune_gemm_search_space = os.environ.get( + "TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_SEARCH_SPACE", "DEFAULT" +).upper() + # the value used as a fallback for the unbacked SymInts # that can appear in the input shapes (e.g., in autotuning) unbacked_symint_fallback = 8192 diff --git a/torch/_inductor/kernel/mm_common.py b/torch/_inductor/kernel/mm_common.py index 26d08183b0e5..76511e19a49d 100644 --- a/torch/_inductor/kernel/mm_common.py +++ b/torch/_inductor/kernel/mm_common.py @@ -1,4 +1,5 @@ import functools +import itertools import logging from typing import cast, List, Tuple @@ -113,39 +114,50 @@ def filtered_configs( # List of dictionaries to store the kernel configs. Configs that evaluate to true -# will be utilised on the target platform -mm_kernel_configs = [ - # "BLOCK_M", "BLOCK_N", "BLOCK_K", "num_stages", "num_warps" - {"config": (16, 32, 16, 3, 2), "cond": True}, - {"config": (16, 32, 32, 4, 2), "cond": True}, - {"config": (16, 32, 32, 5, 2), "cond": True}, - {"config": (32, 32, 16, 1, 2), "cond": True}, - {"config": (32, 32, 128, 2, 4), "cond": torch.version.hip is None}, - {"config": (32, 64, 32, 5, 8), "cond": True}, - {"config": (64, 32, 32, 5, 8), "cond": True}, - {"config": (64, 32, 128, 5, 4), "cond": True}, - {"config": (64, 64, 16, 2, 4), "cond": True}, - {"config": (64, 64, 32, 2, 4), "cond": True}, - {"config": (64, 64, 64, 3, 8), "cond": True}, - {"config": (64, 64, 128, 3, 4), "cond": True}, - {"config": (64, 64, 128, 5, 4), "cond": True}, - {"config": (64, 128, 32, 3, 4), "cond": True}, - {"config": (64, 128, 32, 4, 8), "cond": True}, - {"config": (64, 128, 64, 4, 4), "cond": True}, - {"config": (64, 128, 128, 4, 4), "cond": True}, - {"config": (128, 64, 32, 2, 2), "cond": True}, - {"config": (128, 64, 32, 3, 4), "cond": True}, - {"config": (128, 64, 32, 4, 8), "cond": True}, - {"config": (128, 64, 64, 3, 8), "cond": True}, - {"config": (128, 64, 128, 4, 8), "cond": True}, - {"config": (128, 128, 32, 2, 8), "cond": True}, - {"config": (128, 128, 32, 3, 4), "cond": True}, - {"config": (128, 128, 32, 4, 4), "cond": True}, - {"config": (128, 128, 64, 3, 4), "cond": True}, - {"config": (128, 128, 64, 3, 8), "cond": True}, - {"config": (128, 128, 64, 5, 4), "cond": True}, - {"config": (128, 128, 64, 5, 8), "cond": True}, -] +# will be utilised on the target platform. The configs are as follows: +# (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps) +mm_kernel_configs = ( + [ + {"config": (16, 32, 16, 3, 2), "cond": True}, + {"config": (16, 32, 32, 4, 2), "cond": True}, + {"config": (16, 32, 32, 5, 2), "cond": True}, + {"config": (32, 32, 16, 1, 2), "cond": True}, + {"config": (32, 32, 128, 2, 4), "cond": torch.version.hip is None}, + {"config": (32, 64, 32, 5, 8), "cond": True}, + {"config": (64, 32, 32, 5, 8), "cond": True}, + {"config": (64, 32, 128, 5, 4), "cond": True}, + {"config": (64, 64, 16, 2, 4), "cond": True}, + {"config": (64, 64, 32, 2, 4), "cond": True}, + {"config": (64, 64, 64, 3, 8), "cond": True}, + {"config": (64, 64, 128, 3, 4), "cond": True}, + {"config": (64, 64, 128, 5, 4), "cond": True}, + {"config": (64, 128, 32, 3, 4), "cond": True}, + {"config": (64, 128, 32, 4, 8), "cond": True}, + {"config": (64, 128, 64, 4, 4), "cond": True}, + {"config": (64, 128, 128, 4, 4), "cond": True}, + {"config": (128, 64, 32, 2, 2), "cond": True}, + {"config": (128, 64, 32, 3, 4), "cond": True}, + {"config": (128, 64, 32, 4, 8), "cond": True}, + {"config": (128, 64, 64, 3, 8), "cond": True}, + {"config": (128, 64, 128, 4, 8), "cond": True}, + {"config": (128, 128, 32, 2, 8), "cond": True}, + {"config": (128, 128, 32, 3, 4), "cond": True}, + {"config": (128, 128, 32, 4, 4), "cond": True}, + {"config": (128, 128, 64, 3, 4), "cond": True}, + {"config": (128, 128, 64, 3, 8), "cond": True}, + {"config": (128, 128, 64, 5, 4), "cond": True}, + {"config": (128, 128, 64, 5, 8), "cond": True}, + ] + if inductor_config.max_autotune_gemm_search_space != "EXHAUSTIVE" + else [ + {"config": (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps), "cond": True} + for BLOCK_M, BLOCK_N, BLOCK_K in itertools.product( + [16, 32, 64, 128, 256], repeat=3 + ) + for num_stages in [1, 2, 3, 4, 5] + for num_warps in [2, 4, 8] + ] +) int8_mm_kernel_configs = [ {"config": (64, 64, 32, 2, 4), "cond": True}, From 6c503f1dbbf9ef1bf99f19f0048c287f419df600 Mon Sep 17 00:00:00 2001 From: CaoE Date: Thu, 16 May 2024 18:57:16 -0700 Subject: [PATCH 056/116] save the reciprocal of weights for welford_reduce (#125148) Save the reciprocal of weights for welford_reduce to avoid redundant divisions for improving performance, and `weight_recps` will be inserted into the generated vec kernel. Generated code: - Before: ``` for(long x1=static_cast(0L); x1(1024L); x1+=static_cast(16L)) { auto tmp0 = at::vec::Vectorized::loadu(in_ptr0 + static_cast(x1 + (1024L*x0)), 16); tmp_acc0_vec = welford_combine(tmp_acc0_vec, tmp0); } ``` - After:: ``` static WeightRecp> weight_recps(64); for(long x1=static_cast(0L); x1(1024L); x1+=static_cast(16L)) { auto tmp0 = at::vec::Vectorized::loadu(in_ptr0 + static_cast(x1 + (1024L*x0)), 16); tmp_acc0_vec = welford_combine(tmp_acc0_vec, tmp0, &weight_recps); } ``` Performance: - Single core: Op | shape | eager/ms | inductor/ms | optimized inductor/ms -- | -- | -- | -- | -- layernorm | (56, 384, 1024) | 16.825 | 22.338 | 15.208 var | (56, 384, 1024) | 21.752 | 13.258 | 13.102 - 4 cores: Op | shape | eager/ms | inductor/ms | optimized inductor/ms -- | -- | -- | -- | -- layernorm | (56, 384, 1024) | 4.249 | 5.899 | 4.223 var | (56, 384, 1024) | 5.3152 | 3.278 | 2.163 Pull Request resolved: https://github.com/pytorch/pytorch/pull/125148 Approved by: https://github.com/jgong5, https://github.com/peterbell10 --- torch/_inductor/codegen/cpp.py | 54 +++++++++++++++++++++--- torch/_inductor/codegen/cpp_prefix.h | 62 +++++++++++++++++----------- 2 files changed, 88 insertions(+), 28 deletions(-) diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index c0aad2d27428..a0beddbf9bd3 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -17,7 +17,7 @@ from torch._inductor import dependencies from torch._prims_common import is_float_dtype from torch.utils import _pytree as pytree -from torch.utils._sympy.functions import FloorDiv, ModularIndexing +from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges @@ -1502,6 +1502,7 @@ def __init__(self, args, num_threads): self.local_reduction_init = IndentedBuffer() self.local_reduction_stores = IndentedBuffer() self.is_reduction = False + self.non_parallel_reduction_prefix = IndentedBuffer() self.reduction_cse = CSE(self.newvar_prefix, self.suffix, name_prefix="tmp_acc") self.preloads = IndentedBuffer() self.poststores = IndentedBuffer() @@ -1516,6 +1517,7 @@ def _gen_parallel_reduction_buffers( dtype, reduction_combine_fn=reduction_combine, reduction_init_fn=reduction_init, + welford_weight_reciprocal_vec_fn=None, ): if config.cpp.dynamic_threads and not self.parallel_reduction_prefix: self.parallel_reduction_prefix.writeline( @@ -1552,6 +1554,15 @@ def _gen_parallel_reduction_buffers( "}", ], ) + if ( + reduction_type == "welford_reduce" + and welford_weight_reciprocal_vec_fn + and hasattr(self, "weight_recp_vec_range") + and "vec" in f"{acc_type}" + ): + self.local_reduction_init.writeline( + welford_weight_reciprocal_vec_fn(dtype, num_threads) + ) def get_reduction_var_pattern(self, line: str): return re.search("tmp_acc[0-9]+", line) @@ -1880,6 +1891,8 @@ def get_reduction_code_buffer(loops, buffer="prefix"): prefix = kernel.reduction_prefix if loop.parallel: prefix = prefix + kernel.parallel_reduction_prefix + else: + prefix = prefix + kernel.non_parallel_reduction_prefix return prefix def gen_loops(loops: List[LoopLevel], in_reduction=False): @@ -2318,9 +2331,25 @@ def reduction(self, dtype, src_dtype, reduction_type, value): self.reduction_prefix.writeline( f"{acc_type_vec} {acc_vec} = {self.reduction_init_vec(reduction_type, dtype)};" ) - self.stores.writeline( - f"{acc_vec} = {self.reduction_combine_vec(reduction_type, acc_vec, value)};" + # save the reciprocal of weights for welford reduce if using static shape + reduction_size = functools.reduce( + lambda x, y: x * y, self.ranges[self.reduction_depth :] ) + if reduction_type == "welford_reduce": + reduction_factor = ( + self.tiling_factor if self.tiling_idx >= self.reduction_depth else 1 + ) + self.weight_recp_vec_range = FloorDiv(reduction_size, reduction_factor) + self.non_parallel_reduction_prefix.writeline( + self.welford_weight_reciprocal_vec(dtype, None) + ) + self.stores.writeline( + f"{acc_vec} = {self.reduction_combine_vec(reduction_type, acc_vec, value, True)};" + ) + else: + self.stores.writeline( + f"{acc_vec} = {self.reduction_combine_vec(reduction_type, acc_vec, value)};" + ) self._gen_parallel_reduction_buffers( acc, acc_type, @@ -2334,6 +2363,7 @@ def reduction(self, dtype, src_dtype, reduction_type, value): dtype, reduction_combine_fn=self.reduction_combine_vec, reduction_init_fn=self.reduction_init_vec, + welford_weight_reciprocal_vec_fn=self.welford_weight_reciprocal_vec, ) tmpvar: Union[str, CSEVariable] if self.tiling_idx >= self.reduction_depth: @@ -2435,7 +2465,18 @@ def reduction_acc_type_vec(self, reduction_type, dtype): return vec_type - def reduction_combine_vec(self, reduction_type, var, next_value): + def welford_weight_reciprocal_vec(self, dtype, num_threads=None): + vec_num_range_thread = ( + CeilDiv(self.weight_recp_vec_range, num_threads) + if num_threads + else self.weight_recp_vec_range + ) + vec_num_range_thread_expr = cexpr_index(vec_num_range_thread) + return f"static WeightRecp<{self._get_vec_type(dtype)}> weight_recps({vec_num_range_thread_expr});" + + def reduction_combine_vec( + self, reduction_type, var, next_value, use_weight_recps=False + ): if reduction_type == "max": return f"at::vec::maximum({var}, {next_value})" elif reduction_type == "min": @@ -2447,7 +2488,10 @@ def reduction_combine_vec(self, reduction_type, var, next_value): elif reduction_type == "xor_sum": return f"{var} ^ {next_value}" elif reduction_type == "welford_reduce": - return f"welford_combine({var}, {next_value})" + if use_weight_recps: + return f"welford_combine({var}, {next_value}, &weight_recps)" + else: + return f"welford_combine({var}, {next_value})" elif reduction_type == "welford_combine": if isinstance(next_value, tuple): # When reading a value from Inductor IR we have a tuple of variable names diff --git a/torch/_inductor/codegen/cpp_prefix.h b/torch/_inductor/codegen/cpp_prefix.h index 096f716bc8da..7e3483ca9994 100644 --- a/torch/_inductor/codegen/cpp_prefix.h +++ b/torch/_inductor/codegen/cpp_prefix.h @@ -45,7 +45,7 @@ template struct Welford { T mean = T(0); T m2 = T(0); - T weight = T(0); + int64_t index = 0; }; @@ -58,41 +58,57 @@ struct IsVecType>: std::true_type {}; #endif template -Welford welford_combine(const Welford &a, const Welford &b) { - if constexpr (!IsVecType::value) { - if (a.weight == 0) { - return b; - } - if (b.weight == 0) { - return a; +struct WeightRecp { + using scalar_t = typename T::value_type; + int64_t N; + std::vector weight_recps; + WeightRecp(int64_t N) : N(N) { + weight_recps.reserve(N); + for (const auto i : c10::irange(N)) { + weight_recps.push_back( + scalar_t(static_cast(1) / static_cast(i + 1))); } } - auto delta = b.mean - a.mean; - auto new_weight = a.weight + b.weight; - auto wb_over_w = b.weight / new_weight; - if constexpr (IsVecType::value) { - // Guard against division by zero - wb_over_w = T::blendv(wb_over_w, T(0), new_weight == T(0)); +}; + +template +Welford welford_combine(const Welford &a, const Welford &b) { + if (a.index == 0) { + return b; + } + if (b.index == 0) { + return a; } + auto delta = b.mean - a.mean; + auto new_index = a.index + b.index; + auto wb_over_w = T(b.index) / T(new_index); auto result = Welford{ a.mean + delta * wb_over_w, - a.m2 + b.m2 + delta * delta * a.weight * wb_over_w, - new_weight + a.m2 + b.m2 + delta * delta * T(a.index) * wb_over_w, + new_index, }; return result; } template -Welford welford_combine(const Welford &acc, T data) { +Welford welford_combine(const Welford &acc, T data, const WeightRecp* w=nullptr) { // Add a single data point + int64_t index = acc.index + 1; auto delta = data - acc.mean; - auto new_weight = acc.weight + T(1); - auto new_mean = acc.mean + delta / new_weight; + T new_mean; + if constexpr (!IsVecType::value) { + new_mean = acc.mean + delta / T(index); + } else { + new_mean = acc.mean + + ((w == nullptr || acc.index >= w->weight_recps.size()) + ? delta / T(index) + : delta * T(w->weight_recps[acc.index])); + } auto new_delta = data - new_mean; auto result = Welford{ new_mean, acc.m2 + delta * new_delta, - new_weight + index }; return result; } @@ -177,10 +193,11 @@ template Welford welford_vec_reduce_all(Welford> acc) { using Vec = at::vec::Vectorized; for (size_t n = 1; n < Vec::size(); n *= 2) { + auto index = acc.index; auto shuffled = Welford{ vec_shuffle_down(acc.mean, n), vec_shuffle_down(acc.m2, n), - vec_shuffle_down(acc.weight, n) + index, }; acc = welford_combine(acc, shuffled); } @@ -193,8 +210,7 @@ Welford welford_vec_reduce_all(Welford> acc.m2.store(array); result.m2 = array[0]; - acc.weight.store(array); - result.weight = array[0]; + result.index = acc.index; return result; } From 4ed93d6e0c5deb543ba5a3bd103728f00d39b1a6 Mon Sep 17 00:00:00 2001 From: cyy Date: Fri, 17 May 2024 12:49:23 +0000 Subject: [PATCH 057/116] [Submodule] Remove zstd dependency (#126485) After searching in the codebase, it seems that zstd is not in use now. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126485 Approved by: https://github.com/ezyang --- .gitmodules | 4 ---- CMakeLists.txt | 5 ----- cmake/Dependencies.cmake | 19 ------------------- setup.py | 3 --- third_party/zstd | 1 - 5 files changed, 32 deletions(-) delete mode 160000 third_party/zstd diff --git a/.gitmodules b/.gitmodules index c9b84a370167..4443eace838d 100644 --- a/.gitmodules +++ b/.gitmodules @@ -50,10 +50,6 @@ ignore = dirty path = third_party/psimd url = https://github.com/Maratyszcza/psimd.git -[submodule "third_party/zstd"] - ignore = dirty - path = third_party/zstd - url = https://github.com/facebook/zstd.git [submodule "third_party/cpuinfo"] ignore = dirty path = third_party/cpuinfo diff --git a/CMakeLists.txt b/CMakeLists.txt index f7561d606cbd..1925bd8636f4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -283,7 +283,6 @@ option(USE_VULKAN_FP16_INFERENCE "Vulkan - Use fp16 inference" OFF) option(USE_VULKAN_RELAXED_PRECISION "Vulkan - Use relaxed precision math in the kernels (mediump)" OFF) # option USE_XNNPACK: try to enable xnnpack by default. option(USE_XNNPACK "Use XNNPACK" ON) -option(USE_ZSTD "Use ZSTD" OFF) option(USE_ROCM_KERNEL_ASSERT "Use Kernel Assert for ROCm" OFF) # Ensure that an ITT build is the default for x86 CPUs cmake_dependent_option( @@ -413,7 +412,6 @@ option(USE_SYSTEM_FXDIV "Use system-provided fxdiv." OFF) option(USE_SYSTEM_BENCHMARK "Use system-provided google benchmark." OFF) option(USE_SYSTEM_ONNX "Use system-provided onnx." OFF) option(USE_SYSTEM_XNNPACK "Use system-provided xnnpack." OFF) -option(USE_SYSTEM_ZSTD "Use system-provided zstd." OFF) option(USE_GOLD_LINKER "Use ld.gold to link" OFF) if(USE_SYSTEM_LIBS) set(USE_SYSTEM_CPUINFO ON) @@ -435,9 +433,6 @@ if(USE_SYSTEM_LIBS) if(USE_TBB) set(USE_SYSTEM_TBB ON) endif() - if(USE_ZSTD) - set(USE_SYSTEM_ZSTD ON) - endif() endif() # Used when building Caffe2 through setup.py diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index a9a3aab8c510..4670ebadf2f5 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -1426,25 +1426,6 @@ if(USE_NNAPI AND NOT ANDROID) caffe2_update_option(USE_NNAPI OFF) endif() -if(USE_ZSTD) - if(USE_SYSTEM_ZSTD) - find_package(zstd REQUIRED) - if(TARGET zstd::libzstd_shared) - set(ZSTD_TARGET zstd::libzstd_shared) - else() - set(ZSTD_TARGET zstd::libzstd_static) - endif() - list(APPEND Caffe2_DEPENDENCY_LIBS ${ZSTD_TARGET}) - get_property(ZSTD_INCLUDE_DIR TARGET ${ZSTD_TARGET} PROPERTY INTERFACE_INCLUDE_DIRECTORIES) - include_directories(SYSTEM ${ZSTD_INCLUDE_DIR}) - else() - list(APPEND Caffe2_DEPENDENCY_LIBS libzstd_static) - include_directories(SYSTEM ${CMAKE_CURRENT_LIST_DIR}/../third_party/zstd/lib) - add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/../third_party/zstd/build/cmake) - set_property(TARGET libzstd_static PROPERTY POSITION_INDEPENDENT_CODE ON) - endif() -endif() - # ---[ Onnx if(CAFFE2_CMAKE_BUILDING_WITH_MAIN_REPO AND NOT INTERN_DISABLE_ONNX) if(EXISTS "${CAFFE2_CUSTOM_PROTOC_EXECUTABLE}") diff --git a/setup.py b/setup.py index 84f3d48c958e..93245d971be8 100644 --- a/setup.py +++ b/setup.py @@ -151,9 +151,6 @@ # USE_REDIS # Whether to use Redis for distributed workflows (Linux only) # -# USE_ZSTD -# Enables use of ZSTD, if the libraries are found -# # USE_ROCM_KERNEL_ASSERT=1 # Enable kernel assert in ROCm platform # diff --git a/third_party/zstd b/third_party/zstd deleted file mode 160000 index aec56a52fbab..000000000000 --- a/third_party/zstd +++ /dev/null @@ -1 +0,0 @@ -Subproject commit aec56a52fbab207fc639a1937d1e708a282edca8 From 55033ab43a09b140c1491fa2bbf6aea854955591 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Thu, 16 May 2024 22:13:35 -0700 Subject: [PATCH 058/116] Update ops handler documentation some more (#126480) Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/126480 Approved by: https://github.com/peterbell10 ghstack dependencies: #126292, #126299 --- torch/_inductor/codegen/common.py | 2 + torch/_inductor/ops_handler.py | 192 ++++++++++++++++++++++++++++-- 2 files changed, 183 insertions(+), 11 deletions(-) diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index 0d90e474d04b..8641f89a7d3a 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -601,6 +601,8 @@ class OverridesData: ) +# NB: if you add a new special function, don't forget to update +# torch._inductor.ops_handler too pointwise_overrides_data: Dict[str, OverridesData] = dict( airy_ai=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, diff --git a/torch/_inductor/ops_handler.py b/torch/_inductor/ops_handler.py index 6da386709997..3df6749083c1 100644 --- a/torch/_inductor/ops_handler.py +++ b/torch/_inductor/ops_handler.py @@ -146,6 +146,12 @@ def to_dtype_bitcast(self, x: T, dtype: torch.dtype, src_dtype: torch.dtype) -> """ ... + def identity(self, x: T) -> T: + """ + Returns x as is. This is used to trigger CSE. + """ + ... + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # These operations are only available in a "kernel" context. Check # torch._inductor.codegen.common.CSEProxy for their typical implementation @@ -408,9 +414,6 @@ def to_int(self, x0: T) -> T: def trunc(self, x0: T) -> T: ... - def truncdiv(self, x0: T, x1: T) -> T: - ... - def ceil(self, x0: T) -> T: ... @@ -447,28 +450,195 @@ def sub(self, x0: T, x1: T) -> T: def mul(self, x0: T, x1: T) -> T: ... - def floordiv(self, x0: T, x1: T) -> T: + def pow(self, x0: T, x1: T) -> T: ... - def truediv(self, x0: T, x1: T) -> T: + def and_(self, x0: T, x1: T) -> T: ... - def div(self, x0: T, x1: T) -> T: + def or_(self, x0: T, x1: T) -> T: ... - def mod(self, x0: T, x1: T) -> T: + def xor(self, x0: T, x1: T) -> T: ... - def pow(self, x0: T, x1: T) -> T: + # These are metaprogrammed by MockHandler._init_cls + def lshift(self, x0: T, x1: T) -> T: ... - def and_(self, x0: T, x1: T) -> T: + def rshift(self, x0: T, x1: T) -> T: ... - def or_(self, x0: T, x1: T) -> T: + def getitem(self, x0: T, x1: T) -> T: + # TODO: this is probably just illegal lol ... - def xor(self, x0: T, x1: T) -> T: + def matmul(self, x0: T, x1: T) -> T: + # TODO: this is probably just illegal lol + ... + + def invert(self, x0: T) -> T: + ... + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # These are "special" operators. These only exist if the target + # language actually supports the operator. Keep this in sync with + # pointwise_overrides_data. + + def airy_ai(self, x: T) -> T: + ... + + def bessel_j0(self, x: T) -> T: + ... + + def bessel_j1(self, x: T) -> T: + ... + + def bessel_y0(self, x: T) -> T: + ... + + def bessel_y1(self, x: T) -> T: + ... + + def digamma(self, x: T) -> T: + ... + + def erfcx(self, x: T) -> T: + ... + + def fma(self, x: T, y: T, z: T) -> T: + ... + + def igamma(self, x: T, y: T) -> T: + ... + + def igammac(self, x: T, y: T) -> T: + ... + + def gammainc(self, x: T, y: T) -> T: + ... + + def gammaincc(self, x: T, y: T) -> T: + ... + + def i0(self, x: T) -> T: + ... + + def i0e(self, x: T) -> T: + ... + + def i1(self, x: T) -> T: + ... + + def i1e(self, x: T) -> T: + ... + + def log_ndtr(self, x: T) -> T: + ... + + def modified_bessel_i0(self, x: T) -> T: + ... + + def modified_bessel_i1(self, x: T) -> T: + ... + + def modified_bessel_k0(self, x: T) -> T: + ... + + def modified_bessel_k1(self, x: T) -> T: + ... + + def ndtr(self, x: T) -> T: + ... + + def ndtri(self, x: T) -> T: + ... + + def polygamma(self, x: T, y: T) -> T: + ... + + def scaled_modified_bessel_k0(self, x: T) -> T: + ... + + def scaled_modified_bessel_k1(self, x: T) -> T: + ... + + def spherical_bessel_j0(self, x: T) -> T: + ... + + def zeta(self, x: T, y: T) -> T: + ... + + def chebyshev_polynomial_t(self, x: T, y: T) -> T: + ... + + def chebyshev_polynomial_u(self, x: T, y: T) -> T: + ... + + def chebyshev_polynomial_v(self, x: T, y: T) -> T: + ... + + def chebyshev_polynomial_w(self, x: T, y: T) -> T: + ... + + def legendre_polynomial_p(self, x: T, y: T) -> T: + ... + + def shifted_chebyshev_polynomial_t(self, x: T, y: T) -> T: + ... + + def shifted_chebyshev_polynomial_u(self, x: T, y: T) -> T: + ... + + def shifted_chebyshev_polynomial_v(self, x: T, y: T) -> T: + ... + + def shifted_chebyshev_polynomial_w(self, x: T, y: T) -> T: + ... + + def hermite_polynomial_h(self, x: T, y: T) -> T: + ... + + def hermite_polynomial_he(self, x: T, y: T) -> T: + ... + + def laguerre_polynomial_l(self, x: T, y: T) -> T: + ... + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # These operators are a bit special, because they are conventionally + # natively supported in both Python and C, but the semantics differ so + # care must be taken + + def truncdiv(self, x0: T, x1: T) -> T: + """C-style trunc division between integers only. Computes the true + division of two numbers and rounds the result to zero. + """ + ... + + def floordiv(self, x0: T, x1: T) -> T: + """Python-style floor division between integers only. Computes the + true division of two numbers and floors the result. + """ + ... + + def truediv(self, x0: T, x1: T) -> T: + """True division between floats. Integer inputs are NOT valid: to do + Python style (int, int) -> float division, promote the inputs to float + first.""" + ... + + def div(self, x0: T, x1: T) -> T: + """TODO: to be removed. This renders as / no matter what the backend is + which is incoherent.""" + ... + + def mod(self, x0: T, x1: T) -> T: + """C-style modulus, take sign from LHS (x0).""" + ... + + def remainder(self, x0: T, x1: T) -> T: + """Python-style modulus, take sign from RHS (x1).""" ... # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ From 3f289063117673650db868c978bf3cb8125a22dc Mon Sep 17 00:00:00 2001 From: Andrew Gu Date: Thu, 16 May 2024 18:57:20 -0700 Subject: [PATCH 059/116] [FSDP2] Fixed 2D clip grad norm test (#126497) This fixes https://github.com/pytorch/pytorch/issues/126484. We change from transformer to MLP stack since transformer seems to introduce slight numeric differences when using TP. We include a sequence parallel layer norm module in the MLP stack to exercise `(S(0), R)` placement. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126497 Approved by: https://github.com/weifengpy, https://github.com/wz337 --- .ci/pytorch/test.sh | 1 + .../fsdp/test_fully_shard_clip_grad_norm_.py | 36 ++++++----- torch/testing/_internal/common_fsdp.py | 61 ++++++++++++------- 3 files changed, 61 insertions(+), 37 deletions(-) diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index 9483bb630d4e..1953b314ec83 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -326,6 +326,7 @@ test_inductor_distributed() { python test/run_test.py -i distributed/_composable/fsdp/test_fully_shard_frozen.py --verbose python test/run_test.py -i distributed/_composable/fsdp/test_fully_shard_mixed_precision.py -k test_compute_dtype --verbose python test/run_test.py -i distributed/_composable/fsdp/test_fully_shard_mixed_precision.py -k test_reduce_dtype --verbose + python test/run_test.py -i distributed/_composable/fsdp/test_fully_shard_clip_grad_norm.py -k test_clip_grad_norm_2d --verbose python test/run_test.py -i distributed/fsdp/test_fsdp_tp_integration.py -k test_fsdp_tp_integration --verbose # this runs on both single-gpu and multi-gpu instance. It should be smart about skipping tests that aren't supported diff --git a/test/distributed/_composable/fsdp/test_fully_shard_clip_grad_norm_.py b/test/distributed/_composable/fsdp/test_fully_shard_clip_grad_norm_.py index 9139b62f1367..4e1e897a11be 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_clip_grad_norm_.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_clip_grad_norm_.py @@ -12,7 +12,7 @@ from torch.distributed._tensor.debug import CommDebugMode from torch.distributed.device_mesh import DeviceMesh, init_device_mesh from torch.testing._internal.common_distributed import skip_if_lt_x_gpu -from torch.testing._internal.common_fsdp import FSDPTest +from torch.testing._internal.common_fsdp import FSDPTest, MLPStack from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( ModelArgs, @@ -30,12 +30,12 @@ def _test_clip_grad_norm( ref_optim: torch.optim.Optimizer, model: nn.Module, optim: torch.optim.Optimizer, + inp: torch.Tensor, dp_mesh: Optional[DeviceMesh] = None, ): vector_norm_fn = functools.partial(torch.linalg.vector_norm, ord=norm_type) dp_mesh = dp_mesh or init_device_mesh("cuda", (self.world_size,)) torch.manual_seed(42 + dp_mesh.get_local_rank() + 1) - inp = torch.randint(0, model.model_args.vocab_size, (3, 16), device="cuda") for iter_idx in range(10): ref_optim.zero_grad() ref_model(inp).sum().backward() @@ -53,11 +53,11 @@ def _test_clip_grad_norm( continue self.assertEqual(ref_grad, param.grad.full_tensor()) - # Check that all gradients have norm greater than the max norm - # before clipping to ensure the clipping is not vacuous - self.assertTrue(all(vector_norm_fn(g).item() > max_norm for g in ref_grads)) + # Check that at least one gradient has norm greater than the max + # norm before clipping to ensure the clipping is not vacuous + self.assertTrue(any(vector_norm_fn(g).item() > max_norm for g in ref_grads)) self.assertTrue( - all(vector_norm_fn(g).item() > max_norm for g in local_grads) + any(vector_norm_fn(g).item() > max_norm for g in local_grads) ) # Check gradient norm clipping via total norm and individual @@ -111,7 +111,10 @@ def test_clip_grad_norm_1d(self): fully_shard(module) fully_shard(model) optim = torch.optim.Adam(model.parameters(), lr=1e-2) - self._test_clip_grad_norm(1, norm_type, ref_model, ref_optim, model, optim) + inp = torch.randint(0, model.model_args.vocab_size, (3, 16), device="cuda") + self._test_clip_grad_norm( + 1, norm_type, ref_model, ref_optim, model, optim, inp + ) class TestClipGradNormWorldSize4(_TestClipGradNormBase): @@ -130,20 +133,23 @@ def test_clip_grad_norm_2d(self): ) dp_mesh, tp_mesh = global_mesh["dp"], global_mesh["tp"] torch.manual_seed(42) - model_args = ModelArgs(dropout_p=0.0) - model = Transformer(model_args) + # Test using an MLP stack, not a transformer, since the transformer + # has some more significant numeric differences from the TP + model = MLPStack(16, with_seq_parallel=True) ref_model = replicate( copy.deepcopy(model).cuda(), process_group=dp_mesh.get_group() ) ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2) - model = Transformer.parallelize(model, tp_mesh, use_seq_parallel=True) - for module in model.modules(): - if isinstance(module, TransformerBlock): - fully_shard(module, mesh=dp_mesh) - fully_shard(model, mesh=dp_mesh) + model.parallelize( + tp_mesh, + dp_mesh, + use_activation_checkpointing=False, + reshard_after_forward=True, + ) optim = torch.optim.Adam(model.parameters(), lr=1e-2) + inp = torch.randn(2, 16, device="cuda") self._test_clip_grad_norm( - 1, norm_type, ref_model, ref_optim, model, optim, dp_mesh + 0.5, norm_type, ref_model, ref_optim, model, optim, inp, dp_mesh ) diff --git a/torch/testing/_internal/common_fsdp.py b/torch/testing/_internal/common_fsdp.py index 283982b2ba44..94b6a68f931c 100644 --- a/torch/testing/_internal/common_fsdp.py +++ b/torch/testing/_internal/common_fsdp.py @@ -10,7 +10,17 @@ from copy import deepcopy from enum import auto, Enum from functools import partial, wraps -from typing import Any, Callable, Dict, no_type_check, Optional, Tuple, Type, Union +from typing import ( + Any, + Callable, + Dict, + List, + no_type_check, + Optional, + Tuple, + Type, + Union, +) from unittest import mock import torch @@ -39,6 +49,7 @@ ColwiseParallel, parallelize_module, RowwiseParallel, + SequenceParallel, ) from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer from torch.nn.parallel.distributed import DistributedDataParallel as DDP @@ -865,15 +876,17 @@ def reset_parameters(self): class MLPStack(nn.Sequential): - def __init__(self, mlp_dim: int): - modules = [ - nn.LayerNorm(mlp_dim, bias=False), + def __init__(self, mlp_dim: int, *, with_seq_parallel: bool = False): + modules: List[nn.Module] = [ # Use multiplier of 3 to exercise uneven case MLP(mlp_dim, dim_multiplier=3), MLP(mlp_dim), MLP(mlp_dim, dim_multiplier=3), ] + if with_seq_parallel: + modules.append(nn.LayerNorm(mlp_dim, bias=False)) super().__init__(*modules) + self.with_seq_parallel = with_seq_parallel def parallelize( self, @@ -882,25 +895,29 @@ def parallelize( use_activation_checkpointing: bool, reshard_after_forward: bool, ) -> "MLPStack": - parallelize_module( - self, - device_mesh=tp_mesh, - # Leave the layer norm as implicitly replicated - parallelize_plan={ - # Pass `use_local_output=False` to keep as DTensor to preserve - # uneven activation dims - "1.in_proj": ColwiseParallel(use_local_output=False), - "1.out_proj": RowwiseParallel(use_local_output=False), - "2.in_proj": ColwiseParallel(use_local_output=False), - "2.out_proj": RowwiseParallel(use_local_output=False), - "3.in_proj": ColwiseParallel(use_local_output=False), - "3.out_proj": RowwiseParallel(), - }, - ) - for mlp in self: + parallelize_plan = { + # Pass `use_local_output=False` to keep as DTensor to preserve + # uneven activation dims + "0.in_proj": ColwiseParallel(use_local_output=False), + "0.out_proj": RowwiseParallel(use_local_output=False), + "1.in_proj": ColwiseParallel(use_local_output=False), + "1.out_proj": RowwiseParallel(use_local_output=False), + "2.in_proj": ColwiseParallel(use_local_output=False), + "2.out_proj": RowwiseParallel(output_layouts=Shard(1)) + if self.with_seq_parallel + else RowwiseParallel(), + } + if self.with_seq_parallel: + parallelize_plan["3"] = SequenceParallel(sequence_dim=1) + parallelize_module(self, device_mesh=tp_mesh, parallelize_plan=parallelize_plan) + for module in self: + if isinstance(module, nn.LayerNorm): + continue if use_activation_checkpointing: - checkpoint(mlp) - fully_shard(mlp, mesh=dp_mesh, reshard_after_forward=reshard_after_forward) + checkpoint(module) + fully_shard( + module, mesh=dp_mesh, reshard_after_forward=reshard_after_forward + ) fully_shard(self, mesh=dp_mesh, reshard_after_forward=reshard_after_forward) return self From ab307a8992a6d67457ac7698cb7d75ccd5579cf3 Mon Sep 17 00:00:00 2001 From: eellison Date: Thu, 16 May 2024 15:44:02 -0700 Subject: [PATCH 060/116] Default to env variable instead of config value for precompile parallelism (#126333) Previously, we would default to the config `compile_threads`. That controls the number of forks we use for async compile. It defaults to 1 in fbcode because fork() has known issues with safety. In precompilation, we are using threads, which have no safety issues and should strictly improve compile time. there isn't really any reason to reduce except for testing, and it doesn't make sense to share the same value as for determining forks. This changes so we default it to use as many threads as needed unless the env variable is set. Differential Revision: [D57473023](https://our.internmc.facebook.com/intern/diff/D57473023) Pull Request resolved: https://github.com/pytorch/pytorch/pull/126333 Approved by: https://github.com/nmacchioni --- torch/_inductor/select_algorithm.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index eadb3d8159ab..4e341f358c45 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -6,6 +6,7 @@ import math import operator +import os import sys import textwrap import time @@ -921,6 +922,13 @@ class NoValidChoicesError(RuntimeError): pass +@functools.lru_cache(None) +def get_env_num_workers() -> Optional[int]: + if "TORCHINDUCTOR_COMPILE_THREADS" in os.environ: + return int(os.environ["TORCHINDUCTOR_COMPILE_THREADS"]) + return None + + class AlgorithmSelectorCache(PersistentCache): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -985,11 +993,10 @@ def no_op(*args, **kwargs): or precompilation_timeout_seconds <= 0 ): return no_op - num_workers = min( - config.compile_threads, - torch.get_num_threads(), - len(choices), - ) + + env_workers = get_env_num_workers() + num_workers = env_workers if env_workers is not None else (len(choices)) + if num_workers <= 0: return no_op From 078e530446a4399dc01559afcccbb9f3113b59c1 Mon Sep 17 00:00:00 2001 From: James Wu Date: Thu, 16 May 2024 12:21:15 -0700 Subject: [PATCH 061/116] Delete refactored function, move changes over (#126407) Oops, in https://github.com/pytorch/pytorch/pull/125610 I moved this function to runtime_wrappers.py, but forgot to delete the old one. https://github.com/pytorch/pytorch/pull/126234 then modified it which would do nothing, so I'm applying the change correctly now and deleting the function as I intended. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126407 Approved by: https://github.com/eellison --- .../jit_compile_runtime_wrappers.py | 20 ------------------ .../_aot_autograd/runtime_wrappers.py | 21 +++++++++++-------- 2 files changed, 12 insertions(+), 29 deletions(-) diff --git a/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py b/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py index 320a899e6b64..0b6e02da80d2 100644 --- a/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py +++ b/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py @@ -83,26 +83,6 @@ def _force_contiguous(x): return x -def _compute_output_meta_with_inductor_strides(fw_module, fwd_output_strides): - out = [n.meta["val"] for n in (list(fw_module.graph.nodes)[-1].args[0])] - # will only be set for inductor - if not fwd_output_strides: - return out - - from torch.fx.experimental.symbolic_shapes import statically_known_true - - for i in range(len(out)): - if not isinstance(out[i], Tensor): - continue - if all( - statically_known_true(s1 == s2) - for s1, s2 in zip(out[i].stride(), fwd_output_strides[i]) - ): - continue - out[i] = out[i].as_strided(out[i].shape, fwd_output_strides[i]) - return out - - # See Note [Tangents must be contiguous, Part 2] def coerce_runtime_tangent(x, metadata_tensor): if not isinstance(x, torch.Tensor): diff --git a/torch/_functorch/_aot_autograd/runtime_wrappers.py b/torch/_functorch/_aot_autograd/runtime_wrappers.py index a1fb2980ed1d..c1b9a3b29f2e 100644 --- a/torch/_functorch/_aot_autograd/runtime_wrappers.py +++ b/torch/_functorch/_aot_autograd/runtime_wrappers.py @@ -486,15 +486,18 @@ def _compute_output_meta_with_inductor_strides(self): fwd_output_strides = self.fwd_output_strides if not fwd_output_strides: return out - with TracingContext.get().fake_mode.shape_env.suppress_guards(): - for i in range(len(out)): - if not isinstance(out[i], Tensor): - continue - if all( - s1 == s2 for s1, s2 in zip(out[i].stride(), fwd_output_strides[i]) - ): - continue - out[i] = out[i].as_strided(out[i].shape, fwd_output_strides[i]) + + from torch.fx.experimental.symbolic_shapes import statically_known_true + + for i in range(len(out)): + if not isinstance(out[i], Tensor): + continue + if all( + statically_known_true(s1 == s2) + for s1, s2 in zip(out[i].stride(), fwd_output_strides[i]) + ): + continue + out[i] = out[i].as_strided(out[i].shape, fwd_output_strides[i]) return out # To be called post compile From 7e166e805740766268b4297621961a4afdd6c41b Mon Sep 17 00:00:00 2001 From: David Chiu Date: Fri, 17 May 2024 15:46:36 +0000 Subject: [PATCH 062/116] [optim] Fix: wrong ASGD implementation (#126375) This PR is based on #125440, additionally merging the latest main branch and fixing the lint failures from #126361. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126375 Approved by: https://github.com/janeyx99 --- test/test_optim.py | 10 +++++- torch/_meta_registrations.py | 10 ++++++ torch/optim/asgd.py | 37 +++++++------------- torch/testing/_internal/common_optimizers.py | 8 +++++ 4 files changed, 40 insertions(+), 25 deletions(-) diff --git a/test/test_optim.py b/test/test_optim.py index 717e89224672..7fa612e89da0 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -604,8 +604,16 @@ def _compare_between(self, inputs, models, optimizers, assert_eq_kwargs=None, as for input, model, optimizer in zip(inputs, models, optimizers): optimizer.zero_grad() + if i == 3: + # Freeze a layer to test if the step of this layer in 'fused' or 'foreach' + # is same as the step in 'forloop'. + model[2].requires_grad_(False) + if i == 5: + # Unfreeze the layer after 2 iters. + model[2].requires_grad_(True) + # Test that step behaves as expected (a no-op) when grads are set to None - if i != 3: + if i != 2: output = model(input) loss = output.sum() loss.backward() diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index a4edd839c651..93e45bfb1d84 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -19,6 +19,7 @@ corresponding_real_dtype, elementwise_dtypes, ELEMENTWISE_TYPE_PROMOTION_KIND, + FloatLike, IntLike, make_contiguous_strides_for, Number, @@ -3286,6 +3287,15 @@ def _meta_foreach_inplace(*args, _scalar_op=None, **kwargs): return +@register_meta([aten._foreach_pow_.Scalar]) +def meta__foreach_pow__scalar(self, exponent): + torch._check( + isinstance(exponent, FloatLike), + lambda: f"exponent must be a float but got {type(exponent)}", + ) + return + + @register_meta([aten._foreach_pow.ScalarAndTensor]) def meta__foreach_pow_scalar_and_tensor(self, exponent): # Only foreach_pow has a ScalarAndTensor method and needs special diff --git a/torch/optim/asgd.py b/torch/optim/asgd.py index a87aadc81803..f53f8b427e9f 100644 --- a/torch/optim/asgd.py +++ b/torch/optim/asgd.py @@ -22,13 +22,6 @@ __all__ = ["ASGD", "asgd"] -def _to_tensor(x, device=None): - if not isinstance(x, torch.Tensor): - return torch.tensor(x, device=device) - - return x - - class ASGD(Optimizer): def __init__( self, @@ -264,9 +257,9 @@ def _single_tensor_asgd( mu.copy_(1 / torch.maximum(step_t - t0, torch.ones_like(step_t))) else: step = _get_value(step_t) - new_eta = _to_tensor(lr / ((1 + lambd * lr * step) ** alpha)) + new_eta = torch.as_tensor(lr / ((1 + lambd * lr * step) ** alpha)) eta.copy_(new_eta) - new_mu = _to_tensor(1 / max(1, step - t0)) + new_mu = torch.as_tensor(1 / max(1, step - t0)) mu.copy_(new_mu) @@ -381,27 +374,23 @@ def _multi_tensor_asgd( torch._foreach_copy_(grouped_mus, new_mus) del new_mus - # update eta = lr / (1 + lambd * lr * step^alpha) - new_etas = torch._foreach_pow(grouped_state_steps, alpha) - torch._foreach_mul_(new_etas, lambd) + # update eta = lr / ((1 + lambd * lr * step)^alpha) + new_etas = torch._foreach_mul(grouped_state_steps, lambd) torch._foreach_mul_(new_etas, lr) torch._foreach_add_(new_etas, 1) + torch._foreach_pow_(new_etas, alpha) torch._foreach_reciprocal_(new_etas) torch._foreach_mul_(new_etas, lr) torch._foreach_copy_(grouped_etas, new_etas) else: - step = grouped_state_steps[0].item() - new_etas = [] - new_mus = [] - - for i in range(len(grouped_mus)): - new_eta = _to_tensor( - lr / (1 + lambd * lr * step**alpha), device=device - ) - new_etas.append(new_eta) - new_mu = _to_tensor(1 / max(1, step - t0), device=device) - new_mus.append(new_mu) - + new_etas = [ + torch.as_tensor(lr / ((1 + lambd * lr * step) ** alpha), device=device) + for step in grouped_state_steps + ] + new_mus = [ + torch.as_tensor(1 / max(1, _get_value(step) - t0), device=device) + for step in grouped_state_steps + ] torch._foreach_copy_(grouped_etas, new_etas) torch._foreach_copy_(grouped_mus, new_mus) diff --git a/torch/testing/_internal/common_optimizers.py b/torch/testing/_internal/common_optimizers.py index 5a66923373f7..c81efb093cd8 100644 --- a/torch/testing/_internal/common_optimizers.py +++ b/torch/testing/_internal/common_optimizers.py @@ -590,6 +590,7 @@ def optim_inputs_func_asgd(device, dtype=None): ] return [ OptimizerInput(params=None, kwargs={}, desc="default"), + OptimizerInput(params=None, kwargs={"lambd": 0.1}, desc="non-default lambd"), OptimizerInput(params=None, kwargs={"lr": 0.02}, desc="non-default lr"), OptimizerInput(params=None, kwargs={"t0": 100}, desc="t0"), OptimizerInput(params=None, kwargs={"maximize": True}, desc="maximize"), @@ -1450,6 +1451,13 @@ def _get_optim_inputs_including_global_cliquey_kwargs( "TestOptimRenewed", "test_defaults_changed_to_foreach", ), + DecorateInfo( + unittest.skip( + "ASGD internally changes the weights even with zero grad" + ), + "TestOptimRenewed", + "test_step_is_noop_for_zero_grads", + ), ), ), OptimizerInfo( From 402170b22fb8331d9f15ba236d22e2e4de03b9c8 Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Thu, 16 May 2024 17:08:51 -0300 Subject: [PATCH 063/116] Early return in _recursive_build if obj is a Tensor (#125639) Fix issue #125551 Pull Request resolved: https://github.com/pytorch/pytorch/pull/125639 Approved by: https://github.com/ezyang --- test/dynamo/test_unspec.py | 30 +++++++++++++++++++ ...l_regression_mechanism_functional_call_cpu | 0 ...l_regression_mechanism_make_functional_cpu | 0 ..._regression_mechanism_functional_call_cuda | 0 ..._regression_mechanism_make_functional_cuda | 0 torch/_refs/__init__.py | 18 +++++++++-- 6 files changed, 45 insertions(+), 3 deletions(-) delete mode 100644 test/dynamo_expected_failures/TestExamplesCorrectnessCPU.test_maml_regression_mechanism_functional_call_cpu delete mode 100644 test/dynamo_expected_failures/TestExamplesCorrectnessCPU.test_maml_regression_mechanism_make_functional_cpu delete mode 100644 test/dynamo_expected_failures/TestExamplesCorrectnessCUDA.test_maml_regression_mechanism_functional_call_cuda delete mode 100644 test/dynamo_expected_failures/TestExamplesCorrectnessCUDA.test_maml_regression_mechanism_make_functional_cuda diff --git a/test/dynamo/test_unspec.py b/test/dynamo/test_unspec.py index 9b49f5ff8bb6..cb47a0b728a3 100644 --- a/test/dynamo/test_unspec.py +++ b/test/dynamo/test_unspec.py @@ -394,6 +394,36 @@ def f4(v): self.assertEqual(f3(r), optimize(f3)(r)) self.assertEqual(f4(r), optimize(f4)(r)) + def test_to_tensor(self): + def f1(): + a = np.random.uniform(low=-1, high=1, size=(20, 1)) + return torch.tensor([a, a, a, a], dtype=torch.float64, device="cpu") + + def f2(): + a = torch.tensor([[[123]]]) + return torch.tensor([a, a]) + + def f3(): + a = torch.tensor(123) + return torch.tensor([a, a]) + + def f4(): + a = torch.tensor(123) + b = torch.tensor([[[456]]]) + return torch.tensor([a, b]) + + def f5(): + a = np.array([1, 2]) + return torch.tensor([a, a]) + + optimize = torch.compile(backend="aot_eager", fullgraph=True) + + self.assertEqual(f1().shape, optimize(f1)().shape) + self.assertEqual(f2(), optimize(f2)()) + self.assertEqual(f3(), optimize(f3)()) + self.assertEqual(f4(), optimize(f4)()) + self.assertEqual(f5(), optimize(f5)()) + def test_sym_int_conversion(self): def f(x): y = x.size(0) diff --git a/test/dynamo_expected_failures/TestExamplesCorrectnessCPU.test_maml_regression_mechanism_functional_call_cpu b/test/dynamo_expected_failures/TestExamplesCorrectnessCPU.test_maml_regression_mechanism_functional_call_cpu deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestExamplesCorrectnessCPU.test_maml_regression_mechanism_make_functional_cpu b/test/dynamo_expected_failures/TestExamplesCorrectnessCPU.test_maml_regression_mechanism_make_functional_cpu deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestExamplesCorrectnessCUDA.test_maml_regression_mechanism_functional_call_cuda b/test/dynamo_expected_failures/TestExamplesCorrectnessCUDA.test_maml_regression_mechanism_functional_call_cuda deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestExamplesCorrectnessCUDA.test_maml_regression_mechanism_make_functional_cuda b/test/dynamo_expected_failures/TestExamplesCorrectnessCUDA.test_maml_regression_mechanism_make_functional_cuda deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index ac6a60d0078c..6772a4dff4a7 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -3830,7 +3830,7 @@ def _check_stack_inputs(tensors: TensorSequenceType) -> None: entry_shape = tensors[0].shape for i in range(1, len(tensors)): assert tensors[i].shape == entry_shape, ( - f"stack expects each tensor to be equal size, but got {entry_shape} at entry 0" + f"stack expects each tensor to be equal size, but got {entry_shape} at entry 0 " f"and {tensors[i].shape} at entry {i}" ) @@ -6358,12 +6358,24 @@ def _infer_scalar_type(obj): # Analogous to recursive_store # xref: recursive_store in torch/csrc/utils/tensor_new.cpp -def _recursive_build(scalarType: torch.dtype, obj: TensorOrNumberLikeType): - if isinstance(obj, Tensor) and obj.ndim <= 1: +def _recursive_build( + scalarType: torch.dtype, obj: Union[TensorOrNumberLikeType, TensorSequenceType] +): + if isinstance(obj, Tensor) and obj.numel() == 1: return obj.detach().to(dtype=scalarType, device="cpu", copy=True).view(()) + elif isinstance(obj, Tensor): + # It is invalid to call ".tensor([...])" with a non-scalar tensor in eager mode + # >>> torch.tensor([torch.randn(2)]) + # ValueError: only one element tensors can be converted to Python scalars + # + # But it is possible with a NumPy array + # >>> torch.tensor([np.random.uniform(size=(2,))]).shape + # torch.Size([1, 2]) + return obj.detach().to(dtype=scalarType, device="cpu", copy=True) elif isinstance(obj, Number): return torch.scalar_tensor(obj, dtype=scalarType) + # seq can be a list of tensors seq = obj return torch.stack([_recursive_build(scalarType, item) for item in seq]) From 81277baa0ca4c24f5e60c8348984802638436cb4 Mon Sep 17 00:00:00 2001 From: Matthew Hoffman Date: Fri, 17 May 2024 16:31:01 +0000 Subject: [PATCH 064/116] Remove removed ruff rule TRY200 (#126256) My TOML linter is complaining that "TRY200" is not acceptable for the `tool.ruff.lint` schema. From the ruff docs: https://docs.astral.sh/ruff/rules/reraise-no-cause/ > This rule has been removed and its documentation is only available for historical reasons. > > This rule is identical to [B904](https://docs.astral.sh/ruff/rules/raise-without-from-inside-except/) which should be used instead. and we are currently explicitly ignoring B904. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126256 Approved by: https://github.com/Skylion007 --- benchmarks/dynamo/common.py | 2 +- pyproject.toml | 3 +-- test/dynamo/test_repros.py | 2 +- test/inductor/test_coordinate_descent_tuner.py | 2 +- test/inductor/test_cuda_repro.py | 2 +- test/inductor/test_triton_heuristics.py | 2 +- test/test_autograd.py | 4 ++-- test/test_fx.py | 2 +- test/torch_np/numpy_tests/core/test_indexing.py | 4 ++-- test/torch_np/numpy_tests/core/test_multiarray.py | 4 ++-- test/torch_np/numpy_tests/core/test_scalar_methods.py | 2 +- test/torch_np/numpy_tests/lib/test_function_base.py | 2 +- test/torch_np/numpy_tests/linalg/test_linalg.py | 2 +- torch/_dynamo/debug_utils.py | 2 +- torch/_dynamo/eval_frame.py | 2 +- torch/_dynamo/symbolic_convert.py | 2 +- torch/_dynamo/utils.py | 2 +- torch/_dynamo/variables/sdpa.py | 4 ++-- torch/_dynamo/variables/tensor.py | 2 +- torch/_export/utils.py | 2 +- torch/_inductor/__init__.py | 2 +- torch/_inductor/select_algorithm.py | 6 +++--- torch/_numpy/_util.py | 2 +- torch/_numpy/linalg.py | 2 +- torch/_numpy/testing/utils.py | 10 +++++----- torch/_refs/__init__.py | 2 +- torch/distributed/distributed_c10d.py | 2 +- torch/distributed/pipelining/_IR.py | 2 +- torch/export/_trace.py | 6 +++--- torch/export/_unlift.py | 2 +- torch/export/dynamic_shapes.py | 2 +- torch/export/exported_program.py | 2 +- torch/fx/graph_module.py | 2 +- torch/fx/passes/net_min_base.py | 2 +- torch/nn/utils/_named_member_accessor.py | 2 +- torch/testing/_internal/opinfo/definitions/sparse.py | 4 ++-- torch/utils/_sympy/value_ranges.py | 4 ++-- torch/utils/_traceback.py | 2 +- 38 files changed, 52 insertions(+), 53 deletions(-) diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index 2b877e43447e..6ea7a31a3915 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -1252,7 +1252,7 @@ def wrapper(self, *args, **kwargs) -> Any: ) time.sleep(wait) else: - raise RuntimeError( # noqa: TRY200 + raise RuntimeError( # noqa: B904 f"Failed to load model '{args}' with following error(s): {str(e)}." ) diff --git a/pyproject.toml b/pyproject.toml index 3ff4b94447f9..07f075082097 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,7 +51,6 @@ ignore = [ "B019", "B023", "B028", # No explicit `stacklevel` keyword argument found - "B904", # Migrate from TRY200 "E402", "C408", # C408 ignored because we like the dict keyword argument syntax "E501", # E501 is not flexible enough, we're using B950 instead @@ -90,6 +89,7 @@ ignore = [ ] select = [ "B", + "B904", # Re-raised error without specifying the cause via the from keyword "C4", "G", "E", @@ -133,7 +133,6 @@ select = [ "RUF017", "RUF018", # no assignment in assert "TRY002", # ban vanilla raise (todo fix NOQAs) - "TRY200", # TODO: migrate from deprecated alias "TRY302", "TRY401", # verbose-log-message "UP", diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index ff229c06432f..510b32be905b 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -400,7 +400,7 @@ def _iter_ex(self, resolve: bool) -> Iterator[Any]: try: return ListConfig.ListIterator(self, resolve) except Exception: - raise AssertionError + raise AssertionError from None def __init__(self): self._content = [ diff --git a/test/inductor/test_coordinate_descent_tuner.py b/test/inductor/test_coordinate_descent_tuner.py index 70618c06e9ec..fdd3abb14392 100644 --- a/test/inductor/test_coordinate_descent_tuner.py +++ b/test/inductor/test_coordinate_descent_tuner.py @@ -16,7 +16,7 @@ except ImportError: if __name__ == "__main__": sys.exit(0) - raise unittest.SkipTest("requires triton") # noqa: TRY200 + raise unittest.SkipTest("requires triton") # noqa: B904 from torch._inductor import config from torch._inductor.runtime.coordinate_descent_tuner import CoordescTuner diff --git a/test/inductor/test_cuda_repro.py b/test/inductor/test_cuda_repro.py index db02d1931009..f303330bc114 100644 --- a/test/inductor/test_cuda_repro.py +++ b/test/inductor/test_cuda_repro.py @@ -32,7 +32,7 @@ import triton from triton import language as tl except ImportError: - raise unittest.SkipTest("requires triton") # noqa: TRY200 + raise unittest.SkipTest("requires triton") # noqa: B904 try: from . import test_torchinductor diff --git a/test/inductor/test_triton_heuristics.py b/test/inductor/test_triton_heuristics.py index 6375512cc128..d8c74c0a3841 100644 --- a/test/inductor/test_triton_heuristics.py +++ b/test/inductor/test_triton_heuristics.py @@ -14,7 +14,7 @@ except ImportError: if __name__ == "__main__": sys.exit(0) - raise unittest.SkipTest("requires triton") # noqa: TRY200 + raise unittest.SkipTest("requires triton") # noqa: B904 from torch._inductor import config from torch._inductor.runtime.hints import TRITON_MAX_BLOCK diff --git a/test/test_autograd.py b/test/test_autograd.py index e20e8b18ebae..3ae37e18e7a3 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -2612,7 +2612,7 @@ def coro_no_grad(n=10): except UnrecoverableException: self.assertFalse(torch.is_grad_enabled()) - raise SecondaryException + raise SecondaryException from None @torch.enable_grad() def coro_enable_grad(n=10): @@ -2624,7 +2624,7 @@ def coro_enable_grad(n=10): except UnrecoverableException: self.assertTrue(torch.is_grad_enabled()) - raise SecondaryException + raise SecondaryException from None with torch.enable_grad(): coro = coro_no_grad() diff --git a/test/test_fx.py b/test/test_fx.py index eadcd750aede..a58abb906d89 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -4078,7 +4078,7 @@ def test_function_back_compat(self): f"unintended, please revert it. If it was intended, check with the FX " \ f"team to ensure that the proper deprecation protocols have been followed " \ f"and subsequently --accept the change." - raise AssertionError(msg) # noqa: TRY200 + raise AssertionError(msg) # noqa: B904 def test_class_member_back_compat(self): """ diff --git a/test/torch_np/numpy_tests/core/test_indexing.py b/test/torch_np/numpy_tests/core/test_indexing.py index f0d9023ddbff..77844e77a6e0 100644 --- a/test/torch_np/numpy_tests/core/test_indexing.py +++ b/test/torch_np/numpy_tests/core/test_indexing.py @@ -740,7 +740,7 @@ def _get_multi_index(self, arr, indices): try: indx = np.array(indx, dtype=np.intp) except ValueError: - raise IndexError + raise IndexError from None in_indices[i] = indx elif indx.dtype.kind != "b" and indx.dtype.kind != "i": raise IndexError( @@ -902,7 +902,7 @@ def _get_multi_index(self, arr, indices): arr = arr.reshape(arr.shape[:ax] + mi.shape + arr.shape[ax + 1 :]) except ValueError: # too many dimensions, probably - raise IndexError + raise IndexError from None ax += mi.ndim continue diff --git a/test/torch_np/numpy_tests/core/test_multiarray.py b/test/torch_np/numpy_tests/core/test_multiarray.py index 38e2df73d5b8..bf9aab8ebcee 100644 --- a/test/torch_np/numpy_tests/core/test_multiarray.py +++ b/test/torch_np/numpy_tests/core/test_multiarray.py @@ -409,7 +409,7 @@ def make_array(size, offset, strides): try: r = np.ndarray([size], dtype=int, buffer=x, offset=offset * x.itemsize) except Exception as e: - raise RuntimeError(e) # noqa: TRY200 + raise RuntimeError(e) # noqa: B904 r.strides = strides = strides * x.itemsize return r @@ -6304,7 +6304,7 @@ def test_flat_element_deletion(self): except TypeError: pass except Exception: - raise AssertionError + raise AssertionError from None class TestConversion(TestCase): diff --git a/test/torch_np/numpy_tests/core/test_scalar_methods.py b/test/torch_np/numpy_tests/core/test_scalar_methods.py index addc550ed337..2e763c6636a8 100644 --- a/test/torch_np/numpy_tests/core/test_scalar_methods.py +++ b/test/torch_np/numpy_tests/core/test_scalar_methods.py @@ -132,7 +132,7 @@ def test_roundtrip(self, ftype, frac_vals, exp_vals): df = np.longdouble(d) except (OverflowError, RuntimeWarning): # the values may not fit in any float type - raise SkipTest("longdouble too small on this platform") # noqa: TRY200 + raise SkipTest("longdouble too small on this platform") # noqa: B904 assert_equal(nf / df, f, f"{n}/{d}") diff --git a/test/torch_np/numpy_tests/lib/test_function_base.py b/test/torch_np/numpy_tests/lib/test_function_base.py index fa1168840635..d0eda87b0108 100644 --- a/test/torch_np/numpy_tests/lib/test_function_base.py +++ b/test/torch_np/numpy_tests/lib/test_function_base.py @@ -1435,7 +1435,7 @@ def test_keywords_no_func_code(self): try: vectorize(random.randrange) # Should succeed except Exception: - raise AssertionError # noqa: TRY200 + raise AssertionError # noqa: B904 def test_keywords2_ticket_2100(self): # Test kwarg support: enhancement ticket 2100 diff --git a/test/torch_np/numpy_tests/linalg/test_linalg.py b/test/torch_np/numpy_tests/linalg/test_linalg.py index 616c7b95f5c9..3a5c21745e24 100644 --- a/test/torch_np/numpy_tests/linalg/test_linalg.py +++ b/test/torch_np/numpy_tests/linalg/test_linalg.py @@ -1958,7 +1958,7 @@ def test_xerbla_override(self): pid = os.fork() except (OSError, AttributeError): # fork failed, or not running on POSIX - raise SkipTest("Not POSIX or fork failed.") # noqa: TRY200 + raise SkipTest("Not POSIX or fork failed.") # noqa: B904 if pid == 0: # child; close i/o file handles diff --git a/torch/_dynamo/debug_utils.py b/torch/_dynamo/debug_utils.py index 67dd492fe851..4b4b37a34da9 100644 --- a/torch/_dynamo/debug_utils.py +++ b/torch/_dynamo/debug_utils.py @@ -360,7 +360,7 @@ def same_two_models( fp64_ref = run_fwd_maybe_bwd(fp64_model, fp64_examples, only_fwd) except Exception: if require_fp64: - raise RuntimeError("Could not generate fp64 outputs") # noqa: TRY200 + raise RuntimeError("Could not generate fp64 outputs") # noqa: B904 log.warning("Could not generate fp64 outputs") try: diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 391bdfcf0202..db35c0f631e8 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -1387,7 +1387,7 @@ def graph_with_interpreter(*args): )(*example_fake_inputs) except CondOpArgsMismatchError as e: # Wrap the internal error to the user-facing error - raise UserError( # noqa: TRY200 + raise UserError( # noqa: B904 UserErrorType.DYNAMIC_CONTROL_FLOW, str(e), case_name="cond_operands", diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 4b4d6d3de675..093809703405 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -2502,7 +2502,7 @@ def inline_call_( sub_locals, closure_cells = func.bind_args(parent, args, kwargs) except TypeError as e: # Wrap the general TypeError during bind_args() to the internal ArgsMismatchError with detailed info - raise ArgsMismatchError( # noqa: TRY200 + raise ArgsMismatchError( # noqa: B904 "{reason}.\n func = {func}, args = {args}, kwargs = {kwargs}".format( reason=str(e), func=f"'{func.get_name()}' {func.get_filename()}:{func.get_code().co_firstlineno}", diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 21f3da2f61e6..fcfbde1a6a79 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -1751,7 +1751,7 @@ def get_fake_value(node, tx, allow_non_graph_fake=False): elif isinstance( cause, torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode ): - raise UserError( # noqa: TRY200 + raise UserError( # noqa: B904 UserErrorType.CONSTRAINT_VIOLATION, "Tried to use data-dependent value in the subsequent computation. " "This can happen when we encounter unbounded dynamic value that is unknown during tracing time. " diff --git a/torch/_dynamo/variables/sdpa.py b/torch/_dynamo/variables/sdpa.py index 0a6af76690df..c5b0d9f586c8 100644 --- a/torch/_dynamo/variables/sdpa.py +++ b/torch/_dynamo/variables/sdpa.py @@ -65,9 +65,9 @@ def var_getattr(self, tx, name: str) -> VariableTracker: getattr_static(torch._C._SDPAParams, name) except AttributeError: # Using raise from is too verbose here - raise Unsupported( # noqa: TRY200 + raise Unsupported( f"Unsupported torch._C._SDPAParams attribute {name}" - ) + ) from None proxy = GetAttrVariable.create_getattr_proxy(self.as_proxy(), name) if self.source is not None: diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index 361527050b16..a1adbcf614bc 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -1019,7 +1019,7 @@ def evaluate_expr(self, output_graph=None): try: return guard_scalar(self.sym_num) except GuardOnDataDependentSymNode as e: - raise UserError( # noqa: TRY200 + raise UserError( # noqa: B904 UserErrorType.ANTI_PATTERN, f"Consider annotating your code using torch._check*(). {str(e)}", case_name="constrain_as_size_example", diff --git a/torch/_export/utils.py b/torch/_export/utils.py index 59648ccadab2..19fc4e9bdc4d 100644 --- a/torch/_export/utils.py +++ b/torch/_export/utils.py @@ -118,7 +118,7 @@ def get_keystr(key_path: KeyPath) -> str: sympy.Eq(node_dim.node.expr, arg_dim), symbol ) if solution is None: - raise RuntimeError( # noqa: TRY200 + raise RuntimeError( # noqa: B904 f"Expected input {node.name}.shape[{j}] = {arg_dim} to be " f"of the form {node_dim.node.expr}, where {symbol} is an integer" ) diff --git a/torch/_inductor/__init__.py b/torch/_inductor/__init__.py index 0516fc55e074..0d7cd8cece49 100644 --- a/torch/_inductor/__init__.py +++ b/torch/_inductor/__init__.py @@ -81,7 +81,7 @@ def aot_compile( ) if in_spec is not None and received_spec != in_spec: - raise ValueError( # noqa: TRY200 + raise ValueError( # noqa: B904 "Trying to flatten user inputs with exported input tree spec: \n" f"{in_spec}\n" "but actually got inputs with tree spec of: \n" diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 4e341f358c45..4940f53b1e79 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -1088,7 +1088,7 @@ def wait_on_futures(): else: raise e except ImportError: - raise e + raise e from None executor.shutdown(wait=True) @@ -1284,7 +1284,7 @@ def benchmark_in_current_process(choices): ) timing = float("inf") except AssertionError as e: - raise AssertionError( # noqa: TRY200 + raise AssertionError( # noqa: B904 f"Incorrect result from choice {choice}\n\n{e}" ) except Exception as e: @@ -1297,7 +1297,7 @@ def benchmark_in_current_process(choices): else: raise e except ImportError: - raise e + raise e from None timings[choice] = timing diff --git a/torch/_numpy/_util.py b/torch/_numpy/_util.py index ff219d930731..477d3d44671a 100644 --- a/torch/_numpy/_util.py +++ b/torch/_numpy/_util.py @@ -178,7 +178,7 @@ def _try_convert_to_tensor(obj): tensor = torch.as_tensor(obj) except Exception as e: mesg = f"failed to convert {obj} to ndarray. \nInternal error is: {str(e)}." - raise NotImplementedError(mesg) # noqa: TRY200 + raise NotImplementedError(mesg) # noqa: B904 return tensor diff --git a/torch/_numpy/linalg.py b/torch/_numpy/linalg.py index 2232419db1b2..093851142dbc 100644 --- a/torch/_numpy/linalg.py +++ b/torch/_numpy/linalg.py @@ -38,7 +38,7 @@ def wrapped(*args, **kwds): try: return func(*args, **kwds) except torch._C._LinAlgError as e: - raise LinAlgError(*e.args) # noqa: TRY200 + raise LinAlgError(*e.args) # noqa: B904 return wrapped diff --git a/torch/_numpy/testing/utils.py b/torch/_numpy/testing/utils.py index cd3d3407f582..f757860e1218 100644 --- a/torch/_numpy/testing/utils.py +++ b/torch/_numpy/testing/utils.py @@ -247,7 +247,7 @@ def assert_equal(actual, desired, err_msg="", verbose=True): assert_equal(actualr, desiredr) assert_equal(actuali, desiredi) except AssertionError: - raise AssertionError(msg) # noqa: TRY200 + raise AssertionError(msg) # noqa: B904 # isscalar test to check cases such as [np.nan] != np.nan if isscalar(desired) != isscalar(actual): @@ -279,7 +279,7 @@ def assert_equal(actual, desired, err_msg="", verbose=True): except (DeprecationWarning, FutureWarning) as e: # this handles the case when the two types are not even comparable if "elementwise == comparison" in e.args[0]: - raise AssertionError(msg) # noqa: TRY200 + raise AssertionError(msg) # noqa: B904 else: raise @@ -426,7 +426,7 @@ def _build_err_msg(): assert_almost_equal(actualr, desiredr, decimal=decimal) assert_almost_equal(actuali, desiredi, decimal=decimal) except AssertionError: - raise AssertionError(_build_err_msg()) # noqa: TRY200 + raise AssertionError(_build_err_msg()) # noqa: B904 if isinstance(actual, (ndarray, tuple, list)) or isinstance( desired, (ndarray, tuple, list) @@ -726,7 +726,7 @@ def func_assert_same_pos(x, y, func=isnan, hasval="nan"): names=("x", "y"), precision=precision, ) - raise ValueError(msg) # noqa: TRY200 + raise ValueError(msg) # noqa: B904 def assert_array_equal(x, y, err_msg="", verbose=True, *, strict=False): @@ -2272,7 +2272,7 @@ def check_free_memory(free_bytes): try: mem_free = _parse_size(env_value) except ValueError as exc: - raise ValueError( # noqa: TRY200 + raise ValueError( # noqa: B904 f"Invalid environment variable {env_var}: {exc}" ) diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index 6772a4dff4a7..68675c751736 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -6298,7 +6298,7 @@ def _compute_sizes(seq, scalar_type): try: handle = seq[0] except Exception: - raise ValueError( # noqa: TRY200 + raise ValueError( # noqa: B904 f"could not determine the shape of object type '{type(seq).__name__}'" ) seq = handle diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 9e8655151422..21d7abd93837 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -737,7 +737,7 @@ def _store_based_barrier(rank, store, group_name, rendezvous_count, timeout, log ) if timedelta(seconds=(time.time() - start)) > timeout: - raise DistStoreError( # noqa: TRY200 + raise DistStoreError( # noqa: B904 "Timed out initializing process group in store based barrier on " f"rank {rank}, for key: {store_key} (world_size={world_size}, " f"num_workers_joined={worker_count}, timeout={timeout} error={e})" diff --git a/torch/distributed/pipelining/_IR.py b/torch/distributed/pipelining/_IR.py index 799ec6d5e0d1..1beb75e9df66 100644 --- a/torch/distributed/pipelining/_IR.py +++ b/torch/distributed/pipelining/_IR.py @@ -1281,7 +1281,7 @@ def annotate_split_points(mod: torch.nn.Module, spec: Dict[str, SplitPoint]): except AttributeError as e: raise AttributeError( f'Specified target {qualname} referenced nonexistent module {".".join(atoms[:i+1])}' - ) + ) from e mod_to_wrap = getattr(predecessor_module, atoms[-1]) mod_to_wrap._orig_forward = mod_to_wrap.forward diff --git a/torch/export/_trace.py b/torch/export/_trace.py index 72ae886a392c..55c7442380f9 100644 --- a/torch/export/_trace.py +++ b/torch/export/_trace.py @@ -452,9 +452,9 @@ def _export_to_torch_ir( **kwargs, ) except (ConstraintViolationError, ValueRangeError) as e: - raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e)) # noqa: TRY200 + raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e)) # noqa: B904 except GuardOnDataDependentSymNode as e: - raise UserError( # noqa: TRY200 + raise UserError( # noqa: B904 UserErrorType.ANTI_PATTERN, f"Consider annotating your code using torch._check*(). {str(e)}", case_name="constrain_as_size_example", @@ -1101,7 +1101,7 @@ def forward(self, *args, **kwargs): _disable_forced_specializations=_disable_forced_specializations, ) except (ConstraintViolationError, ValueRangeError) as e: - raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e)) # noqa: TRY200 + raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e)) # noqa: B904 combined_args = _combine_args(mod, args, kwargs) range_constraints = make_constraints( diff --git a/torch/export/_unlift.py b/torch/export/_unlift.py index 52ce64e4dcad..2fdb7916eeeb 100644 --- a/torch/export/_unlift.py +++ b/torch/export/_unlift.py @@ -22,7 +22,7 @@ def _check_input_constraints_pre_hook(self, *args, **kwargs): flat_args_with_path, received_spec = pytree.tree_flatten_with_path(args) if received_spec != self._in_spec: - raise ValueError( # noqa: TRY200 + raise ValueError( # noqa: B904 "Trying to flatten user inputs with exported input tree spec: \n" f"{self._in_spec}\n" "but actually got inputs with tree spec of: \n" diff --git a/torch/export/dynamic_shapes.py b/torch/export/dynamic_shapes.py index eba833332344..f3fecb1043fb 100644 --- a/torch/export/dynamic_shapes.py +++ b/torch/export/dynamic_shapes.py @@ -720,7 +720,7 @@ def root_value(): if solution is not None: return int(solution[1]) # type: ignore[call-overload] else: - raise UserError( # noqa: TRY200 + raise UserError( # noqa: B904 UserErrorType.CONSTRAINT_VIOLATION, f"Expected shape[{i}] = {tensor.shape[i]} of input Tensor to be " f"of the form {expr}, where {symbol} is an integer", diff --git a/torch/export/exported_program.py b/torch/export/exported_program.py index f5635038c4e2..ffb3467055b3 100644 --- a/torch/export/exported_program.py +++ b/torch/export/exported_program.py @@ -448,7 +448,7 @@ def _postprocess_graph_module_outputs(self, res, orig_args, orig_kwargs): res = pytree.tree_unflatten(res, self.call_spec.out_spec) except Exception: _, received_spec = pytree.tree_flatten(res) - raise error.InternalError( # noqa: TRY200 + raise error.InternalError( # noqa: B904 "Trying to flatten user outputs with exported output tree spec: \n" f"{self.call_spec.out_spec}\n" "but actually got outputs with tree spec of: \n" diff --git a/torch/fx/graph_module.py b/torch/fx/graph_module.py index 95a9f568443e..fa44b6306786 100644 --- a/torch/fx/graph_module.py +++ b/torch/fx/graph_module.py @@ -310,7 +310,7 @@ def __call__(self, obj, *args, **kwargs): _WrappedCall._generate_error_message(topmost_framesummary), file=sys.stderr, ) - raise e.with_traceback(None) # noqa: TRY200 + raise e.with_traceback(None) # noqa: B904 else: raise e diff --git a/torch/fx/passes/net_min_base.py b/torch/fx/passes/net_min_base.py index 9d24162500ac..6d050c78f754 100644 --- a/torch/fx/passes/net_min_base.py +++ b/torch/fx/passes/net_min_base.py @@ -373,7 +373,7 @@ def _run_and_compare( self._store_outputs(a_result, b_result, submodule) except Exception as e: report.append(f"Exception raised when running {submod_name}: {e}") - raise FxNetMinimizerRunFuncError( # noqa: TRY200 + raise FxNetMinimizerRunFuncError( # noqa: B904 f"Exception raised when running {submod_name}: {e}" ) diff --git a/torch/nn/utils/_named_member_accessor.py b/torch/nn/utils/_named_member_accessor.py index 3a82b2b426aa..e46318b0d3ac 100644 --- a/torch/nn/utils/_named_member_accessor.py +++ b/torch/nn/utils/_named_member_accessor.py @@ -147,7 +147,7 @@ def get_submodule(self, name: str) -> "torch.nn.Module": f"{module._get_name()} has no attribute `{attr}`" ) from ex if not isinstance(submodule, torch.nn.Module): - raise TypeError( # noqa: TRY200 + raise TypeError( # noqa: B904 f"submodule `{name}`: {submodule} is not an instance of torch.nn.Module" ) self.memo[name] = submodule diff --git a/torch/testing/_internal/opinfo/definitions/sparse.py b/torch/testing/_internal/opinfo/definitions/sparse.py index e6f0ad0e6f51..3e1f816d9f73 100644 --- a/torch/testing/_internal/opinfo/definitions/sparse.py +++ b/torch/testing/_internal/opinfo/definitions/sparse.py @@ -25,7 +25,7 @@ def _check_fail(sample): except sample.error_type: pass except Exception as msg: - raise AssertionError( # noqa: TRY200 + raise AssertionError( # noqa: B904 f"{op_info.name} on {sample.sample_input=} expected exception " f"{sample.error_type}: {sample.error_regex}, got {type(msg).__name__}: {msg}" ) @@ -39,7 +39,7 @@ def _check_success(sample): try: op_info(sample.input, *sample.args, **sample.kwargs) except Exception as msg: - raise AssertionError( # noqa: TRY200 + raise AssertionError( # noqa: B904 f"{op_info.name} on {sample=} expected to succeed " f", got {type(msg).__name__}: {msg}" ) diff --git a/torch/utils/_sympy/value_ranges.py b/torch/utils/_sympy/value_ranges.py index eae126b1b4dc..504fe757d4f2 100644 --- a/torch/utils/_sympy/value_ranges.py +++ b/torch/utils/_sympy/value_ranges.py @@ -137,8 +137,8 @@ def __init__(self, lower: AllIn, upper: AllIn) -> None: try: if not sympy_generic_le(lower, upper): raise ValueRangeError(f"Invalid ranges [{lower}:{upper}]") - except TypeError: - raise TypeError(f"Could not compare {lower} <= {upper}") # noqa: TRY200 + except TypeError as e: + raise TypeError(f"Could not compare {lower} <= {upper}") from e # Because this is a frozen class object.__setattr__(self, "lower", lower) object.__setattr__(self, "upper", upper) diff --git a/torch/utils/_traceback.py b/torch/utils/_traceback.py index fa73b9f41cd6..9f4d04c55105 100644 --- a/torch/utils/_traceback.py +++ b/torch/utils/_traceback.py @@ -128,7 +128,7 @@ def report_compile_source_on_error(): tb.tb_next = tb_next tb_next = tb - raise exc.with_traceback(tb_next) # noqa: TRY200 + raise exc.with_traceback(tb_next) # noqa: B904 def shorten_filename(fn, *, base=None): """Shorten a source filepath, with the assumption that torch/ subdirectories don't need to be shown to user.""" From 7e9a037b47e7aad973aea02f0908f6d6dcbf748a Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Fri, 17 May 2024 16:34:19 +0000 Subject: [PATCH 065/116] [Perf] Vectorize more dtype for int4mm (#126512) It used to be vectorized only for f16, but no reason not to do the same for bf16 or f32 Spiritual followup of https://github.com/pytorch/pytorch/pull/125290 Pull Request resolved: https://github.com/pytorch/pytorch/pull/126512 Approved by: https://github.com/Skylion007 --- aten/src/ATen/native/cpu/int4mm_kernel.cpp | 94 +++++++++++++++++++--- 1 file changed, 85 insertions(+), 9 deletions(-) diff --git a/aten/src/ATen/native/cpu/int4mm_kernel.cpp b/aten/src/ATen/native/cpu/int4mm_kernel.cpp index acb4b927f23f..2ffef25a10ff 100644 --- a/aten/src/ATen/native/cpu/int4mm_kernel.cpp +++ b/aten/src/ATen/native/cpu/int4mm_kernel.cpp @@ -341,12 +341,46 @@ inline void tinygemm_kernel( #if !defined(C10_MOBILE) && defined(__aarch64__) #include -template -inline void tinygemm_kernel( - const Half* RESTRICT A, + +inline float32x4x2_t load_as_float32x4x2(const Half* ptr) { + float16x4x2_t f16_val = vld2_f16(reinterpret_cast(ptr)); + auto val_low = vcvt_f32_f16(f16_val.val[0]); + auto val_high = vcvt_f32_f16(f16_val.val[1]); + return {val_low, val_high}; +} + +inline void store_float32x4(Half* ptr, float32x4_t val) { + vst1_f16(reinterpret_cast(ptr), vcvt_f16_f32(val)); +} + +inline float32x4x2_t load_as_float32x4x2(const BFloat16* ptr) { + int32x4_t shift = vdupq_n_s32(16); + uint16x4x2_t u16_val = vld2_u16(reinterpret_cast(ptr)); + uint32x4_t int_low = vmovl_u16(u16_val.val[0]); + uint32x4_t int_high = vmovl_u16(u16_val.val[1]); + return {vreinterpretq_f32_u32(vshlq_u32(int_low, shift)), vreinterpretq_f32_u32(vshlq_u32(int_high, shift))}; +} + +inline void store_float32x4(BFloat16* ptr, float32x4_t val) { + int32x4_t shift = vdupq_n_s32(-16); + uint32x4_t uint32_val = vshlq_u32(vreinterpretq_u32_f32(val), shift); + vst1_u16(reinterpret_cast(ptr), vmovn_u32(uint32_val)); +} + +inline float32x4x2_t load_as_float32x4x2(const float* ptr) { + return vld2q_f32(ptr); +} + +inline void store_float32x4(float* ptr, float32x4_t val) { + vst1q_f32(ptr, val); +} + +template +inline void tinygemm_kernel_( + const T* RESTRICT A, const uint8_t* RESTRICT B, - const Half* RESTRICT ScaleAndZeros, - Half* RESTRICT C, + const T* RESTRICT ScaleAndZeros, + T* RESTRICT C, int lda, int ldb, int ldc, @@ -368,9 +402,9 @@ inline void tinygemm_kernel( if (is_block_start(k, BLOCK_K)) { int kb = k / BLOCK_K; c10::ForcedUnroll<4>{}([&](auto i) { - auto scales_and_zeros = vld2_f16(reinterpret_cast(ScaleAndZeros + kb * ldc * 2 + n * 2 + i * 8)); - scales[i] = vcvt_f32_f16(scales_and_zeros.val[0]); - zeros[i] = vcvt_f32_f16(scales_and_zeros.val[1]); + auto scales_and_zeros = load_as_float32x4x2(ScaleAndZeros + kb * ldc * 2 + n * 2 + i * 8); + scales[i] = scales_and_zeros.val[0]; + zeros[i] = scales_and_zeros.val[1]; }); } c10::ForcedUnroll<4>{}([&](auto i) { @@ -383,11 +417,53 @@ inline void tinygemm_kernel( }); } c10::ForcedUnroll<4>{}([&](auto i) { - vst1_f16(reinterpret_cast(C + m * ldc + n + i * 4), vcvt_f16_f32(c_val[i])); + store_float32x4(C + m * ldc + n + i * 4, c_val[i]); }); } } } + +template +inline void tinygemm_kernel( + const Half* RESTRICT A, + const uint8_t* RESTRICT B, + const Half* RESTRICT ScaleAndZeros, + Half* RESTRICT C, + int lda, + int ldb, + int ldc, + int K, + int BLOCK_K) { + tinygemm_kernel_(A, B, ScaleAndZeros, C, lda, ldb, ldc, K, BLOCK_K); +} + +template +inline void tinygemm_kernel( + const BFloat16* RESTRICT A, + const uint8_t* RESTRICT B, + const BFloat16* RESTRICT ScaleAndZeros, + BFloat16* RESTRICT C, + int lda, + int ldb, + int ldc, + int K, + int BLOCK_K) { + tinygemm_kernel_(A, B, ScaleAndZeros, C, lda, ldb, ldc, K, BLOCK_K); +} + +template +inline void tinygemm_kernel( + const float* RESTRICT A, + const uint8_t* RESTRICT B, + const float* RESTRICT ScaleAndZeros, + float* RESTRICT C, + int lda, + int ldb, + int ldc, + int K, + int BLOCK_K) { + tinygemm_kernel_(A, B, ScaleAndZeros, C, lda, ldb, ldc, K, BLOCK_K); +} #endif template From 6bcf15669ed2c2dcaa94f87682e504a493397994 Mon Sep 17 00:00:00 2001 From: Colin Peppler Date: Thu, 16 May 2024 15:26:39 -0700 Subject: [PATCH 066/116] [inductor] fix unbacked case in pointwise + reduction vertical fusion (#125982) ``` $ INDUCTOR_TEST_DISABLE_FRESH_CACHE=1 python test/inductor/test_unbacked_symints.py -k test_vertical_pointwise_reduction_fusion File "/data/users/colinpeppler/pytorch/torch/_inductor/scheduler.py", line 1953, in fuse_nodes_once for node1, node2 in self.get_possible_fusions(): File "/data/users/colinpeppler/pytorch/torch/_inductor/scheduler.py", line 2010, in get_possible_fusions check_all_pairs(node_grouping) File "/data/users/colinpeppler/pytorch/torch/_inductor/scheduler.py", line 1997, in check_all_pairs if self.can_fuse(node1, node2): File "/data/users/colinpeppler/pytorch/torch/_inductor/scheduler.py", line 2252, in can_fuse return self.get_backend(device).can_fuse_vertical(node1, node2) File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/cuda_combined_scheduling.py", line 39, in can_fuse_vertical return self._triton_scheduling.can_fuse_vertical(node1, node2) File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton.py", line 3237, in can_fuse if not all( File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton.py", line 3238, in TritonKernel.is_compatible((numel2, rnumel2), n.get_ranges()) File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton.py", line 1543, in is_compatible cls._split_iteration_ranges(groups, lengths) File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton.py", line 1507, in _split_iteration_ranges while current_group < len(remaining) and sv.size_hint(remaining[current_group]) == 1: File "/data/users/colinpeppler/pytorch/torch/_inductor/sizevars.py", line 442, in size_hint return int(out) File "/home/colinpeppler/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/core/expr.py", line 320, in __int__ raise TypeError("Cannot convert symbols to int") torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised: TypeError: Cannot convert symbols to int ``` Where the unbacked symints show up at. ``` > /data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton.py(1506)_split_iteration_ranges() (Pdb) print(groups) (1, 512*u0) (Pdb) print(lengths) ([u0, 32, 16], []) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/125982 Approved by: https://github.com/jansel --- test/inductor/test_unbacked_symints.py | 26 ++++++++++++++++++++++++++ torch/_inductor/codegen/triton.py | 9 +++++---- torch/_inductor/sizevars.py | 16 ++++++++++++++++ 3 files changed, 47 insertions(+), 4 deletions(-) diff --git a/test/inductor/test_unbacked_symints.py b/test/inductor/test_unbacked_symints.py index 43d1307fcfa5..60ce45317238 100644 --- a/test/inductor/test_unbacked_symints.py +++ b/test/inductor/test_unbacked_symints.py @@ -188,6 +188,32 @@ def fn(x, w, a, b): expected = fn(*example_inputs) torch.testing.assert_close(actual, expected) + @skipCUDAIf(not HAS_CUDA, "requires cuda") + @dynamo_config.patch({"capture_scalar_outputs": True}) + def test_vertical_pointwise_reduction_fusion(self, device): + # Tests fusing a pointwise & reduction op with unbacked numel/rnumel. + def fn(x, y, repeats): + u0 = repeats.item() + unbacked = y.expand(u0, *y.shape) # [u0, 1, 16] + + # Note: We add x to both pointwise and reduction. Otherwise, the + # scheduler will refuse to fuse ops whose only common buffer has + # unbacked symints. + pointwise = unbacked + x + reduction = torch.sum(pointwise + x) + return pointwise, reduction + + example_inputs = ( + torch.randn(32, 16).cuda(), + torch.randn(1, 16).cuda(), + torch.tensor(32).cuda(), + ) + + actual = torch.compile(fn, fullgraph=True)(*example_inputs) + expected = fn(*example_inputs) + torch.testing.assert_close(actual, expected) + self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) + instantiate_device_type_tests( TestUnbackedSymints, globals(), only_for=(GPU_TYPE, "cpu") diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index e015cd7dbf6a..b1b7d951b99a 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -1501,14 +1501,15 @@ def getter(flat_vars): return_getters.append(lambda _: sympy.Integer(0)) continue - while ( - current_group < len(remaining) - and sv.size_hint(remaining[current_group]) == 1 + while current_group < len(remaining) and sv.statically_known_equals( + remaining[current_group], 1 # type: ignore[arg-type] ): # scroll to next group with remaining elements current_group += 1 - if sv.size_hint(size) > sv.size_hint(remaining[current_group]): + if current_group + 1 < len(remaining) and sv.statically_known_gt( + size, remaining[current_group] + ): # need to break size in two if not sv.statically_known_multiple_of( size, remaining[current_group] diff --git a/torch/_inductor/sizevars.py b/torch/_inductor/sizevars.py index 65a9cb837907..b6288b34fafa 100644 --- a/torch/_inductor/sizevars.py +++ b/torch/_inductor/sizevars.py @@ -310,6 +310,14 @@ def statically_known_leq(self, left: Expr, right: Expr) -> bool: expr = left <= right return self.is_expr_static_and_true(expr) + # See Note - [On Statically Known] + def statically_known_geq(self, left: Expr, right: Expr) -> bool: + """ + Returns a bool indicating if it is sound to optimize as if left is greater than or equal to right. + """ + expr = left >= right + return self.is_expr_static_and_true(expr) + # See Note - [On Statically Known] def statically_known_lt(self, left: Expr, right: Expr) -> bool: """ @@ -318,6 +326,14 @@ def statically_known_lt(self, left: Expr, right: Expr) -> bool: expr = left < right return self.is_expr_static_and_true(expr) + # See Note - [On Statically Known] + def statically_known_gt(self, left: Expr, right: Expr) -> bool: + """ + Returns a bool indicating if it is sound to optimize as if left is greater than right. + """ + expr = left > right + return self.is_expr_static_and_true(expr) + # See Note - [On Statically Known] def statically_known_multiple_of(self, numerator: Expr, denominator: Expr) -> bool: """ From 31ea8290e7973441bcc9517f4f322fe131149e82 Mon Sep 17 00:00:00 2001 From: Catherine Lee Date: Fri, 17 May 2024 17:29:44 +0000 Subject: [PATCH 067/116] Workflow for uploading additional test stats on workflow dispatch (#126080) This kind of an experiment for uploading test stats during the run, and also for test dashboard stuff so it can re calculate the info Add workflow that is callable via workflow dispatch for uploading additional test stats Adds script that only calculates the additional info Pull Request resolved: https://github.com/pytorch/pytorch/pull/126080 Approved by: https://github.com/ZainRizvi --- .../upload_test_stats_intermediate.yml | 43 +++++++++++++++++++ tools/stats/upload_test_stats_intermediate.py | 29 +++++++++++++ 2 files changed, 72 insertions(+) create mode 100644 .github/workflows/upload_test_stats_intermediate.yml create mode 100644 tools/stats/upload_test_stats_intermediate.py diff --git a/.github/workflows/upload_test_stats_intermediate.yml b/.github/workflows/upload_test_stats_intermediate.yml new file mode 100644 index 000000000000..14b65f6a75ef --- /dev/null +++ b/.github/workflows/upload_test_stats_intermediate.yml @@ -0,0 +1,43 @@ +name: Upload test stats intermediate + +on: + workflow_dispatch: + inputs: + workflow_id: + description: workflow_id of the run + required: true + workflow_run_attempt: + description: workflow_run_attempt of the run + required: true + +jobs: + intermediate_upload_test_stats: + name: Intermediate upload test stats for ${{ inputs.workflow_id }} + runs-on: ubuntu-22.04 + environment: upload-stats + steps: + - name: Checkout PyTorch + uses: pytorch/pytorch/.github/actions/checkout-pytorch@main + with: + fetch-depth: 1 + submodules: false + + - uses: actions/setup-python@v4 + with: + python-version: '3.11' + cache: pip + + - run: | + pip3 install requests==2.26 rockset==1.0.3 boto3==1.19.12 + + - name: Upload test stats + env: + AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} + AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + WORKFLOW_RUN_ID: ${{ inputs.workflow_id }} + WORKFLOW_RUN_ATTEMPT: ${{ inputs.workflow_run_attempt }} + run: | + python3 -m tools.stats.upload_test_stats_intermediate \ + --workflow-run-id "${WORKFLOW_RUN_ID}" \ + --workflow-run-attempt "${WORKFLOW_RUN_ATTEMPT}" \ diff --git a/tools/stats/upload_test_stats_intermediate.py b/tools/stats/upload_test_stats_intermediate.py new file mode 100644 index 000000000000..77cab472367b --- /dev/null +++ b/tools/stats/upload_test_stats_intermediate.py @@ -0,0 +1,29 @@ +import argparse +import sys + +from tools.stats.test_dashboard import upload_additional_info +from tools.stats.upload_test_stats import get_tests + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Upload test stats to Rockset") + parser.add_argument( + "--workflow-run-id", + required=True, + help="id of the workflow to get artifacts from", + ) + parser.add_argument( + "--workflow-run-attempt", + type=int, + required=True, + help="which retry of the workflow this is", + ) + args = parser.parse_args() + + print(f"Workflow id is: {args.workflow_run_id}") + + test_cases = get_tests(args.workflow_run_id, args.workflow_run_attempt) + + # Flush stdout so that any errors in Rockset upload show up last in the logs. + sys.stdout.flush() + + upload_additional_info(args.workflow_run_id, args.workflow_run_attempt, test_cases) From 66dc8fb7ff822033c4b161fc216e21d6886568c7 Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Fri, 17 May 2024 07:54:46 -0700 Subject: [PATCH 068/116] Allow tensor subclasses and add `torch.serialization.add_safe_globals` that allows users to allowlist classes for `weights_only` load (#124331) #### Conditions for allowlisting tensor subclasses We allow tensor subclasses types that (1) Do not override `__setstate__`, `__getattr__`, `__setattr__`, `__get__`, `__set__` or `__getattribute__` of `torch.Tensor` (`torch.Tensor` does not have a definition of `__getattr__`, `__get__` or `__set__` so we check that these are `None`) (2) Use the generic `tp_alloc` (3) Are in a module that *has been imported by the user* to be pushed onto the stack as strings by `GLOBAL` instructions, while storing the type in a dict The strings will be converted to the classes as appropriate when executing `REBUILD` with `_rebuild_from_type_v2` *Note that we use `inspect.getattr_static(sys.modules[module], name)` to get the class/function as this method claims to have no code execution. The rationale for the 3 conditions above is as follows: The rebuild func provided by `Tensor.__reduce_ex__` is `torch._tensor._rebuild_from_type_v2`, which is defined as such (note the call to `getattr`, `Tensor.__setstate__` and the call to `as_subclass` as well as the call to `_set_obj_state` which calls `setattr`) https://github.com/pytorch/pytorch/blob/4e66aaa01092ddc8822bbca315b673329c76f4cd/torch/_tensor.py#L57-L71 `as_subclass` is implemented with a call to `THPVariable_NewWithVar` that will eventually call `tp_alloc` here https://github.com/pytorch/pytorch/blob/4e66aaa01092ddc8822bbca315b673329c76f4cd/torch/csrc/autograd/python_variable.cpp#L2053 The `func` arg to `_rebuild_from_type_v2` for wrapper subclasses is `Tensor.rebuild_wrapper_subclass`, which will similarly call into `THPVariable_NewWithVar` and hit the above `tp_alloc` **Note that we do not call `tp_init` or `tp_new` (i.e. `cls.__init__` or `cls.__new__`) when unpickling** ### How do we check something is a tensor subclass/constraints around imports In order to check whether `bla` is a tensor subclass in the bytecode `GLOBAL module.name`, we need to do an `issubclass` check, which entails converting the global string to the appropriate type. We *do not* arbitrarily import modules but will perform this check as long as the given subclass (given by `module.name`) has already been imported by the user (i.e. `module in sys.modules` and `issubclass(getattr(sys[modules], name), torch.Tensor)` This PR also allowlisted `torch._utils._rebuild_wrapper_subclass` and `torch.device` (used by `_rebuild_wrapper_subclass`) ### API for allow listing This PR also added `torch.serialization.{add/get/clear}_safe_globals` that enables user to allowlist globals they have deemed safe and manipulate this list (for example they could allowlist a tensor subclass with a custom `__setstate__` if they have checked that this is safe). Next steps: - Add testing and allowlist required classes for all in-core tensor subclasses (e.g. `DTensor`, `FakeTensor` etc.) Pull Request resolved: https://github.com/pytorch/pytorch/pull/124331 Approved by: https://github.com/albanD --- docs/source/notes/serialization.rst | 3 + ...ialization.test_allowlist_for_weights_only | 0 test/test_serialization.py | 221 +++++++++++++++++- torch/_C/__init__.pyi.in | 1 + torch/_weights_only_unpickler.py | 216 +++++++++++++++-- torch/csrc/Module.cpp | 17 ++ torch/serialization.py | 28 ++- 7 files changed, 468 insertions(+), 18 deletions(-) create mode 100644 test/dynamo_expected_failures/TestSubclassSerialization.test_allowlist_for_weights_only diff --git a/docs/source/notes/serialization.rst b/docs/source/notes/serialization.rst index 09fd9e858b87..225486cdedac 100644 --- a/docs/source/notes/serialization.rst +++ b/docs/source/notes/serialization.rst @@ -394,3 +394,6 @@ The following utility functions are related to serialization: .. autofunction:: set_default_load_endianness .. autofunction:: get_default_mmap_options .. autofunction:: set_default_mmap_options +.. autofunction:: add_safe_globals +.. autofunction:: clear_safe_globals +.. autofunction:: get_safe_globals diff --git a/test/dynamo_expected_failures/TestSubclassSerialization.test_allowlist_for_weights_only b/test/dynamo_expected_failures/TestSubclassSerialization.test_allowlist_for_weights_only new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/test_serialization.py b/test/test_serialization.py index 5c6b78b44564..1be1b06ab786 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -15,8 +15,10 @@ import shutil import pathlib import platform +from collections import OrderedDict from copy import deepcopy from itertools import product +from types import ModuleType from torch._utils_internal import get_file_path_2 from torch._utils import _rebuild_tensor @@ -27,9 +29,10 @@ from torch.testing._internal.common_utils import ( IS_FILESYSTEM_UTF8_ENCODING, TemporaryDirectoryName, TestCase, IS_FBCODE, IS_WINDOWS, TEST_DILL, run_tests, download_file, BytesIOContext, TemporaryFileName, - parametrize, instantiate_parametrized_tests, AlwaysWarnTypedStorageRemoval, serialTest) + parametrize, instantiate_parametrized_tests, AlwaysWarnTypedStorageRemoval, serialTest, skipIfTorchDynamo) from torch.testing._internal.common_device_type import instantiate_device_type_tests from torch.testing._internal.common_dtype import all_types_and_complex_and +from torch.testing._internal.two_tensor import TwoTensor # noqa: F401 if not IS_WINDOWS: from mmap import MAP_SHARED, MAP_PRIVATE @@ -1038,7 +1041,7 @@ def __reduce__(self): self.assertIsNone(torch.load(f, weights_only=False)) f.seek(0) # Safe load should assert - with self.assertRaisesRegex(pickle.UnpicklingError, "Unsupported class"): + with self.assertRaisesRegex(pickle.UnpicklingError, "Unsupported global: GLOBAL __builtin__.print"): torch.load(f, weights_only=True) @parametrize('weights_only', (False, True)) @@ -4108,6 +4111,23 @@ def __setstate__(self, state): class TestEmptySubclass(torch.Tensor): ... +# ONLY use SubclassSpoof subclasses for the subclass spoof tests since we modify them +# Cannot define locally in test or pickle will fail. +class TestEmptySubclassSpoof(TestEmptySubclass): + ... + +class TestWrapperSubclassSpoof(TestWrapperSubclass): + ... + +class RebuildFromTypeV2Spoof(torch.Tensor): + def __new__(cls, elem, naughty, **kwargs): + if naughty: + raise RuntimeError("naughty") + return super().__new__(cls, elem) + + def __reduce_ex__(self, protocol): + return (torch._tensor._rebuild_from_type_v2, (RebuildFromTypeV2Spoof, torch.Tensor, (True,), {})) + class TestSubclassSerialization(TestCase): def test_tensor_subclass_wrapper_serialization(self): @@ -4187,6 +4207,203 @@ def test_empty_class_serialization(self): f.seek(0) tensor2 = torch.load(f) + def _create_bad_func(self, name): + def bad_func(self, *args, **kwargs): + raise RuntimeError(f"running {name}") + return bad_func + + @parametrize("wrapper", (True, False)) + def test_tensor_subclass_method_spoofing(self, wrapper): + ''' + This tests seeks to do the following: + - determine which methods of a tensor subclass might be called during unpickling (weights_only=False) + we consider these methods "risky" for weights_only + - ensure that we ban overriding this group of methods on a tensor subclass by default (weights_only=True) + - ensure that tensor subclass that doesn't override any of these can be unpickled (weights_only=True) + + We achieve this by overriding all methods of a tensor subclass to raise a RuntimeError + when called. We then try to unpickle a tensor subclass with weights_only=False and ensure that + only the RuntimeErrors that we expect are thrown. + + We then load with weights_only and ensure that weights_only will fail unless all the risky methods + are not overriden by resetting the risky methods to the non-overriden version in a loop and calling load. + The final weights_only load call when all the risky methods are no longer overriden. + ''' + subclass = TestWrapperSubclassSpoof if wrapper else TestEmptySubclassSpoof + t = subclass(torch.randn(2, 3)) + # To trigger setattr for the non-wrapper case + if not wrapper: + t.foo = 'bar' + inp = {'weight': t} + + with TemporaryFileName() as f: + torch.save(inp, f) + loaded = torch.load(f, weights_only=True) + self.assertEqual(loaded['weight'], inp['weight']) + + restore_methods = dict() + methods = [func for func in dir(subclass) if callable(getattr(subclass, func))] + for method in methods: + if method != "__class__": + restore_methods[method] = getattr(subclass, method) + setattr(subclass, method, self._create_bad_func(method)) + # These additional methods might be called during getattr or setattr + # but are not in methods above (not defined on tensor base class) + subclass.__get__ = self._create_bad_func("__get__") + subclass.__set__ = self._create_bad_func("__set__") + subclass.__getattr__ = self._create_bad_func("__getattr__") + restore_methods["__get__"] = None + restore_methods["__getattr__"] = None + restore_methods["__set__"] = None + + try: + # Check that weights_only=False load raises the RuntimeErrors we expect + with self.assertRaisesRegex(RuntimeError, "running __getattribute__"): + torch.load(f, weights_only=False) + subclass.__getattribute__ = restore_methods['__getattribute__'] + with self.assertRaisesRegex(RuntimeError, "running __setstate__"): + torch.load(f, weights_only=False) + subclass.__setstate__ = restore_methods['__setstate__'] + with self.assertRaisesRegex(RuntimeError, "running __setattr__"): + torch.load(f, weights_only=False) + subclass.__setattr__ = restore_methods['__setattr__'] + # should finally work + torch.load(f, weights_only=False) + + # Check that weights_only=True catches that risky methods are overriden + subclass.__setstate__ = self._create_bad_func("__setstate__") + subclass.__getattribute__ = self._create_bad_func("__getattribute__") + subclass.__setattr__ = self._create_bad_func("__setattr__") + with self.assertRaisesRegex(pickle.UnpicklingError, + "methods: __getattribute__=True __getattr__=True __get__=True " + "__setattr__=True __set__=True __setstate__=True"): + torch.load(f, weights_only=True) + risky_methods = ['__get__', '__set__', '__getattr__', '__setattr__', '__getattribute__', '__setstate__'] + for i, meth in enumerate(risky_methods): + setattr(subclass, meth, restore_methods[meth]) + if i != len(risky_methods) - 1: + # When the given methods are not all back to default, load should still throw + # but reflect which methods are no longer overriden + with self.assertRaisesRegex(pickle.UnpicklingError, f"{meth}=False"): + torch.load(f, weights_only=True) + else: + # When the given methods are all back to default, weights_only load should finally work + loaded = torch.load(f, weights_only=True) + finally: + for method, func in restore_methods.items(): + setattr(subclass, method, func) + a = subclass(torch.randn(2, 3)) + + @skipIfTorchDynamo("name 'SYNTHETIC_LOCAL' is not defined") + def test_safe_globals_for_weights_only(self): + ''' + Tests import semantic for tensor subclass and the {add/get/clear}_safe_globals APIs + ''' + # Needed to prevent UnboundLocalError: local variable 'TwoTensor' referenced before assignment + global TwoTensor + t = TwoTensor(torch.randn(2, 3), torch.randn(2, 3)) + p = torch.nn.Parameter(t) + sd = OrderedDict([('t', t), ('p', p)]) + + with tempfile.NamedTemporaryFile() as f: + torch.save(sd, f) + # unimport TwoTensor + try: + del sys.modules['torch.testing._internal.two_tensor'] + + # Loading tensor subclass with weights_only=True should fail + # if tensor subclass has not been imported + with self.assertRaisesRegex(pickle.UnpicklingError, + "expect `torch.testing._internal.two_tensor` to be present in `sys.modules`"): + f.seek(0) + sd = torch.load(f, weights_only=True) + + # Loading tensor subclass with weights_only=True should work + # if target methods are not overriden and user has imported the subclass + from torch.testing._internal.two_tensor import TwoTensor + f.seek(0) + sd = torch.load(f, weights_only=True) + self.assertEqual(sd['t'], t) + self.assertEqual(sd['p'], p) + + # Loading tensor subclass with weights_only=True should fail + # if __setstate__ is overriden + f.seek(0) + restore_setstate = TwoTensor.__setstate__ + try: + TwoTensor.__setstate__ = lambda self, state: self.__dict__.update(state) + with self.assertRaisesRegex(pickle.UnpicklingError, "__setstate__=True"): + torch.load(f, weights_only=True) + + # Loading tensor subclass with overriden __setstate__ with weights_only=True should work + # if the class is marked safe + f.seek(0) + torch.serialization.add_safe_globals([TwoTensor]) + self.assertTrue(torch.serialization.get_safe_globals() == [TwoTensor]) + sd = torch.load(f, weights_only=True) + self.assertEqual(sd['t'], t) + self.assertEqual(sd['p'], p) + + # Should fail again when safe globals are cleared + torch.serialization.clear_safe_globals() + f.seek(0) + with self.assertRaisesRegex(pickle.UnpicklingError, "__setstate__=True"): + torch.load(f, weights_only=True) + finally: + TwoTensor.__setstate__ = restore_setstate + finally: + from torch.testing._internal.two_tensor import TwoTensor + + + def test_tensor_subclass_parent_module_method_spoofing(self): + ''' + Tests that weights_only load does not call any methods of the parent module + that contains the tensor subclass. + + We achieve this by overriding all methods of a module we add to sys.modules to raise a RuntimeError + when called. We then try to unpickle a tensor subclass with weights_only=True and ensure that + no RuntimeErrors are thrown. + ''' + # Simulates user doing `import spoof_mod` where `spoof_mod` contains `TestEmptySubclass` + class SpoofModule(ModuleType): + pass + + spoof_mod = SpoofModule('bla') + spoof_mod.TestEmptySubclass = TestEmptySubclass + inp = {'weight': TestEmptySubclass(torch.randn(2, 3))} + TestEmptySubclass.__module__ = 'spoof_mod' + sys.modules['spoof_mod'] = spoof_mod + + try: + with TemporaryFileName() as f: + torch.save(inp, f) + torch.load(f, weights_only=True) + restore_methods = dict() + methods = [func for func in dir(SpoofModule) if callable(getattr(SpoofModule, func))] + for method in methods: + if method != "__class__": + restore_methods[method] = getattr(SpoofModule, method) + setattr(SpoofModule, method, self._create_bad_func(method)) + SpoofModule.__get__ = self._create_bad_func("__get__") + SpoofModule.__getattr__ = self._create_bad_func("__getattr__") + loaded = torch.load(f, weights_only=True) + self.assertEqual(loaded['weight'], inp['weight']) + finally: + TestEmptySubclass.__module__ = __name__ + del sys.modules['spoof_mod'] + + def test_rebuild_from_type_v2_spoof(self): + t = RebuildFromTypeV2Spoof(torch.randn(2, 3), False) + inp = {'weight': t} + + with TemporaryFileName() as f: + torch.save(inp, f) + # subclass will be pushed onto unpickler's stack as a string + # and only gets converted to the type if it is argument 1 to _rebuild_from_type_v2 + with self.assertRaisesRegex(TypeError, "'str' object is not callable"): + loaded = torch.load(f, weights_only=True) + + instantiate_device_type_tests(TestBothSerialization, globals()) instantiate_parametrized_tests(TestSubclassSerialization) diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index ac70396c468e..0599da2117fb 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -1196,6 +1196,7 @@ def _has_storage(x: Tensor) -> _bool: ... def _construct_storage_from_data_pointer(data_ptr: _int, device: torch.device, size: _int) -> Storage: ... def _should_allow_numbers_as_tensors(func_name: str) -> _bool: ... def _group_tensors_by_device_and_dtype(nested_tensorlists: List[List[Optional[Tensor]]], with_indices: _bool = False) -> Dict[Tuple[torch.device, str], Tuple[List[List[Optional[Tensor]]], List[_int]]]: ... +def _check_tp_alloc_is_default(cls: Type) -> _bool: ... # NB: There is no Capsule type in typing, see # https://code.activestate.com/lists/python-dev/139675/ diff --git a/torch/_weights_only_unpickler.py b/torch/_weights_only_unpickler.py index 44dd8223862a..6c9f3b61ae8b 100644 --- a/torch/_weights_only_unpickler.py +++ b/torch/_weights_only_unpickler.py @@ -9,6 +9,10 @@ # - `torch.nn.Parameter` # - `collections.Counter` # - `collections.OrderedDict` +# Additionally, users can use an allowlist for adding classes they have deemed as safe using +# `_add_safe_globals()` (`torch.serialization.add_safe_globals`) +# `_clear_safe_globals()` (`torch.serialization.clear_safe_globals`) +# `_get_safe_globals()` (`torch.serialization.get_safe_globals`) # Based of https://github.com/python/cpython/blob/main/Lib/pickle.py # Expected to be useful for loading PyTorch model weights @@ -19,6 +23,7 @@ import functools as _functools from collections import Counter, OrderedDict +from inspect import getattr_static from pickle import ( APPEND, APPENDS, @@ -59,11 +64,57 @@ UnpicklingError, ) from struct import unpack -from sys import maxsize -from typing import Any, Dict, List +from sys import maxsize, modules +from typing import Any, Dict, List, Type import torch +_marked_safe_globals_list: List[Any] = [] + + +def _add_safe_globals(safe_globals: List[Any]): + global _marked_safe_globals_list + _marked_safe_globals_list += safe_globals + + +def _get_safe_globals() -> List[Any]: + global _marked_safe_globals_list + return _marked_safe_globals_list + + +def _clear_safe_globals(): + global _marked_safe_globals_list + _marked_safe_globals_list = [] + + +# Separate from _get_allowed_globals because of the lru_cache on _get_allowed_globals +# For example if user had a script like +# torch.load(file_a) +# torch.serialization._add_safe_globals([torch.foo]) +# torch.load(file_b) +# the dynamic additions to safe_globals would not be picked up by +# _get_allowed_globals due to the lru_cache +def _get_user_allowed_globals(): + rc: Dict[str, Any] = {} + for f in _marked_safe_globals_list: + rc[f"{f.__module__}.{f.__name__}"] = f + return rc + + +def _tensor_rebuild_functions(): + return { + torch._utils._rebuild_parameter, + torch._utils._rebuild_parameter_with_state, + torch._utils._rebuild_qtensor, + torch._utils._rebuild_tensor, + torch._utils._rebuild_tensor_v2, + torch._utils._rebuild_tensor_v3, + torch._utils._rebuild_sparse_tensor, + torch._utils._rebuild_meta_tensor_no_storage, + torch._utils._rebuild_nested_tensor, + torch._utils._rebuild_wrapper_subclass, + } + # Unpickling machinery @_functools.lru_cache(maxsize=1) @@ -75,6 +126,7 @@ def _get_allowed_globals(): "torch.serialization._get_layout": torch.serialization._get_layout, "torch.Size": torch.Size, "torch.Tensor": torch.Tensor, + "torch.device": torch.device, } # dtype for t in torch.storage._dtype_to_storage_type_map().keys(): @@ -103,17 +155,7 @@ def _get_allowed_globals(): ]: rc[str(qt)] = qt # Rebuild functions - for f in [ - torch._utils._rebuild_parameter, - torch._utils._rebuild_parameter_with_state, - torch._utils._rebuild_qtensor, - torch._utils._rebuild_tensor, - torch._utils._rebuild_tensor_v2, - torch._utils._rebuild_tensor_v3, - torch._utils._rebuild_sparse_tensor, - torch._utils._rebuild_meta_tensor_no_storage, - torch._utils._rebuild_nested_tensor, - ]: + for f in _tensor_rebuild_functions(): rc[f"torch._utils.{f.__name__}"] = f # Handles Tensor Subclasses, Tensor's with attributes. @@ -128,6 +170,11 @@ def __init__(self, file, *, encoding: str = "bytes"): self.readline = file.readline self.read = file.read self.memo: Dict[int, Any] = {} + # tensor subclass types found from GLOBAL instructions that have passed the criteria + # to be allowed as the second argument to `torch._tensor._rebuild_from_type_v2` + # This enables rebuilding of tensor subclasses defined outside the `torch` package. + # See [Note: Criteria for allowing out-of-core tensor subclasses] for details on the criteria. + self.tensor_subclasses_found: Dict[str, Type] = {} def load(self): """Read a pickled object representation from the open file. @@ -151,8 +198,124 @@ def load(self): full_path = f"{module}.{name}" if full_path in _get_allowed_globals(): self.append(_get_allowed_globals()[full_path]) + elif full_path in _get_user_allowed_globals(): + self.append(_get_user_allowed_globals()[full_path]) else: - raise RuntimeError(f"Unsupported class {full_path}") + # The logic in this branch handles user-defined tensor subclasses. + # We can automatically allow and raise and error for anything that is not provably safe. + # [Note: Criteria for allowing out-of-core tensor subclasses] + # GLOBAL '.' instructions will get the class and + # push the string (not the actual type) while adding the type to the dictionary keyed + # by the string onto the unpickler's stack if they satisfy the following conditions: + # (1) The that defines them is in `sys.modules` + # (we will use getattr_static to access it to ensure no code execution) + # (2) They inherit from `torch.Tensor` + # (2) The class is not overriding any of the `torch.Tensor` methods listed here: + # `__getattr__`, `__get__`, `__getattribute__`, `__setstate__`, `__set__`, + # and `tp_alloc` + # The methods that we ban overriding were selected in a test-driven manner + # by overriding every callable method on a tensor subclass and determinining + # which might get called during unpickling. + # When executing REDUCE, the string will be appropriately converted back to the type only + # for `torch._tensor._rebuild_from_type_v2` as other use of the class could use methods + # we didn't audit. + if module == "__builtin__": + raise RuntimeError( + f"Unsupported global: GLOBAL {full_path} was not an allowed global by default. " + "Please use `torch.serialization.add_safe_globals` to allowlist this global " + "if you trust this class/function." + ) + elif module not in modules: + # TODO: add a link here to a doc that explains to users what we mean by trust + raise RuntimeError( + f"Found GLOBAL `{full_path}` instruction in the pickle file but `{full_path}` was " + f"not in the pre-defined list of allowed globals that are considered safe by the " + "weights_only unpickler for rebuilding state_dicts. This is the expected behavior if " + f"`{full_path}` is a class or function that is not in the list of allowed globals " + f"If `{full_path}` is NOT a tensor subclass, you might consider" + "`torch.serialization.add_safe_globals` if it is appropriate. However, if it is a " + "user-defined tensor subclass not defined in the `torch` package, this error might arise " + f"as we expect `{module}` to be present in `sys.modules` (i.e. it " + "must be imported in the current environment), but this was not the case. " + f"If you intend to unpickle a tensor subclass `{full_path}` please import `{name}` from " + f"`{module}`. Note that having this imported will *only* allow the type `{full_path}` to " + "be passed as the second argument to `torch._tensor._rebuild_from_type_v2`, which should " + "enable the tensor subclass to be unpickled without any arbitrary code execution as long " + # If the user imports and these are overridden the next error will prompt them to use + # torch.serialization.add_safe_globals. + "a sa pre-defined list of methods called when unpickling are not overridden. In " + "particular, the methods are `__getattr__`, `__get__`, `__getattribute__`, `__setstate__`, " + "`__set__`, as well as the implementation of `tp_alloc`." + ) + else: + try: + class_type = getattr_static(modules[module], name) + except AttributeError as e: + raise AttributeError( + "For safety during weights_only loading, we use inspect.getattr_state to " + f"get {name} from {module}, if {module} implements the descriptor protocol, " + "__getattr__ or __getattribute__ these will not be called." + ) from e + # None of the objects here contain any data from the pickle so this is safe + if isinstance(class_type, type) and issubclass( + class_type, torch.Tensor + ): + # getattr is called by the getattr call in `_rebuild_from_type_v2` + custom_get_attribute = ( + class_type.__getattribute__ + is not torch.Tensor.__getattribute__ + ) + custom_get = ( + getattr_static(class_type, "__get__", None) is not None + ) + custom_get_attr = ( + getattr_static(class_type, "__getattr__", None) + is not None + ) + # Tensor.__setstate__ might be called in `_rebuild_from_type_v2` + custom_set_state = ( + class_type.__setstate__ is not torch.Tensor.__setstate__ + ) + # setattr is called in `torch._utils._set_obj_state` + custom_set_attr = ( + class_type.__setattr__ is not object.__setattr__ + ) + custom_set = ( + getattr_static(class_type, "__set__", None) is not None + ) + # tp_alloc is called by `Tensor._rebuild_wrapper_subclass` and `Tensor.as_subclass` + has_custom_tp_alloc = ( + not torch._C._check_tp_alloc_is_default(class_type) + ) + custom_methods = { + "__getattribute__": custom_get_attribute, + "__getattr__": custom_get_attr, + "__get__": custom_get, + "__setattr__": custom_set_attr, + "__set__": custom_set, + "__setstate__": custom_set_state, + "tp_alloc": has_custom_tp_alloc, + } + if any(custom_methods.values()): + error = "" + for k, v in custom_methods.items(): + error += f" {k}={v}" + raise RuntimeError( + f"Trying to unpickle tensor subclass `{full_path}` that has defined a custom " + f"version for one of these methods:{error}. Please check whether you trust these " + "methods and allowlist the subclass with `torch.serialization.add_safe_globals` if so." + ) + # push the string full_path onto the stack (in REBUILD, there is special logic to + # access this from tensor_subclasses_found for rebuild_from_type_v2) + self.tensor_subclasses_found[full_path] = class_type + self.append(full_path) + else: + raise RuntimeError( + f"Unsupported global: GLOBAL {full_path} was not an allowed global by default. " + "Please use `torch.serialization.add_safe_globals` to allowlist this global " + "if you trust this class/function." + ) + elif key[0] == NEWOBJ[0]: args = self.stack.pop() cls = self.stack.pop() @@ -162,10 +325,33 @@ def load(self): elif key[0] == REDUCE[0]: args = self.stack.pop() func = self.stack[-1] - if func not in _get_allowed_globals().values(): + if ( + func not in _get_allowed_globals().values() + and func not in _get_user_allowed_globals().values() + ): raise RuntimeError( f"Trying to call reduce for unrecognized function {func}" ) + # Special handling for tensor subclass type found in GLOBAL that is pushed + # onto stack as str to prevent it from being used anywhere except the + # second arg of _rebuild_from_type_v2 and within argument tuple for _rebuild_wrapper_subclass + # _rebuild_from_type_v2 is called with args (func, type, func_args, state) + # where both type and, when func is rebuild_wrapper_subclass, func_args[0] could be the subclass type + # Since we pushed these subclass types onto the stack as strings, convert them to the actual + # type here. + if func is torch._tensor._rebuild_from_type_v2 and type(args[1]) is str: + args_after = args[2:] + if ( + args[0] is torch._utils._rebuild_wrapper_subclass + and type(args[2][0]) is str + ): + new_arg_tuple = ( + self.tensor_subclasses_found[args[2][0]], + ) + args[2][1:] + args_after = (new_arg_tuple,) + args[3:] + args = ( + args[:1] + (self.tensor_subclasses_found[args[1]],) + args_after + ) self.stack[-1] = func(*args) elif key[0] == BUILD[0]: state = self.stack.pop() diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 3be764220e0d..9ff9131435f4 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -422,6 +422,19 @@ PyObject* THPModule_swap_tensor_impl(PyObject* _unused, PyObject* args) { END_HANDLE_TH_ERRORS } +PyObject* THPModule_check_tp_alloc_is_default( + PyObject* _unused, + PyObject* cls) { + HANDLE_TH_ERRORS + TORCH_CHECK_TYPE( + PyType_Check(cls), + "cls must be a type (got ", + Py_TYPE(cls)->tp_name, + ")"); + return PyBool_FromLong(Py_TYPE(cls)->tp_alloc == PyType_GenericAlloc); + END_HANDLE_TH_ERRORS +} + PyObject* THPModule_addDocStr(PyObject* _unused, PyObject* args) { // adds a __doc__ string to a function, similar to numpy's arr_add_docstring static std::vector all_docs; @@ -1268,6 +1281,10 @@ static PyMethodDef TorchMethods[] = { // NOLINT {"_autograd_init", THPAutograd_initExtension, METH_NOARGS, nullptr}, {"_add_docstr", THPModule_addDocStr, METH_VARARGS, nullptr}, {"_swap_tensor_impl", THPModule_swap_tensor_impl, METH_VARARGS, nullptr}, + {"_check_tp_alloc_is_default", + THPModule_check_tp_alloc_is_default, + METH_O, + nullptr}, {"_init_names", THPModule_initNames, METH_O, nullptr}, {"_has_distributed", THPModule_hasDistributed, METH_NOARGS, nullptr}, {"_set_default_tensor_type", diff --git a/torch/serialization.py b/torch/serialization.py index 64a1e6e0ce06..a7703b9964d0 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -59,6 +59,9 @@ 'LoadEndianness', 'get_default_load_endianness', 'set_default_load_endianness', + 'clear_safe_globals', + 'get_safe_globals', + 'add_safe_globals', ] @@ -148,6 +151,27 @@ def set_default_mmap_options(flags: int): f"expected mmap.MAP_PRIVATE or mmap.MAP_SHARED, but got {flags}") _default_mmap_options = flags +def clear_safe_globals() -> None: + ''' + Clears the list of globals that are safe for ``weights_only`` load. + ''' + _weights_only_unpickler._clear_safe_globals() + +def get_safe_globals() -> List[Any]: + ''' + Returns the list of user-added globals that are safe for ``weights_only`` load. + ''' + return _weights_only_unpickler._get_safe_globals() + +def add_safe_globals(safe_globals: List[Any]) -> None: + ''' + Marks the given globals as safe for ``weights_only`` load. + + Args: + safe_globals (List[Any]): list of globals to mark as safe + ''' + _weights_only_unpickler._add_safe_globals(safe_globals) + def _is_zipfile(f) -> bool: # This is a stricter implementation than zipfile.is_zipfile(). # zipfile.is_zipfile() is True if the magic number appears anywhere in the @@ -952,7 +976,9 @@ def load( UNSAFE_MESSAGE = ( "Weights only load failed. Re-running `torch.load` with `weights_only` set to `False`" " will likely succeed, but it can result in arbitrary code execution." - "Do it only if you get the file from a trusted source. WeightsUnpickler error: " + " Do it only if you get the file from a trusted source. Alternatively, to load" + " with `weights_only` please check the recommended steps in the following error message." + " WeightsUnpickler error: " ) # Add ability to force safe only weight loads via environment variable if os.getenv("TORCH_FORCE_WEIGHTS_ONLY_LOAD", "0").lower() in ['1', 'y', 'yes', 'true']: From ecd9a4e5c38c1439ebea1a4b97ec8fc6b1758453 Mon Sep 17 00:00:00 2001 From: Sam Larsen Date: Thu, 16 May 2024 14:11:15 -0700 Subject: [PATCH 069/116] Enable FX graph cache for huggingface and timm benchmarks (#126205) Pull Request resolved: https://github.com/pytorch/pytorch/pull/126205 Approved by: https://github.com/eellison --- benchmarks/dynamo/huggingface.py | 4 ++++ benchmarks/dynamo/timm_models.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/benchmarks/dynamo/huggingface.py b/benchmarks/dynamo/huggingface.py index d6014706479e..a998d10bf33c 100755 --- a/benchmarks/dynamo/huggingface.py +++ b/benchmarks/dynamo/huggingface.py @@ -15,6 +15,10 @@ log = logging.getLogger(__name__) +# Enable FX graph caching +if "TORCHINDUCTOR_FX_GRAPH_CACHE" not in os.environ: + torch._inductor.config.fx_graph_cache = True + def pip_install(package): subprocess.check_call([sys.executable, "-m", "pip", "install", package]) diff --git a/benchmarks/dynamo/timm_models.py b/benchmarks/dynamo/timm_models.py index ed5132001827..1d291e8d1d75 100755 --- a/benchmarks/dynamo/timm_models.py +++ b/benchmarks/dynamo/timm_models.py @@ -13,6 +13,10 @@ from torch._dynamo.testing import collect_results, reduce_to_scalar_loss from torch._dynamo.utils import clone_inputs +# Enable FX graph caching +if "TORCHINDUCTOR_FX_GRAPH_CACHE" not in os.environ: + torch._inductor.config.fx_graph_cache = True + def pip_install(package): subprocess.check_call([sys.executable, "-m", "pip", "install", package]) From 6931f781c21cb9892efdbb18696d0d3b3f9c4b26 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Fri, 17 May 2024 07:15:40 -0700 Subject: [PATCH 070/116] [quant][pt2e] Allow multi users without output observers (#126487) Summary: The PT2E quantization flow does not support unquantized outputs yet. To work around this, users may wish to remove the output observer from their graphs. However, this fails currently in some cases because the `PortNodeMetaForQDQ` pass is too restrictive, for example: ``` conv -> obs -------> output0 \\-> add -> output1 ``` Previously we expected conv to always have exactly 1 user, which is the observer. When the observer is removed, however, conv now has 2 users, and this fails the check. ``` conv -------> output0 \\-> add -> output1 ``` This commit relaxes the error into a warning to enable this workaround. Test Plan: python test/test_quantization.py TestQuantizePT2E.test_multi_users_without_output_observer Reviewers: jerryzh168 Subscribers: jerryzh168, supriyar Differential Revision: [D57472601](https://our.internmc.facebook.com/intern/diff/D57472601) Pull Request resolved: https://github.com/pytorch/pytorch/pull/126487 Approved by: https://github.com/tarun292 --- ...E.test_multi_users_without_output_observer | 0 test/quantization/pt2e/test_quantize_pt2e.py | 41 +++++++++++++++++++ .../quantization/pt2e/port_metadata_pass.py | 2 +- 3 files changed, 42 insertions(+), 1 deletion(-) create mode 100644 test/dynamo_expected_failures/TestQuantizePT2E.test_multi_users_without_output_observer diff --git a/test/dynamo_expected_failures/TestQuantizePT2E.test_multi_users_without_output_observer b/test/dynamo_expected_failures/TestQuantizePT2E.test_multi_users_without_output_observer new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/quantization/pt2e/test_quantize_pt2e.py b/test/quantization/pt2e/test_quantize_pt2e.py index b96e1ff12ac3..75cf3c444571 100644 --- a/test/quantization/pt2e/test_quantize_pt2e.py +++ b/test/quantization/pt2e/test_quantize_pt2e.py @@ -2278,5 +2278,46 @@ def validate(self, model: torch.fx.GraphModule) -> None: node_list, ) + def test_multi_users_without_output_observer(self): + """ + Test the case in which a node is used by multiple users, + and had its output observer removed. + """ + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 3, 3) + + def forward(self, x): + x = self.conv(x) + return x, x + 1 + + example_inputs = (torch.randn(1, 3, 5, 5),) + m = M() + m = capture_pre_autograd_graph(m, example_inputs) + quantizer = XNNPACKQuantizer().set_global( + get_symmetric_quantization_config(), + ) + m = prepare_pt2e(m, quantizer) + m(*example_inputs) + + # Remove output observer + observer_to_remove = None + for n in m.graph.nodes: + if n.op == "output": + observer_to_remove = n.args[0][0] + assert observer_to_remove.op == "call_module" + assert observer_to_remove.target.startswith("activation_post_process_") + break + assert observer_to_remove is not None + observer_to_remove.replace_all_uses_with(observer_to_remove.args[0]) + m.graph.erase_node(observer_to_remove) + m.recompile() + + # Convert should succeed + m = convert_pt2e(m) + m(*example_inputs) + instantiate_parametrized_tests(TestQuantizePT2E) diff --git a/torch/ao/quantization/pt2e/port_metadata_pass.py b/torch/ao/quantization/pt2e/port_metadata_pass.py index c47e82073578..5ea1f939a3b6 100644 --- a/torch/ao/quantization/pt2e/port_metadata_pass.py +++ b/torch/ao/quantization/pt2e/port_metadata_pass.py @@ -136,7 +136,7 @@ def _port_metadata_for_output_quant_nodes( node_users = _filter_sym_size_users(node) if len(node_users) != 1: - raise InternalError(f"Expecting {node} to have single user") + logger.warning(f"Expecting {node} to have single user") # noqa: G004 q_node = node_users.pop() if q_node.op != "call_function" or q_node.target not in _QUANTIZE_OPS: logger.warning( From de42af4b0087118cf5527261c532927efcb9a0df Mon Sep 17 00:00:00 2001 From: briancoutinho Date: Fri, 17 May 2024 19:08:50 +0000 Subject: [PATCH 071/116] Add coms metadata to execution trace (ET) (#126317) Add Execution Trace communication collective meta data. For specification see https://github.com/pytorch/pytorch/issues/124674 New fields look like ``` { "id": 80, "name": "record_param_comms", "ctrl_deps": 79, "inputs": {"values": [[[78,74,0,100,4,"cuda:0"]],21,["0","default_pg"],0,"allreduce",[],[],0,1,2], "shapes": [[[100]],[],[[],[]],[],[],[],[],[],[],[]], "types": ["GenericList[Tensor(float)]","Int","Tuple[String,String]","Int","String","GenericList[]","GenericList[]","Int","Int","Int"]}, "outputs": {"values": [[[78,74,0,100,4,"cuda:0"]]], "shapes": [[[100]]], "types": ["GenericList[Tensor(float)]"]}, "attrs": [{"name": "rf_id", "type": "uint64", "value": 53},{"name": "fw_parent", "type": "uint64", "value": 0},{"name": "seq_id", "type": "int64", "value": -1},{"name": "scope", "type": "uint64", "value": 0},{"name": "tid", "type": "uint64", "value": 2},{"name": "fw_tid", "type": "uint64", "value": 0},{"name": "op_schema", "type": "string", "value": ""},{"name": "kernel_backend", "type": "string", "value": ""},{"name": "kernel_file", "type": "string", "value": ""}, {"name": "collective_name", "type": "string", "value": "allreduce"}, {"name": "dtype", "type": "string", "value": "Float"}, {"name": "in_msg_nelems", "type": "uint64", "value": 100}, {"name": "out_msg_nelems", "type": "uint64", "value": 100}, {"name": "in_split_size", "type": "string", "value": "[]"}, {"name": "out_split_size", "type": "string", "value": "[]"}, {"name": "global_rank_start", "type": "uint64", "value": 0}, {"name": "global_rank_stride", "type": "uint64", "value": 1}, {"name": "pg_name", "type": "string", "value": "0"}, {"name": "pg_desc", "type": "string", "value": "default_pg"}, {"name": "pg_size", "type": "uint64", "value": 2}] } ``` ## Unit Test Added a new unit test to check the execution trace collected has right attributes `touch /tmp/barrier && TEMP_DIR="/tmp" BACKEND="nccl" WORLD_SIZE="2" python test/distributed/test_distributed_spawn.py -v TestDistBackendWithSpawn.test_ddp_profiling_execution_trace` ``` STAGE:2024-05-08 17:39:10 62892:62892 ActivityProfilerController.cpp:316] Completed Stage: Warm Up STAGE:2024-05-08 17:39:10 62893:62893 ActivityProfilerController.cpp:316] Completed Stage: Warm Up STAGE:2024-05-08 17:39:11 62892:62892 ActivityProfilerController.cpp:322] Completed Stage: Collection STAGE:2024-05-08 17:39:11 62893:62893 ActivityProfilerController.cpp:322] Completed Stage: Collection STAGE:2024-05-08 17:39:11 62892:62892 ActivityProfilerController.cpp:326] Completed Stage: Post Processing STAGE:2024-05-08 17:39:11 62893:62893 ActivityProfilerController.cpp:326] Completed Stage: Post Processing [rank1]:[W508 17:39:12.329544411 reducer.cpp:1399] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) [rank0]:[W508 17:39:12.329626774 reducer.cpp:1399] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) [rank0]:[W508 17:39:12.339239982 execution_trace_observer.cpp:825] Enabling Execution Trace Observer [rank1]:[W508 17:39:12.339364516 execution_trace_observer.cpp:825] Enabling Execution Trace Observer STAGE:2024-05-08 17:39:12 62892:62892 ActivityProfilerController.cpp:316] Completed Stage: Warm Up STAGE:2024-05-08 17:39:12 62893:62893 ActivityProfilerController.cpp:316] Completed Stage: Warm Up [rank1]:[W508 17:39:12.352452400 execution_trace_observer.cpp:837] Disabling Execution Trace Observer STAGE:2024-05-08 17:39:12 62893:62893 ActivityProfilerController.cpp:322] Completed Stage: Collection [rank0]:[W508 17:39:12.354019014 execution_trace_observer.cpp:837] Disabling Execution Trace Observer STAGE:2024-05-08 17:39:12 62893:62893 ActivityProfilerController.cpp:326] Completed Stage: Post Processing STAGE:2024-05-08 17:39:12 62892:62892 ActivityProfilerController.cpp:322] Completed Stage: Collection STAGE:2024-05-08 17:39:12 62892:62892 ActivityProfilerController.cpp:326] Completed Stage: Post Processing Execution trace saved at /tmp/tmpy01ngc3w.et.json Execution trace saved at /tmp/tmptf8543k4.et.json ok ---------------------------------------------------------------------- ``` Also run profilerunit test `touch /tmp/barrier && TEMP_DIR="/tmp" BACKEND="nccl" WORLD_SIZE="2" python test/distributed/test_distributed_spawn.py -v TestDistBackendWithSpawn.test_ddp_profiling_torch_profiler` ``` STAGE:2024-05-08 18:24:22 1926775:1926775 ActivityProfilerController.cpp:316] Completed Stage: Warm Up STAGE:2024-05-08 18:24:22 1926774:1926774 ActivityProfilerController.cpp:316] Completed Stage: Warm Up STAGE:2024-05-08 18:24:24 1926774:1926774 ActivityProfilerController.cpp:322] Completed Stage: Collection STAGE:2024-05-08 18:24:24 1926775:1926775 ActivityProfilerController.cpp:322] Completed Stage: Collection STAGE:2024-05-08 18:24:24 1926774:1926774 ActivityProfilerController.cpp:326] Completed Stage: Post Processing STAGE:2024-05-08 18:24:24 1926775:1926775 ActivityProfilerController.cpp:326] Completed Stage: Post Processing [rank1]:[W508 18:24:24.508622236 reducer.cpp:1399] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) [rank0]:[W508 18:24:24.508622241 reducer.cpp:1399] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) STAGE:2024-05-08 18:24:24 1926774:1926774 ActivityProfilerController.cpp:316] Completed Stage: Warm Up STAGE:2024-05-08 18:24:24 1926775:1926775 ActivityProfilerController.cpp:316] Completed Stage: Warm Up STAGE:2024-05-08 18:24:24 1926774:1926774 ActivityProfilerController.cpp:322] Completed Stage: Collection STAGE:2024-05-08 18:24:24 1926775:1926775 ActivityProfilerController.cpp:322] Completed Stage: Collection STAGE:2024-05-08 18:24:24 1926774:1926774 ActivityProfilerController.cpp:326] Completed Stage: Post Processing STAGE:2024-05-08 18:24:24 1926775:1926775 ActivityProfilerController.cpp:326] Completed Stage: Post Processing Trace saved to /tmp/tmpdrw_cmcu.json Trace saved to /tmp/tmpnio7ec9j.json ok ---------------------------------------------------------------------- Ran 1 test in 19.772s OK ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/126317 Approved by: https://github.com/yoyoyocmu, https://github.com/sanrise --- .../standalone/execution_trace_observer.cpp | 98 +++++++++++++++- torch/csrc/profiler/util.cpp | 73 +++--------- torch/csrc/profiler/util.h | 17 ++- .../_internal/distributed/distributed_test.py | 108 +++++++++++++++++- 4 files changed, 232 insertions(+), 64 deletions(-) diff --git a/torch/csrc/profiler/standalone/execution_trace_observer.cpp b/torch/csrc/profiler/standalone/execution_trace_observer.cpp index d100d8090c07..9e8a995ec977 100644 --- a/torch/csrc/profiler/standalone/execution_trace_observer.cpp +++ b/torch/csrc/profiler/standalone/execution_trace_observer.cpp @@ -28,8 +28,27 @@ #include #include +#ifdef USE_DISTRIBUTED +#include +#endif // USE_DISTRIBUTED + using namespace at; +// Collective property attributes +// https://github.com/pytorch/pytorch/issues/124674 +#ifdef USE_DISTRIBUTED +constexpr auto kETCommsName = "collective_name"; +constexpr auto kETInMsgNelems = "in_msg_nelems"; +constexpr auto kETOutMsgNelems = "out_msg_nelems"; +constexpr auto kETInSplit = "in_split_size"; +constexpr auto kETOutSplit = "out_split_size"; +constexpr auto kETGlobalRankStart = "global_rank_start"; +constexpr auto kETGlobalRankStride = "global_rank_stride"; +constexpr auto kETGroupSize = "pg_size"; +constexpr auto kETProcessGroupName = "pg_name"; +constexpr auto kETProcessGroupDesc = "pg_desc"; +#endif // USE_DISTRIBUTED + namespace torch { namespace profiler { namespace impl { @@ -258,6 +277,19 @@ static std::ofstream openOutputFile(const std::string& name) { return stream; } +static inline std::string getAttrJson( + const std::string& name, + const std::string& type, + const std::string& value) { + // note name and type are not quoted but value should be if it is a string. + return fmt::format( + R"JSON( + {{"name": "{}", "type": "{}", "value": {}}})JSON", + name, + type, + value); +} + static void writeJsonNode( std::ofstream& out, const std::string& name, @@ -277,14 +309,15 @@ static void writeJsonNode( const std::string& output_types = "[]", const std::string& operator_schema = "", const std::string& kernel_backend = "", - const std::string& kernel_file = "") { + const std::string& kernel_file = "", + const std::string& additiona_attrs = "") { out << fmt::format( R"JSON( {{ "id": {}, "name": "{}", "ctrl_deps": {}, "inputs": {{"values": {}, "shapes": {}, "types": {}}}, "outputs": {{"values": {}, "shapes": {}, "types": {}}}, - "attrs": [{{"name": "rf_id", "type": "uint64", "value": {}}},{{"name": "fw_parent", "type": "uint64", "value": {}}},{{"name": "seq_id", "type": "int64", "value": {}}},{{"name": "scope", "type": "uint64", "value": {}}},{{"name": "tid", "type": "uint64", "value": {}}},{{"name": "fw_tid", "type": "uint64", "value": {}}},{{"name": "op_schema", "type": "string", "value": "{}"}},{{"name": "kernel_backend", "type": "string", "value": "{}"}},{{"name": "kernel_file", "type": "string", "value": "{}"}}] + "attrs": [{{"name": "rf_id", "type": "uint64", "value": {}}},{{"name": "fw_parent", "type": "uint64", "value": {}}},{{"name": "seq_id", "type": "int64", "value": {}}},{{"name": "scope", "type": "uint64", "value": {}}},{{"name": "tid", "type": "uint64", "value": {}}},{{"name": "fw_tid", "type": "uint64", "value": {}}},{{"name": "op_schema", "type": "string", "value": "{}"}},{{"name": "kernel_backend", "type": "string", "value": "{}"}},{{"name": "kernel_file", "type": "string", "value": "{}"}}{}] }})JSON", id, name, @@ -303,7 +336,8 @@ static void writeJsonNode( fw_tid, operator_schema, kernel_backend, - kernel_file); + kernel_file, + additiona_attrs); } inline std::string timeString(const std::time_t timepoint) { @@ -332,7 +366,7 @@ static bool initExecutionTraceStart(ExecutionTraceObserver& ob) { ob.out << fmt::format( R"JSON({{ - "schema": "1.0.4-chakra.0.0.4", "pid": {}, "time": "{}", "start_ts": {}, + "schema": "1.1.0-chakra.0.0.4", "pid": {}, "time": "{}", "start_ts": {}, "nodes": [)JSON", ob.pid, ob.record_time, @@ -486,6 +520,56 @@ inline void handleKernelBackendInfo( } } +// Additional attributes for commounication collectives +inline std::string getCommsNodeAttrs(const RecordFunction& fn) { + std::vector attrs; + +#ifdef USE_DISTRIBUTED + // We rely on paramcommsdebug object that is available in thread local info + auto debugInfo = dynamic_cast( + c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::PARAM_COMMS_INFO)); + if (debugInfo == nullptr) { + LOG(WARNING) << "ParamCommsDebugInfo not available for function: " + << fn.name(); + return ", " + getAttrJson("debug", "string", "\"missing comms info\""); + } + + // get NcclMeta from record function, this used ParamCommsDebugInfo above + auto meta = saveNcclMeta(fn, false /*truncate*/); + + auto addAttr = + [&](const char* commsMetaName, const char* etMetaName, const char* type) { + auto it = meta.find(commsMetaName); + if (it != meta.end()) { + attrs.push_back(getAttrJson(etMetaName, type, it->second)); + } + }; + + addAttr(kCommsName, kETCommsName, "string"); + addAttr(kDtype, kDtype, "string"); + + addAttr(kInMsgNelems, kETInMsgNelems, "uint64"); + addAttr(kOutMsgNelems, kETOutMsgNelems, "uint64"); + + // following two metadata are lists. + addAttr(kInSplit, kETInSplit, "string"); + addAttr(kOutSplit, kETOutSplit, "string"); + + addAttr(kGlobalRankStart, kETGlobalRankStart, "uint64"); + addAttr(kGlobalRankStride, kETGlobalRankStride, "uint64"); + + // pg_name is a string. + addAttr(kProcessGroupName, kETProcessGroupName, "string"); + addAttr(kProcessGroupDesc, kETProcessGroupDesc, "string"); + + addAttr(kGroupSize, kETGroupSize, "uint64"); + +#endif // USE_DISTRIBUTED + + // XXX consider using as string stream? + return attrs.size() == 0 ? "" : fmt::format(", {}", fmt::join(attrs, ", ")); +} + static void recordOperatorStart( ExecutionTraceObserver& ob, FunctionCallContext& fc, @@ -645,6 +729,9 @@ static void onFunctionExit(const RecordFunction& fn, ObserverContext* ctx_ptr) { op_schema_str = json_str_escape(c10::toString(op_schema.value())); } + const std::string additiona_attrs = + fn.isNcclMeta() ? getCommsNodeAttrs(fn) : ""; + writeJsonNode( ob->out, fc.name, @@ -664,7 +751,8 @@ static void onFunctionExit(const RecordFunction& fn, ObserverContext* ctx_ptr) { vectorToString(output_types), op_schema_str, fc.kernel_backend, - fc.kernel_file); + fc.kernel_file, + additiona_attrs); ob->out << ","; } catch (const std::exception& e) { LOG(WARNING) << "Exception in execution trace observer: [" << fc.name diff --git a/torch/csrc/profiler/util.cpp b/torch/csrc/profiler/util.cpp index f301596fca81..21e16a7e7eae 100644 --- a/torch/csrc/profiler/util.cpp +++ b/torch/csrc/profiler/util.cpp @@ -334,25 +334,22 @@ std::vector inputTypes(const at::RecordFunction& fn) { // ---------------------------------------------------------------------------- // -- NCCL Metadata ----------------------------------------------------------- // ---------------------------------------------------------------------------- -#ifdef USE_DISTRIBUTED -static constexpr auto kCommsName = "Collective name"; -static constexpr auto kDtype = "dtype"; -static constexpr auto kInMsgNelems = "In msg nelems"; -static constexpr auto kOutMsgNelems = "Out msg nelems"; -static constexpr auto kInSplit = "In split size"; -static constexpr auto kOutSplit = "Out split size"; -static constexpr auto kGlobalRankStart = "Global rank start"; -static constexpr auto kGlobalRankStride = "Global rank stride"; -static constexpr auto kGroupSize = "Group size"; -static constexpr auto kProcessGroupName = "Process Group Name"; -static constexpr auto kProcessGroupDesc = "Process Group Description"; -static constexpr auto kGroupRanks = "Process Group Ranks"; static constexpr int32_t kTruncatLength = 30; -#endif // USE_DISTRIBUTED + +template +inline std::string format_list(ListLikeType list, bool truncate) { + if (truncate && list.size() > kTruncatLength) { + return fmt::format( + "\"[{}, ...]\"", + fmt::join(list.begin(), list.begin() + kTruncatLength, ", ")); + } + return fmt::format("\"[{}]\"", fmt::join(list.begin(), list.end(), ", ")); +} std::unordered_map saveNcclMeta( - const at::RecordFunction& fn) { + const at::RecordFunction& fn, + bool truncate) { std::unordered_map map; #ifdef USE_DISTRIBUTED auto debugInfo = dynamic_cast( @@ -369,34 +366,13 @@ std::unordered_map saveNcclMeta( kDtype, fmt::format("\"{}\"", c10::toString(debugInfo->getDType()))); map.emplace(kInMsgNelems, std::to_string(debugInfo->getInMessageNelems())); map.emplace(kOutMsgNelems, std::to_string(debugInfo->getOutMessageNelems())); + auto& inSplitSizes = debugInfo->getInputSplitSizes(); - if (!inSplitSizes.empty() && inSplitSizes.size() <= kTruncatLength) { - map.emplace( - kInSplit, fmt::format("\"[{}]\"", fmt::join(inSplitSizes, ", "))); - } else if (inSplitSizes.size() > kTruncatLength) { - map.emplace( - kInSplit, - fmt::format( - "\"[{}, ...]\"", - fmt::join( - inSplitSizes.begin(), - inSplitSizes.begin() + kTruncatLength, - ", "))); - } + map.emplace(kInSplit, format_list(inSplitSizes, truncate)); + auto& outSplitSizes = debugInfo->getOutputSplitSizes(); - if (!outSplitSizes.empty() && outSplitSizes.size() <= kTruncatLength) { - map.emplace( - kOutSplit, fmt::format("\"[{}]\"", fmt::join(outSplitSizes, ", "))); - } else if (outSplitSizes.size() > kTruncatLength) { - map.emplace( - kOutSplit, - fmt::format( - "\"[{}, ...]\"", - fmt::join( - outSplitSizes.begin(), - outSplitSizes.begin() + kTruncatLength, - ", "))); - } + map.emplace(kOutSplit, format_list(outSplitSizes, truncate)); + auto globalRankStart = debugInfo->getGlobalRankStart(); if (globalRankStart >= 0) { map.emplace(kGlobalRankStart, std::to_string(globalRankStart)); @@ -415,20 +391,7 @@ std::unordered_map saveNcclMeta( map.emplace(kProcessGroupDesc, fmt::format("\"{}\"", group_desc)); } auto& groupRanks = debugInfo->getGroupRanks(); - if (!groupRanks.empty() && groupRanks.size() <= kTruncatLength) { - map.emplace( - kGroupRanks, fmt::format("\"[{}]\"", fmt::join(groupRanks, ", "))); - } else if (groupRanks.size() > kTruncatLength) { - map.emplace( - kGroupRanks, - fmt::format( - "\"[{}, ..., {}]\"", - fmt::join( - groupRanks.begin(), - groupRanks.begin() + kTruncatLength - 1, - ", "), - groupRanks.back())); - } + map.emplace(kGroupRanks, format_list(groupRanks, truncate)); #endif // USE_DISTRIBUTED return map; } diff --git a/torch/csrc/profiler/util.h b/torch/csrc/profiler/util.h index c8216c93f41c..3c995b49e602 100644 --- a/torch/csrc/profiler/util.h +++ b/torch/csrc/profiler/util.h @@ -100,7 +100,7 @@ TORCH_API std::vector inputTypes(const at::RecordFunction& fn); std::unordered_map TORCH_API saveExtraArgs(const at::RecordFunction& fn); std::unordered_map TORCH_API -saveNcclMeta(const at::RecordFunction& fn); +saveNcclMeta(const at::RecordFunction& fn, bool truncate = true); uint64_t TORCH_API computeFlops( const std::string& op_name, @@ -157,6 +157,21 @@ struct HashCombine { } }; +#ifdef USE_DISTRIBUTED +constexpr auto kCommsName = "Collective name"; +constexpr auto kDtype = "dtype"; +constexpr auto kInMsgNelems = "In msg nelems"; +constexpr auto kOutMsgNelems = "Out msg nelems"; +constexpr auto kInSplit = "In split size"; +constexpr auto kOutSplit = "Out split size"; +constexpr auto kGlobalRankStart = "Global rank start"; +constexpr auto kGlobalRankStride = "Global rank stride"; +constexpr auto kGroupSize = "Group size"; +constexpr auto kProcessGroupName = "Process Group Name"; +constexpr auto kProcessGroupDesc = "Process Group Description"; +constexpr auto kGroupRanks = "Process Group Ranks"; +#endif // USE_DISTRIBUTED + } // namespace impl } // namespace profiler } // namespace torch diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index 180513413093..b9873b9950fa 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -48,6 +48,10 @@ _verify_param_shape_across_processes, _sync_module_states, ) +from torch.profiler import ( + ExecutionTraceObserver, + ProfilerActivity, +) from torch.nn.parallel import DistributedDataParallel from torch.nn.parallel.distributed import _dump_DDP_relevant_env_vars, _MixedPrecision @@ -6867,7 +6871,20 @@ def test_ddp_grad_div_uneven_inputs(self): net.zero_grad() torch.cuda.synchronize(device=self.rank) - def _test_ddp_profiling(self, profiler_ctx): + def _test_ddp_profiling(self, profiler_ctx, profiler_ctx2=None): + """Runs DDP based model training and captures profiles. + This test will do two profiler runs. + 1. An inital basic run to check if profiler events are correctly captured. + 2. A second profiling pass after running some iterations of DDP, to check robustness of thread local state. + + args + profiler_ctx : Profiler context manager for pass 1 + profiler_ctx2 : Profiler context manager for pass 2. + This can be left out as None, in which case a deepcopy + of profiler_ctx is used. + Returns: + prof: Instantiated profiler object that can be used for post analysis. + """ batch = 3 dim = 10 num_iters = 6 @@ -6878,7 +6895,8 @@ def _test_ddp_profiling(self, profiler_ctx): model.cuda(self.rank), device_ids=[self.rank], ) - profiler_ctx_copy = copy.deepcopy(profiler_ctx) + if profiler_ctx2 is None: + profiler_ctx2 = copy.deepcopy(profiler_ctx) with profiler_ctx as prof: for i in range(num_iters): @@ -6913,7 +6931,7 @@ def _test_ddp_profiling(self, profiler_ctx): loss = net(inp).sum() loss.backward() # Now enable the profiler. - with profiler_ctx_copy as prof: + with profiler_ctx2 as prof: loss = net(inp).sum() loss.backward() @@ -6971,6 +6989,90 @@ def test_ddp_profiling_torch_profiler(self): self.assertEqual(a1["Out msg nelems"], 1, msg=f"{a1}") self.assertEqual(a1["dtype"], "Int", msg=f"{a1}") + def _validate_execution_trace_nccl(self, et_file: str) -> None: + """Torch profiler includes nccl metadata in an inserted operator called "record_param_comms" + We test for basic fields in theese nodes in the Execution Trace. + """ + with open(et_file) as f: + et = json.load(f) + + nccl_meta_nodes = [n for n in et["nodes"] if n["name"] == "record_param_comms"] + self.assertEqual(len(nccl_meta_nodes), 3) + per_coll_meta = defaultdict(list) + + # Sanity check NCCL metadata nodes + for n in nccl_meta_nodes: + attrs_list = n.get("attrs", []) + self.assertGreater(len(attrs_list), 0) + attrs = {a["name"]: a["value"] for a in attrs_list} + + collname = attrs.get("collective_name", "") + self.assertNotEqual(collname, "") + self.assertNotEqual(attrs.get("dtype", ""), "") + + per_coll_meta[collname].append(attrs) + if collname in {"wait"}: + continue + + self.assertEqual(attrs["pg_name"], "0") # yes this is a string + self.assertEqual(attrs["pg_desc"], "default_pg") + self.assertEqual(attrs["pg_size"], 2) + + self.assertGreaterEqual(attrs.get("in_msg_nelems", -1), 0) + self.assertGreaterEqual(attrs.get("out_msg_nelems", -1), 0) + self.assertTrue("in_split_size" in attrs.keys()) + self.assertTrue("out_split_size" in attrs.keys()) + self.assertEqual(attrs.get("global_rank_start", -1), 0) + self.assertEqual(attrs.get("global_rank_stride", -1), 1) + + # print(per_coll_meta) + self.assertEqual(len(per_coll_meta["allreduce"]), 2) + self.assertEqual(len(per_coll_meta["wait"]), 1) + + # check allreduce message sizes + a0 = per_coll_meta["allreduce"][0] + self.assertEqual(a0["out_msg_nelems"], 100, msg=f"{a0}") + self.assertEqual(a0["dtype"], "Float", msg=f"{a0}") + a1 = per_coll_meta["allreduce"][1] + self.assertEqual(a1["out_msg_nelems"], 1, msg=f"{a1}") + self.assertEqual(a1["dtype"], "Int", msg=f"{a1}") + + + @require_backend_is_available(DistTestCases.backend_feature["gpu"]) + @skip_if_lt_x_gpu(2) + @skip_but_pass_in_sandcastle_if(IS_FBCODE, "Kineto in fbcode code causes hang") + @skip_but_pass_in_sandcastle_if( + IS_MACOS or IS_WINDOWS, + "torch.profiler not enabled for mac/windows: https://github.com/pytorch/pytorch/pull/56124", + ) + @unittest.skipIf(BACKEND != "nccl", "Tests nccl metadata primarily.") + def test_ddp_profiling_execution_trace(self): + self.assertEqual(dist.get_backend(), "nccl") + # Create a temp file to save execution trace data + fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False) + fp.close() + et_file = fp.name + + et = ExecutionTraceObserver().register_callback(et_file) + + # first profiler context need not have ET + torch_profiler_ctx1 = torch.profiler.profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + ) + # collect ET in second profiler pass + torch_profiler_ctx2 = torch.profiler.profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + execution_trace_observer=et + ) + prof = self._test_ddp_profiling( + profiler_ctx=torch_profiler_ctx1, + profiler_ctx2=torch_profiler_ctx2, + ) + + print(f"Execution trace saved at {fp.name}") + self._validate_execution_trace_nccl(et_file) + + @skip_if_lt_x_gpu(2) @skip_but_pass_in_sandcastle_if( BACKEND not in DistTestCases.backend_feature["ddp"], From f89500030bb6909569668ba47b9895780dc44052 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 17 May 2024 19:19:02 +0000 Subject: [PATCH 072/116] Revert "Remove redundant serialization code (#126249)" This reverts commit aab448e381366d4cf499145adffe9fcb1ac2b28d. Reverted https://github.com/pytorch/pytorch/pull/126249 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it is failing sigmoid/frontend:serialization_test internally ([comment](https://github.com/pytorch/pytorch/pull/126249#issuecomment-2118233656)) --- torch/_export/serde/serialize.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py index 5ff02a787690..e0d76920157c 100644 --- a/torch/_export/serde/serialize.py +++ b/torch/_export/serde/serialize.py @@ -539,9 +539,21 @@ def export_nn_module_stack(val): path, ty = val assert isinstance(path, str) - assert isinstance(ty, str) - return path + "," + ty + # node.meta["nn_module_stack"] could have two forms: + # 1. (path: str, module_type: 'type'), e.g. + # ('', ) + # 2. (path: str, module_type: str), e.g. + # ('', 'sigmoid.inference.MySimpleModel') + # ExportedProgram directly produced by torch.export() has form 1 + # ExportedProgram deserialized from disk has form 2 + # TODO: This is not ideal, we should fix this. + if isinstance(ty, str): + normalized_ty = ty + else: + normalized_ty = ty.__module__ + "." + ty.__qualname__ + + return path + "," + normalized_ty # Serialize to "key,orig_path,type_str" nn_module_list = [ From 875221dedf9eeda567b2327336a2aedc681b9cdc Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 17 May 2024 19:30:02 +0000 Subject: [PATCH 073/116] Revert "Fix aarch64 debug build with GCC (#126290)" This reverts commit 91bf952d10e9524a9b078900d9807efa5d252f5c. Reverted https://github.com/pytorch/pytorch/pull/126290 on behalf of https://github.com/huydhn due to There seems to be a mis-match closing curly bracket here and it breaks some internal build in D57474505 ([comment](https://github.com/pytorch/pytorch/pull/126290#issuecomment-2118246756)) --- aten/src/ATen/native/cpu/int8mm_kernel.cpp | 8 -------- 1 file changed, 8 deletions(-) diff --git a/aten/src/ATen/native/cpu/int8mm_kernel.cpp b/aten/src/ATen/native/cpu/int8mm_kernel.cpp index 9eaf43ec5f00..bd266030b256 100644 --- a/aten/src/ATen/native/cpu/int8mm_kernel.cpp +++ b/aten/src/ATen/native/cpu/int8mm_kernel.cpp @@ -250,19 +250,11 @@ inline void tinygemm_kernel_( }); } -#if __OPTIMIZE__ float32x4_t scale_val = load_as_float32x4(scales); c10::ForcedUnroll{}([&](auto i) { C[m * ldc + i] = reduce(c_val[i]) * vgetq_lane_f32(scale_val, i); }); } -#else - // Workaround GCCs inability to infer lane index at compile time - // See https://github.com/pytorch/pytorch/issues/126283 - c10::ForcedUnroll{}([&](auto i) { - C[m * ldc + i] = reduce(c_val[i]) * float(scales[i]); - }); -#endif } template From eb0b16db92b981a8fa277d2897d32037703b7838 Mon Sep 17 00:00:00 2001 From: Kwanghoon An Date: Fri, 17 May 2024 19:44:50 +0000 Subject: [PATCH 074/116] Initial implementation of AdaRound (#126153) Summary: This is an implementation of AdaRound from a paper https://arxiv.org/abs/2004.10568 This algorithm is going to be used by multiple people, hence we need make it official implementation. Differential Revision: D57227565 Pull Request resolved: https://github.com/pytorch/pytorch/pull/126153 Approved by: https://github.com/jerryzh168, https://github.com/huydhn --- .../core/experimental/test_adaround_eager.py | 118 +++++++++ .../experimental/adaround_fake_quantize.py | 148 +++++++++++ .../experimental/adaround_loss.py | 96 +++++++ .../experimental/adaround_optimization.py | 238 ++++++++++++++++++ 4 files changed, 600 insertions(+) create mode 100644 test/quantization/core/experimental/test_adaround_eager.py create mode 100644 torch/ao/quantization/experimental/adaround_fake_quantize.py create mode 100644 torch/ao/quantization/experimental/adaround_loss.py create mode 100644 torch/ao/quantization/experimental/adaround_optimization.py diff --git a/test/quantization/core/experimental/test_adaround_eager.py b/test/quantization/core/experimental/test_adaround_eager.py new file mode 100644 index 000000000000..33a16f21bd0f --- /dev/null +++ b/test/quantization/core/experimental/test_adaround_eager.py @@ -0,0 +1,118 @@ +# Owner(s): ["oncall: speech_infra"] + +import copy + +import torch +import torch.nn as nn +from torch.ao.quantization.experimental.adaround_optimization import ( + AdaptiveRoundingOptimizer, +) + +from torch.nn import functional as F +from torch.quantization.observer import MinMaxObserver +from torch.testing._internal.common_quantization import QuantizationTestCase + + +def forward_wrapper(fetcher): + def forward(module, input, output): + fetcher.append(input[0].detach()) + fetcher.append(output.detach()) + + return forward + + +class TestAdaround(QuantizationTestCase): + def feedforawrd_callback( + self, + model, + data, + ) -> None: + model(data) + + def run_adaround(self, model, img_data): + adaround_optimizer = AdaptiveRoundingOptimizer( + model, + self.feedforawrd_callback, + forward_wrapper, + img_data, + max_iter=100, + batch_size=10, + ) + adarounded_model = adaround_optimizer.run_adaround() + return adarounded_model + + def get_fake_quant(self, model): + hard_fake_quant_model = copy.deepcopy(model) + for _, module in hard_fake_quant_model.named_modules(): + if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)): + weight_observer = MinMaxObserver( + quant_min=-128, + quant_max=127, + dtype=torch.qint8, + qscheme=torch.per_tensor_symmetric, + ) + weight_observer(module.weight) + scale, zero_point = weight_observer.calculate_qparams() + fake_quant_module = torch.fake_quantize_per_tensor_affine( + module.weight, + scale=scale, + zero_point=zero_point, + quant_min=-128, + quant_max=127, + ) + module.weight.data.copy_(fake_quant_module) + return hard_fake_quant_model + + def test_linear_chain(self): + class LinearChain(nn.Module): + def __init__(self): + super().__init__() + self.linear1 = nn.Linear(3, 4) + self.linear2 = nn.Linear(4, 5) + self.linear3 = nn.Linear(5, 6) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + x = self.linear3(x) + return x + + float_model = LinearChain() + img_data = [torch.rand(10, 3, dtype=torch.float) for _ in range(50)] + adarounded_model = self.run_adaround(float_model, img_data) + fq_model = self.get_fake_quant(float_model) + rand_input = torch.rand(10, 3) + with torch.no_grad(): + ada_out = adarounded_model(rand_input) + fq_out = fq_model(rand_input) + float_out = float_model(rand_input) + ada_loss = F.mse_loss(ada_out, float_out) + fq_loss = F.mse_loss(fq_out, float_out) + self.assertTrue(ada_loss.item() < fq_loss.item()) + + def test_conv_chain(self): + class ConvChain(nn.Module): + def __init__(self): + super().__init__() + self.conv2d1 = nn.Conv2d(3, 4, 5, 5) + self.conv2d2 = nn.Conv2d(4, 5, 5, 5) + self.conv2d3 = nn.Conv2d(5, 6, 5, 5) + + def forward(self, x): + x = self.conv2d1(x) + x = self.conv2d2(x) + x = self.conv2d3(x) + return x + + float_model = ConvChain() + img_data = [torch.rand(10, 3, 125, 125, dtype=torch.float) for _ in range(50)] + adarounded_model = self.run_adaround(float_model, img_data) + fq_model = self.get_fake_quant(float_model) + rand_input = torch.rand(10, 3, 256, 256) + with torch.no_grad(): + ada_out = adarounded_model(rand_input) + fq_out = fq_model(rand_input) + float_out = float_model(rand_input) + ada_loss = F.mse_loss(ada_out, float_out) + fq_loss = F.mse_loss(fq_out, float_out) + self.assertTrue(ada_loss.item() < fq_loss.item()) diff --git a/torch/ao/quantization/experimental/adaround_fake_quantize.py b/torch/ao/quantization/experimental/adaround_fake_quantize.py new file mode 100644 index 000000000000..4d988bbb25bb --- /dev/null +++ b/torch/ao/quantization/experimental/adaround_fake_quantize.py @@ -0,0 +1,148 @@ +from typing import Tuple + +import torch +from torch.ao.quantization.fake_quantize import _is_symmetric_quant +from torch.ao.quantization.utils import is_per_tensor +from torch.quantization import FakeQuantize +from torch.quantization.observer import MinMaxObserver + + +class AdaroundFakeQuantizer(FakeQuantize): + """ + This is a FakeQuantizer that enables an adaptive rounding fake quantizer. + Adaround is a technique to adaptively round weights, derived from the paper https://arxiv.org/pdf/2004.10568.pdf + For HTP compatibility, we are targeting to use symmetric quantization + """ + + scale: torch.Tensor + zero_point: torch.Tensor + V: torch.nn.Parameter + + # pyre-fixme[3]: Return type must be annotated. + def __init__( + self, + observer=MinMaxObserver, + qscheme=torch.per_tensor_symmetric, # not used, but needed for fakequant + quant_min: int = -128, + quant_max: int = 127, + ch_axis: int = 0, + # pyre-fixme[2]: Parameter must be annotated. + **observer_kwargs, + ): + super().__init__( + observer=observer, + qscheme=qscheme, + quant_min=quant_min, + quant_max=quant_max, + is_dynamic=False, + **observer_kwargs, + ) + # Populate quant_min/quant_max to observer_kwargs if valid + if quant_min is not None and quant_max is not None: + assert ( + quant_min <= quant_max + ), "quant_min must be less than or equal to quant_max" + # pyre-fixme[4]: Attribute must be annotated. + self.qscheme = qscheme + self.is_per_tensor: bool = is_per_tensor(qscheme) + self.is_symmetric: bool = _is_symmetric_quant(qscheme) + assert self.is_symmetric, "Only symmetric quantization is supported" + self.ch_axis: int = ch_axis + + self.scale = torch.tensor([], requires_grad=False) + self.zero_point = torch.tensor([], requires_grad=False) + self.V = torch.nn.Parameter(torch.tensor([]), requires_grad=True) + # Fixed Stretch parameters + self.zeta: torch.Tensor = torch.tensor(1.1, requires_grad=False) + self.gamma: torch.Tensor = torch.tensor(-0.1, requires_grad=False) + self.sigmoid = torch.nn.Sigmoid() + self.use_soft_rounding = True + + @torch.jit.export + def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]: + return self.scale, self.zero_point + + @torch.jit.export + def extra_repr(self) -> str: + return ( + f"fake_quant_enabled={self.fake_quant_enabled}, observer_enabled={self.observer_enabled}, " + f"quant_min={self.activation_post_process.quant_min}, quant_max={self.activation_post_process.quant_max}, " + f"dtype={self.dtype}, qscheme={self.qscheme}, ch_axis={self.ch_axis}, " + f"scale={self.scale}, zero_point={self.zero_point}, (self.V >= 0).int().sum()={(self.V >= 0).int().sum()}" + ) + + def enable_weight_fake_quant(self) -> None: + self.fake_quant_enabled[0] = 1 + + def get_rectified_sigmoid_func(self) -> torch.Tensor: + if self.use_soft_rounding: + return torch.clamp( + self.sigmoid(self.V) * (self.zeta - self.gamma) + self.gamma, + min=0, + max=1, + ) + else: + # This will dump a binary solution + return (self.V >= 0).int() + + @torch.jit.ignore + def update_scale( + self, X: torch.Tensor, _scale: torch.Tensor, _zero_point: torch.Tensor + ) -> None: + if self.scale.numel() == 0: + self.scale.data = _scale.to(X.device) + self.zero_point = _zero_point.to(X.device) + else: + self.scale.data = _scale + if not self.is_symmetric: + self.zero_point = _zero_point + else: + self.zero_point = torch.zeros_like(_zero_point) + for i in range(X.dim()): + if i == self.ch_axis: + continue + self.zero_point = self.zero_point.unsqueeze(i) + X_q = X / self.scale + X_q_floor = torch.floor(X_q) + residual = X_q - X_q_floor # [0,1) + assert torch.all( + torch.ge(residual, 0) + ), "residual should be non-negative [0, 1)" + V_init = -torch.log((self.zeta - self.gamma) / (residual - self.gamma) - 1) + self.V.data = V_init + + def forward(self, X: torch.Tensor) -> torch.Tensor: + if self.observer_enabled[0] == 1: + X_detached = X.detach() + self.activation_post_process(X_detached) + _scale, _zero_point = self.activation_post_process.calculate_qparams() + _scale, _zero_point = _scale.to(self.scale.device), _zero_point.to( + self.zero_point.device + ) + dims = list(range(X.dim())) + if not self.is_per_tensor: + dims.remove(self.ch_axis) + if not self.is_per_tensor: + for i in range(X.dim()): + if i == self.ch_axis: + continue + _scale = _scale.unsqueeze(i) + _zero_point = _zero_point.unsqueeze(i) + self.update_scale(X_detached, _scale, _zero_point) + + if self.fake_quant_enabled[0] == 1: + # Perform soft quantization + # See the equation (23) in Adaround paper + h_v = self.get_rectified_sigmoid_func() + X_q = X / self.scale + # Straight-Through Estimator for floor function + X_q_floor = torch.floor(X_q) + self.zero_point + # Regardless of rounding, gradient should be able to flow back to self.V from X_q_dq. + # With adaround, we don't train weight, but train V only. + X_q_dq = ( + torch.clamp(X_q_floor + h_v, min=self.quant_min, max=self.quant_max) + - self.zero_point + ) * self.scale + return X_q_dq + else: + return X diff --git a/torch/ao/quantization/experimental/adaround_loss.py b/torch/ao/quantization/experimental/adaround_loss.py new file mode 100644 index 000000000000..8080d72cc6da --- /dev/null +++ b/torch/ao/quantization/experimental/adaround_loss.py @@ -0,0 +1,96 @@ +from typing import Tuple + +import numpy as np +import torch +from torch.nn import functional as F + +ADAROUND_ZETA: float = 1.1 +ADAROUND_GAMMA: float = -0.1 + + +class AdaptiveRoundingLoss(torch.nn.Module): + """ + Adaptive Rounding Loss functions described in https://arxiv.org/pdf/2004.10568.pdf + rounding regularization is eq [24] + reconstruction loss is eq [25] except regularization term + """ + + def __init__( + self, + max_iter: int, + warm_start: float = 0.2, + beta_range: Tuple[int, int] = (20, 2), + reg_param: float = 0.001, + ) -> None: + super().__init__() + self.max_iter = max_iter + self.warm_start = warm_start + self.beta_range = beta_range + self.reg_param = reg_param + + def rounding_regularization( + self, + V: torch.Tensor, + curr_iter: int, + ) -> torch.Tensor: + """ + Major logics copied from official Adaround Implementation. + Apply rounding regularization to the input tensor V. + """ + assert ( + curr_iter < self.max_iter + ), "Current iteration strictly les sthan max iteration" + if curr_iter < self.warm_start * self.max_iter: + return torch.tensor(0.0) + else: + start_beta, end_beta = self.beta_range + warm_start_end_iter = self.warm_start * self.max_iter + + # compute relative iteration of current iteration + rel_iter = (curr_iter - warm_start_end_iter) / ( + self.max_iter - warm_start_end_iter + ) + beta = end_beta + 0.5 * (start_beta - end_beta) * ( + 1 + np.cos(rel_iter * np.pi) + ) + + # A rectified sigmoid for soft-quantization as formualted [23] in https://arxiv.org/pdf/2004.10568.pdf + h_alpha = torch.clamp( + torch.sigmoid(V) * (ADAROUND_ZETA - ADAROUND_GAMMA) + ADAROUND_GAMMA, + min=0, + max=1, + ) + + # Apply rounding regularization + # This regularization term helps out term to converge into binary solution either 0 or 1 at the end of optimization. + inner_term = torch.add(2 * h_alpha, -1).abs().pow(beta) + regularization_term = torch.add(1, -inner_term).sum() + return regularization_term * self.reg_param + + def reconstruction_loss( + self, + soft_quantized_output: torch.Tensor, + original_output: torch.Tensor, + ) -> torch.Tensor: + """ + Compute the reconstruction loss between the soft quantized output and the original output. + """ + return F.mse_loss( + soft_quantized_output, original_output, reduction="none" + ).mean() + + def forward( + self, + soft_quantized_output: torch.Tensor, + original_output: torch.Tensor, + V: torch.Tensor, + curr_iter: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Compute the asymmetric reconstruction formulation as eq [25] + """ + regularization_term = self.rounding_regularization(V, curr_iter) + reconstruction_term = self.reconstruction_loss( + soft_quantized_output, original_output + ) + return regularization_term, reconstruction_term diff --git a/torch/ao/quantization/experimental/adaround_optimization.py b/torch/ao/quantization/experimental/adaround_optimization.py new file mode 100644 index 000000000000..7304f885a6f3 --- /dev/null +++ b/torch/ao/quantization/experimental/adaround_optimization.py @@ -0,0 +1,238 @@ +import copy +import logging +from typing import Any, Callable, List, Optional, Tuple, Type, Union + +import torch +from torch.ao.quantization.experimental.adaround_fake_quantize import ( + AdaroundFakeQuantizer, +) +from torch.ao.quantization.experimental.adaround_loss import AdaptiveRoundingLoss +from torch.ao.quantization.observer import MinMaxObserver +from torch.nn import functional as F +from torch.nn.parallel import DataParallel +from torch.utils.data import DataLoader, TensorDataset + +logger: logging.Logger = logging.getLogger(__name__) + + +class AdaptiveRoundingOptimizer: + def __init__( + self, + model: Union[torch.nn.Module, torch.nn.DataParallel], + callback: Callable[[torch.nn.Module, List[Any]], None], + forward_hook_wrapper: Callable[[List[torch.Tensor]], Callable], + data: List[Any], + observer: Type[torch.ao.quantization.observer.ObserverBase] = MinMaxObserver, + max_iter=10000, + dtype: torch.dtype = torch.qint8, + quant_min=-128, + quant_max=127, + qscheme: torch.qscheme = torch.per_tensor_symmetric, + batch_size: int = 256, + ): + self.model = model + self.q_model = copy.deepcopy(self.model) + self.device = torch.device("cuda") if torch.cuda.is_available() else None + self.callback = callback + self.forward_hook_wrapper = forward_hook_wrapper + # TODO rather than having a data as list type or, we better pass *iterator* instead of list + self.data = data + self.batch_size = min(batch_size, len(data)) + self.max_iter = max_iter + self.adaptive_round_loss_fn = AdaptiveRoundingLoss( + max_iter=self.max_iter, warm_start=0.2 + ) + self.dtype = dtype + self.observer = observer + self.quant_min = quant_min + self.quant_max = quant_max + self.qscheme = qscheme + + def run_adaround(self) -> torch.nn.Module: + layer_list: List[Tuple[str, torch.nn.Module, torch.nn.Module]] = [] + for (name, module), q_module in zip( + self.model.named_modules(), self.q_model.modules() + ): + if isinstance(module, (torch.nn.Conv1d, torch.nn.Linear)): + # Knowing activation ahead-of-time would be helpful for asymmetric formulation + # But this is challenging in eager mode, but graph module. + layer_list.append((name, module, q_module)) + logger.info(f"Total number of layers : {len(layer_list)}") # noqa: G004 + + for name, module, q_module in layer_list: + logger.info( + f"Kick start adaptive rounding on {name} module {module}" # noqa: G004 + ) + self.optimize_adaptive_rounding( + module, + q_module, + None, + ) + + return ( + self.q_model.module + if isinstance(self.q_model, DataParallel) + else self.q_model + ) + + def get_data_inp_out( + self, module: torch.nn.Module, q_module: torch.nn.Module, data: List[Any] + ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]: + fp_out: List[torch.Tensor] = [] + q_input: List[torch.Tensor] = [] + fp_input: List[torch.Tensor] = [] + fp32_fetcher: List[torch.Tensor] = [] + quant_fetcher: List[torch.Tensor] = [] + handler1 = module.register_forward_hook(self.forward_hook_wrapper(fp32_fetcher)) + handler2 = q_module.register_forward_hook( + self.forward_hook_wrapper(quant_fetcher) + ) + for data_ in data: + with torch.no_grad(): + self.callback(self.model, data_) + self.callback(self.q_model, data_) + fp32_output = fp32_fetcher[1] + quant_input = quant_fetcher[0] + fp_out.append(fp32_output) + q_input.append(quant_input) + fp_input.append(fp32_fetcher[0]) + handler1.remove() + handler2.remove() + return q_input, fp_out, fp_input + + @torch.no_grad() + def feed_forward(self, x, weight, module): + if isinstance(module, torch.nn.Conv1d): + out = torch.nn.functional.conv1d( + x, + weight, + stride=module.stride, + padding=module.padding, + dilation=module.dilation, + groups=module.groups, + ) + elif isinstance(module, torch.nn.Linear): + out = torch.nn.functional.linear( + x, + weight, + bias=module.bias, + ) + else: + raise NotImplementedError + return out + + def _compute_and_display_local_losses( + self, + ada_quantizer: AdaroundFakeQuantizer, + q_module: torch.nn.Module, + q_inp: torch.Tensor, + fp_out: torch.Tensor, + ): + with torch.no_grad(): + ada_quantizer.use_soft_rounding = False + q_w_hard_round = ada_quantizer(q_module.weight) + out_hard_quant = self.feed_forward(q_inp, q_w_hard_round, q_module) + ada_quantizer.use_soft_rounding = True + q_w_soft_round = ada_quantizer(q_module.weight) + out_soft_quant = self.feed_forward(q_inp, q_w_soft_round, q_module) + soft_quant_loss = F.mse_loss(out_soft_quant, fp_out) + hard_quant_loss = F.mse_loss(out_hard_quant, fp_out) + logger.info( + f"soft quant loss: {soft_quant_loss.item()} hard quant loss: {hard_quant_loss.item()}" # noqa: G004 + ) + + def optimize_adaptive_rounding( + self, + module: torch.nn.Module, + q_module: torch.nn.Module, + activation: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, + ) -> None: + ada_quantizer = AdaroundFakeQuantizer( + dtype=self.dtype, + observer=self.observer, + qscheme=self.qscheme, + quant_min=self.quant_min, + quant_max=self.quant_max, + reduce_range=False, + ) + ada_quantizer.enable_observer() + ada_quantizer(q_module.weight) + ada_quantizer.disable_observer() + ada_quantizer.enable_fake_quant() + optimizer = torch.optim.Adam([ada_quantizer.V]) + inp, out, fp_in = self.get_data_inp_out(module, q_module, self.data) + + logger.info("==================== Before adaround ====================") + test_in, test_out, fp_test_in = self.get_data_inp_out( + module, q_module, self.data[0] + ) + + assert ( + torch.abs(test_out[0] - module(fp_test_in[0])).sum().item() == 0 + ), "In-placed activation is detected, please do not use activation in-placed" + # Stack the tensors in each list into a single tensor + # Assuming inp and out are your lists of tensors + inp_tensor = torch.vstack(inp) + out_tensor = torch.vstack(out) + dataset = TensorDataset(inp_tensor, out_tensor) + dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True) + + self._compute_and_display_local_losses( + ada_quantizer, q_module, test_in[0], test_out[0] + ) + global_idx = 0 + one_iter = len(out) // self.batch_size + for iteration in range(self.max_iter // one_iter): + reconstruction_loss = regularization_loss = torch.tensor(0) + for q_inp, fp_out in dataloader: + optimizer.zero_grad() + q_weight = ada_quantizer(q_module.weight) + if isinstance(module, torch.nn.Conv1d): + q_out = torch.nn.functional.conv1d( + q_inp, + q_weight, + stride=q_module.stride, + padding=q_module.padding, + dilation=q_module.dilation, + groups=q_module.groups, + ) + elif isinstance(q_module, torch.nn.Linear): + q_out = torch.nn.functional.linear( + q_inp, + q_weight, + bias=q_module.bias, + ) + else: + raise NotImplementedError + regularization_loss, reconstruction_loss = self.adaptive_round_loss_fn( + fp_out, + q_out, + ada_quantizer.V, + curr_iter=global_idx, + ) + loss = regularization_loss + reconstruction_loss + loss.backward() + optimizer.step() + global_idx += 1 + if global_idx >= self.max_iter: + break + if global_idx >= self.max_iter: + break + if iteration % 30 == 0: + logger.info( + f"glob iter {global_idx} regularization_loss {regularization_loss.item()} " # noqa: G004 + f"reconstruction_loss {reconstruction_loss.item()}" # noqa: G004 + ) + logger.info("==================== After adaround ====================") + self._compute_and_display_local_losses( + ada_quantizer, q_module, test_in[0], test_out[0] + ) + + ada_quantizer.use_soft_rounding = True + ada_quantizer.V.requires_grad = False + ada_quantizer = ada_quantizer.eval() + q_weight = ada_quantizer(q_module.weight) + # At the end of optimization, we need to copy the adarounded weight back to the original module + q_module.weight.data.copy_(q_weight) + # Eager mode requires observer to be set as "weight_fake_quant" to be parsed + q_module.weight_fake_quant = ada_quantizer.activation_post_process From 90a5aeea79fce59b9ddfa717803a01164ecbffbe Mon Sep 17 00:00:00 2001 From: PaliC Date: Fri, 17 May 2024 19:44:57 +0000 Subject: [PATCH 075/116] [distributed] Add cpp-httplib to pytorch (#126470) Adds https://github.com/yhirose/cpp-httplib such that we are able to use https for host to host communication in distributed (specifically torchrun) Todo: We likely need to add cpp-httplib somewhere in the build (cmake/bazel) but first we should write the code for it. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126470 Approved by: https://github.com/d4l3k, https://github.com/Skylion007 --- .gitmodules | 4 ++++ third_party/cpp-httplib | 1 + 2 files changed, 5 insertions(+) create mode 160000 third_party/cpp-httplib diff --git a/.gitmodules b/.gitmodules index 4443eace838d..0d9a339fb53e 100644 --- a/.gitmodules +++ b/.gitmodules @@ -148,3 +148,7 @@ [submodule "third_party/opentelemetry-cpp"] path = third_party/opentelemetry-cpp url = https://github.com/open-telemetry/opentelemetry-cpp.git +[submodule "third_party/cpp-httplib"] + path = third_party/cpp-httplib + url = git@github.com:yhirose/cpp-httplib.git + branch = v0.15.3 diff --git a/third_party/cpp-httplib b/third_party/cpp-httplib new file mode 160000 index 000000000000..3b6597bba913 --- /dev/null +++ b/third_party/cpp-httplib @@ -0,0 +1 @@ +Subproject commit 3b6597bba913d51161383657829b7e644e59c006 From 95b2766864a02cdb6bc49b97a7f7e859b7167e70 Mon Sep 17 00:00:00 2001 From: Aaron Gokaslan Date: Fri, 17 May 2024 19:52:08 +0000 Subject: [PATCH 076/116] [BE][Ez]: Use NotADirectoryError in tensorboard writer (#126534) Slightly improve exception typing for tensorboard wrriter Pull Request resolved: https://github.com/pytorch/pytorch/pull/126534 Approved by: https://github.com/ezyang --- torch/utils/tensorboard/writer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/utils/tensorboard/writer.py b/torch/utils/tensorboard/writer.py index cd281bc0d3fc..c646ce0c0c11 100644 --- a/torch/utils/tensorboard/writer.py +++ b/torch/utils/tensorboard/writer.py @@ -916,7 +916,7 @@ def add_embedding( "warning: Embedding dir exists, did you set global_step for add_embedding()?" ) else: - raise FileExistsError( + raise NotADirectoryError( f"Path: `{save_path}` exists, but is a file. Cannot proceed." ) else: From d782e43464bc939429054e3373af8cf80554fc45 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 17 May 2024 20:29:20 +0000 Subject: [PATCH 077/116] Revert "[FSDP2] Fixed 2D clip grad norm test (#126497)" This reverts commit 3f289063117673650db868c978bf3cb8125a22dc. Reverted https://github.com/pytorch/pytorch/pull/126497 on behalf of https://github.com/jeanschmidt due to reverting to check if might have introduced inductor cuda 12 issues ([comment](https://github.com/pytorch/pytorch/pull/126497#issuecomment-2118338716)) --- .ci/pytorch/test.sh | 1 - .../fsdp/test_fully_shard_clip_grad_norm_.py | 36 +++++------ torch/testing/_internal/common_fsdp.py | 61 +++++++------------ 3 files changed, 37 insertions(+), 61 deletions(-) diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index 1953b314ec83..9483bb630d4e 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -326,7 +326,6 @@ test_inductor_distributed() { python test/run_test.py -i distributed/_composable/fsdp/test_fully_shard_frozen.py --verbose python test/run_test.py -i distributed/_composable/fsdp/test_fully_shard_mixed_precision.py -k test_compute_dtype --verbose python test/run_test.py -i distributed/_composable/fsdp/test_fully_shard_mixed_precision.py -k test_reduce_dtype --verbose - python test/run_test.py -i distributed/_composable/fsdp/test_fully_shard_clip_grad_norm.py -k test_clip_grad_norm_2d --verbose python test/run_test.py -i distributed/fsdp/test_fsdp_tp_integration.py -k test_fsdp_tp_integration --verbose # this runs on both single-gpu and multi-gpu instance. It should be smart about skipping tests that aren't supported diff --git a/test/distributed/_composable/fsdp/test_fully_shard_clip_grad_norm_.py b/test/distributed/_composable/fsdp/test_fully_shard_clip_grad_norm_.py index 4e1e897a11be..9139b62f1367 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_clip_grad_norm_.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_clip_grad_norm_.py @@ -12,7 +12,7 @@ from torch.distributed._tensor.debug import CommDebugMode from torch.distributed.device_mesh import DeviceMesh, init_device_mesh from torch.testing._internal.common_distributed import skip_if_lt_x_gpu -from torch.testing._internal.common_fsdp import FSDPTest, MLPStack +from torch.testing._internal.common_fsdp import FSDPTest from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( ModelArgs, @@ -30,12 +30,12 @@ def _test_clip_grad_norm( ref_optim: torch.optim.Optimizer, model: nn.Module, optim: torch.optim.Optimizer, - inp: torch.Tensor, dp_mesh: Optional[DeviceMesh] = None, ): vector_norm_fn = functools.partial(torch.linalg.vector_norm, ord=norm_type) dp_mesh = dp_mesh or init_device_mesh("cuda", (self.world_size,)) torch.manual_seed(42 + dp_mesh.get_local_rank() + 1) + inp = torch.randint(0, model.model_args.vocab_size, (3, 16), device="cuda") for iter_idx in range(10): ref_optim.zero_grad() ref_model(inp).sum().backward() @@ -53,11 +53,11 @@ def _test_clip_grad_norm( continue self.assertEqual(ref_grad, param.grad.full_tensor()) - # Check that at least one gradient has norm greater than the max - # norm before clipping to ensure the clipping is not vacuous - self.assertTrue(any(vector_norm_fn(g).item() > max_norm for g in ref_grads)) + # Check that all gradients have norm greater than the max norm + # before clipping to ensure the clipping is not vacuous + self.assertTrue(all(vector_norm_fn(g).item() > max_norm for g in ref_grads)) self.assertTrue( - any(vector_norm_fn(g).item() > max_norm for g in local_grads) + all(vector_norm_fn(g).item() > max_norm for g in local_grads) ) # Check gradient norm clipping via total norm and individual @@ -111,10 +111,7 @@ def test_clip_grad_norm_1d(self): fully_shard(module) fully_shard(model) optim = torch.optim.Adam(model.parameters(), lr=1e-2) - inp = torch.randint(0, model.model_args.vocab_size, (3, 16), device="cuda") - self._test_clip_grad_norm( - 1, norm_type, ref_model, ref_optim, model, optim, inp - ) + self._test_clip_grad_norm(1, norm_type, ref_model, ref_optim, model, optim) class TestClipGradNormWorldSize4(_TestClipGradNormBase): @@ -133,23 +130,20 @@ def test_clip_grad_norm_2d(self): ) dp_mesh, tp_mesh = global_mesh["dp"], global_mesh["tp"] torch.manual_seed(42) - # Test using an MLP stack, not a transformer, since the transformer - # has some more significant numeric differences from the TP - model = MLPStack(16, with_seq_parallel=True) + model_args = ModelArgs(dropout_p=0.0) + model = Transformer(model_args) ref_model = replicate( copy.deepcopy(model).cuda(), process_group=dp_mesh.get_group() ) ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2) - model.parallelize( - tp_mesh, - dp_mesh, - use_activation_checkpointing=False, - reshard_after_forward=True, - ) + model = Transformer.parallelize(model, tp_mesh, use_seq_parallel=True) + for module in model.modules(): + if isinstance(module, TransformerBlock): + fully_shard(module, mesh=dp_mesh) + fully_shard(model, mesh=dp_mesh) optim = torch.optim.Adam(model.parameters(), lr=1e-2) - inp = torch.randn(2, 16, device="cuda") self._test_clip_grad_norm( - 0.5, norm_type, ref_model, ref_optim, model, optim, inp, dp_mesh + 1, norm_type, ref_model, ref_optim, model, optim, dp_mesh ) diff --git a/torch/testing/_internal/common_fsdp.py b/torch/testing/_internal/common_fsdp.py index 94b6a68f931c..283982b2ba44 100644 --- a/torch/testing/_internal/common_fsdp.py +++ b/torch/testing/_internal/common_fsdp.py @@ -10,17 +10,7 @@ from copy import deepcopy from enum import auto, Enum from functools import partial, wraps -from typing import ( - Any, - Callable, - Dict, - List, - no_type_check, - Optional, - Tuple, - Type, - Union, -) +from typing import Any, Callable, Dict, no_type_check, Optional, Tuple, Type, Union from unittest import mock import torch @@ -49,7 +39,6 @@ ColwiseParallel, parallelize_module, RowwiseParallel, - SequenceParallel, ) from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer from torch.nn.parallel.distributed import DistributedDataParallel as DDP @@ -876,17 +865,15 @@ def reset_parameters(self): class MLPStack(nn.Sequential): - def __init__(self, mlp_dim: int, *, with_seq_parallel: bool = False): - modules: List[nn.Module] = [ + def __init__(self, mlp_dim: int): + modules = [ + nn.LayerNorm(mlp_dim, bias=False), # Use multiplier of 3 to exercise uneven case MLP(mlp_dim, dim_multiplier=3), MLP(mlp_dim), MLP(mlp_dim, dim_multiplier=3), ] - if with_seq_parallel: - modules.append(nn.LayerNorm(mlp_dim, bias=False)) super().__init__(*modules) - self.with_seq_parallel = with_seq_parallel def parallelize( self, @@ -895,29 +882,25 @@ def parallelize( use_activation_checkpointing: bool, reshard_after_forward: bool, ) -> "MLPStack": - parallelize_plan = { - # Pass `use_local_output=False` to keep as DTensor to preserve - # uneven activation dims - "0.in_proj": ColwiseParallel(use_local_output=False), - "0.out_proj": RowwiseParallel(use_local_output=False), - "1.in_proj": ColwiseParallel(use_local_output=False), - "1.out_proj": RowwiseParallel(use_local_output=False), - "2.in_proj": ColwiseParallel(use_local_output=False), - "2.out_proj": RowwiseParallel(output_layouts=Shard(1)) - if self.with_seq_parallel - else RowwiseParallel(), - } - if self.with_seq_parallel: - parallelize_plan["3"] = SequenceParallel(sequence_dim=1) - parallelize_module(self, device_mesh=tp_mesh, parallelize_plan=parallelize_plan) - for module in self: - if isinstance(module, nn.LayerNorm): - continue + parallelize_module( + self, + device_mesh=tp_mesh, + # Leave the layer norm as implicitly replicated + parallelize_plan={ + # Pass `use_local_output=False` to keep as DTensor to preserve + # uneven activation dims + "1.in_proj": ColwiseParallel(use_local_output=False), + "1.out_proj": RowwiseParallel(use_local_output=False), + "2.in_proj": ColwiseParallel(use_local_output=False), + "2.out_proj": RowwiseParallel(use_local_output=False), + "3.in_proj": ColwiseParallel(use_local_output=False), + "3.out_proj": RowwiseParallel(), + }, + ) + for mlp in self: if use_activation_checkpointing: - checkpoint(module) - fully_shard( - module, mesh=dp_mesh, reshard_after_forward=reshard_after_forward - ) + checkpoint(mlp) + fully_shard(mlp, mesh=dp_mesh, reshard_after_forward=reshard_after_forward) fully_shard(self, mesh=dp_mesh, reshard_after_forward=reshard_after_forward) return self From 30b70b1a6398c15b5861b8f4e4199dd44ec31114 Mon Sep 17 00:00:00 2001 From: Peter Y Yeh Date: Fri, 17 May 2024 20:36:47 +0000 Subject: [PATCH 078/116] [ROCm] enable faster_load_save for Fused_SGD (#125456) Reopen due to rebase error. Fixes https://github.com/pytorch/pytorch/issues/117599 The reported hang test : `test_cuda.py::TestCuda::test_grad_scaling_autocast_fused_optimizers` is passing with this PR HSA Async copy / host wait on completion signal is resolved in MultiTensorApply.cuh ``` :4:command.cpp :347 : 8881368803196 us: [pid:1268211 tid:0x7f5af80d7180] Command (InternalMarker) enqueued: 0xc4e2070 :4:rocvirtual.cpp :556 : 8881368803201 us: [pid:1268211 tid:0x7f5af80d7180] Host wait on completion_signal=0x7f5967df3e00 :3:rocvirtual.hpp :66 : 8881368803209 us: [pid:1268211 tid:0x7f5af80d7180] Host active wait for Signal = (0x7f5967df3e00) for -1 ns ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/125456 Approved by: https://github.com/jeffdaily, https://github.com/eqy, https://github.com/janeyx99 --- aten/src/ATen/native/cuda/FusedSgdKernel.cu | 4 ---- 1 file changed, 4 deletions(-) diff --git a/aten/src/ATen/native/cuda/FusedSgdKernel.cu b/aten/src/ATen/native/cuda/FusedSgdKernel.cu index 61da02ce0b88..e644b1048c9b 100644 --- a/aten/src/ATen/native/cuda/FusedSgdKernel.cu +++ b/aten/src/ATen/native/cuda/FusedSgdKernel.cu @@ -86,12 +86,8 @@ struct FusedSgdMathFunctor { init_args(args, tl, chunk_idx, chunk_size, tensor_loc)}; const auto n = tl.numel_for_tensor[tensor_loc] - chunk_idx * chunk_size; -#ifndef USE_ROCM const auto use_faster_load_store = (n % kILP == 0) && (chunk_size % kILP == 0) && all_aligned; -#else - const auto use_faster_load_store{false}; -#endif if (use_faster_load_store) { for (auto i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; From bed1c600bbfa4552a1c0b71466c1220a52388220 Mon Sep 17 00:00:00 2001 From: Tugsbayasgalan Manlaibaatar Date: Thu, 16 May 2024 11:28:24 -0700 Subject: [PATCH 079/116] Experimental prototype for converting torch.jit.trace modules to export (#124449) Differential Revision: [D56440613](https://our.internmc.facebook.com/intern/diff/D56440613) We want to do this for following reasons: 1. There is current limitation in export tracing for torch.jit.trace d modules that cannot be easily upstreamed 2. We need to run internal CI regularly to understand feature gaps and continuously track them 3. Multiple people will be working on this prototype so it is better to have a checked in version so we don't always run into merge conflicts. Pull Request resolved: https://github.com/pytorch/pytorch/pull/124449 Approved by: https://github.com/angelayi, https://github.com/avikchaudhuri --- test/export/test_experimental.py | 31 ++++++++++++ torch/_export/non_strict_utils.py | 43 +++++++++++------ torch/export/_trace.py | 80 ++++++++++++++++++++++++++----- torch/export/dynamic_shapes.py | 23 +++++---- 4 files changed, 141 insertions(+), 36 deletions(-) diff --git a/test/export/test_experimental.py b/test/export/test_experimental.py index 2d7e88bfc111..b343dbff27a7 100644 --- a/test/export/test_experimental.py +++ b/test/export/test_experimental.py @@ -8,6 +8,7 @@ from torch._export.wrappers import _mark_strict_experimental from torch._functorch.aot_autograd import aot_export_module +from torch.export._trace import _convert_ts_to_export_experimental from torch.testing import FileCheck @@ -106,6 +107,36 @@ def forward(self, x): ): ep = torch.export.export(M(), inp, strict=False) + def test_torchscript_module_export(self): + class M(torch.nn.Module): + def forward(self, x): + return x.cos() + x.sin() + + model_to_trace = M() + inps = (torch.randn(4, 4),) + traced_module_by_torchscript = torch.jit.trace(M(), example_inputs=inps) + + exported_module = _convert_ts_to_export_experimental( + traced_module_by_torchscript, inps + ) + + self.assertTrue(torch.allclose(exported_module(*inps), model_to_trace(*inps))) + + def test_torchscript_module_export_single_input(self): + class M(torch.nn.Module): + def forward(self, x): + return x.cos() + x.sin() + + model_to_trace = M() + inps = torch.randn(4, 4) + traced_module_by_torchscript = torch.jit.trace(M(), example_inputs=inps) + + exported_module = _convert_ts_to_export_experimental( + traced_module_by_torchscript, inps + ) + + self.assertTrue(torch.allclose(exported_module(inps), model_to_trace(inps))) + if __name__ == "__main__": run_tests() diff --git a/torch/_export/non_strict_utils.py b/torch/_export/non_strict_utils.py index 42c20aa55500..aff3d444c960 100644 --- a/torch/_export/non_strict_utils.py +++ b/torch/_export/non_strict_utils.py @@ -110,7 +110,9 @@ def make_fake_params_buffers( return faked_params_buffers # type: ignore[return-value] -def make_fake_inputs(nn_module, args, kwargs, dynamic_shapes): +def make_fake_inputs( + nn_module, args, kwargs, dynamic_shapes, _is_torch_jit_trace=False +): """ Given an nn module, example inputs, and constraints, return a new fake mode, fake inputs created in that mode whose dynamic shape dimensions are constrained @@ -127,7 +129,7 @@ def make_fake_inputs(nn_module, args, kwargs, dynamic_shapes): # - [post-tracing] guards.py processes input shape equalities. constraints = torch.export.dynamic_shapes._process_dynamic_shapes( - nn_module, args, kwargs, dynamic_shapes + nn_module, args, kwargs, dynamic_shapes, _is_torch_jit_trace=_is_torch_jit_trace ) constraints = constraints or [] t_constraints: Dict[int, Dict[int, Constraint]] = defaultdict(dict) @@ -136,13 +138,6 @@ def make_fake_inputs(nn_module, args, kwargs, dynamic_shapes): if constraint.shared is not None: t_constraints[constraint.shared.t_id][constraint.shared.dim] = constraint - code = nn_module.forward.__code__ - co_fields = { - "co_name": code.co_name, - "co_filename": code.co_filename, - "co_firstlineno": code.co_firstlineno, - } - context = torch._guards.TracingContext.try_get() if context is not None: # This occurs when we are exporting within dynamo. There already exists @@ -153,11 +148,22 @@ def make_fake_inputs(nn_module, args, kwargs, dynamic_shapes): len(constraints) == 0 ), "Found constraints when tracing with a toplevel tracing context." fake_mode = context.fake_mode - else: + elif not _is_torch_jit_trace: + code = nn_module.forward.__code__ + co_fields = { + "co_name": code.co_name, + "co_filename": code.co_filename, + "co_firstlineno": code.co_firstlineno, + } fake_mode = FakeTensorMode( shape_env=ShapeEnv(tracked_fakes=[], co_fields=co_fields), allow_non_fake_inputs=True, ) + else: + fake_mode = FakeTensorMode( + shape_env=ShapeEnv(tracked_fakes=[]), + allow_non_fake_inputs=True, + ) if fake_mode.shape_env is None or fake_mode.shape_env.tracked_fakes is None: raise ValueError( "Detected fake_mode does not have a shape_env with tracked fakes. " @@ -166,7 +172,11 @@ def make_fake_inputs(nn_module, args, kwargs, dynamic_shapes): ) with fake_mode: - original_signature = inspect.signature(nn_module.forward) + # FIXME(ycao) ScriptMethod doesn't have signature, I am using an empty one to unblock + if not _is_torch_jit_trace: + original_signature = inspect.signature(nn_module.forward) + else: + original_signature = None sources: Dict[Tuple[int, int], List[Source]] = defaultdict(list) fake_args, fake_kwargs = tree_map_with_path( lambda kp, val: fakify(fake_mode, kp, val, t_constraints, sources), @@ -215,6 +225,7 @@ def produce_guards_and_solve_constraints( equalities_inputs: EqualityConstraint, original_signature: inspect.Signature, _disable_forced_specializations: Optional[bool] = False, + _is_torch_jit_trace=False, ): """ Given a fake mode, sources pairs corresponding to equal dynamic shape dimensions, @@ -259,9 +270,13 @@ def produce_guards_and_solve_constraints( ) dim_constraints.remove_redundant_dynamic_results() forced_specializations = dim_constraints.forced_specializations() - msg = dim_constraints.prettify_results( - original_signature, constraint_violation_error, forced_specializations - ) + if not _is_torch_jit_trace: + msg = dim_constraints.prettify_results( + original_signature, constraint_violation_error, forced_specializations + ) + else: + # FIXME(ycao): This is a hack to get around missing signature from ScriptMethod + msg = "dummy constraint violation message" if constraint_violation_error: constraint_violation_error.args = (constraint_violation_error.args[0] + msg,) elif forced_specializations: diff --git a/torch/export/_trace.py b/torch/export/_trace.py index 55c7442380f9..1c2f3b880f35 100644 --- a/torch/export/_trace.py +++ b/torch/export/_trace.py @@ -478,6 +478,7 @@ def _export_non_strict( transform=lambda x: x, # TODO(zhxchen17) Revisit if this is needed later. pre_dispatch=False, should_insert_runtime_assertion=False, + _is_torch_jit_trace=False, ): # [NOTE] If the user is exporting under training mode, we want to detect if there is any # state change in the autograd global state and error. If the user is exporting under inference @@ -632,16 +633,18 @@ def make_argument_spec(i, node) -> ArgumentSpec: constants = rewrite_script_object_meta(gm) constants.update(lift_constants_pass(gm, export_graph_signature, constant_attrs)) - # prettify names for placeholder nodes - placeholder_naming_pass( - gm, - export_graph_signature, - mod, - fake_args, - fake_kwargs, - fake_params_buffers, - constants, - ) + # FIXME: Skipping this because traced modules do not have signature yet + if not _is_torch_jit_trace: + # prettify names for placeholder nodes + placeholder_naming_pass( + gm, + export_graph_signature, + mod, + fake_args, + fake_kwargs, + fake_params_buffers, + constants, + ) @dataclasses.dataclass class _ExportedProgramNonStrict: @@ -889,6 +892,48 @@ def wrapper(*args, **kwargs): return wrapper +def _convert_ts_to_export_experimental(traced_callable, args, kwargs=None): + torch._C._jit_set_texpr_fuser_enabled(False) + + def process_trace_inputs_for_export(example_inputs, example_kwarg_inputs): + if not isinstance(example_inputs, tuple): + example_inputs = (example_inputs,) + + if example_kwarg_inputs is None: + example_kwarg_inputs = {} + return example_inputs, example_kwarg_inputs + + class _WrapperModule(torch.nn.Module): + def __init__(self, f): + super().__init__() + self.f = f + + def forward(self, *args, **kwargs): + return self.f(*args, **kwargs) + + from torch.jit._trace import TopLevelTracedModule + + export_args, export_kwargs = process_trace_inputs_for_export(args, kwargs) + + if isinstance(traced_callable, TopLevelTracedModule): + return _export( + traced_callable, + export_args, + export_kwargs, + strict=False, + _is_torch_jit_trace=True, + ).module() + + else: + return _export( + _WrapperModule(traced_callable), + export_args, + export_kwargs, + strict=False, + _is_torch_jit_trace=True, + ).module() + + @_log_export_wrapper @_disable_prexisiting_fake_mode def _export( @@ -901,6 +946,7 @@ def _export( preserve_module_call_signature: Tuple[str, ...] = (), pre_dispatch: bool = False, _disable_forced_specializations: Optional[bool] = False, + _is_torch_jit_trace: bool = False, ) -> ExportedProgram: """ Traces either an nn.Module's forward function or just a callable with PyTorch @@ -969,7 +1015,10 @@ def _export( flat_args, orig_in_spec = pytree.tree_flatten((args, kwargs)) original_state_dict = mod.state_dict(keep_vars=True) - forward_arg_names = _get_forward_arg_names(mod, args, kwargs) + if not _is_torch_jit_trace: + forward_arg_names = _get_forward_arg_names(mod, args, kwargs) + else: + forward_arg_names = None if not strict: out_spec = None @@ -1048,7 +1097,9 @@ def forward(self, *args, **kwargs): fake_kwargs, equalities_inputs, original_signature, - ) = make_fake_inputs(mod, args, kwargs, dynamic_shapes) + ) = make_fake_inputs( + mod, args, kwargs, dynamic_shapes, _is_torch_jit_trace=_is_torch_jit_trace + ) fake_params_buffers = make_fake_params_buffers( fake_mode, _get_params_buffers(mod) @@ -1071,6 +1122,7 @@ def forward(self, *args, **kwargs): pre_dispatch=pre_dispatch, transform=_tuplify_outputs, should_insert_runtime_assertion=not strict, + _is_torch_jit_trace=_is_torch_jit_trace, ) # ep_non_strict.constants contains only fake script objects, we need to map them back ep_non_strict.constants = { @@ -1099,6 +1151,7 @@ def forward(self, *args, **kwargs): equalities_inputs, original_signature, _disable_forced_specializations=_disable_forced_specializations, + _is_torch_jit_trace=_is_torch_jit_trace, ) except (ConstraintViolationError, ValueRangeError) as e: raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e)) # noqa: B904 @@ -1148,7 +1201,8 @@ def forward(self, *args, **kwargs): _rewrite_non_persistent_buffers(mod, ep_non_strict.sig, ep_non_strict.constants) _verify_nn_module_stack(gm) _verify_stack_trace(gm) - _verify_placeholder_names(gm, ep_non_strict.sig) + if not _is_torch_jit_trace: + _verify_placeholder_names(gm, ep_non_strict.sig) exported_program = ExportedProgram( root=gm, graph=gm.graph, diff --git a/torch/export/dynamic_shapes.py b/torch/export/dynamic_shapes.py index f3fecb1043fb..a4ed16e975b8 100644 --- a/torch/export/dynamic_shapes.py +++ b/torch/export/dynamic_shapes.py @@ -602,18 +602,20 @@ def f(t, *dynamic_shapes): return tree_map(f, tree, *dynamic_shapes, is_leaf=is_leaf) -def _combine_args(f, args, kwargs): +def _combine_args(f, args, kwargs, _is_torch_jit_trace=False): # combine args and kwargs following the signature of f, as it happens # in the body of f when called with *args, **kwargs if isinstance(f, ExportedProgram): f = f.module() - signature = ( - inspect.signature(f.forward) - if isinstance(f, torch.nn.Module) - else inspect.signature(f) - ) - kwargs = kwargs if kwargs is not None else {} - return signature.bind(*args, **kwargs).arguments + if not _is_torch_jit_trace: + signature = ( + inspect.signature(f.forward) + if isinstance(f, torch.nn.Module) + else inspect.signature(f) + ) + kwargs = kwargs if kwargs is not None else {} + return signature.bind(*args, **kwargs).arguments + return args class ShapesCollection: @@ -692,6 +694,7 @@ def _process_dynamic_shapes( args: Tuple[Any, ...], kwargs: Optional[Dict[str, Any]] = None, dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None, + _is_torch_jit_trace=False, ) -> Optional[List[Constraint]]: from torch._dynamo.exc import UserError, UserErrorType @@ -858,7 +861,9 @@ def assoc_shape(t, dynamic_shape): _tree_map(assoc_shape, combined_args, dynamic_shapes) - combined_args = _combine_args(f, args, kwargs) + combined_args = _combine_args( + f, args, kwargs, _is_torch_jit_trace=_is_torch_jit_trace + ) if not isinstance(dynamic_shapes, dict): assert isinstance(dynamic_shapes, (tuple, list)) combined_args = type(dynamic_shapes)(combined_args.values()) # type: ignore[assignment, misc] From 09fd77148527d81240c9040ce6f9a7868b22136e Mon Sep 17 00:00:00 2001 From: Catherine Lee Date: Fri, 17 May 2024 21:11:07 +0000 Subject: [PATCH 080/116] Disable vulkan test batch_norm_invalid_inputs (#126571) Fails flakily ex https://github.com/pytorch/pytorch/actions/runs/9130802617/job/25109131748 https://github.com/pytorch/pytorch/actions/runs/9125548571/job/25092535707 First bad I can find is https://hud.pytorch.org/pytorch/pytorch/commit/538877d2046a492a1112101e2d5d88e5754d477b Pull Request resolved: https://github.com/pytorch/pytorch/pull/126571 Approved by: https://github.com/SS-JIA --- aten/src/ATen/test/vulkan_api_test.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aten/src/ATen/test/vulkan_api_test.cpp b/aten/src/ATen/test/vulkan_api_test.cpp index 687691a370bf..e41d3d3d6abe 100644 --- a/aten/src/ATen/test/vulkan_api_test.cpp +++ b/aten/src/ATen/test/vulkan_api_test.cpp @@ -788,7 +788,7 @@ TEST_F(VulkanAPITest, avg_pool2d) { ASSERT_TRUE(check); } -TEST_F(VulkanAPITest, batch_norm_invalid_inputs) { +TEST_F(VulkanAPITest, DISABLED_batch_norm_invalid_inputs) { c10::InferenceMode mode; // Act: Vulkan batchnorm only supports evaluation mode From c26f6548f9f3bdebf77aed061e3772d077b3ea0d Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Fri, 17 May 2024 21:42:19 +0000 Subject: [PATCH 081/116] [AOTI] config target platform (#126306) Test Plan: AOTI compile stories15M for Android Differential Revision: D57392830 Pull Request resolved: https://github.com/pytorch/pytorch/pull/126306 Approved by: https://github.com/desertfire --- torch/_inductor/codecache.py | 8 ++++++++ torch/_inductor/config.py | 4 ++++ 2 files changed, 12 insertions(+) diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 0850b4a94bdc..70b467143111 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -1961,6 +1961,14 @@ def _compile_consts_linux(consts: bytes) -> str: return consts_o def _compile_consts_darwin(consts: bytes) -> str: + if config.aot_inductor.debug_dump_consts_bin: + _, _binary_constants_path = write( + consts, + "bin", + specified_dir=specified_output_path, + ) + log.debug("binary constants path: %s", _binary_constants_path) + is_large_consts = len(consts) > 1024 consts_asm = "\t.section\t__DATA,__data\n" consts_asm += "\t.globl\t__binary_constants_bin_start\n" diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 7f14dc62de93..db8a6d9ae3b6 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -727,6 +727,10 @@ class aot_inductor: debug_compile = os.environ.get("AOT_INDUCTOR_DEBUG_COMPILE", "0") == "1" + debug_dump_consts_bin: bool = ( + os.environ.get("AOT_INDUCTOR_DEBUG_DUMP_CONSTS_BIN", "0") == "1" + ) + # Serialized tree spec for flattening inputs serialized_in_spec = "" From d7de4c9d809697b36ae0fd9e16815f6e3b4d985b Mon Sep 17 00:00:00 2001 From: Yihan He Date: Fri, 17 May 2024 21:50:55 +0000 Subject: [PATCH 082/116] Fix issue of lowering nn.linear ops with kwargs (#126331) Summary: Support kwarg bias for nn.linear quantization Differential Revision: D57403190 Pull Request resolved: https://github.com/pytorch/pytorch/pull/126331 Approved by: https://github.com/ZhengkaiZ, https://github.com/huydhn --- torch/ao/quantization/fx/_lower_to_native_backend.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/torch/ao/quantization/fx/_lower_to_native_backend.py b/torch/ao/quantization/fx/_lower_to_native_backend.py index 728506037b55..049f4e3135d9 100644 --- a/torch/ao/quantization/fx/_lower_to_native_backend.py +++ b/torch/ao/quantization/fx/_lower_to_native_backend.py @@ -926,8 +926,14 @@ def _lower_dynamic_weighted_ref_functional( # Linear prepack args: (quantized weights[, bias]) # Conv prepack args: (quantized weights[, bias, stride, padding, dilation, groups]) prepack_args = [quantized_weight] + remaining_func_args + prepack_kwargs = {} if func_node.target == F.linear: prepack_op = get_linear_prepack_op_for_dtype(weight_dtype) + kwargs = func_node.kwargs.copy() + if 'bias' in kwargs: + prepack_kwargs['B'] = kwargs['bias'] + del kwargs['bias'] + func_node.kwargs = kwargs elif func_node.target in CONV_FUNCTIONAL_OPS: prepack_op = get_qconv_prepack_op(func_node.target) # For conv1d, the stride, padding, and dilation args may be ints, @@ -939,7 +945,7 @@ def _lower_dynamic_weighted_ref_functional( else: raise ValueError(f"Lowering is not supported for op '{func_node.target}'") with model.graph.inserting_before(func_node): - packed_weight = model.graph.create_node("call_function", prepack_op, tuple(prepack_args), {}) + packed_weight = model.graph.create_node("call_function", prepack_op, tuple(prepack_args), prepack_kwargs) # Step 3: Replace reference pattern with the corresponding quantized op func_node.target = q_relu_func if relu_node is not None else q_func From faa26df72e2a3ff08f9dd564bb50756916826854 Mon Sep 17 00:00:00 2001 From: "Andrew M. James" Date: Thu, 16 May 2024 23:00:21 +0000 Subject: [PATCH 083/116] [inductor] Load python modules using importlib (#126454) The `compile` + `exec` workflow is susceptible to behavior drifting from a "normal" import use importlib instead to avoid this. In particular here annotations were being stored as strings due to `from __futures__ import annotations` in the scope calling `compile`. Triton cares about annotations on global variables and this makes it much easier to reliably code-gen them. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126454 Approved by: https://github.com/peterbell10 --- torch/_inductor/runtime/compile_tasks.py | 29 ++++++++++++------------ 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/torch/_inductor/runtime/compile_tasks.py b/torch/_inductor/runtime/compile_tasks.py index 66a36703da45..b29a95f64b6c 100644 --- a/torch/_inductor/runtime/compile_tasks.py +++ b/torch/_inductor/runtime/compile_tasks.py @@ -1,10 +1,10 @@ from __future__ import annotations import functools +import importlib import os import sys import warnings -from types import ModuleType from typing import Any, Callable @@ -31,19 +31,20 @@ def _reload_python_module_in_subproc(key, path): def _reload_python_module(key, path): - with open(path) as f: - try: - code = compile(f.read(), path, "exec") - except Exception as e: - raise RuntimeError( - f"Failed to import {path}\n{type(e).__name__}: {e}" - ) from None - mod = ModuleType(f"{__name__}.{key}") - mod.__file__ = path - mod.key = key # type: ignore[attr-defined] - exec(code, mod.__dict__, mod.__dict__) - sys.modules[mod.__name__] = mod - return mod + spec = importlib.util.spec_from_file_location(f"{__name__}.{key}", path) + if spec is None or spec.loader is None: + raise RuntimeError(f"Failed to import {path}: path not found") + module = importlib.util.module_from_spec(spec) + module.key = key # type: ignore[attr-defined] + try: + spec.loader.exec_module(module) + except Exception as e: + raise RuntimeError( + f"Failed to import {path}\n{type(e).__name__}: {e}" + ) from None + + sys.modules[module.__name__] = module + return module @functools.lru_cache(None) From 173b1d811d430cc87f36cb16478d9eb1d57e93ed Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Fri, 17 May 2024 11:20:11 -0700 Subject: [PATCH 084/116] [dynamo] Sourceless builder - ordered dict and re.pattern (#126468) Pull Request resolved: https://github.com/pytorch/pytorch/pull/126468 Approved by: https://github.com/Skylion007 --- torch/_dynamo/variables/builder.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 8f9ab01088a7..41b9fbd836ae 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -2379,6 +2379,8 @@ def create(tx, value) -> VariableTracker: return PlacementVariable(value) elif DeviceMeshVariable.is_device_mesh(value): return DeviceMeshVariable(value) + elif isinstance(value, re.Pattern): + return RegexPatternVariable(value) unimplemented( f"Unexpected type in sourceless builder {value_type.__module__}.{value_type.__qualname__}" ) @@ -2399,6 +2401,7 @@ def make_type_handlers(): ) handlers[dict] = lambda tx, value: ConstDictVariable( {create(tx, k): create(tx, v) for k, v in value.items()}, + type(value), mutable_local=MutableLocal(), ) handlers[list] = lambda tx, value: ListVariable( @@ -2410,6 +2413,7 @@ def make_type_handlers(): handlers[torch.Size] = lambda tx, value: SizeVariable( [create(tx, x) for x in value] ) + handlers[collections.OrderedDict] = handlers[dict] handlers[immutable_dict] = handlers[dict] handlers[immutable_list] = handlers[list] handlers[types.ModuleType] = lambda tx, value: PythonModuleVariable(value) From d54c28e7fcb6377aa007c0d3e53837683420af4b Mon Sep 17 00:00:00 2001 From: Martim Mendes Date: Fri, 17 May 2024 23:41:43 +0000 Subject: [PATCH 085/116] Added error checks for invalid inputs on thnn_conv2d (#121906) Fixes #121188 Prevent Segmentation Fault in 'torch._C._nn.thnn_conv2d' Previously, calling 'torch._C._nn.thnn_conv2d' with invalid arguments for padding, stride, and kernel_size would result in a segmentation fault. This issue has been resolved by implementing argument validation (using Torch Check). Now, when invalid arguments are detected, a runtime error is raised with a debug message detailing the correct format. Additionally, this commit includes tests to cover the three referenced cases. Pull Request resolved: https://github.com/pytorch/pytorch/pull/121906 Approved by: https://github.com/janeyx99 --- aten/src/ATen/native/ConvolutionMM2d.cpp | 5 +++++ test/test_nn.py | 10 ++++++++++ 2 files changed, 15 insertions(+) diff --git a/aten/src/ATen/native/ConvolutionMM2d.cpp b/aten/src/ATen/native/ConvolutionMM2d.cpp index 686948584c72..10ab4a70f091 100644 --- a/aten/src/ATen/native/ConvolutionMM2d.cpp +++ b/aten/src/ATen/native/ConvolutionMM2d.cpp @@ -543,6 +543,11 @@ Tensor& slow_conv2d_forward_out_cpu( IntArrayRef padding, Tensor& output) { // See [Note: hacky wrapper removal for optional tensor] + + TORCH_CHECK(kernel_size.size() == 2, "2D kernel_size expected"); + TORCH_CHECK(stride.size() == 2, "2D stride expected"); + TORCH_CHECK(padding.size() == 2, "2D padding expected"); + c10::MaybeOwned bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt); const Tensor& bias = *bias_maybe_owned; diff --git a/test/test_nn.py b/test/test_nn.py index 008354ad721e..76bc614f025d 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -8205,6 +8205,16 @@ def help(input, conv, memory_format): weight = torch.empty([1, 0, 1], dtype=dtype, device=device) torch._C._nn.slow_conv3d(inp, weight, 1) + with self.assertRaisesRegex(RuntimeError, re.escape("2D kernel_size expected")): + torch._C._nn.thnn_conv2d(torch.rand([1, 1, 1, 1]), kernel_size=[], padding=[1, 1], stride=[1, 1], + weight=torch.rand([1, 1])) + with self.assertRaisesRegex(RuntimeError, re.escape("2D stride expected")): + torch._C._nn.thnn_conv2d(torch.rand([1, 1, 1, 1]), kernel_size=[1, 1], padding=[1, 1], stride=[], + weight=torch.rand([1, 1])) + with self.assertRaisesRegex(RuntimeError, re.escape("2D padding expected")): + torch._C._nn.thnn_conv2d(torch.rand([1, 1, 1, 1]), kernel_size=[1, 1], padding=[], stride=[1, 1], + weight=torch.rand([1, 1])) + def test_InstanceNorm1d_general(self, device): b = random.randint(3, 5) c = random.randint(3, 5) From 93844a31b30c6b99fd18c37ee466637cca708d14 Mon Sep 17 00:00:00 2001 From: Nikita Shulga <2453524+malfet@users.noreply.github.com> Date: Fri, 17 May 2024 23:47:08 +0000 Subject: [PATCH 086/116] Fix aarch64 debug build with GCC (#126290) By working around GCCs quirks in instantiating templates that require immediate values. Provide alternative implementation for scaling the output if compiled without any optimizations (both GCC and clang define `__OPTIMIZE__` if invoked with anything but `-O0`) Test plan (after change was reverted): ssh into aarch64 runner and rebuild given file with `-O0` Fixes https://github.com/pytorch/pytorch/issues/126283 Pull Request resolved: https://github.com/pytorch/pytorch/pull/126290 Approved by: https://github.com/atalman, https://github.com/seemethere --- aten/src/ATen/native/cpu/int8mm_kernel.cpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/aten/src/ATen/native/cpu/int8mm_kernel.cpp b/aten/src/ATen/native/cpu/int8mm_kernel.cpp index bd266030b256..d61a1933afc7 100644 --- a/aten/src/ATen/native/cpu/int8mm_kernel.cpp +++ b/aten/src/ATen/native/cpu/int8mm_kernel.cpp @@ -250,10 +250,18 @@ inline void tinygemm_kernel_( }); } +#if __OPTIMIZE__ float32x4_t scale_val = load_as_float32x4(scales); c10::ForcedUnroll{}([&](auto i) { C[m * ldc + i] = reduce(c_val[i]) * vgetq_lane_f32(scale_val, i); }); +#else + // Workaround GCCs inability to infer lane index at compile time + // See https://github.com/pytorch/pytorch/issues/126283 + c10::ForcedUnroll{}([&](auto i) { + C[m * ldc + i] = reduce(c_val[i]) * float(scales[i]); + }); +#endif } } From 54bc55c515109b7162703610c589148cb1306d21 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Thu, 16 May 2024 18:19:09 -0700 Subject: [PATCH 087/116] Remove dist_ prefix from TORCH_LOGS shortcuts (#126499) e.g. dist_ddp -> ddp 'distributed' shortcut remains unchained Feedback has been that it is not appealing to have the dist_ prefix, and the main reason for it was to keep the distributed shortcuts grouped together in the help menu. It's nice to have shorter shortcuts. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126499 Approved by: https://github.com/XilunWu, https://github.com/kwen2501 ghstack dependencies: #126322 --- torch/_logging/_registrations.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torch/_logging/_registrations.py b/torch/_logging/_registrations.py index d76b5610e97e..10463b864f44 100644 --- a/torch/_logging/_registrations.py +++ b/torch/_logging/_registrations.py @@ -26,13 +26,13 @@ register_log("torch", "torch") register_log("distributed", DISTRIBUTED) register_log( - "dist_c10d", ["torch.distributed.distributed_c10d", "torch.distributed.rendezvous"] + "c10d", ["torch.distributed.distributed_c10d", "torch.distributed.rendezvous"] ) register_log( - "dist_ddp", ["torch.nn.parallel.distributed", "torch._dynamo.backends.distributed"] + "ddp", ["torch.nn.parallel.distributed", "torch._dynamo.backends.distributed"] ) -register_log("dist_pp", ["torch.distributed.pipelining"]) -register_log("dist_fsdp", ["torch.distributed.fsdp"]) +register_log("pp", ["torch.distributed.pipelining"]) +register_log("fsdp", ["torch.distributed.fsdp"]) register_log("onnx", "torch.onnx") register_log("export", ["torch._dynamo", "torch.export", *DYNAMIC]) From 0d5ba547ec1923c27f45516ae509388b8c7aff0c Mon Sep 17 00:00:00 2001 From: Sherlock Huang Date: Sat, 18 May 2024 00:10:46 +0000 Subject: [PATCH 088/116] Tool for scouting exportability in one shot (#126471) Summary: Tool for scouting exportability issues in one shot. - Collect sample inputs for all submodules by running eager inference with forward_pre_hook. - Start from root module, recursively try exporting child modules, if current module export fails. Limitations: - only works for nn.module that contains tree-like submodules structure. this doesn't work for flatten GraphModule. TODO: support dynamic_dims Sample output: https://docs.google.com/spreadsheets/d/1jnixrqBTYbWO_y6AaKA13XqOZmeB1MQAMuWL30dGoOg/edit?usp=sharing ``` exportability_report = { '': UnsupportedOperatorException(func=), 'submod_1': UnsupportedOperatorException(func=), 'submod_2': None } ``` Test Plan: buck2 run mode/dev-nosan fbcode//caffe2/test:test_export -- -r TestExportTools Differential Revision: D57466486 Pull Request resolved: https://github.com/pytorch/pytorch/pull/126471 Approved by: https://github.com/zhxchen17 --- test/export/test_tools.py | 67 ++++++++++++++++++ torch/_export/tools.py | 139 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 206 insertions(+) create mode 100644 test/export/test_tools.py create mode 100644 torch/_export/tools.py diff --git a/test/export/test_tools.py b/test/export/test_tools.py new file mode 100644 index 000000000000..b8ab7616fd67 --- /dev/null +++ b/test/export/test_tools.py @@ -0,0 +1,67 @@ +# Owner(s): ["oncall: export"] + +import torch +from torch._dynamo.test_case import TestCase +from torch._export.tools import report_exportability + +from torch.testing._internal.common_utils import run_tests + +torch.library.define( + "testlib::op_missing_meta", + "(Tensor(a!) x, Tensor(b!) z) -> Tensor", + tags=torch.Tag.pt2_compliant_tag, +) + + +@torch.library.impl("testlib::op_missing_meta", "cpu") +@torch._dynamo.disable +def op_missing_meta(x, z): + x.add_(5) + z.add_(5) + return x + z + + +class TestExportTools(TestCase): + def test_report_exportability_basic(self): + class Module(torch.nn.Module): + def forward(self, x, y): + return x[0] + y + + f = Module() + inp = ([torch.ones(1, 3)], torch.ones(1, 3)) + + report = report_exportability(f, inp) + self.assertTrue(len(report) == 1) + self.assertTrue(report[""] is None) + + def test_report_exportability_with_issues(self): + class Unsupported(torch.nn.Module): + def forward(self, x): + return torch.ops.testlib.op_missing_meta(x, x.cos()) + + class Supported(torch.nn.Module): + def forward(self, x): + return x.sin() + + class Module(torch.nn.Module): + def __init__(self): + super().__init__() + self.unsupported = Unsupported() + self.supported = Supported() + + def forward(self, x): + y = torch.nonzero(x) + return self.unsupported(y) + self.supported(y) + + f = Module() + inp = (torch.ones(4, 4),) + + report = report_exportability(f, inp, strict=False, pre_dispatch=True) + + self.assertTrue(report[""] is not None) + self.assertTrue(report["unsupported"] is not None) + self.assertTrue(report["supported"] is None) + + +if __name__ == "__main__": + run_tests() diff --git a/torch/_export/tools.py b/torch/_export/tools.py new file mode 100644 index 000000000000..d76392993bd2 --- /dev/null +++ b/torch/_export/tools.py @@ -0,0 +1,139 @@ +import logging +import warnings +from typing import Any, Dict, Iterable, Optional, Tuple + +import torch +import torch.export +import torch.export._trace +from torch._utils_internal import log_export_usage + +log = logging.getLogger(__name__) + +__all__ = ["report_exportability"] + + +def _generate_inputs_for_submodules( + model: torch.nn.Module, + target_submodules: Iterable[str], + args: Tuple[Any, ...], + kwargs: Optional[Dict[str, Any]] = None, +) -> Dict[str, Tuple[Any, Any]]: + """ + Generate inputs for targeting submdoules in the given model. Note that if two submodules refer to the same obj, this + function doesn't work. + + Args: + model: root model. + inputs: inputs to the root model. + target_submodules: submodules that we want to generate inputs for. + + Returns: + A dict that maps from submodule name to its inputs. + """ + kwargs = kwargs or {} + + handles = [] + results = {} + submodule_to_names = {mod: name for name, mod in model.named_modules()} + + def pre_forward(module, module_args, module_kwargs): + results[submodule_to_names[module]] = (module_args, module_kwargs) + + try: + for name, mod in model.named_modules(): + if name in target_submodules: + handles.append( + mod.register_forward_pre_hook(pre_forward, with_kwargs=True) + ) + model(*args, **kwargs) + except Exception as e: + warnings.warn( + f"Failed to generate submodule inputs because of the following error:\n{e}" + ) + finally: + for h in handles: + h.remove() + return results + + +def report_exportability( + mod: torch.nn.Module, + args: Tuple[Any, ...], + kwargs: Optional[Dict[str, Any]] = None, + *, + strict: bool = True, + pre_dispatch: bool = False, +) -> Dict[str, Optional[Exception]]: + """ + Report exportability issues for a module in one-shot. + + Args: + mod: root module. + args: args to the root module. + kwargs: kwargs to the root module. + Returns: + A dict that maps from submodule name to the exception that was raised when trying to export it. + `None` means the module is exportable without issue. + Sample output: + { + '': UnsupportedOperatorException(func=), + 'submod_1': UnsupportedOperatorException(func=), + 'submod_2': None + } + """ + + log_export_usage(event="export.report_exportability") + + kwargs = kwargs or {} + + all_submod_names = [name for name, _ in mod.named_modules() if name != ""] + submod_inputs = _generate_inputs_for_submodules(mod, all_submod_names, args, kwargs) + + report: Dict[str, Optional[Exception]] = {} + + def try_export(module, module_name, args, kwargs): + nonlocal submod_inputs, report, strict, pre_dispatch + + if args is not None or kwargs is not None: + try: + torch.export._trace._export( + module, + args, + kwargs, + strict=strict, + pre_dispatch=pre_dispatch, + ) + report[module_name] = None + log.info("Successfully exported `%s`", module_name) + return + except Exception as e: + short_msg = repr(e).split("\n")[0] + log.warning( + "Failed exporting `%s` with exception: %s", module_name, short_msg + ) + report[module_name] = e + + for name, submod in module.named_children(): + sub_module_name = name if module_name == "" else f"{module_name}.{name}" + + submod_args, submod_kwargs = submod_inputs.get( + sub_module_name, (None, None) + ) + + try_export(submod, sub_module_name, submod_args, submod_kwargs) + + return + + try_export(mod, "", args, kwargs) + + unique_issues = set() + for exception in report.values(): + if exception is not None: + key = repr(exception).split("\\n")[0] + unique_issues.add(key) + + log.warning("Found %d export issues:", len(unique_issues)) + for issue in unique_issues: + log.warning(issue) + + return report From 2863c76b1fe4c28721935ae63f637c1ea5de0e47 Mon Sep 17 00:00:00 2001 From: Kostas Tsiampouris Date: Sat, 18 May 2024 00:17:13 +0000 Subject: [PATCH 089/116] [torch-distributed] Make log directory creation idempotent (#126496) Summary: https://docs.python.org/3/library/os.html#os.makedirs > If exist_ok is False (the default), a FileExistsError is raised if the target directory already exists. Test Plan: Existing tests Differential Revision: D57471577 Pull Request resolved: https://github.com/pytorch/pytorch/pull/126496 Approved by: https://github.com/d4l3k --- torch/distributed/elastic/multiprocessing/api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/distributed/elastic/multiprocessing/api.py b/torch/distributed/elastic/multiprocessing/api.py index 4c584dc32e70..eb0b110f25ee 100644 --- a/torch/distributed/elastic/multiprocessing/api.py +++ b/torch/distributed/elastic/multiprocessing/api.py @@ -246,7 +246,7 @@ def __init__( if not log_dir: log_dir = tempfile.mkdtemp(prefix="torchelastic_") elif not os.path.exists(log_dir): - os.makedirs(log_dir) + os.makedirs(log_dir, exist_ok=True) else: if os.path.isfile(log_dir): raise NotADirectoryError(f"log_dir: {log_dir} is a file") From 41fb4bcc73caf12d0c975931e2c3088448db7fbe Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Sat, 18 May 2024 00:39:42 +0000 Subject: [PATCH 090/116] [AOTI] Flag to include aoti sources when building lite interpreter (#126572) Summary: Added USE_LITE_AOTI cmake flag, which is turned OFF by default. When it is turned on, the AOTI sources (inductor_core_resources) are included when building lite interpreter Test Plan: ``` ANDROID_ABI=arm64-v8a ./scripts/build_android.sh -DUSE_LITE_AOTI=ON ``` Differential Revision: D57394078 Pull Request resolved: https://github.com/pytorch/pytorch/pull/126572 Approved by: https://github.com/malfet --- CMakeLists.txt | 3 +++ caffe2/CMakeLists.txt | 3 +++ 2 files changed, 6 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 1925bd8636f4..02cf8dedc79e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -279,6 +279,9 @@ endif() option(USE_SLEEF_FOR_ARM_VEC256 "Use sleef for arm" OFF) option(USE_SOURCE_DEBUG_ON_MOBILE "Enable" ON) option(USE_LITE_INTERPRETER_PROFILER "Enable" ON) +cmake_dependent_option( + USE_LITE_AOTI "Include AOTI sources" OFF + "BUILD_LITE_INTERPRETER" OFF) option(USE_VULKAN_FP16_INFERENCE "Vulkan - Use fp16 inference" OFF) option(USE_VULKAN_RELAXED_PRECISION "Vulkan - Use relaxed precision math in the kernels (mediump)" OFF) # option USE_XNNPACK: try to enable xnnpack by default. diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index bd2588b5aef3..369bb9b106a0 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -457,6 +457,9 @@ if(BUILD_LITE_INTERPRETER) append_filelist("libtorch_lite_cmake_sources" LIBTORCH_CMAKE_SRCS) list(APPEND LIBTORCH_CMAKE_SRCS ${LITE_EAGER_SYMOBLICATION_SRCS}) list(APPEND LIBTORCH_CMAKE_SRCS ${LITE_PROFILER_SRCS}) + if(USE_LITE_AOTI) + append_filelist("inductor_core_resources" LIBTORCH_CMAKE_SRCS) + endif() set(CMAKE_POSITION_INDEPENDENT_CODE TRUE) else() append_filelist("libtorch_cmake_sources" LIBTORCH_CMAKE_SRCS) From bcee6f708afcf437c85dfa8d06997bebd5044fd0 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Fri, 17 May 2024 12:05:21 -0700 Subject: [PATCH 091/116] [Pipelining] Fix 1f1b schedule (#126419) This schedule was running fine locally but failing (hanging) on CI. After analysis (https://fburl.com/gdoc/xt80h1gd), it seems like the schedule was not correct previously but may still work depending on the runtime. The fix bundles together fwd-recv(s->s+1) and bwd-send(s+1->s) into one coalesced group so they would not block each other. Design drawing image Flight recorder traces show the same coalescing pattern as designed image Pull Request resolved: https://github.com/pytorch/pytorch/pull/126419 Approved by: https://github.com/c-p-i-o, https://github.com/kwen2501 --- .../pipelining/PipelineSchedule.py | 62 +++++++++++++++---- 1 file changed, 51 insertions(+), 11 deletions(-) diff --git a/torch/distributed/pipelining/PipelineSchedule.py b/torch/distributed/pipelining/PipelineSchedule.py index bf4116688bff..c8a256299bd5 100644 --- a/torch/distributed/pipelining/PipelineSchedule.py +++ b/torch/distributed/pipelining/PipelineSchedule.py @@ -411,34 +411,71 @@ def _step_microbatches( fwd_sends_to_wait: List[dist.Work] = [] bwd_sends_to_wait: List[dist.Work] = [] + def is_forward_step(i): + assert i >= 0, i + return i < self._n_microbatches + + def is_backward_step(i): + assert i < total_steps, i + return i >= warmup_steps and self._has_backward + + def is_1f1b_step(i): + return is_forward_step(i) and is_backward_step(i) + + def is_warmup_step(i): + return is_forward_step(i) and not is_backward_step(i) + + def is_cooldown_step(i): + return not is_forward_step(i) and is_backward_step(i) + + def should_coalesce_fwd_send_bwd_recv(fwd_send_i): + return ( + is_1f1b_step(fwd_send_i) + or (is_warmup_step(fwd_send_i) and is_cooldown_step(fwd_send_i + 1)) + or ( + fwd_send_i >= 1 + and is_warmup_step(fwd_send_i - 1) + and is_cooldown_step(fwd_send_i) + ) + ) + + def should_coalesce_bwd_send_fwd_recv(bwd_send_i): + # The backward send to prev stage should be coalesced with the fwd recv from the previous stage + return bwd_send_i >= warmup_steps and is_1f1b_step(bwd_send_i + 1) + # bwd chunk counter bwd_mb_index = 0 self._stage._configure_data_parallel_mode(last_backward=False) for i in range(total_steps): - if i < self._n_microbatches: - # forward + if is_forward_step(i): with record_function(f"Forward {i}"): ops = self._stage.get_fwd_recv_ops() + if should_coalesce_bwd_send_fwd_recv(i - 1): + ops.extend(self._stage.get_bwd_send_ops()) + works = sorted_batch_isend_irecv(ops) for work in works.values(): work.wait() output = self._stage.forward_one_chunk(arg_mbs[i], kwarg_mbs[i]) # type: ignore[index] - ops = self._stage.get_fwd_send_ops() - works = sorted_batch_isend_irecv(ops) - fwd_sends_to_wait.extend(works.values()) + if not should_coalesce_fwd_send_bwd_recv(i): + ops = self._stage.get_fwd_send_ops() + works = sorted_batch_isend_irecv(ops) + fwd_sends_to_wait.extend(works.values()) self._maybe_compute_loss(self._stage, output, target_mbs, i) - if i >= warmup_steps and self._has_backward: + if is_backward_step(i): self._stage._configure_data_parallel_mode( last_backward=(i == total_steps - 1) ) - - # backward with record_function(f"Backward {bwd_mb_index}"): ops = self._stage.get_bwd_recv_ops() + + if should_coalesce_fwd_send_bwd_recv(i): + ops.extend(self._stage.get_fwd_send_ops()) + works = sorted_batch_isend_irecv(ops) for work in works.values(): work.wait() @@ -446,9 +483,12 @@ def _step_microbatches( loss = self._maybe_get_loss(self._stage, bwd_mb_index) self._stage.backward_one_chunk(loss=loss) - ops = self._stage.get_bwd_send_ops() - works = sorted_batch_isend_irecv(ops) - bwd_sends_to_wait.extend(works.values()) + if not should_coalesce_bwd_send_fwd_recv(i): + # see Note: coalesced bwd-send/fwd-recv + ops = self._stage.get_bwd_send_ops() + works = sorted_batch_isend_irecv(ops) + bwd_sends_to_wait.extend(works.values()) + bwd_mb_index += 1 # Wait for all forward sends to finish From 224f2bef9f80368ed54a67942d008ef64c5b201a Mon Sep 17 00:00:00 2001 From: Will Constable Date: Fri, 17 May 2024 12:05:22 -0700 Subject: [PATCH 092/116] [C10D] Add __repr__ to P2POp class (#126538) Pull Request resolved: https://github.com/pytorch/pytorch/pull/126538 Approved by: https://github.com/Skylion007, https://github.com/kwen2501, https://github.com/c-p-i-o ghstack dependencies: #126419 --- torch/distributed/distributed_c10d.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 21d7abd93837..70283cada928 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -408,6 +408,21 @@ def __new__(cls, op: Callable, tensor: torch.Tensor, peer: int, _check_single_tensor(tensor, "tensor") return object.__new__(cls) + def __repr__(self): + my_group_rank = get_rank(self.group) + peer_group_rank = get_group_rank(self.group, self.peer) if self.group else self.peer + op_name = self.op.__name__ + group_name = self.group.group_name if self.group else "default_pg" + if "send" in op_name: + s = my_group_rank + d = peer_group_rank + elif "recv" in op_name: + s = peer_group_rank + d = my_group_rank + else: + return super().__repr__() + + return f"P2POp({op_name} pg={group_name}, s={s}, d={d}, {self.tensor.shape}, {self.tensor.dtype})" class _CollOp: """ From 661ecedbd0478cb5b20b4d7a1e7aa454c218abf1 Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Sat, 18 May 2024 01:31:23 +0000 Subject: [PATCH 093/116] gitmodules: switch cpp-httplib to https (#126580) Fixes issue introduced in https://github.com/pytorch/pytorch/pull/126470#issuecomment-2118374811 Test plan: CI Pull Request resolved: https://github.com/pytorch/pytorch/pull/126580 Approved by: https://github.com/PaliC, https://github.com/jeffdaily --- .gitmodules | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitmodules b/.gitmodules index 0d9a339fb53e..811be7863194 100644 --- a/.gitmodules +++ b/.gitmodules @@ -150,5 +150,5 @@ url = https://github.com/open-telemetry/opentelemetry-cpp.git [submodule "third_party/cpp-httplib"] path = third_party/cpp-httplib - url = git@github.com:yhirose/cpp-httplib.git + url = https://github.com/yhirose/cpp-httplib.git branch = v0.15.3 From 1191168c452a8131fcf5ff020fdb5940d7dca7bc Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Tue, 14 May 2024 14:47:42 -0700 Subject: [PATCH 094/116] [pipelining] Follow improvements in export.unflatten (#126217) Previously, we make a copy of `torch.export.unflatten` in pippy/_unflatten.py. But it turns out to be too hard to track bug fixes and improvements in upstream version. For example, `torch.export.unflatten` recently added support for tied parameters, which is something pipelining needs. Now that we moved into pytorch, we make a reference to `torch.export.unflatten` instead of maintaining a copy. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126217 Approved by: https://github.com/H-Huang --- test/distributed/pipelining/test_unflatten.py | 7 +- torch/distributed/pipelining/_IR.py | 7 +- torch/distributed/pipelining/_unflatten.py | 522 +----------------- 3 files changed, 9 insertions(+), 527 deletions(-) diff --git a/test/distributed/pipelining/test_unflatten.py b/test/distributed/pipelining/test_unflatten.py index 9c388d279cdf..37eaf599e4d8 100644 --- a/test/distributed/pipelining/test_unflatten.py +++ b/test/distributed/pipelining/test_unflatten.py @@ -45,7 +45,6 @@ def test_unflatten(self): constant = torch.ones(1, 16, 256, 256) mod = M() - print("Original model:\n", mod) pipe = pipeline( mod, @@ -58,21 +57,19 @@ def test_unflatten(self): orig_state_dict = mod.state_dict() # Check qualnames - print("\nParameters of each stage:") for stage_idx in range(pipe.num_stages): - print(f"\nStage {stage_idx}:") stage_mod = pipe.get_stage_module(stage_idx) for param_name, param in stage_mod.named_parameters(): assert ( param_name in orig_state_dict ), f"{param_name} not in original state dict" - print(f"{param_name}: {param.size()}") + print("Param qualname test passed") # Check equivalence ref = mod(x, constant) out = pipe(x, constant)[0] torch.testing.assert_close(out, ref) - print(f"\nEquivalence test passed {torch.sum(out)} ref {torch.sum(ref)}") + print(f"Equivalence test passed {torch.sum(out)} ref {torch.sum(ref)}") if __name__ == "__main__": diff --git a/torch/distributed/pipelining/_IR.py b/torch/distributed/pipelining/_IR.py index 1beb75e9df66..aeb1a676c99e 100644 --- a/torch/distributed/pipelining/_IR.py +++ b/torch/distributed/pipelining/_IR.py @@ -11,12 +11,13 @@ import torch import torch.fx as fx from torch.export import ExportedProgram +from torch.export.unflatten import _assign_attr, _AttrKind, _sink_params from torch.fx.node import map_aggregate from torch.fx.passes.split_module import split_module from ._backward import _null_coalesce_accumulate, stage_backward from ._debug import PIPPY_VERBOSITY -from ._unflatten import _assign_attr, _AttrKind, _outline_submodules, _sink_params +from ._unflatten import _outline_submodules from ._utils import QualnameMapMixin from .microbatch import split_args_kwargs_into_chunks, TensorChunkSpec @@ -869,8 +870,8 @@ def move_param_to_callee( # After moving the params to their corresponding hierarchies, we also # need to move the `get_attr` nodes from the root of the graph to those # hierarchies. - inputs_to_state: Dict[str, str] = { - attr.name: attr.target for attr in attr_nodes + inputs_to_state: Dict[str, List[str]] = { + attr.name: [attr.target] for attr in attr_nodes } # This is done by (1) `_sind_params` at each submodule; for name, submod in split.named_children(): diff --git a/torch/distributed/pipelining/_unflatten.py b/torch/distributed/pipelining/_unflatten.py index 684fcfbc1d6d..27241d17874c 100644 --- a/torch/distributed/pipelining/_unflatten.py +++ b/torch/distributed/pipelining/_unflatten.py @@ -1,453 +1,8 @@ # Copyright (c) Meta Platforms, Inc. and affiliates -# This file is a copy of private utilities in pytorch/torch/export/unflatten.py -# pylint: skip-file - -import copy -import operator -from enum import Enum -from typing import cast, Dict, List, Optional, Union +from typing import Dict import torch -import torch.fx._pytree as fx_pytree -import torch.utils._pytree as pytree -from torch.export.exported_program import ( - ConstantArgument, - ModuleCallSignature, - SymIntArgument, - TensorArgument, -) -from torch.export.unflatten import InterpreterModule - - -class _AttrKind(Enum): - PARAMETER = "parameter" - BUFFER = "buffer" - CONSTANT = "constant" - - -# Assign attribute 'from_obj' to the qualified name 'target' on 'to_module -# This installs empty Modules where none exist yet if they are subpaths of target -def _assign_attr( - from_obj: Union[torch.Tensor, torch.ScriptObject], - to_module: torch.nn.Module, - target: str, - attr_kind: _AttrKind, - persistent: bool = True, -): - *prefix, field = target.split(".") - for item in prefix: - t = getattr(to_module, item, None) - - if t is None: - t = torch.nn.Module() - setattr(to_module, item, t) - to_module = t - - if attr_kind == _AttrKind.PARAMETER: - assert isinstance(from_obj, torch.nn.Parameter) - to_module.register_parameter(field, from_obj) - elif attr_kind == _AttrKind.BUFFER: - assert isinstance(from_obj, torch.Tensor) - to_module.register_buffer(field, from_obj, persistent=persistent) - elif attr_kind == _AttrKind.CONSTANT: - assert isinstance(from_obj, (torch.Tensor, torch.ScriptObject)) - setattr(to_module, field, from_obj) - - -def _is_prefix(candidate, target): - """Check whether `candidate` is a prefix of `target`.""" - return len(candidate) < len(target) and target[: len(candidate)] == candidate - - -def _compute_accessor(parent_fqn: str, child_fqn: str) -> str: - if parent_fqn == "": - # Handle the root module correctly. - return child_fqn - - parent_split = parent_fqn.split(".") - child_split = child_fqn.split(".") - - assert ( - child_split[: len(parent_split)] == parent_split - ), f"Child module '{child_fqn}' is not a descendant of parent module '{parent_fqn}'" - return ".".join(child_split[len(parent_split) :]) - - -def _verify_graph_equivalence(x: torch.nn.Module, y: torch.nn.Module): - def graph_dump(graph: torch.fx.Graph) -> str: - ret = [] - nodes_idx: Dict[int, int] = {} - - def arg_dump(arg) -> str: - if isinstance(arg, torch.fx.Node): - return "%" + str(nodes_idx[id(arg)]) - return str(arg) - - for i, node in enumerate(graph.nodes): - args_dump = [str(arg) for arg in pytree.tree_map(arg_dump, node.args)] - args_dump += [ - f"{key}={value}" - for key, value in pytree.tree_map(arg_dump, node.kwargs).items() - ] - target = node.target if node.op == "call_function" else "" - ret.append(f"{i}: {node.op}[{target}]({', '.join(args_dump)})") - nodes_idx[id(node)] = i - return "\n".join(ret) - - assert graph_dump(x.graph) == graph_dump(y.graph) - - -def _add_spec(gm: torch.nn.Module, spec) -> str: - i = 0 - while hasattr(gm, f"_spec_{i}"): - i += 1 - name = f"_spec_{i}" - setattr(gm, name, spec) - return name - - -def _generate_flatten(gm: torch.nn.Module, node, spec) -> torch.fx.Node: - name = _add_spec(gm, spec) - spec_node = gm.graph.get_attr(name) - return gm.graph.call_function(fx_pytree.tree_flatten_spec, (node, spec_node)) - - -def _generate_unflatten(gm: torch.nn.Module, nodes, spec) -> torch.fx.Node: - name = _add_spec(gm, spec) - spec_node = gm.graph.get_attr(name) - return gm.graph.call_function(pytree.tree_unflatten, (nodes, spec_node)) - - -def _add_submodule(mod: torch.nn.Module, target: str, module_to_add: torch.nn.Module): - *prefix, field = target.split(".") - - for item in prefix: - submod = getattr(mod, item, None) - - if submod is None: - submod = torch.nn.Module() - setattr(mod, item, submod) - - if not isinstance(submod, torch.nn.Module): - return False - - mod = submod - - mod.add_module(field, module_to_add) - - -class _ModuleFrame: - def __init__( - self, - flat_graph, - nodes, - seen_nodes, - seen_modules, - parent, - module_stack, - module_id, - module_call_graph: Optional[Dict[str, ModuleCallSignature]] = None, - module: Optional[torch.nn.Module] = None, - ): - self.flat_graph = flat_graph - self.nodes = nodes - self.seen_nodes = seen_nodes - self.seen_modules = seen_modules - self.parent = parent - self.module_stack = module_stack - self.module_id = module_id - - self.module_call_graph = module_call_graph - self.verbose = False - - self.fqn = self.module_stack[-1] - if module is not None: - self.module = module - else: - self.module = InterpreterModule(torch.fx.Graph()) - if self.module_id in self.seen_modules: - self.cached_graph_module = self.seen_modules[self.module_id] - else: - self.cached_graph_module = None - self.seen_modules[self.module_id] = self.module - - self.graph = self.module.graph - - # Mapping of nodes in the flat graph to nodes in this graph. - self.node_map: Dict[torch.fx.Node, torch.fx.Node] = {} - self.node_to_placeholder = {} - - self.parent_call_module: Optional[torch.fx.Node] = None - if parent is not None: - accessor = _compute_accessor(parent.fqn, self.fqn) - _add_submodule( - parent.module, - accessor, - self.module - if self.cached_graph_module is None - else self.cached_graph_module, - ) - self.parent_call_module = parent.graph.call_module(accessor) - - signature = self.get_signature() - - if signature is not None and self.parent is not None: - assert signature.in_spec.num_children == 2 - args_spec = signature.in_spec.children_specs[0] - kwargs_spec = signature.in_spec.children_specs[1] - assert args_spec.context is None - assert kwargs_spec.context is not None - - with self.graph.inserting_after(None): - arg_nodes = [] - for idx in range(args_spec.num_children): - arg_nodes.append(self.graph.placeholder(f"_positional_arg_{idx}")) - kwarg_nodes = {} - for name in kwargs_spec.context: - kwarg_nodes[name] = self.graph.placeholder(name) - flat_args = _generate_flatten( - self.module, - (tuple(arg_nodes), kwarg_nodes), - signature.in_spec, - ) - for idx, arg in enumerate(signature.inputs): - flat_arg_node = self.graph.create_node( - op="call_function", - target=operator.getitem, - args=(flat_args, idx), - name=arg.name - if not isinstance(arg, ConstantArgument) - else f"_constant_{idx}", - ) - if isinstance(arg, ConstantArgument): - continue - flat_arg_node.meta = copy.copy(self.seen_nodes[arg.name].meta) - self.node_to_placeholder[self.seen_nodes[arg.name]] = flat_arg_node - - with self.parent.graph.inserting_before(self.parent_call_module): - input_nodes: List[Optional[torch.fx.Node]] = [] - for input in signature.inputs: - if isinstance(input, ConstantArgument) and input.value is None: - input_nodes.append(None) - else: - assert isinstance(input, (TensorArgument, SymIntArgument)) - input_nodes.append( - self.parent.remap_input(self.seen_nodes[input.name]) - ) - - inputs_node = _generate_unflatten( - self.parent.module, - input_nodes, - signature.in_spec, - ) - - args_node = self.parent.graph.call_function( - operator.getitem, (inputs_node, 0) - ) - kwargs_node = self.parent.graph.call_function( - operator.getitem, (inputs_node, 1) - ) - arg_nodes = [ - self.parent.graph.call_function(operator.getitem, (args_node, i)) - for i in range(args_spec.num_children) - ] - kwarg_nodes = { - k: self.parent.graph.call_function( - operator.getitem, (kwargs_node, k) - ) - for k in kwargs_spec.context - } - assert self.parent_call_module is not None - self.parent_call_module.args = tuple(arg_nodes) - self.parent_call_module.kwargs = kwarg_nodes - - def add_placeholder(self, x): - assert x.graph is self.flat_graph - # x is not in subgraph, create a new placeholder for subgraph - with self.graph.inserting_before(None): - placeholder_node = self.graph.placeholder(x.name, type_expr=x.type) - # copy all meta fields, even if some fields might be irrelvant for - # the placeholder node - placeholder_node.meta = copy.copy(x.meta) - self.node_to_placeholder[x] = placeholder_node - - def remap_input(self, x): - assert x.graph is self.flat_graph - if x in self.node_map: - return self.node_map[x] - if x not in self.node_to_placeholder: - self.add_placeholder(x) - if self.parent_call_module is not None: - # Important to *prepend* the output to match how we are - # inserting placeholder nodes. - self.parent_call_module.insert_arg(0, self.parent.remap_input(x)) - return self.node_to_placeholder[x] - - def get_signature(self): - if self.module_call_graph is not None: - return self.module_call_graph.get(self.fqn) - return None - - def finalize_outputs(self): - orig_outputs = [] - signature = self.get_signature() - - if signature is not None and self.parent is not None: - for output in signature.outputs: - if isinstance(output, (TensorArgument, SymIntArgument)): - orig_outputs.append(self.seen_nodes[output.name]) - else: - raise RuntimeError( - f"Unsupported data type for output node: {output}" - ) - - tree_out_node = _generate_unflatten( - self.module, - tuple( - self.node_map[self.seen_nodes[output.name]] - for output in orig_outputs - ), - signature.out_spec, - ) - parent_out: Optional[torch.fx.Node] = _generate_flatten( - self.parent.module, self.parent_call_module, signature.out_spec - ) - graph_outputs: Union[torch.fx.Node, List[torch.fx.Node]] = tree_out_node - else: - graph_outputs = [] - # Iterate through nodes we have copied into self.graph. - for orig_node in self.node_map.keys(): - for user_node in orig_node.users: - if user_node.name not in self.seen_nodes: - # external user node, need to expose as an output - orig_outputs.append(orig_node) - graph_outputs.append(self.node_map[orig_node]) - break - - parent_out = self.parent_call_module - if len(graph_outputs) == 1: - graph_outputs = graph_outputs[0] - - assert isinstance(graph_outputs, (list, torch.fx.Node)) - - self.graph.output(graph_outputs) - - # Rewrite outputs in parent module - if parent_out is None: - return - - parent_out.meta["val"] = ( - graph_outputs.meta.get("val") - if isinstance(graph_outputs, torch.fx.Node) - else [o.meta.get("val") for o in graph_outputs] - ) - - if len(orig_outputs) == 1 and signature is None: - self.parent.node_map[orig_outputs[0]] = parent_out - else: - for i, orig_output in enumerate(orig_outputs): - # Use Proxy to record getitem access. - proxy_out = torch.fx.Proxy(parent_out)[i].node # type: ignore[index] - proxy_out.meta["val"] = orig_output.meta.get("val") - self.parent.node_map[orig_output] = proxy_out - - if self.cached_graph_module is not None: - _verify_graph_equivalence(self.cached_graph_module, self.module) - - def copy_node(self, node): - self.print("copying", node.format_node()) - self.node_map[node] = self.graph.node_copy(node, self.remap_input) - self.seen_nodes[node.name] = node - - def run_outer(self): - i = 0 - for node in self.flat_graph.nodes: - self.print(i, node.meta.get("nn_module_stack"), node.format_node()) - i += 1 - - # Copy all graph inputs - node_idx: int = 0 - node = self.nodes[node_idx] - while node.op == "placeholder": - self.copy_node(node) - node_idx += 1 - node = self.nodes[node_idx] - - self.run_from(node_idx) - - # Copy graph outputs - for node in self.flat_graph.nodes: - if node.op == "output": - self.copy_node(node) - - def print(self, *args, **kwargs): - if self.verbose: - print(*args, **kwargs) - - def run_from(self, node_idx): - module_idx = 0 - # Walk through the graph, building up a new graph with the right submodules - while node_idx < len(self.nodes): - node = self.nodes[node_idx] - assert node.op != "placeholder" - - self.print() - self.print("STEP", node_idx, node.format_node()) - self.print(self.module_stack) - if node.op == "output": - if len(self.module_stack) == 1: - # We want the output node of the original graph to be handled - # specially by the outermost stack frame (in run_outer). So - # skip finalization here. - return node_idx - - # We've reached the end of the graph. Wrap up all the existing stack frames. - self.finalize_outputs() - return node_idx - - node_module_stack = ( - [path for path, ty in node.meta["nn_module_stack"].values()] - if "nn_module_stack" in node.meta - else self.module_stack - ) - if node_module_stack[: len(self.module_stack)] != self.module_stack: - # This means that the current module is done executing and the - # current node is the beginning of a new module. - # - # In this case, we should finalize this module and return without - # incrementing the node counter. - self.finalize_outputs() - self.print("outlining", self.fqn) - self.print(self.graph) - return node_idx - - assert node_module_stack is not None - - if _is_prefix(self.module_stack, node_module_stack): - # This means that the current node represents the execution of a new - # module. - next_module = node_module_stack[len(self.module_stack)] - self.print("Creating new stack frame for", next_module) - # Run a nested version of module outliner from the current node - # counter. Once it is complete, continue from that point. - node_idx = _ModuleFrame( - self.flat_graph, - self.nodes, - self.seen_nodes, - self.seen_modules, - self, - self.module_stack + [next_module], - list(node.meta["nn_module_stack"].keys())[len(self.module_stack)], - self.module_call_graph, - ).run_from(node_idx) - module_idx += 1 - continue - - # The only remaining possibility is that we are in the right stack - # frame. Copy the node into this frame's graph and increment the node counter. - assert node_module_stack == self.module_stack - self.copy_node(node) - node_idx += 1 +from torch.export.unflatten import _ModuleFrame def _outline_submodules(orig_graph: torch.fx.Graph): @@ -463,80 +18,9 @@ def _outline_submodules(orig_graph: torch.fx.Graph): None, [""], "", + {}, module=new_module, ).run_outer() new_module.graph.lint() new_module.recompile() return new_module - - -def _sink_params( - module: torch.nn.Module, - inputs_to_state: Dict[str, str], - scope: List[str], -): - """Sink params, buffers, and constants from graph inputs into get_attr nodes. - - Exported modules are purely functional, so they pass their parameters and - buffers in as inputs to the graph. - - To replicate eager's semantics, we need to get them from the module state - via get_attr instead. - - module: GraphModule, potentially containining nested submodules. - inputs_to_state: mapping graph input names to the corresponding key in the state_dict. - scope: tracks where we are in the module hierarchy, so that we can emit the - right `getattr(self, "foo.bar")` calls, etc. - """ - # We need to use _modules here instead of named_children(), because we - # explicitly want duplicate modules to show up in the traversal. - for name, submodule in module._modules.items(): - _sink_params(cast(torch.nn.Module, submodule), inputs_to_state, scope + [name]) - - if not hasattr(module, "graph"): - # Not all modules have graphs defined, if they are empty modules with no operations (like ParameterList) - return - - graph = module.graph - inputs = list(filter(lambda n: n.op == "placeholder", graph.nodes)) - the_last_input = inputs[-1] - - # Also remove from call_module nodes - call_module_nodes = filter(lambda n: n.op == "call_module", graph.nodes) - for node in call_module_nodes: - node.args = tuple(filter(lambda n: n.name not in inputs_to_state, node.args)) - - for node in inputs: - if node.name not in inputs_to_state: - continue - - if len(node.users) > 0: - state_name = inputs_to_state[node.name].split(".") - # If there's a mismatch beteewn scope name and state name, then there must be multuple scopes - # pointing to the same state name, meaning some modules are shared. In such case, we can simply - # skip updating the current node because another later iteration will take care of this input - # node when the unique match between scope and state name occurs. - # To make sure this always happen, we should enforce the invariant that no placeholder node - # in the unflattened graph appears in inputs_to_state dict, which means all the extra input - # nodes have been handled. - if state_name[: len(scope)] != scope: - continue - attr_path = state_name[len(scope) :] - state_attr = _recursive_getattr(module, attr_path) - assert isinstance(state_attr, (torch.Tensor, torch.ScriptObject)) - - # Make sure the newly created get_attr node is placed after the last placeholder node - with graph.inserting_after(the_last_input): - new_node = graph.create_node("get_attr", ".".join(attr_path)) - - node.replace_all_uses_with(new_node, propagate_meta=True) - graph.erase_node(node) - if isinstance(module, InterpreterModule): - module.finalize() - - -def _recursive_getattr(obj, attr_path): - for attr in attr_path: - obj = getattr(obj, attr) - - return obj From 74b99438f269989ccd23f17824c8f137fc4f23ca Mon Sep 17 00:00:00 2001 From: cyy Date: Sat, 18 May 2024 02:28:17 +0000 Subject: [PATCH 095/116] [Submodule] Remove third-party CUB (#126540) Because it was updated 4 years ago, and now all supported CUDA versions provide CUB. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126540 Approved by: https://github.com/Skylion007 --- .gitmodules | 4 ---- cmake/Dependencies.cmake | 7 +++---- third_party/cub | 1 - 3 files changed, 3 insertions(+), 9 deletions(-) delete mode 160000 third_party/cub diff --git a/.gitmodules b/.gitmodules index 811be7863194..db7698876a29 100644 --- a/.gitmodules +++ b/.gitmodules @@ -2,10 +2,6 @@ ignore = dirty path = third_party/pybind11 url = https://github.com/pybind/pybind11.git -[submodule "third_party/cub"] - ignore = dirty - path = third_party/cub - url = https://github.com/NVlabs/cub.git [submodule "third_party/eigen"] ignore = dirty path = third_party/eigen diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 4670ebadf2f5..8bd80f167a5b 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -1300,11 +1300,10 @@ endif() # ---[ CUB if(USE_CUDA) find_package(CUB) - if(CUB_FOUND) - include_directories(SYSTEM ${CUB_INCLUDE_DIRS}) - else() - include_directories(SYSTEM ${CMAKE_CURRENT_LIST_DIR}/../third_party/cub) + if(NOT CUB_FOUND) + message(FATAL_ERROR "Cannot find CUB.") endif() + include_directories(SYSTEM ${CUB_INCLUDE_DIRS}) endif() if(USE_DISTRIBUTED AND USE_TENSORPIPE) diff --git a/third_party/cub b/third_party/cub deleted file mode 160000 index d106ddb991a5..000000000000 --- a/third_party/cub +++ /dev/null @@ -1 +0,0 @@ -Subproject commit d106ddb991a56c3df1b6d51b2409e36ba8181ce4 From b98decfc38dc1fd453a1514e5e280e06a6b6082d Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Fri, 17 May 2024 14:50:05 -0700 Subject: [PATCH 096/116] [halide-backend] Refactor codegen/triton.py into codegen/simd.py (#126415) This PR is primarily just moving stuff around. It creates a new common baseclass for TritonCodegen and the (upcoming) HalideCodegen. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126415 Approved by: https://github.com/shunting314 --- torch/_inductor/codegen/simd.py | 1917 ++++++++++++++++++ torch/_inductor/codegen/triton.py | 1912 +---------------- torch/_inductor/codegen/triton_split_scan.py | 7 +- torch/_inductor/scheduler.py | 4 +- torch/_inductor/select_algorithm.py | 7 +- torch/_inductor/utils.py | 2 +- 6 files changed, 1977 insertions(+), 1872 deletions(-) create mode 100644 torch/_inductor/codegen/simd.py diff --git a/torch/_inductor/codegen/simd.py b/torch/_inductor/codegen/simd.py new file mode 100644 index 000000000000..2a002fa3677f --- /dev/null +++ b/torch/_inductor/codegen/simd.py @@ -0,0 +1,1917 @@ +from __future__ import annotations + +import collections +import contextlib +import dataclasses +import functools +import itertools +import logging +import math +import operator +from typing import ( + Any, + Callable, + Counter, + DefaultDict, + Dict, + Iterable, + List, + Optional, + Set, + Tuple, + TYPE_CHECKING, + Union, +) + +import sympy + +import torch +import torch._logging + +from torch.utils._sympy.functions import FloorDiv, ModularIndexing +from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT +from ..._dynamo.utils import counters +from .. import config, ir, scheduler +from ..codecache import code_hash + +from ..dependencies import Dep, MemoryDep, StarDep, WeakDep +from ..ir import TritonTemplateBuffer +from ..optimize_indexing import indexing_dtype_strength_reduction +from ..runtime.hints import ReductionHint, TRITON_MAX_BLOCK +from ..runtime.runtime_utils import get_max_y_grid, green_text, yellow_text +from ..scheduler import BaseSchedulerNode, BaseScheduling, WhyNoFuse +from ..utils import ( + get_dtype_size, + IndentedBuffer, + Placeholder, + sympy_dot, + sympy_index_symbol, + sympy_product, + sympy_subs, + unique, +) +from ..virtualized import V +from .common import CSEVariable, index_prevent_reordering, Kernel, PythonPrinter +from .multi_kernel import MultiKernel + +if TYPE_CHECKING: + pass + +log = logging.getLogger(__name__) +perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints") +schedule_log = torch._logging.getArtifactLogger(__name__, "schedule") +fusion_log = torch._logging.getArtifactLogger(__name__, "fusion") + + +pexpr = PythonPrinter().doprint + + +@dataclasses.dataclass +class IndexingOptions: + index_str: str + mask_vars: Set[sympy.Symbol] + mask_str: str + expand_str: Optional[str] + _has_rindex: bool + index: sympy.Expr + + def has_mask(self): + return bool(self.mask_vars) + + def has_rindex(self): + return self._has_rindex + + def has_tmpmask(self): + return "tmp" in self.mask_str + + def has_rmask(self): + return "rmask" in self.mask_str + + +@dataclasses.dataclass +class IterationRanges: + """ + Each range tree represents multiple sets of iteration indexing + in a single tiled dimension in the output kernel. + + If you have two loops ranges one (4, 3, 2) and another (4, 6), + then the range tree will be: + 4 (i0) + 3 (i1) 6 (i3) + 2 (i2) + Where i0 is shared between both loops, but then the split into + different indexing vars. All loop ranges must iterate over + the same number of elements. + """ + + def __init__( + self, + name: str, + var_list: List[sympy.Symbol], + var_ranges: Dict[sympy.Symbol, sympy.Expr], + numel: sympy.Expr, + prefix: str, + *, + kernel: SIMDKernel, + divisor=sympy.Integer(1), + length=sympy.Integer(1), + root: IterationRangesRoot, + ): + super().__init__() + self.name = name + self.var_list = var_list + self.var_ranges = var_ranges + self.numel = numel + self.prefix = prefix + self.divisor = divisor + self.length = length + self.kernel = kernel + self.root = root + + def symbol(self): + return sympy_index_symbol(self.name) + + +class IterationRangesRoot(IterationRanges): + def __init__( + self, + name: str, + numel: sympy.Expr, + # TODO: this is probably SymTy.INDEX and SymTy.RINDEX + prefix: str, + index: int, + kernel: SIMDKernel, + pid_cache=None, + *, + is_loop: bool, + tensor_dim: Optional[int], + grid_dim: Optional[int], + has_zdim: bool, + ): + if pid_cache is None: + pid_cache = {} + super().__init__( + name=name, + var_list=[], + var_ranges={}, + numel=numel, + prefix=prefix, + kernel=kernel, + root=self, + ) + self.index = index + # Store all the nodes in one flat list + self.nodes: Dict[sympy.Expr, IterationRangesEntry] = {} + # This is for re-ordering program ID in triton mm template + # pid_cache["tl.program_id(0)"] = pid_m + self.pid_cache: Dict[str, str] = pid_cache + + # True if the dimension is implemented as a single program looping over + # the full dimension (currently only used for non-persistent reduction) + assert not is_loop or (prefix == "r" and grid_dim is None) + self.is_loop = is_loop + # Index of corresponding dimension on triton tensors + self.tensor_dim = tensor_dim + # Index of corresponding dimension in the triton grid + self.grid_dim = grid_dim + self.has_zdim = has_zdim + + def __repr__(self): + return f"IterationRangesRoot({self.name!r}, {self.numel}, ...)" + + def cache_clear(self): + for node in self.nodes.values(): + node.cache_clear() + + def lookup(self, divisor, length): + """ + Lookup a given RangeTreeEntry, creating it if needed + """ + if V.graph.sizevars.statically_known_equals(divisor * length, self.numel): + expr = FloorDiv(sympy_index_symbol(f"{self.prefix}index"), divisor) + else: + expr = ModularIndexing( + sympy_index_symbol(f"{self.prefix}index"), divisor, length + ) + + if expr not in self.nodes: + node = IterationRangesEntry( + f"{self.prefix}{next(V.kernel.iter_vars_count)}", + divisor, + length, + expr, + self, + ) + V.kernel.range_tree_nodes[node.symbol()] = node + self.var_list.append(node.symbol()) + self.var_ranges[node.symbol()] = length + self.nodes[expr] = node + return self.nodes[expr] + + def construct_entries(self, lengths: List[sympy.Expr]): + divisor = sympy.Integer(1) + itervars = [] + for length in reversed(lengths): + itervars.append(self.lookup(divisor, length)) + divisor = divisor * length + return list(reversed(itervars)) + + def construct(self, lengths: List[sympy.Expr]): + return [e.symbol() for e in self.construct_entries(lengths)] + + def vars_and_sizes(self, index: sympy.Expr): + """Figure out vars from this tree used in index""" + nodes = [V.kernel.range_tree_nodes.get(s) for s in index.free_symbols] + nodes = [n for n in nodes if n and n.prefix == self.prefix] + nodes.sort(key=lambda x: V.graph.sizevars.size_hint(x.divisor)) + divisor = sympy.Integer(1) + index_vars = [] + sizes = [] + + def add(node): + nonlocal divisor + index_vars.append(node.symbol()) + sizes.append(node.length) + divisor = divisor * node.length + + for node in nodes: + if not V.graph.sizevars.statically_known_equals(node.divisor, divisor): + # fill in unused index var + add(self.lookup(divisor, FloorDiv(node.divisor, divisor))) + divisor = node.divisor + add(node) + if not V.graph.sizevars.statically_known_equals(self.numel, divisor): + # fill in unused index var + add(self.lookup(divisor, FloorDiv(self.numel, divisor))) + + return list(reversed(index_vars)), list(reversed(sizes)) + + def ranges_code(self): + assert self.tensor_dim is not None + size = self.kernel.indexing_size_str(self.tensor_dim) + index_dtype = self.kernel.index_dtype + convert = f".to({index_dtype})" if index_dtype != "tl.int32" else "" + return f"tl.arange(0, {self.prefix.upper()}BLOCK){size}{convert}" + + def scalar_code(self, value): + index_dtype = self.kernel.index_dtype + ndim = self.kernel.triton_tensor_ndim() + size = [1] * ndim + return f"tl.full({size}, {value}, {index_dtype})" + + def get_pid(self): + assert self.grid_dim is not None + key = f"tl.program_id({self.grid_dim})" + # y_grid has a limit, so express it in terms of y and z in case of overflow. + # z grid is only exercised when max_tiles == 3 (off by default). + if ( + self.grid_dim == 1 + and not self.has_zdim + and not (isinstance(self.numel, int) and self.numel <= get_max_y_grid()) + ): + key = f"{key} * (tl.program_id({self.grid_dim + 1}) + 1)" + pid = self.pid_cache.get(key, key) + if self.kernel.index_dtype != "tl.int32": + return f"{pid}.to({self.kernel.index_dtype})" + return pid + + def codegen_header(self, code): + x = self.prefix + if self.is_loop: + code.writeline(f"{self.name} = {x}offset + {x}base") + elif self.grid_dim is None: + # no need to "{x}offset = " + code.writeline(f"{self.name} = {self.ranges_code()}") + code.writeline(f"{x}offset = 0") + else: + if self.tensor_dim is not None: + line = f"{x}offset + {self.ranges_code()}" + else: + line = self.scalar_code(f"{x}offset") + code.writelines( + [ + f"{x}offset = {self.get_pid()} * {x.upper()}BLOCK", + f"{self.name} = {line}", + ] + ) + code.writeline(f"{x}mask = {self.name} < {x}numel") + + +class IterationRangesEntry(IterationRanges): + def __init__( + self, + name: str, + divisor: sympy.Expr, + length: sympy.Expr, + expr: sympy.Expr, + parent: IterationRanges, + ): + super().__init__( + name=name, + numel=parent.numel / length, + var_list=parent.var_list, + var_ranges=parent.var_ranges, + prefix=parent.prefix, + divisor=divisor, + length=length, + kernel=parent.kernel, + root=parent.root, + ) + self.parent = parent + self.codegen = functools.lru_cache(None)(self._codegen) + self.expr = expr + + def __repr__(self): + return f"IterationRangesEntry({self.name}, {self.divisor}, {self.length}, {self.expr}, {self.var_ranges})" + + def set_name(self, name): + self.codegen = lambda: name # type: ignore[assignment] + self.codegen.cache_clear = lambda: None # type: ignore[method-assign] + self.name = name + + def cache_clear(self): + self.codegen.cache_clear() + + def _codegen(self): + V.kernel.codegen_iteration_ranges_entry(self) + return self.name + + def precomputed_args(self): + # for dynamic shapes, find parts of indexing expressions that have to be precomputed + precomputed_args: List[sympy.Expr] = [] + if isinstance(self.expr, sympy.Symbol): + return precomputed_args + assert isinstance(self.expr, (FloorDiv, ModularIndexing)), type(self.expr) + for arg in self.expr.args[1:]: + if not isinstance(arg, (sympy.Integer, sympy.Symbol)): + symbols = arg.free_symbols + if len(symbols) > 0 and all( + symbol_is_type(s, SymT.SIZE) for s in symbols + ): + precomputed_args.append(arg) + return precomputed_args + + def __hash__(self): + return hash(self.name) + + def __eq__(self, other): + return self.name == other.name + + +def triton_constant(value): + if value == float("inf"): + return 'float("inf")' + elif value == float("-inf"): + return 'float("-inf")' + elif math.isnan(value): + return 'float("nan")' + return repr(value) + + +class SIMDKernel(Kernel): + """ + Common base class for Triton/Halide codegen which both use flattened indexing rather than loop nests. + """ + + sexpr = pexpr + kexpr: Callable[[sympy.Expr], str] + allow_block_ptr = False + + def __init__( + self, + *groups, + index_dtype: str, + mutations: Optional[Set[str]] = None, + pid_cache=None, + reduction_hint=ReductionHint.DEFAULT, + disable_persistent_reduction=False, + ): + if pid_cache is None: + pid_cache = {} + super().__init__() + self.body = IndentedBuffer() + self.indexing_code = IndentedBuffer() + self.numels = [V.graph.sizevars.simplify(s) for s in groups] + self.mutations: Set[str] = mutations if mutations is not None else set() + self.range_trees: List[IterationRangesRoot] = [] + self.range_tree_nodes: Dict[sympy.Symbol, IterationRangesEntry] = {} + self.iter_vars_count = itertools.count() + self.inside_reduction = self.numels[-1] != 1 + self.reduction_hint = reduction_hint + self.index_dtype: str = index_dtype + self.last_usage: Set[str] = set() + self.buf_accesses: DefaultDict[str, List[Dep]] = collections.defaultdict(list) + self.persistent_reduction: bool = ( + not disable_persistent_reduction + ) and self.should_use_persistent_reduction() + self.no_x_dim = self.want_no_x_dim() + self.code_hash = None + + # define this in a closure to make cache local to object + @functools.lru_cache(None) + def simplify_indexing(index: sympy.Expr): + index = V.graph.sizevars.simplify_with_ranges(index, self.var_ranges()) + for tree in self.range_trees: + index = self.combine_contiguous_dims(index, tree) + return index + + self.simplify_indexing = simplify_indexing + self.initialize_range_tree(pid_cache) + + def want_no_x_dim(self): + return False + + def initialize_range_tree(self, pid_cache): + no_r_dim = not self.inside_reduction or self.numels[-1] == 1 + + prefixes = "zyxr" + active_prefixes = prefixes[-len(self.numels) :] + + grid_dims = "xyz" + if self.no_x_dim: + tensor_dims = "r" + elif no_r_dim: + tensor_dims = "xyz" + else: + tensor_dims = "xyzr" + + tensor_dims = "".join(p for p in tensor_dims if p in active_prefixes) + + for i, prefix in enumerate(active_prefixes): + is_reduction = prefix == "r" + tensor_dim = tensor_dims.find(prefix) if prefix in tensor_dims else None + grid_dim = None if is_reduction else grid_dims.find(prefix) + index = i if grid_dim is None else grid_dim + self.range_trees.append( + IterationRangesRoot( + f"{prefix}index", + self.numels[i], + prefix, + index, + self, + pid_cache=pid_cache, + is_loop=is_reduction and not self.persistent_reduction, + tensor_dim=tensor_dim, + grid_dim=grid_dim, + has_zdim="z" in active_prefixes, + ) + ) + + def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable): + prior = self.inside_reduction + self.inside_reduction = False + try: + return self.store(name, index, value) + finally: + self.inside_reduction = prior + + def should_use_persistent_reduction(self) -> bool: + return False # defined in subclass + + def var_ranges(self): + return dict( + itertools.chain.from_iterable( + tree.var_ranges.items() for tree in self.range_trees + ) + ) + + def triton_tensor_ndim(self): + return sum(int(tree.tensor_dim is not None) for tree in self.range_trees) + + def indexing_size_str(self, i): + sizes = ["None"] * self.triton_tensor_ndim() + sizes[i] = ":" + return f"[{', '.join(sizes)}]" + + def dense_size_list(self) -> List[str]: + sizes = ["1"] * self.triton_tensor_ndim() + for tree in self.range_trees: + if tree.tensor_dim is None: + continue + + if tree.prefix != "r" or self.inside_reduction: + sizes[tree.tensor_dim] = f"{tree.prefix.upper()}BLOCK" + return sizes + + def dense_size_str(self): + sizes = self.dense_size_list() + return f"[{', '.join(sizes)}]" + + def combine_contiguous_dims(self, index: sympy.Expr, tree: IterationRangesRoot): + """ + More aggressive simplification to merge contiguous dims + """ + if isinstance(index, (sympy.Integer, sympy.Symbol)): + return index + index_vars, sizes = tree.vars_and_sizes(index) + if len(sizes) <= 1: + return index + new_sizes, reindex, prune = V.graph.sizevars._simplify_loops( + index_vars, sizes, index_prevent_reordering([index], index_vars, sizes) + ) + if new_sizes == sizes: + return index + new_index_vars = tree.construct(new_sizes) + new_index = sympy_subs(index, dict(zip(index_vars, reindex(new_index_vars)))) + return new_index + + def set_last_usage(self, nodes): + if not self.inside_reduction or self.persistent_reduction: + return + self.last_usage = set( + itertools.chain.from_iterable( + n.last_usage for n in nodes if n is not EnableReduction + ) + ) + + def disable_reduction(self): + should_flush = self.range_trees[-1].is_loop + + @contextlib.contextmanager + def ctx(): + if self.numels[-1] == 1: + assert not self.inside_reduction + yield + return + if should_flush: + # calling codegen_body() will flush all the pending buffers + # and write out a reduction loop + self.codegen_body() + self.inside_reduction = False + try: + yield + if should_flush: + # flush out any code before opening the next loop + self.codegen_body() + finally: + self.inside_reduction = True + + return ctx() + + def set_ranges(self, *lengths): + assert len(lengths) == len(self.range_trees) + return [ + ranges.construct(length) + for length, ranges in zip(lengths, self.range_trees) + ] + + @staticmethod + def _split_iteration_ranges( + groups: Iterable[sympy.Expr], lengths: List[List[sympy.Expr]] + ): + sv = V.graph.sizevars + new_ranges: List[List[sympy.Expr]] = [[] for _ in groups] + remaining = [sv.simplify(g) for g in groups] + var_count = itertools.count() + + def add_range(i, expr): + expr = sv.simplify(expr) + if not sv.statically_known_multiple_of(remaining[i], expr): + raise CantSplit + # guard on the last item out + remaining[i] = FloorDiv(remaining[i], expr) + new_ranges[i].append(expr) + return next(var_count) + + def make_combined(size, idx1, idx2): + def getter(flat_vars): + return size * flat_vars[idx1] + flat_vars[idx2] + + return getter + + return_getters_groups = [] + current_group = 0 + for length_group in lengths: + return_getters = [] + for size in length_group: + if sv.statically_known_equals(size, 1): # type: ignore[arg-type] + return_getters.append(lambda _: sympy.Integer(0)) + continue + + while current_group < len(remaining) and sv.statically_known_equals( + remaining[current_group], 1 # type: ignore[arg-type] + ): + # scroll to next group with remaining elements + current_group += 1 + + if current_group + 1 < len(remaining) and sv.statically_known_gt( + size, remaining[current_group] + ): + # need to break size in two + if not sv.statically_known_multiple_of( + size, remaining[current_group] + ): + raise CantSplit + size1 = remaining[current_group] + size2 = FloorDiv(size, remaining[current_group]) + return_getters.append( + make_combined( + size2, + add_range(current_group, size1), + add_range(current_group + 1, size2), + ) + ) + else: + return_getters.append( + operator.itemgetter(add_range(current_group, size)) + ) + return_getters_groups.append(return_getters) + + assert all( + V.graph.sizevars.size_hint(s) == 1 for s in remaining + ), f"failed to set ranges {remaining} {lengths}" + + return new_ranges, return_getters_groups + + @classmethod + def is_compatible( + cls, groups: Iterable[sympy.Expr], lengths: List[List[sympy.Expr]] + ): + try: + cls._split_iteration_ranges(groups, lengths) + return True + except CantSplit: + return False + + def split_and_set_ranges(self, lengths: List[List[sympy.Expr]]): + """ + We may want to fuse `for i0 in s0*s1` into a tiled kernel with groups (s0, s1). + + To do this we need to split up the iteration space of i0 into something like: + for i1 in s0: + for i2 in s1: + i0 = i1*s1 + i2 + .... + + This function matches and resplits lengths to the groups of + this kernel to enable tiled + non-tiled fusions. + """ + groups = [rt.numel for rt in self.range_trees] + if not self.inside_reduction: + groups[-1] = sympy.Integer(1) + + if len(lengths) == len(self.range_trees) and all( + V.graph.sizevars.simplify(sympy_product(x) - g) == 0 + for x, g in zip(lengths, groups) + ): + return self.set_ranges(*lengths) + + new_ranges, return_getters_groups = self._split_iteration_ranges( + groups, lengths + ) + itervars = list(itertools.chain.from_iterable(self.set_ranges(*new_ranges))) + return [[fn(itervars) for fn in fns] for fns in return_getters_groups] + + def is_indirect_indexing(self, index: sympy.Expr): + # tmpX means indirect indexing + return free_symbol_is_type(index, SymT.TMP) + + def is_broadcasted(self, index: sympy.Expr): + # Note. This may not be correct when there is indirect indexing + if self.is_indirect_indexing(index): + return False + + index_numels = [1] * len(self.numels) + for symbol in index.free_symbols: + if symbol not in self.range_tree_nodes: + # Non-iterated variables, e.g. strides + continue + entry = self.range_tree_nodes[symbol] # type: ignore[index] + assert isinstance(entry.parent, IterationRangesRoot) + index_numels[entry.parent.index] *= entry.length + + # If the index variables only iterate over a subset of the kernel + # numels, then it must be broadcasted. + simplify = V.graph.sizevars.simplify + return any( + simplify(idx_range) != simplify(iter_range) # type: ignore[arg-type] + for idx_range, iter_range in zip(index_numels, self.numels) + ) + + def index_to_str(self, index: sympy.Expr) -> str: + """ + Convert an index expr to a string that can be used in triton code. + e.g. a sympy expression "s2" may actually appear as "ks1" in the triton kernel. + + Index expressions often need to be passed in as arguments to the triton kernel. + Rename_indexing and codegen_indexing keep track of the needed indices and add + new parameters to the function signature. + """ + if isinstance(index, list): + return f"[{', '.join(map(self.index_to_str, index))}]" + return self.kexpr( # type: ignore[call-arg] + self.rename_indexing(self.codegen_indexing(index)) + ) + + def indexing( + self, + index: sympy.Expr, + *, + copy_shape=None, + dense_indexing=False, + override_mask=None, + block_ptr=False, + ): + """ + Compute the index and mask to pass to tl.load() or tl.store() + """ + index = self.simplify_indexing(index) + index = sympy_subs(index, V.graph.sizevars.precomputed_replacements) + # if simple replacements didn't get rid of floor/ceil, try full subs + if len(index.atoms(sympy.floor)) or len(index.atoms(sympy.ceiling)): + index = index.subs(V.graph.sizevars.precomputed_replacements) + # last resort, if no range vars are in the expr, hoist it + # TODO instead of trying to blindly find complicated exprs, we should hoist the + # inputs/outputs sizes and strides, but at the time indexing is generated + # kernel inputs and outputs are not set yet, we'd need a deeper refactor + # to do it this way + + if len(index.atoms(sympy.ceiling)): + for a in index.atoms(sympy.ceiling): + # for nested exprs, atoms yields top level first (?) + # so if everything goes fine, lower level replacements will come up empty + symbols = a.free_symbols + if len(symbols) > 0 and all( + symbol_is_type(s, (SymT.SIZE, SymT.PRECOMPUTED_SIZE)) + for s in symbols + ): + replacements = {a: V.graph.sizevars.lookup_precomputed_size(a)} + index = sympy_subs(index, replacements) + + index = self.simplify_indexing(index) + index_vars = index.free_symbols + has_rindex = False + + mask_vars: Set[str] = set() + for var in index_vars: + assert isinstance(var, sympy.Symbol) + has_rindex = has_rindex or symbol_is_type(var, SymT.RINDEX) + if override_mask: + pass + elif symbol_is_type(var, SymT.TMP): + # indirect indexing + cse_var = self.cse.varname_map[var.name] + mask_vars.update(cse_var.mask_vars) + elif symbol_is_type( + var, + ( + SymT.UNBACKED_INT, + SymT.SIZE, + SymT.PRECOMPUTED_SIZE, + SymT.INDEX, + SymT.FLOAT, + SymT.UNBACKED_FLOAT, + ), + ): + pass + else: + # var is one of xN, yN or rN + assert symbol_is_type( + var, (SymT.RINDEX, SymT.XBLOCK, SymT.YBLOCK) + ), var.name + mask_vars.add(f"{var.name[0]}mask") + + need_dense = ( + config.triton.dense_indexing + or dense_indexing + or self._load_mask is not None + ) and index != 0 + + have_dense = True + have_loop_vars = False + dense_mask_vars = set() + + for tree in self.active_range_trees(): + if index_vars.intersection(tree.var_list): + have_loop_vars = True + else: + have_dense = False + dense_mask_vars.add(f"{tree.prefix}mask") + + if ( + block_ptr + and self.allow_block_ptr + and config.triton.use_block_ptr + and not override_mask + and not self._load_mask + and len(mask_vars - dense_mask_vars) == 0 + and not self.is_indirect_indexing(index) + and have_loop_vars + # workaround https://github.com/openai/triton/issues/2821 + and self.index_dtype == "tl.int32" + ): + index_relative_to_xyr_index = sympy_subs( + index, {v: t.expr for v, t in self.range_tree_nodes.items()} + ) + range_trees = self.active_range_trees(reorder=True) + symbols = [t.symbol() for t in range_trees] + strides = [sympy.Wild(f"stride_{s}", exclude=symbols) for s in symbols] + offset = sympy.Wild("_offset", exclude=symbols) + m = index_relative_to_xyr_index.match(sympy_dot(symbols, strides) + offset) + # TODO(jansel): it is sometimes possible to do higher dimensional block_ptrs with + # a tl.reshape the correct block. We will miss these cases today. + if m: + self.filter_masks(mask_vars) + from .triton import BlockPtrOptions + + return BlockPtrOptions.create( + [m[s] for s in strides], + m[offset], + range_trees, + mask_vars, # type: ignore[arg-type] + ) + + expand_str = None + index_str = self.index_to_str(index) + if isinstance(index, sympy.Integer): + expand_str = f"{copy_shape}.shape" if copy_shape else self.dense_size_str() + index_str = f"tl.full({expand_str}, {index_str}, tl.int32)" + return IndexingOptions( + index_str, set(), "None", expand_str, has_rindex, index + ) + + if need_dense and not have_dense: + expand_str = f"{copy_shape}.shape" if copy_shape else self.dense_size_str() + index_str = f"tl.broadcast_to({index_str}, {expand_str})" + mask_vars = dense_mask_vars + elif not have_loop_vars and copy_shape: + index_str = f"tl.broadcast_to({index_str}, {copy_shape}.shape)" + mask_vars = dense_mask_vars + + if override_mask: + mask_vars = {override_mask} + + if self._load_mask: + mask_vars.add(self._load_mask) + + self.filter_masks(mask_vars) + + mask_str = " & ".join(sorted(map(str, mask_vars))) if mask_vars else "None" + return IndexingOptions(index_str, mask_vars, mask_str, expand_str, has_rindex, index) # type: ignore[arg-type] + + def active_range_trees(self, reorder=False): + trees = [ + t for t in self.range_trees if t.prefix != "r" or self.inside_reduction + ] + if reorder and len(trees) > 1: + count = sum(t.prefix in "xyz" for t in trees) + assert "".join(t.prefix for t in trees[:count]) == "zyx"[-count:], [ + t.prefix for t in trees[:count] + ] + trees[:count] = reversed(trees[:count]) + return trees + + def filter_masks(self, mask_vars): + for tree in self.range_trees: + # Masks are superfluous if we only have one element + if V.graph.sizevars.statically_known_equals(tree.numel, 1): # type: ignore[arg-type] + mask_vars.discard(f"{tree.prefix}mask") + continue + # Masks are superfluous if numel is a multiple of BLOCK + # (We use the fact that BLOCK is required by triton to be a power of 2) + if tree.prefix.upper() not in TRITON_MAX_BLOCK: + continue + max_block = TRITON_MAX_BLOCK[tree.prefix.upper()] + # Optional optimization: if block divides numel exactly, we will + # never need to do a masked load to handle stragglers at the end. + # It's faster to avoid masking at all. But it is sound to always + # mask. + if V.graph.sizevars.statically_known_multiple_of(tree.numel, max_block): # type: ignore[arg-type] + mask_vars.discard(f"{tree.prefix}mask") + + def codegen_indexing(self, expr: sympy.Expr): + expr = V.graph.sizevars.simplify_with_ranges(expr, self.var_ranges()) + for sym in sorted(expr.free_symbols, key=str): + if sym in self.range_tree_nodes: + # if indexing expression is complicated, we precompute it on the host side + # and send the result as a kernel argument + replacements = {} + for ps in self.range_tree_nodes[sym].precomputed_args(): # type: ignore[index] + replacements[ps] = V.graph.sizevars.lookup_precomputed_size(ps) + if len(replacements) > 0: + self.range_tree_nodes[sym].expr = sympy_subs( # type: ignore[index] + self.range_tree_nodes[sym].expr, replacements # type: ignore[index] + ) + self.range_tree_nodes[sym].codegen() # type: ignore[index] + return expr + + @contextlib.contextmanager + def mask_loads(self, mask): + """Context manager to add an additional mask to tl.load/store""" + prior = self._load_mask + if prior: + mask = self.cse.generate(self.compute, f"{mask} & {prior}") + + self._load_mask = mask + try: + # TODO(jansel): do we need a reshape here? + yield mask + finally: + self._load_mask = prior + + def load_mask(self, var): + mask = "" + mask_vars = set(var.mask_vars) + if self._load_mask: + mask_vars.add(self._load_mask) + + if mask_vars: + mask = ( + f"{next(iter(mask_vars))}" + if len(mask_vars) == 1 + # sorted for deterministic order + else f"({' & '.join(sorted(map(str, mask_vars)))})" + ) + return mask + + def get_strides_of_load(self, index: sympy.Expr): + """ + This gets the stride of the index for each of the tiling variables + (technically, it does it at index 0) + + For example, if + xindex = x0 + 512*x1 + 1024*r0 + x0 = (xindex//512) + x1 = (xindex % 512) + r0 = rindex // 1024 + + this function would return + {xindex: 512, rindex: 1024} + """ + index_to_tile_indexes = {k: v.expr for k, v in self.range_tree_nodes.items()} + index_in_tile_vars = sympy_subs(index, index_to_tile_indexes) # type: ignore[arg-type] + strides = {} + for range_tree in self.range_trees: + s = sympy_index_symbol(range_tree.name) + strides[s] = sympy_subs(index_in_tile_vars, {s: 1}) - sympy_subs( + index_in_tile_vars, {s: 0} + ) + return strides + + @staticmethod + def _map_tuple_or_scalar(fn, value): + if isinstance(value, tuple): + return tuple(map(fn, value)) + return fn(value) + + def estimate_kernel_num_bytes(self): + """ + Try the best to estimate the total size (in bytes) of the + kernel's inputs and outputs, which is used for estimating the memory + throughput of this kernel. This information is used for checking how + far we are from the peak memory bandwidth. It's important that + we want to avoid overestimating the sizes of the inputs and outputs, + because it can wrongfully give us a very large memory traffic value, + which may be even larger than the theoretical bandwidth and thus + become very misleading. This is particularly problematic for cases + where we slice some inputs. In those cases, we should only count + the size of the "slices" instead of the original inputs, because + only the slices contribute to the real memory traffic. + """ + nbytes = [] + ninplace_args = len(unique(self.args.inplace_buffers.values())) + _, call_args, _ = self.args.python_argdefs() + + # For pointwise and reduction kernels, this is the upper-bound numels + # for the output buffer. + # FIXME: This is not exactly right for cases like below: + # def foo(tensor0, tensor1): + # x0 = narrow(tensor0) + # return cat(x0, tensor1) + # For this example, we will end up overestimate the size for the + # slice s0. Potentially, we could have precise inputs information + # if we maintained the original inputs of the Pointwise kernel created + # for the "cat". However, I think it might be a bit overwhelming that + # we add such complexity only for handling some particular cases for + # benchmarking. + out_numel = V.graph.sizevars.size_hint(sympy_product(self.numels)) + for i, arg in enumerate(call_args): + # "buf" may be narrowed. In this case, the number of memory accesses + # should be estimated based on the reinterpreted layout. + # On the other hand, buf may be broadcasted. In this case, + # counting the size of the underline storage would give us + # a better estimation in terms of memory accesses. + if arg not in self.buf_accesses: + nbytes.append(0) + continue + arg_numel = V.graph.get_numel(arg) + buf_size = V.graph.sizevars.size_hint(arg_numel) + if buf_size > out_numel: + # This arg points to a buf that has been sliced. + # We need to count each individual slice to have + # a better estimation. + indices: Set[Any] = set() + no_index_dep_count = 0 + for dep in self.buf_accesses[arg]: + if isinstance(dep, (StarDep, WeakDep)): + indices.add(f"no_index_dep_{no_index_dep_count}") + no_index_dep_count += 1 + else: + indices.add(dep.index) + numel = len(indices) * out_numel + else: + numel = buf_size + dtype = V.graph.get_dtype(arg) + dtype_size = get_dtype_size(dtype) + nbytes.append(numel * dtype_size * (1 + int(i < ninplace_args))) + return sum(nbytes) + + def warn_mix_layout(self, kernel_name): + """ + Print message if the kernel have mixed layout inputs. + Only care about 4D tensor for now. + """ + if ( + len(self.args.input_buffers) == 1 + and len(self.args.output_buffers) == 1 + and len(self.args.inplace_buffers) == 0 + ): + # even if input buffer and output buffer have different layout, + # this can be a layout conversion kernel. No need to warn for + # the mix layouts. + return + + argdefs, call_args, signature = self.args.python_argdefs() + uniform_stride_order = None + for arg_name in call_args: + buf = V.graph.get_buffer(arg_name) + if buf and len(buf.layout.size) == 4: + # ignore the tensor if only 1 dimension is non-zero + if len([x for x in buf.layout.size if x == 1]) == 3: + continue + stride_order = ir.get_stride_order(buf.layout.stride) + if uniform_stride_order is None: + uniform_stride_order = stride_order + elif uniform_stride_order != stride_order: + msg = yellow_text( + f"Expected stride order {uniform_stride_order}, but found stride order" + + f" {stride_order} for kernel {kernel_name}" + ) + log.warning(msg) + + stride_order_list = [ + ir.get_stride_order(V.graph.get_buffer(name).layout.stride) + if V.graph.get_buffer(name) + else None + for name in call_args + ] + size_list = [ + V.graph.get_buffer(name).layout.size + if V.graph.get_buffer(name) + else None + for name in call_args + ] + source_list = [ + "GraphInput" + if name in V.graph.graph_inputs + else "IntermediateBuffer" + if name in V.graph.name_to_buffer + else None + for name in call_args + ] + + msg = yellow_text( + f" param names {argdefs}\n buf names {call_args}\n strides {stride_order_list}" + + f"\n sizes {size_list}\n sources {source_list}\n" + ) + log.warning(msg) + return + msg = green_text( + f"All the inputs for the triton kernel {kernel_name} have uniform layout" + ) + log.warning(msg) + + def codegen_kernel(self): + raise NotImplementedError + + def codegen_body(self): + pass + + def codegen_iteration_ranges_entry(self, entry: IterationRangesEntry): + raise NotImplementedError + + +class SIMDScheduling(BaseScheduling): + kernel_type = SIMDKernel # override in subclass + int32_type = "torch.int32" + int64_type = "torch.int64" + + def __init__(self, scheduler): + self.scheduler = scheduler + + def group_fn(self, sizes): + return tuple(V.graph.sizevars.simplify(sympy_product(s)) for s in sizes) + + def can_fuse(self, node1, node2): + """ + Hook called by Scheduler to determine if the Triton backend + can fuse node1 and node2. These nodes might already be + FusedSchedulerNodes. + """ + if isinstance(node1, scheduler.ForeachKernelSchedulerNode) or isinstance( + node2, scheduler.ForeachKernelSchedulerNode + ): + return scheduler.ForeachKernelSchedulerNode.can_fuse(node1, node2) + + _, (numel1, rnumel1) = node1.group + _, (numel2, rnumel2) = node2.group + why = WhyNoFuse(node1, node2) + + if node1.is_split_scan() and not node2.is_split_scan(): + if node2.is_reduction(): + why("Split scan cannot fuse with reductions") + elif node2.is_split_scan() and not node1.is_split_scan(): + if node1.is_reduction(): + why("Split scan cannot fuse with reductions") + + if node1.is_reduction() and node2.is_reduction(): + reduction_can_fuse = numel1 == numel2 and rnumel1 == rnumel2 + if not reduction_can_fuse: + why( + "numel/rnumel mismatch (reduce) (%s, %s), (%s, %s)", + numel1, + numel2, + rnumel1, + rnumel2, + ) + return reduction_can_fuse + + if not node1.is_reduction() and not node2.is_reduction(): + if not (numel1 == numel2 and rnumel1 == rnumel2): + why( + "numel/rnumel mismatch (non-reduce) (%s, %s), (%s, %s)", + numel1, + numel2, + rnumel1, + rnumel2, + ) + return False + + if node1.is_template(): + # Only allow fusion for TritonTemplates for now. + # Fusion for CUDATemplates are not supported. + is_triton_template = isinstance(node1.node, TritonTemplateBuffer) + if not is_triton_template: + why("node1 is not TritonTemplateBuffer") + return is_triton_template + + # check for a bad combined tiling + tiling1 = self.select_tiling(node1.get_nodes(), numel1, rnumel1) + tiling2 = self.select_tiling(node2.get_nodes(), numel1, rnumel1) + tiling3 = self.select_tiling( + node1.get_nodes() + node2.get_nodes(), numel1, rnumel1 + ) + if config.triton.tiling_prevents_pointwise_fusion: + cond = True + if len(tiling1) > 2: + if len(tiling2) > 2: + cond = tiling1 == tiling2 == tiling3 + else: + cond = tiling1 == tiling3 + elif len(tiling2) > 2: + cond = tiling2 == tiling3 + if not cond: + why( + "tiling mismatch (%s, %s, %s)", + tiling1, + tiling2, + tiling3, + ) + return False + + return True + + if not node1.is_reduction() and node2.is_reduction(): + assert rnumel1 == 1 and rnumel2 != 1 + if numel1 == numel2 * rnumel2: + if not all( + SIMDKernel.is_compatible((numel2, rnumel2), n.get_ranges()) + for n in node1.get_nodes() + ): + why("nodes numel/rnumel incompatibility") + return False + if ( + config.triton.tiling_prevents_reduction_fusion + and not node1.is_template() + ): + is_reduction_tiling_valid = self.select_tiling( + node1.get_nodes(), numel1 + ) in ( + (numel1, 1), + (numel2, rnumel2, 1), + ) + if not is_reduction_tiling_valid: + why("invalid tiling for reduction") + return is_reduction_tiling_valid + return True + + if numel1 != numel2: + why("nodes numel incompatibility") + return numel1 == numel2 + + assert node1.is_reduction() and not node2.is_reduction() + # swap args to hit the case above + return self.can_fuse_horizontal(node2, node1) + + can_fuse_vertical = can_fuse + can_fuse_horizontal = can_fuse + + def generate_node_schedule(self, nodes, numel, rnumel): + node_schedule: List[Any] = [] + current_loop_writes: Set[str] = set() + + # Writes with a reduced shape, meaning they are only present once the + # reduction loop has ended + current_loop_reduced_writes = set() + current_loop_has_writes = False + done = set() + + def fits_in_main_body(n): + _, (node_numel, node_rnumel) = n.group + return (node_numel == numel and node_rnumel == rnumel) or ( + node_numel == numel * rnumel and node_rnumel == 1 + ) + + def fits_outside_reduction(n): + _, (node_numel, node_rnumel) = n.group + return node_numel == numel and node_rnumel == 1 and rnumel != 1 + + def schedule_node_in_loop(n): + nonlocal current_loop_has_writes + done.add(n) + node_schedule.append(n) + current_loop_has_writes = True + # A scan is modelled as a reduction in the scheduler but has a + # full sized output that can be used inside the loop body + if ( + n.is_reduction() + and isinstance(n, scheduler.SchedulerNode) + and isinstance(n.node, ir.ComputedBuffer) + and not isinstance(n.node.data, ir.Scan) + ): + current_loop_reduced_writes.add(n.get_name()) + + @contextlib.contextmanager + def end_current_reduction_loop(): + nonlocal current_loop_has_writes + if current_loop_has_writes: + # flush out any other runnable nodes to reduce number of loops + for other_node in nodes[index + 1 :]: + if ( + node not in done + and fits_in_main_body(other_node) + and not (current_loop_reduced_writes & other_node.ancestors) + ): + schedule_node_in_loop(node) + + if node_schedule and node_schedule[-1] is EnableReduction: + node_schedule.pop() + else: + node_schedule.append(DisableReduction) + yield + node_schedule.append(EnableReduction) + current_loop_reduced_writes.clear() + current_loop_has_writes = False + + for index, node in enumerate(nodes): + if node in done: + continue + done.add(node) + + def requires_closing_previous_reduction(node, node_schedule): + if rnumel == 1: + return False + if not current_loop_reduced_writes & node.ancestors: + return False + assert node_schedule and not isinstance( + node_schedule[-1], (EnableReduction, DisableReduction) + ) + return bool(current_loop_reduced_writes) + + if fits_in_main_body(node): + if requires_closing_previous_reduction(node, node_schedule): + with end_current_reduction_loop(): + pass # need to start a new reduction loop + + schedule_node_in_loop(node) + elif fits_outside_reduction(node): + with end_current_reduction_loop(): + node_schedule.append(node) + else: + raise NotImplementedError( + f"unexpected group: ({numel}, {rnumel}) != {node.group[1]}" + ) + + return node_schedule + + def codegen_node( + self, node: Union[scheduler.FusedSchedulerNode, scheduler.SchedulerNode] + ): + """ + Given a set of pre-fused nodes, generate a Triton kernel. + """ + + nodes: List[scheduler.SchedulerNode] = node.get_nodes() # type: ignore[assignment] + + _, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group + + node_schedule = self.generate_node_schedule(nodes, numel, rnumel) + buf_accesses = collections.defaultdict(list) + for node in nodes: + for access in node.read_writes.reads | node.read_writes.writes: + buf_accesses[access.name].append(access) + + schedule_log.debug("Schedule:\n %s", node_schedule) + + return self.codegen_node_schedule(node_schedule, buf_accesses, numel, rnumel) + + @staticmethod + def reduction_hint(node): + assert node.is_reduction() + if all( + dep.is_contiguous() + for dep in itertools.chain(node.read_writes.reads, node.read_writes.writes) + ): + return ReductionHint.INNER + else: + return node.node.data.reduction_hint + + @staticmethod + def can_use_32bit_indexing( + numel: sympy.Expr, buffers: Iterable[Union[ir.Buffer, ir.TensorBox]] + ) -> bool: + int_max = torch.iinfo(torch.int32).max + size_hint = V.graph.sizevars.size_hint + has_hint = V.graph.sizevars.shape_env.has_hint + + def within_32bit(e): + # Allow for unhinted e as long as we can still statically prove + # (e.g., via ValueRanges) that it is still in bounds + if V.graph.sizevars.is_expr_static_and_true(e <= int_max): + return True + # Otherwise, the hint MUST exist and be in range + return has_hint(e) and size_hint(e) <= int_max + + if not within_32bit(numel): + return False + + # Any use of a MultiOutputLayout will create a buffer with a + # Layout whose sizes are accounted for + buf_sizes = [ + buf.get_layout().storage_size() + for buf in buffers + if not isinstance(buf.get_layout(), ir.MultiOutputLayout) + ] + + if not all(within_32bit(size) for size in buf_sizes): + return False + + # Only install guards for 32-bit indexing as there is no correctness + # issue with using 64-bit for everything + V.graph.sizevars.guard_leq(numel, int_max) # type: ignore[arg-type] + for size in buf_sizes: + V.graph.sizevars.guard_leq(size, int_max) # type: ignore[arg-type] + return True + + @classmethod + def select_index_dtype(cls, node_schedule, numel, reduction_numel): + # Gather all used buffer names + buffer_names = set() + for node in node_schedule: + if not isinstance(node, scheduler.BaseSchedulerNode): + continue + + buffer_names.update(node.get_names()) + buffer_names.update(node.used_buffer_names()) + + # Get buffers objects + + def _get_buffer(name: str) -> Union[ir.Buffer, ir.TensorBox]: + buf = V.graph.get_buffer(name) + if buf is None: + raise RuntimeError(f"Failed to find buffer matching name {name}") + return buf + + buffers = [V.graph.get_buffer(name) for name in buffer_names] + + # In theory we can separately check xnumel and rnumel are <= int_max + # but some indexers do use the full linear index so we need to be + # conservative here. + total_numel = numel * reduction_numel + + if SIMDScheduling.can_use_32bit_indexing(total_numel, buffers): + return cls.int32_type + return cls.int64_type + + def has_non_contiguous_pw_in_reduction_kernel(self, node_schedule, numel, rnumel): + pointwise_nodes = list( + filter( + lambda n: n not in (EnableReduction, DisableReduction) + and not n.is_reduction() + and n.group[1][0] == numel * rnumel, + node_schedule, + ) + ) + for node in pointwise_nodes: + # An index can be an integer when loading a random seed. + if not all( + not isinstance(dep, MemoryDep) + or dep.is_contiguous() + or isinstance(dep.index, (sympy.Integer, int)) + or dep.stride1_for_last_dim() + for dep in itertools.chain( + node.read_writes.reads, node.read_writes.writes + ) + ): + return True + return False + + def get_kernel_args(self, node_schedule, numel, reduction_numel): + reductions = list( + filter( + lambda n: n not in (EnableReduction, DisableReduction) + and n.is_reduction(), + node_schedule, + ) + ) + if len(reductions) > 0: + hints = [self.reduction_hint(n) for n in reductions] + if hints.count(hints[0]) == len(hints): + reduction_hint_val = hints[0] + else: + reduction_hint_val = ReductionHint.DEFAULT + + if ( + reduction_hint_val == ReductionHint.INNER + and self.has_non_contiguous_pw_in_reduction_kernel( + node_schedule, numel, reduction_numel + ) + ): + reduction_hint_val = ReductionHint.DEFAULT + else: + reduction_hint_val = ReductionHint.DEFAULT + + mutations = set() + for node in node_schedule: + if hasattr(node, "get_mutations"): + mutations.update(node.get_mutations()) + + index_dtype = self.select_index_dtype(node_schedule, numel, reduction_numel) + + return reduction_hint_val, mutations, index_dtype + + def codegen_node_schedule( + self, node_schedule, buf_accesses, numel, reduction_numel + ): + from torch._inductor.codegen.triton_split_scan import TritonSplitScanKernel + + tiled_groups = self.select_tiling(node_schedule, numel, reduction_numel) + ( + reduction_hint_val, + mutations, + index_dtype, + ) = self.get_kernel_args(node_schedule, numel, reduction_numel) + + is_split_scan = any( + isinstance(node, BaseSchedulerNode) and node.is_split_scan() + for node in node_schedule + ) + kernel_type = TritonSplitScanKernel if is_split_scan else self.kernel_type + kernel_args = tiled_groups + kernel_kwargs = { + "reduction_hint": reduction_hint_val, + "mutations": mutations, + "index_dtype": index_dtype, + } + kernel = kernel_type( + *kernel_args, + **kernel_kwargs, + ) + kernel.buf_accesses = buf_accesses + + self.codegen_node_schedule_with_kernel(node_schedule, kernel) + + with V.set_kernel_handler(kernel): + src_code = kernel.codegen_kernel() + + kernel_name = self.define_kernel(src_code, node_schedule, kernel) + log.debug("Generating kernel code with kernel_name: %s", kernel_name) + kernel.kernel_name = kernel_name + kernel.code_hash = code_hash(src_code) + + if kernel.persistent_reduction and config.triton.multi_kernel: + kernel2 = self.kernel_type( + *kernel_args, + **kernel_kwargs, + disable_persistent_reduction=True, + ) + self.codegen_node_schedule_with_kernel(node_schedule, kernel2) + with V.set_kernel_handler(kernel2): + src_code2 = kernel2.codegen_kernel() + kernel_name2 = self.define_kernel(src_code2, node_schedule, kernel) + kernel2.kernel_name = kernel_name2 + kernel2.code_hash = code_hash(src_code2) + + final_kernel = MultiKernel([kernel, kernel2]) + else: + final_kernel = kernel # type: ignore[assignment] + + with V.set_kernel_handler(final_kernel): + for node in node_schedule: + if node not in (EnableReduction, DisableReduction): + node.mark_run() + + self.codegen_comment(node_schedule) + final_kernel.call_kernel(final_kernel.kernel_name) + if config.nan_asserts: + final_kernel.codegen_nan_check() + if config.warn_mix_layout: + final_kernel.warn_mix_layout(kernel_name) + + V.graph.removed_buffers |= final_kernel.removed_buffers + V.graph.inplaced_to_remove |= final_kernel.inplaced_to_remove + + if ( + V.graph.wrapper_code.supports_intermediate_hooks + and config.generate_intermediate_hooks + ): + # Not every node in the schedule will actually be live on output; + # we can't check dead buffers. + live_outs = kernel.args.live_output_buffers() + for node in node_schedule: + if not isinstance(node, scheduler.BaseSchedulerNode): + continue + name = node.get_name() + if name not in live_outs: + continue + origin_node = node.node.get_origin_node() + if origin_node is not None: + counters["inductor"]["intermediate_hooks"] += 1 + V.graph.wrapper_code.writeline( + f"run_intermediate_hooks({origin_node.name!r}, {name})" + ) + + self.scheduler.free_buffers() + + def codegen_node_schedule_with_kernel(self, node_schedule, kernel): + def current_reduction_nodes(nodes): + return itertools.takewhile(lambda n: n is not DisableReduction, nodes) + + with kernel: + stack = contextlib.ExitStack() + kernel.set_last_usage(current_reduction_nodes(node_schedule)) + + for node in node_schedule: + if node not in (EnableReduction, DisableReduction): + node.decide_inplace_update() + for i, node in enumerate(node_schedule): + if node is DisableReduction: + stack.enter_context(kernel.disable_reduction()) + elif node is EnableReduction: + stack.close() + kernel.set_last_usage(current_reduction_nodes(node_schedule[i:])) + else: + # TODO - use split ranges ? + indexing_dtype_strength_reduction(node._body) + index_vars = kernel.split_and_set_ranges(node.get_ranges()) + node.codegen(index_vars) + + def codegen_template( + self, template_node, epilogue_nodes, only_gen_src_code=False + ) -> Optional[str]: + """ + Codegen a triton template + + If `only_gen_src_code` the src code will be returned instead of codegen'd into the wrapper + """ + _, (numel, rnumel) = template_node.group + assert rnumel == 1 + kernel, render = template_node.node.make_kernel_render(template_node.node) + with kernel: + if not only_gen_src_code: + for node in [template_node, *epilogue_nodes]: + node.mark_run() + partial_code = render() + for node in epilogue_nodes: + node.codegen(kernel.split_and_set_ranges(node.get_ranges())) + + # finalize must be called after adding epilogue above + with V.set_kernel_handler(kernel): + # TODO: Maybe unify CUDATemplateKernel to also use PartialRender for flexible epilogue fusion. + src_code = ( + partial_code + if isinstance(partial_code, str) + else partial_code.finalize() + ) + node_schedule = [template_node, *epilogue_nodes] + + if config.benchmark_kernel: + num_gb = kernel.estimate_kernel_num_bytes() / 1e9 + grid_args = V.graph.sizevars.size_hints(kernel.call_sizes) + assert kernel.meta is not None, "meta is None" + grid = kernel.grid_fn(*grid_args, kernel.meta) + src_code = ( + f"{kernel.imports_for_benchmark_kernel()}\n" + f"{src_code}\n" + f"{kernel.codegen_kernel_benchmark(num_gb, grid).getvalue()}" + ) + + if only_gen_src_code: + return src_code + + kernel_name = self.define_kernel(src_code, node_schedule, kernel) + + self.codegen_comment(node_schedule) + kernel.call_kernel(kernel_name, template_node.node) + V.graph.removed_buffers |= kernel.removed_buffers + V.graph.inplaced_to_remove |= kernel.inplaced_to_remove + self.scheduler.free_buffers() + return None + + def codegen_sync(self): + V.graph.wrapper_code.writeline(V.graph.device_ops.synchronize()) + + def codegen_foreach(self, foreach_node): + from .triton_foreach import ForeachKernel + + for partitions_with_metadata in ForeachKernel.horizontal_partition( + foreach_node.get_subkernel_nodes(), self + ): + kernel = ForeachKernel() + for nodes, tiled_groups, numel, rnumel in partitions_with_metadata: + node_schedule = self.generate_node_schedule(nodes, numel, rnumel) + ( + reduction_hint_val, + mutations, + index_dtype, + ) = self.get_kernel_args(node_schedule, numel, rnumel) + + subkernel = kernel.create_sub_kernel( + *tiled_groups, + reduction_hint=reduction_hint_val, + mutations=mutations, + index_dtype=index_dtype, + ) + + self.codegen_node_schedule_with_kernel( + node_schedule, + subkernel, + ) + + with V.set_kernel_handler(subkernel): + for node in node_schedule: + if node not in (EnableReduction, DisableReduction): + node.mark_run() + V.graph.removed_buffers |= subkernel.removed_buffers + V.graph.inplaced_to_remove |= subkernel.inplaced_to_remove + + src_code = kernel.codegen_kernel() + kernel_name = self.define_kernel(src_code, [foreach_node], kernel) + self.codegen_comment([foreach_node]) + kernel.call_kernel(V.graph.wrapper_code, kernel_name) + + self.scheduler.free_buffers() + + @staticmethod + @functools.lru_cache(32) + def candidate_tilings(node): + ranges, reduction_ranges = node.get_ranges() + if len(ranges) <= 1: + return () + + rw = node.pointwise_read_writes() + assert len(rw.range_vars) == len(ranges) + + # isinstance(dep, MemoryDep): this filters out StarDeps. StarDeps refer to reads + # that need to access the entire tensor; they don't contribute read indexing + # information (and practically, they don't have dep.index so they can't be used + # for stride_hints below + dep_sources = [rw.reads, rw.writes] + assert all( + isinstance(dep, (MemoryDep, StarDep)) + for dep in itertools.chain.from_iterable(dep_sources) + ) + deps = [ + dep + for dep in itertools.chain.from_iterable(dep_sources) + if dep.name not in V.graph.removed_buffers and isinstance(dep, MemoryDep) + ] + write_names = {dep.name for dep in rw.writes} + + tilings: List[CandidateTiling] = [] + + for dep in deps: + strides = V.graph.sizevars.stride_hints(dep.index, rw.range_vars) + assert len(strides) == len(ranges) + try: + split = strides.index(1) + 1 + if split == len(ranges): + continue + if all(s == 0 for s in strides[split:]): + # if this is a broadcasted tensor and all dimensions after split are broadcast, + # this is not a real split + continue + + except ValueError: + continue + tiled_groups = ( + V.graph.sizevars.simplify(sympy_product(ranges[:split])), + V.graph.sizevars.simplify(sympy_product(ranges[split:])), + ) + # score by number of elements + score = V.graph.sizevars.size_hint( + sympy_product( + size for size, stride in zip(ranges, strides) if stride != 0 + ) + ) + if dep.name in write_names: + # ngimel said contiguous writes is more important than reads + score *= 2 + if CandidateTiling.is_good_size(tiled_groups[0]): + score *= 2 + if CandidateTiling.is_good_size(tiled_groups[1]): + score *= 2 + + if ( + V.graph.sizevars.size_hint( + score - sympy_product(itertools.chain(ranges, reduction_ranges)) + ) + >= 0 + ): + tilings.append(CandidateTiling(tiled_groups, score, dep.name)) + return tilings + + @classmethod + def select_tiling(cls, node_schedule, numel, reduction_numel=sympy.Integer(1)): + """ + Heuristics to decide how to tile kernels. + Currently, we tile based on stride-1 dimensions. + + Returns: + `(tile1, tile2, reduction_numel)` s.t. `tile1 * tile2 == numel` + + """ + if reduction_numel != 1 or config.triton.max_tiles <= 1: + # TODO(jansel): should we tile reductions? + # do perf hint here if stride-1 dim is not being reduced + if perf_hint_log.level <= logging.WARNING: + for node in EnableReduction.filter(node_schedule): + if len(cls.candidate_tilings(node)) > 0: + perf_hint_log.info("reduction over non-contiguous dims") + break + return (numel, reduction_numel) + + seen_names = set() + candidate_tiles: Counter[Any] = collections.Counter() + for node in EnableReduction.filter(node_schedule): + for tiling in cls.candidate_tilings(node): + if tiling.name in seen_names: + continue + seen_names.add(tiling.name) + candidate_tiles[tiling.tiling] += tiling.score + + ranked_tilings = [tiling for tiling, score in candidate_tiles.most_common()] + + if config.triton.max_tiles >= 3: + # Consider adding a third dimension of tiling, but only + # when a1 is a multiple of b1; otherwise, you have a lot + # of stragglers which is annoying to generate code for. + # + # NB: More than three max tiles is not enabled by default. + + # Add one 3D tiling choice + for i in range(1, len(ranked_tilings)): + a0, a1 = ranked_tilings[0] + b0, b1 = ranked_tilings[i] + if V.graph.sizevars.size_hint(a1 - b1) == 0: + continue + if V.graph.sizevars.size_hint(a1 - b1) < 0: + # swap so a0 is bigger + a0, a1 = ranked_tilings[i] + b0, b1 = ranked_tilings[0] + assert V.graph.sizevars.size_hint(a1 - b1) > 0 + if V.graph.sizevars.statically_known_multiple_of(a1, b1): + tiling = (a0, FloorDiv(a1, b1), b1) + ranked_tilings = [tiling] + ranked_tilings + break # only 1 choice for now + + if len(ranked_tilings) > 1: + perf_hint_log.info("possibly bad tiling: %s", ranked_tilings) + + for tiled_groups in ranked_tilings: + new_groups = (*tiled_groups, reduction_numel) + if all( + SIMDKernel.is_compatible(new_groups, node.get_ranges()) + for node in node_schedule + if isinstance(node, scheduler.SchedulerNode) + ): + return new_groups + + return (numel, reduction_numel) + + def flush(self): + pass + + def ready_to_flush(self) -> bool: + return False + + def generate_kernel_code_from_nodes(self, nodes, benchmark_kernel=False): + @dataclasses.dataclass + class LastUsageHolder: + n: Any + last_usage: Any + + def __del__(self): + self.n.last_usage = self.last_usage + + last_usage_holders = [LastUsageHolder(n, n.last_usage) for n in nodes] + + # empty last_usage. May cause more aggressive 'evict_last'. Should be fine. + for n in nodes: + n.last_usage = set() + + if not nodes[0].is_template(): + _, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group + node_schedule = self.generate_node_schedule(nodes, numel, rnumel) + + tiled_groups = self.select_tiling(node_schedule, numel, rnumel) + reduction_hint_val, mutations, index_dtype = self.get_kernel_args( + node_schedule, numel, rnumel + ) + + kernel = self.kernel_type( + *tiled_groups, + reduction_hint=reduction_hint_val, + mutations=mutations, + index_dtype=index_dtype, + ) + + self.codegen_node_schedule_with_kernel(node_schedule, kernel) + with config.patch( + "benchmark_kernel", benchmark_kernel + ), V.set_kernel_handler(kernel): + src_code = kernel.codegen_kernel() + else: + template_node = nodes[0] + epilogue_nodes = nodes[1:] + + with config.patch("benchmark_kernel", benchmark_kernel): + src_code = self.codegen_template( + template_node, epilogue_nodes, only_gen_src_code=True + ) + + src_code = src_code.replace(str(Placeholder.KERNEL_NAME), "triton_") + return src_code + + def codegen_comment(self, node_schedule): + pass + + def define_kernel(self, src_code, node_schedule, kernel): + raise NotImplementedError + + +@dataclasses.dataclass +class CandidateTiling: + tiling: Tuple[sympy.Expr, sympy.Expr] + score: int # higher is better + name: Optional[str] = None + + @staticmethod + def is_good_size(s): + """Somewhat arbitrary heuristic used to boost scores for some sizes""" + s = V.graph.sizevars.size_hint(s) + return s >= 32 and (s % 32 == 0) + + +class DisableReduction: + """ + Marker to invoke `kernel.disable_reduction()`. This closes a + reduction loop and allows for pointwise ops to occur on the output + of a reduction. + """ + + +class EnableReduction: + """ + Marker to end a DisableReduction block. + """ + + @staticmethod + def filter(node_schedule): + """ + Get the nodes from node_schedule skipping those in a + DisableReduction block. + """ + disabled = False + for node in node_schedule: + if node in (EnableReduction, DisableReduction): + # Don't tile stuff outside the main reduction loop + disabled = node is DisableReduction + elif disabled: + pass + else: + yield node + + +class CantSplit(Exception): + pass diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index b1b7d951b99a..183d28605b87 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -1,30 +1,13 @@ from __future__ import annotations -import collections -import contextlib import dataclasses import functools import itertools import logging -import math -import operator import os import textwrap from functools import lru_cache -from typing import ( - Any, - Callable, - cast, - Counter, - DefaultDict, - Dict, - Iterable, - List, - Optional, - Set, - Tuple, - Union, -) +from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple, Union import sympy @@ -33,42 +16,24 @@ import torch.utils._pytree as pytree from torch._dynamo.utils import preserve_rng_state -from torch._inductor.metrics import is_metric_table_enabled, log_kernel_metadata from torch._inductor.runtime.hints import AutotuneHint, DeviceProperties from torch._prims_common import is_integer_dtype -from torch.utils._sympy.functions import FloorDiv, ModularIndexing -from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT -from torch.utils._sympy.value_ranges import ValueRanges from torch.utils._triton import has_triton_package +from ...utils._sympy.value_ranges import ValueRanges -from ..._dynamo.utils import counters -from .. import config, ir, scheduler +from .. import config, ir from ..codecache import code_hash, get_path, PyCodeCache -from ..dependencies import Dep, MemoryDep, StarDep, WeakDep -from ..ir import IRNode, TritonTemplateBuffer -from ..optimize_indexing import indexing_dtype_strength_reduction +from ..ir import IRNode +from ..metrics import is_metric_table_enabled, log_kernel_metadata from ..runtime.hints import ReductionHint, TRITON_MAX_BLOCK -from ..runtime.runtime_utils import ( - do_bench_gpu, - get_max_y_grid, - green_text, - next_power_of_2, - yellow_text, -) -from ..scheduler import BaseSchedulerNode, BaseScheduling, WhyNoFuse +from ..runtime.runtime_utils import do_bench_gpu, next_power_of_2 from ..utils import ( cache_on_self, get_bounds_index_expr, - get_dtype_size, get_fused_kernel_name, get_kernel_metadata, is_welford_reduction, Placeholder, - sympy_dot, - sympy_index_symbol, - sympy_product, - sympy_subs, - unique, ) from ..virtualized import _ops as ops, OpsHandler, ReductionType, StoreMode, V from ..wrapper_benchmark import get_kernel_category_by_source_code @@ -77,17 +42,21 @@ CSEVariable, DeferredLine, IndentedBuffer, - index_prevent_reordering, - Kernel, OpOverrides, PythonPrinter, SizeArg, TensorArg, ) -from .multi_kernel import MultiKernel +from .simd import ( + IndexingOptions, + IterationRangesEntry, + pexpr, + SIMDKernel, + SIMDScheduling, + triton_constant, +) from .triton_utils import config_of, signature_of, signature_to_meta - log = logging.getLogger(__name__) perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints") schedule_log = torch._logging.getArtifactLogger(__name__, "schedule") @@ -133,27 +102,6 @@ def gen_common_triton_imports(): return imports.getvalue() -@dataclasses.dataclass -class IndexingOptions: - index_str: str - mask_vars: Set[sympy.Symbol] - mask_str: str - expand_str: Optional[str] - _has_rindex: bool - - def has_mask(self): - return bool(self.mask_vars) - - def has_rindex(self): - return self._has_rindex - - def has_tmpmask(self): - return "tmp" in self.mask_str - - def has_rmask(self): - return "rmask" in self.mask_str - - @dataclasses.dataclass class BlockPtrOptions: constant_offset: sympy.Expr @@ -416,7 +364,6 @@ def _print_RoundDecimal(self, expr): texpr = TritonPrinter().doprint -pexpr = PythonPrinter().doprint def triton_compute_type(dtype): @@ -455,16 +402,6 @@ def triton_acc_type(dtype): return triton_compute_type(dtype) -def triton_constant(value): - if value == float("inf"): - return 'float("inf")' - elif value == float("-inf"): - return 'float("-inf")' - elif math.isnan(value): - return 'float("nan")' - return repr(value) - - class TritonCSEVariable(CSEVariable): def __init__(self, name, bounds: ValueRanges[Any]): super().__init__(name, bounds) @@ -487,9 +424,6 @@ def update_on_args(self, name, args, kwargs): # those reads should subsequently be masked, self.mask_vars.update({f"{arg.name[0]}mask"}) - def __repr__(self): - return f"TritonCSEVariable(name={self.name})" - class TritonOverrides(OpOverrides): """Map element-wise ops to Triton""" @@ -965,283 +899,6 @@ def _typecheck_TritonKernelOverrides(h: TritonKernelOverrides) -> OpsHandler[str return h -@dataclasses.dataclass -class IterationRanges: - """ - Each range tree represents multiple sets of iteration indexing - in a single tiled dimension in the output kernel. - - If you have two loops ranges one (4, 3, 2) and another (4, 6), - then the range tree will be: - 4 (i0) - 3 (i1) 6 (i3) - 2 (i2) - Where i0 is shared between both loops, but then the split into - different indexing vars. All loop ranges must iterate over - the same number of elements. - """ - - def __init__( - self, - name: str, - var_list: List[sympy.Symbol], - var_ranges: Dict[sympy.Symbol, sympy.Expr], - numel: sympy.Expr, - prefix: str, - *, - kernel: TritonKernel, - divisor=sympy.Integer(1), - length=sympy.Integer(1), - root: IterationRangesRoot, - ): - super().__init__() - self.name = name - self.var_list = var_list - self.var_ranges = var_ranges - self.numel = numel - self.prefix = prefix - self.divisor = divisor - self.length = length - self.kernel = kernel - self.root = root - - def symbol(self): - return sympy_index_symbol(self.name) - - -class IterationRangesRoot(IterationRanges): - def __init__( - self, - name: str, - numel: sympy.Expr, - # TODO: this is probably SymTy.INDEX and SymTy.RINDEX - prefix: str, - index: int, - kernel: TritonKernel, - pid_cache=None, - *, - is_loop: bool, - tensor_dim: Optional[int], - grid_dim: Optional[int], - has_zdim: bool, - ): - if pid_cache is None: - pid_cache = {} - super().__init__( - name=name, - var_list=[], - var_ranges={}, - numel=numel, - prefix=prefix, - kernel=kernel, - root=self, - ) - self.index = index - # Store all the nodes in one flat list - self.nodes: Dict[sympy.Expr, IterationRangesEntry] = {} - # This is for re-ordering program ID in triton mm template - # pid_cache["tl.program_id(0)"] = pid_m - self.pid_cache: Dict[str, str] = pid_cache - - # True if the dimension is implemented as a single program looping over - # the full dimension (currently only used for non-persistent reduction) - assert not is_loop or (prefix == "r" and grid_dim is None) - self.is_loop = is_loop - # Index of corresponding dimension on triton tensors - self.tensor_dim = tensor_dim - # Index of corresponding dimension in the triton grid - self.grid_dim = grid_dim - self.has_zdim = has_zdim - - def __repr__(self): - return f"IterationRangesRoot({self.name!r}, {self.numel}, ...)" - - def cache_clear(self): - for node in self.nodes.values(): - node.cache_clear() - - def lookup(self, divisor, length): - """ - Lookup a given RangeTreeEntry, creating it if needed - """ - if V.graph.sizevars.statically_known_equals(divisor * length, self.numel): - expr = FloorDiv(sympy_index_symbol(f"{self.prefix}index"), divisor) - else: - expr = ModularIndexing( - sympy_index_symbol(f"{self.prefix}index"), divisor, length - ) - - if expr not in self.nodes: - node = IterationRangesEntry( - f"{self.prefix}{next(V.kernel.iter_vars_count)}", - divisor, - length, - expr, - self, - ) - V.kernel.range_tree_nodes[node.symbol()] = node - self.var_list.append(node.symbol()) - self.var_ranges[node.symbol()] = length - self.nodes[expr] = node - return self.nodes[expr] - - def construct_entries(self, lengths: List[sympy.Expr]): - divisor = sympy.Integer(1) - itervars = [] - for length in reversed(lengths): - itervars.append(self.lookup(divisor, length)) - divisor = divisor * length - return list(reversed(itervars)) - - def construct(self, lengths: List[sympy.Expr]): - return [e.symbol() for e in self.construct_entries(lengths)] - - def vars_and_sizes(self, index: sympy.Expr): - """Figure out vars from this tree used in index""" - nodes = [V.kernel.range_tree_nodes.get(s) for s in index.free_symbols] - nodes = [n for n in nodes if n and n.prefix == self.prefix] - nodes.sort(key=lambda x: V.graph.sizevars.size_hint(x.divisor)) - divisor = sympy.Integer(1) - index_vars = [] - sizes = [] - - def add(node): - nonlocal divisor - index_vars.append(node.symbol()) - sizes.append(node.length) - divisor = divisor * node.length - - for node in nodes: - if not V.graph.sizevars.statically_known_equals(node.divisor, divisor): - # fill in unused index var - add(self.lookup(divisor, FloorDiv(node.divisor, divisor))) - divisor = node.divisor - add(node) - if not V.graph.sizevars.statically_known_equals(self.numel, divisor): - # fill in unused index var - add(self.lookup(divisor, FloorDiv(self.numel, divisor))) - - return list(reversed(index_vars)), list(reversed(sizes)) - - def ranges_code(self): - assert self.tensor_dim is not None - size = self.kernel.indexing_size_str(self.tensor_dim) - index_dtype = self.kernel.index_dtype - convert = f".to({index_dtype})" if index_dtype != "tl.int32" else "" - return f"tl.arange(0, {self.prefix.upper()}BLOCK){size}{convert}" - - def scalar_code(self, value): - index_dtype = self.kernel.index_dtype - ndim = self.kernel.triton_tensor_ndim() - size = [1] * ndim - return f"tl.full({size}, {value}, {index_dtype})" - - def get_pid(self): - assert self.grid_dim is not None - key = f"tl.program_id({self.grid_dim})" - # y_grid has a limit, so express it in terms of y and z in case of overflow. - # z grid is only exercised when max_tiles == 3 (off by default). - if ( - self.grid_dim == 1 - and not self.has_zdim - and not (isinstance(self.numel, int) and self.numel <= get_max_y_grid()) - ): - key = f"{key} * (tl.program_id({self.grid_dim + 1}) + 1)" - pid = self.pid_cache.get(key, key) - if self.kernel.index_dtype != "tl.int32": - return f"{pid}.to({self.kernel.index_dtype})" - return pid - - def codegen_header(self, code): - x = self.prefix - if self.is_loop: - code.writeline(f"{self.name} = {x}offset + {x}base") - elif self.grid_dim is None: - # no need to "{x}offset = " - code.writeline(f"{self.name} = {self.ranges_code()}") - code.writeline(f"{x}offset = 0") - else: - if self.tensor_dim is not None: - line = f"{x}offset + {self.ranges_code()}" - else: - line = self.scalar_code(f"{x}offset") - code.writelines( - [ - f"{x}offset = {self.get_pid()} * {x.upper()}BLOCK", - f"{self.name} = {line}", - ] - ) - code.writeline(f"{x}mask = {self.name} < {x}numel") - - -class IterationRangesEntry(IterationRanges): - def __init__( - self, - name: str, - divisor: sympy.Expr, - length: sympy.Expr, - expr: sympy.Expr, - parent: IterationRanges, - ): - super().__init__( - name=name, - numel=parent.numel / length, - var_list=parent.var_list, - var_ranges=parent.var_ranges, - prefix=parent.prefix, - divisor=divisor, - length=length, - kernel=parent.kernel, - root=parent.root, - ) - self.parent = parent - self.codegen = functools.lru_cache(None)(self._codegen) - self.expr = expr - - def __repr__(self): - return f"IterationRangesEntry({self.name}, {self.divisor}, {self.length}, {self.expr}, {self.var_ranges})" - - def set_name(self, name): - self.codegen = lambda: name # type: ignore[assignment] - self.codegen.cache_clear = lambda: None # type: ignore[method-assign] - self.name = name - - def cache_clear(self): - self.codegen.cache_clear() - - def writeline(self, line): - if self.root.is_loop: - V.kernel.indexing_code.writeline(line) - else: - # lift non-reduction stores outside loop - V.kernel.body.writeline(line) - - def _codegen(self): - self.writeline(f"{self.name} = " + texpr(V.kernel.rename_indexing(self.expr))) - return self.name - - def precomputed_args(self): - # for dynamic shapes, find parts of indexing expressions that have to be precomputed - precomputed_args: List[sympy.Expr] = [] - if isinstance(self.expr, sympy.Symbol): - return precomputed_args - assert isinstance(self.expr, (FloorDiv, ModularIndexing)), type(self.expr) - for arg in self.expr.args[1:]: - if not isinstance(arg, (sympy.Integer, sympy.Symbol)): - symbols = arg.free_symbols - if len(symbols) > 0 and all( - symbol_is_type(s, SymT.SIZE) for s in symbols - ): - precomputed_args.append(arg) - return precomputed_args - - def __hash__(self): - return hash(self.name) - - def __eq__(self, other): - return self.name == other.name - - class HelperFunctions: """An ordered set of helper functions.""" @@ -1281,11 +938,11 @@ def __getitem__(self, idx): return self.finalized_helpers[idx] -class TritonKernel(Kernel): +class TritonKernel(SIMDKernel): overrides = TritonKernelOverrides # type: ignore[assignment] - sexpr = pexpr - helper_functions: HelperFunctions + kexpr: Callable[[sympy.Expr], str] = texpr + allow_block_ptr = True def __init__( self, @@ -1297,54 +954,35 @@ def __init__( min_elem_per_thread=0, disable_persistent_reduction=False, ): - if pid_cache is None: - pid_cache = {} - super().__init__() - self.numels = [V.graph.sizevars.simplify(s) for s in groups] - self.mutations: Set[str] = mutations if mutations is not None else set() - self.range_trees: List[IterationRangesRoot] = [] - self.range_tree_nodes: Dict[sympy.Symbol, IterationRangesEntry] = {} - self.iter_vars_count = itertools.count() - self.inside_reduction = self.numels[-1] != 1 - self.body = IndentedBuffer() - self.indexing_code = IndentedBuffer() + super().__init__( + *groups, + index_dtype=index_dtype, + mutations=mutations, + reduction_hint=reduction_hint, + pid_cache=pid_cache, + disable_persistent_reduction=disable_persistent_reduction, + ) self.suffix: IndentedBuffer = IndentedBuffer() # type: ignore[assignment] self.outside_loop_vars: Set[Any] = set() - self.reduction_hint = reduction_hint - self.index_dtype: str = index_dtype self.min_elem_per_thread = min_elem_per_thread - self.last_usage: Set[str] = set() self.block_ptr_id = itertools.count() - # buffer accesses in the kernel - self.buf_accesses: DefaultDict[str, List[Dep]] = collections.defaultdict(list) - - self.persistent_reduction: bool = ( - not disable_persistent_reduction - ) and self.should_use_persistent_reduction() - self.no_x_dim = ( - self.reduction_hint == ReductionHint.INNER - and self.persistent_reduction - and len(self.numels) == 2 - and self.numels[-1] >= 256 - ) - self.initialize_range_tree(pid_cache) - self.helper_functions = HelperFunctions() # A set of autotuning hints to pass as part of triton_meta self.autotune_hints: Set[AutotuneHint] = set() + self.triton_meta: Optional[Dict[str, object]] = None - # define this in a closure to make cache local to object - @functools.lru_cache(None) - def simplify_indexing(index: sympy.Expr): - index = V.graph.sizevars.simplify_with_ranges(index, self.var_ranges()) - for tree in self.range_trees: - index = self.combine_contiguous_dims(index, tree) - return index + self.codegen_range_tree() - self.simplify_indexing = simplify_indexing - self.code_hash = None - self.triton_meta: Optional[Dict[str, object]] = None + def codegen_range_tree(self): + for tree in self.range_trees: + # reduction indexing goes inside a loop + if not tree.is_loop: + tree.codegen_header(self.body) + if self.inside_reduction and self.range_trees[-1].is_loop: + # workaround for this issue: + # https://gist.github.com/jansel/6527126f781559095c5531f98a4235a7 + self.body.writeline(f"rbase = {self.range_trees[-1].ranges_code()}") def need_numel_args(self): r""" @@ -1384,508 +1022,21 @@ def should_use_persistent_reduction(self) -> bool: V.graph.sizevars.guard_leq(self.numels[-1], next_power_of_2(hint)) # type: ignore[arg-type] return True - def set_last_usage(self, nodes): - if not self.inside_reduction or self.persistent_reduction: - return - self.last_usage = set( - itertools.chain.from_iterable( - n.last_usage for n in nodes if n is not EnableReduction - ) - ) - - def initialize_range_tree(self, pid_cache): - no_r_dim = not self.inside_reduction or self.numels[-1] == 1 - - prefixes = "zyxr" - active_prefixes = prefixes[-len(self.numels) :] - - grid_dims = "xyz" - if self.no_x_dim: - tensor_dims = "r" - elif no_r_dim: - tensor_dims = "xyz" - else: - tensor_dims = "xyzr" - - tensor_dims = "".join(p for p in tensor_dims if p in active_prefixes) - - for i, prefix in enumerate(active_prefixes): - is_reduction = prefix == "r" - tensor_dim = tensor_dims.find(prefix) if prefix in tensor_dims else None - grid_dim = None if is_reduction else grid_dims.find(prefix) - index = i if grid_dim is None else grid_dim - self.range_trees.append( - IterationRangesRoot( - f"{prefix}index", - self.numels[i], - prefix, - index, - self, - pid_cache=pid_cache, - is_loop=is_reduction and not self.persistent_reduction, - tensor_dim=tensor_dim, - grid_dim=grid_dim, - has_zdim="z" in active_prefixes, - ) - ) - for tree in self.range_trees: - # reduction indexing goes inside a loop - if not tree.is_loop: - tree.codegen_header(self.body) - if self.inside_reduction and self.range_trees[-1].is_loop: - # workaround for this issue: - # https://gist.github.com/jansel/6527126f781559095c5531f98a4235a7 - self.body.writeline(f"rbase = {self.range_trees[-1].ranges_code()}") - - def disable_reduction(self): - should_flush = self.range_trees[-1].is_loop - - @contextlib.contextmanager - def ctx(): - if self.numels[-1] == 1: - assert not self.inside_reduction - yield - return - if should_flush: - # calling codegen_body() will flush all the pending buffers - # and write out a reduction loop - self.codegen_body() - self.inside_reduction = False - try: - yield - if should_flush: - # flush out any code before opening the next loop - self.codegen_body() - finally: - self.inside_reduction = True - - return ctx() - - def set_ranges(self, *lengths): - assert len(lengths) == len(self.range_trees) - return [ - ranges.construct(length) - for length, ranges in zip(lengths, self.range_trees) - ] - - @staticmethod - def _split_iteration_ranges( - groups: Iterable[sympy.Expr], lengths: List[List[sympy.Expr]] - ): - sv = V.graph.sizevars - new_ranges: List[List[sympy.Expr]] = [[] for _ in groups] - remaining = [sv.simplify(g) for g in groups] - var_count = itertools.count() - - def add_range(i, expr): - expr = sv.simplify(expr) - if not sv.statically_known_multiple_of(remaining[i], expr): - raise CantSplit - # guard on the last item out - remaining[i] = FloorDiv(remaining[i], expr) - new_ranges[i].append(expr) - return next(var_count) - - def make_combined(size, idx1, idx2): - def getter(flat_vars): - return size * flat_vars[idx1] + flat_vars[idx2] - - return getter - - return_getters_groups = [] - current_group = 0 - for length_group in lengths: - return_getters = [] - for size in length_group: - if sv.statically_known_equals(size, 1): # type: ignore[arg-type] - return_getters.append(lambda _: sympy.Integer(0)) - continue - - while current_group < len(remaining) and sv.statically_known_equals( - remaining[current_group], 1 # type: ignore[arg-type] - ): - # scroll to next group with remaining elements - current_group += 1 - - if current_group + 1 < len(remaining) and sv.statically_known_gt( - size, remaining[current_group] - ): - # need to break size in two - if not sv.statically_known_multiple_of( - size, remaining[current_group] - ): - raise CantSplit - size1 = remaining[current_group] - size2 = FloorDiv(size, remaining[current_group]) - return_getters.append( - make_combined( - size2, - add_range(current_group, size1), - add_range(current_group + 1, size2), - ) - ) - else: - return_getters.append( - operator.itemgetter(add_range(current_group, size)) - ) - return_getters_groups.append(return_getters) - - assert all( - V.graph.sizevars.size_hint(s) == 1 for s in remaining - ), f"failed to set ranges {remaining} {lengths}" - - return new_ranges, return_getters_groups - - @classmethod - def is_compatible( - cls, groups: Iterable[sympy.Expr], lengths: List[List[sympy.Expr]] - ): - try: - cls._split_iteration_ranges(groups, lengths) - return True - except CantSplit: - return False - - def split_and_set_ranges(self, lengths: List[List[sympy.Expr]]): - """ - We may want to fuse `for i0 in s0*s1` into a tiled kernel with groups (s0, s1). - - To do this we need to split up the iteration space of i0 into something like: - for i1 in s0: - for i2 in s1: - i0 = i1*s1 + i2 - .... - - This function matches and resplits lengths to the groups of - this kernel to enable tiled + non-tiled fusions. - """ - groups = [rt.numel for rt in self.range_trees] - if not self.inside_reduction: - groups[-1] = sympy.Integer(1) - - if len(lengths) == len(self.range_trees) and all( - V.graph.sizevars.simplify(sympy_product(x) - g) == 0 - for x, g in zip(lengths, groups) - ): - return self.set_ranges(*lengths) - - new_ranges, return_getters_groups = self._split_iteration_ranges( - groups, lengths - ) - itervars = list(itertools.chain.from_iterable(self.set_ranges(*new_ranges))) - return [[fn(itervars) for fn in fns] for fns in return_getters_groups] - - def is_indirect_indexing(self, index: sympy.Expr): - # tmpX means indirect indexing - return free_symbol_is_type(index, SymT.TMP) - - def is_broadcasted(self, index: sympy.Expr): - # Note. This may not be correct when there is indirect indexing - if self.is_indirect_indexing(index): - return False - - index_numels = [1] * len(self.numels) - for symbol in index.free_symbols: - if symbol not in self.range_tree_nodes: - # Non-iterated variables, e.g. strides - continue - entry = self.range_tree_nodes[symbol] # type: ignore[index] - assert isinstance(entry.parent, IterationRangesRoot) - index_numels[entry.parent.index] *= entry.length - - # If the index variables only iterate over a subset of the kernel - # numels, then it must be broadcasted. - simplify = V.graph.sizevars.simplify - return any( - simplify(idx_range) != simplify(iter_range) # type: ignore[arg-type] - for idx_range, iter_range in zip(index_numels, self.numels) - ) - - def combine_contiguous_dims(self, index: sympy.Expr, tree: IterationRangesRoot): - """ - More aggressive simplification to merge contiguous dims - """ - if isinstance(index, (sympy.Integer, sympy.Symbol)): - return index - index_vars, sizes = tree.vars_and_sizes(index) - if len(sizes) <= 1: - return index - new_sizes, reindex, prune = V.graph.sizevars._simplify_loops( - index_vars, sizes, index_prevent_reordering([index], index_vars, sizes) - ) - if new_sizes == sizes: - return index - new_index_vars = tree.construct(new_sizes) - new_index = sympy_subs(index, dict(zip(index_vars, reindex(new_index_vars)))) - return new_index - - def index_to_str(self, index: sympy.Expr) -> str: - """ - Convert an index expr to a string that can be used in triton code. - e.g. a sympy expression "s2" may actually appear as "ks1" in the triton kernel. - - Index expressions often need to be passed in as arguments to the triton kernel. - Rename_indexing and codegen_indexing keep track of the needed indices and add - new parameters to the function signature. - """ - if isinstance(index, list): - return f"[{', '.join(map(self.index_to_str, index))}]" - return texpr(self.rename_indexing(self.codegen_indexing(index))) - - def indexing( - self, - index: sympy.Expr, - *, - copy_shape=None, - dense_indexing=False, - override_mask=None, - block_ptr=False, - ) -> Union[IndexingOptions, BlockPtrOptions]: - """ - Compute the index and mask to pass to tl.load() or tl.store() - """ - index = self.simplify_indexing(index) - index = sympy_subs(index, V.graph.sizevars.precomputed_replacements) - # if simple replacements didn't get rid of floor/ceil, try full subs - if len(index.atoms(sympy.floor)) or len(index.atoms(sympy.ceiling)): - index = index.subs(V.graph.sizevars.precomputed_replacements) - # last resort, if no range vars are in the expr, hoist it - # TODO instead of trying to blindly find complicated exprs, we should hoist the - # inputs/outputs sizes and strides, but at the time indexing is generated - # kernel inputs and outputs are not set yet, we'd need a deeper refactor - # to do it this way - - if len(index.atoms(sympy.ceiling)): - for a in index.atoms(sympy.ceiling): - # for nested exprs, atoms yields top level first (?) - # so if everything goes fine, lower level replacements will come up empty - symbols = a.free_symbols - if len(symbols) > 0 and all( - symbol_is_type(s, (SymT.SIZE, SymT.PRECOMPUTED_SIZE)) - for s in symbols - ): - replacements = {a: V.graph.sizevars.lookup_precomputed_size(a)} - index = sympy_subs(index, replacements) - - index = self.simplify_indexing(index) - index_vars = index.free_symbols - has_rindex = False - - mask_vars: Set[str] = set() - for var in index_vars: - assert isinstance(var, sympy.Symbol) - has_rindex = has_rindex or symbol_is_type(var, SymT.RINDEX) - if override_mask: - pass - elif symbol_is_type(var, SymT.TMP): - # indirect indexing - cse_var = self.cse.varname_map[var.name] - mask_vars.update(cse_var.mask_vars) - elif symbol_is_type( - var, - ( - SymT.UNBACKED_INT, - SymT.SIZE, - SymT.PRECOMPUTED_SIZE, - SymT.INDEX, - SymT.FLOAT, - SymT.UNBACKED_FLOAT, - ), - ): - pass - else: - # var is one of xN, yN or rN - assert symbol_is_type( - var, (SymT.RINDEX, SymT.XBLOCK, SymT.YBLOCK) - ), var.name - mask_vars.add(f"{var.name[0]}mask") - - need_dense = ( - config.triton.dense_indexing - or dense_indexing - or self._load_mask is not None - ) and index != 0 - - have_dense = True - have_loop_vars = False - dense_mask_vars = set() - - for tree in self.active_range_trees(): - if index_vars.intersection(tree.var_list): - have_loop_vars = True - else: - have_dense = False - dense_mask_vars.add(f"{tree.prefix}mask") - - if ( - block_ptr - and config.triton.use_block_ptr - and not override_mask - and not self._load_mask - and len(mask_vars - dense_mask_vars) == 0 - and not self.is_indirect_indexing(index) - and have_loop_vars - # workaround https://github.com/openai/triton/issues/2821 - and self.index_dtype == "tl.int32" - ): - index_relative_to_xyr_index = sympy_subs( - index, {v: t.expr for v, t in self.range_tree_nodes.items()} - ) - range_trees = self.active_range_trees(reorder=True) - symbols = [t.symbol() for t in range_trees] - strides = [sympy.Wild(f"stride_{s}", exclude=symbols) for s in symbols] - offset = sympy.Wild("_offset", exclude=symbols) - m = index_relative_to_xyr_index.match(sympy_dot(symbols, strides) + offset) - # TODO(jansel): it is sometimes possible to do higher dimensional block_ptrs with - # a tl.reshape the correct block. We will miss these cases today. - if m: - self.filter_masks(mask_vars) - return BlockPtrOptions.create( - [m[s] for s in strides], - m[offset], - range_trees, - mask_vars, # type: ignore[arg-type] - ) - - expand_str = None - index_str = self.index_to_str(index) - if isinstance(index, sympy.Integer): - expand_str = f"{copy_shape}.shape" if copy_shape else self.dense_size_str() - index_str = f"tl.full({expand_str}, {index_str}, tl.int32)" - return IndexingOptions(index_str, set(), "None", expand_str, has_rindex) - - if need_dense and not have_dense: - expand_str = f"{copy_shape}.shape" if copy_shape else self.dense_size_str() - index_str = f"tl.broadcast_to({index_str}, {expand_str})" - mask_vars = dense_mask_vars - elif not have_loop_vars and copy_shape: - index_str = f"tl.broadcast_to({index_str}, {copy_shape}.shape)" - mask_vars = dense_mask_vars - - if override_mask: - mask_vars = {override_mask} - - if self._load_mask: - mask_vars.add(self._load_mask) - - self.filter_masks(mask_vars) - - mask_str = " & ".join(sorted(map(str, mask_vars))) if mask_vars else "None" - return IndexingOptions(index_str, mask_vars, mask_str, expand_str, has_rindex) # type: ignore[arg-type] - - def active_range_trees(self, reorder=False): - trees = [ - t for t in self.range_trees if t.prefix != "r" or self.inside_reduction - ] - if reorder and len(trees) > 1: - count = sum(t.prefix in "xyz" for t in trees) - assert "".join(t.prefix for t in trees[:count]) == "zyx"[-count:], [ - t.prefix for t in trees[:count] - ] - trees[:count] = reversed(trees[:count]) - return trees - - def filter_masks(self, mask_vars): - for tree in self.range_trees: - # Masks are superfluous if we only have one element - if V.graph.sizevars.statically_known_equals(tree.numel, 1): # type: ignore[arg-type] - mask_vars.discard(f"{tree.prefix}mask") - continue - # Masks are superfluous if numel is a multiple of BLOCK - # (We use the fact that BLOCK is required by triton to be a power of 2) - if tree.prefix.upper() not in TRITON_MAX_BLOCK: - continue - max_block = TRITON_MAX_BLOCK[tree.prefix.upper()] - # Optional optimization: if block divides numel exactly, we will - # never need to do a masked load to handle stragglers at the end. - # It's faster to avoid masking at all. But it is sound to always - # mask. - if V.graph.sizevars.statically_known_multiple_of(tree.numel, max_block): # type: ignore[arg-type] - mask_vars.discard(f"{tree.prefix}mask") - - def var_ranges(self): - return dict( - itertools.chain.from_iterable( - tree.var_ranges.items() for tree in self.range_trees - ) + def want_no_x_dim(self): + return ( + self.reduction_hint == ReductionHint.INNER + and self.persistent_reduction + and len(self.numels) == 2 + and self.numels[-1] >= 256 ) - def codegen_indexing(self, expr: sympy.Expr): - expr = V.graph.sizevars.simplify_with_ranges(expr, self.var_ranges()) - for sym in sorted(expr.free_symbols, key=str): - if sym in self.range_tree_nodes: - # if indexing expression is complicated, we precompute it on the host side - # and send the result as a kernel argument - replacements = {} - for ps in self.range_tree_nodes[sym].precomputed_args(): # type: ignore[index] - replacements[ps] = V.graph.sizevars.lookup_precomputed_size(ps) - if len(replacements) > 0: - self.range_tree_nodes[sym].expr = sympy_subs( # type: ignore[index] - self.range_tree_nodes[sym].expr, replacements # type: ignore[index] - ) - self.range_tree_nodes[sym].codegen() # type: ignore[index] - return expr - - @contextlib.contextmanager - def mask_loads(self, mask): - """Context manager to add an additional mask to tl.load/store""" - prior = self._load_mask - if prior: - mask = self.cse.generate(self.compute, f"{mask} & {prior}") - - self._load_mask = mask - try: - # TODO(jansel): do we need a reshape here? - yield mask - finally: - self._load_mask = prior - def generate_assert(self, check): return torch.version.hip is None and super().generate_assert(check) - def load_mask(self, var): - mask = "" - mask_vars = set(var.mask_vars) - if self._load_mask: - mask_vars.add(self._load_mask) - - if mask_vars: - mask = ( - f"{next(iter(mask_vars))}" - if len(mask_vars) == 1 - # sorted for deterministic order - else f"({' & '.join(sorted(map(str, mask_vars)))})" - ) - return mask - @property def assert_function(self) -> str: return "tl.device_assert" - def get_strides_of_load(self, index: sympy.Expr): - """ - This gets the stride of the index for each of the tiling variables - (technically, it does it at index 0) - - For example, if - xindex = x0 + 512*x1 + 1024*r0 - x0 = (xindex//512) - x1 = (xindex % 512) - r0 = rindex // 1024 - - this function would return - {xindex: 512, rindex: 1024} - """ - index_to_tile_indexes = {k: v.expr for k, v in self.range_tree_nodes.items()} - index_in_tile_vars = sympy_subs(index, index_to_tile_indexes) # type: ignore[arg-type] - strides = {} - for range_tree in self.range_trees: - s = sympy_index_symbol(range_tree.name) - strides[s] = sympy_subs(index_in_tile_vars, {s: 1}) - sympy_subs( - index_in_tile_vars, {s: 0} - ) - return strides - def codegen_block_ptr( self, name: str, var: str, indexing: BlockPtrOptions, other="" ) -> Tuple[str, Optional[DeferredLine], str]: @@ -2126,12 +1277,6 @@ def reduction_resize(self, value): sizes[-1] = "None" return f"{value}[{', '.join(sizes)}]" - @staticmethod - def _map_tuple_or_scalar(fn, value): - if isinstance(value, tuple): - return tuple(map(fn, value)) - return fn(value) - def reduction( self, dtype: torch.dtype, @@ -2685,68 +1830,6 @@ def imports_for_benchmark_kernel(self): ) ) - def estimate_kernel_num_bytes(self): - """ - Try the best to estimate the total size (in bytes) of the - kernel's inputs and outputs, which is used for estimating the memory - throughput of this kernel. This information is used for checking how - far we are from the peak memory bandwidth. It's important that - we want to avoid overestimating the sizes of the inputs and outputs, - because it can wrongfully give us a very large memory traffic value, - which may be even larger than the theoretical bandwidth and thus - become very misleading. This is particularly problematic for cases - where we slice some inputs. In those cases, we should only count - the size of the "slices" instead of the original inputs, because - only the slices contribute to the real memory traffic. - """ - nbytes = [] - ninplace_args = len(unique(self.args.inplace_buffers.values())) - _, call_args, _ = self.args.python_argdefs() - - # For pointwise and reduction kernels, this is the upper-bound numels - # for the output buffer. - # FIXME: This is not exactly right for cases like below: - # def foo(tensor0, tensor1): - # x0 = narrow(tensor0) - # return cat(x0, tensor1) - # For this example, we will end up overestimate the size for the - # slice s0. Potentially, we could have precise inputs information - # if we maintained the original inputs of the Pointwise kernel created - # for the "cat". However, I think it might be a bit overwhelming that - # we add such complexity only for handling some particular cases for - # benchmarking. - out_numel = V.graph.sizevars.size_hint(sympy_product(self.numels)) - for i, arg in enumerate(call_args): - # "buf" may be narrowed. In this case, the number of memory accesses - # should be estimated based on the reinterpreted layout. - # On the other hand, buf may be broadcasted. In this case, - # counting the size of the underline storage would give us - # a better estimation in terms of memory accesses. - if arg not in self.buf_accesses: - nbytes.append(0) - continue - arg_numel = V.graph.get_numel(arg) - buf_size = V.graph.sizevars.size_hint(arg_numel) - if buf_size > out_numel: - # This arg points to a buf that has been sliced. - # We need to count each individual slice to have - # a better estimation. - indices: Set[Any] = set() - no_index_dep_count = 0 - for dep in self.buf_accesses[arg]: - if isinstance(dep, (StarDep, WeakDep)): - indices.add(f"no_index_dep_{no_index_dep_count}") - no_index_dep_count += 1 - else: - indices.add(dep.index) - numel = len(indices) * out_numel - else: - numel = buf_size - dtype = V.graph.get_dtype(arg) - dtype_size = get_dtype_size(dtype) - nbytes.append(numel * dtype_size * (1 + int(i < ninplace_args))) - return sum(nbytes) - def _get_heuristic(self): if self.persistent_reduction: assert self.inside_reduction @@ -2992,28 +2075,6 @@ def KERNEL_NAME(in_ptr0, in_ptr1, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexp if tree.prefix == "x" and self.no_x_dim: code.writeline("XBLOCK: tl.constexpr = 1") - def triton_tensor_ndim(self): - return sum(int(tree.tensor_dim is not None) for tree in self.range_trees) - - def indexing_size_str(self, i): - sizes = ["None"] * self.triton_tensor_ndim() - sizes[i] = ":" - return f"[{', '.join(sizes)}]" - - def dense_size_list(self) -> List[str]: - sizes = ["1"] * self.triton_tensor_ndim() - for tree in self.range_trees: - if tree.tensor_dim is None: - continue - - if tree.prefix != "r" or self.inside_reduction: - sizes[tree.tensor_dim] = f"{tree.prefix.upper()}BLOCK" - return sizes - - def dense_size_str(self): - sizes = self.dense_size_list() - return f"[{', '.join(sizes)}]" - def _get_grid_fn(self): return "grid" @@ -3077,439 +2138,22 @@ def codegen_nan_check(self): line = f"assert not {arg}.isinf().any().item()" wrapper.writeline(line) - def warn_mix_layout(self, kernel_name): - """ - Print message if the kernel have mixed layout inputs. - Only care about 4D tensor for now. - """ - if ( - len(self.args.input_buffers) == 1 - and len(self.args.output_buffers) == 1 - and len(self.args.inplace_buffers) == 0 - ): - # even if input buffer and output buffer have different layout, - # this can be a layout conversion kernel. No need to warn for - # the mix layouts. - return - - argdefs, call_args, signature = self.args.python_argdefs() - uniform_stride_order = None - for arg_name in call_args: - buf = V.graph.get_buffer(arg_name) - if buf and len(buf.layout.size) == 4: - # ignore the tensor if only 1 dimension is non-zero - if len([x for x in buf.layout.size if x == 1]) == 3: - continue - stride_order = ir.get_stride_order(buf.layout.stride) - if uniform_stride_order is None: - uniform_stride_order = stride_order - elif uniform_stride_order != stride_order: - msg = yellow_text( - f"Expected stride order {uniform_stride_order}, but found stride order" - + f" {stride_order} for kernel {kernel_name}" - ) - log.warning(msg) - - stride_order_list = [ - ir.get_stride_order(V.graph.get_buffer(name).layout.stride) - if V.graph.get_buffer(name) - else None - for name in call_args - ] - size_list = [ - V.graph.get_buffer(name).layout.size - if V.graph.get_buffer(name) - else None - for name in call_args - ] - source_list = [ - "GraphInput" - if name in V.graph.graph_inputs - else "IntermediateBuffer" - if name in V.graph.name_to_buffer - else None - for name in call_args - ] - - msg = yellow_text( - f" param names {argdefs}\n buf names {call_args}\n strides {stride_order_list}" - + f"\n sizes {size_list}\n sources {source_list}\n" - ) - log.warning(msg) - return - msg = green_text( - f"All the inputs for the triton kernel {kernel_name} have uniform layout" - ) - log.warning(msg) - def create_cse_var(self, *args, **kwargs): return TritonCSEVariable(*args, **kwargs) - -class TritonScheduling(BaseScheduling): - def __init__(self, scheduler): - self.scheduler = scheduler - - def group_fn(self, sizes): - return tuple(V.graph.sizevars.simplify(sympy_product(s)) for s in sizes) - - def can_fuse(self, node1, node2): - """ - Hook called by Scheduler to determine if the Triton backend - can fuse node1 and node2. These nodes might already be - FusedSchedulerNodes. - """ - if isinstance(node1, scheduler.ForeachKernelSchedulerNode) or isinstance( - node2, scheduler.ForeachKernelSchedulerNode - ): - return scheduler.ForeachKernelSchedulerNode.can_fuse(node1, node2) - - _, (numel1, rnumel1) = node1.group - _, (numel2, rnumel2) = node2.group - why = WhyNoFuse(node1, node2) - - if node1.is_split_scan() and not node2.is_split_scan(): - if node2.is_reduction(): - why("Split scan cannot fuse with reductions") - elif node2.is_split_scan() and not node1.is_split_scan(): - if node1.is_reduction(): - why("Split scan cannot fuse with reductions") - - if node1.is_reduction() and node2.is_reduction(): - reduction_can_fuse = numel1 == numel2 and rnumel1 == rnumel2 - if not reduction_can_fuse: - why( - "numel/rnumel mismatch (reduce) (%s, %s), (%s, %s)", - numel1, - numel2, - rnumel1, - rnumel2, - ) - return reduction_can_fuse - - if not node1.is_reduction() and not node2.is_reduction(): - if not (numel1 == numel2 and rnumel1 == rnumel2): - why( - "numel/rnumel mismatch (non-reduce) (%s, %s), (%s, %s)", - numel1, - numel2, - rnumel1, - rnumel2, - ) - return False - - if node1.is_template(): - # Only allow fusion for TritonTemplates for now. - # Fusion for CUDATemplates are not supported. - is_triton_template = isinstance(node1.node, TritonTemplateBuffer) - if not is_triton_template: - why("node1 is not TritonTemplateBuffer") - return is_triton_template - - # check for a bad combined tiling - tiling1 = self.select_tiling(node1.get_nodes(), numel1, rnumel1) - tiling2 = self.select_tiling(node2.get_nodes(), numel1, rnumel1) - tiling3 = self.select_tiling( - node1.get_nodes() + node2.get_nodes(), numel1, rnumel1 - ) - if config.triton.tiling_prevents_pointwise_fusion: - cond = True - if len(tiling1) > 2: - if len(tiling2) > 2: - cond = tiling1 == tiling2 == tiling3 - else: - cond = tiling1 == tiling3 - elif len(tiling2) > 2: - cond = tiling2 == tiling3 - if not cond: - why( - "tiling mismatch (%s, %s, %s)", - tiling1, - tiling2, - tiling3, - ) - return False - - return True - - if not node1.is_reduction() and node2.is_reduction(): - assert rnumel1 == 1 and rnumel2 != 1 - if numel1 == numel2 * rnumel2: - if not all( - TritonKernel.is_compatible((numel2, rnumel2), n.get_ranges()) - for n in node1.get_nodes() - ): - why("nodes numel/rnumel incompatibility") - return False - if ( - config.triton.tiling_prevents_reduction_fusion - and not node1.is_template() - ): - is_reduction_tiling_valid = self.select_tiling( - node1.get_nodes(), numel1 - ) in ( - (numel1, 1), - (numel2, rnumel2, 1), - ) - if not is_reduction_tiling_valid: - why("invalid tiling for reduction") - return is_reduction_tiling_valid - return True - - if numel1 != numel2: - why("nodes numel incompatibility") - return numel1 == numel2 - - assert node1.is_reduction() and not node2.is_reduction() - # swap args to hit the case above - return self.can_fuse_horizontal(node2, node1) - - can_fuse_vertical = can_fuse - can_fuse_horizontal = can_fuse - - def generate_node_schedule(self, nodes, numel, rnumel): - node_schedule: List[Any] = [] - current_loop_writes: Set[str] = set() - - # Writes with a reduced shape, meaning they are only present once the - # reduction loop has ended - current_loop_reduced_writes = set() - current_loop_has_writes = False - done = set() - - def fits_in_main_body(n): - _, (node_numel, node_rnumel) = n.group - return (node_numel == numel and node_rnumel == rnumel) or ( - node_numel == numel * rnumel and node_rnumel == 1 - ) - - def fits_outside_reduction(n): - _, (node_numel, node_rnumel) = n.group - return node_numel == numel and node_rnumel == 1 and rnumel != 1 - - def schedule_node_in_loop(n): - nonlocal current_loop_has_writes - done.add(n) - node_schedule.append(n) - current_loop_has_writes = True - # A scan is modelled as a reduction in the scheduler but has a - # full sized output that can be used inside the loop body - if ( - n.is_reduction() - and isinstance(n, scheduler.SchedulerNode) - and isinstance(n.node, ir.ComputedBuffer) - and not isinstance(n.node.data, ir.Scan) - ): - current_loop_reduced_writes.add(n.get_name()) - - @contextlib.contextmanager - def end_current_reduction_loop(): - nonlocal current_loop_has_writes - if current_loop_has_writes: - # flush out any other runnable nodes to reduce number of loops - for other_node in nodes[index + 1 :]: - if ( - node not in done - and fits_in_main_body(other_node) - and not (current_loop_reduced_writes & other_node.ancestors) - ): - schedule_node_in_loop(node) - - if node_schedule and node_schedule[-1] is EnableReduction: - node_schedule.pop() - else: - node_schedule.append(DisableReduction) - yield - node_schedule.append(EnableReduction) - current_loop_reduced_writes.clear() - current_loop_has_writes = False - - for index, node in enumerate(nodes): - if node in done: - continue - done.add(node) - - def requires_closing_previous_reduction(node, node_schedule): - if rnumel == 1: - return False - if not current_loop_reduced_writes & node.ancestors: - return False - assert node_schedule and not isinstance( - node_schedule[-1], (EnableReduction, DisableReduction) - ) - return bool(current_loop_reduced_writes) - - if fits_in_main_body(node): - if requires_closing_previous_reduction(node, node_schedule): - with end_current_reduction_loop(): - pass # need to start a new reduction loop - - schedule_node_in_loop(node) - elif fits_outside_reduction(node): - with end_current_reduction_loop(): - node_schedule.append(node) - else: - raise NotImplementedError( - f"unexpected group: ({numel}, {rnumel}) != {node.group[1]}" - ) - - return node_schedule - - def codegen_node( - self, node: Union[scheduler.FusedSchedulerNode, scheduler.SchedulerNode] - ): - """ - Given a set of pre-fused nodes, generate a Triton kernel. - """ - - nodes: List[scheduler.SchedulerNode] = node.get_nodes() # type: ignore[assignment] - - _, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group - - node_schedule = self.generate_node_schedule(nodes, numel, rnumel) - buf_accesses = collections.defaultdict(list) - for node in nodes: - for access in node.read_writes.reads | node.read_writes.writes: - buf_accesses[access.name].append(access) - - schedule_log.debug("Schedule:\n %s", node_schedule) - - return self.codegen_node_schedule(node_schedule, buf_accesses, numel, rnumel) - - @staticmethod - def reduction_hint(node): - assert node.is_reduction() - if all( - dep.is_contiguous() - for dep in itertools.chain(node.read_writes.reads, node.read_writes.writes) - ): - return ReductionHint.INNER + def codegen_iteration_ranges_entry(self, entry: IterationRangesEntry): + line = f"{entry.name} = {self.kexpr(self.rename_indexing(entry.expr))}" + if entry.root.is_loop: + self.indexing_code.writeline(line) else: - return node.node.data.reduction_hint - - @staticmethod - def can_use_32bit_indexing( - numel: sympy.Expr, buffers: Iterable[Union[ir.Buffer, ir.TensorBox]] - ) -> bool: - int_max = torch.iinfo(torch.int32).max - size_hint = V.graph.sizevars.size_hint - has_hint = V.graph.sizevars.shape_env.has_hint - - def within_32bit(e): - # Allow for unhinted e as long as we can still statically prove - # (e.g., via ValueRanges) that it is still in bounds - if V.graph.sizevars.is_expr_static_and_true(e <= int_max): - return True - # Otherwise, the hint MUST exist and be in range - return has_hint(e) and size_hint(e) <= int_max - - if not within_32bit(numel): - return False - - # Any use of a MultiOutputLayout will create a buffer with a - # Layout whose sizes are accounted for - buf_sizes = [ - buf.get_layout().storage_size() - for buf in buffers - if not isinstance(buf.get_layout(), ir.MultiOutputLayout) - ] - - if not all(within_32bit(size) for size in buf_sizes): - return False - - # Only install guards for 32-bit indexing as there is no correctness - # issue with using 64-bit for everything - V.graph.sizevars.guard_leq(numel, int_max) # type: ignore[arg-type] - for size in buf_sizes: - V.graph.sizevars.guard_leq(size, int_max) # type: ignore[arg-type] - return True - - @staticmethod - def select_index_dtype(node_schedule, numel, reduction_numel): - # Gather all used buffer names - buffer_names = set() - for node in node_schedule: - if not isinstance(node, scheduler.BaseSchedulerNode): - continue - - buffer_names.update(node.get_names()) - buffer_names.update(node.used_buffer_names()) - - # Get buffers objects - - def _get_buffer(name: str) -> Union[ir.Buffer, ir.TensorBox]: - buf = V.graph.get_buffer(name) - if buf is None: - raise RuntimeError(f"Failed to find buffer matching name {name}") - return buf - - buffers = [V.graph.get_buffer(name) for name in buffer_names] - - # In theory we can separately check xnumel and rnumel are <= int_max - # but some indexers do use the full linear index so we need to be - # conservative here. - total_numel = numel * reduction_numel - - if TritonScheduling.can_use_32bit_indexing(total_numel, buffers): - return "tl.int32" - return "tl.int64" - - def has_non_contiguous_pw_in_reduction_kernel(self, node_schedule, numel, rnumel): - pointwise_nodes = list( - filter( - lambda n: n not in (EnableReduction, DisableReduction) - and not n.is_reduction() - and n.group[1][0] == numel * rnumel, - node_schedule, - ) - ) - for node in pointwise_nodes: - # An index can be an integer when loading a random seed. - if not all( - not isinstance(dep, MemoryDep) - or dep.is_contiguous() - or isinstance(dep.index, (sympy.Integer, int)) - or dep.stride1_for_last_dim() - for dep in itertools.chain( - node.read_writes.reads, node.read_writes.writes - ) - ): - return True - return False - - def get_kernel_args(self, node_schedule, numel, reduction_numel): - reductions = list( - filter( - lambda n: n not in (EnableReduction, DisableReduction) - and n.is_reduction(), - node_schedule, - ) - ) - if len(reductions) > 0: - hints = [self.reduction_hint(n) for n in reductions] - if hints.count(hints[0]) == len(hints): - reduction_hint_val = hints[0] - else: - reduction_hint_val = ReductionHint.DEFAULT - - if ( - reduction_hint_val == ReductionHint.INNER - and self.has_non_contiguous_pw_in_reduction_kernel( - node_schedule, numel, reduction_numel - ) - ): - reduction_hint_val = ReductionHint.DEFAULT - else: - reduction_hint_val = ReductionHint.DEFAULT - - mutations = set() - for node in node_schedule: - if hasattr(node, "get_mutations"): - mutations.update(node.get_mutations()) + # lift non-reduction stores outside loop + self.body.writeline(line) - index_dtype = self.select_index_dtype(node_schedule, numel, reduction_numel) - return reduction_hint_val, mutations, index_dtype +class TritonScheduling(SIMDScheduling): + int32_type = "tl.int32" + int64_type = "tl.int64" + kernel_type = TritonKernel def codegen_comment(self, node_schedule): wrapper = V.graph.wrapper_code @@ -3537,123 +2181,7 @@ def codegen_comment(self, node_schedule): f"{wrapper.comment} Fused node name list: {', '.join(node_names)}" ) - def codegen_node_schedule( - self, node_schedule, buf_accesses, numel, reduction_numel - ): - from torch._inductor.codegen.triton_split_scan import TritonSplitScanKernel - - tiled_groups = self.select_tiling(node_schedule, numel, reduction_numel) - ( - reduction_hint_val, - mutations, - index_dtype, - ) = self.get_kernel_args(node_schedule, numel, reduction_numel) - - is_split_scan = any( - isinstance(node, BaseSchedulerNode) and node.is_split_scan() - for node in node_schedule - ) - kernel_type = TritonSplitScanKernel if is_split_scan else TritonKernel - kernel_args = tiled_groups - kernel_kwargs = { - "reduction_hint": reduction_hint_val, - "mutations": mutations, - "index_dtype": index_dtype, - } - kernel = kernel_type( - *kernel_args, - **kernel_kwargs, - ) - kernel.buf_accesses = buf_accesses - - self.codegen_node_schedule_with_kernel(node_schedule, kernel) - - with V.set_kernel_handler(kernel): - src_code = kernel.codegen_kernel() - - kernel_name = self.define_kernel(src_code, node_schedule) - log.debug("Generating kernel code with kernel_name: %s", kernel_name) - kernel.kernel_name = kernel_name - kernel.code_hash = code_hash(src_code) - - if kernel.persistent_reduction and config.triton.multi_kernel: - kernel2 = TritonKernel( - *kernel_args, - **kernel_kwargs, - disable_persistent_reduction=True, - ) - self.codegen_node_schedule_with_kernel(node_schedule, kernel2) - with V.set_kernel_handler(kernel2): - src_code2 = kernel2.codegen_kernel() - kernel_name2 = self.define_kernel(src_code2, node_schedule) - kernel2.kernel_name = kernel_name2 - kernel2.code_hash = code_hash(src_code2) - - final_kernel = MultiKernel([kernel, kernel2]) - else: - final_kernel = kernel # type: ignore[assignment] - - with V.set_kernel_handler(final_kernel): - for node in node_schedule: - if node not in (EnableReduction, DisableReduction): - node.mark_run() - - self.codegen_comment(node_schedule) - final_kernel.call_kernel(final_kernel.kernel_name) - if config.nan_asserts: - final_kernel.codegen_nan_check() - if config.warn_mix_layout: - final_kernel.warn_mix_layout(kernel_name) - - V.graph.removed_buffers |= final_kernel.removed_buffers - V.graph.inplaced_to_remove |= final_kernel.inplaced_to_remove - - if ( - V.graph.wrapper_code.supports_intermediate_hooks - and config.generate_intermediate_hooks - ): - # Not every node in the schedule will actually be live on output; - # we can't check dead buffers. - live_outs = kernel.args.live_output_buffers() - for node in node_schedule: - if not isinstance(node, scheduler.BaseSchedulerNode): - continue - name = node.get_name() - if name not in live_outs: - continue - origin_node = node.node.get_origin_node() - if origin_node is not None: - counters["inductor"]["intermediate_hooks"] += 1 - V.graph.wrapper_code.writeline( - f"run_intermediate_hooks({origin_node.name!r}, {name})" - ) - - self.scheduler.free_buffers() - - def codegen_node_schedule_with_kernel(self, node_schedule, kernel): - def current_reduction_nodes(nodes): - return itertools.takewhile(lambda n: n is not DisableReduction, nodes) - - with kernel: - stack = contextlib.ExitStack() - kernel.set_last_usage(current_reduction_nodes(node_schedule)) - - for node in node_schedule: - if node not in (EnableReduction, DisableReduction): - node.decide_inplace_update() - for i, node in enumerate(node_schedule): - if node is DisableReduction: - stack.enter_context(kernel.disable_reduction()) - elif node is EnableReduction: - stack.close() - kernel.set_last_usage(current_reduction_nodes(node_schedule[i:])) - else: - # TODO - use split ranges ? - indexing_dtype_strength_reduction(node._body) - index_vars = kernel.split_and_set_ranges(node.get_ranges()) - node.codegen(index_vars) - - def define_kernel(self, src_code, node_schedule): + def define_kernel(self, src_code, node_schedule, kernel): wrapper = V.graph.wrapper_code if src_code in wrapper.src_to_kernel: kernel_name = wrapper.src_to_kernel[src_code] @@ -3705,293 +2233,6 @@ def define_kernel(self, src_code, node_schedule): return kernel_name - def codegen_template( - self, template_node, epilogue_nodes, only_gen_src_code=False - ) -> Optional[str]: - """ - Codegen a triton template - - If `only_gen_src_code` the src code will be returned instead of codegen'd into the wrapper - """ - _, (numel, rnumel) = template_node.group - assert rnumel == 1 - kernel, render = template_node.node.make_kernel_render(template_node.node) - with kernel: - if not only_gen_src_code: - for node in [template_node, *epilogue_nodes]: - node.mark_run() - partial_code = render() - for node in epilogue_nodes: - node.codegen(kernel.split_and_set_ranges(node.get_ranges())) - - # finalize must be called after adding epilogue above - with V.set_kernel_handler(kernel): - # TODO: Maybe unify CUDATemplateKernel to also use PartialRender for flexible epilogue fusion. - src_code = ( - partial_code - if isinstance(partial_code, str) - else partial_code.finalize() - ) - node_schedule = [template_node, *epilogue_nodes] - - if config.benchmark_kernel: - num_gb = kernel.estimate_kernel_num_bytes() / 1e9 - grid_args = V.graph.sizevars.size_hints(kernel.call_sizes) - assert kernel.meta is not None, "meta is None" - grid = kernel.grid_fn(*grid_args, kernel.meta) - src_code = ( - f"{kernel.imports_for_benchmark_kernel()}\n" - f"{src_code}\n" - f"{kernel.codegen_kernel_benchmark(num_gb, grid).getvalue()}" - ) - - if only_gen_src_code: - return src_code - - kernel_name = self.define_kernel(src_code, node_schedule) - - self.codegen_comment(node_schedule) - kernel.call_kernel(kernel_name, template_node.node) - V.graph.removed_buffers |= kernel.removed_buffers - V.graph.inplaced_to_remove |= kernel.inplaced_to_remove - self.scheduler.free_buffers() - return None - - def codegen_sync(self): - V.graph.wrapper_code.writeline(V.graph.device_ops.synchronize()) - - def codegen_foreach(self, foreach_node): - from .triton_foreach import ForeachKernel - - for partitions_with_metadata in ForeachKernel.horizontal_partition( - foreach_node.get_subkernel_nodes(), self - ): - kernel = ForeachKernel() - for nodes, tiled_groups, numel, rnumel in partitions_with_metadata: - node_schedule = self.generate_node_schedule(nodes, numel, rnumel) - ( - reduction_hint_val, - mutations, - index_dtype, - ) = self.get_kernel_args(node_schedule, numel, rnumel) - - subkernel = kernel.create_sub_kernel( - *tiled_groups, - reduction_hint=reduction_hint_val, - mutations=mutations, - index_dtype=index_dtype, - ) - - self.codegen_node_schedule_with_kernel( - node_schedule, - subkernel, - ) - - with V.set_kernel_handler(subkernel): - for node in node_schedule: - if node not in (EnableReduction, DisableReduction): - node.mark_run() - V.graph.removed_buffers |= subkernel.removed_buffers - V.graph.inplaced_to_remove |= subkernel.inplaced_to_remove - - src_code = kernel.codegen_kernel() - kernel_name = self.define_kernel(src_code, [foreach_node]) - self.codegen_comment([foreach_node]) - kernel.call_kernel(V.graph.wrapper_code, kernel_name) - - self.scheduler.free_buffers() - - @staticmethod - @functools.lru_cache(32) - def candidate_tilings(node): - ranges, reduction_ranges = node.get_ranges() - if len(ranges) <= 1: - return () - - rw = node.pointwise_read_writes() - assert len(rw.range_vars) == len(ranges) - - # isinstance(dep, MemoryDep): this filters out StarDeps. StarDeps refer to reads - # that need to access the entire tensor; they don't contribute read indexing - # information (and practically, they don't have dep.index so they can't be used - # for stride_hints below - dep_sources = [rw.reads, rw.writes] - assert all( - isinstance(dep, (MemoryDep, StarDep)) - for dep in itertools.chain.from_iterable(dep_sources) - ) - deps = [ - dep - for dep in itertools.chain.from_iterable(dep_sources) - if dep.name not in V.graph.removed_buffers and isinstance(dep, MemoryDep) - ] - write_names = {dep.name for dep in rw.writes} - - tilings: List[CandidateTiling] = [] - - for dep in deps: - strides = V.graph.sizevars.stride_hints(dep.index, rw.range_vars) - assert len(strides) == len(ranges) - try: - split = strides.index(1) + 1 - if split == len(ranges): - continue - if all(s == 0 for s in strides[split:]): - # if this is a broadcasted tensor and all dimensions after split are broadcast, - # this is not a real split - continue - - except ValueError: - continue - tiled_groups = ( - V.graph.sizevars.simplify(sympy_product(ranges[:split])), - V.graph.sizevars.simplify(sympy_product(ranges[split:])), - ) - # score by number of elements - score = V.graph.sizevars.size_hint( - sympy_product( - size for size, stride in zip(ranges, strides) if stride != 0 - ) - ) - if dep.name in write_names: - # ngimel said contiguous writes is more important than reads - score *= 2 - if CandidateTiling.is_good_size(tiled_groups[0]): - score *= 2 - if CandidateTiling.is_good_size(tiled_groups[1]): - score *= 2 - - if ( - V.graph.sizevars.size_hint( - score - sympy_product(itertools.chain(ranges, reduction_ranges)) - ) - >= 0 - ): - tilings.append(CandidateTiling(tiled_groups, score, dep.name)) - return tilings - - @classmethod - def select_tiling(cls, node_schedule, numel, reduction_numel=sympy.Integer(1)): - """ - Heuristics to decide how to tile kernels. - Currently, we tile based on stride-1 dimensions. - - Returns: - `(tile1, tile2, reduction_numel)` s.t. `tile1 * tile2 == numel` - - """ - if reduction_numel != 1 or config.triton.max_tiles <= 1: - # TODO(jansel): should we tile reductions? - # do perf hint here if stride-1 dim is not being reduced - if perf_hint_log.level <= logging.WARNING: - for node in EnableReduction.filter(node_schedule): - if len(cls.candidate_tilings(node)) > 0: - perf_hint_log.info("reduction over non-contiguous dims") - break - return (numel, reduction_numel) - - seen_names = set() - candidate_tiles: Counter[Any] = collections.Counter() - for node in EnableReduction.filter(node_schedule): - for tiling in cls.candidate_tilings(node): - if tiling.name in seen_names: - continue - seen_names.add(tiling.name) - candidate_tiles[tiling.tiling] += tiling.score - - ranked_tilings = [tiling for tiling, score in candidate_tiles.most_common()] - - if config.triton.max_tiles >= 3: - # Consider adding a third dimension of tiling, but only - # when a1 is a multiple of b1; otherwise, you have a lot - # of stragglers which is annoying to generate code for. - # - # NB: More than three max tiles is not enabled by default. - - # Add one 3D tiling choice - for i in range(1, len(ranked_tilings)): - a0, a1 = ranked_tilings[0] - b0, b1 = ranked_tilings[i] - if V.graph.sizevars.size_hint(a1 - b1) == 0: - continue - if V.graph.sizevars.size_hint(a1 - b1) < 0: - # swap so a0 is bigger - a0, a1 = ranked_tilings[i] - b0, b1 = ranked_tilings[0] - assert V.graph.sizevars.size_hint(a1 - b1) > 0 - if V.graph.sizevars.statically_known_multiple_of(a1, b1): - tiling = (a0, FloorDiv(a1, b1), b1) - ranked_tilings = [tiling] + ranked_tilings - break # only 1 choice for now - - if len(ranked_tilings) > 1: - perf_hint_log.info("possibly bad tiling: %s", ranked_tilings) - - for tiled_groups in ranked_tilings: - new_groups = (*tiled_groups, reduction_numel) - if all( - TritonKernel.is_compatible(new_groups, node.get_ranges()) - for node in node_schedule - if isinstance(node, scheduler.SchedulerNode) - ): - return new_groups - - return (numel, reduction_numel) - - def flush(self): - pass - - def ready_to_flush(self) -> bool: - return False - - def generate_kernel_code_from_nodes(self, nodes, benchmark_kernel=False): - @dataclasses.dataclass - class LastUsageHolder: - n: Any - last_usage: Any - - def __del__(self): - self.n.last_usage = self.last_usage - - last_usage_holders = [LastUsageHolder(n, n.last_usage) for n in nodes] - - # empty last_usage. May cause more aggressive 'evict_last'. Should be fine. - for n in nodes: - n.last_usage = set() - - if not nodes[0].is_template(): - _, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group - node_schedule = self.generate_node_schedule(nodes, numel, rnumel) - - tiled_groups = self.select_tiling(node_schedule, numel, rnumel) - reduction_hint_val, mutations, index_dtype = self.get_kernel_args( - node_schedule, numel, rnumel - ) - - kernel = TritonKernel( - *tiled_groups, - reduction_hint=reduction_hint_val, - mutations=mutations, - index_dtype=index_dtype, - ) - - self.codegen_node_schedule_with_kernel(node_schedule, kernel) - with config.patch( - "benchmark_kernel", benchmark_kernel - ), V.set_kernel_handler(kernel): - src_code = kernel.codegen_kernel() - else: - template_node = nodes[0] - epilogue_nodes = nodes[1:] - - with config.patch("benchmark_kernel", benchmark_kernel): - src_code = self.codegen_template( - template_node, epilogue_nodes, only_gen_src_code=True - ) - - src_code = src_code.replace(str(Placeholder.KERNEL_NAME), "triton_") - return src_code - @preserve_rng_state() def benchmark_fused_nodes(self, nodes): src_code = self.generate_kernel_code_from_nodes(nodes, benchmark_kernel=True) @@ -4062,50 +2303,3 @@ def store_cache(): ) store_cache() return ms, mod.__file__ - - -@dataclasses.dataclass -class CandidateTiling: - tiling: Tuple[sympy.Expr, sympy.Expr] - score: int # higher is better - name: Optional[str] = None - - @staticmethod - def is_good_size(s): - """Somewhat arbitrary heuristic used to boost scores for some sizes""" - s = V.graph.sizevars.size_hint(s) - return s >= 32 and (s % 32 == 0) - - -class DisableReduction: - """ - Marker to invoke `kernel.disable_reduction()`. This closes a - reduction loop and allows for pointwise ops to occur on the output - of a reduction. - """ - - -class EnableReduction: - """ - Marker to end a DisableReduction block. - """ - - @staticmethod - def filter(node_schedule): - """ - Get the nodes from node_schedule skipping those in a - DisableReduction block. - """ - disabled = False - for node in node_schedule: - if node in (EnableReduction, DisableReduction): - # Don't tile stuff outside the main reduction loop - disabled = node is DisableReduction - elif disabled: - pass - else: - yield node - - -class CantSplit(Exception): - pass diff --git a/torch/_inductor/codegen/triton_split_scan.py b/torch/_inductor/codegen/triton_split_scan.py index 8df904946e4a..2a8e0142fbd4 100644 --- a/torch/_inductor/codegen/triton_split_scan.py +++ b/torch/_inductor/codegen/triton_split_scan.py @@ -4,12 +4,9 @@ import torch._inductor.runtime.hints from torch._inductor import config +from torch._inductor.codegen.simd import IterationRangesRoot -from torch._inductor.codegen.triton import ( - IterationRangesRoot, - triton_compute_type, - TritonKernel, -) +from torch._inductor.codegen.triton import triton_compute_type, TritonKernel from torch._prims_common import prod diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index f98333875f5c..456e0c50567d 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -336,7 +336,7 @@ def decide_inplace_update(self): isinstance(self, (SchedulerNode,)) and config.inplace_buffers and ( - not isinstance(V.kernel, torch._inductor.codegen.triton.TritonKernel) + not isinstance(V.kernel, torch._inductor.codegen.simd.SIMDKernel) or getattr(V.kernel, "mutations", None) is not None ) ): @@ -390,7 +390,7 @@ def decide_inplace_update(self): ) # mutations not tracked in cpp kernels if isinstance( - V.kernel, torch._inductor.codegen.triton.TritonKernel + V.kernel, torch._inductor.codegen.simd.SIMDKernel ): V.kernel.mutations.add(input_node.get_name()) V.kernel.mutations.add(self.get_name()) diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 4940f53b1e79..5cb10e1820cf 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -454,11 +454,8 @@ def indexing( block_ptr=block_ptr, ) - def initialize_range_tree(self, pid_cache): - super().initialize_range_tree(pid_cache) - # ignore default codegen - self.body.clear() - self.indexing_code.clear() + def codegen_range_tree(self): + pass # ignore default codegen def call_kernel(self, name: str, node: Optional[ir.IRNode] = None): wrapper = V.graph.wrapper_code diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 1bbeac16e21e..fb3d221bbfe4 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -1456,7 +1456,7 @@ def dump_node_schedule(node_schedule): An API that can be used in pdb to dump a node_schedule. Right mainly dump the read/write dependencies but can add more as needed. """ - from torch._inductor.codegen.triton import DisableReduction, EnableReduction + from torch._inductor.codegen.simd import DisableReduction, EnableReduction from torch._inductor.scheduler import SchedulerNode print(f"Node schedule with {len(node_schedule)} nodes") From c1767d8626edcf4bdab5b3b44207ab319e844845 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Wed, 15 May 2024 15:07:59 -0700 Subject: [PATCH 097/116] Faster(?) FP16 gemv kernel (#126297) Differential Revision: [D57369266](https://our.internmc.facebook.com/intern/diff/D57369266/) **NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D57369266/)! Pull Request resolved: https://github.com/pytorch/pytorch/pull/126297 Approved by: https://github.com/malfet --- aten/src/ATen/native/BlasKernel.cpp | 119 +++++++++++++++++++++++++--- 1 file changed, 107 insertions(+), 12 deletions(-) diff --git a/aten/src/ATen/native/BlasKernel.cpp b/aten/src/ATen/native/BlasKernel.cpp index 48a077814880..af34ae5c582a 100644 --- a/aten/src/ATen/native/BlasKernel.cpp +++ b/aten/src/ATen/native/BlasKernel.cpp @@ -215,6 +215,87 @@ static inline float16_t reduce(float16x8_t x) { return reduce(vadd_f16(vget_low_f16(x), vget_high_f16(x))); } +/* + * The below reduce overload and + * fp16_gemv_trans_fp16_arith_by_dot_products function is adapted from + * llama.cpp's ggml_vec_dot_f16 and surrounding utility functions, so + * here is the required copyright notice: + * + * MIT License + * + * Copyright (c) 2023-2024 The ggml authors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#define F16_ELEMENTS_PER_ITERATION 32 +#define F16_ELEMENTS_PER_REGISTER 8 +#define F16_REGISTERS_PER_ITERATION (F16_ELEMENTS_PER_ITERATION / F16_ELEMENTS_PER_REGISTER) +static inline double reduce(float16x8_t x[F16_REGISTERS_PER_ITERATION]) { + int offset = F16_REGISTERS_PER_ITERATION / 2; + for (int i = 0; i < offset; ++i) { + x[i] = vaddq_f16(x[i], x[offset + i]); + } + offset /= 2; + for (int i = 0; i < offset; ++i) { + x[i] = vaddq_f16(x[i], x[offset + i]); + } + offset /= 2; + for (int i = 0; i < offset; ++i) { + x[i] = vaddq_f16(x[i], x[offset + i]); + } + const float32x4_t t0 = vcvt_f32_f16(vget_low_f16(x[0])); + const float32x4_t t1 = vcvt_f32_f16(vget_high_f16(x[0])); + return (double)vaddvq_f32(vaddq_f32(t0, t1)); + +} + +static inline float16x8_t f16_fma(float16x8_t a, float16x8_t b, float16x8_t c) { +#ifdef __ARM_FEATURE_FMA + return vfmaq_f16(a, b, c); +#else + return vaddq_f16(a, vmulq_f16(b, c)); +#endif +} + +// Rather than unrolling to process multiple rows (transposed columns) +// of matrix A at once as done in fp16_gemv_trans_fp16_arith, unroll +// along an individual dot product. +static void fp16_gemv_trans_fp16_arith_by_dot_products(const int m, const int n, const float16_t* a, const int lda, const float16_t *x, float16_t* y, int incy) { + parallel_for(0, n, 1, [&](int begin, int end) { + for (int i = begin; i < end; ++i) { + float16x8_t sum[F16_REGISTERS_PER_ITERATION] = {vdupq_n_f16(0)}; + float16x8_t ax[F16_REGISTERS_PER_ITERATION]; + float16x8_t ay[F16_REGISTERS_PER_ITERATION]; + + for (int j = 0; j < m; j += F16_ELEMENTS_PER_ITERATION) { + for (int k = 0; k < F16_REGISTERS_PER_ITERATION; ++k) { + ax[k] = vld1q_f16(x + j + k * F16_ELEMENTS_PER_REGISTER); + ay[k] = vld1q_f16(a + lda * i + j + k * F16_ELEMENTS_PER_REGISTER); + sum[k] = f16_fma(sum[k], ax[k], ay[k]); + } + } + // TODO: add a tail fixup so we don't have to have such a + // restrictive gate to enter this path. + y[i * incy] = reduce(sum); + } + }); +} static void fp16_gemv_trans_fp16_arith(const int m, const int n, const float16_t* a, const int lda, const float16_t *x, float16_t* y, int incy) { parallel_for(0, n / 4, 1, [&](int begin, int end) { @@ -230,13 +311,13 @@ static void fp16_gemv_trans_fp16_arith(const int m, const int n, const float16_t for (auto j = 0; j < m; j += 8) { float16x8_t xVec = vld1q_f16(x + j); float16x8_t a0Vec = vld1q_f16(row0 + j); - sum0Vec = vaddq_f16(sum0Vec, vmulq_f16(a0Vec, xVec)); + sum0Vec = f16_fma(sum0Vec, a0Vec, xVec); float16x8_t a1Vec = vld1q_f16(row1 + j); - sum1Vec = vaddq_f16(sum1Vec, vmulq_f16(a1Vec, xVec)); + sum1Vec = f16_fma(sum1Vec, a1Vec, xVec); float16x8_t a2Vec = vld1q_f16(row2 + j); - sum2Vec = vaddq_f16(sum2Vec, vmulq_f16(a2Vec, xVec)); + sum2Vec = f16_fma(sum2Vec, a2Vec, xVec); float16x8_t a3Vec = vld1q_f16(row3 + j); - sum3Vec = vaddq_f16(sum3Vec, vmulq_f16(a3Vec, xVec)); + sum3Vec = f16_fma(sum3Vec, a3Vec, xVec); } y[(i + 0) * incy] = reduce(sum0Vec); y[(i + 1) * incy] = reduce(sum1Vec); @@ -245,6 +326,7 @@ static void fp16_gemv_trans_fp16_arith(const int m, const int n, const float16_t } }); } + #endif static inline float reduce(float32x4_t x) { @@ -252,6 +334,14 @@ static inline float reduce(float32x4_t x) { return vgetq_lane_f32(vpaddq_f32(sum, sum), 0); } +static inline float32x4_t f32_fma(float32x4_t a, float32x4_t b, float32x4_t c) { +#ifdef __ARM_FEATURE_FMA + return vfmaq_f32(a, b, c); +#else + return vaddq_f32(a, vmulq_f32(b, c)); +#endif +} + static void fp16_gemv_trans_fp32_arith(const int m, const int n, const float16_t* a, const int lda, const float16_t *x, float16_t* y, int incy) { parallel_for(0, n / 4, 1, [&](int begin, int end) { for (auto i = begin * 4 ; i < end * 4; i += 4) { @@ -266,13 +356,13 @@ static void fp16_gemv_trans_fp32_arith(const int m, const int n, const float16_t for (auto j = 0; j < m; j += 4) { float32x4_t xVec = vcvt_f32_f16(vld1_f16(x + j)); float32x4_t a0Vec = vcvt_f32_f16(vld1_f16(row0 + j)); - sum0Vec = vaddq_f32(sum0Vec, vmulq_f32(a0Vec, xVec)); + sum0Vec = f32_fma(sum0Vec, a0Vec, xVec); float32x4_t a1Vec = vcvt_f32_f16(vld1_f16(row1 + j)); - sum1Vec = vaddq_f32(sum1Vec, vmulq_f32(a1Vec, xVec)); + sum1Vec = f32_fma(sum1Vec, a1Vec, xVec); float32x4_t a2Vec = vcvt_f32_f16(vld1_f16(row2 + j)); - sum2Vec = vaddq_f32(sum2Vec, vmulq_f32(a2Vec, xVec)); + sum2Vec = f32_fma(sum2Vec, a2Vec, xVec); float32x4_t a3Vec = vcvt_f32_f16(vld1_f16(row3 + j)); - sum3Vec = vaddq_f32(sum3Vec, vmulq_f32(a3Vec, xVec)); + sum3Vec = f32_fma(sum3Vec, a3Vec, xVec); } y[(i + 0) * incy] = reduce(sum0Vec); y[(i + 1) * incy] = reduce(sum1Vec); @@ -295,11 +385,16 @@ void fp16_gemv_trans( const int incy) { if (incx == 1 && alpha == 1.0 && beta == 0.0 && m % 4 == 0 && n % 4 == 0) { #ifdef __ARM_FEATURE_FP16_SCALAR_ARITHMETIC - return at::globalContext().allowFP16ReductionCPU() && m % 8 == 0 ? fp16_gemv_trans_fp16_arith(m, n, a, lda, x, y, incy) - : fp16_gemv_trans_fp32_arith(m, n, a, lda, x, y, incy); -#else - return fp16_gemv_trans_fp32_arith(m, n, a, lda, x, y, incy); + if (at::globalContext().allowFP16ReductionCPU()) { + if (m % 32 == 0 && n % 32 == 0) { + return fp16_gemv_trans_fp16_arith_by_dot_products(m, n, a, lda, x, y, incy); + } + if (m % 8 == 0) { + return fp16_gemv_trans_fp16_arith(m, n, a, lda, x, y, incy); + } + } #endif + return fp16_gemv_trans_fp32_arith(m, n, a, lda, x, y, incy); } for (const auto i : c10::irange(n)) { float sum = 0; From bf099a08f09af6c33a99ad7c90ea98284db451b9 Mon Sep 17 00:00:00 2001 From: "Wang, Eikan" Date: Fri, 17 May 2024 07:11:55 +0000 Subject: [PATCH 098/116] [2/N] Non-Tensor: Scalar Support: Add scalar to the cache for eager-through-torch.compile (#124070) Add scalar information to the kernel configuration. #### Additional Context Currently, the input parameters are orchestrated by input order in the kernel configuration and loaded/mapped to the kernel at runtime. For example, the cache order of the input parameters of `torch.add(a, b, alpha=2.0)` is `a' first, followed by `b` and then `alpha`. The same order is for cache loading. However, the orchestration mechanism does not support kwargs because the order of kwargs is useless. For example, the `out` of `aten::gelu.out(Tensor self, *, str approximate='none', Tensor(a!) out) -> Tensor(a!)` may be before `approximate`. We will support it with subsequent PRs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/124070 Approved by: https://github.com/jansel, https://github.com/jgong5 --- test/inductor/test_torchinductor.py | 80 ++++++++++++ torch/_inductor/utils.py | 47 ++++--- .../inductor/aoti_eager/kernel_holder.cpp | 116 ++++++++++++++++-- .../csrc/inductor/aoti_eager/kernel_holder.h | 6 +- .../inductor/aoti_eager/kernel_meta_info.cpp | 43 +++++++ .../inductor/aoti_eager/kernel_meta_info.h | 9 ++ 6 files changed, 278 insertions(+), 23 deletions(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 73779d22bd42..1201e68f277e 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -841,6 +841,86 @@ def fn(a): self.assertTrue(kernel_lib_path in kernel_libs_abs_path) + @skipCUDAIf(not SM80OrLater, "Requires sm80") + def test_eager_aoti_with_scalar(self): + namespace_name = "aten" + op_name = "add" + op_overload_name = "Tensor" + op_name_with_overload = f"{op_name}.{op_overload_name}" + + dispatch_key = "CPU" + device = torch.device("cpu") + if self.device.lower() == "cuda": + dispatch_key = "CUDA" + device = torch.device("cuda") + + # Test the difference between scalar tensor and scalar + a = torch.scalar_tensor(1.0, device=device) + b = torch.scalar_tensor(2.0, device=device) + + kernel_lib_path = aoti_compile_with_persistent_cache( + namespace_name, + op_name_with_overload, + a.device.type, + False, + torch.ops.aten.add, + args=(a, b), + kwargs={"alpha": 3.0}, + ) + self.assertTrue(Path(kernel_lib_path).exists()) + device_kernel_cache = aoti_eager_cache_dir(namespace_name, device.type) + kernel_conf = device_kernel_cache / f"{op_name_with_overload}.json" + self.assertTrue(kernel_conf.exists()) + json_data = load_aoti_eager_cache( + namespace_name, op_name_with_overload, a.device.type + ) + op_info = json_data[0] + self.assertTrue(isinstance(op_info, dict)) + self.assertTrue("meta_info" in op_info) + self.assertTrue(len(op_info["meta_info"]) == 3) + self.assertTrue(op_info["meta_info"][0]["sizes"] == []) + self.assertTrue(op_info["meta_info"][0]["strides"] == []) + # Scalar Tensor + self.assertTrue("scalar_value" not in op_info["meta_info"][0]) + self.assertTrue(op_info["meta_info"][1]["sizes"] == []) + self.assertTrue(op_info["meta_info"][1]["strides"] == []) + # Scalar Tensor + self.assertTrue("scalar_value" not in op_info["meta_info"][1]) + self.assertTrue(op_info["meta_info"][2]["sizes"] == []) + self.assertTrue(op_info["meta_info"][2]["strides"] == []) + # Scalar + self.assertTrue("scalar_value" in op_info["meta_info"][2]) + + with _scoped_library("aten", "IMPL") as torch_compile_op_lib_impl: + a = torch.randn(128, device=device) + b = torch.randn(128, device=device) + + scalar_values = [1.0, 2.0, 3.0] + ref_values = [] + for scalar_value in scalar_values: + ref_values.append(torch.add(a, b, alpha=scalar_value)) + + qualified_op_name = f"{namespace_name}::{op_name}" + _, overload_names = torch._C._jit_get_operation(qualified_op_name) + for overload_name in overload_names: + try: + reg_op_name = qualified_op_name + schema = torch._C._get_schema(reg_op_name, overload_name) + if schema.overload_name: + reg_op_name = f"{reg_op_name}.{schema.overload_name}" + torch_compile_op_lib_impl._impl_with_aoti_compile( # noqa: F821 + reg_op_name, dispatch_key + ) + except Exception as e: + continue + + res_values = [] + for scalar_value in scalar_values: + res_values.append(torch.add(a, b, alpha=scalar_value)) + + self.assertEqual(len(ref_values), len(res_values)) + self.assertEqual(ref_values, res_values) + @skipCUDAIf(not SM80OrLater, "Requires sm80") def test_torch_compile_override_registration(self): dynamic = False diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index fb3d221bbfe4..59baad51885e 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -1578,16 +1578,23 @@ def aoti_compile_with_persistent_cache( """ Compile the given function with persistent cache for AOTI eager mode. """ - flattened_inputs = pytree.arg_tree_leaves(*args, **kwargs) - assert all( - isinstance(input, torch.Tensor) for input in flattened_inputs - ), "Only support tensor for now" assert not dynamic, "Only support static shape for now" + type_to_torch_dtype = {int: torch.int32, float: torch.float, bool: torch.bool} + supported_scalar_types = tuple(type_to_torch_dtype.keys()) + flattened_inputs = pytree.arg_tree_leaves(*args, **kwargs) + if not all( + isinstance(input, (supported_scalar_types, torch.Tensor)) + for input in flattened_inputs + ): + raise NotImplementedError("Only support tensor, int, float, bool for now") persistent_cache = aoti_eager_cache_dir(ns, device_type) - persistent_cache.mkdir(parents=True, exist_ok=True) + if not persistent_cache.exists(): + persistent_cache.mkdir(parents=True) + persistent_cache_lib = persistent_cache / "lib" - persistent_cache_lib.mkdir(parents=True, exist_ok=True) + if not persistent_cache_lib.exists(): + persistent_cache_lib.mkdir() with mock.patch.dict( os.environ, @@ -1609,18 +1616,30 @@ def aoti_compile_with_persistent_cache( ) kernel_metadata_items = [] - for input_tensor in flattened_inputs: + for input in flattened_inputs: # TODO(Eikan): To add dynamic support metadata: Dict[str, Any] = {} metadata["is_dynamic"] = dynamic - metadata["device_type"] = f"{input_tensor.device.type}" - if is_cpu_device([input_tensor]): - metadata["device_index"] = -1 + + if isinstance(input, torch.Tensor): + metadata["device_type"] = f"{input.device.type}" + if is_cpu_device([input]): + metadata["device_index"] = -1 + else: + metadata["device_index"] = input.device.index + metadata["dtype"] = f"{input.dtype}" + metadata["sizes"] = list(input.size()) + metadata["strides"] = list(input.stride()) else: - metadata["device_index"] = input_tensor.device.index - metadata["dtype"] = f"{input_tensor.dtype}" - metadata["sizes"] = list(input_tensor.size()) - metadata["strides"] = list(input_tensor.stride()) + assert isinstance(input, supported_scalar_types) + # Scalar tensor + metadata["device_type"] = device_type + metadata["device_index"] = -1 if device_type == "cpu" else 0 + metadata["dtype"] = f"{type_to_torch_dtype[type(input)]}" + metadata["sizes"] = [] + metadata["strides"] = [] + metadata["scalar_value"] = input + kernel_metadata_items.append(metadata) kernel_meta_info: Dict[str, Any] = {} diff --git a/torch/csrc/inductor/aoti_eager/kernel_holder.cpp b/torch/csrc/inductor/aoti_eager/kernel_holder.cpp index 238050f50122..1ada9415ea12 100644 --- a/torch/csrc/inductor/aoti_eager/kernel_holder.cpp +++ b/torch/csrc/inductor/aoti_eager/kernel_holder.cpp @@ -96,8 +96,13 @@ bool unpack_tensors( const std::vector& arguments, const torch::jit::Stack& stack, const c10::Device& device, - std::vector& inputs) { + std::vector& inputs, + bool with_scalar = false) { for (size_t idx = 0; idx < stack.size(); idx++) { + if (!with_scalar && stack[idx].isScalar()) { + continue; + } + if (!unpack_ivalue(arguments[idx], stack[idx], device, inputs)) { return false; } @@ -106,6 +111,40 @@ bool unpack_tensors( return true; } +std::vector get_tensor_parameter_index( + const std::vector& arguments, + const torch::jit::Stack& stack) { + std::vector tensor_parameter_index; + for (size_t idx = 0; idx < stack.size(); idx++) { + if (stack[idx].isScalar() || stack[idx].isTensor()) { + // scalar and tensor + tensor_parameter_index.push_back(idx); + } else if (stack[idx].isTensorList()) { + // tensor list + std::fill_n( + std::back_inserter(tensor_parameter_index), + stack[idx].toListRef().size(), + idx); + } else if (stack[idx].isOptionalTensorList()) { + // optional tensor list: std::vector> + for (const auto& item : stack[idx].toListRef()) { + if (item.toOptional().has_value()) { + tensor_parameter_index.push_back(idx); + } + } + } else if ( + *arguments[idx].real_type() == + *c10::getTypePtr>()) { + // optional tensor + if (stack[idx].toOptional().has_value()) { + tensor_parameter_index.push_back(idx); + } + } + } + + return tensor_parameter_index; +} + } // namespace AOTIPythonKernelHolder::AOTIPythonKernelHolder( @@ -149,14 +188,19 @@ bool AOTIPythonKernelHolder::cache_lookup( "Not implemented for operations that return a non-Tensor value."); std::vector inputs; - auto res = unpack_tensors(op.schema().arguments(), *stack, device_, inputs); + auto res = + unpack_tensors(op.schema().arguments(), *stack, device_, inputs, true); TORCH_CHECK_NOT_IMPLEMENTED( res && inputs.size() > 0, "Not implemented for operations that contain a parameter which is ", "not one of the following types: at::Tensor, at::TensorList, ", "std::optional, std::vector>."); - auto inputs_metadata = get_inputs_metadata(inputs); + auto tensor_parameter_index = + get_tensor_parameter_index(op.schema().arguments(), *stack); + TORCH_INTERNAL_ASSERT(tensor_parameter_index.size() == inputs.size()); + auto inputs_metadata = get_inputs_metadata( + inputs, op.schema().arguments(), tensor_parameter_index); auto aoti_kernel_state = aoti_kernel_cache_.find(inputs_metadata); if (aoti_kernel_state == aoti_kernel_cache_.end()) { return false; @@ -197,18 +241,49 @@ void AOTIPythonKernelHolder::cache_hit( } AOTIKernelMetadata AOTIPythonKernelHolder::get_inputs_metadata( - const std::vector& inputs) { + const std::vector& inputs, + const std::vector& inputs_argument, + const std::vector& inputs_argument_index) { AOTIKernelMetadata inputs_metadata; - for (const auto& input : inputs) { + for (size_t idx = 0; idx < inputs.size(); ++idx) { + auto input = inputs[idx]; + auto input_info = inputs_argument[inputs_argument_index[idx]]; + auto device = input.device(); if (device.is_cpu()) { // If the device is CPU, set the device index to -1. device = c10::Device(device.type(), -1); } + c10::Scalar scalar_value((double)1.0); + auto tensor_type = input.scalar_type(); + + bool is_scalar = input_info.type()->isSubtypeOf(*c10::NumberType::get()); + if (is_scalar) { + if (c10::isFloatingType(input.scalar_type())) { + auto scalar_numeric_value = input.item().toDouble(); + tensor_type = c10::ScalarType::Double; + scalar_value = c10::Scalar(scalar_numeric_value); + } else if (c10::isIntegralType(input.scalar_type(), false)) { + auto scalar_numeric_value = input.item().toUInt64(); + tensor_type = c10::ScalarType::UInt64; + scalar_value = c10::Scalar(scalar_numeric_value); + } else if (input.scalar_type() == c10::ScalarType::Bool) { + auto scalar_numeric_value = input.item().toBool(); + tensor_type = c10::ScalarType::Bool; + scalar_value = c10::Scalar(scalar_numeric_value); + } else { + TORCH_CHECK( + false, + "Unsupported scalar tensor type: ", + c10::toString(input.scalar_type())); + } + } + inputs_metadata.emplace_back( - false, // is symbloic - input.scalar_type(), + false, + tensor_type, + c10::IValue(scalar_value), device, input.sizes().vec(), input.strides().vec()); @@ -269,6 +344,7 @@ void AOTIPythonKernelHolder::init_aoti_kernel_cache() { reinterpret_cast(data_type_obj.ptr())->scalar_type; auto sizes = metadata["sizes"].cast>(); auto strides = metadata["strides"].cast>(); + bool is_scalar = metadata.contains("scalar_value"); std::vector> sym_optional_sizes; std::vector> sym_optional_strides; @@ -279,10 +355,34 @@ void AOTIPythonKernelHolder::init_aoti_kernel_cache() { sym_optional_strides.push_back(std::optional(stride)); } - // Now you can use these variables in your code + // If an input parameter is a scalar, its detailed value is cached. + // This is done to ensure correctness during subsequent checks. + c10::Scalar scalar_value((double)1.0); + if (is_scalar) { + if (c10::isFloatingType(data_type)) { + auto scalar_numeric_value = metadata["scalar_value"].cast(); + data_type = c10::ScalarType::Double; + scalar_value = c10::Scalar(scalar_numeric_value); + } else if (c10::isIntegralType(data_type, false)) { + auto scalar_numeric_value = metadata["scalar_value"].cast(); + data_type = c10::ScalarType::UInt64; + scalar_value = c10::Scalar(scalar_numeric_value); + } else if (data_type == c10::ScalarType::Bool) { + auto scalar_numeric_value = metadata["scalar_value"].cast(); + data_type = c10::ScalarType::Bool; + scalar_value = c10::Scalar(scalar_numeric_value); + } else { + TORCH_CHECK( + false, + "Unsupported scalar tensor type: ", + c10::toString(data_type)); + } + } + tensor_metadata_list.emplace_back( is_dynamic, data_type, + c10::IValue(scalar_value), c10::Device(c10::Device(device_type).type(), device_index), sizes, strides); diff --git a/torch/csrc/inductor/aoti_eager/kernel_holder.h b/torch/csrc/inductor/aoti_eager/kernel_holder.h index 9cbcc217d7c3..b67e4e7d4464 100644 --- a/torch/csrc/inductor/aoti_eager/kernel_holder.h +++ b/torch/csrc/inductor/aoti_eager/kernel_holder.h @@ -3,6 +3,7 @@ #include #include +#include #include #include @@ -82,7 +83,10 @@ class AOTIPythonKernelHolder : public c10::OperatorKernel { void init_aoti_kernel_cache(); // Abstract the meta information of each tensor for the given operation. The // meta infomation will be used for cache lookup as the key. - AOTIKernelMetadata get_inputs_metadata(const std::vector&); + AOTIKernelMetadata get_inputs_metadata( + const std::vector& inputs, + const std::vector& inputs_argument, + const std::vector& inputs_argument_index); // Load the AOTIModelContainerRunner object from the given file path. std::shared_ptr load_aoti_model_runner( const std::string&); diff --git a/torch/csrc/inductor/aoti_eager/kernel_meta_info.cpp b/torch/csrc/inductor/aoti_eager/kernel_meta_info.cpp index e89c59142328..a49fab21d671 100644 --- a/torch/csrc/inductor/aoti_eager/kernel_meta_info.cpp +++ b/torch/csrc/inductor/aoti_eager/kernel_meta_info.cpp @@ -1,5 +1,6 @@ #if !defined(C10_MOBILE) && !defined(ANDROID) #include +#include namespace torch::inductor { @@ -17,6 +18,24 @@ TensorMetadata::TensorMetadata( std::vector strides) : is_symbolic_(is_symbolic), dtype_(dtype), + scalar_value_((float)1.0), + device_(device), + sizes_(sizes), + strides_(strides) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + !is_symbolic_, "Not support symbolic shape now"); +} + +TensorMetadata::TensorMetadata( + bool is_symbolic, + c10::ScalarType dtype, + c10::IValue scalar_value, + c10::Device device, + std::vector sizes, + std::vector strides) + : is_symbolic_(is_symbolic), + dtype_(dtype), + scalar_value_(scalar_value), device_(device), sizes_(sizes), strides_(strides) { @@ -29,15 +48,39 @@ bool TensorMetadata::operator==(const TensorMetadata& other) const { !is_symbolic_, "Not support symbolic shape now"); return this->is_symbolic_ == other.is_symbolic_ && this->dtype_ == other.dtype_ && + this->scalar_value_ == other.scalar_value_ && this->device_.type() == other.device_.type() && this->sizes_ == other.sizes_ && this->strides_ == other.strides_; } +std::ostream& operator<<( + std::ostream& stream, + const TensorMetadata& tensor_metadata) { + stream << "is_symbolic_: " << tensor_metadata.is_symbolic_ << std::endl; + stream << "dtype_: " << tensor_metadata.dtype_ << std::endl; + stream << "scalar_value_: " << tensor_metadata.scalar_value_.type()->str() + << "(" << tensor_metadata.scalar_value_ << ")" << std::endl; + stream << "device_: " << tensor_metadata.device_ << std::endl; + stream << "sizes_: "; + for (const auto& size : tensor_metadata.sizes_) { + stream << size << " "; + } + stream << std::endl; + stream << "strides_: "; + for (const auto& stride : tensor_metadata.strides_) { + stream << stride << " "; + } + stream << std::endl; + return stream; +} + size_t TensorMetadataHash::operator()( const TensorMetadata& tensor_metadata) const { auto hash = std::hash()(tensor_metadata.is_symbolic_); hash = c10::hash_combine( hash, std::hash()(tensor_metadata.dtype_)); + hash = + c10::hash_combine(hash, c10::IValue::hash(tensor_metadata.scalar_value_)); hash = c10::hash_combine( hash, std::hash()(tensor_metadata.device_.type())); diff --git a/torch/csrc/inductor/aoti_eager/kernel_meta_info.h b/torch/csrc/inductor/aoti_eager/kernel_meta_info.h index c7f8315d2707..5c22e9b75f65 100644 --- a/torch/csrc/inductor/aoti_eager/kernel_meta_info.h +++ b/torch/csrc/inductor/aoti_eager/kernel_meta_info.h @@ -33,6 +33,8 @@ struct TensorMetadata { bool is_symbolic_; // Dtype of a tensor(For scalar, we will wrap it as a scalar tensor) c10::ScalarType dtype_; + // Concrete scalar value. Serve for operations w/ scalar parameter + c10::IValue scalar_value_; // Device of a tensor. c10::Device device_; // Sizes of a tensor. Currently, we only support static shape and use int64_t @@ -49,6 +51,13 @@ struct TensorMetadata { c10::Device device, std::vector sizes, std::vector strides); + TensorMetadata( + bool is_symbolic, + c10::ScalarType dtype, + c10::IValue scalar_value, + c10::Device device, + std::vector sizes, + std::vector strides); bool operator==(const TensorMetadata& other) const; }; From d4704dcacc543023e7a746263a970bef155d58c4 Mon Sep 17 00:00:00 2001 From: drisspg Date: Sat, 18 May 2024 03:19:13 +0000 Subject: [PATCH 099/116] Map float8 types to uint8 for allgather (#126556) # Summary Different take on this one: https://github.com/pytorch/pytorch/issues/126338 We should probably not allow this mapping for 'compute' ops e.g. reductions ### Corresponding fp8 PR https://github.com/pytorch-labs/float8_experimental/pull/263 Pull Request resolved: https://github.com/pytorch/pytorch/pull/126556 Approved by: https://github.com/wanchaol --- test/distributed/test_c10d_nccl.py | 93 +++++++++++++++++++ .../distributed/c10d/ProcessGroupNCCL.cpp | 24 ++++- 2 files changed, 115 insertions(+), 2 deletions(-) diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index a0629f054ae0..5a958acdbdd7 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -2577,6 +2577,27 @@ def test_all_reduce_coalesced_nccl(self): ), ) + @requires_nccl() + @skip_if_lt_x_gpu(2) + def test_all_reduce_coalesced_nccl_float8_errors(self): + store = c10d.FileStore(self.file_name, self.world_size) + c10d.init_process_group( + backend="nccl", store=store, rank=self.rank, world_size=self.world_size + ) + process_group = c10d.distributed_c10d._get_default_group() + device = torch.device("cuda:%d" % self.rank) + tensors = [ + torch.full( + (60 + i,), self.rank + 1 + i, device=device, dtype=torch.float + ).to(torch.float8_e4m3fn) + for i in range(5) + ] + with self.assertRaisesRegex( + RuntimeError, + "Float8 dtypes are not currenlty supported for NCCL reductions", + ): + torch.distributed.all_reduce_coalesced(tensors, group=process_group) + @requires_nccl() @skip_if_lt_x_gpu(2) def test_all_reduce_coalesced_manager_nccl(self): @@ -2940,6 +2961,56 @@ def test_reduce_scatter_tensor_coalesced(self): dist.reduce_scatter_tensor(output_tensors[i], input_tensors[i]) self.assertEqual(output_tensors, input_tensors[self.rank] * self.world_size) + @requires_nccl() + @skip_if_lt_x_gpu(2) + def test_reduce_scatter_base_k_float8_errors(self): + store = dist.FileStore(self.file_name, self.world_size) + dist.init_process_group( + "nccl", + world_size=self.world_size, + rank=self.rank, + store=store, + ) + output_tensor = ( + torch.zeros(2, dtype=torch.float32).to(torch.float8_e4m3fn).to(self.rank) + ) + input_tensors = ( + torch.arange(self.world_size * 2, dtype=torch.float32) + .to(torch.float8_e4m3fn) + .to(self.rank) + ) + input_tensors = torch.reshape(input_tensors, (self.world_size, 2)) + with self.assertRaisesRegex( + RuntimeError, + "Float8 dtypes are not currenlty supported for NCCL reductions", + ): + dist.reduce_scatter_tensor(output_tensor, input_tensors) + + @requires_nccl() + @skip_if_lt_x_gpu(2) + def test_reduce_scatter_tensor_coalesced_float8_errors(self): + store = dist.FileStore(self.file_name, self.world_size) + dist.init_process_group( + "nccl", + world_size=self.world_size, + rank=self.rank, + store=store, + ) + output_tensors = torch.zeros(2, 2).to(torch.float8_e5m2).to(self.rank) + input_tensors = [ + torch.ones(2, 2).to(torch.float8_e5m2).to(self.rank) + for _ in range(self.world_size) + ] + + with self.assertRaisesRegex( + RuntimeError, + "Float8 dtypes are not currenlty supported for NCCL reductions", + ): + with dist._coalescing_manager(): + for i in range(self.world_size): + dist.reduce_scatter_tensor(output_tensors[i], input_tensors[i]) + self.assertEqual(output_tensors, input_tensors[self.rank]) + class SetDeviceMethod(Enum): TORCH_CUDA_SET = auto() # torch.cuda.set_device @@ -2980,6 +3051,28 @@ def test_allgather_base(self): dist.all_gather_into_tensor(output_tensor, tensor) self.assertEqual(output_tensor, tensor) + @requires_nccl() + @skip_if_lt_x_gpu(1) + @parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) + def test_allgather_float8(self, float8_dtype): + store = dist.FileStore(self.file_name, self.world_size) + dist.init_process_group( + "nccl", + world_size=self.world_size, + rank=self.rank, + store=store, + ) + device = "cuda" + tensor = torch.ones(10, 16, device=torch.device(device)).to(float8_dtype) + output_tensor = torch.zeros(10, 16, device=torch.device(device)).to( + float8_dtype + ) + dist.all_gather_into_tensor(output_tensor, tensor) + self.assertEqual(output_tensor.view(torch.float32), tensor.view(torch.float32)) + + +instantiate_parametrized_tests(NcclProcessGroupWithDispatchedCollectivesTests) + class LargeCommTest(test_c10d_common.AbstractLargeCommTest, MultiProcessTestCase): def setUp(self): diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 7ff75e1bd7f5..7586058475ff 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -1,4 +1,3 @@ - #ifdef USE_C10D_NCCL #include @@ -64,6 +63,10 @@ std::map ncclDataType = { {at::kLong, ncclInt64}, {at::kHalf, ncclHalf}, {at::kBool, ncclUint8}, + {at::kFloat8_e5m2, ncclUint8}, + {at::kFloat8_e4m3fn, ncclUint8}, + {at::kFloat8_e4m3fnuz, ncclUint8}, + {at::kFloat8_e5m2fnuz, ncclUint8}, #if HAS_NCCL_BF16_DATATYPE {at::kBFloat16, ncclBfloat16}, #endif @@ -3039,6 +3042,9 @@ c10::intrusive_ptr ProcessGroupNCCL::allreduce_sparse( const AllreduceOptions& opts) { TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); auto tensor = tensors.back(); + TORCH_CHECK( + !isFloat8Type(tensor.scalar_type()), + "Float8 dtypes are not currenlty supported for NCCL reductions"); #ifdef IS_NCCLX tensor = tensor.coalesce(); at::Tensor outputTensor = @@ -3153,7 +3159,9 @@ c10::intrusive_ptr ProcessGroupNCCL::allreduce( return c10::make_intrusive(); } } - + TORCH_CHECK( + !isFloat8Type(tensor.scalar_type()), + "Float8 dtypes are not currenlty supported for NCCL reductions"); // @lint-ignore CLANGTIDY RECORD_PARAM_COMMS_DATA( static_cast( @@ -3180,6 +3188,9 @@ c10::intrusive_ptr ProcessGroupNCCL::allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts) { auto total_numel = check_gpu_tensors_same_device(tensors); + TORCH_CHECK( + !isFloat8Type(tensors.back().scalar_type()), + "Float8 dtypes are not currenlty supported for NCCL reductions"); // @lint-ignore CLANGTIDY RECORD_PARAM_COMMS_DATA( @@ -3552,6 +3563,9 @@ c10::intrusive_ptr ProcessGroupNCCL::reduce_scatter( check_gpu_single_tensor(outputTensor); // @lint-ignore CLANGTIDY auto inputTensors_ = inputTensors.back(); + TORCH_CHECK( + !isFloat8Type(outputTensor.scalar_type()), + "Float8 dtypes are not currenlty supported for NCCL reductions"); RECORD_PARAM_COMMS_DATA( static_cast( @@ -3663,6 +3677,9 @@ c10::intrusive_ptr ProcessGroupNCCL::_reduce_scatter_base( // @lint-ignore CLANGTIDY const auto& tensor = outputTensor; + TORCH_CHECK( + !isFloat8Type(tensor.scalar_type()), + "Float8 dtypes are not currenlty supported for NCCL reductions"); RECORD_PARAM_COMMS_DATA( static_cast( this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective @@ -3723,6 +3740,9 @@ c10::intrusive_ptr ProcessGroupNCCL::reduce_scatter_tensor_coalesced( std::vector& outputs, std::vector& inputs, const ReduceScatterOptions& opts) { + TORCH_CHECK( + !isFloat8Type(inputs.back().scalar_type()), + "Float8 dtypes are not currenlty supported for NCCL reductions"); return collectiveCoalesced( inputs, outputs, From a44d0cf227c9ac3ab3bc0187e1a382b16a5ea7d0 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Sat, 18 May 2024 04:44:28 +0000 Subject: [PATCH 100/116] [Traceable FSDP2] Change from register_multi_grad_hook to per-tensor backward hook (#126350) As discussed with Andrew before, under compile we will register per-tensor backward hook instead of multi-grad hook, because it's difficult for Dynamo to support `register_multi_grad_hook` (or anything `.grad_fn` related). We expect both to have the same underlying behavior, ~~and we will add integration test (in subsequent PR) to show that compile and eager has same numerics.~~ As discussed below, we will change eager path to use per-tensor backward hook as well. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126350 Approved by: https://github.com/awgu --- torch/distributed/_composable/fsdp/_fsdp_param_group.py | 2 ++ torch/distributed/_composable/fsdp/_fsdp_state.py | 9 +++++---- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/torch/distributed/_composable/fsdp/_fsdp_param_group.py b/torch/distributed/_composable/fsdp/_fsdp_param_group.py index 569858e92656..ea2307222ce1 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_param_group.py +++ b/torch/distributed/_composable/fsdp/_fsdp_param_group.py @@ -277,6 +277,8 @@ def _record_post_forward(self) -> None: self._post_forward_indices.append(post_forward_index) def pre_backward(self, *unused: Any): + if self._training_state == TrainingState.PRE_BACKWARD: + return with torch.profiler.record_function("FSDP::pre_backward"): self._training_state = TrainingState.PRE_BACKWARD self.unshard() # no-op if prefetched diff --git a/torch/distributed/_composable/fsdp/_fsdp_state.py b/torch/distributed/_composable/fsdp/_fsdp_state.py index bab24c283063..15a00e83f086 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_state.py +++ b/torch/distributed/_composable/fsdp/_fsdp_state.py @@ -5,7 +5,6 @@ import torch import torch.nn as nn from torch.autograd import Variable -from torch.autograd.graph import register_multi_grad_hook from torch.distributed._composable_state import ( _get_module_state, _insert_module_state, @@ -201,11 +200,12 @@ def _post_forward(self, module: nn.Module, input: Any, output: Any) -> Any: ) return output - def _pre_backward(self, *unused: Any) -> None: + def _pre_backward(self, grad: torch.Tensor) -> torch.Tensor: self._training_state = TrainingState.PRE_BACKWARD self._register_root_post_backward_final_callback() if self._fsdp_param_group: - self._fsdp_param_group.pre_backward(*unused) + self._fsdp_param_group.pre_backward() + return grad def _root_post_backward_final_callback(self) -> None: with torch.profiler.record_function("FSDP::root_post_backward_callback"): @@ -235,7 +235,8 @@ def _register_pre_backward_hook(self, output: Any) -> Any: t for t in flat_outputs if (torch.is_tensor(t) and t.requires_grad) ) if tensors: - register_multi_grad_hook(tensors, self._pre_backward, mode="any") + for tensor in tensors: + tensor.register_hook(self._pre_backward) return output def _register_root_post_backward_final_callback(self): From 6bb9d6080d33c817fcbf9e5ae8a59b76812a53d2 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Sat, 18 May 2024 05:02:14 +0000 Subject: [PATCH 101/116] [Dynamo] Treat integers stored on nn.Modules as dynamic (#126466) Fixes #115711 Pull Request resolved: https://github.com/pytorch/pytorch/pull/126466 Approved by: https://github.com/jansel --- test/dynamo/test_modules.py | 57 ++++++++++++++++++++++++++++++ torch/_dynamo/variables/builder.py | 4 --- 2 files changed, 57 insertions(+), 4 deletions(-) diff --git a/test/dynamo/test_modules.py b/test/dynamo/test_modules.py index ceb1521ffe69..b22f02ee2fcc 100644 --- a/test/dynamo/test_modules.py +++ b/test/dynamo/test_modules.py @@ -19,6 +19,7 @@ from torch._dynamo.eval_frame import unsupported from torch._dynamo.mutation_guard import GenerationTracker from torch._dynamo.testing import expectedFailureDynamic, same +from torch._dynamo.utils import ifdynstaticdefault from torch.nn.modules.lazy import LazyModuleMixin from torch.nn.parameter import Parameter, UninitializedParameter @@ -1104,6 +1105,37 @@ def forward(self, x): return self.m(x) +class ModuleWithIntAttr(torch.nn.Module): + def __init__(self): + super().__init__() + self.layer = torch.nn.Linear(4, 4) + self.step = 10 + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + 1 + self.step += 1 + return self.layer(x) + self.step + + +class UnspecInlinableModule(torch.nn.Module): + torchdynamo_force_dynamic = True # forced to be a UnspecializedNNModule + + def forward(self, x): + return torch.sin(x) + + +class UnspecModuleWithIntAttr(torch.nn.Module): + def __init__(self): + super().__init__() + self.layer = UnspecInlinableModule() + self.step = 10 + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + 1 + self.step += 1 + return self.layer(x) + self.step + + def make_test(fn, expected_ops=None): def test_fn(self): return torch._dynamo.testing.standard_test( @@ -1357,6 +1389,31 @@ def forward(self, x): self.assertTrue(torch._dynamo.testing.same(pre, opt_pre)) self.assertTrue(torch._dynamo.testing.same(out1, out_post)) + def test_nn_module_unspec_int_attr(self): + for module_class in [ModuleWithIntAttr, UnspecModuleWithIntAttr]: + mod = module_class() + cnt = torch._dynamo.testing.CompileCounter() + opt_mod = torch.compile(backend=cnt)(copy.deepcopy(mod)) + x = torch.randn(3, 4) + + # Compiling self.step as static. + ref1 = mod(x) + res1 = opt_mod(x) + self.assertTrue(torch.allclose(ref1, res1)) + self.assertEqual(cnt.frame_count, 1) + + # Compiling self.step as dynamic. + ref2 = mod(x) + res2 = opt_mod(x) + self.assertTrue(torch.allclose(ref2, res2)) + self.assertEqual(cnt.frame_count, ifdynstaticdefault(2, 1)) + + # No re-compilation! + ref3 = mod(x) + res3 = opt_mod(x) + self.assertTrue(torch.allclose(ref3, res3)) + self.assertEqual(cnt.frame_count, ifdynstaticdefault(2, 1)) + # RuntimeError: SymIntArrayRef expected to contain only concrete integers @expectedFailureDynamic def test_lazy_module1(self): diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 41b9fbd836ae..c1b9f68639f5 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -1162,10 +1162,6 @@ def wrap_literal(self, value): value in self._common_constants() # Assume integers from global variables want to be specialized or not self.source.guard_source().is_local() - # Assume that integers that came from NN modules want to be - # specialized (as we don't expect users to be changing the - # NN modules on the fly) - or self.source.guard_source().is_nn_module() or is_from_defaults(self.source) or is_cell_contents(self.source) ): From 99af1b3ab03289e922b95252bab51758de6035c9 Mon Sep 17 00:00:00 2001 From: Jiashen Cao Date: Sat, 18 May 2024 06:05:14 +0000 Subject: [PATCH 102/116] Refactor variables / function names related to non-strict export (#126458) Improve variable and function naming for better clarity: `non strict` --> `aten`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126458 Approved by: https://github.com/angelayi --- torch/export/_trace.py | 54 +++++++++++++++++++++++------------------- 1 file changed, 30 insertions(+), 24 deletions(-) diff --git a/torch/export/_trace.py b/torch/export/_trace.py index 1c2f3b880f35..c85a82c8c4c5 100644 --- a/torch/export/_trace.py +++ b/torch/export/_trace.py @@ -468,7 +468,7 @@ def _export_to_torch_ir( return gm_torch_level -def _export_non_strict( +def _export_to_aten_ir( mod: torch.nn.Module, fake_args, fake_kwargs, @@ -647,7 +647,7 @@ def make_argument_spec(i, node) -> ArgumentSpec: ) @dataclasses.dataclass - class _ExportedProgramNonStrict: + class _ExportedArtifact: gm: torch.fx.GraphModule sig: ExportGraphSignature constants: Dict[ @@ -659,7 +659,7 @@ class _ExportedProgramNonStrict: ], ] - return _ExportedProgramNonStrict( + return _ExportedArtifact( gm, export_graph_signature, constants, @@ -1113,7 +1113,7 @@ def forward(self, *args, **kwargs): new_fake_constant_attrs, map_fake_to_real, ): - ep_non_strict = _export_non_strict( + aten_export_artifact = _export_to_aten_ir( patched_mod, new_fake_args, new_fake_kwargs, @@ -1124,15 +1124,15 @@ def forward(self, *args, **kwargs): should_insert_runtime_assertion=not strict, _is_torch_jit_trace=_is_torch_jit_trace, ) - # ep_non_strict.constants contains only fake script objects, we need to map them back - ep_non_strict.constants = { + # aten_export_artifact.constants contains only fake script objects, we need to map them back + aten_export_artifact.constants = { fqn: map_fake_to_real[obj] if isinstance(obj, FakeScriptObject) else obj - for fqn, obj in ep_non_strict.constants.items() + for fqn, obj in aten_export_artifact.constants.items() } - ep_non_strict.gm.meta["inline_constraints"] = { + aten_export_artifact.gm.meta["inline_constraints"] = { k: v for k, v in fake_mode.shape_env.var_to_range.items() if free_unbacked_symbols(k) @@ -1140,14 +1140,14 @@ def forward(self, *args, **kwargs): num_lifted = len( [ spec - for spec in ep_non_strict.sig.input_specs + for spec in aten_export_artifact.sig.input_specs if spec.kind != InputKind.USER_INPUT ] ) try: produce_guards_and_solve_constraints( fake_mode, - ep_non_strict.gm, + aten_export_artifact.gm, equalities_inputs, original_signature, _disable_forced_specializations=_disable_forced_specializations, @@ -1159,7 +1159,7 @@ def forward(self, *args, **kwargs): combined_args = _combine_args(mod, args, kwargs) range_constraints = make_constraints( fake_mode, - ep_non_strict.gm, + aten_export_artifact.gm, combined_args, dynamic_shapes, num_lifted, @@ -1167,7 +1167,7 @@ def forward(self, *args, **kwargs): assert out_spec is not None - gm = ep_non_strict.gm + gm = aten_export_artifact.gm gm.meta["forward_arg_names"] = forward_arg_names module_call_signatures = { @@ -1194,26 +1194,30 @@ def forward(self, *args, **kwargs): node.replace_all_uses_with(new_node) gm.graph.erase_node(node) - res = CollectTracepointsPass(module_call_signatures, ep_non_strict.sig)(gm) + res = CollectTracepointsPass( + module_call_signatures, aten_export_artifact.sig + )(gm) assert res is not None gm = res.graph_module - _rewrite_non_persistent_buffers(mod, ep_non_strict.sig, ep_non_strict.constants) + _rewrite_non_persistent_buffers( + mod, aten_export_artifact.sig, aten_export_artifact.constants + ) _verify_nn_module_stack(gm) _verify_stack_trace(gm) if not _is_torch_jit_trace: - _verify_placeholder_names(gm, ep_non_strict.sig) + _verify_placeholder_names(gm, aten_export_artifact.sig) exported_program = ExportedProgram( root=gm, graph=gm.graph, - graph_signature=ep_non_strict.sig, + graph_signature=aten_export_artifact.sig, state_dict=original_state_dict, range_constraints=range_constraints, module_call_graph=_make_module_call_graph( _EXPORT_MODULE_HIERARCHY, orig_in_spec, out_spec, module_call_signatures ), example_inputs=(args, kwargs), - constants=ep_non_strict.constants, + constants=aten_export_artifact.constants, ) return exported_program @@ -1310,7 +1314,7 @@ def forward(self, *args, **kwargs): # NOTE: graph module expects only positional args constant_attrs = _gather_constant_attrs(mod) - ep_non_strict = _export_non_strict( + aten_export_artifact = _export_to_aten_ir( gm_torch_level, _convert_to_positional_args(orig_arg_names, fake_args, fake_kwargs), {}, @@ -1320,9 +1324,9 @@ def forward(self, *args, **kwargs): should_insert_runtime_assertion=not strict, ) - gm = ep_non_strict.gm - export_graph_signature = ep_non_strict.sig - constants = ep_non_strict.constants + gm = aten_export_artifact.gm + export_graph_signature = aten_export_artifact.sig + constants = aten_export_artifact.constants # Don't copy over nn_module_stack, stack_trace metadata for params/buffers nodes for metadata in params_buffers_to_node_meta.values(): @@ -1378,15 +1382,17 @@ def forward(self, *args, **kwargs): _rewrite_dynamo_tensor_constants( orig_mod_buffers=set(mod.buffers()), traced_mod_buffers=dict(gm_torch_level.named_buffers()), - graph_signature=ep_non_strict.sig, - constants=ep_non_strict.constants, + graph_signature=aten_export_artifact.sig, + constants=aten_export_artifact.constants, ) # 2. Restore FQN of param/buffers param_buffer_table: Dict[str, str] = _get_param_buffer_mapping(mod, gm_torch_level) _replace_param_buffer_names(param_buffer_table, export_graph_signature) # 3. Remove non-persistent buffers from the graph signature - _rewrite_non_persistent_buffers(mod, ep_non_strict.sig, ep_non_strict.constants) + _rewrite_non_persistent_buffers( + mod, aten_export_artifact.sig, aten_export_artifact.constants + ) # 4. Rewrite constants to have the same FQN as the original module. _remap_constants(constant_attrs, export_graph_signature, constants) From ad67553c5c1672d65b810acd7a6a01e11695098b Mon Sep 17 00:00:00 2001 From: Tarunbir Gambhir Date: Sat, 18 May 2024 15:42:45 +0000 Subject: [PATCH 103/116] Updated test_torch.py to use new OptimizerInfo infrastructure (#125538) Fixes #123451 (only addresses test_torch.py cases) This PR solves the specific task to update `test_grad_scaling_autocast` and `test_params_invalidated_with_grads_invalidated_between_unscale_and_step` in `test/test_torch.py` to use the new OptimizerInfo infrastructure. I have combined tests that call `_grad_scaling_autocast_test` into one called `test_grad_scaling_autocast` and used `_get_optim_inputs_including_global_cliquey_kwargs` to avoid hard-coded configurations. ``` $ lintrunner test/test_cuda.py ok No lint issues. ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/125538 Approved by: https://github.com/janeyx99 --- test/test_torch.py | 96 +++++++++++++++++++++------------------------- 1 file changed, 43 insertions(+), 53 deletions(-) diff --git a/test/test_torch.py b/test/test_torch.py index 81da78f9a882..c8cff93bd1bf 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -29,6 +29,8 @@ from functools import partial from torch import multiprocessing as mp from torch.testing import make_tensor +from torch.testing._internal.common_optimizers import ( + optim_db, optims, _get_optim_inputs_including_global_cliquey_kwargs) from torch.testing._internal.common_utils import ( # type: ignore[attr-defined] TEST_WITH_TORCHINDUCTOR, TEST_WITH_ROCM, run_tests, IS_JETSON, @@ -5877,8 +5879,13 @@ def _run_scaling_case(self, device, run, unskipped, skipped, atol=1e-7, optimize self.assertEqual(c, s, atol=atol, rtol=1e-05) - # Compares no scaling + no autocasting against scaling + autocasting. - def _grad_scaling_autocast_test(self, *, device="cuda", atol=1e-3, optimizer_ctor=torch.optim.SGD, optimizer_kwargs=None): + @onlyNativeDeviceTypes + @parametrize("foreach, fused", [(None, None), (True, None), (None, True)]) + @optims( + [optim for optim in optim_db if optim.optim_cls in [torch.optim.AdamW, torch.optim.Adam, torch.optim.SGD]], + dtypes=[torch.float32] + ) + def test_grad_scaling_autocast(self, device, dtype, optim_info, foreach, fused): try_pickle = False def run(device, data, model, optimizer, scaler, loss_fn, skip_iter, try_scaling_api): @@ -5902,6 +5909,9 @@ def run(device, data, model, optimizer, scaler, loss_fn, skip_iter, try_scaling_ optimizer.step() return scaler + optimizer_ctor = optim_info.optim_cls + + # Compares no scaling + no autocasting against scaling + autocasting. # NOTE(mkozuki): With current way of testing, `torch.optim.Adam` is failing in spite of `foreach` and `fused`. # Giving some flexibility to this test might help. context = contextlib.nullcontext @@ -5911,71 +5921,51 @@ def run(device, data, model, optimizer, scaler, loss_fn, skip_iter, try_scaling_ with context(): # sets atol=1e-3 because we're comparing pure fp32 arithmetic vs a mixture of fp16 and fp32 self._run_scaling_case( - device, run, unskipped=3, skipped=1, atol=atol, optimizer_ctor=optimizer_ctor, optimizer_kwargs=optimizer_kwargs, + device, run, unskipped=3, skipped=1, atol=1e-3, + optimizer_ctor=optimizer_ctor, optimizer_kwargs={"foreach": foreach, "fused": fused}, ) # this will be picked up by try_pickle within run(): try_pickle = True self._run_scaling_case( - device, run, unskipped=3, skipped=1, atol=atol, optimizer_ctor=optimizer_ctor, optimizer_kwargs=optimizer_kwargs, + device, run, unskipped=3, skipped=1, atol=1e-3, + optimizer_ctor=optimizer_ctor, optimizer_kwargs={"foreach": foreach, "fused": fused}, ) - @onlyNativeDeviceTypes - def test_grad_scaling_autocast(self, device): - device = torch.device(device) - for optimizer_ctor in (torch.optim.SGD, torch.optim.Adam, torch.optim.AdamW): - self._grad_scaling_autocast_test(device=device.type, optimizer_ctor=optimizer_ctor) - - @onlyNativeDeviceTypes - def test_grad_scaling_autocast_foreach(self, device): - device = torch.device(device) - for optimizer_ctor in (torch.optim.SGD, torch.optim.Adam, torch.optim.AdamW): - self._grad_scaling_autocast_test(device=device.type, optimizer_ctor=optimizer_ctor, optimizer_kwargs={"foreach": True}) - - @onlyNativeDeviceTypes - def test_grad_scaling_autocast_fused(self, device): - device = torch.device(device) - for optimizer_ctor in (torch.optim.Adam, torch.optim.AdamW): - self._grad_scaling_autocast_test(device=device.type, optimizer_ctor=optimizer_ctor, optimizer_kwargs={"fused": True}) - # Make sure that the parameters become nonsense when scaled gradients are finite # but they get invalidated before `optimizer.step`, after `GradScaler.unscale_` @onlyNativeDeviceTypes - def test_params_invalidated_with_grads_invalidated_between_unscale_and_step(self, device): - device = torch.device(device) - for optimizer_ctor, optimizer_kwargs in product( - (torch.optim.Adam, torch.optim.AdamW), - ( - {"foreach": False, "fused": False}, - {"foreach": True, "fused": False}, - {"foreach": False, "fused": True}, - ), - ): - with self.subTest(optimizer=optimizer_ctor, optimizer_kwargs=optimizer_kwargs): - self._test_grads_invalidated_between_unscale_and_step(device.type, optimizer_ctor, optimizer_kwargs) - - def _test_grads_invalidated_between_unscale_and_step(self, device, optimizer_ctor, optimizer_kwargs): - model, _, optimizer, _, data, loss_fn, _ = _create_scaling_case( - device, optimizer_ctor=optimizer_ctor, optimizer_kwargs=optimizer_kwargs, - ) - scaler = torch.GradScaler(device=device, init_scale=128.0) + @optims( + [optim for optim in optim_db if optim.optim_cls in [torch.optim.AdamW, torch.optim.Adam, torch.optim.SGD]], + dtypes=[torch.float32] + ) + def test_params_invalidated_with_grads_invalidated_between_unscale_and_step(self, device, dtype, optim_info): + optimizer_ctor = optim_info.optim_cls + all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs( + device, dtype, optim_info, skip=("differentiable",)) + + for optim_input in all_optim_inputs: + model, _, optimizer, _, data, loss_fn, _ = _create_scaling_case( + device, optimizer_ctor=optimizer_ctor, optimizer_kwargs=optim_input.kwargs, + ) + scaler = torch.GradScaler(device=device, init_scale=128.0) - for input, target in data: - optimizer.zero_grad() - with torch.autocast(device_type=device, dtype=torch.half): - output = model(input) - loss = loss_fn(output, target) - scaler.scale(loss).backward() - scaler.unscale_(optimizer) + for input, target in data: + optimizer.zero_grad() + with torch.autocast(device_type=device, dtype=torch.half): + output = model(input) + loss = loss_fn(output, target) + scaler.scale(loss).backward() + scaler.unscale_(optimizer) - # deliberately break grads - for j, param in enumerate(model.parameters()): - param.grad.copy_(torch.inf if j % 2 else torch.nan) + # deliberately break grads + for j, param in enumerate(model.parameters()): + param.grad.copy_(torch.inf if j % 2 else torch.nan) - scaler.step(optimizer) - scaler.update() + scaler.step(optimizer) + scaler.update() - self.assertTrue(all((p.isnan().any() or p.isinf().any()) for p in model.parameters())) + self.assertTrue(all((p.isnan().any() or p.isinf().any()) for p in model.parameters())) @onlyNativeDeviceTypes def test_grad_scale_will_not_overflow(self, device): From abc4b661249337670c10b91d2449b4bb8f6ddb20 Mon Sep 17 00:00:00 2001 From: Huy Do Date: Sat, 18 May 2024 23:56:03 +0000 Subject: [PATCH 104/116] Forward fix the failed new test from D57474327 (#126596) Summary: TSIA. The two looks the same to me, but buck was failing with the following error when `with torch._inductor.utils.fresh_inductor_cache()` is used: ``` _________________________ ReproTests.test_issue126128 __________________________ self = def test_issue126128(self): def fn(): x = torch.randn(1, 10) y = torch.randn(10, 1) return torch.mm(x, y).sum() def fn2(): x = torch.randn(10, 100) y = torch.randn(100, 10) return torch.mm(x, y).sum() > with torch._inductor.utils.fresh_inductor_cache(): E AttributeError: module 'torch._inductor' has no attribute 'utils' ``` Test Plan: `buck2 test 'fbcode//mode/opt' fbcode//caffe2/test/dynamo:test_dynamo -- --exact 'caffe2/test/dynamo:test_dynamo - test_repros.py::ReproTests::test_issue126128'` Differential Revision: D57516676 Pull Request resolved: https://github.com/pytorch/pytorch/pull/126596 Approved by: https://github.com/xmfan --- test/dynamo/test_repros.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 510b32be905b..96bf924e0999 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -33,6 +33,7 @@ from torch import nn from torch._dynamo.debug_utils import same_two_models from torch._dynamo.testing import CompileCounter, rand_strided, same +from torch._inductor.utils import fresh_inductor_cache from torch.nn import functional as F from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION @@ -4982,7 +4983,7 @@ def fn2(): y = torch.randn(100, 10) return torch.mm(x, y).sum() - with torch._inductor.utils.fresh_inductor_cache(): + with fresh_inductor_cache(): torch.compile(fn)() torch.compile(fn2)() From e3230f87aa77e46d30cac6192aae4a2b47bf1cbf Mon Sep 17 00:00:00 2001 From: chilli Date: Sat, 18 May 2024 06:07:50 -0700 Subject: [PATCH 105/116] Cached required_fw_nodes creation (#126613) Pull Request resolved: https://github.com/pytorch/pytorch/pull/126613 Approved by: https://github.com/anijain2305 --- torch/_dynamo/convert_frame.py | 8 ++++---- torch/_functorch/partitioners.py | 12 ++++++------ 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index c1733280e31f..d5c24a67d9e2 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -312,7 +312,7 @@ def profile_wrapper(*args, **kwargs): retval = prof.runcall(func, *args, **kwargs) profile_latency = time.time() - start_ts prof.disable() - log.info( + log.warning( "### Cprofile for %s trace id [%s] took %.3f seconds ###", func.__name__, trace_id, @@ -322,7 +322,7 @@ def profile_wrapper(*args, **kwargs): try: prof.dump_stats(profile_path) except PermissionError: - log.info("Cannot write to %s", str(profile_path)) + log.warning("Cannot write to %s", str(profile_path)) svg_path = profile_path.with_suffix(".svg") try: gprof2dot_process = subprocess.Popen( @@ -341,9 +341,9 @@ def profile_wrapper(*args, **kwargs): ["dot", "-Tsvg", "-o", str(svg_path)], stdin=gprof2dot_process.stdout, ) - log.info("Generated SVG from profile at %s", str(svg_path)) + log.warning("Generated SVG from profile at %s", str(svg_path)) except FileNotFoundError: - log.info( + log.warning( "Failed to generate SVG from profile -- dumping stats instead." "Try installing gprof2dot and dot for a better visualization" ) diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py index d104247b3f63..0956ee7e367c 100644 --- a/torch/_functorch/partitioners.py +++ b/torch/_functorch/partitioners.py @@ -73,7 +73,7 @@ class NodeInfo: unclaimed_nodes: Set[fx.Node] fw_order: Dict[fx.Node, int] - @property + @functools.cached_property def required_fw_nodes(self) -> List[fx.Node]: return sorted( (n for n in self._required_fw_nodes), key=lambda n: self.fw_order[n] @@ -985,7 +985,7 @@ def find_first_unfusible(start_nodes: List[fx.Node], max_range: int) -> int: orders = [ node_info.get_fw_order(user) for user in used_node.users - if user in node_info.required_fw_nodes + if node_info.is_required_fw(user) ] fw_users = [ user for user in used_node.users if node_info.is_required_fw(user) @@ -994,7 +994,7 @@ def find_first_unfusible(start_nodes: List[fx.Node], max_range: int) -> int: first_unfusible_use = find_first_unfusible(fw_users, max(orders)) for user in tuple(used_node.users): if ( - user in node_info.required_fw_nodes + node_info.is_required_fw(user) and node_info.get_fw_order(user) > first_unfusible_use and is_fusible(used_node, user) ): @@ -1024,7 +1024,7 @@ def find_first_unfusible(start_nodes: List[fx.Node], max_range: int) -> int: if min_cut_options.ban_if_long_fusible_chains: visited = set() for start_node in joint_graph.nodes: - if start_node not in node_info.required_fw_nodes: + if not node_info.is_required_fw(start_node): continue fusible = [(node_info.get_fw_order(start_node), start_node)] start_order = node_info.get_fw_order(start_node) @@ -1050,7 +1050,7 @@ def find_first_unfusible(start_nodes: List[fx.Node], max_range: int) -> int: for user in cur.users: if ( - user in node_info.required_fw_nodes + node_info.is_required_fw(user) and is_fusible(cur, user) and user not in banned_nodes ): @@ -1404,7 +1404,7 @@ def classify_nodes(joint_module): for node in reversed(joint_module.graph.nodes): if node.op == "output": node.dist_from_bw = int(1e9) - elif node not in node_info.required_fw_nodes: + elif not node_info.is_required_fw(node): node.dist_from_bw = 0 else: node.dist_from_bw = int(1e9) From 71b6459edc19dac02b6e559977a0f615bf451210 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Sun, 19 May 2024 02:52:11 +0000 Subject: [PATCH 106/116] Revert "[Dynamo] Treat integers stored on nn.Modules as dynamic (#126466)" This reverts commit 6bb9d6080d33c817fcbf9e5ae8a59b76812a53d2. Reverted https://github.com/pytorch/pytorch/pull/126466 on behalf of https://github.com/huydhn due to Sorry for reverting your change but the ONNX test failure looks legit, not flaky, as it starts failing in trunk https://hud.pytorch.org/pytorch/pytorch/commit/6bb9d6080d33c817fcbf9e5ae8a59b76812a53d2 ([comment](https://github.com/pytorch/pytorch/pull/126466#issuecomment-2119078245)) --- test/dynamo/test_modules.py | 57 ------------------------------ torch/_dynamo/variables/builder.py | 4 +++ 2 files changed, 4 insertions(+), 57 deletions(-) diff --git a/test/dynamo/test_modules.py b/test/dynamo/test_modules.py index b22f02ee2fcc..ceb1521ffe69 100644 --- a/test/dynamo/test_modules.py +++ b/test/dynamo/test_modules.py @@ -19,7 +19,6 @@ from torch._dynamo.eval_frame import unsupported from torch._dynamo.mutation_guard import GenerationTracker from torch._dynamo.testing import expectedFailureDynamic, same -from torch._dynamo.utils import ifdynstaticdefault from torch.nn.modules.lazy import LazyModuleMixin from torch.nn.parameter import Parameter, UninitializedParameter @@ -1105,37 +1104,6 @@ def forward(self, x): return self.m(x) -class ModuleWithIntAttr(torch.nn.Module): - def __init__(self): - super().__init__() - self.layer = torch.nn.Linear(4, 4) - self.step = 10 - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = x + 1 - self.step += 1 - return self.layer(x) + self.step - - -class UnspecInlinableModule(torch.nn.Module): - torchdynamo_force_dynamic = True # forced to be a UnspecializedNNModule - - def forward(self, x): - return torch.sin(x) - - -class UnspecModuleWithIntAttr(torch.nn.Module): - def __init__(self): - super().__init__() - self.layer = UnspecInlinableModule() - self.step = 10 - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = x + 1 - self.step += 1 - return self.layer(x) + self.step - - def make_test(fn, expected_ops=None): def test_fn(self): return torch._dynamo.testing.standard_test( @@ -1389,31 +1357,6 @@ def forward(self, x): self.assertTrue(torch._dynamo.testing.same(pre, opt_pre)) self.assertTrue(torch._dynamo.testing.same(out1, out_post)) - def test_nn_module_unspec_int_attr(self): - for module_class in [ModuleWithIntAttr, UnspecModuleWithIntAttr]: - mod = module_class() - cnt = torch._dynamo.testing.CompileCounter() - opt_mod = torch.compile(backend=cnt)(copy.deepcopy(mod)) - x = torch.randn(3, 4) - - # Compiling self.step as static. - ref1 = mod(x) - res1 = opt_mod(x) - self.assertTrue(torch.allclose(ref1, res1)) - self.assertEqual(cnt.frame_count, 1) - - # Compiling self.step as dynamic. - ref2 = mod(x) - res2 = opt_mod(x) - self.assertTrue(torch.allclose(ref2, res2)) - self.assertEqual(cnt.frame_count, ifdynstaticdefault(2, 1)) - - # No re-compilation! - ref3 = mod(x) - res3 = opt_mod(x) - self.assertTrue(torch.allclose(ref3, res3)) - self.assertEqual(cnt.frame_count, ifdynstaticdefault(2, 1)) - # RuntimeError: SymIntArrayRef expected to contain only concrete integers @expectedFailureDynamic def test_lazy_module1(self): diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index c1b9f68639f5..41b9fbd836ae 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -1162,6 +1162,10 @@ def wrap_literal(self, value): value in self._common_constants() # Assume integers from global variables want to be specialized or not self.source.guard_source().is_local() + # Assume that integers that came from NN modules want to be + # specialized (as we don't expect users to be changing the + # NN modules on the fly) + or self.source.guard_source().is_nn_module() or is_from_defaults(self.source) or is_cell_contents(self.source) ): From 7dae7d3ca5acd1b8f5cf5ed9f24175c367b5d42b Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Fri, 17 May 2024 10:31:49 -0700 Subject: [PATCH 107/116] Remove unnecessary implementations from MockHandler (#126511) Dead implementations are confusing and can cause bugs when people accidentally hit them. Better for it to be missing. Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/126511 Approved by: https://github.com/peterbell10, https://github.com/lezcano --- torch/_inductor/lowering.py | 2 +- torch/_inductor/ops_handler.py | 23 +++++++++++++++++++++-- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 80eee352458d..07899fe2ccd0 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -4559,7 +4559,7 @@ def fn(idx): factor = ops.index_expr(hend - hstart, torch.int32) divide_factors.append(factor) divide_factor = functools.reduce(ops.mul, divide_factors) - return ops.div(fn_sum(idx, x_loader), divide_factor) + return ops.truediv(fn_sum(idx, x_loader), divide_factor) rv = Pointwise.create( device=x.get_device(), diff --git a/torch/_inductor/ops_handler.py b/torch/_inductor/ops_handler.py index 3df6749083c1..71395c71c9b6 100644 --- a/torch/_inductor/ops_handler.py +++ b/torch/_inductor/ops_handler.py @@ -17,7 +17,6 @@ import torch import torch.utils._pytree as pytree -from torch.fx.graph import magic_methods from .utils import IndentedBuffer, reduction_num_outputs, sympy_index_symbol, sympy_str T = TypeVar("T") @@ -748,7 +747,27 @@ def inner(*args): return inner - for name, format_string in itertools.chain(magic_methods.items()): + for name, format_string in { + "add": "{} + {}", + "sub": "{} - {}", + "mul": "{} * {}", + "floordiv": "{} // {}", + "truediv": "{} / {}", + "mod": "{} % {}", # careful, depending on target semantics varies + "pow": "{} ** {}", + "lshift": "{} << {}", + "rshift": "{} >> {}", + "and_": "{} & {}", + "or_": "{} | {}", + "xor": "{} ^ {}", + "eq": "{} == {}", + "ne": "{} != {}", + "lt": "{} < {}", + "gt": "{} > {}", + "le": "{} <= {}", + "ge": "{} >= {}", + "neg": "-{}", + }.items(): setattr(cls, name, make_handler(format_string)) From c4dfd783f471b0aac004e2eb1869c0c04eb24cf6 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Sat, 18 May 2024 21:44:19 -0700 Subject: [PATCH 108/116] UFMT torch.utils._sympy.functions (#126553) Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/126553 Approved by: https://github.com/lezcano, https://github.com/Skylion007 ghstack dependencies: #126511 --- .lintrunner.toml | 2 -- torch/utils/_sympy/functions.py | 59 ++++++++++++++++++++++++--------- 2 files changed, 43 insertions(+), 18 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index 8dfc1554041a..50eb09984fec 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -1929,8 +1929,6 @@ exclude_patterns = [ 'torch/utils/_mode_utils.py', 'torch/utils/_python_dispatch.py', 'torch/utils/_stats.py', - 'torch/utils/_sympy/__init__.py', - 'torch/utils/_sympy/functions.py', 'torch/utils/_traceback.py', 'torch/utils/_zip.py', 'torch/utils/backcompat/__init__.py', diff --git a/torch/utils/_sympy/functions.py b/torch/utils/_sympy/functions.py index 427333b07c16..e8c4a57d84c8 100644 --- a/torch/utils/_sympy/functions.py +++ b/torch/utils/_sympy/functions.py @@ -1,11 +1,21 @@ +import math + import sympy from sympy import S from sympy.core.logic import fuzzy_and, fuzzy_not, fuzzy_or -import math __all__ = [ - "FloorDiv", "ModularIndexing", "CleanDiv", "CeilDiv", "Pow", "TrueDiv", - "LShift", "RShift", "IsNonOverlappingAndDenseIndicator", "Round", "RoundDecimal", + "FloorDiv", + "ModularIndexing", + "CleanDiv", + "CeilDiv", + "Pow", + "TrueDiv", + "LShift", + "RShift", + "IsNonOverlappingAndDenseIndicator", + "Round", + "RoundDecimal", ] @@ -21,6 +31,7 @@ class FloorDiv(sympy.Function): 1. We can use divisibility guards to simplify FloorDiv(a, b) to a / b. 2. Printing out the expression is nicer (compared to say, representing a//b as (a - a % b) / b) """ + nargs = (2,) precedence = 50 # precedence of mul # noqa: F811 @@ -53,11 +64,14 @@ def _eval_is_integer(self): @classmethod def eval(cls, base, divisor): def check_supported_type(x): - if (x.is_integer is False and x.is_real is False and x.is_complex) or x.is_Boolean: + if ( + x.is_integer is False and x.is_real is False and x.is_complex + ) or x.is_Boolean: raise TypeError( f"unsupported operand type(s) for //: " f"'{type(base).__name__}' and '{type(divisor).__name__}'" - f", expected integer or real") + f", expected integer or real" + ) check_supported_type(base) check_supported_type(divisor) @@ -77,7 +91,9 @@ def check_supported_type(x): return sympy.Mul(base, -1) if isinstance(base, sympy.Integer) and isinstance(divisor, sympy.Integer): return base // divisor - if isinstance(base, (sympy.Integer, sympy.Float)) and isinstance(divisor, (sympy.Integer, sympy.Float)): + if isinstance(base, (sympy.Integer, sympy.Float)) and isinstance( + divisor, (sympy.Integer, sympy.Float) + ): return sympy.floor(base / divisor) if isinstance(base, FloorDiv): return FloorDiv(base.args[0], base.args[1] * divisor) @@ -125,7 +141,9 @@ def eval(cls, base, divisor, modulus): gcd = sympy.gcd(base, divisor) if gcd != 1: return ModularIndexing( - sympy.simplify(base / gcd), sympy.simplify(divisor / gcd), modulus + sympy.simplify(base / gcd), + sympy.simplify(divisor / gcd), + modulus, ) except sympy.PolynomialError: pass # https://github.com/pytorch/pytorch/issues/108276 @@ -178,6 +196,7 @@ def eval(cls, c, p, q): elif c == sympy.false: return q + class Mod(sympy.Function): """ We maintain this so that we avoid SymPy correctness issues, such as: @@ -263,16 +282,17 @@ class LShift(sympy.Function): @classmethod def eval(cls, base, shift): if shift < 0: - raise ValueError('negative shift count') - return base * 2 ** shift + raise ValueError("negative shift count") + return base * 2**shift class RShift(sympy.Function): @classmethod def eval(cls, base, shift): if shift < 0: - raise ValueError('negative shift count') - return base // 2 ** shift + raise ValueError("negative shift count") + return base // 2**shift + # Overloaded to be compatible with regular Python. # https://github.com/pytorch/pytorch/issues/90900 @@ -284,7 +304,8 @@ def eval(cls, base, exp): elif base.is_zero and exp < 0: raise ZeroDivisionError(f"{base} cannot be raised to a negative power") else: - return base ** exp + return base**exp + # Overloaded to be compatible with regular Python. # https://github.com/pytorch/pytorch/issues/90900 @@ -317,13 +338,14 @@ def eval(cls, *args): # in dim 0. if all(isinstance(a, sympy.Integer) for a in args): # sym_node imported in torch.__init__. Local import to avoid an import cycle - from torch.fx.experimental.symbolic_shapes import eval_is_non_overlapping_and_dense + from torch.fx.experimental.symbolic_shapes import ( + eval_is_non_overlapping_and_dense, + ) size_args = args[0:dim] stride_args = args[dim:] return eval_is_non_overlapping_and_dense( - [int(a) for a in size_args], - [int(a) for a in stride_args] + [int(a) for a in size_args], [int(a) for a in stride_args] ) return None @@ -361,7 +383,11 @@ def eval(cls, number, ndigits): if number.is_integer and ndigits >= 0: return number elif isinstance(number, sympy.Number) and isinstance(ndigits, sympy.Integer): - value_type, output_type = (int, sympy.Integer) if isinstance(number, sympy.Integer) else (float, sympy.Float) + value_type, output_type = ( + (int, sympy.Integer) + if isinstance(number, sympy.Integer) + else (float, sympy.Float) + ) return output_type(round(value_type(number), int(ndigits))) @@ -401,6 +427,7 @@ def eval(cls, a): return OpaqueUnaryFn + # Keep in sync with math_op_names in torch/fx/experimental/sym_node.py OpaqueUnaryFn_sqrt = make_opaque_unary_fn("sqrt") OpaqueUnaryFn_cos = make_opaque_unary_fn("cos") From 5ea956a61f43f910f4192ee9cf00268db1ede5ef Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Fri, 17 May 2024 08:39:43 -0700 Subject: [PATCH 109/116] Update hf_BirdBird periodic-dynamo-benchmarks results (#126414) can't repro this regression. also nothing in the faulty PR range would cause it only for 1 model. the job is still causing noise, so we should mute it. I think just updating the graph break count is better than skipping the model here since it's still passing Pull Request resolved: https://github.com/pytorch/pytorch/pull/126414 Approved by: https://github.com/ezyang --- .../ci_expected_accuracy/aot_eager_torchbench_inference.csv | 2 +- .../ci_expected_accuracy/aot_eager_torchbench_training.csv | 2 +- .../ci_expected_accuracy/aot_inductor_torchbench_inference.csv | 2 +- .../dynamic_aot_eager_torchbench_inference.csv | 2 +- .../dynamic_aot_eager_torchbench_training.csv | 2 +- .../dynamic_inductor_torchbench_inference.csv | 2 +- .../dynamic_inductor_torchbench_training.csv | 2 +- .../ci_expected_accuracy/dynamo_eager_torchbench_inference.csv | 2 +- .../ci_expected_accuracy/dynamo_eager_torchbench_training.csv | 2 +- .../ci_expected_accuracy/inductor_torchbench_inference.csv | 2 +- .../ci_expected_accuracy/inductor_torchbench_training.csv | 2 +- 11 files changed, 11 insertions(+), 11 deletions(-) diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv index 5b5646e85487..20fb340690ac 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv @@ -150,7 +150,7 @@ hf_Bert_large,pass,0 -hf_BigBird,pass,0 +hf_BigBird,pass,46 diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_training.csv index 0dd9ce3482f4..5131c2e9ade4 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_training.csv @@ -98,7 +98,7 @@ hf_Bert_large,pass,6 -hf_BigBird,pass,6 +hf_BigBird,pass, 52 diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_torchbench_inference.csv index 3e0af614a38c..40382a4f277c 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_torchbench_inference.csv @@ -138,7 +138,7 @@ hf_Bert_large,pass,0 -hf_BigBird,fail_accuracy,0 +hf_BigBird,fail_to_run,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv index 07bbe765f161..431a91d10669 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv @@ -150,7 +150,7 @@ hf_Bert_large,pass,0 -hf_BigBird,fail_to_run,0 +hf_BigBird,pass,46 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_training.csv index 80035c453fbf..1e1a4be4149e 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_training.csv @@ -98,7 +98,7 @@ hf_Bert_large,pass,6 -hf_BigBird,fail_to_run,3 +hf_BigBird,pass,52 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv index 07bbe765f161..f652e5ffa91a 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv @@ -150,7 +150,7 @@ hf_Bert_large,pass,0 -hf_BigBird,fail_to_run,0 +hf_BigBird,fail_accuracy,46 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_training.csv index eb1195caa9a1..ee58808c0bb0 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_training.csv @@ -98,7 +98,7 @@ hf_Bert_large,pass,6 -hf_BigBird,fail_to_run,3 +hf_BigBird,pass,52 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv index 5b5646e85487..20fb340690ac 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv @@ -150,7 +150,7 @@ hf_Bert_large,pass,0 -hf_BigBird,pass,0 +hf_BigBird,pass,46 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_training.csv index 0dd9ce3482f4..cfc524426644 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_training.csv @@ -98,7 +98,7 @@ hf_Bert_large,pass,6 -hf_BigBird,pass,6 +hf_BigBird,pass,52 diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv index 4ced1b19f245..108bc6543aa9 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv @@ -150,7 +150,7 @@ hf_Bert_large,pass,0 -hf_BigBird,fail_accuracy,0 +hf_BigBird,fail_accuracy,46 diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv index 0dd9ce3482f4..cfc524426644 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv @@ -98,7 +98,7 @@ hf_Bert_large,pass,6 -hf_BigBird,pass,6 +hf_BigBird,pass,52 From 853081a8e7e84ce3f09749929bb3cb998fd95b30 Mon Sep 17 00:00:00 2001 From: cyy Date: Sun, 19 May 2024 13:21:39 +0000 Subject: [PATCH 110/116] Replace torch.library.impl_abstract with torch.library.register_fake (#126606) To remove the disrupting warning ``` warnings.warn("torch.library.impl_abstract was renamed to " "torch.library.register_fake. Please use that instead; " "we will remove torch.library.impl_abstract in a future " "version of PyTorch.", DeprecationWarning, stacklevel=2) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/126606 Approved by: https://github.com/ezyang --- torch/_custom_ops.py | 2 +- torch/distributed/pipelining/_IR.py | 2 +- torch/onnx/_internal/fx/decomposition_skip.py | 2 +- torch/testing/_internal/custom_op_db.py | 4 ++-- torch/testing/_internal/hop_db.py | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/torch/_custom_ops.py b/torch/_custom_ops.py index c13b0aaf339a..c09a8ae68543 100644 --- a/torch/_custom_ops.py +++ b/torch/_custom_ops.py @@ -250,7 +250,7 @@ def impl_abstract(qualname, *, func=None): """ import torch.library - return torch.library.impl_abstract(qualname, func, _stacklevel=2) + return torch.library.register_fake(qualname, func, _stacklevel=2) def impl_save_for_backward(qualname, *, func=None): diff --git a/torch/distributed/pipelining/_IR.py b/torch/distributed/pipelining/_IR.py index aeb1a676c99e..204a60a34022 100644 --- a/torch/distributed/pipelining/_IR.py +++ b/torch/distributed/pipelining/_IR.py @@ -304,7 +304,7 @@ def _pipe_split(): return None -@torch.library.impl_abstract("pippy::_pipe_split") # type: ignore[no-redef] +@torch.library.register_fake("pippy::_pipe_split") # type: ignore[no-redef] def _pipe_split(): # noqa: F811 return None diff --git a/torch/onnx/_internal/fx/decomposition_skip.py b/torch/onnx/_internal/fx/decomposition_skip.py index 425e8604468b..7fb971a3307a 100644 --- a/torch/onnx/_internal/fx/decomposition_skip.py +++ b/torch/onnx/_internal/fx/decomposition_skip.py @@ -71,7 +71,7 @@ def register_custom_op(cls): new_op_qualname = f"{_NEW_OP_NAMESPACE}::{cls.new_op_name}" torch.library.define(new_op_qualname, cls.new_op_schema) torch.library.impl(new_op_qualname, "default", cls.replacement) - torch.library.impl_abstract(new_op_qualname, cls.abstract) + torch.library.register_fake(new_op_qualname, cls.abstract) @classmethod def replacement(cls, *args, **kwargs): diff --git a/torch/testing/_internal/custom_op_db.py b/torch/testing/_internal/custom_op_db.py index 3177fb9c8bb5..ee170cc36058 100644 --- a/torch/testing/_internal/custom_op_db.py +++ b/torch/testing/_internal/custom_op_db.py @@ -458,7 +458,7 @@ def source1_fake(x): lib.define("source2(Tensor x) -> Tensor") -@torch.library.impl_abstract("_torch_testing::source2", lib=lib) +@torch.library.register_fake("_torch_testing::source2", lib=lib) def _(x): return x.clone() @@ -467,7 +467,7 @@ def _(x): def source3_fake(x): return x.clone() -torch.library.impl_abstract("_torch_testing::source3", source3_fake, lib=lib) +torch.library.register_fake("_torch_testing::source3", source3_fake, lib=lib) @torch.library.custom_op("_torch_testing::source4", mutates_args=()) diff --git a/torch/testing/_internal/hop_db.py b/torch/testing/_internal/hop_db.py index df78812a6504..1602c1ef6562 100644 --- a/torch/testing/_internal/hop_db.py +++ b/torch/testing/_internal/hop_db.py @@ -82,7 +82,7 @@ def foo_impl_cuda(x, z): return x, z, x + z -@torch.library.impl_abstract("testlib::mutating_custom_op") +@torch.library.register_fake("testlib::mutating_custom_op") def foo_impl_abstract(x, z): return x, z, x + z From 574ae9afb876f2fc25d763ade496d656daaa8b7b Mon Sep 17 00:00:00 2001 From: cyy Date: Sun, 19 May 2024 22:34:24 +0000 Subject: [PATCH 111/116] [Submodule] Remove third-party onnx-tensorrt (#126542) It seems that tensorrt is not used by the C++ code, may be due to the removal of Caffe2. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126542 Approved by: https://github.com/ezyang --- .gitmodules | 4 ---- CMakeLists.txt | 1 - WORKSPACE | 5 ----- caffe2/core/macros.h.in | 2 -- cmake/Caffe2Config.cmake.in | 7 ------- cmake/Dependencies.cmake | 29 ----------------------------- third_party/onnx-tensorrt | 1 - 7 files changed, 49 deletions(-) delete mode 160000 third_party/onnx-tensorrt diff --git a/.gitmodules b/.gitmodules index db7698876a29..bd62cb8280ea 100644 --- a/.gitmodules +++ b/.gitmodules @@ -58,10 +58,6 @@ ignore = dirty path = third_party/onnx url = https://github.com/onnx/onnx.git -[submodule "third_party/onnx-tensorrt"] - ignore = dirty - path = third_party/onnx-tensorrt - url = https://github.com/onnx/onnx-tensorrt [submodule "third_party/sleef"] ignore = dirty path = third_party/sleef diff --git a/CMakeLists.txt b/CMakeLists.txt index 02cf8dedc79e..3c6320e68d39 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -265,7 +265,6 @@ option(USE_PYTORCH_QNNPACK "Use ATen/QNNPACK (quantized 8-bit operators)" ON) option(USE_SNPE "Use Qualcomm's SNPE library" OFF) option(USE_SYSTEM_EIGEN_INSTALL "Use system Eigen instead of the one under third_party" OFF) -option(USE_TENSORRT "Using Nvidia TensorRT library" OFF) cmake_dependent_option( USE_VALGRIND "Use Valgrind. Only available on Linux." ON "LINUX" OFF) diff --git a/WORKSPACE b/WORKSPACE index 8eabea571a57..f7e604332213 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -355,9 +355,4 @@ local_repository( path = "third_party/onnx/third_party/benchmark", ) -local_repository( - name = "unused_onnx_tensorrt_benchmark", - path = "third_party/onnx-tensorrt/third_party/onnx/third_party/benchmark", -) - ### Unused repos end diff --git a/caffe2/core/macros.h.in b/caffe2/core/macros.h.in index e1d8b89325ec..2497effd8637 100644 --- a/caffe2/core/macros.h.in +++ b/caffe2/core/macros.h.in @@ -25,7 +25,6 @@ #cmakedefine USE_MKLDNN #cmakedefine CAFFE2_USE_NVTX #cmakedefine CAFFE2_USE_ITT -#cmakedefine CAFFE2_USE_TRT #ifndef EIGEN_MPL2_ONLY #cmakedefine EIGEN_MPL2_ONLY @@ -67,7 +66,6 @@ {"USE_MKLDNN", "${USE_MKLDNN}"}, \ {"USE_NVTX", "${CAFFE2_USE_NVTX}"}, \ {"USE_ITT", "${CAFFE2_USE_ITT}"}, \ - {"USE_TRT", "${CAFFE2_USE_TRT}"}, \ {"USE_ROCM_KERNEL_ASSERT", "${USE_ROCM_KERNEL_ASSERT}"}, \ {"USE_CUSPARSELT", "${USE_CUSPARSELT}"}, \ } diff --git a/cmake/Caffe2Config.cmake.in b/cmake/Caffe2Config.cmake.in index 30e53c5fc752..c23b3990aff8 100644 --- a/cmake/Caffe2Config.cmake.in +++ b/cmake/Caffe2Config.cmake.in @@ -79,7 +79,6 @@ if(@USE_CUDA@) # If Caffe2 was compiled with the libraries below, they must # be found again when including the Caffe2 target. set(CAFFE2_USE_CUDA @USE_CUDA@) - set(CAFFE2_USE_TENSORRT @USE_TENSORRT@) # Add current directory to module path so we pick up FindCUDAToolkit.cmake set(old_CMAKE_MODULE_PATH "${CMAKE_MODULE_PATH}") @@ -93,12 +92,6 @@ if(@USE_CUDA@) "libraries. Please set the proper CUDA prefixes and / or install " "CUDA.") endif() - if(@CAFFE2_USE_TENSORRT@ AND NOT CAFFE2_USE_TENSORRT) - message(FATAL_ERROR - "Your installed Caffe2 version uses TensorRT but I cannot find the TensorRT " - "libraries. Please set the proper TensorRT prefixes and / or install " - "TensorRT.") - endif() endif() if(@USE_XPU@) diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 8bd80f167a5b..a7e38ee73bcc 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -40,7 +40,6 @@ if(USE_CUDA) set(CAFFE2_USE_CUDNN ${USE_CUDNN}) set(CAFFE2_USE_CUSPARSELT ${USE_CUSPARSELT}) set(CAFFE2_USE_NVRTC ${USE_NVRTC}) - set(CAFFE2_USE_TENSORRT ${USE_TENSORRT}) include(${CMAKE_CURRENT_LIST_DIR}/public/cuda.cmake) if(CAFFE2_USE_CUDA) # A helper variable recording the list of Caffe2 dependent libraries @@ -63,11 +62,6 @@ if(USE_CUDA) else() caffe2_update_option(USE_CUSPARSELT OFF) endif() - if(CAFFE2_USE_TENSORRT) - list(APPEND Caffe2_PUBLIC_CUDA_DEPENDENCY_LIBS caffe2::tensorrt) - else() - caffe2_update_option(USE_TENSORRT OFF) - endif() find_program(SCCACHE_EXECUTABLE sccache) if(SCCACHE_EXECUTABLE) # Using RSP/--options-file renders output noncacheable by sccache @@ -84,12 +78,10 @@ if(USE_CUDA) caffe2_update_option(USE_CUDNN OFF) caffe2_update_option(USE_CUSPARSELT OFF) caffe2_update_option(USE_NVRTC OFF) - caffe2_update_option(USE_TENSORRT OFF) set(CAFFE2_USE_CUDA OFF) set(CAFFE2_USE_CUDNN OFF) set(CAFFE2_USE_CUSPARSELT OFF) set(CAFFE2_USE_NVRTC OFF) - set(CAFFE2_USE_TENSORRT OFF) endif() endif() @@ -1491,27 +1483,6 @@ if(CAFFE2_CMAKE_BUILDING_WITH_MAIN_REPO AND NOT INTERN_DISABLE_ONNX) set(BUILD_SHARED_LIBS ${TEMP_BUILD_SHARED_LIBS}) endif() -# --[ TensorRT integration with onnx-trt -function(add_onnx_tensorrt_subdir) - # We pass the paths we found to onnx tensorrt. - set(CUDNN_INCLUDE_DIR "${CUDNN_INCLUDE_PATH}") - set(CUDNN_LIBRARY "${CUDNN_LIBRARY_PATH}") - set(CMAKE_VERSION_ORIG "{CMAKE_VERSION}") - # TODO: this WAR is for https://github.com/pytorch/pytorch/issues/18524 - set(CMAKE_VERSION "3.9.0") - add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/../third_party/onnx-tensorrt EXCLUDE_FROM_ALL) - set(CMAKE_VERSION "{CMAKE_VERSION_ORIG}") -endfunction() -if(CAFFE2_CMAKE_BUILDING_WITH_MAIN_REPO) - if(USE_TENSORRT) - add_onnx_tensorrt_subdir() - include_directories("${CMAKE_CURRENT_LIST_DIR}/../third_party/onnx-tensorrt") - caffe2_interface_library(nvonnxparser_static onnx_trt_library) - list(APPEND Caffe2_DEPENDENCY_WHOLE_LINK_LIBS onnx_trt_library) - set(CAFFE2_USE_TRT 1) - endif() -endif() - # --[ ATen checks set(USE_LAPACK 0) diff --git a/third_party/onnx-tensorrt b/third_party/onnx-tensorrt deleted file mode 160000 index c153211418a7..000000000000 --- a/third_party/onnx-tensorrt +++ /dev/null @@ -1 +0,0 @@ -Subproject commit c153211418a7c57ce071d9ce2a41f8d1c85a878f From be67985bd7fcf104748ad0ef0b0a4a435fd1b826 Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Thu, 16 May 2024 17:52:04 -0700 Subject: [PATCH 112/116] [compiled autograd] log in cpp using python logger (#126483) Internal infra may not preserve python and c++ log ordering e.g. MAST logs: https://fburl.com/mlhub/38576cxn, all the `[python_compiled_autograd.cpp] Creating cache entry [...]` logs of the entire run are at the beginning of the file Pull Request resolved: https://github.com/pytorch/pytorch/pull/126483 Approved by: https://github.com/jansel ghstack dependencies: #126144, #126146, #126148 --- test/inductor/test_compiled_autograd.py | 110 +++++++++--------- torch/_C/_dynamo/compiled_autograd.pyi | 2 +- torch/_dynamo/compiled_autograd.py | 11 +- .../csrc/dynamo/python_compiled_autograd.cpp | 44 ++++--- 4 files changed, 91 insertions(+), 76 deletions(-) diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py index 201dd4a3c77d..5fbd4c0705ff 100644 --- a/test/inductor/test_compiled_autograd.py +++ b/test/inductor/test_compiled_autograd.py @@ -1533,64 +1533,68 @@ def fn(): ) def test_verbose_logs_cpp(self): - script = """ -import torch - -def compiler_fn(gm): - return torch.compile(gm, backend="eager") + torch._logging.set_logs(compiled_autograd_verbose=True) -def main(): - torch._logging.set_logs(compiled_autograd_verbose=True) - model = torch.nn.Sequential( - torch.nn.Linear(4, 4), - torch.nn.ReLU(), - torch.nn.Linear(4, 4), - torch.nn.ReLU(), - ) + def fn(): + model = torch.nn.Sequential( + torch.nn.Linear(4, 4), + torch.nn.ReLU(), + torch.nn.Linear(4, 4), + torch.nn.ReLU(), + ) + for i in [10, 11, 12]: + model.zero_grad() + x = torch.randn([i, 4]) + result = model(x).sum() + result.backward() + yield model[0].weight.grad + yield model[0].bias.grad + yield model[2].weight.grad + yield model[2].bias.grad - for i in range(10, 100): - x = torch.randn([i, 4]) - result = model(x).sum() - with torch._dynamo.compiled_autograd.enable(compiler_fn): - result.backward() + logs, ctx = logs_to_string( + torch._dynamo.compiled_autograd.__name__, "compiled_autograd_verbose" + ) + with ctx(): + self.check_output_and_recompiles(fn, count=2) + + patterns1 = [ + r".*Creating cache entry for torch::autograd::GraphRoot, with key of size (\d+)\n", + r".*Creating cache entry for SumBackward0, with key of size (\d+)\n", + r".*Creating cache entry for ReluBackward0, with key of size (\d+)\n", + r".*Creating cache entry for AddmmBackward0, with key of size 1(\d+)\n", + r".*Creating cache entry for TBackward0, with key of size (\d+)\n", + r".*Creating cache entry for torch::autograd::AccumulateGrad, with key of size (\d+)\n", + r".*Creating cache entry for ReluBackward0, with key of size (\d+)\n", + r".*Creating cache entry for AddmmBackward0, with key of size (\d+)\n", + r".*Creating cache entry for TBackward0, with key of size (\d+)\n", + r".*Creating cache entry for torch::autograd::AccumulateGrad, with key of size (\d+)\n", + r".*Creating cache entry for torch::autograd::AccumulateGrad, with key of size (\d+)\n", + r".*Creating cache entry for torch::autograd::AccumulateGrad, with key of size (\d+)\n", + ] -main() -""" - stdout, _ = self.run_process_no_exception(script) - stdout = stdout.decode("utf-8") - - patterns = [ - r"\[python_compiled_autograd.cpp\] Creating cache entry for SumBackward0, with key of size (\d+)\n", - r"\[python_compiled_autograd.cpp\] Creating cache entry for ReluBackward0, with key of size (\d+)\n", - r"\[python_compiled_autograd.cpp\] Creating cache entry for AddmmBackward0, with key of size (\d+)\n", - r"\[python_compiled_autograd.cpp\] Creating cache entry for TBackward0, with key of size (\d+)\n", - r"\[python_compiled_autograd.cpp\] Creating cache entry for torch::autograd::AccumulateGrad, with key of size (\d+)\n", - r"\[python_compiled_autograd.cpp\] Creating cache entry for ReluBackward0, with key of size (\d+)\n", - r"\[python_compiled_autograd.cpp\] Creating cache entry for AddmmBackward0, with key of size (\d+)\n", - r"\[python_compiled_autograd.cpp\] Creating cache entry for TBackward0, with key of size (\d+)\n", - r"\[python_compiled_autograd.cpp\] Creating cache entry for torch::autograd::AccumulateGrad, with key of size (\d+)\n", - r"\[python_compiled_autograd.cpp\] Creating cache entry for torch::autograd::AccumulateGrad, with key of size (\d+)\n", - r"\[python_compiled_autograd.cpp\] Creating cache entry for torch::autograd::AccumulateGrad, with key of size (\d+)\n", - r"\[python_compiled_autograd.cpp\] Creating cache entry for torch::autograd::AccumulateGrad, with key of size (\d+)\n", - r"\[python_compiled_autograd.cpp\] Creating cache entry for ReluBackward0, with key of size (\d+)\n", - r"\[python_compiled_autograd.cpp\] Creating cache entry for AddmmBackward0, with key of size (\d+)\n", - r"\[python_compiled_autograd.cpp\] Creating cache entry for TBackward0, with key of size (\d+)\n", - r"\[python_compiled_autograd.cpp\] Creating cache entry for torch::autograd::AccumulateGrad, with key of size (\d+)\n", - r"\[python_compiled_autograd.cpp\] Creating cache entry for torch::autograd::AccumulateGrad, with key of size (\d+)\n", - r"\[python_compiled_autograd.cpp\] Creating cache entry for torch::autograd::AccumulateGrad, with key of size (\d+)\n", - r"\[python_compiled_autograd.cpp\] cache miss: marking sizes\[(\d+)\] as dynamic\n", - r"\[python_compiled_autograd.cpp\] cache miss: marking sizes\[(\d+)\] as dynamic\n", - r"\[python_compiled_autograd.cpp\] cache miss: marking sizes\[(\d+)\] as dynamic\n", - r"\[python_compiled_autograd.cpp\] cache miss: marking sizes\[(\d+)\] as dynamic\n", - r"\[python_compiled_autograd.cpp\] cache miss: marking sizes\[(\d+)\] as dynamic\n", - r"\[python_compiled_autograd.cpp\] cache miss: marking sizes\[(\d+)\] as dynamic\n", - r"\[python_compiled_autograd.cpp\] cache miss: marking sizes\[(\d+)\] as dynamic\n", + # recompile + patterns2 = [ + r".*cache miss: marking sizes\[(\d+)\] as dynamic\n", + r".*cache miss: marking sizes\[(\d+)\] as dynamic\n", + r".*cache miss: marking sizes\[(\d+)\] as dynamic\n", + r".*cache miss: marking sizes\[(\d+)\] as dynamic\n", + r".*cache miss: marking sizes\[(\d+)\] as dynamic\n", + r".*cache miss: marking sizes\[(\d+)\] as dynamic\n", + r".*cache miss: marking sizes\[(\d+)\] as dynamic\n", ] - pattern = r"".join(patterns) - matches = re.findall(pattern, stdout) - self.assertEqual(len(matches), 1) - self.assertEqual(len(matches[0]), len(patterns)) + all_logs = logs.getvalue() + + pattern1 = r"".join(patterns1) + matches1 = re.findall(pattern1, all_logs) + self.assertEqual(len(matches1), 1) + self.assertEqual(len(matches1[0]), len(patterns1)) + + pattern2 = r"".join(patterns2) + matches2 = re.findall(pattern2, all_logs) + self.assertEqual(len(matches2), 1) + self.assertEqual(len(matches2[0]), len(patterns2)) def test_snapshot_verbose_logs_flag(self): def fn(): diff --git a/torch/_C/_dynamo/compiled_autograd.pyi b/torch/_C/_dynamo/compiled_autograd.pyi index 8ec4fbbdae8c..d2067a583921 100644 --- a/torch/_C/_dynamo/compiled_autograd.pyi +++ b/torch/_C/_dynamo/compiled_autograd.pyi @@ -7,4 +7,4 @@ def set_autograd_compiler( ) -> Optional[Callable[[], AutogradCompilerInstance]]: ... def clear_cache() -> None: ... def is_cache_empty() -> bool: ... -def set_verbose_logging(enable: bool) -> bool: ... +def set_verbose_logger(fn: Optional[Callable[[str], None]]) -> bool: ... diff --git a/torch/_dynamo/compiled_autograd.py b/torch/_dynamo/compiled_autograd.py index f9cf03947a8c..7a87a2c7d575 100644 --- a/torch/_dynamo/compiled_autograd.py +++ b/torch/_dynamo/compiled_autograd.py @@ -37,6 +37,10 @@ def snapshot_verbose_logging_enabled(): ) +def cpp_verbose_log_fn(msg: str) -> None: + verbose_log.debug(msg) + + def maybe_clone(x): if x is not None: return clone_preserve_strides(x) @@ -292,9 +296,8 @@ def enable(compiler_fn): prior = torch._C._dynamo.compiled_autograd.set_autograd_compiler( functools.partial(AutogradCompilerInstance, compiler_fn) ) - torch._C._dynamo.compiled_autograd.set_verbose_logging( - snapshot_verbose_logging_enabled() - ) + if snapshot_verbose_logging_enabled(): + torch._C._dynamo.compiled_autograd.set_verbose_logger(cpp_verbose_log_fn) global compiled_autograd_enabled, compiled_autograd_enabled_count compiled_autograd_enabled = True compiled_autograd_enabled_count += 1 @@ -326,4 +329,4 @@ def reset() -> None: compiled_autograd_enable = False assert compiled_autograd_enabled_count == 0 torch._C._dynamo.compiled_autograd.set_autograd_compiler(None) - torch._C._dynamo.compiled_autograd.set_verbose_logging(False) + torch._C._dynamo.compiled_autograd.set_verbose_logger(None) diff --git a/torch/csrc/dynamo/python_compiled_autograd.cpp b/torch/csrc/dynamo/python_compiled_autograd.cpp index fb27b39b28e6..ab3e19b8c126 100644 --- a/torch/csrc/dynamo/python_compiled_autograd.cpp +++ b/torch/csrc/dynamo/python_compiled_autograd.cpp @@ -50,14 +50,6 @@ at trace time. namespace torch::dynamo::autograd { using c10::SymInt; -// snapshot of python verbose logging toggle -static bool is_verbose_logging_enabled; -static constexpr std::string_view VLOG_PREFIX = - "[python_compiled_autograd.cpp] "; -std::ostream& vcout() { - return std::cout << VLOG_PREFIX; -} - static PyObject* wrap_int_list(const std::vector& inputs) { PyObject* pyinput = PyTuple_New(static_cast(inputs.size())); for (const auto i : c10::irange(inputs.size())) { @@ -91,6 +83,13 @@ static void check(bool result) { check(nullptr); } +// snapshot of python verbose logging toggle +static PyObject* python_verbose_logger = nullptr; +void verbose_log_fn(std::string_view msg) { + TORCH_CHECK(python_verbose_logger != nullptr); + check(PyObject_CallFunction(python_verbose_logger, "s", msg.data())); +} + struct CacheNode { // A node in the shadow graph, we follow next edges until we reach the end of // the graph @@ -161,9 +160,10 @@ struct CacheNode { if (changed_value) { if (!was_dynamic) { cache_hit = false; - if (is_verbose_logging_enabled) { - vcout() << "cache miss: marking sizes[" << i << "] as dynamic" - << std::endl; + if (python_verbose_logger != nullptr) { + verbose_log_fn( + "cache miss: marking sizes[" + std::to_string(i) + + "] as dynamic"); } } expected = SizeInput(SizeInput::DYNAMIC, data[i].value); @@ -257,11 +257,18 @@ static PyObject* is_cache_empty(PyObject* dummy, PyObject* args) { END_HANDLE_TH_ERRORS; } -static PyObject* set_verbose_logging(PyObject* dummy, PyObject* args) { +static PyObject* set_verbose_logger(PyObject* dummy, PyObject* args) { HANDLE_TH_ERRORS; - if (!PyArg_ParseTuple(args, "p", &is_verbose_logging_enabled)) { + PyObject* logger = nullptr; + if (!PyArg_ParseTuple(args, "O", &logger)) { Py_RETURN_FALSE; } + + if (logger == Py_None) { + python_verbose_logger = nullptr; + } else { + python_verbose_logger = logger; + } Py_RETURN_TRUE; END_HANDLE_TH_ERRORS; } @@ -271,7 +278,7 @@ static PyMethodDef _methods[] = { {"set_autograd_compiler", set_autograd_compiler, METH_VARARGS, nullptr}, {"clear_cache", clear_cache, METH_NOARGS, nullptr}, {"is_cache_empty", is_cache_empty, METH_NOARGS, nullptr}, - {"set_verbose_logging", set_verbose_logging, METH_VARARGS, nullptr}, + {"set_verbose_logger", set_verbose_logger, METH_VARARGS, nullptr}, {nullptr, nullptr, 0, nullptr}}; static struct PyModuleDef _module = { @@ -367,10 +374,11 @@ CacheNode* _compiled_autograd_impl( node_args.collect(call.node->next_edges()); } CacheKey key = node_args.key(); - if (is_verbose_logging_enabled && + if (python_verbose_logger != nullptr && cache->lookup(key, /*create=*/false) == nullptr) { - vcout() << "Creating cache entry for " << fn->name() - << ", with key of size " << key.key_size << std::endl; + verbose_log_fn( + "Creating cache entry for " + fn->name() + ", with key of size " + + std::to_string(key.key_size)); } cache = cache->lookup(key); } @@ -454,7 +462,7 @@ CacheNode* _compiled_autograd_impl( inputs = THPVariable_UnpackList(pyinputs); } - if (is_verbose_logging_enabled) { + if (python_verbose_logger != nullptr) { std::string _node_name = call.node->name(); THPObjectPtr node_name(PyUnicode_FromString(_node_name.data())); TORCH_INTERNAL_ASSERT(node_name != nullptr); From 5fb11cda4fe60c1a7b30e6c844f84ce8933ef953 Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Fri, 17 May 2024 18:28:06 -0700 Subject: [PATCH 113/116] [compiled autograd] Better cache miss logging (#126602) - log only first node key cache miss - log existing node key sizes - log which node's collected sizes became dynamic e.g. ``` DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:Cache miss due to new autograd node: torch::autograd::GraphRoot (NodeCall 0) with key size 39, previous key sizes=[] ... DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:Cache miss due to new autograd node: torch::autograd::AccumulateGrad (NodeCall 5) with key size 32, previous key sizes=[21] ... DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:Cache miss due to dynamic shapes: collected size idx 0 of torch::autograd::GraphRoot (NodeCall 0) DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:Cache miss due to dynamic shapes: collected size idx 2 of SumBackward0 (NodeCall 1) DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:Cache miss due to dynamic shapes: collected size idx 4 of SumBackward0 (NodeCall 1) DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:Cache miss due to dynamic shapes: collected size idx 2 of ReluBackward0 (NodeCall 2) DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:Cache miss due to dynamic shapes: collected size idx 9 of AddmmBackward0 (NodeCall 3) DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:Cache miss due to dynamic shapes: collected size idx 2 of torch::autograd::AccumulateGrad (NodeCall 5) DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:Cache miss due to dynamic shapes: collected size idx 2 of ReluBackward0 (NodeCall 6) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/126602 Approved by: https://github.com/jansel ghstack dependencies: #126144, #126146, #126148, #126483 --- test/inductor/test_compiled_autograd.py | 34 +++--- .../csrc/dynamo/python_compiled_autograd.cpp | 109 +++++++++++++++--- 2 files changed, 108 insertions(+), 35 deletions(-) diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py index 5fbd4c0705ff..87299d796f6c 100644 --- a/test/inductor/test_compiled_autograd.py +++ b/test/inductor/test_compiled_autograd.py @@ -1559,29 +1559,20 @@ def fn(): self.check_output_and_recompiles(fn, count=2) patterns1 = [ - r".*Creating cache entry for torch::autograd::GraphRoot, with key of size (\d+)\n", - r".*Creating cache entry for SumBackward0, with key of size (\d+)\n", - r".*Creating cache entry for ReluBackward0, with key of size (\d+)\n", - r".*Creating cache entry for AddmmBackward0, with key of size 1(\d+)\n", - r".*Creating cache entry for TBackward0, with key of size (\d+)\n", - r".*Creating cache entry for torch::autograd::AccumulateGrad, with key of size (\d+)\n", - r".*Creating cache entry for ReluBackward0, with key of size (\d+)\n", - r".*Creating cache entry for AddmmBackward0, with key of size (\d+)\n", - r".*Creating cache entry for TBackward0, with key of size (\d+)\n", - r".*Creating cache entry for torch::autograd::AccumulateGrad, with key of size (\d+)\n", - r".*Creating cache entry for torch::autograd::AccumulateGrad, with key of size (\d+)\n", - r".*Creating cache entry for torch::autograd::AccumulateGrad, with key of size (\d+)\n", + r".*Cache miss due to new autograd node: torch::autograd::GraphRoot \(NodeCall 0\) with key size (\d+), " + r"previous key sizes=\[\]\n", ] # recompile patterns2 = [ - r".*cache miss: marking sizes\[(\d+)\] as dynamic\n", - r".*cache miss: marking sizes\[(\d+)\] as dynamic\n", - r".*cache miss: marking sizes\[(\d+)\] as dynamic\n", - r".*cache miss: marking sizes\[(\d+)\] as dynamic\n", - r".*cache miss: marking sizes\[(\d+)\] as dynamic\n", - r".*cache miss: marking sizes\[(\d+)\] as dynamic\n", - r".*cache miss: marking sizes\[(\d+)\] as dynamic\n", + r".*Cache miss due to changed shapes: marking size idx (\d+) of torch::autograd::GraphRoot \(NodeCall 0\) as dynamic\n", + r".*Cache miss due to changed shapes: marking size idx (\d+) of SumBackward0 \(NodeCall 1\) as dynamic\n", + r".*Cache miss due to changed shapes: marking size idx (\d+) of SumBackward0 \(NodeCall 1\) as dynamic\n", + r".*Cache miss due to changed shapes: marking size idx (\d+) of ReluBackward0 \(NodeCall 2\) as dynamic\n", + r".*Cache miss due to changed shapes: marking size idx (\d+) of AddmmBackward0 \(NodeCall 3\) as dynamic\n", + r".*Cache miss due to changed shapes: marking size idx (\d+) of torch::autograd::AccumulateGrad " + r"\(NodeCall 5\) as dynamic\n", + r".*Cache miss due to changed shapes: marking size idx (\d+) of ReluBackward0 \(NodeCall 6\) as dynamic\n", ] all_logs = logs.getvalue() @@ -1589,7 +1580,10 @@ def fn(): pattern1 = r"".join(patterns1) matches1 = re.findall(pattern1, all_logs) self.assertEqual(len(matches1), 1) - self.assertEqual(len(matches1[0]), len(patterns1)) + assert isinstance( + matches1[0], str + ) # for a single match: matches1=['match'], for multiple matches: matches1=[('match1', 'match2')]... + self.assertEqual(len(matches1), len(patterns1)) pattern2 = r"".join(patterns2) matches2 = re.findall(pattern2, all_logs) diff --git a/torch/csrc/dynamo/python_compiled_autograd.cpp b/torch/csrc/dynamo/python_compiled_autograd.cpp index ab3e19b8c126..3a79a7bc6372 100644 --- a/torch/csrc/dynamo/python_compiled_autograd.cpp +++ b/torch/csrc/dynamo/python_compiled_autograd.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include /* @@ -85,10 +86,79 @@ static void check(bool result) { // snapshot of python verbose logging toggle static PyObject* python_verbose_logger = nullptr; -void verbose_log_fn(std::string_view msg) { - TORCH_CHECK(python_verbose_logger != nullptr); - check(PyObject_CallFunction(python_verbose_logger, "s", msg.data())); -} +struct VerboseLogger { + static std::optional maybe_create() { + if (python_verbose_logger == nullptr) { + return std::nullopt; + } + return VerboseLogger(); + } + + void verbose_log_fn(std::string_view msg) const { + TORCH_CHECK(python_verbose_logger != nullptr); + check(PyObject_CallFunction(python_verbose_logger, "s", msg.data())); + } + + void log_node_check( + const Node& fn, + size_t size_inputs_num, + std::unordered_set cached_keys, + const CacheKey& key, + size_t node_idx) { + std::string node_name = + fn.name() + " (NodeCall " + std::to_string(node_idx) + ")"; + + cumulative_sizes_per_node[size_inputs_num] = node_name; + + if (!logged_node_miss && cached_keys.find(key) == cached_keys.end()) { + _log_node_miss(typeid(fn), cached_keys, key, node_name); + logged_node_miss = true; + } + } + + void _log_node_miss( + const std::type_info& node_type, + std::unordered_set cached_keys, + const CacheKey& key, + const std::string& node_name) const { + std::ostringstream oss; + oss << "Cache miss due to new autograd node: " << node_name + << " with key size " << std::to_string(key.key_size) + << ", previous key sizes=["; + + for (auto it = cached_keys.begin(); it != cached_keys.end(); it++) { + if (it->node_type != node_type) { + continue; + } + oss << it->key_size; + if (std::next(it) != cached_keys.end()) { + oss << ","; + } + } + oss << "]"; + verbose_log_fn(oss.str()); + } + + void log_dynamic_shapes_check(size_t size_idx) const { + if (cumulative_sizes_per_node.empty()) { + return; + } + + auto it = cumulative_sizes_per_node.lower_bound(size_idx); + TORCH_CHECK(it != cumulative_sizes_per_node.end()); + size_t start_idx = + it == cumulative_sizes_per_node.begin() ? 0 : std::prev(it)->first; + verbose_log_fn( + "Cache miss due to changed shapes: marking size idx " + + std::to_string(size_idx - start_idx) + " of " + it->second + + " as dynamic"); + } + + // track which size index belongs to which node + std::map cumulative_sizes_per_node; + // only log cache miss due to node key once + bool logged_node_miss = false; +}; struct CacheNode { // A node in the shadow graph, we follow next edges until we reach the end of @@ -134,7 +204,9 @@ struct CacheNode { CacheNode& operator=(const CacheNode&) = delete; CacheNode& operator=(CacheNode&&) = delete; - bool check_dynamic_sizes(AutogradCompilerCall& call) { + bool check_dynamic_sizes( + AutogradCompilerCall& call, + const std::optional& vlogger) { /* We start off by assuming everything is static, then we mark things as dynamic when we see them change. This function: @@ -160,10 +232,8 @@ struct CacheNode { if (changed_value) { if (!was_dynamic) { cache_hit = false; - if (python_verbose_logger != nullptr) { - verbose_log_fn( - "cache miss: marking sizes[" + std::to_string(i) + - "] as dynamic"); + if (vlogger.has_value()) { + vlogger->log_dynamic_shapes_check(i); } } expected = SizeInput(SizeInput::DYNAMIC, data[i].value); @@ -360,6 +430,8 @@ CacheNode* _compiled_autograd_impl( calls.reserve( check_exec_info ? graph_task.exec_info_.size() : dependencies.size() + 1); + int i = 0; + std::optional vlogger = VerboseLogger::maybe_create(); while (!worklist.empty()) { std::shared_ptr fn = std::move(worklist.back()); worklist.pop_back(); @@ -374,11 +446,17 @@ CacheNode* _compiled_autograd_impl( node_args.collect(call.node->next_edges()); } CacheKey key = node_args.key(); - if (python_verbose_logger != nullptr && - cache->lookup(key, /*create=*/false) == nullptr) { - verbose_log_fn( - "Creating cache entry for " + fn->name() + ", with key of size " + - std::to_string(key.key_size)); + if (vlogger.has_value()) { + std::unordered_set cached_keys; + for (const auto& [k, _] : cache->next) { + cached_keys.emplace(k); + } + vlogger->log_node_check( + *fn, + compiler_call.all_size_inputs.size(), + std::move(cached_keys), + key, + i); } cache = cache->lookup(key); } @@ -403,10 +481,11 @@ CacheNode* _compiled_autograd_impl( worklist.emplace_back(edge.function); } } + i++; } // TODO(jansel): some dynamic sizes seem to be ints not symints - if (!cache->check_dynamic_sizes(compiler_call)) { + if (!cache->check_dynamic_sizes(compiler_call, vlogger)) { // cache miss, need to capture FX graph ClosingTHPObjectPtr py_compiler( check(PyObject_CallNoArgs((the_autograd_compiler)))); From cf35a591b95220aa1bfcc04ff8a943efd1d6d6eb Mon Sep 17 00:00:00 2001 From: jayanth domalapalli Date: Mon, 20 May 2024 06:20:45 +0000 Subject: [PATCH 114/116] Updated test_graph_optims and test_graph_scaling_fused_optimizers to use new OptimizerInfo infrastructure (#125127) This PR is meant to address issue #123451, more specifically, the ```test_graph_optims``` and ```test_graph_scaling_fused_optimizers``` functions in ```test_cuda.py``` have been updated so that they now use the new OptimizerInfo infrastructure. Lintrunner passed: ``` $ lintrunner test/test_cuda.py ok No lint issues. ``` Tests passed: ``` >python test_cuda.py -k test_graph_optims Ran 19 tests in 7.463s OK (skipped=9) >python test_cuda.py -k test_graph_scaling_fused_optimizers Ran 6 tests in 2.800s OK (skipped=3) ``` Both the functions have been moved to the newly created TestCase class ```TestCudaOptims```. The test is mostly the same except the ```@optims``` decorator is used at the top of the function to implicitly call the function using each of the optimizers mentioned in the decorator instead of explicitly using a for loop to iterate through each of the optimizers. I was unable to use the ```_get_optim_inputs_including_global_cliquey_kwargs``` to get all kwargs for each of the optimizers since some of the kwargs that are used in the original ```test_graph_optims``` function are not being returned by the new OptimizerInfo infrastructure, more specifically, for the ```torch.optim.rmsprop.RMSprop``` optimizer, the following kwargs are not returned whenever ```_get_optim_inputs_including_global_cliquey_kwargs``` is called: ``` {'foreach': False, 'maximize': True, 'weight_decay': 0} { 'foreach': True, 'maximize': True, 'weight_decay': 0} ``` I ran into the same issue for ```test_graph_scaling_fused_optimizers```, for the ```torch.optim.adamw.AdamW``` optimizer, whenever ```optim_info.optim_inputs_func(device=device)``` was called, the following kwarg was not returned: ``` {'amsgrad': True} ``` Due to this issue, I resorted to using a dictionary to store the kwargs for each of the optimizers, I am aware that this is less than ideal. I was wondering whether I should use the OptimizerInfo infrastructure to get all the kwargs regardless of the fact that it lacks some kwargs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/125127 Approved by: https://github.com/janeyx99 --- test/inductor/test_compiled_optimizers.py | 2 + test/test_cuda.py | 393 ++++++++----------- torch/testing/_internal/common_optimizers.py | 64 ++- 3 files changed, 225 insertions(+), 234 deletions(-) diff --git a/test/inductor/test_compiled_optimizers.py b/test/inductor/test_compiled_optimizers.py index b2d0ed91809f..7100837e9b92 100644 --- a/test/inductor/test_compiled_optimizers.py +++ b/test/inductor/test_compiled_optimizers.py @@ -136,6 +136,8 @@ class KernelCounts(NamedTuple): "test_sgd_momentum_foreach_cuda": 5, "test_sgd_weight_decay_maximize_cuda": 4, "test_sgd_weight_decay_maximize_cpu": 4, + "test_sgd_weight_decay_cpu": 4, + "test_sgd_weight_decay_cuda": 4, "test_sgd_momentum_weight_decay_foreach_cuda": 2, "test_sgd_momentum_nesterov_weight_decay_foreach_cuda": 2, "test_sgd_cuda": 4, diff --git a/test/test_cuda.py b/test/test_cuda.py index 93e08eff4df6..cc3e2380f266 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -37,7 +37,11 @@ instantiate_device_type_tests, onlyCUDA, ) -from torch.testing._internal.common_optimizers import optim_db, optims +from torch.testing._internal.common_optimizers import ( + _get_optim_inputs_including_global_cliquey_kwargs, + optim_db, + optims, +) from torch.testing._internal.common_utils import ( freeze_rng_state, gcIfJetson, @@ -3200,111 +3204,6 @@ def _test_graphed_optimizer( for p_control, p_graphed in zip(params_control, params_graphed): self.assertEqual(p_control, p_graphed) - @unittest.skipIf( - not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" - ) - def test_graph_optims(self): - # Needs generalization if we want to extend this test to non-Adam-like optimizers. - cases = ( - [ - ( - optimizer_ctor, - { - "lr": 0.1, - "betas": (0.8, 0.7), - "foreach": foreach, - "decoupled_weight_decay": decoupled_weight_decay, - "weight_decay": weight_decay, - }, - ) - for optimizer_ctor, foreach, decoupled_weight_decay, weight_decay in product( - ( - torch.optim.NAdam, - torch.optim.RAdam, - ), - ( - False, - True, - ), - ( - False, - True, - ), - ( - 0.0, - 0.1, - ), - ) - ] - + [ - ( - torch.optim.Rprop, - {"lr": 0.1, "foreach": foreach, "maximize": maximize}, - ) - for foreach, maximize in product( - ( - False, - True, - ), - ( - False, - True, - ), - ) - ] - + [ - ( - optimizer_ctor, - { - "lr": 0.1, - "betas": (0.8, 0.7), - "foreach": foreach, - "amsgrad": amsgrad, - }, - ) - for optimizer_ctor, foreach, amsgrad in product( - (torch.optim.Adam, torch.optim.AdamW), - (False, True), - (False, True), - ) - ] - + [ - ( - optimizer_ctor, - {"lr": 0.1, "betas": (0.8, 0.7), "fused": True, "amsgrad": amsgrad}, - ) - for optimizer_ctor, amsgrad in product( - (torch.optim.Adam, torch.optim.AdamW), (False, True) - ) - ] - + [ - ( - optimizer_ctor, - { - "lr": 0.1, - "foreach": foreach, - "maximize": maximize, - "weight_decay": weight_decay, - }, - ) - for optimizer_ctor, foreach, maximize, weight_decay in product( - ( - torch.optim.Adamax, - torch.optim.ASGD, - torch.optim.Adadelta, - torch.optim.RMSprop, - ), - (False, True), - (False, True), - (0, 0.1), - ) - ] - ) - - for optimizer_ctor, kwargs in cases: - with self.subTest(optimizer_ctor=optimizer_ctor, kwargs=kwargs): - self._test_graphed_optimizer(3, 2, optimizer_ctor, kwargs) - @unittest.skipIf( not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" ) @@ -3376,123 +3275,6 @@ def test_graph_optims_with_explicitly_capturable_param_groups(self): self.assertEqual(ref_p1, param1) self.assertEqual(ref_p2, param2) - @unittest.skipIf( - not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" - ) - def test_graph_scaling_fused_optimizers(self): - cases = [ - ( - optimizer_ctor, - {"lr": 0.1, "betas": (0.8, 0.7), "fused": True, "amsgrad": amsgrad}, - ) - for optimizer_ctor, amsgrad in product( - (torch.optim.Adam, torch.optim.AdamW), (False, True) - ) - ] + list( - product( - (torch.optim.SGD,), - [ - { - "lr": 0.1, - "momentum": 0.0, - "dampening": d, - "weight_decay": w, - "nesterov": n, - "fused": True, - } - for d, w, n in product((0.0, 0.5), (0.0, 0.5), (False,)) - ] - + [ - { - "lr": 0.1, - "momentum": 0.5, - "dampening": d, - "weight_decay": w, - "nesterov": n, - "fused": True, - } - for d, w, n in product((0.0,), (0.0, 0.5), (True, False)) - ], - ) - ) - - steps_warmup = 3 - steps_train = 2 - - for OptClass, kwargs in cases: - has_capturable_arg = OptClass in (torch.optim.Adam, torch.optim.AdamW) - for actually_do_graphs in (True, False) if has_capturable_arg else (True,): - params = [torch.randn((i + 5, i + 5), device="cuda") for i in range(2)] - params_control = [p.clone().requires_grad_() for p in params] - params_graphed = [p.clone().requires_grad_() for p in params] - - # `GradScaler` in-place updates gradients thus it's necessary to duplicate gradients. - grads = [ - [torch.randn_like(p) for p in params] - for _ in range(steps_warmup + steps_train) - ] - with torch.no_grad(): - grads_control = [[g.clone() for g in gs] for gs in grads] - grads_graphed = [[g.clone() for g in gs] for gs in grads] - - # Gradient Scaler - scaler_for_control = torch.cuda.amp.GradScaler(init_scale=128.0) - with torch.no_grad(): - scaler_for_control._lazy_init_scale_growth_tracker( - torch.device("cuda") - ) - - scaler_for_graphed = torch.cuda.amp.GradScaler() - scaler_for_graphed.load_state_dict(scaler_for_control.state_dict()) - with torch.no_grad(): - scaler_for_graphed._lazy_init_scale_growth_tracker( - torch.device("cuda") - ) - - # Control (capturable=False) - if has_capturable_arg: - kwargs["capturable"] = False - opt = OptClass(params_control, **kwargs) - - for i in range(steps_warmup + steps_train): - for j, p in enumerate(params_control): - p.grad = grads_control[i][j] - scaler_for_control.step(opt) - scaler_for_control.update() - - # capturable=True - if has_capturable_arg: - kwargs["capturable"] = True - opt = OptClass(params_graphed, **kwargs) - - for i in range(steps_warmup): - for j, p in enumerate(params_graphed): - p.grad = grads_graphed[i][j] - scaler_for_graphed.step(opt) - scaler_for_graphed.update() - - if actually_do_graphs: - g = torch.cuda.CUDAGraph() - with torch.cuda.graph(g): - scaler_for_graphed.step(opt) - scaler_for_graphed.update() - - for i in range(steps_train): - if actually_do_graphs: - for j, p in enumerate(params_graphed): - p.grad.copy_(grads_graphed[i + steps_warmup][j]) - g.replay() - else: - # Passing capturable=True to the constructor and running without graphs should still be - # numerically correct, even if it's not ideal for performance. - for j, p in enumerate(params_graphed): - p.grad = grads_graphed[i + steps_warmup][j] - scaler_for_graphed.step(opt) - scaler_for_graphed.update() - - for p_control, p_graphed in zip(params_control, params_graphed): - self.assertEqual(p_control, p_graphed) - @unittest.skipIf( not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" ) @@ -4698,10 +4480,175 @@ def test_no_triton_on_import(self): self.assertEqual(rc, "False", "Triton was imported when importing torch!") +@torch.testing._internal.common_utils.markDynamoStrictTest class TestCudaOptims(TestCase): # These tests will be instantiate with instantiate_device_type_tests # to apply the new OptimizerInfo structure. + @onlyCUDA + @unittest.skipIf( + not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >=5.3 required for graphs" + ) + @optims( + [optim for optim in optim_db if optim.has_capturable_arg], + dtypes=[torch.float32], + ) + def test_graph_optims(self, device, dtype, optim_info): + optim_cls = optim_info.optim_cls + all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs( + device, dtype, optim_info, skip=("differentiable",) + ) + + steps_warmup = 3 + steps_train = 2 + + for optim_input in all_optim_inputs: + kwargs = optim_input.kwargs + + # lr as a Tensor is not supported when capturable=False and foreach=True for torch.optim.adam + # and torch.optim.adamw + kwargs["lr"] = 0.1 + + for actually_do_graphs in (True, False): + params = [ + torch.randn((i + 5, i + 5), device=device) for i in range(2) + ] + [torch.randn((), device=device)] + params_control = [p.clone().requires_grad_() for p in params] + params_graphed = [p.clone().requires_grad_() for p in params] + + grads = [ + [torch.randn_like(p) for p in params] + for _ in range(steps_warmup + steps_train) + ] + + # Control (capturable=False) + kwargs["capturable"] = False + + opt = optim_cls(params_control, **kwargs) + for i in range(steps_warmup + steps_train): + for j, p in enumerate(params_control): + p.grad = grads[i][j] + opt.step() + + # capturable=True + kwargs["capturable"] = True + opt = optim_cls(params_graphed, **kwargs) + + for i in range(steps_warmup): + for j, p in enumerate(params_graphed): + p.grad = grads[i][j] + opt.step() + + if actually_do_graphs: + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + opt.step() + + for i in range(steps_train): + if actually_do_graphs: + for j, p in enumerate(params_graphed): + p.grad.copy_(grads[i + steps_warmup][j]) + g.replay() + else: + # Passing capturable=True to the constructor and running without graphs should still be + # numerically correct, even if it's not ideal for performance. + for j, p in enumerate(params_graphed): + p.grad = grads[i + steps_warmup][j] + opt.step() + + for p_control, p_graphed in zip(params_control, params_graphed): + self.assertEqual(p_control, p_graphed) + + @onlyCUDA + @unittest.skipIf( + not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" + ) + @optims( + [optim for optim in optim_db if "fused" in optim.supported_impls], + dtypes=[torch.float32], + ) + def test_graph_scaling_fused_optimizers(self, device, dtype, optim_info): + optim_cls = optim_info.optim_cls + + steps_warmup = 3 + steps_train = 2 + + optim_inputs = optim_info.optim_inputs_func(device=device) + + for optim_input in optim_inputs: + kwargs = optim_input.kwargs + kwargs["fused"] = True + + for actually_do_graphs in ( + (True, False) if optim_info.has_capturable_arg else (True,) + ): + params = [torch.randn((i + 5, i + 5), device=device) for i in range(2)] + params_control = [p.clone().requires_grad_() for p in params] + params_graphed = [p.clone().requires_grad_() for p in params] + + # `GradScaler` in-place updates gradients thus it's necessary to duplicate gradients. + grads = [ + [torch.randn_like(p) for p in params] + for _ in range(steps_warmup + steps_train) + ] + with torch.no_grad(): + grads_control = [[g.clone() for g in gs] for gs in grads] + grads_graphed = [[g.clone() for g in gs] for gs in grads] + + # Gradient Scaler + scaler_for_control = torch.cuda.amp.GradScaler(init_scale=128.0) + with torch.no_grad(): + scaler_for_control._lazy_init_scale_growth_tracker(device) + + scaler_for_graphed = torch.cuda.amp.GradScaler() + scaler_for_graphed.load_state_dict(scaler_for_control.state_dict()) + with torch.no_grad(): + scaler_for_graphed._lazy_init_scale_growth_tracker(device) + + # Control (capturable=False) + if optim_info.has_capturable_arg: + kwargs["capturable"] = False + opt = optim_cls(params_control, **kwargs) + + for i in range(steps_warmup + steps_train): + for j, p in enumerate(params_control): + p.grad = grads_control[i][j] + scaler_for_control.step(opt) + scaler_for_control.update() + + # capturable=True + if optim_info.has_capturable_arg: + kwargs["capturable"] = True + opt = optim_cls(params_graphed, **kwargs) + + for i in range(steps_warmup): + for j, p in enumerate(params_graphed): + p.grad = grads_graphed[i][j] + scaler_for_graphed.step(opt) + scaler_for_graphed.update() + + if actually_do_graphs: + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + scaler_for_graphed.step(opt) + scaler_for_graphed.update() + + for i in range(steps_train): + if actually_do_graphs: + for j, p in enumerate(params_graphed): + p.grad.copy_(grads_graphed[i + steps_warmup][j]) + g.replay() + else: + # Passing capturable=True to the constructor and running without graphs should still be + # numerically correct, even if it's not ideal for performance. + for j, p in enumerate(params_graphed): + p.grad = grads_graphed[i + steps_warmup][j] + scaler_for_graphed.step(opt) + scaler_for_graphed.update() + + for p_control, p_graphed in zip(params_control, params_graphed): + self.assertEqual(p_control, p_graphed) + @onlyCUDA @unittest.skipIf( not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" diff --git a/torch/testing/_internal/common_optimizers.py b/torch/testing/_internal/common_optimizers.py index c81efb093cd8..5abacf2df1d6 100644 --- a/torch/testing/_internal/common_optimizers.py +++ b/torch/testing/_internal/common_optimizers.py @@ -123,6 +123,8 @@ def __init__( supported_impls: Tuple[str] = ("foreach", "differentiable"), # the optim supports passing in sparse gradients as well as dense grads supports_sparse: bool = False, + # the optimizer constructor supports passing in capturable as a kwarg + has_capturable_arg: bool = False, # the optim only supports one config: sparse grads w/ dense params, see SparseAdam only_supports_sparse_grads: bool = False, # Tuple of (optimizer kwargs, schedulers_constructors) specifically for sparse tests, @@ -147,6 +149,7 @@ def __init__( self.scheduler_inputs = scheduler_inputs self.supported_impls = supported_impls self.supports_sparse = supports_sparse + self.has_capturable_arg = has_capturable_arg self.metadata_for_sparse = metadata_for_sparse self.only_supports_sparse_grads = only_supports_sparse_grads self.supports_complex = supports_complex @@ -311,10 +314,11 @@ def optim_inputs_func_adadelta(device, dtype=None): OptimizerInput( params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay" ), + OptimizerInput(params=None, kwargs={"maximize": True}, desc="maximize"), OptimizerInput( params=None, kwargs={"weight_decay": 0.1, "maximize": True}, - desc="maximize", + desc="maximize, weight_decay", ), OptimizerInput( params=None, kwargs={"rho": 0.95, "weight_decay": 0.9}, desc="rho" @@ -528,9 +532,14 @@ def optim_inputs_func_adamax(device, dtype=None): ), OptimizerInput( params=None, - kwargs={"weight_decay": 0.1, "maximize": True}, + kwargs={"maximize": True}, desc="maximize", ), + OptimizerInput( + params=None, + kwargs={"weight_decay": 0.1, "maximize": True}, + desc="maximize, weight_decay", + ), ] + (cuda_supported_configs if "cuda" in str(device) else []) @@ -683,14 +692,20 @@ def optim_inputs_func_nadam(device, dtype=None): ), OptimizerInput( params=None, - kwargs={"weight_decay": 0.1, "momentum_decay": 6e-3}, + kwargs={ + "weight_decay": 0.1, + }, desc="weight_decay", ), + OptimizerInput( + params=None, + kwargs={"weight_decay": 0.1, "momentum_decay": 6e-3}, + desc="weight_decay, momentum_decay", + ), OptimizerInput( params=None, kwargs={ "weight_decay": 0.1, - "momentum_decay": 6e-3, "decoupled_weight_decay": True, }, desc="decoupled_weight_decay", @@ -818,11 +833,26 @@ def optim_inputs_func_rmsprop(device, dtype=None): OptimizerInput( params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay" ), + OptimizerInput( + params=None, + kwargs={ + "maximize": True, + }, + desc="maximize", + ), OptimizerInput( params=None, kwargs={"weight_decay": 0.1, "centered": True}, desc="centered", ), + OptimizerInput( + params=None, + kwargs={ + "maximize": True, + "weight_decay": 0.1, + }, + desc="maximize, weight_decay", + ), OptimizerInput( params=None, kwargs={"weight_decay": 0.1, "centered": True, "momentum": 0.1}, @@ -836,7 +866,7 @@ def optim_inputs_func_rmsprop(device, dtype=None): "momentum": 0.1, "maximize": True, }, - desc="maximize", + desc="maximize, centered, weight_decay, w/ momentum", ), ] + (cuda_supported_configs if "cuda" in str(device) else []) @@ -907,7 +937,15 @@ def optim_inputs_func_sgd(device, dtype=None): OptimizerInput( params=None, kwargs={"lr": torch.tensor(0.001)}, desc="tensor lr" ), + OptimizerInput( + params=None, kwargs={"weight_decay": 0.5}, desc="non-zero weight_decay" + ), OptimizerInput(params=None, kwargs={"momentum": 0.9}, desc="momentum"), + OptimizerInput( + params=None, + kwargs={"weight_decay": 0.1, "maximize": True}, + desc="maximize", + ), OptimizerInput( params=None, kwargs={"momentum": 0.9, "dampening": 0.5}, @@ -916,18 +954,13 @@ def optim_inputs_func_sgd(device, dtype=None): OptimizerInput( params=None, kwargs={"momentum": 0.9, "weight_decay": 0.1}, - desc="non-zero weight_decay", + desc="weight_decay w/ momentum", ), OptimizerInput( params=None, kwargs={"momentum": 0.9, "nesterov": True, "weight_decay": 0.1}, desc="nesterov", ), - OptimizerInput( - params=None, - kwargs={"weight_decay": 0.1, "maximize": True}, - desc="maximize", - ), ] @@ -1097,6 +1130,7 @@ def _get_optim_inputs_including_global_cliquey_kwargs( optim_inputs_func=optim_inputs_func_adadelta, optim_error_inputs_func=optim_error_inputs_func_adadelta, supported_impls=("foreach", "differentiable"), + has_capturable_arg=True, skips=( DecorateInfo( skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"), @@ -1232,6 +1266,7 @@ def _get_optim_inputs_including_global_cliquey_kwargs( optim_error_inputs_func=optim_error_inputs_func_adam, supported_impls=("foreach", "differentiable", "fused"), supports_fused_on=("cpu", "cuda"), + has_capturable_arg=True, decorators=( # Expected floating point error between fused and compiled forloop DecorateInfo( @@ -1298,6 +1333,7 @@ def _get_optim_inputs_including_global_cliquey_kwargs( optim_inputs_func=optim_inputs_func_adamax, optim_error_inputs_func=optim_error_inputs_func_adamax, supported_impls=("foreach", "differentiable"), + has_capturable_arg=True, skips=( DecorateInfo( skipIfMps, # addcdiv doesn't work for non-contiguous, see #118115 @@ -1348,6 +1384,7 @@ def _get_optim_inputs_including_global_cliquey_kwargs( optim_error_inputs_func=optim_error_inputs_func_adamw, supported_impls=("foreach", "differentiable", "fused"), supports_fused_on=("cpu", "cuda"), + has_capturable_arg=True, decorators=( # Expected error between compiled forloop and fused optimizers DecorateInfo( @@ -1414,6 +1451,7 @@ def _get_optim_inputs_including_global_cliquey_kwargs( optim_inputs_func=optim_inputs_func_asgd, optim_error_inputs_func=optim_error_inputs_func_asgd, supported_impls=("foreach", "differentiable"), + has_capturable_arg=True, skips=( DecorateInfo( skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"), @@ -1506,6 +1544,7 @@ def _get_optim_inputs_including_global_cliquey_kwargs( optim_inputs_func=optim_inputs_func_nadam, optim_error_inputs_func=optim_error_inputs_func_nadam, supported_impls=("foreach", "differentiable"), + has_capturable_arg=True, skips=( DecorateInfo( skipIfMps, # addcdiv doesn't work for non-contiguous, see #118115 @@ -1561,6 +1600,7 @@ def _get_optim_inputs_including_global_cliquey_kwargs( optim_inputs_func=optim_inputs_func_radam, optim_error_inputs_func=optim_error_inputs_func_radam, supported_impls=("foreach", "differentiable"), + has_capturable_arg=True, skips=( DecorateInfo( skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"), @@ -1606,6 +1646,7 @@ def _get_optim_inputs_including_global_cliquey_kwargs( optim_inputs_func=optim_inputs_func_rmsprop, optim_error_inputs_func=optim_error_inputs_func_rmsprop, supported_impls=("foreach", "differentiable"), + has_capturable_arg=True, skips=( DecorateInfo( skipIfMps, # addcdiv doesn't work for non-contiguous, see #118115 @@ -1655,6 +1696,7 @@ def _get_optim_inputs_including_global_cliquey_kwargs( optim_inputs_func=optim_inputs_func_rprop, optim_error_inputs_func=optim_error_inputs_func_rprop, supported_impls=("foreach", "differentiable"), + has_capturable_arg=True, skips=( DecorateInfo( skipIfMps, # Rprop doesn't update for non-contiguous, see #118117 From 5ad2f100340c64123db607824f70119c8fe2b38a Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 20 May 2024 06:41:11 +0000 Subject: [PATCH 115/116] Revert "[inductor] Load python modules using importlib (#126454)" This reverts commit faa26df72e2a3ff08f9dd564bb50756916826854. Reverted https://github.com/pytorch/pytorch/pull/126454 on behalf of https://github.com/DanilBaibak due to Break internal build ([comment](https://github.com/pytorch/pytorch/pull/126454#issuecomment-2119771267)) --- torch/_inductor/runtime/compile_tasks.py | 29 ++++++++++++------------ 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/torch/_inductor/runtime/compile_tasks.py b/torch/_inductor/runtime/compile_tasks.py index b29a95f64b6c..66a36703da45 100644 --- a/torch/_inductor/runtime/compile_tasks.py +++ b/torch/_inductor/runtime/compile_tasks.py @@ -1,10 +1,10 @@ from __future__ import annotations import functools -import importlib import os import sys import warnings +from types import ModuleType from typing import Any, Callable @@ -31,20 +31,19 @@ def _reload_python_module_in_subproc(key, path): def _reload_python_module(key, path): - spec = importlib.util.spec_from_file_location(f"{__name__}.{key}", path) - if spec is None or spec.loader is None: - raise RuntimeError(f"Failed to import {path}: path not found") - module = importlib.util.module_from_spec(spec) - module.key = key # type: ignore[attr-defined] - try: - spec.loader.exec_module(module) - except Exception as e: - raise RuntimeError( - f"Failed to import {path}\n{type(e).__name__}: {e}" - ) from None - - sys.modules[module.__name__] = module - return module + with open(path) as f: + try: + code = compile(f.read(), path, "exec") + except Exception as e: + raise RuntimeError( + f"Failed to import {path}\n{type(e).__name__}: {e}" + ) from None + mod = ModuleType(f"{__name__}.{key}") + mod.__file__ = path + mod.key = key # type: ignore[attr-defined] + exec(code, mod.__dict__, mod.__dict__) + sys.modules[mod.__name__] = mod + return mod @functools.lru_cache(None) From 53f73cdeb6fe53155f91064ca15b722e80dec2f3 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 20 May 2024 06:59:58 +0000 Subject: [PATCH 116/116] Revert "Add symbolic_shape_specialization structured trace (#126450)" This reverts commit da1fc85d60fcf0bd1e8638d643a7c0c6560c3a5f. Reverted https://github.com/pytorch/pytorch/pull/126450 on behalf of https://github.com/DanilBaibak due to Break internal build ([comment](https://github.com/pytorch/pytorch/pull/126450#issuecomment-2119798075)) --- torch/fx/experimental/symbolic_shapes.py | 27 ++++++------------------ 1 file changed, 7 insertions(+), 20 deletions(-) diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index e310d490b77c..be1be24137f8 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -4397,9 +4397,6 @@ def _set_replacement(self, a: "sympy.Symbol", tgt: "sympy.Expr", msg: str) -> No Use this instead of `self.replacements[a] = tgt`. """ - if tgt == self.replacements.get(a, None): - return - # Precondition: a == tgt assert isinstance(a, sympy.Symbol) @@ -4490,24 +4487,14 @@ def issubset(x, y): "[%s not subset of %s (size-oblivious conditions)]", a, tgt, msg, tgt_bound_so, src_bound_so) return - if isinstance(tgt, (sympy.Integer, sympy.Float)): - # specializing to a constant, which is likely unexpected (unless - # you specified dynamic=True) - - user_tb = TracingContext.extract_stack() - trace_structured( - "symbolic_shape_specialization", - metadata_fn=lambda: { - "symbol": repr(a), - "sources": [s.name() for s in self.var_to_sources[a]], - "value": repr(tgt), - "reason": msg, - "stack": structured.from_traceback(CapturedTraceback.extract(skip=1).summary()), - "user_stack": structured.from_traceback(user_tb) if user_tb else None, - } - ) + if config.print_specializations and isinstance(tgt, (sympy.Integer, sympy.Float)): + # specializing to a constant, which is likely unexpected - if config.print_specializations: + # NOTE(avik): It is possible that we try logging the same specialization multiple times, e.g., + # when adding a to self.replacements, and again when simplifying an expression containing a. + # Thus to avoid duplication, checking whether a is in self.replacements isn't enough; if it is, + # it must not already map to `tgt`. Fortunately this check is cheap because `tgt` is a constant. + if a not in self.replacements or tgt != self.replacements[a]: self.log.warning("Specializing %s to %s", self.var_to_sources[a][0].name(), tgt) self.log.debug("SPECIALIZATION", stack_info=True) log.info("set_replacement %s = %s (%s) %s", a, tgt, msg, tgt_bound)