Skip to content

Commit

Permalink
Update on "make torch.amp.autocast more generic"
Browse files Browse the repository at this point in the history
# Motivation
As discussed in [#124479](#124479), `torch.amp.autocast` can NOT be completely equivalent to `torch.cuda.amp.autocast` and `torch.cpu.amp.autocast` since `torch.amp.autocast` has NOT the default `dtype` for CPU (`torch.bfloat16` by default) and CUDA (`torch.float16` by default) respectively. We would like `torch.amp.autocast` to be more generic to help the developer/customer write the device-agnostic code. Because there are not enough reasons to add device-specific autocast `torch.xxx.amp.autocast` for each device backend.

# Solution
When `None` is passed to `dtype`, we should use `torch.get_autocast_dtype` to get the related dtype for each backend. Meanwhile, `torch.get_autocast_dtype` is necessary to be supported in JIT path for BC.

# Additional Context
With this PR, `torch.amp.autocast(device_type='cuda')` is equivalent to `torch.cuda.amp.autocast`.

cc mcarilli ptrblck leslie-fang-intel jgong5 ezyang msaroufim bdhirsh anijain2305 chauhang voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng

[ghstack-poisoned]
  • Loading branch information
guangyey committed May 7, 2024
2 parents b00a0bd + 42fb93f commit 73d9468
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions test/test_jit_autocast.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,19 @@ def tearDown(self):
@unittest.skipIf(not TEST_CUDA, "No cuda")
def test_generic_jit_autocast(self):
@torch.jit.script
def fn_cuda(a, b):
def fn_cuda_autocast(a, b):
with autocast():
x = torch.mm(a, b)
y = torch.sum(x)
return x, y

@torch.jit.script
def fn_generic(a, b):
def fn_generic_autocast(a, b):
with torch.amp.autocast(device_type='cpu'):
x = torch.mm(a, b)
y = torch.sum(x)
return x, y
self.assertEqual(fn_cuda(self.a_fp32, self.b_fp32), fn_generic(self.a_fp32, self.b_fp32))
self.assertEqual(fn_cuda_autocast(self.a_fp32, self.b_fp32), fn_generic_autocast(self.a_fp32, self.b_fp32))

@unittest.skipIf(not TEST_CUDA, "No cuda")
def test_minimal(self):
Expand Down

0 comments on commit 73d9468

Please sign in to comment.