From b7affa2ac3888d0c6eaf8150267102afb429e395 Mon Sep 17 00:00:00 2001 From: Thiago Crepaldi Date: Thu, 26 Oct 2023 14:21:39 +0000 Subject: [PATCH 01/78] Add unit test for ONNX models with torch.distributions.normal.Normal (#111498) Fixes #111034 Pull Request resolved: https://github.com/pytorch/pytorch/pull/111498 Approved by: https://github.com/justinchuby, https://github.com/BowenBao --- .ci/docker/common/install_onnx.sh | 2 +- test/onnx/test_fx_op_consistency.py | 4 ---- test/onnx/test_fx_to_onnx.py | 16 ++++++++++++++++ 3 files changed, 17 insertions(+), 5 deletions(-) diff --git a/.ci/docker/common/install_onnx.sh b/.ci/docker/common/install_onnx.sh index e70234739f813..37185b0d2b30b 100755 --- a/.ci/docker/common/install_onnx.sh +++ b/.ci/docker/common/install_onnx.sh @@ -32,7 +32,7 @@ pip_install coloredlogs packaging retry pip_install -i https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT-Nightly/pypi/simple/ --no-cache-dir --no-input ort-nightly==1.17.0.dev20231005006 pip_install -i https://test.pypi.org/simple/ onnx==1.15.0rc2 -pip_install onnxscript==0.1.0.dev20231006 --no-deps +pip_install onnxscript==0.1.0.dev20231025 --no-deps # Cache the transformers model to be used later by ONNX tests. We need to run the transformers # package to download the model. By default, the model is cached at ~/.cache/huggingface/hub/ diff --git a/test/onnx/test_fx_op_consistency.py b/test/onnx/test_fx_op_consistency.py index b6d9d3af37f8b..c23c4afe1a77c 100644 --- a/test/onnx/test_fx_op_consistency.py +++ b/test/onnx/test_fx_op_consistency.py @@ -460,10 +460,6 @@ def skip_torchlib_forward_compatibility( "nn.functional.dropout", reason=onnx_test_common.reason_dynamo_does_not_support("Dropout"), ), - xfail( - "nn.functional.embedding", - reason=onnx_test_common.reason_onnx_script_does_not_support("aten.embedding_renorm.default"), - ), xfail( "nn.functional.max_pool2d", dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, diff --git a/test/onnx/test_fx_to_onnx.py b/test/onnx/test_fx_to_onnx.py index 864859b7e4144..8effa50e0bfaa 100644 --- a/test/onnx/test_fx_to_onnx.py +++ b/test/onnx/test_fx_to_onnx.py @@ -538,6 +538,22 @@ def forward(self, x): with self.assertRaises(torch.onnx.InvalidExportOptionsError): raise self._export_exception + def test_exported_program_torch_distributions_normal_Normal(self): + class Model(torch.nn.Module): + def __init__(self): + self.normal = torch.distributions.normal.Normal(0, 1) + super().__init__() + + def forward(self, x): + return self.normal.sample(x.shape) + + x = torch.randn(2, 3) + exported_program = torch.export.export(Model(), args=(x,)) + _ = torch.onnx.dynamo_export( + exported_program, + x, + ) + if __name__ == "__main__": common_utils.run_tests() From 3a284dae30e1de15d18372c6b448ac67c13d9040 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 26 Oct 2023 18:53:14 +0000 Subject: [PATCH 02/78] Revert "Do not materialize entire randperm in RandomSampler (#103339)" This reverts commit d80174e2db679365f8b58ff8583bdc4af5a8b74c. Reverted https://github.com/pytorch/pytorch/pull/103339 on behalf of https://github.com/kit1980 due to Cause issues on MPS, and also fails without numpy ([comment](https://github.com/pytorch/pytorch/pull/103339#issuecomment-1781705172)) --- torch/utils/data/sampler.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/torch/utils/data/sampler.py b/torch/utils/data/sampler.py index 3bdde318b923f..bdbb577f14d2c 100644 --- a/torch/utils/data/sampler.py +++ b/torch/utils/data/sampler.py @@ -157,13 +157,12 @@ def __iter__(self) -> Iterator[int]: if self.replacement: for _ in range(self.num_samples // 32): - yield from map(int, torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).numpy()) - final_samples = torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator) - yield from map(int, final_samples.numpy()) + yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist() + yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist() else: for _ in range(self.num_samples // n): - yield from map(int, torch.randperm(n, generator=generator).numpy()) - yield from map(int, torch.randperm(n, generator=generator)[:self.num_samples % n].numpy()) + yield from torch.randperm(n, generator=generator).tolist() + yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n] def __len__(self) -> int: return self.num_samples From 5e5329155e0aea98418923cf8046fa3cdbde81cc Mon Sep 17 00:00:00 2001 From: Yang Chen Date: Thu, 26 Oct 2023 19:47:04 +0000 Subject: [PATCH 03/78] [aotinductor] only include -lc10 for non-fbcode case (#112125) Summary: otherwise, we would break internal uses Differential Revision: D50681467 Pull Request resolved: https://github.com/pytorch/pytorch/pull/112125 Approved by: https://github.com/swolchok, https://github.com/desertfire, https://github.com/SherlockNoMad --- torch/_inductor/codecache.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index c52947ec029e7..3e8b98436269f 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -1199,9 +1199,10 @@ def get_include_and_linking_paths( else: libs = ["omp"] if config.is_fbcode() else ["gomp"] - # Unconditionally import c10 to use TORCH_CHECK - See PyTorch #108690 - libs += ["c10"] - lpaths += [cpp_extension.TORCH_LIB_PATH] + # Unconditionally import c10 for non-fbcode to use TORCH_CHECK - See PyTorch #108690 + if not config.is_fbcode(): + libs += ["c10"] + lpaths += [cpp_extension.TORCH_LIB_PATH] # third party libs if config.is_fbcode(): From 1b702b185e8dddadb4ad3f487f5412a02c8777e1 Mon Sep 17 00:00:00 2001 From: Justin Yip Date: Thu, 26 Oct 2023 19:48:37 +0000 Subject: [PATCH 04/78] [pytorch-vulkan] disable one zero-dim tensor test to fix test (#112087) Summary: D50347338 has bug on android (not Mac, not Devserver). This diff disable the test for time being while I identify the actual cause. Test Plan: ## Compile on devserver ``` [yipjustin@129360.od ~/fbsource (e415d865c)]$ buck2 build -c ndk.static_linking=true -c pt.enable_qpl=0 --target-platforms=ovr_config//platform/android:arm32-fbsource //xplat/caffe2:pt_vulkan_api_test_binAndroid --show-output File changed: fbcode//caffe2/aten/src/ATen/test/vulkan_api_test.cpp File changed: fbsource//xplat/caffe2/aten/src/ATen/test/vulkan_api_test.cpp Buck UI: https://www.internalfb.com/buck2/99d47e63-ed6e-4db9-bee2-24909d647b78 Network: Up: 3.2KiB Down: 67KiB (reSessionID-459e359b-773c-48a4-b129-81fde7c5e876) Jobs completed: 4664. Time elapsed: 7.3s. Cache hits: 100%. Commands: 38 (cached: 38, remote: 0, local: 0) BUILD SUCCEEDED fbsource//xplat/caffe2:pt_vulkan_api_test_binAndroid buck-out/v2/gen/fbsource/f1f3f9bed27e143c/xplat/caffe2/__pt_vulkan_api_test_binAndroid__/pt_vulkan_api_test_binAndroid ``` ## Run test. adb shell /data/local/tmp/pt_vulkan_api_test_binAndroid | pastry Result: P864940908 ``` ... [ OK ] VulkanAPITest.lstm_success (7 ms) [ RUN ] VulkanAPITest.lstm_mclareninputs_success [ OK ] VulkanAPITest.lstm_mclareninputs_success (56 ms) [ RUN ] VulkanAPITest.lstm_prepack_success [ OK ] VulkanAPITest.lstm_prepack_success (7 ms) [ RUN ] VulkanAPITest.querypool_flushed_shader_log xplat/caffe2/aten/src/ATen/test/vulkan_api_test.cpp:7568: Skipped QueryPool is not available [ SKIPPED ] VulkanAPITest.querypool_flushed_shader_log (0 ms) [----------] 391 tests from VulkanAPITest (30715 ms total) [----------] Global test environment tear-down [==========] 391 tests from 1 test suite ran. (30715 ms total) [ PASSED ] 390 tests. [ SKIPPED ] 1 test, listed below: [ SKIPPED ] VulkanAPITest.querypool_flushed_shader_log YOU HAVE 7 DISABLED TESTS ``` Reviewed By: liuk22 Differential Revision: D50668570 Pull Request resolved: https://github.com/pytorch/pytorch/pull/112087 Approved by: https://github.com/izaitsevfb, https://github.com/SS-JIA --- aten/src/ATen/native/vulkan/ops/Sum.cpp | 5 +++++ aten/src/ATen/test/vulkan_api_test.cpp | 13 +++++++++++++ 2 files changed, 18 insertions(+) diff --git a/aten/src/ATen/native/vulkan/ops/Sum.cpp b/aten/src/ATen/native/vulkan/ops/Sum.cpp index 14d9f2222cd42..9b51040e54ca3 100644 --- a/aten/src/ATen/native/vulkan/ops/Sum.cpp +++ b/aten/src/ATen/native/vulkan/ops/Sum.cpp @@ -135,6 +135,11 @@ Tensor sum_dim_IntList( Tensor sum(const Tensor& self, const c10::optional dtype) { std::vector dims; for (int64_t d = 0; d < self.dim(); d++) { + // If any dimension has zero elements, we will shortcut to a zero-dim. + if (self.size(d) == 0) { + return self.new_zeros({}, at::device(at::kVulkan).dtype(self.dtype())); + } + dims.push_back(d); } diff --git a/aten/src/ATen/test/vulkan_api_test.cpp b/aten/src/ATen/test/vulkan_api_test.cpp index c15fb63123621..382c9e464caeb 100644 --- a/aten/src/ATen/test/vulkan_api_test.cpp +++ b/aten/src/ATen/test/vulkan_api_test.cpp @@ -282,6 +282,11 @@ TEST_F(VulkanAPITest, zero_size_tensor) { ASSERT_TRUE(at::equal(out_vk, cpu)); } +TEST_F(VulkanAPITest, zero_size_tensor_numel) { + auto vk = at::rand({18, 0, 5}, at::device(at::kVulkan).dtype(at::kFloat)); + ASSERT_TRUE(vk.numel() == 0); +} + TEST_F(VulkanAPITest, zero_dim_tensor_1) { auto cpu = at::rand({}, at::device(at::kCPU).dtype(at::kFloat)); auto vv = cpu.item(); @@ -302,6 +307,12 @@ TEST_F(VulkanAPITest, zero_dim_tensor_2) { ASSERT_TRUE(almostEqual(cpu, vk.cpu())); } +TEST_F(VulkanAPITest, zero_dim_tensor_3) { + auto vk = at::zeros({}, at::device(at::kVulkan).dtype(at::kFloat)); + + ASSERT_TRUE(vk.cpu().item() == 0.0f); +} + TEST_F(VulkanAPITest, local_scalar_dense) { float v = 8.31f; // Force the zero-dim tensor to a non-zero constant v. @@ -4530,6 +4541,8 @@ TEST_F(VulkanAPITest, sum_test) { test_sum({6}); test_sum({5, 6}); test_sum({0, 3, 1}); + test_sum({5, 0, 1}); + test_sum({5, 3, 0}); test_sum({3, 3, 1}); test_sum({7, 6, 6}); test_sum({7, 8, 5, 6}); From 8a7c3cec78686e661b3781b916a8aae59083f90a Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Wed, 25 Oct 2023 18:45:50 +0000 Subject: [PATCH 05/78] Constrain sdpa to fx strides (#111721) Fix for https://github.com/pytorch/pytorch/issues/109607. sdpa requires last dimension strides to be 1. Add constraint so that we run the op with the strides we observed in tracing. Pull Request resolved: https://github.com/pytorch/pytorch/pull/111721 Approved by: https://github.com/drisspg, https://github.com/Chillee, https://github.com/jansel ghstack dependencies: #111976 --- test/inductor/test_torchinductor.py | 56 ++++++++++++++++ ...st_torchinductor_codegen_dynamic_shapes.py | 1 + torch/_inductor/lowering.py | 67 +++++++++++++++++-- 3 files changed, 120 insertions(+), 4 deletions(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 669d99632e9e2..a0b65087c9211 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -6772,6 +6772,62 @@ def forward(arg6, arg7, arg16): # expanded dim should not cause copy in require_stride_order self.assertEqual(torch._inductor.metrics.generated_kernel_count, 0) + @requires_cuda() + def test_sdpa(self): + def foo(arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): + view = torch.ops.aten.view.default(arg3_1, [23760, 128]) + arg3_1 = None + mm = torch.ops.aten.mm.default(view, arg4_1) + view = arg4_1 = None + view_1 = torch.ops.aten.view.default(mm, [3, 99, 80, 8]) + mm = None + view_2 = torch.ops.aten.view.default(view_1, [3, 99, 80, 8]) + view_1 = None + permute = torch.ops.aten.permute.default(view_2, [0, 3, 1, 2]) + view_2 = None + view_3 = torch.ops.aten.view.default(permute, [3, 8, 99, 80]) + permute = None + + clone = torch.ops.aten.clone.default( + view_3, memory_format=torch.contiguous_format + ) + view_3 = None + + expand = torch.ops.aten.expand.default(clone, [3, 8, 99, 80]) + clone = None + _scaled_dot_product_efficient_attention = ( + torch.ops.aten._scaled_dot_product_efficient_attention.default( + arg0_1, arg1_1, arg2_1, expand, False + ) + ) + arg0_1 = arg1_1 = arg2_1 = expand = None + getitem = _scaled_dot_product_efficient_attention[0] + _scaled_dot_product_efficient_attention = None + return (getitem,) + + DEVICE = torch.device("cuda:0") + DTYPE = torch.float16 + B = 3 + H = 8 + Q = 99 + K = 80 + D = 32 + C_bias = 128 + + # inputs + query = torch.randn((B, H, Q, D), device=DEVICE, dtype=DTYPE) + key = torch.randn((B, H, K, D), device=DEVICE, dtype=DTYPE) + value = torch.randn((B, H, K, D), device=DEVICE, dtype=DTYPE) + bias = torch.randn((B, Q, K, C_bias), device=DEVICE, dtype=DTYPE) + weights = torch.randn((C_bias, H), device=DEVICE, dtype=DTYPE) + + self.common( + foo, + (query, key, value, bias, weights), + atol=0.02, + rtol=1e4, + ) + def test_where_with_logical_op(self): def fn_and(x, y): return torch.where(torch.logical_and(x, y), 1.0, 0.0) diff --git a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py index 8677453a55c5c..0e80d8adb5828 100644 --- a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py @@ -261,6 +261,7 @@ def run(*ex, **kwargs): "test_zero_dim_reductions_dynamic_shapes": TestFailure( ("cpu", "cuda"), is_skip=True ), + "test_sdpa_dynamic_shapes": TestFailure(("cpu",), is_skip=True), # # The following tests do not support dynamic shapes yet: # diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index b569da72aba01..ed142c57b1017 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -2008,10 +2008,69 @@ def apply_constraint(arg, fx_arg): make_fallback(aten._fused_moving_avg_obs_fq_helper_functional) make_fallback(aten.grid_sampler_2d_backward, require_dense) make_fallback(aten.randperm) -make_fallback(aten._scaled_dot_product_efficient_attention) -make_fallback(aten._scaled_dot_product_efficient_attention_backward) -make_fallback(aten._scaled_dot_product_flash_attention, warn=False) -make_fallback(aten._scaled_dot_product_flash_attention_backward) + + +def sdpa_constraint(fx_node, *args, **kwargs): + # sdpa requires dense last dimension + def apply_constraint(arg, fx_arg): + if not isinstance(arg, ir.IRNode): + return arg + + meta_val = fx_arg.meta["val"] + if not meta_val.is_cuda: + return arg + + stride_order = ir.get_stride_order(meta_val.stride()) + if stride_order and stride_order[-1] != 0: + # contiguous stride order + stride_order = list(reversed(range(len(arg.get_size())))) + + ALIGNMENT = 16 + + def is_aligned(x): + return (V.graph.sizevars.size_hint(x.get_size()[-1]) % ALIGNMENT) == 0 + + assert isinstance(arg, TensorBox) + unaligned_input_shape = isinstance(arg.data, ir.SliceView) and not is_aligned( + arg + ) + aligned_input_view = unaligned_input_shape and is_aligned(arg.unwrap_view()) + + # input is padded, requiring_stride_order will unwrap the view and unpad. + # Would be nice to be able to require certain padding from inductor ir, nyi + if aligned_input_view: + return arg + + return ir.ExternKernel.require_stride_order(arg, stride_order) + + args = tuple( + apply_constraint(arg, fx_arg) for arg, fx_arg in zip(args, fx_node.args) + ) + kwargs = {k: apply_constraint(v, fx_node.kwargs[k]) for k, v in kwargs.items()} + return args, kwargs + + +make_fallback( + aten._scaled_dot_product_efficient_attention.default, + sdpa_constraint, + warn=False, +) +make_fallback( + aten._scaled_dot_product_efficient_attention_backward.default, + sdpa_constraint, + warn=False, +) +make_fallback( + aten._scaled_dot_product_flash_attention.default, + sdpa_constraint, + warn=False, +) +make_fallback( + aten._scaled_dot_product_flash_attention_backward.default, + sdpa_constraint, + warn=False, +) + make_fallback(aten.sort) make_fallback(aten.sort.stable) make_fallback(aten._sparse_coo_tensor_with_dims_and_tensors) From f839a5627bccc20607099d28ce93b73367365949 Mon Sep 17 00:00:00 2001 From: Lengyue Date: Thu, 26 Oct 2023 20:30:46 +0000 Subject: [PATCH 06/78] Add bf16 support to replicate padding (#112099) Fixes #99433 Pull Request resolved: https://github.com/pytorch/pytorch/pull/112099 Approved by: https://github.com/mikaylagawarecki --- aten/src/ATen/native/cuda/ReplicationPadding.cu | 12 ++++++------ .../testing/_internal/common_methods_invocations.py | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/aten/src/ATen/native/cuda/ReplicationPadding.cu b/aten/src/ATen/native/cuda/ReplicationPadding.cu index 24e783fb6c5ce..e65c0e90fe03d 100644 --- a/aten/src/ATen/native/cuda/ReplicationPadding.cu +++ b/aten/src/ATen/native/cuda/ReplicationPadding.cu @@ -268,7 +268,7 @@ void replication_pad2d_backward_out_cuda_template( } gradInput.zero_(); - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kHalf, + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "replication_pad2d_backward_cuda", [&] { auto gradInput_ = gradInput; @@ -383,7 +383,7 @@ void replication_pad3d_backward_out_cuda_template( } gradInput.zero_(); - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kHalf, + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "replication_pad3d_backward_cuda", [&] { auto gradInput_ = gradInput; auto gradOutput_ = gradOutput; @@ -437,7 +437,7 @@ TORCH_IMPL_FUNC(replication_pad1d_out_cuda) ( return; } - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kHalf, + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16, input.scalar_type(), "replication_pad1d_cuda", [&] { at::Tensor input_ = input; at::Tensor output_ = output; @@ -499,7 +499,7 @@ TORCH_IMPL_FUNC(replication_pad1d_backward_out_cuda) ( } gradInput.zero_(); - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kHalf, + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "replication_pad1d_backward_cuda", [&] { auto gradInput_ = gradInput; @@ -543,7 +543,7 @@ TORCH_IMPL_FUNC(replication_pad2d_out_cuda) ( // const auto padR = paddingSize[1]; // This padding is ignored here const auto padT = paddingSize[2]; // const auto padB = paddingSize[3]; // This padding is ignored here - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kHalf, + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16, input.scalar_type(), "replication_pad2d_cuda", [&] { at::Tensor input_ = input; at::Tensor output_ = output; @@ -635,7 +635,7 @@ TORCH_IMPL_FUNC(replication_pad3d_out_cuda) ( return; } - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kHalf, + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16, input.scalar_type(), "replication_pad3d_cuda", [&] { at::Tensor input_ = input; at::Tensor output_ = output; diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index c7df2f6422cf9..a248e6b5551eb 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -12931,7 +12931,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1): supports_forward_ad=True, supports_fwgrad_bwgrad=True, dtypes=all_types_and_complex_and(torch.bfloat16), - dtypesIfCUDA=all_types_and_complex_and(torch.half), + dtypesIfCUDA=all_types_and_complex_and(torch.half, torch.bfloat16), sample_inputs_func=partial(sample_inputs_nn_pad, mode='replicate'), skips=( # Doesn't have a corresponding aten operator. From f66cc675629fead56304205c24b712adba1401a3 Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Thu, 26 Oct 2023 10:39:33 -0700 Subject: [PATCH 07/78] [aotinductor] Fix duplicated unbacked symbol declarations (#111823) Summary: For https://github.com/pytorch/pytorch/issues/111711 Pull Request resolved: https://github.com/pytorch/pytorch/pull/111823 Approved by: https://github.com/ezyang, https://github.com/aakhundov --- test/inductor/test_aot_inductor.py | 20 ++++++++++++++++++++ torch/_inductor/codegen/wrapper.py | 10 ++++++++++ torch/_inductor/ir.py | 6 +++--- 3 files changed, 33 insertions(+), 3 deletions(-) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 5cc4ad9df5713..670628e97c822 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -933,6 +933,24 @@ def forward(self, x): example_inputs = (torch.randn(8, 4, 4, device=self.device),) self.check_model(Model(), example_inputs) + def test_dup_unbacked_sym_decl(self): + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + abs_1 = torch.ops.aten.abs.default(x) + lt = torch.ops.aten.lt.Scalar(abs_1, 0.001) + eq = torch.ops.aten.eq.Scalar(lt, 0) + index_1 = torch.ops.aten.index.Tensor(x, [eq]) + sin = torch.ops.aten.sin.default(index_1) + index_2 = torch.ops.aten.index.Tensor(x, [eq]) + div_3 = torch.ops.aten.div.Tensor(sin, index_2) + return div_3 + + example_inputs = (torch.randn(4, 4, 4, 4).to(self.device),) + self.check_model(Model(), example_inputs) + class AOTInductorTestABICompatibleCpu(TestCase): device = "cpu" @@ -951,6 +969,7 @@ class AOTInductorTestABICompatibleCpu(TestCase): "test_bmm_multiple_dynamic": TestFailure(("abi_compatible_cpu",)), "test_dynamic_cat": TestFailure(("abi_compatible_cpu",)), "test_dynamic_smem_above_default_limit": TestFailure(("abi_compatible_cpu",)), + "test_dup_unbacked_sym_decl": TestFailure(("abi_compatible_cpu",)), "test_foreach_multiple_dynamic": TestFailure(("abi_compatible_cpu",)), # TODO: test_freezing_abi_compatible_cpu somehow fails on CI but not locally, # NotImplementedError: Cannot access storage of OpaqueTensorImpl @@ -977,6 +996,7 @@ class AOTInductorTestABICompatibleCuda(TestCase): "abi_compatible_cuda", # test_failures, xfail by default, set is_skip=True to skip { + "test_dup_unbacked_sym_decl": TestFailure(("abi_compatible_cuda",)), "test_normal_functional": TestFailure(("abi_compatible_cuda",)), }, ) diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 9c844152e3c7c..8bef2105f6266 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -327,6 +327,7 @@ def __init__(self): self.expr_printer = pexpr self.cached_thread_locals = set() self.user_defined_kernel_count = 0 + self.unbacked_symbol_decls = set() self.write_header() self.write_prefix() @@ -1106,6 +1107,15 @@ def codegen_inplace_reuse(self, input_buffer, output_buffer): self.reuses[output_buffer.get_name()] = input_buffer.get_name() self.writeline(ReuseLine(self, input_buffer, output_buffer)) + def codegen_unbacked_symbol_decl(self, symbol): + name = str(symbol) + if name in self.unbacked_symbol_decls: + return name + else: + # When in CppWrapperCodeGen, we should only generate the declaration once + self.unbacked_symbol_decls.add(name) + return self.declare + name + class CppWrapperCodeGen(WrapperCodeGen): """ diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index d00ee55b04631..7be3836e0d75e 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -2718,18 +2718,18 @@ def codegen_unbacked_symbol_defs(self, wrapper): for i, s in enumerate(self.get_size()): if s in symbols_to_define: wrapper.writeline( - f"{wrapper.declare}{s} = {self.get_name()}.size({i}){wrapper.ending}" + f"{wrapper.codegen_unbacked_symbol_decl(s)} = {self.get_name()}.size({i}){wrapper.ending}" ) symbols_to_define.remove(s) for i, s in enumerate(self.get_stride()): if s in symbols_to_define: wrapper.writeline( - f"{wrapper.declare}{s} = {self.get_name()}.stride({i}){wrapper.ending}" + f"{wrapper.codegen_unbacked_symbol_decl(s)} = {self.get_name()}.stride({i}){wrapper.ending}" ) symbols_to_define.remove(s) if (s := self.get_offset()) in symbols_to_define: wrapper.writeline( - f"{wrapper.declare}{s} = {self.get_name()}.storage_offset(){wrapper.ending}" + f"{wrapper.codegen_unbacked_symbol_decl(s)} = {self.get_name()}.storage_offset(){wrapper.ending}" ) symbols_to_define.remove(s) assert ( From 73f36e44fbb2298cfb537362646d72788f71f510 Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Thu, 26 Oct 2023 10:39:34 -0700 Subject: [PATCH 08/78] [aotinductor] Add a debug compile flag (#112021) Summary: When the debug compile flag is specified, model.so is compiled with "-O0 -g". Pull Request resolved: https://github.com/pytorch/pytorch/pull/112021 Approved by: https://github.com/chenyang78 ghstack dependencies: #111823 --- torch/_inductor/codecache.py | 4 +++- torch/_inductor/config.py | 2 ++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 3e8b98436269f..9193781b8ae35 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -990,7 +990,9 @@ def cpp_wrapper_flags() -> str: def optimization_flags() -> str: - base_flags = "-O3 -DNDEBUG -ffast-math -fno-finite-math-only" + base_flags = "-O0 -g" if config.aot_inductor.debug_compile else "-O3 -DNDEBUG" + base_flags += " -ffast-math -fno-finite-math-only" + if config.is_fbcode(): # FIXME: passing `-fopenmp` adds libgomp.so to the generated shared library's dependencies. # This causes `ldopen` to fail in fbcode, because libgomp does not exist in the default paths. diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 287692589d52c..cad5cc9923038 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -484,6 +484,8 @@ class aot_inductor: # If not specified, a temp directory will be created under the default caching path output_path = "" + debug_compile = os.environ.get("AOT_INDUCTOR_DEBUG_COMPILE", "0") == "1" + # Wether to codegen abi compatible model.so abi_compatible = is_fbcode() From 27cf49549a35dd78475098b7de02c0a5ab1367ea Mon Sep 17 00:00:00 2001 From: Jon Chuang <9093549+jon-chuang@users.noreply.github.com> Date: Thu, 26 Oct 2023 21:12:59 +0000 Subject: [PATCH 09/78] [dynamo] `ExecutorchCallDelegateHigherOrderVariable` - add sanity check that input and output tensors are disjoint (#111960) Fixes https://github.com/pytorch/pytorch/issues/111917 Pull Request resolved: https://github.com/pytorch/pytorch/pull/111960 Approved by: https://github.com/zou3519 --- torch/_dynamo/variables/higher_order_ops.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py index 77833bff5a59d..776561e1eb84b 100644 --- a/torch/_dynamo/variables/higher_order_ops.py +++ b/torch/_dynamo/variables/higher_order_ops.py @@ -82,6 +82,18 @@ def only_consist_of(var, types): return False +def _assert_tensors_nonaliasing(inputs, outputs): + input_tensor_ids = set( + pytree.tree_flatten_only(torch.Tensor, lambda t: id(t), inputs) + ) + output_tensor_ids = set( + pytree.tree_flatten_only(torch.Tensor, lambda t: id(t), outputs) + ) + assert input_tensor_ids.isdisjoint( + output_tensor_ids + ), "inputs to function body cannot alias outputs" + + def validate_args_and_maybe_create_graph_inputs( sub_args, tracer, tx, manually_set_subgraph_inputs ): @@ -708,7 +720,14 @@ def call_function( real_sub_args = pytree.tree_map_only( torch.fx.Proxy, lambda a: get_real_value(a.node, tx.output), p_args ) + example_res = lowered_module.original_module(*real_sub_args) + + # NOTE [Guaranteeing the 1-1 correspondence of FakeTensors and real tensors]: + # executorch modules promise not to alias inputs and outputs. + # Thus, output FakeTensors will correctly not alias input FakeTensors. + _assert_tensors_nonaliasing(real_sub_args, example_res) + example_value = deepcopy_to_fake_tensor(example_res, tx.fake_mode) p_args = (lowered_node,) + p_args From d91a18c4335d36c46da382e2301fec57b5cd283a Mon Sep 17 00:00:00 2001 From: rzou Date: Thu, 26 Oct 2023 07:25:00 -0700 Subject: [PATCH 10/78] Grandfather in torchgen'ed aten ops to torch.Tag.pt2_compliant_tag (#112053) In torchgen, we add the pt2_compliant_tag to all aten ops. Test Plan: - new test Pull Request resolved: https://github.com/pytorch/pytorch/pull/112053 Approved by: https://github.com/soulitzer --- test/test_custom_ops.py | 4 ++++ torchgen/model.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/test/test_custom_ops.py b/test/test_custom_ops.py index e2e64320879a6..7c4ac74145451 100644 --- a/test/test_custom_ops.py +++ b/test/test_custom_ops.py @@ -1693,6 +1693,10 @@ def test_define_with_tags(self, tags): self.assertTrue(isinstance(actual, list)) self.assertEqual(actual, list(tags)) + def test_builtin_aten_ops_are_pt2_compliant(self): + for op in [torch.ops.aten.sin.default, torch.ops.aten.sum.dim_IntList]: + self.assertIn(torch.Tag.pt2_compliant_tag, op.tags) + def test_define_bad_schema(self): lib = self.lib() with self.assertRaisesRegex(ValueError, "expected schema to look like"): diff --git a/torchgen/model.py b/torchgen/model.py index 3084c2822b4e7..29178a60d20f9 100644 --- a/torchgen/model.py +++ b/torchgen/model.py @@ -649,6 +649,10 @@ def from_yaml( tags_inp = [tags_inp] assert isinstance(tags_inp, list) + # All aten ops generated by torchgen receive the pt2_compliant tag. + if namespace == "aten" and "pt2_compliant_tag" in valid_tags: + tags_inp.append("pt2_compliant_tag") + tags: Set[str] = set() for t in tags_inp: assert len(valid_tags) > 0 From abe172e268b007a0b59cfaa1b19cb0b822eaf52a Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 26 Oct 2023 21:29:40 +0000 Subject: [PATCH 11/78] Revert "Cleanup error reporting for ProcessGroupNCCL (#111979)" This reverts commit b29c658265d6b95d8ec77f7052eff4f25190fbfc. Reverted https://github.com/pytorch/pytorch/pull/111979 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it is failing multigpu test in trunk https://hud.pytorch.org/pytorch/pytorch/commit/b29c658265d6b95d8ec77f7052eff4f25190fbfc ([comment](https://github.com/pytorch/pytorch/pull/111979#issuecomment-1781919184)) --- c10/util/Exception.h | 3 - test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp | 2 +- test/distributed/test_c10d_nccl.py | 38 ++-- .../distributed/c10d/ProcessGroupNCCL.cpp | 191 ++++++++---------- 4 files changed, 105 insertions(+), 129 deletions(-) diff --git a/c10/util/Exception.h b/c10/util/Exception.h index 6b312cafb89a5..735b4fb837fda 100644 --- a/c10/util/Exception.h +++ b/c10/util/Exception.h @@ -324,9 +324,6 @@ C10_API std::string GetExceptionString(const std::exception& e); throw ::c10::err_type( \ {__func__, __FILE__, static_cast(__LINE__)}, msg) -#define C10_BUILD_ERROR(err_type, msg) \ - ::c10::err_type({__func__, __FILE__, static_cast(__LINE__)}, msg) - // Private helper macro for workaround MSVC misexpansion of nested macro // invocations involving __VA_ARGS__. See // https://stackoverflow.com/questions/5134523/msvc-doesnt-expand-va-args-correctly diff --git a/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp b/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp index d396987308b26..b7e944aeb2820 100644 --- a/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp +++ b/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp @@ -224,7 +224,7 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLTimedoutErrorsBlocking) { // Now run all reduce with errors. pg.set_timedout_error(); work = pg.allreduce(tensors_); - EXPECT_THROW(work->wait(), c10::DistBackendError); + EXPECT_THROW(work->wait(), std::runtime_error); // Communicators might be aborted here, further operations would fail. } diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index e2c5a2221becb..d7071edeef995 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -202,7 +202,7 @@ def tearDown(self): def test_init_no_gpus(self): store = c10d.FileStore(self.file.name, self.world_size) with self.assertRaisesRegex( - ValueError, "ProcessGroupNCCL is only supported with GPUs, no GPUs found!" + RuntimeError, "ProcessGroupNCCL is only supported with GPUs, no GPUs found!" ): c10d.ProcessGroupNCCL(store, self.rank, self.world_size) @@ -407,7 +407,7 @@ def allreduce(tensors, op): for op, err in zip((c10d.ReduceOp.BAND, c10d.ReduceOp.BOR, c10d.ReduceOp.BXOR), ("ReduceOp.BAND", "ReduceOp.BOR", "ReduceOp.BXOR")): with self.assertRaisesRegex( - ValueError, "Cannot use " + err + " with NCCL" + RuntimeError, "Cannot use " + err + " with NCCL" ): allreduce(tensors, op) @@ -524,7 +524,7 @@ def reduce(xs, rootRank, rootTensor, op=None): ("ReduceOp.BAND", "ReduceOp.BOR", "ReduceOp.BXOR"), ): with self.assertRaisesRegex( - ValueError, "Cannot use " + err + " with NCCL" + RuntimeError, "Cannot use " + err + " with NCCL" ): reduce(tensors, self.rank, rt, op) @@ -610,7 +610,7 @@ def allgather_base(output_t, input_t): # anticipate an error with self.assertRaisesRegex( - ValueError, + RuntimeError, "output tensor size must be equal to world_size times input tensor size", ): tensor = torch.tensor([self.rank]).cuda(local_device_id) @@ -622,7 +622,7 @@ def allgather_base(output_t, input_t): # anticipate an error with self.assertRaisesRegex( - TypeError, "output tensor must have the same type as input tensor" + RuntimeError, "output tensor must have the same type as input tensor" ): tensor = torch.tensor([self.rank], dtype=torch.float).cuda(local_device_id) output_t = torch.empty((self.world_size + 1), dtype=torch.long).cuda( @@ -731,7 +731,7 @@ def test_gather_checks(self): for rank in range(self.world_size): output_ts[idx].append(torch.tensor([-1]).cuda(gpu_idx)) - with self.assertRaisesRegex(ValueError, "invalid root rank"): + with self.assertRaisesRegex(RuntimeError, "invalid root rank"): opts = c10d.GatherOptions() opts.rootRank = -1 pg.gather(output_ts, tensors, opts) @@ -739,7 +739,7 @@ def test_gather_checks(self): with self.assertRaisesRegex(TypeError, "incompatible function arguments"): pg.gather(output_ts, tensors, 0) - with self.assertRaisesRegex(ValueError, "invalid root rank"): + with self.assertRaisesRegex(RuntimeError, "invalid root rank"): opts = c10d.GatherOptions() opts.rootRank = self.world_size pg.gather(output_ts, tensors, opts) @@ -753,7 +753,7 @@ def test_gather_checks(self): pg.gather(output_ts, [], opts) with self.assertRaisesRegex( - ValueError, "Tensors must be on distinct GPU devices" + RuntimeError, "Tensors must be on distinct GPU devices" ): # init input tensors2 = [] @@ -866,7 +866,7 @@ def test_scatter_checks(self): for rank in range(self.world_size): scatter_list[idx].append(torch.tensor([rank]).cuda(gpu_idx)) - with self.assertRaisesRegex(ValueError, "invalid root rank"): + with self.assertRaisesRegex(RuntimeError, "invalid root rank"): opts = c10d.ScatterOptions() opts.rootRank = -1 pg.scatter(tensors, scatter_list, opts) @@ -874,7 +874,7 @@ def test_scatter_checks(self): with self.assertRaisesRegex(TypeError, "incompatible function arguments"): pg.scatter(tensors, scatter_list, 0) - with self.assertRaisesRegex(ValueError, "invalid root rank"): + with self.assertRaisesRegex(RuntimeError, "invalid root rank"): opts = c10d.ScatterOptions() opts.rootRank = self.world_size pg.scatter(tensors, scatter_list, opts) @@ -900,7 +900,7 @@ def reduce_scatter_base(output_t, input_t): # anticipate an error with self.assertRaisesRegex( - ValueError, + RuntimeError, "input tensor must be the same size as output size times world size", ): input_t = torch.tensor([self.rank]).cuda(local_device_id) @@ -912,7 +912,7 @@ def reduce_scatter_base(output_t, input_t): # anticipate an error with self.assertRaisesRegex( - TypeError, "input tensor must be the same type as the output tensor." + RuntimeError, "input tensor must be the same type as the output tensor." ): tensor = torch.tensor([self.rank], dtype=torch.float).cuda(local_device_id) output_t = torch.empty((self.world_size + 1), dtype=torch.long).cuda( @@ -1116,7 +1116,7 @@ def test_send_recv(self): # Test with non-contiguous tensors. send_tensor_view = send_tensor.t() if self.rank == 0: - with self.assertRaisesRegex(ValueError, 'Tensors must be contiguous'): + with self.assertRaisesRegex(RuntimeError, 'Tensors must be contiguous'): dist.send(send_tensor_view, 1) @requires_nccl() @@ -1243,13 +1243,13 @@ def test_nccl_propagate_error_reason(self): if self.rank != 0: # Time out due to rank 0 not calling into allreduce. - with self.assertRaises(dist.DistBackendError): + with self.assertRaises(RuntimeError): pg.allreduce([inp]).wait(timedelta(seconds=5)) # Now when nonzero rank attempts to use communicator, original failure reason should be logged.j try: pg.allreduce([torch.ones(2).cuda(self.rank)]).wait() - except dist.DistBackendError as e: + except RuntimeError as e: self.assertTrue("aborted" in str(e)) else: self.fail("Expected error to be raised!") @@ -2783,7 +2783,7 @@ def _test_nccl_errors_blocking(self, func): process_group.allreduce(torch.rand(10).cuda(self.rank)) if self.rank == 0: work = process_group.allreduce(torch.rand(10).cuda(self.rank)) - with self.assertRaisesRegex(dist.DistBackendError, self.blocking_wait_error_msg): + with self.assertRaisesRegex(RuntimeError, self.blocking_wait_error_msg): # Operation would time out in blocking mode. work.wait(timeout=timedelta(seconds=self.op_timeout_sec)) # Run some GPU operations to make sure cuda has not gotten stuck. @@ -2852,7 +2852,7 @@ def test_nccl_blocking_wait_with_barrier(self): ) process_group.barrier().wait() if self.rank == 0: - with self.assertRaisesRegex(dist.DistBackendError, self.blocking_wait_error_msg): + with self.assertRaisesRegex(RuntimeError, self.blocking_wait_error_msg): # This should timeout process_group.barrier().wait(timeout=timedelta(seconds=self.op_timeout_sec)) @@ -2890,7 +2890,7 @@ def test_nccl_timeout(self): if self.rank == 0: # This should timeout in about 1 second. # Watchdog may abort timed out work resulting in NCCL error instead of operation timed out. - with self.assertRaisesRegex(DistBackendError, self.blocking_wait_error_msg): + with self.assertRaisesRegex(RuntimeError, self.blocking_wait_error_msg): process_group.allreduce(torch.rand(10).cuda(self.rank)).wait(timeout=failed_collective_timeout) # Now do a barrier to tell other rank to go ahead. pg_gloo.barrier().wait() @@ -3093,7 +3093,7 @@ def test_nccl_barrier_timeout(self): store = c10d.FileStore(self.file_name, self.world_size) if self.rank == 0: with self.assertRaisesRegex( - DistBackendError, "Health check failure" + RuntimeError, "Health check failure" ): c10d.init_process_group( backend="nccl", diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index a4f69c6a0b553..564caeda5ea89 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -73,8 +73,7 @@ std::map ncclDataType = { // Helper function that gets the data type and issues error if not supported ncclDataType_t getNcclDataType(at::ScalarType type) { auto it = ncclDataType.find(type); - TORCH_CHECK_WITH( - TypeError, + TORCH_CHECK( it != ncclDataType.end(), "Input tensor data type is not supported for NCCL process group: ", type); @@ -124,8 +123,7 @@ ncclRedOpRAII getNcclReduceOp( } #ifdef NCCL_HAS_AVG if (reduceOp == ReduceOp::AVG) { - C10_THROW_ERROR( - TypeError, "Cannot use ReduceOp.AVG with boolean inputs"); + TORCH_CHECK(false, "Cannot use ReduceOp.AVG with boolean inputs"); } #endif } @@ -142,38 +140,37 @@ ncclRedOpRAII getNcclReduceOp( return unpackPreMulSum( reduceOp, comm, dev_in_group); default: - C10_THROW_ERROR( - TypeError, "PreMulSum Data type must be half, float, or double"); + TORCH_CHECK( + false, "PreMulSum Data type must be half, float, or double"); ncclRedOp_t unused; return unused; } #else - C10_THROW_ERROR(ValueError, "PreMulSum requires NCCL>=2.11.1"); + TORCH_CHECK(false, "PreMulSum requires NCCL>=2.11.1"); #endif } return ncclOp.at(reduceOp); } catch (const std::out_of_range& e) { switch (reduceOp) { case ReduceOp::AVG: - C10_THROW_ERROR( - ValueError, - c10::str( - "AVG requires NCCL 2.10+. The current version is ", - NCCL_MAJOR, - ".", - NCCL_MINOR)); + TORCH_CHECK( + false, + "AVG requires NCCL 2.10+. The current version is ", + NCCL_MAJOR, + ".", + NCCL_MINOR); break; case ReduceOp::BAND: - C10_THROW_ERROR(ValueError, "Cannot use ReduceOp.BAND with NCCL"); + TORCH_CHECK(false, "Cannot use ReduceOp.BAND with NCCL"); break; case ReduceOp::BOR: - C10_THROW_ERROR(ValueError, "Cannot use ReduceOp.BOR with NCCL"); + TORCH_CHECK(false, "Cannot use ReduceOp.BOR with NCCL"); break; case ReduceOp::BXOR: - C10_THROW_ERROR(ValueError, "Cannot use ReduceOp.BXOR with NCCL"); + TORCH_CHECK(false, "Cannot use ReduceOp.BXOR with NCCL"); break; default: - C10_THROW_ERROR(ValueError, "Unhandled ReduceOp"); + TORCH_CHECK(false, "Unhandled ReduceOp"); break; } } @@ -218,7 +215,7 @@ std::vector getDeviceList(const std::vector& tensors) { // Return CUDA device with ordinal given by input rank. at::Device getDeviceForRank(int rank) { - TORCH_CHECK_WITH(ValueError, rank >= 0, "Invalid rank ", rank); + TORCH_CHECK(rank >= 0, "Invalid rank ", rank); auto numGPUs = at::cuda::getNumGPUs(); int16_t deviceIdx = static_cast(rank % numGPUs); return at::Device(at::DeviceType::CUDA, deviceIdx); @@ -283,8 +280,7 @@ inline void errorIfCapturingNonCapturableNCCL(c10::cuda::CaptureStatus status) { (((uint64_t)2) << 32) + (((uint64_t)9) << 16) + ((uint64_t)6); static const uint64_t cur_version = torch::cuda::nccl::version(); if (cur_version < min_version) { - TORCH_CHECK_WITH( - NotImplementedError, + TORCH_CHECK( status == c10::cuda::CaptureStatus::None, "Capturing NCCL collectives is only allowed with NCCL >= 2.9.6"); } @@ -703,7 +699,7 @@ bool ProcessGroupNCCL::WorkNCCL::checkTimeout( LOG(ERROR) << exceptionMsg; std::exception_ptr exception_ptr = - std::make_exception_ptr(C10_BUILD_ERROR(DistBackendError, exceptionMsg)); + std::make_exception_ptr(std::runtime_error(exceptionMsg)); setException(exception_ptr); return true; } @@ -866,8 +862,7 @@ ProcessGroupNCCL::ProcessGroupNCCL( traceKeyEnd_(getTraceEndKey("NCCL", rank)), terminateProcessGroup_(false), uid_(process_group_id++) { - TORCH_CHECK_WITH( - ValueError, + TORCH_CHECK( at::cuda::getNumGPUs() != 0, "ProcessGroupNCCL is only supported with GPUs, no GPUs found!"); blockingWait_ = parseEnvVarFlag(NCCL_BLOCKING_WAIT); @@ -1022,8 +1017,7 @@ void ProcessGroupNCCL::runHealthCheck() { } // If there is no exception, the likely culprit is a timeout/hang which is how // most communicator init issues manifest themselves. - TORCH_CHECK_WITH( - DistBackendError, + TORCH_CHECK( healthCheckData.healthCheckSuccess, "ProcessGroupNCCL: Health check failure: Failed to initialize NCCL communicator on rank ", rank_); @@ -1038,13 +1032,11 @@ uint64_t ProcessGroupNCCL::getSequenceNumberForGroup() { void ProcessGroupNCCL::registerOnCompletionHook( std::function)>&& hook) { - TORCH_CHECK_WITH( - DistBackendError, + TORCH_CHECK( onCompletionHook_ == nullptr, "ProcessGroupNCCL OnCompletion hook already registered"); - TORCH_CHECK_WITH( - ValueError, + TORCH_CHECK( enableTiming_.load(), "ProcessGroupNCCL OnCompletion hook requires recording start and end " "events which require setting NCCL_ENABLE_TIMING environment variable. " @@ -1184,8 +1176,7 @@ void ProcessGroupNCCL::ncclCommWatchdog() { LOG(ERROR) << exitMsg; // TODO(whc) clean up the rethrow - why is it stored in a class var and // rethrown? - watchDogException_ = - std::make_exception_ptr(C10_BUILD_ERROR(DistBackendError, exitMsg)); + watchDogException_ = std::make_exception_ptr(std::runtime_error(exitMsg)); std::rethrow_exception(watchDogException_); } } catch (...) { @@ -1194,8 +1185,7 @@ void ProcessGroupNCCL::ncclCommWatchdog() { rank_, "] NCCL watchdog thread terminated with exception: unknown"); LOG(ERROR) << exitMsg; - watchDogException_ = - std::make_exception_ptr(C10_BUILD_ERROR(DistBackendError, exitMsg)); + watchDogException_ = std::make_exception_ptr(std::runtime_error(exitMsg)); std::rethrow_exception(watchDogException_); } } @@ -1392,18 +1382,15 @@ std::exception_ptr ProcessGroupNCCL::checkForNCCLErrorsInternal( // commFailureReason is set. auto commFailureReason = ncclComm->getNcclCommFailureReason(); if (commFailureReason != c10::nullopt) { - return std::make_exception_ptr(C10_BUILD_ERROR( - DistBackendError, - c10::str( - "NCCL communicator encountered error set by ProcessGroupNCCL: ", - *commFailureReason))); + return std::make_exception_ptr(std::runtime_error(c10::str( + "NCCL communicator encountered error set by ProcessGroupNCCL: ", + *commFailureReason))); } ncclResult_t ncclAsyncErr = ncclComm->checkForNcclError(); if (ncclAsyncErr != ncclSuccess) { - return std::make_exception_ptr(C10_BUILD_ERROR( - DistBackendError, + return std::make_exception_ptr(std::runtime_error( "NCCL error: " + ncclGetErrorWithVersion(ncclAsyncErr) + "\n" + - getNcclErrorDetailStr(ncclAsyncErr))); + getNcclErrorDetailStr(ncclAsyncErr))); } } @@ -1443,10 +1430,7 @@ void ProcessGroupNCCL::broadcastUniqueNCCLID( } else { try { auto vec = store_->get(storeKey); - TORCH_CHECK_WITH( - DistBackendError, - vec.size() == NCCL_UNIQUE_ID_BYTES, - "Invalid size for ncclUniqueId"); + TORCH_CHECK(vec.size() == NCCL_UNIQUE_ID_BYTES); std::memcpy(ncclID, vec.data(), vec.size()); } catch (const std::exception& e) { std::string exceptionMsg = c10::str( @@ -1458,13 +1442,13 @@ void ProcessGroupNCCL::broadcastUniqueNCCLID( "', but store->get('", storeKey, "') got error: "); - C10_THROW_ERROR( - DistBackendError, + TORCH_CHECK( + false, exceptionMsg + e.what() + ". This may indicate a possible application crash on rank 0 or a network set up issue."); } catch (...) { - C10_THROW_ERROR( - DistBackendError, + TORCH_CHECK( + false, c10::str( "Unknown exception while [", rank_, @@ -1508,8 +1492,8 @@ std::vector>& ProcessGroupNCCL::getNCCLComm( bool isSendRecvSelf) { // Sanity check if (devicesKey.empty()) { - C10_THROW_ERROR( - DistBackendError, + TORCH_CHECK( + false, "Not able to create/get the NCCL Communicator since " "the GPU devices are not known"); } @@ -1675,10 +1659,10 @@ namespace { // Check validity of tensor void check_gpu_single_tensor(const at::Tensor& tensor) { if (!tensor.is_cuda() || tensor.is_sparse()) { - C10_THROW_ERROR(ValueError, "Tensors must be CUDA and dense"); + TORCH_CHECK(false, "Tensors must be CUDA and dense"); } if (!tensor.is_contiguous(tensor.suggest_memory_format())) { - C10_THROW_ERROR(ValueError, "Tensors must be contiguous"); + TORCH_CHECK(false, "Tensors must be contiguous"); } } @@ -1690,11 +1674,11 @@ void check_gpu_single_tensor(const at::Tensor& tensor) { void check_gpu_tensors_different_devices( const std::vector& tensors) { if (tensors.size() == 0) { - C10_THROW_ERROR(ValueError, "Tensor list must be nonempty"); + TORCH_CHECK(false, "Tensor list must be nonempty"); } if (tensors.size() > static_cast(at::cuda::getNumGPUs())) { - C10_THROW_ERROR( - ValueError, + TORCH_CHECK( + false, "Tensor list mustn't be larger than the number of available GPUs"); } @@ -1706,23 +1690,23 @@ void check_gpu_tensors_different_devices( for (const auto& t : tensors) { if (!t.is_cuda() || t.is_sparse()) { - C10_THROW_ERROR(ValueError, "Tensors must be CUDA and dense"); + TORCH_CHECK(false, "Tensors must be CUDA and dense"); } if (t.scalar_type() != first.scalar_type()) { - C10_THROW_ERROR(TypeError, "Tensors must have identical type"); + TORCH_CHECK(false, "Tensors must have identical type"); } if (t.sizes() != first.sizes()) { - C10_THROW_ERROR(ValueError, "Tensors must have identical size"); + TORCH_CHECK(false, "Tensors must have identical size"); } if (t.strides() != first.strides()) { - C10_THROW_ERROR(ValueError, "Tensors must have identical strides"); + TORCH_CHECK(false, "Tensors must have identical strides"); } if (!t.is_contiguous(t.suggest_memory_format())) { - C10_THROW_ERROR(ValueError, "Tensors must be contiguous"); + TORCH_CHECK(false, "Tensors must be contiguous"); } const auto inserted = usedDevices.insert(t.get_device()).second; if (!inserted) { - C10_THROW_ERROR(ValueError, "Tensors must be on distinct GPU devices"); + TORCH_CHECK(false, "Tensors must be on distinct GPU devices"); } } } @@ -1736,7 +1720,7 @@ void check_gpu_tensors_different_devices( // different devices in the same process. int64_t check_gpu_tensors_same_device(const std::vector& tensors) { if (tensors.size() == 0) { - C10_THROW_ERROR(ValueError, "Tensor list must be nonempty"); + TORCH_CHECK(false, "Tensor list must be nonempty"); } const auto& first = tensors.front(); @@ -1744,20 +1728,19 @@ int64_t check_gpu_tensors_same_device(const std::vector& tensors) { int64_t total_numel = 0; for (const auto& t : tensors) { if (!t.is_cuda() || t.is_sparse()) { - C10_THROW_ERROR(ValueError, "Tensors must be CUDA and dense"); + TORCH_CHECK(false, "Tensors must be CUDA and dense"); } if (t.scalar_type() != first.scalar_type()) { - C10_THROW_ERROR(TypeError, "Tensors must have identical type"); + TORCH_CHECK(false, "Tensors must have identical type"); } if (!t.is_non_overlapping_and_dense()) { - C10_THROW_ERROR(ValueError, "Tensors must be non-overlapping and dense"); + TORCH_CHECK(false, "Tensors must be non-overlapping and dense"); } // If we're in this function, the user called a _coalesced collective // on a set of tensors with potentially different sizes and strides. // Therefore, we don't check for matching sizes and strides, // but we do double-check tensors are on the same device. - TORCH_CHECK_WITH( - ValueError, + TORCH_CHECK( t.get_device() == tensors[0].get_device(), "Expected list of tensors on the same device"); total_numel += t.numel(); @@ -1782,8 +1765,8 @@ std::vector flatten_for_scatter_gather( std::vector& other, size_t world_size) { if (tensor_lists.size() != other.size()) { - C10_THROW_ERROR( - ValueError, + TORCH_CHECK( + false, "Tensor list operands to scatter/gather must have the same length"); } const auto num_devices = tensor_lists.size(); @@ -1793,8 +1776,8 @@ std::vector flatten_for_scatter_gather( for (const auto i : c10::irange(size_t{}, num_devices)) { if (tensor_lists[i].size() != world_size * num_devices) { - C10_THROW_ERROR( - ValueError, + TORCH_CHECK( + false, c10::str( "Tensor list input to scatter/gather must match number of collective participants ", "but got ", @@ -1810,16 +1793,16 @@ std::vector flatten_for_scatter_gather( // Only check device match for the first tensor in the list; the call to // newLikeFlat() below will check the rest. if (tensor_lists[i].front().get_device() != other[i].get_device()) { - C10_THROW_ERROR( - ValueError, + TORCH_CHECK( + false, "Corresponding input/output tensors to scatter/gather must all reside" " on the same device"); } for (const auto& t : tensor_lists[i]) { if (t.numel() != other[i].numel()) { - C10_THROW_ERROR( - ValueError, + TORCH_CHECK( + false, "All tensor operands to scatter/gather must have the same number of elements"); } } @@ -2440,8 +2423,8 @@ c10::intrusive_ptr ProcessGroupNCCL::allreduce_sparse( return work; #else // If the nccl branch is not "exp" then we just error - C10_THROW_ERROR( - Error, + TORCH_CHECK( + false, "allreduce_sparse is only available in the NCCL experimental branch."); #endif } @@ -2587,8 +2570,8 @@ c10::intrusive_ptr ProcessGroupNCCL::_broadcast_oop( // @lint-ignore CLANGTIDY auto in_tensor = inputTensors.back(); if (tensor.numel() != in_tensor.numel()) { - C10_THROW_ERROR( - ValueError, + TORCH_CHECK( + false, "Tensor input and output of _broadcast_oop must have the same number of elements "); } RECORD_PARAM_COMMS_DATA( @@ -2694,8 +2677,8 @@ c10::intrusive_ptr ProcessGroupNCCL::_reduce_oop( // @lint-ignore CLANGTIDY auto in_tensor = inputTensors.back(); if (tensor.numel() != in_tensor.numel()) { - C10_THROW_ERROR( - ValueError, + TORCH_CHECK( + false, "Tensor input and output of _reduce_oop must have the same number of elements "); } RECORD_PARAM_COMMS_DATA( @@ -2853,9 +2836,7 @@ c10::intrusive_ptr ProcessGroupNCCL::allgather_coalesced( std::vector>& /* unused */, std::vector& /* unused */, const AllgatherOptions& /* unused */) { - C10_THROW_ERROR( - NotImplementedError, - "ProcessGroupNCCL does not support allgather_coalesced"); + TORCH_CHECK(false, "ProcessGroupNCCL does not support allgather_coalesced"); } c10::intrusive_ptr ProcessGroupNCCL::allgather_into_tensor_coalesced( @@ -3006,13 +2987,13 @@ c10::intrusive_ptr ProcessGroupNCCL::_reduce_scatter_base( at::Tensor& inputTensor, const ReduceScatterOptions& opts) { if (inputTensor.dtype() != outputTensor.dtype()) { - C10_THROW_ERROR( - TypeError, "input tensor must be the same type as the output tensor."); + TORCH_CHECK( + false, "input tensor must be the same type as the output tensor."); } if (inputTensor.numel() != outputTensor.numel() * size_) { - C10_THROW_ERROR( - ValueError, + TORCH_CHECK( + false, "input tensor must be the same size as output size times world size"); } @@ -3393,8 +3374,8 @@ c10::intrusive_ptr ProcessGroupNCCL::alltoall_base( std::vector& /* unused */, std::vector& /* unused */, const AllToAllOptions& /* unused */) { - C10_THROW_ERROR( - NotImplementedError, + TORCH_CHECK( + false, "ProcessGroupNCCL only supports alltoall* for NCCL lib version >= 2.7.0"); } @@ -3402,8 +3383,8 @@ c10::intrusive_ptr ProcessGroupNCCL::alltoall( std::vector& /* unused */, std::vector& /* unused */, const AllToAllOptions& /* unused */) { - C10_THROW_ERROR( - NotImplementedError, + TORCH_CHECK( + false, "ProcessGroupNCCL only supports alltoall* for NCCL lib version >= 2.7.0"); } @@ -3411,8 +3392,8 @@ c10::intrusive_ptr ProcessGroupNCCL::send( std::vector& /* unused */, int /* unused */, int /* unused */) { - C10_THROW_ERROR( - NotImplementedError, + TORCH_CHECK( + false, "ProcessGroupNCCL only supports send for NCCL lib version >= 2.7.0"); } @@ -3420,8 +3401,8 @@ c10::intrusive_ptr ProcessGroupNCCL::recv( std::vector& /* unused */, int /* unused */, int /* unused */) { - C10_THROW_ERROR( - NotImplementedError, + TORCH_CHECK( + false, "ProcessGroupNCCL only supports recv for NCCL lib version >= 2.7.0"); } #endif @@ -3476,7 +3457,7 @@ c10::intrusive_ptr ProcessGroupNCCL::gather( std::vector& inputTensors, const GatherOptions& opts) { static auto invalidArgument = [](const std::string& msg) { - C10_THROW_ERROR(ValueError, "ProcessGroupNCCL::gather: " + msg); + TORCH_CHECK(false, "ProcessGroupNCCL::gather: " + msg); }; assertRootRank(invalidArgument, opts.rootRank, size_); @@ -3562,7 +3543,7 @@ c10::intrusive_ptr ProcessGroupNCCL::scatter( std::vector>& inputTensors, const ScatterOptions& opts) { static auto invalidArgument = [](const std::string& msg) { - C10_THROW_ERROR(ValueError, "ProcessGroupNCCL::scatter: " + msg); + TORCH_CHECK(false, "ProcessGroupNCCL::scatter: " + msg); }; assertRootRank(invalidArgument, opts.rootRank, size_); @@ -3648,8 +3629,7 @@ c10::intrusive_ptr ProcessGroupNCCL::scatter( c10::intrusive_ptr ProcessGroupNCCL::recvAnysource( std::vector& /* unused */, int /* unused */) { - C10_THROW_ERROR( - NotImplementedError, "ProcessGroupNCCL does not support recvAnysource"); + TORCH_CHECK(false, "ProcessGroupNCCL does not support recvAnysource"); } c10::intrusive_ptr ProcessGroupNCCL::_allgather_base( @@ -3660,13 +3640,12 @@ c10::intrusive_ptr ProcessGroupNCCL::_allgather_base( check_gpu_single_tensor(output_tensor); if (input_tensor.dtype() != output_tensor.dtype()) { - C10_THROW_ERROR( - TypeError, "output tensor must have the same type as input tensor"); + TORCH_CHECK(false, "output tensor must have the same type as input tensor"); } if (input_tensor.numel() * size_ != output_tensor.numel()) { - C10_THROW_ERROR( - ValueError, + TORCH_CHECK( + false, "output tensor size must be equal to world_size times input tensor size"); } From 190b6e4ba88f6cf00d0bd08d6212a3fe6bb76eaa Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Thu, 26 Oct 2023 21:41:22 +0000 Subject: [PATCH 12/78] Make numpy/lib vendored tests dynamo traceable (#112147) Follow up https://github.com/pytorch/pytorch/pull/112146 and #112141 : make numpy/lib vendored tests dynamo traceable Pull Request resolved: https://github.com/pytorch/pytorch/pull/112147 Approved by: https://github.com/lezcano --- pytest.ini | 2 + .../torch_np/numpy_tests/lib/test_arraypad.py | 26 +- .../numpy_tests/lib/test_arraysetops.py | 58 ++-- .../numpy_tests/lib/test_function_base.py | 254 +++++++++++------- .../numpy_tests/lib/test_histograms.py | 73 ++--- .../numpy_tests/lib/test_index_tricks.py | 47 +++- .../numpy_tests/lib/test_shape_base_.py | 76 ++++-- .../numpy_tests/lib/test_twodim_base.py | 102 ++++--- .../numpy_tests/lib/test_type_check.py | 54 ++-- 9 files changed, 460 insertions(+), 232 deletions(-) diff --git a/pytest.ini b/pytest.ini index 67a691290076d..532e3bce098f3 100644 --- a/pytest.ini +++ b/pytest.ini @@ -13,3 +13,5 @@ testpaths = junit_logging_reruns = all filterwarnings = ignore:Module already imported so cannot be rewritten.*hypothesis:pytest.PytestAssertRewriteWarning + +xfail_strict = True diff --git a/test/torch_np/numpy_tests/lib/test_arraypad.py b/test/torch_np/numpy_tests/lib/test_arraypad.py index 54745e8316d51..befa9d76ac467 100644 --- a/test/torch_np/numpy_tests/lib/test_arraypad.py +++ b/test/torch_np/numpy_tests/lib/test_arraypad.py @@ -1,15 +1,27 @@ # Owner(s): ["module: dynamo"] -from unittest import expectedFailure as xfail, skipIf as skipif +from unittest import skipIf as skipif -import torch._numpy as np -from torch._numpy.testing import assert_allclose, assert_array_equal +from torch.testing._internal.common_utils import ( + run_tests, + TEST_WITH_TORCHDYNAMO, + TestCase, + xpassIfTorchDynamo, +) -from torch.testing._internal.common_utils import run_tests, TestCase + +# If we are going to trace through these, we should use NumPy +# If testing on eager mode, we use torch._numpy +if TEST_WITH_TORCHDYNAMO: + import numpy as np + from numpy.testing import assert_allclose, assert_array_equal +else: + import torch._numpy as np + from torch._numpy.testing import assert_allclose, assert_array_equal class TestConstant(TestCase): - @xfail # (reason="tuple values") + @xpassIfTorchDynamo # (reason="tuple values") def test_check_constant(self): a = np.arange(100) a = np.pad(a, (25, 20), "constant", constant_values=(10, 20)) @@ -357,7 +369,7 @@ def test_check_constant_float2(self): ) assert_allclose(test, expected) - @xfail # (reason="tuple values") + @xpassIfTorchDynamo # (reason="tuple values") def test_check_constant_float3(self): a = np.arange(100, dtype=float) a = np.pad(a, (25, 20), "constant", constant_values=(-1.1, -1.2)) @@ -528,7 +540,7 @@ def test_check_constant_odd_pad_amount(self): ) assert_allclose(test, expected) - @xfail # (reason="tuple values") + @xpassIfTorchDynamo # (reason="tuple values") def test_check_constant_pad_2d(self): arr = np.arange(4).reshape(2, 2) test = np.lib.pad( diff --git a/test/torch_np/numpy_tests/lib/test_arraysetops.py b/test/torch_np/numpy_tests/lib/test_arraysetops.py index 0f9773ece6dfa..e046558078591 100644 --- a/test/torch_np/numpy_tests/lib/test_arraysetops.py +++ b/test/torch_np/numpy_tests/lib/test_arraysetops.py @@ -3,24 +3,39 @@ """Test functions for 1D array set operations. """ -from unittest import expectedFailure as xfail +from unittest import skipIf -import torch._numpy as np -from pytest import raises as assert_raises - -from torch._numpy import unique +import numpy -from torch._numpy.testing import assert_array_equal, assert_equal +from pytest import raises as assert_raises from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, run_tests, + subtest, + TEST_WITH_TORCHDYNAMO, TestCase, + xfailIfTorchDynamo, + xpassIfTorchDynamo, ) -@xfail # (reason="TODO") +# If we are going to trace through these, we should use NumPy +# If testing on eager mode, we use torch._numpy +if TEST_WITH_TORCHDYNAMO: + import numpy as np + from numpy import ediff1d, in1d, intersect1d, setdiff1d, setxor1d, union1d, unique + from numpy.testing import assert_array_equal, assert_equal, assert_raises_regex + +else: + import torch._numpy as np + from torch._numpy import unique + from torch._numpy.testing import assert_array_equal, assert_equal + + +@skipIf(numpy.__version__ < "1.24", reason="NP_VER: fails on NumPy 1.23.x") +@xpassIfTorchDynamo # (reason="TODO") @instantiate_parametrized_tests class TestSetOps(TestCase): def test_intersect1d(self): @@ -145,11 +160,14 @@ def test_ediff1d(self): (np.array([1, 2, 3], dtype=np.int64), None, np.nan, "to_end"), # should fail because attempting # to downcast to int type: - ( - np.array([1, 2, 3], dtype=np.int64), - np.array([5, 7, 2], dtype=np.float32), - None, - "to_begin", + subtest( + ( + np.array([1, 2, 3], dtype=np.int64), + np.array([5, 7, 2], dtype=np.float32), + None, + "to_begin", + ), + decorators=[xfailIfTorchDynamo], ), # should fail because attempting to cast # two special floating point values @@ -205,6 +223,7 @@ def test_ediff1d_scalar_handling(self, ary, prepend, append, expected): assert_equal(actual, expected) assert actual.dtype == expected.dtype + @skipIf(True, reason="NP_VER: fails with NumPy 1.22.x") @parametrize("kind", [None, "sort", "table"]) def test_isin(self, kind): # the tests for in1d cover most of isin's behavior @@ -217,7 +236,7 @@ def _isin_slow(a, b): isin_slow = np.vectorize(_isin_slow, otypes=[bool], excluded={1}) def assert_isin_equal(a, b): - x = isin(a, b, kind=kind) + x = np.isin(a, b, kind=kind) y = isin_slow(a, b) assert_array_equal(x, y) @@ -444,7 +463,7 @@ def test_in1d_table_timedelta_fails(self): a = np.array([0, 1, 2], dtype="timedelta64[s]") b = a # Make sure it raises a value error: - with pytest.raises(ValueError): + with assert_raises(ValueError): in1d(a, b, kind="table") @parametrize( @@ -475,7 +494,7 @@ def test_in1d_mixed_dtype(self, dtype1, dtype2, kind): ) if expect_failure: - with pytest.raises(RuntimeError, match="exceed the maximum"): + with assert_raises(RuntimeError, match="exceed the maximum"): in1d(ar1, ar2, kind=kind) else: assert_array_equal(in1d(ar1, ar2, kind=kind), expected) @@ -744,7 +763,7 @@ def check_all(a, b, i1, i2, c, dt): # assert_equal(a3_idx.dtype, np.intp) # assert_equal(a3_inv.dtype, np.intp) - @xfail # (reason="unique with nans") + @xpassIfTorchDynamo # (reason="unique with nans") def test_unique_1d_2(self): # test for ticket 2111 - float a = [2.0, np.nan, 1.0, np.nan] @@ -790,7 +809,7 @@ def test_unique_axis_list(self): assert_array_equal(unique(inp, axis=0), unique(inp_arr, axis=0), msg) assert_array_equal(unique(inp, axis=1), unique(inp_arr, axis=1), msg) - @xfail # _run_axis_tests xfails with the message + @xpassIfTorchDynamo # _run_axis_tests xfails with the message # torch has different unique ordering behaviour" def test_unique_axis(self): types = [] @@ -816,7 +835,7 @@ def test_unique_1d_with_axis(self, axis): uniq = unique(x, axis=axis) assert_array_equal(uniq, [1, 2, 3, 4]) - @xfail # (reason="unique / return_index") + @xpassIfTorchDynamo # (reason="unique / return_index") def test_unique_axis_zeros(self): # issue 15559 single_zero = np.empty(shape=(2, 0), dtype=np.int8) @@ -923,7 +942,8 @@ def _run_axis_tests(self, dtype): msg = "Unique's return_counts=True failed with axis=1" assert_array_equal(cnt, np.array([2, 1, 1]), msg) - @xfail # (reason="unique / return_index / nans") + @skipIf(True, reason="NP_VER: fails on CI with older NumPy") + @xpassIfTorchDynamo # (reason="unique / return_index / nans") def test_unique_nanequals(self): # issue 20326 a = np.array([1, 1, np.nan, np.nan, np.nan]) diff --git a/test/torch_np/numpy_tests/lib/test_function_base.py b/test/torch_np/numpy_tests/lib/test_function_base.py index 3934613a64fc4..a524e9f6528a5 100644 --- a/test/torch_np/numpy_tests/lib/test_function_base.py +++ b/test/torch_np/numpy_tests/lib/test_function_base.py @@ -11,29 +11,21 @@ import hypothesis import hypothesis.strategies as st -import pytest -import torch._numpy as np +import numpy + +import pytest from hypothesis.extra.numpy import arrays from pytest import raises as assert_raises -from torch._numpy.testing import ( - assert_, - assert_allclose, # IS_PYPY, - assert_almost_equal, - assert_array_almost_equal, - assert_array_equal, - assert_equal, - assert_raises_regex, - assert_warns, - suppress_warnings, # HAS_REFCOUNT, IS_WASM -) from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, run_tests, subtest, + TEST_WITH_TORCHDYNAMO, TestCase, + xpassIfTorchDynamo, ) skip = functools.partial(skipif, True) @@ -47,25 +39,79 @@ # from numpy lib import digitize, piecewise, trapz, select, trim_zeros, interp from numpy.lib import delete, extract, insert, msort, place, setxor1d, unwrap, vectorize -from torch._numpy import ( - angle, - bartlett, - blackman, - corrcoef, - cov, - diff, - flipud, - gradient, - hamming, - hanning, - i0, - kaiser, - meshgrid, - sinc, - unique, -) -from torch._numpy._util import normalize_axis_tuple -from torch._numpy.random import rand + +# If we are going to trace through these, we should use NumPy +# If testing on eager mode, we use torch._numpy +if TEST_WITH_TORCHDYNAMO: + import numpy as np + from numpy import ( + angle, + bartlett, + blackman, + corrcoef, + cov, + diff, + digitize, + flipud, + gradient, + hamming, + hanning, + i0, + interp, + kaiser, + meshgrid, + sinc, + trapz, + trim_zeros, + unique, + ) + from numpy.core.numeric import normalize_axis_tuple + from numpy.random import rand + + from numpy.testing import ( + assert_, + assert_allclose, # IS_PYPY, + assert_almost_equal, + assert_array_almost_equal, + assert_array_equal, + assert_equal, + assert_raises_regex, + assert_warns, + suppress_warnings, # HAS_REFCOUNT, IS_WASM + ) +else: + import torch._numpy as np + from torch._numpy import ( + angle, + bartlett, + blackman, + corrcoef, + cov, + diff, + flipud, + gradient, + hamming, + hanning, + i0, + kaiser, + meshgrid, + sinc, + unique, + ) + from torch._numpy._util import normalize_axis_tuple + from torch._numpy.random import rand + + from torch._numpy.testing import ( + assert_, + assert_allclose, # IS_PYPY, + assert_almost_equal, + assert_array_almost_equal, + assert_array_equal, + assert_equal, + assert_raises_regex, + assert_warns, + suppress_warnings, # HAS_REFCOUNT, IS_WASM + ) def get_mat(n): @@ -251,7 +297,7 @@ def test_basic(self): assert_equal(a[0, 0], 1) assert_equal(a_copy[0, 0], 10) - @xfail # (reason="order='F' not implemented") + @xpassIfTorchDynamo # (reason="order='F' not implemented") def test_order(self): # It turns out that people rely on np.copy() preserving order by # default; changing this broke scikit-learn: @@ -477,7 +523,7 @@ def test_many_arguments(self): select(conditions, choices) -@xfail # (reason="TODO: implement") +@xpassIfTorchDynamo # (reason="TODO: implement") @instantiate_parametrized_tests class TestInsert(TestCase): def test_basic(self): @@ -795,7 +841,7 @@ def test_append(self): assert_raises(np.AxisError, diff, x, append=0, axis=3) -@xfail # (reason="TODO: implement") +@xpassIfTorchDynamo # (reason="TODO: implement") @instantiate_parametrized_tests class TestDelete(TestCase): def setUp(self): @@ -867,7 +913,9 @@ def test_index_floats(self): with pytest.raises(IndexError): np.delete([0, 1, 2], np.array([], dtype=float)) - @parametrize("indexer", [np.array([1]), [1]]) + @parametrize( + "indexer", [subtest(np.array([1]), name="array([1])"), subtest([1], name="[1]")] + ) def test_single_item_array(self, indexer): a_del_int = delete(self.a, 1) a_del = delete(self.a, indexer) @@ -1142,7 +1190,7 @@ def test_basic(self): assert_array_almost_equal(z, zo, 11) -@xfail # (reason="trim_zeros not implemented") +@xpassIfTorchDynamo @instantiate_parametrized_tests class TestTrimZeros(TestCase): a = np.array([0, 0, 1, 0, 2, 3, 4, 0]) @@ -1151,7 +1199,11 @@ class TestTrimZeros(TestCase): # d = a.astype(object) def values(self): - attr_names = ("a", "b", "c", "d") + attr_names = ( + "a", + "b", + "c", + ) # "d") return (getattr(self, name) for name in attr_names) def test_basic(self): @@ -1210,7 +1262,7 @@ def test_list_to_list(self): assert isinstance(res, list) -@xfail # (reason="TODO: implement") +@xpassIfTorchDynamo # (reason="TODO: implement") class TestExtins(TestCase): def test_basic(self): a = np.array([1, 3, 2, 1, 2, 3, 3]) @@ -1612,7 +1664,7 @@ def test_size_zero_output(self): f(x) -@xfail # (reason="TODO: implement") +@xpassIfTorchDynamo # (reason="TODO: implement") class TestDigitize(TestCase): def test_forward(self): x = np.arange(-6, 5) @@ -1716,7 +1768,9 @@ def test_period(self): @instantiate_parametrized_tests class TestFilterwindows(TestCase): - @parametrize("dtype", np.typecodes["AllInteger"] + np.typecodes["Float"]) + @parametrize( + "dtype", "Bbhil" + "efd" + ) # np.typecodes["AllInteger"] + np.typecodes["Float"]) @parametrize("M", [0, 1, 10]) def test_hanning(self, dtype: str, M: int) -> None: scalar = M @@ -1736,7 +1790,9 @@ def test_hanning(self, dtype: str, M: int) -> None: else: assert_almost_equal(np.sum(w, axis=0), 4.500, 4) - @parametrize("dtype", np.typecodes["AllInteger"] + np.typecodes["Float"]) + @parametrize( + "dtype", "Bbhil" + "efd" + ) # np.typecodes["AllInteger"] + np.typecodes["Float"]) @parametrize("M", [0, 1, 10]) def test_hamming(self, dtype: str, M: int) -> None: scalar = M @@ -1756,7 +1812,9 @@ def test_hamming(self, dtype: str, M: int) -> None: else: assert_almost_equal(np.sum(w, axis=0), 4.9400, 4) - @parametrize("dtype", np.typecodes["AllInteger"] + np.typecodes["Float"]) + @parametrize( + "dtype", "Bbhil" + "efd" + ) # np.typecodes["AllInteger"] + np.typecodes["Float"]) @parametrize("M", [0, 1, 10]) def test_bartlett(self, dtype: str, M: int) -> None: scalar = M @@ -1776,7 +1834,9 @@ def test_bartlett(self, dtype: str, M: int) -> None: else: assert_almost_equal(np.sum(w, axis=0), 4.4444, 4) - @parametrize("dtype", np.typecodes["AllInteger"] + np.typecodes["Float"]) + @parametrize( + "dtype", "Bbhil" + "efd" + ) # np.typecodes["AllInteger"] + np.typecodes["Float"]) @parametrize("M", [0, 1, 10]) def test_blackman(self, dtype: str, M: int) -> None: scalar = M @@ -1796,7 +1856,9 @@ def test_blackman(self, dtype: str, M: int) -> None: else: assert_almost_equal(np.sum(w, axis=0), 3.7800, 4) - @parametrize("dtype", np.typecodes["AllInteger"] + np.typecodes["Float"]) + @parametrize( + "dtype", "Bbhil" + "efd" + ) # np.typecodes["AllInteger"] + np.typecodes["Float"]) @parametrize("M", [0, 1, 10]) def test_kaiser(self, dtype: str, M: int) -> None: scalar = M @@ -1817,7 +1879,7 @@ def test_kaiser(self, dtype: str, M: int) -> None: assert_almost_equal(np.sum(w, axis=0), 10, 15) -@xfail # (reason="TODO: implement") +@xpassIfTorchDynamo # (reason="TODO: implement") class TestTrapz(TestCase): def test_simple(self): x = np.arange(-10, 10, 0.1) @@ -1886,13 +1948,13 @@ def test_simple(self): assert_(unique(np.array([1, 1, 1, 1, 1])) == np.array([1])) - @xfail # (reason="unique not implemented for 'ComplexDouble'") + @xpassIfTorchDynamo # (reason="unique not implemented for 'ComplexDouble'") def test_simple_complex(self): x = np.array([5 + 6j, 1 + 1j, 1 + 10j, 10, 5 + 6j]) assert_(np.all(unique(x) == [1 + 1j, 1 + 10j, 5 + 6j, 10])) -@xfail # (reason="TODO: implement") +@xpassIfTorchDynamo # (reason="TODO: implement") class TestCheckFinite(TestCase): def test_simple(self): a = [1, 2, 3] @@ -2537,7 +2599,19 @@ def test_error_not_1d(self, vals): np.bincount(vals) -@xfail # (reason="TODO: implement") +parametrize_interp_sc = parametrize( + "sc", + [ + subtest(lambda x: np.float_(x), name="real"), + subtest(lambda x: _make_complex(x, 0), name="complex-real"), + subtest(lambda x: _make_complex(0, x), name="complex-imag"), + subtest(lambda x: _make_complex(x, np.multiply(x, -2)), name="complex-both"), + ], +) + + +@xpassIfTorchDynamo # (reason="TODO: implement") +@instantiate_parametrized_tests class TestInterp(TestCase): def test_exceptions(self): assert_raises(ValueError, interp, 0, [], []) @@ -2612,19 +2686,7 @@ def test_non_finite_behavior_exact_x(self): fp = [1, 2, np.nan, 4] assert_almost_equal(np.interp(x, xp, fp), [1, 2, np.nan, np.nan, 4]) - @pytest.fixture( - params=[ - lambda x: np.float_(x), - lambda x: _make_complex(x, 0), - lambda x: _make_complex(0, x), - lambda x: _make_complex(x, np.multiply(x, -2)), - ], - ids=["real", "complex-real", "complex-imag", "complex-both"], - ) - def sc(self, request): - """scale function used by the below tests""" - return request.param - + @parametrize_interp_sc def test_non_finite_any_nan(self, sc): """test that nans are propagated""" assert_equal(np.interp(0.5, [np.nan, 1], sc([0, 10])), sc(np.nan)) @@ -2632,6 +2694,7 @@ def test_non_finite_any_nan(self, sc): assert_equal(np.interp(0.5, [0, 1], sc([np.nan, 10])), sc(np.nan)) assert_equal(np.interp(0.5, [0, 1], sc([0, np.nan])), sc(np.nan)) + @parametrize_interp_sc def test_non_finite_inf(self, sc): """Test that interp between opposite infs gives nan""" assert_equal(np.interp(0.5, [-np.inf, +np.inf], sc([0, 10])), sc(np.nan)) @@ -2641,6 +2704,7 @@ def test_non_finite_inf(self, sc): # unless the y values are equal assert_equal(np.interp(0.5, [-np.inf, +np.inf], sc([10, 10])), sc(10)) + @parametrize_interp_sc def test_non_finite_half_inf_xf(self, sc): """Test that interp where both axes have a bound at inf gives nan""" assert_equal(np.interp(0.5, [-np.inf, 1], sc([-np.inf, 10])), sc(np.nan)) @@ -2652,6 +2716,7 @@ def test_non_finite_half_inf_xf(self, sc): assert_equal(np.interp(0.5, [0, +np.inf], sc([0, -np.inf])), sc(np.nan)) assert_equal(np.interp(0.5, [0, +np.inf], sc([0, +np.inf])), sc(np.nan)) + @parametrize_interp_sc def test_non_finite_half_inf_x(self, sc): """Test interp where the x axis has a bound at inf""" assert_equal(np.interp(0.5, [-np.inf, -np.inf], sc([0, 10])), sc(10)) @@ -2659,6 +2724,7 @@ def test_non_finite_half_inf_x(self, sc): assert_equal(np.interp(0.5, [0, +np.inf], sc([0, 10])), sc(0)) assert_equal(np.interp(0.5, [+np.inf, +np.inf], sc([0, 10])), sc(0)) + @parametrize_interp_sc def test_non_finite_half_inf_f(self, sc): """Test interp where the f axis has a bound at inf""" assert_equal(np.interp(0.5, [0, 1], sc([0, -np.inf])), sc(-np.inf)) @@ -2786,7 +2852,7 @@ def test_2D(self): x = np.array([[1, 1, 1], [1, 1, 1], [4, 4, 3], [1, 1, 1], [1, 1, 1]]) assert_array_equal(np.percentile(x, 50, axis=0), [1, 1, 1]) - @xfail # (reason="TODO: implement") + @xpassIfTorchDynamo # (reason="TODO: implement") @parametrize("dtype", np.typecodes["Float"]) def test_linear_nan_1D(self, dtype): # METHOD 1 of H&F @@ -2796,14 +2862,14 @@ def test_linear_nan_1D(self, dtype): np.testing.assert_equal(res.dtype, arr.dtype) H_F_TYPE_CODES = [ - (int_type, np.float64) for int_type in np.typecodes["AllInteger"] + (int_type, np.float64) for int_type in "Bbhil" # np.typecodes["AllInteger"] ] + [ (np.float16, np.float16), (np.float32, np.float32), (np.float64, np.float64), ] - @xfail # (reason="TODO: implement percentile interpolations") + @skip(reason="NEP 50 is new in 1.24") @parametrize("input_dtype, expected_dtype", H_F_TYPE_CODES) @parametrize( "method, expected", @@ -2821,7 +2887,11 @@ def test_linear_nan_1D(self, dtype): ) def test_linear_interpolation(self, method, expected, input_dtype, expected_dtype): expected_dtype = np.dtype(expected_dtype) - if np._get_promotion_state() == "legacy": + + if ( + hasattr(np, "_get_promotion_state") + and np._get_promotion_state() == "legacy" + ): expected_dtype = np.promote_types(expected_dtype, np.float64) arr = np.asarray([15.0, 20.0, 35.0, 40.0, 50.0], dtype=input_dtype) @@ -3076,7 +3146,7 @@ def test_percentile_overwrite(self): b = np.percentile([2, 3, 4, 1], [50], overwrite_input=True) assert_equal(b, np.array([2.5])) - @xfail # (reason="pytorch percentile does not support tuple axes.") + @xpassIfTorchDynamo # (reason="pytorch percentile does not support tuple axes.") def test_extended_axis(self): o = np.random.normal(size=(71, 23)) x = np.dstack([o] * 10) @@ -3165,6 +3235,7 @@ def test_keepdims_2(self): np.percentile(d, [1, 7], axis=(0, 3), keepdims=True).shape, (2, 1, 5, 7, 1) ) + @skipif(numpy.__version__ < "1.24", reason="NP_VER: fails on NumPy 1.23.x") @parametrize( "q", [ @@ -3172,7 +3243,7 @@ def test_keepdims_2(self): subtest( [1, 7], decorators=[ - xfail, + xpassIfTorchDynamo, ], ), ], @@ -3186,13 +3257,13 @@ def test_keepdims_2(self): subtest( (0, 1), decorators=[ - xfail, + xpassIfTorchDynamo, ], ), subtest( (-3, -1), decorators=[ - xfail, + xpassIfTorchDynamo, ], ), ], @@ -3242,7 +3313,7 @@ def test_out_nan(self): assert_equal(np.percentile(d, 1, out=o), o) assert_equal(np.percentile(d, 1, method="nearest", out=o), o) - @xfail # (reason="np.percentile undocumented nan weirdness") + @xpassIfTorchDynamo # (reason="np.percentile undocumented nan weirdness") def test_nan_behavior(self): a = np.arange(24, dtype=float) a[2] = np.nan @@ -3335,7 +3406,7 @@ def test_basic(self): assert_equal(np.quantile(x, 1), 3.5) assert_equal(np.quantile(x, 0.5), 1.75) - @xfail # (reason="quantile w/integers or bools") + @xpassIfTorchDynamo # (reason="quantile w/integers or bools") def test_correct_quantile_value(self): a = np.array([True]) tf_quant = np.quantile(True, False) @@ -3394,8 +3465,8 @@ def test_no_p_overwrite(self): np.quantile(np.arange(100.0), p, method="midpoint") assert_array_equal(p, p0) - @xfail # (reason="TODO: make quantile preserve integers") - @parametrize("dtype", np.typecodes["AllInteger"]) + @xpassIfTorchDynamo # (reason="TODO: make quantile preserve integers") + @parametrize("dtype", "Bbhil") # np.typecodes["AllInteger"]) def test_quantile_preserve_int_type(self, dtype): res = np.quantile(np.array([1, 2], dtype=dtype), [0.5], method="nearest") assert res.dtype == dtype @@ -3406,50 +3477,50 @@ def test_quantile_preserve_int_type(self, dtype): subtest( "inverted_cdf", decorators=[ - xfail, + xpassIfTorchDynamo, ], ), subtest( "averaged_inverted_cdf", decorators=[ - xfail, + xpassIfTorchDynamo, ], ), subtest( "closest_observation", decorators=[ - xfail, + xpassIfTorchDynamo, ], ), subtest( "interpolated_inverted_cdf", decorators=[ - xfail, + xpassIfTorchDynamo, ], ), subtest( "hazen", decorators=[ - xfail, + xpassIfTorchDynamo, ], ), subtest( "weibull", decorators=[ - xfail, + xpassIfTorchDynamo, ], ), "linear", subtest( "median_unbiased", decorators=[ - xfail, + xpassIfTorchDynamo, ], ), subtest( "normal_unbiased", decorators=[ - xfail, + xpassIfTorchDynamo, ], ), "nearest", @@ -3517,7 +3588,7 @@ def test_basic(self): a = np.array([0.0444502, 0.141249, 0.0463301]) assert_equal(a[-1], np.median(a)) - @xfail # (reason="median: scalar output vs 0-dim") + @xpassIfTorchDynamo # (reason="median: scalar output vs 0-dim") def test_basic_2(self): # check array scalar result a = np.array([0.0444502, 0.141249, 0.0463301]) @@ -3626,7 +3697,7 @@ def test_nan_behavior(self): b[1, 2] = np.nan assert_equal(np.median(a, 1), b) - @xfail # (reason="median: does not support tuple axes") + @xpassIfTorchDynamo # (reason="median: does not support tuple axes") def test_nan_behavior_2(self): a = np.arange(24, dtype=float).reshape(2, 3, 4) a[1, 2, 3] = np.nan @@ -3638,7 +3709,7 @@ def test_nan_behavior_2(self): b[2] = np.nan assert_equal(np.median(a, (0, 2)), b) - @xfail # (reason="median: scalar vs 0-dim") + @xpassIfTorchDynamo # (reason="median: scalar vs 0-dim") def test_nan_behavior_3(self): a = np.arange(24, dtype=float).reshape(2, 3, 4) a[1, 2, 3] = np.nan @@ -3647,7 +3718,7 @@ def test_nan_behavior_3(self): # no axis assert_equal(np.median(a).ndim, 0) - @xfail # (reason="median: torch.quantile does not handle empty tensors") + @xpassIfTorchDynamo # (reason="median: torch.quantile does not handle empty tensors") @skipif(IS_WASM, reason="fp errors don't work correctly") def test_empty(self): # mean(empty array) emits two warnings: empty slice and divide by 0 @@ -3678,7 +3749,7 @@ def test_empty(self): assert_equal(np.median(a, axis=2), b) assert_(w[0].category is RuntimeWarning) - @xfail # (reason="median: tuple axes not implemented") + @xpassIfTorchDynamo # (reason="median: tuple axes not implemented") def test_extended_axis(self): o = np.random.normal(size=(71, 23)) x = np.dstack([o] * 10) @@ -3728,7 +3799,7 @@ def test_keepdims(self): d = np.ones((3, 5, 7, 11)) assert_equal(np.median(d, axis=None, keepdims=True).shape, (1, 1, 1, 1)) - @xfail # (reason="median: tuple axis") + @xpassIfTorchDynamo # (reason="median: tuple axis") def test_keepdims_2(self): d = np.ones((3, 5, 7, 11)) assert_equal(np.median(d, axis=(0, 1), keepdims=True).shape, (1, 1, 7, 11)) @@ -3737,6 +3808,7 @@ def test_keepdims_2(self): assert_equal(np.median(d, axis=(0, 1, 2, 3), keepdims=True).shape, (1, 1, 1, 1)) assert_equal(np.median(d, axis=(0, 1, 3), keepdims=True).shape, (1, 1, 7, 1)) + @skipif(numpy.__version__ < "1.24", reason="NP_VER: fails on NumPy 1.23.x") @parametrize( "axis", [ @@ -3746,13 +3818,13 @@ def test_keepdims_2(self): subtest( (0, 1), decorators=[ - xfail, + xpassIfTorchDynamo, ], ), subtest( (-3, -1), decorators=[ - xfail, + xpassIfTorchDynamo, ], ), ], @@ -3772,7 +3844,7 @@ def test_keepdims_out(self, axis): assert_equal(result.shape, shape_out) -@xfail # (reason="TODO: implement") +@xpassIfTorchDynamo # (reason="TODO: implement") @instantiate_parametrized_tests class TestSortComplex(TestCase): @parametrize( diff --git a/test/torch_np/numpy_tests/lib/test_histograms.py b/test/torch_np/numpy_tests/lib/test_histograms.py index 9d6b0364fc2d2..4b09ef5b207b4 100644 --- a/test/torch_np/numpy_tests/lib/test_histograms.py +++ b/test/torch_np/numpy_tests/lib/test_histograms.py @@ -3,32 +3,46 @@ # from numpy.testing._private.utils import requires_memory import functools -from unittest import expectedFailure as xfail, skipIf +from unittest import skipIf -import pytest -import torch._numpy as np from pytest import raises as assert_raises -from torch._numpy import histogram, histogramdd - -# from numpy.lib.histograms import histogram, histogramdd, histogram_bin_edges -from torch._numpy.testing import ( - assert_, - assert_allclose, - assert_almost_equal, - assert_array_almost_equal, - assert_array_equal, - assert_equal, - # assert_array_max_ulp, #assert_raises_regex, suppress_warnings, -) + +skip = functools.partial(skipIf, True) + from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, run_tests, slowTest as slow, + TEST_WITH_TORCHDYNAMO, TestCase, + xpassIfTorchDynamo, ) -skip = functools.partial(skipIf, True) +if TEST_WITH_TORCHDYNAMO: + import numpy as np + from numpy import histogram, histogram_bin_edges, histogramdd + from numpy.testing import ( + assert_, + assert_allclose, + assert_almost_equal, + assert_array_almost_equal, + assert_array_equal, + assert_equal, + # assert_array_max_ulp, #assert_raises_regex, suppress_warnings, + ) +else: + import torch._numpy as np + from torch._numpy import histogram, histogramdd + from torch._numpy.testing import ( + assert_, + assert_allclose, + assert_almost_equal, + assert_array_almost_equal, + assert_array_equal, + assert_equal, + # assert_array_max_ulp, #assert_raises_regex, suppress_warnings, + ) class TestHistogram(TestCase): @@ -189,7 +203,7 @@ def test_weights(self): ) assert_almost_equal(a, [0.2, 0.1, 0.1, 0.075]) - @xfail # (reason="histogram complex weights") + @xpassIfTorchDynamo # (reason="histogram complex weights") def test_exotic_weights(self): # Test the use of weights that are not integer or floats, but e.g. # complex numbers or object types. @@ -251,7 +265,7 @@ def test_invalid_range(self): with assert_raises((RuntimeError, ValueError)): np.histogram(vals, range=[0.1, 0.01]) - @xfail # (reason="edge cases") + @xpassIfTorchDynamo # (reason="edge cases") def test_bin_edge_cases(self): # Ensure that floating-point computations correctly place edge cases. arr = np.array([337, 404, 739, 806, 1007, 1811, 2012]) @@ -275,7 +289,7 @@ def test_bin_array_dims(self): with assert_raises((RuntimeError, ValueError)): np.histogram(vals, bins=bins) - @xfail # (reason="no uint64") + @xpassIfTorchDynamo # (reason="no uint64") def test_unsigned_monotonicity_check(self): # Ensures ValueError is raised if bins not increasing monotonically # when bins contain unsigned values (see #9222) @@ -301,7 +315,7 @@ def test_object_array_of_0d(self): np.histogram([np.array(0.5) for i in range(10)] + [0.500000000000001]) np.histogram([np.array(0.5) for i in range(10)] + [0.5]) - @xfail # (reason="bins='auto'") + @xpassIfTorchDynamo # (reason="bins='auto'") def test_some_nan_values(self): # gh-7503 one_nan = np.array([0, 1, np.nan]) @@ -339,7 +353,7 @@ def test_signed_overflow_bounds(self): self.do_signed_overflow_bounds(np.short) self.do_signed_overflow_bounds(np.intc) - @xfail # (reason="int->float conversin loses precision") + @xpassIfTorchDynamo # (reason="int->float conversin loses precision") def test_signed_overflow_bounds_2(self): self.do_signed_overflow_bounds(np.int_) self.do_signed_overflow_bounds(np.longlong) @@ -382,14 +396,14 @@ def do_precision(self, float_small, float_large): self.do_precision_lower_bound(float_small, float_large) self.do_precision_upper_bound(float_small, float_large) - @xfail # (reason="mixed dtypes") + @xpassIfTorchDynamo # (reason="mixed dtypes") def test_precision(self): # not looping results in a useful stack trace upon failure self.do_precision(np.half, np.single) self.do_precision(np.half, np.double) self.do_precision(np.single, np.double) - @xfail # (reason="histogram_bin_edges") + @xpassIfTorchDynamo # (reason="histogram_bin_edges") def test_histogram_bin_edges(self): hist, e = histogram([1, 2, 3, 4], [1, 2]) edges = histogram_bin_edges([1, 2, 3, 4], [1, 2]) @@ -405,7 +419,7 @@ def test_histogram_bin_edges(self): assert_array_equal(edges, e) # @requires_memory(free_bytes=1e10) - @xfail # (reason="pytorch does not support bins = [int, int, array]") + @xpassIfTorchDynamo # (reason="pytorch does not support bins = [int, int, array]") @slow def test_big_arrays(self): sample = np.zeros([100000000, 3]) @@ -416,7 +430,7 @@ def test_big_arrays(self): assert_equal(type(hist), type((1, 2))) -@xfail # (reason="TODO") +@xpassIfTorchDynamo # (reason="TODO") @instantiate_parametrized_tests class TestHistogramOptimBinNums(TestCase): """ @@ -698,7 +712,6 @@ def test_simple_weighted(self): """ Check that weighted data raises a TypeError """ - pytest.xpass(reason="passes by chance") estimator_list = ["fd", "scott", "rice", "sturges", "auto"] for estimator in estimator_list: assert_raises(TypeError, histogram, [1, 2, 3], estimator, weights=[1, 2, 3]) @@ -840,13 +853,13 @@ def test_bins_errors(self): (RuntimeError, ValueError), np.histogramdd, x, bins=[1, 1, 1, [1, 2, 3, -3]] ) - @xfail # (reason="pytorch does not support bins = [int, int, array]") + @xpassIfTorchDynamo # (reason="pytorch does not support bins = [int, int, array]") def test_bins_error_2(self): # mixing scalar (# of bins) and explicit bin arrays, ugh x = np.arange(8).reshape(2, 4) assert_(np.histogramdd(x, bins=[1, 1, 1, [1, 2, 3, 4]])) - @xfail # (reason="pytorch does not support bins = [int, int, array]") + @xpassIfTorchDynamo # (reason="pytorch does not support bins = [int, int, array]") def test_inf_edges(self): # Test using +/-inf bin edges works. See #1788. x = np.arange(6).reshape(3, 2) @@ -897,7 +910,7 @@ def test_finite_range(self): range=[[0.0, 1.0], [np.nan, 0.75], [0.25, 0.5]], ) - @xfail # (reason="pytorch does not allow equal entries") + @xpassIfTorchDynamo # (reason="pytorch does not allow equal entries") def test_equal_edges(self): """Test that adjacent entries in an edge array can be equal""" x = np.array([0, 1, 2]) @@ -928,7 +941,7 @@ def test_edge_dtype(self): def test_large_integers(self): big = 2**60 # Too large to represent with a full precision float - x = np.array([0], np.int64) + x = np.asarray([0], dtype=np.int64) x_edges = np.array([-1, +1], np.int64) y = big + x y_edges = big + x_edges diff --git a/test/torch_np/numpy_tests/lib/test_index_tricks.py b/test/torch_np/numpy_tests/lib/test_index_tricks.py index d3aac7663ec2e..e43e33be03946 100644 --- a/test/torch_np/numpy_tests/lib/test_index_tricks.py +++ b/test/torch_np/numpy_tests/lib/test_index_tricks.py @@ -4,29 +4,52 @@ from unittest import expectedFailure as xfail, skipIf -import torch._numpy as np - from pytest import raises as assert_raises # , assert_raises_regex, -from torch._numpy import diag_indices, diag_indices_from, fill_diagonal, index_exp, s_ -from torch._numpy.testing import ( - assert_, - assert_almost_equal, - assert_array_almost_equal, - assert_array_equal, - assert_equal, -) from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, run_tests, + TEST_WITH_TORCHDYNAMO, TestCase, + xpassIfTorchDynamo, ) skip = functools.partial(skipIf, True) -@xfail # (reason="unravel_index not implemented") +# If we are going to trace through these, we should use NumPy +# If testing on eager mode, we use torch._numpy +if TEST_WITH_TORCHDYNAMO: + import numpy as np + from numpy import diag_indices, diag_indices_from, fill_diagonal, index_exp, s_ + from numpy.testing import ( + assert_, + assert_almost_equal, + assert_array_almost_equal, + assert_array_equal, + assert_equal, + assert_raises_regex, + ) +else: + import torch._numpy as np + from torch._numpy import ( + diag_indices, + diag_indices_from, + fill_diagonal, + index_exp, + s_, + ) + from torch._numpy.testing import ( + assert_, + assert_almost_equal, + assert_array_almost_equal, + assert_array_equal, + assert_equal, + ) + + +@xpassIfTorchDynamo # (reason="unravel_index not implemented") @instantiate_parametrized_tests class TestRavelUnravelIndex(TestCase): def test_basic(self): @@ -428,7 +451,7 @@ def test_repeated_input(self): class TestC(TestCase): - @xfail # (reason="c_ not implemented") + @xpassIfTorchDynamo # (reason="c_ not implemented") def test_c_(self): a = np.c_[np.array([[1, 2, 3]]), 0, 0, np.array([[4, 5, 6]])] assert_equal(a, [[1, 2, 3, 0, 0, 4, 5, 6]]) diff --git a/test/torch_np/numpy_tests/lib/test_shape_base_.py b/test/torch_np/numpy_tests/lib/test_shape_base_.py index 673d1ed0b537e..20c04f3e1a215 100644 --- a/test/torch_np/numpy_tests/lib/test_shape_base_.py +++ b/test/torch_np/numpy_tests/lib/test_shape_base_.py @@ -5,34 +5,62 @@ from unittest import expectedFailure as xfail, skipIf as skipif -import torch._numpy as np - from pytest import raises as assert_raises -from torch._numpy import ( - array_split, - column_stack, - dsplit, - dstack, - expand_dims, - hsplit, - kron, - put_along_axis, - split, - take_along_axis, - tile, - vsplit, -) -from torch._numpy.random import rand, randint - -from torch._numpy.testing import assert_, assert_array_equal, assert_equal from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, run_tests, + TEST_WITH_TORCHDYNAMO, TestCase, + xfailIfTorchDynamo, + xpassIfTorchDynamo, ) + +# If we are going to trace through these, we should use NumPy +# If testing on eager mode, we use torch._numpy +if TEST_WITH_TORCHDYNAMO: + import numpy as np + from numpy import ( + apply_along_axis, + array_split, + column_stack, + dsplit, + dstack, + expand_dims, + hsplit, + kron, + put_along_axis, + split, + take_along_axis, + tile, + vsplit, + ) + from numpy.random import rand, randint + + from numpy.testing import assert_, assert_array_equal, assert_equal + +else: + import torch._numpy as np + from torch._numpy import ( + array_split, + column_stack, + dsplit, + dstack, + expand_dims, + hsplit, + kron, + put_along_axis, + split, + take_along_axis, + tile, + vsplit, + ) + from torch._numpy.random import rand, randint + from torch._numpy.testing import assert_, assert_array_equal, assert_equal + + skip = functools.partial(skipif, True) @@ -126,7 +154,7 @@ def test_replace_max(self): assert_equal(i_min, i_max) - @xfail # ( + @xpassIfTorchDynamo # ( # reason="RuntimeError: Expected index [1, 2, 5] to be smaller than self [3, 4, 1] apart from dimension 1") def test_broadcast(self): """Test that non-indexing dimensions are broadcast in both directions""" @@ -136,7 +164,7 @@ def test_broadcast(self): assert_equal(take_along_axis(a, ai, axis=1), 20) -@xfail # (reason="apply_along_axis not implemented") +@xpassIfTorchDynamo # (reason="apply_along_axis not implemented") class TestApplyAlongAxis(TestCase): def test_simple(self): a = np.ones((20, 10), "d") @@ -679,6 +707,8 @@ def test_basic(self): assert_equal(res.ndim, 0) assert type(res) is np.ndarray + @xfailIfTorchDynamo + def test_basic_2(self): aa = np.ones((3, 1, 4, 1, 1)) assert aa.squeeze().tensor._base is aa.tensor @@ -712,7 +742,7 @@ def test_squeeze_contiguous(self): assert_(a.flags.f_contiguous) assert_(b.flags.f_contiguous) - @xfail # (reason="XXX: noop in torch, while numpy raises") + @xpassIfTorchDynamo # (reason="XXX: noop in torch, while numpy raises") def test_squeeze_axis_handling(self): with assert_raises(ValueError): np.squeeze(np.array([[1], [2], [3]]), axis=0) @@ -810,7 +840,7 @@ def test_kroncompare(self): assert_equal(large, klarge) -@xfail # (reason="TODO: implement") +@xpassIfTorchDynamo # (reason="TODO: implement") class TestMayShareMemory(TestCase): def test_basic(self): d = np.ones((50, 60)) diff --git a/test/torch_np/numpy_tests/lib/test_twodim_base.py b/test/torch_np/numpy_tests/lib/test_twodim_base.py index bbf9fd1bbc5cd..dda807b556369 100644 --- a/test/torch_np/numpy_tests/lib/test_twodim_base.py +++ b/test/torch_np/numpy_tests/lib/test_twodim_base.py @@ -8,40 +8,72 @@ from unittest import expectedFailure as xfail, skipIf as skipif import pytest - -import torch._numpy as np from pytest import raises as assert_raises -from torch._numpy import ( - arange, - array, - diag, - eye, - fliplr, - flipud, - histogram2d, - ones, - tri, # mask_indices, - tril_indices, - tril_indices_from, - triu_indices, - triu_indices_from, - vander, - zeros, -) -from torch._numpy.testing import ( - assert_allclose, - assert_array_almost_equal, - assert_array_equal, # assert_array_max_ulp, - assert_equal, -) from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, run_tests, + TEST_WITH_TORCHDYNAMO, TestCase, + xpassIfTorchDynamo, ) + +# If we are going to trace through these, we should use NumPy +# If testing on eager mode, we use torch._numpy +if TEST_WITH_TORCHDYNAMO: + import numpy as np + from numpy import ( + arange, + array, + diag, + eye, + fliplr, + flipud, + histogram2d, + ones, + tri, # mask_indices, + tril_indices, + tril_indices_from, + triu_indices, + triu_indices_from, + vander, + zeros, + ) + from numpy.testing import ( + assert_allclose, + assert_array_almost_equal, + assert_array_equal, # assert_array_max_ulp, + assert_equal, + ) +else: + import torch._numpy as np + from torch._numpy import ( + arange, + array, + diag, + eye, + fliplr, + flipud, + histogram2d, + ones, + tri, # mask_indices, + tril_indices, + tril_indices_from, + triu_indices, + triu_indices_from, + vander, + zeros, + ) + from torch._numpy.testing import ( + assert_allclose, + assert_array_almost_equal, + assert_array_equal, # assert_array_max_ulp, + assert_equal, + ) + + skip = functools.partial(skipif, True) @@ -101,7 +133,7 @@ def test_eye_bounds(self): def test_bool(self): assert_equal(eye(2, 2, dtype=bool), [[True, False], [False, True]]) - @xfail # (reason="TODO: implement order=non-default") + @xpassIfTorchDynamo # (reason="TODO: implement order=non-default") def test_order(self): mat_c = eye(4, 3, k=-1) mat_f = eye(4, 3, k=-1, order="F") @@ -127,9 +159,10 @@ def test_vector(self): assert_equal(diag(vals, k=2), b) assert_equal(diag(vals, k=-2), c) - def test_matrix(self, vals=None): - if vals is None: - vals = (100 * get_mat(5) + 1).astype("l") + def test_matrix(self): + self.check_matrix(vals=(100 * get_mat(5) + 1).astype("l")) + + def check_matrix(self, vals): b = zeros((5,)) for k in range(5): b[k] = vals[k, k] @@ -142,10 +175,10 @@ def test_matrix(self, vals=None): b[k] = vals[k + 2, k] assert_equal(diag(vals, -2), b[:3]) - @xfail # (reason="TODO implement orders") + @xpassIfTorchDynamo # (reason="TODO implement orders") def test_fortran_order(self): vals = array((100 * get_mat(5) + 1), order="F", dtype="l") - self.test_matrix(vals) + self.check_matrix(vals) def test_diag_bounds(self): A = [[1, 2], [3, 4], [5, 6]] @@ -251,7 +284,7 @@ def test_empty(self): # assert_array_max_ulp(a, np.zeros((4, 4))) assert_allclose(a, np.zeros((4, 4)), atol=1e-15) - @xfail # (reason="pytorch does not support bins = [int, array]") + @xpassIfTorchDynamo # (reason="pytorch does not support bins = [int, array]") def test_binparameter_combination(self): x = array([0, 0.09207008, 0.64575234, 0.12875982, 0.47390599, 0.59944483, 1]) y = array([0, 0.14344267, 0.48988575, 0.30558665, 0.44700682, 0.15886423, 1]) @@ -285,6 +318,7 @@ def test_binparameter_combination(self): assert_array_equal(H, answer) assert_array_equal(xe, array([0.0, 0.25, 0.5, 0.75, 1])) + @skip(reason="NP_VER: fails on CI with older NumPy") @parametrize("x_len, y_len", [(10, 11), (20, 19)]) def test_bad_length(self, x_len, y_len): x, y = np.ones(x_len), np.ones(y_len) @@ -368,7 +402,7 @@ def test_mask_indices(self): iu1 = mask_indices(3, np.triu, 1) assert_array_equal(a[iu1], array([1, 2, 5])) - @xfail # (reason="np.tril_indices == our tuple(tril_indices)") + @xpassIfTorchDynamo # (reason="np.tril_indices == our tuple(tril_indices)") def test_tril_indices(self): # indices without and with offset il1 = tril_indices(4) @@ -428,7 +462,7 @@ def test_tril_indices(self): ) -@xfail # (reason="np.triu_indices == our tuple(triu_indices)") +@xpassIfTorchDynamo # (reason="np.triu_indices == our tuple(triu_indices)") class TestTriuIndices(TestCase): def test_triu_indices(self): iu1 = triu_indices(4) diff --git a/test/torch_np/numpy_tests/lib/test_type_check.py b/test/torch_np/numpy_tests/lib/test_type_check.py index 0afa518edb228..96c0ddbc9672b 100644 --- a/test/torch_np/numpy_tests/lib/test_type_check.py +++ b/test/torch_np/numpy_tests/lib/test_type_check.py @@ -5,22 +5,44 @@ from unittest import expectedFailure as xfail, skipIf as skipif -import torch._numpy as np from pytest import raises as assert_raises - -from torch._numpy import ( - common_type, - iscomplex, - iscomplexobj, - isneginf, - isposinf, - isreal, - isrealobj, - nan_to_num, - real_if_close, +from torch.testing._internal.common_utils import ( + run_tests, + TEST_WITH_TORCHDYNAMO, + TestCase, + xpassIfTorchDynamo, ) -from torch._numpy.testing import assert_, assert_array_equal, assert_equal -from torch.testing._internal.common_utils import run_tests, TestCase + + +if TEST_WITH_TORCHDYNAMO: + import numpy as np + from numpy import ( + common_type, + iscomplex, + iscomplexobj, + isneginf, + isposinf, + isreal, + isrealobj, + nan_to_num, + real_if_close, + ) + from numpy.testing import assert_, assert_array_equal, assert_equal +else: + import torch._numpy as np + from torch._numpy import ( + common_type, + iscomplex, + iscomplexobj, + isneginf, + isposinf, + isreal, + isrealobj, + nan_to_num, + real_if_close, + ) + from torch._numpy.testing import assert_, assert_array_equal, assert_equal + skip = functools.partial(skipif, True) @@ -29,7 +51,7 @@ def assert_all(x): assert_(np.all(x), x) -@xfail # (reason="common_type not implemented") +@xpassIfTorchDynamo # (reason="common_type not implemented") class TestCommonType(TestCase): def test_basic(self): ai32 = np.array([[1, 2], [3, 4]], dtype=np.int32) @@ -96,7 +118,7 @@ def test_default_3(self): assert_equal(mintypecode("idD"), "D") -@xfail # (reason="TODO: decide on if [1] is a scalar or not") +@xpassIfTorchDynamo # (reason="TODO: decide on if [1] is a scalar or not") class TestIsscalar(TestCase): def test_basic(self): assert_(np.isscalar(3)) From e660bd142267d5b608afe32d9793a860a38f1a91 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Thu, 26 Oct 2023 22:16:33 +0000 Subject: [PATCH 13/78] Re-enable some embedded bag tests (#111712) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit They were temporary disabled in 2019 by https://github.com/pytorch/pytorch/pull/26599 As suggested, increased relative tolerance from 0 to 2% when tests are using float16 dtype ### 🤖 Generated by Copilot at 1e49d84 > _`TestEmbeddingNN`_ > _CUDA tests restored_ > _Bug fixed in autumn breeze_ Pull Request resolved: https://github.com/pytorch/pytorch/pull/111712 Approved by: https://github.com/huydhn --- test/nn/test_embedding.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/nn/test_embedding.py b/test/nn/test_embedding.py index 5d876a1556996..d0bf1a63b67a5 100644 --- a/test/nn/test_embedding.py +++ b/test/nn/test_embedding.py @@ -1009,16 +1009,17 @@ def _test_EmbeddingBag_vs_Embedding(self, N, D, B, L, max_norm=None, # We have more floating point error here because we are dealing with larger numbers if backward_prec is None: needed_prec = dtype2prec_DONTUSE[wdtype] * 5 + rtol = 0.02 if wdtype == torch.half else 0 else: needed_prec = backward_prec + rtol = 0 - self.assertEqual(es_weight_grad, e.weight.grad, atol=needed_prec, rtol=0) + self.assertEqual(es_weight_grad, e.weight.grad, atol=needed_prec, rtol=rtol) if test_per_sample_weights and trainable_per_sample_weights: self.assertEqual(per_sample_weights.grad, per_sample_weights_reference.grad, atol=dtype2prec_DONTUSE[wdtype], rtol=0) - @skipCUDAIf(True, "Temporarily disabled. See t54369166") @dtypesIfCUDA(*itertools.product((torch.int, torch.long), (torch.half, torch.float, torch.double))) @dtypes(*itertools.product((torch.int, torch.long), (torch.float, torch.double))) def test_EmbeddingBag_per_sample_weights_and_no_offsets(self, device, dtypes): From 73cc5d1cdda118007ccdb0be8d775ba76726596e Mon Sep 17 00:00:00 2001 From: Shunting Zhang Date: Thu, 26 Oct 2023 11:07:05 -0700 Subject: [PATCH 14/78] [inductor] benchmark fusion (#108193) Pull Request resolved: https://github.com/pytorch/pytorch/pull/108193 Approved by: https://github.com/jansel --- test/inductor/test_benchmark_fusion.py | 138 +++++++++++++++++++++++++ torch/_inductor/codegen/common.py | 21 ++-- torch/_inductor/codegen/cpp.py | 1 + torch/_inductor/codegen/triton.py | 82 ++++++++++++++- torch/_inductor/config.py | 1 + torch/_inductor/scheduler.py | 105 ++++++++++++++++++- torch/_inductor/virtualized.py | 17 ++- 7 files changed, 354 insertions(+), 11 deletions(-) create mode 100644 test/inductor/test_benchmark_fusion.py diff --git a/test/inductor/test_benchmark_fusion.py b/test/inductor/test_benchmark_fusion.py new file mode 100644 index 0000000000000..bc43100fd4bac --- /dev/null +++ b/test/inductor/test_benchmark_fusion.py @@ -0,0 +1,138 @@ +# Owner(s): ["module: inductor"] +import math +import os +import sys + +import torch +from torch.testing._internal.common_utils import ( + IS_CI, + IS_WINDOWS, + skipIfRocm, + TEST_WITH_ASAN, + TestCase as TorchTestCase, +) +from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA + +# Make the helper files in test/ importable +pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) +sys.path.append(pytorch_test_dir) + +import contextlib +import unittest + +from torch._inductor import config +from torch._inductor.scheduler import Scheduler + +if IS_WINDOWS and IS_CI: + sys.stderr.write( + "Windows CI does not have necessary dependencies for test_torchinductor yet\n" + ) + if __name__ == "__main__": + sys.exit(0) + raise unittest.SkipTest("requires sympy/functorch/filelock") + +from inductor.test_torchinductor import check_model, check_model_cuda, copy_tests + + +class TestCase(TorchTestCase): + @classmethod + def setUpClass(cls): + super().setUpClass() + cls._stack = contextlib.ExitStack() + cls._stack.enter_context( + config.patch( + { + "benchmark_kernel": True, + "benchmark_fusion": True, + } + ) + ) + + @classmethod + def tearDownClass(cls): + cls._stack.close() + super().tearDownClass() + + +class BenchmarkFusionTestTemplate: + def test_softmax(self): + def f(x): + return torch.nn.functional.softmax(x, dim=-1) + + self.common(f, (torch.rand(2, 8192),)) + + @skipIfRocm # fail accuracy check on ROCm + def test_resnet18(self): + import torchvision + + model = torchvision.models.resnet18() + model.eval() + batch_size = 16 + inputs = (torch.randn((batch_size, 3, 224, 224)),) + self.common(model, inputs, atol=1e-2, rtol=1e-2) + + def test_register_spills(self): + """ + The test can potentially trigger register spills + """ + old_benchmark_fn = Scheduler.benchmark_fused_nodes + + def new_benchmark_fn(scheduler, nodes): + """ + We override Scheduler.benchmark_fused_nodes to return latency 1.0 + if there are no register spills. Without this, we may not able to + test the code path handling register spilling because before register + start spilling, the related fusion may have already been skipped + due to longer lantency. + """ + ms = old_benchmark_fn(scheduler, nodes) + if not math.isinf(ms): + ms = 1.0 + return ms + + # Disable dynamic_scale_rblock to make it easier to trigger register + # spilling. + with unittest.mock.patch.object( + Scheduler, "benchmark_fused_nodes", new_benchmark_fn + ), config.patch("dynamic_scale_rblock", False): + S = 512 + + def f(*inputs): + inputs = list(inputs) + outputs = [] + out = torch.zeros(S, device=self.device) + for x in inputs: + x = x * 2 + x = x + 1 + x = x.sum(dim=-1) + outputs.append(x) + out = out + x + return outputs, out + + N = int(os.environ.get("NINP", "30")) + inputs = [torch.randn(S, 2560, device=self.device) for _ in range(N)] + opt_f = torch.compile(f) + opt_f(*inputs) + + +if HAS_CUDA and not TEST_WITH_ASAN: + + class BenchmarkFusionCudaTest(TestCase): + common = check_model_cuda + device = "cuda" + + copy_tests(BenchmarkFusionTestTemplate, BenchmarkFusionCudaTest, "cuda") + +if HAS_CPU and not torch.backends.mps.is_available(): + + class BenchmarkFusionCpuTest(TestCase): + common = check_model + device = "cpu" + + copy_tests(BenchmarkFusionTestTemplate, BenchmarkFusionCpuTest, "cpu") + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + if HAS_CPU or HAS_CUDA: + run_tests() diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index 9ab92e22146a7..f0ca422461517 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -438,14 +438,14 @@ def __init__(self, name, line): self.name = name def __call__(self): - # V.kernel may be null since this method may be called for the - # wrapper codegen where there is no specific kernel. - if ( - self.name - not in ( - V.graph.removed_buffers | getattr(V.kernel, "removed_buffers", set()) + if all( + self.name not in x + for x in ( + V.graph.removed_buffers, + V.kernel.removed_buffers, + V.graph.inplaced_to_remove, + V.kernel.inplaced_to_remove, ) - and self.name not in V.graph.inplaced_to_remove ): return self.line return None @@ -647,7 +647,10 @@ def aliases(self): if self._buffer_is_marked_removed(inplaced): continue for other in inplaced.other_names: - if other in V.graph.inplaced_to_remove: + if ( + other in V.graph.inplaced_to_remove + or other in V.kernel.inplaced_to_remove + ): continue if other in self.input_buffers: yield self.input_buffers[other], inplaced.inner_name @@ -888,6 +891,8 @@ def __init__(self, args=None, increase_kernel_count=True): self.indirect_max_sizes: Dict[Tuple[str, str], Tuple[sympy.Expr, str]] = {} self.removed_buffers = set() + self.inplaced_to_remove = set() + # key: the buffer to write # value: the buffer to read and whose memory can be reused for # the buffer specified by key diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 7a1aedf6413a7..ccb242cfb73b2 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -2661,6 +2661,7 @@ def run(kernel): scalar_kernel = codegen_kernel(CppKernel) V.graph.removed_buffers |= scalar_kernel.removed_buffers + V.graph.inplaced_to_remove |= scalar_kernel.inplaced_to_remove self.loop_nest = LoopNestWithSplit.build(scalar_kernel) if not self.picked_vec_isa: diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index c3d05eb3e3501..f486e63f1cd79 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -8,6 +8,7 @@ import logging import math import operator +import os from typing import Any, Counter, Dict, Iterable, List, Optional, Set, Tuple import sympy @@ -20,13 +21,14 @@ from torch.utils._sympy.value_ranges import ValueRanges from ..._dynamo.utils import counters from .. import config, ir, scheduler -from ..codecache import code_hash, get_path +from ..codecache import code_hash, get_path, PyCodeCache from ..dependencies import MemoryDep, StarDep from ..ir import IRNode, ReductionHint, TritonTemplateBuffer from ..optimize_indexing import indexing_dtype_strength_reduction from ..scheduler import BaseScheduling from ..triton_heuristics import AutotuneHint from ..utils import ( + do_bench, get_fused_kernel_name, get_kernel_metadata, green_text, @@ -2521,6 +2523,7 @@ def codegen_node_schedule(self, node_schedule, numel, reduction_numel): self.codegen_comment(node_schedule) kernel.call_kernel(kernel_name) V.graph.removed_buffers |= kernel.removed_buffers + V.graph.inplaced_to_remove |= kernel.inplaced_to_remove if config.warn_mix_layout: kernel.warn_mix_layout(kernel_name) @@ -2640,6 +2643,7 @@ def codegen_template(self, template_node, epilogue_nodes): self.codegen_comment(node_schedule) kernel.call_kernel(kernel_name, template_node.node) V.graph.removed_buffers |= kernel.removed_buffers + V.graph.inplaced_to_remove |= kernel.inplaced_to_remove self.scheduler.free_buffers() def codegen_sync(self): @@ -2677,6 +2681,7 @@ def codegen_foreach(self, foreach_node): if node not in (EnableReduction, DisableReduction): node.mark_run() V.graph.removed_buffers |= subkernel.removed_buffers + V.graph.inplaced_to_remove |= subkernel.inplaced_to_remove src_code = kernel.codegen_kernel() kernel_name = self.define_kernel(src_code, [foreach_node]) @@ -2825,6 +2830,81 @@ def select_tiling(cls, node_schedule, numel, reduction_numel=sympy.Integer(1)): def flush(self): pass + def benchmark_fused_nodes(self, nodes): + _, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group + node_schedule = self.generate_node_schedule(nodes, numel, rnumel) + tiled_groups = self.select_tiling(node_schedule, numel, rnumel) + reduction_hint_val, mutations, index_dtype = self.get_kernel_args( + node_schedule, numel, rnumel + ) + + kernel = TritonKernel( + *tiled_groups, + reduction_hint=reduction_hint_val, + mutations=mutations, + index_dtype=index_dtype, + ) + + # empty last_usage. May cause more aggressive 'evict_last'. Should be fine. + for n in nodes: + n.last_usage = set() + + self.codegen_node_schedule_with_kernel(node_schedule, kernel) + with config.patch("benchmark_kernel", True), V.set_kernel_handler(kernel): # type: ignore[attr-defined] + src_code = kernel.codegen_kernel() + + src_code = src_code.replace(str(Placeholder.KERNEL_NAME), "triton_") + mod = PyCodeCache.load(src_code) + + def cache_file_path(): + return os.path.splitext(mod.__file__)[0] + ".kernel_perf" # type: ignore[type-var,operator] + + def load_cache(): + path = cache_file_path() + if os.path.exists(path): + with open(path) as fd: + return float(fd.read()) + return None + + def store_cache(): + path = cache_file_path() + with open(path, "w") as fd: + fd.write(str(ms)) + + log.debug( + "kernel src code for %s written to: %s", + {n.get_name() for n in nodes}, + mod.__file__, + ) + ms = load_cache() + if ms is not None: + return ms + + args = mod.get_args() + call = mod.call + wrapped_jit_function = mod.triton_ + + # call once to trigger the compilation + call(wrapped_jit_function.clone_args(*args)) + + launchers = wrapped_jit_function.launchers + assert len(launchers) == 1 + if launchers[0].n_spills > 0: + # skip benchmarking the kernel if there are register spills + ms = float("inf") + else: + # We have to clone the inplace updated arguments to avoid earlier calls + # generating out of range indices for later calls. + ms = do_bench(lambda: call(wrapped_jit_function.clone_args(*args))) + + log.debug( + "The fused kernel for %s took %.3f ms to run", + {n.get_name() for n in nodes}, + ms, + ) + store_cache() + return ms + @dataclasses.dataclass class CandidateTiling: diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index cad5cc9923038..ec87843e097c5 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -198,6 +198,7 @@ # For each fused kernel in the wrapper, comment with the nodes that get fused. # Useful for debugging fusion. debug_fusion = os.environ.get("TORCHINDUCTOR_DEBUG_FUSION") == "1" +benchmark_fusion = os.environ.get("TORCHINDUCTOR_BENCHMARK_FUSION") == "1" # how many nodes to allow into a single fusion max_fusion_size = 64 diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 1ed30aeb83837..40762a5dce23a 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -3,6 +3,7 @@ import functools import itertools import logging +import math import os import pprint import textwrap @@ -28,6 +29,8 @@ get_device_tflops, get_dtype_size, get_gpu_dram_gbps, + green_text, + red_text, sympy_product, ) from .virtualized import V @@ -1495,6 +1498,97 @@ def fuse_nodes(self): fusion_log.debug("===== fusion complete (%d iterations) =====", i + 1) break + def benchmark_fused_nodes(self, nodes): + """ + Benchmark fused list of nodes and return the execution time + in milliseconds on randomly generated inputs. + """ + assert len(nodes) > 0 + device = nodes[0].get_device() + V.graph.scheduler = self + self.current_device = device + backend = self.get_backend(device) + return backend.benchmark_fused_nodes(nodes) + + def speedup_by_fusion(self, node1, node2): + """ + If config.benchmark_fusion is False, always return True. + Otherwise, return True if fusion can brings speedup. + """ + if not config.benchmark_fusion: + return True + + if node1.is_template(): + # TODO support benchmarking epilogue fusion + return True + + node_list_1 = node1.get_nodes() + device = node_list_1[0].get_device() + + # don't support benchmark fusion for CPU right now. + if device.type == "cpu": + return True + + node_list_2 = node2.get_nodes() + node_list_fused = node_list_1 + node_list_2 + + # We can not accurately benchmark kernel using atomic_add + # due to how we generate random integer inputs. + # Skip benchmarking them by allowing fusion. + if any( + hasattr(n.node, "data") + and hasattr(n.node.data, "scatter_mode") + and n.node.data.scatter_mode == "atomic_add" + for n in node_list_fused + ): + return True + + from triton.compiler.errors import CompilationError + + try: + ms1 = self.benchmark_fused_nodes(node_list_1) + if math.isinf(ms1): + log.debug( + "Skip fusion because of register spilling of the first kernel" + ) + return False + ms2 = self.benchmark_fused_nodes(node_list_2) + if math.isinf(ms2): + log.debug( + "Skip fusion because of register spilling of the second kernel" + ) + return False + ms_fused = self.benchmark_fused_nodes(node_list_fused) + if math.isinf(ms_fused): + log.debug( + "Skip fusion because of register spilling of the fused kernel" + ) + return False + except CompilationError as e: + # workaround triton issue: https://github.com/openai/triton/issues/2151 + if "Loop-carried variable" in str(e): + return True # allow fusion + else: + raise + + if log.isEnabledFor(logging.DEBUG): + if ms_fused < ms1 + ms2: + log.debug( + "Fusing %s with %s cause %sx speedup", + node1.get_names(), + node2.get_names(), + green_text(f"{(ms1 + ms2) / ms_fused:.3f}"), + ) + else: + log.debug( + "Fusing %s with %s cause %sx slowdown", + node1.get_names(), + node2.get_names(), + red_text(f"{ms_fused / (ms1 + ms2):.3f}"), + ) + + return ms_fused < ms1 + ms2 + def fuse_nodes_once(self): """ Mutates self.nodes to combine nodes into FusedSchedulerNodes. @@ -1510,6 +1604,8 @@ def fuse_nodes_once(self): if self.can_fuse(node1, node2) and not self.will_fusion_create_cycle( node1, node2 ): + if not self.speedup_by_fusion(node1, node2): + continue node3 = fuse(node1, node2) fused_nodes.remove(node1) fused_nodes.remove(node2) @@ -1886,7 +1982,7 @@ def remove_filter(n): remove = all(n in names_to_remove for n in buf.other_names) if remove: self.remove_inplace_buffer(name) - V.graph.inplaced_to_remove.add(name) + V.kernel.inplaced_to_remove.add(name) else: self.remove_buffer(name) @@ -2088,3 +2184,10 @@ def flush(self): Flush the generated kernel and python wrapper code to the source code file. """ raise NotImplementedError() + + def benchmark_fused_nodes(self, nodes): + """ + Benchmark fused list of nodes and return the execution time + in milliseconds on randomly generated inputs. + """ + raise NotImplementedError() diff --git a/torch/_inductor/virtualized.py b/torch/_inductor/virtualized.py index 593e949538293..6e981f9225fa9 100644 --- a/torch/_inductor/virtualized.py +++ b/torch/_inductor/virtualized.py @@ -56,6 +56,21 @@ class NullHandler: pass +class NullKernelHandler(NullHandler): + """ + We need access `V.kernel.removed_buffers` in DeferredLine class when there + is no kernel in the context. This happens when codegening the wrapper. + Initialize `removed_buffers` and `inplaced_to_remove` explicitly so we don't + need call 'getattr' with default value which is error prone to typo in + attribute name. + """ + + def __init__(self): + super().__init__() + self.removed_buffers = set() + self.inplaced_to_remove = set() + + def _arg_str(a) -> str: if isinstance(a, sympy.Expr): return sympy_str(a) @@ -169,7 +184,7 @@ def __getattr__(self, item): _graph = Virtualized("graph", NullHandler) _real_inputs = Virtualized("real_inputs", NullHandler) _fake_mode = Virtualized("fake_mode", NullHandler) -_kernel = Virtualized("kernel", NullHandler) +_kernel = Virtualized("kernel", NullKernelHandler) _debug = Virtualized("debug", NullHandler) _interpreter = Virtualized("interpreter", NullHandler) _aot_compilation = Virtualized("aot_compilation", NullHandler) From 4a94f77c8ec9346e05862cdf0d552e77f34d0e79 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 26 Oct 2023 23:23:49 +0000 Subject: [PATCH 15/78] Revert "Make numpy/lib vendored tests dynamo traceable (#112147)" This reverts commit 190b6e4ba88f6cf00d0bd08d6212a3fe6bb76eaa. Reverted https://github.com/pytorch/pytorch/pull/112147 on behalf of https://github.com/huydhn due to Sorry for reverting this again, but this is failing in trunk https://hud.pytorch.org/pytorch/pytorch/commit/190b6e4ba88f6cf00d0bd08d6212a3fe6bb76eaa ([comment](https://github.com/pytorch/pytorch/pull/112147#issuecomment-1782056995)) --- pytest.ini | 2 - .../torch_np/numpy_tests/lib/test_arraypad.py | 26 +- .../numpy_tests/lib/test_arraysetops.py | 58 ++-- .../numpy_tests/lib/test_function_base.py | 254 +++++++----------- .../numpy_tests/lib/test_histograms.py | 73 +++-- .../numpy_tests/lib/test_index_tricks.py | 47 +--- .../numpy_tests/lib/test_shape_base_.py | 76 ++---- .../numpy_tests/lib/test_twodim_base.py | 102 +++---- .../numpy_tests/lib/test_type_check.py | 54 ++-- 9 files changed, 232 insertions(+), 460 deletions(-) diff --git a/pytest.ini b/pytest.ini index 532e3bce098f3..67a691290076d 100644 --- a/pytest.ini +++ b/pytest.ini @@ -13,5 +13,3 @@ testpaths = junit_logging_reruns = all filterwarnings = ignore:Module already imported so cannot be rewritten.*hypothesis:pytest.PytestAssertRewriteWarning - -xfail_strict = True diff --git a/test/torch_np/numpy_tests/lib/test_arraypad.py b/test/torch_np/numpy_tests/lib/test_arraypad.py index befa9d76ac467..54745e8316d51 100644 --- a/test/torch_np/numpy_tests/lib/test_arraypad.py +++ b/test/torch_np/numpy_tests/lib/test_arraypad.py @@ -1,27 +1,15 @@ # Owner(s): ["module: dynamo"] -from unittest import skipIf as skipif +from unittest import expectedFailure as xfail, skipIf as skipif -from torch.testing._internal.common_utils import ( - run_tests, - TEST_WITH_TORCHDYNAMO, - TestCase, - xpassIfTorchDynamo, -) +import torch._numpy as np +from torch._numpy.testing import assert_allclose, assert_array_equal - -# If we are going to trace through these, we should use NumPy -# If testing on eager mode, we use torch._numpy -if TEST_WITH_TORCHDYNAMO: - import numpy as np - from numpy.testing import assert_allclose, assert_array_equal -else: - import torch._numpy as np - from torch._numpy.testing import assert_allclose, assert_array_equal +from torch.testing._internal.common_utils import run_tests, TestCase class TestConstant(TestCase): - @xpassIfTorchDynamo # (reason="tuple values") + @xfail # (reason="tuple values") def test_check_constant(self): a = np.arange(100) a = np.pad(a, (25, 20), "constant", constant_values=(10, 20)) @@ -369,7 +357,7 @@ def test_check_constant_float2(self): ) assert_allclose(test, expected) - @xpassIfTorchDynamo # (reason="tuple values") + @xfail # (reason="tuple values") def test_check_constant_float3(self): a = np.arange(100, dtype=float) a = np.pad(a, (25, 20), "constant", constant_values=(-1.1, -1.2)) @@ -540,7 +528,7 @@ def test_check_constant_odd_pad_amount(self): ) assert_allclose(test, expected) - @xpassIfTorchDynamo # (reason="tuple values") + @xfail # (reason="tuple values") def test_check_constant_pad_2d(self): arr = np.arange(4).reshape(2, 2) test = np.lib.pad( diff --git a/test/torch_np/numpy_tests/lib/test_arraysetops.py b/test/torch_np/numpy_tests/lib/test_arraysetops.py index e046558078591..0f9773ece6dfa 100644 --- a/test/torch_np/numpy_tests/lib/test_arraysetops.py +++ b/test/torch_np/numpy_tests/lib/test_arraysetops.py @@ -3,39 +3,24 @@ """Test functions for 1D array set operations. """ -from unittest import skipIf - -import numpy +from unittest import expectedFailure as xfail +import torch._numpy as np from pytest import raises as assert_raises +from torch._numpy import unique + +from torch._numpy.testing import assert_array_equal, assert_equal + from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, run_tests, - subtest, - TEST_WITH_TORCHDYNAMO, TestCase, - xfailIfTorchDynamo, - xpassIfTorchDynamo, ) -# If we are going to trace through these, we should use NumPy -# If testing on eager mode, we use torch._numpy -if TEST_WITH_TORCHDYNAMO: - import numpy as np - from numpy import ediff1d, in1d, intersect1d, setdiff1d, setxor1d, union1d, unique - from numpy.testing import assert_array_equal, assert_equal, assert_raises_regex - -else: - import torch._numpy as np - from torch._numpy import unique - from torch._numpy.testing import assert_array_equal, assert_equal - - -@skipIf(numpy.__version__ < "1.24", reason="NP_VER: fails on NumPy 1.23.x") -@xpassIfTorchDynamo # (reason="TODO") +@xfail # (reason="TODO") @instantiate_parametrized_tests class TestSetOps(TestCase): def test_intersect1d(self): @@ -160,14 +145,11 @@ def test_ediff1d(self): (np.array([1, 2, 3], dtype=np.int64), None, np.nan, "to_end"), # should fail because attempting # to downcast to int type: - subtest( - ( - np.array([1, 2, 3], dtype=np.int64), - np.array([5, 7, 2], dtype=np.float32), - None, - "to_begin", - ), - decorators=[xfailIfTorchDynamo], + ( + np.array([1, 2, 3], dtype=np.int64), + np.array([5, 7, 2], dtype=np.float32), + None, + "to_begin", ), # should fail because attempting to cast # two special floating point values @@ -223,7 +205,6 @@ def test_ediff1d_scalar_handling(self, ary, prepend, append, expected): assert_equal(actual, expected) assert actual.dtype == expected.dtype - @skipIf(True, reason="NP_VER: fails with NumPy 1.22.x") @parametrize("kind", [None, "sort", "table"]) def test_isin(self, kind): # the tests for in1d cover most of isin's behavior @@ -236,7 +217,7 @@ def _isin_slow(a, b): isin_slow = np.vectorize(_isin_slow, otypes=[bool], excluded={1}) def assert_isin_equal(a, b): - x = np.isin(a, b, kind=kind) + x = isin(a, b, kind=kind) y = isin_slow(a, b) assert_array_equal(x, y) @@ -463,7 +444,7 @@ def test_in1d_table_timedelta_fails(self): a = np.array([0, 1, 2], dtype="timedelta64[s]") b = a # Make sure it raises a value error: - with assert_raises(ValueError): + with pytest.raises(ValueError): in1d(a, b, kind="table") @parametrize( @@ -494,7 +475,7 @@ def test_in1d_mixed_dtype(self, dtype1, dtype2, kind): ) if expect_failure: - with assert_raises(RuntimeError, match="exceed the maximum"): + with pytest.raises(RuntimeError, match="exceed the maximum"): in1d(ar1, ar2, kind=kind) else: assert_array_equal(in1d(ar1, ar2, kind=kind), expected) @@ -763,7 +744,7 @@ def check_all(a, b, i1, i2, c, dt): # assert_equal(a3_idx.dtype, np.intp) # assert_equal(a3_inv.dtype, np.intp) - @xpassIfTorchDynamo # (reason="unique with nans") + @xfail # (reason="unique with nans") def test_unique_1d_2(self): # test for ticket 2111 - float a = [2.0, np.nan, 1.0, np.nan] @@ -809,7 +790,7 @@ def test_unique_axis_list(self): assert_array_equal(unique(inp, axis=0), unique(inp_arr, axis=0), msg) assert_array_equal(unique(inp, axis=1), unique(inp_arr, axis=1), msg) - @xpassIfTorchDynamo # _run_axis_tests xfails with the message + @xfail # _run_axis_tests xfails with the message # torch has different unique ordering behaviour" def test_unique_axis(self): types = [] @@ -835,7 +816,7 @@ def test_unique_1d_with_axis(self, axis): uniq = unique(x, axis=axis) assert_array_equal(uniq, [1, 2, 3, 4]) - @xpassIfTorchDynamo # (reason="unique / return_index") + @xfail # (reason="unique / return_index") def test_unique_axis_zeros(self): # issue 15559 single_zero = np.empty(shape=(2, 0), dtype=np.int8) @@ -942,8 +923,7 @@ def _run_axis_tests(self, dtype): msg = "Unique's return_counts=True failed with axis=1" assert_array_equal(cnt, np.array([2, 1, 1]), msg) - @skipIf(True, reason="NP_VER: fails on CI with older NumPy") - @xpassIfTorchDynamo # (reason="unique / return_index / nans") + @xfail # (reason="unique / return_index / nans") def test_unique_nanequals(self): # issue 20326 a = np.array([1, 1, np.nan, np.nan, np.nan]) diff --git a/test/torch_np/numpy_tests/lib/test_function_base.py b/test/torch_np/numpy_tests/lib/test_function_base.py index a524e9f6528a5..3934613a64fc4 100644 --- a/test/torch_np/numpy_tests/lib/test_function_base.py +++ b/test/torch_np/numpy_tests/lib/test_function_base.py @@ -11,21 +11,29 @@ import hypothesis import hypothesis.strategies as st - -import numpy - import pytest + +import torch._numpy as np from hypothesis.extra.numpy import arrays from pytest import raises as assert_raises +from torch._numpy.testing import ( + assert_, + assert_allclose, # IS_PYPY, + assert_almost_equal, + assert_array_almost_equal, + assert_array_equal, + assert_equal, + assert_raises_regex, + assert_warns, + suppress_warnings, # HAS_REFCOUNT, IS_WASM +) from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, run_tests, subtest, - TEST_WITH_TORCHDYNAMO, TestCase, - xpassIfTorchDynamo, ) skip = functools.partial(skipif, True) @@ -39,79 +47,25 @@ # from numpy lib import digitize, piecewise, trapz, select, trim_zeros, interp from numpy.lib import delete, extract, insert, msort, place, setxor1d, unwrap, vectorize - -# If we are going to trace through these, we should use NumPy -# If testing on eager mode, we use torch._numpy -if TEST_WITH_TORCHDYNAMO: - import numpy as np - from numpy import ( - angle, - bartlett, - blackman, - corrcoef, - cov, - diff, - digitize, - flipud, - gradient, - hamming, - hanning, - i0, - interp, - kaiser, - meshgrid, - sinc, - trapz, - trim_zeros, - unique, - ) - from numpy.core.numeric import normalize_axis_tuple - from numpy.random import rand - - from numpy.testing import ( - assert_, - assert_allclose, # IS_PYPY, - assert_almost_equal, - assert_array_almost_equal, - assert_array_equal, - assert_equal, - assert_raises_regex, - assert_warns, - suppress_warnings, # HAS_REFCOUNT, IS_WASM - ) -else: - import torch._numpy as np - from torch._numpy import ( - angle, - bartlett, - blackman, - corrcoef, - cov, - diff, - flipud, - gradient, - hamming, - hanning, - i0, - kaiser, - meshgrid, - sinc, - unique, - ) - from torch._numpy._util import normalize_axis_tuple - from torch._numpy.random import rand - - from torch._numpy.testing import ( - assert_, - assert_allclose, # IS_PYPY, - assert_almost_equal, - assert_array_almost_equal, - assert_array_equal, - assert_equal, - assert_raises_regex, - assert_warns, - suppress_warnings, # HAS_REFCOUNT, IS_WASM - ) +from torch._numpy import ( + angle, + bartlett, + blackman, + corrcoef, + cov, + diff, + flipud, + gradient, + hamming, + hanning, + i0, + kaiser, + meshgrid, + sinc, + unique, +) +from torch._numpy._util import normalize_axis_tuple +from torch._numpy.random import rand def get_mat(n): @@ -297,7 +251,7 @@ def test_basic(self): assert_equal(a[0, 0], 1) assert_equal(a_copy[0, 0], 10) - @xpassIfTorchDynamo # (reason="order='F' not implemented") + @xfail # (reason="order='F' not implemented") def test_order(self): # It turns out that people rely on np.copy() preserving order by # default; changing this broke scikit-learn: @@ -523,7 +477,7 @@ def test_many_arguments(self): select(conditions, choices) -@xpassIfTorchDynamo # (reason="TODO: implement") +@xfail # (reason="TODO: implement") @instantiate_parametrized_tests class TestInsert(TestCase): def test_basic(self): @@ -841,7 +795,7 @@ def test_append(self): assert_raises(np.AxisError, diff, x, append=0, axis=3) -@xpassIfTorchDynamo # (reason="TODO: implement") +@xfail # (reason="TODO: implement") @instantiate_parametrized_tests class TestDelete(TestCase): def setUp(self): @@ -913,9 +867,7 @@ def test_index_floats(self): with pytest.raises(IndexError): np.delete([0, 1, 2], np.array([], dtype=float)) - @parametrize( - "indexer", [subtest(np.array([1]), name="array([1])"), subtest([1], name="[1]")] - ) + @parametrize("indexer", [np.array([1]), [1]]) def test_single_item_array(self, indexer): a_del_int = delete(self.a, 1) a_del = delete(self.a, indexer) @@ -1190,7 +1142,7 @@ def test_basic(self): assert_array_almost_equal(z, zo, 11) -@xpassIfTorchDynamo +@xfail # (reason="trim_zeros not implemented") @instantiate_parametrized_tests class TestTrimZeros(TestCase): a = np.array([0, 0, 1, 0, 2, 3, 4, 0]) @@ -1199,11 +1151,7 @@ class TestTrimZeros(TestCase): # d = a.astype(object) def values(self): - attr_names = ( - "a", - "b", - "c", - ) # "d") + attr_names = ("a", "b", "c", "d") return (getattr(self, name) for name in attr_names) def test_basic(self): @@ -1262,7 +1210,7 @@ def test_list_to_list(self): assert isinstance(res, list) -@xpassIfTorchDynamo # (reason="TODO: implement") +@xfail # (reason="TODO: implement") class TestExtins(TestCase): def test_basic(self): a = np.array([1, 3, 2, 1, 2, 3, 3]) @@ -1664,7 +1612,7 @@ def test_size_zero_output(self): f(x) -@xpassIfTorchDynamo # (reason="TODO: implement") +@xfail # (reason="TODO: implement") class TestDigitize(TestCase): def test_forward(self): x = np.arange(-6, 5) @@ -1768,9 +1716,7 @@ def test_period(self): @instantiate_parametrized_tests class TestFilterwindows(TestCase): - @parametrize( - "dtype", "Bbhil" + "efd" - ) # np.typecodes["AllInteger"] + np.typecodes["Float"]) + @parametrize("dtype", np.typecodes["AllInteger"] + np.typecodes["Float"]) @parametrize("M", [0, 1, 10]) def test_hanning(self, dtype: str, M: int) -> None: scalar = M @@ -1790,9 +1736,7 @@ def test_hanning(self, dtype: str, M: int) -> None: else: assert_almost_equal(np.sum(w, axis=0), 4.500, 4) - @parametrize( - "dtype", "Bbhil" + "efd" - ) # np.typecodes["AllInteger"] + np.typecodes["Float"]) + @parametrize("dtype", np.typecodes["AllInteger"] + np.typecodes["Float"]) @parametrize("M", [0, 1, 10]) def test_hamming(self, dtype: str, M: int) -> None: scalar = M @@ -1812,9 +1756,7 @@ def test_hamming(self, dtype: str, M: int) -> None: else: assert_almost_equal(np.sum(w, axis=0), 4.9400, 4) - @parametrize( - "dtype", "Bbhil" + "efd" - ) # np.typecodes["AllInteger"] + np.typecodes["Float"]) + @parametrize("dtype", np.typecodes["AllInteger"] + np.typecodes["Float"]) @parametrize("M", [0, 1, 10]) def test_bartlett(self, dtype: str, M: int) -> None: scalar = M @@ -1834,9 +1776,7 @@ def test_bartlett(self, dtype: str, M: int) -> None: else: assert_almost_equal(np.sum(w, axis=0), 4.4444, 4) - @parametrize( - "dtype", "Bbhil" + "efd" - ) # np.typecodes["AllInteger"] + np.typecodes["Float"]) + @parametrize("dtype", np.typecodes["AllInteger"] + np.typecodes["Float"]) @parametrize("M", [0, 1, 10]) def test_blackman(self, dtype: str, M: int) -> None: scalar = M @@ -1856,9 +1796,7 @@ def test_blackman(self, dtype: str, M: int) -> None: else: assert_almost_equal(np.sum(w, axis=0), 3.7800, 4) - @parametrize( - "dtype", "Bbhil" + "efd" - ) # np.typecodes["AllInteger"] + np.typecodes["Float"]) + @parametrize("dtype", np.typecodes["AllInteger"] + np.typecodes["Float"]) @parametrize("M", [0, 1, 10]) def test_kaiser(self, dtype: str, M: int) -> None: scalar = M @@ -1879,7 +1817,7 @@ def test_kaiser(self, dtype: str, M: int) -> None: assert_almost_equal(np.sum(w, axis=0), 10, 15) -@xpassIfTorchDynamo # (reason="TODO: implement") +@xfail # (reason="TODO: implement") class TestTrapz(TestCase): def test_simple(self): x = np.arange(-10, 10, 0.1) @@ -1948,13 +1886,13 @@ def test_simple(self): assert_(unique(np.array([1, 1, 1, 1, 1])) == np.array([1])) - @xpassIfTorchDynamo # (reason="unique not implemented for 'ComplexDouble'") + @xfail # (reason="unique not implemented for 'ComplexDouble'") def test_simple_complex(self): x = np.array([5 + 6j, 1 + 1j, 1 + 10j, 10, 5 + 6j]) assert_(np.all(unique(x) == [1 + 1j, 1 + 10j, 5 + 6j, 10])) -@xpassIfTorchDynamo # (reason="TODO: implement") +@xfail # (reason="TODO: implement") class TestCheckFinite(TestCase): def test_simple(self): a = [1, 2, 3] @@ -2599,19 +2537,7 @@ def test_error_not_1d(self, vals): np.bincount(vals) -parametrize_interp_sc = parametrize( - "sc", - [ - subtest(lambda x: np.float_(x), name="real"), - subtest(lambda x: _make_complex(x, 0), name="complex-real"), - subtest(lambda x: _make_complex(0, x), name="complex-imag"), - subtest(lambda x: _make_complex(x, np.multiply(x, -2)), name="complex-both"), - ], -) - - -@xpassIfTorchDynamo # (reason="TODO: implement") -@instantiate_parametrized_tests +@xfail # (reason="TODO: implement") class TestInterp(TestCase): def test_exceptions(self): assert_raises(ValueError, interp, 0, [], []) @@ -2686,7 +2612,19 @@ def test_non_finite_behavior_exact_x(self): fp = [1, 2, np.nan, 4] assert_almost_equal(np.interp(x, xp, fp), [1, 2, np.nan, np.nan, 4]) - @parametrize_interp_sc + @pytest.fixture( + params=[ + lambda x: np.float_(x), + lambda x: _make_complex(x, 0), + lambda x: _make_complex(0, x), + lambda x: _make_complex(x, np.multiply(x, -2)), + ], + ids=["real", "complex-real", "complex-imag", "complex-both"], + ) + def sc(self, request): + """scale function used by the below tests""" + return request.param + def test_non_finite_any_nan(self, sc): """test that nans are propagated""" assert_equal(np.interp(0.5, [np.nan, 1], sc([0, 10])), sc(np.nan)) @@ -2694,7 +2632,6 @@ def test_non_finite_any_nan(self, sc): assert_equal(np.interp(0.5, [0, 1], sc([np.nan, 10])), sc(np.nan)) assert_equal(np.interp(0.5, [0, 1], sc([0, np.nan])), sc(np.nan)) - @parametrize_interp_sc def test_non_finite_inf(self, sc): """Test that interp between opposite infs gives nan""" assert_equal(np.interp(0.5, [-np.inf, +np.inf], sc([0, 10])), sc(np.nan)) @@ -2704,7 +2641,6 @@ def test_non_finite_inf(self, sc): # unless the y values are equal assert_equal(np.interp(0.5, [-np.inf, +np.inf], sc([10, 10])), sc(10)) - @parametrize_interp_sc def test_non_finite_half_inf_xf(self, sc): """Test that interp where both axes have a bound at inf gives nan""" assert_equal(np.interp(0.5, [-np.inf, 1], sc([-np.inf, 10])), sc(np.nan)) @@ -2716,7 +2652,6 @@ def test_non_finite_half_inf_xf(self, sc): assert_equal(np.interp(0.5, [0, +np.inf], sc([0, -np.inf])), sc(np.nan)) assert_equal(np.interp(0.5, [0, +np.inf], sc([0, +np.inf])), sc(np.nan)) - @parametrize_interp_sc def test_non_finite_half_inf_x(self, sc): """Test interp where the x axis has a bound at inf""" assert_equal(np.interp(0.5, [-np.inf, -np.inf], sc([0, 10])), sc(10)) @@ -2724,7 +2659,6 @@ def test_non_finite_half_inf_x(self, sc): assert_equal(np.interp(0.5, [0, +np.inf], sc([0, 10])), sc(0)) assert_equal(np.interp(0.5, [+np.inf, +np.inf], sc([0, 10])), sc(0)) - @parametrize_interp_sc def test_non_finite_half_inf_f(self, sc): """Test interp where the f axis has a bound at inf""" assert_equal(np.interp(0.5, [0, 1], sc([0, -np.inf])), sc(-np.inf)) @@ -2852,7 +2786,7 @@ def test_2D(self): x = np.array([[1, 1, 1], [1, 1, 1], [4, 4, 3], [1, 1, 1], [1, 1, 1]]) assert_array_equal(np.percentile(x, 50, axis=0), [1, 1, 1]) - @xpassIfTorchDynamo # (reason="TODO: implement") + @xfail # (reason="TODO: implement") @parametrize("dtype", np.typecodes["Float"]) def test_linear_nan_1D(self, dtype): # METHOD 1 of H&F @@ -2862,14 +2796,14 @@ def test_linear_nan_1D(self, dtype): np.testing.assert_equal(res.dtype, arr.dtype) H_F_TYPE_CODES = [ - (int_type, np.float64) for int_type in "Bbhil" # np.typecodes["AllInteger"] + (int_type, np.float64) for int_type in np.typecodes["AllInteger"] ] + [ (np.float16, np.float16), (np.float32, np.float32), (np.float64, np.float64), ] - @skip(reason="NEP 50 is new in 1.24") + @xfail # (reason="TODO: implement percentile interpolations") @parametrize("input_dtype, expected_dtype", H_F_TYPE_CODES) @parametrize( "method, expected", @@ -2887,11 +2821,7 @@ def test_linear_nan_1D(self, dtype): ) def test_linear_interpolation(self, method, expected, input_dtype, expected_dtype): expected_dtype = np.dtype(expected_dtype) - - if ( - hasattr(np, "_get_promotion_state") - and np._get_promotion_state() == "legacy" - ): + if np._get_promotion_state() == "legacy": expected_dtype = np.promote_types(expected_dtype, np.float64) arr = np.asarray([15.0, 20.0, 35.0, 40.0, 50.0], dtype=input_dtype) @@ -3146,7 +3076,7 @@ def test_percentile_overwrite(self): b = np.percentile([2, 3, 4, 1], [50], overwrite_input=True) assert_equal(b, np.array([2.5])) - @xpassIfTorchDynamo # (reason="pytorch percentile does not support tuple axes.") + @xfail # (reason="pytorch percentile does not support tuple axes.") def test_extended_axis(self): o = np.random.normal(size=(71, 23)) x = np.dstack([o] * 10) @@ -3235,7 +3165,6 @@ def test_keepdims_2(self): np.percentile(d, [1, 7], axis=(0, 3), keepdims=True).shape, (2, 1, 5, 7, 1) ) - @skipif(numpy.__version__ < "1.24", reason="NP_VER: fails on NumPy 1.23.x") @parametrize( "q", [ @@ -3243,7 +3172,7 @@ def test_keepdims_2(self): subtest( [1, 7], decorators=[ - xpassIfTorchDynamo, + xfail, ], ), ], @@ -3257,13 +3186,13 @@ def test_keepdims_2(self): subtest( (0, 1), decorators=[ - xpassIfTorchDynamo, + xfail, ], ), subtest( (-3, -1), decorators=[ - xpassIfTorchDynamo, + xfail, ], ), ], @@ -3313,7 +3242,7 @@ def test_out_nan(self): assert_equal(np.percentile(d, 1, out=o), o) assert_equal(np.percentile(d, 1, method="nearest", out=o), o) - @xpassIfTorchDynamo # (reason="np.percentile undocumented nan weirdness") + @xfail # (reason="np.percentile undocumented nan weirdness") def test_nan_behavior(self): a = np.arange(24, dtype=float) a[2] = np.nan @@ -3406,7 +3335,7 @@ def test_basic(self): assert_equal(np.quantile(x, 1), 3.5) assert_equal(np.quantile(x, 0.5), 1.75) - @xpassIfTorchDynamo # (reason="quantile w/integers or bools") + @xfail # (reason="quantile w/integers or bools") def test_correct_quantile_value(self): a = np.array([True]) tf_quant = np.quantile(True, False) @@ -3465,8 +3394,8 @@ def test_no_p_overwrite(self): np.quantile(np.arange(100.0), p, method="midpoint") assert_array_equal(p, p0) - @xpassIfTorchDynamo # (reason="TODO: make quantile preserve integers") - @parametrize("dtype", "Bbhil") # np.typecodes["AllInteger"]) + @xfail # (reason="TODO: make quantile preserve integers") + @parametrize("dtype", np.typecodes["AllInteger"]) def test_quantile_preserve_int_type(self, dtype): res = np.quantile(np.array([1, 2], dtype=dtype), [0.5], method="nearest") assert res.dtype == dtype @@ -3477,50 +3406,50 @@ def test_quantile_preserve_int_type(self, dtype): subtest( "inverted_cdf", decorators=[ - xpassIfTorchDynamo, + xfail, ], ), subtest( "averaged_inverted_cdf", decorators=[ - xpassIfTorchDynamo, + xfail, ], ), subtest( "closest_observation", decorators=[ - xpassIfTorchDynamo, + xfail, ], ), subtest( "interpolated_inverted_cdf", decorators=[ - xpassIfTorchDynamo, + xfail, ], ), subtest( "hazen", decorators=[ - xpassIfTorchDynamo, + xfail, ], ), subtest( "weibull", decorators=[ - xpassIfTorchDynamo, + xfail, ], ), "linear", subtest( "median_unbiased", decorators=[ - xpassIfTorchDynamo, + xfail, ], ), subtest( "normal_unbiased", decorators=[ - xpassIfTorchDynamo, + xfail, ], ), "nearest", @@ -3588,7 +3517,7 @@ def test_basic(self): a = np.array([0.0444502, 0.141249, 0.0463301]) assert_equal(a[-1], np.median(a)) - @xpassIfTorchDynamo # (reason="median: scalar output vs 0-dim") + @xfail # (reason="median: scalar output vs 0-dim") def test_basic_2(self): # check array scalar result a = np.array([0.0444502, 0.141249, 0.0463301]) @@ -3697,7 +3626,7 @@ def test_nan_behavior(self): b[1, 2] = np.nan assert_equal(np.median(a, 1), b) - @xpassIfTorchDynamo # (reason="median: does not support tuple axes") + @xfail # (reason="median: does not support tuple axes") def test_nan_behavior_2(self): a = np.arange(24, dtype=float).reshape(2, 3, 4) a[1, 2, 3] = np.nan @@ -3709,7 +3638,7 @@ def test_nan_behavior_2(self): b[2] = np.nan assert_equal(np.median(a, (0, 2)), b) - @xpassIfTorchDynamo # (reason="median: scalar vs 0-dim") + @xfail # (reason="median: scalar vs 0-dim") def test_nan_behavior_3(self): a = np.arange(24, dtype=float).reshape(2, 3, 4) a[1, 2, 3] = np.nan @@ -3718,7 +3647,7 @@ def test_nan_behavior_3(self): # no axis assert_equal(np.median(a).ndim, 0) - @xpassIfTorchDynamo # (reason="median: torch.quantile does not handle empty tensors") + @xfail # (reason="median: torch.quantile does not handle empty tensors") @skipif(IS_WASM, reason="fp errors don't work correctly") def test_empty(self): # mean(empty array) emits two warnings: empty slice and divide by 0 @@ -3749,7 +3678,7 @@ def test_empty(self): assert_equal(np.median(a, axis=2), b) assert_(w[0].category is RuntimeWarning) - @xpassIfTorchDynamo # (reason="median: tuple axes not implemented") + @xfail # (reason="median: tuple axes not implemented") def test_extended_axis(self): o = np.random.normal(size=(71, 23)) x = np.dstack([o] * 10) @@ -3799,7 +3728,7 @@ def test_keepdims(self): d = np.ones((3, 5, 7, 11)) assert_equal(np.median(d, axis=None, keepdims=True).shape, (1, 1, 1, 1)) - @xpassIfTorchDynamo # (reason="median: tuple axis") + @xfail # (reason="median: tuple axis") def test_keepdims_2(self): d = np.ones((3, 5, 7, 11)) assert_equal(np.median(d, axis=(0, 1), keepdims=True).shape, (1, 1, 7, 11)) @@ -3808,7 +3737,6 @@ def test_keepdims_2(self): assert_equal(np.median(d, axis=(0, 1, 2, 3), keepdims=True).shape, (1, 1, 1, 1)) assert_equal(np.median(d, axis=(0, 1, 3), keepdims=True).shape, (1, 1, 7, 1)) - @skipif(numpy.__version__ < "1.24", reason="NP_VER: fails on NumPy 1.23.x") @parametrize( "axis", [ @@ -3818,13 +3746,13 @@ def test_keepdims_2(self): subtest( (0, 1), decorators=[ - xpassIfTorchDynamo, + xfail, ], ), subtest( (-3, -1), decorators=[ - xpassIfTorchDynamo, + xfail, ], ), ], @@ -3844,7 +3772,7 @@ def test_keepdims_out(self, axis): assert_equal(result.shape, shape_out) -@xpassIfTorchDynamo # (reason="TODO: implement") +@xfail # (reason="TODO: implement") @instantiate_parametrized_tests class TestSortComplex(TestCase): @parametrize( diff --git a/test/torch_np/numpy_tests/lib/test_histograms.py b/test/torch_np/numpy_tests/lib/test_histograms.py index 4b09ef5b207b4..9d6b0364fc2d2 100644 --- a/test/torch_np/numpy_tests/lib/test_histograms.py +++ b/test/torch_np/numpy_tests/lib/test_histograms.py @@ -3,46 +3,32 @@ # from numpy.testing._private.utils import requires_memory import functools -from unittest import skipIf +from unittest import expectedFailure as xfail, skipIf +import pytest +import torch._numpy as np from pytest import raises as assert_raises - -skip = functools.partial(skipIf, True) - +from torch._numpy import histogram, histogramdd + +# from numpy.lib.histograms import histogram, histogramdd, histogram_bin_edges +from torch._numpy.testing import ( + assert_, + assert_allclose, + assert_almost_equal, + assert_array_almost_equal, + assert_array_equal, + assert_equal, + # assert_array_max_ulp, #assert_raises_regex, suppress_warnings, +) from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, run_tests, slowTest as slow, - TEST_WITH_TORCHDYNAMO, TestCase, - xpassIfTorchDynamo, ) -if TEST_WITH_TORCHDYNAMO: - import numpy as np - from numpy import histogram, histogram_bin_edges, histogramdd - from numpy.testing import ( - assert_, - assert_allclose, - assert_almost_equal, - assert_array_almost_equal, - assert_array_equal, - assert_equal, - # assert_array_max_ulp, #assert_raises_regex, suppress_warnings, - ) -else: - import torch._numpy as np - from torch._numpy import histogram, histogramdd - from torch._numpy.testing import ( - assert_, - assert_allclose, - assert_almost_equal, - assert_array_almost_equal, - assert_array_equal, - assert_equal, - # assert_array_max_ulp, #assert_raises_regex, suppress_warnings, - ) +skip = functools.partial(skipIf, True) class TestHistogram(TestCase): @@ -203,7 +189,7 @@ def test_weights(self): ) assert_almost_equal(a, [0.2, 0.1, 0.1, 0.075]) - @xpassIfTorchDynamo # (reason="histogram complex weights") + @xfail # (reason="histogram complex weights") def test_exotic_weights(self): # Test the use of weights that are not integer or floats, but e.g. # complex numbers or object types. @@ -265,7 +251,7 @@ def test_invalid_range(self): with assert_raises((RuntimeError, ValueError)): np.histogram(vals, range=[0.1, 0.01]) - @xpassIfTorchDynamo # (reason="edge cases") + @xfail # (reason="edge cases") def test_bin_edge_cases(self): # Ensure that floating-point computations correctly place edge cases. arr = np.array([337, 404, 739, 806, 1007, 1811, 2012]) @@ -289,7 +275,7 @@ def test_bin_array_dims(self): with assert_raises((RuntimeError, ValueError)): np.histogram(vals, bins=bins) - @xpassIfTorchDynamo # (reason="no uint64") + @xfail # (reason="no uint64") def test_unsigned_monotonicity_check(self): # Ensures ValueError is raised if bins not increasing monotonically # when bins contain unsigned values (see #9222) @@ -315,7 +301,7 @@ def test_object_array_of_0d(self): np.histogram([np.array(0.5) for i in range(10)] + [0.500000000000001]) np.histogram([np.array(0.5) for i in range(10)] + [0.5]) - @xpassIfTorchDynamo # (reason="bins='auto'") + @xfail # (reason="bins='auto'") def test_some_nan_values(self): # gh-7503 one_nan = np.array([0, 1, np.nan]) @@ -353,7 +339,7 @@ def test_signed_overflow_bounds(self): self.do_signed_overflow_bounds(np.short) self.do_signed_overflow_bounds(np.intc) - @xpassIfTorchDynamo # (reason="int->float conversin loses precision") + @xfail # (reason="int->float conversin loses precision") def test_signed_overflow_bounds_2(self): self.do_signed_overflow_bounds(np.int_) self.do_signed_overflow_bounds(np.longlong) @@ -396,14 +382,14 @@ def do_precision(self, float_small, float_large): self.do_precision_lower_bound(float_small, float_large) self.do_precision_upper_bound(float_small, float_large) - @xpassIfTorchDynamo # (reason="mixed dtypes") + @xfail # (reason="mixed dtypes") def test_precision(self): # not looping results in a useful stack trace upon failure self.do_precision(np.half, np.single) self.do_precision(np.half, np.double) self.do_precision(np.single, np.double) - @xpassIfTorchDynamo # (reason="histogram_bin_edges") + @xfail # (reason="histogram_bin_edges") def test_histogram_bin_edges(self): hist, e = histogram([1, 2, 3, 4], [1, 2]) edges = histogram_bin_edges([1, 2, 3, 4], [1, 2]) @@ -419,7 +405,7 @@ def test_histogram_bin_edges(self): assert_array_equal(edges, e) # @requires_memory(free_bytes=1e10) - @xpassIfTorchDynamo # (reason="pytorch does not support bins = [int, int, array]") + @xfail # (reason="pytorch does not support bins = [int, int, array]") @slow def test_big_arrays(self): sample = np.zeros([100000000, 3]) @@ -430,7 +416,7 @@ def test_big_arrays(self): assert_equal(type(hist), type((1, 2))) -@xpassIfTorchDynamo # (reason="TODO") +@xfail # (reason="TODO") @instantiate_parametrized_tests class TestHistogramOptimBinNums(TestCase): """ @@ -712,6 +698,7 @@ def test_simple_weighted(self): """ Check that weighted data raises a TypeError """ + pytest.xpass(reason="passes by chance") estimator_list = ["fd", "scott", "rice", "sturges", "auto"] for estimator in estimator_list: assert_raises(TypeError, histogram, [1, 2, 3], estimator, weights=[1, 2, 3]) @@ -853,13 +840,13 @@ def test_bins_errors(self): (RuntimeError, ValueError), np.histogramdd, x, bins=[1, 1, 1, [1, 2, 3, -3]] ) - @xpassIfTorchDynamo # (reason="pytorch does not support bins = [int, int, array]") + @xfail # (reason="pytorch does not support bins = [int, int, array]") def test_bins_error_2(self): # mixing scalar (# of bins) and explicit bin arrays, ugh x = np.arange(8).reshape(2, 4) assert_(np.histogramdd(x, bins=[1, 1, 1, [1, 2, 3, 4]])) - @xpassIfTorchDynamo # (reason="pytorch does not support bins = [int, int, array]") + @xfail # (reason="pytorch does not support bins = [int, int, array]") def test_inf_edges(self): # Test using +/-inf bin edges works. See #1788. x = np.arange(6).reshape(3, 2) @@ -910,7 +897,7 @@ def test_finite_range(self): range=[[0.0, 1.0], [np.nan, 0.75], [0.25, 0.5]], ) - @xpassIfTorchDynamo # (reason="pytorch does not allow equal entries") + @xfail # (reason="pytorch does not allow equal entries") def test_equal_edges(self): """Test that adjacent entries in an edge array can be equal""" x = np.array([0, 1, 2]) @@ -941,7 +928,7 @@ def test_edge_dtype(self): def test_large_integers(self): big = 2**60 # Too large to represent with a full precision float - x = np.asarray([0], dtype=np.int64) + x = np.array([0], np.int64) x_edges = np.array([-1, +1], np.int64) y = big + x y_edges = big + x_edges diff --git a/test/torch_np/numpy_tests/lib/test_index_tricks.py b/test/torch_np/numpy_tests/lib/test_index_tricks.py index e43e33be03946..d3aac7663ec2e 100644 --- a/test/torch_np/numpy_tests/lib/test_index_tricks.py +++ b/test/torch_np/numpy_tests/lib/test_index_tricks.py @@ -4,52 +4,29 @@ from unittest import expectedFailure as xfail, skipIf +import torch._numpy as np + from pytest import raises as assert_raises # , assert_raises_regex, +from torch._numpy import diag_indices, diag_indices_from, fill_diagonal, index_exp, s_ +from torch._numpy.testing import ( + assert_, + assert_almost_equal, + assert_array_almost_equal, + assert_array_equal, + assert_equal, +) from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, run_tests, - TEST_WITH_TORCHDYNAMO, TestCase, - xpassIfTorchDynamo, ) skip = functools.partial(skipIf, True) -# If we are going to trace through these, we should use NumPy -# If testing on eager mode, we use torch._numpy -if TEST_WITH_TORCHDYNAMO: - import numpy as np - from numpy import diag_indices, diag_indices_from, fill_diagonal, index_exp, s_ - from numpy.testing import ( - assert_, - assert_almost_equal, - assert_array_almost_equal, - assert_array_equal, - assert_equal, - assert_raises_regex, - ) -else: - import torch._numpy as np - from torch._numpy import ( - diag_indices, - diag_indices_from, - fill_diagonal, - index_exp, - s_, - ) - from torch._numpy.testing import ( - assert_, - assert_almost_equal, - assert_array_almost_equal, - assert_array_equal, - assert_equal, - ) - - -@xpassIfTorchDynamo # (reason="unravel_index not implemented") +@xfail # (reason="unravel_index not implemented") @instantiate_parametrized_tests class TestRavelUnravelIndex(TestCase): def test_basic(self): @@ -451,7 +428,7 @@ def test_repeated_input(self): class TestC(TestCase): - @xpassIfTorchDynamo # (reason="c_ not implemented") + @xfail # (reason="c_ not implemented") def test_c_(self): a = np.c_[np.array([[1, 2, 3]]), 0, 0, np.array([[4, 5, 6]])] assert_equal(a, [[1, 2, 3, 0, 0, 4, 5, 6]]) diff --git a/test/torch_np/numpy_tests/lib/test_shape_base_.py b/test/torch_np/numpy_tests/lib/test_shape_base_.py index 20c04f3e1a215..673d1ed0b537e 100644 --- a/test/torch_np/numpy_tests/lib/test_shape_base_.py +++ b/test/torch_np/numpy_tests/lib/test_shape_base_.py @@ -5,62 +5,34 @@ from unittest import expectedFailure as xfail, skipIf as skipif +import torch._numpy as np + from pytest import raises as assert_raises +from torch._numpy import ( + array_split, + column_stack, + dsplit, + dstack, + expand_dims, + hsplit, + kron, + put_along_axis, + split, + take_along_axis, + tile, + vsplit, +) +from torch._numpy.random import rand, randint + +from torch._numpy.testing import assert_, assert_array_equal, assert_equal from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, run_tests, - TEST_WITH_TORCHDYNAMO, TestCase, - xfailIfTorchDynamo, - xpassIfTorchDynamo, ) - -# If we are going to trace through these, we should use NumPy -# If testing on eager mode, we use torch._numpy -if TEST_WITH_TORCHDYNAMO: - import numpy as np - from numpy import ( - apply_along_axis, - array_split, - column_stack, - dsplit, - dstack, - expand_dims, - hsplit, - kron, - put_along_axis, - split, - take_along_axis, - tile, - vsplit, - ) - from numpy.random import rand, randint - - from numpy.testing import assert_, assert_array_equal, assert_equal - -else: - import torch._numpy as np - from torch._numpy import ( - array_split, - column_stack, - dsplit, - dstack, - expand_dims, - hsplit, - kron, - put_along_axis, - split, - take_along_axis, - tile, - vsplit, - ) - from torch._numpy.random import rand, randint - from torch._numpy.testing import assert_, assert_array_equal, assert_equal - - skip = functools.partial(skipif, True) @@ -154,7 +126,7 @@ def test_replace_max(self): assert_equal(i_min, i_max) - @xpassIfTorchDynamo # ( + @xfail # ( # reason="RuntimeError: Expected index [1, 2, 5] to be smaller than self [3, 4, 1] apart from dimension 1") def test_broadcast(self): """Test that non-indexing dimensions are broadcast in both directions""" @@ -164,7 +136,7 @@ def test_broadcast(self): assert_equal(take_along_axis(a, ai, axis=1), 20) -@xpassIfTorchDynamo # (reason="apply_along_axis not implemented") +@xfail # (reason="apply_along_axis not implemented") class TestApplyAlongAxis(TestCase): def test_simple(self): a = np.ones((20, 10), "d") @@ -707,8 +679,6 @@ def test_basic(self): assert_equal(res.ndim, 0) assert type(res) is np.ndarray - @xfailIfTorchDynamo - def test_basic_2(self): aa = np.ones((3, 1, 4, 1, 1)) assert aa.squeeze().tensor._base is aa.tensor @@ -742,7 +712,7 @@ def test_squeeze_contiguous(self): assert_(a.flags.f_contiguous) assert_(b.flags.f_contiguous) - @xpassIfTorchDynamo # (reason="XXX: noop in torch, while numpy raises") + @xfail # (reason="XXX: noop in torch, while numpy raises") def test_squeeze_axis_handling(self): with assert_raises(ValueError): np.squeeze(np.array([[1], [2], [3]]), axis=0) @@ -840,7 +810,7 @@ def test_kroncompare(self): assert_equal(large, klarge) -@xpassIfTorchDynamo # (reason="TODO: implement") +@xfail # (reason="TODO: implement") class TestMayShareMemory(TestCase): def test_basic(self): d = np.ones((50, 60)) diff --git a/test/torch_np/numpy_tests/lib/test_twodim_base.py b/test/torch_np/numpy_tests/lib/test_twodim_base.py index dda807b556369..bbf9fd1bbc5cd 100644 --- a/test/torch_np/numpy_tests/lib/test_twodim_base.py +++ b/test/torch_np/numpy_tests/lib/test_twodim_base.py @@ -8,72 +8,40 @@ from unittest import expectedFailure as xfail, skipIf as skipif import pytest + +import torch._numpy as np from pytest import raises as assert_raises +from torch._numpy import ( + arange, + array, + diag, + eye, + fliplr, + flipud, + histogram2d, + ones, + tri, # mask_indices, + tril_indices, + tril_indices_from, + triu_indices, + triu_indices_from, + vander, + zeros, +) +from torch._numpy.testing import ( + assert_allclose, + assert_array_almost_equal, + assert_array_equal, # assert_array_max_ulp, + assert_equal, +) from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, run_tests, - TEST_WITH_TORCHDYNAMO, TestCase, - xpassIfTorchDynamo, ) - -# If we are going to trace through these, we should use NumPy -# If testing on eager mode, we use torch._numpy -if TEST_WITH_TORCHDYNAMO: - import numpy as np - from numpy import ( - arange, - array, - diag, - eye, - fliplr, - flipud, - histogram2d, - ones, - tri, # mask_indices, - tril_indices, - tril_indices_from, - triu_indices, - triu_indices_from, - vander, - zeros, - ) - from numpy.testing import ( - assert_allclose, - assert_array_almost_equal, - assert_array_equal, # assert_array_max_ulp, - assert_equal, - ) -else: - import torch._numpy as np - from torch._numpy import ( - arange, - array, - diag, - eye, - fliplr, - flipud, - histogram2d, - ones, - tri, # mask_indices, - tril_indices, - tril_indices_from, - triu_indices, - triu_indices_from, - vander, - zeros, - ) - from torch._numpy.testing import ( - assert_allclose, - assert_array_almost_equal, - assert_array_equal, # assert_array_max_ulp, - assert_equal, - ) - - skip = functools.partial(skipif, True) @@ -133,7 +101,7 @@ def test_eye_bounds(self): def test_bool(self): assert_equal(eye(2, 2, dtype=bool), [[True, False], [False, True]]) - @xpassIfTorchDynamo # (reason="TODO: implement order=non-default") + @xfail # (reason="TODO: implement order=non-default") def test_order(self): mat_c = eye(4, 3, k=-1) mat_f = eye(4, 3, k=-1, order="F") @@ -159,10 +127,9 @@ def test_vector(self): assert_equal(diag(vals, k=2), b) assert_equal(diag(vals, k=-2), c) - def test_matrix(self): - self.check_matrix(vals=(100 * get_mat(5) + 1).astype("l")) - - def check_matrix(self, vals): + def test_matrix(self, vals=None): + if vals is None: + vals = (100 * get_mat(5) + 1).astype("l") b = zeros((5,)) for k in range(5): b[k] = vals[k, k] @@ -175,10 +142,10 @@ def check_matrix(self, vals): b[k] = vals[k + 2, k] assert_equal(diag(vals, -2), b[:3]) - @xpassIfTorchDynamo # (reason="TODO implement orders") + @xfail # (reason="TODO implement orders") def test_fortran_order(self): vals = array((100 * get_mat(5) + 1), order="F", dtype="l") - self.check_matrix(vals) + self.test_matrix(vals) def test_diag_bounds(self): A = [[1, 2], [3, 4], [5, 6]] @@ -284,7 +251,7 @@ def test_empty(self): # assert_array_max_ulp(a, np.zeros((4, 4))) assert_allclose(a, np.zeros((4, 4)), atol=1e-15) - @xpassIfTorchDynamo # (reason="pytorch does not support bins = [int, array]") + @xfail # (reason="pytorch does not support bins = [int, array]") def test_binparameter_combination(self): x = array([0, 0.09207008, 0.64575234, 0.12875982, 0.47390599, 0.59944483, 1]) y = array([0, 0.14344267, 0.48988575, 0.30558665, 0.44700682, 0.15886423, 1]) @@ -318,7 +285,6 @@ def test_binparameter_combination(self): assert_array_equal(H, answer) assert_array_equal(xe, array([0.0, 0.25, 0.5, 0.75, 1])) - @skip(reason="NP_VER: fails on CI with older NumPy") @parametrize("x_len, y_len", [(10, 11), (20, 19)]) def test_bad_length(self, x_len, y_len): x, y = np.ones(x_len), np.ones(y_len) @@ -402,7 +368,7 @@ def test_mask_indices(self): iu1 = mask_indices(3, np.triu, 1) assert_array_equal(a[iu1], array([1, 2, 5])) - @xpassIfTorchDynamo # (reason="np.tril_indices == our tuple(tril_indices)") + @xfail # (reason="np.tril_indices == our tuple(tril_indices)") def test_tril_indices(self): # indices without and with offset il1 = tril_indices(4) @@ -462,7 +428,7 @@ def test_tril_indices(self): ) -@xpassIfTorchDynamo # (reason="np.triu_indices == our tuple(triu_indices)") +@xfail # (reason="np.triu_indices == our tuple(triu_indices)") class TestTriuIndices(TestCase): def test_triu_indices(self): iu1 = triu_indices(4) diff --git a/test/torch_np/numpy_tests/lib/test_type_check.py b/test/torch_np/numpy_tests/lib/test_type_check.py index 96c0ddbc9672b..0afa518edb228 100644 --- a/test/torch_np/numpy_tests/lib/test_type_check.py +++ b/test/torch_np/numpy_tests/lib/test_type_check.py @@ -5,44 +5,22 @@ from unittest import expectedFailure as xfail, skipIf as skipif +import torch._numpy as np from pytest import raises as assert_raises -from torch.testing._internal.common_utils import ( - run_tests, - TEST_WITH_TORCHDYNAMO, - TestCase, - xpassIfTorchDynamo, -) - - -if TEST_WITH_TORCHDYNAMO: - import numpy as np - from numpy import ( - common_type, - iscomplex, - iscomplexobj, - isneginf, - isposinf, - isreal, - isrealobj, - nan_to_num, - real_if_close, - ) - from numpy.testing import assert_, assert_array_equal, assert_equal -else: - import torch._numpy as np - from torch._numpy import ( - common_type, - iscomplex, - iscomplexobj, - isneginf, - isposinf, - isreal, - isrealobj, - nan_to_num, - real_if_close, - ) - from torch._numpy.testing import assert_, assert_array_equal, assert_equal +from torch._numpy import ( + common_type, + iscomplex, + iscomplexobj, + isneginf, + isposinf, + isreal, + isrealobj, + nan_to_num, + real_if_close, +) +from torch._numpy.testing import assert_, assert_array_equal, assert_equal +from torch.testing._internal.common_utils import run_tests, TestCase skip = functools.partial(skipif, True) @@ -51,7 +29,7 @@ def assert_all(x): assert_(np.all(x), x) -@xpassIfTorchDynamo # (reason="common_type not implemented") +@xfail # (reason="common_type not implemented") class TestCommonType(TestCase): def test_basic(self): ai32 = np.array([[1, 2], [3, 4]], dtype=np.int32) @@ -118,7 +96,7 @@ def test_default_3(self): assert_equal(mintypecode("idD"), "D") -@xpassIfTorchDynamo # (reason="TODO: decide on if [1] is a scalar or not") +@xfail # (reason="TODO: decide on if [1] is a scalar or not") class TestIsscalar(TestCase): def test_basic(self): assert_(np.isscalar(3)) From 55ab9932f508466bc3f3e2c96de24dce07e27130 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 26 Oct 2023 23:27:57 +0000 Subject: [PATCH 16/78] Revert "Constrain sdpa to fx strides (#111721)" This reverts commit 8a7c3cec78686e661b3781b916a8aae59083f90a. Reverted https://github.com/pytorch/pytorch/pull/111721 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it is breaking ROCm job in trunk https://hud.pytorch.org/pytorch/pytorch/commit/8a7c3cec78686e661b3781b916a8aae59083f90a ([comment](https://github.com/pytorch/pytorch/pull/111721#issuecomment-1782064133)) --- test/inductor/test_torchinductor.py | 56 ---------------- ...st_torchinductor_codegen_dynamic_shapes.py | 1 - torch/_inductor/lowering.py | 67 ++----------------- 3 files changed, 4 insertions(+), 120 deletions(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index a0b65087c9211..669d99632e9e2 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -6772,62 +6772,6 @@ def forward(arg6, arg7, arg16): # expanded dim should not cause copy in require_stride_order self.assertEqual(torch._inductor.metrics.generated_kernel_count, 0) - @requires_cuda() - def test_sdpa(self): - def foo(arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): - view = torch.ops.aten.view.default(arg3_1, [23760, 128]) - arg3_1 = None - mm = torch.ops.aten.mm.default(view, arg4_1) - view = arg4_1 = None - view_1 = torch.ops.aten.view.default(mm, [3, 99, 80, 8]) - mm = None - view_2 = torch.ops.aten.view.default(view_1, [3, 99, 80, 8]) - view_1 = None - permute = torch.ops.aten.permute.default(view_2, [0, 3, 1, 2]) - view_2 = None - view_3 = torch.ops.aten.view.default(permute, [3, 8, 99, 80]) - permute = None - - clone = torch.ops.aten.clone.default( - view_3, memory_format=torch.contiguous_format - ) - view_3 = None - - expand = torch.ops.aten.expand.default(clone, [3, 8, 99, 80]) - clone = None - _scaled_dot_product_efficient_attention = ( - torch.ops.aten._scaled_dot_product_efficient_attention.default( - arg0_1, arg1_1, arg2_1, expand, False - ) - ) - arg0_1 = arg1_1 = arg2_1 = expand = None - getitem = _scaled_dot_product_efficient_attention[0] - _scaled_dot_product_efficient_attention = None - return (getitem,) - - DEVICE = torch.device("cuda:0") - DTYPE = torch.float16 - B = 3 - H = 8 - Q = 99 - K = 80 - D = 32 - C_bias = 128 - - # inputs - query = torch.randn((B, H, Q, D), device=DEVICE, dtype=DTYPE) - key = torch.randn((B, H, K, D), device=DEVICE, dtype=DTYPE) - value = torch.randn((B, H, K, D), device=DEVICE, dtype=DTYPE) - bias = torch.randn((B, Q, K, C_bias), device=DEVICE, dtype=DTYPE) - weights = torch.randn((C_bias, H), device=DEVICE, dtype=DTYPE) - - self.common( - foo, - (query, key, value, bias, weights), - atol=0.02, - rtol=1e4, - ) - def test_where_with_logical_op(self): def fn_and(x, y): return torch.where(torch.logical_and(x, y), 1.0, 0.0) diff --git a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py index 0e80d8adb5828..8677453a55c5c 100644 --- a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py @@ -261,7 +261,6 @@ def run(*ex, **kwargs): "test_zero_dim_reductions_dynamic_shapes": TestFailure( ("cpu", "cuda"), is_skip=True ), - "test_sdpa_dynamic_shapes": TestFailure(("cpu",), is_skip=True), # # The following tests do not support dynamic shapes yet: # diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index ed142c57b1017..b569da72aba01 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -2008,69 +2008,10 @@ def apply_constraint(arg, fx_arg): make_fallback(aten._fused_moving_avg_obs_fq_helper_functional) make_fallback(aten.grid_sampler_2d_backward, require_dense) make_fallback(aten.randperm) - - -def sdpa_constraint(fx_node, *args, **kwargs): - # sdpa requires dense last dimension - def apply_constraint(arg, fx_arg): - if not isinstance(arg, ir.IRNode): - return arg - - meta_val = fx_arg.meta["val"] - if not meta_val.is_cuda: - return arg - - stride_order = ir.get_stride_order(meta_val.stride()) - if stride_order and stride_order[-1] != 0: - # contiguous stride order - stride_order = list(reversed(range(len(arg.get_size())))) - - ALIGNMENT = 16 - - def is_aligned(x): - return (V.graph.sizevars.size_hint(x.get_size()[-1]) % ALIGNMENT) == 0 - - assert isinstance(arg, TensorBox) - unaligned_input_shape = isinstance(arg.data, ir.SliceView) and not is_aligned( - arg - ) - aligned_input_view = unaligned_input_shape and is_aligned(arg.unwrap_view()) - - # input is padded, requiring_stride_order will unwrap the view and unpad. - # Would be nice to be able to require certain padding from inductor ir, nyi - if aligned_input_view: - return arg - - return ir.ExternKernel.require_stride_order(arg, stride_order) - - args = tuple( - apply_constraint(arg, fx_arg) for arg, fx_arg in zip(args, fx_node.args) - ) - kwargs = {k: apply_constraint(v, fx_node.kwargs[k]) for k, v in kwargs.items()} - return args, kwargs - - -make_fallback( - aten._scaled_dot_product_efficient_attention.default, - sdpa_constraint, - warn=False, -) -make_fallback( - aten._scaled_dot_product_efficient_attention_backward.default, - sdpa_constraint, - warn=False, -) -make_fallback( - aten._scaled_dot_product_flash_attention.default, - sdpa_constraint, - warn=False, -) -make_fallback( - aten._scaled_dot_product_flash_attention_backward.default, - sdpa_constraint, - warn=False, -) - +make_fallback(aten._scaled_dot_product_efficient_attention) +make_fallback(aten._scaled_dot_product_efficient_attention_backward) +make_fallback(aten._scaled_dot_product_flash_attention, warn=False) +make_fallback(aten._scaled_dot_product_flash_attention_backward) make_fallback(aten.sort) make_fallback(aten.sort.stable) make_fallback(aten._sparse_coo_tensor_with_dims_and_tensors) From 4f7f46ee3594278a76aea0550412017b15885fe3 Mon Sep 17 00:00:00 2001 From: lezcano Date: Thu, 26 Oct 2023 20:00:32 +0000 Subject: [PATCH 17/78] Move SymDispatchMode to its own file (#112035) This is just code movement + a getter and a setter to break the dependency of SymDispatchMode, and in turn, ProxySymDispatchMode on sympy. Pull Request resolved: https://github.com/pytorch/pytorch/pull/112035 Approved by: https://github.com/peterbell10 --- torch/fx/experimental/_sym_dispatch_mode.py | 58 ++++++++++++++++++ torch/fx/experimental/proxy_tensor.py | 3 +- torch/fx/experimental/symbolic_shapes.py | 66 ++++----------------- 3 files changed, 70 insertions(+), 57 deletions(-) create mode 100644 torch/fx/experimental/_sym_dispatch_mode.py diff --git a/torch/fx/experimental/_sym_dispatch_mode.py b/torch/fx/experimental/_sym_dispatch_mode.py new file mode 100644 index 0000000000000..52fa6221c3de8 --- /dev/null +++ b/torch/fx/experimental/_sym_dispatch_mode.py @@ -0,0 +1,58 @@ +from typing import List, Type + +__all__ = ["SymDispatchMode", "handle_sym_dispatch", "sym_function_mode"] + +SYM_FUNCTION_MODE = None + + +# SymDispatchMode gets invoked whenever an operation is processed on +# a PySymInt. When this occurs, you get called at __sym_dispatch__ +# with the operation in question. This is symmetric to TorchDispatchMode +# but with some caveats: +# +# - In TorchDispatchMode, you get the same arguments as what a user +# invoked your API with; e.g., if you call torch.ops.aten.foo(a, b), +# you get (a, b) as args to your call. In SymDispatchMode, if +# you call a + b (where a and b are SymInts), you will get +# (a.node, b.node) as your args (these are PySymInts) +# +# - SymInt/PySymInt don't have FX proxy support (unlike, e.g., Tensor). +# So you have to manually call Tracer/create_node to write into +# the graph. See ProxySymDispatchMode for an example +# +class SymDispatchMode: + def __sym_dispatch__(self, func, types, args, kwargs): + raise NotImplementedError() + + def __enter__(self): + global SYM_FUNCTION_MODE + old = SYM_FUNCTION_MODE + if hasattr(self, "inner"): + raise RuntimeError( + f"{self} has already been used as a mode. Please use a fresh version" + ) + else: + self.inner = old + SYM_FUNCTION_MODE = self + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + global SYM_FUNCTION_MODE + SYM_FUNCTION_MODE = self.inner + + +def handle_sym_dispatch(func, args, kwargs): + global SYM_FUNCTION_MODE + mode = sym_function_mode() + assert mode + SYM_FUNCTION_MODE = mode.inner + try: + # TODO: properly compute types + types: List[Type] = [] + return mode.__sym_dispatch__(func, types, args, kwargs) + finally: + SYM_FUNCTION_MODE = mode + + +def sym_function_mode(): + return SYM_FUNCTION_MODE diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index 0f99ebe71e323..7b8146e5c1c5c 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -29,7 +29,8 @@ _push_mode, ) -from .symbolic_shapes import ShapeEnv, SymDispatchMode, SymNode +from .symbolic_shapes import ShapeEnv, SymNode +from ._sym_dispatch_mode import SymDispatchMode from torch.fx import Proxy import torch.fx.traceback as fx_traceback from torch import SymInt, SymFloat, SymBool diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 0a597f6e919b9..619381088ff4f 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -28,6 +28,7 @@ replay_shape_env_events, shape_env_check_state_equal ) +from torch.fx.experimental._sym_dispatch_mode import handle_sym_dispatch, sym_function_mode # NB: The sym_* functions are used via getattr() and must be imported here. from torch import ( # noqa: F401 @@ -66,7 +67,7 @@ class GuardOnDataDependentSymNode(RuntimeError): __all__ = [ "has_symbolic_sizes_strides", "create_contiguous", "ShapeEnv", "is_concrete_int", - "SymDispatchMode", "guard_int", "guard_float", "guard_scalar", "wrap_node", + "guard_int", "guard_float", "guard_scalar", "wrap_node", "method_to_operator", "hint_int", "SYMPY_INTERP", "free_symbols", "is_symbol_binding_fx_node", "is_concrete_bool", "SHAPEENV_EVENT_KEY", "CURRENT_NODE_KEY", ] @@ -90,8 +91,6 @@ def uninteresting_files(): ] return {inspect.getfile(m) for m in mods} -SYM_FUNCTION_MODE = None - # We don't bother with the metaclass as all of the dispatching logic happens # entirely from Python # @@ -101,39 +100,6 @@ def uninteresting_files(): class ConstraintViolationError(RuntimeError): pass -# SymDispatchMode gets invoked whenever an operation is processed on -# a PySymInt. When this occurs, you get called at __sym_dispatch__ -# with the operation in question. This is symmetric to TorchDispatchMode -# but with some caveats: -# -# - In TorchDispatchMode, you get the same arguments as what a user -# invoked your API with; e.g., if you call torch.ops.aten.foo(a, b), -# you get (a, b) as args to your call. In SymDispatchMode, if -# you call a + b (where a and b are SymInts), you will get -# (a.node, b.node) as your args (these are PySymInts) -# -# - SymInt/PySymInt don't have FX proxy support (unlike, e.g., Tensor). -# So you have to manually call Tracer/create_node to write into -# the graph. See ProxySymDispatchMode for an example -# -class SymDispatchMode: - def __sym_dispatch__(self, func, types, args, kwargs): - raise NotImplementedError() - - def __enter__(self): - global SYM_FUNCTION_MODE - old = SYM_FUNCTION_MODE - if hasattr(self, "inner"): - raise RuntimeError(f"{self} has already been used as a mode. Please use a fresh version") - else: - self.inner = old - SYM_FUNCTION_MODE = self - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - global SYM_FUNCTION_MODE - SYM_FUNCTION_MODE = self.inner - def has_symbolic_sizes_strides(elem): return elem._has_symbolic_sizes_strides @@ -143,18 +109,6 @@ def create_contiguous(shape): strides.append(dim * strides[-1]) return list(reversed(strides)) -def _handle_sym_dispatch(func, args, kwargs): - global SYM_FUNCTION_MODE - mode = SYM_FUNCTION_MODE - assert mode - SYM_FUNCTION_MODE = mode.inner - try: - # TODO: properly compute types - types: List[Type] = [] - return mode.__sym_dispatch__(func, types, args, kwargs) - finally: - SYM_FUNCTION_MODE = mode - def hint_int(a): if isinstance(a, torch.SymInt): return a.node.require_hint() @@ -1404,8 +1358,8 @@ def binary_magic_impl(self, other): if alternate_impl and out_hint is not None: return to_node(self, alternate_impl(wrap_node(self), wrap_node(other))) - if SYM_FUNCTION_MODE: - return to_node(self, _handle_sym_dispatch(op, (wrap_node(self), wrap_node(other)), {})) + if sym_function_mode(): + return to_node(self, handle_sym_dispatch(op, (wrap_node(self), wrap_node(other)), {})) assert isinstance(other, SymNode) # TODO: consider constant prop here try: @@ -1438,8 +1392,8 @@ def binary_magic_impl(self, other): def unary_magic_impl(self): op = method_to_operator(method) - if SYM_FUNCTION_MODE: - return to_node(self, _handle_sym_dispatch(op, (wrap_node(self),), {})) + if sym_function_mode(): + return to_node(self, handle_sym_dispatch(op, (wrap_node(self),), {})) # TODO: consider constant prop here expr = self.expr if method == "floor" or method == "ceiling": @@ -1472,10 +1426,10 @@ def unary_magic_impl(self): def sym_ite_impl(pred_node, then_node, else_node): out_hint = then_node.hint if pred_node.hint else else_node.hint - if SYM_FUNCTION_MODE: + if sym_function_mode(): return to_node( pred_node, - _handle_sym_dispatch( + handle_sym_dispatch( sym_ite, (wrap_node(pred_node), wrap_node(then_node), wrap_node(else_node)), {} ) @@ -1506,10 +1460,10 @@ def _make_node_sizes_strides(method, func): def sizes_strides_impl(self, sizes, strides): op = getattr(sys.modules[__name__], method) - if SYM_FUNCTION_MODE: + if sym_function_mode(): return to_node( self, - _handle_sym_dispatch( + handle_sym_dispatch( op, ([wrap_node(s) for s in sizes], [wrap_node(s) for s in strides]), {} From deac5357db28631a49b8be9f2454f099c7ef9335 Mon Sep 17 00:00:00 2001 From: lezcano Date: Thu, 26 Oct 2023 20:00:32 +0000 Subject: [PATCH 18/78] Make proxy_tensor.py not depend on SymPy (#112036) Pull Request resolved: https://github.com/pytorch/pytorch/pull/112036 Approved by: https://github.com/malfet, https://github.com/peterbell10 ghstack dependencies: #112035 --- test/test_python_dispatch.py | 2 +- torch/fx/experimental/proxy_tensor.py | 10 +++++++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/test/test_python_dispatch.py b/test/test_python_dispatch.py index 0349075e36566..2173cdbae00ee 100644 --- a/test/test_python_dispatch.py +++ b/test/test_python_dispatch.py @@ -4,7 +4,7 @@ import torch from copy import deepcopy from torch.library import Library, impl, fallthrough_kernel -from torch.fx.experimental.proxy_tensor import ShapeEnv +from torch.fx.experimental.symbolic_shapes import ShapeEnv from torch import SymInt from torch._subclasses.fake_tensor import FakeTensorMode from torch.cuda.jiterator import _create_jit_fn diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index 7b8146e5c1c5c..80a674ce83a53 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -29,7 +29,6 @@ _push_mode, ) -from .symbolic_shapes import ShapeEnv, SymNode from ._sym_dispatch_mode import SymDispatchMode from torch.fx import Proxy import torch.fx.traceback as fx_traceback @@ -81,6 +80,8 @@ def set_proxy_slot(obj, tracer, proxy): # on a tensor, and it affects the metadata on the proxy. tracer.tensor_tracker[obj] = proxy else: + # Avoid importing sympy at a module level + from .symbolic_shapes import SymNode # NB: Never clobber pre-existing proxy. Although the proxies # are in principle equivalent, when we do graph partitioning # we need there not to be spurious dependencies on tangent inputs. @@ -92,6 +93,8 @@ def set_proxy_slot(obj, tracer, proxy): tracer.symnode_tracker[obj] = proxy def has_proxy_slot(obj, tracer): + # Avoid importing sympy at a module level + from .symbolic_shapes import SymNode assert isinstance(obj, (torch.Tensor, SymNode)), type(obj) return get_proxy_slot(obj, tracer, False, lambda _: True) @@ -102,6 +105,8 @@ def get_proxy_slot(obj, tracer, default=no_default, transform=lambda x: x): if isinstance(obj, torch.Tensor): tracker = tracer.tensor_tracker else: + # Avoid importing sympy at a module level + from .symbolic_shapes import SymNode assert isinstance(obj, SymNode), type(obj) tracker = tracer.symnode_tracker @@ -769,6 +774,9 @@ def make_fx(f, @functools.wraps(f) def wrapped(*args): + # Avoid importing sympy at a module level + from .symbolic_shapes import ShapeEnv + phs = pytree.tree_map(lambda _: fx.PH, args) # type: ignore[attr-defined] fx_tracer = PythonKeyTracer() fake_tensor_mode: Any = nullcontext() From 47ccf0488530b6f3148abf344b87d44dc5ecfd86 Mon Sep 17 00:00:00 2001 From: lezcano Date: Thu, 26 Oct 2023 20:00:33 +0000 Subject: [PATCH 19/78] Split SymNode into its own file (#112037) This PR: - Moves TrueDiv, LShift, RShift, IsNonOverlappingAndDenseIndicator to `_sympy.functions.py` - Moves SymNode to `fx.experimental.sym_node`. - This file does not have any SymPy dependencies at import time - It installs the magic methods in Sym{Bool,Int,Float}. - N.b. With this split, we may be able to move Sym{Bool,Int,Float} to this file, and remove quite a few of the hacks around these classes - Imports `sym_node` in `torch/__init__.py` rather than the whole `symbolic_shapes.py`. This breaks the import-time dependency between torch and SymPy Pull Request resolved: https://github.com/pytorch/pytorch/pull/112037 Approved by: https://github.com/peterbell10 ghstack dependencies: #112035, #112036 --- docs/source/conf.py | 37 +- docs/source/fx.rst | 1 + test/allowlist_for_publicAPI.json | 47 +- test/test_dynamic_shapes.py | 38 +- test/test_proxy_tensor.py | 4 +- torch/__init__.py | 14 +- torch/_dynamo/variables/torch.py | 2 +- torch/_functorch/partitioners.py | 4 +- torch/_inductor/graph.py | 9 +- torch/_inductor/lowering.py | 2 +- torch/fx/experimental/proxy_tensor.py | 7 +- torch/fx/experimental/sym_node.py | 1095 ++++++++++++++++++++++ torch/fx/experimental/symbolic_shapes.py | 877 +---------------- torch/fx/experimental/validator.py | 2 +- torch/utils/_sympy/functions.py | 58 +- torch/utils/_sympy/interp.py | 4 +- 16 files changed, 1244 insertions(+), 957 deletions(-) create mode 100644 torch/fx/experimental/sym_node.py diff --git a/docs/source/conf.py b/docs/source/conf.py index 99e6422161275..1f4d48e504697 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -926,10 +926,25 @@ "record_shapeenv_event", "replay_shape_env_events", "shape_env_check_state_equal", + # torch.fx.experimental.sym_node + "ceil_impl", + "floor_ceil_helper", + "floor_impl", + "method_to_operator", + "sympy_is_channels_last_contiguous_2d", + "sympy_is_channels_last_contiguous_3d", + "sympy_is_channels_last_strides_2d", + "sympy_is_channels_last_strides_3d", + "sympy_is_channels_last_strides_generic", + "sympy_is_contiguous", + "sympy_is_contiguous_generic", + "to_node", + "wrap_node", + "sym_sqrt", + "sym_ite", # torch.fx.experimental.symbolic_shapes "bind_symbols", "cast_symbool_to_symint_guardless", - "ceil_impl", "constrain_range", "constrain_unify", "create_contiguous", @@ -940,8 +955,6 @@ "eval_is_non_overlapping_and_dense", "expect_true", "find_symbol_binding_fx_nodes", - "floor_ceil_helper", - "floor_impl", "free_symbols", "free_unbacked_symbols", "fx_placeholder_targets", @@ -963,21 +976,9 @@ "is_non_overlapping_and_dense_indicator", "is_symbol_binding_fx_node", "is_symbolic", - "method_to_operator", "parallel_and", "parallel_or", - "sym_sqrt", - "sym_ite", - "sympy_is_channels_last_contiguous_2d", - "sympy_is_channels_last_contiguous_3d", - "sympy_is_channels_last_strides_2d", - "sympy_is_channels_last_strides_3d", - "sympy_is_channels_last_strides_generic", - "sympy_is_contiguous", - "sympy_is_contiguous_generic", "tensor_has_hints", - "to_node", - "wrap_node", # torch.fx.experimental.unification.core "reify", # torch.fx.experimental.unification.match @@ -2851,6 +2852,8 @@ "RewritingTracer", # torch.fx.experimental.schema_type_annotation "AnnotateTypesWithSchema", + # torch.fx.experimental.sym_node + "SymNode", # torch.fx.experimental.symbolic_shapes "Constraint", "ConstraintViolationError", @@ -2859,17 +2862,13 @@ "DynamicDimConstraintPrinter", "EqualityConstraint", "GuardOnDataDependentSymNode", - "IsNonOverlappingAndDenseIndicator", "LoggingShapeGuardPrinter", - "Pow", "RelaxedUnspecConstraint", "RuntimeAssert", "ShapeEnv", "ShapeGuardPrinter", "StrictMinMaxConstraint", "SymDispatchMode", - "SymNode", - "TrueDiv", # torch.fx.experimental.unification.match "Dispatcher", "VarDispatcher", diff --git a/docs/source/fx.rst b/docs/source/fx.rst index dd71d411a2c20..df0ce8904d374 100644 --- a/docs/source/fx.rst +++ b/docs/source/fx.rst @@ -1149,6 +1149,7 @@ API Reference .. py:module:: torch.fx.experimental.rewriter .. py:module:: torch.fx.experimental.schema_type_annotation .. py:module:: torch.fx.experimental.symbolic_shapes +.. py:module:: torch.fx.experimental.sym_node .. py:module:: torch.fx.experimental.unification.core .. py:module:: torch.fx.experimental.unification.dispatch .. py:module:: torch.fx.experimental.unification.match diff --git a/test/allowlist_for_publicAPI.json b/test/allowlist_for_publicAPI.json index c1d71d1f07173..20b515934ca37 100644 --- a/test/allowlist_for_publicAPI.json +++ b/test/allowlist_for_publicAPI.json @@ -1967,6 +1967,31 @@ "Transformer", "Tuple" ], + "torch.fx.experimental.sym_dispatch_mode": [ + "sym_function_mode", + "set_sym_function_mode" + ], + "torch.fx.experimental.sym_node": [ + "SymNode", + "method_to_operator", + "magic_methods", + "to_node", + "wrap_node", + "is_channels_last_contiguous_2d", + "is_channels_last_contiguous_3d", + "is_channels_last_strides_2d", + "is_channels_last_strides_3d", + "is_non_overlapping_and_dense_indicator", + "sympy_is_channels_last_contiguous_2d", + "sympy_is_channels_last_contiguous_3d", + "sympy_is_channels_last_strides_2d", + "sympy_is_channels_last_strides_3d", + "sympy_is_channels_last_strides_generic", + "is_contiguous", + "sympy_is_contiguous", + "sympy_is_contiguous_generic", + "sym_sqrt" + ], "torch.fx.experimental.symbolic_shapes": [ "Constraint", "ConstraintViolationError", @@ -1975,18 +2000,13 @@ "DynamicDimConstraintPrinter", "EqualityConstraint", "GuardOnDataDependentSymNode", - "IsNonOverlappingAndDenseIndicator", "LoggingShapeGuardPrinter", - "Pow", "RelaxedUnspecConstraint", "RuntimeAssert", "ShapeGuardPrinter", "StrictMinMaxConstraint", - "SymNode", - "TrueDiv", "bind_symbols", "cast_symbool_to_symint_guardless", - "ceil_impl", "constrain_range", "constrain_unify", "definitely_false", @@ -1996,33 +2016,16 @@ "eval_is_non_overlapping_and_dense", "expect_true", "find_symbol_binding_fx_nodes", - "floor_ceil_helper", - "floor_impl", "free_unbacked_symbols", "fx_placeholder_targets", "fx_placeholder_vals", "guard_bool", "has_hint", - "is_channels_last_contiguous_2d", - "is_channels_last_contiguous_3d", - "is_channels_last_strides_2d", - "is_channels_last_strides_3d", - "is_contiguous", - "is_non_overlapping_and_dense_indicator", "is_symbolic", "parallel_and", "parallel_or", "safe_expand", - "sym_sqrt", - "sympy_is_channels_last_contiguous_2d", - "sympy_is_channels_last_contiguous_3d", - "sympy_is_channels_last_strides_2d", - "sympy_is_channels_last_strides_3d", - "sympy_is_channels_last_strides_generic", - "sympy_is_contiguous", - "sympy_is_contiguous_generic", "tensor_has_hints", - "to_node", "uninteresting_files" ], "torch.fx.experimental.unification.match": [ diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index 07180f8210dfa..c75ddb4061695 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -14,8 +14,9 @@ import torch.nn.functional as F from torch import sym_int, SymBool, SymFloat, SymInt from torch._C import _disabled_torch_function_impl -from torch.fx.experimental import symbolic_shapes +from torch.fx.experimental import sym_node from torch.fx.experimental.proxy_tensor import make_fx +from torch.fx.experimental.sym_node import to_node, sym_sqrt, SymNode from torch.fx.experimental.symbolic_shapes import ( DimConstraints, DimDynamic, @@ -25,10 +26,6 @@ guard_int, GuardOnDataDependentSymNode, ShapeEnv, - sym_float, - sym_sqrt, - SymNode, - to_node, is_symbolic, ) from torch.testing._internal.common_utils import ( @@ -334,7 +331,7 @@ def test_numel(self): def test_int_to_float(self): shape_env = ShapeEnv() x = create_symbolic_tensor("x", torch.randn(5), shape_env) - r = sym_float(x.shape[0]) + r = torch.sym_float(x.shape[0]) self.assertIsInstance(r, torch.SymFloat, msg=type(r)) def test_aten_ops(self): @@ -386,7 +383,7 @@ def test_sym_int(self): self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(floor(s1/2), 3)""") a3 = create_symint(shape_env, 3) - r = sym_int(2.0 * sym_float(a3)) + r = sym_int(2.0 * torch.sym_float(a3)) self.assertEqual(guard_int(r), 6) self.assertIsInstance(r, torch.SymInt, msg=type(r)) self.assertExpectedInline(str(shape_env.guards[2][0]), """Eq(2*s2, 6)""") @@ -669,11 +666,11 @@ def maybe_xfail(inp1, inp2): else: return contextlib.nullcontext() - if fn in symbolic_shapes.magic_methods_on_math: + if fn in sym_node.magic_methods_on_math: lambda_apply = getattr(math, fn) - elif fn in symbolic_shapes.magic_methods_on_submodule: - lambda_apply = getattr(symbolic_shapes, fn) - elif fn in symbolic_shapes.magic_methods_on_operator_with_trailing_underscore: + elif fn in sym_node.magic_methods_on_submodule: + lambda_apply = getattr(sym_node, fn) + elif fn in sym_node.magic_methods_on_operator_with_trailing_underscore: lambda_apply = getattr(operator, f"{fn}_") else: lambda_apply = getattr(operator, fn) @@ -700,7 +697,7 @@ def guard_fn(v): out = lambda_apply(sym_inp1) else: out = lambda_apply(sym_inp1, inp2) - if fn not in symbolic_shapes.alternate_impl_if_hinted_methods: + if fn not in sym_node.alternate_impl_if_hinted_methods: self.assertTrue(isinstance(out, (SymInt, SymFloat, SymBool))) out = guard_fn(out) self.assertEqual(out, ref_out) @@ -712,7 +709,7 @@ def guard_fn(v): sym_inp2 = get_sym_inp(inp2) with maybe_xfail(inp1, sym_inp2): out = lambda_apply(inp1, sym_inp2) - if fn not in symbolic_shapes.alternate_impl_if_hinted_methods: + if fn not in sym_node.alternate_impl_if_hinted_methods: self.assertTrue(isinstance(out, (SymInt, SymFloat, SymBool))) out = guard_fn(out) self.assertEqual(out, ref_out) @@ -720,24 +717,24 @@ def guard_fn(v): # Symified both args with maybe_xfail(sym_inp1, sym_inp2): out = lambda_apply(sym_inp1, sym_inp2) - if fn not in symbolic_shapes.alternate_impl_if_hinted_methods: + if fn not in sym_node.alternate_impl_if_hinted_methods: self.assertTrue(isinstance(out, (SymInt, SymFloat, SymBool))) out = guard_fn(out) self.assertEqual(out, ref_out) - @parametrize("fn", list(symbolic_shapes.magic_methods.keys())) + @parametrize("fn", list(sym_node.magic_methods.keys())) def test_bool_method(self, fn): # sym_ite has its own tests - if fn not in symbolic_shapes.bool_magic_methods or fn == "sym_ite": + if fn not in sym_node.bool_magic_methods or fn == "sym_ite": self.skipTest(f"{fn} is non-bool") - is_unary_fn = fn in symbolic_shapes.unary_magic_methods + is_unary_fn = fn in sym_node.unary_magic_methods shape_env = ShapeEnv() self._do_test(fn, True, False, shape_env, is_unary_fn) - @parametrize("fn", list(symbolic_shapes.magic_methods.keys())) + @parametrize("fn", list(sym_node.magic_methods.keys())) @parametrize("first_type", ["int", "float"]) @parametrize("second_type", ["int", "float"]) def test_method(self, fn, first_type, second_type): @@ -745,12 +742,12 @@ def test_method(self, fn, first_type, second_type): # TODO: Hmm, this looks like we skip all floats self.skipTest(f"{fn} is not a float magic method") - is_unary_fn = fn in symbolic_shapes.unary_magic_methods + is_unary_fn = fn in sym_node.unary_magic_methods # Second argument is ignored for unary function. So only run for one type if is_unary_fn and second_type == "float": self.skipTest(f"{fn} is unary and already tested") - if fn in symbolic_shapes.bool_magic_methods: + if fn in sym_node.bool_magic_methods: self.skipTest(f"{fn} is bool") # Only floats here since these will be converted to int if necessary. @@ -1080,7 +1077,6 @@ def is_complex(x): class TestDimConstraints(TestCase): def test_dim_constraints_reduce_congruences_simple(self): from sympy import Symbol - from torch.fx.experimental.symbolic_shapes import DimConstraints s = Symbol("s", positive=True, integer=True) dim_constraints = DimConstraints({}, {}, set(), {}) diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index 50d46af03bc8b..8ebd9d2c1e53b 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -11,7 +11,7 @@ from torch._subclasses.fake_tensor import DynamicOutputShapeException, DataDependentOutputException, FakeTensorMode from torch._decomp import decomposition_table from torch.fx.experimental.symbolic_shapes import ( - sym_float, eval_guards, bind_symbols, fx_placeholder_vals, fx_placeholder_targets, + eval_guards, bind_symbols, fx_placeholder_vals, fx_placeholder_targets, guard_int, GuardOnDataDependentSymNode ) from torch.testing._internal.custom_op_db import custom_op_db @@ -1408,7 +1408,7 @@ def test_elementwise_meta_with_sym_numbers(self): def f(x, offset, as_sym_float=False): x0 = x.size()[0] if as_sym_float: - x0 = sym_float(x0) + x0 = torch.sym_float(x0) return torch.add(x0, offset) fx_g = make_fx(f, tracing_mode="symbolic")(torch.rand(2, 3), 2.0, False) diff --git a/torch/__init__.py b/torch/__init__.py index e90a440afb130..1ad056f08fe9a 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -261,7 +261,7 @@ def __int__(self): def __index__(self): return self.node.int_() - # Magic methods installed by torch.fx.experimental.symbolic_shapes + # Magic methods installed by torch.fx.experimental.sym_node def __eq__(self, other: object) -> builtins.bool: raise AssertionError("type stub not overridden") @@ -313,7 +313,7 @@ def __init__(self, node): def __bool__(self): return self.node.bool_() - # Magic methods installed by torch.fx.experimental.symbolic_shapes + # Magic methods installed by torch.fx.experimental.sym_node def __eq__(self, other: object) -> builtins.bool: raise AssertionError("type stub not overridden") @@ -363,7 +363,7 @@ def __bool__(self): def __int__(self): return builtins.int(self.node.bool_()) - # Magic methods installed by torch.fx.experimental.symbolic_shapes + # Magic methods installed by torch.fx.experimental.sym_node def __and__(self, other) -> "SymBool": raise AssertionError("type stub not overridden") @@ -996,7 +996,8 @@ def _check_with(error_type, cond: Union[builtins.bool, SymBool], message: Callab if not isinstance(cond, (builtins.bool, torch.SymBool)): raise TypeError(f'cond must be a bool, but got {type(cond)}') - if torch.fx.experimental.symbolic_shapes.expect_true(cond): + from torch.fx.experimental.symbolic_shapes import expect_true + if expect_true(cond): return # error_type must be a subclass of Exception and not subclass of Warning @@ -1045,7 +1046,8 @@ def _check_is_size(i, message=None): """ # This is responsible for the expect_true _check(i >= 0, message) - torch.fx.experimental.symbolic_shapes._advise_is_size(i) + from torch.fx.experimental.symbolic_shapes import _advise_is_size + _advise_is_size(i) def _check_index(cond, message=None): # noqa: F811 r"""Throws error containing an optional message if the specified condition @@ -1799,7 +1801,7 @@ def _register_device_module(device_type, module): csan.enable_cuda_sanitizer() # Populate magic methods on SymInt and SymFloat -import torch.fx.experimental.symbolic_shapes +import torch.fx.experimental.sym_node from torch import func as func from torch.func import vmap diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index eee4e9e0c4ef9..1b4280ba563f5 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -657,7 +657,7 @@ def fn_with_prim_types(x): fn_ = self.value if any(isinstance(x, SymNodeVariable) for x in args): if self.value == math.sqrt: - from torch.fx.experimental.symbolic_shapes import sym_sqrt + from torch.fx.experimental.sym_node import sym_sqrt fn_ = sym_sqrt diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py index a922c80af8da9..3d84f3aac7ca5 100644 --- a/torch/_functorch/partitioners.py +++ b/torch/_functorch/partitioners.py @@ -1,7 +1,7 @@ from torch.fx.experimental.proxy_tensor import is_sym_node, py_sym_types +from torch.fx.experimental.sym_node import magic_methods, method_to_operator from torch.fx.experimental.symbolic_shapes import ( - hint_int, magic_methods, method_to_operator, free_symbols, - is_symbol_binding_fx_node, find_symbol_binding_fx_nodes + hint_int, free_symbols, is_symbol_binding_fx_node, find_symbol_binding_fx_nodes ) import torch import torch.fx as fx diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 9aca7dea4983e..0b52b604e58d7 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -17,13 +17,8 @@ from torch._decomp import get_decompositions from torch._dynamo.utils import dynamo_timed from torch._logging import LazyString -from torch.fx.experimental.symbolic_shapes import ( - free_symbols, - magic_methods, - method_to_operator, - ShapeEnv, - SymTypes, -) +from torch.fx.experimental.sym_node import magic_methods, method_to_operator +from torch.fx.experimental.symbolic_shapes import free_symbols, ShapeEnv, SymTypes from torch.utils._mode_utils import no_dispatch from . import config, ir diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index b569da72aba01..0541eb7357a0f 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -28,7 +28,7 @@ is_integer_dtype, Number, ) -from torch.fx.experimental.symbolic_shapes import magic_methods, method_to_operator +from torch.fx.experimental.sym_node import magic_methods, method_to_operator from torch.utils._pytree import tree_flatten from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing from .._dynamo.utils import import_submodule diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index 80a674ce83a53..bba1a894a7f82 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -29,6 +29,7 @@ _push_mode, ) +from .sym_node import SymNode from ._sym_dispatch_mode import SymDispatchMode from torch.fx import Proxy import torch.fx.traceback as fx_traceback @@ -80,8 +81,6 @@ def set_proxy_slot(obj, tracer, proxy): # on a tensor, and it affects the metadata on the proxy. tracer.tensor_tracker[obj] = proxy else: - # Avoid importing sympy at a module level - from .symbolic_shapes import SymNode # NB: Never clobber pre-existing proxy. Although the proxies # are in principle equivalent, when we do graph partitioning # we need there not to be spurious dependencies on tangent inputs. @@ -93,8 +92,6 @@ def set_proxy_slot(obj, tracer, proxy): tracer.symnode_tracker[obj] = proxy def has_proxy_slot(obj, tracer): - # Avoid importing sympy at a module level - from .symbolic_shapes import SymNode assert isinstance(obj, (torch.Tensor, SymNode)), type(obj) return get_proxy_slot(obj, tracer, False, lambda _: True) @@ -105,8 +102,6 @@ def get_proxy_slot(obj, tracer, default=no_default, transform=lambda x: x): if isinstance(obj, torch.Tensor): tracker = tracer.tensor_tracker else: - # Avoid importing sympy at a module level - from .symbolic_shapes import SymNode assert isinstance(obj, SymNode), type(obj) tracker = tracer.symnode_tracker diff --git a/torch/fx/experimental/sym_node.py b/torch/fx/experimental/sym_node.py new file mode 100644 index 0000000000000..b1281718eaab2 --- /dev/null +++ b/torch/fx/experimental/sym_node.py @@ -0,0 +1,1095 @@ +""" +This file does three things: +- Contains the definition of SymNode +- Installs all the magic methods into SymBool, SymFloat, SymFloat at import time +- Does not depend on sympy at import time + +As this file is imported from within torch/__init__.py we do not want it to depend on SymPy +to avoid having to load SymPy at import time, as doing so is *very* slow. +""" + +import builtins +import itertools +import logging +import math +import operator +import sys +from functools import lru_cache +from typing import Optional, Type, TYPE_CHECKING, Union + +import torch + +# NB: The sym_* functions are used via getattr() and must be imported here. +from torch import ( # noqa: F401 + sym_float, + sym_ite, + sym_max, + sym_min, + sym_not, + SymBool, + SymFloat, + SymInt, +) + +from torch.fx.experimental._sym_dispatch_mode import ( + handle_sym_dispatch, + sym_function_mode, +) + +if TYPE_CHECKING: + from torch.fx.experimental.symbolic_shapes import ShapeEnv + +log = logging.getLogger(__name__) + + +__all__ = ["SymNode", "method_to_operator", "magic_methods", "sym_sqrt"] + + +SymTypes = (SymInt, SymFloat, SymBool) + + +# TODO: An incomplete list +# 1. Set variables to be equal when we do equality +# 2. Specialize on 0/1 when we do subtraction +class SymNode: + """ + This is a type erased SymInt/SymFloat which we use to do actual operations. + End users don't touch this. Magic methods are NOT defined on this object. + """ + + def __init__( + self, + expr, + shape_env, + pytype, + hint: Optional[Union[int, float]], + constant=None, + fx_node=None, + ): + self._expr = expr + self.shape_env = shape_env + self.pytype = pytype + # What's the difference between hint and constant? + # + # - A constant is known to be invariant across invocations of the model; + # it will always be this value. We only really know this when we + # encounter an honest-to-goodness literal (when wrapping it into + # a SymNode, we set constant.) Most of the time, constant is None + # + # - A hint is a *particular* value from the particular run we are + # tracing, but it may vary the next time around. It's useful to + # keep this around, as if we need a concrete value from a SymNode, + # we will return the hint and guard on the expression that produced + # it giving the same hint next time around. The hint is not + # guaranteed to be set either: if you have an unbacked SymNode, + # there won't be any hint; it was the result of some tensor-dependent + # computation, but we don't know what it actually is because we + # haven't actually run the tensor computation. + # + # hint_expr is only set if we don't have a hint. When it is set, it + # contains the expression which contains the unbacked symnodes that, + # if constrained, would allow this expression to be hinted again. + if hint is None: + self._hint_expr = self.expr.xreplace(shape_env.var_to_val) + self._hint = None + self._update_hint() # check if the replacement actually was enough + else: + self._hint_expr = None + self._hint = hint + self.constant: Optional[Union[int, float, bool]] = constant + + from torch.fx.experimental.validator import translation_validation_enabled + + # Record the FX node of the current node if we are doing translation + # validation. They will be used for building the input assertions for + # the translation validation problem. + self.fx_node = fx_node if translation_validation_enabled() else None + + def with_shape_env(self, shape_env: "ShapeEnv") -> "SymNode": + return SymNode( + self._expr, shape_env, self.pytype, self._hint, self.constant, self.fx_node + ) + + @property + def expr(self): + return self.shape_env.replace(self._expr) + + # Check if we have replacements hint_expr that would allow us to + # simplify it into a hint + def _update_hint(self): + if self._hint_expr.free_symbols <= self.shape_env.replacements.keys(): + new_hint = self.shape_env.replace(self._hint_expr) + # NB: unification constraints could result in a replacement that + # doesn't actually solve the hint! Check for this. + if new_hint.free_symbols: + self._hint_expr = new_hint + return + self._hint = self.pytype(new_hint) + self._hint_expr = None + + @property + def hint(self): + if self._hint is None: + self._update_hint() + return self._hint + + def has_hint(self): + return self._hint is not None + + def require_hint(self): + if self._hint is None: + self._update_hint() + if self._hint is None: + raise self.shape_env._make_data_dependent_error( + self._hint_expr, self.expr + ) + else: + return self._hint + else: + return self._hint + + def maybe_as_int(self): + if self.expr.free_symbols: + return None + else: + return int(self.expr) + + def is_int(self): + return self.pytype is int + + def is_float(self): + return self.pytype is float + + def is_bool(self): + return self.pytype is bool + + def wrap_int(self, num): + assert type(num) is int + import sympy + + return SymNode( + sympy.Integer(num), self.shape_env, int, num, constant=num, fx_node=num + ) + + def wrap_float(self, num): + assert type(num) is float + import sympy + + return SymNode( + sympy.Float(num), self.shape_env, float, num, constant=num, fx_node=num + ) + + def wrap_bool(self, num): + assert type(num) is bool + import sympy + + return SymNode( + sympy.true if num else sympy.false, + self.shape_env, + bool, + num, + constant=num, + fx_node=num, + ) + + def clone(self): + return self + + def str(self): + return f"{self.expr}" + + def __str__(self): + return self.str() + + def __repr__(self): + return self.str() + + # These methods call the metaprogrammed methods, they're hand written + # here so we get good stack traces + def abs(self) -> "SymNode": + return self._abs() # type: ignore[attr-defined] + + def add(self, other) -> "SymNode": + return self._add(other) # type: ignore[attr-defined] + + def sub(self, other) -> "SymNode": + return self._sub(other) # type: ignore[attr-defined] + + def mul(self, other) -> "SymNode": + return self._mul(other) # type: ignore[attr-defined] + + def mod(self, other) -> "SymNode": + return self._mod(other) # type: ignore[attr-defined] + + def pow(self, other) -> "SymNode": + return self._pow(other) # type: ignore[attr-defined] + + def and_(self, other) -> "SymNode": + return self._and_(other) # type: ignore[attr-defined] + + def or_(self, other) -> "SymNode": + return self._or_(other) # type: ignore[attr-defined] + + def truediv(self, other) -> "SymNode": + return self._truediv(other) # type: ignore[attr-defined] + + def floordiv(self, other) -> "SymNode": + return self._floordiv(other) # type: ignore[attr-defined] + + def lshift(self, other) -> "SymNode": + return self._lshift(other) # type: ignore[attr-defined] + + def rshift(self, other) -> "SymNode": + return self._rshift(other) # type: ignore[attr-defined] + + def sym_not(self) -> "SymNode": # noqa: F811 + return self._sym_not() # type: ignore[attr-defined] + + def eq(self, other) -> "SymNode": + return self._eq(other) # type: ignore[attr-defined] + + def ne(self, other) -> "SymNode": + return self._ne(other) # type: ignore[attr-defined] + + def gt(self, other) -> "SymNode": + return self._gt(other) # type: ignore[attr-defined] + + def lt(self, other) -> "SymNode": + return self._lt(other) # type: ignore[attr-defined] + + def le(self, other) -> "SymNode": + return self._le(other) # type: ignore[attr-defined] + + def ge(self, other) -> "SymNode": + return self._ge(other) # type: ignore[attr-defined] + + def floor(self) -> "SymNode": + return self._floor() # type: ignore[attr-defined] + + def sym_float(self) -> "SymNode": # noqa: F811 + return self._sym_float() # type: ignore[attr-defined] + + def sym_int(self) -> "SymNode": + return self._sym_int() # type: ignore[attr-defined] + + def ceil(self) -> "SymNode": + return self._ceil() # type: ignore[attr-defined] + + def neg(self) -> "SymNode": + return self._neg() # type: ignore[attr-defined] + + def sym_min(self, other) -> "SymNode": # noqa: F811 + return self._sym_min(other) # type: ignore[attr-defined] + + def sym_max(self, other) -> "SymNode": # noqa: F811 + return self._sym_max(other) # type: ignore[attr-defined] + + def sym_ite(self, then_val, else_val) -> "SymNode": + return self._sym_ite(then_val, else_val) # type: ignore[attr-defined] + + def sym_sqrt(self) -> "SymNode": + return self._sym_sqrt() # type: ignore[attr-defined] + + def is_contiguous(self, sizes, strides) -> "SymNode": + return self._is_contiguous(sizes, strides) # type: ignore[attr-defined] + + def is_channels_last_contiguous_2d(self, sizes, strides) -> "SymNode": + return self._is_channels_last_contiguous_2d(sizes, strides) # type: ignore[attr-defined] + + def is_channels_last_contiguous_3d(self, sizes, strides) -> "SymNode": + return self._is_channels_last_contiguous_3d(sizes, strides) # type: ignore[attr-defined] + + def is_channels_last_strides_2d(self, sizes, strides) -> "SymNode": + return self._is_channels_last_strides_2d(sizes, strides) # type: ignore[attr-defined] + + def is_channels_last_strides_3d(self, sizes, strides) -> "SymNode": + return self._is_channels_last_strides_3d(sizes, strides) # type: ignore[attr-defined] + + def is_non_overlapping_and_dense_indicator(self, sizes, strides) -> "SymNode": + return self._is_non_overlapping_and_dense_indicator(sizes, strides) # type: ignore[attr-defined] + + # Make C++ happy + def sym_or(self, other): + return self.or_(other) + + def sym_and(self, other): + return self.and_(other) + + def is_non_overlapping_and_dense(self, sizes, strides): + return self.is_non_overlapping_and_dense_indicator(sizes, strides).eq(to_node(self, 1)) # type: ignore[attr-defined] + + def int_(self): + return self.guard_int("", 0) # NB: uses Python backtrace + + # You can manually trigger a guard with this function + def guard_int(self, file, line): + # TODO: use the file/line for some useful diagnostic on why a + # guard occurred + r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node) + try: + return int(r) + except Exception: + log.warning("Failed to convert to int: %s", r) + raise + + def guard_float(self, file, line): + # TODO: use the file/line for some useful diagnostic on why a + # guard occurred + r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node) + try: + return float(r) + except Exception: + log.warning("Failed to convert to float: %s", r) + raise + + def guard_bool(self, file, line): + # TODO: use the file/line for some useful diagnostic on why a + # guard occurred + r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node) + try: + return bool(r) + except Exception: + log.warning("Failed to convert to bool: %s", r) + raise + + def expect_true(self, file, line): + if self.has_hint(): + # OK to generate guards + return self.guard_bool(file, line) + # Generate a deferred runtime assert (this might actually end up doing + # a regular guard if we can!) + # TODO: file/line here is very important, because the assert has been + # deferred so you can't backtrace easily + return self.shape_env.defer_runtime_assert( + self.expr, f"{file}:{line}", fx_node=self.fx_node + ) + + def expect_size(self, file, line): + from torch.fx.experimental.symbolic_shapes import _advise_is_size + + b = self.ge(self.wrap_int(0)) + # Generate a deferred runtime assert + r = b.expect_true(file, line) + # Refine compile time range, but only if it's unbacked. + # If you refine range for hinted variables, you can end up making + # improper deductions since compile time reasoning may be + # incompatible with runtime reasoning. + if r and not self.has_hint(): + _advise_is_size(SymInt(self)) + return r + + def bool_(self): + return self.guard_bool("", 0) + + def is_symbolic(self): + return True + + def singleton_int(self): + return None + + def is_constant(self): + return False + + +unary_magic_methods = { + "abs", + "sym_float", + "ceil", + "floor", + "neg", + "sym_sqrt", + "sym_not", +} + + +# Most methods are only registered on SymInt and SymFloat +# Some methods are only be registered on SymBool +only_bool_magic_methods = {"and", "or", "sym_not", "sym_ite"} +# Methods that are also on SymBool, in addition to on SymInt and SymFloat +also_bool_magic_methods = {"eq"} +bool_magic_methods = only_bool_magic_methods | also_bool_magic_methods + +magic_methods_on_math = {"ceil", "floor"} +magic_methods_on_submodule = { + "sym_float", + "sym_sqrt", + "sym_min", + "sym_max", + "sym_not", + "sym_ite", +} +magic_methods_on_operator_with_trailing_underscore = {"and", "or"} + + +always_float_magic_methods = {"truediv", "sym_float", "sym_sqrt", "pow"} +always_int_magic_methods = {"ceil", "floor"} +always_bool_magic_methods = { + "eq", + "ne", + "gt", + "lt", + "le", + "ge", + "and", + "or", + "sym_not", + "is_non_overlapping_and_dense", +} + +# Methods that have a `__foo__` as well as `__rfoo__` + + +def _sympy_truediv(a, b): + from torch.utils._sympy.functions import TrueDiv + + return TrueDiv(a, b) + + +def _sympy_floordiv(a, b): + from torch.utils._sympy.functions import FloorDiv + + return FloorDiv(a, b) + + +def _sympy_mod(a, b): + from torch.utils._sympy.functions import Mod + + return Mod(a, b) + + +def _sympy_pow(a, b): + from torch.utils._sympy.functions import Pow + + return Pow(a, b) + + +def _sympy_and(a, b): + import sympy + + return sympy.And(a, b) + + +def _sympy_or(a, b): + import sympy + + return sympy.Or(a, b) + + +def _sympy_lshift(a, b): + from torch.utils._sympy.functions import LShift + + return LShift(a, b) + + +def _sympy_rshift(a, b): + from torch.utils._sympy.functions import RShift + + return RShift(a, b) + + +reflectable_magic_methods = { + "add": lambda a, b: a + b, + "sub": lambda a, b: a - b, + "mul": lambda a, b: a * b, + "mod": _sympy_mod, + "pow": _sympy_pow, + "and": _sympy_and, + "or": _sympy_or, + "truediv": _sympy_truediv, + "floordiv": _sympy_floordiv, + "lshift": _sympy_lshift, + "rshift": _sympy_rshift, +} + + +def _floor_ceil_helper(a, fn): + import sympy + + if isinstance(a, sympy.Mul): + aa = a.args + if len(aa) == 2 and isinstance(aa[0], sympy.Float) and aa[1].is_integer: + coef = sympy.Integer(aa[0]) + if aa[0] == coef: # structural equality test + return coef * aa[1] + if ( + isinstance(a, sympy.Float) + and a == sympy.Integer(a) + or isinstance(a, sympy.Integer) + ): + return sympy.Integer(a) + return fn(a) + + +def _sympy_floor(a): + import sympy + + return _floor_ceil_helper(a, sympy.floor) + + +def _sympy_ceil(a): + import sympy + + return _floor_ceil_helper(a, sympy.ceiling) + + +def _sympy_eq(a, b): + import sympy + + return sympy.Eq(a, b) + + +def _sympy_ne(a, b): + import sympy + + return sympy.Ne(a, b) + + +def _sympy_gt(a, b): + import sympy + + return sympy.Gt(a, b) + + +def _sympy_lt(a, b): + import sympy + + return sympy.Lt(a, b) + + +def _sympy_le(a, b): + import sympy + + return sympy.Le(a, b) + + +def _sympy_ge(a, b): + import sympy + + return sympy.Ge(a, b) + + +def _sympy_min(a, b): + import sympy + + return sympy.Min(a, b) + + +def _sympy_max(a, b): + import sympy + + return sympy.Max(a, b) + + +def _sympy_ite(a, t, f): + import sympy + + return sympy.Piecewise((t, a), (f, True)) + + +def _sympy_sqrt(a): + import sympy + + return sympy.sqrt(a) + + +def _sympy_abs(a): + import sympy + + return sympy.Abs(a) + + +magic_methods = { + **reflectable_magic_methods, + "sym_not": lambda a: ~a, + "eq": _sympy_eq, + "ne": _sympy_ne, + "gt": _sympy_gt, + "lt": _sympy_lt, + "le": _sympy_le, + "ge": _sympy_ge, + "floor": _sympy_floor, + "sym_float": lambda a: a, # Cannot use sympy.Float(a) here, coz it expects python literals + "ceil": _sympy_ceil, + "neg": lambda a: -a, + "sym_min": _sympy_min, + "sym_max": _sympy_max, + "sym_ite": _sympy_ite, + "sym_sqrt": _sympy_sqrt, + "abs": _sympy_abs, +} + + +# Drop in replacement for math.sqrt +def sym_sqrt(a): + if hasattr(a, "__sym_sqrt__"): + return a.__sym_sqrt__() + return math.sqrt(a) + + +def sympy_is_contiguous(sizes, strides): + dim = len(sizes) + return sympy_is_contiguous_generic(sizes, strides, list(range(dim - 1, -1, -1))) + + +def sympy_is_contiguous_generic(sizes, strides, dim_order): + import sympy + + dim = len(sizes) + + if len(dim_order) != dim: + return sympy.false + + is_contiguous = sympy.true + z = sympy.Integer(1) + # Contiguous if the strides make sense (or the dim is size 1) + for d in dim_order: + is_contiguous &= sympy.Eq(sizes[d], sympy.Integer(1)) | sympy.Eq(strides[d], z) + z *= sizes[d] + # OR if any size is zero + for d in range(dim): + is_contiguous |= sympy.Eq(sizes[d], sympy.Integer(0)) + return is_contiguous + + +# NB: There is a TODO in C++ to allow omitting the batch dim. If that +# happens you will need to refactor this + + +def sympy_is_channels_last_contiguous_2d(sizes, strides): + return sympy_is_contiguous_generic(sizes, strides, [1, 3, 2, 0]) + + +def sympy_is_channels_last_contiguous_3d(sizes, strides): + return sympy_is_contiguous_generic(sizes, strides, [1, 4, 3, 2, 0]) + + +def sympy_is_channels_last_strides_generic(sizes, strides, dim_order): + import sympy + + dim = len(sizes) + + if dim != len(dim_order): + return sympy.false + + m = sympy.Integer(0) + r = sympy.true + + # special case for trivial C dimension. default to NCHW + r &= sympy.Ne(strides[1], 0) + + for d in dim_order: + r &= sympy.Ne(sizes[d], 0) & (strides[d] >= m) + # Fallback to NCHW as default layout for ambiguous cases + # This is the flaw of implicit memory_format from strides. + # N111 tensor with identical strides for size 1 dimension; + # Two cases could lead us here: + # a. N111 contiguous Tensor ([N,1,1,1]@[1,1,1,1]) + # b. N11W contiguous Tensor sliced on the W-dimension. + # ([N,1,1,1]@[W,W,W,W]) + if d == 0: + r &= sympy.Ne(m, strides[1]) + # This is necessary to: + # 1. distinguish the memory_format of N1H1; + # [H, 1, 1, 1] channels_last stride + # [H, H, 1, 1] contiguous stride + # 2. permutation of 1C1W: + # [1, C, 1, H]@[HC, H, H, 1] transpose(1, 3) + # [1, H, 1, C]@[HC, 1, H, H] shouldn't be identified as + # channels_last + m = strides[d] * sympy.Max(sizes[d], 1) + + return r + + +def sympy_is_channels_last_strides_2d(sizes, strides): + return sympy_is_channels_last_strides_generic(sizes, strides, [1, 3, 2, 0]) + + +def sympy_is_channels_last_strides_3d(sizes, strides): + return sympy_is_channels_last_strides_generic(sizes, strides, [1, 4, 3, 2, 0]) + + +def _sympy_is_non_overlapping_and_dense_indicator(sizes, strides): + from torch.utils._sympy.functions import IsNonOverlappingAndDenseIndicator + + return IsNonOverlappingAndDenseIndicator(*sizes, *strides) + + +sizes_strides_methods = { + # TODO: These could also be done with indicators, maybe it is better + # for reasoning to do it that way + "is_contiguous": sympy_is_contiguous, + "is_channels_last_contiguous_2d": sympy_is_channels_last_contiguous_2d, + "is_channels_last_contiguous_3d": sympy_is_channels_last_contiguous_3d, + "is_channels_last_strides_2d": sympy_is_channels_last_strides_2d, + "is_channels_last_strides_3d": sympy_is_channels_last_strides_3d, + "is_non_overlapping_and_dense_indicator": _sympy_is_non_overlapping_and_dense_indicator, +} + +alternate_impl_if_hinted_methods = { + "sym_min": builtins.min, + "sym_max": builtins.max, +} + + +def to_node(self, num): + if isinstance(num, SymTypes): + return num.node + elif type(num) is bool: + return self.wrap_bool(num) + elif type(num) is int: + return self.wrap_int(num) + elif type(num) is float: + return self.wrap_float(num) + else: + # NotImplemented is important so that Python tries the + # other magic method + return NotImplemented + + +def wrap_node(x): + # TODO: let C++ also take advantage of this + if isinstance(x, SymNode) and x.constant is not None: + return x.constant + if x.is_int(): + return SymInt(x) + elif x.is_float(): + return SymFloat(x) + elif x.is_bool(): + return SymBool(x) + else: + raise AssertionError(f"unrecognized return type {x}") + + +def method_to_operator(method): + if method in magic_methods_on_operator_with_trailing_underscore: + method_attr = f"{method}_" + else: + method_attr = method + if method in magic_methods_on_submodule: + op = getattr(torch.fx.experimental.sym_node, method_attr) + elif method in magic_methods_on_math: + op = getattr(math, method_attr) + else: + op = getattr(operator, method_attr) + return op + + +def _make_node_magic(method, func): + func = lru_cache(256)(func) + + if method in magic_methods_on_operator_with_trailing_underscore: + method_attr = f"{method}_" + else: + method_attr = method + + def binary_magic_impl(self, other): + from torch.fx.experimental.symbolic_shapes import safe_expand + + op = method_to_operator(method) + + out_hint = None + if self.hint is not None and other.hint is not None: + out_hint = op(self.hint, other.hint) + + alternate_impl = alternate_impl_if_hinted_methods.get(method) + if alternate_impl and out_hint is not None: + return to_node(self, alternate_impl(wrap_node(self), wrap_node(other))) + + if sym_function_mode(): + return to_node( + self, handle_sym_dispatch(op, (wrap_node(self), wrap_node(other)), {}) + ) + assert isinstance(other, SymNode) + # TODO: consider constant prop here + try: + out = func(self.expr, other.expr) + except Exception: + log.warning("failed to eval %s(%s, %s)", method, self.expr, other.expr) + raise + out = safe_expand(out) + pytype: Type + # This is not strictly correct. In Python, a**b may return complex when + # a < 0 and b is a float: (-1)**2.1. Same for sympy.sqrt(-3.14). This + # returns a float while both arguments are ints: 2**(-1). Also, max and + # min do not type promote. To avoid having data-dependent control flow + # here, we just set the type to float if one of the args is a float. In + # case of a type mismatch, we assume that it will be detected during + # evaluation. + if method in always_float_magic_methods: + pytype = float + elif method in always_bool_magic_methods: + pytype = bool + elif self.pytype is float or other.pytype is float: + pytype = float + else: + pytype = self.pytype + + # Create a FX node that corresponds to the operation being applied to + # this node. + fx_node, _ = self.shape_env.create_fx_call_function( + op, (self.fx_node, other.fx_node) + ) + return SymNode(out, self.shape_env, pytype, out_hint, fx_node=fx_node) + + def unary_magic_impl(self): + from torch.fx.experimental.symbolic_shapes import safe_expand + + op = method_to_operator(method) + if sym_function_mode(): + return to_node(self, handle_sym_dispatch(op, (wrap_node(self),), {})) + # TODO: consider constant prop here + expr = self.expr + if method == "floor" or method == "ceiling": + expr = self.shape_env._simplify_floor_div(expr) + + try: + out = func(expr) + except Exception: + log.warning("failed to eval %s(%s)", method, expr) + raise + + out_hint = None + if self.hint is not None: + out_hint = op(self.hint) + out = safe_expand(out) + pytype: Type + if method in always_int_magic_methods: + pytype = int + elif method in always_float_magic_methods: + pytype = float + else: + pytype = self.pytype + + fx_node, _ = self.shape_env.create_fx_call_function(op, (self.fx_node,)) + return SymNode(out, self.shape_env, pytype, out_hint, fx_node=fx_node) + + if method in unary_magic_methods: + setattr(SymNode, f"_{method_attr}", unary_magic_impl) + elif method == "sym_ite": + + def sym_ite_impl(pred_node, then_node, else_node): + from torch.fx.experimental.symbolic_shapes import safe_expand + + out_hint = then_node.hint if pred_node.hint else else_node.hint + if sym_function_mode(): + return to_node( + pred_node, + handle_sym_dispatch( + sym_ite, + ( + wrap_node(pred_node), + wrap_node(then_node), + wrap_node(else_node), + ), + {}, + ), + ) + + try: + out = func(pred_node.expr, then_node.expr, else_node.expr) + except Exception: + log.warning( + "failed to eval %s(%s, %s, %s)", + method, + pred_node.expr, + then_node.expr, + else_node.expr, + ) + raise + + out = safe_expand(out) + fx_node, _ = pred_node.shape_env.create_fx_call_function( + sym_ite, (pred_node.fx_node, then_node.fx_node, else_node.fx_node) + ) + return SymNode( + out, pred_node.shape_env, then_node.pytype, out_hint, fx_node=fx_node + ) + + setattr(SymNode, f"_{method_attr}", sym_ite_impl) + else: + setattr(SymNode, f"_{method_attr}", binary_magic_impl) + + +def _make_node_sizes_strides(method, func): + # NB: don't LRU cache, lots of arguments + + def sizes_strides_impl(self, sizes, strides): + op = getattr(sys.modules[__name__], method) + if sym_function_mode(): + return to_node( + self, + handle_sym_dispatch( + op, + ([wrap_node(s) for s in sizes], [wrap_node(s) for s in strides]), + {}, + ), + ) + size_exprs = [s.expr for s in sizes] + stride_exprs = [s.expr for s in strides] + try: + out = func(size_exprs, stride_exprs) + except Exception: + log.warning("failed to eval %s(%s, %s)", method, size_exprs, stride_exprs) + raise + # bool is never expandable + + size_hints = [] + out_hint = None + for s in sizes: + if s.hint is None: + break + size_hints.append(s.hint) + else: + stride_hints = [] + for s in strides: + if s.hint is None: + break + stride_hints.append(s.hint) + else: + out_hint = op(size_hints, stride_hints) + + # NB: This is the indicator function, not the actual bool! + pytype: Type + if method.endswith("_indicator"): + pytype = int + else: + pytype = bool + return SymNode(out, self.shape_env, pytype, out_hint) + + setattr(SymNode, f"_{method}", sizes_strides_impl) + + # TODO: This is technically hotpath, but in the ideal end state + # guards on this will resolve at a higher level so you never + # spend time in this code + def sizes_strides_user(sizes, strides): + import sympy + + from torch.fx.experimental.symbolic_shapes import ( + eval_is_non_overlapping_and_dense, + ) + + for a in itertools.chain(sizes, strides): + if isinstance(a, SymInt): + return wrap_node( + getattr(a.node, method)( + [to_node(a.node, b) for b in sizes], + [to_node(a.node, b) for b in strides], + ) + ) + if method == "is_non_overlapping_and_dense_indicator": + return eval_is_non_overlapping_and_dense(sizes, strides) + else: + # TODO: this is an awful implementation + return bool( + func( + [sympy.sympify(a) for a in sizes], + [sympy.sympify(a) for a in strides], + ) + ) + + # Skip for is_non_overlapping_and_dense_indicator + if not hasattr(sys.modules[__name__], method): + setattr(sys.modules[__name__], method, sizes_strides_user) + + +for method, func in magic_methods.items(): + _make_node_magic(method, func) + +for method, func in sizes_strides_methods.items(): + _make_node_sizes_strides(method, func) + + +def _make_user_magic(method, user_type): + # User magic takes care of wrapping the other operand into a node, + # so that our internal logic can assume everything is nodes + + if method in magic_methods_on_operator_with_trailing_underscore: + method_attr = f"{method}_" + else: + method_attr = method + + def get_constant(x: Union[SymInt, int, SymFloat, float, SymBool, bool]): + if isinstance(x, (int, float, bool)): + return x + if isinstance(x, SymBool): + return x.node.guard_bool("", 0) + raise AssertionError("expect to be called with constant SymBools") + + def is_constant(x): + if isinstance(x, (int, float, bool)): + return True + if isinstance(x, (SymInt, SymFloat, SymBool)): + return x.node.is_constant() + return False + + # Before and after performing the operation, check if any operands are constant. + # If so, extract out the constant values first. If `self` itself is a + # constant, then "redispatch" by calling back into the operator. Sometimes + # this means that operations involving SymBool return plain bools. + # Alternatively, we could also rewrap into constant Symbool (i.e. by + # implementing wrap_bool in ConstantSymNodeImpl), but we're not doing that + # today for no particular reason. + def unary_magic_impl(self): + if is_constant(self): + return (method_to_operator(method))(get_constant(self)) + return wrap_node(getattr(self.node, method_attr)()) + + def binary_magic_impl(self, other): + if is_constant(self): + return (method_to_operator(method))(get_constant(self), other) + if is_constant(other): + other = get_constant(other) + other_node = to_node(self.node, other) + if other_node is NotImplemented: + return NotImplemented + ret = wrap_node(getattr(self.node, method_attr)(other_node)) + return get_constant(ret) if is_constant(ret) else ret + + def rbinary_magic_impl(self, other): + if is_constant(self): + return (method_to_operator(method))(get_constant(self), other) + if is_constant(other): + other = get_constant(other) + other_node = to_node(self.node, other) + if other_node is NotImplemented: + return NotImplemented + ret = wrap_node(getattr(other_node, method_attr)(self.node)) + return get_constant(ret) if is_constant(ret) else ret + + if method in unary_magic_methods: + setattr(user_type, f"__{method}__", unary_magic_impl) + elif method == "sym_ite": + + def sym_ite_magic_impl(pred, then_val, else_val): + pred_node = pred.node + then_node = to_node(pred_node, then_val) + else_node = to_node(pred_node, else_val) + if then_node is NotImplemented or else_node is NotImplemented: + return NotImplemented + assert ( + isinstance(then_node, SymNode) + and isinstance(else_node, SymNode) + and then_node.pytype == else_node.pytype + ) + ret = wrap_node(getattr(pred.node, method_attr)(then_node, else_node)) + return get_constant(ret) if ret.node.is_constant() else ret + + setattr(user_type, f"__{method}__", sym_ite_magic_impl) + else: + setattr(user_type, f"__{method}__", binary_magic_impl) + if method in reflectable_magic_methods: + setattr(user_type, f"__r{method}__", rbinary_magic_impl) + + +for method, func in magic_methods.items(): # type: ignore[assignment] + if method in only_bool_magic_methods: + _make_user_magic(method, SymBool) + continue + if method in also_bool_magic_methods: + _make_user_magic(method, SymBool) + _make_user_magic(method, SymInt) + _make_user_magic(method, SymFloat) + +del method +del func diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 619381088ff4f..857e0a35ae3fc 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -28,22 +28,13 @@ replay_shape_env_events, shape_env_check_state_equal ) -from torch.fx.experimental._sym_dispatch_mode import handle_sym_dispatch, sym_function_mode +from torch.fx.experimental.sym_node import SymNode, SymTypes # NB: The sym_* functions are used via getattr() and must be imported here. -from torch import ( # noqa: F401 - sym_float, - sym_max, - sym_min, - sym_not, - sym_ite, - SymBool, - SymFloat, - SymInt, -) +from torch import SymBool, SymFloat, SymInt from torch._guards import ShapeGuard, Source, TracingContext from torch.utils._python_dispatch import is_traceable_wrapper_subclass -from torch.utils._sympy.functions import FloorDiv, LShift, Mod, RShift +from torch.utils._sympy.functions import FloorDiv, Mod, IsNonOverlappingAndDenseIndicator from torch.utils._sympy.solve import try_solve from torch.utils._sympy.value_ranges import bound_sympy, SymPyValueRangeAnalysis, ValueRanges, ValueRangeError from torch.utils._sympy.singleton_int import SingletonInt @@ -52,7 +43,6 @@ InputList = List DimList = List -SymTypes = (SymInt, SymFloat, SymBool) log = logging.getLogger(__name__) @@ -67,8 +57,8 @@ class GuardOnDataDependentSymNode(RuntimeError): __all__ = [ "has_symbolic_sizes_strides", "create_contiguous", "ShapeEnv", "is_concrete_int", - "guard_int", "guard_float", "guard_scalar", "wrap_node", - "method_to_operator", "hint_int", "SYMPY_INTERP", "free_symbols", "is_symbol_binding_fx_node", + "guard_int", "guard_float", "guard_scalar", + "hint_int", "SYMPY_INTERP", "free_symbols", "is_symbol_binding_fx_node", "is_concrete_bool", "SHAPEENV_EVENT_KEY", "CURRENT_NODE_KEY", ] @@ -298,6 +288,7 @@ def _constrain_symbol_range(shape_env, s: sympy.Symbol, compiler_min: int, compi else: shape_env.runtime_var_to_range[s] = ValueRanges(runtime_min, runtime_max) + def _advise_is_size(a): """ Don't use this directly; use torch._check_is_size instead. @@ -330,7 +321,12 @@ def _advise_is_size(a): # that hints have to be consistent with static analysis! If you somehow # have an unbounded SymInt that later constrains to 1, this will be # inconsistent with the range - if isinstance(a, SymInt) and isinstance(a.node, SymNode) and not a.node.has_hint() and isinstance(a.node.expr, sympy.Symbol): + if ( + isinstance(a, SymInt) + and isinstance(a.node, SymNode) + and not a.node.has_hint() + and isinstance(a.node.expr, sympy.Symbol) + ): _constrain_range_for_size(a) @record_shapeenv_event() @@ -531,26 +527,6 @@ def guard_float(a): assert isinstance(a, float), a return a -# Drop in replacement for math.sqrt -def sym_sqrt(a): - if hasattr(a, '__sym_sqrt__'): - return a.__sym_sqrt__() - return math.sqrt(a) - -def to_node(self, num): - if isinstance(num, SymTypes): - return num.node - elif type(num) is bool: - return self.wrap_bool(num) - elif type(num) is int: - return self.wrap_int(num) - elif type(num) is float: - return self.wrap_float(num) - else: - # NotImplemented is important so that Python tries the - # other magic method - return NotImplemented - # Given a GraphModule, return all the FakeTensors for all the placeholders def fx_placeholder_vals(gm): return [n.meta['val'] for n in gm.graph.nodes if n.op == "placeholder"] @@ -713,366 +689,11 @@ def render(self): def is_equal(self, source1, source2): return self._find(source1) == self._find(source2) - -# TODO: An incomplete list -# 1. Set variables to be equal when we do equality -# 2. Specialize on 0/1 when we do subtraction -class SymNode: - """ - This is a type erased SymInt/SymFloat which we use to do actual operations. - End users don't touch this. Magic methods are NOT defined on this object. - """ - def __init__(self, expr, shape_env, pytype, hint: Optional[Union[int, float]], constant=None, fx_node=None): - self._expr = expr - self.shape_env = shape_env - self.pytype = pytype - # What's the difference between hint and constant? - # - # - A constant is known to be invariant across invocations of the model; - # it will always be this value. We only really know this when we - # encounter an honest-to-goodness literal (when wrapping it into - # a SymNode, we set constant.) Most of the time, constant is None - # - # - A hint is a *particular* value from the particular run we are - # tracing, but it may vary the next time around. It's useful to - # keep this around, as if we need a concrete value from a SymNode, - # we will return the hint and guard on the expression that produced - # it giving the same hint next time around. The hint is not - # guaranteed to be set either: if you have an unbacked SymNode, - # there won't be any hint; it was the result of some tensor-dependent - # computation, but we don't know what it actually is because we - # haven't actually run the tensor computation. - # - # hint_expr is only set if we don't have a hint. When it is set, it - # contains the expression which contains the unbacked symnodes that, - # if constrained, would allow this expression to be hinted again. - if hint is None: - self._hint_expr = self.expr.xreplace(shape_env.var_to_val) - self._hint = None - self._update_hint() # check if the replacement actually was enough - else: - self._hint_expr = None - self._hint = hint - self.constant: Optional[Union[int, float, bool]] = constant - # Record the FX node of the current node if we are doing translation - # validation. They will be used for building the input assertions for - # the translation validation problem. - self.fx_node = fx_node if _translation_validation_enabled() else None - - def with_shape_env(self, shape_env: "ShapeEnv") -> "SymNode": - return SymNode(self._expr, shape_env, self.pytype, self._hint, self.constant, self.fx_node) - - @property - def expr(self): - return self.shape_env.replace(self._expr) - - # Check if we have replacements hint_expr that would allow us to - # simplify it into a hint - def _update_hint(self): - if self._hint_expr.free_symbols <= self.shape_env.replacements.keys(): - new_hint = self.shape_env.replace(self._hint_expr) - # NB: unification constraints could result in a replacement that - # doesn't actually solve the hint! Check for this. - if new_hint.free_symbols: - self._hint_expr = new_hint - return - self._hint = self.pytype(new_hint) - self._hint_expr = None - - @property - def hint(self): - if self._hint is None: - self._update_hint() - return self._hint - - def has_hint(self): - return self._hint is not None - - def require_hint(self): - if self._hint is None: - self._update_hint() - if self._hint is None: - raise self.shape_env._make_data_dependent_error(self._hint_expr, self.expr) - else: - return self._hint - else: - return self._hint - - def maybe_as_int(self): - if self.expr.free_symbols: - return None - else: - return int(self.expr) - - def is_int(self): - return self.pytype is int - - def is_float(self): - return self.pytype is float - - def is_bool(self): - return self.pytype is bool - - def wrap_int(self, num): - assert type(num) is int - return SymNode(sympy.Integer(num), self.shape_env, int, num, constant=num, fx_node=num) - - def wrap_float(self, num): - assert type(num) is float - return SymNode(sympy.Float(num), self.shape_env, float, num, constant=num, fx_node=num) - - def wrap_bool(self, num): - assert type(num) is bool - return SymNode(sympy.true if num else sympy.false, self.shape_env, bool, num, constant=num, fx_node=num) - - def clone(self): - return self - - def str(self): - return f"{self.expr}" - - def __str__(self): - return self.str() - - def __repr__(self): - return self.str() - - # These methods call the metaprogrammed methods, they're hand written - # here so we get good stack traces - def abs(self) -> "SymNode": # noqa: F811 - return self._abs() # type: ignore[attr-defined] - - def add(self, other) -> "SymNode": # noqa: F811 - return self._add(other) # type: ignore[attr-defined] - - def sub(self, other) -> "SymNode": # noqa: F811 - return self._sub(other) # type: ignore[attr-defined] - - def mul(self, other) -> "SymNode": # noqa: F811 - return self._mul(other) # type: ignore[attr-defined] - - def mod(self, other) -> "SymNode": # noqa: F811 - return self._mod(other) # type: ignore[attr-defined] - - def pow(self, other) -> "SymNode": # noqa: F811 - return self._pow(other) # type: ignore[attr-defined] - - def and_(self, other) -> "SymNode": # noqa: F811 - return self._and_(other) # type: ignore[attr-defined] - - def or_(self, other) -> "SymNode": # noqa: F811 - return self._or_(other) # type: ignore[attr-defined] - - def truediv(self, other) -> "SymNode": # noqa: F811 - return self._truediv(other) # type: ignore[attr-defined] - - def floordiv(self, other) -> "SymNode": # noqa: F811 - return self._floordiv(other) # type: ignore[attr-defined] - - def lshift(self, other) -> "SymNode": # noqa: F811 - return self._lshift(other) # type: ignore[attr-defined] - - def rshift(self, other) -> "SymNode": # noqa: F811 - return self._rshift(other) # type: ignore[attr-defined] - - def sym_not(self) -> "SymNode": # noqa: F811 - return self._sym_not() # type: ignore[attr-defined] - - def eq(self, other) -> "SymNode": # noqa: F811 - return self._eq(other) # type: ignore[attr-defined] - - def ne(self, other) -> "SymNode": # noqa: F811 - return self._ne(other) # type: ignore[attr-defined] - - def gt(self, other) -> "SymNode": # noqa: F811 - return self._gt(other) # type: ignore[attr-defined] - - def lt(self, other) -> "SymNode": # noqa: F811 - return self._lt(other) # type: ignore[attr-defined] - - def le(self, other) -> "SymNode": # noqa: F811 - return self._le(other) # type: ignore[attr-defined] - - def ge(self, other) -> "SymNode": # noqa: F811 - return self._ge(other) # type: ignore[attr-defined] - - def floor(self) -> "SymNode": # noqa: F811 - return self._floor() # type: ignore[attr-defined] - - def sym_float(self) -> "SymNode": # noqa: F811 - return self._sym_float() # type: ignore[attr-defined] - - def sym_int(self) -> "SymNode": # noqa: F811 - return self._sym_int() # type: ignore[attr-defined] - - def ceil(self) -> "SymNode": # noqa: F811 - return self._ceil() # type: ignore[attr-defined] - - def neg(self) -> "SymNode": # noqa: F811 - return self._neg() # type: ignore[attr-defined] - - def sym_min(self, other) -> "SymNode": # noqa: F811 - return self._sym_min(other) # type: ignore[attr-defined] - - def sym_max(self, other) -> "SymNode": # noqa: F811 - return self._sym_max(other) # type: ignore[attr-defined] - - def sym_ite(self, then_val, else_val) -> "SymNode": - return self._sym_ite(then_val, else_val) - - def sym_sqrt(self) -> "SymNode": # noqa: F811 - return self._sym_sqrt() # type: ignore[attr-defined] - - def is_contiguous(self, sizes, strides) -> "SymNode": # noqa: F811 - return self._is_contiguous(sizes, strides) # type: ignore[attr-defined] - - def is_channels_last_contiguous_2d(self, sizes, strides) -> "SymNode": # noqa: F811 - return self._is_channels_last_contiguous_2d(sizes, strides) # type: ignore[attr-defined] - - def is_channels_last_contiguous_3d(self, sizes, strides) -> "SymNode": # noqa: F811 - return self._is_channels_last_contiguous_3d(sizes, strides) # type: ignore[attr-defined] - - def is_channels_last_strides_2d(self, sizes, strides) -> "SymNode": # noqa: F811 - return self._is_channels_last_strides_2d(sizes, strides) # type: ignore[attr-defined] - - def is_channels_last_strides_3d(self, sizes, strides) -> "SymNode": # noqa: F811 - return self._is_channels_last_strides_3d(sizes, strides) # type: ignore[attr-defined] - - def is_non_overlapping_and_dense_indicator(self, sizes, strides) -> "SymNode": # noqa: F811 - return self._is_non_overlapping_and_dense_indicator(sizes, strides) # type: ignore[attr-defined] - - # Make C++ happy - def sym_or(self, other): # noqa: F811 - return self.or_(other) - - def sym_and(self, other): # noqa: F811 - return self.and_(other) - - def is_non_overlapping_and_dense(self, sizes, strides): - return self.is_non_overlapping_and_dense_indicator(sizes, strides).eq(to_node(self, 1)) # type: ignore[attr-defined] - - def int_(self): - return self.guard_int("", 0) # NB: uses Python backtrace - - # You can manually trigger a guard with this function - def guard_int(self, file, line): - # TODO: use the file/line for some useful diagnostic on why a - # guard occurred - r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node) - try: - return int(r) - except Exception: - log.warning("Failed to convert to int: %s", r) - raise - - def guard_float(self, file, line): - # TODO: use the file/line for some useful diagnostic on why a - # guard occurred - r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node) - try: - return float(r) - except Exception: - log.warning("Failed to convert to float: %s", r) - raise - - def guard_bool(self, file, line): - # TODO: use the file/line for some useful diagnostic on why a - # guard occurred - r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node) - try: - return bool(r) - except Exception: - log.warning("Failed to convert to bool: %s", r) - raise - - def expect_true(self, file, line): - if self.has_hint(): - # OK to generate guards - return self.guard_bool(file, line) - # Generate a deferred runtime assert (this might actually end up doing - # a regular guard if we can!) - # TODO: file/line here is very important, because the assert has been - # deferred so you can't backtrace easily - return self.shape_env.defer_runtime_assert(self.expr, f"{file}:{line}", fx_node=self.fx_node) - - def expect_size(self, file, line): - b = self.ge(self.wrap_int(0)) - # Generate a deferred runtime assert - r = b.expect_true(file, line) - # Refine compile time range, but only if it's unbacked. - # If you refine range for hinted variables, you can end up making - # improper deductions since compile time reasoning may be - # incompatible with runtime reasoning. - if r and not self.has_hint(): - _advise_is_size(SymInt(self)) - return r - - def bool_(self): - return self.guard_bool("", 0) - - def is_symbolic(self): - return True - - def singleton_int(self): - return None - - def is_constant(self): - return False - def is_symbolic(val: Union[int, SymInt, float, SymFloat, bool, SymBool]) -> bool: if isinstance(val, (int, float, bool)): return False return val.node.is_symbolic() -# Overloaded to be compatible with regular Python. -# https://github.com/pytorch/pytorch/issues/90900 -class Pow(sympy.Function): - @classmethod - def eval(cls, base, exp): - if exp.is_zero: - return sympy.Integer(1) - elif base.is_zero and exp < 0: - raise ZeroDivisionError(f"{base} cannot be raised to a negative power") - else: - return base ** exp - -# Overloaded to be compatible with regular Python. -# https://github.com/pytorch/pytorch/issues/90900 -class TrueDiv(sympy.Function): - @classmethod - def eval(cls, base, divisor): - if divisor.is_zero: - raise ZeroDivisionError("division by zero") - else: - return base / divisor - -# TODO: As an indicator, this != 0 implies == 1 (and vice versa). -# Because we do not have the ability to guard on the stride permutation -# at the moment, it is hard to make further inferences when this is true, -# as although we know the tensor is contiguous in *some* layout, we don't -# know which one (however, you could, for example, make the inference that -# reshaping this to a 1D tensor can be guard-free.) -class IsNonOverlappingAndDenseIndicator(sympy.Function): - is_integer = True - - @classmethod - def eval(cls, *args): - assert len(args) % 2 == 0 - dim = len(args) // 2 - # TODO: it is possible to make progress evaluating this guard - # even if not all of the inputs are known. For example, a 2D - # tensor with non-0/1 sizes but strides (0, 1) is definitely - # false, because we know its numel > 1 but it's broadcasted - # in dim 0. - if all(isinstance(a, sympy.Integer) for a in args): - size_args = args[0:dim] - stride_args = args[dim:] - return eval_is_non_overlapping_and_dense( - [int(a) for a in size_args], - [int(a) for a in stride_args] - ) - return None - IndicatorTypes = (IsNonOverlappingAndDenseIndicator,) @lru_cache(256) @@ -1086,151 +707,10 @@ def safe_expand(r): else: return r -# Methods that have a `__foo__` as well as `__rfoo__` -reflectable_magic_methods = { - 'add': lambda a, b: a + b, - 'sub': lambda a, b: a - b, - 'mul': lambda a, b: a * b, - 'mod': lambda a, b: Mod(a, b), - 'pow': lambda a, b: Pow(a, b), - 'and': lambda a, b: sympy.And(a, b), - 'or': lambda a, b: sympy.Or(a, b), - 'truediv': lambda a, b: TrueDiv(a, b), - 'floordiv': lambda a, b: FloorDiv(a, b), - 'lshift': lambda a, b: LShift(a, b), - 'rshift': lambda a, b: RShift(a, b), -} - - def error(): raise AssertionError("shouldn't be hit") -def floor_ceil_helper(a, fn): - if isinstance(a, sympy.Mul): - aa = a.args - if len(aa) == 2 and isinstance(aa[0], sympy.Float) and aa[1].is_integer: - coef = sympy.Integer(aa[0]) - if aa[0] == coef: # structural equality test - return coef * aa[1] - if isinstance(a, sympy.Float) and a == sympy.Integer(a) or isinstance(a, sympy.Integer): - return sympy.Integer(a) - return fn(a) - -def floor_impl(a): - return floor_ceil_helper(a, sympy.floor) - -def ceil_impl(a): - return floor_ceil_helper(a, sympy.ceiling) - - -magic_methods = { - **reflectable_magic_methods, - 'sym_not': lambda a: ~a, - 'eq': lambda a, b: sympy.Eq(a, b), - 'ne': lambda a, b: sympy.Ne(a, b), - 'gt': lambda a, b: sympy.Gt(a, b), - 'lt': lambda a, b: sympy.Lt(a, b), - 'le': lambda a, b: sympy.Le(a, b), - 'ge': lambda a, b: sympy.Ge(a, b), - 'floor': floor_impl, - 'sym_float': lambda a: a, # Cannot use sympy.Float(a) here, coz it expects python literals - 'ceil': ceil_impl, - 'neg': lambda a: -a, - 'sym_min': lambda a, b: sympy.Min(a, b), - 'sym_max': lambda a, b: sympy.Max(a, b), - 'sym_ite': lambda a, t, f: sympy.Piecewise((t, a), (f, True)), - 'sym_sqrt': lambda a: sympy.sqrt(a), - 'abs': lambda a: sympy.Abs(a), -} - -sizes_strides_methods = { - # TODO: These could also be done with indicators, maybe it is better - # for reasoning to do it that way - 'is_contiguous': lambda sizes, strides: sympy_is_contiguous(sizes, strides), - 'is_channels_last_contiguous_2d': lambda sizes, strides: sympy_is_channels_last_contiguous_2d(sizes, strides), - 'is_channels_last_contiguous_3d': lambda sizes, strides: sympy_is_channels_last_contiguous_3d(sizes, strides), - 'is_channels_last_strides_2d': lambda sizes, strides: sympy_is_channels_last_strides_2d(sizes, strides), - 'is_channels_last_strides_3d': lambda sizes, strides: sympy_is_channels_last_strides_3d(sizes, strides), - 'is_non_overlapping_and_dense_indicator': lambda sizes, strides: IsNonOverlappingAndDenseIndicator(*sizes, *strides), -} - -alternate_impl_if_hinted_methods = { - "sym_min": builtins.min, - "sym_max": builtins.max, -} - -def sympy_is_contiguous_generic(sizes, strides, dim_order): - dim = len(sizes) - - if len(dim_order) != dim: - return sympy.false - - is_contiguous = sympy.true - z = sympy.Integer(1) - # Contiguous if the strides make sense (or the dim is size 1) - for d in dim_order: - is_contiguous &= sympy.Eq(sizes[d], sympy.Integer(1)) | sympy.Eq(strides[d], z) - z *= sizes[d] - # OR if any size is zero - for d in range(dim): - is_contiguous |= sympy.Eq(sizes[d], sympy.Integer(0)) - return is_contiguous - -def sympy_is_contiguous(sizes, strides): - dim = len(sizes) - return sympy_is_contiguous_generic(sizes, strides, list(range(dim - 1, -1, -1))) - -# NB: There is a TODO in C++ to allow omitting the batch dim. If that -# happens you will need to refactor this - -def sympy_is_channels_last_contiguous_2d(sizes, strides): - return sympy_is_contiguous_generic(sizes, strides, [1, 3, 2, 0]) - -def sympy_is_channels_last_contiguous_3d(sizes, strides): - return sympy_is_contiguous_generic(sizes, strides, [1, 4, 3, 2, 0]) - -def sympy_is_channels_last_strides_generic(sizes, strides, dim_order): - dim = len(sizes) - - if dim != len(dim_order): - return sympy.false - - m = sympy.Integer(0) - r = sympy.true - - # special case for trivial C dimension. default to NCHW - r &= sympy.Ne(strides[1], 0) - - for d in dim_order: - r &= sympy.Ne(sizes[d], 0) & (strides[d] >= m) - # Fallback to NCHW as default layout for ambiguous cases - # This is the flaw of implicit memory_format from strides. - # N111 tensor with identical strides for size 1 dimension; - # Two cases could lead us here: - # a. N111 contiguous Tensor ([N,1,1,1]@[1,1,1,1]) - # b. N11W contiguous Tensor sliced on the W-dimension. - # ([N,1,1,1]@[W,W,W,W]) - if d == 0: - r &= sympy.Ne(m, strides[1]) - # This is necessary to: - # 1. distinguish the memory_format of N1H1; - # [H, 1, 1, 1] channels_last stride - # [H, H, 1, 1] contiguous stride - # 2. permutation of 1C1W: - # [1, C, 1, H]@[HC, H, H, 1] transpose(1, 3) - # [1, H, 1, C]@[HC, 1, H, H] shouldn't be identified as - # channels_last - m = strides[d] * sympy.Max(sizes[d], 1) - - return r - -def sympy_is_channels_last_strides_2d(sizes, strides): - return sympy_is_channels_last_strides_generic(sizes, strides, [1, 3, 2, 0]) - -def sympy_is_channels_last_strides_3d(sizes, strides): - return sympy_is_channels_last_strides_generic(sizes, strides, [1, 4, 3, 2, 0]) - # TODO: Deduplicate this with torch/_prims_common/__init__.py def eval_is_non_overlapping_and_dense(sizes, strides): return int(guard_bool(_eval_is_non_overlapping_and_dense(sizes, strides))) @@ -1265,39 +745,6 @@ def _eval_is_non_overlapping_and_dense(sizes, strides): return True -unary_magic_methods = { - 'abs', - 'sym_float', - 'ceil', - 'floor', - 'neg', - 'sym_sqrt', - 'sym_not', -} - -# Most methods are only registered on SymInt and SymFloat -# Some methods are only be registered on SymBool -only_bool_magic_methods = {"and", "or", "sym_not", "sym_ite"} -# Methods that are also on SymBool, in addition to on SymInt and SymFloat -also_bool_magic_methods = {"eq"} -bool_magic_methods = only_bool_magic_methods | also_bool_magic_methods - -magic_methods_on_math = {"ceil", "floor"} -magic_methods_on_submodule = {"sym_float", "sym_sqrt", "sym_min", "sym_max", "sym_not", "sym_ite"} -magic_methods_on_operator_with_trailing_underscore = {"and", "or"} - -def method_to_operator(method): - if method in magic_methods_on_operator_with_trailing_underscore: - method_attr = f"{method}_" - else: - method_attr = method - if method in magic_methods_on_submodule: - op = getattr(torch.fx.experimental.symbolic_shapes, method_attr) - elif method in magic_methods_on_math: - op = getattr(math, method_attr) - else: - op = getattr(operator, method_attr) - return op def cast_symbool_to_symint_guardless(symbool: torch.SymBool) -> torch.SymInt: int_sym = sympy.Piecewise((1, symbool.node.expr), (0, True)) @@ -1322,306 +769,6 @@ def cast_symbool_to_symint_guardless(symbool: torch.SymBool) -> torch.SymInt: 'cast_symbool_to_symint_guardless': cast_symbool_to_symint_guardless, } -always_float_magic_methods = {"truediv", "sym_float", "sym_sqrt", "pow"} -always_int_magic_methods = {"ceil", "floor"} -always_bool_magic_methods = {"eq", "ne", "gt", "lt", "le", "ge", "and", "or", "sym_not", "is_non_overlapping_and_dense"} - -def wrap_node(x): - # TODO: let C++ also take advantage of this - if isinstance(x, SymNode) and x.constant is not None: - return x.constant - if x.is_int(): - return SymInt(x) - elif x.is_float(): - return SymFloat(x) - elif x.is_bool(): - return SymBool(x) - else: - raise AssertionError(f"unrecognized return type {x}") - -def _make_node_magic(method, func): - func = lru_cache(256)(func) - - if method in magic_methods_on_operator_with_trailing_underscore: - method_attr = f"{method}_" - else: - method_attr = method - - def binary_magic_impl(self, other): - op = method_to_operator(method) - - out_hint = None - if self.hint is not None and other.hint is not None: - out_hint = op(self.hint, other.hint) - - alternate_impl = alternate_impl_if_hinted_methods.get(method) - if alternate_impl and out_hint is not None: - return to_node(self, alternate_impl(wrap_node(self), wrap_node(other))) - - if sym_function_mode(): - return to_node(self, handle_sym_dispatch(op, (wrap_node(self), wrap_node(other)), {})) - assert isinstance(other, SymNode) - # TODO: consider constant prop here - try: - out = func(self.expr, other.expr) - except Exception: - log.warning("failed to eval %s(%s, %s)", method, self.expr, other.expr) - raise - out = safe_expand(out) - pytype: Type - # This is not strictly correct. In Python, a**b may return complex when - # a < 0 and b is a float: (-1)**2.1. Same for sympy.sqrt(-3.14). This - # returns a float while both arguments are ints: 2**(-1). Also, max and - # min do not type promote. To avoid having data-dependent control flow - # here, we just set the type to float if one of the args is a float. In - # case of a type mismatch, we assume that it will be detected during - # evaluation. - if method in always_float_magic_methods: - pytype = float - elif method in always_bool_magic_methods: - pytype = bool - elif self.pytype is float or other.pytype is float: - pytype = float - else: - pytype = self.pytype - - # Create a FX node that corresponds to the operation being applied to - # this node. - fx_node, _ = self.shape_env.create_fx_call_function(op, (self.fx_node, other.fx_node)) - return SymNode(out, self.shape_env, pytype, out_hint, fx_node=fx_node) - - def unary_magic_impl(self): - op = method_to_operator(method) - if sym_function_mode(): - return to_node(self, handle_sym_dispatch(op, (wrap_node(self),), {})) - # TODO: consider constant prop here - expr = self.expr - if method == "floor" or method == "ceiling": - expr = self.shape_env._simplify_floor_div(expr) - - try: - out = func(expr) - except Exception: - log.warning("failed to eval %s(%s)", method, expr) - raise - - out_hint = None - if self.hint is not None: - out_hint = op(self.hint) - out = safe_expand(out) - pytype: Type - if method in always_int_magic_methods: - pytype = int - elif method in always_float_magic_methods: - pytype = float - else: - pytype = self.pytype - - fx_node, _ = self.shape_env.create_fx_call_function(op, (self.fx_node,)) - return SymNode(out, self.shape_env, pytype, out_hint, fx_node=fx_node) - - if method in unary_magic_methods: - setattr(SymNode, f"_{method_attr}", unary_magic_impl) - elif method == "sym_ite": - - def sym_ite_impl(pred_node, then_node, else_node): - out_hint = then_node.hint if pred_node.hint else else_node.hint - if sym_function_mode(): - return to_node( - pred_node, - handle_sym_dispatch( - sym_ite, - (wrap_node(pred_node), wrap_node(then_node), wrap_node(else_node)), {} - ) - ) - - try: - out = func(pred_node.expr, then_node.expr, else_node.expr) - except Exception: - log.warning("failed to eval %s(%s, %s, %s)", method, pred_node.expr, then_node.expr, else_node.expr) - raise - - out = safe_expand(out) - fx_node, _ = pred_node.shape_env.create_fx_call_function( - sym_ite, - ( - pred_node.fx_node, - then_node.fx_node, - else_node.fx_node - ) - ) - return SymNode(out, pred_node.shape_env, then_node.pytype, out_hint, fx_node=fx_node) - setattr(SymNode, f"_{method_attr}", sym_ite_impl) - else: - setattr(SymNode, f"_{method_attr}", binary_magic_impl) - -def _make_node_sizes_strides(method, func): - # NB: don't LRU cache, lots of arguments - - def sizes_strides_impl(self, sizes, strides): - op = getattr(sys.modules[__name__], method) - if sym_function_mode(): - return to_node( - self, - handle_sym_dispatch( - op, - ([wrap_node(s) for s in sizes], [wrap_node(s) for s in strides]), - {} - ) - ) - size_exprs = [s.expr for s in sizes] - stride_exprs = [s.expr for s in strides] - try: - out = func(size_exprs, stride_exprs) - except Exception: - log.warning("failed to eval %s(%s, %s)", method, size_exprs, stride_exprs) - raise - # bool is never expandable - - size_hints = [] - out_hint = None - for s in sizes: - if s.hint is None: - break - size_hints.append(s.hint) - else: - stride_hints = [] - for s in strides: - if s.hint is None: - break - stride_hints.append(s.hint) - else: - out_hint = op(size_hints, stride_hints) - - # NB: This is the indicator function, not the actual bool! - pytype: Type - if method.endswith("_indicator"): - pytype = int - else: - pytype = bool - return SymNode(out, self.shape_env, pytype, out_hint) - - setattr(SymNode, f"_{method}", sizes_strides_impl) - - # TODO: This is technically hotpath, but in the ideal end state - # guards on this will resolve at a higher level so you never - # spend time in this code - def sizes_strides_user(sizes, strides): - for a in itertools.chain(sizes, strides): - if isinstance(a, SymInt): - return wrap_node(getattr(a.node, method)( - [to_node(a.node, b) for b in sizes], - [to_node(a.node, b) for b in strides], - )) - if method == "is_non_overlapping_and_dense_indicator": - return eval_is_non_overlapping_and_dense(sizes, strides) - else: - # TODO: this is an awful implementation - return bool(func( - [sympy.sympify(a) for a in sizes], - [sympy.sympify(a) for a in strides], - )) - - # Skip for is_non_overlapping_and_dense_indicator - if not hasattr(sys.modules[__name__], method): - setattr(sys.modules[__name__], method, sizes_strides_user) - -for method, func in magic_methods.items(): - _make_node_magic(method, func) - -for method, func in sizes_strides_methods.items(): - _make_node_sizes_strides(method, func) - -def _make_user_magic(method, user_type): - # User magic takes care of wrapping the other operand into a node, - # so that our internal logic can assume everything is nodes - - if method in magic_methods_on_operator_with_trailing_underscore: - method_attr = f"{method}_" - else: - method_attr = method - - def get_constant(x: Union[SymInt, int, SymFloat, float, SymBool, bool]): - if isinstance(x, (int, float, bool)): - return x - if isinstance(x, SymBool): - return x.node.guard_bool("", 0) - raise AssertionError("expect to be called with constant SymBools") - - def is_constant(x): - if isinstance(x, (int, float, bool)): - return True - if isinstance(x, (SymInt, SymFloat, SymBool)): - return x.node.is_constant() - return False - - # Before and after performing the operation, check if any operands are constant. - # If so, extract out the constant values first. If `self` itself is a - # constant, then "redispatch" by calling back into the operator. Sometimes - # this means that operations involving SymBool return plain bools. - # Alternatively, we could also rewrap into constant Symbool (i.e. by - # implementing wrap_bool in ConstantSymNodeImpl), but we're not doing that - # today for no particular reason. - def unary_magic_impl(self): - if is_constant(self): - return (method_to_operator(method))(get_constant(self)) - return wrap_node(getattr(self.node, method_attr)()) - - def binary_magic_impl(self, other): - if is_constant(self): - return (method_to_operator(method))(get_constant(self), other) - if is_constant(other): - other = get_constant(other) - other_node = to_node(self.node, other) - if other_node is NotImplemented: - return NotImplemented - ret = wrap_node(getattr(self.node, method_attr)(other_node)) - return get_constant(ret) if is_constant(ret) else ret - - def rbinary_magic_impl(self, other): - if is_constant(self): - return (method_to_operator(method))(get_constant(self), other) - if is_constant(other): - other = get_constant(other) - other_node = to_node(self.node, other) - if other_node is NotImplemented: - return NotImplemented - ret = wrap_node(getattr(other_node, method_attr)(self.node)) - return get_constant(ret) if is_constant(ret) else ret - - if method in unary_magic_methods: - setattr(user_type, f"__{method}__", unary_magic_impl) - elif method == "sym_ite": - - def sym_ite_magic_impl(pred, then_val, else_val): - pred_node = pred.node - then_node = to_node(pred_node, then_val) - else_node = to_node(pred_node, else_val) - if then_node is NotImplemented or else_node is NotImplemented: - return NotImplemented - assert isinstance(then_node, SymNode) and isinstance(else_node, SymNode) and then_node.pytype == else_node.pytype - ret = wrap_node(getattr(pred.node, method_attr)(then_node, else_node)) - return get_constant(ret) if ret.node.is_constant() else ret - - setattr(user_type, f"__{method}__", sym_ite_magic_impl) - else: - setattr(user_type, f"__{method}__", binary_magic_impl) - if method in reflectable_magic_methods: - setattr(user_type, f"__r{method}__", rbinary_magic_impl) - -for method, func in magic_methods.items(): - if method in only_bool_magic_methods: - _make_user_magic(method, SymBool) - continue - if method in also_bool_magic_methods: - _make_user_magic(method, SymBool) - _make_user_magic(method, SymInt) - _make_user_magic(method, SymFloat) - -del method -del func - - def _translation_validation_enabled() -> bool: from torch.fx.experimental.validator import translation_validation_enabled return translation_validation_enabled() diff --git a/torch/fx/experimental/validator.py b/torch/fx/experimental/validator.py index 8ca34aaa4a089..c5b14082c42b4 100644 --- a/torch/fx/experimental/validator.py +++ b/torch/fx/experimental/validator.py @@ -224,7 +224,7 @@ def abs(self, number: z3.ArithRef) -> z3.ArithRef: # 2. Calls an operation that corresponds to 'op', but works with Z3 # inhabitants (left as is if it works as is) def z3op(op: Callable, validator: "TranslationValidator") -> Callable: - from torch.fx.experimental.symbolic_shapes import sym_sqrt + from torch.fx.experimental.sym_node import sym_sqrt # Operations that have booleans as their argument. # This is needed because the argument of some FX nodes were diff --git a/torch/utils/_sympy/functions.py b/torch/utils/_sympy/functions.py index dfc4b92449781..b9088ac65de62 100644 --- a/torch/utils/_sympy/functions.py +++ b/torch/utils/_sympy/functions.py @@ -2,7 +2,10 @@ from sympy import S from sympy.core.logic import fuzzy_and, fuzzy_not, fuzzy_or -__all__ = ["FloorDiv", "ModularIndexing", "CleanDiv", "CeilDiv", "LShift", "RShift"] +__all__ = [ + "FloorDiv", "ModularIndexing", "CleanDiv", "CeilDiv", "Pow", "TrueDiv", + "LShift", "RShift", "IsNonOverlappingAndDenseIndicator", +] class FloorDiv(sympy.Function): @@ -252,3 +255,56 @@ def eval(cls, base, shift): if shift < 0: raise ValueError('negative shift count') return base // 2 ** shift + +# Overloaded to be compatible with regular Python. +# https://github.com/pytorch/pytorch/issues/90900 +class Pow(sympy.Function): + @classmethod + def eval(cls, base, exp): + if exp.is_zero: + return sympy.Integer(1) + elif base.is_zero and exp < 0: + raise ZeroDivisionError(f"{base} cannot be raised to a negative power") + else: + return base ** exp + +# Overloaded to be compatible with regular Python. +# https://github.com/pytorch/pytorch/issues/90900 +class TrueDiv(sympy.Function): + @classmethod + def eval(cls, base, divisor): + if divisor.is_zero: + raise ZeroDivisionError("division by zero") + else: + return base / divisor + + +# TODO: As an indicator, this != 0 implies == 1 (and vice versa). +# Because we do not have the ability to guard on the stride permutation +# at the moment, it is hard to make further inferences when this is true, +# as although we know the tensor is contiguous in *some* layout, we don't +# know which one (however, you could, for example, make the inference that +# reshaping this to a 1D tensor can be guard-free.) +class IsNonOverlappingAndDenseIndicator(sympy.Function): + is_integer = True + + @classmethod + def eval(cls, *args): + assert len(args) % 2 == 0 + dim = len(args) // 2 + # TODO: it is possible to make progress evaluating this guard + # even if not all of the inputs are known. For example, a 2D + # tensor with non-0/1 sizes but strides (0, 1) is definitely + # false, because we know its numel > 1 but it's broadcasted + # in dim 0. + if all(isinstance(a, sympy.Integer) for a in args): + # sym_node imported in torch.__init__. Local import to avoid an import cycle + from torch.fx.experimental.symbolic_shapes import eval_is_non_overlapping_and_dense + + size_args = args[0:dim] + stride_args = args[dim:] + return eval_is_non_overlapping_and_dense( + [int(a) for a in size_args], + [int(a) for a in stride_args] + ) + return None diff --git a/torch/utils/_sympy/interp.py b/torch/utils/_sympy/interp.py index 742737be9ec32..ac05defe9e4e1 100644 --- a/torch/utils/_sympy/interp.py +++ b/torch/utils/_sympy/interp.py @@ -14,7 +14,7 @@ from sympy.logic.boolalg import Boolean as SympyBoolean, BooleanAtom import torch -from .functions import CleanDiv, FloorDiv, Mod, ModularIndexing, Where +from .functions import CleanDiv, FloorDiv, Mod, ModularIndexing, Pow, TrueDiv, Where # TODO: Dedupe this with SYMPY_INTERP @@ -22,8 +22,6 @@ @functools.lru_cache(None) def handlers(): - from torch.fx.experimental.symbolic_shapes import Pow, TrueDiv - # TODO add CeilDiv (it doesn't appear in the index_expr) # TODO default to some decompositions if the interpreter doesn't have them From acd02a60d511cf70aa60b950d6f7541e7b51a785 Mon Sep 17 00:00:00 2001 From: lezcano Date: Thu, 26 Oct 2023 20:00:33 +0000 Subject: [PATCH 20/78] Add a test making sure we are not importing SymPy when importing torch (#112038) As per title Pull Request resolved: https://github.com/pytorch/pytorch/pull/112038 Approved by: https://github.com/malfet, https://github.com/peterbell10 ghstack dependencies: #112035, #112036, #112037 --- test/test_testing.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/test/test_testing.py b/test/test_testing.py index 7aebaa78e4111..1c668f8476c61 100644 --- a/test/test_testing.py +++ b/test/test_testing.py @@ -2240,6 +2240,17 @@ def test_no_warning_on_import(self) -> None: out = self._check_python_output("import torch") self.assertEqual(out, "") + def test_not_import_sympy(self) -> None: + out = self._check_python_output("import torch;import sys;print('sympy' not in sys.modules)") + self.assertEqual(out.strip(), "True", + "PyTorch should not depend on SymPy at import time as importing SymPy is *very* slow.\n" + "See the beginning of the following blog post for how to profile and find which file is importing sympy:\n" + "https://dev-discuss.pytorch.org/t/delving-into-what-happens-when-you-import-torch/1589\n\n" + "If you hit this error, you may want to:\n" + " - Refactor your code to avoid depending on sympy files you may not need to depend\n" + " - Use TYPE_CHECKING if you are using sympy + strings if you are using sympy on type annotations\n" + " - Import things that depend on SymPy locally") + @unittest.skipIf(IS_WINDOWS, "importing torch+CUDA on CPU results in warning") @parametrize('path', ['torch', 'functorch']) def test_no_mutate_global_logging_on_import(self, path) -> None: From 5b7183478508f03bccc0a662614d77ebaf5fedf7 Mon Sep 17 00:00:00 2001 From: Dino Viehland Date: Thu, 26 Oct 2023 23:55:34 +0000 Subject: [PATCH 21/78] Avoid c++ exception and stack trace (#111438) Summary: When raising an exception here this causes pybind11's dispatcher to kick in, which causes aiplatform's logic to kick in (aiplatform::error_reporting::util::printAddressesWithBestEffortLocationInfo), which ultimately uses `folly::symbolizer::Symbolizer::symbolize` for building up the stack trace. In 3.8 this uses about 3.62% of the CPU time per pyperf (https://fburl.com/scuba/pyperf_experimental/on_demand/oi554uvy). In Cinder 3.8 for some reason this is worse - using 5.94% of the CPU. This exception is happening when doing a hasattr() on `prims` for things like `bitwise_left_shift` which don't exist: https://www.internalfb.com/code/fbsource/[2d695f650d00]/fbcode/caffe2/torch/_inductor/lowering.py?lines=590 That exception is ultimately going to be swallowed anyway, and the stack trace has no meaningful value. Furthermore because this is kind of an expected outcome in the code versus some random C++ exception the stack trace is less valuable as well. This changes this to return a (None, None) on the failure case instead of returning a valid op/overload list, avoiding the exception, and reclaiming the 3.62%-5.94% of time. Test Plan: Existing CI and perf run: https://fburl.com/scuba/pyperf_experimental/on_demand/oi554uvy Differential Revision: D50018789 Pull Request resolved: https://github.com/pytorch/pytorch/pull/111438 Approved by: https://github.com/davidberard98 --- torch/_ops.py | 4 ++++ torch/csrc/jit/python/init.cpp | 5 ++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/torch/_ops.py b/torch/_ops.py index 5be632959510e..53ef91a18363b 100644 --- a/torch/_ops.py +++ b/torch/_ops.py @@ -823,6 +823,10 @@ def __getattr__(self, op_name): qualified_op_name = f"{namespace_name}::{op_name}" try: op, overload_names = torch._C._jit_get_operation(qualified_op_name) + if op is None: + raise AttributeError( + f"'_OpNamespace' '{self.name}' object has no attribute '{op_name}'" + ) except RuntimeError as e: # Turn this into AttributeError so getattr(obj, key, default) # works (this is called by TorchScript with __origin__) diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index c4d5d2e4e2863..21e09d78a0aee 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -1628,7 +1628,10 @@ void initJITBindings(PyObject* module) { try { auto symbol = Symbol::fromQualString(op_name); const auto& unsortedOps = getAllOperatorsFor(symbol); - TORCH_CHECK(!unsortedOps.empty(), "No such operator ", op_name); + if (unsortedOps.empty()) { + // No such operator + return py::make_tuple(py::none(), py::none()); + } // Depending on the order of registration, aten or jit ops may be // registered first. This sorting is helpful in cases where From 1569df7f01058cf0b6a263a111e2f237ffa235b6 Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Fri, 27 Oct 2023 00:13:55 +0000 Subject: [PATCH 22/78] Don't search getitem for batch fusions (#112088) Batch mm fusion regressed optimizer compile time by about ~1m, excluding getitem solves this problem. Pull Request resolved: https://github.com/pytorch/pytorch/pull/112088 Approved by: https://github.com/yanboliang --- torch/_inductor/fx_passes/group_batch_fusion.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/torch/_inductor/fx_passes/group_batch_fusion.py b/torch/_inductor/fx_passes/group_batch_fusion.py index eaa2cc29633c4..5b5123c1370bf 100644 --- a/torch/_inductor/fx_passes/group_batch_fusion.py +++ b/torch/_inductor/fx_passes/group_batch_fusion.py @@ -37,6 +37,10 @@ # The maximum tensor size that can go into the fusion group MAX_FUSE_TENSOR_SIZE_GROUP_LINEAR = 4096 +# exclude these nodes from BFS +# excluding get item improves optimizer compilation time by 60s +SEARCH_EXCLUSIONS = {operator.getitem} + class GroupBatchFusionBase: def match(self, node): @@ -582,6 +586,10 @@ def get_fusion_candidates( candidate_dict: DefaultDict[Any, List[torch.fx.Node]] = collections.defaultdict( list ) + + if root_node.target in SEARCH_EXCLUSIONS: + return candidate_dict + visited_set: Set[torch.fx.Node] = set() for next_node in root_node.all_input_nodes: From 22221c6d60613e498aa67b7f7f0f83ec97e35b8a Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 27 Oct 2023 00:23:03 +0000 Subject: [PATCH 23/78] Revert "Trigger specialization when you call size()/stride() from C++ (#111935)" This reverts commit 5846705e36795d76941e18073e49c6edba90c994. Reverted https://github.com/pytorch/pytorch/pull/111935 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/111935#issuecomment-1782107024)) --- c10/core/TensorImpl.cpp | 8 ++------ test/functorch/test_aotdispatch.py | 6 ++++++ test/test_proxy_tensor.py | 14 ++++---------- torch/_subclasses/fake_tensor.py | 6 ------ torch/csrc/utils/python_dispatch.cpp | 4 ---- torch/fx/experimental/proxy_tensor.py | 4 +--- 6 files changed, 13 insertions(+), 29 deletions(-) diff --git a/c10/core/TensorImpl.cpp b/c10/core/TensorImpl.cpp index f40be43379db8..4310937a808f5 100644 --- a/c10/core/TensorImpl.cpp +++ b/c10/core/TensorImpl.cpp @@ -635,9 +635,7 @@ bool TensorImpl::is_non_overlapping_and_dense_custom() const { } IntArrayRef TensorImpl::sizes_custom() const { - if (C10_UNLIKELY( - matches_python_custom(SizesStridesPolicy::CustomSizes) || - has_symbolic_sizes_strides_)) { + if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomSizes))) { return pyobj_slot_.load_pyobj_interpreter()->sizes(this); } return sizes_default(); @@ -672,9 +670,7 @@ c10::Device TensorImpl::device_custom() const { } IntArrayRef TensorImpl::strides_custom() const { - if (C10_UNLIKELY( - matches_python_custom(SizesStridesPolicy::CustomStrides) || - has_symbolic_sizes_strides_)) { + if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomStrides))) { return pyobj_slot_.load_pyobj_interpreter()->strides(this); } return strides_default(); diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index 949d5843a32b1..6fdbb5972d4d7 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -3482,13 +3482,17 @@ def forward(self, x): xfail('block_diag', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('combinations', ''), # aten.masked_select.default xfail('frexp', ''), # aten.frexp.Tensor - couldn't find symbolic meta function/decomposition + xfail('gradient', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('i0', ''), # aten.i0.default - couldn't find symbolic meta function/decomposition xfail('index_fill', ''), # Cannot call sizes() on tensor with symbolic sizes/strides + xfail('kron', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('kthvalue', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('linalg.eigvals', ''), # aten.linalg_eig.default - couldn't find symbolic meta function/decomposition xfail('linalg.lstsq', ''), # aten.linalg_lstsq.default - couldn't find symbolic meta function/decomposition xfail('linalg.lstsq', 'grad_oriented'), # aten.linalg_lstsq.default - couldn't find symbolic meta funct... xfail('linalg.lu_solve', ''), # aten.linalg_lu_solve.default - couldn't find symbolic meta function/deco... + xfail('linalg.multi_dot', ''), # Cannot call sizes() on tensor with symbolic sizes/strides + xfail('masked.prod', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('masked_scatter', ''), # Cannot call sizes() on tensor with symbolic sizes/strides skip('nn.functional.batch_norm', ''), # '0 is not tracked with proxy for >(data)); }); - m.def("_non_sym_sizes", [](const at::Tensor& a) { - return a.sizes(); // NB: NOT sym_size - }); - using c10::impl::TorchDispatchModeKey; py::enum_(m, "_TorchDispatchModeKey") .value("FUNCTIONAL", TorchDispatchModeKey::FUNCTIONAL) diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index bba1a894a7f82..cdf7982e8427b 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -275,9 +275,7 @@ def can_handle_tensor(x): return r # For pre-autograd tracing, we do not want to run CompositeImplicit decomps. - if not pre_dispatch and func not in [ - torch.ops.aten.size.default, torch.ops.aten.stride.default, torch.ops.aten.storage_offset.default - ]: + if not pre_dispatch: with proxy_mode: r = func.decompose(*args, **kwargs) if r is not NotImplemented: From ac4cc5dbea51a794ccd7ffcdab39a6671f625f63 Mon Sep 17 00:00:00 2001 From: Nikita Shulga <2453524+malfet@users.noreply.github.com> Date: Fri, 27 Oct 2023 00:39:28 +0000 Subject: [PATCH 24/78] [Dynamo] Do not crash if numpy is not installed (#112175) `s/isinstance(value, np.generic)/np is not None and isinstance(value, np.generic)/` Found while looking at https://github.com/pytorch/pytorch/pull/110512 Pull Request resolved: https://github.com/pytorch/pytorch/pull/112175 Approved by: https://github.com/ev-br, https://github.com/kit1980 --- torch/_dynamo/variables/builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 64aadc9157a46..3ac8c2b970bb0 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -517,7 +517,7 @@ def index_source(key): source=self.source, guards=make_guards(GuardBuilder.ID_MATCH), ) - elif isinstance(value, np.generic): + elif np is not None and isinstance(value, np.generic): # numpy array scalars: convert to 0D arrays return self.wrap_numpy_ndarray(np.asarray(value)) elif is_numpy(value): From 797d7100de3dd22db097225bd1e700ce4f047158 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 27 Oct 2023 01:35:27 +0000 Subject: [PATCH 25/78] Revert "[quant][pt2e][be] Cleanup observer insertion logic (#111828)" This reverts commit bf998a2c5d549cf4856c7becfca4a169bf68b709. Reverted https://github.com/pytorch/pytorch/pull/111828 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/111828#issuecomment-1782154648)) --- .../pt2e/test_xnnpack_quantizer.py | 16 +-- torch/ao/quantization/pt2e/prepare.py | 118 ++++++++++-------- torch/ao/quantization/quantizer/quantizer.py | 4 +- .../quantizer/xnnpack_quantizer.py | 7 +- 4 files changed, 81 insertions(+), 64 deletions(-) diff --git a/test/quantization/pt2e/test_xnnpack_quantizer.py b/test/quantization/pt2e/test_xnnpack_quantizer.py index 2bab0aa69052d..b1720d20ffa56 100644 --- a/test/quantization/pt2e/test_xnnpack_quantizer.py +++ b/test/quantization/pt2e/test_xnnpack_quantizer.py @@ -365,16 +365,12 @@ def test_propagate_annotation(self): m = prepare_pt2e(m, quantizer) m(*example_inputs) - act_post_processes_pairs = [] - for n in m.graph.nodes: - if n.target in [ - torch.ops.aten.view.default, - torch.ops.aten.hardtanh.default, - ]: - input_act = getattr(m, n.args[0].target) - output_act = getattr(m, list(n.users)[0].target) - self.assertEqual(id(input_act), id(output_act)) - + self.assertEqual( + id(m.activation_post_process_2), id(m.activation_post_process_3) + ) + self.assertEqual( + id(m.activation_post_process_3), id(m.activation_post_process_4) + ) m = convert_pt2e(m, fold_quantize=True) node_occurrence = { # input and output are using quantize_per_tensor and weight is using quantize_per_channel diff --git a/torch/ao/quantization/pt2e/prepare.py b/torch/ao/quantization/pt2e/prepare.py index d078d64baa752..69be6a1f15fa5 100644 --- a/torch/ao/quantization/pt2e/prepare.py +++ b/torch/ao/quantization/pt2e/prepare.py @@ -1,6 +1,9 @@ import torch from torch._subclasses import FakeTensor from torch.ao.quantization.fx.prepare import ( + _get_arg_as_input_act_obs_or_fq, + _get_output_act_obs_or_fq, + _get_dtype_and_is_dynamic, _insert_obs_or_fq, _save_state, _is_activation_post_process_node, @@ -18,6 +21,7 @@ from torch.ao.quantization.fx.custom_config import PrepareCustomConfig from typing import Dict, Tuple, Union, Any, Optional from torch.ao.quantization.quantizer import ( + QuantizationAnnotation, EdgeOrNode, SharedQuantizationSpec, QuantizationSpecBase, @@ -256,56 +260,70 @@ def _maybe_insert_input_observer_for_arg_or_kwarg( # default (no observer) new_arg = arg - # find the original `arg` node to the current node, skipping inserted observer/fake_quant nodes - original_arg = arg - while _is_activation_post_process_node(original_arg, named_modules): - original_arg = original_arg.args[0] # type: ignore[assignment] - assert isinstance(original_arg, Node), f"expect original argument to be a Node, but got: {type(original_arg)}" - - input_edge = (original_arg, node) - if input_edge not in obs_or_fq_map: - return new_arg - # input_edge needs to be observed - input_edge_obs_or_fq = obs_or_fq_map[input_edge] - if input_edge_obs_or_fq is None: - return new_arg - - arg_as_output_obs_or_fq = obs_or_fq_map.get(original_arg, None) - # the arg is observed as the output and is using the same instance as the input_edge - # we'll reuse the inserted observer/fake_quant - if arg_as_output_obs_or_fq is not None and id(arg_as_output_obs_or_fq) == id(input_edge_obs_or_fq): - return new_arg - - # otherwise, we'll insert a new observer/fake_quant node - - existing_obs_node = None - # skip inserting new observers if there is an observer inserted for the arg before - # that has the same dtype that we want to insert here - # alternatively we could have a dedup pass after we insert all observers to deduplicate - # observers - # Example: - # conv1 -> obs1 -> existing_obs -> conv2 - # \ -> conv3 - # - # instead of inserting new observers we will have: - # conv1 -> obs1 -> existing_obs -> conv2 - # \ -> conv3 - for maybe_obs_node in arg.users.keys(): - if not _is_activation_post_process_node(maybe_obs_node, named_modules): - continue - maybe_obs_mod = named_modules[maybe_obs_node.target] # type: ignore[index] - if ( - type(maybe_obs_mod) == type(input_edge_obs_or_fq) and - maybe_obs_mod.dtype == input_edge_obs_or_fq.dtype - ): - input_edge_obs_or_fq = maybe_obs_mod # type: ignore[assignment] - existing_obs_node = maybe_obs_node - break - - if existing_obs_node is None: - new_arg = _insert_obs_or_fq(arg, input_edge_obs_or_fq, model, named_modules, model.graph) - else: - new_arg = existing_obs_node + quantization_annotation = node.meta.get("quantization_annotation", QuantizationAnnotation()) + arg_as_input_act_obs_or_fq = _get_arg_as_input_act_obs_or_fq(arg, node, named_modules, obs_or_fq_map, is_qat) + arg_as_input_target_dtype, arg_as_input_target_is_dynamic = _get_dtype_and_is_dynamic(arg_as_input_act_obs_or_fq) + + arg_as_output_act_obs_or_fq = _get_output_act_obs_or_fq(arg, named_modules, obs_or_fq_map, is_qat) + arg_as_output_target_dtype, arg_as_output_target_is_dynamic = _get_dtype_and_is_dynamic(arg_as_output_act_obs_or_fq) + + if arg_as_input_target_is_dynamic or arg_as_input_target_dtype not in [torch.float, None]: + if arg_as_input_target_dtype == arg_as_output_target_dtype and \ + arg_as_input_target_is_dynamic == arg_as_output_target_is_dynamic: + assert _is_activation_post_process_node(arg, named_modules) + assert arg_as_input_act_obs_or_fq is not None + observed_arg = arg.args[0] + assert isinstance(observed_arg, Node), f"expect observed argument to be a Node, but got: {type(observed_arg)}" + assert observed_arg in obs_or_fq_map, \ + f"can't find a sharing group for node: {observed_arg}" + # reuse the existing obs/fq + arg_as_input_act_obs_or_fq = obs_or_fq_map[observed_arg] + # we don't need to insert new observer node + new_arg = arg + else: + # skip inserting new observers if there is an observer inserted for the arg before + # that has the same dtype that we want to insert here + # alternatively we could have a dedup pass after we insert all observers to deduplicate + # observers + # Example: + # arg -> existing_obs -> conv1 + # \ -> conv2 + # + # instead of inserting new observers we will have: + # arg -> existing_obs -> conv1 + # \ -> conv2 + existing_obs_node = None + for maybe_obs_node in arg.users.keys(): + if maybe_obs_node.op == 'call_module': + maybe_obs_mod = named_modules[maybe_obs_node.target] # type: ignore[index] + if ( + type(maybe_obs_mod) == type(arg_as_input_act_obs_or_fq) and + maybe_obs_mod.dtype == arg_as_input_target_dtype + ): + arg_as_input_act_obs_or_fq = maybe_obs_mod # type: ignore[assignment] + existing_obs_node = maybe_obs_node + break + + assert arg_as_input_act_obs_or_fq is not None + if existing_obs_node is None: + maybe_observed_arg = arg + # When quantizing two layers with different configs we can have + # conv2d (int8) -> avgpool(uint8) + # In this case observer insertion for avgpool will come here but the input + # to avgpool will be output observer of conv2d + # Now the obs map that we update must correspond to the original input of + # avgpool and not the output obs of conv2d + # This is because when referring to the edge, quantizer would refer to + # original input and not the observed one. + while _is_activation_post_process_node(arg, named_modules): + arg = arg.args[0] # type: ignore[assignment] + arg_as_input_act_obs_or_fq = obs_or_fq_map[(arg, node)] + new_obs_node = _insert_obs_or_fq( + maybe_observed_arg, arg_as_input_act_obs_or_fq, model, named_modules, model.graph) + # override this arg to be the observed arg + new_arg = new_obs_node + else: + new_arg = existing_obs_node return new_arg diff --git a/torch/ao/quantization/quantizer/quantizer.py b/torch/ao/quantization/quantizer/quantizer.py index 2a2de76a0c6ee..607e1b47a3bd3 100644 --- a/torch/ao/quantization/quantizer/quantizer.py +++ b/torch/ao/quantization/quantizer/quantizer.py @@ -138,9 +138,7 @@ class QuantizationAnnotation: """ # a map from torch.fx.Node to a type of QuantizationSpecBase - input_qspec_map: Dict[Node, Optional[QuantizationSpecBase]] = field( - default_factory=dict - ) + input_qspec_map: Dict[Node, QuantizationSpecBase] = field(default_factory=dict) # How the output of this node is quantized, expressed as QuantizationSpec # TODO: change the value to QuantizationSpec in a separate PR diff --git a/torch/ao/quantization/quantizer/xnnpack_quantizer.py b/torch/ao/quantization/quantizer/xnnpack_quantizer.py index 25328c3d0f623..e00dfe649d854 100644 --- a/torch/ao/quantization/quantizer/xnnpack_quantizer.py +++ b/torch/ao/quantization/quantizer/xnnpack_quantizer.py @@ -154,7 +154,12 @@ def get_symmetric_quantization_config( ), ) - bias_quantization_spec = None + bias_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = ( + PlaceholderObserver + ) + bias_quantization_spec = QuantizationSpec( + dtype=torch.float, observer_or_fake_quant_ctr=bias_observer_or_fake_quant_ctr + ) if is_dynamic: quantization_config = QuantizationConfig( act_quantization_spec, From 0a3199dd7e12b6515bc6f3952b927f7b7f761aac Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 27 Oct 2023 01:37:31 +0000 Subject: [PATCH 26/78] Revert "Readded device_assert skipping in index and index_put (and also added (#112093)" This reverts commit e38347f490ae14bf96913a19e7dab9b5e752c276. Reverted https://github.com/pytorch/pytorch/pull/112093 on behalf of https://github.com/izaitsevfb due to Sorry, trying to resolve a conflict with intern, and unblock the revert of #108690 ([comment](https://github.com/pytorch/pytorch/pull/112093#issuecomment-1782154814)) --- test/inductor/test_torchinductor.py | 14 ----------- torch/_inductor/codegen/common.py | 10 ++++---- torch/_inductor/codegen/cpp.py | 2 +- torch/_inductor/fx_passes/post_grad.py | 13 ++++------ torch/_inductor/index_propagation.py | 4 ++-- torch/_inductor/ir.py | 4 ++-- torch/_inductor/lowering.py | 33 ++++++++++++-------------- torch/_inductor/pattern_matcher.py | 1 - torch/_inductor/virtualized.py | 6 ++--- 9 files changed, 31 insertions(+), 56 deletions(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 669d99632e9e2..e38eaa8d84be4 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -827,7 +827,6 @@ def repeat(x, n): self.assertEqual(actual, repeat(x, 3)) @skipIfRocm - @config.patch(debug_index_asserts=False) def test_neg_index(self): def test(fn, inps, has_assert: bool, has_wrapping: bool): for dynamic in (True, False): @@ -894,11 +893,6 @@ def flip_with_index(a): # Constant is propagated as we can prove that the result is always negative. test(flip_with_index_constant, (a,), has_assert=False, has_wrapping=False) - def unsafe_index(a, b): - return aten._unsafe_index(a, (b,)) - - test(unsafe_index, (a, b), has_assert=False, has_wrapping=True) - def test_computed_buffer_inlining(self): def flip(x): idx = torch.arange(x.size(0) - 1, -1, -1, device=x.device) @@ -3553,14 +3547,6 @@ def matmul_with_op(x, y, fn): out, source_codes = run_and_get_code(foo_opt, inps[0], inps[1], fn) self.assertEqual(out, matmul_with_op(inps[0], inps[1], fn)) - def test_remove_noop_copy(self): - def fn(x, y): - x = x.cos() - a = x.copy_(y) - return a.sin() - - self.common(fn, (torch.randn(8, 8), torch.randn(8))) - def test_cat_of_loops_and_extern_kernel(self): class M(torch.nn.Module): def __init__( diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index f0ca422461517..a7a429e89ec92 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -992,7 +992,7 @@ def inner(*args, **kwargs): return inner @staticmethod - def indirect_indexing(var, size, add_asserts=True): + def indirect_indexing(var, size, check=True): # Skip CSE since this doesn't return an expression if var.bounds.lower < 0: @@ -1020,7 +1020,7 @@ def indirect_indexing(var, size, add_asserts=True): new_var.update_on_args("index_wrap", (var,), {}) var = new_var - if self.generate_assert(add_asserts): + if self.generate_assert(check): mask = self.load_mask(var) # An assertion line may have been written already, if so just @@ -1129,10 +1129,8 @@ def __exit__(self, exc_type, exc_val, exc_tb): V.graph.scheduler.remove_kernel_local_buffers() super().__exit__(exc_type, exc_val, exc_tb) - def generate_assert(self, add_asserts): - return ( - add_asserts or config.debug_index_asserts - ) and config.assert_indirect_indexing + def generate_assert(self, check): + return (check or config.debug_index_asserts) and config.assert_indirect_indexing def load_mask(self, var): # only the triton kernel requires mask diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index ccb242cfb73b2..ea0c74334b328 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -2275,7 +2275,7 @@ def can_use_int32(): return tmp_var @staticmethod - def indirect_indexing(index_var, size, add_asserts=True): + def indirect_indexing(index_var, size, check=True): return sympy_symbol(str(index_var)) @staticmethod diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py index 1defdab5abb13..cd2ffcbdc58b7 100644 --- a/torch/_inductor/fx_passes/post_grad.py +++ b/torch/_inductor/fx_passes/post_grad.py @@ -484,8 +484,8 @@ def is_valid_splitwithsizes_cat(match): return True -def same_meta(node1: torch.fx.Node, node2: torch.fx.Node): - """True if two nodes have the same metadata""" +def same_layout(node1: torch.fx.Node, node2: torch.fx.Node): + """True if two nodes have the same size/strides""" val1 = node1.meta.get("val") val2 = node2.meta.get("val") return ( @@ -493,7 +493,6 @@ def same_meta(node1: torch.fx.Node, node2: torch.fx.Node): and val2 is not None and val1.size() == val2.size() and val1.layout == val2.layout - and val1.device == val2.device and (val1.layout != torch.strided or val1.stride() == val2.stride()) ) @@ -506,7 +505,6 @@ def register_fun(cond): register_decomposition(targets, registry=noop_registry, unsafe=True)( (cond, nop_arg) ) - return cond return register_fun @@ -566,10 +564,7 @@ def cat_noop(inputs, dim=0): return len(inputs) == 1 -# Note, we also always have a check for identical metadata, which is why these -# are safe -@register_noop_decomp([aten.copy], nop_arg=1) -@register_noop_decomp([aten.alias, aten.clone]) +@register_noop_decomp([aten.clone, aten.alias]) def true_noop(*args, **kwargs): return True @@ -612,7 +607,7 @@ def remove_noop_ops(graph: torch.fx.Graph): is_valid, args, kwargs = get_fake_args_kwargs(node) if not is_valid: continue - if same_meta(node, src) and cond(*args, **kwargs): + if same_layout(node, src) and cond(*args, **kwargs): node.replace_all_uses_with(src) graph.erase_node(node) diff --git a/torch/_inductor/index_propagation.py b/torch/_inductor/index_propagation.py index 861a465d5fb4b..6be483d880e3b 100644 --- a/torch/_inductor/index_propagation.py +++ b/torch/_inductor/index_propagation.py @@ -249,7 +249,7 @@ def inner(*args: Any, **kwargs: Any) -> IndexPropResult: return inner def indirect_indexing( - self, index: Union[Any, IndexPropVar], size: Any, add_asserts: bool = True + self, index: Union[Any, IndexPropVar], size: Any, check: bool = True ) -> Any: # nb. We do index + Where(...) rather than Where(idx >= 0, idx, idx + sz) because we don't have CSE # for SymPy expressions, so we don't want to repeat idx too much @@ -259,4 +259,4 @@ def indirect_indexing( # If we are turning a indirect indexing into direct, we need to wrap it. index = index.value.expr return index + Where(index >= 0, 0, size) - return self.fallback("indirect_indexing", (index, size, add_asserts), {}).value + return self.fallback("indirect_indexing", (index, size, check), {}).value diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 7be3836e0d75e..beb76bc45b197 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -6144,7 +6144,7 @@ def shim(mask, other): ) @staticmethod - def indirect_indexing(index_proxy, size, add_asserts=True): + def indirect_indexing(index_proxy, size, check=True): """ Flow data from tensors into indexing formulas. Introduce a call_module to update the indexing. @@ -6154,7 +6154,7 @@ def indirect_indexing(index_proxy, size, add_asserts=True): def set_indirect(new_var): self.body.replace_indirect( - var, V.ops.indirect_indexing(new_var, size, add_asserts) + var, V.ops.indirect_indexing(new_var, size, check) ) tracer.create_proxy( diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 0541eb7357a0f..e6248be16d338 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -2696,7 +2696,6 @@ def index_output_size_and_inner_fn( indices_loaders, indexed_size, x_loader, - add_asserts, ): # Note that behavior of indexing differs when there are non consecutive # tensors. In this case, the tensor index is pulled to the beginning. @@ -2747,7 +2746,7 @@ def fn(idx): ops.indirect_indexing( loader(idx[start_offset : start_offset + rank]), size, - add_asserts=add_asserts, + check=check, ) ) new_index = [ @@ -2759,7 +2758,7 @@ def fn(idx): return output_size, fn -def index_impl(x, indices, add_asserts): +def index_impl(x, indices, check): assert isinstance(indices, (list, tuple)) x_loader = x.make_loader() indices, tensor_indices = check_and_broadcast_indices(indices, x.get_device()) @@ -2786,7 +2785,6 @@ def index_impl(x, indices, add_asserts): indices_loaders, indexed_size, x_loader, - add_asserts=add_asserts, ) return Pointwise.create( @@ -2800,7 +2798,7 @@ def index_impl(x, indices, add_asserts): @register_lowering(aten.index, type_promotion_kind=None) def index(x, indices): try: - return index_impl(x, indices, add_asserts=True) + return index_impl(x, indices, check=True) except NotImplementedError: # Fallback to ATen for boolean indexing x.realize() @@ -2809,7 +2807,7 @@ def index(x, indices): @register_lowering(aten._unsafe_index, type_promotion_kind=None) def _unsafe_index(x, indices): - return index_impl(x, indices, add_asserts=False) + return index_impl(x, indices, check=False) # All the indexing decompositions are written in terms of index, index_put, and index_put_ @@ -2827,7 +2825,7 @@ def index_put(x, indices, values, accumulate=False): @register_lowering(aten._unsafe_index_put) def _unsafe_index_put(x, indices, values, accumulate=False): - return index_put_impl_(clone(x), indices, values, accumulate, add_asserts=False) + return index_put_impl_(clone(x), indices, values, accumulate, check=False) def index_put_as_masked_fill(self, indices, value, accumulate): @@ -2849,10 +2847,10 @@ def index_put_fallback(self, indices, values, accumulate): @register_lowering(aten.index_put_, type_promotion_kind=None) def index_put_(self, indices, values, accumulate=False): - return index_put_impl_(self, indices, values, accumulate, add_asserts=True) + return index_put_impl_(self, indices, values, accumulate, check=True) -def index_put_impl_(self, indices, values, accumulate, add_asserts): +def index_put_impl_(self, indices, values, accumulate, check): # Dispatch to masked fill for single boolean index with single value if ( values.get_numel() == 1 @@ -2917,7 +2915,6 @@ def index_put_impl_(self, indices, values, accumulate, add_asserts): indices_loaders, indexed_size, None, - add_asserts=add_asserts, ) values = expand(values, expected_vals_size) @@ -3179,7 +3176,7 @@ def scale_fn(x, scale, size): x = ops.index_expr(x, torch.float32) x = ops.mul(x, ops.constant(scale, torch.float32)) x = ops.to_dtype(x, torch.int32) - return ops.indirect_indexing(x, size, add_asserts=False) + return ops.indirect_indexing(x, size, check=False) def fn(idx): x = idx[-n:] @@ -3308,8 +3305,8 @@ def load_bounded(fy, fx): _0 = ops.constant(0, torch.int32) iHm1 = ops.constant(iH - 1, torch.int32) iWm1 = ops.constant(iW - 1, torch.int32) - iy = ops.indirect_indexing(clamp(fy, _0, iHm1), iH, add_asserts=False) - ix = ops.indirect_indexing(clamp(fx, _0, iWm1), iW, add_asserts=False) + iy = ops.indirect_indexing(clamp(fy, _0, iHm1), iH, check=False) + ix = ops.indirect_indexing(clamp(fx, _0, iWm1), iW, check=False) return x_loader([n, c, iy, ix]) iy = ops.to_dtype(in_y, get_int_dtype(iH + 1)) @@ -3346,7 +3343,7 @@ def reflect(x, size, offset): x = ops.index_expr(x, torch.int32) x = ops.sub(x, ops.index_expr(offset, torch.int32)) x = ops.sub(size, ops.abs(ops.sub(size, ops.abs(x)))) - return ops.indirect_indexing(x, size_num, add_asserts=False) + return ops.indirect_indexing(x, size_num, check=False) def fn(idx): *b, x, y = idx @@ -3806,12 +3803,12 @@ def fn(idx): ops.indirect_indexing( ops.minimum(ph, ops.sub(phend, ops.constant(1, torch.int32))), indices_size[-2], - add_asserts=False, + check=False, ), ops.indirect_indexing( ops.minimum(pw, ops.sub(pwend, ops.constant(1, torch.int32))), indices_size[-1], - add_asserts=False, + check=False, ), ] @@ -4257,14 +4254,14 @@ def fn(idx): ph, ops.sub(phend, ops.constant(1, torch.int32)) ), pooled_height, - add_asserts=False, + check=False, ), ops.indirect_indexing( ops.minimum( pw, ops.sub(pwend, ops.constant(1, torch.int32)) ), pooled_width, - add_asserts=False, + check=False, ), ] ), diff --git a/torch/_inductor/pattern_matcher.py b/torch/_inductor/pattern_matcher.py index c90d178b549e0..aabcd15f89a92 100644 --- a/torch/_inductor/pattern_matcher.py +++ b/torch/_inductor/pattern_matcher.py @@ -939,7 +939,6 @@ def normalize_args(**kwargs): normalize_args=normalize_args, ) pattern.register(pass_dict) - return pattern.pattern @functorch_config.patch(functionalize_rng_ops=False) diff --git a/torch/_inductor/virtualized.py b/torch/_inductor/virtualized.py index 6e981f9225fa9..a1750bfa9e542 100644 --- a/torch/_inductor/virtualized.py +++ b/torch/_inductor/virtualized.py @@ -94,7 +94,7 @@ def masked(mask, body, other) -> str: return f"ops.masked({mask}, {body()}, {other})" @staticmethod - def indirect_indexing(index_var, size, add_asserts=True) -> sympy.Symbol: + def indirect_indexing(index_var, size, check=True) -> sympy.Symbol: return sympy_symbol(f"({str(index_var)})") @classmethod @@ -269,10 +269,10 @@ def _wrap(x): return OpsValue(x) @staticmethod - def indirect_indexing(index, size, add_asserts=True): + def indirect_indexing(index, size, check=True): # Returns a sympy value, not IR value index = OpsWrapper._unwrap(index) - return _ops.indirect_indexing(index, size, add_asserts) + return _ops.indirect_indexing(index, size, check) ops = OpsWrapper() From 64fd027f2e0ee743235f9339f97b3a9224527cae Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 27 Oct 2023 01:40:06 +0000 Subject: [PATCH 27/78] Revert "[inductor] benchmark fusion (#108193)" This reverts commit 73cc5d1cdda118007ccdb0be8d775ba76726596e. Reverted https://github.com/pytorch/pytorch/pull/108193 on behalf of https://github.com/izaitsevfb due to Trying to unblock the revert of #108690, please rebase and reland. ([comment](https://github.com/pytorch/pytorch/pull/108193#issuecomment-1782157638)) --- test/inductor/test_benchmark_fusion.py | 138 ------------------------- torch/_inductor/codegen/common.py | 21 ++-- torch/_inductor/codegen/cpp.py | 1 - torch/_inductor/codegen/triton.py | 82 +-------------- torch/_inductor/config.py | 1 - torch/_inductor/scheduler.py | 105 +------------------ torch/_inductor/virtualized.py | 17 +-- 7 files changed, 11 insertions(+), 354 deletions(-) delete mode 100644 test/inductor/test_benchmark_fusion.py diff --git a/test/inductor/test_benchmark_fusion.py b/test/inductor/test_benchmark_fusion.py deleted file mode 100644 index bc43100fd4bac..0000000000000 --- a/test/inductor/test_benchmark_fusion.py +++ /dev/null @@ -1,138 +0,0 @@ -# Owner(s): ["module: inductor"] -import math -import os -import sys - -import torch -from torch.testing._internal.common_utils import ( - IS_CI, - IS_WINDOWS, - skipIfRocm, - TEST_WITH_ASAN, - TestCase as TorchTestCase, -) -from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA - -# Make the helper files in test/ importable -pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) -sys.path.append(pytorch_test_dir) - -import contextlib -import unittest - -from torch._inductor import config -from torch._inductor.scheduler import Scheduler - -if IS_WINDOWS and IS_CI: - sys.stderr.write( - "Windows CI does not have necessary dependencies for test_torchinductor yet\n" - ) - if __name__ == "__main__": - sys.exit(0) - raise unittest.SkipTest("requires sympy/functorch/filelock") - -from inductor.test_torchinductor import check_model, check_model_cuda, copy_tests - - -class TestCase(TorchTestCase): - @classmethod - def setUpClass(cls): - super().setUpClass() - cls._stack = contextlib.ExitStack() - cls._stack.enter_context( - config.patch( - { - "benchmark_kernel": True, - "benchmark_fusion": True, - } - ) - ) - - @classmethod - def tearDownClass(cls): - cls._stack.close() - super().tearDownClass() - - -class BenchmarkFusionTestTemplate: - def test_softmax(self): - def f(x): - return torch.nn.functional.softmax(x, dim=-1) - - self.common(f, (torch.rand(2, 8192),)) - - @skipIfRocm # fail accuracy check on ROCm - def test_resnet18(self): - import torchvision - - model = torchvision.models.resnet18() - model.eval() - batch_size = 16 - inputs = (torch.randn((batch_size, 3, 224, 224)),) - self.common(model, inputs, atol=1e-2, rtol=1e-2) - - def test_register_spills(self): - """ - The test can potentially trigger register spills - """ - old_benchmark_fn = Scheduler.benchmark_fused_nodes - - def new_benchmark_fn(scheduler, nodes): - """ - We override Scheduler.benchmark_fused_nodes to return latency 1.0 - if there are no register spills. Without this, we may not able to - test the code path handling register spilling because before register - start spilling, the related fusion may have already been skipped - due to longer lantency. - """ - ms = old_benchmark_fn(scheduler, nodes) - if not math.isinf(ms): - ms = 1.0 - return ms - - # Disable dynamic_scale_rblock to make it easier to trigger register - # spilling. - with unittest.mock.patch.object( - Scheduler, "benchmark_fused_nodes", new_benchmark_fn - ), config.patch("dynamic_scale_rblock", False): - S = 512 - - def f(*inputs): - inputs = list(inputs) - outputs = [] - out = torch.zeros(S, device=self.device) - for x in inputs: - x = x * 2 - x = x + 1 - x = x.sum(dim=-1) - outputs.append(x) - out = out + x - return outputs, out - - N = int(os.environ.get("NINP", "30")) - inputs = [torch.randn(S, 2560, device=self.device) for _ in range(N)] - opt_f = torch.compile(f) - opt_f(*inputs) - - -if HAS_CUDA and not TEST_WITH_ASAN: - - class BenchmarkFusionCudaTest(TestCase): - common = check_model_cuda - device = "cuda" - - copy_tests(BenchmarkFusionTestTemplate, BenchmarkFusionCudaTest, "cuda") - -if HAS_CPU and not torch.backends.mps.is_available(): - - class BenchmarkFusionCpuTest(TestCase): - common = check_model - device = "cpu" - - copy_tests(BenchmarkFusionTestTemplate, BenchmarkFusionCpuTest, "cpu") - -if __name__ == "__main__": - from torch._dynamo.test_case import run_tests - - if HAS_CPU or HAS_CUDA: - run_tests() diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index a7a429e89ec92..de967395518b3 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -438,14 +438,14 @@ def __init__(self, name, line): self.name = name def __call__(self): - if all( - self.name not in x - for x in ( - V.graph.removed_buffers, - V.kernel.removed_buffers, - V.graph.inplaced_to_remove, - V.kernel.inplaced_to_remove, + # V.kernel may be null since this method may be called for the + # wrapper codegen where there is no specific kernel. + if ( + self.name + not in ( + V.graph.removed_buffers | getattr(V.kernel, "removed_buffers", set()) ) + and self.name not in V.graph.inplaced_to_remove ): return self.line return None @@ -647,10 +647,7 @@ def aliases(self): if self._buffer_is_marked_removed(inplaced): continue for other in inplaced.other_names: - if ( - other in V.graph.inplaced_to_remove - or other in V.kernel.inplaced_to_remove - ): + if other in V.graph.inplaced_to_remove: continue if other in self.input_buffers: yield self.input_buffers[other], inplaced.inner_name @@ -891,8 +888,6 @@ def __init__(self, args=None, increase_kernel_count=True): self.indirect_max_sizes: Dict[Tuple[str, str], Tuple[sympy.Expr, str]] = {} self.removed_buffers = set() - self.inplaced_to_remove = set() - # key: the buffer to write # value: the buffer to read and whose memory can be reused for # the buffer specified by key diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index ea0c74334b328..31e7ee396851a 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -2661,7 +2661,6 @@ def run(kernel): scalar_kernel = codegen_kernel(CppKernel) V.graph.removed_buffers |= scalar_kernel.removed_buffers - V.graph.inplaced_to_remove |= scalar_kernel.inplaced_to_remove self.loop_nest = LoopNestWithSplit.build(scalar_kernel) if not self.picked_vec_isa: diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index f486e63f1cd79..c3d05eb3e3501 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -8,7 +8,6 @@ import logging import math import operator -import os from typing import Any, Counter, Dict, Iterable, List, Optional, Set, Tuple import sympy @@ -21,14 +20,13 @@ from torch.utils._sympy.value_ranges import ValueRanges from ..._dynamo.utils import counters from .. import config, ir, scheduler -from ..codecache import code_hash, get_path, PyCodeCache +from ..codecache import code_hash, get_path from ..dependencies import MemoryDep, StarDep from ..ir import IRNode, ReductionHint, TritonTemplateBuffer from ..optimize_indexing import indexing_dtype_strength_reduction from ..scheduler import BaseScheduling from ..triton_heuristics import AutotuneHint from ..utils import ( - do_bench, get_fused_kernel_name, get_kernel_metadata, green_text, @@ -2523,7 +2521,6 @@ def codegen_node_schedule(self, node_schedule, numel, reduction_numel): self.codegen_comment(node_schedule) kernel.call_kernel(kernel_name) V.graph.removed_buffers |= kernel.removed_buffers - V.graph.inplaced_to_remove |= kernel.inplaced_to_remove if config.warn_mix_layout: kernel.warn_mix_layout(kernel_name) @@ -2643,7 +2640,6 @@ def codegen_template(self, template_node, epilogue_nodes): self.codegen_comment(node_schedule) kernel.call_kernel(kernel_name, template_node.node) V.graph.removed_buffers |= kernel.removed_buffers - V.graph.inplaced_to_remove |= kernel.inplaced_to_remove self.scheduler.free_buffers() def codegen_sync(self): @@ -2681,7 +2677,6 @@ def codegen_foreach(self, foreach_node): if node not in (EnableReduction, DisableReduction): node.mark_run() V.graph.removed_buffers |= subkernel.removed_buffers - V.graph.inplaced_to_remove |= subkernel.inplaced_to_remove src_code = kernel.codegen_kernel() kernel_name = self.define_kernel(src_code, [foreach_node]) @@ -2830,81 +2825,6 @@ def select_tiling(cls, node_schedule, numel, reduction_numel=sympy.Integer(1)): def flush(self): pass - def benchmark_fused_nodes(self, nodes): - _, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group - node_schedule = self.generate_node_schedule(nodes, numel, rnumel) - tiled_groups = self.select_tiling(node_schedule, numel, rnumel) - reduction_hint_val, mutations, index_dtype = self.get_kernel_args( - node_schedule, numel, rnumel - ) - - kernel = TritonKernel( - *tiled_groups, - reduction_hint=reduction_hint_val, - mutations=mutations, - index_dtype=index_dtype, - ) - - # empty last_usage. May cause more aggressive 'evict_last'. Should be fine. - for n in nodes: - n.last_usage = set() - - self.codegen_node_schedule_with_kernel(node_schedule, kernel) - with config.patch("benchmark_kernel", True), V.set_kernel_handler(kernel): # type: ignore[attr-defined] - src_code = kernel.codegen_kernel() - - src_code = src_code.replace(str(Placeholder.KERNEL_NAME), "triton_") - mod = PyCodeCache.load(src_code) - - def cache_file_path(): - return os.path.splitext(mod.__file__)[0] + ".kernel_perf" # type: ignore[type-var,operator] - - def load_cache(): - path = cache_file_path() - if os.path.exists(path): - with open(path) as fd: - return float(fd.read()) - return None - - def store_cache(): - path = cache_file_path() - with open(path, "w") as fd: - fd.write(str(ms)) - - log.debug( - "kernel src code for %s written to: %s", - {n.get_name() for n in nodes}, - mod.__file__, - ) - ms = load_cache() - if ms is not None: - return ms - - args = mod.get_args() - call = mod.call - wrapped_jit_function = mod.triton_ - - # call once to trigger the compilation - call(wrapped_jit_function.clone_args(*args)) - - launchers = wrapped_jit_function.launchers - assert len(launchers) == 1 - if launchers[0].n_spills > 0: - # skip benchmarking the kernel if there are register spills - ms = float("inf") - else: - # We have to clone the inplace updated arguments to avoid earlier calls - # generating out of range indices for later calls. - ms = do_bench(lambda: call(wrapped_jit_function.clone_args(*args))) - - log.debug( - "The fused kernel for %s took %.3f ms to run", - {n.get_name() for n in nodes}, - ms, - ) - store_cache() - return ms - @dataclasses.dataclass class CandidateTiling: diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index ec87843e097c5..cad5cc9923038 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -198,7 +198,6 @@ # For each fused kernel in the wrapper, comment with the nodes that get fused. # Useful for debugging fusion. debug_fusion = os.environ.get("TORCHINDUCTOR_DEBUG_FUSION") == "1" -benchmark_fusion = os.environ.get("TORCHINDUCTOR_BENCHMARK_FUSION") == "1" # how many nodes to allow into a single fusion max_fusion_size = 64 diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 40762a5dce23a..1ed30aeb83837 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -3,7 +3,6 @@ import functools import itertools import logging -import math import os import pprint import textwrap @@ -29,8 +28,6 @@ get_device_tflops, get_dtype_size, get_gpu_dram_gbps, - green_text, - red_text, sympy_product, ) from .virtualized import V @@ -1498,97 +1495,6 @@ def fuse_nodes(self): fusion_log.debug("===== fusion complete (%d iterations) =====", i + 1) break - def benchmark_fused_nodes(self, nodes): - """ - Benchmark fused list of nodes and return the execution time - in milliseconds on randomly generated inputs. - """ - assert len(nodes) > 0 - device = nodes[0].get_device() - V.graph.scheduler = self - self.current_device = device - backend = self.get_backend(device) - return backend.benchmark_fused_nodes(nodes) - - def speedup_by_fusion(self, node1, node2): - """ - If config.benchmark_fusion is False, always return True. - Otherwise, return True if fusion can brings speedup. - """ - if not config.benchmark_fusion: - return True - - if node1.is_template(): - # TODO support benchmarking epilogue fusion - return True - - node_list_1 = node1.get_nodes() - device = node_list_1[0].get_device() - - # don't support benchmark fusion for CPU right now. - if device.type == "cpu": - return True - - node_list_2 = node2.get_nodes() - node_list_fused = node_list_1 + node_list_2 - - # We can not accurately benchmark kernel using atomic_add - # due to how we generate random integer inputs. - # Skip benchmarking them by allowing fusion. - if any( - hasattr(n.node, "data") - and hasattr(n.node.data, "scatter_mode") - and n.node.data.scatter_mode == "atomic_add" - for n in node_list_fused - ): - return True - - from triton.compiler.errors import CompilationError - - try: - ms1 = self.benchmark_fused_nodes(node_list_1) - if math.isinf(ms1): - log.debug( - "Skip fusion because of register spilling of the first kernel" - ) - return False - ms2 = self.benchmark_fused_nodes(node_list_2) - if math.isinf(ms2): - log.debug( - "Skip fusion because of register spilling of the second kernel" - ) - return False - ms_fused = self.benchmark_fused_nodes(node_list_fused) - if math.isinf(ms_fused): - log.debug( - "Skip fusion because of register spilling of the fused kernel" - ) - return False - except CompilationError as e: - # workaround triton issue: https://github.com/openai/triton/issues/2151 - if "Loop-carried variable" in str(e): - return True # allow fusion - else: - raise - - if log.isEnabledFor(logging.DEBUG): - if ms_fused < ms1 + ms2: - log.debug( - "Fusing %s with %s cause %sx speedup", - node1.get_names(), - node2.get_names(), - green_text(f"{(ms1 + ms2) / ms_fused:.3f}"), - ) - else: - log.debug( - "Fusing %s with %s cause %sx slowdown", - node1.get_names(), - node2.get_names(), - red_text(f"{ms_fused / (ms1 + ms2):.3f}"), - ) - - return ms_fused < ms1 + ms2 - def fuse_nodes_once(self): """ Mutates self.nodes to combine nodes into FusedSchedulerNodes. @@ -1604,8 +1510,6 @@ def fuse_nodes_once(self): if self.can_fuse(node1, node2) and not self.will_fusion_create_cycle( node1, node2 ): - if not self.speedup_by_fusion(node1, node2): - continue node3 = fuse(node1, node2) fused_nodes.remove(node1) fused_nodes.remove(node2) @@ -1982,7 +1886,7 @@ def remove_filter(n): remove = all(n in names_to_remove for n in buf.other_names) if remove: self.remove_inplace_buffer(name) - V.kernel.inplaced_to_remove.add(name) + V.graph.inplaced_to_remove.add(name) else: self.remove_buffer(name) @@ -2184,10 +2088,3 @@ def flush(self): Flush the generated kernel and python wrapper code to the source code file. """ raise NotImplementedError() - - def benchmark_fused_nodes(self, nodes): - """ - Benchmark fused list of nodes and return the execution time - in milliseconds on randomly generated inputs. - """ - raise NotImplementedError() diff --git a/torch/_inductor/virtualized.py b/torch/_inductor/virtualized.py index a1750bfa9e542..ce6438e52e979 100644 --- a/torch/_inductor/virtualized.py +++ b/torch/_inductor/virtualized.py @@ -56,21 +56,6 @@ class NullHandler: pass -class NullKernelHandler(NullHandler): - """ - We need access `V.kernel.removed_buffers` in DeferredLine class when there - is no kernel in the context. This happens when codegening the wrapper. - Initialize `removed_buffers` and `inplaced_to_remove` explicitly so we don't - need call 'getattr' with default value which is error prone to typo in - attribute name. - """ - - def __init__(self): - super().__init__() - self.removed_buffers = set() - self.inplaced_to_remove = set() - - def _arg_str(a) -> str: if isinstance(a, sympy.Expr): return sympy_str(a) @@ -184,7 +169,7 @@ def __getattr__(self, item): _graph = Virtualized("graph", NullHandler) _real_inputs = Virtualized("real_inputs", NullHandler) _fake_mode = Virtualized("fake_mode", NullHandler) -_kernel = Virtualized("kernel", NullKernelHandler) +_kernel = Virtualized("kernel", NullHandler) _debug = Virtualized("debug", NullHandler) _interpreter = Virtualized("interpreter", NullHandler) _aot_compilation = Virtualized("aot_compilation", NullHandler) From 94e90c199c6abb188d1d86244621b9e94e43598c Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Wed, 25 Oct 2023 18:28:25 -0700 Subject: [PATCH 28/78] [dtensor] fix pointwise op linearity with strategy (#112107) This PR fixes the pointwise op strategy linearity, and switch the linear pointwise ops to use strategy. Also add tests show that using the new way we can enable full shard (S(0), S(0)) like operations Why this is useful? for 2-D Parallel like patterns where the named parameters are possibly fully sharded on all devices, [S(0), S(0)] or [S(1), S(0)], etc. need to work, since we don't use the sharding rules anymore, this is possible at this point. @awgu Pull Request resolved: https://github.com/pytorch/pytorch/pull/112107 Approved by: https://github.com/wz337 --- test/distributed/_tensor/test_math_ops.py | 26 ++++++++++++- .../distributed/_tensor/test_pointwise_ops.py | 10 ++++- torch/distributed/_tensor/ops/common_rules.py | 9 ----- .../distributed/_tensor/ops/pointwise_ops.py | 38 +++++++++++++------ 4 files changed, 61 insertions(+), 22 deletions(-) diff --git a/test/distributed/_tensor/test_math_ops.py b/test/distributed/_tensor/test_math_ops.py index 8c3ba342b087f..45f160b1acd67 100644 --- a/test/distributed/_tensor/test_math_ops.py +++ b/test/distributed/_tensor/test_math_ops.py @@ -5,7 +5,7 @@ import torch -from torch.distributed._tensor import distribute_tensor +from torch.distributed._tensor import DeviceMesh, distribute_tensor from torch.distributed._tensor.placement_types import Replicate, Shard from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( @@ -114,6 +114,30 @@ def test_softmax_with_bwd(self): dist_x_grad = dist_x.grad.redistribute(device_mesh, [Replicate()]) self.assertEqual(dist_x_grad.to_local(), x.grad) + @with_comms + def test_full_shard_math_ops(self): + mesh_shape = (2, self.world_size // 2) + mesh = DeviceMesh( + self.device_type, + torch.arange(self.world_size).reshape(*mesh_shape), + ) + global_tensor = torch.ones(4, 4) + double_shard_tensor = distribute_tensor( + global_tensor, mesh, [Shard(0), Shard(0)] + ) + fully_shard_tensor = distribute_tensor( + global_tensor, mesh, [Shard(0), Shard(1)] + ) + + # for op in [torch.add, torch.sub, torch.mul, torch.div]: + for op in [torch.add, torch.sub, torch.mul, torch.div]: + expect_rs = op(global_tensor, 2) + actual_rs = op(double_shard_tensor, 2).redistribute( + mesh, [Replicate(), Replicate()] + ) + actual_local_res = actual_rs.to_local() + self.assertEqual(actual_local_res, expect_rs) + if __name__ == "__main__": run_tests() diff --git a/test/distributed/_tensor/test_pointwise_ops.py b/test/distributed/_tensor/test_pointwise_ops.py index 1103c1a6961b0..a3b5bfedeacaa 100644 --- a/test/distributed/_tensor/test_pointwise_ops.py +++ b/test/distributed/_tensor/test_pointwise_ops.py @@ -148,7 +148,15 @@ def test_partial_add(self): d_1 = DTensor.from_local(torch.rand(2, 2), device_mesh, [_Partial()]) d_2 = DTensor.from_local(torch.rand(2, 2), device_mesh, [_Partial()]) d_3 = d_1 + d_2 - self.assertEqual(d_3._spec.placements[0].is_partial(), True) + self.assertTrue(d_3._spec.placements[0].is_partial()) + + def test_partial_mul_failure(self): + device_mesh = self.build_device_mesh() + d_1 = DTensor.from_local(torch.ones(2, 2), device_mesh, [_Partial()]) + d_2 = DTensor.from_local(torch.ones(2, 2), device_mesh, [_Partial()]) + d_3 = d_1 * d_2 + self.assertTrue(d_3._spec.placements[0].is_replicate()) + self.assertEqual(d_3.to_local(), torch.ones(2, 2) * (self.world_size**2)) def test_activations(self): device_mesh = self.build_device_mesh() diff --git a/torch/distributed/_tensor/ops/common_rules.py b/torch/distributed/_tensor/ops/common_rules.py index 36b777362ddaa..8d3deec677016 100644 --- a/torch/distributed/_tensor/ops/common_rules.py +++ b/torch/distributed/_tensor/ops/common_rules.py @@ -285,12 +285,3 @@ def pointwise_rule(op_schema: OpSchema, linearity: bool = False) -> OutputShardi linearity=linearity, enforce_sharding=enforce_sharding, ) - - -def linear_pointwise_rule(op_schema: OpSchema) -> OutputSharding: - """ - Linear pointwise operators can propagate pending reductions. - For example, c = add(a, b); if a is pending sum, then c will be - pending sum as well without any communication overhead. - """ - return pointwise_rule(op_schema, linearity=True) diff --git a/torch/distributed/_tensor/ops/pointwise_ops.py b/torch/distributed/_tensor/ops/pointwise_ops.py index 5819e7e6e8a86..08da12f6d6de0 100644 --- a/torch/distributed/_tensor/ops/pointwise_ops.py +++ b/torch/distributed/_tensor/ops/pointwise_ops.py @@ -14,16 +14,20 @@ StrategyType, ) -from torch.distributed._tensor.ops.common_rules import linear_pointwise_rule from torch.distributed._tensor.ops.utils import ( generate_redistribute_costs, infer_broadcast_dims_map, map_placements_after_broadcast, normalize_dim, register_op_strategy, - register_prop_rule, ) -from torch.distributed._tensor.placement_types import DTensorSpec, Placement, Shard +from torch.distributed._tensor.placement_types import ( + _Partial, + DTensorSpec, + Placement, + Replicate, + Shard, +) aten = torch.ops.aten @@ -394,7 +398,9 @@ ] -def pointwise_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: +def pointwise_strategy( + mesh: DeviceMesh, op_schema: OpSchema, linearity: bool = False +) -> StrategyType: max_shards_strategy_index = -1 max_shards = -1 # handle broadcasting @@ -445,6 +451,11 @@ def pointwise_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: common_ndim = len(common_shape) new_shard_dim = common_ndim - len(spec_to_follow.shape) + shard_dim out_placements.append(Shard(new_shard_dim)) + elif isinstance(placement, _Partial) and not linearity: + # clear the partial placemnet if op does not support linearity + # by default we just replicate the partial, need to see if this + # is optimal for all cases + out_placements.append(Replicate()) else: out_placements.append(placement) @@ -452,12 +463,6 @@ def pointwise_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: redistribute_costs: List[List[float]] = [] for idx, input_arg in enumerate(op_schema.args_schema): if isinstance(input_arg, OpStrategy): - if idx == max_shards_strategy_index: - # the current input arg is the one we want to follow - input_specs.append(spec_to_follow) - redistribute_costs.append([0] * len(input_arg.strategies)) - continue - # every arg follow the out_placements, but need to handle broadcasting input_arg_spec = input_arg.strategies[0].output_spec input_arg_dims_map = infer_broadcast_dims_map( @@ -491,8 +496,19 @@ def pointwise_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: return pointwise_strategy +def linear_pointwise_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: + """ + Linear pointwise operators can propagate pending reductions. + For example, c = add(a, b); if a is pending sum, then c will be + pending sum as well without any communication overhead. + """ + return pointwise_strategy(mesh, op_schema, linearity=True) + + for op in linear_pointwise_ops: - register_prop_rule(op)(linear_pointwise_rule) + register_op_strategy(op, schema_info=RuntimeSchemaInfo(static_kwargkey=["out"]))( + linear_pointwise_strategy + ) for op in pointwise_ops: From 7cb72704ccaffc47ad0fb37698b364e3e0d0f11b Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Thu, 26 Oct 2023 23:59:56 +0000 Subject: [PATCH 29/78] Constrain sdpa to fx strides (#111721) Fix for https://github.com/pytorch/pytorch/issues/109607. sdpa requires last dimension strides to be 1. Add constraint so that we run the op with the strides we observed in tracing. Pull Request resolved: https://github.com/pytorch/pytorch/pull/111721 Approved by: https://github.com/drisspg, https://github.com/Chillee, https://github.com/jansel ghstack dependencies: #111976 --- test/inductor/test_torchinductor.py | 61 +++++++++++++++++ ...st_torchinductor_codegen_dynamic_shapes.py | 1 + torch/_inductor/lowering.py | 67 +++++++++++++++++-- 3 files changed, 125 insertions(+), 4 deletions(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index e38eaa8d84be4..530e5026cbb45 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -6758,6 +6758,67 @@ def forward(arg6, arg7, arg16): # expanded dim should not cause copy in require_stride_order self.assertEqual(torch._inductor.metrics.generated_kernel_count, 0) + @requires_cuda() + @unittest.skipIf( + not PLATFORM_SUPPORTS_FLASH_ATTENTION, + "Does not support SDPA or pre-SM80 hardware", + ) + @skipIfRocm + def test_sdpa(self): + def foo(arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): + view = torch.ops.aten.view.default(arg3_1, [23760, 128]) + arg3_1 = None + mm = torch.ops.aten.mm.default(view, arg4_1) + view = arg4_1 = None + view_1 = torch.ops.aten.view.default(mm, [3, 99, 80, 8]) + mm = None + view_2 = torch.ops.aten.view.default(view_1, [3, 99, 80, 8]) + view_1 = None + permute = torch.ops.aten.permute.default(view_2, [0, 3, 1, 2]) + view_2 = None + view_3 = torch.ops.aten.view.default(permute, [3, 8, 99, 80]) + permute = None + + clone = torch.ops.aten.clone.default( + view_3, memory_format=torch.contiguous_format + ) + view_3 = None + + expand = torch.ops.aten.expand.default(clone, [3, 8, 99, 80]) + clone = None + _scaled_dot_product_efficient_attention = ( + torch.ops.aten._scaled_dot_product_efficient_attention.default( + arg0_1, arg1_1, arg2_1, expand, False + ) + ) + arg0_1 = arg1_1 = arg2_1 = expand = None + getitem = _scaled_dot_product_efficient_attention[0] + _scaled_dot_product_efficient_attention = None + return (getitem,) + + DEVICE = torch.device("cuda:0") + DTYPE = torch.float16 + B = 3 + H = 8 + Q = 99 + K = 80 + D = 32 + C_bias = 128 + + # inputs + query = torch.randn((B, H, Q, D), device=DEVICE, dtype=DTYPE) + key = torch.randn((B, H, K, D), device=DEVICE, dtype=DTYPE) + value = torch.randn((B, H, K, D), device=DEVICE, dtype=DTYPE) + bias = torch.randn((B, Q, K, C_bias), device=DEVICE, dtype=DTYPE) + weights = torch.randn((C_bias, H), device=DEVICE, dtype=DTYPE) + + self.common( + foo, + (query, key, value, bias, weights), + atol=0.02, + rtol=1e4, + ) + def test_where_with_logical_op(self): def fn_and(x, y): return torch.where(torch.logical_and(x, y), 1.0, 0.0) diff --git a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py index 8677453a55c5c..0e80d8adb5828 100644 --- a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py @@ -261,6 +261,7 @@ def run(*ex, **kwargs): "test_zero_dim_reductions_dynamic_shapes": TestFailure( ("cpu", "cuda"), is_skip=True ), + "test_sdpa_dynamic_shapes": TestFailure(("cpu",), is_skip=True), # # The following tests do not support dynamic shapes yet: # diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index e6248be16d338..e52f2b16e1aab 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -2008,10 +2008,69 @@ def apply_constraint(arg, fx_arg): make_fallback(aten._fused_moving_avg_obs_fq_helper_functional) make_fallback(aten.grid_sampler_2d_backward, require_dense) make_fallback(aten.randperm) -make_fallback(aten._scaled_dot_product_efficient_attention) -make_fallback(aten._scaled_dot_product_efficient_attention_backward) -make_fallback(aten._scaled_dot_product_flash_attention, warn=False) -make_fallback(aten._scaled_dot_product_flash_attention_backward) + + +def sdpa_constraint(fx_node, *args, **kwargs): + # sdpa requires dense last dimension + def apply_constraint(arg, fx_arg): + if not isinstance(arg, ir.IRNode): + return arg + + meta_val = fx_arg.meta["val"] + if not meta_val.is_cuda: + return arg + + stride_order = ir.get_stride_order(meta_val.stride()) + if stride_order and stride_order[-1] != 0: + # contiguous stride order + stride_order = list(reversed(range(len(arg.get_size())))) + + ALIGNMENT = 16 + + def is_aligned(x): + return (V.graph.sizevars.size_hint(x.get_size()[-1]) % ALIGNMENT) == 0 + + assert isinstance(arg, TensorBox) + unaligned_input_shape = isinstance(arg.data, ir.SliceView) and not is_aligned( + arg + ) + aligned_input_view = unaligned_input_shape and is_aligned(arg.unwrap_view()) + + # input is padded, requiring_stride_order will unwrap the view and unpad. + # Would be nice to be able to require certain padding from inductor ir, nyi + if aligned_input_view: + return arg + + return ir.ExternKernel.require_stride_order(arg, stride_order) + + args = tuple( + apply_constraint(arg, fx_arg) for arg, fx_arg in zip(args, fx_node.args) + ) + kwargs = {k: apply_constraint(v, fx_node.kwargs[k]) for k, v in kwargs.items()} + return args, kwargs + + +make_fallback( + aten._scaled_dot_product_efficient_attention.default, + sdpa_constraint, + warn=False, +) +make_fallback( + aten._scaled_dot_product_efficient_attention_backward.default, + sdpa_constraint, + warn=False, +) +make_fallback( + aten._scaled_dot_product_flash_attention.default, + sdpa_constraint, + warn=False, +) +make_fallback( + aten._scaled_dot_product_flash_attention_backward.default, + sdpa_constraint, + warn=False, +) + make_fallback(aten.sort) make_fallback(aten.sort.stable) make_fallback(aten._sparse_coo_tensor_with_dims_and_tensors) From aa9e65d8f50402d5067d468dd695daf1f7a8ea90 Mon Sep 17 00:00:00 2001 From: "Iris Zhang (PyTorch)" Date: Fri, 27 Oct 2023 04:27:29 +0000 Subject: [PATCH 30/78] [DCP] Add fsspec.transaction context when writing checkpoint to storage (#112191) Summary: Adding fsspec.transaction to safeguard checkpointing writing. With the context, it should only commit if there was no exception and discard otherwise. Test Plan: ``` command: buck test @//mode/dev-nosan //caffe2/test/distributed/checkpoint/fb:test_fsspec_filesystem -- --print-passing-details ``` Reviewed By: rohan-varma Differential Revision: D50701929 Pull Request resolved: https://github.com/pytorch/pytorch/pull/112191 Approved by: https://github.com/rohan-varma --- .../checkpoint/_fsspec_filesystem.py | 28 +++++++++++-------- 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/torch/distributed/checkpoint/_fsspec_filesystem.py b/torch/distributed/checkpoint/_fsspec_filesystem.py index 0d37924cd36e4..48b8512fc2cec 100644 --- a/torch/distributed/checkpoint/_fsspec_filesystem.py +++ b/torch/distributed/checkpoint/_fsspec_filesystem.py @@ -14,8 +14,8 @@ from typing import Callable, cast, Dict, List, Optional, Union import fsspec - import torch +from fsspec import AbstractFileSystem from fsspec.core import url_to_fs from torch import Tensor from torch._utils import _get_device_module @@ -261,6 +261,7 @@ def _write_files_from_queue( result_queue: queue.Queue, planner: SavePlanner, inflight_threshhold: int, + fs: AbstractFileSystem, ): try: while True: @@ -289,18 +290,19 @@ def _write_files_from_queue( ] write_results = [] - with fsspec.open(file_name, "wb") as stream: - for write_item in bytes_w: - data = planner.resolve_data(write_item) - write_results.append( - _write_item(stream, data, write_item, storage_key) - ) + with fs.transaction: + with fsspec.open(file_name, "wb") as stream: + for write_item in bytes_w: + data = planner.resolve_data(write_item) + write_results.append( + _write_item(stream, data, write_item, storage_key) + ) - for tensor, write_item in loader.values(): - assert tensor.is_cpu - write_results.append( - _write_item(stream, tensor, write_item, storage_key) - ) + for tensor, write_item in loader.values(): + assert tensor.is_cpu + write_results.append( + _write_item(stream, tensor, write_item, storage_key) + ) result_queue.put(write_results) except queue.Empty: pass @@ -399,6 +401,7 @@ def gen_file(): result_queue, planner, self.per_thread_copy_ahead, + self.fs, ), ) t.start() @@ -409,6 +412,7 @@ def gen_file(): result_queue=result_queue, planner=planner, inflight_threshhold=self.per_thread_copy_ahead, + fs=self.fs, ) for t in threads: From c84dbd2c0393fed3e3d98030c85c30fbbfdc7c3e Mon Sep 17 00:00:00 2001 From: Iris Zhang Date: Thu, 26 Oct 2023 17:09:03 -0700 Subject: [PATCH 31/78] [2D] Enable 2D optimizer set_state_dict() (#111778) Pull Request resolved: https://github.com/pytorch/pytorch/pull/111778 Approved by: https://github.com/fegin, https://github.com/fduwjj ghstack dependencies: #111774 --- .../tensor/parallel/test_fsdp_2d_parallel.py | 43 ++++++++++++++----- torch/distributed/fsdp/_optim_utils.py | 3 +- torch/distributed/fsdp/_shard_utils.py | 6 ++- 3 files changed, 39 insertions(+), 13 deletions(-) diff --git a/test/distributed/tensor/parallel/test_fsdp_2d_parallel.py b/test/distributed/tensor/parallel/test_fsdp_2d_parallel.py index 341fe575dd24c..6fbd17c0fbd48 100644 --- a/test/distributed/tensor/parallel/test_fsdp_2d_parallel.py +++ b/test/distributed/tensor/parallel/test_fsdp_2d_parallel.py @@ -15,7 +15,7 @@ checkpoint_wrapper, CheckpointImpl, ) -from torch.distributed.checkpoint.state_dict import get_state_dict +from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp._common_utils import ( _get_module_fsdp_state, @@ -447,8 +447,7 @@ class TestNew2dParallelStateDict(DTensorTestBase): @with_comms @skip_if_lt_x_gpu(4) @parametrize("is_even_sharded_model", [True, False]) - @parametrize("use_orig_params", [True, False]) - def test_2d_state_dict(self, is_even_sharded_model, use_orig_params): + def test_2d_state_dict(self, is_even_sharded_model): simple_model = SimpleModel if is_even_sharded_model else SimpleModelUneven # Create a model without wrapper @@ -469,7 +468,7 @@ def test_2d_state_dict(self, is_even_sharded_model, use_orig_params): model_2d = FSDP( model_2d, device_mesh=dp_mesh, - use_orig_params=use_orig_params, + use_orig_params=True ) FSDP.set_state_dict_type( @@ -504,8 +503,7 @@ def test_2d_state_dict(self, is_even_sharded_model, use_orig_params): @with_comms @skip_if_lt_x_gpu(4) @parametrize("is_even_sharded_model", [True, False]) - @parametrize("use_orig_params", [True, False]) - def test_2d_load_state_dict(self, is_even_sharded_model, use_orig_params): + def test_2d_load_state_dict(self, is_even_sharded_model): simple_model = SimpleModel if is_even_sharded_model else SimpleModelUneven torch.manual_seed(0) @@ -520,7 +518,7 @@ def test_2d_load_state_dict(self, is_even_sharded_model, use_orig_params): model_2d = FSDP( model_2d, device_mesh=dp_mesh, - use_orig_params=use_orig_params, + use_orig_params=True ) optim_2d = torch.optim.Adam(model_2d.parameters(), lr=0.01) @@ -560,8 +558,7 @@ def test_2d_load_state_dict(self, is_even_sharded_model, use_orig_params): @with_comms @skip_if_lt_x_gpu(4) @parametrize("is_even_sharded_model", [True, False]) - @parametrize("use_orig_params", [True, False]) - def test_2d_optim_state_dict(self, is_even_sharded_model, use_orig_params): + def test_2d_optim_state_dict(self, is_even_sharded_model): simple_model = SimpleModel if is_even_sharded_model else SimpleModelUneven # Create a model without wrapper @@ -586,7 +583,7 @@ def test_2d_optim_state_dict(self, is_even_sharded_model, use_orig_params): model_2d = FSDP( model_2d, device_mesh=mesh_2d["dp"], - use_orig_params=use_orig_params, + use_orig_params=True ) FSDP.set_state_dict_type( model_2d, @@ -619,6 +616,32 @@ def test_2d_optim_state_dict(self, is_even_sharded_model, use_orig_params): self.assertTrue(isinstance(dist_state, torch.Tensor)) self.assertTrue(torch.allclose(state, dist_state)) + # Update the parameters 2d optim states will be different from ref_optim_state_dict. + model_2d(model_2d.get_input().cuda(self.rank)).sum().backward() + optim_2d.step() + + set_state_dict(model_2d, optimizers=optim_2d, optim_state_dict=ref_optim_2d_osd) + _, new_optim_2d_osd = get_state_dict(model_2d, optimizers=optim_2d) + + ref_optim_2d_osd_states = ref_optim_2d_osd["state"] + new_optim_2d_osd_states = optim_2d_osd["state"] + + # Compare the new optim state dict after load with the reference one + self.assertEqual(len(ref_optim_2d_osd_states), len(new_optim_2d_osd_states)) + self.assertEqual(ref_optim_2d_osd_states.keys(), new_optim_2d_osd_states.keys()) + for fqn, states in ref_optim_2d_osd_states.items(): + new_states = new_optim_2d_osd_states.get(fqn) + + for state_name, state in states.items(): + new_state = new_states.get(state_name) + + if isinstance(new_state, DT): + self.assertEqual(new_state.placements, state.placements) + self.assertEqual(new_state.device_mesh, state.device_mesh) + self.assertTrue(torch.allclose(new_state.to_local(), state.to_local())) + else: + self.assertEqual(new_state, state) + instantiate_parametrized_tests(TestNew2dParallelStateDict) if __name__ == "__main__": diff --git a/torch/distributed/fsdp/_optim_utils.py b/torch/distributed/fsdp/_optim_utils.py index f8c5d494c291d..b459d6d8290a9 100644 --- a/torch/distributed/fsdp/_optim_utils.py +++ b/torch/distributed/fsdp/_optim_utils.py @@ -1429,6 +1429,7 @@ def _unflatten_orig_param_states( value = gathered_state[state_name] param_idx = fsdp_param_info.param_indices[fqn] + # TODO: This solution is not general and only apply to PTD TP solution. if isinstance(value, DTensor): placement = value.placements[0] # If gathered state is a DTensor and its TP placement is not Replicate(), we need to @@ -1443,8 +1444,8 @@ def _unflatten_orig_param_states( # If gathered state is a replicate DTensor, we directly reshape it. else: value = value.reshape(flat_param._shapes[param_idx]) - # If gathered state is a tensor, we directly reshape it into unflatten state. else: + # If gathered state is a tensor, we directly reshape it into unflatten state. value = value.reshape(flat_param._shapes[param_idx]) if shard_state: diff --git a/torch/distributed/fsdp/_shard_utils.py b/torch/distributed/fsdp/_shard_utils.py index 0e63ebd1b9c14..bd45e5f2159e9 100644 --- a/torch/distributed/fsdp/_shard_utils.py +++ b/torch/distributed/fsdp/_shard_utils.py @@ -89,8 +89,10 @@ def _gather_state_dict( value = value.to(value.device_mesh.device_type) # FSDP all_gather: [Shard(0)] -> [Replicate()] # HSDP all_gather: [Replicate(), Shard(0)] -> [Replicate(), Replicate()] - placements = list(copy.deepcopy(value.placements)) - placements[-1] = Replicate() + # 2D FSDP + TP all_gather: + # - [Shard(0), Shard(n)] -> [Replicate(), Replicate()] + # - [Shard(0), Replicate()] -> [Replicate(), Replicate()] + placements = [Replicate() for _ in value.placements] value = value.redistribute( device_mesh=value.device_mesh, placements=placements, From 589625cbae3c0c97a7ab22dd61a5e9164aeba069 Mon Sep 17 00:00:00 2001 From: Levy Zhao Date: Fri, 27 Oct 2023 04:46:24 +0000 Subject: [PATCH 32/78] Add bandwidth to extern kernel calc (#110539) Summary: - Modify the result of get_estimated_runtime() for ExternKernelSchedulerNode to count both bytes and FLOPs and return the maximum of the two. Reviewed By: xmfan Differential Revision: D48987490 Pull Request resolved: https://github.com/pytorch/pytorch/pull/110539 Approved by: https://github.com/xw285cornell --- torch/_inductor/scheduler.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 1ed30aeb83837..5086696bed34b 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -558,9 +558,12 @@ def get_estimated_runtime(self) -> float: # TODO(xmfan): find a better heuristic to model FLOPS/latency relationship factor = 1.0 counted_flops = flop_counter_mode.get_total_flops() + counted_bytes = self.get_read_write_buffers_sizes() + compute_time = (factor * counted_flops / gpu_flops) * 1e9 + transfer_time = counted_bytes / gpu_memory_bandwidth # Return estimated runtime in nanoseconds - return (factor * counted_flops / gpu_flops) * 1e9 + return max(compute_time, transfer_time) elif isinstance(self, FusedSchedulerNode) or isinstance( self.node, ComputedBuffer From a6e556f8b0a471ae3a06a50225e7f30a14e36c27 Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Wed, 25 Oct 2023 08:09:14 +0000 Subject: [PATCH 33/78] Support calling __torch_function__ attribute access (#111737) Triggers `__torch_function__` tracing on attribute/method/property access matching the eager behavior for non-overridden attributes/methods/properties that are present on `torch.Tensor`. Some caveats: 1. for methods there doesn't seem to be a way to check if the original implementation of a method is overridden via monkey patching or not. For example: ``` class LocalSubclass(torch.Tensor): @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} return super().__torch_function__(func, types, args, kwargs) x = torch.ones(2, 2).as_subclass(LocalSubclass) > x.sigmoid ``` There isn't a way to verify that this built-in method is equivalent to the base `torch.Tensor` implementation as each instance will have a different built-in method object that can't be traced back to the original `torch.Tensor` impl. You can check that the class itself has the original implementation via ``` > inspect.getattr_static(LocalSubclass, "sigmoid") ``` But we can't detect if the user dynamically patches an object with a built-in method called sigmoid which does something completely different. 2. If a user overrides a method but calls the original implementation we will still graph break. This will require modifying `SuperVariable` (and any other way to get the original impl) to handle tensor subclasses. Pull Request resolved: https://github.com/pytorch/pytorch/pull/111737 Approved by: https://github.com/jansel, https://github.com/ezyang --- test/dynamo/test_subclasses.py | 190 ++++++++++++++++++++-- torch/_dynamo/utils.py | 1 - torch/_dynamo/variables/torch_function.py | 70 +++++++- 3 files changed, 242 insertions(+), 19 deletions(-) diff --git a/test/dynamo/test_subclasses.py b/test/dynamo/test_subclasses.py index 6ec04513cbc6b..030b951e79083 100644 --- a/test/dynamo/test_subclasses.py +++ b/test/dynamo/test_subclasses.py @@ -1,5 +1,4 @@ # Owner(s): ["module: dynamo"] -import contextlib import functools import unittest @@ -17,8 +16,11 @@ from torch.nested._internal.nested_tensor import jagged_from_list, ViewBufferFromNested from torch.testing._internal.inductor_utils import HAS_CUDA + requires_cuda = functools.partial(unittest.skipIf, not HAS_CUDA, "requires cuda") +compile_full_eager = torch.compile(backend="eager", fullgraph=True) + class MockSubclass(torch.Tensor): @classmethod @@ -28,6 +30,30 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): return func(*args, **kwargs) +class DummyNDim(torch.Tensor): + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + + if func == torch.Tensor.ndim.__get__: + return 10 + + return super().__torch_function__(func, types, args, kwargs) + + +class SigmoidToExpSubclass(torch.Tensor): + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + + if func == torch.Tensor.sigmoid: + return super().__torch_function__(torch.Tensor.exp, types, args, kwargs) + + return super().__torch_function__(func, types, args, kwargs) + + class EagerRecordGraphAndInputs: def __init__(self): self.graphs = [] @@ -39,22 +65,18 @@ def __call__(self, gm: torch.fx.GraphModule, example_inputs): return gm -@contextlib.contextmanager -def preserve_subclass_config(): - old_subclass_set = set(torch._dynamo.config.traceable_tensor_subclasses) - try: - torch._dynamo.config.traceable_tensor_subclasses.add(MockSubclass) - yield - finally: - torch._dynamo.config.traceable_tensor_subclasses.clear() - torch._dynamo.config.traceable_tensor_subclasses.update(old_subclass_set) +GLOBAL_TEST_SUBCLASSES = {MockSubclass, DummyNDim, SigmoidToExpSubclass} class SubclassTests(torch._dynamo.test_case.TestCase): @classmethod def setUpClass(cls): super().setUpClass() - cls._exit_stack.enter_context(preserve_subclass_config()) + cls._exit_stack.enter_context( + torch._dynamo.config.patch( + "traceable_tensor_subclasses", GLOBAL_TEST_SUBCLASSES + ) + ) @classmethod def tearDownClass(cls): @@ -115,16 +137,152 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): kwargs = {} return func(*args, **kwargs) - torch._dynamo.config.traceable_tensor_subclasses.add(LocalSubclass) + with torch._dynamo.config.patch("traceable_tensor_subclasses", {LocalSubclass}): + + @torch.compile(backend="eager", fullgraph=True) + def fn(x): + return LocalSubclass(torch.add(x, 1.0)) + + input = torch.ones(2, 2) + + res = fn(input) + self.assertIsInstance(res, LocalSubclass) + + def test_torch_function_call_on_method(self): + x = torch.ones(2, 2) + y = torch.ones(2, 2) + z = torch.ones(2, 2) + wrapped = x.as_subclass(SigmoidToExpSubclass) + wrapped2 = y.as_subclass(SigmoidToExpSubclass) + + def fn(w): + return w.sigmoid() + + fn_opt = compile_full_eager(fn) + + res_exp = fn(wrapped) + res_act = fn_opt(wrapped2) + res_exp2 = z.exp() + + self.assertEqual(res_exp, res_act) + self.assertEqual(res_exp, res_exp2) + + def test_user_overidden_method_unsupported(self): + class LocalSubclass(torch.Tensor): + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + return super().__torch_function__(func, types, args, kwargs) + + def sigmoid(self): + return None @torch.compile(backend="eager", fullgraph=True) def fn(x): - return LocalSubclass(torch.add(x, 1.0)) + x.sigmoid() - input = torch.ones(2, 2) + msg = ( + "Accessing overidden method/attribute sigmoid on a tensor" + " subclass with a __torch_function__ override is not supported" + ) + with torch._dynamo.config.patch( + "traceable_tensor_subclasses", {LocalSubclass} + ), self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg): + x = torch.ones(2, 2).as_subclass(LocalSubclass) + fn(x) - res = fn(input) - self.assertIsInstance(res, LocalSubclass) + def test_user_overidden_attr_unsupported(self): + class LocalSubclass(torch.Tensor): + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + return super().__torch_function__(func, types, args, kwargs) + + ndim = 10 + + @torch.compile(backend="eager", fullgraph=True) + def fn(x): + return x.ndim + + msg = ( + "Accessing overidden method/attribute ndim on a tensor" + " subclass with a __torch_function__ override is not supported" + ) + with torch._dynamo.config.patch( + "traceable_tensor_subclasses", {LocalSubclass} + ), self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg): + x = torch.ones(2, 2).as_subclass(LocalSubclass) + fn(x) + + def test_user_overidden_property_unsupported(self): + class LocalSubclass(torch.Tensor): + def __init__(self): + self._ndim = 10 + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + return super().__torch_function__(func, types, args, kwargs) + + @property + def ndim(self): + return self._ndim + + @ndim.setter + def ndim(self, value): + self._ndim = value + + @torch.compile(backend="eager", fullgraph=True) + def fn(x): + return x.ndim + + msg = ( + "Accessing overidden method/attribute ndim on a tensor" + " subclass with a __torch_function__ override is not supported" + ) + with torch._dynamo.config.patch( + "traceable_tensor_subclasses", {LocalSubclass} + ), self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg): + x = torch.ones(2, 2).as_subclass(LocalSubclass) + fn(x) + + def test_overridden_method_guarding(self): + class LocalSubclass(torch.Tensor): + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + return super().__torch_function__(func, types, args, kwargs) + + @torch.compile(backend="eager", fullgraph=True) + def fn(x): + return x.sigmoid() + + with torch._dynamo.config.patch( + error_on_recompile=True, traceable_tensor_subclasses={LocalSubclass} + ): + x = torch.ones(2, 2).as_subclass(LocalSubclass) + fn(x) + x.sigmoid = False + fn(x) + + def test_torch_function_call_on_attr(self): + x = torch.ones(2, 2) + wrapped = x.as_subclass(DummyNDim) + + def fn(w): + return w.ndim + torch.ones(2) + + fn_opt = compile_full_eager(fn) + + res_exp = fn(wrapped) + res_act = fn_opt(wrapped) + + self.assertEqual(res_exp, res_act) + self.assertEqual(res_exp, torch.ones(2) + 10) def test_compile_with_fake_tensor_dynamic_dim(self): x = torch.randn([3, 4]) diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 6f21a6c8bf5d0..2a2190513cd32 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -2198,7 +2198,6 @@ def is_tensor_base_attr_getter(value): return ( isinstance(value, types.MethodWrapperType) and value.__name__ == "__get__" - and isinstance(value.__self__, types.GetSetDescriptorType) and value.__self__.__objclass__ is torch._C._TensorBase ) diff --git a/torch/_dynamo/variables/torch_function.py b/torch/_dynamo/variables/torch_function.py index 66a30c4eb393c..3e4415fc3b6b2 100644 --- a/torch/_dynamo/variables/torch_function.py +++ b/torch/_dynamo/variables/torch_function.py @@ -1,3 +1,4 @@ +import inspect from typing import Dict, List from torch.overrides import _get_overloaded_args, get_default_nowrap_functions @@ -45,6 +46,19 @@ def is_torch_function_user_object(obj): return hasattr(obj, "__torch_function__") +def _is_attr_overidden(tx, var, name): + import torch + + overridden = False + try: + attr_val = inspect.getattr_static(var.python_type(), name) + overridden |= attr_val != getattr(torch.Tensor, name) + except AttributeError: + pass + + return overridden + + def call_torch_function( tx, torch_function_type, torch_function_var, fn, types, args, kwargs ): @@ -138,6 +152,45 @@ def subclass_type_var(self): def global_mangled_class_name(self): return f"__subclass_{self.class_type.__name__}_{id(self.class_type)}" + def var_getattr(self, tx, name): + # [Note: __torch_function__] We currently only support attributes that are defined on + # base tensors, custom attribute accesses will graph break. + import torch + from .builder import SourcelessBuilder, VariableBuilder + + if name in banned_attrs or not hasattr(torch.Tensor, name): + unimplemented( + f"Accessing {name} on a tensor subclass with a __torch_function__ override is not supported" + ) + + if _is_attr_overidden(tx, self, name): + unimplemented( + f"Accessing overidden method/attribute {name} on a tensor" + " subclass with a __torch_function__ override is not supported" + ) + + if tx.output.torch_function_enabled: + if self.source: + get_fn = VariableBuilder( + tx, + source=AttrSource( + AttrSource(AttrSource(self.source, "__class__"), name), + "__get__", + ), + )(inspect.getattr_static(self.python_type(), name).__get__) + else: + get_fn = SourcelessBuilder()(tx, getattr(torch.Tensor, name).__get__) + + return self.call_torch_function( + tx, + get_fn, + TupleVariable([self.subclass_type_var()]), + [self], + {}, + ) + else: + return super().var_getattr(tx, name) + def call_torch_function(self, tx, fn, types, args, kwargs): return call_torch_function( tx, @@ -160,11 +213,24 @@ def call_method( # of `call_method`. if tx.output.torch_function_enabled: import torch - from .builder import SourcelessBuilder + from .builder import SourcelessBuilder, VariableBuilder + + if _is_attr_overidden(tx, self, name): + unimplemented( + f"Calling overidden method {name} on a tensor" + " subclass with a __torch_function__ override is not supported" + ) # [Note: __torch_function__] Currently we only support methods that are defined on tensor # we will graph break in other cases this will need a bigger overhaul of extracting methods/comparing them for equality - func_var = SourcelessBuilder()(tx, getattr(torch.Tensor, name)) + # We've established with the above check that the method is not overridden, so we guard that the method is the same + # as the impl defined on tensor and retrieve it + if self.source: + func_var = VariableBuilder( + tx, AttrSource(AttrSource(self.source, "__class__"), name) + )(inspect.getattr_static(self.python_type(), name)) + else: + func_var = SourcelessBuilder()(tx, getattr(torch.Tensor, name)) return dispatch_torch_function(tx, func_var, [self] + args, kwargs) else: return self.tensor_variable.call_method(tx, name, args, kwargs) From 033680c9afdff49863b1d0a82de244088482ead1 Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Thu, 26 Oct 2023 18:51:14 -0700 Subject: [PATCH 34/78] [tp] fix PrepareModuleInput for multiple inputs (#112204) Not all inputs needs to annotate shardings and convert to DTensors, if user annotate only one inputs are mark the rest as Nones, we should skip creating DTensors Pull Request resolved: https://github.com/pytorch/pytorch/pull/112204 Approved by: https://github.com/fduwjj --- .../tensor/parallel/test_tp_style.py | 30 +++++++++++++++++-- torch/distributed/tensor/parallel/style.py | 15 ++++++---- 2 files changed, 37 insertions(+), 8 deletions(-) diff --git a/test/distributed/tensor/parallel/test_tp_style.py b/test/distributed/tensor/parallel/test_tp_style.py index ea0c9d6f427f4..a29d7e65d35cd 100644 --- a/test/distributed/tensor/parallel/test_tp_style.py +++ b/test/distributed/tensor/parallel/test_tp_style.py @@ -4,6 +4,7 @@ import torch import torch.distributed as dist from torch.distributed._tensor import DeviceMesh, distribute_tensor, Replicate, Shard +from torch.distributed.tensor.parallel import parallelize_module from torch.distributed.tensor.parallel.style import ( ColwiseParallel, make_input_replicate_1d, @@ -21,14 +22,14 @@ from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, with_comms, + NUM_DEVICES, ) class TensorParallelStyleTest(DTensorTestBase): @property def world_size(self): - gpu_num = torch.cuda.device_count() - return gpu_num if gpu_num % 2 == 0 and gpu_num > 4 else 4 + return NUM_DEVICES def _1d_input_func_check( self, @@ -261,6 +262,31 @@ def test_prepare_module_input(self): error_msgs="No device mesh is currently active", ) + @with_comms + def test_prepare_module_input_multiple_inputs(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(8, 8) + + def forward(self, x, y): + return self.linear(x) + y + + test_mod = TestModule().to(self.device_type) + mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + + parallelize_module(test_mod.linear, mesh, ColwiseParallel()) + parallelize_module( + test_mod, + mesh, + PrepareModuleInput(input_layouts=(Shard(0), None), output_layouts=(Replicate(), None)) + ) + output = test_mod( + torch.randn(2, 8, device=self.device_type), + torch.ones(self.world_size * 2, 8 // self.world_size, device=self.device_type) + ) + self.assertEqual(output.shape, (self.world_size * 2, 8 // self.world_size)) + @with_comms def test_prepare_module_output(self): tensor = torch.rand(8, 16, device=self.device_type) diff --git a/torch/distributed/tensor/parallel/style.py b/torch/distributed/tensor/parallel/style.py index 716cb84c3f82a..6b90c96fbbd3d 100644 --- a/torch/distributed/tensor/parallel/style.py +++ b/torch/distributed/tensor/parallel/style.py @@ -476,11 +476,11 @@ def __init__( ) super().__init__( + _prepare_input=prepare_input_fn, + _prepare_output=prepare_output_fn, input_layouts=input_layouts, output_layouts=output_layouts, use_local_output=use_local_output, - _prepare_input=prepare_input_fn, - _prepare_output=prepare_output_fn, ) @staticmethod @@ -570,11 +570,14 @@ def _make_input_redistribute_1d( for input, input_layout, output_layout in zip( inputs, input_layouts, output_layouts # type: ignore[arg-type] ): - results.append( - _redistribute_per_both_layouts( - input, input_layout, output_layout, device_mesh + if input_layout is None: + results.append(input) + else: + results.append( + _redistribute_per_both_layouts( + input, input_layout, output_layout, device_mesh + ) ) - ) return tuple(results) From 7e5e951dfeb82ceb168875a5de5534d2ed7ed35a Mon Sep 17 00:00:00 2001 From: Fei Kou Date: Fri, 27 Oct 2023 05:08:31 +0000 Subject: [PATCH 35/78] [tp] update node meta with partitioned val (#112080) Test Plan: buck run mode/opt scripts/feikou/di:export_dummy_model -- --world-size=4 buck run mode/opt scripts/feikou/di:run_model -- --num_gpus=4 --num_iters=1 In sigmoid: Non-DI: ``` V1025 13:57:16.341391 2225036 run_model.cpp:84] Non-ditributed run outputs:[ 0.8350 0.5399 1.0196 0.9286 1.1265 1.0324 V1025 13:57:16.341391 2225036 run_model.cpp:84] 0.8350 0.5399 1.0196 0.9286 1.1265 1.0324 V1025 13:57:16.341391 2225036 run_model.cpp:84] 0.8350 0.5399 1.0196 0.9286 1.1265 1.0324 V1025 13:57:16.341391 2225036 run_model.cpp:84] 0.8350 0.5399 1.0196 0.9286 1.1265 1.0324 V1025 13:57:16.341391 2225036 run_model.cpp:84] 0.8350 0.5399 1.0196 0.9286 1.1265 1.0324 V1025 13:57:16.341391 2225036 run_model.cpp:84] [ CUDAFloatType{5,6} ]] ``` DI: ``` V1025 13:57:26.352564 2226855 run_model.cpp:278] [Rank 3] output wait_tensor_9: 0.8350 0.5399 1.0196 0.9286 1.1265 1.0324 V1025 13:57:26.352564 2226855 run_model.cpp:278] 0.8350 0.5399 1.0196 0.9286 1.1265 1.0324 V1025 13:57:26.352564 2226855 run_model.cpp:278] 0.8350 0.5399 1.0196 0.9286 1.1265 1.0324 V1025 13:57:26.352564 2226855 run_model.cpp:278] 0.8350 0.5399 1.0196 0.9286 1.1265 1.0324 V1025 13:57:26.352564 2226855 run_model.cpp:278] 0.8350 0.5399 1.0196 0.9286 1.1265 1.0324 V1025 13:57:26.352564 2226855 run_model.cpp:278] [ CUDAFloatType{5,6} ] ``` Differential Revision: D50663481 Pull Request resolved: https://github.com/pytorch/pytorch/pull/112080 Approved by: https://github.com/wanchaol --- torch/distributed/_tensor/experimental/tp_transform.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/torch/distributed/_tensor/experimental/tp_transform.py b/torch/distributed/_tensor/experimental/tp_transform.py index 0e2b7292af6e0..18ef793872319 100644 --- a/torch/distributed/_tensor/experimental/tp_transform.py +++ b/torch/distributed/_tensor/experimental/tp_transform.py @@ -262,6 +262,9 @@ def _partitioner(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: node_sharding = node.meta["sharding"] if node.op == "placeholder": out_spec = node_sharding.output_spec + local_val = _partition_val(node.meta["val"], out_spec) + # update node value + node.meta["val"] = local_val elif node.op == "call_function": out_spec = node_sharding.output_spec # check if there's misaligned sharding, insert reshard if there is @@ -278,7 +281,9 @@ def _partitioner(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: _insert_reshard_gm( gm, node, input_arg, input_arg_spec, desired_spec ) - + # convert output val to its local component + output_val = node.meta["val"] + node.meta["val"] = _partition_val(output_val, out_spec) elif node.op == "output": for input_arg in node.all_input_nodes: # input args of output should be Replicate, otherwise redistribution is needed. @@ -292,7 +297,6 @@ def _partitioner(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: desired_spec.placements = (Replicate(),) if arg_spec != desired_spec: _insert_reshard_gm(gm, node, arg, arg_spec, desired_spec) - else: raise RuntimeError(f"op code {node} not supported") @@ -321,7 +325,7 @@ def _partition_val(val: Any, spec: DTensorSpec) -> Any: assert my_coord is not None, "current rank not in mesh!" my_coord_on_mesh_dim = my_coord[idx] local_shard = placement._split_tensor( - local_shard, num_chunks, with_padding=False, contiguous=False + local_shard, num_chunks, with_padding=False, contiguous=True )[0][my_coord_on_mesh_dim] return local_shard elif isinstance(val, (list, tuple)): From 572b66331e1709ae1c9dab0b27c9b4cf7dd1ceb6 Mon Sep 17 00:00:00 2001 From: Shengbao Zheng Date: Fri, 27 Oct 2023 05:24:04 +0000 Subject: [PATCH 36/78] [PyTorch][ET] collect comms in ET for send/recv (#111985) Summary: collect send/recv comms op in Execution Trace Test Plan: run param comms with arbitrary collective size to collect operator send ``` { "name": "record_param_comms", "id": 153, "rf_id": 141, "parent": 152, "fw_parent": 0, "seq_id": -1, "scope": 0, "tid": 1, "fw_tid": 0, "op_schema": "", "inputs": [[[21,22,0,262144,4,"cuda:0"]],215038,139890792374272,1,"send",[],[]], "input_shapes": [[[262144]],[],[],[],[],[],[]], "input_types": ["GenericList[Tensor(float)]","Int","Int","Int","String","GenericList[]","GenericList[]"], "outputs": [[[21,22,0,262144,4,"cuda:0"]]], "output_shapes": [[[262144]]], "output_types": ["GenericList[Tensor(float)]"] }, ``` recv ``` { "name": "record_param_comms", "id": 172, "rf_id": 160, "parent": 171, "fw_parent": 0, "seq_id": -1, "scope": 0, "tid": 1, "fw_tid": 0, "op_schema": "", "inputs": [[[138,139,0,262144,4,"cuda:0"]],215042,139890792374272,1,"recv",[],[]], "input_shapes": [[[262144]],[],[],[],[],[],[]], "input_types": ["GenericList[Tensor(float)]","Int","Int","Int","String","GenericList[]","GenericList[]"], "outputs": [[[138,139,0,262144,4,"cuda:0"]]], "output_shapes": [[[262144]]], "output_types": ["GenericList[Tensor(float)]"] }, ``` Differential Revision: D50624443 Pull Request resolved: https://github.com/pytorch/pytorch/pull/111985 Approved by: https://github.com/fduwjj --- .../distributed/c10d/ProcessGroupNCCL.cpp | 36 +++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 564caeda5ea89..69096157abbe0 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -3333,6 +3333,24 @@ c10::intrusive_ptr ProcessGroupNCCL::send( int dstRank, int /* unused */) { check_gpu_tensors_different_devices(tensors); + + // @lint-ignore CLANGTIDY + auto tensor = tensors.back(); + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + this->getID(), + tensors, // inputTensors + tensors, // outputTensors + dstRank, // rank + "send", // colName + tensor.numel(), // inSize + tensor.numel(), // outSize + tensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + this->getSize()); // worldSize + auto ret = pointToPoint( tensors, [&](at::Tensor& input, @@ -3353,6 +3371,24 @@ c10::intrusive_ptr ProcessGroupNCCL::recv( int srcRank, int /* unused */) { check_gpu_tensors_different_devices(tensors); + + // @lint-ignore CLANGTIDY + auto tensor = tensors.back(); + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + this->getID(), + tensors, // inputTensors + tensors, // outputTensors + srcRank, // rank + "recv", // colName + tensor.numel(), // inSize + tensor.numel(), // outSize + tensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + this->getSize()); // worldSize + auto ret = pointToPoint( tensors, [&](at::Tensor& output, From 6a992915465530ab8566883275364de1092cd403 Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Thu, 26 Oct 2023 23:59:56 +0000 Subject: [PATCH 37/78] Removing sdpa conv layout constraint (#112045) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Previously layout opt with sdpa would cause failures because we would pass a non-dense last dim to sdpa. Those layout constraints have been added in prior prs. Now we can do conv layout opt with sdpa. Improves twins_pcpvt_base 1.4622 → 1.5351, xcit_large_24_p8_224 3.0681 → 3.1839 Pull Request resolved: https://github.com/pytorch/pytorch/pull/112045 Approved by: https://github.com/shunting314 ghstack dependencies: #111976, #111721 --- torch/_inductor/graph.py | 24 ------------------------ 1 file changed, 24 deletions(-) diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 0b52b604e58d7..8ffcef80fa1fa 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -327,30 +327,6 @@ def decide_layout_opt(gm) -> bool: log.debug("Skip layout opt because all convolution channels are too small") return False - # aten._scaled_dot_product_flash_attention requires the last stride of query/key/value - # to be 1. Check https://gist.github.com/shunting314/fa6eeab2aad8d1265c4d5e50b560d94f - # for more details. - # - # When a model contains aten._scaled_dot_product_flash_attention and we enable layout optimization, - # the op may get channels last input and fail. Example include: twins_pcpvt_base, xcit_large_24_p8_224 - # for _scaled_dot_product_flash_attention and xcit_large_24_p8_224 for _scaled_dot_product_efficient_attention. - # - # We disable layout optimization if a model contains aten._scaled_dot_product_flash_attention. - # - # An alternative is to do necessary layout conversion to make sure aten._scaled_dot_product_flash_attention's - # inputs have the layout needed. But that seems to have worse perf than disabing the layout opt. - # TODO(shunting) revisit if we can still apply layout optimization to models containing sdpa while - # bringing perf gains. - for n in gm.graph.nodes: - if n.target in ( - torch.ops.aten._scaled_dot_product_flash_attention.default, - torch.ops.aten._scaled_dot_product_efficient_attention.default, - ): - log.debug( - "Skip layout optimization because sdpa (scaled dot product attention) is found" - ) - return False - return True def find_nodes_prefer_channels_last(self): From 632ac01bef20b4396f96f0d57b718a531b84505a Mon Sep 17 00:00:00 2001 From: Jez Ng Date: Thu, 26 Oct 2023 20:01:00 -0700 Subject: [PATCH 38/78] [dynamo] Enable typechecking for exc.py (#112127) Pull Request resolved: https://github.com/pytorch/pytorch/pull/112127 Approved by: https://github.com/Skylion007 ghstack dependencies: #111894, #111992, #112031 --- .lintrunner.toml | 1 + mypy-nofollow.ini | 3 +++ torch/_dynamo/exc.py | 22 +++++++++++----------- 3 files changed, 15 insertions(+), 11 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index ce80a46c7eca8..ebf81035fb398 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -181,6 +181,7 @@ include_patterns = [ 'torch/_dynamo/allowed_functions.py', 'torch/_dynamo/codegen.py', 'torch/_dynamo/eval_frame.py', + 'torch/_dynamo/exc.py', 'torch/_dynamo/funcname_cache.py', 'torch/_dynamo/convert_frame.py', 'torch/_dynamo/symbolic_convert.py', diff --git a/mypy-nofollow.ini b/mypy-nofollow.ini index b2c09ef5af5af..9e9e5e68fdc76 100644 --- a/mypy-nofollow.ini +++ b/mypy-nofollow.ini @@ -38,6 +38,9 @@ ignore_errors = True [mypy-torch._C.*] ignore_errors = True +[mypy-torch.fb.*] +ignore_missing_imports = True + [mypy-torchvision.*] ignore_missing_imports = True diff --git a/torch/_dynamo/exc.py b/torch/_dynamo/exc.py index 96f98155e111e..a482b23b15177 100644 --- a/torch/_dynamo/exc.py +++ b/torch/_dynamo/exc.py @@ -2,7 +2,7 @@ import textwrap from enum import auto, Enum from traceback import extract_stack, format_exc, format_list, StackSummary -from typing import cast, Optional +from typing import cast, NoReturn, Optional import torch._guards @@ -165,7 +165,7 @@ class IncorrectUsage(Exception): ) -def unimplemented_with_warning(e, code, msg): +def unimplemented_with_warning(e: Exception, code, msg: str) -> NoReturn: # This function calls unimplemented internally and eventually graph breaks # or falls to eager. unimplemented itself does not print any user warnings, # i.e., its very silent. This helper function is intended when an error is @@ -179,12 +179,12 @@ def unimplemented_with_warning(e, code, msg): raise unimplemented(msg) from e -def unimplemented(msg: str): +def unimplemented(msg: str) -> NoReturn: assert msg != os.environ.get("BREAK", False) raise Unsupported(msg) -def warning(msg: str): +def warning(msg: str) -> None: counters["warnings"][msg] += 1 assert msg != os.environ.get("BREAK", False) @@ -202,14 +202,12 @@ def __repr__(self) -> str: return self.__str__() -def augment_exc_message(exc, msg="\n", export=False): +def augment_exc_message(exc: Exception, msg: str = "\n", export: bool = False) -> None: import traceback real_stack = get_real_stack(exc) if real_stack is not None: - msg += ( - f"\nfrom user code:\n {''.join(traceback.format_list(get_real_stack(exc)))}" - ) + msg += f"\nfrom user code:\n {''.join(traceback.format_list(real_stack))}" if config.replay_record_enabled and hasattr(exc, "record_filename"): msg += f"\nLast frame execution written to {exc.record_filename}. To run only this frame while debugging, run\ @@ -250,7 +248,7 @@ def augment_exc_message(exc, msg="\n", export=False): exc.args = (new_msg,) + exc.args[1:] -def get_real_stack(exc, frame=None) -> Optional[StackSummary]: +def get_real_stack(exc: Exception, frame=None) -> Optional[StackSummary]: real_stack = getattr(exc, "real_stack", None) if real_stack is None: return None @@ -292,7 +290,9 @@ def filter_stack(stack): return user_stack -def format_error_msg_verbose(exc, code, record_filename=None, frame=None): +def format_error_msg_verbose( + exc: Exception, code, record_filename=None, frame=None +) -> str: msg = ( f"WON'T CONVERT {code.co_name} {code.co_filename} line {code.co_firstlineno}\n" ) @@ -314,7 +314,7 @@ def format_error_msg_verbose(exc, code, record_filename=None, frame=None): return msg -def format_error_msg(exc, code, record_filename=None, frame=None): +def format_error_msg(exc: Exception, code, record_filename=None, frame=None) -> str: msg = os.linesep * 2 if config.verbose: From 20fc2b41869fe4b5f382bfbe21f00b0ba0f7e6cf Mon Sep 17 00:00:00 2001 From: Jez Ng Date: Thu, 26 Oct 2023 20:01:01 -0700 Subject: [PATCH 39/78] [dynamo] Enable typechecking for compiled_autograd.py (#112128) Pull Request resolved: https://github.com/pytorch/pytorch/pull/112128 Approved by: https://github.com/Skylion007 ghstack dependencies: #111894, #111992, #112031, #112127 --- .lintrunner.toml | 1 + torch/_dynamo/compiled_autograd.py | 10 +++++++--- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index ebf81035fb398..0dfebd97c2218 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -180,6 +180,7 @@ code = 'MYPYNOFOLLOW' include_patterns = [ 'torch/_dynamo/allowed_functions.py', 'torch/_dynamo/codegen.py', + 'torch/_dynamo/compiled_autograd.py', 'torch/_dynamo/eval_frame.py', 'torch/_dynamo/exc.py', 'torch/_dynamo/funcname_cache.py', diff --git a/torch/_dynamo/compiled_autograd.py b/torch/_dynamo/compiled_autograd.py index 4395e3d4c008a..3b8fbc3b11563 100644 --- a/torch/_dynamo/compiled_autograd.py +++ b/torch/_dynamo/compiled_autograd.py @@ -1,6 +1,6 @@ import contextlib import functools -from typing import List +from typing import List, Optional import torch from torch._dynamo.external_utils import call_hook @@ -20,6 +20,7 @@ track_tensor_tree, ) from torch.fx.experimental.symbolic_shapes import DimDynamic, ShapeEnv +from torch.fx.proxy import Proxy compiled_autograd_log = getArtifactLogger(__name__, "compiled_autograd") @@ -43,14 +44,14 @@ def __init__(self, compiler_fn) -> None: ) self.fx_tracer = PythonKeyTracer() self.proxy_mode = ProxyTorchDispatchMode(self.fx_tracer, "symbolic") - self.hooks_proxy = None + self.hooks_proxy: Optional[Proxy] = None def wrap_fake(self, x, source): assert isinstance(x, torch.Tensor) return self.fake_tensor_mode.from_tensor(x, source=source) @staticmethod - def source(name, idx): + def source(name, idx) -> GetItemSource: return GetItemSource(LocalSource(name), idx) def begin_capture(self, inputs: List[torch.Tensor], sizes: List[int]): @@ -102,6 +103,7 @@ def proxy_call_hook(self, hook, *args): ) def tensor_pre_hook(self, inputs, hook_id, i: int): + assert self.hooks_proxy is not None hook = self.hooks_proxy[hook_id] proxy = self.proxy_call_hook( hook, @@ -113,6 +115,7 @@ def tensor_pre_hook(self, inputs, hook_id, i: int): return inputs def pre_hook(self, inputs, hook_id): + assert self.hooks_proxy is not None hook = self.hooks_proxy[hook_id] proxies = self.proxy_call_hook( hook, @@ -124,6 +127,7 @@ def pre_hook(self, inputs, hook_id): return inputs def post_hook(self, outputs, inputs, hook_id): + assert self.hooks_proxy is not None hook = self.hooks_proxy[hook_id] proxies = self.proxy_call_hook( hook, From 46667c97fddc28d1e929a790985034ecb39a4bde Mon Sep 17 00:00:00 2001 From: Wei Lu Date: Fri, 27 Oct 2023 07:56:01 +0000 Subject: [PATCH 40/78] [Pytorch][Vulkan] var.dim (#111965) Summary: We implement [`torch.var`](https://pytorch.org/docs/stable/generated/torch.var.html) for tensors of 2d to 4d. By using the `mean`, `sub` and `pow` ops, we can compute the variance as below without adding a new shader. ``` at::Tensor self_mean = self.mean(opt_dim, true); at::Tensor output = (self.sub(self_mean).pow(2)).mean(opt_dim, keepdim); ``` Test Plan: ``` [luwei@devbig984.prn1 /data/users/luwei/fbsource (2da0640c6)]$ LD_LIBRARY_PATH=third-party/swiftshader/lib/linux-x64/ buck run fbcode/mode/dev-nosan //xplat/caffe2:pt_vulkan_api_test_bin -- --gtest_filter="*var*" Building: finished in 0.1 sec (100%) 339/339 jobs, 0/339 updated Total time: 0.1 sec BUILD SUCCEEDED Running main() from third-party/googletest/1.11.0/googletest/googletest/src/gtest_main.cc Note: Google Test filter = *var* [==========] Running 6 tests from 1 test suite. [----------] Global test environment set-up. [----------] 6 tests from VulkanAPITest [ RUN ] VulkanAPITest.var_2d_unbiased [ OK ] VulkanAPITest.var_2d_unbiased (322 ms) [ RUN ] VulkanAPITest.var_2d_biased [ OK ] VulkanAPITest.var_2d_biased (0 ms) [ RUN ] VulkanAPITest.var_3d_unbiased [ OK ] VulkanAPITest.var_3d_unbiased (2 ms) [ RUN ] VulkanAPITest.var_3d_biased [ OK ] VulkanAPITest.var_3d_biased (2 ms) [ RUN ] VulkanAPITest.var_4d_unbiased [ OK ] VulkanAPITest.var_4d_unbiased (175 ms) [ RUN ] VulkanAPITest.var_4d_biased [ OK ] VulkanAPITest.var_4d_biased (5 ms) [----------] 6 tests from VulkanAPITest (508 ms total) [----------] Global test environment tear-down [==========] 6 tests from 1 test suite ran. (508 ms total) [ PASSED ] 6 tests. ``` Reviewed By: yipjustin Differential Revision: D50398925 Pull Request resolved: https://github.com/pytorch/pytorch/pull/111965 Approved by: https://github.com/yipjustin --- aten/src/ATen/native/vulkan/ops/Var.cpp | 76 +++++++++++++++++ aten/src/ATen/test/vulkan_api_test.cpp | 103 ++++++++++++++++++++++++ 2 files changed, 179 insertions(+) create mode 100644 aten/src/ATen/native/vulkan/ops/Var.cpp diff --git a/aten/src/ATen/native/vulkan/ops/Var.cpp b/aten/src/ATen/native/vulkan/ops/Var.cpp new file mode 100644 index 0000000000000..ff95fab34aade --- /dev/null +++ b/aten/src/ATen/native/vulkan/ops/Var.cpp @@ -0,0 +1,76 @@ +#include +#include +#include + +namespace at { +namespace native { +namespace vulkan { +namespace ops { +namespace { + +using namespace api::utils; + +Tensor var_dim_IntList( + const at::Tensor& self_arg, + const OptionalIntArrayRef opt_dim, + bool unbiased = true, // correction=1 in version 2.0 + bool keepdim = false) { + TORCH_CHECK( + self_arg.dim() >= 2 && self_arg.dim() <= 4, + "Vulkan var.dim_IntList only supports 2d, 3d, 4d tensors as input!"); + + TORCH_CHECK( + opt_dim.has_value(), "Vulkan var without a dim arg is not implemented"); + + const Tensor self = self_arg.is_vulkan() ? self_arg : self_arg.vulkan(); + + std::set dims_set; + if (opt_dim.has_value()) { + int sample_size = 1; + auto dims = opt_dim.value(); + + for (const auto& d : dims) { + TORCH_CHECK(d >= -self.dim() || d < self.dim(), "Dimension out of range"); + + int64_t dim_normalized = utils::normalize(d, self.dim()); + if (dims_set.find(dim_normalized) != dims_set.end()) { + TORCH_CHECK( + false, + "dim ", + dim_normalized, + " appears multiple times in the list of dims") + } + dims_set.insert(dim_normalized); + + sample_size *= self.sizes().vec()[dim_normalized]; + } + + at::Tensor self_mean = self.mean(opt_dim, true); + at::Tensor self_minus_mean = self.sub(self_mean); + // We write `self_minus_mean.mul(self_minus_mean)` instead of + // `self.sub(self_mean).pow(2)` because Vulkan driver on Android doesn't + // support negative input: "The result is undefined if x<0 or if x=0 and + // y≤0" see https://registry.khronos.org/OpenGL-Refpages/gl4/html/pow.xhtml + at::Tensor output = + self_minus_mean.mul(self_minus_mean).mean(opt_dim, keepdim); + if (unbiased == true) { + output = output.mul(sample_size * 1.0 / (sample_size - 1)); + } + return output; + } + return self; +} + +#ifdef USE_VULKAN_API + +TORCH_LIBRARY_IMPL(aten, Vulkan, m) { + m.impl(TORCH_SELECTIVE_NAME("aten::var.dim"), TORCH_FN(var_dim_IntList)); +} + +#endif /* USE_VULKAN_API */ + +} // namespace +} // namespace ops +} // namespace vulkan +} // namespace native +} // namespace at diff --git a/aten/src/ATen/test/vulkan_api_test.cpp b/aten/src/ATen/test/vulkan_api_test.cpp index 382c9e464caeb..f61aa060091fb 100644 --- a/aten/src/ATen/test/vulkan_api_test.cpp +++ b/aten/src/ATen/test/vulkan_api_test.cpp @@ -5097,6 +5097,109 @@ TEST_F(VulkanAPITest, unbind_3d_depth_large) { test_unbind({100, 1, 144}, 0); } +void test_var(const at::IntArrayRef input_shape, const at::IntArrayRef dim_list, bool unbiased=true, bool keepdim=false) { + c10::InferenceMode mode; + + const auto in_cpu = at::rand(input_shape, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + const auto out_cpu = at::var(in_cpu, dim_list, unbiased, keepdim); + + const auto in_vulkan = in_cpu.vulkan(); + const auto out_vulkan = at::var(in_vulkan, dim_list, unbiased, keepdim); + + const auto check = almostEqual(out_cpu, out_vulkan.cpu()); + if (!check) { + showRtol(out_cpu, out_vulkan.cpu()); + } + + ASSERT_TRUE(check); +} + +TEST_F(VulkanAPITest, var_2d_unbiased) { + test_var({3, 5}, {1}, true, true); + test_var({3, 5}, {1}, true, false); + + // inpu.dim() == dim_list.size(), only keepdim == true is supported + test_var({3, 5}, {0, 1}, true, true); +} + +TEST_F(VulkanAPITest, var_2d_biased) { + test_var({3, 5}, {1}, false, true); + test_var({3, 5}, {1}, false, false); + + // inpu.dim() == dim_list.size(), only keepdim == true is supported + test_var({3, 5}, {0, 1}, false, true); +} + +TEST_F(VulkanAPITest, var_3d_unbiased) { + test_var({3, 5, 7}, {1}, true, true); + test_var({3, 5, 7}, {1}, true, false); + + test_var({3, 5, 7}, {0, 1}, true, true); + test_var({3, 5, 7}, {0, 1}, true, false); + + test_var({3, 5, 7}, {0, 2}, true, true); + test_var({3, 5, 7}, {0, 2}, true, false); + + test_var({3, 5, 7}, {-1, -2}, true, true); + test_var({3, 5, 7}, {-1, -2}, true, false); + + test_var({3, 5, 7}, {0, 1, 2}, true, true); +} + +TEST_F(VulkanAPITest, var_3d_biased) { + test_var({3, 5, 7}, {1}, false, true); + test_var({3, 5, 7}, {1}, false, false); + + test_var({3, 5, 7}, {0, 1}, false, true); + test_var({3, 5, 7}, {0, 1}, false, false); + + test_var({3, 5, 7}, {0, 2}, false, true); + test_var({3, 5, 7}, {0, 2}, false, false); + + test_var({3, 5, 7}, {-1, -2}, false, true); + test_var({3, 5, 7}, {-1, -2}, false, false); + + test_var({3, 5, 7}, {0, 1, 2}, false, true); +} + +TEST_F(VulkanAPITest, var_4d_unbiased) { + test_var({3, 5, 7, 11}, {0}, true, true); + test_var({3, 5, 7, 11}, {1}, true, false); + + test_var({3, 5, 7, 11}, {0, 1}, true, true); + test_var({3, 5, 7, 11}, {0, 1}, true, false); + + test_var({3, 5, 7, 11}, {0, 2}, true, true); + test_var({3, 5, 7, 11}, {0, 2}, true, false); + + test_var({3, 5, 7, 11}, {-1, -2}, true, true); + test_var({3, 5, 7, 11}, {-1, -2}, true, false); + + test_var({3, 5, 7, 11}, {0, 1, 2}, true, true); + test_var({3, 5, 7, 11}, {0, -1, 2}, true, false); + + test_var({3, 5, 7, 11}, {0, 1, 2, 3}, true, true); +} + +TEST_F(VulkanAPITest, var_4d_biased) { + test_var({3, 5, 7, 11}, {0}, false, true); + test_var({3, 5, 7, 11}, {1}, false, false); + + test_var({3, 5, 7, 11}, {0, 1}, false, true); + test_var({3, 5, 7, 11}, {0, 1}, false, false); + + test_var({3, 5, 7, 11}, {0, 2}, false, true); + test_var({3, 5, 7, 11}, {0, 2}, false, false); + + test_var({3, 5, 7, 11}, {-1, -2}, false, true); + test_var({3, 5, 7, 11}, {-1, -2}, false, false); + + test_var({3, 5, 7, 11}, {0, 1, 2}, false, true); + test_var({3, 5, 7, 11}, {0, -1, 2}, false, false); + + test_var({3, 5, 7, 11}, {0, 1, 2, 3}, false, true); +} + TEST_F(VulkanAPITest, view_explicit) { c10::InferenceMode mode; From 2a86bcbac23f95e3d144d8167faf9176cd265b98 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Thu, 26 Oct 2023 14:42:28 -0700 Subject: [PATCH 41/78] [FSDP][state_dict] Cleanup the usage of _get_pg_default_device (#112168) _get_pg_default_device is not suitable for FSDP use case. We should always use the compute_device when communicating. Differential Revision: [D50698730](https://our.internmc.facebook.com/intern/diff/D50698730/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/112168 Approved by: https://github.com/wz337 --- torch/distributed/fsdp/_optim_utils.py | 8 +++-- torch/distributed/fsdp/_state_dict_utils.py | 36 ++++++--------------- 2 files changed, 14 insertions(+), 30 deletions(-) diff --git a/torch/distributed/fsdp/_optim_utils.py b/torch/distributed/fsdp/_optim_utils.py index b459d6d8290a9..0117aa73a08cf 100644 --- a/torch/distributed/fsdp/_optim_utils.py +++ b/torch/distributed/fsdp/_optim_utils.py @@ -344,11 +344,10 @@ def _broadcast_processed_state( def _broadcast_state( fsdp_state: _FSDPState, state: Any, group: Optional[dist.ProcessGroup] ) -> Any: - device = _get_pg_default_device(group) if fsdp_state.rank == 0: if not isinstance(state, torch.Tensor) or state.dim() == 0: return state - tensor = state.to(device) + tensor = state.to(fsdp_state.compute_device) else: if isinstance(state, torch.Tensor): assert state.dim() == 0, ( @@ -358,7 +357,9 @@ def _broadcast_state( return state elif not isinstance(state, _PosDimTensorInfo): return state - tensor = torch.zeros(state.shape, dtype=state.dtype, device=device) + tensor = torch.zeros( + state.shape, dtype=state.dtype, device=fsdp_state.compute_device + ) dist.broadcast(tensor, src=0, group=group) return tensor @@ -1155,6 +1156,7 @@ def _check_missing_keys_on_rank( assert param_key >= 0 and param_key < len( param_key_to_param ), "Check the `param_key_to_param` construction" + # We cannot use FSDPState.compute_device as this API is a global view. device = _get_pg_default_device(group) num_missing = torch.tensor([len(missing_keys)], dtype=torch.int32, device=device) dist.all_reduce(num_missing, group=group) diff --git a/torch/distributed/fsdp/_state_dict_utils.py b/torch/distributed/fsdp/_state_dict_utils.py index 2ab27ddc9587e..2e8d9c3d288d3 100644 --- a/torch/distributed/fsdp/_state_dict_utils.py +++ b/torch/distributed/fsdp/_state_dict_utils.py @@ -19,7 +19,6 @@ from torch.distributed._tensor import DTensor from torch.distributed._tensor.device_mesh import mesh_resources -from torch.distributed.distributed_c10d import _get_pg_default_device from torch.distributed.fsdp._common_utils import ( _FSDPState, _get_module_fsdp_state_if_fully_sharded_module, @@ -611,7 +610,6 @@ def _sharded_pre_load_state_dict_hook( zip(handle.flat_param._fqns, handle.flat_param._param_extensions) ) - device = fsdp_state.compute_device for fqn, _, _ in _param_name_infos(module, fsdp_state): if not _is_composable(fsdp_state): fqn_from_global_root = f"{prefix}{FSDP_PREFIX}{fqn}" @@ -644,40 +642,24 @@ def _sharded_pre_load_state_dict_hook( ) if len(shards) == 1: local_tensor = shards[0].tensor.flatten() - pg_device = _get_pg_default_device(fsdp_state.process_group) - if local_tensor.device.type != pg_device.type: - with SimpleProfiler.profile(SimpleProfiler.Type.H2D): - local_tensor = local_tensor.to(pg_device) + with SimpleProfiler.profile(SimpleProfiler.Type.H2D): + local_tensor = local_tensor.to(fsdp_state.compute_device) num_padding = chunk_size - local_tensor.numel() if num_padding > 0: local_tensor = F.pad(local_tensor, [0, num_padding]) else: - local_tensor = torch.zeros(chunk_size, dtype=param.dtype, device=device) + local_tensor = torch.zeros( + chunk_size, dtype=param.dtype, device=fsdp_state.compute_device + ) tensor = torch.empty( chunk_size * fsdp_state.world_size, dtype=local_tensor.dtype, - device=device, + device=fsdp_state.compute_device, ) - if local_tensor.is_cpu: - # Tensor could be on FSDP GPU compute device, while local_tensor is on CPU. - # Convert to CPU so all_gather can work. - tensor_dev = tensor.device - with SimpleProfiler.profile(SimpleProfiler.Type.H2D): - tensor = tensor.cpu() - tensor_list = list( - torch.chunk(tensor, dist.get_world_size(fsdp_state.process_group)) + with SimpleProfiler.profile(SimpleProfiler.Type.ALLGATHER): + dist.all_gather_into_tensor( + tensor, local_tensor, group=fsdp_state.process_group ) - with SimpleProfiler.profile(SimpleProfiler.Type.ALLGATHER): - dist.all_gather( - tensor_list, local_tensor, group=fsdp_state.process_group - ) - with SimpleProfiler.profile(SimpleProfiler.Type.D2H): - tensor.to(tensor_dev) - else: - with SimpleProfiler.profile(SimpleProfiler.Type.ALLGATHER): - dist.all_gather_into_tensor( - tensor, local_tensor, group=fsdp_state.process_group - ) tensor = tensor.narrow(0, 0, param_numel).reshape(param.size()) state_dict[fqn_from_global_root] = tensor else: From 7265c22a5d6a0d0649d06407c2bd1e86ef5c7a9e Mon Sep 17 00:00:00 2001 From: Sherlock Huang Date: Fri, 27 Oct 2023 09:14:15 +0000 Subject: [PATCH 42/78] [AOTInductor] Enforce no_grad for Run entries (#111613) Summary: Always enter no_grad mode in AOTInductor run entries. ``` // AOTInductor uses at::addmm_out, which doesn't supports // arguments that requires gradient. For this reason, we // enforce no_grad context for run APIs. ``` Test Plan: buck2 test mode/dev-nosan caffe2/test/inductor:test_aot_inductor and OSS CI Differential Revision: D50432042 Pull Request resolved: https://github.com/pytorch/pytorch/pull/111613 Approved by: https://github.com/chenyang78, https://github.com/khabinov --- test/inductor/test_aot_inductor.py | 27 ++++++++++++++++++ .../codegen/aoti_runtime/interface.cpp | 28 +++++++++++++++---- torch/csrc/inductor/aoti_torch/c/shim.h | 3 ++ .../csrc/inductor/aoti_torch/shim_common.cpp | 9 ++++++ 4 files changed, 62 insertions(+), 5 deletions(-) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 670628e97c822..e2c1161f24742 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -951,6 +951,33 @@ def forward(self, x): example_inputs = (torch.randn(4, 4, 4, 4).to(self.device),) self.check_model(Model(), example_inputs) + def test_run_with_grad_enabled(self): + class Model(torch.nn.Module): + def forward(self, x, weight, bias): + return torch.ops.aten.addmm(bias, weight, x) + + m = Model().to(device=self.device) + x = torch.rand(8, 8, device=self.device, requires_grad=True) + weight = torch.rand(8, 8, device=self.device, requires_grad=True) + bias = torch.rand(8, device=self.device, requires_grad=True) + example_inputs = (x, weight, bias) + + expected = m(*example_inputs) + expected, _ = pytree.tree_flatten(expected) + + # compiler under no_grad + with torch.no_grad(): + so_path = AOTInductorModelRunner.compile(m, example_inputs) + + # run under grad enabled + self.assertTrue(torch.is_grad_enabled()) + + optimized = AOTInductorModelRunner.load(self.device, so_path, example_inputs) + actual = optimized(example_inputs) + actual, _ = pytree.tree_flatten(actual) + + self.assertTrue(same(actual, expected)) + class AOTInductorTestABICompatibleCpu(TestCase): device = "cpu" diff --git a/torch/_inductor/codegen/aoti_runtime/interface.cpp b/torch/_inductor/codegen/aoti_runtime/interface.cpp index a824ea9fae148..8fef1237193c7 100644 --- a/torch/_inductor/codegen/aoti_runtime/interface.cpp +++ b/torch/_inductor/codegen/aoti_runtime/interface.cpp @@ -26,6 +26,22 @@ std::to_string(actual_size)); \ } while (0) +// AOTInductor uses at::addmm_out, which doesn't supports +// arguments that requires gradient. For this reason, we +// enforce no_grad context for run APIs. +// +// A RAII, thread local (!) guard that enables or disables grad mode upon +// construction, and sets it back to the original value upon destruction. +struct AOTINoGradGuard { + AOTINoGradGuard() : prev_mode(aoti_torch_grad_mode_is_enabled()) { + aoti_torch_grad_mode_set_enabled(false); + } + ~AOTINoGradGuard() { + aoti_torch_grad_mode_set_enabled(prev_mode); + } + bool prev_mode; +}; + extern "C" { AOTIRuntimeError AOTInductorModelContainerCreate( @@ -79,6 +95,7 @@ AOTIRuntimeError AOTInductorModelContainerRun( auto stream = reinterpret_cast(stream_handle); CONVERT_EXCEPTION_TO_ERROR_CODE({ + AOTINoGradGuard guard; container->run( input_handles, output_handles, @@ -167,11 +184,12 @@ AOTIRuntimeError AOTInductorModelRun( AtenTensorHandle* output_handles) { auto model = reinterpret_cast(model_handle); CONVERT_EXCEPTION_TO_ERROR_CODE({ - model->run_impl( - input_handles, - output_handles, - (torch::aot_inductor::DeviceStreamType)nullptr, - nullptr); + AOTINoGradGuard guard; + model->run_impl( + input_handles, + output_handles, + (torch::aot_inductor::DeviceStreamType)nullptr, + nullptr); }) } diff --git a/torch/csrc/inductor/aoti_torch/c/shim.h b/torch/csrc/inductor/aoti_torch/c/shim.h index b8564caeb3db7..92bd9bdf93be4 100644 --- a/torch/csrc/inductor/aoti_torch/c/shim.h +++ b/torch/csrc/inductor/aoti_torch/c/shim.h @@ -94,6 +94,9 @@ AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_int32(); AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_int64(); AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_bool(); +AOTI_TORCH_EXPORT bool aoti_torch_grad_mode_is_enabled(); +AOTI_TORCH_EXPORT void aoti_torch_grad_mode_set_enabled(bool enabled); + // Free the tensor object AOTI_TORCH_EXPORT AOTITorchError aoti_torch_delete_tensor_object(AtenTensorHandle tensor); diff --git a/torch/csrc/inductor/aoti_torch/shim_common.cpp b/torch/csrc/inductor/aoti_torch/shim_common.cpp index 1b1a9639183a7..db15f31215712 100644 --- a/torch/csrc/inductor/aoti_torch/shim_common.cpp +++ b/torch/csrc/inductor/aoti_torch/shim_common.cpp @@ -1,4 +1,5 @@ #include +#include #include #include #include @@ -90,6 +91,14 @@ int32_t aoti_torch_dtype_bool() { return (int32_t)c10::ScalarType::Bool; } +bool aoti_torch_grad_mode_is_enabled() { + return c10::GradMode::is_enabled(); +} + +void aoti_torch_grad_mode_set_enabled(bool enabled) { + return c10::GradMode::set_enabled(enabled); +} + AOTITorchError aoti_torch_delete_tensor_object(AtenTensorHandle tensor) { AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ at::Tensor* t = tensor_handle_to_tensor_pointer(tensor); From cf5479b57e607f4bbdee06c6eb6e890d42d3aa05 Mon Sep 17 00:00:00 2001 From: "Li-Huai (Allan) Lin" Date: Thu, 26 Oct 2023 17:58:58 -0700 Subject: [PATCH 43/78] [MPS] Make the device in MPSGenerator consistent with MPSAllocator (#112188) https://github.com/pytorch/pytorch/blob/1b702b185e8dddadb4ad3f487f5412a02c8777e1/aten/src/ATen/mps/MPSAllocator.mm#L751-L760 The device in an MPS tensor is actually allocated with a device index, so this PR makes the device generated by `MPSGenerator` consistent with that. Fixes https://github.com/pytorch/pytorch/issues/110820#issuecomment-1752088865 Pull Request resolved: https://github.com/pytorch/pytorch/pull/112188 Approved by: https://github.com/malfet, https://github.com/kulinseth --- aten/src/ATen/mps/MPSGeneratorImpl.mm | 2 +- aten/src/ATen/mps/MPSStream.mm | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/mps/MPSGeneratorImpl.mm b/aten/src/ATen/mps/MPSGeneratorImpl.mm index 1fd2f3b7577b3..16a0a4553e7b2 100644 --- a/aten/src/ATen/mps/MPSGeneratorImpl.mm +++ b/aten/src/ATen/mps/MPSGeneratorImpl.mm @@ -21,7 +21,7 @@ Generator createMPSGenerator(uint64_t seed_val) { } // namespace mps::detail MPSGeneratorImpl::MPSGeneratorImpl(uint64_t seed_in) - : c10::GeneratorImpl{Device(DeviceType::MPS), DispatchKeySet(c10::DispatchKey::MPS)}, + : c10::GeneratorImpl{Device(DeviceType::MPS, 0), DispatchKeySet(c10::DispatchKey::MPS)}, data_({.seed = seed_in}), engine_(seed_in, 0, 0) {} diff --git a/aten/src/ATen/mps/MPSStream.mm b/aten/src/ATen/mps/MPSStream.mm index afff3da2003d2..959c2a507042e 100644 --- a/aten/src/ATen/mps/MPSStream.mm +++ b/aten/src/ATen/mps/MPSStream.mm @@ -248,7 +248,7 @@ @interface MPSGraphExecutionDescriptor () MPSStream* MPSStreamImpl::getInstance() { if (_stream == nullptr) { - _stream = new MPSStream(Stream(Stream::UNSAFE, c10::Device(DeviceType::MPS), 0)); + _stream = new MPSStream(Stream(Stream::UNSAFE, c10::Device(DeviceType::MPS, 0), 0)); } return _stream; } From 9f7bff11719cd67e5ffd9b2978811f3d8141b803 Mon Sep 17 00:00:00 2001 From: Howard Huang Date: Fri, 27 Oct 2023 14:44:43 +0000 Subject: [PATCH 44/78] Add timeout for master store if clients do not join (#111805) Currently, if the master_store does not have all clients join in the `timeout` time, it will just continue silently which could lead to errors down the road. However, if a client does not connect with the master within the specified time then an exception will be raised. This change will have master_store error out if not all clients have joined, making server and client consistent with each other. Since this is changing the default behavior of master store I am open to suggestions. Example: ```python import torch.distributed as dist import torch.multiprocessing as mp from datetime import timedelta def main(rank, world_size): if rank == 0: print("creating store") # world size is 2 so this eventually times out store = dist.TCPStore("localhost", 1234, 2, True, timeout=timedelta(seconds=5)) print("finished creating store") if __name__ == "__main__": world_size = 2 mp.spawn(main, (world_size,), nprocs=world_size) ``` Previous ``` print("creating store") print("finished creating store") ``` Now ``` print("creating store") torch.distributed.DistStoreError: Timed out after 6 seconds waiting for workers. 1/2 workers joined. ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/111805 Approved by: https://github.com/XilunWu, https://github.com/fduwjj --- test/distributed/elastic/utils/distributed_test.py | 4 ++-- test/distributed/test_store.py | 10 +++++++++- torch/csrc/distributed/c10d/TCPStore.cpp | 9 ++++++++- 3 files changed, 19 insertions(+), 4 deletions(-) diff --git a/test/distributed/elastic/utils/distributed_test.py b/test/distributed/elastic/utils/distributed_test.py index 0eab292969bd5..4e5fa8d7e0c32 100644 --- a/test/distributed/elastic/utils/distributed_test.py +++ b/test/distributed/elastic/utils/distributed_test.py @@ -14,7 +14,7 @@ import unittest from contextlib import closing -from torch.distributed import DistNetworkError +from torch.distributed import DistNetworkError, DistStoreError from torch.distributed.elastic.utils.distributed import ( create_c10d_store, get_socket_with_port, @@ -98,7 +98,7 @@ def test_create_store_multi(self): self.assertEqual(0, worker1.exitcode) def test_create_store_timeout_on_server(self): - with self.assertRaises(TimeoutError): + with self.assertRaises(DistStoreError): # use any available port (port 0) since timeout is expected create_c10d_store( is_server=True, diff --git a/test/distributed/test_store.py b/test/distributed/test_store.py index 358f2d21e5894..e6c79724417ca 100644 --- a/test/distributed/test_store.py +++ b/test/distributed/test_store.py @@ -14,7 +14,7 @@ import torch.distributed as dist import torch.distributed.distributed_c10d as c10d import torch.distributed.rpc as rpc -from torch.distributed import DistNetworkError, DistError +from torch.distributed import DistNetworkError, DistError, DistStoreError from torch.testing._internal.common_distributed import MultiThreadedTestCase from torch.testing._internal.common_utils import instantiate_parametrized_tests, parametrize @@ -405,6 +405,14 @@ def test_multi_get(self): self.assertEqual(b"po", v0) self.assertEqual(b"tato", v1) + def test_store_timeout_on_missing_clients(self): + with self.assertRaisesRegex(DistStoreError, r"Timed out after \d+ seconds waiting for clients. \d+/\d+ clients joined."): + # world_size is 2 so it should timeout + dist.TCPStore("localhost", 0, 2, True, timeout=timedelta(seconds=2)) + + # when wait_for_workers is not set, then there should be no exception raised + dist.TCPStore("localhost", 0, 2, True, timeout=timedelta(seconds=2), wait_for_workers=False) + class LibUvTCPStoreTest(TCPStoreTest): def _create_store(self): diff --git a/torch/csrc/distributed/c10d/TCPStore.cpp b/torch/csrc/distributed/c10d/TCPStore.cpp index 907b0befb32b7..080956b9379b2 100644 --- a/torch/csrc/distributed/c10d/TCPStore.cpp +++ b/torch/csrc/distributed/c10d/TCPStore.cpp @@ -1,4 +1,5 @@ #include +#include #include #include #include @@ -371,7 +372,13 @@ void TCPStore::waitForWorkers() { const auto elapsed = std::chrono::duration_cast( std::chrono::steady_clock::now() - start); if (timeout_ != kNoTimeout && elapsed > timeout_) { - break; + C10_THROW_ERROR( + DistStoreError, + fmt::format( + "Timed out after {} seconds waiting for clients. {}/{} clients joined.", + elapsed.count(), + numWorkersCompleted, + *numWorkers_)); } /* sleep override */ std::this_thread::sleep_for(std::chrono::milliseconds(10)); From 7df675743c5196c5f1c5f8c1b8e2b000c0992a9b Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Thu, 26 Oct 2023 09:59:24 -0700 Subject: [PATCH 45/78] Stop using defaultdict for deferred_runtime_asserts (#112172) In the ShapeEnv record replay machinery we do equality tests on this dict, but `{i0: []}` is considered not equal to `{}`. But you can unpredictably end up with the first by just doing reads from the dict. Doing a real dict removes this wobbliness. Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/112172 Approved by: https://github.com/ysiraichi, https://github.com/Skylion007 --- torch/fx/experimental/symbolic_shapes.py | 25 ++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 857e0a35ae3fc..7008b152fa86d 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -1541,7 +1541,7 @@ def _init( # to the next unbacked symbol to wait on, but if we choose the # latest key, an assert will only show up at the moment when # we can actually codegen it. - self.deferred_runtime_asserts: Dict[sympy.Symbol, List[RuntimeAssert]] = collections.defaultdict(list) + self.deferred_runtime_asserts: Dict[sympy.Symbol, List[RuntimeAssert]] = {} # This exists so we can efficiently invalidate the cache (it's used as # part of the cache key); otherwise we'd have to iterate through # deferred_runtime_asserts to compute its length @@ -2828,16 +2828,17 @@ def _maybe_evaluate_static( if s in self.var_to_val: continue subst = {} - for ra in self.deferred_runtime_asserts[s]: - if compute_hint: - e = ra.expr.xreplace(self.var_to_val) - else: - e = ra.expr - subst[e] = sympy.true - subst[sympy.Not(e)] = sympy.false - # NB: this doesn't match relations if they're flipped; e.g., - # if you have x < 5, we won't get 5 > x. Holler if this is - # a problem + if s in self.deferred_runtime_asserts: + for ra in self.deferred_runtime_asserts[s]: + if compute_hint: + e = ra.expr.xreplace(self.var_to_val) + else: + e = ra.expr + subst[e] = sympy.true + subst[sympy.Not(e)] = sympy.false + # NB: this doesn't match relations if they're flipped; e.g., + # if you have x < 5, we won't get 5 > x. Holler if this is + # a problem # NB: this helps us deal with And/Or connectives expr = expr.subs(subst) @@ -3345,7 +3346,7 @@ def defer_runtime_assert(self, orig_expr: "sympy.Expr", msg, fx_node=None): ra = RuntimeAssert(expr, msg, stack) # TODO: Do this in a way that is less janky than int(s.name[1:]) cands = sorted([s for s in expr.free_symbols if s.name.startswith("i")], key=lambda s: int(s.name[1:])) - self.deferred_runtime_asserts[cands[-1]].append(ra) + self.deferred_runtime_asserts.setdefault(cands[-1], []).append(ra) self.num_deferred_runtime_asserts += 1 # TODO: refine ranges # Unfortunately, range refinement is probably going to not From bd0ea72b28f238b4ef8cf126d4122ca5da7964ac Mon Sep 17 00:00:00 2001 From: Richard Zou Date: Thu, 26 Oct 2023 13:13:51 -0400 Subject: [PATCH 46/78] torch.library: Create helper function `is_functional_schema` (#111660) I will need this again soon. Test Plan: - existing tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/111660 Approved by: https://github.com/soulitzer --- test/test_custom_ops.py | 6 +++--- torch/_custom_op/impl.py | 22 +++++--------------- torch/_library/utils.py | 44 +++++++++++++++++++++++++++++++++++++++- 3 files changed, 51 insertions(+), 21 deletions(-) diff --git a/test/test_custom_ops.py b/test/test_custom_ops.py index 7c4ac74145451..38ab7593e4537 100644 --- a/test/test_custom_ops.py +++ b/test/test_custom_ops.py @@ -467,15 +467,15 @@ def baz(x: Tensor) -> Tensor: raise NotImplementedError() def test_unsupported_schemas(self): - with self.assertRaisesRegex(ValueError, "does not support non-functional"): + with self.assertRaisesRegex(ValueError, "only supports functional"): custom_ops.custom_op( f"{TestCustomOp.test_ns}::foo", "(Tensor(a!) x) -> Tensor(a)" )(foo) - with self.assertRaisesRegex(ValueError, "does not support view functions"): + with self.assertRaisesRegex(ValueError, "only supports functional"): custom_ops.custom_op( f"{TestCustomOp.test_ns}::foo", "(Tensor(a) x) -> Tensor(a)" )(foo) - with self.assertRaisesRegex(ValueError, "no outputs"): + with self.assertRaisesRegex(ValueError, "only supports functional"): custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo", "(Tensor x) -> ()")( foo ) diff --git a/torch/_custom_op/impl.py b/torch/_custom_op/impl.py index 807d47c26a40e..0751c041aeee1 100644 --- a/torch/_custom_op/impl.py +++ b/torch/_custom_op/impl.py @@ -661,24 +661,12 @@ def validate_namespace(ns: str) -> None: ) def validate_schema(schema: FunctionSchema) -> None: - # Coming in the future. Requires us to have correct logic for - # the ADInplaceOrView key - if schema.kind() != SchemaKind.functional: + if not torch._library.utils.is_functional_schema(schema): raise ValueError( - f"custom_op does not support non-functional function schema. Got: {schema}" - ) - - rets = schema.returns - is_non_mutating_view = len(rets) > 0 and any( - r.annotation is not None and not r.annotation.is_write for r in rets - ) - if is_non_mutating_view: - raise ValueError(f"custom_op does not support view functions. Got: {schema}") - - # Just seems weird so banning for now - if not schema.returns: - raise ValueError( - f"custom_op does not support function schema with no outputs. Got: {schema}" + f"custom_op only supports functional operators " + f"(ops that do not mutate any inputs, do not return " + f"views of the inputs, and has at least one return). " + f"Got the following non-functional schema: {schema}" ) # For simplicity: don't allow self arguments diff --git a/torch/_library/utils.py b/torch/_library/utils.py index 5ec5a78f1208a..8e693c8243bf0 100644 --- a/torch/_library/utils.py +++ b/torch/_library/utils.py @@ -1,7 +1,9 @@ import dataclasses import inspect import sys -from typing import Callable, Tuple +from typing import Any, Callable, Tuple + +import torch @dataclasses.dataclass @@ -49,3 +51,43 @@ def parse_namespace(qualname: str) -> Tuple[str, str]: f"of a namespace and a name, e.g. aten::sin" ) return splits[0], splits[1] + + +def lookup_op(qualname: str) -> torch._ops.OpOverloadPacket: + namespace, name = parse_namespace(qualname) + if "." in name: + name, overload = name.split(".") + else: + overload = "default" + ns = getattr(torch.ops, namespace) + packet = getattr(ns, name) + return getattr(packet, overload) + + +def is_functional_schema(schema: Any) -> bool: + """Check if the schema is functional. + + An operator is functional if: + - it does not mutate any of its inputs + - it does not return a view on any of its inputs + - it has at least one return + """ + + # Lazy import because not all PyTorch builds have torchgen + from torchgen.model import FunctionSchema, SchemaKind + + assert isinstance(schema, (str, FunctionSchema)) + if isinstance(schema, str): + schema = FunctionSchema.parse(schema) + + if schema.kind() != SchemaKind.functional: + return False + rets = schema.returns + is_non_mutating_view = len(rets) > 0 and any( + r.annotation is not None and not r.annotation.is_write for r in rets + ) + if is_non_mutating_view: + return False + if not schema.returns: + return False + return True From fdbb73fa4e8fd5092bb1a641973f9b3ffa7ed35a Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Fri, 27 Oct 2023 15:35:30 +0000 Subject: [PATCH 47/78] Check both ops and refs in test_strided_layout (#112160) Trying #112023 again to see if CLA issue is fixed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/112160 Approved by: https://github.com/lezcano, https://github.com/Neilblaze --- test/test_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_ops.py b/test/test_ops.py index 3e9b9f673a38d..a2d5a65466f5f 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -2196,7 +2196,7 @@ def test_fake_crossref_backward_no_amp(self, device, dtype, op): def test_fake_crossref_backward_amp(self, device, dtype, op): self._test_fake_crossref_helper(device, dtype, op, torch.cuda.amp.autocast) - @ops([op for op in op_db if op.is_factory_function]) + @ops([op for op in ops_and_refs if op.is_factory_function]) def test_strided_layout(self, device, dtype, op): samples = op.sample_inputs(device, dtype) for sample in samples: From 63c089b09de42fff67f7ad7ae410669ae0880e1f Mon Sep 17 00:00:00 2001 From: Aaron Enye Shi Date: Fri, 27 Oct 2023 16:18:40 +0000 Subject: [PATCH 48/78] [c10] Move profiler clock to libc10 for timestamps (#111972) Summary: Move the profiler's Approximate Clock from libtorch to libc10. The main reason is to allow c10 features to get time. The clock is using TSC when available for performance. CUDA Caching Allocator's implementation of memory snapshot will add the timestamps to memory events with this same clock in subsequent diff. Test Plan: CI Differential Revision: D50601935 Pulled By: aaronenyeshi Pull Request resolved: https://github.com/pytorch/pytorch/pull/111972 Approved by: https://github.com/davidberard98 --- c10/util/ApproximateClock.cpp | 79 ++++++++++++ c10/util/ApproximateClock.h | 121 ++++++++++++++++++ test/cpp/jit/test_backend_compiler_lib.cpp | 7 +- test/cpp/profiler/containers.cpp | 10 +- torch/csrc/autograd/profiler_kineto.cpp | 14 +- torch/csrc/autograd/profiler_legacy.cpp | 3 +- torch/csrc/autograd/profiler_python.cpp | 41 +++--- torch/csrc/distributed/c10d/reducer_timer.hpp | 3 +- .../jit/runtime/register_prim_ops_fulljit.cpp | 4 +- torch/csrc/profiler/collection.cpp | 20 +-- torch/csrc/profiler/collection.h | 35 +++-- .../profiler/orchestration/python_tracer.cpp | 4 +- .../profiler/orchestration/python_tracer.h | 7 +- torch/csrc/profiler/stubs/cuda.cpp | 3 +- torch/csrc/profiler/util.cpp | 71 ---------- torch/csrc/profiler/util.h | 103 --------------- 16 files changed, 283 insertions(+), 242 deletions(-) create mode 100644 c10/util/ApproximateClock.cpp create mode 100644 c10/util/ApproximateClock.h diff --git a/c10/util/ApproximateClock.cpp b/c10/util/ApproximateClock.cpp new file mode 100644 index 0000000000000..0bda220d83da9 --- /dev/null +++ b/c10/util/ApproximateClock.cpp @@ -0,0 +1,79 @@ +#include +#include +#include +#include + +namespace c10 { + +ApproximateClockToUnixTimeConverter::ApproximateClockToUnixTimeConverter() + : start_times_(measurePairs()) {} + +ApproximateClockToUnixTimeConverter::UnixAndApproximateTimePair +ApproximateClockToUnixTimeConverter::measurePair() { + // Take a measurement on either side to avoid an ordering bias. + auto fast_0 = getApproximateTime(); + auto wall = std::chrono::system_clock::now(); + auto fast_1 = getApproximateTime(); + + TORCH_INTERNAL_ASSERT(fast_1 >= fast_0, "getCount is non-monotonic."); + auto t = std::chrono::duration_cast( + wall.time_since_epoch()); + + // `x + (y - x) / 2` is a more numerically stable average than `(x + y) / 2`. + return {t.count(), fast_0 + (fast_1 - fast_0) / 2}; +} + +ApproximateClockToUnixTimeConverter::time_pairs +ApproximateClockToUnixTimeConverter::measurePairs() { + static constexpr auto n_warmup = 5; + for (C10_UNUSED const auto _ : c10::irange(n_warmup)) { + getApproximateTime(); + static_cast(steady_clock_t::now()); + } + + time_pairs out; + for (const auto i : c10::irange(out.size())) { + out[i] = measurePair(); + } + return out; +} + +std::function ApproximateClockToUnixTimeConverter:: + makeConverter() { + auto end_times = measurePairs(); + + // Compute the real time that passes for each tick of the approximate clock. + std::array scale_factors{}; + for (const auto i : c10::irange(replicates)) { + auto delta_ns = end_times[i].t_ - start_times_[i].t_; + auto delta_approx = end_times[i].approx_t_ - start_times_[i].approx_t_; + scale_factors[i] = (double)delta_ns / (double)delta_approx; + } + std::sort(scale_factors.begin(), scale_factors.end()); + long double scale_factor = scale_factors[replicates / 2 + 1]; + + // We shift all times by `t0` for better numerics. Double precision only has + // 16 decimal digits of accuracy, so if we blindly multiply times by + // `scale_factor` we may suffer from precision loss. The choice of `t0` is + // mostly arbitrary; we just need a factor that is the correct order of + // magnitude to bring the intermediate values closer to zero. We are not, + // however, guaranteed that `t0_approx` is *exactly* the getApproximateTime + // equivalent of `t0`; it is only an estimate that we have to fine tune. + auto t0 = start_times_[0].t_; + auto t0_approx = start_times_[0].approx_t_; + std::array t0_correction{}; + for (const auto i : c10::irange(replicates)) { + auto dt = start_times_[i].t_ - t0; + auto dt_approx = + (double)(start_times_[i].approx_t_ - t0_approx) * scale_factor; + t0_correction[i] = dt - (time_t)dt_approx; // NOLINT + } + t0 += t0_correction[t0_correction.size() / 2 + 1]; // NOLINT + + return [=](approx_time_t t_approx) { + // See above for why this is more stable than `A * t_approx + B`. + return (time_t)((double)(t_approx - t0_approx) * scale_factor) + t0; + }; +} + +} // namespace c10 diff --git a/c10/util/ApproximateClock.h b/c10/util/ApproximateClock.h new file mode 100644 index 0000000000000..7de498cebed6a --- /dev/null +++ b/c10/util/ApproximateClock.h @@ -0,0 +1,121 @@ +// Copyright 2023-present Facebook. All Rights Reserved. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#ifndef _WIN32 +#include +#endif +#if defined(C10_IOS) && defined(C10_MOBILE) +#include // for gettimeofday() +#endif + +#if defined(__i386__) || defined(__x86_64__) || defined(__amd64__) +#define C10_RDTSC +#if defined(_MSC_VER) +#include +#elif defined(__CUDACC__) || defined(__HIPCC__) +#undef C10_RDTSC +#elif defined(__clang__) +// `__rdtsc` is available by default. +// NB: This has to be first, because Clang will also define `__GNUC__` +#elif defined(__GNUC__) +#include +#else +#undef C10_RDTSC +#endif +#endif + +namespace c10 { + +using time_t = int64_t; +using steady_clock_t = std::conditional< + std::chrono::high_resolution_clock::is_steady, + std::chrono::high_resolution_clock, + std::chrono::steady_clock>::type; + +inline time_t getTimeSinceEpoch() { + auto now = std::chrono::system_clock::now().time_since_epoch(); + return std::chrono::duration_cast(now).count(); +} + +inline time_t getTime(bool allow_monotonic = false) { +#if defined(C10_IOS) && defined(C10_MOBILE) + // clock_gettime is only available on iOS 10.0 or newer. Unlike OS X, iOS + // can't rely on CLOCK_REALTIME, as it is defined no matter if clock_gettime + // is implemented or not + struct timeval now; + gettimeofday(&now, NULL); + return static_cast(now.tv_sec) * 1000000000 + + static_cast(now.tv_usec) * 1000; +#elif defined(_WIN32) || defined(__MACH__) + return std::chrono::duration_cast( + steady_clock_t::now().time_since_epoch()) + .count(); +#else + // clock_gettime is *much* faster than std::chrono implementation on Linux + struct timespec t {}; + auto mode = CLOCK_REALTIME; + if (allow_monotonic) { + mode = CLOCK_MONOTONIC; + } + clock_gettime(mode, &t); + return static_cast(t.tv_sec) * 1000000000 + + static_cast(t.tv_nsec); +#endif +} + +// We often do not need to capture true wall times. If a fast mechanism such +// as TSC is available we can use that instead and convert back to epoch time +// during post processing. This greatly reduce the clock's contribution to +// profiling. +// http://btorpey.github.io/blog/2014/02/18/clock-sources-in-linux/ +// https://quick-bench.com/q/r8opkkGZSJMu9wM_XTbDouq-0Io +// TODO: We should use +// `https://github.com/google/benchmark/blob/main/src/cycleclock.h` +inline auto getApproximateTime() { +#if defined(C10_RDTSC) + return static_cast(__rdtsc()); +#else + return getTime(); +#endif +} + +using approx_time_t = decltype(getApproximateTime()); +static_assert( + std::is_same::value || + std::is_same::value, + "Expected either int64_t (`getTime`) or uint64_t (some TSC reads)."); + +// Convert `getCount` results to Nanoseconds since unix epoch. +class C10_API ApproximateClockToUnixTimeConverter final { + public: + ApproximateClockToUnixTimeConverter(); + std::function makeConverter(); + + struct UnixAndApproximateTimePair { + time_t t_; + approx_time_t approx_t_; + }; + static UnixAndApproximateTimePair measurePair(); + + private: + static constexpr size_t replicates = 1001; + using time_pairs = std::array; + time_pairs measurePairs(); + + time_pairs start_times_; +}; + +} // namespace c10 diff --git a/test/cpp/jit/test_backend_compiler_lib.cpp b/test/cpp/jit/test_backend_compiler_lib.cpp index 14ea7bcd2f061..33262efd1e2b1 100644 --- a/test/cpp/jit/test_backend_compiler_lib.cpp +++ b/test/cpp/jit/test_backend_compiler_lib.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include @@ -112,14 +113,14 @@ class BackendWithCompiler : public PyTorchBackendInterface { c10::List output_list; #ifndef NO_PROFILING - auto start_us = torch::profiler::impl::getTime() / 1000; + auto start_us = c10::getTime() / 1000; #endif for (const auto& token : handle.toList()) { IValue val = token; auto instruction = val.toTupleRef().elements()[0].toStringRef(); auto debug_handle = val.toTupleRef().elements()[1].toInt(); #ifndef NO_PROFILING - auto start_time_us = torch::profiler::impl::getTime() / 1000; + auto start_time_us = c10::getTime() / 1000; #endif try { if (instruction.rfind("prim::Constant", 0) == 0) { @@ -171,7 +172,7 @@ class BackendWithCompiler : public PyTorchBackendInterface { TORCH_DELEGATED_BACKEND_THROW(false, e.what(), debug_handle); } #ifndef NO_PROFILING - auto end_time_us = torch::profiler::impl::getTime() / 1000; + auto end_time_us = c10::getTime() / 1000; auto duration = end_time_us - start_time_us; op_runtimes_us.emplace_back(duration, debug_handle, instruction); #endif diff --git a/test/cpp/profiler/containers.cpp b/test/cpp/profiler/containers.cpp index cb417af1bbacc..d5870795d7467 100644 --- a/test/cpp/profiler/containers.cpp +++ b/test/cpp/profiler/containers.cpp @@ -5,6 +5,7 @@ #include +#include #include #include #include @@ -48,13 +49,12 @@ TEST(ProfilerTest, AppendOnlyList_ref) { // Test that we can convert TSC measurements back to wall clock time. TEST(ProfilerTest, clock_converter) { const int n = 10001; - torch::profiler::impl::ApproximateClockToUnixTimeConverter converter; - std::vector + c10::ApproximateClockToUnixTimeConverter converter; + std::vector< + c10::ApproximateClockToUnixTimeConverter::UnixAndApproximateTimePair> pairs; for (const auto i : c10::irange(n)) { - pairs.push_back(torch::profiler::impl::ApproximateClockToUnixTimeConverter:: - measurePair()); + pairs.push_back(c10::ApproximateClockToUnixTimeConverter::measurePair()); } auto count_to_ns = converter.makeConverter(); std::vector deltas; diff --git a/torch/csrc/autograd/profiler_kineto.cpp b/torch/csrc/autograd/profiler_kineto.cpp index 056f4b17df1d9..63367ecb8eb57 100644 --- a/torch/csrc/autograd/profiler_kineto.cpp +++ b/torch/csrc/autograd/profiler_kineto.cpp @@ -3,6 +3,7 @@ #include #include +#include #include #include #include @@ -56,7 +57,7 @@ inline int64_t getTimeUs() { #ifdef USE_KINETO return libkineto::timeSinceEpoch(std::chrono::system_clock::now()); #else - return torch::profiler::impl::getTime() / 1000; + return c10::getTime() / 1000; #endif // USE_KINETO } @@ -321,7 +322,7 @@ struct KinetoThreadLocalState : public ProfilerStateBase { void reportVulkanEventToProfiler(torch::profiler::impl::vulkan_id_t id) { if (!config_.disabled()) { record_queue_.getSubqueue()->emplace_vulkan_event( - torch::profiler::impl::getApproximateTime(), id); + c10::getApproximateTime(), id); } } @@ -333,7 +334,7 @@ struct KinetoThreadLocalState : public ProfilerStateBase { c10::Device device) override { if (config_.profile_memory && !config_.disabled()) { record_queue_.getSubqueue()->emplace_allocation_event( - torch::profiler::impl::getApproximateTime(), + c10::getApproximateTime(), ptr, alloc_size, total_allocated, @@ -350,7 +351,7 @@ struct KinetoThreadLocalState : public ProfilerStateBase { c10::Device device) override { if (config_.profile_memory && !config_.disabled()) { record_queue_.getSubqueue()->emplace_ooms_event( - torch::profiler::impl::getApproximateTime(), + c10::getApproximateTime(), alloc_size, total_allocated, total_reserved, @@ -421,7 +422,7 @@ struct KinetoThreadLocalState : public ProfilerStateBase { } uint64_t start_time_; - torch::profiler::impl::ApproximateClockToUnixTimeConverter clock_converter_; + c10::ApproximateClockToUnixTimeConverter clock_converter_; torch::profiler::impl::RecordQueue record_queue_; std::vector kineto_events_; std::vector event_tree_; @@ -452,8 +453,7 @@ void onFunctionExit( auto* kineto_ctx_ptr = static_cast(ctx_ptr); TORCH_INTERNAL_ASSERT(kineto_ctx_ptr != nullptr); - kineto_ctx_ptr->event_->end_time_ = - torch::profiler::impl::getApproximateTime(); + kineto_ctx_ptr->event_->end_time_ = c10::getApproximateTime(); if (!config.experimental_config.performance_events.empty()) { state_ptr->record_queue_.getSubqueue()->disable_perf_profiler( *kineto_ctx_ptr->event_->counters_); diff --git a/torch/csrc/autograd/profiler_legacy.cpp b/torch/csrc/autograd/profiler_legacy.cpp index 388695957e45d..5d7dd02312b7f 100644 --- a/torch/csrc/autograd/profiler_legacy.cpp +++ b/torch/csrc/autograd/profiler_legacy.cpp @@ -18,6 +18,7 @@ #include #include +#include #include #include @@ -477,7 +478,7 @@ void LegacyEvent::record(bool record_cuda) { torch::profiler::impl::cudaStubs()->record(&device_, &cuda_event, &cpu_ns_); return; } - cpu_ns_ = torch::profiler::impl::getTime(); + cpu_ns_ = c10::getTime(); } /* static */ LegacyEvent LegacyEvent::fromIValue( diff --git a/torch/csrc/autograd/profiler_python.cpp b/torch/csrc/autograd/profiler_python.cpp index b716693b0ca89..9d827837bf4f0 100644 --- a/torch/csrc/autograd/profiler_python.cpp +++ b/torch/csrc/autograd/profiler_python.cpp @@ -16,6 +16,7 @@ #include #include +#include #include #include #include @@ -293,7 +294,7 @@ class ValueCache { auto caller = load(callsite.caller_); TORCH_INTERNAL_ASSERT(!caller.module_info_.has_value()); return ExtraFields::event_type>{ - /*end_time_ns=*/std::numeric_limits::min(), + /*end_time_ns=*/std::numeric_limits::min(), python_tid, caller.frame_state_, load(callsite.value_)}; @@ -666,8 +667,8 @@ struct ThreadLocalResults { ValueCache* value_cache_; PythonTracer* active_tracer_; CallTypeHelper::tuple_type trace_keys_; - AppendOnlyList exit_times_; - AppendOnlyList c_exit_times_; + AppendOnlyList exit_times_; + AppendOnlyList c_exit_times_; }; // ============================================================================ @@ -687,13 +688,13 @@ class PythonTracer final : public python_tracer::PythonTracerBase { void stop() override; std::vector> getEvents( - std::function time_converter, + std::function time_converter, std::vector& enters, - time_t end_time_ns) override; + c10::time_t end_time_ns) override; struct StartFrame { TraceKey trace_key_; - approx_time_t start_time{}; + c10::approx_time_t start_time{}; }; private: @@ -863,7 +864,7 @@ void PythonTracer::recordPyCall( return tls.intern(no_ephemeral_t(), frame, f_back); } }(); - const auto time = getApproximateTime(); + const auto time = c10::getApproximateTime(); is_startup_frame ? start_frames_.push_back({key, time}) : queue_->getSubqueue()->emplace_py_call(key, time); } @@ -879,7 +880,7 @@ void PythonTracer::recordCCall( // `frame->f_back`. auto key = tls.intern( arg, (void*)(fn->m_ml), frame); - queue_->getSubqueue()->emplace_py_call(key, getApproximateTime()); + queue_->getSubqueue()->emplace_py_call(key, c10::getApproximateTime()); } // ============================================================================ @@ -890,17 +891,17 @@ struct Exit { return t_ > other.t_; } - time_t t_; + c10::time_t t_; size_t python_tid_; }; class PostProcess { public: PostProcess( - std::function time_converter, + std::function time_converter, std::deque& tls, const ValueCache& value_cache, - time_t end_time_ns) + c10::time_t end_time_ns) : end_time_{end_time_ns}, time_converter_{std::move(time_converter)} { for (size_t python_tid : c10::irange(tls.size())) { CallTypeHelper::map( @@ -936,7 +937,9 @@ class PostProcess { } template - void addExits(AppendOnlyList& exits, size_t python_tid) { + void addExits( + AppendOnlyList& exits, + size_t python_tid) { for (const auto i : exits) { get_state().exits_.push({time_converter_(i), python_tid}); } @@ -961,7 +964,7 @@ class PostProcess { std::vector>& out) { using stack_t = std::vector>; const auto initial_size = out.size(); - auto pop = [](stack_t& stack, time_t t) { + auto pop = [](stack_t& stack, c10::time_t t) { TORCH_INTERNAL_ASSERT(!stack.empty(), "Python replay stack is empty."); std::get>(stack.back()->extra_fields_).end_time_ns_ = t; stack.pop_back(); @@ -1026,8 +1029,8 @@ class PostProcess { return std::get < E == EventType::PyCall ? 0 : 1 > (state_); } - time_t end_time_; - std::function time_converter_; + c10::time_t end_time_; + std::function time_converter_; std::tuple, State> state_; }; @@ -1054,9 +1057,9 @@ struct PythonIDVisitor { }; std::vector> PythonTracer::getEvents( - std::function time_converter, + std::function time_converter, std::vector& enters, - time_t end_time_ns) { + c10::time_t end_time_ns) { value_cache_.trimPrefixes(); PostProcess post_process( std::move(time_converter), @@ -1099,12 +1102,12 @@ int PythonTracer::pyProfileFn( case PyTrace_EXCEPTION: case PyTrace_RETURN: - local_results.exit_times_.emplace_back(getApproximateTime()); + local_results.exit_times_.emplace_back(c10::getApproximateTime()); break; case PyTrace_C_EXCEPTION: case PyTrace_C_RETURN: - local_results.c_exit_times_.emplace_back(getApproximateTime()); + local_results.c_exit_times_.emplace_back(c10::getApproximateTime()); break; } return 0; diff --git a/torch/csrc/distributed/c10d/reducer_timer.hpp b/torch/csrc/distributed/c10d/reducer_timer.hpp index ca8dd163eecdf..acd8975c4d2db 100644 --- a/torch/csrc/distributed/c10d/reducer_timer.hpp +++ b/torch/csrc/distributed/c10d/reducer_timer.hpp @@ -1,11 +1,12 @@ #pragma once +#include #include namespace c10d { constexpr int kUnsetTime = -1; inline int64_t current_time_in_nanos() { - return torch::profiler::impl::getTime(); + return c10::getTime(); } class TORCH_API Timer { diff --git a/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp b/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp index 822db07fa7072..ac2dd62e64c16 100644 --- a/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp +++ b/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp @@ -2,6 +2,7 @@ #include #include +#include #include #include #include @@ -370,8 +371,7 @@ RegisterOperators logging_operators( tracer::recordSourceLocation(node); graph->insertNode(node); } - auto output = - torch::profiler::impl::getTime(/*allow_monotonic=*/true); + auto output = c10::getTime(/*allow_monotonic=*/true); push(stack, output); if (jit::tracer::isTracing()) { jit::tracer::addOutput(node, output); diff --git a/torch/csrc/profiler/collection.cpp b/torch/csrc/profiler/collection.cpp index bb18304359a02..440cabe94d28a 100644 --- a/torch/csrc/profiler/collection.cpp +++ b/torch/csrc/profiler/collection.cpp @@ -362,7 +362,7 @@ std::unique_ptr ThreadLocalSubqueue::begin_op( nullptr, &out->fallback_->device_event_start_, nullptr); } - event->start_time_ = torch::profiler::impl::getApproximateTime(); + event->start_time_ = c10::getApproximateTime(); event->allow_tf32_cublas_ = at::globalContext().allowTF32CuBLAS(); if (!config_.experimental_config.performance_events.empty()) { const size_t n = config_.experimental_config.performance_events.size(); @@ -402,7 +402,7 @@ struct StealOrDefault { void ThreadLocalSubqueue::TorchOpStorage::materialize( std::vector>& out, - const std::function& time_converter, + const std::function& time_converter, const uint64_t tid, const kineto::DeviceAndResource& kineto_info) { // Plumb Autograd info to the top level annotation. @@ -471,7 +471,7 @@ void materialize_vulkan( std::vector>& out, AppendOnlyList::raw_event_t, BlockSize>& raw_events, - const std::function& time_converter, + const std::function& time_converter, const uint64_t tid, const kineto::DeviceAndResource& kineto_info) { for (const auto& i : raw_events) { @@ -530,7 +530,7 @@ int64_t torchOpEndNS( const ExtraFields& e, const bool finished, const std::weak_ptr& parent) { - if (finished && e.end_time_ns_ == std::numeric_limits::min()) { + if (finished && e.end_time_ns_ == std::numeric_limits::min()) { auto p = parent.lock(); if (p) { return p->endTimeNS(); @@ -1176,7 +1176,7 @@ void build_tree(std::vector>& sorted_events) { if (event->endTimeNS() > event->start_time_ns_) { stacks[event->start_tid_] = event; end_events_.push(event); - } else if (event->endTimeNS() == std::numeric_limits::min()) { + } else if (event->endTimeNS() == std::numeric_limits::min()) { // We use min time to indicate the lack of a termination event, so if we // encounter such a case we don't push to `end_events_`. stacks[event->start_tid_] = event; @@ -1350,12 +1350,12 @@ std::pair< std::vector>, std::unique_ptr> RecordQueue::getRecords( - std::function time_converter, + std::function time_converter, uint64_t start_time_us, uint64_t end_time_us) { - auto converter = [&](approx_time_t t) { - return t == std::numeric_limits::min() - ? std::numeric_limits::min() + auto converter = [&](c10::approx_time_t t) { + return t == std::numeric_limits::min() + ? std::numeric_limits::min() : time_converter(t); }; std::vector> out; @@ -1364,7 +1364,7 @@ RecordQueue::getRecords( auto& queue = *subqueue_it.second; auto materialize = [&](auto& events) { for (auto& i : events) { - time_t start_time_ns = 0; + c10::time_t start_time_ns = 0; if constexpr (std::is_same_v< std::remove_reference_t, ExtraFields>) { diff --git a/torch/csrc/profiler/collection.h b/torch/csrc/profiler/collection.h index b70def43c89ff..d4640eece1c4f 100644 --- a/torch/csrc/profiler/collection.h +++ b/torch/csrc/profiler/collection.h @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -126,7 +127,7 @@ struct ExtraFields : TorchOpBasicFields { ExtraFields( TorchOpBasicFields&& f, uint64_t correlation_id, - time_t end_time_ns, + c10::time_t end_time_ns, std::vector&& inputs, std::vector&& concrete_inputs, jit_stack_t&& jit_stack, @@ -149,7 +150,7 @@ struct ExtraFields : TorchOpBasicFields { allow_tf32_cublas_{allow_tf32_cublas}, perf_event_counters_{std::move(perf_event_counters)} {} uint64_t correlation_id_; - time_t end_time_ns_; + c10::time_t end_time_ns_; std::vector inputs_; std::vector concrete_inputs_; jit_stack_t jit_stack_; @@ -175,7 +176,7 @@ struct ExtraFields { template <> struct ExtraFields { - using raw_event_t = std::pair; + using raw_event_t = std::pair; std::string name_; int64_t duration_ns_{0}; // While building the event tree, we want to report a vulkan event's duration @@ -184,7 +185,7 @@ struct ExtraFields { }; struct RawAllocation { - torch::profiler::impl::approx_time_t start_time_; + c10::approx_time_t start_time_; void* ptr_; int64_t alloc_size_; size_t total_allocated_; @@ -210,7 +211,7 @@ struct ExtraFields : RawAllocation { template <> struct ExtraFields { - torch::profiler::impl::approx_time_t start_time_; + c10::approx_time_t start_time_; int64_t alloc_size_; size_t total_allocated_; size_t total_reserved_; @@ -270,12 +271,15 @@ struct OptimizerInfo { }; struct PyExtraFieldsBase { - PyExtraFieldsBase(time_t end_time_ns, size_t python_tid, PyFrameState caller) + PyExtraFieldsBase( + c10::time_t end_time_ns, + size_t python_tid, + PyFrameState caller) : end_time_ns_{end_time_ns}, python_tid_{python_tid}, caller_{std::move(caller)} {} - time_t end_time_ns_; + c10::time_t end_time_ns_; size_t python_tid_; PyFrameState caller_; @@ -292,7 +296,7 @@ struct ExtraFields : public PyExtraFieldsBase { }; ExtraFields( - time_t end_time_ns, + c10::time_t end_time_ns, size_t python_tid, PyFrameState caller, args_t args) @@ -311,7 +315,7 @@ struct ExtraFields : public PyExtraFieldsBase { using args_t = at::StringView; ExtraFields( - time_t end_time_ns, + c10::time_t end_time_ns, size_t python_tid, PyFrameState caller, args_t args) @@ -421,10 +425,11 @@ struct TORCH_API Result : public std::enable_shared_from_this { struct KinetoObserverContext : public at::ObserverContext { struct Event { TorchOpBasicFields basic_fields_; - approx_time_t start_time_; + c10::approx_time_t start_time_; // Set in the exit callback. - approx_time_t end_time_{std::numeric_limits::min()}; + c10::approx_time_t end_time_{ + std::numeric_limits::min()}; bool allow_tf32_cublas_; std::unique_ptr counters_; @@ -549,7 +554,7 @@ class TORCH_API ThreadLocalSubqueue { // NB: This is a destructive operation. void materialize( std::vector>& out, - const std::function& time_converter, + const std::function& time_converter, const uint64_t tid, const kineto::DeviceAndResource& kineto_info); @@ -605,7 +610,9 @@ class TORCH_API ThreadLocalSubqueue { AppendOnlyList, BlockSize> ooms_; // with_stack (Python) - AppendOnlyList, BlockSize> + AppendOnlyList< + std::pair, + BlockSize> py_calls_; }; @@ -622,7 +629,7 @@ class TORCH_API RecordQueue { std::vector>, std::unique_ptr> getRecords( - std::function time_converter, + std::function time_converter, uint64_t start_time_us, uint64_t end_time_us); diff --git a/torch/csrc/profiler/orchestration/python_tracer.cpp b/torch/csrc/profiler/orchestration/python_tracer.cpp index 64db126b25ef6..61773fe23ca6b 100644 --- a/torch/csrc/profiler/orchestration/python_tracer.cpp +++ b/torch/csrc/profiler/orchestration/python_tracer.cpp @@ -13,9 +13,9 @@ struct NoOpPythonTracer : public PythonTracerBase { void stop() override {} std::vector> getEvents( - std::function, + std::function, std::vector&, - time_t) override { + c10::time_t) override { return {}; } }; diff --git a/torch/csrc/profiler/orchestration/python_tracer.h b/torch/csrc/profiler/orchestration/python_tracer.h index 93becfee3cf31..f05aefecfe169 100644 --- a/torch/csrc/profiler/orchestration/python_tracer.h +++ b/torch/csrc/profiler/orchestration/python_tracer.h @@ -5,6 +5,7 @@ #include #include +#include #include #include @@ -29,7 +30,7 @@ struct CompressedEvent { TraceKey key_; uint64_t system_tid_; kineto::DeviceAndResource kineto_info_; - time_t enter_t_; + c10::time_t enter_t_; }; /* @@ -49,9 +50,9 @@ struct TORCH_API PythonTracerBase { virtual void stop() = 0; virtual std::vector> getEvents( - std::function time_converter, + std::function time_converter, std::vector& enters, - time_t end_time_ns) = 0; + c10::time_t end_time_ns) = 0; }; using MakeFn = std::unique_ptr (*)(RecordQueue*); diff --git a/torch/csrc/profiler/stubs/cuda.cpp b/torch/csrc/profiler/stubs/cuda.cpp index dec87576f364c..ff39df4eb1645 100644 --- a/torch/csrc/profiler/stubs/cuda.cpp +++ b/torch/csrc/profiler/stubs/cuda.cpp @@ -3,6 +3,7 @@ #include #include +#include #include #include #include @@ -49,7 +50,7 @@ struct CUDAMethods : public ProfilerStubs { }); auto stream = at::cuda::getCurrentCUDAStream(); if (cpu_ns) { - *cpu_ns = torch::profiler::impl::getTime(); + *cpu_ns = c10::getTime(); } TORCH_CUDA_CHECK(cudaEventRecord(cuda_event_ptr, stream)); } diff --git a/torch/csrc/profiler/util.cpp b/torch/csrc/profiler/util.cpp index 65dc3b5a5fd76..f29366bc1955b 100644 --- a/torch/csrc/profiler/util.cpp +++ b/torch/csrc/profiler/util.cpp @@ -19,77 +19,6 @@ namespace torch { namespace profiler { namespace impl { -ApproximateClockToUnixTimeConverter::ApproximateClockToUnixTimeConverter() - : start_times_(measurePairs()) {} - -ApproximateClockToUnixTimeConverter::UnixAndApproximateTimePair -ApproximateClockToUnixTimeConverter::measurePair() { - // Take a measurement on either side to avoid an ordering bias. - auto fast_0 = getApproximateTime(); - auto wall = std::chrono::system_clock::now(); - auto fast_1 = getApproximateTime(); - - TORCH_INTERNAL_ASSERT(fast_1 >= fast_0, "getCount is non-monotonic."); - auto t = std::chrono::duration_cast( - wall.time_since_epoch()); - - // `x + (y - x) / 2` is a more numerically stable average than `(x + y) / 2`. - return {t.count(), fast_0 + (fast_1 - fast_0) / 2}; -} - -ApproximateClockToUnixTimeConverter::time_pairs -ApproximateClockToUnixTimeConverter::measurePairs() { - static constexpr auto n_warmup = 5; - for (C10_UNUSED const auto _ : c10::irange(n_warmup)) { - getApproximateTime(); - steady_clock_t::now(); - } - - time_pairs out; - for (const auto i : c10::irange(out.size())) { - out[i] = measurePair(); - } - return out; -} - -std::function ApproximateClockToUnixTimeConverter:: - makeConverter() { - auto end_times = measurePairs(); - - // Compute the real time that passes for each tick of the approximate clock. - std::array scale_factors{}; - for (const auto i : c10::irange(replicates)) { - auto delta_ns = end_times[i].t_ - start_times_[i].t_; - auto delta_approx = end_times[i].approx_t_ - start_times_[i].approx_t_; - scale_factors[i] = (double)delta_ns / (double)delta_approx; - } - std::sort(scale_factors.begin(), scale_factors.end()); - long double scale_factor = scale_factors[replicates / 2 + 1]; - - // We shift all times by `t0` for better numerics. Double precision only has - // 16 decimal digits of accuracy, so if we blindly multiply times by - // `scale_factor` we may suffer from precision loss. The choice of `t0` is - // mostly arbitrary; we just need a factor that is the correct order of - // magnitude to bring the intermediate values closer to zero. We are not, - // however, guaranteed that `t0_approx` is *exactly* the getApproximateTime - // equivilent of `t0`; it is only an estimate that we have to fine tune. - auto t0 = start_times_[0].t_; - auto t0_approx = start_times_[0].approx_t_; - std::array t0_correction{}; - for (const auto i : c10::irange(replicates)) { - auto dt = start_times_[i].t_ - t0; - auto dt_approx = - (double)(start_times_[i].approx_t_ - t0_approx) * scale_factor; - t0_correction[i] = dt - (time_t)dt_approx; - } - t0 += t0_correction[t0_correction.size() / 2 + 1]; - - return [=](approx_time_t t_approx) { - // See above for why this is more stable than `A * t_approx + B`. - return (time_t)((double)(t_approx - t0_approx) * scale_factor) + t0; - }; -} - namespace { c10::optional soft_assert_raises_; } // namespace diff --git a/torch/csrc/profiler/util.h b/torch/csrc/profiler/util.h index dd7147374f0e7..4b565c691ca04 100644 --- a/torch/csrc/profiler/util.h +++ b/torch/csrc/profiler/util.h @@ -14,29 +14,6 @@ #include #include -#ifndef _WIN32 -#include -#endif -#if defined(C10_IOS) && defined(C10_MOBILE) -#include // for gettimeofday() -#endif - -#if defined(__i386__) || defined(__x86_64__) || defined(__amd64__) -#define C10_RDTSC -#if defined(_MSC_VER) -#include -#elif defined(__CUDACC__) || defined(__HIPCC__) -#undef C10_RDTSC -#elif defined(__clang__) -// `__rdtsc` is available by default. -// NB: This has to be first, because Clang will also define `__GNUC__` -#elif defined(__GNUC__) -#include -#else -#undef C10_RDTSC -#endif -#endif - // TODO: replace with pytorch/rfcs#43 when it is ready. #define SOFT_ASSERT(cond, ...) \ [&]() -> bool { \ @@ -83,89 +60,10 @@ TORCH_API void logSoftAssert( const char* cond, const std::string& args); -using time_t = int64_t; -using steady_clock_t = std::conditional< - std::chrono::high_resolution_clock::is_steady, - std::chrono::high_resolution_clock, - std::chrono::steady_clock>::type; - using shape = std::variant, std::vector>>; constexpr int TENSOR_LIST_DISPLAY_LENGTH_LIMIT = 30; -inline time_t getTimeSinceEpoch() { - auto now = std::chrono::system_clock::now().time_since_epoch(); - return std::chrono::duration_cast(now).count(); -} - -inline time_t getTime(bool allow_monotonic = false) { -#if defined(C10_IOS) && defined(C10_MOBILE) - // clock_gettime is only available on iOS 10.0 or newer. Unlike OS X, iOS - // can't rely on CLOCK_REALTIME, as it is defined no matter if clock_gettime - // is implemented or not - struct timeval now; - gettimeofday(&now, NULL); - return static_cast(now.tv_sec) * 1000000000 + - static_cast(now.tv_usec) * 1000; -#elif defined(_WIN32) || defined(__MACH__) - return std::chrono::duration_cast( - steady_clock_t::now().time_since_epoch()) - .count(); -#else - // clock_gettime is *much* faster than std::chrono implementation on Linux - struct timespec t {}; - auto mode = CLOCK_REALTIME; - if (allow_monotonic) { - mode = CLOCK_MONOTONIC; - } - clock_gettime(mode, &t); - return static_cast(t.tv_sec) * 1000000000 + - static_cast(t.tv_nsec); -#endif -} - -// We often do not need to capture true wall times. If a fast mechanism such -// as TSC is available we can use that instead and convert back to epoch time -// during post processing. This greatly reduce the clock's contribution to -// profiling. -// http://btorpey.github.io/blog/2014/02/18/clock-sources-in-linux/ -// https://quick-bench.com/q/r8opkkGZSJMu9wM_XTbDouq-0Io -// TODO: We should use -// `https://github.com/google/benchmark/blob/main/src/cycleclock.h` -inline auto getApproximateTime() { -#if defined(C10_RDTSC) - return static_cast(__rdtsc()); -#else - return getTime(); -#endif -} - -using approx_time_t = decltype(getApproximateTime()); -static_assert( - std::is_same::value || - std::is_same::value, - "Expected either int64_t (`getTime`) or uint64_t (some TSC reads)."); - -// Convert `getCount` results to Nanoseconds since unix epoch. -class ApproximateClockToUnixTimeConverter final { - public: - ApproximateClockToUnixTimeConverter(); - std::function makeConverter(); - - struct UnixAndApproximateTimePair { - time_t t_; - approx_time_t approx_t_; - }; - static UnixAndApproximateTimePair measurePair(); - - private: - static constexpr size_t replicates = 1001; - using time_pairs = std::array; - time_pairs measurePairs(); - - time_pairs start_times_; -}; - std::string getNvtxStr( const char* name, int64_t sequence_nr, @@ -267,7 +165,6 @@ namespace torch { namespace autograd { namespace profiler { using torch::profiler::impl::computeFlops; -using torch::profiler::impl::getTime; } // namespace profiler } // namespace autograd } // namespace torch From d97332f8391e9d1f0c3fa019376d04b523bfb06c Mon Sep 17 00:00:00 2001 From: drisspg Date: Fri, 27 Oct 2023 16:54:23 +0000 Subject: [PATCH 49/78] Add cuda status checks to FA templates (#112229) # Summary cuda status checks were accidentely removed on latest update Pull Request resolved: https://github.com/pytorch/pytorch/pull/112229 Approved by: https://github.com/Skylion007 --- .../flash_attn/flash_bwd_launch_template.h | 21 +++++++++++++++++++ .../flash_attn/flash_fwd_launch_template.h | 6 ++++++ 2 files changed, 27 insertions(+) diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h index 016c42f4c2043..5bef2e301967f 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h @@ -149,6 +149,9 @@ void run_mha_bwd_hdim32(Flash_bwd_params ¶ms, cudaStream_t stream, const boo int max_smem_per_block; cudaError status_ = cudaDeviceGetAttribute( &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + if (status_ != cudaSuccess) { + C10_CUDA_CHECK(status_); + } BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { if (max_smem_per_block >= 2 * ((3 * 128 + 2 * 128) * Headdim + 2 * 128 * 128)) { // 104 KB if constexpr(!Is_dropout) { // We can afford more registers to keep V in registers @@ -170,6 +173,9 @@ void run_mha_bwd_hdim64(Flash_bwd_params ¶ms, cudaStream_t stream, const boo int max_smem_per_block; cudaError status_ = cudaDeviceGetAttribute( &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + if (status_ != cudaSuccess) { + C10_CUDA_CHECK(status_); + } // printf("max_smem_per_block = %d\n", max_smem_per_block); BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { // Changing AtomLayoutMdQ from 2 to 4 takes the same time @@ -213,6 +219,9 @@ void run_mha_bwd_hdim96(Flash_bwd_params ¶ms, cudaStream_t stream, const boo int max_smem_per_block; cudaError status_ = cudaDeviceGetAttribute( &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + if (status_ != cudaSuccess) { + C10_CUDA_CHECK(status_); + } // printf("max_smem_per_block = %d\n", max_smem_per_block); BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { // if (params.h == params.h_k) { @@ -240,6 +249,9 @@ void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream, const bo int max_smem_per_block; cudaError status_ = cudaDeviceGetAttribute( &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + if (status_ != cudaSuccess) { + C10_CUDA_CHECK(status_); + } // printf("max_smem_per_block = %d\n", max_smem_per_block); BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { // if (params.h == params.h_k) { @@ -276,6 +288,9 @@ void run_mha_bwd_hdim160(Flash_bwd_params ¶ms, cudaStream_t stream, const bo int max_smem_per_block; cudaError status_ = cudaDeviceGetAttribute( &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + if (status_ != cudaSuccess) { + C10_CUDA_CHECK(status_); + } BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { if (max_smem_per_block >= 116 * 1024) { run_flash_bwd, Is_dropout>(params, stream, configure); @@ -293,6 +308,9 @@ void run_mha_bwd_hdim192(Flash_bwd_params ¶ms, cudaStream_t stream, const bo int max_smem_per_block; cudaError status_ = cudaDeviceGetAttribute( &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + if (status_ != cudaSuccess) { + C10_CUDA_CHECK(status_); + } BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { if (max_smem_per_block >= 136 * 1024) { run_flash_bwd, Is_dropout>(params, stream, configure); @@ -318,6 +336,9 @@ void run_mha_bwd_hdim256(Flash_bwd_params ¶ms, cudaStream_t stream, const bo int max_smem_per_block; cudaError status_ = cudaDeviceGetAttribute( &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + if (status_ != cudaSuccess) { + C10_CUDA_CHECK(status_); + } BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { if (max_smem_per_block >= 176 * 1024) { // H100 run_flash_bwd, Is_dropout>(params, stream, configure); diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_fwd_launch_template.h b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_fwd_launch_template.h index ea8bce78ab355..c35dfef03ed3c 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_fwd_launch_template.h +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_fwd_launch_template.h @@ -299,6 +299,9 @@ void run_mha_fwd_hdim224(Flash_fwd_params ¶ms, cudaStream_t stream) { int max_smem_per_block; cudaError status_ = cudaDeviceGetAttribute( &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + if (status_ != cudaSuccess) { + C10_CUDA_CHECK(status_); + } // printf("max_smem_per_block = %d\n", max_smem_per_block); BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] { @@ -327,6 +330,9 @@ void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) { &max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, device); status_ = cudaDeviceGetAttribute( &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + if (status_ != cudaSuccess) { + C10_CUDA_CHECK(status_); + } // printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block); BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] { From d3bf6803b62c79f1dafd1eec49b4bd65d5a27697 Mon Sep 17 00:00:00 2001 From: Jon Chuang <9093549+jon-chuang@users.noreply.github.com> Date: Fri, 27 Oct 2023 17:14:58 +0000 Subject: [PATCH 50/78] [dynamo] add sanity check that we do not wrap tracked tensors (#112025) Identified as a result of https://github.com/pytorch/pytorch/pull/111911 Pull Request resolved: https://github.com/pytorch/pytorch/pull/112025 Approved by: https://github.com/ezyang --- torch/_dynamo/variables/builder.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 3ac8c2b970bb0..9f3f9decbf024 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -1003,6 +1003,10 @@ def assert_not_wrapped_by_this_graph(self, value: torch.Tensor): def wrap_tensor(self, value: torch.Tensor): source = self.get_source() + # We cannot already be tracking the tensor, which implies + # it would have already been wrapped + assert value not in self.tx.output.side_effects + if ( source.guard_source().is_nn_module() or get_static_address_type(value) is not None From a26cb0a3f264135a2e3e35466384dfc5eaef1ec1 Mon Sep 17 00:00:00 2001 From: Jez Ng Date: Thu, 26 Oct 2023 20:01:01 -0700 Subject: [PATCH 51/78] [dynamo] Enable typechecking for testing.py (#112129) Pull Request resolved: https://github.com/pytorch/pytorch/pull/112129 Approved by: https://github.com/Skylion007 ghstack dependencies: #111894, #111992, #112031, #112127, #112128 --- .lintrunner.toml | 1 + mypy-nofollow.ini | 3 +++ torch/_dynamo/testing.py | 25 +++++++++++++------------ 3 files changed, 17 insertions(+), 12 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index 0dfebd97c2218..02bc138fe0a4c 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -186,6 +186,7 @@ include_patterns = [ 'torch/_dynamo/funcname_cache.py', 'torch/_dynamo/convert_frame.py', 'torch/_dynamo/symbolic_convert.py', + 'torch/_dynamo/testing.py', 'torch/_dynamo/types.py', 'torch/_dynamo/output_graph.py', 'torch/_dynamo/guards.py', diff --git a/mypy-nofollow.ini b/mypy-nofollow.ini index 9e9e5e68fdc76..657ab6eefe553 100644 --- a/mypy-nofollow.ini +++ b/mypy-nofollow.ini @@ -41,6 +41,9 @@ ignore_errors = True [mypy-torch.fb.*] ignore_missing_imports = True +[mypy-torch_xla.*] +ignore_missing_imports = True + [mypy-torchvision.*] ignore_missing_imports = True diff --git a/torch/_dynamo/testing.py b/torch/_dynamo/testing.py index 13e34ed3f5119..b500ba2ac2521 100644 --- a/torch/_dynamo/testing.py +++ b/torch/_dynamo/testing.py @@ -8,9 +8,10 @@ import sys import types import unittest -from typing import Sequence, Union +from typing import List, Optional, Sequence, Union from unittest.mock import patch +np: Optional[types.ModuleType] = None try: import numpy as np except ModuleNotFoundError: @@ -62,7 +63,7 @@ def named_buffers_for_optimized_module(mod): return mod._orig_mod.named_buffers -def remove_optimized_module_prefix(name): +def remove_optimized_module_prefix(name) -> str: return re.sub(r"^_orig_mod[.]", "", name) @@ -140,21 +141,21 @@ def reduce_to_scalar_loss(out): raise NotImplementedError("Don't know how to reduce", type(out)) -def debug_dir(): +def debug_dir() -> str: path = os.path.join(os.path.dirname(__file__), "../debug") if not os.path.exists(path): os.mkdir(path) return path -def debug_dump(name, code: types.CodeType, extra=""): +def debug_dump(name, code: types.CodeType, extra="") -> None: with open(os.path.join(debug_dir(), name), "w") as fd: fd.write( f"{dis.Bytecode(code).info()}\n\n{dis.Bytecode(code).dis()}\n\n{extra}\n" ) -def debug_insert_nops(frame, cache_size, hooks, _): +def debug_insert_nops(frame, cache_size, hooks, _) -> Optional[GuardedCode]: """used to debug jump updates""" def insert_nops(instructions, code_options): @@ -187,7 +188,7 @@ def __init__(self): self.frame_count = 0 self.op_count = 0 - def __call__(self, gm: torch.fx.GraphModule, example_inputs): + def __call__(self, gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): self.frame_count += 1 for node in gm.graph.nodes: if "call" in node.op: @@ -206,7 +207,7 @@ def __init__(self, backend): self.backend = backend self.graphs = [] - def __call__(self, gm: torch.fx.GraphModule, example_inputs): + def __call__(self, gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): from .backends.registry import lookup_backend self.frame_count += 1 @@ -223,21 +224,21 @@ class EagerAndRecordGraphs: def __init__(self): self.graphs = [] - def __call__(self, gm: torch.fx.GraphModule, example_inputs): + def __call__(self, gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): self.graphs.append(gm) return gm -def strip_comment(code): +def strip_comment(code) -> str: code = str(code) return re.sub(r"(?m)^ *#.*\n?", "", code) -def remove_trailing_space(code): +def remove_trailing_space(code) -> str: return "\n".join([line.rstrip() for line in code.split("\n")]) -def normalize_gm(gm_str): +def normalize_gm(gm_str) -> str: # strip comments as comments have path to files which may differ from # system to system. return remove_trailing_space(strip_comment(gm_str)) @@ -252,7 +253,7 @@ def standard_test(self, fn, nargs, expected_ops=None, expected_ops_dynamic=None) expected = CompileCounter() try: gm = torch.fx.symbolic_trace(fn) - expected(gm) + expected(gm) # type: ignore[call-arg] # FIXME: https://github.com/pytorch/pytorch/issues/112230 print("\nfx.symbolic_trace graph:") gm.graph.print_tabular() expected_ops = expected.op_count From cb48ef21cc934d72d7eccb89cb94daffff477b95 Mon Sep 17 00:00:00 2001 From: Nikita Shulga <2453524+malfet@users.noreply.github.com> Date: Fri, 27 Oct 2023 18:11:29 +0000 Subject: [PATCH 52/78] [no-ci] Clarify revert handling in release branches (#112262) Changes that has been reverted on trunk, must be reverted in release as well Pull Request resolved: https://github.com/pytorch/pytorch/pull/112262 Approved by: https://github.com/huydhn --- RELEASE.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/RELEASE.md b/RELEASE.md index 7df92191d447a..50725dc40cf73 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -16,6 +16,7 @@ - [Release Candidate Storage](#release-candidate-storage) - [Release Candidate health validation](#release-candidate-health-validation) - [Cherry Picking Fixes](#cherry-picking-fixes) + - [Cherry Picking Reverts](#cherry-picking-reverts) - [Promoting RCs to Stable](#promoting-rcs-to-stable) - [Additional Steps to prepare for release day](#additional-steps-to-prepare-for-release-day) - [Modify release matrix](#modify-release-matrix) @@ -211,6 +212,11 @@ Please also make sure to add milestone target to the PR/issue, especially if it **NOTE**: The cherry pick process is not an invitation to add new features, it is mainly there to fix regressions +### Cherry Picking Reverts + +If PR that has been cherry-picked into release branch has been reverted, it's cherry-pick must be reverted as well. + +Reverts for changes that was committed into the main branch prior to the branch cut, must be propagated into release branch as well. ## Promoting RCs to Stable From 700071869afec0e68cf5333caa08b4dab4edf329 Mon Sep 17 00:00:00 2001 From: Nikita Shulga <2453524+malfet@users.noreply.github.com> Date: Fri, 27 Oct 2023 18:12:12 +0000 Subject: [PATCH 53/78] [no-ci][EZ] Update RELEASE.md (#112253) Reflect default branch renames from master to main Pull Request resolved: https://github.com/pytorch/pytorch/pull/112253 Approved by: https://github.com/huydhn, https://github.com/ZainRizvi --- RELEASE.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/RELEASE.md b/RELEASE.md index 50725dc40cf73..0d22b6b5c4bca 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -133,8 +133,8 @@ them: * Example: https://github.com/pytorch/pytorch/pull/77983 and https://github.com/pytorch/pytorch/pull/77986 * A release branches should also be created in [`pytorch/xla`](https://github.com/pytorch/xla) and [`pytorch/builder`](https://github.com/pytorch/builder) repos and pinned in `pytorch/pytorch` * Example: https://github.com/pytorch/pytorch/pull/86290 and https://github.com/pytorch/pytorch/pull/90506 -* Update branch used in composite actions from trunk to release (for example, can be done by running `for i in .github/workflows/*.yml; do sed -i -e s#@master#@release/2.0# $i; done` - * Example: https://github.com/pytorch/pytorch/commit/51b42d98d696a9a474bc69f9a4c755058809542f +* Update branch used in composite actions from trunk to release (for example, can be done by running `for i in .github/workflows/*.yml; do sed -i -e s#@main#@release/2.0# $i; done` + * Example: https://github.com/pytorch/pytorch/commit/17f400404f2ca07ea5ac864428e3d08149de2304 These are examples of changes that should be made to the *default* branch after a release branch is cut From 33daaeb6b557c164876c8f2dda7fe29d1c2dabce Mon Sep 17 00:00:00 2001 From: Facebook Community Bot Date: Fri, 27 Oct 2023 18:14:49 +0000 Subject: [PATCH 54/78] Automated submodule update: FBGEMM (#112118) This is an automated pull request to update the first-party submodule for [pytorch/FBGEMM](https://github.com/pytorch/FBGEMM). New submodule commit: https://github.com/pytorch/FBGEMM/commit/6c2be8831a67d4ab5c12fc7456e1e4e192c67c38 Test Plan: Ensure that CI jobs succeed on GitHub before landing. Pull Request resolved: https://github.com/pytorch/pytorch/pull/112118 Approved by: https://github.com/malfet --- third_party/fbgemm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/fbgemm b/third_party/fbgemm index 9cd8ce8404c69..d4eec11d72a22 160000 --- a/third_party/fbgemm +++ b/third_party/fbgemm @@ -1 +1 @@ -Subproject commit 9cd8ce8404c696a9cf619c47bd820ac7d9bd2263 +Subproject commit d4eec11d72a2280cdf22b77561ba9b6f594cdd7e From baf3e054e36156b58f0f08510a1bf81dfa2ea294 Mon Sep 17 00:00:00 2001 From: DongDongBan Date: Fri, 27 Oct 2023 18:16:54 +0000 Subject: [PATCH 55/78] Fixed an error in the comment of file torch.utils.data.dataloader.py#944 . (#112244) Fixes #ISSUE_NUMBER @ssnl Pull Request resolved: https://github.com/pytorch/pytorch/pull/112244 Approved by: https://github.com/albanD --- torch/utils/data/dataloader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/utils/data/dataloader.py b/torch/utils/data/dataloader.py index 60f8c7dc6f691..ff075bbc3aea2 100644 --- a/torch/utils/data/dataloader.py +++ b/torch/utils/data/dataloader.py @@ -941,7 +941,7 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter): # # No need to check main thread. If this thread is alive, the main loader # # thread must be alive, because this thread is set as daemonic. # While `pin_memory_thread_done_event` is not set: - # Get from `index_queue`. + # Get from `worker_result_queue`. # If timed out, continue to get in the next iteration. # Otherwise, process data. # While `pin_memory_thread_done_event` is not set: From b110d87ac271db01fd1d24a6595cf9633ac1ce43 Mon Sep 17 00:00:00 2001 From: chilli Date: Thu, 26 Oct 2023 19:15:39 -0700 Subject: [PATCH 56/78] Readded device_assert skipping in index and index_put (and also added (#112093) copy to noop pass) Pull Request resolved: https://github.com/pytorch/pytorch/pull/112093 Approved by: https://github.com/oulgen, https://github.com/lezcano --- test/inductor/test_torchinductor.py | 14 +++++++++++ torch/_inductor/codegen/common.py | 10 ++++---- torch/_inductor/codegen/cpp.py | 2 +- torch/_inductor/fx_passes/post_grad.py | 13 ++++++---- torch/_inductor/index_propagation.py | 4 ++-- torch/_inductor/ir.py | 4 ++-- torch/_inductor/lowering.py | 33 ++++++++++++++------------ torch/_inductor/pattern_matcher.py | 1 + torch/_inductor/virtualized.py | 6 ++--- 9 files changed, 56 insertions(+), 31 deletions(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 530e5026cbb45..a3e644e5e7e3b 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -827,6 +827,7 @@ def repeat(x, n): self.assertEqual(actual, repeat(x, 3)) @skipIfRocm + @config.patch(debug_index_asserts=False) def test_neg_index(self): def test(fn, inps, has_assert: bool, has_wrapping: bool): for dynamic in (True, False): @@ -893,6 +894,11 @@ def flip_with_index(a): # Constant is propagated as we can prove that the result is always negative. test(flip_with_index_constant, (a,), has_assert=False, has_wrapping=False) + def unsafe_index(a, b): + return aten._unsafe_index(a, (b,)) + + test(unsafe_index, (a, b), has_assert=False, has_wrapping=True) + def test_computed_buffer_inlining(self): def flip(x): idx = torch.arange(x.size(0) - 1, -1, -1, device=x.device) @@ -3547,6 +3553,14 @@ def matmul_with_op(x, y, fn): out, source_codes = run_and_get_code(foo_opt, inps[0], inps[1], fn) self.assertEqual(out, matmul_with_op(inps[0], inps[1], fn)) + def test_remove_noop_copy(self): + def fn(x, y): + x = x.cos() + a = x.copy_(y) + return a.sin() + + self.common(fn, (torch.randn(8, 8), torch.randn(8))) + def test_cat_of_loops_and_extern_kernel(self): class M(torch.nn.Module): def __init__( diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index de967395518b3..9ab92e22146a7 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -987,7 +987,7 @@ def inner(*args, **kwargs): return inner @staticmethod - def indirect_indexing(var, size, check=True): + def indirect_indexing(var, size, add_asserts=True): # Skip CSE since this doesn't return an expression if var.bounds.lower < 0: @@ -1015,7 +1015,7 @@ def indirect_indexing(var, size, check=True): new_var.update_on_args("index_wrap", (var,), {}) var = new_var - if self.generate_assert(check): + if self.generate_assert(add_asserts): mask = self.load_mask(var) # An assertion line may have been written already, if so just @@ -1124,8 +1124,10 @@ def __exit__(self, exc_type, exc_val, exc_tb): V.graph.scheduler.remove_kernel_local_buffers() super().__exit__(exc_type, exc_val, exc_tb) - def generate_assert(self, check): - return (check or config.debug_index_asserts) and config.assert_indirect_indexing + def generate_assert(self, add_asserts): + return ( + add_asserts or config.debug_index_asserts + ) and config.assert_indirect_indexing def load_mask(self, var): # only the triton kernel requires mask diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 31e7ee396851a..7a1aedf6413a7 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -2275,7 +2275,7 @@ def can_use_int32(): return tmp_var @staticmethod - def indirect_indexing(index_var, size, check=True): + def indirect_indexing(index_var, size, add_asserts=True): return sympy_symbol(str(index_var)) @staticmethod diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py index cd2ffcbdc58b7..1defdab5abb13 100644 --- a/torch/_inductor/fx_passes/post_grad.py +++ b/torch/_inductor/fx_passes/post_grad.py @@ -484,8 +484,8 @@ def is_valid_splitwithsizes_cat(match): return True -def same_layout(node1: torch.fx.Node, node2: torch.fx.Node): - """True if two nodes have the same size/strides""" +def same_meta(node1: torch.fx.Node, node2: torch.fx.Node): + """True if two nodes have the same metadata""" val1 = node1.meta.get("val") val2 = node2.meta.get("val") return ( @@ -493,6 +493,7 @@ def same_layout(node1: torch.fx.Node, node2: torch.fx.Node): and val2 is not None and val1.size() == val2.size() and val1.layout == val2.layout + and val1.device == val2.device and (val1.layout != torch.strided or val1.stride() == val2.stride()) ) @@ -505,6 +506,7 @@ def register_fun(cond): register_decomposition(targets, registry=noop_registry, unsafe=True)( (cond, nop_arg) ) + return cond return register_fun @@ -564,7 +566,10 @@ def cat_noop(inputs, dim=0): return len(inputs) == 1 -@register_noop_decomp([aten.clone, aten.alias]) +# Note, we also always have a check for identical metadata, which is why these +# are safe +@register_noop_decomp([aten.copy], nop_arg=1) +@register_noop_decomp([aten.alias, aten.clone]) def true_noop(*args, **kwargs): return True @@ -607,7 +612,7 @@ def remove_noop_ops(graph: torch.fx.Graph): is_valid, args, kwargs = get_fake_args_kwargs(node) if not is_valid: continue - if same_layout(node, src) and cond(*args, **kwargs): + if same_meta(node, src) and cond(*args, **kwargs): node.replace_all_uses_with(src) graph.erase_node(node) diff --git a/torch/_inductor/index_propagation.py b/torch/_inductor/index_propagation.py index 6be483d880e3b..861a465d5fb4b 100644 --- a/torch/_inductor/index_propagation.py +++ b/torch/_inductor/index_propagation.py @@ -249,7 +249,7 @@ def inner(*args: Any, **kwargs: Any) -> IndexPropResult: return inner def indirect_indexing( - self, index: Union[Any, IndexPropVar], size: Any, check: bool = True + self, index: Union[Any, IndexPropVar], size: Any, add_asserts: bool = True ) -> Any: # nb. We do index + Where(...) rather than Where(idx >= 0, idx, idx + sz) because we don't have CSE # for SymPy expressions, so we don't want to repeat idx too much @@ -259,4 +259,4 @@ def indirect_indexing( # If we are turning a indirect indexing into direct, we need to wrap it. index = index.value.expr return index + Where(index >= 0, 0, size) - return self.fallback("indirect_indexing", (index, size, check), {}).value + return self.fallback("indirect_indexing", (index, size, add_asserts), {}).value diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index beb76bc45b197..7be3836e0d75e 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -6144,7 +6144,7 @@ def shim(mask, other): ) @staticmethod - def indirect_indexing(index_proxy, size, check=True): + def indirect_indexing(index_proxy, size, add_asserts=True): """ Flow data from tensors into indexing formulas. Introduce a call_module to update the indexing. @@ -6154,7 +6154,7 @@ def indirect_indexing(index_proxy, size, check=True): def set_indirect(new_var): self.body.replace_indirect( - var, V.ops.indirect_indexing(new_var, size, check) + var, V.ops.indirect_indexing(new_var, size, add_asserts) ) tracer.create_proxy( diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index e52f2b16e1aab..5b08b107bb0fc 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -2755,6 +2755,7 @@ def index_output_size_and_inner_fn( indices_loaders, indexed_size, x_loader, + add_asserts, ): # Note that behavior of indexing differs when there are non consecutive # tensors. In this case, the tensor index is pulled to the beginning. @@ -2805,7 +2806,7 @@ def fn(idx): ops.indirect_indexing( loader(idx[start_offset : start_offset + rank]), size, - check=check, + add_asserts=add_asserts, ) ) new_index = [ @@ -2817,7 +2818,7 @@ def fn(idx): return output_size, fn -def index_impl(x, indices, check): +def index_impl(x, indices, add_asserts): assert isinstance(indices, (list, tuple)) x_loader = x.make_loader() indices, tensor_indices = check_and_broadcast_indices(indices, x.get_device()) @@ -2844,6 +2845,7 @@ def index_impl(x, indices, check): indices_loaders, indexed_size, x_loader, + add_asserts=add_asserts, ) return Pointwise.create( @@ -2857,7 +2859,7 @@ def index_impl(x, indices, check): @register_lowering(aten.index, type_promotion_kind=None) def index(x, indices): try: - return index_impl(x, indices, check=True) + return index_impl(x, indices, add_asserts=True) except NotImplementedError: # Fallback to ATen for boolean indexing x.realize() @@ -2866,7 +2868,7 @@ def index(x, indices): @register_lowering(aten._unsafe_index, type_promotion_kind=None) def _unsafe_index(x, indices): - return index_impl(x, indices, check=False) + return index_impl(x, indices, add_asserts=False) # All the indexing decompositions are written in terms of index, index_put, and index_put_ @@ -2884,7 +2886,7 @@ def index_put(x, indices, values, accumulate=False): @register_lowering(aten._unsafe_index_put) def _unsafe_index_put(x, indices, values, accumulate=False): - return index_put_impl_(clone(x), indices, values, accumulate, check=False) + return index_put_impl_(clone(x), indices, values, accumulate, add_asserts=False) def index_put_as_masked_fill(self, indices, value, accumulate): @@ -2906,10 +2908,10 @@ def index_put_fallback(self, indices, values, accumulate): @register_lowering(aten.index_put_, type_promotion_kind=None) def index_put_(self, indices, values, accumulate=False): - return index_put_impl_(self, indices, values, accumulate, check=True) + return index_put_impl_(self, indices, values, accumulate, add_asserts=True) -def index_put_impl_(self, indices, values, accumulate, check): +def index_put_impl_(self, indices, values, accumulate, add_asserts): # Dispatch to masked fill for single boolean index with single value if ( values.get_numel() == 1 @@ -2974,6 +2976,7 @@ def index_put_impl_(self, indices, values, accumulate, check): indices_loaders, indexed_size, None, + add_asserts=add_asserts, ) values = expand(values, expected_vals_size) @@ -3235,7 +3238,7 @@ def scale_fn(x, scale, size): x = ops.index_expr(x, torch.float32) x = ops.mul(x, ops.constant(scale, torch.float32)) x = ops.to_dtype(x, torch.int32) - return ops.indirect_indexing(x, size, check=False) + return ops.indirect_indexing(x, size, add_asserts=False) def fn(idx): x = idx[-n:] @@ -3364,8 +3367,8 @@ def load_bounded(fy, fx): _0 = ops.constant(0, torch.int32) iHm1 = ops.constant(iH - 1, torch.int32) iWm1 = ops.constant(iW - 1, torch.int32) - iy = ops.indirect_indexing(clamp(fy, _0, iHm1), iH, check=False) - ix = ops.indirect_indexing(clamp(fx, _0, iWm1), iW, check=False) + iy = ops.indirect_indexing(clamp(fy, _0, iHm1), iH, add_asserts=False) + ix = ops.indirect_indexing(clamp(fx, _0, iWm1), iW, add_asserts=False) return x_loader([n, c, iy, ix]) iy = ops.to_dtype(in_y, get_int_dtype(iH + 1)) @@ -3402,7 +3405,7 @@ def reflect(x, size, offset): x = ops.index_expr(x, torch.int32) x = ops.sub(x, ops.index_expr(offset, torch.int32)) x = ops.sub(size, ops.abs(ops.sub(size, ops.abs(x)))) - return ops.indirect_indexing(x, size_num, check=False) + return ops.indirect_indexing(x, size_num, add_asserts=False) def fn(idx): *b, x, y = idx @@ -3862,12 +3865,12 @@ def fn(idx): ops.indirect_indexing( ops.minimum(ph, ops.sub(phend, ops.constant(1, torch.int32))), indices_size[-2], - check=False, + add_asserts=False, ), ops.indirect_indexing( ops.minimum(pw, ops.sub(pwend, ops.constant(1, torch.int32))), indices_size[-1], - check=False, + add_asserts=False, ), ] @@ -4313,14 +4316,14 @@ def fn(idx): ph, ops.sub(phend, ops.constant(1, torch.int32)) ), pooled_height, - check=False, + add_asserts=False, ), ops.indirect_indexing( ops.minimum( pw, ops.sub(pwend, ops.constant(1, torch.int32)) ), pooled_width, - check=False, + add_asserts=False, ), ] ), diff --git a/torch/_inductor/pattern_matcher.py b/torch/_inductor/pattern_matcher.py index aabcd15f89a92..c90d178b549e0 100644 --- a/torch/_inductor/pattern_matcher.py +++ b/torch/_inductor/pattern_matcher.py @@ -939,6 +939,7 @@ def normalize_args(**kwargs): normalize_args=normalize_args, ) pattern.register(pass_dict) + return pattern.pattern @functorch_config.patch(functionalize_rng_ops=False) diff --git a/torch/_inductor/virtualized.py b/torch/_inductor/virtualized.py index ce6438e52e979..593e949538293 100644 --- a/torch/_inductor/virtualized.py +++ b/torch/_inductor/virtualized.py @@ -79,7 +79,7 @@ def masked(mask, body, other) -> str: return f"ops.masked({mask}, {body()}, {other})" @staticmethod - def indirect_indexing(index_var, size, check=True) -> sympy.Symbol: + def indirect_indexing(index_var, size, add_asserts=True) -> sympy.Symbol: return sympy_symbol(f"({str(index_var)})") @classmethod @@ -254,10 +254,10 @@ def _wrap(x): return OpsValue(x) @staticmethod - def indirect_indexing(index, size, check=True): + def indirect_indexing(index, size, add_asserts=True): # Returns a sympy value, not IR value index = OpsWrapper._unwrap(index) - return _ops.indirect_indexing(index, size, check) + return _ops.indirect_indexing(index, size, add_asserts) ops = OpsWrapper() From c7dcba927634879fef93a1d58e57b6fc6d833405 Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Fri, 27 Oct 2023 18:29:15 +0000 Subject: [PATCH 57/78] Remove passing disable_fastpath in kwargs (#112250) Fixes an issue that came up in https://github.com/pytorch/pytorch/pull/112030 Pull Request resolved: https://github.com/pytorch/pytorch/pull/112250 Approved by: https://github.com/lezcano --- test/test_foreach.py | 16 +++++++--------- .../_internal/common_methods_invocations.py | 7 +++---- 2 files changed, 10 insertions(+), 13 deletions(-) diff --git a/test/test_foreach.py b/test/test_foreach.py index 4b8f94736510e..d2982bda15049 100644 --- a/test/test_foreach.py +++ b/test/test_foreach.py @@ -123,7 +123,7 @@ def _get_funcs(self, op): dtypes=(torch.float32,) ) def test_all_zero_size_tensors_do_not_launch_kernel(self, device, dtype, op): - wrapped_op, ref, inplace_op, inplace_ref = self._get_funcs(op) + wrapped_op, _, inplace_op, _ = self._get_funcs(op) for sample in op.sample_zero_size_inputs(device, dtype): if not op.has_no_out_of_place: @@ -147,9 +147,9 @@ def test_parity(self, device, dtype, op, noncontiguous, inplace): else: func, ref, _, _ = self._get_funcs(op) for sample in op.sample_inputs(device, dtype, noncontiguous=noncontiguous): - kwargs = sample.kwargs - disable_fastpath = kwargs.pop("disable_fastpath") - expect_fastpath = not (noncontiguous or disable_fastpath) + ref_kwargs = sample.kwargs + kwargs = ref_kwargs.copy() + expect_fastpath = not (noncontiguous or sample.disable_fastpath) if op in foreach_pointwise_op_db: values = kwargs.pop("values", None) if values is not None: @@ -168,9 +168,9 @@ def test_parity(self, device, dtype, op, noncontiguous, inplace): if not (op.has_no_in_place or op.has_no_out_of_place) else self.assertRaises(type(e)) ): - ref([ref_input, *sample.ref_args], **sample.ref_kwargs) + ref([ref_input, *sample.ref_args], **ref_kwargs) else: - expected = ref([ref_input, *sample.ref_args], **sample.ref_kwargs) + expected = ref([ref_input, *sample.ref_args], **ref_kwargs) self.assertEqual(expected, actual) def _binary_test( @@ -227,7 +227,6 @@ def clone(arg): (rhs_arg,) = sample.args kwargs = {} or sample.kwargs alpha = kwargs.pop("alpha", None) - _ = kwargs.pop("disable_fastpath") if is_fastpath else False wrapped_op, ref, inplace_op, inplace_ref = self._get_funcs(op) if isinstance(rhs_arg, Number) and not scalar_self_arg_test_complete: scalar_self_arg_test_complete = True @@ -255,7 +254,7 @@ def test_pointwise_op_with_tensor_of_scalarlist_overload(self, device, dtype, op assert len(sample.args) == 2 inputs = [sample.input, *sample.args] kwargs = sample.kwargs - disable_fastpath = kwargs.pop("disable_fastpath") if is_fastpath else False + disable_fastpath = sample.disable_fastpath and is_fastpath wrapped_op, ref, inplace_op, inplace_ref = self._get_funcs(op) values = kwargs.pop("values", None) @@ -691,7 +690,6 @@ def test_outplace_with_invalid_grads(self, device, dtype, op): func, *_ = self._get_funcs(op) sample = list(op.sample_inputs(dtype=dtype, device=device, requires_grad=True, num_input_tensors=[2], same_size=True))[0] self.assertTrue(all(t.requires_grad for t in sample.input)) - sample.kwargs.pop("disable_fastpath") if func.func in foreach_pointwise_op_db: sample.kwargs.pop("values", None) (out1, out2) = func([sample.input, *sample.args], is_cuda=False, is_fastpath=False, **sample.kwargs) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index a248e6b5551eb..69a342467c261 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -8562,12 +8562,12 @@ class ForeachRightmostArgType(enum.Enum): class ForeachSampleInput(SampleInput): ref_args: Any - ref_kwargs: Dict[Any, Any] + disable_fastpath: bool - def __init__(self, *args, **kwargs): + def __init__(self, *args, disable_fastpath=False, **kwargs): super().__init__(*args, **kwargs) self.ref_args = self.args - self.ref_kwargs = {k: self.kwargs[k] for k in self.kwargs if k != "disable_fastpath"} + self.disable_fastpath = disable_fastpath class foreach_inputs_sample_func: @@ -8869,7 +8869,6 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): kwargs.update(self._sample_kwargs(opinfo, rightmost_arg, rightmost_arg_type, dtype)) assert len(args) == 2, f"{len(args)=}" sample = ForeachSampleInput(input, *args, **kwargs) - sample.ref_kwargs["values"] = None if rightmost_arg_type == ForeachRightmostArgType.TensorList else rightmost_arg yield sample if rightmost_arg_type == ForeachRightmostArgType.TensorList: args.pop() From c120e5606ec022540bfb2d85d1194ad4782f05a4 Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Fri, 27 Oct 2023 18:37:00 +0000 Subject: [PATCH 58/78] Use ops_and_refs in test_ops.py instead of _ops_and_refs (#112022) `ops_and_refs` and `_ops_and_refs` have the same definition. Pull Request resolved: https://github.com/pytorch/pytorch/pull/112022 Approved by: https://github.com/lezcano --- test/test_ops.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index a2d5a65466f5f..f69d0a369c552 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -103,7 +103,6 @@ op_db, ) ) -_ops_and_refs = op_db + python_ref_db def reduction_dtype_filter(op): if(not isinstance(op, ReductionPythonRefInfo) or not op.supports_out @@ -118,7 +117,7 @@ def reduction_dtype_filter(op): # Create a list of operators that are a subset of _ref_test_ops but don't have a # numpy ref to compare them too, If both CPU and CUDA are compared to numpy # then they do not need to be compared to each other -_ops_and_refs_with_no_numpy_ref = [op for op in _ops_and_refs if op.ref is None] +_ops_and_refs_with_no_numpy_ref = [op for op in ops_and_refs if op.ref is None] aten = torch.ops.aten @@ -650,7 +649,7 @@ def test_noncontiguous_samples(self, device, dtype, op): # incorrectly sized out parameter warning properly yet # Cases test here: # - out= with the correct dtype and device, but the wrong shape - @ops(_ops_and_refs, dtypes=OpDTypes.none) + @ops(ops_and_refs, dtypes=OpDTypes.none) def test_out_warning(self, device, op): # Prefers running in float32 but has a fallback for the first listed supported dtype supported_dtypes = op.supported_dtypes(self.device_type) @@ -785,7 +784,7 @@ def _any_nonempty(out): # Case 3 and 4 are slightly different when the op is a factory function: # - if device, dtype are NOT passed, any combination of dtype/device should be OK for out # - if device, dtype are passed, device and dtype should match - @ops(_ops_and_refs, dtypes=OpDTypes.any_one) + @ops(ops_and_refs, dtypes=OpDTypes.any_one) def test_out(self, device, dtype, op): # Prefers running in float32 but has a fallback for the first listed supported dtype samples = op.sample_inputs(device, dtype) @@ -977,7 +976,7 @@ def _case_four_transform(t): op_out(out=out) - @ops(filter(reduction_dtype_filter, _ops_and_refs), dtypes=(torch.int16,)) + @ops(filter(reduction_dtype_filter, ops_and_refs), dtypes=(torch.int16,)) def test_out_integral_dtype(self, device, dtype, op): def helper(with_out, expectFail, op_to_test, inputs, *args, **kwargs): out = None From 192e795f3f40f52958102a076ae6afe0c8dd13af Mon Sep 17 00:00:00 2001 From: Rohan Varma Date: Fri, 27 Oct 2023 19:39:02 +0000 Subject: [PATCH 59/78] Change save -> load in comment (#112217) Change save -> load in comment because this is the load_state_dict API Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/112217 Approved by: https://github.com/wz337 --- torch/distributed/checkpoint/state_dict_loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/distributed/checkpoint/state_dict_loader.py b/torch/distributed/checkpoint/state_dict_loader.py index ec370367bc5e2..e7018876bc112 100644 --- a/torch/distributed/checkpoint/state_dict_loader.py +++ b/torch/distributed/checkpoint/state_dict_loader.py @@ -54,7 +54,7 @@ def load_state_dict( coordinator_rank (int): Rank to use to coordinate the checkpoint. rank0 is used by default. - no_dist (bool): If ``True``, distributed checkpoint will not save + no_dist (bool): If ``True``, distributed checkpoint will not load in SPMD style. (Default: ``False``) Returns: From 328a4c54759a310fb8b7921ff3e99a0cb17518c9 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Fri, 27 Oct 2023 19:41:57 +0000 Subject: [PATCH 60/78] [BE] Enhance `OpInfo.supported_dtype` (#111995) Current implementation is prone to errors, as it accepts any object, but does not print an error or something if device_type is not recognized. Remediate it by accepting both device-type and device identifies (either `torch.device` instance or "{device_type}:{ordinal}" string Fixes https://github.com/pytorch/pytorch/issues/111179 Pull Request resolved: https://github.com/pytorch/pytorch/pull/111995 Approved by: https://github.com/albanD --- test/test_testing.py | 24 +++++++++++++++++++++++- torch/testing/_internal/opinfo/core.py | 13 +++++-------- 2 files changed, 28 insertions(+), 9 deletions(-) diff --git a/test/test_testing.py b/test/test_testing.py index 1c668f8476c61..a4c99fafe0f27 100644 --- a/test/test_testing.py +++ b/test/test_testing.py @@ -22,7 +22,7 @@ parametrize, subtest, instantiate_parametrized_tests, dtype_name, TEST_WITH_ROCM, decorateIf) from torch.testing._internal.common_device_type import \ (PYTORCH_TESTING_DEVICE_EXCEPT_FOR_KEY, PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY, dtypes, - get_device_type_test_bases, instantiate_device_type_tests, onlyCUDA, onlyNativeDeviceTypes, + get_device_type_test_bases, instantiate_device_type_tests, onlyCPU, onlyCUDA, onlyNativeDeviceTypes, deviceCountAtLeast, ops, expectedFailureMeta, OpDTypes) from torch.testing._internal.common_methods_invocations import op_db from torch.testing._internal import opinfo @@ -413,6 +413,28 @@ def test_get_supported_dtypes(self, device): self.assertTrue(set(dtypes) == set(dynamic_dtypes)) self.assertTrue(set(dtypes) == set(dynamic_dispatch.dispatch_fn())) + @onlyCPU + @ops( + [ + op + for op in op_db + if len( + op.supported_dtypes("cpu").symmetric_difference( + op.supported_dtypes("cuda") + ) + ) + > 0 + ][:1], + dtypes=OpDTypes.none, + ) + def test_supported_dtypes(self, device, op): + self.assertNotEqual(op.supported_dtypes("cpu"), op.supported_dtypes("cuda")) + self.assertEqual(op.supported_dtypes("cuda"), op.supported_dtypes("cuda:0")) + self.assertEqual( + op.supported_dtypes(torch.device("cuda")), + op.supported_dtypes(torch.device("cuda", index=1)), + ) + instantiate_device_type_tests(TestTesting, globals()) diff --git a/torch/testing/_internal/opinfo/core.py b/torch/testing/_internal/opinfo/core.py index 722121d15f8aa..fc0fbf95864f1 100644 --- a/torch/testing/_internal/opinfo/core.py +++ b/torch/testing/_internal/opinfo/core.py @@ -1305,21 +1305,18 @@ def get_decorators(self, test_class, test_name, device, dtype, param_kwargs): return result def supported_dtypes(self, device_type): - if device_type == "cpu": - return self.dtypes + device_type = torch.device(device_type).type if device_type == "cuda": return self.dtypesIfROCM if TEST_WITH_ROCM else self.dtypesIfCUDA - else: - return self.dtypes + return self.dtypes def supported_backward_dtypes(self, device_type): if not self.supports_autograd: return set() + device_type = torch.device(device_type).type backward_dtypes = None - if device_type == "cpu": - backward_dtypes = self.backward_dtypes - elif device_type == "cuda": + if device_type == "cuda": backward_dtypes = ( self.backward_dtypesIfROCM if TEST_WITH_ROCM @@ -1333,7 +1330,7 @@ def supported_backward_dtypes(self, device_type): ) return set(allowed_backward_dtypes).intersection(backward_dtypes) - def supports_dtype(self, dtype, device_type): + def supports_dtype(self, dtype, device_type) -> bool: return dtype in self.supported_dtypes(device_type) @property From b9cb4103d771193084c023101f2380bf3c7ec9c6 Mon Sep 17 00:00:00 2001 From: Ter Chrng Ng Date: Fri, 27 Oct 2023 20:00:41 +0000 Subject: [PATCH 61/78] Fix iphoneos compilation (#111502) Summary: As title Test Plan: buck build @//arvr/mode/iphoneos/mac/opt //xplat/third-party/XNNPACK:ukernels_asm_aarch64 Reviewed By: mcr229 Differential Revision: D50423968 Pull Request resolved: https://github.com/pytorch/pytorch/pull/111502 Approved by: https://github.com/mcr229 --- third_party/xnnpack.buck.bzl | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/third_party/xnnpack.buck.bzl b/third_party/xnnpack.buck.bzl index 435ea12a9b2e1..8d7b061b81c04 100644 --- a/third_party/xnnpack.buck.bzl +++ b/third_party/xnnpack.buck.bzl @@ -1432,7 +1432,10 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F apple_sdks = (IOS, MACOSX, APPLETVOS), compiler_flags = [ "-O2", - ], + ] + select({ + "ovr_config//cpu:arm64": ["-march=armv8-a"], + "DEFAULT": [] + }), fbobjc_preprocessor_flags = [ "-DXNN_PRIVATE=", "-DXNN_INTERNAL=", @@ -1484,7 +1487,10 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F apple_sdks = (IOS, MACOSX, APPLETVOS), compiler_flags = [ "-O2", - ], + ] + select({ + "ovr_config//cpu:arm64": ["-march=armv8.2-a+dotprod"], + "DEFAULT": [] + }), fbobjc_preprocessor_flags = [ "-DXNN_PRIVATE=", "-DXNN_INTERNAL=", @@ -1567,7 +1573,10 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F apple_sdks = (IOS, MACOSX, APPLETVOS), compiler_flags = [ "-O2", - ], + ] + select({ + "ovr_config//cpu:arm64": ["-march=armv8.2-a+fp16"], + "DEFAULT": [] + }), fbobjc_preprocessor_flags = [ "-DXNN_PRIVATE=", "-DXNN_INTERNAL=", @@ -1727,7 +1736,10 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F apple_sdks = (IOS, MACOSX, APPLETVOS), compiler_flags = [ "-O2", - ], + ] + select({ + "ovr_config//cpu:arm64": ["-march=armv8.2-a+fp16+dotprod"], + "DEFAULT": [] + }), fbobjc_preprocessor_flags = [ "-DXNN_PRIVATE=", "-DXNN_INTERNAL=", From 1dcbd1c088f27213f4e2d68e43dd5d60e0336fe5 Mon Sep 17 00:00:00 2001 From: lezcano Date: Fri, 27 Oct 2023 16:52:24 +0000 Subject: [PATCH 62/78] [dynamo] [easy] Move Set to dicts.py (#110522) A set is more of a dict than a list if you ask me. This comes before the refactor where we implement sets and dicts via the same logic. Pull Request resolved: https://github.com/pytorch/pytorch/pull/110522 Approved by: https://github.com/jansel --- torch/_dynamo/symbolic_convert.py | 3 +- torch/_dynamo/utils.py | 43 ++++++ torch/_dynamo/variables/__init__.py | 2 +- torch/_dynamo/variables/builder.py | 2 +- torch/_dynamo/variables/builtin.py | 9 +- torch/_dynamo/variables/dicts.py | 170 +++++++++++++++++++++- torch/_dynamo/variables/lists.py | 212 +--------------------------- 7 files changed, 221 insertions(+), 220 deletions(-) diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 2de0903af8770..e7e373f00a857 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -81,7 +81,7 @@ GenericContextWrappingVariable, WithExitFunctionVariable, ) -from .variables.dicts import ConstDictVariable +from .variables.dicts import ConstDictVariable, SetVariable from .variables.functions import ( BaseUserFunctionVariable, NestedUserFunctionVariable, @@ -92,7 +92,6 @@ BaseListVariable, ListIteratorVariable, ListVariable, - SetVariable, SliceVariable, TupleVariable, ) diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 2a2190513cd32..737c1cb9f3dd9 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -913,6 +913,49 @@ def enum_repr(value, local): return local_name +def _get_fake_tensor(vt): + fake_tensor = vt.as_proxy().node.meta.get("example_value") + if not is_fake(fake_tensor): + unimplemented("Cannot check Tensor object identity without its fake value") + return fake_tensor + + +def iter_contains(items, search, tx, options, check_tensor_identity=False): + from .variables import BuiltinVariable, ConstantVariable, TensorVariable + + if search.is_python_constant(): + found = any( + x.is_python_constant() + and x.as_python_constant() == search.as_python_constant() + for x in items + ) + return ConstantVariable.create(found, **options) + + must_check_tensor_id = False + if check_tensor_identity and isinstance(search, TensorVariable): + must_check_tensor_id = True + # Match of Tensor means match of FakeTensor + search = _get_fake_tensor(search) + + found = None + for x in items: + if must_check_tensor_id: + if isinstance(x, TensorVariable): + if search is _get_fake_tensor(x): # Object equivalence + return ConstantVariable.create(True) + else: + check = BuiltinVariable(operator.eq).call_function(tx, [x, search], {}) + if found is None: + found = check + else: + found = BuiltinVariable(operator.or_).call_function( + tx, [check, found], {} + ) + if found is None: + found = ConstantVariable.create(False) + return found + + def dict_param_key_ids(value): return { id(k) for k in value.keys() if isinstance(k, (torch.nn.Parameter, torch.Tensor)) diff --git a/torch/_dynamo/variables/__init__.py b/torch/_dynamo/variables/__init__.py index f18eed9b8e0eb..bdeddd62947dc 100644 --- a/torch/_dynamo/variables/__init__.py +++ b/torch/_dynamo/variables/__init__.py @@ -16,6 +16,7 @@ CustomizedDictVariable, DataClassVariable, DefaultDictVariable, + SetVariable, ) from .functions import ( NestedUserFunctionVariable, @@ -29,7 +30,6 @@ ListVariable, NamedTupleVariable, RangeVariable, - SetVariable, SliceVariable, TupleVariable, ) diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 9f3f9decbf024..f830c695606cd 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -87,6 +87,7 @@ DefaultDictVariable, HFPretrainedConfigVariable, PythonSysModulesVariable, + SetVariable, ) from .distributed import ( DeviceMeshVariable, @@ -107,7 +108,6 @@ ListVariable, NamedTupleVariable, RangeVariable, - SetVariable, SizeVariable, SliceVariable, TupleIteratorVariable, diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 907e912fbce4d..49b2583e52e8f 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -39,12 +39,11 @@ from .base import MutableLocal, typestr, VariableTracker from .constant import ConstantVariable, EnumVariable from .ctx_manager import EventVariable, StreamVariable -from .dicts import ConstDictVariable +from .dicts import ConstDictVariable, SetVariable from .lists import ( BaseListVariable, ListIteratorVariable, ListVariable, - SetVariable, SizeVariable, TupleIteratorVariable, TupleVariable, @@ -821,7 +820,11 @@ def _dyn_proxy(self, tx, *args, **kwargs): def _call_iter_tuple_list(self, tx, obj=None, *args, **kwargs): if self._dynamic_args(*args, **kwargs): return self._dyn_proxy(tx, *args, **kwargs) - cls = variables.BaseListVariable.cls_for(self.fn) + # TODO This should probably be treated as a dict, or dicts should also be treated here + if self.fn == set: + cls = SetVariable + else: + cls = variables.BaseListVariable.cls_for(self.fn) if obj is None: if cls is SetVariable: return cls( diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index fda654f4eae98..fa22bac75c68f 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -3,7 +3,7 @@ import functools import inspect import sys -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional import torch import torch.fx @@ -13,9 +13,9 @@ from ..eval_frame import skip_code from ..exc import unimplemented -from ..guards import GuardBuilder +from ..guards import GuardBuilder, make_dupe_guard from ..source import AttrSource, GetItemSource, GlobalWeakRefSource -from ..utils import global_key_name, istensor +from ..utils import global_key_name, istensor, iter_contains from .base import MutableLocal, VariableTracker from .constant import ConstantVariable from .tensor import TensorVariable @@ -25,6 +25,8 @@ class ConstDictVariable(VariableTracker): def __init__(self, items, user_cls, recursively_contains=None, **kwargs): super().__init__(recursively_contains=recursively_contains, **kwargs) + # All the keys are constants + assert not any(isinstance(x, VariableTracker) for x in items) self.guards.update(VariableTracker.propagate(items.values())["guards"]) self.items = items self.user_cls = user_cls @@ -77,7 +79,7 @@ def call_method( args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]", ) -> "VariableTracker": - from . import ConstantVariable, SetVariable, TupleVariable + from . import ConstantVariable, TupleVariable options = VariableTracker.propagate(self, args, kwargs.values()) val = self.items @@ -296,6 +298,166 @@ def call_method( return super().call_method(tx, name, args, kwargs) +class SetVariable(VariableTracker): + @dataclasses.dataclass + class SetElement: + vt: VariableTracker + underlying_value: Any + + def __hash__(self) -> int: + return hash(self.underlying_value) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, SetVariable.SetElement): + return False + if isinstance(self.vt, variables.TensorVariable): + return self.underlying_value is other.underlying_value + else: + return self.underlying_value == other.underlying_value + + def __init__( + self, + items: List[VariableTracker], + recursively_contains=None, + regen_guards=True, + **kwargs, + ): + super().__init__(recursively_contains=recursively_contains, **kwargs) + # Note - Set is still backed by a list, because we want set behavior over the contents, + assert isinstance(items, list) + assert all(isinstance(x, VariableTracker) for x in items) + + self.items = [] + self._add(items) + + # Sometimes, we know that we have passed in the guards from the items in the set + if regen_guards: + self.guards.update(VariableTracker.propagate(items)["guards"]) + + def as_proxy(self): + return [x.as_proxy() for x in self.items] + + def python_type(self): + return set + + def reconstruct(self, codegen): + codegen.load_import_from("builtins", "set") + codegen.foreach(self.items) + return [ + create_instruction("BUILD_SET", arg=len(self.items)) + ] + create_call_function(1, True) + + # Note - this is only used for producing a set + def _as_set_element(self, vt): + from .base import VariableTracker + from .misc import MethodWrapperVariable + from .tensor import TensorVariable + + assert isinstance(vt, VariableTracker) + + if isinstance(vt, TensorVariable): + fake_tensor = vt.as_proxy().node.meta.get("example_value") + if fake_tensor is None: + unimplemented( + "Cannot check Tensor object identity without its fake value" + ) + return SetVariable.SetElement(vt, fake_tensor) + if isinstance(vt, ConstantVariable): + return SetVariable.SetElement(vt, vt.value) + if isinstance(vt, MethodWrapperVariable): + return SetVariable.SetElement(vt, vt.as_python_constant()) + + unimplemented(f"Sets with {type(vt)} NYI") + + @property + def _underlying_items(self): + underlying_items = set() + for current_item in self.items: + assert ( + current_item not in underlying_items + ), "Items modeling set invariant violated" + underlying_items.add(self._as_set_element(current_item)) + return underlying_items + + def _add(self, item): + underlying_items = self._underlying_items + + if isinstance(item, (list, set)): + items_to_add = item + else: + items_to_add = [item] + + for item_to_add in items_to_add: + set_element = self._as_set_element(item_to_add) + if set_element not in underlying_items: + underlying_items.add(set_element) + self.items.append(set_element.vt) + else: + for e in underlying_items: + if hash(set_element) == hash(e): + alias_guard = make_dupe_guard( + e.vt.source, set_element.vt.source + ) + if alias_guard: + e.vt = e.vt.add_guards( + {e.vt.source.make_guard(alias_guard)} + ) + + return self.items + + def call_method( + self, + tx, + name, + args: List[VariableTracker], + kwargs: Dict[str, VariableTracker], + ) -> "VariableTracker": + options = VariableTracker.propagate(self, args, kwargs.values()) + # Somewhat duplicative of CommonListMethodsVariable - but better than to violate substitution + # principles and end up with things like direct item access attempts on a set, or + # getitem sources. + if name == "add" and args and self.mutable_local: + assert not kwargs + item = args[0] + result = SetVariable( + self._add(item), + mutable_local=self.mutable_local, + regen_guards=False, + **options, + ) + tx.replace_all(self, result) + return ConstantVariable.create(None) + elif name == "pop" and self.mutable_local: + assert not kwargs + assert not args + items = list(self.items) + result = items.pop() + tx.replace_all( + self, + SetVariable(items, regen_guards=False, **options), + ) + return result + elif name == "__len__": + return ConstantVariable.create(len(self.items)).add_options(options) + elif name == "__contains__": + assert len(args) == 1 + assert not kwargs + return iter_contains( + self.items, args[0], tx, options, check_tensor_identity=True + ) + else: + return super().call_method(tx, name, args, kwargs) + + def getitem_const(self, arg: VariableTracker): + raise RuntimeError("Illegal to getitem on a set") + + def as_python_constant(self): + return self.python_type()([x.as_python_constant() for x in self.items]) + + def unpack_var_sequence(self, tx): + return [x.add_options(self) for x in self.items] + + class DataClassVariable(ConstDictVariable): """ This is a bit of a hack to deal with diff --git a/torch/_dynamo/variables/lists.py b/torch/_dynamo/variables/lists.py index a449a15a7fe14..d96a6571d52a9 100644 --- a/torch/_dynamo/variables/lists.py +++ b/torch/_dynamo/variables/lists.py @@ -1,23 +1,21 @@ import collections -import dataclasses import functools import inspect import operator -from typing import Any, Dict, List, Optional +from typing import Dict, List, Optional import torch import torch.fx -from torch._subclasses.fake_tensor import is_fake from .. import polyfill, variables from ..bytecode_transformation import create_call_function, create_instruction from ..exc import unimplemented -from ..guards import make_dupe_guard from ..source import GetItemSource from ..utils import ( get_fake_value, guard_if_dyn, is_namedtuple, + iter_contains, namedtuple_fields, odict_values, ) @@ -26,49 +24,6 @@ from .functions import UserFunctionVariable, UserMethodVariable -def _get_fake_tensor(vt): - fake_tensor = vt.as_proxy().node.meta.get("example_value") - if not is_fake(fake_tensor): - unimplemented("Cannot check Tensor object identity without its fake value") - return fake_tensor - - -def _listlike_contains_helper(items, search, tx, options, check_tensor_identity=False): - if search.is_python_constant(): - found = any( - x.is_python_constant() - and x.as_python_constant() == search.as_python_constant() - for x in items - ) - return variables.ConstantVariable.create(found, **options) - - from .builtin import BuiltinVariable - - must_check_tensor_id = False - if check_tensor_identity and isinstance(search, variables.TensorVariable): - must_check_tensor_id = True - # Match of Tensor means match of FakeTensor - search = _get_fake_tensor(search) - - found = None - for x in items: - if must_check_tensor_id: - if isinstance(x, variables.TensorVariable): - if search is _get_fake_tensor(x): # Object equivalence - return ConstantVariable.create(True) - else: - check = BuiltinVariable(operator.eq).call_function(tx, [x, search], {}) - if found is None: - found = check - else: - found = BuiltinVariable(operator.or_).call_function( - tx, [check, found], {} - ) - if found is None: - found = ConstantVariable.create(False) - return found - - class BaseListVariable(VariableTracker): @staticmethod def cls_for_instance(obj): @@ -84,7 +39,6 @@ def cls_for(obj): slice: SliceVariable, torch.Size: SizeVariable, tuple: TupleVariable, - set: SetVariable, odict_values: ListVariable, torch.nn.ParameterList: ListVariable, torch.nn.ModuleList: ListVariable, @@ -172,7 +126,7 @@ def call_method( elif name == "__contains__": assert len(args) == 1 assert not kwargs - return _listlike_contains_helper(self.items, args[0], tx, options) + return iter_contains(self.items, args[0], tx, options) elif name == "index": from .builder import SourcelessBuilder @@ -765,163 +719,3 @@ def reconstruct(self, codegen): class TupleIteratorVariable(ListIteratorVariable): pass - - -class SetVariable(VariableTracker): - @dataclasses.dataclass - class SetElement: - vt: VariableTracker - underlying_value: Any - - def __hash__(self) -> int: - return hash(self.underlying_value) - - def __eq__(self, other: object) -> bool: - if not isinstance(other, SetVariable.SetElement): - return False - if isinstance(self.vt, variables.TensorVariable): - return self.underlying_value is other.underlying_value - else: - return self.underlying_value == other.underlying_value - - def __init__( - self, - items: List[VariableTracker], - recursively_contains=None, - regen_guards=True, - **kwargs, - ): - super().__init__(recursively_contains=recursively_contains, **kwargs) - # Note - Set is still backed by a list, because we want set behavior over the contents, - assert isinstance(items, list) - assert all(isinstance(x, VariableTracker) for x in items) - - self.items = [] - self._add(items) - - # Sometimes, we know that we have passed in the guards from the items in the set - if regen_guards: - self.guards.update(VariableTracker.propagate(items)["guards"]) - - def as_proxy(self): - return [x.as_proxy() for x in self.items] - - def python_type(self): - return set - - def reconstruct(self, codegen): - codegen.load_import_from("builtins", "set") - codegen.foreach(self.items) - return [ - create_instruction("BUILD_SET", arg=len(self.items)) - ] + create_call_function(1, True) - - # Note - this is only used for producing a set - def _as_set_element(self, vt): - from .base import VariableTracker - from .misc import MethodWrapperVariable - from .tensor import TensorVariable - - assert isinstance(vt, VariableTracker) - - if isinstance(vt, TensorVariable): - fake_tensor = vt.as_proxy().node.meta.get("example_value") - if fake_tensor is None: - unimplemented( - "Cannot check Tensor object identity without its fake value" - ) - return SetVariable.SetElement(vt, fake_tensor) - if isinstance(vt, ConstantVariable): - return SetVariable.SetElement(vt, vt.value) - if isinstance(vt, MethodWrapperVariable): - return SetVariable.SetElement(vt, vt.as_python_constant()) - - unimplemented(f"Sets with {type(vt)} NYI") - - @property - def _underlying_items(self): - underlying_items = set() - for current_item in self.items: - assert ( - current_item not in underlying_items - ), "Items modeling set invariant violated" - underlying_items.add(self._as_set_element(current_item)) - return underlying_items - - def _add(self, item): - underlying_items = self._underlying_items - - if isinstance(item, (list, set)): - items_to_add = item - else: - items_to_add = [item] - - for item_to_add in items_to_add: - set_element = self._as_set_element(item_to_add) - if set_element not in underlying_items: - underlying_items.add(set_element) - self.items.append(set_element.vt) - else: - for e in underlying_items: - if hash(set_element) == hash(e): - alias_guard = make_dupe_guard( - e.vt.source, set_element.vt.source - ) - if alias_guard: - e.vt = e.vt.add_guards( - {e.vt.source.make_guard(alias_guard)} - ) - - return self.items - - def call_method( - self, - tx, - name, - args: List[VariableTracker], - kwargs: Dict[str, VariableTracker], - ) -> "VariableTracker": - options = VariableTracker.propagate(self, args, kwargs.values()) - # Somewhat duplicative of CommonListMethodsVariable - but better than to violate substitution - # principles and end up with things like direct item access attempts on a set, or - # getitem sources. - if name == "add" and args and self.mutable_local: - assert not kwargs - item = args[0] - result = SetVariable( - self._add(item), - mutable_local=self.mutable_local, - regen_guards=False, - **options, - ) - tx.replace_all(self, result) - return ConstantVariable.create(None) - elif name == "pop" and self.mutable_local: - assert not kwargs - assert not args - items = list(self.items) - result = items.pop() - tx.replace_all( - self, - SetVariable(items, regen_guards=False, **options), - ) - return result - elif name == "__len__": - return ConstantVariable.create(len(self.items)).add_options(options) - elif name == "__contains__": - assert len(args) == 1 - assert not kwargs - return _listlike_contains_helper( - self.items, args[0], tx, options, check_tensor_identity=True - ) - else: - return super().call_method(tx, name, args, kwargs) - - def getitem_const(self, arg: VariableTracker): - raise RuntimeError("Illegal to getitem on a set") - - def as_python_constant(self): - return self.python_type()([x.as_python_constant() for x in self.items]) - - def unpack_var_sequence(self, tx): - return [x.add_options(self) for x in self.items] From 1774704fc1c4edba28e08d4d41d5b65b04d9b3b3 Mon Sep 17 00:00:00 2001 From: lezcano Date: Fri, 27 Oct 2023 16:52:25 +0000 Subject: [PATCH 63/78] [dynamo] Simplify add_dict in preparation to refactor it with call_set (#110523) The previous implementation had a fair amount of repeated code, and did things like calling `add_options` where options was always empty (which is fine, as the guards are already set within ConstDictVariable). Pull Request resolved: https://github.com/pytorch/pytorch/pull/110523 Approved by: https://github.com/yanboliang, https://github.com/jansel ghstack dependencies: #110522 --- torch/_dynamo/variables/builtin.py | 95 ++++++++++-------------------- torch/_dynamo/variables/misc.py | 18 +----- 2 files changed, 33 insertions(+), 80 deletions(-) diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 49b2583e52e8f..5d18891083ae7 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -37,7 +37,7 @@ specialize_args_kwargs, ) from .base import MutableLocal, typestr, VariableTracker -from .constant import ConstantVariable, EnumVariable +from .constant import ConstantVariable from .ctx_manager import EventVariable, StreamVariable from .dicts import ConstDictVariable, SetVariable from .lists import ( @@ -861,30 +861,6 @@ def _call_iter_tuple_list(self, tx, obj=None, *args, **kwargs): call_list = _call_iter_tuple_list call_set = _call_iter_tuple_list - @staticmethod - def is_supported_call_dict_arg(tx, arg): - return ( - arg is None - or isinstance(arg, ConstDictVariable) - or ( - isinstance( - arg, - ( - ListVariable, - TupleVariable, - ListIteratorVariable, - ), - ) - and all( - isinstance(x, (ListVariable, TupleVariable)) - and isinstance( - x.unpack_var_sequence(tx)[0], (ConstantVariable, EnumVariable) - ) - for x in arg.unpack_var_sequence(tx) - ) - ) - ) - def call_callable(self, tx, arg): from .functions import BaseUserFunctionVariable @@ -893,35 +869,6 @@ def call_callable(self, tx, arg): ): return variables.ConstantVariable.create(True).add_options(arg) - @staticmethod - def call_dict_helper(tx, user_cls, arg, **options): - if arg is None or isinstance(arg, dict): - return ConstDictVariable( - arg if arg is not None else {}, user_cls, mutable_local=MutableLocal() - ).add_options(options) - elif isinstance(arg, variables.ConstDictVariable): - return arg.clone( - user_cls=user_cls, mutable_local=MutableLocal() - ).add_options(options) - elif isinstance( - arg, - ( - ListVariable, - TupleVariable, - ListIteratorVariable, - ), - ): - items = user_cls() - for x in arg.unpack_var_sequence(tx): - k = x.unpack_var_sequence(tx)[0].as_python_constant() - v = x.unpack_var_sequence(tx)[1] - items.update({k: v}) - return ConstDictVariable( - items, user_cls, mutable_local=MutableLocal() - ).add_options(options) - else: - raise AssertionError("call_dict_helper with illegal arg") - def call_cast(self, _, *args, **kwargs): if len(args) == 2: return args[1] @@ -929,20 +876,38 @@ def call_cast(self, _, *args, **kwargs): unimplemented(f"unsupported args to builtin cast(): {args} {kwargs}") def call_dict(self, tx, *args, **kwargs): - if not (args or kwargs): - return self.call_dict_helper(tx, dict, None) - elif ( - not kwargs - and len(args) == 1 - and self.is_supported_call_dict_arg(tx, args[0]) - ): - return self.call_dict_helper(tx, dict, args[0]) + return BuiltinVariable.call_custom_dict(tx, dict, *args, **kwargs) + + @staticmethod + def call_custom_dict(tx, user_cls, *args, **kwargs): + if not kwargs: + if not args: + args = ({},) + assert len(args) == 1 + arg = args[0] + if isinstance(arg, dict): + return ConstDictVariable(arg, user_cls, mutable_local=MutableLocal()) + elif isinstance(arg, variables.ConstDictVariable): + return arg.clone(user_cls=user_cls, mutable_local=MutableLocal()) + elif isinstance( + arg, + ( + ListVariable, + TupleVariable, + ListIteratorVariable, + ), + ): + items = user_cls() + for x in arg.unpack_var_sequence(tx): + k, v = x.unpack_var_sequence(tx) + k = ConstDictVariable.get_key(k) + items.update({k: v}) + return ConstDictVariable(items, user_cls, mutable_local=MutableLocal()) elif not args and kwargs: return variables.ConstDictVariable( - dict(kwargs), user_cls=dict, mutable_local=MutableLocal() + dict(kwargs), user_cls=user_cls, mutable_local=MutableLocal() ) - else: - unimplemented(f"dict(): {args} {kwargs}") + unimplemented(f"dict(): {args} {kwargs}") def call_zip(self, tx, *args): options = VariableTracker.propagate(self, args) diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index a57f9bfdc4302..ad07691919f7f 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -715,21 +715,9 @@ def call_function( if inspect.getattr_static(self.value, "_torchdynamo_disable", False): unimplemented(f"call torch._dynamo.disable() wrapped function {self.value}") # Allowlist a few popular classes(e.g, collections.OrderedDict) calls in skip files. - elif self.value is collections.OrderedDict and ( - len(args) == 0 - or len(args) == 1 - and BuiltinVariable.is_supported_call_dict_arg(tx, args[0]) - ): - if len(args) == 0: - args = dict(kwargs) if kwargs else None - else: - args = args[0] - - return BuiltinVariable.call_dict_helper( - tx, - collections.OrderedDict, - args, - **options, + elif self.value is collections.OrderedDict: + return BuiltinVariable.call_custom_dict( + tx, collections.OrderedDict, *args, **kwargs ) elif ( self.value is collections.defaultdict From ca2106e871efac9b6213c40d4f862d33813b7418 Mon Sep 17 00:00:00 2001 From: Justin Yip Date: Fri, 27 Oct 2023 20:20:41 +0000 Subject: [PATCH 64/78] [pytorch-vulkan] floor-divide for tensor, tensor (#112190) Summary: tsia Test Plan: ## Compile on Mac and run on Android ``` buck2 build -c ndk.static_linking=true -c pt.enable_qpl=0 --target-platforms=ovr_config//platform/android:arm32-fbsource //xplat/caffe2:pt_vulkan_api_test_binAndroid --show-output && adb push buck-out/v2/gen/fbsource/f1f3f9bed27e143c/xplat/caffe2/__pt_vulkan_api_test_binAndroid__/pt_vulkan_api_test_binAndroid /data/local/tmp ``` Run on android ``` $ adb shell /data/local/tmp/pt_vulkan_api_test_binAndroid ... [ RUN ] VulkanAPITest.lstm_prepack_success [ OK ] VulkanAPITest.lstm_prepack_success (11 ms) [ RUN ] VulkanAPITest.querypool_flushed_shader_log xplat/caffe2/aten/src/ATen/test/vulkan_api_test.cpp:7667: Skipped QueryPool is not available [ SKIPPED ] VulkanAPITest.querypool_flushed_shader_log (0 ms) [----------] 396 tests from VulkanAPITest (29980 ms total) [----------] Global test environment tear-down [==========] 396 tests from 1 test suite ran. (29980 ms total) [ PASSED ] 395 tests. [ SKIPPED ] 1 test, listed below: [ SKIPPED ] VulkanAPITest.querypool_flushed_shader_log YOU HAVE 7 DISABLED TESTS ``` All Passed. Full Output: P865232089 Reviewed By: copyrightly Differential Revision: D50677361 Pull Request resolved: https://github.com/pytorch/pytorch/pull/112190 Approved by: https://github.com/manuelcandales --- .../glsl/templates/binary_op_params.yaml | 6 ++ aten/src/ATen/native/vulkan/ops/BinaryOp.cpp | 16 ++++ aten/src/ATen/test/vulkan_api_test.cpp | 88 +++++++++++++++++++ 3 files changed, 110 insertions(+) diff --git a/aten/src/ATen/native/vulkan/glsl/templates/binary_op_params.yaml b/aten/src/ATen/native/vulkan/glsl/templates/binary_op_params.yaml index c41a760a7e1a2..87bb76d43741f 100644 --- a/aten/src/ATen/native/vulkan/glsl/templates/binary_op_params.yaml +++ b/aten/src/ATen/native/vulkan/glsl/templates/binary_op_params.yaml @@ -40,6 +40,9 @@ binary_op_tensor: - NAME: pow IS_DIV: 0 OPERATOR: pow(X, Y) + - NAME: floor_divide + IS_DIV: 1 + OPERATOR: floor(X / Y) binary_op_tensor_inplace: parameter_names_with_default_values: @@ -59,3 +62,6 @@ binary_op_tensor_inplace: - NAME: pow_ IS_DIV: 0 OPERATOR: pow(X, Y) + - NAME: floor_divide_ + IS_DIV: 1 + OPERATOR: floor(X / Y) diff --git a/aten/src/ATen/native/vulkan/ops/BinaryOp.cpp b/aten/src/ATen/native/vulkan/ops/BinaryOp.cpp index 754fa4982cc02..4bd6611b22889 100644 --- a/aten/src/ATen/native/vulkan/ops/BinaryOp.cpp +++ b/aten/src/ATen/native/vulkan/ops/BinaryOp.cpp @@ -539,6 +539,16 @@ Tensor& floor_divide_scalar_(Tensor& self, const Scalar& other) { VK_KERNEL(floor_mul_scalar_)); } +Tensor floor_divide_tensor(const Tensor& self, const Tensor& other) { + return binary_op_tensor( + self, other, c10::optional(), VK_KERNEL(floor_divide)); +} + +Tensor& floor_divide_tensor_(Tensor& self, const Tensor& other_arg) { + return binary_op_tensor_( + self, other_arg, c10::optional(), VK_KERNEL(floor_divide_)); +} + #ifdef USE_VULKAN_API TORCH_LIBRARY_IMPL(aten, Vulkan, m) { @@ -572,6 +582,12 @@ TORCH_LIBRARY_IMPL(aten, Vulkan, m) { m.impl( TORCH_SELECTIVE_NAME("aten::floor_divide_.Scalar"), TORCH_FN(floor_divide_scalar_)); + m.impl( + TORCH_SELECTIVE_NAME("aten::floor_divide"), + TORCH_FN(floor_divide_tensor)); + m.impl( + TORCH_SELECTIVE_NAME("aten::floor_divide_.Tensor"), + TORCH_FN(floor_divide_tensor_)); } #endif /* USE_VULKAN_API */ diff --git a/aten/src/ATen/test/vulkan_api_test.cpp b/aten/src/ATen/test/vulkan_api_test.cpp index f61aa060091fb..02a6a60d991e8 100644 --- a/aten/src/ATen/test/vulkan_api_test.cpp +++ b/aten/src/ATen/test/vulkan_api_test.cpp @@ -3720,6 +3720,94 @@ TEST_F(VulkanAPITest, floor_divide_scalar_inplace) { test_floor_divide_scalar_inplace({3, 3, 12, 12}, 0.3, 0.08); } +TEST_F(VulkanAPITest, floor_divide_zero_dim_tensor) { + c10::InferenceMode mode; + + std::vector input_shape{5, 3, 4, 5}; + float input_scale = 100.0; + + auto in_cpu = at::rand(input_shape, at::device(at::kCPU).dtype(at::kFloat)); + in_cpu = at::mul(in_cpu, input_scale); + auto in_vk = in_cpu.vulkan(); + + auto other_cpu = at::zeros({}, at::device(at::kCPU).dtype(at::kFloat)) + 10.0f; + auto other_vk = other_cpu.vulkan(); + + auto out_cpu = at::floor_divide(in_cpu, other_cpu); + auto out_vk = at::floor_divide(in_vk, other_vk); + + // max tolerance is 1.0 due to floor. + // may consider adding extra check on number of violation. it should be rare. + const auto check = checkRtol(out_cpu - out_vk.cpu(), 1.0f); + if (!check) { + std::cout << "floor_divide test failed with " + << "scale: " << input_scale + << std::endl; + } + + ASSERT_TRUE(check); +} + +TEST_F(VulkanAPITest, floor_divide_tensor) { + c10::InferenceMode mode; + + std::vector input_shape{6, 3, 5, 5}; + float input_scale = 10.0; + + auto in_cpu = at::rand(input_shape, at::device(at::kCPU).dtype(at::kFloat)); + in_cpu = at::mul(in_cpu, input_scale); + // "other" is at least 0.5 to avoid rounding error causes by very small + // values. + auto other_cpu = + at::rand(input_shape, at::device(at::kCPU).dtype(at::kFloat)) + 0.5; + + auto in_vk = in_cpu.vulkan(); + auto other_vk = other_cpu.vulkan(); + + auto out_cpu = at::floor_divide(in_cpu, other_cpu); + auto out_vk = at::floor_divide(in_vk, other_vk); + + // max tolerance is 1.0 due to floor. + // may consider adding extra check on number of violation. it should be rare. + const auto check = checkRtol(out_cpu - out_vk.cpu(), 1.0f); + if (!check) { + std::cout << "floor_divide test failed with " + << "scale: " << input_scale << std::endl; + } + + ASSERT_TRUE(check); +} + +TEST_F(VulkanAPITest, floor_divide_tensor_inplace) { + c10::InferenceMode mode; + + std::vector input_shape{5, 3, 5, 5}; + float input_scale = 10.0; + + auto in_cpu = at::rand(input_shape, at::device(at::kCPU).dtype(at::kFloat)); + in_cpu = at::mul(in_cpu, input_scale); + // "other" is at least 0.5 to avoid rounding error causes by very small + // values. + auto other_cpu = + at::rand(input_shape, at::device(at::kCPU).dtype(at::kFloat)) + 0.5; + + auto in_vk = in_cpu.vulkan(); + auto other_vk = other_cpu.vulkan(); + + in_cpu.floor_divide_(other_cpu); + in_vk.floor_divide_(other_vk); + + // max tolerance is 1.0 due to floor. + // may consider adding extra check on number of violation. it should be rare. + const auto check = checkRtol(in_cpu - in_vk.cpu(), 1.0f); + if (!check) { + std::cout << "floor_divide test failed with " + << "scale: " << input_scale << std::endl; + } + + ASSERT_TRUE(check); +} + TEST_F(VulkanAPITest, relu) { const auto in_cpu = at::rand({17, 197, 302, 5}, at::device(at::kCPU).dtype(at::kFloat)); const auto in_vulkan = in_cpu.vulkan(); From 6d685ff54f99aa4b26ddab6681cf8589597d7a15 Mon Sep 17 00:00:00 2001 From: Nikita Shulga <2453524+malfet@users.noreply.github.com> Date: Fri, 27 Oct 2023 20:48:32 +0000 Subject: [PATCH 65/78] [BE] Remove float8 from vec is_floating_type definition (#112196) As it's not supported yet, and it's also not clear, how support should look like Pull Request resolved: https://github.com/pytorch/pytorch/pull/112196 Approved by: https://github.com/drisspg --- aten/src/ATen/cpu/vec/vec_base.h | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/aten/src/ATen/cpu/vec/vec_base.h b/aten/src/ATen/cpu/vec/vec_base.h index 35283252c8b84..f0a6d624247ce 100644 --- a/aten/src/ATen/cpu/vec/vec_base.h +++ b/aten/src/ATen/cpu/vec/vec_base.h @@ -70,9 +70,7 @@ struct is_floating_point: std::integral_constant::value || std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value> { + std::is_same::value> { }; template From f7dc0ae16c4637be0a7f20a1d9cd4311e9a6d3e8 Mon Sep 17 00:00:00 2001 From: chilli Date: Thu, 26 Oct 2023 19:15:40 -0700 Subject: [PATCH 66/78] Some cleanups in pattern matcher (#112101) Pull Request resolved: https://github.com/pytorch/pytorch/pull/112101 Approved by: https://github.com/eellison ghstack dependencies: #112093 --- test/inductor/test_pattern_matcher.py | 113 +++++------------- .../_inductor/fx_passes/freezing_patterns.py | 8 +- torch/_inductor/fx_passes/fuse_attention.py | 15 ++- torch/_inductor/fx_passes/pad_mm.py | 12 +- torch/_inductor/fx_passes/post_grad.py | 9 +- torch/_inductor/pattern_matcher.py | 73 +++++++---- 6 files changed, 102 insertions(+), 128 deletions(-) diff --git a/test/inductor/test_pattern_matcher.py b/test/inductor/test_pattern_matcher.py index a27be82ac40b0..d1f7abff5a9ff 100644 --- a/test/inductor/test_pattern_matcher.py +++ b/test/inductor/test_pattern_matcher.py @@ -26,6 +26,24 @@ class TestPaternMatcher(TestCase): + def common(self, fn, args, expected_matches, expected_nodes): + counters.clear() + torch.manual_seed(42) + expected = fn(*args) + torch.manual_seed(42) + actual = torch.compile(fn)(*args) + torch.testing.assert_close(actual, expected) + if inductor_config.cpp_wrapper: + # CPP wrapper runs everything twice, so we'll match the pattern twice + expected_matches *= 2 + expected_nodes *= 2 + + self.assertEqual( + counters["inductor"]["pattern_matcher_count"], expected_matches + ) + self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], expected_nodes) + counters.clear() + def test_mm_plus_mm(self): def fn(a, b, c, d): return torch.add(torch.mm(a, b), torch.mm(c, d)) @@ -58,12 +76,7 @@ def fn(a, b, c, d): ), ] for args in args_list: - counters.clear() - expected = fn(*args) - actual = torch.compile(fn)(*args) - torch.testing.assert_close(actual, expected) - self.assertEqual(counters["inductor"]["pattern_matcher_count"], 1) - self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 3) + self.common(fn, args, 1, 3) def _test_fused_int_mm_mul_impl(self, fn, args, fused_int_mm_mul_expected=True): torch._dynamo.reset() @@ -467,11 +480,7 @@ def fn(a, b, c): torch.randn(16, 16, device="cuda"), torch.randn(16, 16, device="cuda"), ] - expected = fn(*args) - actual = torch.compile(fn)(*args) - torch.testing.assert_close(actual, expected) - self.assertEqual(counters["inductor"]["pattern_matcher_count"], 2) - self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 5) + self.common(fn, args, 2, 5) def test_cat_addmm(self): def fn(a, b, c): @@ -489,11 +498,7 @@ def fn(a, b, c): torch.randn(16, 16, device="cuda"), torch.randn(16, 16, device="cuda"), ] - expected = fn(*args) - actual = torch.compile(fn)(*args) - torch.testing.assert_close(actual, expected) - self.assertEqual(counters["inductor"]["pattern_matcher_count"], 2) - self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 5) + self.common(fn, args, 2, 5) def test_cat_slice_cat(self): def check_counter(counter, expected): @@ -513,17 +518,13 @@ def fn(a, b): torch.randn(2, 32, device="cuda"), torch.randn(2, 16, device="cuda"), ] - expected = fn(*args) - actual = torch.compile(fn)(*args) - torch.testing.assert_close(actual, expected) - check_counter(counters["inductor"]["pattern_matcher_count"], 1) - check_counter(counters["inductor"]["pattern_matcher_nodes"], 3) + self.common(fn, args, 1, 3) - counters.clear() args = [ torch.randn(2, 8, device="cuda"), torch.randn(2, 16, device="cuda"), ] + counters.clear() expected = fn(*args) actual = torch.compile(fn)(*args) torch.testing.assert_close(actual, expected) @@ -539,16 +540,11 @@ def fn(a, b): slice_2 = torch.ops.aten.slice.Tensor(slice_1, 1, 0, -1) return torch.ops.aten.cat.default([cat_1, slice_2], 1) - counters.clear() args = [ torch.randn(2, 8, device="cuda"), torch.randn(2, 16, device="cuda"), ] - expected = fn(*args) - actual = torch.compile(fn)(*args) - torch.testing.assert_close(actual, expected) - check_counter(counters["inductor"]["pattern_matcher_count"], 1) - check_counter(counters["inductor"]["pattern_matcher_nodes"], 3) + self.common(fn, args, 1, 3) def test_pointless_convert(self): def fn1(x): @@ -624,12 +620,7 @@ def fn(a): args = [ torch.randn(2, 32, device="cuda"), ] - expected = fn(*args) - actual = torch.compile(fn)(*args) - torch.testing.assert_close(actual, expected) - self.assertEqual(counters["inductor"]["pattern_matcher_count"], 1) - self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 4) - counters.clear() + self.common(fn, args, 1, 4) # Not all getitems are passed to cat def fn(a): @@ -643,12 +634,7 @@ def fn(a): args = [ torch.randn(2, 32, device="cuda"), ] - expected = fn(*args) - actual = torch.compile(fn)(*args) - torch.testing.assert_close(actual, expected) - self.assertEqual(counters["inductor"]["pattern_matcher_count"], 0) - self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 0) - counters.clear() + self.common(fn, args, 0, 0) # Different dimensions (TODO this case should be handled by replacing with a reshape) def fn(a): @@ -661,11 +647,7 @@ def fn(a): args = [ torch.randn(2, 32, device="cuda"), ] - expected = fn(*args) - actual = torch.compile(fn)(*args) - torch.testing.assert_close(actual, expected) - self.assertEqual(counters["inductor"]["pattern_matcher_count"], 0) - self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 0) + self.common(fn, args, 0, 0) # https://github.com/pytorch/pytorch/issues/99686. def fn(a): @@ -676,11 +658,7 @@ def fn(a): args = [ torch.randn(1, 8, device="cuda"), ] - expected = fn(*args) - actual = torch.compile(fn)(*args) - torch.testing.assert_close(actual, expected) - self.assertEqual(counters["inductor"]["pattern_matcher_count"], 0) - self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 0) + self.common(fn, args, 0, 0) def test_cat_splitwithsizes(self): # good case @@ -696,12 +674,7 @@ def fn(a, b, c): torch.randn(2, 3, device="cuda"), torch.randn(2, 5, device="cuda"), ] - expected = fn(*args) - actual = torch.compile(fn)(*args) - torch.testing.assert_close(actual, expected) - self.assertEqual(counters["inductor"]["pattern_matcher_count"], 1) - self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 2) - counters.clear() + self.common(fn, args, 1, 2) # cat node has other users def fn(a, b, c): @@ -716,12 +689,7 @@ def fn(a, b, c): torch.randn(2, 3, device="cuda"), torch.randn(2, 5, device="cuda"), ] - expected = fn(*args) - actual = torch.compile(fn)(*args) - torch.testing.assert_close(actual, expected) - self.assertEqual(counters["inductor"]["pattern_matcher_count"], 0) - self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 0) - counters.clear() + self.common(fn, args, 0, 0) # cat and split dims are different def fn(a, b, c): @@ -736,12 +704,7 @@ def fn(a, b, c): torch.randn(10, 3, device="cuda"), torch.randn(10, 5, device="cuda"), ] - expected = fn(*args) - actual = torch.compile(fn)(*args) - torch.testing.assert_close(actual, expected) - self.assertEqual(counters["inductor"]["pattern_matcher_count"], 0) - self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 0) - counters.clear() + self.common(fn, args, 0, 0) # cat and split lenghts are different def fn(a, b, c): @@ -754,12 +717,7 @@ def fn(a, b, c): torch.randn(2, 3, device="cuda"), torch.randn(2, 5, device="cuda"), ] - expected = fn(*args) - actual = torch.compile(fn)(*args) - torch.testing.assert_close(actual, expected) - self.assertEqual(counters["inductor"]["pattern_matcher_count"], 0) - self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 0) - counters.clear() + self.common(fn, args, 0, 0) # cat input sizes and split sizes are different def fn(a, b, c): @@ -774,12 +732,7 @@ def fn(a, b, c): torch.randn(2, 3, device="cuda"), torch.randn(2, 5, device="cuda"), ] - expected = fn(*args) - actual = torch.compile(fn)(*args) - torch.testing.assert_close(actual, expected) - self.assertEqual(counters["inductor"]["pattern_matcher_count"], 0) - self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 0) - counters.clear() + self.common(fn, args, 0, 0) def test_match_with_mutation(self): from torch._inductor.pattern_matcher import ( diff --git a/torch/_inductor/fx_passes/freezing_patterns.py b/torch/_inductor/fx_passes/freezing_patterns.py index dc0eb65dc3269..cafd7aa4eee83 100644 --- a/torch/_inductor/fx_passes/freezing_patterns.py +++ b/torch/_inductor/fx_passes/freezing_patterns.py @@ -8,8 +8,8 @@ from ..pattern_matcher import ( _return_true, CallFunction, + fwd_only, Ignored, - inference_graph, init_once_fakemode, KeywordArg, Match, @@ -144,7 +144,7 @@ def matmul_replacement(inp, w1, w2, w3): matmul_fuse_pattern, matmul_replacement, [val(), val(), val(), val()], - inference_graph, + fwd_only, pass_patterns[0], extra_check=check_concat_weights, exclusive_arg_names=("w1", "w2", "w3"), @@ -162,7 +162,7 @@ def matmul_replacement_two(inp, w1, w2): matmul_fuse_pattern_two, matmul_replacement_two, [val(), val(), val()], - inference_graph, + fwd_only, pass_patterns[0], extra_check=check_concat_weights, exclusive_arg_names=("w1", "w2"), @@ -184,7 +184,7 @@ def addmm_fuse_replacement_second(inp, w1, w2, w3, b1, b2, b3): addmm_fuse_pattern_second, addmm_fuse_replacement_second, [val() for _ in range(7)], - inference_graph, + fwd_only, pass_patterns[0], extra_check=check_concat_weights, exclusive_arg_names=("w1", "w2", "w3", "b1", "b2", "b3"), diff --git a/torch/_inductor/fx_passes/fuse_attention.py b/torch/_inductor/fx_passes/fuse_attention.py index 8001d589852a0..207f64558a0e7 100644 --- a/torch/_inductor/fx_passes/fuse_attention.py +++ b/torch/_inductor/fx_passes/fuse_attention.py @@ -7,9 +7,9 @@ from ..._dynamo.utils import counters from ..pattern_matcher import ( filter_nodes, - inference_graph, + fwd_only, + joint_fwd_bwd, register_replacement, - training_graph, ) log = logging.getLogger(__name__) @@ -513,7 +513,6 @@ def _get_sfdp_patterns(): # XXX: when adding a new pattern, re-run `gen_attention_patterns` so the pattern # gets serialized to a python file and does not require tracing at runtime. assert isinstance(workaround, dict) - training_args = [*args, *workaround.values()] name = pattern.__name__ training_name = ( @@ -522,9 +521,9 @@ def _get_sfdp_patterns(): yield training_name, { "search_fn": pattern, "replace_fn": replacement, - "example_inputs": training_args, - "trace_fn": training_graph, - "pass_dict": patterns, + "example_inputs": args, + "trace_fn": joint_fwd_bwd, + "pass_dicts": patterns, "extra_check": extra_check, "scalar_workaround": workaround, } @@ -547,8 +546,8 @@ def _get_sfdp_patterns(): "search_fn": pattern, "replace_fn": replacement, "example_inputs": args, - "trace_fn": inference_graph, - "pass_dict": patterns, + "trace_fn": fwd_only, + "pass_dicts": patterns, "extra_check": extra_check, "scalar_workaround": workaround, } diff --git a/torch/_inductor/fx_passes/pad_mm.py b/torch/_inductor/fx_passes/pad_mm.py index 7f01bdcef9a6b..3b6d6d180f86f 100644 --- a/torch/_inductor/fx_passes/pad_mm.py +++ b/torch/_inductor/fx_passes/pad_mm.py @@ -8,12 +8,7 @@ from torch.utils._mode_utils import no_dispatch from torch.utils._triton import has_triton -from ..pattern_matcher import ( - inference_graph, - Match, - register_replacement, - training_graph, -) +from ..pattern_matcher import fwd_only, joint_fwd_bwd, Match, register_replacement aten = torch.ops.aten @@ -453,12 +448,11 @@ def _pad_mm_init(): ), ]: assert isinstance(workaround, dict) # mypy is unable to infer the type properly - args = [*args, *workaround.values()] register_replacement( pattern, replacement, args, - training_graph, + joint_fwd_bwd, patterns, extra_check=extra_check, scalar_workaround=workaround, @@ -467,7 +461,7 @@ def _pad_mm_init(): pattern, replacement, args, - inference_graph, + fwd_only, patterns, extra_check=extra_check, scalar_workaround=workaround, diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py index 1defdab5abb13..7b2797f978c6d 100644 --- a/torch/_inductor/fx_passes/post_grad.py +++ b/torch/_inductor/fx_passes/post_grad.py @@ -566,6 +566,11 @@ def cat_noop(inputs, dim=0): return len(inputs) == 1 +@register_noop_decomp(aten.view) +def view_noop(arg, size): + return arg.shape == size + + # Note, we also always have a check for identical metadata, which is why these # are safe @register_noop_decomp([aten.copy], nop_arg=1) @@ -576,9 +581,7 @@ def true_noop(*args, **kwargs): def remove_noop_ops(graph: torch.fx.Graph): """ - Removes aten.clone and aten.alias ops from the graph when it's safe. - - Other no-ops should be done as decompositions that selectively turn into aten.clone or aten.alias + Removes both operations that are essentially aten.clone and operations that are essentially aten.alias from the graph. """ input_storages = set() output_storages = set() diff --git a/torch/_inductor/pattern_matcher.py b/torch/_inductor/pattern_matcher.py index c90d178b549e0..67fb3849ad720 100644 --- a/torch/_inductor/pattern_matcher.py +++ b/torch/_inductor/pattern_matcher.py @@ -124,7 +124,7 @@ def replace_with_graph(self, replacement_graph, args): def replace_by_example(self, replacement_fn, args, trace_fn=None): assert self.ctx if trace_fn is None: - trace_fn = inference_graph + trace_fn = fwd_only replacement = trace_fn( replacement_fn, torch.fx.map_arg(args, lambda arg: arg.meta["val"]) ) @@ -842,7 +842,7 @@ def register_replacement( replace_fn, example_inputs: Iterable[Any], trace_fn: Callable[[Callable[..., Any], Iterable[Any]], torch.fx.GraphModule], - pass_dict, + pass_dicts, extra_check=_return_true, scalar_workaround=(), exclusive_arg_names=(), @@ -857,10 +857,11 @@ def register_replacement( search_fn: traced to give original pattern replace_fn: traced to give replacement graph example_inputs: example inputs for initial trace - trace_fn: inference_graph or training_graph + trace_fn: fwd_only or joint_fwd_bwd pass_dict: dict of passes to register to extra_check: additional check to run on match(using real shapes) """ + argnames = [*inspect.signature(search_fn).parameters.keys()] def check_fn(match: Match): """ @@ -869,17 +870,24 @@ def check_fn(match: Match): Recheck the match with the correct shapes. """ + for name in argnames: + if name not in match.kwargs: + raise RuntimeError( + f"Not all inputs to pattern found in match.kwargs. Perhaps one " + f"of the inputs is unused? argnames={argnames}, match.kwargs={match.kwargs}" + ) + args = list( torch.fx.map_arg( [match.kwargs[name] for name in argnames], lambda n: n.meta["val"] # type: ignore[has-type] ) ) - for i, grad in enumerate(requires_grad): - if isinstance(args[i], torch.Tensor): - if grad and is_integer_dtype(args[i].dtype): - return False + with torch._dynamo.utils.detect_fake_mode(args): + for i, grad in enumerate(requires_grad): + if isinstance(args[i], torch.Tensor): + if grad and is_integer_dtype(args[i].dtype): + return False - with torch._dynamo.utils.detect_fake_mode(args): args[i] = torch.empty_strided( args[i].size(), args[i].stride(), @@ -887,27 +895,32 @@ def check_fn(match: Match): device=args[i].device, requires_grad=grad, ) - specific_graph = trace_fn(search_fn, args) - specific_pattern = fx_to_pattern( - specific_graph, argnames=argnames, exclusive_arg_names=exclusive_arg_names # type: ignore[has-type] - ) - specific_pattern_match = specific_pattern.match(match.output_nodes()[0]) - if specific_pattern_match and extra_check(specific_pattern_match): - # trace the pattern using the shapes form the user program - match.replacement_graph = trace_fn(replace_fn, args) - return True - return False + specific_graph = trace_fn(search_fn, args) + specific_pattern = fx_to_pattern( + specific_graph, + argnames=argnames, + exclusive_arg_names=exclusive_arg_names, # type: ignore[has-type] + scalar_workaround=scalar_workaround, + ) + specific_pattern_match = specific_pattern.match(match.output_nodes()[0]) + if specific_pattern_match and extra_check(specific_pattern_match): + # trace the pattern using the shapes from the user program + match.replacement_graph = trace_fn(replace_fn, args) + return True + return False def normalize_args(**kwargs): args = [] for name in argnames: # type: ignore[has-type] args.append(kwargs.pop(name)) for i in range(1, len(kwargs) + 1): + if f"tangents_{i}" not in kwargs: + break args.append(kwargs.pop(f"tangents_{i}")) assert not kwargs, f"leftover kwargs: {kwargs!r}" return args - if trace_fn is training_graph: + if trace_fn is joint_fwd_bwd: # If inference mode is enabled during compilation, assume that we don't # want to match on any training graph patterns if torch.is_inference_mode_enabled(): @@ -915,7 +928,6 @@ def normalize_args(**kwargs): # TODO: Revisit the functionalize_rng_ops for lowmem dropout with functorch_config.patch(functionalize_rng_ops=False): - argnames = [*inspect.signature(search_fn).parameters.keys()] requires_grad: List[bool] = [ isinstance(x, torch.Tensor) and x.requires_grad for x in example_inputs ] @@ -938,7 +950,7 @@ def normalize_args(**kwargs): extra_check=check_fn, normalize_args=normalize_args, ) - pattern.register(pass_dict) + pattern.register(pass_dicts) return pattern.pattern @@ -947,7 +959,20 @@ def gen_pattern( search_fn, example_inputs, trace_fn, scalar_workaround=(), exclusive_arg_names=() ) -> PatternExpr: argnames = [*inspect.signature(search_fn).parameters.keys()] - search_gm = trace_fn(search_fn, example_inputs) + + if scalar_workaround == (): + scalar_workaround = {} + flat_inputs = [] + input_idx = 0 # Positional arguments index + + for argname in argnames: + if argname in scalar_workaround: + flat_inputs.append(scalar_workaround[argname]) + else: + flat_inputs.append(example_inputs[input_idx]) + input_idx += 1 + + search_gm = trace_fn(search_fn, flat_inputs) return fx_to_pattern( search_gm, ignore_types=(int, float, list, torch.device, torch.dtype), @@ -1175,7 +1200,7 @@ def run_node(self, n): @torch.no_grad() -def inference_graph(fn, args) -> torch.fx.GraphModule: +def fwd_only(fn, args) -> torch.fx.GraphModule: """Build a normalized inference graph, for use with fx_to_pattern""" # TODO - look into using aot autograd, asserting no mutating ops here with enable_python_dispatcher(): @@ -1186,7 +1211,7 @@ def inference_graph(fn, args) -> torch.fx.GraphModule: @torch.enable_grad() -def training_graph(fn, args) -> torch.fx.GraphModule: +def joint_fwd_bwd(fn, args) -> torch.fx.GraphModule: """Build a normalized training graph, for use with fx_to_pattern""" gm: Optional[torch.fx.GraphModule] = None From 1460e5b7f5e85de532d18ad0c2ac7d48211100e8 Mon Sep 17 00:00:00 2001 From: agunapal Date: Fri, 27 Oct 2023 21:09:32 +0000 Subject: [PATCH 67/78] updated aarch64 maintainers in docs (#112047) This PR adds a new section for maintainers of `aarch64`. Adding @snadampal to the list Pull Request resolved: https://github.com/pytorch/pytorch/pull/112047 Approved by: https://github.com/atalman --- docs/source/community/persons_of_interest.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/source/community/persons_of_interest.rst b/docs/source/community/persons_of_interest.rst index 36975e33f9d06..39aedc7fc3048 100644 --- a/docs/source/community/persons_of_interest.rst +++ b/docs/source/community/persons_of_interest.rst @@ -271,6 +271,11 @@ PowerPC - Alfredo Mendoza (`avmgithub `__) +AArch64 CPU +~~~~~~~~~~~~ + +- Sunita Nadampalli (`snadampal `__) + Docs / Tutorials ~~~~~~~~~~~~~~~~ From 061bf1a153ecb81dbe614499da19d30d9d5e8f95 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Fri, 27 Oct 2023 21:26:54 +0000 Subject: [PATCH 68/78] [5/N] Make torch context manager a TorchCtxManagerClassVariable (#111622) Major change in this PR is to make torch context manager class a separate ```TorchCtxManagerClassVariable```, since we have dynamo implementation for these ctx managers. I was thinking to wrap them as ```UserDefinedClassVariable``` and do dispatch at ```USCVariable.call_function```, but it seems almost the same amount of work and this way is more clear. This is on the way of moving ```TorchVariable``` to ```TorchFunctionVariable``` which will only handle the functions who would be allowed in graph (e.g, ```torch.sin```) and constant folded (e.g, ```torch.is_floating_point```). All other torch functions would be go through skip/inline rules, and would be wrapped as ```UserFunctionVariable``` (for inlined) and ```SkipFilesVariable``` (for skipped). The next steps: * Wrap torch modules, classes, objects as regular ```PythonModuleVariable```, ```UserDefinedClassVariable``` and ```UserDefinedObjectVariable```. * Generate the allow in graph torch functions list and wrap them as ```TorchFunctionVariable```. * Finally merge ```skipfiles.check``` and ```is_allowed``` into one function ```allow_skip.check(fn)``` which would return a Enum of allow, skip and inline. Pull Request resolved: https://github.com/pytorch/pytorch/pull/111622 Approved by: https://github.com/jansel --- test/dynamo/test_allow_inline_skip.py | 99 ------ test/dynamo/test_trace_rules.py | 397 ++++++++++++++++++++++++ torch/_dynamo/allowed_functions.py | 3 + torch/_dynamo/trace_rules.py | 74 +++++ torch/_dynamo/utils.py | 8 + torch/_dynamo/variables/__init__.py | 5 +- torch/_dynamo/variables/builder.py | 6 +- torch/_dynamo/variables/builtin.py | 3 + torch/_dynamo/variables/torch.py | 165 ++++++---- torch/_dynamo/variables/user_defined.py | 16 +- 10 files changed, 602 insertions(+), 174 deletions(-) delete mode 100644 test/dynamo/test_allow_inline_skip.py create mode 100644 test/dynamo/test_trace_rules.py create mode 100644 torch/_dynamo/trace_rules.py diff --git a/test/dynamo/test_allow_inline_skip.py b/test/dynamo/test_allow_inline_skip.py deleted file mode 100644 index e65dbdb12dd2d..0000000000000 --- a/test/dynamo/test_allow_inline_skip.py +++ /dev/null @@ -1,99 +0,0 @@ -# Owner(s): ["module: dynamo"] -import importlib -import types -import unittest - -import torch -import torch._dynamo.test_case -from torch._dynamo.skipfiles import ( - FUNC_INLINELIST, - LEGACY_MOD_INLINELIST, - MOD_INLINELIST, -) -from torch._dynamo.utils import istype - -try: - from .utils import create_dummy_module_and_function -except ImportError: - from utils import create_dummy_module_and_function - - -def gen_get_func_inlinelist(dummy_func_inlinelist): - def get_func_inlinelist(): - inlinelist = set() - for f in dummy_func_inlinelist: - module_name, fn_name = f.rsplit(".", 1) - m = importlib.import_module(module_name) - fn = getattr(m, fn_name) - inlinelist.add(fn.__code__) - return inlinelist - - return get_func_inlinelist - - -class AllowInlineSkipTests(torch._dynamo.test_case.TestCase): - # We are using python function and module string names for these inlinelist, - # this unit test is to make sure the functions/modules can be correctly imported - # or loaded in case there is typo in the strings. - def test_skipfiles_inlinelist_correctness(self): - for m in LEGACY_MOD_INLINELIST.union(MOD_INLINELIST): - self.assertTrue(isinstance(importlib.import_module(m), types.ModuleType)) - for f in FUNC_INLINELIST: - module_name, fn_name = f.rsplit(".", 1) - m = importlib.import_module(module_name) - self.assertTrue(isinstance(getattr(m, fn_name), types.FunctionType)) - - def test_func_inlinelist_torch_function(self): - def fn(x): - if istype(x, torch.Tensor): - return x + 1 - else: - return x - 1 - - func_inlinelist = torch._dynamo.skipfiles.FUNC_INLINELIST.copy() - func_inlinelist.add("torch._dynamo.utils.istype") - - self.assertTrue( - "torch._dynamo" not in torch._dynamo.skipfiles.LEGACY_MOD_INLINELIST - ) - self.assertTrue("torch._dynamo" not in torch._dynamo.skipfiles.MOD_INLINELIST) - - with unittest.mock.patch( - "torch._dynamo.skipfiles.get_func_inlinelist", - gen_get_func_inlinelist(func_inlinelist), - ): - x = torch.rand(3) - opt_fn = torch.compile(backend="eager", fullgraph=True)(fn) - ref = fn(x) - res = opt_fn(x) - self.assertEqual(ref, res) - - def test_func_inlinelist_third_party_function(self): - mod, func = create_dummy_module_and_function() - - def fn(x): - return func(x) - - func_inlinelist = torch._dynamo.skipfiles.FUNC_INLINELIST.copy() - func_inlinelist.add(f"{mod.__name__}.{func.__name__}") - - with unittest.mock.patch( - "torch._dynamo.skipfiles.get_func_inlinelist", - gen_get_func_inlinelist(func_inlinelist), - ), unittest.mock.patch( - "torch._dynamo.skipfiles.SKIP_DIRS", - torch._dynamo.skipfiles.SKIP_DIRS.copy(), - ): - # First adding the module to SKIP_DIRS so that it will be skipped. - torch._dynamo.skipfiles.add(mod.__name__) - x = torch.rand(3) - opt_fn = torch.compile(backend="eager", fullgraph=True)(fn) - ref = fn(x) - res = opt_fn(x) - self.assertEqual(ref, res) - - -if __name__ == "__main__": - from torch._dynamo.test_case import run_tests - - run_tests() diff --git a/test/dynamo/test_trace_rules.py b/test/dynamo/test_trace_rules.py new file mode 100644 index 0000000000000..f34e95152db2d --- /dev/null +++ b/test/dynamo/test_trace_rules.py @@ -0,0 +1,397 @@ +# Owner(s): ["module: dynamo"] +import collections +import copy +import importlib +import inspect +import math +import types +import unittest +import warnings + +import torch +import torch._dynamo.config as config +import torch._dynamo.test_case +import torch._functorch.deprecated as deprecated_func +from torch._dynamo.external_utils import is_compiling +from torch._dynamo.skipfiles import ( + FUNC_INLINELIST, + LEGACY_MOD_INLINELIST, + MOD_INLINELIST, +) +from torch._dynamo.trace_rules import get_torch_obj_rule_map, load_object +from torch._dynamo.utils import is_safe_constant, istype +from torch.fx._symbolic_trace import is_fx_tracing + +try: + from .utils import create_dummy_module_and_function +except ImportError: + from utils import create_dummy_module_and_function + + +ignored_torch_name_rule_set = { + "torch.ExcludeDispatchKeyGuard", + "torch._C.DisableTorchFunction", + "torch._C._AutoDispatchBelowAutograd", + "torch._C._DisableAutocast", + "torch._C._DisableFuncTorch", + "torch._C._DisablePythonDispatcher", + "torch._C._DisableTorchDispatch", + "torch._C._EnablePreDispatch", + "torch._C._EnablePythonDispatcher", + "torch._C._EnableTorchFunction", + "torch._C._ExcludeDispatchKeyGuard", + "torch._C._ForceDispatchKeyGuard", + "torch._C._IncludeDispatchKeyGuard", + "torch._C._InferenceMode", + "torch._C._RestorePythonTLSSnapshot", + "torch._C._SetExcludeDispatchKeyGuard", + "torch._C._profiler._RecordFunctionFast", + "torch._decomp.decompositions_for_rng.PhiloxStateTracker", + "torch._subclasses.fake_tensor.FakeTensorMode", + "torch._subclasses.functional_tensor.FunctionalTensorMode", + "torch.ao.nn.sparse.quantized.utils.LinearBlockSparsePattern", + "torch.autograd.anomaly_mode.detect_anomaly", + "torch.autograd.anomaly_mode.set_detect_anomaly", + "torch.autograd.forward_ad._set_fwd_grad_enabled", + "torch.autograd.forward_ad.dual_level", + "torch.autograd.grad_mode._force_original_view_tracking", + "torch.autograd.grad_mode._unsafe_preserve_version_counter", + "torch.autograd.grad_mode.set_multithreading_enabled", + "torch.autograd.graph.saved_tensors_hooks", + "torch.autograd.profiler.emit_itt", + "torch.autograd.profiler.emit_nvtx", + "torch.autograd.profiler_legacy.profile", + "torch.backends.mkl.verbose", + "torch.backends.mkldnn.verbose", + "torch.cpu.StreamContext", + "torch.cuda.StreamContext", + "torch.cuda._DeviceGuard", + "torch.cuda.device", + "torch.cuda.graphs.graph", + "torch.device", # constant folding + "torch.distributed.autograd.context", + "torch.hub._Faketqdm", + "torch.jit._ir_utils._InsertPoint", + "torch.jit._script.RecursiveScriptClass", + "torch.jit.strict_fusion", + "torch.onnx._internal.diagnostics.infra.context.DiagnosticContext", + "torch.onnx._internal.fx.patcher.ONNXTorchPatcher", + "torch.overrides.TorchFunctionMode", + "torch.package.package_exporter.PackageExporter", + "torch.serialization._opener", + "torch.sparse.check_sparse_tensor_invariants", + "torch.utils._contextlib._DecoratorContextManager", + "torch.utils._device.DeviceContext", + "torch.utils._python_dispatch.TorchDispatchMode", + "torch.utils.data.datapipes._decorator.guaranteed_datapipes_determinism", + "torch.utils.data.datapipes._decorator.runtime_validation_disabled", + "torch.utils.data.datapipes.dataframe.dataframes.CaptureLikeMock", + "torch.utils.hooks.RemovableHandle", +} + + +if torch.distributed.is_available(): + ignored_torch_name_rule_set |= { + "torch.distributed.rpc.server_process_global_profiler._server_process_global_profile", + } + + +def gen_get_func_inlinelist(dummy_func_inlinelist): + def get_func_inlinelist(): + inlinelist = set() + for f in dummy_func_inlinelist: + module_name, fn_name = f.rsplit(".", 1) + m = importlib.import_module(module_name) + fn = getattr(m, fn_name) + inlinelist.add(fn.__code__) + return inlinelist + + return get_func_inlinelist + + +def _disallowed_function_ids(): + remove = [ + True, + False, + None, + collections.OrderedDict, + copy.copy, + copy.deepcopy, + inspect.signature, + math.__package__, + torch.__builtins__, + torch.autocast_decrement_nesting, + torch.autocast_increment_nesting, + torch.autograd.grad, + torch.clear_autocast_cache, + torch.cuda.current_device, + torch.cuda.set_device, + torch.distributions.constraints.is_dependent, + torch.distributions.normal.Normal, + torch.inference_mode, + torch.jit.isinstance, + torch.set_anomaly_enabled, + torch.set_autocast_cache_enabled, + torch.set_autocast_cpu_dtype, + torch.set_autocast_cpu_enabled, + torch.set_autocast_enabled, + torch.set_autocast_gpu_dtype, + warnings.warn, + torch._C._dynamo.eval_frame.unsupported, + torch.Tensor.__init__, + ] + + # extract all dtypes from torch + dtypes = [ + obj for obj in torch.__dict__.values() if isinstance(obj, type(torch.float32)) + ] + remove += dtypes + storage = [ + obj + for obj in torch.__dict__.values() + if isinstance(obj, type(torch.FloatStorage)) + ] + remove += storage + + # Distributed APIs don't work well with torch.compile. + if torch.distributed.is_available(): + remove.extend( + torch.distributed.distributed_c10d.dynamo_unsupported_distributed_c10d_ops + ) + + return {id(x) for x in remove} + + +def generate_allow_list(): + """ + Walk torch.* and get the ids of all the stuff in it + """ + warnings.filterwarnings("ignore", category=UserWarning, module="torch.distributed") + torch_object_ids = dict() + torch_objects = set() + + def _is_allowed_module_prefix(obj): + allowed_modules = ("torch", "math") + # torch.nn.modules.rnn is disallowed because these modules internally + # flatten their parameters. This flattening process will call + # Tensor.set_ with a Storage, and Storages cannot be traced with + # AOTAutograd; so we need to graph-break. To ensure this, we inline + # these functions, rather than keep them opaque-ly in the graph. + disallowed_modules = [ + "torch.optim.", + "torch.utils._foreach_utils", # omit the period so we match all the functions in this module + "torch.utils._pytree", + "torch.nn.modules.rnn.", + "torch._dynamo.", + "torch._C._dynamo.", + "torch._inductor.", + "torch._C.inductor.", + "torch.fx.", + "torch.distributed.fsdp.", + "torch.distributed._tensor.", + # Inline through the ActivationWrapper in + # torch.distributed.algorithms._checkpoint.checkpoint_wrapper. This + # nn module calls torch.utils.checkpoint internally. If Dynamo does + # not trace this, AOT Autograd will try to trace this and can cause + # issues observed in + # https://github.com/pytorch/pytorch/issues/108269 + "torch.distributed.algorithms.", + ] + if config.trace_distributed: + disallowed_modules.append("torch.distributed.") + + allowed_modules_dot = tuple([x + "." for x in allowed_modules]) + module = inspect.getmodule(obj) + if module is None: + return False + + mod_name = module.__name__ + + if any(mod_name.startswith(m) for m in disallowed_modules): + return False + + return mod_name in allowed_modules or mod_name.startswith(allowed_modules_dot) + + def _find_torch_objects(module): + if any( + module.__name__.startswith(mod_name) + for mod_name in config.allowed_functions_module_string_ignorelist + ): + return + torch_object_ids[id(module)] = module.__name__ + for name, obj in list(module.__dict__.items()): + if id(obj) not in torch_object_ids: + # Dynamo allows all builtins into the graph and does not attempt + # to introspect into them. We don't want to allow instances of + # HigherOrderOperator into the graph all the time (Dynamo needs + # to introspect the body functions of these HigherOrderOperator + # first, decide they are safe, and then allow them into the graph). + # So we exclude HigherOrderOperator from being a builtin. + import torch._ops + + if isinstance(obj, torch._ops.HigherOrderOperator): + continue + + # We want to trace through `grad` and `vmap` + if obj in ( + torch.func.grad, + deprecated_func.grad, + torch.func.vmap, + deprecated_func.vmap, + torch.nn.functional.triplet_margin_with_distance_loss, + torch.cond, + ): + continue + + if isinstance(obj, types.ModuleType): + if obj.__name__.startswith("torch.") and _is_allowed_module_prefix( + obj + ): + torch_object_ids[id(obj)] = f"{module.__name__}.{name}" + _find_torch_objects(obj) + elif _is_allowed_module_prefix(obj): + torch_object_ids[id(obj)] = f"{module.__name__}.{name}" + if ( + issubclass(type(obj), type) + and "__enter__" in obj.__dict__ + and "__exit__" in obj.__dict__ + ): + torch_objects.add(obj) + elif inspect.getmodule(obj) is None and not is_safe_constant(obj): + torch_object_ids[id(obj)] = f"{module.__name__}.{name}" + if ( + issubclass(type(obj), type) + and "__enter__" in obj.__dict__ + and "__exit__" in obj.__dict__ + ): + torch_objects.add(obj) + + _find_torch_objects(torch) + _find_torch_objects(math) + + if config.trace_distributed: + from torch.distributed import _functional_collectives_impl as fci + + for f in [ + fci._all_gather_into_tensor, + fci._all_reduce, + fci._reduce_scatter_tensor, + fci._all_reduce_coalesced, + fci._all_gather_into_tensor_coalesced, + fci._reduce_scatter_tensor_coalesced, + ]: + torch_object_ids[id(f)] = repr(f) + + # torch.Tensor.{fn} + for name in dir(torch.Tensor): + method = getattr(torch.Tensor, name) + if isinstance( + method, (types.MethodDescriptorType, types.WrapperDescriptorType) + ): + torch_object_ids[id(method)] = f"torch.Tensor.{name}" + + for idx in _disallowed_function_ids(): + if idx in torch_object_ids: + del torch_object_ids[idx] + + for extra in (is_fx_tracing, is_compiling): + torch_object_ids[id(extra)] = f"{extra.__module__}.{extra.__name__}" + + return torch_objects + + +class TraceRuleTests(torch._dynamo.test_case.TestCase): + # We are using python function and module string names for these inlinelist, + # this unit test is to make sure the functions/modules can be correctly imported + # or loaded in case there is typo in the strings. + def test_skipfiles_inlinelist(self): + for m in LEGACY_MOD_INLINELIST.union(MOD_INLINELIST): + self.assertTrue( + isinstance(importlib.import_module(m), types.ModuleType), + f"{m} from skipfiles.MOD_INLINELIST/LEGACY_MOD_INLINELIST is not a python module, please check and correct it.", + ) + for f in FUNC_INLINELIST: + module_name, fn_name = f.rsplit(".", 1) + m = importlib.import_module(module_name) + self.assertTrue( + isinstance(getattr(m, fn_name), types.FunctionType), + f"{f} from skipfiles.FUNC_INLINELIST is not a python function, please check and correct it.", + ) + + def test_torch_name_rule_map(self): + generated_torch_name_rule_set = generate_allow_list() + ignored_torch_obj_rule_set = { + load_object(x) for x in ignored_torch_name_rule_set + } + used_torch_name_rule_set = ( + set(get_torch_obj_rule_map().keys()) | ignored_torch_obj_rule_set + ) + x = generated_torch_name_rule_set - used_torch_name_rule_set + y = used_torch_name_rule_set - generated_torch_name_rule_set + msg1 = ( + f"New torch objects: {x} " + "were not added to trace_rules.torch_name_rule_map or test_trace_rules.ignored_torch_name_rule_set. " + "Refer the instruction in `torch/_dynamo/trace_rules.py` for more details." + ) + msg2 = ( + f"Existing torch objects: {y} were removed. " + "Please remove them from trace_rules.torch_name_rule_map or test_trace_rules.ignored_torch_name_rule_set. " + "Refer the instruction in `torch/_dynamo/trace_rules.py` for more details." + ) + self.assertTrue(len(x) == 0, msg1) + self.assertTrue(len(y) == 0, msg2) + + def test_func_inlinelist_torch_function(self): + def fn(x): + if istype(x, torch.Tensor): + return x + 1 + else: + return x - 1 + + func_inlinelist = torch._dynamo.skipfiles.FUNC_INLINELIST.copy() + func_inlinelist.add("torch._dynamo.utils.istype") + + self.assertTrue( + "torch._dynamo" not in torch._dynamo.skipfiles.LEGACY_MOD_INLINELIST + ) + self.assertTrue("torch._dynamo" not in torch._dynamo.skipfiles.MOD_INLINELIST) + + with unittest.mock.patch( + "torch._dynamo.skipfiles.get_func_inlinelist", + gen_get_func_inlinelist(func_inlinelist), + ): + x = torch.rand(3) + opt_fn = torch.compile(backend="eager", fullgraph=True)(fn) + ref = fn(x) + res = opt_fn(x) + self.assertEqual(ref, res) + + def test_func_inlinelist_third_party_function(self): + mod, func = create_dummy_module_and_function() + + def fn(x): + return func(x) + + func_inlinelist = torch._dynamo.skipfiles.FUNC_INLINELIST.copy() + func_inlinelist.add(f"{mod.__name__}.{func.__name__}") + + with unittest.mock.patch( + "torch._dynamo.skipfiles.get_func_inlinelist", + gen_get_func_inlinelist(func_inlinelist), + ), unittest.mock.patch( + "torch._dynamo.skipfiles.SKIP_DIRS", + torch._dynamo.skipfiles.SKIP_DIRS.copy(), + ): + # First adding the module to SKIP_DIRS so that it will be skipped. + torch._dynamo.skipfiles.add(mod.__name__) + x = torch.rand(3) + opt_fn = torch.compile(backend="eager", fullgraph=True)(fn) + ref = fn(x) + res = opt_fn(x) + self.assertEqual(ref, res) + + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + run_tests() diff --git a/torch/_dynamo/allowed_functions.py b/torch/_dynamo/allowed_functions.py index b620c6d4887c3..df0a3a1c01e71 100644 --- a/torch/_dynamo/allowed_functions.py +++ b/torch/_dynamo/allowed_functions.py @@ -148,6 +148,9 @@ def _disallowed_function_ids() -> Set[int]: return {id(x) for x in remove} +# We are in progress of refactoring and moving the following functions to test_trace_rules.py. +# If you made any change to the following functions, please also update there as well. +# If you are not clear of how to update, please contact @yanboliang. @FunctionIdSet def _allowed_function_ids() -> Dict[int, str]: """ diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py new file mode 100644 index 0000000000000..053ecfaad7a74 --- /dev/null +++ b/torch/_dynamo/trace_rules.py @@ -0,0 +1,74 @@ +import functools +import importlib + +from .utils import hashable + +from .variables import TorchCtxManagerClassVariable + + +""" +Map of torch objects to their tracing rules (Dynamo variables). +* TorchVariable: The functions should be put into the FX graph or can be constant folded. E.g., + - torch.add: should be put into the FX graph. + - torch.is_floating_point: constant folded. +* TorchCtxManagerClassVariable: The context manager classes are supported by Dynamo. E.g., torch.no_grad +* SkipFilesVariable: The objects should be skipped from tracing. +* UserFunctionVariable: The functions should be inlined. + +We explicitly list torch objects which should be wrapped as TorchCtxManagerClassVariable. +The initial list comes from the heuristic in test/dynamo/test_trace_rules.py:generate_allow_list. + +For developers: If you add/remove a torch level API, it may trigger failures from +test/dynamo/test_trace_rules.py:test_torch_name_rule_map. To fix the failures: +If you are adding a new torch level API or Dynamo implementation: +* Add the name with TorchCtxManagerClassVariable to this map + if you are adding Dynamo implementation for that context manager. +* Remove the object name from test/dynamo/test_trace_rules.ignored_torch_name_rule_set if it's there. + +If you are removing an existing torch level API: +* Remove the entry represented the API from this map or test/dynamo/test_trace_rules.ignored_torch_name_rule_set + depends on where it is. + +TODO: Add torch object names mapping to TorchVariable for in graph and constant fold functions. +TODO: We would consolidate the skipfiles.check rules into trace_rules.lookup later. +TODO: We would support explictly list objects treated as skip/inline after the skipfiles.check +and trace_rules.lookup consolidation is done. Then the explicit listing of skip/inline objects have +a higher priority, which can be used to override the skipfiles.check rules in some cases. +""" +torch_name_rule_map = { + "torch._C.DisableTorchFunctionSubclass": TorchCtxManagerClassVariable, + "torch.amp.autocast_mode.autocast": TorchCtxManagerClassVariable, + "torch.autograd.grad_mode.enable_grad": TorchCtxManagerClassVariable, + "torch.autograd.grad_mode.inference_mode": TorchCtxManagerClassVariable, + "torch.autograd.grad_mode.no_grad": TorchCtxManagerClassVariable, + "torch.autograd.grad_mode.set_grad_enabled": TorchCtxManagerClassVariable, + "torch.autograd.profiler.profile": TorchCtxManagerClassVariable, + "torch.autograd.profiler.record_function": TorchCtxManagerClassVariable, + "torch.cpu.amp.autocast_mode.autocast": TorchCtxManagerClassVariable, + "torch.cuda.amp.autocast_mode.autocast": TorchCtxManagerClassVariable, + "torch.profiler.profiler.profile": TorchCtxManagerClassVariable, +} + + +@functools.lru_cache(None) +def get_torch_obj_rule_map(): + d = dict() + for k, v in torch_name_rule_map.items(): + obj = load_object(k) + assert obj not in d + d[obj] = v + return d + + +def load_object(name): + mod_name, obj_name = name.rsplit(".", 1) + mod = importlib.import_module(mod_name) + obj = getattr(mod, obj_name) + return obj + + +def lookup(obj): + if not hashable(obj): + return None + rule = get_torch_obj_rule_map().get(obj, None) + return rule diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 737c1cb9f3dd9..18c87ed1e19e2 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -392,6 +392,14 @@ def identity(x): return x +def hashable(x): + try: + hash(x) + return True + except TypeError: + return False + + def nothing(*args, **kwargs): pass diff --git a/torch/_dynamo/variables/__init__.py b/torch/_dynamo/variables/__init__.py index bdeddd62947dc..ea21dada5cf71 100644 --- a/torch/_dynamo/variables/__init__.py +++ b/torch/_dynamo/variables/__init__.py @@ -45,6 +45,7 @@ NewGlobalVariable, NumpyVariable, PythonModuleVariable, + SkipFilesVariable, SuperVariable, UnknownVariable, ) @@ -56,7 +57,7 @@ TensorVariable, UnspecializedPythonVariable, ) -from .torch import TorchVariable +from .torch import TorchCtxManagerClassVariable, TorchVariable from .user_defined import UserDefinedClassVariable, UserDefinedObjectVariable __all__ = [ @@ -91,8 +92,10 @@ "PythonModuleVariable", "RangeVariable", "SliceVariable", + "SkipFilesVariable", "SuperVariable", "TensorVariable", + "TorchCtxManagerClassVariable", "TorchVariable", "TupleVariable", "UnknownVariable", diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index f830c695606cd..54778e3dec6ed 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -33,7 +33,7 @@ from torch.fx.immutable_collections import immutable_list from torch.utils._python_dispatch import is_traceable_wrapper_subclass from torch.utils.weak import TensorWeakRef, WeakIdRef -from .. import config, mutation_guard, replay_record, skipfiles +from .. import config, mutation_guard, replay_record, skipfiles, trace_rules from ..allowed_functions import ( is_allowed, is_builtin_callable, @@ -724,6 +724,10 @@ def index_source(key): source=self.source, guards=make_guards(GuardBuilder.ID_MATCH), ) + elif trace_rules.lookup(value) is not None: + return trace_rules.lookup(value).create_with_source( + value, source=self.source + ) elif is_allowed(value): if is_user_defined_allowed(value): self.tx.output.has_user_defined_allowed_in_graph = True diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 5d18891083ae7..ebc053200e905 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -1047,6 +1047,7 @@ def call_reduce(self, tx, function, iterable, initializer=None): def call_getattr( self, tx, obj: VariableTracker, name_var: VariableTracker, default=None ): + from .. import trace_rules from . import ( ConstantVariable, GetAttrVariable, @@ -1142,6 +1143,8 @@ def call_getattr( if is_utils_checkpoint(member): options["source"] = source return build_checkpoint_variable(**options) + elif trace_rules.lookup(member) is not None: + return trace_rules.lookup(member)(member, **options) elif is_allowed(member): return TorchVariable(member, **options) elif ConstantVariable.is_literal(member): diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index 1b4280ba563f5..5ca619e8d8375 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -20,12 +20,12 @@ import torch.nn import torch.onnx.operators from torch._dynamo.variables import UserFunctionVariable -from torch._logging import warning_once from .. import config, variables from ..allowed_functions import torch_get_name from ..device_interface import device_interfaces from ..exc import unimplemented +from ..guards import GuardBuilder from ..utils import ( check_constant_args, check_unspec_python_args, @@ -180,6 +180,95 @@ def check_allowed_op(value): ) +def torch_reconstruct(codegen, value): + name = torch_get_name(value, f"allowed_fn_{id(value)}") + unique_var_name = "__" + re.sub(r"[^a-zA-Z0-9_]+", "_", name) + return codegen.setup_globally_cached(unique_var_name, value, False) + + +class TorchCtxManagerClassVariable(VariableTracker): + """Points to a context manager class in torch.* that dynamo has implementations""" + + @classmethod + def create_with_source(cls, value, source): + return TorchCtxManagerClassVariable( + value, + source=source, + guards={source.make_guard(GuardBuilder.FUNCTION_MATCH)}, + ) + + def __init__(self, value, **kwargs): + super().__init__(**kwargs) + self.value = value + + def reconstruct(self, codegen): + return torch_reconstruct(codegen, self.value) + + def python_type(self): + return type(self.value) + + def as_python_constant(self): + return self.value + + def call_function( + self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" + ) -> "VariableTracker": + from . import GradModeVariable, InferenceModeVariable, StreamVariable + + options = VariableTracker.propagate(self, args, kwargs.values()) + + if self.value is torch.no_grad: + if len(args) == 1 and isinstance( + args[0], variables.functions.BaseUserFunctionVariable + ): + ctx = GradModeVariable.create(tx, False, initialized=False, **options) + return ctx.call_function(tx, args, kwargs) + else: + return GradModeVariable.create(tx, False, **options) + elif self.value is torch.enable_grad: + if len(args) == 1 and isinstance( + args[0], variables.functions.BaseUserFunctionVariable + ): + ctx = GradModeVariable.create(tx, True, initialized=False, **options) + return ctx.call_function(tx, args, kwargs) + return GradModeVariable.create(tx, True, **options) + elif self.value is torch.set_grad_enabled and len(args) == 1: + return GradModeVariable.create(tx, args[0].as_python_constant(), **options) + elif self.value is torch.inference_mode: + return InferenceModeVariable.create( + tx, args[0].as_python_constant(), **options + ) + elif inspect.isclass(self.value) and issubclass(self.value, _StreamBase): + return wrap_fx_proxy_cls( + StreamVariable, + tx, + tx.output.create_proxy( + "call_function", + self.value, + (), + {}, + ), + **options, + ) + elif self.value in [ + torch.amp.autocast_mode.autocast, + torch.cuda.amp.autocast, + torch.cpu.amp.autocast, + ]: + return AutocastModeVariable.create(self.value, args, kwargs) + elif self.value in ( + torch.profiler.profile, + torch.profiler.record_function, + torch.autograd.profiler.profile, + torch.autograd.profiler.record_function, + ): + log.warning("Profiler function %s will be ignored", self.value) + return NullContextVariable(**options) + elif self.value is torch._C.DisableTorchFunctionSubclass: + assert not (args or kwargs) + return TorchFunctionDisableVariable.create(tx, **options) + + class TorchVariable(VariableTracker): """Points to a module or method in torch.*""" @@ -226,12 +315,8 @@ def call_hasattr(self, tx, name): result = hasattr(self.value, name) return variables.ConstantVariable.create(result).add_options(self) - def unique_var_name(self): - name = torch_get_name(self.value, f"allowed_fn_{id(self.value)}") - return "__" + re.sub(r"[^a-zA-Z0-9_]+", "_", name) - def reconstruct(self, codegen): - return codegen.setup_globally_cached(self.unique_var_name(), self.value, False) + return torch_reconstruct(codegen, self.value) def as_proxy(self): return self.value @@ -259,9 +344,7 @@ def call_function( DeterministicAlgorithmsVariable, DisabledSavedTensorsHooksVariable, GradModeVariable, - InferenceModeVariable, StreamContextVariable, - StreamVariable, SymNodeVariable, TensorVariable, UserDefinedObjectVariable, @@ -371,23 +454,6 @@ def call_function( torch.nn.modules.utils._ntuple, ): return self._call_ntuple(tx, args, kwargs, options) - elif self.value is torch.no_grad: - if len(args) == 1 and isinstance( - args[0], variables.functions.BaseUserFunctionVariable - ): - ctx = GradModeVariable.create(tx, False, initialized=False, **options) - return ctx.call_function(tx, args, kwargs) - else: - return GradModeVariable.create(tx, False, **options) - elif self.value is torch.enable_grad: - if len(args) == 1 and isinstance( - args[0], variables.functions.BaseUserFunctionVariable - ): - ctx = GradModeVariable.create(tx, True, initialized=False, **options) - return ctx.call_function(tx, args, kwargs) - return GradModeVariable.create(tx, True, **options) - elif self.value is torch.set_grad_enabled and len(args) == 1: - return GradModeVariable.create(tx, args[0].as_python_constant(), **options) elif self.value is torch.is_grad_enabled: assert not (args or kwargs) return ConstantVariable.create( @@ -397,10 +463,6 @@ def call_function( return DeterministicAlgorithmsVariable.create( tx, args[0].as_python_constant(), **options ) - elif self.value is torch.inference_mode: - return InferenceModeVariable.create( - tx, args[0].as_python_constant(), **options - ) elif self.value is torch.are_deterministic_algorithms_enabled: assert not (args or kwargs) return ConstantVariable.create( @@ -416,29 +478,6 @@ def call_function( return ConstantVariable.create( tx.output.torch_function_enabled, **options ).add_guards(TorchFunctionDisableVariable._guards_singleton) - elif self.value is torch._C.DisableTorchFunctionSubclass: - assert not (args or kwargs) - return TorchFunctionDisableVariable.create(tx, **options) - elif any( - self.value is method - for method in [ - interface_elem.stream for interface_elem in device_interfaces.values() - ] - ): - assert len(args) == 1 - return StreamContextVariable.create(tx, args[0], **options) - elif inspect.isclass(self.value) and issubclass(self.value, _StreamBase): - return wrap_fx_proxy_cls( - StreamVariable, - tx, - tx.output.create_proxy( - "call_function", - self.value, - (), - {}, - ), - **options, - ) elif self.value in ( torch.overrides.has_torch_function_variadic, torch.overrides.has_torch_function_unary, @@ -447,6 +486,14 @@ def call_function( return ConstantVariable.create( any(has_torch_function(a) for a in args), **options ) + elif any( + self.value is method + for method in [ + interface_elem.stream for interface_elem in device_interfaces.values() + ] + ): + assert len(args) == 1 + return StreamContextVariable.create(tx, args[0], **options) elif self.value is torch.from_numpy: if not config.trace_numpy: unimplemented("torch.from_numpy. config.trace_numpy is False") @@ -474,20 +521,6 @@ def call_function( unimplemented(f"torch.from_numpy(<{type(t)}>)") elif can_dispatch_torch_function(tx, args, kwargs): return dispatch_torch_function(tx, self, args, kwargs) - elif self.value in [ - torch.amp.autocast_mode.autocast, - torch.cuda.amp.autocast, - torch.cpu.amp.autocast, - ]: - return AutocastModeVariable.create(self.value, args, kwargs) - elif self.value in ( - torch.profiler.profile, - torch.profiler.record_function, - torch.autograd.profiler.profile, - torch.autograd.profiler.record_function, - ): - warning_once(log, "Profiler function %s will be ignored", self.value) - return NullContextVariable(**options) elif self.value is torch.autograd._profiler_enabled: unimplemented("torch.autograd._profiler_enabled not supported yet") elif self.value is torch.jit.annotate: diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 0f9d2283aa932..c6f7a27c3e364 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -326,6 +326,7 @@ def is_supported_random(self): def call_function( self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" ) -> "VariableTracker": + from .. import trace_rules from .builder import VariableBuilder if ( @@ -347,12 +348,13 @@ def call_function( obj = self.value.__self__ if ( func is torch.utils._contextlib._DecoratorContextManager.clone - and is_allowed(obj.__class__) + and trace_rules.lookup(obj.__class__) + == variables.TorchCtxManagerClassVariable and not (args or kwargs) ): - return variables.TorchVariable(obj.__class__).call_function( - tx, args, kwargs - ) + return variables.TorchCtxManagerClassVariable( + obj.__class__ + ).call_function(tx, args, kwargs) if ( func is torch.autograd.grad_mode.inference_mode.clone @@ -360,9 +362,9 @@ def call_function( ): # simulate the inference_mode.clone implementation var = variables.ConstantVariable(obj.mode) - return variables.TorchVariable(obj.__class__).call_function( - tx, [var], kwargs - ) + return variables.TorchCtxManagerClassVariable( + obj.__class__ + ).call_function(tx, [var], kwargs) elif ( istype(self.value, functools.partial) and is_allowed(self.value.func) From 089e7aa4ac97a6557b488e4d4020dc10656d91ef Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 27 Oct 2023 21:32:30 +0000 Subject: [PATCH 69/78] Revert "[dynamo] `ExecutorchCallDelegateHigherOrderVariable` - add sanity check that input and output tensors are disjoint (#111960)" This reverts commit 27cf49549a35dd78475098b7de02c0a5ab1367ea. Reverted https://github.com/pytorch/pytorch/pull/111960 on behalf of https://github.com/izaitsevfb due to Fails internal executorch tests with module 'torch.utils._pytree' has no attribute 'tree_flatten_only' ([comment](https://github.com/pytorch/pytorch/pull/111960#issuecomment-1783532843)) --- torch/_dynamo/variables/higher_order_ops.py | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py index 776561e1eb84b..77833bff5a59d 100644 --- a/torch/_dynamo/variables/higher_order_ops.py +++ b/torch/_dynamo/variables/higher_order_ops.py @@ -82,18 +82,6 @@ def only_consist_of(var, types): return False -def _assert_tensors_nonaliasing(inputs, outputs): - input_tensor_ids = set( - pytree.tree_flatten_only(torch.Tensor, lambda t: id(t), inputs) - ) - output_tensor_ids = set( - pytree.tree_flatten_only(torch.Tensor, lambda t: id(t), outputs) - ) - assert input_tensor_ids.isdisjoint( - output_tensor_ids - ), "inputs to function body cannot alias outputs" - - def validate_args_and_maybe_create_graph_inputs( sub_args, tracer, tx, manually_set_subgraph_inputs ): @@ -720,14 +708,7 @@ def call_function( real_sub_args = pytree.tree_map_only( torch.fx.Proxy, lambda a: get_real_value(a.node, tx.output), p_args ) - example_res = lowered_module.original_module(*real_sub_args) - - # NOTE [Guaranteeing the 1-1 correspondence of FakeTensors and real tensors]: - # executorch modules promise not to alias inputs and outputs. - # Thus, output FakeTensors will correctly not alias input FakeTensors. - _assert_tensors_nonaliasing(real_sub_args, example_res) - example_value = deepcopy_to_fake_tensor(example_res, tx.fake_mode) p_args = (lowered_node,) + p_args From c67236a05de55e04eb3e9c8a1087bf13b19324ac Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 27 Oct 2023 21:37:48 +0000 Subject: [PATCH 70/78] Revert "[dynamo] Be stricter about `HigherOrderOperator` kwargs (#111938)" This reverts commit edafe2ddb99dd721021262fdfd58c3f796c7da0c. Reverted https://github.com/pytorch/pytorch/pull/111938 on behalf of https://github.com/izaitsevfb due to Fails meta internal executorch tests with `torch._dynamo.exc.InternalTorchDynamoError: name 'p_kwargs' is not defined` ([comment](https://github.com/pytorch/pytorch/pull/111938#issuecomment-1783538268)) --- torch/_dynamo/variables/higher_order_ops.py | 45 ++++++++++----------- 1 file changed, 22 insertions(+), 23 deletions(-) diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py index 77833bff5a59d..056702ec17096 100644 --- a/torch/_dynamo/variables/higher_order_ops.py +++ b/torch/_dynamo/variables/higher_order_ops.py @@ -363,6 +363,12 @@ def make(value, source=None, **kwargs): else: unimplemented(f"HigherOrderOperator {value.__name__}") + def check_kwargs(self, kwargs, supported_types): + if not all(isinstance(value, supported_types) for value in kwargs.values()): + raise unimplemented( + f"Only kwargs of the following types are supported: {supported_types}" + ) + def call_function( self, tx, args: List[VariableTracker], kwargs: Dict[str, VariableTracker] ) -> VariableTracker: @@ -610,11 +616,6 @@ def call_function( ) from .builder import wrap_fx_proxy - if len(kwargs) > 0: - unsupported( - "torch.ops.higher_order.map: kwargs are not supported in the map operator." - ) - assert type(args[0]) in (UserFunctionVariable, NestedUserFunctionVariable) assert type(args[1]) is TensorVariable @@ -671,6 +672,8 @@ def call_function( r = body_r.as_proxy().node.meta["example_value"] example_value = r.new_empty([sample_shape[0], *r.shape]) + _, p_kwargs = proxy_args_kwargs([], kwargs) + # Store the invocation as a call return wrap_fx_proxy( tx=tx, @@ -678,7 +681,7 @@ def call_function( "call_function", self.value, args=tuple(p_args), - kwargs={}, + kwargs=p_kwargs, ), example_value=example_value, ) @@ -688,18 +691,17 @@ class ExecutorchCallDelegateHigherOrderVariable(TorchHigherOrderOperatorVariable def call_function( self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" ) -> "VariableTracker": + from . import ConstantVariable from .builder import wrap_fx_proxy + self.check_kwargs(kwargs, ConstantVariable) + # This is operator for delegation within Executorch which calls a # specific function in the given lowered module with the given # operators. The actual operator is defined in the Executorch codebase. # This is a bad hierarchical violation since # executorch_call_delegate sits at a higher level than dynamo, but # there's no real solution to this issue yet. - if len(kwargs) > 0: - unimplemented( - "executorch_call_delegate: kwargs arguments were not enabled." - ) lowered_module = tx.output.get_submodule(args[0].module_key) lowered_node = make_attr(tx, args[0].module_key) @@ -713,6 +715,8 @@ def call_function( p_args = (lowered_node,) + p_args + _, p_kwargs = proxy_args_kwargs([], kwargs) + # Store the invocation as a call return wrap_fx_proxy( tx=tx, @@ -930,7 +934,7 @@ def call_function( if not isinstance(out_dims, (ConstantVariable, TupleVariable)): unimplemented("torch.func.vmap: out_dims is not an int or tuple variable.") - if len(kwargs) > 0: + if kwargs: unimplemented( "NYI - torch.func.vmap: kwargs arguments are currently unsupported." ) @@ -1082,15 +1086,12 @@ def __init__( def call_function( self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" ) -> "VariableTracker": - from . import UserFunctionVariable + from . import ConstantVariable, UserFunctionVariable from .builder import wrap_fx_proxy tracer = self.fwd_bwd_tracer - if len(kwargs) > 0: - unimplemented( - "kwargs have not been implemented for torch.autograd.Function" - ) + self.check_kwargs(kwargs, ConstantVariable) from . import TorchVariable @@ -1147,6 +1148,8 @@ def call_function( r = body_r.as_proxy().node.meta["example_value"] example_value = r + _, p_kwargs = proxy_args_kwargs([], kwargs) + # Store the invocation as a call return wrap_fx_proxy( tx=tx, @@ -1154,7 +1157,7 @@ def call_function( "call_function", self.value, args=tuple(p_args), - kwargs={}, + kwargs=p_kwargs, ), example_value=example_value, ) @@ -1210,14 +1213,10 @@ def call_function( ) -> "VariableTracker": from .builder import wrap_fx_proxy - # This flattens the kwargs into lifted args p_args, p_kwargs, example_value, treespec = self.create_wrapped_node( tx, args, kwargs, "wrap" ) - if len(p_kwargs) > 0: - unimplemented("kwargs should have been flattened into lifted args") - # Store the invocation as a call variable = wrap_fx_proxy( tx=tx, @@ -1225,7 +1224,7 @@ def call_function( "call_function", self.value, args=tuple(p_args), - kwargs={}, + kwargs=p_kwargs, ), example_value=example_value, ) @@ -1247,7 +1246,7 @@ def call_function( ) -> "VariableTracker": from .builder import wrap_fx_proxy - if len(kwargs) > 0: + if len(kwargs) != 0: unimplemented("out_dtype does not handle kwargs") p_args = tuple(arg.as_proxy() for arg in args) From dbb31a2984fa616b4bb6fac7abb2a06ec0533eb1 Mon Sep 17 00:00:00 2001 From: Oguz Ulgen Date: Fri, 27 Oct 2023 00:48:33 -0700 Subject: [PATCH 71/78] [Inductor] Add triton.autotune support for user defined triton kernels with constant/simple grids (#112228) Pull Request resolved: https://github.com/pytorch/pytorch/pull/112228 Approved by: https://github.com/jansel --- test/dynamo/test_functions.py | 43 ++++++++++++++++++++ torch/_dynamo/variables/builder.py | 6 ++- torch/_dynamo/variables/functions.py | 17 +++++++- torch/_inductor/codegen/wrapper.py | 28 ++++++------- torch/_inductor/ir.py | 10 ++++- torch/_inductor/triton_heuristics.py | 60 ++++++++++++++++++++++++---- 6 files changed, 140 insertions(+), 24 deletions(-) diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index 8837c4cc13985..f19b689783daf 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -1402,6 +1402,30 @@ def add_kernel( output = x + y tl.store(out_ptr + offsets, output, mask=mask) + @triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 128}, num_stages=3, num_warps=8), + triton.Config({"BLOCK_SIZE": 64}, num_stages=3, num_warps=8), + ], + key=[], + ) + @triton.jit + def add_kernel_autotuned( + in_ptr0, + in_ptr1, + out_ptr, + n_elements, + BLOCK_SIZE: "tl.constexpr", + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr0 + offsets, mask=mask) + y = tl.load(in_ptr1 + offsets, mask=mask) + output = x + y + tl.store(out_ptr + offsets, output, mask=mask) + @triton.jit def mul2_kernel( in_ptr0, @@ -1987,6 +2011,25 @@ def call_triton( # reset back CONSTANT_C = prev_c + @requires_cuda() + @requires_triton() + @common_utils.parametrize("grad", [False, True]) + @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"]) + def test_triton_kernel_autotune(self, grad, backend): + def call_triton(x: torch.Tensor, y: torch.Tensor): + output = torch.zeros_like(x, requires_grad=grad) + n_elements = output.numel() + grid = (n_elements,) + add_kernel_autotuned[grid](x, y, output, n_elements) + return output + + t1 = torch.rand(5, device="cuda", requires_grad=grad) + t2 = torch.rand(5, device="cuda", requires_grad=grad) + + torch_add = t1 + t2 + compiled_func = torch.compile(call_triton, backend=backend, fullgraph=True) + self.assertEqual(compiled_func(t1, t2), torch_add) + @requires_cuda() @requires_triton() @common_utils.parametrize("grad", [False, True]) diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 54778e3dec6ed..5773aeca6acc5 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -362,12 +362,16 @@ def _wrap(self, value): from torch.utils._triton import has_triton if has_triton(): + from triton.runtime.autotuner import Autotuner from triton.runtime.jit import JITFunction else: class JITFunction: pass + class Autotuner: + pass + make_guards = self.make_guards # Handle exact type() match @@ -716,7 +720,7 @@ def index_source(key): sym_node_proxy, new_symint == 1, ) - elif isinstance(value, JITFunction): + elif isinstance(value, (JITFunction, Autotuner)): return TritonKernelVariable( value, None, # No kernel idx provided diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index 1bd06a0727ab1..138b82aeed741 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -652,10 +652,12 @@ def get_val(v): class TritonKernelVariable(VariableTracker): def __init__(self, kernel, kernel_idx, grid, **kwargs): - super().__init__(**kwargs) + from triton.runtime.autotuner import Autotuner from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table + super().__init__(**kwargs) + assert kernel is not None self.kernel = kernel @@ -665,6 +667,19 @@ def __init__(self, kernel, kernel_idx, grid, **kwargs): self.grid = grid + if isinstance(kernel, Autotuner): + # We only support configs and keys arguments of triton.autotune + # Make sure other arguments are defaulted + defaults = inspect.signature(Autotuner).parameters + if ( + defaults["warmup"].default != kernel.warmup + or defaults["rep"].default != kernel.rep + or defaults["prune_configs_by"].default != kernel.early_config_prune + ): + raise Unsupported( + "Only configs and keys are supported for triton.autotune" + ) + def call_function( self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" ) -> "VariableTracker": diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 8bef2105f6266..3674c6f817357 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -778,7 +778,7 @@ def get_unique_kernel_name(self, name: str) -> str: self.user_defined_kernel_count += 1 return new_name - def define_user_defined_triton_kernel(self, name, kernel, kwargs): + def define_user_defined_triton_kernel(self, name, kernel, configs, kwargs): original_name = kernel.__name__ compile_wrapper = IndentedBuffer() compile_wrapper.writeline(f"async_compile.triton({original_name!r}, '''") @@ -788,26 +788,18 @@ def define_user_defined_triton_kernel(self, name, kernel, kwargs): import triton import triton.language as tl from torch._inductor.utils import instance_descriptor - from torch._inductor.triton_heuristics import template + from torch._inductor.triton_heuristics import user_autotune """, strip=True, ) compile_wrapper.newline() - # TODO(oulgen): num_stages and num_warps are default values of - # triton.Config. Can we do better? Or ask the user to provide? - num_stages = 2 - num_warps = 4 - from ..ir import Buffer from .common import SizeArg, TensorArg signature: List[Union[TensorArg, SizeArg]] = [] constants = {} for key, arg in kwargs.items(): - # Not a real argument - if key == "grid": - continue if ( key in kernel.__annotations__ and "constexpr" in kernel.__annotations__[key] @@ -829,12 +821,20 @@ def define_user_defined_triton_kernel(self, name, kernel, kwargs): "configs": [config_of(signature)], "kernel_name": name, } + configs = [ + { + "kwargs": config.kwargs, + "num_warps": config.num_warps, + "num_stages": config.num_stages, + } + for config in configs + ] compile_wrapper.splice( f""" - @template( - num_stages={num_stages}, - num_warps={num_warps}, - meta={triton_meta!r} + @user_autotune( + configs={configs!r}, + meta={triton_meta!r}, + filename=__file__ ) @triton.jit """ diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 7be3836e0d75e..4e20ce362b0c0 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -3768,9 +3768,15 @@ def apply_constraint(self): class UserDefinedTritonKernel(ExternKernel): def codegen(self, wrapper): + from triton.runtime.autotuner import Autotuner + from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table kernel = kernel_side_table.get_kernel(self.kernel_idx) + configs = [] + if isinstance(kernel, Autotuner): + configs = kernel.configs + kernel = kernel.fn new_name = wrapper.get_unique_kernel_name(kernel.__name__) self.codegen_comment(wrapper) @@ -3779,7 +3785,9 @@ def codegen(self, wrapper): self.grid, self.codegen_kwargs(), ) - wrapper.define_user_defined_triton_kernel(new_name, kernel, self.kwargs) + wrapper.define_user_defined_triton_kernel( + new_name, kernel, configs, self.kwargs + ) def should_allocate(self): return False diff --git a/torch/_inductor/triton_heuristics.py b/torch/_inductor/triton_heuristics.py index be062e8eb3cbe..218d97d2ab7d6 100644 --- a/torch/_inductor/triton_heuristics.py +++ b/torch/_inductor/triton_heuristics.py @@ -12,7 +12,7 @@ import re import threading from enum import auto, Enum -from typing import Any, Callable, List, Optional, Set, Tuple +from typing import Any, Callable, Dict, List, Optional, Set, Tuple import torch @@ -62,6 +62,7 @@ class HeuristicType(Enum): REDUCTION = auto() PERSISTENT_REDUCTION = auto() TEMPLATE = auto() + USER_AUTOTUNE = auto() class AutotuneHint(Enum): @@ -344,7 +345,7 @@ def launcher({', '.join(def_args)}, grid, stream): return binary, launcher - def bench(self, launcher, *args, grid): + def bench(self, launcher, *args, grid, **kwargs): """Measure the performance of a given launcher""" if launcher.n_spills > config.triton.spill_threshold: log.debug( @@ -362,16 +363,17 @@ def kernel_call(): {**dict(zip(self.arg_names, args)), **launcher.config.kwargs} ) - cloned_args = self.clone_args(*args) + cloned_args, cloned_kwargs = self.clone_args(*args, **kwargs) launcher( *cloned_args, + **cloned_kwargs, grid=grid, stream=stream, ) return do_bench(kernel_call, rep=40, fast_flush=True) - def clone_args(self, *args): + def clone_args(self, *args, **kwargs) -> Tuple[List[Any], Dict[str, Any]]: from .compile_fx import clone_preserve_strides # clone inplace buffers to avoid autotune contaminating them if @@ -385,7 +387,15 @@ def clone_args(self, *args): else: cloned_args.append(arg) - return cloned_args + cloned_kwargs: Dict[str, Any] = {} + for name, arg in kwargs.items(): + if name in self.mutated_arg_names: + assert isinstance(arg, torch.Tensor) + cloned_kwargs[name] = clone_preserve_strides(arg) + else: + cloned_kwargs[name] = arg + + return cloned_args, cloned_kwargs @dynamo_timed def benchmark_all_configs(self, *args, **kwargs): @@ -451,11 +461,14 @@ def coordinate_descent_tuning(self, launcher, *args, **kwargs): Then if coordinate descnt tuning is run with max-autotune disabled, it will start from C1; while if coordinate descent tuning is run with max-autotune enabled, it will start from C3. """ - if self.heuristic_type == HeuristicType.TEMPLATE: + if ( + self.heuristic_type == HeuristicType.TEMPLATE + or self.heuristic_type == HeuristicType.USER_AUTOTUNE + ): # skip triton template return launcher - cloned_args = self.clone_args(*args) + cloned_args, _ = self.clone_args(*args) config2launcher = {launcher.config: launcher} def benchmark_one_config(config): @@ -1130,6 +1143,39 @@ def template(num_stages, num_warps, meta, filename=None): ) +def user_autotune(configs, meta, filename=None): + """ + Compile a user defined triton kernel + """ + defaults = inspect.signature(triton.Config).parameters + default_num_stages = defaults["num_stages"].default + default_num_warps = defaults["num_warps"].default + + if len(configs) == 0: + configs = [ + triton.Config( + {}, num_stages=default_num_stages, num_warps=default_num_warps + ) + ] + else: + configs = [ + triton.Config( + c.get("kwargs", {}), + num_stages=c.get("num_stages", default_num_stages), + num_warps=c.get("num_warps", default_num_warps), + ) + for c in configs + ] + + return cached_autotune( + None, + configs, + meta=meta, + heuristic_type=HeuristicType.USER_AUTOTUNE, + filename=filename, + ) + + def foreach(meta, num_warps, filename=None): """ Compile a triton foreach kernel From 3d2041b34210bef3902f6ba86881b38ac0fbc57e Mon Sep 17 00:00:00 2001 From: Sam Larsen Date: Fri, 27 Oct 2023 09:18:23 -0700 Subject: [PATCH 72/78] [inductor] Fix bug handling output_strides in fx graph cache (#112041) Summary: The current implementation is not properly attaching output strides to the tracing context when an fx graph is loaded from the cache. That bugs leads to assertion failures like `AssertionError: expected size 3==3, stride 1==9 at dim=1`. This change saves the output strides in the serialized object cached on disk and inserts them into the tracing context whether the graph is loaded from cache or compiled. Test Plan: * New unit test using resnet18 (which repros the problem) * Ran the timm benchmark suite with `--training` Differential Revision: [D50756653](https://our.internmc.facebook.com/intern/diff/D50756653) Pull Request resolved: https://github.com/pytorch/pytorch/pull/112041 Approved by: https://github.com/ezyang --- test/inductor/test_codecache.py | 46 +++++++++++++++++++++++---------- torch/_inductor/codecache.py | 1 + torch/_inductor/compile_fx.py | 32 +++++++++++++++++------ 3 files changed, 58 insertions(+), 21 deletions(-) diff --git a/test/inductor/test_codecache.py b/test/inductor/test_codecache.py index 1a30576449edb..28f7fdefffa99 100644 --- a/test/inductor/test_codecache.py +++ b/test/inductor/test_codecache.py @@ -60,6 +60,19 @@ def test_codecache_fork(): _run_codecache_test("fork") +class MyModelConv2d(torch.nn.Module): + def __init__(self, dim=512): + super().__init__() + self.conv1 = torch.nn.Conv2d(3, dim, kernel_size=3, stride=2, bias=False) + self.conv2 = torch.nn.Conv2d(dim, dim, kernel_size=3, stride=2, bias=False) + + def forward(self, x): + x = self.conv1(x) + torch._dynamo.graph_break() + x = self.conv2(x) + return x + + @instantiate_parametrized_tests class TestFxGraphCache(TestCase): @classmethod @@ -124,33 +137,40 @@ def fn(x, y): @requires_triton() @config.patch({"fx_graph_cache": True}) @parametrize("device", ("cuda", "cpu")) - @parametrize("dtype", (torch.float32, torch.bfloat16)) + @parametrize("dtype", (torch.float32, torch.float16)) def test_cache_load_model(self, device, dtype): """ Verify that we can populate and load models from the cache. """ if device == "cuda" and not HAS_CUDA: raise unittest.SkipTest("requires CUDA") - if device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater: - raise unittest.SkipTest("requires SM80 or later") - model = MyModel().to(dtype=dtype, device=device) + def fn(mod, x): + mod.zero_grad() + mod(x).sum().backward() + return [p.grad for p in mod.parameters()] - a = torch.rand(10, 10, dtype=dtype, device=device) + compiled_fn = torch.compile(fn, dynamic=False) - compiled_model = torch.compile(model, dynamic=False) + mod = MyModelConv2d().to(device=device, dtype=dtype) + inp = torch.randn(2, 3, 16, 16, device=device, dtype=dtype) - # A first call shold miss in the cache. - self.assertEqual(model(a), compiled_model(a)) - self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1) + # The first call should see all cache misses. + counters.clear() + grads1 = compiled_fn(mod, inp) + self.assertGreater(counters["inductor"]["fxgraph_cache_miss"], 0) self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) - # A second call should hit. (First reset so in-memory guards + # The second should see all hits. (First reset so in-memory guards # don't prevent compilation). + counters.clear() torch._dynamo.reset() - self.assertEqual(model(a), compiled_model(a)) - self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1) - self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1) + grads2 = compiled_fn(mod, inp) + self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0) + self.assertGreater(counters["inductor"]["fxgraph_cache_hit"], 0) + + # And the results should be the same. + self.assertEqual(grads1, grads2) class TestFxGraphCacheHashing(TestCase): diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 9193781b8ae35..428555b3e0d52 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -659,6 +659,7 @@ class CompiledFxGraph: mutated_inputs: Set[str] = field(default_factory=set) mutated_input_idxs: Set[int] = field(default_factory=set) constants: Dict[str, torch.Tensor] = field(default_factory=dict) + output_strides: Optional[List[Optional[Tuple[int, ...]]]] = None _boxed_call: Optional[bool] = None diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index b3b59f88e76e8..9cfd4cde1ef41 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -8,7 +8,17 @@ import warnings from itertools import count -from typing import Any, Callable, Dict, FrozenSet, List, Optional, Sequence, Union +from typing import ( + Any, + Callable, + Dict, + FrozenSet, + List, + Optional, + Sequence, + Tuple, + Union, +) from unittest import mock from functorch.compile import min_cut_rematerialization_partition @@ -384,6 +394,12 @@ def compile_fx_inner( log.debug("FX codegen and compilation took %.3fs", time.time() - start) + # Return the output strides to the caller via TracingContext + context = torch._guards.TracingContext.get() + if context is not None and context.output_strides is not None: + assert len(context.output_strides) == 0 + context.output_strides.extend(compiled_graph.output_strides) + if aot_mode: return compiled_graph @@ -582,20 +598,19 @@ def fx_codegen_and_compile( ) with V.set_graph_handler(graph): graph.run(*example_inputs) - context = torch._guards.TracingContext.get() - if context is not None and context.output_strides is not None: - # Return the output strides to the caller via TracingContext - assert len(context.output_strides) == 0 - assert graph.graph_outputs is not None + output_strides: List[Optional[Tuple[int, ...]]] = [] + if graph.graph_outputs is not None: + # We'll put the output strides in the compiled graph so we + # can later return them to the caller via TracingContext for out in graph.graph_outputs: if hasattr(out, "layout"): - context.output_strides.append( + output_strides.append( tuple( # type: ignore[arg-type] V.graph.sizevars.size_hint(s) for s in out.layout.stride ) ) else: - context.output_strides.append(None) + output_strides.append(None) compiled_fn = graph.compile_to_fn() @@ -615,6 +630,7 @@ def fx_codegen_and_compile( mutated_inputs=graph.mutated_inputs, mutated_input_idxs=set(graph.mutated_input_idxs), constants=graph.constants, + output_strides=output_strides, ) return compiled_graph From 128f4db77ebb0037523559b59c4531ad0b505d03 Mon Sep 17 00:00:00 2001 From: Ying Zhang Date: Fri, 27 Oct 2023 23:08:38 +0000 Subject: [PATCH 73/78] A small fix in "do_bench_using_profiling" (#112223) This is a small fix in "do_bench_using_profiling()". When CUDA kernels are executed in a non-default CUDA stream, if cuda.synchronize() is called, a CUDA kernel named "Context Sync" will be launched to the default stream to wait until all other streams are finished. This CUDA kernel has "CUDA time" but is not a real kernel to profile. This fix excludes "Context Sync" when calculating kernel total time. Pull Request resolved: https://github.com/pytorch/pytorch/pull/112223 Approved by: https://github.com/int3, https://github.com/chenyang78 --- torch/_inductor/utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index a3b2136e0131e..f443ec7bfef85 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -100,7 +100,11 @@ def do_bench_using_profiling(fn: Callable[[], Any], warmup=25, rep=100) -> float log.debug(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)) filtered_events = EventList( - [event for event in p.events() if event.device_type == DeviceType.CUDA] + [ + event + for event in p.events() + if event.device_type == DeviceType.CUDA and event.name != "Context Sync" + ] ) if len(filtered_events) % n_repeat != 0: raise RuntimeError( From 46a6435203c8db55c07df08e62c3b3ce154b0b23 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Fri, 27 Oct 2023 23:53:32 +0000 Subject: [PATCH 74/78] Make numpy/lib vendored tests dynamo traceable (#112147) Follow up https://github.com/pytorch/pytorch/pull/112146 and #112141 : make numpy/lib vendored tests dynamo traceable Pull Request resolved: https://github.com/pytorch/pytorch/pull/112147 Approved by: https://github.com/lezcano --- pytest.ini | 2 + .../torch_np/numpy_tests/lib/test_arraypad.py | 26 +- .../numpy_tests/lib/test_arraysetops.py | 58 ++-- .../numpy_tests/lib/test_function_base.py | 274 ++++++++++++------ .../numpy_tests/lib/test_histograms.py | 73 +++-- .../numpy_tests/lib/test_index_tricks.py | 47 ++- .../numpy_tests/lib/test_shape_base_.py | 77 +++-- .../numpy_tests/lib/test_twodim_base.py | 102 ++++--- .../numpy_tests/lib/test_type_check.py | 54 +++- 9 files changed, 481 insertions(+), 232 deletions(-) diff --git a/pytest.ini b/pytest.ini index 67a691290076d..532e3bce098f3 100644 --- a/pytest.ini +++ b/pytest.ini @@ -13,3 +13,5 @@ testpaths = junit_logging_reruns = all filterwarnings = ignore:Module already imported so cannot be rewritten.*hypothesis:pytest.PytestAssertRewriteWarning + +xfail_strict = True diff --git a/test/torch_np/numpy_tests/lib/test_arraypad.py b/test/torch_np/numpy_tests/lib/test_arraypad.py index 54745e8316d51..befa9d76ac467 100644 --- a/test/torch_np/numpy_tests/lib/test_arraypad.py +++ b/test/torch_np/numpy_tests/lib/test_arraypad.py @@ -1,15 +1,27 @@ # Owner(s): ["module: dynamo"] -from unittest import expectedFailure as xfail, skipIf as skipif +from unittest import skipIf as skipif -import torch._numpy as np -from torch._numpy.testing import assert_allclose, assert_array_equal +from torch.testing._internal.common_utils import ( + run_tests, + TEST_WITH_TORCHDYNAMO, + TestCase, + xpassIfTorchDynamo, +) -from torch.testing._internal.common_utils import run_tests, TestCase + +# If we are going to trace through these, we should use NumPy +# If testing on eager mode, we use torch._numpy +if TEST_WITH_TORCHDYNAMO: + import numpy as np + from numpy.testing import assert_allclose, assert_array_equal +else: + import torch._numpy as np + from torch._numpy.testing import assert_allclose, assert_array_equal class TestConstant(TestCase): - @xfail # (reason="tuple values") + @xpassIfTorchDynamo # (reason="tuple values") def test_check_constant(self): a = np.arange(100) a = np.pad(a, (25, 20), "constant", constant_values=(10, 20)) @@ -357,7 +369,7 @@ def test_check_constant_float2(self): ) assert_allclose(test, expected) - @xfail # (reason="tuple values") + @xpassIfTorchDynamo # (reason="tuple values") def test_check_constant_float3(self): a = np.arange(100, dtype=float) a = np.pad(a, (25, 20), "constant", constant_values=(-1.1, -1.2)) @@ -528,7 +540,7 @@ def test_check_constant_odd_pad_amount(self): ) assert_allclose(test, expected) - @xfail # (reason="tuple values") + @xpassIfTorchDynamo # (reason="tuple values") def test_check_constant_pad_2d(self): arr = np.arange(4).reshape(2, 2) test = np.lib.pad( diff --git a/test/torch_np/numpy_tests/lib/test_arraysetops.py b/test/torch_np/numpy_tests/lib/test_arraysetops.py index 0f9773ece6dfa..e046558078591 100644 --- a/test/torch_np/numpy_tests/lib/test_arraysetops.py +++ b/test/torch_np/numpy_tests/lib/test_arraysetops.py @@ -3,24 +3,39 @@ """Test functions for 1D array set operations. """ -from unittest import expectedFailure as xfail +from unittest import skipIf -import torch._numpy as np -from pytest import raises as assert_raises - -from torch._numpy import unique +import numpy -from torch._numpy.testing import assert_array_equal, assert_equal +from pytest import raises as assert_raises from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, run_tests, + subtest, + TEST_WITH_TORCHDYNAMO, TestCase, + xfailIfTorchDynamo, + xpassIfTorchDynamo, ) -@xfail # (reason="TODO") +# If we are going to trace through these, we should use NumPy +# If testing on eager mode, we use torch._numpy +if TEST_WITH_TORCHDYNAMO: + import numpy as np + from numpy import ediff1d, in1d, intersect1d, setdiff1d, setxor1d, union1d, unique + from numpy.testing import assert_array_equal, assert_equal, assert_raises_regex + +else: + import torch._numpy as np + from torch._numpy import unique + from torch._numpy.testing import assert_array_equal, assert_equal + + +@skipIf(numpy.__version__ < "1.24", reason="NP_VER: fails on NumPy 1.23.x") +@xpassIfTorchDynamo # (reason="TODO") @instantiate_parametrized_tests class TestSetOps(TestCase): def test_intersect1d(self): @@ -145,11 +160,14 @@ def test_ediff1d(self): (np.array([1, 2, 3], dtype=np.int64), None, np.nan, "to_end"), # should fail because attempting # to downcast to int type: - ( - np.array([1, 2, 3], dtype=np.int64), - np.array([5, 7, 2], dtype=np.float32), - None, - "to_begin", + subtest( + ( + np.array([1, 2, 3], dtype=np.int64), + np.array([5, 7, 2], dtype=np.float32), + None, + "to_begin", + ), + decorators=[xfailIfTorchDynamo], ), # should fail because attempting to cast # two special floating point values @@ -205,6 +223,7 @@ def test_ediff1d_scalar_handling(self, ary, prepend, append, expected): assert_equal(actual, expected) assert actual.dtype == expected.dtype + @skipIf(True, reason="NP_VER: fails with NumPy 1.22.x") @parametrize("kind", [None, "sort", "table"]) def test_isin(self, kind): # the tests for in1d cover most of isin's behavior @@ -217,7 +236,7 @@ def _isin_slow(a, b): isin_slow = np.vectorize(_isin_slow, otypes=[bool], excluded={1}) def assert_isin_equal(a, b): - x = isin(a, b, kind=kind) + x = np.isin(a, b, kind=kind) y = isin_slow(a, b) assert_array_equal(x, y) @@ -444,7 +463,7 @@ def test_in1d_table_timedelta_fails(self): a = np.array([0, 1, 2], dtype="timedelta64[s]") b = a # Make sure it raises a value error: - with pytest.raises(ValueError): + with assert_raises(ValueError): in1d(a, b, kind="table") @parametrize( @@ -475,7 +494,7 @@ def test_in1d_mixed_dtype(self, dtype1, dtype2, kind): ) if expect_failure: - with pytest.raises(RuntimeError, match="exceed the maximum"): + with assert_raises(RuntimeError, match="exceed the maximum"): in1d(ar1, ar2, kind=kind) else: assert_array_equal(in1d(ar1, ar2, kind=kind), expected) @@ -744,7 +763,7 @@ def check_all(a, b, i1, i2, c, dt): # assert_equal(a3_idx.dtype, np.intp) # assert_equal(a3_inv.dtype, np.intp) - @xfail # (reason="unique with nans") + @xpassIfTorchDynamo # (reason="unique with nans") def test_unique_1d_2(self): # test for ticket 2111 - float a = [2.0, np.nan, 1.0, np.nan] @@ -790,7 +809,7 @@ def test_unique_axis_list(self): assert_array_equal(unique(inp, axis=0), unique(inp_arr, axis=0), msg) assert_array_equal(unique(inp, axis=1), unique(inp_arr, axis=1), msg) - @xfail # _run_axis_tests xfails with the message + @xpassIfTorchDynamo # _run_axis_tests xfails with the message # torch has different unique ordering behaviour" def test_unique_axis(self): types = [] @@ -816,7 +835,7 @@ def test_unique_1d_with_axis(self, axis): uniq = unique(x, axis=axis) assert_array_equal(uniq, [1, 2, 3, 4]) - @xfail # (reason="unique / return_index") + @xpassIfTorchDynamo # (reason="unique / return_index") def test_unique_axis_zeros(self): # issue 15559 single_zero = np.empty(shape=(2, 0), dtype=np.int8) @@ -923,7 +942,8 @@ def _run_axis_tests(self, dtype): msg = "Unique's return_counts=True failed with axis=1" assert_array_equal(cnt, np.array([2, 1, 1]), msg) - @xfail # (reason="unique / return_index / nans") + @skipIf(True, reason="NP_VER: fails on CI with older NumPy") + @xpassIfTorchDynamo # (reason="unique / return_index / nans") def test_unique_nanequals(self): # issue 20326 a = np.array([1, 1, np.nan, np.nan, np.nan]) diff --git a/test/torch_np/numpy_tests/lib/test_function_base.py b/test/torch_np/numpy_tests/lib/test_function_base.py index 3934613a64fc4..0c38df1d12c03 100644 --- a/test/torch_np/numpy_tests/lib/test_function_base.py +++ b/test/torch_np/numpy_tests/lib/test_function_base.py @@ -11,29 +11,21 @@ import hypothesis import hypothesis.strategies as st -import pytest -import torch._numpy as np +import numpy + +import pytest from hypothesis.extra.numpy import arrays from pytest import raises as assert_raises -from torch._numpy.testing import ( - assert_, - assert_allclose, # IS_PYPY, - assert_almost_equal, - assert_array_almost_equal, - assert_array_equal, - assert_equal, - assert_raises_regex, - assert_warns, - suppress_warnings, # HAS_REFCOUNT, IS_WASM -) from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, run_tests, subtest, + TEST_WITH_TORCHDYNAMO, TestCase, + xpassIfTorchDynamo, ) skip = functools.partial(skipif, True) @@ -47,25 +39,79 @@ # from numpy lib import digitize, piecewise, trapz, select, trim_zeros, interp from numpy.lib import delete, extract, insert, msort, place, setxor1d, unwrap, vectorize -from torch._numpy import ( - angle, - bartlett, - blackman, - corrcoef, - cov, - diff, - flipud, - gradient, - hamming, - hanning, - i0, - kaiser, - meshgrid, - sinc, - unique, -) -from torch._numpy._util import normalize_axis_tuple -from torch._numpy.random import rand + +# If we are going to trace through these, we should use NumPy +# If testing on eager mode, we use torch._numpy +if TEST_WITH_TORCHDYNAMO: + import numpy as np + from numpy import ( + angle, + bartlett, + blackman, + corrcoef, + cov, + diff, + digitize, + flipud, + gradient, + hamming, + hanning, + i0, + interp, + kaiser, + meshgrid, + sinc, + trapz, + trim_zeros, + unique, + ) + from numpy.core.numeric import normalize_axis_tuple + from numpy.random import rand + + from numpy.testing import ( + assert_, + assert_allclose, # IS_PYPY, + assert_almost_equal, + assert_array_almost_equal, + assert_array_equal, + assert_equal, + assert_raises_regex, + assert_warns, + suppress_warnings, # HAS_REFCOUNT, IS_WASM + ) +else: + import torch._numpy as np + from torch._numpy import ( + angle, + bartlett, + blackman, + corrcoef, + cov, + diff, + flipud, + gradient, + hamming, + hanning, + i0, + kaiser, + meshgrid, + sinc, + unique, + ) + from torch._numpy._util import normalize_axis_tuple + from torch._numpy.random import rand + + from torch._numpy.testing import ( + assert_, + assert_allclose, # IS_PYPY, + assert_almost_equal, + assert_array_almost_equal, + assert_array_equal, + assert_equal, + assert_raises_regex, + assert_warns, + suppress_warnings, # HAS_REFCOUNT, IS_WASM + ) def get_mat(n): @@ -251,7 +297,7 @@ def test_basic(self): assert_equal(a[0, 0], 1) assert_equal(a_copy[0, 0], 10) - @xfail # (reason="order='F' not implemented") + @xpassIfTorchDynamo # (reason="order='F' not implemented") def test_order(self): # It turns out that people rely on np.copy() preserving order by # default; changing this broke scikit-learn: @@ -290,6 +336,7 @@ def test_basic(self): assert_almost_equal(y5.mean(0), np.average(y5, 0)) assert_almost_equal(y5.mean(1), np.average(y5, 1)) + @skip(reason="NP_VER: fails on CI") @parametrize( "x, axis, expected_avg, weights, expected_wavg, expected_wsum", [ @@ -323,6 +370,7 @@ def test_basic_keepdims( assert wsum.shape == np.shape(expected_wsum) assert_array_equal(wsum, expected_wsum) + @skip(reason="NP_VER: fails on CI") def test_weights(self): y = np.arange(10) w = np.arange(10) @@ -477,7 +525,7 @@ def test_many_arguments(self): select(conditions, choices) -@xfail # (reason="TODO: implement") +@xpassIfTorchDynamo # (reason="TODO: implement") @instantiate_parametrized_tests class TestInsert(TestCase): def test_basic(self): @@ -562,6 +610,7 @@ def test_index_floats(self): with pytest.raises(IndexError): np.insert([0, 1, 2], np.array([], dtype=float), []) + @skip(reason="NP_VER: fails on CI") @parametrize("idx", [4, -4]) def test_index_out_of_bounds(self, idx): with pytest.raises(IndexError, match="out of bounds"): @@ -795,7 +844,7 @@ def test_append(self): assert_raises(np.AxisError, diff, x, append=0, axis=3) -@xfail # (reason="TODO: implement") +@xpassIfTorchDynamo # (reason="TODO: implement") @instantiate_parametrized_tests class TestDelete(TestCase): def setUp(self): @@ -867,7 +916,9 @@ def test_index_floats(self): with pytest.raises(IndexError): np.delete([0, 1, 2], np.array([], dtype=float)) - @parametrize("indexer", [np.array([1]), [1]]) + @parametrize( + "indexer", [subtest(np.array([1]), name="array([1])"), subtest([1], name="[1]")] + ) def test_single_item_array(self, indexer): a_del_int = delete(self.a, 1) a_del = delete(self.a, indexer) @@ -1142,7 +1193,7 @@ def test_basic(self): assert_array_almost_equal(z, zo, 11) -@xfail # (reason="trim_zeros not implemented") +@xpassIfTorchDynamo @instantiate_parametrized_tests class TestTrimZeros(TestCase): a = np.array([0, 0, 1, 0, 2, 3, 4, 0]) @@ -1151,7 +1202,11 @@ class TestTrimZeros(TestCase): # d = a.astype(object) def values(self): - attr_names = ("a", "b", "c", "d") + attr_names = ( + "a", + "b", + "c", + ) # "d") return (getattr(self, name) for name in attr_names) def test_basic(self): @@ -1210,7 +1265,7 @@ def test_list_to_list(self): assert isinstance(res, list) -@xfail # (reason="TODO: implement") +@xpassIfTorchDynamo # (reason="TODO: implement") class TestExtins(TestCase): def test_basic(self): a = np.array([1, 3, 2, 1, 2, 3, 3]) @@ -1612,7 +1667,7 @@ def test_size_zero_output(self): f(x) -@xfail # (reason="TODO: implement") +@xpassIfTorchDynamo # (reason="TODO: implement") class TestDigitize(TestCase): def test_forward(self): x = np.arange(-6, 5) @@ -1716,7 +1771,9 @@ def test_period(self): @instantiate_parametrized_tests class TestFilterwindows(TestCase): - @parametrize("dtype", np.typecodes["AllInteger"] + np.typecodes["Float"]) + @parametrize( + "dtype", "Bbhil" + "efd" + ) # np.typecodes["AllInteger"] + np.typecodes["Float"]) @parametrize("M", [0, 1, 10]) def test_hanning(self, dtype: str, M: int) -> None: scalar = M @@ -1736,7 +1793,9 @@ def test_hanning(self, dtype: str, M: int) -> None: else: assert_almost_equal(np.sum(w, axis=0), 4.500, 4) - @parametrize("dtype", np.typecodes["AllInteger"] + np.typecodes["Float"]) + @parametrize( + "dtype", "Bbhil" + "efd" + ) # np.typecodes["AllInteger"] + np.typecodes["Float"]) @parametrize("M", [0, 1, 10]) def test_hamming(self, dtype: str, M: int) -> None: scalar = M @@ -1756,7 +1815,9 @@ def test_hamming(self, dtype: str, M: int) -> None: else: assert_almost_equal(np.sum(w, axis=0), 4.9400, 4) - @parametrize("dtype", np.typecodes["AllInteger"] + np.typecodes["Float"]) + @parametrize( + "dtype", "Bbhil" + "efd" + ) # np.typecodes["AllInteger"] + np.typecodes["Float"]) @parametrize("M", [0, 1, 10]) def test_bartlett(self, dtype: str, M: int) -> None: scalar = M @@ -1776,7 +1837,9 @@ def test_bartlett(self, dtype: str, M: int) -> None: else: assert_almost_equal(np.sum(w, axis=0), 4.4444, 4) - @parametrize("dtype", np.typecodes["AllInteger"] + np.typecodes["Float"]) + @parametrize( + "dtype", "Bbhil" + "efd" + ) # np.typecodes["AllInteger"] + np.typecodes["Float"]) @parametrize("M", [0, 1, 10]) def test_blackman(self, dtype: str, M: int) -> None: scalar = M @@ -1796,7 +1859,9 @@ def test_blackman(self, dtype: str, M: int) -> None: else: assert_almost_equal(np.sum(w, axis=0), 3.7800, 4) - @parametrize("dtype", np.typecodes["AllInteger"] + np.typecodes["Float"]) + @parametrize( + "dtype", "Bbhil" + "efd" + ) # np.typecodes["AllInteger"] + np.typecodes["Float"]) @parametrize("M", [0, 1, 10]) def test_kaiser(self, dtype: str, M: int) -> None: scalar = M @@ -1817,7 +1882,7 @@ def test_kaiser(self, dtype: str, M: int) -> None: assert_almost_equal(np.sum(w, axis=0), 10, 15) -@xfail # (reason="TODO: implement") +@xpassIfTorchDynamo # (reason="TODO: implement") class TestTrapz(TestCase): def test_simple(self): x = np.arange(-10, 10, 0.1) @@ -1886,13 +1951,13 @@ def test_simple(self): assert_(unique(np.array([1, 1, 1, 1, 1])) == np.array([1])) - @xfail # (reason="unique not implemented for 'ComplexDouble'") + @xpassIfTorchDynamo # (reason="unique not implemented for 'ComplexDouble'") def test_simple_complex(self): x = np.array([5 + 6j, 1 + 1j, 1 + 10j, 10, 5 + 6j]) assert_(np.all(unique(x) == [1 + 1j, 1 + 10j, 5 + 6j, 10])) -@xfail # (reason="TODO: implement") +@xpassIfTorchDynamo # (reason="TODO: implement") class TestCheckFinite(TestCase): def test_simple(self): a = [1, 2, 3] @@ -2537,7 +2602,19 @@ def test_error_not_1d(self, vals): np.bincount(vals) -@xfail # (reason="TODO: implement") +parametrize_interp_sc = parametrize( + "sc", + [ + subtest(lambda x: np.float_(x), name="real"), + subtest(lambda x: _make_complex(x, 0), name="complex-real"), + subtest(lambda x: _make_complex(0, x), name="complex-imag"), + subtest(lambda x: _make_complex(x, np.multiply(x, -2)), name="complex-both"), + ], +) + + +@xpassIfTorchDynamo # (reason="TODO: implement") +@instantiate_parametrized_tests class TestInterp(TestCase): def test_exceptions(self): assert_raises(ValueError, interp, 0, [], []) @@ -2612,19 +2689,7 @@ def test_non_finite_behavior_exact_x(self): fp = [1, 2, np.nan, 4] assert_almost_equal(np.interp(x, xp, fp), [1, 2, np.nan, np.nan, 4]) - @pytest.fixture( - params=[ - lambda x: np.float_(x), - lambda x: _make_complex(x, 0), - lambda x: _make_complex(0, x), - lambda x: _make_complex(x, np.multiply(x, -2)), - ], - ids=["real", "complex-real", "complex-imag", "complex-both"], - ) - def sc(self, request): - """scale function used by the below tests""" - return request.param - + @parametrize_interp_sc def test_non_finite_any_nan(self, sc): """test that nans are propagated""" assert_equal(np.interp(0.5, [np.nan, 1], sc([0, 10])), sc(np.nan)) @@ -2632,6 +2697,7 @@ def test_non_finite_any_nan(self, sc): assert_equal(np.interp(0.5, [0, 1], sc([np.nan, 10])), sc(np.nan)) assert_equal(np.interp(0.5, [0, 1], sc([0, np.nan])), sc(np.nan)) + @parametrize_interp_sc def test_non_finite_inf(self, sc): """Test that interp between opposite infs gives nan""" assert_equal(np.interp(0.5, [-np.inf, +np.inf], sc([0, 10])), sc(np.nan)) @@ -2641,6 +2707,7 @@ def test_non_finite_inf(self, sc): # unless the y values are equal assert_equal(np.interp(0.5, [-np.inf, +np.inf], sc([10, 10])), sc(10)) + @parametrize_interp_sc def test_non_finite_half_inf_xf(self, sc): """Test that interp where both axes have a bound at inf gives nan""" assert_equal(np.interp(0.5, [-np.inf, 1], sc([-np.inf, 10])), sc(np.nan)) @@ -2652,6 +2719,7 @@ def test_non_finite_half_inf_xf(self, sc): assert_equal(np.interp(0.5, [0, +np.inf], sc([0, -np.inf])), sc(np.nan)) assert_equal(np.interp(0.5, [0, +np.inf], sc([0, +np.inf])), sc(np.nan)) + @parametrize_interp_sc def test_non_finite_half_inf_x(self, sc): """Test interp where the x axis has a bound at inf""" assert_equal(np.interp(0.5, [-np.inf, -np.inf], sc([0, 10])), sc(10)) @@ -2659,6 +2727,7 @@ def test_non_finite_half_inf_x(self, sc): assert_equal(np.interp(0.5, [0, +np.inf], sc([0, 10])), sc(0)) assert_equal(np.interp(0.5, [+np.inf, +np.inf], sc([0, 10])), sc(0)) + @parametrize_interp_sc def test_non_finite_half_inf_f(self, sc): """Test interp where the f axis has a bound at inf""" assert_equal(np.interp(0.5, [0, 1], sc([0, -np.inf])), sc(-np.inf)) @@ -2739,6 +2808,7 @@ def test_period(self): @instantiate_parametrized_tests class TestPercentile(TestCase): + @skip(reason="NP_VER: fails on CI; no method=") def test_basic(self): x = np.arange(8) * 0.5 assert_equal(np.percentile(x, 0), 0.0) @@ -2786,7 +2856,8 @@ def test_2D(self): x = np.array([[1, 1, 1], [1, 1, 1], [4, 4, 3], [1, 1, 1], [1, 1, 1]]) assert_array_equal(np.percentile(x, 50, axis=0), [1, 1, 1]) - @xfail # (reason="TODO: implement") + @skip(reason="NP_VER: fails on CI; no method=") + @xpassIfTorchDynamo # (reason="TODO: implement") @parametrize("dtype", np.typecodes["Float"]) def test_linear_nan_1D(self, dtype): # METHOD 1 of H&F @@ -2796,14 +2867,14 @@ def test_linear_nan_1D(self, dtype): np.testing.assert_equal(res.dtype, arr.dtype) H_F_TYPE_CODES = [ - (int_type, np.float64) for int_type in np.typecodes["AllInteger"] + (int_type, np.float64) for int_type in "Bbhil" # np.typecodes["AllInteger"] ] + [ (np.float16, np.float16), (np.float32, np.float32), (np.float64, np.float64), ] - @xfail # (reason="TODO: implement percentile interpolations") + @skip(reason="NEP 50 is new in 1.24") @parametrize("input_dtype, expected_dtype", H_F_TYPE_CODES) @parametrize( "method, expected", @@ -2821,7 +2892,11 @@ def test_linear_nan_1D(self, dtype): ) def test_linear_interpolation(self, method, expected, input_dtype, expected_dtype): expected_dtype = np.dtype(expected_dtype) - if np._get_promotion_state() == "legacy": + + if ( + hasattr(np, "_get_promotion_state") + and np._get_promotion_state() == "legacy" + ): expected_dtype = np.promote_types(expected_dtype, np.float64) arr = np.asarray([15.0, 20.0, 35.0, 40.0, 50.0], dtype=input_dtype) @@ -2836,11 +2911,13 @@ def test_linear_interpolation(self, method, expected, input_dtype, expected_dtyp TYPE_CODES = np.typecodes["AllInteger"] + np.typecodes["Float"] + @skip(reason="NP_VER: fails on CI; no method=") @parametrize("dtype", TYPE_CODES) def test_lower_higher(self, dtype): assert_equal(np.percentile(np.arange(10, dtype=dtype), 50, method="lower"), 4) assert_equal(np.percentile(np.arange(10, dtype=dtype), 50, method="higher"), 5) + @skip(reason="NP_VER: fails on CI; no method=") @parametrize("dtype", TYPE_CODES) def test_midpoint(self, dtype): assert_equal( @@ -2856,6 +2933,7 @@ def test_midpoint(self, dtype): np.percentile(np.arange(11, dtype=dtype), 50, method="midpoint"), 5 ) + @skip(reason="NP_VER: fails on CI; no method=") @parametrize("dtype", TYPE_CODES) def test_nearest(self, dtype): assert_equal(np.percentile(np.arange(10, dtype=dtype), 51, method="nearest"), 5) @@ -2874,6 +2952,7 @@ def test_sequence(self): x = np.arange(8) * 0.5 assert_equal(np.percentile(x, [0, 100, 50]), [0, 3.5, 1.75]) + @skip(reason="NP_VER: fails on CI") def test_axis(self): x = np.arange(12).reshape(3, 4) @@ -2912,6 +2991,7 @@ def test_axis(self): np.percentile(x, (25, 50, 75), axis=1, method="higher").shape, (3, 3, 5, 6) ) + @skipif(numpy.__version__ < "1.22", reason="NP_VER: fails with NumPy 1.21.2 on CI") def test_scalar_q(self): # test for no empty dimensions for compatibility with old percentile x = np.arange(12).reshape(3, 4) @@ -2963,6 +3043,7 @@ def test_scalar_q_2(self): assert_equal(c, r1) assert_equal(out, r1) + @skip(reason="NP_VER: fails on CI; no method=") def test_exception(self): assert_raises( (RuntimeError, ValueError), np.percentile, [1, 2], 56, method="foobar" @@ -2979,6 +3060,7 @@ def test_exception(self): def test_percentile_list(self): assert_equal(np.percentile([1, 2, 3], 0), 1) + @skip(reason="NP_VER: fails on CI; no method=") def test_percentile_out(self): x = np.array([1, 2, 3]) y = np.zeros((3,)) @@ -3019,6 +3101,7 @@ def test_percentile_out(self): assert_equal(c, r1) assert_equal(out, r1) + @skip(reason="NP_VER: fails on CI; no method=") def test_percentile_empty_dim(self): # empty dims are preserved d = np.arange(11 * 2).reshape(11, 1, 2, 1) @@ -3060,6 +3143,7 @@ def test_percentile_no_overwrite(self): np.percentile(a, [50]) assert_equal(a, np.array([2, 3, 4, 1])) + @skip(reason="NP_VER: fails on CI; no method=") def test_no_p_overwrite(self): p = np.linspace(0.0, 100.0, num=5) np.percentile(np.arange(100.0), p, method="midpoint") @@ -3076,7 +3160,7 @@ def test_percentile_overwrite(self): b = np.percentile([2, 3, 4, 1], [50], overwrite_input=True) assert_equal(b, np.array([2.5])) - @xfail # (reason="pytorch percentile does not support tuple axes.") + @xpassIfTorchDynamo # (reason="pytorch percentile does not support tuple axes.") def test_extended_axis(self): o = np.random.normal(size=(71, 23)) x = np.dstack([o] * 10) @@ -3165,6 +3249,7 @@ def test_keepdims_2(self): np.percentile(d, [1, 7], axis=(0, 3), keepdims=True).shape, (2, 1, 5, 7, 1) ) + @skipif(numpy.__version__ < "1.24", reason="NP_VER: fails on NumPy 1.23.x") @parametrize( "q", [ @@ -3172,7 +3257,7 @@ def test_keepdims_2(self): subtest( [1, 7], decorators=[ - xfail, + xpassIfTorchDynamo, ], ), ], @@ -3186,13 +3271,13 @@ def test_keepdims_2(self): subtest( (0, 1), decorators=[ - xfail, + xpassIfTorchDynamo, ], ), subtest( (-3, -1), decorators=[ - xfail, + xpassIfTorchDynamo, ], ), ], @@ -3213,6 +3298,7 @@ def test_keepdims_out(self, q, axis): assert result is out assert_equal(result.shape, shape_out) + @skip(reason="NP_VER: fails on CI; no method=") def test_out(self): o = np.zeros((4,)) d = np.ones((3, 4)) @@ -3227,6 +3313,7 @@ def test_out(self): assert_equal(np.percentile(d, 2, out=o), o) assert_equal(np.percentile(d, 2, method="nearest", out=o), o) + @skip(reason="NP_VER: fails on CI; no method=") def test_out_nan(self): with warnings.catch_warnings(record=True): warnings.filterwarnings("always", "", RuntimeWarning) @@ -3242,7 +3329,8 @@ def test_out_nan(self): assert_equal(np.percentile(d, 1, out=o), o) assert_equal(np.percentile(d, 1, method="nearest", out=o), o) - @xfail # (reason="np.percentile undocumented nan weirdness") + @skip(reason="NP_VER: fails on CI; no method=") + @xpassIfTorchDynamo # (reason="np.percentile undocumented nan weirdness") def test_nan_behavior(self): a = np.arange(24, dtype=float) a[2] = np.nan @@ -3335,7 +3423,7 @@ def test_basic(self): assert_equal(np.quantile(x, 1), 3.5) assert_equal(np.quantile(x, 0.5), 1.75) - @xfail # (reason="quantile w/integers or bools") + @xpassIfTorchDynamo # (reason="quantile w/integers or bools") def test_correct_quantile_value(self): a = np.array([True]) tf_quant = np.quantile(True, False) @@ -3382,6 +3470,7 @@ def test_complex(self): arr_c = np.array([0.5 + 3.0j, 2.1 + 0.5j, 1.6 + 2.3j], dtype="F") assert_raises(TypeError, np.quantile, arr_c, 0.5) + @skipif(numpy.__version__ < "1.22", reason="NP_VER: fails with NumPy 1.21.2 on CI") def test_no_p_overwrite(self): # this is worth retesting, because quantile does not make a copy p0 = np.array([0, 0.75, 0.25, 0.5, 1.0]) @@ -3394,62 +3483,64 @@ def test_no_p_overwrite(self): np.quantile(np.arange(100.0), p, method="midpoint") assert_array_equal(p, p0) - @xfail # (reason="TODO: make quantile preserve integers") - @parametrize("dtype", np.typecodes["AllInteger"]) + @skipif(numpy.__version__ < "1.22", reason="NP_VER: fails with NumPy 1.21.2 on CI") + @xpassIfTorchDynamo # (reason="TODO: make quantile preserve integers") + @parametrize("dtype", "Bbhil") # np.typecodes["AllInteger"]) def test_quantile_preserve_int_type(self, dtype): res = np.quantile(np.array([1, 2], dtype=dtype), [0.5], method="nearest") assert res.dtype == dtype + @skipif(numpy.__version__ < "1.22", reason="NP_VER: fails with NumPy 1.21.2 on CI") @parametrize( "method", [ subtest( "inverted_cdf", decorators=[ - xfail, + xpassIfTorchDynamo, ], ), subtest( "averaged_inverted_cdf", decorators=[ - xfail, + xpassIfTorchDynamo, ], ), subtest( "closest_observation", decorators=[ - xfail, + xpassIfTorchDynamo, ], ), subtest( "interpolated_inverted_cdf", decorators=[ - xfail, + xpassIfTorchDynamo, ], ), subtest( "hazen", decorators=[ - xfail, + xpassIfTorchDynamo, ], ), subtest( "weibull", decorators=[ - xfail, + xpassIfTorchDynamo, ], ), "linear", subtest( "median_unbiased", decorators=[ - xfail, + xpassIfTorchDynamo, ], ), subtest( "normal_unbiased", decorators=[ - xfail, + xpassIfTorchDynamo, ], ), "nearest", @@ -3517,7 +3608,7 @@ def test_basic(self): a = np.array([0.0444502, 0.141249, 0.0463301]) assert_equal(a[-1], np.median(a)) - @xfail # (reason="median: scalar output vs 0-dim") + @xpassIfTorchDynamo # (reason="median: scalar output vs 0-dim") def test_basic_2(self): # check array scalar result a = np.array([0.0444502, 0.141249, 0.0463301]) @@ -3626,7 +3717,7 @@ def test_nan_behavior(self): b[1, 2] = np.nan assert_equal(np.median(a, 1), b) - @xfail # (reason="median: does not support tuple axes") + @xpassIfTorchDynamo # (reason="median: does not support tuple axes") def test_nan_behavior_2(self): a = np.arange(24, dtype=float).reshape(2, 3, 4) a[1, 2, 3] = np.nan @@ -3638,7 +3729,7 @@ def test_nan_behavior_2(self): b[2] = np.nan assert_equal(np.median(a, (0, 2)), b) - @xfail # (reason="median: scalar vs 0-dim") + @xpassIfTorchDynamo # (reason="median: scalar vs 0-dim") def test_nan_behavior_3(self): a = np.arange(24, dtype=float).reshape(2, 3, 4) a[1, 2, 3] = np.nan @@ -3647,7 +3738,7 @@ def test_nan_behavior_3(self): # no axis assert_equal(np.median(a).ndim, 0) - @xfail # (reason="median: torch.quantile does not handle empty tensors") + @xpassIfTorchDynamo # (reason="median: torch.quantile does not handle empty tensors") @skipif(IS_WASM, reason="fp errors don't work correctly") def test_empty(self): # mean(empty array) emits two warnings: empty slice and divide by 0 @@ -3678,7 +3769,7 @@ def test_empty(self): assert_equal(np.median(a, axis=2), b) assert_(w[0].category is RuntimeWarning) - @xfail # (reason="median: tuple axes not implemented") + @xpassIfTorchDynamo # (reason="median: tuple axes not implemented") def test_extended_axis(self): o = np.random.normal(size=(71, 23)) x = np.dstack([o] * 10) @@ -3728,7 +3819,7 @@ def test_keepdims(self): d = np.ones((3, 5, 7, 11)) assert_equal(np.median(d, axis=None, keepdims=True).shape, (1, 1, 1, 1)) - @xfail # (reason="median: tuple axis") + @xpassIfTorchDynamo # (reason="median: tuple axis") def test_keepdims_2(self): d = np.ones((3, 5, 7, 11)) assert_equal(np.median(d, axis=(0, 1), keepdims=True).shape, (1, 1, 7, 11)) @@ -3737,6 +3828,7 @@ def test_keepdims_2(self): assert_equal(np.median(d, axis=(0, 1, 2, 3), keepdims=True).shape, (1, 1, 1, 1)) assert_equal(np.median(d, axis=(0, 1, 3), keepdims=True).shape, (1, 1, 7, 1)) + @skipif(numpy.__version__ < "1.24", reason="NP_VER: fails on NumPy 1.23.x") @parametrize( "axis", [ @@ -3746,13 +3838,13 @@ def test_keepdims_2(self): subtest( (0, 1), decorators=[ - xfail, + xpassIfTorchDynamo, ], ), subtest( (-3, -1), decorators=[ - xfail, + xpassIfTorchDynamo, ], ), ], @@ -3772,7 +3864,7 @@ def test_keepdims_out(self, axis): assert_equal(result.shape, shape_out) -@xfail # (reason="TODO: implement") +@xpassIfTorchDynamo # (reason="TODO: implement") @instantiate_parametrized_tests class TestSortComplex(TestCase): @parametrize( diff --git a/test/torch_np/numpy_tests/lib/test_histograms.py b/test/torch_np/numpy_tests/lib/test_histograms.py index 9d6b0364fc2d2..4b09ef5b207b4 100644 --- a/test/torch_np/numpy_tests/lib/test_histograms.py +++ b/test/torch_np/numpy_tests/lib/test_histograms.py @@ -3,32 +3,46 @@ # from numpy.testing._private.utils import requires_memory import functools -from unittest import expectedFailure as xfail, skipIf +from unittest import skipIf -import pytest -import torch._numpy as np from pytest import raises as assert_raises -from torch._numpy import histogram, histogramdd - -# from numpy.lib.histograms import histogram, histogramdd, histogram_bin_edges -from torch._numpy.testing import ( - assert_, - assert_allclose, - assert_almost_equal, - assert_array_almost_equal, - assert_array_equal, - assert_equal, - # assert_array_max_ulp, #assert_raises_regex, suppress_warnings, -) + +skip = functools.partial(skipIf, True) + from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, run_tests, slowTest as slow, + TEST_WITH_TORCHDYNAMO, TestCase, + xpassIfTorchDynamo, ) -skip = functools.partial(skipIf, True) +if TEST_WITH_TORCHDYNAMO: + import numpy as np + from numpy import histogram, histogram_bin_edges, histogramdd + from numpy.testing import ( + assert_, + assert_allclose, + assert_almost_equal, + assert_array_almost_equal, + assert_array_equal, + assert_equal, + # assert_array_max_ulp, #assert_raises_regex, suppress_warnings, + ) +else: + import torch._numpy as np + from torch._numpy import histogram, histogramdd + from torch._numpy.testing import ( + assert_, + assert_allclose, + assert_almost_equal, + assert_array_almost_equal, + assert_array_equal, + assert_equal, + # assert_array_max_ulp, #assert_raises_regex, suppress_warnings, + ) class TestHistogram(TestCase): @@ -189,7 +203,7 @@ def test_weights(self): ) assert_almost_equal(a, [0.2, 0.1, 0.1, 0.075]) - @xfail # (reason="histogram complex weights") + @xpassIfTorchDynamo # (reason="histogram complex weights") def test_exotic_weights(self): # Test the use of weights that are not integer or floats, but e.g. # complex numbers or object types. @@ -251,7 +265,7 @@ def test_invalid_range(self): with assert_raises((RuntimeError, ValueError)): np.histogram(vals, range=[0.1, 0.01]) - @xfail # (reason="edge cases") + @xpassIfTorchDynamo # (reason="edge cases") def test_bin_edge_cases(self): # Ensure that floating-point computations correctly place edge cases. arr = np.array([337, 404, 739, 806, 1007, 1811, 2012]) @@ -275,7 +289,7 @@ def test_bin_array_dims(self): with assert_raises((RuntimeError, ValueError)): np.histogram(vals, bins=bins) - @xfail # (reason="no uint64") + @xpassIfTorchDynamo # (reason="no uint64") def test_unsigned_monotonicity_check(self): # Ensures ValueError is raised if bins not increasing monotonically # when bins contain unsigned values (see #9222) @@ -301,7 +315,7 @@ def test_object_array_of_0d(self): np.histogram([np.array(0.5) for i in range(10)] + [0.500000000000001]) np.histogram([np.array(0.5) for i in range(10)] + [0.5]) - @xfail # (reason="bins='auto'") + @xpassIfTorchDynamo # (reason="bins='auto'") def test_some_nan_values(self): # gh-7503 one_nan = np.array([0, 1, np.nan]) @@ -339,7 +353,7 @@ def test_signed_overflow_bounds(self): self.do_signed_overflow_bounds(np.short) self.do_signed_overflow_bounds(np.intc) - @xfail # (reason="int->float conversin loses precision") + @xpassIfTorchDynamo # (reason="int->float conversin loses precision") def test_signed_overflow_bounds_2(self): self.do_signed_overflow_bounds(np.int_) self.do_signed_overflow_bounds(np.longlong) @@ -382,14 +396,14 @@ def do_precision(self, float_small, float_large): self.do_precision_lower_bound(float_small, float_large) self.do_precision_upper_bound(float_small, float_large) - @xfail # (reason="mixed dtypes") + @xpassIfTorchDynamo # (reason="mixed dtypes") def test_precision(self): # not looping results in a useful stack trace upon failure self.do_precision(np.half, np.single) self.do_precision(np.half, np.double) self.do_precision(np.single, np.double) - @xfail # (reason="histogram_bin_edges") + @xpassIfTorchDynamo # (reason="histogram_bin_edges") def test_histogram_bin_edges(self): hist, e = histogram([1, 2, 3, 4], [1, 2]) edges = histogram_bin_edges([1, 2, 3, 4], [1, 2]) @@ -405,7 +419,7 @@ def test_histogram_bin_edges(self): assert_array_equal(edges, e) # @requires_memory(free_bytes=1e10) - @xfail # (reason="pytorch does not support bins = [int, int, array]") + @xpassIfTorchDynamo # (reason="pytorch does not support bins = [int, int, array]") @slow def test_big_arrays(self): sample = np.zeros([100000000, 3]) @@ -416,7 +430,7 @@ def test_big_arrays(self): assert_equal(type(hist), type((1, 2))) -@xfail # (reason="TODO") +@xpassIfTorchDynamo # (reason="TODO") @instantiate_parametrized_tests class TestHistogramOptimBinNums(TestCase): """ @@ -698,7 +712,6 @@ def test_simple_weighted(self): """ Check that weighted data raises a TypeError """ - pytest.xpass(reason="passes by chance") estimator_list = ["fd", "scott", "rice", "sturges", "auto"] for estimator in estimator_list: assert_raises(TypeError, histogram, [1, 2, 3], estimator, weights=[1, 2, 3]) @@ -840,13 +853,13 @@ def test_bins_errors(self): (RuntimeError, ValueError), np.histogramdd, x, bins=[1, 1, 1, [1, 2, 3, -3]] ) - @xfail # (reason="pytorch does not support bins = [int, int, array]") + @xpassIfTorchDynamo # (reason="pytorch does not support bins = [int, int, array]") def test_bins_error_2(self): # mixing scalar (# of bins) and explicit bin arrays, ugh x = np.arange(8).reshape(2, 4) assert_(np.histogramdd(x, bins=[1, 1, 1, [1, 2, 3, 4]])) - @xfail # (reason="pytorch does not support bins = [int, int, array]") + @xpassIfTorchDynamo # (reason="pytorch does not support bins = [int, int, array]") def test_inf_edges(self): # Test using +/-inf bin edges works. See #1788. x = np.arange(6).reshape(3, 2) @@ -897,7 +910,7 @@ def test_finite_range(self): range=[[0.0, 1.0], [np.nan, 0.75], [0.25, 0.5]], ) - @xfail # (reason="pytorch does not allow equal entries") + @xpassIfTorchDynamo # (reason="pytorch does not allow equal entries") def test_equal_edges(self): """Test that adjacent entries in an edge array can be equal""" x = np.array([0, 1, 2]) @@ -928,7 +941,7 @@ def test_edge_dtype(self): def test_large_integers(self): big = 2**60 # Too large to represent with a full precision float - x = np.array([0], np.int64) + x = np.asarray([0], dtype=np.int64) x_edges = np.array([-1, +1], np.int64) y = big + x y_edges = big + x_edges diff --git a/test/torch_np/numpy_tests/lib/test_index_tricks.py b/test/torch_np/numpy_tests/lib/test_index_tricks.py index d3aac7663ec2e..e43e33be03946 100644 --- a/test/torch_np/numpy_tests/lib/test_index_tricks.py +++ b/test/torch_np/numpy_tests/lib/test_index_tricks.py @@ -4,29 +4,52 @@ from unittest import expectedFailure as xfail, skipIf -import torch._numpy as np - from pytest import raises as assert_raises # , assert_raises_regex, -from torch._numpy import diag_indices, diag_indices_from, fill_diagonal, index_exp, s_ -from torch._numpy.testing import ( - assert_, - assert_almost_equal, - assert_array_almost_equal, - assert_array_equal, - assert_equal, -) from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, run_tests, + TEST_WITH_TORCHDYNAMO, TestCase, + xpassIfTorchDynamo, ) skip = functools.partial(skipIf, True) -@xfail # (reason="unravel_index not implemented") +# If we are going to trace through these, we should use NumPy +# If testing on eager mode, we use torch._numpy +if TEST_WITH_TORCHDYNAMO: + import numpy as np + from numpy import diag_indices, diag_indices_from, fill_diagonal, index_exp, s_ + from numpy.testing import ( + assert_, + assert_almost_equal, + assert_array_almost_equal, + assert_array_equal, + assert_equal, + assert_raises_regex, + ) +else: + import torch._numpy as np + from torch._numpy import ( + diag_indices, + diag_indices_from, + fill_diagonal, + index_exp, + s_, + ) + from torch._numpy.testing import ( + assert_, + assert_almost_equal, + assert_array_almost_equal, + assert_array_equal, + assert_equal, + ) + + +@xpassIfTorchDynamo # (reason="unravel_index not implemented") @instantiate_parametrized_tests class TestRavelUnravelIndex(TestCase): def test_basic(self): @@ -428,7 +451,7 @@ def test_repeated_input(self): class TestC(TestCase): - @xfail # (reason="c_ not implemented") + @xpassIfTorchDynamo # (reason="c_ not implemented") def test_c_(self): a = np.c_[np.array([[1, 2, 3]]), 0, 0, np.array([[4, 5, 6]])] assert_equal(a, [[1, 2, 3, 0, 0, 4, 5, 6]]) diff --git a/test/torch_np/numpy_tests/lib/test_shape_base_.py b/test/torch_np/numpy_tests/lib/test_shape_base_.py index 673d1ed0b537e..728f756f9e999 100644 --- a/test/torch_np/numpy_tests/lib/test_shape_base_.py +++ b/test/torch_np/numpy_tests/lib/test_shape_base_.py @@ -5,34 +5,62 @@ from unittest import expectedFailure as xfail, skipIf as skipif -import torch._numpy as np - from pytest import raises as assert_raises -from torch._numpy import ( - array_split, - column_stack, - dsplit, - dstack, - expand_dims, - hsplit, - kron, - put_along_axis, - split, - take_along_axis, - tile, - vsplit, -) -from torch._numpy.random import rand, randint - -from torch._numpy.testing import assert_, assert_array_equal, assert_equal from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, run_tests, + TEST_WITH_TORCHDYNAMO, TestCase, + xfailIfTorchDynamo, + xpassIfTorchDynamo, ) + +# If we are going to trace through these, we should use NumPy +# If testing on eager mode, we use torch._numpy +if TEST_WITH_TORCHDYNAMO: + import numpy as np + from numpy import ( + apply_along_axis, + array_split, + column_stack, + dsplit, + dstack, + expand_dims, + hsplit, + kron, + put_along_axis, + split, + take_along_axis, + tile, + vsplit, + ) + from numpy.random import rand, randint + + from numpy.testing import assert_, assert_array_equal, assert_equal + +else: + import torch._numpy as np + from torch._numpy import ( + array_split, + column_stack, + dsplit, + dstack, + expand_dims, + hsplit, + kron, + put_along_axis, + split, + take_along_axis, + tile, + vsplit, + ) + from torch._numpy.random import rand, randint + from torch._numpy.testing import assert_, assert_array_equal, assert_equal + + skip = functools.partial(skipif, True) @@ -126,7 +154,7 @@ def test_replace_max(self): assert_equal(i_min, i_max) - @xfail # ( + @xpassIfTorchDynamo # ( # reason="RuntimeError: Expected index [1, 2, 5] to be smaller than self [3, 4, 1] apart from dimension 1") def test_broadcast(self): """Test that non-indexing dimensions are broadcast in both directions""" @@ -136,7 +164,7 @@ def test_broadcast(self): assert_equal(take_along_axis(a, ai, axis=1), 20) -@xfail # (reason="apply_along_axis not implemented") +@xpassIfTorchDynamo # (reason="apply_along_axis not implemented") class TestApplyAlongAxis(TestCase): def test_simple(self): a = np.ones((20, 10), "d") @@ -679,6 +707,8 @@ def test_basic(self): assert_equal(res.ndim, 0) assert type(res) is np.ndarray + @xfailIfTorchDynamo + def test_basic_2(self): aa = np.ones((3, 1, 4, 1, 1)) assert aa.squeeze().tensor._base is aa.tensor @@ -712,7 +742,7 @@ def test_squeeze_contiguous(self): assert_(a.flags.f_contiguous) assert_(b.flags.f_contiguous) - @xfail # (reason="XXX: noop in torch, while numpy raises") + @xpassIfTorchDynamo # (reason="XXX: noop in torch, while numpy raises") def test_squeeze_axis_handling(self): with assert_raises(ValueError): np.squeeze(np.array([[1], [2], [3]]), axis=0) @@ -749,6 +779,7 @@ def test_basic(self): k = np.array([[[1, 2], [3, 4]], [[2, 4], [6, 8]]]) assert_array_equal(np.kron(a, b), k) + @skip(reason="NP_VER: fails on CI") @parametrize( "shape_a,shape_b", [ @@ -810,7 +841,7 @@ def test_kroncompare(self): assert_equal(large, klarge) -@xfail # (reason="TODO: implement") +@xpassIfTorchDynamo # (reason="TODO: implement") class TestMayShareMemory(TestCase): def test_basic(self): d = np.ones((50, 60)) diff --git a/test/torch_np/numpy_tests/lib/test_twodim_base.py b/test/torch_np/numpy_tests/lib/test_twodim_base.py index bbf9fd1bbc5cd..dda807b556369 100644 --- a/test/torch_np/numpy_tests/lib/test_twodim_base.py +++ b/test/torch_np/numpy_tests/lib/test_twodim_base.py @@ -8,40 +8,72 @@ from unittest import expectedFailure as xfail, skipIf as skipif import pytest - -import torch._numpy as np from pytest import raises as assert_raises -from torch._numpy import ( - arange, - array, - diag, - eye, - fliplr, - flipud, - histogram2d, - ones, - tri, # mask_indices, - tril_indices, - tril_indices_from, - triu_indices, - triu_indices_from, - vander, - zeros, -) -from torch._numpy.testing import ( - assert_allclose, - assert_array_almost_equal, - assert_array_equal, # assert_array_max_ulp, - assert_equal, -) from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, run_tests, + TEST_WITH_TORCHDYNAMO, TestCase, + xpassIfTorchDynamo, ) + +# If we are going to trace through these, we should use NumPy +# If testing on eager mode, we use torch._numpy +if TEST_WITH_TORCHDYNAMO: + import numpy as np + from numpy import ( + arange, + array, + diag, + eye, + fliplr, + flipud, + histogram2d, + ones, + tri, # mask_indices, + tril_indices, + tril_indices_from, + triu_indices, + triu_indices_from, + vander, + zeros, + ) + from numpy.testing import ( + assert_allclose, + assert_array_almost_equal, + assert_array_equal, # assert_array_max_ulp, + assert_equal, + ) +else: + import torch._numpy as np + from torch._numpy import ( + arange, + array, + diag, + eye, + fliplr, + flipud, + histogram2d, + ones, + tri, # mask_indices, + tril_indices, + tril_indices_from, + triu_indices, + triu_indices_from, + vander, + zeros, + ) + from torch._numpy.testing import ( + assert_allclose, + assert_array_almost_equal, + assert_array_equal, # assert_array_max_ulp, + assert_equal, + ) + + skip = functools.partial(skipif, True) @@ -101,7 +133,7 @@ def test_eye_bounds(self): def test_bool(self): assert_equal(eye(2, 2, dtype=bool), [[True, False], [False, True]]) - @xfail # (reason="TODO: implement order=non-default") + @xpassIfTorchDynamo # (reason="TODO: implement order=non-default") def test_order(self): mat_c = eye(4, 3, k=-1) mat_f = eye(4, 3, k=-1, order="F") @@ -127,9 +159,10 @@ def test_vector(self): assert_equal(diag(vals, k=2), b) assert_equal(diag(vals, k=-2), c) - def test_matrix(self, vals=None): - if vals is None: - vals = (100 * get_mat(5) + 1).astype("l") + def test_matrix(self): + self.check_matrix(vals=(100 * get_mat(5) + 1).astype("l")) + + def check_matrix(self, vals): b = zeros((5,)) for k in range(5): b[k] = vals[k, k] @@ -142,10 +175,10 @@ def test_matrix(self, vals=None): b[k] = vals[k + 2, k] assert_equal(diag(vals, -2), b[:3]) - @xfail # (reason="TODO implement orders") + @xpassIfTorchDynamo # (reason="TODO implement orders") def test_fortran_order(self): vals = array((100 * get_mat(5) + 1), order="F", dtype="l") - self.test_matrix(vals) + self.check_matrix(vals) def test_diag_bounds(self): A = [[1, 2], [3, 4], [5, 6]] @@ -251,7 +284,7 @@ def test_empty(self): # assert_array_max_ulp(a, np.zeros((4, 4))) assert_allclose(a, np.zeros((4, 4)), atol=1e-15) - @xfail # (reason="pytorch does not support bins = [int, array]") + @xpassIfTorchDynamo # (reason="pytorch does not support bins = [int, array]") def test_binparameter_combination(self): x = array([0, 0.09207008, 0.64575234, 0.12875982, 0.47390599, 0.59944483, 1]) y = array([0, 0.14344267, 0.48988575, 0.30558665, 0.44700682, 0.15886423, 1]) @@ -285,6 +318,7 @@ def test_binparameter_combination(self): assert_array_equal(H, answer) assert_array_equal(xe, array([0.0, 0.25, 0.5, 0.75, 1])) + @skip(reason="NP_VER: fails on CI with older NumPy") @parametrize("x_len, y_len", [(10, 11), (20, 19)]) def test_bad_length(self, x_len, y_len): x, y = np.ones(x_len), np.ones(y_len) @@ -368,7 +402,7 @@ def test_mask_indices(self): iu1 = mask_indices(3, np.triu, 1) assert_array_equal(a[iu1], array([1, 2, 5])) - @xfail # (reason="np.tril_indices == our tuple(tril_indices)") + @xpassIfTorchDynamo # (reason="np.tril_indices == our tuple(tril_indices)") def test_tril_indices(self): # indices without and with offset il1 = tril_indices(4) @@ -428,7 +462,7 @@ def test_tril_indices(self): ) -@xfail # (reason="np.triu_indices == our tuple(triu_indices)") +@xpassIfTorchDynamo # (reason="np.triu_indices == our tuple(triu_indices)") class TestTriuIndices(TestCase): def test_triu_indices(self): iu1 = triu_indices(4) diff --git a/test/torch_np/numpy_tests/lib/test_type_check.py b/test/torch_np/numpy_tests/lib/test_type_check.py index 0afa518edb228..96c0ddbc9672b 100644 --- a/test/torch_np/numpy_tests/lib/test_type_check.py +++ b/test/torch_np/numpy_tests/lib/test_type_check.py @@ -5,22 +5,44 @@ from unittest import expectedFailure as xfail, skipIf as skipif -import torch._numpy as np from pytest import raises as assert_raises - -from torch._numpy import ( - common_type, - iscomplex, - iscomplexobj, - isneginf, - isposinf, - isreal, - isrealobj, - nan_to_num, - real_if_close, +from torch.testing._internal.common_utils import ( + run_tests, + TEST_WITH_TORCHDYNAMO, + TestCase, + xpassIfTorchDynamo, ) -from torch._numpy.testing import assert_, assert_array_equal, assert_equal -from torch.testing._internal.common_utils import run_tests, TestCase + + +if TEST_WITH_TORCHDYNAMO: + import numpy as np + from numpy import ( + common_type, + iscomplex, + iscomplexobj, + isneginf, + isposinf, + isreal, + isrealobj, + nan_to_num, + real_if_close, + ) + from numpy.testing import assert_, assert_array_equal, assert_equal +else: + import torch._numpy as np + from torch._numpy import ( + common_type, + iscomplex, + iscomplexobj, + isneginf, + isposinf, + isreal, + isrealobj, + nan_to_num, + real_if_close, + ) + from torch._numpy.testing import assert_, assert_array_equal, assert_equal + skip = functools.partial(skipif, True) @@ -29,7 +51,7 @@ def assert_all(x): assert_(np.all(x), x) -@xfail # (reason="common_type not implemented") +@xpassIfTorchDynamo # (reason="common_type not implemented") class TestCommonType(TestCase): def test_basic(self): ai32 = np.array([[1, 2], [3, 4]], dtype=np.int32) @@ -96,7 +118,7 @@ def test_default_3(self): assert_equal(mintypecode("idD"), "D") -@xfail # (reason="TODO: decide on if [1] is a scalar or not") +@xpassIfTorchDynamo # (reason="TODO: decide on if [1] is a scalar or not") class TestIsscalar(TestCase): def test_basic(self): assert_(np.isscalar(3)) From a1a765c195d42572864ffe7657e28edf9cec3e67 Mon Sep 17 00:00:00 2001 From: drisspg Date: Sat, 28 Oct 2023 00:06:08 +0000 Subject: [PATCH 75/78] Mirror of Xformers Fix (#112267) # Summary See https://github.com/fairinternal/xformers/pull/850 for more details Pull Request resolved: https://github.com/pytorch/pytorch/pull/112267 Approved by: https://github.com/cpuhrsch --- .../transformers/cuda/mem_eff_attention/kernel_forward.h | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h index ff47b81a1b7c0..c126b7b2944f0 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h @@ -297,9 +297,11 @@ struct AttentionKernel { // 15/16th of tensor core compute In that case : // - we only launch kernels for head_id % kQueriesPerBlock == 0 // - we iterate over heads instead of queries (strideM = strideH) - if (num_queries == 1 && k_strideH == 0 && v_strideH == 0) { - if (head_id % kQueriesPerBlock != 0) + if (num_queries == 1 && k_strideH == 0 && v_strideH == 0 && + logsumexp_ptr == nullptr) { + if (head_id % kQueriesPerBlock != 0) { return false; + } q_strideM = q_strideH; num_queries = num_heads; num_heads = 1; // unused but here for intent From 1ff0b82be977107ab67ad2817ea76d46d3478d8f Mon Sep 17 00:00:00 2001 From: chilli Date: Thu, 26 Oct 2023 19:15:41 -0700 Subject: [PATCH 76/78] Added patterns for randperm + index_add (#112102) Pull Request resolved: https://github.com/pytorch/pytorch/pull/112102 Approved by: https://github.com/lezcano ghstack dependencies: #112093, #112101 --- test/inductor/test_cpp_wrapper.py | 2 +- test/inductor/test_pattern_matcher.py | 44 ++++++++++++++- test/inductor/test_perf.py | 25 +++++++++ torch/_inductor/fx_passes/joint_graph.py | 2 + torch/_inductor/fx_passes/misc_patterns.py | 64 ++++++++++++++++++++++ torch/_inductor/fx_passes/post_grad.py | 13 +++-- torch/_inductor/inductor_prims.py | 7 +++ torch/_inductor/lowering.py | 5 ++ 8 files changed, 153 insertions(+), 9 deletions(-) create mode 100644 torch/_inductor/fx_passes/misc_patterns.py diff --git a/test/inductor/test_cpp_wrapper.py b/test/inductor/test_cpp_wrapper.py index e2a80276eeeb1..a6407698e50d0 100644 --- a/test/inductor/test_cpp_wrapper.py +++ b/test/inductor/test_cpp_wrapper.py @@ -287,7 +287,7 @@ class BaseTest(NamedTuple): BaseTest( "test_cat_slice_cat", device=None, - tests=test_pattern_matcher.TestPaternMatcher(), + tests=test_pattern_matcher.TestPatternMatcher(), ), BaseTest( "test_addmm", diff --git a/test/inductor/test_pattern_matcher.py b/test/inductor/test_pattern_matcher.py index d1f7abff5a9ff..0f96514428cfa 100644 --- a/test/inductor/test_pattern_matcher.py +++ b/test/inductor/test_pattern_matcher.py @@ -25,13 +25,22 @@ from torch.testing._internal.inductor_utils import HAS_CUDA -class TestPaternMatcher(TestCase): - def common(self, fn, args, expected_matches, expected_nodes): +class TestPatternMatcher(TestCase): + def common( + self, + fn, + args, + expected_matches, + expected_nodes, + additional_check=lambda code: None, + ): counters.clear() torch.manual_seed(42) expected = fn(*args) torch.manual_seed(42) - actual = torch.compile(fn)(*args) + actual, codes = run_and_get_code(torch.compile(fn), *args) + if len(codes) == 1: + codes = codes[0] torch.testing.assert_close(actual, expected) if inductor_config.cpp_wrapper: # CPP wrapper runs everything twice, so we'll match the pattern twice @@ -42,6 +51,7 @@ def common(self, fn, args, expected_matches, expected_nodes): counters["inductor"]["pattern_matcher_count"], expected_matches ) self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], expected_nodes) + additional_check(codes) counters.clear() def test_mm_plus_mm(self): @@ -907,6 +917,34 @@ def test_fuse_attention_all_patterns_serialized(self): msg=f"Found mismatched pattern {key}. Run gen_attention_patterns.py", ) + def test_randperm_index(self): + def scaled_index_add(x, y, scale_y): + index = torch.randperm(x.shape[0], device=x.device)[: y.shape[0]] + out = torch.index_add(x, dim=0, source=y * scale_y, index=index) + return out + + dim = 4 + x = torch.randn([8, dim], requires_grad=True, device="cuda") + gO = torch.randn([8, dim], device="cuda") + y = torch.randn([4, dim], requires_grad=True, device="cuda") + scale = torch.randn([dim], requires_grad=True, device="cuda") + + with torch.no_grad(): + self.common(lambda *args: scaled_index_add(*args), (x, y, scale), 1, 3) + + def code_check(codes): + self.assertNotIn("device_assert", codes[0]) + self.assertNotIn("device_assert", codes[1]) + + # Ugh, there is an extra match here because of removing pointless views + self.common( + lambda *args: scaled_index_add(*args).backward(gO), + (x, y, scale), + 3, + 7, + additional_check=code_check, + ) + if __name__ == "__main__": if IS_LINUX and HAS_CUDA: diff --git a/test/inductor/test_perf.py b/test/inductor/test_perf.py index 65978e241e52f..182a3ea3047e2 100644 --- a/test/inductor/test_perf.py +++ b/test/inductor/test_perf.py @@ -603,6 +603,20 @@ def f(a, b): inp = (T(10), TI(2, mx=5)) self.assertExpectedInline(count_numel(f, *inp), """26""") + def f(a, b): + out = aten.index_put(a, (b,), torch.tensor(1.0)) + return a.copy_(out) + + inp = (T(10), TI(2, mx=5)) + self.assertExpectedInline(count_numel(f, *inp), """6""") + + def f(a, b): + out = aten._unsafe_index_put(a, (b,), torch.tensor(1.0)) + return a.copy_(out) + + inp = (T(10), TI(2, mx=5)) + self.assertExpectedInline(count_numel(f, *inp), """6""") + def test_inplace_scatter_noop_view(self): def f(a, b): a[:, b] = 1 @@ -652,6 +666,17 @@ def f(a, b): inp = (T(10, 1, 8), T(1, 10, 8)) self.assertExpectedInline(count_numel(f, *inp), """170""") + # We need more sophisticated decisions for inplacing + # This tests randperm + scatter pattern match as well as inplacing + def test_inplace_randperm_scatter(self): + def scaled_index_add(x, y, scale_y): + index = torch.randperm(x.shape[0], device=x.device)[: y.shape[0]] + out = x.index_add_(dim=0, source=y * scale_y, index=index) + return out + + inp = (T(10, 10), T(5, 10), T(10)) + self.assertExpectedInline(count_numel(scaled_index_add, *inp), """240""") + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_inductor/fx_passes/joint_graph.py b/torch/_inductor/fx_passes/joint_graph.py index 8adb08b3accb1..1038c15032184 100644 --- a/torch/_inductor/fx_passes/joint_graph.py +++ b/torch/_inductor/fx_passes/joint_graph.py @@ -27,10 +27,12 @@ @init_once_fakemode def lazy_init(): from .fuse_attention import _sfdp_init + from .misc_patterns import _misc_patterns_init from .pad_mm import _pad_mm_init _pad_mm_init() _sfdp_init() + _misc_patterns_init() @torch.utils._python_dispatch._disable_current_modes() diff --git a/torch/_inductor/fx_passes/misc_patterns.py b/torch/_inductor/fx_passes/misc_patterns.py new file mode 100644 index 0000000000000..946d879cbc07a --- /dev/null +++ b/torch/_inductor/fx_passes/misc_patterns.py @@ -0,0 +1,64 @@ +import functools + +import torch + +from ..pattern_matcher import fwd_only, register_replacement + +aten = torch.ops.aten + + +@functools.lru_cache(None) +def _misc_patterns_init(): + from .joint_graph import patterns as joint_graph_patterns + from .post_grad import pass_patterns as post_grad_patterns_all + + post_grad_patterns = post_grad_patterns_all[1] # medium priority + + if torch.cuda.is_available(): + # workaround https://github.com/pytorch/pytorch/issues/97894 + device = "cuda" + else: + device = "cpu" + + # These patterns do 2 things + # 1. Since we know that index is completely unique, we can codegen it using + # stores instead of atomic adds, which is quite a bit faster. + # 2. Also, since we are guaranteed that they are completely within bounds, + # we can use unsafe indexing and skip debug asserts + def randperm_index_add_pattern(x, y): + index = torch.randperm(x.shape[0], device=x.device)[: y.shape[0]] + return torch.index_add(x, dim=0, source=y, index=index), index + + def randperm_index_add_replacement(x, y): + index = torch.randperm(x.shape[0], device=x.device)[: y.shape[0]] + return ( + torch.ops.aten._unsafe_index_put( + x, (index,), aten._unsafe_index(x, (index,)) + y, accumulate=False + ), + index, + ) + + register_replacement( + randperm_index_add_pattern, + randperm_index_add_replacement, + [torch.empty(4, 8, device=device), torch.empty(2, 8, device=device)], + fwd_only, + [post_grad_patterns, joint_graph_patterns], + ) + + def randperm_index_pattern(x, slice_shape): + index = torch.randperm(x.shape[0], device=x.device)[:slice_shape] + return torch.ops.aten.index(x, (index,)), index + + def randperm_index_replacement(x, slice_shape): + index = torch.randperm(x.shape[0], device=x.device)[:slice_shape] + return torch.ops.aten._unsafe_index(x, (index,)), index + + pattern = register_replacement( + randperm_index_pattern, + randperm_index_replacement, + [torch.empty(4, 8, device=device)], + fwd_only, + [post_grad_patterns, joint_graph_patterns], + scalar_workaround={"slice_shape": 42}, + ) diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py index 7b2797f978c6d..2a2c7b1b1152c 100644 --- a/torch/_inductor/fx_passes/post_grad.py +++ b/torch/_inductor/fx_passes/post_grad.py @@ -15,7 +15,7 @@ from torch._prims_common import is_boolean_dtype, is_expandable_to, is_integer_dtype from torch.fx.immutable_collections import immutable_dict -from .. import config, ir, pattern_matcher +from .. import config, inductor_prims, ir, pattern_matcher from ..fx_utils import FakeTensorUpdater, get_fake_args_kwargs, get_node_storage from ..lowering import lowerings as L @@ -652,7 +652,7 @@ def reinplace_scatters(graph): assert node.args[0].op == "placeholder" mutated_inputs.add(node.args[0]) - def can_replace(node, mutated_arg): + def can_inplace(node, mutated_arg): if get_node_storage(mutated_arg) is None: return False shared_view_nodes = storage_to_nodes[get_node_storage(mutated_arg)] @@ -665,7 +665,7 @@ def can_replace(node, mutated_arg): if len(shared_view_nodes) > 2: # Arg aliases another node other than copy_ return False - # Check for any uses other than current node and copy_ epilogue + # # Check for any uses other than current node and copy_ epilogue if len(mutated_arg.users) > 2: return False @@ -682,12 +682,15 @@ def can_replace(node, mutated_arg): inplaceable_ops = { aten.index_put.default: InplaceableOp(aten.index_put_.default, 0), + aten._unsafe_index_put.default: InplaceableOp( + inductor_prims._unsafe_index_put_, 0 + ), } inplaceable_triton_ops = {triton_kernel_wrapper_functional} for node in graph.nodes: if (inplaceable_op := inplaceable_ops.get(node.target, None)) is not None: - if can_replace(node, node.args[inplaceable_op.mutated_arg]): + if can_inplace(node, node.args[inplaceable_op.mutated_arg]): node.target = inplaceable_op.inplace_op elif node.target in inplaceable_triton_ops: # inplaceable_triton_ops take an additional argument called @@ -697,7 +700,7 @@ def can_replace(node, mutated_arg): tensors_to_clone = [] for arg in node.kwargs["tensors_to_clone"]: assert arg in node.kwargs["kwargs"] - if not can_replace(node, node.kwargs["kwargs"][arg]): + if not can_inplace(node, node.kwargs["kwargs"][arg]): tensors_to_clone.append(arg) kwargs = dict(node.kwargs) kwargs["tensors_to_clone"] = tensors_to_clone diff --git a/torch/_inductor/inductor_prims.py b/torch/_inductor/inductor_prims.py index 277e20736181d..499496c3d7ee5 100644 --- a/torch/_inductor/inductor_prims.py +++ b/torch/_inductor/inductor_prims.py @@ -81,3 +81,10 @@ def eager_force_stride(input_tensor: Tensor, stride) -> tuple[int, ...]: ), doc="masked_scatter with precomputed indices", ) +_unsafe_index_put_ = make_prim( + "_unsafe_index_put_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor(a!)", + lambda self, indices, values, accumulate=False: torch.ops.aten.index_put_( + self, indices, values, accumulate + ), + doc="Unsafe index_put_ (doesn't issue device asserts)", +) diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 5b08b107bb0fc..3ae3b8752714c 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -2911,6 +2911,11 @@ def index_put_(self, indices, values, accumulate=False): return index_put_impl_(self, indices, values, accumulate, add_asserts=True) +@register_lowering(inductor_prims._unsafe_index_put_, type_promotion_kind=None) +def _unsafe_index_put_(self, indices, values, accumulate=False): + return index_put_impl_(self, indices, values, accumulate, add_asserts=False) + + def index_put_impl_(self, indices, values, accumulate, add_asserts): # Dispatch to masked fill for single boolean index with single value if ( From 668c3b3f3b21eb0e1c81f829ff64fa5a374bfe61 Mon Sep 17 00:00:00 2001 From: Antoni Viros Date: Fri, 27 Oct 2023 22:23:34 +0000 Subject: [PATCH 77/78] Add embedding op to jagged NT (#112288) Pull Request resolved: https://github.com/pytorch/pytorch/pull/112288 Approved by: https://github.com/cpuhrsch --- test/test_nestedtensor.py | 7 +++---- torch/nested/_internal/ops.py | 18 ++++++++++++++++++ 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index 825ec2fe23316..95b1d7c8b46d6 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -790,14 +790,13 @@ def test_layer_norm_breaking(self, device, dtype): lambda: layer_norm(nt), ) - @skipMeta - @torch.inference_mode() - def test_embedding(self, device): + @parametrize("layout", [torch.strided, torch.jagged]) + def test_embedding(self, device, layout): inputs = [ torch.randint(100, (L,), device=device, dtype=torch.int64) for L in torch.randint(5, 50, (8,)) ] - x = torch.nested.nested_tensor(inputs, device=device, dtype=torch.int64) + x = torch.nested.nested_tensor(inputs, device=device, dtype=torch.int64, layout=layout) emb = torch.nn.Embedding(100, 8, device=device) y = emb(x) ys = y.unbind() diff --git a/torch/nested/_internal/ops.py b/torch/nested/_internal/ops.py index cb68cea722fee..95d930f373558 100644 --- a/torch/nested/_internal/ops.py +++ b/torch/nested/_internal/ops.py @@ -674,3 +674,21 @@ def stack_default(func, *args, **kwargs): return NestedTensor( func([t._values for t in tensors], **new_kwargs), **extract_kwargs(tensors[0]) ) + + +@register_jagged_func( + torch.ops.aten.embedding.default, + "weight: t, indices: jt, padding_idx: any?, scale_grad_by_freq: any?, sparse: any?", +) +def embedding_default(func, *args, **kwargs): + _, new_kwargs = normalize_function( + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + # guaranteed this is non-empty if we got here + indices = new_kwargs.pop("indices") + weight = new_kwargs.pop("weight") + + return NestedTensor( + func(weight, indices._values, **new_kwargs), **extract_kwargs(indices) + ) From 8d44999183e564e849519215f423a8ea5c4918ea Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Sat, 28 Oct 2023 01:51:32 +0000 Subject: [PATCH 78/78] Revert "[Inductor] Add triton.autotune support for user defined triton kernels with constant/simple grids (#112228)" This reverts commit dbb31a2984fa616b4bb6fac7abb2a06ec0533eb1. Reverted https://github.com/pytorch/pytorch/pull/112228 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it is failing ROCm test in trunk https://hud.pytorch.org/pytorch/pytorch/commit/dbb31a2984fa616b4bb6fac7abb2a06ec0533eb1 ([comment](https://github.com/pytorch/pytorch/pull/112228#issuecomment-1783660326)) --- test/dynamo/test_functions.py | 43 -------------------- torch/_dynamo/variables/builder.py | 6 +-- torch/_dynamo/variables/functions.py | 17 +------- torch/_inductor/codegen/wrapper.py | 28 ++++++------- torch/_inductor/ir.py | 10 +---- torch/_inductor/triton_heuristics.py | 60 ++++------------------------ 6 files changed, 24 insertions(+), 140 deletions(-) diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index f19b689783daf..8837c4cc13985 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -1402,30 +1402,6 @@ def add_kernel( output = x + y tl.store(out_ptr + offsets, output, mask=mask) - @triton.autotune( - configs=[ - triton.Config({"BLOCK_SIZE": 128}, num_stages=3, num_warps=8), - triton.Config({"BLOCK_SIZE": 64}, num_stages=3, num_warps=8), - ], - key=[], - ) - @triton.jit - def add_kernel_autotuned( - in_ptr0, - in_ptr1, - out_ptr, - n_elements, - BLOCK_SIZE: "tl.constexpr", - ): - pid = tl.program_id(axis=0) - block_start = pid * BLOCK_SIZE - offsets = block_start + tl.arange(0, BLOCK_SIZE) - mask = offsets < n_elements - x = tl.load(in_ptr0 + offsets, mask=mask) - y = tl.load(in_ptr1 + offsets, mask=mask) - output = x + y - tl.store(out_ptr + offsets, output, mask=mask) - @triton.jit def mul2_kernel( in_ptr0, @@ -2011,25 +1987,6 @@ def call_triton( # reset back CONSTANT_C = prev_c - @requires_cuda() - @requires_triton() - @common_utils.parametrize("grad", [False, True]) - @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"]) - def test_triton_kernel_autotune(self, grad, backend): - def call_triton(x: torch.Tensor, y: torch.Tensor): - output = torch.zeros_like(x, requires_grad=grad) - n_elements = output.numel() - grid = (n_elements,) - add_kernel_autotuned[grid](x, y, output, n_elements) - return output - - t1 = torch.rand(5, device="cuda", requires_grad=grad) - t2 = torch.rand(5, device="cuda", requires_grad=grad) - - torch_add = t1 + t2 - compiled_func = torch.compile(call_triton, backend=backend, fullgraph=True) - self.assertEqual(compiled_func(t1, t2), torch_add) - @requires_cuda() @requires_triton() @common_utils.parametrize("grad", [False, True]) diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 5773aeca6acc5..54778e3dec6ed 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -362,16 +362,12 @@ def _wrap(self, value): from torch.utils._triton import has_triton if has_triton(): - from triton.runtime.autotuner import Autotuner from triton.runtime.jit import JITFunction else: class JITFunction: pass - class Autotuner: - pass - make_guards = self.make_guards # Handle exact type() match @@ -720,7 +716,7 @@ def index_source(key): sym_node_proxy, new_symint == 1, ) - elif isinstance(value, (JITFunction, Autotuner)): + elif isinstance(value, JITFunction): return TritonKernelVariable( value, None, # No kernel idx provided diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index 138b82aeed741..1bd06a0727ab1 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -652,12 +652,10 @@ def get_val(v): class TritonKernelVariable(VariableTracker): def __init__(self, kernel, kernel_idx, grid, **kwargs): - from triton.runtime.autotuner import Autotuner + super().__init__(**kwargs) from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table - super().__init__(**kwargs) - assert kernel is not None self.kernel = kernel @@ -667,19 +665,6 @@ def __init__(self, kernel, kernel_idx, grid, **kwargs): self.grid = grid - if isinstance(kernel, Autotuner): - # We only support configs and keys arguments of triton.autotune - # Make sure other arguments are defaulted - defaults = inspect.signature(Autotuner).parameters - if ( - defaults["warmup"].default != kernel.warmup - or defaults["rep"].default != kernel.rep - or defaults["prune_configs_by"].default != kernel.early_config_prune - ): - raise Unsupported( - "Only configs and keys are supported for triton.autotune" - ) - def call_function( self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" ) -> "VariableTracker": diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 3674c6f817357..8bef2105f6266 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -778,7 +778,7 @@ def get_unique_kernel_name(self, name: str) -> str: self.user_defined_kernel_count += 1 return new_name - def define_user_defined_triton_kernel(self, name, kernel, configs, kwargs): + def define_user_defined_triton_kernel(self, name, kernel, kwargs): original_name = kernel.__name__ compile_wrapper = IndentedBuffer() compile_wrapper.writeline(f"async_compile.triton({original_name!r}, '''") @@ -788,18 +788,26 @@ def define_user_defined_triton_kernel(self, name, kernel, configs, kwargs): import triton import triton.language as tl from torch._inductor.utils import instance_descriptor - from torch._inductor.triton_heuristics import user_autotune + from torch._inductor.triton_heuristics import template """, strip=True, ) compile_wrapper.newline() + # TODO(oulgen): num_stages and num_warps are default values of + # triton.Config. Can we do better? Or ask the user to provide? + num_stages = 2 + num_warps = 4 + from ..ir import Buffer from .common import SizeArg, TensorArg signature: List[Union[TensorArg, SizeArg]] = [] constants = {} for key, arg in kwargs.items(): + # Not a real argument + if key == "grid": + continue if ( key in kernel.__annotations__ and "constexpr" in kernel.__annotations__[key] @@ -821,20 +829,12 @@ def define_user_defined_triton_kernel(self, name, kernel, configs, kwargs): "configs": [config_of(signature)], "kernel_name": name, } - configs = [ - { - "kwargs": config.kwargs, - "num_warps": config.num_warps, - "num_stages": config.num_stages, - } - for config in configs - ] compile_wrapper.splice( f""" - @user_autotune( - configs={configs!r}, - meta={triton_meta!r}, - filename=__file__ + @template( + num_stages={num_stages}, + num_warps={num_warps}, + meta={triton_meta!r} ) @triton.jit """ diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 4e20ce362b0c0..7be3836e0d75e 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -3768,15 +3768,9 @@ def apply_constraint(self): class UserDefinedTritonKernel(ExternKernel): def codegen(self, wrapper): - from triton.runtime.autotuner import Autotuner - from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table kernel = kernel_side_table.get_kernel(self.kernel_idx) - configs = [] - if isinstance(kernel, Autotuner): - configs = kernel.configs - kernel = kernel.fn new_name = wrapper.get_unique_kernel_name(kernel.__name__) self.codegen_comment(wrapper) @@ -3785,9 +3779,7 @@ def codegen(self, wrapper): self.grid, self.codegen_kwargs(), ) - wrapper.define_user_defined_triton_kernel( - new_name, kernel, configs, self.kwargs - ) + wrapper.define_user_defined_triton_kernel(new_name, kernel, self.kwargs) def should_allocate(self): return False diff --git a/torch/_inductor/triton_heuristics.py b/torch/_inductor/triton_heuristics.py index 218d97d2ab7d6..be062e8eb3cbe 100644 --- a/torch/_inductor/triton_heuristics.py +++ b/torch/_inductor/triton_heuristics.py @@ -12,7 +12,7 @@ import re import threading from enum import auto, Enum -from typing import Any, Callable, Dict, List, Optional, Set, Tuple +from typing import Any, Callable, List, Optional, Set, Tuple import torch @@ -62,7 +62,6 @@ class HeuristicType(Enum): REDUCTION = auto() PERSISTENT_REDUCTION = auto() TEMPLATE = auto() - USER_AUTOTUNE = auto() class AutotuneHint(Enum): @@ -345,7 +344,7 @@ def launcher({', '.join(def_args)}, grid, stream): return binary, launcher - def bench(self, launcher, *args, grid, **kwargs): + def bench(self, launcher, *args, grid): """Measure the performance of a given launcher""" if launcher.n_spills > config.triton.spill_threshold: log.debug( @@ -363,17 +362,16 @@ def kernel_call(): {**dict(zip(self.arg_names, args)), **launcher.config.kwargs} ) - cloned_args, cloned_kwargs = self.clone_args(*args, **kwargs) + cloned_args = self.clone_args(*args) launcher( *cloned_args, - **cloned_kwargs, grid=grid, stream=stream, ) return do_bench(kernel_call, rep=40, fast_flush=True) - def clone_args(self, *args, **kwargs) -> Tuple[List[Any], Dict[str, Any]]: + def clone_args(self, *args): from .compile_fx import clone_preserve_strides # clone inplace buffers to avoid autotune contaminating them if @@ -387,15 +385,7 @@ def clone_args(self, *args, **kwargs) -> Tuple[List[Any], Dict[str, Any]]: else: cloned_args.append(arg) - cloned_kwargs: Dict[str, Any] = {} - for name, arg in kwargs.items(): - if name in self.mutated_arg_names: - assert isinstance(arg, torch.Tensor) - cloned_kwargs[name] = clone_preserve_strides(arg) - else: - cloned_kwargs[name] = arg - - return cloned_args, cloned_kwargs + return cloned_args @dynamo_timed def benchmark_all_configs(self, *args, **kwargs): @@ -461,14 +451,11 @@ def coordinate_descent_tuning(self, launcher, *args, **kwargs): Then if coordinate descnt tuning is run with max-autotune disabled, it will start from C1; while if coordinate descent tuning is run with max-autotune enabled, it will start from C3. """ - if ( - self.heuristic_type == HeuristicType.TEMPLATE - or self.heuristic_type == HeuristicType.USER_AUTOTUNE - ): + if self.heuristic_type == HeuristicType.TEMPLATE: # skip triton template return launcher - cloned_args, _ = self.clone_args(*args) + cloned_args = self.clone_args(*args) config2launcher = {launcher.config: launcher} def benchmark_one_config(config): @@ -1143,39 +1130,6 @@ def template(num_stages, num_warps, meta, filename=None): ) -def user_autotune(configs, meta, filename=None): - """ - Compile a user defined triton kernel - """ - defaults = inspect.signature(triton.Config).parameters - default_num_stages = defaults["num_stages"].default - default_num_warps = defaults["num_warps"].default - - if len(configs) == 0: - configs = [ - triton.Config( - {}, num_stages=default_num_stages, num_warps=default_num_warps - ) - ] - else: - configs = [ - triton.Config( - c.get("kwargs", {}), - num_stages=c.get("num_stages", default_num_stages), - num_warps=c.get("num_warps", default_num_warps), - ) - for c in configs - ] - - return cached_autotune( - None, - configs, - meta=meta, - heuristic_type=HeuristicType.USER_AUTOTUNE, - filename=filename, - ) - - def foreach(meta, num_warps, filename=None): """ Compile a triton foreach kernel