Skip to content
This repository was archived by the owner on Aug 1, 2025. It is now read-only.

Conversation

@yushangdi
Copy link
Contributor

@yushangdi yushangdi commented Jul 27, 2022

creates a fx graph of the buffers generated by lowering

with this patch you can run e.g.

INDUCTOR_SCHEDULER_GRAPH=1 python benchmarks/torchbench.py --training --devices=cuda --inductor --skip-accuracy-check -n 1 --isolate -k hf_Bert

to dump the forward and backward graphs of compute-buffers of hf_Bert. The resulting svg file will be in torchbenchmark/.

The dumped files' names are in the format of compute_buffer_{num}.svg.

Also the dumping is quite slow, so for large models like hf_Bert, you would want to change the number of layers to a small number.

The graph dumped looks like this: https://www.svgviewer.dev/s/PpmacAjw

image

@yushangdi
Copy link
Contributor Author

yushangdi commented Jul 27, 2022

changed to be gated by a config flag instead of using a whole different path per request.
Different with PR :#635:
now we use the group already computed in the scheduler node instead of recomputing them for dumping.

@yushangdi yushangdi requested a review from jansel July 27, 2022 17:23
Copy link
Contributor

@Chillee Chillee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly LGTM, although I defer to Jason if he has any strong opinions on how debug logging should be structured.


import numpy as np
import torch
from functorch._src.partitioners import draw_graph
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might want to move this import into the function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved

assert False, node
self.name_to_node = {node.get_name(): node for node in self.nodes}

if bool(os.environ.get("INDUCTOR_DEBUG", False)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename this to INDUCTOR_SCHEDULER_GRAPH=1 or something like that for now.

We'll need a better system eventually (cc: @jansel ).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

renamed

self.name_to_node = {node.get_name(): node for node in self.nodes}

if bool(os.environ.get("INDUCTOR_DEBUG", False)):
global graph_dump_index
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So... we only create a single scheduler per graph we pass to inductor. I'd propose instead that we make a global variable that anybody can access for the purposes of logging.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed, we also need the PR here: pytorch/pytorch#82368

return func1


def create_fx_graph(nodes, fname, print_graph=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By the way, let's move this code to utils.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved offline. Will keep this here because moving it to utils.py might cause circular dependency in the future.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, let's rename this to create_fx_from_buffers

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

renamed

Copy link
Contributor

@Chillee Chillee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other than one final note, LGTM.

Not sure if anybody else has any opinions :)

(might want to wait until other folks chime in)

cc: @anijain2305 @ngimel @desertfire

@yushangdi yushangdi requested a review from ngimel July 28, 2022 00:08
from functorch._src.aot_autograd import get_graph_being_compiled

graph_name = get_graph_being_compiled()
create_fx_graph(self.nodes, graph_name, print_graph=True)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure you want print_graph=True here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as @Chillee said below, this line is only called when we are in debugging mode, so I think printing out the graph is probably helpful?


if print_graph:
print(graph)
print("starting creating module")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stray prints?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is within create_fx_graph, which is for debugging purposes, so i think it's fine.

@yushangdi yushangdi merged commit 5e32ac2 into main Jul 28, 2022
anijain2305 added a commit that referenced this pull request Jul 28, 2022
yushangdi pushed a commit that referenced this pull request Jul 28, 2022
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Jul 28, 2022
### Description
Add utilities to get the graph being compiled for debugging purposes.

used in this PR: pytorch/torchdynamo#665

Pull Request resolved: #82368
Approved by: https://github.com/Chillee
facebook-github-bot pushed a commit to pytorch/functorch that referenced this pull request Aug 1, 2022
Summary:
### Description
Add utilities to get the graph being compiled for debugging purposes.

used in this PR: pytorch/torchdynamo#665

X-link: pytorch/pytorch#82368
Approved by: https://github.com/Chillee

Reviewed By: osalpekar

Differential Revision: D38290516

Pulled By: yushangdi

fbshipit-source-id: 77ba38176683e27450307f23740cfc74725e8311
facebook-github-bot pushed a commit to pytorch/pytorch that referenced this pull request Aug 1, 2022
Summary:
### Description
Add utilities to get the graph being compiled for debugging purposes.

used in this PR: pytorch/torchdynamo#665

Pull Request resolved: #82368
Approved by: https://github.com/Chillee

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/46eeb78c5694fe5d2ba3cb13268abcfc3d53997d

Reviewed By: osalpekar

Differential Revision: D38290516

Pulled By: yushangdi

fbshipit-source-id: 77ba38176683e27450307f23740cfc74725e8311
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants