Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions test/test_tensor_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,13 +972,13 @@ def test_discrete_action_spec_reconstruct(self, action_spec_cls):
actions_numpy = [action_spec.to_numpy(a) for a in actions_tensors]
actions_tensors_2 = [action_spec.encode(a) for a in actions_numpy]
assert all(
[(a1 == a2).all() for a1, a2 in zip(actions_tensors, actions_tensors_2)]
(a1 == a2).all() for a1, a2 in zip(actions_tensors, actions_tensors_2)
)

actions_numpy = [int(np.random.randint(0, 10, (1,))) for a in actions_tensors]
actions_tensors = [action_spec.encode(a) for a in actions_numpy]
actions_numpy_2 = [action_spec.to_numpy(a) for a in actions_tensors]
assert all([(a1 == a2) for a1, a2 in zip(actions_numpy, actions_numpy_2)])
assert all((a1 == a2) for a1, a2 in zip(actions_numpy, actions_numpy_2))

def test_mult_discrete_action_spec_reconstruct(self):
torch.manual_seed(0)
Expand All @@ -999,7 +999,7 @@ def test_mult_discrete_action_spec_reconstruct(self):
]
actions_tensors = [action_spec.encode(a) for a in actions_numpy]
actions_numpy_2 = [action_spec.to_numpy(a) for a in actions_tensors]
assert all([(a1 == a2).all() for a1, a2 in zip(actions_numpy, actions_numpy_2)])
assert all((a1 == a2).all() for a1, a2 in zip(actions_numpy, actions_numpy_2))

def test_one_hot_discrete_action_spec_rand(self):
torch.manual_seed(0)
Expand Down
4 changes: 2 additions & 2 deletions test/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,8 +456,8 @@ def make_storage():
) # trainer.app_state["state"]["replay_buffer.replay_buffer._storage._storage"]
td2 = trainer2._modules["replay_buffer"].replay_buffer._storage._storage
if storage_type == "list":
assert all([(_td1 == _td2).all() for _td1, _td2 in zip(td1, td2)])
assert all([(_td1 is not _td2) for _td1, _td2 in zip(td1, td2)])
assert all((_td1 == _td2).all() for _td1, _td2 in zip(td1, td2))
assert all((_td1 is not _td2) for _td1, _td2 in zip(td1, td2))
assert storage2._storage is td2
else:
assert (td1 == td2).all()
Expand Down
4 changes: 2 additions & 2 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1792,7 +1792,7 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
self.in_keys = self._find_in_keys()
self._initialized = True

if all([key in tensordict.keys(include_nested=True) for key in self.in_keys]):
if all(key in tensordict.keys(include_nested=True) for key in self.in_keys):
values = [tensordict.get(key) for key in self.in_keys]
if self.unsqueeze_if_oor:
pos_idx = self.dim > 0
Expand Down Expand Up @@ -2560,7 +2560,7 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec
if isinstance(reward_spec, CompositeSpec):

# If reward_spec is a CompositeSpec, all in_keys should be keys of reward_spec
if not all([k in reward_spec.keys() for k in self.in_keys]):
if not all(k in reward_spec.keys() for k in self.in_keys):
raise KeyError("Not all in_keys are present in ´reward_spec´")

# Define episode specs for all out_keys
Expand Down