Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions aten/src/ATen/TensorUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@

// These functions are NOT in Utils.h, because this file has a dep on Tensor.h

#define TORCH_CHECK_TENSOR_ALL(cond, ...) \
TORCH_CHECK((cond)._is_all_true().item<bool>(), __VA_ARGS__);

namespace at {

// The following are utility functions for checking that arguments
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/functorch/BatchRulesDecompositions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) {
OP_DECOMPOSE(swapdims);
OP_DECOMPOSE(take_along_dim);
OP_DECOMPOSE(tensordot);
OP_DECOMPOSE(_test_check_tensor);
OP_DECOMPOSE(tile);
OP_DECOMPOSE2(trapezoid, x);
OP_DECOMPOSE2(trapezoid, dx);
Expand Down
6 changes: 6 additions & 0 deletions aten/src/ATen/functorch/BatchRulesReduceOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ Tensor sum_decomp(
return at::sum(self, range(0, self.dim()), false, dtype);
}

std::tuple<Tensor, optional<int64_t>> _is_all_true_batch_rule(
const Tensor& self, optional<int64_t> self_bdim) {
return std::make_tuple(at::_is_all_true(self), nullopt);
}

Tensor mean_decomp(
const Tensor& self, optional<ScalarType> dtype) {
return at::mean(self, range(0, self.dim()), false, dtype);
Expand Down Expand Up @@ -502,5 +507,6 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
VMAP_SUPPORT(aminmax, aminmax_batching_rule);
VMAP_SUPPORT(_log_softmax_backward_data, _log_softmax_backward_batch_rule);
VMAP_SUPPORT(_softmax_backward_data, _softmax_backward_batch_rule);
VMAP_SUPPORT(_is_all_true, _is_all_true_batch_rule);
}
}}
4 changes: 4 additions & 0 deletions aten/src/ATen/native/ReduceOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2035,6 +2035,10 @@ Tensor all(const Tensor& self, Dimname dim, bool keepdim) {
Tensor& all_out(const Tensor &self, Dimname dim, bool keepdim, Tensor& result) {
reportNYIDimnameOverload("all");
}
Tensor _is_all_true(const Tensor& self) {
TORCH_INTERNAL_ASSERT(self.scalar_type() == at::kBool);
return self.all();
}
Tensor logcumsumexp(const Tensor& self, Dimname dim) {
return at::logcumsumexp(self, dimname_to_position(self, dim));
}
Expand Down
5 changes: 5 additions & 0 deletions aten/src/ATen/native/TestOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ Tensor _test_autograd_multiple_dispatch_view(const Tensor &self) {
return self.view(-1);
}

Tensor _test_check_tensor(const Tensor& self) {
TORCH_CHECK_TENSOR_ALL(self, "Test message for TORCH_CHECK_TENSOR_ALL");
return self.clone();
}

} // namespace native

namespace functionalization {
Expand Down
9 changes: 9 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,15 @@
- func: affine_grid_generator_backward(Tensor grad, int[] size, bool align_corners) -> Tensor
variants: function

- func: _is_all_true(Tensor self) -> Tensor
variants: function, method
dispatch:
CompositeExplicitAutograd: _is_all_true

# Note: this function is only for testing.
- func: _test_check_tensor(Tensor self) -> Tensor
Copy link
Collaborator Author

@kurtamohler kurtamohler Dec 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this private op for testing. Is that alright? I couldn't think of any other way to test TORCH_CHECK_TENSOR on batched tensors if we don't currently have a good way to do it in C++

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems fine.

One comment is that we should add _test_check_tensor to BatchRulesDecompositions.cpp:

What is going on right now is that _test_check_tensor is going into the vmap fallback, and the vmap fallback runs the function in a for-loop. This will not always happen: if one uses TORCH_CHECK_TENSOR inside a CompositeImplicitAutograd operation that gets decomposed by vmap, then I think it will error out. Adding it to BatchRulesDecompositions tests the case that TORCH_CHECK_TENSOR gets used inside a CompositeImplicitAutograd.

Copy link
Collaborator Author

@kurtamohler kurtamohler Jan 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I add these, I get these errors:

FAIL: test_register_functorch_batched_decomposition_[aten::_is_all_true] (__main__.TestFunctorchDispatcher)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/work2/kurtamohler/development/pytorch-1/torch/testing/_internal/common_utils.py", line 252, in instantiated_test
    test(self, **param_kwargs)
  File "/work2/kurtamohler/development/pytorch-1/torch/testing/_internal/common_utils.py", line 391, in test_wrapper
    return test(*args, **kwargs)
  File "/work2/kurtamohler/development/pytorch-1/test/functorch/test_vmap_registrations.py", line 391, in test_register_functorch_batched_decomposition
    assert registration in CompositeImplicitAutogradRegistrations, (
AssertionError: The registrations in BatchedDecompositions.cpp must be for CompositeImplicitAutograd operations. If your operation aten::_is_all_true is not CompositeImplicitAutograd, then please register it to the FuncTorchBatched key in another file.

======================================================================
FAIL: test_register_functorch_batched_decomposition_[aten::_test_check_tensor] (__main__.TestFunctorchDispatcher)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/work2/kurtamohler/development/pytorch-1/torch/testing/_internal/common_utils.py", line 252, in instantiated_test
    test(self, **param_kwargs)
  File "/work2/kurtamohler/development/pytorch-1/torch/testing/_internal/common_utils.py", line 391, in test_wrapper
    return test(*args, **kwargs)
  File "/work2/kurtamohler/development/pytorch-1/test/functorch/test_vmap_registrations.py", line 391, in test_register_functorch_batched_decomposition
    assert registration in CompositeImplicitAutogradRegistrations, (
AssertionError: The registrations in BatchedDecompositions.cpp must be for CompositeImplicitAutograd operations. If your operation aten::_test_check_tensor is not CompositeImplicitAutograd, then please register it to the FuncTorchBatched key in another file.

I'm not sure what to do. Do we not actually need this OP_DECOMPOSE since the functions are marked CompositeExplicitAutograd? For now, I've removed the OP_DECOMPOSE for _test_check_tensor and _is_all_true to get CI to pass

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_test_check_tensor should be CompositeImplicitAutograd, _is_all_true should be CompositeExplicitAutograd. The OP_DECOMPOSE should be for _test_check_tensor. Does that configuration not work?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, that works

variants: function

- func: all.dim(Tensor self, int dim, bool keepdim=False) -> Tensor
device_check: NoCheck # TensorIterator
structured_delegate: all.out
Expand Down
79 changes: 79 additions & 0 deletions test/functorch/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import itertools
import warnings
import unittest
import random
from torch.testing._internal.common_methods_invocations import op_db
from torch.testing._internal.common_cuda import with_tf32_off
from torch.testing._internal.common_device_type import instantiate_device_type_tests, \
Expand Down Expand Up @@ -4903,6 +4904,83 @@ def f(x):
with self.assertRaisesRegex(RuntimeError, "autograd.Function"):
transform(input)

class TestVmapDeviceType(Namespace.TestVmapBase):
def _vmap_test(self, *args, **kwargs):
return _vmap_test(self, *args, **kwargs)

def test__is_all_true(self, device):
def test():
def f(x, *, expected_result):
result = torch.ops.aten._is_all_true(x)
self.assertFalse(torch._C._functorch.is_batchedtensor(result))
self.assertEqual(result.shape, torch.Size([]))
self.assertEqual(result.item(), expected_result)
return result

x = torch.rand(10, device=device)
vmap(f)(x >= 0, expected_result=True)
vmap(f)(x < 0, expected_result=False)

x[random.choice(range(10))] *= -1
vmap(f)(x >= 0, expected_result=False)
vmap(f)(x < 0, expected_result=False)

x = -torch.rand(10, device=device)
vmap(f)(x > 0, expected_result=False)
vmap(f)(x <= 0, expected_result=True)

check_vmap_fallback(self, test, torch._is_all_true)

def test_check_tensor(self, device):
def test():
test_sizes = [
(1,),
(10,),
(1, 1),
(1, 10),
(10, 1),
(10, 10),
(1, 1, 1),
(10, 1, 1),
(1, 10, 1),
(10, 10, 10),
]

def check_gte_0(t):
return torch._test_check_tensor(t >= 0)

error_message = "Test message for TORCH_CHECK_TENSOR_ALL"

for size in test_sizes:
t_all_gte_0 = torch.rand(size, device=device)
t_all_lt_0 = t_all_gte_0 - 1

vmap(check_gte_0)(t_all_gte_0)

if len(size) >= 2:
vmap(vmap(check_gte_0))(t_all_gte_0)

with self.assertRaisesRegex(RuntimeError, error_message):
vmap(check_gte_0)(t_all_lt_0)

if len(size) >= 2:
with self.assertRaisesRegex(RuntimeError, error_message):
vmap(vmap(check_gte_0))(t_all_lt_0)

if t_all_gte_0.numel() > 1:
t_all_gte_0_but_one = t_all_gte_0.clone()
idx = (random.choice(range(dim_size)) for dim_size in size)
t_all_gte_0_but_one[(..., *idx)] = -1

with self.assertRaisesRegex(RuntimeError, error_message):
vmap(check_gte_0)(t_all_gte_0_but_one)

if len(size) >= 2:
with self.assertRaisesRegex(RuntimeError, error_message):
vmap(vmap(check_gte_0))(t_all_gte_0_but_one)

check_vmap_fallback(self, test, torch._test_check_tensor)

only_for = ("cpu", "cuda")
instantiate_device_type_tests(TestVmapOperatorsOpInfo, globals(), only_for=only_for)

Expand All @@ -4913,6 +4991,7 @@ def f(x):
)
instantiate_device_type_tests(TestTransformFailure, globals(), only_for=only_for)
instantiate_device_type_tests(TestRandomness, globals(), only_for=only_for)
instantiate_device_type_tests(TestVmapDeviceType, globals(), only_for=only_for)

if __name__ == '__main__':
run_tests()
34 changes: 34 additions & 0 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,6 +740,40 @@ def test_scalar_check(self, device):
self.assertEqual((), torch.nn.functional.multi_margin_loss(input, target, reduction='mean').shape)
self.assertEqual((), torch.nn.functional.multi_margin_loss(input, target, reduction='sum').shape)

# Test that `TORCH_CHECK_TENSOR_ALL` raises errors that propagate from C++ to Python
def test_check_tensor(self, device):
test_sizes = [
(),
(1,),
(10,),
(1, 1),
(1, 10),
(10, 1),
(10, 10),
(1, 1, 1),
(10, 1, 1),
(1, 10, 1),
(10, 10, 10),
]
for size in test_sizes:
t_all_true = torch.ones(size, dtype=torch.bool, device=device)
t_all_false = torch.zeros(size, dtype=torch.bool, device=device)

# Should not raise error
torch._test_check_tensor(t_all_true)

with self.assertRaisesRegex(RuntimeError, "Test message for TORCH_CHECK_TENSOR_ALL"):
torch._test_check_tensor(t_all_false)

if t_all_true.numel() > 1:
t_all_true_but_one = t_all_true.clone()
# Choose a random element to set to false
idx = (random.choice(range(dim_size)) for dim_size in size)
t_all_true_but_one[(..., *idx)] = False

with self.assertRaisesRegex(RuntimeError, "Test message for TORCH_CHECK_TENSOR_ALL"):
torch._test_check_tensor(t_all_true_but_one)

# Uses mismatched arange out size to trigger a warning
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
@unittest.skipIf(TEST_WITH_CROSSREF, "crossref perturbs line numbering")
Expand Down
3 changes: 3 additions & 0 deletions tools/autograd/derivatives.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,9 @@
- name: any.dim(Tensor self, int dim, bool keepdim=False) -> Tensor
output_differentiability: [False]

- name: _is_all_true(Tensor self) -> Tensor
self: non_differentiable

- name: all(Tensor self) -> Tensor
output_differentiability: [False]

Expand Down
1 change: 1 addition & 0 deletions torch/overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ def get_ignored_functions() -> Set[Callable]:
Tensor._conj_physical,
Tensor._neg_view,
Tensor._is_zerotensor,
Tensor._is_all_true,
Tensor._addmm_activation,
Tensor.to_padded_tensor,
}
Expand Down