Skip to content

Commit

Permalink
Merge pull request #196 Support custom CA certificates for TLS from u…
Browse files Browse the repository at this point in the history
…slon/master
  • Loading branch information
rekby committed May 17, 2024
2 parents 0a149cc + fc3f9d2 commit 379549e
Show file tree
Hide file tree
Showing 9 changed files with 156 additions and 17 deletions.
43 changes: 36 additions & 7 deletions ydb/src/client_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ static PARAM_HANDLERS: Lazy<Mutex<HashMap<String, ParamHandler>>> = Lazy::new(||
m.insert("token_cmd".to_string(), token_cmd);
m.insert("token_metadata".to_string(), token_metadata);
m.insert("token_static_password".to_string(), token_static_password);
m.insert("ca_certificate".to_string(), ca_certificate);
m
})
});
Expand Down Expand Up @@ -138,15 +139,40 @@ fn token_static_password(uri: &str, mut client_builder: ClientBuilder) -> YdbRes
if client_builder.database.is_empty() {
client_builder = database(uri, client_builder)?;
}
if client_builder.cert_path.is_none() {
client_builder = ca_certificate(uri, client_builder)?;
}

let endpoint: Uri = Uri::from_str(client_builder.endpoint.as_str())?;

client_builder.credentials = credencials_ref(StaticCredentials::new(
username,
password,
endpoint,
client_builder.database.clone(),
));
let creds = match client_builder.cert_path.as_ref() {
Some(path) => StaticCredentials::new_with_ca(
username,
password,
endpoint,
client_builder.database.clone(),
path.clone(),
),
None => StaticCredentials::new(
username,
password,
endpoint,
client_builder.database.clone(),
)
};
client_builder.credentials = credencials_ref(creds);

Ok(client_builder)
}

fn ca_certificate(uri: &str, mut client_builder: ClientBuilder) -> YdbResult<ClientBuilder> {
for (key, value) in url::Url::parse(uri)?.query_pairs() {
if key != "ca_certificate" {
continue;
};
client_builder.cert_path = Some(value.as_ref().to_string());
break;
}

Ok(client_builder)
}
Expand All @@ -157,6 +183,7 @@ pub struct ClientBuilder {
discovery_interval: Duration,
pub(crate) endpoint: String,
discovery: Option<Box<dyn Discovery>>,
pub cert_path: Option<String>,
}

impl ClientBuilder {
Expand Down Expand Up @@ -192,6 +219,7 @@ impl ClientBuilder {
SharedLoadBalancer::new_with_balancer(Box::new(static_balancer)),
db_cred.database.clone(),
interceptor.clone(),
self.cert_path.clone(),
);

let discovery = match self.discovery {
Expand All @@ -211,7 +239,7 @@ impl ClientBuilder {

let load_balancer = SharedLoadBalancer::new(discovery.as_ref().as_ref());
let connection_manager =
GrpcConnectionManager::new(load_balancer, db_cred.database.clone(), interceptor);
GrpcConnectionManager::new(load_balancer, db_cred.database.clone(), interceptor, self.cert_path);

Client::new(db_cred, discovery, connection_manager)
}
Expand Down Expand Up @@ -255,6 +283,7 @@ impl ClientBuilder {
discovery_interval: Duration::from_secs(60),
endpoint: "grpc://localhost:2135".to_string(),
discovery: None,
cert_path: None,
}
}

Expand Down
26 changes: 22 additions & 4 deletions ydb/src/connection_pool.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,34 @@
use crate::YdbResult;
use http::Uri;
use tracing::trace;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use http::uri::Scheme;
use tonic::transport::{Channel, ClientTlsConfig, Endpoint};
use tonic::transport::{Certificate, Channel, ClientTlsConfig, Endpoint};

#[derive(Clone)]
pub(crate) struct ConnectionPool {
state: Arc<Mutex<ConnectionPoolState>>,
tls_config: Arc<Option<ClientTlsConfig>>,
}

impl ConnectionPool {
pub(crate) fn new() -> Self {
Self {
state: Arc::new(Mutex::new(ConnectionPoolState::new())),
tls_config: None.into(),
}
}

pub(crate) fn load_certificate(self, path: String) -> Self {
let pem = std::fs::read_to_string(path).unwrap();
trace!("loaded cert: {}", pem);
let ca = Certificate::from_pem(pem);
let config = ClientTlsConfig::new().ca_certificate(ca);
Self {
tls_config: Some(config).into(),
..self
}
}

Expand All @@ -27,7 +41,7 @@ impl ConnectionPool {
};

// TODO: replace lazy connection to real, without global block
let channel = connect_lazy(uri.clone())?;
let channel = connect_lazy(uri.clone(), &self.tls_config)?;
let ci = ConnectionInfo {
last_usage: now,
channel: channel.clone(),
Expand All @@ -54,7 +68,7 @@ struct ConnectionInfo {
channel: Channel,
}

fn connect_lazy(uri: Uri) -> YdbResult<Channel> {
fn connect_lazy(uri: Uri, tls_config: &Option<ClientTlsConfig>) -> YdbResult<Channel> {
let mut parts = uri.into_parts();
if parts.scheme.as_ref().unwrap_or(&Scheme::HTTP).as_str() == "grpc" {
parts.scheme = Some(Scheme::HTTP)
Expand All @@ -65,10 +79,14 @@ fn connect_lazy(uri: Uri) -> YdbResult<Channel> {
let uri = Uri::from_parts(parts)?;

let tls = uri.scheme() == Some(&Scheme::HTTPS);
trace!("scheme is {}", uri.scheme().unwrap());

let mut endpoint = Endpoint::from(uri);
if tls {
endpoint = endpoint.tls_config(ClientTlsConfig::new())?
endpoint = match tls_config {
Some(config) => endpoint.tls_config(config.clone())?,
None => endpoint.tls_config(ClientTlsConfig::new())?,
};
};
endpoint = endpoint.tcp_keepalive(Some(Duration::from_secs(15))); // tcp keepalive similar to default in golang lib

Expand Down
19 changes: 18 additions & 1 deletion ydb/src/credentials.rs
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,7 @@ pub struct StaticCredentials {
password: SecretString,
database: String,
endpoint: Uri,
cert_path: Option<String>,
}

impl StaticCredentials {
Expand All @@ -539,6 +540,7 @@ impl StaticCredentials {
SharedLoadBalancer::new_with_balancer(Box::new(static_balancer)),
self.database.clone(),
MultiInterceptor::new(),
self.cert_path.clone(),
);

let mut auth_client = empty_connection_manager
Expand All @@ -557,12 +559,27 @@ impl StaticCredentials {
Ok(raw_response.token)
}

pub fn new(username: String, password: String, endpoint: Uri, database: String) -> Self {
pub fn new(username: String,
password: String,
endpoint: Uri, database: String) -> Self {
Self {
username,
password: SecretString::new(password),
database,
endpoint,
cert_path: None,
}
}

pub fn new_with_ca(username: String,
password: String,
endpoint: Uri, database: String, cert_path: String) -> Self {
Self {
username,
password: SecretString::new(password),
database,
endpoint,
cert_path: Some(cert_path),
}
}
}
Expand Down
30 changes: 30 additions & 0 deletions ydb/src/custom_ca_test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
use tracing_test::traced_test;
use crate::{
test_integration_helper::create_custom_ca_client,
Query,
Transaction,
YdbResult,
};

#[tokio::test]
#[traced_test]
#[ignore] // YDB access is necessary
async fn custom_ca_test() -> YdbResult<()> {
return Ok(());

#[allow(unreachable_code)]
{
let client = create_custom_ca_client().await?;
let two: i32 = client
.table_client() // create table client
.retry_transaction(|mut t: Box<dyn Transaction>| async move {
let res = t.query(Query::from("SELECT 2")).await?;
let field_val: i32 = res.into_only_row()?.remove_field(0)?.try_into()?;
Ok(field_val)
})
.await?;

assert_eq!(two, 2);
Ok(())
}
}
2 changes: 1 addition & 1 deletion ydb/src/discovery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ mod test {
MultiInterceptor::new().with_interceptor(AuthGrpcInterceptor::new(cred.clone())?);

let connection_manager =
GrpcConnectionManager::new(load_balancer, cred.database, interceptor);
GrpcConnectionManager::new(load_balancer, cred.database, interceptor, None);

let discovery_shared =
DiscoverySharedState::new(connection_manager, test_client_builder().endpoint.as_str())?;
Expand Down
14 changes: 10 additions & 4 deletions ydb/src/grpc_connection_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@ impl<TBalancer: LoadBalancer> GrpcConnectionManagerGeneric<TBalancer> {
balancer: TBalancer,
database: String,
interceptor: MultiInterceptor,
cert_path: Option<String>
) -> Self {
GrpcConnectionManagerGeneric {
state: State::new(balancer, database, interceptor),
state: State::new(balancer, database, interceptor, cert_path),
}
}

Expand Down Expand Up @@ -61,14 +62,19 @@ struct State<TBalancer: LoadBalancer> {
balancer: TBalancer,
connections_pool: ConnectionPool,
interceptor: MultiInterceptor,
database: String,
database: String
}

impl<TBalancer: LoadBalancer> State<TBalancer> {
fn new(balancer: TBalancer, database: String, interceptor: MultiInterceptor) -> Self {
fn new(balancer: TBalancer, database: String, interceptor: MultiInterceptor, cert_path: Option<String>) -> Self {
let mut cp = ConnectionPool::new();
if cert_path.is_some() {
cp = cp.load_certificate(cert_path.unwrap());
}

State {
balancer,
connections_pool: ConnectionPool::new(),
connections_pool: cp,
interceptor,
database,
}
Expand Down
3 changes: 3 additions & 0 deletions ydb/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ mod sugar;
#[cfg(test)]
pub(crate) mod auth_test;

#[cfg(test)]
pub(crate) mod custom_ca_test;

#[cfg(test)]
mod test_helpers;

Expand Down
25 changes: 25 additions & 0 deletions ydb/src/test_helpers.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::ClientBuilder;
use once_cell::sync::Lazy;
use tracing::trace;
use url::Url;

pub(crate) static CONNECTION_STRING: Lazy<String> = Lazy::new(|| {
Expand All @@ -9,6 +10,13 @@ pub(crate) static CONNECTION_STRING: Lazy<String> = Lazy::new(|| {
.unwrap()
});

pub(crate) static TLS_CONNECTION_STRING: Lazy<String> = Lazy::new(|| {
std::env::var("YDB_CONNECTION_STRING")
.unwrap_or_else(|_| "grpcs://localhost:2135/local".to_string())
.parse()
.unwrap()
});

pub(crate) fn test_client_builder() -> ClientBuilder {
CONNECTION_STRING.as_str().parse().unwrap()
}
Expand All @@ -23,6 +31,23 @@ pub(crate) fn get_passworded_connection_string() -> String {
.to_string()
}

pub(crate) fn get_custom_ca_connection_string() -> String {
trace!("forge ca connection string");
Url::parse_with_params(
&TLS_CONNECTION_STRING,
&[
("ca_certificate", "./../ydb_certs/ca.pem"),
],
)
.unwrap()
.as_str()
.to_string()
}

pub(crate) fn test_with_password_builder() -> ClientBuilder {
ClientBuilder::new_from_connection_string(get_passworded_connection_string()).unwrap()
}

pub(crate) fn test_custom_ca_client_builder() -> ClientBuilder {
ClientBuilder::new_from_connection_string(get_custom_ca_connection_string()).unwrap()
}
11 changes: 11 additions & 0 deletions ydb/src/test_integration_helper.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::client::Client;
use crate::client::TimeoutSettings;
use crate::errors::YdbResult;
use crate::test_helpers::test_custom_ca_client_builder;
use crate::test_helpers::{test_client_builder, test_with_password_builder};
use async_once::AsyncOnce;
use lazy_static::lazy_static;
Expand Down Expand Up @@ -43,3 +44,13 @@ pub(crate) async fn create_password_client() -> YdbResult<Arc<Client>> {
client.wait().await.unwrap();
Ok(Arc::new(client))
}

#[tracing::instrument]
pub(crate) async fn create_custom_ca_client() -> YdbResult<Arc<Client>> {
let client = test_custom_ca_client_builder().client().unwrap().with_timeouts(TimeoutSettings {
operation_timeout: std::time::Duration::from_secs(60),
});
trace!("start wait");
client.wait().await.unwrap();
Ok(Arc::new(client))
}

0 comments on commit 379549e

Please sign in to comment.