From ee09a7c52577c01483d2f194b2b2832500f6e9a2 Mon Sep 17 00:00:00 2001 From: zhxchen17 Date: Fri, 17 Dec 2021 10:21:38 -0800 Subject: [PATCH] [jit] Polymorphic IValue::type() for DynamicType. Before the change: ``` c10::Type t = ivalue.type(); ``` After the change: ``` c10::Type t = ivalue.type(); c10::DynamicType d = ivalue.type(); // new path ``` The new path will be adopted in PyTorch Lite Interpreter to support lightweight type reflection. Note that type getters are selected at compile time so no performance overhead. The benefits of having a DynamicType will be elaborated in a separate document, but in short, DynamicType provides an isolated type system for controlling binary size bloat, and shrink down ~20 supported Type symbols into one so that the size taken by specializations and function name symbols are greatly reduced. Lite Interpreter should only use the `` variant of the interfaces from aten, to reduce binary size. Differential Revision: [D33102276](https://our.internmc.facebook.com/intern/diff/D33102276/) [ghstack-poisoned] --- aten/src/ATen/core/dynamic_type.cpp | 37 +++++++++ aten/src/ATen/core/dynamic_type.h | 8 ++ aten/src/ATen/core/ivalue.cpp | 108 ++++++++++++-------------- aten/src/ATen/core/ivalue.h | 6 +- aten/src/ATen/core/ivalue_inl.h | 25 +++++- aten/src/ATen/core/jit_type.h | 1 - aten/src/ATen/core/jit_type_base.h | 1 + aten/src/ATen/core/type.cpp | 1 + torch/csrc/jit/mobile/interpreter.cpp | 3 +- torch/csrc/jit/mobile/type_parser.h | 1 + 10 files changed, 127 insertions(+), 64 deletions(-) diff --git a/aten/src/ATen/core/dynamic_type.cpp b/aten/src/ATen/core/dynamic_type.cpp index 8d949415753ae..2476b7c903f88 100644 --- a/aten/src/ATen/core/dynamic_type.cpp +++ b/aten/src/ATen/core/dynamic_type.cpp @@ -181,4 +181,41 @@ bool DynamicType::LabeledDynamicType::equals( return (label == other.label) && (*ty == *other.ty); } +DynamicType::Ptr IValue::TagType::get(const c10::IValue& v) { + switch (v.tag) { + case Tag::None: + return NoneType::get(); + case Tag::Tensor: + return TensorType::get(); + case Tag::Double: + return FloatType::get(); + case Tag::ComplexDouble: + return ComplexType::get(); + case Tag::Int: + return IntType::get(); + case Tag::Bool: + return BoolType::get(); + case Tag::String: + return StringType::get(); + case Tag::GenericDict: { + auto d = v.toGenericDict(); + return DictType::create(d.keyType(), d.valueType()); + } + case Tag::GenericList: + return ListType::create(v.toList().elementType()); + case Tag::Device: + return DeviceObjType::get(); + case Tag::Stream: + return StreamObjType::get(); + case Tag::Object: + return v.toObjectRef().type(); + case Tag::Capsule: + return CapsuleType::get(); + case Tag::Tuple: + return v.toTupleRef().type(); + default: + return AnyType::get(); + } +} + } // namespace c10 diff --git a/aten/src/ATen/core/dynamic_type.h b/aten/src/ATen/core/dynamic_type.h index 86dcbd9ccf014..a7c0bb523e2ed 100644 --- a/aten/src/ATen/core/dynamic_type.h +++ b/aten/src/ATen/core/dynamic_type.h @@ -2,6 +2,7 @@ #include +#include #include #include #include @@ -104,6 +105,8 @@ class DynamicType : public Type { }; public: + // TODO Change Ptr to DynamicTypePtr when all migrations are done. + using Ptr = TypePtr; ~DynamicType() override; struct Arguments { @@ -153,4 +156,9 @@ class DynamicType : public Type { }; }; +template <> +struct IValue::TagType { + static DynamicType::Ptr get(const c10::IValue& v); +}; + } // namespace c10 diff --git a/aten/src/ATen/core/ivalue.cpp b/aten/src/ATen/core/ivalue.cpp index 7b7aa4ebf80d9..1284c738b1773 100644 --- a/aten/src/ATen/core/ivalue.cpp +++ b/aten/src/ATen/core/ivalue.cpp @@ -61,14 +61,6 @@ bool operator==(const ivalue::Tuple& lhs, const ivalue::Tuple& rhs) { _fastEqualsForContainer); } -TupleTypePtr Tuple::type() const { - if (!type_) { - type_ = TupleType::create( - fmap(elements(), [&](const IValue& v) { return v.type(); })); - } - return type_; -} - bool operator==(const ivalue::EnumHolder& lhs, const ivalue::EnumHolder& rhs) { return lhs.name() == rhs.name() && *rhs.type() == *lhs.type(); } @@ -83,56 +75,56 @@ const std::string ivalue::EnumHolder::unqualifiedClassName() const { } // namespace ivalue -TypePtr IValue::type() const { - switch (tag) { - case Tag::None: - return NoneType::get(); - case Tag::Tensor: - return TensorType::create(toTensor()); - case Tag::Storage: - return StorageType::get(); - case Tag::Double: - return FloatType::get(); - case Tag::ComplexDouble: - return ComplexType::get(); - case Tag::Int: - return IntType::get(); - case Tag::Bool: - return BoolType::get(); - case Tag::String: - return StringType::get(); - case Tag::Blob: - return AnyType::get(); - case Tag::GenericDict: { - auto d = toGenericDict(); - return DictType::create(d.keyType(), d.valueType()); - } - case Tag::GenericList: - return ListType::create(toList().elementType()); - case Tag::Future: - return FutureType::create(toFuture()->elementType()); - case Tag::RRef: - return RRefType::create(toRRef()->type()); - case Tag::Device: - return DeviceObjType::get(); - case Tag::Stream: - return StreamObjType::get(); - case Tag::Object: - return toObjectRef().type(); - case Tag::PyObject: - return PyObjectType::get(); - case Tag::Uninitialized: - return AnyType::get(); - case Tag::Capsule: - return CapsuleType::get(); - case Tag::Tuple: - return toTupleRef().type(); - case Tag::Generator: - return GeneratorType::get(); - case Tag::Quantizer: - return QuantizerType::get(); - case Tag::Enum: - return toEnumHolder()->type(); +c10::TypePtr IValue::TagType::get(const IValue& v) { + switch (v.tag) { + case Tag::None: + return NoneType::get(); + case Tag::Tensor: + return TensorType::create(v.toTensor()); + case Tag::Storage: + return StorageType::get(); + case Tag::Double: + return FloatType::get(); + case Tag::ComplexDouble: + return ComplexType::get(); + case Tag::Int: + return IntType::get(); + case Tag::Bool: + return BoolType::get(); + case Tag::String: + return StringType::get(); + case Tag::Blob: + return AnyType::get(); + case Tag::GenericDict: { + auto d = v.toGenericDict(); + return DictType::create(d.keyType(), d.valueType()); + } + case Tag::GenericList: + return ListType::create(v.toList().elementType()); + case Tag::Future: + return FutureType::create(v.toFuture()->elementType()); + case Tag::RRef: + return RRefType::create(v.toRRef()->type()); + case Tag::Device: + return DeviceObjType::get(); + case Tag::Stream: + return StreamObjType::get(); + case Tag::Object: + return v.toObjectRef().type(); + case Tag::PyObject: + return PyObjectType::get(); + case Tag::Uninitialized: + return AnyType::get(); + case Tag::Capsule: + return CapsuleType::get(); + case Tag::Tuple: + return v.toTupleRef().type(); + case Tag::Generator: + return GeneratorType::get(); + case Tag::Quantizer: + return QuantizerType::get(); + case Tag::Enum: + return v.toEnumHolder()->type(); } // switch above is complete but this silences compiler warnings TORCH_INTERNAL_ASSERT(false, "unhandled case in IValue::type()"); diff --git a/aten/src/ATen/core/ivalue.h b/aten/src/ATen/core/ivalue.h index 32ae5755ed3f3..0645358ead4f6 100644 --- a/aten/src/ATen/core/ivalue.h +++ b/aten/src/ATen/core/ivalue.h @@ -878,7 +878,8 @@ struct TORCH_API IValue final { } } - TypePtr type() const; + template + typename T::Ptr type() const; // Detect aliased tensors. struct HashAliasedIValue { @@ -1032,6 +1033,9 @@ struct TORCH_API IValue final { } } + template + struct TagType {}; + friend MaybeOwnedTraits; Payload payload; diff --git a/aten/src/ATen/core/ivalue_inl.h b/aten/src/ATen/core/ivalue_inl.h index ee48693f9dfec..e1c5cfb42ad9d 100644 --- a/aten/src/ATen/core/ivalue_inl.h +++ b/aten/src/ATen/core/ivalue_inl.h @@ -12,7 +12,6 @@ #include #include #include -#include #include #include #include @@ -20,9 +19,11 @@ #include #include #include +#include +#include +#include #include #include -#include namespace torch { namespace jit { @@ -684,7 +685,15 @@ struct TORCH_API Tuple : c10::intrusive_ptr_target { return elements_.size(); } - std::shared_ptr type() const; + template + std::shared_ptr type() const { + if (!type_) { + type_ = TupleType::create(fmap(elements(), [&](const IValue& v) { + return v.type(); + })); + } + return type_; + } static size_t hash(const Tuple& t) { return c10::get_hash(t.elements()); @@ -2233,4 +2242,14 @@ struct MaybeOwnedTraits { } }; +template <> +struct IValue::TagType { + static c10::TypePtr get(const IValue&); +}; + +template +typename T::Ptr IValue::type() const { + return IValue::TagType::get(*this); +} + } // namespace c10 diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h index 11883df54ea6b..3fa24da306cce 100644 --- a/aten/src/ATen/core/jit_type.h +++ b/aten/src/ATen/core/jit_type.h @@ -2,7 +2,6 @@ #include #include -#include #include #include #include diff --git a/aten/src/ATen/core/jit_type_base.h b/aten/src/ATen/core/jit_type_base.h index 457ead85efa83..87049bb0ffc65 100644 --- a/aten/src/ATen/core/jit_type_base.h +++ b/aten/src/ATen/core/jit_type_base.h @@ -79,6 +79,7 @@ struct TORCH_API Type : std::enable_shared_from_this { } public: + using Ptr = TypePtr; virtual bool operator==(const Type& rhs) const = 0; // subtyping relation. By default, we return true for the case diff --git a/aten/src/ATen/core/type.cpp b/aten/src/ATen/core/type.cpp index e81205ad853a0..fb2cf03b4ebc4 100644 --- a/aten/src/ATen/core/type.cpp +++ b/aten/src/ATen/core/type.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include #include #include diff --git a/torch/csrc/jit/mobile/interpreter.cpp b/torch/csrc/jit/mobile/interpreter.cpp index 53660ea97b7f1..c4acbb0e30f48 100644 --- a/torch/csrc/jit/mobile/interpreter.cpp +++ b/torch/csrc/jit/mobile/interpreter.cpp @@ -1,5 +1,6 @@ #include +#include #include #include #include @@ -32,7 +33,7 @@ void createObject(Stack& stack, const at::ClassTypePtr& type) { } void isinstance(Stack& stack, at::ArrayRef types) { - at::TypePtr ty = pop(stack).type(); + at::TypePtr ty = pop(stack).type(); for (const at::TypePtr& candidate : types) { if (ty->isSubtypeOf(*candidate)) { push(stack, true); diff --git a/torch/csrc/jit/mobile/type_parser.h b/torch/csrc/jit/mobile/type_parser.h index cf8e92602f27c..1a617918db111 100644 --- a/torch/csrc/jit/mobile/type_parser.h +++ b/torch/csrc/jit/mobile/type_parser.h @@ -1,3 +1,4 @@ +#include #include namespace c10 {