Skip to content

Commit

Permalink
Refactor shape of scalers
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed May 15, 2023
1 parent 1cae35f commit 9c718e8
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 57 deletions.
43 changes: 23 additions & 20 deletions d3rlpy/preprocessing/action_scalers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
generate_optional_config_generation,
make_optional_numpy_field,
)
from .base import Scaler
from .base import Scaler, add_leading_dims, add_leading_dims_numpy

__all__ = [
"ActionScaler",
Expand Down Expand Up @@ -99,8 +99,8 @@ def fit_with_transition_picker(
else:
minimum = np.minimum(minimum, transition.action)
maximum = np.maximum(maximum, transition.action)
self.minimum = minimum.reshape((1,) + minimum.shape)
self.maximum = maximum.reshape((1,) + maximum.shape)
self.minimum = minimum
self.maximum = maximum

def fit_with_trajectory_slicer(
self,
Expand All @@ -123,53 +123,56 @@ def fit_with_trajectory_slicer(
else:
minimum = np.minimum(minimum, min_action)
maximum = np.maximum(maximum, max_action)

self.minimum = minimum.reshape((1,) + minimum.shape)
self.maximum = maximum.reshape((1,) + maximum.shape)
self.minimum = minimum
self.maximum = maximum

def fit_with_env(self, env: gym.Env[Any, Any]) -> None:
assert not self.built
assert isinstance(env.action_space, gym.spaces.Box)
shape = env.action_space.shape
low = np.asarray(env.action_space.low)
high = np.asarray(env.action_space.high)
assert shape
self.minimum = low.reshape((1, *shape))
self.maximum = high.reshape((1, *shape))
self.minimum = low
self.maximum = high

def transform(self, x: torch.Tensor) -> torch.Tensor:
assert self.built
minimum = torch.tensor(
self.minimum, dtype=torch.float32, device=x.device
minimum = add_leading_dims(
torch.tensor(self.minimum, dtype=torch.float32, device=x.device),
target=x,
)
maximum = torch.tensor(
self.maximum, dtype=torch.float32, device=x.device
maximum = add_leading_dims(
torch.tensor(self.maximum, dtype=torch.float32, device=x.device),
target=x,
)
# transform action into [-1.0, 1.0]
return ((x - minimum) / (maximum - minimum)) * 2.0 - 1.0

def reverse_transform(self, x: torch.Tensor) -> torch.Tensor:
assert self.built
minimum = torch.tensor(
self.minimum, dtype=torch.float32, device=x.device
minimum = add_leading_dims(
torch.tensor(self.minimum, dtype=torch.float32, device=x.device),
target=x,
)
maximum = torch.tensor(
self.maximum, dtype=torch.float32, device=x.device
maximum = add_leading_dims(
torch.tensor(self.maximum, dtype=torch.float32, device=x.device),
target=x,
)
# transform action from [-1.0, 1.0]
return ((maximum - minimum) * ((x + 1.0) / 2.0)) + minimum

def transform_numpy(self, x: np.ndarray) -> np.ndarray:
assert self.built
assert self.maximum is not None and self.minimum is not None
minimum, maximum = self.minimum, self.maximum
minimum = add_leading_dims_numpy(self.minimum, target=x)
maximum = add_leading_dims_numpy(self.maximum, target=x)
# transform action into [-1.0, 1.0]
return ((x - minimum) / (maximum - minimum)) * 2.0 - 1.0

def reverse_transform_numpy(self, x: np.ndarray) -> np.ndarray:
assert self.built
assert self.maximum is not None and self.minimum is not None
minimum, maximum = self.minimum, self.maximum
minimum = add_leading_dims_numpy(self.minimum, target=x)
maximum = add_leading_dims_numpy(self.maximum, target=x)
# transform action from [-1.0, 1.0]
return ((maximum - minimum) * ((x + 1.0) / 2.0)) + minimum

Expand Down
16 changes: 15 additions & 1 deletion d3rlpy/preprocessing/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
)
from ..serializable_config import DynamicConfig

__all__ = ["Scaler"]
__all__ = ["Scaler", "add_leading_dims", "add_leading_dims_numpy"]


class Scaler(DynamicConfig, metaclass=ABCMeta):
Expand Down Expand Up @@ -118,3 +118,17 @@ def built(self) -> bool:
"""
raise NotImplementedError


def add_leading_dims(x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
assert x.ndim <= target.ndim
dim_diff = target.ndim - x.ndim
assert x.shape == target.shape[dim_diff:]
return torch.reshape(x, [1] * dim_diff + list(x.shape))


def add_leading_dims_numpy(x: np.ndarray, target: np.ndarray) -> np.ndarray:
assert x.ndim <= target.ndim
dim_diff = target.ndim - x.ndim
assert x.shape == target.shape[dim_diff:]
return np.reshape(x, [1] * dim_diff + list(x.shape))
79 changes: 51 additions & 28 deletions d3rlpy/preprocessing/observation_scalers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
generate_optional_config_generation,
make_optional_numpy_field,
)
from .base import Scaler
from .base import Scaler, add_leading_dims, add_leading_dims_numpy

__all__ = [
"ObservationScaler",
Expand Down Expand Up @@ -159,8 +159,8 @@ def fit_with_transition_picker(
else:
minimum = np.minimum(minimum, observation)
maximum = np.maximum(maximum, observation)
self.minimum = minimum.reshape((1,) + minimum.shape)
self.maximum = maximum.reshape((1,) + maximum.shape)
self.minimum = minimum
self.maximum = maximum

def fit_with_trajectory_slicer(
self,
Expand All @@ -183,47 +183,54 @@ def fit_with_trajectory_slicer(
else:
minimum = np.minimum(minimum, min_observation)
maximum = np.maximum(maximum, max_observation)
self.minimum = minimum.reshape((1,) + minimum.shape)
self.maximum = maximum.reshape((1,) + maximum.shape)
self.minimum = minimum
self.maximum = maximum

def fit_with_env(self, env: gym.Env[Any, Any]) -> None:
assert not self.built
assert isinstance(env.observation_space, gym.spaces.Box)
shape = env.observation_space.shape
low = np.asarray(env.observation_space.low)
high = np.asarray(env.observation_space.high)
self.minimum = low.reshape((1, *shape))
self.maximum = high.reshape((1, *shape))
self.minimum = low
self.maximum = high

def transform(self, x: torch.Tensor) -> torch.Tensor:
assert self.built
minimum = torch.tensor(
self.minimum, dtype=torch.float32, device=x.device
minimum = add_leading_dims(
torch.tensor(self.minimum, dtype=torch.float32, device=x.device),
target=x,
)
maximum = torch.tensor(
self.maximum, dtype=torch.float32, device=x.device
maximum = add_leading_dims(
torch.tensor(self.maximum, dtype=torch.float32, device=x.device),
target=x,
)
return (x - minimum) / (maximum - minimum) * 2.0 - 1.0

def reverse_transform(self, x: torch.Tensor) -> torch.Tensor:
assert self.built
minimum = torch.tensor(
self.minimum, dtype=torch.float32, device=x.device
minimum = add_leading_dims(
torch.tensor(self.minimum, dtype=torch.float32, device=x.device),
target=x,
)
maximum = torch.tensor(
self.maximum, dtype=torch.float32, device=x.device
maximum = add_leading_dims(
torch.tensor(self.maximum, dtype=torch.float32, device=x.device),
target=x,
)
return ((maximum - minimum) * (x + 1.0) / 2.0) + minimum

def transform_numpy(self, x: np.ndarray) -> np.ndarray:
assert self.built
assert self.minimum is not None and self.maximum is not None
return (x - self.minimum) / (self.maximum - self.minimum) * 2.0 - 1.0
minimum = add_leading_dims_numpy(self.minimum, target=x)
maximum = add_leading_dims_numpy(self.maximum, target=x)
return (x - minimum) / (maximum - minimum) * 2.0 - 1.0

def reverse_transform_numpy(self, x: np.ndarray) -> np.ndarray:
assert self.built
assert self.minimum is not None and self.maximum is not None
return ((self.maximum - self.minimum) * (x + 1.0) / 2.0) + self.minimum
minimum = add_leading_dims_numpy(self.minimum, target=x)
maximum = add_leading_dims_numpy(self.maximum, target=x)
return ((maximum - minimum) * (x + 1.0) / 2.0) + minimum

@staticmethod
def get_type() -> str:
Expand Down Expand Up @@ -312,8 +319,8 @@ def fit_with_transition_picker(
total_sqsum += (transition.observation - mean) ** 2
std = np.sqrt(total_sqsum / total_count)

self.mean = mean.reshape((1,) + mean.shape)
self.std = std.reshape((1,) + std.shape)
self.mean = mean
self.std = std

def fit_with_trajectory_slicer(
self,
Expand Down Expand Up @@ -343,8 +350,8 @@ def fit_with_trajectory_slicer(
total_sqsum += np.sum((observations - expanded_mean) ** 2, axis=0)
std = np.sqrt(total_sqsum / total_count)

self.mean = mean.reshape((1,) + mean.shape)
self.std = std.reshape((1,) + std.shape)
self.mean = mean
self.std = std

def fit_with_env(self, env: gym.Env[Any, Any]) -> None:
raise NotImplementedError(
Expand All @@ -353,25 +360,41 @@ def fit_with_env(self, env: gym.Env[Any, Any]) -> None:

def transform(self, x: torch.Tensor) -> torch.Tensor:
assert self.built
mean = torch.tensor(self.mean, dtype=torch.float32, device=x.device)
std = torch.tensor(self.std, dtype=torch.float32, device=x.device)
mean = add_leading_dims(
torch.tensor(self.mean, dtype=torch.float32, device=x.device),
target=x,
)
std = add_leading_dims(
torch.tensor(self.std, dtype=torch.float32, device=x.device),
target=x,
)
return (x - mean) / (std + self.eps)

def reverse_transform(self, x: torch.Tensor) -> torch.Tensor:
assert self.built
mean = torch.tensor(self.mean, dtype=torch.float32, device=x.device)
std = torch.tensor(self.std, dtype=torch.float32, device=x.device)
mean = add_leading_dims(
torch.tensor(self.mean, dtype=torch.float32, device=x.device),
target=x,
)
std = add_leading_dims(
torch.tensor(self.std, dtype=torch.float32, device=x.device),
target=x,
)
return ((std + self.eps) * x) + mean

def transform_numpy(self, x: np.ndarray) -> np.ndarray:
assert self.built
assert self.mean is not None and self.std is not None
return (x - self.mean) / (self.std + self.eps)
mean = add_leading_dims_numpy(self.mean, target=x)
std = add_leading_dims_numpy(self.std, target=x)
return (x - mean) / (std + self.eps)

def reverse_transform_numpy(self, x: np.ndarray) -> np.ndarray:
assert self.built
assert self.mean is not None and self.std is not None
return ((self.std + self.eps) * x) + self.mean
mean = add_leading_dims_numpy(self.mean, target=x)
std = add_leading_dims_numpy(self.std, target=x)
return ((std + self.eps) * x) + mean

@staticmethod
def get_type() -> str:
Expand Down
16 changes: 16 additions & 0 deletions tests/preprocessing/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import numpy as np
import torch

from d3rlpy.preprocessing.base import add_leading_dims, add_leading_dims_numpy


def test_add_leading_dims() -> None:
x = torch.rand(3)
target = torch.rand(1, 2, 3)
assert add_leading_dims(x, target).shape == (1, 1, 3)


def test_add_leading_dims_numpy() -> None:
x = np.random.random(3)
target = np.random.random((1, 2, 3))
assert add_leading_dims_numpy(x, target).shape == (1, 1, 3)
16 changes: 8 additions & 8 deletions tests/preprocessing/test_observation_scalers.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ def test_min_max_observation_scaler_with_transition_picker(
scaler.fit_with_transition_picker(episodes, BasicTransitionPicker())
assert scaler.built
assert scaler.minimum is not None and scaler.maximum is not None
assert np.allclose(scaler.minimum[0], minimum)
assert np.allclose(scaler.maximum[0], maximum)
assert np.allclose(scaler.minimum, minimum)
assert np.allclose(scaler.maximum, maximum)


@pytest.mark.parametrize("observation_shape", [(100,)])
Expand Down Expand Up @@ -137,8 +137,8 @@ def test_min_max_observation_scaler_with_trajectory_slicer(
scaler.fit_with_trajectory_slicer(episodes, BasicTrajectorySlicer())
assert scaler.built
assert scaler.minimum is not None and scaler.maximum is not None
assert np.allclose(scaler.minimum[0], minimum)
assert np.allclose(scaler.maximum[0], maximum)
assert np.allclose(scaler.minimum, minimum)
assert np.allclose(scaler.maximum, maximum)


def test_min_max_observation_scaler_with_env() -> None:
Expand Down Expand Up @@ -221,8 +221,8 @@ def test_standard_observation_scaler_with_transition_picker(
scaler.fit_with_transition_picker(episodes, BasicTransitionPicker())
assert scaler.built
assert scaler.mean is not None and scaler.std is not None
assert np.allclose(scaler.mean[0], mean)
assert np.allclose(scaler.std[0], std)
assert np.allclose(scaler.mean, mean)
assert np.allclose(scaler.std, std)


@pytest.mark.parametrize("observation_shape", [(100,)])
Expand Down Expand Up @@ -253,5 +253,5 @@ def test_standard_observation_scaler_with_trajectory_slicer(
scaler.fit_with_trajectory_slicer(episodes, BasicTrajectorySlicer())
assert scaler.built
assert scaler.mean is not None and scaler.std is not None
assert np.allclose(scaler.mean[0], mean)
assert np.allclose(scaler.std[0], std)
assert np.allclose(scaler.mean, mean)
assert np.allclose(scaler.std, std)

0 comments on commit 9c718e8

Please sign in to comment.