Skip to content

Commit

Permalink
Update on "Treat has_torch_function and object_has_torch_function as …
Browse files Browse the repository at this point in the history
…static False when scripting"


This PR lets us skip the `if not torch.jit.is_scripting():` guards on `functional` and `nn.functional` by directly registering `has_torch_function` and `object_has_torch_function` to the JIT as statically False.

**Benchmarks**

The benchmark script is kind of long. The reason is that it's testing all four PRs in the stack, plus threading and subprocessing so that the benchmark can utilize multiple cores while still collecting good numbers. Both wall times and instruction counts were collected. This stack changes dozens of operators / functions, but very mechanically such that there are only a handful of codepath changes. Each row is a slightly different code path (e.g. testing in Python, testing in the arg parser, different input types, etc.)

<details>

<summary> Test script </summary>

```
import argparse
import multiprocessing
import multiprocessing.dummy
import os
import pickle
import queue
import random
import sys
import subprocess
import tempfile
import time

import torch
from torch.utils.benchmark import Timer, Compare, Measurement


NUM_CORES = multiprocessing.cpu_count()
ENVS = {
    "ref": "HEAD (current)",
    "torch_fn_overhead_stack_0": "#48963",
    "torch_fn_overhead_stack_1": "#48964",
    "torch_fn_overhead_stack_2": "#48965",
    "torch_fn_overhead_stack_3": "#48966",
}

CALLGRIND_ENVS = tuple(ENVS.keys())


MIN_RUN_TIME = 3
REPLICATES = {
    "longer": 1_000,
    "long": 300,
    "short": 50,
}

CALLGRIND_NUMBER = {
    "overnight": 500_000,
    "long": 250_000,
    "short": 10_000,
}

CALLGRIND_TIMEOUT = {
    "overnight": 800,
    "long": 400,
    "short": 100,
}

SETUP = """
    x = torch.ones((1, 1))
    y = torch.ones((1, 1))
    w_tensor = torch.ones((1, 1), requires_grad=True)
    linear = torch.nn.Linear(1, 1, bias=False)
    linear_w = linear.weight
"""

TASKS = {
    "C++: unary                 `.t()`": "w_tensor.t()",
    "C++: unary  (Parameter)    `.t()`": "linear_w.t()",
    "C++: binary (Parameter)    `mul` ": "x + linear_w",
    "tensor.py: _wrap_type_error_to_not_implemented `__floordiv__`": "x // y",
    "tensor.py: method          `__hash__`": "hash(x)",
    "Python scalar              `__rsub__`": "1 - x",
    "functional.py: (unary)     `unique`": "torch.functional.unique(x)",
    "functional.py: (args)      `atleast_1d`": "torch.functional.atleast_1d((x, y))",
    "nn/functional.py: (unary)  `relu`": "torch.nn.functional.relu(x)",
    "nn/functional.py: (args)   `linear`": "torch.nn.functional.linear(x, w_tensor)",
    "nn/functional.py: (args)   `linear (Parameter)`": "torch.nn.functional.linear(x, linear_w)",
    "Linear(..., bias=False)": "linear(x)",
}


def _worker_main(argv, fn):
    parser = argparse.ArgumentParser()
    parser.add_argument("--output_file", type=str)
    parser.add_argument("--single_task", type=int, default=None)
    parser.add_argument("--length", type=str)
    args = parser.parse_args(argv)
    single_task = args.single_task

    conda_prefix = os.getenv("CONDA_PREFIX")
    assert torch.__file__.startswith(conda_prefix)

    env = os.path.split(conda_prefix)[1]
    assert env in ENVS

    results = []
    for i, (k, stmt) in enumerate(TASKS.items()):
        if single_task is not None and single_task != i:
            continue

        timer = Timer(
            stmt=stmt,
            setup=SETUP,
            sub_label=k,
            description=ENVS[env],
        )
        results.append(fn(timer, args.length))

    with open(args.output_file, "wb") as f:
        pickle.dump(results, f)


def worker_main(argv):
    _worker_main(
        argv,
        lambda timer, _: timer.blocked_autorange(min_run_time=MIN_RUN_TIME)
    )


def callgrind_worker_main(argv):
    _worker_main(
        argv,
        lambda timer, length: timer.collect_callgrind(number=CALLGRIND_NUMBER[length], collect_baseline=False))


def main(argv):
    parser = argparse.ArgumentParser()
    parser.add_argument("--long", action="store_true")
    parser.add_argument("--longer", action="store_true")
    args = parser.parse_args(argv)

    if args.longer:
        length = "longer"
    elif args.long:
        length = "long"
    else:
        length = "short"
    replicates = REPLICATES[length]

    num_workers = int(NUM_CORES // 2)
    tasks = list(ENVS.keys()) * replicates
    random.shuffle(tasks)
    task_queue = queue.Queue()
    for _ in range(replicates):
        envs = list(ENVS.keys())
        random.shuffle(envs)
        for e in envs:
            task_queue.put((e, None))

    callgrind_task_queue = queue.Queue()
    for e in CALLGRIND_ENVS:
        for i, _ in enumerate(TASKS):
            callgrind_task_queue.put((e, i))

    results = []
    callgrind_results = []

    def map_fn(worker_id):
        # Adjacent cores often share cache and maxing out a machine can distort
        # timings so we space them out.
        callgrind_cores = f"{worker_id * 2}-{worker_id * 2 + 1}"
        time_cores = str(worker_id * 2)
        _, output_file = tempfile.mkstemp(suffix=".pkl")
        try:
            loop_tasks = (
                # Callgrind is long running, and then the workers can help with
                # timing after they finish collecting counts.
                (callgrind_task_queue, callgrind_results, "callgrind_worker", callgrind_cores, CALLGRIND_TIMEOUT[length]),
                (task_queue, results, "worker", time_cores, None))

            for queue_i, results_i, mode_i, cores, timeout in loop_tasks:
                while True:
                    try:
                        env, task_i = queue_i.get_nowait()
                    except queue.Empty:
                        break

                    remaining_attempts = 3
                    while True:
                        try:
                            subprocess.run(
                                " ".join([
                                    "source", "activate", env, "&&",
                                    "taskset", "--cpu-list", cores,
                                    "python", os.path.abspath(__file__),
                                    "--mode", mode_i,
                                    "--length", length,
                                    "--output_file", output_file
                                ] + ([] if task_i is None else ["--single_task", str(task_i)])),
                                shell=True,
                                check=True,
                                timeout=timeout,
                            )
                            break

                        except subprocess.TimeoutExpired:
                            # Sometimes Valgrind will hang if there are too many
                            # concurrent runs.
                            remaining_attempts -= 1
                            if not remaining_attempts:
                                print("Too many failed attempts.")
                                raise
                            print(f"Timeout after {timeout} sec. Retrying.")

                    # We don't need a lock, as the GIL is enough.
                    with open(output_file, "rb") as f:
                        results_i.extend(pickle.load(f))

        finally:
            os.remove(output_file)

    with multiprocessing.dummy.Pool(num_workers) as pool:
        st, st_estimate, eta, n_total = time.time(), None, "", len(tasks) * len(TASKS)
        map_job = pool.map_async(map_fn, range(num_workers))
        while not map_job.ready():
            n_complete = len(results)
            if n_complete and len(callgrind_results):
                if st_estimate is None:
                    st_estimate = time.time()
                else:
                    sec_per_element = (time.time() - st_estimate) / n_complete
                    n_remaining = n_total - n_complete
                    eta = f"ETA: {n_remaining * sec_per_element:.0f} sec"

            print(
                f"\r{n_complete} / {n_total}  "
                f"({len(callgrind_results)} / {len(CALLGRIND_ENVS) * len(TASKS)})   "
                f"{eta}".ljust(40), end="")
            sys.stdout.flush()
            time.sleep(2)
    total_time = int(time.time() - st)
    print(f"\nTotal time: {int(total_time // 60)} min, {total_time % 60} sec")

    desc_to_ind = {k: i for i, k in enumerate(ENVS.values())}
    results.sort(key=lambda r: desc_to_ind[r.description])

    # TODO: Compare should be richer and more modular.
    compare = Compare(results)
    compare.trim_significant_figures()
    compare.colorize(rowwise=True)

    # Manually add master vs. overall relative delta t.
    merged_results = {
        (r.description, r.sub_label): r
        for r in Measurement.merge(results)
    }

    cmp_lines = str(compare).splitlines(False)
    print(cmp_lines[0][:-1] + "-" * 15 + "]")
    print(f"{cmp_lines[1]} |{'':>10}\u0394t")
    print(cmp_lines[2] + "-" * 15)
    for l, t in zip(cmp_lines[3:3 + len(TASKS)], TASKS.keys()):
        assert l.strip().startswith(t)
        t0 = merged_results[(ENVS["ref"], t)].median
        t1 = merged_results[(ENVS["torch_fn_overhead_stack_3"], t)].median
        print(f"{l} |{'':>5}{(t1 / t0 - 1) * 100:>6.1f}%")
    print("\n".join(cmp_lines[3 + len(TASKS):]))


    counts_dict = {
        (r.task_spec.description, r.task_spec.sub_label): r.counts(denoise=True)
        for r in callgrind_results
    }

    def rel_diff(x, x0):
        return f"{(x / x0 - 1) * 100:>6.1f}%"

    task_pad = max(len(t) for t in TASKS)
    print(f"\n\nInstruction % change (relative to `{CALLGRIND_ENVS[0]}`)")
    print(" " * (task_pad + 8)  + (" " * 7).join([ENVS[env] for env in CALLGRIND_ENVS[1:]]))
    for t in TASKS:
        values = [counts_dict[(ENVS[env], t)] for env in CALLGRIND_ENVS]

        print(t.ljust(task_pad + 3) + "  ".join([
            rel_diff(v, values[0]).rjust(len(ENVS[env]) + 5)
            for v, env in zip(values[1:], CALLGRIND_ENVS[1:])]))

        print("\033[4m" + "    Instructions per invocation".ljust(task_pad + 3) + "  ".join([
            f"{v // CALLGRIND_NUMBER[length]:.0f}".rjust(len(ENVS[env]) + 5)
            for v, env in zip(values[1:], CALLGRIND_ENVS[1:])]) + "\033[0m")
        print()

    import pdb
    pdb.set_trace()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--mode", type=str, choices=("main", "worker", "callgrind_worker"), default="main")
    args, remaining = parser.parse_known_args()

    if args.mode == "main":
        main(remaining)

    elif args.mode == "callgrind_worker":
        callgrind_worker_main(remaining)

    else:
        worker_main(remaining)

```

</details>

**Wall time**
<img width="1178" alt="Screen Shot 2020-12-12 at 12 28 13 PM" src="https://user-images.githubusercontent.com/13089297/101994419-284f6a00-3c77-11eb-8dc8-4f69a890302e.png">

<details>

<summary> Longer run (`python test.py --long`) is basically identical. </summary>

<img width="1184" alt="Screen Shot 2020-12-12 at 5 02 47 PM" src="https://user-images.githubusercontent.com/13089297/102000425-2350e180-3c9c-11eb-999e-a95b37e9ef54.png">

</details>

**Callgrind**
<img width="936" alt="Screen Shot 2020-12-12 at 12 28 54 PM" src="https://user-images.githubusercontent.com/13089297/101994421-2e454b00-3c77-11eb-9cd3-8cde550f536e.png">

Test plan: existing unit tests.

Differential Revision: [D25590731](https://our.internmc.facebook.com/intern/diff/D25590731)

[ghstack-poisoned]
  • Loading branch information
Taylor Robie committed Jan 6, 2021
2 parents a7c9cfd + ba09e88 commit 3a3a097
Show file tree
Hide file tree
Showing 369 changed files with 5,288 additions and 2,014 deletions.
9 changes: 8 additions & 1 deletion .circleci/scripts/binary_linux_test.sh
Expand Up @@ -51,7 +51,14 @@ if [[ "$PACKAGE_TYPE" == conda ]]; then
else
cu_ver="${DESIRED_CUDA:2:2}.${DESIRED_CUDA:4}"
fi
retry conda install \${EXTRA_CONDA_FLAGS} -yq -c nvidia -c pytorch "cudatoolkit=\${cu_ver}"
(
# For some reason conda likes to re-activate the conda environment when attempting this install
# which means that a deactivate is run and some variables might not exist when that happens,
# namely CONDA_MKL_INTERFACE_LAYER_BACKUP from libblas so let's just ignore unbound variables when
# it comes to the conda installation commands
set +u
retry conda install \${EXTRA_CONDA_FLAGS} -yq -c nvidia -c pytorch "cudatoolkit=\${cu_ver}"
)
fi
elif [[ "$PACKAGE_TYPE" != libtorch ]]; then
pip install "\$pkg"
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/update_s3_htmls.yml
Expand Up @@ -9,6 +9,7 @@ on:
jobs:
update-html:
runs-on: ubuntu-latest
if: ${{ github.repository_owner == 'pytorch' }}
strategy:
matrix:
prefix: ["whl", "whl/test", "whl/nightly"]
Expand Down
6 changes: 3 additions & 3 deletions .jenkins/pytorch/README.md
Expand Up @@ -10,9 +10,9 @@ it is very easy to run these tests yourself:
``registry.pytorch.org/pytorch/pytorch-$BUILD_ENVIRONMENT:$DOCKER_VERSION``,
where ``$BUILD_ENVIRONMENT`` is one of the build environments
enumerated in
[pytorch-dockerfiles](https://github.com/pietern/pytorch-dockerfiles/blob/master/build.sh)
[pytorch-dockerfiles](https://github.com/pytorch/pytorch/blob/master/.circleci/docker/build.sh). The dockerfile used by jenkins can be found under the `.circle` [directory](https://github.com/pytorch/pytorch/blob/master/.circleci/docker)

2. Run ``docker -it -u jenkins $DOCKER_IMAGE``, clone PyTorch and
2. Run ``docker run -it -u jenkins $DOCKER_IMAGE``, clone PyTorch and
run one of the scripts in this directory.

The Docker images are designed so that any "reasonable" build commands
Expand All @@ -38,5 +38,5 @@ mechanisms we use:
build scripts.

- We reroute well known paths like `/usr/bin/gcc` to alternate
implementations with `update-alternatives, instead of setting
implementations with `update-alternatives`, instead of setting
`CC` and `CXX` in our implementations.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Expand Up @@ -316,7 +316,7 @@ set(OP_DEPENDENCY "" CACHE STRING
# symbol lookup error: miniconda3/envs/pytorch-py3.7/lib/libmkl_intel_lp64.so: undefined symbol: mkl_blas_dsyrk
# https://software.intel.com/en-us/articles/symbol-lookup-error-when-linking-intel-mkl-with-gcc-on-ubuntu
if(LINUX)
set(CMAKE_SHARED_LINKER_FLAGS "-Wl,--no-as-needed")
set(CMAKE_SHARED_LINKER_FLAGS "-Wl,--no-as-needed ${CMAKE_SHARED_LINKER_FLAGS}")
endif()

if(MSVC)
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/ATen.h
Expand Up @@ -31,3 +31,4 @@
#include <c10/util/Exception.h>
#include <ATen/core/UnsafeFromTH.h>
#include <ATen/core/ivalue.h>
#include <ATen/core/jit_type.h>
28 changes: 28 additions & 0 deletions aten/src/ATen/BatchingRegistrations.cpp
Expand Up @@ -287,6 +287,25 @@ Tensor squeeze_dim_batching_rule(const Tensor& self, int64_t dim) {
return self_physical.getPhysicalToLogicalMap().apply(result);
}

Tensor trace_batching_rule(const Tensor& self) {
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
// Batched Diagonal View
auto self_diag = at::diagonal(self_physical.tensor(), /*offset*/0, /*dim1*/-2, /*dim2*/-1);
auto result = at::sum(self_diag, -1);
return self_physical.getPhysicalToLogicalMap().apply(result);
}

Tensor trace_backward_batching_rule(const Tensor& grad, IntArrayRef input_sizes) {
auto grad_physical = MultiBatchVmapTransform::logicalToPhysical(grad);
auto grad_input = at::zeros(grad_physical.getPhysicalShape(input_sizes), grad.options());
// Batched Diagonal View
auto grad_input_diag = at::diagonal(grad_input, /*offset*/0, /*dim1*/-2, /*dim2*/-1);
// Append a dimension of size one to the grad output
auto grad_physical_tensor = grad_physical.tensor().unsqueeze(-1);
grad_input_diag.copy_(grad_physical_tensor);
return grad_physical.getPhysicalToLogicalMap().apply(grad_input);
}

Tensor transpose_int_batching_rule(const Tensor& self, int64_t dim0, int64_t dim1) {
// PyTorch has a special case where scalar_tensor.transpose(dim0, dim1) works
// for dim0, dim1 in {0, -1} and returns the scalar tensor. If the following happens:
Expand Down Expand Up @@ -1029,6 +1048,7 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
m.impl("squeeze", squeeze_batching_rule);
m.impl("squeeze.dim", squeeze_dim_batching_rule);
m.impl("t", native::t); // composite wrt autograd
m.impl("trace", trace_batching_rule);
m.impl("transpose.int", transpose_int_batching_rule);
m.impl("unbind.int", unbind_batching_rule);
m.impl("unfold", unfold_batching_rule);
Expand Down Expand Up @@ -1089,6 +1109,7 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
#undef TO_BATCHING_RULE
m.impl("clone", clone_batching_rule);

using TensorTensorScalarType = Tensor (*)(const Tensor&, const Tensor&, Scalar);
using TensorTensorType = Tensor (*)(const Tensor&, const Tensor&);
using TensorScalarType = Tensor (*)(const Tensor&, Scalar);

Expand All @@ -1115,6 +1136,12 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
m.impl("pow.Scalar", pow_scalar_Tensor_batching_rule);

m.impl("sigmoid_backward", binary_pointwise_batching_rule<TensorTensorType, at::sigmoid_backward>);
m.impl(
"threshold_backward",
binary_pointwise_batching_rule<
TensorTensorScalarType,
at::threshold_backward,
Scalar>);

// for at::result_type, call the native::result_type implementation.
// We don't have to do anything special because native::result_type operates
Expand Down Expand Up @@ -1150,6 +1177,7 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
// backward operators
m.impl("select_backward", select_backward_batching_rule);
m.impl("slice_backward", slice_backward_batching_rule);
m.impl("trace_backward", trace_backward_batching_rule);
m.impl("diagonal_backward", diagonal_backward_batching_rule);

// Tensor.new_* operators
Expand Down
20 changes: 0 additions & 20 deletions aten/src/ATen/Dispatch.h
Expand Up @@ -93,26 +93,6 @@ inline constexpr bool should_include_kernel_dtype(
return __VA_ARGS__(); \
}

// This macro should be used to skip bfloat16 dispatch on non-ROCm platforms and
// should be removed once the bfloat16 bringup is complete on other platforms.
// This is supposed to be used as a wrapper around the lambda function passed to
// the dispatch macro and will conditionally dispatch ops with bfloat16 type
// only on ROCm.
#if !defined(__HIP_PLATFORM_HCC__)
#define AT_SKIP_BFLOAT16_IF_NOT_ROCM(SCALARTYPE, NAME, ...) \
if (std::is_same<SCALARTYPE, at::BFloat16>::value) { \
AT_ERROR( \
#NAME, \
" not implemented for '", \
toString(at::ScalarType::BFloat16), \
"'"); \
} else { \
return __VA_ARGS__(); \
}
#else
#define AT_SKIP_BFLOAT16_IF_NOT_ROCM(SCALARTYPE, NAME, ...) return __VA_ARGS__()
#endif

namespace detail {

inline at::ScalarType scalar_type(at::ScalarType s) {
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/ParallelOpenMP.cpp
@@ -1,4 +1,5 @@
#include <ATen/Config.h>
#include <ATen/core/jit_type.h>
#if AT_PARALLEL_OPENMP
#include <ATen/Parallel.h>

Expand Down
11 changes: 7 additions & 4 deletions aten/src/ATen/TensorIndexing.h
Expand Up @@ -10,6 +10,8 @@
// There is some back story, see https://github.com/pytorch/pytorch/issues/48684
#include <ATen/NativeFunctions.h>

#include <ATen/core/List.h>

namespace at {
namespace indexing {

Expand Down Expand Up @@ -261,14 +263,15 @@ static inline void recordTensorIndex(const Tensor& tensor, std::vector<Tensor>&
(*dim_ptr)++;
};

static inline std::vector<Tensor> typeConvertIndices(const Tensor& self, std::vector<Tensor>&& indices) {
std::vector<Tensor> converted_inds(indices.size());
static inline c10::List<c10::optional<Tensor>> typeConvertIndices(const Tensor& self, std::vector<Tensor>&& indices) {
c10::List<c10::optional<Tensor>> converted_inds;
converted_inds.reserve(indices.size());
for (size_t i = 0; i < indices.size(); ++i) {
const auto &ind = indices[i];
if (ind.defined()) {
converted_inds[i] = ind.to(ind.options().device(self.device()));
converted_inds.push_back(ind.to(ind.options().device(self.device())));
} else {
converted_inds[i] = std::move(indices[i]);
converted_inds.push_back(std::move(indices[i]));
}
}
return converted_inds;
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/autocast_mode.cpp
Expand Up @@ -406,7 +406,7 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) {
KERNEL(ADD_NS(cross), "cross", Tensor (const Tensor &, const Tensor &, c10::optional<int64_t>), promote)
KERNEL(ADD_NS(dot), "dot", Tensor (const Tensor &, const Tensor &), promote)
KERNEL(ADD_NS(equal), "equal", bool (const Tensor &, const Tensor &), promote)
KERNEL_UNBOXED_ONLY(ADD_NS(index_put), "index_put", Tensor (const Tensor &, TensorList, const Tensor &, bool), promote)
KERNEL(ADD_NS(index_put), "index_put", Tensor (const Tensor &, const torch::List<c10::optional<Tensor>>&, const Tensor &, bool), promote)
KERNEL(ADD_NS(stack), "stack", Tensor (TensorList, int64_t), promote)
KERNEL(ADD_NS(tensordot), "tensordot", Tensor (const Tensor &, const Tensor &, IntArrayRef, IntArrayRef), promote)

Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/core/List.h
Expand Up @@ -243,7 +243,7 @@ class List final {
* Example:
* List<int> a({2, 3, 4});
*/
explicit List(std::initializer_list<T> initial_values);
List(std::initializer_list<T> initial_values);
explicit List(ArrayRef<T> initial_values);

/**
Expand Down
16 changes: 14 additions & 2 deletions aten/src/ATen/core/List_inl.h
@@ -1,7 +1,7 @@
#pragma once

#include <ATen/core/jit_type_base.h>
#include <ATen/core/ivalue.h>
#include <ATen/core/jit_type.h>

namespace c10 {

Expand Down Expand Up @@ -50,7 +50,17 @@ List<T>::List(TypePtr elementType)
namespace impl {
template<class T>
List<T> toTypedList(impl::GenericList list) {
TORCH_INTERNAL_ASSERT(*getTypePtr<T>() == *list.impl_->elementType, "Tried to cast a List<", toString(list.impl_->elementType), "> to a List<", toString(getTypePtr<T>()), ">. Types mismatch.");
// If there's other instances of the list (i.e. list.use_count() > 1), then we have to be invariant
// because upcasting would allow people to add types into the new list that would break the old list.
// However, if there aren't any other instances of this list (i.e. list.use_count() == 1), then we can
// allow upcasting. This can be a perf improvement since we can cast List<T> to List<optional<T>>
// without having to copy it. This is also used to provide backwards compatibility with some old models
// that serialized the index arguments to aten::index, aten::index_put, aten::index_put_ and aten::index_put_impl_
// as List<Tensor> before we changed that argument to be List<optional<Tensor>>. When deserializing, we
// have list.use_count() == 1 and can deserialize the List<Tensor> directly as List<optional<Tensor>>.
TORCH_CHECK(*list.impl_->elementType == *getTypePtr<T>()
|| (list.use_count() == 1 && list.impl_->elementType->isSubtypeOf(getTypePtr<T>()))
, "Tried to cast a List<", toString(list.impl_->elementType), "> to a List<", toString(getTypePtr<T>()), ">. Types mismatch.");
return List<T>(std::move(list.impl_));
}

Expand Down Expand Up @@ -312,3 +322,5 @@ void List<T>::unsafeSetElementType(TypePtr t) {
impl_->elementType = std::move(t);
}
}

#include <ATen/core/jit_type.h>
10 changes: 10 additions & 0 deletions aten/src/ATen/core/Variadic.h
Expand Up @@ -6,6 +6,7 @@
#include <utility>

#include <c10/util/ArrayRef.h>
#include <ATen/core/List.h>

namespace at {

Expand Down Expand Up @@ -56,6 +57,15 @@ struct IterArgs {
}
}

template <typename T>
void operator()(const torch::List<T>& args) {
for (const auto& arg : args) {
self()(arg);
if (self().short_circuit())
return;
}
}

// NB: we need to specify std::vector manually as C++ won't
// do an implicit conversion to make a template deduction go through.
template <typename T>
Expand Down
9 changes: 9 additions & 0 deletions aten/src/ATen/core/builtin_function.h
Expand Up @@ -101,8 +101,17 @@ struct BuiltinOpFunction : public Function {
}

std::string pretty_print_schema() const override {
#ifdef __NVCC__
// Disable the "statement is unreachable" warning
#pragma diag_suppress code_is_unreachable
#endif

TORCH_INTERNAL_ASSERT(false);
return "";

#ifdef __NVCC__
#pragma diag_default code_is_unreachable
#endif
}

Function& setSchema(c10::FunctionSchema schema) override {
Expand Down
9 changes: 9 additions & 0 deletions aten/src/ATen/core/interned_strings.h
Expand Up @@ -17,6 +17,7 @@ namespace c10 {
#define FORALL_NS_SYMBOLS(_) \
_(namespaces, prim) \
_(namespaces, aten) \
_(namespaces, cuda) \
_(namespaces, onnx) \
_(namespaces, attr) \
_(namespaces, scope) \
Expand Down Expand Up @@ -284,6 +285,9 @@ namespace c10 {
_(aten, zero_) \
_(aten, fill_) \
_(aten, masked_fill_) \
_(cuda, _set_device) \
_(cuda, set_stream) \
_(cuda, _current_device) \
_(aten, swapaxes) \
_(aten, swapaxes_) \
_(aten, swapdims) \
Expand Down Expand Up @@ -384,6 +388,7 @@ namespace c10 {
#define FORALL_NS_SYMBOLS(_) \
_(namespaces, prim) \
_(namespaces, aten) \
_(namespaces, cuda) \
_(namespaces, onnx) \
_(namespaces, attr) \
_(namespaces, scope) \
Expand Down Expand Up @@ -454,6 +459,7 @@ struct TORCH_API Symbol {
// (and if it's not, you should add it to the built-ins list above.)
static Symbol attr(const std::string & s);
static Symbol aten(const std::string & s);
static Symbol cuda(const std::string & s);
static Symbol onnx(const std::string & s);
static Symbol prim(const std::string & s);
static Symbol user(const std::string & s);
Expand All @@ -464,6 +470,7 @@ struct TORCH_API Symbol {

bool is_attr() const;
bool is_aten() const;
bool is_cuda() const;
bool is_prim() const;
bool is_onnx() const;
bool is_user() const;
Expand Down Expand Up @@ -524,6 +531,7 @@ FORALL_NS_SYMBOLS(DEFINE_SYMBOL)

inline Symbol Symbol::attr(const std::string & s) { return Symbol::fromQualString("attr::" + s); }
inline Symbol Symbol::aten(const std::string & s) { return Symbol::fromQualString("aten::" + s); }
inline Symbol Symbol::cuda(const std::string & s) { return Symbol::fromQualString("cuda::" + s); }
inline Symbol Symbol::onnx(const std::string & s) { return Symbol::fromQualString("onnx::" + s); }
inline Symbol Symbol::prim(const std::string & s) { return Symbol::fromQualString("prim::" + s); }
inline Symbol Symbol::scope(const std::string & s) { return Symbol::fromQualString("scope::" + s); }
Expand All @@ -532,6 +540,7 @@ inline Symbol Symbol::caffe2(const std::string & s) { return Symbol::fromQualStr
inline Symbol Symbol::dimname(const std::string & s) { return Symbol::fromQualString("dimname::" + s); }
inline bool Symbol::is_attr() const { return ns() == namespaces::attr; }
inline bool Symbol::is_aten() const { return ns() == namespaces::aten; }
inline bool Symbol::is_cuda() const { return ns() == namespaces::cuda; }
inline bool Symbol::is_prim() const { return ns() == namespaces::prim; }
inline bool Symbol::is_onnx() const { return ns() == namespaces::onnx; }
inline bool Symbol::is_user() const { return ns() == namespaces::user; }
Expand Down

0 comments on commit 3a3a097

Please sign in to comment.