Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 9 additions & 27 deletions sqlx-core/src/odbc/connection/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@ use crate::odbc::{Odbc, OdbcConnection, OdbcQueryResult, OdbcRow, OdbcStatement,
use either::Either;
use futures_core::future::BoxFuture;
use futures_core::stream::BoxStream;
use futures_util::TryStreamExt;
use std::borrow::Cow;
use futures_util::{future, FutureExt, StreamExt};

// run method removed; fetch_many implements streaming directly

Expand All @@ -21,15 +20,8 @@ impl<'c> Executor<'c> for &'c mut OdbcConnection {
'c: 'e,
E: Execute<'q, Self::Database> + 'q,
{
let sql = query.sql().to_string();
let args = query.take_arguments();
Box::pin(try_stream! {
let rx = self.worker.execute_stream(&sql, args).await?;
while let Ok(item) = rx.recv_async().await {
r#yield!(item?);
}
Ok(())
})
Box::pin(self.execute_stream(query.sql(), args).into_stream())
}

fn fetch_optional<'e, 'q: 'e, E>(
Expand All @@ -40,15 +32,12 @@ impl<'c> Executor<'c> for &'c mut OdbcConnection {
'c: 'e,
E: Execute<'q, Self::Database> + 'q,
{
let mut s = self.fetch_many(query);
Box::pin(async move {
while let Some(v) = s.try_next().await? {
if let Either::Right(r) = v {
return Ok(Some(r));
}
}
Ok(None)
})
Box::pin(self.fetch_many(query).into_future().then(|(v, _)| match v {
Some(Ok(Either::Right(r))) => future::ok(Some(r)),
Some(Ok(Either::Left(_))) => future::ok(None),
Some(Err(e)) => future::err(e),
None => future::ok(None),
}))
}

fn prepare_with<'e, 'q: 'e>(
Expand All @@ -59,14 +48,7 @@ impl<'c> Executor<'c> for &'c mut OdbcConnection {
where
'c: 'e,
{
Box::pin(async move {
let (_, columns, parameters) = self.worker.prepare(sql).await?;
Ok(OdbcStatement {
sql: Cow::Borrowed(sql),
columns,
parameters,
})
})
Box::pin(async move { self.prepare(sql).await })
}

#[doc(hidden)]
Expand Down
207 changes: 186 additions & 21 deletions sqlx-core/src/odbc/connection/mod.rs
Original file line number Diff line number Diff line change
@@ -1,61 +1,220 @@
use crate::connection::{Connection, LogSettings};
use crate::connection::Connection;
use crate::error::Error;
use crate::odbc::{Odbc, OdbcConnectOptions};
use crate::odbc::{
Odbc, OdbcArguments, OdbcColumn, OdbcConnectOptions, OdbcQueryResult, OdbcRow, OdbcTypeInfo,
};
use crate::transaction::Transaction;
use either::Either;
use sqlx_rt::spawn_blocking;
mod odbc_bridge;
use crate::odbc::{OdbcStatement, OdbcStatementMetadata};
use futures_core::future::BoxFuture;
use futures_util::future;
use odbc_api::ConnectionTransitions;
use odbc_api::{handles::StatementConnection, Prepared, ResultSetMetadata, SharedConnection};
use odbc_bridge::{establish_connection, execute_sql};
use std::borrow::Cow;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};

mod executor;
mod worker;

pub(crate) use worker::ConnectionWorker;
type PreparedStatement = Prepared<StatementConnection<SharedConnection<'static>>>;
type SharedPreparedStatement = Arc<Mutex<PreparedStatement>>;

fn collect_columns(prepared: &mut PreparedStatement) -> Vec<OdbcColumn> {
let count = prepared.num_result_cols().unwrap_or(0);
(1..=count)
.map(|i| create_column(prepared, i as u16))
.collect()
}

fn create_column(stmt: &mut PreparedStatement, index: u16) -> OdbcColumn {
let mut cd = odbc_api::ColumnDescription::default();
let _ = stmt.describe_col(index, &mut cd);

OdbcColumn {
name: decode_column_name(cd.name, index),
type_info: OdbcTypeInfo::new(cd.data_type),
ordinal: usize::from(index.checked_sub(1).unwrap()),
}
}

fn decode_column_name(name_bytes: Vec<u8>, index: u16) -> String {
String::from_utf8(name_bytes).unwrap_or_else(|_| format!("col{}", index - 1))
}

/// A connection to an ODBC-accessible database.
///
/// ODBC uses a blocking C API, so we run all calls on a dedicated background thread
/// and communicate over channels to provide async access.
#[derive(Debug)]
/// ODBC uses a blocking C API, so we offload blocking calls to the runtime's blocking
/// thread-pool via `spawn_blocking` and synchronize access with a mutex.
pub struct OdbcConnection {
pub(crate) worker: ConnectionWorker,
pub(crate) log_settings: LogSettings,
pub(crate) conn: SharedConnection<'static>,
pub(crate) stmt_cache: HashMap<Arc<str>, SharedPreparedStatement>,
}

impl std::fmt::Debug for OdbcConnection {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OdbcConnection")
.field("conn", &self.conn)
.finish()
}
}

impl OdbcConnection {
pub(crate) async fn with_conn<R, F, S>(&mut self, operation: S, f: F) -> Result<R, Error>
where
R: Send + 'static,
F: FnOnce(&mut odbc_api::Connection<'static>) -> Result<R, Error> + Send + 'static,
S: std::fmt::Display + Send + 'static,
{
let conn = Arc::clone(&self.conn);
spawn_blocking(move || {
let mut conn_guard = conn.lock().map_err(|_| {
Error::Protocol(format!("ODBC {}: failed to lock connection", operation))
})?;
f(&mut conn_guard)
})
.await
}

pub(crate) async fn establish(options: &OdbcConnectOptions) -> Result<Self, Error> {
let worker = ConnectionWorker::establish(options.clone()).await?;
let shared_conn = spawn_blocking({
let options = options.clone();
move || {
let conn = establish_connection(&options)?;
let shared_conn = odbc_api::SharedConnection::new(std::sync::Mutex::new(conn));
Ok::<_, Error>(shared_conn)
}
})
.await?;

Ok(Self {
worker,
log_settings: LogSettings::default(),
conn: shared_conn,
stmt_cache: HashMap::new(),
})
}

/// Returns the name of the actual Database Management System (DBMS) this
/// connection is talking to as reported by the ODBC driver.
///
/// This calls the underlying ODBC API `SQL_DBMS_NAME` via
/// `odbc_api::Connection::database_management_system_name`.
///
/// See: https://docs.rs/odbc-api/19.0.1/odbc_api/struct.Connection.html#method.database_management_system_name
pub async fn dbms_name(&mut self) -> Result<String, Error> {
self.worker.get_dbms_name().await
self.with_conn("dbms_name", move |conn| {
Ok(conn.database_management_system_name()?)
})
.await
}

pub(crate) async fn ping_blocking(&mut self) -> Result<(), Error> {
self.with_conn("ping", move |conn| {
conn.execute("SELECT 1", (), None)?;
Ok(())
})
.await
}

pub(crate) async fn begin_blocking(&mut self) -> Result<(), Error> {
self.with_conn("begin", move |conn| {
conn.set_autocommit(false)?;
Ok(())
})
.await
}

pub(crate) async fn commit_blocking(&mut self) -> Result<(), Error> {
self.with_conn("commit", move |conn| {
conn.commit()?;
conn.set_autocommit(true)?;
Ok(())
})
.await
}

pub(crate) async fn rollback_blocking(&mut self) -> Result<(), Error> {
self.with_conn("rollback", move |conn| {
conn.rollback()?;
conn.set_autocommit(true)?;
Ok(())
})
.await
}

/// Launches a background task to execute the SQL statement and send the results to the returned channel.
pub(crate) fn execute_stream(
&mut self,
sql: &str,
args: Option<OdbcArguments>,
) -> flume::Receiver<Result<Either<OdbcQueryResult, OdbcRow>, Error>> {
let (tx, rx) = flume::bounded(64);

let maybe_prepared = if let Some(prepared) = self.stmt_cache.get(sql) {
MaybePrepared::Prepared(Arc::clone(prepared))
} else {
MaybePrepared::NotPrepared(sql.to_string())
};

let conn = Arc::clone(&self.conn);
sqlx_rt::spawn(sqlx_rt::spawn_blocking(move || {
let mut conn = conn.lock().expect("failed to lock connection");
if let Err(e) = execute_sql(&mut conn, maybe_prepared, args, &tx) {
let _ = tx.send(Err(e));
}
}));

rx
}

pub(crate) async fn clear_cached_statements(&mut self) -> Result<(), Error> {
// Clear the statement metadata cache
self.stmt_cache.clear();
Ok(())
}

pub async fn prepare<'a>(&mut self, sql: &'a str) -> Result<OdbcStatement<'a>, Error> {
let conn = Arc::clone(&self.conn);
let sql_arc = Arc::from(sql.to_string());
let sql_clone = Arc::clone(&sql_arc);
let (prepared, metadata) = spawn_blocking(move || {
let mut prepared = conn.into_prepared(&sql_clone)?;
let metadata = OdbcStatementMetadata {
columns: collect_columns(&mut prepared),
parameters: usize::from(prepared.num_params().unwrap_or(0)),
};
Ok::<_, Error>((prepared, metadata))
})
.await?;
self.stmt_cache
.insert(Arc::clone(&sql_arc), Arc::new(Mutex::new(prepared)));
Ok(OdbcStatement {
sql: Cow::Borrowed(sql),
metadata,
})
}
}

pub(crate) enum MaybePrepared {
Prepared(SharedPreparedStatement),
NotPrepared(String),
}

impl Connection for OdbcConnection {
type Database = Odbc;

type Options = OdbcConnectOptions;

fn close(mut self) -> BoxFuture<'static, Result<(), Error>> {
Box::pin(async move { self.worker.shutdown().await })
fn close(self) -> BoxFuture<'static, Result<(), Error>> {
Box::pin(async move {
// Drop connection by moving Arc and letting it fall out of scope.
drop(self);
Ok(())
})
}

fn close_hard(self) -> BoxFuture<'static, Result<(), Error>> {
Box::pin(async move { Ok(()) })
}

fn ping(&mut self) -> BoxFuture<'_, Result<(), Error>> {
Box::pin(self.worker.ping())
Box::pin(self.ping_blocking())
}

fn begin(&mut self) -> BoxFuture<'_, Result<Transaction<'_, Self::Database>, Error>>
Expand All @@ -74,4 +233,10 @@ impl Connection for OdbcConnection {
fn should_flush(&self) -> bool {
false
}

fn clear_cached_statements(&mut self) -> BoxFuture<'_, Result<(), Error>> {
Box::pin(self.clear_cached_statements())
}
}

// moved helpers to connection/inner.rs
Loading
Loading