From 9d20af50608b146fe1c3296210a05cd8e4c60af2 Mon Sep 17 00:00:00 2001 From: albanD Date: Wed, 6 Jul 2022 18:08:49 +0000 Subject: [PATCH] remove overly restrictive checks for cudagraph (#80881) Finish fixing https://github.com/pytorch/pytorch/issues/80809 Pull Request resolved: https://github.com/pytorch/pytorch/pull/80881 Approved by: https://github.com/jbschlosser --- torch/optim/adamw.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/torch/optim/adamw.py b/torch/optim/adamw.py index 9b0dace73fba..546586f6c22e 100644 --- a/torch/optim/adamw.py +++ b/torch/optim/adamw.py @@ -256,8 +256,6 @@ def _single_tensor_adamw(params: List[Tensor], if capturable: assert param.is_cuda and step_t.is_cuda, "If capturable=True, params and state_steps must be CUDA tensors." - else: - assert not step_t.is_cuda, "If capturable=False, state_steps should not be CUDA tensors." # update step step_t += 1 @@ -335,9 +333,6 @@ def _multi_tensor_adamw(params: List[Tensor], if capturable: assert all(p.is_cuda and step.is_cuda for p, step in zip(params, state_steps)), \ "If capturable=True, params and state_steps must be CUDA tensors." - else: - assert all(not step.is_cuda for step in state_steps), \ - "If capturable=False, state_steps should not be CUDA tensors." if maximize: grads = torch._foreach_neg(tuple(grads)) # type: ignore[assignment]