Skip to content

Commit 5409a57

Browse files
committedFeb 21, 2025
Integrate TorchFT
**Summary** This is a WIP TorchFT integration PR. **Current Issues** This doesn't work at this moment as there are hanged groups when a new group joins. **Issue 1:** ~Group 0 and group 1 will hang during the first `should_commit` after group 1 applying the pending state_dict from group 0.~ Fixed with: pytorch/torchft#83 **Issue 2:** ~Group 0 and group 1 will pass the `should_commit` but group 0 needs healing which is wrong and the healing process will cause another hang.~ Fixed with: pytorch/torchft#83 **Issue 3:** ~The byproduct of issue 1 and issue 2: group 1 will continue to print out~ ``` [rank0]:devgpu051:76838:80357 [0] misc/socket.cc:50 NCCL WARN socketProgress: Connection closed by remote peer devgpu051.cln3.svc.fbinfra.net<33618> ``` Fixed with pytorch/torchft#91 and several other fixes. **Issue 4:** When there are 3 groups, everyone requests the state dict every step. ***How to reproduce?*** Using the `Reproduce steps` to run 2 groups, then add another group by modifying the command. Seems to be fixed, will need more tests. **Issue 5:** Hang will happen if using functional collective. ***How to reproduce?*** Pull the latest version of this PR and comment out line 41 and uncomment line 42 in `torchtitan/utils.py` **Reproduce steps:** 1. Patch TorchFT with pytorch/torchft#82 2. Execute lighthouse 3. Execute the following command in one terminal: ``` TORCHFT_MANAGER_PORT=29520 REPLICA_GROUP_ID=0 CUDA_VISIBLE_DEVICES=0,1 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=0 ``` 4. Wait 10 seconds, execute following command in another terminal: ``` TORCHFT_MANAGER_PORT=29522 REPLICA_GROUP_ID=1 CUDA_VISIBLE_DEVICES=2,3 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=1 ``` ghstack-source-id: 208678eaff2d5aaf87ad0c157f0e30b2afe15a9d Pull Request resolved: #834
1 parent 1f5adfe commit 5409a57

File tree

9 files changed

+281
-42
lines changed

9 files changed

+281
-42
lines changed
 

‎run_llama_train.sh

+4
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@ if [ $# -ne 0 ]; then
1919
overrides="$*"
2020
fi
2121

22+
TORCHFT_MANAGER_PORT=${TORCHFT_MANAGER_PORT:-"29512"}
23+
2224
PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \
25+
TORCHFT_LIGHTHOUSE=http://localhost:29510 \
26+
TORCHFT_MANAGER_PORT=${TORCHFT_MANAGER_PORT} \
2327
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
2428
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
2529
train.py --job.config_file ${CONFIG_FILE} $overrides

‎torchtitan/checkpoint.py

+99-27
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
3535

36+
from torchtitan.ft import FTManager
3637
from torchtitan.logging import init_logger, logger
3738
from torchtitan.optimizer import LRSchedulersContainer, OptimizersContainer
3839
from torchtitan.utils import GarbageCollection
@@ -228,6 +229,7 @@ class CheckpointManager:
228229
states (Dict[str, Any]): The states that need to be saved, other than the
229230
previous 4 components.
230231
job_config (JobConfig): The job config used to configure the checkpointing.
232+
ft_manager (Optional[FTManager]): The FTManager from TorchFT.
231233
"""
232234

233235
def __init__(
@@ -238,16 +240,40 @@ def __init__(
238240
lr_schedulers: LRSchedulersContainer,
239241
states: Dict[str, Any],
240242
job_config: JobConfig,
243+
ft_manager: Optional[FTManager] = None,
241244
) -> None:
242245
ckpt_config = job_config.checkpoint
243246
self.enable_checkpoint = ckpt_config.enable_checkpoint
247+
self.ft_manager = ft_manager
248+
249+
if self.ft_manager:
250+
optimizers.init_cache_state_dict()
251+
252+
def state_dict():
253+
ret = {}
254+
for k, v in self.states.items():
255+
if k in {
256+
CheckpointState.MODEL,
257+
CheckpointState.OPTIMIZER,
258+
CheckpointState.LR_SCHEDULER,
259+
CheckpointState.TRAIN_STATE,
260+
}:
261+
ret[k] = v.state_dict()
262+
return ret
263+
264+
def load_state_dict(state_dict):
265+
assert state_dict is not None
266+
for k, v in state_dict.items():
267+
self.states[k].load_state_dict(v)
268+
269+
ft_manager.manager.set_state_dict_fns(load_state_dict, state_dict)
244270

245271
async_mode = ckpt_config.async_mode.lower()
246272
self.enable_staging = (
247273
self.enable_checkpoint and async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM
248-
)
274+
) or self.ft_manager
249275

250-
if not self.enable_checkpoint:
276+
if not self.enable_checkpoint and self.ft_manager is None:
251277
return
252278

253279
self.states = states
@@ -259,6 +285,13 @@ def __init__(
259285
LR_SCHEDULER: lr_schedulers,
260286
}
261287
)
288+
self.ft_states = {CheckpointState.DATALOADER: dataloader}
289+
290+
self.staging = False
291+
self.sending_to_checkpoint_mp = False
292+
self.staging_id = None
293+
self.cpu_offload_state_dict = None
294+
self.staging_stream = torch.cuda.Stream() if self.enable_staging else None
262295

263296
self.staging = False
264297
self.sending_to_checkpoint_mp = False
@@ -269,7 +302,7 @@ def __init__(
269302
self.folder = os.path.join(job_config.job.dump_folder, ckpt_config.folder)
270303
self.interval = ckpt_config.interval
271304
async_mode = ckpt_config.async_mode.lower()
272-
if async_mode == AsyncMode.ASYNC:
305+
if async_mode == AsyncMode.ASYNC or self.ft_manager:
273306
self.pg = dist.new_group(backend="gloo")
274307

275308
self.keep_latest_k = ckpt_config.keep_latest_k
@@ -343,35 +376,42 @@ def save(self, curr_step: int, force: bool = False) -> None:
343376
None
344377
"""
345378

379+
if self.ft_manager:
380+
self._ft_save(curr_step)
381+
346382
if not self._should_save(curr_step, force):
347383
return
348384

349385
begin = time.monotonic()
350-
logger.info("Saving the checkpoint (or staging if async is enabled).")
351-
checkpoint_id = self._create_checkpoint_id(curr_step)
352-
self._async_wait()
353-
# This GC is called for async checkpoint as it is useless to do
354-
# GC right after async_save -- the CPU memory is not able to be
355-
# freed until _async_wait()
356-
if force:
357-
self._save_last_step(curr_step)
358-
elif self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM:
359-
GarbageCollection.collect("GC collection invoked by checkpointer.")
360-
self._async_with_pinned_memory(checkpoint_id)
361-
elif self.async_mode == AsyncMode.ASYNC:
362-
GarbageCollection.collect("GC collection invoked by checkpointer.")
363-
self.async_future = dcp.async_save(
364-
self.states, checkpoint_id=checkpoint_id, process_group=self.pg
365-
)
366-
GarbageCollection.collect("GC collection invoked by checkpointer.")
367-
else:
368-
save_with_gc(self.states, checkpoint_id=checkpoint_id)
369-
self._purge_stale_checkpoints()
386+
if not self.ft_manager or self.ft_manager.manager.participating_rank() == 0:
387+
logger.info("Saving the checkpoint (or staging if async is enabled).")
388+
checkpoint_id = self._create_checkpoint_id(curr_step)
389+
self._async_wait()
390+
# This GC is called for async checkpoint as it is useless to do
391+
# GC right after async_save -- the CPU memory is not able to be
392+
# freed until _async_wait()
393+
if force:
394+
self._save_last_step(curr_step)
395+
elif self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM:
396+
GarbageCollection.collect("GC collection invoked by checkpointer.")
397+
self._async_with_pinned_memory(checkpoint_id)
398+
elif self.async_mode == AsyncMode.ASYNC:
399+
GarbageCollection.collect("GC collection invoked by checkpointer.")
400+
self.async_future = dcp.async_save(
401+
self.states, checkpoint_id=checkpoint_id, process_group=self.pg
402+
)
403+
GarbageCollection.collect("GC collection invoked by checkpointer.")
404+
else:
405+
save_with_gc(self.states, checkpoint_id=checkpoint_id)
406+
self._purge_stale_checkpoints()
370407

371-
logger.info(
372-
"Finished saving the checkpoint (or staging if async is enabled)"
373-
f"in {time.monotonic() - begin:.2f} seconds."
374-
)
408+
logger.info(
409+
"Finished saving the checkpoint (or staging if async is enabled)"
410+
f"in {time.monotonic() - begin:.2f} seconds."
411+
)
412+
elif self.ft_manager:
413+
logger.info("Waiting for replica 0 to save checkpoint.")
414+
time.sleep(1)
375415

376416
@torch.no_grad()
377417
def load(self, step: int = -1) -> bool:
@@ -388,6 +428,9 @@ def load(self, step: int = -1) -> bool:
388428
bool: Whether the checkpoint was loaded successfully.
389429
"""
390430

431+
if self.ft_manager:
432+
self._ft_load()
433+
391434
if not self.enable_checkpoint or not os.path.isdir(self.folder):
392435
return False
393436

@@ -471,10 +514,36 @@ def _find_load_step(self, folder: str = "") -> int:
471514
return -1
472515
return max(step_counts)
473516

517+
def _ft_folder(self) -> str:
518+
return os.path.join(self.folder, f"ft-replicat-{self.ft_manager.replica_id}")
519+
474520
def _create_checkpoint_id(self, step: int, folder: str = "") -> str:
475521
folder = folder if folder else self.folder
476522
return os.path.join(folder, f"step-{step}")
477523

524+
def _ft_save(self, step: int) -> None:
525+
begin = time.monotonic()
526+
self._async_wait()
527+
checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder())
528+
self.async_future = dcp.async_save(
529+
self.ft_states, checkpoint_id=checkpoint_id, process_group=self.pg
530+
)
531+
logger.info(f"Staging ft checkpoint took {time.monotonic() - begin} secs.")
532+
533+
def _ft_load(self) -> None:
534+
step = self._find_load_step(folder=self._ft_folder())
535+
if step == -1:
536+
return
537+
538+
begin = time.monotonic()
539+
logger.info(f"Loading the FT checkpoint at step {step}.")
540+
checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder())
541+
dcp.load(self.ft_states, checkpoint_id=checkpoint_id)
542+
GarbageCollection.collect("GC collection for checkpoint loading.")
543+
logger.info(
544+
f"Finished loading the ft checkpoint in {time.monotonic() - begin:.2f} seconds."
545+
)
546+
478547
def _states_to_load(self, step: int) -> Dict[str, Any]:
479548
"""Determines which states to load for the given step.
480549
@@ -495,6 +564,8 @@ def _states_to_load(self, step: int) -> Dict[str, Any]:
495564
for exclude_key in self.exclude_from_loading:
496565
if exclude_key not in states:
497566
raise ValueError(f"{exclude_key} not found in state_dict.")
567+
if self.ft_manager:
568+
states_to_load.pop(CheckpointState.DATALOADER)
498569
return states_to_load
499570

500571
def _save_last_step(self, curr_step: int) -> None:
@@ -579,6 +650,7 @@ def _cpu_staging(self, checkpoint_id: Optional[str]) -> None:
579650
def _purge_stale_checkpoints(self):
580651
if (
581652
self.keep_latest_k > 0
653+
and self.ft_manager.manager.participating_rank() == 0
582654
and dist.get_rank() == 0
583655
and os.path.isdir(self.folder)
584656
):

‎torchtitan/config_manager.py

+24
Original file line numberDiff line numberDiff line change
@@ -658,6 +658,30 @@ def __init__(self):
658658
action="store_true",
659659
)
660660

661+
self.parser.add_argument(
662+
"--experimental.enable_torchft",
663+
action="store_true",
664+
help="Enable TorchFT integration.",
665+
)
666+
667+
self.parser.add_argument(
668+
"--experimental.ft_replica_id",
669+
type=int,
670+
default=0,
671+
help="The TorchFT replica ID of this run.",
672+
)
673+
674+
self.parser.add_argument(
675+
"--experimental.ft_group_size",
676+
type=int,
677+
default=1,
678+
help="""
679+
The number of TorchFT replicate groups. This number will be used for
680+
dataloader to split the dataset across the replicate groups and FSDP
681+
dimension.
682+
""",
683+
)
684+
661685
def to_dict(self):
662686
return self.args_dict
663687

‎torchtitan/ft.py

+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import importlib
8+
from dataclasses import dataclass
9+
from typing import Optional
10+
11+
from torchtitan.config_manager import JobConfig
12+
13+
if importlib.util.find_spec("torchft") is not None:
14+
import torchft as ft
15+
16+
has_torchft = True
17+
else:
18+
has_torchft = False
19+
20+
21+
@dataclass
22+
class FTManager:
23+
manager: ft.Manager
24+
replica_id: int
25+
group_size: int
26+
27+
28+
def init_ft_manager(job: JobConfig) -> Optional[FTManager]:
29+
"""Initialize the FT manager if TorchFT is enabled.
30+
31+
Args:
32+
job (JobConfig): The job configuration.
33+
34+
Returns:
35+
Optional[FTManager]: The FT manager if TorchFT is enabled, otherwise None.
36+
"""
37+
if not job.experimental.enable_torchft:
38+
return None
39+
40+
if not has_torchft:
41+
raise ImportError("torchft is not installed. Please install it.")
42+
43+
pg = ft.ProcessGroupBabyNCCL()
44+
manager = ft.Manager(
45+
pg=pg,
46+
min_replica_size=1,
47+
load_state_dict=None,
48+
state_dict=None,
49+
use_async_quorum=True,
50+
replica_id=f"torchtitan_ft_{job.experimental.ft_replica_id}",
51+
)
52+
53+
return FTManager(
54+
manager,
55+
job.experimental.ft_replica_id,
56+
job.experimental.ft_group_size,
57+
)

‎torchtitan/optimizer.py

+51-7
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import copy
88
import functools
9-
from typing import Any, Callable, Dict, Iterable, List
9+
from typing import Any, Callable, Dict, Iterable, List, Optional
1010

1111
import torch
1212
import torch.nn as nn
@@ -177,8 +177,49 @@ def zero_grad(self) -> None:
177177
pass
178178

179179

180+
class FTOptimizersContainer(OptimizersContainer):
181+
def __init__(
182+
self,
183+
model_parts: List[nn.Module],
184+
optimizer_kwargs: Dict[str, Any],
185+
name: str,
186+
ft_manager: Any,
187+
) -> None:
188+
import torchft as ft
189+
190+
super().__init__(model_parts, optimizer_kwargs, name)
191+
192+
# Force to initialize the optimizer state so that `optim.step()`
193+
# won't be called by state_dict() and load_state_dict().
194+
_ = {
195+
k: v
196+
for sd in map(get_optimizer_state_dict, model_parts, self.optimizers)
197+
for k, v in sd.items()
198+
}
199+
self.optimizers = [
200+
ft.Optimizer(ft_manager.manager, optim) for optim in self.optimizers
201+
]
202+
self.cache_state_dict: Dict[str, Any] = {}
203+
204+
def init_cache_state_dict(self) -> None:
205+
self.cache_state_dict = super().state_dict()
206+
207+
def state_dict(self) -> Dict[str, Any]:
208+
return self.cache_state_dict
209+
210+
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
211+
# We have to invalidate the `cache_state_dict` because optimizer uses
212+
# assign instead of copy when doing `load_state_dict()`. Without
213+
# invalidating the `cache_state_dict`, there will be memory leakage.
214+
self.cache_state_dict = {}
215+
super().load_state_dict(state_dict)
216+
self.init_cache_state_dict()
217+
218+
180219
def build_optimizers(
181-
model_parts: List[nn.Module], job_config: JobConfig
220+
model_parts: List[nn.Module],
221+
job_config: JobConfig,
222+
ft_manager: Optional[Any] = None,
182223
) -> OptimizersContainer:
183224
"""Create a OptimizersContainer for the given model parts and job config.
184225
@@ -213,11 +254,14 @@ def build_optimizers(
213254
"foreach": not fused,
214255
}
215256

216-
return (
217-
OptimizersContainer(model_parts, optimizer_kwargs, name)
218-
if not optim_in_bwd
219-
else OptimizersInBackwardContainer(model_parts, optimizer_kwargs, name)
220-
)
257+
if optim_in_bwd and ft_manager:
258+
raise ValueError("TorchFT is not supported with optimizers in backward.")
259+
elif optim_in_bwd:
260+
return OptimizersInBackwardContainer(model_parts, optimizer_kwargs, name)
261+
elif ft_manager:
262+
return FTOptimizersContainer(model_parts, optimizer_kwargs, name, ft_manager)
263+
else:
264+
return OptimizersContainer(model_parts, optimizer_kwargs, name)
221265

222266

223267
class LRSchedulersContainer(Stateful):

0 commit comments

Comments
 (0)
Failed to load comments.