5656 _as_context_manager,
5757 _CloudpickleWrapper,
5858 _convert_list_to_stack,
59- _DTYPE2STRDTYPE ,
59+ _DTYPE_TO_STR_DTYPE ,
6060 _GENERIC_NESTED_ERR,
6161 _is_dataclass as is_dataclass,
6262 _is_list_tensor_compatible,
7979 _set_max_batch_size,
8080 _shape,
8181 _split_tensordict,
82+ _STR_DTYPE_TO_DTYPE,
8283 _td_fields,
8384 _unravel_key_to_tuple,
8485 _zip_strict,
@@ -5061,7 +5062,7 @@ def add_single_value(value, key, metadata_dict, dtype, shape, flat_size):
50615062 stop = sum([start, flat_size[-1]])
50625063 if requires_metadata:
50635064 metadata_dict["leaves"][key] = (
5064- _DTYPE2STRDTYPE [dtype],
5065+ _DTYPE_TO_STR_DTYPE [dtype],
50655066 list(shape),
50665067 # _DEVICE2STRDEVICE[device],
50675068 start,
@@ -7953,11 +7954,161 @@ def _recv(
79537954
79547955 return _tag
79557956
7957+ def init_remote(
7958+ self,
7959+ dst: int,
7960+ group: "ProcessGroup" | None = None, # noqa: F821
7961+ device: torch.device | None = None,
7962+ ):
7963+ """Initializes a remote tensordict by sending its metadata and content.
7964+
7965+ This method sends the metadata (shape, dtype, etc.) of the current tensordict to the specified destination rank (`dst`).
7966+
7967+ It then asynchronously sends the actual tensordict content.
7968+
7969+ Args:
7970+ dst (int): The rank of the destination process.
7971+ group ("ProcessGroup", optional): The process group to use for communication. Defaults to None.
7972+ device (torch.device, optional): The device to use for tensor operations. Defaults to None.
7973+
7974+ .. seealso::
7975+ The receiving process should call `~.from_remote_init` or an equivalent method to receive and initialize a new tensordict based on the sent metadata.
7976+
7977+ Examples:
7978+ >>> import os
7979+ >>> import torch
7980+ >>> import torch.distributed as dist
7981+ >>> from tensordict import TensorDict, MemoryMappedTensor
7982+ >>> import multiprocessing as mp
7983+ >>>
7984+ >>> def server(queue):
7985+ ... # Set environment variables for distributed communication
7986+ ... os.environ["MASTER_ADDR"] = "localhost"
7987+ ... os.environ["MASTER_PORT"] = "29505"
7988+ ...
7989+ ... # Initialize the distributed backend
7990+ ... dist.init_process_group("gloo", rank=0, world_size=2)
7991+ ...
7992+ ... # Create a sample tensordict
7993+ ... td = (
7994+ ... TensorDict(
7995+ ... {
7996+ ... ("a", "b"): torch.ones(2),
7997+ ... "c": torch.ones(2),
7998+ ... ("d", "e", "f"): MemoryMappedTensor.from_tensor(torch.ones(2, 2)),
7999+ ... },
8000+ ... [2],
8001+ ... )
8002+ ... .expand(1, 2)
8003+ ... .contiguous()
8004+ ... )
8005+ ...
8006+ ... # Send the tensordict metadata and content to the client
8007+ ... td.init_remote(dst=1)
8008+ ...
8009+ >>> def client(queue):
8010+ ... # Set environment variables for distributed communication
8011+ ... os.environ["MASTER_ADDR"] = "localhost"
8012+ ... os.environ["MASTER_PORT"] = "29505"
8013+ ...
8014+ ... # Initialize the distributed backend
8015+ ... dist.init_process_group("gloo", rank=1, world_size=2)
8016+ ...
8017+ ... # Receive the tensordict metadata and content from the server
8018+ ... received_td = TensorDict.from_remote_init(src=0)
8019+ ...
8020+ ... # Verify that the received tensordict matches the expected structure and values
8021+ ... assert set(received_td.keys()) == {"a", "c", "d"}
8022+ ... assert (received_td == 1).all()
8023+ ...
8024+ ... # Signal that the test has completed successfully
8025+ ... queue.put("yuppie")
8026+ >>>
8027+ >>> if __name__ == "__main__":
8028+ ... queue = mp.Queue(1)
8029+ ...
8030+ ... # Create and start the server and client processes
8031+ ... main_worker = mp.Process(target=server, args=(queue,))
8032+ ... secondary_worker = mp.Process(target=client, args=(queue,))
8033+ ...
8034+ ... main_worker.start()
8035+ ... secondary_worker.start()
8036+ ...
8037+ ... try:
8038+ ... out = queue.get(timeout=10) # Wait for the signal with a timeout
8039+ ... print(out) # Should print "yuppie"
8040+ ... finally:
8041+ ... queue.close()
8042+ ... main_worker.join(timeout=10)
8043+ ... secondary_worker.join(timeout=10)
8044+ """
8045+ # Get a list of key - specs
8046+ data = [
8047+ {
8048+ k: (tuple(val.shape), str(val.dtype), str(val.device))
8049+ for k, val in self.items(True, True)
8050+ },
8051+ self.batch_size,
8052+ self.device,
8053+ self.is_locked,
8054+ ]
8055+ torch.distributed.send_object_list(
8056+ data,
8057+ dst=dst,
8058+ group=group,
8059+ device=device,
8060+ )
8061+ self.isend(dst, group=group)
8062+
8063+ @classmethod
8064+ def from_remote_init(
8065+ cls: T,
8066+ src: int,
8067+ group: "ProcessGroup" | None = None, # noqa: F821
8068+ device: torch.device | None = None,
8069+ ) -> T:
8070+ """Creates a new tensordict instance initialized from remotely sent metadata.
8071+
8072+ This class method receives the metadata sent by `init_remote`, creates a new tensordict with matching shape and dtype,
8073+ and then asynchronously receives the actual tensordict content.
8074+
8075+ Args:
8076+ src (int): The rank of the source process that sent the metadata.
8077+ group ("ProcessGroup", optional): The process group to use for communication. Defaults to None.
8078+ device (torch.device, optional): The device to use for tensor operations. Defaults to None.
8079+
8080+ Returns:
8081+ TensorDict: A new tensordict instance initialized with the received metadata and content.
8082+
8083+ .. seealso::
8084+ The sending process should have called `~.init_remote` to send the metadata and content.
8085+ """
8086+ data = [None, None, None, None]
8087+ torch.distributed.recv_object_list(
8088+ data,
8089+ src=src,
8090+ group=group,
8091+ device=device,
8092+ )
8093+ metadata = data[0]
8094+ td = cls(
8095+ {
8096+ k: torch.empty(v[0], dtype=_STR_DTYPE_TO_DTYPE[v[1]], device=v[2])
8097+ for k, v in metadata.items()
8098+ },
8099+ batch_size=data[1],
8100+ device=data[2],
8101+ )
8102+ if data[3]:
8103+ td.lock_()
8104+ td.irecv(src=src, group=group)
8105+ return td
8106+
79568107 def isend(
79578108 self,
79588109 dst: int,
79598110 *,
7960- group: "torch.distributed.ProcessGroup" | None = None,
8111+ group: "torch.distributed.ProcessGroup" | None = None, # noqa: F821
79618112 init_tag: int = 0,
79628113 pseudo_rand: bool = False,
79638114 return_early: bool = False,
@@ -8048,7 +8199,13 @@ def isend(
80488199 ... secondary_worker.join()
80498200
80508201 """
8051- return self._isend(dst, _tag=init_tag - 1, pseudo_rand=pseudo_rand, group=group, return_early=return_early)
8202+ return self._isend(
8203+ dst,
8204+ _tag=init_tag - 1,
8205+ pseudo_rand=pseudo_rand,
8206+ group=group,
8207+ return_early=return_early,
8208+ )
80528209
80538210 def _isend(
80548211 self,
@@ -8057,7 +8214,7 @@ def _isend(
80578214 _futures: list[torch.Future] | None = None,
80588215 pseudo_rand: bool = False,
80598216 group: "torch.distributed.ProcessGroup" | None = None,
8060- return_early: bool = False,
8217+ return_early: bool = False,
80618218 ) -> int:
80628219 from torch import distributed as dist
80638220
0 commit comments