From 6c69a683bf3421b31cc559bb7651af7b648cff40 Mon Sep 17 00:00:00 2001 From: Pekka Enberg Date: Wed, 2 Apr 2025 11:49:30 +0300 Subject: [PATCH] libsql: WAL sync baton handling --- libsql/src/sync.rs | 43 ++++++++++++++++++++++++++++++----------- libsql/src/sync/test.rs | 33 ++++++++++++++++--------------- 2 files changed, 49 insertions(+), 27 deletions(-) diff --git a/libsql/src/sync.rs b/libsql/src/sync.rs index a8fd090df5..a28b11028e 100644 --- a/libsql/src/sync.rs +++ b/libsql/src/sync.rs @@ -44,6 +44,8 @@ pub enum SyncError { JsonEncode(serde_json::Error), #[error("failed to push frame: status={0}, error={1}")] PushFrame(StatusCode, String), + #[error("no baton from WAL push operation")] + NoBatonFromPush, #[error("failed to verify metadata file version: expected={0}, got={1}")] VerifyVersion(u32, u32), #[error("failed to verify metadata file hash: expected={0}, got={1}")] @@ -79,7 +81,7 @@ pub struct PushResult { } pub enum PushStatus { - Ok, + Ok { baton: String }, Conflict, } @@ -174,12 +176,13 @@ impl SyncContext { #[tracing::instrument(skip(self, frames))] pub(crate) async fn push_frames( &mut self, + baton: Option, frames: Bytes, generation: u32, frame_no: u32, frames_count: u32, - ) -> Result { - let uri = format!( + ) -> Result<(Option, u32)> { + let mut uri = format!( "{}/sync/{}/{}/{}", self.sync_url, generation, @@ -187,15 +190,19 @@ impl SyncContext { frame_no + frames_count ); tracing::debug!("pushing frame(frame_no={}, count={}, generation={})", frame_no, frames_count, generation); + if let Some(baton) = baton { + uri += &format!("/{}", baton); + } - let result = self.push_with_retry(uri, frames, self.max_retries).await?; - match result.status { + let result = self.push_with_retry(uri, frames, self.max_retries).await?; + + let baton = match result.status { PushStatus::Conflict => { return Err(SyncError::InvalidPushFrameConflict(frame_no, result.max_frame_no).into()); } - _ => {} - } + PushStatus::Ok { baton } => baton, + }; let generation = result.generation; let durable_frame_num = result.max_frame_no; @@ -230,7 +237,7 @@ impl SyncContext { self.durable_generation = generation; self.durable_frame_num = durable_frame_num; - Ok(durable_frame_num) + Ok((Some(baton), durable_frame_num)) } async fn push_with_retry(&self, mut uri: String, body: Bytes, max_retries: usize) -> Result { @@ -263,6 +270,11 @@ impl SyncContext { let resp = serde_json::from_slice::(&res_body[..]) .map_err(SyncError::JsonDecode)?; + let baton: Option = resp + .get("baton") + .map(|v| v.as_str().map(String::from)) + .flatten(); + let status = resp .get("status") .ok_or_else(|| SyncError::JsonValue(resp.clone()))?; @@ -288,7 +300,13 @@ impl SyncContext { .ok_or_else(|| SyncError::JsonValue(max_frame_no.clone()))?; let status = match status { - "ok" => PushStatus::Ok, + "ok" => { + if let Some(baton) = baton { + PushStatus::Ok { baton } + } else { + return Err(SyncError::NoBatonFromPush.into()); + } + }, "conflict" => PushStatus::Conflict, _ => return Err(SyncError::JsonValue(resp.clone()).into()), }; @@ -729,6 +747,7 @@ async fn try_push( }); } + let mut baton: Option = None; let generation = sync_ctx.durable_generation(); let start_frame_no = sync_ctx.durable_frame_num() + 1; let end_frame_no = max_frame_no; @@ -748,10 +767,12 @@ async fn try_push( // The server returns its maximum frame number. To avoid resending // frames the server already knows about, we need to update the // frame number to the one returned by the server. - let max_frame_no = sync_ctx - .push_frames(frames.freeze(), generation, frame_no, batch_size) + let (new_baton, max_frame_no) = sync_ctx + .push_frames(baton.clone(), frames.freeze(), generation, frame_no, batch_size) .await?; + baton = new_baton; + if max_frame_no > frame_no { frame_no = max_frame_no + 1; } else { diff --git a/libsql/src/sync/test.rs b/libsql/src/sync/test.rs index 2417fa8158..cbe150da92 100644 --- a/libsql/src/sync/test.rs +++ b/libsql/src/sync/test.rs @@ -28,9 +28,9 @@ async fn test_sync_context_push_frame() { let mut sync_ctx = sync_ctx; // Push a frame and verify the response - let durable_frame = sync_ctx.push_frames(frame, 1, 0, 1).await.unwrap(); + let durable_frame = sync_ctx.push_frames(None, frame, 1, 0, 1).await.unwrap(); sync_ctx.write_metadata().await.unwrap(); - assert_eq!(durable_frame, 0); // First frame should return max_frame_no = 0 + assert_eq!(durable_frame.1, 0); // First frame should return max_frame_no = 0 // Verify internal state was updated assert_eq!(sync_ctx.durable_frame_num(), 0); @@ -56,9 +56,9 @@ async fn test_sync_context_with_auth() { let frame = Bytes::from("test frame with auth"); let mut sync_ctx = sync_ctx; - let durable_frame = sync_ctx.push_frames(frame, 1, 0, 1).await.unwrap(); + let durable_frame = sync_ctx.push_frames(None, frame, 1, 0, 1).await.unwrap(); sync_ctx.write_metadata().await.unwrap(); - assert_eq!(durable_frame, 0); + assert_eq!(durable_frame.1, 0); assert_eq!(server.frame_count(), 1); } @@ -82,9 +82,9 @@ async fn test_sync_context_multiple_frames() { // Push multiple frames and verify incrementing frame numbers for i in 0..3 { let frame = Bytes::from(format!("frame data {}", i)); - let durable_frame = sync_ctx.push_frames(frame, 1, i, 1).await.unwrap(); + let durable_frame = sync_ctx.push_frames(None, frame, 1, i, 1).await.unwrap(); sync_ctx.write_metadata().await.unwrap(); - assert_eq!(durable_frame, i); + assert_eq!(durable_frame.1, i); assert_eq!(sync_ctx.durable_frame_num(), i); assert_eq!(server.frame_count(), i + 1); } @@ -108,9 +108,9 @@ async fn test_sync_context_corrupted_metadata() { let mut sync_ctx = sync_ctx; let frame = Bytes::from("test frame data"); - let durable_frame = sync_ctx.push_frames(frame, 1, 0, 1).await.unwrap(); + let durable_frame = sync_ctx.push_frames(None, frame, 1, 0, 1).await.unwrap(); sync_ctx.write_metadata().await.unwrap(); - assert_eq!(durable_frame, 0); + assert_eq!(durable_frame.1, 0); assert_eq!(server.frame_count(), 1); // Update metadata path to use -info instead of .meta @@ -152,9 +152,9 @@ async fn test_sync_restarts_with_lower_max_frame_no() { let mut sync_ctx = sync_ctx; let frame = Bytes::from("test frame data"); - let durable_frame = sync_ctx.push_frames(frame.clone(), 1, 0, 1).await.unwrap(); + let durable_frame = sync_ctx.push_frames(None, frame.clone(), 1, 0, 1).await.unwrap(); sync_ctx.write_metadata().await.unwrap(); - assert_eq!(durable_frame, 0); + assert_eq!(durable_frame.1, 0); assert_eq!(server.frame_count(), 1); // Bump the durable frame num so that the next time we call the @@ -180,14 +180,14 @@ async fn test_sync_restarts_with_lower_max_frame_no() { // This push should fail because we are ahead of the server and thus should get an invalid // frame no error. sync_ctx - .push_frames(frame.clone(), 1, frame_no, 1) + .push_frames(None, frame.clone(), 1, frame_no, 1) .await .unwrap_err(); let frame_no = sync_ctx.durable_frame_num() + 1; // This then should work because when the last one failed it updated our state of the server // durable_frame_num and we should then start writing from there. - sync_ctx.push_frames(frame, 1, frame_no, 1).await.unwrap(); + sync_ctx.push_frames(None, frame, 1, frame_no, 1).await.unwrap(); } #[tokio::test] @@ -215,7 +215,7 @@ async fn test_sync_context_retry_on_error() { server.return_error.store(true, Ordering::SeqCst); // First attempt should fail but retry - let result = sync_ctx.push_frames(frame.clone(), 1, 0, 1).await; + let result = sync_ctx.push_frames(None, frame.clone(), 1, 0, 1).await; assert!(result.is_err()); // Advance time to trigger retries faster @@ -228,9 +228,9 @@ async fn test_sync_context_retry_on_error() { server.return_error.store(false, Ordering::SeqCst); // Next attempt should succeed - let durable_frame = sync_ctx.push_frames(frame, 1, 0, 1).await.unwrap(); + let durable_frame = sync_ctx.push_frames(None, frame, 1, 0, 1).await.unwrap(); sync_ctx.write_metadata().await.unwrap(); - assert_eq!(durable_frame, 0); + assert_eq!(durable_frame.1, 0); assert_eq!(server.frame_count(), 1); } @@ -378,7 +378,8 @@ impl MockServer { let response = serde_json::json!({ "status": "ok", "generation": 1, - "max_frame_no": current_count + "max_frame_no": current_count, + "baton": "test_baton" }); Ok::<_, hyper::Error>(