Skip to content

Commit

Permalink
fix(pytorch): resolve critical action rescaling bug (#403)
Browse files Browse the repository at this point in the history
This commit addresses a critical bug related to action rescaling in
PyTorch, which was preventing the agent from training effectively in
specific environments.
  • Loading branch information
rickstaa committed Feb 7, 2024
1 parent 0892e09 commit 71d4f64
Showing 1 changed file with 7 additions and 14 deletions.
21 changes: 7 additions & 14 deletions stable_learning_control/algos/pytorch/common/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,24 +126,17 @@ def rescale(data, min_bound, max_bound):
the desired range.
Returns:
torch.Tensor: Array which has it values scaled between the min and max
boundaries.
Union[Torch.Tensor, numpy.ndarray]: Array which has it values scaled between
the min and max boundaries.
"""
data = torch.tensor(data) if not isinstance(data, torch.Tensor) else data
min_bound = (
torch.tensor(min_bound, device=data.device)
if not isinstance(min_bound, torch.Tensor)
else min_bound.to(data.device)
)
max_bound = (
torch.tensor(max_bound, device=data.device)
if not isinstance(max_bound, torch.Tensor)
else max_bound.to(data.device)
)
was_numpy = isinstance(data, np.ndarray)
data = torch.as_tensor(data)
min_bound = torch.as_tensor(min_bound, device=data.device)
max_bound = torch.as_tensor(max_bound, device=data.device)

# Return rescaled data in the same format as the input data.
data_rescaled = (data + 1.0) * (max_bound - min_bound) / 2 + min_bound
return data_rescaled.astype(data.dtype) if isinstance(data, np.ndarray) else data
return data_rescaled.cpu().numpy() if was_numpy else data_rescaled


def np_to_torch(input_object, dtype=None, device=None):
Expand Down

0 comments on commit 71d4f64

Please sign in to comment.