diff --git a/torchrl/objectives/llm/grpo.py b/torchrl/objectives/llm/grpo.py index 51158a419a0..add8e7b14d8 100644 --- a/torchrl/objectives/llm/grpo.py +++ b/torchrl/objectives/llm/grpo.py @@ -82,6 +82,10 @@ class GRPOLoss(LossModule): - float x: symmetric clipping [1 - x, 1 + x] (default: 0.2) - tuple (eps_low, eps_high): asymmetric clipping [1 - eps_low, 1 + eps_high] as in DAPO Clip-Higher recommended defaults from DAPO: (0.20, 0.28); see Eq. (10) in the paper. + kl_mask_threshold (float | None, optional): enable token-wise trust-region filtering (KL-Mask). + When set, tokens with 0.5 * (log(pi_theta/pi_ref))^2 > kl_mask_threshold are masked out from the loss. + This stabilizes updates by skipping tokens that drifted too far from the reference distribution + (see table and description; enables per-token trust region). entropy_bonus (bool, optional): if ``True``, an entropy bonus will be added to the loss to favour exploratory policies. samples_mc_entropy (int, optional): if the distribution retrieved from the policy @@ -142,6 +146,7 @@ def __init__( actor_network: LLMWrapperBase | None = None, *, clip_epsilon: float | tuple[float, float] = 0.2, + kl_mask_threshold: float | None = None, entropy_bonus: bool = True, samples_mc_entropy: int = 1, entropy_coeff: float = 0.01, @@ -161,6 +166,7 @@ def __init__( self.samples_mc_entropy = samples_mc_entropy self.entropy_coeff = entropy_coeff self.reduction = reduction if reduction is not None else "mean" + self.kl_mask_threshold = kl_mask_threshold # Determine device and register clip epsilon as buffer if device is None: @@ -335,6 +341,32 @@ def forward(self, tensordict: TensorDictBase) -> GRPOLossOutput: tensordict, adv_shape=advantage.shape[:-1] ) mask = dist.mask + + # Optional per-token trust-region filtering (KL-Mask) vs reference policy + if self.kl_mask_threshold is not None and self.kl_mask_threshold > 0: + try: + ref_log_prob = tensordict.get( + self.tensor_keys.ref_log_probs, + as_padded_tensor=True, + padding_side="left", + padding_value=0.0, + ) + except KeyError: + ref_log_prob = None + cur_log_prob = tensordict.get("_cur_log_prob", None) + if (ref_log_prob is not None) and (cur_log_prob is not None): + # Align to valid tokens only (safety) + cur_log_prob_masked = torch.where( + expand_as_right(mask, cur_log_prob), cur_log_prob, 0.0 + ) + ref_log_prob_masked = torch.where( + expand_as_right(mask, ref_log_prob), ref_log_prob, 0.0 + ) + log_is_ref = cur_log_prob_masked - ref_log_prob_masked + kl_token = 0.5 * (log_is_ref**2) + tr_mask = kl_token <= self.kl_mask_threshold + # Combine with attention mask + mask = mask & tr_mask # ESS for logging with torch.no_grad(): # In theory, ESS should be computed on particles sampled from the same source. Here we sample according