Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ tracing-opentelemetry = "0.11"
tracing-subscriber = "0.2"
url = "2.2"
rand = "0.8.3"

uuid = { version = "0.8.2", features = ["v4"] }
[dependencies.tonic]
version = "0.4"
#path = "../tonic/tonic"
Expand Down
62 changes: 44 additions & 18 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ mod protosext;
pub use pollers::{ServerGateway, ServerGatewayOptions};
pub use url::Url;

use crate::pollers::ServerGatewayApis;
use crate::protos::temporal::api::workflowservice::v1::StartWorkflowExecutionResponse;
use crate::{
machines::{
ActivationListener, DrivenWorkflow, InconvertibleCommandError, ProtoCommand, WFCommand,
Expand All @@ -41,6 +43,7 @@ use crate::{
protosext::{HistoryInfo, HistoryInfoError},
};
use dashmap::DashMap;
use std::sync::Arc;
use std::{
convert::TryInto,
sync::mpsc::{self, Receiver, SendError, Sender},
Expand All @@ -65,6 +68,9 @@ pub trait Core {
/// Tell the core that some work has been completed - whether as a result of running workflow
/// code or executing an activity.
fn complete_task(&self, req: CompleteTaskReq) -> Result<()>;

/// Returns an instance of ServerGateway.
fn server_gateway(&self) -> Result<Arc<dyn ServerGatewayApis>>;
}

/// Holds various configuration information required to call [init]
Expand All @@ -87,7 +93,7 @@ pub fn init(opts: CoreInitOptions) -> Result<impl Core> {

Ok(CoreSDK {
runtime,
server_gateway: work_provider,
server_gateway: Arc::new(work_provider),
workflow_machines: Default::default(),
workflow_task_tokens: Default::default(),
})
Expand All @@ -101,10 +107,13 @@ pub enum TaskQueue {
_Activity(String),
}

struct CoreSDK<WP> {
struct CoreSDK<WP>
where
WP: ServerGatewayApis + 'static,
{
runtime: Runtime,
/// Provides work in the form of responses the server would send from polling task Qs
server_gateway: WP,
server_gateway: Arc<WP>,
/// Key is run id
workflow_machines: DashMap<String, (WorkflowMachines, Sender<Vec<WFCommand>>)>,
/// Maps task tokens to workflow run ids
Expand All @@ -113,14 +122,14 @@ struct CoreSDK<WP> {

impl<WP> Core for CoreSDK<WP>
where
WP: PollWorkflowTaskQueueApi + RespondWorkflowTaskCompletedApi,
WP: ServerGatewayApis,
{
#[instrument(skip(self))]
fn poll_task(&self, task_queue: &str) -> Result<Task, CoreError> {
// This will block forever in the event there is no work from the server
let work = self
.runtime
.block_on(self.server_gateway.poll(task_queue))?;
.block_on(self.server_gateway.poll_workflow_task(task_queue))?;
let run_id = match &work.workflow_execution {
Some(we) => {
self.instantiate_workflow_if_needed(we);
Expand Down Expand Up @@ -176,8 +185,10 @@ where
self.push_lang_commands(&run_id, success)?;
if let Some(mut machines) = self.workflow_machines.get_mut(&run_id) {
let commands = machines.0.get_commands();
self.runtime
.block_on(self.server_gateway.complete(task_token, commands))?;
self.runtime.block_on(
self.server_gateway
.complete_workflow_task(task_token, commands),
)?;
}
}
Status::Failed(_) => {}
Expand All @@ -193,12 +204,13 @@ where
_ => Err(CoreError::MalformedCompletion(req)),
}
}

fn server_gateway(&self) -> Result<Arc<dyn ServerGatewayApis>> {
Ok(self.server_gateway.clone())
}
}

impl<WP> CoreSDK<WP>
where
WP: PollWorkflowTaskQueueApi,
{
impl<WP: ServerGatewayApis> CoreSDK<WP> {
fn instantiate_workflow_if_needed(&self, workflow_execution: &WorkflowExecution) {
if self
.workflow_machines
Expand Down Expand Up @@ -244,26 +256,40 @@ where
/// implementor.
#[cfg_attr(test, mockall::automock)]
#[async_trait::async_trait]
pub(crate) trait PollWorkflowTaskQueueApi {
pub trait PollWorkflowTaskQueueApi {
/// Fetch new work. Should block indefinitely if there is no work.
async fn poll(&self, task_queue: &str) -> Result<PollWorkflowTaskQueueResponse>;
async fn poll_workflow_task(&self, task_queue: &str) -> Result<PollWorkflowTaskQueueResponse>;
}

/// Implementors can complete tasks as would've been issued by [Core::poll]. The real implementor
/// sends the completed tasks to the server.
#[cfg_attr(test, mockall::automock)]
#[async_trait::async_trait]
pub(crate) trait RespondWorkflowTaskCompletedApi {
pub trait RespondWorkflowTaskCompletedApi {
/// Complete a task by sending it to the server. `task_token` is the task token that would've
/// been received from [PollWorkflowTaskQueueApi::poll]. `commands` is a list of new commands
/// to send to the server, such as starting a timer.
async fn complete(
async fn complete_workflow_task(
&self,
task_token: Vec<u8>,
commands: Vec<ProtoCommand>,
) -> Result<RespondWorkflowTaskCompletedResponse>;
}

/// Implementors should send StartWorkflowExecutionRequest to the server and pass the response back.
#[cfg_attr(test, mockall::automock)]
#[async_trait::async_trait]
pub trait StartWorkflowExecutionApi {
/// Starts workflow execution.
async fn start_workflow(
&self,
namespace: &str,
task_queue: &str,
workflow_id: &str,
workflow_type: &str,
) -> Result<StartWorkflowExecutionResponse>;
}

/// The [DrivenWorkflow] trait expects to be called to make progress, but the [CoreSDKService]
/// expects to be polled by the lang sdk. This struct acts as the bridge between the two, buffering
/// output from calls to [DrivenWorkflow] and offering them to [CoreSDKService]
Expand Down Expand Up @@ -424,17 +450,17 @@ mod test {
let mut tasks = VecDeque::from(responses);
let mut mock_gateway = MockServerGateway::new();
mock_gateway
.expect_poll()
.expect_poll_workflow_task()
.returning(move |_| Ok(tasks.pop_front().unwrap()));
// Response not really important here
mock_gateway
.expect_complete()
.expect_complete_workflow_task()
.returning(|_, _| Ok(RespondWorkflowTaskCompletedResponse::default()));

let runtime = Runtime::new().unwrap();
let core = CoreSDK {
runtime,
server_gateway: mock_gateway,
server_gateway: Arc::new(mock_gateway),
workflow_machines: DashMap::new(),
workflow_task_tokens: DashMap::new(),
};
Expand Down
67 changes: 61 additions & 6 deletions src/pollers/mod.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use std::time::Duration;

use crate::machines::ProtoCommand;
use crate::protos::temporal::api::common::v1::WorkflowType;
use crate::protos::temporal::api::workflowservice::v1::{
RespondWorkflowTaskCompletedRequest, RespondWorkflowTaskCompletedResponse,
StartWorkflowExecutionRequest, StartWorkflowExecutionResponse,
};
use crate::{
protos::temporal::api::enums::v1::TaskQueueKind,
Expand All @@ -11,10 +13,11 @@ use crate::{
protos::temporal::api::workflowservice::v1::{
PollWorkflowTaskQueueRequest, PollWorkflowTaskQueueResponse,
},
PollWorkflowTaskQueueApi, RespondWorkflowTaskCompletedApi, Result,
PollWorkflowTaskQueueApi, RespondWorkflowTaskCompletedApi, Result, StartWorkflowExecutionApi,
};
use tonic::{transport::Channel, Request, Status};
use url::Url;
use uuid::Uuid;

/// Options for the connection to the temporal server
#[derive(Clone, Debug)]
Expand Down Expand Up @@ -76,9 +79,19 @@ pub struct ServerGateway {
pub opts: ServerGatewayOptions,
}

pub trait ServerGatewayApis:
PollWorkflowTaskQueueApi + RespondWorkflowTaskCompletedApi + StartWorkflowExecutionApi
{
}

impl<T> ServerGatewayApis for T where
T: PollWorkflowTaskQueueApi + RespondWorkflowTaskCompletedApi + StartWorkflowExecutionApi
{
}

#[async_trait::async_trait]
impl PollWorkflowTaskQueueApi for ServerGateway {
async fn poll(&self, task_queue: &str) -> Result<PollWorkflowTaskQueueResponse> {
async fn poll_workflow_task(&self, task_queue: &str) -> Result<PollWorkflowTaskQueueResponse> {
let request = PollWorkflowTaskQueueRequest {
namespace: self.opts.namespace.to_string(),
task_queue: Some(TaskQueue {
Expand All @@ -100,7 +113,7 @@ impl PollWorkflowTaskQueueApi for ServerGateway {

#[async_trait::async_trait]
impl RespondWorkflowTaskCompletedApi for ServerGateway {
async fn complete(
async fn complete_workflow_task(
&self,
task_token: Vec<u8>,
commands: Vec<ProtoCommand>,
Expand All @@ -122,15 +135,57 @@ impl RespondWorkflowTaskCompletedApi for ServerGateway {
}
}

#[async_trait::async_trait]
impl StartWorkflowExecutionApi for ServerGateway {
async fn start_workflow(
&self,
namespace: &str,
task_queue: &str,
workflow_id: &str,
workflow_type: &str,
) -> Result<StartWorkflowExecutionResponse> {
let request_id = Uuid::new_v4().to_string();

Ok(self
.service
.clone()
.start_workflow_execution(StartWorkflowExecutionRequest {
namespace: namespace.to_string(),
workflow_id: workflow_id.to_string(),
workflow_type: Some(WorkflowType {
name: workflow_type.to_string(),
}),
task_queue: Some(TaskQueue {
name: task_queue.to_string(),
kind: 0,
}),
request_id,
..Default::default()
})
.await?
.into_inner())
}
}

#[cfg(test)]
mockall::mock! {
pub(crate) ServerGateway {}
pub ServerGateway {}
#[async_trait::async_trait]
impl PollWorkflowTaskQueueApi for ServerGateway {
async fn poll(&self, task_queue: &str) -> Result<PollWorkflowTaskQueueResponse>;
async fn poll_workflow_task(&self, task_queue: &str) -> Result<PollWorkflowTaskQueueResponse>;
}
#[async_trait::async_trait]
impl RespondWorkflowTaskCompletedApi for ServerGateway {
async fn complete(&self, task_token: Vec<u8>, commands: Vec<ProtoCommand>) -> Result<RespondWorkflowTaskCompletedResponse>;
async fn complete_workflow_task(&self, task_token: Vec<u8>, commands: Vec<ProtoCommand>) -> Result<RespondWorkflowTaskCompletedResponse>;
}
#[async_trait::async_trait]
impl StartWorkflowExecutionApi for ServerGateway {
async fn start_workflow(
&self,
namespace: &str,
task_queue: &str,
workflow_id: &str,
workflow_type: &str,
) -> Result<StartWorkflowExecutionResponse>;
}
}
47 changes: 13 additions & 34 deletions tests/integ_tests/poller_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,55 +16,34 @@ use temporal_sdk_core::{
const TASK_QUEUE: &str = "test-tq";
const NAMESPACE: &str = "default";

// TODO try to consolidate this into the SDK code so we don't need to create another runtime.
#[tokio::main]
async fn create_workflow() -> (String, String, ServerGatewayOptions) {
async fn create_workflow(core: &dyn Core, workflow_id: &str) -> String {
core.server_gateway()
.unwrap()
.start_workflow(NAMESPACE, TASK_QUEUE, workflow_id, "test-workflow")
.await
.unwrap()
.run_id
}

#[test]
fn timer_workflow() {
let temporal_server_address = match env::var("TEMPORAL_SERVICE_ADDRESS") {
Ok(addr) => addr,
Err(_) => "http://localhost:7233".to_owned(),
};

let mut rng = rand::thread_rng();
let url = Url::try_from(&*temporal_server_address).unwrap();
let workflow_id: u32 = rng.gen();
let request_id: u32 = rng.gen();
let gateway_opts = ServerGatewayOptions {
namespace: NAMESPACE.to_string(),
identity: "none".to_string(),
worker_binary_id: "".to_string(),
long_poll_timeout: Duration::from_secs(60),
target_url: url,
};
let mut gateway = gateway_opts.connect().await.unwrap();
let response = gateway
.service
.start_workflow_execution(StartWorkflowExecutionRequest {
namespace: NAMESPACE.to_string(),
workflow_id: workflow_id.to_string(),
workflow_type: Some(WorkflowType {
name: "test-workflow".to_string(),
}),
task_queue: Some(TaskQueue {
name: TASK_QUEUE.to_string(),
kind: 0,
}),
request_id: request_id.to_string(),
..Default::default()
})
.await
.unwrap();
(
workflow_id.to_string(),
response.into_inner().run_id,
gateway_opts,
)
}

#[test]
fn timer_workflow() {
let (workflow_id, run_id, gateway_opts) = dbg!(create_workflow());
let core = temporal_sdk_core::init(CoreInitOptions { gateway_opts }).unwrap();
let mut rng = rand::thread_rng();
let workflow_id: u32 = rng.gen();
let run_id = dbg!(create_workflow(&core, &workflow_id.to_string()));
let timer_id: String = rng.gen::<u32>().to_string();
let task = dbg!(core.poll_task(TASK_QUEUE).unwrap());
core.complete_task(CompleteTaskReq::ok_from_api_attrs(
Expand Down