Skip to content

Commit

Permalink
Fix regression in torch.equal behavior for NaNs (#111699)
Browse files Browse the repository at this point in the history
`torch.equal(x, x)` should return false if one of `x` is a tenor of floats one of which is NaN.
So, it renders some of the optimization proposed in #100024 invalid, though as result `torch.equal` will become much slower for identical floating point tensors.

Add regression test that calls torch.equal for tensor containing NaN

Fixes #111251

Pull Request resolved: #111699
Approved by: https://github.com/Skylion007, https://github.com/albanD
  • Loading branch information
malfet authored and pytorchmergebot committed Oct 21, 2023
1 parent aa24459 commit 7709382
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 1 deletion.
22 changes: 21 additions & 1 deletion aten/src/ATen/native/ReduceOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2096,7 +2096,27 @@ bool cpu_equal(const Tensor& self, const Tensor& other) {
&& self.layout() == other.layout()
&& self.is_neg() == other.is_neg()
&& self.is_conj() == other.is_conj()) {
return true;
if (c10::isIntegralType(self.scalar_type(), /*includeBool=*/true)) {
return true;
}
std::atomic<bool> result{true};
auto iter = TensorIteratorConfig().add_input(self).build();
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, iter.input_dtype(), "equal_notnan_cpu", [&] {
iter.for_each([&](char** data, const int64_t *strides, int64_t dim_size) {
if (!result) {
return;
}
char* self_data = data[0];
for (C10_UNUSED const auto i : c10::irange(dim_size)) {
if (isnan_(c10::load<scalar_t>(self_data))) {
result = false;
return;
}
self_data += strides[0];
}
});
});
return result.load();
}

std::atomic<bool> result{true};
Expand Down
5 changes: 5 additions & 0 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6351,6 +6351,11 @@ def test_equal(self):
self.assertNotEqual(t_0.size(), t_1.size())
self.assertFalse(torch.equal(t_0, t_1))

# Fast path: tensor containing `nan` is not equal to self
for dtype in floating_and_complex_types():
t = torch.tensor([1., float('nan')], dtype=dtype)
self.assertFalse(torch.equal(t, t))

def test_element_size(self):
byte = torch.ByteStorage().element_size()
char = torch.CharStorage().element_size()
Expand Down

0 comments on commit 7709382

Please sign in to comment.