Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update on "Deprecate device-specific GradScaler autocast API"
# Motivation ## for `torch.amp.GradScaler`, - `torch.cpu.amp.GradScaler(args...)` is completely equivalent to `torch. amp.GradScaler("cpu", args...)`. - `torch.cuda.amp.GradScaler(args...)` is completely equivalent to `torch.amp.GradScaler("cuda", args...)`. So, we intend to depreate them and **strongly recommend** developer to use `torch.amp.GradScaler`. ## for `custom_fwd` and `custom_bwd`, this is a good solution to make the custom function run with or without effect even in an autocast-enabled region and can be shared by other backends, like CPU and XPU. So we generalize it to be device-agnostic and put them int `torch/amp/autocast_mode.py` and re-expose to `torch.amp.custom_fwd` and `torch.amp.custom_bwd`. Meanwhile, we deprecate `torch.cuda.amp.custom_fwd` and `torch.cuda.amp.custom_bwd`. # Additional Context Add UT to cover the deprecated warning. No need for more UTs to cover the functionality of `torch.amp.custom_f/bwd`, the existing UTs that previously covered the functionality of `torch.cuda.amp.custom_f/bwd` can cover them. To facilitate the review, we separate these code changes to two PRs. The first PR cover `torch.amp.GradScaler`. The follow-up covers `custom_fwd` and `custom_bwd`. cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 mcarilli ptrblck leslie-fang-intel voznesenskym EikanWang Guobing-Chen zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng [ghstack-poisoned]
- Loading branch information