-
Notifications
You must be signed in to change notification settings - Fork 25.6k
[WIP]: track remaining runtime time asserts for backward coddgen instead of trying to regenerate all #151919
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/151919
Note: Links to docs will display an error until the docs builds have been completed. ❌ 3 New Failures, 1 Unrelated FailureAs of commit 39dfbc3 with merge base 264e8fb ( NEW FAILURES - The following jobs have failed:
BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…me asserts" address #151879. The runtime assertion code generation tracks defined unbacked symints, when all defined unabacked symints for a given assert are seen(defined) the runtime assertion is emitted. Before this PR, unbacked symints that are input to the graph were not detected as defined, hence dependent assertions used to never be triggered. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
…me asserts" address #151879. The runtime assertion code generation tracks defined unbacked symints, when all defined unabacked symints for a given assert are seen(defined) the runtime assertion is emitted. Before this PR, unbacked symints that are input to the graph were not detected as defined, hence dependent assertions used to never be triggered. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
return _tensor_nbytes(hint_int(x.numel(), fallback=4096), x.dtype) | ||
|
||
if "val" in node.meta: | ||
if node.op == "get_attr" or node.target is torch.ops.aten._assert_scalar.default: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for torch.ops.aten._assert_scalar.default node.meta["val"] could be none so we check this first.
# We are going to start code generating runtime asserts, so make sure | ||
# you don't start adding new ones in the lowering process | ||
graph.freeze_runtime_asserts() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
freezing should happen righe before we start codegen of forward and not after
codegen start with graph.run(*example_inputs)
remaining_ras = graph.ras_by_symbol | ||
# if lowering was done for backward assert that all runtime asserts has been lowered. | ||
if V.graph.is_backward: | ||
assert len(remaining_ras) == 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I hope this wont disxcover more bugs of unlowered assertions
I wil go with another short term fix and keep this for the future as the proper fix. |
…en for runtime asserts" So when we use mark_unbacked the graph forward graph will have an unbacked inputs symInt. right now the runtime assertions that uses that is never codegen. The way we generate runtime assertion is by emitting them when all the defined unbacked symbols used in them are emitted. We previously skipped placeholder, because for backward we have a wacky approach were we skip input defined unbacked symbols and assumes assertions that uses them are already emitted in forward and we try to emit all runtime assertions again otherwise. see [Note [Backwards runtime asserts] with that we would end up only emitting those that depends on things defined soleley in backward, but we could miss checks that spans inputs defined in both backward and forward, i.e one symbol defined in forward passed as input to backward., and another that is defined in backward.) .This is not ideal an ideal approach could be something like this #151919 but it require more work . That said, there is no reason though to skip place holders for forward so for now doing that. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
…en for runtime asserts" So when we use mark_unbacked the graph will have an unbacked inputs symInt. Right now, deferred runtime assertions that uses those is never generated. This PR changes that, such that in the forward graph we consider those and generate the corresponding runtime assertions of them. We still ignore them for backward which is not ideal The way we generate runtime assertion is by emitting them when all the defined unbacked symbols used in them are seen. We previously skipped placeholder, because for backward we have a wacky approach were we ignore input defined unbacked symbols and assumes assertions that uses them are already emitted in forward and we try to emit all other runtime assertions again. see [Note [Backwards runtime asserts] Doing that we ends up only emitting the runtime assertions that depends on things defined solely in backward, but we could miss checks that spans inputs defined in both backward and forward, i.e one symbol defined in forward passed as input to backward., and another that is defined in backward.) .This is not ideal an ideal approach could be something like this #151919 but it require more work . cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
…ime asserts (#152231) So when we use mark_unbacked the graph will have an unbacked inputs symInt. Right now, deferred runtime assertions that uses those is never generated. This PR changes that, such that in the forward graph we consider those and generate the corresponding runtime assertions of them. We still ignore them for backward which is not ideal The way we generate runtime assertion is by emitting them when all the defined unbacked symbols used in them are seen. We previously skipped placeholder, because for backward we have a wacky approach were we ignore input defined unbacked symbols and assumes assertions that uses them are already emitted in forward and we try to emit all other runtime assertions again. see [Note [Backwards runtime asserts] Doing that we ends up only emitting the runtime assertions that depends on things defined solely in backward, but we could miss checks that spans inputs defined in both backward and forward, i.e one symbol defined in forward passed as input to backward., and another that is defined in backward.) .This is not ideal an ideal approach could be something like this #151919 but it require more work . Pull Request resolved: #152231 Approved by: https://github.com/aorenste
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
Stack from ghstack (oldest at bottom):
address #151879.
The runtime assertion code generation tracks defined unbacked symints, when all defined unabacked summits for a given assert are seen(defined), the runtime assertion is emitted.
Before this PR, unbacked symints that are input to the graph were not detected as defined, hence dependent assertions used to never be triggered.
One issue with the fix is handling emitting runtime assertion for backward .Before this PR, backward will try to regenerate all the assertions again, not considering input defined unabcked symints used to operate as a proxy to avoid generating assertions that should have been defined in the forward (based on the assumption that the unbacked symint is coming from forward output and it would have been defined).
However now as i removed that check we start failing. While I can say (for backward do not consider input defined unbacked symint, this sounds risky and not a complete solution)? What if an assertion depends on both forward .item() call and backward .item() call?
The proposed fix is,
(1) when we finish forward codegen we store the remaining runtime assertions and we only
try to emit those in backward.
(2) after backward we ensure all runtime assertions are emitted.
My only concern is if caching forward and not backward can interfere here. basically if we cache hit forward and miss backward! we wont know what is remaining to lower.
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben