Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Option to hard fail for cudagraphs and/or more cudagraph logging #124506

Open
xmfan opened this issue Apr 19, 2024 · 1 comment
Open

Option to hard fail for cudagraphs and/or more cudagraph logging #124506

xmfan opened this issue Apr 19, 2024 · 1 comment
Labels
module: cuda graphs Ability to capture and then replay streams of CUDA kernels oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@xmfan
Copy link
Member

xmfan commented Apr 19, 2024

馃殌 The feature, motivation and pitch

We lack a way to know how much of the torch.compiled model has been cudagraphed since we silently fallback ("skipping cudagraphs due to"). The debugging workflow today relies on grepping through logs for cudagraph logs and matching them against graph id logs.

Something like a hard fail config for users who want to ensure their model is always fully cudagraphed, or logging a final cudagraph report can help.

> TORCH_LOGS="+cudagraphs" wp python benchmarks/dynamo/torchbench.py --performance --training --amp --backend inductor --device cuda --print-memory --use-warm-peak-memory --compiled-autograd --only hf_Whisper
loading model: 0it [00:08, ?it/s]
cuda train hf_Whisper                         
I0419 11:12:04.474000 140576568861824 torch/_inductor/cudagraph_trees.py:360] [__cudagraphs] recording cudagraph tree for graph without symints
V0419 11:12:04.669000 140576568861824 torch/_inductor/cudagraph_trees.py:1984] [__cudagraphs] Running warmup of function 0
I0419 11:12:08.337000 140576568861824 torch/_inductor/cudagraph_trees.py:360] [__cudagraphs] recording cudagraph tree for graph without symints
V0419 11:12:08.337000 140576568861824 torch/_inductor/cudagraph_trees.py:1984] [__cudagraphs] Running warmup of function 1
I0419 11:12:11.532000 140576568861824 torch/_inductor/cudagraph_trees.py:360] [__cudagraphs] recording cudagraph tree for graph without symints
V0419 11:12:11.532000 140576568861824 torch/_inductor/cudagraph_trees.py:1984] [__cudagraphs] Running warmup of function 2
I0419 11:12:12.893000 140576568861824 torch/_inductor/cudagraph_trees.py:360] [__cudagraphs] recording cudagraph tree for graph without symints
V0419 11:12:12.893000 140576568861824 torch/_inductor/cudagraph_trees.py:1984] [__cudagraphs] Running warmup of function 3
I0419 11:12:13.406000 140576568861824 torch/_inductor/cudagraph_trees.py:360] [__cudagraphs] recording cudagraph tree for graph without symints
V0419 11:12:13.406000 140576568861824 torch/_inductor/cudagraph_trees.py:1984] [__cudagraphs] Running warmup of function 4
skipping cudagraphs due to skipping cudagraphs due to cpu device
W0419 11:12:28.727000 140576568861824 torch/_logging/_internal.py:1016] [18/0] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
I0419 11:12:42.339000 140576568861824 torch/_inductor/cudagraph_trees.py:360] [__cudagraphs] recording cudagraph tree for graph without symints
V0419 11:12:42.340000 140576568861824 torch/_inductor/cudagraph_trees.py:1984] [__cudagraphs] Running warmup of function 5
V0419 11:12:42.363000 140576568861824 torch/_inductor/cudagraph_trees.py:1947] [__cudagraphs] Recording function 0 of graph recording id 0
V0419 11:12:42.756000 140576568861824 torch/_inductor/cudagraph_trees.py:1947] [__cudagraphs] Recording function 1 of graph recording id 1
V0419 11:12:43.127000 140576568861824 torch/_inductor/cudagraph_trees.py:1947] [__cudagraphs] Recording function 2 of graph recording id 2
V0419 11:12:43.482000 140576568861824 torch/_inductor/cudagraph_trees.py:1947] [__cudagraphs] Recording function 3 of graph recording id 3
V0419 11:12:43.843000 140576568861824 torch/_inductor/cudagraph_trees.py:1947] [__cudagraphs] Recording function 4 of graph recording id 4
skipping cudagraphs due to skipping cudagraphs due to cpu device
V0419 11:12:57.067000 140576568861824 torch/_inductor/cudagraph_trees.py:1947] [__cudagraphs] Recording function 5 of graph recording id 5
skipping cudagraphs due to skipping cudagraphs due to cpu device
memory: eager: 1.12 GB, dynamo: 0.76 GB, ratio: 1.47
running benchmark: 100%|鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅| 30/30 [00:00<00:00, 30.27it/s]
1.360x

Alternatives

No response

Additional context

No response

cc @mcarilli @ezyang @eellison @peterbell10 @msaroufim @bdhirsh @anijain2305 @chauhang @BoyuanFeng

@xmfan xmfan added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: cuda graphs Ability to capture and then replay streams of CUDA kernels oncall: pt2 labels Apr 19, 2024
@ezyang
Copy link
Contributor

ezyang commented Apr 19, 2024

tlparse could help here! cc @jamesjwu

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: cuda graphs Ability to capture and then replay streams of CUDA kernels oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

2 participants