diff --git a/.buildkite/docker/Dockerfile b/.buildkite/docker/Dockerfile index 4fe468ca8..cae7a0845 100644 --- a/.buildkite/docker/Dockerfile +++ b/.buildkite/docker/Dockerfile @@ -1,4 +1,4 @@ -FROM rust:1.50 +FROM rust:1.51 RUN rustup component add rustfmt && \ rustup component add clippy diff --git a/Cargo.toml b/Cargo.toml index a4543f56b..4d79882bf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,7 +15,7 @@ base64 = "0.13" crossbeam = "0.8" dashmap = "4.0" derive_more = "0.99" -displaydoc = "0.1" +displaydoc = "0.2" futures = "0.3" itertools = "0.10" once_cell = "1.5" @@ -28,6 +28,7 @@ slotmap = "1.0" thiserror = "1.0" tokio = { version = "1.1", features = ["rt", "rt-multi-thread", "parking_lot"] } tracing = { version = "0.1", features = ["log"] } +tracing-futures = "0.2" tracing-opentelemetry = "0.11" tracing-subscriber = "0.2" url = "2.2" @@ -46,7 +47,7 @@ path = "fsm" [dev-dependencies] assert_matches = "1.4" mockall = "0.9" -rstest = "0.6" +rstest = "0.7" [build-dependencies] tonic-build = "0.4" diff --git a/src/lib.rs b/src/lib.rs index d2ee2bb0d..30aead89b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,7 +1,5 @@ -#![warn(missing_docs)] -// error if there are missing docs -// TODO: Turn on when rust 1.51 docker image available -// #![allow(clippy::upper_case_acronyms)] +#![warn(missing_docs)] // error if there are missing docs +#![allow(clippy::upper_case_acronyms)] //! This crate provides a basis for creating new Temporal SDKs without completely starting from //! scratch @@ -54,9 +52,8 @@ use std::{ atomic::{AtomicBool, Ordering}, Arc, }, - time::Duration, }; -use tokio::runtime::Runtime; +use tokio::sync::Notify; use tonic::codegen::http::uri::InvalidUri; use tracing::Span; @@ -66,29 +63,34 @@ pub type Result = std::result::Result; /// This trait is the primary way by which language specific SDKs interact with the core SDK. It is /// expected that only one instance of an implementation will exist for the lifetime of the /// worker(s) using it. +#[async_trait::async_trait] pub trait Core: Send + Sync { /// Ask the core for some work, returning a [WfActivation]. It is then the language SDK's /// responsibility to call the appropriate workflow code with the provided inputs. Blocks /// indefinitely until such work is available or [shutdown] is called. /// /// TODO: Examples - fn poll_workflow_task(&self, task_queue: &str) -> Result; + async fn poll_workflow_task(&self, task_queue: &str) -> Result; /// Ask the core for some work, returning an [ActivityTask]. It is then the language SDK's /// responsibility to call the appropriate activity code with the provided inputs. Blocks /// indefinitely until such work is available or [shutdown] is called. /// /// TODO: Examples - fn poll_activity_task(&self, task_queue: &str) -> Result; + async fn poll_activity_task(&self, task_queue: &str) -> Result; /// Tell the core that a workflow activation has completed - fn complete_workflow_task(&self, completion: WfActivationCompletion) -> Result<()>; + async fn complete_workflow_task(&self, completion: WfActivationCompletion) -> Result<()>; /// Tell the core that an activity has finished executing - fn complete_activity_task(&self, task_token: Vec, result: ActivityResult) -> Result<()>; + async fn complete_activity_task( + &self, + task_token: Vec, + result: ActivityResult, + ) -> Result<()>; /// Indicate that a long running activity is still making progress - fn send_activity_heartbeat(&self, task_token: ActivityHeartbeat) -> Result<()>; + async fn send_activity_heartbeat(&self, task_token: ActivityHeartbeat) -> Result<()>; /// Returns core's instance of the [ServerGatewayApis] implementor it is using. fn server_gateway(&self) -> Arc; @@ -114,23 +116,14 @@ pub struct CoreInitOptions { /// # Panics /// * Will panic if called from within an async context, as it will construct a runtime and you /// cannot construct a runtime from within a runtime. -pub fn init(opts: CoreInitOptions) -> Result { - let runtime = Runtime::new().map_err(CoreInitError::TokioInitError)?; +pub async fn init(opts: CoreInitOptions) -> Result { // Initialize server client - let work_provider = runtime.block_on(opts.gateway_opts.connect())?; - - Ok(CoreSDK { - runtime, - server_gateway: Arc::new(work_provider), - workflow_machines: WorkflowConcurrencyManager::new(), - workflow_task_tokens: Default::default(), - pending_activations: Default::default(), - shutdown_requested: AtomicBool::new(false), - }) + let work_provider = opts.gateway_opts.connect().await?; + + Ok(CoreSDK::new(work_provider)) } struct CoreSDK { - runtime: Runtime, /// Provides work in the form of responses the server would send from polling task Qs server_gateway: Arc, /// Key is run id @@ -144,38 +137,39 @@ struct CoreSDK { /// Has shutdown been called? shutdown_requested: AtomicBool, + /// Used to wake up future which checks shutdown state + shutdown_notify: Notify, } /// Can be used inside the CoreSDK impl to block on any method that polls the server until it /// responds, or until the shutdown flag is set (aborting the poll) macro_rules! abort_on_shutdown { - ($self:ident, $gateway_fn:tt, $poll_arg:expr) => { - $self.runtime.block_on(async { - let shutdownfut = async { - loop { - if $self.shutdown_requested.load(Ordering::Relaxed) { - break; - } - tokio::time::sleep(Duration::from_millis(100)).await; - } - }; - let poll_result_future = $self.server_gateway.$gateway_fn($poll_arg); - tokio::select! { - _ = shutdownfut => { - Err(CoreError::ShuttingDown) + ($self:ident, $gateway_fn:tt, $poll_arg:expr) => {{ + let shutdownfut = async { + loop { + $self.shutdown_notify.notified().await; + if $self.shutdown_requested.load(Ordering::SeqCst) { + break; } - r = poll_result_future => r.map_err(Into::into) } - }) - }; + }; + let poll_result_future = $self.server_gateway.$gateway_fn($poll_arg); + tokio::select! { + _ = shutdownfut => { + Err(CoreError::ShuttingDown) + } + r = poll_result_future => r.map_err(Into::into) + } + }}; } +#[async_trait::async_trait] impl Core for CoreSDK where WP: ServerGatewayApis + Send + Sync + 'static, { #[instrument(skip(self), fields(pending_activation))] - fn poll_workflow_task(&self, task_queue: &str) -> Result { + async fn poll_workflow_task(&self, task_queue: &str) -> Result { // The poll needs to be in a loop because we can't guarantee tail call optimization in Rust // (simply) and we really, really need that for long-poll retries. loop { @@ -236,7 +230,7 @@ where } #[instrument(skip(self))] - fn poll_activity_task(&self, task_queue: &str) -> Result { + async fn poll_activity_task(&self, task_queue: &str) -> Result { match abort_on_shutdown!(self, poll_activity_task, task_queue.to_owned()) { Ok(work) => { let task_token = work.task_token.clone(); @@ -247,7 +241,7 @@ where } #[instrument(skip(self))] - fn complete_workflow_task(&self, completion: WfActivationCompletion) -> Result<()> { + async fn complete_workflow_task(&self, completion: WfActivationCompletion) -> Result<()> { let task_token = completion.task_token; let wfstatus = completion.status; let run_id = self @@ -280,11 +274,9 @@ where // no more pending activations -- in other words the lang SDK has caught // up on replay. if !self.pending_activations.has_pending(&run_id) { - self.runtime - .block_on( - self.server_gateway - .complete_workflow_task(task_token, commands), - ) + self.server_gateway + .complete_workflow_task(task_token, commands) + .await .map_err(|ts| { if ts.code() == tonic::Code::InvalidArgument && ts.message() == "UnhandledCommand" @@ -300,12 +292,13 @@ where // Blow up any cached data associated with the workflow self.evict_run(&run_id); - self.runtime - .block_on(self.server_gateway.fail_workflow_task( + self.server_gateway + .fail_workflow_task( task_token, WorkflowTaskFailedCause::Unspecified, failure.failure.map(Into::into), - ))?; + ) + .await?; } None => { return Err(CoreError::MalformedWorkflowCompletion { @@ -318,7 +311,11 @@ where } #[instrument(skip(self))] - fn complete_activity_task(&self, task_token: Vec, result: ActivityResult) -> Result<()> { + async fn complete_activity_task( + &self, + task_token: Vec, + result: ActivityResult, + ) -> Result<()> { let status = if let Some(s) = result.status { s } else { @@ -329,28 +326,25 @@ where }; match status { activity_result::Status::Completed(ar::Success { result }) => { - self.runtime.block_on( - self.server_gateway - .complete_activity_task(task_token, result.map(Into::into)), - )?; + self.server_gateway + .complete_activity_task(task_token, result.map(Into::into)) + .await?; } activity_result::Status::Failed(ar::Failure { failure }) => { - self.runtime.block_on( - self.server_gateway - .fail_activity_task(task_token, failure.map(Into::into)), - )?; + self.server_gateway + .fail_activity_task(task_token, failure.map(Into::into)) + .await?; } activity_result::Status::Canceled(ar::Cancelation { details }) => { - self.runtime.block_on( - self.server_gateway - .cancel_activity_task(task_token, details.map(Into::into)), - )?; + self.server_gateway + .cancel_activity_task(task_token, details.map(Into::into)) + .await?; } } Ok(()) } - fn send_activity_heartbeat(&self, _task_token: ActivityHeartbeat) -> Result<()> { + async fn send_activity_heartbeat(&self, _task_token: ActivityHeartbeat) -> Result<()> { unimplemented!() } @@ -360,11 +354,23 @@ where fn shutdown(&self) { self.shutdown_requested.store(true, Ordering::SeqCst); + self.shutdown_notify.notify_one(); self.workflow_machines.shutdown(); } } impl CoreSDK { + pub(crate) fn new(wp: WP) -> Self { + Self { + server_gateway: Arc::new(wp), + workflow_machines: WorkflowConcurrencyManager::new(), + workflow_task_tokens: Default::default(), + pending_activations: Default::default(), + shutdown_requested: AtomicBool::new(false), + shutdown_notify: Notify::new(), + } + } + /// Will create a new workflow manager if needed for the workflow task, if not, it will /// feed the existing manager the updated history we received from the server. /// @@ -436,8 +442,8 @@ impl CoreSDK { #[allow(clippy::large_enum_variant)] // NOTE: Docstrings take the place of #[error("xxxx")] here b/c of displaydoc pub enum CoreError { - /// [Core::shutdown] was called, and there are no more replay tasks to be handled. You must - /// call [Core::complete_task] for any remaining tasks, and then may exit. + /** [Core::shutdown] was called, and there are no more replay tasks to be handled. You must + call [Core::complete_task] for any remaining tasks, and then may exit.*/ ShuttingDown, /// Poll workflow response from server was malformed: {0:?} BadPollResponseFromServer(PollWorkflowTaskQueueResponse), @@ -462,9 +468,9 @@ pub enum CoreError { /// The run id of the erring workflow run_id: String, }, - /// There exists a pending command in this workflow's history which has not yet been handled. - /// When thrown from [Core::complete_task], it means you should poll for a new task, receive a - /// new task token, and complete that new task. + /** There exists a pending command in this workflow's history which has not yet been handled. + When thrown from [Core::complete_task], it means you should poll for a new task, receive a + new task token, and complete that new task. */ UnhandledCommandWhenCompleting, /// Unhandled error when calling the temporal server: {0:?} TonicError(#[from] tonic::Status), @@ -473,8 +479,6 @@ pub enum CoreError { /// Errors thrown during initialization of [Core] #[derive(thiserror::Error, Debug, displaydoc::Display)] pub enum CoreInitError { - /// Failed to initialize tokio runtime: {0:?} - TokioInitError(std::io::Error), /// Invalid URI: {0:?} InvalidUri(#[from] InvalidUri), /// Server connection error: {0:?} @@ -533,11 +537,12 @@ mod test { } #[rstest(core, - case::incremental(single_timer_setup(&[1, 2])), - case::replay(single_timer_setup(&[2])) + case::incremental(single_timer_setup(&[1, 2])), + case::replay(single_timer_setup(&[2])) )] - fn single_timer_test_across_wf_bridge(core: FakeCore) { - let res = core.poll_workflow_task(TASK_Q).unwrap(); + #[tokio::test] + async fn single_timer_test_across_wf_bridge(core: FakeCore) { + let res = core.poll_workflow_task(TASK_Q).await.unwrap(); assert_matches!( res.jobs.as_slice(), [WfActivationJob { @@ -555,9 +560,10 @@ mod test { .into()], task_tok, )) + .await .unwrap(); - let res = core.poll_workflow_task(TASK_Q).unwrap(); + let res = core.poll_workflow_task(TASK_Q).await.unwrap(); assert_matches!( res.jobs.as_slice(), [WfActivationJob { @@ -569,17 +575,19 @@ mod test { vec![CompleteWorkflowExecution { result: None }.into()], task_tok, )) + .await .unwrap(); } #[rstest(core, - case::incremental(single_activity_setup(&[1, 2])), - case::incremental_activity_failure(single_activity_failure_setup(&[1, 2])), - case::replay(single_activity_setup(&[2])), - case::replay_activity_failure(single_activity_failure_setup(&[2])) + case::incremental(single_activity_setup(&[1, 2])), + case::incremental_activity_failure(single_activity_failure_setup(&[1, 2])), + case::replay(single_activity_setup(&[2])), + case::replay_activity_failure(single_activity_failure_setup(&[2])) )] - fn single_activity_completion(core: FakeCore) { - let res = core.poll_workflow_task(TASK_Q).unwrap(); + #[tokio::test] + async fn single_activity_completion(core: FakeCore) { + let res = core.poll_workflow_task(TASK_Q).await.unwrap(); assert_matches!( res.jobs.as_slice(), [WfActivationJob { @@ -597,9 +605,10 @@ mod test { .into()], task_tok, )) + .await .unwrap(); - let res = core.poll_workflow_task(TASK_Q).unwrap(); + let res = core.poll_workflow_task(TASK_Q).await.unwrap(); assert_matches!( res.jobs.as_slice(), [WfActivationJob { @@ -611,11 +620,13 @@ mod test { vec![CompleteWorkflowExecution { result: None }.into()], task_tok, )) + .await .unwrap(); } #[rstest(hist_batches, case::incremental(&[1, 2]), case::replay(&[2]))] - fn parallel_timer_test_across_wf_bridge(hist_batches: &[usize]) { + #[tokio::test] + async fn parallel_timer_test_across_wf_bridge(hist_batches: &[usize]) { let wfid = "fake_wf_id"; let run_id = "fake_run_id"; let timer_1_id = "timer1"; @@ -624,7 +635,7 @@ mod test { let mut t = canned_histories::parallel_timer(timer_1_id, timer_2_id); let core = build_fake_core(wfid, run_id, &mut t, hist_batches); - let res = core.poll_workflow_task(TASK_Q).unwrap(); + let res = core.poll_workflow_task(TASK_Q).await.unwrap(); assert_matches!( res.jobs.as_slice(), [WfActivationJob { @@ -649,9 +660,10 @@ mod test { ], task_tok, )) + .await .unwrap(); - let res = core.poll_workflow_task(TASK_Q).unwrap(); + let res = core.poll_workflow_task(TASK_Q).await.unwrap(); assert_matches!( res.jobs.as_slice(), [ @@ -675,11 +687,13 @@ mod test { vec![CompleteWorkflowExecution { result: None }.into()], task_tok, )) + .await .unwrap(); } #[rstest(hist_batches, case::incremental(&[1, 2]), case::replay(&[2]))] - fn timer_cancel_test_across_wf_bridge(hist_batches: &[usize]) { + #[tokio::test] + async fn timer_cancel_test_across_wf_bridge(hist_batches: &[usize]) { let wfid = "fake_wf_id"; let run_id = "fake_run_id"; let timer_id = "wait_timer"; @@ -688,7 +702,7 @@ mod test { let mut t = canned_histories::cancel_timer(timer_id, cancel_timer_id); let core = build_fake_core(wfid, run_id, &mut t, hist_batches); - let res = core.poll_workflow_task(TASK_Q).unwrap(); + let res = core.poll_workflow_task(TASK_Q).await.unwrap(); assert_matches!( res.jobs.as_slice(), [WfActivationJob { @@ -713,9 +727,10 @@ mod test { ], task_tok, )) + .await .unwrap(); - let res = core.poll_workflow_task(TASK_Q).unwrap(); + let res = core.poll_workflow_task(TASK_Q).await.unwrap(); assert_matches!( res.jobs.as_slice(), [WfActivationJob { @@ -733,23 +748,28 @@ mod test { ], task_tok, )) + .await .unwrap(); } #[rstest(single_timer_setup(&[1]))] - fn after_shutdown_server_is_not_polled(single_timer_setup: FakeCore) { - let res = single_timer_setup.poll_workflow_task(TASK_Q).unwrap(); + #[tokio::test] + async fn after_shutdown_server_is_not_polled(single_timer_setup: FakeCore) { + let res = single_timer_setup.poll_workflow_task(TASK_Q).await.unwrap(); assert_eq!(res.jobs.len(), 1); single_timer_setup.shutdown(); assert_matches!( - single_timer_setup.poll_workflow_task(TASK_Q).unwrap_err(), + single_timer_setup + .poll_workflow_task(TASK_Q) + .await + .unwrap_err(), CoreError::ShuttingDown ); } - #[test] - fn workflow_update_random_seed_on_workflow_reset() { + #[tokio::test] + async fn workflow_update_random_seed_on_workflow_reset() { let wfid = "fake_wf_id"; let run_id = "CA733AB0-8133-45F6-A4C1-8D375F61AE8B"; let original_run_id = "86E39A5F-AE31-4626-BDFE-398EE072D156"; @@ -759,7 +779,7 @@ mod test { canned_histories::workflow_fails_with_reset_after_timer(timer_1_id, original_run_id); let core = build_fake_core(wfid, run_id, &mut t, &[2]); - let res = core.poll_workflow_task(TASK_Q).unwrap(); + let res = core.poll_workflow_task(TASK_Q).await.unwrap(); let randomness_seed_from_start: u64; assert_matches!( res.jobs.as_slice(), @@ -782,9 +802,10 @@ mod test { .into()], task_tok, )) + .await .unwrap(); - let res = core.poll_workflow_task(TASK_Q).unwrap(); + let res = core.poll_workflow_task(TASK_Q).await.unwrap(); assert_matches!( res.jobs.as_slice(), [WfActivationJob { @@ -801,13 +822,15 @@ mod test { vec![CompleteWorkflowExecution { result: None }.into()], task_tok, )) + .await .unwrap(); } // The incremental version only does one batch here, because the workflow completes right away // and any subsequent poll would block forever with nothing to do. #[rstest(hist_batches, case::incremental(&[1]), case::replay(&[2]))] - fn cancel_timer_before_sent_wf_bridge(hist_batches: &[usize]) { + #[tokio::test] + async fn cancel_timer_before_sent_wf_bridge(hist_batches: &[usize]) { let wfid = "fake_wf_id"; let run_id = "fake_run_id"; let cancel_timer_id = "cancel_timer"; @@ -819,7 +842,7 @@ mod test { let core = build_fake_core(wfid, run_id, &mut t, hist_batches); - let res = core.poll_workflow_task(TASK_Q).unwrap(); + let res = core.poll_workflow_task(TASK_Q).await.unwrap(); assert_matches!( res.jobs.as_slice(), [WfActivationJob { @@ -843,11 +866,12 @@ mod test { ], task_tok, )) + .await .unwrap(); } - #[test] - fn complete_activation_with_failure() { + #[tokio::test] + async fn complete_activation_with_failure() { let wfid = "fake_wf_id"; let timer_id = "timer"; @@ -860,7 +884,7 @@ mod test { .times(1) .returning(|_, _, _| Ok(RespondWorkflowTaskFailedResponse {})); - let res = core.poll_workflow_task(TASK_Q).unwrap(); + let res = core.poll_workflow_task(TASK_Q).await.unwrap(); core.complete_workflow_task(WfActivationCompletion::ok_from_cmds( vec![StartTimer { timer_id: timer_id.to_string(), @@ -869,9 +893,10 @@ mod test { .into()], res.task_token, )) + .await .unwrap(); - let res = core.poll_workflow_task(TASK_Q).unwrap(); + let res = core.poll_workflow_task(TASK_Q).await.unwrap(); core.complete_workflow_task(WfActivationCompletion::fail( res.task_token, UserCodeFailure { @@ -879,9 +904,10 @@ mod test { ..Default::default() }, )) + .await .unwrap(); - let res = core.poll_workflow_task(TASK_Q).unwrap(); + let res = core.poll_workflow_task(TASK_Q).await.unwrap(); assert_matches!( res.jobs.as_slice(), [WfActivationJob { @@ -897,9 +923,10 @@ mod test { .into()], res.task_token, )) + .await .unwrap(); // Now we may complete the workflow - let res = core.poll_workflow_task(TASK_Q).unwrap(); + let res = core.poll_workflow_task(TASK_Q).await.unwrap(); assert_matches!( res.jobs.as_slice(), [WfActivationJob { @@ -910,11 +937,13 @@ mod test { vec![CompleteWorkflowExecution { result: None }.into()], res.task_token, )) + .await .unwrap(); } #[rstest(hist_batches, case::incremental(&[1, 2]), case::replay(&[2]))] - fn simple_timer_fail_wf_execution(hist_batches: &[usize]) { + #[tokio::test] + async fn simple_timer_fail_wf_execution(hist_batches: &[usize]) { let wfid = "fake_wf_id"; let run_id = "fake_run_id"; let timer_id = "timer1"; @@ -922,7 +951,7 @@ mod test { let mut t = canned_histories::single_timer(timer_id); let core = build_fake_core(wfid, run_id, &mut t, hist_batches); - let res = core.poll_workflow_task(TASK_Q).unwrap(); + let res = core.poll_workflow_task(TASK_Q).await.unwrap(); core.complete_workflow_task(WfActivationCompletion::ok_from_cmds( vec![StartTimer { timer_id: timer_id.to_string(), @@ -931,9 +960,10 @@ mod test { .into()], res.task_token, )) + .await .unwrap(); - let res = core.poll_workflow_task(TASK_Q).unwrap(); + let res = core.poll_workflow_task(TASK_Q).await.unwrap(); core.complete_workflow_task(WfActivationCompletion::ok_from_cmds( vec![FailWorkflowExecution { failure: Some(UserCodeFailure { @@ -944,23 +974,26 @@ mod test { .into()], res.task_token, )) + .await .unwrap(); } #[rstest(hist_batches, case::incremental(&[1, 2]), case::replay(&[2]))] - fn two_signals(hist_batches: &[usize]) { + #[tokio::test] + async fn two_signals(hist_batches: &[usize]) { let wfid = "fake_wf_id"; let run_id = "fake_run_id"; let mut t = canned_histories::two_signals("sig1", "sig2"); let core = build_fake_core(wfid, run_id, &mut t, hist_batches); - let res = core.poll_workflow_task(TASK_Q).unwrap(); + let res = core.poll_workflow_task(TASK_Q).await.unwrap(); // Task is completed with no commands core.complete_workflow_task(WfActivationCompletion::ok_from_cmds(vec![], res.task_token)) + .await .unwrap(); - let res = core.poll_workflow_task(TASK_Q).unwrap(); + let res = core.poll_workflow_task(TASK_Q).await.unwrap(); assert_matches!( res.jobs.as_slice(), [ diff --git a/src/machines/test_help/mod.rs b/src/machines/test_help/mod.rs index 011da2805..42965053f 100644 --- a/src/machines/test_help/mod.rs +++ b/src/machines/test_help/mod.rs @@ -6,7 +6,6 @@ mod history_builder; pub(super) use async_workflow_driver::{CommandSender, TestWorkflowDriver}; pub(crate) use history_builder::TestHistoryBuilder; -use crate::workflow::WorkflowConcurrencyManager; use crate::{ pollers::MockServerGatewayApis, protos::temporal::api::common::v1::WorkflowExecution, @@ -14,12 +13,10 @@ use crate::{ protos::temporal::api::workflowservice::v1::{ PollWorkflowTaskQueueResponse, RespondWorkflowTaskCompletedResponse, }, - CoreSDK, ServerGatewayApis, + CoreSDK, }; use rand::{thread_rng, Rng}; -use std::sync::atomic::AtomicBool; -use std::{collections::VecDeque, sync::Arc}; -use tokio::runtime::Runtime; +use std::collections::VecDeque; pub(crate) type FakeCore = CoreSDK; @@ -65,20 +62,5 @@ pub(crate) fn build_fake_core( .expect_complete_workflow_task() .returning(|_, _| Ok(RespondWorkflowTaskCompletedResponse::default())); - fake_core_from_mock(mock_gateway) -} - -pub(crate) fn fake_core_from_mock(mock_gateway: MT) -> CoreSDK -where - MT: ServerGatewayApis, -{ - let runtime = Runtime::new().unwrap(); - CoreSDK { - runtime, - server_gateway: Arc::new(mock_gateway), - workflow_machines: WorkflowConcurrencyManager::new(), - workflow_task_tokens: Default::default(), - pending_activations: Default::default(), - shutdown_requested: AtomicBool::new(false), - } + CoreSDK::new(mock_gateway) } diff --git a/src/workflow/mod.rs b/src/workflow/mod.rs index 00880f639..0bb82109f 100644 --- a/src/workflow/mod.rs +++ b/src/workflow/mod.rs @@ -31,11 +31,11 @@ pub enum WorkflowError { UnderlyingMachinesError(#[from] WFMachinesError), /// There was an error in the history associated with the workflow: {0:?} HistoryError(#[from] HistoryInfoError), - /// Error buffering commands coming in from the lang side. This shouldn't happen unless we've - /// run out of memory or there is a logic bug. Considered fatal. + /** Error buffering commands coming in from the lang side. This shouldn't happen unless we've + run out of memory or there is a logic bug. Considered fatal. */ CommandBufferingError(#[from] SendError>), - /// We tried to instantiate a workflow instance, but the provided history resulted in no - /// new activations. There is nothing to do. + /** We tried to instantiate a workflow instance, but the provided history resulted in no + new activations. There is nothing to do. */ MachineWasCreatedWithNoActivations { run_id: String }, } @@ -54,7 +54,7 @@ pub(crate) struct WorkflowManager { } impl WorkflowManager { - /// Create a new workflow manager given workflow history and exection info as would be found + /// Create a new workflow manager given workflow history and execution info as would be found /// in [PollWorkflowTaskQueueResponse] pub fn new( history: History, diff --git a/tests/integ_tests/simple_wf_tests.rs b/tests/integ_tests/simple_wf_tests.rs index a1880ff78..61af33631 100644 --- a/tests/integ_tests/simple_wf_tests.rs +++ b/tests/integ_tests/simple_wf_tests.rs @@ -1,17 +1,8 @@ use assert_matches::assert_matches; use crossbeam::channel::{unbounded, RecvTimeoutError}; -use futures::Future; +use futures::{channel::mpsc::UnboundedReceiver, future, Future, SinkExt, StreamExt}; use rand::{self, Rng}; -use std::{ - collections::HashMap, - convert::TryFrom, - env, - sync::{ - mpsc::{channel, Receiver}, - Arc, - }, - time::Duration, -}; +use std::{collections::HashMap, convert::TryFrom, env, sync::Arc, time::Duration}; use temporal_sdk_core::{ protos::coresdk::{ activity_result::{self, activity_result as act_res, ActivityResult}, @@ -40,7 +31,7 @@ use temporal_sdk_core::{ const NAMESPACE: &str = "default"; type GwApi = Arc; -fn create_workflow( +async fn create_workflow( core: &dyn Core, task_q: &str, workflow_id: &str, @@ -57,9 +48,9 @@ fn create_workflow( .unwrap() .run_id }) + .await } -#[tokio::main] async fn with_gw Fout, Fout: Future>(core: &dyn Core, fun: F) -> Fout::Output { let gw = core.server_gateway(); fun(gw).await @@ -80,20 +71,22 @@ fn get_integ_server_options() -> ServerGatewayOptions { } } -fn get_integ_core() -> impl Core { +async fn get_integ_core() -> impl Core { let gateway_opts = get_integ_server_options(); - temporal_sdk_core::init(CoreInitOptions { gateway_opts }).unwrap() + temporal_sdk_core::init(CoreInitOptions { gateway_opts }) + .await + .unwrap() } -#[test] -fn timer_workflow() { +#[tokio::test] +async fn timer_workflow() { let task_q = "timer_workflow"; - let core = get_integ_core(); + let core = get_integ_core().await; let mut rng = rand::thread_rng(); let workflow_id: u32 = rng.gen(); - create_workflow(&core, task_q, &workflow_id.to_string(), None); + create_workflow(&core, task_q, &workflow_id.to_string(), None).await; let timer_id: String = rng.gen::().to_string(); - let task = core.poll_workflow_task(task_q).unwrap(); + let task = core.poll_workflow_task(task_q).await.unwrap(); core.complete_workflow_task(WfActivationCompletion::ok_from_cmds( vec![StartTimer { timer_id, @@ -102,30 +95,33 @@ fn timer_workflow() { .into()], task.task_token, )) + .await .unwrap(); - let task = dbg!(core.poll_workflow_task(task_q).unwrap()); + let task = core.poll_workflow_task(task_q).await.unwrap(); core.complete_workflow_task(WfActivationCompletion::ok_from_cmds( vec![CompleteWorkflowExecution { result: None }.into()], task.task_token, )) + .await .unwrap(); } -#[test] -fn activity_workflow() { +#[tokio::test] +async fn activity_workflow() { let mut rng = rand::thread_rng(); let task_q_salt: u32 = rng.gen(); let task_q = &format!("activity_workflow_{}", task_q_salt.to_string()); - let core = get_integ_core(); + let core = get_integ_core().await; let workflow_id: u32 = rng.gen(); - create_workflow(&core, task_q, &workflow_id.to_string(), None); + create_workflow(&core, task_q, &workflow_id.to_string(), None).await; let activity_id: String = rng.gen::().to_string(); - let task = core.poll_workflow_task(task_q).unwrap(); + let task = core.poll_workflow_task(task_q).await.unwrap(); // Complete workflow task and schedule activity core.complete_workflow_task(activity_completion_req(task_q, &activity_id, task)) + .await .unwrap(); // Poll activity and verify that it's been scheduled with correct parameters - let task = dbg!(core.poll_activity_task(task_q).unwrap()); + let task = dbg!(core.poll_activity_task(task_q).await.unwrap()); assert_matches!( task.variant, Some(act_task::Variant::Start(start_activity)) => { @@ -141,9 +137,10 @@ fn activity_workflow() { task.task_token, ActivityResult::ok(response_payload.clone()), ) + .await .unwrap(); // Poll workflow task and verify that activity has succeeded. - let task = core.poll_workflow_task(task_q).unwrap(); + let task = core.poll_workflow_task(task_q).await.unwrap(); assert_matches!( task.jobs.as_slice(), [ @@ -163,24 +160,26 @@ fn activity_workflow() { vec![CompleteWorkflowExecution { result: None }.into()], task.task_token, )) + .await .unwrap() } -#[test] -fn activity_non_retryable_failure() { +#[tokio::test] +async fn activity_non_retryable_failure() { let mut rng = rand::thread_rng(); let task_q_salt: u32 = rng.gen(); let task_q = &format!("activity_failed_workflow_{}", task_q_salt.to_string()); - let core = get_integ_core(); + let core = get_integ_core().await; let workflow_id: u32 = rng.gen(); - create_workflow(&core, task_q, &workflow_id.to_string(), None); + create_workflow(&core, task_q, &workflow_id.to_string(), None).await; let activity_id: String = rng.gen::().to_string(); - let task = core.poll_workflow_task(task_q).unwrap(); + let task = core.poll_workflow_task(task_q).await.unwrap(); // Complete workflow task and schedule activity core.complete_workflow_task(activity_completion_req(task_q, &activity_id, task)) + .await .unwrap(); // Poll activity and verify that it's been scheduled with correct parameters - let task = dbg!(core.poll_activity_task(task_q).unwrap()); + let task = dbg!(core.poll_activity_task(task_q).await.unwrap()); assert_matches!( task.variant, Some(act_task::Variant::Start(start_activity)) => { @@ -203,9 +202,10 @@ fn activity_non_retryable_failure() { )), }, ) + .await .unwrap(); // Poll workflow task and verify that activity has failed. - let task = core.poll_workflow_task(task_q).unwrap(); + let task = core.poll_workflow_task(task_q).await.unwrap(); assert_matches!( task.jobs.as_slice(), [ @@ -224,24 +224,26 @@ fn activity_non_retryable_failure() { vec![CompleteWorkflowExecution { result: None }.into()], task.task_token, )) + .await .unwrap() } -#[test] -fn activity_retry() { +#[tokio::test] +async fn activity_retry() { let mut rng = rand::thread_rng(); let task_q_salt: u32 = rng.gen(); let task_q = &format!("activity_failed_workflow_{}", task_q_salt.to_string()); - let core = get_integ_core(); + let core = get_integ_core().await; let workflow_id: u32 = rng.gen(); - create_workflow(&core, task_q, &workflow_id.to_string(), None); + create_workflow(&core, task_q, &workflow_id.to_string(), None).await; let activity_id: String = rng.gen::().to_string(); - let task = core.poll_workflow_task(task_q).unwrap(); + let task = core.poll_workflow_task(task_q).await.unwrap(); // Complete workflow task and schedule activity core.complete_workflow_task(activity_completion_req(task_q, &activity_id, task)) + .await .unwrap(); // Poll activity 1st time - let task = dbg!(core.poll_activity_task(task_q).unwrap()); + let task = core.poll_activity_task(task_q).await.unwrap(); assert_matches!( task.variant, Some(act_task::Variant::Start(start_activity)) => { @@ -264,9 +266,10 @@ fn activity_retry() { )), }, ) + .await .unwrap(); // Poll 2nd time - let task = dbg!(core.poll_activity_task(task_q).unwrap()); + let task = dbg!(core.poll_activity_task(task_q).await.unwrap()); assert_matches!( task.variant, Some(act_task::Variant::Start(start_activity)) => { @@ -288,9 +291,10 @@ fn activity_retry() { )), }, ) + .await .unwrap(); // Poll workflow task and verify activity has succeeded. - let task = core.poll_workflow_task(task_q).unwrap(); + let task = core.poll_workflow_task(task_q).await.unwrap(); assert_matches!( task.jobs.as_slice(), [ @@ -309,6 +313,7 @@ fn activity_retry() { vec![CompleteWorkflowExecution { result: None }.into()], task.task_token, )) + .await .unwrap() } @@ -334,16 +339,16 @@ fn activity_completion_req( ) } -#[test] -fn parallel_timer_workflow() { +#[tokio::test] +async fn parallel_timer_workflow() { let task_q = "parallel_timer_workflow"; - let core = get_integ_core(); + let core = get_integ_core().await; let mut rng = rand::thread_rng(); let workflow_id: u32 = rng.gen(); - create_workflow(&core, task_q, &workflow_id.to_string(), None); + create_workflow(&core, task_q, &workflow_id.to_string(), None).await; let timer_id = "timer 1".to_string(); let timer_2_id = "timer 2".to_string(); - let task = dbg!(core.poll_workflow_task(task_q).unwrap()); + let task = core.poll_workflow_task(task_q).await.unwrap(); core.complete_workflow_task(WfActivationCompletion::ok_from_cmds( vec![ StartTimer { @@ -359,11 +364,12 @@ fn parallel_timer_workflow() { ], task.task_token, )) + .await .unwrap(); // Wait long enough for both timers to complete. Server seems to be a bit weird about actually // sending both of these in one go, so we need to wait longer than you would expect. std::thread::sleep(Duration::from_millis(1500)); - let task = core.poll_workflow_task(task_q).unwrap(); + let task = core.poll_workflow_task(task_q).await.unwrap(); assert_matches!( task.jobs.as_slice(), [ @@ -386,24 +392,20 @@ fn parallel_timer_workflow() { vec![CompleteWorkflowExecution { result: None }.into()], task.task_token, )) + .await .unwrap(); } -#[test] -fn timer_cancel_workflow() { +#[tokio::test] +async fn timer_cancel_workflow() { let task_q = "timer_cancel_workflow"; - let core = get_integ_core(); + let core = get_integ_core().await; let mut rng = rand::thread_rng(); let workflow_id: u32 = rng.gen(); - dbg!(create_workflow( - &core, - task_q, - &workflow_id.to_string(), - None - )); + create_workflow(&core, task_q, &workflow_id.to_string(), None).await; let timer_id = "wait_timer"; let cancel_timer_id = "cancel_timer"; - let task = core.poll_workflow_task(task_q).unwrap(); + let task = core.poll_workflow_task(task_q).await.unwrap(); core.complete_workflow_task(WfActivationCompletion::ok_from_cmds( vec![ StartTimer { @@ -419,8 +421,9 @@ fn timer_cancel_workflow() { ], task.task_token, )) + .await .unwrap(); - let task = dbg!(core.poll_workflow_task(task_q).unwrap()); + let task = dbg!(core.poll_workflow_task(task_q).await.unwrap()); core.complete_workflow_task(WfActivationCompletion::ok_from_cmds( vec![ CancelTimer { @@ -431,18 +434,19 @@ fn timer_cancel_workflow() { ], task.task_token, )) + .await .unwrap(); } -#[test] -fn timer_immediate_cancel_workflow() { +#[tokio::test] +async fn timer_immediate_cancel_workflow() { let task_q = "timer_immediate_cancel_workflow"; - let core = get_integ_core(); + let core = get_integ_core().await; let mut rng = rand::thread_rng(); let workflow_id: u32 = rng.gen(); - create_workflow(&core, task_q, &workflow_id.to_string(), None); + create_workflow(&core, task_q, &workflow_id.to_string(), None).await; let cancel_timer_id = "cancel_timer"; - let task = core.poll_workflow_task(task_q).unwrap(); + let task = core.poll_workflow_task(task_q).await.unwrap(); core.complete_workflow_task(WfActivationCompletion::ok_from_cmds( vec![ StartTimer { @@ -458,23 +462,29 @@ fn timer_immediate_cancel_workflow() { ], task.task_token, )) + .await .unwrap(); } -#[test] -fn parallel_workflows_same_queue() { +#[tokio::test] +async fn parallel_workflows_same_queue() { let task_q = "parallel_workflows_same_queue"; - let core = get_integ_core(); + let core = get_integ_core().await; let num_workflows = 25; - let run_ids: Vec<_> = (0..num_workflows) - .map(|i| create_workflow(&core, task_q, &format!("wf-id-{}", i), Some("wf-type-1"))) - .collect(); + let run_ids: Vec<_> = + future::join_all((0..num_workflows).map(|i| { + let core = &core; + async move { + create_workflow(core, task_q, &format!("wf-id-{}", i), Some("wf-type-1")).await + } + })) + .await; let mut send_chans = HashMap::new(); - fn wf_thread(core: Arc, task_chan: Receiver) { - let task = task_chan.recv().unwrap(); + async fn wf_task(core: Arc, mut task_chan: UnboundedReceiver) { + let task = task_chan.next().await.unwrap(); assert_matches!( task.jobs.as_slice(), [WfActivationJob { @@ -494,12 +504,14 @@ fn parallel_workflows_same_queue() { .into()], task.task_token, )) + .await .unwrap(); - let task = task_chan.recv().unwrap(); + let task = task_chan.next().await.unwrap(); core.complete_workflow_task(WfActivationCompletion::ok_from_cmds( vec![CompleteWorkflowExecution { result: None }.into()], task.task_token, )) + .await .unwrap(); } @@ -507,28 +519,35 @@ fn parallel_workflows_same_queue() { let handles: Vec<_> = run_ids .iter() .map(|run_id| { - let (tx, rx) = channel(); + let (tx, rx) = futures::channel::mpsc::unbounded(); send_chans.insert(run_id.clone(), tx); let core_c = core.clone(); - std::thread::spawn(move || wf_thread(core_c, rx)) + tokio::spawn(wf_task(core_c, rx)) }) .collect(); for _ in 0..num_workflows * 2 { - let task = core.poll_workflow_task(task_q).unwrap(); - send_chans.get(&task.run_id).unwrap().send(task).unwrap(); + let task = core.poll_workflow_task(task_q).await.unwrap(); + send_chans + .get(&task.run_id) + .unwrap() + .send(task) + .await + .unwrap(); } - handles.into_iter().for_each(|h| h.join().unwrap()); + for handle in handles { + handle.await.unwrap() + } } // Ideally this would be a unit test, but returning a pending future with mockall bloats the mock // code a bunch and just isn't worth it. Do it when https://github.com/asomers/mockall/issues/189 is // fixed. -#[test] -fn shutdown_aborts_actively_blocked_poll() { +#[tokio::test] +async fn shutdown_aborts_actively_blocked_poll() { let task_q = "shutdown_aborts_actively_blocked_poll"; - let core = Arc::new(get_integ_core()); + let core = Arc::new(get_integ_core().await); // Begin the poll, and request shutdown from another thread after a small period of time. let tcore = core.clone(); let handle = std::thread::spawn(move || { @@ -536,28 +555,28 @@ fn shutdown_aborts_actively_blocked_poll() { tcore.shutdown(); }); assert_matches!( - core.poll_workflow_task(task_q).unwrap_err(), + core.poll_workflow_task(task_q).await.unwrap_err(), CoreError::ShuttingDown ); handle.join().unwrap(); // Ensure double-shutdown doesn't explode core.shutdown(); assert_matches!( - core.poll_workflow_task(task_q).unwrap_err(), + core.poll_workflow_task(task_q).await.unwrap_err(), CoreError::ShuttingDown ); } -#[test] -fn fail_wf_task() { +#[tokio::test] +async fn fail_wf_task() { let task_q = "fail_wf_task"; - let core = get_integ_core(); + let core = get_integ_core().await; let mut rng = rand::thread_rng(); let workflow_id: u32 = rng.gen(); - create_workflow(&core, task_q, &workflow_id.to_string(), None); + create_workflow(&core, task_q, &workflow_id.to_string(), None).await; // Start with a timer - let task = core.poll_workflow_task(task_q).unwrap(); + let task = core.poll_workflow_task(task_q).await.unwrap(); core.complete_workflow_task(WfActivationCompletion::ok_from_cmds( vec![StartTimer { timer_id: "best-timer".to_string(), @@ -566,13 +585,14 @@ fn fail_wf_task() { .into()], task.task_token, )) + .await .unwrap(); // Allow timer to fire std::thread::sleep(Duration::from_millis(500)); // Then break for whatever reason - let task = core.poll_workflow_task(task_q).unwrap(); + let task = core.poll_workflow_task(task_q).await.unwrap(); core.complete_workflow_task(WfActivationCompletion::fail( task.task_token, UserCodeFailure { @@ -580,11 +600,12 @@ fn fail_wf_task() { ..Default::default() }, )) + .await .unwrap(); // The server will want to retry the task. This time we finish the workflow -- but we need // to poll a couple of times as there will be more than one required workflow activation. - let task = core.poll_workflow_task(task_q).unwrap(); + let task = core.poll_workflow_task(task_q).await.unwrap(); core.complete_workflow_task(WfActivationCompletion::ok_from_cmds( vec![StartTimer { timer_id: "best-timer".to_string(), @@ -593,24 +614,26 @@ fn fail_wf_task() { .into()], task.task_token, )) + .await .unwrap(); - let task = core.poll_workflow_task(task_q).unwrap(); + let task = core.poll_workflow_task(task_q).await.unwrap(); core.complete_workflow_task(WfActivationCompletion::ok_from_cmds( vec![CompleteWorkflowExecution { result: None }.into()], task.task_token, )) + .await .unwrap(); } -#[test] -fn fail_workflow_execution() { +#[tokio::test] +async fn fail_workflow_execution() { let task_q = "fail_workflow_execution"; - let core = get_integ_core(); + let core = get_integ_core().await; let mut rng = rand::thread_rng(); let workflow_id: u32 = rng.gen(); - create_workflow(&core, task_q, &workflow_id.to_string(), None); + create_workflow(&core, task_q, &workflow_id.to_string(), None).await; let timer_id: String = rng.gen::().to_string(); - let task = core.poll_workflow_task(task_q).unwrap(); + let task = core.poll_workflow_task(task_q).await.unwrap(); core.complete_workflow_task(WfActivationCompletion::ok_from_cmds( vec![StartTimer { timer_id, @@ -619,8 +642,9 @@ fn fail_workflow_execution() { .into()], task.task_token, )) + .await .unwrap(); - let task = core.poll_workflow_task(task_q).unwrap(); + let task = core.poll_workflow_task(task_q).await.unwrap(); core.complete_workflow_task(WfActivationCompletion::ok_from_cmds( vec![FailWorkflowExecution { failure: Some(UserCodeFailure { @@ -631,25 +655,27 @@ fn fail_workflow_execution() { .into()], task.task_token, )) + .await .unwrap(); } -#[test] -fn signal_workflow() { +#[tokio::test] +async fn signal_workflow() { let task_q = "signal_workflow"; - let core = get_integ_core(); + let core = get_integ_core().await; let mut rng = rand::thread_rng(); let workflow_id: u32 = rng.gen(); - create_workflow(&core, task_q, &workflow_id.to_string(), None); + create_workflow(&core, task_q, &workflow_id.to_string(), None).await; let signal_id_1 = "signal1"; let signal_id_2 = "signal2"; - let res = core.poll_workflow_task(task_q).unwrap(); + let res = core.poll_workflow_task(task_q).await.unwrap(); // Task is completed with no commands core.complete_workflow_task(WfActivationCompletion::ok_from_cmds( vec![], res.task_token.clone(), )) + .await .unwrap(); // Send the signals to the server @@ -670,9 +696,10 @@ fn signal_workflow() { ) .await .unwrap(); - }); + }) + .await; - let res = core.poll_workflow_task(task_q).unwrap(); + let res = core.poll_workflow_task(task_q).await.unwrap(); assert_matches!( res.jobs.as_slice(), [ @@ -688,19 +715,20 @@ fn signal_workflow() { vec![CompleteWorkflowExecution { result: None }.into()], res.task_token, )) + .await .unwrap(); } -#[test] -fn signal_workflow_signal_not_handled_on_workflow_completion() { +#[tokio::test] +async fn signal_workflow_signal_not_handled_on_workflow_completion() { let task_q = "signal_workflow_signal_not_handled_on_workflow_completion"; - let core = get_integ_core(); + let core = get_integ_core().await; let mut rng = rand::thread_rng(); let workflow_id: u32 = rng.gen(); - create_workflow(&core, task_q, &workflow_id.to_string(), None); + create_workflow(&core, task_q, &workflow_id.to_string(), None).await; let signal_id_1 = "signal1"; - let res = core.poll_workflow_task(task_q).unwrap(); + let res = core.poll_workflow_task(task_q).await.unwrap(); // Task is completed with a timer core.complete_workflow_task(WfActivationCompletion::ok_from_cmds( vec![StartTimer { @@ -710,10 +738,11 @@ fn signal_workflow_signal_not_handled_on_workflow_completion() { .into()], res.task_token, )) + .await .unwrap(); // Poll before sending the signal - we should have the timer job - let res = core.poll_workflow_task(task_q).unwrap(); + let res = core.poll_workflow_task(task_q).await.unwrap(); assert_matches!( res.jobs.as_slice(), [WfActivationJob { @@ -732,7 +761,8 @@ fn signal_workflow_signal_not_handled_on_workflow_completion() { ) .await .unwrap(); - }); + }) + .await; // Send completion - not having seen a poll response with a signal in it yet let unhandled = core @@ -740,11 +770,12 @@ fn signal_workflow_signal_not_handled_on_workflow_completion() { vec![CompleteWorkflowExecution { result: None }.into()], task_token, )) + .await .unwrap_err(); assert_matches!(unhandled, CoreError::UnhandledCommandWhenCompleting); // We should get a new task with the signal - let res = core.poll_workflow_task(task_q).unwrap(); + let res = core.poll_workflow_task(task_q).await.unwrap(); assert_matches!( res.jobs.as_slice(), [WfActivationJob { @@ -755,19 +786,22 @@ fn signal_workflow_signal_not_handled_on_workflow_completion() { vec![CompleteWorkflowExecution { result: None }.into()], res.task_token, )) + .await .unwrap(); } -#[test] -fn long_poll_timeout_is_retried() { +#[tokio::test] +async fn long_poll_timeout_is_retried() { let mut gateway_opts = get_integ_server_options(); // Server whines unless long poll > 2 seconds gateway_opts.long_poll_timeout = Duration::from_secs(3); - let core = temporal_sdk_core::init(CoreInitOptions { gateway_opts }).unwrap(); + let core = temporal_sdk_core::init(CoreInitOptions { gateway_opts }) + .await + .unwrap(); // Should block for more than 3 seconds, since we internally retry long poll let (tx, rx) = unbounded(); - std::thread::spawn(move || { - core.poll_workflow_task("some_task_q").unwrap(); + tokio::spawn(async move { + core.poll_workflow_task("some_task_q").await.unwrap(); tx.send(()) }); let err = rx.recv_timeout(Duration::from_secs(4)).unwrap_err();