From ac85c3fd3370e20e2664200e7b2ac5347ade60f1 Mon Sep 17 00:00:00 2001 From: Tushar Jain Date: Fri, 1 Aug 2025 12:43:55 -0700 Subject: [PATCH 1/2] remove recovery form regression test Summary: - we currently do some validation on the training in the regression test - the force recovery on first step interferes with this because it makes the test non determinstic, particularly because after the recovery, replica takes non deterministic number of steps that makes the gradients non determinstic - to fix this, perform a quorum inside fake training loop for the regression test before doing any training - we also need to increase manager step count by 2, so we do 2 should_commit, because we have 2 fragments and we're testing numerics as if we started from step 0 -- starting from step 2 gives us the same sync schedule for fragments as starting from step 0 --- ...test_diloco_mocked_failure_recovery_0.json | 192 ++++++++++++++++++ torchft/diloco_regression_test.py | 18 +- 2 files changed, 209 insertions(+), 1 deletion(-) diff --git a/test_fixtures/torchft.diloco_regression_test.DiLoCoMockedUpdateTest.test_diloco_mocked_failure_recovery_0.json b/test_fixtures/torchft.diloco_regression_test.DiLoCoMockedUpdateTest.test_diloco_mocked_failure_recovery_0.json index 1111b06..40b6732 100644 --- a/test_fixtures/torchft.diloco_regression_test.DiLoCoMockedUpdateTest.test_diloco_mocked_failure_recovery_0.json +++ b/test_fixtures/torchft.diloco_regression_test.DiLoCoMockedUpdateTest.test_diloco_mocked_failure_recovery_0.json @@ -193,6 +193,78 @@ -53.0 ] ] + }, + "16": { + "layers.0.weight": [ + [ + -61.0 + ] + ], + "layers.1.weight": [ + [ + -55.0 + ] + ] + }, + "17": { + "layers.0.weight": [ + [ + -63.0 + ] + ], + "layers.1.weight": [ + [ + -57.0 + ] + ] + }, + "18": { + "layers.0.weight": [ + [ + -65.0 + ] + ], + "layers.1.weight": [ + [ + -71.0 + ] + ] + }, + "19": { + "layers.0.weight": [ + [ + -67.0 + ] + ], + "layers.1.weight": [ + [ + -73.0 + ] + ] + }, + "20": { + "layers.0.weight": [ + [ + -69.0 + ] + ], + "layers.1.weight": [ + [ + -75.0 + ] + ] + }, + "21": { + "layers.0.weight": [ + [ + -83.0 + ] + ], + "layers.1.weight": [ + [ + -77.0 + ] + ] } }, "global_parameter_history": { @@ -255,6 +327,30 @@ -47.0 ] ] + }, + "15": { + "layers.0.weight": [ + [ + -59.0 + ] + ], + "layers.1.weight": [ + [ + -47.0 + ] + ] + }, + "18": { + "layers.0.weight": [ + [ + -59.0 + ] + ], + "layers.1.weight": [ + [ + -71.0 + ] + ] } } } @@ -381,6 +477,78 @@ -53.0 ] ] + }, + "10": { + "layers.0.weight": [ + [ + -61.0 + ] + ], + "layers.1.weight": [ + [ + -55.0 + ] + ] + }, + "11": { + "layers.0.weight": [ + [ + -63.0 + ] + ], + "layers.1.weight": [ + [ + -57.0 + ] + ] + }, + "12": { + "layers.0.weight": [ + [ + -65.0 + ] + ], + "layers.1.weight": [ + [ + -71.0 + ] + ] + }, + "13": { + "layers.0.weight": [ + [ + -67.0 + ] + ], + "layers.1.weight": [ + [ + -73.0 + ] + ] + }, + "14": { + "layers.0.weight": [ + [ + -69.0 + ] + ], + "layers.1.weight": [ + [ + -75.0 + ] + ] + }, + "15": { + "layers.0.weight": [ + [ + -83.0 + ] + ], + "layers.1.weight": [ + [ + -77.0 + ] + ] } }, "global_parameter_history": { @@ -419,6 +587,30 @@ -47.0 ] ] + }, + "9": { + "layers.0.weight": [ + [ + -59.0 + ] + ], + "layers.1.weight": [ + [ + -47.0 + ] + ] + }, + "12": { + "layers.0.weight": [ + [ + -59.0 + ] + ], + "layers.1.weight": [ + [ + -71.0 + ] + ] } } } diff --git a/torchft/diloco_regression_test.py b/torchft/diloco_regression_test.py index 99ee931..a1cb7c9 100644 --- a/torchft/diloco_regression_test.py +++ b/torchft/diloco_regression_test.py @@ -3,6 +3,7 @@ import json import logging import os +import threading from concurrent.futures import ThreadPoolExecutor, as_completed from contextlib import ExitStack from datetime import timedelta @@ -141,6 +142,7 @@ def __init__( diloco_args: dict[str, Any], inner_lr: float = 1, outer_lr: float = 2, + quorum_barrier: Optional[threading.Barrier] = None, ) -> None: self.inner_lr = inner_lr self.outer_lr = outer_lr @@ -150,6 +152,8 @@ def __init__( rank, store_port, device, runner, model_state_dict, n_fragments, diloco_args ) + self.quorum_barrier = quorum_barrier + def setup_model(self) -> MockModel: """Set up the mock model and move it to the device.""" model = MockModel(in_dim=1, out_dim=1, n_layers=self.n_fragments) @@ -186,6 +190,14 @@ def train_loop(self) -> Dict[str, Any]: backup_device=self.device, **self.diloco_args, ) as self.diloco: + if self.quorum_barrier is not None: + self.manager.start_quorum() + self.manager.wait_quorum() + assert self.quorum_barrier is not None + self.quorum_barrier.wait() + assert self.manager.should_commit() + assert self.manager.should_commit() + local_step = 0 manager_steps = set() while True: @@ -197,7 +209,7 @@ def train_loop(self) -> Dict[str, Any]: manager_curr_step = self.manager.current_step() - if manager_curr_step == 5: + if manager_curr_step == 7: break if manager_curr_step not in manager_steps: @@ -248,6 +260,7 @@ def mock_diloco_train_loop( model_state_dict = train_loop_args.get("model_state_dict", {}) n_fragments = train_loop_args.get("n_fragments", 1) diloco_args = train_loop_args.get("diloco_args", {}) + quorum_barrier = train_loop_args.get("quorum_barrier", None) with ExitStack() as stack: trainer = MockDiLoCoTrainer( @@ -258,6 +271,7 @@ def mock_diloco_train_loop( model_state_dict, n_fragments, diloco_args, + quorum_barrier=quorum_barrier, ) stack.callback(trainer.manager.shutdown) return trainer.train_loop() @@ -304,6 +318,7 @@ def test_diloco_mocked_updates( # Create a proper state_dict for the model to avoid load_state_dict errors temp_model = MockModel(in_dim=1, out_dim=1, n_layers=n_fragments) model_state_dict = temp_model.state_dict() + quorum_barrier = threading.Barrier(num_replicas) with ThreadPoolExecutor(max_workers=num_replicas) as executor: for replica_id in range(num_replicas): @@ -316,6 +331,7 @@ def test_diloco_mocked_updates( train_loop=mock_diloco_train_loop, use_cuda=use_cuda, train_loop_args={ + "quorum_barrier": quorum_barrier, "n_fragments": n_fragments, "model_state_dict": model_state_dict, "diloco_args": { From 0fbae9ef68c4d1a0cfa06d4da6a3182eca18e2d4 Mon Sep 17 00:00:00 2001 From: Tushar Jain Date: Fri, 1 Aug 2025 12:55:40 -0700 Subject: [PATCH 2/2] fix compute/communication overlap for gloo Summary: - we current wait for pg work's future when preparing for a fragment - if we use gloo, this blocks the cpu - move the wait call to when we perform the actual sync of the fragment --- torchft/local_sgd.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/torchft/local_sgd.py b/torchft/local_sgd.py index 957680e..69f7130 100644 --- a/torchft/local_sgd.py +++ b/torchft/local_sgd.py @@ -400,13 +400,6 @@ def prepare_sync(self) -> None: ): self._average_grads() - for work in self._allreduce_work: - work.wait() - - if self._stream is not None: - self._stop_event = torch.cuda.Event() - self._stop_event.record() - @torch.profiler.record_function("torchft::local_sgd::perform_sync") def perform_sync(self) -> bool: """ @@ -416,6 +409,18 @@ def perform_sync(self) -> bool: # Waiting for an allreduce before it has been sent is currently not supported. assert len(self._allreduce_work) > 0 + with ( + torch.cuda.stream(self._stream) + if self._stream is not None + else nullcontext() + ): + for work in self._allreduce_work: + work.wait() + + if self._stream is not None: + self._stop_event = torch.cuda.Event() + self._stop_event.record() + self.wait() # save the parameters so they can be used for merging