From e51edbfb1e7d9dd3f17c5db57a65bdec74645eb0 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 6 Feb 2024 13:31:42 +0000 Subject: [PATCH] init --- test/test_cost.py | 2 +- torchrl/collectors/collectors.py | 22 +++++++++++++------ torchrl/envs/batched_envs.py | 31 ++++++++------------------- torchrl/envs/common.py | 2 +- torchrl/envs/transforms/transforms.py | 2 +- torchrl/objectives/cql.py | 28 ++++++++++++++++-------- 6 files changed, 46 insertions(+), 41 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index c6eb27172ee..dae1fa5f70c 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -6100,7 +6100,7 @@ def zero_param(p): if isinstance(p, nn.Parameter): p.data.zero_() - params.apply(zero_param) + params.apply(zero_param, filter_empty=True) # assert len(list(floss_fn.parameters())) == 0 with params.to_module(loss_fn): diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index bea46bb6cd4..202dcc9ead8 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -801,22 +801,30 @@ def check_exclusive(val): "Consider using a placeholder for missing keys." ) - policy_output._fast_apply(check_exclusive, call_on_nested=True) + policy_output._fast_apply( + check_exclusive, call_on_nested=True, filter_empty=True + ) + # Use apply, because it works well with lazy stacks # Edge-case of this approach: the policy may change the values in-place and only by a tiny bit # or occasionally. In these cases, the keys will be missed (we can't detect if the policy has # changed them here). # This will cause a failure to update entries when policy and env device mismatch and # casting is necessary. + def filter_policy(value_output, value_input, value_input_clone): + if ( + (value_input is None) + or (value_output is not value_input) + or ~torch.isclose(value_output, value_input_clone).any() + ): + return value_output + filtered_policy_output = policy_output.apply( - lambda value_output, value_input, value_input_clone: value_output - if (value_input is None) - or (value_output is not value_input) - or ~torch.isclose(value_output, value_input_clone).any() - else None, + filter_policy, policy_input_copy, policy_input_clone, default=None, + filter_empty=True, ) self._policy_output_keys = list( self._policy_output_keys.union( @@ -933,7 +941,7 @@ def cuda_check(tensor: torch.Tensor): if tensor.is_cuda: cuda_devices.add(tensor.device) - self._final_rollout.apply(cuda_check) + self._final_rollout.apply(cuda_check, filter_empty=True) for device in cuda_devices: streams.append(torch.cuda.Stream(device, priority=-1)) events.append(streams[-1].record_event()) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 9669963cb33..cfb977d4bb2 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -419,12 +419,8 @@ def _check_for_empty_spec(specs: CompositeSpec): def map_device(key, value, device_map=device_map): return value.to(device_map[key]) - # self._env_tensordict.named_apply( - # map_device, nested_keys=True, filter_empty=True - # ) self._env_tensordict.named_apply( - map_device, - nested_keys=True, + map_device, nested_keys=True, filter_empty=True ) self._batch_locked = meta_data.batch_locked @@ -792,16 +788,11 @@ def select_and_clone(name, tensor): if name in selected_output_keys: return tensor.clone() - # out = self.shared_tensordict_parent.named_apply( - # select_and_clone, - # nested_keys=True, - # filter_empty=True, - # ) out = self.shared_tensordict_parent.named_apply( select_and_clone, nested_keys=True, + filter_empty=True, ) - del out["next"] if out.device != device: if device is None: @@ -842,8 +833,7 @@ def select_and_clone(name, tensor): if name in self._selected_step_keys: return tensor.clone() - # out = next_td.named_apply(select_and_clone, nested_keys=True, filter_empty=True) - out = next_td.named_apply(select_and_clone, nested_keys=True) + out = next_td.named_apply(select_and_clone, nested_keys=True, filter_empty=True) if out.device != device: if device is None: @@ -1059,8 +1049,7 @@ def _start_workers(self) -> None: def look_for_cuda(tensor, has_cuda=has_cuda): has_cuda[0] = has_cuda[0] or tensor.is_cuda - # self.shared_tensordict_parent.apply(look_for_cuda, filter_empty=True) - self.shared_tensordict_parent.apply(look_for_cuda) + self.shared_tensordict_parent.apply(look_for_cuda, filter_empty=True) has_cuda = has_cuda[0] if has_cuda: self.event = torch.cuda.Event() @@ -1182,14 +1171,14 @@ def step_and_maybe_reset( if x.device != device else x.clone(), device=device, - # filter_empty=True, + filter_empty=True, ) tensordict_ = tensordict_._fast_apply( lambda x: x.to(device, non_blocking=True) if x.device != device else x.clone(), device=device, - # filter_empty=True, + filter_empty=True, ) else: next_td = next_td.clone().clear_device_() @@ -1244,7 +1233,7 @@ def select_and_clone(name, tensor): out = next_td.named_apply( select_and_clone, nested_keys=True, - # filter_empty=True, + filter_empty=True, ) if out.device != device: if device is None: @@ -1314,9 +1303,8 @@ def select_and_clone(name, tensor): out = self.shared_tensordict_parent.named_apply( select_and_clone, nested_keys=True, - # filter_empty=True, + filter_empty=True, ) - del out["next"] if out.device != device: if device is None: @@ -1452,8 +1440,7 @@ def _run_worker_pipe_shared_mem( def look_for_cuda(tensor, has_cuda=has_cuda): has_cuda[0] = has_cuda[0] or tensor.is_cuda - # shared_tensordict.apply(look_for_cuda, filter_empty=True) - shared_tensordict.apply(look_for_cuda) + shared_tensordict.apply(look_for_cuda, filter_empty=True) has_cuda = has_cuda[0] else: has_cuda = device.type == "cuda" diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 61cd211b6ae..746cc60f142 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -107,7 +107,7 @@ def metadata_from_env(env) -> EnvMetaData: def fill_device_map(name, val, device_map=device_map): device_map[name] = val.device - tensordict.named_apply(fill_device_map, nested_keys=True) + tensordict.named_apply(fill_device_map, nested_keys=True, filter_empty=True) return EnvMetaData( tensordict, specs, batch_size, env_str, device, batch_locked, device_map ) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index efa59e25c26..b71c6fcffc3 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -3063,7 +3063,7 @@ def __init__(self): super().__init__(in_keys=[]) def _call(self, tensordict: TensorDictBase) -> TensorDictBase: - tensordict.apply(check_finite) + tensordict.apply(check_finite, filter_empty=True) return tensordict def _reset( diff --git a/torchrl/objectives/cql.py b/torchrl/objectives/cql.py index f963f0e0b52..69a30c7f484 100644 --- a/torchrl/objectives/cql.py +++ b/torchrl/objectives/cql.py @@ -577,9 +577,14 @@ def actor_loss(self, tensordict: TensorDictBase) -> Tensor: def _get_policy_actions(self, data, actor_params, num_actions=10): batch_size = data.batch_size batch_size = list(batch_size[:-1]) + [batch_size[-1] * num_actions] - tensordict = data.select(*self.actor_network.in_keys).apply( - lambda x: x.repeat_interleave(num_actions, dim=data.ndim - 1), - batch_size=batch_size, + in_keys = [unravel_key(key) for key in self.actor_network.in_keys] + + def filter_and_repeat(name, x): + if name in in_keys: + return x.repeat_interleave(num_actions, dim=data.ndim - 1) + + tensordict = data.named_apply( + filter_and_repeat, batch_size=batch_size, filter_empty=True ) with torch.no_grad(): with set_exploration_type(ExplorationType.RANDOM), actor_params.to_module( @@ -731,13 +736,18 @@ def cql_loss(self, tensordict: TensorDictBase) -> Tensor: batch_size = tensordict_q_random.batch_size batch_size = list(batch_size[:-1]) + [batch_size[-1] * self.num_random] - tensordict_q_random = tensordict_q_random.select( - *self.actor_network.in_keys - ).apply( - lambda x: x.repeat_interleave( - self.num_random, dim=tensordict_q_random.ndim - 1 - ), + in_keys = [unravel_key(key) for key in self.actor_network.in_keys] + + def filter_and_repeat(name, x): + if name in in_keys: + return x.repeat_interleave( + self.num_random, dim=tensordict_q_random.ndim - 1 + ) + + tensordict_q_random = tensordict_q_random.named_apply( + filter_and_repeat, batch_size=batch_size, + filter_empty=True, ) tensordict_q_random.set(self.tensor_keys.action, random_actions_tensor) cql_tensordict = torch.cat(