Skip to content

Commit

Permalink
Cache tensors at preprocessors
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Jun 17, 2023
1 parent 12512e4 commit 489169a
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 44 deletions.
36 changes: 22 additions & 14 deletions d3rlpy/preprocessing/action_scalers.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ def __post_init__(self) -> None:
self.minimum = np.asarray(self.minimum)
if self.maximum is not None:
self.maximum = np.asarray(self.maximum)
self._torch_minimum: Optional[torch.Tensor] = None
self._torch_maximum: Optional[torch.Tensor] = None

def fit_with_transition_picker(
self,
Expand Down Expand Up @@ -120,27 +122,25 @@ def fit_with_env(self, env: gym.Env[Any, Any]) -> None:

def transform(self, x: torch.Tensor) -> torch.Tensor:
assert self.built
minimum = add_leading_dims(
torch.tensor(self.minimum, dtype=torch.float32, device=x.device),
target=x,
)
maximum = add_leading_dims(
torch.tensor(self.maximum, dtype=torch.float32, device=x.device),
target=x,
if self._torch_minimum is None or self._torch_maximum is None:
self._set_torch_value(x.device)
assert (
self._torch_minimum is not None and self._torch_maximum is not None
)
minimum = add_leading_dims(self._torch_minimum, target=x)
maximum = add_leading_dims(self._torch_maximum, target=x)
# transform action into [-1.0, 1.0]
return ((x - minimum) / (maximum - minimum)) * 2.0 - 1.0

def reverse_transform(self, x: torch.Tensor) -> torch.Tensor:
assert self.built
minimum = add_leading_dims(
torch.tensor(self.minimum, dtype=torch.float32, device=x.device),
target=x,
)
maximum = add_leading_dims(
torch.tensor(self.maximum, dtype=torch.float32, device=x.device),
target=x,
if self._torch_minimum is None or self._torch_maximum is None:
self._set_torch_value(x.device)
assert (
self._torch_minimum is not None and self._torch_maximum is not None
)
minimum = add_leading_dims(self._torch_minimum, target=x)
maximum = add_leading_dims(self._torch_maximum, target=x)
# transform action from [-1.0, 1.0]
return ((maximum - minimum) * ((x + 1.0) / 2.0)) + minimum

Expand All @@ -160,6 +160,14 @@ def reverse_transform_numpy(self, x: np.ndarray) -> np.ndarray:
# transform action from [-1.0, 1.0]
return ((maximum - minimum) * ((x + 1.0) / 2.0)) + minimum

def _set_torch_value(self, device: torch.device) -> None:
self._torch_minimum = torch.tensor(
self.minimum, dtype=torch.float32, device=device
)
self._torch_maximum = torch.tensor(
self.maximum, dtype=torch.float32, device=device
)

@staticmethod
def get_type() -> str:
return "min_max"
Expand Down
72 changes: 42 additions & 30 deletions d3rlpy/preprocessing/observation_scalers.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ def __post_init__(self) -> None:
self.minimum = np.asarray(self.minimum)
if self.maximum is not None:
self.maximum = np.asarray(self.maximum)
self._torch_minimum: Optional[torch.Tensor] = None
self._torch_maximum: Optional[torch.Tensor] = None

def fit_with_transition_picker(
self,
Expand Down Expand Up @@ -180,26 +182,24 @@ def fit_with_env(self, env: gym.Env[Any, Any]) -> None:

def transform(self, x: torch.Tensor) -> torch.Tensor:
assert self.built
minimum = add_leading_dims(
torch.tensor(self.minimum, dtype=torch.float32, device=x.device),
target=x,
)
maximum = add_leading_dims(
torch.tensor(self.maximum, dtype=torch.float32, device=x.device),
target=x,
if self._torch_maximum is None or self._torch_minimum is None:
self._set_torch_value(x.device)
assert (
self._torch_minimum is not None and self._torch_maximum is not None
)
minimum = add_leading_dims(self._torch_minimum, target=x)
maximum = add_leading_dims(self._torch_maximum, target=x)
return (x - minimum) / (maximum - minimum) * 2.0 - 1.0

def reverse_transform(self, x: torch.Tensor) -> torch.Tensor:
assert self.built
minimum = add_leading_dims(
torch.tensor(self.minimum, dtype=torch.float32, device=x.device),
target=x,
)
maximum = add_leading_dims(
torch.tensor(self.maximum, dtype=torch.float32, device=x.device),
target=x,
if self._torch_maximum is None or self._torch_minimum is None:
self._set_torch_value(x.device)
assert (
self._torch_minimum is not None and self._torch_maximum is not None
)
minimum = add_leading_dims(self._torch_minimum, target=x)
maximum = add_leading_dims(self._torch_maximum, target=x)
return ((maximum - minimum) * (x + 1.0) / 2.0) + minimum

def transform_numpy(self, x: np.ndarray) -> np.ndarray:
Expand All @@ -216,6 +216,14 @@ def reverse_transform_numpy(self, x: np.ndarray) -> np.ndarray:
maximum = add_leading_dims_numpy(self.maximum, target=x)
return ((maximum - minimum) * (x + 1.0) / 2.0) + minimum

def _set_torch_value(self, device: torch.device) -> None:
self._torch_minimum = torch.tensor(
self.minimum, dtype=torch.float32, device=device
)
self._torch_maximum = torch.tensor(
self.maximum, dtype=torch.float32, device=device
)

@staticmethod
def get_type() -> str:
return "min_max"
Expand Down Expand Up @@ -262,6 +270,8 @@ def __post_init__(self) -> None:
self.mean = np.asarray(self.mean)
if self.std is not None:
self.std = np.asarray(self.std)
self._torch_mean: Optional[torch.Tensor] = None
self._torch_std: Optional[torch.Tensor] = None

def fit_with_transition_picker(
self,
Expand Down Expand Up @@ -328,26 +338,20 @@ def fit_with_env(self, env: gym.Env[Any, Any]) -> None:

def transform(self, x: torch.Tensor) -> torch.Tensor:
assert self.built
mean = add_leading_dims(
torch.tensor(self.mean, dtype=torch.float32, device=x.device),
target=x,
)
std = add_leading_dims(
torch.tensor(self.std, dtype=torch.float32, device=x.device),
target=x,
)
if self._torch_mean is None or self._torch_std is None:
self._set_torch_value(x.device)
assert self._torch_mean is not None and self._torch_std is not None
mean = add_leading_dims(self._torch_mean, target=x)
std = add_leading_dims(self._torch_std, target=x)
return (x - mean) / (std + self.eps)

def reverse_transform(self, x: torch.Tensor) -> torch.Tensor:
assert self.built
mean = add_leading_dims(
torch.tensor(self.mean, dtype=torch.float32, device=x.device),
target=x,
)
std = add_leading_dims(
torch.tensor(self.std, dtype=torch.float32, device=x.device),
target=x,
)
if self._torch_mean is None or self._torch_std is None:
self._set_torch_value(x.device)
assert self._torch_mean is not None and self._torch_std is not None
mean = add_leading_dims(self._torch_mean, target=x)
std = add_leading_dims(self._torch_std, target=x)
return ((std + self.eps) * x) + mean

def transform_numpy(self, x: np.ndarray) -> np.ndarray:
Expand All @@ -364,6 +368,14 @@ def reverse_transform_numpy(self, x: np.ndarray) -> np.ndarray:
std = add_leading_dims_numpy(self.std, target=x)
return ((std + self.eps) * x) + mean

def _set_torch_value(self, device: torch.device) -> None:
self._torch_mean = torch.tensor(
self.mean, dtype=torch.float32, device=device
)
self._torch_std = torch.tensor(
self.std, dtype=torch.float32, device=device
)

@staticmethod
def get_type() -> str:
return "standard"
Expand Down

0 comments on commit 489169a

Please sign in to comment.