Skip to content

Commit

Permalink
[MPS] Add support for Softshrink to MPS Backend (#110814)
Browse files Browse the repository at this point in the history
Adds the softshrink activation function to the mps backend.
Pull Request resolved: #110814
Approved by: https://github.com/kulinseth
  • Loading branch information
igm503 authored and pytorchmergebot committed Oct 11, 2023
1 parent de370eb commit 95ff51d
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 1 deletion.
144 changes: 144 additions & 0 deletions aten/src/ATen/native/mps/operations/Activation.mm
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
#include <ATen/ops/silu_native.h>
#include <ATen/ops/softplus_backward_native.h>
#include <ATen/ops/softplus_native.h>
#include <ATen/ops/softshrink_backward_native.h>
#include <ATen/ops/softshrink_native.h>
#include <ATen/ops/tanh_backward_native.h>
#include <ATen/ops/threshold_backward_native.h>
#include <ATen/ops/threshold_native.h>
Expand Down Expand Up @@ -1455,6 +1457,148 @@ Tensor mish_backward_mps(const Tensor& grad_output, const Tensor& self) {
}
}

TORCH_IMPL_FUNC(softshrink_out_mps)
(const Tensor& self, const Scalar& lambd, const Tensor& result) {
using namespace mps;
TORCH_CHECK(self.is_mps());

if (result.numel() == 0)
return;

MPSScalar lambd_scalar = getMPSScalar(lambd, self.scalar_type());

struct CachedGraph : public MPSCachedGraph {
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
MPSGraphTensor* inputTensor_ = nil;
MPSGraphTensor* outputTensor_ = nil;
MPSGraphTensor* lambdTensor_ = nil;
};

MPSStream* stream = getCurrentMPSStream();

@autoreleasepool {
string key = "softshrink_out_mps:" + getTensorsStringKey({self}) + ":" + std::to_string(lambd.to<double>());

auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
MPSGraphTensor* lambdTensor = mpsGraphScalarPlaceHolder(mpsGraph, inputTensor.dataType);

MPSGraphTensor* negativeLambdTensor = [mpsGraph negativeWithTensor:lambdTensor name:nil];
MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0 dataType:inputTensor.dataType];
MPSGraphTensor* positiveLambdPredicateTensor = [mpsGraph greaterThanOrEqualToWithPrimaryTensor:inputTensor
secondaryTensor:lambdTensor
name:nil];
MPSGraphTensor* negativeLambdPredicateTensor = [mpsGraph lessThanWithPrimaryTensor:inputTensor
secondaryTensor:negativeLambdTensor
name:nil];
MPSGraphTensor* outputTensor =
[mpsGraph selectWithPredicateTensor:positiveLambdPredicateTensor
truePredicateTensor:[mpsGraph subtractionWithPrimaryTensor:inputTensor
secondaryTensor:lambdTensor
name:nil]
falsePredicateTensor:zeroTensor
name:nil];
outputTensor = [mpsGraph selectWithPredicateTensor:negativeLambdPredicateTensor
truePredicateTensor:[mpsGraph additionWithPrimaryTensor:inputTensor
secondaryTensor:lambdTensor
name:nil]
falsePredicateTensor:outputTensor
name:nil];

newCachedGraph->inputTensor_ = inputTensor;
newCachedGraph->outputTensor_ = outputTensor;
newCachedGraph->lambdTensor_ = lambdTensor;
});

Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self);
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, result);

// Create dictionary of inputs and outputs
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData(),
cachedGraph->lambdTensor_ : getMPSGraphTensorFromScalar(stream, lambd_scalar),
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
}
}

static void shrink_backward_out_mps(const Tensor& grad_output,
const Tensor& self,
const Scalar& lambd,
const Tensor& grad_input,
std::string op_name) {
using namespace mps;
TORCH_CHECK(self.is_mps());

if (grad_input.numel() == 0)
return;

MPSScalar lambd_scalar = getMPSScalar(lambd, self.scalar_type());

struct CachedGraph : public MPSCachedGraph {
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
MPSGraphTensor* gradOutputTensor_ = nil;
MPSGraphTensor* inputTensor_ = nil;
MPSGraphTensor* gradInputTensor_ = nil;
MPSGraphTensor* lambdTensor_ = nil;
};

MPSStream* stream = getCurrentMPSStream();

@autoreleasepool {
string key = op_name + ":" + getTensorsStringKey({self}) + ":" + std::to_string(lambd.to<double>());

auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output);
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
MPSGraphTensor* lambdTensor = mpsGraphScalarPlaceHolder(mpsGraph, inputTensor.dataType);

MPSGraphTensor* negativeLambdTensor = [mpsGraph negativeWithTensor:lambdTensor name:nil];
MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0 dataType:inputTensor.dataType];
MPSGraphTensor* positiveLambdPredicateTensor = [mpsGraph greaterThanOrEqualToWithPrimaryTensor:inputTensor
secondaryTensor:lambdTensor
name:nil];
MPSGraphTensor* negativeLambdPredicateTensor = [mpsGraph lessThanWithPrimaryTensor:inputTensor
secondaryTensor:negativeLambdTensor
name:nil];
MPSGraphTensor* gradInputTensor = [mpsGraph selectWithPredicateTensor:positiveLambdPredicateTensor
truePredicateTensor:gradOutputTensor
falsePredicateTensor:zeroTensor
name:nil];
gradInputTensor = [mpsGraph selectWithPredicateTensor:negativeLambdPredicateTensor
truePredicateTensor:gradOutputTensor
falsePredicateTensor:gradInputTensor
name:nil];

newCachedGraph->gradOutputTensor_ = gradOutputTensor;
newCachedGraph->inputTensor_ = inputTensor;
newCachedGraph->gradInputTensor_ = gradInputTensor;
newCachedGraph->lambdTensor_ = lambdTensor;
});
Placeholder gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output);
Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self);
Placeholder gradInputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, grad_input);

// Create dictionary of inputs and outputs
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
gradOutputPlaceholder.getMPSGraphTensor() : gradOutputPlaceholder.getMPSGraphTensorData(),
selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData(),
cachedGraph->lambdTensor_ : getMPSGraphTensorFromScalar(stream, lambd_scalar),
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
@{gradInputPlaceholder.getMPSGraphTensor() : gradInputPlaceholder.getMPSGraphTensorData()};
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
return;
}
}

TORCH_IMPL_FUNC(softshrink_backward_out_mps)
(const Tensor& grad_output, const Tensor& self, const Scalar& lambd, const Tensor& grad_input) {
return shrink_backward_out_mps(grad_output, self, lambd, grad_input, "softshrink_backward_out_mps");
}

Tensor prelu_mps(const Tensor& self, const Tensor& weight_) {
using namespace mps;

Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11653,6 +11653,7 @@
python_module: nn
dispatch:
CPU, CUDA: softshrink_out
MPS: softshrink_out_mps

- func: softshrink(Tensor self, Scalar lambd=0.5) -> Tensor
structured_delegate: softshrink.out
Expand All @@ -11665,6 +11666,7 @@
python_module: nn
dispatch:
CPU, CUDA: softshrink_backward_out
MPS: softshrink_backward_out_mps

- func: softshrink_backward(Tensor grad_output, Tensor self, Scalar lambd) -> Tensor
structured_delegate: softshrink_backward.grad_input
Expand Down
1 change: 0 additions & 1 deletion test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,6 @@ def mps_ops_modifier(ops):
'nn.functional.multilabel_margin_loss': None,
'nn.functional.pdist': None,
'nn.functional.rrelu': None,
'nn.functional.softshrink': None,
'nn.functional.norm': None,
'ormqr': None,
'pca_lowrank': None,
Expand Down

0 comments on commit 95ff51d

Please sign in to comment.