Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 122 additions & 38 deletions torchrl/objectives/llm/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import contextlib

from collections import defaultdict, deque
from dataclasses import dataclass
from typing import Literal
Expand All @@ -15,19 +17,19 @@
TensorClass,
TensorDict,
TensorDictBase,
TensorDictParams,
)
from tensordict.nn import (
CompositeDistribution,
ProbabilisticTensorDictModule,
ProbabilisticTensorDictSequential,
TensorDictModule,
set_composite_lp_aggregate,
)
from tensordict.utils import expand_as_right
from torch import distributions as d
from torchrl._utils import logger as torchrl_logger
from torchrl._utils import logger as torchrl_logger, VERBOSE
from torchrl.envs.transforms.transforms import Transform
from torchrl.modules.llm import LLMWrapperBase
from torchrl.objectives.ppo import ClipPPOLoss
from torchrl.objectives.common import LossModule
from torchrl.objectives.utils import _reduce, _sum_td_features


Expand All @@ -46,7 +48,7 @@ class GRPOLossOutput(TensorClass["nocast"]):
kl_to_inference: torch.Tensor | None = None


class GRPOLoss(ClipPPOLoss):
class GRPOLoss(LossModule):
"""GRPO loss.

The clipped importance weighted loss is computed as follows:
Expand Down Expand Up @@ -116,20 +118,18 @@ class GRPOLoss(ClipPPOLoss):
"""

actor_network: LLMWrapperBase
critic_network: TensorDictModule
actor_network_params: TensorDictParams
critic_network_params: TensorDictParams
target_actor_network_params: TensorDictParams
target_critic_network_params: TensorDictParams

@dataclass
class _AcceptedKeys(ClipPPOLoss._AcceptedKeys):
class _AcceptedKeys(LossModule._AcceptedKeys):
"""Maintains default values for all configurable tensordict keys.

This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their
default values
"""

advantage: NestedKey = "advantage"
action: NestedKey = ("tokens", "full")
sample_log_prob: NestedKey = ("log_probs", "full")
ref_log_probs: NestedKey = ("next", "ref_log_probs", "full")

def __init__(
Expand All @@ -149,32 +149,85 @@ def __init__(
masking_strategy: Literal["sft", "rlhf", "generic"] = "sft",
**kwargs,
):
# Define clipping of the value loss
if isinstance(clip_value, bool):
clip_value = clip_epsilon if clip_value else None

super().__init__(
actor_network,
critic_network=None,
entropy_bonus=entropy_bonus,
samples_mc_entropy=samples_mc_entropy,
entropy_coeff=entropy_coeff,
gamma=gamma,
separate_losses=False,
reduction=reduction,
clip_value=clip_value,
functional=False,
device=device,
**kwargs,
)
# We don't want to use the string action but the tokens
self._set_in_keys()
super().__init__()
# Core modules and hyper-parameters
self.actor_network = actor_network
self.entropy_bonus = entropy_bonus
self.samples_mc_entropy = samples_mc_entropy
self.entropy_coeff = entropy_coeff
self.reduction = reduction

# Determine device and register clip epsilon as buffer
if device is None:
try:
device = next(self.parameters()).device
except (AttributeError, StopIteration):
device = getattr(
torch, "get_default_device", lambda: torch.device("cpu")
)()
self.register_buffer("clip_epsilon", torch.tensor(clip_epsilon, device=device))

self.masking_strategy = masking_strategy
# Always use the full tokens for the action
# Defaults for keys
self.set_keys(sample_log_prob=("log_probs", "full"), action=("tokens", "full"))
# TODO: make this a buffer
# KL coefficients
self.kl_to_ref_coeff = kl_to_ref_coeff
self.kl_to_inference_coeff = kl_to_inference_coeff
# Prepare IO keys
self._set_in_keys()

@property
def _clip_bounds(self):
return ((-self.clip_epsilon).log1p(), self.clip_epsilon.log1p())

def _set_in_keys(self):
keys = []
if getattr(self, "actor_network", None) is not None and hasattr(
self.actor_network, "in_keys"
):
in_keys = self.actor_network.in_keys
if isinstance(in_keys, (list, tuple)):
keys.extend(in_keys)
keys.append(self.tensor_keys.action)
keys.append(self.tensor_keys.sample_log_prob)
keys.append(self.tensor_keys.advantage)
keys.append(self.tensor_keys.ref_log_probs)
self._in_keys = list(dict.fromkeys(keys))

@property
def in_keys(self):
if getattr(self, "_in_keys", None) is None:
self._set_in_keys()
return self._in_keys

@in_keys.setter
def in_keys(self, values):
self._in_keys = values

@property
def out_keys(self):
if getattr(self, "_out_keys", None) is None:
keys = ["loss_objective", "clip_fraction", "ESS", "kl_approx"]
if self.entropy_bonus:
keys.extend(["entropy", "loss_entropy"])
keys.extend(
[
"loss_kl_to_ref",
"kl_to_ref",
"loss_kl_to_inference",
"kl_to_inference",
]
)
self._out_keys = keys
return self._out_keys

@out_keys.setter
def out_keys(self, values):
self._out_keys = values

def _forward_value_estimator_keys(self, **kwargs) -> None:
# No value estimator in GRPO; simply refresh input keys
self._set_in_keys()

def _get_cur_log_prob(self, tensordict):
"""Override to use LLM-specific distribution with explicit masking strategy.
Expand Down Expand Up @@ -281,11 +334,6 @@ def forward(self, tensordict: TensorDictBase) -> GRPOLossOutput:
entropy = _sum_td_features(entropy)
td_out.set("entropy", entropy.detach().mean()) # for logging
td_out.set("loss_entropy", -self.entropy_coeff * entropy)
if self._has_critic:
loss_critic, value_clip_fraction = self.loss_critic(tensordict)
td_out.set("loss_critic", loss_critic)
if value_clip_fraction is not None:
td_out.set("value_clip_fraction", value_clip_fraction)

td_out.set("ESS", _reduce(ess / batch, self.reduction))
td_out = td_out.named_apply(
Expand Down Expand Up @@ -323,6 +371,42 @@ def forward(self, tensordict: TensorDictBase) -> GRPOLossOutput:
del tensordict["_cur_log_prob"]
return GRPOLossOutput.from_tensordict(td_out)

def _get_entropy(
self, dist: d.Distribution, adv_shape: torch.Size
) -> torch.Tensor | TensorDict:
try:
entropy = dist.entropy()
if not entropy.isfinite().all():
del entropy
if VERBOSE:
torchrl_logger.info(
"Entropy is not finite. Using Monte Carlo sampling."
)
raise NotImplementedError
except NotImplementedError:
if VERBOSE:
torchrl_logger.warning(
f"Entropy not implemented for {type(dist)} or is not finite. Using Monte Carlo sampling."
)
if getattr(dist, "has_rsample", False):
x = dist.rsample((self.samples_mc_entropy,))
else:
x = dist.sample((self.samples_mc_entropy,))
with set_composite_lp_aggregate(False) if isinstance(
dist, CompositeDistribution
) else contextlib.nullcontext():
log_prob = dist.log_prob(x)
if is_tensor_collection(log_prob):
if isinstance(self.tensor_keys.sample_log_prob, NestedKey):
log_prob = log_prob.get(self.tensor_keys.sample_log_prob)
else:
log_prob = log_prob.select(*self.tensor_keys.sample_log_prob)

entropy = -log_prob.mean(0)
if is_tensor_collection(entropy) and entropy.batch_size != adv_shape:
entropy.batch_size = adv_shape
return entropy.unsqueeze(-1)

def _kl_to_ref(
self,
tensordict: TensorDictBase,
Expand Down
Loading