diff --git a/README.md b/README.md
index aa67f59..61e37d7 100644
--- a/README.md
+++ b/README.md
@@ -13,7 +13,7 @@ Easy Per Step Fault Tolerance for PyTorch
   | <a href="https://pytorch.org/torchft/"><b>Documentation</b></a>
   | <a href="https://github.com/pytorch-labs/torchft/blob/main/media/fault_tolerance_poster.pdf"><b>Poster</b></a>
   | <a href="https://docs.google.com/document/d/1OZsOsz34gRDSxYXiKkj4WqcD9x0lP9TcsfBeu_SsOY4/edit"><b>Design Doc</b></a>
-  | 
+  |
 </p>
 <p align="center">
   <a href="https://pypi.org/project/torchft-nightly/"><img alt="PyPI - Version" src="https://img.shields.io/pypi/v/torchft-nightly"></a>
@@ -98,7 +98,7 @@ when using synchronous training.
 You can start a lighthouse server by running:
 
 ```sh
-$ RUST_BACKTRACE=1 torchft_lighthouse --min_replicas 1 --quorum_tick_ms 100 --join_timeout_ms 1000
+$ RUST_BACKTRACE=1 torchft_lighthouse --min_replicas 1 --quorum_tick_ms 100 --join_timeout_ms 10000
 ```
 
 ### Example Training Loop (DDP)
@@ -108,7 +108,7 @@ See [train_ddp.py](./train_ddp.py) for the full example.
 Invoke with:
 
 ```sh
-$ TORCHFT_MANAGER_PORT=29512 TORCHFT_LIGHTHOUSE=http://localhost:29510 torchrun --master_port 29501 --nnodes 1 --nproc_per_node 1 train.py
+$ TORCHFT_LIGHTHOUSE=http://localhost:29510 torchrun --master_port 29501 --nnodes 1 --nproc_per_node 1 train.py
 ```
 
 train.py:
diff --git a/proto/torchft.proto b/proto/torchft.proto
index 67a42c0..7e248e5 100644
--- a/proto/torchft.proto
+++ b/proto/torchft.proto
@@ -72,30 +72,32 @@ service LighthouseService {
 message ManagerQuorumRequest {
     int64 rank = 1;
     int64 step = 2;
-    string checkpoint_server_addr = 3;
+    string checkpoint_metadata = 3;
     bool shrink_only = 4;
 }
 
 message ManagerQuorumResponse {
     int64 quorum_id = 1;
-    string address = 2;
-    string store_address = 3;
+    string recover_src_manager_address = 2;
+    optional int64 recover_src_rank = 3;
+    repeated int64 recover_dst_ranks = 4;
+    string store_address = 5;
     // These are information for the replicas which are at the max step.
-    int64 max_step = 4;
-    optional int64 max_rank = 5;
-    int64 max_world_size = 6;
+    int64 max_step = 6;
+    optional int64 max_rank = 7;
+    int64 max_world_size = 8;
     // These are information for all replicas including behind replicas.
-    int64 replica_rank = 7;
-    int64 replica_world_size = 8;
-    bool heal = 9;
+    int64 replica_rank = 9;
+    int64 replica_world_size = 10;
+    bool heal = 11;
 }
 
-message CheckpointAddressRequest {
+message CheckpointMetadataRequest {
     int64 rank = 1;
 }
 
-message CheckpointAddressResponse {
-    string checkpoint_server_address = 1;
+message CheckpointMetadataResponse {
+    string checkpoint_metadata = 1;
 }
 
 message ShouldCommitRequest {
@@ -114,7 +116,7 @@ message KillResponse {}
 
 service ManagerService {
     rpc Quorum (ManagerQuorumRequest) returns (ManagerQuorumResponse);
-    rpc CheckpointAddress(CheckpointAddressRequest) returns (CheckpointAddressResponse);
+    rpc CheckpointMetadata(CheckpointMetadataRequest) returns (CheckpointMetadataResponse);
     rpc ShouldCommit(ShouldCommitRequest) returns (ShouldCommitResponse);
     rpc Kill(KillRequest) returns (KillResponse);
 }
diff --git a/src/lib.rs b/src/lib.rs
index d9a124b..529532d 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -27,7 +27,7 @@ pub mod torchftpb {
 }
 
 use crate::torchftpb::manager_service_client::ManagerServiceClient;
-use crate::torchftpb::{CheckpointAddressRequest, ManagerQuorumRequest, ShouldCommitRequest};
+use crate::torchftpb::{CheckpointMetadataRequest, ManagerQuorumRequest, ShouldCommitRequest};
 use pyo3::prelude::*;
 
 #[pyclass]
@@ -113,15 +113,15 @@ impl ManagerClient {
         py: Python<'_>,
         rank: i64,
         step: i64,
-        checkpoint_server_addr: String,
+        checkpoint_metadata: String,
         shrink_only: bool,
         timeout: Duration,
-    ) -> Result<(i64, i64, i64, String, String, i64, Option<i64>, i64, bool), StatusError> {
+    ) -> Result<QuorumResult, StatusError> {
         py.allow_threads(move || {
             let mut request = tonic::Request::new(ManagerQuorumRequest {
                 rank: rank,
                 step: step,
-                checkpoint_server_addr: checkpoint_server_addr,
+                checkpoint_metadata: checkpoint_metadata,
                 shrink_only: shrink_only,
             });
 
@@ -131,28 +131,30 @@ impl ManagerClient {
 
             let response = self.runtime.block_on(self.client.clone().quorum(request))?;
             let resp = response.into_inner();
-            Ok((
-                resp.quorum_id,
-                resp.replica_rank,
-                resp.replica_world_size,
-                resp.address,
-                resp.store_address,
-                resp.max_step,
-                resp.max_rank,
-                resp.max_world_size,
-                resp.heal,
-            ))
+            Ok(QuorumResult {
+                quorum_id: resp.quorum_id,
+                replica_rank: resp.replica_rank,
+                replica_world_size: resp.replica_world_size,
+                recover_src_manager_address: resp.recover_src_manager_address,
+                recover_src_rank: resp.recover_src_rank,
+                recover_dst_ranks: resp.recover_dst_ranks,
+                store_address: resp.store_address,
+                max_step: resp.max_step,
+                max_rank: resp.max_rank,
+                max_world_size: resp.max_world_size,
+                heal: resp.heal,
+            })
         })
     }
 
-    fn checkpoint_address(
+    fn checkpoint_metadata(
         &self,
         py: Python<'_>,
         rank: i64,
         timeout: Duration,
     ) -> Result<String, StatusError> {
         py.allow_threads(move || {
-            let mut request = tonic::Request::new(CheckpointAddressRequest { rank: rank });
+            let mut request = tonic::Request::new(CheckpointMetadataRequest { rank: rank });
 
             // This timeout is processed on the server side so we also enable
             // keep alives to detect server health.
@@ -160,9 +162,9 @@ impl ManagerClient {
 
             let response = self
                 .runtime
-                .block_on(self.client.clone().checkpoint_address(request))?;
+                .block_on(self.client.clone().checkpoint_metadata(request))?;
             let resp = response.into_inner();
-            Ok(resp.checkpoint_server_address)
+            Ok(resp.checkpoint_metadata)
         })
     }
 
@@ -194,6 +196,41 @@ impl ManagerClient {
     }
 }
 
+#[pyclass(get_all, set_all)]
+struct QuorumResult {
+    quorum_id: i64,
+    replica_rank: i64,
+    replica_world_size: i64,
+    recover_src_manager_address: String,
+    recover_src_rank: Option<i64>,
+    recover_dst_ranks: Vec<i64>,
+    store_address: String,
+    max_step: i64,
+    max_rank: Option<i64>,
+    max_world_size: i64,
+    heal: bool,
+}
+
+#[pymethods]
+impl QuorumResult {
+    #[new]
+    fn new() -> Self {
+        Self {
+            quorum_id: 0,
+            replica_rank: 0,
+            replica_world_size: 1,
+            recover_src_manager_address: "".to_string(),
+            recover_src_rank: None,
+            recover_dst_ranks: Vec::new(),
+            store_address: "".to_string(),
+            max_step: 0,
+            max_rank: None,
+            max_world_size: 1,
+            heal: false,
+        }
+    }
+}
+
 fn reset_python_signals(py: Python<'_>) -> PyResult<()> {
     // clear python signal handlers
     // signal.signal(signal.SIGINT, signal.SIG_DFL)
@@ -319,6 +356,7 @@ fn torchft(m: &Bound<'_, PyModule>) -> PyResult<()> {
     m.add_class::<Manager>()?;
     m.add_class::<ManagerClient>()?;
     m.add_class::<Lighthouse>()?;
+    m.add_class::<QuorumResult>()?;
     m.add_function(wrap_pyfunction!(lighthouse_main, m)?)?;
 
     Ok(())
diff --git a/src/lighthouse.rs b/src/lighthouse.rs
index e6be595..643aef1 100644
--- a/src/lighthouse.rs
+++ b/src/lighthouse.rs
@@ -77,7 +77,7 @@ pub struct LighthouseOpt {
     #[structopt(
         long = "join_timeout_ms",
         default_value = "60000",
-        help = "How long to wait for new replicas to join before considering a quorum"
+        help = "How long to wait for heartbeating stragglers to join before issuing quorum"
     )]
     pub join_timeout_ms: u64,
 
@@ -90,14 +90,14 @@ pub struct LighthouseOpt {
     #[structopt(
         long = "quorum_tick_ms",
         default_value = "100",
-        help = "How frequently to check for quorum when waiting for workers."
+        help = "How frequently to check for quorum when waiting for stragglers."
     )]
     pub quorum_tick_ms: u64,
 
     #[structopt(
         long = "heartbeat_timeout_ms",
         default_value = "5000",
-        help = "how long to wait for a heartbeat before considering a replica dead."
+        help = "How long to wait for a heartbeat before considering a replica dead."
     )]
     pub heartbeat_timeout_ms: u64,
 }
@@ -146,9 +146,10 @@ fn quorum_compute(
         .any(|(_, details)| details.member.shrink_only);
 
     let metadata = format!(
-        "[{}/{} participants healthy][shrink_only={}]",
+        "[{}/{} participants healthy][{} heartbeating][shrink_only={}]",
         healthy_participants.len(),
         state.participants.len(),
+        healthy_replicas.len(),
         shrink_only,
     );
 
@@ -190,7 +191,7 @@ fn quorum_compute(
         return (
             None,
             format!(
-                "No quorum, only have {} participants, need min_replicas {} {}",
+                "New quorum not ready, only have {} participants, need min_replicas {} {}",
                 healthy_participants.len(),
                 opt.min_replicas,
                 metadata
@@ -203,7 +204,7 @@ fn quorum_compute(
         return (
             None,
             format!(
-                "No quorum, only have {} participants, need at least half of {} healthy workers {}",
+                "New quorum not ready, only have {} participants, need at least half of {} healthy workers {}",
                 healthy_participants.len(),
                 healthy_replicas.len(),
                 metadata
@@ -261,7 +262,7 @@ impl Lighthouse {
 
     fn _quorum_tick(self: Arc<Self>, state: &mut State) -> Result<()> {
         let (quorum_met, reason) = quorum_compute(Instant::now(), state, &self.opt);
-        info!("{}", reason);
+        info!("Next quorum status: {}", reason);
 
         if quorum_met.is_some() {
             let participants = quorum_met.unwrap();
@@ -600,7 +601,9 @@ mod tests {
 
         let now = Instant::now();
 
-        assert!(!quorum_compute(now, &state, &opt).0.is_some());
+        let (quorum_met, reason) = quorum_compute(now, &state, &opt);
+        assert!(quorum_met.is_none(), "{}", reason);
+        assert!(reason.contains("New quorum not ready, only have 0 participants, need min_replicas 1 [0/0 participants healthy]"), "{}", reason);
 
         state.participants.insert(
             "a".to_string(),
@@ -689,7 +692,13 @@ mod tests {
         );
         state.heartbeats.insert("a".to_string(), now);
 
-        assert!(quorum_compute(now, &state, &opt).0.is_some());
+        let (quorum_met, reason) = quorum_compute(now, &state, &opt);
+        assert!(quorum_met.is_some(), "{}", reason);
+        assert!(
+            reason.contains("[1/1 participants healthy][1 heartbeating]"),
+            "{}",
+            reason
+        );
 
         // expired heartbeat
         state
@@ -698,6 +707,11 @@ mod tests {
 
         let (quorum_met, reason) = quorum_compute(now, &state, &opt);
         assert!(quorum_met.is_none(), "{}", reason);
+        assert!(
+            reason.contains("[0/1 participants healthy][0 heartbeating]"),
+            "{}",
+            reason
+        );
 
         // 1 healthy, 1 expired
         state.participants.insert(
@@ -886,6 +900,7 @@ mod tests {
 
         let (quorum_met, reason) = quorum_compute(now, &state, &opt);
         assert!(quorum_met.is_some(), "{}", reason);
+        assert!(reason.contains("[shrink_only=true]",), "{}", reason);
 
         let quorum = quorum_met.unwrap();
         assert!(quorum.len() == 1);
@@ -982,7 +997,7 @@ mod tests {
         state.heartbeats.insert("b".to_string(), now);
         let (quorum_met, reason) = quorum_compute(now, &state, &opt);
         assert!(quorum_met.is_none(), "{}", reason);
-        assert!(reason.contains("at least half"), "{}", reason);
+        assert!(reason.contains("New quorum not ready, only have 1 participants, need at least half of 2 healthy workers [1/1 participants healthy][2 heartbeating]"), "{}", reason);
 
         Ok(())
     }
diff --git a/src/manager.rs b/src/manager.rs
index 982500a..931e995 100644
--- a/src/manager.rs
+++ b/src/manager.rs
@@ -26,7 +26,7 @@ use crate::torchftpb::lighthouse_service_client::LighthouseServiceClient;
 use crate::torchftpb::manager_service_client::ManagerServiceClient;
 use crate::torchftpb::{
     manager_service_server::{ManagerService, ManagerServiceServer},
-    CheckpointAddressRequest, CheckpointAddressResponse, KillRequest, KillResponse,
+    CheckpointMetadataRequest, CheckpointMetadataResponse, KillRequest, KillResponse,
     LighthouseHeartbeatRequest, LighthouseQuorumRequest, ManagerQuorumRequest,
     ManagerQuorumResponse, Quorum, QuorumMember, ShouldCommitRequest, ShouldCommitResponse,
 };
@@ -38,7 +38,7 @@ use log::{info, warn};
 use std::{println as info, println as warn};
 
 struct ManagerState {
-    checkpoint_servers: HashMap<i64, String>,
+    checkpoint_metadata: HashMap<i64, String>,
     channel: broadcast::Sender<Quorum>,
     participants: HashSet<i64>,
 
@@ -104,7 +104,7 @@ impl Manager {
             world_size: world_size,
             heartbeat_interval: heartbeat_interval,
             state: Mutex::new(ManagerState {
-                checkpoint_servers: HashMap::new(),
+                checkpoint_metadata: HashMap::new(),
                 channel: tx,
                 participants: HashSet::new(),
 
@@ -237,8 +237,8 @@ impl ManagerService for Arc<Manager> {
             // save checkpoint server info for healing process
             // TODO: make separate call to set?
             state
-                .checkpoint_servers
-                .insert(req.rank, req.checkpoint_server_addr.clone());
+                .checkpoint_metadata
+                .insert(req.rank, req.checkpoint_metadata.clone());
 
             // TODO check step
             state.participants.insert(rank);
@@ -266,81 +266,28 @@ impl ManagerService for Arc<Manager> {
             .await
             .map_err(|e| Status::internal(e.to_string()))?;
 
-        let mut participants = quorum.participants.clone();
-        participants.sort_by(|a, b| a.replica_id.cmp(&b.replica_id));
-
-        let replica_rank = participants.iter().enumerate().find_map(|(i, p)| {
-            if p.replica_id == self.replica_id {
-                Some(i)
-            } else {
-                None
-            }
-        });
-        if replica_rank.is_none() {
-            return Err(Status::not_found(format!(
-                "replica {} not participating in returned quorum",
-                self.replica_id
-            )));
-        }
-
-        let max_step = participants.iter().map(|p| p.step).max().unwrap();
-        let max_participants: Vec<&QuorumMember> =
-            participants.iter().filter(|p| p.step == max_step).collect();
-
-        let primary = max_participants[rank as usize % max_participants.len()];
-
-        let mut max_rank = None;
-        for (i, p) in max_participants.iter().enumerate() {
-            if p.replica_id == self.replica_id {
-                max_rank = Some(i as i64);
-                break;
-            }
-        }
-
-        // Decide whether we should be healing:
-        // 1. if we're not at the max step
-        // 2. if everyone is at the first step and we're not the primary
-        let heal = max_step != req.step || max_step == 0 && primary.replica_id != self.replica_id;
-        if heal {
-            info!(
-                "healing is required step={}, max_step={}",
-                req.step, max_step
-            );
-        }
-
-        let reply = ManagerQuorumResponse {
-            quorum_id: quorum.quorum_id,
-            // address is used for looking up the checkpoint server address.
-            address: primary.address.clone(),
-            store_address: primary.store_address.clone(),
-            max_step: max_step,
-            max_rank: max_rank,
-            max_world_size: max_participants.len() as i64,
-            replica_rank: replica_rank.unwrap() as i64,
-            replica_world_size: participants.len() as i64,
-            heal: heal,
-        };
-
         info!("returning quorum for rank {}", rank);
 
+        let reply = compute_quorum_results(&self.replica_id, rank, &quorum)?;
+
         Ok(Response::new(reply))
     }
 
-    async fn checkpoint_address(
+    async fn checkpoint_metadata(
         &self,
-        request: Request<CheckpointAddressRequest>,
-    ) -> Result<Response<CheckpointAddressResponse>, Status> {
+        request: Request<CheckpointMetadataRequest>,
+    ) -> Result<Response<CheckpointMetadataResponse>, Status> {
         let state = self.state.lock().await;
 
         let req = request.into_inner();
 
-        let address = state
-            .checkpoint_servers
+        let metadata = state
+            .checkpoint_metadata
             .get(&req.rank)
             .ok_or_else(|| Status::invalid_argument("rank not found"))?;
 
-        let reply = CheckpointAddressResponse {
-            checkpoint_server_address: address.clone(),
+        let reply = CheckpointMetadataResponse {
+            checkpoint_metadata: metadata.clone(),
         };
         Ok(Response::new(reply))
     }
@@ -407,6 +354,131 @@ impl ManagerService for Arc<Manager> {
     }
 }
 
+fn compute_quorum_results(
+    replica_id: &str,
+    rank: i64,
+    quorum: &Quorum,
+) -> Result<ManagerQuorumResponse, Status> {
+    let mut participants = quorum.participants.clone();
+    participants.sort_by(|a, b| a.replica_id.cmp(&b.replica_id));
+
+    // Compute the rank of the replica in the returned quorum.
+    let replica_rank = participants
+        .iter()
+        .enumerate()
+        .find_map(|(i, p)| {
+            if p.replica_id == replica_id {
+                Some(i)
+            } else {
+                None
+            }
+        })
+        .ok_or_else(|| {
+            Status::not_found(format!(
+                "replica {} not participating in returned quorum",
+                replica_id
+            ))
+        })?;
+
+    let step = participants[replica_rank].step;
+
+    // Compute the details for workers at max step.
+    let max_step = participants.iter().map(|p| p.step).max().unwrap();
+    let max_participants: Vec<&QuorumMember> =
+        participants.iter().filter(|p| p.step == max_step).collect();
+    let max_rank = max_participants.iter().enumerate().find_map(|(i, p)| {
+        if p.replica_id == replica_id {
+            Some(i as i64)
+        } else {
+            None
+        }
+    });
+
+    // The primary TCPStore to use for this rank.
+    let primary_rank = rank as usize % max_participants.len();
+    let primary = max_participants[primary_rank];
+
+    // Compute recovery assignments
+
+    // Nodes are recovering if:
+    // 1. not at the max step
+    // 2. max_step == 0 and not the primary replica
+    let all_recover_dst_ranks: Vec<usize> = participants
+        .iter()
+        .enumerate()
+        .filter_map(|(i, p)| {
+            if p.step != max_step || max_step == 0 && primary.replica_id != p.replica_id {
+                Some(i)
+            } else {
+                None
+            }
+        })
+        .collect();
+    let all_recover_dst_ranks_set = all_recover_dst_ranks.iter().collect::<HashSet<_>>();
+    let up_to_date_ranks: Vec<usize> = participants
+        .iter()
+        .enumerate()
+        .filter_map(|(i, _p)| {
+            if !all_recover_dst_ranks_set.contains(&i) {
+                Some(i)
+            } else {
+                None
+            }
+        })
+        .collect();
+
+    // This is a map of rank to the ranks that are recovering from that node.
+    let mut recovery_assignments: HashMap<usize, Vec<i64>> = HashMap::new();
+    // The rank of the node that this rank is recovering from.
+    let mut recover_src_rank: Option<i64> = None;
+    for (i, recovering_rank) in all_recover_dst_ranks.iter().enumerate() {
+        let up_to_date_idx = (i + rank as usize) % up_to_date_ranks.len();
+        let recovering_recover_src_rank = up_to_date_ranks[up_to_date_idx];
+        if !recovery_assignments.contains_key(&recovering_recover_src_rank) {
+            recovery_assignments.insert(recovering_recover_src_rank, Vec::new());
+        }
+        recovery_assignments
+            .get_mut(&recovering_recover_src_rank)
+            .unwrap()
+            .push(*recovering_rank as i64);
+        if *recovering_rank == replica_rank {
+            recover_src_rank = Some(recovering_recover_src_rank as i64);
+        }
+    }
+
+    let heal = recover_src_rank.is_some();
+    if heal {
+        info!(
+            "healing is required step={}, max_step={}, recover_src_rank={}",
+            step,
+            max_step,
+            recover_src_rank.unwrap()
+        );
+    }
+
+    let recover_src_manager_address = match recover_src_rank {
+        Some(r) => participants[r as usize].address.clone(),
+        None => "".to_string(),
+    };
+
+    Ok(ManagerQuorumResponse {
+        quorum_id: quorum.quorum_id,
+        // address is used for looking up the checkpoint server address.
+        recover_src_manager_address: recover_src_manager_address,
+        recover_src_rank: recover_src_rank,
+        recover_dst_ranks: recovery_assignments
+            .get(&replica_rank)
+            .map_or_else(Vec::new, |v| v.clone()),
+        store_address: primary.store_address.clone(),
+        max_step: max_step,
+        max_rank: max_rank,
+        max_world_size: max_participants.len() as i64,
+        replica_rank: replica_rank as i64,
+        replica_world_size: participants.len() as i64,
+        heal: heal,
+    })
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;
@@ -506,7 +578,7 @@ mod tests {
         let mut request = tonic::Request::new(ManagerQuorumRequest {
             rank: 0,
             step: 123,
-            checkpoint_server_addr: "addr".to_string(),
+            checkpoint_metadata: "addr".to_string(),
             shrink_only: false,
         });
         request.set_timeout(Duration::from_secs(10));
@@ -516,7 +588,7 @@ mod tests {
         lighthouse_fut.abort();
 
         assert_eq!(resp.quorum_id, 1);
-        assert_eq!(resp.address, manager.address());
+        assert_eq!(resp.recover_src_manager_address, "".to_string());
         assert_eq!(resp.store_address, "store_addr".to_string());
         assert_eq!(resp.max_step, 123);
         assert_eq!(resp.max_rank, Some(0));
@@ -565,7 +637,7 @@ mod tests {
                 let mut request = tonic::Request::new(ManagerQuorumRequest {
                     rank: 0,
                     step: 0,
-                    checkpoint_server_addr: "addr".to_string(),
+                    checkpoint_metadata: "addr".to_string(),
                     shrink_only: false,
                 });
                 request.set_timeout(Duration::from_secs(10));
@@ -597,4 +669,183 @@ mod tests {
 
         Ok(())
     }
+
+    #[tokio::test]
+    async fn test_checkpoint_metadata() -> Result<()> {
+        let lighthouse = Lighthouse::new(LighthouseOpt {
+            bind: "[::]:0".to_string(),
+            join_timeout_ms: 100,
+            min_replicas: 1,
+            quorum_tick_ms: 100,
+            heartbeat_timeout_ms: 5000,
+        })
+        .await?;
+        let lighthouse_fut = tokio::spawn(lighthouse.clone().run());
+
+        let manager = Manager::new(
+            "rep_id".to_string(),
+            lighthouse.address(),
+            "localhost".to_string(),
+            "[::]:0".to_string(),
+            "store_addr".to_string(),
+            1,                          // world size
+            Duration::from_millis(100), // heartbeat interval
+            Duration::from_secs(10),    // connect timeout
+        )
+        .await?;
+        let manager_fut = tokio::spawn(manager.clone().run());
+
+        let mut client = manager_client_new(manager.address(), Duration::from_secs(10)).await?;
+
+        let request = tonic::Request::new(CheckpointMetadataRequest { rank: 0 });
+        let resp = client.checkpoint_metadata(request).await;
+        assert!(resp.err().unwrap().to_string().contains("rank not found"));
+
+        {
+            let mut state = manager.state.lock().await;
+
+            state.checkpoint_metadata.insert(0, "addr".to_string());
+        }
+
+        let request = tonic::Request::new(CheckpointMetadataRequest { rank: 0 });
+        let resp = client.checkpoint_metadata(request).await?.into_inner();
+        assert_eq!(resp.checkpoint_metadata, "addr".to_string());
+
+        manager_fut.abort();
+        lighthouse_fut.abort();
+
+        Ok(())
+    }
+
+    #[tokio::test]
+    async fn test_compute_quorum_results_first_step() -> Result<()> {
+        let quorum = Quorum {
+            quorum_id: 1,
+            participants: vec![
+                QuorumMember {
+                    replica_id: "replica_0".to_string(),
+                    address: "addr_0".to_string(),
+                    store_address: "store_addr_0".to_string(),
+                    step: 0,
+                    world_size: 1,
+                    shrink_only: false,
+                },
+                QuorumMember {
+                    replica_id: "replica_1".to_string(),
+                    address: "addr_1".to_string(),
+                    store_address: "store_addr_1".to_string(),
+                    step: 0,
+                    world_size: 1,
+                    shrink_only: false,
+                },
+            ],
+            created: None,
+        };
+
+        // rank 0
+
+        let results = compute_quorum_results("replica_0", 0, &quorum)?;
+        assert!(!results.heal);
+        assert_eq!(results.replica_rank, 0);
+        assert_eq!(results.recover_src_rank, None);
+        assert_eq!(results.recover_dst_ranks, vec![1]);
+
+        let results = compute_quorum_results("replica_1", 0, &quorum)?;
+        assert!(results.heal);
+        assert_eq!(results.replica_rank, 1);
+        assert_eq!(results.recover_src_rank, Some(0));
+        assert_eq!(results.recover_dst_ranks, Vec::<i64>::new());
+
+        // rank 1 assignments should be offset from rank 0 above and the primary
+
+        let results = compute_quorum_results("replica_1", 1, &quorum)?;
+        assert!(!results.heal);
+        assert_eq!(results.replica_rank, 1);
+        assert_eq!(results.recover_src_rank, None);
+        assert_eq!(results.recover_dst_ranks, vec![0]);
+
+        Ok(())
+    }
+
+    #[tokio::test]
+    async fn test_compute_quorum_results_recovery() -> Result<()> {
+        let quorum = Quorum {
+            quorum_id: 1,
+            participants: vec![
+                QuorumMember {
+                    replica_id: "replica_0".to_string(),
+                    address: "addr_0".to_string(),
+                    store_address: "store_addr_0".to_string(),
+                    step: 0,
+                    world_size: 1,
+                    shrink_only: false,
+                },
+                QuorumMember {
+                    replica_id: "replica_1".to_string(),
+                    address: "addr_1".to_string(),
+                    store_address: "store_addr_1".to_string(),
+                    step: 1,
+                    world_size: 1,
+                    shrink_only: false,
+                },
+                QuorumMember {
+                    replica_id: "replica_2".to_string(),
+                    address: "addr_2".to_string(),
+                    store_address: "store_addr_2".to_string(),
+                    step: 0,
+                    world_size: 1,
+                    shrink_only: false,
+                },
+                QuorumMember {
+                    replica_id: "replica_3".to_string(),
+                    address: "addr_3".to_string(),
+                    store_address: "store_addr_3".to_string(),
+                    step: 1,
+                    world_size: 1,
+                    shrink_only: false,
+                },
+                QuorumMember {
+                    replica_id: "replica_4".to_string(),
+                    address: "addr_4".to_string(),
+                    store_address: "store_addr_4".to_string(),
+                    step: 0,
+                    world_size: 1,
+                    shrink_only: false,
+                },
+            ],
+            created: None,
+        };
+
+        // rank 0
+
+        let results = compute_quorum_results("replica_0", 0, &quorum)?;
+        assert!(results.heal);
+        assert_eq!(results.recover_src_manager_address, "addr_1".to_string());
+        assert_eq!(results.replica_rank, 0);
+        assert_eq!(results.recover_src_rank, Some(1));
+        assert!(results.recover_dst_ranks.is_empty());
+
+        let results = compute_quorum_results("replica_1", 0, &quorum)?;
+        assert!(!results.heal);
+        assert_eq!(results.recover_src_manager_address, "".to_string());
+        assert_eq!(results.replica_rank, 1);
+        assert_eq!(results.recover_src_rank, None);
+        assert_eq!(results.recover_dst_ranks, vec![0, 4]);
+
+        let results = compute_quorum_results("replica_3", 0, &quorum)?;
+        assert!(!results.heal);
+        assert_eq!(results.replica_rank, 3);
+        assert_eq!(results.recover_src_rank, None);
+        assert_eq!(results.recover_dst_ranks, vec![2]);
+
+        // rank 1 assignments should be offset from rank 0 above
+
+        let results = compute_quorum_results("replica_1", 1, &quorum)?;
+        assert!(!results.heal);
+        assert_eq!(results.replica_rank, 1);
+        assert_eq!(results.recover_src_rank, None);
+        assert_eq!(results.recover_dst_ranks, vec![2]);
+
+        Ok(())
+    }
 }
diff --git a/torchft/checkpointing.py b/torchft/checkpointing.py
index aaad843..c3168b2 100644
--- a/torchft/checkpointing.py
+++ b/torchft/checkpointing.py
@@ -16,9 +16,11 @@
 import socket
 import threading
 import urllib.request
+from abc import ABC, abstractmethod
+from contextlib import contextmanager
 from datetime import timedelta
 from http.server import BaseHTTPRequestHandler
-from typing import Callable, Generic, TypeVar
+from typing import Generator, Generic, List, Optional, TypeVar
 
 import torch
 
@@ -29,7 +31,83 @@
 T = TypeVar("T")
 
 
-class CheckpointServer(Generic[T]):
+class CheckpointTransport(Generic[T], ABC):
+    @abstractmethod
+    def metadata(self) -> str:
+        """
+        Returns a string that will be used by the remote CheckpointTransport to fetch the checkpoint.
+        """
+        ...
+
+    @abstractmethod
+    def send_checkpoint(
+        self, dst_ranks: List[int], step: int, state_dict: T, timeout: timedelta
+    ) -> None:
+        """
+        Sends the checkpoint, only called when there is a rank that is behind.
+
+        This may be async.
+
+        Args:
+            dst_ranks: the ranks to send to
+            step: the step number to send
+            state_dict: the state dict to send
+            timeout: the timeout to wait for the checkpoint to be sent
+        """
+        ...
+
+    def disallow_checkpoint(self) -> None:
+        """
+        Called after send_checkpoint to wait for the checkpoint to be sent.
+
+        Once this returns, the state_dict may be mutated so no further data should be sent.
+        """
+        ...
+
+    @abstractmethod
+    def recv_checkpoint(
+        self, src_rank: int, metadata: str, step: int, timeout: timedelta
+    ) -> T:
+        """
+        Receives the checkpoint from the given rank.
+
+        Args:
+            src_rank: the rank to receive the checkpoint from
+            metadata: the metadata returned by the remote CheckpointTransport
+            step: the step number to receive
+            timeout: the timeout to wait for the checkpoint
+        """
+        ...
+
+    def shutdown(self, wait: bool = True) -> None:
+        """
+        Called to shutdown the checkpoint transport.
+
+        Args:
+            wait: whether to wait for the transport to shutdown
+        """
+
+
+@contextmanager
+def _timed_acquire(
+    lock: threading.Lock, timeout: timedelta
+) -> Generator[None, None, None]:
+    """
+    Acquire a lock with a timeout.
+
+    Args:
+        lock: the lock to acquire
+        timeout: the timeout to acquire the lock
+    """
+    if not lock.acquire(timeout=timeout.total_seconds()):
+        raise TimeoutError(f"timed out acquiring lock after {timeout}")
+    try:
+        yield
+    finally:
+        lock.release()
+
+
+class CheckpointServer(CheckpointTransport[T]):
     """
     This is an HTTP server that can be used to transfer checkpoints
     between workers.
@@ -41,11 +119,16 @@ class CheckpointServer(Generic[T]):
         state_dict: a callable that returns the state dict to be transferred
     """
 
-    def __init__(self, state_dict: Callable[[], T], timeout: timedelta) -> None:
+    def __init__(self, timeout: timedelta) -> None:
         self._checkpoint_lock = threading.Lock()
         self._disallowed = False
         self._step = -1
         self._timeout = timeout
+        self._state_dict: Optional[T] = None
+
+        # We don't allow checkpoints until the first send_checkpoint to avoid
+        # serving the default step=-1 invalid checkpoint.
+        self.disallow_checkpoint()
 
         ckpt_server = self
 
@@ -58,7 +141,9 @@ def do_GET(self):
                     # validate socket timeout is actually set
                     assert self.connection.gettimeout() == self.timeout
 
-                    with ckpt_server._checkpoint_lock:
+                    with _timed_acquire(
+                        ckpt_server._checkpoint_lock, ckpt_server._timeout
+                    ):
                         step = ckpt_server._step
 
                         if self.path != f"/checkpoint/{step}":
@@ -74,9 +159,9 @@ def do_GET(self):
                         self.send_header("Content-type", "application/octet-stream")
                         self.end_headers()
 
-                        sd = state_dict()
+                        state_dict = ckpt_server._state_dict
 
-                        torch.save(sd, self.wfile)
+                        torch.save(state_dict, self.wfile)
                 except Exception as e:
                     logger.exception(
                         f"Exception in checkpoint server when handling {self.path=}: {e}",
@@ -113,11 +198,13 @@ def load_from_address(cls, address: str, timeout: timedelta) -> T:
             data = f.read()
 
         reader = io.BytesIO(data)
-        return torch.load(reader, weights_only=True)
+        # We have to set weights_only to False as there are some non-tensor
+        # states like lr_scheduler.
+        return torch.load(reader, weights_only=False)
 
     def address(self) -> str:
         """
-        Returns the HTTP address to fetch a checkpoint from this server at the current step.
+        Returns the HTTP address to fetch a checkpoint from this server. Step must be appended to the end of the address.
 
         Format: http://host:port/checkpoint/1234
 
@@ -125,7 +212,7 @@ def address(self) -> str:
             an HTTP address
         """
         port = self._server.socket.getsockname()[1]
-        return f"http://{socket.gethostname()}:{port}/checkpoint/{self._step}"
+        return f"http://{socket.gethostname()}:{port}/checkpoint/"
 
     def _serve(self) -> None:
         try:
@@ -156,8 +243,28 @@ def allow_checkpoint(self, step: int) -> None:
             self._disallowed = False
             self._checkpoint_lock.release()
 
-    def shutdown(self) -> None:
+    def shutdown(self, wait: bool = True) -> None:
         """
         Shutdown the server.
         """
-        self._server.shutdown()
+        if not wait:
+            # hack for nonblocking shutdown of socketserver threads
+            # pyre-fixme[16]: no attribute `__shutdown_request`.
+            self._server.__shutdown_request = True
+        if wait:
+            self._server.shutdown()
+            self._thread.join()
+
+    def metadata(self) -> str:
+        return self.address()
+
+    def send_checkpoint(
+        self, dst_ranks: List[int], step: int, state_dict: T, timeout: timedelta
+    ) -> None:
+        self._state_dict = state_dict
+        self.allow_checkpoint(step)
+
+    def recv_checkpoint(
+        self, src_rank: int, metadata: str, step: int, timeout: timedelta
+    ) -> T:
+        return self.load_from_address(f"{metadata}{step}", timeout)
diff --git a/torchft/checkpointing_test.py b/torchft/checkpointing_test.py
index 983c429..31658b4 100644
--- a/torchft/checkpointing_test.py
+++ b/torchft/checkpointing_test.py
@@ -4,12 +4,13 @@
 # This source code is licensed under the BSD-style license found in the
 # LICENSE file in the root directory of this source tree.
 
+import threading
 import urllib.error
 from datetime import timedelta
 from unittest import TestCase
 from unittest.mock import MagicMock
 
-from torchft.checkpointing import CheckpointServer
+from torchft.checkpointing import CheckpointServer, _timed_acquire
 
 
 class TestCheckpointing(TestCase):
@@ -18,26 +19,87 @@ def test_checkpoint_server(self) -> None:
         state_dict_fn = MagicMock()
         state_dict_fn.return_value = expected
         server = CheckpointServer(
-            state_dict=state_dict_fn,
             timeout=timedelta(seconds=10),
         )
 
-        server.disallow_checkpoint()
-        server.allow_checkpoint(1234)
+        server.send_checkpoint(
+            dst_ranks=[],
+            step=1234,
+            state_dict=expected,
+            timeout=timedelta(seconds=10),
+        )
 
-        addr = server.address()
+        metadata = server.metadata()
 
-        out = CheckpointServer.load_from_address(addr, timeout=timedelta(seconds=10))
+        out = server.recv_checkpoint(
+            src_rank=0, metadata=metadata, step=1234, timeout=timedelta(seconds=10)
+        )
         self.assertEqual(out, expected)
 
         # test timeout
         with self.assertRaisesRegex(urllib.error.URLError, r"urlopen error"):
-            CheckpointServer.load_from_address(addr, timeout=timedelta(seconds=0.0))
+            server.recv_checkpoint(
+                src_rank=0, metadata=metadata, step=1234, timeout=timedelta(seconds=0.0)
+            )
 
         # test mismatch case
-        server.allow_checkpoint(2345)
+        server.send_checkpoint(
+            dst_ranks=[],
+            step=2345,
+            state_dict=expected,
+            timeout=timedelta(seconds=10),
+        )
 
         with self.assertRaisesRegex(urllib.error.HTTPError, r"Error 400"):
-            CheckpointServer.load_from_address(addr, timeout=timedelta(seconds=10))
+            server.recv_checkpoint(
+                src_rank=0, metadata=metadata, step=1234, timeout=timedelta(seconds=10)
+            )
 
         server.shutdown()
+
+    def test_checkpoint_server_locking(self) -> None:
+        server = CheckpointServer(
+            timeout=timedelta(seconds=10),
+        )
+
+        # server should start up in a disallowed state this will block incoming
+        # requests until allow_checkpoint is called
+        self.assertTrue(server._checkpoint_lock.locked())
+        self.assertTrue(server._disallowed)
+        self.assertEqual(server._step, -1)
+
+        # allow requests
+        server.allow_checkpoint(1)
+
+        self.assertFalse(server._checkpoint_lock.locked())
+        self.assertFalse(server._disallowed)
+        self.assertEqual(server._step, 1)
+
+        # duplicate allow/disallow is fine
+        server.allow_checkpoint(2)
+        self.assertEqual(server._step, 2)
+
+        server.disallow_checkpoint()
+        server.disallow_checkpoint()
+        self.assertTrue(server._checkpoint_lock.locked())
+        self.assertTrue(server._disallowed)
+
+        server.shutdown()
+
+    def test_timed_acquire(self) -> None:
+        lock = threading.Lock()
+
+        with _timed_acquire(lock, timedelta(seconds=10)):
+            self.assertTrue(lock.locked())
+
+        self.assertFalse(lock.locked())
+
+        lock.acquire()
+
+        with self.assertRaisesRegex(
+            TimeoutError, r"timed out acquiring lock after 0.0"
+        ):
+            with _timed_acquire(lock, timedelta(seconds=0.0)):
+                pass
+
+        self.assertTrue(lock.locked())
diff --git a/torchft/device_mesh_test.py b/torchft/device_mesh_test.py
new file mode 100644
index 0000000..ee78c6d
--- /dev/null
+++ b/torchft/device_mesh_test.py
@@ -0,0 +1,84 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import io
+import os
+from concurrent.futures import ProcessPoolExecutor
+from typing import cast
+from unittest import TestCase
+from unittest.mock import Mock
+
+import torch
+import torch.distributed as dist
+
+from torchft.manager import Manager
+from torchft.process_group import (
+    ManagedProcessGroup,
+    ProcessGroupGloo,
+    ft_init_device_mesh,
+)
+
+
+class DeviceMeshTest(TestCase):
+    @staticmethod
+    def _test_init_device_mesh(world_size: int, rank: int) -> None:
+        os.environ["MASTER_ADDR"] = "127.0.0.1"
+        os.environ["MASTER_PORT"] = str(12346)
+        os.environ["RANK"] = str(rank)
+        os.environ["WORLD_SIZE"] = str(4)
+
+        testcase = TestCase()
+
+        manager = Mock(spec=Manager)
+        manager._pg = ProcessGroupGloo()
+        # Even though we only have 4 workers, we can still initialize (2, 4) mesh.
+        # That's because the replicate group is NOT phystically created in the
+        # real mesh but is virtually added to the mesh via ManagedDeviceMesh.
+        device_mesh = ft_init_device_mesh(
+            device_type="cpu",
+            mesh_shape=(2, world_size),
+            mesh_dim_names=("dp_replicate", "dp_shard"),
+            replicate_dim=0,
+            manager=manager,
+        )
+
+        testcase.assertTrue(
+            isinstance(device_mesh.get_group("dp_replicate"), ManagedProcessGroup)
+        )
+        testcase.assertTrue(
+            not isinstance(device_mesh.get_group("dp_shard"), ManagedProcessGroup)
+        )
+        replicate_group = device_mesh.get_group("dp_replicate")
+        testcase.assertEqual(
+            cast(ManagedProcessGroup, replicate_group)._manager, manager
+        )
+        replicate_mesh = device_mesh["dp_replicate"]
+        testcase.assertEqual(replicate_mesh.get_group(), replicate_group)
+
+        flatten_mesh = device_mesh._flatten("dp")
+        manager.num_participants.return_value = 0
+        testcase.assertEqual(flatten_mesh.size(), world_size)
+        manager.num_participants.return_value = 1
+        testcase.assertEqual(flatten_mesh.size(), world_size)
+        manager.num_participants.return_value = 2
+        testcase.assertEqual(flatten_mesh.size(), world_size * 2)
+
+        testcase.assertEqual(flatten_mesh.get_local_rank(), dist.get_rank())
+
+        device_mesh.get_coordinate()
+        buffer = io.BytesIO()
+        torch.save(device_mesh, buffer)
+        buffer.seek(0)
+        torch.load(buffer, weights_only=False)
+
+    def test_init_device_mesh(self) -> None:
+        with ProcessPoolExecutor(max_workers=4) as executor:
+            futures = []
+            for i in range(4):
+                future = executor.submit(self._test_init_device_mesh, 4, i)
+                futures.append(future)
+            for f in futures:
+                f.result()
diff --git a/torchft/local_sgd.py b/torchft/local_sgd.py
index 1458f07..ec3cf82 100644
--- a/torchft/local_sgd.py
+++ b/torchft/local_sgd.py
@@ -3,25 +3,29 @@
 #
 # This source code is licensed under the BSD-style license found in the
 # LICENSE file in the root directory of this source tree.
-
 """
 LocalSGD
 =========
-
 This module implements a fault tolerant version of LocalSGD and related methods.
 """
-
-from typing import Any, Dict, List, Mapping, Optional
+import logging
+from types import TracebackType
+from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Type
 
 import torch
 from torch import nn, optim
+from torch.nn.parameter import Parameter
+from torch.optim.optimizer import Optimizer
+from torch.utils.hooks import RemovableHandle
 
 from torchft.manager import Manager
 
+logger: logging.Logger = logging.getLogger(__name__)
+
 
-class LocalSGD(nn.Module):
+class LocalSGD:
     """
-    LocalSGD is a model wrapper similar to DistributedDataParallel that
+    LocalSGD is a context manager that
     implements the algorithm described in https://arxiv.org/pdf/1805.09767
 
     This will synchronize the model parameters periodically in a fault tolerant
@@ -68,18 +72,14 @@ def __init__(
             pin_memory: Whether to pin the memory used for the backup of the model parameters.
         """
         super().__init__()
-
         self._manager = manager
         self._model = model
+        self._local_optimizer = optimizer
         self._local_step = 0
-        self._started_step = False
         self._sync_every = sync_every
         assert sync_every >= 1, "sync_every must be greater than or equal to 1"
-
         device = backup_device or torch.device("cpu")
-
         self._backup_parameters: Dict[str, torch.Tensor] = {}
-
         for name, p in self._model.named_parameters():
             t = torch.empty(*tuple(p.shape), dtype=p.dtype, device=device)
             if (
@@ -90,95 +90,150 @@ def __init__(
                 t = t.pin_memory()
             self._backup_parameters[name] = t
 
+        self._hooks: List[RemovableHandle] = []
         # Need to copy the parameters to the host to be safe if we are on the first step.
         self._save_parameters()
 
-        optimizer.register_step_post_hook(self._step_post_hook)
+    def __enter__(self) -> "LocalSGD":
+        # Add optimizer hook which increments the local step counter and syncs if necessary
+        self._hooks.append(
+            self._local_optimizer.register_step_post_hook(self._step_post_hook)
+        )
+        return self
+
+    def __exit__(
+        self,
+        exc_type: Optional[Type[BaseException]],
+        exc_value: Optional[BaseException],
+        traceback: Optional[TracebackType],
+    ) -> bool:
+        # Handle any cleanup or error handling here
+        if exc_type is not None:
+            # If an exception occurred, restore parameters
+            self._restore_parameters()
+        # Clean up hooks
+        for hook in self._hooks:
+            hook.remove()
+        self._hooks.clear()
+
+        return False  # Propagate exceptions
 
     def _save_parameters(self) -> None:
-        # TODO: consider running copy on a separate stream
-        for name, p in self._model.named_parameters():
-            self._backup_parameters[name].copy_(p.data, non_blocking=True)
+        with torch.no_grad():
+            # TODO: consider running copy on a separate stream
+            for name, p in self._model.named_parameters():
+                self._backup_parameters[name].copy_(p.data, non_blocking=True)
 
     def _restore_parameters(self) -> None:
-        # TODO: consider running copy on a separate stream
-        for name, p in self._model.named_parameters():
-            p.data.copy_(self._backup_parameters[name], non_blocking=True)
+        with torch.no_grad():
+            # TODO: consider running copy on a separate stream
+            for name, p in self._model.named_parameters():
+                p.data.copy_(self._backup_parameters[name], non_blocking=False)
 
-    # pyre-fixme[14]: support state_dict args
-    def state_dict(self) -> Dict[str, object]:
-        """
-        state_dict returns the state_dict from the last time LocalSGD
-        synchronized and not the current weights.
-        """
-        state_dict = self._model.state_dict()
-        for name, p in self._backup_parameters.items():
-            assert name in state_dict
-            state_dict[name] = p
-        return state_dict
-
-    def load_state_dict(
-        self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False
+    def _step_post_hook(
+        self, _optim: optim.Optimizer, _args: List[object], _kwargs: Dict[str, object]
     ) -> None:
         """
-        Loads the state dict to the model and the backup parameters.
+        This hook is registered on the optimizer and is called after the optimizer step.
+        """
+        self._local_step += 1
+        if self._local_step >= self._sync_every:
+            self.sync()
 
-        This must be called while the model weights aren't being modified to
-        avoid corrupting the backup weights.
+    def sync(self) -> None:
         """
-        self._model.load_state_dict(state_dict, strict=strict, assign=assign)
-        self._save_parameters()
+        Synchronizes and averages the model weights across the manager.
+        """
+        self._manager.start_quorum()
+        self._perform_sync()
+        self._local_step = 0
 
-    def forward(self, *args: object, **kwargs: object) -> object:
+    def _perform_sync(self) -> None:
+        """
+        Performs the synchronization of the model weights across the manager.
+        This method is intended to be overridden by subclasses to implement custom
+        synchronization logic.
         """
-        Run the model parameters.
+        self._average()
+        if self._manager.should_commit():
+            self._save_parameters()
+        else:
+            # commit failed, restore from the backup parameters
+            self._restore_parameters()
 
-        This should be called before the optimizer step.
+    def _average(self) -> None:
+        # TODO: do we need to broadcast buffers like DDP does?
 
-        This will start the quorum and save the parameters if this is the first step.
-        """
-        if self._local_step == 0:
-            self._manager.start_quorum()
+        works = []
+
+        for p in self._model.parameters():
+            # TODO: bucketize parameters
+            works.append(self._manager.allreduce(p.data.detach()))
 
-        self._started_step = True
+        for work in works:
+            work.wait()
 
-        return self._model.forward(*args, **kwargs)
 
-    def _step_post_hook(
-        self, _optim: optim.Optimizer, _args: List[object], _kwargs: Dict[str, object]
-    ) -> None:
-        """
-        This hook is registered on the optimizer and is called after the optimizer step.
+class DiLoCo(LocalSGD):
+    """
+    DiLoCo is a subclass of LocalSGD that overrides the synchronization
+    mechanism to average and synchronize the pseudogradients (delta of the previous global weight and current local weights).
 
-        This will call the allreduce on the model weights every sync_every steps.
-        If any errors occur it will restore to the weights from the previous sync.
+    diloco: https://arxiv.org/pdf/2311.08105
+    """
 
-        ``forward`` must be called before this function.
+    def __init__(
+        self,
+        manager: Manager,
+        model: nn.Module,
+        inner_optimizer: optim.Optimizer,
+        outer_optimizer: optim.Optimizer,
+        sync_every: int,
+        backup_device: Optional[torch.device] = None,
+        pin_memory: bool = True,
+    ) -> None:
+        if manager._use_async_quorum:
+            raise ValueError(
+                "Using DiLoCo require synchronous quorum to be enabled. "
+                "Ensure that the manager is initialized with use_async_quorum=False"
+            )
+        super().__init__(
+            manager, model, inner_optimizer, sync_every, backup_device, pin_memory
+        )
+        self._outer_optimizer = outer_optimizer
+
+    def _perform_sync(self) -> None:
+        """
+        Overrides the sync method to calculate the pseugradient, average them across the manager group, and
+        step using the outer optimizer.
         """
-        assert self._started_step, "forward must be called before step"
-        self._started_step = False
 
-        self._local_step += 1
+        # Set the .grad field of each parameter to its pseudogradient
+        for name, p in self._model.named_parameters():
+            assert name in self._backup_parameters
+            pseudogradient = p.data - self._backup_parameters[name]
+            p.grad = pseudogradient
 
-        if self._local_step >= self._sync_every:
-            self._local_step = 0
-            self._average()
+        self._average_grads()
+        # Restore the parameters back to the previous state
+        self._restore_parameters()
 
-            if self._manager.should_commit():
-                # save the parameters so we can restore from them later if necessary.
-                self._save_parameters()
-            else:
-                # commit failed, restore from the backup parameters
-                self._restore_parameters()
-
-    def _average(self) -> None:
-        # TODO: do we need to broadcast buffers like DDP does?
+        if self._manager.should_commit():
+            # Use the outer optimizer to update the model parameters
+            self._outer_optimizer.step()
+            self._save_parameters()
+        self._outer_optimizer.zero_grad()
 
+    def _average_grads(self) -> None:
+        """
+        Average the gradients across the diloco group.
+        """
         works = []
-
         for p in self._model.parameters():
-            # TODO: bucketize parameters
-            works.append(self._manager.allreduce(p.data.detach()))
-
+            # Perform allreduce on the pseudogradients
+            assert p.grad is not None
+            work = self._manager.allreduce(p.grad)
+            works.append(work)
+        # Wait for all allreduce operations to complete
         for work in works:
             work.wait()
diff --git a/torchft/local_sgd_test.py b/torchft/local_sgd_test.py
index d2b73cd..05f88b7 100644
--- a/torchft/local_sgd_test.py
+++ b/torchft/local_sgd_test.py
@@ -11,7 +11,7 @@
 import torch
 from torch import nn, optim
 
-from torchft.local_sgd import LocalSGD
+from torchft.local_sgd import DiLoCo, LocalSGD
 from torchft.manager import Manager
 
 
@@ -40,57 +40,107 @@ def _copy_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Ten
 
 class LocalSGDTest(TestCase):
     def test_local_sgd_healthy(self) -> None:
-        base_m = SimpleModel()
-        optimizer = optim.SGD(base_m.parameters())
+        model = SimpleModel()
+        optimizer = optim.SGD(model.parameters())
         manager = create_autospec(Manager)
-
-        m = LocalSGD(manager, base_m, optimizer, sync_every=2)
-        self.assertEqual(m._local_step, 0)
-
-        torch.testing.assert_close(m._backup_parameters, _params_dict(base_m))
-
-        inp = torch.rand(2, 3)
-
-        loss = m(inp).mean()
-        loss.backward()
-        optimizer.step()
-
-        self.assertEqual(m._local_step, 1)
-        self.assertEqual(manager.start_quorum.call_count, 1)
-
-        loss = m(inp).mean()
-        loss.backward()
-        optimizer.step()
-
-        manager.should_commit.return_value = True
-        self.assertEqual(m._local_step, 0)
-
-        torch.testing.assert_close(m._backup_parameters, _params_dict(base_m))
-        self.assertEqual(manager.should_commit.call_count, 1)
-        self.assertEqual(manager.allreduce.call_count, 4)
+        with LocalSGD(manager, model, optimizer, sync_every=2) as local_sgd:
+            self.assertEqual(local_sgd._local_step, 0)
+            torch.testing.assert_close(
+                local_sgd._backup_parameters, _params_dict(model)
+            )
+            inp = torch.rand(2, 3)
+            loss = model(inp).mean()
+            loss.backward()
+            optimizer.step()
+
+            self.assertEqual(local_sgd._local_step, 1)
+            self.assertEqual(manager.start_quorum.call_count, 0)
+            loss = model(inp).mean()
+            loss.backward()
+            optimizer.step()
+            self.assertEqual(manager.start_quorum.call_count, 1)
+
+            manager.should_commit.return_value = True
+            self.assertEqual(local_sgd._local_step, 0)
+            torch.testing.assert_close(
+                local_sgd._backup_parameters, _params_dict(model)
+            )
+            self.assertEqual(manager.should_commit.call_count, 1)
+            self.assertEqual(manager.allreduce.call_count, 4)
 
     def test_local_sgd_recovery(self) -> None:
-        base_m = SimpleModel()
-        optimizer = optim.SGD(base_m.parameters())
+        model = SimpleModel()
+        optimizer = optim.SGD(model.parameters())
         manager = create_autospec(Manager)
 
-        m = LocalSGD(manager, base_m, optimizer, sync_every=2)
-
-        torch.testing.assert_close(m._backup_parameters, _params_dict(base_m))
-        og_state_dict = _copy_state_dict(base_m.state_dict())
-
-        inp = torch.rand(2, 3)
-
-        loss = m(inp).mean()
-        loss.backward()
-        optimizer.step()
-
-        self.assertEqual(m._local_step, 1)
-
-        state_dict = m.state_dict()
-        torch.testing.assert_close(state_dict, m._backup_parameters)
-        torch.testing.assert_close(state_dict, og_state_dict)
+        with LocalSGD(manager, model, optimizer, sync_every=2) as local_sgd:
+            torch.testing.assert_close(
+                local_sgd._backup_parameters, _params_dict(model)
+            )
+            og_state_dict = _copy_state_dict(model.state_dict())
+            print(og_state_dict)
+
+            inp = torch.rand(2, 3)
+
+            loss = model(inp).mean()
+            loss.backward()
+            optimizer.step()
+
+            # Check that the model's state dict has been updated
+            for name, param in model.state_dict().items():
+                # Ensure the parameter has changed
+                self.assertFalse(
+                    torch.equal(og_state_dict[name], param),
+                    f"Parameter {name} did not change.",
+                )
+            self.assertEqual(local_sgd._local_step, 1)
+
+            local_sgd._restore_parameters()
+            torch.testing.assert_close(
+                local_sgd._backup_parameters, _params_dict(model)
+            )
+
+
+class DiLoCoTest(TestCase):
+    def test_diloco_healthy(self) -> None:
+        model = SimpleModel()
+
+        # Setup optimizers
+        inner_optimizer = torch.optim.AdamW(
+            model.parameters(), lr=4e-4, weight_decay=0.1, betas=(0.9, 0.95)
+        )
+        outer_optimizer = torch.optim.SGD(
+            model.parameters(), lr=0.7, momentum=0.9, nesterov=True
+        )
 
-        m.load_state_dict(state_dict)
-        torch.testing.assert_close(_params_dict(base_m), state_dict)
-        torch.testing.assert_close(m._backup_parameters, _params_dict(base_m))
+        manager = create_autospec(Manager)
+        manager._use_async_quorum = False
+        with DiLoCo(
+            manager, model, inner_optimizer, outer_optimizer, sync_every=2
+        ) as diloco:
+            parameter_count = len(list(model.parameters()))
+            initial_outer_opt_state = outer_optimizer.state_dict()
+            self.assertEqual(initial_outer_opt_state["state"], {})
+
+            self.assertEqual(diloco._local_step, 0)
+            torch.testing.assert_close(diloco._backup_parameters, _params_dict(model))
+            inp = torch.rand(2, 3)
+            loss = model(inp).mean()
+            loss.backward()
+            inner_optimizer.step()
+
+            self.assertEqual(diloco._local_step, 1)
+            self.assertEqual(manager.start_quorum.call_count, 0)
+            loss = model(inp).mean()
+            loss.backward()
+            inner_optimizer.step()
+            self.assertEqual(manager.start_quorum.call_count, 1)
+
+            manager.should_commit.return_value = True
+            self.assertEqual(diloco._local_step, 0)
+            torch.testing.assert_close(diloco._backup_parameters, _params_dict(model))
+            self.assertEqual(manager.should_commit.call_count, 1)
+            self.assertEqual(manager.allreduce.call_count, parameter_count)
+
+            outer_opt_state = outer_optimizer.state_dict()
+            self.assertEqual(len(outer_opt_state["state"]), parameter_count)
diff --git a/torchft/manager.py b/torchft/manager.py
index d9ff366..5ca25cb 100644
--- a/torchft/manager.py
+++ b/torchft/manager.py
@@ -38,7 +38,7 @@
 import torch
 from torch.distributed import ReduceOp, TCPStore
 
-from torchft.checkpointing import CheckpointServer
+from torchft.checkpointing import CheckpointServer, CheckpointTransport
 from torchft.futures import future_timeout
 from torchft.torchft import Manager as _Manager, ManagerClient
 
@@ -87,8 +87,8 @@ class Manager:
     def __init__(
         self,
         pg: "ProcessGroup",
-        load_state_dict: Callable[[T], None],
-        state_dict: Callable[[], T],
+        load_state_dict: Optional[Callable[[T], None]],
+        state_dict: Optional[Callable[[], T]],
         min_replica_size: int,
         use_async_quorum: bool = True,
         timeout: timedelta = timedelta(seconds=60),
@@ -104,6 +104,7 @@ def __init__(
         port: Optional[int] = None,
         hostname: str = socket.gethostname(),
         heartbeat_interval: timedelta = timedelta(milliseconds=100),
+        checkpoint_transport: Optional[CheckpointTransport[Dict[str, T]]] = None,
     ) -> None:
         """
         Args:
@@ -139,9 +140,11 @@ def __init__(
             lighthouse_addr: if rank==0, the address of the lighthouse server
             replica_id: if rank==0, the replica_id for this group
             hostname: if rank==0, the hostname to advertise to the lighthouse server
+            checkpoint_transport: the checkpoint transport to use for
+                transfering checkpoints to recovering replicas
         """
         self._load_state_dict = load_state_dict
-        self._state_dict = state_dict
+        self._user_state_dict = state_dict
         self._pending_state_dict: Optional[Dict[str, object]] = None
         self._use_async_quorum = use_async_quorum
         self._timeout = timeout
@@ -156,15 +159,13 @@ def __init__(
         world_size = world_size or int(os.environ["WORLD_SIZE"])
         self._min_replica_size = min_replica_size
 
-        def _manager_state_dict() -> Dict[str, T]:
-            return {
-                "user": state_dict(),
-                "torchft": cast(T, self.state_dict()),
-            }
+        if checkpoint_transport is None:
+            checkpoint_transport = CheckpointServer[Dict[str, T]](
+                timeout=timeout,
+            )
 
-        self._ckpt_server = CheckpointServer[Dict[str, T]](
-            _manager_state_dict,
-            timeout=timeout,
+        self._checkpoint_transport: CheckpointTransport[Dict[str, T]] = (
+            checkpoint_transport
         )
         self._executor = ThreadPoolExecutor(
             max_workers=1, thread_name_prefix="async_quorum"
@@ -223,14 +224,20 @@ def _manager_state_dict() -> Dict[str, T]:
         self._participating_rank: Optional[int] = None
         self._participating_world_size: int = 0
 
-    def shutdown(self) -> None:
+    def set_state_dict_fns(
+        self, load_state_dict: Callable[[T], None], state_dict: Callable[[], T]
+    ) -> None:
+        self._load_state_dict = load_state_dict
+        self._user_state_dict = state_dict
+
+    def shutdown(self, wait: bool = True) -> None:
         """
         Shutdown the manager and checkpoint server.
         """
-        self._ckpt_server.shutdown()
+        self._checkpoint_transport.shutdown(wait=wait)
         if self._manager is not None:
             self._manager.shutdown()
-        self._executor.shutdown()
+        self._executor.shutdown(wait=wait)
 
     def allreduce(self, tensor: torch.Tensor) -> torch.futures.Future[torch.Tensor]:
         """
@@ -386,7 +393,6 @@ def start_quorum(
 
         self._errored = None
         self._healing = False
-        self._ckpt_server.allow_checkpoint(self._step)
 
         # TODO: we should really be wrapping this whole section in a try-except
         # block to allow gracefully recovering from issues in PG setup and quorum.
@@ -422,24 +428,24 @@ def wait_quorum(self) -> None:
     def _async_quorum(
         self, allow_heal: bool, shrink_only: bool, quorum_timeout: timedelta
     ) -> None:
-        (
-            quorum_id,
-            replica_rank,
-            replica_world_size,
-            address,
-            store_address,
-            max_step,
-            max_rank,
-            max_world_size,
-            heal,
-        ) = self._client.quorum(
+        quorum = self._client.quorum(
             rank=self._rank,
             step=self._step,
-            checkpoint_server_addr=self._ckpt_server.address(),
+            checkpoint_metadata=self._checkpoint_transport.metadata(),
             shrink_only=shrink_only,
             timeout=quorum_timeout,
         )
 
+        quorum_id = quorum.quorum_id
+        replica_rank = quorum.replica_rank
+        replica_world_size = quorum.replica_world_size
+        recover_src_manager_address = quorum.recover_src_manager_address
+        store_address = quorum.store_address
+        max_step = quorum.max_step
+        max_rank = quorum.max_rank
+        max_world_size = quorum.max_world_size
+        heal = quorum.heal
+
         # When using async quorum we need to take the recovered workers.
         # When not using async quorum we need to take the max world size as all
         # workers will be healthy.
@@ -470,29 +476,54 @@ def _async_quorum(
             self._pg.configure(store_prefixed_addr, replica_rank, replica_world_size)
             self._quorum_id = quorum_id
 
-        # See manager.rs for healing conditions
-        if heal and allow_heal:
-            self._healing = True
-            self._logger.info(
-                f"healing required, fetching checkpoint server address from {address=} {max_step=}"
-            )
-            primary_client = ManagerClient(
-                address, connect_timeout=self._connect_timeout
-            )
-            checkpoint_server_address = primary_client.checkpoint_address(
-                self._rank, timeout=self._timeout
-            )
+        if allow_heal:
+            if quorum.recover_dst_ranks:
+                self._logger.info(
+                    f"peers need recovery from us {quorum.recover_dst_ranks}"
+                )
+                self._checkpoint_transport.send_checkpoint(
+                    dst_ranks=quorum.recover_dst_ranks,
+                    step=max_step,
+                    state_dict=self._manager_state_dict(),
+                    timeout=self._timeout,
+                )
 
-            self._logger.info(f"fetching checkpoint from {checkpoint_server_address=}")
-            self._pending_state_dict = CheckpointServer.load_from_address(
-                checkpoint_server_address, timeout=self._timeout
-            )
-            self.load_state_dict(self._pending_state_dict["torchft"])
-            # we apply the user state dict only when safe from the main thread
+            # See manager.rs for healing conditions
+            if heal:
+                self._healing = True
+                self._logger.info(
+                    f"healing required, fetching checkpoint metadata from {recover_src_manager_address=} {max_step=}"
+                )
+                primary_client = ManagerClient(
+                    recover_src_manager_address, connect_timeout=self._connect_timeout
+                )
+                checkpoint_metadata = primary_client.checkpoint_metadata(
+                    self._rank, timeout=self._timeout
+                )
+                recover_src_rank = quorum.recover_src_rank
+                assert (
+                    recover_src_rank is not None
+                ), "must have a recover rank when healing"
+
+                self._logger.info(
+                    f"fetching checkpoint from {recover_src_rank=} with {checkpoint_metadata=}"
+                )
 
-            # This isn't strictly needed as loading the state_dict above should
-            # restore the correct step but it makes writing tests simpler.
-            self._step = max_step
+                # we apply the user state dict only when safe from the main thread
+                # save it for now
+                self._pending_state_dict = self._checkpoint_transport.recv_checkpoint(
+                    src_rank=recover_src_rank,
+                    metadata=checkpoint_metadata,
+                    step=max_step,
+                    timeout=self._timeout,
+                )
+
+                # pyre-fixme[6]: got object
+                self.load_state_dict(self._pending_state_dict["torchft"])
+
+                # This isn't strictly needed as loading the state_dict above should
+                # restore the correct step but it makes writing tests simpler.
+                self._step = max_step
 
     def _apply_pending_state_dict(self) -> None:
         assert self._healing, "must be in healing state"
@@ -504,8 +535,12 @@ def _apply_pending_state_dict(self) -> None:
         self._logger.info("applying pending state dict")
 
         assert self._pending_state_dict is not None, "checkpoint was not staged"
+        assert (
+            self._load_state_dict is not None
+        ), "user load_state_dict is not initialized."
         self._load_state_dict(self._pending_state_dict["user"])
         self._pending_state_dict = None
+        self._logger.info("Loaded state dict.")
 
     def should_commit(self, timeout: Optional[timedelta] = None) -> bool:
         """
@@ -553,7 +588,7 @@ def should_commit(self, timeout: Optional[timedelta] = None) -> bool:
             f"should_commit={should_commit} enough_replicas={enough_replicas}, errored={self._errored}"
         )
 
-        self._ckpt_server.disallow_checkpoint()
+        self._checkpoint_transport.disallow_checkpoint()
 
         # decide whether we're in a healthy state to increase the step count
         if should_commit:
@@ -574,6 +609,13 @@ def load_state_dict(self, state_dict: Dict[str, int]) -> None:
         self._step = state_dict["step"]
         self._batches_committed = state_dict["batches_committed"]
 
+    def _manager_state_dict(self) -> Dict[str, object]:
+        assert self._user_state_dict is not None, "user state_dict is not initialized."
+        return {
+            "user": self._user_state_dict(),
+            "torchft": self.state_dict(),
+        }
+
     def state_dict(self) -> Dict[str, int]:
         """
         Get the state dict for this manager.
@@ -610,15 +652,35 @@ def batches_committed(self) -> int:
         """
         return self._batches_committed
 
+    def participating_rank(self) -> Optional[int]:
+        """
+        Get the replica group rank of the current quorum. This will be the same on all
+        ranks within the replica group.
+
+        If this replica group is not participating in the current quorum, this will be None.
+
+        This will block on the async quorum if it is not yet ready.
+
+        Returns:
+            the rank of the current quorum
+        """
+        self.wait_quorum()
+
+        return self._participating_rank
+
     def num_participants(self) -> int:
         """
         Get the number of participants in the current quorum.
 
         This is the number of replicas participating in the current step.
 
+        This will block on the async quorum if it is not yet ready.
+
         Returns:
             the number of participants in the current quorum
         """
+        self.wait_quorum()
+
         assert self._participating_world_size >= 0, "internal error"
         return self._participating_world_size
 
diff --git a/torchft/manager_integ_test.py b/torchft/manager_integ_test.py
index d6e7bde..8c7c45d 100644
--- a/torchft/manager_integ_test.py
+++ b/torchft/manager_integ_test.py
@@ -1,3 +1,4 @@
+import copy
 import logging
 import threading
 import time
@@ -5,7 +6,7 @@
 from contextlib import ExitStack, contextmanager
 from dataclasses import dataclass, field
 from datetime import timedelta
-from typing import Dict, Generator, List, Protocol, Set, Tuple
+from typing import Any, Dict, Generator, List, Protocol, Set, Tuple
 from unittest import TestCase
 
 import torch
@@ -14,7 +15,7 @@
 from torch import nn, optim
 
 from torchft.ddp import DistributedDataParallel
-from torchft.local_sgd import LocalSGD
+from torchft.local_sgd import DiLoCo, LocalSGD
 from torchft.manager import Manager
 from torchft.optim import OptimizerWrapper
 from torchft.process_group import ProcessGroupGloo
@@ -76,6 +77,7 @@ class Runner:
     world_size: int = 1
     attempts: int = 3
     manager_args: Dict[str, object] = field(default_factory=dict)
+    train_loop_args: Dict[str, Any] = field(default_factory=dict)
 
     def _replica_main(self) -> List[Dict[str, Dict[str, object]]]:
         store = dist.TCPStore(
@@ -103,7 +105,7 @@ def _replica_main(self) -> List[Dict[str, Dict[str, object]]]:
                 try:
                     fut.result()
                 except Exception as e:
-                    logger.exception(f"worker threw exception: {e}")
+                    logger.exception(f"worker {self.replica_id=} threw exception: {e}")
                     raise
 
             return [fut.result() for fut in futures]
@@ -159,7 +161,7 @@ def state_dict() -> Dict[str, Dict[str, object]]:
             # pyre-fixme[6]: Incompatible parameter type
             **runner.manager_args,
         )
-        stack.callback(manager.shutdown)
+        stack.callback(lambda: manager.shutdown(wait=False))
 
         m: nn.Module = DistributedDataParallel(manager, MyModel())
         optimizer: optim.Optimizer = OptimizerWrapper(
@@ -223,34 +225,118 @@ def state_dict() -> Dict[str, Dict[str, object]]:
             # pyre-fixme[6]: Incompatible parameter type
             **runner.manager_args,
         )
-        stack.callback(manager.shutdown)
+        stack.callback(lambda: manager.shutdown(wait=False))
 
         m: nn.Module = MyModel()
         optimizer: optim.Optimizer = optim.Adam(m.parameters())
-        m = LocalSGD(manager, m, optimizer, sync_every=2)
         criterion = nn.CrossEntropyLoss()
 
-        while True:
-            inputs = torch.rand(2, 3)
-            labels = torch.randint(4, (2,))
+        with LocalSGD(manager, m, optimizer, sync_every=2):
+            while True:
+                inputs = torch.rand(2, 3)
+                labels = torch.randint(4, (2,))
 
-            optimizer.zero_grad()
-            out = m(inputs)
-            loss = criterion(out, labels)
+                optimizer.zero_grad()
+                out = m(inputs)
+                loss = criterion(out, labels)
 
-            loss.backward()
+                loss.backward()
 
-            optimizer.step()
+                optimizer.step()
 
-            if manager.current_step() >= 4:
-                break
+                if manager.current_step() >= 4:
+                    break
 
-            runner.failure_injector.check(rank, manager.current_step())
+                runner.failure_injector.check(rank, manager.current_step())
 
         # return state_dict so we can check consistency
         return state_dict()
 
 
+def diloco_train_loop(
+    rank: int,
+    store_port: int,
+    runner: Runner,
+) -> Dict[str, Dict[str, object]]:
+    with ExitStack() as stack:
+        # Declare the model and optimizers
+        m: nn.Module = MyModel()
+        model_state_dict: Dict[str, Any] = runner.train_loop_args["model_state_dict"]
+        m.load_state_dict(model_state_dict)
+
+        # Setup optimizers
+        inner_optimizer: optim.Optimizer = torch.optim.AdamW(
+            m.parameters(), lr=4e-4, weight_decay=0.1, betas=(0.9, 0.95)
+        )
+        outer_optimizer: optim.Optimizer = torch.optim.SGD(
+            m.parameters(), lr=0.7, momentum=0.9, nesterov=True
+        )
+
+        # pyre-ignore[53]
+        def load_state_dict(state_dict: Dict[str, Dict[str, object]]) -> None:
+            m.load_state_dict(state_dict["model"])
+            # TODO: make this cleaner so we don't have to save this
+            diloco._backup_parameters = state_dict["backup_params"]
+            inner_optimizer.load_state_dict(state_dict["inner_optim"])
+            outer_optimizer.load_state_dict(state_dict["outer_optim"])
+
+        def state_dict() -> Dict[str, Dict[str, object]]:  # pyre-ignore[53]
+            return {
+                "model": m.state_dict(),
+                "backup_params": copy.deepcopy(diloco._backup_parameters),
+                "inner_optim": inner_optimizer.state_dict(),
+                "outer_optim": outer_optimizer.state_dict(),
+            }
+
+        print(f"worker {runner.replica_id=} {rank=} {runner.world_size=} starting")
+
+        pg = ProcessGroupGloo()
+        manager = Manager(
+            pg=pg,
+            min_replica_size=2,
+            use_async_quorum=False,
+            load_state_dict=load_state_dict,
+            state_dict=state_dict,
+            replica_id=str(runner.replica_id),
+            store_addr="localhost",
+            store_port=store_port,
+            rank=rank,
+            world_size=runner.world_size,
+            lighthouse_addr=runner.lighthouse_address,
+            port=19530 + runner.replica_id,
+            # pyre-fixme[6]: Incompatible parameter type
+            **runner.manager_args,
+        )
+        stack.callback(manager.shutdown)
+
+        criterion = nn.CrossEntropyLoss()
+        all_state_dicts = {}
+        with DiLoCo(
+            manager, m, inner_optimizer, outer_optimizer, sync_every=2
+        ) as diloco:
+            while True:
+                inputs = torch.rand(2, 3)
+                labels = torch.randint(4, (2,))
+
+                out = m(inputs)
+                loss = criterion(out, labels)
+
+                inner_optimizer.zero_grad()
+                loss.backward()
+                inner_optimizer.step()
+                manager_step_str = str(manager.current_step())
+                all_state_dicts[manager_step_str] = state_dict()
+
+                # after 4 model updates then break
+                if manager.current_step() >= 4:
+                    break
+
+                runner.failure_injector.check(rank, manager.current_step())
+
+        # return state_dict so we can check consistency
+        return all_state_dicts
+
+
 class ManagerIntegTest(TestCase):
     @contextmanager
     def assertElapsedLessThan(
@@ -431,6 +517,108 @@ def test_local_sgd_recovery(self) -> None:
 
         self.assertEqual(failure_injectors[1].count, 1)
 
+    def test_diloco_healthy(self) -> None:
+        lighthouse = Lighthouse(
+            bind="[::]:0",
+            min_replicas=2,
+        )
+        num_replicas = 2
+        futures = []
+
+        torch.manual_seed(42)
+        # Initialize the model so we can pass in the state_dict
+        m: nn.Module = MyModel()
+
+        with ThreadPoolExecutor(max_workers=num_replicas) as executor:
+            for replica_id in range(num_replicas):
+                failure_injector = FailureInjector()
+                runner = Runner(
+                    replica_id=replica_id,
+                    lighthouse_address=lighthouse.address(),
+                    failure_injector=failure_injector,
+                    train_loop=diloco_train_loop,
+                    train_loop_args={
+                        "model_state_dict": m.state_dict(),
+                    },
+                )
+                futures.append(executor.submit(runner.run_replica))
+
+        state_dicts = []
+
+        for fut in as_completed(futures):
+            state_dicts.append(fut.result()[0])
+
+        lighthouse.shutdown()
+
+        for replica_group in state_dicts:
+            for step, state_dict in replica_group.items():
+                # inner optimizer will be different, outer optimizer and model should be the same
+                torch.testing.assert_close(
+                    state_dict["backup_params"],
+                    state_dicts[0][str(step)]["backup_params"],
+                )
+                torch.testing.assert_close(
+                    state_dict["outer_optim"], state_dicts[0][str(step)]["outer_optim"]
+                )
+
+    def test_diloco_recovery(self) -> None:
+        lighthouse = Lighthouse(
+            bind="[::]:0",
+            min_replicas=2,
+        )
+        num_replicas = 2
+        futures = []
+
+        failure_injectors = [
+            FailureInjector(),
+            FailureInjector().fail_at(0, 2),
+        ]
+
+        torch.manual_seed(42)
+        # Initialize the model so we can pass in the state_dict
+        m: nn.Module = MyModel()
+
+        with ThreadPoolExecutor(max_workers=num_replicas) as executor:
+            for replica_id, failure_injector in zip(
+                range(num_replicas), failure_injectors
+            ):
+                runner = Runner(
+                    replica_id=replica_id,
+                    lighthouse_address=lighthouse.address(),
+                    failure_injector=failure_injector,
+                    train_loop=diloco_train_loop,
+                    train_loop_args={
+                        "model_state_dict": m.state_dict(),
+                    },
+                )
+                futures.append(executor.submit(runner.run_replica))
+
+            state_dicts = []
+
+            for fut in as_completed(futures):
+                try:
+                    state_dicts.append(fut.result()[0])
+                except Exception as e:
+                    print(e)
+                    raise
+
+        lighthouse.shutdown()
+        for replica_group in state_dicts:
+            for step, state_dict in replica_group.items():
+                str_step = str(step)
+                if str_step in state_dicts[0]:
+                    # inner optimizer will be different, outer optimizer and model should be the same
+                    torch.testing.assert_close(
+                        state_dict["backup_params"],
+                        state_dicts[0][str_step]["backup_params"],
+                    )
+                    torch.testing.assert_close(
+                        state_dict["outer_optim"],
+                        state_dicts[0][str_step]["outer_optim"],
+                    )
+
+        self.assertEqual(failure_injectors[1].count, 1)
+
     def test_quorum_timeout(self) -> None:
         with ExitStack() as stack:
             lighthouse = Lighthouse(
@@ -460,7 +648,7 @@ def test_quorum_timeout(self) -> None:
                 port=19530,
                 use_async_quorum=False,
             )
-            stack.callback(manager.shutdown)
+            stack.callback(lambda: manager.shutdown(wait=False))
 
             with self.assertElapsedLessThan(1.0):
                 with self.assertRaisesRegex(
diff --git a/torchft/manager_test.py b/torchft/manager_test.py
index 97b891c..3c1b662 100644
--- a/torchft/manager_test.py
+++ b/torchft/manager_test.py
@@ -4,7 +4,9 @@
 # This source code is licensed under the BSD-style license found in the
 # LICENSE file in the root directory of this source tree.
 
+import concurrent
 from datetime import timedelta
+from typing import Optional
 from unittest import TestCase
 from unittest.mock import MagicMock, create_autospec, patch
 
@@ -13,7 +15,7 @@
 
 from torchft.manager import MANAGER_ADDR_KEY, REPLICA_ID_KEY, Manager, WorldSizeMode
 from torchft.process_group import ProcessGroup, _DummyWork
-from torchft.torchft import ManagerClient
+from torchft.torchft import QuorumResult
 
 
 def mock_should_commit(
@@ -25,13 +27,19 @@ def mock_should_commit(
 class TestManager(TestCase):
     store: TCPStore  # pyre-fixme[13]: never initialized
     load_state_dict: MagicMock  # pyre-fixme[13]: never initialized
+    manager: Optional[Manager]  # pyre-fixme[13]: never initialized
+
+    def tearDown(self) -> None:
+        manager = self.manager
+        if manager is not None:
+            manager.shutdown(wait=False)
 
     def _create_manager(
         self,
         use_async_quorum: bool = True,
         min_replica_size: int = 2,
         world_size_mode: WorldSizeMode = WorldSizeMode.DYNAMIC,
-        timeout: timedelta = timedelta(seconds=60),
+        timeout: timedelta = timedelta(seconds=10),
     ) -> Manager:
         pg = create_autospec(ProcessGroup)
         self.store = TCPStore(
@@ -58,6 +66,7 @@ def _create_manager(
                 world_size_mode=world_size_mode,
                 timeout=timeout,
             )
+            self.manager = manager
         return manager
 
     @patch("torchft.manager.ManagerClient", autospec=True)
@@ -87,22 +96,54 @@ def test_state_dict(self, client_mock: MagicMock) -> None:
         self.assertEqual(manager.current_step(), 1234)
         self.assertEqual(manager.batches_committed(), 2345)
 
+    @patch("torchft.manager.ManagerClient", autospec=True)
+    def test_user_state_dict(self, client_mock: MagicMock) -> None:
+        manager = self._create_manager()
+
+        self.assertEqual(
+            manager._manager_state_dict(),
+            {
+                "user": {},
+                "torchft": {
+                    "step": 0,
+                    "batches_committed": 0,
+                },
+            },
+        )
+
+        manager.set_state_dict_fns(
+            self.load_state_dict,
+            lambda: {"new_state": 1},
+        )
+
+        self.assertEqual(
+            manager._manager_state_dict(),
+            {
+                "user": {"new_state": 1},
+                "torchft": {
+                    "step": 0,
+                    "batches_committed": 0,
+                },
+            },
+        )
+
     @patch("torchft.manager.ManagerClient", autospec=True)
     def test_quorum_happy(self, client_mock: MagicMock) -> None:
         manager = self._create_manager()
         client_mock().should_commit = mock_should_commit
 
-        client_mock().quorum.return_value = (
-            123,  # quorum_id
-            1,  # replica_rank
-            2,  # replica_world
-            "manager address",
-            f"localhost:{self.store.port}",
-            1,  # max_step
-            1,  # max_rank
-            2,  # max_world_size
-            False,  # heal
-        )
+        quorum = QuorumResult()
+        quorum.quorum_id = 123
+        quorum.replica_rank = 1
+        quorum.replica_world_size = 2
+        quorum.recover_src_manager_address = "manager address"
+        quorum.store_address = f"localhost:{self.store.port}"
+        quorum.max_step = 1
+        quorum.max_rank = 1
+        quorum.max_world_size = 2
+        quorum.heal = False
+
+        client_mock().quorum.return_value = quorum
 
         self.assertEqual(manager._quorum_id, -1)
         self.assertEqual(manager.current_step(), 0)
@@ -127,21 +168,30 @@ def test_quorum_heal_sync(self, client_mock: MagicMock) -> None:
         manager = self._create_manager(use_async_quorum=False)
         client_mock().should_commit = mock_should_commit
 
-        client_mock().quorum.return_value = (
-            123,  # quorum_id
-            1,  # replica_rank
-            2,  # replica_world
-            "manager address",
-            f"localhost:{self.store.port}",
-            20,  # max_step
-            None,  # max_rank
-            2,  # max_world_size
-            True,  # heal
-        )
-        # forcible increment checkpoint server to compute correct address
-        manager._ckpt_server.allow_checkpoint(manager.current_step())
+        quorum = QuorumResult()
+        quorum.quorum_id = 123
+        quorum.replica_rank = 1
+        quorum.replica_world_size = 2
+        quorum.recover_src_manager_address = "manager address"
+        quorum.recover_src_rank = 0
+        quorum.store_address = f"localhost:{self.store.port}"
+        quorum.max_step = 20
+        quorum.max_rank = None
+        quorum.max_world_size = 2
+        quorum.heal = True
+
+        client_mock().quorum.return_value = quorum
 
-        client_mock().checkpoint_address.return_value = manager._ckpt_server.address()
+        # forcible increment checkpoint server to compute correct address
+        manager._checkpoint_transport.send_checkpoint(
+            dst_ranks=[],
+            step=quorum.max_step,
+            state_dict=manager._manager_state_dict(),
+            timeout=timedelta(seconds=10),
+        )
+        client_mock().checkpoint_metadata.return_value = (
+            manager._checkpoint_transport.metadata()
+        )
 
         self.assertEqual(manager._quorum_id, -1)
         self.assertEqual(manager.current_step(), 0)
@@ -169,21 +219,30 @@ def test_quorum_heal_async_not_enough_participants(
         manager = self._create_manager(use_async_quorum=True, min_replica_size=2)
         client_mock().should_commit = mock_should_commit
 
-        client_mock().quorum.return_value = (
-            123,  # quorum_id
-            1,  # replica_rank
-            2,  # replica_world
-            "manager address",
-            f"localhost:{self.store.port}",
-            20,  # max_step
-            None,  # max_rank
-            1,  # max_world_size
-            True,  # heal
-        )
-        # forcible increment checkpoint server to compute correct address
-        manager._ckpt_server.allow_checkpoint(manager.current_step())
+        quorum = QuorumResult()
+        quorum.quorum_id = 123
+        quorum.replica_rank = 1
+        quorum.replica_world_size = 2
+        quorum.recover_src_manager_address = "manager address"
+        quorum.recover_src_rank = 0
+        quorum.store_address = f"localhost:{self.store.port}"
+        quorum.max_step = 20
+        quorum.max_rank = None
+        quorum.max_world_size = 1
+        quorum.heal = True
+
+        client_mock().quorum.return_value = quorum
 
-        client_mock().checkpoint_address.return_value = manager._ckpt_server.address()
+        # forcible increment checkpoint server to compute correct address
+        manager._checkpoint_transport.send_checkpoint(
+            dst_ranks=[],
+            step=quorum.max_step,
+            state_dict=manager._manager_state_dict(),
+            timeout=timedelta(seconds=10),
+        )
+        client_mock().checkpoint_metadata.return_value = (
+            manager._checkpoint_transport.metadata()
+        )
 
         self.assertEqual(manager._quorum_id, -1)
         self.assertEqual(manager.current_step(), 0)
@@ -212,6 +271,7 @@ def test_quorum_heal_async_not_enough_participants(
         self.assertEqual(self.load_state_dict.call_count, 1)
 
         # failed to commit so no step
+        quorum.heal = False
         manager.start_quorum()
         self.assertEqual(manager.current_step(), 20)
         self.assertEqual(manager.batches_committed(), 0)
@@ -221,21 +281,30 @@ def test_quorum_heal_async_zero_grad(self, client_mock: MagicMock) -> None:
         manager = self._create_manager(use_async_quorum=True, min_replica_size=1)
         client_mock().should_commit = mock_should_commit
 
-        client_mock().quorum.return_value = (
-            123,  # quorum_id
-            1,  # replica_rank
-            2,  # replica_world
-            "manager address",
-            f"localhost:{self.store.port}",
-            20,  # max_step
-            None,  # max_rank
-            1,  # max_world_size
-            True,  # heal
-        )
-        # forceable increment checkpoint server to compute correct address
-        manager._ckpt_server.allow_checkpoint(manager.current_step())
+        quorum = QuorumResult()
+        quorum.quorum_id = 123
+        quorum.replica_rank = 1
+        quorum.replica_world_size = 2
+        quorum.recover_src_manager_address = "manager address"
+        quorum.recover_src_rank = 0
+        quorum.store_address = f"localhost:{self.store.port}"
+        quorum.max_step = 20
+        quorum.max_rank = None
+        quorum.max_world_size = 1
+        quorum.heal = True
+
+        client_mock().quorum.return_value = quorum
 
-        client_mock().checkpoint_address.return_value = manager._ckpt_server.address()
+        # forceable increment checkpoint server to compute correct address
+        manager._checkpoint_transport.send_checkpoint(
+            dst_ranks=[],
+            step=quorum.max_step,
+            state_dict=manager._manager_state_dict(),
+            timeout=timedelta(seconds=10),
+        )
+        client_mock().checkpoint_metadata.return_value = (
+            manager._checkpoint_transport.metadata()
+        )
 
         self.assertEqual(manager._quorum_id, -1)
         self.assertEqual(manager.current_step(), 0)
@@ -261,6 +330,8 @@ def test_quorum_heal_async_zero_grad(self, client_mock: MagicMock) -> None:
 
         self.assertEqual(self.load_state_dict.call_count, 1)
 
+        # healed
+        quorum.heal = False
         manager.start_quorum()
         self.assertEqual(manager.current_step(), 21)
         self.assertEqual(manager.batches_committed(), 1)
@@ -270,17 +341,18 @@ def test_allreduce_error(self, client_mock: MagicMock) -> None:
         manager = self._create_manager()
         client_mock().should_commit = mock_should_commit
 
-        client_mock().quorum.return_value = (
-            123,  # quorum_id
-            1,  # replica_rank
-            2,  # replica_world
-            "manager address",
-            f"localhost:{self.store.port}",
-            1,  # max_step
-            1,  # max_rank
-            2,  # max_world_size
-            False,  # heal
-        )
+        quorum = QuorumResult()
+        quorum.quorum_id = 123
+        quorum.replica_rank = 1
+        quorum.replica_world_size = 2
+        quorum.recover_src_manager_address = "manager address"
+        quorum.store_address = f"localhost:{self.store.port}"
+        quorum.max_step = 1
+        quorum.max_rank = 1
+        quorum.max_world_size = 2
+        quorum.heal = False
+
+        client_mock().quorum.return_value = quorum
 
         self.assertEqual(manager._quorum_id, -1)
         self.assertEqual(manager.current_step(), 0)
@@ -308,17 +380,8 @@ def test_allreduce_error(self, client_mock: MagicMock) -> None:
         manager._pg.allreduce.side_effect = None
 
         # inject failure when worked waited
-        client_mock().quorum.return_value = (
-            123,  # quorum_id
-            1,  # replica_rank
-            2,  # replica_world
-            "manager address",
-            f"localhost:{self.store.port}",
-            2,  # max_step
-            1,  # max_rank
-            2,  # max_world_size
-            False,  # heal
-        )
+        quorum.max_step = 2
+
         manager.start_quorum()
 
         self.assertFalse(manager._errored)
@@ -336,17 +399,7 @@ def test_allreduce_error(self, client_mock: MagicMock) -> None:
         manager._pg.allreduce.reset_mock(return_value=True)
 
         # recover on next step
-        client_mock().quorum.return_value = (
-            123,  # quorum_id
-            1,  # replica_rank
-            2,  # replica_world_size
-            "manager address",
-            f"localhost:{self.store.port}",
-            3,  # max_step
-            1,  # max_rank
-            2,  # max_world_size
-            False,  # heal
-        )
+        quorum.max_step = 3
 
         manager.start_quorum()
         manager.allreduce(torch.tensor([1.0])).wait()
@@ -362,17 +415,18 @@ def test_quorum_fixed_world_size(self, client_mock: MagicMock) -> None:
             )
             client_mock().should_commit = mock_should_commit
 
-            client_mock().quorum.return_value = (
-                123,  # quorum_id
-                rank,  # replica_rank
-                3,  # replica_world
-                "manager address",
-                f"localhost:{self.store.port}",
-                1,  # max_step
-                rank,  # max_rank
-                3,  # max_world_size
-                False,  # heal
-            )
+            quorum = QuorumResult()
+            quorum.quorum_id = 123
+            quorum.replica_rank = rank
+            quorum.replica_world_size = 3
+            quorum.recover_src_manager_address = "manager address"
+            quorum.store_address = f"localhost:{self.store.port}"
+            quorum.max_step = 1
+            quorum.max_rank = rank
+            quorum.max_world_size = 3
+            quorum.heal = False
+
+            client_mock().quorum.return_value = quorum
 
             self.assertEqual(manager._quorum_id, -1)
             self.assertEqual(manager.current_step(), 0)
@@ -395,17 +449,18 @@ def test_quorum_no_healing(self, client_mock: MagicMock) -> None:
         )
         client_mock().should_commit = mock_should_commit
 
-        client_mock().quorum.return_value = (
-            123,  # quorum_id
-            0,  # replica_rank
-            3,  # replica_world
-            "manager address",
-            f"localhost:{self.store.port}",
-            1,  # max_step
-            None,  # max_rank
-            2,  # max_world_size
-            True,  # heal
-        )
+        quorum = QuorumResult()
+        quorum.quorum_id = 123
+        quorum.replica_rank = 0
+        quorum.replica_world_size = 3
+        quorum.recover_src_manager_address = "manager address"
+        quorum.recover_src_rank = 1
+        quorum.store_address = f"localhost:{self.store.port}"
+        quorum.max_step = 1
+        quorum.max_rank = None
+        quorum.max_world_size = 2
+        quorum.heal = True
+        client_mock().quorum.return_value = quorum
 
         self.assertEqual(manager._quorum_id, -1)
         self.assertEqual(manager.current_step(), 0)
@@ -467,10 +522,16 @@ def test_manager_wrap_future_timeout(self, client_mock: MagicMock) -> None:
     def test_manager_numerics(self, client_mock: MagicMock) -> None:
         manager = self._create_manager()
 
-        manager._quorum_future = MagicMock()
+        manager._quorum_future = quorum_future = MagicMock(
+            spec=concurrent.futures.Future
+        )
         manager._participating_rank = 1
         manager._participating_world_size = 5
         self.assertEqual(manager.num_participants(), 5)
+        self.assertEqual(quorum_future.result.call_count, 1)
+        self.assertEqual(manager.participating_rank(), 1)
+        self.assertEqual(quorum_future.result.call_count, 2)
+
         # pyre-ignore[16]: _pg is mocked
         manager._pg.allreduce.return_value = _DummyWork(None)
 
@@ -492,17 +553,18 @@ def test_manager_numerics(self, client_mock: MagicMock) -> None:
     def test_quorum_happy_timeouts(self, client_mock: MagicMock) -> None:
         manager = self._create_manager(use_async_quorum=False)
 
-        client_mock().quorum.return_value = (
-            123,  # quorum_id
-            1,  # replica_rank
-            2,  # replica_world
-            "manager address",
-            f"localhost:{self.store.port}",
-            1,  # max_step
-            1,  # max_rank
-            2,  # max_world_size
-            False,  # heal
-        )
+        quorum = QuorumResult()
+        quorum.quorum_id = 123
+        quorum.replica_rank = 1
+        quorum.replica_world_size = 2
+        quorum.recover_src_manager_address = "manager address"
+        quorum.store_address = f"localhost:{self.store.port}"
+        quorum.max_step = 1
+        quorum.max_rank = 1
+        quorum.max_world_size = 2
+        quorum.heal = False
+
+        client_mock().quorum.return_value = quorum
 
         manager.start_quorum(timeout=timedelta(seconds=12))
         self.assertEqual(
diff --git a/torchft/optim.py b/torchft/optim.py
index ce24823..0583d94 100644
--- a/torchft/optim.py
+++ b/torchft/optim.py
@@ -12,8 +12,9 @@
 
 """
 
-from typing import TYPE_CHECKING, Optional
+from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional
 
+import torch
 from torch.optim import Optimizer
 
 if TYPE_CHECKING:
@@ -52,3 +53,11 @@ def step(self, closure: Optional[object] = None) -> None:
         assert closure is None, "optimizers that use closures are not supported"
         if self.manager.should_commit():
             self.optim.step()
+
+    @property
+    def param_groups(self) -> List[Dict[str, Any]]:
+        return self.optim.param_groups
+
+    @property
+    def state(self) -> Mapping[torch.Tensor, Any]:  # pyre-fixme[3]
+        return self.optim.state
diff --git a/torchft/optim_test.py b/torchft/optim_test.py
index 50412d8..5dd6964 100644
--- a/torchft/optim_test.py
+++ b/torchft/optim_test.py
@@ -7,6 +7,7 @@
 from unittest import TestCase
 from unittest.mock import MagicMock, create_autospec
 
+import torch
 from torch.nn import Linear
 from torch.optim import AdamW
 
@@ -34,9 +35,16 @@ def test_optimizer_wrapper(self) -> None:
         optim.zero_grad()
         self.assertEqual(manager.start_quorum.call_count, 1)
 
+        b = torch.rand(3)
+        m(b).sum().backward()
+
         manager.should_commit.return_value = True
         optim.step()
         manager.should_commit.return_value = False
         optim.step()
+        self.assertEqual(len(optim.param_groups), 2)
+        self.assertEqual(optim.param_groups[1]["lr"], 1e-4)
+        self.assertEqual(optim.param_groups[1]["params"], [])
+        self.assertEqual(len(optim.state), len(list(m.parameters())))
 
         self.assertEqual(manager.should_commit.call_count, 2)
diff --git a/torchft/process_group.py b/torchft/process_group.py
index 34797c7..4790352 100644
--- a/torchft/process_group.py
+++ b/torchft/process_group.py
@@ -19,9 +19,22 @@
 import logging
 import queue
 import threading
-from abc import ABC
+from contextlib import contextmanager, nullcontext
+from dataclasses import dataclass
 from datetime import timedelta
-from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Union
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Callable,
+    Dict,
+    Generator,
+    List,
+    Optional,
+    Tuple,
+    TypeVar,
+    Union,
+    cast,
+)
 
 import torch
 import torch.distributed as dist
@@ -30,7 +43,6 @@
 # pyre-fixme[21]: no attribute ProcessGroupNCCL
 # pyre-fixme[21]: no attribute ProcessGroupGloo
 from torch.distributed import (
-    BroadcastOptions,
     DeviceMesh,
     PrefixStore,
     ProcessGroup as BaseProcessGroup,
@@ -41,8 +53,15 @@
     get_rank,
     init_device_mesh,
 )
-from torch.distributed.distributed_c10d import Work, _world
+from torch.distributed.distributed_c10d import (
+    AllgatherOptions,
+    AllreduceOptions,
+    BroadcastOptions,
+    ReduceOp,
+    Work,
+)
 from torch.futures import Future
+from torch.utils._pytree import tree_any
 
 if TYPE_CHECKING:
     from torchft.manager import Manager
@@ -55,6 +74,9 @@
 _FUTURE_EXCEPTION = "fut_exception"
 
 
+T = TypeVar("T")
+
+
 def _get(q: mp.Queue, timeout: Union[float, timedelta]) -> object:
     """
     Gets an item from a queue with a timeout. If the timeout is exceeded then
@@ -123,7 +145,9 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
         raise NotImplementedError("not implemented")
 
     # pyre-fixme[14]: inconsistent override
-    def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
+    def allreduce(
+        self, tensors: List[torch.Tensor], opts: Union[AllreduceOptions, ReduceOp]
+    ) -> Work:
         raise NotImplementedError("not implemented")
 
     # pyre-fixme[14]: inconsistent override
@@ -131,12 +155,24 @@ def allgather(
         self,
         output_tensors: List[List[torch.Tensor]],
         input_tensor: List[torch.Tensor],
-        opts: object,
+        opts: AllgatherOptions,
     ) -> Work:
+        """
+        Gathers tensors from the whole group in a list.
+
+        See torch.distributed.all_gather for more details.
+        """
         raise NotImplementedError("not implemented")
 
     # pyre-fixme[14]: inconsistent override
-    def broadcast(self, tensor_list: List[torch.Tensor], opts: object) -> Work:
+    def broadcast(
+        self, tensor_list: List[torch.Tensor], opts: BroadcastOptions
+    ) -> Work:
+        """
+        Broadcasts the tensor to the whole group.
+
+        See torch.distributed.broadcast for more details.
+        """
         raise NotImplementedError("not implemented")
 
     def broadcast_one(self, tensor: torch.Tensor, root: int) -> Work:
@@ -507,6 +543,10 @@ def __init__(self, manager: "Manager") -> None:
         self._manager = manager
 
     def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
+        # Ensure we have a valid quorum and are configured before trying to do
+        # any work.
+        self._manager.wait_quorum()
+
         if self._manager.errored() is not None:
             return _DummyWork(tensors)
 
@@ -547,26 +587,52 @@ def __init__(
         self._timeout = timeout
 
     def wait(self, timeout: Optional[timedelta] = None) -> bool:
+        self._pg._assert_alive()
+
         self._tx.put(("wait", self._op_id), timeout=self._timeout)
-        assert _get(self._rx, self._timeout) == self._op_id
+        op_id, event = cast(
+            Tuple[int, Optional[torch.cuda.Event]],
+            _get(self._rx, timeout or self._timeout),
+        )
+        assert op_id == self._op_id
+        if event is not None:
+            event.wait()
         return True
 
+    def synchronize(self) -> None:
+        # TODO: No one seems to use this and NCCL wait already only waits the
+        # stream and is non-blocking on the CPU side so no real need for a
+        # separate call.
+        raise NotImplementedError("not implemented")
+
     def get_future(self) -> Future[object]:
         return self._pg._get_future(self._op_id)
 
+    def __del__(self) -> None:
+        self._tx.put(("del", self._op_id), timeout=self._timeout)
 
-class _BabyWorkNCCL(_BabyWork):
-    def wait(self, timeout: Optional[timedelta] = None) -> bool:
-        self._tx.put(("synchronize", self._op_id), timeout=self._timeout)
-        # pyre-fixme[23]: unable to unpack into 2 values
-        op_id, event = _get(self._rx, self._timeout)
-        assert op_id == self._op_id
-        assert isinstance(event, torch.cuda.Event)
 
-        # Wait on Event makes the stream wait but not the CPU thread.
-        event.wait()
+def _is_any_cuda(obj: object) -> bool:
+    """
+    Returns true if any of the tensors in the object are CUDA tensors.
 
-        return True
+    Supports lists, tuples, dicts, and tensors.
+    """
+    return tree_any(lambda obj: isinstance(obj, torch.Tensor) and obj.is_cuda, obj)
+
+
+@dataclass
+class _OpMetadata:
+    work: Work
+    stream: Optional[torch.cuda.Stream]
+
+    @contextmanager
+    def set_stream(self) -> Generator[None, None, None]:
+        if self.stream is not None:
+            with torch.cuda.stream(self.stream):
+                yield
+        else:
+            yield
 
 
 class ProcessGroupBaby(ProcessGroup):
@@ -575,11 +641,8 @@ class ProcessGroupBaby(ProcessGroup):
     subprocess. Since it's running in a subprocess all tensors need to be in
     shared memory or will be moved to shared memory. CUDA tensors are implicitly
     share able and don't need any changes.
-
     """
 
-    WORK_CLASS: Type[_BabyWork] = _BabyWork
-
     def __init__(self, timeout: Union[float, timedelta] = 60.0) -> None:
         super().__init__(0, 1)
 
@@ -609,8 +672,14 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
         if self._rx is not None:
             self._rx.close()
         if self._future_queue is not None:
+            # wait for the future thread to exit and then close the queue
             self._future_queue.put(_QUEUE_CLOSE)
-            assert self._future_queue is not None
+            assert self._future_thread is not None
+            self._future_thread.join(timeout=10.0)
+            # pyre-ignore[16]: optional value is checked above
+            if self._future_thread.is_alive():
+                raise RuntimeError("future thread did not exit")
+            # pyre-ignore[16]: optional value is checked above
             self._future_queue.close()
 
         ctx = mp.get_context("spawn")
@@ -631,7 +700,14 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
 
         self._p = ctx.Process(
             target=self._worker,
-            args=(store_addr, rank, world_size, self._tx, self._rx, self._future_queue),
+            args=(
+                store_addr,
+                rank,
+                world_size,
+                self._tx,
+                self._rx,
+                self._future_queue,
+            ),
             daemon=True,
         )
         self._p.start()
@@ -668,23 +744,73 @@ def _worker(
                 return
             tx.put(None)
 
-            work = {}
+            streams: Dict[str, torch.cuda.Stream] = {}
+            work: Dict[int, _OpMetadata] = {}
             next_op_id: int = 0
 
             while True:
                 op = rx.get()
                 cmd = op[0]
                 if cmd == "func":
-                    func_name, args, kwargs = op[1:]
-                    fn = getattr(pg, func_name)
-                    work[next_op_id] = fn(*args, **kwargs)
+                    func_name, args, kwargs, stream_device, stream_id, event = op[1:]
+
+                    # To avoid potential deadlocks we need to preserve the
+                    # stream/synchronization behavior of the parent process.
+                    # We allocate one Stream per stream_id to make sure that we
+                    # don't accidentally introduce cross stream synchronization
+                    # points.
+                    if stream_id is not None:
+                        stream_key = f"{stream_device}/{stream_id}"
+                        if stream_key not in streams:
+                            streams[stream_key] = torch.cuda.Stream(
+                                device=stream_device
+                            )
+                        stream = streams[stream_key]
+                    else:
+                        stream = None
+
+                    with (
+                        torch.cuda.stream(stream)
+                        if stream is not None
+                        else nullcontext()
+                    ):
+                        # Make the stream wait on the cuda event to make sure we
+                        # don't start the operation until the tensor is ready.
+                        if event is not None:
+                            event.wait()
+
+                        args = _PickleSafeOptions.unsafe_args(args)
+                        fn = getattr(pg, func_name)
+                        work[next_op_id] = _OpMetadata(
+                            work=fn(*args, **kwargs),
+                            stream=stream,
+                        )
                     tx.put(next_op_id)
                     next_op_id += 1
                 elif cmd == "wait":
                     op_id: int = op[1]
-                    work[op_id].wait()
+
+                    metadata = work[op_id]
+
+                    with metadata.set_stream():
+                        # With WorkNCCL this makes the stream wait not the CPU when
+                        # no timeout is passed.
+                        metadata.work.wait()
+
+                        # Register event on the stream that we can pass to the main
+                        # process.
+                        event = (
+                            torch.cuda.current_stream().record_event(
+                                torch.cuda.Event(interprocess=True)
+                            )
+                            if metadata.stream is not None
+                            else None
+                        )
+
+                    tx.put((op_id, event))
+                elif cmd == "del":
+                    op_id: int = op[1]
                     del work[op_id]
-                    tx.put(op_id)
                 elif cmd == "future":
                     op_id: int = op[1]
 
@@ -695,29 +821,17 @@ def callback(fut: Future[object]) -> None:
                         except Exception as e:
                             future_queue.put((op_id, _FUTURE_EXCEPTION, e))
 
-                    work[op_id].get_future().add_done_callback(callback)
+                    work[op_id].work.get_future().add_done_callback(callback)
                     tx.put(op_id)
-                elif cmd == "synchronize":
-                    # CUDA only, use events instead of waiting on CPU
-                    op_id = op[1]
-
-                    # With WorkNCCL this makes the stream wait not the CPU when
-                    # no timeout is passed.
-                    work[op_id].wait()
-
-                    # Register event on the stream that we can pass to the main
-                    # process.
-                    event = torch.cuda.Event(interprocess=True)
-                    event.record()
-
-                    del work[op_id]
-                    tx.put((op_id, event))
+                elif cmd == "num_active_work":
+                    tx.put(len(work))
                 else:
                     raise ValueError(f"unknown cmd: {cmd}")
 
         except Exception as e:
             logger.exception("worker errored")
             tx.put(e)
+            raise
 
     def _future_handler(self, future_queue: mp.Queue) -> None:
         try:
@@ -739,6 +853,8 @@ def _future_handler(self, future_queue: mp.Queue) -> None:
             logger.exception(f"got unexpected error in future handler: {e}")
 
     def _get_future(self, op_id: int) -> Future[object]:
+        self._assert_alive()
+
         with self._futures_lock:
             fut = Future()  # pyre-fixme[29]: is not a function
             self._futures[op_id] = fut
@@ -751,21 +867,58 @@ def _get_future(self, op_id: int) -> Future[object]:
         return fut
 
     def _run_func(self, func: str, *args: object, **kwargs: object) -> Work:
+        self._assert_alive()
+
         rx = self._rx
         tx = self._tx
         assert rx is not None
         assert tx is not None
 
-        tx.put(("func", func, args, kwargs), timeout=self._timeout)
+        is_cuda = _is_any_cuda(args)
+
+        stream_device = torch.cuda.current_stream().device if is_cuda else None
+        stream_id = torch.cuda.current_stream().stream_id if is_cuda else None
+        event = (
+            torch.cuda.current_stream().record_event(
+                torch.cuda.Event(interprocess=True)
+            )
+            if is_cuda
+            else None
+        )
+
+        tx.put(
+            (
+                "func",
+                func,
+                _PickleSafeOptions.safe_args(args),
+                kwargs,
+                stream_device,
+                stream_id,
+                event,
+            ),
+            timeout=self._timeout,
+        )
 
         op_id = _get(rx, self._timeout)
         assert isinstance(op_id, int), f"invalid return {op_id}"
 
-        return self.WORK_CLASS(
-            pg=self, tx=tx, rx=rx, op_id=op_id, timeout=self._timeout
-        )
+        return _BabyWork(pg=self, tx=tx, rx=rx, op_id=op_id, timeout=self._timeout)
 
-    def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
+    def _assert_alive(self) -> None:
+        """
+        Assert that the process group is alive. This is used to ensure that
+        operations are not performed on a dead process group and any errors are surfaced.
+        """
+        p = self._p
+        assert p is not None
+        if not p.is_alive():
+            raise RuntimeError(f"child process {p.pid=} is dead {p.exitcode=}")
+
+    def allreduce(
+        self,
+        tensors: List[torch.Tensor],
+        opts: Union[dist.AllreduceOptions, dist.ReduceOp],
+    ) -> Work:
         assert isinstance(tensors, list), "input must be list"
 
         for tensor in tensors:
@@ -774,9 +927,90 @@ def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
 
         return self._run_func("allreduce", tensors, opts)
 
+    def allgather(
+        self,
+        output_tensors: List[List[torch.Tensor]],
+        input_tensor: List[torch.Tensor],
+        opts: AllgatherOptions,
+    ) -> Work:
+        assert isinstance(output_tensors, list), "input must be list"
+        assert isinstance(input_tensor, list), "input must be list"
+
+        for tensor_list in output_tensors:
+            for tensor in tensor_list:
+                if not tensor.is_shared():
+                    tensor.share_memory_()
+
+        for tensor in input_tensor:
+            if not tensor.is_shared():
+                tensor.share_memory_()
+
+        return self._run_func("allgather", output_tensors, input_tensor, opts)
+
+    def broadcast(
+        self,
+        tensor_list: List[torch.Tensor],
+        opts: BroadcastOptions,
+    ) -> Work:
+        assert isinstance(tensor_list, list), "input must be list"
+
+        for tensor in tensor_list:
+            if not tensor.is_shared():
+                tensor.share_memory_()
+
+        return self._run_func("broadcast", tensor_list, opts)
+
     def size(self) -> int:
         return self._world_size
 
+    def num_active_work(self) -> int:
+        assert self._tx is not None
+        self._tx.put(("num_active_work",), timeout=self._timeout)
+
+        assert self._rx is not None
+        return cast(int, _get(self._rx, self._timeout))
+
+
+@dataclass
+class _PickleSafeOptions:
+    func: Callable[[], object]
+    fields: Dict[str, object]
+
+    @classmethod
+    def safe_args(cls, args: T) -> T:
+        if isinstance(args, tuple):
+            return tuple(cls.safe_args(arg) for arg in args)
+        elif isinstance(args, list):
+            return [cls.safe_args(arg) for arg in args]
+        elif isinstance(args, (AllreduceOptions, AllgatherOptions, BroadcastOptions)):
+            return cls.from_torch(args)
+        else:
+            return args
+
+    @classmethod
+    def unsafe_args(cls, args: T) -> T:
+        if isinstance(args, tuple):
+            return tuple(cls.unsafe_args(arg) for arg in args)
+        elif isinstance(args, list):
+            return [cls.unsafe_args(arg) for arg in args]
+        elif isinstance(args, cls):
+            return args.to_torch()
+        else:
+            return args
+
+    @classmethod
+    def from_torch(cls, opts: object) -> "_PickleSafeOptions":
+        return cls(
+            func=opts.__class__,
+            fields={k: getattr(opts, k) for k in dir(opts) if not k.startswith("_")},
+        )
+
+    def to_torch(self) -> object:
+        opts = self.func()
+        for k, v in self.fields.items():
+            setattr(opts, k, v)
+        return opts
+
 
 class ProcessGroupBabyGloo(ProcessGroupBaby):
     """
@@ -811,8 +1045,6 @@ class ProcessGroupBabyNCCL(ProcessGroupBaby):
     tensors may leak in the current PyTorch implementation. TODO fix
     """
 
-    WORK_CLASS = _BabyWorkNCCL
-
     @classmethod
     def _create_pg(cls, store: Store, rank: int, world_size: int) -> BaseProcessGroup:
         # pyre-fixme[16]: no attribute ProcessGroupNCCL
@@ -852,6 +1084,8 @@ def extend_device_mesh(
 
 
 class ManagedDeviceMesh(DeviceMesh):
+    replicate_pg_singleton: Optional["ManagedProcessGroup"] = None
+
     def __init__(
         self,
         mesh: Optional[DeviceMesh],
@@ -880,6 +1114,16 @@ def __init__(
         self._flatten_mesh_list: Tuple[DeviceMesh, ...] = tuple()
         self._thread_id: Optional[int] = None
 
+    def __getstate__(self) -> Dict[str, Any]:
+        state = self.__dict__.copy()
+        state["replicate_pg"] = None
+        return state
+
+    def __setstate__(self, state: Dict[str, Any]) -> None:
+        self.__dict__.update(state)
+        assert self.replicate_pg_singleton is not None
+        self.replicate_pg = self.replicate_pg_singleton
+
     def __getitem__(self, mesh_dim_names: Union[str, Tuple[str, ...]]) -> DeviceMesh:
         if isinstance(mesh_dim_names, str):
             if mesh_dim_names == self.replicate_dim_name:
@@ -897,13 +1141,16 @@ def __getitem__(self, mesh_dim_names: Union[str, Tuple[str, ...]]) -> DeviceMesh
                 return self.mesh[mesh_dim_names]
         else:
             assert isinstance(mesh_dim_names, tuple)
-            if self.replicate_dim_name in mesh_dim_names:
+            if self.replicate_dim_name not in mesh_dim_names:
                 assert self.mesh is not None
                 return self.mesh[mesh_dim_names]
             else:
+                mesh_dim_names_wo_replicate = tuple(
+                    n for n in mesh_dim_names if n != self.replicate_dim_name
+                )
                 assert self.mesh is not None
                 return ManagedDeviceMesh(
-                    self.mesh[mesh_dim_names],
+                    self.mesh[mesh_dim_names_wo_replicate],
                     mesh_dim_names,
                     self.replicate_pg,
                     mesh_dim_names.index(self.replicate_dim_name),
@@ -938,14 +1185,18 @@ def _flatten(self, mesh_dim_name: Optional[str]) -> "DeviceMesh":
         return flatten_mesh
 
     def size(self, mesh_dim: Optional[int] = None) -> int:
+        replicate_pg_size = self.replicate_pg.size()
+        # We have to lie to the users if there are zero particpants.
+        # This is possible during the initialization stage of training.
+        replicate_pg_size = 1 if replicate_pg_size == 0 else replicate_pg_size
         if mesh_dim is None:
             if self.mesh is None:
-                return self.replicate_pg.size()
+                return replicate_pg_size
             else:
                 assert self.mesh is not None
-                return self.mesh.size() * self.replicate_pg.size()
+                return self.mesh.size() * replicate_pg_size
         elif mesh_dim == self.replicate_dim:
-            return self.replicate_pg.size()
+            return replicate_pg_size
         else:
             assert self.mesh is not None
             return self.mesh.size(self._real_mesh_dim(mesh_dim))
@@ -995,7 +1246,16 @@ def get_coordinate(self) -> Optional[List[int]]:
         dimensions of the mesh. If this rank is not part of the mesh, return None.
         """
         assert self.mesh is not None
-        return self.mesh._coordinate_on_dim if self.mesh._coordinate_on_dim else None
+        coordinate = (
+            self.mesh._coordinate_on_dim if self.mesh._coordinate_on_dim else None
+        )
+        if not coordinate:
+            return coordinate
+
+        # We need to copy be cause we are going to modify the coordinate.
+        coordinate = coordinate.copy()
+        coordinate.insert(get_rank(self.replicate_pg), self.replicate_dim)
+        return coordinate
 
     def get_all_groups(self) -> List[BaseProcessGroup]:
         raise NotImplementedError
@@ -1057,19 +1317,11 @@ def ft_init_device_mesh(
         mesh_dim_names=tuple(_mesh_dim_names),
     )
 
-    if device_type == "cpu":
-        pg = ProcessGroupGloo()
-    elif device_type == "cuda":
-        pg = ProcessGroupNCCL()
-    else:
-        raise ValueError()
-
-    manager._pg = pg
     replicate_pg = ManagedProcessGroup(manager)
-    # We have to use MultiProcessTestCase, otherwise c10d will complain
-    # the same backend has been registered.
     replicate_pg.register(mesh_dim_names[replicate_dim])
 
+    ManagedDeviceMesh.replicate_pg_singleton = replicate_pg
+
     return ManagedDeviceMesh(
         mesh=mesh,
         mesh_dim_names=mesh_dim_names,
diff --git a/torchft/process_group_test.py b/torchft/process_group_test.py
index e86c7e0..f765625 100644
--- a/torchft/process_group_test.py
+++ b/torchft/process_group_test.py
@@ -4,6 +4,8 @@
 # This source code is licensed under the BSD-style license found in the
 # LICENSE file in the root directory of this source tree.
 
+import gc
+import io
 import multiprocessing
 import os
 import unittest
@@ -86,14 +88,15 @@ def check_tensors(arg: Any) -> None:  # pyre-ignore[2]
                 check_tensors(item)
 
     # Test collectives
-    collectives = {
-        "allreduce": ([input_tensor], AllreduceOptions()),
-        "allgather": (output_tensors, [input_tensor], AllgatherOptions()),
-        "broadcast": (tensor_list, BroadcastOptions()),
-        "broadcast_one": (input_tensor, 0),
-    }
+    collectives = [
+        ("allreduce", ([input_tensor], AllreduceOptions())),
+        ("allreduce", ([input_tensor], ReduceOp.SUM)),
+        ("allgather", (output_tensors, [input_tensor], AllgatherOptions())),
+        ("broadcast", (tensor_list, BroadcastOptions())),
+        ("broadcast_one", (input_tensor, 0)),
+    ]
     works: Dict[str, dist._Work] = {}
-    for coll_str, args in collectives.items():
+    for coll_str, args in collectives:
         coll = getattr(pg, coll_str)
         work = coll(*args)
         works[coll_str] = work
@@ -210,6 +213,84 @@ def test_baby_gloo_timeout(self) -> None:
         with self.assertRaisesRegex(TimeoutError, "timed out after 0.01 seconds"):
             a.configure(store_addr, 0, 2)
 
+    def test_reconfigure_baby_process_group(self) -> None:
+        store = TCPStore(
+            host_name="localhost", port=0, is_master=True, wait_for_workers=False
+        )
+        store_addr = f"localhost:{store.port}/prefix"
+
+        a = ProcessGroupBabyGloo()
+        a.configure(store_addr, 0, 1)
+        future_thread_1 = a._future_thread
+        future_queue_1 = a._future_queue
+        p_1 = a._p
+
+        store_addr = f"localhost:{store.port}/prefix2"
+        a.configure(store_addr, 0, 1)
+        future_thread_2 = a._future_thread
+        future_queue_2 = a._future_queue
+        p_2 = a._p
+
+        self.assertNotEqual(future_thread_1, future_thread_2)
+        self.assertNotEqual(future_queue_1, future_queue_2)
+        self.assertNotEqual(p_1, p_2)
+
+        assert future_thread_1 is not None
+        self.assertFalse(future_thread_1.is_alive())
+        assert future_queue_1 is not None
+        self.assertTrue(future_queue_1._closed)  # pyre-ignore[16]: no attribute _closed
+        assert p_1 is not None
+        self.assertFalse(p_1.is_alive())
+
+        assert future_thread_2 is not None
+        self.assertTrue(future_thread_2.is_alive())
+        assert future_queue_2 is not None
+        self.assertFalse(future_queue_2._closed)
+        assert p_2 is not None
+        self.assertTrue(p_2.is_alive())
+
+    def test_baby_gloo_apis(self) -> None:
+        store = TCPStore(
+            host_name="localhost", port=0, is_master=True, wait_for_workers=False
+        )
+
+        store_addr = f"localhost:{store.port}/prefix"
+
+        a = ProcessGroupBabyGloo(timeout=timedelta(seconds=10))
+        a.configure(store_addr, 0, 1)
+
+        _test_pg(a)
+
+        # force collection to ensure no BabyWork objects remain
+        gc.collect()
+
+        self.assertEqual(a.num_active_work(), 0)
+
+    # pyre-fixme[56]: Pyre was not able to infer the type of argument
+    @skipUnless(torch.cuda.is_available(), "needs CUDA")
+    def test_baby_nccl_apis(self) -> None:
+        # set to 1 if more than >=2 gpus
+        device_id = 1 % torch.cuda.device_count()
+        torch.cuda.set_device(device_id)
+
+        store = TCPStore(
+            host_name="localhost", port=0, is_master=True, wait_for_workers=False
+        )
+
+        store_addr = f"localhost:{store.port}/prefix"
+
+        a = ProcessGroupBabyNCCL(timeout=timedelta(seconds=10))
+        a.configure(store_addr, 0, 1)
+
+        _test_pg(a, torch.randn((2, 3), device="cuda"))
+
+        torch.cuda.synchronize()
+
+        # force collection to ensure no BabyWork objects remain
+        gc.collect()
+
+        self.assertEqual(a.num_active_work(), 0)
+
     def test_dummy(self) -> None:
         pg = ProcessGroupDummy(0, 1)
         m = nn.Linear(3, 4)
@@ -226,12 +307,15 @@ def test_baby_nccl_2gpu(self) -> None:
         store_addr: str = f"localhost:{store.port}/prefix"
 
         def run(rank: int) -> Tuple[torch.Tensor, Work]:
-            a = ProcessGroupBabyNCCL()
+            a = ProcessGroupBabyNCCL(
+                timeout=timedelta(seconds=10.0),
+            )
             a.configure(store_addr, rank, 2)
-
             self.assertEqual(a.size(), 2)
 
-            at = torch.tensor([rank + 1], device=f"cuda:{rank}")
+            # We test using set_device to ensure stream device is correct.
+            torch.cuda.set_device(rank)
+            at = torch.tensor([rank + 1], device="cuda")
 
             a_work = a.allreduce([at], ReduceOp.SUM)
             return at, a_work
@@ -331,7 +415,8 @@ def test_managed_process_group(self) -> None:
         self.assertIsInstance(list(works.values())[0], _ManagedWork)
 
         self.assertEqual(manager.report_error.call_count, 0)
-        self.assertEqual(manager.wrap_future.call_count, 1)
+        self.assertEqual(manager.wrap_future.call_count, 2)
+        self.assertEqual(manager.wait_quorum.call_count, 2)
 
 
 class DeviceMeshTest(TestCase):
diff --git a/torchft/torchft.pyi b/torchft/torchft.pyi
index fbd0293..2c8c6cd 100644
--- a/torchft/torchft.pyi
+++ b/torchft/torchft.pyi
@@ -1,5 +1,5 @@
 from datetime import timedelta
-from typing import Optional, Tuple
+from typing import List, Optional
 
 class ManagerClient:
     def __init__(self, addr: str, connect_timeout: timedelta) -> None: ...
@@ -7,11 +7,11 @@ class ManagerClient:
         self,
         rank: int,
         step: int,
-        checkpoint_server_addr: str,
+        checkpoint_metadata: str,
         shrink_only: bool,
         timeout: timedelta,
-    ) -> Tuple[int, int, int, str, str, int, Optional[int], int, bool]: ...
-    def checkpoint_address(self, rank: int, timeout: timedelta) -> str: ...
+    ) -> QuorumResult: ...
+    def checkpoint_metadata(self, rank: int, timeout: timedelta) -> str: ...
     def should_commit(
         self,
         rank: int,
@@ -20,6 +20,19 @@ class ManagerClient:
         timeout: timedelta,
     ) -> bool: ...
 
+class QuorumResult:
+    quorum_id: int
+    replica_rank: int
+    replica_world_size: int
+    recover_src_manager_address: str
+    recover_src_rank: Optional[int]
+    recover_dst_ranks: List[int]
+    store_address: str
+    max_step: int
+    max_rank: Optional[int]
+    max_world_size: int
+    heal: bool
+
 class Manager:
     def __init__(
         self,
diff --git a/train_ddp.py b/train_ddp.py
index 9ad9cc8..4bcc029 100644
--- a/train_ddp.py
+++ b/train_ddp.py
@@ -7,6 +7,7 @@
 import logging
 import os
 import sys
+from datetime import timedelta
 
 import torch
 import torch.nn.functional as F
@@ -70,7 +71,13 @@ def state_dict():
         }
 
     device = "cuda" if torch.cuda.is_available() else "cpu"
-    pg = ProcessGroupBabyNCCL() if torch.cuda.is_available() else ProcessGroupGloo()
+    pg = (
+        ProcessGroupBabyNCCL(
+            timeout=timedelta(seconds=5),
+        )
+        if torch.cuda.is_available()
+        else ProcessGroupGloo(timeout=timedelta(seconds=5))
+    )
 
     manager = Manager(
         pg=pg,
@@ -78,6 +85,7 @@ def state_dict():
         load_state_dict=load_state_dict,
         state_dict=state_dict,
         replica_id=f"train_ddp_{REPLICA_GROUP_ID}",
+        timeout=timedelta(seconds=10),
     )
 
     class Net(nn.Module):
diff --git a/train_fsdp.py b/train_fsdp.py
new file mode 100644
index 0000000..7a5d07c
--- /dev/null
+++ b/train_fsdp.py
@@ -0,0 +1,157 @@
+import os 
+from datasets import load_dataset
+
+import torch
+from transformers import LlamaForCausalLM, AutoTokenizer
+from torch.distributed._composable.fsdp import fully_shard
+import torch.distributed as dist
+from tqdm import tqdm
+from transformers.data import DataCollatorForSeq2Seq
+from transformers.models.llama.modeling_llama import LlamaDecoderLayer
+from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
+
+from torchdata.stateful_dataloader import StatefulDataLoader
+
+from torchft import (
+    DistributedSampler,
+    Manager,
+    Optimizer,
+    ProcessGroupBabyNCCL,
+    ProcessGroupGloo,
+)
+from torchft.process_group import ft_init_device_mesh
+
+def hsdp_device_mesh(replica_group_size, sharding_group_size, device=None, manager=None):
+
+    if replica_group_size is None or sharding_group_size is None:
+        raise ValueError("Both replica_group_size and sharding_group_size must be provided.")
+
+    device = device or f"cuda"
+
+    device_mesh = ft_init_device_mesh(
+        device_type=device,
+        mesh_shape=(replica_group_size, sharding_group_size),
+        mesh_dim_names=("dp_replicate", "dp_shard"),
+        replicate_dim=0,
+        manager=manager,
+        )
+    if device_mesh is None:
+        raise RuntimeError("Failed to create a valid device mesh.")
+
+    return device_mesh
+
+def parallelize_llama(model, mesh):
+    sharding_conditions = [lambda m: isinstance(m, LlamaDecoderLayer)]
+
+    for m in reversed(list(model.modules())):
+        if any(c(m) for c in sharding_conditions):
+            # fully_shard(m, mesh=mesh, reshard_after_forward=True)
+            fully_shard(m, mesh=mesh)
+    # fully_shard([model.model.embed_tokens, model.lm_head], mesh=mesh)
+    fully_shard(model, mesh=mesh)
+
+def main():
+    REPLICA_GROUP_ID = int(os.environ.get("REPLICA_GROUP_ID", 0))
+    NUM_REPLICA_GROUPS = int(os.environ.get("NUM_REPLICA_GROUPS", 2))
+    NUM_REPLICAS = int(os.environ.get("NUM_REPLICAS", 2))
+
+    rank = int(os.environ.get("RANK", 0))
+
+    model_name = "Meta-Llama/Llama-3.2-1B-Instruct"
+    tokenizer = AutoTokenizer.from_pretrained(model_name)
+    model = LlamaForCausalLM.from_pretrained(model_name)
+
+    if not tokenizer.pad_token_id:
+        tokenizer.pad_token_id = tokenizer.eos_token_id
+
+    # If there is a mismatch between tokenizer vocab size and embedding matrix,
+    # throw a warning and then expand the embedding matrix
+    assert len(tokenizer) == model.get_input_embeddings().weight.shape[0]
+
+    train_data = load_dataset("samsum", split="train")
+    
+    class SAMSumDataset(torch.utils.data.Dataset):
+        def __init__(self, data, tokenizer):
+            self.data = data
+            self.tokenizer = tokenizer
+        def __getitem__(self, idx):
+            text = self.data[idx]
+            prompt = self.tokenizer.encode(tokenizer.bos_token + f"Summarize this dialog: {text['dialogue']}\n---\nSummary: ", add_special_tokens=False)
+            summary = self.tokenizer.encode(text["summary"] + self.tokenizer.eos_token, add_special_tokens=False)
+            input_ids = prompt + summary
+            labels = len(prompt) * [-100] + summary
+            return {"input_ids": input_ids, "labels": labels}
+        def __len__(self):
+            return len(self.data)
+    
+    
+    train_dataset = SAMSumDataset(train_data, tokenizer)
+    
+    batch_size = 8
+
+    sampler = DistributedSampler(
+        train_dataset,
+        replica_group=REPLICA_GROUP_ID,
+        num_replica_groups=NUM_REPLICA_GROUPS,
+        rank=rank,
+        shuffle=True,
+        num_replicas=NUM_REPLICAS,
+        )
+
+    train_dataloader = StatefulDataLoader(train_dataset, batch_size=batch_size, collate_fn=DataCollatorForSeq2Seq(tokenizer), sampler=sampler)
+
+    def load_state_dict(state_dict):
+        set_state_dict(
+            model,
+            optimizer.optim,
+            model_state_dict=state_dict["model"],
+            optim_state_dict=state_dict["optim"],
+        )
+        
+
+    def state_dict():
+        model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer.optim)
+        return {
+            "model": model_state_dict,
+            "optim": optimizer_state_dict,
+        }
+   
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+    pg = ProcessGroupBabyNCCL() if torch.cuda.is_available() else ProcessGroupGloo()
+
+    manager = Manager(
+        pg=pg,
+        min_replica_size=1,
+        load_state_dict=load_state_dict,
+        state_dict=state_dict,
+        replica_id=f"train_fsdp_{REPLICA_GROUP_ID}",
+        use_async_quorum=False,
+    )
+
+    mesh = hsdp_device_mesh(NUM_REPLICA_GROUPS, NUM_REPLICAS, "cuda" if torch.cuda.is_available() else "cpu", manager=manager)
+    
+    parallelize_llama(model, mesh)
+
+    model.to(device)
+    
+    optimizer = Optimizer(manager, torch.optim.Adam(model.parameters(), lr=1e-5))
+
+    while manager.current_step() < 500:
+        model.train()
+        for batch in tqdm(train_dataloader):
+            input_ids = batch["input_ids"].to(device)
+            labels = batch["labels"].to(device)
+            optimizer.zero_grad()
+
+            outputs = model(input_ids, labels=labels)
+            loss = outputs.loss
+            loss.backward()
+            optimizer.step()
+            
+            if manager.current_step() % 100 == 0:
+                print(f"[{manager.current_step()}] loss = {loss.item()}")
+
+
+if __name__ == "__main__":
+    main()