diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 415db733f..1aaecbc72 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -31,4 +31,14 @@ steps: - docker-compose#v3.0.0: run: unit-test config: .buildkite/docker/docker-compose.yaml + - label: "integ-test" + agents: + queue: "default" + docker: "*" + command: "cargo test --test integ_tests" + timeout_in_minutes: 15 + plugins: + - docker-compose#v3.0.0: + run: unit-test + config: .buildkite/docker/docker-compose.yaml - wait diff --git a/Cargo.toml b/Cargo.toml index 6eb064d9c..13f9c1b2e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,11 +21,17 @@ prost = "0.7" prost-types = "0.7" thiserror = "1.0" tokio = { version = "1.1", features = ["rt", "rt-multi-thread"] } -tonic = "0.4" tracing = { version = "0.1", features = ["log"] } -tracing-opentelemetry = "0.10" +tracing-opentelemetry = "0.11" tracing-subscriber = "0.2" url = "2.2" +rand = "0.8.3" + +[dependencies.tonic] +version = "0.4" +#path = "../tonic/tonic" +# Using our fork for now which fixes grpc-timeout header getting stripped +git = "https://github.com/temporalio/tonic" [dependencies.rustfsm] path = "fsm" @@ -40,3 +46,10 @@ tonic-build = "0.4" [workspace] members = [".", "fsm"] + +[[test]] +name = "integ_tests" +path = "tests/main.rs" +# Prevents autodiscovery, and hence these getting run with `cargo test`. Run with +# `cargo test --test integ_tests` +test = false \ No newline at end of file diff --git a/README.md b/README.md index cc352b7ba..d4f1667b4 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,6 @@ -Core SDK that can be used as a base for all other SDKs. +[![Build status](https://badge.buildkite.com/c23f47f4a827f04daece909963bd3a248496f0cdbabfbecee4.svg)](https://buildkite.com/temporal/core-sdk?branch=master) + +Core SDK that can be used as a base for all other Temporal SDKs. # Getting started This repo uses a submodule for upstream protobuf files. The path `protos/api_upstream` is a diff --git a/src/lib.rs b/src/lib.rs index 7a760a37d..385e0c067 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,23 +3,28 @@ //! This crate provides a basis for creating new Temporal SDKs without completely starting from //! scratch +#[cfg(test)] +#[macro_use] +pub extern crate assert_matches; #[macro_use] extern crate tracing; #[cfg(test)] #[macro_use] -extern crate assert_matches; - -pub mod protos; +extern crate mockall; mod machines; mod pollers; +pub mod protos; mod protosext; +pub use pollers::{ServerGateway, ServerGatewayOptions}; +pub use url::Url; + use crate::{ machines::{ - ActivationListener, DrivenWorkflow, InconvertibleCommandError, WFCommand, WorkflowMachines, + ActivationListener, DrivenWorkflow, InconvertibleCommandError, ProtoCommand, WFCommand, + WorkflowMachines, }, - pollers::ServerGatewayOptions, protos::{ coresdk::{ complete_task_req::Completion, wf_activation_completion::Status, CompleteTaskReq, Task, @@ -31,7 +36,9 @@ use crate::{ WorkflowExecutionCanceledEventAttributes, WorkflowExecutionSignaledEventAttributes, WorkflowExecutionStartedEventAttributes, }, - workflowservice::v1::PollWorkflowTaskQueueResponse, + workflowservice::v1::{ + PollWorkflowTaskQueueResponse, RespondWorkflowTaskCompletedResponse, + }, }, }, protosext::{HistoryInfo, HistoryInfoError}, @@ -41,12 +48,14 @@ use dashmap::DashMap; use std::{ convert::TryInto, sync::mpsc::{self, Receiver, SendError, Sender}, + time::Duration, }; use tokio::runtime::Runtime; -use url::Url; +use tonic::codegen::http::uri::InvalidUri; /// A result alias having [CoreError] as the error type pub type Result = std::result::Result; +const DEFAULT_LONG_POLL_TIMEOUT: u64 = 60; /// This trait is the primary way by which language specific SDKs interact with the core SDK. It is /// expected that only one instance of an implementation will exist for the lifetime of the @@ -57,7 +66,7 @@ pub trait Core { /// language SDK's responsibility to call the appropriate code with the provided inputs. /// /// TODO: Examples - fn poll_task(&self) -> Result; + fn poll_task(&self, task_queue: &str) -> Result; /// Tell the core that some work has been completed - whether as a result of running workflow /// code or executing an activity. @@ -79,33 +88,48 @@ pub struct CoreInitOptions { /// A string that should be unique to the exact worker code/binary being executed pub worker_binary_id: String, + + /// Optional tokio runtime + pub runtime: Option, } /// Initializes an instance of the core sdk and establishes a connection to the temporal server. /// /// Note: Also creates tokio runtime that will be used for all client-server interactions. pub fn init(opts: CoreInitOptions) -> Result { - let runtime = Runtime::new().map_err(CoreError::TokioInitError)?; + let runtime = opts + .runtime + .map(Ok) + .unwrap_or_else(|| Runtime::new().map_err(CoreError::TokioInitError))?; let gateway_opts = ServerGatewayOptions { namespace: opts.namespace, identity: opts.identity, worker_binary_id: opts.worker_binary_id, + long_poll_timeout: Duration::from_secs(DEFAULT_LONG_POLL_TIMEOUT), }; // Initialize server client let work_provider = runtime.block_on(gateway_opts.connect(opts.target_url))?; Ok(CoreSDK { runtime, - work_provider, + server_gateway: work_provider, workflow_machines: Default::default(), workflow_task_tokens: Default::default(), }) } +/// Type of task queue to poll. +pub enum TaskQueue { + /// Workflow task + Workflow(String), + /// Activity task + _Activity(String), +} + struct CoreSDK { runtime: Runtime, /// Provides work in the form of responses the server would send from polling task Qs - work_provider: WP, + server_gateway: WP, /// Key is run id workflow_machines: DashMap>)>, /// Maps task tokens to workflow run ids @@ -114,14 +138,14 @@ struct CoreSDK { impl Core for CoreSDK where - WP: WorkflowTaskProvider, + WP: PollWorkflowTaskQueueApi + RespondWorkflowTaskCompletedApi, { #[instrument(skip(self))] - fn poll_task(&self) -> Result { + fn poll_task(&self, task_queue: &str) -> Result { // This will block forever in the event there is no work from the server let work = self .runtime - .block_on(self.work_provider.get_work("TODO: Real task queue"))?; + .block_on(self.server_gateway.poll(task_queue))?; let run_id = match &work.workflow_execution { Some(we) => { self.instantiate_workflow_if_needed(we); @@ -167,13 +191,20 @@ where status: Some(wfstatus), })), } => { - let wf_run_id = self + let run_id = self .workflow_task_tokens .get(&task_token) .map(|x| x.value().clone()) - .ok_or(CoreError::NothingFoundForTaskToken(task_token))?; + .ok_or(CoreError::NothingFoundForTaskToken(task_token.clone()))?; match wfstatus { - Status::Successful(success) => self.push_lang_commands(&wf_run_id, success)?, + Status::Successful(success) => { + self.push_lang_commands(&run_id, success)?; + if let Some(mut machines) = self.workflow_machines.get_mut(&run_id) { + let commands = machines.0.get_commands(); + self.runtime + .block_on(self.server_gateway.complete(task_token, commands))?; + } + } Status::Failed(_) => {} } Ok(()) @@ -186,13 +217,12 @@ where } _ => Err(CoreError::MalformedCompletion(req)), } - // TODO: Get fsm commands and send them to server (get_commands) } } impl CoreSDK where - WP: WorkflowTaskProvider, + WP: PollWorkflowTaskQueueApi, { fn instantiate_workflow_if_needed(&self, workflow_execution: &WorkflowExecution) { if self @@ -230,6 +260,11 @@ where .unwrap() .1 .send(cmds)?; + self.workflow_machines + .get_mut(run_id) + .unwrap() + .0 + .event_loop(); Ok(()) } } @@ -238,9 +273,22 @@ where /// implementor. #[cfg_attr(test, mockall::automock)] #[async_trait::async_trait] -pub trait WorkflowTaskProvider { +pub trait PollWorkflowTaskQueueApi { /// Fetch new work. Should block indefinitely if there is no work. - async fn get_work(&self, task_queue: &str) -> Result; + async fn poll(&self, task_queue: &str) -> Result; +} + +/// Implementors can provide new work to the SDK. The connection to the server is the real +/// implementor. +#[cfg_attr(test, mockall::automock)] +#[async_trait::async_trait] +pub trait RespondWorkflowTaskCompletedApi { + /// Fetch new work. Should block indefinitely if there is no work. + async fn complete( + &self, + task_token: Vec, + commands: Vec, + ) -> Result; } /// The [DrivenWorkflow] trait expects to be called to make progress, but the [CoreSDKService] @@ -324,11 +372,16 @@ pub enum CoreError { TonicTransportError(#[from] tonic::transport::Error), /// Failed to initialize tokio runtime: {0:?} TokioInitError(std::io::Error), + /// Invalid URI: {0:?} + InvalidUri(#[from] InvalidUri), } #[cfg(test)] mod test { - use super::*; + use std::collections::VecDeque; + + use tracing::Level; + use crate::{ machines::test_help::TestHistoryBuilder, protos::{ @@ -342,9 +395,20 @@ mod test { }, }, }; - use std::collections::VecDeque; - use tracing::Level; + mock! { + ServerGateway {} + #[async_trait::async_trait] + impl PollWorkflowTaskQueueApi for ServerGateway { + async fn poll(&self, task_queue: &str) -> Result; + } + #[async_trait::async_trait] + impl RespondWorkflowTaskCompletedApi for ServerGateway { + async fn complete(&self, task_token: Vec, commands: Vec) -> Result; + } + } + + use super::*; #[test] fn workflow_bridge() { let s = span!(Level::DEBUG, "Test start"); @@ -353,6 +417,7 @@ mod test { let wfid = "fake_wf_id"; let run_id = "fake_run_id"; let timer_id = "fake_timer"; + let task_queue = "test-task-queue"; let mut t = TestHistoryBuilder::default(); t.add_by_type(EventType::WorkflowExecutionStarted); @@ -401,20 +466,20 @@ mod test { let responses = vec![first_response, second_response]; let mut tasks = VecDeque::from(responses); - let mut mock_provider = MockWorkflowTaskProvider::new(); + let mut mock_provider = MockServerGateway::new(); mock_provider - .expect_get_work() + .expect_poll() .returning(move |_| Ok(tasks.pop_front().unwrap())); let runtime = Runtime::new().unwrap(); let core = CoreSDK { runtime, - work_provider: mock_provider, + server_gateway: mock_provider, workflow_machines: DashMap::new(), workflow_task_tokens: DashMap::new(), }; - let res = dbg!(core.poll_task().unwrap()); + let res = dbg!(core.poll_task(task_queue).unwrap()); // TODO: uggo assert_matches!( res.get_wf_jobs().as_slice(), @@ -436,7 +501,7 @@ mod test { .unwrap(); dbg!("sent completion w/ start timer"); - let res = dbg!(core.poll_task().unwrap()); + let res = dbg!(core.poll_task(task_queue).unwrap()); // TODO: uggo assert_matches!( res.get_wf_jobs().as_slice(), diff --git a/src/machines/workflow_machines.rs b/src/machines/workflow_machines.rs index 1315e0df1..5b98767c9 100644 --- a/src/machines/workflow_machines.rs +++ b/src/machines/workflow_machines.rs @@ -381,7 +381,7 @@ impl WorkflowMachines { .expect("We have just ensured this is populated") } - fn event_loop(&mut self) -> Result<()> { + pub fn event_loop(&mut self) -> Result<()> { let results = self.drive_me.iterate_wf()?; self.handle_driven_results(results); diff --git a/src/pollers/mod.rs b/src/pollers/mod.rs index d26b5d9d5..f8ecb808b 100644 --- a/src/pollers/mod.rs +++ b/src/pollers/mod.rs @@ -1,38 +1,64 @@ -use crate::protos::temporal::api::enums::v1::TaskQueueKind; -use crate::protos::temporal::api::taskqueue::v1::TaskQueue; -use crate::protos::temporal::api::workflowservice::v1::workflow_service_client::WorkflowServiceClient; +use std::time::Duration; + +use crate::machines::ProtoCommand; use crate::protos::temporal::api::workflowservice::v1::{ - PollWorkflowTaskQueueRequest, PollWorkflowTaskQueueResponse, + RespondWorkflowTaskCompletedRequest, RespondWorkflowTaskCompletedResponse, +}; +use crate::{ + protos::temporal::api::enums::v1::TaskQueueKind, + protos::temporal::api::taskqueue::v1::TaskQueue, + protos::temporal::api::workflowservice::v1::workflow_service_client::WorkflowServiceClient, + protos::temporal::api::workflowservice::v1::{ + PollWorkflowTaskQueueRequest, PollWorkflowTaskQueueResponse, + }, + PollWorkflowTaskQueueApi, RespondWorkflowTaskCompletedApi, Result, }; -use crate::Result; -use crate::WorkflowTaskProvider; +use tonic::{transport::Channel, Request, Status}; use url::Url; #[derive(Clone)] -pub(crate) struct ServerGatewayOptions { +pub struct ServerGatewayOptions { pub namespace: String, pub identity: String, pub worker_binary_id: String, + pub long_poll_timeout: Duration, } impl ServerGatewayOptions { - pub(crate) async fn connect(&self, target_url: Url) -> Result { - let service = WorkflowServiceClient::connect(target_url.to_string()).await?; + pub async fn connect(&self, target_url: Url) -> Result { + let channel = Channel::from_shared(target_url.to_string())? + .connect() + .await?; + let service = WorkflowServiceClient::with_interceptor(channel, intercept); Ok(ServerGateway { service, opts: self.clone(), }) } } + +/// This function will get called on each outbound request. Returning a +/// `Status` here will cancel the request and have that status returned to +/// the caller. +fn intercept(mut req: Request<()>) -> Result, Status> { + // TODO convert error + let metadata = req.metadata_mut(); + metadata.insert("grpc-timeout", "50000m".parse().unwrap()); + metadata.insert("client-name", "core-sdk".parse().unwrap()); + println!("Intercepting request: {:?}", req); + Ok(req) +} + /// Provides -pub(crate) struct ServerGateway { - service: WorkflowServiceClient, - opts: ServerGatewayOptions, +pub struct ServerGateway { + pub service: WorkflowServiceClient, + pub opts: ServerGatewayOptions, } -impl ServerGateway { +#[async_trait::async_trait] +impl PollWorkflowTaskQueueApi for ServerGateway { async fn poll(&self, task_queue: &str) -> Result { - let request = tonic::Request::new(PollWorkflowTaskQueueRequest { + let request = PollWorkflowTaskQueueRequest { namespace: self.opts.namespace.to_string(), task_queue: Some(TaskQueue { name: task_queue.to_string(), @@ -40,7 +66,7 @@ impl ServerGateway { }), identity: self.opts.identity.to_string(), binary_checksum: self.opts.worker_binary_id.to_string(), - }); + }; Ok(self .service @@ -52,8 +78,25 @@ impl ServerGateway { } #[async_trait::async_trait] -impl WorkflowTaskProvider for ServerGateway { - async fn get_work(&self, task_queue: &str) -> Result { - self.poll(task_queue).await +impl RespondWorkflowTaskCompletedApi for ServerGateway { + async fn complete( + &self, + task_token: Vec, + commands: Vec, + ) -> Result { + let request = RespondWorkflowTaskCompletedRequest { + task_token, + commands, + identity: self.opts.identity.to_string(), + binary_checksum: self.opts.worker_binary_id.to_string(), + namespace: self.opts.namespace.to_string(), + ..Default::default() + }; + Ok(self + .service + .clone() + .respond_workflow_task_completed(request) + .await? + .into_inner()) } } diff --git a/tests/integ_tests/poller_test.rs b/tests/integ_tests/poller_test.rs new file mode 100644 index 000000000..b4a35153b --- /dev/null +++ b/tests/integ_tests/poller_test.rs @@ -0,0 +1,95 @@ +use rand::{self, Rng}; +use std::{convert::TryFrom, time::Duration}; +use temporal_sdk_core::{ + protos::{ + coresdk::CompleteTaskReq, + temporal::api::command::v1::{ + CompleteWorkflowExecutionCommandAttributes, StartTimerCommandAttributes, + }, + temporal::api::common::v1::WorkflowType, + temporal::api::taskqueue::v1::TaskQueue, + temporal::api::workflowservice::v1::StartWorkflowExecutionRequest, + }, + Core, CoreInitOptions, ServerGatewayOptions, Url, +}; + +const TASK_QUEUE: &str = "test-tq"; +const NAMESPACE: &str = "default"; + +const TARGET_URI: &'static str = "http://localhost:7233"; + +// TODO try to consolidate this into the SDK code so we don't need to create another runtime. +#[tokio::main] +async fn create_workflow() -> (String, String) { + let mut rng = rand::thread_rng(); + let workflow_id: u32 = rng.gen(); + let request_id: u32 = rng.gen(); + let gateway_opts = ServerGatewayOptions { + namespace: NAMESPACE.to_string(), + identity: "none".to_string(), + worker_binary_id: "".to_string(), + long_poll_timeout: Duration::from_secs(60), + }; + let mut gateway = gateway_opts + .connect(Url::try_from(TARGET_URI).unwrap()) + .await + .unwrap(); + let response = gateway + .service + .start_workflow_execution(StartWorkflowExecutionRequest { + namespace: NAMESPACE.to_string(), + workflow_id: workflow_id.to_string(), + workflow_type: Some(WorkflowType { + name: "test-workflow".to_string(), + }), + task_queue: Some(TaskQueue { + name: TASK_QUEUE.to_string(), + kind: 0, + }), + request_id: request_id.to_string(), + ..Default::default() + }) + .await + .unwrap(); + (workflow_id.to_string(), response.into_inner().run_id) +} + +#[test] +fn timer_workflow() { + let (workflow_id, run_id) = dbg!(create_workflow()); + let core = temporal_sdk_core::init(CoreInitOptions { + target_url: Url::try_from(TARGET_URI).unwrap(), + namespace: NAMESPACE.to_string(), + identity: "none".to_string(), + worker_binary_id: "".to_string(), + runtime: None, + }) + .unwrap(); + let mut rng = rand::thread_rng(); + let timer_id: String = rng.gen::().to_string(); + let task = dbg!(core.poll_task(TASK_QUEUE).unwrap()); + // TODO verify + core.complete_task(CompleteTaskReq::ok_from_api_attrs( + StartTimerCommandAttributes { + timer_id: timer_id.to_string(), + start_to_fire_timeout: Some(Duration::from_secs(1).into()), + ..Default::default() + } + .into(), + task.task_token, + )) + .unwrap(); + dbg!("sent completion w/ start timer"); + let task = dbg!(core.poll_task(TASK_QUEUE).unwrap()); + // TODO verify + core.complete_task(CompleteTaskReq::ok_from_api_attrs( + CompleteWorkflowExecutionCommandAttributes { result: None }.into(), + task.task_token, + )) + .unwrap(); + dbg!( + "sent workflow done, completed workflow", + workflow_id, + run_id + ); +} diff --git a/tests/main.rs b/tests/main.rs new file mode 100644 index 000000000..976cf28e4 --- /dev/null +++ b/tests/main.rs @@ -0,0 +1,4 @@ +#[cfg(test)] +mod integ_tests { + mod poller_test; +}