From 20e65b6b2243bbe5b73ba045f92c3fd99f7ff931 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 16 Oct 2025 21:40:34 +0100 Subject: [PATCH] Update [ghstack-poisoned] --- torchrl/objectives/llm/grpo.py | 160 +++++++++++++++++++++++++-------- 1 file changed, 122 insertions(+), 38 deletions(-) diff --git a/torchrl/objectives/llm/grpo.py b/torchrl/objectives/llm/grpo.py index 9633fd451f6..08d93ae4c58 100644 --- a/torchrl/objectives/llm/grpo.py +++ b/torchrl/objectives/llm/grpo.py @@ -4,6 +4,8 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations +import contextlib + from collections import defaultdict, deque from dataclasses import dataclass from typing import Literal @@ -15,19 +17,19 @@ TensorClass, TensorDict, TensorDictBase, - TensorDictParams, ) from tensordict.nn import ( + CompositeDistribution, ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential, - TensorDictModule, + set_composite_lp_aggregate, ) from tensordict.utils import expand_as_right from torch import distributions as d -from torchrl._utils import logger as torchrl_logger +from torchrl._utils import logger as torchrl_logger, VERBOSE from torchrl.envs.transforms.transforms import Transform from torchrl.modules.llm import LLMWrapperBase -from torchrl.objectives.ppo import ClipPPOLoss +from torchrl.objectives.common import LossModule from torchrl.objectives.utils import _reduce, _sum_td_features @@ -46,7 +48,7 @@ class GRPOLossOutput(TensorClass["nocast"]): kl_to_inference: torch.Tensor | None = None -class GRPOLoss(ClipPPOLoss): +class GRPOLoss(LossModule): """GRPO loss. The clipped importance weighted loss is computed as follows: @@ -116,20 +118,18 @@ class GRPOLoss(ClipPPOLoss): """ actor_network: LLMWrapperBase - critic_network: TensorDictModule - actor_network_params: TensorDictParams - critic_network_params: TensorDictParams - target_actor_network_params: TensorDictParams - target_critic_network_params: TensorDictParams @dataclass - class _AcceptedKeys(ClipPPOLoss._AcceptedKeys): + class _AcceptedKeys(LossModule._AcceptedKeys): """Maintains default values for all configurable tensordict keys. This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their default values """ + advantage: NestedKey = "advantage" + action: NestedKey = ("tokens", "full") + sample_log_prob: NestedKey = ("log_probs", "full") ref_log_probs: NestedKey = ("next", "ref_log_probs", "full") def __init__( @@ -149,32 +149,85 @@ def __init__( masking_strategy: Literal["sft", "rlhf", "generic"] = "sft", **kwargs, ): - # Define clipping of the value loss - if isinstance(clip_value, bool): - clip_value = clip_epsilon if clip_value else None - - super().__init__( - actor_network, - critic_network=None, - entropy_bonus=entropy_bonus, - samples_mc_entropy=samples_mc_entropy, - entropy_coeff=entropy_coeff, - gamma=gamma, - separate_losses=False, - reduction=reduction, - clip_value=clip_value, - functional=False, - device=device, - **kwargs, - ) - # We don't want to use the string action but the tokens - self._set_in_keys() + super().__init__() + # Core modules and hyper-parameters + self.actor_network = actor_network + self.entropy_bonus = entropy_bonus + self.samples_mc_entropy = samples_mc_entropy + self.entropy_coeff = entropy_coeff + self.reduction = reduction + + # Determine device and register clip epsilon as buffer + if device is None: + try: + device = next(self.parameters()).device + except (AttributeError, StopIteration): + device = getattr( + torch, "get_default_device", lambda: torch.device("cpu") + )() + self.register_buffer("clip_epsilon", torch.tensor(clip_epsilon, device=device)) + self.masking_strategy = masking_strategy - # Always use the full tokens for the action + # Defaults for keys self.set_keys(sample_log_prob=("log_probs", "full"), action=("tokens", "full")) - # TODO: make this a buffer + # KL coefficients self.kl_to_ref_coeff = kl_to_ref_coeff self.kl_to_inference_coeff = kl_to_inference_coeff + # Prepare IO keys + self._set_in_keys() + + @property + def _clip_bounds(self): + return ((-self.clip_epsilon).log1p(), self.clip_epsilon.log1p()) + + def _set_in_keys(self): + keys = [] + if getattr(self, "actor_network", None) is not None and hasattr( + self.actor_network, "in_keys" + ): + in_keys = self.actor_network.in_keys + if isinstance(in_keys, (list, tuple)): + keys.extend(in_keys) + keys.append(self.tensor_keys.action) + keys.append(self.tensor_keys.sample_log_prob) + keys.append(self.tensor_keys.advantage) + keys.append(self.tensor_keys.ref_log_probs) + self._in_keys = list(dict.fromkeys(keys)) + + @property + def in_keys(self): + if getattr(self, "_in_keys", None) is None: + self._set_in_keys() + return self._in_keys + + @in_keys.setter + def in_keys(self, values): + self._in_keys = values + + @property + def out_keys(self): + if getattr(self, "_out_keys", None) is None: + keys = ["loss_objective", "clip_fraction", "ESS", "kl_approx"] + if self.entropy_bonus: + keys.extend(["entropy", "loss_entropy"]) + keys.extend( + [ + "loss_kl_to_ref", + "kl_to_ref", + "loss_kl_to_inference", + "kl_to_inference", + ] + ) + self._out_keys = keys + return self._out_keys + + @out_keys.setter + def out_keys(self, values): + self._out_keys = values + + def _forward_value_estimator_keys(self, **kwargs) -> None: + # No value estimator in GRPO; simply refresh input keys + self._set_in_keys() def _get_cur_log_prob(self, tensordict): """Override to use LLM-specific distribution with explicit masking strategy. @@ -281,11 +334,6 @@ def forward(self, tensordict: TensorDictBase) -> GRPOLossOutput: entropy = _sum_td_features(entropy) td_out.set("entropy", entropy.detach().mean()) # for logging td_out.set("loss_entropy", -self.entropy_coeff * entropy) - if self._has_critic: - loss_critic, value_clip_fraction = self.loss_critic(tensordict) - td_out.set("loss_critic", loss_critic) - if value_clip_fraction is not None: - td_out.set("value_clip_fraction", value_clip_fraction) td_out.set("ESS", _reduce(ess / batch, self.reduction)) td_out = td_out.named_apply( @@ -323,6 +371,42 @@ def forward(self, tensordict: TensorDictBase) -> GRPOLossOutput: del tensordict["_cur_log_prob"] return GRPOLossOutput.from_tensordict(td_out) + def _get_entropy( + self, dist: d.Distribution, adv_shape: torch.Size + ) -> torch.Tensor | TensorDict: + try: + entropy = dist.entropy() + if not entropy.isfinite().all(): + del entropy + if VERBOSE: + torchrl_logger.info( + "Entropy is not finite. Using Monte Carlo sampling." + ) + raise NotImplementedError + except NotImplementedError: + if VERBOSE: + torchrl_logger.warning( + f"Entropy not implemented for {type(dist)} or is not finite. Using Monte Carlo sampling." + ) + if getattr(dist, "has_rsample", False): + x = dist.rsample((self.samples_mc_entropy,)) + else: + x = dist.sample((self.samples_mc_entropy,)) + with set_composite_lp_aggregate(False) if isinstance( + dist, CompositeDistribution + ) else contextlib.nullcontext(): + log_prob = dist.log_prob(x) + if is_tensor_collection(log_prob): + if isinstance(self.tensor_keys.sample_log_prob, NestedKey): + log_prob = log_prob.get(self.tensor_keys.sample_log_prob) + else: + log_prob = log_prob.select(*self.tensor_keys.sample_log_prob) + + entropy = -log_prob.mean(0) + if is_tensor_collection(entropy) and entropy.batch_size != adv_shape: + entropy.batch_size = adv_shape + return entropy.unsqueeze(-1) + def _kl_to_ref( self, tensordict: TensorDictBase,