Skip to content

Commit

Permalink
Adding kwargs to common_optimizers.py for added test coverability
Browse files Browse the repository at this point in the history
  • Loading branch information
jayanthd04 committed May 3, 2024
1 parent 287f741 commit 39d99b5
Showing 1 changed file with 52 additions and 0 deletions.
52 changes: 52 additions & 0 deletions torch/testing/_internal/common_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,9 @@ def optim_inputs_func_adadelta(device, dtype=None):
OptimizerInput(
params=None, kwargs={"rho": 0.95, "weight_decay": 0.9}, desc="rho"
),
OptimizerInput(
params=None,kwargs={'maximize':True},desc="maximize, no weight_decay"
),
] + (cuda_supported_configs if "cuda" in str(device) else [])


Expand Down Expand Up @@ -516,6 +519,11 @@ def optim_inputs_func_adamax(device, dtype=None):
kwargs={"weight_decay": 0.1, "maximize": True},
desc="maximize",
),
OptimizerInput(
params=None,
kwargs={"maximize":True},
desc="maximize, no weight_decay",
),
] + (cuda_supported_configs if "cuda" in str(device) else [])


Expand Down Expand Up @@ -658,6 +666,28 @@ def optim_inputs_func_nadam(device, dtype=None):
},
desc="decoupled_weight_decay",
),
OptimizerInput(
params=None,
kwargs={
"weight_decay":0.1,
},
desc="weight_decay, no momentum_decay",
),
OptimizerInput(
params=None,
kwargs={
"decoupled_weight_decay":True,
},
desc="decoupled_weight_decay, no weight_decay, no momentum_decay",
),
OptimizerInput(
params=None,
kwargs={
"decoupled_weight_decay":True,
"weight_decay":0.1,
},
desc="decoupled_weight_decay, weight_decay",
),
] + (cuda_supported_configs if "cuda" in str(device) else [])


Expand Down Expand Up @@ -721,6 +751,13 @@ def optim_inputs_func_radam(device=None, dtype=None):
kwargs={"weight_decay": 0.1, "decoupled_weight_decay": True},
desc="decoupled_weight_decay",
),
OptimizerInput(
params=None,
kwargs={
"decoupled_weight_decay":True,
},
desc="decoupled_weight_decay, no weight_decay",
),
] + (cuda_supported_configs if "cuda" in str(device) else [])


Expand Down Expand Up @@ -791,6 +828,21 @@ def optim_inputs_func_rmsprop(device, dtype=None):
},
desc="maximize",
),
OptimizerInput(
params=None,
kwargs={
"maximize":True,
},
desc="maximize, no weight_decay",
),
OptimizerInput(
params=None,
kwargs={
"maximize":True,
"weight_decay":0.1,
},
desc="maximize, weight_decay",
),
] + (cuda_supported_configs if "cuda" in str(device) else [])


Expand Down

0 comments on commit 39d99b5

Please sign in to comment.