diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h index f99dc3c07058..a65a48d601dc 100644 --- a/aten/src/ATen/core/interned_strings.h +++ b/aten/src/ATen/core/interned_strings.h @@ -294,6 +294,7 @@ namespace c10 { _(aten, swapdims_) \ _(aten, movedim) \ _(aten, moveaxis) \ + _(aten, has_torch_function) \ FORALL_ATEN_BASE_SYMBOLS(_) \ _(onnx, Add) \ _(onnx, Concat) \ diff --git a/torch/csrc/jit/frontend/ir_emitter.cpp b/torch/csrc/jit/frontend/ir_emitter.cpp index d3aa0ba7295a..0dd84e4bb257 100644 --- a/torch/csrc/jit/frontend/ir_emitter.cpp +++ b/torch/csrc/jit/frontend/ir_emitter.cpp @@ -1224,8 +1224,11 @@ struct to_ir { } auto expr_out = emitToBool(expr.range(), emitExpr(expr)); c10::optional static_if = c10::nullopt; - if (expr_out->node()->kind() == aten::is_scripting) { + auto kind = expr_out->node()->kind(); + if (kind == aten::is_scripting) { static_if = true; + } else if (kind == aten::has_torch_function) { + static_if = false; } // MetaCompile on boolean literals and constants if (auto maybe_ivalue = toIValue(expr_out)) { diff --git a/torch/csrc/jit/runtime/register_special_ops.cpp b/torch/csrc/jit/runtime/register_special_ops.cpp index 28a4136ba829..2cd5a13d3f4b 100644 --- a/torch/csrc/jit/runtime/register_special_ops.cpp +++ b/torch/csrc/jit/runtime/register_special_ops.cpp @@ -372,6 +372,10 @@ RegisterOperators reg({ TORCH_SELECTIVE_SCHEMA("aten::is_scripting() -> bool"), [](Stack* stack) { push(stack, true); }, aliasAnalysisFromSchema()), + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA("aten::has_torch_function(...) -> bool"), + [](Stack* stack) { push(stack, false); }, + aliasAnalysisFromSchema()), OperatorGenerator( TORCH_SELECTIVE_SCHEMA( "aten::_no_grad_uniform_(Tensor(a!) tensor, float a, float b) -> Tensor(a!)"), diff --git a/torch/functional.py b/torch/functional.py index 43fa0a3df546..1d3403e65304 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -69,9 +69,8 @@ def broadcast_tensors(*tensors): tensor([[0, 1, 2], [0, 1, 2]]) """ - if not torch.jit.is_scripting(): - if has_torch_function(tensors): - return handle_torch_function(broadcast_tensors, tensors, *tensors) + if has_torch_function(tensors): + return handle_torch_function(broadcast_tensors, tensors, *tensors) return _VF.broadcast_tensors(tensors) # type: ignore @@ -147,10 +146,9 @@ def split(tensor, split_size_or_sections, dim=0): [6, 7], [8, 9]])) """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(tensor): - return handle_torch_function(split, (tensor,), tensor, split_size_or_sections, - dim=dim) + if has_torch_function_unary(tensor): + return handle_torch_function( + split, (tensor,), tensor, split_size_or_sections, dim=dim) # Overwriting reason: # This dispatches to two ATen functions depending on the type of # split_size_or_sections. The branching code is in tensor.py, which we @@ -236,11 +234,11 @@ def lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True): >>> torch.norm(A_ - A) tensor(2.9802e-08) """ - if not torch.jit.is_scripting(): - if has_torch_function_variadic(LU_data, LU_pivots): - return handle_torch_function( - lu_unpack, (LU_data, LU_pivots), LU_data, LU_pivots, unpack_data=unpack_data, - unpack_pivots=unpack_pivots) + if has_torch_function_variadic(LU_data, LU_pivots): + return handle_torch_function( + lu_unpack, (LU_data, LU_pivots), LU_data, LU_pivots, + unpack_data=unpack_data, + unpack_pivots=unpack_pivots) shape = LU_data.shape # In generalized LU factorization, the following shape relations hold: # A.shape[-2:] == (m, n) @@ -301,7 +299,7 @@ def einsum(equation, *operands): based on the Einstein summation convention. Einsum allows computing many common multi-dimensional linear algebraic array operations by representing them - in a short-hand format based on the Einstein summation convention, given by :attr:`equation`. The details of + in a short-hand format based on the Einstein summation convention, given by :attr:`equation`. The details of this format are described below, but the general idea is to label every dimension of the input :attr:`operands` with some subscript and define which subscripts are part of the output. The output is then computed by summing the product of the elements of the :attr:`operands` along the dimensions whose subscripts are not part of the @@ -387,7 +385,7 @@ def einsum(equation, *operands): # batch permute >>> A = torch.randn(2, 3, 4, 5) - >>> torch.einsum('...ij->...ji', A).shape + >>> torch.einsum('...ij->...ji', A).shape torch.Size([2, 3, 5, 4]) # equivalent to torch.nn.functional.bilinear @@ -398,9 +396,8 @@ def einsum(equation, *operands): tensor([[-0.3430, -5.2405, 0.4494], [ 0.3311, 5.5201, -3.0356]]) """ - if not torch.jit.is_scripting(): - if has_torch_function(operands): - return handle_torch_function(einsum, operands, equation, *operands) + if has_torch_function(operands): + return handle_torch_function(einsum, operands, equation, *operands) if len(operands) == 1 and isinstance(operands[0], (list, tuple)): # the old interface of passing the operands as one list argument _operands = operands[0] @@ -448,9 +445,8 @@ def meshgrid(*tensors): def _meshgrid(*tensors): - if not torch.jit.is_scripting(): - if has_torch_function(tensors): - return handle_torch_function(meshgrid, tensors, *tensors) + if has_torch_function(tensors): + return handle_torch_function(meshgrid, tensors, *tensors) if len(tensors) == 1 and isinstance(tensors[0], (list, tuple)): # the old interface of passing the operands as one list argument tensors = tensors[0] # type: ignore @@ -568,12 +564,11 @@ def stft(input: Tensor, n_fft: int, hop_length: Optional[int] = None, Tensor: A tensor containing the STFT result with shape described above """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - stft, (input,), input, n_fft, hop_length=hop_length, win_length=win_length, - window=window, center=center, pad_mode=pad_mode, normalized=normalized, - onesided=onesided, return_complex=return_complex) + if has_torch_function_unary(input): + return handle_torch_function( + stft, (input,), input, n_fft, hop_length=hop_length, win_length=win_length, + window=window, center=center, pad_mode=pad_mode, normalized=normalized, + onesided=onesided, return_complex=return_complex) # TODO: after having proper ways to map Python strings to ATen Enum, move # this and F.pad to ATen. if center: @@ -650,12 +645,11 @@ def istft(input: Tensor, n_fft: int, hop_length: Optional[int] = None, Returns: Tensor: Least squares estimation of the original signal of size (..., signal_length) """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - istft, (input,), input, n_fft, hop_length=hop_length, win_length=win_length, - window=window, center=center, normalized=normalized, onesided=onesided, - length=length, return_complex=return_complex) + if has_torch_function_unary(input): + return handle_torch_function( + istft, (input,), input, n_fft, hop_length=hop_length, win_length=win_length, + window=window, center=center, normalized=normalized, onesided=onesided, + length=length, return_complex=return_complex) return _VF.istft(input, n_fft, hop_length, win_length, window, center, # type: ignore normalized, onesided, length, return_complex) @@ -734,11 +728,10 @@ def _unique_impl(input: Tensor, sorted: bool = True, [ 1, 2]]) """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - unique, (input,), input, sorted=sorted, return_inverse=return_inverse, - return_counts=return_counts, dim=dim) + if has_torch_function_unary(input): + return handle_torch_function( + unique, (input,), input, sorted=sorted, return_inverse=return_inverse, + return_counts=return_counts, dim=dim) if dim is not None: output, inverse_indices, counts = _VF.unique_dim( # type: ignore @@ -810,11 +803,10 @@ def _unique_consecutive_impl(input: Tensor, return_inverse: bool = False, >>> counts tensor([2, 2, 1, 2, 1]) """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - unique_consecutive, (input,), input, return_inverse=return_inverse, - return_counts=return_counts, dim=dim) + if has_torch_function_unary(input): + return handle_torch_function( + unique_consecutive, (input,), input, return_inverse=return_inverse, + return_counts=return_counts, dim=dim) output, inverse_indices, counts = _VF.unique_consecutive( # type: ignore input, return_inverse=return_inverse, return_counts=return_counts, dim=dim) return output, inverse_indices, counts @@ -823,9 +815,8 @@ def _unique_consecutive_impl(input: Tensor, return_inverse: bool = False, def _return_counts(input, sorted=True, return_inverse=False, return_counts=False, dim=None): # type: (Tensor, bool, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor] - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return _unique_impl(input, sorted, return_inverse, return_counts, dim) + if has_torch_function_unary(input): + return _unique_impl(input, sorted, return_inverse, return_counts, dim) output, _, counts = _unique_impl(input, sorted, return_inverse, return_counts, dim) return output, counts @@ -834,9 +825,8 @@ def _return_counts(input, sorted=True, return_inverse=False, return_counts=False def _return_output(input, sorted=True, return_inverse=False, return_counts=False, dim=None): # type: (Tensor, bool, bool, bool, Optional[int]) -> Tensor - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return _unique_impl(input, sorted, return_inverse, return_counts, dim) + if has_torch_function_unary(input): + return _unique_impl(input, sorted, return_inverse, return_counts, dim) output, _, _ = _unique_impl(input, sorted, return_inverse, return_counts, dim) return output @@ -845,9 +835,8 @@ def _return_output(input, sorted=True, return_inverse=False, return_counts=False def _return_inverse(input, sorted=True, return_inverse=False, return_counts=False, dim=None): # type: (Tensor, bool, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor] - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return _unique_impl(input, sorted, return_inverse, return_counts, dim) + if has_torch_function_unary(input): + return _unique_impl(input, sorted, return_inverse, return_counts, dim) output, inverse_indices, _ = _unique_impl(input, sorted, return_inverse, return_counts, dim) return output, inverse_indices @@ -888,9 +877,8 @@ def _return_inverse(input, sorted=True, return_inverse=False, return_counts=Fals def _consecutive_return_counts(input, return_inverse=False, return_counts=False, dim=None): # type: (Tensor, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor] - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return _unique_consecutive_impl(input, return_inverse, return_counts, dim) + if has_torch_function_unary(input): + return _unique_consecutive_impl(input, return_inverse, return_counts, dim) output, _, counts = _unique_consecutive_impl(input, return_inverse, return_counts, dim) return output, counts @@ -899,9 +887,8 @@ def _consecutive_return_counts(input, return_inverse=False, return_counts=False, def _consecutive_return_output(input, return_inverse=False, return_counts=False, dim=None): # type: (Tensor, bool, bool, Optional[int]) -> Tensor - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return _unique_consecutive_impl(input, return_inverse, return_counts, dim) + if has_torch_function_unary(input): + return _unique_consecutive_impl(input, return_inverse, return_counts, dim) output, _, _ = _unique_consecutive_impl(input, return_inverse, return_counts, dim) return output @@ -910,9 +897,8 @@ def _consecutive_return_output(input, return_inverse=False, return_counts=False, def _consecutive_return_inverse(input, return_inverse=False, return_counts=False, dim=None): # type: (Tensor, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor] - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return _unique_consecutive_impl(input, return_inverse, return_counts, dim) + if has_torch_function_unary(input): + return _unique_consecutive_impl(input, return_inverse, return_counts, dim) output, inverse_indices, _ = _unique_consecutive_impl(input, return_inverse, return_counts, dim) return output, inverse_indices @@ -1000,9 +986,8 @@ def tensordot(a, b, dims=2, out=None): [ 1.5513, -14.4737, -6.5113], [ -0.2850, 4.2573, -3.5997]]) """ - if not torch.jit.is_scripting(): - if has_torch_function_variadic(a, b): - return handle_torch_function(tensordot, (a, b), a, b, dims=dims) + if has_torch_function_variadic(a, b): + return handle_torch_function(tensordot, (a, b), a, b, dims=dims) if isinstance(dims, (list, tuple)) or \ (isinstance(dims, torch.Tensor) and dims.numel() > 1): dims_a, dims_b = dims @@ -1046,9 +1031,8 @@ def cartesian_prod(*tensors): [3, 4], [3, 5]]) """ - if not torch.jit.is_scripting(): - if has_torch_function(tensors): - return handle_torch_function(cartesian_prod, tensors, *tensors) + if has_torch_function(tensors): + return handle_torch_function(cartesian_prod, tensors, *tensors) return _VF.cartesian_prod(tensors) # type: ignore def block_diag(*tensors): @@ -1128,10 +1112,9 @@ def cdist(x1, x2, p=2., compute_mode='use_mm_for_euclid_dist_if_necessary'): [2.7138, 3.8322], [2.2830, 0.3791]]) """ - if not torch.jit.is_scripting(): - if has_torch_function_variadic(x1, x2): - return handle_torch_function( - cdist, (x1, x2), x1, x2, p=p, compute_mode=compute_mode) + if has_torch_function_variadic(x1, x2): + return handle_torch_function( + cdist, (x1, x2), x1, x2, p=p, compute_mode=compute_mode) if compute_mode == 'use_mm_for_euclid_dist_if_necessary': return _VF.cdist(x1, x2, p, None) # type: ignore elif compute_mode == 'use_mm_for_euclid_dist': @@ -1168,9 +1151,8 @@ def atleast_1d(*tensors): >>> torch.atleast_1d((x,y)) (tensor([0.5000]), tensor([1.])) """ - if not torch.jit.is_scripting(): - if has_torch_function(tensors): - return handle_torch_function(atleast_1d, tensors, *tensors) + if has_torch_function(tensors): + return handle_torch_function(atleast_1d, tensors, *tensors) if len(tensors) == 1: tensors = tensors[0] return _VF.atleast_1d(tensors) # type: ignore @@ -1203,9 +1185,8 @@ def atleast_2d(*tensors): >>> torch.atleast_2d((x,y)) (tensor([[0.5000]]), tensor([[1.]])) """ - if not torch.jit.is_scripting(): - if has_torch_function(tensors): - return handle_torch_function(atleast_2d, tensors, *tensors) + if has_torch_function(tensors): + return handle_torch_function(atleast_2d, tensors, *tensors) if len(tensors) == 1: tensors = tensors[0] return _VF.atleast_2d(tensors) # type: ignore @@ -1247,9 +1228,8 @@ def atleast_3d(*tensors): >>> torch.atleast_3d((x,y)) (tensor([[[0.5000]]]), tensor([[[1.]]])) """ - if not torch.jit.is_scripting(): - if has_torch_function(tensors): - return handle_torch_function(atleast_3d, tensors, *tensors) + if has_torch_function(tensors): + return handle_torch_function(atleast_3d, tensors, *tensors) if len(tensors) == 1: tensors = tensors[0] return _VF.atleast_3d(tensors) # type: ignore @@ -1380,10 +1360,9 @@ def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None): # noqa (tensor(3.7417), tensor(11.2250)) """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - norm, (input,), input, p=p, dim=dim, keepdim=keepdim, out=out, dtype=dtype) + if has_torch_function_unary(input): + return handle_torch_function( + norm, (input,), input, p=p, dim=dim, keepdim=keepdim, out=out, dtype=dtype) ndim = input.dim() @@ -1476,9 +1455,8 @@ def chain_matmul(*matrices): .. _`[CLRS]`: https://mitpress.mit.edu/books/introduction-algorithms-third-edition """ - if not torch.jit.is_scripting(): - if has_torch_function(matrices): - return handle_torch_function(chain_matmul, matrices, *matrices) + if has_torch_function(matrices): + return handle_torch_function(chain_matmul, matrices, *matrices) return _VF.chain_matmul(matrices) # type: ignore @@ -1596,10 +1574,9 @@ def _check_list_size(out_len: int, get_infos: bool, out: _ListOrSeq) -> None: def _lu_with_infos(A, pivot=True, get_infos=False, out=None): # type: (Tensor, bool, bool, Optional[Tuple[Tensor, Tensor, Tensor]]) -> Tuple[Tensor, Tensor, Tensor] - if not torch.jit.is_scripting(): - if has_torch_function_unary(A): - return handle_torch_function( - lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out) + if has_torch_function_unary(A): + return handle_torch_function( + lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out) result = _lu_impl(A, pivot, get_infos, out) if out is not None: _check_list_size(len(out), get_infos, out) @@ -1612,10 +1589,9 @@ def _lu_with_infos(A, pivot=True, get_infos=False, out=None): def _lu_no_infos(A, pivot=True, get_infos=False, out=None): # type: (Tensor, bool, bool, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tensor] # need to check for torch_function here so that we exit if - if not torch.jit.is_scripting(): - if has_torch_function_unary(A): - return handle_torch_function( - lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out) + if has_torch_function_unary(A): + return handle_torch_function( + lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out) result = _lu_impl(A, pivot, get_infos, out) if out is not None: _check_list_size(len(out), get_infos, out) diff --git a/torch/jit/_script.py b/torch/jit/_script.py index 8bc8c6117c1b..bdf00e21c515 100644 --- a/torch/jit/_script.py +++ b/torch/jit/_script.py @@ -32,6 +32,8 @@ _set_jit_function_cache, _set_jit_overload_cache, ) +from torch.overrides import ( + has_torch_function, has_torch_function_unary, has_torch_function_variadic) torch._C.ScriptMethod.graph_for = _graph_for # type: ignore torch._C.ScriptFunction.graph_for = _graph_for # type: ignore @@ -1119,3 +1121,6 @@ def _unwrap_optional(x): _register_builtin(_unwrap_optional, "aten::_unwrap_optional") _register_builtin(_jit_internal.is_scripting, "aten::is_scripting") +_register_builtin(has_torch_function, "aten::has_torch_function") +_register_builtin(has_torch_function_unary, "aten::has_torch_function") +_register_builtin(has_torch_function_variadic, "aten::has_torch_function") diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 2cfc1c2b9393..ca2aaa5f9a40 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -412,18 +412,17 @@ def fractional_max_pool2d_with_indices( .. _Fractional MaxPooling: http://arxiv.org/abs/1412.6071 """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - fractional_max_pool2d_with_indices, - (input,), - input, - kernel_size, - output_size=output_size, - output_ratio=output_ratio, - return_indices=return_indices, - _random_samples=_random_samples, - ) + if has_torch_function_unary(input): + return handle_torch_function( + fractional_max_pool2d_with_indices, + (input,), + input, + kernel_size, + output_size=output_size, + output_ratio=output_ratio, + return_indices=return_indices, + _random_samples=_random_samples, + ) if output_size is None and output_ratio is None: raise ValueError("fractional_max_pool2d requires specifying either " "an output_size or an output_ratio") if output_size is None: @@ -440,18 +439,17 @@ def _fractional_max_pool2d( input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None ): # type: (Tensor, BroadcastingList2[int], Optional[BroadcastingList2[int]], Optional[BroadcastingList2[float]], bool, Optional[Tensor]) -> Tensor # noqa - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - fractional_max_pool2d, - (input,), - input, - kernel_size, - output_size=output_size, - output_ratio=output_ratio, - return_indices=return_indices, - _random_samples=_random_samples, - ) + if has_torch_function_unary(input): + return handle_torch_function( + fractional_max_pool2d, + (input,), + input, + kernel_size, + output_size=output_size, + output_ratio=output_ratio, + return_indices=return_indices, + _random_samples=_random_samples, + ) return fractional_max_pool2d_with_indices( input, kernel_size, output_size, output_ratio, return_indices, _random_samples )[0] @@ -502,18 +500,17 @@ def fractional_max_pool3d_with_indices( .. _Fractional MaxPooling: http://arxiv.org/abs/1412.6071 """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - fractional_max_pool3d_with_indices, - (input,), - input, - kernel_size, - output_size=output_size, - output_ratio=output_ratio, - return_indices=return_indices, - _random_samples=_random_samples, - ) + if has_torch_function_unary(input): + return handle_torch_function( + fractional_max_pool3d_with_indices, + (input,), + input, + kernel_size, + output_size=output_size, + output_ratio=output_ratio, + return_indices=return_indices, + _random_samples=_random_samples, + ) if output_size is None and output_ratio is None: raise ValueError("fractional_max_pool3d requires specifying either " "an output_size or an output_ratio") if output_size is None: @@ -534,18 +531,17 @@ def _fractional_max_pool3d( input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None ): # type: (Tensor, BroadcastingList3[int], Optional[BroadcastingList3[int]], Optional[BroadcastingList3[float]], bool, Optional[Tensor]) -> Tensor # noqa - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - fractional_max_pool3d, - (input,), - input, - kernel_size, - output_size=output_size, - output_ratio=output_ratio, - return_indices=return_indices, - _random_samples=_random_samples, - ) + if has_torch_function_unary(input): + return handle_torch_function( + fractional_max_pool3d, + (input,), + input, + kernel_size, + output_size=output_size, + output_ratio=output_ratio, + return_indices=return_indices, + _random_samples=_random_samples, + ) return fractional_max_pool3d_with_indices( input, kernel_size, output_size, output_ratio, return_indices, _random_samples )[0] @@ -571,19 +567,18 @@ def max_pool1d_with_indices( See :class:`~torch.nn.MaxPool1d` for details. """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - max_pool1d_with_indices, - (input,), - input, - kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - ceil_mode=ceil_mode, - return_indices=return_indices, - ) + if has_torch_function_unary(input): + return handle_torch_function( + max_pool1d_with_indices, + (input,), + input, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + return_indices=return_indices, + ) if stride is None: stride = torch.jit.annotate(List[int], []) return torch.max_pool1d_with_indices(input, kernel_size, stride, padding, dilation, ceil_mode) @@ -591,19 +586,18 @@ def max_pool1d_with_indices( def _max_pool1d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False): # type: (Tensor, BroadcastingList1[int], Optional[BroadcastingList1[int]], BroadcastingList1[int], BroadcastingList1[int], bool, bool) -> Tensor # noqa - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - max_pool1d, - (input,), - input, - kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - ceil_mode=ceil_mode, - return_indices=return_indices, - ) + if has_torch_function_unary(input): + return handle_torch_function( + max_pool1d, + (input,), + input, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + return_indices=return_indices, + ) if stride is None: stride = torch.jit.annotate(List[int], []) return torch.max_pool1d(input, kernel_size, stride, padding, dilation, ceil_mode) @@ -629,19 +623,18 @@ def max_pool2d_with_indices( See :class:`~torch.nn.MaxPool2d` for details. """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - max_pool2d_with_indices, - (input,), - input, - kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - ceil_mode=ceil_mode, - return_indices=return_indices, - ) + if has_torch_function_unary(input): + return handle_torch_function( + max_pool2d_with_indices, + (input,), + input, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + return_indices=return_indices, + ) if stride is None: stride = torch.jit.annotate(List[int], []) return torch._C._nn.max_pool2d_with_indices(input, kernel_size, stride, padding, dilation, ceil_mode) @@ -649,19 +642,18 @@ def max_pool2d_with_indices( def _max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False): # type: (Tensor, BroadcastingList2[int], Optional[BroadcastingList2[int]], BroadcastingList2[int], BroadcastingList2[int], bool, bool) -> Tensor # noqa - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - max_pool2d, - (input,), - input, - kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - ceil_mode=ceil_mode, - return_indices=return_indices, - ) + if has_torch_function_unary(input): + return handle_torch_function( + max_pool2d, + (input,), + input, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + return_indices=return_indices, + ) if stride is None: stride = torch.jit.annotate(List[int], []) return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode) @@ -687,19 +679,18 @@ def max_pool3d_with_indices( See :class:`~torch.nn.MaxPool3d` for details. """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - max_pool3d_with_indices, - (input,), - input, - kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - ceil_mode=ceil_mode, - return_indices=return_indices, - ) + if has_torch_function_unary(input): + return handle_torch_function( + max_pool3d_with_indices, + (input,), + input, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + return_indices=return_indices, + ) if stride is None: stride = torch.jit.annotate(List[int], []) return torch._C._nn.max_pool3d_with_indices(input, kernel_size, stride, padding, dilation, ceil_mode) @@ -707,19 +698,18 @@ def max_pool3d_with_indices( def _max_pool3d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False): # type: (Tensor, BroadcastingList3[int], Optional[BroadcastingList3[int]], BroadcastingList3[int], BroadcastingList3[int], bool, bool) -> Tensor # noqa - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - max_pool3d, - (input,), - input, - kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - ceil_mode=ceil_mode, - return_indices=return_indices, - ) + if has_torch_function_unary(input): + return handle_torch_function( + max_pool3d, + (input,), + input, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + return_indices=return_indices, + ) if stride is None: stride = torch.jit.annotate(List[int], []) return torch.max_pool3d(input, kernel_size, stride, padding, dilation, ceil_mode) @@ -775,18 +765,17 @@ def max_unpool1d(input, indices, kernel_size, stride=None, padding=0, output_siz See :class:`~torch.nn.MaxUnpool1d` for details. """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - max_unpool1d, - (input,), - input, - indices, - kernel_size, - stride=stride, - padding=padding, - output_size=output_size, - ) + if has_torch_function_unary(input): + return handle_torch_function( + max_unpool1d, + (input,), + input, + indices, + kernel_size, + stride=stride, + padding=padding, + output_size=output_size, + ) kernel_size = _single(kernel_size) if stride is not None: _stride = _single(stride) @@ -807,18 +796,17 @@ def max_unpool2d(input, indices, kernel_size, stride=None, padding=0, output_siz See :class:`~torch.nn.MaxUnpool2d` for details. """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - max_unpool2d, - (input,), - input, - indices, - kernel_size, - stride=stride, - padding=padding, - output_size=output_size, - ) + if has_torch_function_unary(input): + return handle_torch_function( + max_unpool2d, + (input,), + input, + indices, + kernel_size, + stride=stride, + padding=padding, + output_size=output_size, + ) kernel_size = _pair(kernel_size) if stride is not None: _stride = _pair(stride) @@ -835,18 +823,17 @@ def max_unpool3d(input, indices, kernel_size, stride=None, padding=0, output_siz See :class:`~torch.nn.MaxUnpool3d` for details. """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - max_unpool3d, - (input,), - input, - indices, - kernel_size, - stride=stride, - padding=padding, - output_size=output_size, - ) + if has_torch_function_unary(input): + return handle_torch_function( + max_unpool3d, + (input,), + input, + indices, + kernel_size, + stride=stride, + padding=padding, + output_size=output_size, + ) kernel_size = _triple(kernel_size) if stride is not None: _stride = _triple(stride) @@ -865,11 +852,10 @@ def lp_pool2d(input, norm_type, kernel_size, stride=None, ceil_mode=False): See :class:`~torch.nn.LPPool2d` for details. """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - lp_pool2d, (input,), input, norm_type, kernel_size, stride=stride, ceil_mode=ceil_mode - ) + if has_torch_function_unary(input): + return handle_torch_function( + lp_pool2d, (input,), input, norm_type, kernel_size, stride=stride, ceil_mode=ceil_mode + ) kw, kh = utils._pair(kernel_size) if stride is not None: out = avg_pool2d(input.pow(norm_type), kernel_size, stride, 0, ceil_mode) @@ -887,11 +873,10 @@ def lp_pool1d(input, norm_type, kernel_size, stride=None, ceil_mode=False): See :class:`~torch.nn.LPPool1d` for details. """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - lp_pool1d, (input,), input, norm_type, kernel_size, stride=stride, ceil_mode=ceil_mode - ) + if has_torch_function_unary(input): + return handle_torch_function( + lp_pool1d, (input,), input, norm_type, kernel_size, stride=stride, ceil_mode=ceil_mode + ) if stride is not None: out = avg_pool1d(input.pow(norm_type), kernel_size, stride, 0, ceil_mode) else: @@ -911,21 +896,19 @@ def adaptive_max_pool1d_with_indices(input, output_size, return_indices=False): output_size: the target output size (single integer) return_indices: whether to return pooling indices. Default: ``False`` """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - adaptive_max_pool1d_with_indices, (input,), input, output_size, return_indices=return_indices - ) + if has_torch_function_unary(input): + return handle_torch_function( + adaptive_max_pool1d_with_indices, (input,), input, output_size, return_indices=return_indices + ) return torch.adaptive_max_pool1d(input, output_size) def _adaptive_max_pool1d(input, output_size, return_indices=False): # type: (Tensor, BroadcastingList1[int], bool) -> Tensor - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - adaptive_max_pool1d, (input,), input, output_size, return_indices=return_indices - ) + if has_torch_function_unary(input): + return handle_torch_function( + adaptive_max_pool1d, (input,), input, output_size, return_indices=return_indices + ) return adaptive_max_pool1d_with_indices(input, output_size)[0] @@ -952,22 +935,20 @@ def adaptive_max_pool2d_with_indices(input, output_size, return_indices=False): double-integer tuple) return_indices: whether to return pooling indices. Default: ``False`` """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - adaptive_max_pool2d_with_indices, (input,), input, output_size, return_indices=return_indices - ) + if has_torch_function_unary(input): + return handle_torch_function( + adaptive_max_pool2d_with_indices, (input,), input, output_size, return_indices=return_indices + ) output_size = _list_with_default(output_size, input.size()) return torch._C._nn.adaptive_max_pool2d(input, output_size) def _adaptive_max_pool2d(input, output_size, return_indices=False): # type: (Tensor, BroadcastingList2[int], bool) -> Tensor - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - adaptive_max_pool2d, (input,), input, output_size, return_indices=return_indices - ) + if has_torch_function_unary(input): + return handle_torch_function( + adaptive_max_pool2d, (input,), input, output_size, return_indices=return_indices + ) return adaptive_max_pool2d_with_indices(input, output_size)[0] @@ -994,22 +975,20 @@ def adaptive_max_pool3d_with_indices(input, output_size, return_indices=False): triple-integer tuple) return_indices: whether to return pooling indices. Default: ``False`` """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - adaptive_max_pool3d_with_indices, (input,), input, output_size, return_indices=return_indices - ) + if has_torch_function_unary(input): + return handle_torch_function( + adaptive_max_pool3d_with_indices, (input,), input, output_size, return_indices=return_indices + ) output_size = _list_with_default(output_size, input.size()) return torch._C._nn.adaptive_max_pool3d(input, output_size) def _adaptive_max_pool3d(input, output_size, return_indices=False): # type: (Tensor, BroadcastingList3[int], bool) -> Tensor - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - adaptive_max_pool3d, (input,), input, output_size, return_indices=return_indices - ) + if has_torch_function_unary(input): + return handle_torch_function( + adaptive_max_pool3d, (input,), input, output_size, return_indices=return_indices + ) return adaptive_max_pool3d_with_indices(input, output_size)[0] @@ -1052,9 +1031,8 @@ def adaptive_avg_pool2d(input, output_size): output_size: the target output size (single integer or double-integer tuple) """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function(adaptive_avg_pool2d, (input,), input, output_size) + if has_torch_function_unary(input): + return handle_torch_function(adaptive_avg_pool2d, (input,), input, output_size) _output_size = _list_with_default(output_size, input.size()) return torch._C._nn.adaptive_avg_pool2d(input, _output_size) @@ -1071,9 +1049,8 @@ def adaptive_avg_pool3d(input, output_size): output_size: the target output size (single integer or triple-integer tuple) """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function(adaptive_avg_pool3d, (input,), input, output_size) + if has_torch_function_unary(input): + return handle_torch_function(adaptive_avg_pool3d, (input,), input, output_size) _output_size = _list_with_default(output_size, input.size()) return torch._C._nn.adaptive_avg_pool3d(input, _output_size) @@ -1092,9 +1069,8 @@ def dropout(input: Tensor, p: float = 0.5, training: bool = True, inplace: bool training: apply dropout if is ``True``. Default: ``True`` inplace: If set to ``True``, will do this operation in-place. Default: ``False`` """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function(dropout, (input,), input, p=p, training=training, inplace=inplace) + if has_torch_function_unary(input): + return handle_torch_function(dropout, (input,), input, p=p, training=training, inplace=inplace) if p < 0.0 or p > 1.0: raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p)) return _VF.dropout_(input, p, training) if inplace else _VF.dropout(input, p, training) @@ -1105,9 +1081,8 @@ def alpha_dropout(input: Tensor, p: float = 0.5, training: bool = False, inplace See :class:`~torch.nn.AlphaDropout` for details. """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function(alpha_dropout, (input,), input, p=p, training=training, inplace=inplace) + if has_torch_function_unary(input): + return handle_torch_function(alpha_dropout, (input,), input, p=p, training=training, inplace=inplace) if p < 0.0 or p > 1.0: raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p)) return _VF.alpha_dropout_(input, p, training) if inplace else _VF.alpha_dropout(input, p, training) @@ -1128,9 +1103,8 @@ def dropout2d(input: Tensor, p: float = 0.5, training: bool = True, inplace: boo training: apply dropout if is ``True``. Default: ``True`` inplace: If set to ``True``, will do this operation in-place. Default: ``False`` """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function(dropout2d, (input,), input, p=p, training=training, inplace=inplace) + if has_torch_function_unary(input): + return handle_torch_function(dropout2d, (input,), input, p=p, training=training, inplace=inplace) if p < 0.0 or p > 1.0: raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p)) return _VF.feature_dropout_(input, p, training) if inplace else _VF.feature_dropout(input, p, training) @@ -1153,9 +1127,8 @@ def dropout3d(input: Tensor, p: float = 0.5, training: bool = True, inplace: boo """ # This is 100% the same code as dropout2d. We duplicate this code so that # stack traces are not confusing. - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function(dropout3d, (input,), input, p=p, training=training, inplace=inplace) + if has_torch_function_unary(input): + return handle_torch_function(dropout3d, (input,), input, p=p, training=training, inplace=inplace) if p < 0.0 or p > 1.0: raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p)) return _VF.feature_dropout_(input, p, training) if inplace else _VF.feature_dropout(input, p, training) @@ -1181,11 +1154,10 @@ def feature_alpha_dropout(input: Tensor, p: float = 0.5, training: bool = False, training: apply dropout if is ``True``. Default: ``True`` inplace: If set to ``True``, will do this operation in-place. Default: ``False`` """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - feature_alpha_dropout, (input,), input, p=p, training=training, inplace=inplace - ) + if has_torch_function_unary(input): + return handle_torch_function( + feature_alpha_dropout, (input,), input, p=p, training=training, inplace=inplace + ) if p < 0.0 or p > 1.0: raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p)) return _VF.feature_alpha_dropout_(input, p, training) if inplace else _VF.feature_alpha_dropout(input, p, training) @@ -1196,9 +1168,8 @@ def _threshold(input: Tensor, threshold: float, value: float, inplace: bool = Fa See :class:`~torch.nn.Threshold` for more details. """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function(_threshold, (input,), input, threshold, value, inplace=inplace) + if has_torch_function_unary(input): + return handle_torch_function(_threshold, (input,), input, threshold, value, inplace=inplace) if inplace: result = _VF.threshold_(input, threshold, value) else: @@ -1227,9 +1198,8 @@ def relu(input: Tensor, inplace: bool = False) -> Tensor: Applies the rectified linear unit function element-wise. See :class:`~torch.nn.ReLU` for more details. """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function(relu, (input,), input, inplace=inplace) + if has_torch_function_unary(input): + return handle_torch_function(relu, (input,), input, inplace=inplace) if inplace: result = torch.relu_(input) else: @@ -1265,9 +1235,8 @@ def glu(input: Tensor, dim: int = -1) -> Tensor: input (Tensor): input tensor dim (int): dimension on which to split the input. Default: -1 """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function(glu, (input,), input, dim=dim) + if has_torch_function_unary(input): + return handle_torch_function(glu, (input,), input, dim=dim) if input.dim() == 0: raise RuntimeError("glu does not support scalars because halving size must be even") return torch._C._nn.glu(input, dim) @@ -1280,9 +1249,8 @@ def hardtanh(input: Tensor, min_val: float = -1.0, max_val: float = 1.0, inplace Applies the HardTanh function element-wise. See :class:`~torch.nn.Hardtanh` for more details. """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function(hardtanh, (input,), input, min_val=min_val, max_val=max_val, inplace=inplace) + if has_torch_function_unary(input): + return handle_torch_function(hardtanh, (input,), input, min_val=min_val, max_val=max_val, inplace=inplace) if inplace: result = torch._C._nn.hardtanh_(input, min_val, max_val) else: @@ -1307,9 +1275,8 @@ def relu6(input: Tensor, inplace: bool = False) -> Tensor: See :class:`~torch.nn.ReLU6` for more details. """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function(relu6, (input,), input, inplace=inplace) + if has_torch_function_unary(input): + return handle_torch_function(relu6, (input,), input, inplace=inplace) return hardtanh(input, 0.0, 6.0, inplace) @@ -1319,9 +1286,8 @@ def elu(input: Tensor, alpha: float = 1.0, inplace: bool = False) -> Tensor: See :class:`~torch.nn.ELU` for more details. """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function(elu, (input,), input, alpha=alpha, inplace=inplace) + if has_torch_function_unary(input): + return handle_torch_function(elu, (input,), input, alpha=alpha, inplace=inplace) if inplace: result = torch._C._nn.elu_(input, alpha) else: @@ -1349,9 +1315,8 @@ def selu(input: Tensor, inplace: bool = False) -> Tensor: See :class:`~torch.nn.SELU` for more details. """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function(selu, (input,), input, inplace=inplace) + if has_torch_function_unary(input): + return handle_torch_function(selu, (input,), input, inplace=inplace) if inplace: result = torch.selu_(input) else: @@ -1377,9 +1342,8 @@ def celu(input: Tensor, alpha: float = 1.0, inplace: bool = False) -> Tensor: See :class:`~torch.nn.CELU` for more details. """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function(celu, (input,), input, alpha=alpha, inplace=inplace) + if has_torch_function_unary(input): + return handle_torch_function(celu, (input,), input, alpha=alpha, inplace=inplace) if inplace: result = torch.celu_(input, alpha) else: @@ -1406,9 +1370,8 @@ def leaky_relu(input: Tensor, negative_slope: float = 0.01, inplace: bool = Fals See :class:`~torch.nn.LeakyReLU` for more details. """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function(leaky_relu, (input,), input, negative_slope=negative_slope, inplace=inplace) + if has_torch_function_unary(input): + return handle_torch_function(leaky_relu, (input,), input, negative_slope=negative_slope, inplace=inplace) if inplace: result = torch._C._nn.leaky_relu_(input, negative_slope) else: @@ -1435,9 +1398,8 @@ def prelu(input: Tensor, weight: Tensor) -> Tensor: See :class:`~torch.nn.PReLU` for more details. """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function(prelu, (input,), input, weight) + if has_torch_function_unary(input): + return handle_torch_function(prelu, (input,), input, weight) return torch.prelu(input, weight) @@ -1450,11 +1412,10 @@ def rrelu( See :class:`~torch.nn.RReLU` for more details. """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - rrelu, (input,), input, lower=lower, upper=upper, training=training, inplace=inplace - ) + if has_torch_function_unary(input): + return handle_torch_function( + rrelu, (input,), input, lower=lower, upper=upper, training=training, inplace=inplace + ) if inplace: result = torch.rrelu_(input, lower, upper, training) else: @@ -1493,9 +1454,8 @@ def gelu(input): See `Gaussian Error Linear Units (GELUs) `_. """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function(gelu, (input,), input) + if has_torch_function_unary(input): + return handle_torch_function(gelu, (input,), input) return torch._C._nn.gelu(input) @@ -1507,9 +1467,8 @@ def hardshrink(input: Tensor, lambd: float = 0.5) -> Tensor: See :class:`~torch.nn.Hardshrink` for more details. """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function(hardshrink, (input,), input, lambd=lambd) + if has_torch_function_unary(input): + return handle_torch_function(hardshrink, (input,), input, lambd=lambd) return torch.hardshrink(input, lambd) @@ -1520,9 +1479,8 @@ def tanhshrink(input): See :class:`~torch.nn.Tanhshrink` for more details. """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function(tanhshrink, (input,), input) + if has_torch_function_unary(input): + return handle_torch_function(tanhshrink, (input,), input) return input - input.tanh() @@ -1533,9 +1491,8 @@ def softsign(input): See :class:`~torch.nn.Softsign` for more details. """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function(softsign, (input,), input) + if has_torch_function_unary(input): + return handle_torch_function(softsign, (input,), input) return input / (input.abs() + 1) @@ -1582,9 +1539,8 @@ def softmin(input: Tensor, dim: Optional[int] = None, _stacklevel: int = 3, dtyp If specified, the input tensor is casted to :attr:`dtype` before the operation is performed. This is useful for preventing data type overflows. Default: None. """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function(softmin, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype) + if has_torch_function_unary(input): + return handle_torch_function(softmin, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype) if dim is None: dim = _get_softmax_dim("softmin", input.dim(), _stacklevel) if dtype is None: @@ -1619,9 +1575,8 @@ def softmax(input: Tensor, dim: Optional[int] = None, _stacklevel: int = 3, dtyp Use log_softmax instead (it's faster and has better numerical properties). """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function(softmax, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype) + if has_torch_function_unary(input): + return handle_torch_function(softmax, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype) if dim is None: dim = _get_softmax_dim("softmax", input.dim(), _stacklevel) if dtype is None: @@ -1671,9 +1626,8 @@ def gumbel_softmax(logits: Tensor, tau: float = 1, hard: bool = False, eps: floa .. _Link 2: https://arxiv.org/abs/1611.01144 """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(logits): - return handle_torch_function(gumbel_softmax, (logits,), logits, tau=tau, hard=hard, eps=eps, dim=dim) + if has_torch_function_unary(logits): + return handle_torch_function(gumbel_softmax, (logits,), logits, tau=tau, hard=hard, eps=eps, dim=dim) if eps != 1e-10: warnings.warn("`eps` parameter is deprecated and has no effect.") @@ -1710,9 +1664,8 @@ def log_softmax(input: Tensor, dim: Optional[int] = None, _stacklevel: int = 3, If specified, the input tensor is casted to :attr:`dtype` before the operation is performed. This is useful for preventing data type overflows. Default: None. """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function(log_softmax, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype) + if has_torch_function_unary(input): + return handle_torch_function(log_softmax, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype) if dim is None: dim = _get_softmax_dim("log_softmax", input.dim(), _stacklevel) if dtype is None: @@ -1774,9 +1727,8 @@ def hardsigmoid(input: Tensor, inplace: bool = False) -> Tensor: See :class:`~torch.nn.Hardsigmoid` for more details. """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function(hardsigmoid, (input,), input, inplace=inplace) + if has_torch_function_unary(input): + return handle_torch_function(hardsigmoid, (input,), input, inplace=inplace) if inplace: return torch._C._nn.hardsigmoid_(input) return torch._C._nn.hardsigmoid(input) @@ -1796,9 +1748,8 @@ def linear(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tens - Bias: :math:`(out\_features)` - Output: :math:`(N, *, out\_features)` """ - if not torch.jit.is_scripting(): - if has_torch_function_variadic(input, weight): - return handle_torch_function(linear, (input, weight), input, weight, bias=bias) + if has_torch_function_variadic(input, weight): + return handle_torch_function(linear, (input, weight), input, weight, bias=bias) if input.dim() == 2 and bias is not None: # fused op is marginally faster ret = torch.addmm(bias, input, weight.t()) @@ -1846,9 +1797,8 @@ def silu(input: Tensor, inplace: bool = False) -> Tensor: See :class:`~torch.nn.SiLU` for more details. """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function(silu, (input,), input, inplace=inplace) + if has_torch_function_unary(input): + return handle_torch_function(silu, (input,), input, inplace=inplace) if inplace: return torch._C._nn.silu_(input) return torch._C._nn.silu(input) @@ -1871,9 +1821,8 @@ def hardswish(input: Tensor, inplace: bool = False) -> Tensor: .. _`Searching for MobileNetV3`: https://arxiv.org/abs/1905.02244 """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function(hardswish, (input,), input, inplace=inplace) + if has_torch_function_unary(input): + return handle_torch_function(hardswish, (input,), input, inplace=inplace) if inplace: return torch._C._nn.hardswish_(input) return torch._C._nn.hardswish(input) @@ -2058,23 +2007,21 @@ def embedding_bag( tensor([[ 0.3397, 0.3552, 0.5545], [ 0.5893, 0.4386, 0.5882]]) """ - - if not torch.jit.is_scripting(): - if has_torch_function_variadic(input, weight): - return handle_torch_function( - embedding_bag, - (input, weight), - input, - weight, - offsets=offsets, - max_norm=max_norm, - norm_type=norm_type, - scale_grad_by_freq=scale_grad_by_freq, - mode=mode, - sparse=sparse, - per_sample_weights=per_sample_weights, - include_last_offset=include_last_offset, - ) + if has_torch_function_variadic(input, weight): + return handle_torch_function( + embedding_bag, + (input, weight), + input, + weight, + offsets=offsets, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + mode=mode, + sparse=sparse, + per_sample_weights=per_sample_weights, + include_last_offset=include_last_offset, + ) # Check for backward compatibility. # Used to be embedding_bag(weight, input, ...) # Now is embedding_bag(input, weight, ...) @@ -2188,20 +2135,19 @@ def batch_norm( See :class:`~torch.nn.BatchNorm1d`, :class:`~torch.nn.BatchNorm2d`, :class:`~torch.nn.BatchNorm3d` for details. """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - batch_norm, - (input,), - input, - running_mean, - running_var, - weight=weight, - bias=bias, - training=training, - momentum=momentum, - eps=eps, - ) + if has_torch_function_unary(input): + return handle_torch_function( + batch_norm, + (input,), + input, + running_mean, + running_var, + weight=weight, + bias=bias, + training=training, + momentum=momentum, + eps=eps, + ) if training: _verify_batch_size(input.size()) @@ -2227,20 +2173,19 @@ def instance_norm( See :class:`~torch.nn.InstanceNorm1d`, :class:`~torch.nn.InstanceNorm2d`, :class:`~torch.nn.InstanceNorm3d` for details. """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - instance_norm, - (input,), - input, - running_mean=running_mean, - running_var=running_var, - weight=weight, - bias=bias, - use_input_stats=use_input_stats, - momentum=momentum, - eps=eps, - ) + if has_torch_function_unary(input): + return handle_torch_function( + instance_norm, + (input,), + input, + running_mean=running_mean, + running_var=running_var, + weight=weight, + bias=bias, + use_input_stats=use_input_stats, + momentum=momentum, + eps=eps, + ) _verify_batch_size(input.size()) return torch.instance_norm( input, weight, bias, running_mean, running_var, use_input_stats, momentum, eps, torch.backends.cudnn.enabled @@ -2258,11 +2203,10 @@ def layer_norm( See :class:`~torch.nn.LayerNorm` for details. """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - layer_norm, (input,), input, normalized_shape, weight=weight, bias=bias, eps=eps - ) + if has_torch_function_unary(input): + return handle_torch_function( + layer_norm, (input,), input, normalized_shape, weight=weight, bias=bias, eps=eps + ) return torch.layer_norm(input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled) @@ -2273,9 +2217,8 @@ def group_norm( See :class:`~torch.nn.GroupNorm` for details. """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function(group_norm, (input,), input, num_groups, weight=weight, bias=bias, eps=eps) + if has_torch_function_unary(input): + return handle_torch_function(group_norm, (input,), input, num_groups, weight=weight, bias=bias, eps=eps) _verify_batch_size([input.size(0) * input.size(1) // num_groups, num_groups] + list(input.size()[2:])) return torch.group_norm(input, num_groups, weight, bias, eps, torch.backends.cudnn.enabled) @@ -2287,9 +2230,8 @@ def local_response_norm(input: Tensor, size: int, alpha: float = 1e-4, beta: flo See :class:`~torch.nn.LocalResponseNorm` for details. """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function(local_response_norm, (input,), input, size, alpha=alpha, beta=beta, k=k) + if has_torch_function_unary(input): + return handle_torch_function(local_response_norm, (input,), input, size, alpha=alpha, beta=beta, k=k) dim = input.dim() if dim < 3: raise ValueError( @@ -2425,19 +2367,18 @@ def nll_loss( >>> output = F.nll_loss(F.log_softmax(input), target) >>> output.backward() """ - if not torch.jit.is_scripting(): - if has_torch_function_variadic(input, target): - return handle_torch_function( - nll_loss, - (input, target), - input, - target, - weight=weight, - size_average=size_average, - ignore_index=ignore_index, - reduce=reduce, - reduction=reduction, - ) + if has_torch_function_variadic(input, target): + return handle_torch_function( + nll_loss, + (input, target), + input, + target, + weight=weight, + size_average=size_average, + ignore_index=ignore_index, + reduce=reduce, + reduction=reduction, + ) if size_average is not None or reduce is not None: reduction = _Reduction.legacy_get_string(size_average, reduce) dim = input.dim() @@ -2521,20 +2462,19 @@ def poisson_nll_loss( specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` """ - if not torch.jit.is_scripting(): - if has_torch_function_variadic(input, target): - return handle_torch_function( - poisson_nll_loss, - (input, target), - input, - target, - log_input=log_input, - full=full, - size_average=size_average, - eps=eps, - reduce=reduce, - reduction=reduction, - ) + if has_torch_function_variadic(input, target): + return handle_torch_function( + poisson_nll_loss, + (input, target), + input, + target, + log_input=log_input, + full=full, + size_average=size_average, + eps=eps, + reduce=reduce, + reduction=reduction, + ) if size_average is not None or reduce is not None: reduction = _Reduction.legacy_get_string(size_average, reduce) if reduction != "none" and reduction != "mean" and reduction != "sum": @@ -2591,18 +2531,17 @@ def kl_div( :attr:``reduction`` = ``'batchmean'`` which aligns with KL math definition. In the next major release, ``'mean'`` will be changed to be the same as 'batchmean'. """ - if not torch.jit.is_scripting(): - if has_torch_function_variadic(input, target): - return handle_torch_function( - kl_div, - (input, target), - input, - target, - size_average=size_average, - reduce=reduce, - reduction=reduction, - log_target=log_target, - ) + if has_torch_function_variadic(input, target): + return handle_torch_function( + kl_div, + (input, target), + input, + target, + size_average=size_average, + reduce=reduce, + reduction=reduction, + log_target=log_target, + ) if size_average is not None or reduce is not None: reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) else: @@ -2676,19 +2615,18 @@ def cross_entropy( >>> loss = F.cross_entropy(input, target) >>> loss.backward() """ - if not torch.jit.is_scripting(): - if has_torch_function_variadic(input, target): - return handle_torch_function( - cross_entropy, - (input, target), - input, - target, - weight=weight, - size_average=size_average, - ignore_index=ignore_index, - reduce=reduce, - reduction=reduction, - ) + if has_torch_function_variadic(input, target): + return handle_torch_function( + cross_entropy, + (input, target), + input, + target, + weight=weight, + size_average=size_average, + ignore_index=ignore_index, + reduce=reduce, + reduction=reduction, + ) if size_average is not None or reduce is not None: reduction = _Reduction.legacy_get_string(size_average, reduce) return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction) @@ -2735,18 +2673,17 @@ def binary_cross_entropy( >>> loss = F.binary_cross_entropy(F.sigmoid(input), target) >>> loss.backward() """ - if not torch.jit.is_scripting(): - if has_torch_function_variadic(input, target): - return handle_torch_function( - binary_cross_entropy, - (input, target), - input, - target, - weight=weight, - size_average=size_average, - reduce=reduce, - reduction=reduction, - ) + if has_torch_function_variadic(input, target): + return handle_torch_function( + binary_cross_entropy, + (input, target), + input, + target, + weight=weight, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) if size_average is not None or reduce is not None: reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) else: @@ -2808,19 +2745,18 @@ def binary_cross_entropy_with_logits( >>> loss = F.binary_cross_entropy_with_logits(input, target) >>> loss.backward() """ - if not torch.jit.is_scripting(): - if has_torch_function_variadic(input, target): - return handle_torch_function( - binary_cross_entropy_with_logits, - (input, target), - input, - target, - weight=weight, - size_average=size_average, - reduce=reduce, - reduction=reduction, - pos_weight=pos_weight, - ) + if has_torch_function_variadic(input, target): + return handle_torch_function( + binary_cross_entropy_with_logits, + (input, target), + input, + target, + weight=weight, + size_average=size_average, + reduce=reduce, + reduction=reduction, + pos_weight=pos_weight, + ) if size_average is not None or reduce is not None: reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) else: @@ -2845,18 +2781,17 @@ def smooth_l1_loss( See :class:`~torch.nn.SmoothL1Loss` for details. """ - if not torch.jit.is_scripting(): - if has_torch_function_variadic(input, target): - return handle_torch_function( - smooth_l1_loss, - (input, target), - input, - target, - size_average=size_average, - reduce=reduce, - reduction=reduction, - beta=beta, - ) + if has_torch_function_variadic(input, target): + return handle_torch_function( + smooth_l1_loss, + (input, target), + input, + target, + size_average=size_average, + reduce=reduce, + reduction=reduction, + beta=beta, + ) if not (target.size() == input.size()): warnings.warn( "Using a target size ({}) that is different to the input size ({}). " @@ -2884,11 +2819,10 @@ def l1_loss( See :class:`~torch.nn.L1Loss` for details. """ - if not torch.jit.is_scripting(): - if has_torch_function_variadic(input, target): - return handle_torch_function( - l1_loss, (input, target), input, target, size_average=size_average, reduce=reduce, reduction=reduction - ) + if has_torch_function_variadic(input, target): + return handle_torch_function( + l1_loss, (input, target), input, target, size_average=size_average, reduce=reduce, reduction=reduction + ) if not (target.size() == input.size()): warnings.warn( "Using a target size ({}) that is different to the input size ({}). " @@ -2916,11 +2850,10 @@ def mse_loss( See :class:`~torch.nn.MSELoss` for details. """ - if not torch.jit.is_scripting(): - if has_torch_function_variadic(input, target): - return handle_torch_function( - mse_loss, (input, target), input, target, size_average=size_average, reduce=reduce, reduction=reduction - ) + if has_torch_function_variadic(input, target): + return handle_torch_function( + mse_loss, (input, target), input, target, size_average=size_average, reduce=reduce, reduction=reduction + ) if not (target.size() == input.size()): warnings.warn( "Using a target size ({}) that is different to the input size ({}). " @@ -2948,19 +2881,18 @@ def margin_ranking_loss( See :class:`~torch.nn.MarginRankingLoss` for details. """ # noqa - if not torch.jit.is_scripting(): - if has_torch_function_variadic(input1, input2, target): - return handle_torch_function( - margin_ranking_loss, - (input1, input2, target), - input1, - input2, - target, - margin=margin, - size_average=size_average, - reduce=reduce, - reduction=reduction, - ) + if has_torch_function_variadic(input1, input2, target): + return handle_torch_function( + margin_ranking_loss, + (input1, input2, target), + input1, + input2, + target, + margin=margin, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) if size_average is not None or reduce is not None: reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) else: @@ -2987,18 +2919,17 @@ def hinge_embedding_loss( See :class:`~torch.nn.HingeEmbeddingLoss` for details. """ # noqa - if not torch.jit.is_scripting(): - if has_torch_function_variadic(input, target): - return handle_torch_function( - hinge_embedding_loss, - (input, target), - input, - target, - margin=margin, - size_average=size_average, - reduce=reduce, - reduction=reduction, - ) + if has_torch_function_variadic(input, target): + return handle_torch_function( + hinge_embedding_loss, + (input, target), + input, + target, + margin=margin, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) if size_average is not None or reduce is not None: reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) else: @@ -3017,17 +2948,16 @@ def multilabel_margin_loss( See :class:`~torch.nn.MultiLabelMarginLoss` for details. """ - if not torch.jit.is_scripting(): - if has_torch_function_variadic(input, target): - return handle_torch_function( - multilabel_margin_loss, - (input, target), - input, - target, - size_average=size_average, - reduce=reduce, - reduction=reduction, - ) + if has_torch_function_variadic(input, target): + return handle_torch_function( + multilabel_margin_loss, + (input, target), + input, + target, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) if size_average is not None or reduce is not None: reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) else: @@ -3046,11 +2976,10 @@ def soft_margin_loss( See :class:`~torch.nn.SoftMarginLoss` for details. """ - if not torch.jit.is_scripting(): - if has_torch_function_variadic(input, target): - return handle_torch_function( - soft_margin_loss, (input, target), input, target, size_average=size_average, reduce=reduce, reduction=reduction - ) + if has_torch_function_variadic(input, target): + return handle_torch_function( + soft_margin_loss, (input, target), input, target, size_average=size_average, reduce=reduce, reduction=reduction + ) if size_average is not None or reduce is not None: reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) else: @@ -3070,18 +2999,17 @@ def multilabel_soft_margin_loss( See :class:`~torch.nn.MultiLabelSoftMarginLoss` for details. """ - if not torch.jit.is_scripting(): - if has_torch_function_variadic(input, target): - return handle_torch_function( - multilabel_soft_margin_loss, - (input, target), - input, - target, - weight=weight, - size_average=size_average, - reduce=reduce, - reduction=reduction, - ) + if has_torch_function_variadic(input, target): + return handle_torch_function( + multilabel_soft_margin_loss, + (input, target), + input, + target, + weight=weight, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) if size_average is not None or reduce is not None: reduction = _Reduction.legacy_get_string(size_average, reduce) @@ -3117,19 +3045,18 @@ def cosine_embedding_loss( See :class:`~torch.nn.CosineEmbeddingLoss` for details. """ # noqa - if not torch.jit.is_scripting(): - if has_torch_function_variadic(input1, input2, target): - return handle_torch_function( - cosine_embedding_loss, - (input1, input2, target), - input1, - input2, - target, - margin=margin, - size_average=size_average, - reduce=reduce, - reduction=reduction, - ) + if has_torch_function_variadic(input1, input2, target): + return handle_torch_function( + cosine_embedding_loss, + (input1, input2, target), + input1, + input2, + target, + margin=margin, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) if size_average is not None or reduce is not None: reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) else: @@ -3152,20 +3079,19 @@ def multi_margin_loss( See :class:`~torch.nn.MultiMarginLoss` for details. """ - if not torch.jit.is_scripting(): - if has_torch_function_variadic(input, target): - return handle_torch_function( - multi_margin_loss, - (input, target), - input, - target, - p=p, - margin=margin, - weight=weight, - size_average=size_average, - reduce=reduce, - reduction=reduction, - ) + if has_torch_function_variadic(input, target): + return handle_torch_function( + multi_margin_loss, + (input, target), + input, + target, + p=p, + margin=margin, + weight=weight, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) if size_average is not None or reduce is not None: reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) else: @@ -3444,18 +3370,17 @@ def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corne Note: {backward_reproducibility_note} """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - interpolate, - (input,), - input, - size=size, - scale_factor=scale_factor, - mode=mode, - align_corners=align_corners, - recompute_scale_factor=recompute_scale_factor, - ) + if has_torch_function_unary(input): + return handle_torch_function( + interpolate, + (input,), + input, + size=size, + scale_factor=scale_factor, + mode=mode, + align_corners=align_corners, + recompute_scale_factor=recompute_scale_factor, + ) if mode in ("nearest", "area"): if align_corners is not None: @@ -3808,11 +3733,10 @@ def grid_sample( .. _`PIL`: https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/src/libImaging/Resample.c#L51 .. _`OpenCV`: https://github.com/opencv/opencv/blob/f345ed564a06178670750bad59526cfa4033be55/modules/imgproc/src/resize.cpp#L908 """ - if not torch.jit.is_scripting(): - if has_torch_function_variadic(input, grid): - return handle_torch_function( - grid_sample, (input, grid), input, grid, mode=mode, padding_mode=padding_mode, align_corners=align_corners - ) + if has_torch_function_variadic(input, grid): + return handle_torch_function( + grid_sample, (input, grid), input, grid, mode=mode, padding_mode=padding_mode, align_corners=align_corners + ) if mode != "bilinear" and mode != "nearest" and mode != "bicubic": raise ValueError( "nn.functional.grid_sample(): expected mode to be " @@ -3899,9 +3823,8 @@ def affine_grid(theta: Tensor, size: List[int], align_corners: Optional[bool] = along a unit dimension are considered to be at ```0`` (the center of the input image). """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(theta): - return handle_torch_function(affine_grid, (theta,), theta, size, align_corners=align_corners) + if has_torch_function_unary(theta): + return handle_torch_function(affine_grid, (theta,), theta, size, align_corners=align_corners) if align_corners is None: warnings.warn( "Default grid_sample and affine_grid behavior has changed " @@ -4008,9 +3931,8 @@ def _pad(input: Tensor, pad: List[int], mode: str = "constant", value: float = 0 torch.Size([3, 9, 7, 3]) """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function(_pad, (input,), input, pad, mode=mode, value=value) + if has_torch_function_unary(input): + return handle_torch_function(_pad, (input,), input, pad, mode=mode, value=value) assert len(pad) % 2 == 0, "Padding length must be divisible by 2" assert len(pad) // 2 <= input.dim(), "Padding length too large" if mode == "constant": @@ -4191,22 +4113,21 @@ def triplet_margin_loss( r""" See :class:`~torch.nn.TripletMarginLoss` for details """ - if not torch.jit.is_scripting(): - if has_torch_function_variadic(anchor, positive, negative): - return handle_torch_function( - triplet_margin_loss, - (anchor, positive, negative), - anchor, - positive, - negative, - margin=margin, - p=p, - eps=eps, - swap=swap, - size_average=size_average, - reduce=reduce, - reduction=reduction, - ) + if has_torch_function_variadic(anchor, positive, negative): + return handle_torch_function( + triplet_margin_loss, + (anchor, positive, negative), + anchor, + positive, + negative, + margin=margin, + p=p, + eps=eps, + swap=swap, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) if size_average is not None or reduce is not None: reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) else: @@ -4285,9 +4206,8 @@ def normalize(input: Tensor, p: float = 2, dim: int = 1, eps: float = 1e-12, out out (Tensor, optional): the output tensor. If :attr:`out` is used, this operation won't be differentiable. """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function(normalize, (input,), input, p=p, dim=dim, eps=eps, out=out) + if has_torch_function_unary(input): + return handle_torch_function(normalize, (input,), input, p=p, dim=dim, eps=eps, out=out) if out is None: denom = input.norm(p, dim, keepdim=True).clamp_min(eps).expand_as(input) return input / denom @@ -4318,11 +4238,10 @@ def unfold(input, kernel_size, dilation=1, padding=0, stride=1): See :class:`torch.nn.Unfold` for details """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - unfold, (input,), input, kernel_size, dilation=dilation, padding=padding, stride=stride - ) + if has_torch_function_unary(input): + return handle_torch_function( + unfold, (input,), input, kernel_size, dilation=dilation, padding=padding, stride=stride + ) if input.dim() == 4: msg = "{} must be int or 2-tuple for 4D input" assert_int_or_pair(kernel_size, "kernel_size", msg) @@ -4346,11 +4265,10 @@ def fold(input, output_size, kernel_size, dilation=1, padding=0, stride=1): See :class:`torch.nn.Fold` for details """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - fold, (input,), input, output_size, kernel_size, dilation=dilation, padding=padding, stride=stride - ) + if has_torch_function_unary(input): + return handle_torch_function( + fold, (input,), input, output_size, kernel_size, dilation=dilation, padding=padding, stride=stride + ) if input.dim() == 3: msg = "{} must be int or 2-tuple for 3D input" assert_int_or_pair(output_size, "output_size", msg) @@ -4613,36 +4531,35 @@ def multi_head_attention_forward( - attn_output_weights: :math:`(N, L, S)` where N is the batch size, L is the target sequence length, S is the source sequence length. """ - if not torch.jit.is_scripting(): - tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias) - if has_torch_function(tens_ops): - return handle_torch_function( - multi_head_attention_forward, - tens_ops, - query, - key, - value, - embed_dim_to_check, - num_heads, - in_proj_weight, - in_proj_bias, - bias_k, - bias_v, - add_zero_attn, - dropout_p, - out_proj_weight, - out_proj_bias, - training=training, - key_padding_mask=key_padding_mask, - need_weights=need_weights, - attn_mask=attn_mask, - use_separate_proj_weight=use_separate_proj_weight, - q_proj_weight=q_proj_weight, - k_proj_weight=k_proj_weight, - v_proj_weight=v_proj_weight, - static_k=static_k, - static_v=static_v, - ) + tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias) + if has_torch_function(tens_ops): + return handle_torch_function( + multi_head_attention_forward, + tens_ops, + query, + key, + value, + embed_dim_to_check, + num_heads, + in_proj_weight, + in_proj_bias, + bias_k, + bias_v, + add_zero_attn, + dropout_p, + out_proj_weight, + out_proj_bias, + training=training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + use_separate_proj_weight=use_separate_proj_weight, + q_proj_weight=q_proj_weight, + k_proj_weight=k_proj_weight, + v_proj_weight=v_proj_weight, + static_k=static_k, + static_v=static_v, + ) tgt_len, bsz, embed_dim = query.size() assert embed_dim == embed_dim_to_check # allow MHA to have different sizes for the feature dimension