Skip to content

Commit

Permalink
[PyTorch] Fix getCustomClassType() perf
Browse files Browse the repository at this point in the history
Pull Request resolved: #48981

1) It was copying the entire hash table every time.
2) We don't need to do a hash lookup at all.
ghstack-source-id: 118164406

Differential Revision: [D25385543](https://our.internmc.facebook.com/intern/diff/D25385543/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D25385543/)!
  • Loading branch information
swolchok committed Dec 9, 2020
1 parent 274ce26 commit bb4c891
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 25 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/core/ivalue.cpp
Expand Up @@ -22,7 +22,7 @@ namespace ivalue {

// This is in ivalue.cpp because we need to access Type::annotation_str, which
// is declared in jit_type.h
void checkCustomClassType(TypePtr expected_type, TypePtr actual_type) {
void checkCustomClassType(const Type* expected_type, const Type* actual_type) {
// NB: doing pointer comparison here
// If in the future there ever arises a need to call operator== on custom class
// Type's, this needs to be changed!
Expand Down
14 changes: 9 additions & 5 deletions aten/src/ATen/core/ivalue.h
Expand Up @@ -949,8 +949,8 @@ TORCH_API ska::flat_hash_map<std::type_index, c10::ClassTypePtr>&
getCustomClassTypeMap();

template <typename T>
c10::ClassTypePtr getCustomClassType() {
auto tmap = c10::getCustomClassTypeMap();
c10::ClassTypePtr getCustomClassTypeImpl() {
auto& tmap = c10::getCustomClassTypeMap();
auto res = tmap.find(std::type_index(typeid(T)));
if (res == tmap.end()) {
throw c10::Error("Can't find class id in custom class type map", "");
Expand All @@ -959,9 +959,13 @@ c10::ClassTypePtr getCustomClassType() {
}

template <typename T>
inline bool isCustomClassRegistered() {
auto tmap = c10::getCustomClassTypeMap();
return tmap.find(std::type_index(typeid(T))) != tmap.end();
const c10::ClassTypePtr& getCustomClassType() {
// Classes are never unregistered from getCustomClassTypeMap and the
// hash lookup can be a hot path, so just cache.
// For the same reason, it's fine If this ends up getting duplicated across
// DSO boundaries for whatever reason.
static c10::ClassTypePtr cache = getCustomClassTypeImpl<T>();
return cache;
}

TORCH_API std::unordered_map<std::string, std::function<PyObject*(void*)>>&
Expand Down
27 changes: 15 additions & 12 deletions aten/src/ATen/core/ivalue_inl.h
Expand Up @@ -172,7 +172,7 @@ inline at::Generator IValue::toGenerator() const& {
namespace ivalue {

void CAFFE2_API
checkCustomClassType(TypePtr expected_type, TypePtr actual_type);
checkCustomClassType(const Type* expected_type, const Type* actual_type);

template <typename T>
using Shared = c10::intrusive_ptr<T>;
Expand Down Expand Up @@ -756,8 +756,8 @@ c10::intrusive_ptr<T> IValue::toCustomClass() && {
obj->slots().size() == 1,
"Tried to cast IValue to custom class but it did "
"not contain a custom class!");
auto expected_type = c10::getCustomClassType<c10::intrusive_ptr<T>>();
ivalue::checkCustomClassType(expected_type, type());
const Type* expected_type = c10::getCustomClassType<c10::intrusive_ptr<T>>().get();
ivalue::checkCustomClassType(expected_type, type().get());
auto userObj =
c10::static_intrusive_pointer_cast<T>(obj->getSlot(0).toCapsule());
return userObj;
Expand All @@ -774,8 +774,8 @@ c10::intrusive_ptr<T> IValue::toCustomClass() const& {
obj->slots().size() == 1,
"Tried to cast IValue to custom class but it did "
"not contain a custom class!");
auto expected_type = c10::getCustomClassType<c10::intrusive_ptr<T>>();
ivalue::checkCustomClassType(expected_type, type());
const Type* expected_type = c10::getCustomClassType<c10::intrusive_ptr<T>>().get();
ivalue::checkCustomClassType(expected_type, type().get());
auto userObj =
c10::static_intrusive_pointer_cast<T>(obj->getSlot(0).toCapsule());
return userObj;
Expand Down Expand Up @@ -1085,13 +1085,16 @@ template <
typename T,
std::enable_if_t<std::is_base_of<torch::CustomClassHolder, T>::value, int>>
IValue::IValue(c10::intrusive_ptr<T> custom_class) {
if (!c10::isCustomClassRegistered<c10::intrusive_ptr<T>>()) {
throw c10::Error(
"Trying to instantiate a class that isn't a registered custom class: " +
std::string(c10::util::get_fully_qualified_type_name<T>()),
"");
}
auto classType = c10::getCustomClassType<c10::intrusive_ptr<T>>();
TypePtr classType = []() {
try {
return c10::getCustomClassType<c10::intrusive_ptr<T>>();
} catch (const c10::Error&) {
throw c10::Error(
"Trying to instantiate a class that isn't a registered custom class: " +
std::string(c10::util::get_fully_qualified_type_name<T>()),
"");
}
}();
auto ivalue_obj = c10::ivalue::Object::create(
c10::StrongTypePtr(nullptr, classType), /*num_slots=*/1);
ivalue_obj->setSlot(0, IValue::make_capsule(std::move(custom_class)));
Expand Down
19 changes: 12 additions & 7 deletions aten/src/ATen/core/jit_type.h
Expand Up @@ -1725,13 +1725,18 @@ namespace detail {
template <typename T>
struct getTypePtr_ final {
static TypePtr call() {
TORCH_CHECK(
isCustomClassRegistered<T>(),
"Type ",
c10::util::get_fully_qualified_type_name<T>(),
" could not be converted to any of the known types."
);
auto res = getCustomClassType<T>();
TypePtr res = []() {
try {
return getCustomClassType<T>();
} catch(const c10::Error&) {
TORCH_CHECK(
false,
"Type ",
c10::util::get_fully_qualified_type_name<T>(),
" could not be converted to any of the known types."
);
}
}();
return std::dynamic_pointer_cast<Type>(std::move(res));
}
};
Expand Down

0 comments on commit bb4c891

Please sign in to comment.