Skip to content

Commit

Permalink
init loss_module param reset
Browse files Browse the repository at this point in the history
  • Loading branch information
BY571 committed Mar 18, 2024
1 parent 87f3437 commit 4b29473
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 2 deletions.
3 changes: 2 additions & 1 deletion examples/td3/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ def main(cfg: "DictConfig"): # noqa: F821
q_losses,
) = ([], [])
for _ in range(num_updates):

if update_counter % 5 == 0:
loss_module.reset_parameters()
# Update actor every delayed_updates
update_counter += 1
update_actor = update_counter % delayed_updates == 0
Expand Down
32 changes: 31 additions & 1 deletion torchrl/objectives/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
import warnings
from copy import deepcopy
from dataclasses import dataclass
from typing import Iterator, List, Optional, Tuple
from typing import Callable, Iterator, List, Optional, Tuple

import torch
import torch.nn.init as init
from tensordict import is_tensor_collection, TensorDict, TensorDictBase

from tensordict.nn import TensorDictModule, TensorDictModuleBase, TensorDictParams
Expand Down Expand Up @@ -363,6 +364,35 @@ def reset(self) -> None:
# mainly used for PPO with KL target
pass

def reset_parameters(
self,
module_names: Optional[List[Parameter]] = None,
init_func: Optional[Callable] = None,
) -> None:
"""Reset the parameters of the specified modules.
Args:
module_names (Optional[List[Parameter]]): A list of module names to reset the parameters for.
If None, all modules with names ending in "_params" will be reset.
init_func (Optional[Callable]): A function to initialize the parameters.
If None, the parameters will be initialized with uniform random values between -1 and 1.
"""
if module_names is None:
params_2_reset = [
name for name in self._modules.keys() if name.endswith("_params")
]
else:
params_2_reset = [name + "_params" for name in module_names]

def _reset_params(param):
if init_func is not None:
init_func(param.data)
else:
init.uniform_(param.data, -1, 1)

for name in params_2_reset:
getattr(self, name).apply(_reset_params)

@property
def value_estimator(self) -> ValueEstimatorBase:
"""The value function blends in the reward and value estimate(s) from upcoming state(s)/state-action pair(s) into a target value estimate for the value network."""
Expand Down

0 comments on commit 4b29473

Please sign in to comment.