Skip to content

Commit 4362e67

Browse files
committed
Merge with main
2 parents 9be78d1 + 778f2f4 commit 4362e67

File tree

3 files changed

+3
-9
lines changed

3 files changed

+3
-9
lines changed

test/test_tensor_spec.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,6 @@ def test_multi_discrete(shape, ns):
259259
np.random.seed(0)
260260
ts = MultiDiscreteTensorSpec(ns)
261261
_real_shape = shape if shape is not None else []
262-
_len_ns = [len(ns)] if len(ns) > 1 else []
263262
nvec_shape = torch.tensor(ns).size()
264263
if nvec_shape == torch.Size([1]):
265264
nvec_shape = []

torchrl/data/tensor_specs.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1087,12 +1087,11 @@ def _project(self, val: torch.Tensor) -> torch.Tensor:
10871087
def is_in(self, val: torch.Tensor) -> bool:
10881088
if val.ndim < 1:
10891089
val = val.unsqueeze(0)
1090-
val_is_too_small = len(self.shape) > val.ndim
10911090
val_have_wrong_dim = (
10921091
self.shape != torch.Size([1])
10931092
and val.shape[-len(self.shape) :] != self.shape
10941093
)
1095-
if self.dtype != val.dtype or val_is_too_small or val_have_wrong_dim:
1094+
if self.dtype != val.dtype or len(self.shape) > val.ndim or val_have_wrong_dim:
10961095
return False
10971096

10981097
for permutation in itertools.product(*[range(d) for d in self.shape]):
@@ -1107,7 +1106,6 @@ def to_onehot(self) -> MultOneHotDiscreteTensorSpec:
11071106
f"DiscreteTensorSpec with shape != tensor.Size([1]) can't be converted OneHotDiscreteTensorSpec. Got "
11081107
f"shape={self.shape}."
11091108
)
1110-
11111109
return MultOneHotDiscreteTensorSpec(
11121110
[_space.n for _space in self.space], self.device, self.dtype
11131111
)

torchrl/objectives/ppo.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -233,13 +233,10 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
233233
f"and {log_weight.shape})"
234234
)
235235
gain1 = log_weight.exp() * advantage
236-
log_weight_clip = torch.empty_like(log_weight)
237-
# log_weight_clip.data.clamp_(*self._clip_bounds)
238-
idx_pos = advantage >= 0
239-
log_weight_clip[idx_pos] = log_weight[idx_pos].clamp_max(self._clip_bounds[1])
240-
log_weight_clip[~idx_pos] = log_weight[~idx_pos].clamp_min(self._clip_bounds[0])
241236

237+
log_weight_clip = log_weight.clamp(*self._clip_bounds)
242238
gain2 = log_weight_clip.exp() * advantage
239+
243240
gain = torch.stack([gain1, gain2], -1).min(dim=-1)[0]
244241
td_out = TensorDict({"loss_objective": -gain.mean()}, [])
245242

0 commit comments

Comments
 (0)