Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Postgres TLS support #2018

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions contrib/sync_db_pools/lib/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ diesel_postgres_pool = ["diesel/postgres", "diesel/r2d2"]
diesel_mysql_pool = ["diesel/mysql", "diesel/r2d2"]
sqlite_pool = ["rusqlite", "r2d2_sqlite"]
postgres_pool = ["postgres", "r2d2_postgres"]
postgres_pool_tls = ["postgres_pool", "postgres-native-tls", "native-tls"]
memcache_pool = ["memcache", "r2d2-memcache"]

[dependencies]
Expand All @@ -26,6 +27,8 @@ diesel = { version = "1.0", default-features = false, optional = true }

postgres = { version = "0.19", optional = true }
r2d2_postgres = { version = "0.18", optional = true }
postgres-native-tls = { version = "0.5", optional = true }
native-tls = { version = "0.2", features = ["vendored"], optional = true }

rusqlite = { version = "0.25", optional = true }
r2d2_sqlite = { version = "0.18", optional = true }
Expand Down
33 changes: 28 additions & 5 deletions contrib/sync_db_pools/lib/src/config.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,26 @@
use rocket::{Rocket, Build};
use rocket::figment::{self, Figment, providers::Serialized};
use rocket::figment::{self, providers::Serialized, Figment};
use rocket::{Build, Rocket};

use serde::{Serialize, Deserialize};
use serde::{Deserialize, Serialize};

#[cfg(feature = "postgres_pool_tls")]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct TlsConfig {
pub accept_invalid_certs: Option<bool>,
pub accept_invalid_hostnames: Option<bool>,
pub cert_path: Option<std::path::PathBuf>,
}

#[cfg(feature = "postgres_pool_tls")]
impl Default for TlsConfig {
fn default() -> Self {
Self {
accept_invalid_certs: None,
accept_invalid_hostnames: None,
cert_path: None,
}
}
}

/// A base `Config` for any `Poolable` type.
///
Expand Down Expand Up @@ -39,6 +58,9 @@ pub struct Config {
/// Defaults to `5`.
// FIXME: Use `time`.
pub timeout: u8,
/// Postgres TLS Config
#[cfg(feature = "postgres_pool_tls")]
pub tls: Option<TlsConfig>,
}

impl Config {
Expand Down Expand Up @@ -102,7 +124,8 @@ impl Config {
/// ```
pub fn figment(db_name: &str, rocket: &Rocket<Build>) -> Figment {
let db_key = format!("databases.{}", db_name);
let default_pool_size = rocket.figment()
let default_pool_size = rocket
.figment()
.extract_inner::<u32>(rocket::Config::WORKERS)
.map(|workers| workers * 4)
.ok();
Expand All @@ -113,7 +136,7 @@ impl Config {

match default_pool_size {
Some(pool_size) => figment.join(Serialized::default("pool_size", pool_size)),
None => figment
None => figment,
}
}
}
77 changes: 50 additions & 27 deletions contrib/sync_db_pools/lib/src/connection.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
use std::marker::PhantomData;
use std::sync::Arc;

use rocket::{Phase, Rocket, Ignite, Sentinel};
use rocket::fairing::{AdHoc, Fairing};
use rocket::request::{Request, Outcome, FromRequest};
use rocket::outcome::IntoOutcome;
use rocket::http::Status;
use rocket::outcome::IntoOutcome;
use rocket::request::{FromRequest, Outcome, Request};
use rocket::{Ignite, Phase, Rocket, Sentinel};

use rocket::tokio::sync::{OwnedSemaphorePermit, Semaphore, Mutex};
use rocket::tokio::sync::{Mutex, OwnedSemaphorePermit, Semaphore};
use rocket::tokio::time::timeout;

use crate::{Config, Poolable, Error};
use crate::{Config, Error, Poolable};

/// Unstable internal details of generated code for the #[database] attribute.
///
Expand All @@ -31,7 +31,7 @@ impl<K, C: Poolable> Clone for ConnectionPool<K, C> {
config: self.config.clone(),
pool: self.pool.clone(),
semaphore: self.semaphore.clone(),
_marker: PhantomData
_marker: PhantomData,
}
}
}
Expand All @@ -49,23 +49,28 @@ pub struct Connection<K, C: Poolable> {

// A wrapper around spawn_blocking that propagates panics to the calling code.
async fn run_blocking<F, R>(job: F) -> R
where F: FnOnce() -> R + Send + 'static, R: Send + 'static,
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
match tokio::task::spawn_blocking(job).await {
Ok(ret) => ret,
Err(e) => match e.try_into_panic() {
Ok(panic) => std::panic::resume_unwind(panic),
Err(_) => unreachable!("spawn_blocking tasks are never cancelled"),
}
},
}
}

macro_rules! dberr {
($msg:literal, $db_name:expr, $efmt:literal, $error:expr, $rocket:expr) => ({
rocket::error!(concat!("database ", $msg, " error for pool named `{}`"), $db_name);
($msg:literal, $db_name:expr, $efmt:literal, $error:expr, $rocket:expr) => {{
rocket::error!(
concat!("database ", $msg, " error for pool named `{}`"),
$db_name
);
error_!($efmt, $error);
return Err($rocket);
});
}};
}

impl<K: 'static, C: Poolable> ConnectionPool<K, C> {
Expand All @@ -88,8 +93,10 @@ impl<K: 'static, C: Poolable> ConnectionPool<K, C> {
Err(Error::Config(e)) => dberr!("config", db, "{}", e, rocket),
Err(Error::Pool(e)) => dberr!("pool init", db, "{}", e, rocket),
Err(Error::Custom(e)) => dberr!("pool manager", db, "{:?}", e, rocket),
Err(Error::Io(e)) => dberr!("io", db, "{:?}", e, rocket),
}
}).await
})
.await
})
}

Expand All @@ -103,7 +110,10 @@ impl<K: 'static, C: Poolable> ConnectionPool<K, C> {
}
};

let pool = self.pool.as_ref().cloned()
let pool = self
.pool
.as_ref()
.cloned()
.expect("internal invariant broken: self.pool is Some");

match run_blocking(move || pool.get_timeout(duration)).await {
Expand All @@ -125,12 +135,18 @@ impl<K: 'static, C: Poolable> ConnectionPool<K, C> {
Some(pool) => match pool.get().await.ok() {
Some(conn) => Some(conn),
None => {
error_!("no connections available for `{}`", std::any::type_name::<K>());
error_!(
"no connections available for `{}`",
std::any::type_name::<K>()
);
None
}
},
None => {
error_!("missing database fairing for `{}`", std::any::type_name::<K>());
error_!(
"missing database fairing for `{}`",
std::any::type_name::<K>()
);
None
}
}
Expand All @@ -145,8 +161,9 @@ impl<K: 'static, C: Poolable> ConnectionPool<K, C> {
impl<K: 'static, C: Poolable> Connection<K, C> {
#[inline]
pub async fn run<F, R>(&self, f: F) -> R
where F: FnOnce(&mut C) -> R + Send + 'static,
R: Send + 'static,
where
F: FnOnce(&mut C) -> R + Send + 'static,
R: Send + 'static,
{
// It is important that this inner Arc<Mutex<>> (or the OwnedMutexGuard
// derived from it) never be a variable on the stack at an await point,
Expand All @@ -160,14 +177,15 @@ impl<K: 'static, C: Poolable> Connection<K, C> {
run_blocking(move || {
// And then re-enter the runtime to wait on the async mutex, but in
// a blocking fashion.
let mut connection = tokio::runtime::Handle::current().block_on(async {
connection.lock_owned().await
});
let mut connection =
tokio::runtime::Handle::current().block_on(async { connection.lock_owned().await });

let conn = connection.as_mut()
let conn = connection
.as_mut()
.expect("internal invariant broken: self.connection is Some");
f(conn)
}).await
})
.await
}
}

Expand All @@ -178,9 +196,8 @@ impl<K, C: Poolable> Drop for Connection<K, C> {

// See same motivation above for this arrangement of spawn_blocking/block_on
tokio::task::spawn_blocking(move || {
let mut connection = tokio::runtime::Handle::current().block_on(async {
connection.lock_owned().await
});
let mut connection =
tokio::runtime::Handle::current().block_on(async { connection.lock_owned().await });

if let Some(conn) = connection.take() {
drop(conn);
Expand Down Expand Up @@ -209,7 +226,10 @@ impl<'r, K: 'static, C: Poolable> FromRequest<'r> for Connection<K, C> {
match request.rocket().state::<ConnectionPool<K, C>>() {
Some(c) => c.get().await.into_outcome(Status::ServiceUnavailable),
None => {
error_!("Missing database fairing for `{}`", std::any::type_name::<K>());
error_!(
"Missing database fairing for `{}`",
std::any::type_name::<K>()
);
Outcome::Failure((Status::InternalServerError, ()))
}
}
Expand All @@ -223,7 +243,10 @@ impl<K: 'static, C: Poolable> Sentinel for Connection<K, C> {
if rocket.state::<ConnectionPool<K, C>>().is_none() {
let conn = Paint::default(std::any::type_name::<K>()).bold();
let fairing = Paint::default(format!("{}::fairing()", conn)).wrap().bold();
error!("requesting `{}` DB connection without attaching `{}`.", conn, fairing);
error!(
"requesting `{}` DB connection without attaching `{}`.",
conn, fairing
);
info_!("Attach `{}` to use database connection pooling.", fairing);
return true;
}
Expand Down
8 changes: 8 additions & 0 deletions contrib/sync_db_pools/lib/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ pub enum Error<T> {
Pool(r2d2::Error),
/// An error occurred while extracting a `figment` configuration.
Config(figment::Error),
/// An IO error
Io(std::io::Error),
}

impl<T> From<figment::Error> for Error<T> {
Expand All @@ -27,3 +29,9 @@ impl<T> From<r2d2::Error> for Error<T> {
Error::Pool(error)
}
}

impl<T> From<std::io::Error> for Error<T> {
fn from(error: std::io::Error) -> Self {
Error::Io(error)
}
}
70 changes: 59 additions & 11 deletions contrib/sync_db_pools/lib/src/poolable.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::time::Duration;

use r2d2::ManageConnection;
use rocket::{Rocket, Build};
use rocket::{Build, Rocket};

#[allow(unused_imports)]
use crate::{Config, Error};
Expand Down Expand Up @@ -98,7 +98,7 @@ use crate::{Config, Error};
/// [`Poolable`].
pub trait Poolable: Send + Sized + 'static {
/// The associated connection manager for the given connection type.
type Manager: ManageConnection<Connection=Self>;
type Manager: ManageConnection<Connection = Self>;

/// The associated error type in the event that constructing the connection
/// manager and/or the connection pool fails.
Expand All @@ -119,19 +119,22 @@ impl Poolable for diesel::SqliteConnection {
type Error = std::convert::Infallible;

fn pool(db_name: &str, rocket: &Rocket<Build>) -> PoolResult<Self> {
use diesel::{SqliteConnection, connection::SimpleConnection};
use diesel::r2d2::{CustomizeConnection, ConnectionManager, Error, Pool};
use diesel::r2d2::{ConnectionManager, CustomizeConnection, Error, Pool};
use diesel::{connection::SimpleConnection, SqliteConnection};

#[derive(Debug)]
struct Customizer;

impl CustomizeConnection<SqliteConnection, Error> for Customizer {
fn on_acquire(&self, conn: &mut SqliteConnection) -> Result<(), Error> {
conn.batch_execute("\
conn.batch_execute(
"\
PRAGMA journal_mode = WAL;\
PRAGMA busy_timeout = 1000;\
PRAGMA foreign_keys = ON;\
").map_err(Error::QueryError)?;
",
)
.map_err(Error::QueryError)?;

Ok(())
}
Expand Down Expand Up @@ -183,8 +186,7 @@ impl Poolable for diesel::MysqlConnection {
}
}

// TODO: Add a feature to enable TLS in `postgres`; parse a suitable `config`.
#[cfg(feature = "postgres_pool")]
#[cfg(all(feature = "postgres_pool", not(feature = "postgres_pool_tls")))]
impl Poolable for postgres::Client {
type Manager = r2d2_postgres::PostgresConnectionManager<postgres::tls::NoTls>;
type Error = postgres::Error;
Expand All @@ -202,6 +204,53 @@ impl Poolable for postgres::Client {
}
}

#[cfg(feature = "postgres_pool_tls")]
impl Poolable for postgres::Client {
type Manager = r2d2_postgres::PostgresConnectionManager<postgres_native_tls::MakeTlsConnector>;
type Error = postgres::Error;

fn pool(db_name: &str, rocket: &Rocket<Build>) -> PoolResult<Self> {
let config = Config::from(db_name, rocket)?;
let tls_config = config.tls.unwrap_or_default();
let mut tls_connector = native_tls::TlsConnector::builder();

if let Some(accept_invalid_certs) = tls_config.accept_invalid_certs {
tls_connector.danger_accept_invalid_certs(accept_invalid_certs);
if accept_invalid_certs {
rocket::warn!("Accepting invalid certificates");
}
}

if let Some(accept_invalid_hostnames) = tls_config.accept_invalid_hostnames {
tls_connector.danger_accept_invalid_hostnames(accept_invalid_hostnames);
if accept_invalid_hostnames {
rocket::warn!("Accepting invalid certificate hostnames");
}
}

if let Some(cert_path) = tls_config.cert_path {
let cert_bytes = std::fs::read(cert_path).map_err(Error::Io)?;
let cert = native_tls::Certificate::from_pem(&cert_bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?;
tls_connector.add_root_certificate(cert);
}

let tls_connector = tls_connector
.build()
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?;
let db_connector = postgres_native_tls::MakeTlsConnector::new(tls_connector);

let url = config.url.parse().map_err(Error::Custom)?;
let manager = r2d2_postgres::PostgresConnectionManager::new(url, db_connector);
let pool = r2d2::Pool::builder()
.max_size(config.pool_size)
.connection_timeout(Duration::from_secs(config.timeout as u64))
.build(manager)?;

Ok(pool)
}
}

#[cfg(feature = "sqlite_pool")]
impl Poolable for rusqlite::Connection {
type Manager = r2d2_sqlite::SqliteConnectionManager;
Expand Down Expand Up @@ -247,10 +296,9 @@ impl Poolable for rusqlite::Connection {
};

flags.insert(sql_flag)
};
}

let manager = r2d2_sqlite::SqliteConnectionManager::file(&*config.url)
.with_flags(flags);
let manager = r2d2_sqlite::SqliteConnectionManager::file(&*config.url).with_flags(flags);

let pool = r2d2::Pool::builder()
.max_size(config.pool_size)
Expand Down