diff --git a/torch_xla/csrc/ops/ops.cpp b/torch_xla/csrc/ops/ops.cpp index a6b23e318a71..071bc3900384 100644 --- a/torch_xla/csrc/ops/ops.cpp +++ b/torch_xla/csrc/ops/ops.cpp @@ -204,27 +204,15 @@ NodePtr SoftmaxBackwardOp(const Value& grad_output, const Value& output, XlaHelpers::GetCanonicalDimensionIndex(dim, grad_output.shape().rank())); } -NodePtr Clamp(const Value& input, c10::optional min, - c10::optional max) { - const xla::Shape& input_shape = input.shape(); - XlaHelpers::MinMax min_max = - XlaHelpers::MinMaxValues(input_shape.element_type()); - if (!min) { - min = min_max.min; - } - if (!max) { - max = min_max.max; - } - NodePtr min_value = ScalarOp(*min, input_shape.element_type()); - NodePtr max_value = ScalarOp(*max, input_shape.element_type()); +NodePtr Clamp(const Value& input, const Value& min, const Value& max) { auto lower_fn = [](const Node& node, LoweringContext* loctx) -> XlaOpVector { xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(0)); xla::XlaOp xla_min = loctx->GetOutputOp(node.operand(1)); xla::XlaOp xla_max = loctx->GetOutputOp(node.operand(2)); return node.ReturnOp(xla::Clamp(xla_min, xla_input, xla_max), loctx); }; - return GenericOp(OpKind(at::aten::clamp), OpList{input, min_value, max_value}, - input_shape, std::move(lower_fn)); + return GenericOp(OpKind(at::aten::clamp), OpList{input, min, max}, + input.shape(), std::move(lower_fn)); } NodePtr AddMatMulOp(const Value& input, const Value& weight, diff --git a/torch_xla/csrc/ops/ops.h b/torch_xla/csrc/ops/ops.h index 8f16aad1eeff..e8cf10f004c4 100644 --- a/torch_xla/csrc/ops/ops.h +++ b/torch_xla/csrc/ops/ops.h @@ -129,8 +129,7 @@ NodePtr LogSoftmaxBackwardOp(const Value& grad_output, const Value& output, NodePtr SoftmaxBackwardOp(const Value& grad_output, const Value& output, xla::int64 dim); -NodePtr Clamp(const Value& input, c10::optional min, - c10::optional max); +NodePtr Clamp(const Value& input, const Value& min, const Value& max); NodePtr Ceil(const Value& input); diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index cca375c7c115..775aeae3ca2a 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -93,6 +93,29 @@ namespace torch_xla { namespace { +struct MinMaxValues { + ir::Value min; + ir::Value max; +}; + +MinMaxValues GetMinMaxValues(const XLATensor& tensor, + c10::optional min, + c10::optional max) { + auto shape = tensor.shape(); + XlaHelpers::MinMax min_max = + XlaHelpers::MinMaxValues(shape.get().element_type()); + if (!min) { + min = min_max.min; + } + if (!max) { + max = min_max.max; + } + return {XLATensor::GetIrValueForScalar(*min, shape.get().element_type(), + tensor.GetDevice()), + XLATensor::GetIrValueForScalar(*max, shape.get().element_type(), + tensor.GetDevice())}; +} + void CheckRank(const XLATensor& t, xla::int64 expected_rank, const std::string& tag, const std::string& arg_name, int arg_number) { @@ -598,12 +621,16 @@ XLATensor XLATensor::cholesky(const XLATensor& input, bool upper) { XLATensor XLATensor::clamp(const XLATensor& input, c10::optional min, c10::optional max) { - return input.CreateFrom(ir::ops::Clamp(input.GetIrValue(), min, max)); + MinMaxValues min_max = GetMinMaxValues(input, min, max); + return input.CreateFrom( + ir::ops::Clamp(input.GetIrValue(), min_max.min, min_max.max)); } void XLATensor::clamp_(XLATensor& input, c10::optional min, c10::optional max) { - input.SetIrValue(ir::ops::Clamp(input.GetIrValue(), min, max)); + MinMaxValues min_max = GetMinMaxValues(input, min, max); + input.SetIrValue( + ir::ops::Clamp(input.GetIrValue(), min_max.min, min_max.max)); } XLATensor XLATensor::clone(const XLATensor& input) {