diff --git a/temporalio/bridge/worker.py b/temporalio/bridge/worker.py index 297765be0..a9c857373 100644 --- a/temporalio/bridge/worker.py +++ b/temporalio/bridge/worker.py @@ -303,7 +303,7 @@ async def decode_activation( activation: temporalio.bridge.proto.workflow_activation.WorkflowActivation, data_converter: temporalio.converter.DataConverter, decode_headers: bool, - concurrency_limit: int, + storage_concurrency_limit: int, ) -> temporalio.converter._extstore.StorageOperationMetrics: """Decode all payloads in the activation. @@ -315,8 +315,16 @@ async def decode_activation( await CommandAwarePayloadVisitor( skip_search_attributes=True, skip_headers=not decode_headers, - concurrency_limit=concurrency_limit, - ).visit(_Visitor(data_converter._decode_payload_sequence), activation) + concurrency_limit=storage_concurrency_limit, + ).visit( + _Visitor(data_converter._external_retrieve_payload_sequence), activation + ) + + await CommandAwarePayloadVisitor( + skip_search_attributes=True, + skip_headers=not decode_headers, + ).visit(_Visitor(data_converter._decode_payload_sequence), activation) + return metrics @@ -324,18 +332,31 @@ async def encode_completion( completion: temporalio.bridge.proto.workflow_completion.WorkflowActivationCompletion, data_converter: temporalio.converter.DataConverter, encode_headers: bool, - concurrency_limit: int, + storage_concurrency_limit: int, ) -> temporalio.converter._extstore.StorageOperationMetrics: """Encode all payloads in the completion. Returns: Metrics from any external storage store operations that occurred. """ + await CommandAwarePayloadVisitor( + skip_search_attributes=True, + skip_headers=not encode_headers, + ).visit(_Visitor(data_converter._encode_payload_sequence), completion) + + async def _store_and_validate( + payloads: Sequence[Payload], + ) -> list[Payload]: + stored = await data_converter._external_store_payload_sequence(payloads) + data_converter._validate_payload_limits(stored) + return stored + metrics = temporalio.converter._extstore.StorageOperationMetrics() with metrics.track(): await CommandAwarePayloadVisitor( skip_search_attributes=True, skip_headers=not encode_headers, - concurrency_limit=concurrency_limit, - ).visit(_Visitor(data_converter._encode_payload_sequence), completion) + concurrency_limit=storage_concurrency_limit, + ).visit(_Visitor(_store_and_validate), completion) + return metrics diff --git a/temporalio/client.py b/temporalio/client.py index 22b07b1c1..cc2750ec6 100644 --- a/temporalio/client.py +++ b/temporalio/client.py @@ -9185,7 +9185,7 @@ async def _apply_headers( return if encode_headers: for payload in source.values(): - payload.CopyFrom(await data_converter._encode_payload(payload)) + payload.CopyFrom(await data_converter._transform_outbound_payload(payload)) temporalio.common._apply_headers(source, dest) diff --git a/temporalio/converter/_data_converter.py b/temporalio/converter/_data_converter.py index 9c2163774..99de876ea 100644 --- a/temporalio/converter/_data_converter.py +++ b/temporalio/converter/_data_converter.py @@ -111,6 +111,8 @@ async def encode( """ payloads = self.payload_converter.to_payloads(values) payloads = await self._encode_payload_sequence(payloads) + payloads = await self._external_store_payload_sequence(payloads) + self._validate_payload_limits(payloads) return payloads async def decode( @@ -128,6 +130,7 @@ async def decode( Returns: Decoded and converted values. """ + payloads = await self._external_retrieve_payload_sequence(payloads) payloads = await self._decode_payload_sequence(payloads) return self.payload_converter.from_payloads(payloads, type_hints) @@ -156,13 +159,13 @@ async def encode_failure( ) -> None: """Convert and encode failure.""" self.failure_converter.to_failure(exception, self.payload_converter, failure) - await _apply_to_failure_payloads(failure, self._encode_payloads) + await _apply_to_failure_payloads(failure, self._transform_outbound_payloads) async def decode_failure( self, failure: temporalio.api.failure.v1.Failure ) -> BaseException: """Decode and convert failure.""" - await _apply_to_failure_payloads(failure, self._decode_payloads) + await _apply_to_failure_payloads(failure, self._transform_inbound_payloads) return self.failure_converter.from_failure(failure, self.payload_converter) def with_context(self, context: SerializationContext) -> Self: @@ -250,7 +253,7 @@ async def _encode_memo_existing( "[TMPRL1103] Attempted to upload memo with size that exceeded the warning limit.", ) - async def _encode_payload( + async def _transform_outbound_payload( self, payload: temporalio.api.common.v1.Payload ) -> temporalio.api.common.v1.Payload: if self.payload_codec: @@ -260,27 +263,16 @@ async def _encode_payload( self._validate_payload_limits([payload]) return payload - async def _encode_payloads(self, payloads: temporalio.api.common.v1.Payloads): + async def _transform_outbound_payloads( + self, payloads: temporalio.api.common.v1.Payloads + ): if self.payload_codec: await self.payload_codec.encode_wrapper(payloads) if self.external_storage: await self.external_storage._store_payloads(payloads) self._validate_payload_limits(payloads.payloads) - async def _encode_payload_sequence( - self, payloads: Sequence[temporalio.api.common.v1.Payload] - ) -> list[temporalio.api.common.v1.Payload]: - encoded_payloads = list(payloads) - if self.payload_codec: - encoded_payloads = await self.payload_codec.encode(encoded_payloads) - if self.external_storage: - encoded_payloads = await self.external_storage._store_payload_sequence( - encoded_payloads - ) - self._validate_payload_limits(encoded_payloads) - return encoded_payloads - - async def _decode_payload( + async def _transform_inbound_payload( self, payload: temporalio.api.common.v1.Payload ) -> temporalio.api.common.v1.Payload: if self.external_storage: @@ -289,7 +281,9 @@ async def _decode_payload( payload = (await self.payload_codec.decode([payload]))[0] return payload - async def _decode_payloads(self, payloads: temporalio.api.common.v1.Payloads): + async def _transform_inbound_payloads( + self, payloads: temporalio.api.common.v1.Payloads + ): if self.external_storage: await self.external_storage._retrieve_payloads(payloads) else: @@ -304,23 +298,51 @@ async def _decode_payloads(self, payloads: temporalio.api.common.v1.Payloads): if self.payload_codec: await self.payload_codec.decode_wrapper(payloads) - async def _decode_payload_sequence( + async def _encode_payload_sequence( self, payloads: Sequence[temporalio.api.common.v1.Payload] ) -> list[temporalio.api.common.v1.Payload]: - decoded_payloads = list(payloads) + """Codec encode only.""" + encoded_payloads = list(payloads) + if self.payload_codec: + encoded_payloads = await self.payload_codec.encode(encoded_payloads) + return encoded_payloads + + async def _external_store_payload_sequence( + self, payloads: Sequence[temporalio.api.common.v1.Payload] + ) -> list[temporalio.api.common.v1.Payload]: + """External storage store, then validate payload limits.""" + stored_payloads = list(payloads) + if self.external_storage: + stored_payloads = await self.external_storage._store_payload_sequence( + stored_payloads + ) + return stored_payloads + + async def _external_retrieve_payload_sequence( + self, payloads: Sequence[temporalio.api.common.v1.Payload] + ) -> list[temporalio.api.common.v1.Payload]: + """External storage retrieve only.""" + retrieved_payloads = list(payloads) if self.external_storage: - decoded_payloads = await self.external_storage._retrieve_payload_sequence( - decoded_payloads + retrieved_payloads = await self.external_storage._retrieve_payload_sequence( + retrieved_payloads ) else: if any( p.metadata.get("encoding") == _REFERENCE_ENCODING - for p in decoded_payloads + for p in retrieved_payloads ): warnings.warn( "[TMPRL1105] Detected externally stored payload(s) but external storage is not configured.", StorageWarning, ) + return retrieved_payloads + + async def _decode_payload_sequence( + self, payloads: Sequence[temporalio.api.common.v1.Payload] + ) -> list[temporalio.api.common.v1.Payload]: + """Codec decode only.""" + decoded_payloads = list(payloads) if self.payload_codec: decoded_payloads = await self.payload_codec.decode(decoded_payloads) return decoded_payloads diff --git a/temporalio/worker/_activity.py b/temporalio/worker/_activity.py index 7b67734d9..c7a1032fe 100644 --- a/temporalio/worker/_activity.py +++ b/temporalio/worker/_activity.py @@ -631,7 +631,9 @@ async def _execute_activity( if self._encode_headers: for payload in start.header_fields.values(): - payload.CopyFrom(await data_converter._decode_payload(payload)) + payload.CopyFrom( + await data_converter._transform_inbound_payload(payload) + ) running_activity.info = info input = ExecuteActivityInput( diff --git a/temporalio/worker/_replayer.py b/temporalio/worker/_replayer.py index 30a0f35df..508d5f708 100644 --- a/temporalio/worker/_replayer.py +++ b/temporalio/worker/_replayer.py @@ -268,7 +268,7 @@ def on_eviction_hook( "header_codec_behavior", HeaderCodecBehavior.NO_CODEC ) != HeaderCodecBehavior.NO_CODEC, - max_workflow_task_payload_concurrency=1, + max_workflow_task_external_storage_concurrency=1, ) external_storage = data_converter.external_storage storage_driver_types = ( diff --git a/temporalio/worker/_worker.py b/temporalio/worker/_worker.py index 0baccbe95..9057e1449 100644 --- a/temporalio/worker/_worker.py +++ b/temporalio/worker/_worker.py @@ -36,7 +36,10 @@ from ._nexus import _NexusWorker from ._plugin import Plugin from ._tuning import WorkerTuner -from ._workflow import _DEFAULT_WORKFLOW_TASK_PAYLOAD_CONCURRENCY, _WorkflowWorker +from ._workflow import ( + _DEFAULT_WORKFLOW_TASK_EXTERNAL_STORAGE_CONCURRENCY, + _WorkflowWorker, +) from ._workflow_instance import UnsandboxedWorkflowRunner, WorkflowRunner from .workflow_sandbox import SandboxedWorkflowRunner @@ -142,7 +145,7 @@ def __init__( maximum=5 ), disable_payload_error_limit: bool = False, - max_workflow_task_payload_concurrency: int = _DEFAULT_WORKFLOW_TASK_PAYLOAD_CONCURRENCY, + max_workflow_task_external_storage_concurrency: int = _DEFAULT_WORKFLOW_TASK_EXTERNAL_STORAGE_CONCURRENCY, ) -> None: """Create a worker to process workflows and/or activities. @@ -317,10 +320,10 @@ def __init__( and cause a task failure if the size limit is exceeded. The default is False. See https://docs.temporal.io/troubleshooting/blob-size-limit-error for more details. - max_workflow_task_payload_concurrency: Maximum number of payload - operations (codec encode/decode, external storage I/O, etc.) - that may run concurrently within a single workflow task - activation. Defaults to 1. WARNING: This setting is experimental. + max_workflow_task_external_storage_concurrency: Maximum number of + external storage I/O operations (store/retrieve) that may run + concurrently within a single workflow task activation. + Defaults to 10. WARNING: This setting is experimental. """ config = WorkerConfig( @@ -366,7 +369,7 @@ def __init__( activity_task_poller_behavior=activity_task_poller_behavior, nexus_task_poller_behavior=nexus_task_poller_behavior, disable_payload_error_limit=disable_payload_error_limit, - max_workflow_task_payload_concurrency=max_workflow_task_payload_concurrency, + max_workflow_task_external_storage_concurrency=max_workflow_task_external_storage_concurrency, ) plugins_from_client = cast( @@ -420,12 +423,14 @@ def _init_from_config(self, client: temporalio.client.Client, config: WorkerConf raise ValueError( "default_versioning_behavior must be UNSPECIFIED when use_worker_versioning is False" ) - max_workflow_task_payload_concurrency = config.get( - "max_workflow_task_payload_concurrency", - _DEFAULT_WORKFLOW_TASK_PAYLOAD_CONCURRENCY, + max_workflow_task_external_storage_concurrency = config.get( + "max_workflow_task_external_storage_concurrency", + _DEFAULT_WORKFLOW_TASK_EXTERNAL_STORAGE_CONCURRENCY, ) - if max_workflow_task_payload_concurrency < 1: - raise ValueError("max_workflow_task_payload_concurrency must be positive") + if max_workflow_task_external_storage_concurrency < 1: + raise ValueError( + "max_workflow_task_external_storage_concurrency must be positive" + ) # Prepend applicable client interceptors to the given ones client_config = config["client"].config(active_config=True) # type: ignore[reportTypedDictNotRequiredAccess] @@ -530,7 +535,7 @@ def check_activity(activity: str): assert_local_activity_valid=check_activity, encode_headers=client_config["header_codec_behavior"] != HeaderCodecBehavior.NO_CODEC, - max_workflow_task_payload_concurrency=max_workflow_task_payload_concurrency, + max_workflow_task_external_storage_concurrency=max_workflow_task_external_storage_concurrency, ) tuner = config.get("tuner") @@ -977,7 +982,7 @@ class WorkerConfig(TypedDict, total=False): activity_task_poller_behavior: PollerBehavior nexus_task_poller_behavior: PollerBehavior disable_payload_error_limit: bool - max_workflow_task_payload_concurrency: int + max_workflow_task_external_storage_concurrency: int def _warn_if_activity_executor_max_workers_is_inconsistent( diff --git a/temporalio/worker/_workflow.py b/temporalio/worker/_workflow.py index 4844ec198..914b14370 100644 --- a/temporalio/worker/_workflow.py +++ b/temporalio/worker/_workflow.py @@ -47,7 +47,7 @@ # Set to true to log all activations and completions LOG_PROTOS = False -_DEFAULT_WORKFLOW_TASK_PAYLOAD_CONCURRENCY: int = 1 +_DEFAULT_WORKFLOW_TASK_EXTERNAL_STORAGE_CONCURRENCY: int = 10 class _WorkflowWorker: # type:ignore[reportUnusedClass] @@ -76,7 +76,7 @@ def __init__( should_enforce_versioning_behavior: bool, assert_local_activity_valid: Callable[[str], None], encode_headers: bool, - max_workflow_task_payload_concurrency: int, + max_workflow_task_external_storage_concurrency: int, ) -> None: self._bridge_worker = bridge_worker self._namespace = namespace @@ -115,8 +115,8 @@ def __init__( self._on_eviction_hook = on_eviction_hook self._disable_safe_eviction = disable_safe_eviction self._encode_headers = encode_headers - self._max_workflow_task_payload_concurrency = ( - max_workflow_task_payload_concurrency + self._max_workflow_task_external_storage_concurrency = ( + max_workflow_task_external_storage_concurrency ) self._throw_after_activation: Exception | None = None @@ -300,7 +300,7 @@ async def _handle_activation( act, data_converter, decode_headers=self._encode_headers, - concurrency_limit=self._max_workflow_task_payload_concurrency, + storage_concurrency_limit=self._max_workflow_task_external_storage_concurrency, ) if not workflow: assert init_job @@ -409,7 +409,7 @@ async def _handle_activation( completion, data_converter, encode_headers=self._encode_headers, - concurrency_limit=self._max_workflow_task_payload_concurrency, + storage_concurrency_limit=self._max_workflow_task_external_storage_concurrency, ) except temporalio.converter._payload_limits._PayloadSizeError as err: logger.warning(err.message) @@ -893,11 +893,29 @@ async def _encode_payload_sequence( ) -> list[temporalio.api.common.v1.Payload]: return await self._get_current_dc()._encode_payload_sequence(payloads) + async def _external_store_payload_sequence( + self, payloads: Sequence[temporalio.api.common.v1.Payload] + ) -> list[temporalio.api.common.v1.Payload]: + return await self._get_current_dc()._external_store_payload_sequence(payloads) + + async def _external_retrieve_payload_sequence( + self, payloads: Sequence[temporalio.api.common.v1.Payload] + ) -> list[temporalio.api.common.v1.Payload]: + return await self._get_current_dc()._external_retrieve_payload_sequence( + payloads + ) + async def _decode_payload_sequence( self, payloads: Sequence[temporalio.api.common.v1.Payload] ) -> list[temporalio.api.common.v1.Payload]: return await self._get_current_dc()._decode_payload_sequence(payloads) + def _validate_payload_limits( + self, + payloads: Sequence[temporalio.api.common.v1.Payload], + ) -> None: + self._get_current_dc()._validate_payload_limits(payloads) + class _InterruptDeadlockError(BaseException): pass diff --git a/tests/worker/test_visitor.py b/tests/worker/test_visitor.py index 9d8463015..15860f58c 100644 --- a/tests/worker/test_visitor.py +++ b/tests/worker/test_visitor.py @@ -303,7 +303,9 @@ async def test_bridge_encoding(): payload_codec=SimpleCodec(), ) - await temporalio.bridge.worker.encode_completion(comp, data_converter, True, 1) + await temporalio.bridge.worker.encode_completion( + comp, data_converter, True, storage_concurrency_limit=1 + ) cmd = comp.successful.commands[0] sa = cmd.schedule_activity