Skip to content

Commit 28f8c50

Browse files
committedFeb 25, 2025
Integrate TorchFT
**Summary** This is a WIP TorchFT integration PR. **Current Issues** This doesn't work at this moment as there are hanged groups when a new group joins. **Issue 1:** ~Group 0 and group 1 will hang during the first `should_commit` after group 1 applying the pending state_dict from group 0.~ Fixed with: pytorch/torchft#83 **Issue 2:** ~Group 0 and group 1 will pass the `should_commit` but group 0 needs healing which is wrong and the healing process will cause another hang.~ Fixed with: pytorch/torchft#83 **Issue 3:** ~The byproduct of issue 1 and issue 2: group 1 will continue to print out~ ``` [rank0]:devgpu051:76838:80357 [0] misc/socket.cc:50 NCCL WARN socketProgress: Connection closed by remote peer devgpu051.cln3.svc.fbinfra.net<33618> ``` Fixed with pytorch/torchft#91 and several other fixes. **Issue 4:** When there are 3 groups, everyone requests the state dict every step. ***How to reproduce?*** Using the `Reproduce steps` to run 2 groups, then add another group by modifying the command. Seems to be fixed, will need more tests. **Issue 5:** Hang will happen if using functional collective. ***How to reproduce?*** Pull the latest version of this PR and comment out line 41 and uncomment line 42 in `torchtitan/utils.py` **Reproduce steps:** 1. Patch TorchFT with pytorch/torchft#82 2. Execute lighthouse 3. Execute the following command in one terminal: ``` TORCHFT_MANAGER_PORT=29520 REPLICA_GROUP_ID=0 CUDA_VISIBLE_DEVICES=0,1 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=0 ``` 4. Wait 10 seconds, execute following command in another terminal: ``` TORCHFT_MANAGER_PORT=29522 REPLICA_GROUP_ID=1 CUDA_VISIBLE_DEVICES=2,3 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=1 ``` ghstack-source-id: 440da0f8d30d8466c22e1d8e1d738366b2d58bea Pull Request resolved: #834
1 parent 0c86fdd commit 28f8c50

File tree

8 files changed

+277
-42
lines changed

8 files changed

+277
-42
lines changed
 

‎run_train.sh

+4
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@ if [ $# -ne 0 ]; then
1919
overrides="$*"
2020
fi
2121

22+
TORCHFT_MANAGER_PORT=${TORCHFT_MANAGER_PORT:-"29512"}
23+
2224
PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \
25+
TORCHFT_LIGHTHOUSE=http://localhost:29510 \
26+
TORCHFT_MANAGER_PORT=${TORCHFT_MANAGER_PORT} \
2327
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
2428
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
2529
torchtitan/train.py --job.config_file ${CONFIG_FILE} $overrides

‎torchtitan/components/checkpoint.py

+101-27
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ class CheckpointManager:
219219
states (Dict[str, Any]): The states that need to be saved, other than the
220220
previous 4 components.
221221
job_config (JobConfig): The job config used to configure the checkpointing.
222+
ft_manager (Optional[ft.Manager]): The FTManager from TorchFT.
222223
"""
223224

224225
def __init__(
@@ -229,16 +230,41 @@ def __init__(
229230
lr_schedulers: LRSchedulersContainer,
230231
states: Dict[str, Any],
231232
job_config: JobConfig,
233+
ft_manager: Optional["ft.Manager"] = None,
232234
) -> None:
233235
ckpt_config = job_config.checkpoint
234236
self.enable_checkpoint = ckpt_config.enable_checkpoint
237+
self.ft_manager = ft_manager
238+
239+
if self.ft_manager:
240+
optimizers.init_cache_state_dict()
241+
242+
def state_dict():
243+
ret = {}
244+
for k, v in self.states.items():
245+
if k in {
246+
MODEL,
247+
OPTIMIZER,
248+
LR_SCHEDULER,
249+
TRAIN_STATE,
250+
}:
251+
ret[k] = v.state_dict()
252+
return ret
253+
254+
def load_state_dict(state_dict):
255+
assert state_dict is not None
256+
for k, v in state_dict.items():
257+
self.states[k].load_state_dict(v)
258+
259+
ft_manager.set_state_dict_fns(load_state_dict, state_dict)
260+
self.ft_replica_id = job_config.experimental.ft_replica_id
235261

236262
async_mode = ckpt_config.async_mode.lower()
237263
self.enable_staging = (
238264
self.enable_checkpoint and async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM
239-
)
265+
) or self.ft_manager
240266

241-
if not self.enable_checkpoint:
267+
if not self.enable_checkpoint and self.ft_manager is None:
242268
return
243269

244270
self.states = states
@@ -250,6 +276,13 @@ def __init__(
250276
LR_SCHEDULER: lr_schedulers,
251277
}
252278
)
279+
self.ft_states = {DATALOADER: dataloader}
280+
281+
self.staging = False
282+
self.sending_to_checkpoint_mp = False
283+
self.staging_id = None
284+
self.cpu_offload_state_dict = None
285+
self.staging_stream = torch.cuda.Stream() if self.enable_staging else None
253286

254287
self.staging = False
255288
self.sending_to_checkpoint_mp = False
@@ -260,7 +293,7 @@ def __init__(
260293
self.folder = os.path.join(job_config.job.dump_folder, ckpt_config.folder)
261294
self.interval = ckpt_config.interval
262295
async_mode = ckpt_config.async_mode.lower()
263-
if async_mode == AsyncMode.ASYNC:
296+
if async_mode == AsyncMode.ASYNC or self.ft_manager:
264297
self.pg = dist.new_group(backend="gloo")
265298

266299
self.keep_latest_k = ckpt_config.keep_latest_k
@@ -334,35 +367,44 @@ def save(self, curr_step: int, force: bool = False) -> None:
334367
None
335368
"""
336369

370+
if self.ft_manager:
371+
self._ft_save(curr_step)
372+
337373
if not self._should_save(curr_step, force):
338374
return
339375

340376
begin = time.monotonic()
341-
logger.info("Saving the checkpoint (or staging if async is enabled).")
342-
checkpoint_id = self._create_checkpoint_id(curr_step)
343-
self._async_wait()
344-
# This GC is called for async checkpoint as it is useless to do
345-
# GC right after async_save -- the CPU memory is not able to be
346-
# freed until _async_wait()
347-
if force:
348-
self._save_last_step(curr_step)
349-
elif self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM:
350-
GarbageCollection.collect("GC collection invoked by checkpointer.")
351-
self._async_with_pinned_memory(checkpoint_id)
352-
elif self.async_mode == AsyncMode.ASYNC:
353-
GarbageCollection.collect("GC collection invoked by checkpointer.")
354-
self.async_future = dcp.async_save(
355-
self.states, checkpoint_id=checkpoint_id, process_group=self.pg
356-
)
357-
GarbageCollection.collect("GC collection invoked by checkpointer.")
358-
else:
359-
save_with_gc(self.states, checkpoint_id=checkpoint_id)
360-
self._purge_stale_checkpoints()
377+
if not self.ft_manager or self.ft_manager.participating_rank() == 0:
378+
logger.info("Saving the checkpoint (or staging if async is enabled).")
379+
checkpoint_id = self._create_checkpoint_id(curr_step)
380+
self._async_wait()
381+
# This GC is called for async checkpoint as it is useless to do
382+
# GC right after async_save -- the CPU memory is not able to be
383+
# freed until _async_wait()
384+
if force:
385+
self._save_last_step(curr_step)
386+
elif self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM:
387+
GarbageCollection.collect("GC collection invoked by checkpointer.")
388+
self._async_with_pinned_memory(checkpoint_id)
389+
elif self.async_mode == AsyncMode.ASYNC:
390+
GarbageCollection.collect("GC collection invoked by checkpointer.")
391+
self.async_future = dcp.async_save(
392+
self.states, checkpoint_id=checkpoint_id, process_group=self.pg
393+
)
394+
GarbageCollection.collect("GC collection invoked by checkpointer.")
395+
else:
396+
save_with_gc(self.states, checkpoint_id=checkpoint_id)
397+
self._purge_stale_checkpoints()
361398

362-
logger.info(
363-
"Finished saving the checkpoint (or staging if async is enabled)"
364-
f"in {time.monotonic() - begin:.2f} seconds."
365-
)
399+
logger.info(
400+
"Finished saving the checkpoint (or staging if async is enabled)"
401+
f"in {time.monotonic() - begin:.2f} seconds."
402+
)
403+
elif self.ft_manager:
404+
logger.info(
405+
"Replica %d doesn't save checkpoint.",
406+
self.ft_manager.participating_rank(),
407+
)
366408

367409
@torch.no_grad()
368410
def load(self, step: int = -1) -> bool:
@@ -379,6 +421,9 @@ def load(self, step: int = -1) -> bool:
379421
bool: Whether the checkpoint was loaded successfully.
380422
"""
381423

424+
if self.ft_manager:
425+
self._ft_load()
426+
382427
if not self.enable_checkpoint or not os.path.isdir(self.folder):
383428
return False
384429

@@ -462,10 +507,36 @@ def _find_load_step(self, folder: str = "") -> int:
462507
return -1
463508
return max(step_counts)
464509

510+
def _ft_folder(self) -> str:
511+
return os.path.join(self.folder, f"ft-replicat-{self.ft_replica_id}")
512+
465513
def _create_checkpoint_id(self, step: int, folder: str = "") -> str:
466514
folder = folder if folder else self.folder
467515
return os.path.join(folder, f"step-{step}")
468516

517+
def _ft_save(self, step: int) -> None:
518+
begin = time.monotonic()
519+
self._async_wait()
520+
checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder())
521+
self.async_future = dcp.async_save(
522+
self.ft_states, checkpoint_id=checkpoint_id, process_group=self.pg
523+
)
524+
logger.info(f"Staging ft checkpoint took {time.monotonic() - begin} secs.")
525+
526+
def _ft_load(self) -> None:
527+
step = self._find_load_step(folder=self._ft_folder())
528+
if step == -1:
529+
return
530+
531+
begin = time.monotonic()
532+
logger.info(f"Loading the FT checkpoint at step {step}.")
533+
checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder())
534+
dcp.load(self.ft_states, checkpoint_id=checkpoint_id)
535+
GarbageCollection.collect("GC collection for checkpoint loading.")
536+
logger.info(
537+
f"Finished loading the ft checkpoint in {time.monotonic() - begin:.2f} seconds."
538+
)
539+
469540
def _states_to_load(self, step: int) -> Dict[str, Any]:
470541
"""Determines which states to load for the given step.
471542
@@ -486,6 +557,8 @@ def _states_to_load(self, step: int) -> Dict[str, Any]:
486557
for exclude_key in self.exclude_from_loading:
487558
if exclude_key not in states:
488559
raise ValueError(f"{exclude_key} not found in state_dict.")
560+
if self.ft_manager:
561+
states_to_load.pop(DATALOADER)
489562
return states_to_load
490563

491564
def _save_last_step(self, curr_step: int) -> None:
@@ -570,6 +643,7 @@ def _cpu_staging(self, checkpoint_id: Optional[str]) -> None:
570643
def _purge_stale_checkpoints(self):
571644
if (
572645
self.keep_latest_k > 0
646+
and self.ft_manager.participating_rank() == 0
573647
and dist.get_rank() == 0
574648
and os.path.isdir(self.folder)
575649
):

‎torchtitan/components/optimizer.py

+51-7
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import copy
88
import functools
9-
from typing import Any, Callable, Dict, Iterable, List
9+
from typing import Any, Callable, Dict, Iterable, List, Optional
1010

1111
import torch
1212
import torch.nn as nn
@@ -177,8 +177,49 @@ def zero_grad(self) -> None:
177177
pass
178178

179179

180+
class FTOptimizersContainer(OptimizersContainer):
181+
def __init__(
182+
self,
183+
model_parts: List[nn.Module],
184+
optimizer_kwargs: Dict[str, Any],
185+
name: str,
186+
ft_manager: Optional["ft.Manager"],
187+
) -> None:
188+
import torchft as ft
189+
190+
super().__init__(model_parts, optimizer_kwargs, name)
191+
192+
# Force to initialize the optimizer state so that `optim.step()`
193+
# won't be called by state_dict() and load_state_dict().
194+
_ = {
195+
k: v
196+
for sd in map(get_optimizer_state_dict, model_parts, self.optimizers)
197+
for k, v in sd.items()
198+
}
199+
self.optimizers = [
200+
ft.Optimizer(ft_manager, optim) for optim in self.optimizers
201+
]
202+
self.cache_state_dict: Dict[str, Any] = {}
203+
204+
def init_cache_state_dict(self) -> None:
205+
self.cache_state_dict = super().state_dict()
206+
207+
def state_dict(self) -> Dict[str, Any]:
208+
return self.cache_state_dict
209+
210+
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
211+
# We have to invalidate the `cache_state_dict` because optimizer uses
212+
# assign instead of copy when doing `load_state_dict()`. Without
213+
# invalidating the `cache_state_dict`, there will be memory leakage.
214+
self.cache_state_dict = {}
215+
super().load_state_dict(state_dict)
216+
self.init_cache_state_dict()
217+
218+
180219
def build_optimizers(
181-
model_parts: List[nn.Module], job_config: JobConfig
220+
model_parts: List[nn.Module],
221+
job_config: JobConfig,
222+
ft_manager: Optional["ft.Manager"] = None,
182223
) -> OptimizersContainer:
183224
"""Create a OptimizersContainer for the given model parts and job config.
184225
@@ -219,11 +260,14 @@ def build_optimizers(
219260
"foreach": foreach,
220261
}
221262

222-
return (
223-
OptimizersContainer(model_parts, optimizer_kwargs, name)
224-
if not optim_in_bwd
225-
else OptimizersInBackwardContainer(model_parts, optimizer_kwargs, name)
226-
)
263+
if optim_in_bwd and ft_manager:
264+
raise ValueError("TorchFT is not supported with optimizers in backward.")
265+
elif optim_in_bwd:
266+
return OptimizersInBackwardContainer(model_parts, optimizer_kwargs, name)
267+
elif ft_manager:
268+
return FTOptimizersContainer(model_parts, optimizer_kwargs, name, ft_manager)
269+
else:
270+
return OptimizersContainer(model_parts, optimizer_kwargs, name)
227271

228272

229273
class LRSchedulersContainer(Stateful):

‎torchtitan/config_manager.py

+24
Original file line numberDiff line numberDiff line change
@@ -661,6 +661,30 @@ def __init__(self):
661661
action="store_true",
662662
)
663663

664+
self.parser.add_argument(
665+
"--experimental.enable_torchft",
666+
action="store_true",
667+
help="Enable TorchFT integration.",
668+
)
669+
670+
self.parser.add_argument(
671+
"--experimental.ft_replica_id",
672+
type=int,
673+
default=0,
674+
help="The TorchFT replica ID of this run.",
675+
)
676+
677+
self.parser.add_argument(
678+
"--experimental.ft_group_size",
679+
type=int,
680+
default=1,
681+
help="""
682+
The number of TorchFT replicate groups. This number will be used for
683+
dataloader to split the dataset across the replicate groups and FSDP
684+
dimension.
685+
""",
686+
)
687+
664688
def to_dict(self):
665689
return self.args_dict
666690

‎torchtitan/distributed/parallel_dims.py

+17-4
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from dataclasses import dataclass
88
from functools import cached_property
9+
from typing import Any, Optional
910

1011
from torch.distributed.device_mesh import init_device_mesh
1112

@@ -24,6 +25,7 @@ class ParallelDims:
2425
pp: int
2526
world_size: int
2627
enable_loss_parallel: bool
28+
ft_manager: Optional["ft.Manager"]
2729

2830
def __post_init__(self):
2931
self._validate()
@@ -56,13 +58,24 @@ def build_mesh(self, device_type):
5658
[self.pp, self.dp_replicate, self.dp_shard, self.cp, self.tp],
5759
["pp", "dp_replicate", "dp_shard", "cp", "tp"],
5860
):
59-
if d > 1:
61+
if d > 1 or (name == "dp_replicate" and self.ft_manager is not None):
6062
dims.append(d)
6163
names.append(name)
6264

6365
logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}")
6466
names = tuple(names)
65-
mesh = init_device_mesh(device_type, dims, mesh_dim_names=names)
67+
if self.ft_manager is None:
68+
mesh = init_device_mesh(device_type, dims, mesh_dim_names=names)
69+
else:
70+
from torchft.process_group import ft_init_device_mesh
71+
72+
mesh = ft_init_device_mesh(
73+
device_type=device_type,
74+
mesh_shape=dims,
75+
mesh_dim_names=names,
76+
replicate_dim=names.index("dp_replicate"),
77+
manager=self.ft_manager,
78+
)
6679

6780
# Create all the submesh here to ensure all required process groups are
6881
# initialized:
@@ -73,7 +86,7 @@ def build_mesh(self, device_type):
7386
# Mesh for loss all-reduce
7487
dp_cp_mesh_dim_names = []
7588

76-
if self.dp_replicate_enabled:
89+
if self.dp_replicate_enabled or self.ft_manager is not None:
7790
dp_mesh_dim_names.append("dp_replicate")
7891
dp_cp_mesh_dim_names.append("dp_replicate")
7992
if self.dp_shard_enabled:
@@ -101,7 +114,7 @@ def dp_enabled(self):
101114

102115
@property
103116
def dp_replicate_enabled(self):
104-
return self.dp_replicate > 1
117+
return self.dp_replicate > 1 or self.ft_manager is not None
105118

106119
@property
107120
def dp_shard_enabled(self):

0 commit comments

Comments
 (0)
Failed to load comments.