diff --git a/docs/source/reference/trainers.rst b/docs/source/reference/trainers.rst index ad1399ca23e..d6e411dc2f1 100644 --- a/docs/source/reference/trainers.rst +++ b/docs/source/reference/trainers.rst @@ -29,9 +29,9 @@ The :obj:`trainer.train()` method can be sketched as follows: ... self._post_steps_hook() # "post_steps" ... self._post_steps_log_hook(batch) # "post_steps_log" -There are 9 hooks that can be used in a trainer loop: :obj:`"batch_process"`, :obj:`"pre_optim_steps"`, +There are 10 hooks that can be used in a trainer loop: :obj:`"batch_process"`, :obj:`"pre_optim_steps"`, :obj:`"process_optim_batch"`, :obj:`"post_loss"`, :obj:`"post_steps"`, :obj:`"post_optim"`, :obj:`"pre_steps_log"`, -:obj:`"post_steps_log"` and :obj:`"post_optim_log"`. They are indicated in the comments where they are applied. +:obj:`"post_steps_log"`, :obj:`"post_optim_log"` and :obj:`"optimizer"`. They are indicated in the comments where they are applied. Hooks can be split into 3 categories: **data processing** (:obj:`"batch_process"` and :obj:`"process_optim_batch"`), **logging** (:obj:`"pre_steps_log"`, :obj:`"post_optim_log"` and :obj:`"post_steps_log"`) and **operations** hook (:obj:`"pre_optim_steps"`, :obj:`"post_loss"`, :obj:`"post_optim"` and :obj:`"post_steps"`). @@ -139,6 +139,7 @@ Trainer and hooks BatchSubSampler CountFramesLog LogReward + OptimizerHook Recorder ReplayBuffer RewardNormalizer diff --git a/test/test_trainer.py b/test/test_trainer.py index 5fbedaa8137..bd0c8a8ea59 100644 --- a/test/test_trainer.py +++ b/test/test_trainer.py @@ -40,6 +40,7 @@ CountFramesLog, LogReward, mask_batch, + OptimizerHook, ReplayBufferTrainer, RewardNormalizer, SelectKeys, @@ -82,7 +83,10 @@ class MockingLossModule(nn.Module): pass -def mocking_trainer(file=None) -> Trainer: +_mocking_optim = MockingOptim() + + +def mocking_trainer(file=None, optimizer=_mocking_optim) -> Trainer: trainer = Trainer( MockingCollector(), *[ @@ -90,7 +94,7 @@ def mocking_trainer(file=None) -> Trainer: ] * 2, loss_module=MockingLossModule(), - optimizer=MockingOptim(), + optimizer=optimizer, save_trainer_file=file, ) trainer._pbar_str = OrderedDict() @@ -472,6 +476,159 @@ def make_storage(): TensorDict.load_state_dict = TensorDict_load_state_dict +class TestOptimizer: + @staticmethod + def _setup(): + torch.manual_seed(0) + x = torch.randn(5, 10) + model1 = nn.Linear(10, 20) + model2 = nn.Linear(10, 20) + td = TensorDict( + { + "loss_1": model1(x).sum(), + "loss_2": model2(x).sum(), + }, + batch_size=[], + ) + model1_params = list(model1.parameters()) + model2_params = list(model2.parameters()) + all_params = model1_params + model2_params + return model1_params, model2_params, all_params, td + + def test_optimizer_set_as_argument(self): + _, _, all_params, td = self._setup() + + optimizer = torch.optim.SGD(all_params, lr=1e-3) + trainer = mocking_trainer(optimizer=optimizer) + + params_before = [torch.clone(p) for p in all_params] + td_out = trainer._optimizer_hook(td) + params_after = all_params + + assert "grad_norm_0" in td_out.keys() + assert all( + not torch.equal(p_before, p_after) + for p_before, p_after in zip(params_before, params_after) + ) + + def test_optimizer_set_as_hook(self): + _, _, all_params, td = self._setup() + + optimizer = torch.optim.SGD(all_params, lr=1e-3) + trainer = mocking_trainer(optimizer=None) + hook = OptimizerHook(optimizer) + hook.register(trainer) + + params_before = [torch.clone(p) for p in all_params] + td_out = trainer._optimizer_hook(td) + params_after = all_params + + assert "grad_norm_0" in td_out.keys() + assert all( + not torch.equal(p_before, p_after) + for p_before, p_after in zip(params_before, params_after) + ) + + def test_optimizer_no_optimizer(self): + _, _, all_params, td = self._setup() + + trainer = mocking_trainer(optimizer=None) + + params_before = [torch.clone(p) for p in all_params] + td_out = trainer._optimizer_hook(td) + params_after = all_params + + assert not [key for key in td_out.keys() if key.startswith("grad_norm_")] + assert all( + torch.equal(p_before, p_after) + for p_before, p_after in zip(params_before, params_after) + ) + + def test_optimizer_hook_loss_components_empty(self): + model = nn.Linear(10, 20) + optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) + with pytest.raises(ValueError, match="loss_components list cannot be empty"): + OptimizerHook(optimizer, loss_components=[]) + + def test_optimizer_hook_loss_components_partial(self): + model1_params, model2_params, all_params, td = self._setup() + + optimizer = torch.optim.SGD(all_params, lr=1e-3) + trainer = mocking_trainer(optimizer=None) + hook = OptimizerHook(optimizer, loss_components=["loss_1"]) + hook.register(trainer) + + model1_params_before = [torch.clone(p) for p in model1_params] + model2_params_before = [torch.clone(p) for p in model2_params] + td_out = trainer._optimizer_hook(td) + model1_params_after = model1_params + model2_params_after = model2_params + + assert "grad_norm_0" in td_out.keys() + assert all( + not torch.equal(p_before, p_after) + for p_before, p_after in zip(model1_params_before, model1_params_after) + ) + assert all( + torch.equal(p_before, p_after) + for p_before, p_after in zip(model2_params_before, model2_params_after) + ) + + def test_optimizer_hook_loss_components_none(self): + model1_params, model2_params, all_params, td = self._setup() + + optimizer = torch.optim.SGD(all_params, lr=1e-3) + trainer = mocking_trainer(optimizer=None) + hook = OptimizerHook(optimizer, loss_components=None) + hook.register(trainer) + + model1_params_before = [torch.clone(p) for p in model1_params] + model2_params_before = [torch.clone(p) for p in model2_params] + td_out = trainer._optimizer_hook(td) + model1_params_after = model1_params + model2_params_after = model2_params + + assert "grad_norm_0" in td_out.keys() + assert all( + not torch.equal(p_before, p_after) + for p_before, p_after in zip(model1_params_before, model1_params_after) + ) + assert all( + not torch.equal(p_before, p_after) + for p_before, p_after in zip(model2_params_before, model2_params_after) + ) + + def test_optimizer_multiple_hooks(self): + model1_params, model2_params, _, td = self._setup() + + trainer = mocking_trainer(optimizer=None) + + optimizer1 = torch.optim.SGD(model1_params, lr=1e-3) + hook1 = OptimizerHook(optimizer1, loss_components=["loss_1"]) + hook1.register(trainer, name="optimizer1") + + optimizer2 = torch.optim.Adam(model2_params, lr=1e-4) + hook2 = OptimizerHook(optimizer2, loss_components=["loss_2"]) + hook2.register(trainer, name="optimizer2") + + model1_params_before = [torch.clone(p) for p in model1_params] + model2_params_before = [torch.clone(p) for p in model2_params] + td_out = trainer._optimizer_hook(td) + model1_params_after = model1_params + model2_params_after = model2_params + + assert "grad_norm_0" in td_out.keys() + assert "grad_norm_1" in td_out.keys() + assert all( + not torch.equal(p_before, p_after) + for p_before, p_after in zip(model1_params_before, model1_params_after) + ) + assert all( + not torch.equal(p_before, p_after) + for p_before, p_after in zip(model2_params_before, model2_params_after) + ) + + class TestLogReward: @pytest.mark.parametrize("logname", ["a", "b"]) @pytest.mark.parametrize("pbar", [True, False]) diff --git a/torchrl/trainers/trainers.py b/torchrl/trainers/trainers.py index 1909df1370a..d45c0305298 100644 --- a/torchrl/trainers/trainers.py +++ b/torchrl/trainers/trainers.py @@ -155,9 +155,6 @@ def __init__( self.loss_module = loss_module self.optimizer = optimizer self.logger = logger - self._params = [] - for p in self.optimizer.param_groups: - self._params += p["params"] # seeding self.seed = seed @@ -186,10 +183,15 @@ def __init__( self._post_optim_log_ops = [] self._pre_optim_ops = [] self._post_loss_ops = [] + self._optimizer_ops = [] self._process_optim_batch_ops = [] self._post_optim_ops = [] self._modules = {} + if self.optimizer is not None: + optimizer_hook = OptimizerHook(self.optimizer) + optimizer_hook.register(self) + def register_module(self, module_name: str, module: Any) -> None: if module_name in self._modules: raise RuntimeError( @@ -317,6 +319,12 @@ def register_op(self, dest: str, op: Callable, **kwargs) -> None: ) self._post_loss_ops.append((op, kwargs)) + elif dest == "optimizer": + _check_input_output_typehint( + op, input=[TensorDictBase, bool, float, int], output=TensorDictBase + ) + self._optimizer_ops.append((op, kwargs)) + elif dest == "post_steps": _check_input_output_typehint(op, input=None, output=None) self._post_steps_ops.append((op, kwargs)) @@ -386,6 +394,13 @@ def _post_loss_hook(self, batch): batch = out return batch + def _optimizer_hook(self, batch): + for i, (op, kwargs) in enumerate(self._optimizer_ops): + out = op(batch, self.clip_grad_norm, self.clip_norm, i, **kwargs) + if isinstance(out, TensorDictBase): + batch = out + return batch.detach() + def _post_optim_hook(self): for op, kwargs in self._post_optim_ops: op(**kwargs) @@ -440,16 +455,6 @@ def shutdown(self): print("shutting down collector") self.collector.shutdown() - def _optimizer_step(self, losses_td: TensorDictBase) -> TensorDictBase: - # sum all keys that start with 'loss_' - loss = sum([item for key, item in losses_td.items() if key.startswith("loss")]) - loss.backward() - - grad_norm = self._grad_clip() - self.optimizer.step() - self.optimizer.zero_grad() - return losses_td.detach().set("grad_norm", grad_norm) - def optim_steps(self, batch: TensorDictBase) -> None: average_losses = None @@ -462,7 +467,7 @@ def optim_steps(self, batch: TensorDictBase) -> None: losses_td = self.loss_module(sub_batch) self._post_loss_hook(sub_batch) - losses_detached = self._optimizer_step(losses_td) + losses_detached = self._optimizer_hook(losses_td) self._post_optim_hook() self._post_optim_log(sub_batch) @@ -480,16 +485,6 @@ def optim_steps(self, batch: TensorDictBase) -> None: **average_losses, ) - def _grad_clip(self) -> float: - if self.clip_grad_norm: - gn = nn.utils.clip_grad_norm_(self._params, self.clip_norm) - else: - gn = sum( - [p.grad.pow(2).sum() for p in self._params if p.grad is not None] - ).sqrt() - nn.utils.clip_grad_value_(self._params, self.clip_norm) - return float(gn) - def _log(self, log_pbar=False, **kwargs) -> None: collected_frames = self.collected_frames for key, item in kwargs.items(): @@ -696,6 +691,85 @@ def register(self, trainer: Trainer, name: str = "replay_buffer"): trainer.register_module(name, self) +class OptimizerHook(TrainerHookBase): + """Add an optimizer for one or more loss components. + + Args: + optimizer (optim.Optimizer): An optimizer to apply to the loss_components. + loss_components (Sequence[str], optional): The keys in the loss TensorDict + for which the optimizer should be appled to the respective values. + If omitted, the optimizer is applied to all components with the + names starting with `loss_`. + + Examples: + >>> optimizer_hook = OptimizerHook(optimizer, ["loss_actor"]) + >>> trainer.register_op("optimizer", optimizer_hook) + + """ + + def __init__( + self, + optimizer: optim.Optimizer, + loss_components: Optional[Sequence[str]] = None, + ): + if loss_components is not None and not loss_components: + raise ValueError( + "loss_components list cannot be empty. " + "Set to None to act on all components of the loss." + ) + + self.optimizer = optimizer + self.loss_components = loss_components + if self.loss_components is not None: + self.loss_components = set(self.loss_components) + + def _grad_clip(self, clip_grad_norm: bool, clip_norm: float) -> float: + params = [] + for param_group in self.optimizer.param_groups: + params += param_group["params"] + + if clip_grad_norm: + gn = nn.utils.clip_grad_norm_(params, clip_norm) + else: + gn = sum([p.grad.pow(2).sum() for p in params if p.grad is not None]).sqrt() + nn.utils.clip_grad_value_(params, clip_norm) + + return float(gn) + + def __call__( + self, + losses_td: TensorDictBase, + clip_grad_norm: bool, + clip_norm: float, + index: int, + ) -> TensorDictBase: + loss_components = ( + [item for key, item in losses_td.items() if key in self.loss_components] + if self.loss_components is not None + else [item for key, item in losses_td.items() if key.startswith("loss")] + ) + loss = sum(loss_components) + loss.backward() + + grad_norm = self._grad_clip(clip_grad_norm, clip_norm) + losses_td[f"grad_norm_{index}"] = torch.tensor(grad_norm) + + self.optimizer.step() + self.optimizer.zero_grad() + + return losses_td + + def state_dict(self) -> Dict[str, Any]: + return {} + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + pass + + def register(self, trainer, name="optimizer") -> None: + trainer.register_op("optimizer", self) + trainer.register_module(name, self) + + class ClearCudaCache(TrainerHookBase): """Clears cuda cache at a given interval. @@ -1208,7 +1282,9 @@ def load_state_dict(self, state_dict) -> None: self.frame_count = state_dict["frame_count"] -def _check_input_output_typehint(func: Callable, input: Type, output: Type): +def _check_input_output_typehint( + func: Callable, input: Type | List[Type], output: Type +): # Placeholder for a function that checks the types input / output against expectations return