diff --git a/test/llm/test_objectives.py b/test/llm/test_objectives.py index f6216129c51..3c09a252ea8 100644 --- a/test/llm/test_objectives.py +++ b/test/llm/test_objectives.py @@ -86,9 +86,7 @@ def make_silly_trajectory(n_steps=None): # Mock infrastructure moved to conftest.py -def _mock_data_grpo( - vocab_size: int, device: torch.device | str = "cpu" -) -> TensorDict: +def _mock_data_grpo(vocab_size: int, device: torch.device | str = "cpu") -> TensorDict: from transformers import AutoTokenizer device = torch.device(device) @@ -175,11 +173,17 @@ def _mock_data_grpo( class TestLosses: - def test_grpo(self, mock_transformer_model): + @pytest.mark.parametrize("dapo", [True, False], ids=["dapo", "symmetric"]) + def test_grpo(self, mock_transformer_model, dapo): """Test GRPO loss computation with mock models.""" vocab_size = 1024 device = torch.device("cpu") - + if dapo: + eps_low = 0.20 + eps_high = 0.28 + eps = (eps_low, eps_high) + else: + eps = 0.20 # Create mock model and wrap it model = mock_transformer_model(vocab_size=vocab_size, device=device) actor_network = TransformersWrapper( @@ -190,7 +194,7 @@ def test_grpo(self, mock_transformer_model): ) # Create loss module - loss_fn = GRPOLoss(actor_network) + loss_fn = GRPOLoss(actor_network, eps=eps) # Create fake data data = _mock_data_grpo(vocab_size=vocab_size, device=device) diff --git a/torchrl/objectives/llm/grpo.py b/torchrl/objectives/llm/grpo.py index 98581aac01c..2cdab05be9f 100644 --- a/torchrl/objectives/llm/grpo.py +++ b/torchrl/objectives/llm/grpo.py @@ -78,8 +78,10 @@ class GRPOLoss(LossModule): The masking strategy must match the strategy used for advantage computation to avoid shape mismatches. Keyword Args: - clip_epsilon (scalar, optional): weight clipping threshold in the clipped PPO loss equation. - default: 0.2 + clip_epsilon (float | tuple[float, float], optional): clipping threshold(s) for the clipped surrogate. + - float x: symmetric clipping [1 - x, 1 + x] (default: 0.2) + - tuple (eps_low, eps_high): asymmetric clipping [1 - eps_low, 1 + eps_high] as in DAPO Clip-Higher + recommended defaults from DAPO: (0.20, 0.28); see Eq. (10) in the paper. entropy_bonus (bool, optional): if ``True``, an entropy bonus will be added to the loss to favour exploratory policies. samples_mc_entropy (int, optional): if the distribution retrieved from the policy @@ -115,6 +117,9 @@ class GRPOLoss(LossModule): .. note:: Parameters and buffers from the policy / critic will not be cast to that device to ensure that the storages match the ones that are passed to other components, such as data collectors. + + .. note:: For non-symmetric clipping thresholds, see the `DAPO `_ paper. + """ actor_network: LLMWrapperBase @@ -136,7 +141,7 @@ def __init__( self, actor_network: LLMWrapperBase | None = None, *, - clip_epsilon: float = 0.2, + clip_epsilon: float | tuple[float, float] = 0.2, entropy_bonus: bool = True, samples_mc_entropy: int = 1, entropy_coeff: float = 0.01, @@ -165,7 +170,28 @@ def __init__( device = getattr( torch, "get_default_device", lambda: torch.device("cpu") )() - self.register_buffer("clip_epsilon", torch.tensor(clip_epsilon, device=device)) + # Accept symmetric or asymmetric thresholds + if isinstance(clip_epsilon, (tuple, list)): + if len(clip_epsilon) != 2: + raise ValueError( + f"clip_epsilon tuple must have length 2, got {clip_epsilon}." + ) + eps_low, eps_high = clip_epsilon + else: + eps_low = float(clip_epsilon) + eps_high = float(clip_epsilon) + # Basic validation + if eps_low < 0 or eps_high < 0: + raise ValueError( + f"clip_epsilon values must be non-negative, got ({eps_low}, {eps_high})." + ) + if eps_low >= 1.0: + raise ValueError( + f"clip_epsilon low must be < 1 (to keep 1 - eps_low > 0), got {eps_low}." + ) + # Register buffers + self.register_buffer("clip_epsilon_low", torch.tensor(eps_low, device=device)) + self.register_buffer("clip_epsilon_high", torch.tensor(eps_high, device=device)) self.masking_strategy = masking_strategy # Defaults for keys @@ -178,7 +204,11 @@ def __init__( @property def _clip_bounds(self): - return ((-self.clip_epsilon).log1p(), self.clip_epsilon.log1p()) + # Returns (log(1 - eps_low), log(1 + eps_high)) for clamping log-weight + return ( + (-self.clip_epsilon_low).log1p(), + self.clip_epsilon_high.log1p(), + ) def _set_in_keys(self): keys = [] @@ -325,6 +355,7 @@ def forward(self, tensordict: TensorDictBase) -> GRPOLossOutput: ratio = log_weight_clip.exp() gain2 = ratio * advantage + # Token-level objective: compute min over clipped/unclipped at the token level gain = torch.stack([gain1, gain2], -1).min(dim=-1).values td_out = TensorDict({"loss_objective": -gain}) td_out.set("clip_fraction", clip_fraction) @@ -514,6 +545,55 @@ def _log_weight( return log_weight, dist, kl_approx +class DAPO(GRPOLoss): + """DAPO (Clip-Higher over GRPO). + + Validates asymmetric clip thresholds; recommended (0.20, 0.28), see Eq. (10) in DAPO + [arXiv](https://arxiv.org/html/2503.14476). + """ + + def __init__( + self, + tensordict: TensorDictBase, + key: NestedKey = ("next", "ref_log_prob"), + ref_log_prob: torch.Tensor | None = None, + coeff: float | None = None, + mask: torch.Tensor | None = None, + dist: d.Distribution | None = None, + ): + if coeff is None: + coeff = self.kl_to_ref_coeff + # TODO: customize this + if ref_log_prob is None: + ref_log_prob = tensordict.get( + key, + as_padded_tensor=True, + padding_side="left", + padding_value=0.0, + ) + if ref_log_prob is None: + raise KeyError( + f"Couldn't find the ref log-prob {key} in the input data ({tensordict.keys(True)=})." + ) + ref_log_prob = ref_log_prob.squeeze(-1) + cur_log_prob = tensordict.get("_cur_log_prob") + # TODO: remove this + if cur_log_prob.shape != ref_log_prob.shape: + raise ValueError( + f"cur_log_prob and ref_log_prob must have the same shape, got {cur_log_prob.shape=} and {ref_log_prob.shape=}" + ) + if mask is not None: + ref_log_prob = torch.where( + expand_as_right(mask, ref_log_prob), ref_log_prob, 0.0 + ) + cur_log_prob = torch.where( + expand_as_right(mask, cur_log_prob), cur_log_prob, 0.0 + ) + diff = ref_log_prob - cur_log_prob + kl_penalty = (diff.expm1() - diff).mean() + return coeff * kl_penalty, kl_penalty + + class MCAdvantage(Transform): """Monte-Carlo advantage computation engine.