Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JIT] Optimize FunctionSchema::checkArg for the Tensor case. #48034

Closed
wants to merge 3 commits into from
Closed
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions aten/src/ATen/core/function_schema_inl.h
Expand Up @@ -151,6 +151,10 @@ inline void FunctionSchema::checkArg(
const IValue& value,
const Argument& argument,
optional<size_t> pos) const {
if (value.isTensor() && argument.type() == TensorType::get()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is correct (and faster, yay!) but still does 2 shared_ptr operations (argument.type() and TypeTensor::get() both return shared_ptr which has to be destructed).
argument.type()->kind() == TensorType::kind() would eliminate 1 shared_ptr creation. Modifying argument.type() to return a const TypePtr& would eliminate the other. Not sure if they will speed things up further, but worth checking.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about checking only the type kind, but wasn't sure it would correctly work in a theoretical case where argument type would have, say, shapes specialized. I.e. if the graph would look like:

graph(%input : Float(10))

In that case, IIUC, argument.type()->kind() would equal TensorType::kind(), but argument.type() would not be equal TensorType::get().

As for your second suggestion, let me try that. FWIW, I also tried replacing signature of Type::isSubtypeOfExt to use a reference for the first argument, but I didn't notice improved performance from that.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Submitted #48061 for that, PTAL @zdevito!

// Fast-path for the common case
return;
}
if (!value.type()->isSubtypeOf(argument.type())) {
TORCH_CHECK(
false,
Expand Down