Skip to content

Commit

Permalink
Allow closing the connecting by returning ControlFlow::Break (#33)
Browse files Browse the repository at this point in the history
* Specify what should happen when `ControlFlow::Break` is returned

* Implement `ControlFlow::Break` for `native_tungstenite`

* Add placeholder code for native_tungstenite_tokio backend

Not sure how to implement it there

* Start work on the web backend

* Fix borrow checker stuff

* Use the same close logic everywhere

* Make `close` non-fallible

* Add warnings for tokio backend

* Fix warning

* Add warning about the tokio backend sucking
  • Loading branch information
emilk committed Apr 18, 2024
1 parent 8d99d31 commit e145ce5
Show file tree
Hide file tree
Showing 6 changed files with 149 additions and 117 deletions.
18 changes: 0 additions & 18 deletions check.sh

This file was deleted.

2 changes: 2 additions & 0 deletions ewebsock/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ tls = ["tungstenite/rustls-tls-webpki-roots"]
## This adds a lot of dependencies,
## but may yield lower latency and CPU usage
## when using `ws_connect`.
##
## Will ignore any `ControlFlow::Break` returned from the `on_event` callback.
tokio = [
"dep:async-stream",
"dep:futures",
Expand Down
20 changes: 17 additions & 3 deletions ewebsock/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
#[cfg(not(feature = "tokio"))]
mod native_tungstenite;

use std::ops::ControlFlow;

#[cfg(not(target_arch = "wasm32"))]
#[cfg(not(feature = "tokio"))]
pub use native_tungstenite::*;
Expand Down Expand Up @@ -98,9 +100,9 @@ impl WsReceiver {
let on_event = Box::new(move |event| {
wake_up(); // wake up UI thread
if tx.send(event).is_ok() {
std::ops::ControlFlow::Continue(())
ControlFlow::Continue(())
} else {
std::ops::ControlFlow::Break(())
ControlFlow::Break(())
}
});
let ws_receiver = WsReceiver { rx };
Expand All @@ -119,7 +121,7 @@ pub type Error = String;
/// Short for `Result<T, ewebsock::Error>`.
pub type Result<T> = std::result::Result<T, Error>;

pub(crate) type EventHandler = Box<dyn Send + Fn(WsEvent) -> std::ops::ControlFlow<()>>;
pub(crate) type EventHandler = Box<dyn Send + Fn(WsEvent) -> ControlFlow<()>>;

/// Options for a connection.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
Expand All @@ -143,6 +145,9 @@ impl Default for Options {

/// Connect to the given URL, and return a sender and receiver.
///
/// If `on_event` returns [`ControlFlow::Break`], the connection will be closed
/// without calling `on_event` again.
///
/// This is a wrapper around [`ws_connect`].
///
/// # Errors
Expand All @@ -161,6 +166,9 @@ pub fn connect(url: impl Into<String>, options: Options) -> Result<(WsSender, Ws
///
/// This allows you to wake up the UI thread, for instance.
///
/// If `on_event` returns [`ControlFlow::Break`], the connection will be closed
/// without calling `on_event` again.
///
/// This is a wrapper around [`ws_connect`].
///
/// # Errors
Expand All @@ -180,6 +188,9 @@ pub fn connect_with_wakeup(

/// Connect and call the given event handler on each received event.
///
/// If `on_event` returns [`ControlFlow::Break`], the connection will be closed
/// without calling `on_event` again.
///
/// See [`crate::connect`] for a more high-level version.
///
/// # Errors
Expand All @@ -196,6 +207,9 @@ pub fn ws_connect(url: String, options: Options, on_event: EventHandler) -> Resu
///
/// This can be slightly more efficient when you don't need to send messages.
///
/// If `on_event` returns [`ControlFlow::Break`], the connection will be closed
/// without calling `on_event` again.
///
/// # Errors
/// * On native: failure to spawn receiver thread.
/// * On web: failure to use `WebSocket` API.
Expand Down
72 changes: 47 additions & 25 deletions ewebsock/src/native_tungstenite.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
#![allow(deprecated)] // TODO(emilk): Remove when we update tungstenite

use std::sync::mpsc::{Receiver, TryRecvError};
use std::{
ops::ControlFlow,
sync::mpsc::{Receiver, TryRecvError},
};

use crate::{EventHandler, Options, Result, WsEvent, WsMessage};

Expand All @@ -13,9 +16,7 @@ pub struct WsSender {

impl Drop for WsSender {
fn drop(&mut self) {
if let Err(err) = self.close() {
log::warn!("Failed to close web-socket: {err:?}");
}
self.close();
}
}

Expand All @@ -32,16 +33,11 @@ impl WsSender {
/// Close the connection.
///
/// This is called automatically when the sender is dropped.
///
/// # Errors
/// This should never fail, except _maybe_ on Web.
#[allow(clippy::unnecessary_wraps)] // To keep the same signature as the Web version
pub fn close(&mut self) -> Result<()> {
pub fn close(&mut self) {
if self.tx.is_some() {
log::debug!("Closing WebSocket");
}
self.tx = None;
Ok(())
}

/// Forget about this sender without closing the connection.
Expand Down Expand Up @@ -90,33 +86,46 @@ pub fn ws_receiver_blocking(url: &str, options: Options, on_event: &EventHandler
response.headers()
);

on_event(WsEvent::Opened);
let control = on_event(WsEvent::Opened);
if control.is_break() {
log::trace!("Closing connection due to Break");
return socket
.close(None)
.map_err(|err| format!("Failed to close connection: {err}"));
}

loop {
match socket.read_message() {
let control = match socket.read_message() {
Ok(incoming_msg) => match incoming_msg {
tungstenite::protocol::Message::Text(text) => {
on_event(WsEvent::Message(WsMessage::Text(text)));
on_event(WsEvent::Message(WsMessage::Text(text)))
}
tungstenite::protocol::Message::Binary(data) => {
on_event(WsEvent::Message(WsMessage::Binary(data)));
on_event(WsEvent::Message(WsMessage::Binary(data)))
}
tungstenite::protocol::Message::Ping(data) => {
on_event(WsEvent::Message(WsMessage::Ping(data)));
on_event(WsEvent::Message(WsMessage::Ping(data)))
}
tungstenite::protocol::Message::Pong(data) => {
on_event(WsEvent::Message(WsMessage::Pong(data)));
on_event(WsEvent::Message(WsMessage::Pong(data)))
}
tungstenite::protocol::Message::Close(close) => {
on_event(WsEvent::Closed);
log::debug!("WebSocket close received: {close:?}");
return Ok(());
}
tungstenite::protocol::Message::Frame(_) => {}
tungstenite::protocol::Message::Frame(_) => ControlFlow::Continue(()),
},
Err(err) => {
return Err(format!("read: {err}"));
}
};

if control.is_break() {
log::trace!("Closing connection due to Break");
return socket
.close(None)
.map_err(|err| format!("Failed to close connection: {err}"));
}

std::thread::sleep(std::time::Duration::from_millis(10));
Expand Down Expand Up @@ -172,7 +181,13 @@ pub fn ws_connect_blocking(
response.headers()
);

on_event(WsEvent::Opened);
let control = on_event(WsEvent::Opened);
if control.is_break() {
log::trace!("Closing connection due to Break");
return socket
.close(None)
.map_err(|err| format!("Failed to close connection: {err}"));
}

match socket.get_mut() {
tungstenite::stream::MaybeTlsStream::Plain(stream) => stream.set_nonblocking(true),
Expand Down Expand Up @@ -216,38 +231,45 @@ pub fn ws_connect_blocking(
Err(TryRecvError::Empty) => {}
};

match socket.read_message() {
let control = match socket.read_message() {
Ok(incoming_msg) => {
did_work = true;
match incoming_msg {
tungstenite::protocol::Message::Text(text) => {
on_event(WsEvent::Message(WsMessage::Text(text)));
on_event(WsEvent::Message(WsMessage::Text(text)))
}
tungstenite::protocol::Message::Binary(data) => {
on_event(WsEvent::Message(WsMessage::Binary(data)));
on_event(WsEvent::Message(WsMessage::Binary(data)))
}
tungstenite::protocol::Message::Ping(data) => {
on_event(WsEvent::Message(WsMessage::Ping(data)));
on_event(WsEvent::Message(WsMessage::Ping(data)))
}
tungstenite::protocol::Message::Pong(data) => {
on_event(WsEvent::Message(WsMessage::Pong(data)));
on_event(WsEvent::Message(WsMessage::Pong(data)))
}
tungstenite::protocol::Message::Close(close) => {
on_event(WsEvent::Closed);
log::debug!("Close received: {close:?}");
return Ok(());
}
tungstenite::protocol::Message::Frame(_) => {}
tungstenite::protocol::Message::Frame(_) => ControlFlow::Continue(()),
}
}
Err(tungstenite::Error::Io(io_err))
if io_err.kind() == std::io::ErrorKind::WouldBlock =>
{
// Ignore
ControlFlow::Continue(()) // Ignore
}
Err(err) => {
return Err(format!("read: {err}"));
}
};

if control.is_break() {
log::trace!("Closing connection due to Break");
return socket
.close(None)
.map_err(|err| format!("Failed to close connection: {err}"));
}

if !did_work {
Expand Down
44 changes: 21 additions & 23 deletions ewebsock/src/native_tungstenite_tokio.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::ops::ControlFlow;

use crate::{EventHandler, Options, Result, WsEvent, WsMessage};

/// This is how you send [`WsMessage`]s to the server.
Expand All @@ -9,9 +11,7 @@ pub struct WsSender {

impl Drop for WsSender {
fn drop(&mut self) {
if let Err(err) = self.close() {
log::warn!("Failed to close web-socket: {err:?}");
}
self.close();
}
}

Expand All @@ -28,16 +28,11 @@ impl WsSender {
/// Close the connection.
///
/// This is called automatically when the sender is dropped.
///
/// # Errors
/// This should never fail, except _maybe_ on Web.
#[allow(clippy::unnecessary_wraps)] // To keep the same signature as the Web version
pub fn close(&mut self) -> Result<()> {
pub fn close(&mut self) {
if self.tx.is_some() {
log::debug!("Closing WebSocket");
}
self.tx = None;
Ok(())
}

/// Forget about this sender without closing the connection.
Expand All @@ -57,7 +52,7 @@ async fn ws_connect_async(

let config = tungstenite::protocol::WebSocketConfig::from(options);
let disable_nagle = false; // God damn everyone who adds negations to the names of their variables
let (ws_stream, _) = match tokio_tungstenite::connect_async_with_config(
let (ws_stream, _response) = match tokio_tungstenite::connect_async_with_config(
url,
Some(config),
disable_nagle,
Expand All @@ -72,7 +67,11 @@ async fn ws_connect_async(
};

log::info!("WebSocket handshake has been successfully completed");
on_event(WsEvent::Opened);

let control = on_event(WsEvent::Opened);
if control.is_break() {
log::warn!("ControlFlow::Break not implemented for the tungstenite tokio backend");
}

let (write, read) = ws_stream.split();

Expand All @@ -88,29 +87,28 @@ async fn ws_connect_async(
.forward(write);

let reader = read.for_each(move |event| {
match event {
let control = match event {
Ok(message) => match message {
tungstenite::protocol::Message::Text(text) => {
on_event(WsEvent::Message(WsMessage::Text(text)));
on_event(WsEvent::Message(WsMessage::Text(text)))
}
tungstenite::protocol::Message::Binary(data) => {
on_event(WsEvent::Message(WsMessage::Binary(data)));
on_event(WsEvent::Message(WsMessage::Binary(data)))
}
tungstenite::protocol::Message::Ping(data) => {
on_event(WsEvent::Message(WsMessage::Ping(data)));
on_event(WsEvent::Message(WsMessage::Ping(data)))
}
tungstenite::protocol::Message::Pong(data) => {
on_event(WsEvent::Message(WsMessage::Pong(data)));
}
tungstenite::protocol::Message::Close(_) => {
on_event(WsEvent::Closed);
on_event(WsEvent::Message(WsMessage::Pong(data)))
}
tungstenite::protocol::Message::Frame(_) => {}
tungstenite::protocol::Message::Close(_) => on_event(WsEvent::Closed),
tungstenite::protocol::Message::Frame(_) => ControlFlow::Continue(()),
},
Err(err) => {
on_event(WsEvent::Error(err.to_string()));
}
Err(err) => on_event(WsEvent::Error(err.to_string())),
};
if control.is_break() {
log::warn!("ControlFlow::Break not implemented for the tungstenite tokio backend");
}
async {}
});

Expand Down

0 comments on commit e145ce5

Please sign in to comment.