Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PT_BREAK] Implement Tanh Gelu Approximation #3039

Merged
merged 7 commits into from
Feb 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion scripts/apply_patches.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ if [ -f "$TORCH_PIN" ]; then
if [[ $CID = \#* ]]; then
PRNUM="${CID//[!0-9]/}"
set +x
MCHECK=$(git -C $PTDIR log -1000)
MCHECK=$(git -C $PTDIR log -100)
if [[ $MCHECK != *"Pull Request resolved: https://github.com/pytorch/pytorch/pull/$PRNUM"* ]]; then
echo "Fetching PyTorch PR #$PRNUM"
pushd "$PTDIR"
Expand Down
41 changes: 23 additions & 18 deletions test/cpp/test_aten_xla_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6356,14 +6356,16 @@ TEST_F(AtenXlaTensorTest, TestCeluInPlace) {
TEST_F(AtenXlaTensorTest, TestGelu) {
torch::Tensor input =
torch::rand({2, 3}, torch::TensorOptions(torch::kFloat));
torch::Tensor output = torch::gelu(input);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_input = CopyToDevice(input, device);
torch::Tensor xla_output = torch::gelu(xla_input);
AllClose(output, xla_output);
});
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::gelu", cpp_test::GetIgnoredCounters());
for (const auto& approximate : {"none", "tanh"}) {
torch::Tensor output = torch::gelu(input, approximate);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_input = CopyToDevice(input, device);
torch::Tensor xla_output = torch::gelu(xla_input, approximate);
AllClose(output, xla_output);
});
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::gelu", cpp_test::GetIgnoredCounters());
}
}

TEST_F(AtenXlaTensorTest, TestAddMatMul) {
Expand Down Expand Up @@ -10173,16 +10175,19 @@ TEST_F(AtenXlaTensorTest, TestEluBackward) {
}

TEST_F(AtenXlaTensorTest, TestGeluBackward) {
auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
return torch::gelu(inputs[0]);
};
ForEachDevice([&](const torch::Device& device) {
TestBackward(
{torch::rand({2, 3},
torch::TensorOptions(torch::kFloat).requires_grad(true))},
device, testfn);
});
ExpectCounterChanged("xla::gelu_backward", cpp_test::GetIgnoredCounters());
for (const auto& approximate : {"none", "tanh"}) {
auto testfn =
[&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
return torch::gelu(inputs[0], approximate);
};
ForEachDevice([&](const torch::Device& device) {
TestBackward(
{torch::rand(
{2, 3}, torch::TensorOptions(torch::kFloat).requires_grad(true))},
device, testfn);
});
ExpectCounterChanged("xla::gelu_backward", cpp_test::GetIgnoredCounters());
}
}

TEST_F(AtenXlaTensorTest, TestLeakyReluBackward) {
Expand Down
1 change: 1 addition & 0 deletions torch_patches/.torch_pin
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
#61439
23 changes: 19 additions & 4 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "torch_xla/csrc/aten_xla_bridge.h"
#include "torch_xla/csrc/debug_util.h"
#include "torch_xla/csrc/device.h"
#include "torch_xla/csrc/gelu.h"
#include "torch_xla/csrc/helpers.h"
#include "torch_xla/csrc/ops/as_strided.h"
#include "torch_xla/csrc/ops/index_ops.h"
Expand Down Expand Up @@ -199,6 +200,16 @@ void DoBinaryOpOut(const at::Tensor& self, const at::Tensor& other,
bin_op_out(operands.first, operands.second, out_tensor);
}

GeluType GetXlaGeluType(const c10::string_view approximate) {
if (approximate == "none") {
return GeluType::None;
} else if (approximate == "tanh") {
return GeluType::Tanh;
} else {
XLA_ERROR() << "Unknown gelu type: " << approximate;
}
}

} // namespace

at::Tensor& XLANativeFunctions::__ilshift__(at::Tensor& self,
Expand Down Expand Up @@ -1506,16 +1517,20 @@ at::Tensor XLANativeFunctions::ge(const at::Tensor& self,
XLATensor::ge(bridge::GetXlaTensor(self), bridge::GetXlaTensor(other)));
}

at::Tensor XLANativeFunctions::gelu(const at::Tensor& self) {
at::Tensor XLANativeFunctions::gelu(const at::Tensor& self,
c10::string_view approximate) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(XLATensor::gelu(bridge::GetXlaTensor(self)));
return bridge::AtenFromXlaTensor(
XLATensor::gelu(bridge::GetXlaTensor(self), GetXlaGeluType(approximate)));
}

at::Tensor XLANativeFunctions::gelu_backward(const at::Tensor& grad,
const at::Tensor& self) {
const at::Tensor& self,
c10::string_view approximate) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(XLATensor::gelu_backward(
bridge::GetXlaTensor(grad), bridge::GetXlaTensor(self)));
bridge::GetXlaTensor(grad), bridge::GetXlaTensor(self),
GetXlaGeluType(approximate)));
}

at::Tensor XLANativeFunctions::ger(const at::Tensor& self,
Expand Down
12 changes: 12 additions & 0 deletions torch_xla/csrc/gelu.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#pragma once

namespace torch_xla {

// These constants control the approximation behavior of gelu function.
enum GeluType {
None, // Baseline Gelu
Tanh, // Tahn Gelu Approximation
END
};

} // namespace torch_xla
58 changes: 48 additions & 10 deletions torch_xla/csrc/ops/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "torch_xla/csrc/convert_ops.h"
#include "torch_xla/csrc/data_ops.h"
#include "torch_xla/csrc/elementwise.h"
#include "torch_xla/csrc/gelu.h"
#include "torch_xla/csrc/helpers.h"
#include "torch_xla/csrc/lowering_context.h"
#include "torch_xla/csrc/matrix.h"
Expand Down Expand Up @@ -650,22 +651,59 @@ NodePtr EluBackward(const Value& grad_output, const Value& output,
positive_output_branch, negative_output_branch);
}

NodePtr Gelu(const Value& input) {
NodePtr Gelu(const Value& input, GeluType approximate) {
ScopePusher ir_scope("aten::gelu");
// input * 0.5 * (1.0 + torch.erf(input / math.sqrt(2.0)))
const xla::Shape& shape = input.shape();
return input * ScalarOp(0.5, shape) *
(Erf(input * ScalarOp(M_SQRT1_2, shape)) + ScalarOp(1.0, shape));
if (approximate == GeluType::Tanh) {
// inner = math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(input, 3))
// input * 0.5 * (1.0 + torch.tanh(inner))
const float kBeta = M_SQRT2 * M_2_SQRTPI * 0.5;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe better to mark it as static?

auto beta = ScalarOp(kBeta, shape);
auto kappa = ScalarOp(0.044715, shape);
auto three = ScalarOp(3, shape);
auto one = ScalarOp(1, shape);
auto half = ScalarOp(0.5, shape);
NodePtr inner = beta * (input + kappa * Pow(input, three));
return half * input * (one + Tanh(inner));
} else {
// input * 0.5 * (1.0 + torch.erf(input / math.sqrt(2.0)))
return input * ScalarOp(0.5, shape) *
(Erf(input * ScalarOp(M_SQRT1_2, shape)) + ScalarOp(1.0, shape));
}
}

NodePtr GeluBackward(const Value& grad, const Value& input) {
NodePtr GeluBackward(const Value& grad, const Value& input,
GeluType approximate) {
ScopePusher ir_scope("aten::gelu_backward");
const float kAlpha = M_2_SQRTPI * M_SQRT1_2 * 0.5;
const xla::Shape& shape = input.shape();
NodePtr scratch = Erf(input * ScalarOp(M_SQRT1_2, shape));
NodePtr dinput = Exp(input * input * ScalarOp(-0.5, shape));
return grad * (ScalarOp(0.5, shape) * (ScalarOp(1.0, shape) + scratch) +
input * dinput * ScalarOp(kAlpha, shape));
if (approximate == GeluType::Tanh) {
constexpr float kBeta = M_SQRT2 * M_2_SQRTPI * 0.5;
auto beta = ScalarOp(kBeta, shape);
auto kappa = ScalarOp(0.044715, shape);
auto one = ScalarOp(1, shape);
auto two = ScalarOp(2, shape);
auto three = ScalarOp(3, shape);
auto half = ScalarOp(0.5, shape);
NodePtr inner = beta * (input + kappa * Pow(input, three));
NodePtr tanh_inner = Tanh(inner);

NodePtr left = half * input;
NodePtr right = one + tanh_inner;

NodePtr left_derivative = half * right;

NodePtr tanh_derivative = one - tanh_inner * tanh_inner;
NodePtr inner_derivative = beta * (one + three * kappa * Pow(input, two));
NodePtr right_derivative = left * tanh_derivative * inner_derivative;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, it really should be lower as a node class and lowering function if it gets this complicated. I will let this one go since upstream pr needs to merge and fix it latter.


return grad * (left_derivative + right_derivative);
} else {
constexpr float kAlpha = M_2_SQRTPI * M_SQRT1_2 * 0.5;
NodePtr scratch = Erf(input * ScalarOp(M_SQRT1_2, shape));
NodePtr dinput = Exp(input * input * ScalarOp(-0.5, shape));
return grad * (ScalarOp(0.5, shape) * (ScalarOp(1.0, shape) + scratch) +
input * dinput * ScalarOp(kAlpha, shape));
}
}

NodePtr Lshift(const Value& input, const at::Scalar& other) {
Expand Down
6 changes: 4 additions & 2 deletions torch_xla/csrc/ops/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include <memory>

#include "torch_xla/csrc/gelu.h"
#include "torch_xla/csrc/ir.h"
#include "torch_xla/csrc/ops/constant.h"
#include "torch_xla/csrc/ops/generic.h"
Expand Down Expand Up @@ -194,9 +195,10 @@ NodePtr EluBackward(const Value& grad_output, const Value& output,
const at::Scalar& alpha, const at::Scalar& scale,
const at::Scalar& input_scale);

NodePtr Gelu(const Value& input);
NodePtr Gelu(const Value& input, GeluType approximate);

NodePtr GeluBackward(const Value& grad, const Value& input);
NodePtr GeluBackward(const Value& grad, const Value& input,
GeluType approximate);

NodePtr Lshift(const Value& input, const at::Scalar& other);

Expand Down
7 changes: 5 additions & 2 deletions torch_xla/csrc/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "torch_xla/csrc/computation.h"
#include "torch_xla/csrc/cross_replica_reduces.h"
#include "torch_xla/csrc/device.h"
#include "torch_xla/csrc/gelu.h"
#include "torch_xla/csrc/ir.h"
#include "torch_xla/csrc/ir_util.h"
#include "torch_xla/csrc/lowering_context.h"
Expand Down Expand Up @@ -595,8 +596,10 @@ class XLATensor {

static XLATensor ge(const XLATensor& input, const XLATensor& other);

static XLATensor gelu(const XLATensor& input);
static XLATensor gelu_backward(const XLATensor& grad, const XLATensor& input);
static XLATensor gelu(const XLATensor& input, GeluType approximate);

static XLATensor gelu_backward(const XLATensor& grad, const XLATensor& input,
GeluType approximate);

static XLATensor ger(const XLATensor& input, const XLATensor& vec2);

Expand Down
12 changes: 7 additions & 5 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "torch/csrc/lazy/core/helpers.h"
#include "torch_xla/csrc/aten_xla_bridge.h"
#include "torch_xla/csrc/data_ops.h"
#include "torch_xla/csrc/gelu.h"
#include "torch_xla/csrc/helpers.h"
#include "torch_xla/csrc/ir_util.h"
#include "torch_xla/csrc/layout_manager.h"
Expand Down Expand Up @@ -1381,14 +1382,15 @@ XLATensor XLATensor::ge(const XLATensor& input, const XLATensor& other) {
return DispatchComparisonOp(at::aten::ge, input, other);
}

XLATensor XLATensor::gelu(const XLATensor& input) {
return input.CreateFrom(ir::ops::Gelu(input.GetIrValue()));
XLATensor XLATensor::gelu(const XLATensor& input, GeluType approximate) {
return input.CreateFrom(ir::ops::Gelu(input.GetIrValue(), approximate));
}

XLATensor XLATensor::gelu_backward(const XLATensor& grad,
const XLATensor& input) {
return input.CreateFrom(
ir::ops::GeluBackward(grad.GetIrValue(), input.GetIrValue()));
const XLATensor& input,
GeluType approximate) {
return input.CreateFrom(ir::ops::GeluBackward(
grad.GetIrValue(), input.GetIrValue(), approximate));
}

XLATensor XLATensor::ger(const XLATensor& input, const XLATensor& vec2) {
Expand Down