Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1817,6 +1817,13 @@ def get_policy_version(self) -> str | int | None:
"""
return self.policy_version

def getattr_policy(self, attr):
# send command to policy to return the attr
return getattr(self.policy, attr)

def getattr_env(self, attr):
# send command to env to return the attr
return getattr(self.env, attr)


class _MultiDataCollector(DataCollectorBase):
Expand Down Expand Up @@ -2782,6 +2789,57 @@ def get_policy_version(self) -> str | int | None:
"""
return self.policy_version

def getattr_policy(self, attr):
"""Get an attribute from the policy of the first worker.

Args:
attr (str): The attribute name to retrieve from the policy.

Returns:
The attribute value from the policy of the first worker.

Raises:
AttributeError: If the attribute doesn't exist on the policy.
"""
_check_for_faulty_process(self.procs)

# Send command to first worker (index 0)
self.pipes[0].send((attr, "getattr_policy"))
result, msg = self.pipes[0].recv()
if msg != "getattr_policy":
raise RuntimeError(f"Expected msg='getattr_policy', got {msg}")

# If the worker returned an AttributeError, re-raise it
if isinstance(result, AttributeError):
raise result

return result

def getattr_env(self, attr):
"""Get an attribute from the environment of the first worker.

Args:
attr (str): The attribute name to retrieve from the environment.

Returns:
The attribute value from the environment of the first worker.

Raises:
AttributeError: If the attribute doesn't exist on the environment.
"""
_check_for_faulty_process(self.procs)

# Send command to first worker (index 0)
self.pipes[0].send((attr, "getattr_env"))
result, msg = self.pipes[0].recv()
if msg != "getattr_env":
raise RuntimeError(f"Expected msg='getattr_env', got {msg}")

# If the worker returned an AttributeError, re-raise it
if isinstance(result, AttributeError):
raise result

return result


@accept_remote_rref_udf_invocation
Expand Down Expand Up @@ -3947,6 +4005,25 @@ def cast_tensor(x, MPS_ERROR=MPS_ERROR):
has_timed_out = False
continue

elif msg == "getattr_policy":
attr_name = data_in
try:
result = getattr(inner_collector.policy, attr_name)
pipe_child.send((result, "getattr_policy"))
except AttributeError as e:
pipe_child.send((e, "getattr_policy"))
has_timed_out = False
continue

elif msg == "getattr_env":
attr_name = data_in
try:
result = getattr(inner_collector.env, attr_name)
pipe_child.send((result, "getattr_env"))
except AttributeError as e:
pipe_child.send((e, "getattr_env"))
has_timed_out = False
continue

elif msg == "close":
del collected_tensordict, data, next_data, data_in
Expand Down
Loading