diff --git a/src/lib.rs b/src/lib.rs index 4d65ddf1b..53f1838ea 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -36,7 +36,11 @@ use dashmap::DashMap; use std::{ convert::TryInto, fmt::Debug, - sync::{mpsc::SendError, Arc}, + sync::{ + atomic::{AtomicBool, Ordering}, + mpsc::SendError, + Arc, + }, }; use tokio::runtime::Runtime; use tonic::codegen::http::uri::InvalidUri; @@ -62,6 +66,13 @@ pub trait Core: Send + Sync { /// Returns an instance of ServerGateway. fn server_gateway(&self) -> Result>; + + /// Eventually ceases all polling of the server. [Core::poll_task] should be called until it + /// returns [CoreError::ShuttingDown] to ensure that any workflows which are still undergoing + /// replay have an opportunity to finish. This means that the lang sdk will need to call + /// [Core::complete_task] for those workflows until they are done. At that point, the lang + /// SDK can end the process, or drop the [Core] instance, which will close the connection. + fn shutdown(&self) -> Result<()>; } /// Holds various configuration information required to call [init] @@ -89,6 +100,7 @@ pub fn init(opts: CoreInitOptions) -> Result { workflow_machines: WorkflowConcurrencyManager::new(), workflow_task_tokens: Default::default(), pending_activations: Default::default(), + shutdown_requested: AtomicBool::new(false), }) } @@ -115,6 +127,9 @@ where /// Workflows that are currently under replay will queue their run ID here, indicating that /// there are more workflow tasks / activations to be performed. pending_activations: SegQueue, + + /// Has shutdown been called? + shutdown_requested: AtomicBool, } #[derive(Debug)] @@ -146,6 +161,10 @@ where }); } + if self.shutdown_requested.load(Ordering::SeqCst) { + return Err(CoreError::ShuttingDown); + } + // This will block forever in the event there is no work from the server let work = self .runtime @@ -215,6 +234,11 @@ where fn server_gateway(&self) -> Result> { Ok(self.server_gateway.clone()) } + + fn shutdown(&self) -> Result<(), CoreError> { + self.shutdown_requested.store(true, Ordering::SeqCst); + Ok(()) + } } impl CoreSDK { @@ -271,8 +295,9 @@ impl CoreSDK { #[allow(clippy::large_enum_variant)] // NOTE: Docstrings take the place of #[error("xxxx")] here b/c of displaydoc pub enum CoreError { - /// No tasks to perform for now - NoWork, + /// [Core::shutdown] was called, and there are no more replay tasks to be handled. You must + /// call [Core::complete_task] for any remaining tasks, and then may exit. + ShuttingDown, /// Poll response from server was malformed: {0:?} BadDataFromWorkProvider(PollWorkflowTaskQueueResponse), /// Lang SDK sent us a malformed completion: {0:?} @@ -302,6 +327,7 @@ pub enum CoreError { #[cfg(test)] mod test { use super::*; + use crate::machines::test_help::FakeCore; use crate::{ machines::test_help::{build_fake_core, TestHistoryBuilder}, protos::{ @@ -317,6 +343,7 @@ mod test { }, }, }; + use rstest::{fixture, rstest}; #[test] fn single_timer_test_across_wf_bridge() { @@ -484,11 +511,8 @@ mod test { .unwrap(); } - #[test] - fn single_timer_whole_replay_test_across_wf_bridge() { - let s = span!(Level::DEBUG, "Test start", t = "bridge"); - let _enter = s.enter(); - + #[fixture] + fn single_timer_whole_replay() -> FakeCore { let wfid = "fake_wf_id"; let run_id = "fake_run_id"; let timer_1_id = "timer1".to_string(); @@ -543,5 +567,22 @@ mod test { task_tok, )) .unwrap(); + core + } + + #[rstest] + fn single_timer_whole_replay_test_across_wf_bridge(_single_timer_whole_replay: FakeCore) { + // Nothing to do here -- whole real test is in fixture. Rstest properly handles leading `_` + } + + #[rstest] + fn after_shutdown_server_is_not_polled(single_timer_whole_replay: FakeCore) { + single_timer_whole_replay.shutdown().unwrap(); + assert_matches!( + single_timer_whole_replay + .poll_task("irrelevant") + .unwrap_err(), + CoreError::ShuttingDown + ); } } diff --git a/src/machines/test_help/mod.rs b/src/machines/test_help/mod.rs index f669750d9..9522ae745 100644 --- a/src/machines/test_help/mod.rs +++ b/src/machines/test_help/mod.rs @@ -17,9 +17,12 @@ use crate::{ CoreSDK, }; use rand::{thread_rng, Rng}; +use std::sync::atomic::AtomicBool; use std::{collections::VecDeque, sync::Arc}; use tokio::runtime::Runtime; +pub(crate) type FakeCore = CoreSDK; + /// Given identifiers for a workflow/run, and a test history builder, construct an instance of /// the core SDK with a mock server gateway that will produce the responses as appropriate. /// @@ -31,7 +34,7 @@ pub(crate) fn build_fake_core( run_id: &str, t: &mut TestHistoryBuilder, response_batches: &[usize], -) -> CoreSDK { +) -> FakeCore { let wf = Some(WorkflowExecution { workflow_id: wf_id.to_string(), run_id: run_id.to_string(), @@ -69,5 +72,6 @@ pub(crate) fn build_fake_core( workflow_machines: WorkflowConcurrencyManager::new(), workflow_task_tokens: Default::default(), pending_activations: Default::default(), + shutdown_requested: AtomicBool::new(false), } }