Skip to content

update tracing codegen to use redispatch API #52009

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 8 commits into from

Conversation

bdhirsh
Copy link
Contributor

@bdhirsh bdhirsh commented Feb 9, 2021

Taking advantage of the new redispatch API to clean up the codegen'd tracing kernels. Instead of directly interacting with the Dispatcher, the tracing kernels now just call the redispatch API directly.

One small benefit to this: hopefully the compiler is more likely to inline Dispatcher::redispatch(), since it's now used in fewer call-sites. After this change, the only places it's used are:

  • the redispatch API (RedispatchFunctions.cpp)
  • BackendSelect kernels.

One small complication: the redispatch API doesn't interact too well with manual_cpp_binding ops currently. I put a note with some thoughts in the comments.

Example tracing kernel before:

Tensor add_Tensor(c10::DispatchKeySet ks, const Tensor & self, const
  torch::jit::Node* node = nullptr;
  std::shared_ptr<jit::tracer::TracingState> tracer_state;
  if (jit::tracer::isTracing()) {
    tracer_state = jit::tracer::getTracingState();
    at::Symbol op_name;
    op_name = jit::Symbol::fromQualString("aten::add");
    node = tracer_state->graph->create(op_name, /*num_outputs=*/0);
    jit::tracer::recordSourceLocation(node);
    jit::tracer::addInputs(node, "self", self);
    jit::tracer::addInputs(node, "other", other);
    jit::tracer::addInputs(node, "alpha", alpha);
    tracer_state->graph->insertNode(node);

    jit::tracer::setTracingState(nullptr);
  }
  static auto op = c10::Dispatcher::singleton()
      .findSchemaOrThrow("aten::add", "Tensor")
      .typed<Tensor (const Tensor &, const Tensor &, Scalar)>();
  auto result =c10::Dispatcher::singleton()
      .redispatch<Tensor, const Tensor &, const Tensor &, Scalar>(op,
  if (tracer_state) {
    jit::tracer::setTracingState(std::move(tracer_state));
    jit::tracer::addOutput(node, result);
  }
  return result;
}

after: (note the lack of Dispatcher:: calls)

Tensor add_Tensor(c10::DispatchKeySet ks, const Tensor & self, const Tensor & other, Scalar alpha)
  torch::jit::Node* node = nullptr;
  std::shared_ptr<jit::tracer::TracingState> tracer_state;
  if (jit::tracer::isTracing()) {
    tracer_state = jit::tracer::getTracingState();
    at::Symbol op_name;
    op_name = jit::Symbol::fromQualString("aten::add");
    node = tracer_state->graph->create(op_name, /*num_outputs=*/0);
    jit::tracer::recordSourceLocation(node);
    jit::tracer::addInputs(node, "self", self);
    jit::tracer::addInputs(node, "other", other);
    jit::tracer::addInputs(node, "alpha", alpha);
    tracer_state->graph->insertNode(node);

    jit::tracer::setTracingState(nullptr);
  }
  auto result =at::redispatch::add(ks & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::D
  if (tracer_state) {
    jit::tracer::setTracingState(std::move(tracer_state));
    jit::tracer::addOutput(node, result);
  }
  return result;
}

Stack from ghstack:

Differential Revision: D26356078


assign_return_values = f'{tie_return_values(f)} = ' \
if f.func.kind() == SchemaKind.functional and f.func.returns else ''

api_name = cpp.name(f.func, faithful_name_for_out_overloads=True)
if f.manual_cpp_binding:
Copy link
Contributor Author

@bdhirsh bdhirsh Feb 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

manual_cpp_binding ops are hardcoded directly into Functions.h, so they aren't currently part of the redispatch API :( I opted instead to call their codegen'd __dispatch_*, which ARE available.

It would probably be nice if I could just call the fast-path variants instead. I opted not to for two reasons:

  • I can't think of an elegant way to do it. One option is to additionally tell the codegen to generate variants in the redispatch API that directly call the hardcoded fast-path variant for ops that have manual_cpp_binding set, appropriately calling the function/method variant depending on the op.
  • Another would just be to duplicate each hardcoded function from Functions.h into both namespaces (bad).
  • The perf benefit is probably minimal, since tracing kernels aren't called that frequently.

If anyone thinks that tradeoff is worth it though, lmk and I'm happy to add it in!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reasoning seems fine.

But consider constructing a full CppSignatureGroup and then extracting name from the faithful signature (if it exists). Helper method in CppSignatureGroup for returning faithful if it exists and regular if not will be helpful. Then you can axe the duplication of __dispatch logic.

Taking advantage of the new `redispatch` API to clean up the codegen'd tracing kernels. Instead of directly interacting with the Dispatcher, the tracing kernels now just call the `redispatch` API directly.

One small benefit to this: hopefully the compiler is more likely to inline `Dispatcher::redispatch()`, since it's now used in fewer call-sites. After this change, the only places it's used are:
- the `redispatch` API (`RedispatchFunctions.cpp`)
- BackendSelect kernels.

One small complication: the redispatch API doesn't interact too well with `manual_cpp_binding` ops currently. I put a note with some thoughts in the comments.

Example tracing kernel before:
```
Tensor add_Tensor(c10::DispatchKeySet ks, const Tensor & self, const
  torch::jit::Node* node = nullptr;
  std::shared_ptr<jit::tracer::TracingState> tracer_state;
  if (jit::tracer::isTracing()) {
    tracer_state = jit::tracer::getTracingState();
    at::Symbol op_name;
    op_name = jit::Symbol::fromQualString("aten::add");
    node = tracer_state->graph->create(op_name, /*num_outputs=*/0);
    jit::tracer::recordSourceLocation(node);
    jit::tracer::addInputs(node, "self", self);
    jit::tracer::addInputs(node, "other", other);
    jit::tracer::addInputs(node, "alpha", alpha);
    tracer_state->graph->insertNode(node);

    jit::tracer::setTracingState(nullptr);
  }
  static auto op = c10::Dispatcher::singleton()
      .findSchemaOrThrow("aten::add", "Tensor")
      .typed<Tensor (const Tensor &, const Tensor &, Scalar)>();
  auto result =c10::Dispatcher::singleton()
      .redispatch<Tensor, const Tensor &, const Tensor &, Scalar>(op,
  if (tracer_state) {
    jit::tracer::setTracingState(std::move(tracer_state));
    jit::tracer::addOutput(node, result);
  }
  return result;
}
```

after: (note the lack of `Dispatcher::` calls)
```
Tensor add_Tensor(c10::DispatchKeySet ks, const Tensor & self, const Tensor & other, Scalar alpha)
  torch::jit::Node* node = nullptr;
  std::shared_ptr<jit::tracer::TracingState> tracer_state;
  if (jit::tracer::isTracing()) {
    tracer_state = jit::tracer::getTracingState();
    at::Symbol op_name;
    op_name = jit::Symbol::fromQualString("aten::add");
    node = tracer_state->graph->create(op_name, /*num_outputs=*/0);
    jit::tracer::recordSourceLocation(node);
    jit::tracer::addInputs(node, "self", self);
    jit::tracer::addInputs(node, "other", other);
    jit::tracer::addInputs(node, "alpha", alpha);
    tracer_state->graph->insertNode(node);

    jit::tracer::setTracingState(nullptr);
  }
  auto result =at::redispatch::add(ks & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::D
  if (tracer_state) {
    jit::tracer::setTracingState(std::move(tracer_state));
    jit::tracer::addOutput(node, result);
  }
  return result;
}
```




[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Feb 10, 2021
ghstack-source-id: c7fca55
Pull Request resolved: #52009
@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Feb 10, 2021

💊 CI failures summary and remediations

As of commit 07f759f (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

@ezyang
Copy link
Contributor

ezyang commented Feb 10, 2021

the copy paste here got botched by screen


  auto result =at::redispatch::add(ks & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::D

Taking advantage of the new `redispatch` API to clean up the codegen'd tracing kernels. Instead of directly interacting with the Dispatcher, the tracing kernels now just call the `redispatch` API directly.

One small benefit to this: hopefully the compiler is more likely to inline `Dispatcher::redispatch()`, since it's now used in fewer call-sites. After this change, the only places it's used are:
- the `redispatch` API (`RedispatchFunctions.cpp`)
- BackendSelect kernels.

One small complication: the redispatch API doesn't interact too well with `manual_cpp_binding` ops currently. I put a note with some thoughts in the comments.

Example tracing kernel before:
```
Tensor add_Tensor(c10::DispatchKeySet ks, const Tensor & self, const
  torch::jit::Node* node = nullptr;
  std::shared_ptr<jit::tracer::TracingState> tracer_state;
  if (jit::tracer::isTracing()) {
    tracer_state = jit::tracer::getTracingState();
    at::Symbol op_name;
    op_name = jit::Symbol::fromQualString("aten::add");
    node = tracer_state->graph->create(op_name, /*num_outputs=*/0);
    jit::tracer::recordSourceLocation(node);
    jit::tracer::addInputs(node, "self", self);
    jit::tracer::addInputs(node, "other", other);
    jit::tracer::addInputs(node, "alpha", alpha);
    tracer_state->graph->insertNode(node);

    jit::tracer::setTracingState(nullptr);
  }
  static auto op = c10::Dispatcher::singleton()
      .findSchemaOrThrow("aten::add", "Tensor")
      .typed<Tensor (const Tensor &, const Tensor &, Scalar)>();
  auto result =c10::Dispatcher::singleton()
      .redispatch<Tensor, const Tensor &, const Tensor &, Scalar>(op,
  if (tracer_state) {
    jit::tracer::setTracingState(std::move(tracer_state));
    jit::tracer::addOutput(node, result);
  }
  return result;
}
```

after: (note the lack of `Dispatcher::` calls)
```
Tensor add_Tensor(c10::DispatchKeySet ks, const Tensor & self, const Tensor & other, Scalar alpha)
  torch::jit::Node* node = nullptr;
  std::shared_ptr<jit::tracer::TracingState> tracer_state;
  if (jit::tracer::isTracing()) {
    tracer_state = jit::tracer::getTracingState();
    at::Symbol op_name;
    op_name = jit::Symbol::fromQualString("aten::add");
    node = tracer_state->graph->create(op_name, /*num_outputs=*/0);
    jit::tracer::recordSourceLocation(node);
    jit::tracer::addInputs(node, "self", self);
    jit::tracer::addInputs(node, "other", other);
    jit::tracer::addInputs(node, "alpha", alpha);
    tracer_state->graph->insertNode(node);

    jit::tracer::setTracingState(nullptr);
  }
  auto result =at::redispatch::add(ks & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::D
  if (tracer_state) {
    jit::tracer::setTracingState(std::move(tracer_state));
    jit::tracer::addOutput(node, result);
  }
  return result;
}
```


Differential Revision: [D26356078](https://our.internmc.facebook.com/intern/diff/D26356078)

[ghstack-poisoned]
Taking advantage of the new `redispatch` API to clean up the codegen'd tracing kernels. Instead of directly interacting with the Dispatcher, the tracing kernels now just call the `redispatch` API directly.

One small benefit to this: hopefully the compiler is more likely to inline `Dispatcher::redispatch()`, since it's now used in fewer call-sites. After this change, the only places it's used are:
- the `redispatch` API (`RedispatchFunctions.cpp`)
- BackendSelect kernels.

One small complication: the redispatch API doesn't interact too well with `manual_cpp_binding` ops currently. I put a note with some thoughts in the comments.

Example tracing kernel before:
```
Tensor add_Tensor(c10::DispatchKeySet ks, const Tensor & self, const
  torch::jit::Node* node = nullptr;
  std::shared_ptr<jit::tracer::TracingState> tracer_state;
  if (jit::tracer::isTracing()) {
    tracer_state = jit::tracer::getTracingState();
    at::Symbol op_name;
    op_name = jit::Symbol::fromQualString("aten::add");
    node = tracer_state->graph->create(op_name, /*num_outputs=*/0);
    jit::tracer::recordSourceLocation(node);
    jit::tracer::addInputs(node, "self", self);
    jit::tracer::addInputs(node, "other", other);
    jit::tracer::addInputs(node, "alpha", alpha);
    tracer_state->graph->insertNode(node);

    jit::tracer::setTracingState(nullptr);
  }
  static auto op = c10::Dispatcher::singleton()
      .findSchemaOrThrow("aten::add", "Tensor")
      .typed<Tensor (const Tensor &, const Tensor &, Scalar)>();
  auto result =c10::Dispatcher::singleton()
      .redispatch<Tensor, const Tensor &, const Tensor &, Scalar>(op,
  if (tracer_state) {
    jit::tracer::setTracingState(std::move(tracer_state));
    jit::tracer::addOutput(node, result);
  }
  return result;
}
```

after: (note the lack of `Dispatcher::` calls)
```
Tensor add_Tensor(c10::DispatchKeySet ks, const Tensor & self, const Tensor & other, Scalar alpha)
  torch::jit::Node* node = nullptr;
  std::shared_ptr<jit::tracer::TracingState> tracer_state;
  if (jit::tracer::isTracing()) {
    tracer_state = jit::tracer::getTracingState();
    at::Symbol op_name;
    op_name = jit::Symbol::fromQualString("aten::add");
    node = tracer_state->graph->create(op_name, /*num_outputs=*/0);
    jit::tracer::recordSourceLocation(node);
    jit::tracer::addInputs(node, "self", self);
    jit::tracer::addInputs(node, "other", other);
    jit::tracer::addInputs(node, "alpha", alpha);
    tracer_state->graph->insertNode(node);

    jit::tracer::setTracingState(nullptr);
  }
  auto result =at::redispatch::add(ks & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::D
  if (tracer_state) {
    jit::tracer::setTracingState(std::move(tracer_state));
    jit::tracer::addOutput(node, result);
  }
  return result;
}
```


Differential Revision: [D26356078](https://our.internmc.facebook.com/intern/diff/D26356078)

[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Feb 11, 2021
ghstack-source-id: 33c5e1c
Pull Request resolved: #52009
Taking advantage of the new `redispatch` API to clean up the codegen'd tracing kernels. Instead of directly interacting with the Dispatcher, the tracing kernels now just call the `redispatch` API directly.

One small benefit to this: hopefully the compiler is more likely to inline `Dispatcher::redispatch()`, since it's now used in fewer call-sites. After this change, the only places it's used are:
- the `redispatch` API (`RedispatchFunctions.cpp`)
- BackendSelect kernels.

One small complication: the redispatch API doesn't interact too well with `manual_cpp_binding` ops currently. I put a note with some thoughts in the comments.

Example tracing kernel before:
```
Tensor add_Tensor(c10::DispatchKeySet ks, const Tensor & self, const
  torch::jit::Node* node = nullptr;
  std::shared_ptr<jit::tracer::TracingState> tracer_state;
  if (jit::tracer::isTracing()) {
    tracer_state = jit::tracer::getTracingState();
    at::Symbol op_name;
    op_name = jit::Symbol::fromQualString("aten::add");
    node = tracer_state->graph->create(op_name, /*num_outputs=*/0);
    jit::tracer::recordSourceLocation(node);
    jit::tracer::addInputs(node, "self", self);
    jit::tracer::addInputs(node, "other", other);
    jit::tracer::addInputs(node, "alpha", alpha);
    tracer_state->graph->insertNode(node);

    jit::tracer::setTracingState(nullptr);
  }
  static auto op = c10::Dispatcher::singleton()
      .findSchemaOrThrow("aten::add", "Tensor")
      .typed<Tensor (const Tensor &, const Tensor &, Scalar)>();
  auto result =c10::Dispatcher::singleton()
      .redispatch<Tensor, const Tensor &, const Tensor &, Scalar>(op,
  if (tracer_state) {
    jit::tracer::setTracingState(std::move(tracer_state));
    jit::tracer::addOutput(node, result);
  }
  return result;
}
```

after: (note the lack of `Dispatcher::` calls)
```
Tensor add_Tensor(c10::DispatchKeySet ks, const Tensor & self, const Tensor & other, Scalar alpha)
  torch::jit::Node* node = nullptr;
  std::shared_ptr<jit::tracer::TracingState> tracer_state;
  if (jit::tracer::isTracing()) {
    tracer_state = jit::tracer::getTracingState();
    at::Symbol op_name;
    op_name = jit::Symbol::fromQualString("aten::add");
    node = tracer_state->graph->create(op_name, /*num_outputs=*/0);
    jit::tracer::recordSourceLocation(node);
    jit::tracer::addInputs(node, "self", self);
    jit::tracer::addInputs(node, "other", other);
    jit::tracer::addInputs(node, "alpha", alpha);
    tracer_state->graph->insertNode(node);

    jit::tracer::setTracingState(nullptr);
  }
  auto result =at::redispatch::add(ks & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::D
  if (tracer_state) {
    jit::tracer::setTracingState(std::move(tracer_state));
    jit::tracer::addOutput(node, result);
  }
  return result;
}
```


Differential Revision: [D26356078](https://our.internmc.facebook.com/intern/diff/D26356078)

[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Feb 12, 2021
ghstack-source-id: 2e76b7c
Pull Request resolved: #52009
Taking advantage of the new `redispatch` API to clean up the codegen'd tracing kernels. Instead of directly interacting with the Dispatcher, the tracing kernels now just call the `redispatch` API directly.

One small benefit to this: hopefully the compiler is more likely to inline `Dispatcher::redispatch()`, since it's now used in fewer call-sites. After this change, the only places it's used are:
- the `redispatch` API (`RedispatchFunctions.cpp`)
- BackendSelect kernels.

One small complication: the redispatch API doesn't interact too well with `manual_cpp_binding` ops currently. I put a note with some thoughts in the comments.

Example tracing kernel before:
```
Tensor add_Tensor(c10::DispatchKeySet ks, const Tensor & self, const
  torch::jit::Node* node = nullptr;
  std::shared_ptr<jit::tracer::TracingState> tracer_state;
  if (jit::tracer::isTracing()) {
    tracer_state = jit::tracer::getTracingState();
    at::Symbol op_name;
    op_name = jit::Symbol::fromQualString("aten::add");
    node = tracer_state->graph->create(op_name, /*num_outputs=*/0);
    jit::tracer::recordSourceLocation(node);
    jit::tracer::addInputs(node, "self", self);
    jit::tracer::addInputs(node, "other", other);
    jit::tracer::addInputs(node, "alpha", alpha);
    tracer_state->graph->insertNode(node);

    jit::tracer::setTracingState(nullptr);
  }
  static auto op = c10::Dispatcher::singleton()
      .findSchemaOrThrow("aten::add", "Tensor")
      .typed<Tensor (const Tensor &, const Tensor &, Scalar)>();
  auto result =c10::Dispatcher::singleton()
      .redispatch<Tensor, const Tensor &, const Tensor &, Scalar>(op,
  if (tracer_state) {
    jit::tracer::setTracingState(std::move(tracer_state));
    jit::tracer::addOutput(node, result);
  }
  return result;
}
```

after: (note the lack of `Dispatcher::` calls)
```
Tensor add_Tensor(c10::DispatchKeySet ks, const Tensor & self, const Tensor & other, Scalar alpha)
  torch::jit::Node* node = nullptr;
  std::shared_ptr<jit::tracer::TracingState> tracer_state;
  if (jit::tracer::isTracing()) {
    tracer_state = jit::tracer::getTracingState();
    at::Symbol op_name;
    op_name = jit::Symbol::fromQualString("aten::add");
    node = tracer_state->graph->create(op_name, /*num_outputs=*/0);
    jit::tracer::recordSourceLocation(node);
    jit::tracer::addInputs(node, "self", self);
    jit::tracer::addInputs(node, "other", other);
    jit::tracer::addInputs(node, "alpha", alpha);
    tracer_state->graph->insertNode(node);

    jit::tracer::setTracingState(nullptr);
  }
  auto result =at::redispatch::add(ks & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::D
  if (tracer_state) {
    jit::tracer::setTracingState(std::move(tracer_state));
    jit::tracer::addOutput(node, result);
  }
  return result;
}
```


Differential Revision: [D26356078](https://our.internmc.facebook.com/intern/diff/D26356078)

[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Feb 12, 2021
Pull Request resolved: #52009

Taking advantage of the new `redispatch` API to clean up the codegen'd tracing kernels. Instead of directly interacting with the Dispatcher, the tracing kernels now just call the `redispatch` API directly.

One small benefit to this: hopefully the compiler is more likely to inline `Dispatcher::redispatch()`, since it's now used in fewer call-sites. After this change, the only places it's used are:
- the `redispatch` API (`RedispatchFunctions.cpp`)
- BackendSelect kernels.

One small complication: the redispatch API doesn't interact too well with `manual_cpp_binding` ops currently. I put a note with some thoughts in the comments.

Example tracing kernel before:
```
Tensor add_Tensor(c10::DispatchKeySet ks, const Tensor & self, const
  torch::jit::Node* node = nullptr;
  std::shared_ptr<jit::tracer::TracingState> tracer_state;
  if (jit::tracer::isTracing()) {
    tracer_state = jit::tracer::getTracingState();
    at::Symbol op_name;
    op_name = jit::Symbol::fromQualString("aten::add");
    node = tracer_state->graph->create(op_name, /*num_outputs=*/0);
    jit::tracer::recordSourceLocation(node);
    jit::tracer::addInputs(node, "self", self);
    jit::tracer::addInputs(node, "other", other);
    jit::tracer::addInputs(node, "alpha", alpha);
    tracer_state->graph->insertNode(node);

    jit::tracer::setTracingState(nullptr);
  }
  static auto op = c10::Dispatcher::singleton()
      .findSchemaOrThrow("aten::add", "Tensor")
      .typed<Tensor (const Tensor &, const Tensor &, Scalar)>();
  auto result =c10::Dispatcher::singleton()
      .redispatch<Tensor, const Tensor &, const Tensor &, Scalar>(op,
  if (tracer_state) {
    jit::tracer::setTracingState(std::move(tracer_state));
    jit::tracer::addOutput(node, result);
  }
  return result;
}
```

after: (note the lack of `Dispatcher::` calls)
```
Tensor add_Tensor(c10::DispatchKeySet ks, const Tensor & self, const Tensor & other, Scalar alpha)
  torch::jit::Node* node = nullptr;
  std::shared_ptr<jit::tracer::TracingState> tracer_state;
  if (jit::tracer::isTracing()) {
    tracer_state = jit::tracer::getTracingState();
    at::Symbol op_name;
    op_name = jit::Symbol::fromQualString("aten::add");
    node = tracer_state->graph->create(op_name, /*num_outputs=*/0);
    jit::tracer::recordSourceLocation(node);
    jit::tracer::addInputs(node, "self", self);
    jit::tracer::addInputs(node, "other", other);
    jit::tracer::addInputs(node, "alpha", alpha);
    tracer_state->graph->insertNode(node);

    jit::tracer::setTracingState(nullptr);
  }
  auto result =at::redispatch::add(ks & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::D
  if (tracer_state) {
    jit::tracer::setTracingState(std::move(tracer_state));
    jit::tracer::addOutput(node, result);
  }
  return result;
}
```

Differential Revision: [D26356078](https://our.internmc.facebook.com/intern/diff/D26356078/)
ghstack-source-id: 121639905
Taking advantage of the new `redispatch` API to clean up the codegen'd tracing kernels. Instead of directly interacting with the Dispatcher, the tracing kernels now just call the `redispatch` API directly.

One small benefit to this: hopefully the compiler is more likely to inline `Dispatcher::redispatch()`, since it's now used in fewer call-sites. After this change, the only places it's used are:
- the `redispatch` API (`RedispatchFunctions.cpp`)
- BackendSelect kernels.

One small complication: the redispatch API doesn't interact too well with `manual_cpp_binding` ops currently. I put a note with some thoughts in the comments.

Example tracing kernel before:
```
Tensor add_Tensor(c10::DispatchKeySet ks, const Tensor & self, const
  torch::jit::Node* node = nullptr;
  std::shared_ptr<jit::tracer::TracingState> tracer_state;
  if (jit::tracer::isTracing()) {
    tracer_state = jit::tracer::getTracingState();
    at::Symbol op_name;
    op_name = jit::Symbol::fromQualString("aten::add");
    node = tracer_state->graph->create(op_name, /*num_outputs=*/0);
    jit::tracer::recordSourceLocation(node);
    jit::tracer::addInputs(node, "self", self);
    jit::tracer::addInputs(node, "other", other);
    jit::tracer::addInputs(node, "alpha", alpha);
    tracer_state->graph->insertNode(node);

    jit::tracer::setTracingState(nullptr);
  }
  static auto op = c10::Dispatcher::singleton()
      .findSchemaOrThrow("aten::add", "Tensor")
      .typed<Tensor (const Tensor &, const Tensor &, Scalar)>();
  auto result =c10::Dispatcher::singleton()
      .redispatch<Tensor, const Tensor &, const Tensor &, Scalar>(op,
  if (tracer_state) {
    jit::tracer::setTracingState(std::move(tracer_state));
    jit::tracer::addOutput(node, result);
  }
  return result;
}
```

after: (note the lack of `Dispatcher::` calls)
```
Tensor add_Tensor(c10::DispatchKeySet ks, const Tensor & self, const Tensor & other, Scalar alpha)
  torch::jit::Node* node = nullptr;
  std::shared_ptr<jit::tracer::TracingState> tracer_state;
  if (jit::tracer::isTracing()) {
    tracer_state = jit::tracer::getTracingState();
    at::Symbol op_name;
    op_name = jit::Symbol::fromQualString("aten::add");
    node = tracer_state->graph->create(op_name, /*num_outputs=*/0);
    jit::tracer::recordSourceLocation(node);
    jit::tracer::addInputs(node, "self", self);
    jit::tracer::addInputs(node, "other", other);
    jit::tracer::addInputs(node, "alpha", alpha);
    tracer_state->graph->insertNode(node);

    jit::tracer::setTracingState(nullptr);
  }
  auto result =at::redispatch::add(ks & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::D
  if (tracer_state) {
    jit::tracer::setTracingState(std::move(tracer_state));
    jit::tracer::addOutput(node, result);
  }
  return result;
}
```


Differential Revision: [D26356078](https://our.internmc.facebook.com/intern/diff/D26356078)

[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Feb 20, 2021
Pull Request resolved: #52009

Taking advantage of the new `redispatch` API to clean up the codegen'd tracing kernels. Instead of directly interacting with the Dispatcher, the tracing kernels now just call the `redispatch` API directly.

One small benefit to this: hopefully the compiler is more likely to inline `Dispatcher::redispatch()`, since it's now used in fewer call-sites. After this change, the only places it's used are:
- the `redispatch` API (`RedispatchFunctions.cpp`)
- BackendSelect kernels.

One small complication: the redispatch API doesn't interact too well with `manual_cpp_binding` ops currently. I put a note with some thoughts in the comments.

Example tracing kernel before:
```
Tensor add_Tensor(c10::DispatchKeySet ks, const Tensor & self, const
  torch::jit::Node* node = nullptr;
  std::shared_ptr<jit::tracer::TracingState> tracer_state;
  if (jit::tracer::isTracing()) {
    tracer_state = jit::tracer::getTracingState();
    at::Symbol op_name;
    op_name = jit::Symbol::fromQualString("aten::add");
    node = tracer_state->graph->create(op_name, /*num_outputs=*/0);
    jit::tracer::recordSourceLocation(node);
    jit::tracer::addInputs(node, "self", self);
    jit::tracer::addInputs(node, "other", other);
    jit::tracer::addInputs(node, "alpha", alpha);
    tracer_state->graph->insertNode(node);

    jit::tracer::setTracingState(nullptr);
  }
  static auto op = c10::Dispatcher::singleton()
      .findSchemaOrThrow("aten::add", "Tensor")
      .typed<Tensor (const Tensor &, const Tensor &, Scalar)>();
  auto result =c10::Dispatcher::singleton()
      .redispatch<Tensor, const Tensor &, const Tensor &, Scalar>(op,
  if (tracer_state) {
    jit::tracer::setTracingState(std::move(tracer_state));
    jit::tracer::addOutput(node, result);
  }
  return result;
}
```

after: (note the lack of `Dispatcher::` calls)
```
Tensor add_Tensor(c10::DispatchKeySet ks, const Tensor & self, const Tensor & other, Scalar alpha)
  torch::jit::Node* node = nullptr;
  std::shared_ptr<jit::tracer::TracingState> tracer_state;
  if (jit::tracer::isTracing()) {
    tracer_state = jit::tracer::getTracingState();
    at::Symbol op_name;
    op_name = jit::Symbol::fromQualString("aten::add");
    node = tracer_state->graph->create(op_name, /*num_outputs=*/0);
    jit::tracer::recordSourceLocation(node);
    jit::tracer::addInputs(node, "self", self);
    jit::tracer::addInputs(node, "other", other);
    jit::tracer::addInputs(node, "alpha", alpha);
    tracer_state->graph->insertNode(node);

    jit::tracer::setTracingState(nullptr);
  }
  auto result =at::redispatch::add(ks & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::D
  if (tracer_state) {
    jit::tracer::setTracingState(std::move(tracer_state));
    jit::tracer::addOutput(node, result);
  }
  return result;
}
```

Differential Revision: [D26356078](https://our.internmc.facebook.com/intern/diff/D26356078/)
ghstack-source-id: 122157408
Taking advantage of the new `redispatch` API to clean up the codegen'd tracing kernels. Instead of directly interacting with the Dispatcher, the tracing kernels now just call the `redispatch` API directly.

One small benefit to this: hopefully the compiler is more likely to inline `Dispatcher::redispatch()`, since it's now used in fewer call-sites. After this change, the only places it's used are:
- the `redispatch` API (`RedispatchFunctions.cpp`)
- BackendSelect kernels.

One small complication: the redispatch API doesn't interact too well with `manual_cpp_binding` ops currently. I put a note with some thoughts in the comments.

Example tracing kernel before:
```
Tensor add_Tensor(c10::DispatchKeySet ks, const Tensor & self, const
  torch::jit::Node* node = nullptr;
  std::shared_ptr<jit::tracer::TracingState> tracer_state;
  if (jit::tracer::isTracing()) {
    tracer_state = jit::tracer::getTracingState();
    at::Symbol op_name;
    op_name = jit::Symbol::fromQualString("aten::add");
    node = tracer_state->graph->create(op_name, /*num_outputs=*/0);
    jit::tracer::recordSourceLocation(node);
    jit::tracer::addInputs(node, "self", self);
    jit::tracer::addInputs(node, "other", other);
    jit::tracer::addInputs(node, "alpha", alpha);
    tracer_state->graph->insertNode(node);

    jit::tracer::setTracingState(nullptr);
  }
  static auto op = c10::Dispatcher::singleton()
      .findSchemaOrThrow("aten::add", "Tensor")
      .typed<Tensor (const Tensor &, const Tensor &, Scalar)>();
  auto result =c10::Dispatcher::singleton()
      .redispatch<Tensor, const Tensor &, const Tensor &, Scalar>(op,
  if (tracer_state) {
    jit::tracer::setTracingState(std::move(tracer_state));
    jit::tracer::addOutput(node, result);
  }
  return result;
}
```

after: (note the lack of `Dispatcher::` calls)
```
Tensor add_Tensor(c10::DispatchKeySet ks, const Tensor & self, const Tensor & other, Scalar alpha)
  torch::jit::Node* node = nullptr;
  std::shared_ptr<jit::tracer::TracingState> tracer_state;
  if (jit::tracer::isTracing()) {
    tracer_state = jit::tracer::getTracingState();
    at::Symbol op_name;
    op_name = jit::Symbol::fromQualString("aten::add");
    node = tracer_state->graph->create(op_name, /*num_outputs=*/0);
    jit::tracer::recordSourceLocation(node);
    jit::tracer::addInputs(node, "self", self);
    jit::tracer::addInputs(node, "other", other);
    jit::tracer::addInputs(node, "alpha", alpha);
    tracer_state->graph->insertNode(node);

    jit::tracer::setTracingState(nullptr);
  }
  auto result =at::redispatch::add(ks & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::D
  if (tracer_state) {
    jit::tracer::setTracingState(std::move(tracer_state));
    jit::tracer::addOutput(node, result);
  }
  return result;
}
```


Differential Revision: [D26356078](https://our.internmc.facebook.com/intern/diff/D26356078)

[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Feb 21, 2021
Pull Request resolved: #52009

Taking advantage of the new `redispatch` API to clean up the codegen'd tracing kernels. Instead of directly interacting with the Dispatcher, the tracing kernels now just call the `redispatch` API directly.

One small benefit to this: hopefully the compiler is more likely to inline `Dispatcher::redispatch()`, since it's now used in fewer call-sites. After this change, the only places it's used are:
- the `redispatch` API (`RedispatchFunctions.cpp`)
- BackendSelect kernels.

One small complication: the redispatch API doesn't interact too well with `manual_cpp_binding` ops currently. I put a note with some thoughts in the comments.

Example tracing kernel before:
```
Tensor add_Tensor(c10::DispatchKeySet ks, const Tensor & self, const
  torch::jit::Node* node = nullptr;
  std::shared_ptr<jit::tracer::TracingState> tracer_state;
  if (jit::tracer::isTracing()) {
    tracer_state = jit::tracer::getTracingState();
    at::Symbol op_name;
    op_name = jit::Symbol::fromQualString("aten::add");
    node = tracer_state->graph->create(op_name, /*num_outputs=*/0);
    jit::tracer::recordSourceLocation(node);
    jit::tracer::addInputs(node, "self", self);
    jit::tracer::addInputs(node, "other", other);
    jit::tracer::addInputs(node, "alpha", alpha);
    tracer_state->graph->insertNode(node);

    jit::tracer::setTracingState(nullptr);
  }
  static auto op = c10::Dispatcher::singleton()
      .findSchemaOrThrow("aten::add", "Tensor")
      .typed<Tensor (const Tensor &, const Tensor &, Scalar)>();
  auto result =c10::Dispatcher::singleton()
      .redispatch<Tensor, const Tensor &, const Tensor &, Scalar>(op,
  if (tracer_state) {
    jit::tracer::setTracingState(std::move(tracer_state));
    jit::tracer::addOutput(node, result);
  }
  return result;
}
```

after: (note the lack of `Dispatcher::` calls)
```
Tensor add_Tensor(c10::DispatchKeySet ks, const Tensor & self, const Tensor & other, Scalar alpha)
  torch::jit::Node* node = nullptr;
  std::shared_ptr<jit::tracer::TracingState> tracer_state;
  if (jit::tracer::isTracing()) {
    tracer_state = jit::tracer::getTracingState();
    at::Symbol op_name;
    op_name = jit::Symbol::fromQualString("aten::add");
    node = tracer_state->graph->create(op_name, /*num_outputs=*/0);
    jit::tracer::recordSourceLocation(node);
    jit::tracer::addInputs(node, "self", self);
    jit::tracer::addInputs(node, "other", other);
    jit::tracer::addInputs(node, "alpha", alpha);
    tracer_state->graph->insertNode(node);

    jit::tracer::setTracingState(nullptr);
  }
  auto result =at::redispatch::add(ks & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::D
  if (tracer_state) {
    jit::tracer::setTracingState(std::move(tracer_state));
    jit::tracer::addOutput(node, result);
  }
  return result;
}
```

Differential Revision: [D26356078](https://our.internmc.facebook.com/intern/diff/D26356078/)
ghstack-source-id: 122171881
@codecov
Copy link

codecov bot commented Feb 22, 2021

Codecov Report

Merging #52009 (07f759f) into gh/bdhirsh/77/base (ad8e906) will decrease coverage by 0.00%.
The diff coverage is n/a.

@@                  Coverage Diff                   @@
##           gh/bdhirsh/77/base   #52009      +/-   ##
======================================================
- Coverage               80.76%   80.76%   -0.01%     
======================================================
  Files                    1969     1969              
  Lines                  216037   216037              
======================================================
- Hits                   174480   174478       -2     
- Misses                  41557    41559       +2     

@facebook-github-bot
Copy link
Contributor

@bdhirsh merged this pull request in 947225c.

@facebook-github-bot facebook-github-bot deleted the gh/bdhirsh/77/head branch February 26, 2021 15:18
aocsa pushed a commit to Quansight/pytorch that referenced this pull request Mar 15, 2021
Summary:
Pull Request resolved: pytorch#52009

Taking advantage of the new `redispatch` API to clean up the codegen'd tracing kernels. Instead of directly interacting with the Dispatcher, the tracing kernels now just call the `redispatch` API directly.

One small benefit to this: hopefully the compiler is more likely to inline `Dispatcher::redispatch()`, since it's now used in fewer call-sites. After this change, the only places it's used are:
- the `redispatch` API (`RedispatchFunctions.cpp`)
- BackendSelect kernels.

One small complication: the redispatch API doesn't interact too well with `manual_cpp_binding` ops currently. I put a note with some thoughts in the comments.

Example tracing kernel before:
```
Tensor add_Tensor(c10::DispatchKeySet ks, const Tensor & self, const
  torch::jit::Node* node = nullptr;
  std::shared_ptr<jit::tracer::TracingState> tracer_state;
  if (jit::tracer::isTracing()) {
    tracer_state = jit::tracer::getTracingState();
    at::Symbol op_name;
    op_name = jit::Symbol::fromQualString("aten::add");
    node = tracer_state->graph->create(op_name, /*num_outputs=*/0);
    jit::tracer::recordSourceLocation(node);
    jit::tracer::addInputs(node, "self", self);
    jit::tracer::addInputs(node, "other", other);
    jit::tracer::addInputs(node, "alpha", alpha);
    tracer_state->graph->insertNode(node);

    jit::tracer::setTracingState(nullptr);
  }
  static auto op = c10::Dispatcher::singleton()
      .findSchemaOrThrow("aten::add", "Tensor")
      .typed<Tensor (const Tensor &, const Tensor &, Scalar)>();
  auto result =c10::Dispatcher::singleton()
      .redispatch<Tensor, const Tensor &, const Tensor &, Scalar>(op,
  if (tracer_state) {
    jit::tracer::setTracingState(std::move(tracer_state));
    jit::tracer::addOutput(node, result);
  }
  return result;
}
```

after: (note the lack of `Dispatcher::` calls)
```
Tensor add_Tensor(c10::DispatchKeySet ks, const Tensor & self, const Tensor & other, Scalar alpha)
  torch::jit::Node* node = nullptr;
  std::shared_ptr<jit::tracer::TracingState> tracer_state;
  if (jit::tracer::isTracing()) {
    tracer_state = jit::tracer::getTracingState();
    at::Symbol op_name;
    op_name = jit::Symbol::fromQualString("aten::add");
    node = tracer_state->graph->create(op_name, /*num_outputs=*/0);
    jit::tracer::recordSourceLocation(node);
    jit::tracer::addInputs(node, "self", self);
    jit::tracer::addInputs(node, "other", other);
    jit::tracer::addInputs(node, "alpha", alpha);
    tracer_state->graph->insertNode(node);

    jit::tracer::setTracingState(nullptr);
  }
  auto result =at::redispatch::add(ks & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::D
  if (tracer_state) {
    jit::tracer::setTracingState(std::move(tracer_state));
    jit::tracer::addOutput(node, result);
  }
  return result;
}
```

Test Plan: Imported from OSS

Reviewed By: ezyang

Differential Revision: D26356078

Pulled By: bdhirsh

fbshipit-source-id: bc96ca4c6d90903f1e265859160d4b13a8cc7310
xsacha pushed a commit to xsacha/pytorch that referenced this pull request Mar 31, 2021
Summary:
Pull Request resolved: pytorch#52009

Taking advantage of the new `redispatch` API to clean up the codegen'd tracing kernels. Instead of directly interacting with the Dispatcher, the tracing kernels now just call the `redispatch` API directly.

One small benefit to this: hopefully the compiler is more likely to inline `Dispatcher::redispatch()`, since it's now used in fewer call-sites. After this change, the only places it's used are:
- the `redispatch` API (`RedispatchFunctions.cpp`)
- BackendSelect kernels.

One small complication: the redispatch API doesn't interact too well with `manual_cpp_binding` ops currently. I put a note with some thoughts in the comments.

Example tracing kernel before:
```
Tensor add_Tensor(c10::DispatchKeySet ks, const Tensor & self, const
  torch::jit::Node* node = nullptr;
  std::shared_ptr<jit::tracer::TracingState> tracer_state;
  if (jit::tracer::isTracing()) {
    tracer_state = jit::tracer::getTracingState();
    at::Symbol op_name;
    op_name = jit::Symbol::fromQualString("aten::add");
    node = tracer_state->graph->create(op_name, /*num_outputs=*/0);
    jit::tracer::recordSourceLocation(node);
    jit::tracer::addInputs(node, "self", self);
    jit::tracer::addInputs(node, "other", other);
    jit::tracer::addInputs(node, "alpha", alpha);
    tracer_state->graph->insertNode(node);

    jit::tracer::setTracingState(nullptr);
  }
  static auto op = c10::Dispatcher::singleton()
      .findSchemaOrThrow("aten::add", "Tensor")
      .typed<Tensor (const Tensor &, const Tensor &, Scalar)>();
  auto result =c10::Dispatcher::singleton()
      .redispatch<Tensor, const Tensor &, const Tensor &, Scalar>(op,
  if (tracer_state) {
    jit::tracer::setTracingState(std::move(tracer_state));
    jit::tracer::addOutput(node, result);
  }
  return result;
}
```

after: (note the lack of `Dispatcher::` calls)
```
Tensor add_Tensor(c10::DispatchKeySet ks, const Tensor & self, const Tensor & other, Scalar alpha)
  torch::jit::Node* node = nullptr;
  std::shared_ptr<jit::tracer::TracingState> tracer_state;
  if (jit::tracer::isTracing()) {
    tracer_state = jit::tracer::getTracingState();
    at::Symbol op_name;
    op_name = jit::Symbol::fromQualString("aten::add");
    node = tracer_state->graph->create(op_name, /*num_outputs=*/0);
    jit::tracer::recordSourceLocation(node);
    jit::tracer::addInputs(node, "self", self);
    jit::tracer::addInputs(node, "other", other);
    jit::tracer::addInputs(node, "alpha", alpha);
    tracer_state->graph->insertNode(node);

    jit::tracer::setTracingState(nullptr);
  }
  auto result =at::redispatch::add(ks & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::D
  if (tracer_state) {
    jit::tracer::setTracingState(std::move(tracer_state));
    jit::tracer::addOutput(node, result);
  }
  return result;
}
```

Test Plan: Imported from OSS

Reviewed By: ezyang

Differential Revision: D26356078

Pulled By: bdhirsh

fbshipit-source-id: bc96ca4c6d90903f1e265859160d4b13a8cc7310
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants