Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions test/cpp/test_aten_xla_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10706,5 +10706,98 @@ TEST_F(AtenXlaTensorTest, TestLerpScalarOut) {
ExpectCounterChanged("xla::lerp", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestNanToNum) {
for (torch::ScalarType scalar_type :
{torch::kHalf, torch::kFloat, torch::kDouble, torch::kShort, torch::kInt,
torch::kLong}) {
torch::Tensor input =
isFloatingType(scalar_type)
? torch::tensor(
{1.0, std::nan("1"), std::numeric_limits<double>::infinity(),
-std::numeric_limits<double>::infinity()},
torch::TensorOptions(scalar_type))
: torch::randint(0, 100, {3, 4}, torch::TensorOptions(scalar_type));
torch::Tensor output = torch::nan_to_num(input);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_input = CopyToDevice(input, device);
torch::Tensor xla_output = torch::nan_to_num(xla_input);
AllClose(output, xla_output);
});
output =
torch::nan_to_num(input, /*nan=*/1.0, /*posinf=*/2.0, /*neginf=*/3.0);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_input = CopyToDevice(input, device);
torch::Tensor xla_output = torch::nan_to_num(
xla_input, /*nan=*/1.0, /*posinf=*/2.0, /*neginf=*/3.0);
AllClose(output, xla_output);
});
}
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::nan_to_num", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestNanToNumInplace) {
for (torch::ScalarType scalar_type :
{torch::kHalf, torch::kFloat, torch::kDouble, torch::kShort, torch::kInt,
torch::kLong}) {
torch::Tensor input =
isFloatingType(scalar_type)
? torch::tensor(
{1.0, std::nan("1"), std::numeric_limits<double>::infinity(),
-std::numeric_limits<double>::infinity()},
torch::TensorOptions(scalar_type))
: torch::randint(0, 100, {3, 4}, torch::TensorOptions(scalar_type));
torch::Tensor input_copy = input.clone();
input.nan_to_num_();
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_input = CopyToDevice(input_copy, device);
xla_input.nan_to_num_();
AllClose(input, xla_input);
});
input = input_copy.clone();
input.nan_to_num_(/*nan=*/1.0, /*posinf=*/2.0, /*neginf=*/3.0);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_input = CopyToDevice(input_copy, device);
xla_input.nan_to_num_(/*nan=*/1.0, /*posinf=*/2.0, /*neginf=*/3.0);
AllClose(input, xla_input);
});
}
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::nan_to_num", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestNanToNumOut) {
for (torch::ScalarType scalar_type :
{torch::kHalf, torch::kFloat, torch::kDouble, torch::kShort, torch::kInt,
torch::kLong}) {
torch::Tensor input =
isFloatingType(scalar_type)
? torch::tensor(
{1.0, std::nan("1"), std::numeric_limits<double>::infinity(),
-std::numeric_limits<double>::infinity()},
torch::TensorOptions(scalar_type))
: torch::randint(0, 100, {3, 4}, torch::TensorOptions(scalar_type));
torch::Tensor output = torch::zeros_like(input);
torch::nan_to_num_out(output, input);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_input = CopyToDevice(input, device);
torch::Tensor xla_output = torch::zeros_like(input);
torch::nan_to_num_out(xla_output, xla_input);
AllClose(output, xla_output);
});
torch::nan_to_num_out(output, input, /*nan=*/1.0, /*posinf=*/2.0,
/*neginf=*/3.0);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_input = CopyToDevice(input, device);
torch::Tensor xla_output = torch::zeros_like(input);
torch::nan_to_num_out(xla_output, xla_input, /*nan=*/1.0, /*posinf=*/2.0,
/*neginf=*/3.0);
AllClose(output, xla_output);
});
}
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::nan_to_num", cpp_test::GetIgnoredCounters());
}

} // namespace cpp_test
} // namespace torch_xla
27 changes: 27 additions & 0 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2196,6 +2196,33 @@ at::Tensor& XLANativeFunctions::mv_out(const at::Tensor& self,
return out;
}

at::Tensor XLANativeFunctions::nan_to_num(const at::Tensor& self,
c10::optional<double> nan,
c10::optional<double> posinf,
c10::optional<double> neginf) {
XLA_FN_COUNTER("xla::");
// nan_to_num doesn't apply to integer types.
if (!at::native::is_floating_point(self)) {
return CopyTensor(self);
}
auto element_type = TensorTypeToRawXlaType(self.scalar_type());
XlaHelpers::MinMax min_max = XlaHelpers::MinMaxValues(element_type);
at::Scalar nan_replacement = nan.has_value() ? *nan : 0.0;
at::Scalar posinf_replacement = posinf.has_value() ? *posinf : min_max.max;
at::Scalar neginf_replacement = neginf.has_value() ? *neginf : min_max.min;
for (const auto& replacement :
{nan_replacement, posinf_replacement, neginf_replacement}) {
XLA_CHECK(min_max.min.toDouble() <= replacement.toDouble() &&
replacement.toDouble() <= min_max.max.toDouble())
<< "Type " << self.scalar_type() << " replacement value "
<< replacement.toDouble() << " must be in the range ["
<< min_max.min.toDouble() << ", " << min_max.max.toDouble() << "].";
}
return bridge::AtenFromXlaTensor(
XLATensor::nan_to_num(bridge::GetXlaTensor(self), nan_replacement,
posinf_replacement, neginf_replacement));
}

std::tuple<at::Tensor, at::Tensor, at::Tensor>
XLANativeFunctions::native_batch_norm(
const at::Tensor& input, const c10::optional<at::Tensor>& weight,
Expand Down
19 changes: 19 additions & 0 deletions torch_xla/csrc/ops/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -863,6 +863,25 @@ NodePtr LogicalOr(const Value& input, const Value& other) {
std::move(lower_fn));
}

NodePtr NanToNum(const Value& input, const Value& nan, const Value& posinf,
const Value& neginf) {
auto lower_fn = [](const Node& node, LoweringContext* loctx) -> XlaOpVector {
xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(0));
xla::XlaOp nan_replacement = loctx->GetOutputOp(node.operand(1));
xla::XlaOp posinf_replacement = loctx->GetOutputOp(node.operand(2));
xla::XlaOp neginf_replacement = loctx->GetOutputOp(node.operand(3));
const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(xla_input);
xla::XlaOp result =
xla::Select(xla::IsNan(xla_input), nan_replacement,
xla::Select(xla::IsPosInf(xla_input), posinf_replacement,
xla::Select(xla::IsNegInf(xla_input),
neginf_replacement, xla_input)));
return node.ReturnOp(result, loctx);
};
return GenericOp(OpKind(at::aten::nan_to_num), {input, nan, posinf, neginf},
input.shape(), std::move(lower_fn));
}

} // namespace ops
} // namespace ir
} // namespace torch_xla
3 changes: 3 additions & 0 deletions torch_xla/csrc/ops/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,9 @@ NodePtr LogicalAnd(const Value& input, const Value& other);

NodePtr LogicalOr(const Value& input, const Value& other);

NodePtr NanToNum(const Value& input, const Value& nan, const Value& posinf,
const Value& neginf);

} // namespace ops
} // namespace ir
} // namespace torch_xla
4 changes: 4 additions & 0 deletions torch_xla/csrc/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -789,6 +789,10 @@ class XLATensor {
static void mv_out(XLATensor& out, const XLATensor& input,
const XLATensor& vec);

static XLATensor nan_to_num(const XLATensor& input, const at::Scalar& nan,
const at::Scalar& posinf,
const at::Scalar& neginf);

// Returns a new tensor that is a narrowed view of the input in the given
// dimension.
static XLATensor narrow(const XLATensor& input, xla::int64 dim,
Expand Down
13 changes: 13 additions & 0 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1810,6 +1810,19 @@ void XLATensor::mv_out(XLATensor& out, const XLATensor& input,
out.SetIrValue(ir::ops::Dot(input.GetIrValue(), vec.GetIrValue()));
}

XLATensor XLATensor::nan_to_num(const XLATensor& input, const at::Scalar& nan,
const at::Scalar& posinf,
const at::Scalar& neginf) {
ir::Value nan_value =
GetIrValueForScalar(nan, input.shape(), input.GetDevice());
ir::Value posinf_value =
GetIrValueForScalar(posinf, input.shape(), input.GetDevice());
ir::Value neginf_value =
GetIrValueForScalar(neginf, input.shape(), input.GetDevice());
return input.CreateFrom(ir::ops::NanToNum(input.GetIrValue(), nan_value,
posinf_value, neginf_value));
}

XLATensor XLATensor::narrow(const XLATensor& input, xla::int64 dim,
xla::int64 start, xla::int64 length) {
auto input_shape = input.shape();
Expand Down
1 change: 1 addition & 0 deletions xla_native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ supported:
- logical_xor
- logical_or
- logical_and
- nan_to_num
autograd:
- max_pool2d
- max_pool3d