diff --git a/protos/local/core_interface.proto b/protos/local/core_interface.proto index 7bb0f5cff..aa7a91d15 100644 --- a/protos/local/core_interface.proto +++ b/protos/local/core_interface.proto @@ -21,3 +21,9 @@ message ActivityHeartbeat { string activity_id = 1; repeated common.Payload details = 2; } + +// A request as given to [crate::Core::complete_activity_task] +message ActivityTaskCompletion { + bytes task_token = 1; + activity_result.ActivityResult result = 2; +} diff --git a/src/lib.rs b/src/lib.rs index 30aead89b..558ae5b16 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -35,7 +35,7 @@ use crate::{ activity_task::ActivityTask, workflow_activation::WfActivation, workflow_completion::{wf_activation_completion, WfActivationCompletion}, - ActivityHeartbeat, + ActivityHeartbeat, ActivityTaskCompletion, }, temporal::api::{ enums::v1::WorkflowTaskFailedCause, workflowservice::v1::PollWorkflowTaskQueueResponse, @@ -83,11 +83,7 @@ pub trait Core: Send + Sync { async fn complete_workflow_task(&self, completion: WfActivationCompletion) -> Result<()>; /// Tell the core that an activity has finished executing - async fn complete_activity_task( - &self, - task_token: Vec, - result: ActivityResult, - ) -> Result<()>; + async fn complete_activity_task(&self, completion: ActivityTaskCompletion) -> Result<()>; /// Indicate that a long running activity is still making progress async fn send_activity_heartbeat(&self, task_token: ActivityHeartbeat) -> Result<()>; @@ -311,11 +307,16 @@ where } #[instrument(skip(self))] - async fn complete_activity_task( - &self, - task_token: Vec, - result: ActivityResult, - ) -> Result<()> { + async fn complete_activity_task(&self, completion: ActivityTaskCompletion) -> Result<()> { + let task_token = completion.task_token; + let result = if let Some(r) = completion.result { + r + } else { + return Err(CoreError::MalformedActivityCompletion { + reason: "Activity completion had empty result field".to_owned(), + completion: None, + }); + }; let status = if let Some(s) = result.status { s } else { diff --git a/tests/integ_tests/simple_wf_tests.rs b/tests/integ_tests/simple_wf_tests.rs index 61af33631..eaf6690bc 100644 --- a/tests/integ_tests/simple_wf_tests.rs +++ b/tests/integ_tests/simple_wf_tests.rs @@ -3,6 +3,7 @@ use crossbeam::channel::{unbounded, RecvTimeoutError}; use futures::{channel::mpsc::UnboundedReceiver, future, Future, SinkExt, StreamExt}; use rand::{self, Rng}; use std::{collections::HashMap, convert::TryFrom, env, sync::Arc, time::Duration}; +use temporal_sdk_core::protos::coresdk::ActivityTaskCompletion; use temporal_sdk_core::{ protos::coresdk::{ activity_result::{self, activity_result as act_res, ActivityResult}, @@ -133,10 +134,10 @@ async fn activity_workflow() { metadata: Default::default(), }; // Complete activity successfully. - core.complete_activity_task( - task.task_token, - ActivityResult::ok(response_payload.clone()), - ) + core.complete_activity_task(ActivityTaskCompletion { + task_token: task.task_token, + result: Some(ActivityResult::ok(response_payload.clone())), + }) .await .unwrap(); // Poll workflow task and verify that activity has succeeded. @@ -192,16 +193,16 @@ async fn activity_non_retryable_failure() { non_retryable: true, ..Default::default() }; - core.complete_activity_task( - task.task_token, - ActivityResult { + core.complete_activity_task(ActivityTaskCompletion { + task_token: task.task_token, + result: Some(ActivityResult { status: Some(activity_result::activity_result::Status::Failed( activity_result::Failure { failure: Some(failure.clone()), }, )), - }, - ) + }), + }) .await .unwrap(); // Poll workflow task and verify that activity has failed. @@ -256,16 +257,16 @@ async fn activity_retry() { non_retryable: false, ..Default::default() }; - core.complete_activity_task( - task.task_token, - ActivityResult { + core.complete_activity_task(ActivityTaskCompletion { + task_token: task.task_token, + result: Some(ActivityResult { status: Some(activity_result::activity_result::Status::Failed( activity_result::Failure { failure: Some(failure), }, )), - }, - ) + }), + }) .await .unwrap(); // Poll 2nd time @@ -281,16 +282,10 @@ async fn activity_retry() { data: b"hello ".to_vec(), metadata: Default::default(), }; - core.complete_activity_task( - task.task_token, - ActivityResult { - status: Some(activity_result::activity_result::Status::Completed( - activity_result::Success { - result: Some(response_payload.clone()), - }, - )), - }, - ) + core.complete_activity_task(ActivityTaskCompletion { + task_token: task.task_token, + result: Some(ActivityResult::ok(response_payload.clone())), + }) .await .unwrap(); // Poll workflow task and verify activity has succeeded.