Skip to content

Conversation

StrongerXi
Copy link
Contributor

@StrongerXi StrongerXi commented Oct 8, 2024

See test_inline_closure_returned_by_another_function_and_captures and #136814 for more context.

In #90286, we introduced an optimization so that for captured cells that are unmodified during a Dynamo trace, UserFunctionVariable will represent them as variable of the cell's actual value, rather than a NewCellVariable.

Later on we introduced more mechanisms to model such cells across function calls (#104222), and across function calls where NestedUserFunctionVariable::bind_args need to look up further in the parent frames (#106491) to find these cells' values.

This patch removes InlinedClosureVariable in favor of a simpler modelling, which is also more consistent with what was introduced in #90286, i.e., just model these cells as their contents, in symbolic_locals.

This fixes #136814 because resolution of InlinedClosureVariable to the underlying cell content value happens in
NestedUserFunctionVariable::bind_args, which requires Dynamo to have the value in scope at the function call site (when Dynamo does inlining), but's not always the case (as the test case shows). However, if we model the cells in symbolic_locals, we never need such resolution, and the values are directly stored into the NestedUserFunctionVariable::closure upon the function creation, at which point Dynamo always has the cell value in symbolic_locals for look up.

Fixes #136814.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames @rec

Copy link

pytorch-bot bot commented Oct 8, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/137510

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit c788357 with merge base 4aed81c (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@StrongerXi StrongerXi force-pushed the fix-inlined-closure-var-resolution branch from 21301cc to cf9336c Compare October 8, 2024 20:58
@StrongerXi
Copy link
Contributor Author

1 line fix, in NestedUserFunctionVariable::bind_args, in the regular path without the cell->content optimization, variables in NestedUserFunctionVariable::closure could be ClosureVariable as well, not just NewCellVariable.

@StrongerXi
Copy link
Contributor Author

Rebasing to fix failure from another reverted PR #137447.

…nother function

See #136814 for more context.

In #90286, we introduced an optimization so that for captured cells that
are unmodified during a Dynamo trace, `UserFunctionVariable` will
represent them as variable of the cell's actual value, rather than a
`NewCellVariable`.

Later on we introduced more mechanisms to model such cells across
function calls (#104222), and across function calls where
`NestedUserFunctionVariable::bind_args` need to look up further in the
parent frames (#106491) to find these cells' values.

This patch removes `InlinedClosureVariable` in favor of a simpler
modelling which is also more consistent with what was introduced in #90286,
i.e., just model these cells as their contents, in `symbolic_locals`.

This fixes #136814 because resolution of `InlinedClosureVariable` to the
underlying cell content value happens in
`NestedUserFunctionVariable::bind_args`, which requires Dynamo to have
the value in scope at the function call site (when Dynamo starts
inlining), but's not always the case (as the test case shows). However,
if we model the cells in `symbolic_locals`, we never need such
resolution, and the values are directly stored into the
`NestedUserFunctionVariable::closure` upon the function creation, at
which point Dynamo always has the cell value in `symbolic_locals` for
look up.

Fixes #136814.
@StrongerXi
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 9, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@github-actions github-actions bot deleted the fix-inlined-closure-var-resolution branch November 9, 2024 02:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Dynamo inlining errors with some calls to nested functions that use captured variables

3 participants