diff --git a/src/lib.rs b/src/lib.rs index edcae0014..9fc45b150 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -49,6 +49,7 @@ use std::{ mpsc::SendError, Arc, }, + time::Duration, }; use tokio::runtime::Runtime; use tonic::codegen::http::uri::InvalidUri; @@ -111,10 +112,7 @@ pub fn init(opts: CoreInitOptions) -> Result { }) } -struct CoreSDK -where - WP: ServerGatewayApis + 'static, -{ +struct CoreSDK { runtime: Runtime, /// Provides work in the form of responses the server would send from polling task Qs server_gateway: Arc, @@ -133,7 +131,7 @@ where impl Core for CoreSDK where - WP: ServerGatewayApis + Send + Sync, + WP: ServerGatewayApis + Send + Sync + 'static, { #[instrument(skip(self), fields(pending_activation))] fn poll_task(&self, task_queue: &str) -> Result { @@ -158,29 +156,33 @@ where return Err(CoreError::ShuttingDown); } - // This will block forever in the event there is no work from the server - let work = self - .runtime - .block_on(self.server_gateway.poll_workflow_task(task_queue))?; - let task_token = work.task_token.clone(); - debug!( - task_token = %fmt_task_token(&task_token), - "Received workflow task from server" - ); - - let (next_activation, run_id) = self.instantiate_or_update_workflow(work)?; + // This will block forever (unless interrupted by shutdown) in the event there is no work + // from the server + match self.poll_server(task_queue) { + Ok(work) => { + let task_token = work.task_token.clone(); + debug!( + task_token = %fmt_task_token(&task_token), + "Received workflow task from server" + ); + + let (next_activation, run_id) = self.instantiate_or_update_workflow(work)?; + + if next_activation.more_activations_needed { + self.pending_activations.push(PendingActivation { + run_id, + task_token: task_token.clone(), + }); + } - if next_activation.more_activations_needed { - self.pending_activations.push(PendingActivation { - run_id, - task_token: task_token.clone(), - }); + Ok(Task { + task_token, + variant: next_activation.activation.map(Into::into), + }) + } + Err(CoreError::ShuttingDown) => self.poll_task(task_queue), + Err(e) => Err(e), } - - Ok(Task { - task_token, - variant: next_activation.activation.map(Into::into), - }) } #[instrument(skip(self))] @@ -297,6 +299,28 @@ impl CoreSDK { Ok(()) } + /// Blocks polling the server until it responds, or until the shutdown flag is set (aborting + /// the poll) + fn poll_server(&self, task_queue: &str) -> Result { + self.runtime.block_on(async { + let shutdownfut = async { + loop { + if self.shutdown_requested.load(Ordering::Relaxed) { + break; + } + tokio::time::sleep(Duration::from_millis(100)).await; + } + }; + let pollfut = self.server_gateway.poll_workflow_task(task_queue); + tokio::select! { + _ = shutdownfut => { + Err(CoreError::ShuttingDown) + } + r = pollfut => r + } + }) + } + /// Remove a workflow run from the cache entirely fn evict_run(&self, run_id: &str) { self.workflow_machines.evict(run_id); diff --git a/src/machines/test_help/mod.rs b/src/machines/test_help/mod.rs index 636697b90..f865268a2 100644 --- a/src/machines/test_help/mod.rs +++ b/src/machines/test_help/mod.rs @@ -14,7 +14,7 @@ use crate::{ protos::temporal::api::workflowservice::v1::{ PollWorkflowTaskQueueResponse, RespondWorkflowTaskCompletedResponse, }, - CoreSDK, + CoreSDK, ServerGatewayApis, }; use rand::{thread_rng, Rng}; use std::sync::atomic::AtomicBool; @@ -65,6 +65,13 @@ pub(crate) fn build_fake_core( .expect_complete_workflow_task() .returning(|_, _| Ok(RespondWorkflowTaskCompletedResponse::default())); + fake_core_from_mock(mock_gateway) +} + +pub(crate) fn fake_core_from_mock(mock_gateway: MT) -> CoreSDK +where + MT: ServerGatewayApis, +{ let runtime = Runtime::new().unwrap(); CoreSDK { runtime, diff --git a/src/workflow/concurrency_manager.rs b/src/workflow/concurrency_manager.rs index 22bb2dbb7..b6200b66b 100644 --- a/src/workflow/concurrency_manager.rs +++ b/src/workflow/concurrency_manager.rs @@ -126,9 +126,12 @@ impl WorkflowConcurrencyManager { /// # Panics /// If the workflow machine thread panicked pub fn shutdown(&self) { + let mut wf_thread = self.wf_thread.lock(); + if wf_thread.is_none() { + return; + } let _ = self.shutdown_chan.send(true); - self.wf_thread - .lock() + wf_thread .take() .unwrap() .join() diff --git a/tests/integ_tests/simple_wf_tests.rs b/tests/integ_tests/simple_wf_tests.rs index 28ddb0828..a5dd3347c 100644 --- a/tests/integ_tests/simple_wf_tests.rs +++ b/tests/integ_tests/simple_wf_tests.rs @@ -213,7 +213,7 @@ fn timer_cancel_workflow() { #[test] fn timer_immediate_cancel_workflow() { - let task_q = "timer_cancel_workflow"; + let task_q = "timer_immediate_cancel_workflow"; let core = get_integ_core(); let mut rng = rand::thread_rng(); let workflow_id: u32 = rng.gen(); @@ -305,6 +305,26 @@ fn parallel_workflows_same_queue() { handles.into_iter().for_each(|h| h.join().unwrap()); } +// Ideally this would be a unit test, but returning a pending future with mockall bloats the mock +// code a bunch and just isn't worth it. Do it when https://github.com/asomers/mockall/issues/189 is +// fixed. +#[test] +fn shutdown_aborts_actively_blocked_poll() { + let task_q = "shutdown_aborts_actively_blocked_poll"; + let core = Arc::new(get_integ_core()); + // Begin the poll, and request shutdown from another thread after a small period of time. + let tcore = core.clone(); + let handle = std::thread::spawn(move || { + std::thread::sleep(Duration::from_millis(100)); + tcore.shutdown().unwrap(); + }); + assert_matches!(core.poll_task(task_q).unwrap_err(), CoreError::ShuttingDown); + handle.join().unwrap(); + // Ensure double-shutdown doesn't explode + core.shutdown().unwrap(); + assert_matches!(core.poll_task(task_q).unwrap_err(), CoreError::ShuttingDown); +} + #[test] fn fail_wf_task() { let task_q = "fail_wf_task";