Skip to content

Commit

Permalink
Update on "generalize custom_fwd&custom_bwd to be device-agnostic"
Browse files Browse the repository at this point in the history
cc mcarilli ptrblck leslie-fang-intel jgong5 voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
  • Loading branch information
guangyey committed May 22, 2024
2 parents f8fc77d + df0198f commit fa5f5a5
Showing 1 changed file with 1 addition and 4 deletions.
5 changes: 1 addition & 4 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6152,10 +6152,7 @@ def run(model0, model1, optimizer0, optimizer1, try_scaling_api):
@onlyNativeDeviceTypes
def test_grad_scaler_pass_itself(self, device):
device = torch.device(device)
if device.type == "cuda":
GradScaler = partial(torch.amp.GradScaler, device="cuda")
else:
GradScaler = partial(torch.amp.GradScaler, device="cpu")
GradScaler = partial(torch.amp.GradScaler, device=device)

class _PlaceHolderOptimizer(torch.optim.Optimizer):
tester = self
Expand Down

0 comments on commit fa5f5a5

Please sign in to comment.