diff --git a/aten/src/ATen/native/mps/operations/Linear.mm b/aten/src/ATen/native/mps/operations/Linear.mm index 39685d334edf..83d1d75a7eee 100644 --- a/aten/src/ATen/native/mps/operations/Linear.mm +++ b/aten/src/ATen/native/mps/operations/Linear.mm @@ -1,5 +1,6 @@ // Copyright © 2022 Apple Inc. #define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include #include namespace at::native { @@ -21,11 +22,26 @@ Tensor _mps_linear(const Tensor& input, const Tensor& weight_arg, const c10::opt auto input_size = input.sizes(); std::vector output_size(input_size.begin(), input_size.end() - 1); output_size.push_back(weight.size(0)); + + TORCH_CHECK(input.size(-1) == weight_arg.size(-1), + "linear(): input and weight.T shapes cannot be multiplied (", + input.size(-2), + "x", + input.size(-1), + " and ", + weight_arg.size(-1), + "x", + weight_arg.size(-2), + ")"); + + if (is_bias_defined) { + // Check bias and output shapes compatibility only. + inferExpandGeometry_dimvector(bias.sizes(), bias.strides(), output_size); + } + Tensor output = at::empty(output_size, input.scalar_type(), c10::nullopt, kMPS, c10::nullopt, input.suggest_memory_format()); - TORCH_CHECK(output.is_mps()); - if (output.numel() == 0) { return output; } @@ -69,7 +85,8 @@ Tensor _mps_linear(const Tensor& input, const Tensor& weight_arg, const c10::opt MPSGraphTensor* inputFlattened = inputTensor; bool doReshape = false; // workaround to improve the performance with 3D+ inputs - if (input_size.size() > 2 && input_size[0] > 1 && input_size[1] >= 1 && input_size[1] <= 32) { + if (input_size.size() > 2 && input_size[0] > 1 && input_size[1] >= 1 && input_size[1] <= 32 && + bias.dim() == 1) { doReshape = true; inputFlattened = [mpsGraph flatten2DTensor:inputTensor axis:-1 name:nil]; } diff --git a/test/test_mps.py b/test/test_mps.py index 0ad5acab5851..9d0cfbfaab18 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -1495,6 +1495,18 @@ def test_linear_1d_weight(self): self.assertEqual(linear, linear_mps) + def test_linear_2d_bias(self): + device = "cpu" + x = torch.randn(2, 2, 2, 64, device=device) + linear = torch.nn.Linear(64, 4, device=device) + linear.bias = torch.nn.Parameter(torch.arange(8, dtype=torch.float32, device=device).reshape(2, 4)) + y = linear(x) + device = "mps" + x_mps = x.to(device) + linear.to(device) + y_mps = linear(x_mps) + self.assertEqual(y, y_mps) + def _linear_helper(self, in_features, out_features, shape, bias=True, backward_pass=False): cpu_linear = torch.nn.Linear(in_features=in_features, out_features=out_features, device="cpu", bias=bias) mps_linear = torch.nn.Linear(in_features=in_features, out_features=out_features, device="mps", bias=bias)