Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ members = [


[dependencies]
tokio-native-tls = "0.3.0"
tokio-rustls = "0.24.1"
rustls-pemfile = "1.0.3"
rabbitmq-stream-protocol = { version = "0.2", path = "protocol" }
tokio = { version = "1.29.1", features = ["full"] }
tokio-util = { version = "0.7.3", features = ["codec"] }
Expand Down
68 changes: 68 additions & 0 deletions examples/tls_producer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
use tracing::info;
use tracing_subscriber::FmtSubscriber;

use rabbitmq_stream_client::{types::Message, Environment, NoDedup, Producer, TlsConfiguration};

const BATCH_SIZE: usize = 100;

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let stream_name = String::from("tls_test_stream");
let subscriber = FmtSubscriber::builder().finish();

tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed");

let tls_configuration: TlsConfiguration = TlsConfiguration::builder()
.add_root_certificate(String::from("/path/to/your/certificate-ca.pem"))
.build();

let environment = Environment::builder()
.host("localhost")
.port(5551)
.tls(tls_configuration)
.build()
.await?;

start_publisher(environment.clone(), &stream_name)
.await
.expect("error in publisher");

Ok(())
}

async fn start_publisher(
env: Environment,
stream: &String,
) -> Result<(), Box<dyn std::error::Error>> {
let _ = env.stream_creator().create(&stream).await;

let producer = env.producer().batch_size(BATCH_SIZE).build(&stream).await?;

let is_batch_send = true;
tokio::task::spawn(async move {
info!(
"Starting producer with batch size {} and batch send {}",
BATCH_SIZE, is_batch_send
);
info!("Sending {} simple messages", BATCH_SIZE);
batch_send_simple(&producer).await;
})
.await?;
Ok(())
}

async fn batch_send_simple(producer: &Producer<NoDedup>) {
let mut msg = Vec::with_capacity(BATCH_SIZE);
for i in 0..BATCH_SIZE {
msg.push(
Message::builder()
.body(format!("rust message{}", i))
.build(),
);
}

producer
.batch_send(msg, move |_| async move {})
.await
.unwrap();
}
31 changes: 23 additions & 8 deletions src/client/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::convert::TryFrom;

use std::{
collections::HashMap,
io,
Expand All @@ -13,12 +15,16 @@ use futures::{
Stream, StreamExt, TryFutureExt,
};
use pin_project::pin_project;
use rustls::ServerName;
use tokio::io::AsyncRead;
use tokio::io::AsyncWrite;
use tokio::io::ReadBuf;
use tokio::{net::TcpStream, sync::Notify};
use tokio::{sync::RwLock, task::JoinHandle};
use tokio_native_tls::TlsStream;
use tokio_rustls::client::TlsStream;
use tokio_rustls::rustls::ClientConfig;
use tokio_rustls::{rustls, TlsConnector};

use tokio_util::codec::Framed;
use tracing::trace;

Expand Down Expand Up @@ -414,17 +420,26 @@ impl Client {
> {
let stream = if broker.tls.enabled() {
let stream = TcpStream::connect((broker.host.as_str(), broker.port)).await?;
let mut roots = rustls::RootCertStore::empty();
let cert = broker.tls.get_root_certificates();
let cert_bytes = std::fs::read(cert);

let root_cert_store = rustls_pemfile::certs(&mut cert_bytes.unwrap().as_ref()).unwrap();

let mut tls_builder = tokio_native_tls::native_tls::TlsConnector::builder();
tls_builder
.danger_accept_invalid_certs(true)
.danger_accept_invalid_hostnames(true);
root_cert_store
.iter()
.for_each(|cert| roots.add(&rustls::Certificate(cert.to_vec())).unwrap());

let conn = tokio_native_tls::TlsConnector::from(tls_builder.build()?);
let config = ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(roots)
.with_no_client_auth();

let stream = conn.connect(broker.host.as_str(), stream).await?;
let connector = TlsConnector::from(Arc::new(config));
let domain = ServerName::try_from(broker.host.as_str()).unwrap();
let conn = connector.connect(domain, stream).await?;

GenericTcpStream::SecureTcp(stream)
GenericTcpStream::SecureTcp(conn)
} else {
let stream = TcpStream::connect((broker.host.as_str(), broker.port)).await?;
GenericTcpStream::Tcp(stream)
Expand Down
5 changes: 2 additions & 3 deletions src/client/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,15 @@ impl Default for ClientOptions {
collector: Arc::new(NopMetricsCollector {}),
tls: TlsConfiguration {
enabled: false,
hostname_verification: false,
trust_everything: false,
certificate_path: String::from(""),
},
}
}
}

impl ClientOptions {
pub fn get_tls(&self) -> TlsConfiguration {
self.tls
self.tls.clone()
}

pub fn enable_tls(&mut self) {
Expand Down
52 changes: 12 additions & 40 deletions src/environment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use crate::{
stream_creator::StreamCreator,
RabbitMQStreamResult,
};

/// Main access point to a node
#[derive(Clone)]
pub struct Environment {
Expand Down Expand Up @@ -108,18 +109,7 @@ impl EnvironmentBuilder {
}

pub fn tls(mut self, tls_configuration: TlsConfiguration) -> EnvironmentBuilder {
self.0
.client_options
.tls
.trust_everything(tls_configuration.trust_everything_enabled());
self.0
.client_options
.tls
.hostname_verification_enable(tls_configuration.hostname_verification_enabled());
self.0
.client_options
.tls
.enable(tls_configuration.enabled());
self.0.client_options.tls = tls_configuration;

self
}
Expand All @@ -142,28 +132,22 @@ pub struct EnvironmentOptions {
}

/** Helper for tls configuration */
#[derive(Clone, Copy)]
#[derive(Clone)]
pub struct TlsConfiguration {
pub(crate) enabled: bool,
pub(crate) hostname_verification: bool,
pub(crate) trust_everything: bool,
pub(crate) certificate_path: String,
}

impl Default for TlsConfiguration {
fn default() -> TlsConfiguration {
TlsConfiguration {
enabled: true,
trust_everything: false,
hostname_verification: true,
certificate_path: String::from(""),
}
}
}

impl TlsConfiguration {
pub fn trust_everything(&mut self, trust_everything: bool) {
self.trust_everything = trust_everything
}

pub fn enable(&mut self, enabled: bool) {
self.enabled = enabled
}
Expand All @@ -172,37 +156,25 @@ impl TlsConfiguration {
self.enabled
}

pub fn hostname_verification_enable(&mut self, hostname_verification: bool) {
self.hostname_verification = hostname_verification
pub fn get_root_certificates(&self) -> String {
self.certificate_path.clone()
}

pub fn hostname_verification_enabled(&self) -> bool {
self.hostname_verification
}

pub fn trust_everything_enabled(&self) -> bool {
self.trust_everything
//
pub fn add_root_certificate(&mut self, certificate_path: String) {
self.certificate_path = certificate_path
}
}

pub struct TlsConfigurationBuilder(TlsConfiguration);

impl TlsConfigurationBuilder {
pub fn trust_everything(mut self, trust_everything: bool) -> TlsConfigurationBuilder {
self.0.trust_everything = trust_everything;
self
}

pub fn enable(mut self, enable: bool) -> TlsConfigurationBuilder {
self.0.enabled = enable;
self
}

pub fn hostname_verification_enable(
mut self,
hostname_verification: bool,
) -> TlsConfigurationBuilder {
self.0.hostname_verification = hostname_verification;
pub fn add_root_certificate(mut self, certificate_path: String) -> TlsConfigurationBuilder {
self.0.certificate_path = certificate_path;
self
}

Expand Down
2 changes: 1 addition & 1 deletion src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ pub enum ClientError {
#[error("Client already closed")]
AlreadyClosed,
#[error(transparent)]
Tls(#[from] tokio_native_tls::native_tls::Error),
Tls(#[from] tokio_rustls::rustls::Error),
#[error("Request error: {0:?}")]
RequestError(ResponseCode),
}
Expand Down
4 changes: 3 additions & 1 deletion tests/integration/environment_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ async fn environment_create_streams_with_parameters() {
assert_eq!(delete_response.is_ok(), true);
}

/*
#[tokio::test(flavor = "multi_thread")]
async fn environment_fail_tls_connection() {
// in this test we try to connect to a server that does not support tls
Expand All @@ -144,8 +145,9 @@ async fn environment_fail_tls_connection() {
.tls(TlsConfiguration::default())
.build()
.await;

assert!(matches!(
env,
Err(rabbitmq_stream_client::error::ClientError::Tls { .. })
));
}
}*/