Skip to content

Commit

Permalink
Update on "Treat has_torch_function and object_has_torch_function as …
Browse files Browse the repository at this point in the history
…static False when scripting"


This PR lets us skip the `if not torch.jit.is_scripting():` guards on `functional` and `nn.functional` by directly registering `has_torch_function` and `object_has_torch_function` to the JIT as statically False.

**Benchmarks**

The benchmark script is kind of long. The reason is that it's testing all four PRs in the stack, plus threading and subprocessing so that the benchmark can utilize multiple cores while still collecting good numbers. Both wall times and instruction counts were collected. This stack changes dozens of operators / functions, but very mechanically such that there are only a handful of codepath changes. Each row is a slightly different code path (e.g. testing in Python, testing in the arg parser, different input types, etc.)

<details>

<summary> Test script </summary>

```
import argparse
import multiprocessing
import multiprocessing.dummy
import os
import pickle
import queue
import random
import sys
import subprocess
import tempfile
import time

import torch
from torch.utils.benchmark import Timer, Compare, Measurement


NUM_CORES = multiprocessing.cpu_count()
ENVS = {
    "ref": "HEAD (current)",
    "torch_fn_overhead_stack_0": "#48963",
    "torch_fn_overhead_stack_1": "#48964",
    "torch_fn_overhead_stack_2": "#48965",
    "torch_fn_overhead_stack_3": "#48966",
}

CALLGRIND_ENVS = tuple(ENVS.keys())


MIN_RUN_TIME = 3
REPLICATES = {
    "longer": 1_000,
    "long": 300,
    "short": 50,
}

CALLGRIND_NUMBER = {
    "overnight": 500_000,
    "long": 250_000,
    "short": 10_000,
}

CALLGRIND_TIMEOUT = {
    "overnight": 800,
    "long": 400,
    "short": 100,
}

SETUP = """
    x = torch.ones((1, 1))
    y = torch.ones((1, 1))
    w_tensor = torch.ones((1, 1), requires_grad=True)
    linear = torch.nn.Linear(1, 1, bias=False)
    linear_w = linear.weight
"""

TASKS = {
    "C++: unary                 `.t()`": "w_tensor.t()",
    "C++: unary  (Parameter)    `.t()`": "linear_w.t()",
    "C++: binary (Parameter)    `mul` ": "x + linear_w",
    "tensor.py: _wrap_type_error_to_not_implemented `__floordiv__`": "x // y",
    "tensor.py: method          `__hash__`": "hash(x)",
    "Python scalar              `__rsub__`": "1 - x",
    "functional.py: (unary)     `unique`": "torch.functional.unique(x)",
    "functional.py: (args)      `atleast_1d`": "torch.functional.atleast_1d((x, y))",
    "nn/functional.py: (unary)  `relu`": "torch.nn.functional.relu(x)",
    "nn/functional.py: (args)   `linear`": "torch.nn.functional.linear(x, w_tensor)",
    "nn/functional.py: (args)   `linear (Parameter)`": "torch.nn.functional.linear(x, linear_w)",
    "Linear(..., bias=False)": "linear(x)",
}


def _worker_main(argv, fn):
    parser = argparse.ArgumentParser()
    parser.add_argument("--output_file", type=str)
    parser.add_argument("--single_task", type=int, default=None)
    parser.add_argument("--length", type=str)
    args = parser.parse_args(argv)
    single_task = args.single_task

    conda_prefix = os.getenv("CONDA_PREFIX")
    assert torch.__file__.startswith(conda_prefix)

    env = os.path.split(conda_prefix)[1]
    assert env in ENVS

    results = []
    for i, (k, stmt) in enumerate(TASKS.items()):
        if single_task is not None and single_task != i:
            continue

        timer = Timer(
            stmt=stmt,
            setup=SETUP,
            sub_label=k,
            description=ENVS[env],
        )
        results.append(fn(timer, args.length))

    with open(args.output_file, "wb") as f:
        pickle.dump(results, f)


def worker_main(argv):
    _worker_main(
        argv,
        lambda timer, _: timer.blocked_autorange(min_run_time=MIN_RUN_TIME)
    )


def callgrind_worker_main(argv):
    _worker_main(
        argv,
        lambda timer, length: timer.collect_callgrind(number=CALLGRIND_NUMBER[length], collect_baseline=False))


def main(argv):
    parser = argparse.ArgumentParser()
    parser.add_argument("--long", action="store_true")
    parser.add_argument("--longer", action="store_true")
    args = parser.parse_args(argv)

    if args.longer:
        length = "longer"
    elif args.long:
        length = "long"
    else:
        length = "short"
    replicates = REPLICATES[length]

    num_workers = int(NUM_CORES // 2)
    tasks = list(ENVS.keys()) * replicates
    random.shuffle(tasks)
    task_queue = queue.Queue()
    for _ in range(replicates):
        envs = list(ENVS.keys())
        random.shuffle(envs)
        for e in envs:
            task_queue.put((e, None))

    callgrind_task_queue = queue.Queue()
    for e in CALLGRIND_ENVS:
        for i, _ in enumerate(TASKS):
            callgrind_task_queue.put((e, i))

    results = []
    callgrind_results = []

    def map_fn(worker_id):
        # Adjacent cores often share cache and maxing out a machine can distort
        # timings so we space them out.
        callgrind_cores = f"{worker_id * 2}-{worker_id * 2 + 1}"
        time_cores = str(worker_id * 2)
        _, output_file = tempfile.mkstemp(suffix=".pkl")
        try:
            loop_tasks = (
                # Callgrind is long running, and then the workers can help with
                # timing after they finish collecting counts.
                (callgrind_task_queue, callgrind_results, "callgrind_worker", callgrind_cores, CALLGRIND_TIMEOUT[length]),
                (task_queue, results, "worker", time_cores, None))

            for queue_i, results_i, mode_i, cores, timeout in loop_tasks:
                while True:
                    try:
                        env, task_i = queue_i.get_nowait()
                    except queue.Empty:
                        break

                    remaining_attempts = 3
                    while True:
                        try:
                            subprocess.run(
                                " ".join([
                                    "source", "activate", env, "&&",
                                    "taskset", "--cpu-list", cores,
                                    "python", os.path.abspath(__file__),
                                    "--mode", mode_i,
                                    "--length", length,
                                    "--output_file", output_file
                                ] + ([] if task_i is None else ["--single_task", str(task_i)])),
                                shell=True,
                                check=True,
                                timeout=timeout,
                            )
                            break

                        except subprocess.TimeoutExpired:
                            # Sometimes Valgrind will hang if there are too many
                            # concurrent runs.
                            remaining_attempts -= 1
                            if not remaining_attempts:
                                print("Too many failed attempts.")
                                raise
                            print(f"Timeout after {timeout} sec. Retrying.")

                    # We don't need a lock, as the GIL is enough.
                    with open(output_file, "rb") as f:
                        results_i.extend(pickle.load(f))

        finally:
            os.remove(output_file)

    with multiprocessing.dummy.Pool(num_workers) as pool:
        st, st_estimate, eta, n_total = time.time(), None, "", len(tasks) * len(TASKS)
        map_job = pool.map_async(map_fn, range(num_workers))
        while not map_job.ready():
            n_complete = len(results)
            if n_complete and len(callgrind_results):
                if st_estimate is None:
                    st_estimate = time.time()
                else:
                    sec_per_element = (time.time() - st_estimate) / n_complete
                    n_remaining = n_total - n_complete
                    eta = f"ETA: {n_remaining * sec_per_element:.0f} sec"

            print(
                f"\r{n_complete} / {n_total}  "
                f"({len(callgrind_results)} / {len(CALLGRIND_ENVS) * len(TASKS)})   "
                f"{eta}".ljust(40), end="")
            sys.stdout.flush()
            time.sleep(2)
    total_time = int(time.time() - st)
    print(f"\nTotal time: {int(total_time // 60)} min, {total_time % 60} sec")

    desc_to_ind = {k: i for i, k in enumerate(ENVS.values())}
    results.sort(key=lambda r: desc_to_ind[r.description])

    # TODO: Compare should be richer and more modular.
    compare = Compare(results)
    compare.trim_significant_figures()
    compare.colorize(rowwise=True)

    # Manually add master vs. overall relative delta t.
    merged_results = {
        (r.description, r.sub_label): r
        for r in Measurement.merge(results)
    }

    cmp_lines = str(compare).splitlines(False)
    print(cmp_lines[0][:-1] + "-" * 15 + "]")
    print(f"{cmp_lines[1]} |{'':>10}\u0394t")
    print(cmp_lines[2] + "-" * 15)
    for l, t in zip(cmp_lines[3:3 + len(TASKS)], TASKS.keys()):
        assert l.strip().startswith(t)
        t0 = merged_results[(ENVS["ref"], t)].median
        t1 = merged_results[(ENVS["torch_fn_overhead_stack_3"], t)].median
        print(f"{l} |{'':>5}{(t1 / t0 - 1) * 100:>6.1f}%")
    print("\n".join(cmp_lines[3 + len(TASKS):]))


    counts_dict = {
        (r.task_spec.description, r.task_spec.sub_label): r.counts(denoise=True)
        for r in callgrind_results
    }

    def rel_diff(x, x0):
        return f"{(x / x0 - 1) * 100:>6.1f}%"

    task_pad = max(len(t) for t in TASKS)
    print(f"\n\nInstruction % change (relative to `{CALLGRIND_ENVS[0]}`)")
    print(" " * (task_pad + 8)  + (" " * 7).join([ENVS[env] for env in CALLGRIND_ENVS[1:]]))
    for t in TASKS:
        values = [counts_dict[(ENVS[env], t)] for env in CALLGRIND_ENVS]

        print(t.ljust(task_pad + 3) + "  ".join([
            rel_diff(v, values[0]).rjust(len(ENVS[env]) + 5)
            for v, env in zip(values[1:], CALLGRIND_ENVS[1:])]))

        print("\033[4m" + "    Instructions per invocation".ljust(task_pad + 3) + "  ".join([
            f"{v // CALLGRIND_NUMBER[length]:.0f}".rjust(len(ENVS[env]) + 5)
            for v, env in zip(values[1:], CALLGRIND_ENVS[1:])]) + "\033[0m")
        print()

    import pdb
    pdb.set_trace()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--mode", type=str, choices=("main", "worker", "callgrind_worker"), default="main")
    args, remaining = parser.parse_known_args()

    if args.mode == "main":
        main(remaining)

    elif args.mode == "callgrind_worker":
        callgrind_worker_main(remaining)

    else:
        worker_main(remaining)

```

</details>

**Wall time**
<img width="1178" alt="Screen Shot 2020-12-12 at 12 28 13 PM" src="https://user-images.githubusercontent.com/13089297/101994419-284f6a00-3c77-11eb-8dc8-4f69a890302e.png">

<details>

<summary> Longer run (`python test.py --long`) is basically identical. </summary>

<img width="1184" alt="Screen Shot 2020-12-12 at 5 02 47 PM" src="https://user-images.githubusercontent.com/13089297/102000425-2350e180-3c9c-11eb-999e-a95b37e9ef54.png">

</details>

**Callgrind**
<img width="936" alt="Screen Shot 2020-12-12 at 12 28 54 PM" src="https://user-images.githubusercontent.com/13089297/101994421-2e454b00-3c77-11eb-9cd3-8cde550f536e.png">

Test plan: existing unit tests.

Differential Revision: [D25590731](https://our.internmc.facebook.com/intern/diff/D25590731)

[ghstack-poisoned]
  • Loading branch information
Taylor Robie committed Dec 16, 2020
2 parents f0fc7bd + 5061aef commit 49f002a
Show file tree
Hide file tree
Showing 88 changed files with 3,119 additions and 1,283 deletions.
5 changes: 5 additions & 0 deletions .circleci/scripts/binary_linux_test.sh
Expand Up @@ -7,6 +7,11 @@ set -eux -o pipefail
python_nodot="\$(echo $DESIRED_PYTHON | tr -d m.u)"
# There was a bug that was introduced in conda-package-handling >= 1.6.1 that makes archives
# above a certain size fail out when attempting to extract
# see: https://github.com/conda/conda-package-handling/issues/71
conda install -y conda-package-handling=1.6.0
# Set up Python
if [[ "$PACKAGE_TYPE" == conda ]]; then
retry conda create -qyn testenv python="$DESIRED_PYTHON"
Expand Down
22 changes: 15 additions & 7 deletions aten/src/ATen/BatchedTensorImpl.cpp
Expand Up @@ -76,13 +76,6 @@ void BatchedTensorImpl::checkInvariants() const {
}

// The following are publically exposed as methods of Tensor
IntArrayRef BatchedTensorImpl::strides() const {
TORCH_CHECK(false, "NYI: Getting tensor strides inside of vmap");
}
int64_t BatchedTensorImpl::stride(int64_t d) const {
TORCH_CHECK(false, "NYI: Getting tensor strides inside of vmap");
}

bool BatchedTensorImpl::is_contiguous(at::MemoryFormat memory_format) const {
TORCH_CHECK(memory_format == MemoryFormat::Contiguous,
"NYI: querying is_contiguous inside of vmap for memory_format ",
Expand Down Expand Up @@ -139,4 +132,19 @@ Tensor addBatchDim(const Tensor& tensor, int64_t level, int64_t dim) {
return makeBatched(batched->value(), std::move(new_bdims));
}

bool inplaceIsVmapCompatible(const Tensor& self, const Tensor& other) {
const auto* other_batched = maybeGetBatchedImpl(other);
if (!other_batched) {
return true;
}
const auto* self_batched = maybeGetBatchedImpl(self);
if (!self_batched) {
// self is not batched but other is batched
return false;
}
auto self_levels = createVmapLevelsBitset(self_batched->bdims());
auto other_levels = createVmapLevelsBitset(other_batched->bdims());
return self_levels == (self_levels | other_levels);
}

} // namespace at
6 changes: 4 additions & 2 deletions aten/src/ATen/BatchedTensorImpl.h
Expand Up @@ -74,8 +74,6 @@ struct TORCH_API BatchedTensorImpl : public c10::TensorImpl {

// Override a bunch of methods inherited from TensorImpl to return error messages.
bool is_contiguous(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const override;
IntArrayRef strides() const override;
int64_t stride(int64_t d) const override;
void set_size(int64_t dim, int64_t new_size) override;
void set_stride(int64_t dim, int64_t new_stride) override;
void set_storage_offset(int64_t storage_offset) override;
Expand Down Expand Up @@ -143,5 +141,9 @@ TORCH_API Tensor makeBatched(const Tensor& tensor, BatchDims bdims);
// Adds a batch dim to `tensor`, returning a BatchedTensor
TORCH_API Tensor addBatchDim(const Tensor& tensor, int64_t level, int64_t dim);

// Checks if an inplace operation on self and other is "vmap compatible".
// See NOTE: [vmap-incompatible in-place operations] for the definition of this.
TORCH_API bool inplaceIsVmapCompatible(const Tensor& self, const Tensor& other);


}
8 changes: 4 additions & 4 deletions aten/src/ATen/VmapModeRegistrations.cpp
Expand Up @@ -79,15 +79,15 @@ TORCH_LIBRARY_IMPL(aten, VmapMode, m) {

m.impl("rand", unsupportedRandomOp<IntArrayRef, TENSOROPTIONS>);
m.impl("rand.generator", unsupportedRandomOp<IntArrayRef, optional<Generator>, TENSOROPTIONS>);
m.impl_UNBOXED("rand.names", unsupportedRandomOp<IntArrayRef, optional<DimnameList>, const TensorOptions&>);
m.impl_UNBOXED("rand.generator_with_names", unsupportedRandomOp<IntArrayRef, optional<Generator>, optional<DimnameList>, const TensorOptions&>);
m.impl("rand.names", unsupportedRandomOp<IntArrayRef, optional<DimnameList>, TENSOROPTIONS>);
m.impl("rand.generator_with_names", unsupportedRandomOp<IntArrayRef, optional<Generator>, optional<DimnameList>, TENSOROPTIONS>);
m.impl("rand.out", unsupportedRandomOp_<IntArrayRef, Tensor&>);
m.impl("rand.generator_out", unsupportedRandomOp_<IntArrayRef, optional<Generator>, Tensor&>);

m.impl("randn", unsupportedRandomOp<IntArrayRef, TENSOROPTIONS>);
m.impl("randn.generator", unsupportedRandomOp<IntArrayRef, optional<Generator>, TENSOROPTIONS>);
m.impl_UNBOXED("randn.names", unsupportedRandomOp<IntArrayRef, optional<DimnameList>, const TensorOptions&>);
m.impl_UNBOXED("randn.generator_with_names", unsupportedRandomOp<IntArrayRef, optional<Generator>, optional<DimnameList>, const TensorOptions&>);
m.impl("randn.names", unsupportedRandomOp<IntArrayRef, optional<DimnameList>, TENSOROPTIONS>);
m.impl("randn.generator_with_names", unsupportedRandomOp<IntArrayRef, optional<Generator>, optional<DimnameList>, TENSOROPTIONS>);
m.impl("randn.out", unsupportedRandomOp_<IntArrayRef, Tensor&>);
m.impl("randn.generator_out", unsupportedRandomOp_<IntArrayRef, optional<Generator>, Tensor&>);

Expand Down
21 changes: 7 additions & 14 deletions aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h
Expand Up @@ -265,20 +265,13 @@ namespace impl {
return ivalue_to_arg<std::vector<T>, AllowDeprecatedTypes>::call(std::move(v));
}
};
template<bool AllowDeprecatedTypes>
struct ivalue_to_arg<optional<ArrayRef<int64_t>>, AllowDeprecatedTypes> final {
// If an argument is optional<ArrayRef<int64_t>>, convert the IValue to a optional<std::vector<int64_t>> and pass that
// to the operator.
static OptionalArray<int64_t> call(IValue&& v) {
return std::move(v).toOptionalIntArray();
}
};
template<bool AllowDeprecatedTypes>
struct ivalue_to_arg<optional<ArrayRef<double>>, AllowDeprecatedTypes> final {
// If an argument is optional<ArrayRef<T>>, convert the IValue to a optional<std::vector<T>> and pass that
// to the operator.
static OptionalArray<double> call(IValue&& v) {
return std::move(v).toOptionalDoubleArray();
template<class T, bool AllowDeprecatedTypes>
struct ivalue_to_arg<optional<ArrayRef<T>>, AllowDeprecatedTypes> final {
// If an argument is optional<ArrayRef<T>>, convert the IValue to an optional<std::vector<T>> and pass that
// to the operator. OptionalArray<T> is basically a optional<std::vector<T>> but impliticly convertible
// to optional<ArrayRef<T>>.
static OptionalArray<T> call(IValue&& v) {
return ivalue_to_arg<OptionalArray<T>, AllowDeprecatedTypes>::call(std::move(v));
}
};

Expand Down
8 changes: 8 additions & 0 deletions aten/src/ATen/core/ivalue.cpp
Expand Up @@ -76,6 +76,8 @@ TypePtr IValue::type() const {
return NoneType::get();
case Tag::Tensor:
return TensorType::create(toTensor());
case Tag::Storage:
return StorageType::get();
case Tag::Double:
return FloatType::get();
case Tag::Int:
Expand Down Expand Up @@ -260,6 +262,8 @@ IValue IValue::equals(const IValue& rhs) const {
return false;
}
return lhs.toTensor().eq(rhs.toTensor());
case Tag::Storage:
return rhs.isStorage() && lhs.toStorage().unsafeGetStorageImpl() == rhs.toStorage().unsafeGetStorageImpl();
case Tag::Double:
return rhs.isDouble() && lhs.toDouble() == rhs.toDouble();
case Tag::Int:
Expand Down Expand Up @@ -310,6 +314,8 @@ size_t IValue::hash(const IValue& v) {
// Tensor __hash__ is equivalent to `id()`, so take the pointer value of
// the tensor to emulate it
return c10::get_hash(v.payload.as_int);
case Tag::Storage:
return c10::get_hash(v.payload.as_int);
case Tag::Int:
return c10::get_hash(v.payload.as_int);
case Tag::String:
Expand Down Expand Up @@ -647,6 +653,8 @@ std::ostream& operator<<(std::ostream & out, const IValue & v) {
return out << v.toNone();
case IValue::Tag::Tensor:
return out << v.toTensor();
case IValue::Tag::Storage:
return out << v.toStorage().unsafeGetStorageImpl();
case IValue::Tag::Double: {
double d = v.toDouble();
int c = std::fpclassify(d);
Expand Down
23 changes: 15 additions & 8 deletions aten/src/ATen/core/ivalue.h
Expand Up @@ -105,6 +105,7 @@ struct Capsule {
#define TORCH_FORALL_TAGS(_) \
_(None) \
_(Tensor) \
_(Storage) \
_(Double) \
_(Int) \
_(Bool) \
Expand Down Expand Up @@ -314,6 +315,20 @@ struct CAFFE2_API IValue final {
return static_cast<at::TensorImpl*>(payload.as_intrusive_ptr);
}

IValue(at::Storage s) : tag(Tag::Storage), is_intrusive_ptr(static_cast<bool>(s)) {
// Note: the undefined tensor is not refcounted, so while it
// is tagged as a tensor, is_intrusive_ptr is set to false.
// This is not an optional optimization: our incref call
// *will not* do the right thing when called on an
// undefined tensor.
payload.as_intrusive_ptr = s.unsafeReleaseStorageImpl();
}
bool isStorage() const {
return Tag::Storage == tag;
}
c10::Storage toStorage() &&;
c10::Storage toStorage() const&;

const IValue& toIValue() const {
return *this;
}
Expand Down Expand Up @@ -705,14 +720,6 @@ struct CAFFE2_API IValue final {
template <typename T>
optional<T> toOptional();

/// @private [doxygen private]
/// Only for use in generated code.
OptionalArray<int64_t> toOptionalIntArray();

/// @private [doxygen private]
/// Only for use in generated code.
OptionalArray<double> toOptionalDoubleArray();

/// @private [doxygen private]
/// this is a shallow comparison of two IValues to test the object identity
bool isSameIdentity(const IValue& rhs) const;
Expand Down
64 changes: 40 additions & 24 deletions aten/src/ATen/core/ivalue_inl.h
Expand Up @@ -137,6 +137,15 @@ inline at::Tensor IValue::toTensor() const& {
AT_ASSERT(isTensor(), "Expected Tensor but got ", tagKind());
return at::Tensor(toIntrusivePtr<at::TensorImpl, at::UndefinedTensorImpl>());
}
inline c10::Storage IValue::toStorage() && {
AT_ASSERT(isStorage(), "Expected Storage but got ", tagKind());
return c10::Storage(
moveToIntrusivePtr<at::StorageImpl>());
}
inline c10::Storage IValue::toStorage() const& {
AT_ASSERT(isStorage(), "Expected Storage but got ", tagKind());
return c10::Storage(toIntrusivePtr<at::StorageImpl>());
}
inline c10::Stream IValue::toStream() && {
return c10::Stream::unpack(payload.as_int);
}
Expand Down Expand Up @@ -743,6 +752,7 @@ inline const ivalue::Object& IValue::toObjectRef() const {
return this->method_name(); \
}
DEFINE_TO(at::Tensor, toTensor)
DEFINE_TO(at::Storage, toStorage)
DEFINE_TO(c10::Stream, toStream)
DEFINE_TO(float, toDouble)
DEFINE_TO(double, toDouble)
Expand Down Expand Up @@ -861,6 +871,36 @@ c10::List<Elem> generic_to(IValue ivalue, _fake_type<c10::List<Elem>>) {
return impl::toTypedList<Elem>(std::move(ivalue).toList());
}

template <typename T>
static std::vector<T> createVectorFromList(const c10::detail::ListImpl* impl) {
std::vector<T> result;
result.reserve(impl->list.size());
for (size_t i = 0, N = impl->list.size(); i < N; ++i) {
result.push_back(impl->list[i].to<T>());
}
return result;
}

template <typename T>
static std::vector<T> createVectorFromList(const c10::List<T>& impl) {
std::vector<T> result;
result.reserve(impl.size());
for (size_t i = 0, N = impl.size(); i < N; ++i) {
result.push_back(impl[i]);
}
return result;
}

template <typename T>
OptionalArray<T> generic_to(IValue ivalue, _fake_type<OptionalArray<T>>) {
if (ivalue.isNone()) {
return {};
}
return createVectorFromList<T>(
std::move(ivalue).to<c10::List<T>>()
);
}

namespace detail {
template <typename Elem, size_t... I>
std::array<Elem, sizeof...(I)> generic_to_array(
Expand Down Expand Up @@ -952,16 +992,6 @@ inline T IValue::to() const& {
return generic_to(*this, _fake_type<T>{});
}

template <typename T>
static std::vector<T> createVectorFromList(const c10::detail::ListImpl* impl) {
std::vector<T> result;
result.reserve(impl->list.size());
for (size_t i = 0, N = impl->list.size(); i < N; ++i) {
result.push_back(impl->list[i].to<T>());
}
return result;
}

inline c10::List<int64_t> IValue::toIntList() && {
AT_ASSERT(isIntList(), "Expected IntList but got ", tagKind());
return c10::List<int64_t>(moveToIntrusivePtr<c10::detail::ListImpl>());
Expand Down Expand Up @@ -1211,20 +1241,6 @@ inline optional<T> IValue::toOptional() {
return this->to<T>();
}

inline OptionalArray<int64_t> IValue::toOptionalIntArray() {
if (this->isNone()) {
return {};
}
return this->toIntVector();
}

inline OptionalArray<double> IValue::toOptionalDoubleArray() {
if (this->isNone()) {
return {};
}
return this->toDoubleVector();
}

inline bool IValue::isCustomClass() const {
return torch::isCustomClass(*this);
}
Expand Down
30 changes: 30 additions & 0 deletions aten/src/ATen/core/jit_type.h
Expand Up @@ -31,6 +31,7 @@ using OptNameList = c10::optional<std::vector<std::string>>;
_(EnumType) \
_(AnyEnumType) \
_(TensorType) \
_(StorageType) \
_(TupleType) \
_(ListType) \
_(DictType) \
Expand Down Expand Up @@ -1407,6 +1408,29 @@ struct CAFFE2_API StringType : public Type {
StringType() : Type(TypeKind::StringType) {}
};

struct StorageType;
using StorageTypePtr = std::shared_ptr<StorageType>;
struct CAFFE2_API StorageType : public Type {
static StorageTypePtr create() {
return StorageTypePtr(new StorageType()); // NOLINT(modernize-make-shared)
}
bool operator==(const Type& rhs) const override {
return rhs.kind() == kind();
}
std::string str() const override {
return annotation_str();
}
std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
return "Storage";
}
static const TypeKind Kind = TypeKind::StorageType;
// global singleton
static StorageTypePtr get();

private:
StorageType() : Type(TypeKind::StorageType) {}
};

struct FunctionType;
using FunctionTypePtr = std::shared_ptr<FunctionType>;
struct CAFFE2_API FunctionType : public NamedType {
Expand Down Expand Up @@ -1757,6 +1781,12 @@ struct getTypePtr_<at::Tensor> final {
}
};
template <>
struct getTypePtr_<c10::Storage> final {
static TypePtr call() {
return StorageType::get();
}
};
template <>
struct getTypePtr_<c10::Stream> final {
static TypePtr call() {
return StreamObjType::get();
Expand Down
4 changes: 4 additions & 0 deletions aten/src/ATen/core/type.cpp
Expand Up @@ -134,6 +134,10 @@ BoolTypePtr BoolType::get() {
static auto value = BoolType::create();
return value;
}
StorageTypePtr StorageType::get() {
static auto value = StorageType::create();
return value;
}
NoneTypePtr NoneType::get() {
static auto value = NoneType::create();
return value;
Expand Down

0 comments on commit 49f002a

Please sign in to comment.