From 788e2e917722ff832048bc6c3871808b2cfa33b2 Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Wed, 3 Dec 2025 19:17:33 -0800 Subject: [PATCH 01/10] fix: role detection; fix: set inside transactions --- pgdog/src/backend/databases.rs | 346 +----------------- pgdog/src/backend/mod.rs | 2 +- pgdog/src/backend/pool/cluster.rs | 12 - pgdog/src/backend/pool/connection/binding.rs | 11 +- .../src/backend/pool/{replicas => lb}/ban.rs | 0 .../src/backend/pool/{replicas => lb}/mod.rs | 172 ++++----- .../backend/pool/{replicas => lb}/monitor.rs | 24 +- .../pool/{replicas => lb}/target_health.rs | 0 .../src/backend/pool/{replicas => lb}/test.rs | 116 +++--- pgdog/src/backend/pool/mod.rs | 4 +- pgdog/src/backend/pool/pool_impl.rs | 10 +- .../backend/pool/replicas/detected_role.rs | 27 -- pgdog/src/backend/pool/shard/mod.rs | 81 +--- pgdog/src/backend/pool/shard/monitor.rs | 18 +- pgdog/src/backend/pool/shard/role_detector.rs | 37 +- pgdog/src/backend/server.rs | 2 +- .../frontend/client/query_engine/connect.rs | 14 + pgdog/src/frontend/client/query_engine/mod.rs | 10 +- pgdog/src/frontend/client/query_engine/set.rs | 9 +- pgdog/src/frontend/router/parser/command.rs | 1 + pgdog/src/frontend/router/parser/query/set.rs | 70 ++-- .../src/frontend/router/parser/query/test.rs | 34 +- pgdog/src/net/parameter.rs | 11 +- 23 files changed, 268 insertions(+), 743 deletions(-) rename pgdog/src/backend/pool/{replicas => lb}/ban.rs (100%) rename pgdog/src/backend/pool/{replicas => lb}/mod.rs (66%) rename pgdog/src/backend/pool/{replicas => lb}/monitor.rs (80%) rename pgdog/src/backend/pool/{replicas => lb}/target_health.rs (100%) rename pgdog/src/backend/pool/{replicas => lb}/test.rs (87%) delete mode 100644 pgdog/src/backend/pool/replicas/detected_role.rs diff --git a/pgdog/src/backend/databases.rs b/pgdog/src/backend/databases.rs index afacd902c..8ff7c2d83 100644 --- a/pgdog/src/backend/databases.rs +++ b/pgdog/src/backend/databases.rs @@ -374,11 +374,6 @@ pub(crate) fn new_pool( user: &crate::config::User, config: &crate::config::Config, ) -> Option<(User, Cluster)> { - let existing_roles = databases() - .cluster(user) - .ok() - .map(|cluster| cluster.redetect_roles()); - let sharded_tables = config.sharded_tables(); let omnisharded_tables = config.omnisharded_tables(); let sharded_mappings = config.sharded_mappings(); @@ -386,27 +381,10 @@ pub(crate) fn new_pool( let general = &config.general; let databases = config.databases(); - let mut shards = databases.get(&user.database).cloned()?; + let shards = databases.get(&user.database).cloned()?; let mut shard_configs = vec![]; - for (shard_number, user_databases) in shards.iter_mut().enumerate() { - let role_detector = user_databases.iter().any(|d| d.role == Role::Auto); - - if let Some(ref shard_roles) = existing_roles - .as_ref() - .and_then(|existing_roles| existing_roles.get(shard_number).cloned()) - .flatten() - { - for user_database in user_databases.iter_mut() { - // Override role with automatically detected one. - if let Some(role) = shard_roles.get(&user_database.number) { - if user_database.role == Role::Auto { - user_database.role = role.role; - } - } - } - } - + for user_databases in shards { let has_single_replica = user_databases.len() == 1; let primary = user_databases .iter() @@ -424,11 +402,7 @@ pub(crate) fn new_pool( }) .collect::>(); - shard_configs.push(ClusterShardConfig { - primary, - replicas, - role_detector, - }); + shard_configs.push(ClusterShardConfig { primary, replicas }); } let mut sharded_tables = sharded_tables @@ -1182,318 +1156,4 @@ mod tests { "Mirror config should not be precomputed when source has no users" ); } - - #[test] - fn test_new_pool_fetches_existing_roles_on_reload() { - use crate::backend::pool::lsn_monitor::LsnStats; - use crate::backend::replication::publisher::Lsn; - use std::sync::Arc; - use tokio::time::Instant; - - let _lock = lock(); - - let mut config = Config::default(); - config.databases = vec![ - Database { - name: "testdb".to_string(), - host: "127.0.0.1".to_string(), - port: 5432, - role: Role::Auto, - shard: 0, - ..Default::default() - }, - Database { - name: "testdb".to_string(), - host: "127.0.0.1".to_string(), - port: 5433, - role: Role::Auto, - shard: 0, - ..Default::default() - }, - ]; - - let users = crate::config::Users { - users: vec![crate::config::User { - name: "testuser".to_string(), - database: "testdb".to_string(), - password: Some("pass".to_string()), - ..Default::default() - }], - ..Default::default() - }; - - let config_and_users = ConfigAndUsers { - config: config.clone(), - users: users.clone(), - config_path: std::path::PathBuf::new(), - users_path: std::path::PathBuf::new(), - }; - - let initial_databases = from_config(&config_and_users); - let cluster = initial_databases.cluster(("testuser", "testdb")).unwrap(); - - for pool in cluster.shards()[0].pools() { - let lsn_stats = LsnStats { - replica: pool.addr().database_number == 1, - lsn: Lsn::from_i64(1000), - offset_bytes: 1000, - timestamp: Default::default(), - fetched: Instant::now(), - aurora: false, - }; - pool.set_lsn_stats(lsn_stats); - } - - DATABASES.store(Arc::new(initial_databases)); - - let (_, new_cluster) = new_pool(&users.users[0], &config).unwrap(); - - let roles = new_cluster.shards()[0].current_roles(); - - assert_eq!(roles.len(), 2); - - assert_eq!( - roles.get(&0).unwrap().role, - Role::Primary, - "database_number 0 should be assigned Primary (replica=false)" - ); - assert_eq!( - roles.get(&1).unwrap().role, - Role::Replica, - "database_number 1 should be assigned Replica (replica=true)" - ); - - DATABASES.store(Arc::new(Databases::default())); - } - - #[test] - fn test_new_pool_only_assigns_roles_to_auto() { - use crate::backend::pool::lsn_monitor::LsnStats; - use crate::backend::replication::publisher::Lsn; - use std::sync::Arc; - use tokio::time::Instant; - - let _lock = lock(); - - let mut config = Config::default(); - config.databases = vec![ - Database { - name: "testdb".to_string(), - host: "127.0.0.1".to_string(), - port: 5432, - role: Role::Primary, - shard: 0, - ..Default::default() - }, - Database { - name: "testdb".to_string(), - host: "127.0.0.1".to_string(), - port: 5433, - role: Role::Replica, - shard: 0, - ..Default::default() - }, - Database { - name: "testdb".to_string(), - host: "127.0.0.1".to_string(), - port: 5434, - role: Role::Auto, - shard: 0, - ..Default::default() - }, - ]; - - let users = crate::config::Users { - users: vec![crate::config::User { - name: "testuser".to_string(), - database: "testdb".to_string(), - password: Some("pass".to_string()), - ..Default::default() - }], - ..Default::default() - }; - - let config_and_users = ConfigAndUsers { - config: config.clone(), - users: users.clone(), - config_path: std::path::PathBuf::new(), - users_path: std::path::PathBuf::new(), - }; - - let initial_databases = from_config(&config_and_users); - let cluster = initial_databases.cluster(("testuser", "testdb")).unwrap(); - - for pool in cluster.shards()[0].pools() { - let db_num = pool.addr().database_number; - let lsn_stats = LsnStats { - replica: db_num != 0, - lsn: Lsn::from_i64(1000), - offset_bytes: 1000, - timestamp: Default::default(), - fetched: Instant::now(), - aurora: false, - }; - pool.set_lsn_stats(lsn_stats); - } - - DATABASES.store(Arc::new(initial_databases)); - - let (_, new_cluster) = new_pool(&users.users[0], &config).unwrap(); - - let roles = new_cluster.shards()[0].current_roles(); - - assert_eq!(roles.len(), 3); - - assert_eq!( - roles.get(&0).unwrap().role, - Role::Primary, - "Explicit Primary should remain Primary (LSN says replica=false)" - ); - assert_eq!( - roles.get(&1).unwrap().role, - Role::Replica, - "Explicit Replica should remain Replica (even though LSN says replica=true which would suggest Replica, the explicit config is preserved)" - ); - assert_eq!( - roles.get(&2).unwrap().role, - Role::Replica, - "Auto role should be assigned Replica based on LSN replica=true" - ); - - DATABASES.store(Arc::new(Databases::default())); - } - - #[test] - fn test_new_pool_matches_roles_by_database_number() { - use crate::backend::pool::lsn_monitor::LsnStats; - use crate::backend::replication::publisher::Lsn; - use std::sync::Arc; - use tokio::time::Instant; - - let _lock = lock(); - - let mut config = Config::default(); - config.databases = vec![ - Database { - name: "db1".to_string(), - host: "127.0.0.1".to_string(), - port: 5432, - role: Role::Auto, - shard: 0, - ..Default::default() - }, - Database { - name: "db2".to_string(), - host: "127.0.0.1".to_string(), - port: 5433, - role: Role::Auto, - shard: 0, - ..Default::default() - }, - Database { - name: "db1".to_string(), - host: "127.0.0.1".to_string(), - port: 5434, - role: Role::Auto, - shard: 0, - ..Default::default() - }, - ]; - - let users = crate::config::Users { - users: vec![ - crate::config::User { - name: "user1".to_string(), - database: "db1".to_string(), - password: Some("pass".to_string()), - ..Default::default() - }, - crate::config::User { - name: "user2".to_string(), - database: "db2".to_string(), - password: Some("pass".to_string()), - ..Default::default() - }, - ], - ..Default::default() - }; - - let config_and_users = ConfigAndUsers { - config: config.clone(), - users: users.clone(), - config_path: std::path::PathBuf::new(), - users_path: std::path::PathBuf::new(), - }; - - let initial_databases = from_config(&config_and_users); - - let db1_cluster = initial_databases.cluster(("user1", "db1")).unwrap(); - for pool in db1_cluster.shards()[0].pools() { - let db_num = pool.addr().database_number; - let lsn_stats = LsnStats { - replica: db_num == 2, - lsn: Lsn::from_i64(1000), - offset_bytes: 1000, - timestamp: Default::default(), - fetched: Instant::now(), - aurora: false, - }; - pool.set_lsn_stats(lsn_stats); - } - - let db2_cluster = initial_databases.cluster(("user2", "db2")).unwrap(); - for pool in db2_cluster.shards()[0].pools() { - let lsn_stats = LsnStats { - replica: false, - lsn: Lsn::from_i64(1000), - offset_bytes: 1000, - timestamp: Default::default(), - fetched: Instant::now(), - aurora: false, - }; - pool.set_lsn_stats(lsn_stats); - } - - DATABASES.store(Arc::new(initial_databases)); - - let (_, new_db1_cluster) = new_pool(&users.users[0], &config).unwrap(); - let db1_roles = new_db1_cluster.shards()[0].current_roles(); - - assert_eq!(db1_roles.len(), 2); - assert!( - db1_roles.get(&0).is_some(), - "db1 should have database_number 0" - ); - assert!( - db1_roles.get(&2).is_some(), - "db1 should have database_number 2" - ); - - assert_eq!( - db1_roles.get(&0).unwrap().role, - Role::Primary, - "database_number 0 should be Primary (replica=false)" - ); - assert_eq!( - db1_roles.get(&2).unwrap().role, - Role::Replica, - "database_number 2 should be Replica (replica=true)" - ); - - let (_, new_db2_cluster) = new_pool(&users.users[1], &config).unwrap(); - let db2_roles = new_db2_cluster.shards()[0].current_roles(); - - assert_eq!(db2_roles.len(), 1); - assert!( - db2_roles.get(&1).is_some(), - "db2 should have database_number 1" - ); - assert_eq!( - db2_roles.get(&1).unwrap().role, - Role::Primary, - "database_number 1 should be Primary (replica=false)" - ); - - DATABASES.store(Arc::new(Databases::default())); - } } diff --git a/pgdog/src/backend/mod.rs b/pgdog/src/backend/mod.rs index 7bb4db713..8498d7600 100644 --- a/pgdog/src/backend/mod.rs +++ b/pgdog/src/backend/mod.rs @@ -15,7 +15,7 @@ pub mod server_options; pub mod stats; pub use error::Error; -pub use pool::{Cluster, ClusterShardConfig, Pool, Replicas, Shard, ShardingSchema}; +pub use pool::{Cluster, ClusterShardConfig, LoadBalancer, Pool, Shard, ShardingSchema}; pub use prepared_statements::PreparedStatements; pub use protocol::*; pub use pub_sub::{PubSubClient, PubSubListener}; diff --git a/pgdog/src/backend/pool/cluster.rs b/pgdog/src/backend/pool/cluster.rs index 1b5445530..4b022bbd7 100644 --- a/pgdog/src/backend/pool/cluster.rs +++ b/pgdog/src/backend/pool/cluster.rs @@ -15,7 +15,6 @@ use tracing::{error, info}; use crate::{ backend::{ databases::{databases, User as DatabaseUser}, - pool::shard::role_detector::DetectedRoles, replication::{ReplicationConfig, ShardedColumn, ShardedSchemas}, Schema, ShardedTables, }, @@ -88,7 +87,6 @@ impl ShardingSchema { pub struct ClusterShardConfig { pub primary: Option, pub replicas: Vec, - pub role_detector: bool, } impl ClusterShardConfig { @@ -232,7 +230,6 @@ impl Cluster { rw_split, identifier: identifier.clone(), lsn_check_interval, - role_detector: config.role_detector, }) }) .collect(), @@ -526,14 +523,6 @@ impl Cluster { Ok(()) } - - /// Re-detect primary/replica roles, with shard numbers. - pub fn redetect_roles(&self) -> Vec> { - self.shards - .iter() - .map(|shard| shard.redetect_roles()) - .collect() - } } #[cfg(test)] @@ -582,7 +571,6 @@ mod test { rw_split: ReadWriteSplit::IncludePrimary, identifier: identifier.clone(), lsn_check_interval: Duration::MAX, - role_detector: false, }) }) .collect::>(); diff --git a/pgdog/src/backend/pool/connection/binding.rs b/pgdog/src/backend/pool/connection/binding.rs index 5a077768e..8392425c4 100644 --- a/pgdog/src/backend/pool/connection/binding.rs +++ b/pgdog/src/backend/pool/connection/binding.rs @@ -2,7 +2,7 @@ use crate::{ frontend::{client::query_engine::TwoPcPhase, ClientRequest}, - net::{parameter::Parameters, BackendKeyData, ProtocolMessage}, + net::{parameter::Parameters, BackendKeyData, ProtocolMessage, Query}, state::State, }; @@ -252,7 +252,10 @@ impl Binding { } /// Execute a query on all servers. - pub async fn execute(&mut self, query: &str) -> Result, Error> { + pub async fn execute( + &mut self, + query: impl Into + Clone, + ) -> Result, Error> { let mut result = vec![]; match self { Binding::Direct(Some(ref mut server)) => { @@ -260,7 +263,9 @@ impl Binding { } Binding::MultiShard(ref mut servers, _) => { - let futures = servers.iter_mut().map(|server| server.execute(query)); + let futures = servers + .iter_mut() + .map(|server| server.execute(query.clone())); let results = join_all(futures).await; for server_result in results { diff --git a/pgdog/src/backend/pool/replicas/ban.rs b/pgdog/src/backend/pool/lb/ban.rs similarity index 100% rename from pgdog/src/backend/pool/replicas/ban.rs rename to pgdog/src/backend/pool/lb/ban.rs diff --git a/pgdog/src/backend/pool/replicas/mod.rs b/pgdog/src/backend/pool/lb/mod.rs similarity index 66% rename from pgdog/src/backend/pool/replicas/mod.rs rename to pgdog/src/backend/pool/lb/mod.rs index f66571859..9742969b1 100644 --- a/pgdog/src/backend/pool/replicas/mod.rs +++ b/pgdog/src/backend/pool/lb/mod.rs @@ -1,9 +1,8 @@ //! Replicas pool. use std::{ - collections::HashMap, sync::{ - atomic::{AtomicUsize, Ordering}, + atomic::{AtomicBool, AtomicUsize, Ordering}, Arc, }, time::Duration, @@ -15,21 +14,16 @@ use tokio::{ time::{timeout, Instant}, }; +use crate::config::{LoadBalancingStrategy, ReadWriteSplit, Role}; use crate::net::messages::BackendKeyData; -use crate::{ - backend::pool::shard::role_detector::DetectedRoles, - config::{LoadBalancingStrategy, ReadWriteSplit, Role}, -}; use super::{Error, Guard, Pool, PoolConfig, Request}; pub mod ban; -pub mod detected_role; pub mod monitor; pub mod target_health; use ban::Ban; -pub use detected_role::*; use monitor::*; pub use target_health::*; @@ -41,7 +35,7 @@ mod test; pub struct ReadTarget { pub pool: Pool, pub ban: Ban, - pub role: Role, + replica: Arc, pub health: TargetHealth, } @@ -50,20 +44,27 @@ impl ReadTarget { let ban = Ban::new(&pool); Self { ban, - role, + replica: Arc::new(AtomicBool::new(role == Role::Replica)), health: pool.inner().health.clone(), pool, } } + + /// Get role. + pub(super) fn role(&self) -> Role { + if self.replica.load(Ordering::Relaxed) { + Role::Replica + } else { + Role::Primary + } + } } /// Replicas pools. #[derive(Clone, Default, Debug)] -pub struct Replicas { - /// Replica targets (pools with ban state). - pub(super) replicas: Vec, - /// Primary target (pool with ban state). - pub(super) primary: Option, +pub struct LoadBalancer { + /// Read/write targets. + pub(super) targets: Vec, /// Checkout timeout. pub(super) checkout_timeout: Duration, /// Round robin atomic counter. @@ -76,20 +77,20 @@ pub struct Replicas { pub(super) rw_split: ReadWriteSplit, } -impl Replicas { +impl LoadBalancer { /// Create new replicas pools. pub fn new( primary: &Option, addrs: &[PoolConfig], lb_strategy: LoadBalancingStrategy, rw_split: ReadWriteSplit, - ) -> Replicas { - let checkout_timeout = addrs + ) -> LoadBalancer { + let mut checkout_timeout = addrs .iter() .map(|c| c.config.checkout_timeout) .sum::(); - let replicas: Vec<_> = addrs + let mut targets: Vec<_> = addrs .iter() .map(|config| ReadTarget::new(Pool::new(config), Role::Replica)) .collect(); @@ -98,9 +99,13 @@ impl Replicas { .as_ref() .map(|pool| ReadTarget::new(pool.clone(), Role::Primary)); + if let Some(primary) = primary_target { + checkout_timeout += primary.pool.config().checkout_timeout; + targets.push(primary); + } + Self { - primary: primary_target, - replicas, + targets, checkout_timeout, round_robin: Arc::new(AtomicUsize::new(0)), lb_strategy, @@ -109,30 +114,35 @@ impl Replicas { } } - /// Get current database roles. - pub fn current_roles(&self) -> DetectedRoles { - let mut roles = self - .replicas + /// Get the primary pool, if configured. + pub fn primary(&self) -> Option<&Pool> { + self.primary_target().map(|target| &target.pool) + } + + /// Get the primary read target containing the pool, ban state, and health. + /// + /// Unlike [`primary()`], this returns the full target struct which allows + /// access to ban and health state for monitoring and testing purposes. + pub fn primary_target(&self) -> Option<&ReadTarget> { + self.targets .iter() - .map(|replica| { - let role = DetectedRole::from_read_target(replica); - (role.database_number, role) - }) - .collect::>(); - - if let Some(ref primary) = self.primary { - let role = DetectedRole::from_read_target(primary); - roles.insert(role.database_number, role); - } + .rev() // If there is a primary, it's likely to be last. + .find(|target| target.role() == Role::Primary) + } - roles.into() + pub fn write_only(&self) -> bool { + self.targets + .iter() + .all(|target| target.role() == Role::Primary) } /// Detect database roles from pg_is_in_recovery() and /// return new primary (if any), and replicas. - pub fn redetect_roles(&self) -> Option { + pub fn redetect_roles(&self) -> bool { + let mut changed = false; + let mut targets = self - .replicas + .targets .clone() .into_iter() .map(|target| (target.pool.lsn_stats(), target)) @@ -140,11 +150,7 @@ impl Replicas { // Only detect roles if the LSN detector is running. if !targets.iter().all(|target| target.0.valid()) { - return None; - } - - if let Some(primary) = self.primary.clone() { - targets.push((primary.pool.lsn_stats(), primary)); + return false; } // Pick primary by latest data. The one with the most @@ -160,44 +166,30 @@ impl Replicas { let primary = targets .iter() .find(|target| target.0.valid() && !target.0.replica); + + if let Some(primary) = primary { + if primary.1.role() != Role::Primary { + changed = true; + primary.1.replica.store(false, Ordering::Relaxed); + } + } let replicas = targets .iter() .filter(|target| target.0.replica) .collect::>(); - let mut numbers: HashMap<_, _> = replicas - .iter() - .map(|target| { - let database_number = target.1.pool.addr().database_number; - ( - database_number, - DetectedRole { - role: Role::Replica, - as_of: target.0.fetched, - database_number, - }, - ) - }) - .collect(); - if let Some(primary) = primary { - let database_number = primary.1.pool.addr().database_number; - - numbers.insert( - database_number, - DetectedRole { - role: Role::Primary, - as_of: primary.0.fetched, - database_number, - }, - ); + for replica in replicas { + if replica.1.role() != Role::Replica { + replica.1.replica.store(true, Ordering::Relaxed); + } } - Some(numbers.into()) + changed } /// Launch replica pools and start the monitor. pub fn launch(&self) { - self.replicas.iter().for_each(|target| target.pool.launch()); + self.targets.iter().for_each(|target| target.pool.launch()); Monitor::spawn(self); } @@ -211,27 +203,27 @@ impl Replicas { } /// Move connections from this replica set to another. - pub fn move_conns_to(&self, destination: &Replicas) { - assert_eq!(self.replicas.len(), destination.replicas.len()); + pub fn move_conns_to(&self, destination: &LoadBalancer) { + assert_eq!(self.targets.len(), destination.targets.len()); - for (from, to) in self.replicas.iter().zip(destination.replicas.iter()) { + for (from, to) in self.targets.iter().zip(destination.targets.iter()) { from.pool.move_conns_to(&to.pool); } } /// The two replica sets are referring to the same databases. - pub fn can_move_conns_to(&self, destination: &Replicas) -> bool { - self.replicas.len() == destination.replicas.len() + pub fn can_move_conns_to(&self, destination: &LoadBalancer) -> bool { + self.targets.len() == destination.targets.len() && self - .replicas + .targets .iter() - .zip(destination.replicas.iter()) + .zip(destination.targets.iter()) .all(|(a, b)| a.pool.can_move_conns_to(&b.pool)) } /// How many replicas we are connected to. pub fn len(&self) -> usize { - self.replicas.len() + self.targets.len() } /// There are no replicas. @@ -241,7 +233,7 @@ impl Replicas { /// Cancel a query if one is running. pub async fn cancel(&self, id: &BackendKeyData) -> Result<(), super::super::Error> { - for target in &self.replicas { + for target in &self.targets { target.pool.cancel(id).await?; } @@ -250,21 +242,17 @@ impl Replicas { /// Replica pools handle. pub fn pools(&self) -> Vec<&Pool> { - self.replicas.iter().map(|target| &target.pool).collect() + self.targets.iter().map(|target| &target.pool).collect() } /// Collect all connection pools used for read queries. pub fn pools_with_roles_and_bans(&self) -> Vec<(Role, Ban, Pool)> { - let mut result: Vec<_> = self - .replicas + let result: Vec<_> = self + .targets .iter() - .map(|target| (Role::Replica, target.ban.clone(), target.pool.clone())) + .map(|target| (target.role(), target.ban.clone(), target.pool.clone())) .collect(); - if let Some(ref primary) = self.primary { - result.push((Role::Primary, primary.ban.clone(), primary.pool.clone())); - } - result } @@ -272,7 +260,7 @@ impl Replicas { use LoadBalancingStrategy::*; use ReadWriteSplit::*; - let mut candidates: Vec<&ReadTarget> = self.replicas.iter().collect(); + let mut candidates: Vec<&ReadTarget> = self.targets.iter().collect(); let primary_reads = match self.rw_split { IncludePrimary => true, @@ -280,10 +268,8 @@ impl Replicas { ExcludePrimary => false, }; - if primary_reads { - if let Some(ref primary) = self.primary { - candidates.push(primary); - } + if !primary_reads { + candidates.retain(|target| target.role() == Role::Replica); } match self.lb_strategy { @@ -330,7 +316,7 @@ impl Replicas { /// /// N.B. The primary pool is managed by `super::Shard`. pub fn shutdown(&self) { - for target in &self.replicas { + for target in &self.targets { target.pool.shutdown(); } diff --git a/pgdog/src/backend/pool/replicas/monitor.rs b/pgdog/src/backend/pool/lb/monitor.rs similarity index 80% rename from pgdog/src/backend/pool/replicas/monitor.rs rename to pgdog/src/backend/pool/lb/monitor.rs index 748e5e855..de9e02535 100644 --- a/pgdog/src/backend/pool/replicas/monitor.rs +++ b/pgdog/src/backend/pool/lb/monitor.rs @@ -9,12 +9,12 @@ static MAINTENANCE: Duration = Duration::from_millis(333); #[derive(Clone, Debug)] pub(super) struct Monitor { - replicas: Replicas, + replicas: LoadBalancer, } impl Monitor { /// Create new replica targets monitor. - pub(super) fn spawn(replicas: &Replicas) -> JoinHandle<()> { + pub(super) fn spawn(replicas: &LoadBalancer) -> JoinHandle<()> { let monitor = Self { replicas: replicas.clone(), }; @@ -27,24 +27,10 @@ impl Monitor { async fn run(&self) { let mut interval = interval(MAINTENANCE); - let mut targets: Vec<_> = self.replicas.replicas.clone(); - if let Some(primary) = self.replicas.primary.clone() { - targets.push(primary); - } - - let mut bans: Vec = self - .replicas - .replicas - .iter() - .map(|target| target.ban.clone()) - .collect(); - - if let Some(ref primary) = self.replicas.primary { - bans.push(primary.ban.clone()); - } - debug!("replicas monitor running"); + let targets = &self.replicas.targets; + loop { let mut check_offline = false; let mut ban_targets = Vec::new(); @@ -59,7 +45,7 @@ impl Monitor { if check_offline { let offline = self .replicas - .replicas + .targets .iter() .all(|target| !target.pool.lock().online); diff --git a/pgdog/src/backend/pool/replicas/target_health.rs b/pgdog/src/backend/pool/lb/target_health.rs similarity index 100% rename from pgdog/src/backend/pool/replicas/target_health.rs rename to pgdog/src/backend/pool/lb/target_health.rs diff --git a/pgdog/src/backend/pool/replicas/test.rs b/pgdog/src/backend/pool/lb/test.rs similarity index 87% rename from pgdog/src/backend/pool/replicas/test.rs rename to pgdog/src/backend/pool/lb/test.rs index 94f5f03e3..9c72fb44b 100644 --- a/pgdog/src/backend/pool/replicas/test.rs +++ b/pgdog/src/backend/pool/lb/test.rs @@ -28,11 +28,11 @@ fn create_test_pool_config(host: &str, port: u16) -> PoolConfig { } } -fn setup_test_replicas() -> Replicas { +fn setup_test_replicas() -> LoadBalancer { let pool_config1 = create_test_pool_config("127.0.0.1", 5432); let pool_config2 = create_test_pool_config("localhost", 5432); - let replicas = Replicas::new( + let replicas = LoadBalancer::new( &None, &[pool_config1, pool_config2], LoadBalancingStrategy::Random, @@ -47,7 +47,7 @@ async fn test_replica_ban_recovery_after_timeout() { let replicas = setup_test_replicas(); // Ban the first replica with very short timeout - let ban = &replicas.replicas[0].ban; + let ban = &replicas.targets[0].ban; ban.ban(Error::ServerError, Duration::from_millis(50)); assert!(ban.banned()); @@ -70,7 +70,7 @@ async fn test_replica_manual_unban() { let replicas = setup_test_replicas(); // Ban the first replica - let ban = &replicas.replicas[0].ban; + let ban = &replicas.targets[0].ban; ban.ban(Error::ServerError, Duration::from_millis(1000)); assert!(ban.banned()); @@ -87,7 +87,7 @@ async fn test_replica_manual_unban() { async fn test_replica_ban_error_retrieval() { let replicas = setup_test_replicas(); - let ban = &replicas.replicas[0].ban; + let ban = &replicas.targets[0].ban; // No error initially assert!(ban.error().is_none()); @@ -108,7 +108,7 @@ async fn test_multiple_replica_banning() { // Ban both replicas for i in 0..2 { - let ban = &replicas.replicas[i].ban; + let ban = &replicas.targets[i].ban; ban.ban(Error::ServerError, Duration::from_millis(100)); assert!(ban.banned()); @@ -116,7 +116,7 @@ async fn test_multiple_replica_banning() { // Both should be banned assert_eq!( - replicas.replicas.iter().filter(|r| r.ban.banned()).count(), + replicas.targets.iter().filter(|r| r.ban.banned()).count(), 2 ); @@ -127,7 +127,7 @@ async fn test_multiple_replica_banning() { async fn test_replica_ban_idempotency() { let replicas = setup_test_replicas(); - let ban = &replicas.replicas[0].ban; + let ban = &replicas.targets[0].ban; // First ban should succeed let first_ban = ban.ban(Error::ServerError, Duration::from_millis(100)); @@ -170,7 +170,7 @@ async fn test_primary_pool_banning() { let replica_configs = [create_test_pool_config("localhost", 5432)]; - let replicas = Replicas::new( + let replicas = LoadBalancer::new( &Some(primary_pool), &replica_configs, LoadBalancingStrategy::Random, @@ -179,9 +179,9 @@ async fn test_primary_pool_banning() { replicas.launch(); // Test primary ban exists - assert!(replicas.primary.is_some()); + assert!(replicas.primary_target().is_some()); - let primary_ban = &replicas.primary.as_ref().unwrap().ban; + let primary_ban = &replicas.primary_target().unwrap().ban; // Ban primary for reads primary_ban.ban(Error::ServerError, Duration::from_millis(100)); @@ -198,7 +198,6 @@ async fn test_primary_pool_banning() { assert!(has_primary); // Shutdown both primary and replicas - replicas.primary.as_ref().unwrap().pool.shutdown(); replicas.shutdown(); } @@ -206,7 +205,7 @@ async fn test_primary_pool_banning() { async fn test_ban_timeout_not_expired() { let replicas = setup_test_replicas(); - let ban = &replicas.replicas[0].ban; + let ban = &replicas.targets[0].ban; ban.ban(Error::ServerError, Duration::from_millis(1000)); // Long timeout assert!(ban.banned()); @@ -225,8 +224,8 @@ async fn test_ban_timeout_not_expired() { async fn test_unban_if_expired_checks_pool_health() { let replicas = setup_test_replicas(); - let ban = &replicas.replicas[0].ban; - let pool = &replicas.replicas[0].pool; + let ban = &replicas.targets[0].ban; + let pool = &replicas.targets[0].pool; ban.ban(Error::ServerError, Duration::from_millis(50)); assert!(ban.banned()); @@ -271,7 +270,7 @@ async fn test_replica_ban_clears_idle_connections() { idle_before ); - let ban = &replicas.replicas[0].ban; + let ban = &replicas.targets[0].ban; // Ban should trigger dump_idle() on the pool ban.ban(Error::ServerError, Duration::from_millis(100)); @@ -294,7 +293,7 @@ async fn test_monitor_automatic_ban_expiration() { let replicas = setup_test_replicas(); // Ban the first replica with very short timeout - let ban = &replicas.replicas[0].ban; + let ban = &replicas.targets[0].ban; ban.ban(Error::ServerError, Duration::from_millis(100)); assert!(ban.banned()); @@ -327,7 +326,7 @@ async fn test_read_write_split_exclude_primary() { create_test_pool_config("127.0.0.1", 5432), ]; - let replicas = Replicas::new( + let replicas = LoadBalancer::new( &Some(primary_pool), &replica_configs, LoadBalancingStrategy::Random, @@ -348,11 +347,10 @@ async fn test_read_write_split_exclude_primary() { assert_eq!(replica_ids.len(), 2); // Verify primary pool ID is not in the set of used pools - let primary_id = replicas.primary.as_ref().unwrap().pool.id(); + let primary_id = replicas.primary().unwrap().id(); assert!(!replica_ids.contains(&primary_id)); // Shutdown both primary and replicas - replicas.primary.as_ref().unwrap().pool.shutdown(); replicas.shutdown(); } @@ -364,7 +362,7 @@ async fn test_read_write_split_include_primary() { let replica_configs = [create_test_pool_config("localhost", 5432)]; - let replicas = Replicas::new( + let replicas = LoadBalancer::new( &Some(primary_pool), &replica_configs, LoadBalancingStrategy::Random, @@ -385,11 +383,10 @@ async fn test_read_write_split_include_primary() { assert_eq!(used_pool_ids.len(), 2); // Verify primary pool ID is in the set of used pools - let primary_id = replicas.primary.as_ref().unwrap().pool.id(); + let primary_id = replicas.primary().unwrap().id(); assert!(used_pool_ids.contains(&primary_id)); // Shutdown both primary and replicas - replicas.primary.as_ref().unwrap().pool.shutdown(); replicas.shutdown(); } @@ -401,7 +398,7 @@ async fn test_read_write_split_exclude_primary_no_primary() { create_test_pool_config("127.0.0.1", 5432), ]; - let replicas = Replicas::new( + let replicas = LoadBalancer::new( &None, &replica_configs, LoadBalancingStrategy::Random, @@ -431,7 +428,7 @@ async fn test_read_write_split_include_primary_no_primary() { create_test_pool_config("127.0.0.1", 5432), ]; - let replicas = Replicas::new( + let replicas = LoadBalancer::new( &None, &replica_configs, LoadBalancingStrategy::Random, @@ -462,7 +459,7 @@ async fn test_read_write_split_with_banned_primary() { let replica_configs = [create_test_pool_config("localhost", 5432)]; - let replicas = Replicas::new( + let replicas = LoadBalancer::new( &Some(primary_pool), &replica_configs, LoadBalancingStrategy::Random, @@ -471,7 +468,7 @@ async fn test_read_write_split_with_banned_primary() { replicas.launch(); // Ban the primary - let primary_ban = &replicas.primary.as_ref().unwrap().ban; + let primary_ban = &replicas.targets.last().unwrap().ban; primary_ban.ban(Error::ServerError, Duration::from_millis(1000)); let request = Request::default(); @@ -487,11 +484,10 @@ async fn test_read_write_split_with_banned_primary() { assert_eq!(used_pool_ids.len(), 1); // Verify primary pool ID is not in the set of used pools - let primary_id = replicas.primary.as_ref().unwrap().pool.id(); + let primary_id = replicas.targets.last().unwrap().pool.id(); assert!(!used_pool_ids.contains(&primary_id)); // Shutdown both primary and replicas - replicas.primary.as_ref().unwrap().pool.shutdown(); replicas.shutdown(); } @@ -503,7 +499,7 @@ async fn test_read_write_split_with_banned_replicas() { let replica_configs = [create_test_pool_config("localhost", 5432)]; - let replicas = Replicas::new( + let replicas = LoadBalancer::new( &Some(primary_pool), &replica_configs, LoadBalancingStrategy::Random, @@ -512,7 +508,7 @@ async fn test_read_write_split_with_banned_replicas() { replicas.launch(); // Ban the replica - let replica_ban = &replicas.replicas[0].ban; + let replica_ban = &replicas.targets[0].ban; replica_ban.ban(Error::ServerError, Duration::from_millis(1000)); let request = Request::default(); @@ -528,11 +524,10 @@ async fn test_read_write_split_with_banned_replicas() { assert_eq!(used_pool_ids.len(), 1); // Verify primary pool ID is in the set of used pools - let primary_id = replicas.primary.as_ref().unwrap().pool.id(); + let primary_id = replicas.targets.last().unwrap().pool.id(); assert!(used_pool_ids.contains(&primary_id)); // Shutdown both primary and replicas - replicas.primary.as_ref().unwrap().pool.shutdown(); replicas.shutdown(); } @@ -547,7 +542,7 @@ async fn test_read_write_split_exclude_primary_with_round_robin() { create_test_pool_config("127.0.0.1", 5432), ]; - let replicas = Replicas::new( + let replicas = LoadBalancer::new( &Some(primary_pool), &replica_configs, LoadBalancingStrategy::RoundRobin, @@ -569,7 +564,7 @@ async fn test_read_write_split_exclude_primary_with_round_robin() { assert_eq!(unique_ids.len(), 2); // Verify primary is never used - let primary_id = replicas.primary.as_ref().unwrap().pool.id(); + let primary_id = replicas.targets.last().unwrap().pool.id(); assert!(!pool_sequence.contains(&primary_id)); // Verify round-robin pattern: each pool should be different from the previous one @@ -584,7 +579,6 @@ async fn test_read_write_split_exclude_primary_with_round_robin() { } // Shutdown both primary and replicas - replicas.primary.as_ref().unwrap().pool.shutdown(); replicas.shutdown(); } @@ -593,7 +587,7 @@ async fn test_monitor_shuts_down_on_notify() { let pool_config1 = create_test_pool_config("127.0.0.1", 5432); let pool_config2 = create_test_pool_config("localhost", 5432); - let replicas = Replicas::new( + let replicas = LoadBalancer::new( &None, &[pool_config1, pool_config2], LoadBalancingStrategy::Random, @@ -601,7 +595,7 @@ async fn test_monitor_shuts_down_on_notify() { ); replicas - .replicas + .targets .iter() .for_each(|target| target.pool.launch()); let monitor_handle = Monitor::spawn(&replicas); @@ -627,11 +621,11 @@ async fn test_monitor_shuts_down_on_notify() { async fn test_monitor_bans_unhealthy_target() { let replicas = setup_test_replicas(); - replicas.replicas[0].health.toggle(false); + replicas.targets[0].health.toggle(false); sleep(Duration::from_millis(400)).await; - assert!(replicas.replicas[0].ban.banned()); + assert!(replicas.targets[0].ban.banned()); replicas.shutdown(); } @@ -640,13 +634,13 @@ async fn test_monitor_bans_unhealthy_target() { async fn test_monitor_clears_expired_bans() { let replicas = setup_test_replicas(); - replicas.replicas[0] + replicas.targets[0] .ban .ban(Error::ServerError, Duration::from_millis(50)); sleep(Duration::from_millis(400)).await; - assert!(!replicas.replicas[0].ban.banned()); + assert!(!replicas.targets[0].ban.banned()); replicas.shutdown(); } @@ -655,7 +649,7 @@ async fn test_monitor_clears_expired_bans() { async fn test_monitor_does_not_ban_single_target() { let pool_config = create_test_pool_config("127.0.0.1", 5432); - let replicas = Replicas::new( + let replicas = LoadBalancer::new( &None, &[pool_config], LoadBalancingStrategy::Random, @@ -663,11 +657,11 @@ async fn test_monitor_does_not_ban_single_target() { ); replicas.launch(); - replicas.replicas[0].health.toggle(false); + replicas.targets[0].health.toggle(false); sleep(Duration::from_millis(400)).await; - assert!(!replicas.replicas[0].ban.banned()); + assert!(!replicas.targets[0].ban.banned()); replicas.shutdown(); } @@ -676,13 +670,13 @@ async fn test_monitor_does_not_ban_single_target() { async fn test_monitor_unbans_all_when_all_unhealthy() { let replicas = setup_test_replicas(); - replicas.replicas[0].health.toggle(false); - replicas.replicas[1].health.toggle(false); + replicas.targets[0].health.toggle(false); + replicas.targets[1].health.toggle(false); sleep(Duration::from_millis(400)).await; - assert!(!replicas.replicas[0].ban.banned()); - assert!(!replicas.replicas[1].ban.banned()); + assert!(!replicas.targets[0].ban.banned()); + assert!(!replicas.targets[1].ban.banned()); replicas.shutdown(); } @@ -725,7 +719,7 @@ async fn test_monitor_does_not_ban_with_zero_ban_timeout() { ..Default::default() }; - let replicas = Replicas::new( + let replicas = LoadBalancer::new( &None, &[pool_config1, pool_config2], LoadBalancingStrategy::Random, @@ -733,11 +727,11 @@ async fn test_monitor_does_not_ban_with_zero_ban_timeout() { ); replicas.launch(); - replicas.replicas[0].health.toggle(false); + replicas.targets[0].health.toggle(false); sleep(Duration::from_millis(400)).await; - assert!(!replicas.replicas[0].ban.banned()); + assert!(!replicas.targets[0].ban.banned()); replicas.shutdown(); } @@ -747,7 +741,7 @@ async fn test_monitor_health_state_race() { use tokio::spawn; let replicas = setup_test_replicas(); - let target = replicas.replicas[0].clone(); + let target = replicas.targets[0].clone(); let toggle_task = spawn(async move { for _ in 0..50 { @@ -762,8 +756,8 @@ async fn test_monitor_health_state_race() { toggle_task.await.unwrap(); - let banned = replicas.replicas[0].ban.banned(); - let healthy = replicas.replicas[0].health.healthy(); + let banned = replicas.targets[0].ban.banned(); + let healthy = replicas.targets[0].health.healthy(); assert!( !banned || !healthy, @@ -781,7 +775,7 @@ async fn test_include_primary_if_replica_banned_no_bans() { let replica_configs = [create_test_pool_config("localhost", 5432)]; - let replicas = Replicas::new( + let replicas = LoadBalancer::new( &Some(primary_pool), &replica_configs, LoadBalancingStrategy::Random, @@ -802,11 +796,10 @@ async fn test_include_primary_if_replica_banned_no_bans() { assert_eq!(used_pool_ids.len(), 1); // Verify primary pool ID is not in the set of used pools - let primary_id = replicas.primary.as_ref().unwrap().pool.id(); + let primary_id = replicas.primary().unwrap().id(); assert!(!used_pool_ids.contains(&primary_id)); // Shutdown both primary and replicas - replicas.primary.as_ref().unwrap().pool.shutdown(); replicas.shutdown(); } @@ -818,7 +811,7 @@ async fn test_include_primary_if_replica_banned_with_ban() { let replica_configs = [create_test_pool_config("localhost", 5432)]; - let replicas = Replicas::new( + let replicas = LoadBalancer::new( &Some(primary_pool), &replica_configs, LoadBalancingStrategy::Random, @@ -827,7 +820,7 @@ async fn test_include_primary_if_replica_banned_with_ban() { replicas.launch(); // Ban the replica - let replica_ban = &replicas.replicas[0].ban; + let replica_ban = &replicas.targets[0].ban; replica_ban.ban(Error::ServerError, Duration::from_millis(1000)); let request = Request::default(); @@ -843,10 +836,9 @@ async fn test_include_primary_if_replica_banned_with_ban() { assert_eq!(used_pool_ids.len(), 1); // Verify primary pool ID is in the set of used pools - let primary_id = replicas.primary.as_ref().unwrap().pool.id(); + let primary_id = replicas.primary().unwrap().id(); assert!(used_pool_ids.contains(&primary_id)); // Shutdown both primary and replicas - replicas.primary.as_ref().unwrap().pool.shutdown(); replicas.shutdown(); } diff --git a/pgdog/src/backend/pool/mod.rs b/pgdog/src/backend/pool/mod.rs index b4f54d2cd..2f2d79b37 100644 --- a/pgdog/src/backend/pool/mod.rs +++ b/pgdog/src/backend/pool/mod.rs @@ -11,13 +11,13 @@ pub mod error; pub mod guard; pub mod healthcheck; pub mod inner; +pub mod lb; pub mod lsn_monitor; pub mod mapping; pub mod mirror_stats; pub mod monitor; pub mod oids; pub mod pool_impl; -pub mod replicas; pub mod request; pub mod shard; pub mod state; @@ -32,12 +32,12 @@ pub use connection::Connection; pub use error::Error; pub use guard::Guard; pub use healthcheck::Healtcheck; +pub use lb::LoadBalancer; pub use lsn_monitor::LsnStats; pub use mirror_stats::MirrorStats; pub use monitor::Monitor; pub use oids::Oids; pub use pool_impl::Pool; -pub use replicas::Replicas; pub use request::Request; pub use shard::Shard; pub use state::State; diff --git a/pgdog/src/backend/pool/pool_impl.rs b/pgdog/src/backend/pool/pool_impl.rs index 5694b96ad..7b353926d 100644 --- a/pgdog/src/backend/pool/pool_impl.rs +++ b/pgdog/src/backend/pool/pool_impl.rs @@ -18,8 +18,8 @@ use crate::net::{Parameter, Parameters}; use super::inner::CheckInResult; use super::{ - lsn_monitor::LsnMonitor, replicas::TargetHealth, Address, Comms, Config, Error, Guard, - Healtcheck, Inner, Monitor, Oids, PoolConfig, Request, State, Waiting, + lb::TargetHealth, lsn_monitor::LsnMonitor, Address, Comms, Config, Error, Guard, Healtcheck, + Inner, Monitor, Oids, PoolConfig, Request, State, Waiting, }; static ID_COUNTER: Lazy> = Lazy::new(|| Arc::new(AtomicU64::new(0))); @@ -422,12 +422,6 @@ impl Pool { self.lock().config = config; } - /// Set LSN stats for testing. - #[cfg(test)] - pub(crate) fn set_lsn_stats(&self, stats: LsnStats) { - *self.inner().lsn_stats.write() = stats; - } - /// Fetch OIDs for user-defined data types. pub fn oids(&self) -> Option { self.lock().oids diff --git a/pgdog/src/backend/pool/replicas/detected_role.rs b/pgdog/src/backend/pool/replicas/detected_role.rs deleted file mode 100644 index 9ed6c7533..000000000 --- a/pgdog/src/backend/pool/replicas/detected_role.rs +++ /dev/null @@ -1,27 +0,0 @@ -use pgdog_config::Role; -use tokio::time::Instant; - -use super::ReadTarget; - -#[derive(Debug, Clone, Copy, Eq)] -pub struct DetectedRole { - pub role: Role, - pub as_of: Instant, - pub database_number: usize, -} - -impl DetectedRole { - pub fn from_read_target(target: &ReadTarget) -> Self { - Self { - role: target.role, - as_of: Instant::now(), - database_number: target.pool.addr().database_number, - } - } -} - -impl PartialEq for DetectedRole { - fn eq(&self, other: &Self) -> bool { - self.role == other.role && self.database_number == other.database_number - } -} diff --git a/pgdog/src/backend/pool/shard/mod.rs b/pgdog/src/backend/pool/shard/mod.rs index 20a9c6c4c..4d7ae7bd9 100644 --- a/pgdog/src/backend/pool/shard/mod.rs +++ b/pgdog/src/backend/pool/shard/mod.rs @@ -8,13 +8,13 @@ use tokio::{select, spawn, sync::Notify}; use tracing::debug; use crate::backend::databases::User; -use crate::backend::pool::replicas::ban::Ban; +use crate::backend::pool::lb::ban::Ban; use crate::backend::PubSubListener; use crate::config::{config, LoadBalancingStrategy, ReadWriteSplit, Role}; use crate::net::messages::BackendKeyData; use crate::net::NotificationResponse; -use super::{Error, Guard, Pool, PoolConfig, Replicas, Request}; +use super::{Error, Guard, LoadBalancer, Pool, PoolConfig, Request}; pub mod monitor; pub mod role_detector; @@ -37,8 +37,6 @@ pub(super) struct ShardConfig<'a> { pub(super) identifier: Arc, /// LSN check interval pub(super) lsn_check_interval: Duration, - /// Role detector is enabled. - pub(super) role_detector: bool, } /// Connection pools for a single database shard. @@ -67,8 +65,8 @@ impl Shard { /// Get connection to the primary database. pub async fn primary(&self, request: &Request) -> Result { - self.primary - .as_ref() + self.replicas + .primary() .ok_or(Error::NoPrimary)? .get(request) .await @@ -77,24 +75,12 @@ impl Shard { /// Get connection to one of the replica databases, using the configured /// load balancing algorithm. pub async fn replica(&self, request: &Request) -> Result { - if self.replicas.is_empty() { - self.primary - .as_ref() - .ok_or(Error::NoDatabases)? - .get(request) - .await - } else { - self.replicas.get(request).await - } + self.replicas.get(request).await } /// Get connection to primary if configured, otherwise replica. pub async fn primary_or_replica(&self, request: &Request) -> Result { - if self.primary.is_some() { - self.primary(request).await - } else { - self.replica(request).await - } + self.replica(request).await } /// Move connections from this shard to another shard, preserving them. @@ -102,28 +88,12 @@ impl Shard { /// This is done during configuration reloading, if no significant changes are made to /// the configuration. pub fn move_conns_to(&self, destination: &Shard) { - if let Some(ref primary) = self.primary { - if let Some(ref other) = destination.primary { - primary.move_conns_to(other); - } - } - self.replicas.move_conns_to(&destination.replicas); } /// Checks if the connection pools from this shard are compatible /// with the other shard. If yes, they can be moved without closing them. pub(crate) fn can_move_conns_to(&self, other: &Shard) -> bool { - if let Some(ref primary) = self.primary { - if let Some(ref other) = other.primary { - if !primary.can_move_conns_to(other) { - return false; - } - } else { - return false; - } - } - self.replicas.can_move_conns_to(&other.replicas) } @@ -150,9 +120,6 @@ impl Shard { /// Bring every pool online. pub fn launch(&self) { - if let Some(ref primary) = self.primary { - primary.launch(); - } self.replicas.launch(); ShardMonitor::run(self); if let Some(ref listener) = self.pub_sub { @@ -162,7 +129,7 @@ impl Shard { /// Returns true if the shard has a primary database. pub fn has_primary(&self) -> bool { - self.primary.is_some() + self.replicas.primary().is_some() } /// Returns true if the shard has any replica databases. @@ -180,9 +147,6 @@ impl Shard { /// If these connection pools aren't running the query sent by this client, this is a no-op. /// pub async fn cancel(&self, id: &BackendKeyData) -> Result<(), super::super::Error> { - if let Some(ref primary) = self.primary { - primary.cancel(id).await?; - } self.replicas.cancel(id).await?; Ok(()) @@ -199,15 +163,12 @@ impl Shard { /// Get all connection pools along with their roles (i.e., primary or replica). pub fn pools_with_roles(&self) -> Vec<(Role, Pool)> { let mut pools = vec![]; - if let Some(primary) = self.primary.clone() { - pools.push((Role::Primary, primary)); - } pools.extend( self.replicas - .pools() - .into_iter() - .map(|p| (Role::Replica, p.clone())), + .targets + .iter() + .map(|target| (target.role(), target.pool.clone())), ); pools @@ -221,9 +182,6 @@ impl Shard { /// Shutdown every pool and maintenance task in this shard. pub fn shutdown(&self) { self.comms.shutdown.notify_waiters(); - if let Some(pool) = self.primary.as_ref() { - pool.shutdown() - } if let Some(ref listener) = self.pub_sub { listener.shutdown(); } @@ -244,14 +202,9 @@ impl Shard { /// Re-detect primary/replica roles and re-build /// the shard routing logic. - pub fn redetect_roles(&self) -> Option { + pub fn redetect_roles(&self) -> bool { self.replicas.redetect_roles() } - - /// Get current roles. - pub fn current_roles(&self) -> DetectedRoles { - self.replicas.current_roles() - } } impl Deref for Shard { @@ -267,8 +220,7 @@ impl Deref for Shard { #[derive(Default, Debug, Clone)] pub struct ShardInner { number: usize, - primary: Option, - replicas: Replicas, + replicas: LoadBalancer, comms: Arc, pub_sub: Option, identifier: Arc, @@ -284,14 +236,12 @@ impl ShardInner { rw_split, identifier, lsn_check_interval, - role_detector, } = shard; let primary = primary.as_ref().map(Pool::new); - let replicas = Replicas::new(&primary, replicas, lb_strategy, rw_split); + let replicas = LoadBalancer::new(&primary, replicas, lb_strategy, rw_split); let comms = Arc::new(ShardComms { shutdown: Notify::new(), lsn_check_interval, - role_detector, }); let pub_sub = if config().pub_sub_enabled() { primary.as_ref().map(PubSubListener::new) @@ -301,7 +251,6 @@ impl ShardInner { Self { number, - primary, replicas, comms, pub_sub, @@ -343,12 +292,11 @@ mod test { database: "pgdog".into(), }), lsn_check_interval: Duration::MAX, - role_detector: false, }); shard.launch(); for _ in 0..25 { - let replica_id = shard.replicas.replicas[0].pool.id(); + let replica_id = shard.replicas.targets[0].pool.id(); let conn = shard.replica(&Request::default()).await.unwrap(); assert_eq!(conn.pool.id(), replica_id); @@ -382,7 +330,6 @@ mod test { database: "pgdog".into(), }), lsn_check_interval: Duration::MAX, - role_detector: false, }); shard.launch(); let mut ids = BTreeSet::new(); diff --git a/pgdog/src/backend/pool/shard/monitor.rs b/pgdog/src/backend/pool/shard/monitor.rs index 096de67bd..0f1219927 100644 --- a/pgdog/src/backend/pool/shard/monitor.rs +++ b/pgdog/src/backend/pool/shard/monitor.rs @@ -1,16 +1,13 @@ -use crate::backend::databases; - use super::*; use tokio::time::interval; -use tracing::{info, warn}; +use tracing::warn; /// Shard communication primitives. #[derive(Debug)] pub(super) struct ShardComms { pub(super) shutdown: Notify, pub(super) lsn_check_interval: Duration, - pub(super) role_detector: bool, } impl Default for ShardComms { @@ -18,7 +15,6 @@ impl Default for ShardComms { Self { shutdown: Notify::new(), lsn_check_interval: Duration::MAX, - role_detector: false, } } } @@ -55,15 +51,6 @@ impl ShardMonitor { ); let mut detector = RoleDetector::new(&self.shard); - let detector_enabled = self.shard.comms().role_detector; - - if detector_enabled { - info!( - "failover enabled for shard {} [{}]", - self.shard.number(), - self.shard.identifier() - ); - } loop { select! { @@ -73,13 +60,12 @@ impl ShardMonitor { }, } - if detector_enabled && detector.changed() { + if detector.changed() { warn!( "database role changed in shard {} [{}]", self.shard.number(), self.shard.identifier() ); - databases::reload_from_existing(); break; } diff --git a/pgdog/src/backend/pool/shard/role_detector.rs b/pgdog/src/backend/pool/shard/role_detector.rs index bdf75e482..515dca0bc 100644 --- a/pgdog/src/backend/pool/shard/role_detector.rs +++ b/pgdog/src/backend/pool/shard/role_detector.rs @@ -1,28 +1,4 @@ -use std::{collections::HashMap, ops::Deref}; - use super::Shard; -use crate::backend::pool::replicas::DetectedRole; - -pub type DatabaseNumber = usize; - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct DetectedRoles { - roles: HashMap, -} - -impl Deref for DetectedRoles { - type Target = HashMap; - - fn deref(&self) -> &Self::Target { - &self.roles - } -} - -impl From> for DetectedRoles { - fn from(value: HashMap) -> Self { - Self { roles: value } - } -} #[derive(Debug, Clone, PartialEq, Eq)] pub enum RoleChangeEvent { @@ -32,7 +8,6 @@ pub enum RoleChangeEvent { } pub(super) struct RoleDetector { - current: DetectedRoles, // Database number <> Role shard: Shard, } @@ -40,22 +15,12 @@ impl RoleDetector { /// Create new role change detector. pub(super) fn new(shard: &Shard) -> Self { Self { - current: shard.current_roles(), shard: shard.clone(), } } /// Detect role change in the shard. pub(super) fn changed(&mut self) -> bool { - let latest = self.shard.redetect_roles(); - let mut changed = false; - if let Some(latest) = latest { - if self.current != latest { - changed = true; - self.current = latest; - } - } - - changed + self.shard.redetect_roles() } } diff --git a/pgdog/src/backend/server.rs b/pgdog/src/backend/server.rs index de0daff84..329be5803 100644 --- a/pgdog/src/backend/server.rs +++ b/pgdog/src/backend/server.rs @@ -469,7 +469,7 @@ impl Server { // Combine both to create a new, fresh session state // on this connection. - queries.extend(tracked.set_queries()); + queries.extend(tracked.set_queries(false)); // Set state on the connection only if // there are any params to change. diff --git a/pgdog/src/frontend/client/query_engine/connect.rs b/pgdog/src/frontend/client/query_engine/connect.rs index 818f3218f..61767345a 100644 --- a/pgdog/src/frontend/client/query_engine/connect.rs +++ b/pgdog/src/frontend/client/query_engine/connect.rs @@ -46,6 +46,7 @@ impl QueryEngine { } let query_timeout = context.timeouts.query_timeout(&self.stats.state); + // We may need to sync params with the server and that reads from the socket. timeout( query_timeout, @@ -53,6 +54,19 @@ impl QueryEngine { ) .await??; + // Sync transaction parameters. These will only + // be captured inside an explicit transaction + // so we don't have to track them. + let set_queries = self.transaction_params.set_queries(true); + for query in set_queries { + timeout(query_timeout, self.backend.execute(query)).await??; + } + debug!( + "synced {} in-transaction parameters", + self.transaction_params.len() + ); + self.transaction_params.clear(); + true } diff --git a/pgdog/src/frontend/client/query_engine/mod.rs b/pgdog/src/frontend/client/query_engine/mod.rs index 9ff2c4381..02ebbb6de 100644 --- a/pgdog/src/frontend/client/query_engine/mod.rs +++ b/pgdog/src/frontend/client/query_engine/mod.rs @@ -55,6 +55,7 @@ pub struct QueryEngine { notify_buffer: NotifyBuffer, pending_explain: Option, hooks: QueryEngineHooks, + transaction_params: Parameters, } impl QueryEngine { @@ -210,11 +211,16 @@ impl QueryEngine { .await? } Command::Unlisten(channel) => self.unlisten(context, &channel.clone()).await?, - Command::Set { name, value } => { + Command::Set { + name, + value, + in_transaction, + } => { if self.backend.connected() { self.execute(context, &route).await? } else { - self.set(context, name.clone(), value.clone()).await? + self.set(context, name.clone(), value.clone(), *in_transaction) + .await? } } Command::SetRoute(route) => { diff --git a/pgdog/src/frontend/client/query_engine/set.rs b/pgdog/src/frontend/client/query_engine/set.rs index d6ba8041e..35152be51 100644 --- a/pgdog/src/frontend/client/query_engine/set.rs +++ b/pgdog/src/frontend/client/query_engine/set.rs @@ -8,9 +8,14 @@ impl QueryEngine { context: &mut QueryEngineContext<'_>, name: String, value: ParameterValue, + in_transaction: bool, ) -> Result<(), Error> { - context.params.insert(name, value); - self.comms.update_params(context.params); + if in_transaction { + self.transaction_params.insert(name, value); + } else { + context.params.insert(name, value); + self.comms.update_params(context.params); + } let bytes_sent = context .stream diff --git a/pgdog/src/frontend/router/parser/command.rs b/pgdog/src/frontend/router/parser/command.rs index f924e8f29..c91bcb66e 100644 --- a/pgdog/src/frontend/router/parser/command.rs +++ b/pgdog/src/frontend/router/parser/command.rs @@ -26,6 +26,7 @@ pub enum Command { Set { name: String, value: ParameterValue, + in_transaction: bool, }, PreparedStatement(Prepare), Rewrite(Vec), diff --git a/pgdog/src/frontend/router/parser/query/set.rs b/pgdog/src/frontend/router/parser/query/set.rs index 0b469ef2a..74fe083e0 100644 --- a/pgdog/src/frontend/router/parser/query/set.rs +++ b/pgdog/src/frontend/router/parser/query/set.rs @@ -75,47 +75,53 @@ impl QueryParser { // TODO: Handle SET commands for updating client // params without touching the server. name => { - if !self.in_transaction { - let mut value = vec![]; + if let Shard::Direct(shard) = self.shard { + return Ok(Command::Query( + Route::write(shard).set_read(context.read_only), + )); + } - for node in &stmt.args { - if let Some(NodeEnum::AConst(AConst { val: Some(val), .. })) = &node.node { - match val { - Val::Sval(String { sval }) => { - value.push(sval.to_string()); - } + let mut value = vec![]; - Val::Ival(Integer { ival }) => { - value.push(ival.to_string()); - } + for node in &stmt.args { + if let Some(NodeEnum::AConst(AConst { val: Some(val), .. })) = &node.node { + match val { + Val::Sval(String { sval }) => { + value.push(sval.to_string()); + } - Val::Fval(Float { fval }) => { - value.push(fval.to_string()); - } + Val::Ival(Integer { ival }) => { + value.push(ival.to_string()); + } - Val::Boolval(Boolean { boolval }) => { - value.push(boolval.to_string()); - } + Val::Fval(Float { fval }) => { + value.push(fval.to_string()); + } - _ => (), + Val::Boolval(Boolean { boolval }) => { + value.push(boolval.to_string()); } + + _ => (), } } + } - match value.len() { - 0 => (), - 1 => { - return Ok(Command::Set { - name: name.to_string(), - value: ParameterValue::String(value.pop().unwrap()), - }) - } - _ => { - return Ok(Command::Set { - name: name.to_string(), - value: ParameterValue::Tuple(value), - }) - } + match value.len() { + 0 => (), + 1 => { + return Ok(Command::Set { + name: name.to_string(), + value: ParameterValue::String(value.pop().unwrap()), + in_transaction: self.in_transaction, + }) + } + _ => { + return Ok(Command::Set { + name: name.to_string(), + value: ParameterValue::Tuple(value), + in_transaction: self.in_transaction, + }) } } } diff --git a/pgdog/src/frontend/router/parser/query/test.rs b/pgdog/src/frontend/router/parser/query/test.rs index 3a8d02e8c..3a0098344 100644 --- a/pgdog/src/frontend/router/parser/query/test.rs +++ b/pgdog/src/frontend/router/parser/query/test.rs @@ -434,7 +434,7 @@ fn test_set() { command!("SET TIME ZONE 'UTC'"), ] { match command { - Command::Set { name, value } => { + Command::Set { name, value, .. } => { assert_eq!(name, "timezone"); assert_eq!(value, ParameterValue::from("UTC")); } @@ -445,7 +445,7 @@ fn test_set() { let (command, qp) = command!("SET statement_timeout TO 3000"); match command { - Command::Set { name, value } => { + Command::Set { name, value, .. } => { assert_eq!(name, "statement_timeout"); assert_eq!(value, ParameterValue::from("3000")); } @@ -457,7 +457,7 @@ fn test_set() { // The server will report an error on synchronization. let (command, qp) = command!("SET is_superuser TO true"); match command { - Command::Set { name, value } => { + Command::Set { name, value, .. } => { assert_eq!(name, "is_superuser"); assert_eq!(value, ParameterValue::from("true")); } @@ -468,14 +468,14 @@ fn test_set() { let (_, mut qp) = command!("BEGIN"); assert!(qp.write_override); let command = query_parser!(qp, Query::new(r#"SET statement_timeout TO 3000"#), true); - match command { - Command::Query(q) => assert!(q.is_write()), - _ => panic!("set should trigger binding"), - } + assert!( + matches!(command, Command::Set { .. }), + "set must be intercepted inside transactions" + ); let (command, _) = command!("SET search_path TO \"$user\", public, \"APPLES\""); match command { - Command::Set { name, value } => { + Command::Set { name, value, .. } => { assert_eq!(name, "search_path"); assert_eq!( value, @@ -502,10 +502,8 @@ fn test_set() { let route = qp.query(&mut context).unwrap(); match route { - Command::Query(route) => { - assert_eq!(route.is_read(), read_only); - } - cmd => panic!("not a query: {:?}", cmd), + Command::Set { .. } => {} + _ => panic!("set must be intercepted"), } } } @@ -546,8 +544,14 @@ fn test_transaction() { cluster.clone() ); match route { - Command::Query(q) => { - assert!(q.is_write()); + Command::Set { + name, + value, + in_transaction, + } => { + assert!(in_transaction); + assert_eq!(name, "application_name"); + assert_eq!(value.as_str().unwrap(), "test"); assert!(!cluster.read_only()); } @@ -962,5 +966,5 @@ fn test_set_comments() { Query::new("SET statement_timeout TO 1"), true ); - assert_eq!(command.route().shard(), &Shard::All); + assert!(matches!(command, Command::Set { .. })); } diff --git a/pgdog/src/net/parameter.rs b/pgdog/src/net/parameter.rs index 398e7f59c..94b177c0e 100644 --- a/pgdog/src/net/parameter.rs +++ b/pgdog/src/net/parameter.rs @@ -176,10 +176,17 @@ impl Parameters { self.hash == other.hash } - pub fn set_queries(&self) -> Vec { + pub fn set_queries(&self, local: bool) -> Vec { self.params .iter() - .map(|(name, value)| Query::new(format!(r#"SET "{}" TO {}"#, name, value))) + .map(|(name, value)| { + Query::new(format!( + r#"{} "{}" TO {}"#, + if local { "SET LOCAL" } else { "SET" }, + name, + value + )) + }) .collect() } From fcd5561b58443d9461fd77df35ccd76aacd5d5b0 Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Wed, 3 Dec 2025 19:26:51 -0800 Subject: [PATCH 02/10] set local in the right place --- pgdog/src/backend/server.rs | 2 +- pgdog/src/frontend/client/query_engine/connect.rs | 13 ------------- pgdog/src/frontend/client/query_engine/query.rs | 13 +++++++++++++ pgdog/src/net/parameter.rs | 11 ++--------- 4 files changed, 16 insertions(+), 23 deletions(-) diff --git a/pgdog/src/backend/server.rs b/pgdog/src/backend/server.rs index 329be5803..de0daff84 100644 --- a/pgdog/src/backend/server.rs +++ b/pgdog/src/backend/server.rs @@ -469,7 +469,7 @@ impl Server { // Combine both to create a new, fresh session state // on this connection. - queries.extend(tracked.set_queries(false)); + queries.extend(tracked.set_queries()); // Set state on the connection only if // there are any params to change. diff --git a/pgdog/src/frontend/client/query_engine/connect.rs b/pgdog/src/frontend/client/query_engine/connect.rs index 61767345a..d56b60caa 100644 --- a/pgdog/src/frontend/client/query_engine/connect.rs +++ b/pgdog/src/frontend/client/query_engine/connect.rs @@ -54,19 +54,6 @@ impl QueryEngine { ) .await??; - // Sync transaction parameters. These will only - // be captured inside an explicit transaction - // so we don't have to track them. - let set_queries = self.transaction_params.set_queries(true); - for query in set_queries { - timeout(query_timeout, self.backend.execute(query)).await??; - } - debug!( - "synced {} in-transaction parameters", - self.transaction_params.len() - ); - self.transaction_params.clear(); - true } diff --git a/pgdog/src/frontend/client/query_engine/query.rs b/pgdog/src/frontend/client/query_engine/query.rs index b60561c35..ec549bd10 100644 --- a/pgdog/src/frontend/client/query_engine/query.rs +++ b/pgdog/src/frontend/client/query_engine/query.rs @@ -38,6 +38,19 @@ impl QueryEngine { } self.backend.execute(begin_stmt.query()).await?; + + // Sync transaction parameters. These will only + // be captured inside an explicit transaction + // so we don't have to track them. + let query_timeout = context.timeouts.query_timeout(&self.stats.state); + for query in self.transaction_params.set_queries() { + timeout(query_timeout, self.backend.execute(query)).await??; + } + debug!( + "synced {} in-transaction parameters", + self.transaction_params.len() + ); + self.transaction_params.clear(); } else if !self.connect(context, route).await? { return Ok(()); } diff --git a/pgdog/src/net/parameter.rs b/pgdog/src/net/parameter.rs index 94b177c0e..398e7f59c 100644 --- a/pgdog/src/net/parameter.rs +++ b/pgdog/src/net/parameter.rs @@ -176,17 +176,10 @@ impl Parameters { self.hash == other.hash } - pub fn set_queries(&self, local: bool) -> Vec { + pub fn set_queries(&self) -> Vec { self.params .iter() - .map(|(name, value)| { - Query::new(format!( - r#"{} "{}" TO {}"#, - if local { "SET LOCAL" } else { "SET" }, - name, - value - )) - }) + .map(|(name, value)| Query::new(format!(r#"SET "{}" TO {}"#, name, value))) .collect() } From e89fdb31caa6cbd93abe93b96b485e579e8022be Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Wed, 3 Dec 2025 19:44:33 -0800 Subject: [PATCH 03/10] mirror fix --- pgdog/src/backend/pool/connection/mirror/mod.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pgdog/src/backend/pool/connection/mirror/mod.rs b/pgdog/src/backend/pool/connection/mirror/mod.rs index 28363dea4..c9ce80ab6 100644 --- a/pgdog/src/backend/pool/connection/mirror/mod.rs +++ b/pgdog/src/backend/pool/connection/mirror/mod.rs @@ -95,6 +95,11 @@ impl Mirror { // Same query engine as the client, except with a potentially different database config. let mut query_engine = QueryEngine::new(¶ms, &comms(), false, &None)?; + // Mirror must read server responses to keep the connection synchronized, + // so disable test_mode which skips reading responses. + #[cfg(test)] + query_engine.set_test_mode(false); + // Mirror traffic handler. let mut mirror = Self::new(¶ms, &config); From 4feddb0cec2afa16ad02588484f224e07ee33500 Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Wed, 3 Dec 2025 19:59:05 -0800 Subject: [PATCH 04/10] Better names --- pgdog/src/backend/pool/lb/mod.rs | 17 +++++++------ pgdog/src/backend/pool/shard/mod.rs | 38 ++++++++++++++++------------- 2 files changed, 31 insertions(+), 24 deletions(-) diff --git a/pgdog/src/backend/pool/lb/mod.rs b/pgdog/src/backend/pool/lb/mod.rs index 9742969b1..7044f0c69 100644 --- a/pgdog/src/backend/pool/lb/mod.rs +++ b/pgdog/src/backend/pool/lb/mod.rs @@ -1,4 +1,4 @@ -//! Replicas pool. +//! Load balanced connection pool. use std::{ sync::{ @@ -60,7 +60,7 @@ impl ReadTarget { } } -/// Replicas pools. +/// Load balancer. #[derive(Clone, Default, Debug)] pub struct LoadBalancer { /// Read/write targets. @@ -85,10 +85,14 @@ impl LoadBalancer { lb_strategy: LoadBalancingStrategy, rw_split: ReadWriteSplit, ) -> LoadBalancer { - let mut checkout_timeout = addrs - .iter() - .map(|c| c.config.checkout_timeout) - .sum::(); + let checkout_timeout = primary + .as_ref() + .map(|primary| primary.config().checkout_timeout) + .unwrap_or(Duration::ZERO) + + addrs + .iter() + .map(|c| c.config.checkout_timeout) + .sum::(); let mut targets: Vec<_> = addrs .iter() @@ -100,7 +104,6 @@ impl LoadBalancer { .map(|pool| ReadTarget::new(pool.clone(), Role::Primary)); if let Some(primary) = primary_target { - checkout_timeout += primary.pool.config().checkout_timeout; targets.push(primary); } diff --git a/pgdog/src/backend/pool/shard/mod.rs b/pgdog/src/backend/pool/shard/mod.rs index 4d7ae7bd9..aacbe5aec 100644 --- a/pgdog/src/backend/pool/shard/mod.rs +++ b/pgdog/src/backend/pool/shard/mod.rs @@ -65,7 +65,7 @@ impl Shard { /// Get connection to the primary database. pub async fn primary(&self, request: &Request) -> Result { - self.replicas + self.lb .primary() .ok_or(Error::NoPrimary)? .get(request) @@ -75,12 +75,16 @@ impl Shard { /// Get connection to one of the replica databases, using the configured /// load balancing algorithm. pub async fn replica(&self, request: &Request) -> Result { - self.replicas.get(request).await + self.lb.get(request).await } /// Get connection to primary if configured, otherwise replica. pub async fn primary_or_replica(&self, request: &Request) -> Result { - self.replica(request).await + if let Ok(primary) = self.primary(request).await { + Ok(primary) + } else { + self.replica(request).await + } } /// Move connections from this shard to another shard, preserving them. @@ -88,13 +92,13 @@ impl Shard { /// This is done during configuration reloading, if no significant changes are made to /// the configuration. pub fn move_conns_to(&self, destination: &Shard) { - self.replicas.move_conns_to(&destination.replicas); + self.lb.move_conns_to(&destination.lb); } /// Checks if the connection pools from this shard are compatible /// with the other shard. If yes, they can be moved without closing them. pub(crate) fn can_move_conns_to(&self, other: &Shard) -> bool { - self.replicas.can_move_conns_to(&other.replicas) + self.lb.can_move_conns_to(&other.lb) } /// Listen for notifications on channel. @@ -120,7 +124,7 @@ impl Shard { /// Bring every pool online. pub fn launch(&self) { - self.replicas.launch(); + self.lb.launch(); ShardMonitor::run(self); if let Some(ref listener) = self.pub_sub { listener.launch(); @@ -129,12 +133,12 @@ impl Shard { /// Returns true if the shard has a primary database. pub fn has_primary(&self) -> bool { - self.replicas.primary().is_some() + self.lb.primary().is_some() } /// Returns true if the shard has any replica databases. pub fn has_replicas(&self) -> bool { - !self.replicas.is_empty() + !self.lb.is_empty() } /// Request a query to be cancelled on any of the servers in the connection pools @@ -147,7 +151,7 @@ impl Shard { /// If these connection pools aren't running the query sent by this client, this is a no-op. /// pub async fn cancel(&self, id: &BackendKeyData) -> Result<(), super::super::Error> { - self.replicas.cancel(id).await?; + self.lb.cancel(id).await?; Ok(()) } @@ -165,7 +169,7 @@ impl Shard { let mut pools = vec![]; pools.extend( - self.replicas + self.lb .targets .iter() .map(|target| (target.role(), target.pool.clone())), @@ -176,7 +180,7 @@ impl Shard { /// Get all connection pools with bans and their role in the shard. pub fn pools_with_roles_and_bans(&self) -> Vec<(Role, Ban, Pool)> { - self.replicas.pools_with_roles_and_bans() + self.lb.pools_with_roles_and_bans() } /// Shutdown every pool and maintenance task in this shard. @@ -185,7 +189,7 @@ impl Shard { if let Some(ref listener) = self.pub_sub { listener.shutdown(); } - self.replicas.shutdown(); + self.lb.shutdown(); } fn comms(&self) -> &ShardComms { @@ -203,7 +207,7 @@ impl Shard { /// Re-detect primary/replica roles and re-build /// the shard routing logic. pub fn redetect_roles(&self) -> bool { - self.replicas.redetect_roles() + self.lb.redetect_roles() } } @@ -220,7 +224,7 @@ impl Deref for Shard { #[derive(Default, Debug, Clone)] pub struct ShardInner { number: usize, - replicas: LoadBalancer, + lb: LoadBalancer, comms: Arc, pub_sub: Option, identifier: Arc, @@ -238,7 +242,7 @@ impl ShardInner { lsn_check_interval, } = shard; let primary = primary.as_ref().map(Pool::new); - let replicas = LoadBalancer::new(&primary, replicas, lb_strategy, rw_split); + let lb = LoadBalancer::new(&primary, replicas, lb_strategy, rw_split); let comms = Arc::new(ShardComms { shutdown: Notify::new(), lsn_check_interval, @@ -251,7 +255,7 @@ impl ShardInner { Self { number, - replicas, + lb, comms, pub_sub, identifier, @@ -296,7 +300,7 @@ mod test { shard.launch(); for _ in 0..25 { - let replica_id = shard.replicas.targets[0].pool.id(); + let replica_id = shard.lb.targets[0].pool.id(); let conn = shard.replica(&Request::default()).await.unwrap(); assert_eq!(conn.pool.id(), replica_id); From 51ebbc497be2b87c3f76314e5804f37215eab7f7 Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Wed, 3 Dec 2025 21:53:42 -0800 Subject: [PATCH 05/10] Fix promoter --- pgdog/src/backend/pool/lb/mod.rs | 62 ++++----- pgdog/src/backend/pool/lb/test.rs | 61 +++++++++ pgdog/src/backend/pool/shard/mod.rs | 2 +- pgdog/src/backend/pool/shard/monitor.rs | 2 +- pgdog/src/backend/pool/shard/role_detector.rs | 129 +++++++++++++++++- 5 files changed, 213 insertions(+), 43 deletions(-) diff --git a/pgdog/src/backend/pool/lb/mod.rs b/pgdog/src/backend/pool/lb/mod.rs index 7044f0c69..43250af55 100644 --- a/pgdog/src/backend/pool/lb/mod.rs +++ b/pgdog/src/backend/pool/lb/mod.rs @@ -13,6 +13,7 @@ use tokio::{ sync::Notify, time::{timeout, Instant}, }; +use tracing::warn; use crate::config::{LoadBalancingStrategy, ReadWriteSplit, Role}; use crate::net::messages::BackendKeyData; @@ -58,6 +59,13 @@ impl ReadTarget { Role::Primary } } + + /// Set role. + pub(super) fn set_role(&self, role: Role) -> bool { + let value = role == Role::Replica; + let old = self.replica.swap(value, Ordering::Relaxed); + value != old + } } /// Load balancer. @@ -133,16 +141,10 @@ impl LoadBalancer { .find(|target| target.role() == Role::Primary) } - pub fn write_only(&self) -> bool { - self.targets - .iter() - .all(|target| target.role() == Role::Primary) - } - /// Detect database roles from pg_is_in_recovery() and /// return new primary (if any), and replicas. pub fn redetect_roles(&self) -> bool { - let mut changed = false; + let mut promoted = false; let mut targets = self .targets @@ -151,11 +153,6 @@ impl LoadBalancer { .map(|target| (target.pool.lsn_stats(), target)) .collect::>(); - // Only detect roles if the LSN detector is running. - if !targets.iter().all(|target| target.0.valid()) { - return false; - } - // Pick primary by latest data. The one with the most // up-to-date lsn number and pg_is_in_recovery() = false // is the new primary. @@ -168,26 +165,26 @@ impl LoadBalancer { let primary = targets .iter() - .find(|target| target.0.valid() && !target.0.replica); + .position(|target| !target.0.replica && target.0.valid()); if let Some(primary) = primary { - if primary.1.role() != Role::Primary { - changed = true; - primary.1.replica.store(false, Ordering::Relaxed); - } - } - let replicas = targets - .iter() - .filter(|target| target.0.replica) - .collect::>(); - - for replica in replicas { - if replica.1.role() != Role::Replica { - replica.1.replica.store(true, Ordering::Relaxed); + promoted = targets[primary].1.set_role(Role::Primary); + + if promoted { + warn!("new primary chosen: {}", targets[primary].1.pool.addr()); + + // Demote everyone else to replicas. + targets + .iter() + .enumerate() + .filter(|(i, _)| *i != primary) + .for_each(|(_, target)| { + target.1.set_role(Role::Replica); + }); } } - changed + promoted } /// Launch replica pools and start the monitor. @@ -224,14 +221,11 @@ impl LoadBalancer { .all(|(a, b)| a.pool.can_move_conns_to(&b.pool)) } - /// How many replicas we are connected to. - pub fn len(&self) -> usize { - self.targets.len() - } - /// There are no replicas. - pub fn is_empty(&self) -> bool { - self.len() == 0 + pub fn has_replicas(&self) -> bool { + self.targets + .iter() + .any(|target| target.role() == Role::Replica) } /// Cancel a query if one is running. diff --git a/pgdog/src/backend/pool/lb/test.rs b/pgdog/src/backend/pool/lb/test.rs index 9c72fb44b..cd3454854 100644 --- a/pgdog/src/backend/pool/lb/test.rs +++ b/pgdog/src/backend/pool/lb/test.rs @@ -842,3 +842,64 @@ async fn test_include_primary_if_replica_banned_with_ban() { // Shutdown both primary and replicas replicas.shutdown(); } + +#[tokio::test] +async fn test_has_replicas_with_replicas() { + let replicas = setup_test_replicas(); + + assert!(replicas.has_replicas()); + + replicas.shutdown(); +} + +#[tokio::test] +async fn test_has_replicas_with_primary_and_replicas() { + let primary_config = create_test_pool_config("127.0.0.1", 5432); + let primary_pool = Pool::new(&primary_config); + primary_pool.launch(); + + let replica_configs = [create_test_pool_config("localhost", 5432)]; + + let lb = LoadBalancer::new( + &Some(primary_pool), + &replica_configs, + LoadBalancingStrategy::Random, + ReadWriteSplit::IncludePrimary, + ); + lb.launch(); + + assert!(lb.has_replicas()); + + lb.shutdown(); +} + +#[tokio::test] +async fn test_has_replicas_primary_only() { + let primary_config = create_test_pool_config("127.0.0.1", 5432); + let primary_pool = Pool::new(&primary_config); + primary_pool.launch(); + + let lb = LoadBalancer::new( + &Some(primary_pool), + &[], + LoadBalancingStrategy::Random, + ReadWriteSplit::IncludePrimary, + ); + lb.launch(); + + assert!(!lb.has_replicas()); + + lb.shutdown(); +} + +#[tokio::test] +async fn test_has_replicas_empty() { + let lb = LoadBalancer::new( + &None, + &[], + LoadBalancingStrategy::Random, + ReadWriteSplit::IncludePrimary, + ); + + assert!(!lb.has_replicas()); +} diff --git a/pgdog/src/backend/pool/shard/mod.rs b/pgdog/src/backend/pool/shard/mod.rs index aacbe5aec..0fbe8d4b4 100644 --- a/pgdog/src/backend/pool/shard/mod.rs +++ b/pgdog/src/backend/pool/shard/mod.rs @@ -138,7 +138,7 @@ impl Shard { /// Returns true if the shard has any replica databases. pub fn has_replicas(&self) -> bool { - !self.lb.is_empty() + self.lb.has_replicas() } /// Request a query to be cancelled on any of the servers in the connection pools diff --git a/pgdog/src/backend/pool/shard/monitor.rs b/pgdog/src/backend/pool/shard/monitor.rs index 0f1219927..6fb811137 100644 --- a/pgdog/src/backend/pool/shard/monitor.rs +++ b/pgdog/src/backend/pool/shard/monitor.rs @@ -56,6 +56,7 @@ impl ShardMonitor { select! { _ = maintenance.tick() => {}, _ = self.shard.comms().shutdown.notified() => { + println!("Shutting down"); break; }, } @@ -66,7 +67,6 @@ impl ShardMonitor { self.shard.number(), self.shard.identifier() ); - break; } let pool_with_stats = self diff --git a/pgdog/src/backend/pool/shard/role_detector.rs b/pgdog/src/backend/pool/shard/role_detector.rs index 515dca0bc..f2508bba6 100644 --- a/pgdog/src/backend/pool/shard/role_detector.rs +++ b/pgdog/src/backend/pool/shard/role_detector.rs @@ -1,12 +1,5 @@ use super::Shard; -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum RoleChangeEvent { - Failover, - Initial, - NoChange, -} - pub(super) struct RoleDetector { shard: Shard, } @@ -24,3 +17,125 @@ impl RoleDetector { self.shard.redetect_roles() } } + +#[cfg(test)] +mod test { + use std::sync::Arc; + use std::time::Duration; + + use tokio::time::Instant; + + use crate::backend::databases::User; + use crate::backend::pool::lsn_monitor::LsnStats; + use crate::backend::pool::{Address, PoolConfig}; + use crate::backend::replication::publisher::Lsn; + use crate::config::{LoadBalancingStrategy, ReadWriteSplit}; + + use super::super::ShardConfig; + use super::*; + + fn create_test_pool_config(host: &str, port: u16) -> PoolConfig { + PoolConfig { + address: Address { + host: host.into(), + port, + user: "pgdog".into(), + password: "pgdog".into(), + database_name: "pgdog".into(), + ..Default::default() + }, + ..Default::default() + } + } + + fn create_test_shard(primary: &Option, replicas: &[PoolConfig]) -> Shard { + Shard::new(ShardConfig { + number: 0, + primary, + replicas, + lb_strategy: LoadBalancingStrategy::Random, + rw_split: ReadWriteSplit::ExcludePrimary, + identifier: Arc::new(User { + user: "pgdog".into(), + database: "pgdog".into(), + }), + lsn_check_interval: Duration::MAX, + }) + } + + fn set_lsn_stats(shard: &Shard, index: usize, replica: bool, lsn: i64) { + let pools = shard.pools(); + let stats = LsnStats { + replica, + lsn: Lsn::from_i64(lsn), + offset_bytes: lsn, + fetched: Instant::now(), + ..Default::default() + }; + *pools[index].inner().lsn_stats.write() = stats; + } + + #[test] + fn test_changed_returns_false_when_lsn_stats_invalid() { + let primary = Some(create_test_pool_config("127.0.0.1", 5432)); + let replicas = [create_test_pool_config("localhost", 5432)]; + let shard = create_test_shard(&primary, &replicas); + + let mut detector = RoleDetector::new(&shard); + + assert!(!detector.changed()); + } + + #[test] + fn test_changed_returns_false_when_roles_unchanged() { + let primary = Some(create_test_pool_config("127.0.0.1", 5432)); + let replicas = [create_test_pool_config("localhost", 5432)]; + let shard = create_test_shard(&primary, &replicas); + + set_lsn_stats(&shard, 0, true, 100); + set_lsn_stats(&shard, 1, false, 200); + + let mut detector = RoleDetector::new(&shard); + + assert!(!detector.changed()); + } + + #[test] + fn test_changed_returns_true_on_failover() { + let primary = Some(create_test_pool_config("127.0.0.1", 5432)); + let replicas = [create_test_pool_config("localhost", 5432)]; + let shard = create_test_shard(&primary, &replicas); + + set_lsn_stats(&shard, 0, true, 100); + set_lsn_stats(&shard, 1, false, 200); + + let mut detector = RoleDetector::new(&shard); + + assert!(!detector.changed()); + + set_lsn_stats(&shard, 0, false, 300); + set_lsn_stats(&shard, 1, true, 200); + + assert!(detector.changed()); + } + + #[test] + fn test_changed_returns_false_after_roles_stabilize() { + let primary = Some(create_test_pool_config("127.0.0.1", 5432)); + let replicas = [create_test_pool_config("localhost", 5432)]; + let shard = create_test_shard(&primary, &replicas); + + set_lsn_stats(&shard, 0, true, 100); + set_lsn_stats(&shard, 1, false, 200); + + let mut detector = RoleDetector::new(&shard); + assert!(!detector.changed()); + + set_lsn_stats(&shard, 0, false, 300); + set_lsn_stats(&shard, 1, true, 200); + + assert!(detector.changed()); + + assert!(!detector.changed()); + } +} From f830903e9b91f8e3cf95f0171d7417eee358304d Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Wed, 3 Dec 2025 21:55:57 -0800 Subject: [PATCH 06/10] clippy --- pgdog/src/backend/pool/shard/monitor.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/pgdog/src/backend/pool/shard/monitor.rs b/pgdog/src/backend/pool/shard/monitor.rs index 6fb811137..606089614 100644 --- a/pgdog/src/backend/pool/shard/monitor.rs +++ b/pgdog/src/backend/pool/shard/monitor.rs @@ -56,7 +56,6 @@ impl ShardMonitor { select! { _ = maintenance.tick() => {}, _ = self.shard.comms().shutdown.notified() => { - println!("Shutting down"); break; }, } From 3fb2b6bd46f7ad2c8cf89e062e619ff8eca6ff16 Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Thu, 4 Dec 2025 09:37:02 -0800 Subject: [PATCH 07/10] Add set in transaction integration test --- integration/rust/tests/integration/mod.rs | 1 + .../tests/integration/set_in_transaction.rs | 103 ++++++++++++++++++ 2 files changed, 104 insertions(+) create mode 100644 integration/rust/tests/integration/set_in_transaction.rs diff --git a/integration/rust/tests/integration/mod.rs b/integration/rust/tests/integration/mod.rs index 4c9229a39..5357d5987 100644 --- a/integration/rust/tests/integration/mod.rs +++ b/integration/rust/tests/integration/mod.rs @@ -12,6 +12,7 @@ pub mod prepared; pub mod reload; pub mod rewrite; pub mod savepoint; +pub mod set_in_transaction; pub mod set_sharding_key; pub mod shard_consistency; pub mod stddev; diff --git a/integration/rust/tests/integration/set_in_transaction.rs b/integration/rust/tests/integration/set_in_transaction.rs new file mode 100644 index 000000000..6dea4aa9f --- /dev/null +++ b/integration/rust/tests/integration/set_in_transaction.rs @@ -0,0 +1,103 @@ +use rust::setup::{admin_sqlx, connections_sqlx}; +use serial_test::serial; +use sqlx::Executor; + +#[tokio::test] +#[serial] +async fn test_set_in_transaction_preserved_after_commit() { + let admin = admin_sqlx().await; + admin + .execute("SET cross_shard_disabled TO true") + .await + .unwrap(); + + let pools = connections_sqlx().await; + let sharded = &pools[1]; + + let mut conn = sharded.acquire().await.unwrap(); + + // Start a transaction and change lock_timeout + conn.execute("BEGIN").await.unwrap(); + conn.execute("SET lock_timeout TO '45s'").await.unwrap(); + + // Verify lock_timeout is set inside transaction + let timeout_in_tx: String = sqlx::query_scalar("SHOW lock_timeout") + .fetch_one(&mut *conn) + .await + .unwrap(); + assert_eq!( + timeout_in_tx, "45s", + "lock_timeout should be 45s inside transaction" + ); + + conn.execute("COMMIT").await.unwrap(); + + // Verify lock_timeout is preserved after commit + let timeout_after_commit: String = sqlx::query_scalar("SHOW lock_timeout") + .fetch_one(&mut *conn) + .await + .unwrap(); + assert_eq!( + timeout_after_commit, "45s", + "lock_timeout should be preserved after commit" + ); + + admin + .execute("SET cross_shard_disabled TO false") + .await + .unwrap(); +} + +#[tokio::test] +#[serial] +async fn test_set_in_transaction_discarded_after_rollback() { + let admin = admin_sqlx().await; + admin + .execute("SET cross_shard_disabled TO true") + .await + .unwrap(); + + let pools = connections_sqlx().await; + let sharded = &pools[1]; + + let mut conn = sharded.acquire().await.unwrap(); + + // Get the default statement_timeout before any transaction + let default_timeout: String = sqlx::query_scalar("SHOW statement_timeout") + .fetch_one(&mut *conn) + .await + .unwrap(); + + // Start a transaction and change statement_timeout + conn.execute("BEGIN").await.unwrap(); + conn.execute("SET statement_timeout TO '30s'") + .await + .unwrap(); + + // Verify statement_timeout is set inside transaction + let timeout_in_tx: String = sqlx::query_scalar("SHOW statement_timeout") + .fetch_one(&mut *conn) + .await + .unwrap(); + assert_eq!( + timeout_in_tx, "30s", + "statement_timeout should be 30s inside transaction" + ); + + conn.execute("ROLLBACK").await.unwrap(); + + // Verify statement_timeout is back to default after rollback + let timeout_after_rollback: String = sqlx::query_scalar("SHOW statement_timeout") + .fetch_one(&mut *conn) + .await + .unwrap(); + assert_eq!( + timeout_after_rollback, default_timeout, + "statement_timeout should be reset to default after rollback" + ); + + admin + .execute("SET cross_shard_disabled TO false") + .await + .unwrap(); +} From 7caf9cb8ea9833092eb05a4109ba22afd037c184 Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Thu, 4 Dec 2025 10:12:46 -0800 Subject: [PATCH 08/10] More tests --- .../tests/integration/set_in_transaction.rs | 56 ++++++++--- pgdog/src/backend/pool/lb/test.rs | 98 +++++++++++++++++++ pgdog/src/backend/server.rs | 2 +- .../src/frontend/client/query_engine/query.rs | 2 +- pgdog/src/frontend/client/test/mod.rs | 39 +++++++- pgdog/src/net/parameter.rs | 12 ++- 6 files changed, 186 insertions(+), 23 deletions(-) diff --git a/integration/rust/tests/integration/set_in_transaction.rs b/integration/rust/tests/integration/set_in_transaction.rs index 6dea4aa9f..a897144ba 100644 --- a/integration/rust/tests/integration/set_in_transaction.rs +++ b/integration/rust/tests/integration/set_in_transaction.rs @@ -4,7 +4,7 @@ use sqlx::Executor; #[tokio::test] #[serial] -async fn test_set_in_transaction_preserved_after_commit() { +async fn test_set_in_transaction_reset_after_commit() { let admin = admin_sqlx().await; admin .execute("SET cross_shard_disabled TO true") @@ -16,9 +16,24 @@ async fn test_set_in_transaction_preserved_after_commit() { let mut conn = sharded.acquire().await.unwrap(); + // Get the original lock_timeout before any transaction + let original_timeout: String = sqlx::query_scalar("SHOW lock_timeout") + .fetch_one(&mut *conn) + .await + .unwrap(); + + // Make sure we set it to something different + let new_timeout = if original_timeout == "45s" { + "30s" + } else { + "45s" + }; + // Start a transaction and change lock_timeout conn.execute("BEGIN").await.unwrap(); - conn.execute("SET lock_timeout TO '45s'").await.unwrap(); + conn.execute(format!("SET lock_timeout TO '{}'", new_timeout).as_str()) + .await + .unwrap(); // Verify lock_timeout is set inside transaction let timeout_in_tx: String = sqlx::query_scalar("SHOW lock_timeout") @@ -26,20 +41,21 @@ async fn test_set_in_transaction_preserved_after_commit() { .await .unwrap(); assert_eq!( - timeout_in_tx, "45s", - "lock_timeout should be 45s inside transaction" + timeout_in_tx, new_timeout, + "lock_timeout should be {} inside transaction", + new_timeout ); conn.execute("COMMIT").await.unwrap(); - // Verify lock_timeout is preserved after commit + // Verify lock_timeout is reset to original after commit let timeout_after_commit: String = sqlx::query_scalar("SHOW lock_timeout") .fetch_one(&mut *conn) .await .unwrap(); assert_eq!( - timeout_after_commit, "45s", - "lock_timeout should be preserved after commit" + timeout_after_commit, original_timeout, + "lock_timeout should be reset to original after commit" ); admin @@ -50,7 +66,7 @@ async fn test_set_in_transaction_preserved_after_commit() { #[tokio::test] #[serial] -async fn test_set_in_transaction_discarded_after_rollback() { +async fn test_set_in_transaction_reset_after_rollback() { let admin = admin_sqlx().await; admin .execute("SET cross_shard_disabled TO true") @@ -62,15 +78,22 @@ async fn test_set_in_transaction_discarded_after_rollback() { let mut conn = sharded.acquire().await.unwrap(); - // Get the default statement_timeout before any transaction - let default_timeout: String = sqlx::query_scalar("SHOW statement_timeout") + // Get the original statement_timeout before any transaction + let original_timeout: String = sqlx::query_scalar("SHOW statement_timeout") .fetch_one(&mut *conn) .await .unwrap(); + // Make sure we set it to something different + let new_timeout = if original_timeout == "30s" { + "45s" + } else { + "30s" + }; + // Start a transaction and change statement_timeout conn.execute("BEGIN").await.unwrap(); - conn.execute("SET statement_timeout TO '30s'") + conn.execute(format!("SET statement_timeout TO '{}'", new_timeout).as_str()) .await .unwrap(); @@ -80,20 +103,21 @@ async fn test_set_in_transaction_discarded_after_rollback() { .await .unwrap(); assert_eq!( - timeout_in_tx, "30s", - "statement_timeout should be 30s inside transaction" + timeout_in_tx, new_timeout, + "statement_timeout should be {} inside transaction", + new_timeout ); conn.execute("ROLLBACK").await.unwrap(); - // Verify statement_timeout is back to default after rollback + // Verify statement_timeout is back to original after rollback let timeout_after_rollback: String = sqlx::query_scalar("SHOW statement_timeout") .fetch_one(&mut *conn) .await .unwrap(); assert_eq!( - timeout_after_rollback, default_timeout, - "statement_timeout should be reset to default after rollback" + timeout_after_rollback, original_timeout, + "statement_timeout should be reset to original after rollback" ); admin diff --git a/pgdog/src/backend/pool/lb/test.rs b/pgdog/src/backend/pool/lb/test.rs index cd3454854..327a7f0ed 100644 --- a/pgdog/src/backend/pool/lb/test.rs +++ b/pgdog/src/backend/pool/lb/test.rs @@ -903,3 +903,101 @@ async fn test_has_replicas_empty() { assert!(!lb.has_replicas()); } + +#[tokio::test] +async fn test_set_role() { + let replicas = setup_test_replicas(); + + // Initially all targets are replicas + assert_eq!(replicas.targets[0].role(), Role::Replica); + assert_eq!(replicas.targets[1].role(), Role::Replica); + + // Setting replica to replica returns false (no change) + let changed = replicas.targets[0].set_role(Role::Replica); + assert!(!changed); + assert_eq!(replicas.targets[0].role(), Role::Replica); + + // Setting replica to primary returns true (changed) + let changed = replicas.targets[0].set_role(Role::Primary); + assert!(changed); + assert_eq!(replicas.targets[0].role(), Role::Primary); + + // Setting primary to primary returns false (no change) + let changed = replicas.targets[0].set_role(Role::Primary); + assert!(!changed); + assert_eq!(replicas.targets[0].role(), Role::Primary); + + // Setting primary to replica returns true (changed) + let changed = replicas.targets[0].set_role(Role::Replica); + assert!(changed); + assert_eq!(replicas.targets[0].role(), Role::Replica); + + replicas.shutdown(); +} + +#[tokio::test] +async fn test_can_move_conns_to_same_config() { + let pool_config1 = create_test_pool_config("127.0.0.1", 5432); + let pool_config2 = create_test_pool_config("localhost", 5432); + + let lb1 = LoadBalancer::new( + &None, + &[pool_config1.clone(), pool_config2.clone()], + LoadBalancingStrategy::Random, + ReadWriteSplit::IncludePrimary, + ); + + let lb2 = LoadBalancer::new( + &None, + &[pool_config1, pool_config2], + LoadBalancingStrategy::Random, + ReadWriteSplit::IncludePrimary, + ); + + assert!(lb1.can_move_conns_to(&lb2)); +} + +#[tokio::test] +async fn test_can_move_conns_to_different_count() { + let pool_config1 = create_test_pool_config("127.0.0.1", 5432); + let pool_config2 = create_test_pool_config("localhost", 5432); + + let lb1 = LoadBalancer::new( + &None, + &[pool_config1.clone(), pool_config2], + LoadBalancingStrategy::Random, + ReadWriteSplit::IncludePrimary, + ); + + let lb2 = LoadBalancer::new( + &None, + &[pool_config1], + LoadBalancingStrategy::Random, + ReadWriteSplit::IncludePrimary, + ); + + assert!(!lb1.can_move_conns_to(&lb2)); +} + +#[tokio::test] +async fn test_can_move_conns_to_different_addresses() { + let pool_config1 = create_test_pool_config("127.0.0.1", 5432); + let pool_config2 = create_test_pool_config("localhost", 5432); + let pool_config3 = create_test_pool_config("127.0.0.1", 5433); + + let lb1 = LoadBalancer::new( + &None, + &[pool_config1, pool_config2], + LoadBalancingStrategy::Random, + ReadWriteSplit::IncludePrimary, + ); + + let lb2 = LoadBalancer::new( + &None, + &[pool_config3.clone(), pool_config3], + LoadBalancingStrategy::Random, + ReadWriteSplit::IncludePrimary, + ); + + assert!(!lb1.can_move_conns_to(&lb2)); +} diff --git a/pgdog/src/backend/server.rs b/pgdog/src/backend/server.rs index de0daff84..329be5803 100644 --- a/pgdog/src/backend/server.rs +++ b/pgdog/src/backend/server.rs @@ -469,7 +469,7 @@ impl Server { // Combine both to create a new, fresh session state // on this connection. - queries.extend(tracked.set_queries()); + queries.extend(tracked.set_queries(false)); // Set state on the connection only if // there are any params to change. diff --git a/pgdog/src/frontend/client/query_engine/query.rs b/pgdog/src/frontend/client/query_engine/query.rs index ec549bd10..42af6ae5e 100644 --- a/pgdog/src/frontend/client/query_engine/query.rs +++ b/pgdog/src/frontend/client/query_engine/query.rs @@ -43,7 +43,7 @@ impl QueryEngine { // be captured inside an explicit transaction // so we don't have to track them. let query_timeout = context.timeouts.query_timeout(&self.stats.state); - for query in self.transaction_params.set_queries() { + for query in self.transaction_params.set_queries(true) { timeout(query_timeout, self.backend.execute(query)).await??; } debug!( diff --git a/pgdog/src/frontend/client/test/mod.rs b/pgdog/src/frontend/client/test/mod.rs index a73984583..32527fd9c 100644 --- a/pgdog/src/frontend/client/test/mod.rs +++ b/pgdog/src/frontend/client/test/mod.rs @@ -2,7 +2,7 @@ use std::time::{Duration, Instant}; use pgdog_config::PoolerMode; use tokio::{ - io::{AsyncReadExt, AsyncWriteExt}, + io::{AsyncRead, AsyncReadExt, AsyncWriteExt}, net::{TcpListener, TcpStream}, time::timeout, }; @@ -21,8 +21,8 @@ use crate::{ }, net::{ bind::Parameter, Bind, Close, CommandComplete, DataRow, Describe, ErrorResponse, Execute, - Field, Flush, Format, FromBytes, Parse, Protocol, Query, ReadyForQuery, RowDescription, - Sync, Terminate, ToBytes, + Field, Flush, Format, FromBytes, Message, Parse, Protocol, Query, ReadyForQuery, + RowDescription, Sync, Terminate, ToBytes, }, state::State, }; @@ -74,6 +74,39 @@ macro_rules! new_client { }}; } +pub fn buffer(messages: &[impl ToBytes]) -> BytesMut { + let mut buf = BytesMut::new(); + for message in messages { + buf.put(message.to_bytes().unwrap()); + } + buf +} + +/// Read a series of messages from the stream and make sure +/// they arrive in the right order. +pub async fn read(stream: &mut (impl AsyncRead + Unpin), codes: &[char]) -> Vec { + let mut result = vec![]; + + for code in codes { + let c = stream.read_u8().await.unwrap(); + + assert_eq!(c as char, *code, "unexpected message received"); + + let len = stream.read_i32().await.unwrap(); + let mut data = vec![0u8; len as usize - 4]; + stream.read_exact(&mut data).await.unwrap(); + + let mut message = BytesMut::new(); + message.put_u8(c); + message.put_i32(len); + message.put_slice(&data); + + result.push(Message::new(message.freeze())) + } + + result +} + macro_rules! buffer { ( $( $msg:block ),* ) => {{ let mut buf = BytesMut::new(); diff --git a/pgdog/src/net/parameter.rs b/pgdog/src/net/parameter.rs index 398e7f59c..8de5afe29 100644 --- a/pgdog/src/net/parameter.rs +++ b/pgdog/src/net/parameter.rs @@ -176,10 +176,18 @@ impl Parameters { self.hash == other.hash } - pub fn set_queries(&self) -> Vec { + /// Generate SET queries to change server state. + /// + /// # Arguments + /// + /// * `local`: Generate `SET LOCAL` which are automatically + /// reset after the transaction is over. + /// + pub fn set_queries(&self, local: bool) -> Vec { + let set = if local { "SET LOCAL" } else { "SET" }; self.params .iter() - .map(|(name, value)| Query::new(format!(r#"SET "{}" TO {}"#, name, value))) + .map(|(name, value)| Query::new(format!(r#"{} "{}" TO {}"#, set, name, value))) .collect() } From f161eb30ca9643b37c78eee5a6df0d0da6e265e3 Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Thu, 4 Dec 2025 10:41:08 -0800 Subject: [PATCH 09/10] more testsg --- pgdog-config/src/core.rs | 32 +++++++-- pgdog/src/backend/databases.rs | 2 +- pgdog/src/backend/pool/config.rs | 54 ++++++++++++++- pgdog/src/backend/pool/lb/mod.rs | 14 ++-- pgdog/src/backend/pool/shard/monitor.rs | 10 ++- pgdog/src/backend/pool/shard/role_detector.rs | 65 +++++++++++++++---- 6 files changed, 150 insertions(+), 27 deletions(-) diff --git a/pgdog-config/src/core.rs b/pgdog-config/src/core.rs index a43c34437..b3ad87e1a 100644 --- a/pgdog-config/src/core.rs +++ b/pgdog-config/src/core.rs @@ -7,6 +7,7 @@ use tracing::{info, warn}; use crate::sharding::ShardedSchema; use crate::{ EnumeratedDatabase, Memory, OmnishardedTable, PassthoughAuth, PreparedStatements, RewriteMode, + Role, }; use super::database::Database; @@ -312,18 +313,39 @@ impl Config { } } - // Check pooler mode. - let mut pooler_mode = HashMap::>::new(); + struct Check { + pooler_mode: Option, + role: Role, + role_warned: bool, + } + + // Check identical configs. + let mut checks = HashMap::::new(); for database in &self.databases { - if let Some(mode) = pooler_mode.get(&database.name) { - if mode != &database.pooler_mode { + if let Some(existing) = checks.get_mut(&database.name) { + if existing.pooler_mode != database.pooler_mode { warn!( "database \"{}\" (shard={}, role={}) has a different \"pooler_mode\" setting, ignoring", database.name, database.shard, database.role, ); } + let auto = existing.role == Role::Auto || database.role == Role::Auto; + if auto && existing.role != database.role && !existing.role_warned { + warn!( + r#"database "{}" has a mix of auto and specific roles, automatic role detection will be disabled"#, + database.name + ); + existing.role_warned = true; + } } else { - pooler_mode.insert(database.name.clone(), database.pooler_mode.clone()); + checks.insert( + database.name.clone(), + Check { + pooler_mode: database.pooler_mode.clone(), + role: database.role, + role_warned: false, + }, + ); } } diff --git a/pgdog/src/backend/databases.rs b/pgdog/src/backend/databases.rs index 8ff7c2d83..d0f724aba 100644 --- a/pgdog/src/backend/databases.rs +++ b/pgdog/src/backend/databases.rs @@ -395,7 +395,7 @@ pub(crate) fn new_pool( }); let replicas = user_databases .iter() - .filter(|d| matches!(d.role, Role::Replica | Role::Auto)) + .filter(|d| matches!(d.role, Role::Replica | Role::Auto)) // Auto role is assumed read-only until proven otherwise. .map(|replica| PoolConfig { address: Address::new(replica, user, replica.number), config: Config::new(general, replica, user, has_single_replica), diff --git a/pgdog/src/backend/pool/config.rs b/pgdog/src/backend/pool/config.rs index cd626d77a..7287b4ad9 100644 --- a/pgdog/src/backend/pool/config.rs +++ b/pgdog/src/backend/pool/config.rs @@ -2,7 +2,7 @@ use std::time::Duration; -use pgdog_config::pooling::ConnectionRecovery; +use pgdog_config::{pooling::ConnectionRecovery, Role}; use serde::{Deserialize, Serialize}; use crate::config::{Database, General, PoolerMode, User}; @@ -68,6 +68,8 @@ pub struct Config { pub lsn_check_timeout: Duration, /// LSN check delay. pub lsn_check_delay: Duration, + /// Automatic role detection enabled. + pub role_detection: bool, } impl Config { @@ -199,6 +201,7 @@ impl Config { lsn_check_interval: Duration::from_millis(general.lsn_check_interval), lsn_check_timeout: Duration::from_millis(general.lsn_check_timeout), lsn_check_delay: Duration::from_millis(general.lsn_check_delay), + role_detection: database.role == Role::Auto, ..Default::default() } } @@ -236,6 +239,55 @@ impl Default for Config { lsn_check_interval: Duration::from_millis(5_000), lsn_check_timeout: Duration::from_millis(5_000), lsn_check_delay: Duration::from_millis(5_000), + role_detection: false, } } } + +#[cfg(test)] +mod test { + use super::*; + + fn create_database(role: Role) -> Database { + Database { + name: "test".into(), + role, + host: "localhost".into(), + port: 5432, + ..Default::default() + } + } + + #[test] + fn test_role_auto_enables_role_detection() { + let general = General::default(); + let user = User::default(); + let database = create_database(Role::Auto); + + let config = Config::new(&general, &database, &user, false); + + assert!(config.role_detection); + } + + #[test] + fn test_role_primary_disables_role_detection() { + let general = General::default(); + let user = User::default(); + let database = create_database(Role::Primary); + + let config = Config::new(&general, &database, &user, false); + + assert!(!config.role_detection); + } + + #[test] + fn test_role_replica_disables_role_detection() { + let general = General::default(); + let user = User::default(); + let database = create_database(Role::Replica); + + let config = Config::new(&general, &database, &user, false); + + assert!(!config.role_detection); + } +} diff --git a/pgdog/src/backend/pool/lb/mod.rs b/pgdog/src/backend/pool/lb/mod.rs index 43250af55..072b00090 100644 --- a/pgdog/src/backend/pool/lb/mod.rs +++ b/pgdog/src/backend/pool/lb/mod.rs @@ -33,14 +33,14 @@ mod test; /// Read query load balancer target. #[derive(Clone, Debug)] -pub struct ReadTarget { +pub struct Target { pub pool: Pool, pub ban: Ban, replica: Arc, pub health: TargetHealth, } -impl ReadTarget { +impl Target { pub(super) fn new(pool: Pool, role: Role) -> Self { let ban = Ban::new(&pool); Self { @@ -72,7 +72,7 @@ impl ReadTarget { #[derive(Clone, Default, Debug)] pub struct LoadBalancer { /// Read/write targets. - pub(super) targets: Vec, + pub(super) targets: Vec, /// Checkout timeout. pub(super) checkout_timeout: Duration, /// Round robin atomic counter. @@ -104,12 +104,12 @@ impl LoadBalancer { let mut targets: Vec<_> = addrs .iter() - .map(|config| ReadTarget::new(Pool::new(config), Role::Replica)) + .map(|config| Target::new(Pool::new(config), Role::Replica)) .collect(); let primary_target = primary .as_ref() - .map(|pool| ReadTarget::new(pool.clone(), Role::Primary)); + .map(|pool| Target::new(pool.clone(), Role::Primary)); if let Some(primary) = primary_target { targets.push(primary); @@ -134,7 +134,7 @@ impl LoadBalancer { /// /// Unlike [`primary()`], this returns the full target struct which allows /// access to ban and health state for monitoring and testing purposes. - pub fn primary_target(&self) -> Option<&ReadTarget> { + pub fn primary_target(&self) -> Option<&Target> { self.targets .iter() .rev() // If there is a primary, it's likely to be last. @@ -257,7 +257,7 @@ impl LoadBalancer { use LoadBalancingStrategy::*; use ReadWriteSplit::*; - let mut candidates: Vec<&ReadTarget> = self.targets.iter().collect(); + let mut candidates: Vec<&Target> = self.targets.iter().collect(); let primary_reads = match self.rw_split { IncludePrimary => true, diff --git a/pgdog/src/backend/pool/shard/monitor.rs b/pgdog/src/backend/pool/shard/monitor.rs index 606089614..99cbab83a 100644 --- a/pgdog/src/backend/pool/shard/monitor.rs +++ b/pgdog/src/backend/pool/shard/monitor.rs @@ -1,7 +1,7 @@ use super::*; use tokio::time::interval; -use tracing::warn; +use tracing::{info, warn}; /// Shard communication primitives. #[derive(Debug)] @@ -52,6 +52,14 @@ impl ShardMonitor { let mut detector = RoleDetector::new(&self.shard); + if detector.enabled() { + info!( + "automatic database role detection is enabled for shard {} [{}]", + self.shard.number(), + self.shard.identifier() + ); + } + loop { select! { _ = maintenance.tick() => {}, diff --git a/pgdog/src/backend/pool/shard/role_detector.rs b/pgdog/src/backend/pool/shard/role_detector.rs index f2508bba6..849fd57d6 100644 --- a/pgdog/src/backend/pool/shard/role_detector.rs +++ b/pgdog/src/backend/pool/shard/role_detector.rs @@ -1,6 +1,7 @@ use super::Shard; pub(super) struct RoleDetector { + enabled: bool, shard: Shard, } @@ -8,13 +9,26 @@ impl RoleDetector { /// Create new role change detector. pub(super) fn new(shard: &Shard) -> Self { Self { + enabled: shard + .pools() + .iter() + .all(|pool| pool.config().role_detection), shard: shard.clone(), } } /// Detect role change in the shard. pub(super) fn changed(&mut self) -> bool { - self.shard.redetect_roles() + if self.enabled() { + self.shard.redetect_roles() + } else { + false + } + } + + /// Role detector is enabled. + pub(super) fn enabled(&self) -> bool { + self.enabled } } @@ -27,14 +41,14 @@ mod test { use crate::backend::databases::User; use crate::backend::pool::lsn_monitor::LsnStats; - use crate::backend::pool::{Address, PoolConfig}; + use crate::backend::pool::{Address, Config, PoolConfig}; use crate::backend::replication::publisher::Lsn; use crate::config::{LoadBalancingStrategy, ReadWriteSplit}; use super::super::ShardConfig; use super::*; - fn create_test_pool_config(host: &str, port: u16) -> PoolConfig { + fn create_test_pool_config(host: &str, port: u16, role_detection: bool) -> PoolConfig { PoolConfig { address: Address { host: host.into(), @@ -44,7 +58,10 @@ mod test { database_name: "pgdog".into(), ..Default::default() }, - ..Default::default() + config: Config { + role_detection, + ..Default::default() + }, } } @@ -77,19 +94,20 @@ mod test { #[test] fn test_changed_returns_false_when_lsn_stats_invalid() { - let primary = Some(create_test_pool_config("127.0.0.1", 5432)); - let replicas = [create_test_pool_config("localhost", 5432)]; + let primary = Some(create_test_pool_config("127.0.0.1", 5432, true)); + let replicas = [create_test_pool_config("localhost", 5432, true)]; let shard = create_test_shard(&primary, &replicas); let mut detector = RoleDetector::new(&shard); + assert!(detector.enabled()); assert!(!detector.changed()); } #[test] fn test_changed_returns_false_when_roles_unchanged() { - let primary = Some(create_test_pool_config("127.0.0.1", 5432)); - let replicas = [create_test_pool_config("localhost", 5432)]; + let primary = Some(create_test_pool_config("127.0.0.1", 5432, true)); + let replicas = [create_test_pool_config("localhost", 5432, true)]; let shard = create_test_shard(&primary, &replicas); set_lsn_stats(&shard, 0, true, 100); @@ -97,13 +115,14 @@ mod test { let mut detector = RoleDetector::new(&shard); + assert!(detector.enabled()); assert!(!detector.changed()); } #[test] fn test_changed_returns_true_on_failover() { - let primary = Some(create_test_pool_config("127.0.0.1", 5432)); - let replicas = [create_test_pool_config("localhost", 5432)]; + let primary = Some(create_test_pool_config("127.0.0.1", 5432, true)); + let replicas = [create_test_pool_config("localhost", 5432, true)]; let shard = create_test_shard(&primary, &replicas); set_lsn_stats(&shard, 0, true, 100); @@ -111,6 +130,7 @@ mod test { let mut detector = RoleDetector::new(&shard); + assert!(detector.enabled()); assert!(!detector.changed()); set_lsn_stats(&shard, 0, false, 300); @@ -121,14 +141,15 @@ mod test { #[test] fn test_changed_returns_false_after_roles_stabilize() { - let primary = Some(create_test_pool_config("127.0.0.1", 5432)); - let replicas = [create_test_pool_config("localhost", 5432)]; + let primary = Some(create_test_pool_config("127.0.0.1", 5432, true)); + let replicas = [create_test_pool_config("localhost", 5432, true)]; let shard = create_test_shard(&primary, &replicas); set_lsn_stats(&shard, 0, true, 100); set_lsn_stats(&shard, 1, false, 200); let mut detector = RoleDetector::new(&shard); + assert!(detector.enabled()); assert!(!detector.changed()); set_lsn_stats(&shard, 0, false, 300); @@ -138,4 +159,24 @@ mod test { assert!(!detector.changed()); } + + #[test] + fn test_disabled_when_not_all_roles_auto() { + let primary = Some(create_test_pool_config("127.0.0.1", 5432, false)); + let replicas = [create_test_pool_config("localhost", 5432, true)]; + let shard = create_test_shard(&primary, &replicas); + + set_lsn_stats(&shard, 0, true, 100); + set_lsn_stats(&shard, 1, false, 200); + + let mut detector = RoleDetector::new(&shard); + + assert!(!detector.enabled()); + assert!(!detector.changed()); + + set_lsn_stats(&shard, 0, false, 300); + set_lsn_stats(&shard, 1, true, 200); + + assert!(!detector.changed()); + } } From 4d26f550d1d4fc3698940176bc4df1fd359edb35 Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Thu, 4 Dec 2025 10:49:53 -0800 Subject: [PATCH 10/10] Paranoid --- .../tests/integration/set_in_transaction.rs | 62 +++++++++++++++++++ pgdog/src/backend/pool/connection/binding.rs | 12 ++++ .../src/frontend/client/query_engine/query.rs | 1 + 3 files changed, 75 insertions(+) diff --git a/integration/rust/tests/integration/set_in_transaction.rs b/integration/rust/tests/integration/set_in_transaction.rs index a897144ba..1d4342f8c 100644 --- a/integration/rust/tests/integration/set_in_transaction.rs +++ b/integration/rust/tests/integration/set_in_transaction.rs @@ -125,3 +125,65 @@ async fn test_set_in_transaction_reset_after_rollback() { .await .unwrap(); } + +#[tokio::test] +#[serial] +async fn test_set_local_in_transaction_reset_after_commit() { + let admin = admin_sqlx().await; + admin + .execute("SET cross_shard_disabled TO true") + .await + .unwrap(); + + let pools = connections_sqlx().await; + let sharded = &pools[1]; + + let mut conn = sharded.acquire().await.unwrap(); + + // Get the original work_mem before any transaction + let original_work_mem: String = sqlx::query_scalar("SHOW work_mem") + .fetch_one(&mut *conn) + .await + .unwrap(); + + // Make sure we set it to something different + let new_work_mem = if original_work_mem == "8MB" { + "16MB" + } else { + "8MB" + }; + + // Start a transaction and change work_mem using SET LOCAL + conn.execute("BEGIN").await.unwrap(); + conn.execute(format!("SET LOCAL work_mem TO '{}'", new_work_mem).as_str()) + .await + .unwrap(); + + // Verify work_mem is set inside transaction + let work_mem_in_tx: String = sqlx::query_scalar("SHOW work_mem") + .fetch_one(&mut *conn) + .await + .unwrap(); + assert_eq!( + work_mem_in_tx, new_work_mem, + "work_mem should be {} inside transaction", + new_work_mem + ); + + conn.execute("COMMIT").await.unwrap(); + + // Verify work_mem is reset to original after commit (SET LOCAL is transaction-scoped) + let work_mem_after_commit: String = sqlx::query_scalar("SHOW work_mem") + .fetch_one(&mut *conn) + .await + .unwrap(); + assert_eq!( + work_mem_after_commit, original_work_mem, + "work_mem should be reset to original after commit (SET LOCAL is transaction-scoped)" + ); + + admin + .execute("SET cross_shard_disabled TO false") + .await + .unwrap(); +} diff --git a/pgdog/src/backend/pool/connection/binding.rs b/pgdog/src/backend/pool/connection/binding.rs index 8392425c4..a2cc1c4fd 100644 --- a/pgdog/src/backend/pool/connection/binding.rs +++ b/pgdog/src/backend/pool/connection/binding.rs @@ -372,6 +372,18 @@ impl Binding { } } + /// Reset changed params on all servers, disabling parameter tracking + /// for this request. + pub fn reset_changed_params(&mut self) { + match self { + Binding::Direct(Some(ref mut server)) => server.reset_changed_params(), + Binding::MultiShard(ref mut servers, _) => servers + .iter_mut() + .for_each(|server| server.reset_changed_params()), + _ => (), + } + } + pub(super) fn dirty(&mut self) { match self { Binding::Direct(Some(ref mut server)) => server.mark_dirty(true), diff --git a/pgdog/src/frontend/client/query_engine/query.rs b/pgdog/src/frontend/client/query_engine/query.rs index 42af6ae5e..560a1c778 100644 --- a/pgdog/src/frontend/client/query_engine/query.rs +++ b/pgdog/src/frontend/client/query_engine/query.rs @@ -51,6 +51,7 @@ impl QueryEngine { self.transaction_params.len() ); self.transaction_params.clear(); + self.backend.reset_changed_params(); } else if !self.connect(context, route).await? { return Ok(()); }