-
Notifications
You must be signed in to change notification settings - Fork 268
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Enable parameter reset in loss module #2017
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -9,9 +9,10 @@ | |||||
import warnings | ||||||
from copy import deepcopy | ||||||
from dataclasses import dataclass | ||||||
from typing import Iterator, List, Optional, Tuple | ||||||
from typing import Callable, Iterator, List, Optional, Tuple | ||||||
|
||||||
import torch | ||||||
import torch.nn.init as init | ||||||
from tensordict import is_tensor_collection, TensorDict, TensorDictBase | ||||||
|
||||||
from tensordict.nn import TensorDictModule, TensorDictModuleBase, TensorDictParams | ||||||
|
@@ -363,6 +364,35 @@ def reset(self) -> None: | |||||
# mainly used for PPO with KL target | ||||||
pass | ||||||
|
||||||
def reset_parameters( | ||||||
self, | ||||||
module_names: Optional[List[Parameter]] = None, | ||||||
init_func: Optional[Callable] = None, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
) -> None: | ||||||
"""Reset the parameters of the specified modules. | ||||||
Args: | ||||||
module_names (Optional[List[Parameter]]): A list of module names to reset the parameters for. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
If None, all modules with names ending in "_params" will be reset. | ||||||
init_func (Optional[Callable]): A function to initialize the parameters. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
If None, the parameters will be initialized with uniform random values between -1 and 1. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems very unlikely that anyone would want to use that init IMO. Shouldn't we use the init method from the corresponding nn.Module if there is? |
||||||
""" | ||||||
if module_names is None: | ||||||
params_2_reset = [ | ||||||
name for name in self._modules.keys() if name.endswith("_params") | ||||||
] | ||||||
else: | ||||||
params_2_reset = [name + "_params" for name in module_names] | ||||||
|
||||||
def _reset_params(param): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Having one single reset function will be hard to handle, we need a way to tie the reset function and the module. |
||||||
if init_func is not None: | ||||||
init_func(param.data) | ||||||
else: | ||||||
init.uniform_(param.data, -1, 1) | ||||||
|
||||||
for name in params_2_reset: | ||||||
getattr(self, name).apply(_reset_params) | ||||||
|
||||||
@property | ||||||
def value_estimator(self) -> ValueEstimatorBase: | ||||||
"""The value function blends in the reward and value estimate(s) from upcoming state(s)/state-action pair(s) into a target value estimate for the value network.""" | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.