Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Optimizer state_dict hooks #105953

Closed
wants to merge 9 commits into from
93 changes: 93 additions & 0 deletions test/optim/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -1913,6 +1913,99 @@ def test_fused_optimizer_raises(self):
with self.assertRaisesRegex(RuntimeError, "`fused` does not support `differentiable`"):
optimizer_ctor([torch.empty((), device="cuda")], differentiable=True, fused=True)

@staticmethod
def _state_dict_pre_hook(optimizer: Optimizer) -> None:
optimizer.state["test"] = 1

@staticmethod
def _state_dict_post_hook(optimizer: Optimizer, state_dict: Dict[str, Any]) -> Dict[str, Any]:
if "test" in state_dict["state"]:
state_dict["state"].pop("test")
state_dict["ran_state_dict_pre_hook"] = True
else:
state_dict["ran_state_dict_pre_hook"] = False
return state_dict

@staticmethod
def _load_state_dict_pre_hook1(optimizer: Optimizer, state_dict: Dict[str, Any]) -> None:
state_dict["param_groups"][0]["lr"] = 0.002

@staticmethod
def _load_state_dict_pre_hook2(optimizer: Optimizer, state_dict: Dict[str, Any]) -> Dict[str, Any]:
# The typical use case for returning a state dict is to drastically modify the state dict.
# I will simulate by simply making a deep copy and ensuring that my_state_dict still gets used
my_state_dict = deepcopy(state_dict)
my_state_dict["param_groups"][0]["lr"] = 0.003
return my_state_dict

@staticmethod
def _load_state_dict_post_hook(optimizer: Optimizer) -> None:
optimizer.state["ran_load_state_dict_pre_hook2"] = optimizer.param_groups[0]["lr"] == 0.003
optimizer.state["ran_load_state_dict_post_hook"] = True

def test_state_dict_pre_hook(self):
param = torch.rand(2, 3, requires_grad=True)
param.grad = torch.rand(2, 3, requires_grad=True)
opt = SGD([param], lr=0.001)
opt.register_state_dict_pre_hook(self._state_dict_pre_hook)
state_dict = opt.state_dict()
self.assertEqual(state_dict["state"]["test"], 1)

def test_state_dict_post_hook(self):
param = torch.rand(2, 3, requires_grad=True)
param.grad = torch.rand(2, 3, requires_grad=True)
opt = SGD([param], lr=0.001)
opt.register_state_dict_post_hook(self._state_dict_post_hook)
state_dict = opt.state_dict()
self.assertEqual(state_dict["ran_state_dict_pre_hook"], False)

def test_state_dict_pre_post_hook(self):
param = torch.rand(2, 3, requires_grad=True)
param.grad = torch.rand(2, 3, requires_grad=True)
opt = SGD([param], lr=0.001)
opt.register_state_dict_pre_hook(self._state_dict_pre_hook)
opt.register_state_dict_post_hook(self._state_dict_post_hook)
state_dict = opt.state_dict()
self.assertFalse("test" in state_dict["state"])
self.assertEqual(state_dict["ran_state_dict_pre_hook"], True)

def test_load_state_dict_pre_hook_and_prepend(self):
param = torch.rand(2, 3, requires_grad=True)
param.grad = torch.rand(2, 3, requires_grad=True)
opt = SGD([param], lr=0.001)
state_dict = opt.state_dict()

# usually one would have a new opt instance here, but it's all the same here
opt.register_load_state_dict_pre_hook(self._load_state_dict_pre_hook1)
opt.load_state_dict(state_dict)
self.assertEqual(opt.param_groups[0]["lr"], 0.002)

opt.register_load_state_dict_pre_hook(self._load_state_dict_pre_hook2, prepend=True)
opt.load_state_dict(state_dict)
# If prepend were False would be 0.003 but since prepend is True, the other hook overrides
self.assertEqual(opt.param_groups[0]["lr"], 0.002)

def test_load_state_dict_post_hook(self):
param = torch.rand(2, 3, requires_grad=True)
param.grad = torch.rand(2, 3, requires_grad=True)
opt = SGD([param], lr=0.001)

opt.register_load_state_dict_post_hook(self._load_state_dict_post_hook)
opt.load_state_dict(opt.state_dict())
self.assertFalse(opt.state["ran_load_state_dict_pre_hook2"])
self.assertTrue(opt.state["ran_load_state_dict_post_hook"])

def test_load_state_dict_pre_post_hook(self):
param = torch.rand(2, 3, requires_grad=True)
param.grad = torch.rand(2, 3, requires_grad=True)
opt = SGD([param], lr=0.001)

opt.register_load_state_dict_pre_hook(self._load_state_dict_pre_hook2)
opt.register_load_state_dict_post_hook(self._load_state_dict_post_hook)
opt.load_state_dict(opt.state_dict())
self.assertTrue(opt.state["ran_load_state_dict_pre_hook2"])
self.assertTrue(opt.state["ran_load_state_dict_post_hook"])


def _diff_fn(p, grad, opt_differentiable_state, opt_class, kwargs, *ignored):
# Ignored is the list of values in `opt_differentiable_state`, we do this
Expand Down
191 changes: 188 additions & 3 deletions torch/optim/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Callable,
DefaultDict,
Dict,
OrderedDict,
Hashable,
Iterable,
List,
Expand Down Expand Up @@ -36,6 +37,8 @@

Args: TypeAlias = Tuple[Any, ...]
Kwargs: TypeAlias = Dict[str, Any]
StateDict: TypeAlias = Dict[str, Any]

GlobalOptimizerPreHook: TypeAlias = Callable[["Optimizer", Args, Kwargs], Optional[Tuple[Args, Kwargs]]]
GlobalOptimizerPostHook: TypeAlias = Callable[["Optimizer", Args, Kwargs], None]

Expand Down Expand Up @@ -229,12 +232,20 @@

_optimizer_step_pre_hooks: Dict[int, OptimizerPreHook]
_optimizer_step_post_hooks: Dict[int, OptimizerPostHook]
_optimizer_state_dict_pre_hooks: OrderedDict[int, Callable[["Optimizer"], None]]
_optimizer_state_dict_post_hooks: OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]
_optimizer_load_state_dict_pre_hooks: OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]
_optimizer_load_state_dict_post_hooks: OrderedDict[int, Callable[["Optimizer"], None]]

def __init__(self, params: params_t, defaults: Dict[str, Any]) -> None:
torch._C._log_api_usage_once("python.optimizer")
self.defaults = defaults
self._optimizer_step_pre_hooks = OrderedDict()
self._optimizer_step_post_hooks = OrderedDict()
self._optimizer_state_dict_pre_hooks = OrderedDict()
self._optimizer_state_dict_post_hooks = OrderedDict()
self._optimizer_load_state_dict_pre_hooks = OrderedDict()
self._optimizer_load_state_dict_post_hooks = OrderedDict()

self._patch_step_function()

Expand Down Expand Up @@ -273,6 +284,14 @@
self._optimizer_step_pre_hooks = OrderedDict()
if '_optimizer_step_post_hooks' not in self.__dict__:
self._optimizer_step_post_hooks = OrderedDict()
if '_optimizer_state_dict_pre_hooks' not in self.__dict__:
self._optimizer_state_dict_pre_hooks = OrderedDict()
if '_optimizer_state_dict_post_hooks' not in self.__dict__:
self._optimizer_state_dict_post_hooks = OrderedDict()
if '_optimizer_load_state_dict_pre_hooks' not in self.__dict__:
self._optimizer_load_state_dict_pre_hooks = OrderedDict()
if '_optimizer_load_state_dict_post_hooks' not in self.__dict__:
self._optimizer_load_state_dict_post_hooks = OrderedDict()
self._patch_step_function() # To support multiprocessing pickle/unpickle
self.defaults.setdefault('differentiable', False)

Expand Down Expand Up @@ -427,8 +446,77 @@
self._optimizer_step_post_hooks[handle.id] = hook
return handle


def register_state_dict_pre_hook(
self, hook: Callable[["Optimizer"], None], prepend: bool = False
) -> RemovableHandle:
r"""Register a state dict pre-hook which will be called before
:meth:`~torch.optim.Optimizer.state_dict` is called. It should have the
following signature::

hook(optimizer) -> None

The ``optimizer`` argument is the optimizer instance being used.
The hook will be called with argument ``self`` before calling ``state_dict`` on ``self``.
The registered hook can be used to perform pre-processing before the ``state_dict``
call is made.

Args:
hook (Callable): The user defined hook to be registered.
prepend (bool): If True, the provided pre ``hook`` will be fired before
all the already registered pre-hooks on ``state_dict``. Otherwise,
the provided ``hook`` will be fired after all the existing registered
pre-hooks. (default: False)

Returns:
:class:`torch.utils.hooks.RemoveableHandle`:
a handle that can be used to remove the added hook by calling
``handle.remove()``
"""
handle = hooks.RemovableHandle(self._optimizer_state_dict_pre_hooks)
self._optimizer_state_dict_pre_hooks[handle.id] = hook
if prepend:
self._optimizer_state_dict_pre_hooks.move_to_end(handle.id, last=False)
return handle


def register_state_dict_post_hook(
self,
hook: Callable[["Optimizer", StateDict], Optional[StateDict]],
prepend: bool = False,
) -> RemovableHandle:
r"""Register a state dict post-hook which will be called after
:meth:`~torch.optim.Optimizer.state_dict` is called. It should have the
following signature::

hook(optimizer, state_dict) -> state_dict or None

The hook will be called with arguments ``self`` and ``state_dict`` after generating
a ``state_dict`` on ``self``. The hook may modify the state_dict inplace or optionally
return a new one. The registered hook can be used to perform post-processing
on the ``state_dict`` before it is returned.

Args:
hook (Callable): The user defined hook to be registered.
prepend (bool): If True, the provided post ``hook`` will be fired before
all the existing registered post-hooks on ``state_dict``. Otherwise,
the provided ``hook`` will be fired after all the existing registered
post-hooks. (default: False)

Returns:
:class:`torch.utils.hooks.RemoveableHandle`:
a handle that can be used to remove the added hook by calling
``handle.remove()``
"""
handle = hooks.RemovableHandle(self._optimizer_state_dict_post_hooks)
self._optimizer_state_dict_post_hooks[handle.id] = hook
if prepend:
self._optimizer_state_dict_post_hooks.move_to_end(handle.id, last=False)
return handle


@torch._disable_dynamo
def state_dict(self) -> Dict[str, Any]:
def state_dict(self) -> StateDict:
r"""Returns the state of the optimizer as a :class:`dict`.

It contains two entries:
Expand All @@ -438,6 +526,10 @@
* param_groups - a list containing all parameter groups where each
parameter group is a dict
"""

for pre_hook in self._optimizer_state_dict_pre_hooks.values():
pre_hook(self)

# Save order indices instead of Tensors
param_mappings: Dict[int, int] = {}
start_index = 0
Expand All @@ -454,11 +546,18 @@
# Remap state to use order indices as keys
packed_state = {(param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): v
for k, v in self.state.items()}
return {

state_dict = {
'state': packed_state,
'param_groups': param_groups,
}

for post_hook in self._optimizer_state_dict_post_hooks.values():
hook_result = post_hook(self, state_dict)
if hook_result is not None:
state_dict = hook_result
return state_dict

@staticmethod
def _process_value_according_to_param_policy(
param: torch.Tensor,
Expand Down Expand Up @@ -491,8 +590,84 @@
else:
return value.to(device=param.device)


def register_load_state_dict_pre_hook(
self,
hook: Callable[["Optimizer", StateDict], Optional[StateDict]],
prepend: bool = False,
) -> RemovableHandle:
r"""Register a load_state_dict pre-hook which will be called before
:meth:`~torch.optim.Optimizer.load_state_dict` is called. It should have the
following signature::

hook(optimizer, state_dict) -> None or state_dict
janeyx99 marked this conversation as resolved.
Show resolved Hide resolved

The ``optimizer`` argument is the optimizer instance being used and the
``state_dict`` argument is a shallow copy of the ``state_dict`` the user
passed in to ``load_state_dict``. The hook may modify the state_dict inplace
or optionally return a new one. If a state_dict is returned, it will be used
to be loaded into the optimizer.

The hook will be called with argument ``self`` and ``state_dict`` before
calling ``load_state_dict`` on ``self``. The registered hook can be used to
perform pre-processing before the ``load_state_dict`` call is made.

Args:
hook (Callable): The user defined hook to be registered.
prepend (bool): If True, the provided pre ``hook`` will be fired before
all the existing registered pre-hooks on ``load_state_dict``. Otherwise,
the provided ``hook`` will be fired after all the pre-hooks registered
before. (default: False)

Returns:
:class:`torch.utils.hooks.RemoveableHandle`:
a handle that can be used to remove the added hook by calling
``handle.remove()``
"""
handle = hooks.RemovableHandle(self._optimizer_load_state_dict_pre_hooks)
self._optimizer_load_state_dict_pre_hooks[handle.id] = hook
if prepend:
self._optimizer_load_state_dict_pre_hooks.move_to_end(handle.id, last=False)
return handle


def register_load_state_dict_post_hook(
self, hook: Callable[["Optimizer"], None], prepend: bool = False
) -> RemovableHandle:
r"""Register a load_state_dict post-hook which will be called after
:meth:`~torch.optim.Optimizer.load_state_dict` is called. It should have the
following signature::

hook(optimizer) -> None

The ``optimizer`` argument is the optimizer instance being used.

The hook will be called with argument ``self`` after calling
``load_state_dict`` on ``self``. The registered hook can be used to
perform post-processing after ``load_state_dict`` has loaded the
``state_dict``.

Args:
hook (Callable): The user defined hook to be registered.
prepend (bool): If True, the provided post ``hook`` will be fired before
all the existing registered post-hooks on ``load_state_dict``. Otherwise,
the provided ``hook`` will be fired after all the existing registered
post-hooks. (default: False)

Returns:
:class:`torch.utils.hooks.RemoveableHandle`:
a handle that can be used to remove the added hook by calling
``handle.remove()``
"""
handle = hooks.RemovableHandle(self._optimizer_load_state_dict_post_hooks)
self._optimizer_load_state_dict_post_hooks[handle.id] = hook
if prepend:
self._optimizer_load_state_dict_post_hooks.move_to_end(handle.id, last=False) # type: ignore[attr-defined]
return handle


@torch._disable_dynamo
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
def load_state_dict(self, state_dict: StateDict) -> None:

Check notice on line 670 in torch/optim/optimizer.py

View workflow job for this annotation

GitHub Actions / bc_linter

Function Optimizer.load_state_dict: state_dict changed from Dict[str, Any] to StateDict
r"""Loads the optimizer state.

Args:
Expand All @@ -501,6 +676,12 @@
"""
# deepcopy, to be consistent with module API
state_dict = deepcopy(state_dict)

for pre_hook in self._optimizer_load_state_dict_pre_hooks.values():
hook_result = pre_hook(self, state_dict)
if hook_result is not None:
state_dict = hook_result

# Validate the state_dict
groups = self.param_groups
saved_groups = state_dict['param_groups']
Expand Down Expand Up @@ -548,6 +729,10 @@
update_group(g, ng) for g, ng in zip(groups, saved_groups)]
self.__setstate__({'state': state, 'param_groups': param_groups})

for post_hook in self._optimizer_load_state_dict_post_hooks.values():
post_hook(self)


@torch._disable_dynamo
def zero_grad(self, set_to_none: bool = True) -> None:
r"""Resets the gradients of all optimized :class:`torch.Tensor` s.
Expand Down