Skip to content

Commit

Permalink
Creating test_new_graph_optims under TestCudaOptims class and using O…
Browse files Browse the repository at this point in the history
…ptimizerInfos
  • Loading branch information
jayanthd04 committed Apr 23, 2024
1 parent fe29e60 commit f6cec68
Showing 1 changed file with 86 additions and 7 deletions.
93 changes: 86 additions & 7 deletions test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -2905,9 +2905,6 @@ def _test_graphed_optimizer(self, steps_warmup, steps_train, optimizer_ctor, kwa
self.assertEqual(p_control, p_graphed)

@unittest.skipIf(not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs")
#@optims([optim for optim in optim_db if optim.optim_cls in [torch.optim.NAdam,torch.optim.RAdam,
#torch.optim.Rprop,torch.optim.Adam,torch.optim.AdamW,torch.optim.Adamax,torch.optim.ASGD,
#torch.optim.Adadelta,torch.optim.RMSprop]],dtypes=[torch.float32])
def test_graph_optims(self):
# Needs generalization if we want to extend this test to non-Adam-like optimizers.
cases = [
Expand All @@ -2934,7 +2931,7 @@ def test_graph_optims(self):

for optimizer_ctor, kwargs in cases:
with self.subTest(optimizer_ctor=optimizer_ctor, kwargs=kwargs):
#print(optimizer_ctor,kwargs)
print(optimizer_ctor,kwargs)
self._test_graphed_optimizer(3, 2, optimizer_ctor, kwargs)

@unittest.skipIf(not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs")
Expand Down Expand Up @@ -4163,13 +4160,95 @@ def test_no_triton_on_import(self):

@torch.testing._internal.common_utils.markDynamoStrictTest
class TestCudaOptims(TestCase):
def _test_graphed_optims(self,steps_warmup,steps_train,optimizer_ctor,kwargs):
for actually_do_graphs in (True,False):
params = [
torch.randn((i+5,i+5),device="cuda") for i in range(2)
] + [torch.randn((),device="cuda")]
params_control = [p.clone().requires_grad_() for p in params]
params_graphed = [p.clone().requires_grad_() for p in params]

grads =[[torch.randn_like(p) for p in params] for _ in range(steps_warmup+steps_train)]

#Control
opt = optimizer_ctor(params_control,capturable=False,**kwargs)

for i in range(steps_warmup+steps_train):
for j, p in enumerate(params_control):
p.grad = grads[i][j]
opt.step()

opt = optimizer_ctor(params_graphed,capturable=True,**kwargs)

for i in range(steps_warmup):
for j, p in enumerate(params_graphed):
p.grad = grads[i][j]
opt.step()

if actually_do_graphs:
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
opt.step()

for i in range(steps_train):
if actually_do_graphs:
for j, p in enumerate(params_graphed):
p.grad.copy_(grads[i+steps_warmup][j])
g.replay()
else:
for j, p in enumerate(params_graphed):
p.grad = grads[i+steps_warmup][j]
opt.step()

for p_control,p_graphed in zip(params_control,params_graphed):
self.assertEqual(p_control,p_graphed)

@unittest.skipIf(not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >=5.3 required for graphs")
@optims([optim for optim in optim_db if optim.optim_cls in [torch.optim.NAdam,torch.optim.RAdam,
torch.optim.Rprop,torch.optim.Adam,torch.optim.AdamW,torch.optim.Adamax,torch.optim.ASGD,
torch.optim.Adadelta,torch.optim.RMSprop]],dtypes=[torch.float32])
def test_graph_optims_new(self,device,dtype,optim_info):
optim_cls=optim_info.optim_cls
print(optim_cls)
def test_new_graph_optims(self,device,dtype,optim_info):
optim_cls=optim_info.optim_cls
optKwargs = {
torch.optim.NAdam:[({"lr": 0.1, "betas": (0.8, 0.7), "foreach": foreach,
"decoupled_weight_decay": decoupled_weight_decay, "weight_decay": weight_decay}
for foreach,decoupled_weight_decay,weight_decay in product((False,True,),(False,True),
(0.0,0.1)))],

torch.optim.RAdam:[({"lr":0.1,"betas":(0.8,0.7),"foreach":foreach,
"decoupled_weight_decay":decoupled_weight_decay,"weight_decay":weight_decay}
for foreach,decoupled_weight_decay,weight_decay in product((False,True),
(False,True),(0.0,0.1)))],
torch.optim.Rprop:[({"lr":0.1,"foreach":foreach,"maximize":maximize}
for foreach, maximize in product((False,True),(False,True)))],
torch.optim.Adam:[({"lr":0.1,"betas":(0.8,0.7),"foreach":foreach,"amsgrad":amsgrad}
for foreach,amsgrad in product((False,True),(False,True))),
({"lr":0.1,"betas":(0.8,0.7),"fused":True,"amsgrad":amsgrad}
for amsgrad in (False,True))],
torch.optim.AdamW:[({"lr":0.1,"betas":(0.8,0.7),"foreach":foreach,"amsgrad":amsgrad}
for foreach,amsgrad in product((False,True),(False,True))),
({"lr":0.1,"betas":(0.8,0.7),"fused":True,"amsgrad":amsgrad}
for amsgrad in (False,True))],
torch.optim.Adamax:[({"lr":0.1,"foreach":foreach,"maximize":maximize,"weight_decay":weight_decay}
for foreach,maximize,weight_decay in product((False,True),(False,True),
(0,0.1)))],
torch.optim.ASGD:[({"lr":0.1,"foreach":foreach,"maximize":maximize,"weight_decay":weight_decay}
for foreach,maximize,weight_decay in product((False,True),(False,True),(0,0.1))
)],
torch.optim.Adadelta:[({"lr":0.1,"foreach":foreach,"maximize":maximize,"weight_decay":weight_decay}
for foreach,maximize,weight_decay in product((False,True),(False,True),
(0,0.1)))],
torch.optim.RMSprop:[({"lr":0.1,"foreach":foreach, "maximize":maximize,"weight_decay":weight_decay}
for foreach,maximize,weight_decay in product((False,True),(False,True),
(0,0.1)))],

}
#optim_inputs = optim_info.optim_inputs_func(device=device)
for kwargs in optKwargs[optim_cls]:
for kwarg in kwargs:
with self.subTest(optimizer_ctor=optim_cls,kwargs=kwarg):
print(optim_cls,kwarg)
self._test_graphed_optims(3,2,optim_cls,kwarg)
instantiate_parametrized_tests(TestCuda)
instantiate_parametrized_tests(TestCudaMallocAsync)
instantiate_device_type_tests(TestCudaOptims,globals(),only_for='cuda')
Expand Down

0 comments on commit f6cec68

Please sign in to comment.