Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add different pattern to show op in timeline #37074

Merged
merged 4 commits into from
Apr 23, 2020

Conversation

zhuzilin
Copy link
Contributor

This is a PR from JIZHI, the AI platform in Tencent.

When using timeline in tensorflow, we often observe large blanks in the "/job" row, which should be showing the consecutive execution of operators.

The following are timelines for transformer and transformer using XLA:
Transformer:
transformer
XLA:
XLA
The timeline is somehow confusing and those blanks may leads to misunderstanding that there are gpus hanging freely.

The reason for this issue is that the "/job" row only shows the scheduling time for each op. As a result, the async kernels may not start to compute by the time the scheduling of its op is over. For the transformer example, the matmul in embedding takes long and blocks other kernels, which results in the large blank in the middle. And for XLA, those fused kernels are scheduled early but have to wait for execution, which results in the large gap between the ending in "/job" and "/stream:all".

Therefore, we propose 2 new pattern to show the op execution time.

  • "gpu" pattern will align op with the execution span of all its kernels
  • "all" pattern will only change the ending time of the op to the ending of its last kernel.

we added a new argument op_time to function generate_chrome_trace_format to let user select how op execution time will be shown. The default value is "schedule" which behaves the same as before. And other possible values are "gpu" and "all" as explained above.

  def generate_chrome_trace_format(self, show_dataflow=True, show_memory=False, op_time="schedule"):

The result of "gpu" pattern is:
Transformer:
transformer
XLA:
XLA

And the result of "all" pattern is
Transformer:
transformer
XLA:
XLA
Notice that the above illustrations have only shown part of the "all" pattern since it may induce large parallel (many op are waiting to be executed at the same time.)

Additionally, the kernel name of an XLA fusion kernel is not parsed correctly, and because we do not have plans on changing the C++ part of the profiler, we reparsed it using the timeline_label attribute provided in RunMetadata.

Thank you for your time on this review.

@tensorflow-bot TensorFlow Bot bot added the size:M CL Change Size: Medium label Feb 26, 2020
@googlebot
Copy link

Thanks for your pull request. It looks like this may be your first contribution to a Google open source project (if not, look below for help). Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

📝 Please visit https://cla.developers.google.com/ to sign.

Once you've signed (or fixed any issues), please reply here with @googlebot I signed it! and we'll verify it.


What to do if you already signed the CLA

Individual signers
Corporate signers

ℹ️ Googlers: Go here for more info.

@zhuzilin
Copy link
Contributor Author

@googlebot I signed it!

@googlebot
Copy link

CLAs look good, thanks!

ℹ️ Googlers: Go here for more info.

@gbaned gbaned self-assigned this Feb 26, 2020
@gbaned gbaned requested a review from yifeif February 26, 2020 04:34
@yifeif
Copy link
Contributor

yifeif commented Feb 27, 2020

Thanks for the PR @zhuzilin. @prb12, would you mind taking a look?

@yifeif yifeif requested review from prb12 and removed request for yifeif February 27, 2020 21:46
@gbaned gbaned added the awaiting review Pull request awaiting review label Mar 5, 2020
@xinan-jiang
Copy link
Contributor

@prb12 do you see this pull request?

@yifeif
Copy link
Contributor

yifeif commented Mar 18, 2020

@xinan-jiang, sorry, looks like @prb12 no longer works on TF anymore. @qiuminxu do you know if anybody on the performance side can take a look at this PR? Thank you!

@yifeif yifeif removed the request for review from prb12 March 18, 2020 16:03
@qiuminxu qiuminxu self-requested a review March 18, 2020 16:15
@yifeif yifeif added the ready to pull PR ready for merge process label Mar 18, 2020
Copy link
Contributor

@qiuminxu qiuminxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for sending the PR.

tensorflow/python/client/timeline.py Show resolved Hide resolved
tensorflow/python/client/timeline.py Show resolved Hide resolved
tensorflow/python/client/timeline.py Show resolved Hide resolved
tensorflow/python/client/timeline.py Outdated Show resolved Hide resolved
tensorflow/python/client/timeline.py Show resolved Hide resolved
@tensorflow-bot TensorFlow Bot bot removed the ready to pull PR ready for merge process label Mar 19, 2020
@zhuzilin
Copy link
Contributor Author

Thanks for sending the PR.

Thank you for your review. I have modified the file according to the advice. Could you have a second look? Thank you!

@gbaned gbaned requested a review from qiuminxu March 19, 2020 11:43
@gbaned gbaned added kokoro:force-run Tests on submitted change ready to pull PR ready for merge process and removed ready to pull PR ready for merge process labels Mar 23, 2020
@kokoro-team kokoro-team removed the kokoro:force-run Tests on submitted change label Mar 23, 2020
@zhuzilin
Copy link
Contributor Author

@google-admin @googlebot Could you help me check what's wrong with the cla check on this pr? The same github account and email address are used on #37813 and that pr passed the cla check.

@qiuminxu
Copy link
Contributor

Can you run the check again and see if it passes?

@zhuzilin
Copy link
Contributor Author

@googlebot I signed it!

@gbaned gbaned added the kokoro:force-run Tests on submitted change label Mar 30, 2020
@kokoro-team kokoro-team removed the kokoro:force-run Tests on submitted change label Mar 30, 2020
@zhuzilin
Copy link
Contributor Author

zhuzilin commented Apr 1, 2020

@qiuminxu Could you help me check what's the migration error in the check "import/copybara"? Thank you!

@qiuminxu
Copy link
Contributor

qiuminxu commented Apr 1, 2020

@yifeif I couldn't see the details of the test error, can someone help take a look?

@rthadur rthadur removed the ready to pull PR ready for merge process label Apr 7, 2020
@gbaned gbaned added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Apr 8, 2020
@gbaned
Copy link
Contributor

gbaned commented Apr 22, 2020

@mihaimaruseac Can you please take a look on cla/google test failure? Thanks!

@google-ml-butler Google-ML-Butler bot added kokoro:force-run Tests on submitted change ready to pull PR ready for merge process labels Apr 22, 2020
@kokoro-team kokoro-team removed the kokoro:force-run Tests on submitted change label Apr 22, 2020
@gbaned gbaned added ready to pull PR ready for merge process and removed stat:awaiting tensorflower Status - Awaiting response from tensorflower ready to pull PR ready for merge process labels Apr 23, 2020
@tensorflow-copybara tensorflow-copybara merged commit 958d2c1 into tensorflow:master Apr 23, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla: yes ready to pull PR ready for merge process size:M CL Change Size: Medium
Projects
None yet
Development

Successfully merging this pull request may close these issues.