-
Notifications
You must be signed in to change notification settings - Fork 94
Description
Figured I'd leave this here if anyone else wants to implement websockets w dropshot ever as an example, at the end of the code samples I have a proposal for what very nice built-in websocket functionality might be as well for some commentary.
websocket.rs
Helper utilities for dealing with websockets.
//! Websocket specific functions.
use std::{
borrow::Cow,
pin::Pin,
task::{Context, Poll},
};
use anyhow::Result;
use futures_util::{future, ready, FutureExt, Sink, Stream};
use tokio_tungstenite::{tungstenite::protocol, WebSocketStream};
/// A websocket `Stream` and `Sink`.
///
/// Ping messages sent from the client will be handled internally by replying with a Pong message.
/// Close messages need to be handled explicitly: usually by closing the `Sink` end of the
/// `WebSocket`.
///
/// **Note!**
/// Due to rust futures nature, pings won't be handled until read part of `WebSocket` is polled
pub struct WebSocket {
inner: WebSocketStream<hyper::upgrade::Upgraded>,
}
impl WebSocket {
pub(crate) async fn from_raw_socket(
upgraded: hyper::upgrade::Upgraded,
role: protocol::Role,
config: Option<protocol::WebSocketConfig>,
) -> Self {
WebSocketStream::from_raw_socket(upgraded, role, config)
.map(|inner| WebSocket { inner })
.await
}
/// Gracefully close this websocket.
#[allow(dead_code)]
pub async fn close(mut self) -> Result<()> {
future::poll_fn(|cx| Pin::new(&mut self).poll_close(cx)).await
}
}
impl Stream for WebSocket {
type Item = Result<Message>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match ready!(Pin::new(&mut self.inner).poll_next(cx)) {
Some(Ok(item)) => Poll::Ready(Some(Ok(Message { inner: item }))),
Some(Err(e)) => {
tracing::debug!("websocket poll error: {}", e);
Poll::Ready(Some(Err(e.into())))
}
None => {
tracing::trace!("websocket closed");
Poll::Ready(None)
}
}
}
}
impl Sink<Message> for WebSocket {
type Error = anyhow::Error;
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
match ready!(Pin::new(&mut self.inner).poll_ready(cx)) {
Ok(()) => Poll::Ready(Ok(())),
Err(e) => Poll::Ready(Err(e.into())),
}
}
fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<()> {
match Pin::new(&mut self.inner).start_send(item.inner) {
Ok(()) => Ok(()),
Err(e) => {
tracing::debug!("websocket start_send error: {}", e);
Err(e.into())
}
}
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
match ready!(Pin::new(&mut self.inner).poll_flush(cx)) {
Ok(()) => Poll::Ready(Ok(())),
Err(e) => Poll::Ready(Err(e.into())),
}
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
match ready!(Pin::new(&mut self.inner).poll_close(cx)) {
Ok(()) => Poll::Ready(Ok(())),
Err(err) => {
tracing::debug!("websocket close error: {}", err);
Poll::Ready(Err(err.into()))
}
}
}
}
impl std::fmt::Debug for WebSocket {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WebSocket").finish()
}
}
/// A WebSocket message.
///
/// This will likely become a `non-exhaustive` enum in the future, once that
/// language feature has stabilized.
#[derive(Eq, PartialEq, Clone)]
pub struct Message {
inner: protocol::Message,
}
impl Message {
/// Construct a new Text `Message`.
pub fn text<S: Into<String>>(s: S) -> Message {
Message {
inner: protocol::Message::text(s),
}
}
/// Construct a new Json `Message`.
pub fn json<T>(j: T) -> Result<Message>
where
T: serde::Serialize,
{
Ok(Message {
inner: protocol::Message::text(serde_json::to_string(&j)?),
})
}
/// Construct a new Binary `Message`.
pub fn binary<V: Into<Vec<u8>>>(v: V) -> Message {
Message {
inner: protocol::Message::binary(v),
}
}
/// Construct a new Ping `Message`.
pub fn ping<V: Into<Vec<u8>>>(v: V) -> Message {
Message {
inner: protocol::Message::Ping(v.into()),
}
}
/// Construct a new Pong `Message`.
///
/// Note that one rarely needs to manually construct a Pong message because the underlying tungstenite socket
/// automatically responds to the Ping messages it receives. Manual construction might still be useful in some cases
/// like in tests or to send unidirectional heartbeats.
pub fn pong<V: Into<Vec<u8>>>(v: V) -> Message {
Message {
inner: protocol::Message::Pong(v.into()),
}
}
/// Construct the default Close `Message`.
pub fn close() -> Message {
Message {
inner: protocol::Message::Close(None),
}
}
/// Construct a Close `Message` with a code and reason.
pub fn close_with(code: impl Into<u16>, reason: impl Into<Cow<'static, str>>) -> Message {
Message {
inner: protocol::Message::Close(Some(protocol::frame::CloseFrame {
code: protocol::frame::coding::CloseCode::from(code.into()),
reason: reason.into(),
})),
}
}
/// Returns true if this message is a Text message.
pub fn is_text(&self) -> bool {
self.inner.is_text()
}
/// Returns true if this message is a Binary message.
pub fn is_binary(&self) -> bool {
self.inner.is_binary()
}
/// Returns true if this message a is a Close message.
pub fn is_close(&self) -> bool {
self.inner.is_close()
}
/// Returns true if this message is a Ping message.
pub fn is_ping(&self) -> bool {
self.inner.is_ping()
}
/// Returns true if this message is a Pong message.
pub fn is_pong(&self) -> bool {
self.inner.is_pong()
}
/// Try to get the close frame (close code and reason)
pub fn close_frame(&self) -> Option<(u16, &str)> {
if let protocol::Message::Close(Some(ref close_frame)) = self.inner {
Some((close_frame.code.into(), close_frame.reason.as_ref()))
} else {
None
}
}
/// Try to get a reference to the string text, if this is a Text message.
pub fn to_str(&self) -> Result<&str> {
match self.inner {
protocol::Message::Text(ref s) => Ok(s),
_ => anyhow::bail!("not a text message"),
}
}
/// Return the bytes of this message, if the message can contain data.
pub fn as_bytes(&self) -> &[u8] {
match self.inner {
protocol::Message::Text(ref s) => s.as_bytes(),
protocol::Message::Binary(ref v) => v,
protocol::Message::Ping(ref v) => v,
protocol::Message::Pong(ref v) => v,
protocol::Message::Frame(ref frame) => frame.payload().as_slice(),
protocol::Message::Close(_) => &[],
}
}
/// Return the type as decoded json.
pub fn as_json<T>(&self) -> Result<T>
where
T: serde::de::DeserializeOwned,
{
serde_json::from_slice(self.as_bytes()).map_err(Into::into)
}
/// Destructure this message into binary data.
pub fn into_bytes(self) -> Vec<u8> {
self.inner.into_data()
}
}
impl std::fmt::Debug for Message {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
std::fmt::Debug::fmt(&self.inner, f)
}
}
impl From<Message> for Vec<u8> {
fn from(m: Message) -> Self {
m.into_bytes()
}
}
sample endpoint
This is just an example of wrapping the docker attach endpoint which also happens to be a websocket.
/**
* Attach to a docker container to create an interactive terminal.
*/
#[endpoint {
method = GET,
path = "/term",
tags = ["term"],
}]
pub async fn create_term(
rqctx: Arc<RequestContext<Arc<Context>>>,
) -> Result<http::response::Response<hyper::body::Body>, HttpError> {
// We want to parse the headers from the request context.
let request = &rqctx.request.lock().await;
// We want to make sure we have the websocket upgrade header.
let h: Option<headers::Connection> = request.headers().typed_get();
if let Some(h) = h {
if !h.contains(&http::header::UPGRADE) {
anyhow::bail!("Connection header did not include 'upgrade'");
}
} else {
anyhow::bail!("Connection header not sent");
}
// Now get the secure websocket key.
let h: Option<headers::SecWebsocketKey> = request.headers().typed_get();
let websocket_key = if let Some(h) = h {
h
} else {
anyhow::bail!("Websocket key not sent");
};
let ctx = rqctx.context();
// Create and start the term container.
let container_id = ctx.create_ws_term_container(token).await?;
// We reply with:
// - Status of `101 Switching Protocols`
// - Header `connection: upgrade`
// - Header `upgrade: websocket`
// - Header `sec-websocket-accept` with the hash value of the received key.
let mut response = http::Response::default();
*response.status_mut() = http::StatusCode::SWITCHING_PROTOCOLS;
response.headers_mut().typed_insert(headers::Connection::upgrade());
response.headers_mut().typed_insert(headers::Upgrade::websocket());
response
.headers_mut()
.typed_insert(headers::SecWebsocketAccept::from(websocket_key));
// Spawn a task to handle the websocket connection.
tokio::spawn(common::enclose! { (rqctx, ctx) async move {
let mut request = rqctx.request.lock().await;
// Attach to the container.
let bollard::container::AttachContainerResults { mut output, mut input } = ctx
.docker
.attach_container(
&container_id,
Some(bollard::container::AttachContainerOptions::<String> {
stdout: Some(true),
stderr: Some(true),
stdin: Some(true),
stream: Some(true),
..Default::default()
}),
)
.await?;
// Use the hyper feature of upgrading a connection.
match hyper::upgrade::on(request.deref_mut()).await {
//if successfully upgraded
Ok(upgraded) => {
//create a websocket stream from the upgraded object
let ws = crate::server::websocket::WebSocket::from_raw_socket(
// Pass the upgraded object as the base layer stream of the Websocket.
upgraded,
tokio_tungstenite::tungstenite::protocol::Role::Server,
None,
)
.await;
// We split the stream into a sink and a stream.
let (mut ws_write, mut ws_read) = ws.split();
// For every message we get on the container websocket, we want
// to send it back out to our browser websocket.
tokio::spawn(async move {
while let Some(byte) = output.next().await {
ws_write.send(crate::server::websocket::Message::json(crate::types::WebsocketMessage::Stdout{
data: match byte? {
bollard::container::LogOutput::StdErr{message} => {
String::from_utf8(message.to_vec())?
}
bollard::container::LogOutput::StdOut{message} => {
String::from_utf8(message.to_vec())?
}
bollard::container::LogOutput::StdIn{message} => {
String::from_utf8(message.to_vec())?
}
bollard::container::LogOutput::Console{message} => {
String::from_utf8(message.to_vec())?
}
},
})?).await?;
}
Ok::<(), anyhow::Error>(())
});
// Everytime we get a message from the server, we need to handle it.
while let Some(result) = ws_read.next().await {
match result {
Ok(msg) =>{
// If the message is a close message we want to stop the container.
if msg.is_close() {
break;
}
if msg.is_ping(){
// We can continue early.
continue;
}
// Let's parse the message as JSON.
let m: crate::types::WebsocketMessage = msg.as_json()?;
match m {
crate::types::WebsocketMessage::Stdin{data} => {
// We want to send the message to the container.
input.write_all(data.as_bytes()).await?;
}
crate::types::WebsocketMessage::Stdout{data} => {
// We should not get stdout, since that is not a valid
// incoming type, it is the type we send back to the
// browser.
log::warn!("Received stdout from browser: {}", data);
}
crate::types::WebsocketMessage::Resize{height, width} => {
// Resize the container.
ctx.docker.resize_container_tty(&container_id,
bollard::container::ResizeContainerTtyOptions{
height,
width,
}).await?;
}
crate::types::WebsocketMessage::KeepAlive => {
// Do nothing.
continue;
}
}
},
Err(e) => {
log::warn!("browser websocket message error: {}", e);
break;
}
};
}
// If we get here, it means the websocket connection was closed.
// Meaning the user closed the connection.
log::debug!("browser websocket connection closed");
}
Err(e) => {
anyhow::bail!("trying to upgrade connection to websocket connection failed: {}", e);
}
}
// If we get here we need to remove the container.
ctx.remove_container(&container_id).await?;
Ok(())
}});
Ok(response)
}
So my proposal for the UX fully integrated into dropshot would be an Ext trait on RequestContext such that it implements a function upgrade which does all the bullshit with making sure all the headers are there and correct and upgrade takes a function generic that would be run async such as all my logic above that runs async after we upgrade... make sense? Happy to implement it if it sounds good. I figure we will eventually need this and its even nicer as an end user if you don't have to do all the boilerplate.