Skip to content

Commit

Permalink
get rid of isCustomClassRegistered in favor of try/catch on "[PyTorch…
Browse files Browse the repository at this point in the history
…] Fix getCustomClassType() perf"

1) It was copying the entire hash table every time.
2) We don't need to do a hash lookup at all.

Differential Revision: [D25385543](https://our.internmc.facebook.com/intern/diff/D25385543/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D25385543/)!

[ghstack-poisoned]
  • Loading branch information
swolchok committed Dec 9, 2020
1 parent 0deedff commit 8525492
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 20 deletions.
6 changes: 0 additions & 6 deletions aten/src/ATen/core/ivalue.h
Expand Up @@ -968,12 +968,6 @@ const c10::ClassTypePtr& getCustomClassType() {
return cache;
}

template <typename T>
inline bool isCustomClassRegistered() {
auto& tmap = c10::getCustomClassTypeMap();
return tmap.find(std::type_index(typeid(T))) != tmap.end();
}

TORCH_API std::unordered_map<std::string, std::function<PyObject*(void*)>>&
getClassConverter();
} // namespace c10
Expand Down
17 changes: 10 additions & 7 deletions aten/src/ATen/core/ivalue_inl.h
Expand Up @@ -1085,13 +1085,16 @@ template <
typename T,
std::enable_if_t<std::is_base_of<torch::CustomClassHolder, T>::value, int>>
IValue::IValue(c10::intrusive_ptr<T> custom_class) {
if (!c10::isCustomClassRegistered<c10::intrusive_ptr<T>>()) {
throw c10::Error(
"Trying to instantiate a class that isn't a registered custom class: " +
std::string(c10::util::get_fully_qualified_type_name<T>()),
"");
}
auto classType = c10::getCustomClassType<c10::intrusive_ptr<T>>();
TypePtr classType = []() {
try {
return c10::getCustomClassType<c10::intrusive_ptr<T>>();
} catch (const c10::Error&) {
throw c10::Error(
"Trying to instantiate a class that isn't a registered custom class: " +
std::string(c10::util::get_fully_qualified_type_name<T>()),
"");
}
}();
auto ivalue_obj = c10::ivalue::Object::create(
c10::StrongTypePtr(nullptr, classType), /*num_slots=*/1);
ivalue_obj->setSlot(0, IValue::make_capsule(std::move(custom_class)));
Expand Down
19 changes: 12 additions & 7 deletions aten/src/ATen/core/jit_type.h
Expand Up @@ -1725,13 +1725,18 @@ namespace detail {
template <typename T>
struct getTypePtr_ final {
static TypePtr call() {
TORCH_CHECK(
isCustomClassRegistered<T>(),
"Type ",
c10::util::get_fully_qualified_type_name<T>(),
" could not be converted to any of the known types."
);
auto res = getCustomClassType<T>();
TypePtr res = []() {
try {
return getCustomClassType<T>();
} catch(const c10::Error&) {
TORCH_CHECK(
false,
"Type ",
c10::util::get_fully_qualified_type_name<T>(),
" could not be converted to any of the known types."
);
}
}();
return std::dynamic_pointer_cast<Type>(std::move(res));
}
};
Expand Down

0 comments on commit 8525492

Please sign in to comment.