Skip to content

Commit

Permalink
Fixing most of the linting issues
Browse files Browse the repository at this point in the history
  • Loading branch information
jayanthd04 committed Apr 27, 2024
1 parent c7b33f4 commit af8653c
Showing 1 changed file with 28 additions and 25 deletions.
53 changes: 28 additions & 25 deletions test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -4614,10 +4614,10 @@ def _test_graphed_optims(self, steps_warmup, steps_train, optimizer_ctor, kwargs
torch.optim.Adamax,
torch.optim.ASGD,
torch.optim.Adadelta,
torch.optim.RMSprop
torch.optim.RMSprop,
]
],
dtypes=[torch.float32]
dtypes=[torch.float32],
)
def test_graph_optims(self, device, dtype, optim_info):
optim_cls = optim_info.optim_cls
Expand All @@ -4629,26 +4629,27 @@ def test_graph_optims(self, device, dtype, optim_info):
"betas": (0.8, 0.7),
"foreach": foreach,
"decoupled_weight_decay": decoupled_weight_decay,
"weight_decay": weight_decay
"weight_decay": weight_decay,
}
for foreach, decoupled_weight_decay, weight_decay in product(
(
False,
True,
),
(False, True),
(0.0, 0.1)
(0.0, 0.1),
)
)
],
torch.optim.RAdam: [
(
{"lr": 0.1,
"betas": (0.8, 0.7),
"foreach": foreach,
"decoupled_weight_decay": decoupled_weight_decay,
"weight_decay": weight_decay
}
{
"lr": 0.1,
"betas": (0.8, 0.7),
"foreach": foreach,
"decoupled_weight_decay": decoupled_weight_decay,
"weight_decay": weight_decay
}
for foreach, decoupled_weight_decay, weight_decay in product(
(False, True), (False, True), (0.0, 0.1)
)
Expand All @@ -4666,37 +4667,37 @@ def test_graph_optims(self, device, dtype, optim_info):
"lr": 0.1,
"betas": (0.8, 0.7),
"foreach": foreach,
"amsgrad": amsgrad
"amsgrad": amsgrad,
}
for foreach, amsgrad in product((False, True), (False, True))
),
(
{"lr": 0.1, "betas": (0.8, 0.7), "fused": True, "amsgrad": amsgrad}
for amsgrad in (False, True)
)
),
],
torch.optim.AdamW: [
(
{
"lr": 0.1,
"betas": (0.8, 0.7),
"foreach": foreach,
"amsgrad": amsgrad
"amsgrad": amsgrad,
}
for foreach, amsgrad in product((False, True), (False, True))
),
(
{"lr": 0.1, "betas": (0.8, 0.7), "fused": True, "amsgrad": amsgrad}
for amsgrad in (False, True)
)
),
],
torch.optim.Adamax: [
(
{
"lr": 0.1,
"foreach": foreach,
"maximize": maximize,
"weight_decay": weight_decay
"weight_decay": weight_decay,
}
for foreach, maximize, weight_decay in product(
(False, True), (False, True), (0, 0.1)
Expand All @@ -4709,7 +4710,7 @@ def test_graph_optims(self, device, dtype, optim_info):
"lr": 0.1,
"foreach": foreach,
"maximize": maximize,
"weight_decay": weight_decay
"weight_decay": weight_decay,
}
for foreach, maximize, weight_decay in product(
(False, True), (False, True), (0, 0.1)
Expand All @@ -4722,9 +4723,10 @@ def test_graph_optims(self, device, dtype, optim_info):
"lr": 0.1,
"foreach": foreach,
"maximize": maximize,
"weight_decay": weight_decay}
for foreach, maximize, weight_decay in product(
(False, True), (False, True), (0, 0.1)
"weight_decay": weight_decay
}
for foreach, maximize, weight_decay in product(
(False, True), (False, True), (0, 0.1)
)
)
],
Expand All @@ -4734,7 +4736,7 @@ def test_graph_optims(self, device, dtype, optim_info):
"lr": 0.1,
"foreach": foreach,
"maximize": maximize,
"weight_decay": weight_decay
"weight_decay": weight_decay,
}
for foreach, maximize, weight_decay in product(
(False, True), (False, True), (0, 0.1)
Expand Down Expand Up @@ -4781,7 +4783,7 @@ def test_graph_scaling_fused_optimizers(self, device, dtype, optim_info):
"dampening": d,
"weight_decay": w,
"nesterov": n,
"fused": True
"fused": True,
}
for d, w, n in product((0.0, 0.5), (0.0, 0.5), (False,))
),
Expand All @@ -4792,8 +4794,9 @@ def test_graph_scaling_fused_optimizers(self, device, dtype, optim_info):
"dampening": d,
"weight_decay": w,
"nesterov": n,
"fused": True}
for d, w, n in product((0.0,), (0.0, 0.5), (True, False))
"fused": True,
}
for d, w, n in product((0.0,), (0.0, 0.5), (True, False))
),
],
}
Expand Down Expand Up @@ -4827,14 +4830,14 @@ def test_graph_scaling_fused_optimizers(self, device, dtype, optim_info):
with torch.no_grad():
scaler_for_control._lazy_init_scale_growth_tracker(
torch.device("cuda")
)
)

scaler_for_graphed = torch.cuda.amp.GradScaler()
scaler_for_graphed.load_state_dict(scaler_for_control.state_dict())
with torch.no_grad():
scaler_for_graphed._lazy_init_scale_growth_tracker(
torch.device("cuda")
)
)

# Control (capturable=False)
if has_capturable_arg:
Expand Down

0 comments on commit af8653c

Please sign in to comment.