Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 41 additions & 7 deletions torchrl/objectives/llm/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ class GRPOLoss(LossModule):
When set, tokens with 0.5 * (log(pi_theta/pi_ref))^2 > kl_mask_threshold are masked out from the loss.
This stabilizes updates by skipping tokens that drifted too far from the reference distribution
(see table and description; enables per-token trust region).
aggregation (str, optional): loss aggregation strategy for the policy objective.
- "token_mean": global masked token mean (weights long sequences more). Default.
- "prompt_mean": per-sample masked mean over tokens, then mean across samples (equal sample weight).
- "none": return per-token loss (mask applied, no aggregation). Useful for downstream custom reductions.
entropy_bonus (bool, optional): if ``True``, an entropy bonus will be added to the
loss to favour exploratory policies.
samples_mc_entropy (int, optional): if the distribution retrieved from the policy
Expand Down Expand Up @@ -147,6 +151,7 @@ def __init__(
*,
clip_epsilon: float | tuple[float, float] = 0.2,
kl_mask_threshold: float | None = None,
aggregation: str | None = "token_mean",
entropy_bonus: bool = True,
samples_mc_entropy: int = 1,
entropy_coeff: float = 0.01,
Expand All @@ -167,6 +172,7 @@ def __init__(
self.entropy_coeff = entropy_coeff
self.reduction = reduction if reduction is not None else "mean"
self.kl_mask_threshold = kl_mask_threshold
self.aggregation = aggregation or "token_mean"

# Determine device and register clip epsilon as buffer
if device is None:
Expand Down Expand Up @@ -397,13 +403,13 @@ def forward(self, tensordict: TensorDictBase) -> GRPOLossOutput:
td_out.set("loss_entropy", -self.entropy_coeff * entropy)

td_out.set("ESS", _reduce(ess / batch, self.reduction))
td_out = td_out.named_apply(
lambda name, value: _reduce(
value, reduction=self.reduction, mask=mask
).squeeze(-1)
if name.startswith("loss_")
else value,
)
# Aggregate loss terms according to aggregation strategy
for key in list(td_out.keys()):
if isinstance(key, tuple) or not isinstance(key, str):
continue
if key.startswith("loss_"):
val = td_out.get(key)
td_out.set(key, self._aggregate_loss_value(val, mask))
if self.kl_to_ref_coeff is not None and self.kl_to_ref_coeff > 0:
# FIXME: parameterize this
loss_kl, kl_penalty = self._kl_to_ref(
Expand Down Expand Up @@ -447,6 +453,34 @@ def _compute_policy_objective(
gain = torch.stack([gain1, gain2], -1).min(dim=-1).values
return -gain, clip_fraction

def _aggregate_loss_value(
self, value: torch.Tensor, mask: torch.Tensor
) -> torch.Tensor:
"""Aggregate a per-token loss tensor using the configured strategy.

Supports:
- token_mean: masked mean across all tokens (default)
- prompt_mean: per-sample masked mean over tokens, then mean across batch
- none: return per-token loss with masked-out tokens set to 0

The input `value` is expected to have shape [..., T, 1] where T is the token dimension,
and `mask` has shape [..., T].
"""
if self.aggregation == "none" or self.reduction == "none":
mask_exp = expand_as_right(mask, value)
return torch.where(mask_exp, value, value.new_zeros(()).expand_as(value))

if self.aggregation == "prompt_mean":
# Mean over valid tokens per sample, then mean across batch
mask_exp = expand_as_right(mask, value).to(value.dtype)
token_sum = (value * mask_exp).sum(dim=-2, keepdim=False)
token_count = mask_exp.sum(dim=-2, keepdim=False).clamp_min(1.0)
sample_mean = token_sum / token_count
return sample_mean.mean(dim=0, keepdim=False)

# token_mean (global masked mean)
return _reduce(value, reduction="mean", mask=mask).squeeze(-1)

def _get_entropy(
self, dist: d.Distribution, adv_shape: torch.Size
) -> torch.Tensor | TensorDict:
Expand Down
Loading