From efc2d08eacbf9ab5e070e4df1c64512cadceb4b3 Mon Sep 17 00:00:00 2001 From: Richard Zou Date: Tue, 2 Aug 2022 18:17:27 -0400 Subject: [PATCH] Revert #75195 (#82504) (#82662) This is a short-term fix for a serious regression in functorch (https://github.com/pytorch/functorch/issues/989). Additional things this PR does: - the out= tests for nn.functional.linear fail after the revert. I added some xfails. These xfails were present in the original PR (#75195). - the profiler tests fail on the revert, so I updated the expecttests for the profiler tests Test Plan: - test offline that the functorch regression was fixed Pull Request resolved: https://github.com/pytorch/pytorch/pull/82504 Approved by: https://github.com/ngimel, https://github.com/ezyang, https://github.com/atalman --- aten/src/ATen/native/LinearAlgebra.cpp | 3 ++- torch/testing/_internal/common_methods_invocations.py | 4 ++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index 8721080816de..4bf90198dd8b 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -1712,7 +1712,8 @@ Tensor _matmul_impl( } else if (dim_tensor1 == 2 && dim_tensor2 == 1) { return has_out ? at::mv_out(out, tensor1, tensor2) : tensor1.mv(tensor2); } else if (dim_tensor1 == 1 && dim_tensor2 == 2) { - return has_out ? at::mv_out(out, tensor2.t(), tensor1) : tensor2.t().mv(tensor1); + return has_out ? at::mm_out(out, tensor1.unsqueeze(0), tensor2).squeeze_(0) + : tensor1.unsqueeze(0).mm(tensor2).squeeze_(0); } else if (dim_tensor1 == 2 && dim_tensor2 == 2) { return has_out ? at::mm_out(out, tensor1, tensor2) : tensor1.mm(tensor2); } else if (should_fold(tensor1, dim_tensor2) || should_fold(tensor2, dim_tensor1)) { diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 939ca9077e5b..00eb3f7f09b5 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -12192,6 +12192,8 @@ def error_inputs_mean(op_info, device, **kwargs): 'TestCommon', 'test_noncontiguous_samples', device_type='cpu'), ], skips=( + # Strides are not the same! + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), # https://github.com/pytorch/pytorch/issues/67470 DecorateInfo(unittest.skip("67470!"), 'TestCommon', 'test_noncontiguous_samples', @@ -13517,6 +13519,8 @@ def error_inputs_mean(op_info, device, **kwargs): DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'), ), decorators=( + # Strides are not the same! + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}), 'TestCudaFuserOpInfo', 'test_nvfuser_correctness'), )),