Skip to content

Commit

Permalink
Add proxy waiters to allow the Stream to trigger AsyncWrite oper… (#6)
Browse files Browse the repository at this point in the history
* Updates for tokio master

* More logging

* Remove trace on start send

* Update to tokio 0.2 release

* Run travis with stable Rust toolchain again

* Add proxy waiters to allow the Stream to trigger AsyncWrite operations

As a side effect also gets rid of unsafe code and raw pointers.

We have the problem that external read operations (i.e. the Stream impl)
can trigger both read (AsyncRead) and write (AsyncWrite) operations on
the underyling stream. At the same time write operations (i.e. the Sink
impl) can trigger write operations (AsyncWrite) too.
Both the Stream and the Sink can be used on two different tasks, but it
is required that AsyncRead and AsyncWrite are only ever used by a single
task (or better: with a single waker) at a time.

Doing otherwise would cause only the latest waker to be remembered, so
in our case either the Stream or the Sink impl would potentially wait
forever to be woken up because only the other one would've been woken
up.

To solve this we implement a waker proxy that has two slots (one for
read, one for write) to store wakers. One waker proxy is always passed
to the AsyncRead, the other to AsyncWrite so that they will only ever
have to store a single waker, but internally we dispatch any wakeups to
up to two actual wakers (one from the Sink impl and one from the Stream
impl).

* Remove WebSocketStream::send()

The same functionality is already provided via StreamExt::send() and
unlike the custom implementation it handles WouldBlock correctly and not
as an error.

* Remove custom WebSocketStream::close() implementation

Instead simply send an owned Closed message. This simplifies the code
and among other things also handles errors like WouldBlock correctly
instead of handling them like a real error.

* Update dependencies

* Remove unused #[pin_project] attribute

* Use #[pin_project] in stream implementation to get rid of remaining unsafe code

Now the whole crate has no unsafe code left.

* Depend only on the necessary tokio features instead of "full"

Reduces the number of dependencies and code to compile.

Based on a PR by Maksym Vorobiov.

* Remove unused bytes dependency
  • Loading branch information
sdroege authored and dbcfd committed Nov 30, 2019
1 parent 320cf71 commit 6e566ec
Show file tree
Hide file tree
Showing 13 changed files with 174 additions and 197 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
language: rust
rust:
- nightly-2019-09-05
- stable

before_script:
- export PATH="$PATH:$HOME/.cargo/bin"
Expand Down
20 changes: 8 additions & 12 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,33 +13,29 @@ edition = "2018"

[features]
default = ["connect"]
connect = ["stream"]
tls = ["native-tls", "stream", "tungstenite/tls"]
stream = ["bytes"]
connect = ["stream", "tokio/net"]
tls = ["native-tls", "tokio-tls", "stream", "tungstenite/tls"]
stream = []

[dependencies]
log = "0.4"
futures = "0.3"
pin-project = "0.4"
tokio = { git = "https://github.com/tokio-rs/tokio.git", branch = "master", features = ["full"] }
tokio = "0.2"

[dependencies.tungstenite]
version = "0.9.2"
default-features = false

[dependencies.bytes]
optional = true
version = "0.4.8"

[dependencies.native-tls]
optional = true
version = "0.2.0"

#[dependencies.tokio-tls]
#optional = true
#git = "https://github.com/tokio-rs/tokio.git"
#branch = "master"
[dependencies.tokio-tls]
optional = true
version = "0.3"

[dev-dependencies]
tokio = { version = "0.2", features = ["net", "macros", "rt-threaded", "io-std", "io-util"] }
url = "2.0.0"
env_logger = "0.6"
2 changes: 1 addition & 1 deletion examples/autobahn-client.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use futures::StreamExt;
use futures::{SinkExt, StreamExt};
use log::*;
use tokio_tungstenite::{connect_async, tungstenite::Result};
use url::Url;
Expand Down
2 changes: 1 addition & 1 deletion examples/autobahn-server.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use futures::StreamExt;
use futures::{SinkExt, StreamExt};
use log::*;
use std::net::{SocketAddr, ToSocketAddrs};
use tokio::net::{TcpListener, TcpStream};
Expand Down
2 changes: 1 addition & 1 deletion examples/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
use std::env;
use std::io::{self, Write};

use futures::StreamExt;
use futures::{SinkExt, StreamExt};
use log::*;
use tungstenite::protocol::Message;

Expand Down
2 changes: 1 addition & 1 deletion examples/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use std::env;
use std::io::Error;

use futures::channel::mpsc::{UnboundedReceiver, UnboundedSender};
use futures::StreamExt;
use futures::{SinkExt, StreamExt};
use log::*;
use std::net::{SocketAddr, ToSocketAddrs};
use tokio::net::{TcpListener, TcpStream};
Expand Down
138 changes: 107 additions & 31 deletions src/compat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,55 +3,131 @@ use std::io::{Read, Write};
use std::pin::Pin;
use std::task::{Context, Poll};

use futures::task;
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite};
use tungstenite::{Error as WsError, WebSocket};
use tungstenite::Error as WsError;

pub(crate) trait HasContext {
fn set_context(&mut self, context: (bool, *mut ()));
pub(crate) enum ContextWaker {
Read,
Write,
}

#[derive(Debug)]
pub(crate) struct AllowStd<S> {
pub(crate) inner: S,
pub(crate) context: (bool, *mut ()),
inner: S,
// We have the problem that external read operations (i.e. the Stream impl)
// can trigger both read (AsyncRead) and write (AsyncWrite) operations on
// the underyling stream. At the same time write operations (i.e. the Sink
// impl) can trigger write operations (AsyncWrite) too.
// Both the Stream and the Sink can be used on two different tasks, but it
// is required that AsyncRead and AsyncWrite are only ever used by a single
// task (or better: with a single waker) at a time.
//
// Doing otherwise would cause only the latest waker to be remembered, so
// in our case either the Stream or the Sink impl would potentially wait
// forever to be woken up because only the other one would've been woken
// up.
//
// To solve this we implement a waker proxy that has two slots (one for
// read, one for write) to store wakers. One waker proxy is always passed
// to the AsyncRead, the other to AsyncWrite so that they will only ever
// have to store a single waker, but internally we dispatch any wakeups to
// up to two actual wakers (one from the Sink impl and one from the Stream
// impl).
//
// write_waker_proxy is always used for AsyncWrite, read_waker_proxy for
// AsyncRead. The read_waker slots of both are used for the Stream impl
// (and handshaking), the write_waker slots for the Sink impl.
write_waker_proxy: Arc<WakerProxy>,
read_waker_proxy: Arc<WakerProxy>,
}

// Internal trait used only in the Handshake module for registering
// the waker for the context used during handshaking. We're using the
// read waker slot for this, but any would do.
//
// Don't ever use this from multiple tasks at the same time!
pub(crate) trait SetWaker {
fn set_waker(&self, waker: &task::Waker);
}

impl<S> HasContext for AllowStd<S> {
fn set_context(&mut self, context: (bool, *mut ())) {
self.context = context;
impl<S> SetWaker for AllowStd<S> {
fn set_waker(&self, waker: &task::Waker) {
self.set_waker(ContextWaker::Read, waker);
}
}

pub(crate) struct Guard<'a, S>(pub(crate) &'a mut WebSocket<AllowStd<S>>);
impl<S> AllowStd<S> {
pub(crate) fn new(inner: S, waker: &task::Waker) -> Self {
let res = Self {
inner,
write_waker_proxy: Default::default(),
read_waker_proxy: Default::default(),
};

// Register the handshake waker as read waker for both proxies,
// see also the SetWaker trait.
res.write_waker_proxy.read_waker.register(waker);
res.read_waker_proxy.read_waker.register(waker);

impl<S> Drop for Guard<'_, S> {
fn drop(&mut self) {
trace!("{}:{} Guard.drop", file!(), line!());
(self.0).get_mut().context = (true, std::ptr::null_mut());
res
}

// Set the read or write waker for our proxies.
//
// Read: this is only supposed to be called by read (or handshake) operations, i.e. the Stream
// impl on the WebSocketStream.
// Reading can also cause writes to happen, e.g. in case of Message::Ping handling.
//
// Write: this is only supposde to be called by write operations, i.e. the Sink impl on the
// WebSocketStream.
pub(crate) fn set_waker(&self, kind: ContextWaker, waker: &task::Waker) {
match kind {
ContextWaker::Read => {
self.write_waker_proxy.read_waker.register(waker);
self.read_waker_proxy.read_waker.register(waker);
}
ContextWaker::Write => {
self.write_waker_proxy.write_waker.register(waker);
self.read_waker_proxy.write_waker.register(waker);
}
}
}
}

// *mut () context is neither Send nor Sync
unsafe impl<S: Send> Send for AllowStd<S> {}
unsafe impl<S: Sync> Sync for AllowStd<S> {}
// Proxy Waker that we pass to the internal AsyncRead/Write of the
// stream underlying the websocket. We have two slots here for the
// actual wakers to allow external read operations to trigger both
// reads and writes, and the same for writes.
#[derive(Debug, Default)]
struct WakerProxy {
read_waker: task::AtomicWaker,
write_waker: task::AtomicWaker,
}

impl task::ArcWake for WakerProxy {
fn wake_by_ref(arc_self: &Arc<Self>) {
arc_self.read_waker.wake();
arc_self.write_waker.wake();
}
}

impl<S> AllowStd<S>
where
S: Unpin,
{
fn with_context<F, R>(&mut self, f: F) -> Poll<std::io::Result<R>>
fn with_context<F, R>(&mut self, kind: ContextWaker, f: F) -> Poll<std::io::Result<R>>
where
F: FnOnce(&mut Context<'_>, Pin<&mut S>) -> Poll<std::io::Result<R>>,
{
trace!("{}:{} AllowStd.with_context", file!(), line!());
unsafe {
if !self.context.0 {
//was called by start_send without context
return Poll::Pending;
}
assert!(!self.context.1.is_null());
let waker = &mut *(self.context.1 as *mut _);
f(waker, Pin::new(&mut self.inner))
}
let waker = match kind {
ContextWaker::Read => task::waker_ref(&self.read_waker_proxy),
ContextWaker::Write => task::waker_ref(&self.write_waker_proxy),
};
let mut context = task::Context::from_waker(&waker);
f(&mut context, Pin::new(&mut self.inner))
}

pub(crate) fn get_mut(&mut self) -> &mut S {
Expand All @@ -69,7 +145,7 @@ where
{
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
trace!("{}:{} Read.read", file!(), line!());
match self.with_context(|ctx, stream| {
match self.with_context(ContextWaker::Read, |ctx, stream| {
trace!(
"{}:{} Read.with_context read -> poll_read",
file!(),
Expand All @@ -89,7 +165,7 @@ where
{
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
trace!("{}:{} Write.write", file!(), line!());
match self.with_context(|ctx, stream| {
match self.with_context(ContextWaker::Write, |ctx, stream| {
trace!(
"{}:{} Write.with_context write -> poll_write",
file!(),
Expand All @@ -104,7 +180,7 @@ where

fn flush(&mut self) -> std::io::Result<()> {
trace!("{}:{} Write.flush", file!(), line!());
match self.with_context(|ctx, stream| {
match self.with_context(ContextWaker::Write, |ctx, stream| {
trace!(
"{}:{} Write.with_context flush -> poll_flush",
file!(),
Expand All @@ -122,9 +198,9 @@ pub(crate) fn cvt<T>(r: Result<T, WsError>) -> Poll<Result<T, WsError>> {
match r {
Ok(v) => Poll::Ready(Ok(v)),
Err(WsError::Io(ref e)) if e.kind() == std::io::ErrorKind::WouldBlock => {
debug!("WouldBlock");
trace!("WouldBlock");
Poll::Pending
},
}
Err(e) => Poll::Ready(Err(e)),
}
}
4 changes: 2 additions & 2 deletions src/connect.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Connection helper.
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::tcp::TcpStream;
use tokio::net::TcpStream;

use tungstenite::client::url_mode;
use tungstenite::handshake::client::Response;
Expand All @@ -13,7 +13,7 @@ pub(crate) mod encryption {
use native_tls::TlsConnector;
use tokio_tls::{TlsConnector as TokioTlsConnector, TlsStream};

use tokio_io::{AsyncRead, AsyncWrite};
use tokio::io::{AsyncRead, AsyncWrite};

use tungstenite::stream::Mode;
use tungstenite::Error;
Expand Down
32 changes: 9 additions & 23 deletions src/handshake.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::compat::{AllowStd, HasContext};
use crate::compat::{AllowStd, SetWaker};
use crate::WebSocketStream;
use log::*;
use pin_project::pin_project;
Expand Down Expand Up @@ -45,10 +45,7 @@ where
.take()
.expect("future polled after completion");
trace!("Setting context when skipping handshake");
let stream = AllowStd {
inner: inner.stream,
context: (true, ctx as *mut _ as *mut ()),
};
let stream = AllowStd::new(inner.stream, ctx.waker());

Poll::Ready((inner.f)(stream))
}
Expand All @@ -71,7 +68,7 @@ struct StartedHandshakeFutureInner<F, S> {
async fn handshake<Role, F, S>(stream: S, f: F) -> Result<Role::FinalResult, Error<Role>>
where
Role: HandshakeRole + Unpin,
Role::InternalStream: HasContext,
Role::InternalStream: SetWaker,
F: FnOnce(AllowStd<S>) -> Result<Role::FinalResult, Error<Role>> + Unpin,
S: AsyncRead + AsyncWrite + Unpin,
{
Expand Down Expand Up @@ -125,7 +122,7 @@ where
impl<Role, F, S> Future for StartedHandshakeFuture<F, S>
where
Role: HandshakeRole,
Role::InternalStream: HasContext,
Role::InternalStream: SetWaker,
F: FnOnce(AllowStd<S>) -> Result<Role::FinalResult, Error<Role>> + Unpin,
S: Unpin,
AllowStd<S>: Read + Write,
Expand All @@ -135,18 +132,11 @@ where
fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
let inner = self.0.take().expect("future polled after completion");
trace!("Setting ctx when starting handshake");
let stream = AllowStd {
inner: inner.stream,
context: (true, ctx as *mut _ as *mut ()),
};
let stream = AllowStd::new(inner.stream, ctx.waker());

match (inner.f)(stream) {
Ok(r) => Poll::Ready(Ok(StartedHandshake::Done(r))),
Err(Error::Interrupted(mut mid)) => {
let machine = mid.get_mut();
machine.get_mut().set_context((true, std::ptr::null_mut()));
Poll::Ready(Ok(StartedHandshake::Mid(mid)))
}
Err(Error::Interrupted(mid)) => Poll::Ready(Ok(StartedHandshake::Mid(mid))),
Err(Error::Failure(e)) => Poll::Ready(Err(Error::Failure(e))),
}
}
Expand All @@ -155,7 +145,7 @@ where
impl<Role> Future for MidHandshake<Role>
where
Role: HandshakeRole + Unpin,
Role::InternalStream: HasContext,
Role::InternalStream: SetWaker,
{
type Output = Result<Role::FinalResult, Error<Role>>;

Expand All @@ -165,16 +155,12 @@ where

let machine = s.get_mut();
trace!("Setting context in handshake");
machine
.get_mut()
.set_context((true, cx as *mut _ as *mut ()));
machine.get_mut().set_waker(cx.waker());

match s.handshake() {
Ok(stream) => Poll::Ready(Ok(stream)),
Err(Error::Failure(e)) => Poll::Ready(Err(Error::Failure(e))),
Err(Error::Interrupted(mut mid)) => {
let machine = mid.get_mut();
machine.get_mut().set_context((true, std::ptr::null_mut()));
Err(Error::Interrupted(mid)) => {
*this.0 = Some(mid);
Poll::Pending
}
Expand Down

0 comments on commit 6e566ec

Please sign in to comment.