Skip to content

Commit

Permalink
Gelu Backward, Contribution from Kevin Stephano (#58249)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #58249

Test Plan: Imported from OSS

Reviewed By: ejguan

Differential Revision: D28425381

Pulled By: Krovatkin

fbshipit-source-id: 21b7ac972220b6c35b285e3b66f05eb392002408
  • Loading branch information
Krovatkin authored and facebook-github-bot committed May 13, 2021
1 parent 3a898c2 commit d304bb0
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 9 deletions.
7 changes: 2 additions & 5 deletions test/test_ops.py
Expand Up @@ -362,9 +362,6 @@ def _test_inplace_preserve_storage(samples, variants):
def test_variant_consistency_jit(self, device, dtype, op):
_requires_grad = op.supports_autograd and (dtype.is_floating_point or
op.supports_complex_autograd(torch.device(device).type))
# TODO: fix this
if _requires_grad and not op.supports_gradgrad:
self.skipTest("skipped! This test does not handle ops that don't support gragrad properly")

samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad)

Expand Down Expand Up @@ -404,7 +401,7 @@ def out_fn(output):
out_fn,
(sample.input,) + sample.args,
sample.kwargs,
no_grad=not _requires_grad)
no_grad=not _requires_grad, no_gradgrad=not op.supports_gradgrad)

# Check traced forward, grad, and grad grad
traced_fn = create_traced_fn(self, variant)
Expand All @@ -414,7 +411,7 @@ def out_fn(output):
out_fn,
(sample.input,) + sample.args,
sample.kwargs,
no_grad=not _requires_grad)
no_grad=not _requires_grad, no_gradgrad=not op.supports_gradgrad)

# Check alias annotation schema for correctness (make
# sure inputs that aren't supposed to be modified aren't)
Expand Down
4 changes: 2 additions & 2 deletions torch/testing/_internal/common_jit.py
Expand Up @@ -38,7 +38,7 @@ def check_output_types(self, func, ref_outputs, args, kwargs):
])

def check_against_reference(self, func, reference_func, output_func, args, kwargs=None,
allow_unused=True, check_types=True, no_grad=False):
allow_unused=True, check_types=True, no_grad=False, no_gradgrad=False):
kwargs = kwargs if kwargs else {}

def allSum(vs):
Expand Down Expand Up @@ -104,7 +104,7 @@ def get_recording_tensors(args):
self.assertEqual(outputs, outputs_test)
self.assertEqual(grads, grads_test)
# test the grad grad case
if self._testMethodName in nn_functional_single_grad:
if self._testMethodName in nn_functional_single_grad or no_gradgrad:
return

outputs = output_func(self.runAndSaveRNG(reference_func, recording_inputs, kwargs))
Expand Down
8 changes: 6 additions & 2 deletions torch/testing/_internal/common_methods_invocations.py
Expand Up @@ -1104,7 +1104,9 @@ def sample_inputs_cdist(op_info, device, dtype, requires_grad, **kwargs):

samples = []
for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
for p in [0, 1, 2, 3, 0.5, 1.5, 2.5, float("inf")]:
# FIXME add an override for JIT and revert 0. back to 0
# since it's accepted by eager
for p in [0., 1., 2., 3., 0.5, 1.5, 2.5, float("inf")]:
for t1_size, t2_size in test_cases:
# The args should never be non-contiguous as this is not supported in the backward
samples.append(SampleInput(
Expand Down Expand Up @@ -4102,6 +4104,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs):
dtypes=floating_types(),
supports_out=False,
supports_gradgrad=False,
assert_autodiffed=False,
sample_inputs_func=sample_inputs_cdist),
UnaryUfuncInfo('ceil',
ref=np.ceil,
Expand Down Expand Up @@ -4995,13 +4998,14 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs):
dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
sample_inputs_func=sample_inputs_max_min_binary,),
OpInfo('nn.functional.hardswish',
aten_name="hardswish",
supports_autograd=True,
assert_autodiffed=True,
sample_inputs_func=sample_inputs_hardswish,
dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
supports_gradgrad=False,
supports_out=False,
autodiff_fusible_nodes=["aten::hardswish"]),
autodiff_nonfusible_nodes=["aten::hardswish"]),
OpInfo('topk',
dtypes=all_types(),
dtypesIfCUDA=all_types_and(torch.bfloat16, torch.float16),
Expand Down
3 changes: 3 additions & 0 deletions torch/testing/_internal/jit_metaprogramming_utils.py
Expand Up @@ -483,6 +483,9 @@ def check_alias_annotation(method_name, args, kwargs, *, aten_name, func_type='m
call = get_call(method_name, func_type, actuals, kwargs)
script = script_template.format(', '.join(formals), call)
CU = torch.jit.CompilationUnit(script)
# to clean up IR
torch._C._jit_pass_inline(CU.the_method.graph)
torch._C._jit_pass_constant_propagation(CU.the_method.graph)
torch._C._jit_check_alias_annotation(CU.the_method.graph, tuple(tensors), aten_name)

def get_nn_module_name_from_kwargs(**kwargs):
Expand Down

0 comments on commit d304bb0

Please sign in to comment.