-
Notifications
You must be signed in to change notification settings - Fork 227
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
17 changed files
with
354 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
import dataclasses | ||
|
||
from ...base import DeviceArg, LearnableConfig, register_learnable | ||
from ...constants import ActionSpace | ||
from ...models.builders import ( | ||
create_continuous_q_function, | ||
create_deterministic_policy, | ||
) | ||
from ...models.encoders import EncoderFactory, make_encoder_field | ||
from ...models.optimizers import OptimizerFactory, make_optimizer_field | ||
from ...models.q_functions import QFunctionFactory, make_q_func_field | ||
from ...types import Shape | ||
from .base import QLearningAlgoBase | ||
from .torch.ddpg_impl import DDPGModules | ||
from .torch.rebrac_impl import ReBRACImpl | ||
|
||
__all__ = ["ReBRACConfig", "ReBRAC"] | ||
|
||
|
||
@dataclasses.dataclass() | ||
class ReBRACConfig(LearnableConfig): | ||
r"""Config of ReBRAC algorithm. | ||
ReBRAC is an extention to TD3+BC with additional optimization. | ||
#. Deeper Networks (2 -> 3 hidden layers) | ||
#. LayerNorm | ||
#. Larger Batches (256 -> 1024) | ||
#. Increased Discount Factor (0.99 -> 0.999) | ||
#. Actor and Critic penalty decoupling | ||
.. math:: | ||
J(\phi) = \mathbb{E}_{s,a \sim D} | ||
[\lambda Q(s, \pi(s)) - \beta_1 \cdot (a - \pi(s))^2] | ||
.. math:: | ||
L(\theta) = \mathbb{E}_{s,a,r,s',\hat{a'} \sim D, a' \sim \pi(s')} | ||
[(Q_\theta (s, a) - (r + \gamma Q_\theta (s', a') | ||
- \beta_2 \cdot (a' - \hat{a'})^2))^2] | ||
References: | ||
* `Tarasov et al., Revisiting the Minimalist Approach to Offline | ||
Reinforcement Learning. <https://arxiv.org/abs/2305.09836>`_ | ||
Args: | ||
observation_scaler (d3rlpy.preprocessing.ObservationScaler): | ||
Observation preprocessor. | ||
action_scaler (d3rlpy.preprocessing.ActionScaler): Action preprocessor. | ||
reward_scaler (d3rlpy.preprocessing.RewardScaler): Reward preprocessor. | ||
actor_learning_rate (float): Learning rate for a policy function. | ||
critic_learning_rate (float): Learning rate for Q functions. | ||
actor_optim_factory (d3rlpy.models.optimizers.OptimizerFactory): | ||
Optimizer factory for the actor. | ||
critic_optim_factory (d3rlpy.models.optimizers.OptimizerFactory): | ||
Optimizer factory for the critic. | ||
actor_encoder_factory (d3rlpy.models.encoders.EncoderFactory): | ||
Encoder factory for the actor. | ||
critic_encoder_factory (d3rlpy.models.encoders.EncoderFactory): | ||
Encoder factory for the critic. | ||
q_func_factory (d3rlpy.models.q_functions.QFunctionFactory): | ||
Q function factory. | ||
batch_size (int): Mini-batch size. | ||
gamma (float): Discount factor. | ||
tau (float): Target network synchronization coefficiency. | ||
n_critics (int): Number of Q functions for ensemble. | ||
target_smoothing_sigma (float): Standard deviation for target noise. | ||
target_smoothing_clip (float): Clipping range for target noise. | ||
actor_beta (float): :math:`\beta_1` value. | ||
critic_beta (float): :math:`\beta_2` value. | ||
update_actor_interval (int): Interval to update policy function | ||
described as `delayed policy update` in the paper. | ||
""" | ||
|
||
actor_learning_rate: float = 1e-3 | ||
critic_learning_rate: float = 1e-3 | ||
actor_optim_factory: OptimizerFactory = make_optimizer_field() | ||
critic_optim_factory: OptimizerFactory = make_optimizer_field() | ||
actor_encoder_factory: EncoderFactory = make_encoder_field() | ||
critic_encoder_factory: EncoderFactory = make_encoder_field() | ||
q_func_factory: QFunctionFactory = make_q_func_field() | ||
batch_size: int = 1024 | ||
gamma: float = 0.99 | ||
tau: float = 0.005 | ||
n_critics: int = 2 | ||
target_smoothing_sigma: float = 0.2 | ||
target_smoothing_clip: float = 0.5 | ||
actor_beta: float = 0.001 | ||
critic_beta: float = 0.01 | ||
update_actor_interval: int = 2 | ||
|
||
def create(self, device: DeviceArg = False) -> "ReBRAC": | ||
return ReBRAC(self, device) | ||
|
||
@staticmethod | ||
def get_type() -> str: | ||
return "rebrac" | ||
|
||
|
||
class ReBRAC(QLearningAlgoBase[ReBRACImpl, ReBRACConfig]): | ||
def inner_create_impl( | ||
self, observation_shape: Shape, action_size: int | ||
) -> None: | ||
policy = create_deterministic_policy( | ||
observation_shape, | ||
action_size, | ||
self._config.actor_encoder_factory, | ||
device=self._device, | ||
) | ||
targ_policy = create_deterministic_policy( | ||
observation_shape, | ||
action_size, | ||
self._config.actor_encoder_factory, | ||
device=self._device, | ||
) | ||
q_funcs, q_func_forwarder = create_continuous_q_function( | ||
observation_shape, | ||
action_size, | ||
self._config.critic_encoder_factory, | ||
self._config.q_func_factory, | ||
n_ensembles=self._config.n_critics, | ||
device=self._device, | ||
) | ||
targ_q_funcs, targ_q_func_forwarder = create_continuous_q_function( | ||
observation_shape, | ||
action_size, | ||
self._config.critic_encoder_factory, | ||
self._config.q_func_factory, | ||
n_ensembles=self._config.n_critics, | ||
device=self._device, | ||
) | ||
|
||
actor_optim = self._config.actor_optim_factory.create( | ||
policy.named_modules(), lr=self._config.actor_learning_rate | ||
) | ||
critic_optim = self._config.critic_optim_factory.create( | ||
q_funcs.named_modules(), lr=self._config.critic_learning_rate | ||
) | ||
|
||
modules = DDPGModules( | ||
policy=policy, | ||
targ_policy=targ_policy, | ||
q_funcs=q_funcs, | ||
targ_q_funcs=targ_q_funcs, | ||
actor_optim=actor_optim, | ||
critic_optim=critic_optim, | ||
) | ||
|
||
self._impl = ReBRACImpl( | ||
observation_shape=observation_shape, | ||
action_size=action_size, | ||
modules=modules, | ||
q_func_forwarder=q_func_forwarder, | ||
targ_q_func_forwarder=targ_q_func_forwarder, | ||
gamma=self._config.gamma, | ||
tau=self._config.tau, | ||
target_smoothing_sigma=self._config.target_smoothing_sigma, | ||
target_smoothing_clip=self._config.target_smoothing_clip, | ||
actor_beta=self._config.actor_beta, | ||
critic_beta=self._config.critic_beta, | ||
update_actor_interval=self._config.update_actor_interval, | ||
device=self._device, | ||
) | ||
|
||
def get_action_type(self) -> ActionSpace: | ||
return ActionSpace.CONTINUOUS | ||
|
||
|
||
register_learnable(ReBRACConfig) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
# pylint: disable=too-many-ancestors | ||
import torch | ||
|
||
from ....models.torch import ActionOutput, ContinuousEnsembleQFunctionForwarder | ||
from ....torch_utility import TorchMiniBatch | ||
from ....types import Shape | ||
from .ddpg_impl import DDPGModules | ||
from .td3_impl import TD3Impl | ||
from .td3_plus_bc_impl import TD3PlusBCActorLoss | ||
|
||
__all__ = ["ReBRACImpl"] | ||
|
||
|
||
class ReBRACImpl(TD3Impl): | ||
_actor_beta: float | ||
_critic_beta: float | ||
|
||
def __init__( | ||
self, | ||
observation_shape: Shape, | ||
action_size: int, | ||
modules: DDPGModules, | ||
q_func_forwarder: ContinuousEnsembleQFunctionForwarder, | ||
targ_q_func_forwarder: ContinuousEnsembleQFunctionForwarder, | ||
gamma: float, | ||
tau: float, | ||
target_smoothing_sigma: float, | ||
target_smoothing_clip: float, | ||
actor_beta: float, | ||
critic_beta: float, | ||
update_actor_interval: int, | ||
device: str, | ||
): | ||
super().__init__( | ||
observation_shape=observation_shape, | ||
action_size=action_size, | ||
modules=modules, | ||
q_func_forwarder=q_func_forwarder, | ||
targ_q_func_forwarder=targ_q_func_forwarder, | ||
gamma=gamma, | ||
tau=tau, | ||
target_smoothing_sigma=target_smoothing_sigma, | ||
target_smoothing_clip=target_smoothing_clip, | ||
update_actor_interval=update_actor_interval, | ||
device=device, | ||
) | ||
self._actor_beta = actor_beta | ||
self._critic_beta = critic_beta | ||
|
||
def compute_actor_loss( | ||
self, batch: TorchMiniBatch, action: ActionOutput | ||
) -> TD3PlusBCActorLoss: | ||
q_t = self._q_func_forwarder.compute_expected_q( | ||
batch.observations, action.squashed_mu, "none" | ||
)[0] | ||
lam = 1 / (q_t.abs().mean()).detach() | ||
bc_loss = ((batch.actions - action.squashed_mu) ** 2).mean() | ||
return TD3PlusBCActorLoss( | ||
actor_loss=lam * -q_t.mean() + self._actor_beta * bc_loss, | ||
bc_loss=bc_loss, | ||
) | ||
|
||
def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: | ||
with torch.no_grad(): | ||
action = self._modules.targ_policy(batch.next_observations) | ||
# smoothing target | ||
noise = torch.randn(action.mu.shape, device=batch.device) | ||
scaled_noise = self._target_smoothing_sigma * noise | ||
clipped_noise = scaled_noise.clamp( | ||
-self._target_smoothing_clip, self._target_smoothing_clip | ||
) | ||
smoothed_action = action.squashed_mu + clipped_noise | ||
clipped_action = smoothed_action.clamp(-1.0, 1.0) | ||
next_q = self._targ_q_func_forwarder.compute_target( | ||
batch.next_observations, | ||
clipped_action, | ||
reduction="min", | ||
) | ||
|
||
# BRAC reguralization | ||
bc_loss = (clipped_action - batch.next_actions) ** 2 | ||
|
||
return next_q - self._critic_beta * bc_loss.sum(dim=1, keepdim=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.