Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,9 @@ impl<WP: ServerGatewayApis> CoreSDK<WP> {
tokio::time::sleep(Duration::from_millis(100)).await;
}
};
let pollfut = self.server_gateway.poll_workflow_task(task_queue);
let pollfut = self
.server_gateway
.poll_workflow_task(task_queue.to_owned());
tokio::select! {
_ = shutdownfut => {
Err(CoreError::ShuttingDown)
Expand Down
6 changes: 3 additions & 3 deletions src/machines/test_help/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ pub(super) use workflow_driver::{CommandSender, TestWorkflowDriver};

use crate::workflow::WorkflowConcurrencyManager;
use crate::{
pollers::MockServerGateway,
pollers::MockServerGatewayApis,
protos::temporal::api::common::v1::WorkflowExecution,
protos::temporal::api::history::v1::History,
protos::temporal::api::workflowservice::v1::{
Expand All @@ -21,7 +21,7 @@ use std::sync::atomic::AtomicBool;
use std::{collections::VecDeque, sync::Arc};
use tokio::runtime::Runtime;

pub(crate) type FakeCore = CoreSDK<MockServerGateway>;
pub(crate) type FakeCore = CoreSDK<MockServerGatewayApis>;

/// Given identifiers for a workflow/run, and a test history builder, construct an instance of
/// the core SDK with a mock server gateway that will produce the responses as appropriate.
Expand Down Expand Up @@ -55,7 +55,7 @@ pub(crate) fn build_fake_core(
.collect();

let mut tasks = VecDeque::from(responses);
let mut mock_gateway = MockServerGateway::new();
let mut mock_gateway = MockServerGatewayApis::new();
mock_gateway
.expect_poll_workflow_task()
.times(response_batches.len())
Expand Down
195 changes: 104 additions & 91 deletions src/pollers/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
use std::time::Duration;

use crate::protos::temporal::api::common::v1::{Payloads, WorkflowExecution};
use crate::protos::temporal::api::workflowservice::v1::{
SignalWorkflowExecutionRequest, SignalWorkflowExecutionResponse,
};
use crate::{
machines::ProtoCommand,
protos::temporal::api::{
Expand All @@ -15,10 +19,6 @@ use crate::{
StartWorkflowExecutionResponse,
},
},
workflow::{
PollWorkflowTaskQueueApi, RespondWorkflowTaskCompletedApi, RespondWorkflowTaskFailedApi,
StartWorkflowExecutionApi,
},
CoreError, Result,
};
use tonic::{transport::Channel, Request, Status};
Expand Down Expand Up @@ -86,34 +86,94 @@ pub struct ServerGateway {
pub opts: ServerGatewayOptions,
}

/// This trait provides ways to call the temporal server itself
pub trait ServerGatewayApis:
PollWorkflowTaskQueueApi
+ RespondWorkflowTaskCompletedApi
+ StartWorkflowExecutionApi
+ RespondWorkflowTaskFailedApi
{
}
/// This trait provides ways to call the temporal server
#[cfg_attr(test, mockall::automock)]
#[async_trait::async_trait]
pub trait ServerGatewayApis {
/// Starts workflow execution.
async fn start_workflow(
&self,
namespace: String,
task_queue: String,
workflow_id: String,
workflow_type: String,
) -> Result<StartWorkflowExecutionResponse>;

/// Fetch new work. Should block indefinitely if there is no work.
async fn poll_workflow_task(&self, task_queue: String)
-> Result<PollWorkflowTaskQueueResponse>;

impl<T> ServerGatewayApis for T where
T: PollWorkflowTaskQueueApi
+ RespondWorkflowTaskCompletedApi
+ StartWorkflowExecutionApi
+ RespondWorkflowTaskFailedApi
{
/// Complete a task by sending it to the server. `task_token` is the task token that would've
/// been received from [PollWorkflowTaskQueueApi::poll]. `commands` is a list of new commands
/// to send to the server, such as starting a timer.
async fn complete_workflow_task(
&self,
task_token: Vec<u8>,
commands: Vec<ProtoCommand>,
) -> Result<RespondWorkflowTaskCompletedResponse>;

/// Fail task by sending the failure to the server. `task_token` is the task token that would've
/// been received from [PollWorkflowTaskQueueApi::poll].
async fn fail_workflow_task(
&self,
task_token: Vec<u8>,
cause: WorkflowTaskFailedCause,
failure: Option<Failure>,
) -> Result<RespondWorkflowTaskFailedResponse>;

/// Send a signal to a certain workflow instance
async fn signal_workflow_execution(
&self,
workflow_id: String,
run_id: String,
signal_name: String,
payloads: Option<Payloads>,
) -> Result<SignalWorkflowExecutionResponse>;
}

#[async_trait::async_trait]
impl PollWorkflowTaskQueueApi for ServerGateway {
async fn poll_workflow_task(&self, task_queue: &str) -> Result<PollWorkflowTaskQueueResponse> {
impl ServerGatewayApis for ServerGateway {
async fn start_workflow(
&self,
namespace: String,
task_queue: String,
workflow_id: String,
workflow_type: String,
) -> Result<StartWorkflowExecutionResponse> {
let request_id = Uuid::new_v4().to_string();

Ok(self
.service
.clone()
.start_workflow_execution(StartWorkflowExecutionRequest {
namespace,
workflow_id,
workflow_type: Some(WorkflowType {
name: workflow_type,
}),
task_queue: Some(TaskQueue {
name: task_queue,
kind: 0,
}),
request_id,
..Default::default()
})
.await?
.into_inner())
}

async fn poll_workflow_task(
&self,
task_queue: String,
) -> Result<PollWorkflowTaskQueueResponse> {
let request = PollWorkflowTaskQueueRequest {
namespace: self.opts.namespace.to_string(),
namespace: self.opts.namespace.clone(),
task_queue: Some(TaskQueue {
name: task_queue.to_string(),
name: task_queue,
kind: TaskQueueKind::Unspecified as i32,
}),
identity: self.opts.identity.to_string(),
binary_checksum: self.opts.worker_binary_id.to_string(),
identity: self.opts.identity.clone(),
binary_checksum: self.opts.worker_binary_id.clone(),
};

Ok(self
Expand All @@ -123,10 +183,7 @@ impl PollWorkflowTaskQueueApi for ServerGateway {
.await?
.into_inner())
}
}

#[async_trait::async_trait]
impl RespondWorkflowTaskCompletedApi for ServerGateway {
async fn complete_workflow_task(
&self,
task_token: Vec<u8>,
Expand All @@ -135,9 +192,9 @@ impl RespondWorkflowTaskCompletedApi for ServerGateway {
let request = RespondWorkflowTaskCompletedRequest {
task_token,
commands,
identity: self.opts.identity.to_string(),
binary_checksum: self.opts.worker_binary_id.to_string(),
namespace: self.opts.namespace.to_string(),
identity: self.opts.identity.clone(),
binary_checksum: self.opts.worker_binary_id.clone(),
namespace: self.opts.namespace.clone(),
..Default::default()
};
match self
Expand All @@ -156,10 +213,7 @@ impl RespondWorkflowTaskCompletedApi for ServerGateway {
}
}
}
}

#[async_trait::async_trait]
impl RespondWorkflowTaskFailedApi for ServerGateway {
async fn fail_workflow_task(
&self,
task_token: Vec<u8>,
Expand All @@ -170,9 +224,9 @@ impl RespondWorkflowTaskFailedApi for ServerGateway {
task_token,
cause: cause as i32,
failure,
identity: self.opts.identity.to_string(),
binary_checksum: self.opts.worker_binary_id.to_string(),
namespace: self.opts.namespace.to_string(),
identity: self.opts.identity.clone(),
binary_checksum: self.opts.worker_binary_id.clone(),
namespace: self.opts.namespace.clone(),
};
Ok(self
.service
Expand All @@ -181,70 +235,29 @@ impl RespondWorkflowTaskFailedApi for ServerGateway {
.await?
.into_inner())
}
}

#[async_trait::async_trait]
impl StartWorkflowExecutionApi for ServerGateway {
async fn start_workflow(
async fn signal_workflow_execution(
&self,
namespace: &str,
task_queue: &str,
workflow_id: &str,
workflow_type: &str,
) -> Result<StartWorkflowExecutionResponse> {
let request_id = Uuid::new_v4().to_string();

workflow_id: String,
run_id: String,
signal_name: String,
payloads: Option<Payloads>,
) -> Result<SignalWorkflowExecutionResponse> {
Ok(self
.service
.clone()
.start_workflow_execution(StartWorkflowExecutionRequest {
namespace: namespace.to_string(),
workflow_id: workflow_id.to_string(),
workflow_type: Some(WorkflowType {
name: workflow_type.to_string(),
.signal_workflow_execution(SignalWorkflowExecutionRequest {
namespace: self.opts.namespace.clone(),
workflow_execution: Some(WorkflowExecution {
workflow_id,
run_id,
}),
task_queue: Some(TaskQueue {
name: task_queue.to_string(),
kind: 0,
}),
request_id,
signal_name,
input: payloads,
identity: self.opts.identity.clone(),
..Default::default()
})
.await?
.into_inner())
}
}

#[cfg(test)]
mockall::mock! {
pub ServerGateway {}
#[async_trait::async_trait]
impl PollWorkflowTaskQueueApi for ServerGateway {
async fn poll_workflow_task(&self, task_queue: &str) -> Result<PollWorkflowTaskQueueResponse>;
}
#[async_trait::async_trait]
impl RespondWorkflowTaskCompletedApi for ServerGateway {
async fn complete_workflow_task(&self, task_token: Vec<u8>, commands: Vec<ProtoCommand>) -> Result<RespondWorkflowTaskCompletedResponse>;
}

#[async_trait::async_trait]
impl RespondWorkflowTaskFailedApi for ServerGateway {
async fn fail_workflow_task(
&self,
task_token: Vec<u8>,
cause: WorkflowTaskFailedCause,
failure: Option<Failure>,
) -> Result<RespondWorkflowTaskFailedResponse>;
}

#[async_trait::async_trait]
impl StartWorkflowExecutionApi for ServerGateway {
async fn start_workflow(
&self,
namespace: &str,
task_queue: &str,
workflow_id: &str,
workflow_type: &str,
) -> Result<StartWorkflowExecutionResponse>;
}
}
65 changes: 2 additions & 63 deletions src/workflow/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,77 +7,16 @@ pub(crate) use concurrency_manager::WorkflowConcurrencyManager;
pub(crate) use driven_workflow::{ActivationListener, DrivenWorkflow, WorkflowFetcher};

use crate::{
machines::{ProtoCommand, WFCommand, WorkflowMachines},
machines::{WFCommand, WorkflowMachines},
protos::{
coresdk::WfActivation,
temporal::api::{
enums::v1::WorkflowTaskFailedCause,
failure::v1::Failure,
history::v1::History,
workflowservice::v1::{
PollWorkflowTaskQueueResponse, RespondWorkflowTaskCompletedResponse,
RespondWorkflowTaskFailedResponse, StartWorkflowExecutionResponse,
},
},
temporal::api::{history::v1::History, workflowservice::v1::PollWorkflowTaskQueueResponse},
},
protosext::HistoryInfo,
CoreError, Result,
};
use std::sync::mpsc::Sender;

/// Implementors can provide new workflow tasks to the SDK. The connection to the server is the real
/// implementor.
#[cfg_attr(test, mockall::automock)]
#[async_trait::async_trait]
pub trait PollWorkflowTaskQueueApi {
/// Fetch new work. Should block indefinitely if there is no work.
async fn poll_workflow_task(&self, task_queue: &str) -> Result<PollWorkflowTaskQueueResponse>;
}

/// Implementors can complete tasks issued by [Core::poll]. The real implementor sends the completed
/// tasks to the server.
#[cfg_attr(test, mockall::automock)]
#[async_trait::async_trait]
pub trait RespondWorkflowTaskCompletedApi {
/// Complete a task by sending it to the server. `task_token` is the task token that would've
/// been received from [PollWorkflowTaskQueueApi::poll]. `commands` is a list of new commands
/// to send to the server, such as starting a timer.
async fn complete_workflow_task(
&self,
task_token: Vec<u8>,
commands: Vec<ProtoCommand>,
) -> Result<RespondWorkflowTaskCompletedResponse>;
}

/// Implementors can fail workflow tasks issued by [Core::poll]. The real implementor sends the
/// failed tasks to the server.
#[cfg_attr(test, mockall::automock)]
#[async_trait::async_trait]
pub trait RespondWorkflowTaskFailedApi {
/// Fail task by sending the failure to the server. `task_token` is the task token that would've
/// been received from [PollWorkflowTaskQueueApi::poll].
async fn fail_workflow_task(
&self,
task_token: Vec<u8>,
cause: WorkflowTaskFailedCause,
failure: Option<Failure>,
) -> Result<RespondWorkflowTaskFailedResponse>;
}

/// Implementors should send StartWorkflowExecutionRequest to the server and pass the response back.
#[cfg_attr(test, mockall::automock)]
#[async_trait::async_trait]
pub trait StartWorkflowExecutionApi {
/// Starts workflow execution.
async fn start_workflow(
&self,
namespace: &str,
task_queue: &str,
workflow_id: &str,
workflow_type: &str,
) -> Result<StartWorkflowExecutionResponse>;
}

/// Manages an instance of a [WorkflowMachines], which is not thread-safe, as well as other data
/// associated with that specific workflow run.
pub(crate) struct WorkflowManager {
Expand Down
Loading