Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow multiple tls features #237

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions src/client/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,48 @@ pub struct Config {
pub(crate) encryption: EncryptionLevel,
pub(crate) trust: TrustConfig,
pub(crate) auth: AuthMethod,
pub(crate) tls_choice: TlsChoice,
}

#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub enum TlsChoice {
#[cfg(not(any(
feature = "rustls",
feature = "native-tls",
feature = "vendored-openssl"
)))]
None,
#[cfg(feature = "rustls")]
Rustls,
#[cfg(feature = "native-tls")]
NativeTls,
#[cfg(feature = "vendored-openssl")]
Openssl,
}

impl Default for TlsChoice {
#[allow(unreachable_code, clippy::needless_return)]
fn default() -> TlsChoice {
#[cfg(feature = "rustls")]
{
return TlsChoice::Rustls;
}
#[cfg(feature = "native-tls")]
{
return TlsChoice::NativeTls;
}
#[cfg(feature = "vendored-openssl")]
{
return TlsChoice::Openssl;
}

#[cfg(not(any(
feature = "rustls",
feature = "native-tls",
feature = "vendored-openssl"
)))]
TlsChoice::None
}
}

#[derive(Clone, Debug)]
Expand Down Expand Up @@ -62,6 +104,7 @@ impl Default for Config {
encryption: EncryptionLevel::NotSupported,
trust: TrustConfig::Default,
auth: AuthMethod::None,
tls_choice: TlsChoice::default(),
}
}
}
Expand Down Expand Up @@ -120,6 +163,11 @@ impl Config {
self.encryption = encryption;
}

/// Set the choice of Tls
pub fn tls_choice(&mut self, tls_choice: TlsChoice) {
self.tls_choice = tls_choice;
}

/// If set, the server certificate will not be validated and it is accepted
/// as-is.
///
Expand Down
23 changes: 19 additions & 4 deletions src/client/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
feature = "native-tls",
feature = "vendored-openssl"
))]
use crate::client::{tls::TlsPreloginWrapper, tls_stream::create_tls_stream};
use crate::client::{config::TlsChoice, tls::TlsPreloginWrapper, tls_stream};
use crate::{
client::{tls::MaybeTlsStream, AuthMethod, Config},
tds::{
Expand Down Expand Up @@ -442,10 +442,25 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send> Connection<S> {
let Self {
transport, context, ..
} = self;
let mut stream = match transport.into_inner() {
MaybeTlsStream::Raw(tcp) => {
create_tls_stream(config, TlsPreloginWrapper::new(tcp)).await?

let mut stream = match (transport.into_inner(), config.tls_choice) {
#[cfg(feature = "rustls")]
(MaybeTlsStream::Raw(tcp), TlsChoice::Rustls) => {
tls_stream::create_tls_stream_rustls(config, TlsPreloginWrapper::new(tcp))
.await?
}
#[cfg(feature = "vendored-openssl")]
(MaybeTlsStream::Raw(tcp), TlsChoice::Openssl) => {
tls_stream::create_tls_stream_openssl(config, TlsPreloginWrapper::new(tcp))
.await?
}
#[cfg(feature = "native-tls")]
(MaybeTlsStream::Raw(tcp), TlsChoice::NativeTls) => {
tls_stream::create_tls_stream_native_tls(config, TlsPreloginWrapper::new(tcp))
.await?
}
// this should still be fine as the relevant TlsChoices are only
// enabled when the equivalent tls crate is enabled
_ => unreachable!(),
};

Expand Down
121 changes: 108 additions & 13 deletions src/client/tls_stream.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
use crate::Config;
use futures::{AsyncRead, AsyncWrite};

use std::{
io,
pin::Pin,
task::{Context, Poll},
};
#[cfg(feature = "native-tls")]
mod native_tls_stream;

Expand All @@ -10,35 +14,126 @@ mod rustls_tls_stream;
#[cfg(feature = "vendored-openssl")]
mod opentls_tls_stream;

#[cfg(feature = "native-tls")]
pub(crate) use native_tls_stream::TlsStream;
// #[cfg(feature = "native-tls")]
// pub(crate) use native_tls_stream::TlsStream as NativeTlsStream;

#[cfg(feature = "rustls")]
pub(crate) use rustls_tls_stream::TlsStream;
// #[cfg(feature = "rustls")]
// pub(crate) use rustls_tls_stream::TlsStream as RustlsTlsStream;

#[cfg(feature = "vendored-openssl")]
pub(crate) use opentls_tls_stream::TlsStream;
// #[cfg(feature = "vendored-openssl")]
// pub(crate) use opentls_tls_stream::TlsStream as OptenSslTlsStream;

pub(crate) enum TlsStream<S: AsyncRead + AsyncWrite + Unpin + Send> {
#[cfg(feature = "vendored-openssl")]
Openssl(opentls_tls_stream::TlsStream<S>),
#[cfg(feature = "rustls")]
Rustls(rustls_tls_stream::TlsStream<S>),
#[cfg(feature = "native-tls")]
NativeTls(native_tls_stream::TlsStream<S>),
}

impl<S> TlsStream<S>
where
S: AsyncRead + AsyncWrite + Unpin + Send,
{
pub(crate) fn get_mut(&mut self) -> &mut S {
match self {
#[cfg(feature = "vendored-openssl")]
TlsStream::Openssl(s) => s.get_mut(),
#[cfg(feature = "rustls")]
TlsStream::Rustls(s) => s.get_mut(),
#[cfg(feature = "native-tls")]
TlsStream::NativeTls(s) => s.get_mut(),
}
}
}

impl<S: AsyncRead + AsyncWrite + Unpin + Send> AsyncRead for TlsStream<S> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let inner = Pin::get_mut(self);
match inner {
#[cfg(feature = "vendored-openssl")]
TlsStream::Openssl(s) => Pin::new(s).poll_read(cx, buf),
#[cfg(feature = "rustls")]
TlsStream::Rustls(s) => Pin::new(&mut s.0).poll_read(cx, buf),
#[cfg(feature = "native-tls")]
TlsStream::NativeTls(s) => Pin::new(s).poll_read(cx, buf),
}
}
}

impl<S: AsyncRead + AsyncWrite + Unpin + Send> AsyncWrite for TlsStream<S> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let inner = Pin::get_mut(self);
match inner {
#[cfg(feature = "vendored-openssl")]
TlsStream::Openssl(s) => Pin::new(s).poll_write(cx, buf),
#[cfg(feature = "rustls")]
TlsStream::Rustls(s) => Pin::new(&mut s.0).poll_write(cx, buf),
#[cfg(feature = "native-tls")]
TlsStream::NativeTls(s) => Pin::new(s).poll_write(cx, buf),
}
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let inner = Pin::get_mut(self);
match inner {
#[cfg(feature = "vendored-openssl")]
TlsStream::Openssl(s) => Pin::new(s).poll_flush(cx),
#[cfg(feature = "rustls")]
TlsStream::Rustls(s) => Pin::new(&mut s.0).poll_flush(cx),
#[cfg(feature = "native-tls")]
TlsStream::NativeTls(s) => Pin::new(s).poll_flush(cx),
}
}

fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let inner = Pin::get_mut(self);
match inner {
#[cfg(feature = "vendored-openssl")]
TlsStream::Openssl(s) => Pin::new(s).poll_close(cx),
#[cfg(feature = "rustls")]
TlsStream::Rustls(s) => Pin::new(&mut s.0).poll_close(cx),
#[cfg(feature = "native-tls")]
TlsStream::NativeTls(s) => Pin::new(s).poll_close(cx),
}
}
}

#[cfg(feature = "rustls")]
pub(crate) async fn create_tls_stream<S: AsyncRead + AsyncWrite + Unpin + Send>(
pub(crate) async fn create_tls_stream_rustls<S: AsyncRead + AsyncWrite + Unpin + Send>(
config: &Config,
stream: S,
) -> crate::Result<TlsStream<S>> {
TlsStream::new(config, stream).await
rustls_tls_stream::TlsStream::new(config, stream)
.await
.map(TlsStream::Rustls)
}

#[cfg(feature = "native-tls")]
pub(crate) async fn create_tls_stream<S: AsyncRead + AsyncWrite + Unpin + Send>(
pub(crate) async fn create_tls_stream_native_tls<S: AsyncRead + AsyncWrite + Unpin + Send>(
config: &Config,
stream: S,
) -> crate::Result<TlsStream<S>> {
native_tls_stream::create_tls_stream(config, stream).await
native_tls_stream::create_tls_stream(config, stream)
.await
.map(TlsStream::NativeTls)
}

#[cfg(feature = "vendored-openssl")]
pub(crate) async fn create_tls_stream<S: AsyncRead + AsyncWrite + Unpin + Send>(
pub(crate) async fn create_tls_stream_openssl<S: AsyncRead + AsyncWrite + Unpin + Send>(
config: &Config,
stream: S,
) -> crate::Result<TlsStream<S>> {
opentls_tls_stream::create_tls_stream(config, stream).await
opentls_tls_stream::create_tls_stream(config, stream)
.await
.map(TlsStream::Openssl)
}
2 changes: 1 addition & 1 deletion src/client/tls_stream/rustls_tls_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ impl From<tokio_rustls::webpki::Error> for Error {
}

pub(crate) struct TlsStream<S: AsyncRead + AsyncWrite + Unpin + Send>(
Compat<tokio_rustls::client::TlsStream<Compat<S>>>,
pub(super) Compat<tokio_rustls::client::TlsStream<Compat<S>>>,
);

struct NoCertVerifier;
Expand Down