Skip to content

Commit

Permalink
Implement support for diesel::Instrumentation for all provided connec…
Browse files Browse the repository at this point in the history
…tion types

This commit implements the necessary methods to support the diesel
Instrumentation interface for logging and other connection
instrumentation functionality. It also adds tests for this new functionality.
  • Loading branch information
weiznich committed Jul 5, 2024
1 parent 74867bd commit fbf0336
Show file tree
Hide file tree
Showing 13 changed files with 581 additions and 126 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ jobs:
find ~/.cargo/registry -iname "*clippy.toml" -delete
- name: Run clippy
run: cargo +stable clippy --all
run: cargo +stable clippy --all --all-features

- name: Check formating
run: cargo +stable fmt --all -- --check
Expand Down
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ for Rust libraries in [RFC #1105](https://github.com/rust-lang/rfcs/blob/master/
## [Unreleased]

* Added type `diesel_async::pooled_connection::mobc::PooledConnection`
* MySQL/MariaDB now use `CLIENT_FOUND_ROWS` capability to allow consistent behavior with PostgreSQL regarding return value of UPDATe commands.
* MySQL/MariaDB now use `CLIENT_FOUND_ROWS` capability to allow consistent behaviour with PostgreSQL regarding return value of UPDATe commands.
* The minimal supported rust version is now 1.78.0
* Add a `SyncConnectionWrapper` type that turns a sync connection into an async one. This enables SQLite support for diesel-async
* Add support for `diesel::connection::Instrumentation` to support logging and other instrumentation for any of the provided connection impls.

## [0.4.1] - 2023-09-01

Expand Down
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ cfg-if = "1"
chrono = "0.4"
diesel = { version = "2.2.0", default-features = false, features = ["chrono"] }
diesel_migrations = "2.2.0"
assert_matches = "1.0.1"

[features]
default = []
Expand Down
16 changes: 5 additions & 11 deletions src/async_connection_wrapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ mod implementation {
pub struct AsyncConnectionWrapper<C, B> {
inner: C,
runtime: B,
instrumentation: Option<Box<dyn Instrumentation>>,
}

impl<C, B> From<C> for AsyncConnectionWrapper<C, B>
Expand All @@ -119,7 +118,6 @@ mod implementation {
Self {
inner,
runtime: B::get_runtime(),
instrumentation: None,
}
}
}
Expand Down Expand Up @@ -150,11 +148,7 @@ mod implementation {
let runtime = B::get_runtime();
let f = C::establish(database_url);
let inner = runtime.block_on(f)?;
Ok(Self {
inner,
runtime,
instrumentation: None,
})
Ok(Self { inner, runtime })
}

fn execute_returning_count<T>(&mut self, source: &T) -> diesel::QueryResult<usize>
Expand All @@ -165,18 +159,18 @@ mod implementation {
self.runtime.block_on(f)
}

fn transaction_state(
&mut self,
fn transaction_state(
&mut self,
) -> &mut <Self::TransactionManager as diesel::connection::TransactionManager<Self>>::TransactionStateData{
self.inner.transaction_state()
}

fn instrumentation(&mut self) -> &mut dyn Instrumentation {
&mut self.instrumentation
self.inner.instrumentation()
}

fn set_instrumentation(&mut self, instrumentation: impl Instrumentation) {
self.instrumentation = Some(Box::new(instrumentation));
self.inner.set_instrumentation(instrumentation);
}
}

Expand Down
7 changes: 7 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
#![warn(missing_docs)]

use diesel::backend::Backend;
use diesel::connection::Instrumentation;
use diesel::query_builder::{AsQuery, QueryFragment, QueryId};
use diesel::result::Error;
use diesel::row::Row;
Expand Down Expand Up @@ -347,4 +348,10 @@ pub trait AsyncConnection: SimpleAsyncConnection + Sized + Send {
fn _silence_lint_on_execute_future(_: Self::ExecuteFuture<'_, '_>) {}
#[doc(hidden)]
fn _silence_lint_on_load_future(_: Self::LoadFuture<'_, '_>) {}

#[doc(hidden)]
fn instrumentation(&mut self) -> &mut dyn Instrumentation;

/// Set a specific [`Instrumentation`] implementation for this connection
fn set_instrumentation(&mut self, instrumentation: impl Instrumentation);
}
148 changes: 111 additions & 37 deletions src/mysql/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use crate::stmt_cache::{PrepareCallback, StmtCache};
use crate::{AnsiTransactionManager, AsyncConnection, SimpleAsyncConnection};
use diesel::connection::statement_cache::{MaybeCached, StatementCacheKey};
use diesel::connection::Instrumentation;
use diesel::connection::InstrumentationEvent;
use diesel::connection::StrQueryHelper;
use diesel::mysql::{Mysql, MysqlQueryBuilder, MysqlType};
use diesel::query_builder::QueryBuilder;
use diesel::query_builder::{bind_collector::RawBytesBindCollector, QueryFragment, QueryId};
Expand All @@ -26,12 +29,32 @@ pub struct AsyncMysqlConnection {
conn: mysql_async::Conn,
stmt_cache: StmtCache<Mysql, Statement>,
transaction_manager: AnsiTransactionManager,
instrumentation: std::sync::Mutex<Option<Box<dyn Instrumentation>>>,
}

#[async_trait::async_trait]
impl SimpleAsyncConnection for AsyncMysqlConnection {
async fn batch_execute(&mut self, query: &str) -> diesel::QueryResult<()> {
Ok(self.conn.query_drop(query).await.map_err(ErrorHelper)?)
self.instrumentation
.lock()
.unwrap_or_else(|p| p.into_inner())
.on_connection_event(InstrumentationEvent::start_query(&StrQueryHelper::new(
query,
)));
let result = self
.conn
.query_drop(query)
.await
.map_err(ErrorHelper)
.map_err(Into::into);
self.instrumentation
.lock()
.unwrap_or_else(|p| p.into_inner())
.on_connection_event(InstrumentationEvent::finish_query(
&StrQueryHelper::new(query),
result.as_ref().err(),
));
result
}
}

Expand All @@ -53,20 +76,18 @@ impl AsyncConnection for AsyncMysqlConnection {
type TransactionManager = AnsiTransactionManager;

async fn establish(database_url: &str) -> diesel::ConnectionResult<Self> {
let opts = Opts::from_url(database_url)
.map_err(|e| diesel::result::ConnectionError::InvalidConnectionUrl(e.to_string()))?;
let builder = OptsBuilder::from_opts(opts)
.init(CONNECTION_SETUP_QUERIES.to_vec())
.stmt_cache_size(0) // We have our own cache
.client_found_rows(true); // This allows a consistent behavior between MariaDB/MySQL and PostgreSQL (and is already set in `diesel`)

let conn = mysql_async::Conn::new(builder).await.map_err(ErrorHelper)?;

Ok(AsyncMysqlConnection {
conn,
stmt_cache: StmtCache::new(),
transaction_manager: AnsiTransactionManager::default(),
})
let mut instrumentation = diesel::connection::get_default_instrumentation();
instrumentation.on_connection_event(InstrumentationEvent::start_establish_connection(
database_url,
));
let r = Self::establish_connection_inner(database_url).await;
instrumentation.on_connection_event(InstrumentationEvent::finish_establish_connection(
database_url,
r.as_ref().err(),
));
let mut conn = r?;
conn.instrumentation = std::sync::Mutex::new(instrumentation);
Ok(conn)
}

fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query>
Expand All @@ -80,7 +101,10 @@ impl AsyncConnection for AsyncMysqlConnection {
let stmt_for_exec = match stmt {
MaybeCached::Cached(ref s) => (*s).clone(),
MaybeCached::CannotCache(ref s) => s.clone(),
_ => todo!(),
_ => unreachable!(
"Diesel has only two variants here at the time of writing.\n\
If you ever see this error message please open in issue in the diesel-async issue tracker"
),
};

let (tx, rx) = futures_channel::mpsc::channel(0);
Expand Down Expand Up @@ -152,6 +176,19 @@ impl AsyncConnection for AsyncMysqlConnection {
fn transaction_state(&mut self) -> &mut AnsiTransactionManager {
&mut self.transaction_manager
}

fn instrumentation(&mut self) -> &mut dyn Instrumentation {
self.instrumentation
.get_mut()
.unwrap_or_else(|p| p.into_inner())
}

fn set_instrumentation(&mut self, instrumentation: impl Instrumentation) {
*self
.instrumentation
.get_mut()
.unwrap_or_else(|p| p.into_inner()) = Some(Box::new(instrumentation));
}
}

#[inline(always)]
Expand Down Expand Up @@ -195,6 +232,7 @@ impl AsyncMysqlConnection {
conn,
stmt_cache: StmtCache::new(),
transaction_manager: AnsiTransactionManager::default(),
instrumentation: std::sync::Mutex::new(None),
};

for stmt in CONNECTION_SETUP_QUERIES {
Expand All @@ -219,6 +257,12 @@ impl AsyncMysqlConnection {
T: QueryFragment<Mysql> + QueryId,
F: Future<Output = QueryResult<R>> + Send,
{
self.instrumentation
.lock()
.unwrap_or_else(|p| p.into_inner())
.on_connection_event(InstrumentationEvent::start_query(&diesel::debug_query(
&query,
)));
let mut bind_collector = RawBytesBindCollector::<Mysql>::new();
let bind_collector = query
.collect_binds(&mut bind_collector, &mut (), &Mysql)
Expand All @@ -228,6 +272,7 @@ impl AsyncMysqlConnection {
ref mut conn,
ref mut stmt_cache,
ref mut transaction_manager,
ref instrumentation,
..
} = self;

Expand All @@ -242,28 +287,37 @@ impl AsyncMysqlConnection {
} = bind_collector?;
let is_safe_to_cache_prepared = is_safe_to_cache_prepared?;
let sql = sql?;
let cache_key = if let Some(query_id) = query_id {
StatementCacheKey::Type(query_id)
} else {
StatementCacheKey::Sql {
sql: sql.clone(),
bind_types: metadata.clone(),
}
let inner = async {
let cache_key = if let Some(query_id) = query_id {
StatementCacheKey::Type(query_id)
} else {
StatementCacheKey::Sql {
sql: sql.clone(),
bind_types: metadata.clone(),
}
};

let (stmt, conn) = stmt_cache
.cached_prepared_statement(
cache_key,
sql.clone(),
is_safe_to_cache_prepared,
&metadata,
conn,
instrumentation,
)
.await?;
callback(conn, stmt, ToSqlHelper { metadata, binds }).await
};

let (stmt, conn) = stmt_cache
.cached_prepared_statement(
cache_key,
sql,
is_safe_to_cache_prepared,
&metadata,
conn,
)
.await?;
update_transaction_manager_status(
callback(conn, stmt, ToSqlHelper { metadata, binds }).await,
transaction_manager,
)
let r = update_transaction_manager_status(inner.await, transaction_manager);
instrumentation
.lock()
.unwrap_or_else(|p| p.into_inner())
.on_connection_event(InstrumentationEvent::finish_query(
&StrQueryHelper::new(&sql),
r.as_ref().err(),
));
r
}
.boxed()
}
Expand Down Expand Up @@ -300,6 +354,26 @@ impl AsyncMysqlConnection {

Ok(())
}

async fn establish_connection_inner(
database_url: &str,
) -> Result<AsyncMysqlConnection, ConnectionError> {
let opts = Opts::from_url(database_url)
.map_err(|e| diesel::result::ConnectionError::InvalidConnectionUrl(e.to_string()))?;
let builder = OptsBuilder::from_opts(opts)
.init(CONNECTION_SETUP_QUERIES.to_vec())
.stmt_cache_size(0) // We have our own cache
.client_found_rows(true); // This allows a consistent behavior between MariaDB/MySQL and PostgreSQL (and is already set in `diesel`)

let conn = mysql_async::Conn::new(builder).await.map_err(ErrorHelper)?;

Ok(AsyncMysqlConnection {
conn,
stmt_cache: StmtCache::new(),
transaction_manager: AnsiTransactionManager::default(),
instrumentation: std::sync::Mutex::new(None),
})
}
}

#[cfg(any(
Expand Down
Loading

0 comments on commit fbf0336

Please sign in to comment.