Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions integration/rust/tests/integration/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pub mod prepared;
pub mod reload;
pub mod rewrite;
pub mod savepoint;
pub mod set_in_transaction;
pub mod set_sharding_key;
pub mod shard_consistency;
pub mod stddev;
Expand Down
189 changes: 189 additions & 0 deletions integration/rust/tests/integration/set_in_transaction.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
use rust::setup::{admin_sqlx, connections_sqlx};
use serial_test::serial;
use sqlx::Executor;

#[tokio::test]
#[serial]
async fn test_set_in_transaction_reset_after_commit() {
let admin = admin_sqlx().await;
admin
.execute("SET cross_shard_disabled TO true")
.await
.unwrap();

let pools = connections_sqlx().await;
let sharded = &pools[1];

let mut conn = sharded.acquire().await.unwrap();

// Get the original lock_timeout before any transaction
let original_timeout: String = sqlx::query_scalar("SHOW lock_timeout")
.fetch_one(&mut *conn)
.await
.unwrap();

// Make sure we set it to something different
let new_timeout = if original_timeout == "45s" {
"30s"
} else {
"45s"
};

// Start a transaction and change lock_timeout
conn.execute("BEGIN").await.unwrap();
conn.execute(format!("SET lock_timeout TO '{}'", new_timeout).as_str())
.await
.unwrap();

// Verify lock_timeout is set inside transaction
let timeout_in_tx: String = sqlx::query_scalar("SHOW lock_timeout")
.fetch_one(&mut *conn)
.await
.unwrap();
assert_eq!(
timeout_in_tx, new_timeout,
"lock_timeout should be {} inside transaction",
new_timeout
);

conn.execute("COMMIT").await.unwrap();

// Verify lock_timeout is reset to original after commit
let timeout_after_commit: String = sqlx::query_scalar("SHOW lock_timeout")
.fetch_one(&mut *conn)
.await
.unwrap();
assert_eq!(
timeout_after_commit, original_timeout,
"lock_timeout should be reset to original after commit"
);

admin
.execute("SET cross_shard_disabled TO false")
.await
.unwrap();
}

#[tokio::test]
#[serial]
async fn test_set_in_transaction_reset_after_rollback() {
let admin = admin_sqlx().await;
admin
.execute("SET cross_shard_disabled TO true")
.await
.unwrap();

let pools = connections_sqlx().await;
let sharded = &pools[1];

let mut conn = sharded.acquire().await.unwrap();

// Get the original statement_timeout before any transaction
let original_timeout: String = sqlx::query_scalar("SHOW statement_timeout")
.fetch_one(&mut *conn)
.await
.unwrap();

// Make sure we set it to something different
let new_timeout = if original_timeout == "30s" {
"45s"
} else {
"30s"
};

// Start a transaction and change statement_timeout
conn.execute("BEGIN").await.unwrap();
conn.execute(format!("SET statement_timeout TO '{}'", new_timeout).as_str())
.await
.unwrap();

// Verify statement_timeout is set inside transaction
let timeout_in_tx: String = sqlx::query_scalar("SHOW statement_timeout")
.fetch_one(&mut *conn)
.await
.unwrap();
assert_eq!(
timeout_in_tx, new_timeout,
"statement_timeout should be {} inside transaction",
new_timeout
);

conn.execute("ROLLBACK").await.unwrap();

// Verify statement_timeout is back to original after rollback
let timeout_after_rollback: String = sqlx::query_scalar("SHOW statement_timeout")
.fetch_one(&mut *conn)
.await
.unwrap();
assert_eq!(
timeout_after_rollback, original_timeout,
"statement_timeout should be reset to original after rollback"
);

admin
.execute("SET cross_shard_disabled TO false")
.await
.unwrap();
}

#[tokio::test]
#[serial]
async fn test_set_local_in_transaction_reset_after_commit() {
let admin = admin_sqlx().await;
admin
.execute("SET cross_shard_disabled TO true")
.await
.unwrap();

let pools = connections_sqlx().await;
let sharded = &pools[1];

let mut conn = sharded.acquire().await.unwrap();

// Get the original work_mem before any transaction
let original_work_mem: String = sqlx::query_scalar("SHOW work_mem")
.fetch_one(&mut *conn)
.await
.unwrap();

// Make sure we set it to something different
let new_work_mem = if original_work_mem == "8MB" {
"16MB"
} else {
"8MB"
};

// Start a transaction and change work_mem using SET LOCAL
conn.execute("BEGIN").await.unwrap();
conn.execute(format!("SET LOCAL work_mem TO '{}'", new_work_mem).as_str())
.await
.unwrap();

// Verify work_mem is set inside transaction
let work_mem_in_tx: String = sqlx::query_scalar("SHOW work_mem")
.fetch_one(&mut *conn)
.await
.unwrap();
assert_eq!(
work_mem_in_tx, new_work_mem,
"work_mem should be {} inside transaction",
new_work_mem
);

conn.execute("COMMIT").await.unwrap();

// Verify work_mem is reset to original after commit (SET LOCAL is transaction-scoped)
let work_mem_after_commit: String = sqlx::query_scalar("SHOW work_mem")
.fetch_one(&mut *conn)
.await
.unwrap();
assert_eq!(
work_mem_after_commit, original_work_mem,
"work_mem should be reset to original after commit (SET LOCAL is transaction-scoped)"
);

admin
.execute("SET cross_shard_disabled TO false")
.await
.unwrap();
}
32 changes: 27 additions & 5 deletions pgdog-config/src/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use tracing::{info, warn};
use crate::sharding::ShardedSchema;
use crate::{
EnumeratedDatabase, Memory, OmnishardedTable, PassthoughAuth, PreparedStatements, RewriteMode,
Role,
};

use super::database::Database;
Expand Down Expand Up @@ -312,18 +313,39 @@ impl Config {
}
}

// Check pooler mode.
let mut pooler_mode = HashMap::<String, Option<PoolerMode>>::new();
struct Check {
pooler_mode: Option<PoolerMode>,
role: Role,
role_warned: bool,
}

// Check identical configs.
let mut checks = HashMap::<String, Check>::new();
for database in &self.databases {
if let Some(mode) = pooler_mode.get(&database.name) {
if mode != &database.pooler_mode {
if let Some(existing) = checks.get_mut(&database.name) {
if existing.pooler_mode != database.pooler_mode {
warn!(
"database \"{}\" (shard={}, role={}) has a different \"pooler_mode\" setting, ignoring",
database.name, database.shard, database.role,
);
}
let auto = existing.role == Role::Auto || database.role == Role::Auto;
if auto && existing.role != database.role && !existing.role_warned {
warn!(
r#"database "{}" has a mix of auto and specific roles, automatic role detection will be disabled"#,
database.name
);
existing.role_warned = true;
}
} else {
pooler_mode.insert(database.name.clone(), database.pooler_mode.clone());
checks.insert(
database.name.clone(),
Check {
pooler_mode: database.pooler_mode.clone(),
role: database.role,
role_warned: false,
},
);
}
}

Expand Down
Loading
Loading