Skip to content

Commit

Permalink
Adding comments to test_cuda.py
Browse files Browse the repository at this point in the history
  • Loading branch information
jayanthd04 committed May 14, 2024
1 parent d4815a7 commit a0af2fe
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -4465,7 +4465,10 @@ def test_graph_optims(self, device, dtype, optim_info):

for optim_input in all_optim_inputs:
kwargs = optim_input.kwargs
kwargs["lr"]=0.1

# lr as a Tensor is not supported when capturable=False and foreach=True for torch.optim.adam
# and torch.optim.adamw
kwargs["lr"] = 0.1

for actually_do_graphs in (True, False):
params = [
Expand All @@ -4480,7 +4483,6 @@ def test_graph_optims(self, device, dtype, optim_info):
]

# Control (capturable=False)

kwargs["capturable"] = False

opt = optim_cls(params_control, **kwargs)
Expand Down

0 comments on commit a0af2fe

Please sign in to comment.