diff --git a/Cargo.toml b/Cargo.toml index 13f9c1b2e..63703200b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,7 +26,7 @@ tracing-opentelemetry = "0.11" tracing-subscriber = "0.2" url = "2.2" rand = "0.8.3" - +uuid = { version = "0.8.2", features = ["v4"] } [dependencies.tonic] version = "0.4" #path = "../tonic/tonic" diff --git a/src/lib.rs b/src/lib.rs index 7e87cfcfb..438c22a49 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -17,6 +17,8 @@ mod protosext; pub use pollers::{ServerGateway, ServerGatewayOptions}; pub use url::Url; +use crate::pollers::ServerGatewayApis; +use crate::protos::temporal::api::workflowservice::v1::StartWorkflowExecutionResponse; use crate::{ machines::{ ActivationListener, DrivenWorkflow, InconvertibleCommandError, ProtoCommand, WFCommand, @@ -41,6 +43,7 @@ use crate::{ protosext::{HistoryInfo, HistoryInfoError}, }; use dashmap::DashMap; +use std::sync::Arc; use std::{ convert::TryInto, sync::mpsc::{self, Receiver, SendError, Sender}, @@ -65,6 +68,9 @@ pub trait Core { /// Tell the core that some work has been completed - whether as a result of running workflow /// code or executing an activity. fn complete_task(&self, req: CompleteTaskReq) -> Result<()>; + + /// Returns an instance of ServerGateway. + fn server_gateway(&self) -> Result>; } /// Holds various configuration information required to call [init] @@ -87,7 +93,7 @@ pub fn init(opts: CoreInitOptions) -> Result { Ok(CoreSDK { runtime, - server_gateway: work_provider, + server_gateway: Arc::new(work_provider), workflow_machines: Default::default(), workflow_task_tokens: Default::default(), }) @@ -101,10 +107,13 @@ pub enum TaskQueue { _Activity(String), } -struct CoreSDK { +struct CoreSDK +where + WP: ServerGatewayApis + 'static, +{ runtime: Runtime, /// Provides work in the form of responses the server would send from polling task Qs - server_gateway: WP, + server_gateway: Arc, /// Key is run id workflow_machines: DashMap>)>, /// Maps task tokens to workflow run ids @@ -113,14 +122,14 @@ struct CoreSDK { impl Core for CoreSDK where - WP: PollWorkflowTaskQueueApi + RespondWorkflowTaskCompletedApi, + WP: ServerGatewayApis, { #[instrument(skip(self))] fn poll_task(&self, task_queue: &str) -> Result { // This will block forever in the event there is no work from the server let work = self .runtime - .block_on(self.server_gateway.poll(task_queue))?; + .block_on(self.server_gateway.poll_workflow_task(task_queue))?; let run_id = match &work.workflow_execution { Some(we) => { self.instantiate_workflow_if_needed(we); @@ -176,8 +185,10 @@ where self.push_lang_commands(&run_id, success)?; if let Some(mut machines) = self.workflow_machines.get_mut(&run_id) { let commands = machines.0.get_commands(); - self.runtime - .block_on(self.server_gateway.complete(task_token, commands))?; + self.runtime.block_on( + self.server_gateway + .complete_workflow_task(task_token, commands), + )?; } } Status::Failed(_) => {} @@ -193,12 +204,13 @@ where _ => Err(CoreError::MalformedCompletion(req)), } } + + fn server_gateway(&self) -> Result> { + Ok(self.server_gateway.clone()) + } } -impl CoreSDK -where - WP: PollWorkflowTaskQueueApi, -{ +impl CoreSDK { fn instantiate_workflow_if_needed(&self, workflow_execution: &WorkflowExecution) { if self .workflow_machines @@ -244,26 +256,40 @@ where /// implementor. #[cfg_attr(test, mockall::automock)] #[async_trait::async_trait] -pub(crate) trait PollWorkflowTaskQueueApi { +pub trait PollWorkflowTaskQueueApi { /// Fetch new work. Should block indefinitely if there is no work. - async fn poll(&self, task_queue: &str) -> Result; + async fn poll_workflow_task(&self, task_queue: &str) -> Result; } /// Implementors can complete tasks as would've been issued by [Core::poll]. The real implementor /// sends the completed tasks to the server. #[cfg_attr(test, mockall::automock)] #[async_trait::async_trait] -pub(crate) trait RespondWorkflowTaskCompletedApi { +pub trait RespondWorkflowTaskCompletedApi { /// Complete a task by sending it to the server. `task_token` is the task token that would've /// been received from [PollWorkflowTaskQueueApi::poll]. `commands` is a list of new commands /// to send to the server, such as starting a timer. - async fn complete( + async fn complete_workflow_task( &self, task_token: Vec, commands: Vec, ) -> Result; } +/// Implementors should send StartWorkflowExecutionRequest to the server and pass the response back. +#[cfg_attr(test, mockall::automock)] +#[async_trait::async_trait] +pub trait StartWorkflowExecutionApi { + /// Starts workflow execution. + async fn start_workflow( + &self, + namespace: &str, + task_queue: &str, + workflow_id: &str, + workflow_type: &str, + ) -> Result; +} + /// The [DrivenWorkflow] trait expects to be called to make progress, but the [CoreSDKService] /// expects to be polled by the lang sdk. This struct acts as the bridge between the two, buffering /// output from calls to [DrivenWorkflow] and offering them to [CoreSDKService] @@ -424,17 +450,17 @@ mod test { let mut tasks = VecDeque::from(responses); let mut mock_gateway = MockServerGateway::new(); mock_gateway - .expect_poll() + .expect_poll_workflow_task() .returning(move |_| Ok(tasks.pop_front().unwrap())); // Response not really important here mock_gateway - .expect_complete() + .expect_complete_workflow_task() .returning(|_, _| Ok(RespondWorkflowTaskCompletedResponse::default())); let runtime = Runtime::new().unwrap(); let core = CoreSDK { runtime, - server_gateway: mock_gateway, + server_gateway: Arc::new(mock_gateway), workflow_machines: DashMap::new(), workflow_task_tokens: DashMap::new(), }; diff --git a/src/pollers/mod.rs b/src/pollers/mod.rs index cf3c6125f..f59bb786a 100644 --- a/src/pollers/mod.rs +++ b/src/pollers/mod.rs @@ -1,8 +1,10 @@ use std::time::Duration; use crate::machines::ProtoCommand; +use crate::protos::temporal::api::common::v1::WorkflowType; use crate::protos::temporal::api::workflowservice::v1::{ RespondWorkflowTaskCompletedRequest, RespondWorkflowTaskCompletedResponse, + StartWorkflowExecutionRequest, StartWorkflowExecutionResponse, }; use crate::{ protos::temporal::api::enums::v1::TaskQueueKind, @@ -11,10 +13,11 @@ use crate::{ protos::temporal::api::workflowservice::v1::{ PollWorkflowTaskQueueRequest, PollWorkflowTaskQueueResponse, }, - PollWorkflowTaskQueueApi, RespondWorkflowTaskCompletedApi, Result, + PollWorkflowTaskQueueApi, RespondWorkflowTaskCompletedApi, Result, StartWorkflowExecutionApi, }; use tonic::{transport::Channel, Request, Status}; use url::Url; +use uuid::Uuid; /// Options for the connection to the temporal server #[derive(Clone, Debug)] @@ -76,9 +79,19 @@ pub struct ServerGateway { pub opts: ServerGatewayOptions, } +pub trait ServerGatewayApis: + PollWorkflowTaskQueueApi + RespondWorkflowTaskCompletedApi + StartWorkflowExecutionApi +{ +} + +impl ServerGatewayApis for T where + T: PollWorkflowTaskQueueApi + RespondWorkflowTaskCompletedApi + StartWorkflowExecutionApi +{ +} + #[async_trait::async_trait] impl PollWorkflowTaskQueueApi for ServerGateway { - async fn poll(&self, task_queue: &str) -> Result { + async fn poll_workflow_task(&self, task_queue: &str) -> Result { let request = PollWorkflowTaskQueueRequest { namespace: self.opts.namespace.to_string(), task_queue: Some(TaskQueue { @@ -100,7 +113,7 @@ impl PollWorkflowTaskQueueApi for ServerGateway { #[async_trait::async_trait] impl RespondWorkflowTaskCompletedApi for ServerGateway { - async fn complete( + async fn complete_workflow_task( &self, task_token: Vec, commands: Vec, @@ -122,15 +135,57 @@ impl RespondWorkflowTaskCompletedApi for ServerGateway { } } +#[async_trait::async_trait] +impl StartWorkflowExecutionApi for ServerGateway { + async fn start_workflow( + &self, + namespace: &str, + task_queue: &str, + workflow_id: &str, + workflow_type: &str, + ) -> Result { + let request_id = Uuid::new_v4().to_string(); + + Ok(self + .service + .clone() + .start_workflow_execution(StartWorkflowExecutionRequest { + namespace: namespace.to_string(), + workflow_id: workflow_id.to_string(), + workflow_type: Some(WorkflowType { + name: workflow_type.to_string(), + }), + task_queue: Some(TaskQueue { + name: task_queue.to_string(), + kind: 0, + }), + request_id, + ..Default::default() + }) + .await? + .into_inner()) + } +} + #[cfg(test)] mockall::mock! { - pub(crate) ServerGateway {} + pub ServerGateway {} #[async_trait::async_trait] impl PollWorkflowTaskQueueApi for ServerGateway { - async fn poll(&self, task_queue: &str) -> Result; + async fn poll_workflow_task(&self, task_queue: &str) -> Result; } #[async_trait::async_trait] impl RespondWorkflowTaskCompletedApi for ServerGateway { - async fn complete(&self, task_token: Vec, commands: Vec) -> Result; + async fn complete_workflow_task(&self, task_token: Vec, commands: Vec) -> Result; + } + #[async_trait::async_trait] + impl StartWorkflowExecutionApi for ServerGateway { + async fn start_workflow( + &self, + namespace: &str, + task_queue: &str, + workflow_id: &str, + workflow_type: &str, + ) -> Result; } } diff --git a/tests/integ_tests/poller_test.rs b/tests/integ_tests/poller_test.rs index cf0cd866c..e256d2d3c 100644 --- a/tests/integ_tests/poller_test.rs +++ b/tests/integ_tests/poller_test.rs @@ -16,18 +16,23 @@ use temporal_sdk_core::{ const TASK_QUEUE: &str = "test-tq"; const NAMESPACE: &str = "default"; -// TODO try to consolidate this into the SDK code so we don't need to create another runtime. #[tokio::main] -async fn create_workflow() -> (String, String, ServerGatewayOptions) { +async fn create_workflow(core: &dyn Core, workflow_id: &str) -> String { + core.server_gateway() + .unwrap() + .start_workflow(NAMESPACE, TASK_QUEUE, workflow_id, "test-workflow") + .await + .unwrap() + .run_id +} + +#[test] +fn timer_workflow() { let temporal_server_address = match env::var("TEMPORAL_SERVICE_ADDRESS") { Ok(addr) => addr, Err(_) => "http://localhost:7233".to_owned(), }; - - let mut rng = rand::thread_rng(); let url = Url::try_from(&*temporal_server_address).unwrap(); - let workflow_id: u32 = rng.gen(); - let request_id: u32 = rng.gen(); let gateway_opts = ServerGatewayOptions { namespace: NAMESPACE.to_string(), identity: "none".to_string(), @@ -35,36 +40,10 @@ async fn create_workflow() -> (String, String, ServerGatewayOptions) { long_poll_timeout: Duration::from_secs(60), target_url: url, }; - let mut gateway = gateway_opts.connect().await.unwrap(); - let response = gateway - .service - .start_workflow_execution(StartWorkflowExecutionRequest { - namespace: NAMESPACE.to_string(), - workflow_id: workflow_id.to_string(), - workflow_type: Some(WorkflowType { - name: "test-workflow".to_string(), - }), - task_queue: Some(TaskQueue { - name: TASK_QUEUE.to_string(), - kind: 0, - }), - request_id: request_id.to_string(), - ..Default::default() - }) - .await - .unwrap(); - ( - workflow_id.to_string(), - response.into_inner().run_id, - gateway_opts, - ) -} - -#[test] -fn timer_workflow() { - let (workflow_id, run_id, gateway_opts) = dbg!(create_workflow()); let core = temporal_sdk_core::init(CoreInitOptions { gateway_opts }).unwrap(); let mut rng = rand::thread_rng(); + let workflow_id: u32 = rng.gen(); + let run_id = dbg!(create_workflow(&core, &workflow_id.to_string())); let timer_id: String = rng.gen::().to_string(); let task = dbg!(core.poll_task(TASK_QUEUE).unwrap()); core.complete_task(CompleteTaskReq::ok_from_api_attrs(