Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions torch_patches/.torch_pin
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
#81827
26 changes: 0 additions & 26 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1491,32 +1491,6 @@ at::Tensor XLANativeFunctions::hardshrink(const at::Tensor& self,
XLATensor::hardshrink(bridge::GetXlaTensor(self), lambda));
}

at::Tensor XLANativeFunctions::hardsigmoid(const at::Tensor& self) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(
XLATensor::hardsigmoid(bridge::GetXlaTensor(self)));
}

at::Tensor XLANativeFunctions::hardsigmoid_backward(
const at::Tensor& grad_output, const at::Tensor& self) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(XLATensor::hardsigmoid_backward(
bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(self)));
}

at::Tensor XLANativeFunctions::hardswish(const at::Tensor& self) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(
XLATensor::hardswish(bridge::GetXlaTensor(self)));
}

at::Tensor XLANativeFunctions::hardswish_backward(const at::Tensor& grad_output,
const at::Tensor& self) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(XLATensor::hardswish_backward(
bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(self)));
}

at::Tensor XLANativeFunctions::hardshrink_backward(const at::Tensor& grad_out,
const at::Tensor& self,
const at::Scalar& lambda) {
Expand Down
48 changes: 0 additions & 48 deletions torch_xla/csrc/ops/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,54 +168,6 @@ torch::lazy::NodePtr Prelu(const torch::lazy::Value& input,
GetXlaShape(input), std::move(lower_fn));
}

torch::lazy::NodePtr HardSigmoid(const torch::lazy::Value& input) {
auto lower_fn = [](const XlaNode& node,
LoweringContext* loctx) -> XlaOpVector {
xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(0));
return node.ReturnOp(BuildHardSigmoid(xla_input), loctx);
};
return GenericOp(torch::lazy::OpKind(at::aten::hardsigmoid), {input},
GetXlaShape(input), std::move(lower_fn));
}

torch::lazy::NodePtr HardSigmoidBackward(const torch::lazy::Value& grad_output,
const torch::lazy::Value& input) {
auto lower_fn = [](const XlaNode& node,
LoweringContext* loctx) -> XlaOpVector {
xla::XlaOp xla_grad_output = loctx->GetOutputOp(node.operand(0));
xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(1));
return node.ReturnOp(BuildHardSigmoidBackward(xla_grad_output, xla_input),
loctx);
};
return GenericOp(torch::lazy::OpKind(at::aten::hardsigmoid_backward),
{grad_output, input}, GetXlaShape(input),
std::move(lower_fn));
}

torch::lazy::NodePtr HardSwish(const torch::lazy::Value& input) {
auto lower_fn = [](const XlaNode& node,
LoweringContext* loctx) -> XlaOpVector {
xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(0));
return node.ReturnOp(BuildHardSwish(xla_input), loctx);
};
return GenericOp(torch::lazy::OpKind(at::aten::hardswish), {input},
GetXlaShape(input), std::move(lower_fn));
}

torch::lazy::NodePtr HardSwishBackward(const torch::lazy::Value& grad_output,
const torch::lazy::Value& input) {
auto lower_fn = [](const XlaNode& node,
LoweringContext* loctx) -> XlaOpVector {
xla::XlaOp xla_grad_output = loctx->GetOutputOp(node.operand(0));
xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(1));
return node.ReturnOp(BuildHardSwishBackward(xla_grad_output, xla_input),
loctx);
};
return GenericOp(torch::lazy::OpKind(at::aten::hardswish_backward),
{grad_output, input}, GetXlaShape(input),
std::move(lower_fn));
}

torch::lazy::NodePtr LogSigmoid(const torch::lazy::Value& input) {
auto lower_fn = [](const XlaNode& node,
LoweringContext* loctx) -> XlaOpVector {
Expand Down
10 changes: 0 additions & 10 deletions torch_xla/csrc/ops/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,16 +109,6 @@ torch::lazy::NodePtr Fmod(const torch::lazy::Value& dividend,

torch::lazy::NodePtr Not(const torch::lazy::Value& input);

torch::lazy::NodePtr HardSigmoid(const torch::lazy::Value& input);

torch::lazy::NodePtr HardSigmoidBackward(const torch::lazy::Value& grad_output,
const torch::lazy::Value& input);

torch::lazy::NodePtr HardSwish(const torch::lazy::Value& input);

torch::lazy::NodePtr HardSwishBackward(const torch::lazy::Value& grad_output,
const torch::lazy::Value& input);

torch::lazy::NodePtr LogSigmoid(const torch::lazy::Value& input);

torch::lazy::NodePtr LogSigmoidBackward(const torch::lazy::Value& grad_output,
Expand Down
23 changes: 23 additions & 0 deletions torch_xla/csrc/ops/ops_lower_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,29 @@ torch_xla::XlaOpVector Floor::Lower(LoweringContext* loctx) const {
return ReturnOp(xla::Floor(xla_input), loctx);
}

torch_xla::XlaOpVector Hardsigmoid::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
return ReturnOp(BuildHardSigmoid(xla_input), loctx);
}

torch_xla::XlaOpVector HardsigmoidBackward::Lower(
LoweringContext* loctx) const {
xla::XlaOp xla_grad_output = loctx->GetOutputOp(operand(0));
xla::XlaOp xla_input = loctx->GetOutputOp(operand(1));
return ReturnOp(BuildHardSigmoidBackward(xla_grad_output, xla_input), loctx);
}

torch_xla::XlaOpVector Hardswish::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
return ReturnOp(BuildHardSwish(xla_input), loctx);
}

torch_xla::XlaOpVector HardswishBackward::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_grad_output = loctx->GetOutputOp(operand(0));
xla::XlaOp xla_input = loctx->GetOutputOp(operand(1));
return ReturnOp(BuildHardSwishBackward(xla_grad_output, xla_input), loctx);
}

torch_xla::XlaOpVector Inverse::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
return ReturnOp(BuildInverse(xla_input), loctx);
Expand Down
18 changes: 18 additions & 0 deletions torch_xla/csrc/ops/ops_xla_shape_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,24 @@ xla::Shape FloorOutputShape(const torch::lazy::Value& input) {
return GetXlaShape(input);
}

xla::Shape HardsigmoidOutputShape(const torch::lazy::Value& input) {
return GetXlaShape(input);
}

xla::Shape HardsigmoidBackwardOutputShape(const torch::lazy::Value& grad_output,
const torch::lazy::Value& input) {
return GetXlaShape(input);
}

xla::Shape HardswishOutputShape(const torch::lazy::Value& input) {
return GetXlaShape(input);
}

xla::Shape HardswishBackwardOutputShape(const torch::lazy::Value& grad_output,
const torch::lazy::Value& input) {
return GetXlaShape(input);
}

xla::Shape InverseOutputShape(const torch::lazy::Value& input) {
return GetXlaShape(input);
}
Expand Down
10 changes: 10 additions & 0 deletions torch_xla/csrc/ops/ops_xla_shape_fn.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,16 @@ xla::Shape ExpOutputShape(const torch::lazy::Value& input);

xla::Shape FloorOutputShape(const torch::lazy::Value& input);

xla::Shape HardsigmoidOutputShape(const torch::lazy::Value& input);

xla::Shape HardsigmoidBackwardOutputShape(const torch::lazy::Value& grad_output,
const torch::lazy::Value& input);

xla::Shape HardswishOutputShape(const torch::lazy::Value& input);

xla::Shape HardswishBackwardOutputShape(const torch::lazy::Value& grad_output,
const torch::lazy::Value& input);

xla::Shape InverseOutputShape(const torch::lazy::Value& input);

xla::Shape LogdetOutputShape(const torch::lazy::Value& input);
Expand Down
10 changes: 0 additions & 10 deletions torch_xla/csrc/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -717,16 +717,6 @@ class XLATensor : public c10::intrusive_ptr_target {
const XLATensorPtr& input,
const at::Scalar& lambda);

static XLATensorPtr hardsigmoid(const XLATensorPtr& input);

static XLATensorPtr hardsigmoid_backward(const XLATensorPtr& grad_output,
const XLATensorPtr& input);

static XLATensorPtr hardswish(const XLATensorPtr& input);

static XLATensorPtr hardswish_backward(const XLATensorPtr& grad_output,
const XLATensorPtr& input);

static XLATensorPtr hardtanh_backward(const XLATensorPtr& grad_output,
const XLATensorPtr& input,
const at::Scalar& min_val,
Expand Down
20 changes: 0 additions & 20 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1590,26 +1590,6 @@ XLATensorPtr XLATensor::hardshrink_backward(const XLATensorPtr& grad_out,
grad_out->GetIrValue(), input->GetIrValue(), lambda));
}

XLATensorPtr XLATensor::hardsigmoid(const XLATensorPtr& input) {
return input->CreateFrom(HardSigmoid(input->GetIrValue()));
}

XLATensorPtr XLATensor::hardsigmoid_backward(const XLATensorPtr& grad_output,
const XLATensorPtr& input) {
return input->CreateFrom(
HardSigmoidBackward(grad_output->GetIrValue(), input->GetIrValue()));
}

XLATensorPtr XLATensor::hardswish(const XLATensorPtr& input) {
return input->CreateFrom(HardSwish(input->GetIrValue()));
}

XLATensorPtr XLATensor::hardswish_backward(const XLATensorPtr& grad_output,
const XLATensorPtr& input) {
return input->CreateFrom(
HardSwishBackward(grad_output->GetIrValue(), input->GetIrValue()));
}

XLATensorPtr XLATensor::hardtanh_backward(const XLATensorPtr& grad_output,
const XLATensorPtr& input,
const at::Scalar& min_val,
Expand Down
8 changes: 4 additions & 4 deletions xla_native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ full_codegen:
- exp
- floor
- inverse
- hardsigmoid
- hardsigmoid_backward
- hardswish
- hardswish_backward
- logdet
- maximum
- minimum
Expand Down Expand Up @@ -149,10 +153,6 @@ supported:
- gt.Tensor
- hardshrink
- hardshrink_backward
- hardsigmoid
- hardsigmoid_backward
- hardswish
- hardswish_backward
- hardtanh
- hardtanh_backward
- index.Tensor
Expand Down