diff --git a/README.md b/README.md index aa67f59..61e37d7 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ Easy Per Step Fault Tolerance for PyTorch | <a href="https://pytorch.org/torchft/"><b>Documentation</b></a> | <a href="https://github.com/pytorch-labs/torchft/blob/main/media/fault_tolerance_poster.pdf"><b>Poster</b></a> | <a href="https://docs.google.com/document/d/1OZsOsz34gRDSxYXiKkj4WqcD9x0lP9TcsfBeu_SsOY4/edit"><b>Design Doc</b></a> - | + | </p> <p align="center"> <a href="https://pypi.org/project/torchft-nightly/"><img alt="PyPI - Version" src="https://img.shields.io/pypi/v/torchft-nightly"></a> @@ -98,7 +98,7 @@ when using synchronous training. You can start a lighthouse server by running: ```sh -$ RUST_BACKTRACE=1 torchft_lighthouse --min_replicas 1 --quorum_tick_ms 100 --join_timeout_ms 1000 +$ RUST_BACKTRACE=1 torchft_lighthouse --min_replicas 1 --quorum_tick_ms 100 --join_timeout_ms 10000 ``` ### Example Training Loop (DDP) @@ -108,7 +108,7 @@ See [train_ddp.py](./train_ddp.py) for the full example. Invoke with: ```sh -$ TORCHFT_MANAGER_PORT=29512 TORCHFT_LIGHTHOUSE=http://localhost:29510 torchrun --master_port 29501 --nnodes 1 --nproc_per_node 1 train.py +$ TORCHFT_LIGHTHOUSE=http://localhost:29510 torchrun --master_port 29501 --nnodes 1 --nproc_per_node 1 train.py ``` train.py: diff --git a/proto/torchft.proto b/proto/torchft.proto index 67a42c0..7e248e5 100644 --- a/proto/torchft.proto +++ b/proto/torchft.proto @@ -72,30 +72,32 @@ service LighthouseService { message ManagerQuorumRequest { int64 rank = 1; int64 step = 2; - string checkpoint_server_addr = 3; + string checkpoint_metadata = 3; bool shrink_only = 4; } message ManagerQuorumResponse { int64 quorum_id = 1; - string address = 2; - string store_address = 3; + string recover_src_manager_address = 2; + optional int64 recover_src_rank = 3; + repeated int64 recover_dst_ranks = 4; + string store_address = 5; // These are information for the replicas which are at the max step. - int64 max_step = 4; - optional int64 max_rank = 5; - int64 max_world_size = 6; + int64 max_step = 6; + optional int64 max_rank = 7; + int64 max_world_size = 8; // These are information for all replicas including behind replicas. - int64 replica_rank = 7; - int64 replica_world_size = 8; - bool heal = 9; + int64 replica_rank = 9; + int64 replica_world_size = 10; + bool heal = 11; } -message CheckpointAddressRequest { +message CheckpointMetadataRequest { int64 rank = 1; } -message CheckpointAddressResponse { - string checkpoint_server_address = 1; +message CheckpointMetadataResponse { + string checkpoint_metadata = 1; } message ShouldCommitRequest { @@ -114,7 +116,7 @@ message KillResponse {} service ManagerService { rpc Quorum (ManagerQuorumRequest) returns (ManagerQuorumResponse); - rpc CheckpointAddress(CheckpointAddressRequest) returns (CheckpointAddressResponse); + rpc CheckpointMetadata(CheckpointMetadataRequest) returns (CheckpointMetadataResponse); rpc ShouldCommit(ShouldCommitRequest) returns (ShouldCommitResponse); rpc Kill(KillRequest) returns (KillResponse); } diff --git a/src/lib.rs b/src/lib.rs index d9a124b..529532d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -27,7 +27,7 @@ pub mod torchftpb { } use crate::torchftpb::manager_service_client::ManagerServiceClient; -use crate::torchftpb::{CheckpointAddressRequest, ManagerQuorumRequest, ShouldCommitRequest}; +use crate::torchftpb::{CheckpointMetadataRequest, ManagerQuorumRequest, ShouldCommitRequest}; use pyo3::prelude::*; #[pyclass] @@ -113,15 +113,15 @@ impl ManagerClient { py: Python<'_>, rank: i64, step: i64, - checkpoint_server_addr: String, + checkpoint_metadata: String, shrink_only: bool, timeout: Duration, - ) -> Result<(i64, i64, i64, String, String, i64, Option<i64>, i64, bool), StatusError> { + ) -> Result<QuorumResult, StatusError> { py.allow_threads(move || { let mut request = tonic::Request::new(ManagerQuorumRequest { rank: rank, step: step, - checkpoint_server_addr: checkpoint_server_addr, + checkpoint_metadata: checkpoint_metadata, shrink_only: shrink_only, }); @@ -131,28 +131,30 @@ impl ManagerClient { let response = self.runtime.block_on(self.client.clone().quorum(request))?; let resp = response.into_inner(); - Ok(( - resp.quorum_id, - resp.replica_rank, - resp.replica_world_size, - resp.address, - resp.store_address, - resp.max_step, - resp.max_rank, - resp.max_world_size, - resp.heal, - )) + Ok(QuorumResult { + quorum_id: resp.quorum_id, + replica_rank: resp.replica_rank, + replica_world_size: resp.replica_world_size, + recover_src_manager_address: resp.recover_src_manager_address, + recover_src_rank: resp.recover_src_rank, + recover_dst_ranks: resp.recover_dst_ranks, + store_address: resp.store_address, + max_step: resp.max_step, + max_rank: resp.max_rank, + max_world_size: resp.max_world_size, + heal: resp.heal, + }) }) } - fn checkpoint_address( + fn checkpoint_metadata( &self, py: Python<'_>, rank: i64, timeout: Duration, ) -> Result<String, StatusError> { py.allow_threads(move || { - let mut request = tonic::Request::new(CheckpointAddressRequest { rank: rank }); + let mut request = tonic::Request::new(CheckpointMetadataRequest { rank: rank }); // This timeout is processed on the server side so we also enable // keep alives to detect server health. @@ -160,9 +162,9 @@ impl ManagerClient { let response = self .runtime - .block_on(self.client.clone().checkpoint_address(request))?; + .block_on(self.client.clone().checkpoint_metadata(request))?; let resp = response.into_inner(); - Ok(resp.checkpoint_server_address) + Ok(resp.checkpoint_metadata) }) } @@ -194,6 +196,41 @@ impl ManagerClient { } } +#[pyclass(get_all, set_all)] +struct QuorumResult { + quorum_id: i64, + replica_rank: i64, + replica_world_size: i64, + recover_src_manager_address: String, + recover_src_rank: Option<i64>, + recover_dst_ranks: Vec<i64>, + store_address: String, + max_step: i64, + max_rank: Option<i64>, + max_world_size: i64, + heal: bool, +} + +#[pymethods] +impl QuorumResult { + #[new] + fn new() -> Self { + Self { + quorum_id: 0, + replica_rank: 0, + replica_world_size: 1, + recover_src_manager_address: "".to_string(), + recover_src_rank: None, + recover_dst_ranks: Vec::new(), + store_address: "".to_string(), + max_step: 0, + max_rank: None, + max_world_size: 1, + heal: false, + } + } +} + fn reset_python_signals(py: Python<'_>) -> PyResult<()> { // clear python signal handlers // signal.signal(signal.SIGINT, signal.SIG_DFL) @@ -319,6 +356,7 @@ fn torchft(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::<Manager>()?; m.add_class::<ManagerClient>()?; m.add_class::<Lighthouse>()?; + m.add_class::<QuorumResult>()?; m.add_function(wrap_pyfunction!(lighthouse_main, m)?)?; Ok(()) diff --git a/src/lighthouse.rs b/src/lighthouse.rs index e6be595..643aef1 100644 --- a/src/lighthouse.rs +++ b/src/lighthouse.rs @@ -77,7 +77,7 @@ pub struct LighthouseOpt { #[structopt( long = "join_timeout_ms", default_value = "60000", - help = "How long to wait for new replicas to join before considering a quorum" + help = "How long to wait for heartbeating stragglers to join before issuing quorum" )] pub join_timeout_ms: u64, @@ -90,14 +90,14 @@ pub struct LighthouseOpt { #[structopt( long = "quorum_tick_ms", default_value = "100", - help = "How frequently to check for quorum when waiting for workers." + help = "How frequently to check for quorum when waiting for stragglers." )] pub quorum_tick_ms: u64, #[structopt( long = "heartbeat_timeout_ms", default_value = "5000", - help = "how long to wait for a heartbeat before considering a replica dead." + help = "How long to wait for a heartbeat before considering a replica dead." )] pub heartbeat_timeout_ms: u64, } @@ -146,9 +146,10 @@ fn quorum_compute( .any(|(_, details)| details.member.shrink_only); let metadata = format!( - "[{}/{} participants healthy][shrink_only={}]", + "[{}/{} participants healthy][{} heartbeating][shrink_only={}]", healthy_participants.len(), state.participants.len(), + healthy_replicas.len(), shrink_only, ); @@ -190,7 +191,7 @@ fn quorum_compute( return ( None, format!( - "No quorum, only have {} participants, need min_replicas {} {}", + "New quorum not ready, only have {} participants, need min_replicas {} {}", healthy_participants.len(), opt.min_replicas, metadata @@ -203,7 +204,7 @@ fn quorum_compute( return ( None, format!( - "No quorum, only have {} participants, need at least half of {} healthy workers {}", + "New quorum not ready, only have {} participants, need at least half of {} healthy workers {}", healthy_participants.len(), healthy_replicas.len(), metadata @@ -261,7 +262,7 @@ impl Lighthouse { fn _quorum_tick(self: Arc<Self>, state: &mut State) -> Result<()> { let (quorum_met, reason) = quorum_compute(Instant::now(), state, &self.opt); - info!("{}", reason); + info!("Next quorum status: {}", reason); if quorum_met.is_some() { let participants = quorum_met.unwrap(); @@ -600,7 +601,9 @@ mod tests { let now = Instant::now(); - assert!(!quorum_compute(now, &state, &opt).0.is_some()); + let (quorum_met, reason) = quorum_compute(now, &state, &opt); + assert!(quorum_met.is_none(), "{}", reason); + assert!(reason.contains("New quorum not ready, only have 0 participants, need min_replicas 1 [0/0 participants healthy]"), "{}", reason); state.participants.insert( "a".to_string(), @@ -689,7 +692,13 @@ mod tests { ); state.heartbeats.insert("a".to_string(), now); - assert!(quorum_compute(now, &state, &opt).0.is_some()); + let (quorum_met, reason) = quorum_compute(now, &state, &opt); + assert!(quorum_met.is_some(), "{}", reason); + assert!( + reason.contains("[1/1 participants healthy][1 heartbeating]"), + "{}", + reason + ); // expired heartbeat state @@ -698,6 +707,11 @@ mod tests { let (quorum_met, reason) = quorum_compute(now, &state, &opt); assert!(quorum_met.is_none(), "{}", reason); + assert!( + reason.contains("[0/1 participants healthy][0 heartbeating]"), + "{}", + reason + ); // 1 healthy, 1 expired state.participants.insert( @@ -886,6 +900,7 @@ mod tests { let (quorum_met, reason) = quorum_compute(now, &state, &opt); assert!(quorum_met.is_some(), "{}", reason); + assert!(reason.contains("[shrink_only=true]",), "{}", reason); let quorum = quorum_met.unwrap(); assert!(quorum.len() == 1); @@ -982,7 +997,7 @@ mod tests { state.heartbeats.insert("b".to_string(), now); let (quorum_met, reason) = quorum_compute(now, &state, &opt); assert!(quorum_met.is_none(), "{}", reason); - assert!(reason.contains("at least half"), "{}", reason); + assert!(reason.contains("New quorum not ready, only have 1 participants, need at least half of 2 healthy workers [1/1 participants healthy][2 heartbeating]"), "{}", reason); Ok(()) } diff --git a/src/manager.rs b/src/manager.rs index 982500a..931e995 100644 --- a/src/manager.rs +++ b/src/manager.rs @@ -26,7 +26,7 @@ use crate::torchftpb::lighthouse_service_client::LighthouseServiceClient; use crate::torchftpb::manager_service_client::ManagerServiceClient; use crate::torchftpb::{ manager_service_server::{ManagerService, ManagerServiceServer}, - CheckpointAddressRequest, CheckpointAddressResponse, KillRequest, KillResponse, + CheckpointMetadataRequest, CheckpointMetadataResponse, KillRequest, KillResponse, LighthouseHeartbeatRequest, LighthouseQuorumRequest, ManagerQuorumRequest, ManagerQuorumResponse, Quorum, QuorumMember, ShouldCommitRequest, ShouldCommitResponse, }; @@ -38,7 +38,7 @@ use log::{info, warn}; use std::{println as info, println as warn}; struct ManagerState { - checkpoint_servers: HashMap<i64, String>, + checkpoint_metadata: HashMap<i64, String>, channel: broadcast::Sender<Quorum>, participants: HashSet<i64>, @@ -104,7 +104,7 @@ impl Manager { world_size: world_size, heartbeat_interval: heartbeat_interval, state: Mutex::new(ManagerState { - checkpoint_servers: HashMap::new(), + checkpoint_metadata: HashMap::new(), channel: tx, participants: HashSet::new(), @@ -237,8 +237,8 @@ impl ManagerService for Arc<Manager> { // save checkpoint server info for healing process // TODO: make separate call to set? state - .checkpoint_servers - .insert(req.rank, req.checkpoint_server_addr.clone()); + .checkpoint_metadata + .insert(req.rank, req.checkpoint_metadata.clone()); // TODO check step state.participants.insert(rank); @@ -266,81 +266,28 @@ impl ManagerService for Arc<Manager> { .await .map_err(|e| Status::internal(e.to_string()))?; - let mut participants = quorum.participants.clone(); - participants.sort_by(|a, b| a.replica_id.cmp(&b.replica_id)); - - let replica_rank = participants.iter().enumerate().find_map(|(i, p)| { - if p.replica_id == self.replica_id { - Some(i) - } else { - None - } - }); - if replica_rank.is_none() { - return Err(Status::not_found(format!( - "replica {} not participating in returned quorum", - self.replica_id - ))); - } - - let max_step = participants.iter().map(|p| p.step).max().unwrap(); - let max_participants: Vec<&QuorumMember> = - participants.iter().filter(|p| p.step == max_step).collect(); - - let primary = max_participants[rank as usize % max_participants.len()]; - - let mut max_rank = None; - for (i, p) in max_participants.iter().enumerate() { - if p.replica_id == self.replica_id { - max_rank = Some(i as i64); - break; - } - } - - // Decide whether we should be healing: - // 1. if we're not at the max step - // 2. if everyone is at the first step and we're not the primary - let heal = max_step != req.step || max_step == 0 && primary.replica_id != self.replica_id; - if heal { - info!( - "healing is required step={}, max_step={}", - req.step, max_step - ); - } - - let reply = ManagerQuorumResponse { - quorum_id: quorum.quorum_id, - // address is used for looking up the checkpoint server address. - address: primary.address.clone(), - store_address: primary.store_address.clone(), - max_step: max_step, - max_rank: max_rank, - max_world_size: max_participants.len() as i64, - replica_rank: replica_rank.unwrap() as i64, - replica_world_size: participants.len() as i64, - heal: heal, - }; - info!("returning quorum for rank {}", rank); + let reply = compute_quorum_results(&self.replica_id, rank, &quorum)?; + Ok(Response::new(reply)) } - async fn checkpoint_address( + async fn checkpoint_metadata( &self, - request: Request<CheckpointAddressRequest>, - ) -> Result<Response<CheckpointAddressResponse>, Status> { + request: Request<CheckpointMetadataRequest>, + ) -> Result<Response<CheckpointMetadataResponse>, Status> { let state = self.state.lock().await; let req = request.into_inner(); - let address = state - .checkpoint_servers + let metadata = state + .checkpoint_metadata .get(&req.rank) .ok_or_else(|| Status::invalid_argument("rank not found"))?; - let reply = CheckpointAddressResponse { - checkpoint_server_address: address.clone(), + let reply = CheckpointMetadataResponse { + checkpoint_metadata: metadata.clone(), }; Ok(Response::new(reply)) } @@ -407,6 +354,131 @@ impl ManagerService for Arc<Manager> { } } +fn compute_quorum_results( + replica_id: &str, + rank: i64, + quorum: &Quorum, +) -> Result<ManagerQuorumResponse, Status> { + let mut participants = quorum.participants.clone(); + participants.sort_by(|a, b| a.replica_id.cmp(&b.replica_id)); + + // Compute the rank of the replica in the returned quorum. + let replica_rank = participants + .iter() + .enumerate() + .find_map(|(i, p)| { + if p.replica_id == replica_id { + Some(i) + } else { + None + } + }) + .ok_or_else(|| { + Status::not_found(format!( + "replica {} not participating in returned quorum", + replica_id + )) + })?; + + let step = participants[replica_rank].step; + + // Compute the details for workers at max step. + let max_step = participants.iter().map(|p| p.step).max().unwrap(); + let max_participants: Vec<&QuorumMember> = + participants.iter().filter(|p| p.step == max_step).collect(); + let max_rank = max_participants.iter().enumerate().find_map(|(i, p)| { + if p.replica_id == replica_id { + Some(i as i64) + } else { + None + } + }); + + // The primary TCPStore to use for this rank. + let primary_rank = rank as usize % max_participants.len(); + let primary = max_participants[primary_rank]; + + // Compute recovery assignments + + // Nodes are recovering if: + // 1. not at the max step + // 2. max_step == 0 and not the primary replica + let all_recover_dst_ranks: Vec<usize> = participants + .iter() + .enumerate() + .filter_map(|(i, p)| { + if p.step != max_step || max_step == 0 && primary.replica_id != p.replica_id { + Some(i) + } else { + None + } + }) + .collect(); + let all_recover_dst_ranks_set = all_recover_dst_ranks.iter().collect::<HashSet<_>>(); + let up_to_date_ranks: Vec<usize> = participants + .iter() + .enumerate() + .filter_map(|(i, _p)| { + if !all_recover_dst_ranks_set.contains(&i) { + Some(i) + } else { + None + } + }) + .collect(); + + // This is a map of rank to the ranks that are recovering from that node. + let mut recovery_assignments: HashMap<usize, Vec<i64>> = HashMap::new(); + // The rank of the node that this rank is recovering from. + let mut recover_src_rank: Option<i64> = None; + for (i, recovering_rank) in all_recover_dst_ranks.iter().enumerate() { + let up_to_date_idx = (i + rank as usize) % up_to_date_ranks.len(); + let recovering_recover_src_rank = up_to_date_ranks[up_to_date_idx]; + if !recovery_assignments.contains_key(&recovering_recover_src_rank) { + recovery_assignments.insert(recovering_recover_src_rank, Vec::new()); + } + recovery_assignments + .get_mut(&recovering_recover_src_rank) + .unwrap() + .push(*recovering_rank as i64); + if *recovering_rank == replica_rank { + recover_src_rank = Some(recovering_recover_src_rank as i64); + } + } + + let heal = recover_src_rank.is_some(); + if heal { + info!( + "healing is required step={}, max_step={}, recover_src_rank={}", + step, + max_step, + recover_src_rank.unwrap() + ); + } + + let recover_src_manager_address = match recover_src_rank { + Some(r) => participants[r as usize].address.clone(), + None => "".to_string(), + }; + + Ok(ManagerQuorumResponse { + quorum_id: quorum.quorum_id, + // address is used for looking up the checkpoint server address. + recover_src_manager_address: recover_src_manager_address, + recover_src_rank: recover_src_rank, + recover_dst_ranks: recovery_assignments + .get(&replica_rank) + .map_or_else(Vec::new, |v| v.clone()), + store_address: primary.store_address.clone(), + max_step: max_step, + max_rank: max_rank, + max_world_size: max_participants.len() as i64, + replica_rank: replica_rank as i64, + replica_world_size: participants.len() as i64, + heal: heal, + }) +} + #[cfg(test)] mod tests { use super::*; @@ -506,7 +578,7 @@ mod tests { let mut request = tonic::Request::new(ManagerQuorumRequest { rank: 0, step: 123, - checkpoint_server_addr: "addr".to_string(), + checkpoint_metadata: "addr".to_string(), shrink_only: false, }); request.set_timeout(Duration::from_secs(10)); @@ -516,7 +588,7 @@ mod tests { lighthouse_fut.abort(); assert_eq!(resp.quorum_id, 1); - assert_eq!(resp.address, manager.address()); + assert_eq!(resp.recover_src_manager_address, "".to_string()); assert_eq!(resp.store_address, "store_addr".to_string()); assert_eq!(resp.max_step, 123); assert_eq!(resp.max_rank, Some(0)); @@ -565,7 +637,7 @@ mod tests { let mut request = tonic::Request::new(ManagerQuorumRequest { rank: 0, step: 0, - checkpoint_server_addr: "addr".to_string(), + checkpoint_metadata: "addr".to_string(), shrink_only: false, }); request.set_timeout(Duration::from_secs(10)); @@ -597,4 +669,183 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_checkpoint_metadata() -> Result<()> { + let lighthouse = Lighthouse::new(LighthouseOpt { + bind: "[::]:0".to_string(), + join_timeout_ms: 100, + min_replicas: 1, + quorum_tick_ms: 100, + heartbeat_timeout_ms: 5000, + }) + .await?; + let lighthouse_fut = tokio::spawn(lighthouse.clone().run()); + + let manager = Manager::new( + "rep_id".to_string(), + lighthouse.address(), + "localhost".to_string(), + "[::]:0".to_string(), + "store_addr".to_string(), + 1, // world size + Duration::from_millis(100), // heartbeat interval + Duration::from_secs(10), // connect timeout + ) + .await?; + let manager_fut = tokio::spawn(manager.clone().run()); + + let mut client = manager_client_new(manager.address(), Duration::from_secs(10)).await?; + + let request = tonic::Request::new(CheckpointMetadataRequest { rank: 0 }); + let resp = client.checkpoint_metadata(request).await; + assert!(resp.err().unwrap().to_string().contains("rank not found")); + + { + let mut state = manager.state.lock().await; + + state.checkpoint_metadata.insert(0, "addr".to_string()); + } + + let request = tonic::Request::new(CheckpointMetadataRequest { rank: 0 }); + let resp = client.checkpoint_metadata(request).await?.into_inner(); + assert_eq!(resp.checkpoint_metadata, "addr".to_string()); + + manager_fut.abort(); + lighthouse_fut.abort(); + + Ok(()) + } + + #[tokio::test] + async fn test_compute_quorum_results_first_step() -> Result<()> { + let quorum = Quorum { + quorum_id: 1, + participants: vec![ + QuorumMember { + replica_id: "replica_0".to_string(), + address: "addr_0".to_string(), + store_address: "store_addr_0".to_string(), + step: 0, + world_size: 1, + shrink_only: false, + }, + QuorumMember { + replica_id: "replica_1".to_string(), + address: "addr_1".to_string(), + store_address: "store_addr_1".to_string(), + step: 0, + world_size: 1, + shrink_only: false, + }, + ], + created: None, + }; + + // rank 0 + + let results = compute_quorum_results("replica_0", 0, &quorum)?; + assert!(!results.heal); + assert_eq!(results.replica_rank, 0); + assert_eq!(results.recover_src_rank, None); + assert_eq!(results.recover_dst_ranks, vec![1]); + + let results = compute_quorum_results("replica_1", 0, &quorum)?; + assert!(results.heal); + assert_eq!(results.replica_rank, 1); + assert_eq!(results.recover_src_rank, Some(0)); + assert_eq!(results.recover_dst_ranks, Vec::<i64>::new()); + + // rank 1 assignments should be offset from rank 0 above and the primary + + let results = compute_quorum_results("replica_1", 1, &quorum)?; + assert!(!results.heal); + assert_eq!(results.replica_rank, 1); + assert_eq!(results.recover_src_rank, None); + assert_eq!(results.recover_dst_ranks, vec![0]); + + Ok(()) + } + + #[tokio::test] + async fn test_compute_quorum_results_recovery() -> Result<()> { + let quorum = Quorum { + quorum_id: 1, + participants: vec![ + QuorumMember { + replica_id: "replica_0".to_string(), + address: "addr_0".to_string(), + store_address: "store_addr_0".to_string(), + step: 0, + world_size: 1, + shrink_only: false, + }, + QuorumMember { + replica_id: "replica_1".to_string(), + address: "addr_1".to_string(), + store_address: "store_addr_1".to_string(), + step: 1, + world_size: 1, + shrink_only: false, + }, + QuorumMember { + replica_id: "replica_2".to_string(), + address: "addr_2".to_string(), + store_address: "store_addr_2".to_string(), + step: 0, + world_size: 1, + shrink_only: false, + }, + QuorumMember { + replica_id: "replica_3".to_string(), + address: "addr_3".to_string(), + store_address: "store_addr_3".to_string(), + step: 1, + world_size: 1, + shrink_only: false, + }, + QuorumMember { + replica_id: "replica_4".to_string(), + address: "addr_4".to_string(), + store_address: "store_addr_4".to_string(), + step: 0, + world_size: 1, + shrink_only: false, + }, + ], + created: None, + }; + + // rank 0 + + let results = compute_quorum_results("replica_0", 0, &quorum)?; + assert!(results.heal); + assert_eq!(results.recover_src_manager_address, "addr_1".to_string()); + assert_eq!(results.replica_rank, 0); + assert_eq!(results.recover_src_rank, Some(1)); + assert!(results.recover_dst_ranks.is_empty()); + + let results = compute_quorum_results("replica_1", 0, &quorum)?; + assert!(!results.heal); + assert_eq!(results.recover_src_manager_address, "".to_string()); + assert_eq!(results.replica_rank, 1); + assert_eq!(results.recover_src_rank, None); + assert_eq!(results.recover_dst_ranks, vec![0, 4]); + + let results = compute_quorum_results("replica_3", 0, &quorum)?; + assert!(!results.heal); + assert_eq!(results.replica_rank, 3); + assert_eq!(results.recover_src_rank, None); + assert_eq!(results.recover_dst_ranks, vec![2]); + + // rank 1 assignments should be offset from rank 0 above + + let results = compute_quorum_results("replica_1", 1, &quorum)?; + assert!(!results.heal); + assert_eq!(results.replica_rank, 1); + assert_eq!(results.recover_src_rank, None); + assert_eq!(results.recover_dst_ranks, vec![2]); + + Ok(()) + } } diff --git a/torchft/checkpointing.py b/torchft/checkpointing.py index aaad843..c3168b2 100644 --- a/torchft/checkpointing.py +++ b/torchft/checkpointing.py @@ -16,9 +16,11 @@ import socket import threading import urllib.request +from abc import ABC, abstractmethod +from contextlib import contextmanager from datetime import timedelta from http.server import BaseHTTPRequestHandler -from typing import Callable, Generic, TypeVar +from typing import Generator, Generic, List, Optional, TypeVar import torch @@ -29,7 +31,83 @@ T = TypeVar("T") -class CheckpointServer(Generic[T]): +class CheckpointTransport(Generic[T], ABC): + @abstractmethod + def metadata(self) -> str: + """ + Returns a string that will be used by the remote CheckpointTransport to fetch the checkpoint. + """ + ... + + @abstractmethod + def send_checkpoint( + self, dst_ranks: List[int], step: int, state_dict: T, timeout: timedelta + ) -> None: + """ + Sends the checkpoint, only called when there is a rank that is behind. + + This may be async. + + Args: + dst_ranks: the ranks to send to + step: the step number to send + state_dict: the state dict to send + timeout: the timeout to wait for the checkpoint to be sent + """ + ... + + def disallow_checkpoint(self) -> None: + """ + Called after send_checkpoint to wait for the checkpoint to be sent. + + Once this returns, the state_dict may be mutated so no further data should be sent. + """ + ... + + @abstractmethod + def recv_checkpoint( + self, src_rank: int, metadata: str, step: int, timeout: timedelta + ) -> T: + """ + Receives the checkpoint from the given rank. + + Args: + src_rank: the rank to receive the checkpoint from + metadata: the metadata returned by the remote CheckpointTransport + step: the step number to receive + timeout: the timeout to wait for the checkpoint + """ + ... + + def shutdown(self, wait: bool = True) -> None: + """ + Called to shutdown the checkpoint transport. + + Args: + wait: whether to wait for the transport to shutdown + """ + + +@contextmanager +def _timed_acquire( + lock: threading.Lock, timeout: timedelta +) -> Generator[None, None, None]: + """ + Acquire a lock with a timeout. + + Args: + lock: the lock to acquire + timeout: the timeout to acquire the lock + """ + if not lock.acquire(timeout=timeout.total_seconds()): + raise TimeoutError(f"timed out acquiring lock after {timeout}") + try: + yield + finally: + lock.release() + + +class CheckpointServer(CheckpointTransport[T]): """ This is an HTTP server that can be used to transfer checkpoints between workers. @@ -41,11 +119,16 @@ class CheckpointServer(Generic[T]): state_dict: a callable that returns the state dict to be transferred """ - def __init__(self, state_dict: Callable[[], T], timeout: timedelta) -> None: + def __init__(self, timeout: timedelta) -> None: self._checkpoint_lock = threading.Lock() self._disallowed = False self._step = -1 self._timeout = timeout + self._state_dict: Optional[T] = None + + # We don't allow checkpoints until the first send_checkpoint to avoid + # serving the default step=-1 invalid checkpoint. + self.disallow_checkpoint() ckpt_server = self @@ -58,7 +141,9 @@ def do_GET(self): # validate socket timeout is actually set assert self.connection.gettimeout() == self.timeout - with ckpt_server._checkpoint_lock: + with _timed_acquire( + ckpt_server._checkpoint_lock, ckpt_server._timeout + ): step = ckpt_server._step if self.path != f"/checkpoint/{step}": @@ -74,9 +159,9 @@ def do_GET(self): self.send_header("Content-type", "application/octet-stream") self.end_headers() - sd = state_dict() + state_dict = ckpt_server._state_dict - torch.save(sd, self.wfile) + torch.save(state_dict, self.wfile) except Exception as e: logger.exception( f"Exception in checkpoint server when handling {self.path=}: {e}", @@ -113,11 +198,13 @@ def load_from_address(cls, address: str, timeout: timedelta) -> T: data = f.read() reader = io.BytesIO(data) - return torch.load(reader, weights_only=True) + # We have to set weights_only to False as there are some non-tensor + # states like lr_scheduler. + return torch.load(reader, weights_only=False) def address(self) -> str: """ - Returns the HTTP address to fetch a checkpoint from this server at the current step. + Returns the HTTP address to fetch a checkpoint from this server. Step must be appended to the end of the address. Format: http://host:port/checkpoint/1234 @@ -125,7 +212,7 @@ def address(self) -> str: an HTTP address """ port = self._server.socket.getsockname()[1] - return f"http://{socket.gethostname()}:{port}/checkpoint/{self._step}" + return f"http://{socket.gethostname()}:{port}/checkpoint/" def _serve(self) -> None: try: @@ -156,8 +243,28 @@ def allow_checkpoint(self, step: int) -> None: self._disallowed = False self._checkpoint_lock.release() - def shutdown(self) -> None: + def shutdown(self, wait: bool = True) -> None: """ Shutdown the server. """ - self._server.shutdown() + if not wait: + # hack for nonblocking shutdown of socketserver threads + # pyre-fixme[16]: no attribute `__shutdown_request`. + self._server.__shutdown_request = True + if wait: + self._server.shutdown() + self._thread.join() + + def metadata(self) -> str: + return self.address() + + def send_checkpoint( + self, dst_ranks: List[int], step: int, state_dict: T, timeout: timedelta + ) -> None: + self._state_dict = state_dict + self.allow_checkpoint(step) + + def recv_checkpoint( + self, src_rank: int, metadata: str, step: int, timeout: timedelta + ) -> T: + return self.load_from_address(f"{metadata}{step}", timeout) diff --git a/torchft/checkpointing_test.py b/torchft/checkpointing_test.py index 983c429..31658b4 100644 --- a/torchft/checkpointing_test.py +++ b/torchft/checkpointing_test.py @@ -4,12 +4,13 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import threading import urllib.error from datetime import timedelta from unittest import TestCase from unittest.mock import MagicMock -from torchft.checkpointing import CheckpointServer +from torchft.checkpointing import CheckpointServer, _timed_acquire class TestCheckpointing(TestCase): @@ -18,26 +19,87 @@ def test_checkpoint_server(self) -> None: state_dict_fn = MagicMock() state_dict_fn.return_value = expected server = CheckpointServer( - state_dict=state_dict_fn, timeout=timedelta(seconds=10), ) - server.disallow_checkpoint() - server.allow_checkpoint(1234) + server.send_checkpoint( + dst_ranks=[], + step=1234, + state_dict=expected, + timeout=timedelta(seconds=10), + ) - addr = server.address() + metadata = server.metadata() - out = CheckpointServer.load_from_address(addr, timeout=timedelta(seconds=10)) + out = server.recv_checkpoint( + src_rank=0, metadata=metadata, step=1234, timeout=timedelta(seconds=10) + ) self.assertEqual(out, expected) # test timeout with self.assertRaisesRegex(urllib.error.URLError, r"urlopen error"): - CheckpointServer.load_from_address(addr, timeout=timedelta(seconds=0.0)) + server.recv_checkpoint( + src_rank=0, metadata=metadata, step=1234, timeout=timedelta(seconds=0.0) + ) # test mismatch case - server.allow_checkpoint(2345) + server.send_checkpoint( + dst_ranks=[], + step=2345, + state_dict=expected, + timeout=timedelta(seconds=10), + ) with self.assertRaisesRegex(urllib.error.HTTPError, r"Error 400"): - CheckpointServer.load_from_address(addr, timeout=timedelta(seconds=10)) + server.recv_checkpoint( + src_rank=0, metadata=metadata, step=1234, timeout=timedelta(seconds=10) + ) server.shutdown() + + def test_checkpoint_server_locking(self) -> None: + server = CheckpointServer( + timeout=timedelta(seconds=10), + ) + + # server should start up in a disallowed state this will block incoming + # requests until allow_checkpoint is called + self.assertTrue(server._checkpoint_lock.locked()) + self.assertTrue(server._disallowed) + self.assertEqual(server._step, -1) + + # allow requests + server.allow_checkpoint(1) + + self.assertFalse(server._checkpoint_lock.locked()) + self.assertFalse(server._disallowed) + self.assertEqual(server._step, 1) + + # duplicate allow/disallow is fine + server.allow_checkpoint(2) + self.assertEqual(server._step, 2) + + server.disallow_checkpoint() + server.disallow_checkpoint() + self.assertTrue(server._checkpoint_lock.locked()) + self.assertTrue(server._disallowed) + + server.shutdown() + + def test_timed_acquire(self) -> None: + lock = threading.Lock() + + with _timed_acquire(lock, timedelta(seconds=10)): + self.assertTrue(lock.locked()) + + self.assertFalse(lock.locked()) + + lock.acquire() + + with self.assertRaisesRegex( + TimeoutError, r"timed out acquiring lock after 0.0" + ): + with _timed_acquire(lock, timedelta(seconds=0.0)): + pass + + self.assertTrue(lock.locked()) diff --git a/torchft/device_mesh_test.py b/torchft/device_mesh_test.py new file mode 100644 index 0000000..ee78c6d --- /dev/null +++ b/torchft/device_mesh_test.py @@ -0,0 +1,84 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import io +import os +from concurrent.futures import ProcessPoolExecutor +from typing import cast +from unittest import TestCase +from unittest.mock import Mock + +import torch +import torch.distributed as dist + +from torchft.manager import Manager +from torchft.process_group import ( + ManagedProcessGroup, + ProcessGroupGloo, + ft_init_device_mesh, +) + + +class DeviceMeshTest(TestCase): + @staticmethod + def _test_init_device_mesh(world_size: int, rank: int) -> None: + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(12346) + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(4) + + testcase = TestCase() + + manager = Mock(spec=Manager) + manager._pg = ProcessGroupGloo() + # Even though we only have 4 workers, we can still initialize (2, 4) mesh. + # That's because the replicate group is NOT phystically created in the + # real mesh but is virtually added to the mesh via ManagedDeviceMesh. + device_mesh = ft_init_device_mesh( + device_type="cpu", + mesh_shape=(2, world_size), + mesh_dim_names=("dp_replicate", "dp_shard"), + replicate_dim=0, + manager=manager, + ) + + testcase.assertTrue( + isinstance(device_mesh.get_group("dp_replicate"), ManagedProcessGroup) + ) + testcase.assertTrue( + not isinstance(device_mesh.get_group("dp_shard"), ManagedProcessGroup) + ) + replicate_group = device_mesh.get_group("dp_replicate") + testcase.assertEqual( + cast(ManagedProcessGroup, replicate_group)._manager, manager + ) + replicate_mesh = device_mesh["dp_replicate"] + testcase.assertEqual(replicate_mesh.get_group(), replicate_group) + + flatten_mesh = device_mesh._flatten("dp") + manager.num_participants.return_value = 0 + testcase.assertEqual(flatten_mesh.size(), world_size) + manager.num_participants.return_value = 1 + testcase.assertEqual(flatten_mesh.size(), world_size) + manager.num_participants.return_value = 2 + testcase.assertEqual(flatten_mesh.size(), world_size * 2) + + testcase.assertEqual(flatten_mesh.get_local_rank(), dist.get_rank()) + + device_mesh.get_coordinate() + buffer = io.BytesIO() + torch.save(device_mesh, buffer) + buffer.seek(0) + torch.load(buffer, weights_only=False) + + def test_init_device_mesh(self) -> None: + with ProcessPoolExecutor(max_workers=4) as executor: + futures = [] + for i in range(4): + future = executor.submit(self._test_init_device_mesh, 4, i) + futures.append(future) + for f in futures: + f.result() diff --git a/torchft/local_sgd.py b/torchft/local_sgd.py index 1458f07..ec3cf82 100644 --- a/torchft/local_sgd.py +++ b/torchft/local_sgd.py @@ -3,25 +3,29 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - """ LocalSGD ========= - This module implements a fault tolerant version of LocalSGD and related methods. """ - -from typing import Any, Dict, List, Mapping, Optional +import logging +from types import TracebackType +from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Type import torch from torch import nn, optim +from torch.nn.parameter import Parameter +from torch.optim.optimizer import Optimizer +from torch.utils.hooks import RemovableHandle from torchft.manager import Manager +logger: logging.Logger = logging.getLogger(__name__) + -class LocalSGD(nn.Module): +class LocalSGD: """ - LocalSGD is a model wrapper similar to DistributedDataParallel that + LocalSGD is a context manager that implements the algorithm described in https://arxiv.org/pdf/1805.09767 This will synchronize the model parameters periodically in a fault tolerant @@ -68,18 +72,14 @@ def __init__( pin_memory: Whether to pin the memory used for the backup of the model parameters. """ super().__init__() - self._manager = manager self._model = model + self._local_optimizer = optimizer self._local_step = 0 - self._started_step = False self._sync_every = sync_every assert sync_every >= 1, "sync_every must be greater than or equal to 1" - device = backup_device or torch.device("cpu") - self._backup_parameters: Dict[str, torch.Tensor] = {} - for name, p in self._model.named_parameters(): t = torch.empty(*tuple(p.shape), dtype=p.dtype, device=device) if ( @@ -90,95 +90,150 @@ def __init__( t = t.pin_memory() self._backup_parameters[name] = t + self._hooks: List[RemovableHandle] = [] # Need to copy the parameters to the host to be safe if we are on the first step. self._save_parameters() - optimizer.register_step_post_hook(self._step_post_hook) + def __enter__(self) -> "LocalSGD": + # Add optimizer hook which increments the local step counter and syncs if necessary + self._hooks.append( + self._local_optimizer.register_step_post_hook(self._step_post_hook) + ) + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> bool: + # Handle any cleanup or error handling here + if exc_type is not None: + # If an exception occurred, restore parameters + self._restore_parameters() + # Clean up hooks + for hook in self._hooks: + hook.remove() + self._hooks.clear() + + return False # Propagate exceptions def _save_parameters(self) -> None: - # TODO: consider running copy on a separate stream - for name, p in self._model.named_parameters(): - self._backup_parameters[name].copy_(p.data, non_blocking=True) + with torch.no_grad(): + # TODO: consider running copy on a separate stream + for name, p in self._model.named_parameters(): + self._backup_parameters[name].copy_(p.data, non_blocking=True) def _restore_parameters(self) -> None: - # TODO: consider running copy on a separate stream - for name, p in self._model.named_parameters(): - p.data.copy_(self._backup_parameters[name], non_blocking=True) + with torch.no_grad(): + # TODO: consider running copy on a separate stream + for name, p in self._model.named_parameters(): + p.data.copy_(self._backup_parameters[name], non_blocking=False) - # pyre-fixme[14]: support state_dict args - def state_dict(self) -> Dict[str, object]: - """ - state_dict returns the state_dict from the last time LocalSGD - synchronized and not the current weights. - """ - state_dict = self._model.state_dict() - for name, p in self._backup_parameters.items(): - assert name in state_dict - state_dict[name] = p - return state_dict - - def load_state_dict( - self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False + def _step_post_hook( + self, _optim: optim.Optimizer, _args: List[object], _kwargs: Dict[str, object] ) -> None: """ - Loads the state dict to the model and the backup parameters. + This hook is registered on the optimizer and is called after the optimizer step. + """ + self._local_step += 1 + if self._local_step >= self._sync_every: + self.sync() - This must be called while the model weights aren't being modified to - avoid corrupting the backup weights. + def sync(self) -> None: """ - self._model.load_state_dict(state_dict, strict=strict, assign=assign) - self._save_parameters() + Synchronizes and averages the model weights across the manager. + """ + self._manager.start_quorum() + self._perform_sync() + self._local_step = 0 - def forward(self, *args: object, **kwargs: object) -> object: + def _perform_sync(self) -> None: + """ + Performs the synchronization of the model weights across the manager. + This method is intended to be overridden by subclasses to implement custom + synchronization logic. """ - Run the model parameters. + self._average() + if self._manager.should_commit(): + self._save_parameters() + else: + # commit failed, restore from the backup parameters + self._restore_parameters() - This should be called before the optimizer step. + def _average(self) -> None: + # TODO: do we need to broadcast buffers like DDP does? - This will start the quorum and save the parameters if this is the first step. - """ - if self._local_step == 0: - self._manager.start_quorum() + works = [] + + for p in self._model.parameters(): + # TODO: bucketize parameters + works.append(self._manager.allreduce(p.data.detach())) - self._started_step = True + for work in works: + work.wait() - return self._model.forward(*args, **kwargs) - def _step_post_hook( - self, _optim: optim.Optimizer, _args: List[object], _kwargs: Dict[str, object] - ) -> None: - """ - This hook is registered on the optimizer and is called after the optimizer step. +class DiLoCo(LocalSGD): + """ + DiLoCo is a subclass of LocalSGD that overrides the synchronization + mechanism to average and synchronize the pseudogradients (delta of the previous global weight and current local weights). - This will call the allreduce on the model weights every sync_every steps. - If any errors occur it will restore to the weights from the previous sync. + diloco: https://arxiv.org/pdf/2311.08105 + """ - ``forward`` must be called before this function. + def __init__( + self, + manager: Manager, + model: nn.Module, + inner_optimizer: optim.Optimizer, + outer_optimizer: optim.Optimizer, + sync_every: int, + backup_device: Optional[torch.device] = None, + pin_memory: bool = True, + ) -> None: + if manager._use_async_quorum: + raise ValueError( + "Using DiLoCo require synchronous quorum to be enabled. " + "Ensure that the manager is initialized with use_async_quorum=False" + ) + super().__init__( + manager, model, inner_optimizer, sync_every, backup_device, pin_memory + ) + self._outer_optimizer = outer_optimizer + + def _perform_sync(self) -> None: + """ + Overrides the sync method to calculate the pseugradient, average them across the manager group, and + step using the outer optimizer. """ - assert self._started_step, "forward must be called before step" - self._started_step = False - self._local_step += 1 + # Set the .grad field of each parameter to its pseudogradient + for name, p in self._model.named_parameters(): + assert name in self._backup_parameters + pseudogradient = p.data - self._backup_parameters[name] + p.grad = pseudogradient - if self._local_step >= self._sync_every: - self._local_step = 0 - self._average() + self._average_grads() + # Restore the parameters back to the previous state + self._restore_parameters() - if self._manager.should_commit(): - # save the parameters so we can restore from them later if necessary. - self._save_parameters() - else: - # commit failed, restore from the backup parameters - self._restore_parameters() - - def _average(self) -> None: - # TODO: do we need to broadcast buffers like DDP does? + if self._manager.should_commit(): + # Use the outer optimizer to update the model parameters + self._outer_optimizer.step() + self._save_parameters() + self._outer_optimizer.zero_grad() + def _average_grads(self) -> None: + """ + Average the gradients across the diloco group. + """ works = [] - for p in self._model.parameters(): - # TODO: bucketize parameters - works.append(self._manager.allreduce(p.data.detach())) - + # Perform allreduce on the pseudogradients + assert p.grad is not None + work = self._manager.allreduce(p.grad) + works.append(work) + # Wait for all allreduce operations to complete for work in works: work.wait() diff --git a/torchft/local_sgd_test.py b/torchft/local_sgd_test.py index d2b73cd..05f88b7 100644 --- a/torchft/local_sgd_test.py +++ b/torchft/local_sgd_test.py @@ -11,7 +11,7 @@ import torch from torch import nn, optim -from torchft.local_sgd import LocalSGD +from torchft.local_sgd import DiLoCo, LocalSGD from torchft.manager import Manager @@ -40,57 +40,107 @@ def _copy_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Ten class LocalSGDTest(TestCase): def test_local_sgd_healthy(self) -> None: - base_m = SimpleModel() - optimizer = optim.SGD(base_m.parameters()) + model = SimpleModel() + optimizer = optim.SGD(model.parameters()) manager = create_autospec(Manager) - - m = LocalSGD(manager, base_m, optimizer, sync_every=2) - self.assertEqual(m._local_step, 0) - - torch.testing.assert_close(m._backup_parameters, _params_dict(base_m)) - - inp = torch.rand(2, 3) - - loss = m(inp).mean() - loss.backward() - optimizer.step() - - self.assertEqual(m._local_step, 1) - self.assertEqual(manager.start_quorum.call_count, 1) - - loss = m(inp).mean() - loss.backward() - optimizer.step() - - manager.should_commit.return_value = True - self.assertEqual(m._local_step, 0) - - torch.testing.assert_close(m._backup_parameters, _params_dict(base_m)) - self.assertEqual(manager.should_commit.call_count, 1) - self.assertEqual(manager.allreduce.call_count, 4) + with LocalSGD(manager, model, optimizer, sync_every=2) as local_sgd: + self.assertEqual(local_sgd._local_step, 0) + torch.testing.assert_close( + local_sgd._backup_parameters, _params_dict(model) + ) + inp = torch.rand(2, 3) + loss = model(inp).mean() + loss.backward() + optimizer.step() + + self.assertEqual(local_sgd._local_step, 1) + self.assertEqual(manager.start_quorum.call_count, 0) + loss = model(inp).mean() + loss.backward() + optimizer.step() + self.assertEqual(manager.start_quorum.call_count, 1) + + manager.should_commit.return_value = True + self.assertEqual(local_sgd._local_step, 0) + torch.testing.assert_close( + local_sgd._backup_parameters, _params_dict(model) + ) + self.assertEqual(manager.should_commit.call_count, 1) + self.assertEqual(manager.allreduce.call_count, 4) def test_local_sgd_recovery(self) -> None: - base_m = SimpleModel() - optimizer = optim.SGD(base_m.parameters()) + model = SimpleModel() + optimizer = optim.SGD(model.parameters()) manager = create_autospec(Manager) - m = LocalSGD(manager, base_m, optimizer, sync_every=2) - - torch.testing.assert_close(m._backup_parameters, _params_dict(base_m)) - og_state_dict = _copy_state_dict(base_m.state_dict()) - - inp = torch.rand(2, 3) - - loss = m(inp).mean() - loss.backward() - optimizer.step() - - self.assertEqual(m._local_step, 1) - - state_dict = m.state_dict() - torch.testing.assert_close(state_dict, m._backup_parameters) - torch.testing.assert_close(state_dict, og_state_dict) + with LocalSGD(manager, model, optimizer, sync_every=2) as local_sgd: + torch.testing.assert_close( + local_sgd._backup_parameters, _params_dict(model) + ) + og_state_dict = _copy_state_dict(model.state_dict()) + print(og_state_dict) + + inp = torch.rand(2, 3) + + loss = model(inp).mean() + loss.backward() + optimizer.step() + + # Check that the model's state dict has been updated + for name, param in model.state_dict().items(): + # Ensure the parameter has changed + self.assertFalse( + torch.equal(og_state_dict[name], param), + f"Parameter {name} did not change.", + ) + self.assertEqual(local_sgd._local_step, 1) + + local_sgd._restore_parameters() + torch.testing.assert_close( + local_sgd._backup_parameters, _params_dict(model) + ) + + +class DiLoCoTest(TestCase): + def test_diloco_healthy(self) -> None: + model = SimpleModel() + + # Setup optimizers + inner_optimizer = torch.optim.AdamW( + model.parameters(), lr=4e-4, weight_decay=0.1, betas=(0.9, 0.95) + ) + outer_optimizer = torch.optim.SGD( + model.parameters(), lr=0.7, momentum=0.9, nesterov=True + ) - m.load_state_dict(state_dict) - torch.testing.assert_close(_params_dict(base_m), state_dict) - torch.testing.assert_close(m._backup_parameters, _params_dict(base_m)) + manager = create_autospec(Manager) + manager._use_async_quorum = False + with DiLoCo( + manager, model, inner_optimizer, outer_optimizer, sync_every=2 + ) as diloco: + parameter_count = len(list(model.parameters())) + initial_outer_opt_state = outer_optimizer.state_dict() + self.assertEqual(initial_outer_opt_state["state"], {}) + + self.assertEqual(diloco._local_step, 0) + torch.testing.assert_close(diloco._backup_parameters, _params_dict(model)) + inp = torch.rand(2, 3) + loss = model(inp).mean() + loss.backward() + inner_optimizer.step() + + self.assertEqual(diloco._local_step, 1) + self.assertEqual(manager.start_quorum.call_count, 0) + loss = model(inp).mean() + loss.backward() + inner_optimizer.step() + self.assertEqual(manager.start_quorum.call_count, 1) + + manager.should_commit.return_value = True + self.assertEqual(diloco._local_step, 0) + torch.testing.assert_close(diloco._backup_parameters, _params_dict(model)) + self.assertEqual(manager.should_commit.call_count, 1) + self.assertEqual(manager.allreduce.call_count, parameter_count) + + outer_opt_state = outer_optimizer.state_dict() + self.assertEqual(len(outer_opt_state["state"]), parameter_count) diff --git a/torchft/manager.py b/torchft/manager.py index d9ff366..5ca25cb 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -38,7 +38,7 @@ import torch from torch.distributed import ReduceOp, TCPStore -from torchft.checkpointing import CheckpointServer +from torchft.checkpointing import CheckpointServer, CheckpointTransport from torchft.futures import future_timeout from torchft.torchft import Manager as _Manager, ManagerClient @@ -87,8 +87,8 @@ class Manager: def __init__( self, pg: "ProcessGroup", - load_state_dict: Callable[[T], None], - state_dict: Callable[[], T], + load_state_dict: Optional[Callable[[T], None]], + state_dict: Optional[Callable[[], T]], min_replica_size: int, use_async_quorum: bool = True, timeout: timedelta = timedelta(seconds=60), @@ -104,6 +104,7 @@ def __init__( port: Optional[int] = None, hostname: str = socket.gethostname(), heartbeat_interval: timedelta = timedelta(milliseconds=100), + checkpoint_transport: Optional[CheckpointTransport[Dict[str, T]]] = None, ) -> None: """ Args: @@ -139,9 +140,11 @@ def __init__( lighthouse_addr: if rank==0, the address of the lighthouse server replica_id: if rank==0, the replica_id for this group hostname: if rank==0, the hostname to advertise to the lighthouse server + checkpoint_transport: the checkpoint transport to use for + transfering checkpoints to recovering replicas """ self._load_state_dict = load_state_dict - self._state_dict = state_dict + self._user_state_dict = state_dict self._pending_state_dict: Optional[Dict[str, object]] = None self._use_async_quorum = use_async_quorum self._timeout = timeout @@ -156,15 +159,13 @@ def __init__( world_size = world_size or int(os.environ["WORLD_SIZE"]) self._min_replica_size = min_replica_size - def _manager_state_dict() -> Dict[str, T]: - return { - "user": state_dict(), - "torchft": cast(T, self.state_dict()), - } + if checkpoint_transport is None: + checkpoint_transport = CheckpointServer[Dict[str, T]]( + timeout=timeout, + ) - self._ckpt_server = CheckpointServer[Dict[str, T]]( - _manager_state_dict, - timeout=timeout, + self._checkpoint_transport: CheckpointTransport[Dict[str, T]] = ( + checkpoint_transport ) self._executor = ThreadPoolExecutor( max_workers=1, thread_name_prefix="async_quorum" @@ -223,14 +224,20 @@ def _manager_state_dict() -> Dict[str, T]: self._participating_rank: Optional[int] = None self._participating_world_size: int = 0 - def shutdown(self) -> None: + def set_state_dict_fns( + self, load_state_dict: Callable[[T], None], state_dict: Callable[[], T] + ) -> None: + self._load_state_dict = load_state_dict + self._user_state_dict = state_dict + + def shutdown(self, wait: bool = True) -> None: """ Shutdown the manager and checkpoint server. """ - self._ckpt_server.shutdown() + self._checkpoint_transport.shutdown(wait=wait) if self._manager is not None: self._manager.shutdown() - self._executor.shutdown() + self._executor.shutdown(wait=wait) def allreduce(self, tensor: torch.Tensor) -> torch.futures.Future[torch.Tensor]: """ @@ -386,7 +393,6 @@ def start_quorum( self._errored = None self._healing = False - self._ckpt_server.allow_checkpoint(self._step) # TODO: we should really be wrapping this whole section in a try-except # block to allow gracefully recovering from issues in PG setup and quorum. @@ -422,24 +428,24 @@ def wait_quorum(self) -> None: def _async_quorum( self, allow_heal: bool, shrink_only: bool, quorum_timeout: timedelta ) -> None: - ( - quorum_id, - replica_rank, - replica_world_size, - address, - store_address, - max_step, - max_rank, - max_world_size, - heal, - ) = self._client.quorum( + quorum = self._client.quorum( rank=self._rank, step=self._step, - checkpoint_server_addr=self._ckpt_server.address(), + checkpoint_metadata=self._checkpoint_transport.metadata(), shrink_only=shrink_only, timeout=quorum_timeout, ) + quorum_id = quorum.quorum_id + replica_rank = quorum.replica_rank + replica_world_size = quorum.replica_world_size + recover_src_manager_address = quorum.recover_src_manager_address + store_address = quorum.store_address + max_step = quorum.max_step + max_rank = quorum.max_rank + max_world_size = quorum.max_world_size + heal = quorum.heal + # When using async quorum we need to take the recovered workers. # When not using async quorum we need to take the max world size as all # workers will be healthy. @@ -470,29 +476,54 @@ def _async_quorum( self._pg.configure(store_prefixed_addr, replica_rank, replica_world_size) self._quorum_id = quorum_id - # See manager.rs for healing conditions - if heal and allow_heal: - self._healing = True - self._logger.info( - f"healing required, fetching checkpoint server address from {address=} {max_step=}" - ) - primary_client = ManagerClient( - address, connect_timeout=self._connect_timeout - ) - checkpoint_server_address = primary_client.checkpoint_address( - self._rank, timeout=self._timeout - ) + if allow_heal: + if quorum.recover_dst_ranks: + self._logger.info( + f"peers need recovery from us {quorum.recover_dst_ranks}" + ) + self._checkpoint_transport.send_checkpoint( + dst_ranks=quorum.recover_dst_ranks, + step=max_step, + state_dict=self._manager_state_dict(), + timeout=self._timeout, + ) - self._logger.info(f"fetching checkpoint from {checkpoint_server_address=}") - self._pending_state_dict = CheckpointServer.load_from_address( - checkpoint_server_address, timeout=self._timeout - ) - self.load_state_dict(self._pending_state_dict["torchft"]) - # we apply the user state dict only when safe from the main thread + # See manager.rs for healing conditions + if heal: + self._healing = True + self._logger.info( + f"healing required, fetching checkpoint metadata from {recover_src_manager_address=} {max_step=}" + ) + primary_client = ManagerClient( + recover_src_manager_address, connect_timeout=self._connect_timeout + ) + checkpoint_metadata = primary_client.checkpoint_metadata( + self._rank, timeout=self._timeout + ) + recover_src_rank = quorum.recover_src_rank + assert ( + recover_src_rank is not None + ), "must have a recover rank when healing" + + self._logger.info( + f"fetching checkpoint from {recover_src_rank=} with {checkpoint_metadata=}" + ) - # This isn't strictly needed as loading the state_dict above should - # restore the correct step but it makes writing tests simpler. - self._step = max_step + # we apply the user state dict only when safe from the main thread + # save it for now + self._pending_state_dict = self._checkpoint_transport.recv_checkpoint( + src_rank=recover_src_rank, + metadata=checkpoint_metadata, + step=max_step, + timeout=self._timeout, + ) + + # pyre-fixme[6]: got object + self.load_state_dict(self._pending_state_dict["torchft"]) + + # This isn't strictly needed as loading the state_dict above should + # restore the correct step but it makes writing tests simpler. + self._step = max_step def _apply_pending_state_dict(self) -> None: assert self._healing, "must be in healing state" @@ -504,8 +535,12 @@ def _apply_pending_state_dict(self) -> None: self._logger.info("applying pending state dict") assert self._pending_state_dict is not None, "checkpoint was not staged" + assert ( + self._load_state_dict is not None + ), "user load_state_dict is not initialized." self._load_state_dict(self._pending_state_dict["user"]) self._pending_state_dict = None + self._logger.info("Loaded state dict.") def should_commit(self, timeout: Optional[timedelta] = None) -> bool: """ @@ -553,7 +588,7 @@ def should_commit(self, timeout: Optional[timedelta] = None) -> bool: f"should_commit={should_commit} enough_replicas={enough_replicas}, errored={self._errored}" ) - self._ckpt_server.disallow_checkpoint() + self._checkpoint_transport.disallow_checkpoint() # decide whether we're in a healthy state to increase the step count if should_commit: @@ -574,6 +609,13 @@ def load_state_dict(self, state_dict: Dict[str, int]) -> None: self._step = state_dict["step"] self._batches_committed = state_dict["batches_committed"] + def _manager_state_dict(self) -> Dict[str, object]: + assert self._user_state_dict is not None, "user state_dict is not initialized." + return { + "user": self._user_state_dict(), + "torchft": self.state_dict(), + } + def state_dict(self) -> Dict[str, int]: """ Get the state dict for this manager. @@ -610,15 +652,35 @@ def batches_committed(self) -> int: """ return self._batches_committed + def participating_rank(self) -> Optional[int]: + """ + Get the replica group rank of the current quorum. This will be the same on all + ranks within the replica group. + + If this replica group is not participating in the current quorum, this will be None. + + This will block on the async quorum if it is not yet ready. + + Returns: + the rank of the current quorum + """ + self.wait_quorum() + + return self._participating_rank + def num_participants(self) -> int: """ Get the number of participants in the current quorum. This is the number of replicas participating in the current step. + This will block on the async quorum if it is not yet ready. + Returns: the number of participants in the current quorum """ + self.wait_quorum() + assert self._participating_world_size >= 0, "internal error" return self._participating_world_size diff --git a/torchft/manager_integ_test.py b/torchft/manager_integ_test.py index d6e7bde..8c7c45d 100644 --- a/torchft/manager_integ_test.py +++ b/torchft/manager_integ_test.py @@ -1,3 +1,4 @@ +import copy import logging import threading import time @@ -5,7 +6,7 @@ from contextlib import ExitStack, contextmanager from dataclasses import dataclass, field from datetime import timedelta -from typing import Dict, Generator, List, Protocol, Set, Tuple +from typing import Any, Dict, Generator, List, Protocol, Set, Tuple from unittest import TestCase import torch @@ -14,7 +15,7 @@ from torch import nn, optim from torchft.ddp import DistributedDataParallel -from torchft.local_sgd import LocalSGD +from torchft.local_sgd import DiLoCo, LocalSGD from torchft.manager import Manager from torchft.optim import OptimizerWrapper from torchft.process_group import ProcessGroupGloo @@ -76,6 +77,7 @@ class Runner: world_size: int = 1 attempts: int = 3 manager_args: Dict[str, object] = field(default_factory=dict) + train_loop_args: Dict[str, Any] = field(default_factory=dict) def _replica_main(self) -> List[Dict[str, Dict[str, object]]]: store = dist.TCPStore( @@ -103,7 +105,7 @@ def _replica_main(self) -> List[Dict[str, Dict[str, object]]]: try: fut.result() except Exception as e: - logger.exception(f"worker threw exception: {e}") + logger.exception(f"worker {self.replica_id=} threw exception: {e}") raise return [fut.result() for fut in futures] @@ -159,7 +161,7 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-fixme[6]: Incompatible parameter type **runner.manager_args, ) - stack.callback(manager.shutdown) + stack.callback(lambda: manager.shutdown(wait=False)) m: nn.Module = DistributedDataParallel(manager, MyModel()) optimizer: optim.Optimizer = OptimizerWrapper( @@ -223,34 +225,118 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-fixme[6]: Incompatible parameter type **runner.manager_args, ) - stack.callback(manager.shutdown) + stack.callback(lambda: manager.shutdown(wait=False)) m: nn.Module = MyModel() optimizer: optim.Optimizer = optim.Adam(m.parameters()) - m = LocalSGD(manager, m, optimizer, sync_every=2) criterion = nn.CrossEntropyLoss() - while True: - inputs = torch.rand(2, 3) - labels = torch.randint(4, (2,)) + with LocalSGD(manager, m, optimizer, sync_every=2): + while True: + inputs = torch.rand(2, 3) + labels = torch.randint(4, (2,)) - optimizer.zero_grad() - out = m(inputs) - loss = criterion(out, labels) + optimizer.zero_grad() + out = m(inputs) + loss = criterion(out, labels) - loss.backward() + loss.backward() - optimizer.step() + optimizer.step() - if manager.current_step() >= 4: - break + if manager.current_step() >= 4: + break - runner.failure_injector.check(rank, manager.current_step()) + runner.failure_injector.check(rank, manager.current_step()) # return state_dict so we can check consistency return state_dict() +def diloco_train_loop( + rank: int, + store_port: int, + runner: Runner, +) -> Dict[str, Dict[str, object]]: + with ExitStack() as stack: + # Declare the model and optimizers + m: nn.Module = MyModel() + model_state_dict: Dict[str, Any] = runner.train_loop_args["model_state_dict"] + m.load_state_dict(model_state_dict) + + # Setup optimizers + inner_optimizer: optim.Optimizer = torch.optim.AdamW( + m.parameters(), lr=4e-4, weight_decay=0.1, betas=(0.9, 0.95) + ) + outer_optimizer: optim.Optimizer = torch.optim.SGD( + m.parameters(), lr=0.7, momentum=0.9, nesterov=True + ) + + # pyre-ignore[53] + def load_state_dict(state_dict: Dict[str, Dict[str, object]]) -> None: + m.load_state_dict(state_dict["model"]) + # TODO: make this cleaner so we don't have to save this + diloco._backup_parameters = state_dict["backup_params"] + inner_optimizer.load_state_dict(state_dict["inner_optim"]) + outer_optimizer.load_state_dict(state_dict["outer_optim"]) + + def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53] + return { + "model": m.state_dict(), + "backup_params": copy.deepcopy(diloco._backup_parameters), + "inner_optim": inner_optimizer.state_dict(), + "outer_optim": outer_optimizer.state_dict(), + } + + print(f"worker {runner.replica_id=} {rank=} {runner.world_size=} starting") + + pg = ProcessGroupGloo() + manager = Manager( + pg=pg, + min_replica_size=2, + use_async_quorum=False, + load_state_dict=load_state_dict, + state_dict=state_dict, + replica_id=str(runner.replica_id), + store_addr="localhost", + store_port=store_port, + rank=rank, + world_size=runner.world_size, + lighthouse_addr=runner.lighthouse_address, + port=19530 + runner.replica_id, + # pyre-fixme[6]: Incompatible parameter type + **runner.manager_args, + ) + stack.callback(manager.shutdown) + + criterion = nn.CrossEntropyLoss() + all_state_dicts = {} + with DiLoCo( + manager, m, inner_optimizer, outer_optimizer, sync_every=2 + ) as diloco: + while True: + inputs = torch.rand(2, 3) + labels = torch.randint(4, (2,)) + + out = m(inputs) + loss = criterion(out, labels) + + inner_optimizer.zero_grad() + loss.backward() + inner_optimizer.step() + manager_step_str = str(manager.current_step()) + all_state_dicts[manager_step_str] = state_dict() + + # after 4 model updates then break + if manager.current_step() >= 4: + break + + runner.failure_injector.check(rank, manager.current_step()) + + # return state_dict so we can check consistency + return all_state_dicts + + class ManagerIntegTest(TestCase): @contextmanager def assertElapsedLessThan( @@ -431,6 +517,108 @@ def test_local_sgd_recovery(self) -> None: self.assertEqual(failure_injectors[1].count, 1) + def test_diloco_healthy(self) -> None: + lighthouse = Lighthouse( + bind="[::]:0", + min_replicas=2, + ) + num_replicas = 2 + futures = [] + + torch.manual_seed(42) + # Initialize the model so we can pass in the state_dict + m: nn.Module = MyModel() + + with ThreadPoolExecutor(max_workers=num_replicas) as executor: + for replica_id in range(num_replicas): + failure_injector = FailureInjector() + runner = Runner( + replica_id=replica_id, + lighthouse_address=lighthouse.address(), + failure_injector=failure_injector, + train_loop=diloco_train_loop, + train_loop_args={ + "model_state_dict": m.state_dict(), + }, + ) + futures.append(executor.submit(runner.run_replica)) + + state_dicts = [] + + for fut in as_completed(futures): + state_dicts.append(fut.result()[0]) + + lighthouse.shutdown() + + for replica_group in state_dicts: + for step, state_dict in replica_group.items(): + # inner optimizer will be different, outer optimizer and model should be the same + torch.testing.assert_close( + state_dict["backup_params"], + state_dicts[0][str(step)]["backup_params"], + ) + torch.testing.assert_close( + state_dict["outer_optim"], state_dicts[0][str(step)]["outer_optim"] + ) + + def test_diloco_recovery(self) -> None: + lighthouse = Lighthouse( + bind="[::]:0", + min_replicas=2, + ) + num_replicas = 2 + futures = [] + + failure_injectors = [ + FailureInjector(), + FailureInjector().fail_at(0, 2), + ] + + torch.manual_seed(42) + # Initialize the model so we can pass in the state_dict + m: nn.Module = MyModel() + + with ThreadPoolExecutor(max_workers=num_replicas) as executor: + for replica_id, failure_injector in zip( + range(num_replicas), failure_injectors + ): + runner = Runner( + replica_id=replica_id, + lighthouse_address=lighthouse.address(), + failure_injector=failure_injector, + train_loop=diloco_train_loop, + train_loop_args={ + "model_state_dict": m.state_dict(), + }, + ) + futures.append(executor.submit(runner.run_replica)) + + state_dicts = [] + + for fut in as_completed(futures): + try: + state_dicts.append(fut.result()[0]) + except Exception as e: + print(e) + raise + + lighthouse.shutdown() + for replica_group in state_dicts: + for step, state_dict in replica_group.items(): + str_step = str(step) + if str_step in state_dicts[0]: + # inner optimizer will be different, outer optimizer and model should be the same + torch.testing.assert_close( + state_dict["backup_params"], + state_dicts[0][str_step]["backup_params"], + ) + torch.testing.assert_close( + state_dict["outer_optim"], + state_dicts[0][str_step]["outer_optim"], + ) + + self.assertEqual(failure_injectors[1].count, 1) + def test_quorum_timeout(self) -> None: with ExitStack() as stack: lighthouse = Lighthouse( @@ -460,7 +648,7 @@ def test_quorum_timeout(self) -> None: port=19530, use_async_quorum=False, ) - stack.callback(manager.shutdown) + stack.callback(lambda: manager.shutdown(wait=False)) with self.assertElapsedLessThan(1.0): with self.assertRaisesRegex( diff --git a/torchft/manager_test.py b/torchft/manager_test.py index 97b891c..3c1b662 100644 --- a/torchft/manager_test.py +++ b/torchft/manager_test.py @@ -4,7 +4,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import concurrent from datetime import timedelta +from typing import Optional from unittest import TestCase from unittest.mock import MagicMock, create_autospec, patch @@ -13,7 +15,7 @@ from torchft.manager import MANAGER_ADDR_KEY, REPLICA_ID_KEY, Manager, WorldSizeMode from torchft.process_group import ProcessGroup, _DummyWork -from torchft.torchft import ManagerClient +from torchft.torchft import QuorumResult def mock_should_commit( @@ -25,13 +27,19 @@ def mock_should_commit( class TestManager(TestCase): store: TCPStore # pyre-fixme[13]: never initialized load_state_dict: MagicMock # pyre-fixme[13]: never initialized + manager: Optional[Manager] # pyre-fixme[13]: never initialized + + def tearDown(self) -> None: + manager = self.manager + if manager is not None: + manager.shutdown(wait=False) def _create_manager( self, use_async_quorum: bool = True, min_replica_size: int = 2, world_size_mode: WorldSizeMode = WorldSizeMode.DYNAMIC, - timeout: timedelta = timedelta(seconds=60), + timeout: timedelta = timedelta(seconds=10), ) -> Manager: pg = create_autospec(ProcessGroup) self.store = TCPStore( @@ -58,6 +66,7 @@ def _create_manager( world_size_mode=world_size_mode, timeout=timeout, ) + self.manager = manager return manager @patch("torchft.manager.ManagerClient", autospec=True) @@ -87,22 +96,54 @@ def test_state_dict(self, client_mock: MagicMock) -> None: self.assertEqual(manager.current_step(), 1234) self.assertEqual(manager.batches_committed(), 2345) + @patch("torchft.manager.ManagerClient", autospec=True) + def test_user_state_dict(self, client_mock: MagicMock) -> None: + manager = self._create_manager() + + self.assertEqual( + manager._manager_state_dict(), + { + "user": {}, + "torchft": { + "step": 0, + "batches_committed": 0, + }, + }, + ) + + manager.set_state_dict_fns( + self.load_state_dict, + lambda: {"new_state": 1}, + ) + + self.assertEqual( + manager._manager_state_dict(), + { + "user": {"new_state": 1}, + "torchft": { + "step": 0, + "batches_committed": 0, + }, + }, + ) + @patch("torchft.manager.ManagerClient", autospec=True) def test_quorum_happy(self, client_mock: MagicMock) -> None: manager = self._create_manager() client_mock().should_commit = mock_should_commit - client_mock().quorum.return_value = ( - 123, # quorum_id - 1, # replica_rank - 2, # replica_world - "manager address", - f"localhost:{self.store.port}", - 1, # max_step - 1, # max_rank - 2, # max_world_size - False, # heal - ) + quorum = QuorumResult() + quorum.quorum_id = 123 + quorum.replica_rank = 1 + quorum.replica_world_size = 2 + quorum.recover_src_manager_address = "manager address" + quorum.store_address = f"localhost:{self.store.port}" + quorum.max_step = 1 + quorum.max_rank = 1 + quorum.max_world_size = 2 + quorum.heal = False + + client_mock().quorum.return_value = quorum self.assertEqual(manager._quorum_id, -1) self.assertEqual(manager.current_step(), 0) @@ -127,21 +168,30 @@ def test_quorum_heal_sync(self, client_mock: MagicMock) -> None: manager = self._create_manager(use_async_quorum=False) client_mock().should_commit = mock_should_commit - client_mock().quorum.return_value = ( - 123, # quorum_id - 1, # replica_rank - 2, # replica_world - "manager address", - f"localhost:{self.store.port}", - 20, # max_step - None, # max_rank - 2, # max_world_size - True, # heal - ) - # forcible increment checkpoint server to compute correct address - manager._ckpt_server.allow_checkpoint(manager.current_step()) + quorum = QuorumResult() + quorum.quorum_id = 123 + quorum.replica_rank = 1 + quorum.replica_world_size = 2 + quorum.recover_src_manager_address = "manager address" + quorum.recover_src_rank = 0 + quorum.store_address = f"localhost:{self.store.port}" + quorum.max_step = 20 + quorum.max_rank = None + quorum.max_world_size = 2 + quorum.heal = True + + client_mock().quorum.return_value = quorum - client_mock().checkpoint_address.return_value = manager._ckpt_server.address() + # forcible increment checkpoint server to compute correct address + manager._checkpoint_transport.send_checkpoint( + dst_ranks=[], + step=quorum.max_step, + state_dict=manager._manager_state_dict(), + timeout=timedelta(seconds=10), + ) + client_mock().checkpoint_metadata.return_value = ( + manager._checkpoint_transport.metadata() + ) self.assertEqual(manager._quorum_id, -1) self.assertEqual(manager.current_step(), 0) @@ -169,21 +219,30 @@ def test_quorum_heal_async_not_enough_participants( manager = self._create_manager(use_async_quorum=True, min_replica_size=2) client_mock().should_commit = mock_should_commit - client_mock().quorum.return_value = ( - 123, # quorum_id - 1, # replica_rank - 2, # replica_world - "manager address", - f"localhost:{self.store.port}", - 20, # max_step - None, # max_rank - 1, # max_world_size - True, # heal - ) - # forcible increment checkpoint server to compute correct address - manager._ckpt_server.allow_checkpoint(manager.current_step()) + quorum = QuorumResult() + quorum.quorum_id = 123 + quorum.replica_rank = 1 + quorum.replica_world_size = 2 + quorum.recover_src_manager_address = "manager address" + quorum.recover_src_rank = 0 + quorum.store_address = f"localhost:{self.store.port}" + quorum.max_step = 20 + quorum.max_rank = None + quorum.max_world_size = 1 + quorum.heal = True + + client_mock().quorum.return_value = quorum - client_mock().checkpoint_address.return_value = manager._ckpt_server.address() + # forcible increment checkpoint server to compute correct address + manager._checkpoint_transport.send_checkpoint( + dst_ranks=[], + step=quorum.max_step, + state_dict=manager._manager_state_dict(), + timeout=timedelta(seconds=10), + ) + client_mock().checkpoint_metadata.return_value = ( + manager._checkpoint_transport.metadata() + ) self.assertEqual(manager._quorum_id, -1) self.assertEqual(manager.current_step(), 0) @@ -212,6 +271,7 @@ def test_quorum_heal_async_not_enough_participants( self.assertEqual(self.load_state_dict.call_count, 1) # failed to commit so no step + quorum.heal = False manager.start_quorum() self.assertEqual(manager.current_step(), 20) self.assertEqual(manager.batches_committed(), 0) @@ -221,21 +281,30 @@ def test_quorum_heal_async_zero_grad(self, client_mock: MagicMock) -> None: manager = self._create_manager(use_async_quorum=True, min_replica_size=1) client_mock().should_commit = mock_should_commit - client_mock().quorum.return_value = ( - 123, # quorum_id - 1, # replica_rank - 2, # replica_world - "manager address", - f"localhost:{self.store.port}", - 20, # max_step - None, # max_rank - 1, # max_world_size - True, # heal - ) - # forceable increment checkpoint server to compute correct address - manager._ckpt_server.allow_checkpoint(manager.current_step()) + quorum = QuorumResult() + quorum.quorum_id = 123 + quorum.replica_rank = 1 + quorum.replica_world_size = 2 + quorum.recover_src_manager_address = "manager address" + quorum.recover_src_rank = 0 + quorum.store_address = f"localhost:{self.store.port}" + quorum.max_step = 20 + quorum.max_rank = None + quorum.max_world_size = 1 + quorum.heal = True + + client_mock().quorum.return_value = quorum - client_mock().checkpoint_address.return_value = manager._ckpt_server.address() + # forceable increment checkpoint server to compute correct address + manager._checkpoint_transport.send_checkpoint( + dst_ranks=[], + step=quorum.max_step, + state_dict=manager._manager_state_dict(), + timeout=timedelta(seconds=10), + ) + client_mock().checkpoint_metadata.return_value = ( + manager._checkpoint_transport.metadata() + ) self.assertEqual(manager._quorum_id, -1) self.assertEqual(manager.current_step(), 0) @@ -261,6 +330,8 @@ def test_quorum_heal_async_zero_grad(self, client_mock: MagicMock) -> None: self.assertEqual(self.load_state_dict.call_count, 1) + # healed + quorum.heal = False manager.start_quorum() self.assertEqual(manager.current_step(), 21) self.assertEqual(manager.batches_committed(), 1) @@ -270,17 +341,18 @@ def test_allreduce_error(self, client_mock: MagicMock) -> None: manager = self._create_manager() client_mock().should_commit = mock_should_commit - client_mock().quorum.return_value = ( - 123, # quorum_id - 1, # replica_rank - 2, # replica_world - "manager address", - f"localhost:{self.store.port}", - 1, # max_step - 1, # max_rank - 2, # max_world_size - False, # heal - ) + quorum = QuorumResult() + quorum.quorum_id = 123 + quorum.replica_rank = 1 + quorum.replica_world_size = 2 + quorum.recover_src_manager_address = "manager address" + quorum.store_address = f"localhost:{self.store.port}" + quorum.max_step = 1 + quorum.max_rank = 1 + quorum.max_world_size = 2 + quorum.heal = False + + client_mock().quorum.return_value = quorum self.assertEqual(manager._quorum_id, -1) self.assertEqual(manager.current_step(), 0) @@ -308,17 +380,8 @@ def test_allreduce_error(self, client_mock: MagicMock) -> None: manager._pg.allreduce.side_effect = None # inject failure when worked waited - client_mock().quorum.return_value = ( - 123, # quorum_id - 1, # replica_rank - 2, # replica_world - "manager address", - f"localhost:{self.store.port}", - 2, # max_step - 1, # max_rank - 2, # max_world_size - False, # heal - ) + quorum.max_step = 2 + manager.start_quorum() self.assertFalse(manager._errored) @@ -336,17 +399,7 @@ def test_allreduce_error(self, client_mock: MagicMock) -> None: manager._pg.allreduce.reset_mock(return_value=True) # recover on next step - client_mock().quorum.return_value = ( - 123, # quorum_id - 1, # replica_rank - 2, # replica_world_size - "manager address", - f"localhost:{self.store.port}", - 3, # max_step - 1, # max_rank - 2, # max_world_size - False, # heal - ) + quorum.max_step = 3 manager.start_quorum() manager.allreduce(torch.tensor([1.0])).wait() @@ -362,17 +415,18 @@ def test_quorum_fixed_world_size(self, client_mock: MagicMock) -> None: ) client_mock().should_commit = mock_should_commit - client_mock().quorum.return_value = ( - 123, # quorum_id - rank, # replica_rank - 3, # replica_world - "manager address", - f"localhost:{self.store.port}", - 1, # max_step - rank, # max_rank - 3, # max_world_size - False, # heal - ) + quorum = QuorumResult() + quorum.quorum_id = 123 + quorum.replica_rank = rank + quorum.replica_world_size = 3 + quorum.recover_src_manager_address = "manager address" + quorum.store_address = f"localhost:{self.store.port}" + quorum.max_step = 1 + quorum.max_rank = rank + quorum.max_world_size = 3 + quorum.heal = False + + client_mock().quorum.return_value = quorum self.assertEqual(manager._quorum_id, -1) self.assertEqual(manager.current_step(), 0) @@ -395,17 +449,18 @@ def test_quorum_no_healing(self, client_mock: MagicMock) -> None: ) client_mock().should_commit = mock_should_commit - client_mock().quorum.return_value = ( - 123, # quorum_id - 0, # replica_rank - 3, # replica_world - "manager address", - f"localhost:{self.store.port}", - 1, # max_step - None, # max_rank - 2, # max_world_size - True, # heal - ) + quorum = QuorumResult() + quorum.quorum_id = 123 + quorum.replica_rank = 0 + quorum.replica_world_size = 3 + quorum.recover_src_manager_address = "manager address" + quorum.recover_src_rank = 1 + quorum.store_address = f"localhost:{self.store.port}" + quorum.max_step = 1 + quorum.max_rank = None + quorum.max_world_size = 2 + quorum.heal = True + client_mock().quorum.return_value = quorum self.assertEqual(manager._quorum_id, -1) self.assertEqual(manager.current_step(), 0) @@ -467,10 +522,16 @@ def test_manager_wrap_future_timeout(self, client_mock: MagicMock) -> None: def test_manager_numerics(self, client_mock: MagicMock) -> None: manager = self._create_manager() - manager._quorum_future = MagicMock() + manager._quorum_future = quorum_future = MagicMock( + spec=concurrent.futures.Future + ) manager._participating_rank = 1 manager._participating_world_size = 5 self.assertEqual(manager.num_participants(), 5) + self.assertEqual(quorum_future.result.call_count, 1) + self.assertEqual(manager.participating_rank(), 1) + self.assertEqual(quorum_future.result.call_count, 2) + # pyre-ignore[16]: _pg is mocked manager._pg.allreduce.return_value = _DummyWork(None) @@ -492,17 +553,18 @@ def test_manager_numerics(self, client_mock: MagicMock) -> None: def test_quorum_happy_timeouts(self, client_mock: MagicMock) -> None: manager = self._create_manager(use_async_quorum=False) - client_mock().quorum.return_value = ( - 123, # quorum_id - 1, # replica_rank - 2, # replica_world - "manager address", - f"localhost:{self.store.port}", - 1, # max_step - 1, # max_rank - 2, # max_world_size - False, # heal - ) + quorum = QuorumResult() + quorum.quorum_id = 123 + quorum.replica_rank = 1 + quorum.replica_world_size = 2 + quorum.recover_src_manager_address = "manager address" + quorum.store_address = f"localhost:{self.store.port}" + quorum.max_step = 1 + quorum.max_rank = 1 + quorum.max_world_size = 2 + quorum.heal = False + + client_mock().quorum.return_value = quorum manager.start_quorum(timeout=timedelta(seconds=12)) self.assertEqual( diff --git a/torchft/optim.py b/torchft/optim.py index ce24823..0583d94 100644 --- a/torchft/optim.py +++ b/torchft/optim.py @@ -12,8 +12,9 @@ """ -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional +import torch from torch.optim import Optimizer if TYPE_CHECKING: @@ -52,3 +53,11 @@ def step(self, closure: Optional[object] = None) -> None: assert closure is None, "optimizers that use closures are not supported" if self.manager.should_commit(): self.optim.step() + + @property + def param_groups(self) -> List[Dict[str, Any]]: + return self.optim.param_groups + + @property + def state(self) -> Mapping[torch.Tensor, Any]: # pyre-fixme[3] + return self.optim.state diff --git a/torchft/optim_test.py b/torchft/optim_test.py index 50412d8..5dd6964 100644 --- a/torchft/optim_test.py +++ b/torchft/optim_test.py @@ -7,6 +7,7 @@ from unittest import TestCase from unittest.mock import MagicMock, create_autospec +import torch from torch.nn import Linear from torch.optim import AdamW @@ -34,9 +35,16 @@ def test_optimizer_wrapper(self) -> None: optim.zero_grad() self.assertEqual(manager.start_quorum.call_count, 1) + b = torch.rand(3) + m(b).sum().backward() + manager.should_commit.return_value = True optim.step() manager.should_commit.return_value = False optim.step() + self.assertEqual(len(optim.param_groups), 2) + self.assertEqual(optim.param_groups[1]["lr"], 1e-4) + self.assertEqual(optim.param_groups[1]["params"], []) + self.assertEqual(len(optim.state), len(list(m.parameters()))) self.assertEqual(manager.should_commit.call_count, 2) diff --git a/torchft/process_group.py b/torchft/process_group.py index 34797c7..4790352 100644 --- a/torchft/process_group.py +++ b/torchft/process_group.py @@ -19,9 +19,22 @@ import logging import queue import threading -from abc import ABC +from contextlib import contextmanager, nullcontext +from dataclasses import dataclass from datetime import timedelta -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Generator, + List, + Optional, + Tuple, + TypeVar, + Union, + cast, +) import torch import torch.distributed as dist @@ -30,7 +43,6 @@ # pyre-fixme[21]: no attribute ProcessGroupNCCL # pyre-fixme[21]: no attribute ProcessGroupGloo from torch.distributed import ( - BroadcastOptions, DeviceMesh, PrefixStore, ProcessGroup as BaseProcessGroup, @@ -41,8 +53,15 @@ get_rank, init_device_mesh, ) -from torch.distributed.distributed_c10d import Work, _world +from torch.distributed.distributed_c10d import ( + AllgatherOptions, + AllreduceOptions, + BroadcastOptions, + ReduceOp, + Work, +) from torch.futures import Future +from torch.utils._pytree import tree_any if TYPE_CHECKING: from torchft.manager import Manager @@ -55,6 +74,9 @@ _FUTURE_EXCEPTION = "fut_exception" +T = TypeVar("T") + + def _get(q: mp.Queue, timeout: Union[float, timedelta]) -> object: """ Gets an item from a queue with a timeout. If the timeout is exceeded then @@ -123,7 +145,9 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None: raise NotImplementedError("not implemented") # pyre-fixme[14]: inconsistent override - def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work: + def allreduce( + self, tensors: List[torch.Tensor], opts: Union[AllreduceOptions, ReduceOp] + ) -> Work: raise NotImplementedError("not implemented") # pyre-fixme[14]: inconsistent override @@ -131,12 +155,24 @@ def allgather( self, output_tensors: List[List[torch.Tensor]], input_tensor: List[torch.Tensor], - opts: object, + opts: AllgatherOptions, ) -> Work: + """ + Gathers tensors from the whole group in a list. + + See torch.distributed.all_gather for more details. + """ raise NotImplementedError("not implemented") # pyre-fixme[14]: inconsistent override - def broadcast(self, tensor_list: List[torch.Tensor], opts: object) -> Work: + def broadcast( + self, tensor_list: List[torch.Tensor], opts: BroadcastOptions + ) -> Work: + """ + Broadcasts the tensor to the whole group. + + See torch.distributed.broadcast for more details. + """ raise NotImplementedError("not implemented") def broadcast_one(self, tensor: torch.Tensor, root: int) -> Work: @@ -507,6 +543,10 @@ def __init__(self, manager: "Manager") -> None: self._manager = manager def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work: + # Ensure we have a valid quorum and are configured before trying to do + # any work. + self._manager.wait_quorum() + if self._manager.errored() is not None: return _DummyWork(tensors) @@ -547,26 +587,52 @@ def __init__( self._timeout = timeout def wait(self, timeout: Optional[timedelta] = None) -> bool: + self._pg._assert_alive() + self._tx.put(("wait", self._op_id), timeout=self._timeout) - assert _get(self._rx, self._timeout) == self._op_id + op_id, event = cast( + Tuple[int, Optional[torch.cuda.Event]], + _get(self._rx, timeout or self._timeout), + ) + assert op_id == self._op_id + if event is not None: + event.wait() return True + def synchronize(self) -> None: + # TODO: No one seems to use this and NCCL wait already only waits the + # stream and is non-blocking on the CPU side so no real need for a + # separate call. + raise NotImplementedError("not implemented") + def get_future(self) -> Future[object]: return self._pg._get_future(self._op_id) + def __del__(self) -> None: + self._tx.put(("del", self._op_id), timeout=self._timeout) -class _BabyWorkNCCL(_BabyWork): - def wait(self, timeout: Optional[timedelta] = None) -> bool: - self._tx.put(("synchronize", self._op_id), timeout=self._timeout) - # pyre-fixme[23]: unable to unpack into 2 values - op_id, event = _get(self._rx, self._timeout) - assert op_id == self._op_id - assert isinstance(event, torch.cuda.Event) - # Wait on Event makes the stream wait but not the CPU thread. - event.wait() +def _is_any_cuda(obj: object) -> bool: + """ + Returns true if any of the tensors in the object are CUDA tensors. - return True + Supports lists, tuples, dicts, and tensors. + """ + return tree_any(lambda obj: isinstance(obj, torch.Tensor) and obj.is_cuda, obj) + + +@dataclass +class _OpMetadata: + work: Work + stream: Optional[torch.cuda.Stream] + + @contextmanager + def set_stream(self) -> Generator[None, None, None]: + if self.stream is not None: + with torch.cuda.stream(self.stream): + yield + else: + yield class ProcessGroupBaby(ProcessGroup): @@ -575,11 +641,8 @@ class ProcessGroupBaby(ProcessGroup): subprocess. Since it's running in a subprocess all tensors need to be in shared memory or will be moved to shared memory. CUDA tensors are implicitly share able and don't need any changes. - """ - WORK_CLASS: Type[_BabyWork] = _BabyWork - def __init__(self, timeout: Union[float, timedelta] = 60.0) -> None: super().__init__(0, 1) @@ -609,8 +672,14 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None: if self._rx is not None: self._rx.close() if self._future_queue is not None: + # wait for the future thread to exit and then close the queue self._future_queue.put(_QUEUE_CLOSE) - assert self._future_queue is not None + assert self._future_thread is not None + self._future_thread.join(timeout=10.0) + # pyre-ignore[16]: optional value is checked above + if self._future_thread.is_alive(): + raise RuntimeError("future thread did not exit") + # pyre-ignore[16]: optional value is checked above self._future_queue.close() ctx = mp.get_context("spawn") @@ -631,7 +700,14 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None: self._p = ctx.Process( target=self._worker, - args=(store_addr, rank, world_size, self._tx, self._rx, self._future_queue), + args=( + store_addr, + rank, + world_size, + self._tx, + self._rx, + self._future_queue, + ), daemon=True, ) self._p.start() @@ -668,23 +744,73 @@ def _worker( return tx.put(None) - work = {} + streams: Dict[str, torch.cuda.Stream] = {} + work: Dict[int, _OpMetadata] = {} next_op_id: int = 0 while True: op = rx.get() cmd = op[0] if cmd == "func": - func_name, args, kwargs = op[1:] - fn = getattr(pg, func_name) - work[next_op_id] = fn(*args, **kwargs) + func_name, args, kwargs, stream_device, stream_id, event = op[1:] + + # To avoid potential deadlocks we need to preserve the + # stream/synchronization behavior of the parent process. + # We allocate one Stream per stream_id to make sure that we + # don't accidentally introduce cross stream synchronization + # points. + if stream_id is not None: + stream_key = f"{stream_device}/{stream_id}" + if stream_key not in streams: + streams[stream_key] = torch.cuda.Stream( + device=stream_device + ) + stream = streams[stream_key] + else: + stream = None + + with ( + torch.cuda.stream(stream) + if stream is not None + else nullcontext() + ): + # Make the stream wait on the cuda event to make sure we + # don't start the operation until the tensor is ready. + if event is not None: + event.wait() + + args = _PickleSafeOptions.unsafe_args(args) + fn = getattr(pg, func_name) + work[next_op_id] = _OpMetadata( + work=fn(*args, **kwargs), + stream=stream, + ) tx.put(next_op_id) next_op_id += 1 elif cmd == "wait": op_id: int = op[1] - work[op_id].wait() + + metadata = work[op_id] + + with metadata.set_stream(): + # With WorkNCCL this makes the stream wait not the CPU when + # no timeout is passed. + metadata.work.wait() + + # Register event on the stream that we can pass to the main + # process. + event = ( + torch.cuda.current_stream().record_event( + torch.cuda.Event(interprocess=True) + ) + if metadata.stream is not None + else None + ) + + tx.put((op_id, event)) + elif cmd == "del": + op_id: int = op[1] del work[op_id] - tx.put(op_id) elif cmd == "future": op_id: int = op[1] @@ -695,29 +821,17 @@ def callback(fut: Future[object]) -> None: except Exception as e: future_queue.put((op_id, _FUTURE_EXCEPTION, e)) - work[op_id].get_future().add_done_callback(callback) + work[op_id].work.get_future().add_done_callback(callback) tx.put(op_id) - elif cmd == "synchronize": - # CUDA only, use events instead of waiting on CPU - op_id = op[1] - - # With WorkNCCL this makes the stream wait not the CPU when - # no timeout is passed. - work[op_id].wait() - - # Register event on the stream that we can pass to the main - # process. - event = torch.cuda.Event(interprocess=True) - event.record() - - del work[op_id] - tx.put((op_id, event)) + elif cmd == "num_active_work": + tx.put(len(work)) else: raise ValueError(f"unknown cmd: {cmd}") except Exception as e: logger.exception("worker errored") tx.put(e) + raise def _future_handler(self, future_queue: mp.Queue) -> None: try: @@ -739,6 +853,8 @@ def _future_handler(self, future_queue: mp.Queue) -> None: logger.exception(f"got unexpected error in future handler: {e}") def _get_future(self, op_id: int) -> Future[object]: + self._assert_alive() + with self._futures_lock: fut = Future() # pyre-fixme[29]: is not a function self._futures[op_id] = fut @@ -751,21 +867,58 @@ def _get_future(self, op_id: int) -> Future[object]: return fut def _run_func(self, func: str, *args: object, **kwargs: object) -> Work: + self._assert_alive() + rx = self._rx tx = self._tx assert rx is not None assert tx is not None - tx.put(("func", func, args, kwargs), timeout=self._timeout) + is_cuda = _is_any_cuda(args) + + stream_device = torch.cuda.current_stream().device if is_cuda else None + stream_id = torch.cuda.current_stream().stream_id if is_cuda else None + event = ( + torch.cuda.current_stream().record_event( + torch.cuda.Event(interprocess=True) + ) + if is_cuda + else None + ) + + tx.put( + ( + "func", + func, + _PickleSafeOptions.safe_args(args), + kwargs, + stream_device, + stream_id, + event, + ), + timeout=self._timeout, + ) op_id = _get(rx, self._timeout) assert isinstance(op_id, int), f"invalid return {op_id}" - return self.WORK_CLASS( - pg=self, tx=tx, rx=rx, op_id=op_id, timeout=self._timeout - ) + return _BabyWork(pg=self, tx=tx, rx=rx, op_id=op_id, timeout=self._timeout) - def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work: + def _assert_alive(self) -> None: + """ + Assert that the process group is alive. This is used to ensure that + operations are not performed on a dead process group and any errors are surfaced. + """ + p = self._p + assert p is not None + if not p.is_alive(): + raise RuntimeError(f"child process {p.pid=} is dead {p.exitcode=}") + + def allreduce( + self, + tensors: List[torch.Tensor], + opts: Union[dist.AllreduceOptions, dist.ReduceOp], + ) -> Work: assert isinstance(tensors, list), "input must be list" for tensor in tensors: @@ -774,9 +927,90 @@ def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work: return self._run_func("allreduce", tensors, opts) + def allgather( + self, + output_tensors: List[List[torch.Tensor]], + input_tensor: List[torch.Tensor], + opts: AllgatherOptions, + ) -> Work: + assert isinstance(output_tensors, list), "input must be list" + assert isinstance(input_tensor, list), "input must be list" + + for tensor_list in output_tensors: + for tensor in tensor_list: + if not tensor.is_shared(): + tensor.share_memory_() + + for tensor in input_tensor: + if not tensor.is_shared(): + tensor.share_memory_() + + return self._run_func("allgather", output_tensors, input_tensor, opts) + + def broadcast( + self, + tensor_list: List[torch.Tensor], + opts: BroadcastOptions, + ) -> Work: + assert isinstance(tensor_list, list), "input must be list" + + for tensor in tensor_list: + if not tensor.is_shared(): + tensor.share_memory_() + + return self._run_func("broadcast", tensor_list, opts) + def size(self) -> int: return self._world_size + def num_active_work(self) -> int: + assert self._tx is not None + self._tx.put(("num_active_work",), timeout=self._timeout) + + assert self._rx is not None + return cast(int, _get(self._rx, self._timeout)) + + +@dataclass +class _PickleSafeOptions: + func: Callable[[], object] + fields: Dict[str, object] + + @classmethod + def safe_args(cls, args: T) -> T: + if isinstance(args, tuple): + return tuple(cls.safe_args(arg) for arg in args) + elif isinstance(args, list): + return [cls.safe_args(arg) for arg in args] + elif isinstance(args, (AllreduceOptions, AllgatherOptions, BroadcastOptions)): + return cls.from_torch(args) + else: + return args + + @classmethod + def unsafe_args(cls, args: T) -> T: + if isinstance(args, tuple): + return tuple(cls.unsafe_args(arg) for arg in args) + elif isinstance(args, list): + return [cls.unsafe_args(arg) for arg in args] + elif isinstance(args, cls): + return args.to_torch() + else: + return args + + @classmethod + def from_torch(cls, opts: object) -> "_PickleSafeOptions": + return cls( + func=opts.__class__, + fields={k: getattr(opts, k) for k in dir(opts) if not k.startswith("_")}, + ) + + def to_torch(self) -> object: + opts = self.func() + for k, v in self.fields.items(): + setattr(opts, k, v) + return opts + class ProcessGroupBabyGloo(ProcessGroupBaby): """ @@ -811,8 +1045,6 @@ class ProcessGroupBabyNCCL(ProcessGroupBaby): tensors may leak in the current PyTorch implementation. TODO fix """ - WORK_CLASS = _BabyWorkNCCL - @classmethod def _create_pg(cls, store: Store, rank: int, world_size: int) -> BaseProcessGroup: # pyre-fixme[16]: no attribute ProcessGroupNCCL @@ -852,6 +1084,8 @@ def extend_device_mesh( class ManagedDeviceMesh(DeviceMesh): + replicate_pg_singleton: Optional["ManagedProcessGroup"] = None + def __init__( self, mesh: Optional[DeviceMesh], @@ -880,6 +1114,16 @@ def __init__( self._flatten_mesh_list: Tuple[DeviceMesh, ...] = tuple() self._thread_id: Optional[int] = None + def __getstate__(self) -> Dict[str, Any]: + state = self.__dict__.copy() + state["replicate_pg"] = None + return state + + def __setstate__(self, state: Dict[str, Any]) -> None: + self.__dict__.update(state) + assert self.replicate_pg_singleton is not None + self.replicate_pg = self.replicate_pg_singleton + def __getitem__(self, mesh_dim_names: Union[str, Tuple[str, ...]]) -> DeviceMesh: if isinstance(mesh_dim_names, str): if mesh_dim_names == self.replicate_dim_name: @@ -897,13 +1141,16 @@ def __getitem__(self, mesh_dim_names: Union[str, Tuple[str, ...]]) -> DeviceMesh return self.mesh[mesh_dim_names] else: assert isinstance(mesh_dim_names, tuple) - if self.replicate_dim_name in mesh_dim_names: + if self.replicate_dim_name not in mesh_dim_names: assert self.mesh is not None return self.mesh[mesh_dim_names] else: + mesh_dim_names_wo_replicate = tuple( + n for n in mesh_dim_names if n != self.replicate_dim_name + ) assert self.mesh is not None return ManagedDeviceMesh( - self.mesh[mesh_dim_names], + self.mesh[mesh_dim_names_wo_replicate], mesh_dim_names, self.replicate_pg, mesh_dim_names.index(self.replicate_dim_name), @@ -938,14 +1185,18 @@ def _flatten(self, mesh_dim_name: Optional[str]) -> "DeviceMesh": return flatten_mesh def size(self, mesh_dim: Optional[int] = None) -> int: + replicate_pg_size = self.replicate_pg.size() + # We have to lie to the users if there are zero particpants. + # This is possible during the initialization stage of training. + replicate_pg_size = 1 if replicate_pg_size == 0 else replicate_pg_size if mesh_dim is None: if self.mesh is None: - return self.replicate_pg.size() + return replicate_pg_size else: assert self.mesh is not None - return self.mesh.size() * self.replicate_pg.size() + return self.mesh.size() * replicate_pg_size elif mesh_dim == self.replicate_dim: - return self.replicate_pg.size() + return replicate_pg_size else: assert self.mesh is not None return self.mesh.size(self._real_mesh_dim(mesh_dim)) @@ -995,7 +1246,16 @@ def get_coordinate(self) -> Optional[List[int]]: dimensions of the mesh. If this rank is not part of the mesh, return None. """ assert self.mesh is not None - return self.mesh._coordinate_on_dim if self.mesh._coordinate_on_dim else None + coordinate = ( + self.mesh._coordinate_on_dim if self.mesh._coordinate_on_dim else None + ) + if not coordinate: + return coordinate + + # We need to copy be cause we are going to modify the coordinate. + coordinate = coordinate.copy() + coordinate.insert(get_rank(self.replicate_pg), self.replicate_dim) + return coordinate def get_all_groups(self) -> List[BaseProcessGroup]: raise NotImplementedError @@ -1057,19 +1317,11 @@ def ft_init_device_mesh( mesh_dim_names=tuple(_mesh_dim_names), ) - if device_type == "cpu": - pg = ProcessGroupGloo() - elif device_type == "cuda": - pg = ProcessGroupNCCL() - else: - raise ValueError() - - manager._pg = pg replicate_pg = ManagedProcessGroup(manager) - # We have to use MultiProcessTestCase, otherwise c10d will complain - # the same backend has been registered. replicate_pg.register(mesh_dim_names[replicate_dim]) + ManagedDeviceMesh.replicate_pg_singleton = replicate_pg + return ManagedDeviceMesh( mesh=mesh, mesh_dim_names=mesh_dim_names, diff --git a/torchft/process_group_test.py b/torchft/process_group_test.py index e86c7e0..f765625 100644 --- a/torchft/process_group_test.py +++ b/torchft/process_group_test.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import gc +import io import multiprocessing import os import unittest @@ -86,14 +88,15 @@ def check_tensors(arg: Any) -> None: # pyre-ignore[2] check_tensors(item) # Test collectives - collectives = { - "allreduce": ([input_tensor], AllreduceOptions()), - "allgather": (output_tensors, [input_tensor], AllgatherOptions()), - "broadcast": (tensor_list, BroadcastOptions()), - "broadcast_one": (input_tensor, 0), - } + collectives = [ + ("allreduce", ([input_tensor], AllreduceOptions())), + ("allreduce", ([input_tensor], ReduceOp.SUM)), + ("allgather", (output_tensors, [input_tensor], AllgatherOptions())), + ("broadcast", (tensor_list, BroadcastOptions())), + ("broadcast_one", (input_tensor, 0)), + ] works: Dict[str, dist._Work] = {} - for coll_str, args in collectives.items(): + for coll_str, args in collectives: coll = getattr(pg, coll_str) work = coll(*args) works[coll_str] = work @@ -210,6 +213,84 @@ def test_baby_gloo_timeout(self) -> None: with self.assertRaisesRegex(TimeoutError, "timed out after 0.01 seconds"): a.configure(store_addr, 0, 2) + def test_reconfigure_baby_process_group(self) -> None: + store = TCPStore( + host_name="localhost", port=0, is_master=True, wait_for_workers=False + ) + store_addr = f"localhost:{store.port}/prefix" + + a = ProcessGroupBabyGloo() + a.configure(store_addr, 0, 1) + future_thread_1 = a._future_thread + future_queue_1 = a._future_queue + p_1 = a._p + + store_addr = f"localhost:{store.port}/prefix2" + a.configure(store_addr, 0, 1) + future_thread_2 = a._future_thread + future_queue_2 = a._future_queue + p_2 = a._p + + self.assertNotEqual(future_thread_1, future_thread_2) + self.assertNotEqual(future_queue_1, future_queue_2) + self.assertNotEqual(p_1, p_2) + + assert future_thread_1 is not None + self.assertFalse(future_thread_1.is_alive()) + assert future_queue_1 is not None + self.assertTrue(future_queue_1._closed) # pyre-ignore[16]: no attribute _closed + assert p_1 is not None + self.assertFalse(p_1.is_alive()) + + assert future_thread_2 is not None + self.assertTrue(future_thread_2.is_alive()) + assert future_queue_2 is not None + self.assertFalse(future_queue_2._closed) + assert p_2 is not None + self.assertTrue(p_2.is_alive()) + + def test_baby_gloo_apis(self) -> None: + store = TCPStore( + host_name="localhost", port=0, is_master=True, wait_for_workers=False + ) + + store_addr = f"localhost:{store.port}/prefix" + + a = ProcessGroupBabyGloo(timeout=timedelta(seconds=10)) + a.configure(store_addr, 0, 1) + + _test_pg(a) + + # force collection to ensure no BabyWork objects remain + gc.collect() + + self.assertEqual(a.num_active_work(), 0) + + # pyre-fixme[56]: Pyre was not able to infer the type of argument + @skipUnless(torch.cuda.is_available(), "needs CUDA") + def test_baby_nccl_apis(self) -> None: + # set to 1 if more than >=2 gpus + device_id = 1 % torch.cuda.device_count() + torch.cuda.set_device(device_id) + + store = TCPStore( + host_name="localhost", port=0, is_master=True, wait_for_workers=False + ) + + store_addr = f"localhost:{store.port}/prefix" + + a = ProcessGroupBabyNCCL(timeout=timedelta(seconds=10)) + a.configure(store_addr, 0, 1) + + _test_pg(a, torch.randn((2, 3), device="cuda")) + + torch.cuda.synchronize() + + # force collection to ensure no BabyWork objects remain + gc.collect() + + self.assertEqual(a.num_active_work(), 0) + def test_dummy(self) -> None: pg = ProcessGroupDummy(0, 1) m = nn.Linear(3, 4) @@ -226,12 +307,15 @@ def test_baby_nccl_2gpu(self) -> None: store_addr: str = f"localhost:{store.port}/prefix" def run(rank: int) -> Tuple[torch.Tensor, Work]: - a = ProcessGroupBabyNCCL() + a = ProcessGroupBabyNCCL( + timeout=timedelta(seconds=10.0), + ) a.configure(store_addr, rank, 2) - self.assertEqual(a.size(), 2) - at = torch.tensor([rank + 1], device=f"cuda:{rank}") + # We test using set_device to ensure stream device is correct. + torch.cuda.set_device(rank) + at = torch.tensor([rank + 1], device="cuda") a_work = a.allreduce([at], ReduceOp.SUM) return at, a_work @@ -331,7 +415,8 @@ def test_managed_process_group(self) -> None: self.assertIsInstance(list(works.values())[0], _ManagedWork) self.assertEqual(manager.report_error.call_count, 0) - self.assertEqual(manager.wrap_future.call_count, 1) + self.assertEqual(manager.wrap_future.call_count, 2) + self.assertEqual(manager.wait_quorum.call_count, 2) class DeviceMeshTest(TestCase): diff --git a/torchft/torchft.pyi b/torchft/torchft.pyi index fbd0293..2c8c6cd 100644 --- a/torchft/torchft.pyi +++ b/torchft/torchft.pyi @@ -1,5 +1,5 @@ from datetime import timedelta -from typing import Optional, Tuple +from typing import List, Optional class ManagerClient: def __init__(self, addr: str, connect_timeout: timedelta) -> None: ... @@ -7,11 +7,11 @@ class ManagerClient: self, rank: int, step: int, - checkpoint_server_addr: str, + checkpoint_metadata: str, shrink_only: bool, timeout: timedelta, - ) -> Tuple[int, int, int, str, str, int, Optional[int], int, bool]: ... - def checkpoint_address(self, rank: int, timeout: timedelta) -> str: ... + ) -> QuorumResult: ... + def checkpoint_metadata(self, rank: int, timeout: timedelta) -> str: ... def should_commit( self, rank: int, @@ -20,6 +20,19 @@ class ManagerClient: timeout: timedelta, ) -> bool: ... +class QuorumResult: + quorum_id: int + replica_rank: int + replica_world_size: int + recover_src_manager_address: str + recover_src_rank: Optional[int] + recover_dst_ranks: List[int] + store_address: str + max_step: int + max_rank: Optional[int] + max_world_size: int + heal: bool + class Manager: def __init__( self, diff --git a/train_ddp.py b/train_ddp.py index 9ad9cc8..4bcc029 100644 --- a/train_ddp.py +++ b/train_ddp.py @@ -7,6 +7,7 @@ import logging import os import sys +from datetime import timedelta import torch import torch.nn.functional as F @@ -70,7 +71,13 @@ def state_dict(): } device = "cuda" if torch.cuda.is_available() else "cpu" - pg = ProcessGroupBabyNCCL() if torch.cuda.is_available() else ProcessGroupGloo() + pg = ( + ProcessGroupBabyNCCL( + timeout=timedelta(seconds=5), + ) + if torch.cuda.is_available() + else ProcessGroupGloo(timeout=timedelta(seconds=5)) + ) manager = Manager( pg=pg, @@ -78,6 +85,7 @@ def state_dict(): load_state_dict=load_state_dict, state_dict=state_dict, replica_id=f"train_ddp_{REPLICA_GROUP_ID}", + timeout=timedelta(seconds=10), ) class Net(nn.Module): diff --git a/train_fsdp.py b/train_fsdp.py new file mode 100644 index 0000000..7a5d07c --- /dev/null +++ b/train_fsdp.py @@ -0,0 +1,157 @@ +import os +from datasets import load_dataset + +import torch +from transformers import LlamaForCausalLM, AutoTokenizer +from torch.distributed._composable.fsdp import fully_shard +import torch.distributed as dist +from tqdm import tqdm +from transformers.data import DataCollatorForSeq2Seq +from transformers.models.llama.modeling_llama import LlamaDecoderLayer +from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict + +from torchdata.stateful_dataloader import StatefulDataLoader + +from torchft import ( + DistributedSampler, + Manager, + Optimizer, + ProcessGroupBabyNCCL, + ProcessGroupGloo, +) +from torchft.process_group import ft_init_device_mesh + +def hsdp_device_mesh(replica_group_size, sharding_group_size, device=None, manager=None): + + if replica_group_size is None or sharding_group_size is None: + raise ValueError("Both replica_group_size and sharding_group_size must be provided.") + + device = device or f"cuda" + + device_mesh = ft_init_device_mesh( + device_type=device, + mesh_shape=(replica_group_size, sharding_group_size), + mesh_dim_names=("dp_replicate", "dp_shard"), + replicate_dim=0, + manager=manager, + ) + if device_mesh is None: + raise RuntimeError("Failed to create a valid device mesh.") + + return device_mesh + +def parallelize_llama(model, mesh): + sharding_conditions = [lambda m: isinstance(m, LlamaDecoderLayer)] + + for m in reversed(list(model.modules())): + if any(c(m) for c in sharding_conditions): + # fully_shard(m, mesh=mesh, reshard_after_forward=True) + fully_shard(m, mesh=mesh) + # fully_shard([model.model.embed_tokens, model.lm_head], mesh=mesh) + fully_shard(model, mesh=mesh) + +def main(): + REPLICA_GROUP_ID = int(os.environ.get("REPLICA_GROUP_ID", 0)) + NUM_REPLICA_GROUPS = int(os.environ.get("NUM_REPLICA_GROUPS", 2)) + NUM_REPLICAS = int(os.environ.get("NUM_REPLICAS", 2)) + + rank = int(os.environ.get("RANK", 0)) + + model_name = "Meta-Llama/Llama-3.2-1B-Instruct" + tokenizer = AutoTokenizer.from_pretrained(model_name) + model = LlamaForCausalLM.from_pretrained(model_name) + + if not tokenizer.pad_token_id: + tokenizer.pad_token_id = tokenizer.eos_token_id + + # If there is a mismatch between tokenizer vocab size and embedding matrix, + # throw a warning and then expand the embedding matrix + assert len(tokenizer) == model.get_input_embeddings().weight.shape[0] + + train_data = load_dataset("samsum", split="train") + + class SAMSumDataset(torch.utils.data.Dataset): + def __init__(self, data, tokenizer): + self.data = data + self.tokenizer = tokenizer + def __getitem__(self, idx): + text = self.data[idx] + prompt = self.tokenizer.encode(tokenizer.bos_token + f"Summarize this dialog: {text['dialogue']}\n---\nSummary: ", add_special_tokens=False) + summary = self.tokenizer.encode(text["summary"] + self.tokenizer.eos_token, add_special_tokens=False) + input_ids = prompt + summary + labels = len(prompt) * [-100] + summary + return {"input_ids": input_ids, "labels": labels} + def __len__(self): + return len(self.data) + + + train_dataset = SAMSumDataset(train_data, tokenizer) + + batch_size = 8 + + sampler = DistributedSampler( + train_dataset, + replica_group=REPLICA_GROUP_ID, + num_replica_groups=NUM_REPLICA_GROUPS, + rank=rank, + shuffle=True, + num_replicas=NUM_REPLICAS, + ) + + train_dataloader = StatefulDataLoader(train_dataset, batch_size=batch_size, collate_fn=DataCollatorForSeq2Seq(tokenizer), sampler=sampler) + + def load_state_dict(state_dict): + set_state_dict( + model, + optimizer.optim, + model_state_dict=state_dict["model"], + optim_state_dict=state_dict["optim"], + ) + + + def state_dict(): + model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer.optim) + return { + "model": model_state_dict, + "optim": optimizer_state_dict, + } + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + pg = ProcessGroupBabyNCCL() if torch.cuda.is_available() else ProcessGroupGloo() + + manager = Manager( + pg=pg, + min_replica_size=1, + load_state_dict=load_state_dict, + state_dict=state_dict, + replica_id=f"train_fsdp_{REPLICA_GROUP_ID}", + use_async_quorum=False, + ) + + mesh = hsdp_device_mesh(NUM_REPLICA_GROUPS, NUM_REPLICAS, "cuda" if torch.cuda.is_available() else "cpu", manager=manager) + + parallelize_llama(model, mesh) + + model.to(device) + + optimizer = Optimizer(manager, torch.optim.Adam(model.parameters(), lr=1e-5)) + + while manager.current_step() < 500: + model.train() + for batch in tqdm(train_dataloader): + input_ids = batch["input_ids"].to(device) + labels = batch["labels"].to(device) + optimizer.zero_grad() + + outputs = model(input_ids, labels=labels) + loss = outputs.loss + loss.backward() + optimizer.step() + + if manager.current_step() % 100 == 0: + print(f"[{manager.current_step()}] loss = {loss.item()}") + + +if __name__ == "__main__": + main()