Skip to content

Commit

Permalink
Update base for Update on "make torch.amp.autocast more generic"
Browse files Browse the repository at this point in the history
# Motivation
As discussed in [#124479](#124479), `torch.amp.autocast` can NOT be completely equivalent to `torch.cuda.amp.autocast` and `torch.cpu.amp.autocast` since `torch.amp.autocast` has NOT the default `dtype` for CPU (`torch.bfloat16` by default) and CUDA (`torch.float16` by default) respectively. We would like `torch.amp.autocast` to be more generic to help the developer/customer write the device-agnostic code. Because there are not enough reasons to add device-specific autocast `torch.xxx.amp.autocast` for each device backend.

# Solution
When `None` is passed to `dtype`, we should use `torch.get_autocast_dtype` to get the related dtype for each backend. Meanwhile, `torch.get_autocast_dtype` is necessary to be supported in JIT path for BC.

# Additional Context
With this PR, `torch.amp.autocast(device_type='cuda')` is equivalent to `torch.cuda.amp.autocast`.

cc mcarilli ptrblck leslie-fang-intel jgong5 ezyang msaroufim bdhirsh anijain2305 chauhang voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng

[ghstack-poisoned]
  • Loading branch information
guangyey committed May 7, 2024
1 parent 72a8d86 commit 42fb93f
Showing 0 changed files with 0 additions and 0 deletions.

0 comments on commit 42fb93f

Please sign in to comment.