Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -6100,7 +6100,7 @@ def zero_param(p):
if isinstance(p, nn.Parameter):
p.data.zero_()

params.apply(zero_param)
params.apply(zero_param, filter_empty=True)

# assert len(list(floss_fn.parameters())) == 0
with params.to_module(loss_fn):
Expand Down
22 changes: 15 additions & 7 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,22 +801,30 @@ def check_exclusive(val):
"Consider using a placeholder for missing keys."
)

policy_output._fast_apply(check_exclusive, call_on_nested=True)
policy_output._fast_apply(
check_exclusive, call_on_nested=True, filter_empty=True
)

# Use apply, because it works well with lazy stacks
# Edge-case of this approach: the policy may change the values in-place and only by a tiny bit
# or occasionally. In these cases, the keys will be missed (we can't detect if the policy has
# changed them here).
# This will cause a failure to update entries when policy and env device mismatch and
# casting is necessary.
def filter_policy(value_output, value_input, value_input_clone):
if (
(value_input is None)
or (value_output is not value_input)
or ~torch.isclose(value_output, value_input_clone).any()
):
return value_output

filtered_policy_output = policy_output.apply(
lambda value_output, value_input, value_input_clone: value_output
if (value_input is None)
or (value_output is not value_input)
or ~torch.isclose(value_output, value_input_clone).any()
else None,
filter_policy,
policy_input_copy,
policy_input_clone,
default=None,
filter_empty=True,
)
self._policy_output_keys = list(
self._policy_output_keys.union(
Expand Down Expand Up @@ -933,7 +941,7 @@ def cuda_check(tensor: torch.Tensor):
if tensor.is_cuda:
cuda_devices.add(tensor.device)

self._final_rollout.apply(cuda_check)
self._final_rollout.apply(cuda_check, filter_empty=True)
for device in cuda_devices:
streams.append(torch.cuda.Stream(device, priority=-1))
events.append(streams[-1].record_event())
Expand Down
31 changes: 9 additions & 22 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,12 +419,8 @@ def _check_for_empty_spec(specs: CompositeSpec):
def map_device(key, value, device_map=device_map):
return value.to(device_map[key])

# self._env_tensordict.named_apply(
# map_device, nested_keys=True, filter_empty=True
# )
self._env_tensordict.named_apply(
map_device,
nested_keys=True,
map_device, nested_keys=True, filter_empty=True
)

self._batch_locked = meta_data.batch_locked
Expand Down Expand Up @@ -792,16 +788,11 @@ def select_and_clone(name, tensor):
if name in selected_output_keys:
return tensor.clone()

# out = self.shared_tensordict_parent.named_apply(
# select_and_clone,
# nested_keys=True,
# filter_empty=True,
# )
out = self.shared_tensordict_parent.named_apply(
select_and_clone,
nested_keys=True,
filter_empty=True,
)
del out["next"]

if out.device != device:
if device is None:
Expand Down Expand Up @@ -842,8 +833,7 @@ def select_and_clone(name, tensor):
if name in self._selected_step_keys:
return tensor.clone()

# out = next_td.named_apply(select_and_clone, nested_keys=True, filter_empty=True)
out = next_td.named_apply(select_and_clone, nested_keys=True)
out = next_td.named_apply(select_and_clone, nested_keys=True, filter_empty=True)

if out.device != device:
if device is None:
Expand Down Expand Up @@ -1059,8 +1049,7 @@ def _start_workers(self) -> None:
def look_for_cuda(tensor, has_cuda=has_cuda):
has_cuda[0] = has_cuda[0] or tensor.is_cuda

# self.shared_tensordict_parent.apply(look_for_cuda, filter_empty=True)
self.shared_tensordict_parent.apply(look_for_cuda)
self.shared_tensordict_parent.apply(look_for_cuda, filter_empty=True)
has_cuda = has_cuda[0]
if has_cuda:
self.event = torch.cuda.Event()
Expand Down Expand Up @@ -1182,14 +1171,14 @@ def step_and_maybe_reset(
if x.device != device
else x.clone(),
device=device,
# filter_empty=True,
filter_empty=True,
)
tensordict_ = tensordict_._fast_apply(
lambda x: x.to(device, non_blocking=True)
if x.device != device
else x.clone(),
device=device,
# filter_empty=True,
filter_empty=True,
)
else:
next_td = next_td.clone().clear_device_()
Expand Down Expand Up @@ -1244,7 +1233,7 @@ def select_and_clone(name, tensor):
out = next_td.named_apply(
select_and_clone,
nested_keys=True,
# filter_empty=True,
filter_empty=True,
)
if out.device != device:
if device is None:
Expand Down Expand Up @@ -1314,9 +1303,8 @@ def select_and_clone(name, tensor):
out = self.shared_tensordict_parent.named_apply(
select_and_clone,
nested_keys=True,
# filter_empty=True,
filter_empty=True,
)
del out["next"]

if out.device != device:
if device is None:
Expand Down Expand Up @@ -1452,8 +1440,7 @@ def _run_worker_pipe_shared_mem(
def look_for_cuda(tensor, has_cuda=has_cuda):
has_cuda[0] = has_cuda[0] or tensor.is_cuda

# shared_tensordict.apply(look_for_cuda, filter_empty=True)
shared_tensordict.apply(look_for_cuda)
shared_tensordict.apply(look_for_cuda, filter_empty=True)
has_cuda = has_cuda[0]
else:
has_cuda = device.type == "cuda"
Expand Down
2 changes: 1 addition & 1 deletion torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def metadata_from_env(env) -> EnvMetaData:
def fill_device_map(name, val, device_map=device_map):
device_map[name] = val.device

tensordict.named_apply(fill_device_map, nested_keys=True)
tensordict.named_apply(fill_device_map, nested_keys=True, filter_empty=True)
return EnvMetaData(
tensordict, specs, batch_size, env_str, device, batch_locked, device_map
)
Expand Down
2 changes: 1 addition & 1 deletion torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3063,7 +3063,7 @@ def __init__(self):
super().__init__(in_keys=[])

def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
tensordict.apply(check_finite)
tensordict.apply(check_finite, filter_empty=True)
return tensordict

def _reset(
Expand Down
28 changes: 19 additions & 9 deletions torchrl/objectives/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,9 +577,14 @@ def actor_loss(self, tensordict: TensorDictBase) -> Tensor:
def _get_policy_actions(self, data, actor_params, num_actions=10):
batch_size = data.batch_size
batch_size = list(batch_size[:-1]) + [batch_size[-1] * num_actions]
tensordict = data.select(*self.actor_network.in_keys).apply(
lambda x: x.repeat_interleave(num_actions, dim=data.ndim - 1),
batch_size=batch_size,
in_keys = [unravel_key(key) for key in self.actor_network.in_keys]

def filter_and_repeat(name, x):
if name in in_keys:
return x.repeat_interleave(num_actions, dim=data.ndim - 1)

tensordict = data.named_apply(
filter_and_repeat, batch_size=batch_size, filter_empty=True
)
with torch.no_grad():
with set_exploration_type(ExplorationType.RANDOM), actor_params.to_module(
Expand Down Expand Up @@ -731,13 +736,18 @@ def cql_loss(self, tensordict: TensorDictBase) -> Tensor:

batch_size = tensordict_q_random.batch_size
batch_size = list(batch_size[:-1]) + [batch_size[-1] * self.num_random]
tensordict_q_random = tensordict_q_random.select(
*self.actor_network.in_keys
).apply(
lambda x: x.repeat_interleave(
self.num_random, dim=tensordict_q_random.ndim - 1
),
in_keys = [unravel_key(key) for key in self.actor_network.in_keys]

def filter_and_repeat(name, x):
if name in in_keys:
return x.repeat_interleave(
self.num_random, dim=tensordict_q_random.ndim - 1
)

tensordict_q_random = tensordict_q_random.named_apply(
filter_and_repeat,
batch_size=batch_size,
filter_empty=True,
)
tensordict_q_random.set(self.tensor_keys.action, random_actions_tensor)
cql_tensordict = torch.cat(
Expand Down