Skip to content

Commit 3aa61e5

Browse files
committedFeb 28, 2025
Integrate TorchFT
**Summary** This is a WIP TorchFT integration PR. **Current Issues** This doesn't work at this moment as there are hanged groups when a new group joins. **Issue 1:** ~Group 0 and group 1 will hang during the first `should_commit` after group 1 applying the pending state_dict from group 0.~ Fixed with: pytorch/torchft#83 **Issue 2:** ~Group 0 and group 1 will pass the `should_commit` but group 0 needs healing which is wrong and the healing process will cause another hang.~ Fixed with: pytorch/torchft#83 **Issue 3:** ~The byproduct of issue 1 and issue 2: group 1 will continue to print out~ ``` [rank0]:devgpu051:76838:80357 [0] misc/socket.cc:50 NCCL WARN socketProgress: Connection closed by remote peer devgpu051.cln3.svc.fbinfra.net<33618> ``` Fixed with pytorch/torchft#91 and several other fixes. **Issue 4:** When there are 3 groups, everyone requests the state dict every step. ***How to reproduce?*** Using the `Reproduce steps` to run 2 groups, then add another group by modifying the command. Seems to be fixed, will need more tests. **Issue 5:** Hang will happen if using functional collective. ***How to reproduce?*** Pull the latest version of this PR and comment out line 41 and uncomment line 42 in `torchtitan/utils.py` **Reproduce steps:** 1. Patch TorchFT with pytorch/torchft#82 2. Execute lighthouse 3. Execute the following command in one terminal: ``` TORCHFT_MANAGER_PORT=29520 REPLICA_GROUP_ID=0 CUDA_VISIBLE_DEVICES=0,1 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=0 ``` 4. Wait 10 seconds, execute following command in another terminal: ``` TORCHFT_MANAGER_PORT=29522 REPLICA_GROUP_ID=1 CUDA_VISIBLE_DEVICES=2,3 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=1 ``` ghstack-source-id: f07ae76e95c450994ba418c06b9bb064275fc974 Pull Request resolved: #834
1 parent ec82573 commit 3aa61e5

File tree

10 files changed

+385
-48
lines changed

10 files changed

+385
-48
lines changed
 

‎run_train.sh

+3
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@ if [ $# -ne 0 ]; then
1919
overrides="$*"
2020
fi
2121

22+
TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE:-"http://localhost:29510"}
23+
2224
PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \
25+
TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE} \
2326
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
2427
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
2528
torchtitan/train.py --job.config_file ${CONFIG_FILE} $overrides

‎tests/unit_tests/test_checkpoint.py

+8
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,18 @@ class DummyJob:
6161
dump_folder: str = "dummy_folder"
6262

6363

64+
@dataclass
65+
class DummyExperimental:
66+
ft_replica_id = 0
67+
ft_group_size = 1
68+
69+
6470
@dataclass
6571
class DummyJobConfig:
6672
checkpoint: DummyCheckpointConfig = field(default_factory=DummyCheckpointConfig)
6773
job: DummyJob = field(default_factory=DummyJob)
74+
experimental: DummyExperimental = field(default_factory=DummyExperimental)
75+
ft_manager = None
6876

6977

7078
# Dummy instances to supply as constructor arguments.

‎tests/unit_tests/test_model_converter.py

+1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def build_parallel_dims(job_config, world_size):
2222
pp=job_config.experimental.pipeline_parallel_degree,
2323
world_size=world_size,
2424
enable_loss_parallel=not job_config.training.disable_loss_parallel,
25+
ft_manager=None,
2526
)
2627
return parallel_dims
2728

‎torchtitan/components/checkpoint.py

+115-27
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from torch.distributed.checkpoint.stateful import Stateful
3232
from torch.utils.data import DataLoader
3333

34+
from torchtitan.components.ft import FTManager
3435
from torchtitan.components.optimizer import LRSchedulersContainer, OptimizersContainer
3536
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
3637
from torchtitan.tools.logging import init_logger, logger
@@ -214,6 +215,19 @@ class CheckpointManager:
214215
3. LR schedulers also index model states like optimizers. Here we flatten the lr_schedulers
215216
with the assumption that all lr_schedulers have the same state_dict.
216217
218+
Note: TorchFT checkpointing flow
219+
220+
There are two types of checkpoints: when TorchFT is enabled: 1) the full perisistent
221+
checkpoint, 2) the per-replica checkpoint.
222+
223+
The full perisistent checkpoint is saved by the replica with
224+
``ft_manager.participating_rank() == 0``. It contains everything including the model,
225+
optimizer, lr_scheduler, dataloader, and train_state. Right now the full perisistent
226+
checkpoint is loaded by all replicas. However, we can optimize it to only load if
227+
there are no other alive replicas.
228+
229+
The per-replica checkpoint contains only the dataloader and is saved/loaded by all
230+
replicas to/from the its own folder. The folder name is prefixed with the ft_replica_id.
217231
218232
Args:
219233
dataloader (DataLoader): The dataloader used to load the data.
@@ -223,6 +237,7 @@ class CheckpointManager:
223237
states (Dict[str, Any]): The states that need to be saved, other than the
224238
previous 4 components.
225239
job_config (JobConfig): The job config used to configure the checkpointing.
240+
ft_manager (Optional[ft.Manager]): The FTManager from TorchFT.
226241
"""
227242

228243
def __init__(
@@ -233,16 +248,41 @@ def __init__(
233248
lr_schedulers: LRSchedulersContainer,
234249
states: Dict[str, Any],
235250
job_config: JobConfig,
251+
ft_manager: FTManager,
236252
) -> None:
237253
ckpt_config = job_config.checkpoint
238254
self.enable_checkpoint = ckpt_config.enable_checkpoint
255+
self.ft_manager = ft_manager.manager if ft_manager.enabled else None
256+
257+
if self.ft_manager:
258+
optimizers.init_cache_state_dict()
259+
260+
def state_dict():
261+
ret = {}
262+
for k, v in self.states.items():
263+
if k in {
264+
MODEL,
265+
OPTIMIZER,
266+
LR_SCHEDULER,
267+
TRAIN_STATE,
268+
}:
269+
ret[k] = v.state_dict()
270+
return ret
271+
272+
def load_state_dict(state_dict):
273+
assert state_dict is not None
274+
for k, v in state_dict.items():
275+
self.states[k].load_state_dict(v)
276+
277+
self.ft_manager.set_state_dict_fns(load_state_dict, state_dict)
278+
self.ft_replica_id = job_config.experimental.ft_replica_id
239279

240280
async_mode = ckpt_config.async_mode.lower()
241281
self.enable_staging = (
242282
self.enable_checkpoint and async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM
243-
)
283+
) or self.ft_manager
244284

245-
if not self.enable_checkpoint:
285+
if not self.enable_checkpoint and self.ft_manager is None:
246286
return
247287

248288
self.states = states
@@ -254,6 +294,13 @@ def __init__(
254294
LR_SCHEDULER: lr_schedulers,
255295
}
256296
)
297+
self.ft_states = {DATALOADER: dataloader}
298+
299+
self.staging = False
300+
self.sending_to_checkpoint_mp = False
301+
self.staging_id = None
302+
self.cpu_offload_state_dict = None
303+
self.staging_stream = torch.cuda.Stream() if self.enable_staging else None
257304

258305
self.staging = False
259306
self.sending_to_checkpoint_mp = False
@@ -264,7 +311,7 @@ def __init__(
264311
self.folder = os.path.join(job_config.job.dump_folder, ckpt_config.folder)
265312
self.interval = ckpt_config.interval
266313
async_mode = ckpt_config.async_mode.lower()
267-
if async_mode == AsyncMode.ASYNC:
314+
if async_mode == AsyncMode.ASYNC or self.ft_manager:
268315
self.pg = dist.new_group(backend="gloo")
269316

270317
self.keep_latest_k = ckpt_config.keep_latest_k
@@ -339,35 +386,44 @@ def save(self, curr_step: int, force: bool = False) -> None:
339386
None
340387
"""
341388

389+
if self.ft_manager:
390+
self._ft_save(curr_step)
391+
342392
if not self._should_save(curr_step, force):
343393
return
344394

345395
begin = time.monotonic()
346-
logger.info("Saving the checkpoint (or staging if async is enabled).")
347-
checkpoint_id = self._create_checkpoint_id(curr_step)
348-
self._async_wait()
349-
# This GC is called for async checkpoint as it is useless to do
350-
# GC right after async_save -- the CPU memory is not able to be
351-
# freed until _async_wait()
352-
if force:
353-
self._save_last_step(curr_step)
354-
elif self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM:
355-
GarbageCollection.collect("GC collection invoked by checkpointer.")
356-
self._async_with_pinned_memory(checkpoint_id)
357-
elif self.async_mode == AsyncMode.ASYNC:
358-
GarbageCollection.collect("GC collection invoked by checkpointer.")
359-
self.async_future = dcp.async_save(
360-
self.states, checkpoint_id=checkpoint_id, process_group=self.pg
361-
)
362-
GarbageCollection.collect("GC collection invoked by checkpointer.")
363-
else:
364-
save_with_gc(self.states, checkpoint_id=checkpoint_id)
365-
self._purge_stale_checkpoints()
396+
if not self.ft_manager or self.ft_manager.participating_rank() == 0:
397+
logger.info("Saving the checkpoint (or staging if async is enabled).")
398+
checkpoint_id = self._create_checkpoint_id(curr_step)
399+
self._async_wait()
400+
# This GC is called for async checkpoint as it is useless to do
401+
# GC right after async_save -- the CPU memory is not able to be
402+
# freed until _async_wait()
403+
if force:
404+
self._save_last_step(curr_step)
405+
elif self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM:
406+
GarbageCollection.collect("GC collection invoked by checkpointer.")
407+
self._async_with_pinned_memory(checkpoint_id)
408+
elif self.async_mode == AsyncMode.ASYNC:
409+
GarbageCollection.collect("GC collection invoked by checkpointer.")
410+
self.async_future = dcp.async_save(
411+
self.states, checkpoint_id=checkpoint_id, process_group=self.pg
412+
)
413+
GarbageCollection.collect("GC collection invoked by checkpointer.")
414+
else:
415+
save_with_gc(self.states, checkpoint_id=checkpoint_id)
416+
self._purge_stale_checkpoints()
366417

367-
logger.info(
368-
"Finished saving the checkpoint (or staging if async is enabled)"
369-
f"in {time.monotonic() - begin:.2f} seconds."
370-
)
418+
logger.info(
419+
"Finished saving the checkpoint (or staging if async is enabled)"
420+
f"in {time.monotonic() - begin:.2f} seconds."
421+
)
422+
elif self.ft_manager:
423+
logger.info(
424+
"Replica %d doesn't save checkpoint.",
425+
self.ft_manager.participating_rank(),
426+
)
371427

372428
@torch.no_grad()
373429
def load(self, step: int = -1) -> bool:
@@ -384,6 +440,9 @@ def load(self, step: int = -1) -> bool:
384440
bool: Whether the checkpoint was loaded successfully.
385441
"""
386442

443+
if self.ft_manager:
444+
self._ft_load()
445+
387446
if not self.enable_checkpoint or not os.path.isdir(self.folder):
388447
return False
389448

@@ -467,10 +526,36 @@ def _find_load_step(self, folder: str = "") -> int:
467526
return -1
468527
return max(step_counts)
469528

529+
def _ft_folder(self) -> str:
530+
return os.path.join(self.folder, f"ft-replicat-{self.ft_replica_id}")
531+
470532
def _create_checkpoint_id(self, step: int, folder: str = "") -> str:
471533
folder = folder if folder else self.folder
472534
return os.path.join(folder, f"step-{step}")
473535

536+
def _ft_save(self, step: int) -> None:
537+
begin = time.monotonic()
538+
self._async_wait()
539+
checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder())
540+
self.async_future = dcp.async_save(
541+
self.ft_states, checkpoint_id=checkpoint_id, process_group=self.pg
542+
)
543+
logger.info(f"Staging ft checkpoint took {time.monotonic() - begin} secs.")
544+
545+
def _ft_load(self) -> None:
546+
step = self._find_load_step(folder=self._ft_folder())
547+
if step == -1:
548+
return
549+
550+
begin = time.monotonic()
551+
logger.info(f"Loading the FT checkpoint at step {step}.")
552+
checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder())
553+
dcp.load(self.ft_states, checkpoint_id=checkpoint_id)
554+
GarbageCollection.collect("GC collection for checkpoint loading.")
555+
logger.info(
556+
f"Finished loading the ft checkpoint in {time.monotonic() - begin:.2f} seconds."
557+
)
558+
474559
def _states_to_load(self, step: int) -> Dict[str, Any]:
475560
"""Determines which states to load for the given step.
476561
@@ -491,6 +576,8 @@ def _states_to_load(self, step: int) -> Dict[str, Any]:
491576
for exclude_key in self.exclude_from_loading:
492577
if exclude_key not in states:
493578
raise ValueError(f"{exclude_key} not found in state_dict.")
579+
if self.ft_manager:
580+
states_to_load.pop(DATALOADER)
494581
return states_to_load
495582

496583
def _save_last_step(self, curr_step: int) -> None:
@@ -577,6 +664,7 @@ def _purge_stale_checkpoints(self):
577664
self.keep_latest_k > 0
578665
and dist.get_rank() == 0
579666
and os.path.isdir(self.folder)
667+
and (not self.ft_manager or self.ft_manager.participating_rank() == 0)
580668
):
581669
discovered_checkpoints = []
582670
for filename in os.listdir(self.folder):

‎torchtitan/components/ft.py

+78
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import importlib
8+
from typing import Optional
9+
10+
from torchtitan.config_manager import JobConfig
11+
12+
if importlib.util.find_spec("torchft") is not None:
13+
import torchft as ft
14+
15+
has_torchft = True
16+
else:
17+
has_torchft = False
18+
19+
20+
class FTManager:
21+
def __init__(
22+
self,
23+
manager: Optional["ft.Manager"],
24+
group_size: int = 1,
25+
replica_id: int = 0,
26+
) -> None:
27+
self._manager = manager
28+
self.group_size = group_size
29+
self.replica_id = replica_id
30+
31+
@property
32+
def enabled(self) -> bool:
33+
return self._manager is not None
34+
35+
@property
36+
def manager(self) -> "ft.Manager":
37+
assert self._manager is not None
38+
return self._manager
39+
40+
def get_dp_rank(self, dp_degree: int, dp_rank: int) -> int:
41+
return dp_degree * self.replica_id + dp_rank
42+
43+
def get_dp_degree(self, dp_degree: int) -> int:
44+
return dp_degree * self.group_size
45+
46+
47+
def init_ft_manager(job: JobConfig) -> FTManager:
48+
"""Initialize the FT manager if TorchFT is enabled.
49+
50+
Args:
51+
job (JobConfig): The job configuration.
52+
53+
Returns:
54+
Optional[ft.Manager]: The FT manager if TorchFT is enabled, otherwise None.
55+
"""
56+
if not job.experimental.enable_torchft:
57+
return FTManager(None)
58+
59+
if not has_torchft:
60+
raise ImportError("torchft is not installed. Please install it.")
61+
62+
if job.experimental.ft_min_replica_size < 1:
63+
raise ValueError("At least one FT replica is required.")
64+
65+
pg = ft.ProcessGroupBabyNCCL()
66+
67+
return FTManager(
68+
ft.Manager(
69+
pg=pg,
70+
min_replica_size=job.experimental.ft_min_replica_size,
71+
load_state_dict=None,
72+
state_dict=None,
73+
use_async_quorum=True,
74+
replica_id=f"torchtitan_ft_{job.experimental.ft_replica_id}",
75+
),
76+
group_size=job.experimental.ft_group_size,
77+
replica_id=job.experimental.ft_replica_id,
78+
)

0 commit comments

Comments
 (0)
Failed to load comments.