Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Enable parameter reset in loss module #2017

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/td3/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ def main(cfg: "DictConfig"): # noqa: F821
q_losses,
) = ([], [])
for _ in range(num_updates):

if update_counter % 5 == 0:
loss_module.reset_parameters()
# Update actor every delayed_updates
update_counter += 1
update_actor = update_counter % delayed_updates == 0
Expand Down
32 changes: 31 additions & 1 deletion torchrl/objectives/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
import warnings
from copy import deepcopy
from dataclasses import dataclass
from typing import Iterator, List, Optional, Tuple
from typing import Callable, Iterator, List, Optional, Tuple

import torch
import torch.nn.init as init
from tensordict import is_tensor_collection, TensorDict, TensorDictBase

from tensordict.nn import TensorDictModule, TensorDictModuleBase, TensorDictParams
Expand Down Expand Up @@ -363,6 +364,35 @@ def reset(self) -> None:
# mainly used for PPO with KL target
pass

def reset_parameters(
self,
module_names: Optional[List[Parameter]] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
module_names: Optional[List[Parameter]] = None,
module_names: List[Parameter] | None = None,

init_func: Optional[Callable] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
init_func: Optional[Callable] = None,
init_func: Callable[[torch.Tensor], None] | None = None,

) -> None:
"""Reset the parameters of the specified modules.
Args:
module_names (Optional[List[Parameter]]): A list of module names to reset the parameters for.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
module_names (Optional[List[Parameter]]): A list of module names to reset the parameters for.
module_names (list of nn.Parameter, optional): A list of module names to reset the parameters for.

If None, all modules with names ending in "_params" will be reset.
init_func (Optional[Callable]): A function to initialize the parameters.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
init_func (Optional[Callable]): A function to initialize the parameters.
init_func (Callable[[torch.Tensor], None]): A function to initialize the parameters.

If None, the parameters will be initialized with uniform random values between -1 and 1.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems very unlikely that anyone would want to use that init IMO. Shouldn't we use the init method from the corresponding nn.Module if there is?

"""
if module_names is None:
params_2_reset = [
name for name in self._modules.keys() if name.endswith("_params")
]
else:
params_2_reset = [name + "_params" for name in module_names]

def _reset_params(param):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having one single reset function will be hard to handle, we need a way to tie the reset function and the module.

if init_func is not None:
init_func(param.data)
else:
init.uniform_(param.data, -1, 1)

for name in params_2_reset:
getattr(self, name).apply(_reset_params)

@property
def value_estimator(self) -> ValueEstimatorBase:
"""The value function blends in the reward and value estimate(s) from upcoming state(s)/state-action pair(s) into a target value estimate for the value network."""
Expand Down
Loading