From 4b29473a32561b6251e01eeb6d50f51adc957690 Mon Sep 17 00:00:00 2001 From: BY571 Date: Mon, 18 Mar 2024 15:53:20 +0100 Subject: [PATCH] init loss_module param reset --- examples/td3/td3.py | 3 ++- torchrl/objectives/common.py | 32 +++++++++++++++++++++++++++++++- 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/examples/td3/td3.py b/examples/td3/td3.py index ef2edd578cb..6109de11510 100644 --- a/examples/td3/td3.py +++ b/examples/td3/td3.py @@ -123,7 +123,8 @@ def main(cfg: "DictConfig"): # noqa: F821 q_losses, ) = ([], []) for _ in range(num_updates): - + if update_counter % 5 == 0: + loss_module.reset_parameters() # Update actor every delayed_updates update_counter += 1 update_actor = update_counter % delayed_updates == 0 diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index 6b6fd391560..576427ad88e 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -9,9 +9,10 @@ import warnings from copy import deepcopy from dataclasses import dataclass -from typing import Iterator, List, Optional, Tuple +from typing import Callable, Iterator, List, Optional, Tuple import torch +import torch.nn.init as init from tensordict import is_tensor_collection, TensorDict, TensorDictBase from tensordict.nn import TensorDictModule, TensorDictModuleBase, TensorDictParams @@ -363,6 +364,35 @@ def reset(self) -> None: # mainly used for PPO with KL target pass + def reset_parameters( + self, + module_names: Optional[List[Parameter]] = None, + init_func: Optional[Callable] = None, + ) -> None: + """Reset the parameters of the specified modules. + + Args: + module_names (Optional[List[Parameter]]): A list of module names to reset the parameters for. + If None, all modules with names ending in "_params" will be reset. + init_func (Optional[Callable]): A function to initialize the parameters. + If None, the parameters will be initialized with uniform random values between -1 and 1. + """ + if module_names is None: + params_2_reset = [ + name for name in self._modules.keys() if name.endswith("_params") + ] + else: + params_2_reset = [name + "_params" for name in module_names] + + def _reset_params(param): + if init_func is not None: + init_func(param.data) + else: + init.uniform_(param.data, -1, 1) + + for name in params_2_reset: + getattr(self, name).apply(_reset_params) + @property def value_estimator(self) -> ValueEstimatorBase: """The value function blends in the reward and value estimate(s) from upcoming state(s)/state-action pair(s) into a target value estimate for the value network."""