diff --git a/contrib/sync_db_pools/lib/Cargo.toml b/contrib/sync_db_pools/lib/Cargo.toml index 73c4d3d1df..20c073a5c8 100644 --- a/contrib/sync_db_pools/lib/Cargo.toml +++ b/contrib/sync_db_pools/lib/Cargo.toml @@ -15,6 +15,7 @@ diesel_postgres_pool = ["diesel/postgres", "diesel/r2d2"] diesel_mysql_pool = ["diesel/mysql", "diesel/r2d2"] sqlite_pool = ["rusqlite", "r2d2_sqlite"] postgres_pool = ["postgres", "r2d2_postgres"] +postgres_pool_tls = ["postgres_pool", "postgres-native-tls", "native-tls"] memcache_pool = ["memcache", "r2d2-memcache"] [dependencies] @@ -26,6 +27,8 @@ diesel = { version = "1.0", default-features = false, optional = true } postgres = { version = "0.19", optional = true } r2d2_postgres = { version = "0.18", optional = true } +postgres-native-tls = { version = "0.5", optional = true } +native-tls = { version = "0.2", features = ["vendored"], optional = true } rusqlite = { version = "0.25", optional = true } r2d2_sqlite = { version = "0.18", optional = true } diff --git a/contrib/sync_db_pools/lib/src/config.rs b/contrib/sync_db_pools/lib/src/config.rs index a0938e27e6..6099abe14a 100644 --- a/contrib/sync_db_pools/lib/src/config.rs +++ b/contrib/sync_db_pools/lib/src/config.rs @@ -1,7 +1,26 @@ -use rocket::{Rocket, Build}; -use rocket::figment::{self, Figment, providers::Serialized}; +use rocket::figment::{self, providers::Serialized, Figment}; +use rocket::{Build, Rocket}; -use serde::{Serialize, Deserialize}; +use serde::{Deserialize, Serialize}; + +#[cfg(feature = "postgres_pool_tls")] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +pub struct TlsConfig { + pub accept_invalid_certs: Option, + pub accept_invalid_hostnames: Option, + pub cert_path: Option, +} + +#[cfg(feature = "postgres_pool_tls")] +impl Default for TlsConfig { + fn default() -> Self { + Self { + accept_invalid_certs: None, + accept_invalid_hostnames: None, + cert_path: None, + } + } +} /// A base `Config` for any `Poolable` type. /// @@ -39,6 +58,9 @@ pub struct Config { /// Defaults to `5`. // FIXME: Use `time`. pub timeout: u8, + /// Postgres TLS Config + #[cfg(feature = "postgres_pool_tls")] + pub tls: Option, } impl Config { @@ -102,7 +124,8 @@ impl Config { /// ``` pub fn figment(db_name: &str, rocket: &Rocket) -> Figment { let db_key = format!("databases.{}", db_name); - let default_pool_size = rocket.figment() + let default_pool_size = rocket + .figment() .extract_inner::(rocket::Config::WORKERS) .map(|workers| workers * 4) .ok(); @@ -113,7 +136,7 @@ impl Config { match default_pool_size { Some(pool_size) => figment.join(Serialized::default("pool_size", pool_size)), - None => figment + None => figment, } } } diff --git a/contrib/sync_db_pools/lib/src/connection.rs b/contrib/sync_db_pools/lib/src/connection.rs index 47e1b23e5f..b24ffbef5c 100644 --- a/contrib/sync_db_pools/lib/src/connection.rs +++ b/contrib/sync_db_pools/lib/src/connection.rs @@ -1,16 +1,16 @@ use std::marker::PhantomData; use std::sync::Arc; -use rocket::{Phase, Rocket, Ignite, Sentinel}; use rocket::fairing::{AdHoc, Fairing}; -use rocket::request::{Request, Outcome, FromRequest}; -use rocket::outcome::IntoOutcome; use rocket::http::Status; +use rocket::outcome::IntoOutcome; +use rocket::request::{FromRequest, Outcome, Request}; +use rocket::{Ignite, Phase, Rocket, Sentinel}; -use rocket::tokio::sync::{OwnedSemaphorePermit, Semaphore, Mutex}; +use rocket::tokio::sync::{Mutex, OwnedSemaphorePermit, Semaphore}; use rocket::tokio::time::timeout; -use crate::{Config, Poolable, Error}; +use crate::{Config, Error, Poolable}; /// Unstable internal details of generated code for the #[database] attribute. /// @@ -31,7 +31,7 @@ impl Clone for ConnectionPool { config: self.config.clone(), pool: self.pool.clone(), semaphore: self.semaphore.clone(), - _marker: PhantomData + _marker: PhantomData, } } } @@ -49,23 +49,28 @@ pub struct Connection { // A wrapper around spawn_blocking that propagates panics to the calling code. async fn run_blocking(job: F) -> R - where F: FnOnce() -> R + Send + 'static, R: Send + 'static, +where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, { match tokio::task::spawn_blocking(job).await { Ok(ret) => ret, Err(e) => match e.try_into_panic() { Ok(panic) => std::panic::resume_unwind(panic), Err(_) => unreachable!("spawn_blocking tasks are never cancelled"), - } + }, } } macro_rules! dberr { - ($msg:literal, $db_name:expr, $efmt:literal, $error:expr, $rocket:expr) => ({ - rocket::error!(concat!("database ", $msg, " error for pool named `{}`"), $db_name); + ($msg:literal, $db_name:expr, $efmt:literal, $error:expr, $rocket:expr) => {{ + rocket::error!( + concat!("database ", $msg, " error for pool named `{}`"), + $db_name + ); error_!($efmt, $error); return Err($rocket); - }); + }}; } impl ConnectionPool { @@ -88,8 +93,10 @@ impl ConnectionPool { Err(Error::Config(e)) => dberr!("config", db, "{}", e, rocket), Err(Error::Pool(e)) => dberr!("pool init", db, "{}", e, rocket), Err(Error::Custom(e)) => dberr!("pool manager", db, "{:?}", e, rocket), + Err(Error::Io(e)) => dberr!("io", db, "{:?}", e, rocket), } - }).await + }) + .await }) } @@ -103,7 +110,10 @@ impl ConnectionPool { } }; - let pool = self.pool.as_ref().cloned() + let pool = self + .pool + .as_ref() + .cloned() .expect("internal invariant broken: self.pool is Some"); match run_blocking(move || pool.get_timeout(duration)).await { @@ -125,12 +135,18 @@ impl ConnectionPool { Some(pool) => match pool.get().await.ok() { Some(conn) => Some(conn), None => { - error_!("no connections available for `{}`", std::any::type_name::()); + error_!( + "no connections available for `{}`", + std::any::type_name::() + ); None } }, None => { - error_!("missing database fairing for `{}`", std::any::type_name::()); + error_!( + "missing database fairing for `{}`", + std::any::type_name::() + ); None } } @@ -145,8 +161,9 @@ impl ConnectionPool { impl Connection { #[inline] pub async fn run(&self, f: F) -> R - where F: FnOnce(&mut C) -> R + Send + 'static, - R: Send + 'static, + where + F: FnOnce(&mut C) -> R + Send + 'static, + R: Send + 'static, { // It is important that this inner Arc> (or the OwnedMutexGuard // derived from it) never be a variable on the stack at an await point, @@ -160,14 +177,15 @@ impl Connection { run_blocking(move || { // And then re-enter the runtime to wait on the async mutex, but in // a blocking fashion. - let mut connection = tokio::runtime::Handle::current().block_on(async { - connection.lock_owned().await - }); + let mut connection = + tokio::runtime::Handle::current().block_on(async { connection.lock_owned().await }); - let conn = connection.as_mut() + let conn = connection + .as_mut() .expect("internal invariant broken: self.connection is Some"); f(conn) - }).await + }) + .await } } @@ -178,9 +196,8 @@ impl Drop for Connection { // See same motivation above for this arrangement of spawn_blocking/block_on tokio::task::spawn_blocking(move || { - let mut connection = tokio::runtime::Handle::current().block_on(async { - connection.lock_owned().await - }); + let mut connection = + tokio::runtime::Handle::current().block_on(async { connection.lock_owned().await }); if let Some(conn) = connection.take() { drop(conn); @@ -209,7 +226,10 @@ impl<'r, K: 'static, C: Poolable> FromRequest<'r> for Connection { match request.rocket().state::>() { Some(c) => c.get().await.into_outcome(Status::ServiceUnavailable), None => { - error_!("Missing database fairing for `{}`", std::any::type_name::()); + error_!( + "Missing database fairing for `{}`", + std::any::type_name::() + ); Outcome::Failure((Status::InternalServerError, ())) } } @@ -223,7 +243,10 @@ impl Sentinel for Connection { if rocket.state::>().is_none() { let conn = Paint::default(std::any::type_name::()).bold(); let fairing = Paint::default(format!("{}::fairing()", conn)).wrap().bold(); - error!("requesting `{}` DB connection without attaching `{}`.", conn, fairing); + error!( + "requesting `{}` DB connection without attaching `{}`.", + conn, fairing + ); info_!("Attach `{}` to use database connection pooling.", fairing); return true; } diff --git a/contrib/sync_db_pools/lib/src/error.rs b/contrib/sync_db_pools/lib/src/error.rs index fbf179e2a0..fb6eca2b35 100644 --- a/contrib/sync_db_pools/lib/src/error.rs +++ b/contrib/sync_db_pools/lib/src/error.rs @@ -14,6 +14,8 @@ pub enum Error { Pool(r2d2::Error), /// An error occurred while extracting a `figment` configuration. Config(figment::Error), + /// An IO error + Io(std::io::Error), } impl From for Error { @@ -27,3 +29,9 @@ impl From for Error { Error::Pool(error) } } + +impl From for Error { + fn from(error: std::io::Error) -> Self { + Error::Io(error) + } +} diff --git a/contrib/sync_db_pools/lib/src/poolable.rs b/contrib/sync_db_pools/lib/src/poolable.rs index 9ba895f0c3..4d4a820346 100644 --- a/contrib/sync_db_pools/lib/src/poolable.rs +++ b/contrib/sync_db_pools/lib/src/poolable.rs @@ -1,7 +1,7 @@ use std::time::Duration; use r2d2::ManageConnection; -use rocket::{Rocket, Build}; +use rocket::{Build, Rocket}; #[allow(unused_imports)] use crate::{Config, Error}; @@ -98,7 +98,7 @@ use crate::{Config, Error}; /// [`Poolable`]. pub trait Poolable: Send + Sized + 'static { /// The associated connection manager for the given connection type. - type Manager: ManageConnection; + type Manager: ManageConnection; /// The associated error type in the event that constructing the connection /// manager and/or the connection pool fails. @@ -119,19 +119,22 @@ impl Poolable for diesel::SqliteConnection { type Error = std::convert::Infallible; fn pool(db_name: &str, rocket: &Rocket) -> PoolResult { - use diesel::{SqliteConnection, connection::SimpleConnection}; - use diesel::r2d2::{CustomizeConnection, ConnectionManager, Error, Pool}; + use diesel::r2d2::{ConnectionManager, CustomizeConnection, Error, Pool}; + use diesel::{connection::SimpleConnection, SqliteConnection}; #[derive(Debug)] struct Customizer; impl CustomizeConnection for Customizer { fn on_acquire(&self, conn: &mut SqliteConnection) -> Result<(), Error> { - conn.batch_execute("\ + conn.batch_execute( + "\ PRAGMA journal_mode = WAL;\ PRAGMA busy_timeout = 1000;\ PRAGMA foreign_keys = ON;\ - ").map_err(Error::QueryError)?; + ", + ) + .map_err(Error::QueryError)?; Ok(()) } @@ -183,8 +186,7 @@ impl Poolable for diesel::MysqlConnection { } } -// TODO: Add a feature to enable TLS in `postgres`; parse a suitable `config`. -#[cfg(feature = "postgres_pool")] +#[cfg(all(feature = "postgres_pool", not(feature = "postgres_pool_tls")))] impl Poolable for postgres::Client { type Manager = r2d2_postgres::PostgresConnectionManager; type Error = postgres::Error; @@ -202,6 +204,53 @@ impl Poolable for postgres::Client { } } +#[cfg(feature = "postgres_pool_tls")] +impl Poolable for postgres::Client { + type Manager = r2d2_postgres::PostgresConnectionManager; + type Error = postgres::Error; + + fn pool(db_name: &str, rocket: &Rocket) -> PoolResult { + let config = Config::from(db_name, rocket)?; + let tls_config = config.tls.unwrap_or_default(); + let mut tls_connector = native_tls::TlsConnector::builder(); + + if let Some(accept_invalid_certs) = tls_config.accept_invalid_certs { + tls_connector.danger_accept_invalid_certs(accept_invalid_certs); + if accept_invalid_certs { + rocket::warn!("Accepting invalid certificates"); + } + } + + if let Some(accept_invalid_hostnames) = tls_config.accept_invalid_hostnames { + tls_connector.danger_accept_invalid_hostnames(accept_invalid_hostnames); + if accept_invalid_hostnames { + rocket::warn!("Accepting invalid certificate hostnames"); + } + } + + if let Some(cert_path) = tls_config.cert_path { + let cert_bytes = std::fs::read(cert_path).map_err(Error::Io)?; + let cert = native_tls::Certificate::from_pem(&cert_bytes) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?; + tls_connector.add_root_certificate(cert); + } + + let tls_connector = tls_connector + .build() + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?; + let db_connector = postgres_native_tls::MakeTlsConnector::new(tls_connector); + + let url = config.url.parse().map_err(Error::Custom)?; + let manager = r2d2_postgres::PostgresConnectionManager::new(url, db_connector); + let pool = r2d2::Pool::builder() + .max_size(config.pool_size) + .connection_timeout(Duration::from_secs(config.timeout as u64)) + .build(manager)?; + + Ok(pool) + } +} + #[cfg(feature = "sqlite_pool")] impl Poolable for rusqlite::Connection { type Manager = r2d2_sqlite::SqliteConnectionManager; @@ -247,10 +296,9 @@ impl Poolable for rusqlite::Connection { }; flags.insert(sql_flag) - }; + } - let manager = r2d2_sqlite::SqliteConnectionManager::file(&*config.url) - .with_flags(flags); + let manager = r2d2_sqlite::SqliteConnectionManager::file(&*config.url).with_flags(flags); let pool = r2d2::Pool::builder() .max_size(config.pool_size)