-
Notifications
You must be signed in to change notification settings - Fork 24.7k
update tracing codegen to use redispatch API #52009
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
tools/autograd/gen_trace_type.py
Outdated
|
||
assign_return_values = f'{tie_return_values(f)} = ' \ | ||
if f.func.kind() == SchemaKind.functional and f.func.returns else '' | ||
|
||
api_name = cpp.name(f.func, faithful_name_for_out_overloads=True) | ||
if f.manual_cpp_binding: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
manual_cpp_binding
ops are hardcoded directly into Functions.h
, so they aren't currently part of the redispatch API :( I opted instead to call their codegen'd __dispatch_*
, which ARE available.
It would probably be nice if I could just call the fast-path variants instead. I opted not to for two reasons:
- I can't think of an elegant way to do it. One option is to additionally tell the codegen to generate variants in the redispatch API that directly call the hardcoded fast-path variant for ops that have
manual_cpp_binding
set, appropriately calling the function/method variant depending on the op. - Another would just be to duplicate each hardcoded function from
Functions.h
into both namespaces (bad). - The perf benefit is probably minimal, since tracing kernels aren't called that frequently.
If anyone thinks that tradeoff is worth it though, lmk and I'm happy to add it in!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reasoning seems fine.
But consider constructing a full CppSignatureGroup and then extracting name from the faithful signature (if it exists). Helper method in CppSignatureGroup for returning faithful if it exists and regular if not will be helpful. Then you can axe the duplication of __dispatch
logic.
Taking advantage of the new `redispatch` API to clean up the codegen'd tracing kernels. Instead of directly interacting with the Dispatcher, the tracing kernels now just call the `redispatch` API directly. One small benefit to this: hopefully the compiler is more likely to inline `Dispatcher::redispatch()`, since it's now used in fewer call-sites. After this change, the only places it's used are: - the `redispatch` API (`RedispatchFunctions.cpp`) - BackendSelect kernels. One small complication: the redispatch API doesn't interact too well with `manual_cpp_binding` ops currently. I put a note with some thoughts in the comments. Example tracing kernel before: ``` Tensor add_Tensor(c10::DispatchKeySet ks, const Tensor & self, const torch::jit::Node* node = nullptr; std::shared_ptr<jit::tracer::TracingState> tracer_state; if (jit::tracer::isTracing()) { tracer_state = jit::tracer::getTracingState(); at::Symbol op_name; op_name = jit::Symbol::fromQualString("aten::add"); node = tracer_state->graph->create(op_name, /*num_outputs=*/0); jit::tracer::recordSourceLocation(node); jit::tracer::addInputs(node, "self", self); jit::tracer::addInputs(node, "other", other); jit::tracer::addInputs(node, "alpha", alpha); tracer_state->graph->insertNode(node); jit::tracer::setTracingState(nullptr); } static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("aten::add", "Tensor") .typed<Tensor (const Tensor &, const Tensor &, Scalar)>(); auto result =c10::Dispatcher::singleton() .redispatch<Tensor, const Tensor &, const Tensor &, Scalar>(op, if (tracer_state) { jit::tracer::setTracingState(std::move(tracer_state)); jit::tracer::addOutput(node, result); } return result; } ``` after: (note the lack of `Dispatcher::` calls) ``` Tensor add_Tensor(c10::DispatchKeySet ks, const Tensor & self, const Tensor & other, Scalar alpha) torch::jit::Node* node = nullptr; std::shared_ptr<jit::tracer::TracingState> tracer_state; if (jit::tracer::isTracing()) { tracer_state = jit::tracer::getTracingState(); at::Symbol op_name; op_name = jit::Symbol::fromQualString("aten::add"); node = tracer_state->graph->create(op_name, /*num_outputs=*/0); jit::tracer::recordSourceLocation(node); jit::tracer::addInputs(node, "self", self); jit::tracer::addInputs(node, "other", other); jit::tracer::addInputs(node, "alpha", alpha); tracer_state->graph->insertNode(node); jit::tracer::setTracingState(nullptr); } auto result =at::redispatch::add(ks & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::D if (tracer_state) { jit::tracer::setTracingState(std::move(tracer_state)); jit::tracer::addOutput(node, result); } return result; } ``` [ghstack-poisoned]
💊 CI failures summary and remediationsAs of commit 07f759f (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions to the (internal) Dr. CI Users group. |
the copy paste here got botched by screen
|
Taking advantage of the new `redispatch` API to clean up the codegen'd tracing kernels. Instead of directly interacting with the Dispatcher, the tracing kernels now just call the `redispatch` API directly. One small benefit to this: hopefully the compiler is more likely to inline `Dispatcher::redispatch()`, since it's now used in fewer call-sites. After this change, the only places it's used are: - the `redispatch` API (`RedispatchFunctions.cpp`) - BackendSelect kernels. One small complication: the redispatch API doesn't interact too well with `manual_cpp_binding` ops currently. I put a note with some thoughts in the comments. Example tracing kernel before: ``` Tensor add_Tensor(c10::DispatchKeySet ks, const Tensor & self, const torch::jit::Node* node = nullptr; std::shared_ptr<jit::tracer::TracingState> tracer_state; if (jit::tracer::isTracing()) { tracer_state = jit::tracer::getTracingState(); at::Symbol op_name; op_name = jit::Symbol::fromQualString("aten::add"); node = tracer_state->graph->create(op_name, /*num_outputs=*/0); jit::tracer::recordSourceLocation(node); jit::tracer::addInputs(node, "self", self); jit::tracer::addInputs(node, "other", other); jit::tracer::addInputs(node, "alpha", alpha); tracer_state->graph->insertNode(node); jit::tracer::setTracingState(nullptr); } static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("aten::add", "Tensor") .typed<Tensor (const Tensor &, const Tensor &, Scalar)>(); auto result =c10::Dispatcher::singleton() .redispatch<Tensor, const Tensor &, const Tensor &, Scalar>(op, if (tracer_state) { jit::tracer::setTracingState(std::move(tracer_state)); jit::tracer::addOutput(node, result); } return result; } ``` after: (note the lack of `Dispatcher::` calls) ``` Tensor add_Tensor(c10::DispatchKeySet ks, const Tensor & self, const Tensor & other, Scalar alpha) torch::jit::Node* node = nullptr; std::shared_ptr<jit::tracer::TracingState> tracer_state; if (jit::tracer::isTracing()) { tracer_state = jit::tracer::getTracingState(); at::Symbol op_name; op_name = jit::Symbol::fromQualString("aten::add"); node = tracer_state->graph->create(op_name, /*num_outputs=*/0); jit::tracer::recordSourceLocation(node); jit::tracer::addInputs(node, "self", self); jit::tracer::addInputs(node, "other", other); jit::tracer::addInputs(node, "alpha", alpha); tracer_state->graph->insertNode(node); jit::tracer::setTracingState(nullptr); } auto result =at::redispatch::add(ks & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::D if (tracer_state) { jit::tracer::setTracingState(std::move(tracer_state)); jit::tracer::addOutput(node, result); } return result; } ``` Differential Revision: [D26356078](https://our.internmc.facebook.com/intern/diff/D26356078) [ghstack-poisoned]
Taking advantage of the new `redispatch` API to clean up the codegen'd tracing kernels. Instead of directly interacting with the Dispatcher, the tracing kernels now just call the `redispatch` API directly. One small benefit to this: hopefully the compiler is more likely to inline `Dispatcher::redispatch()`, since it's now used in fewer call-sites. After this change, the only places it's used are: - the `redispatch` API (`RedispatchFunctions.cpp`) - BackendSelect kernels. One small complication: the redispatch API doesn't interact too well with `manual_cpp_binding` ops currently. I put a note with some thoughts in the comments. Example tracing kernel before: ``` Tensor add_Tensor(c10::DispatchKeySet ks, const Tensor & self, const torch::jit::Node* node = nullptr; std::shared_ptr<jit::tracer::TracingState> tracer_state; if (jit::tracer::isTracing()) { tracer_state = jit::tracer::getTracingState(); at::Symbol op_name; op_name = jit::Symbol::fromQualString("aten::add"); node = tracer_state->graph->create(op_name, /*num_outputs=*/0); jit::tracer::recordSourceLocation(node); jit::tracer::addInputs(node, "self", self); jit::tracer::addInputs(node, "other", other); jit::tracer::addInputs(node, "alpha", alpha); tracer_state->graph->insertNode(node); jit::tracer::setTracingState(nullptr); } static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("aten::add", "Tensor") .typed<Tensor (const Tensor &, const Tensor &, Scalar)>(); auto result =c10::Dispatcher::singleton() .redispatch<Tensor, const Tensor &, const Tensor &, Scalar>(op, if (tracer_state) { jit::tracer::setTracingState(std::move(tracer_state)); jit::tracer::addOutput(node, result); } return result; } ``` after: (note the lack of `Dispatcher::` calls) ``` Tensor add_Tensor(c10::DispatchKeySet ks, const Tensor & self, const Tensor & other, Scalar alpha) torch::jit::Node* node = nullptr; std::shared_ptr<jit::tracer::TracingState> tracer_state; if (jit::tracer::isTracing()) { tracer_state = jit::tracer::getTracingState(); at::Symbol op_name; op_name = jit::Symbol::fromQualString("aten::add"); node = tracer_state->graph->create(op_name, /*num_outputs=*/0); jit::tracer::recordSourceLocation(node); jit::tracer::addInputs(node, "self", self); jit::tracer::addInputs(node, "other", other); jit::tracer::addInputs(node, "alpha", alpha); tracer_state->graph->insertNode(node); jit::tracer::setTracingState(nullptr); } auto result =at::redispatch::add(ks & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::D if (tracer_state) { jit::tracer::setTracingState(std::move(tracer_state)); jit::tracer::addOutput(node, result); } return result; } ``` Differential Revision: [D26356078](https://our.internmc.facebook.com/intern/diff/D26356078) [ghstack-poisoned]
Taking advantage of the new `redispatch` API to clean up the codegen'd tracing kernels. Instead of directly interacting with the Dispatcher, the tracing kernels now just call the `redispatch` API directly. One small benefit to this: hopefully the compiler is more likely to inline `Dispatcher::redispatch()`, since it's now used in fewer call-sites. After this change, the only places it's used are: - the `redispatch` API (`RedispatchFunctions.cpp`) - BackendSelect kernels. One small complication: the redispatch API doesn't interact too well with `manual_cpp_binding` ops currently. I put a note with some thoughts in the comments. Example tracing kernel before: ``` Tensor add_Tensor(c10::DispatchKeySet ks, const Tensor & self, const torch::jit::Node* node = nullptr; std::shared_ptr<jit::tracer::TracingState> tracer_state; if (jit::tracer::isTracing()) { tracer_state = jit::tracer::getTracingState(); at::Symbol op_name; op_name = jit::Symbol::fromQualString("aten::add"); node = tracer_state->graph->create(op_name, /*num_outputs=*/0); jit::tracer::recordSourceLocation(node); jit::tracer::addInputs(node, "self", self); jit::tracer::addInputs(node, "other", other); jit::tracer::addInputs(node, "alpha", alpha); tracer_state->graph->insertNode(node); jit::tracer::setTracingState(nullptr); } static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("aten::add", "Tensor") .typed<Tensor (const Tensor &, const Tensor &, Scalar)>(); auto result =c10::Dispatcher::singleton() .redispatch<Tensor, const Tensor &, const Tensor &, Scalar>(op, if (tracer_state) { jit::tracer::setTracingState(std::move(tracer_state)); jit::tracer::addOutput(node, result); } return result; } ``` after: (note the lack of `Dispatcher::` calls) ``` Tensor add_Tensor(c10::DispatchKeySet ks, const Tensor & self, const Tensor & other, Scalar alpha) torch::jit::Node* node = nullptr; std::shared_ptr<jit::tracer::TracingState> tracer_state; if (jit::tracer::isTracing()) { tracer_state = jit::tracer::getTracingState(); at::Symbol op_name; op_name = jit::Symbol::fromQualString("aten::add"); node = tracer_state->graph->create(op_name, /*num_outputs=*/0); jit::tracer::recordSourceLocation(node); jit::tracer::addInputs(node, "self", self); jit::tracer::addInputs(node, "other", other); jit::tracer::addInputs(node, "alpha", alpha); tracer_state->graph->insertNode(node); jit::tracer::setTracingState(nullptr); } auto result =at::redispatch::add(ks & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::D if (tracer_state) { jit::tracer::setTracingState(std::move(tracer_state)); jit::tracer::addOutput(node, result); } return result; } ``` Differential Revision: [D26356078](https://our.internmc.facebook.com/intern/diff/D26356078) [ghstack-poisoned]
Taking advantage of the new `redispatch` API to clean up the codegen'd tracing kernels. Instead of directly interacting with the Dispatcher, the tracing kernels now just call the `redispatch` API directly. One small benefit to this: hopefully the compiler is more likely to inline `Dispatcher::redispatch()`, since it's now used in fewer call-sites. After this change, the only places it's used are: - the `redispatch` API (`RedispatchFunctions.cpp`) - BackendSelect kernels. One small complication: the redispatch API doesn't interact too well with `manual_cpp_binding` ops currently. I put a note with some thoughts in the comments. Example tracing kernel before: ``` Tensor add_Tensor(c10::DispatchKeySet ks, const Tensor & self, const torch::jit::Node* node = nullptr; std::shared_ptr<jit::tracer::TracingState> tracer_state; if (jit::tracer::isTracing()) { tracer_state = jit::tracer::getTracingState(); at::Symbol op_name; op_name = jit::Symbol::fromQualString("aten::add"); node = tracer_state->graph->create(op_name, /*num_outputs=*/0); jit::tracer::recordSourceLocation(node); jit::tracer::addInputs(node, "self", self); jit::tracer::addInputs(node, "other", other); jit::tracer::addInputs(node, "alpha", alpha); tracer_state->graph->insertNode(node); jit::tracer::setTracingState(nullptr); } static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("aten::add", "Tensor") .typed<Tensor (const Tensor &, const Tensor &, Scalar)>(); auto result =c10::Dispatcher::singleton() .redispatch<Tensor, const Tensor &, const Tensor &, Scalar>(op, if (tracer_state) { jit::tracer::setTracingState(std::move(tracer_state)); jit::tracer::addOutput(node, result); } return result; } ``` after: (note the lack of `Dispatcher::` calls) ``` Tensor add_Tensor(c10::DispatchKeySet ks, const Tensor & self, const Tensor & other, Scalar alpha) torch::jit::Node* node = nullptr; std::shared_ptr<jit::tracer::TracingState> tracer_state; if (jit::tracer::isTracing()) { tracer_state = jit::tracer::getTracingState(); at::Symbol op_name; op_name = jit::Symbol::fromQualString("aten::add"); node = tracer_state->graph->create(op_name, /*num_outputs=*/0); jit::tracer::recordSourceLocation(node); jit::tracer::addInputs(node, "self", self); jit::tracer::addInputs(node, "other", other); jit::tracer::addInputs(node, "alpha", alpha); tracer_state->graph->insertNode(node); jit::tracer::setTracingState(nullptr); } auto result =at::redispatch::add(ks & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::D if (tracer_state) { jit::tracer::setTracingState(std::move(tracer_state)); jit::tracer::addOutput(node, result); } return result; } ``` Differential Revision: [D26356078](https://our.internmc.facebook.com/intern/diff/D26356078) [ghstack-poisoned]
Pull Request resolved: #52009 Taking advantage of the new `redispatch` API to clean up the codegen'd tracing kernels. Instead of directly interacting with the Dispatcher, the tracing kernels now just call the `redispatch` API directly. One small benefit to this: hopefully the compiler is more likely to inline `Dispatcher::redispatch()`, since it's now used in fewer call-sites. After this change, the only places it's used are: - the `redispatch` API (`RedispatchFunctions.cpp`) - BackendSelect kernels. One small complication: the redispatch API doesn't interact too well with `manual_cpp_binding` ops currently. I put a note with some thoughts in the comments. Example tracing kernel before: ``` Tensor add_Tensor(c10::DispatchKeySet ks, const Tensor & self, const torch::jit::Node* node = nullptr; std::shared_ptr<jit::tracer::TracingState> tracer_state; if (jit::tracer::isTracing()) { tracer_state = jit::tracer::getTracingState(); at::Symbol op_name; op_name = jit::Symbol::fromQualString("aten::add"); node = tracer_state->graph->create(op_name, /*num_outputs=*/0); jit::tracer::recordSourceLocation(node); jit::tracer::addInputs(node, "self", self); jit::tracer::addInputs(node, "other", other); jit::tracer::addInputs(node, "alpha", alpha); tracer_state->graph->insertNode(node); jit::tracer::setTracingState(nullptr); } static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("aten::add", "Tensor") .typed<Tensor (const Tensor &, const Tensor &, Scalar)>(); auto result =c10::Dispatcher::singleton() .redispatch<Tensor, const Tensor &, const Tensor &, Scalar>(op, if (tracer_state) { jit::tracer::setTracingState(std::move(tracer_state)); jit::tracer::addOutput(node, result); } return result; } ``` after: (note the lack of `Dispatcher::` calls) ``` Tensor add_Tensor(c10::DispatchKeySet ks, const Tensor & self, const Tensor & other, Scalar alpha) torch::jit::Node* node = nullptr; std::shared_ptr<jit::tracer::TracingState> tracer_state; if (jit::tracer::isTracing()) { tracer_state = jit::tracer::getTracingState(); at::Symbol op_name; op_name = jit::Symbol::fromQualString("aten::add"); node = tracer_state->graph->create(op_name, /*num_outputs=*/0); jit::tracer::recordSourceLocation(node); jit::tracer::addInputs(node, "self", self); jit::tracer::addInputs(node, "other", other); jit::tracer::addInputs(node, "alpha", alpha); tracer_state->graph->insertNode(node); jit::tracer::setTracingState(nullptr); } auto result =at::redispatch::add(ks & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::D if (tracer_state) { jit::tracer::setTracingState(std::move(tracer_state)); jit::tracer::addOutput(node, result); } return result; } ``` Differential Revision: [D26356078](https://our.internmc.facebook.com/intern/diff/D26356078/) ghstack-source-id: 121639905
Taking advantage of the new `redispatch` API to clean up the codegen'd tracing kernels. Instead of directly interacting with the Dispatcher, the tracing kernels now just call the `redispatch` API directly. One small benefit to this: hopefully the compiler is more likely to inline `Dispatcher::redispatch()`, since it's now used in fewer call-sites. After this change, the only places it's used are: - the `redispatch` API (`RedispatchFunctions.cpp`) - BackendSelect kernels. One small complication: the redispatch API doesn't interact too well with `manual_cpp_binding` ops currently. I put a note with some thoughts in the comments. Example tracing kernel before: ``` Tensor add_Tensor(c10::DispatchKeySet ks, const Tensor & self, const torch::jit::Node* node = nullptr; std::shared_ptr<jit::tracer::TracingState> tracer_state; if (jit::tracer::isTracing()) { tracer_state = jit::tracer::getTracingState(); at::Symbol op_name; op_name = jit::Symbol::fromQualString("aten::add"); node = tracer_state->graph->create(op_name, /*num_outputs=*/0); jit::tracer::recordSourceLocation(node); jit::tracer::addInputs(node, "self", self); jit::tracer::addInputs(node, "other", other); jit::tracer::addInputs(node, "alpha", alpha); tracer_state->graph->insertNode(node); jit::tracer::setTracingState(nullptr); } static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("aten::add", "Tensor") .typed<Tensor (const Tensor &, const Tensor &, Scalar)>(); auto result =c10::Dispatcher::singleton() .redispatch<Tensor, const Tensor &, const Tensor &, Scalar>(op, if (tracer_state) { jit::tracer::setTracingState(std::move(tracer_state)); jit::tracer::addOutput(node, result); } return result; } ``` after: (note the lack of `Dispatcher::` calls) ``` Tensor add_Tensor(c10::DispatchKeySet ks, const Tensor & self, const Tensor & other, Scalar alpha) torch::jit::Node* node = nullptr; std::shared_ptr<jit::tracer::TracingState> tracer_state; if (jit::tracer::isTracing()) { tracer_state = jit::tracer::getTracingState(); at::Symbol op_name; op_name = jit::Symbol::fromQualString("aten::add"); node = tracer_state->graph->create(op_name, /*num_outputs=*/0); jit::tracer::recordSourceLocation(node); jit::tracer::addInputs(node, "self", self); jit::tracer::addInputs(node, "other", other); jit::tracer::addInputs(node, "alpha", alpha); tracer_state->graph->insertNode(node); jit::tracer::setTracingState(nullptr); } auto result =at::redispatch::add(ks & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::D if (tracer_state) { jit::tracer::setTracingState(std::move(tracer_state)); jit::tracer::addOutput(node, result); } return result; } ``` Differential Revision: [D26356078](https://our.internmc.facebook.com/intern/diff/D26356078) [ghstack-poisoned]
Pull Request resolved: #52009 Taking advantage of the new `redispatch` API to clean up the codegen'd tracing kernels. Instead of directly interacting with the Dispatcher, the tracing kernels now just call the `redispatch` API directly. One small benefit to this: hopefully the compiler is more likely to inline `Dispatcher::redispatch()`, since it's now used in fewer call-sites. After this change, the only places it's used are: - the `redispatch` API (`RedispatchFunctions.cpp`) - BackendSelect kernels. One small complication: the redispatch API doesn't interact too well with `manual_cpp_binding` ops currently. I put a note with some thoughts in the comments. Example tracing kernel before: ``` Tensor add_Tensor(c10::DispatchKeySet ks, const Tensor & self, const torch::jit::Node* node = nullptr; std::shared_ptr<jit::tracer::TracingState> tracer_state; if (jit::tracer::isTracing()) { tracer_state = jit::tracer::getTracingState(); at::Symbol op_name; op_name = jit::Symbol::fromQualString("aten::add"); node = tracer_state->graph->create(op_name, /*num_outputs=*/0); jit::tracer::recordSourceLocation(node); jit::tracer::addInputs(node, "self", self); jit::tracer::addInputs(node, "other", other); jit::tracer::addInputs(node, "alpha", alpha); tracer_state->graph->insertNode(node); jit::tracer::setTracingState(nullptr); } static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("aten::add", "Tensor") .typed<Tensor (const Tensor &, const Tensor &, Scalar)>(); auto result =c10::Dispatcher::singleton() .redispatch<Tensor, const Tensor &, const Tensor &, Scalar>(op, if (tracer_state) { jit::tracer::setTracingState(std::move(tracer_state)); jit::tracer::addOutput(node, result); } return result; } ``` after: (note the lack of `Dispatcher::` calls) ``` Tensor add_Tensor(c10::DispatchKeySet ks, const Tensor & self, const Tensor & other, Scalar alpha) torch::jit::Node* node = nullptr; std::shared_ptr<jit::tracer::TracingState> tracer_state; if (jit::tracer::isTracing()) { tracer_state = jit::tracer::getTracingState(); at::Symbol op_name; op_name = jit::Symbol::fromQualString("aten::add"); node = tracer_state->graph->create(op_name, /*num_outputs=*/0); jit::tracer::recordSourceLocation(node); jit::tracer::addInputs(node, "self", self); jit::tracer::addInputs(node, "other", other); jit::tracer::addInputs(node, "alpha", alpha); tracer_state->graph->insertNode(node); jit::tracer::setTracingState(nullptr); } auto result =at::redispatch::add(ks & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::D if (tracer_state) { jit::tracer::setTracingState(std::move(tracer_state)); jit::tracer::addOutput(node, result); } return result; } ``` Differential Revision: [D26356078](https://our.internmc.facebook.com/intern/diff/D26356078/) ghstack-source-id: 122157408
Taking advantage of the new `redispatch` API to clean up the codegen'd tracing kernels. Instead of directly interacting with the Dispatcher, the tracing kernels now just call the `redispatch` API directly. One small benefit to this: hopefully the compiler is more likely to inline `Dispatcher::redispatch()`, since it's now used in fewer call-sites. After this change, the only places it's used are: - the `redispatch` API (`RedispatchFunctions.cpp`) - BackendSelect kernels. One small complication: the redispatch API doesn't interact too well with `manual_cpp_binding` ops currently. I put a note with some thoughts in the comments. Example tracing kernel before: ``` Tensor add_Tensor(c10::DispatchKeySet ks, const Tensor & self, const torch::jit::Node* node = nullptr; std::shared_ptr<jit::tracer::TracingState> tracer_state; if (jit::tracer::isTracing()) { tracer_state = jit::tracer::getTracingState(); at::Symbol op_name; op_name = jit::Symbol::fromQualString("aten::add"); node = tracer_state->graph->create(op_name, /*num_outputs=*/0); jit::tracer::recordSourceLocation(node); jit::tracer::addInputs(node, "self", self); jit::tracer::addInputs(node, "other", other); jit::tracer::addInputs(node, "alpha", alpha); tracer_state->graph->insertNode(node); jit::tracer::setTracingState(nullptr); } static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("aten::add", "Tensor") .typed<Tensor (const Tensor &, const Tensor &, Scalar)>(); auto result =c10::Dispatcher::singleton() .redispatch<Tensor, const Tensor &, const Tensor &, Scalar>(op, if (tracer_state) { jit::tracer::setTracingState(std::move(tracer_state)); jit::tracer::addOutput(node, result); } return result; } ``` after: (note the lack of `Dispatcher::` calls) ``` Tensor add_Tensor(c10::DispatchKeySet ks, const Tensor & self, const Tensor & other, Scalar alpha) torch::jit::Node* node = nullptr; std::shared_ptr<jit::tracer::TracingState> tracer_state; if (jit::tracer::isTracing()) { tracer_state = jit::tracer::getTracingState(); at::Symbol op_name; op_name = jit::Symbol::fromQualString("aten::add"); node = tracer_state->graph->create(op_name, /*num_outputs=*/0); jit::tracer::recordSourceLocation(node); jit::tracer::addInputs(node, "self", self); jit::tracer::addInputs(node, "other", other); jit::tracer::addInputs(node, "alpha", alpha); tracer_state->graph->insertNode(node); jit::tracer::setTracingState(nullptr); } auto result =at::redispatch::add(ks & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::D if (tracer_state) { jit::tracer::setTracingState(std::move(tracer_state)); jit::tracer::addOutput(node, result); } return result; } ``` Differential Revision: [D26356078](https://our.internmc.facebook.com/intern/diff/D26356078) [ghstack-poisoned]
Pull Request resolved: #52009 Taking advantage of the new `redispatch` API to clean up the codegen'd tracing kernels. Instead of directly interacting with the Dispatcher, the tracing kernels now just call the `redispatch` API directly. One small benefit to this: hopefully the compiler is more likely to inline `Dispatcher::redispatch()`, since it's now used in fewer call-sites. After this change, the only places it's used are: - the `redispatch` API (`RedispatchFunctions.cpp`) - BackendSelect kernels. One small complication: the redispatch API doesn't interact too well with `manual_cpp_binding` ops currently. I put a note with some thoughts in the comments. Example tracing kernel before: ``` Tensor add_Tensor(c10::DispatchKeySet ks, const Tensor & self, const torch::jit::Node* node = nullptr; std::shared_ptr<jit::tracer::TracingState> tracer_state; if (jit::tracer::isTracing()) { tracer_state = jit::tracer::getTracingState(); at::Symbol op_name; op_name = jit::Symbol::fromQualString("aten::add"); node = tracer_state->graph->create(op_name, /*num_outputs=*/0); jit::tracer::recordSourceLocation(node); jit::tracer::addInputs(node, "self", self); jit::tracer::addInputs(node, "other", other); jit::tracer::addInputs(node, "alpha", alpha); tracer_state->graph->insertNode(node); jit::tracer::setTracingState(nullptr); } static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("aten::add", "Tensor") .typed<Tensor (const Tensor &, const Tensor &, Scalar)>(); auto result =c10::Dispatcher::singleton() .redispatch<Tensor, const Tensor &, const Tensor &, Scalar>(op, if (tracer_state) { jit::tracer::setTracingState(std::move(tracer_state)); jit::tracer::addOutput(node, result); } return result; } ``` after: (note the lack of `Dispatcher::` calls) ``` Tensor add_Tensor(c10::DispatchKeySet ks, const Tensor & self, const Tensor & other, Scalar alpha) torch::jit::Node* node = nullptr; std::shared_ptr<jit::tracer::TracingState> tracer_state; if (jit::tracer::isTracing()) { tracer_state = jit::tracer::getTracingState(); at::Symbol op_name; op_name = jit::Symbol::fromQualString("aten::add"); node = tracer_state->graph->create(op_name, /*num_outputs=*/0); jit::tracer::recordSourceLocation(node); jit::tracer::addInputs(node, "self", self); jit::tracer::addInputs(node, "other", other); jit::tracer::addInputs(node, "alpha", alpha); tracer_state->graph->insertNode(node); jit::tracer::setTracingState(nullptr); } auto result =at::redispatch::add(ks & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::D if (tracer_state) { jit::tracer::setTracingState(std::move(tracer_state)); jit::tracer::addOutput(node, result); } return result; } ``` Differential Revision: [D26356078](https://our.internmc.facebook.com/intern/diff/D26356078/) ghstack-source-id: 122171881
Codecov Report
@@ Coverage Diff @@
## gh/bdhirsh/77/base #52009 +/- ##
======================================================
- Coverage 80.76% 80.76% -0.01%
======================================================
Files 1969 1969
Lines 216037 216037
======================================================
- Hits 174480 174478 -2
- Misses 41557 41559 +2 |
Summary: Pull Request resolved: pytorch#52009 Taking advantage of the new `redispatch` API to clean up the codegen'd tracing kernels. Instead of directly interacting with the Dispatcher, the tracing kernels now just call the `redispatch` API directly. One small benefit to this: hopefully the compiler is more likely to inline `Dispatcher::redispatch()`, since it's now used in fewer call-sites. After this change, the only places it's used are: - the `redispatch` API (`RedispatchFunctions.cpp`) - BackendSelect kernels. One small complication: the redispatch API doesn't interact too well with `manual_cpp_binding` ops currently. I put a note with some thoughts in the comments. Example tracing kernel before: ``` Tensor add_Tensor(c10::DispatchKeySet ks, const Tensor & self, const torch::jit::Node* node = nullptr; std::shared_ptr<jit::tracer::TracingState> tracer_state; if (jit::tracer::isTracing()) { tracer_state = jit::tracer::getTracingState(); at::Symbol op_name; op_name = jit::Symbol::fromQualString("aten::add"); node = tracer_state->graph->create(op_name, /*num_outputs=*/0); jit::tracer::recordSourceLocation(node); jit::tracer::addInputs(node, "self", self); jit::tracer::addInputs(node, "other", other); jit::tracer::addInputs(node, "alpha", alpha); tracer_state->graph->insertNode(node); jit::tracer::setTracingState(nullptr); } static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("aten::add", "Tensor") .typed<Tensor (const Tensor &, const Tensor &, Scalar)>(); auto result =c10::Dispatcher::singleton() .redispatch<Tensor, const Tensor &, const Tensor &, Scalar>(op, if (tracer_state) { jit::tracer::setTracingState(std::move(tracer_state)); jit::tracer::addOutput(node, result); } return result; } ``` after: (note the lack of `Dispatcher::` calls) ``` Tensor add_Tensor(c10::DispatchKeySet ks, const Tensor & self, const Tensor & other, Scalar alpha) torch::jit::Node* node = nullptr; std::shared_ptr<jit::tracer::TracingState> tracer_state; if (jit::tracer::isTracing()) { tracer_state = jit::tracer::getTracingState(); at::Symbol op_name; op_name = jit::Symbol::fromQualString("aten::add"); node = tracer_state->graph->create(op_name, /*num_outputs=*/0); jit::tracer::recordSourceLocation(node); jit::tracer::addInputs(node, "self", self); jit::tracer::addInputs(node, "other", other); jit::tracer::addInputs(node, "alpha", alpha); tracer_state->graph->insertNode(node); jit::tracer::setTracingState(nullptr); } auto result =at::redispatch::add(ks & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::D if (tracer_state) { jit::tracer::setTracingState(std::move(tracer_state)); jit::tracer::addOutput(node, result); } return result; } ``` Test Plan: Imported from OSS Reviewed By: ezyang Differential Revision: D26356078 Pulled By: bdhirsh fbshipit-source-id: bc96ca4c6d90903f1e265859160d4b13a8cc7310
Summary: Pull Request resolved: pytorch#52009 Taking advantage of the new `redispatch` API to clean up the codegen'd tracing kernels. Instead of directly interacting with the Dispatcher, the tracing kernels now just call the `redispatch` API directly. One small benefit to this: hopefully the compiler is more likely to inline `Dispatcher::redispatch()`, since it's now used in fewer call-sites. After this change, the only places it's used are: - the `redispatch` API (`RedispatchFunctions.cpp`) - BackendSelect kernels. One small complication: the redispatch API doesn't interact too well with `manual_cpp_binding` ops currently. I put a note with some thoughts in the comments. Example tracing kernel before: ``` Tensor add_Tensor(c10::DispatchKeySet ks, const Tensor & self, const torch::jit::Node* node = nullptr; std::shared_ptr<jit::tracer::TracingState> tracer_state; if (jit::tracer::isTracing()) { tracer_state = jit::tracer::getTracingState(); at::Symbol op_name; op_name = jit::Symbol::fromQualString("aten::add"); node = tracer_state->graph->create(op_name, /*num_outputs=*/0); jit::tracer::recordSourceLocation(node); jit::tracer::addInputs(node, "self", self); jit::tracer::addInputs(node, "other", other); jit::tracer::addInputs(node, "alpha", alpha); tracer_state->graph->insertNode(node); jit::tracer::setTracingState(nullptr); } static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("aten::add", "Tensor") .typed<Tensor (const Tensor &, const Tensor &, Scalar)>(); auto result =c10::Dispatcher::singleton() .redispatch<Tensor, const Tensor &, const Tensor &, Scalar>(op, if (tracer_state) { jit::tracer::setTracingState(std::move(tracer_state)); jit::tracer::addOutput(node, result); } return result; } ``` after: (note the lack of `Dispatcher::` calls) ``` Tensor add_Tensor(c10::DispatchKeySet ks, const Tensor & self, const Tensor & other, Scalar alpha) torch::jit::Node* node = nullptr; std::shared_ptr<jit::tracer::TracingState> tracer_state; if (jit::tracer::isTracing()) { tracer_state = jit::tracer::getTracingState(); at::Symbol op_name; op_name = jit::Symbol::fromQualString("aten::add"); node = tracer_state->graph->create(op_name, /*num_outputs=*/0); jit::tracer::recordSourceLocation(node); jit::tracer::addInputs(node, "self", self); jit::tracer::addInputs(node, "other", other); jit::tracer::addInputs(node, "alpha", alpha); tracer_state->graph->insertNode(node); jit::tracer::setTracingState(nullptr); } auto result =at::redispatch::add(ks & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::D if (tracer_state) { jit::tracer::setTracingState(std::move(tracer_state)); jit::tracer::addOutput(node, result); } return result; } ``` Test Plan: Imported from OSS Reviewed By: ezyang Differential Revision: D26356078 Pulled By: bdhirsh fbshipit-source-id: bc96ca4c6d90903f1e265859160d4b13a8cc7310
Taking advantage of the new
redispatch
API to clean up the codegen'd tracing kernels. Instead of directly interacting with the Dispatcher, the tracing kernels now just call theredispatch
API directly.One small benefit to this: hopefully the compiler is more likely to inline
Dispatcher::redispatch()
, since it's now used in fewer call-sites. After this change, the only places it's used are:redispatch
API (RedispatchFunctions.cpp
)One small complication: the redispatch API doesn't interact too well with
manual_cpp_binding
ops currently. I put a note with some thoughts in the comments.Example tracing kernel before:
after: (note the lack of
Dispatcher::
calls)Stack from ghstack:
Differential Revision: D26356078