Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 40 additions & 10 deletions torchrl/objectives/llm/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,16 +348,10 @@ def forward(self, tensordict: TensorDictBase) -> GRPOLossOutput:
raise ValueError(
f"advantage and log_weight must have the same number of dimensions, got {advantage.ndim=} and {log_weight.ndim=}"
)
gain1 = log_weight.exp() * advantage

log_weight_clip = log_weight.clamp(*self._clip_bounds)
clip_fraction = (log_weight_clip != log_weight).to(log_weight.dtype).mean()
ratio = log_weight_clip.exp()
gain2 = ratio * advantage

# Token-level objective: compute min over clipped/unclipped at the token level
gain = torch.stack([gain1, gain2], -1).min(dim=-1).values
td_out = TensorDict({"loss_objective": -gain})
loss_objective, clip_fraction = self._compute_policy_objective(
log_weight, advantage
)
td_out = TensorDict({"loss_objective": loss_objective})
td_out.set("clip_fraction", clip_fraction)
td_out.set("kl_approx", kl_approx.detach().mean()) # for logging

Expand Down Expand Up @@ -406,6 +400,21 @@ def forward(self, tensordict: TensorDictBase) -> GRPOLossOutput:
del tensordict["_cur_log_prob"]
return GRPOLossOutput.from_tensordict(td_out)

def _compute_policy_objective(
self, log_weight: torch.Tensor, advantage: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""Default GRPO objective: PPO-style min between unclipped and clipped ratios.

Returns (loss_objective, clip_fraction).
"""
gain1 = log_weight.exp() * advantage
log_weight_clip = log_weight.clamp(*self._clip_bounds)
clip_fraction = (log_weight_clip != log_weight).to(log_weight.dtype).mean()
ratio = log_weight_clip.exp()
gain2 = ratio * advantage
gain = torch.stack([gain1, gain2], -1).min(dim=-1).values
return -gain, clip_fraction

def _get_entropy(
self, dist: d.Distribution, adv_shape: torch.Size
) -> torch.Tensor | TensorDict:
Expand Down Expand Up @@ -594,6 +603,27 @@ def __init__(
return coeff * kl_penalty, kl_penalty


class CISPO(GRPOLoss):
"""CISPO (Clipped Importance Sampling Policy Optimization).

Inherits the GRPO pipeline (masking, ESS, entropy, optional KL penalties) but
replaces the PPO-style min with a clipped-importance objective:
loss = - clip(weight, [1 - eps_low, 1 + eps_high]) * advantage

See MiniMax-M1 (CISPO) [arXiv](https://arxiv.org/html/2506.13585).
"""

def _compute_policy_objective(
self, log_weight: torch.Tensor, advantage: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
# CISPO: use clipped importance weights directly
log_weight_clip = log_weight.clamp(*self._clip_bounds)
clip_fraction = (log_weight_clip != log_weight).to(log_weight.dtype).mean()
ratio = log_weight_clip.exp()
gain = ratio * advantage
return -gain, clip_fraction


class MCAdvantage(Transform):
"""Monte-Carlo advantage computation engine.

Expand Down
Loading