Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MPS] Fixes GELU, LeakyRELU and MISH on non-contiguous tensors #123049

Closed
wants to merge 16 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
61 changes: 52 additions & 9 deletions aten/src/ATen/native/mps/operations/Activation.mm
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,17 @@ Tensor relu_mps(const Tensor& self) {
using CachedGraph = MPSUnaryCachedGraph;
TORCH_CHECK(output.is_mps());

if (self.numel() == 0) {
return;
}

MPSStream* stream = getCurrentMPSStream();

bool executeGatherOp =
!(self.is_contiguous(MemoryFormat::Contiguous) || self.is_contiguous(MemoryFormat::ChannelsLast) ||
self.is_contiguous(MemoryFormat::ChannelsLast3d));
Tensor output_ = at::empty_like(self, executeGatherOp ? MemoryFormat::Contiguous : MemoryFormat::Preserve);

@autoreleasepool {
string key = "leaky_relu" + getTensorsStringKey({self}) + ":" + to_string(negative_slope.to<double>());
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
Expand All @@ -152,13 +161,17 @@ Tensor relu_mps(const Tensor& self) {
newCachedGraph->outputTensor_ = outputTensor;
});

Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self);
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output);
Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, nil, executeGatherOp);
Placeholder outputPlaceholder =
Placeholder(cachedGraph->outputTensor_, executeGatherOp ? output_ : output, nil, false);

// Create dictionary of inputs and outputs
auto feeds = dictionaryFromPlaceholders(selfPlaceholder);
runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder);
}
if (executeGatherOp) {
output.copy_(output_);
}
}

TORCH_IMPL_FUNC(leaky_relu_backward_out_mps)
Expand All @@ -171,8 +184,14 @@ Tensor relu_mps(const Tensor& self) {
using CachedGraph = MPSUnaryGradCachedGraph;
TORCH_CHECK(output.is_mps());

if (self.numel() == 0) {
return;
}

MPSStream* stream = getCurrentMPSStream();

Tensor output_ = at::empty_like(self, self.suggest_memory_format());

@autoreleasepool {
string key =
"leaky_relu_backward" + getTensorsStringKey({self, grad_output}) + ":" + to_string(negative_slope.to<double>());
Expand Down Expand Up @@ -202,12 +221,13 @@ Tensor relu_mps(const Tensor& self) {

Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self);
Placeholder gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output);
Placeholder outputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, output);
Placeholder outputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, output_);

// Create dictionary of inputs and outputs
auto feeds = dictionaryFromPlaceholders(gradOutputPlaceholder, selfPlaceholder);
runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder);
}
output.copy_(output_);
}

TORCH_IMPL_FUNC(log_softmax_mps_out)
Expand Down Expand Up @@ -656,6 +676,11 @@ Tensor log_sigmoid_backward_mps(const Tensor& grad_output, const Tensor& self, c
auto approximate_type = get_gelutype_enum(approximate);
MPSStream* stream = getCurrentMPSStream();

bool executeGatherOp =
!(self.is_contiguous(MemoryFormat::Contiguous) || self.is_contiguous(MemoryFormat::ChannelsLast) ||
self.is_contiguous(MemoryFormat::ChannelsLast3d));
Tensor output_ = at::empty_like(self, executeGatherOp ? MemoryFormat::Contiguous : MemoryFormat::Preserve);

@autoreleasepool {
const auto key = "gelu_out_mps" + getTensorsStringKey({self}) + ":" + gelutype_to_string(approximate_type);
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
Expand All @@ -672,12 +697,17 @@ Tensor log_sigmoid_backward_mps(const Tensor& grad_output, const Tensor& self, c
newCachedGraph->outputTensor_ = outputTensor;
});

Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self);
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output);
Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, nil, executeGatherOp);
Placeholder outputPlaceholder =
Placeholder(cachedGraph->outputTensor_, executeGatherOp ? output_ : output, nil, false);

auto feeds = dictionaryFromPlaceholders(selfPlaceholder);
runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder);
}

if (executeGatherOp) {
output.copy_(output_);
}
}

TORCH_IMPL_FUNC(gelu_backward_out_mps)
Expand All @@ -686,8 +716,11 @@ Tensor log_sigmoid_backward_mps(const Tensor& grad_output, const Tensor& self, c
using CachedGraph = MPSUnaryGradCachedGraph;

// Empty output
if (grad_input.numel() == 0)
if (self.numel() == 0) {
return;
}

Tensor grad_input_ = at::empty_like(self, self.suggest_memory_format());

auto approximate_type = get_gelutype_enum(approximate);
MPSStream* stream = getCurrentMPSStream();
Expand Down Expand Up @@ -761,11 +794,12 @@ Tensor log_sigmoid_backward_mps(const Tensor& grad_output, const Tensor& self, c

Placeholder gradPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad);
Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self);
Placeholder outputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, grad_input);
Placeholder outputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, grad_input_);

auto feeds = dictionaryFromPlaceholders(gradPlaceholder, selfPlaceholder);
runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder);
}
grad_input.copy_(grad_input_);
}

static void elu_variants_out_mps(const Tensor& self,
Expand Down Expand Up @@ -1241,6 +1275,11 @@ Tensor glu_backward_mps(const Tensor& grad_output, const Tensor& self, const int

MPSStream* stream = getCurrentMPSStream();

bool executeGatherOp =
!(self.is_contiguous(MemoryFormat::Contiguous) || self.is_contiguous(MemoryFormat::ChannelsLast) ||
self.is_contiguous(MemoryFormat::ChannelsLast3d));
Tensor result_ = at::empty_like(self, executeGatherOp ? MemoryFormat::Contiguous : MemoryFormat::Preserve);

@autoreleasepool {
string key = "mish_out_mps:" + getTensorsStringKey({self});

Expand All @@ -1257,12 +1296,16 @@ Tensor glu_backward_mps(const Tensor& grad_output, const Tensor& self, const int
newCachedGraph->inputTensor_ = inputTensor;
newCachedGraph->outputTensor_ = outputTensor;
});
Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self);
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, result);
Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, nil, executeGatherOp);
Placeholder outputPlaceholder =
Placeholder(cachedGraph->outputTensor_, executeGatherOp ? result_ : result, nil, false);

auto feeds = dictionaryFromPlaceholders(selfPlaceholder);
runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder);
}
if (executeGatherOp) {
result.copy_(result_);
}
}

Tensor mish_backward_mps(const Tensor& grad_output, const Tensor& self) {
Expand Down
95 changes: 79 additions & 16 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -1472,29 +1472,44 @@ def testNpLeakyRelu(self):
0.9]]),
negative_slope=0.1))

def _testLeakyRelu(self, np_features, negative_slope, device):
cpu_x = torch.from_numpy(np_features).requires_grad_()
mps_x = torch.from_numpy(np_features).to('mps').requires_grad_()
def _testLeakyRelu(self, shape, dtype, negative_slope, contiguous):
cpu_x = torch.randn(shape, device='cpu', dtype=dtype)
mps_x = cpu_x.detach().clone().to('mps')

if not contiguous and not (0 in shape or len(shape) < 2):
# Tranposing will make the tensor non-contiguous
cpu_x = cpu_x.transpose(0, 1)
mps_x = mps_x.transpose(0, 1)
assert not mps_x.is_contiguous()

cpu_x.requires_grad_()
mps_x.requires_grad_()

relu_op = torch.nn.LeakyReLU(negative_slope)

cpu_leaky_relu = relu_op(cpu_x)
mps_leaky_relu = relu_op(mps_x)
torch.testing.assert_close(cpu_leaky_relu, mps_leaky_relu.to('cpu'))

# test backward pass

cpu_grad = torch.ones_like(cpu_leaky_relu)
mps_grad = cpu_grad.to('mps')
cpu_leaky_relu.backward(gradient=cpu_grad)

mps_leaky_relu.backward(gradient=mps_grad)
torch.testing.assert_close(cpu_x.grad, mps_x.grad.to('cpu'))
cpu_leaky_relu.backward(gradient=cpu_grad)

def testNumbersCPU(self):
for t in [np.float32]:
self._testLeakyRelu(
np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
negative_slope=0.2,
device="cpu")
assert cpu_x.grad is not None # Check that the grad is well-populated
self.assertEqual(cpu_x.grad, mps_x.grad)

def testNumbersCPU(self):
for t in [torch.float, torch.half]:
for shape in [[], (0,), (0, 3), (4,), (4, 3), (5, 4, 3)]:
for contiguous in [True, False]:
self._testLeakyRelu(shape,
dtype=t,
negative_slope=0.2,
contiguous=contiguous)

class TestAvgPool(TestCaseMPS):
def _sum_pool2d(self, x, kernel_size):
Expand Down Expand Up @@ -6633,9 +6648,18 @@ def helper(input_shape, out_shape, return_indices, dtype, channels_last=False):
helper((2, 16, 16), (4, 4), return_indices, dtype)

def test_gelu_simple(self):
def helper(shape, dtype=torch.float):
cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
def helper(shape, dtype=torch.float, contiguous=True):
cpu_x = torch.randn(shape, device='cpu', dtype=dtype)
x = cpu_x.detach().clone().to('mps')

if not contiguous and (0 not in shape and len(shape) >= 2):
# Tranposing will make the tensor non-contiguous
cpu_x = cpu_x.transpose(0, 1)
x = x.transpose(0, 1)
assert not x.is_contiguous()

cpu_x.requires_grad_()
x.requires_grad_()

gelu_result = torch.nn.GELU()(x)
# GELU is not supported on CPU, so cast it to float
Expand All @@ -6650,16 +6674,55 @@ def helper(shape, dtype=torch.float):
atol = 1e-5 if dtype == torch.float else 1e-2
rtol = 1e-3 if dtype == torch.float else 1e-2
self.assertEqual(gelu_result, gelu_result_cpu.to(dtype), atol=atol, rtol=rtol)

assert x.grad is not None # Check that the grad is well-populated
self.assertEqual(x.grad, cpu_x.grad, atol=atol, rtol=rtol)

# Test empty shape too
for dtype in [torch.float, torch.half]:
for shape in [(0, 3), [], (2, 3), (2, 8, 4, 5)]:
helper(shape, dtype)
for shape in [[], (0,), (0, 3), (4,), (4, 3), (5, 4, 3)]:
for contiguous in [True, False]:
helper(shape, dtype, contiguous)
# Test that gelu would raise an assert for integral types
for dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
self.assertRaises(RuntimeError, lambda: torch.nn.GELU()(torch.randint(100, (2,), dtype=dtype, device="mps")))

def test_mish_simple(self):
def helper(shape, dtype=torch.float, contiguous=True):
cpu_x = torch.randn(shape, device='cpu', dtype=dtype)
x = cpu_x.detach().clone().to('mps')

if not contiguous and (0 not in shape and len(shape) >= 2):
# Tranposing will make the tensor non-contiguous
cpu_x = cpu_x.transpose(0, 1)
x = x.transpose(0, 1)
assert not x.is_contiguous()

cpu_x.requires_grad_()
x.requires_grad_()

mish_result = torch.nn.Mish()(x)
mish_result_cpu = torch.nn.Mish()(cpu_x)

cpu_grad = torch.ones_like(mish_result_cpu)
grad = cpu_grad.to('mps')

mish_result.backward(gradient=grad)
mish_result_cpu.backward(gradient=cpu_grad)

atol = 1e-5 if dtype == torch.float else 1e-2
rtol = 1e-3 if dtype == torch.float else 1e-2
self.assertEqual(mish_result, mish_result_cpu.to(dtype), atol=atol, rtol=rtol)

assert x.grad is not None # Check that the grad is well-populated
self.assertEqual(x.grad, cpu_x.grad, atol=atol, rtol=rtol)

# Test empty shape too
for dtype in [torch.float, torch.half]:
for shape in [[], (0,), (0, 3), (4,), (4, 3), (5, 4, 3)]:
for contiguous in [True, False]:
helper(shape, dtype, contiguous)

def test_gelu(self):
def _test_gelu(n, m, dtype, contiguous, atol=None, rtol=None):
numpy_dtype = {
Expand Down