Skip to content

Commit

Permalink
Merge pull request #215 from CBenoit/master
Browse files Browse the repository at this point in the history
Fix `poll_close` returning WouldBlock error kind
  • Loading branch information
daniel-abramov committed Mar 3, 2022
2 parents b25e863 + 75b5bfd commit cb4b133
Showing 1 changed file with 16 additions and 5 deletions.
21 changes: 16 additions & 5 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ where
#[derive(Debug)]
pub struct WebSocketStream<S> {
inner: WebSocket<AllowStd<S>>,
closing: bool,
}

impl<S> WebSocketStream<S> {
Expand Down Expand Up @@ -215,7 +216,7 @@ impl<S> WebSocketStream<S> {
}

pub(crate) fn new(ws: WebSocket<AllowStd<S>>) -> Self {
WebSocketStream { inner: ws }
WebSocketStream { inner: ws, closing: false }
}

fn with_context<F, R>(&mut self, ctx: Option<(ContextWaker, &mut Context<'_>)>, f: F) -> R
Expand Down Expand Up @@ -294,9 +295,7 @@ where
fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
match (*self).with_context(None, |s| s.write_message(item)) {
Ok(()) => Ok(()),
Err(::tungstenite::Error::Io(ref err))
if err.kind() == std::io::ErrorKind::WouldBlock =>
{
Err(::tungstenite::Error::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
// the message was accepted and queued
// isn't an error.
Ok(())
Expand All @@ -313,9 +312,21 @@ where
}

fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match (*self).with_context(Some((ContextWaker::Write, cx)), |s| s.close(None)) {
let res = if self.closing {
// After queueing it, we call `write_pending` to drive the close handshake to completion.
(*self).with_context(Some((ContextWaker::Write, cx)), |s| s.write_pending())
} else {
(*self).with_context(Some((ContextWaker::Write, cx)), |s| s.close(None))
};

match res {
Ok(()) => Poll::Ready(Ok(())),
Err(::tungstenite::Error::ConnectionClosed) => Poll::Ready(Ok(())),
Err(::tungstenite::Error::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
trace!("WouldBlock");
self.closing = true;
Poll::Pending
}
Err(err) => {
debug!("websocket close error: {}", err);
Poll::Ready(Err(err))
Expand Down

0 comments on commit cb4b133

Please sign in to comment.