Skip to content

Commit

Permalink
[functorch] test - vmapjvpvjp (#83375)
Browse files Browse the repository at this point in the history
Adds `vmapjvpvjp` test to `functorch`

Runtime of the test:
```
= 856 passed, 250 skipped, 16175 deselected, 137 xfailed, 197 warnings in 2231.84s (0:37:11) =
```

Pull Request resolved: #83375
Approved by: https://github.com/zou3519
  • Loading branch information
kshitij12345 authored and pytorchmergebot committed Sep 13, 2022
1 parent b4a881a commit 53c71e2
Showing 1 changed file with 160 additions and 0 deletions.
160 changes: 160 additions & 0 deletions functorch/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1221,6 +1221,166 @@ def reference(primals, cotangents, primals_tangents, cotangents_tangents):
expected = reference(primals, cotangents, primals_tangents, cotangents_tangents)
self.assertEqual(result, expected)

@skipOps('TestOperators', 'test_vmapjvpvjp', vjp_fail.union({
# Following operatos take too long, hence skipped
skip('atleast_1d'),
skip('atleast_2d'),
skip('atleast_3d'),
skip('meshgrid', 'list_of_tensors'),
skip('meshgrid', 'variadic_tensors'),
skip('broadcast_tensors'),
skip('linalg.lstsq'),
skip('nn.functional.bilinear'),
skip('native_layer_norm'),
# Potential bugs/errors
xfail('_masked.cumprod'), # calls item()
xfail('_masked.prod'), # calls item()
xfail('as_strided'), # AssertionError: Tensor-likes are not close!
xfail('as_strided_scatter'), # AssertionError: Tensor-likes are not close!
xfail('bernoulli'), # calls random op
xfail('bfloat16'), # required rank 4 tensor to use channels_last format
xfail('cdist'), # Forward AD not implemented and no decomposition
xfail('chalf'), # required rank 4 tensor to use channels_last format
xfail('cholesky'), # Forward AD not implemented and no decomposition
xfail('cumprod'), # calls item()
xfail('double'), # required rank 4 tensor to use channels_last format
xfail('float'), # required rank 4 tensor to use channels_last format
xfail('half'), # required rank 4 tensor to use channels_last format
xfail('index_reduce'), # Forward AD not implemented and no decomposition
xfail('linalg.eig'), # vmap over torch.allclose isn't supported yet.
# AssertionError: Tensor-likes are not close!
# Mismatched elements: 2 / 120 (1.7%)
# Greatest absolute difference: 0.09438323974609375
# Greatest relative difference: 0.00115722746596277
xfail('linalg.householder_product', device_type='cuda'),
xfail('linalg.vander'), # calls item()
xfail('logcumsumexp'), # Forward AD not implemented and no decomposition
xfail('mvlgamma', 'mvlgamma_p_1'), # vmap: inplace into a regular tensor
xfail('mvlgamma', 'mvlgamma_p_3'), # vmap: inplace into a regular tensor
xfail('mvlgamma', 'mvlgamma_p_5'), # vmap: inplace into a regular tensor
xfail('nanquantile'), # Batching rule not implemented for aten::equal
# RuntimeError: Batch norm got a batched tensor as input while the
# running_mean or running_var, which will be updated in place,
# were not batched.
xfail('nn.functional.batch_norm'),
xfail('nn.functional.batch_norm', 'without_cudnn'),
xfail('nn.functional.binary_cross_entropy'), # vmap: inplace into a regular tensor
xfail('nn.functional.dropout2d'), # calls random op
xfail('nn.functional.dropout3d'), # calls random op
xfail('nn.functional.dropout'), # calls random op
xfail('nn.functional.embedding_bag'), # Forward AD not implemented and no decomposition
xfail('nn.functional.feature_alpha_dropout', 'with_train'), # calls random op
xfail('nn.functional.fractional_max_pool2d'), # calls random op
xfail('nn.functional.fractional_max_pool3d'), # calls random op
xfail('nn.functional.gaussian_nll_loss'), # data depenedant flow
xfail('nn.functional.grid_sample'), # Forward AD not implemented and no decomposition
xfail('nn.functional.hardsigmoid'), # Forward AD not implemented and no decomposition
xfail('nn.functional.hinge_embedding_loss'), # vmap: inplace into a regular tensor
xfail('nn.functional.huber_loss'), # Forward AD not implemented and no decomposition
# RuntimeError: Batch norm got a batched tensor as input while the
# running_mean or running_var, which will be updated in place,
# were not batched.
xfail('nn.functional.instance_norm'),
xfail('nn.functional.logsigmoid'), # Forward AD not implemented and no decomposition
# NYI: Tensor.clone(memory_format) inside vmap is only supported with
# memory_format torch.preserve_format or torch.contiguous_format (got ChannelsLast)
xfail('nn.functional.max_pool2d', device_type='cuda'), # AssertionError: Tensor-likes are not close!
xfail('nn.functional.max_unpool2d'),
xfail('nn.functional.max_unpool2d', 'grad'),
xfail('nn.functional.multi_margin_loss'), # Forward AD not implemented and no decomposition
xfail('nn.functional.multilabel_margin_loss'), # Forward AD not implemented and no decomposition
xfail('nn.functional.multilabel_soft_margin_loss'), # Forward AD not implemented and no decomposition
xfail('nn.functional.pdist'), # Forward AD not implemented and no decomposition
xfail('nn.functional.rrelu'), # vmap: we do not yet support aten::rrelu_with_noise.
xfail('nn.functional.soft_margin_loss'), # Forward AD not implemented and no decomposition
xfail('normal'), # calls random op
xfail('normal', 'number_mean'), # calls random op
xfail('pca_lowrank'), # calls random op
xfail('prod'), # Dynamic shape due to aten::nonzero call
xfail('quantile'), # Batching rule not implemented for aten::equal
xfail('renorm'), # Forward AD not implemented and no decomposition
xfail('scatter_reduce', 'amax'), # Forward AD not implemented and no decomposition
xfail('scatter_reduce', 'amin'), # Forward AD not implemented and no decomposition
xfail('scatter_reduce', 'mean'), # Forward AD not implemented and no decomposition
xfail('scatter_reduce', 'prod'), # Forward AD not implemented and no decomposition
xfail('scatter_reduce', 'sum'), # Forward AD not implemented and no decomposition
xfail('segment_reduce', 'lengths'), # Forward AD not implemented and no decomposition
xfail('segment_reduce', 'offsets'), # Forward AD not implemented and no decomposition
xfail('sparse.sampled_addmm'), # RuntimeError: Sparse CSR tensors do not have strides
xfail('svd_lowrank'), # calls random op
xfail('symeig'), # Forward AD not implemented and no decomposition
xfail('take'), # vmap: inplace into regular tensor
xfail('to'), # RuntimeError: required rank 4 tensor to use channels_last format
xfail('to_sparse'), # Forward AD not implemented and no decomposition
xfail('view_as_complex'), # RuntimeError: Tensor must have a last dimension with stride 1
}))
@ops(op_db + additional_op_db, allowed_dtypes=(torch.float,))
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
@opsToleranceOverride('TestOperators', 'test_vmapjvpvjp', (
tol1('linalg.svd',
{torch.float32: tol(atol=5e-04, rtol=5e-04)}),
tol1('linalg.householder_product',
{torch.float32: tol(atol=5e-04, rtol=5e-04)}),
tol1('linalg.multi_dot',
{torch.float32: tol(atol=5e-04, rtol=5e-04)}),
tol1('svd',
{torch.float32: tol(atol=5e-04, rtol=5e-04)}),
))
# linalg.svd - manual tolerance
# svd
def test_vmapjvpvjp(self, device, dtype, op):
# Since we test `jvpvjp` seperately,
# in this we just check that vmap of `jvpvjp`
# is correct.
if not op.supports_autograd:
self.skipTest("Skipped! Autograd not supported.")
return

samples = op.sample_inputs(device, dtype, requires_grad=True)

# TODO: test in-place
if is_inplace(op, op.get_op()):
self.skipTest("Skipped! NYI: inplace-testing not supported.")
return

for sample in samples:
fn, primals = normalize_op_input_output(op, sample)
result = fn(*primals)
cotangents = tree_map(lambda x: torch.randn_like(x), result)

primals_tangents = tree_map(lambda x: torch.randn_like(x), primals)
cotangents_tangents = tree_map(lambda x: torch.randn_like(x), cotangents)

if isinstance(primals[0], torch.Tensor) and primals[0].numel() == 0:
# typically the first primal arg is the input. If the input has no elements, we will typically run
# into an issue of "Expected Tensor but got None"
continue

def push_vjp(primals, cotangents):
_, vjp_fn = vjp(fn, *primals)
return vjp_fn(cotangents)

args, spec = tree_flatten(((primals, cotangents), (primals_tangents, cotangents_tangents)))

def jvp_of_vjp(*args):
(primals, tangents) = tree_unflatten(args, spec)
primals_out, tangents_out = jvp(push_vjp, primals, tangents)

if isinstance(primals_out, torch.Tensor):
return (primals_out, tangents_out)
else:
flat_primals_out, _ = tree_flatten(primals_out)
flat_tangents_out, _ = tree_flatten(tangents_out)
return tuple(flat_primals_out + flat_tangents_out)

is_batch_norm_and_training = is_batch_norm_training(op, sample.kwargs)
generator = get_fallback_and_vmap_exhaustive(
jvp_of_vjp, args, {}, is_batch_norm_and_training=is_batch_norm_and_training)
for loop_out, batched_out in generator:
self.assertEqual(loop_out, batched_out)


def _make_extremal_inputs(self, shape, device):
if shape is None:
return (None,)
Expand Down

0 comments on commit 53c71e2

Please sign in to comment.