Skip to content

websocket example w dropshot #387

@jessfraz

Description

@jessfraz

Figured I'd leave this here if anyone else wants to implement websockets w dropshot ever as an example, at the end of the code samples I have a proposal for what very nice built-in websocket functionality might be as well for some commentary.

websocket.rs

Helper utilities for dealing with websockets.

//! Websocket specific functions.

use std::{
    borrow::Cow,
    pin::Pin,
    task::{Context, Poll},
};

use anyhow::Result;
use futures_util::{future, ready, FutureExt, Sink, Stream};
use tokio_tungstenite::{tungstenite::protocol, WebSocketStream};

/// A websocket `Stream` and `Sink`.
///
/// Ping messages sent from the client will be handled internally by replying with a Pong message.
/// Close messages need to be handled explicitly: usually by closing the `Sink` end of the
/// `WebSocket`.
///
/// **Note!**
/// Due to rust futures nature, pings won't be handled until read part of `WebSocket` is polled
pub struct WebSocket {
    inner: WebSocketStream<hyper::upgrade::Upgraded>,
}

impl WebSocket {
    pub(crate) async fn from_raw_socket(
        upgraded: hyper::upgrade::Upgraded,
        role: protocol::Role,
        config: Option<protocol::WebSocketConfig>,
    ) -> Self {
        WebSocketStream::from_raw_socket(upgraded, role, config)
            .map(|inner| WebSocket { inner })
            .await
    }

    /// Gracefully close this websocket.
    #[allow(dead_code)]
    pub async fn close(mut self) -> Result<()> {
        future::poll_fn(|cx| Pin::new(&mut self).poll_close(cx)).await
    }
}

impl Stream for WebSocket {
    type Item = Result<Message>;

    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        match ready!(Pin::new(&mut self.inner).poll_next(cx)) {
            Some(Ok(item)) => Poll::Ready(Some(Ok(Message { inner: item }))),
            Some(Err(e)) => {
                tracing::debug!("websocket poll error: {}", e);
                Poll::Ready(Some(Err(e.into())))
            }
            None => {
                tracing::trace!("websocket closed");
                Poll::Ready(None)
            }
        }
    }
}

impl Sink<Message> for WebSocket {
    type Error = anyhow::Error;

    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
        match ready!(Pin::new(&mut self.inner).poll_ready(cx)) {
            Ok(()) => Poll::Ready(Ok(())),
            Err(e) => Poll::Ready(Err(e.into())),
        }
    }

    fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<()> {
        match Pin::new(&mut self.inner).start_send(item.inner) {
            Ok(()) => Ok(()),
            Err(e) => {
                tracing::debug!("websocket start_send error: {}", e);
                Err(e.into())
            }
        }
    }

    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
        match ready!(Pin::new(&mut self.inner).poll_flush(cx)) {
            Ok(()) => Poll::Ready(Ok(())),
            Err(e) => Poll::Ready(Err(e.into())),
        }
    }

    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
        match ready!(Pin::new(&mut self.inner).poll_close(cx)) {
            Ok(()) => Poll::Ready(Ok(())),
            Err(err) => {
                tracing::debug!("websocket close error: {}", err);
                Poll::Ready(Err(err.into()))
            }
        }
    }
}

impl std::fmt::Debug for WebSocket {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("WebSocket").finish()
    }
}

/// A WebSocket message.
///
/// This will likely become a `non-exhaustive` enum in the future, once that
/// language feature has stabilized.
#[derive(Eq, PartialEq, Clone)]
pub struct Message {
    inner: protocol::Message,
}

impl Message {
    /// Construct a new Text `Message`.
    pub fn text<S: Into<String>>(s: S) -> Message {
        Message {
            inner: protocol::Message::text(s),
        }
    }

    /// Construct a new Json `Message`.
    pub fn json<T>(j: T) -> Result<Message>
    where
        T: serde::Serialize,
    {
        Ok(Message {
            inner: protocol::Message::text(serde_json::to_string(&j)?),
        })
    }

    /// Construct a new Binary `Message`.
    pub fn binary<V: Into<Vec<u8>>>(v: V) -> Message {
        Message {
            inner: protocol::Message::binary(v),
        }
    }

    /// Construct a new Ping `Message`.
    pub fn ping<V: Into<Vec<u8>>>(v: V) -> Message {
        Message {
            inner: protocol::Message::Ping(v.into()),
        }
    }

    /// Construct a new Pong `Message`.
    ///
    /// Note that one rarely needs to manually construct a Pong message because the underlying tungstenite socket
    /// automatically responds to the Ping messages it receives. Manual construction might still be useful in some cases
    /// like in tests or to send unidirectional heartbeats.
    pub fn pong<V: Into<Vec<u8>>>(v: V) -> Message {
        Message {
            inner: protocol::Message::Pong(v.into()),
        }
    }

    /// Construct the default Close `Message`.
    pub fn close() -> Message {
        Message {
            inner: protocol::Message::Close(None),
        }
    }

    /// Construct a Close `Message` with a code and reason.
    pub fn close_with(code: impl Into<u16>, reason: impl Into<Cow<'static, str>>) -> Message {
        Message {
            inner: protocol::Message::Close(Some(protocol::frame::CloseFrame {
                code: protocol::frame::coding::CloseCode::from(code.into()),
                reason: reason.into(),
            })),
        }
    }

    /// Returns true if this message is a Text message.
    pub fn is_text(&self) -> bool {
        self.inner.is_text()
    }

    /// Returns true if this message is a Binary message.
    pub fn is_binary(&self) -> bool {
        self.inner.is_binary()
    }

    /// Returns true if this message a is a Close message.
    pub fn is_close(&self) -> bool {
        self.inner.is_close()
    }

    /// Returns true if this message is a Ping message.
    pub fn is_ping(&self) -> bool {
        self.inner.is_ping()
    }

    /// Returns true if this message is a Pong message.
    pub fn is_pong(&self) -> bool {
        self.inner.is_pong()
    }

    /// Try to get the close frame (close code and reason)
    pub fn close_frame(&self) -> Option<(u16, &str)> {
        if let protocol::Message::Close(Some(ref close_frame)) = self.inner {
            Some((close_frame.code.into(), close_frame.reason.as_ref()))
        } else {
            None
        }
    }

    /// Try to get a reference to the string text, if this is a Text message.
    pub fn to_str(&self) -> Result<&str> {
        match self.inner {
            protocol::Message::Text(ref s) => Ok(s),
            _ => anyhow::bail!("not a text message"),
        }
    }

    /// Return the bytes of this message, if the message can contain data.
    pub fn as_bytes(&self) -> &[u8] {
        match self.inner {
            protocol::Message::Text(ref s) => s.as_bytes(),
            protocol::Message::Binary(ref v) => v,
            protocol::Message::Ping(ref v) => v,
            protocol::Message::Pong(ref v) => v,
            protocol::Message::Frame(ref frame) => frame.payload().as_slice(),
            protocol::Message::Close(_) => &[],
        }
    }

    /// Return the type as decoded json.
    pub fn as_json<T>(&self) -> Result<T>
    where
        T: serde::de::DeserializeOwned,
    {
        serde_json::from_slice(self.as_bytes()).map_err(Into::into)
    }

    /// Destructure this message into binary data.
    pub fn into_bytes(self) -> Vec<u8> {
        self.inner.into_data()
    }
}

impl std::fmt::Debug for Message {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        std::fmt::Debug::fmt(&self.inner, f)
    }
}

impl From<Message> for Vec<u8> {
    fn from(m: Message) -> Self {
        m.into_bytes()
    }
}

sample endpoint

This is just an example of wrapping the docker attach endpoint which also happens to be a websocket.

/**
 * Attach to a docker container to create an interactive terminal.
 */
#[endpoint {
    method = GET,
    path = "/term",
    tags = ["term"],
}]
pub async fn create_term(
    rqctx: Arc<RequestContext<Arc<Context>>>,
) -> Result<http::response::Response<hyper::body::Body>, HttpError> {
    // We want to parse the headers from the request context.
    let request = &rqctx.request.lock().await;

    // We want to make sure we have the websocket upgrade header.
    let h: Option<headers::Connection> = request.headers().typed_get();
    if let Some(h) = h {
        if !h.contains(&http::header::UPGRADE) {
            anyhow::bail!("Connection header did not include 'upgrade'");
        }
    } else {
         anyhow::bail!("Connection header not sent");
    }

    // Now get the secure websocket key.
    let h: Option<headers::SecWebsocketKey> = request.headers().typed_get();
    let websocket_key = if let Some(h) = h {
        h
    } else {
        anyhow::bail!("Websocket key not sent");
    };

    let ctx = rqctx.context();

    // Create and start the term container.
    let container_id = ctx.create_ws_term_container(token).await?;

    // We reply with:
    // - Status of `101 Switching Protocols`
    // - Header `connection: upgrade`
    // - Header `upgrade: websocket`
    // - Header `sec-websocket-accept` with the hash value of the received key.
    let mut response = http::Response::default();
    *response.status_mut() = http::StatusCode::SWITCHING_PROTOCOLS;
    response.headers_mut().typed_insert(headers::Connection::upgrade());
    response.headers_mut().typed_insert(headers::Upgrade::websocket());
    response
        .headers_mut()
        .typed_insert(headers::SecWebsocketAccept::from(websocket_key));

    // Spawn a task to handle the websocket connection.
    tokio::spawn(common::enclose! { (rqctx, ctx) async move {
        let mut request = rqctx.request.lock().await;

        // Attach to the container.
        let bollard::container::AttachContainerResults { mut output, mut input } = ctx
            .docker
            .attach_container(
                &container_id,
                Some(bollard::container::AttachContainerOptions::<String> {
                    stdout: Some(true),
                    stderr: Some(true),
                    stdin: Some(true),
                    stream: Some(true),
                    ..Default::default()
                }),
            )
            .await?;


        // Use the hyper feature of upgrading a connection.
        match hyper::upgrade::on(request.deref_mut()).await {
            //if successfully upgraded
            Ok(upgraded) => {
                //create a websocket stream from the upgraded object
                let ws = crate::server::websocket::WebSocket::from_raw_socket(
                    // Pass the upgraded object as the base layer stream of the Websocket.
                    upgraded,
                    tokio_tungstenite::tungstenite::protocol::Role::Server,
                    None,
                )
                .await;

                // We split the stream into a sink and a stream.
                let (mut ws_write, mut ws_read) = ws.split();

                // For every message we get on the container websocket, we want
                // to send it back out to our browser websocket.
                tokio::spawn(async move {
                    while let Some(byte) = output.next().await {
                        ws_write.send(crate::server::websocket::Message::json(crate::types::WebsocketMessage::Stdout{
                            data: match byte? {
                                bollard::container::LogOutput::StdErr{message} => {
                                    String::from_utf8(message.to_vec())?
                                }
                                bollard::container::LogOutput::StdOut{message} => {
                                    String::from_utf8(message.to_vec())?
                                }
                                bollard::container::LogOutput::StdIn{message} => {
                                    String::from_utf8(message.to_vec())?
                                }
                                bollard::container::LogOutput::Console{message} => {
                                    String::from_utf8(message.to_vec())?
                                }
                            },
                        })?).await?;
                    }

                    Ok::<(), anyhow::Error>(())
                });

                // Everytime we get a message from the server, we need to handle it.
                while let Some(result) = ws_read.next().await {
                     match result {
                        Ok(msg) =>{
                            // If the message is a close message we want to stop the container.
                            if msg.is_close() {
                                break;
                            }

                            if msg.is_ping(){
                                // We can continue early.
                                continue;
                            }

                            // Let's parse the message as JSON.
                            let m: crate::types::WebsocketMessage = msg.as_json()?;

                            match m {
                                crate::types::WebsocketMessage::Stdin{data} => {
                                    // We want to send the message to the container.
                                    input.write_all(data.as_bytes()).await?;
                                }
                                crate::types::WebsocketMessage::Stdout{data} => {
                                    // We should not get stdout, since that is not a valid
                                    // incoming type, it is the type we send back to the
                                    // browser.
                                    log::warn!("Received stdout from browser: {}", data);
                                }
                                crate::types::WebsocketMessage::Resize{height, width} => {
                                    // Resize the container.
                                    ctx.docker.resize_container_tty(&container_id,
                                    bollard::container::ResizeContainerTtyOptions{
                                        height,
                                        width,
                                    }).await?;
                                }
                                crate::types::WebsocketMessage::KeepAlive => {
                                    // Do nothing.
                                    continue;
                                }
                            }
                        },
                        Err(e) => {
                            log::warn!("browser websocket message error: {}", e);
                            break;
                        }
                    };
                }

                // If we get here, it means the websocket connection was closed.
                // Meaning the user closed the connection.
                log::debug!("browser websocket connection closed");
            }
            Err(e) => {
                anyhow::bail!("trying to upgrade connection to websocket connection failed: {}", e);
            }
        }

        // If we get here we need to remove the container.
        ctx.remove_container(&container_id).await?;

        Ok(())
    }});

    Ok(response)
}

So my proposal for the UX fully integrated into dropshot would be an Ext trait on RequestContext such that it implements a function upgrade which does all the bullshit with making sure all the headers are there and correct and upgrade takes a function generic that would be run async such as all my logic above that runs async after we upgrade... make sense? Happy to implement it if it sounds good. I figure we will eventually need this and its even nicer as an end user if you don't have to do all the boilerplate.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions