Skip to content

Commit

Permalink
remove overly restrictive checks for cudagraph (#80881)
Browse files Browse the repository at this point in the history
Finish fixing #80809
Pull Request resolved: #80881
Approved by: https://github.com/jbschlosser
  • Loading branch information
albanD authored and pytorchmergebot committed Jul 6, 2022
1 parent 393f7f6 commit 9d20af5
Showing 1 changed file with 0 additions and 5 deletions.
5 changes: 0 additions & 5 deletions torch/optim/adamw.py
Expand Up @@ -256,8 +256,6 @@ def _single_tensor_adamw(params: List[Tensor],

if capturable:
assert param.is_cuda and step_t.is_cuda, "If capturable=True, params and state_steps must be CUDA tensors."
else:
assert not step_t.is_cuda, "If capturable=False, state_steps should not be CUDA tensors."

# update step
step_t += 1
Expand Down Expand Up @@ -335,9 +333,6 @@ def _multi_tensor_adamw(params: List[Tensor],
if capturable:
assert all(p.is_cuda and step.is_cuda for p, step in zip(params, state_steps)), \
"If capturable=True, params and state_steps must be CUDA tensors."
else:
assert all(not step.is_cuda for step in state_steps), \
"If capturable=False, state_steps should not be CUDA tensors."

if maximize:
grads = torch._foreach_neg(tuple(grads)) # type: ignore[assignment]
Expand Down

0 comments on commit 9d20af5

Please sign in to comment.