Skip to content

Conversation

laithsakka
Copy link
Contributor

@laithsakka laithsakka commented Apr 22, 2025

Stack from ghstack (oldest at bottom):

address #151879.

The runtime assertion code generation tracks defined unbacked symints, when all defined unabacked summits for a given assert are seen(defined), the runtime assertion is emitted.

Before this PR, unbacked symints that are input to the graph were not detected as defined, hence dependent assertions used to never be triggered.
One issue with the fix is handling emitting runtime assertion for backward .Before this PR, backward will try to regenerate all the assertions again, not considering input defined unabcked symints used to operate as a proxy to avoid generating assertions that should have been defined in the forward (based on the assumption that the unbacked symint is coming from forward output and it would have been defined).
However now as i removed that check we start failing. While I can say (for backward do not consider input defined unbacked symint, this sounds risky and not a complete solution)? What if an assertion depends on both forward .item() call and backward .item() call?
The proposed fix is,
(1) when we finish forward codegen we store the remaining runtime assertions and we only
try to emit those in backward.
(2) after backward we ensure all runtime assertions are emitted.

My only concern is if caching forward and not backward can interfere here. basically if we cache hit forward and miss backward! we wont know what is remaining to lower.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben

@pytorch-bot
Copy link

pytorch-bot bot commented Apr 22, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/151919

Note: Links to docs will display an error until the docs builds have been completed.

❌ 3 New Failures, 1 Unrelated Failure

As of commit 39dfbc3 with merge base 264e8fb (image):

NEW FAILURES - The following jobs have failed:

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@laithsakka laithsakka changed the title fix handling unbacked input during graph lowering of runtime asserts Fix handling unbacked input during graph lowering of runtime asserts Apr 22, 2025
…me asserts"


address #151879.
The runtime assertion code generation tracks defined unbacked symints, when all defined unabacked symints for a 
given assert are seen(defined) the runtime assertion is emitted.
Before this PR, unbacked symints that are input to the graph were not detected as defined, hence dependent assertions 
used to never be triggered. 


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
laithsakka added a commit that referenced this pull request Apr 22, 2025
…me asserts"


address #151879.
The runtime assertion code generation tracks defined unbacked symints, when all defined unabacked symints for a 
given assert are seen(defined) the runtime assertion is emitted.
Before this PR, unbacked symints that are input to the graph were not detected as defined, hence dependent assertions 
used to never be triggered. 


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
laithsakka added a commit that referenced this pull request Apr 25, 2025
@pytorch-bot pytorch-bot bot added the release notes: fx release notes category label Apr 25, 2025
@laithsakka laithsakka requested a review from bdhirsh April 25, 2025 19:46
@laithsakka laithsakka changed the title Fix handling unbacked input during graph lowering of runtime asserts Fix: input defined unbacked symbols skipped during lowering. Apr 25, 2025
return _tensor_nbytes(hint_int(x.numel(), fallback=4096), x.dtype)

if "val" in node.meta:
if node.op == "get_attr" or node.target is torch.ops.aten._assert_scalar.default:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for torch.ops.aten._assert_scalar.default node.meta["val"] could be none so we check this first.

# We are going to start code generating runtime asserts, so make sure
# you don't start adding new ones in the lowering process
graph.freeze_runtime_asserts()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

freezing should happen righe before we start codegen of forward and not after
codegen start with graph.run(*example_inputs)

remaining_ras = graph.ras_by_symbol
# if lowering was done for backward assert that all runtime asserts has been lowered.
if V.graph.is_backward:
assert len(remaining_ras) == 0
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hope this wont disxcover more bugs of unlowered assertions

@laithsakka laithsakka marked this pull request as draft April 25, 2025 22:57
@laithsakka laithsakka changed the title Fix: input defined unbacked symbols skipped during lowering. [WIP]: track remaining runtime time asserts for backward coddgen instead of trying to regenerate all Apr 25, 2025
@laithsakka
Copy link
Contributor Author

I wil go with another short term fix and keep this for the future as the proper fix.

laithsakka added a commit that referenced this pull request Apr 28, 2025
…en for runtime asserts"


So when we use mark_unbacked the graph forward graph will have an unbacked inputs symInt. 
right now the runtime assertions that uses that is never codegen.

The way we generate runtime assertion is by emitting them when all the defined unbacked symbols used 
in them are emitted. 

We previously skipped placeholder, because for backward we have a wacky approach were we 
skip input defined unbacked symbols and assumes assertions that uses them are already emitted
in forward and we try to emit all runtime assertions again otherwise. see [Note [Backwards runtime asserts]

with that we would end up only emitting those that depends on things defined soleley in backward, but we could miss checks that spans inputs defined in both backward and forward, i.e one symbol defined in forward passed as input to backward., and another that is defined in backward.) .This is not ideal an ideal approach could be something like this 
#151919 but it require more work .

That said, there is no reason though to skip place holders for forward so for now doing that. 



cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
laithsakka added a commit that referenced this pull request May 1, 2025
…en for runtime asserts"


So when we use mark_unbacked the graph will have an unbacked inputs symInt. Right now, 
deferred runtime assertions that uses those  is never generated.

This PR changes that, such that in the forward graph we consider those and generate the corresponding 
runtime assertions of them. We still ignore them for backward which is not ideal

The way we generate runtime assertion is by emitting them when all the defined unbacked symbols used 
in them are seen. 

We previously skipped placeholder, because for backward we have a wacky approach were we 
ignore input defined unbacked symbols and assumes assertions that uses them are already emitted
in forward and we try to emit all other runtime assertions again. see [Note [Backwards runtime asserts]

Doing that we ends up only emitting the runtime assertions that depends on things defined solely in backward, but we could miss checks that spans inputs defined in both backward and forward, i.e one symbol defined in forward passed as input to backward., and another that is defined in backward.) .This is not ideal an ideal approach could be something like this #151919 but it require more work .


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this pull request May 2, 2025
…ime asserts (#152231)

So when we use mark_unbacked the graph will have an unbacked inputs symInt. Right now,
deferred runtime assertions that uses those  is never generated.

This PR changes that, such that in the forward graph we consider those and generate the corresponding
runtime assertions of them. We still ignore them for backward which is not ideal

The way we generate runtime assertion is by emitting them when all the defined unbacked symbols used
in them are seen.

We previously skipped placeholder, because for backward we have a wacky approach were we
ignore input defined unbacked symbols and assumes assertions that uses them are already emitted
in forward and we try to emit all other runtime assertions again. see [Note [Backwards runtime asserts]

Doing that we ends up only emitting the runtime assertions that depends on things defined solely in backward, but we could miss checks that spans inputs defined in both backward and forward, i.e one symbol defined in forward passed as input to backward., and another that is defined in backward.) .This is not ideal an ideal approach could be something like this #151919 but it require more work .

Pull Request resolved: #152231
Approved by: https://github.com/aorenste
@eellison eellison removed their request for review June 4, 2025 12:54
@github-actions
Copy link
Contributor

github-actions bot commented Aug 3, 2025

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Aug 3, 2025
@github-actions github-actions bot closed this Sep 2, 2025
@github-actions github-actions bot deleted the gh/laithsakka/160/head branch October 3, 2025 02:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant