diff --git a/src/commands/init.rs b/src/commands/init.rs index a9ec11b..d28aa55 100644 --- a/src/commands/init.rs +++ b/src/commands/init.rs @@ -330,7 +330,10 @@ pub async fn init( .with_context(|| format!("Invalid database name: '{}'", db_info.name))?; // Try to create database atomically (avoids TOCTOU vulnerability) - let create_query = format!("CREATE DATABASE \"{}\"", db_info.name); + let create_query = format!( + "CREATE DATABASE {}", + crate::utils::quote_ident(&db_info.name) + ); match target_client.execute(&create_query, &[]).await { Ok(_) => { tracing::info!(" Created database '{}'", db_info.name); @@ -372,8 +375,10 @@ pub async fn init( drop_database_if_exists(&target_client, &db_info.name).await?; // Recreate the database - let create_query = - format!("CREATE DATABASE \"{}\"", db_info.name); + let create_query = format!( + "CREATE DATABASE {}", + crate::utils::quote_ident(&db_info.name) + ); target_client .execute(&create_query, &[]) .await @@ -666,7 +671,10 @@ async fn drop_database_if_exists(target_conn: &Client, db_name: &str) -> Result< target_conn.execute(terminate_query, &[&db_name]).await?; // Drop the database - let drop_query = format!("DROP DATABASE IF EXISTS \"{}\"", db_name); + let drop_query = format!( + "DROP DATABASE IF EXISTS {}", + crate::utils::quote_ident(db_name) + ); target_conn .execute(&drop_query, &[]) .await diff --git a/src/main.rs b/src/main.rs index 26d7199..5185c4f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,6 +9,13 @@ use database_replicator::commands; #[command(about = "Universal database-to-PostgreSQL replication CLI", long_about = None)] #[command(version)] struct Cli { + /// Allow self-signed TLS certificates (insecure - use only for testing) + #[arg( + long = "allow-self-signed-certs", + global = true, + default_value_t = false + )] + allow_self_signed_certs: bool, #[command(subcommand)] command: Commands, } @@ -181,6 +188,9 @@ async fn main() -> anyhow::Result<()> { let cli = Cli::parse(); + // Initialize TLS policy using thread-safe OnceLock + database_replicator::postgres::connection::init_tls_policy(cli.allow_self_signed_certs); + match cli.command { Commands::Validate { source, @@ -402,7 +412,7 @@ async fn init_remote( drop_existing: bool, no_sync: bool, seren_api: String, - _job_timeout: u64, + job_timeout: u64, ) -> anyhow::Result<()> { use database_replicator::migration; use database_replicator::postgres; @@ -466,6 +476,8 @@ async fn init_remote( } else { Some(FilterSpec { include_databases, + exclude_databases, + include_tables, exclude_tables, }) }; @@ -481,7 +493,12 @@ async fn init_remote( "estimated_size_bytes".to_string(), serde_json::Value::Number(serde_json::Number::from(estimated_size_bytes)), ); - // Note: "yes" and "job_timeout" are client-side only options, not sent to server + // Optional timeout hint for remote orchestrator + options.insert( + "job_timeout_seconds".to_string(), + serde_json::Value::Number(serde_json::Number::from(job_timeout as i64)), + ); + // Note: "yes" is client-side only, not sent to server let job_spec = JobSpec { version: "1.0".to_string(), diff --git a/src/migration/dump.rs b/src/migration/dump.rs index 6fc1393..9520c17 100644 --- a/src/migration/dump.rs +++ b/src/migration/dump.rs @@ -65,6 +65,7 @@ pub async fn dump_globals(source_url: &str, output_path: &str) -> Result<()> { Duration::from_secs(1), // Start with 1 second delay "pg_dumpall (dump globals)", ) + .await .context( "pg_dumpall failed to dump global objects.\n\ \n\ @@ -172,6 +173,7 @@ pub async fn dump_schema( Duration::from_secs(1), // Start with 1 second delay "pg_dump (dump schema)", ) + .await .with_context(|| { format!( "pg_dump failed to dump schema for database '{}'.\n\ @@ -299,6 +301,7 @@ pub async fn dump_data( Duration::from_secs(1), // Start with 1 second delay "pg_dump (dump data)", ) + .await .with_context(|| { format!( "pg_dump failed to dump data for database '{}'.\n\ diff --git a/src/migration/restore.rs b/src/migration/restore.rs index 54ffaaf..f10d4c3 100644 --- a/src/migration/restore.rs +++ b/src/migration/restore.rs @@ -62,7 +62,8 @@ pub async fn restore_globals(target_url: &str, input_path: &str) -> Result<()> { 3, // Max 3 retries Duration::from_secs(1), // Start with 1 second delay "psql (restore globals)", - ); + ) + .await; // Handle result - don't fail on warnings for global objects match result { @@ -136,6 +137,7 @@ pub async fn restore_schema(target_url: &str, input_path: &str) -> Result<()> { Duration::from_secs(1), // Start with 1 second delay "psql (restore schema)", ) + .await .context( "Schema restoration failed.\n\ \n\ @@ -228,6 +230,7 @@ pub async fn restore_data(target_url: &str, input_path: &str) -> Result<()> { Duration::from_secs(1), // Start with 1 second delay "pg_restore (restore data)", ) + .await .context( "Data restoration failed.\n\ \n\ diff --git a/src/mysql/mod.rs b/src/mysql/mod.rs index 3dbe93a..8a762dd 100644 --- a/src/mysql/mod.rs +++ b/src/mysql/mod.rs @@ -44,11 +44,7 @@ pub fn validate_mysql_url(connection_string: &str) -> Result { } if !connection_string.starts_with("mysql://") { - bail!( - "Invalid MySQL connection string '{}'. \ - Must start with 'mysql://'", - connection_string - ); + bail!("Invalid MySQL connection string. Must start with 'mysql://'"); } tracing::debug!("Validated MySQL connection string"); diff --git a/src/mysql/reader.rs b/src/mysql/reader.rs index fe23d3c..6d4d754 100644 --- a/src/mysql/reader.rs +++ b/src/mysql/reader.rs @@ -86,7 +86,11 @@ pub async fn get_table_row_count( tracing::debug!("Getting row count for table '{}.{}'", db_name, table_name); // Use backticks for identifiers to allow reserved words - let query = format!("SELECT COUNT(*) FROM `{}`.`{}`", db_name, table_name); + let query = format!( + "SELECT COUNT(*) FROM {}.{}", + crate::utils::quote_mysql_ident(db_name), + crate::utils::quote_mysql_ident(table_name) + ); let count: Option = conn .query_first(&query) @@ -137,7 +141,11 @@ pub async fn read_table_data(conn: &mut Conn, db_name: &str, table_name: &str) - tracing::info!("Reading all rows from table '{}.{}'", db_name, table_name); // Use backticks for identifiers - let query = format!("SELECT * FROM `{}`.`{}`", db_name, table_name); + let query = format!( + "SELECT * FROM {}.{}", + crate::utils::quote_mysql_ident(db_name), + crate::utils::quote_mysql_ident(table_name) + ); let rows: Vec = conn .query(&query) diff --git a/src/postgres/connection.rs b/src/postgres/connection.rs index 77668e3..e6dd351 100644 --- a/src/postgres/connection.rs +++ b/src/postgres/connection.rs @@ -5,9 +5,28 @@ use crate::utils; use anyhow::{Context, Result}; use native_tls::TlsConnector; use postgres_native_tls::MakeTlsConnector; +use std::sync::OnceLock; use std::time::Duration; use tokio_postgres::Client; +/// Thread-safe storage for TLS configuration set at startup +static ALLOW_SELF_SIGNED_CERTS: OnceLock = OnceLock::new(); + +/// Initialize the TLS certificate policy (call once at startup) +/// +/// This must be called before any database connections are made. +/// It is thread-safe and will only set the value once. +/// +/// # Arguments +/// +/// * `allow` - If true, accept self-signed/invalid TLS certificates (insecure) +pub fn init_tls_policy(allow: bool) { + let _ = ALLOW_SELF_SIGNED_CERTS.set(allow); + if allow { + tracing::warn!("TLS policy: Allowing self-signed/invalid certificates (insecure)"); + } +} + /// Add TCP keepalive parameters to a PostgreSQL connection string /// /// Automatically adds keepalive parameters to prevent idle connection timeouts @@ -130,10 +149,15 @@ pub async fn connect(connection_string: &str) -> Result { )?; // Set up TLS connector for cloud connections - // TEMPORARY: Accept invalid certs to debug TLS issues - // TODO: Remove this once we identify the certificate validation issue - let tls_connector = TlsConnector::builder() - .danger_accept_invalid_certs(true) + // By default, require valid certificates. Allow opt-in via init_tls_policy() called at startup. + let allow_self_signed = ALLOW_SELF_SIGNED_CERTS.get().copied().unwrap_or(false); + + let mut tls_builder = TlsConnector::builder(); + if allow_self_signed { + tls_builder.danger_accept_invalid_certs(true); + } + + let tls_connector = tls_builder .build() .context("Failed to build TLS connector")?; let tls = MakeTlsConnector::new(tls_connector); diff --git a/src/remote/models.rs b/src/remote/models.rs index 9a960d8..93aa864 100644 --- a/src/remote/models.rs +++ b/src/remote/models.rs @@ -19,6 +19,8 @@ pub struct JobSpec { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct FilterSpec { pub include_databases: Option>, + pub exclude_databases: Option>, + pub include_tables: Option>, pub exclude_tables: Option>, } diff --git a/src/replication/monitor.rs b/src/replication/monitor.rs index b4a719b..c1aa82b 100644 --- a/src/replication/monitor.rs +++ b/src/replication/monitor.rs @@ -34,41 +34,47 @@ pub async fn get_replication_lag( client: &Client, subscription_name: Option<&str>, ) -> Result> { - let query = if let Some(sub_name) = subscription_name { - format!( - "SELECT - application_name, - state, - sent_lsn::text, - write_lsn::text, - flush_lsn::text, - replay_lsn::text, - EXTRACT(EPOCH FROM write_lag) * 1000 as write_lag_ms, - EXTRACT(EPOCH FROM flush_lag) * 1000 as flush_lag_ms, - EXTRACT(EPOCH FROM replay_lag) * 1000 as replay_lag_ms - FROM pg_stat_replication - WHERE application_name = '{}'", - sub_name - ) - } else { - "SELECT - application_name, - state, - sent_lsn::text, - write_lsn::text, - flush_lsn::text, - replay_lsn::text, - EXTRACT(EPOCH FROM write_lag) * 1000 as write_lag_ms, - EXTRACT(EPOCH FROM flush_lag) * 1000 as flush_lag_ms, - EXTRACT(EPOCH FROM replay_lag) * 1000 as replay_lag_ms - FROM pg_stat_replication" - .to_string() - }; + if let Some(name) = subscription_name { + crate::utils::validate_postgres_identifier(name).context("Invalid subscription name")?; + } - let rows = client - .query(&query, &[]) - .await - .context("Failed to query replication statistics")?; + let rows = if let Some(sub_name) = subscription_name { + client + .query( + "SELECT + application_name, + state, + sent_lsn::text, + write_lsn::text, + flush_lsn::text, + replay_lsn::text, + EXTRACT(EPOCH FROM write_lag) * 1000 as write_lag_ms, + EXTRACT(EPOCH FROM flush_lag) * 1000 as flush_lag_ms, + EXTRACT(EPOCH FROM replay_lag) * 1000 as replay_lag_ms + FROM pg_stat_replication + WHERE application_name = $1", + &[&sub_name], + ) + .await + } else { + client + .query( + "SELECT + application_name, + state, + sent_lsn::text, + write_lsn::text, + flush_lsn::text, + replay_lsn::text, + EXTRACT(EPOCH FROM write_lag) * 1000 as write_lag_ms, + EXTRACT(EPOCH FROM flush_lag) * 1000 as flush_lag_ms, + EXTRACT(EPOCH FROM replay_lag) * 1000 as replay_lag_ms + FROM pg_stat_replication", + &[], + ) + .await + } + .context("Failed to query replication statistics")?; let mut stats = Vec::new(); for row in rows { @@ -94,33 +100,39 @@ pub async fn get_subscription_status( client: &Client, subscription_name: Option<&str>, ) -> Result> { - let query = if let Some(sub_name) = subscription_name { - format!( - "SELECT - subname, - pid, - received_lsn::text, - latest_end_lsn::text, - srsubstate - FROM pg_stat_subscription - WHERE subname = '{}'", - sub_name - ) - } else { - "SELECT - subname, - pid, - received_lsn::text, - latest_end_lsn::text, - srsubstate - FROM pg_stat_subscription" - .to_string() - }; + if let Some(name) = subscription_name { + crate::utils::validate_postgres_identifier(name).context("Invalid subscription name")?; + } - let rows = client - .query(&query, &[]) - .await - .context("Failed to query subscription statistics")?; + let rows = if let Some(sub_name) = subscription_name { + client + .query( + "SELECT + subname, + pid, + received_lsn::text, + latest_end_lsn::text, + srsubstate + FROM pg_stat_subscription + WHERE subname = $1", + &[&sub_name], + ) + .await + } else { + client + .query( + "SELECT + subname, + pid, + received_lsn::text, + latest_end_lsn::text, + srsubstate + FROM pg_stat_subscription", + &[], + ) + .await + } + .context("Failed to query subscription statistics")?; let mut stats = Vec::new(); for row in rows { diff --git a/src/replication/publication.rs b/src/replication/publication.rs index daf8c85..1586b2d 100644 --- a/src/replication/publication.rs +++ b/src/replication/publication.rs @@ -39,7 +39,10 @@ pub async fn create_publication( tracing::info!("Creating publication '{}'...", publication_name); if filter.is_empty() { - let query = format!("CREATE PUBLICATION \"{}\" FOR ALL TABLES", publication_name); + let query = format!( + "CREATE PUBLICATION {} FOR ALL TABLES", + crate::utils::quote_ident(publication_name) + ); return execute_publication_query(client, publication_name, &query).await; } @@ -121,8 +124,8 @@ pub async fn create_publication( ); let query = format!( - "CREATE PUBLICATION \"{}\" FOR TABLE {}", - publication_name, + "CREATE PUBLICATION {} FOR TABLE {}", + crate::utils::quote_ident(publication_name), clauses.join(", ") ); @@ -218,7 +221,10 @@ pub async fn drop_publication(client: &Client, publication_name: &str) -> Result tracing::info!("Dropping publication '{}'...", publication_name); - let query = format!("DROP PUBLICATION IF EXISTS \"{}\"", publication_name); + let query = format!( + "DROP PUBLICATION IF EXISTS {}", + crate::utils::quote_ident(publication_name) + ); client .execute(&query, &[]) diff --git a/src/replication/subscription.rs b/src/replication/subscription.rs index 6faed27..d56972f 100644 --- a/src/replication/subscription.rs +++ b/src/replication/subscription.rs @@ -51,8 +51,10 @@ pub async fn create_subscription( ); let query = format!( - "CREATE SUBSCRIPTION \"{}\" CONNECTION '{}' PUBLICATION \"{}\"", - subscription_name, source_connection_string, publication_name + "CREATE SUBSCRIPTION {} CONNECTION {} PUBLICATION {}", + crate::utils::quote_ident(subscription_name), + crate::utils::quote_literal(source_connection_string), + crate::utils::quote_ident(publication_name) ); match client.execute(&query, &[]).await { @@ -155,7 +157,10 @@ pub async fn drop_subscription(client: &Client, subscription_name: &str) -> Resu tracing::info!("Dropping subscription '{}'...", subscription_name); - let query = format!("DROP SUBSCRIPTION IF EXISTS \"{}\"", subscription_name); + let query = format!( + "DROP SUBSCRIPTION IF EXISTS {}", + crate::utils::quote_ident(subscription_name) + ); client.execute(&query, &[]).await.context(format!( "Failed to drop subscription '{}'", diff --git a/src/sqlite/reader.rs b/src/sqlite/reader.rs index 6148a9d..fbaed74 100644 --- a/src/sqlite/reader.rs +++ b/src/sqlite/reader.rs @@ -95,7 +95,7 @@ pub fn get_table_row_count(conn: &Connection, table: &str) -> Result { tracing::debug!("Getting row count for table '{}'", table); // Note: table name is validated above, so it's safe to use in SQL - let query = format!("SELECT COUNT(*) FROM \"{}\"", table); + let query = format!("SELECT COUNT(*) FROM {}", crate::utils::quote_ident(table)); let count: i64 = conn .query_row(&query, [], |row| row.get(0)) @@ -151,7 +151,7 @@ pub fn read_table_data( tracing::info!("Reading all data from table '{}'", table); // Note: table name is validated above - let query = format!("SELECT * FROM \"{}\"", table); + let query = format!("SELECT * FROM {}", crate::utils::quote_ident(table)); let mut stmt = conn .prepare(&query) diff --git a/src/utils.rs b/src/utils.rs index f7d9ef3..d4084cd 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -261,7 +261,7 @@ where /// # use std::time::Duration; /// # use std::process::Command; /// # use database_replicator::utils::retry_subprocess_with_backoff; -/// # fn example() -> Result<()> { +/// # async fn example() -> Result<()> { /// retry_subprocess_with_backoff( /// || { /// let mut cmd = Command::new("psql"); @@ -271,11 +271,11 @@ where /// 3, // Try up to 3 times /// Duration::from_secs(1), // Start with 1s delay /// "psql" -/// )?; +/// ).await?; /// # Ok(()) /// # } /// ``` -pub fn retry_subprocess_with_backoff( +pub async fn retry_subprocess_with_backoff( mut operation: F, max_retries: u32, initial_delay: Duration, @@ -311,7 +311,7 @@ where max_retries + 1, delay ); - std::thread::sleep(delay); + tokio::time::sleep(delay).await; delay *= 2; // Exponential backoff } } @@ -328,7 +328,7 @@ where last_error.as_ref().unwrap(), delay ); - std::thread::sleep(delay); + tokio::time::sleep(delay).await; delay *= 2; // Exponential backoff } } @@ -490,6 +490,57 @@ pub fn quote_ident(identifier: &str) -> String { quoted } +/// Quote a SQL string literal (for use in SQL statements) +/// +/// Escapes single quotes by doubling them and wraps the string in single quotes. +/// Use this for string values in SQL, not for identifiers. +/// +/// # Examples +/// +/// ``` +/// use database_replicator::utils::quote_literal; +/// assert_eq!(quote_literal("hello"), "'hello'"); +/// assert_eq!(quote_literal("it's"), "'it''s'"); +/// assert_eq!(quote_literal(""), "''"); +/// ``` +pub fn quote_literal(value: &str) -> String { + let mut quoted = String::with_capacity(value.len() + 2); + quoted.push('\''); + for ch in value.chars() { + if ch == '\'' { + quoted.push('\''); + } + quoted.push(ch); + } + quoted.push('\''); + quoted +} + +/// Quote a MySQL identifier (database, table, column) +/// +/// MySQL uses backticks for identifier quoting. Escapes embedded backticks +/// by doubling them. +/// +/// # Examples +/// +/// ``` +/// use database_replicator::utils::quote_mysql_ident; +/// assert_eq!(quote_mysql_ident("users"), "`users`"); +/// assert_eq!(quote_mysql_ident("user`name"), "`user``name`"); +/// ``` +pub fn quote_mysql_ident(identifier: &str) -> String { + let mut quoted = String::with_capacity(identifier.len() + 2); + quoted.push('`'); + for ch in identifier.chars() { + if ch == '`' { + quoted.push('`'); + } + quoted.push(ch); + } + quoted.push('`'); + quoted +} + /// Validate that source and target URLs are different to prevent accidental data loss /// /// Compares two PostgreSQL connection URLs to ensure they point to different databases. @@ -561,7 +612,18 @@ pub fn validate_source_target_different(source_url: &str, target_url: &str) -> R && source_parts.user == target_parts.user { bail!( - "Source and target URLs point to the same database!\\n\\\n \\n\\\n This would cause DATA LOSS - the target would overwrite the source.\\n\\\n \\n\\\n Source: {}@{}:{}/{}\\n\\\n Target: {}@{}:{}/{}\\n\\\n \\n\\\n Please ensure source and target are different databases.\\n\\\n Common causes:\\n\\\n - Copy-paste error in connection strings\\n\\\n - Wrong environment variables (e.g., SOURCE_URL == TARGET_URL)\\n\\\n - Typo in database name or host", + "Source and target URLs point to the same database!\n\ + \n\ + This would cause DATA LOSS - the target would overwrite the source.\n\ + \n\ + Source: {}@{}:{}/{}\n\ + Target: {}@{}:{}/{}\n\ + \n\ + Please ensure source and target are different databases.\n\ + Common causes:\n\ + - Copy-paste error in connection strings\n\ + - Wrong environment variables (e.g., SOURCE_URL == TARGET_URL)\n\ + - Typo in database name or host", source_parts.user.as_deref().unwrap_or("(no user)"), source_parts.host, source_parts.port, diff --git a/tests/integration_remote_test.rs b/tests/integration_remote_test.rs index bccb9f6..ea083e1 100644 --- a/tests/integration_remote_test.rs +++ b/tests/integration_remote_test.rs @@ -266,6 +266,8 @@ async fn test_remote_job_submission_with_filters() { // Create a job spec with database filters let filter = database_replicator::remote::FilterSpec { include_databases: Some(vec!["postgres".to_string()]), + exclude_databases: None, + include_tables: None, exclude_tables: None, }; diff --git a/tests/security_test.rs b/tests/security_test.rs index f8a1c6a..a18f63f 100644 --- a/tests/security_test.rs +++ b/tests/security_test.rs @@ -787,10 +787,7 @@ fn test_mysql_url_with_special_chars_in_password() { fn test_mysql_error_messages_dont_leak_credentials() { use database_replicator::mysql; - // SECURITY NOTE: Current implementation includes full URL in error messages - // This test documents the current behavior - ideally this should be fixed - // to sanitize URLs before including in error messages - + // SECURITY: Error messages should NOT leak passwords or full URLs let url_with_password = "not-mysql://admin:secretpass@host:3306/db"; let result = mysql::validate_mysql_url(url_with_password); @@ -798,22 +795,23 @@ fn test_mysql_error_messages_dont_leak_credentials() { let error_msg = result.unwrap_err().to_string(); - // KNOWN ISSUE: Error message currently contains the full URL including password - // This test verifies current behavior, but this should be improved + // Verify password is NOT leaked in error message assert!( - error_msg.contains("secretpass") || error_msg.contains("not-mysql://"), - "Error message currently includes full URL (known issue)" + !error_msg.contains("secretpass"), + "Error message should not contain password: {error_msg}" + ); + + // Verify full URL is NOT leaked in error message + assert!( + !error_msg.contains("not-mysql://"), + "Error message should not contain full malformed URL: {error_msg}" ); // Verify it does explain the validation failure assert!( error_msg.contains("mysql://") || error_msg.contains("Invalid"), - "Error should explain validation requirement" + "Error should explain validation requirement: {error_msg}" ); - - // TODO: Enhance validate_mysql_url to sanitize URLs in error messages - // Expected: "Invalid MySQL connection string. Must start with 'mysql://'" - // (without exposing the actual malformed URL) } // ----------------------------------------------------------------------------