Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JIT][ProfilingExecutor] A reliable and consistent API/protocol to query & propagate profiling data #55999

Open
jjsjann123 opened this issue Apr 14, 2021 · 1 comment
Assignees
Labels
oncall: jit Add this issue/PR to JIT oncall triage queue

Comments

@jjsjann123
Copy link
Collaborator

jjsjann123 commented Apr 14, 2021

馃殌 Feature

As more modules utilize profiling data in the pipeline, it is increasingly difficult to locate the proper profile node for a given value in graph.
A consistent API/protocol to query and propagate profile node would be beneficial that: 1. Allowing pass writers to implement robust and short code (less code to search for profile node hence less bugs); 2. Orchestrating subsequent passes to manipulate profile nodes in compatible manners.

Motivation

The problem arises when we try to improve autodiff via utilizing requires_grad from profiled tensor type. requires_grad on IO tensors to a DifferentiableGraph is used to: 1. Prune computation of grad_inputs in backward graph; 2, mark requires_grad on output tensor in forward graph.

This works when we have a DifferentiableGraph that preserves profile nodes for all input/output tensors within its subgraph. However, it is not validated by any explicit checks and can easily be broken by optimization passes, since each pass writer mutates the graph freely.

e.g. For autodiff, we can look at a made up graph. This is mimicking a graph output at this stage:

auto diff_nodes = CreateAutodiffSubgraphs(
copy,
getAutodiffSubgraphInlining() ? autodiffSubgraphNodeThreshold : 1);

  graph(%0 : Tensor,
        %1 : Bool):
  ..., %2 : Tensor = prim::DifferentiableGraph_0(%0)
  %3 : Tensor = prim::If(%1)
    block0():
      %4 : Tensor = prim::DifferentiableGraph_1(%2)
      -> (%4)
    block1():
      %5 : Tensor = prim::DifferentiableGraph_2(%2)
      -> (%5)
  -> (%3)
with prim::DifferentiableGraph_0 = graph(%0 : Tensor):
  ...
  %out : Tensor = aten::operation(...)
  ...
  return (..., %out)
with prim::DifferentiableGraph_1 = graph(%0 : Tensor):
  %temp : Tensor = prim::profile[profiled_type=Tensor](%0)
  ...
with prim::DifferentiableGraph_2 = graph(%0 : Tensor):
  %temp : Tensor = prim::profile[profiled_type=Float(...)](%0)
  ...

We notice a few complications above:

  1. For edges connecting two DifferentiableGraph, the profile node is absorbed by one of them and making it harder for the other to retrieve the profiling information. In the example above, the last output tensor in prim::DifferentiableGraph_0 doesn鈥檛 come directly from a profile node. But the profile node is instead located inside its consumer(s). When we query the profiled tensor type of output inside the subgraph within a DifferentiableGraph, there are three places that we need to look: i. Within the subgraph, we look for profile node that feeds tensor to output node; ii. In the block where the DifferentiableGraph exists, we check users of its output and looking for profile nodes; iii. Finally, for users in case ii that are DifferentiableGraph themselves, we have to look at their subgraph and look for profile node on the corresponding inputs as well.
  2. We could have multiple profile nodes in different branches with conflicting type information. Conflicting type information in general is tricky to handle, but in the restricted use case where only a single profile run is executed, the conflict is more of a concrete tensor type vs an empty tensor type, which can be resolved by simply iterating through all branches until a concrete type is found.

Pitch

In the current protocol for using profiling information, where users insert custom guards and realize the guarded information in optimization, the overhead of profile node is relatively small. Giving that profiling information not used by user will just be discarded, with no runtime penalty.
So this is drastically different from the earlier use of profile node in conjunction with BailOut node, where the existence of profile node implies BailOut overhead and also blocks fusion. Therefore, our strategy of merging profile node across blocks should also be recalibrated.
I think by simply cloning profile node instead of moving them in optimization passes, it would be much easier for subsequent passes to execute similar profiling-information-dependent optimizations. However, without a validation pass, it is hard to keep up this protocol. Hence there comes the request to formalize our APIs to manipulate and validate profiling nodes in a graph. Given that a profile node is mostly just a pass-through node without side effects at runtime, we can safely propagate profile nodes in a generic way after a graph mutation to facilitate future optimization passes, as well as removing a graph with profile nodes before execution.
Simple APIs that I think would be useful:

// Propagate profile nodes across boundaries between Blocks.
void propagateProfileNode(std::shared_ptr<Graph> graph);
// Validate that the given graph satisfies the assumption made in `retriveProfileInformation`, that every use of value in the graph with profiling information could be extracted.
bool validateProfileNode(std::shared_ptr<Graph> graph);
// Note that we are using `Use { Node* user; size_t offset}` here instead of `Value`, this should return the profile node regarding the specific use when applicable.
Node* retrieveProfileInformation(Use use);

cc @gmagogsfm

@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Apr 14, 2021
@github-actions github-actions bot added this to Need triage in JIT Triage Apr 14, 2021
@gmagogsfm gmagogsfm removed this from Need triage in JIT Triage Apr 15, 2021
@Krovatkin
Copy link
Contributor

@jjsjann123 I like the idea! Let me give it a little bit more thought if we could run into any issues with this approach and we could get to implement it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: jit Add this issue/PR to JIT oncall triage queue
Projects
None yet
Development

No branches or pull requests

3 participants