Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions test/inductor/test_compiled_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,13 @@ class KernelCounts(NamedTuple):
KERNEL_COUNTS = {
Adam: KernelCounts(multitensor=2, singletensor=8),
AdamW: KernelCounts(multitensor=2, singletensor=8),
NAdam: KernelCounts(multitensor=2, singletensor=8),
NAdam: KernelCounts(multitensor=2, singletensor=11),
Rprop: KernelCounts(multitensor=2, singletensor=8),
RMSprop: KernelCounts(multitensor=2, singletensor=8),
Adadelta: KernelCounts(multitensor=2, singletensor=8),
Adagrad: KernelCounts(multitensor=5, singletensor=8),
SGD: KernelCounts(multitensor=1, singletensor=8),
ASGD: KernelCounts(multitensor=2, singletensor=8),
ASGD: KernelCounts(multitensor=2, singletensor=11),
RAdam: KernelCounts(multitensor=2, singletensor=8),
Adamax: KernelCounts(multitensor=2, singletensor=8),
}
Expand Down Expand Up @@ -436,7 +436,7 @@ def check_cudagraphs_ran(self):
test_adagrad_recompile = make_recompile_test(Adagrad, kernel_count=5, lr=0.01)
test_asgd_recompile_default = make_recompile_test(ASGD, kernel_count=2, lr=0.01)
test_asgd_recompile_single = make_recompile_test(
ASGD, kernel_count=8, lr=0.01, foreach=False
ASGD, kernel_count=11, lr=0.01, foreach=False
)
test_asgd_recompile_foreach = make_recompile_test(
ASGD, kernel_count=2, lr=0.01, foreach=True
Expand Down
42 changes: 31 additions & 11 deletions torch/_dynamo/variables/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,28 +179,48 @@ def mark_static(x):

# Recursively realize the variable trackers for optim.state and
# optim.param_groups, which recursively install the necessary guards.

# NB: Its necessary to install the guards for optim.state first.
# optim.state is a dict with parameters as keys. Therefore, we just put
# ID_MATCH on parameters, instead of TENSOR_MATCH. When we install the
# guards for param_groups later, VariableTrackerCache just ensures that
# we directly return the cached tensor variable tracker without
# inserting the TENSOR_MATCH guard.
state_vt = LazyVariableTracker.realize_all(
VariableBuilder(tx, AttrSource(self.source, "state"))(self.value.state)
)

param_groups_vt = LazyVariableTracker.realize_all(
VariableBuilder(tx, AttrSource(self.source, "param_groups"))(
self.value.param_groups
)
)

state_vt = VariableBuilder(tx, AttrSource(self.source, "state"))(
self.value.state
)

# We need to realize the top level state dict to populate
# the guard locals
state_vt.realize()

# Populate self.grad_to_source and self.tensor_to_source so that we can
# manually update_list_args
for g_ind, (group, group_vt) in enumerate(
zip(self.value.param_groups, param_groups_vt.items)
):
# we assume here that all params within a param group
# are initialized similarly
if len(group["params"]) > 0:
for param in group["params"]:
if param.grad is not None:
key_index = None
for i, k in enumerate(self.value.state.keys()):
if k is param:
key_index = i
break
if key_index:
state_source = AttrSource(self.source, "state")
LazyVariableTracker.realize_all(
VariableBuilder(
tx,
GetItemSource(
state_source,
ConstDictKeySource(state_source, key_index),
),
)(self.value.state[param])
)
break

group_source = group_vt.source
params_vt = group_vt.getitem_const(ConstantVariable.create("params"))
for p_ind, (p, p_vt) in enumerate(
Expand Down