New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[JIT] Optimize FunctionSchema::checkArg for the Tensor case. #48034
Conversation
This results in a ~25% improvement on DeepAndWide model and would improve other models as well. Before the change: ``` 522[ms] 507[ms] 559[ms] 512[ms] 510[ms] 548[ms] 515[ms] 600[ms] 518[ms] 494[ms] ``` After the change: ``` 388[ms] 622[ms] 404[ms] 405[ms] 380[ms] 379[ms] 579[ms] 417[ms] 377[ms] 409[ms] ``` [ghstack-poisoned]
💊 CI failures summary and remediationsAs of commit b48cd83 (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group. This comment has been revised 8 times. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a good idea to try to speed up this pathway for common cases. The caching approach here has a number of issues with it that I commented on below. Maybe tackle improving the performance of checkAndNormalizeInput first. For instance, isSubtypeOf is very slow and probably accidentally atomic reference counts types. For common types like Tensors and tuples of them, this can be made much faster by a fast path in checkArg (
if (!value.type()->isSubtypeOf(argument.type())) { |
torch/csrc/jit/api/function_impl.cpp
Outdated
bool need_schema_check = true; | ||
if (!kwargs.size()) { // Fast path | ||
size_t input_types_hash = computeInputTypesHash(stack); | ||
if (!schema_checks_cache_.count(input_types_hash)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the hash collides this check produces wrong results. In the fast path (a hit), one would need to check the equality of the types, which would require more computation.
torch/csrc/jit/api/function_impl.cpp
Outdated
if (!kwargs.size()) { // Fast path | ||
size_t input_types_hash = computeInputTypesHash(stack); | ||
if (!schema_checks_cache_.count(input_types_hash)) { | ||
getSchema().checkAndNormalizeInputs(stack, kwargs); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the schema has default arguments, or other things in the 'NormalizeInputs' bucket, then caching this is invalid because these actions need to be applied to each invocation.
torch/csrc/jit/api/function_impl.cpp
Outdated
size_t input_types_hash = computeInputTypesHash(stack); | ||
if (!schema_checks_cache_.count(input_types_hash)) { | ||
getSchema().checkAndNormalizeInputs(stack, kwargs); | ||
schema_checks_cache_.insert(input_types_hash); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mutating the GraphFunction data structure requires holding a lock because it is invoked from multiple threads.
The Tensor case is one of the most common and the existing check can be made faster. This results in a ~21% improvement on DeepAndWide model and would improve other models as well. Before the change: ``` 505[ms] 491[ms] 514[ms] 538[ms] 514[ms] 554[ms] 556[ms] 512[ms] 516[ms] 527[ms] ``` After the change: ``` 406[ms] 394[ms] 414[ms] 423[ms] 449[ms] 397[ms] 410[ms] 389[ms] 395[ms] 414[ms] ``` Differential Revision: [D24999486](https://our.internmc.facebook.com/intern/diff/D24999486) [ghstack-poisoned]
Thanks for the feedback! Indeed, we could achieve similar gains with a safer change in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool! I have additional suggestions for further speed improvement below.
@@ -151,6 +151,10 @@ inline void FunctionSchema::checkArg( | |||
const IValue& value, | |||
const Argument& argument, | |||
optional<size_t> pos) const { | |||
if (value.isTensor() && argument.type() == TensorType::get()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is correct (and faster, yay!) but still does 2 shared_ptr operations (argument.type() and TypeTensor::get() both return shared_ptr which has to be destructed).
argument.type()->kind() == TensorType::kind()
would eliminate 1 shared_ptr creation. Modifying argument.type() to return a const TypePtr&
would eliminate the other. Not sure if they will speed things up further, but worth checking.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought about checking only the type kind, but wasn't sure it would correctly work in a theoretical case where argument type would have, say, shapes specialized. I.e. if the graph would look like:
graph(%input : Float(10))
In that case, IIUC, argument.type()->kind()
would equal TensorType::kind()
, but argument.type()
would not be equal TensorType::get()
.
As for your second suggestion, let me try that. FWIW, I also tried replacing signature of Type::isSubtypeOfExt
to use a reference for the first argument, but I didn't notice improved performance from that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Tensor case is one of the most common and the existing check can be made faster. This results in a ~21% improvement on DeepAndWide model and would improve other models as well. Before the change: ``` 505[ms] 491[ms] 514[ms] 538[ms] 514[ms] 554[ms] 556[ms] 512[ms] 516[ms] 527[ms] ``` After the change: ``` 406[ms] 394[ms] 414[ms] 423[ms] 449[ms] 397[ms] 410[ms] 389[ms] 395[ms] 414[ms] ``` Differential Revision: [D24999486](https://our.internmc.facebook.com/intern/diff/D24999486) [ghstack-poisoned]
Codecov Report
@@ Coverage Diff @@
## gh/ZolotukhinM/378/base #48034 +/- ##
========================================================
Coverage 81.30% 81.30%
========================================================
Files 1839 1839
Lines 198446 198446
========================================================
+ Hits 161337 161338 +1
+ Misses 37109 37108 -1 |
@ZolotukhinM merged this pull request in 3611d26. |
Stack from ghstack:
The Tensor case is one of the most common and the existing check can be
made faster. This results in a ~21% improvement on DeepAndWide model and
would improve other models as well.
Before the change:
After the change:
Differential Revision: D24999486