From 337772d12dcdb6662314e53e12f75d0b2104292a Mon Sep 17 00:00:00 2001 From: Nathan Flurry Date: Tue, 5 May 2026 07:18:50 -0700 Subject: [PATCH] fix(sqlite): stop actors on head fence mismatch --- Cargo.lock | 3 + .../errors/depot.head_fence_mismatch.json | 5 + .../packages/depot-client-embedded/Cargo.toml | 1 + .../packages/depot-client-embedded/src/lib.rs | 53 +- engine/packages/depot-client-types/src/lib.rs | 7 + engine/packages/depot-client/Cargo.toml | 1 + engine/packages/depot-client/src/database.rs | 42 +- engine/packages/depot-client/src/vfs.rs | 179 +++-- engine/packages/depot-client/src/worker.rs | 25 + .../packages/depot-client/tests/inline/vfs.rs | 163 ++++- .../depot-client/tests/inline/vfs_support.rs | 77 ++- engine/packages/depot/Cargo.toml | 1 + .../depot/src/conveyer/commit/apply.rs | 46 +- engine/packages/depot/src/conveyer/error.rs | 21 + engine/packages/depot/src/conveyer/read.rs | 44 +- .../depot/src/conveyer/types/pages.rs | 23 + .../packages/depot/tests/conveyer_commit.rs | 109 ++- engine/packages/depot/tests/conveyer_error.rs | 14 +- .../pegboard-envoy/src/ws_to_tunnel_task.rs | 47 +- engine/sdks/rust/envoy-protocol/src/lib.rs | 2 +- .../sdks/rust/envoy-protocol/src/versioned.rs | 534 ++++++++++++--- .../envoy-protocol/tests/remote_sql_compat.rs | 59 +- .../tests/stateless_sqlite_v3.rs | 16 +- engine/sdks/schemas/envoy-protocol/v5.bare | 642 ++++++++++++++++++ .../typescript/envoy-protocol/src/index.ts | 28 +- .../rivetkit-core/src/actor/sqlite.rs | 51 +- .../packages/rivetkit-core/tests/sqlite.rs | 138 ++++ 27 files changed, 2081 insertions(+), 250 deletions(-) create mode 100644 engine/artifacts/errors/depot.head_fence_mismatch.json create mode 100644 engine/sdks/schemas/envoy-protocol/v5.bare diff --git a/Cargo.lock b/Cargo.lock index dc61275d08..e4cabba8bb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1810,6 +1810,7 @@ dependencies = [ "aws-config", "aws-sdk-s3", "base64 0.22.1", + "depot-client-types", "futures-util", "gasoline", "lazy_static", @@ -1856,6 +1857,7 @@ dependencies = [ "parking_lot", "rivet-config", "rivet-envoy-protocol", + "rivet-error", "rivet-pools", "rivet-test-deps", "scc", @@ -1876,6 +1878,7 @@ dependencies = [ "depot", "depot-client", "rivet-envoy-protocol", + "rivet-error", "tokio", ] diff --git a/engine/artifacts/errors/depot.head_fence_mismatch.json b/engine/artifacts/errors/depot.head_fence_mismatch.json new file mode 100644 index 0000000000..2a7a1032d0 --- /dev/null +++ b/engine/artifacts/errors/depot.head_fence_mismatch.json @@ -0,0 +1,5 @@ +{ + "code": "head_fence_mismatch", + "group": "depot", + "message": "SQLite head fence mismatch." +} \ No newline at end of file diff --git a/engine/packages/depot-client-embedded/Cargo.toml b/engine/packages/depot-client-embedded/Cargo.toml index 5fe50d156f..eb1355fb33 100644 --- a/engine/packages/depot-client-embedded/Cargo.toml +++ b/engine/packages/depot-client-embedded/Cargo.toml @@ -16,4 +16,5 @@ async-trait.workspace = true depot.workspace = true depot-client.workspace = true rivet-envoy-protocol.workspace = true +rivet-error.workspace = true tokio.workspace = true diff --git a/engine/packages/depot-client-embedded/src/lib.rs b/engine/packages/depot-client-embedded/src/lib.rs index 9fa56460b6..6b369af9e5 100644 --- a/engine/packages/depot-client-embedded/src/lib.rs +++ b/engine/packages/depot-client-embedded/src/lib.rs @@ -8,6 +8,7 @@ use std::sync::Arc; use anyhow::Result; use async_trait::async_trait; +use depot::error::SqliteStorageError; use depot_client::{ database::{NativeDatabaseHandle, open_database_from_transport}, vfs::{SqliteTransport, SqliteVfsMetrics}, @@ -48,22 +49,31 @@ impl SqliteTransport for EmbeddedDepotSqliteTransport { &self, request: protocol::SqliteGetPagesRequest, ) -> Result { - match self.db.get_pages(request.pgnos).await { - Ok(pages) => Ok(protocol::SqliteGetPagesResponse::SqliteGetPagesOk( + match self + .db + .get_pages_with_options( + request.pgnos, + depot::types::GetPagesOptions { + expected_head_txid: request.expected_head_txid, + }, + ) + .await + { + Ok(result) => Ok(protocol::SqliteGetPagesResponse::SqliteGetPagesOk( protocol::SqliteGetPagesOk { - pages: pages + pages: result + .pages .into_iter() .map(|page| protocol::SqliteFetchedPage { pgno: page.pgno, bytes: page.bytes, }) .collect(), + head_txid: Some(result.head_txid), }, )), Err(err) => Ok(protocol::SqliteGetPagesResponse::SqliteErrorResponse( - protocol::SqliteErrorResponse { - message: sqlite_error_reason(&err), - }, + sqlite_error_response(&err), )), } } @@ -74,7 +84,7 @@ impl SqliteTransport for EmbeddedDepotSqliteTransport { ) -> Result { match self .db - .commit( + .commit_with_options( request .dirty_pages .into_iter() @@ -85,15 +95,20 @@ impl SqliteTransport for EmbeddedDepotSqliteTransport { .collect(), request.db_size_pages, request.now_ms, + depot::types::CommitOptions { + expected_head_txid: request.expected_head_txid, + }, ) .await { - Ok(()) => Ok(protocol::SqliteCommitResponse::SqliteCommitOk), - Err(err) => Ok(protocol::SqliteCommitResponse::SqliteErrorResponse( - protocol::SqliteErrorResponse { - message: sqlite_error_reason(&err), + Ok(result) => Ok(protocol::SqliteCommitResponse::SqliteCommitOk( + protocol::SqliteCommitOk { + head_txid: Some(result.head_txid), }, )), + Err(err) => Ok(protocol::SqliteCommitResponse::SqliteErrorResponse( + sqlite_error_response(&err), + )), } } } @@ -104,3 +119,19 @@ fn sqlite_error_reason(err: &anyhow::Error) -> String { .collect::>() .join(": ") } + +fn sqlite_error_response(err: &anyhow::Error) -> protocol::SqliteErrorResponse { + let structured = depot_error(err) + .map(|err| rivet_error::RivetError::extract(&err.clone().build())) + .unwrap_or_else(|| rivet_error::RivetError::extract(err)); + protocol::SqliteErrorResponse { + group: structured.group().to_string(), + code: structured.code().to_string(), + message: sqlite_error_reason(err), + } +} + +fn depot_error(err: &anyhow::Error) -> Option<&SqliteStorageError> { + err.chain() + .find_map(|source| source.downcast_ref::()) +} diff --git a/engine/packages/depot-client-types/src/lib.rs b/engine/packages/depot-client-types/src/lib.rs index c5e3dab64e..33ad5de390 100644 --- a/engine/packages/depot-client-types/src/lib.rs +++ b/engine/packages/depot-client-types/src/lib.rs @@ -1,5 +1,12 @@ //! Shared SQLite execution types for local and remote depot client backends. +pub const HEAD_FENCE_MISMATCH_GROUP: &str = "depot"; +pub const HEAD_FENCE_MISMATCH_CODE: &str = "head_fence_mismatch"; + +pub fn is_head_fence_mismatch(group: &str, code: &str) -> bool { + group == HEAD_FENCE_MISMATCH_GROUP && code == HEAD_FENCE_MISMATCH_CODE +} + #[derive(Clone, Debug, PartialEq)] pub enum BindParam { Null, diff --git a/engine/packages/depot-client/Cargo.toml b/engine/packages/depot-client/Cargo.toml index 20001d259f..807f6bbed9 100644 --- a/engine/packages/depot-client/Cargo.toml +++ b/engine/packages/depot-client/Cargo.toml @@ -29,6 +29,7 @@ depot = { workspace = true, features = ["test-faults"] } futures-util.workspace = true gas.workspace = true rivet-config.workspace = true +rivet-error.workspace = true rivet-pools.workspace = true rivet-test-deps.workspace = true sha2.workspace = true diff --git a/engine/packages/depot-client/src/database.rs b/engine/packages/depot-client/src/database.rs index dd42f9a553..996784b6c6 100644 --- a/engine/packages/depot-client/src/database.rs +++ b/engine/packages/depot-client/src/database.rs @@ -10,7 +10,7 @@ use crate::{ SqliteVfsMetricsSnapshot, VfsConfig, VfsPreloadHintSnapshot, fetch_initial_pages_for_registration, }, - worker::SqliteWorkerHandle, + worker::{SqliteWorkerFatalError, SqliteWorkerHandle}, }; #[derive(Clone)] @@ -70,7 +70,8 @@ impl NativeDatabaseHandle { } pub async fn exec(&self, sql: String) -> Result { - self.worker.exec(sql).await + self.check_fatal_error()?; + self.map_worker_result(self.worker.exec(sql).await) } pub async fn query(&self, sql: String, params: Option>) -> Result { @@ -91,11 +92,15 @@ impl NativeDatabaseHandle { sql: String, params: Option>, ) -> Result { - self.worker.execute(sql, params).await + self.check_fatal_error()?; + self.map_worker_result(self.worker.execute(sql, params).await) } pub async fn close(&self) -> Result<()> { - self.worker.close().await + match self.worker.close().await { + Ok(()) => Ok(()), + Err(error) => Err(self.fatal_error().unwrap_or(error)), + } } pub async fn wait_for_worker_failure(&self) -> bool { @@ -106,6 +111,10 @@ impl NativeDatabaseHandle { self.vfs.take_last_error() } + pub fn clone_fatal_error(&self) -> Option { + self.vfs.clone_fatal_error() + } + pub fn snapshot_preload_hints(&self) -> VfsPreloadHintSnapshot { self.vfs.snapshot_preload_hints() } @@ -130,7 +139,30 @@ impl NativeDatabaseHandle { } async fn initialize(&self) -> Result<()> { - self.worker.wait_ready().await + self.map_worker_result(self.worker.wait_ready().await) + } + + fn check_fatal_error(&self) -> Result<()> { + if let Some(error) = self.fatal_error() { + return Err(error); + } + + Ok(()) + } + + fn map_worker_result(&self, result: Result) -> Result { + match result { + Ok(value) => { + self.check_fatal_error()?; + Ok(value) + } + Err(error) => Err(self.fatal_error().unwrap_or(error)), + } + } + + fn fatal_error(&self) -> Option { + self.clone_fatal_error() + .map(|message| SqliteWorkerFatalError::new(message).into()) } } diff --git a/engine/packages/depot-client/src/vfs.rs b/engine/packages/depot-client/src/vfs.rs index 31ae2d051a..a9ccc62ec2 100644 --- a/engine/packages/depot-client/src/vfs.rs +++ b/engine/packages/depot-client/src/vfs.rs @@ -12,6 +12,7 @@ use std::time::{Duration, Instant}; use anyhow::Result; use async_trait::async_trait; +use depot_client_types::is_head_fence_mismatch; use libsqlite3_sys::*; use moka::sync::Cache; use parking_lot::{Mutex, RwLock}; @@ -119,7 +120,6 @@ pub trait SqliteTransport: Send + Sync { } pub type SqliteTransportHandle = Arc; - fn sqlite_now_ms() -> Result { use std::time::{SystemTime, UNIX_EPOCH}; @@ -221,6 +221,21 @@ pub struct VfsPreloadHintSnapshot { pub ranges: Vec, } +#[derive(Debug, Clone, Default)] +pub(crate) struct InitialPages { + pub pages: Vec<(u32, Vec)>, + pub head_txid: Option, +} + +impl From)>> for InitialPages { + fn from(pages: Vec<(u32, Vec)>) -> Self { + Self { + pages, + head_txid: None, + } + } +} + #[derive(Debug, Clone, PartialEq, Eq)] pub enum CommitPath { Fast, @@ -232,12 +247,14 @@ pub struct BufferedCommitRequest { pub actor_id: String, pub new_db_size_pages: u32, pub dirty_pages: Vec, + pub expected_head_txid: Option, } #[derive(Debug, Clone)] pub struct BufferedCommitOutcome { pub path: CommitPath, pub db_size_pages: u32, + pub head_txid: Option, } #[derive(Debug, Clone, PartialEq, Eq)] @@ -321,6 +338,7 @@ pub struct VfsContext { state: RwLock, aux_files: RwLock>>, last_error: Mutex>, + fatal_error: RwLock>, #[cfg(test)] fail_next_aux_open: Mutex>, #[cfg(test)] @@ -345,6 +363,7 @@ pub struct VfsContext { #[derive(Debug, Clone)] struct VfsState { db_size_pages: u32, + head_txid: Option, page_size: usize, page_cache: Cache>, committed_page_cache: Cache>, @@ -413,6 +432,7 @@ struct RecentPageAccess { #[derive(Debug)] enum GetPagesError { + Fatal(String), Other(String), } @@ -798,6 +818,7 @@ impl VfsState { let committed_page_cache = build_page_cache(config); let mut state = Self { db_size_pages: 1, + head_txid: None, page_size: DEFAULT_PAGE_SIZE, page_cache, committed_page_cache, @@ -939,11 +960,13 @@ impl VfsContext { transport: SqliteTransportHandle, config: VfsConfig, io_methods: sqlite3_io_methods, - initial_pages: Vec<(u32, Vec)>, + initial_pages: impl Into, metrics: Option>, ) -> std::result::Result { let mut state = VfsState::new(&config); - for (pgno, page) in initial_pages { + let initial_pages = initial_pages.into(); + state.head_txid = initial_pages.head_txid; + for (pgno, page) in initial_pages.pages { state.seed_page(&config, PageCacheInsertKind::Startup, pgno, page); } @@ -955,6 +978,7 @@ impl VfsContext { state: RwLock::new(state), aux_files: RwLock::new(BTreeMap::new()), last_error: Mutex::new(None), + fatal_error: RwLock::new(None), #[cfg(test)] fail_next_aux_open: Mutex::new(None), #[cfg(test)] @@ -988,6 +1012,10 @@ impl VfsContext { self.last_error.lock().clone() } + fn clone_fatal_error(&self) -> Option { + self.fatal_error.read().clone() + } + pub(crate) fn take_last_error(&self) -> Option { self.last_error.lock().take() } @@ -1118,6 +1146,14 @@ impl VfsContext { self.state.write().dead = true; } + fn mark_fatal(&self, message: String) { + self.mark_dead(message.clone()); + let mut fatal_error = self.fatal_error.write(); + if fatal_error.is_none() { + *fatal_error = Some(message); + } + } + pub(crate) fn snapshot_preload_hints(&self) -> VfsPreloadHintSnapshot { if !self.config.recent_page_hints { return VfsPreloadHintSnapshot::default(); @@ -1218,6 +1254,7 @@ impl VfsContext { prediction_budget, predicted_pgnos, db_size_pages, + expected_head_txid, ) = { let mut state = self.state.write(); for pgno in target_pgnos.iter().copied() { @@ -1260,6 +1297,7 @@ impl VfsContext { prediction_budget, predicted_pgnos, state.db_size_pages, + state.head_txid, ) }; @@ -1303,7 +1341,7 @@ impl VfsContext { actor_id: self.actor_id.clone(), pgnos: to_fetch.clone(), expected_generation: None, - expected_head_txid: None, + expected_head_txid, })) .map_err(|err| GetPagesError::Other(err.to_string()))?; if let Some(metrics) = &self.metrics { @@ -1312,6 +1350,10 @@ impl VfsContext { match response { protocol::SqliteGetPagesResponse::SqliteGetPagesOk(ok) => { + let response_head_txid = ok.head_txid; + if let Some(head_txid) = response_head_txid { + self.state.write().head_txid = Some(head_txid); + } let missing_pages = missing.iter().copied().collect::>(); let (page_cache, protected_page_cache) = { let state = self.state.read(); @@ -1339,6 +1381,7 @@ impl VfsContext { && missing_pages.contains(&fetched.pgno) && fetched.pgno == 1 { + self.state.write().head_txid = Some(0); Some(empty_db_page()) } else { fetched.bytes @@ -1409,6 +1452,9 @@ impl VfsContext { } return Ok(resolved); } + if is_head_fence_mismatch_response(&error) { + return Err(GetPagesError::Fatal(error.message)); + } Err(GetPagesError::Other(error.message)) } } @@ -1445,6 +1491,7 @@ impl VfsContext { BufferedCommitRequest { actor_id: self.actor_id.clone(), new_db_size_pages: state.db_size_pages, + expected_head_txid: state.head_txid, dirty_pages: state .write_buffer .dirty @@ -1473,7 +1520,7 @@ impl VfsContext { ?err, "sqlite flush commit failed" ); - mark_dead_for_non_fence_commit_error(self, &err); + handle_non_finalize_commit_error(self, &err); return Err(err); } }; @@ -1507,7 +1554,10 @@ impl VfsContext { } let state_update_start = Instant::now(); let mut state = self.state.write(); - state.db_size_pages = request.new_db_size_pages; + state.db_size_pages = outcome.db_size_pages; + state.head_txid = outcome + .head_txid + .or_else(|| state.head_txid.map(|head_txid| head_txid.saturating_add(1))); for dirty_page in &request.dirty_pages { state.cache_committed_page(&self.config, dirty_page.pgno, dirty_page.bytes.clone()); } @@ -1555,6 +1605,7 @@ impl VfsContext { BufferedCommitRequest { actor_id: self.actor_id.clone(), new_db_size_pages: state.db_size_pages, + expected_head_txid: state.head_txid, dirty_pages: state .write_buffer .dirty @@ -1580,7 +1631,7 @@ impl VfsContext { ?err, "sqlite atomic commit failed" ); - mark_dead_for_non_fence_commit_error(self, &err); + handle_non_finalize_commit_error(self, &err); return Err(err); } }; @@ -1615,7 +1666,10 @@ impl VfsContext { self.clear_last_error(); let state_update_start = Instant::now(); let mut state = self.state.write(); - state.db_size_pages = request.new_db_size_pages; + state.db_size_pages = outcome.db_size_pages; + state.head_txid = outcome + .head_txid + .or_else(|| state.head_txid.map(|head_txid| head_txid.saturating_add(1))); for dirty_page in &request.dirty_pages { state.cache_committed_page(&self.config, dirty_page.pgno, dirty_page.bytes.clone()); } @@ -1696,9 +1750,9 @@ fn assert_batch_atomic_probe(db: *mut sqlite3, vfs: &SqliteVfs) -> std::result:: Ok(()) } -fn mark_dead_for_non_fence_commit_error(ctx: &VfsContext, err: &CommitBufferError) { +fn handle_non_finalize_commit_error(ctx: &VfsContext, err: &CommitBufferError) { match err { - CommitBufferError::FenceMismatch(_) => {} + CommitBufferError::FenceMismatch(message) => ctx.mark_fatal(message.clone()), CommitBufferError::StageNotFound(stage_id) => { ctx.mark_dead(format!( "sqlite stage {stage_id} missing during commit finalize" @@ -1708,31 +1762,38 @@ fn mark_dead_for_non_fence_commit_error(ctx: &VfsContext, err: &CommitBufferErro } } -fn mark_dead_from_fence_commit_error(ctx: &VfsContext, err: &CommitBufferError) { +fn handle_finalize_fence_error(ctx: &VfsContext, err: &CommitBufferError) { if let CommitBufferError::FenceMismatch(reason) = err { - ctx.mark_dead(reason.clone()); + ctx.mark_fatal(reason.clone()); } } +#[cfg(test)] pub(crate) async fn fetch_initial_main_page_for_registration( transport: SqliteTransportHandle, actor_id: &str, ) -> std::result::Result>, String> { - fetch_initial_main_page(transport, actor_id.to_string()).await + fetch_initial_pages(transport, actor_id.to_string(), 1) + .await + .map(|pages| { + pages + .pages + .into_iter() + .find(|(pgno, _)| *pgno == 1) + .map(|(_, bytes)| bytes) + }) } pub(crate) async fn fetch_initial_pages_for_registration( transport: SqliteTransportHandle, actor_id: &str, config: &VfsConfig, -) -> std::result::Result)>, String> { +) -> std::result::Result { if !config.startup_preload_first_pages || !config.page_cache_mode.caches_startup_preloaded_pages() || config.startup_preload_max_bytes < DEFAULT_PAGE_SIZE { - return fetch_initial_main_page_for_registration(transport, actor_id) - .await - .map(|page| page.into_iter().map(|page| (1, page)).collect()); + return fetch_initial_pages(transport, actor_id.to_string(), 1).await; } let page_count_from_bytes = config.startup_preload_max_bytes / DEFAULT_PAGE_SIZE; @@ -1743,25 +1804,11 @@ pub(crate) async fn fetch_initial_pages_for_registration( fetch_initial_pages(transport, actor_id.to_string(), page_count).await } -async fn fetch_initial_main_page( - transport: SqliteTransportHandle, - actor_id: String, -) -> std::result::Result>, String> { - fetch_initial_pages(transport, actor_id, 1) - .await - .map(|pages| { - pages - .into_iter() - .find(|(pgno, _)| *pgno == 1) - .map(|(_, bytes)| bytes) - }) -} - async fn fetch_initial_pages( transport: SqliteTransportHandle, actor_id: String, page_count: u32, -) -> std::result::Result)>, String> { +) -> std::result::Result { let request_actor_id = actor_id.clone(); let response = transport .get_pages(protocol::SqliteGetPagesRequest { @@ -1773,11 +1820,14 @@ async fn fetch_initial_pages( .await; match response { - Ok(protocol::SqliteGetPagesResponse::SqliteGetPagesOk(ok)) => Ok(ok - .pages - .into_iter() - .filter_map(|page| page.bytes.map(|bytes| (page.pgno, bytes))) - .collect()), + Ok(protocol::SqliteGetPagesResponse::SqliteGetPagesOk(ok)) => Ok(InitialPages { + pages: ok + .pages + .into_iter() + .filter_map(|page| page.bytes.map(|bytes| (page.pgno, bytes))) + .collect(), + head_txid: ok.head_txid, + }), Ok(protocol::SqliteGetPagesResponse::SqliteErrorResponse(error)) => { if !is_initial_main_page_missing(&error.message) { return Err(format!( @@ -1790,7 +1840,10 @@ async fn fetch_initial_pages( error = %error.message, "sqlite initial page fetch did not find persisted data" ); - Ok(Vec::new()) + Ok(InitialPages { + pages: Vec::new(), + head_txid: Some(0), + }) } Err(err) => Err(format!("sqlite initial page fetch failed: {err}")), } @@ -1824,7 +1877,7 @@ async fn commit_buffered_pages( db_size_pages: request.new_db_size_pages, now_ms: sqlite_now_ms().map_err(|err| CommitBufferError::Other(err.to_string()))?, expected_generation: None, - expected_head_txid: None, + expected_head_txid: request.expected_head_txid, }; metrics.serialize_ns += serialize_start.elapsed().as_nanos() as u64; let transport_start = Instant::now(); @@ -1833,22 +1886,31 @@ async fn commit_buffered_pages( .await .map_err(|err| CommitBufferError::Other(err.to_string()))? { - protocol::SqliteCommitResponse::SqliteCommitOk => { + protocol::SqliteCommitResponse::SqliteCommitOk(ok) => { metrics.transport_ns += transport_start.elapsed().as_nanos() as u64; Ok(( BufferedCommitOutcome { path: CommitPath::Fast, db_size_pages: request.new_db_size_pages, + head_txid: ok.head_txid, }, metrics, )) } protocol::SqliteCommitResponse::SqliteErrorResponse(error) => { - Err(CommitBufferError::Other(error.message)) + if is_head_fence_mismatch_response(&error) { + Err(CommitBufferError::FenceMismatch(error.message)) + } else { + Err(CommitBufferError::Other(error.message)) + } } } } +fn is_head_fence_mismatch_response(error: &protocol::SqliteErrorResponse) -> bool { + is_head_fence_mismatch(&error.group, &error.code) +} + unsafe fn get_file(p: *mut sqlite3_file) -> &'static mut VfsFile { &mut *(p as *mut VfsFile) } @@ -2110,7 +2172,7 @@ unsafe extern "C" fn io_close(p_file: *mut sqlite3_file) -> c_int { Ok(()) => SQLITE_OK, Err(err) => { let ctx = &*file.ctx; - mark_dead_from_fence_commit_error(ctx, &err); + handle_finalize_fence_error(ctx, &err); SQLITE_IOERR } } @@ -2174,6 +2236,16 @@ unsafe extern "C" fn io_read( let resolved = match ctx.resolve_pages(&requested_pages, true) { Ok(pages) => pages, + Err(GetPagesError::Fatal(message)) => { + tracing::error!( + actor_id = %ctx.actor_id, + requested_pages = ?requested_pages, + error = %message, + "sqlite xRead hit fatal sqlite error" + ); + ctx.mark_fatal(message); + return SQLITE_IOERR_READ; + } Err(GetPagesError::Other(message)) => { tracing::error!( actor_id = %ctx.actor_id, @@ -2306,6 +2378,10 @@ unsafe extern "C" fn io_write( } else { match ctx.resolve_pages(&pages_to_resolve, false) { Ok(pages) => pages, + Err(GetPagesError::Fatal(message)) => { + ctx.mark_fatal(message); + return SQLITE_IOERR_WRITE; + } Err(GetPagesError::Other(message)) => { ctx.mark_dead(message); return SQLITE_IOERR_WRITE; @@ -2420,7 +2496,7 @@ unsafe extern "C" fn io_sync(p_file: *mut sqlite3_file, _flags: c_int) -> c_int ?err, "sqlite sync failed" ); - mark_dead_from_fence_commit_error(ctx, &err); + handle_finalize_fence_error(ctx, &err); SQLITE_IOERR_FSYNC } } @@ -2497,7 +2573,7 @@ unsafe extern "C" fn io_file_control( ?err, "sqlite atomic write file control failed" ); - mark_dead_from_fence_commit_error(ctx, &err); + handle_finalize_fence_error(ctx, &err); SQLITE_IOERR } }, @@ -2739,6 +2815,10 @@ impl SqliteVfs { self.ctx.clone_last_error() } + pub fn clone_fatal_error(&self) -> Option { + self.ctx.clone_fatal_error() + } + pub(crate) fn snapshot_preload_hints(&self) -> VfsPreloadHintSnapshot { self.ctx.snapshot_preload_hints() } @@ -2762,7 +2842,7 @@ impl SqliteVfs { actor_id, runtime, config, - Vec::new(), + InitialPages::default(), metrics, ) } @@ -2787,7 +2867,10 @@ impl SqliteVfs { actor_id, runtime, config, - initial_pages, + InitialPages { + pages: initial_pages, + head_txid: None, + }, metrics, ) } @@ -2798,7 +2881,7 @@ impl SqliteVfs { actor_id: String, runtime: Handle, config: VfsConfig, - initial_pages: Vec<(u32, Vec)>, + initial_pages: InitialPages, metrics: Option>, ) -> std::result::Result { let mut io_methods: sqlite3_io_methods = unsafe { std::mem::zeroed() }; @@ -2942,7 +3025,7 @@ impl Drop for NativeDatabase { return; } Err(err) => { - mark_dead_for_non_fence_commit_error(ctx, &err); + handle_non_finalize_commit_error(ctx, &err); tracing::warn!(?err, "failed to flush sqlite database before close"); } } diff --git a/engine/packages/depot-client/src/worker.rs b/engine/packages/depot-client/src/worker.rs index ad1e5fb94c..b86651afc8 100644 --- a/engine/packages/depot-client/src/worker.rs +++ b/engine/packages/depot-client/src/worker.rs @@ -504,6 +504,8 @@ fn worker_error_code(error: &anyhow::Error) -> &'static str { .is_some() { "overloaded" + } else if error.downcast_ref::().is_some() { + "fatal" } else if error.downcast_ref::().is_some() { "closing" } else if error.downcast_ref::().is_some() { @@ -562,6 +564,29 @@ impl fmt::Display for SqliteWorkerCloseTimeoutError { impl Error for SqliteWorkerCloseTimeoutError {} +#[derive(Debug, Clone)] +pub struct SqliteWorkerFatalError { + message: String, +} + +impl SqliteWorkerFatalError { + pub fn new(message: String) -> Self { + Self { message } + } + + pub fn message(&self) -> &str { + &self.message + } +} + +impl fmt::Display for SqliteWorkerFatalError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "sqlite worker hit fatal storage error: {}", self.message) + } +} + +impl Error for SqliteWorkerFatalError {} + fn panic_message(payload: &Box) -> String { if let Some(message) = payload.downcast_ref::<&str>() { message.to_string() diff --git a/engine/packages/depot-client/tests/inline/vfs.rs b/engine/packages/depot-client/tests/inline/vfs.rs index 29c3409eb6..d1631501f7 100644 --- a/engine/packages/depot-client/tests/inline/vfs.rs +++ b/engine/packages/depot-client/tests/inline/vfs.rs @@ -26,6 +26,7 @@ use crate::optimization_flags::{ }; use crate::query::{BindParam, ColumnValue}; use crate::vfs::SqliteVfsMetrics; +use crate::worker::SqliteWorkerFatalError; use super::*; @@ -108,6 +109,7 @@ impl SqliteTransport for RecordingInitialPagesTransport { bytes: Some(vec![pgno as u8; DEFAULT_PAGE_SIZE]), }) .collect(), + head_txid: Some(0), }, )) } @@ -130,6 +132,8 @@ impl SqliteTransport for MissingDbTransport { ) -> anyhow::Result { Ok(protocol::SqliteGetPagesResponse::SqliteErrorResponse( protocol::SqliteErrorResponse { + group: "depot".to_string(), + code: "database_not_found".to_string(), message: "sqlite database was not found in this bucket branch".to_string(), }, )) @@ -164,7 +168,11 @@ fn startup_initial_pages_do_not_require_preload_hints_on_open() { )) .expect("initial pages should load"); - let loaded_pgnos = pages.iter().map(|(pgno, _)| *pgno).collect::>(); + let loaded_pgnos = pages + .pages + .iter() + .map(|(pgno, _)| *pgno) + .collect::>(); assert_eq!(*transport.requested_pgnos.lock(), vec![1, 2, 3, 4]); assert_eq!(loaded_pgnos, vec![1, 2, 3, 4]); } @@ -387,28 +395,44 @@ fn open_worker_handle_with_metrics( harness: &DirectEngineHarness, metrics: Option>, ) -> crate::database::NativeDatabaseHandle { + open_worker_handle_with_vfs(runtime, harness, metrics).1 +} + +fn open_worker_handle_with_vfs( + runtime: &tokio::runtime::Runtime, + harness: &DirectEngineHarness, + metrics: Option>, +) -> (Arc, crate::database::NativeDatabaseHandle) { let engine = runtime.block_on(harness.open_engine()); let transport = Arc::new(DirectDepotTransport::new(engine)); - let initial_main_page = runtime - .block_on(fetch_initial_main_page_for_registration( + let config = VfsConfig::default(); + let initial_pages = runtime + .block_on(fetch_initial_pages_for_registration( transport.clone(), &harness.actor_id, + &config, )) - .expect("initial main page preload should succeed"); + .expect("initial pages preload should succeed"); let vfs = Arc::new( - SqliteVfs::register_with_transport_and_initial_page( + SqliteVfs::register_with_transport_and_initial_pages( &next_test_name("sqlite-worker-vfs"), transport, harness.actor_id.clone(), runtime.handle().clone(), - VfsConfig::default(), - initial_main_page, + config, + initial_pages, None, ) .expect("worker vfs should register"), ); - crate::database::NativeDatabaseHandle::new_with_metrics(vfs, harness.actor_id.clone(), metrics) - .expect("worker handle should start") + let db = + crate::database::NativeDatabaseHandle::new_with_metrics( + vfs.clone(), + harness.actor_id.clone(), + metrics, + ) + .expect("worker handle should start"); + (vfs, db) } #[derive(Default)] @@ -1177,6 +1201,7 @@ fn strict_direct_reopen_counts_cold_tier_get_for_cold_covered_page() { let page = runtime .block_on(engine.get_pages(&harness.actor_id, &[1])) .expect("page 1 should be fetched from depot") + .pages .into_iter() .find(|page| page.pgno == 1) .and_then(|page| page.bytes) @@ -1220,6 +1245,7 @@ fn strict_direct_warmed_shard_cache_does_not_count_as_cold_tier_evidence() { let page = runtime .block_on(engine.get_pages(&harness.actor_id, &[1])) .expect("page 1 should be fetched from depot") + .pages .into_iter() .find(|page| page.pgno == 1) .and_then(|page| page.bytes) @@ -1235,6 +1261,7 @@ fn strict_direct_warmed_shard_cache_does_not_count_as_cold_tier_evidence() { let cold_page = runtime .block_on(engine.get_pages(&harness.actor_id, &[1])) .expect("strict direct read should hit cold tier") + .pages .into_iter() .find(|page| page.pgno == 1) .and_then(|page| page.bytes) @@ -3492,6 +3519,123 @@ fn resolve_pages_surfaces_read_path_error_response() { )); } +#[test] +fn resolve_pages_sends_known_head_txid_as_read_fence() { + let runtime = direct_runtime(); + let harness = DirectEngineHarness::new(); + let engine = runtime.block_on(harness.open_engine()); + runtime + .block_on(engine.apply_commit( + &harness.actor_id, + vec![depot::types::DirtyPage { + pgno: 2, + bytes: vec![0x44; DEFAULT_PAGE_SIZE], + }], + 2, + )) + .expect("mirror seed should succeed"); + let transport = Arc::new(DirectMirrorTransport::new(engine)); + let hooks = transport.direct_hooks(); + let ctx = VfsContext::new( + harness.actor_id.clone(), + runtime.handle().clone(), + transport, + VfsConfig::default(), + unsafe { std::mem::zeroed() }, + Vec::new(), + None, + ) + .expect("vfs context should build"); + { + let mut state = ctx.state.write(); + state.db_size_pages = 2; + state.head_txid = Some(7); + } + + ctx.resolve_pages(&[2], false) + .expect("read should fetch from transport"); + let requests = hooks.get_pages_requests(); + let request = requests.last().expect("get_pages request should be recorded"); + assert_eq!(request.expected_head_txid, Some(7)); +} + +#[test] +fn fatal_sqlite_error_reports_first_message_through_worker_boundary() { + let runtime = direct_runtime(); + let harness = DirectEngineHarness::new(); + let (vfs, db) = open_worker_handle_with_vfs(&runtime, &harness, None); + let ctx = vfs.ctx(); + runtime + .block_on(db.exec("SELECT 1;".to_string())) + .expect("worker should initialize before fatal state is injected"); + + ctx.mark_fatal("first fatal sqlite error".to_string()); + ctx.mark_fatal("second fatal sqlite error".to_string()); + + assert!(ctx.is_dead()); + assert_eq!(ctx.clone_fatal_error().as_deref(), Some("first fatal sqlite error")); + + let err = runtime + .block_on(db.exec("SELECT 1;".to_string())) + .expect_err("fatal VFS state should fail before SQL execution"); + let fatal = err + .downcast_ref::() + .expect("error should be typed as fatal worker error"); + assert_eq!(fatal.message(), "first fatal sqlite error"); + + runtime + .block_on(db.close()) + .expect("worker should still close cleanly after fatal state"); +} + +#[test] +fn head_fence_ioerr_maps_to_fatal_worker_error_and_future_operations_fail_closed() { + let runtime = direct_runtime(); + let harness = DirectEngineHarness::new(); + let stale = open_worker_handle(&runtime, &harness); + runtime + .block_on(stale.exec("SELECT 1;".to_string())) + .expect("stale worker should initialize before becoming stale"); + let writer = open_worker_handle(&runtime, &harness); + + runtime + .block_on(writer.exec("CREATE TABLE writer_first (value INTEGER);".to_string())) + .expect("writer should advance depot head"); + + let err = runtime + .block_on(stale.exec("CREATE TABLE stale_writer (value INTEGER);".to_string())) + .expect_err("stale writer should hit a fatal head fence mismatch"); + let fatal = err + .downcast_ref::() + .expect("head fence mismatch should be surfaced as a fatal worker error"); + assert!( + fatal.message().contains("head fence mismatch"), + "unexpected fatal message: {}", + fatal.message() + ); + assert!( + stale + .clone_fatal_error() + .is_some_and(|message| message.contains("head fence mismatch")), + "VFS should retain the fatal head fence error" + ); + + let err = runtime + .block_on(stale.exec("SELECT 1;".to_string())) + .expect_err("future operations should fail closed from stored fatal state"); + assert!( + err.downcast_ref::().is_some(), + "future error should remain typed as fatal worker error: {err:#}" + ); + + runtime + .block_on(stale.close()) + .expect("stale worker should still close cleanly"); + runtime + .block_on(writer.close()) + .expect("writer worker should close cleanly"); +} + #[test] fn commit_buffered_pages_uses_fast_path() { let runtime = direct_runtime(); @@ -3506,6 +3650,7 @@ fn commit_buffered_pages_uses_fast_path() { BufferedCommitRequest { actor_id: harness.actor_id.clone(), new_db_size_pages: 1, + expected_head_txid: None, dirty_pages: vec![protocol::SqliteDirtyPage { pgno: 1, bytes: empty_db_page(), diff --git a/engine/packages/depot-client/tests/inline/vfs_support.rs b/engine/packages/depot-client/tests/inline/vfs_support.rs index 65182e9696..c4e5ecf522 100644 --- a/engine/packages/depot-client/tests/inline/vfs_support.rs +++ b/engine/packages/depot-client/tests/inline/vfs_support.rs @@ -10,6 +10,7 @@ use async_trait::async_trait; use depot::{ cold_tier::{ColdTier, ColdTierObjectMetadata}, conveyer::{Db, db::CompactionSignaler}, + error::SqliteStorageError, fault::DepotFaultController, keys::{ SHARD_SIZE, branch_compaction_cold_shard_key, branch_compaction_root_key, @@ -333,14 +334,24 @@ impl DirectStorage { &self, actor_id: &str, pgnos: &[u32], - ) -> anyhow::Result> { + ) -> anyhow::Result { + self.get_pages_with_options(actor_id, pgnos, depot::types::GetPagesOptions::default()) + .await + } + + pub(crate) async fn get_pages_with_options( + &self, + actor_id: &str, + pgnos: &[u32], + options: depot::types::GetPagesOptions, + ) -> anyhow::Result { if let Some(message) = self.hooks.take_get_pages_error() { return Err(anyhow::anyhow!(message)); } let actor_db = self.actor_db(actor_id.to_string()).await; self.counters.depot_get_pages.fetch_add(1, Ordering::SeqCst); - actor_db.get_pages(pgnos.to_vec()).await + actor_db.get_pages_with_options(pgnos.to_vec(), options).await } pub(crate) async fn read_mirror( @@ -411,11 +422,23 @@ impl SqliteTransport for DirectDepotTransport { &self, request: protocol::SqliteGetPagesRequest, ) -> Result { + self.storage.hooks.record_get_pages_request(request.clone()); let pgnos = request.pgnos.clone(); - match self.storage.get_pages(&request.actor_id, &pgnos).await { - Ok(pages) => Ok(protocol::SqliteGetPagesResponse::SqliteGetPagesOk( + match self + .storage + .get_pages_with_options( + &request.actor_id, + &pgnos, + depot::types::GetPagesOptions { + expected_head_txid: request.expected_head_txid, + }, + ) + .await + { + Ok(result) => Ok(protocol::SqliteGetPagesResponse::SqliteGetPagesOk( protocol::SqliteGetPagesOk { - pages: pages.into_iter().map(protocol_fetched_page).collect(), + pages: result.pages.into_iter().map(protocol_fetched_page).collect(), + head_txid: Some(result.head_txid), }, )), Err(err) => Ok(protocol::SqliteGetPagesResponse::SqliteErrorResponse( @@ -441,10 +464,21 @@ impl SqliteTransport for DirectDepotTransport { .collect::>(); let actor_db = self.storage.actor_db(actor_id).await; match actor_db - .commit(dirty_pages, request.db_size_pages, request.now_ms) + .commit_with_options( + dirty_pages, + request.db_size_pages, + request.now_ms, + depot::types::CommitOptions { + expected_head_txid: request.expected_head_txid, + }, + ) .await { - Ok(_) => Ok(protocol::SqliteCommitResponse::SqliteCommitOk), + Ok(result) => Ok(protocol::SqliteCommitResponse::SqliteCommitOk( + protocol::SqliteCommitOk { + head_txid: Some(result.head_txid), + }, + )), Err(err) => Ok(protocol::SqliteCommitResponse::SqliteErrorResponse( sqlite_error_response(&err), )), @@ -472,6 +506,7 @@ impl SqliteTransport for DirectMirrorTransport { &self, request: protocol::SqliteGetPagesRequest, ) -> Result { + self.storage.hooks.record_get_pages_request(request.clone()); if let Some(message) = self.storage.hooks.take_get_pages_error() { return Err(anyhow::anyhow!(message)); } @@ -483,6 +518,7 @@ impl SqliteTransport for DirectMirrorTransport { Ok(protocol::SqliteGetPagesResponse::SqliteGetPagesOk( protocol::SqliteGetPagesOk { pages: pages.into_iter().map(protocol_fetched_page).collect(), + head_txid: None, }, )) } @@ -507,7 +543,11 @@ impl SqliteTransport for DirectMirrorTransport { .apply_commit(&actor_id, dirty_pages, request.db_size_pages) .await { - Ok(()) => Ok(protocol::SqliteCommitResponse::SqliteCommitOk), + Ok(()) => Ok(protocol::SqliteCommitResponse::SqliteCommitOk( + protocol::SqliteCommitOk { + head_txid: None, + }, + )), Err(err) => Ok(protocol::SqliteCommitResponse::SqliteErrorResponse( sqlite_error_response(&err), )), @@ -562,6 +602,7 @@ pub(crate) struct DirectTransportHooks { fail_next_get_pages: Mutex>, hang_next_commit: Mutex, pause_next_commit: Mutex>, + get_pages_requests: Mutex>, commit_requests: Mutex>, } @@ -584,6 +625,16 @@ impl DirectTransportHooks { self.commit_requests.lock() } + pub(crate) fn get_pages_requests( + &self, + ) -> parking_lot::MutexGuard<'_, Vec> { + self.get_pages_requests.lock() + } + + pub(crate) fn record_get_pages_request(&self, req: protocol::SqliteGetPagesRequest) { + self.get_pages_requests.lock().push(req); + } + pub(crate) fn record_commit_request(&self, req: protocol::SqliteCommitRequest) { self.commit_requests.lock().push(req); } @@ -684,7 +735,17 @@ fn sqlite_error_reason(err: &anyhow::Error) -> String { } pub(crate) fn sqlite_error_response(err: &anyhow::Error) -> protocol::SqliteErrorResponse { + let structured = depot_error(err) + .map(|err| rivet_error::RivetError::extract(&err.clone().build())) + .unwrap_or_else(|| rivet_error::RivetError::extract(err)); protocol::SqliteErrorResponse { + group: structured.group().to_string(), + code: structured.code().to_string(), message: sqlite_error_reason(err), } } + +fn depot_error(err: &anyhow::Error) -> Option<&SqliteStorageError> { + err.chain() + .find_map(|source| source.downcast_ref::()) +} diff --git a/engine/packages/depot/Cargo.toml b/engine/packages/depot/Cargo.toml index 7416aa8cc4..d1bfb96f32 100644 --- a/engine/packages/depot/Cargo.toml +++ b/engine/packages/depot/Cargo.toml @@ -16,6 +16,7 @@ async-trait.workspace = true aws-config = "1" aws-sdk-s3 = "1" base64.workspace = true +depot-client-types.workspace = true futures-util.workspace = true gas.workspace = true lazy_static.workspace = true diff --git a/engine/packages/depot/src/conveyer/commit/apply.rs b/engine/packages/depot/src/conveyer/commit/apply.rs index a4132023af..14814b0cc3 100644 --- a/engine/packages/depot/src/conveyer/commit/apply.rs +++ b/engine/packages/depot/src/conveyer/commit/apply.rs @@ -22,8 +22,9 @@ use crate::{ page_index::DeltaPageIndex, quota, types::{ - BranchState, CommitRow, DBHead, DatabaseBranchId, DirtyPage, decode_compaction_root, - decode_database_branch_record, decode_db_head, encode_commit_row, encode_db_head, + BranchState, CommitOptions, CommitResult, CommitRow, DBHead, DatabaseBranchId, + DirtyPage, decode_compaction_root, decode_database_branch_record, decode_db_head, + encode_commit_row, encode_db_head, }, udb, }, @@ -47,6 +48,23 @@ impl Db { db_size_pages: u32, now_ms: i64, ) -> Result<()> { + self.commit_with_options( + dirty_pages, + db_size_pages, + now_ms, + CommitOptions::default(), + ) + .await + .map(|_| ()) + } + + pub async fn commit_with_options( + &self, + dirty_pages: Vec, + db_size_pages: u32, + now_ms: i64, + options: CommitOptions, + ) -> Result { validate_dirty_pages(&dirty_pages)?; #[cfg(feature = "test-faults")] maybe_fire_commit_fault( @@ -82,6 +100,7 @@ impl Db { let database_id = self.database_id.clone(); let bucket_id = self.sqlite_bucket_id(); let dirty_pages_for_tx = dirty_pages.clone(); + let expected_head_txid = options.expected_head_txid; #[cfg(feature = "test-faults")] let fault_controller = self.fault_controller.clone(); @@ -91,6 +110,7 @@ impl Db { let database_id = database_id.clone(); let bucket_id = bucket_id; let dirty_pages = dirty_pages_for_tx.clone(); + let expected_head_txid = expected_head_txid; let cached_ancestry = cached_ancestry.clone(); let cached_access_bucket = cached_access_bucket; let last_deltas_available_at_ms = last_deltas_available_at_ms; @@ -160,6 +180,23 @@ impl Db { .map(|bytes| decode_db_head(bytes.as_slice())) .transpose() .context("decode current sqlite db head")?; + let actual_head_txid = previous_head.as_ref().map_or(0, |head| head.head_txid); + if let Some(expected_head_txid) = expected_head_txid { + if expected_head_txid != actual_head_txid { + tracing::error!( + %database_id, + ?branch_id, + expected_head_txid, + actual_head_txid, + "sqlite head fence mismatch; this indicates multiple actor instances are writing the same sqlite database in parallel, which is incorrect actor lifecycle behavior" + ); + return Err(SqliteStorageError::HeadFenceMismatch { + expected_head_txid, + actual_head_txid, + } + .into()); + } + } #[cfg(feature = "test-faults")] maybe_fire_commit_fault( &fault_controller, @@ -483,7 +520,10 @@ impl Db { self.publish_deltas_available_if_needed(result.deltas_available, result.branch_id) .await?; - Ok(()) + Ok(CommitResult { + head_txid: result.txid, + db_size_pages, + }) } async fn publish_deltas_available_if_needed( diff --git a/engine/packages/depot/src/conveyer/error.rs b/engine/packages/depot/src/conveyer/error.rs index 837c178265..a0e40ca2ab 100644 --- a/engine/packages/depot/src/conveyer/error.rs +++ b/engine/packages/depot/src/conveyer/error.rs @@ -2,6 +2,10 @@ use rivet_error::RivetError; use serde::Serialize; use std::fmt; +pub use depot_client_types::{ + HEAD_FENCE_MISMATCH_CODE, HEAD_FENCE_MISMATCH_GROUP, is_head_fence_mismatch, +}; + #[derive(Debug, Clone, PartialEq, Eq, RivetError)] #[error("depot")] pub enum SqliteStorageError { @@ -30,6 +34,16 @@ pub enum SqliteStorageError { max_size_bytes: u64, }, + #[error( + "head_fence_mismatch", + "SQLite head fence mismatch.", + "SQLite head fence mismatch. Expected head txid {expected_head_txid}, but current head txid is {actual_head_txid}." + )] + HeadFenceMismatch { + expected_head_txid: u64, + actual_head_txid: u64, + }, + #[error( "quota_exceeded", "Not enough space left in Depot.", @@ -135,6 +149,13 @@ impl fmt::Display for SqliteStorageError { f, "CommitTooLarge: raw dirty pages were {actual_size_bytes} bytes, limit is {max_size_bytes} bytes" ), + SqliteStorageError::HeadFenceMismatch { + expected_head_txid, + actual_head_txid, + } => write!( + f, + "sqlite head fence mismatch: expected head txid {expected_head_txid}, current head txid {actual_head_txid}" + ), SqliteStorageError::SqliteStorageQuotaExceeded { remaining_bytes, payload_size, diff --git a/engine/packages/depot/src/conveyer/read.rs b/engine/packages/depot/src/conveyer/read.rs index f27fa2da1a..10d0784e44 100644 --- a/engine/packages/depot/src/conveyer/read.rs +++ b/engine/packages/depot/src/conveyer/read.rs @@ -25,7 +25,7 @@ use crate::conveyer::{ ltx::{DecodedLtx, decode_ltx_v3}, metrics, page_index::DeltaPageIndex, - types::{DatabaseBranchId, FetchedPage}, + types::{DatabaseBranchId, FetchedPage, GetPagesOptions, GetPagesResult}, }; use self::{ @@ -38,6 +38,21 @@ use self::{ impl Db { pub async fn get_pages(&self, pgnos: Vec) -> Result> { + self.get_pages_with_metadata(pgnos) + .await + .map(|result| result.pages) + } + + pub async fn get_pages_with_metadata(&self, pgnos: Vec) -> Result { + self.get_pages_with_options(pgnos, GetPagesOptions::default()) + .await + } + + pub async fn get_pages_with_options( + &self, + pgnos: Vec, + options: GetPagesOptions, + ) -> Result { let node_id = self.node_id.to_string(); let labels = &[node_id.as_str()]; let _timer = metrics::SQLITE_PUMP_GET_PAGES_DURATION @@ -65,6 +80,7 @@ impl Db { let bucket_id = self.sqlite_bucket_id(); let pgnos_for_tx = pgnos.clone(); let now_ms = cache::now_ms()?; + let expected_head_txid = options.expected_head_txid; #[cfg(feature = "test-faults")] let fault_controller = self.fault_controller.clone(); let tx_result = self @@ -76,6 +92,7 @@ impl Db { let cached_pidx = cached_pidx.clone(); let cached_ancestry = cached_ancestry.clone(); let cached_access_bucket = cached_access_bucket; + let expected_head_txid = expected_head_txid; #[cfg(feature = "test-faults")] let fault_controller = fault_controller.clone(); @@ -114,6 +131,22 @@ impl Db { None, ) .await?; + if let Some(expected_head_txid) = expected_head_txid { + if expected_head_txid != head.head_txid { + tracing::error!( + %database_id, + branch_id = ?scope.branch_id(), + expected_head_txid, + actual_head_txid = head.head_txid, + "sqlite head fence mismatch while reading; this indicates multiple actor instances are accessing the same sqlite database in parallel, which is incorrect actor lifecycle behavior" + ); + return Err(SqliteStorageError::HeadFenceMismatch { + expected_head_txid, + actual_head_txid: head.head_txid, + } + .into()); + } + } let pgnos_in_range = pgnos .into_iter() @@ -125,6 +158,7 @@ impl Db { branch_id, branch_ancestry: scope.branch_ancestry(), access_bucket: None, + head_txid: head.head_txid, db_size_pages: head.db_size_pages, loaded_pidx_rows: None, page_sources: BTreeMap::new(), @@ -376,6 +410,7 @@ impl Db { branch_id, branch_ancestry: scope.branch_ancestry(), access_bucket, + head_txid: head.head_txid, db_size_pages: head.db_size_pages, loaded_pidx_rows, page_sources, @@ -549,7 +584,11 @@ impl Db { self.shard_cache_fill.enqueue_many(shard_cache_fill_jobs); - Ok(pages) + Ok(GetPagesResult { + pages, + head_txid: tx_result.head_txid, + db_size_pages: tx_result.db_size_pages, + }) } } @@ -557,6 +596,7 @@ struct GetPagesTxResult { branch_id: DatabaseBranchId, branch_ancestry: BranchAncestry, access_bucket: Option, + head_txid: u64, db_size_pages: u32, loaded_pidx_rows: Option>, page_sources: BTreeMap>, diff --git a/engine/packages/depot/src/conveyer/types/pages.rs b/engine/packages/depot/src/conveyer/types/pages.rs index 13599cc332..e244c53bde 100644 --- a/engine/packages/depot/src/conveyer/types/pages.rs +++ b/engine/packages/depot/src/conveyer/types/pages.rs @@ -11,3 +11,26 @@ pub struct FetchedPage { pub pgno: u32, pub bytes: Option>, } + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct GetPagesResult { + pub pages: Vec, + pub head_txid: u64, + pub db_size_pages: u32, +} + +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)] +pub struct GetPagesOptions { + pub expected_head_txid: Option, +} + +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)] +pub struct CommitOptions { + pub expected_head_txid: Option, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub struct CommitResult { + pub head_txid: u64, + pub db_size_pages: u32, +} diff --git a/engine/packages/depot/tests/conveyer_commit.rs b/engine/packages/depot/tests/conveyer_commit.rs index 5d396701ca..e91f223829 100644 --- a/engine/packages/depot/tests/conveyer_commit.rs +++ b/engine/packages/depot/tests/conveyer_commit.rs @@ -23,11 +23,11 @@ use depot::{ ltx::{LtxHeader, encode_ltx_v3}, quota::{self, SQLITE_MAX_STORAGE_BYTES}, types::{ - BucketId, CompactionRoot, DBHead, DatabaseBranchId, DirtyPage, FetchedPage, MetaCompact, - SqliteCmpDirty, decode_bucket_branch_record, decode_bucket_pointer, decode_commit_row, - decode_database_branch_record, decode_database_pointer, decode_db_head, - decode_sqlite_cmp_dirty, encode_compaction_root, encode_db_head, encode_meta_compact, - encode_sqlite_cmp_dirty, + BucketId, CommitOptions, CompactionRoot, DBHead, DatabaseBranchId, DirtyPage, FetchedPage, + GetPagesOptions, MetaCompact, SqliteCmpDirty, decode_bucket_branch_record, + decode_bucket_pointer, decode_commit_row, decode_database_branch_record, + decode_database_pointer, decode_db_head, decode_sqlite_cmp_dirty, encode_compaction_root, + encode_db_head, encode_meta_compact, encode_sqlite_cmp_dirty, }, workflows::compaction::DeltasAvailable, }; @@ -276,6 +276,105 @@ async fn commit_lazily_initializes_meta_on_first_write() -> Result<()> { }) } +#[tokio::test] +async fn commit_head_fence_rejects_stale_writer() -> Result<()> { + commit_matrix!("depot-commit-head-fence-stale", |ctx, db, database_db| { + let _ = &db; + let first = database_db + .commit_with_options( + vec![page(1, 0x11)], + 1, + 1_000, + CommitOptions { + expected_head_txid: Some(0), + }, + ) + .await?; + assert_eq!(first.head_txid, 1); + + let err = database_db + .commit_with_options( + vec![page(1, 0x22)], + 1, + 1_001, + CommitOptions { + expected_head_txid: Some(0), + }, + ) + .await + .expect_err("stale writer should be rejected"); + assert!( + err.chain() + .any(|source| source.to_string().contains("head fence mismatch")), + "unexpected error: {err:#}" + ); + assert_eq!( + database_db.get_pages(vec![1]).await?, + vec![fetched_page(1, 0x11)] + ); + + Ok(()) + }) +} + +#[tokio::test] +async fn get_pages_head_fence_rejects_stale_reader() -> Result<()> { + commit_matrix!("depot-get-pages-head-fence-stale", |ctx, db, database_db| { + let _ = &db; + let first = database_db + .commit_with_options( + vec![page(1, 0x11)], + 1, + 1_000, + CommitOptions { + expected_head_txid: Some(0), + }, + ) + .await?; + assert_eq!(first.head_txid, 1); + + let read = database_db + .get_pages_with_options( + vec![1], + GetPagesOptions { + expected_head_txid: Some(1), + }, + ) + .await?; + assert_eq!(read.head_txid, 1); + assert_eq!(read.pages, vec![fetched_page(1, 0x11)]); + + let second = database_db + .commit_with_options( + vec![page(1, 0x22)], + 1, + 1_001, + CommitOptions { + expected_head_txid: Some(1), + }, + ) + .await?; + assert_eq!(second.head_txid, 2); + + let err = database_db + .get_pages_with_options( + vec![1], + GetPagesOptions { + expected_head_txid: Some(1), + }, + ) + .await + .expect_err("stale reader should be rejected"); + assert!( + err.chain() + .any(|source| source.to_string().contains("head fence mismatch")), + "unexpected error: {err:#}" + ); + + Ok(()) + }) +} + #[tokio::test] async fn commit_rejects_invalid_dirty_pages_before_storage_writes() -> Result<()> { commit_matrix!("depot-commit-invalid-dirty", |ctx, db, database_db| { diff --git a/engine/packages/depot/tests/conveyer_error.rs b/engine/packages/depot/tests/conveyer_error.rs index f757b3daa3..d3224cbb41 100644 --- a/engine/packages/depot/tests/conveyer_error.rs +++ b/engine/packages/depot/tests/conveyer_error.rs @@ -1,4 +1,4 @@ -use depot::error::SqliteStorageError; +use depot::error::{HEAD_FENCE_MISMATCH_CODE, HEAD_FENCE_MISMATCH_GROUP, SqliteStorageError}; #[test] fn pitr_errors_are_typed_and_downcastable() { @@ -10,6 +10,18 @@ fn pitr_errors_are_typed_and_downcastable() { assert_eq!(storage_err, &SqliteStorageError::ForkOutOfRetention); } +#[test] +fn head_fence_mismatch_constants_match_rivet_error_schema() { + let err = SqliteStorageError::HeadFenceMismatch { + expected_head_txid: 1, + actual_head_txid: 2, + }; + let rivet_err = rivet_error::RivetError::extract(&err.build()); + + assert_eq!(rivet_err.group(), HEAD_FENCE_MISMATCH_GROUP); + assert_eq!(rivet_err.code(), HEAD_FENCE_MISMATCH_CODE); +} + #[test] fn pitr_errors_have_rivet_error_codes() { let cases = [ diff --git a/engine/packages/pegboard-envoy/src/ws_to_tunnel_task.rs b/engine/packages/pegboard-envoy/src/ws_to_tunnel_task.rs index 45e7bb65f3..577baa91bf 100644 --- a/engine/packages/pegboard-envoy/src/ws_to_tunnel_task.rs +++ b/engine/packages/pegboard-envoy/src/ws_to_tunnel_task.rs @@ -2,7 +2,7 @@ use anyhow::{Context, bail}; use bytes::Bytes; use depot::{ conveyer::Db, - error::SqliteStorageError, + error::{is_head_fence_mismatch, SqliteStorageError}, workflows::compaction::{ DATABASE_BRANCH_ID_TAG, DbManagerInput, DeltasAvailable, database_branch_tag_value, }, @@ -693,19 +693,28 @@ async fn handle_sqlite_get_pages( validate_sqlite_actor(ctx, conn, &request.actor_id).await?; let actor_db = actor_db(ctx, conn, request.actor_id.clone()).await?; - let pages = actor_db.get_pages(request.pgnos).await?; - Ok(sqlite_get_pages_ok(pages).await?) + let result = actor_db + .get_pages_with_options( + request.pgnos, + depot::types::GetPagesOptions { + expected_head_txid: request.expected_head_txid, + }, + ) + .await?; + Ok(sqlite_get_pages_ok(result).await?) } async fn sqlite_get_pages_ok( - pages: Vec, + result: depot::types::GetPagesResult, ) -> Result { Ok(protocol::SqliteGetPagesResponse::SqliteGetPagesOk( protocol::SqliteGetPagesOk { - pages: pages + pages: result + .pages .into_iter() .map(sqlite_runtime::protocol_sqlite_conveyer_fetched_page) .collect(), + head_txid: Some(result.head_txid), }, )) } @@ -724,7 +733,7 @@ async fn handle_sqlite_commit( let actor_id = request.actor_id.clone(); let actor_db = actor_db(ctx, conn, actor_id.clone()).await?; let engine_result = actor_db - .commit( + .commit_with_options( request .dirty_pages .into_iter() @@ -732,17 +741,26 @@ async fn handle_sqlite_commit( .collect(), request.db_size_pages, request.now_ms, + depot::types::CommitOptions { + expected_head_txid: request.expected_head_txid, + }, ) .await; let response_build_start = Instant::now(); let response = match engine_result { - Ok(()) => Ok(protocol::SqliteCommitResponse::SqliteCommitOk), + Ok(result) => Ok(protocol::SqliteCommitResponse::SqliteCommitOk( + protocol::SqliteCommitOk { + head_txid: Some(result.head_txid), + }, + )), Err(err) => match depot_error(&err) { Some(SqliteStorageError::CommitTooLarge { actual_size_bytes, max_size_bytes, }) => Ok(protocol::SqliteCommitResponse::SqliteErrorResponse( protocol::SqliteErrorResponse { + group: "depot".to_string(), + code: "commit_too_large".to_string(), message: format!( "sqlite commit too large: actual_size_bytes={actual_size_bytes}, max_size_bytes={max_size_bytes}" ), @@ -1093,7 +1111,8 @@ async fn actor_db(ctx: &StandaloneCtx, conn: &Conn, actor_id: String) -> Result< } fn depot_error(err: &anyhow::Error) -> Option<&SqliteStorageError> { - err.downcast_ref::() + err.chain() + .find_map(|source| source.downcast_ref::()) } fn sqlite_error_reason(err: &anyhow::Error) -> String { @@ -1104,7 +1123,19 @@ fn sqlite_error_reason(err: &anyhow::Error) -> String { } fn sqlite_error_response(err: &anyhow::Error) -> protocol::SqliteErrorResponse { + let structured = depot_error(err) + .map(|err| rivet_error::RivetError::extract(&err.clone().build())) + .unwrap_or_else(|| rivet_error::RivetError::extract(err)); + if is_head_fence_mismatch(structured.group(), structured.code()) { + tracing::error!( + error_group = structured.group(), + error_code = structured.code(), + "sqlite head fence mismatch from Depot; this indicates multiple actor instances are accessing the same sqlite database in parallel, which is incorrect actor lifecycle behavior" + ); + } protocol::SqliteErrorResponse { + group: structured.group().to_string(), + code: structured.code().to_string(), message: sqlite_error_reason(err), } } diff --git a/engine/sdks/rust/envoy-protocol/src/lib.rs b/engine/sdks/rust/envoy-protocol/src/lib.rs index 167dcbe173..87d6be2058 100644 --- a/engine/sdks/rust/envoy-protocol/src/lib.rs +++ b/engine/sdks/rust/envoy-protocol/src/lib.rs @@ -3,6 +3,6 @@ pub mod util; pub mod versioned; // Re-export latest -pub use generated::v4::*; +pub use generated::v5::*; pub use generated::PROTOCOL_VERSION; diff --git a/engine/sdks/rust/envoy-protocol/src/versioned.rs b/engine/sdks/rust/envoy-protocol/src/versioned.rs index 3c4952339b..3ae4119811 100644 --- a/engine/sdks/rust/envoy-protocol/src/versioned.rs +++ b/engine/sdks/rust/envoy-protocol/src/versioned.rs @@ -2,7 +2,7 @@ use anyhow::{Result, bail}; use std::{error::Error, fmt}; use vbare::OwnedVersionedData; -use crate::generated::{v1, v2, v3, v4}; +use crate::generated::{v1, v2, v3, v4, v5}; fn convert_same_bytes(message: T) -> Result where @@ -100,131 +100,156 @@ fn incompatible( } pub enum ToEnvoy { - V4(v4::ToEnvoy), + V5(v5::ToEnvoy), } impl OwnedVersionedData for ToEnvoy { - type Latest = v4::ToEnvoy; + type Latest = v5::ToEnvoy; fn wrap_latest(latest: Self::Latest) -> Self { - Self::V4(latest) + Self::V5(latest) } fn unwrap_latest(self) -> Result { match self { - Self::V4(data) => Ok(data), + Self::V5(data) => Ok(data), } } fn deserialize_version(payload: &[u8], version: u16) -> Result { - Ok(Self::V4(match version { - 1 => convert_to_envoy_v3_to_v4(convert_to_envoy_v2_to_v3(convert_to_envoy_v1_to_v2( - serde_bare::from_slice(payload)?, - )?)?)?, - 2 => convert_to_envoy_v3_to_v4(convert_to_envoy_v2_to_v3(serde_bare::from_slice( + Ok(Self::V5(match version { + 1 => convert_to_envoy_v4_to_v5(convert_to_envoy_v3_to_v4( + convert_to_envoy_v2_to_v3(convert_to_envoy_v1_to_v2(serde_bare::from_slice( + payload, + )?)?)?, + )?)?, + 2 => convert_to_envoy_v4_to_v5(convert_to_envoy_v3_to_v4( + convert_to_envoy_v2_to_v3(serde_bare::from_slice(payload)?)?, + )?)?, + 3 => convert_to_envoy_v4_to_v5(convert_to_envoy_v3_to_v4(serde_bare::from_slice( payload, )?)?)?, - 3 => convert_to_envoy_v3_to_v4(serde_bare::from_slice(payload)?)?, - 4 => serde_bare::from_slice(payload)?, + 4 => convert_to_envoy_v4_to_v5(serde_bare::from_slice(payload)?)?, + 5 => serde_bare::from_slice(payload)?, _ => bail!("invalid version: {version}"), })) } fn serialize_version(self, version: u16) -> Result> { - let Self::V4(data) = self; + let Self::V5(data) = self; match version { - 1 => serde_bare::to_vec(&convert_to_envoy_v2_to_v1(convert_to_envoy_v3_to_v2( - convert_to_envoy_v4_to_v3(data, 1)?, - )?)?) - .map_err(Into::into), - 2 => serde_bare::to_vec(&convert_to_envoy_v3_to_v2(convert_to_envoy_v4_to_v3( - data, 2, - )?)?) - .map_err(Into::into), - 3 => serde_bare::to_vec(&convert_to_envoy_v4_to_v3(data, 3)?).map_err(Into::into), - 4 => serde_bare::to_vec(&data).map_err(Into::into), + 1 => { + let data = convert_to_envoy_v5_to_v4(data)?; + serde_bare::to_vec(&convert_to_envoy_v2_to_v1(convert_to_envoy_v3_to_v2( + convert_to_envoy_v4_to_v3(data, 1)?, + )?)?) + .map_err(Into::into) + } + 2 => { + let data = convert_to_envoy_v5_to_v4(data)?; + serde_bare::to_vec(&convert_to_envoy_v3_to_v2(convert_to_envoy_v4_to_v3( + data, 2, + )?)?) + .map_err(Into::into) + } + 3 => { + let data = convert_to_envoy_v5_to_v4(data)?; + serde_bare::to_vec(&convert_to_envoy_v4_to_v3(data, 3)?).map_err(Into::into) + } + 4 => serde_bare::to_vec(&convert_to_envoy_v5_to_v4(data)?).map_err(Into::into), + 5 => serde_bare::to_vec(&data).map_err(Into::into), _ => bail!("invalid version: {version}"), } } fn deserialize_converters() -> Vec Result> { - vec![Ok, Ok, Ok] + vec![Ok, Ok, Ok, Ok] } fn serialize_converters() -> Vec Result> { - vec![Ok, Ok, Ok] + vec![Ok, Ok, Ok, Ok] } } pub enum ToRivet { - V4(v4::ToRivet), + V5(v5::ToRivet), } impl OwnedVersionedData for ToRivet { - type Latest = v4::ToRivet; + type Latest = v5::ToRivet; fn wrap_latest(latest: Self::Latest) -> Self { - Self::V4(latest) + Self::V5(latest) } fn unwrap_latest(self) -> Result { match self { - Self::V4(data) => Ok(data), + Self::V5(data) => Ok(data), } } fn deserialize_version(payload: &[u8], version: u16) -> Result { - Ok(Self::V4(match version { - 1 | 2 => convert_to_rivet_v3_to_v4(convert_to_rivet_v2_to_v3( + Ok(Self::V5(match version { + 1 | 2 => convert_to_rivet_v4_to_v5(convert_to_rivet_v3_to_v4(convert_to_rivet_v2_to_v3( serde_bare::from_slice(payload)?, - )?)?, - 3 => convert_to_rivet_v3_to_v4(serde_bare::from_slice(payload)?)?, - 4 => serde_bare::from_slice(payload)?, + )?)?)?, + 3 => convert_to_rivet_v4_to_v5(convert_to_rivet_v3_to_v4(serde_bare::from_slice( + payload, + )?)?)?, + 4 => convert_to_rivet_v4_to_v5(serde_bare::from_slice(payload)?)?, + 5 => serde_bare::from_slice(payload)?, _ => bail!("invalid version: {version}"), })) } fn serialize_version(self, version: u16) -> Result> { - let Self::V4(data) = self; + let Self::V5(data) = self; match version { - 1 | 2 => serde_bare::to_vec(&convert_to_rivet_v3_to_v2(convert_to_rivet_v4_to_v3( - data, version, - )?)?) - .map_err(Into::into), - 3 => serde_bare::to_vec(&convert_to_rivet_v4_to_v3(data, 3)?).map_err(Into::into), - 4 => serde_bare::to_vec(&data).map_err(Into::into), + 1 | 2 => { + let data = convert_to_rivet_v5_to_v4(data)?; + serde_bare::to_vec(&convert_to_rivet_v3_to_v2(convert_to_rivet_v4_to_v3( + data, version, + )?)?) + .map_err(Into::into) + } + 3 => { + let data = convert_to_rivet_v5_to_v4(data)?; + serde_bare::to_vec(&convert_to_rivet_v4_to_v3(data, 3)?).map_err(Into::into) + } + 4 => serde_bare::to_vec(&convert_to_rivet_v5_to_v4(data)?).map_err(Into::into), + 5 => serde_bare::to_vec(&data).map_err(Into::into), _ => bail!("invalid version: {version}"), } } fn deserialize_converters() -> Vec Result> { - vec![Ok, Ok, Ok] + vec![Ok, Ok, Ok, Ok] } fn serialize_converters() -> Vec Result> { - vec![Ok, Ok, Ok] + vec![Ok, Ok, Ok, Ok] } } pub enum ToEnvoyConn { - V4(v4::ToEnvoyConn), + V5(v5::ToEnvoyConn), } impl OwnedVersionedData for ToEnvoyConn { - type Latest = v4::ToEnvoyConn; + type Latest = v5::ToEnvoyConn; fn wrap_latest(latest: Self::Latest) -> Self { - Self::V4(latest) + Self::V5(latest) } fn unwrap_latest(self) -> Result { match self { - Self::V4(data) => Ok(data), + Self::V5(data) => Ok(data), } } fn deserialize_version(payload: &[u8], version: u16) -> Result { - Ok(Self::V4(match version { + Ok(Self::V5(match version { 1 => convert_same_bytes(convert_to_envoy_conn_v1_to_v3(serde_bare::from_slice( payload, )?)?)?, @@ -232,55 +257,65 @@ impl OwnedVersionedData for ToEnvoyConn { payload, )?)?)?, 3 => convert_same_bytes(serde_bare::from_slice::(payload)?)?, - 4 => serde_bare::from_slice(payload)?, + 4 => convert_same_bytes(serde_bare::from_slice::(payload)?)?, + 5 => serde_bare::from_slice(payload)?, _ => bail!("invalid version: {version}"), })) } fn serialize_version(self, version: u16) -> Result> { - let Self::V4(data) = self; - let data_v3 = || convert_same_bytes_ref::<_, v3::ToEnvoyConn>(&data); + let Self::V5(data) = self; + let data_v4 = || convert_same_bytes_ref::<_, v4::ToEnvoyConn>(&data); match version { 1 => { - serde_bare::to_vec(&convert_to_envoy_conn_v3_to_v1(data_v3()?)?).map_err(Into::into) + let data = data_v4()?; + let data_v3 = convert_same_bytes_ref::<_, v3::ToEnvoyConn>(&data)?; + serde_bare::to_vec(&convert_to_envoy_conn_v3_to_v1(data_v3)?).map_err(Into::into) } 2 => { - serde_bare::to_vec(&convert_to_envoy_conn_v3_to_v2(data_v3()?)?).map_err(Into::into) + let data = data_v4()?; + let data_v3 = convert_same_bytes_ref::<_, v3::ToEnvoyConn>(&data)?; + serde_bare::to_vec(&convert_to_envoy_conn_v3_to_v2(data_v3)?).map_err(Into::into) + } + 3 => { + let data = data_v4()?; + serde_bare::to_vec(&convert_same_bytes_ref::<_, v3::ToEnvoyConn>(&data)?) + .map_err(Into::into) } - 3 => serde_bare::to_vec(&data_v3()?).map_err(Into::into), - 4 => serde_bare::to_vec(&data).map_err(Into::into), + 4 => serde_bare::to_vec(&data_v4()?).map_err(Into::into), + 5 => serde_bare::to_vec(&data).map_err(Into::into), _ => bail!("invalid version: {version}"), } } fn deserialize_converters() -> Vec Result> { - vec![Ok, Ok, Ok] + vec![Ok, Ok, Ok, Ok] } fn serialize_converters() -> Vec Result> { - vec![Ok, Ok, Ok] + vec![Ok, Ok, Ok, Ok] } } pub enum ToGateway { - V4(v4::ToGateway), + V5(v5::ToGateway), } impl OwnedVersionedData for ToGateway { - type Latest = v4::ToGateway; + type Latest = v5::ToGateway; fn wrap_latest(latest: Self::Latest) -> Self { - Self::V4(latest) + Self::V5(latest) } fn unwrap_latest(self) -> Result { match self { - Self::V4(data) => Ok(data), + Self::V5(data) => Ok(data), } } fn deserialize_version(payload: &[u8], version: u16) -> Result { - Ok(Self::V4(match version { + Ok(Self::V5(match version { 1 => convert_same_bytes(convert_to_gateway_v1_to_v3(serde_bare::from_slice( payload, )?))?, @@ -288,51 +323,65 @@ impl OwnedVersionedData for ToGateway { payload, )?))?, 3 => convert_same_bytes(serde_bare::from_slice::(payload)?)?, - 4 => serde_bare::from_slice(payload)?, + 4 => convert_same_bytes(serde_bare::from_slice::(payload)?)?, + 5 => serde_bare::from_slice(payload)?, _ => bail!("invalid version: {version}"), })) } fn serialize_version(self, version: u16) -> Result> { - let Self::V4(data) = self; - let data_v3 = || convert_same_bytes_ref::<_, v3::ToGateway>(&data); + let Self::V5(data) = self; + let data_v4 = || convert_same_bytes_ref::<_, v4::ToGateway>(&data); match version { - 1 => serde_bare::to_vec(&convert_to_gateway_v3_to_v1(data_v3()?)).map_err(Into::into), - 2 => serde_bare::to_vec(&convert_to_gateway_v3_to_v2(data_v3()?)).map_err(Into::into), - 3 => serde_bare::to_vec(&data_v3()?).map_err(Into::into), - 4 => serde_bare::to_vec(&data).map_err(Into::into), + 1 => { + let data = data_v4()?; + let data = convert_same_bytes_ref::<_, v3::ToGateway>(&data)?; + serde_bare::to_vec(&convert_to_gateway_v3_to_v1(data)).map_err(Into::into) + } + 2 => { + let data = data_v4()?; + let data = convert_same_bytes_ref::<_, v3::ToGateway>(&data)?; + serde_bare::to_vec(&convert_to_gateway_v3_to_v2(data)).map_err(Into::into) + } + 3 => { + let data = data_v4()?; + serde_bare::to_vec(&convert_same_bytes_ref::<_, v3::ToGateway>(&data)?) + .map_err(Into::into) + } + 4 => serde_bare::to_vec(&data_v4()?).map_err(Into::into), + 5 => serde_bare::to_vec(&data).map_err(Into::into), _ => bail!("invalid version: {version}"), } } fn deserialize_converters() -> Vec Result> { - vec![Ok, Ok, Ok] + vec![Ok, Ok, Ok, Ok] } fn serialize_converters() -> Vec Result> { - vec![Ok, Ok, Ok] + vec![Ok, Ok, Ok, Ok] } } pub enum ToOutbound { - V4(v4::ToOutbound), + V5(v5::ToOutbound), } impl OwnedVersionedData for ToOutbound { - type Latest = v4::ToOutbound; + type Latest = v5::ToOutbound; fn wrap_latest(latest: Self::Latest) -> Self { - Self::V4(latest) + Self::V5(latest) } fn unwrap_latest(self) -> Result { match self { - Self::V4(data) => Ok(data), + Self::V5(data) => Ok(data), } } fn deserialize_version(payload: &[u8], version: u16) -> Result { - Ok(Self::V4(match version { + Ok(Self::V5(match version { 1 => convert_same_bytes(convert_to_outbound_v1_to_v3(serde_bare::from_slice( payload, )?))?, @@ -340,51 +389,65 @@ impl OwnedVersionedData for ToOutbound { payload, )?))?, 3 => convert_same_bytes(serde_bare::from_slice::(payload)?)?, - 4 => serde_bare::from_slice(payload)?, + 4 => convert_same_bytes(serde_bare::from_slice::(payload)?)?, + 5 => serde_bare::from_slice(payload)?, _ => bail!("invalid version: {version}"), })) } fn serialize_version(self, version: u16) -> Result> { - let Self::V4(data) = self; - let data_v3 = || convert_same_bytes_ref::<_, v3::ToOutbound>(&data); + let Self::V5(data) = self; + let data_v4 = || convert_same_bytes_ref::<_, v4::ToOutbound>(&data); match version { - 1 => serde_bare::to_vec(&convert_to_outbound_v3_to_v1(data_v3()?)).map_err(Into::into), - 2 => serde_bare::to_vec(&convert_to_outbound_v3_to_v2(data_v3()?)).map_err(Into::into), - 3 => serde_bare::to_vec(&data_v3()?).map_err(Into::into), - 4 => serde_bare::to_vec(&data).map_err(Into::into), + 1 => { + let data = data_v4()?; + let data = convert_same_bytes_ref::<_, v3::ToOutbound>(&data)?; + serde_bare::to_vec(&convert_to_outbound_v3_to_v1(data)).map_err(Into::into) + } + 2 => { + let data = data_v4()?; + let data = convert_same_bytes_ref::<_, v3::ToOutbound>(&data)?; + serde_bare::to_vec(&convert_to_outbound_v3_to_v2(data)).map_err(Into::into) + } + 3 => { + let data = data_v4()?; + serde_bare::to_vec(&convert_same_bytes_ref::<_, v3::ToOutbound>(&data)?) + .map_err(Into::into) + } + 4 => serde_bare::to_vec(&data_v4()?).map_err(Into::into), + 5 => serde_bare::to_vec(&data).map_err(Into::into), _ => bail!("invalid version: {version}"), } } fn deserialize_converters() -> Vec Result> { - vec![Ok, Ok, Ok] + vec![Ok, Ok, Ok, Ok] } fn serialize_converters() -> Vec Result> { - vec![Ok, Ok, Ok] + vec![Ok, Ok, Ok, Ok] } } pub enum ActorCommandKeyData { - V4(v4::ActorCommandKeyData), + V5(v5::ActorCommandKeyData), } impl OwnedVersionedData for ActorCommandKeyData { - type Latest = v4::ActorCommandKeyData; + type Latest = v5::ActorCommandKeyData; fn wrap_latest(latest: Self::Latest) -> Self { - Self::V4(latest) + Self::V5(latest) } fn unwrap_latest(self) -> Result { match self { - Self::V4(data) => Ok(data), + Self::V5(data) => Ok(data), } } fn deserialize_version(payload: &[u8], version: u16) -> Result { - Ok(Self::V4(match version { + Ok(Self::V5(match version { 1 => convert_same_bytes(convert_actor_command_key_data_v1_to_v3( serde_bare::from_slice(payload)?, ))?, @@ -392,34 +455,287 @@ impl OwnedVersionedData for ActorCommandKeyData { serde_bare::from_slice(payload)?, ))?, 3 => convert_same_bytes(serde_bare::from_slice::(payload)?)?, - 4 => serde_bare::from_slice(payload)?, + 4 => convert_same_bytes(serde_bare::from_slice::(payload)?)?, + 5 => serde_bare::from_slice(payload)?, _ => bail!("invalid version: {version}"), })) } fn serialize_version(self, version: u16) -> Result> { - let Self::V4(data) = self; - let data_v3 = || convert_same_bytes_ref::<_, v3::ActorCommandKeyData>(&data); + let Self::V5(data) = self; + let data_v4 = || convert_same_bytes_ref::<_, v4::ActorCommandKeyData>(&data); match version { - 1 => serde_bare::to_vec(&convert_actor_command_key_data_v3_to_v1(data_v3()?)) - .map_err(Into::into), - 2 => serde_bare::to_vec(&convert_actor_command_key_data_v3_to_v2(data_v3()?)) - .map_err(Into::into), - 3 => serde_bare::to_vec(&data_v3()?).map_err(Into::into), - 4 => serde_bare::to_vec(&data).map_err(Into::into), + 1 => { + let data = data_v4()?; + let data = convert_same_bytes_ref::<_, v3::ActorCommandKeyData>(&data)?; + serde_bare::to_vec(&convert_actor_command_key_data_v3_to_v1(data)) + .map_err(Into::into) + } + 2 => { + let data = data_v4()?; + let data = convert_same_bytes_ref::<_, v3::ActorCommandKeyData>(&data)?; + serde_bare::to_vec(&convert_actor_command_key_data_v3_to_v2(data)) + .map_err(Into::into) + } + 3 => { + let data = data_v4()?; + serde_bare::to_vec(&convert_same_bytes_ref::<_, v3::ActorCommandKeyData>(&data)?) + .map_err(Into::into) + } + 4 => serde_bare::to_vec(&data_v4()?).map_err(Into::into), + 5 => serde_bare::to_vec(&data).map_err(Into::into), _ => bail!("invalid version: {version}"), } } fn deserialize_converters() -> Vec Result> { - vec![Ok, Ok, Ok] + vec![Ok, Ok, Ok, Ok] } fn serialize_converters() -> Vec Result> { - vec![Ok, Ok, Ok] + vec![Ok, Ok, Ok, Ok] + } +} + +fn convert_to_envoy_v4_to_v5(message: v4::ToEnvoy) -> Result { + Ok(match message { + v4::ToEnvoy::ToEnvoyInit(data) => v5::ToEnvoy::ToEnvoyInit(convert_same_bytes(data)?), + v4::ToEnvoy::ToEnvoyCommands(data) => { + v5::ToEnvoy::ToEnvoyCommands(convert_same_bytes(data)?) + } + v4::ToEnvoy::ToEnvoyAckEvents(data) => { + v5::ToEnvoy::ToEnvoyAckEvents(convert_same_bytes(data)?) + } + v4::ToEnvoy::ToEnvoyKvResponse(data) => { + v5::ToEnvoy::ToEnvoyKvResponse(convert_same_bytes(data)?) + } + v4::ToEnvoy::ToEnvoyTunnelMessage(data) => { + v5::ToEnvoy::ToEnvoyTunnelMessage(convert_same_bytes(data)?) + } + v4::ToEnvoy::ToEnvoyPing(data) => v5::ToEnvoy::ToEnvoyPing(convert_same_bytes(data)?), + v4::ToEnvoy::ToEnvoySqliteGetPagesResponse(data) => { + v5::ToEnvoy::ToEnvoySqliteGetPagesResponse(v5::ToEnvoySqliteGetPagesResponse { + request_id: data.request_id, + data: convert_sqlite_get_pages_response_v4_to_v5(data.data)?, + }) + } + v4::ToEnvoy::ToEnvoySqliteCommitResponse(data) => { + v5::ToEnvoy::ToEnvoySqliteCommitResponse(v5::ToEnvoySqliteCommitResponse { + request_id: data.request_id, + data: convert_sqlite_commit_response_v4_to_v5(data.data)?, + }) + } + v4::ToEnvoy::ToEnvoySqliteExecResponse(data) => { + v5::ToEnvoy::ToEnvoySqliteExecResponse(v5::ToEnvoySqliteExecResponse { + request_id: data.request_id, + data: convert_sqlite_exec_response_v4_to_v5(data.data)?, + }) + } + v4::ToEnvoy::ToEnvoySqliteExecuteResponse(data) => { + v5::ToEnvoy::ToEnvoySqliteExecuteResponse(v5::ToEnvoySqliteExecuteResponse { + request_id: data.request_id, + data: convert_sqlite_execute_response_v4_to_v5(data.data)?, + }) + } + }) +} + +fn convert_to_envoy_v5_to_v4(message: v5::ToEnvoy) -> Result { + Ok(match message { + v5::ToEnvoy::ToEnvoyInit(data) => v4::ToEnvoy::ToEnvoyInit(convert_same_bytes(data)?), + v5::ToEnvoy::ToEnvoyCommands(data) => { + v4::ToEnvoy::ToEnvoyCommands(convert_same_bytes(data)?) + } + v5::ToEnvoy::ToEnvoyAckEvents(data) => { + v4::ToEnvoy::ToEnvoyAckEvents(convert_same_bytes(data)?) + } + v5::ToEnvoy::ToEnvoyKvResponse(data) => { + v4::ToEnvoy::ToEnvoyKvResponse(convert_same_bytes(data)?) + } + v5::ToEnvoy::ToEnvoyTunnelMessage(data) => { + v4::ToEnvoy::ToEnvoyTunnelMessage(convert_same_bytes(data)?) + } + v5::ToEnvoy::ToEnvoyPing(data) => v4::ToEnvoy::ToEnvoyPing(convert_same_bytes(data)?), + v5::ToEnvoy::ToEnvoySqliteGetPagesResponse(data) => { + v4::ToEnvoy::ToEnvoySqliteGetPagesResponse(v4::ToEnvoySqliteGetPagesResponse { + request_id: data.request_id, + data: convert_sqlite_get_pages_response_v5_to_v4(data.data)?, + }) + } + v5::ToEnvoy::ToEnvoySqliteCommitResponse(data) => { + v4::ToEnvoy::ToEnvoySqliteCommitResponse(v4::ToEnvoySqliteCommitResponse { + request_id: data.request_id, + data: convert_sqlite_commit_response_v5_to_v4(data.data)?, + }) + } + v5::ToEnvoy::ToEnvoySqliteExecResponse(data) => { + v4::ToEnvoy::ToEnvoySqliteExecResponse(v4::ToEnvoySqliteExecResponse { + request_id: data.request_id, + data: convert_sqlite_exec_response_v5_to_v4(data.data)?, + }) + } + v5::ToEnvoy::ToEnvoySqliteExecuteResponse(data) => { + v4::ToEnvoy::ToEnvoySqliteExecuteResponse(v4::ToEnvoySqliteExecuteResponse { + request_id: data.request_id, + data: convert_sqlite_execute_response_v5_to_v4(data.data)?, + }) + } + }) +} + +fn convert_sqlite_error_response_v4_to_v5( + error: v4::SqliteErrorResponse, +) -> v5::SqliteErrorResponse { + v5::SqliteErrorResponse { + group: "core".to_string(), + code: "internal_error".to_string(), + message: error.message, + } +} + +fn convert_sqlite_error_response_v5_to_v4( + error: v5::SqliteErrorResponse, +) -> v4::SqliteErrorResponse { + v4::SqliteErrorResponse { + message: error.message, } } +fn convert_sqlite_get_pages_response_v4_to_v5( + message: v4::SqliteGetPagesResponse, +) -> Result { + Ok(match message { + v4::SqliteGetPagesResponse::SqliteGetPagesOk(ok) => { + v5::SqliteGetPagesResponse::SqliteGetPagesOk(v5::SqliteGetPagesOk { + pages: convert_same_bytes(ok.pages)?, + head_txid: None, + }) + } + v4::SqliteGetPagesResponse::SqliteErrorResponse(error) => { + v5::SqliteGetPagesResponse::SqliteErrorResponse( + convert_sqlite_error_response_v4_to_v5(error), + ) + } + }) +} + +fn convert_sqlite_get_pages_response_v5_to_v4( + message: v5::SqliteGetPagesResponse, +) -> Result { + Ok(match message { + v5::SqliteGetPagesResponse::SqliteGetPagesOk(ok) => { + v4::SqliteGetPagesResponse::SqliteGetPagesOk(v4::SqliteGetPagesOk { + pages: convert_same_bytes(ok.pages)?, + }) + } + v5::SqliteGetPagesResponse::SqliteErrorResponse(error) => { + v4::SqliteGetPagesResponse::SqliteErrorResponse( + convert_sqlite_error_response_v5_to_v4(error), + ) + } + }) +} + +fn convert_sqlite_commit_response_v4_to_v5( + message: v4::SqliteCommitResponse, +) -> Result { + Ok(match message { + v4::SqliteCommitResponse::SqliteCommitOk => { + v5::SqliteCommitResponse::SqliteCommitOk(v5::SqliteCommitOk { + head_txid: None, + }) + } + v4::SqliteCommitResponse::SqliteErrorResponse(error) => { + v5::SqliteCommitResponse::SqliteErrorResponse( + convert_sqlite_error_response_v4_to_v5(error), + ) + } + }) +} + +fn convert_sqlite_commit_response_v5_to_v4( + message: v5::SqliteCommitResponse, +) -> Result { + Ok(match message { + v5::SqliteCommitResponse::SqliteCommitOk(_) => { + v4::SqliteCommitResponse::SqliteCommitOk + } + v5::SqliteCommitResponse::SqliteErrorResponse(error) => { + v4::SqliteCommitResponse::SqliteErrorResponse( + convert_sqlite_error_response_v5_to_v4(error), + ) + } + }) +} + +fn convert_sqlite_exec_response_v4_to_v5( + message: v4::SqliteExecResponse, +) -> Result { + Ok(match message { + v4::SqliteExecResponse::SqliteExecOk(ok) => { + v5::SqliteExecResponse::SqliteExecOk(convert_same_bytes(ok)?) + } + v4::SqliteExecResponse::SqliteErrorResponse(error) => { + v5::SqliteExecResponse::SqliteErrorResponse( + convert_sqlite_error_response_v4_to_v5(error), + ) + } + }) +} + +fn convert_sqlite_exec_response_v5_to_v4( + message: v5::SqliteExecResponse, +) -> Result { + Ok(match message { + v5::SqliteExecResponse::SqliteExecOk(ok) => { + v4::SqliteExecResponse::SqliteExecOk(convert_same_bytes(ok)?) + } + v5::SqliteExecResponse::SqliteErrorResponse(error) => { + v4::SqliteExecResponse::SqliteErrorResponse( + convert_sqlite_error_response_v5_to_v4(error), + ) + } + }) +} + +fn convert_sqlite_execute_response_v4_to_v5( + message: v4::SqliteExecuteResponse, +) -> Result { + Ok(match message { + v4::SqliteExecuteResponse::SqliteExecuteOk(ok) => { + v5::SqliteExecuteResponse::SqliteExecuteOk(convert_same_bytes(ok)?) + } + v4::SqliteExecuteResponse::SqliteErrorResponse(error) => { + v5::SqliteExecuteResponse::SqliteErrorResponse( + convert_sqlite_error_response_v4_to_v5(error), + ) + } + }) +} + +fn convert_sqlite_execute_response_v5_to_v4( + message: v5::SqliteExecuteResponse, +) -> Result { + Ok(match message { + v5::SqliteExecuteResponse::SqliteExecuteOk(ok) => { + v4::SqliteExecuteResponse::SqliteExecuteOk(convert_same_bytes(ok)?) + } + v5::SqliteExecuteResponse::SqliteErrorResponse(error) => { + v4::SqliteExecuteResponse::SqliteErrorResponse( + convert_sqlite_error_response_v5_to_v4(error), + ) + } + }) +} + +fn convert_to_rivet_v4_to_v5(message: v4::ToRivet) -> Result { + convert_same_bytes(message) +} + +fn convert_to_rivet_v5_to_v4(message: v5::ToRivet) -> Result { + convert_same_bytes(message) +} + fn convert_to_envoy_v3_to_v4(message: v3::ToEnvoy) -> Result { convert_same_bytes(message) } @@ -1449,12 +1765,12 @@ mod tests { use super::{ActorCommandKeyData, ToEnvoy}; use crate::{ PROTOCOL_VERSION, - generated::{v1, v2, v4}, + generated::{v1, v2, v5}, }; #[test] fn protocol_version_constant_matches_schema_version() { - assert_eq!(PROTOCOL_VERSION, 4); + assert_eq!(PROTOCOL_VERSION, 5); } #[test] @@ -1479,10 +1795,10 @@ mod tests { }]))?; let decoded = ToEnvoy::deserialize_version(&payload, 1)?.unwrap_latest()?; - let v4::ToEnvoy::ToEnvoyCommands(commands) = decoded else { + let v5::ToEnvoy::ToEnvoyCommands(commands) = decoded else { panic!("expected commands"); }; - let v4::Command::CommandStartActor(start) = &commands[0].inner else { + let v5::Command::CommandStartActor(start) = &commands[0].inner else { panic!("expected start actor"); }; @@ -1509,9 +1825,9 @@ mod tests { #[test] fn actor_command_key_data_round_trips_to_v1() -> Result<()> { - let encoded = ActorCommandKeyData::wrap_latest(v4::ActorCommandKeyData::CommandStartActor( - v4::CommandStartActor { - config: v4::ActorConfig { + let encoded = ActorCommandKeyData::wrap_latest(v5::ActorCommandKeyData::CommandStartActor( + v5::CommandStartActor { + config: v5::ActorConfig { name: "demo".into(), key: None, create_ts: 7, @@ -1524,7 +1840,7 @@ mod tests { .serialize_version(1)?; let decoded = ActorCommandKeyData::deserialize_version(&encoded, 1)?.unwrap_latest()?; - let v4::ActorCommandKeyData::CommandStartActor(start) = decoded else { + let v5::ActorCommandKeyData::CommandStartActor(start) = decoded else { panic!("expected start actor"); }; assert_eq!(start.config.name, "demo"); diff --git a/engine/sdks/rust/envoy-protocol/tests/remote_sql_compat.rs b/engine/sdks/rust/envoy-protocol/tests/remote_sql_compat.rs index 17d4e7b55a..185619d6b8 100644 --- a/engine/sdks/rust/envoy-protocol/tests/remote_sql_compat.rs +++ b/engine/sdks/rust/envoy-protocol/tests/remote_sql_compat.rs @@ -1,6 +1,6 @@ use anyhow::Result; use rivet_envoy_protocol::{ - generated::v4, + generated::{v4, v5}, versioned::{ ProtocolCompatibilityDirection, ProtocolCompatibilityError, ProtocolCompatibilityFeature, ToEnvoy, ToRivet, @@ -8,10 +8,10 @@ use rivet_envoy_protocol::{ }; use vbare::OwnedVersionedData; -fn remote_sql_request_exec() -> v4::ToRivet { - v4::ToRivet::ToRivetSqliteExecRequest(v4::ToRivetSqliteExecRequest { +fn remote_sql_request_exec() -> v5::ToRivet { + v5::ToRivet::ToRivetSqliteExecRequest(v5::ToRivetSqliteExecRequest { request_id: 1, - data: v4::SqliteExecRequest { + data: v5::SqliteExecRequest { namespace_id: "namespace".into(), actor_id: "actor".into(), generation: 7, @@ -20,34 +20,38 @@ fn remote_sql_request_exec() -> v4::ToRivet { }) } -fn remote_sql_request_execute() -> v4::ToRivet { - v4::ToRivet::ToRivetSqliteExecuteRequest(v4::ToRivetSqliteExecuteRequest { +fn remote_sql_request_execute() -> v5::ToRivet { + v5::ToRivet::ToRivetSqliteExecuteRequest(v5::ToRivetSqliteExecuteRequest { request_id: 2, - data: v4::SqliteExecuteRequest { + data: v5::SqliteExecuteRequest { namespace_id: "namespace".into(), actor_id: "actor".into(), generation: 7, sql: "select ?".into(), - params: Some(vec![v4::SqliteBindParam::SqliteValueInteger( - v4::SqliteValueInteger { value: 1 }, + params: Some(vec![v5::SqliteBindParam::SqliteValueInteger( + v5::SqliteValueInteger { value: 1 }, )]), }, }) } -fn remote_sql_response_exec() -> v4::ToEnvoy { - v4::ToEnvoy::ToEnvoySqliteExecResponse(v4::ToEnvoySqliteExecResponse { +fn remote_sql_response_exec() -> v5::ToEnvoy { + v5::ToEnvoy::ToEnvoySqliteExecResponse(v5::ToEnvoySqliteExecResponse { request_id: 1, - data: v4::SqliteExecResponse::SqliteErrorResponse(v4::SqliteErrorResponse { + data: v5::SqliteExecResponse::SqliteErrorResponse(v5::SqliteErrorResponse { + group: "sqlite".into(), + code: "remote_unavailable".into(), message: "remote sql execution is unavailable".into(), }), }) } -fn remote_sql_response_execute() -> v4::ToEnvoy { - v4::ToEnvoy::ToEnvoySqliteExecuteResponse(v4::ToEnvoySqliteExecuteResponse { +fn remote_sql_response_execute() -> v5::ToEnvoy { + v5::ToEnvoy::ToEnvoySqliteExecuteResponse(v5::ToEnvoySqliteExecuteResponse { request_id: 2, - data: v4::SqliteExecuteResponse::SqliteErrorResponse(v4::SqliteErrorResponse { + data: v5::SqliteExecuteResponse::SqliteErrorResponse(v5::SqliteErrorResponse { + group: "sqlite".into(), + code: "remote_unavailable".into(), message: "remote sql execution is unavailable".into(), }), }) @@ -109,11 +113,11 @@ fn new_core_new_pegboard_envoy_allows_remote_sql_both_directions() -> Result<()> assert!(matches!( ToRivet::deserialize(&request, 4)?, - v4::ToRivet::ToRivetSqliteExecRequest(_) + v5::ToRivet::ToRivetSqliteExecRequest(_) )); assert!(matches!( ToEnvoy::deserialize(&response, 4)?, - v4::ToEnvoy::ToEnvoySqliteExecResponse(_) + v5::ToEnvoy::ToEnvoySqliteExecResponse(_) )); Ok(()) @@ -121,8 +125,25 @@ fn new_core_new_pegboard_envoy_allows_remote_sql_both_directions() -> Result<()> #[test] fn v4_remote_sql_payloads_do_not_decode_as_v3() -> Result<()> { - let request = serde_bare::to_vec(&remote_sql_request_exec())?; - let response = serde_bare::to_vec(&remote_sql_response_exec())?; + let request = serde_bare::to_vec(&v4::ToRivet::ToRivetSqliteExecRequest( + v4::ToRivetSqliteExecRequest { + request_id: 1, + data: v4::SqliteExecRequest { + namespace_id: "namespace".into(), + actor_id: "actor".into(), + generation: 7, + sql: "select 1".into(), + }, + }, + ))?; + let response = serde_bare::to_vec(&v4::ToEnvoy::ToEnvoySqliteExecResponse( + v4::ToEnvoySqliteExecResponse { + request_id: 1, + data: v4::SqliteExecResponse::SqliteErrorResponse(v4::SqliteErrorResponse { + message: "remote sql execution is unavailable".into(), + }), + }, + ))?; assert!(ToRivet::deserialize(&request, 3).is_err()); assert!(ToEnvoy::deserialize(&response, 3).is_err()); diff --git a/engine/sdks/rust/envoy-protocol/tests/stateless_sqlite_v3.rs b/engine/sdks/rust/envoy-protocol/tests/stateless_sqlite_v3.rs index 4fcfabd259..13dfdd105e 100644 --- a/engine/sdks/rust/envoy-protocol/tests/stateless_sqlite_v3.rs +++ b/engine/sdks/rust/envoy-protocol/tests/stateless_sqlite_v3.rs @@ -86,23 +86,27 @@ fn commit_response_ok_and_err_roundtrip() -> anyhow::Result<()> { let ok = roundtrip_to_envoy(protocol::ToEnvoy::ToEnvoySqliteCommitResponse( protocol::ToEnvoySqliteCommitResponse { request_id: 1, - data: protocol::SqliteCommitResponse::SqliteCommitOk, + data: protocol::SqliteCommitResponse::SqliteCommitOk(protocol::SqliteCommitOk { + head_txid: Some(7), + }), }, ))?; let protocol::ToEnvoy::ToEnvoySqliteCommitResponse(ok) = ok else { panic!("expected commit response"); }; assert_eq!(ok.request_id, 1); - assert!(matches!( - ok.data, - protocol::SqliteCommitResponse::SqliteCommitOk - )); + let protocol::SqliteCommitResponse::SqliteCommitOk(ok) = ok.data else { + panic!("expected ok response"); + }; + assert_eq!(ok.head_txid, Some(7)); let err = roundtrip_to_envoy(protocol::ToEnvoy::ToEnvoySqliteCommitResponse( protocol::ToEnvoySqliteCommitResponse { request_id: 2, data: protocol::SqliteCommitResponse::SqliteErrorResponse( protocol::SqliteErrorResponse { + group: "depot".into(), + code: "quota_exceeded".into(), message: "quota exceeded".into(), }, ), @@ -164,7 +168,7 @@ fn expected_generation_optional_present_and_absent() -> anyhow::Result<()> { #[test] fn protocol_version_constant_matches_schema_version() { - assert_eq!(PROTOCOL_VERSION, 4); + assert_eq!(PROTOCOL_VERSION, 5); } #[test] diff --git a/engine/sdks/schemas/envoy-protocol/v5.bare b/engine/sdks/schemas/envoy-protocol/v5.bare new file mode 100644 index 0000000000..baa7039cae --- /dev/null +++ b/engine/sdks/schemas/envoy-protocol/v5.bare @@ -0,0 +1,642 @@ +# MARK: Core Primitives + +type Id str +type Json str + +type GatewayId data[4] +type RequestId data[4] +type MessageIndex u16 + +# MARK: KV + +# Basic types +type KvKey data +type KvValue data +type KvMetadata struct { + version: data + updateTs: i64 +} + +# Query types +type KvListAllQuery void +type KvListRangeQuery struct { + start: KvKey + end: KvKey + exclusive: bool +} + +type KvListPrefixQuery struct { + key: KvKey +} + +type KvListQuery union { + KvListAllQuery | + KvListRangeQuery | + KvListPrefixQuery +} + +# Request types +type KvGetRequest struct { + keys: list +} + +type KvListRequest struct { + query: KvListQuery + reverse: optional + limit: optional +} + +type KvPutRequest struct { + keys: list + values: list +} + +type KvDeleteRequest struct { + keys: list +} + +type KvDeleteRangeRequest struct { + start: KvKey + end: KvKey +} + +type KvDropRequest void + +# Response types +type KvErrorResponse struct { + message: str +} + +type KvGetResponse struct { + keys: list + values: list + metadata: list +} + +type KvListResponse struct { + keys: list + values: list + metadata: list +} + +type KvPutResponse void +type KvDeleteResponse void +type KvDropResponse void + +# Request/Response unions +type KvRequestData union { + KvGetRequest | + KvListRequest | + KvPutRequest | + KvDeleteRequest | + KvDeleteRangeRequest | + KvDropRequest +} + +type KvResponseData union { + KvErrorResponse | + KvGetResponse | + KvListResponse | + KvPutResponse | + KvDeleteResponse | + KvDropResponse +} + +# MARK: SQLite + +type SqlitePgno u32 +type SqliteGeneration u64 +type SqlitePageBytes data + +type SqliteDirtyPage struct { + pgno: SqlitePgno + bytes: SqlitePageBytes +} + +type SqliteFetchedPage struct { + pgno: SqlitePgno + bytes: optional +} + +type SqliteGetPagesRequest struct { + actorId: Id + pgnos: list + expectedGeneration: optional + expectedHeadTxid: optional +} + +type SqliteGetPagesOk struct { + pages: list + headTxid: optional +} + +type SqliteErrorResponse struct { + group: str + code: str + message: str +} + +type SqliteGetPagesResponse union { + SqliteGetPagesOk | + SqliteErrorResponse +} + +type SqliteCommitRequest struct { + actorId: Id + dirtyPages: list + dbSizePages: u32 + nowMs: i64 + expectedGeneration: optional + expectedHeadTxid: optional +} + +type SqliteCommitOk struct { + headTxid: optional +} + +type SqliteCommitResponse union { + SqliteCommitOk | + SqliteErrorResponse +} + +# MARK: SQLite Remote Execution + +type SqliteValueNull void + +type SqliteValueInteger struct { + value: i64 +} + +type SqliteValueFloat struct { + value: data[8] +} + +type SqliteValueText struct { + value: str +} + +type SqliteValueBlob struct { + value: data +} + +type SqliteBindParam union { + SqliteValueNull | + SqliteValueInteger | + SqliteValueFloat | + SqliteValueText | + SqliteValueBlob +} + +type SqliteColumnValue union { + SqliteValueNull | + SqliteValueInteger | + SqliteValueFloat | + SqliteValueText | + SqliteValueBlob +} + +type SqliteQueryResult struct { + columns: list + rows: list> +} + +type SqliteExecuteResult struct { + columns: list + rows: list> + changes: i64 + lastInsertRowId: optional +} + +type SqliteExecRequest struct { + namespaceId: Id + actorId: Id + generation: SqliteGeneration + sql: str +} + +type SqliteExecuteRequest struct { + namespaceId: Id + actorId: Id + generation: SqliteGeneration + sql: str + params: optional> +} + +type SqliteExecOk struct { + result: SqliteQueryResult +} + +type SqliteExecuteOk struct { + result: SqliteExecuteResult +} + +type SqliteExecResponse union { + SqliteExecOk | + SqliteErrorResponse +} + +type SqliteExecuteResponse union { + SqliteExecuteOk | + SqliteErrorResponse +} + +# MARK: Actor + +# Core +type StopCode enum { + OK + ERROR +} + +type ActorName struct { + metadata: Json +} + +type ActorConfig struct { + name: str + key: optional + createTs: i64 + input: optional +} + +type ActorCheckpoint struct { + actorId: Id + generation: u32 + index: i64 +} + +# Intent +type ActorIntentSleep void + +type ActorIntentStop void + +type ActorIntent union { + ActorIntentSleep | + ActorIntentStop +} + +# State +type ActorStateRunning void + +type ActorStateStopped struct { + code: StopCode + message: optional +} + +type ActorState union { + ActorStateRunning | + ActorStateStopped +} + +# MARK: Events +type EventActorIntent struct { + intent: ActorIntent +} + +type EventActorStateUpdate struct { + state: ActorState +} + +type EventActorSetAlarm struct { + alarmTs: optional +} + +type Event union { + EventActorIntent | + EventActorStateUpdate | + EventActorSetAlarm +} + +type EventWrapper struct { + checkpoint: ActorCheckpoint + inner: Event +} + +# MARK: Preloaded KV + +type PreloadedKvEntry struct { + key: KvKey + value: KvValue + metadata: KvMetadata +} + +type PreloadedKv struct { + entries: list + requestedGetKeys: list + requestedPrefixes: list +} + +# MARK: Commands + +type HibernatingRequest struct { + gatewayId: GatewayId + requestId: RequestId +} + +type CommandStartActor struct { + config: ActorConfig + hibernatingRequests: list + preloadedKv: optional +} + +type StopActorReason enum { + SLEEP_INTENT + STOP_INTENT + DESTROY + GOING_AWAY + LOST +} + +type CommandStopActor struct { + reason: StopActorReason +} + +type Command union { + CommandStartActor | + CommandStopActor +} + +type CommandWrapper struct { + checkpoint: ActorCheckpoint + inner: Command +} + +# We redeclare this so its top level +type ActorCommandKeyData union { + CommandStartActor | + CommandStopActor +} + +# MARK: Tunnel + +# Message ID + +type MessageId struct { + # Globally unique ID + gatewayId: GatewayId + # Unique ID to the gateway + requestId: RequestId + # Unique ID to the request + messageIndex: MessageIndex +} + +# HTTP +type ToEnvoyRequestStart struct { + actorId: Id + method: str + path: str + headers: map + body: optional + stream: bool +} + +type ToEnvoyRequestChunk struct { + body: data + finish: bool +} + +type ToEnvoyRequestAbort void + +type ToRivetResponseStart struct { + status: u16 + headers: map + body: optional + stream: bool +} + +type ToRivetResponseChunk struct { + body: data + finish: bool +} + +type ToRivetResponseAbort void + +# WebSocket +type ToEnvoyWebSocketOpen struct { + actorId: Id + path: str + headers: map +} + +type ToEnvoyWebSocketMessage struct { + data: data + binary: bool +} + +type ToEnvoyWebSocketClose struct { + code: optional + reason: optional +} + +type ToRivetWebSocketOpen struct { + canHibernate: bool +} + +type ToRivetWebSocketMessage struct { + data: data + binary: bool +} + +type ToRivetWebSocketMessageAck struct { + index: MessageIndex +} + +type ToRivetWebSocketClose struct { + code: optional + reason: optional + hibernate: bool +} + +# To Rivet +type ToRivetTunnelMessageKind union { + # HTTP + ToRivetResponseStart | + ToRivetResponseChunk | + ToRivetResponseAbort | + + # WebSocket + ToRivetWebSocketOpen | + ToRivetWebSocketMessage | + ToRivetWebSocketMessageAck | + ToRivetWebSocketClose +} + +type ToRivetTunnelMessage struct { + messageId: MessageId + messageKind: ToRivetTunnelMessageKind +} + +# To Envoy +type ToEnvoyTunnelMessageKind union { + # HTTP + ToEnvoyRequestStart | + ToEnvoyRequestChunk | + ToEnvoyRequestAbort | + + # WebSocket + ToEnvoyWebSocketOpen | + ToEnvoyWebSocketMessage | + ToEnvoyWebSocketClose +} + +type ToEnvoyTunnelMessage struct { + messageId: MessageId + messageKind: ToEnvoyTunnelMessageKind +} + +type ToEnvoyPing struct { + ts: i64 +} + +# MARK: To Rivet +type ToRivetMetadata struct { + prepopulateActorNames: optional> + metadata: optional +} + +type ToRivetEvents list + +type ToRivetAckCommands struct { + lastCommandCheckpoints: list +} + +type ToRivetStopping void + +type ToRivetPong struct { + ts: i64 +} + +type ToRivetKvRequest struct { + actorId: Id + requestId: u32 + data: KvRequestData +} + +type ToRivetSqliteGetPagesRequest struct { + requestId: u32 + data: SqliteGetPagesRequest +} + +type ToRivetSqliteCommitRequest struct { + requestId: u32 + data: SqliteCommitRequest +} + +type ToRivetSqliteExecRequest struct { + requestId: u32 + data: SqliteExecRequest +} + +type ToRivetSqliteExecuteRequest struct { + requestId: u32 + data: SqliteExecuteRequest +} + +type ToRivet union { + ToRivetMetadata | + ToRivetEvents | + ToRivetAckCommands | + ToRivetStopping | + ToRivetPong | + ToRivetKvRequest | + ToRivetTunnelMessage | + ToRivetSqliteGetPagesRequest | + ToRivetSqliteCommitRequest | + ToRivetSqliteExecRequest | + ToRivetSqliteExecuteRequest +} + +# MARK: To Envoy +type ProtocolMetadata struct { + envoyLostThreshold: i64 + actorStopThreshold: i64 + maxResponsePayloadSize: u64 +} + +type ToEnvoyInit struct { + metadata: ProtocolMetadata +} + +type ToEnvoyCommands list + +type ToEnvoyAckEvents struct { + lastEventCheckpoints: list +} + +type ToEnvoyKvResponse struct { + requestId: u32 + data: KvResponseData +} + +type ToEnvoySqliteGetPagesResponse struct { + requestId: u32 + data: SqliteGetPagesResponse +} + +type ToEnvoySqliteCommitResponse struct { + requestId: u32 + data: SqliteCommitResponse +} + +type ToEnvoySqliteExecResponse struct { + requestId: u32 + data: SqliteExecResponse +} + +type ToEnvoySqliteExecuteResponse struct { + requestId: u32 + data: SqliteExecuteResponse +} + +type ToEnvoy union { + ToEnvoyInit | + ToEnvoyCommands | + ToEnvoyAckEvents | + ToEnvoyKvResponse | + ToEnvoyTunnelMessage | + ToEnvoyPing | + ToEnvoySqliteGetPagesResponse | + ToEnvoySqliteCommitResponse | + ToEnvoySqliteExecResponse | + ToEnvoySqliteExecuteResponse +} + +# MARK: To Envoy Conn +type ToEnvoyConnPing struct { + gatewayId: GatewayId + requestId: RequestId + ts: i64 +} + +type ToEnvoyConnClose void + +type ToEnvoyConn union { + ToEnvoyConnPing | + ToEnvoyConnClose | + ToEnvoyCommands | + ToEnvoyAckEvents | + ToEnvoyTunnelMessage +} + +# MARK: To Gateway +type ToGatewayPong struct { + requestId: RequestId + ts: i64 +} + +type ToGateway union { + ToGatewayPong | + ToRivetTunnelMessage +} + +# MARK: To Outbound +type ToOutboundActorStart struct { + namespaceId: Id + poolName: str + checkpoint: ActorCheckpoint + actorConfig: ActorConfig +} + +type ToOutbound union { + ToOutboundActorStart +} diff --git a/engine/sdks/typescript/envoy-protocol/src/index.ts b/engine/sdks/typescript/envoy-protocol/src/index.ts index 65c8863383..5f186af919 100644 --- a/engine/sdks/typescript/envoy-protocol/src/index.ts +++ b/engine/sdks/typescript/envoy-protocol/src/index.ts @@ -679,29 +679,38 @@ function write7(bc: bare.ByteCursor, x: readonly SqliteFetchedPage[]): void { export type SqliteGetPagesOk = { readonly pages: readonly SqliteFetchedPage[] + readonly headTxid: u64 | null } export function readSqliteGetPagesOk(bc: bare.ByteCursor): SqliteGetPagesOk { return { pages: read7(bc), + headTxid: read2(bc), } } export function writeSqliteGetPagesOk(bc: bare.ByteCursor, x: SqliteGetPagesOk): void { write7(bc, x.pages) + write2(bc, x.headTxid) } export type SqliteErrorResponse = { + readonly group: string + readonly code: string readonly message: string } export function readSqliteErrorResponse(bc: bare.ByteCursor): SqliteErrorResponse { return { + group: bare.readString(bc), + code: bare.readString(bc), message: bare.readString(bc), } } export function writeSqliteErrorResponse(bc: bare.ByteCursor, x: SqliteErrorResponse): void { + bare.writeString(bc, x.group) + bare.writeString(bc, x.code) bare.writeString(bc, x.message) } @@ -787,7 +796,19 @@ export function writeSqliteCommitRequest(bc: bare.ByteCursor, x: SqliteCommitReq write2(bc, x.expectedHeadTxid) } -export type SqliteCommitOk = null +export type SqliteCommitOk = { + readonly headTxid: u64 | null +} + +export function readSqliteCommitOk(bc: bare.ByteCursor): SqliteCommitOk { + return { + headTxid: read2(bc), + } +} + +export function writeSqliteCommitOk(bc: bare.ByteCursor, x: SqliteCommitOk): void { + write2(bc, x.headTxid) +} export type SqliteCommitResponse = | { readonly tag: "SqliteCommitOk"; readonly val: SqliteCommitOk } @@ -798,7 +819,7 @@ export function readSqliteCommitResponse(bc: bare.ByteCursor): SqliteCommitRespo const tag = bare.readU8(bc) switch (tag) { case 0: - return { tag: "SqliteCommitOk", val: null } + return { tag: "SqliteCommitOk", val: readSqliteCommitOk(bc) } case 1: return { tag: "SqliteErrorResponse", val: readSqliteErrorResponse(bc) } default: { @@ -812,6 +833,7 @@ export function writeSqliteCommitResponse(bc: bare.ByteCursor, x: SqliteCommitRe switch (x.tag) { case "SqliteCommitOk": { bare.writeU8(bc, 0) + writeSqliteCommitOk(bc, x.val) break } case "SqliteErrorResponse": { @@ -3247,4 +3269,4 @@ function assert(condition: boolean, message?: string): asserts condition { if (!condition) throw new Error(message ?? "Assertion failed") } -export const VERSION = 4; \ No newline at end of file +export const VERSION = 5; \ No newline at end of file diff --git a/rivetkit-rust/packages/rivetkit-core/src/actor/sqlite.rs b/rivetkit-rust/packages/rivetkit-core/src/actor/sqlite.rs index 0b1a9b9ce6..62b42ca287 100644 --- a/rivetkit-rust/packages/rivetkit-core/src/actor/sqlite.rs +++ b/rivetkit-rust/packages/rivetkit-core/src/actor/sqlite.rs @@ -1,12 +1,12 @@ use std::collections::HashSet; use std::io::Cursor; -#[cfg(feature = "sqlite-local")] use std::sync::{ Arc, atomic::{AtomicBool, Ordering}, }; use anyhow::{Context, Result}; +use depot_client_types::is_head_fence_mismatch; pub use depot_client_types::{BindParam, ColumnValue, ExecResult, ExecuteResult, QueryResult}; #[cfg(feature = "sqlite-local")] use parking_lot::Mutex; @@ -34,7 +34,7 @@ use depot_client::{ vfs::{SqliteVfsMetrics, SqliteVfsMetricsSnapshot}, worker::{ SQLITE_WORKER_QUEUE_CAPACITY, SqliteWorkerCloseTimeoutError, SqliteWorkerClosingError, - SqliteWorkerDeadError, SqliteWorkerOverloadedError, + SqliteWorkerDeadError, SqliteWorkerFatalError, SqliteWorkerOverloadedError, }, }; #[cfg(feature = "sqlite-local")] @@ -89,7 +89,6 @@ pub struct SqliteDb { open_lock: Arc>, #[cfg(feature = "sqlite-local")] worker_failure_task: Arc>>>, - #[cfg(feature = "sqlite-local")] worker_fatal_reported: Arc, #[cfg(feature = "sqlite-local")] vfs_metrics: Option>, @@ -119,7 +118,6 @@ impl SqliteDb { open_lock: Default::default(), #[cfg(feature = "sqlite-local")] worker_failure_task: Default::default(), - #[cfg(feature = "sqlite-local")] worker_fatal_reported: Default::default(), #[cfg(feature = "sqlite-local")] vfs_metrics: None, @@ -167,8 +165,9 @@ impl SqliteDb { let vfs_metrics = self.vfs_metrics.clone(); let rt_handle = tokio::runtime::Handle::try_current() .context("open sqlite database requires a tokio runtime")?; + self.worker_fatal_reported.store(false, Ordering::Release); - let native_db = open_database_from_transport( + let native_db = self.map_local_worker_result(open_database_from_transport( Arc::new(EnvoySqliteTransport::new(config.handle.clone())), config.actor_id.clone(), config @@ -177,8 +176,7 @@ impl SqliteDb { rt_handle, vfs_metrics, ) - .await?; - self.worker_fatal_reported.store(false, Ordering::Release); + .await)?; self.start_worker_failure_monitor(native_db.clone(), config); *self.db.lock() = Some(native_db); Ok(()) @@ -379,7 +377,7 @@ impl SqliteDb { report_sqlite_worker_fatal( &self.worker_fatal_reported, config, - format!("sqlite worker failed: {error}"), + sqlite_worker_fatal_message(error), ); } @@ -467,7 +465,7 @@ impl SqliteDb { Ok(query_result_from_protocol(ok.result)) } protocol::SqliteExecResponse::SqliteErrorResponse(error) => { - Err(remote_sqlite_error_response(error.message)) + Err(self.remote_sqlite_error_response(error)) } } } @@ -495,7 +493,7 @@ impl SqliteDb { Ok(execute_result_from_protocol(ok.result)) } protocol::SqliteExecuteResponse::SqliteErrorResponse(error) => { - Err(remote_sqlite_error_response(error.message)) + Err(self.remote_sqlite_error_response(error)) } } } @@ -538,9 +536,23 @@ impl SqliteDb { .clone() .ok_or_else(|| sqlite_not_configured("handle")) } + + fn remote_sqlite_error_response(&self, error: protocol::SqliteErrorResponse) -> anyhow::Error { + if is_head_fence_mismatch_response(&error) { + if let Ok(config) = self.runtime_config() { + report_sqlite_worker_fatal( + &self.worker_fatal_reported, + config, + format!("remote sqlite fatal storage error: {}", error.message), + ); + } + return SqliteRuntimeError::Closed.build(); + } + + remote_sqlite_error_response(error.message) + } } -#[cfg(feature = "sqlite-local")] fn report_sqlite_worker_fatal(reported: &AtomicBool, config: SqliteRuntimeConfig, message: String) { if reported.swap(true, Ordering::AcqRel) { return; @@ -582,12 +594,22 @@ fn select_sqlite_backend(enabled: bool, remote_sqlite: bool) -> SqliteBackend { #[cfg(feature = "sqlite-local")] fn is_fatal_worker_error(error: &anyhow::Error) -> bool { - error.downcast_ref::().is_some() + error.downcast_ref::().is_some() + || error.downcast_ref::().is_some() || error .downcast_ref::() .is_some() } +#[cfg(feature = "sqlite-local")] +fn sqlite_worker_fatal_message(error: &anyhow::Error) -> String { + if let Some(error) = error.downcast_ref::() { + return format!("sqlite fatal storage error: {}", error.message()); + } + + format!("sqlite worker failed: {error}") +} + #[cfg(feature = "sqlite-local")] fn map_local_worker_error(error: anyhow::Error) -> anyhow::Error { if error @@ -604,6 +626,7 @@ fn map_local_worker_error(error: anyhow::Error) -> anyhow::Error { if error.downcast_ref::().is_some() || error.downcast_ref::().is_some() + || error.downcast_ref::().is_some() { return SqliteRuntimeError::Closed.build(); } @@ -702,6 +725,10 @@ fn remote_sqlite_error_response(message: String) -> anyhow::Error { SqliteRuntimeError::RemoteExecutionFailed { message }.build() } + +fn is_head_fence_mismatch_response(error: &protocol::SqliteErrorResponse) -> bool { + is_head_fence_mismatch(&error.group, &error.code) +} impl std::fmt::Debug for SqliteDb { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("SqliteDb") diff --git a/rivetkit-rust/packages/rivetkit-core/tests/sqlite.rs b/rivetkit-rust/packages/rivetkit-core/tests/sqlite.rs index 6ba7a74e79..ec492fbd8c 100644 --- a/rivetkit-rust/packages/rivetkit-core/tests/sqlite.rs +++ b/rivetkit-rust/packages/rivetkit-core/tests/sqlite.rs @@ -1,4 +1,101 @@ +use std::collections::HashMap; +use std::sync::Arc; +use std::sync::atomic::AtomicBool; + use super::*; +use depot_client_types::{HEAD_FENCE_MISMATCH_CODE, HEAD_FENCE_MISMATCH_GROUP}; +use rivet_envoy_client::config::{ + BoxFuture as EnvoyBoxFuture, EnvoyCallbacks, EnvoyConfig, HttpRequest, HttpResponse, + WebSocketHandler, WebSocketSender, +}; +use rivet_envoy_client::context::{SharedContext, WsTxMessage}; +use rivet_envoy_client::envoy::ToEnvoyMessage; +use rivet_envoy_client::handle::EnvoyHandle; +use tokio::sync::{Mutex as AsyncMutex, mpsc}; + +struct IdleEnvoyCallbacks; + +impl EnvoyCallbacks for IdleEnvoyCallbacks { + fn on_actor_start( + &self, + _handle: EnvoyHandle, + _actor_id: String, + _generation: u32, + _config: protocol::ActorConfig, + _preloaded_kv: Option, + ) -> EnvoyBoxFuture> { + Box::pin(async { Ok(()) }) + } + + fn on_shutdown(&self) {} + + fn fetch( + &self, + _handle: EnvoyHandle, + _actor_id: String, + _gateway_id: protocol::GatewayId, + _request_id: protocol::RequestId, + _request: HttpRequest, + ) -> EnvoyBoxFuture> { + Box::pin(async { unreachable!("sqlite tests do not fetch") }) + } + + fn websocket( + &self, + _handle: EnvoyHandle, + _actor_id: String, + _gateway_id: protocol::GatewayId, + _request_id: protocol::RequestId, + _request: HttpRequest, + _path: String, + _headers: HashMap, + _is_hibernatable: bool, + _is_restoring_hibernatable: bool, + _sender: WebSocketSender, + ) -> EnvoyBoxFuture> { + Box::pin(async { unreachable!("sqlite tests do not open websockets") }) + } + + fn can_hibernate( + &self, + _actor_id: &str, + _gateway_id: &protocol::GatewayId, + _request_id: &protocol::RequestId, + _request: &HttpRequest, + ) -> EnvoyBoxFuture> { + Box::pin(async { Ok(false) }) + } +} + +fn test_envoy_handle() -> (EnvoyHandle, mpsc::UnboundedReceiver) { + let (envoy_tx, envoy_rx) = mpsc::unbounded_channel(); + let shared = Arc::new(SharedContext { + config: EnvoyConfig { + version: 1, + endpoint: "http://127.0.0.1:1".to_string(), + token: None, + namespace: "test".to_string(), + pool_name: "test".to_string(), + prepopulate_actor_names: HashMap::new(), + metadata: None, + not_global: true, + debug_latency_ms: None, + callbacks: Arc::new(IdleEnvoyCallbacks), + }, + envoy_key: "test-envoy".to_string(), + envoy_tx, + actors: Default::default(), + actors_notify: Arc::new(tokio::sync::Notify::new()), + live_tunnel_requests: Default::default(), + pending_hibernation_restores: Default::default(), + ws_tx: Arc::new(AsyncMutex::new(None::>)), + protocol_metadata: Arc::new(AsyncMutex::new(None)), + shutting_down: AtomicBool::new(false), + stopped_tx: tokio::sync::watch::channel(true).0, + }); + + (EnvoyHandle::from_shared(shared), envoy_rx) +} #[test] fn remote_backend_requires_declared_database_and_capability() { @@ -117,3 +214,44 @@ fn remote_lost_response_errors_become_indeterminate_result() { assert_eq!(structured.group(), "sqlite"); assert_eq!(structured.code(), "remote_indeterminate_result"); } + +#[test] +fn remote_head_fence_mismatch_stops_actor_once() { + let (handle, mut envoy_rx) = test_envoy_handle(); + let db = SqliteDb::new_with_remote_sqlite(handle, "actor-a", Some(7), true, true); + + let mapped = db.remote_sqlite_error_response(protocol::SqliteErrorResponse { + group: HEAD_FENCE_MISMATCH_GROUP.to_string(), + code: HEAD_FENCE_MISMATCH_CODE.to_string(), + message: "head fence mismatch in remote sqlite".to_string(), + }); + let structured = rivet_error::RivetError::extract(&mapped); + assert_eq!(structured.group(), "sqlite"); + assert_eq!(structured.code(), "closed"); + + match envoy_rx.try_recv().expect("missing stop actor intent") { + ToEnvoyMessage::ActorIntent { + actor_id, + generation, + intent, + error, + } => { + assert_eq!(actor_id, "actor-a"); + assert_eq!(generation, Some(7)); + assert!(matches!(intent, protocol::ActorIntent::ActorIntentStop)); + assert!( + error + .expect("missing stop reason") + .contains("remote sqlite fatal storage error") + ); + } + _ => panic!("expected stop actor intent"), + } + + let _ = db.remote_sqlite_error_response(protocol::SqliteErrorResponse { + group: HEAD_FENCE_MISMATCH_GROUP.to_string(), + code: HEAD_FENCE_MISMATCH_CODE.to_string(), + message: "second head fence mismatch".to_string(), + }); + assert!(envoy_rx.try_recv().is_err()); +}