diff --git a/src/lib.rs b/src/lib.rs index 9fc45b150..49a752131 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -311,7 +311,9 @@ impl CoreSDK { tokio::time::sleep(Duration::from_millis(100)).await; } }; - let pollfut = self.server_gateway.poll_workflow_task(task_queue); + let pollfut = self + .server_gateway + .poll_workflow_task(task_queue.to_owned()); tokio::select! { _ = shutdownfut => { Err(CoreError::ShuttingDown) diff --git a/src/machines/test_help/mod.rs b/src/machines/test_help/mod.rs index f865268a2..dc5aaec8d 100644 --- a/src/machines/test_help/mod.rs +++ b/src/machines/test_help/mod.rs @@ -8,7 +8,7 @@ pub(super) use workflow_driver::{CommandSender, TestWorkflowDriver}; use crate::workflow::WorkflowConcurrencyManager; use crate::{ - pollers::MockServerGateway, + pollers::MockServerGatewayApis, protos::temporal::api::common::v1::WorkflowExecution, protos::temporal::api::history::v1::History, protos::temporal::api::workflowservice::v1::{ @@ -21,7 +21,7 @@ use std::sync::atomic::AtomicBool; use std::{collections::VecDeque, sync::Arc}; use tokio::runtime::Runtime; -pub(crate) type FakeCore = CoreSDK; +pub(crate) type FakeCore = CoreSDK; /// Given identifiers for a workflow/run, and a test history builder, construct an instance of /// the core SDK with a mock server gateway that will produce the responses as appropriate. @@ -55,7 +55,7 @@ pub(crate) fn build_fake_core( .collect(); let mut tasks = VecDeque::from(responses); - let mut mock_gateway = MockServerGateway::new(); + let mut mock_gateway = MockServerGatewayApis::new(); mock_gateway .expect_poll_workflow_task() .times(response_batches.len()) diff --git a/src/pollers/mod.rs b/src/pollers/mod.rs index e5080880c..6563033bc 100644 --- a/src/pollers/mod.rs +++ b/src/pollers/mod.rs @@ -1,5 +1,9 @@ use std::time::Duration; +use crate::protos::temporal::api::common::v1::{Payloads, WorkflowExecution}; +use crate::protos::temporal::api::workflowservice::v1::{ + SignalWorkflowExecutionRequest, SignalWorkflowExecutionResponse, +}; use crate::{ machines::ProtoCommand, protos::temporal::api::{ @@ -15,10 +19,6 @@ use crate::{ StartWorkflowExecutionResponse, }, }, - workflow::{ - PollWorkflowTaskQueueApi, RespondWorkflowTaskCompletedApi, RespondWorkflowTaskFailedApi, - StartWorkflowExecutionApi, - }, CoreError, Result, }; use tonic::{transport::Channel, Request, Status}; @@ -86,34 +86,94 @@ pub struct ServerGateway { pub opts: ServerGatewayOptions, } -/// This trait provides ways to call the temporal server itself -pub trait ServerGatewayApis: - PollWorkflowTaskQueueApi - + RespondWorkflowTaskCompletedApi - + StartWorkflowExecutionApi - + RespondWorkflowTaskFailedApi -{ -} +/// This trait provides ways to call the temporal server +#[cfg_attr(test, mockall::automock)] +#[async_trait::async_trait] +pub trait ServerGatewayApis { + /// Starts workflow execution. + async fn start_workflow( + &self, + namespace: String, + task_queue: String, + workflow_id: String, + workflow_type: String, + ) -> Result; + + /// Fetch new work. Should block indefinitely if there is no work. + async fn poll_workflow_task(&self, task_queue: String) + -> Result; -impl ServerGatewayApis for T where - T: PollWorkflowTaskQueueApi - + RespondWorkflowTaskCompletedApi - + StartWorkflowExecutionApi - + RespondWorkflowTaskFailedApi -{ + /// Complete a task by sending it to the server. `task_token` is the task token that would've + /// been received from [PollWorkflowTaskQueueApi::poll]. `commands` is a list of new commands + /// to send to the server, such as starting a timer. + async fn complete_workflow_task( + &self, + task_token: Vec, + commands: Vec, + ) -> Result; + + /// Fail task by sending the failure to the server. `task_token` is the task token that would've + /// been received from [PollWorkflowTaskQueueApi::poll]. + async fn fail_workflow_task( + &self, + task_token: Vec, + cause: WorkflowTaskFailedCause, + failure: Option, + ) -> Result; + + /// Send a signal to a certain workflow instance + async fn signal_workflow_execution( + &self, + workflow_id: String, + run_id: String, + signal_name: String, + payloads: Option, + ) -> Result; } #[async_trait::async_trait] -impl PollWorkflowTaskQueueApi for ServerGateway { - async fn poll_workflow_task(&self, task_queue: &str) -> Result { +impl ServerGatewayApis for ServerGateway { + async fn start_workflow( + &self, + namespace: String, + task_queue: String, + workflow_id: String, + workflow_type: String, + ) -> Result { + let request_id = Uuid::new_v4().to_string(); + + Ok(self + .service + .clone() + .start_workflow_execution(StartWorkflowExecutionRequest { + namespace, + workflow_id, + workflow_type: Some(WorkflowType { + name: workflow_type, + }), + task_queue: Some(TaskQueue { + name: task_queue, + kind: 0, + }), + request_id, + ..Default::default() + }) + .await? + .into_inner()) + } + + async fn poll_workflow_task( + &self, + task_queue: String, + ) -> Result { let request = PollWorkflowTaskQueueRequest { - namespace: self.opts.namespace.to_string(), + namespace: self.opts.namespace.clone(), task_queue: Some(TaskQueue { - name: task_queue.to_string(), + name: task_queue, kind: TaskQueueKind::Unspecified as i32, }), - identity: self.opts.identity.to_string(), - binary_checksum: self.opts.worker_binary_id.to_string(), + identity: self.opts.identity.clone(), + binary_checksum: self.opts.worker_binary_id.clone(), }; Ok(self @@ -123,10 +183,7 @@ impl PollWorkflowTaskQueueApi for ServerGateway { .await? .into_inner()) } -} -#[async_trait::async_trait] -impl RespondWorkflowTaskCompletedApi for ServerGateway { async fn complete_workflow_task( &self, task_token: Vec, @@ -135,9 +192,9 @@ impl RespondWorkflowTaskCompletedApi for ServerGateway { let request = RespondWorkflowTaskCompletedRequest { task_token, commands, - identity: self.opts.identity.to_string(), - binary_checksum: self.opts.worker_binary_id.to_string(), - namespace: self.opts.namespace.to_string(), + identity: self.opts.identity.clone(), + binary_checksum: self.opts.worker_binary_id.clone(), + namespace: self.opts.namespace.clone(), ..Default::default() }; match self @@ -156,10 +213,7 @@ impl RespondWorkflowTaskCompletedApi for ServerGateway { } } } -} -#[async_trait::async_trait] -impl RespondWorkflowTaskFailedApi for ServerGateway { async fn fail_workflow_task( &self, task_token: Vec, @@ -170,9 +224,9 @@ impl RespondWorkflowTaskFailedApi for ServerGateway { task_token, cause: cause as i32, failure, - identity: self.opts.identity.to_string(), - binary_checksum: self.opts.worker_binary_id.to_string(), - namespace: self.opts.namespace.to_string(), + identity: self.opts.identity.clone(), + binary_checksum: self.opts.worker_binary_id.clone(), + namespace: self.opts.namespace.clone(), }; Ok(self .service @@ -181,70 +235,29 @@ impl RespondWorkflowTaskFailedApi for ServerGateway { .await? .into_inner()) } -} -#[async_trait::async_trait] -impl StartWorkflowExecutionApi for ServerGateway { - async fn start_workflow( + async fn signal_workflow_execution( &self, - namespace: &str, - task_queue: &str, - workflow_id: &str, - workflow_type: &str, - ) -> Result { - let request_id = Uuid::new_v4().to_string(); - + workflow_id: String, + run_id: String, + signal_name: String, + payloads: Option, + ) -> Result { Ok(self .service .clone() - .start_workflow_execution(StartWorkflowExecutionRequest { - namespace: namespace.to_string(), - workflow_id: workflow_id.to_string(), - workflow_type: Some(WorkflowType { - name: workflow_type.to_string(), + .signal_workflow_execution(SignalWorkflowExecutionRequest { + namespace: self.opts.namespace.clone(), + workflow_execution: Some(WorkflowExecution { + workflow_id, + run_id, }), - task_queue: Some(TaskQueue { - name: task_queue.to_string(), - kind: 0, - }), - request_id, + signal_name, + input: payloads, + identity: self.opts.identity.clone(), ..Default::default() }) .await? .into_inner()) } } - -#[cfg(test)] -mockall::mock! { - pub ServerGateway {} - #[async_trait::async_trait] - impl PollWorkflowTaskQueueApi for ServerGateway { - async fn poll_workflow_task(&self, task_queue: &str) -> Result; - } - #[async_trait::async_trait] - impl RespondWorkflowTaskCompletedApi for ServerGateway { - async fn complete_workflow_task(&self, task_token: Vec, commands: Vec) -> Result; - } - - #[async_trait::async_trait] - impl RespondWorkflowTaskFailedApi for ServerGateway { - async fn fail_workflow_task( - &self, - task_token: Vec, - cause: WorkflowTaskFailedCause, - failure: Option, - ) -> Result; - } - - #[async_trait::async_trait] - impl StartWorkflowExecutionApi for ServerGateway { - async fn start_workflow( - &self, - namespace: &str, - task_queue: &str, - workflow_id: &str, - workflow_type: &str, - ) -> Result; - } -} diff --git a/src/workflow/mod.rs b/src/workflow/mod.rs index 2950de4a4..712e01a47 100644 --- a/src/workflow/mod.rs +++ b/src/workflow/mod.rs @@ -7,77 +7,16 @@ pub(crate) use concurrency_manager::WorkflowConcurrencyManager; pub(crate) use driven_workflow::{ActivationListener, DrivenWorkflow, WorkflowFetcher}; use crate::{ - machines::{ProtoCommand, WFCommand, WorkflowMachines}, + machines::{WFCommand, WorkflowMachines}, protos::{ coresdk::WfActivation, - temporal::api::{ - enums::v1::WorkflowTaskFailedCause, - failure::v1::Failure, - history::v1::History, - workflowservice::v1::{ - PollWorkflowTaskQueueResponse, RespondWorkflowTaskCompletedResponse, - RespondWorkflowTaskFailedResponse, StartWorkflowExecutionResponse, - }, - }, + temporal::api::{history::v1::History, workflowservice::v1::PollWorkflowTaskQueueResponse}, }, protosext::HistoryInfo, CoreError, Result, }; use std::sync::mpsc::Sender; -/// Implementors can provide new workflow tasks to the SDK. The connection to the server is the real -/// implementor. -#[cfg_attr(test, mockall::automock)] -#[async_trait::async_trait] -pub trait PollWorkflowTaskQueueApi { - /// Fetch new work. Should block indefinitely if there is no work. - async fn poll_workflow_task(&self, task_queue: &str) -> Result; -} - -/// Implementors can complete tasks issued by [Core::poll]. The real implementor sends the completed -/// tasks to the server. -#[cfg_attr(test, mockall::automock)] -#[async_trait::async_trait] -pub trait RespondWorkflowTaskCompletedApi { - /// Complete a task by sending it to the server. `task_token` is the task token that would've - /// been received from [PollWorkflowTaskQueueApi::poll]. `commands` is a list of new commands - /// to send to the server, such as starting a timer. - async fn complete_workflow_task( - &self, - task_token: Vec, - commands: Vec, - ) -> Result; -} - -/// Implementors can fail workflow tasks issued by [Core::poll]. The real implementor sends the -/// failed tasks to the server. -#[cfg_attr(test, mockall::automock)] -#[async_trait::async_trait] -pub trait RespondWorkflowTaskFailedApi { - /// Fail task by sending the failure to the server. `task_token` is the task token that would've - /// been received from [PollWorkflowTaskQueueApi::poll]. - async fn fail_workflow_task( - &self, - task_token: Vec, - cause: WorkflowTaskFailedCause, - failure: Option, - ) -> Result; -} - -/// Implementors should send StartWorkflowExecutionRequest to the server and pass the response back. -#[cfg_attr(test, mockall::automock)] -#[async_trait::async_trait] -pub trait StartWorkflowExecutionApi { - /// Starts workflow execution. - async fn start_workflow( - &self, - namespace: &str, - task_queue: &str, - workflow_id: &str, - workflow_type: &str, - ) -> Result; -} - /// Manages an instance of a [WorkflowMachines], which is not thread-safe, as well as other data /// associated with that specific workflow run. pub(crate) struct WorkflowManager { diff --git a/tests/integ_tests/simple_wf_tests.rs b/tests/integ_tests/simple_wf_tests.rs index a5dd3347c..36ded4eca 100644 --- a/tests/integ_tests/simple_wf_tests.rs +++ b/tests/integ_tests/simple_wf_tests.rs @@ -1,4 +1,5 @@ use assert_matches::assert_matches; +use futures::Future; use rand::{self, Rng}; use std::{ collections::HashMap, @@ -20,15 +21,12 @@ use temporal_sdk_core::{ CancelTimerCommandAttributes, CompleteWorkflowExecutionCommandAttributes, FailWorkflowExecutionCommandAttributes, StartTimerCommandAttributes, }, - common::v1::WorkflowExecution, enums::v1::WorkflowTaskFailedCause, failure::v1::Failure, - workflowservice::v1::SignalWorkflowExecutionRequest, }, }, - Core, CoreError, CoreInitOptions, ServerGatewayOptions, Url, + Core, CoreError, CoreInitOptions, ServerGatewayApis, ServerGatewayOptions, Url, }; -use tokio::runtime::Runtime; // TODO: These tests can get broken permanently if they break one time and the server is not // restarted, because pulling from the same task queue produces tasks for the previous failed @@ -38,25 +36,31 @@ use tokio::runtime::Runtime; // at the end matches. const NAMESPACE: &str = "default"; +type GwApi = Arc; -#[tokio::main] -async fn create_workflow( +fn create_workflow( core: &dyn Core, task_q: &str, workflow_id: &str, wf_type: Option<&str>, ) -> String { - core.server_gateway() - .unwrap() - .start_workflow( - NAMESPACE, - task_q, - workflow_id, - wf_type.unwrap_or("test-workflow"), + with_gw(core, |gw: GwApi| async move { + gw.start_workflow( + NAMESPACE.to_owned(), + task_q.to_owned(), + workflow_id.to_owned(), + wf_type.unwrap_or("test-workflow").to_owned(), ) .await .unwrap() .run_id + }) +} + +#[tokio::main] +async fn with_gw Fout, Fout: Future>(core: &dyn Core, fun: F) -> Fout::Output { + let gw = core.server_gateway().unwrap(); + fun(gw).await } fn get_integ_server_options() -> ServerGatewayOptions { @@ -434,33 +438,23 @@ fn signal_workflow() { .unwrap(); // Send the signals to the server - let rt = Runtime::new().unwrap(); - let mut client = rt.block_on(async { get_integ_server_options().connect().await.unwrap() }); - let wfe = WorkflowExecution { - workflow_id: workflow_id.to_string(), - run_id: res.get_run_id().unwrap().to_string(), - }; - rt.block_on(async { - client - .service - .signal_workflow_execution(SignalWorkflowExecutionRequest { - namespace: "default".to_string(), - workflow_execution: Some(wfe.clone()), - signal_name: signal_id_1.to_string(), - ..Default::default() - }) - .await - .unwrap(); - client - .service - .signal_workflow_execution(SignalWorkflowExecutionRequest { - namespace: "default".to_string(), - workflow_execution: Some(wfe), - signal_name: signal_id_2.to_string(), - ..Default::default() - }) - .await - .unwrap(); + with_gw(&core, |gw: GwApi| async move { + gw.signal_workflow_execution( + workflow_id.to_string(), + res.get_run_id().unwrap().to_string(), + signal_id_1.to_string(), + None, + ) + .await + .unwrap(); + gw.signal_workflow_execution( + workflow_id.to_string(), + res.get_run_id().unwrap().to_string(), + signal_id_2.to_string(), + None, + ) + .await + .unwrap(); }); let res = core.poll_task(task_q).unwrap(); @@ -513,31 +507,24 @@ fn signal_workflow_signal_not_handled_on_workflow_completion() { },] ); - // Send a signal to the server before we complete the workflow - let rt = Runtime::new().unwrap(); - let mut client = rt.block_on(async { get_integ_server_options().connect().await.unwrap() }); - let wfe = WorkflowExecution { - workflow_id: workflow_id.to_string(), - run_id: res.get_run_id().unwrap().to_string(), - }; - rt.block_on(async { - client - .service - .signal_workflow_execution(SignalWorkflowExecutionRequest { - namespace: "default".to_string(), - workflow_execution: Some(wfe.clone()), - signal_name: signal_id_1.to_string(), - ..Default::default() - }) - .await - .unwrap(); + let task_token = res.task_token.clone(); + // Send the signals to the server + with_gw(&core, |gw: GwApi| async move { + gw.signal_workflow_execution( + workflow_id.to_string(), + res.get_run_id().unwrap().to_string(), + signal_id_1.to_string(), + None, + ) + .await + .unwrap(); }); // Send completion - not having seen a poll response with a signal in it yet let unhandled = core .complete_task(TaskCompletion::ok_from_api_attrs( vec![CompleteWorkflowExecutionCommandAttributes { result: None }.into()], - res.task_token, + task_token, )) .unwrap_err(); assert_matches!(unhandled, CoreError::UnhandledCommandWhenCompleting);