Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions test/test_ops_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torch.testing._internal.common_utils import TestGradients, run_tests
from torch.testing._internal.common_methods_invocations import op_db
from torch.testing._internal.control_flow_opinfo_db import control_flow_opinfo_db
from torch.testing._internal.custom_op_db import custom_op_db
from torch.testing._internal.common_device_type import \
(instantiate_device_type_tests, ops, OpDTypes)

Expand All @@ -18,7 +19,7 @@

class TestBwdGradients(TestGradients):
# Tests that gradients are computed correctly
@_gradcheck_ops(op_db + control_flow_opinfo_db)
@_gradcheck_ops(op_db + control_flow_opinfo_db + custom_op_db)
def test_fn_grad(self, device, dtype, op):
# This is verified by test_dtypes in test_ops.py
if dtype not in op.supported_backward_dtypes(torch.device(device).type):
Expand All @@ -33,7 +34,7 @@ def test_fn_grad(self, device, dtype, op):
# self._skip_helper(op, device, dtype)
# self._grad_test_helper(device, dtype, op, op.get_method())

@_gradcheck_ops(op_db)
@_gradcheck_ops(op_db + custom_op_db)
def test_inplace_grad(self, device, dtype, op):
self._skip_helper(op, device, dtype)
if not op.inplace_variant:
Expand All @@ -52,7 +53,7 @@ def test_inplace_grad(self, device, dtype, op):
self._grad_test_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace()))

# Test that gradients of gradients are computed correctly
@_gradcheck_ops(op_db + control_flow_opinfo_db)
@_gradcheck_ops(op_db + control_flow_opinfo_db + custom_op_db)
def test_fn_gradgrad(self, device, dtype, op):
self._skip_helper(op, device, dtype)
if not op.supports_gradgrad:
Expand All @@ -61,7 +62,7 @@ def test_fn_gradgrad(self, device, dtype, op):
self._check_helper(device, dtype, op, op.get_op(), 'bwgrad_bwgrad')

# Test that gradients of gradients are properly raising
@_gradcheck_ops(op_db)
@_gradcheck_ops(op_db + custom_op_db)
def test_fn_fail_gradgrad(self, device, dtype, op):
self._skip_helper(op, device, dtype)
if op.supports_gradgrad:
Expand Down
282 changes: 281 additions & 1 deletion test/test_python_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,7 +730,7 @@ def foo2(x: torch.Tensor) -> torch.Tensor:

def test_private_ctor(self):
with self.assertRaisesRegex(RuntimeError, 'CustomOp constructor is private'):
CustomOp(None, None, None, None)
CustomOp(None, None, None, None, None)

def test_lifetime(self):
@custom_op(f'{TestCustomOp.test_ns}::foo')
Expand Down Expand Up @@ -789,6 +789,21 @@ def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
foo(y, x)
foo._destroy()

def test_autograd_notimplemented_gradmode(self):
@custom_op(f'{TestCustomOp.test_ns}::foo')
def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
...

@foo.impl(['cpu'])
def foo_impl(x, y):
return x * y

x = torch.randn(3, requires_grad=True)
y = torch.randn(3)
with torch.no_grad():
# Shouldn't raise, because we are in no_grad
foo(y, x)

def test_impl_cpu(self):
@custom_op(f'{TestCustomOp.test_ns}::foo')
def foo(x: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -824,6 +839,271 @@ def foo_impl(x):
foo.impl(invalid_type)(foo_impl)
foo._destroy()

def test_backward_partially_registered(self):
@custom_op(f'{TestCustomOp.test_ns}::foo')
def foo(x: torch.Tensor) -> torch.Tensor:
...

@foo.impl(['cpu', 'cuda'])
def foo_impl(x):
return x.sin()

@foo.impl_backward()
def foo_backward(ctx, saved, grad):
return grad * saved.cos()

x = torch.randn([], requires_grad=True)
with self.assertRaisesRegex(RuntimeError, "unable to find a 'save_for_backward'"):
y = foo(x)
y.backward()

def test_save_for_backward_inputs_are_namedtuple(self):
@custom_op(f'{TestCustomOp.test_ns}::foo')
def foo(x: torch.Tensor) -> torch.Tensor:
...

@foo.impl(['cpu', 'cuda'])
def foo_impl(x):
return x.sin()

hit = 0

@foo.impl_save_for_backward()
def foo_save_for_backward(inputs, output):
nonlocal hit
hit += 1
self.assertTrue(isinstance(inputs, tuple))
self.assertEqual(list(inputs._asdict().keys()), ['x'])
return inputs.x

@foo.impl_backward()
def foo_backward(ctx, saved, grad):
return {'x': grad * saved.cos()}

x = torch.randn([], requires_grad=True)
y = foo(x)
self.assertEqual(hit, 1)
y.backward()
self.assertEqual(hit, 1)

def test_backward_returns_dict(self):
@custom_op(f'{TestCustomOp.test_ns}::foo')
def foo(x: torch.Tensor) -> torch.Tensor:
...

@foo.impl(['cpu', 'cuda'])
def foo_impl(x):
return x.sin()

@foo.impl_save_for_backward()
def foo_save_for_backward(inputs, output):
return inputs.x

@foo.impl_backward()
def foo_backward(ctx, saved, grad):
return grad * saved.cos()

x = torch.randn([], requires_grad=True)
y = foo(x)
with self.assertRaisesRegex(RuntimeError, 'to be a dict'):
y.backward()

def test_backward_dict_invalid_keys(self):
@custom_op(f'{TestCustomOp.test_ns}::foo')
def foo(x: torch.Tensor) -> torch.Tensor:
...

@foo.impl(['cpu', 'cuda'])
def foo_impl(x):
return x.sin()

@foo.impl_save_for_backward()
def foo_save_for_backward(inputs, output):
return inputs.x

@foo.impl_backward()
def foo_backward(ctx, saved, grad):
return {'x': grad * saved.cos(), 'y': None}

x = torch.randn([], requires_grad=True)
y = foo(x)
with self.assertRaisesRegex(RuntimeError, "to have keys {'x'}"):
y.backward()

def test_backward_dict_grad_for_nontensor(self):
@custom_op(f'{TestCustomOp.test_ns}::foo')
def foo(x: torch.Tensor, dim: int) -> torch.Tensor:
...

@foo.impl(['cpu', 'cuda'])
def foo_impl(x, dim):
return x.sin()

@foo.impl_save_for_backward()
def foo_save_for_backward(inputs, output):
return inputs.x

@foo.impl_backward()
def foo_backward(ctx, saved, grad):
return {'x': grad * saved.cos(), 'dim': None}

x = torch.randn([], requires_grad=True)
y = foo(x, 32)
with self.assertRaisesRegex(RuntimeError, "non-Tensor-like types"):
y.backward()

def test_backward_dict_requires_keys_for_input_tensors(self):
@custom_op(f'{TestCustomOp.test_ns}::foo')
def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
...

@foo.impl(['cpu', 'cuda'])
def foo_impl(x, y):
return x.sin()

@foo.impl_save_for_backward()
def foo_save_for_backward(inputs, output):
return inputs.x

@foo.impl_backward()
def foo_backward(ctx, saved, grad):
return {'x': grad * saved.cos()}

x = torch.randn([], requires_grad=True)
y = foo(x, x)
with self.assertRaisesRegex(RuntimeError, r"to have keys {.*'y'.*}"):
y.backward()

def test_backward_dict_requires_keys_for_input_optional_tensors(self):
@custom_op(f'{TestCustomOp.test_ns}::foo')
def foo(x: torch.Tensor, y: Optional[torch.Tensor]) -> torch.Tensor:
...

@foo.impl(['cpu', 'cuda'])
def foo_impl(x, y):
return x.sin()

@foo.impl_save_for_backward()
def foo_save_for_backward(inputs, output):
return inputs.x

@foo.impl_backward()
def foo_backward(ctx, saved, grad):
return {'x': grad * saved.cos()}

x = torch.randn([], requires_grad=True)
y = foo(x, None)
with self.assertRaisesRegex(RuntimeError, r"to have keys {.*'y'.*}"):
y.backward()

def test_backward_grads_are_tensor_or_none(self):
@custom_op(f'{TestCustomOp.test_ns}::foo')
def foo(x: torch.Tensor) -> torch.Tensor:
...

@foo.impl(['cpu', 'cuda'])
def foo_impl(x):
return x.sin()

@foo.impl_save_for_backward()
def foo_save_for_backward(inputs, output):
return inputs.x

@foo.impl_backward()
def foo_backward(ctx, saved, grad):
return {'x': (grad * saved.cos(),)}

x = torch.randn([], requires_grad=True)
y = foo(x)
with self.assertRaisesRegex(RuntimeError, 'either None or a Tensor'):
y.backward()

def test_backward_tensorlist_input_requires_list_grads_with_same_numel(self):
@custom_op(f'{TestCustomOp.test_ns}::foo')
def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor:
...

@foo.impl(['cpu', 'cuda'])
def foo_impl(xs):
return xs[0].sin()

@foo.impl_save_for_backward()
def foo_save_for_backward(inputs, output):
return inputs.xs[0]

@foo.impl_backward()
def foo_backward(ctx, saved, grad):
return {'xs': [grad * saved.cos(), None]}

xs = [torch.randn([], requires_grad=True) for _ in range(3)]
y = foo(xs)
with self.assertRaisesRegex(RuntimeError, "3 gradients but got 2"):
y.backward()

def test_backward_tensorlist_input_requires_list_grads_none_or_Tensor(self):
@custom_op(f'{TestCustomOp.test_ns}::foo')
def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor:
...

@foo.impl(['cpu', 'cuda'])
def foo_impl(xs):
return xs[0].sin()

@foo.impl_save_for_backward()
def foo_save_for_backward(inputs, output):
return inputs.xs[0]

@foo.impl_backward()
def foo_backward(ctx, saved, grad):
return {'xs': [grad * saved.cos(), None, (None,)]}

xs = [torch.randn([], requires_grad=True) for _ in range(3)]
y = foo(xs)
with self.assertRaisesRegex(RuntimeError, "None or Tensor"):
y.backward()

def test_backward_tensorlist_input_requires_list_grads(self):
@custom_op(f'{TestCustomOp.test_ns}::foo')
def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor:
...

@foo.impl(['cpu', 'cuda'])
def foo_impl(xs):
return xs[0].sin()

@foo.impl_save_for_backward()
def foo_save_for_backward(inputs, output):
return inputs.xs[0]

@foo.impl_backward()
def foo_backward(ctx, saved, grad):
return {'xs': None}

xs = [torch.randn([], requires_grad=True) for _ in range(3)]
y = foo(xs)
with self.assertRaisesRegex(RuntimeError, "list of gradients"):
y.backward()

def test_backward_output_differentiability_type(self):
@custom_op(f'{TestCustomOp.test_ns}::foo')
def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor:
...

with self.assertRaisesRegex(RuntimeError, "output_differentiability"):
@foo.impl_backward(output_differentiability=True)
def foo_backward(ctx, saved, grad):
return {'xs': None}

def test_backward_output_differentiability_numel(self):
@custom_op(f'{TestCustomOp.test_ns}::foo')
def foo(xs: Sequence[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
...

with self.assertRaisesRegex(RuntimeError, "output_differentiability"):
@foo.impl_backward(output_differentiability=[True])
def foo_backward(ctx, saved, grad):
return {'xs': None}

@unittest.skipIf(not TEST_CUDA, "requires CUDA")
def test_impl_separate(self):
@custom_op(f'{TestCustomOp.test_ns}::foo')
Expand Down
Loading