Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: use DuplexStream instead of UnixStream to communicate with workers #320

Merged
merged 4 commits into from
Apr 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions crates/base/src/deno_runtime.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::inspector_server::Inspector;
use crate::rt_worker::rt;
use crate::rt_worker::supervisor::{CPUUsage, CPUUsageMetrics};
use crate::rt_worker::worker::UnixStreamEntry;
use crate::rt_worker::worker::DuplexStreamEntry;
use crate::utils::units::{bytes_to_display, mib_to_bytes};

use anyhow::{anyhow, bail, Context, Error};
Expand Down Expand Up @@ -33,7 +33,6 @@ use std::borrow::Cow;
use std::collections::HashMap;
use std::ffi::c_void;
use std::fmt;
use std::os::fd::RawFd;
use std::sync::Arc;
use std::task::Poll;
use std::time::Duration;
Expand Down Expand Up @@ -525,7 +524,7 @@ impl DenoRuntime {
}

if conf.is_main_worker() || conf.is_user_worker() {
op_state.put::<HashMap<RawFd, watch::Receiver<ConnSync>>>(HashMap::new());
op_state.put::<HashMap<usize, watch::Receiver<ConnSync>>>(HashMap::new());
}

if conf.is_user_worker() {
Expand Down Expand Up @@ -596,14 +595,15 @@ impl DenoRuntime {

pub async fn run(
&mut self,
unix_stream_rx: mpsc::UnboundedReceiver<UnixStreamEntry>,
duplex_stream_rx: mpsc::UnboundedReceiver<DuplexStreamEntry>,
maybe_cpu_usage_metrics_tx: Option<mpsc::UnboundedSender<CPUUsageMetrics>>,
name: Option<String>,
) -> (Result<(), Error>, i64) {
{
let op_state_rc = self.js_runtime.op_state();
let mut op_state = op_state_rc.borrow_mut();
op_state.put::<mpsc::UnboundedReceiver<UnixStreamEntry>>(unix_stream_rx);

op_state.put::<mpsc::UnboundedReceiver<DuplexStreamEntry>>(duplex_stream_rx);

if self.conf.is_main_worker() {
op_state.put::<mpsc::UnboundedSender<UserWorkerMsgs>>(
Expand Down Expand Up @@ -887,7 +887,7 @@ extern "C" fn mem_check_gc_prologue_callback_fn(
#[cfg(test)]
mod test {
use crate::deno_runtime::DenoRuntime;
use crate::rt_worker::worker::UnixStreamEntry;
use crate::rt_worker::worker::DuplexStreamEntry;
use deno_core::{FastString, ModuleCodeString, PollEventLoopOptions};
use sb_graph::emitter::EmitterFactory;
use sb_graph::{generate_binary_eszip, EszipPayloadKind};
Expand Down Expand Up @@ -1467,8 +1467,8 @@ mod test {
let mut user_rt =
create_basic_user_runtime("./test_cases/array_buffers", 20, 1000, &[]).await;

let (_tx, unix_stream_rx) = mpsc::unbounded_channel::<UnixStreamEntry>();
let (result, _) = user_rt.run(unix_stream_rx, None, None).await;
let (_tx, duplex_stream_rx) = mpsc::unbounded_channel::<DuplexStreamEntry>();
let (result, _) = user_rt.run(duplex_stream_rx, None, None).await;

assert!(result.is_ok(), "expected no errors");

Expand All @@ -1482,8 +1482,8 @@ mod test {
let mut user_rt =
create_basic_user_runtime("./test_cases/array_buffers", 15, 1000, &[]).await;

let (_tx, unix_stream_rx) = mpsc::unbounded_channel::<UnixStreamEntry>();
let (result, _) = user_rt.run(unix_stream_rx, None, None).await;
let (_tx, duplex_stream_rx) = mpsc::unbounded_channel::<DuplexStreamEntry>();
let (result, _) = user_rt.run(duplex_stream_rx, None, None).await;

match result {
Err(err) => {
Expand All @@ -1501,7 +1501,7 @@ mod test {
memory_limit_mb: u64,
worker_timeout_ms: u64,
) {
let (_unix_stream_tx, unix_stream_rx) = mpsc::unbounded_channel::<UnixStreamEntry>();
let (_duplex_stream_tx, duplex_stream_rx) = mpsc::unbounded_channel::<DuplexStreamEntry>();
let (callback_tx, mut callback_rx) = mpsc::unbounded_channel::<()>();
let mut user_rt =
create_basic_user_runtime(path, memory_limit_mb, worker_timeout_ms, static_patterns)
Expand All @@ -1518,7 +1518,7 @@ mod test {
});

let wait_fut = async move {
let (result, _) = user_rt.run(unix_stream_rx, None, None).await;
let (result, _) = user_rt.run(duplex_stream_rx, None, None).await;

assert_eq!(
result.unwrap_err().to_string(),
Expand Down
6 changes: 3 additions & 3 deletions crates/base/src/rt_worker/implementation/default_handler.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::deno_runtime::DenoRuntime;
use crate::rt_worker::supervisor::CPUUsageMetrics;
use crate::rt_worker::worker::{HandleCreationType, UnixStreamEntry, Worker, WorkerHandler};
use crate::rt_worker::worker::{DuplexStreamEntry, HandleCreationType, Worker, WorkerHandler};
use anyhow::Error;
use event_worker::events::{BootFailureEvent, PseudoEvent, UncaughtExceptionEvent, WorkerEvents};
use log::error;
Expand All @@ -19,14 +19,14 @@ impl WorkerHandler for Worker {
fn handle_creation<'r>(
&self,
created_rt: &'r mut DenoRuntime,
unix_stream_rx: UnboundedReceiver<UnixStreamEntry>,
duplex_stream_rx: UnboundedReceiver<DuplexStreamEntry>,
termination_event_rx: Receiver<WorkerEvents>,
maybe_cpu_usage_metrics_tx: Option<UnboundedSender<CPUUsageMetrics>>,
name: Option<String>,
) -> HandleCreationType<'r> {
let run_worker_rt = async move {
match created_rt
.run(unix_stream_rx, maybe_cpu_usage_metrics_tx, name)
.run(duplex_stream_rx, maybe_cpu_usage_metrics_tx, name)
.await
{
// if the error is execution terminated, check termination event reason
Expand Down
18 changes: 9 additions & 9 deletions crates/base/src/rt_worker/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use sb_workers::context::{UserWorkerMsgs, WorkerContextInitOpts};
use std::any::Any;
use std::future::{pending, Future};
use std::pin::Pin;
use tokio::net::UnixStream;
use tokio::io;
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
use tokio::sync::oneshot::{Receiver, Sender};
use tokio::sync::{oneshot, watch};
Expand All @@ -44,14 +44,14 @@ pub struct Worker {
}

pub type HandleCreationType<'r> = Pin<Box<dyn Future<Output = Result<WorkerEvents, Error>> + 'r>>;
pub type UnixStreamEntry = (UnixStream, Option<watch::Receiver<ConnSync>>);
pub type DuplexStreamEntry = (io::DuplexStream, Option<watch::Receiver<ConnSync>>);

pub trait WorkerHandler: Send {
fn handle_error(&self, error: Error) -> Result<WorkerEvents, Error>;
fn handle_creation<'r>(
&self,
created_rt: &'r mut DenoRuntime,
unix_stream_rx: UnboundedReceiver<UnixStreamEntry>,
duplex_stream_rx: UnboundedReceiver<DuplexStreamEntry>,
termination_event_rx: Receiver<WorkerEvents>,
maybe_cpu_metrics_tx: Option<UnboundedSender<CPUUsageMetrics>>,
name: Option<String>,
Expand Down Expand Up @@ -91,9 +91,9 @@ impl Worker {
pub fn start(
&self,
mut opts: WorkerContextInitOpts,
unix_stream_pair: (
UnboundedSender<UnixStreamEntry>,
UnboundedReceiver<UnixStreamEntry>,
duplex_stream_pair: (
UnboundedSender<DuplexStreamEntry>,
UnboundedReceiver<DuplexStreamEntry>,
),
booter_signal: Sender<Result<MetricSource, Error>>,
termination_token: Option<TerminationToken>,
Expand All @@ -104,7 +104,7 @@ impl Worker {
let event_metadata = self.event_metadata.clone();
let supervisor_policy = self.supervisor_policy;

let (unix_stream_tx, unix_stream_rx) = unix_stream_pair;
let (duplex_stream_tx, duplex_stream_rx) = duplex_stream_pair;
let events_msg_tx = self.events_msg_tx.clone();
let pool_msg_tx = self.pool_msg_tx.clone();

Expand Down Expand Up @@ -244,7 +244,7 @@ impl Worker {
let result = method_cloner
.handle_creation(
&mut runtime,
unix_stream_rx,
duplex_stream_rx,
termination_event_rx,
maybe_cpu_usage_metrics_tx,
Some(worker_name),
Expand Down Expand Up @@ -283,7 +283,7 @@ impl Worker {
}
};

drop(unix_stream_tx);
drop(duplex_stream_tx);

match result {
Ok(event) => {
Expand Down
31 changes: 14 additions & 17 deletions crates/base/src/rt_worker/worker_ctx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ use std::io::ErrorKind;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
use tokio::io::copy_bidirectional;
use tokio::net::{TcpStream, UnixStream};
use tokio::io::{self, copy_bidirectional};
use tokio::net::TcpStream;
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
use tokio::sync::{mpsc, oneshot, watch, Mutex};
use tokio_rustls::server::TlsStream;
Expand All @@ -42,7 +42,7 @@ use uuid::Uuid;

use super::rt;
use super::supervisor::{self, CPUTimerParam, CPUUsageMetrics};
use super::worker::UnixStreamEntry;
use super::worker::DuplexStreamEntry;
use super::worker_pool::{SupervisorPolicy, WorkerPoolPolicy};

#[derive(Clone)]
Expand Down Expand Up @@ -93,28 +93,25 @@ impl TerminationToken {
}

async fn handle_request(
unix_stream_tx: mpsc::UnboundedSender<UnixStreamEntry>,
duplex_stream_tx: mpsc::UnboundedSender<DuplexStreamEntry>,
msg: WorkerRequestMsg,
) -> Result<(), Error> {
// create a unix socket pair
let (sender_stream, recv_stream) = UnixStream::pair()?;
let (ours, theirs) = io::duplex(1024);
let WorkerRequestMsg {
mut req,
res_tx,
conn_watch,
} = msg;

let _ = unix_stream_tx.send((recv_stream, conn_watch.clone()));
let _ = duplex_stream_tx.send((theirs, conn_watch.clone()));
let req_upgrade_type = get_upgrade_type(req.headers());
let req_upgrade = req_upgrade_type
.clone()
.and_then(|it| Some(it).zip(req.extensions_mut().remove::<OnUpgrade>()));

// send the HTTP request to the worker over Unix stream
let (mut request_sender, connection) = http1::Builder::new()
.writev(true)
.handshake(sender_stream)
.await?;
// send the HTTP request to the worker over duplex stream
let (mut request_sender, connection) =
http1::Builder::new().writev(true).handshake(ours).await?;

let (upgrade_tx, upgrade_rx) = oneshot::channel();

Expand Down Expand Up @@ -177,7 +174,7 @@ async fn handle_request(

async fn relay_upgraded_request_and_response(
downstream: OnUpgrade,
parts: http1::Parts<UnixStream>,
parts: http1::Parts<io::DuplexStream>,
) {
let mut upstream = Upgraded2::new(parts.io, parts.read_buf);
let mut downstream = downstream.await.expect("failed to upgrade request");
Expand All @@ -190,7 +187,7 @@ async fn relay_upgraded_request_and_response(
// `close_notify` before shutdown an upstream if downstream is a
// TLS stream.

// INVARIANT: `UnexpectedEof` due to shutdown `UnixStream` is
// INVARIANT: `UnexpectedEof` due to shutdown `DuplexStream` is
// only expected to occur in the context of `TlsStream`.
panic!("unhandleable unexpected eof");
};
Expand Down Expand Up @@ -516,7 +513,7 @@ pub async fn create_worker<Opt: Into<CreateWorkerArgs>>(
init_opts: Opt,
inspector: Option<Inspector>,
) -> Result<(MetricSource, mpsc::UnboundedSender<WorkerRequestMsg>), Error> {
let (unix_stream_tx, unix_stream_rx) = mpsc::unbounded_channel::<UnixStreamEntry>();
let (duplex_stream_tx, duplex_stream_rx) = mpsc::unbounded_channel::<DuplexStreamEntry>();
let (worker_boot_result_tx, worker_boot_result_rx) =
oneshot::channel::<Result<MetricSource, Error>>();

Expand All @@ -539,7 +536,7 @@ pub async fn create_worker<Opt: Into<CreateWorkerArgs>>(
if let Some(worker_struct_ref) = downcast_reference {
worker_struct_ref.start(
init_opts,
(unix_stream_tx.clone(), unix_stream_rx),
(duplex_stream_tx.clone(), duplex_stream_rx),
worker_boot_result_tx,
maybe_termination_token.clone(),
inspector,
Expand All @@ -549,7 +546,7 @@ pub async fn create_worker<Opt: Into<CreateWorkerArgs>>(
let (worker_req_tx, mut worker_req_rx) = mpsc::unbounded_channel::<WorkerRequestMsg>();

let worker_req_handle: tokio::task::JoinHandle<Result<(), Error>> = tokio::task::spawn({
let stream_tx = unix_stream_tx;
let stream_tx = duplex_stream_tx;
async move {
while let Some(msg) = worker_req_rx.recv().await {
tokio::task::spawn({
Expand Down
17 changes: 16 additions & 1 deletion crates/base/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ pub struct WorkerEntrypoints {
pub struct ServerFlags {
pub no_module_cache: bool,
pub allow_main_inspector: bool,
pub tcp_nodelay: bool,
pub graceful_exit_deadline_sec: u64,
}

Expand Down Expand Up @@ -446,7 +447,13 @@ impl Server {
}

let event_tx = can_receive_event.then_some(event_tx.clone());
let mut graceful_exit_deadline = flags.graceful_exit_deadline_sec;
let ServerFlags {
tcp_nodelay,
graceful_exit_deadline_sec,
..
} = flags;

let mut graceful_exit_deadline = graceful_exit_deadline_sec;

loop {
let main_worker_req_tx = self.main_worker_req_tx.clone();
Expand All @@ -457,6 +464,10 @@ impl Server {
msg = non_secure_listener.accept() => {
match msg {
Ok((stream, _)) => {
if tcp_nodelay {
let _ = stream.set_nodelay(true);
}

accept_stream(stream, main_worker_req_tx, event_tx, metric_src)
}
Err(e) => error!("socket error: {}", e)
Expand All @@ -473,6 +484,10 @@ impl Server {
} => {
match msg {
Ok((stream, _)) => {
if tcp_nodelay {
let _ = stream.get_ref().0.set_nodelay(true);
}

accept_stream(stream, main_worker_req_tx, event_tx, metric_src);
}
Err(e) => error!("socket error: {}", e)
Expand Down
14 changes: 13 additions & 1 deletion crates/cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use base::deno_runtime::MAYBE_DENO_VERSION;
use base::rt_worker::worker_pool::{SupervisorPolicy, WorkerPoolPolicy};
use base::server::{ServerFlags, Tls, WorkerEntrypoints};
use base::{DecoratorType, InspectorOption};
use clap::builder::{FalseyValueParser, TypedValueParser};
use clap::builder::{BoolishValueParser, FalseyValueParser, TypedValueParser};
use clap::{arg, crate_version, value_parser, ArgAction, ArgGroup, ArgMatches, Command};
use deno_core::url::Url;
use log::warn;
Expand Down Expand Up @@ -145,6 +145,13 @@ fn cli() -> Command {
.action(ArgAction::SetTrue)
)
.arg(arg!(--"static" <Path> "Glob pattern for static files to be included"))
.arg(arg!(--"tcp-nodelay" [BOOL] "Disables Nagle's algorithm")
.num_args(0..=1)
.value_parser(BoolishValueParser::new())
.require_equals(true)
.default_value("true")
.default_missing_value("true")
)
)
.subcommand(
Command::new("bundle")
Expand Down Expand Up @@ -269,6 +276,10 @@ fn main() -> Result<(), anyhow::Error> {
None
};

let tcp_nodelay =sub_matches.get_one::<bool>("tcp-nodelay")
.copied()
.unwrap();

start_server(
ip.as_str(),
port,
Expand Down Expand Up @@ -298,6 +309,7 @@ fn main() -> Result<(), anyhow::Error> {
ServerFlags {
no_module_cache,
allow_main_inspector,
tcp_nodelay,
graceful_exit_deadline_sec: graceful_exit_timeout.unwrap_or(0),
},
None,
Expand Down
Loading
Loading