diff --git a/aten/src/ATen/core/ivalue.h b/aten/src/ATen/core/ivalue.h index cf3dd09cdef3..d2e72933b532 100644 --- a/aten/src/ATen/core/ivalue.h +++ b/aten/src/ATen/core/ivalue.h @@ -968,12 +968,6 @@ const c10::ClassTypePtr& getCustomClassType() { return cache; } -template -inline bool isCustomClassRegistered() { - auto& tmap = c10::getCustomClassTypeMap(); - return tmap.find(std::type_index(typeid(T))) != tmap.end(); -} - TORCH_API std::unordered_map>& getClassConverter(); } // namespace c10 diff --git a/aten/src/ATen/core/ivalue_inl.h b/aten/src/ATen/core/ivalue_inl.h index 981033ac9f1e..e5752d8e7b9b 100644 --- a/aten/src/ATen/core/ivalue_inl.h +++ b/aten/src/ATen/core/ivalue_inl.h @@ -1085,13 +1085,16 @@ template < typename T, std::enable_if_t::value, int>> IValue::IValue(c10::intrusive_ptr custom_class) { - if (!c10::isCustomClassRegistered>()) { - throw c10::Error( - "Trying to instantiate a class that isn't a registered custom class: " + - std::string(c10::util::get_fully_qualified_type_name()), - ""); - } - auto classType = c10::getCustomClassType>(); + TypePtr classType = []() { + try { + return c10::getCustomClassType>(); + } catch (const c10::Error&) { + throw c10::Error( + "Trying to instantiate a class that isn't a registered custom class: " + + std::string(c10::util::get_fully_qualified_type_name()), + ""); + } + }(); auto ivalue_obj = c10::ivalue::Object::create( c10::StrongTypePtr(nullptr, classType), /*num_slots=*/1); ivalue_obj->setSlot(0, IValue::make_capsule(std::move(custom_class))); diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h index 40c2ec7f443d..1f6907818a46 100644 --- a/aten/src/ATen/core/jit_type.h +++ b/aten/src/ATen/core/jit_type.h @@ -1725,13 +1725,18 @@ namespace detail { template struct getTypePtr_ final { static TypePtr call() { - TORCH_CHECK( - isCustomClassRegistered(), - "Type ", - c10::util::get_fully_qualified_type_name(), - " could not be converted to any of the known types." - ); - auto res = getCustomClassType(); + TypePtr res = []() { + try { + return getCustomClassType(); + } catch(const c10::Error&) { + TORCH_CHECK( + false, + "Type ", + c10::util::get_fully_qualified_type_name(), + " could not be converted to any of the known types." + ); + } + }(); return std::dynamic_pointer_cast(std::move(res)); } };