Skip to content

Commit d17a91c

Browse files
committed
Add leaky_relu_ to ATen XLA tensor
1 parent 88d8efa commit d17a91c

File tree

5 files changed

+27
-0
lines changed

5 files changed

+27
-0
lines changed

test/cpp/test_aten_xla_tensor.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1715,6 +1715,18 @@ TEST_F(AtenXlaTensorTest, TestLeakyRelu) {
17151715
});
17161716
}
17171717

1718+
TEST_F(AtenXlaTensorTest, TestLeakyReluInPlace) {
1719+
at::Tensor input = at::rand({2, 1, 4, 6}, at::TensorOptions(at::kFloat));
1720+
double negative_slope = 0.01;
1721+
ForEachDevice([&](const Device& device) {
1722+
at::Tensor xla_input = bridge::CreateXlaTensor(input.clone(), device);
1723+
at::Tensor output = at::leaky_relu_(input, negative_slope);
1724+
at::Tensor xla_output = at::leaky_relu_(xla_input, negative_slope);
1725+
AllClose(output, xla_output);
1726+
AllClose(input, xla_input);
1727+
});
1728+
}
1729+
17181730
TEST_F(AtenXlaTensorTest, TestExp) {
17191731
at::Tensor a = at::rand({2, 2}, at::TensorOptions(at::kFloat));
17201732
at::Tensor b = at::exp(a);

torch_xla/csrc/aten_xla_type.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1147,6 +1147,13 @@ at::Tensor AtenXlaType::leaky_relu(const at::Tensor& self,
11471147
bridge::GetXlaTensor(self), negative_slope.to<double>()));
11481148
}
11491149

1150+
at::Tensor& AtenXlaType::leaky_relu_(at::Tensor& self,
1151+
at::Scalar negative_slope) const {
1152+
XLATensor self_tensor = bridge::GetXlaTensor(self);
1153+
XLATensor::leaky_relu_(self_tensor, negative_slope.to<double>());
1154+
return self;
1155+
}
1156+
11501157
at::Tensor AtenXlaType::threshold(const at::Tensor& self, at::Scalar threshold,
11511158
at::Scalar value) const {
11521159
return bridge::AtenFromXlaTensor(XLATensor::threshold(

torch_xla/csrc/aten_xla_type.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,8 @@ class AtenXlaType : public AtenXlaTypeBase {
348348

349349
at::Tensor leaky_relu(const at::Tensor& self,
350350
at::Scalar negative_slope) const override;
351+
at::Tensor& leaky_relu_(at::Tensor& self,
352+
at::Scalar negative_slope) const override;
351353

352354
at::Tensor threshold(const at::Tensor& self, at::Scalar threshold,
353355
at::Scalar value) const override;

torch_xla/csrc/tensor.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -907,6 +907,11 @@ XLATensor XLATensor::leaky_relu(const XLATensor& input, double negative_slope) {
907907
ir::MakeNode<ir::ops::LeakyRelu>(input.GetIrValue(), negative_slope));
908908
}
909909

910+
void XLATensor::leaky_relu_(XLATensor& input, double negative_slope) {
911+
input.SetIrValue(
912+
ir::MakeNode<ir::ops::LeakyRelu>(input.GetIrValue(), negative_slope));
913+
}
914+
910915
XLATensor XLATensor::DispatchComparisonOp(c10::Symbol kind,
911916
const XLATensor& input,
912917
const at::Scalar& other) {

torch_xla/csrc/tensor.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@ class XLATensor {
216216
static void relu_(XLATensor& input);
217217

218218
static XLATensor leaky_relu(const XLATensor& input, double negative_slope);
219+
static void leaky_relu_(XLATensor& input, double negative_slope);
219220

220221
static XLATensor threshold(const XLATensor& input, float threshold,
221222
float value);

0 commit comments

Comments
 (0)