Skip to content

Commit

Permalink
Cleaning up common_optimizers.py
Browse files Browse the repository at this point in the history
  • Loading branch information
jayanthd04 committed May 7, 2024
1 parent 2a6c915 commit ad331e2
Showing 1 changed file with 4 additions and 12 deletions.
16 changes: 4 additions & 12 deletions torch/testing/_internal/common_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,9 +319,7 @@ def optim_inputs_func_adadelta(device, dtype=None):
OptimizerInput(
params=None, kwargs={"rho": 0.95, "weight_decay": 0.9}, desc="rho"
),
OptimizerInput(
params=None, kwargs={"maximize": True}, desc="maximize"
),
OptimizerInput(params=None, kwargs={"maximize": True}, desc="maximize"),
] + (cuda_supported_configs if "cuda" in str(device) else [])


Expand Down Expand Up @@ -697,7 +695,6 @@ def optim_inputs_func_nadam(device, dtype=None):
params=None,
kwargs={
"weight_decay": 0.1,
#"momentum_decay": 6e-3,
"decoupled_weight_decay": True,
},
desc="decoupled_weight_decay",
Expand Down Expand Up @@ -782,13 +779,6 @@ def optim_inputs_func_radam(device=None, dtype=None):
kwargs={"weight_decay": 0.1, "decoupled_weight_decay": True},
desc="decoupled_weight_decay",
),
#OptimizerInput(
#params=None,
#kwargs={
#"decoupled_weight_decay": True,
#},
#desc="decoupled_weight_decay, no weight_decay",
#),
] + (cuda_supported_configs if "cuda" in str(device) else [])


Expand Down Expand Up @@ -964,7 +954,9 @@ def optim_inputs_func_sgd(device, dtype=None):
kwargs={"weight_decay": 0.1, "maximize": True},
desc="maximize",
),
OptimizerInput(params=None, kwargs={"weight_decay": 0.5}, desc="non-zero weight_decay"),
OptimizerInput(
params=None, kwargs={"weight_decay": 0.5}, desc="non-zero weight_decay"
),
]


Expand Down

0 comments on commit ad331e2

Please sign in to comment.