diff --git a/sqlx-core/src/odbc/connection/executor.rs b/sqlx-core/src/odbc/connection/executor.rs index 45a32185f8..b62ebe1531 100644 --- a/sqlx-core/src/odbc/connection/executor.rs +++ b/sqlx-core/src/odbc/connection/executor.rs @@ -5,8 +5,7 @@ use crate::odbc::{Odbc, OdbcConnection, OdbcQueryResult, OdbcRow, OdbcStatement, use either::Either; use futures_core::future::BoxFuture; use futures_core::stream::BoxStream; -use futures_util::TryStreamExt; -use std::borrow::Cow; +use futures_util::{future, FutureExt, StreamExt}; // run method removed; fetch_many implements streaming directly @@ -21,15 +20,8 @@ impl<'c> Executor<'c> for &'c mut OdbcConnection { 'c: 'e, E: Execute<'q, Self::Database> + 'q, { - let sql = query.sql().to_string(); let args = query.take_arguments(); - Box::pin(try_stream! { - let rx = self.worker.execute_stream(&sql, args).await?; - while let Ok(item) = rx.recv_async().await { - r#yield!(item?); - } - Ok(()) - }) + Box::pin(self.execute_stream(query.sql(), args).into_stream()) } fn fetch_optional<'e, 'q: 'e, E>( @@ -40,15 +32,12 @@ impl<'c> Executor<'c> for &'c mut OdbcConnection { 'c: 'e, E: Execute<'q, Self::Database> + 'q, { - let mut s = self.fetch_many(query); - Box::pin(async move { - while let Some(v) = s.try_next().await? { - if let Either::Right(r) = v { - return Ok(Some(r)); - } - } - Ok(None) - }) + Box::pin(self.fetch_many(query).into_future().then(|(v, _)| match v { + Some(Ok(Either::Right(r))) => future::ok(Some(r)), + Some(Ok(Either::Left(_))) => future::ok(None), + Some(Err(e)) => future::err(e), + None => future::ok(None), + })) } fn prepare_with<'e, 'q: 'e>( @@ -59,14 +48,7 @@ impl<'c> Executor<'c> for &'c mut OdbcConnection { where 'c: 'e, { - Box::pin(async move { - let (_, columns, parameters) = self.worker.prepare(sql).await?; - Ok(OdbcStatement { - sql: Cow::Borrowed(sql), - columns, - parameters, - }) - }) + Box::pin(async move { self.prepare(sql).await }) } #[doc(hidden)] diff --git a/sqlx-core/src/odbc/connection/mod.rs b/sqlx-core/src/odbc/connection/mod.rs index fc9751bae0..84d5572c84 100644 --- a/sqlx-core/src/odbc/connection/mod.rs +++ b/sqlx-core/src/odbc/connection/mod.rs @@ -1,53 +1,212 @@ -use crate::connection::{Connection, LogSettings}; +use crate::connection::Connection; use crate::error::Error; -use crate::odbc::{Odbc, OdbcConnectOptions}; +use crate::odbc::{ + Odbc, OdbcArguments, OdbcColumn, OdbcConnectOptions, OdbcQueryResult, OdbcRow, OdbcTypeInfo, +}; use crate::transaction::Transaction; +use either::Either; +use sqlx_rt::spawn_blocking; +mod odbc_bridge; +use crate::odbc::{OdbcStatement, OdbcStatementMetadata}; use futures_core::future::BoxFuture; use futures_util::future; +use odbc_api::ConnectionTransitions; +use odbc_api::{handles::StatementConnection, Prepared, ResultSetMetadata, SharedConnection}; +use odbc_bridge::{establish_connection, execute_sql}; +use std::borrow::Cow; +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; mod executor; -mod worker; -pub(crate) use worker::ConnectionWorker; +type PreparedStatement = Prepared>>; +type SharedPreparedStatement = Arc>; + +fn collect_columns(prepared: &mut PreparedStatement) -> Vec { + let count = prepared.num_result_cols().unwrap_or(0); + (1..=count) + .map(|i| create_column(prepared, i as u16)) + .collect() +} + +fn create_column(stmt: &mut PreparedStatement, index: u16) -> OdbcColumn { + let mut cd = odbc_api::ColumnDescription::default(); + let _ = stmt.describe_col(index, &mut cd); + + OdbcColumn { + name: decode_column_name(cd.name, index), + type_info: OdbcTypeInfo::new(cd.data_type), + ordinal: usize::from(index.checked_sub(1).unwrap()), + } +} + +fn decode_column_name(name_bytes: Vec, index: u16) -> String { + String::from_utf8(name_bytes).unwrap_or_else(|_| format!("col{}", index - 1)) +} /// A connection to an ODBC-accessible database. /// -/// ODBC uses a blocking C API, so we run all calls on a dedicated background thread -/// and communicate over channels to provide async access. -#[derive(Debug)] +/// ODBC uses a blocking C API, so we offload blocking calls to the runtime's blocking +/// thread-pool via `spawn_blocking` and synchronize access with a mutex. pub struct OdbcConnection { - pub(crate) worker: ConnectionWorker, - pub(crate) log_settings: LogSettings, + pub(crate) conn: SharedConnection<'static>, + pub(crate) stmt_cache: HashMap, SharedPreparedStatement>, +} + +impl std::fmt::Debug for OdbcConnection { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("OdbcConnection") + .field("conn", &self.conn) + .finish() + } } impl OdbcConnection { + pub(crate) async fn with_conn(&mut self, operation: S, f: F) -> Result + where + R: Send + 'static, + F: FnOnce(&mut odbc_api::Connection<'static>) -> Result + Send + 'static, + S: std::fmt::Display + Send + 'static, + { + let conn = Arc::clone(&self.conn); + spawn_blocking(move || { + let mut conn_guard = conn.lock().map_err(|_| { + Error::Protocol(format!("ODBC {}: failed to lock connection", operation)) + })?; + f(&mut conn_guard) + }) + .await + } + pub(crate) async fn establish(options: &OdbcConnectOptions) -> Result { - let worker = ConnectionWorker::establish(options.clone()).await?; + let shared_conn = spawn_blocking({ + let options = options.clone(); + move || { + let conn = establish_connection(&options)?; + let shared_conn = odbc_api::SharedConnection::new(std::sync::Mutex::new(conn)); + Ok::<_, Error>(shared_conn) + } + }) + .await?; + Ok(Self { - worker, - log_settings: LogSettings::default(), + conn: shared_conn, + stmt_cache: HashMap::new(), }) } /// Returns the name of the actual Database Management System (DBMS) this /// connection is talking to as reported by the ODBC driver. - /// - /// This calls the underlying ODBC API `SQL_DBMS_NAME` via - /// `odbc_api::Connection::database_management_system_name`. - /// - /// See: https://docs.rs/odbc-api/19.0.1/odbc_api/struct.Connection.html#method.database_management_system_name pub async fn dbms_name(&mut self) -> Result { - self.worker.get_dbms_name().await + self.with_conn("dbms_name", move |conn| { + Ok(conn.database_management_system_name()?) + }) + .await + } + + pub(crate) async fn ping_blocking(&mut self) -> Result<(), Error> { + self.with_conn("ping", move |conn| { + conn.execute("SELECT 1", (), None)?; + Ok(()) + }) + .await + } + + pub(crate) async fn begin_blocking(&mut self) -> Result<(), Error> { + self.with_conn("begin", move |conn| { + conn.set_autocommit(false)?; + Ok(()) + }) + .await + } + + pub(crate) async fn commit_blocking(&mut self) -> Result<(), Error> { + self.with_conn("commit", move |conn| { + conn.commit()?; + conn.set_autocommit(true)?; + Ok(()) + }) + .await + } + + pub(crate) async fn rollback_blocking(&mut self) -> Result<(), Error> { + self.with_conn("rollback", move |conn| { + conn.rollback()?; + conn.set_autocommit(true)?; + Ok(()) + }) + .await + } + + /// Launches a background task to execute the SQL statement and send the results to the returned channel. + pub(crate) fn execute_stream( + &mut self, + sql: &str, + args: Option, + ) -> flume::Receiver, Error>> { + let (tx, rx) = flume::bounded(64); + + let maybe_prepared = if let Some(prepared) = self.stmt_cache.get(sql) { + MaybePrepared::Prepared(Arc::clone(prepared)) + } else { + MaybePrepared::NotPrepared(sql.to_string()) + }; + + let conn = Arc::clone(&self.conn); + sqlx_rt::spawn(sqlx_rt::spawn_blocking(move || { + let mut conn = conn.lock().expect("failed to lock connection"); + if let Err(e) = execute_sql(&mut conn, maybe_prepared, args, &tx) { + let _ = tx.send(Err(e)); + } + })); + + rx + } + + pub(crate) async fn clear_cached_statements(&mut self) -> Result<(), Error> { + // Clear the statement metadata cache + self.stmt_cache.clear(); + Ok(()) + } + + pub async fn prepare<'a>(&mut self, sql: &'a str) -> Result, Error> { + let conn = Arc::clone(&self.conn); + let sql_arc = Arc::from(sql.to_string()); + let sql_clone = Arc::clone(&sql_arc); + let (prepared, metadata) = spawn_blocking(move || { + let mut prepared = conn.into_prepared(&sql_clone)?; + let metadata = OdbcStatementMetadata { + columns: collect_columns(&mut prepared), + parameters: usize::from(prepared.num_params().unwrap_or(0)), + }; + Ok::<_, Error>((prepared, metadata)) + }) + .await?; + self.stmt_cache + .insert(Arc::clone(&sql_arc), Arc::new(Mutex::new(prepared))); + Ok(OdbcStatement { + sql: Cow::Borrowed(sql), + metadata, + }) } } +pub(crate) enum MaybePrepared { + Prepared(SharedPreparedStatement), + NotPrepared(String), +} + impl Connection for OdbcConnection { type Database = Odbc; type Options = OdbcConnectOptions; - fn close(mut self) -> BoxFuture<'static, Result<(), Error>> { - Box::pin(async move { self.worker.shutdown().await }) + fn close(self) -> BoxFuture<'static, Result<(), Error>> { + Box::pin(async move { + // Drop connection by moving Arc and letting it fall out of scope. + drop(self); + Ok(()) + }) } fn close_hard(self) -> BoxFuture<'static, Result<(), Error>> { @@ -55,7 +214,7 @@ impl Connection for OdbcConnection { } fn ping(&mut self) -> BoxFuture<'_, Result<(), Error>> { - Box::pin(self.worker.ping()) + Box::pin(self.ping_blocking()) } fn begin(&mut self) -> BoxFuture<'_, Result, Error>> @@ -74,4 +233,10 @@ impl Connection for OdbcConnection { fn should_flush(&self) -> bool { false } + + fn clear_cached_statements(&mut self) -> BoxFuture<'_, Result<(), Error>> { + Box::pin(self.clear_cached_statements()) + } } + +// moved helpers to connection/inner.rs diff --git a/sqlx-core/src/odbc/connection/odbc_bridge.rs b/sqlx-core/src/odbc/connection/odbc_bridge.rs new file mode 100644 index 0000000000..d0e20262e3 --- /dev/null +++ b/sqlx-core/src/odbc/connection/odbc_bridge.rs @@ -0,0 +1,327 @@ +use crate::error::Error; +use crate::odbc::{ + connection::MaybePrepared, OdbcArgumentValue, OdbcArguments, OdbcColumn, OdbcQueryResult, + OdbcRow, OdbcTypeInfo, +}; +use either::Either; +use flume::{SendError, Sender}; +use odbc_api::handles::{AsStatementRef, Statement}; +use odbc_api::{Cursor, CursorRow, IntoParameter, Nullable, ResultSetMetadata}; + +pub type ExecuteResult = Result, Error>; +pub type ExecuteSender = Sender; + +pub fn establish_connection( + options: &crate::odbc::OdbcConnectOptions, +) -> Result, Error> { + let env = odbc_api::environment().map_err(|e| Error::Configuration(e.to_string().into()))?; + let conn = env + .connect_with_connection_string(options.connection_string(), Default::default()) + .map_err(|e| Error::Configuration(e.to_string().into()))?; + Ok(conn) +} + +pub fn execute_sql( + conn: &mut odbc_api::Connection<'static>, + maybe_prepared: MaybePrepared, + args: Option, + tx: &ExecuteSender, +) -> Result<(), Error> { + let params = prepare_parameters(args); + + let affected = match maybe_prepared { + MaybePrepared::Prepared(prepared) => { + let mut prepared = prepared.lock().expect("prepared statement lock"); + if let Some(mut cursor) = prepared.execute(¶ms[..])? { + handle_cursor(&mut cursor, tx); + } + extract_rows_affected(&mut *prepared) + } + MaybePrepared::NotPrepared(sql) => { + let mut preallocated = conn.preallocate().map_err(Error::from)?; + if let Some(mut cursor) = preallocated.execute(&sql, ¶ms[..])? { + handle_cursor(&mut cursor, tx); + } + extract_rows_affected(&mut preallocated) + } + }; + + let _ = send_done(tx, affected); + Ok(()) +} + +fn extract_rows_affected(stmt: &mut S) -> u64 { + let mut stmt_ref = stmt.as_stmt_ref(); + let count = match stmt_ref.row_count().into_result(&stmt_ref) { + Ok(count) => count, + Err(e) => { + log::warn!("Failed to get row count: {}", e); + return 0; + } + }; + + match u64::try_from(count) { + Ok(count) => count, + Err(e) => { + log::warn!("Failed to get row count: {}", e); + 0 + } + } +} + +fn prepare_parameters( + args: Option, +) -> Vec> { + let args = args.map(|a| a.values).unwrap_or_default(); + args.into_iter().map(to_param).collect() +} + +fn to_param(arg: OdbcArgumentValue) -> Box { + match arg { + OdbcArgumentValue::Int(i) => Box::new(i.into_parameter()), + OdbcArgumentValue::Float(f) => Box::new(f.into_parameter()), + OdbcArgumentValue::Text(s) => Box::new(s.into_parameter()), + OdbcArgumentValue::Bytes(b) => Box::new(b.into_parameter()), + OdbcArgumentValue::Null => Box::new(Option::::None.into_parameter()), + } +} + +fn handle_cursor(cursor: &mut C, tx: &ExecuteSender) +where + C: Cursor + ResultSetMetadata, +{ + let columns = collect_columns(cursor); + + match stream_rows(cursor, &columns, tx) { + Ok(true) => { + let _ = send_done(tx, 0); + } + Ok(false) => {} + Err(e) => { + send_error(tx, e); + } + } +} + +fn send_done(tx: &ExecuteSender, rows_affected: u64) -> Result<(), SendError> { + tx.send(Ok(Either::Left(OdbcQueryResult { rows_affected }))) +} + +fn send_error(tx: &ExecuteSender, error: Error) { + let _ = tx.send(Err(error)); +} + +fn send_row(tx: &ExecuteSender, row: OdbcRow) -> Result<(), SendError> { + tx.send(Ok(Either::Right(row))) +} + +fn collect_columns(cursor: &mut C) -> Vec +where + C: ResultSetMetadata, +{ + let count = cursor.num_result_cols().unwrap_or(0); + (1..=count) + .map(|i| create_column(cursor, i as u16)) + .collect() +} + +fn create_column(cursor: &mut C, index: u16) -> OdbcColumn +where + C: ResultSetMetadata, +{ + let mut cd = odbc_api::ColumnDescription::default(); + let _ = cursor.describe_col(index, &mut cd); + + OdbcColumn { + name: decode_column_name(cd.name, index), + type_info: OdbcTypeInfo::new(cd.data_type), + ordinal: usize::from(index.checked_sub(1).unwrap()), + } +} + +fn decode_column_name(name_bytes: Vec, index: u16) -> String { + String::from_utf8(name_bytes).unwrap_or_else(|_| format!("col{}", index - 1)) +} + +fn stream_rows(cursor: &mut C, columns: &[OdbcColumn], tx: &ExecuteSender) -> Result +where + C: Cursor, +{ + let mut receiver_open = true; + + while let Some(mut row) = cursor.next_row()? { + let values = collect_row_values(&mut row, columns)?; + let row_data = OdbcRow { + columns: columns.to_vec(), + values: values.into_iter().map(|(_, value)| value).collect(), + }; + + if send_row(tx, row_data).is_err() { + receiver_open = false; + break; + } + } + Ok(receiver_open) +} + +fn collect_row_values( + row: &mut CursorRow<'_>, + columns: &[OdbcColumn], +) -> Result, Error> { + columns + .iter() + .enumerate() + .map(|(i, column)| collect_column_value(row, i, column)) + .collect() +} + +fn collect_column_value( + row: &mut CursorRow<'_>, + index: usize, + column: &OdbcColumn, +) -> Result<(OdbcTypeInfo, crate::odbc::OdbcValue), Error> { + use odbc_api::DataType; + + let col_idx = (index + 1) as u16; + let type_info = column.type_info.clone(); + let data_type = type_info.data_type(); + + let value = match data_type { + DataType::TinyInt + | DataType::SmallInt + | DataType::Integer + | DataType::BigInt + | DataType::Bit => extract_int(row, col_idx, &type_info)?, + + DataType::Real => extract_float::(row, col_idx, &type_info)?, + DataType::Float { .. } | DataType::Double => { + extract_float::(row, col_idx, &type_info)? + } + + DataType::Char { .. } + | DataType::Varchar { .. } + | DataType::LongVarchar { .. } + | DataType::WChar { .. } + | DataType::WVarchar { .. } + | DataType::WLongVarchar { .. } + | DataType::Date + | DataType::Time { .. } + | DataType::Timestamp { .. } + | DataType::Decimal { .. } + | DataType::Numeric { .. } => extract_text(row, col_idx, &type_info)?, + + DataType::Binary { .. } | DataType::Varbinary { .. } | DataType::LongVarbinary { .. } => { + extract_binary(row, col_idx, &type_info)? + } + + DataType::Unknown | DataType::Other { .. } => { + match extract_text(row, col_idx, &type_info) { + Ok(v) => v, + Err(_) => extract_binary(row, col_idx, &type_info)?, + } + } + }; + + Ok((type_info, value)) +} + +fn extract_int( + row: &mut CursorRow<'_>, + col_idx: u16, + type_info: &OdbcTypeInfo, +) -> Result { + let mut nullable = Nullable::::null(); + row.get_data(col_idx, &mut nullable)?; + + let (is_null, int) = match nullable.into_opt() { + None => (true, None), + Some(v) => (false, Some(v)), + }; + + Ok(crate::odbc::OdbcValue { + type_info: type_info.clone(), + is_null, + text: None, + blob: None, + int, + float: None, + }) +} + +fn extract_float( + row: &mut CursorRow<'_>, + col_idx: u16, + type_info: &OdbcTypeInfo, +) -> Result +where + T: Into + Default, + odbc_api::Nullable: odbc_api::parameter::CElement + odbc_api::handles::CDataMut, +{ + let mut nullable = Nullable::::null(); + row.get_data(col_idx, &mut nullable)?; + + let (is_null, float) = match nullable.into_opt() { + None => (true, None), + Some(v) => (false, Some(v.into())), + }; + + Ok(crate::odbc::OdbcValue { + type_info: type_info.clone(), + is_null, + text: None, + blob: None, + int: None, + float, + }) +} + +fn extract_text( + row: &mut CursorRow<'_>, + col_idx: u16, + type_info: &OdbcTypeInfo, +) -> Result { + let mut buf = Vec::new(); + let is_some = row.get_text(col_idx, &mut buf)?; + + let (is_null, text) = if !is_some { + (true, None) + } else { + match String::from_utf8(buf) { + Ok(s) => (false, Some(s)), + Err(e) => return Err(Error::Decode(e.into())), + } + }; + + Ok(crate::odbc::OdbcValue { + type_info: type_info.clone(), + is_null, + text, + blob: None, + int: None, + float: None, + }) +} + +fn extract_binary( + row: &mut CursorRow<'_>, + col_idx: u16, + type_info: &OdbcTypeInfo, +) -> Result { + let mut buf = Vec::new(); + let is_some = row.get_binary(col_idx, &mut buf)?; + + let (is_null, blob) = if !is_some { + (true, None) + } else { + (false, Some(buf)) + }; + + Ok(crate::odbc::OdbcValue { + type_info: type_info.clone(), + is_null, + text: None, + blob, + int: None, + float: None, + }) +} diff --git a/sqlx-core/src/odbc/connection/worker.rs b/sqlx-core/src/odbc/connection/worker.rs deleted file mode 100644 index 2ed7367946..0000000000 --- a/sqlx-core/src/odbc/connection/worker.rs +++ /dev/null @@ -1,781 +0,0 @@ -use std::collections::hash_map::Entry; -use std::collections::HashMap; -use std::thread; - -use flume::{SendError, TrySendError}; -use futures_channel::oneshot; - -use crate::error::Error; -use crate::odbc::{ - OdbcArgumentValue, OdbcArguments, OdbcColumn, OdbcConnectOptions, OdbcQueryResult, OdbcRow, - OdbcTypeInfo, -}; -#[allow(unused_imports)] -use crate::row::Row as SqlxRow; -use either::Either; -#[allow(unused_imports)] -use odbc_api::handles::Statement as OdbcStatementTrait; -use odbc_api::handles::StatementImpl; -use odbc_api::{Cursor, CursorRow, IntoParameter, Nullable, Preallocated, ResultSetMetadata}; - -// Type aliases for commonly used types -type OdbcConnection = odbc_api::Connection<'static>; -type TransactionResult = Result<(), Error>; -type TransactionSender = oneshot::Sender; -type ExecuteResult = Result, Error>; -type ExecuteSender = flume::Sender; -type PrepareResult = Result<(u64, Vec, usize), Error>; -type PrepareSender = oneshot::Sender; - -#[derive(Debug)] -pub(crate) struct ConnectionWorker { - command_tx: flume::Sender, - join_handle: Option>, -} - -#[derive(Debug)] -enum Command { - Ping { - tx: oneshot::Sender<()>, - }, - Shutdown { - tx: oneshot::Sender<()>, - }, - Begin { - tx: TransactionSender, - }, - Commit { - tx: TransactionSender, - }, - Rollback { - tx: TransactionSender, - }, - Execute { - sql: Box, - args: Option, - tx: ExecuteSender, - }, - Prepare { - sql: Box, - tx: PrepareSender, - }, - GetDbmsName { - tx: oneshot::Sender>, - }, -} - -impl Drop for ConnectionWorker { - fn drop(&mut self) { - self.shutdown_sync(); - } -} - -impl ConnectionWorker { - pub async fn establish(options: OdbcConnectOptions) -> Result { - let (command_tx, command_rx) = flume::bounded(64); - let (conn_tx, conn_rx) = oneshot::channel(); - let thread = thread::Builder::new() - .name("sqlx-odbc-conn".into()) - .spawn(move || worker_thread_main(options, command_rx, conn_tx))?; - - conn_rx.await.map_err(|_| Error::WorkerCrashed)??; - Ok(ConnectionWorker { - command_tx, - join_handle: Some(thread), - }) - } - - pub(crate) async fn ping(&mut self) -> Result<(), Error> { - let (tx, rx) = oneshot::channel(); - send_command_and_await(&self.command_tx, Command::Ping { tx }, rx).await - } - - pub(crate) async fn shutdown(&mut self) -> Result<(), Error> { - let (tx, rx) = oneshot::channel(); - send_command_and_await(&self.command_tx, Command::Shutdown { tx }, rx).await - } - - pub(crate) fn shutdown_sync(&mut self) { - // Send shutdown command without waiting for response - // Use try_send to avoid any potential blocking in Drop - - if let Some(join_handle) = self.join_handle.take() { - let (mut tx, _rx) = oneshot::channel(); - while let Err(TrySendError::Full(Command::Shutdown { tx: t })) = - self.command_tx.try_send(Command::Shutdown { tx }) - { - tx = t; - log::warn!("odbc worker thread queue is full, retrying..."); - thread::sleep(std::time::Duration::from_millis(10)); - } - if let Err(e) = join_handle.join() { - let err = e.downcast_ref::(); - log::error!( - "failed to join worker thread while shutting down: {:?}", - err - ); - } - } - } - - pub(crate) async fn begin(&mut self) -> Result<(), Error> { - let (tx, rx) = oneshot::channel(); - send_transaction_command(&self.command_tx, Command::Begin { tx }, rx).await - } - - pub(crate) async fn commit(&mut self) -> Result<(), Error> { - let (tx, rx) = oneshot::channel(); - send_transaction_command(&self.command_tx, Command::Commit { tx }, rx).await - } - - pub(crate) async fn rollback(&mut self) -> Result<(), Error> { - let (tx, rx) = oneshot::channel(); - send_transaction_command(&self.command_tx, Command::Rollback { tx }, rx).await - } - - pub(crate) async fn execute_stream( - &mut self, - sql: &str, - args: Option, - ) -> Result, Error>>, Error> { - let (tx, rx) = flume::bounded(64); - self.command_tx - .send_async(Command::Execute { - sql: sql.into(), - args, - tx, - }) - .await - .map_err(|_| Error::WorkerCrashed)?; - Ok(rx) - } - - pub(crate) async fn prepare( - &mut self, - sql: &str, - ) -> Result<(u64, Vec, usize), Error> { - let (tx, rx) = oneshot::channel(); - send_command_and_await( - &self.command_tx, - Command::Prepare { - sql: sql.into(), - tx, - }, - rx, - ) - .await? - } - - pub(crate) async fn get_dbms_name(&mut self) -> Result { - let (tx, rx) = oneshot::channel(); - send_command_and_await(&self.command_tx, Command::GetDbmsName { tx }, rx).await? - } -} - -// Worker thread implementation -fn worker_thread_main( - options: OdbcConnectOptions, - command_rx: flume::Receiver, - conn_tx: oneshot::Sender>, -) { - // Establish connection - let conn = match establish_connection(&options) { - Ok(conn) => { - log::debug!("ODBC connection established successfully"); - let _ = conn_tx.send(Ok(())); - conn - } - Err(e) => { - let _ = conn_tx.send(Err(e)); - return; - } - }; - - let mut stmt_manager = StatementManager::new(&conn); - - // Process commands - while let Ok(cmd) = command_rx.recv() { - log::trace!("Processing command: {:?}", cmd); - match process_command(cmd, &conn, &mut stmt_manager) { - Ok(CommandControlFlow::Continue) => {} - Ok(CommandControlFlow::Shutdown(shutdown_tx)) => { - log::debug!("Shutting down ODBC worker thread"); - drop(stmt_manager); - drop(conn); - send_oneshot(shutdown_tx, (), "shutdown"); - break; - } - Err(()) => { - log::error!("ODBC worker error while processing command"); - } - } - } - // Channel disconnected or shutdown command received, worker thread exits -} - -fn establish_connection(options: &OdbcConnectOptions) -> Result { - // Get or create the shared ODBC environment - // This ensures thread-safe initialization and prevents concurrent environment creation issues - let env = odbc_api::environment().map_err(|e| Error::Configuration(e.to_string().into()))?; - - let conn = env - .connect_with_connection_string(options.connection_string(), Default::default()) - .map_err(|e| Error::Configuration(e.to_string().into()))?; - - Ok(conn) -} - -/// Statement manager to handle preallocated statements -struct StatementManager<'conn> { - conn: &'conn OdbcConnection, - // Reusable statement for direct execution - direct_stmt: Option>>, - // Cache of prepared statements by SQL hash - prepared_cache: HashMap>>, -} - -impl<'conn> StatementManager<'conn> { - fn new(conn: &'conn OdbcConnection) -> Self { - log::debug!("Creating new statement manager"); - Self { - conn, - direct_stmt: None, - prepared_cache: HashMap::new(), - } - } - - fn get_or_create_direct_stmt( - &mut self, - ) -> Result<&mut Preallocated>, Error> { - if self.direct_stmt.is_none() { - log::debug!("Preallocating ODBC direct statement"); - self.direct_stmt = Some(self.conn.preallocate().map_err(Error::from)?); - } - Ok(self.direct_stmt.as_mut().unwrap()) - } - - fn get_or_create_prepared( - &mut self, - sql: &str, - ) -> Result<&mut odbc_api::Prepared>, Error> { - use std::collections::hash_map::DefaultHasher; - use std::hash::{Hash, Hasher}; - - let mut hasher = DefaultHasher::new(); - sql.hash(&mut hasher); - let sql_hash = hasher.finish(); - - match self.prepared_cache.entry(sql_hash) { - Entry::Vacant(e) => { - log::trace!("Preparing statement for SQL: {}", sql); - let prepared = self.conn.prepare(sql)?; - Ok(e.insert(prepared)) - } - Entry::Occupied(e) => { - log::trace!("Using prepared statement for SQL: {}", sql); - Ok(e.into_mut()) - } - } - } -} -// Helper function to send results through oneshot channels with consistent error handling -fn send_oneshot(tx: oneshot::Sender, result: T, operation: &str) { - if tx.send(result).is_err() { - log::warn!("Failed to send {} result: receiver dropped", operation); - } -} - -fn send_stream_result( - tx: &ExecuteSender, - result: ExecuteResult, -) -> Result<(), SendError> { - tx.send(result) -} - -async fn send_command_and_await( - command_tx: &flume::Sender, - cmd: Command, - rx: oneshot::Receiver, -) -> Result { - command_tx - .send_async(cmd) - .await - .map_err(|_| Error::WorkerCrashed)?; - rx.await.map_err(|_| Error::WorkerCrashed) -} - -async fn send_transaction_command( - command_tx: &flume::Sender, - cmd: Command, - rx: oneshot::Receiver, -) -> Result<(), Error> { - send_command_and_await(command_tx, cmd, rx).await??; - Ok(()) -} - -// Utility functions for transaction operations -fn execute_transaction_operation( - conn: &OdbcConnection, - operation: F, - operation_name: &str, -) -> TransactionResult -where - F: FnOnce(&OdbcConnection) -> Result<(), odbc_api::Error>, -{ - log::trace!("{} odbc transaction", operation_name); - operation(conn) - .map_err(|e| Error::Protocol(format!("Failed to {} transaction: {}", operation_name, e))) -} - -#[derive(Debug)] -enum CommandControlFlow { - Shutdown(oneshot::Sender<()>), - Continue, -} - -type CommandResult = Result; - -// Returns a shutdown tx if the command is a shutdown command -fn process_command<'conn>( - cmd: Command, - conn: &'conn OdbcConnection, - stmt_manager: &mut StatementManager<'conn>, -) -> CommandResult { - match cmd { - Command::Ping { tx } => handle_ping(conn, tx), - Command::Begin { tx } => handle_begin(conn, tx), - Command::Commit { tx } => handle_commit(conn, tx), - Command::Rollback { tx } => handle_rollback(conn, tx), - Command::Shutdown { tx } => Ok(CommandControlFlow::Shutdown(tx)), - Command::Execute { sql, args, tx } => handle_execute(stmt_manager, sql, args, tx), - Command::Prepare { sql, tx } => handle_prepare(stmt_manager, sql, tx), - Command::GetDbmsName { tx } => handle_get_dbms_name(conn, tx), - } -} - -// Command handlers -fn handle_ping(conn: &OdbcConnection, tx: oneshot::Sender<()>) -> CommandResult { - match conn.execute("SELECT 1", (), None) { - Ok(_) => send_oneshot(tx, (), "ping"), - Err(e) => log::error!("Ping failed: {}", e), - } - Ok(CommandControlFlow::Continue) -} - -fn handle_begin(conn: &OdbcConnection, tx: TransactionSender) -> CommandResult { - let result = execute_transaction_operation(conn, |c| c.set_autocommit(false), "begin"); - send_oneshot(tx, result, "begin transaction"); - Ok(CommandControlFlow::Continue) -} - -fn handle_commit(conn: &OdbcConnection, tx: TransactionSender) -> CommandResult { - let result = execute_transaction_operation( - conn, - |c| c.commit().and_then(|_| c.set_autocommit(true)), - "commit", - ); - send_oneshot(tx, result, "commit transaction"); - Ok(CommandControlFlow::Continue) -} - -fn handle_rollback(conn: &OdbcConnection, tx: TransactionSender) -> CommandResult { - let result = execute_transaction_operation( - conn, - |c| c.rollback().and_then(|_| c.set_autocommit(true)), - "rollback", - ); - send_oneshot(tx, result, "rollback transaction"); - Ok(CommandControlFlow::Continue) -} -fn handle_prepare<'conn>( - stmt_manager: &mut StatementManager<'conn>, - sql: Box, - tx: PrepareSender, -) -> CommandResult { - let result = do_prepare(stmt_manager, sql); - send_oneshot(tx, result, "prepare"); - Ok(CommandControlFlow::Continue) -} - -fn do_prepare<'conn>(stmt_manager: &mut StatementManager<'conn>, sql: Box) -> PrepareResult { - log::trace!("Preparing statement: {}", sql); - // Use the statement manager to get or create the prepared statement - let prepared = stmt_manager.get_or_create_prepared(&sql)?; - let columns = collect_columns(prepared); - let params = usize::from(prepared.num_params().unwrap_or(0)); - log::debug!( - "Prepared statement with {} columns and {} parameters", - columns.len(), - params - ); - Ok((0, columns, params)) -} - -fn handle_get_dbms_name( - conn: &OdbcConnection, - tx: oneshot::Sender>, -) -> CommandResult { - log::debug!("Getting DBMS name"); - let result = conn - .database_management_system_name() - .map_err(|e| Error::Protocol(format!("Failed to get DBMS name: {}", e))); - send_oneshot(tx, result, "DBMS name"); - Ok(CommandControlFlow::Continue) -} - -fn handle_execute<'conn>( - stmt_manager: &mut StatementManager<'conn>, - sql: Box, - args: Option, - tx: ExecuteSender, -) -> CommandResult { - log::debug!( - "Executing SQL: {}", - sql.chars().take(100).collect::() - ); - - let result = execute_sql(stmt_manager, &sql, args, &tx); - with_result_send_error(result, &tx, |_| {}); - Ok(CommandControlFlow::Continue) -} - -// SQL execution functions -fn execute_sql<'conn>( - stmt_manager: &mut StatementManager<'conn>, - sql: &str, - args: Option, - tx: &ExecuteSender, -) -> Result<(), Error> { - let params = prepare_parameters(args); - let stmt = stmt_manager.get_or_create_direct_stmt()?; - log::trace!("Starting execution of SQL: {}", sql); - - // Execute and handle result immediately to avoid borrowing conflicts - if let Some(mut cursor) = stmt.execute(sql, ¶ms[..])? { - handle_cursor(&mut cursor, tx); - return Ok(()); - } - - // Execution completed without result set, get affected rows - let affected = extract_rows_affected(stmt); - let _ = send_done(tx, affected); - Ok(()) -} - -fn extract_rows_affected(stmt: &mut Preallocated>) -> u64 { - let count_opt = match stmt.row_count() { - Ok(count_opt) => count_opt, - Err(e) => { - log::warn!("Failed to get ODBC row count: {}", e); - return 0; - } - }; - - let count = match count_opt { - Some(count) => count, - None => { - log::debug!("ODBC row count is not available"); - return 0; - } - }; - - let affected = match u64::try_from(count) { - Ok(count) => count, - Err(e) => { - log::warn!("Failed to convert ODBC row count to u64: {}", e); - return 0; - } - }; - log::trace!("ODBC statement affected {} rows", affected); - affected -} - -fn prepare_parameters( - args: Option, -) -> Vec> { - let args = args.map(|a| a.values).unwrap_or_default(); - args.into_iter().map(to_param).collect() -} - -fn to_param(arg: OdbcArgumentValue) -> Box { - match arg { - OdbcArgumentValue::Int(i) => Box::new(i.into_parameter()), - OdbcArgumentValue::Float(f) => Box::new(f.into_parameter()), - OdbcArgumentValue::Text(s) => Box::new(s.into_parameter()), - OdbcArgumentValue::Bytes(b) => Box::new(b.into_parameter()), - OdbcArgumentValue::Null => Box::new(Option::::None.into_parameter()), - } -} - -fn handle_cursor(cursor: &mut C, tx: &ExecuteSender) -where - C: Cursor + ResultSetMetadata, -{ - let columns = collect_columns(cursor); - log::trace!("Processing ODBC result set with {} columns", columns.len()); - - match stream_rows(cursor, &columns, tx) { - Ok(true) => { - log::trace!("Successfully streamed all rows"); - let _ = send_done(tx, 0); - } - Ok(false) => { - log::trace!("Row streaming stopped early (receiver closed)"); - } - Err(e) => { - send_error(tx, e); - } - } -} - -// Unified result sending functions -fn send_done(tx: &ExecuteSender, rows_affected: u64) -> Result<(), SendError> { - send_stream_result(tx, Ok(Either::Left(OdbcQueryResult { rows_affected }))) -} - -fn with_result_send_error( - result: Result, - tx: &ExecuteSender, - handler: impl FnOnce(T), -) { - match result { - Ok(result) => handler(result), - Err(error) => send_error(tx, error), - } -} - -fn send_error(tx: &ExecuteSender, error: Error) { - if let Err(e) = send_stream_result(tx, Err(error)) { - log::error!("Failed to send error from ODBC worker thread: {}", e); - } -} - -fn send_row(tx: &ExecuteSender, row: OdbcRow) -> Result<(), SendError> { - send_stream_result(tx, Ok(Either::Right(row))) -} - -// Metadata and row processing -fn collect_columns(cursor: &mut C) -> Vec -where - C: ResultSetMetadata, -{ - let count = cursor.num_result_cols().unwrap_or(0); - - (1..=count) - .map(|i| create_column(cursor, i as u16)) - .collect() -} - -fn create_column(cursor: &mut C, index: u16) -> OdbcColumn -where - C: ResultSetMetadata, -{ - let mut cd = odbc_api::ColumnDescription::default(); - let _ = cursor.describe_col(index, &mut cd); - - OdbcColumn { - name: decode_column_name(cd.name, index), - type_info: OdbcTypeInfo::new(cd.data_type), - ordinal: usize::from(index.checked_sub(1).unwrap()), - } -} - -fn decode_column_name(name_bytes: Vec, index: u16) -> String { - String::from_utf8(name_bytes).unwrap_or_else(|_| format!("col{}", index - 1)) -} - -fn stream_rows(cursor: &mut C, columns: &[OdbcColumn], tx: &ExecuteSender) -> Result -where - C: Cursor, -{ - let mut receiver_open = true; - let mut row_count = 0; - - while let Some(mut row) = cursor.next_row()? { - let values = collect_row_values(&mut row, columns)?; - let row_data = OdbcRow { - columns: columns.to_vec(), - values: values.into_iter().map(|(_, value)| value).collect(), - }; - - if send_row(tx, row_data).is_err() { - log::debug!("Receiver closed after {} rows", row_count); - receiver_open = false; - break; - } - row_count += 1; - } - - if receiver_open { - log::debug!("Streamed {} rows successfully", row_count); - } - Ok(receiver_open) -} - -fn collect_row_values( - row: &mut CursorRow<'_>, - columns: &[OdbcColumn], -) -> Result, Error> { - columns - .iter() - .enumerate() - .map(|(i, column)| collect_column_value(row, i, column)) - .collect() -} - -fn collect_column_value( - row: &mut CursorRow<'_>, - index: usize, - column: &OdbcColumn, -) -> Result<(OdbcTypeInfo, crate::odbc::OdbcValue), Error> { - use odbc_api::DataType; - - let col_idx = (index + 1) as u16; - let type_info = column.type_info.clone(); - let data_type = type_info.data_type(); - - // Extract value based on data type - let value = match data_type { - // Integer types - DataType::TinyInt - | DataType::SmallInt - | DataType::Integer - | DataType::BigInt - | DataType::Bit => extract_int(row, col_idx, &type_info)?, - - // Floating point types - DataType::Real => extract_float::(row, col_idx, &type_info)?, - DataType::Float { .. } | DataType::Double => { - extract_float::(row, col_idx, &type_info)? - } - - // String types - DataType::Char { .. } - | DataType::Varchar { .. } - | DataType::LongVarchar { .. } - | DataType::WChar { .. } - | DataType::WVarchar { .. } - | DataType::WLongVarchar { .. } - | DataType::Date - | DataType::Time { .. } - | DataType::Timestamp { .. } - | DataType::Decimal { .. } - | DataType::Numeric { .. } => extract_text(row, col_idx, &type_info)?, - - // Binary types - DataType::Binary { .. } | DataType::Varbinary { .. } | DataType::LongVarbinary { .. } => { - extract_binary(row, col_idx, &type_info)? - } - - // Unknown types - try text first, fall back to binary - DataType::Unknown | DataType::Other { .. } => { - match extract_text(row, col_idx, &type_info) { - Ok(v) => v, - Err(_) => extract_binary(row, col_idx, &type_info)?, - } - } - }; - - Ok((type_info, value)) -} - -fn extract_int( - row: &mut CursorRow<'_>, - col_idx: u16, - type_info: &OdbcTypeInfo, -) -> Result { - let mut nullable = Nullable::::null(); - row.get_data(col_idx, &mut nullable)?; - - let (is_null, int) = match nullable.into_opt() { - None => (true, None), - Some(v) => (false, Some(v)), - }; - - Ok(crate::odbc::OdbcValue { - type_info: type_info.clone(), - is_null, - text: None, - blob: None, - int, - float: None, - }) -} - -fn extract_float( - row: &mut CursorRow<'_>, - col_idx: u16, - type_info: &OdbcTypeInfo, -) -> Result -where - T: Into + Default, - odbc_api::Nullable: odbc_api::parameter::CElement + odbc_api::handles::CDataMut, -{ - let mut nullable = Nullable::::null(); - row.get_data(col_idx, &mut nullable)?; - - let (is_null, float) = match nullable.into_opt() { - None => (true, None), - Some(v) => (false, Some(v.into())), - }; - - Ok(crate::odbc::OdbcValue { - type_info: type_info.clone(), - is_null, - text: None, - blob: None, - int: None, - float, - }) -} - -fn extract_text( - row: &mut CursorRow<'_>, - col_idx: u16, - type_info: &OdbcTypeInfo, -) -> Result { - let mut buf = Vec::new(); - let is_some = row.get_text(col_idx, &mut buf)?; - - let (is_null, text) = if !is_some { - (true, None) - } else { - match String::from_utf8(buf) { - Ok(s) => (false, Some(s)), - Err(e) => return Err(Error::Decode(e.into())), - } - }; - - Ok(crate::odbc::OdbcValue { - type_info: type_info.clone(), - is_null, - text, - blob: None, - int: None, - float: None, - }) -} - -fn extract_binary( - row: &mut CursorRow<'_>, - col_idx: u16, - type_info: &OdbcTypeInfo, -) -> Result { - let mut buf = Vec::new(); - let is_some = row.get_binary(col_idx, &mut buf)?; - - let (is_null, blob) = if !is_some { - (true, None) - } else { - (false, Some(buf)) - }; - - Ok(crate::odbc::OdbcValue { - type_info: type_info.clone(), - is_null, - text: None, - blob, - int: None, - float: None, - }) -} diff --git a/sqlx-core/src/odbc/mod.rs b/sqlx-core/src/odbc/mod.rs index da41adb1e9..492cc370b6 100644 --- a/sqlx-core/src/odbc/mod.rs +++ b/sqlx-core/src/odbc/mod.rs @@ -31,7 +31,7 @@ mod error; mod options; mod query_result; mod row; -mod statement; +pub mod statement; mod transaction; mod type_info; pub mod types; @@ -44,7 +44,7 @@ pub use database::Odbc; pub use options::OdbcConnectOptions; pub use query_result::OdbcQueryResult; pub use row::OdbcRow; -pub use statement::OdbcStatement; +pub use statement::{OdbcStatement, OdbcStatementMetadata}; pub use transaction::OdbcTransactionManager; pub use type_info::{DataTypeExt, OdbcTypeInfo}; pub use value::{OdbcValue, OdbcValueRef}; diff --git a/sqlx-core/src/odbc/statement.rs b/sqlx-core/src/odbc/statement.rs index beeef9807a..d7345deecc 100644 --- a/sqlx-core/src/odbc/statement.rs +++ b/sqlx-core/src/odbc/statement.rs @@ -8,8 +8,13 @@ use std::borrow::Cow; #[derive(Debug, Clone)] pub struct OdbcStatement<'q> { pub(crate) sql: Cow<'q, str>, - pub(crate) columns: Vec, - pub(crate) parameters: usize, + pub(crate) metadata: OdbcStatementMetadata, +} + +#[derive(Debug, Clone)] +pub struct OdbcStatementMetadata { + pub columns: Vec, + pub parameters: usize, } impl<'q> Statement<'q> for OdbcStatement<'q> { @@ -18,8 +23,7 @@ impl<'q> Statement<'q> for OdbcStatement<'q> { fn to_owned(&self) -> OdbcStatement<'static> { OdbcStatement { sql: Cow::Owned(self.sql.to_string()), - columns: self.columns.clone(), - parameters: self.parameters, + metadata: self.metadata.clone(), } } @@ -27,10 +31,10 @@ impl<'q> Statement<'q> for OdbcStatement<'q> { &self.sql } fn parameters(&self) -> Option> { - Some(Either::Right(self.parameters)) + Some(Either::Right(self.metadata.parameters)) } fn columns(&self) -> &[OdbcColumn] { - &self.columns + &self.metadata.columns } // ODBC arguments placeholder @@ -40,6 +44,7 @@ impl<'q> Statement<'q> for OdbcStatement<'q> { impl ColumnIndex> for &'_ str { fn index(&self, statement: &OdbcStatement<'_>) -> Result { statement + .metadata .columns .iter() .position(|c| c.name == *self) @@ -54,21 +59,22 @@ impl<'q> From> for crate::any::AnyStatement<'q> { // First build the columns and collect names let columns: Vec<_> = stmt + .metadata .columns - .into_iter() + .iter() .enumerate() .map(|(index, col)| { column_names.insert(crate::ext::ustr::UStr::new(&col.name), index); crate::any::AnyColumn { kind: crate::any::column::AnyColumnKind::Odbc(col.clone()), - type_info: crate::any::AnyTypeInfo::from(col.type_info), + type_info: crate::any::AnyTypeInfo::from(col.type_info.clone()), } }) .collect(); crate::any::AnyStatement { sql: stmt.sql, - parameters: Some(either::Either::Right(stmt.parameters)), + parameters: Some(either::Either::Right(stmt.metadata.parameters)), columns, column_names: std::sync::Arc::new(column_names), } diff --git a/sqlx-core/src/odbc/transaction.rs b/sqlx-core/src/odbc/transaction.rs index 2556c16784..aa57d73859 100644 --- a/sqlx-core/src/odbc/transaction.rs +++ b/sqlx-core/src/odbc/transaction.rs @@ -11,19 +11,19 @@ impl TransactionManager for OdbcTransactionManager { fn begin( conn: &mut ::Connection, ) -> BoxFuture<'_, Result<(), Error>> { - Box::pin(async move { conn.worker.begin().await }) + Box::pin(async move { conn.begin_blocking().await }) } fn commit( conn: &mut ::Connection, ) -> BoxFuture<'_, Result<(), Error>> { - Box::pin(async move { conn.worker.commit().await }) + Box::pin(async move { conn.commit_blocking().await }) } fn rollback( conn: &mut ::Connection, ) -> BoxFuture<'_, Result<(), Error>> { - Box::pin(async move { conn.worker.rollback().await }) + Box::pin(async move { conn.rollback_blocking().await }) } fn start_rollback(_conn: &mut ::Connection) { diff --git a/sqlx-rt/src/rt_async_std.rs b/sqlx-rt/src/rt_async_std.rs index aeca8541ab..e8ccb49849 100644 --- a/sqlx-rt/src/rt_async_std.rs +++ b/sqlx-rt/src/rt_async_std.rs @@ -1,7 +1,8 @@ pub use async_std::{ self, fs, future::timeout, io::prelude::ReadExt as AsyncReadExt, io::prelude::WriteExt as AsyncWriteExt, io::Read as AsyncRead, io::Write as AsyncWrite, - net::TcpStream, sync::Mutex as AsyncMutex, task::sleep, task::spawn, task::yield_now, + net::TcpStream, sync::Mutex as AsyncMutex, task::sleep, task::spawn, task::spawn_blocking, + task::yield_now, }; #[cfg(unix)] diff --git a/sqlx-rt/src/rt_tokio.rs b/sqlx-rt/src/rt_tokio.rs index b1d3bc8149..72b2cbb27b 100644 --- a/sqlx-rt/src/rt_tokio.rs +++ b/sqlx-rt/src/rt_tokio.rs @@ -45,3 +45,14 @@ pub fn test_block_on(future: F) -> F::Output { .expect("failed to initialize Tokio test runtime") .block_on(future) } + +/// Spawn a blocking task. Panics if the task panics. +pub async fn spawn_blocking(f: F) -> R +where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, +{ + tokio::task::spawn_blocking(f) + .await + .expect("blocking task panicked") +} diff --git a/tests/odbc/odbc.rs b/tests/odbc/odbc.rs index f92b73d881..c22dbc054e 100644 --- a/tests/odbc/odbc.rs +++ b/tests/odbc/odbc.rs @@ -149,7 +149,7 @@ async fn it_fetch_optional_some_and_none() -> anyhow::Result<()> { #[tokio::test] async fn it_can_prepare_then_query_without_params() -> anyhow::Result<()> { let mut conn = new::().await?; - let stmt = (&mut conn).prepare("SELECT 7 AS seven").await?; + let stmt = conn.prepare("SELECT 7 AS seven").await?; let row = stmt.query().fetch_one(&mut conn).await?; let col_name = row.column(0).name(); assert!( @@ -166,7 +166,7 @@ async fn it_can_prepare_then_query_without_params() -> anyhow::Result<()> { async fn it_can_prepare_then_query_with_params_integer_float_text() -> anyhow::Result<()> { let mut conn = new::().await?; - let stmt = (&mut conn).prepare("SELECT ? AS i, ? AS f, ? AS t").await?; + let stmt = conn.prepare("SELECT ? AS i, ? AS f, ? AS t").await?; let row = stmt .query() @@ -217,7 +217,7 @@ async fn it_can_bind_many_params_dynamically() -> anyhow::Result<()> { sql.push('?'); } - let stmt = (&mut conn).prepare(&sql).await?; + let stmt = conn.prepare(&sql).await?; let values: Vec = (1..=count as i32).collect(); let mut q = stmt.query(); @@ -237,7 +237,7 @@ async fn it_can_bind_many_params_dynamically() -> anyhow::Result<()> { async fn it_can_bind_heterogeneous_params() -> anyhow::Result<()> { let mut conn = new::().await?; - let stmt = (&mut conn).prepare("SELECT ?, ?, ?, ?, ?").await?; + let stmt = conn.prepare("SELECT ?, ?, ?, ?, ?").await?; let row = stmt .query() @@ -266,7 +266,7 @@ async fn it_can_bind_heterogeneous_params() -> anyhow::Result<()> { #[tokio::test] async fn it_binds_null_string_parameter() -> anyhow::Result<()> { let mut conn = new::().await?; - let stmt = (&mut conn).prepare("SELECT ?, ?").await?; + let stmt = conn.prepare("SELECT ?, ?").await?; let row = stmt .query() .bind("abc") @@ -424,7 +424,7 @@ async fn it_handles_large_strings() -> anyhow::Result<()> { // Test a moderately large string let large_string = "a".repeat(1000); - let stmt = (&mut conn).prepare("SELECT ? AS large_str").await?; + let stmt = conn.prepare("SELECT ? AS large_str").await?; let row = stmt .query() .bind(&large_string) @@ -443,7 +443,7 @@ async fn it_handles_binary_data() -> anyhow::Result<()> { // Test binary data - use UTF-8 safe bytes for PostgreSQL compatibility let binary_data = b"ABCDE"; - let stmt = (&mut conn).prepare("SELECT ? AS binary_data").await?; + let stmt = conn.prepare("SELECT ? AS binary_data").await?; let row = stmt .query() .bind(&binary_data[..]) @@ -459,7 +459,7 @@ async fn it_handles_binary_data() -> anyhow::Result<()> { async fn it_handles_mixed_null_and_values() -> anyhow::Result<()> { let mut conn = new::().await?; - let stmt = (&mut conn).prepare("SELECT ?, ?, ?, ?").await?; + let stmt = conn.prepare("SELECT ?, ?, ?, ?").await?; let row = stmt .query() .bind(42_i32) @@ -505,7 +505,7 @@ async fn it_handles_slice_types() -> anyhow::Result<()> { // Test slice types let test_data = b"Hello, ODBC!"; - let stmt = (&mut conn).prepare("SELECT ? AS slice_data").await?; + let stmt = conn.prepare("SELECT ? AS slice_data").await?; let row = stmt .query() .bind(&test_data[..]) @@ -528,7 +528,7 @@ async fn it_handles_uuid() -> anyhow::Result<()> { let uuid_str = test_uuid.to_string(); // Test UUID as string - let stmt = (&mut conn).prepare("SELECT ? AS uuid_data").await?; + let stmt = conn.prepare("SELECT ? AS uuid_data").await?; let row = stmt.query().bind(&uuid_str).fetch_one(&mut conn).await?; let result = row.try_get_raw(0)?.to_owned().decode::(); @@ -536,7 +536,7 @@ async fn it_handles_uuid() -> anyhow::Result<()> { // Test with a specific UUID string let specific_uuid_str = "550e8400-e29b-41d4-a716-446655440000"; - let stmt = (&mut conn).prepare("SELECT ? AS uuid_data").await?; + let stmt = conn.prepare("SELECT ? AS uuid_data").await?; let row = stmt .query() .bind(specific_uuid_str) @@ -563,7 +563,7 @@ async fn it_handles_json() -> anyhow::Result<()> { }); let json_str = test_json.to_string(); - let stmt = (&mut conn).prepare("SELECT ? AS json_data").await?; + let stmt = conn.prepare("SELECT ? AS json_data").await?; let row = stmt.query().bind(&json_str).fetch_one(&mut conn).await?; let result: Value = row.try_get_raw(0)?.to_owned().decode(); @@ -581,7 +581,7 @@ async fn it_handles_bigdecimal() -> anyhow::Result<()> { let test_decimal = BigDecimal::from_str("123.456789")?; let decimal_str = test_decimal.to_string(); - let stmt = (&mut conn).prepare("SELECT ? AS decimal_data").await?; + let stmt = conn.prepare("SELECT ? AS decimal_data").await?; let row = stmt.query().bind(&decimal_str).fetch_one(&mut conn).await?; let result = row.try_get_raw(0)?.to_owned().decode::(); @@ -598,7 +598,7 @@ async fn it_handles_rust_decimal() -> anyhow::Result<()> { let test_decimal = "123.456789".parse::()?; let decimal_str = test_decimal.to_string(); - let stmt = (&mut conn).prepare("SELECT ? AS decimal_data").await?; + let stmt = conn.prepare("SELECT ? AS decimal_data").await?; let row = stmt.query().bind(&decimal_str).fetch_one(&mut conn).await?; let result = row.try_get_raw(0)?.to_owned().decode::(); @@ -621,7 +621,7 @@ async fn it_handles_chrono_datetime() -> anyhow::Result<()> { let test_datetime = NaiveDateTime::new(test_date, test_time); // Test that we can encode chrono types (by storing them as strings) - let stmt = (&mut conn).prepare("SELECT ? AS date_data").await?; + let stmt = conn.prepare("SELECT ? AS date_data").await?; let row = stmt.query().bind(test_date).fetch_one(&mut conn).await?; // Decode as string and verify format @@ -629,14 +629,14 @@ async fn it_handles_chrono_datetime() -> anyhow::Result<()> { assert_eq!(result_str, "2023-12-25"); // Test time encoding - let stmt = (&mut conn).prepare("SELECT ? AS time_data").await?; + let stmt = conn.prepare("SELECT ? AS time_data").await?; let row = stmt.query().bind(test_time).fetch_one(&mut conn).await?; let result_str = row.try_get_raw(0)?.to_owned().decode::(); assert_eq!(result_str, "14:30:00"); // Test datetime encoding - let stmt = (&mut conn).prepare("SELECT ? AS datetime_data").await?; + let stmt = conn.prepare("SELECT ? AS datetime_data").await?; let row = stmt .query() .bind(test_datetime) @@ -764,7 +764,7 @@ async fn it_handles_prepare_statement_errors() -> anyhow::Result<()> { // So we test that execution fails even if preparation succeeds // Test executing prepared invalid SQL - if let Ok(stmt) = (&mut conn).prepare("INVALID PREPARE STATEMENT").await { + if let Ok(stmt) = conn.prepare("INVALID PREPARE STATEMENT").await { let result = stmt.query().fetch_one(&mut conn).await; let err = result.expect_err("should be an error"); assert!( @@ -775,7 +775,7 @@ async fn it_handles_prepare_statement_errors() -> anyhow::Result<()> { } // Test executing prepared SQL with syntax errors - match (&mut conn) + match conn .prepare("SELECT idonotexist FROM idonotexist WHERE idonotexist") .await { @@ -811,9 +811,7 @@ async fn it_handles_parameter_binding_errors() -> anyhow::Result<()> { let mut conn = new::().await?; // Test with completely missing parameters - this should more reliably fail - let stmt = (&mut conn) - .prepare("SELECT ? AS param1, ? AS param2") - .await?; + let stmt = conn.prepare("SELECT ? AS param1, ? AS param2").await?; // Test with no parameters when some are expected let result = stmt.query().fetch_one(&mut conn).await; @@ -824,7 +822,7 @@ async fn it_handles_parameter_binding_errors() -> anyhow::Result<()> { // Test that we can handle parameter binding gracefully // Even if the driver is permissive, the system should be robust - let stmt2 = (&mut conn).prepare("SELECT ? AS single_param").await?; + let stmt2 = conn.prepare("SELECT ? AS single_param").await?; // Bind correct number of parameters - this should work let result = stmt2.query().bind(42i32).fetch_one(&mut conn).await; @@ -842,7 +840,7 @@ async fn it_handles_parameter_execution_errors() -> anyhow::Result<()> { let mut conn = new::().await?; // Test parameter binding with incompatible operations that should fail at execution - let stmt = (&mut conn).prepare("SELECT ? / 0 AS div_by_zero").await?; + let stmt = conn.prepare("SELECT ? / 0 AS div_by_zero").await?; // This should execute but may produce a runtime error (division by zero) let result = stmt.query().bind(42i32).fetch_one(&mut conn).await; @@ -850,7 +848,7 @@ async fn it_handles_parameter_execution_errors() -> anyhow::Result<()> { let _ = result; // Test with a parameter in an invalid context that should fail - if let Ok(stmt) = (&mut conn).prepare("SELECT * FROM ?").await { + if let Ok(stmt) = conn.prepare("SELECT * FROM ?").await { // Using parameter as table name should fail at execution let result = stmt .query() @@ -1019,7 +1017,7 @@ async fn it_handles_prepared_statement_with_wrong_parameters() -> anyhow::Result let mut conn = new::().await?; // Prepare a statement expecting specific parameter types - let stmt = (&mut conn).prepare("SELECT ? + ? AS sum").await?; + let stmt = conn.prepare("SELECT ? + ? AS sum").await?; // Test binding incompatible types (if the database is strict about types) // Some databases/drivers are permissive, others are strict