Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions test/test_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from torchrl.modules.tensordict_module.exploration import (
_OrnsteinUhlenbeckProcess,
OrnsteinUhlenbeckProcessWrapper,
AdditiveGaussianWrapper,
)


Expand Down Expand Up @@ -90,6 +91,98 @@ def test_ou_wrapper(device, d_obs=4, d_act=6, batch=32, n_steps=100, seed=0):
assert (out.get("action") >= -1.0).all(), out.get("action").max()


@pytest.mark.parametrize("device", get_available_devices())
class TestAdditiveGaussian:
def test_additivegaussian_sd(
self, device, d_obs=4, d_act=6, batch=32, n_steps=100, seed=0
):
torch.manual_seed(seed)
net = NormalParamWrapper(nn.Linear(d_obs, 2 * d_act)).to(device)
module = TensorDictModule(
net, in_keys=["observation"], out_keys=["loc", "scale"]
)
action_spec = NdBoundedTensorSpec(
-torch.ones(d_act, device=device),
torch.ones(d_act, device=device),
(d_act,),
device=device,
)
policy = ProbabilisticActor(
spec=action_spec,
module=module,
dist_param_keys=["loc", "scale"],
distribution_class=TanhNormal,
default_interaction_mode="random",
).to(device)
exploratory_policy = AdditiveGaussianWrapper(policy).to(device)

sigma_init = (
action_spec.project(
torch.randn(1000000, action_spec.shape[-1], device=device)
).std()
* exploratory_policy.sigma_init
)
sigma_end = (
action_spec.project(
torch.randn(1000000, action_spec.shape[-1], device=device)
).std()
* exploratory_policy.sigma_end
)
noisy_action = exploratory_policy._add_noise(
action_spec.rand((100000,)).zero_()
)
assert abs(noisy_action.std() - sigma_init) < 1e-1

for _ in range(exploratory_policy.annealing_num_steps):
exploratory_policy.step(1)
noisy_action = exploratory_policy._add_noise(
action_spec.rand((100000,)).zero_()
)
assert abs(noisy_action.std() - sigma_end) < 1e-1

def test_additivegaussian_wrapper(
self, device, d_obs=4, d_act=6, batch=32, n_steps=100, seed=0
):
torch.manual_seed(seed)
net = NormalParamWrapper(nn.Linear(d_obs, 2 * d_act)).to(device)
module = TensorDictModule(
net, in_keys=["observation"], out_keys=["loc", "scale"]
)
action_spec = NdBoundedTensorSpec(
-torch.ones(d_act, device=device),
torch.ones(d_act, device=device),
(d_act,),
device=device,
)
policy = ProbabilisticActor(
spec=action_spec,
module=module,
dist_param_keys=["loc", "scale"],
distribution_class=TanhNormal,
default_interaction_mode="random",
).to(device)
exploratory_policy = AdditiveGaussianWrapper(policy).to(device)

tensordict = TensorDict(
batch_size=[batch],
source={"observation": torch.randn(batch, d_obs, device=device)},
device=device,
)
out_noexp = []
out = []
for i in range(n_steps):
tensordict_noexp = policy(tensordict.select("observation"))
tensordict = exploratory_policy(tensordict)
out.append(tensordict.clone())
out_noexp.append(tensordict_noexp.clone())
tensordict.set_("observation", torch.randn(batch, d_obs, device=device))
out = torch.stack(out, 0)
out_noexp = torch.stack(out_noexp, 0)
assert (out_noexp.get("action") != out.get("action")).all()
assert (out.get("action") <= 1.0).all(), out.get("action").min()
assert (out.get("action") >= -1.0).all(), out.get("action").max()


@pytest.mark.parametrize("state_dim", [7])
@pytest.mark.parametrize("action_dim", [5, 11])
@pytest.mark.parametrize("gSDE", [True, False])
Expand Down
77 changes: 76 additions & 1 deletion torchrl/modules/tensordict_module/exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
TensorDictModuleWrapper,
)

__all__ = ["EGreedyWrapper", "OrnsteinUhlenbeckProcessWrapper"]
__all__ = [
"EGreedyWrapper",
"AdditiveGaussianWrapper",
"OrnsteinUhlenbeckProcessWrapper",
]

from torchrl.data.tensordict.tensordict import TensorDictBase

Expand Down Expand Up @@ -112,6 +116,77 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
return tensordict


class AdditiveGaussianWrapper(TensorDictModuleWrapper):
"""
Additive Gaussian PO wrapper.

Args:
policy (TensorDictModule): a policy.
sigma_init (scalar, optional): initial epsilon value.
default: 1.0
sigma_end (scalar, optional): final epsilon value.
default: 0.1
annealing_num_steps (int, optional): number of steps it will take for
sigma to reach the `sigma_end` value.
action_key (str, optional): if the policy module has more than one output key,
its output spec will be of type CompositeSpec. One needs to know where to
find the action spec.
Default is "action".

"""

def __init__(
self,
policy: TensorDictModule,
sigma_init: float = 1.0,
sigma_end: float = 0.1,
annealing_num_steps: int = 1000,
action_key: str = "action",
):
super().__init__(policy)
self.register_buffer("sigma_init", torch.tensor([sigma_init]))
self.register_buffer("sigma_end", torch.tensor([sigma_end]))
if self.sigma_end > self.sigma_init:
raise RuntimeError("sigma should decrease over time or be constant")
self.annealing_num_steps = annealing_num_steps
self.register_buffer("sigma", torch.tensor([sigma_init]))
self.action_key = action_key

def step(self, frames: int = 1) -> None:
"""A step of sigma decay.
After self.annealing_num_steps, this function is a no-op.

Args:
frames (int): number of frames since last step.

"""
for _ in range(frames):
self.sigma.data[0] = max(
self.sigma_end.item(),
(
self.sigma
- (self.sigma_init - self.sigma_end) / self.annealing_num_steps
).item(),
)

def _add_noise(self, action: torch.Tensor) -> torch.Tensor:
sigma = self.sigma.item()
noise = torch.randn(action.shape, device=action.device) * sigma
spec = self.td_module.spec
if isinstance(spec, CompositeSpec):
spec = spec[self.action_key]
action = action + noise
action = spec.project(action)
return action

def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
tensordict = self.td_module.forward(tensordict)
if exploration_mode() == "random" or exploration_mode() is None:
out = tensordict.get(self.action_key)
tensordict.set(self.action_key, out)
return tensordict


class OrnsteinUhlenbeckProcessWrapper(TensorDictModuleWrapper):
"""
Ornstein-Uhlenbeck exploration policy wrapper as presented in "CONTINUOUS CONTROL WITH DEEP REINFORCEMENT LEARNING",
Expand Down