Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: port task graph execution #5791

Merged
merged 10 commits into from
Aug 24, 2023
150 changes: 150 additions & 0 deletions crates/turborepo-lib/src/engine/execute.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
use std::sync::{Arc, Mutex};

use futures::{stream::FuturesUnordered, StreamExt};
use tokio::sync::{mpsc, oneshot, Semaphore};
use tracing::log::debug;

use super::{Engine, TaskNode};
use crate::{graph::Walker, run::task_id::TaskId};

pub struct Message<T, U> {
pub info: T,
pub callback: oneshot::Sender<U>,
}

// Type alias used just to make altering the data sent to the visitor easier in
// the future
type VisitorData = TaskId<'static>;
type VisitorResult = Result<(), StopExecution>;

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ExecutionOptions {
parallel: bool,
concurrency: usize,
}

impl ExecutionOptions {
pub fn new(parallel: bool, concurrency: usize) -> Self {
Self {
parallel,
concurrency,
}
}
}

#[derive(Debug, thiserror::Error)]
pub enum ExecuteError {
#[error("Semaphore closed before all tasks finished")]
Semaphore(#[from] tokio::sync::AcquireError),
#[error("Engine visitor closed channel before walk finished")]
Visitor,
}

impl From<mpsc::error::SendError<Message<VisitorData, VisitorResult>>> for ExecuteError {
fn from(
_: mpsc::error::SendError<Message<TaskId<'static>, Result<(), StopExecution>>>,
) -> Self {
ExecuteError::Visitor
}
}

#[derive(Debug, Clone, Copy)]
pub struct StopExecution;

impl Engine {
/// Execute a task graph by sending task ids to the visitor
/// while respecting concurrency limits.
/// The visitor is expected to handle any error handling on it's end.
/// We enforce this by only allowing the returning of a sentinel error
/// type which will stop any further execution of tasks.
/// This will not stop any task which is currently running, simply it will
/// stop scheduling new tasks.
// (olszewski) The current impl requires that the visitor receiver is read until
// finish even once a task sends back the stop signal. This is suboptimal
// since it would mean the visitor would need to also track if
// it is cancelled :)
pub async fn execute(
self: Arc<Self>,
options: ExecutionOptions,
visitor: mpsc::Sender<Message<VisitorData, VisitorResult>>,
) -> Result<(), ExecuteError> {
let ExecutionOptions {
parallel,
concurrency,
} = options;
let sema = Arc::new(Semaphore::new(concurrency));
let mut tasks: FuturesUnordered<tokio::task::JoinHandle<Result<(), ExecuteError>>> =
FuturesUnordered::new();

let (walker, mut nodes) = Walker::new(&self.task_graph).walk();
let walker = Arc::new(Mutex::new(walker));

while let Some((node_id, done)) = nodes.recv().await {
let visitor = visitor.clone();
let sema = sema.clone();
let walker = walker.clone();
let this = self.clone();

tasks.push(tokio::spawn(async move {
let TaskNode::Task(task_id) = this
.task_graph
.node_weight(node_id)
.expect("node id should be present")
else {
// Root task has nothing to do so we don't emit any event for it
if done.send(()).is_err() {
debug!(
"Graph walker done callback receiver was closed before done signal \
could be sent"
);
}
return Ok(());
};

// Acquire the semaphore unless parallel
let _permit = match parallel {
false => Some(sema.acquire().await.expect(
"Graph concurrency semaphore closed while tasks are still attempting to \
acquire permits",
)),
true => None,
};

let (message, result) = Message::new(task_id.clone());
visitor.send(message).await?;

if let Err(StopExecution) = result.await.unwrap_or_else(|_| {
// If the visitor doesn't send a callback, then we assume the task finished
debug!("Engine visitor dropped callback sender without sending result");
Ok(())
}) {
if walker
.lock()
.expect("Walker mutex poisoned")
.cancel()
.is_err()
{
debug!("Unable to cancel graph walk");
}
}
if done.send(()).is_err() {
debug!("Graph walk done receiver closed before node was finished processing");
}
Ok(())
}));
}

while let Some(res) = tasks.next().await {
res.expect("unable to join task")?;
}

Ok(())
}
}

impl<T, U> Message<T, U> {
pub fn new(info: T) -> (Self, oneshot::Receiver<U>) {
let (callback, receiver) = oneshot::channel();
(Self { info, callback }, receiver)
}
}
9 changes: 9 additions & 0 deletions crates/turborepo-lib/src/engine/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
mod builder;
mod execute;

mod dot;

use std::{
Expand All @@ -7,6 +9,7 @@ use std::{
};

pub use builder::EngineBuilder;
pub use execute::{ExecuteError, ExecutionOptions, Message};
use petgraph::Graph;

use crate::{
Expand Down Expand Up @@ -116,6 +119,12 @@ impl Engine<Built> {
)
}

// TODO get rid of static lifetime and figure out right way to tell compiler the
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think what we'll want to eventually do is make an owned/borrowed version of task id ala https://github.com/sunshowers-code/borrow-complex-key-example.

// lifetime of the return ref
pub fn task_definition(&self, task_id: &TaskId<'static>) -> Option<&TaskDefinition> {
self.task_definitions.get(task_id)
}

pub fn validate(
&self,
package_graph: &PackageGraph,
Expand Down
11 changes: 10 additions & 1 deletion crates/turborepo-lib/src/run/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ mod global_hash;
mod scope;
pub mod task_id;

use std::io::{BufWriter, IsTerminal};
use std::{
io::{BufWriter, IsTerminal},
sync::Arc,
};

use anyhow::{anyhow, Context as ErrorContext, Result};
use itertools::Itertools;
Expand All @@ -27,6 +30,7 @@ use crate::{
package_graph::{PackageGraph, WorkspaceName},
package_json::PackageJson,
run::{cache::RunCache, global_hash::get_global_hash_inputs},
task_graph::Visitor,
};

#[derive(Debug)]
Expand Down Expand Up @@ -206,6 +210,11 @@ impl Run {
self.base.ui,
);

let pkg_dep_graph = Arc::new(pkg_dep_graph);
let engine = Arc::new(engine);
let visitor = Visitor::new(pkg_dep_graph, &opts);
visitor.visit(engine).await?;

Ok(())
}
}
3 changes: 3 additions & 0 deletions crates/turborepo-lib/src/task_graph/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
mod visitor;

use std::collections::{HashMap, HashSet};

use serde::{Deserialize, Serialize};
use turbopath::RelativeUnixPathBuf;
pub use visitor::{Error, Visitor};

use crate::{
cli::OutputLogsMode,
Expand Down
108 changes: 108 additions & 0 deletions crates/turborepo-lib/src/task_graph/visitor.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
use std::sync::{Arc, OnceLock};

use futures::{stream::FuturesUnordered, StreamExt};
use regex::Regex;
use tokio::sync::mpsc;

use crate::{
engine::{Engine, ExecutionOptions},
opts::Opts,
package_graph::{PackageGraph, WorkspaceName},
run::task_id::{self, TaskId},
};

// This holds the whole world
pub struct Visitor<'a> {
package_graph: Arc<PackageGraph>,
opts: &'a Opts<'a>,
}

#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("cannot find package {package_name} for task {task_id}")]
MissingPackage {
package_name: WorkspaceName,
task_id: TaskId<'static>,
},
#[error(
"root task {task_name} ({command}) looks like it invokes turbo and might cause a loop"
)]
RecursiveTurbo { task_name: String, command: String },
#[error("Could not find definition for task")]
MissingDefinition,
#[error("error while executing engine: {0}")]
Engine(#[from] crate::engine::ExecuteError),
}

impl<'a> Visitor<'a> {
pub fn new(package_graph: Arc<PackageGraph>, opts: &'a Opts) -> Self {
Self {
package_graph,
opts,
}
}

pub async fn visit(&self, engine: Arc<Engine>) -> Result<(), Error> {
let concurrency = self.opts.run_opts.concurrency as usize;
let (node_sender, mut node_stream) = mpsc::channel(concurrency);

let engine_handle = {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need the engine to execute in another task so it can send over the tasks via the node stream.

let engine = engine.clone();
tokio::spawn(engine.execute(ExecutionOptions::new(false, concurrency), node_sender))
};

let mut tasks = FuturesUnordered::new();

while let Some(message) = node_stream.recv().await {
let crate::engine::Message { info, callback } = message;
let package_name = WorkspaceName::from(info.package());
let package_json = self
.package_graph
.package_json(&package_name)
.ok_or_else(|| Error::MissingPackage {
package_name: package_name.clone(),
task_id: info.clone(),
})?;

let command = package_json.scripts.get(info.task()).cloned();

match command {
Some(cmd)
if info.package() == task_id::ROOT_PKG_NAME && turbo_regex().is_match(&cmd) =>
{
return Err(Error::RecursiveTurbo {
task_name: info.to_string(),
command: cmd.to_string(),
})
}
_ => (),
}

let _task_def = engine
.task_definition(&info)
.ok_or(Error::MissingDefinition)?;

tasks.push(tokio::spawn(async move {
println!(
"Executing {info}: {}",
command.as_deref().unwrap_or("no script def")
);
Comment on lines +86 to +89
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once we have the chance of task failures we'll hook up error handling. I'm thinking that the visitor will manage the collection of errors via a mpsc channel setup, but this might change when we get to implementing it.

callback.send(Ok(())).unwrap();
}));
}

// Wait for the engine task to finish and for all of our tasks to finish
engine_handle.await.expect("engine execution panicked")?;
// This will poll the futures until they are all completed
while let Some(result) = tasks.next().await {
result.expect("task executor panicked");
}

Ok(())
}
}

fn turbo_regex() -> &'static Regex {
static RE: OnceLock<Regex> = OnceLock::new();
RE.get_or_init(|| Regex::new(r"(?:^|\s)turbo(?:$|\s)").unwrap())
}
Loading