Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions test/llm/test_objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,7 @@ def make_silly_trajectory(n_steps=None):
# Mock infrastructure moved to conftest.py


def _mock_data_grpo(
vocab_size: int, device: torch.device | str = "cpu"
) -> TensorDict:
def _mock_data_grpo(vocab_size: int, device: torch.device | str = "cpu") -> TensorDict:
from transformers import AutoTokenizer

device = torch.device(device)
Expand Down Expand Up @@ -175,11 +173,17 @@ def _mock_data_grpo(


class TestLosses:
def test_grpo(self, mock_transformer_model):
@pytest.mark.parametrize("dapo", [True, False], ids=["dapo", "symmetric"])
def test_grpo(self, mock_transformer_model, dapo):
"""Test GRPO loss computation with mock models."""
vocab_size = 1024
device = torch.device("cpu")

if dapo:
eps_low = 0.20
eps_high = 0.28
eps = (eps_low, eps_high)
else:
eps = 0.20
# Create mock model and wrap it
model = mock_transformer_model(vocab_size=vocab_size, device=device)
actor_network = TransformersWrapper(
Expand All @@ -190,7 +194,7 @@ def test_grpo(self, mock_transformer_model):
)

# Create loss module
loss_fn = GRPOLoss(actor_network)
loss_fn = GRPOLoss(actor_network, eps=eps)

# Create fake data
data = _mock_data_grpo(vocab_size=vocab_size, device=device)
Expand Down
90 changes: 85 additions & 5 deletions torchrl/objectives/llm/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,10 @@ class GRPOLoss(LossModule):
The masking strategy must match the strategy used for advantage computation to avoid shape mismatches.

Keyword Args:
clip_epsilon (scalar, optional): weight clipping threshold in the clipped PPO loss equation.
default: 0.2
clip_epsilon (float | tuple[float, float], optional): clipping threshold(s) for the clipped surrogate.
- float x: symmetric clipping [1 - x, 1 + x] (default: 0.2)
- tuple (eps_low, eps_high): asymmetric clipping [1 - eps_low, 1 + eps_high] as in DAPO Clip-Higher
recommended defaults from DAPO: (0.20, 0.28); see Eq. (10) in the paper.
entropy_bonus (bool, optional): if ``True``, an entropy bonus will be added to the
loss to favour exploratory policies.
samples_mc_entropy (int, optional): if the distribution retrieved from the policy
Expand Down Expand Up @@ -115,6 +117,9 @@ class GRPOLoss(LossModule):

.. note:: Parameters and buffers from the policy / critic will not be cast to that device to ensure that
the storages match the ones that are passed to other components, such as data collectors.

.. note:: For non-symmetric clipping thresholds, see the `DAPO <https://arxiv.org/html/2503.14476>`_ paper.

"""

actor_network: LLMWrapperBase
Expand All @@ -136,7 +141,7 @@ def __init__(
self,
actor_network: LLMWrapperBase | None = None,
*,
clip_epsilon: float = 0.2,
clip_epsilon: float | tuple[float, float] = 0.2,
entropy_bonus: bool = True,
samples_mc_entropy: int = 1,
entropy_coeff: float = 0.01,
Expand Down Expand Up @@ -165,7 +170,28 @@ def __init__(
device = getattr(
torch, "get_default_device", lambda: torch.device("cpu")
)()
self.register_buffer("clip_epsilon", torch.tensor(clip_epsilon, device=device))
# Accept symmetric or asymmetric thresholds
if isinstance(clip_epsilon, (tuple, list)):
if len(clip_epsilon) != 2:
raise ValueError(
f"clip_epsilon tuple must have length 2, got {clip_epsilon}."
)
eps_low, eps_high = clip_epsilon
else:
eps_low = float(clip_epsilon)
eps_high = float(clip_epsilon)
# Basic validation
if eps_low < 0 or eps_high < 0:
raise ValueError(
f"clip_epsilon values must be non-negative, got ({eps_low}, {eps_high})."
)
if eps_low >= 1.0:
raise ValueError(
f"clip_epsilon low must be < 1 (to keep 1 - eps_low > 0), got {eps_low}."
)
# Register buffers
self.register_buffer("clip_epsilon_low", torch.tensor(eps_low, device=device))
self.register_buffer("clip_epsilon_high", torch.tensor(eps_high, device=device))

self.masking_strategy = masking_strategy
# Defaults for keys
Expand All @@ -178,7 +204,11 @@ def __init__(

@property
def _clip_bounds(self):
return ((-self.clip_epsilon).log1p(), self.clip_epsilon.log1p())
# Returns (log(1 - eps_low), log(1 + eps_high)) for clamping log-weight
return (
(-self.clip_epsilon_low).log1p(),
self.clip_epsilon_high.log1p(),
)

def _set_in_keys(self):
keys = []
Expand Down Expand Up @@ -325,6 +355,7 @@ def forward(self, tensordict: TensorDictBase) -> GRPOLossOutput:
ratio = log_weight_clip.exp()
gain2 = ratio * advantage

# Token-level objective: compute min over clipped/unclipped at the token level
gain = torch.stack([gain1, gain2], -1).min(dim=-1).values
td_out = TensorDict({"loss_objective": -gain})
td_out.set("clip_fraction", clip_fraction)
Expand Down Expand Up @@ -514,6 +545,55 @@ def _log_weight(
return log_weight, dist, kl_approx


class DAPO(GRPOLoss):
"""DAPO (Clip-Higher over GRPO).

Validates asymmetric clip thresholds; recommended (0.20, 0.28), see Eq. (10) in DAPO
[arXiv](https://arxiv.org/html/2503.14476).
"""

def __init__(
self,
tensordict: TensorDictBase,
key: NestedKey = ("next", "ref_log_prob"),
ref_log_prob: torch.Tensor | None = None,
coeff: float | None = None,
mask: torch.Tensor | None = None,
dist: d.Distribution | None = None,
):
if coeff is None:
coeff = self.kl_to_ref_coeff
# TODO: customize this
if ref_log_prob is None:
ref_log_prob = tensordict.get(
key,
as_padded_tensor=True,
padding_side="left",
padding_value=0.0,
)
if ref_log_prob is None:
raise KeyError(
f"Couldn't find the ref log-prob {key} in the input data ({tensordict.keys(True)=})."
)
ref_log_prob = ref_log_prob.squeeze(-1)
cur_log_prob = tensordict.get("_cur_log_prob")
# TODO: remove this
if cur_log_prob.shape != ref_log_prob.shape:
raise ValueError(
f"cur_log_prob and ref_log_prob must have the same shape, got {cur_log_prob.shape=} and {ref_log_prob.shape=}"
)
if mask is not None:
ref_log_prob = torch.where(
expand_as_right(mask, ref_log_prob), ref_log_prob, 0.0
)
cur_log_prob = torch.where(
expand_as_right(mask, cur_log_prob), cur_log_prob, 0.0
)
diff = ref_log_prob - cur_log_prob
kl_penalty = (diff.expm1() - diff).mean()
return coeff * kl_penalty, kl_penalty


class MCAdvantage(Transform):
"""Monte-Carlo advantage computation engine.

Expand Down
Loading