From 8d8b1a213f8ed4d7bf1e9297db2374f7a332e198 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 22 Sep 2025 09:07:25 +0100 Subject: [PATCH] Update [ghstack-poisoned] --- torchrl/collectors/collectors.py | 77 ++++++++++++++++++++++++++++++++ 1 file changed, 77 insertions(+) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 57396e3c089..cb7d66a42e7 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -1817,6 +1817,13 @@ def get_policy_version(self) -> str | int | None: """ return self.policy_version + def getattr_policy(self, attr): + # send command to policy to return the attr + return getattr(self.policy, attr) + + def getattr_env(self, attr): + # send command to env to return the attr + return getattr(self.env, attr) class _MultiDataCollector(DataCollectorBase): @@ -2782,6 +2789,57 @@ def get_policy_version(self) -> str | int | None: """ return self.policy_version + def getattr_policy(self, attr): + """Get an attribute from the policy of the first worker. + + Args: + attr (str): The attribute name to retrieve from the policy. + + Returns: + The attribute value from the policy of the first worker. + + Raises: + AttributeError: If the attribute doesn't exist on the policy. + """ + _check_for_faulty_process(self.procs) + + # Send command to first worker (index 0) + self.pipes[0].send((attr, "getattr_policy")) + result, msg = self.pipes[0].recv() + if msg != "getattr_policy": + raise RuntimeError(f"Expected msg='getattr_policy', got {msg}") + + # If the worker returned an AttributeError, re-raise it + if isinstance(result, AttributeError): + raise result + + return result + + def getattr_env(self, attr): + """Get an attribute from the environment of the first worker. + + Args: + attr (str): The attribute name to retrieve from the environment. + + Returns: + The attribute value from the environment of the first worker. + + Raises: + AttributeError: If the attribute doesn't exist on the environment. + """ + _check_for_faulty_process(self.procs) + + # Send command to first worker (index 0) + self.pipes[0].send((attr, "getattr_env")) + result, msg = self.pipes[0].recv() + if msg != "getattr_env": + raise RuntimeError(f"Expected msg='getattr_env', got {msg}") + + # If the worker returned an AttributeError, re-raise it + if isinstance(result, AttributeError): + raise result + + return result @accept_remote_rref_udf_invocation @@ -3947,6 +4005,25 @@ def cast_tensor(x, MPS_ERROR=MPS_ERROR): has_timed_out = False continue + elif msg == "getattr_policy": + attr_name = data_in + try: + result = getattr(inner_collector.policy, attr_name) + pipe_child.send((result, "getattr_policy")) + except AttributeError as e: + pipe_child.send((e, "getattr_policy")) + has_timed_out = False + continue + + elif msg == "getattr_env": + attr_name = data_in + try: + result = getattr(inner_collector.env, attr_name) + pipe_child.send((result, "getattr_env")) + except AttributeError as e: + pipe_child.send((e, "getattr_env")) + has_timed_out = False + continue elif msg == "close": del collected_tensordict, data, next_data, data_in