diff --git a/api/src/routes/pipelines.rs b/api/src/routes/pipelines.rs index 4b8753fae..cde65ee9f 100644 --- a/api/src/routes/pipelines.rs +++ b/api/src/routes/pipelines.rs @@ -8,6 +8,7 @@ use config::shared::{ DestinationConfig, PgConnectionConfig, PipelineConfig as SharedPipelineConfig, ReplicatorConfig, SupabaseConfig, TlsConfig, }; +use postgres::schema::TableId; use serde::{Deserialize, Serialize}; use sqlx::{PgPool, PgTransaction}; use std::ops::DerefMut; @@ -718,7 +719,8 @@ pub async fn get_pipeline_replication_status( let mut tables: Vec = Vec::new(); for row in state_rows { let table_id = row.table_id.0; - let table_name = get_table_name_from_oid(&source_pool, table_id).await?; + let table_name = + get_table_name_from_oid(&source_pool, TableId::new(row.table_id.0)).await?; tables.push(TableReplicationStatus { table_id, table_name: table_name.to_string(), diff --git a/etl/benches/table_copies.rs b/etl/benches/table_copies.rs index fcd118c91..46a34b0cc 100644 --- a/etl/benches/table_copies.rs +++ b/etl/benches/table_copies.rs @@ -387,7 +387,10 @@ async fn start_pipeline(args: RunArgs) -> Result<(), Box> { let mut table_copied_notifications = vec![]; for table_id in &args.table_ids { let table_copied = state_store - .notify_on_table_state(*table_id, TableReplicationPhaseType::FinishedCopy) + .notify_on_table_state( + TableId::new(*table_id), + TableReplicationPhaseType::FinishedCopy, + ) .await; table_copied_notifications.push(table_copied); } diff --git a/etl/src/conversions/event.rs b/etl/src/conversions/event.rs index 9293fa53e..a2175e897 100644 --- a/etl/src/conversions/event.rs +++ b/etl/src/conversions/event.rs @@ -112,7 +112,11 @@ impl RelationEvent { .iter() .map(Self::build_column_schema) .collect::, _>>()?; - let table_schema = TableSchema::new(relation_body.rel_id(), table_name, column_schemas); + let table_schema = TableSchema::new( + TableId::new(relation_body.rel_id()), + table_name, + column_schemas, + ); Ok(Self { start_lsn, @@ -319,7 +323,7 @@ async fn convert_insert_to_event( insert_body: &protocol::InsertBody, ) -> Result { let table_id = insert_body.rel_id(); - let table_schema = get_table_schema(schema_cache, table_id).await?; + let table_schema = get_table_schema(schema_cache, TableId::new(table_id)).await?; let table_row = convert_tuple_to_row( &table_schema.column_schemas, @@ -329,7 +333,7 @@ async fn convert_insert_to_event( Ok(InsertEvent { start_lsn, commit_lsn, - table_id, + table_id: TableId::new(table_id), table_row, }) } @@ -341,7 +345,7 @@ async fn convert_update_to_event( update_body: &protocol::UpdateBody, ) -> Result { let table_id = update_body.rel_id(); - let table_schema = get_table_schema(schema_cache, table_id).await?; + let table_schema = get_table_schema(schema_cache, TableId::new(table_id)).await?; let table_row = convert_tuple_to_row( &table_schema.column_schemas, @@ -364,7 +368,7 @@ async fn convert_update_to_event( Ok(UpdateEvent { start_lsn, commit_lsn, - table_id, + table_id: TableId::new(table_id), table_row, old_table_row, }) @@ -377,7 +381,7 @@ async fn convert_delete_to_event( delete_body: &protocol::DeleteBody, ) -> Result { let table_id = delete_body.rel_id(); - let table_schema = get_table_schema(schema_cache, table_id).await?; + let table_schema = get_table_schema(schema_cache, TableId::new(table_id)).await?; // We try to extract the old tuple by either taking the entire old tuple or the key of the old // tuple. @@ -395,7 +399,7 @@ async fn convert_delete_to_event( Ok(DeleteEvent { start_lsn, commit_lsn, - table_id, + table_id: TableId::new(table_id), old_table_row, }) } diff --git a/etl/src/destination/bigquery.rs b/etl/src/destination/bigquery.rs index 0c9741bfe..b7d3ba4b8 100644 --- a/etl/src/destination/bigquery.rs +++ b/etl/src/destination/bigquery.rs @@ -394,7 +394,7 @@ impl BigQueryDestination { let mut result = Vec::new(); for (table_id, (table_name, column_schemas)) in table_schemas { - let table_schema = TableSchema::new(table_id, table_name, column_schemas); + let table_schema = TableSchema::new(TableId::new(table_id), table_name, column_schemas); result.push(table_schema); } @@ -557,7 +557,9 @@ impl BigQueryDestination { .lock_inner() .await; - if let Some(table_schema) = schema_cache.get_table_schema_ref(&table_id) { + if let Some(table_schema) = + schema_cache.get_table_schema_ref(&TableId::new(table_id)) + { inner .client .truncate_table( @@ -582,7 +584,7 @@ impl BigQueryDestination { /// Extracts table ID, schema name, and table name for storage in `etl_table_schemas`. fn table_schema_to_table_row(table_schema: &TableSchema) -> TableRow { let columns = vec![ - Cell::U32(table_schema.id), + Cell::U32(table_schema.id.into()), Cell::String(table_schema.name.schema.clone()), Cell::String(table_schema.name.name.clone()), ]; @@ -598,7 +600,7 @@ impl BigQueryDestination { for (column_order, column_schema) in table_schema.column_schemas.iter().enumerate() { let columns = vec![ - Cell::U32(table_schema.id), + Cell::U32(table_schema.id.into()), Cell::String(column_schema.name.clone()), Cell::String(Self::postgres_type_to_string(&column_schema.typ)), Cell::I32(column_schema.modifier), @@ -924,7 +926,7 @@ mod tests { ColumnSchema::new("data".to_string(), Type::JSONB, -1, true, false), ColumnSchema::new("active".to_string(), Type::BOOL, -1, false, false), ]; - let table_schema = TableSchema::new(456, table_name, columns); + let table_schema = TableSchema::new(TableId::new(456), table_name, columns); let schema_row = BigQueryDestination::table_schema_to_table_row(&table_schema); assert_eq!(schema_row.values[0], Cell::U32(456)); diff --git a/etl/src/replication/apply.rs b/etl/src/replication/apply.rs index dea5e3109..a693f443a 100644 --- a/etl/src/replication/apply.rs +++ b/etl/src/replication/apply.rs @@ -746,7 +746,7 @@ where }; if !hook - .should_apply_changes(message.rel_id(), remote_final_lsn) + .should_apply_changes(TableId::new(message.rel_id()), remote_final_lsn) .await? { return Ok(HandleMessageResult::default()); @@ -756,8 +756,12 @@ where // dealt with differently based on the worker type. // TODO: explore how to deal with applying relation messages to the schema (creating it if missing). let schema_cache = schema_cache.lock_inner().await; - let Some(existing_table_schema) = schema_cache.get_table_schema_ref(&message.rel_id()) else { - return Err(ApplyLoopError::MissingTableSchema(message.rel_id())); + let Some(existing_table_schema) = + schema_cache.get_table_schema_ref(&TableId::new(message.rel_id())) + else { + return Err(ApplyLoopError::MissingTableSchema(TableId::new( + message.rel_id(), + ))); }; // We compare the table schema from the relation message with the existing schema (if any). @@ -766,7 +770,7 @@ where if !existing_table_schema.partial_eq(&event.table_schema) { return Ok(HandleMessageResult { end_batch: Some(EndBatch::Exclusive), - skip_table: Some(message.rel_id()), + skip_table: Some(TableId::new(message.rel_id())), ..Default::default() }); } @@ -803,7 +807,7 @@ where }; if !hook - .should_apply_changes(message.rel_id(), remote_final_lsn) + .should_apply_changes(TableId::new(message.rel_id()), remote_final_lsn) .await? { return Ok(HandleMessageResult::default()); @@ -841,7 +845,7 @@ where }; if !hook - .should_apply_changes(message.rel_id(), remote_final_lsn) + .should_apply_changes(TableId::new(message.rel_id()), remote_final_lsn) .await? { return Ok(HandleMessageResult::default()); @@ -879,7 +883,7 @@ where }; if !hook - .should_apply_changes(message.rel_id(), remote_final_lsn) + .should_apply_changes(TableId::new(message.rel_id()), remote_final_lsn) .await? { return Ok(HandleMessageResult::default()); @@ -919,7 +923,7 @@ where let mut rel_ids = Vec::with_capacity(message.rel_ids().len()); for &table_id in message.rel_ids().iter() { if hook - .should_apply_changes(table_id, remote_final_lsn) + .should_apply_changes(TableId::new(table_id), remote_final_lsn) .await? { rel_ids.push(table_id) diff --git a/etl/src/replication/slot.rs b/etl/src/replication/slot.rs index 2ce3e7e86..4fd9b3183 100644 --- a/etl/src/replication/slot.rs +++ b/etl/src/replication/slot.rs @@ -43,6 +43,7 @@ pub fn get_slot_name( #[cfg(test)] mod tests { use super::*; + use postgres::schema::TableId; #[test] fn test_apply_worker_slot_name() { @@ -55,7 +56,13 @@ mod tests { #[test] fn test_table_sync_slot_name() { let pipeline_id = 1; - let result = get_slot_name(pipeline_id, WorkerType::TableSync { table_id: 123 }).unwrap(); + let result = get_slot_name( + pipeline_id, + WorkerType::TableSync { + table_id: TableId::new(123), + }, + ) + .unwrap(); assert!(result.starts_with(TABLE_SYNC_PREFIX)); assert!(result.len() <= MAX_SLOT_NAME_LENGTH); } diff --git a/etl/src/state/store/postgres.rs b/etl/src/state/store/postgres.rs index 1b13169ed..f1fff441e 100644 --- a/etl/src/state/store/postgres.rs +++ b/etl/src/state/store/postgres.rs @@ -3,7 +3,7 @@ use std::{collections::HashMap, sync::Arc}; use config::shared::PgConnectionConfig; use postgres::replication::{ TableReplicationState, TableReplicationStateRow, connect_to_source_database, - update_replication_state, + get_table_replication_state_rows, update_replication_state, }; use postgres::schema::TableId; use sqlx::PgPool; @@ -96,7 +96,7 @@ impl PostgresStateStore { pool: &PgPool, pipeline_id: PipelineId, ) -> sqlx::Result> { - postgres::replication::get_table_replication_state_rows(pool, pipeline_id as i64).await + get_table_replication_state_rows(pool, pipeline_id as i64).await } async fn update_replication_state( @@ -165,7 +165,7 @@ impl StateStore for PostgresStateStore { let phase = self .replication_phase_from_state(&row.state, row.sync_done_lsn) .await?; - table_states.insert(row.table_id.0, phase); + table_states.insert(TableId::new(row.table_id.0), phase); } let mut inner = self.inner.lock().await; inner.table_states = table_states.clone(); @@ -188,6 +188,7 @@ impl StateStore for PostgresStateStore { .await?; let mut inner = self.inner.lock().await; inner.table_states.insert(table_id, state); + Ok(()) } } diff --git a/etl/src/test_utils/test_schema.rs b/etl/src/test_utils/test_schema.rs index d84eafb60..3e50bd31e 100644 --- a/etl/src/test_utils/test_schema.rs +++ b/etl/src/test_utils/test_schema.rs @@ -1,4 +1,4 @@ -use postgres::schema::{ColumnSchema, Oid, TableName, TableSchema}; +use postgres::schema::{ColumnSchema, TableId, TableName, TableSchema}; use postgres::tokio::test_utils::{PgDatabase, id_column_schema}; use std::ops::RangeInclusive; use tokio_postgres::types::{PgLsn, Type}; @@ -188,7 +188,7 @@ pub async fn insert_mock_data( pub async fn get_users_age_sum_from_rows( destination: &TestDestinationWrapper, - table_id: Oid, + table_id: TableId, ) -> i32 { let mut actual_sum = 0; @@ -250,7 +250,7 @@ pub fn events_equal_excluding_fields(left: &Event, right: &Event) -> bool { pub fn build_expected_users_inserts( mut starting_id: i64, - users_table_id: Oid, + users_table_id: TableId, expected_rows: Vec<(&str, i32)>, ) -> Vec { let mut events = Vec::new(); @@ -277,7 +277,7 @@ pub fn build_expected_users_inserts( pub fn build_expected_orders_inserts( mut starting_id: i64, - orders_table_id: Oid, + orders_table_id: TableId, expected_rows: Vec<&str>, ) -> Vec { let mut events = Vec::new(); diff --git a/etl/src/workers/table_sync.rs b/etl/src/workers/table_sync.rs index 8167c3e04..f46337260 100644 --- a/etl/src/workers/table_sync.rs +++ b/etl/src/workers/table_sync.rs @@ -325,7 +325,7 @@ where "table_sync_worker", pipeline_id = self.pipeline_id, publication_name = self.config.publication_name, - table_id = self.table_id, + table_id = %self.table_id, ); let table_sync_worker = async move { debug!( diff --git a/postgres/Cargo.toml b/postgres/Cargo.toml index dd8f75add..4505d14d6 100644 --- a/postgres/Cargo.toml +++ b/postgres/Cargo.toml @@ -5,6 +5,8 @@ edition = "2024" [dependencies] config = { workspace = true } + +bytes = { workspace = true } pg_escape = { workspace = true } rustls = { workspace = true } serde = { workspace = true, features = ["derive"] } diff --git a/postgres/src/replication/db.rs b/postgres/src/replication/db.rs index f0593151a..1f85c51a0 100644 --- a/postgres/src/replication/db.rs +++ b/postgres/src/replication/db.rs @@ -47,7 +47,7 @@ pub async fn get_table_name_from_oid( "; let row = sqlx::query(query) - .bind(table_id as i64) + .bind(table_id.into_inner() as i64) .fetch_optional(pool) .await?; diff --git a/postgres/src/replication/state.rs b/postgres/src/replication/state.rs index 42e912016..cf3c47657 100644 --- a/postgres/src/replication/state.rs +++ b/postgres/src/replication/state.rs @@ -60,7 +60,7 @@ pub async fn update_replication_state( "#, ) .bind(pipeline_id as i64) - .bind(SqlxTableId(table_id)) + .bind(SqlxTableId(table_id.into_inner())) .bind(state) .bind(sync_done_lsn) .execute(pool) diff --git a/postgres/src/schema.rs b/postgres/src/schema.rs index 6009ea421..5831ab276 100644 --- a/postgres/src/schema.rs +++ b/postgres/src/schema.rs @@ -1,8 +1,9 @@ use std::cmp::Ordering; use std::fmt; +use std::str::FromStr; use pg_escape::quote_identifier; -use tokio_postgres::types::Type; +use tokio_postgres::types::{FromSql, ToSql, Type}; /// An object identifier in PostgreSQL. pub type Oid = u32; @@ -107,10 +108,87 @@ impl ColumnSchema { } } -/// A type alias for PostgreSQL table OIDs. +/// A type-safe wrapper for PostgreSQL table OIDs. /// /// Table OIDs are unique identifiers assigned to tables in PostgreSQL. -pub type TableId = Oid; +/// +/// This newtype provides type safety by preventing accidental use of raw [`Oid`] values +/// where a table identifier is expected. +#[derive(Debug, Clone, Copy, Eq, PartialEq, PartialOrd, Ord, Hash)] +pub struct TableId(pub Oid); + +impl TableId { + /// Creates a new [`TableId`] from an [`Oid`]. + pub fn new(oid: Oid) -> Self { + Self(oid) + } + + /// Returns the underlying [`Oid`] value. + pub fn into_inner(self) -> Oid { + self.0 + } +} + +impl From for TableId { + fn from(oid: Oid) -> Self { + Self(oid) + } +} + +impl From for Oid { + fn from(table_id: TableId) -> Self { + table_id.0 + } +} + +impl fmt::Display for TableId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +impl FromStr for TableId { + type Err = ::Err; + + fn from_str(s: &str) -> Result { + s.parse::().map(TableId::new) + } +} + +impl<'a> FromSql<'a> for TableId { + fn from_sql( + ty: &Type, + raw: &'a [u8], + ) -> Result> { + Ok(TableId::new(Oid::from_sql(ty, raw)?)) + } + + fn accepts(ty: &Type) -> bool { + ::accepts(ty) + } +} + +impl ToSql for TableId { + fn to_sql( + &self, + ty: &Type, + w: &mut bytes::BytesMut, + ) -> Result> + where + Self: Sized, + { + self.0.to_sql(ty, w) + } + + fn accepts(ty: &Type) -> bool + where + Self: Sized, + { + ::accepts(ty) + } + + tokio_postgres::types::to_sql_checked!(); +} /// Represents the complete schema of a PostgreSQL table. ///