Skip to content
Merged
76 changes: 50 additions & 26 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ use std::{
mpsc::SendError,
Arc,
},
time::Duration,
};
use tokio::runtime::Runtime;
use tonic::codegen::http::uri::InvalidUri;
Expand Down Expand Up @@ -111,10 +112,7 @@ pub fn init(opts: CoreInitOptions) -> Result<impl Core> {
})
}

struct CoreSDK<WP>
where
WP: ServerGatewayApis + 'static,
{
struct CoreSDK<WP> {
runtime: Runtime,
/// Provides work in the form of responses the server would send from polling task Qs
server_gateway: Arc<WP>,
Expand All @@ -133,7 +131,7 @@ where

impl<WP> Core for CoreSDK<WP>
where
WP: ServerGatewayApis + Send + Sync,
WP: ServerGatewayApis + Send + Sync + 'static,
{
#[instrument(skip(self), fields(pending_activation))]
fn poll_task(&self, task_queue: &str) -> Result<Task> {
Expand All @@ -158,29 +156,33 @@ where
return Err(CoreError::ShuttingDown);
}

// This will block forever in the event there is no work from the server
let work = self
.runtime
.block_on(self.server_gateway.poll_workflow_task(task_queue))?;
let task_token = work.task_token.clone();
debug!(
task_token = %fmt_task_token(&task_token),
"Received workflow task from server"
);

let (next_activation, run_id) = self.instantiate_or_update_workflow(work)?;
// This will block forever (unless interrupted by shutdown) in the event there is no work
// from the server
match self.poll_server(task_queue) {
Ok(work) => {
let task_token = work.task_token.clone();
debug!(
task_token = %fmt_task_token(&task_token),
"Received workflow task from server"
);

let (next_activation, run_id) = self.instantiate_or_update_workflow(work)?;

if next_activation.more_activations_needed {
self.pending_activations.push(PendingActivation {
run_id,
task_token: task_token.clone(),
});
}

if next_activation.more_activations_needed {
self.pending_activations.push(PendingActivation {
run_id,
task_token: task_token.clone(),
});
Ok(Task {
task_token,
variant: next_activation.activation.map(Into::into),
})
}
Err(CoreError::ShuttingDown) => self.poll_task(task_queue),
Err(e) => Err(e),
}

Ok(Task {
task_token,
variant: next_activation.activation.map(Into::into),
})
}

#[instrument(skip(self))]
Expand Down Expand Up @@ -297,6 +299,28 @@ impl<WP: ServerGatewayApis> CoreSDK<WP> {
Ok(())
}

/// Blocks polling the server until it responds, or until the shutdown flag is set (aborting
/// the poll)
fn poll_server(&self, task_queue: &str) -> Result<PollWorkflowTaskQueueResponse> {
self.runtime.block_on(async {
let shutdownfut = async {
loop {
if self.shutdown_requested.load(Ordering::Relaxed) {
break;
}
tokio::time::sleep(Duration::from_millis(100)).await;
}
};
let pollfut = self.server_gateway.poll_workflow_task(task_queue);
tokio::select! {
_ = shutdownfut => {
Err(CoreError::ShuttingDown)
}
r = pollfut => r
}
})
}

/// Remove a workflow run from the cache entirely
fn evict_run(&self, run_id: &str) {
self.workflow_machines.evict(run_id);
Expand Down
9 changes: 8 additions & 1 deletion src/machines/test_help/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use crate::{
protos::temporal::api::workflowservice::v1::{
PollWorkflowTaskQueueResponse, RespondWorkflowTaskCompletedResponse,
},
CoreSDK,
CoreSDK, ServerGatewayApis,
};
use rand::{thread_rng, Rng};
use std::sync::atomic::AtomicBool;
Expand Down Expand Up @@ -65,6 +65,13 @@ pub(crate) fn build_fake_core(
.expect_complete_workflow_task()
.returning(|_, _| Ok(RespondWorkflowTaskCompletedResponse::default()));

fake_core_from_mock(mock_gateway)
}

pub(crate) fn fake_core_from_mock<MT>(mock_gateway: MT) -> CoreSDK<MT>
where
MT: ServerGatewayApis,
{
let runtime = Runtime::new().unwrap();
CoreSDK {
runtime,
Expand Down
7 changes: 5 additions & 2 deletions src/workflow/concurrency_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,12 @@ impl WorkflowConcurrencyManager {
/// # Panics
/// If the workflow machine thread panicked
pub fn shutdown(&self) {
let mut wf_thread = self.wf_thread.lock();
if wf_thread.is_none() {
return;
}
let _ = self.shutdown_chan.send(true);
self.wf_thread
.lock()
wf_thread
.take()
.unwrap()
.join()
Expand Down
22 changes: 21 additions & 1 deletion tests/integ_tests/simple_wf_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ fn timer_cancel_workflow() {

#[test]
fn timer_immediate_cancel_workflow() {
let task_q = "timer_cancel_workflow";
let task_q = "timer_immediate_cancel_workflow";
let core = get_integ_core();
let mut rng = rand::thread_rng();
let workflow_id: u32 = rng.gen();
Expand Down Expand Up @@ -305,6 +305,26 @@ fn parallel_workflows_same_queue() {
handles.into_iter().for_each(|h| h.join().unwrap());
}

// Ideally this would be a unit test, but returning a pending future with mockall bloats the mock
// code a bunch and just isn't worth it. Do it when https://github.com/asomers/mockall/issues/189 is
// fixed.
#[test]
fn shutdown_aborts_actively_blocked_poll() {
let task_q = "shutdown_aborts_actively_blocked_poll";
let core = Arc::new(get_integ_core());
// Begin the poll, and request shutdown from another thread after a small period of time.
let tcore = core.clone();
let handle = std::thread::spawn(move || {
std::thread::sleep(Duration::from_millis(100));
tcore.shutdown().unwrap();
});
assert_matches!(core.poll_task(task_q).unwrap_err(), CoreError::ShuttingDown);
handle.join().unwrap();
// Ensure double-shutdown doesn't explode
core.shutdown().unwrap();
assert_matches!(core.poll_task(task_q).unwrap_err(), CoreError::ShuttingDown);
}

#[test]
fn fail_wf_task() {
let task_q = "fail_wf_task";
Expand Down