Skip to content

Commit

Permalink
add Optimizer state_dict hooks
Browse files Browse the repository at this point in the history
ghstack-source-id: 620bac6defa1cf2d1d51b91d3b293f4cc5ac265b
Pull Request resolved: #105953
  • Loading branch information
janeyx99 committed Jul 27, 2023
1 parent ca7ece9 commit 2b848c1
Show file tree
Hide file tree
Showing 2 changed files with 280 additions and 3 deletions.
93 changes: 93 additions & 0 deletions test/optim/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -1913,6 +1913,99 @@ def test_fused_optimizer_raises(self):
with self.assertRaisesRegex(RuntimeError, "`fused` does not support `differentiable`"):
optimizer_ctor([torch.empty((), device="cuda")], differentiable=True, fused=True)

@staticmethod
def _state_dict_pre_hook(optimizer: Optimizer) -> None:
optimizer.state["test"] = 1

@staticmethod
def _state_dict_post_hook(optimizer: Optimizer, state_dict: Dict[str, Any]) -> Dict[str, Any]:
if "test" in state_dict["state"]:
state_dict["state"].pop("test")
state_dict["ran_state_dict_pre_hook"] = True
else:
state_dict["ran_state_dict_pre_hook"] = False
return state_dict

@staticmethod
def _load_state_dict_pre_hook1(optimizer: Optimizer, state_dict: Dict[str, Any]) -> None:
state_dict["param_groups"][0]["lr"] = 0.002

@staticmethod
def _load_state_dict_pre_hook2(optimizer: Optimizer, state_dict: Dict[str, Any]) -> Dict[str, Any]:
# The typical use case for returning a state dict is to drastically modify the state dict.
# I will simulate by simply making a deep copy and ensuring that my_state_dict still gets used
my_state_dict = deepcopy(state_dict)
my_state_dict["param_groups"][0]["lr"] = 0.003
return my_state_dict

@staticmethod
def _load_state_dict_post_hook(optimizer: Optimizer) -> None:
optimizer.state["ran_load_state_dict_pre_hook2"] = optimizer.param_groups[0]["lr"] == 0.003
optimizer.state["ran_load_state_dict_post_hook"] = True

def test_state_dict_pre_hook(self):
param = torch.rand(2, 3, requires_grad=True)
param.grad = torch.rand(2, 3, requires_grad=True)
opt = SGD([param], lr=0.001)
opt.register_state_dict_pre_hook(self._state_dict_pre_hook)
state_dict = opt.state_dict()
self.assertEqual(state_dict["state"]["test"], 1)

def test_state_dict_post_hook(self):
param = torch.rand(2, 3, requires_grad=True)
param.grad = torch.rand(2, 3, requires_grad=True)
opt = SGD([param], lr=0.001)
opt.register_state_dict_post_hook(self._state_dict_post_hook)
state_dict = opt.state_dict()
self.assertEqual(state_dict["ran_state_dict_pre_hook"], False)

def test_state_dict_pre_post_hook(self):
param = torch.rand(2, 3, requires_grad=True)
param.grad = torch.rand(2, 3, requires_grad=True)
opt = SGD([param], lr=0.001)
opt.register_state_dict_pre_hook(self._state_dict_pre_hook)
opt.register_state_dict_post_hook(self._state_dict_post_hook)
state_dict = opt.state_dict()
self.assertFalse("test" in state_dict["state"])
self.assertEqual(state_dict["ran_state_dict_pre_hook"], True)

def test_load_state_dict_pre_hook_and_prepend(self):
param = torch.rand(2, 3, requires_grad=True)
param.grad = torch.rand(2, 3, requires_grad=True)
opt = SGD([param], lr=0.001)
state_dict = opt.state_dict()

# usually one would have a new opt instance here, but it's all the same here
opt.register_load_state_dict_pre_hook(self._load_state_dict_pre_hook1)
opt.load_state_dict(state_dict)
self.assertEqual(opt.param_groups[0]["lr"], 0.002)

opt.register_load_state_dict_pre_hook(self._load_state_dict_pre_hook2, prepend=True)
opt.load_state_dict(state_dict)
# If prepend were False would be 0.003 but since prepend is True, the other hook overrides
self.assertEqual(opt.param_groups[0]["lr"], 0.002)

def test_load_state_dict_post_hook(self):
param = torch.rand(2, 3, requires_grad=True)
param.grad = torch.rand(2, 3, requires_grad=True)
opt = SGD([param], lr=0.001)

opt.register_load_state_dict_post_hook(self._load_state_dict_post_hook)
opt.load_state_dict(opt.state_dict())
self.assertFalse(opt.state["ran_load_state_dict_pre_hook2"])
self.assertTrue(opt.state["ran_load_state_dict_post_hook"])

def test_load_state_dict_pre_post_hook(self):
param = torch.rand(2, 3, requires_grad=True)
param.grad = torch.rand(2, 3, requires_grad=True)
opt = SGD([param], lr=0.001)

opt.register_load_state_dict_pre_hook(self._load_state_dict_pre_hook2)
opt.register_load_state_dict_post_hook(self._load_state_dict_post_hook)
opt.load_state_dict(opt.state_dict())
self.assertTrue(opt.state["ran_load_state_dict_pre_hook2"])
self.assertTrue(opt.state["ran_load_state_dict_post_hook"])


def _diff_fn(p, grad, opt_differentiable_state, opt_class, kwargs, *ignored):
# Ignored is the list of values in `opt_differentiable_state`, we do this
Expand Down
190 changes: 187 additions & 3 deletions torch/optim/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@

Args: TypeAlias = Tuple[Any, ...]
Kwargs: TypeAlias = Dict[str, Any]
StateDict: TypeAlias = Dict[str, Any]

GlobalOptimizerPreHook: TypeAlias = Callable[["Optimizer", Args, Kwargs], Optional[Tuple[Args, Kwargs]]]
GlobalOptimizerPostHook: TypeAlias = Callable[["Optimizer", Args, Kwargs], None]

Expand Down Expand Up @@ -229,12 +231,20 @@ class Optimizer:

_optimizer_step_pre_hooks: Dict[int, OptimizerPreHook]
_optimizer_step_post_hooks: Dict[int, OptimizerPostHook]
_optimizer_state_dict_pre_hooks: 'OrderedDict[int, Callable[["Optimizer"], None]]'
_optimizer_state_dict_post_hooks: 'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]'
_optimizer_load_state_dict_pre_hooks: 'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]'
_optimizer_load_state_dict_post_hooks: 'OrderedDict[int, Callable[["Optimizer"], None]]'

def __init__(self, params: params_t, defaults: Dict[str, Any]) -> None:
torch._C._log_api_usage_once("python.optimizer")
self.defaults = defaults
self._optimizer_step_pre_hooks = OrderedDict()
self._optimizer_step_post_hooks = OrderedDict()
self._optimizer_state_dict_pre_hooks = OrderedDict()
self._optimizer_state_dict_post_hooks = OrderedDict()
self._optimizer_load_state_dict_pre_hooks = OrderedDict()
self._optimizer_load_state_dict_post_hooks = OrderedDict()

self._patch_step_function()

Expand Down Expand Up @@ -273,6 +283,14 @@ def __setstate__(self, state: Dict[str, Any]) -> None:
self._optimizer_step_pre_hooks = OrderedDict()
if '_optimizer_step_post_hooks' not in self.__dict__:
self._optimizer_step_post_hooks = OrderedDict()
if '_optimizer_state_dict_pre_hooks' not in self.__dict__:
self._optimizer_state_dict_pre_hooks = OrderedDict()
if '_optimizer_state_dict_post_hooks' not in self.__dict__:
self._optimizer_state_dict_post_hooks = OrderedDict()
if '_optimizer_load_state_dict_pre_hooks' not in self.__dict__:
self._optimizer_load_state_dict_pre_hooks = OrderedDict()
if '_optimizer_load_state_dict_post_hooks' not in self.__dict__:
self._optimizer_load_state_dict_post_hooks = OrderedDict()
self._patch_step_function() # To support multiprocessing pickle/unpickle
self.defaults.setdefault('differentiable', False)

Expand Down Expand Up @@ -427,8 +445,77 @@ def register_step_post_hook(self, hook: OptimizerPostHook) -> RemovableHandle:
self._optimizer_step_post_hooks[handle.id] = hook
return handle


def register_state_dict_pre_hook(
self, hook: Callable[["Optimizer"], None], prepend: bool = False
) -> RemovableHandle:
r"""Register a state dict pre-hook which will be called before
:meth:`~torch.optim.Optimizer.state_dict` is called. It should have the
following signature::
hook(optimizer) -> None
The ``optimizer`` argument is the optimizer instance being used.
The hook will be called with argument ``self`` before calling ``state_dict`` on ``self``.
The registered hook can be used to perform pre-processing before the ``state_dict``
call is made.
Args:
hook (Callable): The user defined hook to be registered.
prepend (bool): If True, the provided pre ``hook`` will be fired before
all the already registered pre-hooks on ``state_dict``. Otherwise,
the provided ``hook`` will be fired after all the existing registered
pre-hooks. (default: False)
Returns:
:class:`torch.utils.hooks.RemoveableHandle`:
a handle that can be used to remove the added hook by calling
``handle.remove()``
"""
handle = hooks.RemovableHandle(self._optimizer_state_dict_pre_hooks)
self._optimizer_state_dict_pre_hooks[handle.id] = hook
if prepend:
self._optimizer_state_dict_pre_hooks.move_to_end(handle.id, last=False)
return handle


def register_state_dict_post_hook(
self,
hook: Callable[["Optimizer", StateDict], Optional[StateDict]],
prepend: bool = False,
) -> RemovableHandle:
r"""Register a state dict post-hook which will be called after
:meth:`~torch.optim.Optimizer.state_dict` is called. It should have the
following signature::
hook(optimizer, state_dict) -> state_dict or None
The hook will be called with arguments ``self`` and ``state_dict`` after generating
a ``state_dict`` on ``self``. The hook may modify the state_dict inplace or optionally
return a new one. The registered hook can be used to perform post-processing
on the ``state_dict`` before it is returned.
Args:
hook (Callable): The user defined hook to be registered.
prepend (bool): If True, the provided post ``hook`` will be fired before
all the existing registered post-hooks on ``state_dict``. Otherwise,
the provided ``hook`` will be fired after all the existing registered
post-hooks. (default: False)
Returns:
:class:`torch.utils.hooks.RemoveableHandle`:
a handle that can be used to remove the added hook by calling
``handle.remove()``
"""
handle = hooks.RemovableHandle(self._optimizer_state_dict_post_hooks)
self._optimizer_state_dict_post_hooks[handle.id] = hook
if prepend:
self._optimizer_state_dict_post_hooks.move_to_end(handle.id, last=False)
return handle


@torch._disable_dynamo
def state_dict(self) -> Dict[str, Any]:
def state_dict(self) -> StateDict:
r"""Returns the state of the optimizer as a :class:`dict`.
It contains two entries:
Expand All @@ -438,6 +525,10 @@ def state_dict(self) -> Dict[str, Any]:
* param_groups - a list containing all parameter groups where each
parameter group is a dict
"""

for pre_hook in self._optimizer_state_dict_pre_hooks.values():
pre_hook(self)

# Save order indices instead of Tensors
param_mappings: Dict[int, int] = {}
start_index = 0
Expand All @@ -454,11 +545,18 @@ def pack_group(group: Dict[str, Any]) -> Dict[str, Any]:
# Remap state to use order indices as keys
packed_state = {(param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): v
for k, v in self.state.items()}
return {

state_dict = {
'state': packed_state,
'param_groups': param_groups,
}

for post_hook in self._optimizer_state_dict_post_hooks.values():
hook_result = post_hook(self, state_dict)
if hook_result is not None:
state_dict = hook_result
return state_dict

@staticmethod
def _process_value_according_to_param_policy(
param: torch.Tensor,
Expand Down Expand Up @@ -491,8 +589,84 @@ def _process_value_according_to_param_policy(
else:
return value.to(device=param.device)


def register_load_state_dict_pre_hook(
self,
hook: Callable[["Optimizer", StateDict], Optional[StateDict]],
prepend: bool = False,
) -> RemovableHandle:
r"""Register a load_state_dict pre-hook which will be called before
:meth:`~torch.optim.Optimizer.load_state_dict` is called. It should have the
following signature::
hook(optimizer, state_dict) -> state_dict or None
The ``optimizer`` argument is the optimizer instance being used and the
``state_dict`` argument is a shallow copy of the ``state_dict`` the user
passed in to ``load_state_dict``. The hook may modify the state_dict inplace
or optionally return a new one. If a state_dict is returned, it will be used
to be loaded into the optimizer.
The hook will be called with argument ``self`` and ``state_dict`` before
calling ``load_state_dict`` on ``self``. The registered hook can be used to
perform pre-processing before the ``load_state_dict`` call is made.
Args:
hook (Callable): The user defined hook to be registered.
prepend (bool): If True, the provided pre ``hook`` will be fired before
all the existing registered pre-hooks on ``load_state_dict``. Otherwise,
the provided ``hook`` will be fired after all the pre-hooks registered
before. (default: False)
Returns:
:class:`torch.utils.hooks.RemoveableHandle`:
a handle that can be used to remove the added hook by calling
``handle.remove()``
"""
handle = hooks.RemovableHandle(self._optimizer_load_state_dict_pre_hooks)
self._optimizer_load_state_dict_pre_hooks[handle.id] = hook
if prepend:
self._optimizer_load_state_dict_pre_hooks.move_to_end(handle.id, last=False)
return handle


def register_load_state_dict_post_hook(
self, hook: Callable[["Optimizer"], None], prepend: bool = False
) -> RemovableHandle:
r"""Register a load_state_dict post-hook which will be called after
:meth:`~torch.optim.Optimizer.load_state_dict` is called. It should have the
following signature::
hook(optimizer) -> None
The ``optimizer`` argument is the optimizer instance being used.
The hook will be called with argument ``self`` after calling
``load_state_dict`` on ``self``. The registered hook can be used to
perform post-processing after ``load_state_dict`` has loaded the
``state_dict``.
Args:
hook (Callable): The user defined hook to be registered.
prepend (bool): If True, the provided post ``hook`` will be fired before
all the existing registered post-hooks on ``load_state_dict``. Otherwise,
the provided ``hook`` will be fired after all the existing registered
post-hooks. (default: False)
Returns:
:class:`torch.utils.hooks.RemoveableHandle`:
a handle that can be used to remove the added hook by calling
``handle.remove()``
"""
handle = hooks.RemovableHandle(self._optimizer_load_state_dict_post_hooks)
self._optimizer_load_state_dict_post_hooks[handle.id] = hook
if prepend:
self._optimizer_load_state_dict_post_hooks.move_to_end(handle.id, last=False) # type: ignore[attr-defined]
return handle


@torch._disable_dynamo
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
def load_state_dict(self, state_dict: StateDict) -> None:
r"""Loads the optimizer state.
Args:
Expand All @@ -501,6 +675,12 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""
# deepcopy, to be consistent with module API
state_dict = deepcopy(state_dict)

for pre_hook in self._optimizer_load_state_dict_pre_hooks.values():
hook_result = pre_hook(self, state_dict)
if hook_result is not None:
state_dict = hook_result

# Validate the state_dict
groups = self.param_groups
saved_groups = state_dict['param_groups']
Expand Down Expand Up @@ -548,6 +728,10 @@ def update_group(group: Dict[str, Any], new_group: Dict[str, Any]) -> Dict[str,
update_group(g, ng) for g, ng in zip(groups, saved_groups)]
self.__setstate__({'state': state, 'param_groups': param_groups})

for post_hook in self._optimizer_load_state_dict_post_hooks.values():
post_hook(self)


@torch._disable_dynamo
def zero_grad(self, set_to_none: bool = True) -> None:
r"""Resets the gradients of all optimized :class:`torch.Tensor` s.
Expand Down

0 comments on commit 2b848c1

Please sign in to comment.