Skip to content

Commit

Permalink
Refactor CQL and CalQL
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed May 6, 2024
1 parent 4dee692 commit 6223bf6
Show file tree
Hide file tree
Showing 7 changed files with 181 additions and 50 deletions.
6 changes: 6 additions & 0 deletions d3rlpy/algos/qlearning/cal_ql.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class CalQLConfig(CQLConfig):
n_action_samples (int): Number of sampled actions to compute
:math:`\log{\sum_a \exp{Q(s, a)}}`.
soft_q_backup (bool): Flag to use SAC-style backup.
max_q_backup (bool): Flag to sample max Q-values for target.
"""

def create(self, device: DeviceArg = False) -> "CalQL":
Expand All @@ -82,6 +83,10 @@ class CalQL(CQL):
def inner_create_impl(
self, observation_shape: Shape, action_size: int
) -> None:
assert not (
self._config.soft_q_backup and self._config.max_q_backup
), "soft_q_backup and max_q_backup are mutually exclusive."

policy = create_normal_policy(
observation_shape,
action_size,
Expand Down Expand Up @@ -156,6 +161,7 @@ def inner_create_impl(
conservative_weight=self._config.conservative_weight,
n_action_samples=self._config.n_action_samples,
soft_q_backup=self._config.soft_q_backup,
max_q_backup=self._config.max_q_backup,
device=self._device,
)

Expand Down
7 changes: 7 additions & 0 deletions d3rlpy/algos/qlearning/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ class CQLConfig(LearnableConfig):
n_action_samples (int): Number of sampled actions to compute
:math:`\log{\sum_a \exp{Q(s, a)}}`.
soft_q_backup (bool): Flag to use SAC-style backup.
max_q_backup (bool): Flag to sample max Q-values for target.
"""

actor_learning_rate: float = 1e-4
Expand All @@ -122,6 +123,7 @@ class CQLConfig(LearnableConfig):
conservative_weight: float = 5.0
n_action_samples: int = 10
soft_q_backup: bool = False
max_q_backup: bool = False

def create(self, device: DeviceArg = False) -> "CQL":
return CQL(self, device)
Expand All @@ -135,6 +137,10 @@ class CQL(QLearningAlgoBase[CQLImpl, CQLConfig]):
def inner_create_impl(
self, observation_shape: Shape, action_size: int
) -> None:
assert not (
self._config.soft_q_backup and self._config.max_q_backup
), "soft_q_backup and max_q_backup are mutually exclusive."

policy = create_normal_policy(
observation_shape,
action_size,
Expand Down Expand Up @@ -209,6 +215,7 @@ def inner_create_impl(
conservative_weight=self._config.conservative_weight,
n_action_samples=self._config.n_action_samples,
soft_q_backup=self._config.soft_q_backup,
max_q_backup=self._config.max_q_backup,
device=self._device,
)

Expand Down
74 changes: 33 additions & 41 deletions d3rlpy/algos/qlearning/torch/cql_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
ContinuousEnsembleQFunctionForwarder,
DiscreteEnsembleQFunctionForwarder,
Parameter,
build_squashed_gaussian_distribution,
get_parameter,
)
from ....torch_utility import (
Expand All @@ -22,6 +21,7 @@
from .ddpg_impl import DDPGBaseCriticLoss
from .dqn_impl import DoubleDQNImpl, DQNLoss, DQNModules
from .sac_impl import SACImpl, SACModules
from .utility import sample_q_values_with_policy

__all__ = ["CQLImpl", "DiscreteCQLImpl", "CQLModules", "DiscreteCQLLoss"]

Expand All @@ -44,6 +44,7 @@ class CQLImpl(SACImpl):
_conservative_weight: float
_n_action_samples: int
_soft_q_backup: bool
_max_q_backup: bool

def __init__(
self,
Expand All @@ -58,6 +59,7 @@ def __init__(
conservative_weight: float,
n_action_samples: int,
soft_q_backup: bool,
max_q_backup: bool,
device: str,
):
super().__init__(
Expand All @@ -74,6 +76,7 @@ def __init__(
self._conservative_weight = conservative_weight
self._n_action_samples = n_action_samples
self._soft_q_backup = soft_q_backup
self._max_q_backup = max_q_backup

def compute_critic_loss(
self, batch: TorchMiniBatch, q_tpn: torch.Tensor
Expand All @@ -88,16 +91,16 @@ def compute_critic_loss(
if self._modules.alpha_optim:
self.update_alpha(conservative_loss)
return CQLCriticLoss(
critic_loss=loss.critic_loss + conservative_loss,
conservative_loss=conservative_loss,
critic_loss=loss.critic_loss + conservative_loss.sum(),
conservative_loss=conservative_loss.sum(),
alpha=get_parameter(self._modules.log_alpha).exp(),
)

def update_alpha(self, conservative_loss: torch.Tensor) -> None:
assert self._modules.alpha_optim
self._modules.alpha_optim.zero_grad()
# the original implementation does scale the loss value
loss = -conservative_loss
loss = -conservative_loss.mean()
loss.backward(retain_graph=True)
self._modules.alpha_optim.step()

Expand All @@ -107,39 +110,14 @@ def _compute_policy_is_values(
value_obs: TorchObservation,
returns_to_go: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
with torch.no_grad():
dist = build_squashed_gaussian_distribution(
self._modules.policy(policy_obs)
)
policy_actions, n_log_probs = dist.sample_n_with_log_prob(
self._n_action_samples
)

# (batch, observation) -> (batch, n, observation)
repeated_obs = expand_and_repeat_recursively(
value_obs, self._n_action_samples
)
# (batch, n, observation) -> (batch * n, observation)
flat_obs = flatten_left_recursively(repeated_obs, dim=1)
# (batch, n, action) -> (batch * n, action)
flat_policy_acts = policy_actions.reshape(-1, self.action_size)

# estimate action-values for policy actions
policy_values = self._q_func_forwarder.compute_expected_q(
flat_obs, flat_policy_acts, "none"
return sample_q_values_with_policy(
policy=self._modules.policy,
q_func_forwarder=self._q_func_forwarder,
policy_observations=policy_obs,
value_observations=value_obs,
n_action_samples=self._n_action_samples,
detach_policy_output=True,
)
batch_size = (
policy_obs.shape[0]
if isinstance(policy_obs, torch.Tensor)
else policy_obs[0].shape[0]
)
policy_values = policy_values.view(
-1, batch_size, self._n_action_samples
)
log_probs = n_log_probs.view(1, -1, self._n_action_samples)

# importance sampling
return policy_values, log_probs

def _compute_random_is_values(
self, obs: TorchObservation
Expand Down Expand Up @@ -206,26 +184,40 @@ def _compute_conservative_loss(
obs_t, act_t, "none"
)

loss = logsumexp.mean(dim=0).mean() - data_values.mean(dim=0).mean()
scaled_loss = self._conservative_weight * loss
loss = (logsumexp - data_values).mean(dim=[1, 2])

# clip for stability
log_alpha = get_parameter(self._modules.log_alpha)
clipped_alpha = log_alpha.exp().clamp(0, 1e6)[0][0]

return clipped_alpha * (scaled_loss - self._alpha_threshold)
return (
clipped_alpha
* self._conservative_weight
* (loss - self._alpha_threshold)
)

def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor:
if self._soft_q_backup:
target_value = super().compute_target(batch)
else:
target_value = self._compute_deterministic_target(batch)
with torch.no_grad():
target_value = self._compute_deterministic_target(batch)
return target_value

def _compute_deterministic_target(
self, batch: TorchMiniBatch
) -> torch.Tensor:
with torch.no_grad():
if self._max_q_backup:
q_values, _ = sample_q_values_with_policy(
policy=self._modules.policy,
q_func_forwarder=self._targ_q_func_forwarder,
policy_observations=batch.next_observations,
value_observations=batch.next_observations,
n_action_samples=self._n_action_samples,
detach_policy_output=True,
)
return q_values.min(dim=0).values.max(dim=1, keepdims=True).values
else:
action = self._modules.policy(batch.next_observations).squashed_mu
return self._targ_q_func_forwarder.compute_target(
batch.next_observations,
Expand Down
57 changes: 56 additions & 1 deletion d3rlpy/algos/qlearning/torch/utility.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,25 @@
from typing import Tuple

import torch
from typing_extensions import Protocol

from ....models.torch import (
ContinuousEnsembleQFunctionForwarder,
DiscreteEnsembleQFunctionForwarder,
NormalPolicy,
build_squashed_gaussian_distribution,
)
from ....torch_utility import (
expand_and_repeat_recursively,
flatten_left_recursively,
)
from ....types import TorchObservation

__all__ = ["DiscreteQFunctionMixin", "ContinuousQFunctionMixin"]
__all__ = [
"DiscreteQFunctionMixin",
"ContinuousQFunctionMixin",
"sample_q_values_with_policy",
]


class _DiscreteQFunctionProtocol(Protocol):
Expand Down Expand Up @@ -38,3 +50,46 @@ def inner_predict_value(
return self._q_func_forwarder.compute_expected_q(
x, action, reduction="mean"
).reshape(-1)


def sample_q_values_with_policy(
policy: NormalPolicy,
q_func_forwarder: ContinuousEnsembleQFunctionForwarder,
policy_observations: TorchObservation,
value_observations: TorchObservation,
n_action_samples: int,
detach_policy_output: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
dist = build_squashed_gaussian_distribution(policy(policy_observations))
# (batch, n, action), (batch, n)
policy_actions, n_log_probs = dist.sample_n_with_log_prob(n_action_samples)

if detach_policy_output:
policy_actions = policy_actions.detach()
n_log_probs = n_log_probs.detach()

# (batch, observation) -> (batch, n, observation)
repeated_obs = expand_and_repeat_recursively(
x=value_observations,
n=n_action_samples,
)
# (batch, n, observation) -> (batch * n, observation)
flat_obs = flatten_left_recursively(repeated_obs, dim=1)
# (batch, n, action) -> (batch * n, action)
flat_policy_acts = policy_actions.reshape(-1, policy_actions.shape[-1])

# estimate action-values for policy actions
# (M, batch * n, 1)
policy_values = q_func_forwarder.compute_expected_q(
flat_obs, flat_policy_acts, "none"
)
batch_size = (
policy_observations.shape[0]
if isinstance(policy_observations, torch.Tensor)
else policy_observations[0].shape[0]
)
policy_values = policy_values.view(-1, batch_size, n_action_samples)
log_probs = n_log_probs.view(1, -1, n_action_samples)

# (M, batch, n), (1, batch, n)
return policy_values, log_probs
32 changes: 24 additions & 8 deletions reproductions/finetuning/cal_ql_finetune.py
Original file line number Diff line number Diff line change
@@ -1,54 +1,70 @@
import argparse
import copy

import d3rlpy


def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, default="antmaze-umaze-v0")
parser.add_argument(
"--dataset", type=str, default="antmaze-medium-diverse-v2"
)
parser.add_argument("--seed", type=int, default=1)
parser.add_argument("--gpu", type=int)
args = parser.parse_args()

dataset, env = d3rlpy.datasets.get_minari(args.dataset)
dataset, env = d3rlpy.datasets.get_d4rl(args.dataset)

# fix seed
d3rlpy.seed(args.seed)
d3rlpy.envs.seed_env(env, args.seed)

# for antmaze datasets
reward_scaler = d3rlpy.preprocessing.ConstantShiftRewardScaler(shift=-1)
reward_scaler = d3rlpy.preprocessing.ConstantShiftRewardScaler(
shift=-5,
multiplier=10.0,
multiply_first=True,
)

encoder = d3rlpy.models.encoders.VectorEncoderFactory(
[256, 256, 256, 256],
)

cal_ql = d3rlpy.algos.CalQLConfig(
actor_learning_rate=3e-4,
actor_learning_rate=1e-4,
critic_learning_rate=3e-4,
temp_learning_rate=1e-4,
alpha_learning_rate=3e-4,
initial_alpha=2.72,
batch_size=256,
conservative_weight=5.0,
critic_encoder_factory=encoder,
alpha_threshold=0.8,
reward_scaler=reward_scaler,
max_q_backup=True,
).create(device=args.gpu)

# pretraining
cal_ql.fit(
dataset,
n_steps=1000000,
n_steps_per_epoch=100000,
n_steps_per_epoch=1000,
save_interval=10,
evaluators={"environment": d3rlpy.metrics.EnvironmentEvaluator(env)},
experiment_name=f"CalQL_pretraining_{args.dataset}_{args.seed}",
)

# prepare FIFO buffer filled with dataset episodes
buffer = d3rlpy.dataset.create_fifo_replay_buffer(1000000)
buffer = d3rlpy.dataset.create_fifo_replay_buffer(1000000, env=env)

# sample half from offline dataset and the rest from online buffer
mixed_buffer = d3rlpy.dataset.MixedReplayBuffer(
primary_replay_buffer=buffer,
secondary_replay_buffer=dataset,
secondary_mix_ratio=0.5,
)

# finetuning
eval_env = copy.deepcopy(env)
_, eval_env = d3rlpy.datasets.get_d4rl(args.dataset)
d3rlpy.envs.seed_env(eval_env, args.seed)
cal_ql.fit_online(
env,
Expand Down
Empty file.
Loading

0 comments on commit 6223bf6

Please sign in to comment.