From 6a3fc0c21cd14686a412f6a606cb458906f12fac Mon Sep 17 00:00:00 2001 From: Taylor Robie Date: Sun, 10 Jan 2021 19:16:04 -0800 Subject: [PATCH] Treat has_torch_function and object_has_torch_function as static False when scripting (#48966) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/48966 This PR lets us skip the `if not torch.jit.is_scripting():` guards on `functional` and `nn.functional` by directly registering `has_torch_function` and `object_has_torch_function` to the JIT as statically False. **Benchmarks** The benchmark script is kind of long. The reason is that it's testing all four PRs in the stack, plus threading and subprocessing so that the benchmark can utilize multiple cores while still collecting good numbers. Both wall times and instruction counts were collected. This stack changes dozens of operators / functions, but very mechanically such that there are only a handful of codepath changes. Each row is a slightly different code path (e.g. testing in Python, testing in the arg parser, different input types, etc.)
Test script ``` import argparse import multiprocessing import multiprocessing.dummy import os import pickle import queue import random import sys import subprocess import tempfile import time import torch from torch.utils.benchmark import Timer, Compare, Measurement NUM_CORES = multiprocessing.cpu_count() ENVS = { "ref": "HEAD (current)", "torch_fn_overhead_stack_0": "#48963", "torch_fn_overhead_stack_1": "#48964", "torch_fn_overhead_stack_2": "#48965", "torch_fn_overhead_stack_3": "#48966", } CALLGRIND_ENVS = tuple(ENVS.keys()) MIN_RUN_TIME = 3 REPLICATES = { "longer": 1_000, "long": 300, "short": 50, } CALLGRIND_NUMBER = { "overnight": 500_000, "long": 250_000, "short": 10_000, } CALLGRIND_TIMEOUT = { "overnight": 800, "long": 400, "short": 100, } SETUP = """ x = torch.ones((1, 1)) y = torch.ones((1, 1)) w_tensor = torch.ones((1, 1), requires_grad=True) linear = torch.nn.Linear(1, 1, bias=False) linear_w = linear.weight """ TASKS = { "C++: unary `.t()`": "w_tensor.t()", "C++: unary (Parameter) `.t()`": "linear_w.t()", "C++: binary (Parameter) `mul` ": "x + linear_w", "tensor.py: _wrap_type_error_to_not_implemented `__floordiv__`": "x // y", "tensor.py: method `__hash__`": "hash(x)", "Python scalar `__rsub__`": "1 - x", "functional.py: (unary) `unique`": "torch.functional.unique(x)", "functional.py: (args) `atleast_1d`": "torch.functional.atleast_1d((x, y))", "nn/functional.py: (unary) `relu`": "torch.nn.functional.relu(x)", "nn/functional.py: (args) `linear`": "torch.nn.functional.linear(x, w_tensor)", "nn/functional.py: (args) `linear (Parameter)`": "torch.nn.functional.linear(x, linear_w)", "Linear(..., bias=False)": "linear(x)", } def _worker_main(argv, fn): parser = argparse.ArgumentParser() parser.add_argument("--output_file", type=str) parser.add_argument("--single_task", type=int, default=None) parser.add_argument("--length", type=str) args = parser.parse_args(argv) single_task = args.single_task conda_prefix = os.getenv("CONDA_PREFIX") assert torch.__file__.startswith(conda_prefix) env = os.path.split(conda_prefix)[1] assert env in ENVS results = [] for i, (k, stmt) in enumerate(TASKS.items()): if single_task is not None and single_task != i: continue timer = Timer( stmt=stmt, setup=SETUP, sub_label=k, description=ENVS[env], ) results.append(fn(timer, args.length)) with open(args.output_file, "wb") as f: pickle.dump(results, f) def worker_main(argv): _worker_main( argv, lambda timer, _: timer.blocked_autorange(min_run_time=MIN_RUN_TIME) ) def callgrind_worker_main(argv): _worker_main( argv, lambda timer, length: timer.collect_callgrind(number=CALLGRIND_NUMBER[length], collect_baseline=False)) def main(argv): parser = argparse.ArgumentParser() parser.add_argument("--long", action="store_true") parser.add_argument("--longer", action="store_true") args = parser.parse_args(argv) if args.longer: length = "longer" elif args.long: length = "long" else: length = "short" replicates = REPLICATES[length] num_workers = int(NUM_CORES // 2) tasks = list(ENVS.keys()) * replicates random.shuffle(tasks) task_queue = queue.Queue() for _ in range(replicates): envs = list(ENVS.keys()) random.shuffle(envs) for e in envs: task_queue.put((e, None)) callgrind_task_queue = queue.Queue() for e in CALLGRIND_ENVS: for i, _ in enumerate(TASKS): callgrind_task_queue.put((e, i)) results = [] callgrind_results = [] def map_fn(worker_id): # Adjacent cores often share cache and maxing out a machine can distort # timings so we space them out. callgrind_cores = f"{worker_id * 2}-{worker_id * 2 + 1}" time_cores = str(worker_id * 2) _, output_file = tempfile.mkstemp(suffix=".pkl") try: loop_tasks = ( # Callgrind is long running, and then the workers can help with # timing after they finish collecting counts. (callgrind_task_queue, callgrind_results, "callgrind_worker", callgrind_cores, CALLGRIND_TIMEOUT[length]), (task_queue, results, "worker", time_cores, None)) for queue_i, results_i, mode_i, cores, timeout in loop_tasks: while True: try: env, task_i = queue_i.get_nowait() except queue.Empty: break remaining_attempts = 3 while True: try: subprocess.run( " ".join([ "source", "activate", env, "&&", "taskset", "--cpu-list", cores, "python", os.path.abspath(__file__), "--mode", mode_i, "--length", length, "--output_file", output_file ] + ([] if task_i is None else ["--single_task", str(task_i)])), shell=True, check=True, timeout=timeout, ) break except subprocess.TimeoutExpired: # Sometimes Valgrind will hang if there are too many # concurrent runs. remaining_attempts -= 1 if not remaining_attempts: print("Too many failed attempts.") raise print(f"Timeout after {timeout} sec. Retrying.") # We don't need a lock, as the GIL is enough. with open(output_file, "rb") as f: results_i.extend(pickle.load(f)) finally: os.remove(output_file) with multiprocessing.dummy.Pool(num_workers) as pool: st, st_estimate, eta, n_total = time.time(), None, "", len(tasks) * len(TASKS) map_job = pool.map_async(map_fn, range(num_workers)) while not map_job.ready(): n_complete = len(results) if n_complete and len(callgrind_results): if st_estimate is None: st_estimate = time.time() else: sec_per_element = (time.time() - st_estimate) / n_complete n_remaining = n_total - n_complete eta = f"ETA: {n_remaining * sec_per_element:.0f} sec" print( f"\r{n_complete} / {n_total} " f"({len(callgrind_results)} / {len(CALLGRIND_ENVS) * len(TASKS)}) " f"{eta}".ljust(40), end="") sys.stdout.flush() time.sleep(2) total_time = int(time.time() - st) print(f"\nTotal time: {int(total_time // 60)} min, {total_time % 60} sec") desc_to_ind = {k: i for i, k in enumerate(ENVS.values())} results.sort(key=lambda r: desc_to_ind[r.description]) # TODO: Compare should be richer and more modular. compare = Compare(results) compare.trim_significant_figures() compare.colorize(rowwise=True) # Manually add master vs. overall relative delta t. merged_results = { (r.description, r.sub_label): r for r in Measurement.merge(results) } cmp_lines = str(compare).splitlines(False) print(cmp_lines[0][:-1] + "-" * 15 + "]") print(f"{cmp_lines[1]} |{'':>10}\u0394t") print(cmp_lines[2] + "-" * 15) for l, t in zip(cmp_lines[3:3 + len(TASKS)], TASKS.keys()): assert l.strip().startswith(t) t0 = merged_results[(ENVS["ref"], t)].median t1 = merged_results[(ENVS["torch_fn_overhead_stack_3"], t)].median print(f"{l} |{'':>5}{(t1 / t0 - 1) * 100:>6.1f}%") print("\n".join(cmp_lines[3 + len(TASKS):])) counts_dict = { (r.task_spec.description, r.task_spec.sub_label): r.counts(denoise=True) for r in callgrind_results } def rel_diff(x, x0): return f"{(x / x0 - 1) * 100:>6.1f}%" task_pad = max(len(t) for t in TASKS) print(f"\n\nInstruction % change (relative to `{CALLGRIND_ENVS[0]}`)") print(" " * (task_pad + 8) + (" " * 7).join([ENVS[env] for env in CALLGRIND_ENVS[1:]])) for t in TASKS: values = [counts_dict[(ENVS[env], t)] for env in CALLGRIND_ENVS] print(t.ljust(task_pad + 3) + " ".join([ rel_diff(v, values[0]).rjust(len(ENVS[env]) + 5) for v, env in zip(values[1:], CALLGRIND_ENVS[1:])])) print("\033[4m" + " Instructions per invocation".ljust(task_pad + 3) + " ".join([ f"{v // CALLGRIND_NUMBER[length]:.0f}".rjust(len(ENVS[env]) + 5) for v, env in zip(values[1:], CALLGRIND_ENVS[1:])]) + "\033[0m") print() import pdb pdb.set_trace() if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--mode", type=str, choices=("main", "worker", "callgrind_worker"), default="main") args, remaining = parser.parse_known_args() if args.mode == "main": main(remaining) elif args.mode == "callgrind_worker": callgrind_worker_main(remaining) else: worker_main(remaining) ```
**Wall time** Screen Shot 2020-12-12 at 12 28 13 PM
Longer run (`python test.py --long`) is basically identical. Screen Shot 2020-12-12 at 5 02 47 PM
**Callgrind** Screen Shot 2020-12-12 at 12 28 54 PM Test Plan: existing unit tests. Reviewed By: ezyang Differential Revision: D25590731 Pulled By: robieta fbshipit-source-id: fe05305ff22b0e34ced44b60f2e9f07907a099dd --- aten/src/ATen/core/interned_strings.h | 1 + torch/csrc/jit/frontend/ir_emitter.cpp | 5 +- .../csrc/jit/runtime/register_special_ops.cpp | 4 + torch/functional.py | 164 +-- torch/jit/_script.py | 5 + torch/nn/functional.py | 1143 ++++++++--------- 6 files changed, 614 insertions(+), 708 deletions(-) diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h index f99dc3c07058..a65a48d601dc 100644 --- a/aten/src/ATen/core/interned_strings.h +++ b/aten/src/ATen/core/interned_strings.h @@ -294,6 +294,7 @@ namespace c10 { _(aten, swapdims_) \ _(aten, movedim) \ _(aten, moveaxis) \ + _(aten, has_torch_function) \ FORALL_ATEN_BASE_SYMBOLS(_) \ _(onnx, Add) \ _(onnx, Concat) \ diff --git a/torch/csrc/jit/frontend/ir_emitter.cpp b/torch/csrc/jit/frontend/ir_emitter.cpp index d3aa0ba7295a..0dd84e4bb257 100644 --- a/torch/csrc/jit/frontend/ir_emitter.cpp +++ b/torch/csrc/jit/frontend/ir_emitter.cpp @@ -1224,8 +1224,11 @@ struct to_ir { } auto expr_out = emitToBool(expr.range(), emitExpr(expr)); c10::optional static_if = c10::nullopt; - if (expr_out->node()->kind() == aten::is_scripting) { + auto kind = expr_out->node()->kind(); + if (kind == aten::is_scripting) { static_if = true; + } else if (kind == aten::has_torch_function) { + static_if = false; } // MetaCompile on boolean literals and constants if (auto maybe_ivalue = toIValue(expr_out)) { diff --git a/torch/csrc/jit/runtime/register_special_ops.cpp b/torch/csrc/jit/runtime/register_special_ops.cpp index 28a4136ba829..2cd5a13d3f4b 100644 --- a/torch/csrc/jit/runtime/register_special_ops.cpp +++ b/torch/csrc/jit/runtime/register_special_ops.cpp @@ -372,6 +372,10 @@ RegisterOperators reg({ TORCH_SELECTIVE_SCHEMA("aten::is_scripting() -> bool"), [](Stack* stack) { push(stack, true); }, aliasAnalysisFromSchema()), + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA("aten::has_torch_function(...) -> bool"), + [](Stack* stack) { push(stack, false); }, + aliasAnalysisFromSchema()), OperatorGenerator( TORCH_SELECTIVE_SCHEMA( "aten::_no_grad_uniform_(Tensor(a!) tensor, float a, float b) -> Tensor(a!)"), diff --git a/torch/functional.py b/torch/functional.py index 43fa0a3df546..1d3403e65304 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -69,9 +69,8 @@ def broadcast_tensors(*tensors): tensor([[0, 1, 2], [0, 1, 2]]) """ - if not torch.jit.is_scripting(): - if has_torch_function(tensors): - return handle_torch_function(broadcast_tensors, tensors, *tensors) + if has_torch_function(tensors): + return handle_torch_function(broadcast_tensors, tensors, *tensors) return _VF.broadcast_tensors(tensors) # type: ignore @@ -147,10 +146,9 @@ def split(tensor, split_size_or_sections, dim=0): [6, 7], [8, 9]])) """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(tensor): - return handle_torch_function(split, (tensor,), tensor, split_size_or_sections, - dim=dim) + if has_torch_function_unary(tensor): + return handle_torch_function( + split, (tensor,), tensor, split_size_or_sections, dim=dim) # Overwriting reason: # This dispatches to two ATen functions depending on the type of # split_size_or_sections. The branching code is in tensor.py, which we @@ -236,11 +234,11 @@ def lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True): >>> torch.norm(A_ - A) tensor(2.9802e-08) """ - if not torch.jit.is_scripting(): - if has_torch_function_variadic(LU_data, LU_pivots): - return handle_torch_function( - lu_unpack, (LU_data, LU_pivots), LU_data, LU_pivots, unpack_data=unpack_data, - unpack_pivots=unpack_pivots) + if has_torch_function_variadic(LU_data, LU_pivots): + return handle_torch_function( + lu_unpack, (LU_data, LU_pivots), LU_data, LU_pivots, + unpack_data=unpack_data, + unpack_pivots=unpack_pivots) shape = LU_data.shape # In generalized LU factorization, the following shape relations hold: # A.shape[-2:] == (m, n) @@ -301,7 +299,7 @@ def einsum(equation, *operands): based on the Einstein summation convention. Einsum allows computing many common multi-dimensional linear algebraic array operations by representing them - in a short-hand format based on the Einstein summation convention, given by :attr:`equation`. The details of + in a short-hand format based on the Einstein summation convention, given by :attr:`equation`. The details of this format are described below, but the general idea is to label every dimension of the input :attr:`operands` with some subscript and define which subscripts are part of the output. The output is then computed by summing the product of the elements of the :attr:`operands` along the dimensions whose subscripts are not part of the @@ -387,7 +385,7 @@ def einsum(equation, *operands): # batch permute >>> A = torch.randn(2, 3, 4, 5) - >>> torch.einsum('...ij->...ji', A).shape + >>> torch.einsum('...ij->...ji', A).shape torch.Size([2, 3, 5, 4]) # equivalent to torch.nn.functional.bilinear @@ -398,9 +396,8 @@ def einsum(equation, *operands): tensor([[-0.3430, -5.2405, 0.4494], [ 0.3311, 5.5201, -3.0356]]) """ - if not torch.jit.is_scripting(): - if has_torch_function(operands): - return handle_torch_function(einsum, operands, equation, *operands) + if has_torch_function(operands): + return handle_torch_function(einsum, operands, equation, *operands) if len(operands) == 1 and isinstance(operands[0], (list, tuple)): # the old interface of passing the operands as one list argument _operands = operands[0] @@ -448,9 +445,8 @@ def meshgrid(*tensors): def _meshgrid(*tensors): - if not torch.jit.is_scripting(): - if has_torch_function(tensors): - return handle_torch_function(meshgrid, tensors, *tensors) + if has_torch_function(tensors): + return handle_torch_function(meshgrid, tensors, *tensors) if len(tensors) == 1 and isinstance(tensors[0], (list, tuple)): # the old interface of passing the operands as one list argument tensors = tensors[0] # type: ignore @@ -568,12 +564,11 @@ def stft(input: Tensor, n_fft: int, hop_length: Optional[int] = None, Tensor: A tensor containing the STFT result with shape described above """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - stft, (input,), input, n_fft, hop_length=hop_length, win_length=win_length, - window=window, center=center, pad_mode=pad_mode, normalized=normalized, - onesided=onesided, return_complex=return_complex) + if has_torch_function_unary(input): + return handle_torch_function( + stft, (input,), input, n_fft, hop_length=hop_length, win_length=win_length, + window=window, center=center, pad_mode=pad_mode, normalized=normalized, + onesided=onesided, return_complex=return_complex) # TODO: after having proper ways to map Python strings to ATen Enum, move # this and F.pad to ATen. if center: @@ -650,12 +645,11 @@ def istft(input: Tensor, n_fft: int, hop_length: Optional[int] = None, Returns: Tensor: Least squares estimation of the original signal of size (..., signal_length) """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - istft, (input,), input, n_fft, hop_length=hop_length, win_length=win_length, - window=window, center=center, normalized=normalized, onesided=onesided, - length=length, return_complex=return_complex) + if has_torch_function_unary(input): + return handle_torch_function( + istft, (input,), input, n_fft, hop_length=hop_length, win_length=win_length, + window=window, center=center, normalized=normalized, onesided=onesided, + length=length, return_complex=return_complex) return _VF.istft(input, n_fft, hop_length, win_length, window, center, # type: ignore normalized, onesided, length, return_complex) @@ -734,11 +728,10 @@ def _unique_impl(input: Tensor, sorted: bool = True, [ 1, 2]]) """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - unique, (input,), input, sorted=sorted, return_inverse=return_inverse, - return_counts=return_counts, dim=dim) + if has_torch_function_unary(input): + return handle_torch_function( + unique, (input,), input, sorted=sorted, return_inverse=return_inverse, + return_counts=return_counts, dim=dim) if dim is not None: output, inverse_indices, counts = _VF.unique_dim( # type: ignore @@ -810,11 +803,10 @@ def _unique_consecutive_impl(input: Tensor, return_inverse: bool = False, >>> counts tensor([2, 2, 1, 2, 1]) """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - unique_consecutive, (input,), input, return_inverse=return_inverse, - return_counts=return_counts, dim=dim) + if has_torch_function_unary(input): + return handle_torch_function( + unique_consecutive, (input,), input, return_inverse=return_inverse, + return_counts=return_counts, dim=dim) output, inverse_indices, counts = _VF.unique_consecutive( # type: ignore input, return_inverse=return_inverse, return_counts=return_counts, dim=dim) return output, inverse_indices, counts @@ -823,9 +815,8 @@ def _unique_consecutive_impl(input: Tensor, return_inverse: bool = False, def _return_counts(input, sorted=True, return_inverse=False, return_counts=False, dim=None): # type: (Tensor, bool, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor] - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return _unique_impl(input, sorted, return_inverse, return_counts, dim) + if has_torch_function_unary(input): + return _unique_impl(input, sorted, return_inverse, return_counts, dim) output, _, counts = _unique_impl(input, sorted, return_inverse, return_counts, dim) return output, counts @@ -834,9 +825,8 @@ def _return_counts(input, sorted=True, return_inverse=False, return_counts=False def _return_output(input, sorted=True, return_inverse=False, return_counts=False, dim=None): # type: (Tensor, bool, bool, bool, Optional[int]) -> Tensor - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return _unique_impl(input, sorted, return_inverse, return_counts, dim) + if has_torch_function_unary(input): + return _unique_impl(input, sorted, return_inverse, return_counts, dim) output, _, _ = _unique_impl(input, sorted, return_inverse, return_counts, dim) return output @@ -845,9 +835,8 @@ def _return_output(input, sorted=True, return_inverse=False, return_counts=False def _return_inverse(input, sorted=True, return_inverse=False, return_counts=False, dim=None): # type: (Tensor, bool, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor] - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return _unique_impl(input, sorted, return_inverse, return_counts, dim) + if has_torch_function_unary(input): + return _unique_impl(input, sorted, return_inverse, return_counts, dim) output, inverse_indices, _ = _unique_impl(input, sorted, return_inverse, return_counts, dim) return output, inverse_indices @@ -888,9 +877,8 @@ def _return_inverse(input, sorted=True, return_inverse=False, return_counts=Fals def _consecutive_return_counts(input, return_inverse=False, return_counts=False, dim=None): # type: (Tensor, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor] - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return _unique_consecutive_impl(input, return_inverse, return_counts, dim) + if has_torch_function_unary(input): + return _unique_consecutive_impl(input, return_inverse, return_counts, dim) output, _, counts = _unique_consecutive_impl(input, return_inverse, return_counts, dim) return output, counts @@ -899,9 +887,8 @@ def _consecutive_return_counts(input, return_inverse=False, return_counts=False, def _consecutive_return_output(input, return_inverse=False, return_counts=False, dim=None): # type: (Tensor, bool, bool, Optional[int]) -> Tensor - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return _unique_consecutive_impl(input, return_inverse, return_counts, dim) + if has_torch_function_unary(input): + return _unique_consecutive_impl(input, return_inverse, return_counts, dim) output, _, _ = _unique_consecutive_impl(input, return_inverse, return_counts, dim) return output @@ -910,9 +897,8 @@ def _consecutive_return_output(input, return_inverse=False, return_counts=False, def _consecutive_return_inverse(input, return_inverse=False, return_counts=False, dim=None): # type: (Tensor, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor] - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return _unique_consecutive_impl(input, return_inverse, return_counts, dim) + if has_torch_function_unary(input): + return _unique_consecutive_impl(input, return_inverse, return_counts, dim) output, inverse_indices, _ = _unique_consecutive_impl(input, return_inverse, return_counts, dim) return output, inverse_indices @@ -1000,9 +986,8 @@ def tensordot(a, b, dims=2, out=None): [ 1.5513, -14.4737, -6.5113], [ -0.2850, 4.2573, -3.5997]]) """ - if not torch.jit.is_scripting(): - if has_torch_function_variadic(a, b): - return handle_torch_function(tensordot, (a, b), a, b, dims=dims) + if has_torch_function_variadic(a, b): + return handle_torch_function(tensordot, (a, b), a, b, dims=dims) if isinstance(dims, (list, tuple)) or \ (isinstance(dims, torch.Tensor) and dims.numel() > 1): dims_a, dims_b = dims @@ -1046,9 +1031,8 @@ def cartesian_prod(*tensors): [3, 4], [3, 5]]) """ - if not torch.jit.is_scripting(): - if has_torch_function(tensors): - return handle_torch_function(cartesian_prod, tensors, *tensors) + if has_torch_function(tensors): + return handle_torch_function(cartesian_prod, tensors, *tensors) return _VF.cartesian_prod(tensors) # type: ignore def block_diag(*tensors): @@ -1128,10 +1112,9 @@ def cdist(x1, x2, p=2., compute_mode='use_mm_for_euclid_dist_if_necessary'): [2.7138, 3.8322], [2.2830, 0.3791]]) """ - if not torch.jit.is_scripting(): - if has_torch_function_variadic(x1, x2): - return handle_torch_function( - cdist, (x1, x2), x1, x2, p=p, compute_mode=compute_mode) + if has_torch_function_variadic(x1, x2): + return handle_torch_function( + cdist, (x1, x2), x1, x2, p=p, compute_mode=compute_mode) if compute_mode == 'use_mm_for_euclid_dist_if_necessary': return _VF.cdist(x1, x2, p, None) # type: ignore elif compute_mode == 'use_mm_for_euclid_dist': @@ -1168,9 +1151,8 @@ def atleast_1d(*tensors): >>> torch.atleast_1d((x,y)) (tensor([0.5000]), tensor([1.])) """ - if not torch.jit.is_scripting(): - if has_torch_function(tensors): - return handle_torch_function(atleast_1d, tensors, *tensors) + if has_torch_function(tensors): + return handle_torch_function(atleast_1d, tensors, *tensors) if len(tensors) == 1: tensors = tensors[0] return _VF.atleast_1d(tensors) # type: ignore @@ -1203,9 +1185,8 @@ def atleast_2d(*tensors): >>> torch.atleast_2d((x,y)) (tensor([[0.5000]]), tensor([[1.]])) """ - if not torch.jit.is_scripting(): - if has_torch_function(tensors): - return handle_torch_function(atleast_2d, tensors, *tensors) + if has_torch_function(tensors): + return handle_torch_function(atleast_2d, tensors, *tensors) if len(tensors) == 1: tensors = tensors[0] return _VF.atleast_2d(tensors) # type: ignore @@ -1247,9 +1228,8 @@ def atleast_3d(*tensors): >>> torch.atleast_3d((x,y)) (tensor([[[0.5000]]]), tensor([[[1.]]])) """ - if not torch.jit.is_scripting(): - if has_torch_function(tensors): - return handle_torch_function(atleast_3d, tensors, *tensors) + if has_torch_function(tensors): + return handle_torch_function(atleast_3d, tensors, *tensors) if len(tensors) == 1: tensors = tensors[0] return _VF.atleast_3d(tensors) # type: ignore @@ -1380,10 +1360,9 @@ def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None): # noqa (tensor(3.7417), tensor(11.2250)) """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - norm, (input,), input, p=p, dim=dim, keepdim=keepdim, out=out, dtype=dtype) + if has_torch_function_unary(input): + return handle_torch_function( + norm, (input,), input, p=p, dim=dim, keepdim=keepdim, out=out, dtype=dtype) ndim = input.dim() @@ -1476,9 +1455,8 @@ def chain_matmul(*matrices): .. _`[CLRS]`: https://mitpress.mit.edu/books/introduction-algorithms-third-edition """ - if not torch.jit.is_scripting(): - if has_torch_function(matrices): - return handle_torch_function(chain_matmul, matrices, *matrices) + if has_torch_function(matrices): + return handle_torch_function(chain_matmul, matrices, *matrices) return _VF.chain_matmul(matrices) # type: ignore @@ -1596,10 +1574,9 @@ def _check_list_size(out_len: int, get_infos: bool, out: _ListOrSeq) -> None: def _lu_with_infos(A, pivot=True, get_infos=False, out=None): # type: (Tensor, bool, bool, Optional[Tuple[Tensor, Tensor, Tensor]]) -> Tuple[Tensor, Tensor, Tensor] - if not torch.jit.is_scripting(): - if has_torch_function_unary(A): - return handle_torch_function( - lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out) + if has_torch_function_unary(A): + return handle_torch_function( + lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out) result = _lu_impl(A, pivot, get_infos, out) if out is not None: _check_list_size(len(out), get_infos, out) @@ -1612,10 +1589,9 @@ def _lu_with_infos(A, pivot=True, get_infos=False, out=None): def _lu_no_infos(A, pivot=True, get_infos=False, out=None): # type: (Tensor, bool, bool, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tensor] # need to check for torch_function here so that we exit if - if not torch.jit.is_scripting(): - if has_torch_function_unary(A): - return handle_torch_function( - lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out) + if has_torch_function_unary(A): + return handle_torch_function( + lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out) result = _lu_impl(A, pivot, get_infos, out) if out is not None: _check_list_size(len(out), get_infos, out) diff --git a/torch/jit/_script.py b/torch/jit/_script.py index 8bc8c6117c1b..bdf00e21c515 100644 --- a/torch/jit/_script.py +++ b/torch/jit/_script.py @@ -32,6 +32,8 @@ _set_jit_function_cache, _set_jit_overload_cache, ) +from torch.overrides import ( + has_torch_function, has_torch_function_unary, has_torch_function_variadic) torch._C.ScriptMethod.graph_for = _graph_for # type: ignore torch._C.ScriptFunction.graph_for = _graph_for # type: ignore @@ -1119,3 +1121,6 @@ def _unwrap_optional(x): _register_builtin(_unwrap_optional, "aten::_unwrap_optional") _register_builtin(_jit_internal.is_scripting, "aten::is_scripting") +_register_builtin(has_torch_function, "aten::has_torch_function") +_register_builtin(has_torch_function_unary, "aten::has_torch_function") +_register_builtin(has_torch_function_variadic, "aten::has_torch_function") diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 2cfc1c2b9393..ca2aaa5f9a40 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -412,18 +412,17 @@ def fractional_max_pool2d_with_indices( .. _Fractional MaxPooling: http://arxiv.org/abs/1412.6071 """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - fractional_max_pool2d_with_indices, - (input,), - input, - kernel_size, - output_size=output_size, - output_ratio=output_ratio, - return_indices=return_indices, - _random_samples=_random_samples, - ) + if has_torch_function_unary(input): + return handle_torch_function( + fractional_max_pool2d_with_indices, + (input,), + input, + kernel_size, + output_size=output_size, + output_ratio=output_ratio, + return_indices=return_indices, + _random_samples=_random_samples, + ) if output_size is None and output_ratio is None: raise ValueError("fractional_max_pool2d requires specifying either " "an output_size or an output_ratio") if output_size is None: @@ -440,18 +439,17 @@ def _fractional_max_pool2d( input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None ): # type: (Tensor, BroadcastingList2[int], Optional[BroadcastingList2[int]], Optional[BroadcastingList2[float]], bool, Optional[Tensor]) -> Tensor # noqa - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - fractional_max_pool2d, - (input,), - input, - kernel_size, - output_size=output_size, - output_ratio=output_ratio, - return_indices=return_indices, - _random_samples=_random_samples, - ) + if has_torch_function_unary(input): + return handle_torch_function( + fractional_max_pool2d, + (input,), + input, + kernel_size, + output_size=output_size, + output_ratio=output_ratio, + return_indices=return_indices, + _random_samples=_random_samples, + ) return fractional_max_pool2d_with_indices( input, kernel_size, output_size, output_ratio, return_indices, _random_samples )[0] @@ -502,18 +500,17 @@ def fractional_max_pool3d_with_indices( .. _Fractional MaxPooling: http://arxiv.org/abs/1412.6071 """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - fractional_max_pool3d_with_indices, - (input,), - input, - kernel_size, - output_size=output_size, - output_ratio=output_ratio, - return_indices=return_indices, - _random_samples=_random_samples, - ) + if has_torch_function_unary(input): + return handle_torch_function( + fractional_max_pool3d_with_indices, + (input,), + input, + kernel_size, + output_size=output_size, + output_ratio=output_ratio, + return_indices=return_indices, + _random_samples=_random_samples, + ) if output_size is None and output_ratio is None: raise ValueError("fractional_max_pool3d requires specifying either " "an output_size or an output_ratio") if output_size is None: @@ -534,18 +531,17 @@ def _fractional_max_pool3d( input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None ): # type: (Tensor, BroadcastingList3[int], Optional[BroadcastingList3[int]], Optional[BroadcastingList3[float]], bool, Optional[Tensor]) -> Tensor # noqa - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - fractional_max_pool3d, - (input,), - input, - kernel_size, - output_size=output_size, - output_ratio=output_ratio, - return_indices=return_indices, - _random_samples=_random_samples, - ) + if has_torch_function_unary(input): + return handle_torch_function( + fractional_max_pool3d, + (input,), + input, + kernel_size, + output_size=output_size, + output_ratio=output_ratio, + return_indices=return_indices, + _random_samples=_random_samples, + ) return fractional_max_pool3d_with_indices( input, kernel_size, output_size, output_ratio, return_indices, _random_samples )[0] @@ -571,19 +567,18 @@ def max_pool1d_with_indices( See :class:`~torch.nn.MaxPool1d` for details. """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - max_pool1d_with_indices, - (input,), - input, - kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - ceil_mode=ceil_mode, - return_indices=return_indices, - ) + if has_torch_function_unary(input): + return handle_torch_function( + max_pool1d_with_indices, + (input,), + input, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + return_indices=return_indices, + ) if stride is None: stride = torch.jit.annotate(List[int], []) return torch.max_pool1d_with_indices(input, kernel_size, stride, padding, dilation, ceil_mode) @@ -591,19 +586,18 @@ def max_pool1d_with_indices( def _max_pool1d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False): # type: (Tensor, BroadcastingList1[int], Optional[BroadcastingList1[int]], BroadcastingList1[int], BroadcastingList1[int], bool, bool) -> Tensor # noqa - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - max_pool1d, - (input,), - input, - kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - ceil_mode=ceil_mode, - return_indices=return_indices, - ) + if has_torch_function_unary(input): + return handle_torch_function( + max_pool1d, + (input,), + input, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + return_indices=return_indices, + ) if stride is None: stride = torch.jit.annotate(List[int], []) return torch.max_pool1d(input, kernel_size, stride, padding, dilation, ceil_mode) @@ -629,19 +623,18 @@ def max_pool2d_with_indices( See :class:`~torch.nn.MaxPool2d` for details. """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - max_pool2d_with_indices, - (input,), - input, - kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - ceil_mode=ceil_mode, - return_indices=return_indices, - ) + if has_torch_function_unary(input): + return handle_torch_function( + max_pool2d_with_indices, + (input,), + input, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + return_indices=return_indices, + ) if stride is None: stride = torch.jit.annotate(List[int], []) return torch._C._nn.max_pool2d_with_indices(input, kernel_size, stride, padding, dilation, ceil_mode) @@ -649,19 +642,18 @@ def max_pool2d_with_indices( def _max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False): # type: (Tensor, BroadcastingList2[int], Optional[BroadcastingList2[int]], BroadcastingList2[int], BroadcastingList2[int], bool, bool) -> Tensor # noqa - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - max_pool2d, - (input,), - input, - kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - ceil_mode=ceil_mode, - return_indices=return_indices, - ) + if has_torch_function_unary(input): + return handle_torch_function( + max_pool2d, + (input,), + input, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + return_indices=return_indices, + ) if stride is None: stride = torch.jit.annotate(List[int], []) return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode) @@ -687,19 +679,18 @@ def max_pool3d_with_indices( See :class:`~torch.nn.MaxPool3d` for details. """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - max_pool3d_with_indices, - (input,), - input, - kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - ceil_mode=ceil_mode, - return_indices=return_indices, - ) + if has_torch_function_unary(input): + return handle_torch_function( + max_pool3d_with_indices, + (input,), + input, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + return_indices=return_indices, + ) if stride is None: stride = torch.jit.annotate(List[int], []) return torch._C._nn.max_pool3d_with_indices(input, kernel_size, stride, padding, dilation, ceil_mode) @@ -707,19 +698,18 @@ def max_pool3d_with_indices( def _max_pool3d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False): # type: (Tensor, BroadcastingList3[int], Optional[BroadcastingList3[int]], BroadcastingList3[int], BroadcastingList3[int], bool, bool) -> Tensor # noqa - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - max_pool3d, - (input,), - input, - kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - ceil_mode=ceil_mode, - return_indices=return_indices, - ) + if has_torch_function_unary(input): + return handle_torch_function( + max_pool3d, + (input,), + input, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + return_indices=return_indices, + ) if stride is None: stride = torch.jit.annotate(List[int], []) return torch.max_pool3d(input, kernel_size, stride, padding, dilation, ceil_mode) @@ -775,18 +765,17 @@ def max_unpool1d(input, indices, kernel_size, stride=None, padding=0, output_siz See :class:`~torch.nn.MaxUnpool1d` for details. """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - max_unpool1d, - (input,), - input, - indices, - kernel_size, - stride=stride, - padding=padding, - output_size=output_size, - ) + if has_torch_function_unary(input): + return handle_torch_function( + max_unpool1d, + (input,), + input, + indices, + kernel_size, + stride=stride, + padding=padding, + output_size=output_size, + ) kernel_size = _single(kernel_size) if stride is not None: _stride = _single(stride) @@ -807,18 +796,17 @@ def max_unpool2d(input, indices, kernel_size, stride=None, padding=0, output_siz See :class:`~torch.nn.MaxUnpool2d` for details. """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - max_unpool2d, - (input,), - input, - indices, - kernel_size, - stride=stride, - padding=padding, - output_size=output_size, - ) + if has_torch_function_unary(input): + return handle_torch_function( + max_unpool2d, + (input,), + input, + indices, + kernel_size, + stride=stride, + padding=padding, + output_size=output_size, + ) kernel_size = _pair(kernel_size) if stride is not None: _stride = _pair(stride) @@ -835,18 +823,17 @@ def max_unpool3d(input, indices, kernel_size, stride=None, padding=0, output_siz See :class:`~torch.nn.MaxUnpool3d` for details. """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - max_unpool3d, - (input,), - input, - indices, - kernel_size, - stride=stride, - padding=padding, - output_size=output_size, - ) + if has_torch_function_unary(input): + return handle_torch_function( + max_unpool3d, + (input,), + input, + indices, + kernel_size, + stride=stride, + padding=padding, + output_size=output_size, + ) kernel_size = _triple(kernel_size) if stride is not None: _stride = _triple(stride) @@ -865,11 +852,10 @@ def lp_pool2d(input, norm_type, kernel_size, stride=None, ceil_mode=False): See :class:`~torch.nn.LPPool2d` for details. """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - lp_pool2d, (input,), input, norm_type, kernel_size, stride=stride, ceil_mode=ceil_mode - ) + if has_torch_function_unary(input): + return handle_torch_function( + lp_pool2d, (input,), input, norm_type, kernel_size, stride=stride, ceil_mode=ceil_mode + ) kw, kh = utils._pair(kernel_size) if stride is not None: out = avg_pool2d(input.pow(norm_type), kernel_size, stride, 0, ceil_mode) @@ -887,11 +873,10 @@ def lp_pool1d(input, norm_type, kernel_size, stride=None, ceil_mode=False): See :class:`~torch.nn.LPPool1d` for details. """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - lp_pool1d, (input,), input, norm_type, kernel_size, stride=stride, ceil_mode=ceil_mode - ) + if has_torch_function_unary(input): + return handle_torch_function( + lp_pool1d, (input,), input, norm_type, kernel_size, stride=stride, ceil_mode=ceil_mode + ) if stride is not None: out = avg_pool1d(input.pow(norm_type), kernel_size, stride, 0, ceil_mode) else: @@ -911,21 +896,19 @@ def adaptive_max_pool1d_with_indices(input, output_size, return_indices=False): output_size: the target output size (single integer) return_indices: whether to return pooling indices. Default: ``False`` """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - adaptive_max_pool1d_with_indices, (input,), input, output_size, return_indices=return_indices - ) + if has_torch_function_unary(input): + return handle_torch_function( + adaptive_max_pool1d_with_indices, (input,), input, output_size, return_indices=return_indices + ) return torch.adaptive_max_pool1d(input, output_size) def _adaptive_max_pool1d(input, output_size, return_indices=False): # type: (Tensor, BroadcastingList1[int], bool) -> Tensor - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - adaptive_max_pool1d, (input,), input, output_size, return_indices=return_indices - ) + if has_torch_function_unary(input): + return handle_torch_function( + adaptive_max_pool1d, (input,), input, output_size, return_indices=return_indices + ) return adaptive_max_pool1d_with_indices(input, output_size)[0] @@ -952,22 +935,20 @@ def adaptive_max_pool2d_with_indices(input, output_size, return_indices=False): double-integer tuple) return_indices: whether to return pooling indices. Default: ``False`` """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - adaptive_max_pool2d_with_indices, (input,), input, output_size, return_indices=return_indices - ) + if has_torch_function_unary(input): + return handle_torch_function( + adaptive_max_pool2d_with_indices, (input,), input, output_size, return_indices=return_indices + ) output_size = _list_with_default(output_size, input.size()) return torch._C._nn.adaptive_max_pool2d(input, output_size) def _adaptive_max_pool2d(input, output_size, return_indices=False): # type: (Tensor, BroadcastingList2[int], bool) -> Tensor - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - adaptive_max_pool2d, (input,), input, output_size, return_indices=return_indices - ) + if has_torch_function_unary(input): + return handle_torch_function( + adaptive_max_pool2d, (input,), input, output_size, return_indices=return_indices + ) return adaptive_max_pool2d_with_indices(input, output_size)[0] @@ -994,22 +975,20 @@ def adaptive_max_pool3d_with_indices(input, output_size, return_indices=False): triple-integer tuple) return_indices: whether to return pooling indices. Default: ``False`` """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - adaptive_max_pool3d_with_indices, (input,), input, output_size, return_indices=return_indices - ) + if has_torch_function_unary(input): + return handle_torch_function( + adaptive_max_pool3d_with_indices, (input,), input, output_size, return_indices=return_indices + ) output_size = _list_with_default(output_size, input.size()) return torch._C._nn.adaptive_max_pool3d(input, output_size) def _adaptive_max_pool3d(input, output_size, return_indices=False): # type: (Tensor, BroadcastingList3[int], bool) -> Tensor - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - adaptive_max_pool3d, (input,), input, output_size, return_indices=return_indices - ) + if has_torch_function_unary(input): + return handle_torch_function( + adaptive_max_pool3d, (input,), input, output_size, return_indices=return_indices + ) return adaptive_max_pool3d_with_indices(input, output_size)[0] @@ -1052,9 +1031,8 @@ def adaptive_avg_pool2d(input, output_size): output_size: the target output size (single integer or double-integer tuple) """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function(adaptive_avg_pool2d, (input,), input, output_size) + if has_torch_function_unary(input): + return handle_torch_function(adaptive_avg_pool2d, (input,), input, output_size) _output_size = _list_with_default(output_size, input.size()) return torch._C._nn.adaptive_avg_pool2d(input, _output_size) @@ -1071,9 +1049,8 @@ def adaptive_avg_pool3d(input, output_size): output_size: the target output size (single integer or triple-integer tuple) """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function(adaptive_avg_pool3d, (input,), input, output_size) + if has_torch_function_unary(input): + return handle_torch_function(adaptive_avg_pool3d, (input,), input, output_size) _output_size = _list_with_default(output_size, input.size()) return torch._C._nn.adaptive_avg_pool3d(input, _output_size) @@ -1092,9 +1069,8 @@ def dropout(input: Tensor, p: float = 0.5, training: bool = True, inplace: bool training: apply dropout if is ``True``. Default: ``True`` inplace: If set to ``True``, will do this operation in-place. Default: ``False`` """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function(dropout, (input,), input, p=p, training=training, inplace=inplace) + if has_torch_function_unary(input): + return handle_torch_function(dropout, (input,), input, p=p, training=training, inplace=inplace) if p < 0.0 or p > 1.0: raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p)) return _VF.dropout_(input, p, training) if inplace else _VF.dropout(input, p, training) @@ -1105,9 +1081,8 @@ def alpha_dropout(input: Tensor, p: float = 0.5, training: bool = False, inplace See :class:`~torch.nn.AlphaDropout` for details. """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function(alpha_dropout, (input,), input, p=p, training=training, inplace=inplace) + if has_torch_function_unary(input): + return handle_torch_function(alpha_dropout, (input,), input, p=p, training=training, inplace=inplace) if p < 0.0 or p > 1.0: raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p)) return _VF.alpha_dropout_(input, p, training) if inplace else _VF.alpha_dropout(input, p, training) @@ -1128,9 +1103,8 @@ def dropout2d(input: Tensor, p: float = 0.5, training: bool = True, inplace: boo training: apply dropout if is ``True``. Default: ``True`` inplace: If set to ``True``, will do this operation in-place. Default: ``False`` """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function(dropout2d, (input,), input, p=p, training=training, inplace=inplace) + if has_torch_function_unary(input): + return handle_torch_function(dropout2d, (input,), input, p=p, training=training, inplace=inplace) if p < 0.0 or p > 1.0: raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p)) return _VF.feature_dropout_(input, p, training) if inplace else _VF.feature_dropout(input, p, training) @@ -1153,9 +1127,8 @@ def dropout3d(input: Tensor, p: float = 0.5, training: bool = True, inplace: boo """ # This is 100% the same code as dropout2d. We duplicate this code so that # stack traces are not confusing. - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function(dropout3d, (input,), input, p=p, training=training, inplace=inplace) + if has_torch_function_unary(input): + return handle_torch_function(dropout3d, (input,), input, p=p, training=training, inplace=inplace) if p < 0.0 or p > 1.0: raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p)) return _VF.feature_dropout_(input, p, training) if inplace else _VF.feature_dropout(input, p, training) @@ -1181,11 +1154,10 @@ def feature_alpha_dropout(input: Tensor, p: float = 0.5, training: bool = False, training: apply dropout if is ``True``. Default: ``True`` inplace: If set to ``True``, will do this operation in-place. Default: ``False`` """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - feature_alpha_dropout, (input,), input, p=p, training=training, inplace=inplace - ) + if has_torch_function_unary(input): + return handle_torch_function( + feature_alpha_dropout, (input,), input, p=p, training=training, inplace=inplace + ) if p < 0.0 or p > 1.0: raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p)) return _VF.feature_alpha_dropout_(input, p, training) if inplace else _VF.feature_alpha_dropout(input, p, training) @@ -1196,9 +1168,8 @@ def _threshold(input: Tensor, threshold: float, value: float, inplace: bool = Fa See :class:`~torch.nn.Threshold` for more details. """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function(_threshold, (input,), input, threshold, value, inplace=inplace) + if has_torch_function_unary(input): + return handle_torch_function(_threshold, (input,), input, threshold, value, inplace=inplace) if inplace: result = _VF.threshold_(input, threshold, value) else: @@ -1227,9 +1198,8 @@ def relu(input: Tensor, inplace: bool = False) -> Tensor: Applies the rectified linear unit function element-wise. See :class:`~torch.nn.ReLU` for more details. """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function(relu, (input,), input, inplace=inplace) + if has_torch_function_unary(input): + return handle_torch_function(relu, (input,), input, inplace=inplace) if inplace: result = torch.relu_(input) else: @@ -1265,9 +1235,8 @@ def glu(input: Tensor, dim: int = -1) -> Tensor: input (Tensor): input tensor dim (int): dimension on which to split the input. Default: -1 """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function(glu, (input,), input, dim=dim) + if has_torch_function_unary(input): + return handle_torch_function(glu, (input,), input, dim=dim) if input.dim() == 0: raise RuntimeError("glu does not support scalars because halving size must be even") return torch._C._nn.glu(input, dim) @@ -1280,9 +1249,8 @@ def hardtanh(input: Tensor, min_val: float = -1.0, max_val: float = 1.0, inplace Applies the HardTanh function element-wise. See :class:`~torch.nn.Hardtanh` for more details. """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function(hardtanh, (input,), input, min_val=min_val, max_val=max_val, inplace=inplace) + if has_torch_function_unary(input): + return handle_torch_function(hardtanh, (input,), input, min_val=min_val, max_val=max_val, inplace=inplace) if inplace: result = torch._C._nn.hardtanh_(input, min_val, max_val) else: @@ -1307,9 +1275,8 @@ def relu6(input: Tensor, inplace: bool = False) -> Tensor: See :class:`~torch.nn.ReLU6` for more details. """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function(relu6, (input,), input, inplace=inplace) + if has_torch_function_unary(input): + return handle_torch_function(relu6, (input,), input, inplace=inplace) return hardtanh(input, 0.0, 6.0, inplace) @@ -1319,9 +1286,8 @@ def elu(input: Tensor, alpha: float = 1.0, inplace: bool = False) -> Tensor: See :class:`~torch.nn.ELU` for more details. """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function(elu, (input,), input, alpha=alpha, inplace=inplace) + if has_torch_function_unary(input): + return handle_torch_function(elu, (input,), input, alpha=alpha, inplace=inplace) if inplace: result = torch._C._nn.elu_(input, alpha) else: @@ -1349,9 +1315,8 @@ def selu(input: Tensor, inplace: bool = False) -> Tensor: See :class:`~torch.nn.SELU` for more details. """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function(selu, (input,), input, inplace=inplace) + if has_torch_function_unary(input): + return handle_torch_function(selu, (input,), input, inplace=inplace) if inplace: result = torch.selu_(input) else: @@ -1377,9 +1342,8 @@ def celu(input: Tensor, alpha: float = 1.0, inplace: bool = False) -> Tensor: See :class:`~torch.nn.CELU` for more details. """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function(celu, (input,), input, alpha=alpha, inplace=inplace) + if has_torch_function_unary(input): + return handle_torch_function(celu, (input,), input, alpha=alpha, inplace=inplace) if inplace: result = torch.celu_(input, alpha) else: @@ -1406,9 +1370,8 @@ def leaky_relu(input: Tensor, negative_slope: float = 0.01, inplace: bool = Fals See :class:`~torch.nn.LeakyReLU` for more details. """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function(leaky_relu, (input,), input, negative_slope=negative_slope, inplace=inplace) + if has_torch_function_unary(input): + return handle_torch_function(leaky_relu, (input,), input, negative_slope=negative_slope, inplace=inplace) if inplace: result = torch._C._nn.leaky_relu_(input, negative_slope) else: @@ -1435,9 +1398,8 @@ def prelu(input: Tensor, weight: Tensor) -> Tensor: See :class:`~torch.nn.PReLU` for more details. """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function(prelu, (input,), input, weight) + if has_torch_function_unary(input): + return handle_torch_function(prelu, (input,), input, weight) return torch.prelu(input, weight) @@ -1450,11 +1412,10 @@ def rrelu( See :class:`~torch.nn.RReLU` for more details. """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - rrelu, (input,), input, lower=lower, upper=upper, training=training, inplace=inplace - ) + if has_torch_function_unary(input): + return handle_torch_function( + rrelu, (input,), input, lower=lower, upper=upper, training=training, inplace=inplace + ) if inplace: result = torch.rrelu_(input, lower, upper, training) else: @@ -1493,9 +1454,8 @@ def gelu(input): See `Gaussian Error Linear Units (GELUs) `_. """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function(gelu, (input,), input) + if has_torch_function_unary(input): + return handle_torch_function(gelu, (input,), input) return torch._C._nn.gelu(input) @@ -1507,9 +1467,8 @@ def hardshrink(input: Tensor, lambd: float = 0.5) -> Tensor: See :class:`~torch.nn.Hardshrink` for more details. """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function(hardshrink, (input,), input, lambd=lambd) + if has_torch_function_unary(input): + return handle_torch_function(hardshrink, (input,), input, lambd=lambd) return torch.hardshrink(input, lambd) @@ -1520,9 +1479,8 @@ def tanhshrink(input): See :class:`~torch.nn.Tanhshrink` for more details. """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function(tanhshrink, (input,), input) + if has_torch_function_unary(input): + return handle_torch_function(tanhshrink, (input,), input) return input - input.tanh() @@ -1533,9 +1491,8 @@ def softsign(input): See :class:`~torch.nn.Softsign` for more details. """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function(softsign, (input,), input) + if has_torch_function_unary(input): + return handle_torch_function(softsign, (input,), input) return input / (input.abs() + 1) @@ -1582,9 +1539,8 @@ def softmin(input: Tensor, dim: Optional[int] = None, _stacklevel: int = 3, dtyp If specified, the input tensor is casted to :attr:`dtype` before the operation is performed. This is useful for preventing data type overflows. Default: None. """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function(softmin, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype) + if has_torch_function_unary(input): + return handle_torch_function(softmin, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype) if dim is None: dim = _get_softmax_dim("softmin", input.dim(), _stacklevel) if dtype is None: @@ -1619,9 +1575,8 @@ def softmax(input: Tensor, dim: Optional[int] = None, _stacklevel: int = 3, dtyp Use log_softmax instead (it's faster and has better numerical properties). """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function(softmax, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype) + if has_torch_function_unary(input): + return handle_torch_function(softmax, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype) if dim is None: dim = _get_softmax_dim("softmax", input.dim(), _stacklevel) if dtype is None: @@ -1671,9 +1626,8 @@ def gumbel_softmax(logits: Tensor, tau: float = 1, hard: bool = False, eps: floa .. _Link 2: https://arxiv.org/abs/1611.01144 """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(logits): - return handle_torch_function(gumbel_softmax, (logits,), logits, tau=tau, hard=hard, eps=eps, dim=dim) + if has_torch_function_unary(logits): + return handle_torch_function(gumbel_softmax, (logits,), logits, tau=tau, hard=hard, eps=eps, dim=dim) if eps != 1e-10: warnings.warn("`eps` parameter is deprecated and has no effect.") @@ -1710,9 +1664,8 @@ def log_softmax(input: Tensor, dim: Optional[int] = None, _stacklevel: int = 3, If specified, the input tensor is casted to :attr:`dtype` before the operation is performed. This is useful for preventing data type overflows. Default: None. """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function(log_softmax, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype) + if has_torch_function_unary(input): + return handle_torch_function(log_softmax, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype) if dim is None: dim = _get_softmax_dim("log_softmax", input.dim(), _stacklevel) if dtype is None: @@ -1774,9 +1727,8 @@ def hardsigmoid(input: Tensor, inplace: bool = False) -> Tensor: See :class:`~torch.nn.Hardsigmoid` for more details. """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function(hardsigmoid, (input,), input, inplace=inplace) + if has_torch_function_unary(input): + return handle_torch_function(hardsigmoid, (input,), input, inplace=inplace) if inplace: return torch._C._nn.hardsigmoid_(input) return torch._C._nn.hardsigmoid(input) @@ -1796,9 +1748,8 @@ def linear(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tens - Bias: :math:`(out\_features)` - Output: :math:`(N, *, out\_features)` """ - if not torch.jit.is_scripting(): - if has_torch_function_variadic(input, weight): - return handle_torch_function(linear, (input, weight), input, weight, bias=bias) + if has_torch_function_variadic(input, weight): + return handle_torch_function(linear, (input, weight), input, weight, bias=bias) if input.dim() == 2 and bias is not None: # fused op is marginally faster ret = torch.addmm(bias, input, weight.t()) @@ -1846,9 +1797,8 @@ def silu(input: Tensor, inplace: bool = False) -> Tensor: See :class:`~torch.nn.SiLU` for more details. """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function(silu, (input,), input, inplace=inplace) + if has_torch_function_unary(input): + return handle_torch_function(silu, (input,), input, inplace=inplace) if inplace: return torch._C._nn.silu_(input) return torch._C._nn.silu(input) @@ -1871,9 +1821,8 @@ def hardswish(input: Tensor, inplace: bool = False) -> Tensor: .. _`Searching for MobileNetV3`: https://arxiv.org/abs/1905.02244 """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function(hardswish, (input,), input, inplace=inplace) + if has_torch_function_unary(input): + return handle_torch_function(hardswish, (input,), input, inplace=inplace) if inplace: return torch._C._nn.hardswish_(input) return torch._C._nn.hardswish(input) @@ -2058,23 +2007,21 @@ def embedding_bag( tensor([[ 0.3397, 0.3552, 0.5545], [ 0.5893, 0.4386, 0.5882]]) """ - - if not torch.jit.is_scripting(): - if has_torch_function_variadic(input, weight): - return handle_torch_function( - embedding_bag, - (input, weight), - input, - weight, - offsets=offsets, - max_norm=max_norm, - norm_type=norm_type, - scale_grad_by_freq=scale_grad_by_freq, - mode=mode, - sparse=sparse, - per_sample_weights=per_sample_weights, - include_last_offset=include_last_offset, - ) + if has_torch_function_variadic(input, weight): + return handle_torch_function( + embedding_bag, + (input, weight), + input, + weight, + offsets=offsets, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + mode=mode, + sparse=sparse, + per_sample_weights=per_sample_weights, + include_last_offset=include_last_offset, + ) # Check for backward compatibility. # Used to be embedding_bag(weight, input, ...) # Now is embedding_bag(input, weight, ...) @@ -2188,20 +2135,19 @@ def batch_norm( See :class:`~torch.nn.BatchNorm1d`, :class:`~torch.nn.BatchNorm2d`, :class:`~torch.nn.BatchNorm3d` for details. """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - batch_norm, - (input,), - input, - running_mean, - running_var, - weight=weight, - bias=bias, - training=training, - momentum=momentum, - eps=eps, - ) + if has_torch_function_unary(input): + return handle_torch_function( + batch_norm, + (input,), + input, + running_mean, + running_var, + weight=weight, + bias=bias, + training=training, + momentum=momentum, + eps=eps, + ) if training: _verify_batch_size(input.size()) @@ -2227,20 +2173,19 @@ def instance_norm( See :class:`~torch.nn.InstanceNorm1d`, :class:`~torch.nn.InstanceNorm2d`, :class:`~torch.nn.InstanceNorm3d` for details. """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - instance_norm, - (input,), - input, - running_mean=running_mean, - running_var=running_var, - weight=weight, - bias=bias, - use_input_stats=use_input_stats, - momentum=momentum, - eps=eps, - ) + if has_torch_function_unary(input): + return handle_torch_function( + instance_norm, + (input,), + input, + running_mean=running_mean, + running_var=running_var, + weight=weight, + bias=bias, + use_input_stats=use_input_stats, + momentum=momentum, + eps=eps, + ) _verify_batch_size(input.size()) return torch.instance_norm( input, weight, bias, running_mean, running_var, use_input_stats, momentum, eps, torch.backends.cudnn.enabled @@ -2258,11 +2203,10 @@ def layer_norm( See :class:`~torch.nn.LayerNorm` for details. """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - layer_norm, (input,), input, normalized_shape, weight=weight, bias=bias, eps=eps - ) + if has_torch_function_unary(input): + return handle_torch_function( + layer_norm, (input,), input, normalized_shape, weight=weight, bias=bias, eps=eps + ) return torch.layer_norm(input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled) @@ -2273,9 +2217,8 @@ def group_norm( See :class:`~torch.nn.GroupNorm` for details. """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function(group_norm, (input,), input, num_groups, weight=weight, bias=bias, eps=eps) + if has_torch_function_unary(input): + return handle_torch_function(group_norm, (input,), input, num_groups, weight=weight, bias=bias, eps=eps) _verify_batch_size([input.size(0) * input.size(1) // num_groups, num_groups] + list(input.size()[2:])) return torch.group_norm(input, num_groups, weight, bias, eps, torch.backends.cudnn.enabled) @@ -2287,9 +2230,8 @@ def local_response_norm(input: Tensor, size: int, alpha: float = 1e-4, beta: flo See :class:`~torch.nn.LocalResponseNorm` for details. """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function(local_response_norm, (input,), input, size, alpha=alpha, beta=beta, k=k) + if has_torch_function_unary(input): + return handle_torch_function(local_response_norm, (input,), input, size, alpha=alpha, beta=beta, k=k) dim = input.dim() if dim < 3: raise ValueError( @@ -2425,19 +2367,18 @@ def nll_loss( >>> output = F.nll_loss(F.log_softmax(input), target) >>> output.backward() """ - if not torch.jit.is_scripting(): - if has_torch_function_variadic(input, target): - return handle_torch_function( - nll_loss, - (input, target), - input, - target, - weight=weight, - size_average=size_average, - ignore_index=ignore_index, - reduce=reduce, - reduction=reduction, - ) + if has_torch_function_variadic(input, target): + return handle_torch_function( + nll_loss, + (input, target), + input, + target, + weight=weight, + size_average=size_average, + ignore_index=ignore_index, + reduce=reduce, + reduction=reduction, + ) if size_average is not None or reduce is not None: reduction = _Reduction.legacy_get_string(size_average, reduce) dim = input.dim() @@ -2521,20 +2462,19 @@ def poisson_nll_loss( specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` """ - if not torch.jit.is_scripting(): - if has_torch_function_variadic(input, target): - return handle_torch_function( - poisson_nll_loss, - (input, target), - input, - target, - log_input=log_input, - full=full, - size_average=size_average, - eps=eps, - reduce=reduce, - reduction=reduction, - ) + if has_torch_function_variadic(input, target): + return handle_torch_function( + poisson_nll_loss, + (input, target), + input, + target, + log_input=log_input, + full=full, + size_average=size_average, + eps=eps, + reduce=reduce, + reduction=reduction, + ) if size_average is not None or reduce is not None: reduction = _Reduction.legacy_get_string(size_average, reduce) if reduction != "none" and reduction != "mean" and reduction != "sum": @@ -2591,18 +2531,17 @@ def kl_div( :attr:``reduction`` = ``'batchmean'`` which aligns with KL math definition. In the next major release, ``'mean'`` will be changed to be the same as 'batchmean'. """ - if not torch.jit.is_scripting(): - if has_torch_function_variadic(input, target): - return handle_torch_function( - kl_div, - (input, target), - input, - target, - size_average=size_average, - reduce=reduce, - reduction=reduction, - log_target=log_target, - ) + if has_torch_function_variadic(input, target): + return handle_torch_function( + kl_div, + (input, target), + input, + target, + size_average=size_average, + reduce=reduce, + reduction=reduction, + log_target=log_target, + ) if size_average is not None or reduce is not None: reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) else: @@ -2676,19 +2615,18 @@ def cross_entropy( >>> loss = F.cross_entropy(input, target) >>> loss.backward() """ - if not torch.jit.is_scripting(): - if has_torch_function_variadic(input, target): - return handle_torch_function( - cross_entropy, - (input, target), - input, - target, - weight=weight, - size_average=size_average, - ignore_index=ignore_index, - reduce=reduce, - reduction=reduction, - ) + if has_torch_function_variadic(input, target): + return handle_torch_function( + cross_entropy, + (input, target), + input, + target, + weight=weight, + size_average=size_average, + ignore_index=ignore_index, + reduce=reduce, + reduction=reduction, + ) if size_average is not None or reduce is not None: reduction = _Reduction.legacy_get_string(size_average, reduce) return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction) @@ -2735,18 +2673,17 @@ def binary_cross_entropy( >>> loss = F.binary_cross_entropy(F.sigmoid(input), target) >>> loss.backward() """ - if not torch.jit.is_scripting(): - if has_torch_function_variadic(input, target): - return handle_torch_function( - binary_cross_entropy, - (input, target), - input, - target, - weight=weight, - size_average=size_average, - reduce=reduce, - reduction=reduction, - ) + if has_torch_function_variadic(input, target): + return handle_torch_function( + binary_cross_entropy, + (input, target), + input, + target, + weight=weight, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) if size_average is not None or reduce is not None: reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) else: @@ -2808,19 +2745,18 @@ def binary_cross_entropy_with_logits( >>> loss = F.binary_cross_entropy_with_logits(input, target) >>> loss.backward() """ - if not torch.jit.is_scripting(): - if has_torch_function_variadic(input, target): - return handle_torch_function( - binary_cross_entropy_with_logits, - (input, target), - input, - target, - weight=weight, - size_average=size_average, - reduce=reduce, - reduction=reduction, - pos_weight=pos_weight, - ) + if has_torch_function_variadic(input, target): + return handle_torch_function( + binary_cross_entropy_with_logits, + (input, target), + input, + target, + weight=weight, + size_average=size_average, + reduce=reduce, + reduction=reduction, + pos_weight=pos_weight, + ) if size_average is not None or reduce is not None: reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) else: @@ -2845,18 +2781,17 @@ def smooth_l1_loss( See :class:`~torch.nn.SmoothL1Loss` for details. """ - if not torch.jit.is_scripting(): - if has_torch_function_variadic(input, target): - return handle_torch_function( - smooth_l1_loss, - (input, target), - input, - target, - size_average=size_average, - reduce=reduce, - reduction=reduction, - beta=beta, - ) + if has_torch_function_variadic(input, target): + return handle_torch_function( + smooth_l1_loss, + (input, target), + input, + target, + size_average=size_average, + reduce=reduce, + reduction=reduction, + beta=beta, + ) if not (target.size() == input.size()): warnings.warn( "Using a target size ({}) that is different to the input size ({}). " @@ -2884,11 +2819,10 @@ def l1_loss( See :class:`~torch.nn.L1Loss` for details. """ - if not torch.jit.is_scripting(): - if has_torch_function_variadic(input, target): - return handle_torch_function( - l1_loss, (input, target), input, target, size_average=size_average, reduce=reduce, reduction=reduction - ) + if has_torch_function_variadic(input, target): + return handle_torch_function( + l1_loss, (input, target), input, target, size_average=size_average, reduce=reduce, reduction=reduction + ) if not (target.size() == input.size()): warnings.warn( "Using a target size ({}) that is different to the input size ({}). " @@ -2916,11 +2850,10 @@ def mse_loss( See :class:`~torch.nn.MSELoss` for details. """ - if not torch.jit.is_scripting(): - if has_torch_function_variadic(input, target): - return handle_torch_function( - mse_loss, (input, target), input, target, size_average=size_average, reduce=reduce, reduction=reduction - ) + if has_torch_function_variadic(input, target): + return handle_torch_function( + mse_loss, (input, target), input, target, size_average=size_average, reduce=reduce, reduction=reduction + ) if not (target.size() == input.size()): warnings.warn( "Using a target size ({}) that is different to the input size ({}). " @@ -2948,19 +2881,18 @@ def margin_ranking_loss( See :class:`~torch.nn.MarginRankingLoss` for details. """ # noqa - if not torch.jit.is_scripting(): - if has_torch_function_variadic(input1, input2, target): - return handle_torch_function( - margin_ranking_loss, - (input1, input2, target), - input1, - input2, - target, - margin=margin, - size_average=size_average, - reduce=reduce, - reduction=reduction, - ) + if has_torch_function_variadic(input1, input2, target): + return handle_torch_function( + margin_ranking_loss, + (input1, input2, target), + input1, + input2, + target, + margin=margin, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) if size_average is not None or reduce is not None: reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) else: @@ -2987,18 +2919,17 @@ def hinge_embedding_loss( See :class:`~torch.nn.HingeEmbeddingLoss` for details. """ # noqa - if not torch.jit.is_scripting(): - if has_torch_function_variadic(input, target): - return handle_torch_function( - hinge_embedding_loss, - (input, target), - input, - target, - margin=margin, - size_average=size_average, - reduce=reduce, - reduction=reduction, - ) + if has_torch_function_variadic(input, target): + return handle_torch_function( + hinge_embedding_loss, + (input, target), + input, + target, + margin=margin, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) if size_average is not None or reduce is not None: reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) else: @@ -3017,17 +2948,16 @@ def multilabel_margin_loss( See :class:`~torch.nn.MultiLabelMarginLoss` for details. """ - if not torch.jit.is_scripting(): - if has_torch_function_variadic(input, target): - return handle_torch_function( - multilabel_margin_loss, - (input, target), - input, - target, - size_average=size_average, - reduce=reduce, - reduction=reduction, - ) + if has_torch_function_variadic(input, target): + return handle_torch_function( + multilabel_margin_loss, + (input, target), + input, + target, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) if size_average is not None or reduce is not None: reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) else: @@ -3046,11 +2976,10 @@ def soft_margin_loss( See :class:`~torch.nn.SoftMarginLoss` for details. """ - if not torch.jit.is_scripting(): - if has_torch_function_variadic(input, target): - return handle_torch_function( - soft_margin_loss, (input, target), input, target, size_average=size_average, reduce=reduce, reduction=reduction - ) + if has_torch_function_variadic(input, target): + return handle_torch_function( + soft_margin_loss, (input, target), input, target, size_average=size_average, reduce=reduce, reduction=reduction + ) if size_average is not None or reduce is not None: reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) else: @@ -3070,18 +2999,17 @@ def multilabel_soft_margin_loss( See :class:`~torch.nn.MultiLabelSoftMarginLoss` for details. """ - if not torch.jit.is_scripting(): - if has_torch_function_variadic(input, target): - return handle_torch_function( - multilabel_soft_margin_loss, - (input, target), - input, - target, - weight=weight, - size_average=size_average, - reduce=reduce, - reduction=reduction, - ) + if has_torch_function_variadic(input, target): + return handle_torch_function( + multilabel_soft_margin_loss, + (input, target), + input, + target, + weight=weight, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) if size_average is not None or reduce is not None: reduction = _Reduction.legacy_get_string(size_average, reduce) @@ -3117,19 +3045,18 @@ def cosine_embedding_loss( See :class:`~torch.nn.CosineEmbeddingLoss` for details. """ # noqa - if not torch.jit.is_scripting(): - if has_torch_function_variadic(input1, input2, target): - return handle_torch_function( - cosine_embedding_loss, - (input1, input2, target), - input1, - input2, - target, - margin=margin, - size_average=size_average, - reduce=reduce, - reduction=reduction, - ) + if has_torch_function_variadic(input1, input2, target): + return handle_torch_function( + cosine_embedding_loss, + (input1, input2, target), + input1, + input2, + target, + margin=margin, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) if size_average is not None or reduce is not None: reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) else: @@ -3152,20 +3079,19 @@ def multi_margin_loss( See :class:`~torch.nn.MultiMarginLoss` for details. """ - if not torch.jit.is_scripting(): - if has_torch_function_variadic(input, target): - return handle_torch_function( - multi_margin_loss, - (input, target), - input, - target, - p=p, - margin=margin, - weight=weight, - size_average=size_average, - reduce=reduce, - reduction=reduction, - ) + if has_torch_function_variadic(input, target): + return handle_torch_function( + multi_margin_loss, + (input, target), + input, + target, + p=p, + margin=margin, + weight=weight, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) if size_average is not None or reduce is not None: reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) else: @@ -3444,18 +3370,17 @@ def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corne Note: {backward_reproducibility_note} """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - interpolate, - (input,), - input, - size=size, - scale_factor=scale_factor, - mode=mode, - align_corners=align_corners, - recompute_scale_factor=recompute_scale_factor, - ) + if has_torch_function_unary(input): + return handle_torch_function( + interpolate, + (input,), + input, + size=size, + scale_factor=scale_factor, + mode=mode, + align_corners=align_corners, + recompute_scale_factor=recompute_scale_factor, + ) if mode in ("nearest", "area"): if align_corners is not None: @@ -3808,11 +3733,10 @@ def grid_sample( .. _`PIL`: https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/src/libImaging/Resample.c#L51 .. _`OpenCV`: https://github.com/opencv/opencv/blob/f345ed564a06178670750bad59526cfa4033be55/modules/imgproc/src/resize.cpp#L908 """ - if not torch.jit.is_scripting(): - if has_torch_function_variadic(input, grid): - return handle_torch_function( - grid_sample, (input, grid), input, grid, mode=mode, padding_mode=padding_mode, align_corners=align_corners - ) + if has_torch_function_variadic(input, grid): + return handle_torch_function( + grid_sample, (input, grid), input, grid, mode=mode, padding_mode=padding_mode, align_corners=align_corners + ) if mode != "bilinear" and mode != "nearest" and mode != "bicubic": raise ValueError( "nn.functional.grid_sample(): expected mode to be " @@ -3899,9 +3823,8 @@ def affine_grid(theta: Tensor, size: List[int], align_corners: Optional[bool] = along a unit dimension are considered to be at ```0`` (the center of the input image). """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(theta): - return handle_torch_function(affine_grid, (theta,), theta, size, align_corners=align_corners) + if has_torch_function_unary(theta): + return handle_torch_function(affine_grid, (theta,), theta, size, align_corners=align_corners) if align_corners is None: warnings.warn( "Default grid_sample and affine_grid behavior has changed " @@ -4008,9 +3931,8 @@ def _pad(input: Tensor, pad: List[int], mode: str = "constant", value: float = 0 torch.Size([3, 9, 7, 3]) """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function(_pad, (input,), input, pad, mode=mode, value=value) + if has_torch_function_unary(input): + return handle_torch_function(_pad, (input,), input, pad, mode=mode, value=value) assert len(pad) % 2 == 0, "Padding length must be divisible by 2" assert len(pad) // 2 <= input.dim(), "Padding length too large" if mode == "constant": @@ -4191,22 +4113,21 @@ def triplet_margin_loss( r""" See :class:`~torch.nn.TripletMarginLoss` for details """ - if not torch.jit.is_scripting(): - if has_torch_function_variadic(anchor, positive, negative): - return handle_torch_function( - triplet_margin_loss, - (anchor, positive, negative), - anchor, - positive, - negative, - margin=margin, - p=p, - eps=eps, - swap=swap, - size_average=size_average, - reduce=reduce, - reduction=reduction, - ) + if has_torch_function_variadic(anchor, positive, negative): + return handle_torch_function( + triplet_margin_loss, + (anchor, positive, negative), + anchor, + positive, + negative, + margin=margin, + p=p, + eps=eps, + swap=swap, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) if size_average is not None or reduce is not None: reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) else: @@ -4285,9 +4206,8 @@ def normalize(input: Tensor, p: float = 2, dim: int = 1, eps: float = 1e-12, out out (Tensor, optional): the output tensor. If :attr:`out` is used, this operation won't be differentiable. """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function(normalize, (input,), input, p=p, dim=dim, eps=eps, out=out) + if has_torch_function_unary(input): + return handle_torch_function(normalize, (input,), input, p=p, dim=dim, eps=eps, out=out) if out is None: denom = input.norm(p, dim, keepdim=True).clamp_min(eps).expand_as(input) return input / denom @@ -4318,11 +4238,10 @@ def unfold(input, kernel_size, dilation=1, padding=0, stride=1): See :class:`torch.nn.Unfold` for details """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - unfold, (input,), input, kernel_size, dilation=dilation, padding=padding, stride=stride - ) + if has_torch_function_unary(input): + return handle_torch_function( + unfold, (input,), input, kernel_size, dilation=dilation, padding=padding, stride=stride + ) if input.dim() == 4: msg = "{} must be int or 2-tuple for 4D input" assert_int_or_pair(kernel_size, "kernel_size", msg) @@ -4346,11 +4265,10 @@ def fold(input, output_size, kernel_size, dilation=1, padding=0, stride=1): See :class:`torch.nn.Fold` for details """ - if not torch.jit.is_scripting(): - if has_torch_function_unary(input): - return handle_torch_function( - fold, (input,), input, output_size, kernel_size, dilation=dilation, padding=padding, stride=stride - ) + if has_torch_function_unary(input): + return handle_torch_function( + fold, (input,), input, output_size, kernel_size, dilation=dilation, padding=padding, stride=stride + ) if input.dim() == 3: msg = "{} must be int or 2-tuple for 3D input" assert_int_or_pair(output_size, "output_size", msg) @@ -4613,36 +4531,35 @@ def multi_head_attention_forward( - attn_output_weights: :math:`(N, L, S)` where N is the batch size, L is the target sequence length, S is the source sequence length. """ - if not torch.jit.is_scripting(): - tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias) - if has_torch_function(tens_ops): - return handle_torch_function( - multi_head_attention_forward, - tens_ops, - query, - key, - value, - embed_dim_to_check, - num_heads, - in_proj_weight, - in_proj_bias, - bias_k, - bias_v, - add_zero_attn, - dropout_p, - out_proj_weight, - out_proj_bias, - training=training, - key_padding_mask=key_padding_mask, - need_weights=need_weights, - attn_mask=attn_mask, - use_separate_proj_weight=use_separate_proj_weight, - q_proj_weight=q_proj_weight, - k_proj_weight=k_proj_weight, - v_proj_weight=v_proj_weight, - static_k=static_k, - static_v=static_v, - ) + tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias) + if has_torch_function(tens_ops): + return handle_torch_function( + multi_head_attention_forward, + tens_ops, + query, + key, + value, + embed_dim_to_check, + num_heads, + in_proj_weight, + in_proj_bias, + bias_k, + bias_v, + add_zero_attn, + dropout_p, + out_proj_weight, + out_proj_bias, + training=training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + use_separate_proj_weight=use_separate_proj_weight, + q_proj_weight=q_proj_weight, + k_proj_weight=k_proj_weight, + v_proj_weight=v_proj_weight, + static_k=static_k, + static_v=static_v, + ) tgt_len, bsz, embed_dim = query.size() assert embed_dim == embed_dim_to_check # allow MHA to have different sizes for the feature dimension