Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for using TLS with PostgreSQL (#260) #266

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion refinery/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ edition = "2018"
default = []
rusqlite-bundled = ["refinery-core/rusqlite-bundled"]
rusqlite = ["refinery-core/rusqlite"]
postgres = ["refinery-core/postgres"]
postgres = ["refinery-core/postgres", "refinery-core/postgres-openssl", "refinery-core/openssl"]
mysql = ["refinery-core/mysql", "refinery-core/flate2"]
tokio-postgres = ["refinery-core/tokio-postgres"]
mysql_async = ["refinery-core/mysql_async"]
Expand Down
2 changes: 1 addition & 1 deletion refinery_cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ path = "src/main.rs"

[features]
default = ["mysql", "postgresql", "sqlite-bundled", "mssql"]
postgresql = ["refinery-core/postgres"]
postgresql = ["refinery-core/postgres", "refinery-core/postgres-openssl", "refinery-core/openssl"]
mysql = ["refinery-core/mysql", "refinery-core/flate2"]
sqlite = ["refinery-core/rusqlite"]
sqlite-bundled = ["sqlite", "refinery-core/rusqlite-bundled"]
Expand Down
2 changes: 2 additions & 0 deletions refinery_core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ walkdir = "2.3.1"
# allow multiple versions of the same dependency if API is similar
rusqlite = { version = ">= 0.23, <= 0.28", optional = true }
postgres = { version = "0.19", optional = true }
postgres-openssl = { version = "0.5", optional = true }
openssl = { version = "0.10", optional = true }
tokio-postgres = { version = "0.7", optional = true }
mysql = { version = ">= 21.0.0, <= 23", optional = true, default-features = false}
mysql_async = { version = ">= 0.28, <= 0.30", optional = true }
Expand Down
50 changes: 45 additions & 5 deletions refinery_core/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::convert::TryFrom;
use std::fs;
use std::path::{Path, PathBuf};
use std::str::FromStr;
use std::{borrow::Cow, collections::HashMap};
use url::Url;

// refinery config file used by migrate_from_config if migration from a Config struct is preferred instead of using the macros
Expand Down Expand Up @@ -34,6 +35,7 @@ impl Config {
db_user: None,
db_pass: None,
db_name: None,
use_tls: None,
#[cfg(feature = "tiberius-config")]
trust_cert: false,
},
Expand Down Expand Up @@ -138,6 +140,10 @@ impl Config {
self.main.db_port.as_deref()
}

pub fn use_tls(&self) -> Option<bool> {
self.main.use_tls
}

pub fn set_db_user(self, db_user: &str) -> Config {
Config {
main: Main {
Expand Down Expand Up @@ -202,13 +208,12 @@ impl TryFrom<Url> for Config {
}
};

let query_params = url
.query_pairs()
.collect::<HashMap<Cow<'_, str>, Cow<'_, str>>>();

cfg_if::cfg_if! {
if #[cfg(feature = "tiberius-config")] {
use std::{borrow::Cow, collections::HashMap};
let query_params = url
.query_pairs()
.collect::<HashMap< Cow<'_, str>, Cow<'_, str>>>();

let trust_cert = query_params.
get("trust_cert")
.unwrap_or(&Cow::Borrowed("false"))
Expand All @@ -222,6 +227,21 @@ impl TryFrom<Url> for Config {
}
}

let use_tls = match query_params
.get("sslmode")
.unwrap_or(&Cow::Borrowed("disable"))
{
&Cow::Borrowed("disable") => Ok(false),
&Cow::Borrowed("require") => Ok(true),
_ => Err(()),
}
.map_err(|_| {
Error::new(
Kind::ConfigError("Invalid sslmode value, please use disable/require".into()),
None,
)
})?;
Comment on lines +230 to +243
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
let use_tls = match query_params
.get("sslmode")
.unwrap_or(&Cow::Borrowed("disable"))
{
&Cow::Borrowed("disable") => Ok(false),
&Cow::Borrowed("require") => Ok(true),
_ => Err(()),
}
.map_err(|_| {
Error::new(
Kind::ConfigError("Invalid sslmode value, please use disable/require".into()),
None,
)
})?;
let use_tls = match query_params
.get("sslmode")
.unwrap_or(&Cow::Borrowed("disable"))
{
&Cow::Borrowed("disable") => false,
&Cow::Borrowed("require") => true,
_ => return Error::new(
Kind::ConfigError("Invalid sslmode value, please use disable/require".into()),
None,
),
};


Ok(Self {
main: Main {
db_type,
Expand All @@ -237,6 +257,7 @@ impl TryFrom<Url> for Config {
db_user: Some(url.username().to_string()),
db_pass: url.password().map(|r| r.to_string()),
db_name: Some(url.path().trim_start_matches('/').to_string()),
use_tls: Some(use_tls),
#[cfg(feature = "tiberius-config")]
trust_cert,
},
Expand Down Expand Up @@ -268,6 +289,7 @@ struct Main {
db_user: Option<String>,
db_pass: Option<String>,
db_name: Option<String>,
use_tls: Option<bool>,
#[cfg(feature = "tiberius-config")]
#[serde(default)]
trust_cert: bool,
Expand Down Expand Up @@ -451,6 +473,24 @@ mod tests {
);
}

#[test]
fn build_no_tls_conn_from_str() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we do it all on one test? both the use not use and the invalid value?

let config =
Config::from_str("postgres://root:1234@localhost:5432/refinery?sslmode=disable")
.unwrap();
assert!(config.use_tls().is_some());
assert!(!config.use_tls().unwrap());
}

#[test]
fn build_tls_conn_from_str() {
let config =
Config::from_str("postgres://root:1234@localhost:5432/refinery?sslmode=require")
.unwrap();
assert!(config.use_tls().is_some());
assert!(config.use_tls().unwrap());
}

#[test]
fn builds_db_env_var_failure() {
std::env::set_var("DATABASE_URL", "this_is_not_a_url");
Expand Down
11 changes: 10 additions & 1 deletion refinery_core/src/drivers/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,16 @@ macro_rules! with_connection {
cfg_if::cfg_if! {
if #[cfg(feature = "postgres")] {
let path = build_db_url("postgresql", &$config);
let conn = postgres::Client::connect(path.as_str(), postgres::NoTls).migration_err("could not connect to database", None)?;

let conn;
if $config.use_tls().is_some() && $config.use_tls().unwrap() {
let builder = openssl::ssl::SslConnector::builder(openssl::ssl::SslMethod::tls()).unwrap();
let connector = postgres_openssl::MakeTlsConnector::new(builder.build());
conn = postgres::Client::connect(path.as_str(), connector).migration_err("could not connect to database", None)?;
} else {
conn = postgres::Client::connect(path.as_str(), postgres::NoTls).migration_err("could not connect to database", None)?;
}

$op(conn)
} else {
panic!("tried to migrate from config for a postgresql database, but feature postgres not enabled!");
Expand Down