Skip to content

Commit

Permalink
Implement ReBRAC
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Jun 2, 2024
1 parent 5f810eb commit 949ed5c
Show file tree
Hide file tree
Showing 17 changed files with 354 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ $ docker run -it --gpus all --name d3rlpy takuseno/d3rlpy:latest bash
| [TD3+BC](https://arxiv.org/abs/2106.06860) | :no_entry: | :white_check_mark: |
| [Implicit Q-Learning (IQL)](https://arxiv.org/abs/2110.06169) | :no_entry: | :white_check_mark: |
| [Calibrated Q-Learning (Cal-QL)](https://arxiv.org/abs/2303.05479) | :no_entry: | :white_check_mark: |
| [ReBRAC](https://arxiv.org/abs/2305.09836) | :no_entry: | :white_check_mark: |
| [Decision Transformer](https://arxiv.org/abs/2106.01345) | :white_check_mark: | :white_check_mark: |
| [Gato](https://arxiv.org/abs/2205.06175) | :construction: | :construction: |

Expand Down
1 change: 1 addition & 0 deletions d3rlpy/algos/qlearning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .nfq import *
from .plas import *
from .random_policy import *
from .rebrac import *
from .sac import *
from .td3 import *
from .td3_plus_bc import *
170 changes: 170 additions & 0 deletions d3rlpy/algos/qlearning/rebrac.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
import dataclasses

from ...base import DeviceArg, LearnableConfig, register_learnable
from ...constants import ActionSpace
from ...models.builders import (
create_continuous_q_function,
create_deterministic_policy,
)
from ...models.encoders import EncoderFactory, make_encoder_field
from ...models.optimizers import OptimizerFactory, make_optimizer_field
from ...models.q_functions import QFunctionFactory, make_q_func_field
from ...types import Shape
from .base import QLearningAlgoBase
from .torch.ddpg_impl import DDPGModules
from .torch.rebrac_impl import ReBRACImpl

__all__ = ["ReBRACConfig", "ReBRAC"]


@dataclasses.dataclass()
class ReBRACConfig(LearnableConfig):
r"""Config of ReBRAC algorithm.
ReBRAC is an extention to TD3+BC with additional optimization.
#. Deeper Networks (2 -> 3 hidden layers)
#. LayerNorm
#. Larger Batches (256 -> 1024)
#. Increased Discount Factor (0.99 -> 0.999)
#. Actor and Critic penalty decoupling
.. math::
J(\phi) = \mathbb{E}_{s,a \sim D}
[\lambda Q(s, \pi(s)) - \beta_1 \cdot (a - \pi(s))^2]
.. math::
L(\theta) = \mathbb{E}_{s,a,r,s',\hat{a'} \sim D, a' \sim \pi(s')}
[(Q_\theta (s, a) - (r + \gamma Q_\theta (s', a')
- \beta_2 \cdot (a' - \hat{a'})^2))^2]
References:
* `Tarasov et al., Revisiting the Minimalist Approach to Offline
Reinforcement Learning. <https://arxiv.org/abs/2305.09836>`_
Args:
observation_scaler (d3rlpy.preprocessing.ObservationScaler):
Observation preprocessor.
action_scaler (d3rlpy.preprocessing.ActionScaler): Action preprocessor.
reward_scaler (d3rlpy.preprocessing.RewardScaler): Reward preprocessor.
actor_learning_rate (float): Learning rate for a policy function.
critic_learning_rate (float): Learning rate for Q functions.
actor_optim_factory (d3rlpy.models.optimizers.OptimizerFactory):
Optimizer factory for the actor.
critic_optim_factory (d3rlpy.models.optimizers.OptimizerFactory):
Optimizer factory for the critic.
actor_encoder_factory (d3rlpy.models.encoders.EncoderFactory):
Encoder factory for the actor.
critic_encoder_factory (d3rlpy.models.encoders.EncoderFactory):
Encoder factory for the critic.
q_func_factory (d3rlpy.models.q_functions.QFunctionFactory):
Q function factory.
batch_size (int): Mini-batch size.
gamma (float): Discount factor.
tau (float): Target network synchronization coefficiency.
n_critics (int): Number of Q functions for ensemble.
target_smoothing_sigma (float): Standard deviation for target noise.
target_smoothing_clip (float): Clipping range for target noise.
actor_beta (float): :math:`\beta_1` value.
critic_beta (float): :math:`\beta_2` value.
update_actor_interval (int): Interval to update policy function
described as `delayed policy update` in the paper.
"""

actor_learning_rate: float = 1e-3
critic_learning_rate: float = 1e-3
actor_optim_factory: OptimizerFactory = make_optimizer_field()
critic_optim_factory: OptimizerFactory = make_optimizer_field()
actor_encoder_factory: EncoderFactory = make_encoder_field()
critic_encoder_factory: EncoderFactory = make_encoder_field()
q_func_factory: QFunctionFactory = make_q_func_field()
batch_size: int = 1024
gamma: float = 0.99
tau: float = 0.005
n_critics: int = 2
target_smoothing_sigma: float = 0.2
target_smoothing_clip: float = 0.5
actor_beta: float = 0.001
critic_beta: float = 0.01
update_actor_interval: int = 2

def create(self, device: DeviceArg = False) -> "ReBRAC":
return ReBRAC(self, device)

@staticmethod
def get_type() -> str:
return "rebrac"


class ReBRAC(QLearningAlgoBase[ReBRACImpl, ReBRACConfig]):
def inner_create_impl(
self, observation_shape: Shape, action_size: int
) -> None:
policy = create_deterministic_policy(
observation_shape,
action_size,
self._config.actor_encoder_factory,
device=self._device,
)
targ_policy = create_deterministic_policy(
observation_shape,
action_size,
self._config.actor_encoder_factory,
device=self._device,
)
q_funcs, q_func_forwarder = create_continuous_q_function(
observation_shape,
action_size,
self._config.critic_encoder_factory,
self._config.q_func_factory,
n_ensembles=self._config.n_critics,
device=self._device,
)
targ_q_funcs, targ_q_func_forwarder = create_continuous_q_function(
observation_shape,
action_size,
self._config.critic_encoder_factory,
self._config.q_func_factory,
n_ensembles=self._config.n_critics,
device=self._device,
)

actor_optim = self._config.actor_optim_factory.create(
policy.named_modules(), lr=self._config.actor_learning_rate
)
critic_optim = self._config.critic_optim_factory.create(
q_funcs.named_modules(), lr=self._config.critic_learning_rate
)

modules = DDPGModules(
policy=policy,
targ_policy=targ_policy,
q_funcs=q_funcs,
targ_q_funcs=targ_q_funcs,
actor_optim=actor_optim,
critic_optim=critic_optim,
)

self._impl = ReBRACImpl(
observation_shape=observation_shape,
action_size=action_size,
modules=modules,
q_func_forwarder=q_func_forwarder,
targ_q_func_forwarder=targ_q_func_forwarder,
gamma=self._config.gamma,
tau=self._config.tau,
target_smoothing_sigma=self._config.target_smoothing_sigma,
target_smoothing_clip=self._config.target_smoothing_clip,
actor_beta=self._config.actor_beta,
critic_beta=self._config.critic_beta,
update_actor_interval=self._config.update_actor_interval,
device=self._device,
)

def get_action_type(self) -> ActionSpace:
return ActionSpace.CONTINUOUS


register_learnable(ReBRACConfig)
1 change: 1 addition & 0 deletions d3rlpy/algos/qlearning/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .dqn_impl import *
from .iql_impl import *
from .plas_impl import *
from .rebrac_impl import *
from .sac_impl import *
from .td3_impl import *
from .td3_plus_bc_impl import *
83 changes: 83 additions & 0 deletions d3rlpy/algos/qlearning/torch/rebrac_impl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# pylint: disable=too-many-ancestors
import torch

from ....models.torch import ActionOutput, ContinuousEnsembleQFunctionForwarder
from ....torch_utility import TorchMiniBatch
from ....types import Shape
from .ddpg_impl import DDPGModules
from .td3_impl import TD3Impl
from .td3_plus_bc_impl import TD3PlusBCActorLoss

__all__ = ["ReBRACImpl"]


class ReBRACImpl(TD3Impl):
_actor_beta: float
_critic_beta: float

def __init__(
self,
observation_shape: Shape,
action_size: int,
modules: DDPGModules,
q_func_forwarder: ContinuousEnsembleQFunctionForwarder,
targ_q_func_forwarder: ContinuousEnsembleQFunctionForwarder,
gamma: float,
tau: float,
target_smoothing_sigma: float,
target_smoothing_clip: float,
actor_beta: float,
critic_beta: float,
update_actor_interval: int,
device: str,
):
super().__init__(
observation_shape=observation_shape,
action_size=action_size,
modules=modules,
q_func_forwarder=q_func_forwarder,
targ_q_func_forwarder=targ_q_func_forwarder,
gamma=gamma,
tau=tau,
target_smoothing_sigma=target_smoothing_sigma,
target_smoothing_clip=target_smoothing_clip,
update_actor_interval=update_actor_interval,
device=device,
)
self._actor_beta = actor_beta
self._critic_beta = critic_beta

def compute_actor_loss(
self, batch: TorchMiniBatch, action: ActionOutput
) -> TD3PlusBCActorLoss:
q_t = self._q_func_forwarder.compute_expected_q(
batch.observations, action.squashed_mu, "none"
)[0]
lam = 1 / (q_t.abs().mean()).detach()
bc_loss = ((batch.actions - action.squashed_mu) ** 2).mean()
return TD3PlusBCActorLoss(
actor_loss=lam * -q_t.mean() + self._actor_beta * bc_loss,
bc_loss=bc_loss,
)

def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor:
with torch.no_grad():
action = self._modules.targ_policy(batch.next_observations)
# smoothing target
noise = torch.randn(action.mu.shape, device=batch.device)
scaled_noise = self._target_smoothing_sigma * noise
clipped_noise = scaled_noise.clamp(
-self._target_smoothing_clip, self._target_smoothing_clip
)
smoothed_action = action.squashed_mu + clipped_noise
clipped_action = smoothed_action.clamp(-1.0, 1.0)
next_q = self._targ_q_func_forwarder.compute_target(
batch.next_observations,
clipped_action,
reduction="min",
)

# BRAC reguralization
bc_loss = (clipped_action - batch.next_actions) ** 2

return next_q - self._critic_beta * bc_loss.sum(dim=1, keepdim=True)
3 changes: 3 additions & 0 deletions d3rlpy/dataset/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ class Transition:
reward: Reward. This could be a multi-step discounted return.
next_observation: Observation at next timestep. This could be
observation at multi-step ahead.
next_action: Action at next timestep. This could be action at
multi-step ahead.
terminal: Flag of environment termination.
interval: Timesteps between ``observation`` and ``next_observation``.
rewards_to_go: Remaining rewards till the end of an episode, which is
Expand All @@ -74,6 +76,7 @@ class Transition:
action: NDArray # (...)
reward: Float32NDArray # (1,)
next_observation: Observation # (...)
next_action: NDArray # (...)
terminal: float
interval: int
rewards_to_go: Float32NDArray # (L, 1)
Expand Down
7 changes: 7 additions & 0 deletions d3rlpy/dataset/mini_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class TransitionMiniBatch:
next_observations: Union[
Float32NDArray, Sequence[Float32NDArray]
] # (B, ...)
next_actions: Float32NDArray # (B, ...)
terminals: Float32NDArray # (B, 1)
intervals: Float32NDArray # (B, 1)
transitions: Sequence[Transition]
Expand All @@ -47,6 +48,8 @@ def __post_init__(self) -> None:
assert check_dtype(self.observations, np.float32)
assert check_non_1d_array(self.actions)
assert check_dtype(self.actions, np.float32)
assert check_non_1d_array(self.next_actions)
assert check_dtype(self.next_actions, np.float32)
assert check_non_1d_array(self.rewards)
assert check_dtype(self.rewards, np.float32)
assert check_non_1d_array(self.next_observations)
Expand Down Expand Up @@ -80,6 +83,9 @@ def from_transitions(
next_observations = stack_observations(
[transition.next_observation for transition in transitions]
)
next_actions = np.stack(
[transition.next_action for transition in transitions], axis=0
)
terminals = np.reshape(
np.array([transition.terminal for transition in transitions]),
[-1, 1],
Expand All @@ -93,6 +99,7 @@ def from_transitions(
actions=cast_recursively(actions, np.float32),
rewards=cast_recursively(rewards, np.float32),
next_observations=cast_recursively(next_observations, np.float32),
next_actions=cast_recursively(next_actions, np.float32),
terminals=cast_recursively(terminals, np.float32),
intervals=cast_recursively(intervals, np.float32),
transitions=transitions,
Expand Down
10 changes: 10 additions & 0 deletions d3rlpy/dataset/transition_pickers.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,19 @@ def __call__(self, episode: EpisodeBase, index: int) -> Transition:
is_terminal = episode.terminated and index == episode.size() - 1
if is_terminal:
next_observation = create_zero_observation(observation)
next_action = np.zeros_like(episode.actions[index])
else:
next_observation = retrieve_observation(
episode.observations, index + 1
)
next_action = episode.actions[index + 1]

return Transition(
observation=observation,
action=episode.actions[index],
reward=episode.rewards[index],
next_observation=next_observation,
next_action=next_action,
terminal=float(is_terminal),
interval=1,
rewards_to_go=episode.rewards[index:],
Expand Down Expand Up @@ -143,16 +146,19 @@ def __call__(self, episode: EpisodeBase, index: int) -> Transition:
is_terminal = episode.terminated and index == episode.size() - 1
if is_terminal:
next_observation = create_zero_observation(observation)
next_action = np.zeros_like(episode.actions[index])
else:
next_observation = stack_recent_observations(
episode.observations, index + 1, self._n_frames
)
next_action = episode.actions[index + 1]

return Transition(
observation=observation,
action=episode.actions[index],
reward=episode.rewards[index],
next_observation=next_observation,
next_action=next_action,
terminal=float(is_terminal),
interval=1,
rewards_to_go=episode.rewards[index:],
Expand Down Expand Up @@ -189,16 +195,19 @@ def __call__(self, episode: EpisodeBase, index: int) -> Transition:
is_terminal = next_index == episode.size()
if is_terminal:
next_observation = create_zero_observation(observation)
next_action = np.zeros_like(episode.actions[index])
else:
next_observation = retrieve_observation(
episode.observations, next_index
)
next_action = episode.actions[next_index]
else:
is_terminal = False
next_index = min(index + self._n_steps, episode.size() - 1)
next_observation = retrieve_observation(
episode.observations, next_index
)
next_action = episode.actions[next_index]

# compute multi-step return
interval = next_index - index
Expand All @@ -210,6 +219,7 @@ def __call__(self, episode: EpisodeBase, index: int) -> Transition:
action=episode.actions[index],
reward=ret,
next_observation=next_observation,
next_action=next_action,
terminal=float(is_terminal),
interval=interval,
rewards_to_go=episode.rewards[index:],
Expand Down
Loading

0 comments on commit 949ed5c

Please sign in to comment.