Skip to content

Commit

Permalink
[MPS] Add linear inputs check
Browse files Browse the repository at this point in the history
Fixes #98211

[ghstack-poisoned]
  • Loading branch information
qqaatw committed Apr 15, 2023
1 parent 05809c7 commit 037dc23
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 3 deletions.
23 changes: 20 additions & 3 deletions aten/src/ATen/native/mps/operations/Linear.mm
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// Copyright © 2022 Apple Inc.
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/ExpandUtils.h>
#include <ATen/native/mps/OperationUtils.h>

namespace at::native {
Expand All @@ -21,11 +22,26 @@ Tensor _mps_linear(const Tensor& input, const Tensor& weight_arg, const c10::opt
auto input_size = input.sizes();
std::vector<int64_t> output_size(input_size.begin(), input_size.end() - 1);
output_size.push_back(weight.size(0));

TORCH_CHECK(input.size(-1) == weight_arg.size(-1),
"linear(): input and weight.T shapes cannot be multiplied (",
input.size(-2),
"x",
input.size(-1),
" and ",
weight_arg.size(-1),
"x",
weight_arg.size(-2),
")");

if (is_bias_defined) {
// Check bias and output shapes compatibility only.
inferExpandGeometry_dimvector(bias.sizes(), bias.strides(), output_size);
}

Tensor output =
at::empty(output_size, input.scalar_type(), c10::nullopt, kMPS, c10::nullopt, input.suggest_memory_format());

TORCH_CHECK(output.is_mps());

if (output.numel() == 0) {
return output;
}
Expand Down Expand Up @@ -69,7 +85,8 @@ Tensor _mps_linear(const Tensor& input, const Tensor& weight_arg, const c10::opt
MPSGraphTensor* inputFlattened = inputTensor;
bool doReshape = false;
// workaround to improve the performance with 3D+ inputs
if (input_size.size() > 2 && input_size[0] > 1 && input_size[1] >= 1 && input_size[1] <= 32) {
if (input_size.size() > 2 && input_size[0] > 1 && input_size[1] >= 1 && input_size[1] <= 32 &&
bias.dim() == 1) {
doReshape = true;
inputFlattened = [mpsGraph flatten2DTensor:inputTensor axis:-1 name:nil];
}
Expand Down
12 changes: 12 additions & 0 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -1495,6 +1495,18 @@ def test_linear_1d_weight(self):

self.assertEqual(linear, linear_mps)

def test_linear_2d_bias(self):
device = "cpu"
x = torch.randn(2, 2, 2, 64, device=device)
linear = torch.nn.Linear(64, 4, device=device)
linear.bias = torch.nn.Parameter(torch.arange(8, dtype=torch.float32, device=device).reshape(2, 4))
y = linear(x)
device = "mps"
x_mps = x.to(device)
linear.to(device)
y_mps = linear(x_mps)
self.assertEqual(y, y_mps)

def _linear_helper(self, in_features, out_features, shape, bias=True, backward_pass=False):
cpu_linear = torch.nn.Linear(in_features=in_features, out_features=out_features, device="cpu", bias=bias)
mps_linear = torch.nn.Linear(in_features=in_features, out_features=out_features, device="mps", bias=bias)
Expand Down

0 comments on commit 037dc23

Please sign in to comment.