Skip to content

Commit

Permalink
Upgrade to tokio 1.0
Browse files Browse the repository at this point in the history
  • Loading branch information
e00E committed Jan 22, 2021
1 parent a3a5f3e commit 88b32b3
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 81 deletions.
23 changes: 13 additions & 10 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,16 @@ pin-project = "1.0"
# Optional deps
## HTTP
base64 = { version = "0.13", optional = true }
hyper = { version = "0.13", optional = true, default-features = false, features = ["stream", "tcp"] }
hyper-tls = { version = "0.4", optional = true }
hyper-proxy = {version = "0.8.0", optional = true }
typed-headers = { version = "0.2.0", optional = true }
hyper = { version = "0.14", optional = true, default-features = false, features = ["client", "http1", "stream", "tcp"] }
hyper-tls = { version = "0.5", optional = true }
hyper-proxy = { git = "https://github.com/e00E/hyper-proxy.git", branch = "upgrade-tokio", optional = true }
headers = { version = "0.3", optional = true }
## WS
async-native-tls = { version = "0.3", optional = true, default-features = false }
async-native-tls = { git = "https://github.com/e00E/async-native-tls.git", branch = "tokio-upgrade", optional = true, default-features = false }
async-std = { version = "1.6", optional = true }
tokio = { version = "0.2", optional = true, features = ["full"] }
tokio-util = { version = "0.6", optional = true, features = ["compat"] }
tokio = { version = "1.0", optional = true, features = ["full"] }
tokio-stream = { version = "0.1", optional = true }
tokio-util = { version = "0.6", optional = true, features = ["compat", "io"] }
soketto = { version = "0.4.1", optional = true }
## Shared (WS, HTTP)
url = { version = "2.1", optional = true }
Expand All @@ -58,19 +59,21 @@ hex-literal = "0.3"
wasm-bindgen-test = "0.3.19"

[target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies]
tokio = { version = "0.2", features = ["full"] }
hyper = { version = "0.14", default-features = false, features = ["server"] }
tokio = { version = "1.0", features = ["full"] }
tokio-stream = { version = "0.1", features = ["net"] }

[features]
default = ["http-tls", "signing", "ws-tls-tokio", "ipc-tokio"]
eip-1193 = ["js-sys", "wasm-bindgen", "wasm-bindgen-futures", "futures-timer/wasm-bindgen", "rand", "getrandom"]
http = ["hyper", "hyper-proxy", "url", "base64", "typed-headers"]
http = ["hyper", "hyper-proxy", "url", "base64", "headers"]
http-tls = ["hyper-tls", "http"]
signing = ["secp256k1"]
ws-tokio = ["soketto", "url", "tokio", "tokio-util"]
ws-async-std = ["soketto", "url", "async-std"]
ws-tls-tokio = ["async-native-tls", "async-native-tls/runtime-tokio", "ws-tokio"]
ws-tls-async-std = ["async-native-tls", "async-native-tls/runtime-async-std", "ws-async-std"]
ipc-tokio = ["tokio"]
ipc-tokio = ["tokio", "tokio-stream", "tokio-util"]
test = []

[workspace]
4 changes: 1 addition & 3 deletions src/transports/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,7 @@ impl Http {
let mut proxy = hyper_proxy::Proxy::new(hyper_proxy::Intercept::All, uri);

if username != "" {
let credentials =
typed_headers::Credentials::basic(&username, &password).map_err(|_| Error::Internal)?;

let credentials = headers::Authorization::basic(&username, &password);
proxy.set_authorization(credentials);
}

Expand Down
31 changes: 16 additions & 15 deletions src/transports/ipc.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
//! IPC transport

use crate::{api::SubscriptionId, helpers, BatchTransport, DuplexTransport, Error, RequestId, Result, Transport};
use futures::future::{join_all, JoinAll};
use futures::{
future::{join_all, JoinAll},
stream::StreamExt,
};
use jsonrpc_core as rpc;
use std::{
collections::BTreeMap,
Expand All @@ -11,11 +14,12 @@ use std::{
task::{Context, Poll},
};
use tokio::{
io::{reader_stream, AsyncWriteExt},
io::AsyncWriteExt,
net::UnixStream,
stream::StreamExt,
sync::{mpsc, oneshot},
};
use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_util::io::ReaderStream;

/// Unix Domain Sockets (IPC) transport.
#[derive(Debug, Clone)]
Expand All @@ -39,7 +43,7 @@ impl Ipc {
let id = Arc::new(AtomicUsize::new(1));
let (messages_tx, messages_rx) = mpsc::unbounded_channel();

tokio::spawn(run_server(stream, messages_rx));
tokio::spawn(run_server(stream, UnboundedReceiverStream::new(messages_rx)));

Ipc { id, messages_tx }
}
Expand Down Expand Up @@ -90,12 +94,12 @@ impl BatchTransport for Ipc {
}

impl DuplexTransport for Ipc {
type NotificationStream = mpsc::UnboundedReceiver<rpc::Value>;
type NotificationStream = UnboundedReceiverStream<rpc::Value>;

fn subscribe(&self, id: SubscriptionId) -> Result<Self::NotificationStream> {
let (tx, rx) = mpsc::unbounded_channel();
self.messages_tx.send(TransportMessage::Subscribe(id, tx))?;
Ok(rx)
Ok(UnboundedReceiverStream::new(rx))
}

fn unsubscribe(&self, id: SubscriptionId) -> Result<()> {
Expand Down Expand Up @@ -150,12 +154,12 @@ enum TransportMessage {
}

#[cfg(unix)]
async fn run_server(unix_stream: UnixStream, messages_rx: mpsc::UnboundedReceiver<TransportMessage>) -> Result<()> {
async fn run_server(unix_stream: UnixStream, messages_rx: UnboundedReceiverStream<TransportMessage>) -> Result<()> {
let (socket_reader, mut socket_writer) = unix_stream.into_split();
let mut pending_response_txs = BTreeMap::default();
let mut subscription_txs = BTreeMap::default();

let mut socket_reader = reader_stream(socket_reader);
let mut socket_reader = ReaderStream::new(socket_reader);
let mut messages_rx = messages_rx.fuse();
let mut read_buffer = vec![];

Expand Down Expand Up @@ -330,10 +334,7 @@ impl From<oneshot::error::RecvError> for Error {
mod test {
use super::*;
use serde_json::json;
use tokio::{
io::{reader_stream, AsyncWriteExt},
net::UnixStream,
};
use tokio::{io::AsyncWriteExt, net::UnixStream};

#[tokio::test]
async fn works_for_single_requests() {
Expand Down Expand Up @@ -370,7 +371,7 @@ mod test {
async fn eth_node_single(stream: UnixStream) {
let (rx, mut tx) = stream.into_split();

let mut rx = reader_stream(rx);
let mut rx = ReaderStream::new(rx);
if let Some(Ok(bytes)) = rx.next().await {
let v: serde_json::Value = serde_json::from_slice(&bytes).unwrap();

Expand Down Expand Up @@ -434,7 +435,7 @@ mod test {
async fn eth_node_batch(stream: UnixStream) {
let (rx, mut tx) = stream.into_split();

let mut rx = reader_stream(rx);
let mut rx = ReaderStream::new(rx);
if let Some(Ok(bytes)) = rx.next().await {
let v: serde_json::Value = serde_json::from_slice(&bytes).unwrap();

Expand Down Expand Up @@ -489,7 +490,7 @@ mod test {
async fn eth_node_partial_batches(stream: UnixStream) {
let (rx, mut tx) = stream.into_split();
let mut buf = vec![];
let mut rx = reader_stream(rx);
let mut rx = ReaderStream::new(rx);
while let Some(Ok(bytes)) = rx.next().await {
buf.extend(bytes);

Expand Down
63 changes: 10 additions & 53 deletions src/transports/ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ impl WsServerTask {
select! {
msg = requests.next() => match msg {
Some(TransportMessage::Request { id, request, sender: tx }) => {
if pending.insert(id.clone(), tx).is_some() {
if pending.insert(id, tx).is_some() {
log::warn!("Replacing a pending request with id {:?}", id);
}
let res = sender.send_text(request).await;
Expand Down Expand Up @@ -459,6 +459,10 @@ pub mod compat {
/// Compatibility layer between async-std and tokio
#[cfg(feature = "ws-tokio")]
pub mod compat {
use std::io;
use tokio::io::AsyncRead;
use tokio_util::compat::{Compat, TokioAsyncReadCompatExt};

/// async-std compatible TcpStream.
pub type TcpStream = Compat<tokio::net::TcpStream>;
/// async-std compatible TcpListener.
Expand All @@ -470,62 +474,14 @@ pub mod compat {
#[cfg(not(feature = "ws-tls-tokio"))]
pub type TlsStream = TcpStream;

use std::{
io,
pin::Pin,
task::{Context, Poll},
};

/// Create new TcpStream object.
pub async fn raw_tcp_stream(addrs: String) -> io::Result<tokio::net::TcpStream> {
Ok(tokio::net::TcpStream::connect(addrs).await?)
}

/// Wrap given argument into compatibility layer.
pub fn compat<T>(t: T) -> Compat<T> {
Compat(t)
}

/// Compatibility layer.
pub struct Compat<T>(T);
impl<T: tokio::io::AsyncWrite + Unpin> tokio::io::AsyncWrite for Compat<T> {
fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
tokio::io::AsyncWrite::poll_write(Pin::new(&mut self.0), cx, buf)
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
tokio::io::AsyncWrite::poll_flush(Pin::new(&mut self.0), cx)
}

fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
tokio::io::AsyncWrite::poll_shutdown(Pin::new(&mut self.0), cx)
}
}

impl<T: tokio::io::AsyncWrite + Unpin> futures::AsyncWrite for Compat<T> {
fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
tokio::io::AsyncWrite::poll_write(Pin::new(&mut self.0), cx, buf)
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
tokio::io::AsyncWrite::poll_flush(Pin::new(&mut self.0), cx)
}

fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
tokio::io::AsyncWrite::poll_shutdown(Pin::new(&mut self.0), cx)
}
}

impl<T: tokio::io::AsyncRead + Unpin> futures::AsyncRead for Compat<T> {
fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
tokio::io::AsyncRead::poll_read(Pin::new(&mut self.0), cx, buf)
}
}

impl<T: tokio::io::AsyncRead + Unpin> tokio::io::AsyncRead for Compat<T> {
fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
tokio::io::AsyncRead::poll_read(Pin::new(&mut self.0), cx, buf)
}
pub fn compat<T: AsyncRead>(t: T) -> Compat<T> {
TokioAsyncReadCompatExt::compat(t)
}
}

Expand All @@ -538,6 +494,7 @@ mod tests {
StreamExt,
};
use soketto::handshake;
use tokio_stream::wrappers::TcpListenerStream;

#[test]
fn bounds_matching() {
Expand Down Expand Up @@ -566,8 +523,8 @@ mod tests {
assert_eq!(res.await, Ok(rpc::Value::String("x".into())));
}

async fn server(mut listener: compat::TcpListener, addr: &str) {
let mut incoming = listener.incoming();
async fn server(listener: compat::TcpListener, addr: &str) {
let mut incoming = TcpListenerStream::new(listener);
println!("Listening on: {}", addr);
while let Some(Ok(socket)) = incoming.next().await {
let socket = compat::compat(socket);
Expand Down

0 comments on commit 88b32b3

Please sign in to comment.