Skip to content

Commit

Permalink
Make sure that while caching values we don't invoke any Aten operator (
Browse files Browse the repository at this point in the history
…pytorch#99050)

Summary:
Pull Request resolved: pytorch#99050

title
also change catch to catch all so we can make it wont fail

Test Plan: existing tests

Reviewed By: harishs88ss

Differential Revision: D44945942

fbshipit-source-id: bed0c757b414f1e83f4394318bbf578ac082d5cc
  • Loading branch information
qihqi authored and facebook-github-bot committed Apr 13, 2023
1 parent dda7ce4 commit 549d953
Showing 1 changed file with 17 additions and 10 deletions.
27 changes: 17 additions & 10 deletions torch/csrc/jit/serialization/flatbuffer_serializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,18 +185,23 @@ class FlatbufferSerializer {
// but without relying on aten::nonzero operator being present in the
// binary.
bool operator()(const IValue& lhs, const IValue& rhs) const {
// The only case we don't return bool is for tensor comparison. Lets do
// pointer comparison here.
if (lhs.isTensor() || rhs.isTensor()) {
if (lhs.isTensor() && rhs.isTensor()) {
return (&lhs.toTensor()) == (&rhs.toTensor());
}
return false;
}
IValue eq = lhs.equals(rhs);
if (eq.isBool()) {
return eq.toBool();
}
// The only case we don't return bool is for tensor comparison. Lets do
// pointer comparison here.
return (&lhs.toTensor()) == (&rhs.toTensor());
return false;
}
};

std::unordered_map<IValue, uint32_t, IValueHash, IValueEqual> cached_ivalues_;

const mobile::CompilationUnit* mcu_ = nullptr;
};

Expand Down Expand Up @@ -678,18 +683,20 @@ uint32_t FlatbufferSerializer::storeIValueAndGetIndex(
if (iter != cached_ivalues_.end()) {
return iter->second;
}
} catch (const std::runtime_error&) {
// Threw if ivalue is not hashable
} catch (const c10::Error&) {
// Threw if ivalue is don't have proper operator==
} catch (...) {
// Threw if ivalue is not hashable or
// if ivalue is don't have proper operator==
// we don't care catchall because either case we want to skip hashing
}

auto offset = iValueToFB(fbb, ivalue);
uint32_t index = insertIValue(offset);
try {
cached_ivalues_[ivalue] = index;
} catch (const std::runtime_error&) {
} catch (const c10::Error&) {
} catch (...) {
// Threw if ivalue is not hashable or
// if ivalue is don't have proper operator==
// we don't care catchall because either case we want to skip hashing
}

return index;
Expand Down

0 comments on commit 549d953

Please sign in to comment.