Skip to content

Commit

Permalink
Treat has_torch_function and object_has_torch_function as static Fals…
Browse files Browse the repository at this point in the history
…e when scripting (#48966)

Summary:
Pull Request resolved: #48966

This PR lets us skip the `if not torch.jit.is_scripting():` guards on `functional` and `nn.functional` by directly registering `has_torch_function` and `object_has_torch_function` to the JIT as statically False.

**Benchmarks**

The benchmark script is kind of long. The reason is that it's testing all four PRs in the stack, plus threading and subprocessing so that the benchmark can utilize multiple cores while still collecting good numbers. Both wall times and instruction counts were collected. This stack changes dozens of operators / functions, but very mechanically such that there are only a handful of codepath changes. Each row is a slightly different code path (e.g. testing in Python, testing in the arg parser, different input types, etc.)

<details>

<summary> Test script </summary>

```
import argparse
import multiprocessing
import multiprocessing.dummy
import os
import pickle
import queue
import random
import sys
import subprocess
import tempfile
import time

import torch
from torch.utils.benchmark import Timer, Compare, Measurement

NUM_CORES = multiprocessing.cpu_count()
ENVS = {
    "ref": "HEAD (current)",
    "torch_fn_overhead_stack_0": "#48963",
    "torch_fn_overhead_stack_1": "#48964",
    "torch_fn_overhead_stack_2": "#48965",
    "torch_fn_overhead_stack_3": "#48966",
}

CALLGRIND_ENVS = tuple(ENVS.keys())

MIN_RUN_TIME = 3
REPLICATES = {
    "longer": 1_000,
    "long": 300,
    "short": 50,
}

CALLGRIND_NUMBER = {
    "overnight": 500_000,
    "long": 250_000,
    "short": 10_000,
}

CALLGRIND_TIMEOUT = {
    "overnight": 800,
    "long": 400,
    "short": 100,
}

SETUP = """
    x = torch.ones((1, 1))
    y = torch.ones((1, 1))
    w_tensor = torch.ones((1, 1), requires_grad=True)
    linear = torch.nn.Linear(1, 1, bias=False)
    linear_w = linear.weight
"""

TASKS = {
    "C++: unary                 `.t()`": "w_tensor.t()",
    "C++: unary  (Parameter)    `.t()`": "linear_w.t()",
    "C++: binary (Parameter)    `mul` ": "x + linear_w",
    "tensor.py: _wrap_type_error_to_not_implemented `__floordiv__`": "x // y",
    "tensor.py: method          `__hash__`": "hash(x)",
    "Python scalar              `__rsub__`": "1 - x",
    "functional.py: (unary)     `unique`": "torch.functional.unique(x)",
    "functional.py: (args)      `atleast_1d`": "torch.functional.atleast_1d((x, y))",
    "nn/functional.py: (unary)  `relu`": "torch.nn.functional.relu(x)",
    "nn/functional.py: (args)   `linear`": "torch.nn.functional.linear(x, w_tensor)",
    "nn/functional.py: (args)   `linear (Parameter)`": "torch.nn.functional.linear(x, linear_w)",
    "Linear(..., bias=False)": "linear(x)",
}

def _worker_main(argv, fn):
    parser = argparse.ArgumentParser()
    parser.add_argument("--output_file", type=str)
    parser.add_argument("--single_task", type=int, default=None)
    parser.add_argument("--length", type=str)
    args = parser.parse_args(argv)
    single_task = args.single_task

    conda_prefix = os.getenv("CONDA_PREFIX")
    assert torch.__file__.startswith(conda_prefix)

    env = os.path.split(conda_prefix)[1]
    assert env in ENVS

    results = []
    for i, (k, stmt) in enumerate(TASKS.items()):
        if single_task is not None and single_task != i:
            continue

        timer = Timer(
            stmt=stmt,
            setup=SETUP,
            sub_label=k,
            description=ENVS[env],
        )
        results.append(fn(timer, args.length))

    with open(args.output_file, "wb") as f:
        pickle.dump(results, f)

def worker_main(argv):
    _worker_main(
        argv,
        lambda timer, _: timer.blocked_autorange(min_run_time=MIN_RUN_TIME)
    )

def callgrind_worker_main(argv):
    _worker_main(
        argv,
        lambda timer, length: timer.collect_callgrind(number=CALLGRIND_NUMBER[length], collect_baseline=False))

def main(argv):
    parser = argparse.ArgumentParser()
    parser.add_argument("--long", action="store_true")
    parser.add_argument("--longer", action="store_true")
    args = parser.parse_args(argv)

    if args.longer:
        length = "longer"
    elif args.long:
        length = "long"
    else:
        length = "short"
    replicates = REPLICATES[length]

    num_workers = int(NUM_CORES // 2)
    tasks = list(ENVS.keys()) * replicates
    random.shuffle(tasks)
    task_queue = queue.Queue()
    for _ in range(replicates):
        envs = list(ENVS.keys())
        random.shuffle(envs)
        for e in envs:
            task_queue.put((e, None))

    callgrind_task_queue = queue.Queue()
    for e in CALLGRIND_ENVS:
        for i, _ in enumerate(TASKS):
            callgrind_task_queue.put((e, i))

    results = []
    callgrind_results = []

    def map_fn(worker_id):
        # Adjacent cores often share cache and maxing out a machine can distort
        # timings so we space them out.
        callgrind_cores = f"{worker_id * 2}-{worker_id * 2 + 1}"
        time_cores = str(worker_id * 2)
        _, output_file = tempfile.mkstemp(suffix=".pkl")
        try:
            loop_tasks = (
                # Callgrind is long running, and then the workers can help with
                # timing after they finish collecting counts.
                (callgrind_task_queue, callgrind_results, "callgrind_worker", callgrind_cores, CALLGRIND_TIMEOUT[length]),
                (task_queue, results, "worker", time_cores, None))

            for queue_i, results_i, mode_i, cores, timeout in loop_tasks:
                while True:
                    try:
                        env, task_i = queue_i.get_nowait()
                    except queue.Empty:
                        break

                    remaining_attempts = 3
                    while True:
                        try:
                            subprocess.run(
                                " ".join([
                                    "source", "activate", env, "&&",
                                    "taskset", "--cpu-list", cores,
                                    "python", os.path.abspath(__file__),
                                    "--mode", mode_i,
                                    "--length", length,
                                    "--output_file", output_file
                                ] + ([] if task_i is None else ["--single_task", str(task_i)])),
                                shell=True,
                                check=True,
                                timeout=timeout,
                            )
                            break

                        except subprocess.TimeoutExpired:
                            # Sometimes Valgrind will hang if there are too many
                            # concurrent runs.
                            remaining_attempts -= 1
                            if not remaining_attempts:
                                print("Too many failed attempts.")
                                raise
                            print(f"Timeout after {timeout} sec. Retrying.")

                    # We don't need a lock, as the GIL is enough.
                    with open(output_file, "rb") as f:
                        results_i.extend(pickle.load(f))

        finally:
            os.remove(output_file)

    with multiprocessing.dummy.Pool(num_workers) as pool:
        st, st_estimate, eta, n_total = time.time(), None, "", len(tasks) * len(TASKS)
        map_job = pool.map_async(map_fn, range(num_workers))
        while not map_job.ready():
            n_complete = len(results)
            if n_complete and len(callgrind_results):
                if st_estimate is None:
                    st_estimate = time.time()
                else:
                    sec_per_element = (time.time() - st_estimate) / n_complete
                    n_remaining = n_total - n_complete
                    eta = f"ETA: {n_remaining * sec_per_element:.0f} sec"

            print(
                f"\r{n_complete} / {n_total}  "
                f"({len(callgrind_results)} / {len(CALLGRIND_ENVS) * len(TASKS)})   "
                f"{eta}".ljust(40), end="")
            sys.stdout.flush()
            time.sleep(2)
    total_time = int(time.time() - st)
    print(f"\nTotal time: {int(total_time // 60)} min, {total_time % 60} sec")

    desc_to_ind = {k: i for i, k in enumerate(ENVS.values())}
    results.sort(key=lambda r: desc_to_ind[r.description])

    # TODO: Compare should be richer and more modular.
    compare = Compare(results)
    compare.trim_significant_figures()
    compare.colorize(rowwise=True)

    # Manually add master vs. overall relative delta t.
    merged_results = {
        (r.description, r.sub_label): r
        for r in Measurement.merge(results)
    }

    cmp_lines = str(compare).splitlines(False)
    print(cmp_lines[0][:-1] + "-" * 15 + "]")
    print(f"{cmp_lines[1]} |{'':>10}\u0394t")
    print(cmp_lines[2] + "-" * 15)
    for l, t in zip(cmp_lines[3:3 + len(TASKS)], TASKS.keys()):
        assert l.strip().startswith(t)
        t0 = merged_results[(ENVS["ref"], t)].median
        t1 = merged_results[(ENVS["torch_fn_overhead_stack_3"], t)].median
        print(f"{l} |{'':>5}{(t1 / t0 - 1) * 100:>6.1f}%")
    print("\n".join(cmp_lines[3 + len(TASKS):]))

    counts_dict = {
        (r.task_spec.description, r.task_spec.sub_label): r.counts(denoise=True)
        for r in callgrind_results
    }

    def rel_diff(x, x0):
        return f"{(x / x0 - 1) * 100:>6.1f}%"

    task_pad = max(len(t) for t in TASKS)
    print(f"\n\nInstruction % change (relative to `{CALLGRIND_ENVS[0]}`)")
    print(" " * (task_pad + 8)  + (" " * 7).join([ENVS[env] for env in CALLGRIND_ENVS[1:]]))
    for t in TASKS:
        values = [counts_dict[(ENVS[env], t)] for env in CALLGRIND_ENVS]

        print(t.ljust(task_pad + 3) + "  ".join([
            rel_diff(v, values[0]).rjust(len(ENVS[env]) + 5)
            for v, env in zip(values[1:], CALLGRIND_ENVS[1:])]))

        print("\033[4m" + "    Instructions per invocation".ljust(task_pad + 3) + "  ".join([
            f"{v // CALLGRIND_NUMBER[length]:.0f}".rjust(len(ENVS[env]) + 5)
            for v, env in zip(values[1:], CALLGRIND_ENVS[1:])]) + "\033[0m")
        print()

    import pdb
    pdb.set_trace()

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--mode", type=str, choices=("main", "worker", "callgrind_worker"), default="main")
    args, remaining = parser.parse_known_args()

    if args.mode == "main":
        main(remaining)

    elif args.mode == "callgrind_worker":
        callgrind_worker_main(remaining)

    else:
        worker_main(remaining)

```

</details>

**Wall time**
<img width="1178" alt="Screen Shot 2020-12-12 at 12 28 13 PM" src="https://user-images.githubusercontent.com/13089297/101994419-284f6a00-3c77-11eb-8dc8-4f69a890302e.png">

<details>

<summary> Longer run (`python test.py --long`) is basically identical. </summary>

<img width="1184" alt="Screen Shot 2020-12-12 at 5 02 47 PM" src="https://user-images.githubusercontent.com/13089297/102000425-2350e180-3c9c-11eb-999e-a95b37e9ef54.png">

</details>

**Callgrind**
<img width="936" alt="Screen Shot 2020-12-12 at 12 28 54 PM" src="https://user-images.githubusercontent.com/13089297/101994421-2e454b00-3c77-11eb-9cd3-8cde550f536e.png">

Test Plan: existing unit tests.

Reviewed By: ezyang

Differential Revision: D25590731

Pulled By: robieta

fbshipit-source-id: fe05305ff22b0e34ced44b60f2e9f07907a099dd
  • Loading branch information
Taylor Robie authored and facebook-github-bot committed Jan 11, 2021
1 parent d31a760 commit 6a3fc0c
Show file tree
Hide file tree
Showing 6 changed files with 614 additions and 708 deletions.
1 change: 1 addition & 0 deletions aten/src/ATen/core/interned_strings.h
Expand Up @@ -294,6 +294,7 @@ namespace c10 {
_(aten, swapdims_) \
_(aten, movedim) \
_(aten, moveaxis) \
_(aten, has_torch_function) \
FORALL_ATEN_BASE_SYMBOLS(_) \
_(onnx, Add) \
_(onnx, Concat) \
Expand Down
5 changes: 4 additions & 1 deletion torch/csrc/jit/frontend/ir_emitter.cpp
Expand Up @@ -1224,8 +1224,11 @@ struct to_ir {
}
auto expr_out = emitToBool(expr.range(), emitExpr(expr));
c10::optional<bool> static_if = c10::nullopt;
if (expr_out->node()->kind() == aten::is_scripting) {
auto kind = expr_out->node()->kind();
if (kind == aten::is_scripting) {
static_if = true;
} else if (kind == aten::has_torch_function) {
static_if = false;
}
// MetaCompile on boolean literals and constants
if (auto maybe_ivalue = toIValue(expr_out)) {
Expand Down
4 changes: 4 additions & 0 deletions torch/csrc/jit/runtime/register_special_ops.cpp
Expand Up @@ -372,6 +372,10 @@ RegisterOperators reg({
TORCH_SELECTIVE_SCHEMA("aten::is_scripting() -> bool"),
[](Stack* stack) { push(stack, true); },
aliasAnalysisFromSchema()),
OperatorGenerator(
TORCH_SELECTIVE_SCHEMA("aten::has_torch_function(...) -> bool"),
[](Stack* stack) { push(stack, false); },
aliasAnalysisFromSchema()),
OperatorGenerator(
TORCH_SELECTIVE_SCHEMA(
"aten::_no_grad_uniform_(Tensor(a!) tensor, float a, float b) -> Tensor(a!)"),
Expand Down
164 changes: 70 additions & 94 deletions torch/functional.py
Expand Up @@ -69,9 +69,8 @@ def broadcast_tensors(*tensors):
tensor([[0, 1, 2],
[0, 1, 2]])
"""
if not torch.jit.is_scripting():
if has_torch_function(tensors):
return handle_torch_function(broadcast_tensors, tensors, *tensors)
if has_torch_function(tensors):
return handle_torch_function(broadcast_tensors, tensors, *tensors)
return _VF.broadcast_tensors(tensors) # type: ignore


Expand Down Expand Up @@ -147,10 +146,9 @@ def split(tensor, split_size_or_sections, dim=0):
[6, 7],
[8, 9]]))
"""
if not torch.jit.is_scripting():
if has_torch_function_unary(tensor):
return handle_torch_function(split, (tensor,), tensor, split_size_or_sections,
dim=dim)
if has_torch_function_unary(tensor):
return handle_torch_function(
split, (tensor,), tensor, split_size_or_sections, dim=dim)
# Overwriting reason:
# This dispatches to two ATen functions depending on the type of
# split_size_or_sections. The branching code is in tensor.py, which we
Expand Down Expand Up @@ -236,11 +234,11 @@ def lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True):
>>> torch.norm(A_ - A)
tensor(2.9802e-08)
"""
if not torch.jit.is_scripting():
if has_torch_function_variadic(LU_data, LU_pivots):
return handle_torch_function(
lu_unpack, (LU_data, LU_pivots), LU_data, LU_pivots, unpack_data=unpack_data,
unpack_pivots=unpack_pivots)
if has_torch_function_variadic(LU_data, LU_pivots):
return handle_torch_function(
lu_unpack, (LU_data, LU_pivots), LU_data, LU_pivots,
unpack_data=unpack_data,
unpack_pivots=unpack_pivots)
shape = LU_data.shape
# In generalized LU factorization, the following shape relations hold:
# A.shape[-2:] == (m, n)
Expand Down Expand Up @@ -301,7 +299,7 @@ def einsum(equation, *operands):
based on the Einstein summation convention.
Einsum allows computing many common multi-dimensional linear algebraic array operations by representing them
in a short-hand format based on the Einstein summation convention, given by :attr:`equation`. The details of
in a short-hand format based on the Einstein summation convention, given by :attr:`equation`. The details of
this format are described below, but the general idea is to label every dimension of the input :attr:`operands`
with some subscript and define which subscripts are part of the output. The output is then computed by summing
the product of the elements of the :attr:`operands` along the dimensions whose subscripts are not part of the
Expand Down Expand Up @@ -387,7 +385,7 @@ def einsum(equation, *operands):
# batch permute
>>> A = torch.randn(2, 3, 4, 5)
>>> torch.einsum('...ij->...ji', A).shape
>>> torch.einsum('...ij->...ji', A).shape
torch.Size([2, 3, 5, 4])
# equivalent to torch.nn.functional.bilinear
Expand All @@ -398,9 +396,8 @@ def einsum(equation, *operands):
tensor([[-0.3430, -5.2405, 0.4494],
[ 0.3311, 5.5201, -3.0356]])
"""
if not torch.jit.is_scripting():
if has_torch_function(operands):
return handle_torch_function(einsum, operands, equation, *operands)
if has_torch_function(operands):
return handle_torch_function(einsum, operands, equation, *operands)
if len(operands) == 1 and isinstance(operands[0], (list, tuple)):
# the old interface of passing the operands as one list argument
_operands = operands[0]
Expand Down Expand Up @@ -448,9 +445,8 @@ def meshgrid(*tensors):


def _meshgrid(*tensors):
if not torch.jit.is_scripting():
if has_torch_function(tensors):
return handle_torch_function(meshgrid, tensors, *tensors)
if has_torch_function(tensors):
return handle_torch_function(meshgrid, tensors, *tensors)
if len(tensors) == 1 and isinstance(tensors[0], (list, tuple)):
# the old interface of passing the operands as one list argument
tensors = tensors[0] # type: ignore
Expand Down Expand Up @@ -568,12 +564,11 @@ def stft(input: Tensor, n_fft: int, hop_length: Optional[int] = None,
Tensor: A tensor containing the STFT result with shape described above
"""
if not torch.jit.is_scripting():
if has_torch_function_unary(input):
return handle_torch_function(
stft, (input,), input, n_fft, hop_length=hop_length, win_length=win_length,
window=window, center=center, pad_mode=pad_mode, normalized=normalized,
onesided=onesided, return_complex=return_complex)
if has_torch_function_unary(input):
return handle_torch_function(
stft, (input,), input, n_fft, hop_length=hop_length, win_length=win_length,
window=window, center=center, pad_mode=pad_mode, normalized=normalized,
onesided=onesided, return_complex=return_complex)
# TODO: after having proper ways to map Python strings to ATen Enum, move
# this and F.pad to ATen.
if center:
Expand Down Expand Up @@ -650,12 +645,11 @@ def istft(input: Tensor, n_fft: int, hop_length: Optional[int] = None,
Returns:
Tensor: Least squares estimation of the original signal of size (..., signal_length)
"""
if not torch.jit.is_scripting():
if has_torch_function_unary(input):
return handle_torch_function(
istft, (input,), input, n_fft, hop_length=hop_length, win_length=win_length,
window=window, center=center, normalized=normalized, onesided=onesided,
length=length, return_complex=return_complex)
if has_torch_function_unary(input):
return handle_torch_function(
istft, (input,), input, n_fft, hop_length=hop_length, win_length=win_length,
window=window, center=center, normalized=normalized, onesided=onesided,
length=length, return_complex=return_complex)

return _VF.istft(input, n_fft, hop_length, win_length, window, center, # type: ignore
normalized, onesided, length, return_complex)
Expand Down Expand Up @@ -734,11 +728,10 @@ def _unique_impl(input: Tensor, sorted: bool = True,
[ 1, 2]])
"""
if not torch.jit.is_scripting():
if has_torch_function_unary(input):
return handle_torch_function(
unique, (input,), input, sorted=sorted, return_inverse=return_inverse,
return_counts=return_counts, dim=dim)
if has_torch_function_unary(input):
return handle_torch_function(
unique, (input,), input, sorted=sorted, return_inverse=return_inverse,
return_counts=return_counts, dim=dim)

if dim is not None:
output, inverse_indices, counts = _VF.unique_dim( # type: ignore
Expand Down Expand Up @@ -810,11 +803,10 @@ def _unique_consecutive_impl(input: Tensor, return_inverse: bool = False,
>>> counts
tensor([2, 2, 1, 2, 1])
"""
if not torch.jit.is_scripting():
if has_torch_function_unary(input):
return handle_torch_function(
unique_consecutive, (input,), input, return_inverse=return_inverse,
return_counts=return_counts, dim=dim)
if has_torch_function_unary(input):
return handle_torch_function(
unique_consecutive, (input,), input, return_inverse=return_inverse,
return_counts=return_counts, dim=dim)
output, inverse_indices, counts = _VF.unique_consecutive( # type: ignore
input, return_inverse=return_inverse, return_counts=return_counts, dim=dim)
return output, inverse_indices, counts
Expand All @@ -823,9 +815,8 @@ def _unique_consecutive_impl(input: Tensor, return_inverse: bool = False,
def _return_counts(input, sorted=True, return_inverse=False, return_counts=False, dim=None):
# type: (Tensor, bool, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]

if not torch.jit.is_scripting():
if has_torch_function_unary(input):
return _unique_impl(input, sorted, return_inverse, return_counts, dim)
if has_torch_function_unary(input):
return _unique_impl(input, sorted, return_inverse, return_counts, dim)

output, _, counts = _unique_impl(input, sorted, return_inverse, return_counts, dim)
return output, counts
Expand All @@ -834,9 +825,8 @@ def _return_counts(input, sorted=True, return_inverse=False, return_counts=False
def _return_output(input, sorted=True, return_inverse=False, return_counts=False, dim=None):
# type: (Tensor, bool, bool, bool, Optional[int]) -> Tensor

if not torch.jit.is_scripting():
if has_torch_function_unary(input):
return _unique_impl(input, sorted, return_inverse, return_counts, dim)
if has_torch_function_unary(input):
return _unique_impl(input, sorted, return_inverse, return_counts, dim)

output, _, _ = _unique_impl(input, sorted, return_inverse, return_counts, dim)
return output
Expand All @@ -845,9 +835,8 @@ def _return_output(input, sorted=True, return_inverse=False, return_counts=False
def _return_inverse(input, sorted=True, return_inverse=False, return_counts=False, dim=None):
# type: (Tensor, bool, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]

if not torch.jit.is_scripting():
if has_torch_function_unary(input):
return _unique_impl(input, sorted, return_inverse, return_counts, dim)
if has_torch_function_unary(input):
return _unique_impl(input, sorted, return_inverse, return_counts, dim)

output, inverse_indices, _ = _unique_impl(input, sorted, return_inverse, return_counts, dim)
return output, inverse_indices
Expand Down Expand Up @@ -888,9 +877,8 @@ def _return_inverse(input, sorted=True, return_inverse=False, return_counts=Fals
def _consecutive_return_counts(input, return_inverse=False, return_counts=False, dim=None):
# type: (Tensor, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]

if not torch.jit.is_scripting():
if has_torch_function_unary(input):
return _unique_consecutive_impl(input, return_inverse, return_counts, dim)
if has_torch_function_unary(input):
return _unique_consecutive_impl(input, return_inverse, return_counts, dim)

output, _, counts = _unique_consecutive_impl(input, return_inverse, return_counts, dim)
return output, counts
Expand All @@ -899,9 +887,8 @@ def _consecutive_return_counts(input, return_inverse=False, return_counts=False,
def _consecutive_return_output(input, return_inverse=False, return_counts=False, dim=None):
# type: (Tensor, bool, bool, Optional[int]) -> Tensor

if not torch.jit.is_scripting():
if has_torch_function_unary(input):
return _unique_consecutive_impl(input, return_inverse, return_counts, dim)
if has_torch_function_unary(input):
return _unique_consecutive_impl(input, return_inverse, return_counts, dim)

output, _, _ = _unique_consecutive_impl(input, return_inverse, return_counts, dim)
return output
Expand All @@ -910,9 +897,8 @@ def _consecutive_return_output(input, return_inverse=False, return_counts=False,
def _consecutive_return_inverse(input, return_inverse=False, return_counts=False, dim=None):
# type: (Tensor, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]

if not torch.jit.is_scripting():
if has_torch_function_unary(input):
return _unique_consecutive_impl(input, return_inverse, return_counts, dim)
if has_torch_function_unary(input):
return _unique_consecutive_impl(input, return_inverse, return_counts, dim)

output, inverse_indices, _ = _unique_consecutive_impl(input, return_inverse, return_counts, dim)
return output, inverse_indices
Expand Down Expand Up @@ -1000,9 +986,8 @@ def tensordot(a, b, dims=2, out=None):
[ 1.5513, -14.4737, -6.5113],
[ -0.2850, 4.2573, -3.5997]])
"""
if not torch.jit.is_scripting():
if has_torch_function_variadic(a, b):
return handle_torch_function(tensordot, (a, b), a, b, dims=dims)
if has_torch_function_variadic(a, b):
return handle_torch_function(tensordot, (a, b), a, b, dims=dims)
if isinstance(dims, (list, tuple)) or \
(isinstance(dims, torch.Tensor) and dims.numel() > 1):
dims_a, dims_b = dims
Expand Down Expand Up @@ -1046,9 +1031,8 @@ def cartesian_prod(*tensors):
[3, 4],
[3, 5]])
"""
if not torch.jit.is_scripting():
if has_torch_function(tensors):
return handle_torch_function(cartesian_prod, tensors, *tensors)
if has_torch_function(tensors):
return handle_torch_function(cartesian_prod, tensors, *tensors)
return _VF.cartesian_prod(tensors) # type: ignore

def block_diag(*tensors):
Expand Down Expand Up @@ -1128,10 +1112,9 @@ def cdist(x1, x2, p=2., compute_mode='use_mm_for_euclid_dist_if_necessary'):
[2.7138, 3.8322],
[2.2830, 0.3791]])
"""
if not torch.jit.is_scripting():
if has_torch_function_variadic(x1, x2):
return handle_torch_function(
cdist, (x1, x2), x1, x2, p=p, compute_mode=compute_mode)
if has_torch_function_variadic(x1, x2):
return handle_torch_function(
cdist, (x1, x2), x1, x2, p=p, compute_mode=compute_mode)
if compute_mode == 'use_mm_for_euclid_dist_if_necessary':
return _VF.cdist(x1, x2, p, None) # type: ignore
elif compute_mode == 'use_mm_for_euclid_dist':
Expand Down Expand Up @@ -1168,9 +1151,8 @@ def atleast_1d(*tensors):
>>> torch.atleast_1d((x,y))
(tensor([0.5000]), tensor([1.]))
"""
if not torch.jit.is_scripting():
if has_torch_function(tensors):
return handle_torch_function(atleast_1d, tensors, *tensors)
if has_torch_function(tensors):
return handle_torch_function(atleast_1d, tensors, *tensors)
if len(tensors) == 1:
tensors = tensors[0]
return _VF.atleast_1d(tensors) # type: ignore
Expand Down Expand Up @@ -1203,9 +1185,8 @@ def atleast_2d(*tensors):
>>> torch.atleast_2d((x,y))
(tensor([[0.5000]]), tensor([[1.]]))
"""
if not torch.jit.is_scripting():
if has_torch_function(tensors):
return handle_torch_function(atleast_2d, tensors, *tensors)
if has_torch_function(tensors):
return handle_torch_function(atleast_2d, tensors, *tensors)
if len(tensors) == 1:
tensors = tensors[0]
return _VF.atleast_2d(tensors) # type: ignore
Expand Down Expand Up @@ -1247,9 +1228,8 @@ def atleast_3d(*tensors):
>>> torch.atleast_3d((x,y))
(tensor([[[0.5000]]]), tensor([[[1.]]]))
"""
if not torch.jit.is_scripting():
if has_torch_function(tensors):
return handle_torch_function(atleast_3d, tensors, *tensors)
if has_torch_function(tensors):
return handle_torch_function(atleast_3d, tensors, *tensors)
if len(tensors) == 1:
tensors = tensors[0]
return _VF.atleast_3d(tensors) # type: ignore
Expand Down Expand Up @@ -1380,10 +1360,9 @@ def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None): # noqa
(tensor(3.7417), tensor(11.2250))
"""

if not torch.jit.is_scripting():
if has_torch_function_unary(input):
return handle_torch_function(
norm, (input,), input, p=p, dim=dim, keepdim=keepdim, out=out, dtype=dtype)
if has_torch_function_unary(input):
return handle_torch_function(
norm, (input,), input, p=p, dim=dim, keepdim=keepdim, out=out, dtype=dtype)

ndim = input.dim()

Expand Down Expand Up @@ -1476,9 +1455,8 @@ def chain_matmul(*matrices):
.. _`[CLRS]`: https://mitpress.mit.edu/books/introduction-algorithms-third-edition
"""
if not torch.jit.is_scripting():
if has_torch_function(matrices):
return handle_torch_function(chain_matmul, matrices, *matrices)
if has_torch_function(matrices):
return handle_torch_function(chain_matmul, matrices, *matrices)
return _VF.chain_matmul(matrices) # type: ignore


Expand Down Expand Up @@ -1596,10 +1574,9 @@ def _check_list_size(out_len: int, get_infos: bool, out: _ListOrSeq) -> None:

def _lu_with_infos(A, pivot=True, get_infos=False, out=None):
# type: (Tensor, bool, bool, Optional[Tuple[Tensor, Tensor, Tensor]]) -> Tuple[Tensor, Tensor, Tensor]
if not torch.jit.is_scripting():
if has_torch_function_unary(A):
return handle_torch_function(
lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out)
if has_torch_function_unary(A):
return handle_torch_function(
lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out)
result = _lu_impl(A, pivot, get_infos, out)
if out is not None:
_check_list_size(len(out), get_infos, out)
Expand All @@ -1612,10 +1589,9 @@ def _lu_with_infos(A, pivot=True, get_infos=False, out=None):
def _lu_no_infos(A, pivot=True, get_infos=False, out=None):
# type: (Tensor, bool, bool, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tensor]
# need to check for torch_function here so that we exit if
if not torch.jit.is_scripting():
if has_torch_function_unary(A):
return handle_torch_function(
lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out)
if has_torch_function_unary(A):
return handle_torch_function(
lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out)
result = _lu_impl(A, pivot, get_infos, out)
if out is not None:
_check_list_size(len(out), get_infos, out)
Expand Down

0 comments on commit 6a3fc0c

Please sign in to comment.