diff --git a/torchrl/objectives/llm/grpo.py b/torchrl/objectives/llm/grpo.py index 2cdab05be9f..51158a419a0 100644 --- a/torchrl/objectives/llm/grpo.py +++ b/torchrl/objectives/llm/grpo.py @@ -348,16 +348,10 @@ def forward(self, tensordict: TensorDictBase) -> GRPOLossOutput: raise ValueError( f"advantage and log_weight must have the same number of dimensions, got {advantage.ndim=} and {log_weight.ndim=}" ) - gain1 = log_weight.exp() * advantage - - log_weight_clip = log_weight.clamp(*self._clip_bounds) - clip_fraction = (log_weight_clip != log_weight).to(log_weight.dtype).mean() - ratio = log_weight_clip.exp() - gain2 = ratio * advantage - - # Token-level objective: compute min over clipped/unclipped at the token level - gain = torch.stack([gain1, gain2], -1).min(dim=-1).values - td_out = TensorDict({"loss_objective": -gain}) + loss_objective, clip_fraction = self._compute_policy_objective( + log_weight, advantage + ) + td_out = TensorDict({"loss_objective": loss_objective}) td_out.set("clip_fraction", clip_fraction) td_out.set("kl_approx", kl_approx.detach().mean()) # for logging @@ -406,6 +400,21 @@ def forward(self, tensordict: TensorDictBase) -> GRPOLossOutput: del tensordict["_cur_log_prob"] return GRPOLossOutput.from_tensordict(td_out) + def _compute_policy_objective( + self, log_weight: torch.Tensor, advantage: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + """Default GRPO objective: PPO-style min between unclipped and clipped ratios. + + Returns (loss_objective, clip_fraction). + """ + gain1 = log_weight.exp() * advantage + log_weight_clip = log_weight.clamp(*self._clip_bounds) + clip_fraction = (log_weight_clip != log_weight).to(log_weight.dtype).mean() + ratio = log_weight_clip.exp() + gain2 = ratio * advantage + gain = torch.stack([gain1, gain2], -1).min(dim=-1).values + return -gain, clip_fraction + def _get_entropy( self, dist: d.Distribution, adv_shape: torch.Size ) -> torch.Tensor | TensorDict: @@ -594,6 +603,27 @@ def __init__( return coeff * kl_penalty, kl_penalty +class CISPO(GRPOLoss): + """CISPO (Clipped Importance Sampling Policy Optimization). + + Inherits the GRPO pipeline (masking, ESS, entropy, optional KL penalties) but + replaces the PPO-style min with a clipped-importance objective: + loss = - clip(weight, [1 - eps_low, 1 + eps_high]) * advantage + + See MiniMax-M1 (CISPO) [arXiv](https://arxiv.org/html/2506.13585). + """ + + def _compute_policy_objective( + self, log_weight: torch.Tensor, advantage: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + # CISPO: use clipped importance weights directly + log_weight_clip = log_weight.clamp(*self._clip_bounds) + clip_fraction = (log_weight_clip != log_weight).to(log_weight.dtype).mean() + ratio = log_weight_clip.exp() + gain = ratio * advantage + return -gain, clip_fraction + + class MCAdvantage(Transform): """Monte-Carlo advantage computation engine.