-
Notifications
You must be signed in to change notification settings - Fork 1.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: port task graph execution #5791
Changes from all commits
9a39139
d164ed7
688bd5b
9db3044
2dfd031
c1306b3
f533019
ccdce01
457a471
1c0b6b3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
use std::sync::{Arc, Mutex}; | ||
|
||
use futures::{stream::FuturesUnordered, StreamExt}; | ||
use tokio::sync::{mpsc, oneshot, Semaphore}; | ||
use tracing::log::debug; | ||
|
||
use super::{Engine, TaskNode}; | ||
use crate::{graph::Walker, run::task_id::TaskId}; | ||
|
||
pub struct Message<T, U> { | ||
pub info: T, | ||
pub callback: oneshot::Sender<U>, | ||
} | ||
|
||
// Type alias used just to make altering the data sent to the visitor easier in | ||
// the future | ||
type VisitorData = TaskId<'static>; | ||
type VisitorResult = Result<(), StopExecution>; | ||
|
||
#[derive(Debug, Clone, Copy, PartialEq, Eq)] | ||
pub struct ExecutionOptions { | ||
parallel: bool, | ||
concurrency: usize, | ||
} | ||
|
||
impl ExecutionOptions { | ||
pub fn new(parallel: bool, concurrency: usize) -> Self { | ||
Self { | ||
parallel, | ||
concurrency, | ||
} | ||
} | ||
} | ||
|
||
#[derive(Debug, thiserror::Error)] | ||
pub enum ExecuteError { | ||
#[error("Semaphore closed before all tasks finished")] | ||
Semaphore(#[from] tokio::sync::AcquireError), | ||
#[error("Engine visitor closed channel before walk finished")] | ||
Visitor, | ||
} | ||
|
||
impl From<mpsc::error::SendError<Message<VisitorData, VisitorResult>>> for ExecuteError { | ||
fn from( | ||
_: mpsc::error::SendError<Message<TaskId<'static>, Result<(), StopExecution>>>, | ||
) -> Self { | ||
ExecuteError::Visitor | ||
} | ||
} | ||
|
||
#[derive(Debug, Clone, Copy)] | ||
pub struct StopExecution; | ||
|
||
impl Engine { | ||
/// Execute a task graph by sending task ids to the visitor | ||
/// while respecting concurrency limits. | ||
/// The visitor is expected to handle any error handling on it's end. | ||
/// We enforce this by only allowing the returning of a sentinel error | ||
/// type which will stop any further execution of tasks. | ||
/// This will not stop any task which is currently running, simply it will | ||
/// stop scheduling new tasks. | ||
// (olszewski) The current impl requires that the visitor receiver is read until | ||
// finish even once a task sends back the stop signal. This is suboptimal | ||
// since it would mean the visitor would need to also track if | ||
// it is cancelled :) | ||
pub async fn execute( | ||
self: Arc<Self>, | ||
options: ExecutionOptions, | ||
visitor: mpsc::Sender<Message<VisitorData, VisitorResult>>, | ||
) -> Result<(), ExecuteError> { | ||
let ExecutionOptions { | ||
parallel, | ||
concurrency, | ||
} = options; | ||
let sema = Arc::new(Semaphore::new(concurrency)); | ||
let mut tasks: FuturesUnordered<tokio::task::JoinHandle<Result<(), ExecuteError>>> = | ||
FuturesUnordered::new(); | ||
|
||
let (walker, mut nodes) = Walker::new(&self.task_graph).walk(); | ||
let walker = Arc::new(Mutex::new(walker)); | ||
|
||
while let Some((node_id, done)) = nodes.recv().await { | ||
let visitor = visitor.clone(); | ||
let sema = sema.clone(); | ||
let walker = walker.clone(); | ||
let this = self.clone(); | ||
|
||
tasks.push(tokio::spawn(async move { | ||
let TaskNode::Task(task_id) = this | ||
.task_graph | ||
.node_weight(node_id) | ||
.expect("node id should be present") | ||
else { | ||
// Root task has nothing to do so we don't emit any event for it | ||
if done.send(()).is_err() { | ||
debug!( | ||
"Graph walker done callback receiver was closed before done signal \ | ||
could be sent" | ||
); | ||
} | ||
return Ok(()); | ||
}; | ||
|
||
// Acquire the semaphore unless parallel | ||
let _permit = match parallel { | ||
false => Some(sema.acquire().await.expect( | ||
"Graph concurrency semaphore closed while tasks are still attempting to \ | ||
acquire permits", | ||
)), | ||
true => None, | ||
}; | ||
|
||
let (message, result) = Message::new(task_id.clone()); | ||
visitor.send(message).await?; | ||
|
||
if let Err(StopExecution) = result.await.unwrap_or_else(|_| { | ||
// If the visitor doesn't send a callback, then we assume the task finished | ||
debug!("Engine visitor dropped callback sender without sending result"); | ||
Ok(()) | ||
}) { | ||
if walker | ||
.lock() | ||
.expect("Walker mutex poisoned") | ||
.cancel() | ||
.is_err() | ||
{ | ||
debug!("Unable to cancel graph walk"); | ||
} | ||
} | ||
if done.send(()).is_err() { | ||
debug!("Graph walk done receiver closed before node was finished processing"); | ||
} | ||
Ok(()) | ||
})); | ||
} | ||
|
||
while let Some(res) = tasks.next().await { | ||
res.expect("unable to join task")?; | ||
} | ||
|
||
Ok(()) | ||
} | ||
} | ||
|
||
impl<T, U> Message<T, U> { | ||
pub fn new(info: T) -> (Self, oneshot::Receiver<U>) { | ||
let (callback, receiver) = oneshot::channel(); | ||
(Self { info, callback }, receiver) | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
use std::sync::{Arc, OnceLock}; | ||
|
||
use futures::{stream::FuturesUnordered, StreamExt}; | ||
use regex::Regex; | ||
use tokio::sync::mpsc; | ||
|
||
use crate::{ | ||
engine::{Engine, ExecutionOptions}, | ||
opts::Opts, | ||
package_graph::{PackageGraph, WorkspaceName}, | ||
run::task_id::{self, TaskId}, | ||
}; | ||
|
||
// This holds the whole world | ||
pub struct Visitor<'a> { | ||
package_graph: Arc<PackageGraph>, | ||
opts: &'a Opts<'a>, | ||
} | ||
|
||
#[derive(Debug, thiserror::Error)] | ||
pub enum Error { | ||
#[error("cannot find package {package_name} for task {task_id}")] | ||
MissingPackage { | ||
package_name: WorkspaceName, | ||
task_id: TaskId<'static>, | ||
}, | ||
#[error( | ||
"root task {task_name} ({command}) looks like it invokes turbo and might cause a loop" | ||
)] | ||
RecursiveTurbo { task_name: String, command: String }, | ||
#[error("Could not find definition for task")] | ||
MissingDefinition, | ||
#[error("error while executing engine: {0}")] | ||
Engine(#[from] crate::engine::ExecuteError), | ||
} | ||
|
||
impl<'a> Visitor<'a> { | ||
pub fn new(package_graph: Arc<PackageGraph>, opts: &'a Opts) -> Self { | ||
Self { | ||
package_graph, | ||
opts, | ||
} | ||
} | ||
|
||
pub async fn visit(&self, engine: Arc<Engine>) -> Result<(), Error> { | ||
let concurrency = self.opts.run_opts.concurrency as usize; | ||
let (node_sender, mut node_stream) = mpsc::channel(concurrency); | ||
|
||
let engine_handle = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need the engine to execute in another task so it can send over the tasks via the node stream. |
||
let engine = engine.clone(); | ||
tokio::spawn(engine.execute(ExecutionOptions::new(false, concurrency), node_sender)) | ||
}; | ||
|
||
let mut tasks = FuturesUnordered::new(); | ||
|
||
while let Some(message) = node_stream.recv().await { | ||
let crate::engine::Message { info, callback } = message; | ||
let package_name = WorkspaceName::from(info.package()); | ||
let package_json = self | ||
.package_graph | ||
.package_json(&package_name) | ||
.ok_or_else(|| Error::MissingPackage { | ||
package_name: package_name.clone(), | ||
task_id: info.clone(), | ||
})?; | ||
|
||
let command = package_json.scripts.get(info.task()).cloned(); | ||
|
||
match command { | ||
Some(cmd) | ||
if info.package() == task_id::ROOT_PKG_NAME && turbo_regex().is_match(&cmd) => | ||
{ | ||
return Err(Error::RecursiveTurbo { | ||
task_name: info.to_string(), | ||
command: cmd.to_string(), | ||
}) | ||
} | ||
_ => (), | ||
} | ||
|
||
let _task_def = engine | ||
.task_definition(&info) | ||
.ok_or(Error::MissingDefinition)?; | ||
|
||
tasks.push(tokio::spawn(async move { | ||
println!( | ||
"Executing {info}: {}", | ||
command.as_deref().unwrap_or("no script def") | ||
); | ||
Comment on lines
+86
to
+89
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Once we have the chance of task failures we'll hook up error handling. I'm thinking that the visitor will manage the collection of errors via a |
||
callback.send(Ok(())).unwrap(); | ||
})); | ||
} | ||
|
||
// Wait for the engine task to finish and for all of our tasks to finish | ||
engine_handle.await.expect("engine execution panicked")?; | ||
// This will poll the futures until they are all completed | ||
while let Some(result) = tasks.next().await { | ||
result.expect("task executor panicked"); | ||
} | ||
|
||
Ok(()) | ||
} | ||
} | ||
|
||
fn turbo_regex() -> &'static Regex { | ||
static RE: OnceLock<Regex> = OnceLock::new(); | ||
RE.get_or_init(|| Regex::new(r"(?:^|\s)turbo(?:$|\s)").unwrap()) | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think what we'll want to eventually do is make an owned/borrowed version of task id ala https://github.com/sunshowers-code/borrow-complex-key-example.