Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions test/functorch/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,6 @@ def wrapped(*args):
return wrapped, args



def get_jvp_variant(f, sample):
# We want this higher-order variant of jvp, so that it can
# be used to wrap vmap
Expand Down Expand Up @@ -403,9 +402,14 @@ class TestOperators(TestCase):
tol1('masked.cumprod',
{torch.float32: tol(atol=1e-05, rtol=1e-05)}),
tol1('svd_lowrank',
{torch.float32: tol(atol=3e-05, rtol=3e-04)}, device_type='cuda'),
{torch.float32: tol(atol=3e-04, rtol=3e-04)}, device_type='cuda'),
tol1('linalg.tensorsolve',
{torch.float32: tol(atol=3e-04, rtol=3e-04)}, device_type='cuda'),
tol1('__rmatmul__',
{torch.float32: tol(atol=3e-04, rtol=3e-04)}, device_type='cuda'),
tol1('matmul',
{torch.float32: tol(atol=3e-04, rtol=3e-04)}, device_type='cuda'),

))
def test_grad(self, device, dtype, op):
if op.name in vjp_fail:
Expand Down Expand Up @@ -541,7 +545,6 @@ def test_jvp(self, device, dtype, op):
clone_inputs=True,
fixme_ref_jvp_local=fixme_ref_jvp_local)


def jvp_opinfo_test(self, fn, sample, output_process_fn,
clone_inputs, fixme_ref_jvp_local):
# NB: we used requires_grad=True to determine where the primals are,
Expand Down Expand Up @@ -618,8 +621,12 @@ def maybe_clone_inputs():
{torch.float32: tol(atol=1e-05, rtol=1e-05)}),
tol1('linalg.tensorsolve',
{torch.float32: tol(atol=1e-05, rtol=1e-05)}),
tol1('linalg.multi_dot',
{torch.float32: tol(atol=1e-04, rtol=1e-04)}),
tol1('svd_lowrank',
{torch.float32: tol(atol=1e-04, rtol=1e-04)}),
tol1('pca_lowrank',
{torch.float32: tol(atol=1e-04, rtol=1e-04)}),
))
def test_vjp(self, device, dtype, op):
if not op.supports_autograd:
Expand Down Expand Up @@ -815,6 +822,8 @@ def fn(inp, *args, **kwargs):
{torch.float32: tol(atol=2e-03, rtol=2e-02)}),
tol1('svd',
{torch.float32: tol(atol=1e-03, rtol=5e-04)}),
tol1('matrix_exp',
{torch.float32: tol(atol=1e-03, rtol=5e-04)}),
))
@skipOps('TestOperators', 'test_vmapvjpvjp', {
xfail('as_strided', 'partial_views'),
Expand Down Expand Up @@ -935,6 +944,8 @@ def vjp_of_vjp(*args_and_cotangents):
{torch.float32: tol(atol=5e-04, rtol=1e-04)}, device_type="cuda"),
tol1('linalg.householder_product',
{torch.float32: tol(atol=1e-04, rtol=1e-04)}),
tol1('matrix_exp',
{torch.float32: tol(atol=5e-04, rtol=1e-04)}, device_type="cuda"),
))
@skipOps('TestOperators', 'test_vmapvjp', vmapvjp_fail.union({
xfail('as_strided', 'partial_views'),
Expand Down Expand Up @@ -1553,6 +1564,8 @@ def reference(primals, cotangents, primals_tangents, cotangents_tangents):
{torch.float32: tol(atol=5e-04, rtol=5e-04)}),
tol1('svd',
{torch.float32: tol(atol=5e-04, rtol=5e-04)}),
tol1('matrix_exp',
{torch.float32: tol(atol=5e-04, rtol=5e-04)}),
))
def test_vmapjvpvjp(self, device, dtype, op):
# Since we test `jvpvjp` seperately,
Expand Down Expand Up @@ -1597,7 +1610,6 @@ def jvp_of_vjp(*args):
for loop_out, batched_out in generator:
self.assertEqual(loop_out, batched_out)


def _make_extremal_inputs(self, shape, device):
if shape is None:
return (None,)
Expand Down Expand Up @@ -1670,7 +1682,6 @@ def test_extremal_numerics_softmax(self, device):
cotangents = torch.randn_like(result, device=device)
self._compare_jacobians_of_vjp(torch.nn.functional.softmax, (cotangents, input))


def test_extremal_numerics_log_softmax(self, device):
N, C, H, W = 3, 4, 5, 6
shapes = ((N, C), (N, C, H), (N, C, H, W))
Expand Down