From 95495ca1c3001a2051b41fe7d3f092e0572abf65 Mon Sep 17 00:00:00 2001 From: Spencer Judge Date: Tue, 23 Sep 2025 15:26:40 -0700 Subject: [PATCH] Refactor client traits Introduce mockable versions of all services --- client/Cargo.toml | 1 + client/src/lib.rs | 496 +++++++------ client/src/metrics.rs | 3 +- client/src/raw.rs | 687 ++++++++++++------- client/src/workflow_handle/mod.rs | 34 +- core-c-bridge/src/client.rs | 613 ++++++++++++----- core-c-bridge/src/tests/mod.rs | 26 +- core/src/lib.rs | 14 +- core/src/worker/client.rs | 121 ++-- tests/common/mod.rs | 4 +- tests/integ_tests/client_tests.rs | 37 +- tests/integ_tests/ephemeral_server_tests.rs | 14 +- tests/integ_tests/metrics_tests.rs | 23 +- tests/integ_tests/update_tests.rs | 4 +- tests/integ_tests/worker_versioning_tests.rs | 57 +- tests/integ_tests/workflow_tests/resets.rs | 49 +- tests/main.rs | 32 +- 17 files changed, 1364 insertions(+), 851 deletions(-) diff --git a/client/Cargo.toml b/client/Cargo.toml index 50095bfde..f083ed777 100644 --- a/client/Cargo.toml +++ b/client/Cargo.toml @@ -20,6 +20,7 @@ backoff = "0.4" base64 = "0.22" derive_builder = { workspace = true } derive_more = { workspace = true } +dyn-clone = "1.0" bytes = "1.10" futures-util = { version = "0.3", default-features = false } futures-retry = "0.6.0" diff --git a/client/src/lib.rs b/client/src/lib.rs index ea0607ce3..8bdde001d 100644 --- a/client/src/lib.rs +++ b/client/src/lib.rs @@ -41,7 +41,7 @@ pub use workflow_handle::{ use crate::{ metrics::{ChannelOrGrpcOverride, GrpcMetricSvc, MetricsContext}, - raw::{AttachMetricLabels, sealed::RawClientLike}, + raw::AttachMetricLabels, sealed::WfHandleClient, workflow_handle::UntypedWorkflowHandle, }; @@ -76,7 +76,7 @@ use temporal_sdk_core_protos::{ }, }; use tonic::{ - Code, + Code, IntoRequest, body::Body, client::GrpcService, codegen::InterceptedService, @@ -513,8 +513,7 @@ impl ClientOptions { pub async fn connect_no_namespace( &self, metrics_meter: Option, - ) -> Result>, ClientInitError> - { + ) -> Result>, ClientInitError> { self.connect_no_namespace_with_service_override(metrics_meter, None) .await } @@ -529,8 +528,7 @@ impl ClientOptions { &self, metrics_meter: Option, service_override: Option, - ) -> Result>, ClientInitError> - { + ) -> Result>, ClientInitError> { let service = if let Some(service_override) = service_override { GrpcMetricSvc { inner: ChannelOrGrpcOverride::GrpcOverride(service_override), @@ -590,7 +588,7 @@ impl ClientOptions { }; if !self.skip_get_system_info { match client - .get_system_info(GetSystemInfoRequest::default()) + .get_system_info(GetSystemInfoRequest::default().into_request()) .await { Ok(sysinfo) => { @@ -734,14 +732,13 @@ impl Interceptor for ServiceCallInterceptor { } /// Aggregates various services exposed by the Temporal server -#[derive(Debug, Clone)] -pub struct TemporalServiceClient { - svc: T, - workflow_svc_client: OnceLock>, - operator_svc_client: OnceLock>, - cloud_svc_client: OnceLock>, - test_svc_client: OnceLock>, - health_svc_client: OnceLock>, +#[derive(Clone)] +pub struct TemporalServiceClient { + workflow_svc_client: Box, + operator_svc_client: Box, + cloud_svc_client: Box, + test_svc_client: Box, + health_svc_client: Box, } /// We up the limit on incoming messages from server from the 4Mb default to 128Mb. If for @@ -756,136 +753,100 @@ fn get_decode_max_size() -> usize { }) } -impl TemporalServiceClient -where - T: Clone, - T: GrpcService + Send + Clone + 'static, - T::ResponseBody: tonic::codegen::Body + Send + 'static, - T::Error: Into, - ::Error: Into + Send, -{ - fn new(svc: T) -> Self { +impl TemporalServiceClient { + fn new(svc: T) -> Self + where + T: GrpcService + Send + Sync + Clone + 'static, + T::ResponseBody: tonic::codegen::Body + Send + 'static, + T::Error: Into, + ::Error: Into + Send, + >::Future: Send, + { + let workflow_svc_client = Box::new( + WorkflowServiceClient::new(svc.clone()) + .max_decoding_message_size(get_decode_max_size()), + ); + let operator_svc_client = Box::new( + OperatorServiceClient::new(svc.clone()) + .max_decoding_message_size(get_decode_max_size()), + ); + let cloud_svc_client = Box::new( + CloudServiceClient::new(svc.clone()).max_decoding_message_size(get_decode_max_size()), + ); + let test_svc_client = Box::new( + TestServiceClient::new(svc.clone()).max_decoding_message_size(get_decode_max_size()), + ); + let health_svc_client = Box::new( + HealthClient::new(svc.clone()).max_decoding_message_size(get_decode_max_size()), + ); + Self { - svc, - workflow_svc_client: OnceLock::new(), - operator_svc_client: OnceLock::new(), - cloud_svc_client: OnceLock::new(), - test_svc_client: OnceLock::new(), - health_svc_client: OnceLock::new(), + workflow_svc_client, + operator_svc_client, + cloud_svc_client, + test_svc_client, + health_svc_client, } } + + /// Create a service client from implementations of the individual underlying services. Useful + /// for mocking out service implementations. + pub fn from_services( + workflow: Box, + operator: Box, + cloud: Box, + test: Box, + health: Box, + ) -> Self { + Self { + workflow_svc_client: workflow, + operator_svc_client: operator, + cloud_svc_client: cloud, + test_svc_client: test, + health_svc_client: health, + } + } + /// Get the underlying workflow service client - pub fn workflow_svc(&self) -> &WorkflowServiceClient { - self.workflow_svc_client.get_or_init(|| { - WorkflowServiceClient::new(self.svc.clone()) - .max_decoding_message_size(get_decode_max_size()) - }) + pub fn workflow_svc(&self) -> Box { + self.workflow_svc_client.clone() } /// Get the underlying operator service client - pub fn operator_svc(&self) -> &OperatorServiceClient { - self.operator_svc_client.get_or_init(|| { - OperatorServiceClient::new(self.svc.clone()) - .max_decoding_message_size(get_decode_max_size()) - }) + pub fn operator_svc(&self) -> Box { + self.operator_svc_client.clone() } /// Get the underlying cloud service client - pub fn cloud_svc(&self) -> &CloudServiceClient { - self.cloud_svc_client.get_or_init(|| { - CloudServiceClient::new(self.svc.clone()) - .max_decoding_message_size(get_decode_max_size()) - }) + pub fn cloud_svc(&self) -> Box { + self.cloud_svc_client.clone() } /// Get the underlying test service client - pub fn test_svc(&self) -> &TestServiceClient { - self.test_svc_client.get_or_init(|| { - TestServiceClient::new(self.svc.clone()) - .max_decoding_message_size(get_decode_max_size()) - }) + pub fn test_svc(&self) -> Box { + self.test_svc_client.clone() } /// Get the underlying health service client - pub fn health_svc(&self) -> &HealthClient { - self.health_svc_client.get_or_init(|| { - HealthClient::new(self.svc.clone()).max_decoding_message_size(get_decode_max_size()) - }) - } - /// Get the underlying workflow service client mutably - pub fn workflow_svc_mut(&mut self) -> &mut WorkflowServiceClient { - let _ = self.workflow_svc(); - self.workflow_svc_client.get_mut().unwrap() - } - /// Get the underlying operator service client mutably - pub fn operator_svc_mut(&mut self) -> &mut OperatorServiceClient { - let _ = self.operator_svc(); - self.operator_svc_client.get_mut().unwrap() - } - /// Get the underlying cloud service client mutably - pub fn cloud_svc_mut(&mut self) -> &mut CloudServiceClient { - let _ = self.cloud_svc(); - self.cloud_svc_client.get_mut().unwrap() - } - /// Get the underlying test service client mutably - pub fn test_svc_mut(&mut self) -> &mut TestServiceClient { - let _ = self.test_svc(); - self.test_svc_client.get_mut().unwrap() - } - /// Get the underlying health service client mutably - pub fn health_svc_mut(&mut self) -> &mut HealthClient { - let _ = self.health_svc(); - self.health_svc_client.get_mut().unwrap() + pub fn health_svc(&self) -> Box { + self.health_svc_client.clone() } } -/// A [WorkflowServiceClient] with the default interceptors attached. -pub type WorkflowServiceClientWithMetrics = WorkflowServiceClient; -/// An [OperatorServiceClient] with the default interceptors attached. -pub type OperatorServiceClientWithMetrics = OperatorServiceClient; -/// An [TestServiceClient] with the default interceptors attached. -pub type TestServiceClientWithMetrics = TestServiceClient; -/// A [TemporalServiceClient] with the default interceptors attached. -pub type TemporalServiceClientWithMetrics = TemporalServiceClient; -type InterceptedMetricsSvc = InterceptedService; - /// Contains an instance of a namespace-bound client for interacting with the Temporal server -#[derive(Debug, Clone)] +#[derive(Clone)] pub struct Client { /// Client for interacting with workflow service - inner: ConfiguredClient, + inner: ConfiguredClient, /// The namespace this client interacts with namespace: String, } impl Client { /// Create a new client from an existing configured lower level client and a namespace - pub fn new( - client: ConfiguredClient, - namespace: String, - ) -> Self { + pub fn new(client: ConfiguredClient, namespace: String) -> Self { Client { inner: client, namespace, } } - /// Return an auto-retrying version of the underling grpc client (instrumented with metrics - /// collection, if enabled). - /// - /// Note that it is reasonably cheap to clone the returned type if you need to own it. Such - /// clones will keep re-using the same channel. - pub fn raw_retry_client(&self) -> RetryClient { - RetryClient::new( - self.raw_client().clone(), - self.inner.options.retry_config.clone(), - ) - } - - /// Access the underling grpc client. This raw client is not bound to a specific namespace. - /// - /// Note that it is reasonably cheap to clone the returned type if you need to own it. Such - /// clones will keep re-using the same channel. - pub fn raw_client(&self) -> &WorkflowServiceClientWithMetrics { - self.inner.workflow_svc() - } - /// Return the options this client was initialized with pub fn options(&self) -> &ClientOptions { &self.inner.options @@ -897,12 +858,12 @@ impl Client { } /// Returns a reference to the underlying client - pub fn inner(&self) -> &ConfiguredClient { + pub fn inner(&self) -> &ConfiguredClient { &self.inner } /// Consumes self and returns the underlying client - pub fn into_inner(self) -> ConfiguredClient { + pub fn into_inner(self) -> ConfiguredClient { self.inner } } @@ -1378,15 +1339,7 @@ impl From for Priority { #[async_trait::async_trait] impl WorkflowClientTrait for T where - T: RawClientLike + NamespacedClient + Clone + Send + Sync + 'static, - ::SvcType: GrpcService + Send + Clone + 'static, - <::SvcType as GrpcService>::ResponseBody: - tonic::codegen::Body + Send + 'static, - <::SvcType as GrpcService>::Error: - Into, - <::SvcType as GrpcService>::Future: Send, - <<::SvcType as GrpcService>::ResponseBody - as tonic::codegen::Body>::Error: Into + Send, + T: WorkflowService + NamespacedClient + Clone + Send + Sync + 'static, { async fn start_workflow( &self, @@ -1399,35 +1352,38 @@ where ) -> Result { Ok(self .clone() - .start_workflow_execution(StartWorkflowExecutionRequest { - namespace: self.namespace(), - input: input.into_payloads(), - workflow_id, - workflow_type: Some(WorkflowType { - name: workflow_type, - }), - task_queue: Some(TaskQueue { - name: task_queue, - kind: TaskQueueKind::Unspecified as i32, - normal_name: "".to_string(), - }), - request_id: request_id.unwrap_or_else(|| Uuid::new_v4().to_string()), - workflow_id_reuse_policy: options.id_reuse_policy as i32, - workflow_id_conflict_policy: options.id_conflict_policy as i32, - workflow_execution_timeout: options - .execution_timeout - .and_then(|d| d.try_into().ok()), - workflow_run_timeout: options.run_timeout.and_then(|d| d.try_into().ok()), - workflow_task_timeout: options.task_timeout.and_then(|d| d.try_into().ok()), - search_attributes: options.search_attributes.map(|d| d.into()), - cron_schedule: options.cron_schedule.unwrap_or_default(), - request_eager_execution: options.enable_eager_workflow_start, - retry_policy: options.retry_policy, - links: options.links, - completion_callbacks: options.completion_callbacks, - priority: options.priority.map(Into::into), - ..Default::default() - }) + .start_workflow_execution( + StartWorkflowExecutionRequest { + namespace: self.namespace(), + input: input.into_payloads(), + workflow_id, + workflow_type: Some(WorkflowType { + name: workflow_type, + }), + task_queue: Some(TaskQueue { + name: task_queue, + kind: TaskQueueKind::Unspecified as i32, + normal_name: "".to_string(), + }), + request_id: request_id.unwrap_or_else(|| Uuid::new_v4().to_string()), + workflow_id_reuse_policy: options.id_reuse_policy as i32, + workflow_id_conflict_policy: options.id_conflict_policy as i32, + workflow_execution_timeout: options + .execution_timeout + .and_then(|d| d.try_into().ok()), + workflow_run_timeout: options.run_timeout.and_then(|d| d.try_into().ok()), + workflow_task_timeout: options.task_timeout.and_then(|d| d.try_into().ok()), + search_attributes: options.search_attributes.map(|d| d.into()), + cron_schedule: options.cron_schedule.unwrap_or_default(), + request_eager_execution: options.enable_eager_workflow_start, + retry_policy: options.retry_policy, + links: options.links, + completion_callbacks: options.completion_callbacks, + priority: options.priority.map(Into::into), + ..Default::default() + } + .into_request(), + ) .await? .into_inner()) } @@ -1445,7 +1401,7 @@ where }), }; Ok( - WorkflowService::reset_sticky_task_queue(&mut self.clone(), request) + WorkflowService::reset_sticky_task_queue(&mut self.clone(), request.into_request()) .await? .into_inner(), ) @@ -1456,17 +1412,20 @@ where task_token: TaskToken, result: Option, ) -> Result { - Ok(self.clone().respond_activity_task_completed( - RespondActivityTaskCompletedRequest { - task_token: task_token.0, - result, - identity: self.identity(), - namespace: self.namespace(), - ..Default::default() - }, - ) - .await? - .into_inner()) + Ok(self + .clone() + .respond_activity_task_completed( + RespondActivityTaskCompletedRequest { + task_token: task_token.0, + result, + identity: self.identity(), + namespace: self.namespace(), + ..Default::default() + } + .into_request(), + ) + .await? + .into_inner()) } async fn record_activity_heartbeat( @@ -1474,16 +1433,19 @@ where task_token: TaskToken, details: Option, ) -> Result { - Ok(self.clone().record_activity_task_heartbeat( - RecordActivityTaskHeartbeatRequest { - task_token: task_token.0, - details, - identity: self.identity(), - namespace: self.namespace(), - }, - ) - .await? - .into_inner()) + Ok(self + .clone() + .record_activity_task_heartbeat( + RecordActivityTaskHeartbeatRequest { + task_token: task_token.0, + details, + identity: self.identity(), + namespace: self.namespace(), + } + .into_request(), + ) + .await? + .into_inner()) } async fn cancel_activity_task( @@ -1491,17 +1453,20 @@ where task_token: TaskToken, details: Option, ) -> Result { - Ok(self.clone().respond_activity_task_canceled( - RespondActivityTaskCanceledRequest { - task_token: task_token.0, - details, - identity: self.identity(), - namespace: self.namespace(), - ..Default::default() - }, - ) - .await? - .into_inner()) + Ok(self + .clone() + .respond_activity_task_canceled( + RespondActivityTaskCanceledRequest { + task_token: task_token.0, + details, + identity: self.identity(), + namespace: self.namespace(), + ..Default::default() + } + .into_request(), + ) + .await? + .into_inner()) } async fn signal_workflow_execution( @@ -1512,7 +1477,8 @@ where payloads: Option, request_id: Option, ) -> Result { - Ok(WorkflowService::signal_workflow_execution(&mut self.clone(), + Ok(WorkflowService::signal_workflow_execution( + &mut self.clone(), SignalWorkflowExecutionRequest { namespace: self.namespace(), workflow_execution: Some(WorkflowExecution { @@ -1524,7 +1490,8 @@ where identity: self.identity(), request_id: request_id.unwrap_or_else(|| Uuid::new_v4().to_string()), ..Default::default() - }, + } + .into_request(), ) .await? .into_inner()) @@ -1535,7 +1502,8 @@ where options: SignalWithStartOptions, workflow_options: WorkflowOptions, ) -> Result { - Ok(WorkflowService::signal_with_start_workflow_execution(&mut self.clone(), + Ok(WorkflowService::signal_with_start_workflow_execution( + &mut self.clone(), SignalWithStartWorkflowExecutionRequest { namespace: self.namespace(), workflow_id: options.workflow_id, @@ -1567,7 +1535,8 @@ where cron_schedule: workflow_options.cron_schedule.unwrap_or_default(), header: options.signal_header, ..Default::default() - }, + } + .into_request(), ) .await? .into_inner()) @@ -1579,19 +1548,22 @@ where run_id: String, query: WorkflowQuery, ) -> Result { - Ok(self.clone().query_workflow( - QueryWorkflowRequest { - namespace: self.namespace(), - execution: Some(WorkflowExecution { - workflow_id, - run_id, - }), - query: Some(query), - query_reject_condition: 1, - }, - ) - .await? - .into_inner()) + Ok(self + .clone() + .query_workflow( + QueryWorkflowRequest { + namespace: self.namespace(), + execution: Some(WorkflowExecution { + workflow_id, + run_id, + }), + query: Some(query), + query_reject_condition: 1, + } + .into_request(), + ) + .await? + .into_inner()) } async fn describe_workflow_execution( @@ -1599,14 +1571,16 @@ where workflow_id: String, run_id: Option, ) -> Result { - Ok(WorkflowService::describe_workflow_execution(&mut self.clone(), + Ok(WorkflowService::describe_workflow_execution( + &mut self.clone(), DescribeWorkflowExecutionRequest { namespace: self.namespace(), execution: Some(WorkflowExecution { workflow_id, run_id: run_id.unwrap_or_default(), }), - }, + } + .into_request(), ) .await? .into_inner()) @@ -1618,7 +1592,8 @@ where run_id: Option, page_token: Vec, ) -> Result { - Ok(WorkflowService::get_workflow_execution_history(&mut self.clone(), + Ok(WorkflowService::get_workflow_execution_history( + &mut self.clone(), GetWorkflowExecutionHistoryRequest { namespace: self.namespace(), execution: Some(WorkflowExecution { @@ -1627,7 +1602,8 @@ where }), next_page_token: page_token, ..Default::default() - }, + } + .into_request(), ) .await? .into_inner()) @@ -1640,22 +1616,25 @@ where reason: String, request_id: Option, ) -> Result { - Ok(self.clone().request_cancel_workflow_execution( - RequestCancelWorkflowExecutionRequest { - namespace: self.namespace(), - workflow_execution: Some(WorkflowExecution { - workflow_id, - run_id: run_id.unwrap_or_default(), - }), - identity: self.identity(), - request_id: request_id.unwrap_or_else(|| Uuid::new_v4().to_string()), - first_execution_run_id: "".to_string(), - reason, - links: vec![], - }, - ) - .await? - .into_inner()) + Ok(self + .clone() + .request_cancel_workflow_execution( + RequestCancelWorkflowExecutionRequest { + namespace: self.namespace(), + workflow_execution: Some(WorkflowExecution { + workflow_id, + run_id: run_id.unwrap_or_default(), + }), + identity: self.identity(), + request_id: request_id.unwrap_or_else(|| Uuid::new_v4().to_string()), + first_execution_run_id: "".to_string(), + reason, + links: vec![], + } + .into_request(), + ) + .await? + .into_inner()) } async fn terminate_workflow_execution( @@ -1663,7 +1642,8 @@ where workflow_id: String, run_id: Option, ) -> Result { - Ok(WorkflowService::terminate_workflow_execution(&mut self.clone(), + Ok(WorkflowService::terminate_workflow_execution( + &mut self.clone(), TerminateWorkflowExecutionRequest { namespace: self.namespace(), workflow_execution: Some(WorkflowExecution { @@ -1675,7 +1655,8 @@ where identity: self.identity(), first_execution_run_id: "".to_string(), links: vec![], - }, + } + .into_request(), ) .await? .into_inner()) @@ -1687,23 +1668,25 @@ where ) -> Result { let req = Into::::into(options); Ok( - WorkflowService::register_namespace(&mut self.clone(),req) + WorkflowService::register_namespace(&mut self.clone(), req.into_request()) .await? .into_inner(), ) } async fn list_namespaces(&self) -> Result { - Ok(WorkflowService::list_namespaces(&mut self.clone(), - ListNamespacesRequest::default(), + Ok(WorkflowService::list_namespaces( + &mut self.clone(), + ListNamespacesRequest::default().into_request(), ) .await? .into_inner()) } async fn describe_namespace(&self, namespace: Namespace) -> Result { - Ok(WorkflowService::describe_namespace(&mut self.clone(), - namespace.into_describe_namespace_request(), + Ok(WorkflowService::describe_namespace( + &mut self.clone(), + namespace.into_describe_namespace_request().into_request(), ) .await? .into_inner()) @@ -1716,14 +1699,16 @@ where start_time_filter: Option, filters: Option, ) -> Result { - Ok(WorkflowService::list_open_workflow_executions(&mut self.clone(), + Ok(WorkflowService::list_open_workflow_executions( + &mut self.clone(), ListOpenWorkflowExecutionsRequest { namespace: self.namespace(), maximum_page_size, next_page_token, start_time_filter, filters, - }, + } + .into_request(), ) .await? .into_inner()) @@ -1736,14 +1721,16 @@ where start_time_filter: Option, filters: Option, ) -> Result { - Ok(WorkflowService::list_closed_workflow_executions(&mut self.clone(), + Ok(WorkflowService::list_closed_workflow_executions( + &mut self.clone(), ListClosedWorkflowExecutionsRequest { namespace: self.namespace(), maximum_page_size, next_page_token, start_time_filter, filters, - }, + } + .into_request(), ) .await? .into_inner()) @@ -1755,13 +1742,15 @@ where next_page_token: Vec, query: String, ) -> Result { - Ok(WorkflowService::list_workflow_executions(&mut self.clone(), + Ok(WorkflowService::list_workflow_executions( + &mut self.clone(), ListWorkflowExecutionsRequest { namespace: self.namespace(), page_size, next_page_token, query, - }, + } + .into_request(), ) .await? .into_inner()) @@ -1773,21 +1762,24 @@ where next_page_token: Vec, query: String, ) -> Result { - Ok(WorkflowService::list_archived_workflow_executions(&mut self.clone(), + Ok(WorkflowService::list_archived_workflow_executions( + &mut self.clone(), ListArchivedWorkflowExecutionsRequest { namespace: self.namespace(), page_size, next_page_token, query, - }, + } + .into_request(), ) .await? .into_inner()) } async fn get_search_attributes(&self) -> Result { - Ok(WorkflowService::get_search_attributes(&mut self.clone(), - GetSearchAttributesRequest {}, + Ok(WorkflowService::get_search_attributes( + &mut self.clone(), + GetSearchAttributesRequest {}.into_request(), ) .await? .into_inner()) @@ -1801,7 +1793,8 @@ where wait_policy: update::v1::WaitPolicy, args: Option, ) -> Result { - Ok(WorkflowService::update_workflow_execution(&mut self.clone(), + Ok(WorkflowService::update_workflow_execution( + &mut self.clone(), UpdateWorkflowExecutionRequest { namespace: self.namespace(), workflow_execution: Some(WorkflowExecution { @@ -1821,7 +1814,8 @@ where }), }), ..Default::default() - }, + } + .into_request(), ) .await? .into_inner()) @@ -1829,17 +1823,9 @@ where } mod sealed { - use crate::{InterceptedMetricsSvc, RawClientLike, WorkflowClientTrait}; - - pub trait WfHandleClient: - WorkflowClientTrait + RawClientLike - { - } - - impl WfHandleClient for T where - T: WorkflowClientTrait + RawClientLike - { - } + use crate::{WorkflowClientTrait, WorkflowService}; + pub trait WfHandleClient: WorkflowClientTrait + WorkflowService {} + impl WfHandleClient for T where T: WorkflowClientTrait + WorkflowService {} } /// Additional methods for workflow clients diff --git a/client/src/metrics.rs b/client/src/metrics.rs index d2dbbf279..aeceb4a67 100644 --- a/client/src/metrics.rs +++ b/client/src/metrics.rs @@ -208,7 +208,7 @@ fn code_as_screaming_snake(code: &Code) -> &'static str { /// Implements metrics functionality for gRPC (really, any http) calls #[derive(Debug, Clone)] -pub struct GrpcMetricSvc { +pub(crate) struct GrpcMetricSvc { pub(crate) inner: ChannelOrGrpcOverride, // If set to none, metrics are a no-op pub(crate) metrics: Option, @@ -230,6 +230,7 @@ impl fmt::Debug for ChannelOrGrpcOverride { } } +// TODO: Rewrite as a RawGrpcCaller implementation impl Service> for GrpcMetricSvc { type Response = http::Response; type Error = Box; diff --git a/client/src/raw.rs b/client/src/raw.rs index 95f88c4af..2dddd87e2 100644 --- a/client/src/raw.rs +++ b/client/src/raw.rs @@ -3,14 +3,14 @@ //! happen. use crate::{ - Client, ConfiguredClient, InterceptedMetricsSvc, LONG_POLL_TIMEOUT, RequestExt, RetryClient, - SharedReplaceableClient, TEMPORAL_NAMESPACE_HEADER_KEY, TemporalServiceClient, + Client, ConfiguredClient, LONG_POLL_TIMEOUT, RequestExt, RetryClient, SharedReplaceableClient, + TEMPORAL_NAMESPACE_HEADER_KEY, TemporalServiceClient, metrics::{namespace_kv, task_queue_kv}, - raw::sealed::RawClientLike, worker_registry::{Slot, SlotManager}, }; +use dyn_clone::DynClone; use futures_util::{FutureExt, TryFutureExt, future::BoxFuture}; -use std::sync::Arc; +use std::{any::Any, marker::PhantomData, sync::Arc}; use temporal_sdk_core_api::telemetry::metrics::MetricKeyValue; use temporal_sdk_core_protos::{ grpc::health::v1::{health_client::HealthClient, *}, @@ -29,80 +29,182 @@ use tonic::{ metadata::{AsciiMetadataValue, KeyAndValueRef}, }; -pub(super) mod sealed { - use super::*; +/// Something that has access to the raw grpc services +trait RawClientProducer { + /// Returns information about workers associated with this client. Implementers outside of + /// core can safely return `None`. + fn get_workers_info(&self) -> Option>; - /// Something that has access to the raw grpc services - #[async_trait::async_trait] - pub trait RawClientLike: Send { - type SvcType: Send + Sync + Clone + 'static; + /// Return a workflow service client instance + fn workflow_client(&mut self) -> Box; - /// Return a mutable ref to the workflow service client instance - fn workflow_client_mut(&mut self) -> &mut WorkflowServiceClient; + /// Return a mutable ref to the operator service client instance + fn operator_client(&mut self) -> Box; - /// Return a mutable ref to the operator service client instance - fn operator_client_mut(&mut self) -> &mut OperatorServiceClient; + /// Return a mutable ref to the cloud service client instance + fn cloud_client(&mut self) -> Box; - /// Return a mutable ref to the cloud service client instance - fn cloud_client_mut(&mut self) -> &mut CloudServiceClient; + /// Return a mutable ref to the test service client instance + fn test_client(&mut self) -> Box; - /// Return a mutable ref to the test service client instance - fn test_client_mut(&mut self) -> &mut TestServiceClient; + /// Return a mutable ref to the health service client instance + fn health_client(&mut self) -> Box; +} - /// Return a mutable ref to the health service client instance - fn health_client_mut(&mut self) -> &mut HealthClient; +/// Any client that can make gRPC calls. The default implementation simply invokes the passed-in +/// function. Implementers may override this to provide things like retry behavior, ex: +/// [RetryClient]. +#[async_trait::async_trait] +trait RawGrpcCaller: Send + Sync + 'static { + async fn call( + &mut self, + _call_name: &'static str, + mut callfn: F, + req: Request, + ) -> Result, Status> + where + Req: Clone + Unpin + Send + Sync + 'static, + Resp: Send + 'static, + F: Send + Sync + Unpin + 'static, + for<'a> F: + FnMut(&'a mut Self, Request) -> BoxFuture<'static, Result, Status>>, + { + callfn(self, req).await + } +} - /// Return a registry with workers using this client instance - fn get_workers_info(&self) -> Option>; +trait ErasedRawClient: Send + Sync + 'static { + fn erased_call( + &mut self, + call_name: &'static str, + op: &mut dyn ErasedCallOp, + ) -> BoxFuture<'static, Result>, Status>>; +} - async fn call( - &mut self, - _call_name: &'static str, - mut callfn: F, - req: Request, - ) -> Result, Status> - where - Req: Clone + Unpin + Send + Sync + 'static, - F: FnMut(&mut Self, Request) -> BoxFuture<'static, Result, Status>>, - F: Send + Sync + Unpin + 'static, - { - callfn(self, req).await +trait ErasedCallOp: Send { + fn invoke( + &mut self, + raw: &mut dyn ErasedRawClient, + call_name: &'static str, + ) -> BoxFuture<'static, Result>, Status>>; +} + +struct CallShim { + callfn: F, + seed_req: Option>, + _resp: PhantomData, +} + +impl CallShim { + fn new(callfn: F, seed_req: Request) -> Self { + Self { + callfn, + seed_req: Some(seed_req), + _resp: PhantomData, } } } +impl ErasedCallOp for CallShim +where + Req: Clone + Unpin + Send + Sync + 'static, + Resp: Send + 'static, + F: Send + Sync + Unpin + 'static, + for<'a> F: FnMut( + &'a mut dyn ErasedRawClient, + Request, + ) -> BoxFuture<'static, Result, Status>>, +{ + fn invoke( + &mut self, + raw: &mut dyn ErasedRawClient, + _call_name: &'static str, + ) -> BoxFuture<'static, Result>, Status>> { + (self.callfn)( + raw, + self.seed_req + .take() + .expect("CallShim must have request populated"), + ) + .map(|res| res.map(|payload| payload.map(|t| Box::new(t) as Box))) + .boxed() + } +} #[async_trait::async_trait] -impl RawClientLike for RetryClient +impl RawGrpcCaller for dyn ErasedRawClient { + async fn call( + &mut self, + call_name: &'static str, + callfn: F, + req: Request, + ) -> Result, Status> + where + Req: Clone + Unpin + Send + Sync + 'static, + Resp: Send + 'static, + F: Send + Sync + Unpin + 'static, + for<'a> F: FnMut( + &'a mut dyn ErasedRawClient, + Request, + ) -> BoxFuture<'static, Result, Status>>, + { + let mut shim = CallShim::new(callfn, req); + let erased_resp = ErasedRawClient::erased_call(self, call_name, &mut shim).await?; + Ok(erased_resp.map(|boxed| { + *boxed + .downcast() + .expect("RawGrpcCaller erased response type mismatch") + })) + } +} + +impl ErasedRawClient for T where - RC: RawClientLike + 'static, - T: Send + Sync + Clone + 'static, + T: RawGrpcCaller + 'static, { - type SvcType = T; + fn erased_call( + &mut self, + call_name: &'static str, + op: &mut dyn ErasedCallOp, + ) -> BoxFuture<'static, Result>, Status>> { + let raw: &mut dyn ErasedRawClient = self; + op.invoke(raw, call_name) + } +} - fn workflow_client_mut(&mut self) -> &mut WorkflowServiceClient { - self.get_client_mut().workflow_client_mut() +impl RawClientProducer for RetryClient +where + RC: RawClientProducer + 'static, +{ + fn get_workers_info(&self) -> Option> { + self.get_client().get_workers_info() } - fn operator_client_mut(&mut self) -> &mut OperatorServiceClient { - self.get_client_mut().operator_client_mut() + fn workflow_client(&mut self) -> Box { + self.get_client_mut().workflow_client() } - fn cloud_client_mut(&mut self) -> &mut CloudServiceClient { - self.get_client_mut().cloud_client_mut() + fn operator_client(&mut self) -> Box { + self.get_client_mut().operator_client() } - fn test_client_mut(&mut self) -> &mut TestServiceClient { - self.get_client_mut().test_client_mut() + fn cloud_client(&mut self) -> Box { + self.get_client_mut().cloud_client() } - fn health_client_mut(&mut self) -> &mut HealthClient { - self.get_client_mut().health_client_mut() + fn test_client(&mut self) -> Box { + self.get_client_mut().test_client() } - fn get_workers_info(&self) -> Option> { - self.get_client().get_workers_info() + fn health_client(&mut self) -> Box { + self.get_client_mut().health_client() } +} +#[async_trait::async_trait] +impl RawGrpcCaller for RetryClient +where + RC: RawGrpcCaller + 'static, +{ async fn call( &mut self, call_name: &'static str, @@ -128,155 +230,142 @@ where } } -#[async_trait::async_trait] -impl RawClientLike for SharedReplaceableClient +/// Helper for cloning a tonic request as long as the inner message may be cloned. +fn req_cloner(cloneme: &Request) -> Request { + let msg = cloneme.get_ref().clone(); + let mut new_req = Request::new(msg); + let new_met = new_req.metadata_mut(); + for kv in cloneme.metadata().iter() { + match kv { + KeyAndValueRef::Ascii(k, v) => { + new_met.insert(k, v.clone()); + } + KeyAndValueRef::Binary(k, v) => { + new_met.insert_bin(k, v.clone()); + } + } + } + *new_req.extensions_mut() = cloneme.extensions().clone(); + new_req +} + +impl RawClientProducer for SharedReplaceableClient where - RC: RawClientLike + Clone + Sync + 'static, - T: Send + Sync + Clone + 'static, + RC: RawClientProducer + Clone + Send + Sync + 'static, { - type SvcType = T; - - fn workflow_client_mut(&mut self) -> &mut WorkflowServiceClient { - self.inner_mut_refreshed().workflow_client_mut() + fn get_workers_info(&self) -> Option> { + self.inner_cow().get_workers_info() } - - fn operator_client_mut(&mut self) -> &mut OperatorServiceClient { - self.inner_mut_refreshed().operator_client_mut() + fn workflow_client(&mut self) -> Box { + self.inner_mut_refreshed().workflow_client() } - fn cloud_client_mut(&mut self) -> &mut CloudServiceClient { - self.inner_mut_refreshed().cloud_client_mut() + fn operator_client(&mut self) -> Box { + self.inner_mut_refreshed().operator_client() } - fn test_client_mut(&mut self) -> &mut TestServiceClient { - self.inner_mut_refreshed().test_client_mut() + fn cloud_client(&mut self) -> Box { + self.inner_mut_refreshed().cloud_client() } - fn health_client_mut(&mut self) -> &mut HealthClient { - self.inner_mut_refreshed().health_client_mut() + fn test_client(&mut self) -> Box { + self.inner_mut_refreshed().test_client() } - fn get_workers_info(&self) -> Option> { - self.inner_cow().get_workers_info() + fn health_client(&mut self) -> Box { + self.inner_mut_refreshed().health_client() } } -impl RawClientLike for TemporalServiceClient -where - T: Send + Sync + Clone + 'static, - T: GrpcService + Send + Clone + 'static, - T::ResponseBody: tonic::codegen::Body + Send + 'static, - T::Error: Into, - ::Error: Into + Send, +#[async_trait::async_trait] +impl RawGrpcCaller for SharedReplaceableClient where + RC: RawGrpcCaller + Clone + Sync + 'static { - type SvcType = T; +} - fn workflow_client_mut(&mut self) -> &mut WorkflowServiceClient { - self.workflow_svc_mut() +impl RawClientProducer for TemporalServiceClient { + fn get_workers_info(&self) -> Option> { + None } - fn operator_client_mut(&mut self) -> &mut OperatorServiceClient { - self.operator_svc_mut() + fn workflow_client(&mut self) -> Box { + self.workflow_svc() } - fn cloud_client_mut(&mut self) -> &mut CloudServiceClient { - self.cloud_svc_mut() + fn operator_client(&mut self) -> Box { + self.operator_svc() } - fn test_client_mut(&mut self) -> &mut TestServiceClient { - self.test_svc_mut() + fn cloud_client(&mut self) -> Box { + self.cloud_svc() } - fn health_client_mut(&mut self) -> &mut HealthClient { - self.health_svc_mut() + fn test_client(&mut self) -> Box { + self.test_svc() } - fn get_workers_info(&self) -> Option> { - None + fn health_client(&mut self) -> Box { + self.health_svc() } } -impl RawClientLike for ConfiguredClient> -where - T: Send + Sync + Clone + 'static, - T: GrpcService + Send + Clone + 'static, - T::ResponseBody: tonic::codegen::Body + Send + 'static, - T::Error: Into, - ::Error: Into + Send, -{ - type SvcType = T; +impl RawGrpcCaller for TemporalServiceClient {} - fn workflow_client_mut(&mut self) -> &mut WorkflowServiceClient { - self.client.workflow_client_mut() +impl RawClientProducer for ConfiguredClient { + fn get_workers_info(&self) -> Option> { + Some(self.workers()) } - fn operator_client_mut(&mut self) -> &mut OperatorServiceClient { - self.client.operator_client_mut() + fn workflow_client(&mut self) -> Box { + self.client.workflow_client() } - fn cloud_client_mut(&mut self) -> &mut CloudServiceClient { - self.client.cloud_client_mut() + fn operator_client(&mut self) -> Box { + self.client.operator_client() } - fn test_client_mut(&mut self) -> &mut TestServiceClient { - self.client.test_client_mut() + fn cloud_client(&mut self) -> Box { + self.client.cloud_client() } - fn health_client_mut(&mut self) -> &mut HealthClient { - self.client.health_client_mut() + fn test_client(&mut self) -> Box { + self.client.test_client() } - fn get_workers_info(&self) -> Option> { - Some(self.workers()) + fn health_client(&mut self) -> Box { + self.client.health_client() } } -impl RawClientLike for Client { - type SvcType = InterceptedMetricsSvc; +impl RawGrpcCaller for ConfiguredClient {} - fn workflow_client_mut(&mut self) -> &mut WorkflowServiceClient { - self.inner.workflow_client_mut() +impl RawClientProducer for Client { + fn get_workers_info(&self) -> Option> { + self.inner.get_workers_info() } - fn operator_client_mut(&mut self) -> &mut OperatorServiceClient { - self.inner.operator_client_mut() + fn workflow_client(&mut self) -> Box { + self.inner.workflow_client() } - fn cloud_client_mut(&mut self) -> &mut CloudServiceClient { - self.inner.cloud_client_mut() + fn operator_client(&mut self) -> Box { + self.inner.operator_client() } - fn test_client_mut(&mut self) -> &mut TestServiceClient { - self.inner.test_client_mut() + fn cloud_client(&mut self) -> Box { + self.inner.cloud_client() } - fn health_client_mut(&mut self) -> &mut HealthClient { - self.inner.health_client_mut() + fn test_client(&mut self) -> Box { + self.inner.test_client() } - fn get_workers_info(&self) -> Option> { - self.inner.get_workers_info() + fn health_client(&mut self) -> Box { + self.inner.health_client() } } -/// Helper for cloning a tonic request as long as the inner message may be cloned. -fn req_cloner(cloneme: &Request) -> Request { - let msg = cloneme.get_ref().clone(); - let mut new_req = Request::new(msg); - let new_met = new_req.metadata_mut(); - for kv in cloneme.metadata().iter() { - match kv { - KeyAndValueRef::Ascii(k, v) => { - new_met.insert(k, v.clone()); - } - KeyAndValueRef::Binary(k, v) => { - new_met.insert_bin(k, v.clone()); - } - } - } - *new_req.extensions_mut() = cloneme.extensions().clone(); - new_req -} +impl RawGrpcCaller for Client {} #[derive(Clone, Debug)] pub(super) struct AttachMetricLabels { @@ -306,62 +395,27 @@ impl AttachMetricLabels { #[derive(Copy, Clone, Debug)] pub(super) struct IsUserLongPoll; -// Blanket impl the trait for all raw-client-like things. Since the trait default-implements -// everything, there's nothing to actually implement. -impl WorkflowService for RC -where - RC: RawClientLike, - T: GrpcService + Send + Clone + 'static, - T::ResponseBody: tonic::codegen::Body + Send + 'static, - T::Error: Into, - T::Future: Send, - ::Error: Into + Send, -{ -} -impl OperatorService for RC -where - RC: RawClientLike, - T: GrpcService + Send + Clone + 'static, - T::ResponseBody: tonic::codegen::Body + Send + 'static, - T::Error: Into, - T::Future: Send, - ::Error: Into + Send, -{ -} -impl CloudService for RC -where - RC: RawClientLike, - T: GrpcService + Send + Clone + 'static, - T::ResponseBody: tonic::codegen::Body + Send + 'static, - T::Error: Into, - T::Future: Send, - ::Error: Into + Send, -{ -} -impl TestService for RC -where - RC: RawClientLike, - T: GrpcService + Send + Clone + 'static, - T::ResponseBody: tonic::codegen::Body + Send + 'static, - T::Error: Into, - T::Future: Send, - ::Error: Into + Send, -{ -} -impl HealthService for RC -where - RC: RawClientLike, - T: GrpcService + Send + Clone + 'static, - T::ResponseBody: tonic::codegen::Body + Send + 'static, - T::Error: Into, - T::Future: Send, - ::Error: Into + Send, -{ +macro_rules! proxy_def { + ($client_type:tt, $client_meth:ident, $method:ident, $req:ty, $resp:ty, defaults) => { + #[doc = concat!("See [", stringify!($client_type), "::", stringify!($method), "]")] + fn $method( + &mut self, + _request: tonic::Request<$req>, + ) -> BoxFuture<'_, Result, tonic::Status>> { + async { Ok(tonic::Response::new(<$resp>::default())) }.boxed() + } + }; + ($client_type:tt, $client_meth:ident, $method:ident, $req:ty, $resp:ty) => { + #[doc = concat!("See [", stringify!($client_type), "::", stringify!($method), "]")] + fn $method( + &mut self, + _request: tonic::Request<$req>, + ) -> BoxFuture<'_, Result, tonic::Status>>; + }; } - /// Helps re-declare gRPC client methods /// -/// There are two forms: +/// There are four forms: /// /// * The first takes a closure that can modify the request. This is only called once, before the /// actual rpc call is made, and before determinations are made about the kind of call (long poll @@ -369,22 +423,23 @@ where /// * The second takes three closures. The first can modify the request like in the first form. /// The second can modify the request and return a value, and is called right before every call /// (including on retries). The third is called with the response to the call after it resolves. -macro_rules! proxy { +/// * The third and fourth are equivalents of the above that skip calling through the `call` method +/// and are implemented directly on the generated gRPC clients (IE: the bottom of the stack). +macro_rules! proxy_impl { ($client_type:tt, $client_meth:ident, $method:ident, $req:ty, $resp:ty $(, $closure:expr)?) => { #[doc = concat!("See [", stringify!($client_type), "::", stringify!($method), "]")] fn $method( &mut self, - request: impl tonic::IntoRequest<$req>, - ) -> BoxFuture<'_, Result, tonic::Status>> { #[allow(unused_mut)] - let mut as_req = request.into_request(); - $( type_closure_arg(&mut as_req, $closure); )* + mut request: tonic::Request<$req>, + ) -> BoxFuture<'_, Result, tonic::Status>> { + $( type_closure_arg(&mut request, $closure); )* #[allow(unused_mut)] let fact = |c: &mut Self, mut req: tonic::Request<$req>| { - let mut c = c.$client_meth().clone(); + let mut c = c.$client_meth(); async move { c.$method(req).await }.boxed() }; - self.call(stringify!($method), fact, as_req) + self.call(stringify!($method), fact, request) } }; ($client_type:tt, $client_meth:ident, $method:ident, $req:ty, $resp:ty, @@ -392,52 +447,108 @@ macro_rules! proxy { #[doc = concat!("See [", stringify!($client_type), "::", stringify!($method), "]")] fn $method( &mut self, - request: impl tonic::IntoRequest<$req>, + mut request: tonic::Request<$req>, ) -> BoxFuture<'_, Result, tonic::Status>> { - #[allow(unused_mut)] - let mut as_req = request.into_request(); - type_closure_arg(&mut as_req, $closure_request); + type_closure_arg(&mut request, $closure_request); #[allow(unused_mut)] let fact = |c: &mut Self, mut req: tonic::Request<$req>| { - let data = type_closure_two_arg(&mut req, c.get_workers_info().unwrap(), - $closure_before); - let mut c = c.$client_meth().clone(); + let data = type_closure_two_arg(&mut req, c.get_workers_info(), $closure_before); + let mut c = c.$client_meth(); async move { type_closure_two_arg(c.$method(req).await, data, $closure_after) }.boxed() }; - self.call(stringify!($method), fact, as_req) + self.call(stringify!($method), fact, request) + } + }; + ($client_type:tt, $method:ident, $req:ty, $resp:ty $(, $closure:expr)?) => { + #[doc = concat!("See [", stringify!($client_type), "::", stringify!($method), "]")] + fn $method( + &mut self, + #[allow(unused_mut)] + mut request: tonic::Request<$req>, + ) -> BoxFuture<'_, Result, tonic::Status>> { + $( type_closure_arg(&mut request, $closure); )* + async move { <$client_type<_>>::$method(self, request).await }.boxed() + } + }; + ($client_type:tt, $method:ident, $req:ty, $resp:ty, + $closure_request:expr, $closure_before:expr, $closure_after:expr) => { + #[doc = concat!("See [", stringify!($client_type), "::", stringify!($method), "]")] + fn $method( + &mut self, + mut request: tonic::Request<$req>, + ) -> BoxFuture<'_, Result, tonic::Status>> { + type_closure_arg(&mut request, $closure_request); + let data = type_closure_two_arg(&mut request, Option::>::None, + $closure_before); + async move { + type_closure_two_arg(<$client_type<_>>::$method(self, request).await, + data, $closure_after) + }.boxed() } }; } -macro_rules! proxier { - ( $trait_name:ident; $impl_list_name:ident; $client_type:tt; $client_meth:ident; - $(($method:ident, $req:ty, $resp:ty - $(, $closure:expr $(, $closure_before:expr, $closure_after:expr)?)? );)* ) => { +macro_rules! proxier_impl { + ($trait_name:ident; $impl_list_name:ident; $client_type:tt; $client_meth:ident; + [$( proxy_def!($($def_args:tt)*); )*]; + $(($method:ident, $req:ty, $resp:ty + $(, $closure:expr $(, $closure_before:expr, $closure_after:expr)?)? );)* ) => { #[cfg(test)] const $impl_list_name: &'static [&'static str] = &[$(stringify!($method)),*]; - /// Trait version of the generated client with modifications to attach appropriate metric - /// labels or whatever else to requests - pub trait $trait_name: RawClientLike + + #[doc = concat!("Trait version of [", stringify!($client_type), "]")] + pub trait $trait_name: Send + Sync + DynClone + { + $( proxy_def!($($def_args)*); )* + } + dyn_clone::clone_trait_object!($trait_name); + + impl $trait_name for RC + where + RC: RawGrpcCaller + RawClientProducer + Clone, + { + $( + proxy_impl!($client_type, $client_meth, $method, $req, $resp + $(,$closure $(,$closure_before, $closure_after)*)*); + )* + } + + impl RawGrpcCaller for $client_type {} + + impl $trait_name for $client_type where - // Yo this is wild - ::SvcType: GrpcService + Send + Clone + 'static, - <::SvcType as GrpcService>::ResponseBody: - tonic::codegen::Body + Send + 'static, - <::SvcType as GrpcService>::Error: - Into, - <::SvcType as GrpcService>::Future: Send, - <<::SvcType as GrpcService>::ResponseBody - as tonic::codegen::Body>::Error: Into + Send, + T: GrpcService + Clone + Send + Sync + 'static, + T::ResponseBody: tonic::codegen::Body + Send + 'static, + T::Error: Into, + ::Error: Into + Send, + >::Future: Send { $( - proxy!($client_type, $client_meth, $method, $req, $resp - $(,$closure $(,$closure_before, $closure_after)*)*); + proxy_impl!($client_type, $method, $req, $resp + $(,$closure $(,$closure_before, $closure_after)*)*); )* } }; } +macro_rules! proxier { + ( $trait_name:ident; $impl_list_name:ident; $client_type:tt; $client_meth:ident; + $(($method:ident, $req:ty, $resp:ty + $(, $closure:expr $(, $closure_before:expr, $closure_after:expr)?)? );)* ) => { + proxier_impl!($trait_name; $impl_list_name; $client_type; $client_meth; + [$(proxy_def!($client_type, $client_meth, $method, $req, $resp);)*]; + $(($method, $req, $resp $(, $closure $(, $closure_before, $closure_after)?)?);)*); + }; + ( $trait_name:ident; $impl_list_name:ident; $client_type:tt; $client_meth:ident; defaults; + $(($method:ident, $req:ty, $resp:ty + $(, $closure:expr $(, $closure_before:expr, $closure_after:expr)?)? );)* ) => { + proxier_impl!($trait_name; $impl_list_name; $client_type; $client_meth; + [$(proxy_def!($client_type, $client_meth, $method, $req, $resp, defaults);)*]; + $(($method, $req, $resp $(, $closure $(, $closure_before, $closure_after)?)?);)*); + }; +} + macro_rules! namespaced_request { ($req:ident) => {{ let ns_str = $req.get_ref().namespace.clone(); @@ -464,7 +575,7 @@ fn type_closure_two_arg(arg1: R, arg2: T, f: impl FnOnce(R, T) -> S) -> } proxier! { - WorkflowService; ALL_IMPLEMENTED_WORKFLOW_SERVICE_RPCS; WorkflowServiceClient; workflow_client_mut; + WorkflowService; ALL_IMPLEMENTED_WORKFLOW_SERVICE_RPCS; WorkflowServiceClient; workflow_client; defaults; ( register_namespace, RegisterNamespaceRequest, @@ -516,21 +627,25 @@ proxier! { r.extensions_mut().insert(labels); }, |r, workers| { - let mut slot: Option> = None; - let req_mut = r.get_mut(); - if req_mut.request_eager_execution { - let namespace = req_mut.namespace.clone(); - let task_queue = req_mut.task_queue.as_ref() - .map(|tq| tq.name.clone()).unwrap_or_default(); - match workers.try_reserve_wft_slot(namespace, task_queue) { - Some(s) => slot = Some(s), - None => req_mut.request_eager_execution = false + if let Some(workers) = workers { + let mut slot: Option> = None; + let req_mut = r.get_mut(); + if req_mut.request_eager_execution { + let namespace = req_mut.namespace.clone(); + let task_queue = req_mut.task_queue.as_ref() + .map(|tq| tq.name.clone()).unwrap_or_default(); + match workers.try_reserve_wft_slot(namespace, task_queue) { + Some(s) => slot = Some(s), + None => req_mut.request_eager_execution = false + } } + slot + } else { + None } - slot }, |resp, slot| { - if let Some(mut s) = slot + if let Some(s) = slot && let Ok(response) = resp.as_ref() && let Some(task) = response.get_ref().clone().eager_workflow_task && let Err(e) = s.schedule_wft(task) { @@ -1323,7 +1438,7 @@ proxier! { } proxier! { - OperatorService; ALL_IMPLEMENTED_OPERATOR_SERVICE_RPCS; OperatorServiceClient; operator_client_mut; + OperatorService; ALL_IMPLEMENTED_OPERATOR_SERVICE_RPCS; OperatorServiceClient; operator_client; defaults; (add_search_attributes, AddSearchAttributesRequest, AddSearchAttributesResponse); (remove_search_attributes, RemoveSearchAttributesRequest, RemoveSearchAttributesResponse); (list_search_attributes, ListSearchAttributesRequest, ListSearchAttributesResponse); @@ -1344,7 +1459,7 @@ proxier! { } proxier! { - CloudService; ALL_IMPLEMENTED_CLOUD_SERVICE_RPCS; CloudServiceClient; cloud_client_mut; + CloudService; ALL_IMPLEMENTED_CLOUD_SERVICE_RPCS; CloudServiceClient; cloud_client; defaults; (get_users, cloudreq::GetUsersRequest, cloudreq::GetUsersResponse); (get_user, cloudreq::GetUserRequest, cloudreq::GetUserResponse); (create_user, cloudreq::CreateUserRequest, cloudreq::CreateUserResponse); @@ -1419,7 +1534,7 @@ proxier! { } proxier! { - TestService; ALL_IMPLEMENTED_TEST_SERVICE_RPCS; TestServiceClient; test_client_mut; + TestService; ALL_IMPLEMENTED_TEST_SERVICE_RPCS; TestServiceClient; test_client; defaults; (lock_time_skipping, LockTimeSkippingRequest, LockTimeSkippingResponse); (unlock_time_skipping, UnlockTimeSkippingRequest, UnlockTimeSkippingResponse); (sleep, SleepRequest, SleepResponse); @@ -1429,7 +1544,7 @@ proxier! { } proxier! { - HealthService; ALL_IMPLEMENTED_HEALTH_SERVICE_RPCS; HealthClient; health_client_mut; + HealthService; ALL_IMPLEMENTED_HEALTH_SERVICE_RPCS; HealthClient; health_client; (check, HealthCheckRequest, HealthCheckResponse); (watch, HealthCheckRequest, tonic::codec::Streaming); } @@ -1442,6 +1557,7 @@ mod tests { use temporal_sdk_core_protos::temporal::api::{ operatorservice::v1::DeleteNamespaceRequest, workflowservice::v1::ListNamespacesRequest, }; + use tonic::IntoRequest; // Just to help make sure some stuff compiles. Not run. #[allow(dead_code)] @@ -1452,7 +1568,7 @@ mod tests { let list_ns_req = ListNamespacesRequest::default(); let fact = |c: &mut RetryClient<_>, req| { - let mut c = c.workflow_client_mut().clone(); + let mut c = c.workflow_client(); async move { c.list_namespaces(req).await }.boxed() }; retry_client @@ -1463,7 +1579,7 @@ mod tests { // Operator svc method let op_del_ns_req = DeleteNamespaceRequest::default(); let fact = |c: &mut RetryClient<_>, req| { - let mut c = c.operator_client_mut().clone(); + let mut c = c.operator_client(); async move { c.delete_namespace(req).await }.boxed() }; retry_client @@ -1474,7 +1590,7 @@ mod tests { // Cloud svc method let cloud_del_ns_req = cloudreq::DeleteNamespaceRequest::default(); let fact = |c: &mut RetryClient<_>, req| { - let mut c = c.cloud_client_mut().clone(); + let mut c = c.cloud_client(); async move { c.delete_namespace(req).await }.boxed() }; retry_client @@ -1483,17 +1599,23 @@ mod tests { .unwrap(); // Verify calling through traits works - retry_client.list_namespaces(list_ns_req).await.unwrap(); + retry_client + .list_namespaces(list_ns_req.into_request()) + .await + .unwrap(); // Have to disambiguate operator and cloud service - OperatorService::delete_namespace(&mut retry_client, op_del_ns_req) + OperatorService::delete_namespace(&mut retry_client, op_del_ns_req.into_request()) .await .unwrap(); - CloudService::delete_namespace(&mut retry_client, cloud_del_ns_req) + CloudService::delete_namespace(&mut retry_client, cloud_del_ns_req.into_request()) .await .unwrap(); - retry_client.get_current_time(()).await.unwrap(); retry_client - .check(HealthCheckRequest::default()) + .get_current_time(().into_request()) + .await + .unwrap(); + retry_client + .check(HealthCheckRequest::default().into_request()) .await .unwrap(); } @@ -1559,4 +1681,65 @@ mod tests { let proto_def = include_str!("../../sdk-core-protos/protos/grpc/health/v1/health.proto"); verify_methods(proto_def, ALL_IMPLEMENTED_HEALTH_SERVICE_RPCS); } + + #[tokio::test] + async fn can_mock_workflow_service() { + #[derive(Clone)] + struct MyFakeServices {} + impl RawGrpcCaller for MyFakeServices {} + impl WorkflowService for MyFakeServices { + fn list_namespaces( + &mut self, + _request: Request, + ) -> BoxFuture<'_, Result, Status>> { + async { + Ok(Response::new(ListNamespacesResponse { + namespaces: vec![DescribeNamespaceResponse { + failover_version: 12345, + ..Default::default() + }], + ..Default::default() + })) + } + .boxed() + } + } + impl OperatorService for MyFakeServices {} + impl CloudService for MyFakeServices {} + impl TestService for MyFakeServices {} + // Health service isn't possible to create a default impl for. + impl HealthService for MyFakeServices { + fn check( + &mut self, + _request: tonic::Request, + ) -> BoxFuture<'_, Result, tonic::Status>> + { + todo!() + } + fn watch( + &mut self, + _request: tonic::Request, + ) -> BoxFuture< + '_, + Result< + tonic::Response>, + tonic::Status, + >, + > { + todo!() + } + } + let mut mocked_client = TemporalServiceClient::from_services( + Box::new(MyFakeServices {}), + Box::new(MyFakeServices {}), + Box::new(MyFakeServices {}), + Box::new(MyFakeServices {}), + Box::new(MyFakeServices {}), + ); + let r = mocked_client + .list_namespaces(ListNamespacesRequest::default().into_request()) + .await + .unwrap(); + assert_eq!(r.into_inner().namespaces[0].failover_version, 12345); + } } diff --git a/client/src/workflow_handle/mod.rs b/client/src/workflow_handle/mod.rs index af6d7ab16..6c112a888 100644 --- a/client/src/workflow_handle/mod.rs +++ b/client/src/workflow_handle/mod.rs @@ -1,4 +1,4 @@ -use crate::{InterceptedMetricsSvc, RawClientLike, WorkflowService}; +use crate::WorkflowService; use anyhow::{anyhow, bail}; use std::{fmt::Debug, marker::PhantomData}; use temporal_sdk_core_protos::{ @@ -11,6 +11,7 @@ use temporal_sdk_core_protos::{ workflowservice::v1::GetWorkflowExecutionHistoryRequest, }, }; +use tonic::IntoRequest; /// Enumerates terminal states for a particular workflow execution // TODO: Add non-proto failure types, flesh out details, etc. @@ -81,7 +82,7 @@ impl WorkflowExecutionInfo { /// Bind the workflow info to a specific client, turning it into a workflow handle pub fn bind_untyped(self, client: CT) -> UntypedWorkflowHandle where - CT: RawClientLike + Clone, + CT: WorkflowService + Clone, { UntypedWorkflowHandle::new(client, self) } @@ -92,7 +93,7 @@ pub(crate) type UntypedWorkflowHandle = WorkflowHandle>; impl WorkflowHandle where - CT: RawClientLike + Clone, + CT: WorkflowService + Clone, // TODO: Make more generic, capable of (de)serialization w/ serde RT: FromPayloadsExt, { @@ -125,18 +126,21 @@ where let server_res = self .client .clone() - .get_workflow_execution_history(GetWorkflowExecutionHistoryRequest { - namespace: self.info.namespace.to_string(), - execution: Some(WorkflowExecution { - workflow_id: self.info.workflow_id.clone(), - run_id: run_id.clone(), - }), - skip_archival: true, - wait_new_event: true, - history_event_filter_type: HistoryEventFilterType::CloseEvent as i32, - next_page_token: next_page_tok.clone(), - ..Default::default() - }) + .get_workflow_execution_history( + GetWorkflowExecutionHistoryRequest { + namespace: self.info.namespace.to_string(), + execution: Some(WorkflowExecution { + workflow_id: self.info.workflow_id.clone(), + run_id: run_id.clone(), + }), + skip_archival: true, + wait_new_event: true, + history_event_filter_type: HistoryEventFilterType::CloseEvent as i32, + next_page_token: next_page_tok.clone(), + ..Default::default() + } + .into_request(), + ) .await? .into_inner(); diff --git a/core-c-bridge/src/client.rs b/core-c-bridge/src/client.rs index ccdd660fd..e5b01f09c 100644 --- a/core-c-bridge/src/client.rs +++ b/core-c-bridge/src/client.rs @@ -16,8 +16,8 @@ use std::{ use temporal_client::{ ClientKeepAliveConfig, ClientOptions as CoreClientOptions, ClientOptionsBuilder, ClientTlsConfig, CloudService, ConfiguredClient, HealthService, HttpConnectProxyOptions, - OperatorService, RetryClient, RetryConfig, TemporalServiceClientWithMetrics, TestService, - TlsConfig, WorkflowService, callback_based, + OperatorService, RetryClient, RetryConfig, TemporalServiceClient, TestService, TlsConfig, + WorkflowService, callback_based, }; use tokio::sync::oneshot; use tonic::metadata::MetadataKey; @@ -79,7 +79,7 @@ pub struct ClientHttpConnectProxyOptions { pub password: ByteArrayRef, } -type CoreClient = RetryClient>; +type CoreClient = RetryClient>; pub struct Client { pub(crate) runtime: Runtime, @@ -528,16 +528,6 @@ pub extern "C" fn temporal_core_client_rpc_call( }); } -macro_rules! rpc_call { - ($client:ident, $call:ident, $call_name:ident) => { - if $call.retry { - rpc_resp($client.$call_name(rpc_req($call)?).await) - } else { - rpc_resp($client.into_inner().$call_name(rpc_req($call)?).await) - } - }; -} - macro_rules! rpc_call_on_trait { ($client:ident, $call:ident, $trait:tt, $call_name:ident) => { if $call.retry { @@ -555,115 +545,307 @@ async fn call_workflow_service( let rpc = call.rpc.to_str(); let mut client = client.clone(); match rpc { - "CountWorkflowExecutions" => rpc_call!(client, call, count_workflow_executions), - "CreateSchedule" => rpc_call!(client, call, create_schedule), - "CreateWorkflowRule" => rpc_call!(client, call, create_workflow_rule), - "DeleteSchedule" => rpc_call!(client, call, delete_schedule), - "DeleteWorkerDeployment" => rpc_call!(client, call, delete_worker_deployment), + "CountWorkflowExecutions" => { + rpc_call_on_trait!(client, call, WorkflowService, count_workflow_executions) + } + "CreateSchedule" => rpc_call_on_trait!(client, call, WorkflowService, create_schedule), + "CreateWorkflowRule" => { + rpc_call_on_trait!(client, call, WorkflowService, create_workflow_rule) + } + "DeleteSchedule" => rpc_call_on_trait!(client, call, WorkflowService, delete_schedule), + "DeleteWorkerDeployment" => { + rpc_call_on_trait!(client, call, WorkflowService, delete_worker_deployment) + } "DeleteWorkerDeploymentVersion" => { - rpc_call!(client, call, delete_worker_deployment_version) - } - "DeleteWorkflowExecution" => rpc_call!(client, call, delete_workflow_execution), - "DeleteWorkflowRule" => rpc_call!(client, call, delete_workflow_rule), - "DeprecateNamespace" => rpc_call!(client, call, deprecate_namespace), - "DescribeBatchOperation" => rpc_call!(client, call, describe_batch_operation), - "DescribeDeployment" => rpc_call!(client, call, describe_deployment), - "DescribeNamespace" => rpc_call!(client, call, describe_namespace), - "DescribeSchedule" => rpc_call!(client, call, describe_schedule), - "DescribeTaskQueue" => rpc_call!(client, call, describe_task_queue), - "DescribeWorkerDeployment" => rpc_call!(client, call, describe_worker_deployment), + rpc_call_on_trait!( + client, + call, + WorkflowService, + delete_worker_deployment_version + ) + } + "DeleteWorkflowExecution" => { + rpc_call_on_trait!(client, call, WorkflowService, delete_workflow_execution) + } + "DeleteWorkflowRule" => { + rpc_call_on_trait!(client, call, WorkflowService, delete_workflow_rule) + } + "DeprecateNamespace" => { + rpc_call_on_trait!(client, call, WorkflowService, deprecate_namespace) + } + "DescribeBatchOperation" => { + rpc_call_on_trait!(client, call, WorkflowService, describe_batch_operation) + } + "DescribeDeployment" => { + rpc_call_on_trait!(client, call, WorkflowService, describe_deployment) + } + "DescribeNamespace" => { + rpc_call_on_trait!(client, call, WorkflowService, describe_namespace) + } + "DescribeSchedule" => rpc_call_on_trait!(client, call, WorkflowService, describe_schedule), + "DescribeTaskQueue" => { + rpc_call_on_trait!(client, call, WorkflowService, describe_task_queue) + } + "DescribeWorkerDeployment" => { + rpc_call_on_trait!(client, call, WorkflowService, describe_worker_deployment) + } "DescribeWorkerDeploymentVersion" => { - rpc_call!(client, call, describe_worker_deployment_version) - } - "DescribeWorkflowExecution" => rpc_call!(client, call, describe_workflow_execution), - "DescribeWorkflowRule" => rpc_call!(client, call, describe_workflow_rule), - "ExecuteMultiOperation" => rpc_call!(client, call, execute_multi_operation), - "FetchWorkerConfig" => rpc_call!(client, call, fetch_worker_config), - "GetClusterInfo" => rpc_call!(client, call, get_cluster_info), - "GetCurrentDeployment" => rpc_call!(client, call, get_current_deployment), - "GetDeploymentReachability" => rpc_call!(client, call, get_deployment_reachability), - "GetSearchAttributes" => rpc_call!(client, call, get_search_attributes), - "GetSystemInfo" => rpc_call!(client, call, get_system_info), + rpc_call_on_trait!( + client, + call, + WorkflowService, + describe_worker_deployment_version + ) + } + "DescribeWorkflowExecution" => { + rpc_call_on_trait!(client, call, WorkflowService, describe_workflow_execution) + } + "DescribeWorkflowRule" => { + rpc_call_on_trait!(client, call, WorkflowService, describe_workflow_rule) + } + "ExecuteMultiOperation" => { + rpc_call_on_trait!(client, call, WorkflowService, execute_multi_operation) + } + "FetchWorkerConfig" => { + rpc_call_on_trait!(client, call, WorkflowService, fetch_worker_config) + } + "GetClusterInfo" => rpc_call_on_trait!(client, call, WorkflowService, get_cluster_info), + "GetCurrentDeployment" => { + rpc_call_on_trait!(client, call, WorkflowService, get_current_deployment) + } + "GetDeploymentReachability" => { + rpc_call_on_trait!(client, call, WorkflowService, get_deployment_reachability) + } + "GetSearchAttributes" => { + rpc_call_on_trait!(client, call, WorkflowService, get_search_attributes) + } + "GetSystemInfo" => rpc_call_on_trait!(client, call, WorkflowService, get_system_info), "GetWorkerBuildIdCompatibility" => { - rpc_call!(client, call, get_worker_build_id_compatibility) + rpc_call_on_trait!( + client, + call, + WorkflowService, + get_worker_build_id_compatibility + ) } "GetWorkerTaskReachability" => { - rpc_call!(client, call, get_worker_task_reachability) + rpc_call_on_trait!(client, call, WorkflowService, get_worker_task_reachability) } - "GetWorkerVersioningRules" => rpc_call!(client, call, get_worker_versioning_rules), - "GetWorkflowExecutionHistory" => rpc_call!(client, call, get_workflow_execution_history), + "GetWorkerVersioningRules" => { + rpc_call_on_trait!(client, call, WorkflowService, get_worker_versioning_rules) + } + "GetWorkflowExecutionHistory" => rpc_call_on_trait!( + client, + call, + WorkflowService, + get_workflow_execution_history + ), "GetWorkflowExecutionHistoryReverse" => { - rpc_call!(client, call, get_workflow_execution_history_reverse) + rpc_call_on_trait!( + client, + call, + WorkflowService, + get_workflow_execution_history_reverse + ) } "ListArchivedWorkflowExecutions" => { - rpc_call!(client, call, list_archived_workflow_executions) - } - "ListBatchOperations" => rpc_call!(client, call, list_batch_operations), - "ListClosedWorkflowExecutions" => rpc_call!(client, call, list_closed_workflow_executions), - "ListDeployments" => rpc_call!(client, call, list_deployments), - "ListNamespaces" => rpc_call!(client, call, list_namespaces), - "ListOpenWorkflowExecutions" => rpc_call!(client, call, list_open_workflow_executions), - "ListScheduleMatchingTimes" => rpc_call!(client, call, list_schedule_matching_times), - "ListSchedules" => rpc_call!(client, call, list_schedules), - "ListTaskQueuePartitions" => rpc_call!(client, call, list_task_queue_partitions), - "ListWorkerDeployments" => rpc_call!(client, call, list_worker_deployments), - "ListWorkers" => rpc_call!(client, call, list_workers), - "ListWorkflowExecutions" => rpc_call!(client, call, list_workflow_executions), - "ListWorkflowRules" => rpc_call!(client, call, list_workflow_rules), - "PatchSchedule" => rpc_call!(client, call, patch_schedule), - "PauseActivity" => rpc_call!(client, call, pause_activity), - "PollActivityTaskQueue" => rpc_call!(client, call, poll_activity_task_queue), - "PollNexusTaskQueue" => rpc_call!(client, call, poll_nexus_task_queue), - "PollWorkflowExecutionUpdate" => rpc_call!(client, call, poll_workflow_execution_update), - "PollWorkflowTaskQueue" => rpc_call!(client, call, poll_workflow_task_queue), - "QueryWorkflow" => rpc_call!(client, call, query_workflow), - "RecordActivityTaskHeartbeat" => rpc_call!(client, call, record_activity_task_heartbeat), + rpc_call_on_trait!( + client, + call, + WorkflowService, + list_archived_workflow_executions + ) + } + "ListBatchOperations" => { + rpc_call_on_trait!(client, call, WorkflowService, list_batch_operations) + } + "ListClosedWorkflowExecutions" => rpc_call_on_trait!( + client, + call, + WorkflowService, + list_closed_workflow_executions + ), + "ListDeployments" => rpc_call_on_trait!(client, call, WorkflowService, list_deployments), + "ListNamespaces" => rpc_call_on_trait!(client, call, WorkflowService, list_namespaces), + "ListOpenWorkflowExecutions" => { + rpc_call_on_trait!(client, call, WorkflowService, list_open_workflow_executions) + } + "ListScheduleMatchingTimes" => { + rpc_call_on_trait!(client, call, WorkflowService, list_schedule_matching_times) + } + "ListSchedules" => rpc_call_on_trait!(client, call, WorkflowService, list_schedules), + "ListTaskQueuePartitions" => { + rpc_call_on_trait!(client, call, WorkflowService, list_task_queue_partitions) + } + "ListWorkerDeployments" => { + rpc_call_on_trait!(client, call, WorkflowService, list_worker_deployments) + } + "ListWorkers" => rpc_call_on_trait!(client, call, WorkflowService, list_workers), + "ListWorkflowExecutions" => { + rpc_call_on_trait!(client, call, WorkflowService, list_workflow_executions) + } + "ListWorkflowRules" => { + rpc_call_on_trait!(client, call, WorkflowService, list_workflow_rules) + } + "PatchSchedule" => rpc_call_on_trait!(client, call, WorkflowService, patch_schedule), + "PauseActivity" => rpc_call_on_trait!(client, call, WorkflowService, pause_activity), + "PollActivityTaskQueue" => { + rpc_call_on_trait!(client, call, WorkflowService, poll_activity_task_queue) + } + "PollNexusTaskQueue" => { + rpc_call_on_trait!(client, call, WorkflowService, poll_nexus_task_queue) + } + "PollWorkflowExecutionUpdate" => rpc_call_on_trait!( + client, + call, + WorkflowService, + poll_workflow_execution_update + ), + "PollWorkflowTaskQueue" => { + rpc_call_on_trait!(client, call, WorkflowService, poll_workflow_task_queue) + } + "QueryWorkflow" => rpc_call_on_trait!(client, call, WorkflowService, query_workflow), + "RecordActivityTaskHeartbeat" => rpc_call_on_trait!( + client, + call, + WorkflowService, + record_activity_task_heartbeat + ), "RecordActivityTaskHeartbeatById" => { - rpc_call!(client, call, record_activity_task_heartbeat_by_id) + rpc_call_on_trait!( + client, + call, + WorkflowService, + record_activity_task_heartbeat_by_id + ) + } + "RecordWorkerHeartbeat" => { + rpc_call_on_trait!(client, call, WorkflowService, record_worker_heartbeat) + } + "RegisterNamespace" => { + rpc_call_on_trait!(client, call, WorkflowService, register_namespace) } - "RecordWorkerHeartbeat" => rpc_call!(client, call, record_worker_heartbeat), - "RegisterNamespace" => rpc_call!(client, call, register_namespace), "RequestCancelWorkflowExecution" => { - rpc_call!(client, call, request_cancel_workflow_execution) + rpc_call_on_trait!( + client, + call, + WorkflowService, + request_cancel_workflow_execution + ) + } + "ResetActivity" => rpc_call_on_trait!(client, call, WorkflowService, reset_activity), + "ResetStickyTaskQueue" => { + rpc_call_on_trait!(client, call, WorkflowService, reset_sticky_task_queue) } - "ResetActivity" => rpc_call!(client, call, reset_activity), - "ResetStickyTaskQueue" => rpc_call!(client, call, reset_sticky_task_queue), - "ResetWorkflowExecution" => rpc_call!(client, call, reset_workflow_execution), - "RespondActivityTaskCanceled" => rpc_call!(client, call, respond_activity_task_canceled), + "ResetWorkflowExecution" => { + rpc_call_on_trait!(client, call, WorkflowService, reset_workflow_execution) + } + "RespondActivityTaskCanceled" => rpc_call_on_trait!( + client, + call, + WorkflowService, + respond_activity_task_canceled + ), "RespondActivityTaskCanceledById" => { - rpc_call!(client, call, respond_activity_task_canceled_by_id) + rpc_call_on_trait!( + client, + call, + WorkflowService, + respond_activity_task_canceled_by_id + ) } - "RespondActivityTaskCompleted" => rpc_call!(client, call, respond_activity_task_completed), + "RespondActivityTaskCompleted" => rpc_call_on_trait!( + client, + call, + WorkflowService, + respond_activity_task_completed + ), "RespondActivityTaskCompletedById" => { - rpc_call!(client, call, respond_activity_task_completed_by_id) + rpc_call_on_trait!( + client, + call, + WorkflowService, + respond_activity_task_completed_by_id + ) + } + "RespondActivityTaskFailed" => { + rpc_call_on_trait!(client, call, WorkflowService, respond_activity_task_failed) } - "RespondActivityTaskFailed" => rpc_call!(client, call, respond_activity_task_failed), "RespondActivityTaskFailedById" => { - rpc_call!(client, call, respond_activity_task_failed_by_id) - } - "RespondNexusTaskCompleted" => rpc_call!(client, call, respond_nexus_task_completed), - "RespondNexusTaskFailed" => rpc_call!(client, call, respond_nexus_task_failed), - "RespondQueryTaskCompleted" => rpc_call!(client, call, respond_query_task_completed), - "RespondWorkflowTaskCompleted" => rpc_call!(client, call, respond_workflow_task_completed), - "RespondWorkflowTaskFailed" => rpc_call!(client, call, respond_workflow_task_failed), - "ScanWorkflowExecutions" => rpc_call!(client, call, scan_workflow_executions), - "SetCurrentDeployment" => rpc_call!(client, call, set_current_deployment), + rpc_call_on_trait!( + client, + call, + WorkflowService, + respond_activity_task_failed_by_id + ) + } + "RespondNexusTaskCompleted" => { + rpc_call_on_trait!(client, call, WorkflowService, respond_nexus_task_completed) + } + "RespondNexusTaskFailed" => { + rpc_call_on_trait!(client, call, WorkflowService, respond_nexus_task_failed) + } + "RespondQueryTaskCompleted" => { + rpc_call_on_trait!(client, call, WorkflowService, respond_query_task_completed) + } + "RespondWorkflowTaskCompleted" => rpc_call_on_trait!( + client, + call, + WorkflowService, + respond_workflow_task_completed + ), + "RespondWorkflowTaskFailed" => { + rpc_call_on_trait!(client, call, WorkflowService, respond_workflow_task_failed) + } + "ScanWorkflowExecutions" => { + rpc_call_on_trait!(client, call, WorkflowService, scan_workflow_executions) + } + "SetCurrentDeployment" => { + rpc_call_on_trait!(client, call, WorkflowService, set_current_deployment) + } "SetWorkerDeploymentCurrentVersion" => { - rpc_call!(client, call, set_worker_deployment_current_version) + rpc_call_on_trait!( + client, + call, + WorkflowService, + set_worker_deployment_current_version + ) } "SetWorkerDeploymentRampingVersion" => { - rpc_call!(client, call, set_worker_deployment_ramping_version) + rpc_call_on_trait!( + client, + call, + WorkflowService, + set_worker_deployment_ramping_version + ) } - "ShutdownWorker" => rpc_call!(client, call, shutdown_worker), + "ShutdownWorker" => rpc_call_on_trait!(client, call, WorkflowService, shutdown_worker), "SignalWithStartWorkflowExecution" => { - rpc_call!(client, call, signal_with_start_workflow_execution) - } - "SignalWorkflowExecution" => rpc_call!(client, call, signal_workflow_execution), - "StartWorkflowExecution" => rpc_call!(client, call, start_workflow_execution), - "StartBatchOperation" => rpc_call!(client, call, start_batch_operation), - "StopBatchOperation" => rpc_call!(client, call, stop_batch_operation), - "TerminateWorkflowExecution" => rpc_call!(client, call, terminate_workflow_execution), - "TriggerWorkflowRule" => rpc_call!(client, call, trigger_workflow_rule), + rpc_call_on_trait!( + client, + call, + WorkflowService, + signal_with_start_workflow_execution + ) + } + "SignalWorkflowExecution" => { + rpc_call_on_trait!(client, call, WorkflowService, signal_workflow_execution) + } + "StartWorkflowExecution" => { + rpc_call_on_trait!(client, call, WorkflowService, start_workflow_execution) + } + "StartBatchOperation" => { + rpc_call_on_trait!(client, call, WorkflowService, start_batch_operation) + } + "StopBatchOperation" => { + rpc_call_on_trait!(client, call, WorkflowService, stop_batch_operation) + } + "TerminateWorkflowExecution" => { + rpc_call_on_trait!(client, call, WorkflowService, terminate_workflow_execution) + } + "TriggerWorkflowRule" => { + rpc_call_on_trait!(client, call, WorkflowService, trigger_workflow_rule) + } "UnpauseActivity" => { rpc_call_on_trait!(client, call, WorkflowService, unpause_activity) } @@ -671,19 +853,45 @@ async fn call_workflow_service( rpc_call_on_trait!(client, call, WorkflowService, update_activity_options) } "UpdateNamespace" => rpc_call_on_trait!(client, call, WorkflowService, update_namespace), - "UpdateSchedule" => rpc_call!(client, call, update_schedule), - "UpdateTaskQueueConfig" => rpc_call!(client, call, update_task_queue_config), - "UpdateWorkerConfig" => rpc_call!(client, call, update_worker_config), + "UpdateSchedule" => rpc_call_on_trait!(client, call, WorkflowService, update_schedule), + "UpdateTaskQueueConfig" => { + rpc_call_on_trait!(client, call, WorkflowService, update_task_queue_config) + } + "UpdateWorkerConfig" => { + rpc_call_on_trait!(client, call, WorkflowService, update_worker_config) + } "UpdateWorkerDeploymentVersionMetadata" => { - rpc_call!(client, call, update_worker_deployment_version_metadata) + rpc_call_on_trait!( + client, + call, + WorkflowService, + update_worker_deployment_version_metadata + ) + } + "UpdateWorkerVersioningRules" => rpc_call_on_trait!( + client, + call, + WorkflowService, + update_worker_versioning_rules + ), + "UpdateWorkflowExecution" => { + rpc_call_on_trait!(client, call, WorkflowService, update_workflow_execution) } - "UpdateWorkerVersioningRules" => rpc_call!(client, call, update_worker_versioning_rules), - "UpdateWorkflowExecution" => rpc_call!(client, call, update_workflow_execution), "UpdateWorkflowExecutionOptions" => { - rpc_call!(client, call, update_workflow_execution_options) + rpc_call_on_trait!( + client, + call, + WorkflowService, + update_workflow_execution_options + ) } "UpdateWorkerBuildIdCompatibility" => { - rpc_call!(client, call, update_worker_build_id_compatibility) + rpc_call_on_trait!( + client, + call, + WorkflowService, + update_worker_build_id_compatibility + ) } rpc => Err(anyhow::anyhow!("Unknown RPC call {rpc}")), } @@ -696,8 +904,12 @@ async fn call_operator_service( let rpc = call.rpc.to_str(); let mut client = client.clone(); match rpc { - "AddOrUpdateRemoteCluster" => rpc_call!(client, call, add_or_update_remote_cluster), - "AddSearchAttributes" => rpc_call!(client, call, add_search_attributes), + "AddOrUpdateRemoteCluster" => { + rpc_call_on_trait!(client, call, OperatorService, add_or_update_remote_cluster) + } + "AddSearchAttributes" => { + rpc_call_on_trait!(client, call, OperatorService, add_search_attributes) + } "CreateNexusEndpoint" => { rpc_call_on_trait!(client, call, OperatorService, create_nexus_endpoint) } @@ -705,13 +917,20 @@ async fn call_operator_service( "DeleteNexusEndpoint" => { rpc_call_on_trait!(client, call, OperatorService, delete_nexus_endpoint) } - "DeleteWorkflowExecution" => rpc_call!(client, call, delete_workflow_execution), "GetNexusEndpoint" => rpc_call_on_trait!(client, call, OperatorService, get_nexus_endpoint), - "ListClusters" => rpc_call!(client, call, list_clusters), - "ListNexusEndpoints" => rpc_call!(client, call, list_nexus_endpoints), - "ListSearchAttributes" => rpc_call!(client, call, list_search_attributes), - "RemoveRemoteCluster" => rpc_call!(client, call, remove_remote_cluster), - "RemoveSearchAttributes" => rpc_call!(client, call, remove_search_attributes), + "ListClusters" => rpc_call_on_trait!(client, call, OperatorService, list_clusters), + "ListNexusEndpoints" => { + rpc_call_on_trait!(client, call, OperatorService, list_nexus_endpoints) + } + "ListSearchAttributes" => { + rpc_call_on_trait!(client, call, OperatorService, list_search_attributes) + } + "RemoveRemoteCluster" => { + rpc_call_on_trait!(client, call, OperatorService, remove_remote_cluster) + } + "RemoveSearchAttributes" => { + rpc_call_on_trait!(client, call, OperatorService, remove_search_attributes) + } "UpdateNexusEndpoint" => { rpc_call_on_trait!(client, call, OperatorService, update_nexus_endpoint) } @@ -723,68 +942,116 @@ async fn call_cloud_service(client: &CoreClient, call: &RpcCallOptions) -> anyho let rpc = call.rpc.to_str(); let mut client = client.clone(); match rpc { - "AddNamespaceRegion" => rpc_call!(client, call, add_namespace_region), - "AddUserGroupMember" => rpc_call!(client, call, add_user_group_member), - "CreateApiKey" => rpc_call!(client, call, create_api_key), - "CreateNamespace" => rpc_call!(client, call, create_namespace), - "CreateNamespaceExportSink" => rpc_call!(client, call, create_namespace_export_sink), + "AddNamespaceRegion" => { + rpc_call_on_trait!(client, call, CloudService, add_namespace_region) + } + "AddUserGroupMember" => { + rpc_call_on_trait!(client, call, CloudService, add_user_group_member) + } + "CreateApiKey" => rpc_call_on_trait!(client, call, CloudService, create_api_key), + "CreateNamespace" => rpc_call_on_trait!(client, call, CloudService, create_namespace), + "CreateNamespaceExportSink" => { + rpc_call_on_trait!(client, call, CloudService, create_namespace_export_sink) + } "CreateNexusEndpoint" => { rpc_call_on_trait!(client, call, CloudService, create_nexus_endpoint) } - "CreateServiceAccount" => rpc_call!(client, call, create_service_account), - "CreateUserGroup" => rpc_call!(client, call, create_user_group), - "CreateUser" => rpc_call!(client, call, create_user), - "DeleteApiKey" => rpc_call!(client, call, delete_api_key), + "CreateServiceAccount" => { + rpc_call_on_trait!(client, call, CloudService, create_service_account) + } + "CreateUserGroup" => rpc_call_on_trait!(client, call, CloudService, create_user_group), + "CreateUser" => rpc_call_on_trait!(client, call, CloudService, create_user), + "DeleteApiKey" => rpc_call_on_trait!(client, call, CloudService, delete_api_key), "DeleteNamespace" => rpc_call_on_trait!(client, call, CloudService, delete_namespace), - "DeleteNamespaceExportSink" => rpc_call!(client, call, delete_namespace_export_sink), - "DeleteNamespaceRegion" => rpc_call!(client, call, delete_namespace_region), + "DeleteNamespaceExportSink" => { + rpc_call_on_trait!(client, call, CloudService, delete_namespace_export_sink) + } + "DeleteNamespaceRegion" => { + rpc_call_on_trait!(client, call, CloudService, delete_namespace_region) + } "DeleteNexusEndpoint" => { rpc_call_on_trait!(client, call, CloudService, delete_nexus_endpoint) } - "DeleteServiceAccount" => rpc_call!(client, call, delete_service_account), - "DeleteUserGroup" => rpc_call!(client, call, delete_user_group), - "DeleteUser" => rpc_call!(client, call, delete_user), - "FailoverNamespaceRegion" => rpc_call!(client, call, failover_namespace_region), - "GetAccount" => rpc_call!(client, call, get_account), - "GetApiKey" => rpc_call!(client, call, get_api_key), - "GetApiKeys" => rpc_call!(client, call, get_api_keys), - "GetAsyncOperation" => rpc_call!(client, call, get_async_operation), - "GetNamespace" => rpc_call!(client, call, get_namespace), - "GetNamespaceExportSink" => rpc_call!(client, call, get_namespace_export_sink), - "GetNamespaceExportSinks" => rpc_call!(client, call, get_namespace_export_sinks), - "GetNamespaces" => rpc_call!(client, call, get_namespaces), + "DeleteServiceAccount" => { + rpc_call_on_trait!(client, call, CloudService, delete_service_account) + } + "DeleteUserGroup" => rpc_call_on_trait!(client, call, CloudService, delete_user_group), + "DeleteUser" => rpc_call_on_trait!(client, call, CloudService, delete_user), + "FailoverNamespaceRegion" => { + rpc_call_on_trait!(client, call, CloudService, failover_namespace_region) + } + "GetAccount" => rpc_call_on_trait!(client, call, CloudService, get_account), + "GetApiKey" => rpc_call_on_trait!(client, call, CloudService, get_api_key), + "GetApiKeys" => rpc_call_on_trait!(client, call, CloudService, get_api_keys), + "GetAsyncOperation" => rpc_call_on_trait!(client, call, CloudService, get_async_operation), + "GetNamespace" => rpc_call_on_trait!(client, call, CloudService, get_namespace), + "GetNamespaceExportSink" => { + rpc_call_on_trait!(client, call, CloudService, get_namespace_export_sink) + } + "GetNamespaceExportSinks" => { + rpc_call_on_trait!(client, call, CloudService, get_namespace_export_sinks) + } + "GetNamespaces" => rpc_call_on_trait!(client, call, CloudService, get_namespaces), "GetNexusEndpoint" => rpc_call_on_trait!(client, call, CloudService, get_nexus_endpoint), - "GetNexusEndpoints" => rpc_call!(client, call, get_nexus_endpoints), - "GetRegion" => rpc_call!(client, call, get_region), - "GetRegions" => rpc_call!(client, call, get_regions), - "GetServiceAccount" => rpc_call!(client, call, get_service_account), - "GetServiceAccounts" => rpc_call!(client, call, get_service_accounts), - "GetUsage" => rpc_call!(client, call, get_usage), - "GetUserGroup" => rpc_call!(client, call, get_user_group), - "GetUserGroupMembers" => rpc_call!(client, call, get_user_group_members), - "GetUserGroups" => rpc_call!(client, call, get_user_groups), - "GetUser" => rpc_call!(client, call, get_user), - "GetUsers" => rpc_call!(client, call, get_users), - "RemoveUserGroupMember" => rpc_call!(client, call, remove_user_group_member), - "RenameCustomSearchAttribute" => rpc_call!(client, call, rename_custom_search_attribute), - "SetUserGroupNamespaceAccess" => rpc_call!(client, call, set_user_group_namespace_access), - "SetUserNamespaceAccess" => rpc_call!(client, call, set_user_namespace_access), - "UpdateAccount" => rpc_call!(client, call, update_account), - "UpdateApiKey" => rpc_call!(client, call, update_api_key), + "GetNexusEndpoints" => rpc_call_on_trait!(client, call, CloudService, get_nexus_endpoints), + "GetRegion" => rpc_call_on_trait!(client, call, CloudService, get_region), + "GetRegions" => rpc_call_on_trait!(client, call, CloudService, get_regions), + "GetServiceAccount" => rpc_call_on_trait!(client, call, CloudService, get_service_account), + "GetServiceAccounts" => { + rpc_call_on_trait!(client, call, CloudService, get_service_accounts) + } + "GetUsage" => rpc_call_on_trait!(client, call, CloudService, get_usage), + "GetUserGroup" => rpc_call_on_trait!(client, call, CloudService, get_user_group), + "GetUserGroupMembers" => { + rpc_call_on_trait!(client, call, CloudService, get_user_group_members) + } + "GetUserGroups" => rpc_call_on_trait!(client, call, CloudService, get_user_groups), + "GetUser" => rpc_call_on_trait!(client, call, CloudService, get_user), + "GetUsers" => rpc_call_on_trait!(client, call, CloudService, get_users), + "RemoveUserGroupMember" => { + rpc_call_on_trait!(client, call, CloudService, remove_user_group_member) + } + "RenameCustomSearchAttribute" => { + rpc_call_on_trait!(client, call, CloudService, rename_custom_search_attribute) + } + "SetUserGroupNamespaceAccess" => { + rpc_call_on_trait!(client, call, CloudService, set_user_group_namespace_access) + } + "SetUserNamespaceAccess" => { + rpc_call_on_trait!(client, call, CloudService, set_user_namespace_access) + } + "UpdateAccount" => rpc_call_on_trait!(client, call, CloudService, update_account), + "UpdateApiKey" => rpc_call_on_trait!(client, call, CloudService, update_api_key), "UpdateNamespace" => rpc_call_on_trait!(client, call, CloudService, update_namespace), - "UpdateNamespaceExportSink" => rpc_call!(client, call, update_namespace_export_sink), + "UpdateNamespaceExportSink" => { + rpc_call_on_trait!(client, call, CloudService, update_namespace_export_sink) + } "UpdateNexusEndpoint" => { rpc_call_on_trait!(client, call, CloudService, update_nexus_endpoint) } - "UpdateServiceAccount" => rpc_call!(client, call, update_service_account), - "UpdateUserGroup" => rpc_call!(client, call, update_user_group), - "UpdateUser" => rpc_call!(client, call, update_user), - "ValidateNamespaceExportSink" => rpc_call!(client, call, validate_namespace_export_sink), - "UpdateNamespaceTags" => rpc_call!(client, call, update_namespace_tags), - "CreateConnectivityRule" => rpc_call!(client, call, create_connectivity_rule), - "GetConnectivityRule" => rpc_call!(client, call, get_connectivity_rule), - "GetConnectivityRules" => rpc_call!(client, call, get_connectivity_rules), - "DeleteConnectivityRule" => rpc_call!(client, call, delete_connectivity_rule), + "UpdateServiceAccount" => { + rpc_call_on_trait!(client, call, CloudService, update_service_account) + } + "UpdateUserGroup" => rpc_call_on_trait!(client, call, CloudService, update_user_group), + "UpdateUser" => rpc_call_on_trait!(client, call, CloudService, update_user), + "ValidateNamespaceExportSink" => { + rpc_call_on_trait!(client, call, CloudService, validate_namespace_export_sink) + } + "UpdateNamespaceTags" => { + rpc_call_on_trait!(client, call, CloudService, update_namespace_tags) + } + "CreateConnectivityRule" => { + rpc_call_on_trait!(client, call, CloudService, create_connectivity_rule) + } + "GetConnectivityRule" => { + rpc_call_on_trait!(client, call, CloudService, get_connectivity_rule) + } + "GetConnectivityRules" => { + rpc_call_on_trait!(client, call, CloudService, get_connectivity_rules) + } + "DeleteConnectivityRule" => { + rpc_call_on_trait!(client, call, CloudService, delete_connectivity_rule) + } rpc => Err(anyhow::anyhow!("Unknown RPC call {rpc}")), } } @@ -793,12 +1060,14 @@ async fn call_test_service(client: &CoreClient, call: &RpcCallOptions) -> anyhow let rpc = call.rpc.to_str(); let mut client = client.clone(); match rpc { - "GetCurrentTime" => rpc_call!(client, call, get_current_time), - "LockTimeSkipping" => rpc_call!(client, call, lock_time_skipping), - "SleepUntil" => rpc_call!(client, call, sleep_until), - "Sleep" => rpc_call!(client, call, sleep), - "UnlockTimeSkippingWithSleep" => rpc_call!(client, call, unlock_time_skipping_with_sleep), - "UnlockTimeSkipping" => rpc_call!(client, call, unlock_time_skipping), + "GetCurrentTime" => rpc_call_on_trait!(client, call, TestService, get_current_time), + "LockTimeSkipping" => rpc_call_on_trait!(client, call, TestService, lock_time_skipping), + "SleepUntil" => rpc_call_on_trait!(client, call, TestService, sleep_until), + "Sleep" => rpc_call_on_trait!(client, call, TestService, sleep), + "UnlockTimeSkippingWithSleep" => { + rpc_call_on_trait!(client, call, TestService, unlock_time_skipping_with_sleep) + } + "UnlockTimeSkipping" => rpc_call_on_trait!(client, call, TestService, unlock_time_skipping), rpc => Err(anyhow::anyhow!("Unknown RPC call {rpc}")), } } @@ -810,7 +1079,7 @@ async fn call_health_service( let rpc = call.rpc.to_str(); let mut client = client.clone(); match rpc { - "Check" => rpc_call!(client, call, check), + "Check" => rpc_call_on_trait!(client, call, HealthService, check), "Watch" => Err(anyhow::anyhow!( "Health service Watch method is not implemented in C bridge" )), diff --git a/core-c-bridge/src/tests/mod.rs b/core-c-bridge/src/tests/mod.rs index 697f2cafc..169847d43 100644 --- a/core-c-bridge/src/tests/mod.rs +++ b/core-c-bridge/src/tests/mod.rs @@ -317,7 +317,7 @@ fn test_simple_callback_override() { })) .unwrap(); let start_resp = StartWorkflowExecutionResponse::decode(&*start_resp_raw).unwrap(); - assert!(start_resp.run_id == "run-id for my-workflow-id"); + assert_eq!(start_resp.run_id, "run-id for my-workflow-id"); // Try a query where a query failure will actually be delivered as failure details. // However, we don't currently have temporal_sdk_core_protos::google::rpc::Status in @@ -336,23 +336,23 @@ fn test_simple_callback_override() { .unwrap_err() .downcast::() .unwrap(); - assert!(query_err.status_code == tonic::Code::Internal as u32); - assert!(query_err.message == "query-fail"); - assert!( + assert_eq!(query_err.status_code, tonic::Code::Internal as u32); + assert_eq!(query_err.message, "query-fail"); + assert_eq!( Failure::decode(query_err.details.as_ref().unwrap().as_slice()) .unwrap() - .message - == "intentional failure" + .message, + "intentional failure" ); // Confirm we got the expected calls - assert!( - *CALLBACK_OVERRIDE_CALLS.lock().unwrap() - == vec![ - "service: temporal.api.workflowservice.v1.WorkflowService, rpc: GetSystemInfo", - "service: temporal.api.workflowservice.v1.WorkflowService, rpc: StartWorkflowExecution", - "service: temporal.api.workflowservice.v1.WorkflowService, rpc: QueryWorkflow" - ] + assert_eq!( + *CALLBACK_OVERRIDE_CALLS.lock().unwrap(), + vec![ + "service: temporal.api.workflowservice.v1.WorkflowService, rpc: GetSystemInfo", + "service: temporal.api.workflowservice.v1.WorkflowService, rpc: StartWorkflowExecution", + "service: temporal.api.workflowservice.v1.WorkflowService, rpc: QueryWorkflow" + ] ); }); } diff --git a/core/src/lib.rs b/core/src/lib.rs index c3aa1862d..35c20fdce 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -62,9 +62,7 @@ use crate::{ use anyhow::bail; use futures_util::Stream; use std::sync::{Arc, OnceLock}; -use temporal_client::{ - ConfiguredClient, NamespacedClient, SharedReplaceableClient, TemporalServiceClientWithMetrics, -}; +use temporal_client::{ConfiguredClient, NamespacedClient, SharedReplaceableClient}; use temporal_sdk_core_api::{ Worker as WorkerTrait, errors::{CompleteActivityError, PollError}, @@ -174,23 +172,23 @@ pub(crate) fn sticky_q_name_for_worker( mod sealed { use super::*; - use temporal_client::SharedReplaceableClient; + use temporal_client::{SharedReplaceableClient, TemporalServiceClient}; /// Allows passing different kinds of clients into things that want to be flexible. Motivating /// use-case was worker initialization. /// /// Needs to exist in this crate to avoid blanket impl conflicts. pub struct AnyClient { - pub(crate) inner: Box>, + pub(crate) inner: Box>, } impl AnyClient { - pub(crate) fn into_inner(self) -> Box> { + pub(crate) fn into_inner(self) -> Box> { self.inner } } - impl From> for AnyClient { - fn from(c: ConfiguredClient) -> Self { + impl From> for AnyClient { + fn from(c: ConfiguredClient) -> Self { Self { inner: Box::new(c) } } } diff --git a/core/src/worker/client.rs b/core/src/worker/client.rs index 53577713d..a480180d9 100644 --- a/core/src/worker/client.rs +++ b/core/src/worker/client.rs @@ -434,7 +434,7 @@ impl WorkerClient for WorkerClientBag { Ok(self .client .clone() - .respond_workflow_task_completed(request) + .respond_workflow_task_completed(request.into_request()) .await? .into_inner()) } @@ -458,7 +458,8 @@ impl WorkerClient for WorkerClientBag { // Will never be set, deprecated. deployment: None, deployment_options: self.deployment_options(), - }, + } + .into_request(), ) .await? .into_inner()) @@ -472,12 +473,15 @@ impl WorkerClient for WorkerClientBag { Ok(self .client .clone() - .respond_nexus_task_completed(RespondNexusTaskCompletedRequest { - namespace: self.namespace.clone(), - identity: self.identity.clone(), - task_token: task_token.0, - response: Some(response), - }) + .respond_nexus_task_completed( + RespondNexusTaskCompletedRequest { + namespace: self.namespace.clone(), + identity: self.identity.clone(), + task_token: task_token.0, + response: Some(response), + } + .into_request(), + ) .await? .into_inner()) } @@ -490,12 +494,15 @@ impl WorkerClient for WorkerClientBag { Ok(self .client .clone() - .record_activity_task_heartbeat(RecordActivityTaskHeartbeatRequest { - task_token: task_token.0, - details, - identity: self.identity.clone(), - namespace: self.namespace.clone(), - }) + .record_activity_task_heartbeat( + RecordActivityTaskHeartbeatRequest { + task_token: task_token.0, + details, + identity: self.identity.clone(), + namespace: self.namespace.clone(), + } + .into_request(), + ) .await? .into_inner()) } @@ -519,7 +526,8 @@ impl WorkerClient for WorkerClientBag { // Will never be set, deprecated. deployment: None, deployment_options: self.deployment_options(), - }, + } + .into_request(), ) .await? .into_inner()) @@ -546,7 +554,8 @@ impl WorkerClient for WorkerClientBag { // Will never be set, deprecated. deployment: None, deployment_options: self.deployment_options(), - }, + } + .into_request(), ) .await? .into_inner()) @@ -575,7 +584,7 @@ impl WorkerClient for WorkerClientBag { Ok(self .client .clone() - .respond_workflow_task_failed(request) + .respond_workflow_task_failed(request.into_request()) .await? .into_inner()) } @@ -588,12 +597,15 @@ impl WorkerClient for WorkerClientBag { Ok(self .client .clone() - .respond_nexus_task_failed(RespondNexusTaskFailedRequest { - namespace: self.namespace.clone(), - identity: self.identity.clone(), - task_token: task_token.0, - error: Some(error), - }) + .respond_nexus_task_failed( + RespondNexusTaskFailedRequest { + namespace: self.namespace.clone(), + identity: self.identity.clone(), + task_token: task_token.0, + error: Some(error), + } + .into_request(), + ) .await? .into_inner()) } @@ -607,15 +619,18 @@ impl WorkerClient for WorkerClientBag { Ok(self .client .clone() - .get_workflow_execution_history(GetWorkflowExecutionHistoryRequest { - namespace: self.namespace.clone(), - execution: Some(WorkflowExecution { - workflow_id, - run_id: run_id.unwrap_or_default(), - }), - next_page_token: page_token, - ..Default::default() - }) + .get_workflow_execution_history( + GetWorkflowExecutionHistoryRequest { + namespace: self.namespace.clone(), + execution: Some(WorkflowExecution { + workflow_id, + run_id: run_id.unwrap_or_default(), + }), + next_page_token: page_token, + ..Default::default() + } + .into_request(), + ) .await? .into_inner()) } @@ -640,15 +655,18 @@ impl WorkerClient for WorkerClientBag { Ok(self .client .clone() - .respond_query_task_completed(RespondQueryTaskCompletedRequest { - task_token: task_token.into(), - completed_type: completed_type as i32, - query_result, - error_message, - namespace: self.namespace.clone(), - failure, - cause: cause.into(), - }) + .respond_query_task_completed( + RespondQueryTaskCompletedRequest { + task_token: task_token.into(), + completed_type: completed_type as i32, + query_result, + error_message, + namespace: self.namespace.clone(), + failure, + cause: cause.into(), + } + .into_request(), + ) .await? .into_inner()) } @@ -658,7 +676,9 @@ impl WorkerClient for WorkerClientBag { .client .clone() .describe_namespace( - Namespace::Name(self.namespace.clone()).into_describe_namespace_request(), + Namespace::Name(self.namespace.clone()) + .into_describe_namespace_request() + .into_request(), ) .await? .into_inner()) @@ -674,7 +694,7 @@ impl WorkerClient for WorkerClientBag { }; Ok( - WorkflowService::shutdown_worker(&mut self.client.clone(), request) + WorkflowService::shutdown_worker(&mut self.client.clone(), request.into_request()) .await? .into_inner(), ) @@ -691,11 +711,14 @@ impl WorkerClient for WorkerClientBag { Ok(self .client .clone() - .record_worker_heartbeat(RecordWorkerHeartbeatRequest { - namespace: self.namespace.clone(), - identity: self.identity.clone(), - worker_heartbeat: vec![heartbeat], - }) + .record_worker_heartbeat( + RecordWorkerHeartbeatRequest { + namespace: self.namespace.clone(), + identity: self.identity.clone(), + worker_heartbeat: vec![heartbeat], + } + .into_request(), + ) .await? .into_inner()) } @@ -719,7 +742,7 @@ impl WorkerClient for WorkerClientBag { fn sdk_name_and_version(&self) -> (String, String) { let inner = self.client.get_client().inner_cow(); - let opts = inner.inner().options(); + let opts = inner.options(); (opts.client_name.clone(), opts.client_version.clone()) } diff --git a/tests/common/mod.rs b/tests/common/mod.rs index d3c9dd3a0..a53d702d0 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -65,6 +65,7 @@ use temporal_sdk_core_protos::{ }, }; use tokio::{sync::OnceCell, task::AbortHandle}; +use tonic::IntoRequest; use tracing::{debug, warn}; use url::Url; @@ -256,8 +257,7 @@ impl CoreWfStarter { .get_client() .inner() .workflow_svc() - .clone() - .get_cluster_info(GetClusterInfoRequest::default()) + .get_cluster_info(GetClusterInfoRequest::default().into_request()) .await; let srv_ver = semver::Version::parse( &clustinfo diff --git a/tests/integ_tests/client_tests.rs b/tests/integ_tests/client_tests.rs index 9a4473b35..4dfa6a4ec 100644 --- a/tests/integ_tests/client_tests.rs +++ b/tests/integ_tests/client_tests.rs @@ -32,7 +32,7 @@ use tokio::{ sync::{mpsc::UnboundedSender, oneshot}, }; use tonic::{ - Code, Request, Status, + Code, IntoRequest, Request, Status, body::Body, codegen::{Service, http::Response}, server::NamedService, @@ -56,10 +56,13 @@ async fn can_use_retry_raw_client() { let opts = get_integ_server_options(); let mut client = opts.connect_no_namespace(None).await.unwrap(); client - .describe_namespace(DescribeNamespaceRequest { - namespace: NAMESPACE.to_string(), - ..Default::default() - }) + .describe_namespace( + DescribeNamespaceRequest { + namespace: NAMESPACE.to_string(), + ..Default::default() + } + .into_request(), + ) .await .unwrap(); } @@ -79,10 +82,13 @@ async fn per_call_timeout_respected_whole_client() { hm.insert("grpc-timeout".to_string(), "0S".to_string()); raw_client.get_client().set_headers(hm).unwrap(); let err = raw_client - .describe_namespace(DescribeNamespaceRequest { - namespace: NAMESPACE.to_string(), - ..Default::default() - }) + .describe_namespace( + DescribeNamespaceRequest { + namespace: NAMESPACE.to_string(), + ..Default::default() + } + .into_request(), + ) .await .unwrap_err(); assert_matches!(err.code(), Code::DeadlineExceeded | Code::Cancelled); @@ -409,12 +415,15 @@ async fn cloud_ops_test() { hm.insert("temporal-cloud-api-version".to_string(), api_version); hm }); - let mut client = opts.connect_no_namespace(None).await.unwrap().into_inner(); - let cloud_client = client.cloud_svc_mut(); + let client = opts.connect_no_namespace(None).await.unwrap().into_inner(); + let mut cloud_client = client.cloud_svc(); let res = cloud_client - .get_namespace(GetNamespaceRequest { - namespace: namespace.clone(), - }) + .get_namespace( + GetNamespaceRequest { + namespace: namespace.clone(), + } + .into_request(), + ) .await .unwrap(); assert_eq!(res.into_inner().namespace.unwrap().namespace, namespace); diff --git a/tests/integ_tests/ephemeral_server_tests.rs b/tests/integ_tests/ephemeral_server_tests.rs index ef3a5abe4..1b9b612a3 100644 --- a/tests/integ_tests/ephemeral_server_tests.rs +++ b/tests/integ_tests/ephemeral_server_tests.rs @@ -7,6 +7,7 @@ use temporal_sdk_core::ephemeral_server::{ default_cached_download, }; use temporal_sdk_core_protos::temporal::api::workflowservice::v1::DescribeNamespaceRequest; +use tonic::IntoRequest; use url::Url; #[tokio::test] @@ -147,17 +148,20 @@ async fn assert_ephemeral_server(server: &EphemeralServer) { .await .unwrap(); let resp = client - .describe_namespace(DescribeNamespaceRequest { - namespace: NAMESPACE.to_string(), - ..Default::default() - }) + .describe_namespace( + DescribeNamespaceRequest { + namespace: NAMESPACE.to_string(), + ..Default::default() + } + .into_request(), + ) .await .unwrap(); assert!(resp.into_inner().namespace_info.unwrap().name == "default"); // If it has test service, make sure we can use it too if server.has_test_service { - let resp = client.get_current_time(()).await.unwrap(); + let resp = client.get_current_time(().into_request()).await.unwrap(); // Make sure it's within 5 mins of now let resp_seconds = resp.get_ref().time.as_ref().unwrap().seconds as u64; let curr_seconds = SystemTime::now() diff --git a/tests/integ_tests/metrics_tests.rs b/tests/integ_tests/metrics_tests.rs index dd4068ac4..f6432cd40 100644 --- a/tests/integ_tests/metrics_tests.rs +++ b/tests/integ_tests/metrics_tests.rs @@ -70,6 +70,7 @@ use temporal_sdk_core_protos::{ }, }; use tokio::{join, sync::Barrier}; +use tonic::IntoRequest; use url::Url; pub(crate) async fn get_text(endpoint: String) -> String { @@ -106,7 +107,7 @@ async fn prometheus_metrics_exported( assert!(raw_client.get_client().capabilities().is_some()); let _ = raw_client - .list_namespaces(ListNamespacesRequest::default()) + .list_namespaces(ListNamespacesRequest::default().into_request()) .await .unwrap(); @@ -538,7 +539,7 @@ fn runtime_new() { .unwrap(); assert!(raw_client.get_client().capabilities().is_some()); let _ = raw_client - .list_namespaces(ListNamespacesRequest::default()) + .list_namespaces(ListNamespacesRequest::default().into_request()) .await .unwrap(); let body = get_text(format!("http://{addr}/metrics")).await; @@ -632,9 +633,12 @@ async fn request_fail_codes() { .unwrap(); // Describe namespace w/ invalid argument (unset namespace field) - WorkflowService::describe_namespace(&mut client, DescribeNamespaceRequest::default()) - .await - .unwrap_err(); + WorkflowService::describe_namespace( + &mut client, + DescribeNamespaceRequest::default().into_request(), + ) + .await + .unwrap_err(); let body = get_text(format!("http://{addr}/metrics")).await; let matching_line = body @@ -677,9 +681,12 @@ async fn request_fail_codes_otel() { for _ in 0..10 { // Describe namespace w/ invalid argument (unset namespace field) - WorkflowService::describe_namespace(&mut client, DescribeNamespaceRequest::default()) - .await - .unwrap_err(); + WorkflowService::describe_namespace( + &mut client, + DescribeNamespaceRequest::default().into_request(), + ) + .await + .unwrap_err(); tokio::time::sleep(Duration::from_secs(1)).await; } diff --git a/tests/integ_tests/update_tests.rs b/tests/integ_tests/update_tests.rs index 3a23668e3..748541f03 100644 --- a/tests/integ_tests/update_tests.rs +++ b/tests/integ_tests/update_tests.rs @@ -42,6 +42,7 @@ use temporal_sdk_core_protos::{ test_utils::start_timer_cmd, }; use tokio::{join, sync::Barrier}; +use tonic::IntoRequest; use uuid::Uuid; #[derive(Clone, Copy)] @@ -125,7 +126,8 @@ async fn reapplied_updates_due_to_reset() { reset_reapply_type: ResetReapplyType::AllEligible as i32, request_id: Uuid::new_v4().to_string(), ..Default::default() - }, + } + .into_request(), ) .await .unwrap() diff --git a/tests/integ_tests/worker_versioning_tests.rs b/tests/integ_tests/worker_versioning_tests.rs index dcc2e27bd..f532dafe4 100644 --- a/tests/integ_tests/worker_versioning_tests.rs +++ b/tests/integ_tests/worker_versioning_tests.rs @@ -23,6 +23,7 @@ use temporal_sdk_core_protos::{ }, }; use tokio::join; +use tonic::IntoRequest; #[rstest::rstest] #[tokio::test] @@ -76,10 +77,13 @@ async fn sets_deployment_info_on_task_responses(#[values(true, false)] use_defau client .get_client() .clone() - .describe_worker_deployment(DescribeWorkerDeploymentRequest { - namespace: client.namespace(), - deployment_name: deploy_name.clone(), - }) + .describe_worker_deployment( + DescribeWorkerDeploymentRequest { + namespace: client.namespace(), + deployment_name: deploy_name.clone(), + } + .into_request(), + ) .await }, Duration::from_secs(5), @@ -92,13 +96,16 @@ async fn sets_deployment_info_on_task_responses(#[values(true, false)] use_defau client .get_client() .clone() - .set_worker_deployment_current_version(SetWorkerDeploymentCurrentVersionRequest { - namespace: client.namespace(), - deployment_name: deploy_name.clone(), - version: format!("{deploy_name}.1.0"), - conflict_token: desc_resp.conflict_token, - ..Default::default() - }) + .set_worker_deployment_current_version( + SetWorkerDeploymentCurrentVersionRequest { + namespace: client.namespace(), + deployment_name: deploy_name.clone(), + version: format!("{deploy_name}.1.0"), + conflict_token: desc_resp.conflict_token, + ..Default::default() + } + .into_request(), + ) .await .unwrap(); @@ -178,10 +185,13 @@ async fn activity_has_deployment_stamp() { client .get_client() .clone() - .describe_worker_deployment(DescribeWorkerDeploymentRequest { - namespace: client.namespace(), - deployment_name: deploy_name.clone(), - }) + .describe_worker_deployment( + DescribeWorkerDeploymentRequest { + namespace: client.namespace(), + deployment_name: deploy_name.clone(), + } + .into_request(), + ) .await }, Duration::from_secs(50), @@ -194,13 +204,16 @@ async fn activity_has_deployment_stamp() { client .get_client() .clone() - .set_worker_deployment_current_version(SetWorkerDeploymentCurrentVersionRequest { - namespace: client.namespace(), - deployment_name: deploy_name.clone(), - version: format!("{deploy_name}.1.0"), - conflict_token: desc_resp.conflict_token, - ..Default::default() - }) + .set_worker_deployment_current_version( + SetWorkerDeploymentCurrentVersionRequest { + namespace: client.namespace(), + deployment_name: deploy_name.clone(), + version: format!("{deploy_name}.1.0"), + conflict_token: desc_resp.conflict_token, + ..Default::default() + } + .into_request(), + ) .await .unwrap(); diff --git a/tests/integ_tests/workflow_tests/resets.rs b/tests/integ_tests/workflow_tests/resets.rs index 6b356da9d..cd6fec88a 100644 --- a/tests/integ_tests/workflow_tests/resets.rs +++ b/tests/integ_tests/workflow_tests/resets.rs @@ -19,6 +19,7 @@ use temporal_sdk_core_protos::{ }, }; use tokio::sync::Notify; +use tonic::IntoRequest; const POST_RESET_SIG: &str = "post-reset"; @@ -65,17 +66,20 @@ async fn reset_workflow() { notify.notified().await; // Do the reset client - .reset_workflow_execution(ResetWorkflowExecutionRequest { - namespace: NAMESPACE.to_owned(), - workflow_execution: Some(WorkflowExecution { - workflow_id: wf_name.to_owned(), - run_id: run_id.clone(), - }), - // End of first WFT - workflow_task_finish_event_id: 4, - request_id: "test-req-id".to_owned(), - ..Default::default() - }) + .reset_workflow_execution( + ResetWorkflowExecutionRequest { + namespace: NAMESPACE.to_owned(), + workflow_execution: Some(WorkflowExecution { + workflow_id: wf_name.to_owned(), + run_id: run_id.clone(), + }), + // End of first WFT + workflow_task_finish_event_id: 4, + request_id: "test-req-id".to_owned(), + ..Default::default() + } + .into_request(), + ) .await .unwrap(); @@ -191,16 +195,19 @@ async fn reset_randomseed() { notify.notified().await; // Reset the workflow to be after first timer has fired client - .reset_workflow_execution(ResetWorkflowExecutionRequest { - namespace: NAMESPACE.to_owned(), - workflow_execution: Some(WorkflowExecution { - workflow_id: wf_name.to_owned(), - run_id: run_id.clone(), - }), - workflow_task_finish_event_id: 14, - request_id: "test-req-id".to_owned(), - ..Default::default() - }) + .reset_workflow_execution( + ResetWorkflowExecutionRequest { + namespace: NAMESPACE.to_owned(), + workflow_execution: Some(WorkflowExecution { + workflow_id: wf_name.to_owned(), + run_id: run_id.clone(), + }), + workflow_task_finish_event_id: 14, + request_id: "test-req-id".to_owned(), + ..Default::default() + } + .into_request(), + ) .await .unwrap(); diff --git a/tests/main.rs b/tests/main.rs index 4e97e33ef..8828d4899 100644 --- a/tests/main.rs +++ b/tests/main.rs @@ -38,6 +38,7 @@ mod integ_tests { operatorservice::v1::CreateNexusEndpointRequest, workflowservice::v1::ListNamespacesRequest, }; + use tonic::IntoRequest; // Create a worker like a bridge would (unwraps aside) #[tokio::test] @@ -64,27 +65,32 @@ mod integ_tests { // Do things with worker or client let _ = retrying_client - .list_namespaces(ListNamespacesRequest::default()) + .list_namespaces(ListNamespacesRequest::default().into_request()) .await; } pub(crate) async fn mk_nexus_endpoint(starter: &mut CoreWfStarter) -> String { let client = starter.get_client().await; let endpoint = format!("mycoolendpoint-{}", rand_6_chars()); - let mut op_client = client.get_client().inner().operator_svc().clone(); + let mut op_client = client.get_client().inner().operator_svc(); op_client - .create_nexus_endpoint(CreateNexusEndpointRequest { - spec: Some(EndpointSpec { - name: endpoint.to_owned(), - description: None, - target: Some(EndpointTarget { - variant: Some(endpoint_target::Variant::Worker(endpoint_target::Worker { - namespace: client.namespace(), - task_queue: starter.get_task_queue().to_owned(), - })), + .create_nexus_endpoint( + CreateNexusEndpointRequest { + spec: Some(EndpointSpec { + name: endpoint.to_owned(), + description: None, + target: Some(EndpointTarget { + variant: Some(endpoint_target::Variant::Worker( + endpoint_target::Worker { + namespace: client.namespace(), + task_queue: starter.get_task_queue().to_owned(), + }, + )), + }), }), - }), - }) + } + .into_request(), + ) .await .unwrap(); // Endpoint creation can (as of server 1.25.2 at least) return before they are actually usable.