From c7e9abb66abef127f8cebccbe0aa27c6ded9ead6 Mon Sep 17 00:00:00 2001 From: Sebastian Messmer Date: Mon, 4 Jan 2021 05:01:02 -0800 Subject: [PATCH] Making ops c10-full: list of optional tensors (#49138) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49138 See for details: https://fb.quip.com/QRtJAin66lPN We need to model optional types explicitly, mostly for schema inference. So we cannot pass a `Tensor?[]` as `ArrayRef`, instead we need to pass it as an optional type. This PR changes it to `torch::List>`. It also makes the ops c10-full that were blocked by this. ## Backwards Compatibility - This should not break the Python API because the representation in Python is the same and python_arg_parser just transforms the python list into a `List>` instead of into a `List`. - This should not break serialized models because there's some logic that allows loading a serialized `List` as `List>`, see https://github.com/pytorch/pytorch/pull/49138/files#diff-9315f5dd045f47114c677174dcaa2f982721233eee1aa19068a42ff3ef775315R57 - This will break backwards compatibility for the C++ API. There is no implicit conversion from `ArrayRef` (which was the old argument type) to `List>`. One common call pattern is `tensor.index({indices_tensor})`, where indices_tensor is another `Tensor`, and that will continue working because the `{}` initializer_list constructor for `List>` can take `Tensor` elements that are implicitly converted to `optional`, but another common call pattern was `tensor.index(indices_tensor)`, where previously, the `Tensor` got implicitly converted to an `ArrayRef`, and to implicitly convert `Tensor -> optional -> List>` would be two implicit conversions. C++ doesn't allow chaining. two implicit conversions. So those call sites have to be rewritten to `tensor.index({indices_tensor})`. ghstack-source-id: 119269131 Test Plan: ## Benchmarks (C++ instruction counts): ### Forward #### Script ```py from torch.utils.benchmark import Timer counts = Timer( stmt=""" auto t = {{op call to measure}}; """, setup=""" using namespace torch::indexing; auto x = torch::ones({4, 4, 4}); """, language="cpp", ).collect_callgrind(number=1_000) print(counts) ``` #### Results | Op call |before |after |delta | | |------------------------------------------------------------------------|---------|--------|-------|------| |x[0] = 1 |11566015 |11566015|0 |0.00% | |x.index({0}) |6807019 |6801019 |-6000 |-0.09%| |x.index({0, 0}) |13529019 |13557019|28000 |0.21% | |x.index({0, 0, 0}) |10677004 |10692004|15000 |0.14% | |x.index({"..."}) |5512015 |5506015 |-6000 |-0.11%| |x.index({Slice(None, None, None)}) |6866016 |6936016 |70000 |1.02% | |x.index({None}) |8554015 |8548015 |-6000 |-0.07%| |x.index({false}) |22400000 |22744000|344000 |1.54% | |x.index({true}) |27624088 |27264393|-359695|-1.30%| |x.index({"...", 0, true, Slice(1, None, 2), torch::tensor({1, 2})})|123472000|123463306|-8694|-0.01%| ### Autograd #### Script ```py from torch.utils.benchmark import Timer counts = Timer( stmt=""" auto t = {{op call to measure}}; """, setup=""" using namespace torch::indexing; auto x = torch::ones({4, 4, 4}, torch::requires_grad()); """, language="cpp", ).collect_callgrind(number=1_000) print(counts) ``` Note: the script measures the **forward** path of an op call with autograd enabled (i.e. calls into VariableType). It does not measure the backward path. #### Results | Op call |before |after |delta | | |------------------------------------------------------------------------|---------|--------|-------|------| |x.index({0}) |14839019|14833019|-6000| 0.00% | |x.index({0, 0}) |28342019|28370019|28000| 0.00% | |x.index({0, 0, 0}) |24434004|24449004|15000| 0.00% | |x.index({"..."}) |12773015|12767015|-6000| 0.00% | |x.index({Slice(None, None, None)}) |14837016|14907016|70000| 0.47% | |x.index({None}) |15926015|15920015|-6000| 0.00% | |x.index({false}) |36958000|37477000|519000| 1.40% | |x.index({true}) |41971408|42426094|454686| 1.08% | |x.index({"...", 0, true, Slice(1, None, 2), torch::tensor({1, 2})}) |168184392|164545682|-3638710| -2.16% | Reviewed By: bhosmer Differential Revision: D25454632 fbshipit-source-id: 28ab0cffbbdbdff1c40b4130ca62ee72f981b76d --- aten/src/ATen/ATen.h | 1 + aten/src/ATen/ParallelOpenMP.cpp | 1 + aten/src/ATen/TensorIndexing.h | 11 +- aten/src/ATen/autocast_mode.cpp | 2 +- aten/src/ATen/core/List.h | 2 +- aten/src/ATen/core/List_inl.h | 16 +- aten/src/ATen/core/Variadic.h | 10 + aten/src/ATen/core/jit_type.h | 187 +---------------- aten/src/ATen/core/jit_type_base.h | 195 ++++++++++++++++++ aten/src/ATen/native/Embedding.cpp | 2 +- aten/src/ATen/native/IndexingUtils.h | 63 +++--- aten/src/ATen/native/LinearAlgebra.cpp | 4 +- .../ATen/native/TensorAdvancedIndexing.cpp | 26 ++- aten/src/ATen/native/TensorAdvancedIndexing.h | 4 +- aten/src/ATen/native/cuda/IndexKernel.cu | 2 +- aten/src/ATen/native/cuda/Indexing.cu | 4 +- aten/src/ATen/native/native_functions.yaml | 4 + aten/src/ATen/native/sparse/SparseTensor.cpp | 4 +- aten/src/ATen/templates/TensorBody.h | 1 + caffe2/contrib/aten/aten_op.cc | 16 +- caffe2/contrib/aten/aten_op_template.h | 12 +- caffe2/contrib/aten/gen_op.py | 18 +- test/cpp/api/tensor_indexing.cpp | 16 +- test/test_overrides.py | 2 + tools/autograd/gen_autograd_functions.py | 11 +- tools/autograd/gen_trace_type.py | 3 +- tools/autograd/gen_variable_type.py | 51 ++++- tools/autograd/templates/Functions.h | 9 + tools/autograd/templates/VariableType.h | 1 - tools/codegen/api/cpp.py | 8 +- tools/codegen/api/native.py | 2 +- tools/codegen/api/python.py | 11 +- torch/csrc/autograd/FunctionsManual.cpp | 35 ++-- torch/csrc/autograd/FunctionsManual.h | 2 +- torch/csrc/autograd/VariableTypeManual.cpp | 11 +- torch/csrc/autograd/VariableTypeUtils.h | 19 ++ torch/csrc/jit/backends/backend_detail.h | 1 + torch/csrc/jit/frontend/tracer.cpp | 13 ++ torch/csrc/jit/frontend/tracer.h | 4 + torch/csrc/jit/mobile/module.h | 1 + torch/csrc/jit/runtime/interpreter.h | 1 + torch/csrc/jit/runtime/register_prim_ops.cpp | 8 +- torch/csrc/jit/runtime/vararg_functions.h | 1 + torch/csrc/utils/python_arg_parser.cpp | 3 +- torch/csrc/utils/python_arg_parser.h | 17 ++ 45 files changed, 510 insertions(+), 305 deletions(-) create mode 100644 aten/src/ATen/core/jit_type_base.h diff --git a/aten/src/ATen/ATen.h b/aten/src/ATen/ATen.h index ae95ef43f21c..8d29a9204420 100644 --- a/aten/src/ATen/ATen.h +++ b/aten/src/ATen/ATen.h @@ -31,3 +31,4 @@ #include #include #include +#include diff --git a/aten/src/ATen/ParallelOpenMP.cpp b/aten/src/ATen/ParallelOpenMP.cpp index 07fc4e279557..261f6cdd46b5 100644 --- a/aten/src/ATen/ParallelOpenMP.cpp +++ b/aten/src/ATen/ParallelOpenMP.cpp @@ -1,4 +1,5 @@ #include +#include #if AT_PARALLEL_OPENMP #include diff --git a/aten/src/ATen/TensorIndexing.h b/aten/src/ATen/TensorIndexing.h index 3890662123a2..f6c3bbbe09cc 100644 --- a/aten/src/ATen/TensorIndexing.h +++ b/aten/src/ATen/TensorIndexing.h @@ -10,6 +10,8 @@ // There is some back story, see https://github.com/pytorch/pytorch/issues/48684 #include +#include + namespace at { namespace indexing { @@ -261,14 +263,15 @@ static inline void recordTensorIndex(const Tensor& tensor, std::vector& (*dim_ptr)++; }; -static inline std::vector typeConvertIndices(const Tensor& self, std::vector&& indices) { - std::vector converted_inds(indices.size()); +static inline c10::List> typeConvertIndices(const Tensor& self, std::vector&& indices) { + c10::List> converted_inds; + converted_inds.reserve(indices.size()); for (size_t i = 0; i < indices.size(); ++i) { const auto &ind = indices[i]; if (ind.defined()) { - converted_inds[i] = ind.to(ind.options().device(self.device())); + converted_inds.push_back(ind.to(ind.options().device(self.device()))); } else { - converted_inds[i] = std::move(indices[i]); + converted_inds.push_back(std::move(indices[i])); } } return converted_inds; diff --git a/aten/src/ATen/autocast_mode.cpp b/aten/src/ATen/autocast_mode.cpp index 8c82f965ef0f..dfb8e3ac0f32 100644 --- a/aten/src/ATen/autocast_mode.cpp +++ b/aten/src/ATen/autocast_mode.cpp @@ -406,7 +406,7 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) { KERNEL(ADD_NS(cross), "cross", Tensor (const Tensor &, const Tensor &, c10::optional), promote) KERNEL(ADD_NS(dot), "dot", Tensor (const Tensor &, const Tensor &), promote) KERNEL(ADD_NS(equal), "equal", bool (const Tensor &, const Tensor &), promote) - KERNEL_UNBOXED_ONLY(ADD_NS(index_put), "index_put", Tensor (const Tensor &, TensorList, const Tensor &, bool), promote) + KERNEL(ADD_NS(index_put), "index_put", Tensor (const Tensor &, const torch::List>&, const Tensor &, bool), promote) KERNEL(ADD_NS(stack), "stack", Tensor (TensorList, int64_t), promote) KERNEL(ADD_NS(tensordot), "tensordot", Tensor (const Tensor &, const Tensor &, IntArrayRef, IntArrayRef), promote) diff --git a/aten/src/ATen/core/List.h b/aten/src/ATen/core/List.h index 40f733784fe5..f911722c51e1 100644 --- a/aten/src/ATen/core/List.h +++ b/aten/src/ATen/core/List.h @@ -243,7 +243,7 @@ class List final { * Example: * List a({2, 3, 4}); */ - explicit List(std::initializer_list initial_values); + List(std::initializer_list initial_values); explicit List(ArrayRef initial_values); /** diff --git a/aten/src/ATen/core/List_inl.h b/aten/src/ATen/core/List_inl.h index 3cbd7a310275..ab3ddae55770 100644 --- a/aten/src/ATen/core/List_inl.h +++ b/aten/src/ATen/core/List_inl.h @@ -1,7 +1,7 @@ #pragma once +#include #include -#include namespace c10 { @@ -50,7 +50,17 @@ List::List(TypePtr elementType) namespace impl { template List toTypedList(impl::GenericList list) { - TORCH_INTERNAL_ASSERT(*getTypePtr() == *list.impl_->elementType, "Tried to cast a List<", toString(list.impl_->elementType), "> to a List<", toString(getTypePtr()), ">. Types mismatch."); + // If there's other instances of the list (i.e. list.use_count() > 1), then we have to be invariant + // because upcasting would allow people to add types into the new list that would break the old list. + // However, if there aren't any other instances of this list (i.e. list.use_count() == 1), then we can + // allow upcasting. This can be a perf improvement since we can cast List to List> + // without having to copy it. This is also used to provide backwards compatibility with some old models + // that serialized the index arguments to aten::index, aten::index_put, aten::index_put_ and aten::index_put_impl_ + // as List before we changed that argument to be List>. When deserializing, we + // have list.use_count() == 1 and can deserialize the List directly as List>. + TORCH_CHECK(*list.impl_->elementType == *getTypePtr() + || (list.use_count() == 1 && list.impl_->elementType->isSubtypeOf(getTypePtr())) + , "Tried to cast a List<", toString(list.impl_->elementType), "> to a List<", toString(getTypePtr()), ">. Types mismatch."); return List(std::move(list.impl_)); } @@ -312,3 +322,5 @@ void List::unsafeSetElementType(TypePtr t) { impl_->elementType = std::move(t); } } + +#include diff --git a/aten/src/ATen/core/Variadic.h b/aten/src/ATen/core/Variadic.h index b49d94bba1c8..d33f3d575177 100644 --- a/aten/src/ATen/core/Variadic.h +++ b/aten/src/ATen/core/Variadic.h @@ -6,6 +6,7 @@ #include #include +#include namespace at { @@ -56,6 +57,15 @@ struct IterArgs { } } + template + void operator()(const torch::List& args) { + for (const auto& arg : args) { + self()(arg); + if (self().short_circuit()) + return; + } + } + // NB: we need to specify std::vector manually as C++ won't // do an implicit conversion to make a template deduction go through. template diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h index f6902cd4beb6..a3ae813616e0 100644 --- a/aten/src/ATen/core/jit_type.h +++ b/aten/src/ATen/core/jit_type.h @@ -1,10 +1,11 @@ #pragma once +#include #include #include #include -#include #include +#include #include #include @@ -17,197 +18,17 @@ struct ClassType; namespace torch { namespace jit { struct CompilationUnit; +struct Function; } // namespace jit } // namespace torch namespace c10 { +struct IValue; struct FunctionSchema; struct NamedType; using OptNameList = c10::optional>; -#define C10_FORALL_TYPES(_) \ - _(AnyType) \ - _(EnumType) \ - _(AnyEnumType) \ - _(TensorType) \ - _(StorageType) \ - _(TupleType) \ - _(ListType) \ - _(DictType) \ - _(NumberType) \ - _(FloatType) \ - _(FutureType) \ - _(RRefType) \ - _(IntType) \ - _(NoneType) \ - _(StringType) \ - _(GeneratorType) \ - _(QuantizerType) \ - _(BoolType) \ - _(OptionalType) \ - _(VarType) \ - _(DeviceObjType) \ - _(StreamObjType) \ - _(FunctionType) \ - _(ClassType) \ - _(PyObjectType) \ - _(CapsuleType) \ - _(InterfaceType) \ - _(QSchemeType) \ - _(LayoutType) \ - _(ScalarTypeType) \ - _(AnyListType) \ - _(AnyTupleType) \ - _(AnyClassType) - -enum class TypeKind { -#define DEFINE_TYPE(T) T, - C10_FORALL_TYPES(DEFINE_TYPE) -#undef DEFINE_TYPE -}; - -TORCH_API const char* typeKindToString(TypeKind kind); - -struct Type; -using TypePtr = std::shared_ptr; -using ConstTypePtr = std::shared_ptr; - -// Use this to customize how a Type is printed using `annotation_str()`. If -// c10::nullopt is returned, `annotation_str()` falls through to its default -// implementation. -using TypePrinter = - std::function(const ConstTypePtr&)>; - -struct TORCH_API Type : std::enable_shared_from_this { - private: - TypeKind kind_; - - protected: - Type(TypeKind kind) : kind_(kind) {} - - virtual std::string annotation_str_impl(TypePrinter printer) const { - return str(); - } - - public: - virtual bool operator==(const Type& rhs) const = 0; - - // subtyping relation. By default, we return true for the case - // when the type is exactly equal or if this <: T where rhs = Optional[T] - - // if this returns false and the why_not stream is non-null, it contains - // additional details that describe why this is not a subtype of 'rhs'. - // This additional information should only contain details that are not obvious - // from the annotation_str() that describes the type. For instance it is clear that `int <: str` is false - // but not clear why `Foo <: InterfaceBar` might be false. - virtual bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const; - virtual bool is_module() const; - bool isSubtypeOf(const TypePtr& rhs) const { - return isSubtypeOfExt(rhs, nullptr); - } - - // How this type will appear in FunctionSchema declarations - virtual std::string str() const = 0; - - // How this type will appear as if it were a type annotation in Python - // which is sometimes different than how it appears in declarations (e.g. - // int[] vs List[int]) - // - // Takes a custom printer that users can pass in to customize the output of - // this method. - std::string annotation_str(TypePrinter printer) const { - if (printer) { - // the printer can return nullopt to fall through to the default impl - if (auto renamed = printer(shared_from_this())) { - return *renamed; - } - } - return annotation_str_impl(printer); - } - std::string annotation_str() const { - // Overload instead of define a default value for `printer` to help - // debuggers out. - return annotation_str(nullptr); - } - - // Returns a human readable string that includes additional information like - // "type is inferred rather than explictly defined" to help construct more - // user-friendly messages. - virtual std::string repr_str() const { - return annotation_str(); - } - - TypeKind kind() const { - return kind_; - } - - virtual bool requires_grad() const { - for (const auto& ct : containedTypes()) { - if (ct->requires_grad()) { - return true; - } - } - return false; - } - - // Dynamically cast this object to the subclass indicated by the - // template variable, returning nullptr if the cast is invalid. - template - std::shared_ptr cast() { - if (T::Kind == kind()) { - return std::static_pointer_cast(shared_from_this()); - } - return nullptr; - } - template - std::shared_ptr cast() const { - if (T::Kind == kind()) { - return std::static_pointer_cast(shared_from_this()); - } - return nullptr; - } - template - std::shared_ptr expect() { - auto r = cast(); - AT_ASSERT(r); - return r; - } - template - std::shared_ptr expect() const { - auto r = cast(); - AT_ASSERT(r); - return r; - } - virtual ~Type() = default; - virtual bool hasFreeVariables() const { - return false; - } - // list of types this type contains, e.g. for a List then element type of a - // list for a tuple, the types of the tuple elements - virtual at::ArrayRef containedTypes() const { - return {}; - } - // create a new version of this type, replacing its contained types with - // contained_types - TypePtr withContained(std::vector contained_types) { - auto current_contained = containedTypes(); - AT_ASSERT(current_contained.size() == contained_types.size()); - if (current_contained.equals(contained_types)) { - return shared_from_this(); - } - return createWithContained(std::move(contained_types)); - } - // per-type constructor, you only need to override this if the - // containedTypes() is not empty - virtual TypePtr createWithContained( - std::vector contained_types) const { - AT_ERROR( - "type with contained types did not overload createWithContained: ", - str()); - } -}; - struct AnyType; using AnyTypePtr = std::shared_ptr; // Any is the top of the type hierarchy, all other types are subtypes diff --git a/aten/src/ATen/core/jit_type_base.h b/aten/src/ATen/core/jit_type_base.h new file mode 100644 index 000000000000..37da9ad7ef8d --- /dev/null +++ b/aten/src/ATen/core/jit_type_base.h @@ -0,0 +1,195 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace c10 { + +#define C10_FORALL_TYPES(_) \ + _(AnyType) \ + _(EnumType) \ + _(AnyEnumType) \ + _(TensorType) \ + _(StorageType) \ + _(TupleType) \ + _(ListType) \ + _(DictType) \ + _(NumberType) \ + _(FloatType) \ + _(FutureType) \ + _(RRefType) \ + _(IntType) \ + _(NoneType) \ + _(StringType) \ + _(GeneratorType) \ + _(QuantizerType) \ + _(BoolType) \ + _(OptionalType) \ + _(VarType) \ + _(DeviceObjType) \ + _(StreamObjType) \ + _(FunctionType) \ + _(ClassType) \ + _(PyObjectType) \ + _(CapsuleType) \ + _(InterfaceType) \ + _(QSchemeType) \ + _(LayoutType) \ + _(ScalarTypeType) \ + _(AnyListType) \ + _(AnyTupleType) \ + _(AnyClassType) + +enum class TypeKind { +#define DEFINE_TYPE(T) T, + C10_FORALL_TYPES(DEFINE_TYPE) +#undef DEFINE_TYPE +}; + +TORCH_API const char* typeKindToString(TypeKind kind); + +struct Type; +using TypePtr = std::shared_ptr; +using ConstTypePtr = std::shared_ptr; + +// Use this to customize how a Type is printed using `annotation_str()`. If +// c10::nullopt is returned, `annotation_str()` falls through to its default +// implementation. +using TypePrinter = + std::function(const ConstTypePtr&)>; + +struct TORCH_API Type : std::enable_shared_from_this { + private: + TypeKind kind_; + + protected: + Type(TypeKind kind) : kind_(kind) {} + + virtual std::string annotation_str_impl(TypePrinter printer) const { + return str(); + } + + public: + virtual bool operator==(const Type& rhs) const = 0; + + // subtyping relation. By default, we return true for the case + // when the type is exactly equal or if this <: T where rhs = Optional[T] + + // if this returns false and the why_not stream is non-null, it contains + // additional details that describe why this is not a subtype of 'rhs'. + // This additional information should only contain details that are not obvious + // from the annotation_str() that describes the type. For instance it is clear that `int <: str` is false + // but not clear why `Foo <: InterfaceBar` might be false. + virtual bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const; + virtual bool is_module() const; + bool isSubtypeOf(const TypePtr& rhs) const { + return isSubtypeOfExt(rhs, nullptr); + } + + // How this type will appear in FunctionSchema declarations + virtual std::string str() const = 0; + + // How this type will appear as if it were a type annotation in Python + // which is sometimes different than how it appears in declarations (e.g. + // int[] vs List[int]) + // + // Takes a custom printer that users can pass in to customize the output of + // this method. + std::string annotation_str(TypePrinter printer) const { + if (printer) { + // the printer can return nullopt to fall through to the default impl + if (auto renamed = printer(shared_from_this())) { + return *renamed; + } + } + return annotation_str_impl(printer); + } + std::string annotation_str() const { + // Overload instead of define a default value for `printer` to help + // debuggers out. + return annotation_str(nullptr); + } + + // Returns a human readable string that includes additional information like + // "type is inferred rather than explictly defined" to help construct more + // user-friendly messages. + virtual std::string repr_str() const { + return annotation_str(); + } + + TypeKind kind() const { + return kind_; + } + + virtual bool requires_grad() const { + for (const auto& ct : containedTypes()) { + if (ct->requires_grad()) { + return true; + } + } + return false; + } + + // Dynamically cast this object to the subclass indicated by the + // template variable, returning nullptr if the cast is invalid. + template + std::shared_ptr cast() { + if (T::Kind == kind()) { + return std::static_pointer_cast(shared_from_this()); + } + return nullptr; + } + template + std::shared_ptr cast() const { + if (T::Kind == kind()) { + return std::static_pointer_cast(shared_from_this()); + } + return nullptr; + } + template + std::shared_ptr expect() { + auto r = cast(); + AT_ASSERT(r); + return r; + } + template + std::shared_ptr expect() const { + auto r = cast(); + AT_ASSERT(r); + return r; + } + virtual ~Type() = default; + virtual bool hasFreeVariables() const { + return false; + } + // list of types this type contains, e.g. for a List then element type of a + // list for a tuple, the types of the tuple elements + virtual at::ArrayRef containedTypes() const { + return {}; + } + // create a new version of this type, replacing its contained types with + // contained_types + TypePtr withContained(std::vector contained_types) { + auto current_contained = containedTypes(); + AT_ASSERT(current_contained.size() == contained_types.size()); + if (current_contained.equals(contained_types)) { + return shared_from_this(); + } + return createWithContained(std::move(contained_types)); + } + // per-type constructor, you only need to override this if the + // containedTypes() is not empty + virtual TypePtr createWithContained( + std::vector contained_types) const { + AT_ERROR( + "type with contained types did not overload createWithContained: ", + str()); + } +}; + +} diff --git a/aten/src/ATen/native/Embedding.cpp b/aten/src/ATen/native/Embedding.cpp index bf74e8b356c7..a4854e1ced4d 100644 --- a/aten/src/ATen/native/Embedding.cpp +++ b/aten/src/ATen/native/Embedding.cpp @@ -68,7 +68,7 @@ Tensor embedding_sparse_backward( Tensor indices = indices_; Tensor grad = grad_; if (padding_idx != -1) { - auto c = indices != padding_idx; + torch::List> c({indices != padding_idx}); indices = indices.index(c); grad = grad.index(c); } diff --git a/aten/src/ATen/native/IndexingUtils.h b/aten/src/ATen/native/IndexingUtils.h index 94d61b02dd0b..92f6957f25ad 100644 --- a/aten/src/ATen/native/IndexingUtils.h +++ b/aten/src/ATen/native/IndexingUtils.h @@ -1,6 +1,7 @@ #pragma once #include #include +#include #include @@ -15,40 +16,45 @@ static void invalid_mask(const Tensor & self, int64_t idx, const Tensor & mask, } -static std::vector expandTensors(const Tensor & self, TensorList indices) { +static std::vector expandTensors(const Tensor & self, const torch::List>& indices) { // If indices come in as ByteTensor or BoolTensor (masks), expand them into the equivalent indexing by LongTensors std::vector result; - for (const auto & index : indices) { - if (index.scalar_type() == kByte || index.scalar_type() == kBool) { - if (index.scalar_type() == kByte) { - TORCH_WARN("indexing with dtype torch.uint8 is now deprecated," \ - " please use a dtype torch.bool instead."); - } - // The sizes of the ByteTensor mask or bool tensor must match the sizes of the - // corresponding dimensions in self - for (int64_t j = 0; j < index.dim(); j++) { - int64_t srcIdx = result.size() + j; - if (index.size(j) != self.size(srcIdx)) { - invalid_mask(self, srcIdx, index, j); + for (c10::optional index_opt : indices) { + if (!index_opt.has_value()) { + result.emplace_back(); + } else { + Tensor index = std::move(*index_opt); + if (index.scalar_type() == kByte || index.scalar_type() == kBool) { + if (index.scalar_type() == kByte) { + TORCH_WARN("indexing with dtype torch.uint8 is now deprecated," \ + " please use a dtype torch.bool instead."); } + // The sizes of the ByteTensor mask or bool tensor must match the sizes of the + // corresponding dimensions in self + for (int64_t j = 0; j < index.dim(); j++) { + int64_t srcIdx = result.size() + j; + if (index.size(j) != self.size(srcIdx)) { + invalid_mask(self, srcIdx, index, j); + } + } + // Replace with nonzeros + auto nonzero = index.nonzero(); + for (int64_t j = 0; j < index.dim(); j++) { + result.emplace_back(nonzero.select(1, j)); + } + } else { + result.emplace_back(std::move(index)); } - // Replace with nonzeros - auto nonzero = index.nonzero(); - for (int64_t j = 0; j < index.dim(); j++) { - result.emplace_back(nonzero.select(1, j)); - } - } else { - result.emplace_back(index); } } return result; } -static void checkIndexTensorTypes(TensorList indices) { - for (auto& tensor : indices) { - if (tensor.defined()) { - auto scalarType = tensor.scalar_type(); +static void checkIndexTensorTypes(const torch::List>& indices) { + for (c10::optional tensor : indices) { + if (tensor.has_value() && tensor->defined()) { + auto scalarType = tensor->scalar_type(); if (scalarType != kLong && scalarType != kByte && scalarType != kBool) { TORCH_CHECK_INDEX(false, "tensors used as indices must be long, byte or bool tensors"); } @@ -56,6 +62,15 @@ static void checkIndexTensorTypes(TensorList indices) { } } +inline torch::List> toListOfOptionalTensors(ArrayRef list) { + torch::List> result; + result.reserve(list.size()); + for (const Tensor& a : list) { + result.push_back(a); + } + return result; +} + static bool hasContiguousSubspace(TensorList tl) { // true if all the non-null tensors are adjacent auto isDefined = [](const Tensor & tensor){ return tensor.defined(); }; diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index da8d2bd6db47..a37d1046bac2 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -73,7 +74,8 @@ Tensor logdet(const Tensor& self) { // U is singular when U(i, i) = 0 for some i in [1, self.size(-1)]. Tensor logdet_vals = diag_U.abs_().log_().sum(-1); if (self.dim() > 2) { - logdet_vals.index_put_((det_sign < 0).nonzero_numpy(), at::full({}, NAN, self.options())); + auto indices = toListOfOptionalTensors((det_sign < 0).nonzero_numpy()); + logdet_vals.index_put_(std::move(indices), at::full({}, NAN, self.options())); } else if (det_sign.item() < 0) { logdet_vals.fill_(NAN); } diff --git a/aten/src/ATen/native/TensorAdvancedIndexing.cpp b/aten/src/ATen/native/TensorAdvancedIndexing.cpp index 1d9f9d9d2a12..2d79a4e3713f 100644 --- a/aten/src/ATen/native/TensorAdvancedIndexing.cpp +++ b/aten/src/ATen/native/TensorAdvancedIndexing.cpp @@ -206,7 +206,7 @@ AdvancedIndex::AdvancedIndex(const Tensor& src, TensorList indices_list) } } -static AdvancedIndex make_info(Tensor self, TensorList orig) { +static AdvancedIndex make_info(Tensor self, const torch::List>& orig) { checkIndexTensorTypes(orig); // first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more LongTensors auto indices = expandTensors(self, orig); @@ -281,7 +281,7 @@ static TensorIterator make_index_out_iterator(const AdvancedIndex& info, Tensor& return config.build(); } -Tensor index(const Tensor & self, TensorList indices) { +Tensor index(const Tensor & self, const torch::List>& indices) { TORCH_CHECK_INDEX(indices.size() <= (size_t)self.dim(), "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")"); auto info = make_info(self, indices); @@ -290,7 +290,7 @@ Tensor index(const Tensor & self, TensorList indices) { return iter.output(); } -Tensor quantized_index(const Tensor & self, TensorList indices) { +Tensor quantized_index(const Tensor & self, const torch::List>& indices) { TORCH_INTERNAL_ASSERT( self.qscheme() == c10::kPerTensorAffine || self.qscheme() == c10::kPerTensorSymmetric, @@ -311,12 +311,14 @@ Tensor quantized_index(const Tensor & self, TensorList indices) { res, self.q_scale(), self.q_zero_point(), self.scalar_type()); } -Tensor& index_out(Tensor& result, const Tensor & self, TensorList indices) { +Tensor& index_out(Tensor& result, const Tensor & self, const torch::List>& indices) { TORCH_CHECK_INDEX(indices.size() <= (size_t)self.dim(), "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")"); at::assert_no_internal_overlap(result); at::assert_no_overlap(result, self); - for (auto& index: indices) { - at::assert_no_overlap(result, index); + for (const c10::optional& index: indices) { + if (index.has_value()) { + at::assert_no_overlap(result, *index); + } } auto info = make_info(self, indices); @@ -325,11 +327,11 @@ Tensor& index_out(Tensor& result, const Tensor & self, TensorList indices) { return result; } -Tensor index_put(const Tensor & self, TensorList indices, const Tensor & value, bool accumulate) { +Tensor index_put(const Tensor & self, const torch::List>& indices, const Tensor & value, bool accumulate) { return self.clone(at::MemoryFormat::Preserve).index_put_(indices, value, accumulate); } -Tensor & _index_put_impl_(Tensor & self, TensorList indices, const Tensor & value, const bool accumulate, const bool unsafe) { +Tensor & _index_put_impl_(Tensor & self, const torch::List>& indices, const Tensor & value, const bool accumulate, const bool unsafe) { TORCH_CHECK_INDEX(indices.size() <= (size_t)self.dim(), "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")"); if (at::has_internal_overlap(self) == MemOverlap::YES) { TORCH_WARN( @@ -338,8 +340,10 @@ Tensor & _index_put_impl_(Tensor & self, TensorList indices, const Tensor & valu "This also applies to advanced indexing e.g. tensor[indices] = tensor"); } at::assert_no_overlap(self, value); - for (auto& index: indices) { - at::assert_no_overlap(self, index); + for (const c10::optional& index: indices) { + if (index.has_value()) { + at::assert_no_overlap(self, *index); + } } if (accumulate && self.device().type() == kCUDA) { @@ -356,7 +360,7 @@ Tensor & _index_put_impl_(Tensor & self, TensorList indices, const Tensor & valu } -Tensor & index_put_(Tensor & self, TensorList indices, const Tensor & value, const bool accumulate) { +Tensor & index_put_(Tensor & self, const torch::List>& indices, const Tensor & value, const bool accumulate) { return at::_index_put_impl_(self, indices, value, accumulate, /*unsafe=*/false); } diff --git a/aten/src/ATen/native/TensorAdvancedIndexing.h b/aten/src/ATen/native/TensorAdvancedIndexing.h index 560b46162546..0e0958606de1 100644 --- a/aten/src/ATen/native/TensorAdvancedIndexing.h +++ b/aten/src/ATen/native/TensorAdvancedIndexing.h @@ -15,7 +15,7 @@ enum class SCATTER_GATHER_OP: uint8_t {REDUCE_ADD, REDUCE_MULTIPLY}; using index_fn = void(*)(TensorIterator &, IntArrayRef indexed_sizes, IntArrayRef indexed_strides); using index_put_fn = void(*)(TensorIterator &, IntArrayRef indexed_sizes, IntArrayRef indexed_strides, bool accumulate); -using index_put_accum_fn = void(*)(Tensor &, TensorList , const Tensor &, bool unsafe); +using index_put_accum_fn = void(*)(Tensor &, const c10::List> &, const Tensor &, bool unsafe); using masked_fill_fn = void(*)(TensorIterator &, Scalar scalar); using masked_select_fn = void(*)(TensorIterator &, int64_t orig_stride); @@ -42,6 +42,6 @@ DECLARE_DISPATCH(scatter_add_fn, scatter_add_stub); DECLARE_DISPATCH(scatter_reduce_fn, scatter_reduce_stub); DECLARE_DISPATCH(scatter_scalar_reduce_fn, scatter_scalar_reduce_stub); -TORCH_API Tensor& index_out(Tensor& result, const Tensor & self, TensorList indices); +TORCH_API Tensor& index_out(Tensor& result, const Tensor & self, const c10::List>& indices); }} // namespace at::native diff --git a/aten/src/ATen/native/cuda/IndexKernel.cu b/aten/src/ATen/native/cuda/IndexKernel.cu index cb4aa644fee2..d88f202487af 100644 --- a/aten/src/ATen/native/cuda/IndexKernel.cu +++ b/aten/src/ATen/native/cuda/IndexKernel.cu @@ -190,7 +190,7 @@ static Tensor & masked_select_out_cuda_impl(Tensor & result, const Tensor & self Tensor _mask = (mask.dim() == 0) ? mask.unsqueeze(0) : mask; Tensor _self = (self.dim() == 0) ? self.unsqueeze(0) : self; std::tie(_mask, _self) = expand_outplace(_mask, _self); - at::native::index_out(result, _self, _mask); + at::native::index_out(result, _self, c10::List>({_mask})); return result; } diff --git a/aten/src/ATen/native/cuda/Indexing.cu b/aten/src/ATen/native/cuda/Indexing.cu index d630d727019f..e372f8bdb697 100644 --- a/aten/src/ATen/native/cuda/Indexing.cu +++ b/aten/src/ATen/native/cuda/Indexing.cu @@ -160,7 +160,7 @@ computeLinearIndex(const Tensor & src, TensorList indices, bool check_range) { } -static std::tuple> makeLinearIndex(Tensor self, TensorList orig, bool check_range) { +static std::tuple> makeLinearIndex(Tensor self, const c10::List>& orig, bool check_range) { checkIndexTensorTypes(orig); // first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more LongTensors auto indices = expandTensors(self, orig); @@ -184,7 +184,7 @@ static std::tuple>& indices, const Tensor & value, bool unsafe) { if (indices.size() > (size_t)self.dim()) { TORCH_CHECK_INDEX(false, "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")"); } diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index a5b945399da8..6b0aaa8f4d9b 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -2226,6 +2226,7 @@ use_c10_dispatcher: full - func: index.Tensor(Tensor self, Tensor?[] indices) -> Tensor + use_c10_dispatcher: full variants: function, method dispatch: CPU, CUDA: index @@ -2254,6 +2255,7 @@ variants: function, method - func: index_put_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor(a!) + use_c10_dispatcher: full variants: function, method dispatch: DefaultBackend: index_put_ @@ -2264,9 +2266,11 @@ # - Tensor & Tensor::index_put_(std::initializer_list indices, Scalar v) - func: index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor + use_c10_dispatcher: full variants: function, method - func: _index_put_impl_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False, bool unsafe=False) -> Tensor(a!) + use_c10_dispatcher: full variants: function dispatch: CPU, CUDA: _index_put_impl_ diff --git a/aten/src/ATen/native/sparse/SparseTensor.cpp b/aten/src/ATen/native/sparse/SparseTensor.cpp index d621efafee41..fb7e16539c15 100644 --- a/aten/src/ATen/native/sparse/SparseTensor.cpp +++ b/aten/src/ATen/native/sparse/SparseTensor.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include @@ -14,7 +15,6 @@ namespace at { namespace native { using namespace at::sparse; - /****************************************************************************** * access methods ******************************************************************************/ @@ -328,7 +328,7 @@ SparseTensor dense_to_sparse(const Tensor& self, int64_t sparse_dim){ Tensor values; if (self.dim() > 0) { - std::vector ix = indices.chunk(indices.size(0), 0); + auto ix = toListOfOptionalTensors(indices.chunk(indices.size(0), 0)); values = self.index(ix).squeeze(0).clone(at::MemoryFormat::Preserve); } else { AT_ASSERT(nz.sizes().equals({0, 1})); diff --git a/aten/src/ATen/templates/TensorBody.h b/aten/src/ATen/templates/TensorBody.h index d42c8c23fe9c..1c0a04a318d0 100644 --- a/aten/src/ATen/templates/TensorBody.h +++ b/aten/src/ATen/templates/TensorBody.h @@ -28,6 +28,7 @@ class Tensor; } namespace c10{ struct TensorOptions; +template class List; } namespace at { struct Generator; diff --git a/caffe2/contrib/aten/aten_op.cc b/caffe2/contrib/aten/aten_op.cc index 9e7479141ad4..dba68d21c2dd 100644 --- a/caffe2/contrib/aten/aten_op.cc +++ b/caffe2/contrib/aten/aten_op.cc @@ -6,13 +6,17 @@ namespace caffe2 { namespace internal { at::Tensor index_with_uint8_handling( const at::Tensor& self, - at::TensorList indices) { + const torch::List>& indices) { // Support BC only for the simplest case of mask indexing - if (indices.size() == 1 && indices[0].scalar_type() == at::kByte) { - TORCH_WARN( - "Indexing with uint8 mask tensor in ATenOp is now deprecated," - " please use a bool mask instead."); - return at::index(self, {indices[0].to(at::kBool)}); + if (indices.size() == 1) { + c10::optional first = indices[0]; + if (first.has_value() + && first->scalar_type() == at::kByte) { + TORCH_WARN( + "Indexing with uint8 mask tensor in ATenOp is now deprecated," + " please use a bool mask instead."); + return at::index(self, {first->to(at::kBool)}); + } } return at::index(self, indices); } diff --git a/caffe2/contrib/aten/aten_op_template.h b/caffe2/contrib/aten/aten_op_template.h index f3a42dbd8f59..cd1ce7651b48 100644 --- a/caffe2/contrib/aten/aten_op_template.h +++ b/caffe2/contrib/aten/aten_op_template.h @@ -21,7 +21,7 @@ using at::Half; // for AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, ...) namespace internal { TORCH_API at::Tensor index_with_uint8_handling( const at::Tensor& self, - at::TensorList indices); + const torch::List>& indices); } template @@ -86,6 +86,16 @@ class ATenOp : public Operator { std::vector peekSlice(size_t i, size_t len, size_t N) { std::vector results; + results.reserve(len); + for (size_t ii = i; ii < i + len; ++ii) { + results.push_back(peek(ii, N)); + } + return results; + } + + torch::List> peekSliceOptionals(size_t i, size_t len, size_t N) { + torch::List> results; + results.reserve(len); for (size_t ii = i; ii < i + len; ++ii) { results.push_back(peek(ii, N)); } diff --git a/caffe2/contrib/aten/gen_op.py b/caffe2/contrib/aten/gen_op.py index 2a822058bfdf..ba29ab933da9 100755 --- a/caffe2/contrib/aten/gen_op.py +++ b/caffe2/contrib/aten/gen_op.py @@ -68,7 +68,7 @@ def value_has_tensors(v): def value_is_tensor_type(v): - return value_has_tensors(v) and v['dynamic_type'] != 'TensorList' + return value_has_tensors(v) and v['dynamic_type'] not in ['TensorList', 'const c10::List> &'] # for each aten type, how do we handle a return value of that type? @@ -208,7 +208,7 @@ def self_as_first_argument(arguments): def get_num_inputs(o): args = 0 for a in o['arguments']: - if a['type'] == 'TensorList': + if a['type'] in ['TensorList', 'const c10::List> &']: return '*' elif value_has_tensors(a): args += 1 @@ -277,10 +277,10 @@ def emit_assignments(o, env): # e.g. "Float" is at::kFloat assert('Type' in o['method_of']) - static_tensor_inputs = sum(arg['type'] != 'TensorList' and value_is_tensor_type(arg) for arg in o['arguments']) - has_tensorlist = any(arg['type'] == 'TensorList' for arg in o['arguments']) + static_tensor_inputs = sum(arg['type'] not in ['TensorList', 'const c10::List> &'] and value_is_tensor_type(arg) for arg in o['arguments']) + has_tensorlist = any(arg['type'] in ['TensorList', 'const c10::List> &'] for arg in o['arguments']) if has_tensorlist: - tensorlist_idx = [i for i, arg in enumerate(o['arguments']) if arg['type'] == 'TensorList'][0] + tensorlist_idx = [i for i, arg in enumerate(o['arguments']) if arg['type'] in ['TensorList', 'const c10::List> &']][0] real_inputs = 0 for i, arg in enumerate(o['arguments']): @@ -290,10 +290,16 @@ def emit_assignments(o, env): view_length = 'InputSize()' if has_tensorlist and i < tensorlist_idx else static_tensor_inputs if arg['type'] == 'TensorList': # NOTE: do not advance real_inputs here. After this we will - # switch to indexing the "stack" from the end as if we only had + # switch to indexing the "stack" from the end env['statements'].append( 'auto {} = peekSlice({}, InputSize() - {}, InputSize());' .format(arg['name'], real_inputs, static_tensor_inputs)) + elif arg['type'] == 'const c10::List> &': + # NOTE: do not advance real_inputs here. After this we will + # switch to indexing the "stack" from the end + env['statements'].append( + 'auto {} = peekSliceOptionals({}, InputSize() - {}, InputSize());' + .format(arg['name'], real_inputs, static_tensor_inputs)) elif value_is_tensor_type(arg): # load tensor inputs from Caffe2 env['statements'].append( diff --git a/test/cpp/api/tensor_indexing.cpp b/test/cpp/api/tensor_indexing.cpp index efb153fbf481..03600c5c882e 100644 --- a/test/cpp/api/tensor_indexing.cpp +++ b/test/cpp/api/tensor_indexing.cpp @@ -83,27 +83,27 @@ TEST(TensorIndexingTest, TestNoIndices) { ASSERT_THROWS_WITH(tensor.index_put_(indices, value), "Passing an empty index list to Tensor::index_put_() is not valid syntax"); } -TEST(TensorIndexingTest, TestAdvancedIndexingWithArrayRefOfTensor) { +TEST(TensorIndexingTest, TestAdvancedIndexingWithListOfTensor) { { torch::Tensor tensor = torch::randn({20, 20}); torch::Tensor index = torch::arange(10, torch::kLong).cpu(); - torch::Tensor result_with_array_ref = tensor.index(at::ArrayRef({index})); + torch::Tensor result = at::index(tensor, {index}); torch::Tensor result_with_init_list = tensor.index({index}); - ASSERT_TRUE(result_with_array_ref.equal(result_with_init_list)); + ASSERT_TRUE(result.equal(result_with_init_list)); } { torch::Tensor tensor = torch::randn({20, 20}); torch::Tensor index = torch::arange(10, torch::kLong).cpu(); - torch::Tensor result_with_array_ref = tensor.index_put_(at::ArrayRef({index}), torch::ones({20})); + torch::Tensor result = at::index_put_(tensor, {index}, torch::ones({20})); torch::Tensor result_with_init_list = tensor.index_put_({index}, torch::ones({20})); - ASSERT_TRUE(result_with_array_ref.equal(result_with_init_list)); + ASSERT_TRUE(result.equal(result_with_init_list)); } { torch::Tensor tensor = torch::randn({20, 20}); torch::Tensor index = torch::arange(10, torch::kLong).cpu(); - torch::Tensor result_with_array_ref = tensor.index_put_(at::ArrayRef({index}), torch::ones({1, 20})); + torch::Tensor result = at::index_put_(tensor, {index}, torch::ones({1, 20})); torch::Tensor result_with_init_list = tensor.index_put_({index}, torch::ones({1, 20})); - ASSERT_TRUE(result_with_array_ref.equal(result_with_init_list)); + ASSERT_TRUE(result.equal(result_with_init_list)); } } @@ -173,7 +173,7 @@ TEST(TensorIndexingTest, TestBoolIndices) { TEST(TensorIndexingTest, TestBoolIndicesAccumulate) { auto mask = torch::zeros({10}, torch::kBool); auto y = torch::ones({10, 10}); - y.index_put_({mask}, y.index({mask}), /*accumulate=*/true); + y.index_put_({mask}, {y.index({mask})}, /*accumulate=*/true); assert_tensor_equal(y, torch::ones({10, 10})); } diff --git a/test/test_overrides.py b/test/test_overrides.py index 95f94504d84e..f32b04cb2e53 100644 --- a/test/test_overrides.py +++ b/test/test_overrides.py @@ -563,6 +563,8 @@ def instance_gen(): func_args.append(instance_gen()) elif t == 'TensorList': func_args.append([instance_gen(), instance_gen()]) + elif t == 'c10::List>': + func_args.append([instance_gen(), instance_gen()]) elif t == 'IntArrayRef': size = arg.get('size', 2) if size == 1: diff --git a/tools/autograd/gen_autograd_functions.py b/tools/autograd/gen_autograd_functions.py index a22154b5c01d..4724b99a8742 100644 --- a/tools/autograd/gen_autograd_functions.py +++ b/tools/autograd/gen_autograd_functions.py @@ -141,7 +141,7 @@ def process_function(info: DifferentiabilityInfo, template: CodeTemplate) -> str compute_index_ranges: List[str] = [] for arg in info.args_with_derivatives: - if arg.type == 'TensorList': + if arg.type == 'TensorList' or arg.type == 'const c10::List> &': size = f'{arg.name}_size_' saved_list_sizes.append(f'size_t {arg.name}_size_;') else: @@ -166,6 +166,15 @@ def save_var(var: SavedAttribute, is_output: bool) -> None: release_variables.append(f'{name}_released_ = true;') unpack.append(f'auto {name} = unpack_list({name}_);') asserts.append(f'TORCH_CHECK(!{name}_released_, ERR_BACKWARD_TWICE);') + elif var.type == 'c10::List>': + saved_variables.append(f'std::vector {name}_;') + saved_variables.append(f'bool {name}_released_ = false;') + # Just clear() is sufficient, we don't need to loop and clear each variable. + # Because the SavedVariable owns a tensor and a grad_fn, removing the SavedVariable makes them go away as well. + release_variables.append(f'{name}_.clear();') + release_variables.append(f'{name}_released_ = true;') + unpack.append(f'auto {name} = unpack_opt_list({name}_);') + asserts.append(f'TORCH_CHECK(!{name}_released_, ERR_BACKWARD_TWICE);') elif var.type == 'IntArrayRef': saved_variables.append(f'std::vector {name};') elif var.type == 'c10::optional': diff --git a/tools/autograd/gen_trace_type.py b/tools/autograd/gen_trace_type.py index b2dfe2667128..78c460843d94 100644 --- a/tools/autograd/gen_trace_type.py +++ b/tools/autograd/gen_trace_type.py @@ -112,9 +112,8 @@ def dispatch_trace_input(arg: Union[Argument, TensorOptionsArguments]) -> Sequen ] else: name = arg.name - # XXX: For arg that have type of Tensor?[], tracer will pass allow_undefined to addInputs if str(arg.type) == 'Tensor?[]': - return [f'jit::tracer::addInputs(node, "{name}", {name}, true);'] + return [f'jit::tracer::addInputs(node, "{name}", {name});'] else: return [ADD_TRACE_INPUT.substitute(name=name, input=name)] diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index 1d75ae46e9c9..f97fb55ab012 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -118,6 +118,21 @@ } """) +SAVE_OPTIONALTENSORLIST_STORAGE = CodeTemplate("""\ +std::vector> ${tensorlist_name}_storage_saved(${tensorlist_name}.size()); +for (const c10::optional& tensor : ${tensorlist_name}) + ${tensorlist_name}_storage_saved.push_back( + tensor.has_value() && tensor->has_storage() ? c10::optional(tensor->storage()) : c10::nullopt); +""") + +ENFORCE_SAME_OPTIONALTENSORLIST_STORAGE = CodeTemplate("""\ +for (size_t i=0; i<${tensorlist_name}.size(); i++) { + if (${tensorlist_name}_storage_saved[i].has_value()) + AT_ASSERT(${tensorlist_name}_storage_saved[i].value().is_alias_of( + static_cast>(${tensorlist_name}[i])->storage())); +} +""") + SAVE_TENSOR_IMPL = CodeTemplate("""\ c10::intrusive_ptr ${tensor_name}_impl_saved; if (${tensor_name}.defined()) ${tensor_name}_impl_saved = ${tensor_name}.getIntrusivePtr(); @@ -140,6 +155,21 @@ } """) +SAVE_OPTIONALTENSORLIST_IMPL = CodeTemplate("""\ +std::vector> ${tensorlist_name}_impl_saved(${tensorlist_name}.size()); +for (size_t i=0; i<${tensorlist_name}.size(); i++) { + c10::optional t = ${tensorlist_name}[i]; + if (t.has_value() && t->defined()) ${tensorlist_name}_impl_saved[i] = t->getIntrusivePtr(); +} +""") + +ENFORCE_SAME_OPTIONALTENSORLIST_IMPL = CodeTemplate("""\ +for (size_t i=0; i<${tensorlist_name}.size(); i++) { + if (${tensorlist_name}_impl_saved[i]) + AT_ASSERT(${tensorlist_name}_impl_saved[i] == static_cast>(${tensorlist_name}[i])->getIntrusivePtr()); +} +""") + # The following list contains functions that we don't enforce the invariant on. DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE = { # These functions are expected to change impl or storage of input tensors @@ -466,7 +496,8 @@ def emit_save_inputs(): if func is None: return setup - has_tensorlist_arg = any(arg.type == 'TensorList' for arg in func.args_with_derivatives) + has_tensorlist_arg = \ + any(arg.type in ['TensorList', 'const c10::List> &'] for arg in func.args_with_derivatives) # We don't want to save tensors if we know that they will never be used # when computing the derivative, so we add guards to those statements @@ -515,7 +546,7 @@ def guard_for(arg: SavedAttribute) -> Optional[str]: setup.extend(save_variables(func.all_saved_inputs, False, guard_for)) for arg in func.args_with_derivatives: - if arg.type == 'TensorList': + if arg.type in ['TensorList', 'const c10::List> &']: setup.append(f'grad_fn->{arg.name}_size_ = {arg.name}.size();') return setup @@ -554,7 +585,7 @@ def emit_check_if_in_complex_autograd_allowlist(): return body for arg in differentiable_outputs: name = arg['name'] - if arg['type'] == 'Tensor' or arg['type'] == 'TensorList': + if arg['type'] in ['Tensor', 'TensorList', 'const c10::List> &']: body.append('throw_error_for_complex_autograd({}, "{}");'.format(name, base_name)) return body @@ -599,7 +630,7 @@ def save_variables( expr = f'SavedVariable({var}, {str(is_output).lower()}, {is_inplace_view})' else: expr = f'SavedVariable({var}, {str(is_output).lower()})' - elif arg.type == 'TensorList': + elif arg.type in ['TensorList', 'c10::List>']: name += '_' expr = f'make_saved_variable_list({arg.name})' elif arg.type == 'IntArrayRef': @@ -699,7 +730,7 @@ def wrap_output(return_values, var): # Only allow rebasing of the history if we return a single Tensor # If we are in a no grad block, raise a warning # See NOTE [ View + Inplace detection ] for more details about this logic - if return_info['dynamic_type'] == 'TensorList': + if return_info['dynamic_type'] in ['TensorList', 'const c10::List> &']: if base_name in MULTI_OUTPUT_SAFE_FUNCTIONS: creation_meta = "CreationMeta::MULTI_OUTPUT_SAFE" else: @@ -736,6 +767,11 @@ def enforce_same_tensorimpl_and_storage(env, call): SAVE_TENSORLIST_IMPL.substitute(tensorlist_name=arg)] enforce_same_ptrs_stmts += [ENFORCE_SAME_TENSORLIST_STORAGE.substitute(tensorlist_name=arg), ENFORCE_SAME_TENSORLIST_IMPL.substitute(tensorlist_name=arg)] + elif simple_type == 'c10::List>': + save_ptrs_stmts += [SAVE_OPTIONALTENSORLIST_STORAGE.substitute(tensorlist_name=arg), + SAVE_OPTIONALTENSORLIST_IMPL.substitute(tensorlist_name=arg)] + enforce_same_ptrs_stmts += [ENFORCE_SAME_OPTIONALTENSORLIST_STORAGE.substitute(tensorlist_name=arg), + ENFORCE_SAME_OPTIONALTENSORLIST_IMPL.substitute(tensorlist_name=arg)] elif simple_type == 'Tensor': save_ptrs_stmts += [SAVE_TENSOR_STORAGE.substitute(tensor_name=arg), SAVE_TENSOR_IMPL.substitute(tensor_name=arg)] @@ -836,7 +872,7 @@ def emit_increment_version(): def unpack_args(env, declaration): def requires_unpack(arg): - return 'Tensor' in arg['dynamic_type'] + return 'Tensor' in arg['dynamic_type'] and 'c10::optional' not in arg['type'] body = [] unpacked_args = [] @@ -855,9 +891,8 @@ def requires_unpack(arg): dynamic_type = arg['dynamic_type'] if 'TensorOptions' not in dynamic_type: is_nullable = arg.get('is_nullable', False) - ref = (not is_nullable) and dynamic_type not in ['TensorList'] + ref = (not is_nullable) and dynamic_type != 'TensorList' suffix = '_opt' if is_nullable and dynamic_type != 'TensorList' else '' - body.append(UNPACK_TENSOR.substitute( arg_name=arg['name'], arg_pos=i, diff --git a/tools/autograd/templates/Functions.h b/tools/autograd/templates/Functions.h index 03240e2a5a2b..0540bb65b33b 100644 --- a/tools/autograd/templates/Functions.h +++ b/tools/autograd/templates/Functions.h @@ -32,6 +32,15 @@ inline std::vector unpack_list(at::ArrayRef xs) { }); } +inline c10::List> unpack_opt_list(at::ArrayRef xs) { + torch::List> result; + result.reserve(xs.size()); + for (const SavedVariable& v : xs) { + result.push_back(v.unpack()); + } + return result; +} + struct TypeAndSize { TypeAndSize() : options(at::TensorOptions()) {} /* implicit */ diff --git a/tools/autograd/templates/VariableType.h b/tools/autograd/templates/VariableType.h index 9062a4d08e34..fc8ffa5799c1 100644 --- a/tools/autograd/templates/VariableType.h +++ b/tools/autograd/templates/VariableType.h @@ -49,7 +49,6 @@ namespace VariableType { at::Tensor & unpack(Tensor & t, const char * name, int pos); const at::Tensor & unpack(const Tensor & t, const char * name, int pos); at::Tensor unpack_opt(const Tensor & t, const char * name, int pos); - c10::optional unpack_opt(const c10::optional & t, const char * name, int pos); std::vector unpack(at::TensorList tl, const char *name, int pos); }; diff --git a/tools/codegen/api/cpp.py b/tools/codegen/api/cpp.py index ffd9626601a0..c27a6768300a 100644 --- a/tools/codegen/api/cpp.py +++ b/tools/codegen/api/cpp.py @@ -104,9 +104,11 @@ def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> CType: return BaseCType("TensorList", binds) elif str(t.elem) == 'Dimname': return BaseCType("DimnameList", binds) - # TODO: do something reasonable about lists of optional tensors - elif (not local.use_c10_dispatcher().dispatcher_uses_new_style()) and str(t.elem) == 'Tensor?': - return BaseCType("TensorList", binds) + elif str(t.elem) == 'Tensor?': + if local.use_c10_dispatcher().dispatcher_uses_new_style(): + return BaseCType("const c10::List> &", binds) + else: + return BaseCType("TensorList", binds) elem = argumenttype_type(t.elem, mutable=mutable, binds=binds) # TODO: explicitly qualify namespace here return BaseCType(f"ArrayRef<{elem.cpp_type()}>", binds) diff --git a/tools/codegen/api/native.py b/tools/codegen/api/native.py index 3b793527edd9..9781c46884e7 100644 --- a/tools/codegen/api/native.py +++ b/tools/codegen/api/native.py @@ -34,7 +34,7 @@ def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> CType: else: return ConstRefCType(BaseCType('Tensor', binds)) elif str(t) == 'Tensor?[]': - return BaseCType('TensorList', binds) + return BaseCType('const c10::List> &', binds) return cpp.argumenttype_type(t, mutable=mutable, binds=binds) def returns_type(rs: Sequence[Return]) -> str: diff --git a/tools/codegen/api/python.py b/tools/codegen/api/python.py index 059032869675..bdb31d4d8616 100644 --- a/tools/codegen/api/python.py +++ b/tools/codegen/api/python.py @@ -228,7 +228,7 @@ class PythonArgument: # Compute argument formal for python argument parsing. # Needs to be consistent with torch/csrc/utils/python_arg_parser.h. def argument_str(self, *, method: bool = False) -> str: - type_str = argument_type_str(self.type) + type_str = argument_type_str(self.type).replace('const ', '').replace(' &', '') name = self.name # s/self/input/ outside method bindings @@ -624,10 +624,9 @@ def argument_type_str(t: Type, *, simple_type: bool = False) -> str: return f'ScalarList[{size}]' if size is not None else 'ScalarList' elif str(t.elem) == 'Tensor?': if simple_type: - return 'TensorList' + return 'c10::List>' else: - # TODO: clone the old codegen behavior but does it make sense? - return 'TensorList?' + return 'const c10::List> &' elif str(t.elem) == 'Dimname': return f'DimnameList[{size}]' if size is not None else 'DimnameList' elem = argument_type_str(t.elem, simple_type=simple_type) @@ -1051,12 +1050,14 @@ def arg_parser_unpack_method(t: Type, has_default: bool) -> str: return 'toDimnameListOptional' elif isinstance(t, ListType): - if str(t.elem) == 'Tensor' or str(t.elem) == 'Tensor?': + if str(t.elem) == 'Tensor': # accept and use definite size if t.size is not None: return f'tensorlist_n<{t.size}>' else: return 'tensorlist' + elif str(t.elem) == 'Tensor?': + return 'list_of_optional_tensors' elif str(t.elem) == 'Dimname': # accept definite size return 'dimnamelist' diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index 3c84a0da4a99..6558295d58cb 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #include @@ -2211,15 +2212,17 @@ Tensor det_backward(const Tensor & grad, const Tensor& self, const Tensor& det) return nonsingular_case_backward(grad, self, det); } } else { - auto nonzero_det_indices = at::where(det); + auto nonzero_det_indices = at::native::toListOfOptionalTensors(at::where(det)); + c10::optional first_nonzero_det_index = nonzero_det_indices[0]; - if (nonzero_det_indices[0].size(0) == det.numel()) { // all determinants are nonzero (non-singular) + if (first_nonzero_det_index->size(0) == det.numel()) { // all determinants are nonzero (non-singular) return nonsingular_case_backward(grad, self, det); } - auto zero_det_indices = at::where(det == 0); + auto zero_det_indices = at::native::toListOfOptionalTensors(at::where(det == 0)); + c10::optional first_zero_det_index = zero_det_indices[0]; - if (zero_det_indices[0].size(0) == det.numel()) { // all determinants are zero (singular) + if (first_zero_det_index->size(0) == det.numel()) { // all determinants are zero (singular) return singular_case_backward(grad, self, det); } @@ -2261,15 +2264,17 @@ Tensor logdet_backward(const Tensor & grad, const Tensor& self, const Tensor& lo return singular_case_backward(grad, self); } } else { - auto finite_logdet_indices = at::where(logdet != -INFINITY); + auto finite_logdet_indices = at::native::toListOfOptionalTensors(at::where(logdet != -INFINITY)); + c10::optional first_finite_logdet_index = finite_logdet_indices[0]; - if (finite_logdet_indices[0].size(0) == logdet.numel()) { // all log determinants are finite (non-singular) + if (first_finite_logdet_index->size(0) == logdet.numel()) { // all log determinants are finite (non-singular) return nonsingular_case_backward(grad, self); } - auto neginf_logdet_indices = at::where(logdet == -INFINITY); + auto neginf_logdet_indices = at::native::toListOfOptionalTensors(at::where(logdet == -INFINITY)); + c10::optional first_neginf_logdet_index = neginf_logdet_indices[0]; - if (neginf_logdet_indices[0].size(0) == logdet.numel()) { // all log determinants are -inf (singular) + if (first_neginf_logdet_index->size(0) == logdet.numel()) { // all log determinants are -inf (singular) return singular_case_backward(grad, self); } @@ -2313,15 +2318,17 @@ Tensor slogdet_backward(const Tensor& grad_logabsdet, return nonsingular_case_backward(grad_logabsdet, self); } } else { - auto nonzero_signdet_indices = at::where(signdet); + auto nonzero_signdet_indices = at::native::toListOfOptionalTensors(at::where(signdet)); + c10::optional first_nonzero_signdet_index = nonzero_signdet_indices[0]; - if (nonzero_signdet_indices[0].size(0) == logabsdet.numel()) { // all log determinants are finite (non-singular) + if (first_nonzero_signdet_index->size(0) == logabsdet.numel()) { // all log determinants are finite (non-singular) return nonsingular_case_backward(grad_logabsdet, self); } - auto zero_signdet_indices = at::where(signdet == 0); + auto zero_signdet_indices = at::native::toListOfOptionalTensors(at::where(signdet == 0)); + c10::optional first_zero_signdet_index = zero_signdet_indices[0]; - if (zero_signdet_indices[0].size(0) == logabsdet.numel()) { // all log determinants are -inf (singular) + if (first_zero_signdet_index->size(0) == logabsdet.numel()) { // all log determinants are -inf (singular) return singular_case_backward(grad_logabsdet, self); } @@ -2873,8 +2880,8 @@ Tensor embedding_dense_double_backward(const Tensor & grad, const Tensor & indic return gg_weight.view(size); } -Tensor index_backward(Tensor zeros_like_self, TensorList indices, const Tensor& grad) { - return at::_index_put_impl_(zeros_like_self, indices, grad, true, true); +Tensor index_backward(Tensor zeros_like_self, const torch::List>& indices, const Tensor& grad) { + return at::_index_put_impl_(zeros_like_self, indices, grad, true, true); } Tensor _cudnn_ctc_loss_backward(const Tensor& grad_out, const Tensor& loss, const Tensor& raw_grad, bool zero_infinity) { diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h index 3814e8078b23..30736e13f58a 100644 --- a/torch/csrc/autograd/FunctionsManual.h +++ b/torch/csrc/autograd/FunctionsManual.h @@ -124,7 +124,7 @@ at::Tensor slogdet_backward(const at::Tensor& grad_logabsdet, const at::Tensor& at::Tensor log1p_backward(const at::Tensor& grad, const at::Tensor& self); at::Tensor sparse_constructor_values_backward(const at::Tensor& sparse_grad_out, const at::Tensor& indices, at::IntArrayRef values_shape); at::Tensor embedding_dense_double_backward(const at::Tensor & grad, const at::Tensor & indices, int64_t padding_idx); -at::Tensor index_backward(at::Tensor zeros_like_self, at::TensorList indices, const at::Tensor& grad); +at::Tensor index_backward(at::Tensor zeros_like_self, const torch::List>& indices, const at::Tensor& grad); at::Tensor _cudnn_ctc_loss_backward(const at::Tensor& grad_out, const at::Tensor& loss, const at::Tensor& raw_grad, bool zero_infinity); Tensor svd_backward(const std::vector &grads, const Tensor& self, diff --git a/torch/csrc/autograd/VariableTypeManual.cpp b/torch/csrc/autograd/VariableTypeManual.cpp index 0663d7f46fa8..d1f15fff3669 100644 --- a/torch/csrc/autograd/VariableTypeManual.cpp +++ b/torch/csrc/autograd/VariableTypeManual.cpp @@ -66,10 +66,6 @@ Tensor unpack_opt(const Tensor & t, const char * name, int pos) { return unpack(t, name, pos); } -c10::optional unpack_opt(const c10::optional & t, const char * name, int pos) { - return t; -} - std::vector unpack(at::TensorList tl, const char *name, int pos) { std::vector ret(tl.size()); for (size_t i = 0; i < tl.size(); ++i) { @@ -94,7 +90,7 @@ void _backward( // instead of us having to unwrap it to Tensor _gradient here. Tensor _gradient = gradient.has_value() ? *gradient : Tensor(); std::vector input_vars(inputs.begin(), inputs.end()); - torch::autograd::backward({self}, {_gradient}, std::move(keep_graph), create_graph, input_vars); + torch::autograd::backward({self}, {_gradient}, keep_graph, create_graph, input_vars); } void set_data(Tensor & self, const Tensor & new_data) { @@ -230,7 +226,6 @@ Tensor _fw_primal(const Tensor & self, int64_t level) { // We don't have an outplace copy, so this can't be generated automatically Tensor & copy_(Tensor & self, const Tensor & src, bool non_blocking) { - jit::Value* output = nullptr; // TODO: once copy is exposed in Declarations.yaml we may be able to bind // it automatically auto& self_ = unpack(self, "self", 0); @@ -282,7 +277,7 @@ Tensor& resize_( } { at::AutoNonVariableTypeMode non_var_type_mode(true); - self_.resize_(size, std::move(optional_memory_format)); + self_.resize_(size, optional_memory_format); } if (self.fw_grad(/* level */ 0).defined()) { @@ -303,7 +298,7 @@ Tensor& resize_as_( } { at::AutoNonVariableTypeMode non_var_type_mode(true); - at::resize_as_(self_, the_template_, std::move(optional_memory_format)); + at::resize_as_(self_, the_template_, optional_memory_format); } // Handle fw grad diff --git a/torch/csrc/autograd/VariableTypeUtils.h b/torch/csrc/autograd/VariableTypeUtils.h index af02de68fc27..509a12e01140 100644 --- a/torch/csrc/autograd/VariableTypeUtils.h +++ b/torch/csrc/autograd/VariableTypeUtils.h @@ -266,12 +266,31 @@ inline void check_no_requires_grad(TensorList tensors, const char* name) { } } +inline void check_no_requires_grad(const c10::List>& tensors, const char* name) { + for (c10::optional tensor : tensors) { + if (tensor.has_value()) { + check_no_requires_grad(*tensor, name); + } + } +} + // Assumed that saved tensor lists are never inplace outputs inline std::vector make_saved_variable_list(TensorList tensors) { return fmap(tensors, [](const Tensor& tensor) -> SavedVariable { return SavedVariable{tensor, false /* is output */}; }); } +// Assumed that saved tensor lists are never inplace outputs +inline std::vector make_saved_variable_list(const c10::List>& tensors) { + return fmap(tensors, [](const c10::optional& tensor) -> SavedVariable { + if (tensor.has_value()) { + return SavedVariable{*tensor, false /* is output */}; + } else { + return SavedVariable{Tensor(), false /* is output */}; + } + }); +} + inline std::vector> to_args_sizes(TensorList tensors) { std::vector> args_sizes(tensors.size()); for (size_t i = 0; i < tensors.size(); ++i) { diff --git a/torch/csrc/jit/backends/backend_detail.h b/torch/csrc/jit/backends/backend_detail.h index 2d19f2ed8950..00f0f2f9eb44 100644 --- a/torch/csrc/jit/backends/backend_detail.h +++ b/torch/csrc/jit/backends/backend_detail.h @@ -1,5 +1,6 @@ #pragma once +#include #include namespace torch { diff --git a/torch/csrc/jit/frontend/tracer.cpp b/torch/csrc/jit/frontend/tracer.cpp index 72ccd77f2220..1bab391bd393 100644 --- a/torch/csrc/jit/frontend/tracer.cpp +++ b/torch/csrc/jit/frontend/tracer.cpp @@ -103,6 +103,9 @@ void TracingState::delValue(const IValue& var) { Value* getValueTrace(const IValue& var) { return getTracingState()->getValue(var); } +Value* getOptTensorValueTrace(const c10::optional& var) { + return getValueTrace(IValue(var)); +} Value* TracingState::getValue(const IValue& var) { // allow tracing of tuples passed to List[Tensor] or Tuple[Tensor...] // arguments @@ -686,6 +689,16 @@ void addInputs( } n->addInput(list_node->output()); } +TORCH_API void addInputs( + Node* n, + const char* name, + const List>& value) { + Graph* g = n->owningGraph(); + Node* list_node = nullptr; + list_node = g->insertNode(g->createList( + OptionalType::ofTensor(), fmap(value, getOptTensorValueTrace))); + n->addInput(list_node->output()); +} void addInputs( Node* n, diff --git a/torch/csrc/jit/frontend/tracer.h b/torch/csrc/jit/frontend/tracer.h index 61d79cb3efd2..f5cbd821bda4 100644 --- a/torch/csrc/jit/frontend/tracer.h +++ b/torch/csrc/jit/frontend/tracer.h @@ -255,6 +255,10 @@ TORCH_API void addInputs( const char* name, ArrayRef value, bool allow_undefined = false); +TORCH_API void addInputs( + Node* n, + const char* name, + const List>& value); TORCH_API void addInputs( Node* n, const char* name, diff --git a/torch/csrc/jit/mobile/module.h b/torch/csrc/jit/mobile/module.h index 8b7da739df9a..2be75c61b6b5 100644 --- a/torch/csrc/jit/mobile/module.h +++ b/torch/csrc/jit/mobile/module.h @@ -1,5 +1,6 @@ #pragma once //#include +#include #include #include diff --git a/torch/csrc/jit/runtime/interpreter.h b/torch/csrc/jit/runtime/interpreter.h index 120a3ffb7507..a4bb209cd17e 100644 --- a/torch/csrc/jit/runtime/interpreter.h +++ b/torch/csrc/jit/runtime/interpreter.h @@ -5,6 +5,7 @@ #include #include +#include #include #include diff --git a/torch/csrc/jit/runtime/register_prim_ops.cpp b/torch/csrc/jit/runtime/register_prim_ops.cpp index f23b09dc0e74..fe75ec52046e 100644 --- a/torch/csrc/jit/runtime/register_prim_ops.cpp +++ b/torch/csrc/jit/runtime/register_prim_ops.cpp @@ -908,7 +908,7 @@ RegisterOperators reg( TORCH_SELECTIVE_SCHEMA( "aten::index.Tensor_hacked_twin(Tensor self, Tensor[] indices) -> Tensor"), [](Stack* stack) { - auto indices = pop(stack).toTensorVector(); + auto indices = pop(stack).to>>(); auto self = pop(stack).toTensor(); auto result = at::index(self, indices); push(stack, std::move(result)); @@ -921,7 +921,7 @@ RegisterOperators reg( auto unsafe = pop(stack).toBool(); auto accumulate = pop(stack).toBool(); auto values = pop(stack).toTensor(); - auto indices = pop(stack).toTensorVector(); + auto indices = pop(stack).to>>(); auto self = pop(stack).toTensor(); auto result = at::_index_put_impl_(self, indices, values, accumulate, unsafe); @@ -934,7 +934,7 @@ RegisterOperators reg( [](Stack* stack) { auto accumulate = pop(stack).toBool(); auto values = pop(stack).toTensor(); - auto indices = pop(stack).toTensorVector(); + auto indices = pop(stack).to>>(); auto self = pop(stack).toTensor(); auto result = at::index_put_(self, indices, values, accumulate); push(stack, std::move(result)); @@ -946,7 +946,7 @@ RegisterOperators reg( [](Stack* stack) { auto accumulate = pop(stack).toBool(); auto values = pop(stack).toTensor(); - auto indices = pop(stack).toTensorVector(); + auto indices = pop(stack).to>>(); auto self = pop(stack).toTensor(); auto result = at::index_put_(self, indices, values, accumulate); push(stack, std::move(result)); diff --git a/torch/csrc/jit/runtime/vararg_functions.h b/torch/csrc/jit/runtime/vararg_functions.h index 36bef721d626..d6eba7f5d191 100644 --- a/torch/csrc/jit/runtime/vararg_functions.h +++ b/torch/csrc/jit/runtime/vararg_functions.h @@ -2,6 +2,7 @@ #include #include #include +#include #include namespace torch { diff --git a/torch/csrc/utils/python_arg_parser.cpp b/torch/csrc/utils/python_arg_parser.cpp index c7fdf844945e..ee3a0bc71f2f 100644 --- a/torch/csrc/utils/python_arg_parser.cpp +++ b/torch/csrc/utils/python_arg_parser.cpp @@ -24,6 +24,7 @@ static std::unordered_map type_map = { {"double", ParameterType::DOUBLE}, {"complex", ParameterType::COMPLEX}, {"TensorList", ParameterType::TENSOR_LIST}, + {"c10::List>", ParameterType::TENSOR_LIST}, {"IntArrayRef", ParameterType::INT_LIST}, {"ArrayRef", ParameterType::FLOAT_LIST}, {"Generator", ParameterType::GENERATOR}, @@ -390,7 +391,7 @@ bool is_float_or_complex_list(PyObject* obj) { } auto size = tuple ? PyTuple_GET_SIZE(obj) : PyList_GET_SIZE(obj); - if (size > 0) { + if (size > 0) { PyObject* iobj = tuple ? PyTuple_GET_ITEM(obj, 0) : PyList_GET_ITEM(obj, 0); if (!THPUtils_checkDouble(iobj) && !PyComplex_Check(iobj)) { return false; diff --git a/torch/csrc/utils/python_arg_parser.h b/torch/csrc/utils/python_arg_parser.h index ccf3ba6b42c4..9fa490139cbd 100644 --- a/torch/csrc/utils/python_arg_parser.h +++ b/torch/csrc/utils/python_arg_parser.h @@ -160,6 +160,7 @@ struct PythonArgs { inline at::Scalar scalarWithDefault(int i, at::Scalar default_scalar); inline std::vector scalarlist(int i); inline std::vector tensorlist(int i); + inline torch::List> list_of_optional_tensors(int i); template inline std::array tensorlist_n(int i); inline std::vector intlist(int i); @@ -327,6 +328,22 @@ inline std::vector PythonArgs::tensorlist(int i) { return res; } +inline torch::List> PythonArgs::list_of_optional_tensors(int i) { + if (!args[i]) return torch::List>(); + auto tuple = six::isTuple(args[i]); + THPObjectPtr arg = six::maybeAsTuple(args[i]); + auto size = tuple ? PyTuple_GET_SIZE(arg.get()) : PyList_GET_SIZE(arg.get()); + torch::List> res; + res.reserve(size); + for (int idx = 0; idx < size; idx++) { + PyObject* obj = tuple ? PyTuple_GET_ITEM(arg.get(), idx) : PyList_GET_ITEM(arg.get(), idx); + // This is checked by the argument parser so it's safe to cast without checking + // if this is a tensor first + res.push_back(reinterpret_cast(obj)->cdata); + } + return res; +} + template inline std::array PythonArgs::tensorlist_n(int i) { auto res = std::array();