Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[jit] Polymorphic IValue::type() for DynamicType. #70120

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
ee09a7c
[jit] Polymorphic IValue::type() for DynamicType.
zhxchen17 Dec 17, 2021
45e0d0c
Update on "[jit] Polymorphic IValue::type() for DynamicType."
zhxchen17 Dec 20, 2021
ae138ab
Update on "[jit] Polymorphic IValue::type() for DynamicType."
zhxchen17 Dec 20, 2021
dcbf069
Update on "[jit] Polymorphic IValue::type() for DynamicType."
zhxchen17 Dec 20, 2021
e8b5fe4
Update on "[jit] Polymorphic IValue::type() for DynamicType."
zhxchen17 Dec 21, 2021
07783fc
Update on "[jit] Polymorphic IValue::type() for DynamicType."
zhxchen17 Dec 21, 2021
883ea7a
Update on "[jit] Polymorphic IValue::type() for DynamicType."
zhxchen17 Dec 21, 2021
e7a8f03
Update on "[jit] Polymorphic IValue::type() for DynamicType."
zhxchen17 Dec 21, 2021
9987809
Update on "[jit] Polymorphic IValue::type() for DynamicType."
zhxchen17 Dec 23, 2021
6b32746
Update on "[jit] Polymorphic IValue::type() for DynamicType."
zhxchen17 Dec 24, 2021
09e1f54
Update on "[jit] Polymorphic IValue::type() for DynamicType."
zhxchen17 Dec 29, 2021
81ad873
Update on "[jit] Polymorphic IValue::type() for DynamicType."
zhxchen17 Dec 30, 2021
2c16266
Update on "[jit] Polymorphic IValue::type() for DynamicType."
zhxchen17 Jan 4, 2022
1d2b253
Update on "[jit] Polymorphic IValue::type() for DynamicType."
zhxchen17 Jan 5, 2022
6df124a
Update on "[jit] Polymorphic IValue::type() for DynamicType."
zhxchen17 Jan 7, 2022
00f1ad8
Update on "[jit] Polymorphic IValue::type() for DynamicType."
zhxchen17 Jan 7, 2022
9553998
Update on "[jit] Polymorphic IValue::type() for DynamicType."
zhxchen17 Jan 8, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
37 changes: 37 additions & 0 deletions aten/src/ATen/core/dynamic_type.cpp
Expand Up @@ -181,4 +181,41 @@ bool DynamicType::LabeledDynamicType::equals(
return (label == other.label) && (*ty == *other.ty);
}

DynamicType::Ptr IValue::TagType<c10::DynamicType>::get(const c10::IValue& v) {
switch (v.tag) {
case Tag::None:
return NoneType::get();
case Tag::Tensor:
return TensorType::get();
case Tag::Double:
return FloatType::get();
case Tag::ComplexDouble:
return ComplexType::get();
case Tag::Int:
return IntType::get();
case Tag::Bool:
return BoolType::get();
case Tag::String:
return StringType::get();
case Tag::GenericDict: {
auto d = v.toGenericDict();
return DictType::create(d.keyType(), d.valueType());
}
case Tag::GenericList:
return ListType::create(v.toList().elementType());
case Tag::Device:
return DeviceObjType::get();
case Tag::Stream:
return StreamObjType::get();
case Tag::Object:
return v.toObjectRef().type();
case Tag::Capsule:
return CapsuleType::get();
case Tag::Tuple:
return v.toTupleRef().type<c10::DynamicType>();
default:
return AnyType::get();
}
}

} // namespace c10
8 changes: 8 additions & 0 deletions aten/src/ATen/core/dynamic_type.h
Expand Up @@ -2,6 +2,7 @@

#include <memory>

#include <ATen/core/ivalue.h>
#include <ATen/core/class_type.h>
#include <ATen/core/jit_type_base.h>
#include <c10/util/Optional.h>
Expand Down Expand Up @@ -104,6 +105,8 @@ class DynamicType : public SharedType {
};

public:
// TODO Change Ptr to DynamicTypePtr when all migrations are done.
using Ptr = TypePtr;
~DynamicType() override;

struct Arguments {
Expand Down Expand Up @@ -156,4 +159,9 @@ class DynamicType : public SharedType {
};
};

template <>
struct IValue::TagType<c10::DynamicType> {
static DynamicType::Ptr get(const c10::IValue& v);
};

} // namespace c10
108 changes: 50 additions & 58 deletions aten/src/ATen/core/ivalue.cpp
Expand Up @@ -62,14 +62,6 @@ bool operator==(const ivalue::Tuple& lhs, const ivalue::Tuple& rhs) {
_fastEqualsForContainer);
}

TupleTypePtr Tuple::type() const {
if (!type_) {
type_ = TupleType::create(
fmap(elements(), [&](const IValue& v) { return v.type(); }));
}
return type_;
}

bool operator==(const ivalue::EnumHolder& lhs, const ivalue::EnumHolder& rhs) {
return lhs.name() == rhs.name() && *rhs.type() == *lhs.type();
}
Expand All @@ -84,56 +76,56 @@ const std::string ivalue::EnumHolder::unqualifiedClassName() const {

} // namespace ivalue

TypePtr IValue::type() const {
switch (tag) {
case Tag::None:
return NoneType::get();
case Tag::Tensor:
return TensorType::create(toTensor());
case Tag::Storage:
return StorageType::get();
case Tag::Double:
return FloatType::get();
case Tag::ComplexDouble:
return ComplexType::get();
case Tag::Int:
return IntType::get();
case Tag::Bool:
return BoolType::get();
case Tag::String:
return StringType::get();
case Tag::Blob:
return AnyType::get();
case Tag::GenericDict: {
auto d = toGenericDict();
return DictType::create(d.keyType(), d.valueType());
}
case Tag::GenericList:
return ListType::create(toList().elementType());
case Tag::Future:
return FutureType::create(toFuture()->elementType());
case Tag::RRef:
return RRefType::create(toRRef()->type());
case Tag::Device:
return DeviceObjType::get();
case Tag::Stream:
return StreamObjType::get();
case Tag::Object:
return toObjectRef().type();
case Tag::PyObject:
return PyObjectType::get();
case Tag::Uninitialized:
return AnyType::get();
case Tag::Capsule:
return CapsuleType::get();
case Tag::Tuple:
return toTupleRef().type();
case Tag::Generator:
return GeneratorType::get();
case Tag::Quantizer:
return QuantizerType::get();
case Tag::Enum:
return toEnumHolder()->type();
c10::TypePtr IValue::TagType<c10::Type>::get(const IValue& v) {
switch (v.tag) {
case Tag::None:
return NoneType::get();
case Tag::Tensor:
return TensorType::create(v.toTensor());
case Tag::Storage:
return StorageType::get();
case Tag::Double:
return FloatType::get();
case Tag::ComplexDouble:
return ComplexType::get();
case Tag::Int:
return IntType::get();
case Tag::Bool:
return BoolType::get();
case Tag::String:
return StringType::get();
case Tag::Blob:
return AnyType::get();
case Tag::GenericDict: {
auto d = v.toGenericDict();
return DictType::create(d.keyType(), d.valueType());
}
case Tag::GenericList:
return ListType::create(v.toList().elementType());
case Tag::Future:
return FutureType::create(v.toFuture()->elementType());
case Tag::RRef:
return RRefType::create(v.toRRef()->type());
case Tag::Device:
return DeviceObjType::get();
case Tag::Stream:
return StreamObjType::get();
case Tag::Object:
return v.toObjectRef().type();
case Tag::PyObject:
return PyObjectType::get();
case Tag::Uninitialized:
return AnyType::get();
case Tag::Capsule:
return CapsuleType::get();
case Tag::Tuple:
return v.toTupleRef().type();
case Tag::Generator:
return GeneratorType::get();
case Tag::Quantizer:
return QuantizerType::get();
case Tag::Enum:
return v.toEnumHolder()->type();
}
// switch above is complete but this silences compiler warnings
TORCH_INTERNAL_ASSERT(false, "unhandled case in IValue::type()");
Expand Down
6 changes: 5 additions & 1 deletion aten/src/ATen/core/ivalue.h
Expand Up @@ -889,7 +889,8 @@ struct TORCH_API IValue final {
}
}

TypePtr type() const;
template <typename T = c10::Type>
typename T::Ptr type() const;

// Detect aliased tensors.
struct HashAliasedIValue {
Expand Down Expand Up @@ -1048,6 +1049,9 @@ struct TORCH_API IValue final {
}
}

template <typename T>
struct TagType {};

friend MaybeOwnedTraits<IValue>;

Payload payload;
Expand Down
25 changes: 22 additions & 3 deletions aten/src/ATen/core/ivalue_inl.h
Expand Up @@ -12,17 +12,18 @@
#include <ATen/core/qualified_name.h>
#include <ATen/core/rref_interface.h>
#include <ATen/core/symbol.h>
#include <c10/core/impl/DeviceGuardImplInterface.h>
#include <c10/core/DeviceGuard.h>
#include <c10/core/Event.h>
#include <c10/core/Scalar.h>
#include <c10/core/Stream.h>
#include <c10/core/StreamGuard.h>
#include <c10/core/TensorImpl.h>
#include <c10/core/UndefinedTensorImpl.h>
#include <c10/core/impl/DeviceGuardImplInterface.h>
#include <c10/util/FunctionRef.h>
#include <c10/util/hash.h>
#include <c10/util/intrusive_ptr.h>
#include <c10/util/irange.h>
#include <c10/util/hash.h>

namespace torch {
namespace jit {
Expand Down Expand Up @@ -684,7 +685,15 @@ struct TORCH_API Tuple : c10::intrusive_ptr_target {
return elements_.size();
}

std::shared_ptr<TupleType> type() const;
template <typename T = c10::Type>
std::shared_ptr<TupleType> type() const {
if (!type_) {
type_ = TupleType::create(fmap(elements(), [&](const IValue& v) {
return v.type<T>();
}));
}
return type_;
}

static size_t hash(const Tuple& t) {
return c10::get_hash(t.elements());
Expand Down Expand Up @@ -2234,4 +2243,14 @@ struct MaybeOwnedTraits<IValue> {
}
};

template <>
struct IValue::TagType<c10::Type> {
static TORCH_API c10::TypePtr get(const IValue&);
};

template <typename T>
typename T::Ptr IValue::type() const {
return IValue::TagType<T>::get(*this);
}

} // namespace c10
1 change: 1 addition & 0 deletions aten/src/ATen/core/jit_type_base.h
Expand Up @@ -373,6 +373,7 @@ struct TORCH_API Type {
};

using TypePtr = SingletonOrSharedTypePtr<Type>;
using Ptr = TypePtr;

// subtyping relation. By default, we return true for the case
// when the type is exactly equal or if this <: T where rhs = Optional[T]
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/core/type.cpp
@@ -1,6 +1,7 @@
#include <ATen/core/Dict.h>
#include <ATen/core/Tensor.h>
#include <ATen/core/dynamic_type.h>
#include <ATen/core/function_schema.h>
#include <ATen/core/enum_type.h>
#include <ATen/core/function_schema.h>
#include <ATen/core/jit_type.h>
Expand Down
3 changes: 2 additions & 1 deletion torch/csrc/jit/mobile/interpreter.cpp
@@ -1,6 +1,7 @@
#include <torch/csrc/jit/mobile/interpreter.h>

#include <ATen/core/class_type.h>
#include <ATen/core/dynamic_type.h>
#include <ATen/core/function.h>
#include <ATen/core/jit_type.h>
#include <ATen/core/operator_name.h>
Expand Down Expand Up @@ -33,7 +34,7 @@ void createObject(Stack& stack, const at::ClassTypePtr& type) {
}

void isinstance(Stack& stack, at::ArrayRef<at::TypePtr> types) {
at::TypePtr ty = pop(stack).type();
at::TypePtr ty = pop(stack).type<c10::DynamicType>();
for (const at::TypePtr& candidate : types) {
if (ty->isSubtypeOf(*candidate)) {
push(stack, true);
Expand Down