Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions test/cpp/test_aten_xla_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4857,6 +4857,19 @@ TEST_F(AtenXlaTensorTest, TestCeluInPlace) {
ExpectCounterChanged("xla::elu_", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestGelu) {
torch::Tensor input =
torch::rand({2, 3}, torch::TensorOptions(torch::kFloat));
torch::Tensor output = torch::gelu(input);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_input = CopyToDevice(input, device);
torch::Tensor xla_output = torch::gelu(xla_input);
AllClose(output, xla_output);
});
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::gelu", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestAddMatMul) {
int in_channels = 32;
int out_channels = 320;
Expand Down Expand Up @@ -7809,6 +7822,19 @@ TEST_F(AtenXlaTensorTest, TestEluBackward) {
});
}

TEST_F(AtenXlaTensorTest, TestGeluBackward) {
auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
return torch::gelu(inputs[0]);
};
ForEachDevice([&](const torch::Device& device) {
TestBackward(
{torch::rand({2, 3},
torch::TensorOptions(torch::kFloat).requires_grad(true))},
device, testfn);
});
ExpectCounterChanged("xla::gelu_backward", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestLeakyReluBackward) {
double negative_slope = 0.01;
auto testfn = [=](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
Expand Down
12 changes: 12 additions & 0 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1327,6 +1327,18 @@ at::Tensor& AtenXlaType::ge_(at::Tensor& self, const at::Tensor& other) {
return self;
}

at::Tensor AtenXlaType::gelu(const at::Tensor& self) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(XLATensor::gelu(bridge::GetXlaTensor(self)));
}

at::Tensor AtenXlaType::gelu_backward(const at::Tensor& grad,
const at::Tensor& self) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(XLATensor::gelu_backward(
bridge::GetXlaTensor(grad), bridge::GetXlaTensor(self)));
}

at::Tensor AtenXlaType::gt(const at::Tensor& self, at::Scalar other) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(
Expand Down
5 changes: 5 additions & 0 deletions torch_xla/csrc/aten_xla_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,11 @@ class AtenXlaType {

static at::Tensor& ge_(at::Tensor& self, const at::Tensor& other);

static at::Tensor gelu(const at::Tensor& self);

static at::Tensor gelu_backward(const at::Tensor& grad,
const at::Tensor& self);

static at::Tensor gt(const at::Tensor& self, at::Scalar other);

static at::Tensor gt(const at::Tensor& self, const at::Tensor& other);
Expand Down
16 changes: 16 additions & 0 deletions torch_xla/csrc/ops/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,22 @@ NodePtr EluBackward(const Value& grad_output, const Value& output,
positive_output_branch, negative_output_branch);
}

NodePtr Gelu(const Value& input) {
// input * 0.5 * (1.0 + torch.erf(input / math.sqrt(2.0)))
const xla::Shape& shape = input.shape();
return input * ScalarOp(0.5, shape) *
(Erf(input * ScalarOp(M_SQRT1_2, shape)) + ScalarOp(1.0, shape));
}

NodePtr GeluBackward(const Value& grad, const Value& input) {
const float kAlpha = M_2_SQRTPI * M_SQRT1_2 * 0.5;
const xla::Shape& shape = input.shape();
NodePtr scratch = Erf(input * ScalarOp(M_SQRT1_2, shape));
NodePtr dinput = Exp(input * input * ScalarOp(-0.5, shape));
return grad * (ScalarOp(0.5, shape) * (ScalarOp(1.0, shape) + scratch) +
input * dinput * ScalarOp(kAlpha, shape));
}

NodePtr Lshift(const Value& input, at::Scalar other) {
return input * ScalarOp(pow(2, other.to<double>()), input.shape());
}
Expand Down
4 changes: 4 additions & 0 deletions torch_xla/csrc/ops/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,10 @@ NodePtr Elu(const Value& input, at::Scalar alpha, at::Scalar scale,
NodePtr EluBackward(const Value& grad_output, const Value& output,
at::Scalar alpha, at::Scalar scale, at::Scalar input_scale);

NodePtr Gelu(const Value& input);

NodePtr GeluBackward(const Value& grad, const Value& input);

NodePtr Lshift(const Value& input, at::Scalar other);

NodePtr Lshift(const Value& input, const Value& other);
Expand Down
3 changes: 3 additions & 0 deletions torch_xla/csrc/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,9 @@ class XLATensor {
static XLATensor ge(const XLATensor& input, const XLATensor& other);
static void ge_(XLATensor& input, const XLATensor& other);

static XLATensor gelu(const XLATensor& input);
static XLATensor gelu_backward(const XLATensor& grad, const XLATensor& input);

static XLATensor gt(const XLATensor& input, at::Scalar other);
static void gt_(XLATensor& input, at::Scalar other);

Expand Down
10 changes: 10 additions & 0 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1059,6 +1059,16 @@ void XLATensor::ge_(XLATensor& input, const XLATensor& other) {
input.SetIrValue(ir::MakeNode<ir::ops::Cast>(cmp_result, input.dtype()));
}

XLATensor XLATensor::gelu(const XLATensor& input) {
return input.CreateFrom(ir::ops::Gelu(input.GetIrValue()));
}

XLATensor XLATensor::gelu_backward(const XLATensor& grad,
const XLATensor& input) {
return input.CreateFrom(
ir::ops::GeluBackward(grad.GetIrValue(), input.GetIrValue()));
}

XLATensor XLATensor::gt(const XLATensor& input, at::Scalar other) {
return DispatchComparisonOp(at::aten::gt, input, other);
}
Expand Down