diff --git a/backend/src/worker.rs b/backend/src/worker.rs index 05800a18195e6..56ebe6645a7ea 100644 --- a/backend/src/worker.rs +++ b/backend/src/worker.rs @@ -8,12 +8,10 @@ use itertools::Itertools; use std::{ + borrow::Borrow, collections::HashMap, + io, panic, process::{ExitStatus, Stdio}, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }, time::Duration, }; use uuid::Uuid; @@ -41,8 +39,13 @@ use tokio::{ fs::{DirBuilder, File}, io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader}, process::{Child, Command}, - sync::mpsc, - time::Instant, + sync::watch, + time::{interval, sleep, Instant, MissedTickBehavior}, +}; + +use futures::{ + future::{self, ready, FutureExt}, + stream::{self, StreamExt}, }; use async_recursion::async_recursion; @@ -1343,6 +1346,12 @@ async fn get_reserved_variables( .collect()) } +/// - wait until child exits and return with exit status +/// - read lines from stdout and stderr and append them to the "queue"."logs" +/// quitting early if output exceedes MAX_LOG_SIZE characters (not bytes) +/// - update the `last_line` and `logs` strings with the program output +/// - update "queue"."last_ping" every five seconds +/// - kill process if we exceed timeout or "queue"."canceled" is set async fn handle_child( job_id: &Uuid, db: &DB, @@ -1351,175 +1360,257 @@ async fn handle_child( timeout: i32, mut child: Child, ) -> crate::error::Result { + let timeout = Duration::from_secs(u64::try_from(timeout).expect("invalid timeout")); + let ping_interval = Duration::from_secs(5); + let cancel_check_interval = Duration::from_millis(500); + let write_logs_delay = Duration::from_millis(500); + + let (set_too_many_logs, mut too_many_logs) = watch::channel::(false); + + let output = child_joined_output_stream(&mut child); let job_id = job_id.clone(); - let stderr = child - .stderr - .take() - .expect("child did not have a handle to stdout"); - let stdout = child - .stdout - .take() - .expect("child did not have a handle to stdout"); + /* the cancellation future is polled on by `wait_on_child` while + * waiting for the child to exit normally */ + let cancel_check = async { + let db = db.clone(); - let mut reader = BufReader::new(stdout).lines(); - let mut stderr_reader = BufReader::new(stderr).lines(); - - let done = Arc::new(AtomicBool::new(false)); - - let done2 = done.clone(); - let done3 = done.clone(); - let done4 = done.clone(); - // Ensure the child process is spawned in the runtime so it can - // make progress on its own while we await for any output. - let handle = tokio::spawn(async move { - let inner_done = done2.clone(); - let r: Result = tokio::select! { - r = child.wait() => { - inner_done.store(true, Ordering::Relaxed); - Ok(r?) + let mut interval = interval_skipping_missed(cancel_check_interval).boxed(); + while let Some(_) = interval.next().await { + if sqlx::query_scalar!("SELECT canceled FROM queue WHERE id = $1", job_id) + .fetch_optional(&db) + .await + .map(|v| Some(true) == v) + .unwrap_or_else(|err| { + tracing::error!(%job_id, %err, "error checking cancelation for job {job_id}: {err}"); + false + }) + { + break; } - _ = async move { - while !done2.load(Ordering::Relaxed) { - tokio::time::sleep(Duration::from_secs(1)).await; + } + }; + + /* a future that completes when the child process exits */ + let wait_on_child = async { + let db = db.clone(); + + let timed_out = tokio::select! { + biased; + result = child.wait() => return result.map(Some), + _ = too_many_logs.changed() => false, + _ = cancel_check => false, + _ = sleep(timeout) => true, + }; + + let set_reason = async { + if timed_out { + if let Err(err) = sqlx::query( + r#" + UPDATE queue + SET canceled = true + , canceled_by = 'timeout', + , canceled_reason = $1 + WHERE id = $2 + r"#, + ) + .bind(format!("duration > {}", timeout.as_secs())) + .bind(job_id) + .execute(&db) + .await + { + tracing::error!(%job_id, %err, "error setting cancelation reason for job {job_id}: {err}"); } - } => { - child.kill().await?; - return Err(Error::ExecutionErr("execution interrupted".to_string()).into()) } }; - r - }); - - let (tx, mut rx) = mpsc::channel::(100); - - tokio::spawn(async move { - while !done4.load(Ordering::Relaxed) { - let send = tokio::select! { - Ok(Some(out)) = reader.next_line() => { - if out.len() > MAX_LOG_SIZE as usize { - tracing::info!("Line is too big"); - let _ = tx.send(format!("Line is too big")).await; - done4.store(true, Ordering::Relaxed); - break; - } else { - tx.send(out).await + + /* send SIGKILL and reap child process */ + let (_, kill) = future::join(set_reason, child.kill()).await; + kill.map(|()| None) + }; + + /* a future that reads output from the child and appends to the database */ + let lines = async move { + /* log_remaining is zero when output limit was reached */ + let mut log_remaining = (MAX_LOG_SIZE as usize).saturating_sub(logs.chars().count()); + let mut result = io::Result::Ok(()); + let mut output = output; + /* `do_write` resolves the task, but does not contain the Result. + * It's useful to know if the task completed. */ + let (mut do_write, mut write_result) = tokio::spawn(ready(())).remote_handle(); + + while let Some(line) = output.by_ref().next().await { + let do_write_ = do_write.shared(); + + let mut read_lines = stream::once(async { line }) + .chain(output.by_ref()) + /* after receiving a line, continue until some delay has passed + * _and_ the previous database write is complete */ + .take_until(future::join(sleep(write_logs_delay), do_write_.clone())) + .boxed(); + + /* Read up until an error is encountered, + * handle log lines first and then the error... */ + let mut joined = String::new(); + + while let Some(line) = read_lines.next().await { + match line { + Ok(_) if log_remaining == 0 => (), + Ok(line) => { + append_with_limit(&mut joined, &line, &mut log_remaining); + + *last_line = line; + + if log_remaining == 0 { + tracing::info!(%job_id, "Too many logs lines for job {job_id}"); + let _ = set_too_many_logs.send(true); + joined.push_str(&format!( + "Job logs or result reached character limit of {MAX_LOG_SIZE}; killing job." + )); + /* stop reading and drop our streams fairly quickly */ + break; + } } - }, - Ok(Some(err)) = stderr_reader.next_line() => { - if err.len() > MAX_LOG_SIZE as usize { - tracing::info!("Line is too big"); - let _ = tx.send(format!("Line is too big")).await; - done4.store(true, Ordering::Relaxed); + Err(err) => { + result = Err(err); break; - } else { - tx.send(err).await } - }, - else => { - break - }, - }; - if send.err().is_some() { - tracing::error!("error sending log line"); - }; - } - }); + } + } - let db2 = db.clone(); + logs.push_str(&joined); - tokio::spawn(async move { - while !&done3.load(Ordering::Relaxed) { - let q = sqlx::query!("UPDATE queue SET last_ping = now() WHERE id = $1", job_id) - .execute(&db2) - .await; + /* Ensure the last flush completed before starting a new one. + * + * This shouldn't pause since `take_until()` reads lines until `do_write` + * resolves. We only stop reading lines before `take_until()` resolves if we reach + * EOF or a read error. In those cases, waiting on a database query to complete is + * fine because we're done. */ - if q.is_err() { - tracing::error!("error setting last ping for id {}", job_id); + if let Some(Ok(p)) = do_write_ + .then(|()| write_result) + .await + .err() + .map(|err| err.try_into_panic()) + { + panic::resume_unwind(p); } - tokio::time::sleep(Duration::from_secs(5)).await; - } - }); + (do_write, write_result) = + tokio::spawn(append_logs(job_id, joined, db.clone())).remote_handle(); - let mut start = logs.chars().count(); - let mut last_update = chrono::Utc::now().timestamp_millis(); - let initial_start = chrono::Utc::now(); + if let Err(err) = result { + tracing::error!(%job_id, %err, "error reading output for job {job_id}: {err}"); + break; + } - while !done.load(Ordering::Relaxed) { - let diff = 500 - (chrono::Utc::now().timestamp_millis() - last_update); - let sleeping_future = if diff > 0 as i64 { - tokio::time::sleep(Duration::from_millis(diff as u64)) - } else { - //TODO make it just resolve immediately - tokio::time::sleep(Duration::from_millis(0)) - }; - tokio::select! { - _ = sleeping_future => { - let end = logs.chars().count(); + if *set_too_many_logs.borrow() { + break; + } + } - let to_send = logs.chars().skip(start).collect::(); + /* drop our end of the pipe */ + drop(output); - if start != end { - concat_logs(&to_send, &job_id, db).await; - start = end; - } + if let Some(Ok(p)) = do_write + .then(|()| write_result) + .await + .err() + .map(|err| err.try_into_panic()) + { + panic::resume_unwind(p); + } + }; - let canceled = sqlx::query_scalar!("SELECT canceled FROM queue WHERE id = $1", job_id) - .fetch_one(db) + /* a stream updating "queue"."last_ping" at an interval */ + + let ping = interval_skipping_missed(ping_interval) + .map(|_| db.clone()) + .then(move |db| async move { + if let Err(err) = + sqlx::query!("UPDATE queue SET last_ping = now() WHERE id = $1", job_id) + .execute(&db) .await - .map_err(|e| tracing::error!("error getting canceled for id {}: {e}", job_id)) - .unwrap_or(false); + { + tracing::error!(%job_id, %err, "error setting last ping for job {job_id}: {err}"); + } + }); - if canceled { - tracing::info!("killed after cancel: {}", job_id); - done.store(true, Ordering::Relaxed); - } + let wait_result = tokio::select! { + (w, _) = future::join(wait_on_child, lines) => w, + /* ping should repeat forever without stopping */ + _ = ping.collect::<()>() => unreachable!("job ping stopped"), + }; - let has_timeout = (chrono::Utc::now() - initial_start).num_seconds() > timeout as i64; + match wait_result { + _ if *too_many_logs.borrow() => Err(Error::ExecutionErr( + "logs or result reached limit".to_string(), + )), + Ok(Some(status)) => Ok(status), + Ok(None) => Err(Error::ExecutionErr("job process killed".to_string())), + Err(err) => Err(Error::ExecutionErr(format!("job process io error: {err}"))), + } +} - if has_timeout { - let q = sqlx::query(&format!( - "UPDATE queue SET canceled = true, canceled_by = 'timeout', \ - canceled_reason = 'duration > {}' WHERE id = $1", - timeout - )) - .bind(job_id) - .execute(db) - .await; +fn interval_skipping_missed(period: Duration) -> impl futures::Stream { + let mut interval = interval(period); + interval.set_missed_tick_behavior(MissedTickBehavior::Skip); + stream::poll_fn(move |cx| interval.poll_tick(cx).map(Some)) +} - if q.is_err() { - tracing::error!("error setting canceled for id {}", job_id); - } - } - last_update = chrono::Utc::now().timestamp_millis(); - }, - nl = rx.recv() => { +/// takes stdout and stderr from Child, panics if either are not present +/// +/// builds a stream joining both stdout and stderr each read line by line +fn child_joined_output_stream( + child: &mut Child, +) -> impl stream::FusedStream> { + let stderr = child + .stderr + .take() + .expect("child did not have a handle to stdout"); - if let Some(nl) = nl { - if logs.chars().count() > MAX_LOG_SIZE as usize{ - tracing::info!("Too many logs lines: {}", job_id); - logs.push_str("Too many logs lines. Limit is 200000 chars. Killing job."); - done.store(true, Ordering::Relaxed); - } + let stdout = child + .stdout + .take() + .expect("child did not have a handle to stdout"); - logs.push('\n'); - logs.push_str(&nl); + let stdout = BufReader::new(stdout).lines(); + let stderr = BufReader::new(stderr).lines(); + stream::select(lines_to_stream(stderr), lines_to_stream(stdout)) +} - *last_line = nl; - } else { - let to_send = logs.chars().skip(start).collect::(); - concat_logs(&to_send, &job_id, db).await; - break; - } - }, - } +fn lines_to_stream( + mut lines: tokio::io::Lines, +) -> impl futures::Stream> { + stream::poll_fn(move |cx| { + std::pin::Pin::new(&mut lines) + .poll_next_line(cx) + .map(|result| result.transpose()) + }) +} + +// as a detail, `BufReader::lines()` removes \n and \r\n from the strings it yields, +// so this pushes \n to thd destination string in each call +fn append_with_limit(dst: &mut String, src: &str, limit: &mut usize) { + if *limit > 0 { + dst.push('\n'); } + *limit -= 1; - let status = handle - .await - .map_err(|e| Error::ExecutionErr(e.to_string()))??; - Ok(status) + let src_len = src.chars().count(); + if src_len <= *limit { + dst.push_str(&src); + *limit -= src_len; + } else { + let byte_pos = src + .char_indices() + .skip(*limit) + .next() + .map(|(byte_pos, _)| byte_pos) + .unwrap_or(0); + dst.push_str(&src[0..byte_pos]); + *limit = 0; + } } async fn set_logs(logs: &str, id: uuid::Uuid, db: &DB) { @@ -1532,22 +1623,26 @@ async fn set_logs(logs: &str, id: uuid::Uuid, db: &DB) { .await .is_err() { - tracing::error!("error updating logs for id {}", id) + tracing::error!(%id, "error updating logs for id {id}") }; } -async fn concat_logs(logs: &str, id: &Uuid, db: &DB) { - if sqlx::query!( +/* TODO retry this? */ +async fn append_logs(job_id: uuid::Uuid, logs: impl AsRef, db: impl Borrow) { + if logs.as_ref().is_empty() { + return; + } + + if let Err(err) = sqlx::query!( "UPDATE queue SET logs = concat(logs, $1::text) WHERE id = $2", - logs.to_owned(), - id + logs.as_ref(), + job_id, ) - .execute(db) + .execute(db.borrow()) .await - .is_err() { - tracing::error!("error updating logs for id {}", id) - }; + tracing::error!(%job_id, %err, "error updating logs for job {job_id}: {err}"); + } } pub async fn restart_zombie_jobs_periodically( @@ -2822,20 +2917,17 @@ def main(error, port): let mut listener = PgListener::connect_with(db).await.unwrap(); listener.listen(channel).await.unwrap(); - Box::pin(futures::stream::unfold( - listener, - |mut listener| async move { - let uuid = listener - .try_recv() - .await - .unwrap() - .expect("lost database connection") - .payload() - .parse::() - .expect("invalid uuid"); - Some((uuid, listener)) - }, - )) + Box::pin(stream::unfold(listener, |mut listener| async move { + let uuid = listener + .try_recv() + .await + .unwrap() + .expect("lost database connection") + .payload() + .parse::() + .expect("invalid uuid"); + Some((uuid, listener)) + })) } async fn completed_job_result(uuid: Uuid, db: &DB) -> Value {