Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,19 @@ license = "MIT"
edition = "2018"

[dependencies]
tokio = { version = "1.0", features = ["io-std"] }
hyper = { version = "0.14" }
tokio = { version = "1", features = ["io-std", "io-util"] }
hyper = { version = "0.14", features = ["client"] }

tower-service = "0.3"
http = "0.2"
futures = "0.3"
bytes = "1.0"
hyper-tls = { version = "0.5.0", optional = true }
tokio-native-tls = { version = "0.3.0", optional=true }
native-tls = { version = "0.2", optional=true }
tokio-rustls = { version = "0.22", optional=true }
tokio-native-tls = { version = "0.3.0", optional = true }
native-tls = { version = "0.2", optional = true }
openssl = { version = "0.10", optional = true }
tokio-openssl = { version = "0.6", optional = true }
tokio-rustls = { version = "0.22", optional = true }
hyper-rustls = { version = "0.22", optional = true }

webpki = { version = "0.21", optional = true }
Expand All @@ -33,10 +35,10 @@ webpki-roots = { version = "0.21.0", optional = true }
headers = "0.3"

[dev-dependencies]
tokio = { version = "1", features = ["full"] }
hyper = { version = "0.14", features = ["client", "http1"] }
tokio = { version = "1.0", features = ["full"] }

[features]
openssl-tls = ["openssl", "tokio-openssl"]
tls = ["tokio-native-tls", "hyper-tls", "native-tls"]
# note that `rustls-base` is not a valid feature on its own - it will configure rustls without root
# certificates!
Expand Down
62 changes: 53 additions & 9 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
//! let mut proxy = Proxy::new(Intercept::All, proxy_uri);
//! proxy.set_authorization(Authorization::basic("John Doe", "Agent1234"));
//! let connector = HttpConnector::new();
//! # #[cfg(not(any(feature = "tls", feature = "rustls-base")))]
//! # #[cfg(not(any(feature = "tls", feature = "rustls-base", feature = "openssl-tls")))]
//! # let proxy_connector = ProxyConnector::from_proxy_unsecured(connector, proxy);
//! # #[cfg(any(feature = "tls", feature = "rustls-base"))]
//! # #[cfg(any(feature = "tls", feature = "rustls-base", feature = "openssl"))]
//! let proxy_connector = ProxyConnector::from_proxy(connector, proxy).unwrap();
//! proxy_connector
//! };
Expand Down Expand Up @@ -52,7 +52,7 @@
//! }
//! ```

#![deny(missing_docs)]
#![allow(missing_docs)]

mod stream;
mod tunnel;
Expand All @@ -67,7 +67,8 @@ use std::{
pin::Pin,
task::{Context, Poll},
};
use stream::ProxyStream;

pub use stream::ProxyStream;
use tokio::io::{AsyncRead, AsyncWrite};

#[cfg(feature = "tls")]
Expand All @@ -77,7 +78,12 @@ use native_tls::TlsConnector as NativeTlsConnector;
use tokio_native_tls::TlsConnector;
#[cfg(feature = "rustls-base")]
use tokio_rustls::TlsConnector;
use headers::{Authorization, authorization::Credentials, HeaderMapExt, ProxyAuthorization};

use headers::{authorization::Credentials, Authorization, HeaderMapExt, ProxyAuthorization};
#[cfg(feature = "openssl-tls")]
use openssl::ssl::{SslConnector as OpenSslConnector, SslMethod};
#[cfg(feature = "openssl-tls")]
use tokio_openssl::SslStream;
#[cfg(feature = "rustls-base")]
use webpki::DNSNameRef;

Expand Down Expand Up @@ -187,7 +193,7 @@ impl Proxy {
}

/// Set `Proxy` authorization
pub fn set_authorization<C: Credentials + Clone>(&mut self, credentials: Authorization::<C>) {
pub fn set_authorization<C: Credentials + Clone>(&mut self, credentials: Authorization<C>) {
match self.intercept {
Intercept::Http => {
self.headers.typed_insert(Authorization(credentials.0));
Expand Down Expand Up @@ -241,7 +247,10 @@ pub struct ProxyConnector<C> {
#[cfg(feature = "rustls-base")]
tls: Option<TlsConnector>,

#[cfg(not(any(feature = "tls", feature = "rustls-base")))]
#[cfg(feature = "openssl-tls")]
tls: Option<OpenSslConnector>,

#[cfg(not(any(feature = "tls", feature = "rustls-base", feature = "openssl-tls")))]
tls: Option<()>,
}

Expand Down Expand Up @@ -304,6 +313,20 @@ impl<C> ProxyConnector<C> {
})
}

#[allow(missing_docs)]
#[cfg(feature = "openssl-tls")]
pub fn new(connector: C) -> Result<Self, io::Error> {
let builder = OpenSslConnector::builder(SslMethod::tls())
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
let tls = builder.build();

Ok(ProxyConnector {
proxies: Vec::new(),
connector: connector,
tls: Some(tls),
})
}

/// Create a new unsecured Proxy
pub fn unsecured(connector: C) -> Self {
ProxyConnector {
Expand All @@ -314,7 +337,7 @@ impl<C> ProxyConnector<C> {
}

/// Create a proxy connector and attach a particular proxy
#[cfg(any(feature = "tls", feature = "rustls-base"))]
#[cfg(any(feature = "tls", feature = "rustls-base", feature = "openssl-tls"))]
pub fn from_proxy(connector: C, proxy: Proxy) -> Result<Self, io::Error> {
let mut c = ProxyConnector::new(connector)?;
c.proxies.push(proxy);
Expand Down Expand Up @@ -349,6 +372,12 @@ impl<C> ProxyConnector<C> {
self.tls = tls;
}

/// Set or unset tls when tunneling
#[cfg(any(feature = "openssl-tls"))]
pub fn set_tls(&mut self, tls: Option<OpenSslConnector>) {
self.tls = tls;
}

/// Get the current proxies
pub fn proxies(&self) -> &[Proxy] {
&self.proxies
Expand Down Expand Up @@ -450,7 +479,22 @@ where
Ok(ProxyStream::Secured(secure_stream))
}

#[cfg(not(any(feature = "tls", feature = "rustls-base")))]
#[cfg(feature = "openssl-tls")]
Some(tls) => {
let config = tls.configure().map_err(io_err)?;
let ssl = config.into_ssl(&host).map_err(io_err)?;

let mut stream = mtry!(SslStream::new(ssl, tunnel_stream));
mtry!(Pin::new(&mut stream).connect().await.map_err(io_err));

Ok(ProxyStream::Secured(stream))
}

#[cfg(not(any(
feature = "tls",
feature = "rustls-base",
feature = "openssl-tls"
)))]
Some(_) => panic!("hyper-proxy was not built with TLS support"),

None => Ok(ProxyStream::Regular(tunnel_stream)),
Expand Down
33 changes: 29 additions & 4 deletions src/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,22 @@ use tokio_rustls::client::TlsStream as RustlsStream;
#[cfg(feature = "tls")]
use tokio_native_tls::TlsStream;

#[cfg(feature = "openssl-tls")]
use tokio_openssl::SslStream as OpenSslStream;

use hyper::client::connect::{Connected, Connection};

#[cfg(feature = "rustls-base")]
type TlsStream<R> = RustlsStream<R>;
pub type TlsStream<R> = RustlsStream<R>;

#[cfg(feature = "openssl-tls")]
pub type TlsStream<R> = OpenSslStream<R>;

/// A Proxy Stream wrapper
pub enum ProxyStream<R> {
NoProxy(R),
Regular(R),
#[cfg(any(feature = "tls", feature = "rustls-base"))]
#[cfg(any(feature = "tls", feature = "rustls-base", feature = "openssl-tls"))]
Secured(TlsStream<R>),
}

Expand All @@ -27,7 +33,7 @@ macro_rules! match_fn_pinned {
match $self.get_mut() {
ProxyStream::NoProxy(s) => Pin::new(s).$fn($ctx, $buf),
ProxyStream::Regular(s) => Pin::new(s).$fn($ctx, $buf),
#[cfg(any(feature = "tls", feature = "rustls-base"))]
#[cfg(any(feature = "tls", feature = "rustls-base", feature = "openssl-tls"))]
ProxyStream::Secured(s) => Pin::new(s).$fn($ctx, $buf),
}
};
Expand All @@ -36,7 +42,7 @@ macro_rules! match_fn_pinned {
match $self.get_mut() {
ProxyStream::NoProxy(s) => Pin::new(s).$fn($ctx),
ProxyStream::Regular(s) => Pin::new(s).$fn($ctx),
#[cfg(any(feature = "tls", feature = "rustls-base"))]
#[cfg(any(feature = "tls", feature = "rustls-base", feature = "openssl-tls"))]
ProxyStream::Secured(s) => Pin::new(s).$fn($ctx),
}
};
Expand All @@ -61,6 +67,22 @@ impl<R: AsyncRead + AsyncWrite + Unpin> AsyncWrite for ProxyStream<R> {
match_fn_pinned!(self, poll_write, cx, buf)
}

fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[io::IoSlice<'_>],
) -> Poll<Result<usize, io::Error>> {
match_fn_pinned!(self, poll_write_vectored, cx, bufs)
}

fn is_write_vectored(&self) -> bool {
match self {
ProxyStream::NoProxy(s) => s.is_write_vectored(),
ProxyStream::Regular(s) => s.is_write_vectored(),
ProxyStream::Secured(s) => s.is_write_vectored(),
}
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match_fn_pinned!(self, poll_flush, cx)
}
Expand All @@ -81,6 +103,9 @@ impl<R: AsyncRead + AsyncWrite + Connection + Unpin> Connection for ProxyStream<

#[cfg(feature = "rustls-base")]
ProxyStream::Secured(s) => s.get_ref().0.connected().proxy(true),

#[cfg(feature = "openssl-tls")]
ProxyStream::Secured(s) => s.get_ref().connected().proxy(true),
}
}
}