Skip to content

Commit

Permalink
Upgrade to tokio 1.0
Browse files Browse the repository at this point in the history
  • Loading branch information
e00E committed Jan 6, 2021
1 parent 05e445d commit 81400d7
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 67 deletions.
22 changes: 11 additions & 11 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,29 +14,29 @@ license = "MIT"
edition = "2018"

[dependencies]
tokio = { version = "0.2.22", features = ["io-std"] }
hyper = { version = "0.13" }
tokio = { version = "1.0", features = ["io-std"] }
hyper = { version = "0.14" }

tower-service = "0.3"
http = "0.2"
futures = "0.3"
bytes = "0.5"
hyper-tls = { optional = true, version = "0.4.3" }
tokio-tls = { version = "0.3", optional=true }
bytes = "1.0"
hyper-tls = { version = "0.5.0", optional = true }
tokio-native-tls = { version = "0.3.0", optional=true }
native-tls = { version = "0.2", optional=true }
tokio-rustls = { version = "0.14", optional=true }
hyper-rustls = { version="0.21", optional=true }
tokio-rustls = { version = "0.22", optional=true }
hyper-rustls = { version = "0.22", optional = true }

webpki = { version = "0.21", optional = true }
rustls-native-certs = { version = "0.4.0", optional = true }
webpki-roots = { version = "0.20.0", optional = true }
rustls-native-certs = { version = "0.5.0", optional = true }
webpki-roots = { version = "0.21.0", optional = true }
typed-headers = "0.2"

[dev-dependencies]
tokio = { version = "0.2.22", features = ["full"] }
tokio = { version = "1.0", features = ["full"] }

[features]
tls = ["tokio-tls", "hyper-tls", "native-tls"]
tls = ["tokio-native-tls", "hyper-tls", "native-tls"]
# note that `rustls-base` is not a valid feature on its own - it will configure rustls without root
# certificates!
rustls-base = ["tokio-rustls", "hyper-rustls", "webpki"]
Expand Down
26 changes: 14 additions & 12 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
//!
//! # Example
//! ```rust,no_run
//! use hyper::{Client, Request, Uri};
//! use hyper::{Client, Request, Uri, body::HttpBody};
//! use hyper::client::HttpConnector;
//! use futures::{TryFutureExt, TryStreamExt};
//! use hyper_proxy::{Proxy, ProxyConnector, Intercept};
//! use typed_headers::Credentials;
//! use std::error::Error;
//! use tokio::io::{stdout, AsyncWriteExt as _};
//!
//! #[tokio::main]
//! async fn main() -> Result<(), Box<dyn Error>> {
Expand All @@ -33,18 +34,19 @@
//! }
//!
//! let client = Client::builder().build(proxy);
//! let fut_http = client.request(req)
//! .and_then(|res| res.into_body().map_ok(|x|x.to_vec()).try_concat())
//! .map_ok(move |body| ::std::str::from_utf8(&body).unwrap().to_string());
//! let mut resp = client.request(req).await?;
//! println!("Response: {}", resp.status());
//! while let Some(chunk) = resp.body_mut().data().await {
//! stdout().write_all(&chunk?).await?;
//! }
//!
//! // Connecting to an https uri is straightforward (uses 'CONNECT' method underneath)
//! let uri = "https://my-remote-websitei-secured.com".parse().unwrap();
//! let fut_https = client.get(uri)
//! .and_then(|res| res.into_body().map_ok(|x|x.to_vec()).try_concat())
//! .map_ok(move |body| ::std::str::from_utf8(&body).unwrap().to_string());
//!
//! let (http_res, https_res) = futures::future::join(fut_http, fut_https).await;
//! let (_, _) = (http_res?, https_res?);
//! let mut resp = client.get(uri).await?;
//! println!("Response: {}", resp.status());
//! while let Some(chunk) = resp.body_mut().data().await {
//! stdout().write_all(&chunk?).await?;
//! }
//!
//! Ok(())
//! }
Expand All @@ -71,10 +73,10 @@ use tokio::io::{AsyncRead, AsyncWrite};
#[cfg(feature = "tls")]
use native_tls::TlsConnector as NativeTlsConnector;

#[cfg(feature = "tls")]
use tokio_native_tls::TlsConnector;
#[cfg(feature = "rustls-base")]
use tokio_rustls::TlsConnector;
#[cfg(feature = "tls")]
use tokio_tls::TlsConnector;
use typed_headers::{Authorization, Credentials, HeaderMapExt, ProxyAuthorization};
#[cfg(feature = "rustls-base")]
use webpki::DNSNameRef;
Expand Down
42 changes: 5 additions & 37 deletions src/stream.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
use bytes::{Buf, BufMut};
use std::io;
use std::mem::MaybeUninit;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};

#[cfg(feature = "rustls-base")]
use tokio_rustls::client::TlsStream as RustlsStream;

#[cfg(feature = "tls")]
use tokio_tls::TlsStream;
use tokio_native_tls::TlsStream;

use hyper::client::connect::{Connected, Connection};

Expand Down Expand Up @@ -45,35 +43,13 @@ macro_rules! match_fn_pinned {
}

impl<R: AsyncRead + AsyncWrite + Unpin> AsyncRead for ProxyStream<R> {
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit<u8>]) -> bool {
match *self {
ProxyStream::NoProxy(ref s) => s.prepare_uninitialized_buffer(buf),

ProxyStream::Regular(ref s) => s.prepare_uninitialized_buffer(buf),

#[cfg(any(feature = "tls", feature = "rustls-base"))]
ProxyStream::Secured(ref s) => s.prepare_uninitialized_buffer(buf),
}
}

fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
match_fn_pinned!(self, poll_read, cx, buf)
}

fn poll_read_buf<B: BufMut>(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut B,
) -> Poll<io::Result<usize>>
where
Self: Sized,
{
match_fn_pinned!(self, poll_read_buf, cx, buf)
}
}

impl<R: AsyncRead + AsyncWrite + Unpin> AsyncWrite for ProxyStream<R> {
Expand All @@ -92,14 +68,6 @@ impl<R: AsyncRead + AsyncWrite + Unpin> AsyncWrite for ProxyStream<R> {
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match_fn_pinned!(self, poll_shutdown, cx)
}

fn poll_write_buf<B: Buf>(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut B,
) -> Poll<io::Result<usize>> {
match_fn_pinned!(self, poll_write_buf, cx, buf)
}
}

impl<R: AsyncRead + AsyncWrite + Connection + Unpin> Connection for ProxyStream<R> {
Expand All @@ -109,7 +77,7 @@ impl<R: AsyncRead + AsyncWrite + Connection + Unpin> Connection for ProxyStream<

ProxyStream::Regular(s) => s.connected().proxy(true),
#[cfg(feature = "tls")]
ProxyStream::Secured(s) => s.get_ref().connected().proxy(true),
ProxyStream::Secured(s) => s.get_ref().get_ref().get_ref().connected().proxy(true),

#[cfg(feature = "rustls-base")]
ProxyStream::Secured(s) => s.get_ref().0.connected().proxy(true),
Expand Down
14 changes: 7 additions & 7 deletions src/tunnel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::future::Future;
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};

macro_rules! try_ready {
($x:expr) => {
Expand Down Expand Up @@ -88,9 +88,9 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Future for Tunnel<S> {

loop {
if let TunnelState::Writing = &this.state {
let n = try_ready!(
Pin::new(this.stream.as_mut().unwrap()).poll_write_buf(ctx, &mut this.buf)
);
let fut = this.stream.as_mut().unwrap().write_buf(&mut this.buf);
futures::pin_mut!(fut);
let n = try_ready!(fut.poll(ctx));

if !this.buf.has_remaining() {
this.state = TunnelState::Reading;
Expand All @@ -99,9 +99,9 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Future for Tunnel<S> {
return Poll::Ready(Err(io_err("unexpected EOF while tunnel writing")));
}
} else {
let n = try_ready!(
Pin::new(this.stream.as_mut().unwrap()).poll_read_buf(ctx, &mut this.buf)
);
let fut = this.stream.as_mut().unwrap().read_buf(&mut this.buf);
futures::pin_mut!(fut);
let n = try_ready!(fut.poll(ctx));

if n == 0 {
return Poll::Ready(Err(io_err("unexpected EOF while tunnel reading")));
Expand Down

0 comments on commit 81400d7

Please sign in to comment.