Skip to content

Commit

Permalink
feat(workflows): add worker instance failover
Browse files Browse the repository at this point in the history
  • Loading branch information
MasterPtato committed Jun 4, 2024
1 parent a3511ca commit b61ffea
Show file tree
Hide file tree
Showing 14 changed files with 243 additions and 52 deletions.
3 changes: 3 additions & 0 deletions docs/libraries/workflow/ERRORS.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Errors

Only errors from inside of activities will be retried. Errors thrown in the workflow body will not be retried because they will never succeed (the state is consistent up the point of error).
4 changes: 0 additions & 4 deletions lib/chirp-workflow/core/src/ctx/test.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use std::sync::Arc;

use global_error::{GlobalError, GlobalResult};
use serde::Serialize;
use tokio::time::Duration;
Expand All @@ -10,8 +8,6 @@ use crate::{
Workflow, WorkflowError, WorkflowInput,
};

pub type TestCtxHandle = Arc<TestCtx>;

pub struct TestCtx {
name: String,
ray_id: Uuid,
Expand Down
6 changes: 5 additions & 1 deletion lib/chirp-workflow/core/src/db/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@ pub trait Database: Send {
input: serde_json::Value,
) -> WorkflowResult<()>;
async fn get_workflow(&self, id: Uuid) -> WorkflowResult<Option<WorkflowRow>>;
async fn pull_workflows(&self, filter: &[&str]) -> WorkflowResult<Vec<PulledWorkflow>>;
async fn pull_workflows(
&self,
worker_instance_id: Uuid,
filter: &[&str],
) -> WorkflowResult<Vec<PulledWorkflow>>;

// When a workflow is completed
async fn commit_workflow(
Expand Down
87 changes: 50 additions & 37 deletions lib/chirp-workflow/core/src/db/postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ use super::{
};
use crate::{schema::ActivityId, WorkflowError, WorkflowResult};

const NODE_ID: Uuid = Uuid::nil();

pub struct DatabasePostgres {
pool: PgPool,
}
Expand Down Expand Up @@ -99,44 +97,59 @@ impl Database for DatabasePostgres {
.map_err(WorkflowError::Sqlx)
}

async fn pull_workflows(&self, filter: &[&str]) -> WorkflowResult<Vec<PulledWorkflow>> {
async fn pull_workflows(
&self,
worker_instance_id: Uuid,
filter: &[&str],
) -> WorkflowResult<Vec<PulledWorkflow>> {
// TODO(RVT-3753): include limit on query to allow better workflow spread between nodes?
// Select all workflows that haven't started or that have a wake condition
let rows = sqlx::query_as::<_, PulledWorkflowRow>(indoc!(
"
UPDATE db_workflow.workflows as w
-- Assign this node to this workflow
SET node_id = $1
WHERE
-- Filter
workflow_name = ANY($2) AND
-- Not already complete
output IS NULL AND
-- No assigned node (not running)
node_id IS NULL AND
-- Check for wake condition
(
wake_immediate OR
wake_deadline_ts IS NOT NULL OR
(
SELECT true
FROM db_workflow.signals AS s
WHERE s.signal_name = ANY(wake_signals)
LIMIT 1
) OR
(
SELECT true
FROM db_workflow.workflows AS w2
WHERE
w2.workflow_id = w.wake_sub_workflow_id AND
output IS NOT NULL
)
WITH
pull_workflows AS (
UPDATE db_workflow.workflows as w
-- Assign this node to this workflow
SET worker_instance_id = $1
WHERE
-- Filter
workflow_name = ANY($2) AND
-- Not already complete
output IS NULL AND
-- No assigned node (not running)
worker_instance_id IS NULL AND
-- Check for wake condition
(
wake_immediate OR
wake_deadline_ts IS NOT NULL OR
(
SELECT true
FROM db_workflow.signals AS s
WHERE s.signal_name = ANY(wake_signals)
LIMIT 1
) OR
(
SELECT true
FROM db_workflow.workflows AS w2
WHERE
w2.workflow_id = w.wake_sub_workflow_id AND
output IS NOT NULL
)
)
RETURNING workflow_id, workflow_name, create_ts, ray_id, input, wake_deadline_ts
),
-- Update last ping
worker_instance_update AS (
UPSERT INTO db_workflow.worker_instances (worker_instance_id, last_ping_ts)
VALUES ($1, $3)
RETURNING 1
)
RETURNING workflow_id, workflow_name, create_ts, ray_id, input, wake_deadline_ts
SELECT * FROM pull_workflows
",
))
.bind(NODE_ID)
.bind(worker_instance_id)
.bind(filter)
.bind(rivet_util::timestamp::now())
.fetch_all(&mut *self.conn().await?)
.await
.map_err(WorkflowError::Sqlx)?;
Expand Down Expand Up @@ -199,12 +212,12 @@ impl Database for DatabasePostgres {
sqlx::query_as::<_, SubWorkflowEventRow>(indoc!(
"
SELECT
sw.workflow_id, sw.location, sw.sub_workflow_id, w.name as sub_workflow_name
sw.workflow_id, sw.location, sw.sub_workflow_id, w.workflow_name AS sub_workflow_name
FROM db_workflow.workflow_sub_workflow_events AS sw
JOIN db_workflow.workflows AS w
ON sw.sub_workflow_id = w.workflow_id
WHERE workflow_id = ANY($1)
ORDER BY workflow_id, location ASC
WHERE sw.workflow_id = ANY($1)
ORDER BY sw.workflow_id, sw.location ASC
",
))
.bind(&workflow_ids)
Expand Down Expand Up @@ -274,7 +287,7 @@ impl Database for DatabasePostgres {
"
UPDATE db_workflow.workflows
SET
node_id = NULL,
worker_instance_id = NULL,
wake_immediate = $2,
wake_deadline_ts = $3,
wake_signals = $4,
Expand Down Expand Up @@ -307,7 +320,7 @@ impl Database for DatabasePostgres {
UPSERT INTO db_workflow.workflow_activity_events (
workflow_id, location, activity_name, input_hash, input, output
)
VALUES ($1, $2, $3, $4, $5)
VALUES ($1, $2, $3, $4, $5, $6)
",
))
.bind(workflow_id)
Expand Down
4 changes: 4 additions & 0 deletions lib/chirp-workflow/core/src/registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ impl Registry {
.get(name)
.ok_or(WorkflowError::WorkflowMissingFromRegistry(name.to_string()))
}

pub fn size(&self) -> usize {
self.workflows.len()
}
}

pub struct RegistryWorkflow {
Expand Down
19 changes: 17 additions & 2 deletions lib/chirp-workflow/core/src/worker.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use global_error::GlobalResult;
use tokio::time::Duration;
use tracing::Instrument;
use uuid::Uuid;

use crate::{util, DatabaseHandle, RegistryHandle, WorkflowCtx};

Expand All @@ -10,16 +11,27 @@ const TICK_INTERVAL: Duration = Duration::from_millis(50);
/// that are registered in its registry. After pulling, the workflows are ran and their state is written to
/// the database.
pub struct Worker {
worker_instance_id: Uuid,
registry: RegistryHandle,
db: DatabaseHandle,
}

impl Worker {
pub fn new(registry: RegistryHandle, db: DatabaseHandle) -> Self {
Worker { registry, db }
Worker {
worker_instance_id: Uuid::new_v4(),
registry,
db,
}
}

pub async fn start(mut self, pools: rivet_pools::Pools) -> GlobalResult<()> {
tracing::info!(
worker_instance_id=?self.worker_instance_id,
"starting worker instance with {} registered workflows",
self.registry.size(),
);

let mut interval = tokio::time::interval(TICK_INTERVAL);
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);

Expand Down Expand Up @@ -49,7 +61,10 @@ impl Worker {
.collect::<Vec<_>>();

// Query awake workflows
let workflows = self.db.pull_workflows(&registered_workflows).await?;
let workflows = self
.db
.pull_workflows(self.worker_instance_id, &registered_workflows)
.await?;
for workflow in workflows {
let conn = util::new_conn(
&shared_client,
Expand Down
16 changes: 16 additions & 0 deletions svc/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion svc/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,8 @@ members = [
"pkg/user/ops/token-create",
"pkg/user/standalone/delete-pending",
"pkg/user/standalone/search-user-gc",
"pkg/user/worker"
"pkg/user/worker",
"pkg/workflow/standalone/gc"
]

# Speed up compilation
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
CREATE TABLE nodes (
node_id UUID PRIMARY KEY,
CREATE TABLE worker_instances (
worker_instance_id UUID PRIMARY KEY,
last_ping_ts INT
);

-- TODO: In the event of a node failure, clear all of the wake conditions and remove the node id. This can be
-- done in a periodic GC service
-- NOTE: If a row has `worker_instance_id` set and `output` unset, it is currently running
CREATE TABLE workflows (
workflow_id UUID PRIMARY KEY,
workflow_name TEXT NOT NULL,
create_ts INT NOT NULL,
ray_id UUID NOT NULL,
-- The node that's running this workflow
node_id UUID,
-- The worker instance that's running this workflow
worker_instance_id UUID,

input JSONB NOT NULL,
-- Null if incomplete
Expand All @@ -24,7 +23,10 @@ CREATE TABLE workflows (

INDEX (wake_immediate),
INDEX (wake_deadline_ts),
INDEX (wake_sub_workflow_id)
INDEX (wake_sub_workflow_id),

-- Query by worker_instance_id for failover
INDEX(worker_instance_id)
);

CREATE INDEX gin_workflows_wake_signals
Expand Down
20 changes: 20 additions & 0 deletions svc/pkg/workflow/standalone/gc/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
[package]
name = "workflow-gc"
version = "0.0.1"
edition = "2021"
authors = ["Rivet Gaming, LLC <developer@rivet.gg>"]
license = "Apache-2.0"

[dependencies]
chirp-client = { path = "../../../../../lib/chirp/client" }
rivet-connection = { path = "../../../../../lib/connection" }
rivet-health-checks = { path = "../../../../../lib/health-checks" }
rivet-metrics = { path = "../../../../../lib/metrics" }
rivet-operation = { path = "../../../../../lib/operation/core" }
rivet-runtime = { path = "../../../../../lib/runtime" }
tokio = { version = "1.29", features = ["full"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", default-features = false, features = ["fmt", "json", "ansi"] }

[dev-dependencies]
chirp-worker = { path = "../../../../../lib/chirp/worker" }
11 changes: 11 additions & 0 deletions svc/pkg/workflow/standalone/gc/Service.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
[service]
name = "workflow-gc"

[runtime]
kind = "rust"

[headless]
singleton = true

[databases]
db-workflow = {}
58 changes: 58 additions & 0 deletions svc/pkg/workflow/standalone/gc/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
use std::{collections::HashSet, time::Duration};

use rivet_operation::prelude::*;

const WORKER_INSTANCE_LOST_THRESHOLD: i64 = util::duration::seconds(30);

#[tracing::instrument(skip_all)]
pub async fn run_from_env(ts: i64, pools: rivet_pools::Pools) -> GlobalResult<()> {
let client = chirp_client::SharedClient::from_env(pools.clone())?.wrap_new("workflow-gc");
let cache = rivet_cache::CacheInner::from_env(pools.clone())?;
let ctx = OperationContext::new(
"workflow-gc".into(),
Duration::from_secs(60),
rivet_connection::Connection::new(client, pools, cache),
Uuid::new_v4(),
Uuid::new_v4(),
util::timestamp::now(),
util::timestamp::now(),
(),
);

// Reset all workflows on worker instances that have not had a ping in the last 30 seconds
let rows = sql_fetch_all!(
[ctx, (Uuid, Uuid,)]
"
UPDATE db_workflow.workflows AS w
SET
worker_instance_id = NULL,
wake_immediate = true,
wake_deadline_ts = NULL,
wake_signals = ARRAY[],
wake_sub_workflow_id = NULL
FROM db_workflow.worker_instances AS wi
WHERE
wi.last_ping_ts < $1 AND
wi.worker_instance_id = w.worker_instance_id AND
w.output IS NULL
RETURNING w.workflow_id, wi.worker_instance_id
",
ts - WORKER_INSTANCE_LOST_THRESHOLD,
)
.await?;

if !rows.is_empty() {
let unique_worker_instance_ids = rows
.iter()
.map(|(_, worker_instance_id)| worker_instance_id)
.collect::<HashSet<_>>();

tracing::info!(
worker_instance_ids=?unique_worker_instance_ids,
total_workflows=%rows.len(),
"handled failover",
);
}

Ok(())
}
Loading

0 comments on commit b61ffea

Please sign in to comment.