-
Notifications
You must be signed in to change notification settings - Fork 560
Apply default precision from bench models to XLA #6133
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Apply default precision from bench models to XLA #6133
Conversation
This PR introduces a change fixing a disparity between instructions' precision of XLA, and TorchInductor. For example: After calling torch.half() or any other buffer/parameter casting TorchInductor will cast all non-fp16 floating types to fp16, and program input shape will correctly be reflected by this. OTOH XLA needs additional environment variables to perform casting upon lowering. This PR reads the default precision set by TorchBench models for inference, and training and sets the appropriate env vars.
36e7cf3 to
42d44eb
Compare
|
Thanks Greg! Please land this and I'll run another nightly with it today. |
| del benchmark | ||
| gc.collect() | ||
|
|
||
| def apply_default_precision_config(self, test, benchmark): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for finding this and adding it to the benchmarks.
I am wondering if this is the right place to fix this.
Is this something we can add in the dynamo bridge when lowering to XLA?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a workaround. See b/316126405
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if we should add these things to benchmarking or if that is just more confusing overall. We should at least add a comment in these cases.
This PR introduces a change fixing a disparity between instructions' precision of XLA, and TorchInductor.
For example:
After calling
torch.half()or any other buffer/parameter casting TorchInductor will cast all non-fp16 floating types tofp16, and program input shape will correctly be reflected by this. OTOH XLA needs additional environment variables to perform casting upon lowering.This PR reads the default precision set by TorchBench models for inference, and training and sets the appropriate env vars.