Skip to content

Commit

Permalink
[WIP] add Optimizer state_dict hooks
Browse files Browse the repository at this point in the history
ghstack-source-id: 0e74a91dcc186d97dcdc76b308f3634bd18171fb
Pull Request resolved: #105953
  • Loading branch information
janeyx99 committed Jul 27, 2023
1 parent 6f74c2c commit 973cf03
Showing 1 changed file with 186 additions and 3 deletions.
189 changes: 186 additions & 3 deletions torch/optim/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@

Args: TypeAlias = Tuple[Any, ...]
Kwargs: TypeAlias = Dict[str, Any]
StateDict: TypeAlias = Dict[str, Any]

GlobalOptimizerPreHook: TypeAlias = Callable[["Optimizer", Args, Kwargs], Optional[Tuple[Args, Kwargs]]]
GlobalOptimizerPostHook: TypeAlias = Callable[["Optimizer", Args, Kwargs], None]

Expand Down Expand Up @@ -229,12 +231,20 @@ class Optimizer:

_optimizer_step_pre_hooks: Dict[int, OptimizerPreHook]
_optimizer_step_post_hooks: Dict[int, OptimizerPostHook]
_optimizer_state_dict_pre_hooks: Dict[int, Callable[[Self], None]]
_optimizer_state_dict_post_hooks: Dict[int, Callable[[Self, StateDict], Optional[StateDict]]]
_optimizer_load_state_dict_pre_hooks: Dict[int, Callable[[Self, StateDict], None]]
_optimizer_load_state_dict_post_hooks: Dict[int, Callable[[Self], None]]

def __init__(self, params: params_t, defaults: Dict[str, Any]) -> None:
torch._C._log_api_usage_once("python.optimizer")
self.defaults = defaults
self._optimizer_step_pre_hooks = OrderedDict()
self._optimizer_step_post_hooks = OrderedDict()
self._optimizer_state_dict_pre_hooks = OrderedDict()
self._optimizer_state_dict_post_hooks = OrderedDict()
self._optimizer_load_state_dict_pre_hooks = OrderedDict()
self._optimizer_load_state_dict_post_hooks = OrderedDict()

self._patch_step_function()

Expand Down Expand Up @@ -273,6 +283,14 @@ def __setstate__(self, state: Dict[str, Any]) -> None:
self._optimizer_step_pre_hooks = OrderedDict()
if '_optimizer_step_post_hooks' not in self.__dict__:
self._optimizer_step_post_hooks = OrderedDict()
if '_optimizer_state_dict_pre_hooks' not in self.__dict__:
self._optimizer_state_dict_pre_hooks = OrderedDict()
if '_optimizer_state_dict_post_hooks' not in self.__dict__:
self._optimizer_state_dict_post_hooks = OrderedDict()
if '_optimizer_load_state_dict_pre_hooks' not in self.__dict__:
self._optimizer_load_state_dict_pre_hooks = OrderedDict()
if '_optimizer_load_state_dict_post_hooks' not in self.__dict__:
self._optimizer_load_state_dict_post_hooks = OrderedDict()
self._patch_step_function() # To support multiprocessing pickle/unpickle
self.defaults.setdefault('differentiable', False)

Expand Down Expand Up @@ -427,8 +445,77 @@ def register_step_post_hook(self, hook: OptimizerPostHook) -> RemovableHandle:
self._optimizer_step_post_hooks[handle.id] = hook
return handle


def register_state_dict_pre_hook(
self, hook: Callable[[Self], None], prepend: bool = False
) -> RemovableHandle:
r"""Register a state dict pre-hook which will be called before
:meth:`~torch.optim.Optimizer.state_dict` is called. It should have the
following signature::
hook(optimizer) -> None
The ``optimizer`` argument is the optimizer instance being used.
The hook will be called with argument ``self`` before calling ``state_dict`` on ``self``.
The registered hook can be used to perform pre-processing before the ``state_dict``
call is made.
Args:
hook (Callable): The user defined hook to be registered.
prepend (bool): If True, the provided pre ``hook`` will be fired before
all the existing registered pre-hooks on ``state_dict``. Otherwise,
the provided ``hook`` will be fired after all the existing registered
pre-hooks. (default: False)
Returns:
:class:`torch.utils.hooks.RemoveableHandle`:
a handle that can be used to remove the added hook by calling
``handle.remove()``
"""
handle = hooks.RemovableHandle(self._optimizer_state_dict_pre_hooks)
self._optimizer_state_dict_pre_hooks[handle.id] = hook
if prepend:
self._optimizer_state_dict_pre_hooks.move_to_end(handle.id, last=False) # type: ignore[attr-defined]
return handle


def register_state_dict_post_hook(
self,
hook: Callable[[Self, StateDict], Optional[StateDict]],
prepend: bool = False,
) -> RemovableHandle:
r"""Register a state dict post-hook which will be called after
:meth:`~torch.optim.Optimizer.state_dict` is called. It should have the
following signature::
hook(optimizer, state_dict) -> state_dict or None
The hook will be called with arguments ``self`` and ``state_dict`` after generating
a ``state_dict`` on ``self``. The hook may modify the state_dict inplace or optionally
return a new one. The registered hook can be used to perform post-processing
on the ``state_dict`` before it is returned.
Args:
hook (Callable): The user defined hook to be registered.
prepend (bool): If True, the provided post ``hook`` will be fired before
all the existing registered post-hooks on ``state_dict``. Otherwise,
the provided ``hook`` will be fired after all the existing registered
post-hooks. (default: False)
Returns:
:class:`torch.utils.hooks.RemoveableHandle`:
a handle that can be used to remove the added hook by calling
``handle.remove()``
"""
handle = hooks.RemovableHandle(self._optimizer_state_dict_post_hooks)
self._optimizer_state_dict_post_hooks[handle.id] = hook
if prepend:
self._optimizer_state_dict_post_hooks.move_to_end(handle.id, last=False) # type: ignore[attr-defined]
return handle


@torch._disable_dynamo
def state_dict(self) -> Dict[str, Any]:
def state_dict(self) -> StateDict:
r"""Returns the state of the optimizer as a :class:`dict`.
It contains two entries:
Expand All @@ -438,6 +525,10 @@ def state_dict(self) -> Dict[str, Any]:
* param_groups - a list containing all parameter groups where each
parameter group is a dict
"""

for hook in self._optimizer_state_dict_pre_hooks.values():
hook(self)

# Save order indices instead of Tensors
param_mappings: Dict[int, int] = {}
start_index = 0
Expand All @@ -454,11 +545,18 @@ def pack_group(group: Dict[str, Any]) -> Dict[str, Any]:
# Remap state to use order indices as keys
packed_state = {(param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): v
for k, v in self.state.items()}
return {

state_dict = {
'state': packed_state,
'param_groups': param_groups,
}

for hook in self._optimizer_state_dict_post_hooks.values():
hook_result = hook(self, state_dict)
if hook_result is not None:
state_dict = hook_result
return state_dict

@staticmethod
def _process_value_according_to_param_policy(
param: torch.Tensor,
Expand Down Expand Up @@ -491,8 +589,83 @@ def _process_value_according_to_param_policy(
else:
return value.to(device=param.device)


def register_optimizer_load_state_dict_pre_hook(
self,
hook: Callable[[Self, StateDict], None],
prepend: bool = False,
) -> RemovableHandle:
r"""Register a load_state_dict pre-hook which will be called before
:meth:`~torch.optim.Optimizer.load_state_dict` is called. It should have the
following signature::
hook(optimizer, state_dict) -> None
The ``optimizer`` argument is the optimizer instance being used and the
``state_dict`` argument is a shallow copy of the input ``state_dict`` to
``load_state_dict``, so it CAN be modified. The modifications will persist
on the state_dict that will be loaded from.
The hook will be called with argument ``self`` and ``state_dict`` before
calling ``load_state_dict`` on ``self``. The registered hook can be used to
perform pre-processing before the ``load_state_dict`` call is made.
Args:
hook (Callable): The user defined hook to be registered.
prepend (bool): If True, the provided pre ``hook`` will be fired before
all the existing registered pre-hooks on ``load_state_dict``. Otherwise,
the provided ``hook`` will be fired after all the pre-hooks registered
before. (default: False)
Returns:
:class:`torch.utils.hooks.RemoveableHandle`:
a handle that can be used to remove the added hook by calling
``handle.remove()``
"""
handle = hooks.RemovableHandle(self._optimizer_load_state_dict_pre_hooks)
self._optimizer_load_state_dict_pre_hooks[handle.id] = hook
if prepend:
self._optimizer_load_state_dict_pre_hooks.move_to_end(handle.id, last=False) # type: ignore[attr-defined]
return handle


def register_optimizer_load_state_dict_post_hook(
self, hook: Callable[[Self], None], prepend: bool = False
) -> RemovableHandle:
r"""Register a load_state_dict post-hook which will be called after
:meth:`~torch.optim.Optimizer.load_state_dict` is called. It should have the
following signature::
hook(optimizer) -> None
The ``optimizer`` argument is the optimizer instance being used.
The hook will be called with argument ``self`` after calling
``load_state_dict`` on ``self``. The registered hook can be used to
perform post-processing after ``load_state_dict`` has loaded the
``state_dict``.
Args:
hook (Callable): The user defined hook to be registered.
prepend (bool): If True, the provided post ``hook`` will be fired before
all the existing registered post-hooks on ``load_state_dict``. Otherwise,
the provided ``hook`` will be fired after all the existing registered
post-hooks. (default: False)
Returns:
:class:`torch.utils.hooks.RemoveableHandle`:
a handle that can be used to remove the added hook by calling
``handle.remove()``
"""
handle = hooks.RemovableHandle(self._optimizer_load_state_dict_post_hooks)
self._optimizer_load_state_dict_post_hooks[handle.id] = hook
if prepend:
self._optimizer_load_state_dict_post_hooks.move_to_end(handle.id, last=False) # type: ignore[attr-defined]
return handle


@torch._disable_dynamo
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
def load_state_dict(self, state_dict: StateDict) -> None:
r"""Loads the optimizer state.
Args:
Expand All @@ -502,6 +675,9 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
# shallow copy, to be consistent with module API
state_dict = state_dict.copy()

for hook in self._optimizer_load_state_dict_pre_hooks.values():
hook(self, state_dict)

# Validate the state_dict
groups = self.param_groups

Expand Down Expand Up @@ -551,6 +727,13 @@ def update_group(group: Dict[str, Any], new_group: Dict[str, Any]) -> Dict[str,
update_group(g, ng) for g, ng in zip(groups, saved_groups)]
self.__setstate__({'state': state, 'param_groups': param_groups})

for hook in self._optimizer_load_state_dict_post_hooks.values():
out = hook(self)
assert out is None, (
"Hooks registered with ``register_optimizer_load_state_dict_post_hook`` are not"
"expected to return new values."
)

@torch._disable_dynamo
def zero_grad(self, set_to_none: bool = True) -> None:
r"""Resets the gradients of all optimized :class:`torch.Tensor` s.
Expand Down

0 comments on commit 973cf03

Please sign in to comment.