Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions torchrl/objectives/llm/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ class GRPOLoss(LossModule):
- float x: symmetric clipping [1 - x, 1 + x] (default: 0.2)
- tuple (eps_low, eps_high): asymmetric clipping [1 - eps_low, 1 + eps_high] as in DAPO Clip-Higher
recommended defaults from DAPO: (0.20, 0.28); see Eq. (10) in the paper.
kl_mask_threshold (float | None, optional): enable token-wise trust-region filtering (KL-Mask).
When set, tokens with 0.5 * (log(pi_theta/pi_ref))^2 > kl_mask_threshold are masked out from the loss.
This stabilizes updates by skipping tokens that drifted too far from the reference distribution
(see table and description; enables per-token trust region).
entropy_bonus (bool, optional): if ``True``, an entropy bonus will be added to the
loss to favour exploratory policies.
samples_mc_entropy (int, optional): if the distribution retrieved from the policy
Expand Down Expand Up @@ -142,6 +146,7 @@ def __init__(
actor_network: LLMWrapperBase | None = None,
*,
clip_epsilon: float | tuple[float, float] = 0.2,
kl_mask_threshold: float | None = None,
entropy_bonus: bool = True,
samples_mc_entropy: int = 1,
entropy_coeff: float = 0.01,
Expand All @@ -161,6 +166,7 @@ def __init__(
self.samples_mc_entropy = samples_mc_entropy
self.entropy_coeff = entropy_coeff
self.reduction = reduction if reduction is not None else "mean"
self.kl_mask_threshold = kl_mask_threshold

# Determine device and register clip epsilon as buffer
if device is None:
Expand Down Expand Up @@ -335,6 +341,32 @@ def forward(self, tensordict: TensorDictBase) -> GRPOLossOutput:
tensordict, adv_shape=advantage.shape[:-1]
)
mask = dist.mask

# Optional per-token trust-region filtering (KL-Mask) vs reference policy
if self.kl_mask_threshold is not None and self.kl_mask_threshold > 0:
try:
ref_log_prob = tensordict.get(
self.tensor_keys.ref_log_probs,
as_padded_tensor=True,
padding_side="left",
padding_value=0.0,
)
except KeyError:
ref_log_prob = None
cur_log_prob = tensordict.get("_cur_log_prob", None)
if (ref_log_prob is not None) and (cur_log_prob is not None):
# Align to valid tokens only (safety)
cur_log_prob_masked = torch.where(
expand_as_right(mask, cur_log_prob), cur_log_prob, 0.0
)
ref_log_prob_masked = torch.where(
expand_as_right(mask, ref_log_prob), ref_log_prob, 0.0
)
log_is_ref = cur_log_prob_masked - ref_log_prob_masked
kl_token = 0.5 * (log_is_ref**2)
tr_mask = kl_token <= self.kl_mask_threshold
# Combine with attention mask
mask = mask & tr_mask
# ESS for logging
with torch.no_grad():
# In theory, ESS should be computed on particles sampled from the same source. Here we sample according
Expand Down
Loading