Skip to content

Commit 2959863

Browse files
author
Vincent Moens
committed
[Feature] broadcast tensordicts
ghstack-source-id: 3b142c9 Pull-Request-resolved: #1307
1 parent 4012767 commit 2959863

File tree

6 files changed

+246
-12
lines changed

6 files changed

+246
-12
lines changed

tensordict/_reductions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from tensordict._td import TensorDict
1313

1414
from tensordict.tensorclass import NonTensorData, NonTensorStack
15-
from tensordict.utils import _is_tensorclass, _STRDTYPE2DTYPE
15+
from tensordict.utils import _is_tensorclass, _STR_DTYPE_TO_DTYPE
1616

1717
CLS_MAP = {
1818
"TensorDict": TensorDict,
@@ -99,7 +99,7 @@ def from_metadata(metadata=metadata, prefix=None):
9999
for (key, (data, batch_size, device)) in non_tensor.items()
100100
}
101101
for key, (dtype, local_shape, start, stop, pad) in leaves.items():
102-
dtype = _STRDTYPE2DTYPE[dtype]
102+
dtype = _STR_DTYPE_TO_DTYPE[dtype]
103103
# device = torch.device(device)
104104
local_shape = torch.Size(local_shape)
105105
value = storage[start:stop].view(dtype)

tensordict/_td.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@
6767
_set_item,
6868
_set_max_batch_size,
6969
_shape,
70-
_STRDTYPE2DTYPE,
70+
_STR_DTYPE_TO_DTYPE,
7171
_StringKeys,
7272
_StringOnlyDict,
7373
_sub_index,
@@ -2836,7 +2836,7 @@ def _load_memmap(
28362836
else:
28372837
shape = torch.Size(shape)
28382838
tensor = MemoryMappedTensor.from_filename(
2839-
dtype=_STRDTYPE2DTYPE[dtype],
2839+
dtype=_STR_DTYPE_TO_DTYPE[dtype],
28402840
shape=shape,
28412841
filename=str(prefix / f"{key}.memmap"),
28422842
)
@@ -2846,7 +2846,7 @@ def _load_memmap(
28462846
tensor = torch.zeros(
28472847
torch.Size(shape),
28482848
device=device,
2849-
dtype=_STRDTYPE2DTYPE[dtype],
2849+
dtype=_STR_DTYPE_TO_DTYPE[dtype],
28502850
)
28512851
result._set_str(
28522852
key,

tensordict/base.py

Lines changed: 162 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
_as_context_manager,
5757
_CloudpickleWrapper,
5858
_convert_list_to_stack,
59-
_DTYPE2STRDTYPE,
59+
_DTYPE_TO_STR_DTYPE,
6060
_GENERIC_NESTED_ERR,
6161
_is_dataclass as is_dataclass,
6262
_is_list_tensor_compatible,
@@ -79,6 +79,7 @@
7979
_set_max_batch_size,
8080
_shape,
8181
_split_tensordict,
82+
_STR_DTYPE_TO_DTYPE,
8283
_td_fields,
8384
_unravel_key_to_tuple,
8485
_zip_strict,
@@ -5061,7 +5062,7 @@ def add_single_value(value, key, metadata_dict, dtype, shape, flat_size):
50615062
stop = sum([start, flat_size[-1]])
50625063
if requires_metadata:
50635064
metadata_dict["leaves"][key] = (
5064-
_DTYPE2STRDTYPE[dtype],
5065+
_DTYPE_TO_STR_DTYPE[dtype],
50655066
list(shape),
50665067
# _DEVICE2STRDEVICE[device],
50675068
start,
@@ -7953,11 +7954,161 @@ def _recv(
79537954

79547955
return _tag
79557956

7957+
def init_remote(
7958+
self,
7959+
dst: int,
7960+
group: "ProcessGroup" | None = None, # noqa: F821
7961+
device: torch.device | None = None,
7962+
):
7963+
"""Initializes a remote tensordict by sending its metadata and content.
7964+
7965+
This method sends the metadata (shape, dtype, etc.) of the current tensordict to the specified destination rank (`dst`).
7966+
7967+
It then asynchronously sends the actual tensordict content.
7968+
7969+
Args:
7970+
dst (int): The rank of the destination process.
7971+
group ("ProcessGroup", optional): The process group to use for communication. Defaults to None.
7972+
device (torch.device, optional): The device to use for tensor operations. Defaults to None.
7973+
7974+
.. seealso::
7975+
The receiving process should call `~.from_remote_init` or an equivalent method to receive and initialize a new tensordict based on the sent metadata.
7976+
7977+
Examples:
7978+
>>> import os
7979+
>>> import torch
7980+
>>> import torch.distributed as dist
7981+
>>> from tensordict import TensorDict, MemoryMappedTensor
7982+
>>> import multiprocessing as mp
7983+
>>>
7984+
>>> def server(queue):
7985+
... # Set environment variables for distributed communication
7986+
... os.environ["MASTER_ADDR"] = "localhost"
7987+
... os.environ["MASTER_PORT"] = "29505"
7988+
...
7989+
... # Initialize the distributed backend
7990+
... dist.init_process_group("gloo", rank=0, world_size=2)
7991+
...
7992+
... # Create a sample tensordict
7993+
... td = (
7994+
... TensorDict(
7995+
... {
7996+
... ("a", "b"): torch.ones(2),
7997+
... "c": torch.ones(2),
7998+
... ("d", "e", "f"): MemoryMappedTensor.from_tensor(torch.ones(2, 2)),
7999+
... },
8000+
... [2],
8001+
... )
8002+
... .expand(1, 2)
8003+
... .contiguous()
8004+
... )
8005+
...
8006+
... # Send the tensordict metadata and content to the client
8007+
... td.init_remote(dst=1)
8008+
...
8009+
>>> def client(queue):
8010+
... # Set environment variables for distributed communication
8011+
... os.environ["MASTER_ADDR"] = "localhost"
8012+
... os.environ["MASTER_PORT"] = "29505"
8013+
...
8014+
... # Initialize the distributed backend
8015+
... dist.init_process_group("gloo", rank=1, world_size=2)
8016+
...
8017+
... # Receive the tensordict metadata and content from the server
8018+
... received_td = TensorDict.from_remote_init(src=0)
8019+
...
8020+
... # Verify that the received tensordict matches the expected structure and values
8021+
... assert set(received_td.keys()) == {"a", "c", "d"}
8022+
... assert (received_td == 1).all()
8023+
...
8024+
... # Signal that the test has completed successfully
8025+
... queue.put("yuppie")
8026+
>>>
8027+
>>> if __name__ == "__main__":
8028+
... queue = mp.Queue(1)
8029+
...
8030+
... # Create and start the server and client processes
8031+
... main_worker = mp.Process(target=server, args=(queue,))
8032+
... secondary_worker = mp.Process(target=client, args=(queue,))
8033+
...
8034+
... main_worker.start()
8035+
... secondary_worker.start()
8036+
...
8037+
... try:
8038+
... out = queue.get(timeout=10) # Wait for the signal with a timeout
8039+
... print(out) # Should print "yuppie"
8040+
... finally:
8041+
... queue.close()
8042+
... main_worker.join(timeout=10)
8043+
... secondary_worker.join(timeout=10)
8044+
"""
8045+
# Get a list of key - specs
8046+
data = [
8047+
{
8048+
k: (tuple(val.shape), str(val.dtype), str(val.device))
8049+
for k, val in self.items(True, True)
8050+
},
8051+
self.batch_size,
8052+
self.device,
8053+
self.is_locked,
8054+
]
8055+
torch.distributed.send_object_list(
8056+
data,
8057+
dst=dst,
8058+
group=group,
8059+
device=device,
8060+
)
8061+
self.isend(dst, group=group)
8062+
8063+
@classmethod
8064+
def from_remote_init(
8065+
cls: T,
8066+
src: int,
8067+
group: "ProcessGroup" | None = None, # noqa: F821
8068+
device: torch.device | None = None,
8069+
) -> T:
8070+
"""Creates a new tensordict instance initialized from remotely sent metadata.
8071+
8072+
This class method receives the metadata sent by `init_remote`, creates a new tensordict with matching shape and dtype,
8073+
and then asynchronously receives the actual tensordict content.
8074+
8075+
Args:
8076+
src (int): The rank of the source process that sent the metadata.
8077+
group ("ProcessGroup", optional): The process group to use for communication. Defaults to None.
8078+
device (torch.device, optional): The device to use for tensor operations. Defaults to None.
8079+
8080+
Returns:
8081+
TensorDict: A new tensordict instance initialized with the received metadata and content.
8082+
8083+
.. seealso::
8084+
The sending process should have called `~.init_remote` to send the metadata and content.
8085+
"""
8086+
data = [None, None, None, None]
8087+
torch.distributed.recv_object_list(
8088+
data,
8089+
src=src,
8090+
group=group,
8091+
device=device,
8092+
)
8093+
metadata = data[0]
8094+
td = cls(
8095+
{
8096+
k: torch.empty(v[0], dtype=_STR_DTYPE_TO_DTYPE[v[1]], device=v[2])
8097+
for k, v in metadata.items()
8098+
},
8099+
batch_size=data[1],
8100+
device=data[2],
8101+
)
8102+
if data[3]:
8103+
td.lock_()
8104+
td.irecv(src=src, group=group)
8105+
return td
8106+
79568107
def isend(
79578108
self,
79588109
dst: int,
79598110
*,
7960-
group: "torch.distributed.ProcessGroup" | None = None,
8111+
group: "torch.distributed.ProcessGroup" | None = None, # noqa: F821
79618112
init_tag: int = 0,
79628113
pseudo_rand: bool = False,
79638114
return_early: bool = False,
@@ -8048,7 +8199,13 @@ def isend(
80488199
... secondary_worker.join()
80498200

80508201
"""
8051-
return self._isend(dst, _tag=init_tag - 1, pseudo_rand=pseudo_rand, group=group, return_early=return_early)
8202+
return self._isend(
8203+
dst,
8204+
_tag=init_tag - 1,
8205+
pseudo_rand=pseudo_rand,
8206+
group=group,
8207+
return_early=return_early,
8208+
)
80528209

80538210
def _isend(
80548211
self,
@@ -8057,7 +8214,7 @@ def _isend(
80578214
_futures: list[torch.Future] | None = None,
80588215
pseudo_rand: bool = False,
80598216
group: "torch.distributed.ProcessGroup" | None = None,
8060-
return_early: bool = False,
8217+
return_early: bool = False,
80618218
) -> int:
80628219
from torch import distributed as dist
80638220

tensordict/tensorclass.pyi

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -704,6 +704,22 @@ class TensorClass:
704704
init_tag: int = 0,
705705
pseudo_rand: bool = False,
706706
) -> int: ...
707+
def broadcast_content(
708+
self,
709+
src: int,
710+
dst: int,
711+
group: "ProcessGroup" | None = None,
712+
device: torch.device | None = None,
713+
group_src: "ProcessGroup" | None = None,
714+
): ...
715+
@classmethod
716+
def from_broadcast(
717+
cls,
718+
src: int,
719+
group: "ProcessGroup" | None = None,
720+
device: torch.device | None = None,
721+
group_src: "ProcessGroup" | None = None,
722+
): ...
707723
def isend(
708724
self,
709725
dst: int,

tensordict/utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,10 @@ def dims(self, *args, **kwargs):
147147
_TORCH_DTYPES = _TORCH_DTYPES + (torch.uint32,)
148148
if hasattr(torch, "uint64"):
149149
_TORCH_DTYPES = _TORCH_DTYPES + (torch.uint64,)
150-
_STRDTYPE2DTYPE = {str(dtype): dtype for dtype in _TORCH_DTYPES}
151-
_DTYPE2STRDTYPE = {dtype: str_dtype for str_dtype, dtype in _STRDTYPE2DTYPE.items()}
150+
_STR_DTYPE_TO_DTYPE = {str(dtype): dtype for dtype in _TORCH_DTYPES}
151+
_DTYPE_TO_STR_DTYPE = {
152+
dtype: str_dtype for str_dtype, dtype in _STR_DTYPE_TO_DTYPE.items()
153+
}
152154

153155
IndexType = Union[None, int, slice, str, Tensor, List[Any], Tuple[Any, ...]]
154156
DeviceType = Union[torch.device, str, int]

test/test_distributed.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -761,6 +761,65 @@ def make_td(ones):
761761
return td
762762

763763

764+
class TestInitRemote:
765+
port = "29505"
766+
767+
@classmethod
768+
def client(cls, queue, rank):
769+
os.environ["MASTER_ADDR"] = "localhost"
770+
os.environ["MASTER_PORT"] = cls.port
771+
dist.init_process_group(
772+
"gloo",
773+
rank=rank,
774+
world_size=2,
775+
)
776+
777+
td = TensorDict.from_remote_init(src=0)
778+
assert set(td.keys()) == {"a", "c", "d"}
779+
assert (td == 1).all()
780+
queue.put("yuppie")
781+
782+
@classmethod
783+
def server(cls, queue):
784+
os.environ["MASTER_ADDR"] = "localhost"
785+
os.environ["MASTER_PORT"] = cls.port
786+
dist.init_process_group(
787+
"gloo",
788+
rank=0,
789+
world_size=2,
790+
)
791+
792+
td = (
793+
TensorDict(
794+
{
795+
("a", "b"): torch.ones(2),
796+
"c": torch.ones(2),
797+
("d", "e", "f"): MemoryMappedTensor.from_tensor(torch.ones(2, 2)),
798+
},
799+
[2],
800+
)
801+
.expand(1, 2)
802+
.contiguous()
803+
)
804+
td.init_remote(dst=1)
805+
806+
def test_init_remote(self, set_context, tmp_path):
807+
queue = mp.Queue(1)
808+
main_worker = mp.Process(target=type(self).server, args=(queue,))
809+
secondary_worker = mp.Process(target=type(self).client, args=(queue, 1))
810+
811+
main_worker.start()
812+
secondary_worker.start()
813+
out = None
814+
try:
815+
out = queue.get(timeout=TIMEOUT)
816+
finally:
817+
queue.close()
818+
main_worker.join(timeout=TIMEOUT)
819+
secondary_worker.join(timeout=TIMEOUT)
820+
assert out == "yuppie"
821+
822+
764823
if __name__ == "__main__":
765824
args, unknown = argparse.ArgumentParser().parse_known_args()
766825
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

0 commit comments

Comments
 (0)