Skip to content

Commit c477ba9

Browse files
committed
feat(workflows): add sleep fn (#1077)
<!-- Please make sure there is an issue that this PR is correlated to. --> ## Changes <!-- If there are frontend changes, please include screenshots. -->
1 parent 0c58f83 commit c477ba9

File tree

11 files changed

+371
-186
lines changed

11 files changed

+371
-186
lines changed

docs/libraries/workflow/GOTCHAS.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,3 +122,15 @@ ctx
122122
})
123123
.await?;
124124
```
125+
126+
## Nested options with serde
127+
128+
Nested options do not serialize/deserialize consistently with serde.
129+
130+
```rust
131+
Some(Some(1234)) -> "1234" -> Some(Some(1234))
132+
Some(None) -> "null" -> None
133+
None -> "null" -> None
134+
```
135+
136+
Be careful when writing your struct definitions.

lib/chirp-workflow/core/src/ctx/workflow.rs

Lines changed: 62 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,13 @@ use crate::{
1818
metrics,
1919
registry::RegistryHandle,
2020
signal::Signal,
21-
util::{self, GlobalErrorExt, Location},
21+
util::{
22+
self,
23+
time::{DurationToMillis, TsToMillis},
24+
GlobalErrorExt, Location,
25+
},
2226
workflow::{Workflow, WorkflowInput},
27+
worker,
2328
};
2429

2530
// Time to delay a workflow from retrying after an error
@@ -211,7 +216,7 @@ impl WorkflowCtx {
211216
}
212217
Err(err) => {
213218
// Retry the workflow if its recoverable
214-
let deadline_ts = if let Some(deadline_ts) = err.backoff() {
219+
let deadline_ts = if let Some(deadline_ts) = err.deadline_ts() {
215220
Some(deadline_ts)
216221
} else if err.is_retryable() {
217222
Some(rivet_util::timestamp::now() + RETRY_TIMEOUT_MS as i64)
@@ -1251,7 +1256,61 @@ impl WorkflowCtx {
12511256
Ok(output)
12521257
}
12531258

1254-
// TODO: sleep_for, sleep_until
1259+
pub async fn sleep<T: DurationToMillis>(&mut self, duration: T) -> GlobalResult<()> {
1260+
self.sleep_until(rivet_util::timestamp::now() + duration.to_millis()?)
1261+
.await
1262+
}
1263+
1264+
pub async fn sleep_until<T: TsToMillis>(&mut self, time: T) -> GlobalResult<()> {
1265+
let event = self.relevant_history().nth(self.location_idx);
1266+
1267+
// Slept before
1268+
if let Some(event) = event {
1269+
// Validate history is consistent
1270+
let Event::Sleep(_) = event else {
1271+
return Err(WorkflowError::HistoryDiverged(format!(
1272+
"expected {event} at {}, found sleep",
1273+
self.loc(),
1274+
)))
1275+
.map_err(GlobalError::raw);
1276+
};
1277+
1278+
tracing::debug!(name=%self.name, id=%self.workflow_id, "skipping replayed sleep");
1279+
}
1280+
// Sleep
1281+
else {
1282+
let ts = time.to_millis()?;
1283+
1284+
self.db
1285+
.commit_workflow_sleep_event(
1286+
self.workflow_id,
1287+
self.full_location().as_ref(),
1288+
ts,
1289+
self.loop_location(),
1290+
)
1291+
.await?;
1292+
1293+
let duration = ts - rivet_util::timestamp::now();
1294+
if duration < 0 {
1295+
// No-op
1296+
tracing::warn!("tried to sleep for a negative duration");
1297+
} else if duration < worker::TICK_INTERVAL.as_millis() as i64 + 1 {
1298+
tracing::info!(name=%self.name, id=%self.workflow_id, until_ts=%ts, "sleeping in memory");
1299+
1300+
// Sleep in memory if duration is shorter than the worker tick
1301+
tokio::time::sleep(std::time::Duration::from_millis(duration.try_into()?)).await;
1302+
} else {
1303+
tracing::info!(name=%self.name, id=%self.workflow_id, until_ts=%ts, "sleeping");
1304+
1305+
return Err(WorkflowError::Sleep(ts)).map_err(GlobalError::raw);
1306+
}
1307+
}
1308+
1309+
// Move to next event
1310+
self.location_idx += 1;
1311+
1312+
Ok(())
1313+
}
12551314
}
12561315

12571316
impl WorkflowCtx {

lib/chirp-workflow/core/src/db/mod.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,15 @@ pub trait Database: Send {
169169
output: Option<serde_json::Value>,
170170
loop_location: Option<&[usize]>,
171171
) -> WorkflowResult<()>;
172+
173+
/// Writes a workflow sleep event to history.
174+
async fn commit_workflow_sleep_event(
175+
&self,
176+
from_workflow_id: Uuid,
177+
location: &[usize],
178+
util_ts: i64,
179+
loop_location: Option<&[usize]>,
180+
) -> WorkflowResult<()>;
172181
}
173182

174183
#[derive(sqlx::FromRow)]
@@ -266,3 +275,9 @@ pub struct LoopEventRow {
266275
pub output: Option<serde_json::Value>,
267276
pub iteration: i64,
268277
}
278+
279+
#[derive(sqlx::FromRow)]
280+
pub struct SleepEventRow {
281+
pub workflow_id: Uuid,
282+
pub location: Vec<i64>,
283+
}

lib/chirp-workflow/core/src/db/pg_nats.rs

Lines changed: 127 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,14 @@ use uuid::Uuid;
88

99
use super::{
1010
ActivityEventRow, Database, LoopEventRow, MessageSendEventRow, PulledWorkflow,
11-
PulledWorkflowRow, SignalEventRow, SignalRow, SignalSendEventRow, SubWorkflowEventRow,
12-
WorkflowRow,
11+
PulledWorkflowRow, SignalEventRow, SignalRow, SignalSendEventRow, SleepEventRow,
12+
SubWorkflowEventRow, WorkflowRow,
1313
};
1414
use crate::{
1515
activity::ActivityId,
1616
error::{WorkflowError, WorkflowResult},
1717
event::combine_events,
18-
message,
18+
message, worker,
1919
};
2020

2121
/// Max amount of workflows pulled from the database with each call to `pull_workflows`.
@@ -152,74 +152,84 @@ impl Database for DatabasePgNats {
152152
filter: &[&str],
153153
) -> WorkflowResult<Vec<PulledWorkflow>> {
154154
// Select all workflows that haven't started or that have a wake condition
155-
let workflow_rows = sqlx::query_as::<_, PulledWorkflowRow>(indoc!(
156-
"
157-
WITH
158-
pull_workflows AS (
159-
UPDATE db_workflow.workflows AS w
160-
-- Assign this node to this workflow
161-
SET worker_instance_id = $1
162-
WHERE
163-
-- Filter
164-
workflow_name = ANY($2) AND
165-
-- Not already complete
166-
output IS NULL AND
167-
-- No assigned node (not running)
168-
worker_instance_id IS NULL AND
169-
-- Check for wake condition
170-
(
171-
-- Immediate
172-
wake_immediate OR
173-
-- After deadline
174-
wake_deadline_ts IS NOT NULL OR
175-
-- Signal exists
176-
(
177-
SELECT true
178-
FROM db_workflow.signals AS s
179-
WHERE
180-
s.workflow_id = w.workflow_id AND
181-
s.signal_name = ANY(w.wake_signals) AND
182-
s.ack_ts IS NULL
183-
LIMIT 1
184-
) OR
185-
-- Tagged signal exists
186-
(
187-
SELECT true
188-
FROM db_workflow.tagged_signals AS s
189-
WHERE
190-
s.signal_name = ANY(w.wake_signals) AND
191-
s.tags <@ w.tags AND
192-
s.ack_ts IS NULL
193-
LIMIT 1
194-
) OR
195-
-- Sub workflow completed
196-
(
197-
SELECT true
198-
FROM db_workflow.workflows AS w2
199-
WHERE
200-
w2.workflow_id = w.wake_sub_workflow_id AND
201-
output IS NOT NULL
202-
)
155+
let workflow_rows = self
156+
.query(|| async {
157+
sqlx::query_as::<_, PulledWorkflowRow>(indoc!(
158+
"
159+
WITH
160+
pull_workflows AS (
161+
UPDATE db_workflow.workflows AS w
162+
-- Assign this node to this workflow
163+
SET worker_instance_id = $1
164+
WHERE
165+
-- Filter
166+
workflow_name = ANY($2) AND
167+
-- Not already complete
168+
output IS NULL AND
169+
-- No assigned node (not running)
170+
worker_instance_id IS NULL AND
171+
-- Check for wake condition
172+
(
173+
-- Immediate
174+
wake_immediate OR
175+
-- After deadline
176+
(
177+
wake_deadline_ts IS NOT NULL AND
178+
$3 > wake_deadline_ts - $4
179+
) OR
180+
-- Signal exists
181+
(
182+
SELECT true
183+
FROM db_workflow.signals AS s
184+
WHERE
185+
s.workflow_id = w.workflow_id AND
186+
s.signal_name = ANY(w.wake_signals) AND
187+
s.ack_ts IS NULL
188+
LIMIT 1
189+
) OR
190+
-- Tagged signal exists
191+
(
192+
SELECT true
193+
FROM db_workflow.tagged_signals AS s
194+
WHERE
195+
s.signal_name = ANY(w.wake_signals) AND
196+
s.tags <@ w.tags AND
197+
s.ack_ts IS NULL
198+
LIMIT 1
199+
) OR
200+
-- Sub workflow completed
201+
(
202+
SELECT true
203+
FROM db_workflow.workflows AS w2
204+
WHERE
205+
w2.workflow_id = w.wake_sub_workflow_id AND
206+
output IS NOT NULL
207+
)
208+
)
209+
LIMIT $5
210+
RETURNING workflow_id, workflow_name, create_ts, ray_id, input, wake_deadline_ts
211+
),
212+
-- Update last ping
213+
worker_instance_update AS (
214+
UPSERT INTO db_workflow.worker_instances (worker_instance_id, last_ping_ts)
215+
VALUES ($1, $3)
216+
RETURNING 1
203217
)
204-
LIMIT $4
205-
RETURNING workflow_id, workflow_name, create_ts, ray_id, input, wake_deadline_ts
206-
),
207-
-- Update last ping
208-
worker_instance_update AS (
209-
UPSERT INTO db_workflow.worker_instances (worker_instance_id, last_ping_ts)
210-
VALUES ($1, $3)
211-
RETURNING 1
212-
)
213-
SELECT * FROM pull_workflows
214-
",
215-
))
216-
.bind(worker_instance_id)
217-
.bind(filter)
218-
.bind(rivet_util::timestamp::now())
219-
.bind(MAX_PULLED_WORKFLOWS)
220-
.fetch_all(&mut *self.conn().await?)
221-
.await
222-
.map_err(WorkflowError::Sqlx)?;
218+
SELECT * FROM pull_workflows
219+
",
220+
))
221+
.bind(worker_instance_id)
222+
.bind(filter)
223+
.bind(rivet_util::timestamp::now())
224+
// Add padding to the tick interval so that the workflow deadline is never passed before its pulled.
225+
// The worker sleeps internally to handle this
226+
.bind(worker::TICK_INTERVAL.as_millis() as i64 + 1)
227+
.bind(MAX_PULLED_WORKFLOWS)
228+
.fetch_all(&mut *self.conn().await?)
229+
.await
230+
.map_err(WorkflowError::Sqlx)
231+
})
232+
.await?;
223233

224234
if workflow_rows.is_empty() {
225235
return Ok(Vec::new());
@@ -240,6 +250,7 @@ impl Database for DatabasePgNats {
240250
msg_send_events,
241251
sub_workflow_events,
242252
loop_events,
253+
sleep_events,
243254
) = tokio::try_join!(
244255
async {
245256
sqlx::query_as::<_, ActivityEventRow>(indoc!(
@@ -347,6 +358,21 @@ impl Database for DatabasePgNats {
347358
.await
348359
.map_err(WorkflowError::Sqlx)
349360
},
361+
async {
362+
sqlx::query_as::<_, SleepEventRow>(indoc!(
363+
"
364+
SELECT
365+
workflow_id, location
366+
FROM db_workflow.workflow_sleep_events
367+
WHERE workflow_id = ANY($1) AND forgotten = FALSE
368+
ORDER BY workflow_id, location ASC
369+
",
370+
))
371+
.bind(&workflow_ids)
372+
.fetch_all(&mut *self.conn().await?)
373+
.await
374+
.map_err(WorkflowError::Sqlx)
375+
},
350376
)?;
351377

352378
let workflows = combine_events(
@@ -357,6 +383,7 @@ impl Database for DatabasePgNats {
357383
msg_send_events,
358384
sub_workflow_events,
359385
loop_events,
386+
sleep_events,
360387
)?;
361388

362389
Ok(workflows)
@@ -397,7 +424,6 @@ impl Database for DatabasePgNats {
397424
wake_sub_workflow_id: Option<Uuid>,
398425
error: &str,
399426
) -> WorkflowResult<()> {
400-
// TODO(RVT-3762): Should this compare `wake_deadline_ts` before setting it?
401427
self.query(|| async {
402428
sqlx::query(indoc!(
403429
"
@@ -1017,4 +1043,34 @@ impl Database for DatabasePgNats {
10171043

10181044
Ok(())
10191045
}
1046+
1047+
async fn commit_workflow_sleep_event(
1048+
&self,
1049+
from_workflow_id: Uuid,
1050+
location: &[usize],
1051+
until_ts: i64,
1052+
loop_location: Option<&[usize]>,
1053+
) -> WorkflowResult<()> {
1054+
self.query(|| async {
1055+
sqlx::query(indoc!(
1056+
"
1057+
INSERT INTO db_workflow.workflow_sleep_events(
1058+
workflow_id, location, until_ts, loop_location
1059+
)
1060+
VALUES($1, $2, $3, $4)
1061+
RETURNING 1
1062+
",
1063+
))
1064+
.bind(from_workflow_id)
1065+
.bind(location.iter().map(|x| *x as i64).collect::<Vec<_>>())
1066+
.bind(until_ts)
1067+
.bind(loop_location.map(|l| l.iter().map(|x| *x as i64).collect::<Vec<_>>()))
1068+
.execute(&mut *self.conn().await?)
1069+
.await
1070+
.map_err(WorkflowError::Sqlx)
1071+
})
1072+
.await?;
1073+
1074+
Ok(())
1075+
}
10201076
}

0 commit comments

Comments
 (0)