Skip to content

Commit 768d014

Browse files
committed
[WIP][RFC] Required changes for integration with TorchTitan
Summary: We are not going to land this PR, this PR may be further divided into several PRs. Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent bed29d2 commit 768d014

File tree

4 files changed

+50
-13
lines changed

4 files changed

+50
-13
lines changed

torchft/checkpointing.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,10 @@ def do_GET(self):
7575
self.end_headers()
7676

7777
sd = state_dict()
78-
78+
logger.warning("After state_dict ===================.")
7979
torch.save(sd, self.wfile)
80+
logger.warning("After torch.save ===================.")
81+
8082
except Exception as e:
8183
logger.exception(
8284
f"Exception in checkpoint server when handling {self.path=}: {e}",
@@ -113,7 +115,7 @@ def load_from_address(cls, address: str, timeout: timedelta) -> T:
113115
data = f.read()
114116

115117
reader = io.BytesIO(data)
116-
return torch.load(reader, weights_only=True)
118+
return torch.load(reader, weights_only=False)
117119

118120
def address(self) -> str:
119121
"""

torchft/manager.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@ class Manager:
8787
def __init__(
8888
self,
8989
pg: "ProcessGroup",
90-
load_state_dict: Callable[[T], None],
91-
state_dict: Callable[[], T],
90+
load_state_dict: Optional[Callable[[T], None]],
91+
state_dict: Optional[Callable[[], T]],
9292
min_replica_size: int,
9393
use_async_quorum: bool = True,
9494
timeout: timedelta = timedelta(seconds=60),
@@ -158,7 +158,7 @@ def __init__(
158158

159159
def _manager_state_dict() -> Dict[str, T]:
160160
return {
161-
"user": state_dict(),
161+
"user": self._state_dict(),
162162
"torchft": cast(T, self.state_dict()),
163163
}
164164

@@ -223,6 +223,12 @@ def _manager_state_dict() -> Dict[str, T]:
223223
self._participating_rank: Optional[int] = None
224224
self._participating_world_size: int = 0
225225

226+
def set_state_dict_fns(
227+
self, load_state_dict: Callable[T, None], state_dict: Callable[[], T]
228+
) -> None:
229+
self._load_state_dict = load_state_dict
230+
self._state_dict = state_dict
231+
226232
def shutdown(self) -> None:
227233
"""
228234
Shutdown the manager and checkpoint server.
@@ -506,6 +512,7 @@ def _apply_pending_state_dict(self) -> None:
506512
assert self._pending_state_dict is not None, "checkpoint was not staged"
507513
self._load_state_dict(self._pending_state_dict["user"])
508514
self._pending_state_dict = None
515+
self._logger.info("Loaded state dict.")
509516

510517
def should_commit(self, timeout: Optional[timedelta] = None) -> bool:
511518
"""

torchft/optim.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
1313
"""
1414

15-
from typing import TYPE_CHECKING, Optional
15+
from typing import Any, TYPE_CHECKING, Optional
1616

1717
from torch.optim import Optimizer
1818

@@ -52,3 +52,11 @@ def step(self, closure: Optional[object] = None) -> None:
5252
assert closure is None, "optimizers that use closures are not supported"
5353
if self.manager.should_commit():
5454
self.optim.step()
55+
56+
@property
57+
def param_groups(self) -> Any:
58+
return self.optim.param_groups
59+
60+
@property
61+
def state(self) -> Any:
62+
return self.optim.state

torchft/process_group.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import threading
2222
from abc import ABC
2323
from datetime import timedelta
24-
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Union
24+
from typing import Any, TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Union
2525

2626
import torch
2727
import torch.distributed as dist
@@ -852,6 +852,8 @@ def extend_device_mesh(
852852

853853

854854
class ManagedDeviceMesh(DeviceMesh):
855+
replicate_pg_singleton: Optional["ManagedProcessGroup"]
856+
855857
def __init__(
856858
self,
857859
mesh: Optional[DeviceMesh],
@@ -880,6 +882,15 @@ def __init__(
880882
self._flatten_mesh_list: Tuple[DeviceMesh, ...] = tuple()
881883
self._thread_id: Optional[int] = None
882884

885+
def __getstate__(self) -> Dict[str, Any]:
886+
state = self.__dict__.copy()
887+
state["replicate_pg"] = None
888+
return state
889+
890+
def __setstate__(self, state: Dict[str, Any]) -> None:
891+
self.__dict__.update(state)
892+
self.replicate_pg = self.replicate_pg_singleton
893+
883894
def __getitem__(self, mesh_dim_names: Union[str, Tuple[str, ...]]) -> DeviceMesh:
884895
if isinstance(mesh_dim_names, str):
885896
if mesh_dim_names == self.replicate_dim_name:
@@ -897,13 +908,14 @@ def __getitem__(self, mesh_dim_names: Union[str, Tuple[str, ...]]) -> DeviceMesh
897908
return self.mesh[mesh_dim_names]
898909
else:
899910
assert isinstance(mesh_dim_names, tuple)
900-
if self.replicate_dim_name in mesh_dim_names:
911+
if self.replicate_dim_name not in mesh_dim_names:
901912
assert self.mesh is not None
902913
return self.mesh[mesh_dim_names]
903914
else:
904915
assert self.mesh is not None
916+
mesh_dim_names_wo_replicate = tuple(n for n in mesh_dim_names if n != self.replicate_dim_name)
905917
return ManagedDeviceMesh(
906-
self.mesh[mesh_dim_names],
918+
self.mesh[mesh_dim_names_wo_replicate],
907919
mesh_dim_names,
908920
self.replicate_pg,
909921
mesh_dim_names.index(self.replicate_dim_name),
@@ -938,14 +950,16 @@ def _flatten(self, mesh_dim_name: Optional[str]) -> "DeviceMesh":
938950
return flatten_mesh
939951

940952
def size(self, mesh_dim: Optional[int] = None) -> int:
953+
replicate_pg_size = self.replicate_pg.size()
954+
replicate_pg_size = 1 if replicate_pg_size == 0 else replicate_pg_size
941955
if mesh_dim is None:
942956
if self.mesh is None:
943-
return self.replicate_pg.size()
957+
return replicate_pg_size
944958
else:
945959
assert self.mesh is not None
946-
return self.mesh.size() * self.replicate_pg.size()
960+
return self.mesh.size() * replicate_pg_size
947961
elif mesh_dim == self.replicate_dim:
948-
return self.replicate_pg.size()
962+
return replicate_pg_size
949963
else:
950964
assert self.mesh is not None
951965
return self.mesh.size(self._real_mesh_dim(mesh_dim))
@@ -995,7 +1009,11 @@ def get_coordinate(self) -> Optional[List[int]]:
9951009
dimensions of the mesh. If this rank is not part of the mesh, return None.
9961010
"""
9971011
assert self.mesh is not None
998-
return self.mesh._coordinate_on_dim if self.mesh._coordinate_on_dim else None
1012+
ret = self.mesh._coordinate_on_dim if self.mesh._coordinate_on_dim else None
1013+
if ret:
1014+
ret = ret.copy()
1015+
ret.insert(get_rank(self.replicate_pg), self.replicate_dim)
1016+
return ret
9991017

10001018
def get_all_groups(self) -> List[BaseProcessGroup]:
10011019
raise NotImplementedError
@@ -1070,6 +1088,8 @@ def ft_init_device_mesh(
10701088
# the same backend has been registered.
10711089
replicate_pg.register(mesh_dim_names[replicate_dim])
10721090

1091+
ManagedDeviceMesh.replicate_pg_singleton = replicate_pg
1092+
10731093
return ManagedDeviceMesh(
10741094
mesh=mesh,
10751095
mesh_dim_names=mesh_dim_names,

0 commit comments

Comments
 (0)