Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Optimizer state_dict hooks #105953

Closed
wants to merge 9 commits into from
147 changes: 146 additions & 1 deletion torch/optim/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,10 @@ def __init__(self, params, defaults):
self.defaults = defaults
self._optimizer_step_pre_hooks: Dict[int, Callable] = OrderedDict()
self._optimizer_step_post_hooks: Dict[int, Callable] = OrderedDict()
self._optimizer_state_dict_pre_hooks: Dict[int, Callable] = OrderedDict()
self._optimizer_state_dict_post_hooks: Dict[int, Callable] = OrderedDict()
self._optimizer_load_state_dict_pre_hooks: Dict[int, Callable] = OrderedDict()
self._optimizer_load_state_dict_post_hooks: Dict[int, Callable] = OrderedDict()

self._patch_step_function()

Expand Down Expand Up @@ -239,6 +243,14 @@ def __setstate__(self, state):
self._optimizer_step_pre_hooks = OrderedDict()
if '_optimizer_step_post_hooks' not in self.__dict__:
self._optimizer_step_post_hooks = OrderedDict()
if '_optimizer_state_dict_pre_hooks' not in self.__dict__:
self._optimizer_state_dict_pre_hooks = OrderedDict()
if '_optimizer_state_dict_post_hooks' not in self.__dict__:
self._optimizer_state_dict_post_hooks = OrderedDict()
if '_optimizer_load_state_dict_pre_hooks' not in self.__dict__:
self._optimizer_load_state_dict_pre_hooks = OrderedDict()
if '_optimizer_load_state_dict_post_hooks' not in self.__dict__:
self._optimizer_load_state_dict_post_hooks = OrderedDict()
self._patch_step_function() # To support multiprocessing pickle/unpickle
self.defaults.setdefault('differentiable', False)

Expand Down Expand Up @@ -385,6 +397,59 @@ def register_step_post_hook(self, hook: Callable[..., None]) -> RemovableHandle:
self._optimizer_step_post_hooks[handle.id] = hook
return handle


def register_state_dict_pre_hook(self, hook: Callable[["Optimizer"], None]) -> RemovableHandle:
janeyx99 marked this conversation as resolved.
Show resolved Hide resolved
r"""Register a state dict pre hook which will be called before
:meth:`~torch.optim.Optimizer.state_dict` is called. It should have the
following signature::

hook(optimizer) -> None

The ``optimizer`` argument is the optimizer instance being used.
The hook will be called with argument ``self`` before calling ``state_dict`` on ``self``.
The registered hook can be used to perform pre-processing before the ``state_dict``
call is made.

Args:
hook (Callable): The user defined hook to be registered.

Returns:
:class:`torch.utils.hooks.RemoveableHandle`:
a handle that can be used to remove the added hook by calling
``handle.remove()``
"""
handle = hooks.RemovableHandle(self._optimizer_state_dict_pre_hooks)
self._optimizer_state_dict_pre_hooks[handle.id] = hook
return handle


def register_state_dict_post_hook(
self, hook: Callable[["Optimizer", Dict[str, Any]], Optional[Dict[str, Any]]]
) -> RemovableHandle:
r"""Register a state dict post hook which will be called after
:meth:`~torch.optim.Optimizer.state_dict` is called. It should have the
following signature::

hook(optimizer, state_dict) -> state_dict or None

The hook will be called with arguments ``self`` and ``state_dict`` after generating
a ``state_dict`` on ``self``. The hook may modify the state_dict inplace or optionally
return a new one. The registered hook can be used to perform post-processing
on the ``state_dict`` before it is returned.

Args:
hook (Callable): The user defined hook to be registered.

Returns:
:class:`torch.utils.hooks.RemoveableHandle`:
a handle that can be used to remove the added hook by calling
``handle.remove()``
"""
handle = hooks.RemovableHandle(self._optimizer_state_dict_post_hooks)
self._optimizer_state_dict_post_hooks[handle.id] = hook
return handle


@torch._disable_dynamo
def state_dict(self):
r"""Returns the state of the optimizer as a :class:`dict`.
Expand All @@ -396,6 +461,10 @@ def state_dict(self):
* param_groups - a list containing all parameter groups where each
parameter group is a dict
"""

for hook in self._optimizer_state_dict_pre_hooks.values():
hook(self)

# Save order indices instead of Tensors
param_mappings = {}
start_index = 0
Expand All @@ -412,11 +481,18 @@ def pack_group(group):
# Remap state to use order indices as keys
packed_state = {(param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): v
for k, v in self.state.items()}
return {

state_dict = {
'state': packed_state,
'param_groups': param_groups,
}

for hook in self._optimizer_state_dict_post_hooks.values():
hook_result = hook(self, state_dict)
if hook_result is not None:
state_dict = hook_result
return state_dict

@staticmethod
def _process_value_according_to_param_policy(param: Tensor, value: Tensor, param_id: Optional[int] = None,
param_groups: Optional[List[Dict[Any, Any]]] = None, key=None) -> Tensor:
Expand Down Expand Up @@ -444,6 +520,64 @@ def _process_value_according_to_param_policy(param: Tensor, value: Tensor, param
else:
return value.to(device=param.device)


def register_optimizer_load_state_dict_pre_hook(
self, hook: Callable[["Optimizer", Dict[str, Any]], None]
) -> RemovableHandle:
r"""Register a load_state_dict pre hook which will be called before
:meth:`~torch.optim.Optimizer.load_state_dict` is called. It should have the
following signature::

hook(optimizer, state_dict) -> None
janeyx99 marked this conversation as resolved.
Show resolved Hide resolved

The ``optimizer`` argument is the optimizer instance being used and the
``state_dict`` argument is a COPY of the input ``state_dict`` to
``load_state_dict``, so it can be modified.

The hook will be called with argument ``self`` and ``state_dict`` before
calling ``load_state_dict`` on ``self``. The registered hook can be used to
perform pre-processing before the ``load_state_dict`` call is made.

Args:
hook (Callable): The user defined hook to be registered.

Returns:
:class:`torch.utils.hooks.RemoveableHandle`:
a handle that can be used to remove the added hook by calling
``handle.remove()``
"""
handle = hooks.RemovableHandle(self._optimizer_load_state_dict_pre_hooks)
self._optimizer_load_state_dict_pre_hooks[handle.id] = hook
return handle


def register_optimizer_load_state_dict_post_hook(self, hook: Callable[["Optimizer"], None]) -> RemovableHandle:
r"""Register a load_state_dict post hook which will be called after
:meth:`~torch.optim.Optimizer.load_state_dict` is called. It should have the
following signature::

hook(optimizer) -> None

The ``optimizer`` argument is the optimizer instance being used.

The hook will be called with argument ``self`` after calling
``load_state_dict`` on ``self``. The registered hook can be used to
perform post-processing after ``load_state_dict`` has loaded the
``state_dict``.

Args:
hook (Callable): The user defined hook to be registered.

Returns:
:class:`torch.utils.hooks.RemoveableHandle`:
a handle that can be used to remove the added hook by calling
``handle.remove()``
"""
handle = hooks.RemovableHandle(self._optimizer_load_state_dict_post_hooks)
self._optimizer_load_state_dict_post_hooks[handle.id] = hook
return handle


@torch._disable_dynamo
def load_state_dict(self, state_dict):
r"""Loads the optimizer state.
Expand All @@ -454,6 +588,10 @@ def load_state_dict(self, state_dict):
"""
# deepcopy, to be consistent with module API
state_dict = deepcopy(state_dict)

for hook in self._optimizer_load_state_dict_pre_hooks.values():
hook(self, state_dict)

# Validate the state_dict
groups = self.param_groups
saved_groups = state_dict['param_groups']
Expand Down Expand Up @@ -501,6 +639,13 @@ def update_group(group, new_group):
update_group(g, ng) for g, ng in zip(groups, saved_groups)]
self.__setstate__({'state': state, 'param_groups': param_groups})

for hook in self._optimizer_load_state_dict_post_hooks.values():
out = hook(self)
assert out is None, (
janeyx99 marked this conversation as resolved.
Show resolved Hide resolved
"Hooks registered with ``register_optimizer_load_state_dict_post_hook`` are not"
"expected to return new values."
)

@torch._disable_dynamo
def zero_grad(self, set_to_none: bool = True):
r"""Resets the gradients of all optimized :class:`torch.Tensor` s.
Expand Down