Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 27 additions & 6 deletions temporalio/bridge/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ async def decode_activation(
activation: temporalio.bridge.proto.workflow_activation.WorkflowActivation,
data_converter: temporalio.converter.DataConverter,
decode_headers: bool,
concurrency_limit: int,
storage_concurrency_limit: int,
) -> temporalio.converter._extstore.StorageOperationMetrics:
"""Decode all payloads in the activation.

Expand All @@ -315,27 +315,48 @@ async def decode_activation(
await CommandAwarePayloadVisitor(
skip_search_attributes=True,
skip_headers=not decode_headers,
concurrency_limit=concurrency_limit,
).visit(_Visitor(data_converter._decode_payload_sequence), activation)
concurrency_limit=storage_concurrency_limit,
).visit(
_Visitor(data_converter._external_retrieve_payload_sequence), activation
)

await CommandAwarePayloadVisitor(
skip_search_attributes=True,
skip_headers=not decode_headers,
).visit(_Visitor(data_converter._decode_payload_sequence), activation)

return metrics


async def encode_completion(
completion: temporalio.bridge.proto.workflow_completion.WorkflowActivationCompletion,
data_converter: temporalio.converter.DataConverter,
encode_headers: bool,
concurrency_limit: int,
storage_concurrency_limit: int,
) -> temporalio.converter._extstore.StorageOperationMetrics:
"""Encode all payloads in the completion.

Returns:
Metrics from any external storage store operations that occurred.
"""
await CommandAwarePayloadVisitor(
skip_search_attributes=True,
skip_headers=not encode_headers,
).visit(_Visitor(data_converter._encode_payload_sequence), completion)

async def _store_and_validate(
payloads: Sequence[Payload],
) -> list[Payload]:
stored = await data_converter._external_store_payload_sequence(payloads)
data_converter._validate_payload_limits(stored)
return stored

metrics = temporalio.converter._extstore.StorageOperationMetrics()
with metrics.track():
await CommandAwarePayloadVisitor(
skip_search_attributes=True,
skip_headers=not encode_headers,
concurrency_limit=concurrency_limit,
).visit(_Visitor(data_converter._encode_payload_sequence), completion)
concurrency_limit=storage_concurrency_limit,
).visit(_Visitor(_store_and_validate), completion)

return metrics
2 changes: 1 addition & 1 deletion temporalio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9185,7 +9185,7 @@ async def _apply_headers(
return
if encode_headers:
for payload in source.values():
payload.CopyFrom(await data_converter._encode_payload(payload))
payload.CopyFrom(await data_converter._transform_outbound_payload(payload))
temporalio.common._apply_headers(source, dest)


Expand Down
70 changes: 46 additions & 24 deletions temporalio/converter/_data_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ async def encode(
"""
payloads = self.payload_converter.to_payloads(values)
payloads = await self._encode_payload_sequence(payloads)
payloads = await self._external_store_payload_sequence(payloads)
self._validate_payload_limits(payloads)
return payloads

async def decode(
Expand All @@ -128,6 +130,7 @@ async def decode(
Returns:
Decoded and converted values.
"""
payloads = await self._external_retrieve_payload_sequence(payloads)
payloads = await self._decode_payload_sequence(payloads)
return self.payload_converter.from_payloads(payloads, type_hints)

Expand Down Expand Up @@ -156,13 +159,13 @@ async def encode_failure(
) -> None:
"""Convert and encode failure."""
self.failure_converter.to_failure(exception, self.payload_converter, failure)
await _apply_to_failure_payloads(failure, self._encode_payloads)
await _apply_to_failure_payloads(failure, self._transform_outbound_payloads)

async def decode_failure(
self, failure: temporalio.api.failure.v1.Failure
) -> BaseException:
"""Decode and convert failure."""
await _apply_to_failure_payloads(failure, self._decode_payloads)
await _apply_to_failure_payloads(failure, self._transform_inbound_payloads)
return self.failure_converter.from_failure(failure, self.payload_converter)

def with_context(self, context: SerializationContext) -> Self:
Expand Down Expand Up @@ -250,7 +253,7 @@ async def _encode_memo_existing(
"[TMPRL1103] Attempted to upload memo with size that exceeded the warning limit.",
)

async def _encode_payload(
async def _transform_outbound_payload(
self, payload: temporalio.api.common.v1.Payload
) -> temporalio.api.common.v1.Payload:
if self.payload_codec:
Expand All @@ -260,27 +263,16 @@ async def _encode_payload(
self._validate_payload_limits([payload])
return payload

async def _encode_payloads(self, payloads: temporalio.api.common.v1.Payloads):
async def _transform_outbound_payloads(
self, payloads: temporalio.api.common.v1.Payloads
):
if self.payload_codec:
await self.payload_codec.encode_wrapper(payloads)
if self.external_storage:
await self.external_storage._store_payloads(payloads)
self._validate_payload_limits(payloads.payloads)

async def _encode_payload_sequence(
self, payloads: Sequence[temporalio.api.common.v1.Payload]
) -> list[temporalio.api.common.v1.Payload]:
encoded_payloads = list(payloads)
if self.payload_codec:
encoded_payloads = await self.payload_codec.encode(encoded_payloads)
if self.external_storage:
encoded_payloads = await self.external_storage._store_payload_sequence(
encoded_payloads
)
self._validate_payload_limits(encoded_payloads)
return encoded_payloads

async def _decode_payload(
async def _transform_inbound_payload(
self, payload: temporalio.api.common.v1.Payload
) -> temporalio.api.common.v1.Payload:
if self.external_storage:
Expand All @@ -289,7 +281,9 @@ async def _decode_payload(
payload = (await self.payload_codec.decode([payload]))[0]
return payload

async def _decode_payloads(self, payloads: temporalio.api.common.v1.Payloads):
async def _transform_inbound_payloads(
self, payloads: temporalio.api.common.v1.Payloads
):
if self.external_storage:
await self.external_storage._retrieve_payloads(payloads)
else:
Expand All @@ -304,23 +298,51 @@ async def _decode_payloads(self, payloads: temporalio.api.common.v1.Payloads):
if self.payload_codec:
await self.payload_codec.decode_wrapper(payloads)

async def _decode_payload_sequence(
async def _encode_payload_sequence(
self, payloads: Sequence[temporalio.api.common.v1.Payload]
) -> list[temporalio.api.common.v1.Payload]:
decoded_payloads = list(payloads)
"""Codec encode only."""
encoded_payloads = list(payloads)
if self.payload_codec:
encoded_payloads = await self.payload_codec.encode(encoded_payloads)
return encoded_payloads

async def _external_store_payload_sequence(
self, payloads: Sequence[temporalio.api.common.v1.Payload]
) -> list[temporalio.api.common.v1.Payload]:
"""External storage store, then validate payload limits."""
stored_payloads = list(payloads)
if self.external_storage:
stored_payloads = await self.external_storage._store_payload_sequence(
stored_payloads
)
return stored_payloads

async def _external_retrieve_payload_sequence(
self, payloads: Sequence[temporalio.api.common.v1.Payload]
) -> list[temporalio.api.common.v1.Payload]:
"""External storage retrieve only."""
retrieved_payloads = list(payloads)
if self.external_storage:
decoded_payloads = await self.external_storage._retrieve_payload_sequence(
decoded_payloads
retrieved_payloads = await self.external_storage._retrieve_payload_sequence(
retrieved_payloads
)
else:
if any(
p.metadata.get("encoding") == _REFERENCE_ENCODING
for p in decoded_payloads
for p in retrieved_payloads
):
warnings.warn(
"[TMPRL1105] Detected externally stored payload(s) but external storage is not configured.",
StorageWarning,
)
return retrieved_payloads

async def _decode_payload_sequence(
self, payloads: Sequence[temporalio.api.common.v1.Payload]
) -> list[temporalio.api.common.v1.Payload]:
"""Codec decode only."""
decoded_payloads = list(payloads)
if self.payload_codec:
decoded_payloads = await self.payload_codec.decode(decoded_payloads)
return decoded_payloads
Expand Down
4 changes: 3 additions & 1 deletion temporalio/worker/_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,9 @@ async def _execute_activity(

if self._encode_headers:
for payload in start.header_fields.values():
payload.CopyFrom(await data_converter._decode_payload(payload))
payload.CopyFrom(
await data_converter._transform_inbound_payload(payload)
)

running_activity.info = info
input = ExecuteActivityInput(
Expand Down
2 changes: 1 addition & 1 deletion temporalio/worker/_replayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def on_eviction_hook(
"header_codec_behavior", HeaderCodecBehavior.NO_CODEC
)
!= HeaderCodecBehavior.NO_CODEC,
max_workflow_task_payload_concurrency=1,
max_workflow_task_external_storage_concurrency=1,
)
external_storage = data_converter.external_storage
storage_driver_types = (
Expand Down
33 changes: 19 additions & 14 deletions temporalio/worker/_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@
from ._nexus import _NexusWorker
from ._plugin import Plugin
from ._tuning import WorkerTuner
from ._workflow import _DEFAULT_WORKFLOW_TASK_PAYLOAD_CONCURRENCY, _WorkflowWorker
from ._workflow import (
_DEFAULT_WORKFLOW_TASK_EXTERNAL_STORAGE_CONCURRENCY,
_WorkflowWorker,
)
from ._workflow_instance import UnsandboxedWorkflowRunner, WorkflowRunner
from .workflow_sandbox import SandboxedWorkflowRunner

Expand Down Expand Up @@ -142,7 +145,7 @@ def __init__(
maximum=5
),
disable_payload_error_limit: bool = False,
max_workflow_task_payload_concurrency: int = _DEFAULT_WORKFLOW_TASK_PAYLOAD_CONCURRENCY,
max_workflow_task_external_storage_concurrency: int = _DEFAULT_WORKFLOW_TASK_EXTERNAL_STORAGE_CONCURRENCY,
) -> None:
"""Create a worker to process workflows and/or activities.

Expand Down Expand Up @@ -317,10 +320,10 @@ def __init__(
and cause a task failure if the size limit is exceeded. The default is False.
See https://docs.temporal.io/troubleshooting/blob-size-limit-error for more
details.
max_workflow_task_payload_concurrency: Maximum number of payload
operations (codec encode/decode, external storage I/O, etc.)
that may run concurrently within a single workflow task
activation. Defaults to 1. WARNING: This setting is experimental.
max_workflow_task_external_storage_concurrency: Maximum number of
external storage I/O operations (store/retrieve) that may run
concurrently within a single workflow task activation.
Defaults to 10. WARNING: This setting is experimental.

"""
config = WorkerConfig(
Expand Down Expand Up @@ -366,7 +369,7 @@ def __init__(
activity_task_poller_behavior=activity_task_poller_behavior,
nexus_task_poller_behavior=nexus_task_poller_behavior,
disable_payload_error_limit=disable_payload_error_limit,
max_workflow_task_payload_concurrency=max_workflow_task_payload_concurrency,
max_workflow_task_external_storage_concurrency=max_workflow_task_external_storage_concurrency,
)

plugins_from_client = cast(
Expand Down Expand Up @@ -420,12 +423,14 @@ def _init_from_config(self, client: temporalio.client.Client, config: WorkerConf
raise ValueError(
"default_versioning_behavior must be UNSPECIFIED when use_worker_versioning is False"
)
max_workflow_task_payload_concurrency = config.get(
"max_workflow_task_payload_concurrency",
_DEFAULT_WORKFLOW_TASK_PAYLOAD_CONCURRENCY,
max_workflow_task_external_storage_concurrency = config.get(
"max_workflow_task_external_storage_concurrency",
_DEFAULT_WORKFLOW_TASK_EXTERNAL_STORAGE_CONCURRENCY,
)
if max_workflow_task_payload_concurrency < 1:
raise ValueError("max_workflow_task_payload_concurrency must be positive")
if max_workflow_task_external_storage_concurrency < 1:
raise ValueError(
"max_workflow_task_external_storage_concurrency must be positive"
)

# Prepend applicable client interceptors to the given ones
client_config = config["client"].config(active_config=True) # type: ignore[reportTypedDictNotRequiredAccess]
Expand Down Expand Up @@ -530,7 +535,7 @@ def check_activity(activity: str):
assert_local_activity_valid=check_activity,
encode_headers=client_config["header_codec_behavior"]
!= HeaderCodecBehavior.NO_CODEC,
max_workflow_task_payload_concurrency=max_workflow_task_payload_concurrency,
max_workflow_task_external_storage_concurrency=max_workflow_task_external_storage_concurrency,
)

tuner = config.get("tuner")
Expand Down Expand Up @@ -977,7 +982,7 @@ class WorkerConfig(TypedDict, total=False):
activity_task_poller_behavior: PollerBehavior
nexus_task_poller_behavior: PollerBehavior
disable_payload_error_limit: bool
max_workflow_task_payload_concurrency: int
max_workflow_task_external_storage_concurrency: int


def _warn_if_activity_executor_max_workers_is_inconsistent(
Expand Down
30 changes: 24 additions & 6 deletions temporalio/worker/_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
# Set to true to log all activations and completions
LOG_PROTOS = False

_DEFAULT_WORKFLOW_TASK_PAYLOAD_CONCURRENCY: int = 1
_DEFAULT_WORKFLOW_TASK_EXTERNAL_STORAGE_CONCURRENCY: int = 10


class _WorkflowWorker: # type:ignore[reportUnusedClass]
Expand Down Expand Up @@ -76,7 +76,7 @@ def __init__(
should_enforce_versioning_behavior: bool,
assert_local_activity_valid: Callable[[str], None],
encode_headers: bool,
max_workflow_task_payload_concurrency: int,
max_workflow_task_external_storage_concurrency: int,
) -> None:
self._bridge_worker = bridge_worker
self._namespace = namespace
Expand Down Expand Up @@ -115,8 +115,8 @@ def __init__(
self._on_eviction_hook = on_eviction_hook
self._disable_safe_eviction = disable_safe_eviction
self._encode_headers = encode_headers
self._max_workflow_task_payload_concurrency = (
max_workflow_task_payload_concurrency
self._max_workflow_task_external_storage_concurrency = (
max_workflow_task_external_storage_concurrency
)
self._throw_after_activation: Exception | None = None

Expand Down Expand Up @@ -300,7 +300,7 @@ async def _handle_activation(
act,
data_converter,
decode_headers=self._encode_headers,
concurrency_limit=self._max_workflow_task_payload_concurrency,
storage_concurrency_limit=self._max_workflow_task_external_storage_concurrency,
)
if not workflow:
assert init_job
Expand Down Expand Up @@ -409,7 +409,7 @@ async def _handle_activation(
completion,
data_converter,
encode_headers=self._encode_headers,
concurrency_limit=self._max_workflow_task_payload_concurrency,
storage_concurrency_limit=self._max_workflow_task_external_storage_concurrency,
)
except temporalio.converter._payload_limits._PayloadSizeError as err:
logger.warning(err.message)
Expand Down Expand Up @@ -893,11 +893,29 @@ async def _encode_payload_sequence(
) -> list[temporalio.api.common.v1.Payload]:
return await self._get_current_dc()._encode_payload_sequence(payloads)

async def _external_store_payload_sequence(
self, payloads: Sequence[temporalio.api.common.v1.Payload]
) -> list[temporalio.api.common.v1.Payload]:
return await self._get_current_dc()._external_store_payload_sequence(payloads)

async def _external_retrieve_payload_sequence(
self, payloads: Sequence[temporalio.api.common.v1.Payload]
) -> list[temporalio.api.common.v1.Payload]:
return await self._get_current_dc()._external_retrieve_payload_sequence(
payloads
)

async def _decode_payload_sequence(
self, payloads: Sequence[temporalio.api.common.v1.Payload]
) -> list[temporalio.api.common.v1.Payload]:
return await self._get_current_dc()._decode_payload_sequence(payloads)

def _validate_payload_limits(
self,
payloads: Sequence[temporalio.api.common.v1.Payload],
) -> None:
self._get_current_dc()._validate_payload_limits(payloads)


class _InterruptDeadlockError(BaseException):
pass
Loading
Loading