Skip to content

Conversation

jjsjann123
Copy link
Collaborator

named_modules() return a generator, which is not subscriptable and causes node support query to fail

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Jul 11, 2022

🔗 Helpful links

✅ No Failures (0 Pending)

As of commit c0f077e (more details on the Dr. CI page):

Expand to see more

💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@george-qi george-qi added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jul 11, 2022
inp = torch.randn(2, 3, 4, 5).to(dtype=dtype, device=device)
m = Model().to(dtype=dtype, device=device)

traced = symbolic_trace(m)
Copy link
Contributor

@SherlockNoMad SherlockNoMad Jul 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This symbolic_trace will result in a FX graph with torch ops. You can print(traced) to check this. For this graph, .compile() would not taking effect.

To get an FX graph with Aten ops (where the partitioner is based on), you need to use make_fx. See https://github.com/pytorch/pytorch/pull/81311/files#diff-de76da6bb7d45d8a243d50773b91340bc6feb177ad734ff4bdfba6def8e4c801R162 for example.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I understand that compiling this graph would not give us any fusion.

But currently this script fails and triggers runtime error. I think this is not a proper behavior for compiling a valid FX graph with call_module.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ic...
Can you add an explicit comment to say this symbolic trace will not result in any fusion, but it should pass the compiler without error.

Copy link
Contributor

@SherlockNoMad SherlockNoMad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LTGM, with minor comment.

@jjsjann123
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a merge job. Check the current status here

@github-actions
Copy link
Contributor

Hey @jjsjann123.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

@jjsjann123 jjsjann123 deleted the fx_nvfuser_patch branch July 14, 2022 18:55
facebook-github-bot pushed a commit that referenced this pull request Jul 15, 2022
Summary:
named_modules() return a generator, which is not subscriptable and causes node support query to fail

Pull Request resolved: #81258
Approved by: https://github.com/SherlockNoMad

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/cc67a92e749d076f4594410f2b23caed76aba80c

Reviewed By: DanilBaibak

Differential Revision: D37876448

Pulled By: DanilBaibak

fbshipit-source-id: 15c39ac0ac7d0f742a7d148cca82ac1633f39d94
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed Merged module: fx open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants