Skip to content

Commit

Permalink
Implement Sink and Stream for WsClient (#907)
Browse files Browse the repository at this point in the history
* Switch WsClient from tokio's unbounded channel to futures unbounded channel

* Implement Sink for WsClient

* Implement Stream for WsClient

* Test sink and stream

* Use WsError for the Sink implementation
  • Loading branch information
FSMaxB committed Oct 30, 2021
1 parent 25eedf6 commit 1948044
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 10 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ all-features = true
async-compression = { version = "0.3.7", features = ["brotli", "deflate", "gzip", "tokio"], optional = true }
bytes = "1.0"
futures-util = { version = "0.3", default-features = false, features = ["sink"] }
futures-channel = { version = "0.3.17", features = ["sink"]}
headers = "0.3"
http = "0.2"
hyper = { version = "0.14", features = ["stream", "server", "http1", "tcp", "client"] }
Expand Down
75 changes: 65 additions & 10 deletions src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,14 @@ use std::net::SocketAddr;
#[cfg(feature = "websocket")]
use std::pin::Pin;
#[cfg(feature = "websocket")]
use std::task::Context;
#[cfg(feature = "websocket")]
use std::task::{self, Poll};

use bytes::Bytes;
#[cfg(feature = "websocket")]
use futures_channel::mpsc;
#[cfg(feature = "websocket")]
use futures_util::StreamExt;
use futures_util::{future, FutureExt, TryFutureExt};
use http::{
Expand All @@ -102,15 +106,17 @@ use http::{
use serde::Serialize;
use serde_json;
#[cfg(feature = "websocket")]
use tokio::sync::{mpsc, oneshot};
#[cfg(feature = "websocket")]
use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio::sync::oneshot;

use crate::filter::Filter;
#[cfg(feature = "websocket")]
use crate::filters::ws::Message;
use crate::reject::IsReject;
use crate::reply::Reply;
use crate::route::{self, Route};
use crate::Request;
#[cfg(feature = "websocket")]
use crate::{Sink, Stream};

use self::inner::OneOrTuple;

Expand Down Expand Up @@ -484,9 +490,8 @@ impl WsBuilder {
F::Error: IsReject + Send,
{
let (upgraded_tx, upgraded_rx) = oneshot::channel();
let (wr_tx, wr_rx) = mpsc::unbounded_channel();
let wr_rx = UnboundedReceiverStream::new(wr_rx);
let (rd_tx, rd_rx) = mpsc::unbounded_channel();
let (wr_tx, wr_rx) = mpsc::unbounded();
let (rd_tx, rd_rx) = mpsc::unbounded();

tokio::spawn(async move {
use tokio_tungstenite::tungstenite::protocol;
Expand Down Expand Up @@ -546,7 +551,7 @@ impl WsBuilder {
Ok(m) => future::ready(!m.is_close()),
})
.for_each(move |item| {
rd_tx.send(item).expect("ws receive error");
rd_tx.unbounded_send(item).expect("ws receive error");
future::ready(())
});

Expand All @@ -573,13 +578,13 @@ impl WsClient {

/// Send a websocket message to the server.
pub async fn send(&mut self, msg: crate::ws::Message) {
self.tx.send(msg).unwrap();
self.tx.unbounded_send(msg).unwrap();
}

/// Receive a websocket message from the server.
pub async fn recv(&mut self) -> Result<crate::filters::ws::Message, WsError> {
self.rx
.recv()
.next()
.await
.map(|result| result.map_err(WsError::new))
.unwrap_or_else(|| {
Expand All @@ -591,7 +596,7 @@ impl WsClient {
/// Assert the server has closed the connection.
pub async fn recv_closed(&mut self) -> Result<(), WsError> {
self.rx
.recv()
.next()
.await
.map(|result| match result {
Ok(msg) => Err(WsError::new(format!("received message: {:?}", msg))),
Expand All @@ -602,6 +607,11 @@ impl WsClient {
Ok(())
})
}

fn pinned_tx(self: Pin<&mut Self>) -> Pin<&mut mpsc::UnboundedSender<crate::ws::Message>> {
let this = Pin::into_inner(self);
Pin::new(&mut this.tx)
}
}

#[cfg(feature = "websocket")]
Expand All @@ -611,6 +621,51 @@ impl fmt::Debug for WsClient {
}
}

#[cfg(feature = "websocket")]
impl Sink<crate::ws::Message> for WsClient {
type Error = WsError;

fn poll_ready(
self: Pin<&mut Self>,
context: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
self.pinned_tx().poll_ready(context).map_err(WsError::new)
}

fn start_send(self: Pin<&mut Self>, message: Message) -> Result<(), Self::Error> {
self.pinned_tx().start_send(message).map_err(WsError::new)
}

fn poll_flush(
self: Pin<&mut Self>,
context: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
self.pinned_tx().poll_flush(context).map_err(WsError::new)
}

fn poll_close(
self: Pin<&mut Self>,
context: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
self.pinned_tx().poll_close(context).map_err(WsError::new)
}
}

#[cfg(feature = "websocket")]
impl Stream for WsClient {
type Item = Result<crate::ws::Message, WsError>;

fn poll_next(self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = Pin::into_inner(self);
let rx = Pin::new(&mut this.rx);
match rx.poll_next(context) {
Poll::Ready(Some(result)) => Poll::Ready(Some(result.map_err(WsError::new))),
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
}

// ===== impl WsError =====

#[cfg(feature = "websocket")]
Expand Down
15 changes: 15 additions & 0 deletions tests/ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,21 @@ async fn binary() {
assert_eq!(msg.as_bytes(), &b"bonk"[..]);
}

#[tokio::test]
async fn wsclient_sink_and_stream() {
let _ = pretty_env_logger::try_init();

let mut client = warp::test::ws()
.handshake(ws_echo())
.await
.expect("handshake");

let message = warp::ws::Message::text("hello");
SinkExt::send(&mut client, message.clone()).await.unwrap();
let received_message = client.next().await.unwrap().unwrap();
assert_eq!(message, received_message);
}

#[tokio::test]
async fn close_frame() {
let _ = pretty_env_logger::try_init();
Expand Down

0 comments on commit 1948044

Please sign in to comment.