diff --git a/aten/src/ATen/core/ivalue.cpp b/aten/src/ATen/core/ivalue.cpp index 6b8f4412cbf7..60382e37b6ff 100644 --- a/aten/src/ATen/core/ivalue.cpp +++ b/aten/src/ATen/core/ivalue.cpp @@ -22,7 +22,7 @@ namespace ivalue { // This is in ivalue.cpp because we need to access Type::annotation_str, which // is declared in jit_type.h -void checkCustomClassType(TypePtr expected_type, TypePtr actual_type) { +void checkCustomClassType(const Type* expected_type, const Type* actual_type) { // NB: doing pointer comparison here // If in the future there ever arises a need to call operator== on custom class // Type's, this needs to be changed! diff --git a/aten/src/ATen/core/ivalue.h b/aten/src/ATen/core/ivalue.h index 9ea18dc8482d..d2e72933b532 100644 --- a/aten/src/ATen/core/ivalue.h +++ b/aten/src/ATen/core/ivalue.h @@ -949,8 +949,8 @@ TORCH_API ska::flat_hash_map& getCustomClassTypeMap(); template -c10::ClassTypePtr getCustomClassType() { - auto tmap = c10::getCustomClassTypeMap(); +c10::ClassTypePtr getCustomClassTypeImpl() { + auto& tmap = c10::getCustomClassTypeMap(); auto res = tmap.find(std::type_index(typeid(T))); if (res == tmap.end()) { throw c10::Error("Can't find class id in custom class type map", ""); @@ -959,9 +959,13 @@ c10::ClassTypePtr getCustomClassType() { } template -inline bool isCustomClassRegistered() { - auto tmap = c10::getCustomClassTypeMap(); - return tmap.find(std::type_index(typeid(T))) != tmap.end(); +const c10::ClassTypePtr& getCustomClassType() { + // Classes are never unregistered from getCustomClassTypeMap and the + // hash lookup can be a hot path, so just cache. + // For the same reason, it's fine If this ends up getting duplicated across + // DSO boundaries for whatever reason. + static c10::ClassTypePtr cache = getCustomClassTypeImpl(); + return cache; } TORCH_API std::unordered_map>& diff --git a/aten/src/ATen/core/ivalue_inl.h b/aten/src/ATen/core/ivalue_inl.h index 3068bda5f5a5..e5752d8e7b9b 100644 --- a/aten/src/ATen/core/ivalue_inl.h +++ b/aten/src/ATen/core/ivalue_inl.h @@ -172,7 +172,7 @@ inline at::Generator IValue::toGenerator() const& { namespace ivalue { void CAFFE2_API -checkCustomClassType(TypePtr expected_type, TypePtr actual_type); +checkCustomClassType(const Type* expected_type, const Type* actual_type); template using Shared = c10::intrusive_ptr; @@ -756,8 +756,8 @@ c10::intrusive_ptr IValue::toCustomClass() && { obj->slots().size() == 1, "Tried to cast IValue to custom class but it did " "not contain a custom class!"); - auto expected_type = c10::getCustomClassType>(); - ivalue::checkCustomClassType(expected_type, type()); + const Type* expected_type = c10::getCustomClassType>().get(); + ivalue::checkCustomClassType(expected_type, type().get()); auto userObj = c10::static_intrusive_pointer_cast(obj->getSlot(0).toCapsule()); return userObj; @@ -774,8 +774,8 @@ c10::intrusive_ptr IValue::toCustomClass() const& { obj->slots().size() == 1, "Tried to cast IValue to custom class but it did " "not contain a custom class!"); - auto expected_type = c10::getCustomClassType>(); - ivalue::checkCustomClassType(expected_type, type()); + const Type* expected_type = c10::getCustomClassType>().get(); + ivalue::checkCustomClassType(expected_type, type().get()); auto userObj = c10::static_intrusive_pointer_cast(obj->getSlot(0).toCapsule()); return userObj; @@ -1085,13 +1085,16 @@ template < typename T, std::enable_if_t::value, int>> IValue::IValue(c10::intrusive_ptr custom_class) { - if (!c10::isCustomClassRegistered>()) { - throw c10::Error( - "Trying to instantiate a class that isn't a registered custom class: " + - std::string(c10::util::get_fully_qualified_type_name()), - ""); - } - auto classType = c10::getCustomClassType>(); + TypePtr classType = []() { + try { + return c10::getCustomClassType>(); + } catch (const c10::Error&) { + throw c10::Error( + "Trying to instantiate a class that isn't a registered custom class: " + + std::string(c10::util::get_fully_qualified_type_name()), + ""); + } + }(); auto ivalue_obj = c10::ivalue::Object::create( c10::StrongTypePtr(nullptr, classType), /*num_slots=*/1); ivalue_obj->setSlot(0, IValue::make_capsule(std::move(custom_class))); diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h index 40c2ec7f443d..1f6907818a46 100644 --- a/aten/src/ATen/core/jit_type.h +++ b/aten/src/ATen/core/jit_type.h @@ -1725,13 +1725,18 @@ namespace detail { template struct getTypePtr_ final { static TypePtr call() { - TORCH_CHECK( - isCustomClassRegistered(), - "Type ", - c10::util::get_fully_qualified_type_name(), - " could not be converted to any of the known types." - ); - auto res = getCustomClassType(); + TypePtr res = []() { + try { + return getCustomClassType(); + } catch(const c10::Error&) { + TORCH_CHECK( + false, + "Type ", + c10::util::get_fully_qualified_type_name(), + " could not be converted to any of the known types." + ); + } + }(); return std::dynamic_pointer_cast(std::move(res)); } };