Skip to content

Commit

Permalink
Rearranging common_optimizers.py configs, cleaning up test_cuda.py an…
Browse files Browse the repository at this point in the history
…d getting rid of explicitly assigning betas
  • Loading branch information
jayanthd04 committed May 9, 2024
1 parent 37e6979 commit d4815a7
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 58 deletions.
25 changes: 2 additions & 23 deletions test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -4459,23 +4459,13 @@ def test_graph_optims(self, device, dtype, optim_info):
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
device, dtype, optim_info, skip=("differentiable",)
)
has_betas = any(
"betas" in error_inp.optimizer_error_input.kwargs
for error_inp in optim_info.optim_error_inputs_func(
device="cpu", dtype=dtype
)
)

steps_warmup = 3
steps_train = 2

for optim_input in all_optim_inputs:
kwargs = optim_input.kwargs
if "lr" in kwargs:
del kwargs["lr"]
kwargs["lr"] = 0.1
if has_betas and optim_cls != torch.optim.Adamax:
kwargs["betas"] = (0.8, 0.7)
kwargs["lr"]=0.1

for actually_do_graphs in (True, False):
params = [
Expand Down Expand Up @@ -4543,26 +4533,15 @@ def test_graph_scaling_fused_optimizers(self, device, dtype, optim_info):
steps_train = 2

optim_inputs = optim_info.optim_inputs_func(device=device)
has_betas = any(
"betas" in error_inp.optimizer_error_input.kwargs
for error_inp in optim_info.optim_error_inputs_func(
device="cpu", dtype=dtype
)
)

for optim_input in optim_inputs:
kwargs = optim_input.kwargs
kwargs["fused"] = True
if "lr" in kwargs:
del kwargs["lr"]
kwargs["lr"] = 0.1
if has_betas:
kwargs["betas"] = (0.8, 0.7)

for actually_do_graphs in (
(True, False) if optim_info.has_capturable_arg else (True,)
):
params = [torch.randn((i + 5, i + 5), device="cuda") for i in range(2)]
params = [torch.randn((i + 5, i + 5), device=device) for i in range(2)]
params_control = [p.clone().requires_grad_() for p in params]
params_graphed = [p.clone().requires_grad_() for p in params]

Expand Down
72 changes: 37 additions & 35 deletions torch/testing/_internal/common_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def __init__(
supported_impls: Tuple[str] = ("foreach", "differentiable"),
# the optim supports passing in sparse gradients as well as dense grads
supports_sparse: bool = False,
# the optim is capturable in a CUDA graph
# the optimizer constructor supports passing in capturable as a kwarg
has_capturable_arg: bool = False,
# the optim only supports one config: sparse grads w/ dense params, see SparseAdam
only_supports_sparse_grads: bool = False,
Expand Down Expand Up @@ -314,6 +314,7 @@ def optim_inputs_func_adadelta(device, dtype=None):
OptimizerInput(
params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
),
OptimizerInput(params=None, kwargs={"maximize": True}, desc="maximize"),
OptimizerInput(
params=None,
kwargs={"weight_decay": 0.1, "maximize": True},
Expand All @@ -322,7 +323,7 @@ def optim_inputs_func_adadelta(device, dtype=None):
OptimizerInput(
params=None, kwargs={"rho": 0.95, "weight_decay": 0.9}, desc="rho"
),
OptimizerInput(params=None, kwargs={"maximize": True}, desc="maximize"),

] + (cuda_supported_configs if "cuda" in str(device) else [])


Expand Down Expand Up @@ -532,14 +533,15 @@ def optim_inputs_func_adamax(device, dtype=None):
),
OptimizerInput(
params=None,
kwargs={"weight_decay": 0.1, "maximize": True},
desc="maximize, weight_decay",
kwargs={"maximize": True},
desc="maximize",
),
OptimizerInput(
params=None,
kwargs={"maximize": True},
desc="maximize",
kwargs={"weight_decay": 0.1, "maximize": True},
desc="maximize, weight_decay",
),

] + (cuda_supported_configs if "cuda" in str(device) else [])


Expand Down Expand Up @@ -689,6 +691,13 @@ def optim_inputs_func_nadam(device, dtype=None):
kwargs={"momentum_decay": 6e-3},
desc="non-zero momentum_decay",
),
OptimizerInput(
params=None,
kwargs={
"weight_decay": 0.1,
},
desc="weight_decay",
),
OptimizerInput(
params=None,
kwargs={"weight_decay": 0.1, "momentum_decay": 6e-3},
Expand All @@ -702,13 +711,6 @@ def optim_inputs_func_nadam(device, dtype=None):
},
desc="decoupled_weight_decay",
),
OptimizerInput(
params=None,
kwargs={
"weight_decay": 0.1,
},
desc="weight_decay",
),
] + (cuda_supported_configs if "cuda" in str(device) else [])


Expand Down Expand Up @@ -834,38 +836,38 @@ def optim_inputs_func_rmsprop(device, dtype=None):
),
OptimizerInput(
params=None,
kwargs={"weight_decay": 0.1, "centered": True},
desc="centered",
kwargs={
"maximize": True,
},
desc="maximize",
),
OptimizerInput(
params=None,
kwargs={"weight_decay": 0.1, "centered": True, "momentum": 0.1},
desc="momentum",
kwargs={"weight_decay": 0.1, "centered": True},
desc="centered",
),
OptimizerInput(
params=None,
kwargs={
"weight_decay": 0.1,
"centered": True,
"momentum": 0.1,
"maximize": True,
"weight_decay": 0.1,
},
desc="maximize, centered, weight_decay, w/ momentum",
desc="maximize, weight_decay",
),
OptimizerInput(
params=None,
kwargs={
"maximize": True,
},
desc="maximize",
kwargs={"weight_decay": 0.1, "centered": True, "momentum": 0.1},
desc="momentum",
),
OptimizerInput(
params=None,
kwargs={
"maximize": True,
"weight_decay": 0.1,
"centered": True,
"momentum": 0.1,
"maximize": True,
},
desc="maximize, weight_decay",
desc="maximize, centered, weight_decay, w/ momentum",
),
] + (cuda_supported_configs if "cuda" in str(device) else [])

Expand Down Expand Up @@ -936,7 +938,15 @@ def optim_inputs_func_sgd(device, dtype=None):
OptimizerInput(
params=None, kwargs={"lr": torch.tensor(0.001)}, desc="tensor lr"
),
OptimizerInput(
params=None, kwargs={"weight_decay": 0.5}, desc="non-zero weight_decay"
),
OptimizerInput(params=None, kwargs={"momentum": 0.9}, desc="momentum"),
OptimizerInput(
params=None,
kwargs={"weight_decay": 0.1, "maximize": True},
desc="maximize",
),
OptimizerInput(
params=None,
kwargs={"momentum": 0.9, "dampening": 0.5},
Expand All @@ -952,14 +962,6 @@ def optim_inputs_func_sgd(device, dtype=None):
kwargs={"momentum": 0.9, "nesterov": True, "weight_decay": 0.1},
desc="nesterov",
),
OptimizerInput(
params=None,
kwargs={"weight_decay": 0.1, "maximize": True},
desc="maximize",
),
OptimizerInput(
params=None, kwargs={"weight_decay": 0.5}, desc="non-zero weight_decay"
),
]


Expand Down

0 comments on commit d4815a7

Please sign in to comment.