Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 3 additions & 15 deletions torch_xla/csrc/ops/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,27 +204,15 @@ NodePtr SoftmaxBackwardOp(const Value& grad_output, const Value& output,
XlaHelpers::GetCanonicalDimensionIndex(dim, grad_output.shape().rank()));
}

NodePtr Clamp(const Value& input, c10::optional<at::Scalar> min,
c10::optional<at::Scalar> max) {
const xla::Shape& input_shape = input.shape();
XlaHelpers::MinMax min_max =
XlaHelpers::MinMaxValues(input_shape.element_type());
if (!min) {
min = min_max.min;
}
if (!max) {
max = min_max.max;
}
NodePtr min_value = ScalarOp(*min, input_shape.element_type());
NodePtr max_value = ScalarOp(*max, input_shape.element_type());
NodePtr Clamp(const Value& input, const Value& min, const Value& max) {
auto lower_fn = [](const Node& node, LoweringContext* loctx) -> XlaOpVector {
xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(0));
xla::XlaOp xla_min = loctx->GetOutputOp(node.operand(1));
xla::XlaOp xla_max = loctx->GetOutputOp(node.operand(2));
return node.ReturnOp(xla::Clamp(xla_min, xla_input, xla_max), loctx);
};
return GenericOp(OpKind(at::aten::clamp), OpList{input, min_value, max_value},
input_shape, std::move(lower_fn));
return GenericOp(OpKind(at::aten::clamp), OpList{input, min, max},
input.shape(), std::move(lower_fn));
}

NodePtr AddMatMulOp(const Value& input, const Value& weight,
Expand Down
3 changes: 1 addition & 2 deletions torch_xla/csrc/ops/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,7 @@ NodePtr LogSoftmaxBackwardOp(const Value& grad_output, const Value& output,
NodePtr SoftmaxBackwardOp(const Value& grad_output, const Value& output,
xla::int64 dim);

NodePtr Clamp(const Value& input, c10::optional<at::Scalar> min,
c10::optional<at::Scalar> max);
NodePtr Clamp(const Value& input, const Value& min, const Value& max);

NodePtr Ceil(const Value& input);

Expand Down
31 changes: 29 additions & 2 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,29 @@
namespace torch_xla {
namespace {

struct MinMaxValues {
ir::Value min;
ir::Value max;
};

MinMaxValues GetMinMaxValues(const XLATensor& tensor,
c10::optional<at::Scalar> min,
c10::optional<at::Scalar> max) {
auto shape = tensor.shape();
XlaHelpers::MinMax min_max =
XlaHelpers::MinMaxValues(shape.get().element_type());
if (!min) {
min = min_max.min;
}
if (!max) {
max = min_max.max;
}
return {XLATensor::GetIrValueForScalar(*min, shape.get().element_type(),
tensor.GetDevice()),
XLATensor::GetIrValueForScalar(*max, shape.get().element_type(),
tensor.GetDevice())};
}

void CheckRank(const XLATensor& t, xla::int64 expected_rank,
const std::string& tag, const std::string& arg_name,
int arg_number) {
Expand Down Expand Up @@ -598,12 +621,16 @@ XLATensor XLATensor::cholesky(const XLATensor& input, bool upper) {
XLATensor XLATensor::clamp(const XLATensor& input,
c10::optional<at::Scalar> min,
c10::optional<at::Scalar> max) {
return input.CreateFrom(ir::ops::Clamp(input.GetIrValue(), min, max));
MinMaxValues min_max = GetMinMaxValues(input, min, max);
return input.CreateFrom(
ir::ops::Clamp(input.GetIrValue(), min_max.min, min_max.max));
}

void XLATensor::clamp_(XLATensor& input, c10::optional<at::Scalar> min,
c10::optional<at::Scalar> max) {
input.SetIrValue(ir::ops::Clamp(input.GetIrValue(), min, max));
MinMaxValues min_max = GetMinMaxValues(input, min, max);
input.SetIrValue(
ir::ops::Clamp(input.GetIrValue(), min_max.min, min_max.max));
}

XLATensor XLATensor::clone(const XLATensor& input) {
Expand Down