-
Notifications
You must be signed in to change notification settings - Fork 309
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Enable parameter reset in loss module #2017
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/2017
Note: Links to docs will display an error until the docs builds have been completed. ❌ 4 New Failures, 1 Unrelated FailureAs of commit 4b29473 with merge base 87f3437 (): NEW FAILURES - The following jobs have failed:
BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this!
We'll need tests for the feature.
How do we handle the target parameters?
Wouldn't something like this be a bit more robust?
from torchrl.objectives import DQNLoss
from torchrl.modules import QValueActor
from torch import nn
module = nn.Sequential(nn.Linear(1, 64), nn.ReLU(), nn.Linear(64, 64))
value_net = QValueActor(module=module, action_space="categorical")
loss = DQNLoss(value_network=value_net, action_space="categorical")
with loss.value_network_params.to_module(loss.value_network):
loss.apply(lambda module: module.reset_parameters() if hasattr(module, "reset_parameters") else None)
module_names (Optional[List[Parameter]]): A list of module names to reset the parameters for. | ||
If None, all modules with names ending in "_params" will be reset. | ||
init_func (Optional[Callable]): A function to initialize the parameters. | ||
If None, the parameters will be initialized with uniform random values between -1 and 1. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems very unlikely that anyone would want to use that init IMO. Shouldn't we use the init method from the corresponding nn.Module if there is?
def reset_parameters( | ||
self, | ||
module_names: Optional[List[Parameter]] = None, | ||
init_func: Optional[Callable] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
init_func: Optional[Callable] = None, | |
init_func: Callable[[torch.Tensor], None] | None = None, |
@@ -363,6 +364,35 @@ def reset(self) -> None: | |||
# mainly used for PPO with KL target | |||
pass | |||
|
|||
def reset_parameters( | |||
self, | |||
module_names: Optional[List[Parameter]] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
module_names: Optional[List[Parameter]] = None, | |
module_names: List[Parameter] | None = None, |
"""Reset the parameters of the specified modules. | ||
|
||
Args: | ||
module_names (Optional[List[Parameter]]): A list of module names to reset the parameters for. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
module_names (Optional[List[Parameter]]): A list of module names to reset the parameters for. | |
module_names (list of nn.Parameter, optional): A list of module names to reset the parameters for. |
Args: | ||
module_names (Optional[List[Parameter]]): A list of module names to reset the parameters for. | ||
If None, all modules with names ending in "_params" will be reset. | ||
init_func (Optional[Callable]): A function to initialize the parameters. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
init_func (Optional[Callable]): A function to initialize the parameters. | |
init_func (Callable[[torch.Tensor], None]): A function to initialize the parameters. |
else: | ||
params_2_reset = [name + "_params" for name in module_names] | ||
|
||
def _reset_params(param): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Having one single reset function will be hard to handle, we need a way to tie the reset function and the module.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this!
We'll need tests for the feature.
How do we handle the target parameters?
Wouldn't something like this be a bit more robust?
from torchrl.objectives import DQNLoss
from torchrl.modules import QValueActor
from torch import nn
module = nn.Sequential(nn.Linear(1, 64), nn.ReLU(), nn.Linear(64, 64))
value_net = QValueActor(module=module, action_space="categorical")
loss = DQNLoss(value_network=value_net, action_space="categorical")
with loss.value_network_params.to_module(loss.value_network):
loss.apply(lambda module: module.reset_parameters() if hasattr(module, "reset_parameters") else None)
I like the solution! But we are accessing the parameters directly in your example so we would need to define a reset function manually, which I think is perfectly fine because then the user has to decide the way how to reset weights and biases: def reset_parameters(params):
""" User specified resetting function depending on their needs for initialization """
if len(params.shape) > 1:
# weights
nn.init.xavier_uniform_(params)
elif len(params.shape) == 1:
# biases
nn.init.zeros_(params)
else:
raise ValueError("Unknown parameter shape: {}".format(params.shape))
with loss.value_network_params.to_module(loss.value_network):
loss.apply(lambda x: reset_parameters(x.data) if hasattr(x, "data") else None) And for handling the target_network_params I think we could simply do something like: loss.target_value_network_params.update(loss.value_network_params) What do you think? I think we can close the draft. But we might want to mention the way to reset parameters somewhere in the docs. |
This won't work because the target params are locked (you can't update them). They're locked because we want to avoid this kind of operation :) loss.target_value_network_params.apply(lambda dest, src: dest.data.copy_(src), loss.value_network_params) |
def reset_parameters(params):
""" User specified resetting function depending on their needs for initialization """
if len(params.shape) > 1:
# weights
nn.init.xavier_uniform_(params)
elif len(params.shape) == 1:
# biases
nn.init.zeros_(params)
else:
raise ValueError("Unknown parameter shape: {}".format(params.shape))
with loss.value_network_params.to_module(loss.value_network):
loss.apply(lambda x: reset_parameters(x.data) if hasattr(x, "data") else None) Unfortunately this isn't very generic Not all modules are "weights" and "biases" and "biases" can be 2d (my point is: the dimension is a very indirect determinator of the tensor role in a model) The way I usually see this work is to use the module Maybe we could allow the user to pass a reset function, but in that case we don't even need to re-populate the module (we can just do def reset(name, tensor):
if name == "bias":
tensor.data.zero_()
if name == "weight":
nn.init.xavier_uniform_(tensor)
tensordict.apply(reset, named=True) which is more straightforward IMO |
Description
Allows to reset the parameters in the loss module.