New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[JIT] Optimize FunctionSchema::checkArg for the Tensor case. #48034
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -45,10 +45,33 @@ c10::intrusive_ptr<c10::ivalue::Future> GraphFunction::runAsync( | |
return get_executor().runAsync(stack, std::move(taskLauncher)); | ||
} | ||
|
||
size_t GraphFunction::computeInputTypesHash( | ||
const std::vector<IValue>& stack) const { | ||
// Use an algorithm similar to boost::hash_combine to compute the vector hash | ||
size_t r = 0; | ||
const size_t magic_number = 0x9e3779b9; | ||
for (const IValue& iv : stack) { | ||
r ^= std::hash<uint32_t>{}(iv.tagAsInt()) + magic_number + (r << 6) + | ||
(r >> 2); | ||
} | ||
return r; | ||
} | ||
|
||
IValue GraphFunction::operator()( | ||
std::vector<IValue> stack, | ||
const Kwargs& kwargs) { | ||
getSchema().checkAndNormalizeInputs(stack, kwargs); | ||
bool need_schema_check = true; | ||
if (!kwargs.size()) { // Fast path | ||
size_t input_types_hash = computeInputTypesHash(stack); | ||
if (!schema_checks_cache_.count(input_types_hash)) { | ||
getSchema().checkAndNormalizeInputs(stack, kwargs); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the schema has default arguments, or other things in the 'NormalizeInputs' bucket, then caching this is invalid because these actions need to be applied to each invocation. |
||
schema_checks_cache_.insert(input_types_hash); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. mutating the GraphFunction data structure requires holding a lock because it is invoked from multiple threads. |
||
} | ||
need_schema_check = false; | ||
} | ||
if (need_schema_check) { | ||
getSchema().checkAndNormalizeInputs(stack, kwargs); | ||
} | ||
run(stack); | ||
return stack.front(); | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the hash collides this check produces wrong results. In the fast path (a hit), one would need to check the equality of the types, which would require more computation.