From 6db755d70f6e9fe8df0010877eb6767d2d2e4f87 Mon Sep 17 00:00:00 2001 From: vmoens Date: Sat, 31 Dec 2022 23:02:19 +0000 Subject: [PATCH] init --- torchrl/collectors/collectors.py | 37 +++++++++++++++++++++++++++----- torchrl/envs/common.py | 9 +++++++- 2 files changed, 40 insertions(+), 6 deletions(-) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index aa3d79e1499..47a7cec7f19 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import _pickle import abc import inspect import os @@ -708,7 +709,14 @@ def shutdown(self) -> None: del self.env def __del__(self): - self.shutdown() # make sure env is closed + try: + self.shutdown() + except Exception: + # an AttributeError will typically be raised if the collector is deleted when the program ends. + # In the future, insignificant changes to the close method may change the error type. + # We excplicitely assume that any error raised during closure in + # __del__ will not affect the program. + pass def state_dict(self) -> OrderedDict: """Returns the local state_dict of the data collector (environment and policy). @@ -1016,7 +1024,19 @@ def _run_processes(self) -> None: } proc = mp.Process(target=_main_async_collector, kwargs=kwargs) # proc.daemon can't be set as daemonic processes may be launched by the process itself - proc.start() + try: + proc.start() + except _pickle.PicklingError as err: + if "" in str(err): + raise RuntimeError( + """Can't open a process with doubly cloud-pickled lambda function. +This error is likely due to an attempt to use a ParallelEnv in a +multiprocessed data collector. To do this, consider wrapping your +lambda function in an `torchrl.envs.EnvCreator` wrapper as follows: +`env = ParallelEnv(N, EnvCreator(my_lambda_function))`. +This will not only ensure that your lambda function is cloud-pickled once, but +also that the state dict is synchronised across processes if needed.""" + ) pipe_child.close() self.procs.append(proc) self.pipes.append(pipe_parent) @@ -1027,7 +1047,14 @@ def _run_processes(self) -> None: self.closed = False def __del__(self): - self.shutdown() + try: + self.shutdown() + except Exception: + # an AttributeError will typically be raised if the collector is deleted when the program ends. + # In the future, insignificant changes to the close method may change the error type. + # We excplicitely assume that any error raised during closure in + # __del__ will not affect the program. + pass def shutdown(self) -> None: """Shuts down all processes. This operation is irreversible.""" @@ -1624,8 +1651,8 @@ def _main_async_collector( f"without receiving a command from main. Consider increasing the maximum idle count " f"if this is expected via the environment variable MAX_IDLE_COUNT " f"(current value is {_MAX_IDLE_COUNT})." - f"\nIf this occurs at the end of a function, it means that your collector has not been " - f"collected, consider calling `collector.shutdown()` or `del collector` at the end of the function." + f"\nIf this occurs at the end of a function or program, it means that your collector has not been " + f"collected, consider calling `collector.shutdown()` or `del collector` before ending the program." ) continue if msg in ("continue", "continue_random"): diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 406013d2480..aec32eff395 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -677,7 +677,14 @@ def __del__(self): # if del occurs before env has been set up, we don't want a recursion # error if "is_closed" in self.__dict__ and not self.is_closed: - self.close() + try: + self.close() + except Exception: + # a TypeError will typically be raised if the env is deleted when the program ends. + # In the future, insignificant changes to the close method may change the error type. + # We excplicitely assume that any error raised during closure in + # __del__ will not affect the program. + pass def to(self, device: DEVICE_TYPING) -> EnvBase: device = torch.device(device)