Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,10 @@ class XlaTestCase(TestCase):

def assertEqualRel(self, out, expected, rel_err=1e-2, abs_err=1e-5):
try:
nan_mask = torch.isnan(expected)
self.assertTrue(torch.equal(nan_mask, torch.isnan(out)))
diff_tensor = (out - expected).abs().float()
diff_tensor[nan_mask] = 0
max_rel_err = torch.max(out.abs(), expected.abs()).float() * rel_err
# Allow higher relative differences as long as we're still below the
# absolute error.
Expand Down Expand Up @@ -733,11 +736,24 @@ def test_fn(a):

self.runAtenTest(torch.zeros([4, 4]), test_fn)

def test_max_throw(self):
def test_reduction_zero_dim(self):
self.runAtenTest(torch.rand(2, 0, 4).bool(), lambda x : torch.all(x))
self.runAtenTest(torch.rand(2, 0, 4).bool(), lambda x : torch.any(x))
self.runAtenTest(torch.rand(2, 0, 4), lambda x : torch.sum(x))
self.runAtenTest(torch.rand(2, 0, 4), lambda x : torch.mean(x))
self.runAtenTest(torch.rand(2, 0, 4), lambda x : torch.prod(x))
# min & max throws
xla_device = xm.xla_device()
xla_a = torch.randn(2, 0, 4, device=xla_device)
a = torch.rand(2, 0, 4)
xla_a = a.to(xla_device)
self.assertRaises(RuntimeError, lambda: torch.max(a, dim=1))
self.assertRaises(RuntimeError, lambda: torch.max(a))
self.assertRaises(RuntimeError, lambda: torch.min(a, dim=1))
self.assertRaises(RuntimeError, lambda: torch.min(a))
self.assertRaises(RuntimeError, lambda: torch.max(xla_a, dim=1))
self.assertRaises(RuntimeError, lambda: torch.max(xla_a))
self.assertRaises(RuntimeError, lambda: torch.min(xla_a, dim=1))
self.assertRaises(RuntimeError, lambda: torch.min(xla_a))

def test_writeable_tensors_updates(self):

Expand Down
12 changes: 8 additions & 4 deletions torch_xla/csrc/reduction.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "torch_xla/csrc/reduction.h"

#include <cmath>

#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/xla_client/debug_macros.h"
Expand Down Expand Up @@ -31,7 +33,6 @@ ReductionInfo GetReductionInfo(
rinfo.new_dimensions.push_back(shape.dimensions(i));
}
}
XLA_CHECK_GT(rinfo.element_count, 0);
return rinfo;
}

Expand Down Expand Up @@ -73,10 +74,11 @@ xla::XlaOp CreateSummation(
xla::XlaOp result = xla::Reduce(
input, init_value, XlaHelpers::CreateAddComputation(shape.element_type()),
dimensions);
if (scale && rinfo.element_count > 1) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure?
Coming in with count==0 means division by zero below...

Copy link
Contributor Author

@ailzhang ailzhang Sep 5, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea nan is the expected output for this case: in the sense that numpy and pytorch both return nan...

In [1]: import numpy as np

In [2]: import torch

In [3]: a = torch.rand(2, 0, 4)

In [4]: a.mean()
Out[4]: tensor(nan)

In [5]: b = np.random.rand(2, 0, 4)

In [6]: b.mean()
/home/ubuntu/miniconda3/envs/maskrcnn36/bin/ipython:1: RuntimeWarning: Mean of empty slice.
  #!/home/ubuntu/miniconda3/envs/maskrcnn36/bin/python
/home/ubuntu/miniconda3/envs/maskrcnn36/lib/python3.6/site-packages/numpy/core/_methods.py:85: RuntimeWarning: invalid value encountered in double_scalars
  ret = ret.dtype.type(ret / rcount)
Out[6]: nan

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about the line:

1.0f / static_cast<float>(rinfo.element_count)

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry I didn't get this - what's the issue about that line? It produces a nan which is then multiplied to result.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Division by zero is Undefined Behavior, and many times depends on compiler options (whether to SIGFPE or not).
So better not rely on it returning NaN.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ehhh what's a better way to get nan in this case?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what the most reliable way of construction NaN in XLA. I assume ConstantR0 with a NaN float value should work.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ScalarValue already does that.
I am not sure how XLA "feels" about NaN uploaded to TPU device. We can check.
In any case, that should be:

rinfo.element_count > 0 ? 1.0f / static_cast<float>(rinfo.element_count) : NAN

if (scale) {
xla::XlaOp scale = XlaHelpers::ScalarValue<float>(
1.0f / static_cast<float>(rinfo.element_count), shape.element_type(),
input.builder());
rinfo.element_count > 0 ? 1.0f / static_cast<float>(rinfo.element_count)
: NAN,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd #include <cmath> and run clang-format.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm clang-format seems to be happy with this line ;)
Btw somehow it builds and passes test without cmath...hmmm but I just added anyway

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I know it compiles. But relies on some other include to pull the proper include, which is wrong.
If that include will stop including cmath, we will be suddenly breaking.

shape.element_type(), input.builder());
result = xla::Mul(result, scale);
}
if (keep_reduced_dimensions) {
Expand Down Expand Up @@ -146,6 +148,7 @@ xla::XlaOp BuildMaxInDim(const xla::XlaOp& input, xla::int64 dim,
xla::XlaOp init_value = XlaHelpers::ScalarValue(
min_max.min, shape.element_type(), input.builder());
ReductionInfo rinfo = GetReductionInfo(shape, {dim}, keep_reduced_dimensions);
XLA_CHECK_GT(rinfo.element_count, 0);
xla::XlaOp result = xla::Reduce(
input, init_value, XlaHelpers::CreateMaxComputation(shape.element_type()),
{dim});
Expand All @@ -162,6 +165,7 @@ xla::XlaOp BuildMinInDim(const xla::XlaOp& input, xla::int64 dim,
xla::XlaOp init_value = XlaHelpers::ScalarValue(
min_max.max, shape.element_type(), input.builder());
ReductionInfo rinfo = GetReductionInfo(shape, {dim}, keep_reduced_dimensions);
XLA_CHECK_GT(rinfo.element_count, 0);
xla::XlaOp result = xla::Reduce(
input, init_value, XlaHelpers::CreateMinComputation(shape.element_type()),
{dim});
Expand Down