Skip to content

Commit

Permalink
Merge branch 'pytorch:main' into fix-resolve_neg-docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
kato8966 committed Jun 25, 2023
2 parents a95ab54 + 86e0eda commit 54326ab
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 42 deletions.
4 changes: 2 additions & 2 deletions aten/src/ATen/native/cuda/Blas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
at::relu_(const_cast<Tensor&>(*result_));
break;
case Activation::GELU:
at::gelu_(const_cast<Tensor&>(*result_));
at::gelu_(const_cast<Tensor&>(*result_), "tanh");
break;
default: break;
}
Expand All @@ -347,7 +347,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
// epilogue above.
#if !defined(CUDA_VERSION) || CUDA_VERSION < 11080
if (useLtInterface && activation == Activation::GELU) {
at::gelu_(const_cast<Tensor&>(*result_));
at::gelu_(const_cast<Tensor&>(*result_), "tanh");
}
#endif

Expand Down
2 changes: 1 addition & 1 deletion test/run_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,7 +983,7 @@ def get_pytest_args(
if not is_cpp_test:
# C++ tests need to be run with pytest directly, not via python
pytest_args.extend(["-p", "no:xdist", "--use-pytest"])
if not options.continue_through_error and HAVE_TEST_SELECTION_TOOLS:
if not options.continue_through_error and IS_CI and HAVE_TEST_SELECTION_TOOLS:
pytest_args.append(f"--sc={stepcurrent_key}")
else:
# Use pytext-dist to run C++ tests in parallel as running them sequentially using run_test
Expand Down
32 changes: 8 additions & 24 deletions test/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -5415,35 +5415,19 @@ def _test_addmm_addmv(self, f, t, m, v, *, alpha=None, beta=None, transpose_out=
else:
f(t, m, v, alpha=alpha, beta=beta, out=res2)
res3 = alpha * (m.to(numpy_dtype).cpu().numpy() @ v.to(numpy_dtype).cpu().numpy())
res1_fused_epilogue = (t.is_cuda and t.dim() == 1 and beta == 1)
if TEST_WITH_ROCM or IS_WINDOWS or _get_torch_cuda_version() < (11, 8):
# epilogue fusion enabled only on CUDA >= 11.8
res1_fused_epilogue = False
res2_fused_epilogue = res1_fused_epilogue and res2.is_contiguous()
if beta != 0:
res3 += (beta * t).to(numpy_dtype).cpu().numpy()
if activation == "relu":
res3 = res3 * (res3 > 0)
elif activation == "gelu":
res3_t = torch.from_numpy(res3).to(dtype)
approximate = "none"
if res1_fused_epilogue:
# fused GELU epilogue used in CUDA utilizes
# the tanh approximation to compute GELU
approximate = "tanh"
approximate = "tanh" if t.is_cuda else "none"
res3_t = torch.nn.functional.gelu(res3_t, approximate=approximate)
res3 = res3_t.to(numpy_dtype).cpu().numpy()
else:
assert activation is None, f"unsupported activation {activation}"
res3 = torch.from_numpy(res3).to(dtype)
if activation == "gelu" and res1_fused_epilogue and not res2_fused_epilogue:
# when out=res2 is transposed (not contiguous), the epilogue is unfused;
# in this case, when the activation is GELU and res1's epilogue is fused,
# the difference between res1 and res2 will be larger due to the tanh
# approximation of GELU in res1 computation, but not in res2
self.assertEqual(res1, res2, atol=1e-3, rtol=0)
else:
self.assertEqual(res1, res2)
self.assertEqual(res1, res2)
self.assertEqual(res1, res3)

@precisionOverride({torch.bfloat16: 1e-0, torch.half: 5e-4, torch.float: 1e-4, torch.double: 1e-8,
Expand Down Expand Up @@ -5563,19 +5547,19 @@ def maybe_transpose(cond, m):
def test_addmm(self, device, dtype):
self._test_addmm_impl(torch.addmm, None, device, dtype)

@precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6,
torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8})
@precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 5e-2,
torch.half: 5e-2, torch.cfloat: 1e-4, torch.cdouble: 1e-8})
@dtypesIfCUDA(*floating_types_and(
*[torch.bfloat16] if TEST_WITH_ROCM or SM53OrLater else []))
*[torch.bfloat16, torch.half] if TEST_WITH_ROCM or SM53OrLater else []))
@dtypes(*floating_types_and(torch.bfloat16))
@tf32_on_and_off(0.05)
def test_addmm_relu(self, device, dtype):
self._test_addmm_impl(torch._addmm_activation, "relu", device, dtype)

@precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6,
torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8})
@precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 5e-2,
torch.half: 5e-2, torch.cfloat: 1e-4, torch.cdouble: 1e-8})
@dtypesIfCUDA(*floating_types_and(
*[torch.bfloat16] if TEST_WITH_ROCM or SM53OrLater else []))
*[torch.bfloat16, torch.half] if TEST_WITH_ROCM or SM53OrLater else []))
@dtypes(*floating_types_and(torch.bfloat16))
@tf32_on_and_off(0.05)
def test_addmm_gelu(self, device, dtype):
Expand Down
31 changes: 27 additions & 4 deletions test/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,10 +456,6 @@ def _test_gradients_helper(self, device, dtype, module_info, training, check):
else:
other_kwargs[name] = obj

grad_input = input_args + params + tuple(obj for (_, obj) in kwarg_tensors)

flat_input, flat_spec = torch.utils._pytree.tree_flatten(grad_input)

def fn_to_gradcheck(*flat_input_and_params):
input_and_params = torch.utils._pytree.tree_unflatten(flat_input_and_params, flat_spec)
new_input_args = input_and_params[:len(input_args)]
Expand All @@ -471,8 +467,35 @@ def fn_to_gradcheck(*flat_input_and_params):
output_flattened, _ = torch.utils._pytree.tree_flatten(output)
return output_flattened

# check total derivative
grad_input = input_args + params + tuple(obj for (_, obj) in kwarg_tensors)
flat_input, flat_spec = torch.utils._pytree.tree_flatten(grad_input)

self.assertTrue(check(fn_to_gradcheck, flat_input, nondet_tol=gradcheck_nondet_tol))

# check partial derivatives
old_params_requires_grad = [p.requires_grad for p in params]
for p in params:
p.requires_grad = False

old_kwargs_requires_grad = [obj.requires_grad for (_, obj) in kwarg_tensors]
for (_, obj) in kwarg_tensors:
obj.requires_grad = False

for p, old in zip(params, old_params_requires_grad):
p.requires_grad = old
grad_input = input_args + params + tuple(obj for (_, obj) in kwarg_tensors)
flat_input, flat_spec = torch.utils._pytree.tree_flatten(grad_input)
self.assertTrue(check(fn_to_gradcheck, flat_input, nondet_tol=gradcheck_nondet_tol))
p.requires_grad = False

for (_, obj), old in zip(kwarg_tensors, old_kwargs_requires_grad):
obj.requires_grad = old
grad_input = input_args + params + tuple(obj for (_, obj) in kwarg_tensors)
flat_input, flat_spec = torch.utils._pytree.tree_flatten(grad_input)
self.assertTrue(check(fn_to_gradcheck, flat_input, nondet_tol=gradcheck_nondet_tol))
obj.requires_grad = False

@modules(module_db, allowed_dtypes=[torch.double])
def test_grad(self, device, dtype, module_info, training):
self._test_gradients_helper(device, dtype, module_info, training, gradcheck)
Expand Down
48 changes: 37 additions & 11 deletions torch/_inductor/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,11 +430,14 @@ def has_type_promotion(*args):
if len(args) < 2:
return False
else:
dtype = args[0].data.get_dtype()
return any(
isinstance(t, TensorBox) and t.data.get_dtype() != dtype
for t in args
)
dtype = None
for t in args:
if isinstance(t, TensorBox):
if dtype is None:
dtype = t.data.get_dtype()
elif dtype != t.data.get_dtype():
return True
return False

# group by device, whether any of the inputs are dynamic, and whether their types match
# (proxy for type promotion)
Expand All @@ -445,7 +448,15 @@ def group_args(arg_pairs):
out = defaultdict(list)
for i, args in enumerate(arg_pairs):
use_foreach = not (is_dynamic(*args) or has_type_promotion(*args))
out[(args[0].get_device(), use_foreach)].append((i, args))
device = None
for t in args:
if isinstance(t, TensorBox):
device = t.data.get_device()
break
assert (
device is not None
), "foreach op should have at least one tensor arg"
out[(device, use_foreach)].append((i, args))
return out

realize_outputs = False
Expand All @@ -454,12 +465,26 @@ def group_args(arg_pairs):
if not (user.op == "call_function" and user.target in foreach_ops):
realize_outputs = True

# replicate scalar input to match lenghth of list input
if len(inputs) > 1 and not isinstance(inputs[1], (list, tuple)):
inputs = (inputs[0], [inputs[1] for _ in inputs[0]])
a_list_input = None
for input in inputs:
if isinstance(input, (list, tuple)):
a_list_input = input
break
assert (
a_list_input is not None
), "at least one input must be a list to a foreach op"

# broadcast scalar inputs to match length of list inputs
broadcast_inputs = []
for input in inputs:
if not isinstance(input, (list, tuple)):
broadcast_inputs.append([input] * len(a_list_input))
else:
broadcast_inputs.append(input)

groups = group_args(zip(*broadcast_inputs))

groups = group_args(zip(*inputs))
outputs = [None] * len(inputs[0])
outputs = [None] * len(a_list_input)
for (device, use_foreach), group in groups.items():
buffer_list = []
for (
Expand Down Expand Up @@ -4229,6 +4254,7 @@ def register_pointwise_numeric_ldf64(op):
register_foreach_pointwise(aten._foreach_sub.Scalar, sub)
register_foreach_pointwise(aten._foreach_neg.default, neg)
register_foreach_pointwise(aten._foreach_pow.Scalar, pow)
register_foreach_pointwise(aten._foreach_pow.ScalarAndTensor, pow)
register_foreach_pointwise(aten._foreach_div.List, div)
register_foreach_pointwise(aten._foreach_div.Scalar, div)
register_foreach_pointwise(aten._foreach_sqrt, sqrt)
Expand Down
4 changes: 4 additions & 0 deletions torch/testing/_internal/common_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -1367,6 +1367,10 @@ def module_inputs_torch_nn_LayerNorm(module_info, device, dtype, requires_grad,
constructor_input=FunctionInput([5], 1e-3),
forward_input=FunctionInput(make_input(((4, 5, 5)))),
desc='1d_elementwise_affine'),
ModuleInput(
constructor_input=FunctionInput([5], 1e-3),
forward_input=FunctionInput(make_input(((128, 5, 5)))),
desc='1d_elementwise_affine_large_batch'),
ModuleInput(
constructor_input=FunctionInput([5], 1e-3, False),
forward_input=FunctionInput(make_input(((4, 5, 5)))),
Expand Down

0 comments on commit 54326ab

Please sign in to comment.